Coverage for rhodent/density_matrices/readers/gpaw.py: 92%
436 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
1from __future__ import annotations
2from abc import ABC, abstractmethod
3from typing import Callable, Generator
5import numpy as np
6from numpy.typing import NDArray
7from itertools import zip_longest
9from ase.units import Bohr
10from ase.io.ulm import Reader
12from gpaw.mpi import broadcast, world
13from gpaw.lcaotddft.ksdecomposition import KohnShamDecomposition
14from gpaw.lcaotddft.wfwriter import WaveFunctionReader
16from ..buffer import DensityMatrixBuffer
17from ...utils import Logger, add_fake_kpts, get_array_filter
18from ...utils.logging import format_times
19from ...utils.memory import HasMemoryEstimate, MemoryEstimate
20from ...typing import Array1D, Communicator
23class BaseWfsReader(ABC):
25 """ Read wave functions or density matrices from the time-dependent wave functions file.
27 Parameters
28 ----------
29 wfs_fname
30 File name of the time-dependent wave functions file written by GPAW.
31 comm
32 MPI communicator.
33 yield_re
34 Whether to read and yield the real part of wave functions/density matrices.
35 yield_im
36 Whether to read and yield the imaginary part of wave functions/density matrices.
37 stridet
38 Skip this many steps when reading.
39 tmax
40 Last time index to read.
41 filter_times
42 A list of times to read in atomic units. The closest times in the time-dependent
43 wave functions file will be read. Applied after skipping the stridet steps to tmax.
44 log
45 Logger object.
46 """
48 def __init__(self,
49 wfs_fname: str,
50 comm=world,
51 yield_re: bool = True,
52 yield_im: bool = True,
53 stridet: int = 1,
54 tmax: int = -1,
55 filter_times: list[float] | Array1D[np.float64] | None = None,
56 log: Logger | None = None):
57 self._comm = comm
58 self._yield_re = yield_re
59 self._yield_im = yield_im
60 if log is None:
61 log = Logger()
62 self._log = log
64 # The main reader is closed when it is garbage collected
65 # Hence, we need to keep it in the scope
66 self.mainreader = WaveFunctionReader(wfs_fname)
67 self._time_t, self.initreader, self._full_reader_t = prepare_wave_function_readers(
68 self.mainreader, comm, self.log, stridet=stridet, tmax=tmax)
69 self._flt_t = get_array_filter(self._time_t, filter_times)
70 self.reader_t = [self._full_reader_t[r]
71 for r in np.arange(len(self._full_reader_t), dtype=int)[self._flt_t]]
73 @property
74 def comm(self) -> Communicator:
75 """ MPI communicator. """
76 return self._comm
78 @property
79 def yield_re(self) -> bool:
80 """ Whether this object should read real parts. """
81 return self._yield_re
83 @property
84 def yield_im(self) -> bool:
85 """ Whether this object should read imaginary parts. """
86 return self._yield_im
88 @property
89 def log(self) -> Logger:
90 """ Logger object. """
91 return self._log
93 @property
94 def time_t(self) -> Array1D[np.float64]:
95 """ Array of times to read; in atomic units. """
96 return self._time_t[self._flt_t] # type: ignore
98 @property
99 def nt(self) -> int:
100 """ Number of times to read. """
101 return len(self.time_t)
103 @property
104 def dt(self) -> float:
105 """ Time step in atomic units. """
106 time_t = self.time_t
107 dt = time_t[1] - time_t[0]
108 if not np.allclose(time_t[1:] - dt, time_t[:-1]):
109 fname = self.mainreader.filename
110 raise ValueError(f'Unable to get a time step. Variable time step in {fname}.')
112 return dt
114 def work_loop(self,
115 rank: int) -> Generator[int | None, None, None]:
116 """ Yield the time indices that this rank will read.
118 New indices are yielded until the end of self.reader_t is reached
119 (across all ranks).
121 Yields
122 ------
123 Time index between 0 and len(self.reader_t) - 1 corresponding to
124 the time being read by this rank. Or None if this rank has nothing
125 to read.
126 """
127 for t_r in self.work_loop_by_ranks():
128 yield t_r[rank]
130 def work_loop_by_ranks(self) -> Generator[list[int], None, None]:
131 nt = self.nt
132 ntperrank = (nt + self.comm.size - 1) // self.comm.size
134 for localt in range(ntperrank):
135 globalt_r = [rank + localt * self.comm.size for rank in range(self.comm.size)]
136 globalt_r = [globalt if globalt < nt else None for globalt in globalt_r]
137 yield globalt_r
139 def global_work_loop(self) -> Generator[int, None, None]:
140 for chunks_r in self.work_loop_by_ranks():
141 for chunk in chunks_r:
142 if chunk is None:
143 continue
144 yield chunk
146 @abstractmethod
147 def iread(self, *args, **kwargs) -> Generator[DensityMatrixBuffer, None, None]:
148 """ Iteratively read wave functions or density matrices time by time. """
149 raise NotImplementedError
151 @abstractmethod
152 def nnshape(self, *args, **kwargs) -> tuple[int, int]:
153 """ Shape of the density matrices or wave functions. """
154 raise NotImplementedError
156 def gather_on_root(self, *args, **kwargs) -> Generator[DensityMatrixBuffer | None, None, None]:
157 for indices_r, dm_buffer in zip_longest(self.work_loop_by_ranks(),
158 self.iread(*args, **kwargs), fillvalue=None):
159 assert indices_r is not None, 'Work loop shorter than work'
161 # Yield root's own work
162 if self.comm.rank == 0:
163 assert indices_r[0] is not None
164 assert dm_buffer is not None
165 dm_buffer.ensure_contiguous_buffers()
167 yield dm_buffer.copy()
168 else:
169 yield None
171 # Yield the work of non-root
172 for recvrank, recvindices in enumerate(indices_r[1:], start=1):
173 if recvindices is None:
174 # No work on this recvrank
175 continue
177 if self.comm.rank == 0:
178 # Receive work
179 assert dm_buffer is not None
180 dm_buffer.recv_arrays(self.comm, recvrank, log=self.log)
181 yield dm_buffer.copy()
182 else:
183 # Send work to root if there is any
184 if self.comm.rank == recvrank:
185 assert dm_buffer is not None
186 dm_buffer.send_arrays(self.comm, 0, log=self.log)
187 yield None
189 def collect_on_root(self, *args, **kwargs) -> DensityMatrixBuffer | None:
190 nnshape = self.nnshape(*args, **kwargs)
191 full_dm = DensityMatrixBuffer(nnshape, (self.nt, ), np.float64)
192 if self.yield_re:
193 full_dm.zeros(True, 0)
194 if self.yield_im:
195 full_dm.zeros(False, 0)
197 for t, dm_buffer in zip_longest(self.global_work_loop(),
198 self.gather_on_root(*args, **kwargs), fillvalue=None):
199 if self.comm.rank != 0:
200 continue
202 assert t is not None, 'Iterators must be same length'
203 assert dm_buffer is not None, 'Iterators must be same length'
205 for partial_data_nn, full_data_nn in zip(dm_buffer._iter_buffers(),
206 full_dm[t]._iter_buffers()):
207 full_data_nn[:] += partial_data_nn
209 if self.comm.rank != 0:
210 return None
212 return full_dm
215class KohnShamRhoWfsReader(HasMemoryEstimate, BaseWfsReader):
217 """ Read density matrices from the time-dependent wave functions file.
219 Yield density matrices time by time.
221 Parameters
222 ----------
223 wfs_fname
224 File name of the time-dependent wave functions file.
225 ksd
226 KohnShamDecomposition object or file name to the ksd file.
227 comm
228 MPI communicator.
229 yield_re
230 Whether to read and yield the real part of wave functions/density matrices.
231 yield_im
232 Whether to read and yield the imaginary part of wave functions/density matrices.
233 stridet
234 Skip this many steps when reading.
235 tmax
236 Last time index to read.
237 filter_times
238 A list of times to read in atomic units. The closest times in the time-dependent wave functions file
239 will be read.
240 striden
241 Option passed through to the LCAORhoWfsReader.
242 log
243 Logger object.
244 """
246 def __init__(self,
247 wfs_fname: str,
248 ksd: str | KohnShamDecomposition,
249 comm=world,
250 yield_re: bool = True,
251 yield_im: bool = True,
252 stridet: int = 1,
253 tmax: int = -1,
254 filter_times: list[float] | Array1D[np.float64] | None = None,
255 log: Logger | None = None,
256 striden: int = 0):
257 # Set up an internal LCAO density matrix reader
258 self.lcao_rho_reader = LCAORhoWfsReader(
259 wfs_fname=wfs_fname, comm=comm,
260 yield_re=yield_re, yield_im=yield_im, log=log,
261 stridet=stridet, tmax=tmax,
262 filter_times=filter_times, striden=striden)
264 # And copy its attributes to self
265 self._yield_re = yield_re
266 self._yield_im = yield_im
267 self._comm = self.lcao_rho_reader.comm
268 self._log = self.lcao_rho_reader.log
269 self.mainreader = self.lcao_rho_reader.mainreader
270 self.initreader = self.lcao_rho_reader.initreader
271 self._full_reader_t = self.lcao_rho_reader._full_reader_t
272 self.reader_t = self.lcao_rho_reader.reader_t
273 self._flt_t = self.lcao_rho_reader._flt_t
274 self._time_t = self.lcao_rho_reader._time_t
276 # Set up ksd
277 if isinstance(ksd, KohnShamDecomposition):
278 self._ksd = ksd
279 else:
280 self._ksd = KohnShamDecomposition(filename=ksd)
281 add_fake_kpts(self._ksd)
283 self._C0S_sknM: NDArray[np.float64] | None = None
284 self._rho0_sknn: NDArray[np.float64] | None = None
286 def __str__(self) -> str:
287 nn, nM = proxy_coefficients(self.initreader).shape[2:]
288 ntperrank = (self.nt + self.comm.size - 1) // self.comm.size
290 lines = []
291 lines.append('Time-dependent wave functions reader')
292 lines.append(' Constructing density matrices in basis of ground state orbitals.')
293 lines.append('')
294 lines.append(f' file: {self.mainreader.filename}')
295 lines.append(f' wave function dimensions {(nn, nM)}')
296 lines.append(f' {self.nt} times')
297 lines.append(f' {format_times(self.time_t, units="au")}')
298 lines.append(f' {self.comm.size} ranks reading in {ntperrank} iterations')
300 return '\n'.join(lines)
302 def get_memory_estimate(self) -> MemoryEstimate:
303 nn, nM = proxy_coefficients(self.initreader).shape[2:]
305 memory_estimate = MemoryEstimate()
306 memory_estimate.add_key('C0S_nM', (nn, nM), float,
307 on_num_ranks=self.comm.size)
308 memory_estimate.add_key('rho0_MM', (nM, nM), float,
309 on_num_ranks=self.comm.size)
311 return memory_estimate
313 @property
314 def ksd(self) -> KohnShamDecomposition:
315 """ Kohn-Sham decomposition object. """
316 return self._ksd
318 @property
319 def C0S_sknM(self) -> NDArray[np.float64]:
320 if self._C0S_sknM is None:
321 self.log(f'Constructing C0_sknM on {self.comm.size} ranks',
322 who='Reader', rank=0, flush=True)
323 self._C0S_sknM = read_C0S_parallel(self.ksd.reader, comm=self.comm)
324 self.log('Constructed C0_sknM',
325 who='Reader', rank=0, flush=True)
326 assert self._C0S_sknM is not None
327 return self._C0S_sknM
329 @property
330 def rho0_sknn(self) -> NDArray[np.float64]:
331 if self._rho0_sknn is None:
332 f_skn = self.ksd.reader.occ_un[:]
333 nn = f_skn.shape[2]
334 rho0_sknn = np.zeros(f_skn.shape[:2] + (nn, nn))
335 diag_nn = np.eye(nn, dtype=bool)
336 rho0_sknn[..., diag_nn] = f_skn
337 self._rho0_sknn = rho0_sknn
338 return self._rho0_sknn
340 def nnshape(self,
341 s: int,
342 k: int,
343 n1: slice,
344 n2: slice) -> tuple[int, int]:
345 n1size = n1.stop - n1.start
346 n2size = n2.stop - n2.start
347 nnshape = (n1size, n2size)
348 return nnshape
350 def parallel_prepare(self):
351 """ Read everything necessary synchronously on all ranks. """
352 self.C0S_sknM
353 self.lcao_rho_reader.rho0_skMM
355 def iread(self,
356 s: int,
357 k: int,
358 n1: slice,
359 n2: slice) -> Generator[DensityMatrixBuffer, None, None]:
360 """ Read the density matrices time by time.
362 Parameters
363 ----------
364 s, k, n1, n2
365 Read these indices.
366 """
367 dm_buffer = DensityMatrixBuffer(self.nnshape(s, k, n1, n2), (), np.float64)
368 if self.yield_re:
369 dm_buffer.zeros(True, 0)
370 if self.yield_im:
371 dm_buffer.zeros(False, 0)
373 einsumstr = 'nN,mM,NM->nm'
374 self.C0S_sknM # Read this on all ranks
376 nM = self.C0S_sknM.shape[3]
377 sliceM = slice(0, nM)
379 for lcao_dm in self.lcao_rho_reader.iread(s, k, sliceM, sliceM):
380 C0S_nM1 = self.C0S_sknM[s, k, n1, :] # Here n is full KS basis
381 C0S_nM2 = self.C0S_sknM[s, k, n2, :]
383 self.log.start('read')
385 # Conjugate C_nM2
386 if self.yield_re:
387 Rerho_MM = lcao_dm.real
388 Rerho_x = np.einsum(einsumstr, C0S_nM1, C0S_nM2, Rerho_MM, optimize=True)
389 dm_buffer.safe_fill(True, 0, Rerho_x)
390 if self.yield_im:
391 Imrho_MM = lcao_dm.imag
392 Imrho_x = np.einsum(einsumstr, C0S_nM1, C0S_nM2, Imrho_MM, optimize=True)
393 dm_buffer.safe_fill(False, 0, Imrho_x)
395 yield dm_buffer
398class LCAORhoWfsReader(BaseWfsReader):
400 """ Read density matrices in the LCAO basis from the time-dependent wave functions file.
402 Yield density matrices time by time.
403 """
405 def __init__(self,
406 wfs_fname: str,
407 comm=world,
408 yield_re: bool = True,
409 yield_im: bool = True,
410 stridet: int = 1,
411 tmax: int = -1,
412 filter_times: list[float] | Array1D[np.float64] | None = None,
413 log: Logger | None = None,
414 striden: int = 4):
416 super().__init__(wfs_fname=wfs_fname, comm=comm,
417 yield_re=yield_re, yield_im=yield_im, log=log,
418 stridet=stridet, tmax=tmax,
419 filter_times=filter_times)
420 self._f_skn: NDArray[np.float64] | None = None
421 self._C0_sknM: NDArray[np.float64] | None = None
422 self._rho0_skMM: NDArray[np.float64] | None = None
423 self._striden = striden
425 @property
426 def nn(self) -> int:
427 return self.f_skn.shape[2]
429 @property
430 def striden(self) -> int:
431 return self._striden
433 @property
434 def true_striden(self) -> int:
435 if self.striden == 0:
436 return self.nn
437 return self.striden
439 @property
440 def f_skn(self) -> NDArray[np.float64]:
441 """ Occupations numbers. """
442 if self._f_skn is None:
443 self._f_skn = proxy_occupations(self.initreader)[:]
444 return self._f_skn
446 @property
447 def C0_sknM(self) -> NDArray[np.float64]:
448 if self._C0_sknM is None:
449 C0_sknM = proxy_coefficients(self.initreader)[:]
450 assert np.max(np.abs(C0_sknM.imag)) < 1e-20
451 self._C0_sknM = C0_sknM.real
452 return self._C0_sknM
454 @property
455 def rho0_skMM(self) -> NDArray[np.float64]:
456 if self._rho0_skMM is None:
457 self.log(f'Constructing rho0_skMM on {self.comm.size} ranks',
458 who='Reader', rank=0, flush=True)
459 self._rho0_skMM = calculate_rho0_parallel(self.f_skn, self.C0_sknM, comm=self.comm)
460 self.log('Constructed rho0_skMM',
461 who='Reader', rank=0, flush=True)
462 assert self._rho0_skMM is not None
463 return self._rho0_skMM
465 def inner_work_loop(self) -> Generator[slice, None, None]:
466 for n in range(0, self.nn, self.true_striden):
467 yield slice(n, n + self.true_striden)
469 def subtract_ground_state(self,
470 dm_buffer: DensityMatrixBuffer,
471 s: int,
472 k: int,
473 M1: slice,
474 M2: slice):
475 rhs = -self.rho0_skMM[s, k, M1, M2]
476 dm_buffer.safe_add(True, 0, rhs)
478 def nnshape(self,
479 s: int,
480 k: int,
481 M1: slice,
482 M2: slice) -> tuple[int, int]:
483 M1size = M1.stop - M1.start
484 M2size = M2.stop - M2.start
485 MMshape = (M1size, M2size)
486 return MMshape
488 def iread(self,
489 s: int,
490 k: int,
491 M1: slice,
492 M2: slice) -> Generator[DensityMatrixBuffer, None, None]:
493 """ Read the density matrices time by time.
495 Parameters
496 ----------
497 s, k, M1, M2
498 Read these indices.
499 """
500 dm_buffer = DensityMatrixBuffer(self.nnshape(s, k, M1, M2), (), np.float64)
502 einsumstr = 'n,nM,nO->MO'
504 self.rho0_skMM # Construct synchronously on all ranks
505 for globalt in self.work_loop(self.comm.rank):
506 if globalt is None:
507 continue
509 if self.yield_re:
510 dm_buffer.zeros(True, 0)
511 if self.yield_im:
512 dm_buffer.zeros(False, 0)
514 reader = self.reader_t[globalt]
515 for n in self.inner_work_loop():
516 C_nM1 = proxy_C_nM(reader, s, k, n, M1) # Here n is occupied states only
517 C_nM2 = proxy_C_nM(reader, s, k, n, M2)
518 f_n = proxy_occupations(reader, s, k)[n]
519 path = np.einsum_path(einsumstr, f_n, C_nM1.real, C_nM2.real, optimize='optimal')[0]
521 # Conjugate C_nM2
522 if self.yield_re:
523 Rerho_x = np.einsum(einsumstr, f_n, C_nM1.real, C_nM2.real, optimize=path)
524 Rerho_x += np.einsum(einsumstr, f_n, C_nM1.imag, C_nM2.imag, optimize=path)
525 dm_buffer.safe_add(True, 0, Rerho_x)
526 if self.yield_im:
527 Imrho_x = np.einsum(einsumstr, f_n, C_nM1.imag, C_nM2.real, optimize=path)
528 Imrho_x -= np.einsum(einsumstr, f_n, C_nM1.real, C_nM2.imag, optimize=path)
529 dm_buffer.safe_add(False, 0, Imrho_x)
531 self.subtract_ground_state(dm_buffer, s, k, M1, M2)
533 yield dm_buffer
536class WfsReader(BaseWfsReader):
538 """ Read wave function LCAO coefficients from the time-dependent wave functions file.
540 Yield wave functions time by time.
541 """
543 def __init__(self,
544 wfs_fname: str,
545 comm=world,
546 yield_re: bool = True,
547 yield_im: bool = True,
548 stridet: int = 1,
549 tmax: int = -1,
550 filter_times: list[float] | Array1D[np.float64] | None = None,
551 log: Logger | None = None):
552 super().__init__(wfs_fname=wfs_fname, comm=comm,
553 yield_re=yield_re, yield_im=yield_im, log=log,
554 stridet=stridet, tmax=tmax,
555 filter_times=filter_times)
556 self._f_skn: NDArray[np.float64] | None = None
557 self._C0_sknM: NDArray[np.float64] | None = None
559 @property
560 def nn(self) -> int:
561 return self.C0_sknM.shape[2]
563 @property
564 def nM(self) -> int:
565 return self.C0_sknM.shape[3]
567 @property
568 def f_skn(self) -> NDArray[np.float64]:
569 """ Occupations """
570 if self._f_skn is None:
571 self._f_skn = proxy_occupations(self.initreader)[:]
572 return self._f_skn
574 @property
575 def C0_sknM(self) -> NDArray[np.float64]:
576 if self._C0_sknM is None:
577 C0_sknM = proxy_coefficients(self.initreader)[:]
578 assert np.max(np.abs(C0_sknM.imag)) < 1e-20
579 self._C0_sknM = C0_sknM.real
580 return self._C0_sknM
582 def nnshape(self,
583 s: int,
584 k: int,
585 n: slice,
586 M: slice) -> tuple[int, int]:
587 nsize = n.stop - n.start
588 Msize = M.stop - M.start
589 nMshape = (nsize, Msize)
590 return nMshape
592 def iread(self,
593 s: int,
594 k: int,
595 n: slice,
596 M: slice) -> Generator[DensityMatrixBuffer, None, None]:
597 """ Read the density matrices time by time.
599 Parameters
600 ----------
601 s, k, n, M
602 Read these indices.
603 """
604 dm_buffer = DensityMatrixBuffer(self.nnshape(s, k, n, M), (), np.float64)
606 for globalt in self.work_loop(self.comm.rank):
607 if globalt is None:
608 continue
610 if self.yield_re:
611 dm_buffer.zeros(True, 0)
612 if self.yield_im:
613 dm_buffer.zeros(False, 0)
615 reader = self.reader_t[globalt]
616 C_nM = proxy_C_nM(reader, s, k, n, M)
617 if self.yield_re:
618 dm_buffer.safe_fill(True, 0, C_nM.real)
619 if self.yield_im:
620 dm_buffer.safe_fill(False, 0, C_nM.imag)
622 yield dm_buffer
625def prepare_wave_function_readers(mainreader,
626 comm,
627 log: Callable = print,
628 stridet: int = 1,
629 tmax: int = -1,
630 parallel: bool = True,
631 ) -> tuple[Array1D[np.float64], WaveFunctionReader, list[WaveFunctionReader]]:
632 readerlen = len(mainreader)
633 log(f'Opening time-dependent wave functions file {mainreader.filename}', who='Reader', rank=0, flush=True)
635 # Encode the action as int according to the following list
636 int2action = [None, 'init', 'kick', 'propagate']
637 action2int = {a: i for i, a in enumerate(int2action)}
639 if parallel:
640 # Read in parallel
641 readrange = range(0, readerlen, comm.size)
642 else:
643 # Read everything on root
644 readrange = range(readerlen) if comm.rank == 0 else range(0)
646 if comm.rank == 0 or parallel:
647 # Buffers to read to on all ranks when parallel
648 time_1 = np.array([0], dtype=float)
649 action_1 = np.array([0], dtype=int)
651 def round_up(val):
652 v = (val + comm.size - 1)
653 v = v // comm.size
654 return v * comm.size
656 if comm.rank == 0:
657 # Buffers for gathering
658 time_t = np.zeros(round_up(readerlen), dtype=time_1.dtype)
659 action_t = np.zeros(round_up(readerlen), dtype=action_1.dtype)
661 # Read all times in parallel, gather on root
662 for roott in readrange:
663 t = roott + comm.rank
664 if t < readerlen:
665 log(f'Opened item #{t}', who='Reader', comm=comm, if_elapsed=5, flush=True)
667 reader = mainreader[t]
668 action = getattr(reader, 'action', None)
669 action_1[:] = action2int[action]
670 try:
671 time_1[:] = getattr(reader, 'time')
672 except AttributeError:
673 # Depending on the action these might not be set
674 time_1[:] = np.nan
676 if parallel:
677 comm.gather(time_1, 0, time_t[roott:roott + comm.size] if comm.rank == 0 else None)
678 comm.gather(action_1, 0, action_t[roott:roott + comm.size] if comm.rank == 0 else None)
679 else:
680 time_t[t] = time_1[0]
681 action_t[t] = action_1[0]
683 if comm.rank == 0:
684 # The root rank must count the number of time entries
685 nreadt = 0
686 lasttime = -np.inf
688 readtimes = []
689 timereadert = []
690 initreadert = None
691 for t in range(readerlen):
692 action = int2action[action_t[t]]
693 curtime = time_t[t]
695 if curtime <= lasttime:
696 log(f'Negative time difference at #{t} '
697 f'({curtime:.1f} <= {lasttime:.1f}). Skipping',
698 who='Reader', comm=comm, flush=True)
699 continue
701 if action is None:
702 # This is just some dummy entry
703 continue
705 # Find the first 'init' entry
706 if action == 'init':
707 assert curtime >= lasttime, f'Times not matching {t}:{curtime} !>= {lasttime}'
708 if initreadert is None:
709 initreadert = t
710 continue
712 assert action in ['kick', 'propagate']
713 assert curtime != np.nan
714 if nreadt % stridet == 0:
715 # Save multiples of stridet
716 # Note that 0 will always be saved. That contains the kick
717 readtimes.append(curtime)
718 timereadert.append(t)
719 nreadt += 1
720 if tmax > 0 and nreadt >= tmax:
721 break
722 lasttime = curtime
724 assert initreadert is not None
725 broadcast_obj = (initreadert, timereadert, np.array(readtimes))
726 else:
727 broadcast_obj = None
729 # Distribute
730 (initreadert, timereadert, time_t) = broadcast(broadcast_obj, comm=comm, root=0)
731 initreader = mainreader[initreadert]
732 timereader_t = [mainreader[t] for t in timereadert]
734 log(f'Opened time-dependent wave functions file with {len(timereader_t)} times',
735 who='Reader', rank=0, flush=True)
737 return time_t, initreader, timereader_t
740def read_C0S_parallel(ksdreader: Reader,
741 comm=None) -> NDArray[np.float64]:
742 if comm is None:
743 comm = world
745 # Compute C0S in parallel
746 C0_sknM = ksdreader.proxy('C0_unM', 0)
747 S_skMM = ksdreader.proxy('S_uMM', 0)
748 nM = C0_sknM.shape[3]
750 # Rank local contributions
751 strideM = (nM + comm.size - 1) // comm.size
752 sliceM = slice(strideM * comm.rank, strideM * (comm.rank + 1))
753 C0_sknM = C0_sknM[:][..., sliceM]
754 S_skMM = S_skMM[:][..., sliceM, :]
756 # Compute C0S
757 C0S_sknM = np.einsum('sknO,skOM->sknM', C0_sknM, S_skMM, optimize=True)
759 # Sum and distribute
760 comm.sum(C0S_sknM)
762 return C0S_sknM
765def calculate_rho0_parallel(f_skn: NDArray[np.float64],
766 C0_sknM: NDArray[np.float64],
767 comm=None) -> NDArray[np.float64]:
768 if comm is None:
769 comm = world
771 nn = f_skn.shape[2]
772 striden = (nn + comm.size - 1) // comm.size
773 slicen = slice(striden * comm.rank, striden * (comm.rank + 1))
774 f_skn = f_skn[:][..., slicen]
775 C0_sknM = C0_sknM[:][..., slicen, :]
777 # Compute density matrix
778 rho0_skMM = np.einsum('skn,sknM,sknO->skMO',
779 f_skn, C0_sknM, C0_sknM,
780 optimize=True)
782 # Sum and distribute
783 comm.sum(rho0_skMM)
785 return rho0_skMM
788def proxy_coefficients(reader, *indices):
789 """ Proxy the wave function coefficients, with the correct units."""
790 coefficients = reader.wave_functions.proxy('coefficients', *indices)
791 coefficients.scale = Bohr ** 1.5
793 return coefficients
796def proxy_occupations(reader, *indices):
797 """ Proxy the wave function occupations with the correct scale."""
798 occupations = reader.wave_functions.proxy('occupations', *indices)
799 occupations.scale = 2 / reader.wave_functions.occupations.shape[0]
801 return occupations
804def proxy_C_nM(reader,
805 *indices):
806 """ Proxy the wave function coefficients, with the correct units.
808 Make sure that the proxied array has at least two dimensions"""
809 x = indices[:-2]
810 n = indices[-2]
811 M = indices[-1]
813 C_nM = proxy_coefficients(reader, *x)
814 if isinstance(n, slice):
815 C_x = C_nM[n][:, M]
816 else:
817 C_M = C_nM.proxy(n)
818 C_x = np.atleast_2d(C_M[M])
819 assert len(C_x.shape) == 2
820 return C_x