Coverage for tests/integration/test_hotcarriers.py: 82%

85 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-01 16:57 +0000

1from __future__ import annotations 

2 

3from pathlib import Path 

4import pytest 

5import numpy as np 

6 

7from gpaw.mpi import world 

8from tests import get_permanent_test_file 

9 

10 

11def hcdist_pulse_sweep(gpw_fname: str, 

12 ksd_fname: str, 

13 *, 

14 wfs_fname: str | None = None, 

15 frho_dname: Path | None = None, 

16 pulses, times, energy_o, energy_u, sigma): 

17 """ Helper function to compute hot carrier distributions. 

18 

19 Allows testing that the following works: 

20 

21 - Obtaining density matrices, either from *_wfs.ulm 

22 files, or from frho files. 

23 - Calculating Voronoi weights using 

24 VoronoiLCAOWeightCalculator and 

25 VoronoiWeightCalculator. 

26 - Calculating hot carriers using HotCarriersCalculator 

27 - Writing results to file using HotCarriersWriter 

28 

29 """ 

30 from rhodent.calculators import HotCarriersCalculator 

31 from rhodent.response import ResponseFromWaveFunctions, ResponseFromFourierTransform 

32 from rhodent.writers.writer import PulseConvolutionResultsCollector 

33 from rhodent.writers.hcdist import HotCarriersWriter 

34 

35 voronoi = get_voronoi(gpw_fname) 

36 

37 if wfs_fname is not None: 

38 assert frho_dname is None 

39 assert wfs_fname is not None 

40 perturbation = {'name': 'SincPulse', 'strength': 1e-6, 

41 'cutoff_freq': 4, 'time0': 5, 'relative_t0': True} 

42 response = ResponseFromWaveFunctions(wfs_fname=wfs_fname, 

43 ksd=ksd_fname, 

44 perturbation=perturbation) 

45 else: 

46 assert frho_dname is not None 

47 assert wfs_fname is None 

48 frho_fmt = str(frho_dname / 'w{freq:05.2f}-{reim}.npy') 

49 perturbation = {'name': 'deltakick', 'strength': 1e-5} 

50 response = ResponseFromFourierTransform(frho_fmt=frho_fmt, 

51 ksd=ksd_fname, 

52 perturbation=perturbation) 

53 

54 calc_kwargs: dict[str, bool] = dict( 

55 yield_total_hcdists=True, 

56 yield_proj_hcdists=True, 

57 yield_total_P=True, 

58 yield_proj_P=True, 

59 yield_total_P_ou=True, 

60 ) 

61 calc = HotCarriersCalculator(response=response, 

62 times=times, 

63 pulses=pulses, 

64 voronoi=voronoi, 

65 energies_occ=energy_o, 

66 energies_unocc=energy_u, 

67 sigma=sigma, 

68 ) 

69 

70 writer = HotCarriersWriter(PulseConvolutionResultsCollector(calc, calc_kwargs), only_one_pulse=False) 

71 

72 data = dict(**writer.common_arrays) 

73 data.update(writer.calculate_data()._data) 

74 if world.rank == 0: 

75 return data 

76 else: 

77 return None 

78 

79 

80def get_keys(): 

81 """ Get the keys saved by the writer separated whether they represent 

82 grids/parameters of data. 

83 """ 

84 

85 grid_keys = ['eig_n', 'eig_i', 'eig_a', 'imin', 'imax', 'amin', 'amax', 

86 'time_t', 'pulsefreq_p', 'pulsefwhm_p', 'sigma', 'energy_o', 'energy_u'] 

87 data_keys = ['sumocc_pt', 'sumunocc_pt', 'P_pti', 'P_pta', 'hcdist_pto', 'hcdist_ptu', 

88 'P_ptou', 'sumocc_proj_ptI', 'sumunocc_proj_ptI', 

89 'P_proj_ptIi', 'P_proj_ptIa', 'hcdist_proj_ptIo', 'hcdist_proj_ptIu'] 

90 

91 return grid_keys, data_keys 

92 

93 

94def get_reference_pulses_and_times(): 

95 from rhodent.utils import create_pulse 

96 

97 pulses = [create_pulse(pulsefreq, 5.0, 10.0) for pulsefreq in [1.1, 1.3]] 

98 

99 # The construction of times below guarantees that the times align to 

100 # the wave function trajectory *_wfs.ulm 

101 # Convolution using the *_fdm.ulm files is able to interpolate to any time, 

102 # but convolution using the *_wfs.ulm files picks the closest available time. 

103 # For some of the tests below it is thus crucial that these align. 

104 grids = dict( 

105 energy_o=np.arange(-1, 3.01, 0.5), # Very rough energy grid for tests 

106 energy_u=np.arange(-3, 1.01, 0.5), 

107 times=np.arange(1, 3000, 20)[40::20] * 10, # Gives 6 times, starting from ~8000 as 

108 sigma=0.7, 

109 pulses=pulses, 

110 ) 

111 

112 return grids 

113 

114 

115def get_voronoi(gpw_fname): 

116 from rhodent.voronoi import VoronoiLCAOWeightCalculator, VoronoiWeightCalculator 

117 voronoi_lcao = VoronoiLCAOWeightCalculator( 

118 atom_projections=[[1]], 

119 gpw_file=gpw_fname) 

120 voronoi = VoronoiWeightCalculator(voronoi_lcao) 

121 

122 return voronoi 

123 

124 

125def write_reference_data(ref_hcdist, gpw_fname, ksd_fname, wfs_fname): 

126 from gpaw.mpi import world 

127 

128 assert world.size == 1, 'Run me in serial mode' 

129 

130 kwargs = get_reference_pulses_and_times() 

131 data = hcdist_pulse_sweep(gpw_fname, ksd_fname, wfs_fname=wfs_fname, 

132 **kwargs) 

133 

134 np.savez_compressed(ref_hcdist, **data) 

135 print(f'Saved data to {ref_hcdist}') 

136 

137 

138@pytest.mark.bigdata 

139@pytest.mark.parametrize('test_system', ['Na8', 'Ag8']) 

140def test_hcdist_frho_against_reference(ref_hcdist, gpw_fname, ksd_fname, frho_dname): 

141 """ Test that we can compute hot carriers from *_fdm.ulm files and 

142 get the same results as the reference file. 

143 

144 Note that the reference file is computed from *_wfs.ulm files. 

145 """ 

146 kwargs = get_reference_pulses_and_times() 

147 data = hcdist_pulse_sweep(gpw_fname, ksd_fname, frho_dname=frho_dname, **kwargs) 

148 

149 if world.rank != 0: 

150 assert data is None 

151 return 

152 

153 ref_data = np.load(ref_hcdist) 

154 

155 grid_keys, data_keys = get_keys() 

156 

157 assert set(grid_keys + data_keys) == set(ref_data.files) 

158 

159 for key in grid_keys: 

160 np.testing.assert_allclose(ref_data[key], data[key], err_msg=key, rtol=0, atol=1e-12) 

161 

162 for key in data_keys: 

163 # Allow for some mismatch, since it is not exactly the same data 

164 vmax = np.abs(ref_data[key]).max() 

165 np.testing.assert_allclose(ref_data[key], data[key], 

166 err_msg=key, atol=1e-3 * vmax, rtol=0.01) 

167 

168 

169@pytest.mark.bigdata 

170@pytest.mark.parametrize('test_system', ['Na8', 'Ag8']) 

171def test_hcdist_wave_functions_against_reference(ref_hcdist, gpw_fname, ksd_fname, wfs_fname): 

172 kwargs = get_reference_pulses_and_times() 

173 data = hcdist_pulse_sweep(gpw_fname, ksd_fname, wfs_fname=wfs_fname, 

174 **kwargs) 

175 

176 if world.rank != 0: 

177 assert data is None 

178 return 

179 

180 ref_data = np.load(ref_hcdist) 

181 

182 grid_keys, data_keys = get_keys() 

183 

184 assert set(grid_keys + data_keys) == set(ref_data.files) 

185 

186 for key in grid_keys + data_keys: 

187 np.testing.assert_allclose(ref_data[key], data[key], err_msg=key, rtol=0, atol=1e-12) 

188 

189 

190if __name__ == '__main__': 

191 for test_system in ['Na8', 'Ag8']: 

192 wfs_fname = get_permanent_test_file(test_system, 'wfs_fname') 

193 assert wfs_fname.exists(), 'Run me on a system where the time-dependent wave functions file exists' 

194 write_reference_data(get_permanent_test_file(test_system, 'ref_hcdist'), 

195 get_permanent_test_file(test_system, 'gpw_fname'), 

196 get_permanent_test_file(test_system, 'ksd_fname'), 

197 wfs_fname)