Coverage for rhodent/density_matrices/distributed/frequency.py: 98%
168 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
1from __future__ import annotations
3from typing import Generator
4import numpy as np
6from gpaw.tddft.units import au_to_eV
8from .base import BaseDistributor, RhoParameters
9from .time import TimeDistributor, AlltoallvTimeDistributor
10from ..buffer import DensityMatrixBuffer
11from ..readers.gpaw import KohnShamRhoWfsReader
12from ...utils import get_array_filter, safe_fill_larger, fast_pad
13from ...utils.logging import format_frequencies
14from ...utils.memory import MemoryEstimate
15from ...perturbation import create_perturbation, PerturbationLike
16from ...typing import Array1D
19class FourierTransformer(BaseDistributor):
21 """ Iteratively take the Fourier transform of density matrices.
23 Parameters
24 ----------
25 rho_nn_reader
26 Object that can iteratively read density matrices in the time domain,
27 distributed such that different ranks have different chunks of the density
28 matrix, but each ranks has all times for the same chunk.
29 perturbation
30 The perturbation which the density matrices are a response to.
31 filter_frequencies
32 After Fourier transformation keep only these frequencies (or the frequencies
33 closest to them). In atomic units.
34 frequency_broadening
35 Gaussian broadening width in atomic units. Default (0) is no broadening.
36 result_on_ranks
37 List of ranks among which the resulting arrays will be distributed.
38 Empty list (default) to distribute among all ranks.
39 """
41 def __init__(self,
42 rho_nn_reader: TimeDistributor,
43 perturbation: PerturbationLike,
44 filter_frequencies: list[float] | Array1D[np.float64] | None = None,
45 frequency_broadening: float = 0,
46 result_on_ranks: list[int] = []):
47 super().__init__(rho_nn_reader.rho_wfs_reader,
48 rho_nn_reader._parameters,
49 comm=rho_nn_reader.comm)
50 self.rho_nn_reader = rho_nn_reader
51 self.perturbation = create_perturbation(perturbation)
52 self.frequency_broadening = frequency_broadening
53 self._flt_w = get_array_filter(self._omega_w, filter_frequencies)
55 if len(result_on_ranks) == 0:
56 self._result_on_ranks = set(range(self.comm.size))
57 else:
58 assert all(isinstance(rank, int) and rank >= 0 and rank < self.comm.size
59 for rank in result_on_ranks), result_on_ranks
60 self._result_on_ranks = set(result_on_ranks)
62 self._dist_buffer: DensityMatrixBuffer | None = None
64 @property
65 def dtype(self):
66 return np.complex128
68 @property
69 def xshape(self):
70 return (self.nw, )
72 @property
73 def freq_w(self) -> Array1D[np.float64]:
74 return self._omega_w[self.flt_w] # type: ignore
76 @property
77 def _omega_w(self) -> Array1D[np.float64]:
78 padnt = fast_pad(self.rho_nn_reader.nt)
79 dt = self.rho_nn_reader.dt
80 omega_w = 2 * np.pi * np.fft.rfftfreq(padnt, dt)
82 return omega_w # type: ignore
84 @property
85 def nw(self) -> int:
86 return len(self.freq_w)
88 @property
89 def nlocalw(self) -> int:
90 return (self.nw + self.nranks_result - 1) // self.nranks_result
92 @property
93 def flt_w(self) -> slice | Array1D[np.bool_]:
94 return self._flt_w
96 @property
97 def result_on_ranks(self) -> list[int]:
98 """ Set of ranks among which the result will be distributed """
99 return sorted(self._result_on_ranks)
101 @property
102 def nranks_result(self) -> int:
103 """ Number of ranks that the resulting arrays will be distributed among """
104 return len(self._result_on_ranks)
106 def distributed_work(self) -> list[list[int]]:
107 freqw_r = self.comm.size * [[]]
108 for r, rank in enumerate(self.result_on_ranks):
109 freqw_r[rank] = list(range(r, self.nw, self.nranks_result))
111 return freqw_r
113 def my_work(self) -> list[int]:
114 freqw_r = self.distributed_work()
115 return freqw_r[self.comm.rank]
117 def __str__(self) -> str:
118 nt = len(self.rho_nn_reader.time_t)
119 niters = len(list(self.work_loop_by_ranks()))
121 lines = []
122 lines.append('Fourier transformer')
123 lines.append(f' Calculating Fourier transform on {self.maxnchunks} ranks')
124 lines.append(' Fast Fourier transform')
125 lines.append(f' matrix dimensions {self.rho_nn_reader._parameters.nnshape}')
126 lines.append(f' grid of {nt} times')
127 lines.append(f' {self.describe_reim()}')
128 if self.frequency_broadening == 0:
129 lines.append(' No frequency broadening')
130 else:
131 lines.append(f' Applying frequency broadening of {self.frequency_broadening * au_to_eV:.2f}eV')
132 lines.append(f' keeping frequency grid of {self.nw} frequencies')
133 lines.append(f' {format_frequencies(self.freq_w, units="au")}')
134 lines.append('')
136 lines.append(' Redistributing into full density matrices')
137 lines.append(f' {niters} iterations to process all chunks')
138 lines.append(f' matrix dimensions {self.rho_nn_reader._parameters.full_nnshape}')
139 lines.append(f' result stored on {self.nranks_result} ranks')
141 return '\n'.join(lines)
143 def get_memory_estimate(self) -> MemoryEstimate:
144 parameters = self.rho_nn_reader._parameters
146 narrays = 2 if self.yield_re and self.yield_im else 1
147 temp_shape = parameters.nnshape + (self.maxnchunks, self.nlocalw, narrays)
148 result_shape = parameters.full_nnshape + (self.nlocalw, narrays)
150 total_result_size = int(np.prod(parameters.full_nnshape)) * self.nw * narrays
152 comment = f'Buffers hold {narrays} arrays ({self.describe_reim()})'
153 own_memory_estimate = MemoryEstimate(comment=comment)
154 own_memory_estimate.add_key('Temporary buffer', temp_shape, complex,
155 on_num_ranks=self.nranks_result)
156 own_memory_estimate.add_key('Result buffer', result_shape, complex,
157 total_size=total_result_size,
158 on_num_ranks=self.nranks_result)
160 memory_estimate = MemoryEstimate()
161 memory_estimate.add_child('Time-dependent wave functions reader',
162 self.rho_nn_reader.rho_wfs_reader.get_memory_estimate())
163 memory_estimate.add_child('Parallel density matrices reader',
164 self.rho_nn_reader.get_memory_estimate())
165 memory_estimate.add_child('Fourier transformer',
166 own_memory_estimate)
168 return memory_estimate
170 def __iter__(self) -> Generator[DensityMatrixBuffer, None, None]:
171 time_t = self.rho_nn_reader.time_t # Times in wave functions file
172 dt = self.rho_nn_reader.dt # Time step
173 padnt = fast_pad(len(time_t)) # Pad with zeros
175 dm_buffer = DensityMatrixBuffer(self.rho_nn_reader._parameters.nnshape,
176 (self.nw, ),
177 np.complex128)
178 if self.yield_re:
179 dm_buffer.zeros(True, 0)
180 if self.yield_im:
181 dm_buffer.zeros(False, 0)
183 for read_buffer in self.rho_nn_reader:
184 for data_nnt, buffer_nnw in zip(read_buffer._iter_buffers(), dm_buffer._iter_buffers()):
185 if self.frequency_broadening == 0:
186 data_nnw = self.perturbation.normalize_frequency_response(data_nnt, time_t, padnt, axis=-1)
187 else:
188 data_nnt = self.perturbation.normalize_time_response(data_nnt, time_t, axis=-1)
189 data_nnt[..., :len(time_t)] *= np.exp(-0.5 * self.frequency_broadening ** 2 * time_t**2)
190 data_nnw = np.fft.rfft(data_nnt, n=padnt, axis=-1) * dt
191 buffer_nnw[:] = data_nnw[..., self.flt_w].conj() # Change sign convention
193 yield dm_buffer.copy()
195 @property
196 def dist_buffer(self) -> DensityMatrixBuffer:
197 """ Buffer of density matrices on this rank after redistribution """
198 if self._dist_buffer is None:
199 self._dist_buffer = self.redistribute()
200 return self._dist_buffer
202 def create_out_buffer(self) -> DensityMatrixBuffer:
203 """ Create the DensityMatrixBuffer to hold the temporary density matrix after each redistribution """
204 parameters = self.rho_nn_reader._parameters
205 nlocalw = self.nlocalw if self.comm.rank in self.result_on_ranks else 0
206 out_dm = DensityMatrixBuffer(nnshape=parameters.nnshape,
207 xshape=(self.maxnchunks, nlocalw),
208 dtype=np.complex128)
209 out_dm.zero_buffers(real=self.yield_re, imag=self.yield_im, derivative_order_s=[0])
211 return out_dm
213 def create_result_buffer(self) -> DensityMatrixBuffer:
214 """ Create the DensityMatrixBuffer to hold the resulting density matrix """
215 parameters = self.rho_nn_reader._parameters
216 nnshape = parameters.full_nnshape
217 full_dm = DensityMatrixBuffer(nnshape=nnshape,
218 xshape=(len(self.my_work()), ),
219 dtype=np.complex128)
220 full_dm.zero_buffers(real=self.yield_re, imag=self.yield_im, derivative_order_s=[0])
222 return full_dm
224 def redistribute(self) -> DensityMatrixBuffer:
225 """ Perform the Fourier transform and redistribute the data
227 When the Fourier transform is performed, the data is distributed such that each rank
228 stores the entire time/frequency series for one chunk of the density matrices, i.e. indices n1, n2.
230 This function then performs a redistribution of the data such that each rank stores full
231 density matrices, for certain frequencies.
233 If the density matrices are split into more chunks than there are ranks, then the
234 chunks are read, Fourier transformed and distributed in a loop several times until all
235 data has been processed.
237 Returns
238 -------
239 Density matrix buffer with x-dimensions (Number of local frequencies, )
240 where the Number of local frequencies variers between the ranks.
241 """
242 local_work = iter(self)
243 parameters = self.rho_nn_reader._parameters
244 log = self.log
245 self.rho_nn_reader.rho_wfs_reader.lcao_rho_reader.striden == 0, \
246 'n stride must be 0 (index all) for redistribute'
248 # Frequency indices of result on each rank
249 freqw_r = self.distributed_work()
250 niters = len(list(self.work_loop_by_ranks()))
252 out_dm = self.create_out_buffer()
253 full_dm = self.create_result_buffer()
255 _exhausted = object()
257 # Loop over the chunks of the density matrix
258 for chunki, indices_r in enumerate(self.work_loop_by_ranks()):
259 # At this point, each rank stores one unique chunk of the density matrix.
260 # All ranks have the entire time series of data for their own chunk.
261 # If there are more chunks than ranks, then this loop will run
262 # for several iterations. If the number of chunks is not divisible by the number of
263 # ranks then, during the last iteration, some of the chunks are None (meaning the rank
264 # currently has no data).
266 # List of chunks that each rank currently stores, where element r of the list
267 # contains the chunk that rank r works with. Ranks higher than the length of the list
268 # currently store no chunks.
269 # The list itself is identical on all ranks.
270 chunks_by_rank = [indices[2:] for indices in indices_r if indices is not None]
272 ntargets = len(chunks_by_rank)
274 if self.comm.rank < ntargets:
275 # This rank has data to send. Compute the Fourier transform and store the result
276 dm_buffer = next(local_work)
277 else:
278 # This rank has no data to send
279 assert next(local_work, _exhausted) is _exhausted
280 # Still, we need to create a dummy buffer
281 dm_buffer = DensityMatrixBuffer(nnshape=parameters.nnshape,
282 xshape=(0, ), dtype=np.complex128)
283 dm_buffer.zero_buffers(real=self.yield_re, imag=self.yield_im,
284 derivative_order_s=[0])
286 log.start('alltoallv')
288 # Redistribute the data:
289 # - dm_buffer stores single chunks of density matrices, for all frequencies.
290 # - out_dm will store several chunks, for a few frequencies.
291 # source_indices_r describes which slices of dm_buffer should be sent to which rank
292 # target_indices_r describes to which positions of the out_dm buffer should be received
293 # from which rank
294 source_indices_r = [None if len(w) == 0 else w for w in freqw_r]
295 target_indices_r = [r if r < ntargets else None for r in range(self.comm.size)]
296 dm_buffer.redistribute(out_dm,
297 comm=self.comm,
298 source_indices_r=source_indices_r,
299 target_indices_r=target_indices_r,
300 log=log)
302 if self.comm.rank == 0:
303 log(f'Chunk {chunki+1}/{niters}: distributed frequency response in '
304 f'{log.elapsed("alltoallv"):.1f}s', flush=True, who='Response')
306 for array_nnrw, full_array_nnw in zip(out_dm._iter_buffers(), full_dm._iter_buffers()):
307 for r, nn_indices in enumerate(chunks_by_rank):
308 safe_fill_larger(full_array_nnw[nn_indices], array_nnrw[:, :, r])
310 assert next(local_work, _exhausted) is _exhausted
312 return full_dm
314 @classmethod
315 def from_reader(cls, # type: ignore
316 rho_nn_reader: KohnShamRhoWfsReader,
317 parameters: RhoParameters,
318 *,
319 perturbation: PerturbationLike,
320 filter_frequencies: list[float] | Array1D[np.float64] | None = None,
321 frequency_broadening: float = 0,
322 result_on_ranks: list[int] = []) -> FourierTransformer:
323 time_distributor = AlltoallvTimeDistributor(rho_nn_reader, parameters)
324 fourier_transformer = FourierTransformer(time_distributor,
325 perturbation=perturbation,
326 filter_frequencies=filter_frequencies,
327 frequency_broadening=frequency_broadening,
328 result_on_ranks=result_on_ranks)
329 return fourier_transformer