Coverage for tests/integration/test_energy.py: 82%

83 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-01 16:57 +0000

1from __future__ import annotations 

2 

3from pathlib import Path 

4import pytest 

5import numpy as np 

6 

7from gpaw.mpi import world 

8from tests import get_permanent_test_file 

9 

10 

11def energy_pulse_sweep(gpw_fname: str, 

12 ksd_fname: str, 

13 *, 

14 wfs_fname: str | None = None, 

15 frho_dname: Path | None = None, 

16 pulses, times, energy_o, energy_u, sigma, 

17 reduce: bool = False): 

18 """ Helper function to compute energies and energy decompositions 

19 

20 Allows testing that the following works: 

21 

22 - Obtaining density matrices, either from *_wfs.ulm 

23 files, or from frho files. 

24 - Calculating energy using EnergyCalculator 

25 - Writing results to file using EnergyWriter 

26 

27 """ 

28 from rhodent.calculators import EnergyCalculator 

29 from rhodent.response import ResponseFromWaveFunctions, ResponseFromFourierTransform 

30 from rhodent.writers.writer import PulseConvolutionResultsCollector 

31 from rhodent.writers.energy import EnergyWriter 

32 

33 if wfs_fname is not None: 

34 assert frho_dname is None 

35 assert wfs_fname is not None 

36 perturbation = {'name': 'SincPulse', 'strength': 1e-6, 

37 'cutoff_freq': 4, 'time0': 5, 'relative_t0': True} 

38 response = ResponseFromWaveFunctions(wfs_fname=wfs_fname, 

39 ksd=ksd_fname, 

40 perturbation=perturbation) 

41 else: 

42 assert frho_dname is not None 

43 assert wfs_fname is None 

44 frho_fmt = str(frho_dname / 'w{freq:05.2f}-{reim}.npy') 

45 perturbation = {'name': 'deltakick', 'strength': 1e-5} 

46 response = ResponseFromFourierTransform(frho_fmt=frho_fmt, 

47 ksd=ksd_fname, 

48 perturbation=perturbation) 

49 

50 calc_kwargs = dict(yield_total_E_ia=True, 

51 yield_total_E_ou=True, 

52 yield_total_dists=True, 

53 direction=2) 

54 

55 calc = EnergyCalculator(response=response, 

56 times=times, 

57 pulses=pulses, 

58 voronoi=None, 

59 energies_occ=energy_o, 

60 energies_unocc=energy_u, 

61 sigma=sigma, 

62 ) 

63 

64 writer = EnergyWriter(PulseConvolutionResultsCollector(calc, calc_kwargs), only_one_pulse=False) 

65 

66 data = dict(**writer.common_arrays) 

67 data.update(writer.calculate_data()._data) 

68 if world.rank == 0: 

69 return data 

70 else: 

71 return None 

72 

73 

74def get_keys(): 

75 """ Get the keys saved by the writer separated whether they represent 

76 grids/parameters of data. 

77 """ 

78 

79 grid_keys = ['eig_n', 'eig_i', 'eig_a', 'imin', 'imax', 'amin', 'amax', 

80 'time_t', 'pulsefreq_p', 'pulsefwhm_p', 'sigma', 'energy_o', 'energy_u'] 

81 data_keys = ['dm_pt', 'field_pt', 'Epulse_pt', 

82 'resonant_high', 'resonant_low', 

83 'total_pt', 'total_Hxc_pt', 

84 'total_resonant_pt', 'total_resonant_Hxc_pt', 

85 'E_ptia', 'Ec_ptia', 'E_transition_ptu', 'Ec_transition_ptu', 

86 'E_ptou', 'Ec_ptou', 'total_proj_ptII', 'total_Hxc_proj_ptII'] 

87 

88 return grid_keys, data_keys 

89 

90 

91def get_reference_pulses_and_times(): 

92 from rhodent.utils import create_pulse 

93 

94 pulses = [create_pulse(pulsefreq, 5.0, 10.0) for pulsefreq in [1.1, 1.3]] 

95 

96 # The construction of times below guarantees that the times align to 

97 # the wave function trajectory *_wfs.ulm 

98 # Convolution using the *_fdm.ulm files is able to interpolate to any time, 

99 # but convolution using the *_wfs.ulm files picks the closest available time. 

100 # For some of the tests below it is thus crucial that these align. 

101 grids = dict( 

102 energy_o=np.arange(-1, 3.01, 0.5), # Very rough energy grid for tests 

103 energy_u=np.arange(-3, 1.01, 0.5), 

104 times=np.arange(1, 3000, 20)[40::20] * 10, # Gives 6 times, starting from ~8000 as 

105 sigma=0.7, 

106 pulses=pulses, 

107 ) 

108 

109 return grids 

110 

111 

112def write_reference_data(ref_energy, gpw_fname, ksd_fname, wfs_fname): 

113 from gpaw.mpi import world 

114 

115 assert world.size == 1, 'Run me in serial mode' 

116 

117 kwargs = get_reference_pulses_and_times() 

118 data = energy_pulse_sweep(gpw_fname, ksd_fname, wfs_fname=wfs_fname, 

119 **kwargs) 

120 

121 np.savez_compressed(ref_energy, **data) 

122 print(f'Saved data to {ref_energy}') 

123 

124 

125@pytest.mark.bigdata 

126@pytest.mark.parametrize('test_system', ['Na8', 'Ag8']) 

127def test_energy_frho_against_reference(ref_energy, gpw_fname, ksd_fname, frho_dname): 

128 """ Test that we can compute hot carriers from *_fdm.ulm files and 

129 get the same results as the reference file. 

130 

131 Note that the reference file is computed from *_wfs.ulm files. 

132 """ 

133 kwargs = get_reference_pulses_and_times() 

134 data = energy_pulse_sweep(gpw_fname, ksd_fname, frho_dname=frho_dname, **kwargs) 

135 

136 if world.rank != 0: 

137 assert data is None 

138 return 

139 

140 ref_data = np.load(ref_energy) 

141 

142 grid_keys, data_keys = get_keys() 

143 

144 assert set(grid_keys + data_keys) == set(ref_data.files) 

145 

146 for key in grid_keys: 

147 np.testing.assert_allclose(ref_data[key], data[key], err_msg=key, rtol=0, atol=1e-12) 

148 

149 for key in data_keys: 

150 # Allow for some mismatch, since it is not exactly the same data 

151 if np.array(ref_data[key]).size == 0: 

152 # No voronoi 

153 continue 

154 vmax = np.abs(ref_data[key]).max() 

155 np.testing.assert_allclose(ref_data[key], data[key], 

156 err_msg=key, atol=1e-3 * vmax, rtol=0.01) 

157 

158 

159@pytest.mark.bigdata 

160@pytest.mark.parametrize('test_system', ['Na8', 'Ag8']) 

161def test_energy_wave_functions_against_reference(ref_energy, gpw_fname, ksd_fname, wfs_fname): 

162 kwargs = get_reference_pulses_and_times() 

163 data = energy_pulse_sweep(gpw_fname, ksd_fname, wfs_fname=wfs_fname, 

164 **kwargs) 

165 

166 if world.rank != 0: 

167 assert data is None 

168 return 

169 

170 ref_data = np.load(ref_energy) 

171 

172 grid_keys, data_keys = get_keys() 

173 

174 assert set(grid_keys + data_keys) == set(ref_data.files) 

175 

176 for key in grid_keys + data_keys: 

177 if np.array(ref_data[key]).size == 0: 

178 # No voronoi 

179 continue 

180 np.testing.assert_allclose(ref_data[key], data[key], err_msg=key, rtol=0, atol=1e-12) 

181 

182 

183if __name__ == '__main__': 

184 for test_system in ['Na8', 'Ag8']: 

185 wfs_fname = get_permanent_test_file(test_system, 'wfs_fname') 

186 assert wfs_fname.exists(), 'Run me on a system where the time-dependent wave functions file exists' 

187 write_reference_data(get_permanent_test_file(test_system, 'ref_energy'), 

188 get_permanent_test_file(test_system, 'gpw_fname'), 

189 get_permanent_test_file(test_system, 'ksd_fname'), 

190 wfs_fname)