Coverage for tests/integration/test_energy.py: 82%
83 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
1from __future__ import annotations
3from pathlib import Path
4import pytest
5import numpy as np
7from gpaw.mpi import world
8from tests import get_permanent_test_file
11def energy_pulse_sweep(gpw_fname: str,
12 ksd_fname: str,
13 *,
14 wfs_fname: str | None = None,
15 frho_dname: Path | None = None,
16 pulses, times, energy_o, energy_u, sigma,
17 reduce: bool = False):
18 """ Helper function to compute energies and energy decompositions
20 Allows testing that the following works:
22 - Obtaining density matrices, either from *_wfs.ulm
23 files, or from frho files.
24 - Calculating energy using EnergyCalculator
25 - Writing results to file using EnergyWriter
27 """
28 from rhodent.calculators import EnergyCalculator
29 from rhodent.response import ResponseFromWaveFunctions, ResponseFromFourierTransform
30 from rhodent.writers.writer import PulseConvolutionResultsCollector
31 from rhodent.writers.energy import EnergyWriter
33 if wfs_fname is not None:
34 assert frho_dname is None
35 assert wfs_fname is not None
36 perturbation = {'name': 'SincPulse', 'strength': 1e-6,
37 'cutoff_freq': 4, 'time0': 5, 'relative_t0': True}
38 response = ResponseFromWaveFunctions(wfs_fname=wfs_fname,
39 ksd=ksd_fname,
40 perturbation=perturbation)
41 else:
42 assert frho_dname is not None
43 assert wfs_fname is None
44 frho_fmt = str(frho_dname / 'w{freq:05.2f}-{reim}.npy')
45 perturbation = {'name': 'deltakick', 'strength': 1e-5}
46 response = ResponseFromFourierTransform(frho_fmt=frho_fmt,
47 ksd=ksd_fname,
48 perturbation=perturbation)
50 calc_kwargs = dict(yield_total_E_ia=True,
51 yield_total_E_ou=True,
52 yield_total_dists=True,
53 direction=2)
55 calc = EnergyCalculator(response=response,
56 times=times,
57 pulses=pulses,
58 voronoi=None,
59 energies_occ=energy_o,
60 energies_unocc=energy_u,
61 sigma=sigma,
62 )
64 writer = EnergyWriter(PulseConvolutionResultsCollector(calc, calc_kwargs), only_one_pulse=False)
66 data = dict(**writer.common_arrays)
67 data.update(writer.calculate_data()._data)
68 if world.rank == 0:
69 return data
70 else:
71 return None
74def get_keys():
75 """ Get the keys saved by the writer separated whether they represent
76 grids/parameters of data.
77 """
79 grid_keys = ['eig_n', 'eig_i', 'eig_a', 'imin', 'imax', 'amin', 'amax',
80 'time_t', 'pulsefreq_p', 'pulsefwhm_p', 'sigma', 'energy_o', 'energy_u']
81 data_keys = ['dm_pt', 'field_pt', 'Epulse_pt',
82 'resonant_high', 'resonant_low',
83 'total_pt', 'total_Hxc_pt',
84 'total_resonant_pt', 'total_resonant_Hxc_pt',
85 'E_ptia', 'Ec_ptia', 'E_transition_ptu', 'Ec_transition_ptu',
86 'E_ptou', 'Ec_ptou', 'total_proj_ptII', 'total_Hxc_proj_ptII']
88 return grid_keys, data_keys
91def get_reference_pulses_and_times():
92 from rhodent.utils import create_pulse
94 pulses = [create_pulse(pulsefreq, 5.0, 10.0) for pulsefreq in [1.1, 1.3]]
96 # The construction of times below guarantees that the times align to
97 # the wave function trajectory *_wfs.ulm
98 # Convolution using the *_fdm.ulm files is able to interpolate to any time,
99 # but convolution using the *_wfs.ulm files picks the closest available time.
100 # For some of the tests below it is thus crucial that these align.
101 grids = dict(
102 energy_o=np.arange(-1, 3.01, 0.5), # Very rough energy grid for tests
103 energy_u=np.arange(-3, 1.01, 0.5),
104 times=np.arange(1, 3000, 20)[40::20] * 10, # Gives 6 times, starting from ~8000 as
105 sigma=0.7,
106 pulses=pulses,
107 )
109 return grids
112def write_reference_data(ref_energy, gpw_fname, ksd_fname, wfs_fname):
113 from gpaw.mpi import world
115 assert world.size == 1, 'Run me in serial mode'
117 kwargs = get_reference_pulses_and_times()
118 data = energy_pulse_sweep(gpw_fname, ksd_fname, wfs_fname=wfs_fname,
119 **kwargs)
121 np.savez_compressed(ref_energy, **data)
122 print(f'Saved data to {ref_energy}')
125@pytest.mark.bigdata
126@pytest.mark.parametrize('test_system', ['Na8', 'Ag8'])
127def test_energy_frho_against_reference(ref_energy, gpw_fname, ksd_fname, frho_dname):
128 """ Test that we can compute hot carriers from *_fdm.ulm files and
129 get the same results as the reference file.
131 Note that the reference file is computed from *_wfs.ulm files.
132 """
133 kwargs = get_reference_pulses_and_times()
134 data = energy_pulse_sweep(gpw_fname, ksd_fname, frho_dname=frho_dname, **kwargs)
136 if world.rank != 0:
137 assert data is None
138 return
140 ref_data = np.load(ref_energy)
142 grid_keys, data_keys = get_keys()
144 assert set(grid_keys + data_keys) == set(ref_data.files)
146 for key in grid_keys:
147 np.testing.assert_allclose(ref_data[key], data[key], err_msg=key, rtol=0, atol=1e-12)
149 for key in data_keys:
150 # Allow for some mismatch, since it is not exactly the same data
151 if np.array(ref_data[key]).size == 0:
152 # No voronoi
153 continue
154 vmax = np.abs(ref_data[key]).max()
155 np.testing.assert_allclose(ref_data[key], data[key],
156 err_msg=key, atol=1e-3 * vmax, rtol=0.01)
159@pytest.mark.bigdata
160@pytest.mark.parametrize('test_system', ['Na8', 'Ag8'])
161def test_energy_wave_functions_against_reference(ref_energy, gpw_fname, ksd_fname, wfs_fname):
162 kwargs = get_reference_pulses_and_times()
163 data = energy_pulse_sweep(gpw_fname, ksd_fname, wfs_fname=wfs_fname,
164 **kwargs)
166 if world.rank != 0:
167 assert data is None
168 return
170 ref_data = np.load(ref_energy)
172 grid_keys, data_keys = get_keys()
174 assert set(grid_keys + data_keys) == set(ref_data.files)
176 for key in grid_keys + data_keys:
177 if np.array(ref_data[key]).size == 0:
178 # No voronoi
179 continue
180 np.testing.assert_allclose(ref_data[key], data[key], err_msg=key, rtol=0, atol=1e-12)
183if __name__ == '__main__':
184 for test_system in ['Na8', 'Ag8']:
185 wfs_fname = get_permanent_test_file(test_system, 'wfs_fname')
186 assert wfs_fname.exists(), 'Run me on a system where the time-dependent wave functions file exists'
187 write_reference_data(get_permanent_test_file(test_system, 'ref_energy'),
188 get_permanent_test_file(test_system, 'gpw_fname'),
189 get_permanent_test_file(test_system, 'ksd_fname'),
190 wfs_fname)