Coverage for rhodent/writers/writer.py: 96%
185 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
1from __future__ import annotations
3from abc import ABC, abstractmethod
4from typing import Any, Generic, TypeVar
5import numpy as np
6from numpy.typing import NDArray
8from gpaw.mpi import world
9from gpaw.io import Writer as GPAWWriter
11from ..density_matrices.base import WorkMetadata, WorkMetadataT, BaseDensityMatrices
12from ..density_matrices.frequency import FrequencyDensityMatrices, FrequencyDensityMatrixMetadata
13from ..density_matrices.time import ConvolutionDensityMatrices, ConvolutionDensityMatrixMetadata
14from ..calculators.base import BaseObservableCalculator
15from ..voronoi import VoronoiWeights, EmptyVoronoiWeights, atom_projections_to_numpy
16from ..utils import Result, ResultKeys
19class ResultsCollector(ABC, Generic[WorkMetadataT]):
21 """ Utility class to collect result arrays for different
22 times, pulses, or frequencies.
24 Parameters
25 ----------
26 resultkeys
27 Result keys to be collected.
28 additional_dimension
29 Shape of additional dimension(s) due to the different times, frequencies, etc.
30 additional_suffix
31 String prepended to the suffix if each key.
32 """
34 def __init__(self,
35 calc: BaseObservableCalculator,
36 calc_kwargs: dict[str, Any],
37 resultkeys: ResultKeys,
38 additional_suffix: str,
39 additional_dimension: tuple[int, ...],
40 exclude: list[str] = []):
41 self.calc = calc
42 self.calc_kwargs = calc_kwargs
43 self.resultkeys = resultkeys.__copy__()
44 for key in exclude:
45 if key in self.resultkeys:
46 self.resultkeys.remove(key)
47 self.additional_dimension = additional_dimension
48 self.additional_suffix = additional_suffix
50 # Create the new result keys for the aggregated data
51 self.collect_resultkeys = ResultKeys()
52 for key, shape, dtype in self.resultkeys:
53 newkey = self.format_key(key)
54 self.collect_resultkeys.add_key(newkey, additional_dimension + shape, dtype)
56 self.result = Result(mutable=True)
58 def empty_results(self):
59 if world.rank == 0:
60 self.result.create_all_zeros(self.collect_resultkeys)
62 def finalize_results(self):
63 pass
65 def format_key(self,
66 key: str) -> str:
67 """ Add the new suffix to the key.
69 Parameters
70 ----------
71 key
72 Original result key.
74 Returns
75 -------
76 New result key with the added suffix.
77 """
78 shape, _ = self.resultkeys[key]
79 if len(shape) == 0:
80 return key + f'_{self.additional_suffix}'
82 s = key.split('_')
83 assert len(s) > 1
84 s[-1] = self.additional_suffix + s[-1]
85 return '_'.join(s)
87 @abstractmethod
88 def accumulate_results(self,
89 work: WorkMetadataT,
90 result: Result):
91 pass
94ResultsCollectorT = TypeVar('ResultsCollectorT', bound=ResultsCollector)
97class TimeResultsCollector(ResultsCollector):
99 """ Collect results after convolution with different pulses.
101 The letter t is prepended to the suffix of the result keys to indicate
102 an additional dimension of time.
104 Parameters
105 ----------
106 calc
107 Calculator.
108 calc_kwargs
109 Keyword arguments passed to the icalculate function.
110 exclude
111 Keys that are excluded from collection.
112 """
114 def __init__(self,
115 calc: BaseObservableCalculator,
116 calc_kwargs: dict[str, Any],
117 exclude: list[str] = []):
118 assert isinstance(calc.density_matrices, ConvolutionDensityMatrices)
119 assert len(calc.density_matrices.pulses) == 1
120 Nt = len(calc.times)
122 resultkeys = calc.get_result_keys(**calc_kwargs)
123 super().__init__(calc, calc_kwargs, resultkeys,
124 additional_suffix='t', additional_dimension=(Nt, ), exclude=exclude)
126 def accumulate_results(self,
127 work: ConvolutionDensityMatrixMetadata,
128 result: Result):
129 assert isinstance(work, ConvolutionDensityMatrixMetadata)
130 assert world.rank == 0
132 for key, _, _ in self.resultkeys:
133 newkey = self.format_key(key)
134 self.result.set_to(newkey, work.globalt, result[key])
137class TimeAverageResultsCollector(ResultsCollector):
139 """ Collect results and average over times.
141 Parameters
142 ----------
143 calc
144 Calculator.
145 calc_kwargs
146 Keyword arguments passed to the icalculate function.
147 exclude
148 Keys that are excluded from collection.
149 """
151 def __init__(self,
152 calc: BaseObservableCalculator,
153 calc_kwargs: dict[str, Any],
154 exclude: list[str] = []):
155 assert isinstance(calc.density_matrices, ConvolutionDensityMatrices)
156 assert len(calc.density_matrices.pulses) == 1
158 resultkeys = calc.get_result_keys(**calc_kwargs)
159 super().__init__(calc, calc_kwargs, resultkeys,
160 additional_suffix='', additional_dimension=(), exclude=exclude)
162 def accumulate_results(self,
163 work: ConvolutionDensityMatrixMetadata,
164 result: Result):
165 assert isinstance(work, ConvolutionDensityMatrixMetadata)
166 assert world.rank == 0
168 for key, _, _ in self.resultkeys:
169 newkey = self.format_key(key)
170 self.result.add_to(newkey, slice(None), result[key])
172 def finalize_results(self):
173 if world.rank > 0:
174 return
176 nt = len(self.calc.density_matrices.times)
177 for key, _, _ in self.collect_resultkeys:
178 self.result[key] /= nt
181class PulseConvolutionResultsCollector(ResultsCollector):
183 """ Collect results after convolution with different pulses.
185 The letters pt are prepended to the suffix of the result keys to indicate
186 an additional dimension of pulse and time.
188 Parameters
189 ----------
190 calc
191 Calculator.
192 calc_kwargs
193 Keyword arguments passed to the icalculate function.
194 exclude
195 Keys that are excluded from collection.
196 """
198 def __init__(self,
199 calc: BaseObservableCalculator,
200 calc_kwargs: dict[str, Any],
201 exclude: list[str] = []):
202 assert isinstance(calc.density_matrices, ConvolutionDensityMatrices)
203 Np = len(calc.pulses)
204 Nt = len(calc.times)
206 resultkeys = calc.get_result_keys(**calc_kwargs)
207 super().__init__(calc, calc_kwargs, resultkeys,
208 additional_suffix='pt', additional_dimension=(Np, Nt), exclude=exclude)
210 def accumulate_results(self,
211 work: ConvolutionDensityMatrixMetadata,
212 result: Result):
213 assert isinstance(work, ConvolutionDensityMatrixMetadata)
214 assert world.rank == 0
216 for key, _, _ in self.resultkeys:
217 newkey = self.format_key(key)
218 self.result.set_to(newkey, (work.globalp, work.globalt), result[key])
221class PulseConvolutionAverageResultsCollector(ResultsCollector):
223 """ Collect results after convolution with different pulses, average over times.
225 The letter p is prepended to the suffix of the result keys to indicate
226 an additional dimension of pulse.
228 Parameters
229 ----------
230 calc
231 Calculator.
232 calc_kwargs
233 Keyword arguments passed to the icalculate function.
234 exclude
235 Keys that are excluded from collection.
236 """
238 def __init__(self,
239 calc: BaseObservableCalculator,
240 calc_kwargs: dict[str, Any],
241 exclude: list[str] = []):
242 assert isinstance(calc.density_matrices, ConvolutionDensityMatrices)
243 Np = len(calc.pulses)
245 resultkeys = calc.get_result_keys(**calc_kwargs)
246 super().__init__(calc, calc_kwargs, resultkeys,
247 additional_suffix='p', additional_dimension=(Np, ), exclude=exclude)
249 def accumulate_results(self,
250 work: ConvolutionDensityMatrixMetadata,
251 result: Result):
252 assert isinstance(work, ConvolutionDensityMatrixMetadata)
253 assert world.rank == 0
255 for key, _, _ in self.resultkeys:
256 newkey = self.format_key(key)
257 self.result.add_to(newkey, work.globalp, result[key])
259 def finalize_results(self):
260 if world.rank > 0:
261 return
263 nt = len(self.calc.density_matrices.times)
264 for key, _, _ in self.collect_resultkeys:
265 self.result[key] /= nt
268class FrequencyResultsCollector(ResultsCollector):
270 """ Collect results in the frequency domain.
272 This class should work with the Fourier transform of
273 the real part of density matrices.
275 The letter w is prepended to the suffix of the result keys to indicate
276 an additional dimension of frequency.
278 Parameters
279 ----------
280 calc
281 Calculator.
282 calc_kwargs
283 Keyword arguments passed to the icalculate function.
284 exclude
285 Keys that are excluded from collection
286 """
288 def __init__(self,
289 calc: BaseObservableCalculator,
290 calc_kwargs: dict[str, Any],
291 exclude: list[str] = []):
292 assert isinstance(calc.density_matrices, FrequencyDensityMatrices)
293 Nw = len(calc.frequencies)
294 assert 'Im' not in calc.density_matrices.reim
296 resultkeys = calc.get_result_keys(**calc_kwargs)
297 super().__init__(calc, calc_kwargs, resultkeys,
298 additional_suffix='w', additional_dimension=(Nw, ), exclude=exclude)
300 def accumulate_results(self,
301 work: FrequencyDensityMatrixMetadata,
302 result: Result):
303 assert isinstance(work, FrequencyDensityMatrixMetadata)
304 assert world.rank == 0
306 for key, _, _ in self.resultkeys:
307 newkey = self.format_key(key)
308 self.result.set_to(newkey, work.globalw, result[key])
311class Writer(Generic[ResultsCollectorT]):
313 def __init__(self, collector: ResultsCollectorT):
314 self._collector = collector
315 self._ulm_tag = 'RhodentResults'
317 @property
318 def collector(self) -> ResultsCollectorT:
319 return self._collector
321 @property
322 def calc(self) -> BaseObservableCalculator:
323 return self.collector.calc
325 @property
326 def density_matrices(self) -> BaseDensityMatrices:
327 return self.collector.calc.density_matrices
329 @property
330 def voronoi(self) -> VoronoiWeights:
331 voronoi = self.calc.voronoi
332 if voronoi is None:
333 return EmptyVoronoiWeights()
334 return voronoi
336 @property
337 def common_arrays(self) -> dict[str, NDArray[np.float64] | NDArray[np.int64] | int | float]:
338 """ Dictionary of eigenvalues and limits. """
339 imin, imax, amin, amax = self.calc.ksd.ialims()
340 arrays: dict[str, NDArray[np.float64] | NDArray[np.int64] | int | float] = dict()
341 arrays['eig_n'] = self.calc.eig_n
342 arrays['eig_i'] = self.calc.eig_i
343 arrays['eig_a'] = self.calc.eig_a
344 arrays['imin'] = imin
345 arrays['imax'] = imax
346 arrays['amin'] = amin
347 arrays['amax'] = amax
349 return arrays
351 @property
352 def icalculate_kwargs(self) -> dict:
353 """ Keyword arguments to icalculate. """
354 return self.collector.calc_kwargs
356 def fill_ulm(self,
357 writer,
358 work: WorkMetadata,
359 result: Result):
360 """ Fill one entry of the ULM file.
362 Parameters
363 ----------
364 writer
365 Open ULM writer object.
366 work
367 Metadata to current piece of data.
368 result
369 Result containing the current observables.
370 """
371 raise NotImplementedError
373 def write_empty_arrays_ulm(self, writer):
374 """ Add empty arrays in to the ULM file.
376 Parameters
377 ----------
378 writer
379 Open ULM writer object.
380 """
381 raise NotImplementedError
383 def calculate_data(self) -> Result:
384 """ Calculate results on all ranks and return Result object.
386 Returns
387 -------
388 Retult object. Is empty on non-root ranks.
389 """
390 self.collector.empty_results()
392 for work, res in self.calc.icalculate_gather_on_root(**self.icalculate_kwargs):
393 self.collector.accumulate_results(work, res)
395 self.collector.finalize_results()
397 return self.collector.result
399 def calculate_and_save_npz(self,
400 out_fname: str,
401 write_extra: dict[str, Any] = dict()):
402 """ Calculate results on all ranks and save to npz file.
404 Parameters
405 ----------
406 out_fname
407 File name.
408 """
409 result = self.calculate_data()
411 if world.rank > 0:
412 return
414 atom_projections = atom_projections_to_numpy(self.voronoi.atom_projections)
415 np.savez(out_fname, **self.common_arrays, **result._data, # type: ignore
416 atom_projections=atom_projections)
417 self.calc.log_parallel(f'Written {out_fname}', flush=True)
419 def calculate_and_save_ulm(self,
420 out_fname: str,
421 write_extra: dict[str, Any] = dict()):
422 """ Calculate results on all ranks and save to ULM file.
424 Parameters
425 ----------
426 out_fname
427 File name.
428 """
429 self.collector.empty_results()
431 with GPAWWriter(out_fname, world, mode='w', tag=self._ulm_tag[:16]) as writer:
432 writer.write(version=1)
433 writer.write('atom_projections', self.voronoi.atom_projections)
434 writer.write(**(self.common_arrays if world.rank == 0 else dict()))
436 self.write_empty_arrays_ulm(writer)
438 for work, res in self.calc.icalculate_gather_on_root(**self.icalculate_kwargs):
439 self.fill_ulm(writer, work, res)
440 self.collector.accumulate_results(work, res)
442 self.collector.finalize_results()
443 writer.write(**self.collector.result._data)
445 if world.rank == 0:
446 self.calc.log_parallel(f'Written {out_fname}', flush=True)