Coverage for rhodent/density_matrices/distributed/base.py: 84%
223 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
1from __future__ import annotations
3from abc import ABC, abstractmethod
4from typing import Generator, Generic, NamedTuple, Iterable
5from itertools import product, zip_longest
7import numpy as np
9from gpaw.mpi import world
10from gpaw.lcaotddft.ksdecomposition import KohnShamDecomposition
12from ..buffer import DensityMatrixBuffer
13from ..readers.gpaw import KohnShamRhoWfsReader
14from ...utils import DTypeT, Logger, concatenate_indices, env
15from ...typing import Communicator
16from ...utils.memory import HasMemoryEstimate
19class BaseDistributor(HasMemoryEstimate, ABC, Generic[DTypeT]):
21 """ Distribute density matrices over time, frequency or other dimensions across MPI ranks
22 """
24 def __init__(self,
25 rho_reader: KohnShamRhoWfsReader,
26 parameters: RhoParameters | None = None,
27 comm: Communicator | None = None):
28 self.rho_wfs_reader = rho_reader
30 self._comm = world if comm is None else comm
31 if parameters is None:
32 parameters = RhoParameters.from_ksd(self.ksd, self.comm)
33 self._parameters = parameters
35 self.derivative_order_s = [0]
37 @property
38 @abstractmethod
39 def dtype(self) -> np.dtype[DTypeT]:
40 """ Dtype of buffers. """
41 raise NotImplementedError
43 @property
44 @abstractmethod
45 def xshape(self) -> tuple[int, ...]:
46 """ Shape of x-dimension in buffers. """
47 raise NotImplementedError
49 @property
50 def ksd(self) -> KohnShamDecomposition:
51 """ Kohn-Sham decomposition object. """
52 return self.rho_wfs_reader.ksd
54 @property
55 def comm(self) -> Communicator:
56 """ MPI communicator. """
57 return self._comm
59 @property
60 def yield_re(self) -> bool:
61 """ Whether real part of density matrices is calculated. """
62 return self.rho_wfs_reader.yield_re
64 @property
65 def yield_im(self) -> bool:
66 """ Whether imaginary part of density matrices is calculated. """
67 return self.rho_wfs_reader.yield_im
69 @property
70 def log(self) -> Logger:
71 """ Logger object. """
72 return self.rho_wfs_reader.log
74 @abstractmethod
75 def __iter__(self) -> Generator[DensityMatrixBuffer, None, None]:
76 """ Yield density matrices in parts. Different data is
77 yielded on different ranks
79 Yields
80 ------
81 Part of the density matrix
82 """
83 raise NotImplementedError
85 def work_loop(self,
86 rank: int) -> Generator[RhoIndices | None, None, None]:
87 """ Like work_loop_by_rank but for one particular rank
88 """
89 for chunks_r in self.work_loop_by_ranks():
90 yield chunks_r[rank]
92 @property
93 def niters(self) -> int:
94 """ Number of iterations needed to read all chunks. """
95 return len(list(self.work_loop_by_ranks()))
97 @property
98 def maxntimes(self) -> int:
99 """ Maximum number of ranks participating in reading of times. """
100 for t_r in self.rho_wfs_reader.work_loop_by_ranks():
101 return sum(1 for t in t_r if t is not None)
103 raise RuntimeError
105 @property
106 def maxnchunks(self) -> int:
107 """ Maximum number of ranks participating in reading of chunks. """
108 for chunks_r in self.work_loop_by_ranks():
109 return sum(1 for chunk in chunks_r if chunk is not None)
111 raise RuntimeError
113 def describe_reim(self) -> str:
114 if self.yield_re and self.yield_im:
115 return 'Real and imaginary parts'
116 elif self.yield_re:
117 return 'Real part'
118 else:
119 return 'Imaginary part'
121 def describe_derivatives(self) -> str:
122 return 'derivative orders: ' + ', '.join([str(d) for d in self.derivative_order_s])
124 def work_loop_by_ranks(self) -> Generator[list[RhoIndices | None], None, None]:
125 """ Yield slice objects corresponding to the chunk of the density matrix
126 that is gathered on each rank.
128 New indices are yielded until the entire density matrix is processed
129 (across all ranks).
131 Yields
132 ------
133 List of slice objects corresponding to part of the density matrix
134 yielded on each ranks. None in place of the slice object if there is
135 nothing yielded for that rank.
136 """
137 gen = self._parameters.iterate_indices()
139 while True:
140 chunks_r: list[RhoIndices | None] = [indices for _, indices
141 in zip(range(self.comm.size), gen)]
143 remaining = self.comm.size - len(chunks_r)
144 if remaining == 0:
145 yield chunks_r
146 elif remaining == self.comm.size:
147 # There is nothing left to do for any rank
148 break
149 else:
150 # Append Nones for the ranks that are not doing anything
151 chunks_r += remaining * [None]
152 yield chunks_r
153 break
155 def gather_on_root(self) -> Generator[DensityMatrixBuffer | None, None, None]:
156 self.rho_wfs_reader.C0S_sknM # Make sure to read this synchronously
158 for indices_r, dm_buffer in zip_longest(self.work_loop_by_ranks(),
159 self, fillvalue=None):
160 assert indices_r is not None, 'Work loop shorter than work'
162 # Yield root's own work
163 if self.comm.rank == 0:
164 assert indices_r[0] is not None
165 assert dm_buffer is not None
166 dm_buffer.ensure_contiguous_buffers()
168 yield dm_buffer.copy()
169 else:
170 yield None
172 # Yield the work of non-root
173 for recvrank, recvindices in enumerate(indices_r[1:], start=1):
174 if recvindices is None:
175 # No work on this recvrank
176 continue
178 if self.comm.rank == 0:
179 # Receive work
180 assert dm_buffer is not None
181 dm_buffer.recv_arrays(self.comm, recvrank, log=self.log)
182 yield dm_buffer.copy()
183 else:
184 # Send work to root if there is any
185 if self.comm.rank == recvrank:
186 assert dm_buffer is not None
187 dm_buffer.send_arrays(self.comm, 0, log=self.log)
188 yield None
190 def collect_on_root(self) -> DensityMatrixBuffer | None:
191 gen = self._parameters.iterate_indices()
193 nnshape = (self._parameters.n1size, self._parameters.n2size)
194 full_dm = DensityMatrixBuffer(nnshape, self.xshape, dtype=self.dtype)
195 full_dm.zero_buffers(real=self.yield_re, imag=self.yield_im, derivative_order_s=self.derivative_order_s)
197 for indices, dm_buffer in zip_longest(gen,
198 self.gather_on_root(), fillvalue=None):
199 if self.comm.rank != 0:
200 continue
202 assert indices is not None, 'Iterators must be same length'
203 assert dm_buffer is not None, 'Iterators must be same length'
205 s, k, n1, n2 = indices
206 assert s == 0
207 assert k == 0
209 for partial_data, full_data in zip(dm_buffer._iter_buffers(), full_dm._iter_buffers()):
210 _nn1, _nn2 = full_data[n1, n2].shape[:2]
211 # Numpy struggles with the static type below
212 full_data[n1, n2, :] += partial_data[:_nn1, :_nn2:] # type: ignore
213 self.log(f'Collected on root: density matrix slice [s={s}, k={k}, n1={n1}, n2={n2}].',
214 flush=True, who='Response')
216 if self.comm.rank != 0:
217 return None
219 return full_dm
221 @classmethod
222 @abstractmethod
223 def from_reader(cls,
224 rho_nn_reader: KohnShamRhoWfsReader,
225 parameters: RhoParameters,
226 **kwargs) -> BaseDistributor:
227 """ Set up this class from a density matrix reader and parameters object
229 """
230 raise NotImplementedError
232 @classmethod
233 def from_parameters(cls,
234 wfs_fname: str,
235 ksd: KohnShamDecomposition | str,
236 comm=world,
237 yield_re: bool = True,
238 yield_im: bool = True,
239 stridet: int = 1,
240 log: Logger | None = None,
241 verbose: bool = False,
242 **kwargs):
243 """ Set up this class, trying to enforce memory limit.
245 Parameters
246 ----------
247 wfs_fname
248 File name of the time-dependent wave functions file.
249 ksd
250 KohnShamDecomposition object or file name to the ksd file.
251 comm
252 MPI communicator.
253 yield_re
254 Whether to read and yield the real part of wave functions/density matrices.
255 yield_im
256 Whether to read and yield the imaginary part of wave functions/density matrices.
257 stridet
258 Skip this many steps when reading the time-dependent wave functions file.
259 log
260 Logger object.
261 verbose
262 Be verbose in the attempts to satisfy memory requirement.
263 kwargs
264 Options passed through to :func:`from_reader`.
265 """
266 # Set up the time-dependent wave functions reader which is always needed
267 rho_reader = KohnShamRhoWfsReader(
268 wfs_fname=wfs_fname, ksd=ksd, comm=comm,
269 yield_re=yield_re, yield_im=yield_im, log=log, stridet=stridet)
271 log = rho_reader.log
273 # Get the target memory limit
274 to_MiB = 1024 ** -2
275 mem_limit = env.get_response_max_mem(comm.size) / to_MiB
276 log('Attempting to set up response calculation with memory limit of '
277 f'{mem_limit * to_MiB:.1f} MiB across all ranks.', who='Setup', rank=0)
279 totals = []
280 for iterations in range(1, 100):
281 # Try setting up the distributor such that `iterations` iterations are
282 # needed to process all chunks
283 parameters = RhoParameters.from_ksd(rho_reader.ksd, comm, chunk_iterations=iterations)
284 distributor = cls.from_reader(rho_reader, parameters, **kwargs)
285 total = distributor.get_memory_estimate().grand_total
286 totals.append(total)
287 compare = totals[:-5:-1] # Last 4 totals in reverse order
288 last_changes = [tot_new / tot_old for tot_new, tot_old in zip(compare, compare[1:])]
289 if len(last_changes) == 0:
290 improvement = ''
291 else:
292 s = ', '.join([f'{(1 - change)*100:.1f}%' for change in last_changes])
293 improvement = f'Last improvements {s}'
295 if verbose:
296 log(f'Trying splitting in {distributor.niters:3} chunks -- estimate {total * to_MiB:.1f} MiB. '
297 f'{improvement}', who='Setup', rank=0)
298 if total < mem_limit:
299 log(f'Found suitable set of parameters limiting the memory to {total * to_MiB:.1f} MiB.',
300 who='Setup', rank=0)
301 return distributor
302 if len(last_changes) == 3 and sum(last_changes) / 3 > 0.98:
303 break
305 parameters = RhoParameters.from_ksd(rho_reader.ksd, comm, chunk_iterations=iterations)
306 distributor = cls.from_reader(rho_reader, parameters, **kwargs)
307 total = distributor.get_memory_estimate().grand_total
309 log(f'Cannot satisfy memory limit. Estimate is {total * to_MiB:.1f} MiB.',
310 who='Setup', rank=0)
312 return distributor
315class RhoIndices(NamedTuple):
317 s: int
318 k: int
319 n1: slice
320 n2: slice
322 @staticmethod
323 def concatenate_indices(indices_list: Iterable[RhoIndices],
324 ) -> tuple[RhoIndices, list[RhoIndices]]:
325 indices_list = list(indices_list)
326 assert len(indices_list) > 0
327 s, k = indices_list[0][:2]
328 assert all(indices.s == s for indices in indices_list), f'All s must be identical {indices_list}'
329 assert all(indices.k == k for indices in indices_list), f'All k must be identical {indices_list}'
331 _indices_concat, _reduced_indices_list = concatenate_indices(
332 [(indices.n1, indices.n2) for indices in indices_list])
333 indices_concat = RhoIndices(s, k, *_indices_concat)
334 reduced_indices_list = [RhoIndices(s, k, *indices) for indices in _reduced_indices_list]
336 return indices_concat, reduced_indices_list
339class RhoParameters(NamedTuple):
341 """ Utility class to describe density matrix indices.
343 Parameters
344 ----------
345 ns
346 Number of spins.
347 nk
348 Number of kpoints.
349 n1min
350 Smallest index of row to read.
351 n1max
352 Largest index of row to read.
353 n2min
354 Smallest index of column to read.
355 n2max
356 Largest index of column to read.
357 striden1
358 Stride for reading rows. Each chunk will be this size in the first dimension.
359 striden2
360 Stride for reading columns. Each chunk will be this size in the second dimension.
361 """
363 ns: int
364 nk: int
365 n1min: int
366 n1max: int
367 n2min: int
368 n2max: int
369 striden1: int = 4
370 striden2: int = 4
372 def __post_init__(self):
373 self.striden1 = min(self.striden1, self.n1size)
374 self.striden2 = min(self.striden2, self.n2size)
376 @property
377 def full_nnshape(self) -> tuple[int, int]:
378 """ Shape of the full density matrix to be read. """
379 return (self.n1size, self.n2size)
381 @property
382 def nnshape(self) -> tuple[int, int]:
383 """ Shape of each density matrix chunk. """
384 return (self.striden1, self.striden2)
386 @property
387 def n1size(self) -> int:
388 """ Size of full density matrix in the first dimension. """
389 return self.n1max + 1 - self.n1min
391 @property
392 def n2size(self) -> int:
393 """ Size of full density matrix in the first dimension. """
394 return self.n2max + 1 - self.n2min
396 def iterate_indices(self) -> Generator[RhoIndices, None, None]:
397 """ Iteratively yield indices slicing chunks of the density matrix. """
398 for s, k, n1, n2 in product(range(self.ns), range(self.nk),
399 range(0, self.n1size, self.striden1),
400 range(0, self.n2size, self.striden2)):
401 indices = RhoIndices(s=0, k=0,
402 n1=slice(n1, n1 + self.striden1),
403 n2=slice(n2, n2 + self.striden2))
404 yield indices
406 @classmethod
407 def from_ksd(cls,
408 ksd: KohnShamDecomposition,
409 comm: Communicator | None = None,
410 only_ia: bool = True,
411 chunk_iterations: int = 1,
412 **kwargs) -> RhoParameters:
413 """ Initialize from KohnShamDecomposition.
415 Parameters
416 ----------
417 ksd
418 KohnShamDecomposition.
419 comm
420 MPI Communicator.
421 only_ia
422 ``True`` if the parameters should be set up such that
423 the electron-hole part of the density matrix is read,
424 otherwise full density matrix.
425 chunk_iterations
426 Attempt to set up the strides so that the total number of
427 chunks is as close as possible but not more than the number
428 of MPI ranks times :attr:`chunk_iterations`.
429 kwargs
430 Options passed through to the constructor.
431 """
432 if comm is None:
433 comm = world
435 # Number of spins, kpoints and states
436 ns, nk, nn, _ = ksd.reader.proxy('C0_unM', 0).shape
438 params = dict()
439 if only_ia:
440 # Dimensions of electron-hole part
441 imin, imax, amin, amax = [int(i) for i in ksd.ialims()]
443 params['n1min'], params['n2min'] = imin, amin
444 params['n1max'], params['n2max'] = imax, amax
445 else:
446 params['n1min'], params['n2min'] = 0, 0
447 params['n1max'], params['n2max'] = nn - 1, nn - 1
449 # Set up a helper object get the size
450 helper = cls(ns, nk, **params)
452 # We want this many chunks in total
453 target_nchunks = chunk_iterations * comm.size
454 ar = helper.n2size / helper.n1size # Aspect ratio of density matrix
456 nsplits1 = max(int(np.floor(np.sqrt(target_nchunks / ar))), 1)
457 nsplits2 = (target_nchunks + nsplits1 - 1) // nsplits1
459 # Defaults
460 params['striden1'] = (helper.n1size + nsplits1 - 1) // nsplits1
461 params['striden2'] = (helper.n2size + nsplits2 - 1) // nsplits2
463 # Overwrite the default options in params with explicitly set options
464 params.update(**kwargs)
466 return cls(ns, nk, **params)