Coverage for rhodent/density_matrices/distributed/pulse.py: 95%
203 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
1from __future__ import annotations
3from typing import Collection, Generator
4import numpy as np
5from numpy.typing import NDArray
7from gpaw.tddft.units import au_to_as
9from .base import BaseDistributor, RhoParameters
10from .time import TimeDistributor, AlltoallvTimeDistributor
11from ..buffer import DensityMatrixBuffer
12from ..readers.gpaw import KohnShamRhoWfsReader
13from ...utils import get_array_filter, safe_fill_larger, fast_pad
14from ...perturbation import create_perturbation, PerturbationLike, PulsePerturbation
15from ...utils.logging import format_times
16from ...utils.memory import MemoryEstimate
17from ...typing import Array1D
20class PulseConvolver(BaseDistributor):
21 r""" Class performing pulse convolution of density matrices.
23 The procedure of the pulse convolution is the following:
25 - The entire time series of (real and/or imaginary parts of) density matrices
26 is read, in several chunks of indices (n1, n2). Each MPI rank works on
27 different chunks.
28 - Each chunk is Fourier transformed, divided by the Fourier transform of the
29 original perturbation and multiplied by the Fourier transform of the
30 new pulse(s).
31 - Optionally, derivatives are computed by multiplying the density matrices
32 in the frequency domain by factors of :math:`i \omega`.
33 - Each chunk is inverse Fourier tranformed, and only selected times are kept.
35 Additionally, this class can redistribute the resulting convoluted density matrices
36 so that the each rank holds the entire density matrix, for a few times.
38 Parameters
39 ----------
40 rho_nn_reader
41 Object begin able to iteratively read density matrices in the time domain.
42 Density matrices are split in chunks and distributed among ranks.
43 perturbation
44 The perturbation which the density matrices read by :attr:`rho_nn_reader`
45 are a response to.
46 pulses
47 List of pulses to perform to convolution with.
48 derivative_order_s
49 List of derivative orders to compute.
50 filter_times
51 After convolution keep only these times (or the times closest to them).
52 In atomic units.
53 result_on_ranks
54 List of ranks among which the resulting arrays will be distributed.
55 Empty list (default) to distribute among all ranks.
56 """
58 def __init__(self,
59 rho_nn_reader: TimeDistributor,
60 perturbation: PerturbationLike,
61 pulses: Collection[PerturbationLike],
62 derivative_order_s: list[int] = [0],
63 filter_times: list[float] | Array1D[np.float64] | None = None,
64 result_on_ranks: list[int] = []):
65 super().__init__(rho_nn_reader.rho_wfs_reader,
66 rho_nn_reader._parameters,
67 comm=rho_nn_reader.comm)
68 self.rho_nn_reader = rho_nn_reader
70 # Check if we need to perform upscaling
71 wfs_time_t = self.rho_nn_reader.time_t
72 dt = self.rho_nn_reader.dt
73 self._warn_too_small_dt = False
75 if filter_times is None or len(filter_times) < 2:
76 self._upscaling = 1
77 else:
78 # See if there is a mismatch between wanted and existing time grids.
79 # If so, then upscale the data in the Fourier transform step.
80 self._requested_dt = min(np.diff(np.sort(filter_times)))
81 upscaling = int(np.round(dt / self._requested_dt))
82 if upscaling < 1:
83 upscaling = 1
84 elif upscaling > 100:
85 self._warn_too_small_dt = True
86 upscaling = 1
87 self._upscaling = upscaling
89 # Construct an upscaled times grid
90 self._time_t = wfs_time_t[0] + dt/self.upscaling * np.arange(len(wfs_time_t) * self.upscaling)
92 # And filter it
93 self._flt_t = get_array_filter(self._time_t, filter_times)
95 # Set up pulses and perturbation
96 self.pulses = [create_perturbation(pulse) for pulse in pulses]
97 if not all(isinstance(pulse, PulsePerturbation) for pulse in self.pulses):
98 raise ValueError('Pulse convolution can only be performed with pulses of type PulsePerturbation.')
99 self.perturbation = create_perturbation(perturbation)
101 # Set up derivatives
102 assert all(order in [0, 1, 2] for order in derivative_order_s)
103 assert all(np.diff(derivative_order_s) > 0), 'Derivative orders must be strictly increasing'
104 self.derivative_order_s = derivative_order_s
106 # Check which ranks the result should be stored on
107 if len(result_on_ranks) == 0:
108 self._result_on_ranks = set(range(self.comm.size))
109 else:
110 assert all(isinstance(rank, int) and rank >= 0 and rank < self.comm.size
111 for rank in result_on_ranks), result_on_ranks
112 self._result_on_ranks = set(result_on_ranks)
114 self._dist_buffer: DensityMatrixBuffer | None = None
116 @property
117 def dtype(self):
118 return np.float64
120 @property
121 def xshape(self):
122 return (len(self.pulses), self.nt)
124 @property
125 def time_t(self) -> NDArray[np.float64]:
126 """ Array of times corresponding to convoluted density matrices; in atomic units. """
127 return self._time_t[self._flt_t]
129 @property
130 def nt(self) -> int:
131 """ Number of times for which convoluted density matrices are calculated. """
132 return len(self.time_t)
134 @property
135 def nlocalt(self) -> int:
136 """ Number of times stored on this rank after redistribution of the result. """
137 return (self.nt + self.nranks_result - 1) // self.nranks_result
139 @property
140 def upscaling(self) -> int:
141 """ Upscaling factor.
143 Data is upscaled in time by this factor during the Fourier transformation step,
144 in order to calculate convoluted density matrices at a denser grid of times than
145 what is present in the time-dependent wave functions file.
146 """
147 return self._upscaling
149 @property
150 def result_on_ranks(self) -> list[int]:
151 """ Set of ranks among which the result will be distributed. """
152 return sorted(self._result_on_ranks)
154 @property
155 def nranks_result(self) -> int:
156 """ Number of ranks storing part of the result after redistribution. """
157 return len(self._result_on_ranks)
159 def distributed_work(self) -> list[list[int]]:
160 # Empty list for ranks that will not have any part of the result
161 timet_r = self.comm.size * [[]]
162 for r, rank in enumerate(self.result_on_ranks):
163 timet_r[rank] = list(range(r, self.nt, self.nranks_result))
165 return timet_r
167 def my_work(self) -> list[int]:
168 timet_r = self.distributed_work()
169 return timet_r[self.comm.rank]
171 def __str__(self) -> str:
172 wfs_nt = len(self.rho_nn_reader.time_t)
173 dt = self.rho_nn_reader.dt
175 lines = []
176 lines.append('Pulse convolver')
177 lines.append(f' Performing convolution trick on {self.maxnchunks} ranks')
178 lines.append(' Fast Fourier transform')
179 lines.append(f' matrix dimensions {self.rho_nn_reader._parameters.nnshape}')
180 lines.append(f' grid of {wfs_nt} times')
181 lines.append(f' {self.describe_reim()}')
182 lines.append(' In frequency domain')
183 lines.append(f' calculating {self.describe_derivatives()}')
184 if self._warn_too_small_dt:
185 lines.append('WARNING:, the smallest spacing between requested times is ')
186 lines.append(f'{self._requested_dt * au_to_as:.2f}. This is much smaller than the time step ')
187 lines.append(f'in the time-dependent wave functions file ({dt * au_to_as:.2f} as). ')
188 lines.append('No upscaling will be done.')
189 elif self.upscaling == 1:
190 lines.append(' not upscaling data')
191 else:
192 lines.append(f' upscaling by factor {self.upscaling}')
193 lines.append(f' requested time step {self._requested_dt * au_to_as:.2f} as')
194 lines.append(f' time stpe in file {dt * au_to_as:.2f} as.')
195 lines.append(f' convolution with {len(self.pulses)} pulses')
196 lines.append(' Fast inverse Fourier transform')
197 lines.append(f' keeping time grid of {self.nt} times')
198 lines.append(f' {format_times(self.time_t, units="au")}')
199 lines.append('')
201 lines.append(' Redistributing into full density matrices')
202 lines.append(f' {self.niters} iterations to process all chunks')
203 lines.append(f' matrix dimensions {self.rho_nn_reader._parameters.full_nnshape}')
204 lines.append(f' result stored on {self.nranks_result} ranks')
206 return '\n'.join(lines)
208 def get_memory_estimate(self) -> MemoryEstimate:
209 parameters = self.rho_nn_reader._parameters
211 narrays = (2 if self.yield_re and self.yield_im else 1) * len(self.derivative_order_s)
212 temp_shape = parameters.nnshape + (self.maxnchunks, len(self.pulses), self.nlocalt, narrays)
213 result_shape = parameters.full_nnshape + (len(self.pulses), self.nlocalt, narrays)
215 total_result_size = int(np.prod(parameters.full_nnshape + (len(self.pulses), self.nt))) * narrays
217 comment = f'Buffers hold {narrays} arrays ({self.describe_reim()}, {self.describe_derivatives()})'
218 own_memory_estimate = MemoryEstimate(comment=comment)
219 own_memory_estimate.add_key('Temporary buffer', temp_shape, float,
220 on_num_ranks=self.nranks_result)
221 own_memory_estimate.add_key('Result buffer', result_shape, float,
222 total_size=total_result_size,
223 on_num_ranks=self.nranks_result)
225 memory_estimate = MemoryEstimate()
226 memory_estimate.add_child('Time-dependent wave functions reader',
227 self.rho_nn_reader.rho_wfs_reader.get_memory_estimate())
228 memory_estimate.add_child('Parallel density matrices reader',
229 self.rho_nn_reader.get_memory_estimate())
230 memory_estimate.add_child('Pulse convolver',
231 own_memory_estimate)
233 return memory_estimate
235 def _freq_domain_derivative(self,
236 order: int) -> NDArray[np.complex128 | np.float64]:
237 r""" Take derivative in frequency space by multiplying by .. math:
239 (i \omega)^n.
241 Parameters
242 ----------
243 order
244 Order :math:`n` of the derivative.
245 """
246 if order == 0:
247 return np.array([1])
249 padnt = fast_pad(self.rho_nn_reader.nt)
250 dt = self.rho_nn_reader.dt
251 omega_w = 2 * np.pi * np.fft.rfftfreq(padnt, dt)
253 return (1.0j * omega_w) ** order # type: ignore
255 @property
256 def dist_buffer(self) -> DensityMatrixBuffer:
257 """ Buffer of denisty matrices on this rank after redistribution. """
258 if self._dist_buffer is None:
259 self._dist_buffer = self.redistribute()
260 return self._dist_buffer
262 def __iter__(self) -> Generator[DensityMatrixBuffer, None, None]:
263 """ Iteratively read density matrices and perform the pulse convolution.
265 Each iteration performs the calculation for a different chunk of the
266 density matrix. Each rank works on a different set of chunks.
267 All ranks always work on the entire grid of times, during all interations.
269 Yields
270 ------
271 A chunk of the convoluted density matrices, for the requested times.
272 """
273 wfs_time_t = self.rho_nn_reader.time_t # Times in wave functions file
274 padnt = fast_pad(len(wfs_time_t)) # Pad with zeros
276 # Take Fourier transform of pulses
277 pulse_pt = [pulse.pulse.strength(wfs_time_t) for pulse in self.pulses]
278 pulse_pw = np.fft.rfft(pulse_pt, axis=-1, n=padnt)
280 # Create buffer for result
281 dm_buffer = DensityMatrixBuffer(self.rho_nn_reader._parameters.nnshape,
282 (len(self.pulses), self.nt),
283 np.float64)
284 dm_buffer.zero_buffers(real=self.yield_re, imag=self.yield_im, derivative_order_s=self.derivative_order_s)
286 for read_buffer in self.rho_nn_reader:
287 x = []
288 if self.yield_re:
289 x.append((read_buffer._re_buffers[0], dm_buffer._re_buffers))
290 if self.yield_im:
291 x.append((read_buffer._im_buffers[0], dm_buffer._im_buffers))
292 for data_nnt, buffers in x:
293 # Take the Fourier transform of the data (Rerho or Imrho) and divide by
294 # the Fourier transform of the perturbation
295 # The data is padded by zeros, circumventing the periodicity of the Fourier transform
296 data_nnw = self.perturbation.normalize_frequency_response(data_nnt, wfs_time_t, padnt, axis=-1)
298 # Loop over the desired outputs (and which derivative orders they are)
299 for derivative, buffer_nnpt in buffers.items():
300 deriv_w = self._freq_domain_derivative(derivative)
301 for p, pulse_w in enumerate(pulse_pw):
302 # Multiply factor for derivative (power of I*omega)
303 # All timesteps cancel when taking fft->ifft, so do not scale by it
304 _data_nnw = data_nnw * (deriv_w * pulse_w)
305 # Inverse Fourier transform
306 # Optionally, the data is upscaled by padding with even more zeros
307 conv_nnt = np.fft.irfft(_data_nnw, n=padnt * self.upscaling, axis=-1) * self.upscaling
308 buffer_nnpt[..., p, :] = conv_nnt[..., :len(self._time_t)][..., self._flt_t]
310 yield dm_buffer.copy()
312 def create_out_buffer(self) -> DensityMatrixBuffer:
313 """ Create the DensityMatrixBuffer to hold the temporary density matrix after each redistribution """
314 parameters = self.rho_nn_reader._parameters
315 nlocalt = self.nlocalt if self.comm.rank in self.result_on_ranks else 0
316 out_dm = DensityMatrixBuffer(nnshape=parameters.nnshape,
317 xshape=(self.maxnchunks, len(self.pulses), nlocalt),
318 dtype=np.float64)
319 out_dm.zero_buffers(real=self.yield_re, imag=self.yield_im, derivative_order_s=self.derivative_order_s)
321 return out_dm
323 def create_result_buffer(self) -> DensityMatrixBuffer:
324 """ Create the DensityMatrixBuffer to hold the resulting density matrix """
325 parameters = self.rho_nn_reader._parameters
326 nnshape = parameters.full_nnshape
327 full_dm = DensityMatrixBuffer(nnshape=nnshape,
328 xshape=(len(self.pulses), len(self.my_work())),
329 dtype=np.float64)
330 full_dm.zero_buffers(real=self.yield_re, imag=self.yield_im, derivative_order_s=self.derivative_order_s)
332 return full_dm
334 def redistribute(self) -> DensityMatrixBuffer:
335 """ Perform the pulse convolution and redistribute the resulting density matrices.
337 During the pulse convolution step, the data is distributed such that each rank
338 stores the entire time series for one chunk of the density matrices, i.e. indices n1, n2.
340 This function then performs a redistribution of the data such that each rank stores full
341 density matrices, for certain times.
343 If the density matrices are split into more chunks than there are ranks, then the
344 chunks are read, convoluted with pulses and distributed in a loop several times until all
345 data has been processed.
347 Returns
348 -------
349 Density matrix buffer with x-dimensions (number of pulses, number of local times)
350 where the number of local times variers between the ranks.
351 """
352 local_work = iter(self)
353 parameters = self.rho_nn_reader._parameters
354 log = self.log
355 self.rho_nn_reader.rho_wfs_reader.lcao_rho_reader.striden == 0, \
356 'n stride must be 0 (index all) for redistribute'
358 # Time indices of result on each rank
359 timet_r = self.distributed_work()
361 out_dm = self.create_out_buffer()
362 full_dm = self.create_result_buffer()
364 _exhausted = object()
366 # Loop over the chunks of the density matrix
367 for chunki, indices_r in enumerate(self.work_loop_by_ranks()):
368 # At this point, each rank stores one unique chunk of the density matrix.
369 # All ranks have the entire time series of data for their own chunk.
370 # If there are more chunks than ranks, then this loop will run
371 # for several iterations. If the number of chunks is not divisible by the number of
372 # ranks then, during the last iteration, some of the chunks are None (meaning the rank
373 # currently has no data).
375 # List of chunks that each rank currently stores, where element r of the list
376 # contains the chunk that rank r works with. Ranks higher than the length of the list
377 # currently store no chunks.
378 # The list itself is identical on all ranks.
379 chunks_by_rank = [indices[2:] for indices in indices_r if indices is not None]
381 ntargets = len(chunks_by_rank)
383 if self.comm.rank < ntargets:
384 # This rank has data to send. Compute the pulse convolution and store the result
385 dm_buffer = next(local_work)
386 else:
387 # This rank has no data to send
388 assert next(local_work, _exhausted) is _exhausted
389 # Still, we need to create a dummy buffer
390 dm_buffer = DensityMatrixBuffer(nnshape=parameters.nnshape,
391 xshape=(0, 0), dtype=np.float64)
392 dm_buffer.zero_buffers(real=self.yield_re, imag=self.yield_im,
393 derivative_order_s=self.derivative_order_s)
395 log.start('alltoallv')
397 # Redistribute the data:
398 # - dm_buffer stores single chunks of density matrices, for all times and pulses.
399 # - out_dm will store several chunks, for a few times
400 # source_indices_r describes which slices of dm_buffer should be sent to which rank
401 # target_indices_r describes to which positions of the out_dm buffer should be received
402 # from which rank
403 source_indices_r = [None if len(t) == 0 else (slice(None), t) for t in timet_r]
404 target_indices_r = [r if r < ntargets else None for r in range(self.comm.size)]
405 dm_buffer.redistribute(out_dm,
406 comm=self.comm,
407 source_indices_r=source_indices_r,
408 target_indices_r=target_indices_r,
409 log=log)
411 if self.comm.rank == 0:
412 log(f'Chunk {chunki+1}/{self.niters}: distributed convoluted response in '
413 f'{log.elapsed("alltoallv"):.1f}s', who='Response', flush=True)
415 # Copy the redistributed data into the aggregated results buffer
416 for array_nnrpt, full_array_nnpt in zip(out_dm._iter_buffers(), full_dm._iter_buffers()):
417 for r, nn_indices in enumerate(chunks_by_rank):
418 safe_fill_larger(full_array_nnpt[nn_indices], array_nnrpt[:, :, r])
420 assert next(local_work, _exhausted) is _exhausted
422 return full_dm
424 @classmethod
425 def from_reader(cls, # type: ignore
426 rho_nn_reader: KohnShamRhoWfsReader,
427 parameters: RhoParameters,
428 perturbation: PerturbationLike,
429 pulses: Collection[PerturbationLike],
430 derivative_order_s: list[int] = [0],
431 filter_times: list[float] | Array1D[np.float64] | None = None,
432 result_on_ranks: list[int] = []) -> PulseConvolver:
433 time_distributor = AlltoallvTimeDistributor(rho_nn_reader, parameters)
434 pulse_convolver = cls(time_distributor,
435 pulses=pulses,
436 perturbation=perturbation,
437 derivative_order_s=derivative_order_s,
438 filter_times=filter_times,
439 result_on_ranks=result_on_ranks)
440 return pulse_convolver