Coverage for rhodent/density_matrices/readers/numpy.py: 87%
171 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
1from __future__ import annotations
3from typing import Collection
4import numpy as np
5from numpy.typing import ArrayLike, NDArray
7from gpaw.mpi import world
8from gpaw.lcaotddft.ksdecomposition import KohnShamDecomposition
10from ...perturbation import Perturbation, PerturbationLike, create_perturbation
11from ...typing import Array1D, Communicator
12from ...utils import Logger, add_fake_kpts, find_files, partial_format, get_gaussian_pulse_values, filter_array
13from ...utils.logging import format_times, format_frequencies
16class TimeDensityMatrixReader:
18 """ Finds density matrices in the time domain saved to disk and reads them.
20 Parameters
21 ----------
22 pulserho_fmt
23 Formatting string for the density matrices saved to disk.
25 The formatting string should be a plain string containing variable
26 placeholders within curly brackets ``{}``. It should not be confused with
27 a formatted string literal (f-string).
29 Example:
31 * pulserho_fmt = ``pulserho/t{time:09.1f}{tag}.npy``.
33 Accepts variables
35 * ``{time}`` - Time in units of as.
36 * ``{tag}`` - Derivative tag, ``''``, ``'-Iomega'``, or ``'-omega2'``.
37 * ``{pulsefreq}`` - Pulse frequency in units of eV.
38 * ``{pulsefwhm}`` - Pulse FWHM in units of fs.
39 ksd
40 KohnShamDecomposition object or file name to the ksd file.
41 filter_times
42 Look for these times (or as close to them as possible). In units of as.
43 pulses
44 Density matrices in response to these pulses. By default, no information about the
45 pulse.
46 derivative_order_s
47 List of derivative orders.
48 log
49 Logger object.
50 """
52 def __init__(self,
53 pulserho_fmt: str,
54 ksd: str | KohnShamDecomposition,
55 filter_times: Array1D[np.float64] | list[float] | None = None,
56 pulses: Collection[PerturbationLike] = [None],
57 derivative_order_s: list[int] = [0],
58 log: Logger | None = None,
59 comm: Communicator | None = None):
60 if log is None:
61 log = Logger()
62 self._log = log
63 self._comm = world if comm is None else comm
65 if isinstance(ksd, KohnShamDecomposition):
66 self._ksd = ksd
67 else:
68 self._ksd = KohnShamDecomposition(filename=ksd)
69 add_fake_kpts(self._ksd)
71 self._pulses = [create_perturbation(pulse) for pulse in pulses]
72 self.derivative_order_s = derivative_order_s
73 self.pulserho_fmt = pulserho_fmt
74 tag_s = ['', '-Iomega', '-omega2']
76 if self.comm.rank == 0:
77 # Look for files on the root rank
79 nested_times: list[Array1D[np.float64]] = []
80 for pulse in self.pulses:
81 for derivative in self.derivative_order_s:
82 # Partially format the format string, i.e. fill out the pulsefreq,
83 # pulsefwhm, and tag fields
84 tag = tag_s[derivative]
85 fmt = partial_format(pulserho_fmt, tag=tag, **get_gaussian_pulse_values(pulse))
87 # Search the file tree for files
88 f = find_files(fmt, expected_keys=['time'])
89 nested_times.append(f['time'])
91 self._time_t, self._part_time_t = extract_common(nested_times if self.comm.rank == 0 else None,
92 filter_times,
93 self.comm)
95 @property
96 def ksd(self) -> KohnShamDecomposition:
97 """ Kohn-Sham decomposition object. """
98 return self._ksd
100 @property
101 def comm(self) -> Communicator:
102 """ MPI communicator. """
103 return self._comm
105 @property
106 def log(self) -> Logger:
107 """ Logger object. """
108 return self._log
110 @property
111 def times(self) -> Array1D[np.float64]:
112 """ Simulation time in units of as. """
113 return self._time_t
115 @property
116 def nt(self) -> int:
117 """ Number of times. """
118 return len(self.times)
120 @property
121 def pulses(self) -> list[Perturbation]:
122 """ Pulses with which density matrices are convoluted. """
123 return self._pulses
125 def __str__(self) -> str:
126 lines = ['Response from density matrices on disk']
128 lines.append('')
129 lines.append(f'Format string: {self.pulserho_fmt}')
130 lines.append(f'Calculating response for {self.nt} times and {len(self.pulses)} pulses')
131 lines.append(f' times: {format_times(self.times)}')
132 npartt = len(self._part_time_t)
133 if npartt > 0:
134 lines.append(f'Additionally {npartt} times are available for some '
135 'pulses/derivatives only')
136 lines.append(f' times: {format_times(self._part_time_t)}')
138 return '\n'.join(lines)
140 def read(self,
141 time: float,
142 pulse: Perturbation,
143 derivative: int) -> NDArray[np.complex128]:
144 r""" Read single density matrix from disk.
146 Parameters
147 ----------
148 time
149 Simulation time in units of as.
150 pulse
151 Pulse which this density matrix is in response to.
152 derivative
153 Read derivative of this order.
155 Returns
156 -------
157 Density matrix :math:`\rho_ia`.
158 """
160 tag_s = ['', '-Iomega', '-omega2']
162 fname_kw = dict(time=time, tag=tag_s[derivative],
163 **get_gaussian_pulse_values(pulse))
164 fname = self.pulserho_fmt.format(**fname_kw)
166 rho = read_numpy(fname)
167 if len(rho.shape) == 1:
168 # Transform from ravelled form
169 rho_ia = self.ksd.M_ia_from_M_p(rho)
170 else:
171 rho_ia = rho
173 return rho_ia
176class FrequencyDensityMatrixReader:
178 """ Finds density matrices in the frequency domain saved to disk and reads them.
180 Parameters
181 ----------
182 frho_fmt
183 Formatting string for the density matrices
184 in frequency space saved to disk.
186 The formatting string should be a plain string containing variable
187 placeholders within curly brackets ``{}``. It should not be confused with
188 a formatted string literal (f-string).
190 Example:
192 * frho_fmt = ``frho/w{freq:05.2f}-{reim}.npy``.
194 Accepts variables:
196 * ``{freq}`` - Frequency in units of eV.
197 * ``{reim}`` - ``'Re'`` or ``'Im'`` for Fourier transform of real/imaginary
198 part of density matrix.
199 ksd
200 KohnShamDecomposition object or file name to the ksd file.
201 filter_frequencies
202 Look for these frequencies (or as close to them as possible). In units of eV.
203 log
204 Logger object.
205 """
207 def __init__(self,
208 frho_fmt: str,
209 ksd: str | KohnShamDecomposition,
210 filter_frequencies: ArrayLike | None = None,
211 real: bool = True,
212 imag: bool = True,
213 log: Logger | None = None,
214 comm: Communicator | None = None):
215 if log is None:
216 log = Logger()
217 self._log = log
218 if isinstance(ksd, KohnShamDecomposition):
219 self._ksd = ksd
220 else:
221 self._ksd = KohnShamDecomposition(filename=ksd)
222 add_fake_kpts(self._ksd)
224 self._comm = world if comm is None else comm
225 self.frho_fmt = frho_fmt
226 if not real and not imag:
227 raise ValueError('At least one of real or imag must be true')
228 reim_r = ['Re'] if real else []
229 reim_r += ['Im'] if imag else []
231 if self.comm.rank == 0:
232 # Look for files on the root rank
234 nested_freqs: list[Array1D[np.float64]] = []
235 for reim in reim_r:
236 # Partially format the format string, i.e. fill out the pulsefreq,
237 # pulsefwhm, and tag fields
238 fmt = partial_format(frho_fmt, reim=reim)
240 # Search the file tree for files
241 f = find_files(fmt, expected_keys=['freq'])
242 nested_freqs.append(f['freq'])
244 self._freq_w, self._part_freq_w = extract_common(nested_freqs if self.comm.rank == 0 else None,
245 filter_frequencies,
246 self.comm)
248 @property
249 def ksd(self) -> KohnShamDecomposition:
250 """ Kohn-Sham decomposition object. """
251 return self._ksd
253 @property
254 def comm(self) -> Communicator:
255 """ MPI communicator. """
256 return self._comm
258 @property
259 def log(self) -> Logger:
260 """ Logger object. """
261 return self._log
263 @property
264 def frequencies(self) -> Array1D[np.float64]:
265 """ Frequencies in units of eV. """
266 return self._freq_w
268 @property
269 def nw(self) -> int:
270 """ Number of frequencies. """
271 return len(self.frequencies)
273 def __str__(self) -> str:
274 lines = ['Response from Fourier transform of density matrices on disk']
276 lines.append('')
277 lines.append(f'Format string: {self.frho_fmt}')
278 lines.append(f'Calculating response for {self.nw} frequencies')
279 lines.append(f' frequencies: {format_frequencies(self.frequencies)}')
280 npartw = len(self._part_freq_w)
281 if npartw > 0:
282 lines.append(f'Additionally {npartw} frequencies are available for real '
283 'or imaginary parts but not both')
284 lines.append(f' frequencies: {format_frequencies(self._part_freq_w)}')
286 return '\n'.join(lines)
288 def read(self,
289 frequency: float,
290 real: bool) -> NDArray[np.complex128]:
291 fname_kw = dict(freq=frequency, reim='Re' if real else 'Im')
292 fname = self.frho_fmt.format(**fname_kw)
294 rho = read_numpy(fname)
296 if len(rho.shape) == 1 and self.ksd.only_ia:
297 # Twice the rho is saved by the KohnShamDecomposition transform
298 rho /= 2
300 if len(rho.shape) == 1:
301 # Transform from ravelled form
302 rho_ia = self.ksd.M_ia_from_M_p(rho)
303 else:
304 rho_ia = rho
306 return rho_ia
309def read_numpy(fname: str) -> NDArray[np.complex128]:
310 r""" Read density matrix from numpy binary file or archive.
312 Supports data stored in non-ravelled form (preferred; indices :math:`ia`
313 for electron-hole pairs) and in ravelled form (legacy; single index :math:`p`
314 for electron-hole pairs).
316 Parameters
317 ----------
318 fname
319 File name.
321 Returns
322 -------
323 Density matrix :math:`\rho_ia` or :math:`\rho_p`.
324 """
326 f = np.load(fname)
327 if isinstance(f, np.lib.npyio.NpzFile):
328 # Read npz file
329 if 'rho_p' in f.files:
330 rho = f['rho_p']
331 if len(rho.shape) != 1:
332 raise RuntimeError(f'Expected 1D array, got shape {rho.shape}.')
333 elif 'rho_ia' in f.files:
334 rho = f['rho_ia']
335 if len(rho.shape) != 2:
336 raise RuntimeError(f'Expected 2D array, got shape {rho.shape}.')
337 else:
338 raise RuntimeError("Expected file 'rho_p', or 'rho_ia' in file")
339 f.close()
340 else:
341 # Read npy file
342 assert isinstance(f, np.ndarray)
343 rho = f
344 if len(rho.shape) not in [1, 2]:
345 raise RuntimeError(f'Expected 1D or 2D array, got shape {rho.shape}.')
347 return rho
350def extract_common(nested_values: list[Array1D[np.float64]] | None,
351 filter_values: ArrayLike | None,
352 comm: Communicator) -> tuple[Array1D[np.float64], Array1D[np.float64]]:
353 """ From a list of arrays, extract array elements that are present in all arrays.
355 Parameters
356 ----------
357 nested_values
358 List of arrays on root rank, ``None`` on other ranks.
359 filter_values
360 Filter the values, keeping only the values closest to these.
361 comm
362 MPI communicator.
364 Returns
365 -------
366 Tuple of values present in all arrays, and values present in at least one array. \
367 Broadcast to all ranks.
368 """
369 if comm.rank == 0:
370 assert nested_values is not None
371 # Values present in all arrays
372 values_any: set[float] = set(nested_values[0])
373 # Values present in any array
374 values_all: set[float] = set(values_any)
376 for values in nested_values[1:]:
377 values_any |= set(values)
378 values_all &= set(values)
380 # Filter the values
381 values_any = set(filter_array(sorted(values_any), filter_values)) # type: ignore
382 values_all = set(filter_array(sorted(values_all), filter_values)) # type: ignore
384 # Values present in some arrays only
385 values_some = values_any - values_all
387 # Broadcast to all ranks
388 shapes = np.array([len(values_all), len(values_some)], dtype=int)
389 else:
390 assert nested_values is None
391 shapes = np.array([0, 0], dtype=int)
393 comm.broadcast(shapes, 0)
395 # Values in all arrays: as array
396 array_all = np.zeros(shapes[0], dtype=float)
397 # Values in some arrays: as array
398 array_some = np.zeros(shapes[1], dtype=float)
400 if comm.rank == 0:
401 array_all[:] = sorted(values_all)
402 array_some[:] = sorted(values_some)
404 comm.broadcast(array_all, 0)
405 if array_some.size > 0:
406 comm.broadcast(array_some, 0)
408 return array_all, array_some