Coverage for rhodent/writers/energy.py: 94%

51 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-01 16:57 +0000

1from __future__ import annotations 

2 

3import numpy as np 

4from numpy.typing import NDArray 

5from gpaw.tddft.units import au_to_eV 

6 

7from ..density_matrices.base import WorkMetadata 

8from ..density_matrices.time import ConvolutionDensityMatrices 

9from .writer import Writer, ResultsCollector 

10from ..utils import Result, get_gaussian_pulse_values 

11 

12 

13class EnergyWriter(Writer): 

14 

15 """ Calculate energy contributions 

16 

17 Parameters 

18 ---------- 

19 collector 

20 ResultsCollector object 

21 """ 

22 

23 def __init__(self, 

24 collector: ResultsCollector, 

25 only_one_pulse: bool): 

26 super().__init__(collector) 

27 self.only_one_pulse = only_one_pulse 

28 self._ulm_tag = 'EnergyDecomposition' 

29 if only_one_pulse: 

30 if isinstance(self.density_matrices, ConvolutionDensityMatrices): 

31 assert len(self.density_matrices.pulses) == 1, 'Only one pulse allowed' 

32 else: 

33 assert isinstance(self.density_matrices, ConvolutionDensityMatrices) 

34 

35 @property 

36 def common_arrays(self) -> dict[str, NDArray[np.float64] | NDArray[np.int64] | int | float]: 

37 common = super().common_arrays 

38 

39 if self.calc.sigma is not None: 

40 # There is an energy grid 

41 common['sigma'] = self.calc.sigma 

42 common['energy_o'] = np.array(self.calc.energies_occ) 

43 common['energy_u'] = np.array(self.calc.energies_unocc) 

44 

45 assert isinstance(self.density_matrices, ConvolutionDensityMatrices) 

46 common['time_t'] = self.density_matrices.times * 1e-3 

47 

48 # If pulses are Gaussian pulses, then get dictionaries of 'pulsefreq' and 'pulsefwhm' 

49 pulsedicts = [get_gaussian_pulse_values(pulse) for pulse in self.density_matrices.pulses] 

50 

51 try: 

52 pulsefreqs = [d['pulsefreq'] for d in pulsedicts] 

53 pulsefwhms = [d['pulsefwhm'] for d in pulsedicts] 

54 

55 if self.only_one_pulse: 

56 common['pulsefreq'] = pulsefreqs[0] 

57 common['pulsefwhm'] = pulsefwhms[0] 

58 else: 

59 common['pulsefreq_p'] = np.array(pulsefreqs) 

60 common['pulsefwhm_p'] = np.array(pulsefwhms) 

61 except KeyError: 

62 # Not GaussianPulses 

63 pass 

64 

65 common['resonant_low'] = self.calc._filter_pair_low * au_to_eV # type: ignore 

66 common['resonant_high'] = self.calc._filter_pair_high * au_to_eV # type: ignore 

67 return common 

68 

69 def fill_ulm(self, 

70 writer, 

71 work: WorkMetadata, 

72 result: Result): 

73 assert self.only_one_pulse 

74 if self.collector.calc_kwargs['yield_total_E_ou']: 

75 writer.fill(result['E_ou']) 

76 

77 def write_empty_arrays_ulm(self, writer): 

78 assert self.only_one_pulse 

79 if not self.collector.calc_kwargs['yield_total_E_ou']: 

80 return 

81 shape_ou = (len(self.calc.energies_occ), len(self.calc.energies_unocc)) 

82 Nt = len(self.density_matrices.times) 

83 writer.add_array('E_tou', (Nt, ) + shape_ou, dtype=float)