Coverage for rhodent/writers/energy.py: 94%
51 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
1from __future__ import annotations
3import numpy as np
4from numpy.typing import NDArray
5from gpaw.tddft.units import au_to_eV
7from ..density_matrices.base import WorkMetadata
8from ..density_matrices.time import ConvolutionDensityMatrices
9from .writer import Writer, ResultsCollector
10from ..utils import Result, get_gaussian_pulse_values
13class EnergyWriter(Writer):
15 """ Calculate energy contributions
17 Parameters
18 ----------
19 collector
20 ResultsCollector object
21 """
23 def __init__(self,
24 collector: ResultsCollector,
25 only_one_pulse: bool):
26 super().__init__(collector)
27 self.only_one_pulse = only_one_pulse
28 self._ulm_tag = 'EnergyDecomposition'
29 if only_one_pulse:
30 if isinstance(self.density_matrices, ConvolutionDensityMatrices):
31 assert len(self.density_matrices.pulses) == 1, 'Only one pulse allowed'
32 else:
33 assert isinstance(self.density_matrices, ConvolutionDensityMatrices)
35 @property
36 def common_arrays(self) -> dict[str, NDArray[np.float64] | NDArray[np.int64] | int | float]:
37 common = super().common_arrays
39 if self.calc.sigma is not None:
40 # There is an energy grid
41 common['sigma'] = self.calc.sigma
42 common['energy_o'] = np.array(self.calc.energies_occ)
43 common['energy_u'] = np.array(self.calc.energies_unocc)
45 assert isinstance(self.density_matrices, ConvolutionDensityMatrices)
46 common['time_t'] = self.density_matrices.times * 1e-3
48 # If pulses are Gaussian pulses, then get dictionaries of 'pulsefreq' and 'pulsefwhm'
49 pulsedicts = [get_gaussian_pulse_values(pulse) for pulse in self.density_matrices.pulses]
51 try:
52 pulsefreqs = [d['pulsefreq'] for d in pulsedicts]
53 pulsefwhms = [d['pulsefwhm'] for d in pulsedicts]
55 if self.only_one_pulse:
56 common['pulsefreq'] = pulsefreqs[0]
57 common['pulsefwhm'] = pulsefwhms[0]
58 else:
59 common['pulsefreq_p'] = np.array(pulsefreqs)
60 common['pulsefwhm_p'] = np.array(pulsefwhms)
61 except KeyError:
62 # Not GaussianPulses
63 pass
65 common['resonant_low'] = self.calc._filter_pair_low * au_to_eV # type: ignore
66 common['resonant_high'] = self.calc._filter_pair_high * au_to_eV # type: ignore
67 return common
69 def fill_ulm(self,
70 writer,
71 work: WorkMetadata,
72 result: Result):
73 assert self.only_one_pulse
74 if self.collector.calc_kwargs['yield_total_E_ou']:
75 writer.fill(result['E_ou'])
77 def write_empty_arrays_ulm(self, writer):
78 assert self.only_one_pulse
79 if not self.collector.calc_kwargs['yield_total_E_ou']:
80 return
81 shape_ou = (len(self.calc.energies_occ), len(self.calc.energies_unocc))
82 Nt = len(self.density_matrices.times)
83 writer.add_array('E_tou', (Nt, ) + shape_ou, dtype=float)