Coverage for rhodent/writers/tcm.py: 92%

60 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-01 16:57 +0000

1from __future__ import annotations 

2 

3import numpy as np 

4from numpy.typing import NDArray 

5 

6from ..density_matrices.frequency import FrequencyDensityMatrices 

7from ..density_matrices.base import WorkMetadata 

8from ..density_matrices.time import ConvolutionDensityMatrices 

9from ..utils import Result, get_gaussian_pulse_values 

10from .writer import ResultsCollector, Writer 

11 

12 

13class DipoleWriter(Writer): 

14 

15 """ Calculate dipole moment contributions, optionally broadened onto 

16 an energy grid as a transition contribution map 

17 

18 Parameters 

19 ---------- 

20 collector 

21 ResultsCollector object 

22 """ 

23 

24 def __init__(self, 

25 collector: ResultsCollector, 

26 only_one_pulse: bool): 

27 super().__init__(collector) 

28 self.only_one_pulse = only_one_pulse 

29 if only_one_pulse: 

30 if isinstance(self.density_matrices, ConvolutionDensityMatrices): 

31 assert len(self.density_matrices.pulses) == 1, 'Only one pulse allowed' 

32 else: 

33 if isinstance(self.density_matrices, ConvolutionDensityMatrices): 

34 self._ulm_tag = 'Time TCM' 

35 else: 

36 assert isinstance(self.density_matrices, FrequencyDensityMatrices) 

37 self._ulm_tag = 'TCM' 

38 

39 @property 

40 def common_arrays(self) -> dict[str, NDArray[np.float64] | NDArray[np.int64] | int | float]: 

41 from ..calculators import DipoleCalculator 

42 assert isinstance(self.calc, DipoleCalculator) 

43 

44 common = super().common_arrays 

45 

46 if self.calc.sigma is not None: 

47 # There is an energy grid 

48 common['sigma'] = self.calc.sigma 

49 common['energy_o'] = np.array(self.calc.energies_occ) 

50 common['energy_u'] = np.array(self.calc.energies_unocc) 

51 

52 if isinstance(self.density_matrices, ConvolutionDensityMatrices): 

53 common['time_t'] = self.density_matrices.times * 1e-3 

54 else: 

55 assert isinstance(self.density_matrices, FrequencyDensityMatrices) 

56 common['freq_w'] = self.density_matrices.frequencies 

57 common['frequency_broadening'] = self.density_matrices.frequency_broadening 

58 common['osc_prefactor_w'] = self.calc.oscillator_strength_prefactor 

59 

60 if isinstance(self.density_matrices, ConvolutionDensityMatrices): 

61 # If pulses are Gaussian pulses, then get dictionaries of 'pulsefreq' and 'pulsefwhm' 

62 pulsedicts = [get_gaussian_pulse_values(pulse) for pulse in self.density_matrices.pulses] 

63 try: 

64 pulsefreqs = [d['pulsefreq'] for d in pulsedicts] 

65 pulsefwhms = [d['pulsefwhm'] for d in pulsedicts] 

66 

67 if self.only_one_pulse: 

68 common['pulsefreq'] = pulsefreqs[0] 

69 common['pulsefwhm'] = pulsefwhms[0] 

70 else: 

71 common['pulsefreq_p'] = np.array(pulsefreqs) 

72 common['pulsefwhm_p'] = np.array(pulsefwhms) 

73 except KeyError: 

74 # Not GaussianPulses 

75 pass 

76 

77 return common 

78 

79 def fill_ulm(self, 

80 writer, 

81 work: WorkMetadata, 

82 result: Result): 

83 if self.collector.calc_kwargs.get('yield_total_ou', False): 

84 writer.fill(result['dm_ouv']) 

85 

86 def write_empty_arrays_ulm(self, writer): 

87 if not self.collector.calc_kwargs.get('yield_total_ou', False): 

88 return 

89 shape_ou = (len(self.calc.energies_occ), len(self.calc.energies_unocc)) 

90 if isinstance(self.density_matrices, ConvolutionDensityMatrices): 

91 Nt = len(self.density_matrices.times) 

92 # Real dipole 

93 writer.add_array('dm_touv', (Nt, ) + shape_ou + (3, ), dtype=float) 

94 else: 

95 assert isinstance(self.density_matrices, FrequencyDensityMatrices) 

96 Nw = len(self.density_matrices.frequencies) 

97 # Complex polarizability 

98 writer.add_array('dm_wouv', (Nw, ) + shape_ou + (3, ), dtype=complex)