Coverage for rhodent/density_matrices/base.py: 68%
148 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
1from __future__ import annotations
3from abc import ABC, abstractmethod
4from typing import Generic, Generator, NamedTuple, TypeVar
6from gpaw.lcaotddft.ksdecomposition import KohnShamDecomposition
8from .density_matrix import DensityMatrix
9from ..utils import Logger, add_fake_kpts, two_communicators
10from ..utils.memory import HasMemoryEstimate, MemoryEstimate
11from ..typing import Communicator
14class WorkMetadata(NamedTuple):
15 """ Metadata to the density matrices """
16 density_matrices: BaseDensityMatrices
18 @property
19 def global_indices(self) -> tuple[int, ...]:
20 """ Unique index for this work. """
21 raise NotImplementedError
23 @property
24 @abstractmethod
25 def desc(self) -> str:
26 raise NotImplementedError
28 def __str__(self) -> str:
29 return f'{self.__class__.__name__}{self.global_indices}'
31 def __repr__(self) -> str:
32 return f'{self.__class__.__name__}{self.global_indices}'
35WorkMetadataT = TypeVar('WorkMetadataT', bound=WorkMetadata)
38class BaseDensityMatrices(HasMemoryEstimate, ABC, Generic[WorkMetadataT]):
40 _log: Logger
41 _ksd: KohnShamDecomposition
43 """
44 Collection of density matrices in the Kohn-Sham basis for different times
45 or frequencies, possibly after convolution with various pulses.
47 Plain density matrices and/or derivatives thereof may be represented.
49 Parameters
50 ----------
51 ksd
52 KohnShamDecomposition object or file name.
53 real
54 Calculate the real part of density matrices.
55 imag
56 Calculate the imaginary part of density matrices.
57 calc_size
58 Size of the calculation communicator.
59 """
61 def __init__(self,
62 ksd: KohnShamDecomposition | str,
63 real: bool = True,
64 imag: bool = True,
65 calc_size: int = 1,
66 log: Logger | None = None):
67 assert real or imag
68 self._reim_r: list[str] = []
69 if real:
70 self._reim_r.append('Re')
71 if imag:
72 self._reim_r.append('Im')
74 if log is None:
75 self._log = Logger()
76 else:
77 self._log = log
79 self._loop_comm, self._calc_comm = two_communicators(-1, calc_size)
80 if isinstance(ksd, KohnShamDecomposition):
81 self._ksd = ksd
82 else:
83 self._ksd = KohnShamDecomposition(filename=ksd)
84 add_fake_kpts(self._ksd)
86 # Do a quick sanity check at runtime
87 self._runtime_verify_work_loop()
89 @abstractmethod
90 def __str__(self) -> str:
91 raise NotImplementedError
93 def get_memory_estimate(self) -> MemoryEstimate:
94 memory_estimate = MemoryEstimate(comment='Unknown')
96 return memory_estimate
98 def parallel_prepare(self):
99 """ Read everything necessary synchronously on all ranks. """
101 @abstractmethod
102 def __iter__(self) -> Generator[tuple[WorkMetadataT, DensityMatrix], None, None]:
103 """ Obtain density matrices for various times, pulses or frequencies.
105 Yields
106 ------
107 Tuple (work, dm) on the root rank of the calculation communicator:
109 work
110 An object representing the metadata (time, frequency or pulse) for the work done.
111 dm
112 Density matrix for this time, frequency or pulse.
113 """
114 raise NotImplementedError
116 def iread_gather_on_root(self) -> Generator[tuple[WorkMetadataT, DensityMatrix], None, None]:
117 """ Obtain density matrices for various times, pulses or frequencies and gather to the root rank.
119 Yields
120 ------
121 Tuple (work, dm) on the root rank of the loop and calculation communicators:
123 work
124 An object representing the metadata (time, frequency or pulse) for the work done.
125 dm
126 Density matrix for this time, frequency or pulse.
127 """
128 work: WorkMetadataT | None
129 gen = iter(self)
131 # Loop over the work to be done, and the ranks that are supposed to do it
132 self.parallel_prepare()
133 for rank, work in self.global_work_loop_with_idle():
134 if work is None:
135 # Rank rank will not do any work at this point
136 continue
138 if rank == self.loop_comm.rank:
139 mywork, mydm = next(gen)
140 if self.calc_comm.rank == 0:
141 self.log(f'Read {mywork.desc} in {self.log.elapsed("read"):.1f}s',
142 who='Response', if_elapsed=5)
143 assert work.global_indices == mywork.global_indices, f'{work.desc} != {mywork.desc}'
145 dm = DensityMatrix.broadcast(
146 mydm if self.loop_comm.rank == rank else None,
147 ksd=self.ksd,
148 root=rank, dm_comm=self.calc_comm, comm=self.loop_comm)
150 yield work, dm
152 _exhausted = object()
153 rem = next(gen, _exhausted)
154 assert rem is _exhausted, rem
156 @property
157 def ksd(self) -> KohnShamDecomposition:
158 """ Kohn-Sham decomposition object. """
159 return self._ksd
161 @property
162 def log(self) -> Logger:
163 """ Logger. """
164 return self._log
166 def log_parallel(self, *args, **kwargs) -> Logger:
167 """ Log message with communicator information. """
168 return self._log(*args, **kwargs, comm=self.loop_comm, who='Response')
170 @property
171 def reim(self) -> list[str]:
172 """ List of strings ``'Re'`` and ``'Im'``, depending on whether real, and/or imaginary parts are computed. """
173 return self._reim_r
175 @abstractmethod
176 def work_loop(self,
177 rank: int) -> Generator[WorkMetadataT | None, None, None]:
178 """ The work to be done by a certain rank of the loop communicator.
180 Parameters
181 ----------
182 rank
183 Rank of the loop communicator.
185 Yields
186 ------
187 Objects representing the time, frequency or pulse to be computed by rank ``rank``.
188 None is yielded when `rank` does not do any work while other ranks are doing work.
189 """
190 raise NotImplementedError
192 def _runtime_verify_work_loop(self):
193 """ Verify that the description of work to be done is consistent across ranks. """
194 local_work_r = [list(self.work_loop(rank)) for rank in range(self.loop_comm.size)]
195 work_lengths = [len(local_work) for local_work in local_work_r]
196 assert all([work_lengths[0] == work_length for work_length in work_lengths]), \
197 f'The work loop has different length across the different ranks. {work_lengths}'
198 concat_work_list = [work.global_indices for local_work in local_work_r for work in local_work
199 if work is not None]
200 assert len(concat_work_list) == len(set(concat_work_list)), \
201 f'Different ranks do duplicate work {concat_work_list}'
203 @property
204 def local_work_plan(self) -> tuple[WorkMetadataT, ...]:
205 """ The work to be done by a this rank of the loop communicator.
207 Yields
208 ------
209 Objects representing the time, frequency or pulse to be computed by this rank.
210 """
211 local_work_plan = tuple(work for work in self.work_loop(self.loop_comm.rank)
212 if work is not None)
213 return local_work_plan
215 @property
216 def local_work_plan_with_idle(self) -> tuple[WorkMetadataT | None, ...]:
217 """ The work to be done by a this rank of the loop communicator.
219 This function includes None values for when this rank does not do any work
220 in order to synchronize the execution.
222 Yields
223 ------
224 Objects representing the time, frequency or pulse to be computed by this rank.
225 None is yielded when this rank does not do any work while other ranks are doing work.
226 """
227 local_work_plan = tuple(self.work_loop(self.loop_comm.rank))
229 return local_work_plan
231 def global_work_loop_with_idle(self) -> Generator[tuple[int, WorkMetadataT | None], None, None]:
232 """ The work to be done by a all ranks of the loop communicator.
234 This function includes None values for when ranks do not do any work
235 in order to synchronize the execution.
237 Yields
238 ------
239 Lists of length equal to the loop communicator size. Each element in the list represents
240 the work to be done. See `local_work_plan_with_idle`.
241 """
242 work_loop_r = [self.work_loop(rank) for rank in range(self.loop_comm.size)]
243 while True:
244 for rank in range(self.loop_comm.size):
245 try:
246 work = next(work_loop_r[rank])
247 yield rank, work
248 except StopIteration:
249 if rank == 0:
250 # No more work to do
251 return
252 else:
253 raise RuntimeError(f'Ranks have different amount of work. Exited on {rank}')
255 def global_work_loop(self) -> Generator[tuple[int, WorkMetadataT | None], None, None]:
256 """ The work to be done by a all ranks of the loop communicator.
258 Yields
259 ------
260 Lists of length equal to the loop communicator size. Each element in the list represents
261 the work to be done. See :func:`local_work_plan`.
262 """
263 for rank, work in self.global_work_loop_with_idle():
264 if work is None:
265 continue
266 yield rank, work
268 @property
269 def localn(self) -> int:
270 """ Total number of density matrices this rank will work with. """
271 return len(self.local_work_plan)
273 @property
274 def globaln(self) -> int:
275 """ Total number of density matrices to work with across all ranks. """
276 local_work_r = [list(self.work_loop(rank)) for rank in range(self.loop_comm.size)]
277 concat_work_list = [work for local_work in local_work_r for work in local_work
278 if work is not None]
279 return len(concat_work_list)
281 @property
282 def calc_comm(self) -> Communicator:
283 """ Calculation communicator.
285 Each rank of this communicator calculates the observables corresponding to
286 a part (in electron-hole space) of the density matrices. """
287 return self._calc_comm
289 @calc_comm.setter
290 def calc_comm(self, value: Communicator):
291 from gpaw.mpi import world
292 if value is None:
293 self.calc_comm = world
294 return
296 assert hasattr(value, 'rank')
297 assert hasattr(value, 'size')
298 self._calc_comm = value
300 @property
301 def loop_comm(self) -> Communicator:
302 """ Loop communicator.
304 Each rank of this communicator calculates the density matrices corresponding to
305 different times, frequencies or after convolution with a different pulse. """
306 return self._loop_comm
308 @loop_comm.setter
309 def loop_comm(self, value: Communicator):
310 from gpaw.mpi import world
311 if value is None:
312 self.loop_comm = world
313 return
315 assert hasattr(value, 'rank')
316 assert hasattr(value, 'size')
317 self._loop_comm = value
318 raise NotImplementedError
320 @abstractmethod
321 def write_to_disk(self,
322 fmt: str):
323 """ Calculate the density matrices amd save to disk.
325 Parameters
326 ----------
327 fmt
328 Formatting string.
329 """
330 raise NotImplementedError