Coverage for tests/integration/test_hotcarriers.py: 82%
85 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
1from __future__ import annotations
3from pathlib import Path
4import pytest
5import numpy as np
7from gpaw.mpi import world
8from tests import get_permanent_test_file
11def hcdist_pulse_sweep(gpw_fname: str,
12 ksd_fname: str,
13 *,
14 wfs_fname: str | None = None,
15 frho_dname: Path | None = None,
16 pulses, times, energy_o, energy_u, sigma):
17 """ Helper function to compute hot carrier distributions.
19 Allows testing that the following works:
21 - Obtaining density matrices, either from *_wfs.ulm
22 files, or from frho files.
23 - Calculating Voronoi weights using
24 VoronoiLCAOWeightCalculator and
25 VoronoiWeightCalculator.
26 - Calculating hot carriers using HotCarriersCalculator
27 - Writing results to file using HotCarriersWriter
29 """
30 from rhodent.calculators import HotCarriersCalculator
31 from rhodent.response import ResponseFromWaveFunctions, ResponseFromFourierTransform
32 from rhodent.writers.writer import PulseConvolutionResultsCollector
33 from rhodent.writers.hcdist import HotCarriersWriter
35 voronoi = get_voronoi(gpw_fname)
37 if wfs_fname is not None:
38 assert frho_dname is None
39 assert wfs_fname is not None
40 perturbation = {'name': 'SincPulse', 'strength': 1e-6,
41 'cutoff_freq': 4, 'time0': 5, 'relative_t0': True}
42 response = ResponseFromWaveFunctions(wfs_fname=wfs_fname,
43 ksd=ksd_fname,
44 perturbation=perturbation)
45 else:
46 assert frho_dname is not None
47 assert wfs_fname is None
48 frho_fmt = str(frho_dname / 'w{freq:05.2f}-{reim}.npy')
49 perturbation = {'name': 'deltakick', 'strength': 1e-5}
50 response = ResponseFromFourierTransform(frho_fmt=frho_fmt,
51 ksd=ksd_fname,
52 perturbation=perturbation)
54 calc_kwargs: dict[str, bool] = dict(
55 yield_total_hcdists=True,
56 yield_proj_hcdists=True,
57 yield_total_P=True,
58 yield_proj_P=True,
59 yield_total_P_ou=True,
60 )
61 calc = HotCarriersCalculator(response=response,
62 times=times,
63 pulses=pulses,
64 voronoi=voronoi,
65 energies_occ=energy_o,
66 energies_unocc=energy_u,
67 sigma=sigma,
68 )
70 writer = HotCarriersWriter(PulseConvolutionResultsCollector(calc, calc_kwargs), only_one_pulse=False)
72 data = dict(**writer.common_arrays)
73 data.update(writer.calculate_data()._data)
74 if world.rank == 0:
75 return data
76 else:
77 return None
80def get_keys():
81 """ Get the keys saved by the writer separated whether they represent
82 grids/parameters of data.
83 """
85 grid_keys = ['eig_n', 'eig_i', 'eig_a', 'imin', 'imax', 'amin', 'amax',
86 'time_t', 'pulsefreq_p', 'pulsefwhm_p', 'sigma', 'energy_o', 'energy_u']
87 data_keys = ['sumocc_pt', 'sumunocc_pt', 'P_pti', 'P_pta', 'hcdist_pto', 'hcdist_ptu',
88 'P_ptou', 'sumocc_proj_ptI', 'sumunocc_proj_ptI',
89 'P_proj_ptIi', 'P_proj_ptIa', 'hcdist_proj_ptIo', 'hcdist_proj_ptIu']
91 return grid_keys, data_keys
94def get_reference_pulses_and_times():
95 from rhodent.utils import create_pulse
97 pulses = [create_pulse(pulsefreq, 5.0, 10.0) for pulsefreq in [1.1, 1.3]]
99 # The construction of times below guarantees that the times align to
100 # the wave function trajectory *_wfs.ulm
101 # Convolution using the *_fdm.ulm files is able to interpolate to any time,
102 # but convolution using the *_wfs.ulm files picks the closest available time.
103 # For some of the tests below it is thus crucial that these align.
104 grids = dict(
105 energy_o=np.arange(-1, 3.01, 0.5), # Very rough energy grid for tests
106 energy_u=np.arange(-3, 1.01, 0.5),
107 times=np.arange(1, 3000, 20)[40::20] * 10, # Gives 6 times, starting from ~8000 as
108 sigma=0.7,
109 pulses=pulses,
110 )
112 return grids
115def get_voronoi(gpw_fname):
116 from rhodent.voronoi import VoronoiLCAOWeightCalculator, VoronoiWeightCalculator
117 voronoi_lcao = VoronoiLCAOWeightCalculator(
118 atom_projections=[[1]],
119 gpw_file=gpw_fname)
120 voronoi = VoronoiWeightCalculator(voronoi_lcao)
122 return voronoi
125def write_reference_data(ref_hcdist, gpw_fname, ksd_fname, wfs_fname):
126 from gpaw.mpi import world
128 assert world.size == 1, 'Run me in serial mode'
130 kwargs = get_reference_pulses_and_times()
131 data = hcdist_pulse_sweep(gpw_fname, ksd_fname, wfs_fname=wfs_fname,
132 **kwargs)
134 np.savez_compressed(ref_hcdist, **data)
135 print(f'Saved data to {ref_hcdist}')
138@pytest.mark.bigdata
139@pytest.mark.parametrize('test_system', ['Na8', 'Ag8'])
140def test_hcdist_frho_against_reference(ref_hcdist, gpw_fname, ksd_fname, frho_dname):
141 """ Test that we can compute hot carriers from *_fdm.ulm files and
142 get the same results as the reference file.
144 Note that the reference file is computed from *_wfs.ulm files.
145 """
146 kwargs = get_reference_pulses_and_times()
147 data = hcdist_pulse_sweep(gpw_fname, ksd_fname, frho_dname=frho_dname, **kwargs)
149 if world.rank != 0:
150 assert data is None
151 return
153 ref_data = np.load(ref_hcdist)
155 grid_keys, data_keys = get_keys()
157 assert set(grid_keys + data_keys) == set(ref_data.files)
159 for key in grid_keys:
160 np.testing.assert_allclose(ref_data[key], data[key], err_msg=key, rtol=0, atol=1e-12)
162 for key in data_keys:
163 # Allow for some mismatch, since it is not exactly the same data
164 vmax = np.abs(ref_data[key]).max()
165 np.testing.assert_allclose(ref_data[key], data[key],
166 err_msg=key, atol=1e-3 * vmax, rtol=0.01)
169@pytest.mark.bigdata
170@pytest.mark.parametrize('test_system', ['Na8', 'Ag8'])
171def test_hcdist_wave_functions_against_reference(ref_hcdist, gpw_fname, ksd_fname, wfs_fname):
172 kwargs = get_reference_pulses_and_times()
173 data = hcdist_pulse_sweep(gpw_fname, ksd_fname, wfs_fname=wfs_fname,
174 **kwargs)
176 if world.rank != 0:
177 assert data is None
178 return
180 ref_data = np.load(ref_hcdist)
182 grid_keys, data_keys = get_keys()
184 assert set(grid_keys + data_keys) == set(ref_data.files)
186 for key in grid_keys + data_keys:
187 np.testing.assert_allclose(ref_data[key], data[key], err_msg=key, rtol=0, atol=1e-12)
190if __name__ == '__main__':
191 for test_system in ['Na8', 'Ag8']:
192 wfs_fname = get_permanent_test_file(test_system, 'wfs_fname')
193 assert wfs_fname.exists(), 'Run me on a system where the time-dependent wave functions file exists'
194 write_reference_data(get_permanent_test_file(test_system, 'ref_hcdist'),
195 get_permanent_test_file(test_system, 'gpw_fname'),
196 get_permanent_test_file(test_system, 'ksd_fname'),
197 wfs_fname)