Coverage for tests/unittests/writers/test_energy_writers.py: 98%
44 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
1from __future__ import annotations
3import pytest
4import numpy as np
5from ase.io.ulm import Reader
7from gpaw.mpi import world
8from rhodent.calculators.energy import EnergyCalculator
9from rhodent.utils import create_pulse
12def load_ulm(fname):
13 with Reader(fname) as reader:
14 return reader.asdict()
17def compare(leftdata, rightdata, keys):
18 for key in keys:
19 np.testing.assert_allclose(leftdata[key], rightdata[key], err_msg=key)
22def compare_known(leftdata, rightdata, keys, reference):
23 # Loose tolerance to allow for energy and time grids which are written out with less precision
24 for key, ref in zip(keys, reference):
25 np.testing.assert_allclose(leftdata[key], rightdata[key], atol=1e-5, err_msg=key)
26 np.testing.assert_allclose(leftdata[key], ref, atol=1e-5, err_msg=f'{key} != reference')
29@pytest.mark.parametrize('test_system', ['Na55'])
30def test_energy(tmp_path, mock_voronoi, mock_response):
31 time_t = np.linspace(0, 30, 7) # In fs
33 # Set up mock objects
34 voronoi = mock_voronoi(atom_projections=[[0, 1, 2], [3]])
35 pulse = create_pulse(4.0)
36 calc_kwargs = dict(response=mock_response(),
37 voronoi=None,
38 times=time_t * 1e3,
39 pulses=[pulse],
40 energies_occ=np.linspace(-5, 1, 10),
41 energies_unocc=np.linspace(-1, 5, 9),
42 sigma=0.05)
43 calc = EnergyCalculator(**calc_kwargs)
45 kwargs = dict(include_tcm=True, save_dist=True)
47 # Calculate with and without Voronoi
48 calc.calculate_and_write(str(tmp_path / 'energy.npz'), **kwargs)
49 calc.calculate_and_write(str(tmp_path / 'energy.ulm'), **kwargs)
51 calc_kwargs['voronoi'] = voronoi
52 calc = EnergyCalculator(**calc_kwargs)
53 calc.calculate_and_write(str(tmp_path / 'energy_voronoi.npz'), **kwargs)
54 calc.calculate_and_write(str(tmp_path / 'energy_voronoi.ulm'), **kwargs)
56 if world.rank != 0:
57 return
59 # Verify without Voronoi
60 data_ulm = load_ulm(tmp_path / 'energy.ulm')
61 data_npz = np.load(tmp_path / 'energy.npz')
63 compare_known(data_npz, data_ulm,
64 ['pulsefreq', 'pulsefwhm', 'energy_o', 'energy_u', 'time_t', 'sigma'],
65 [4.0, 5.0, calc_kwargs['energies_occ'], calc_kwargs['energies_unocc'], time_t, calc_kwargs['sigma']])
66 compare(data_npz, data_ulm, ['dm_t', 'total_t', 'total_Hxc_t', 'Epulse_t', 'E_tou'])
68 total_t = data_npz['total_t']
69 total_Hxc_t = data_npz['total_Hxc_t']
70 Epulse_t = data_npz['Epulse_t']
72 # Verify with Voronoi
73 data_ulm = load_ulm(tmp_path / 'energy_voronoi.ulm')
74 data_npz = np.load(tmp_path / 'energy_voronoi.npz')
76 compare_known(data_npz, data_ulm,
77 ['pulsefreq', 'pulsefwhm', 'energy_o', 'energy_u', 'time_t', 'sigma',
78 'total_t', 'total_Hxc_t', 'Epulse_t'],
79 [4.0, 5.0, calc_kwargs['energies_occ'], calc_kwargs['energies_unocc'], time_t, calc_kwargs['sigma'],
80 total_t, total_Hxc_t, Epulse_t])
81 compare(data_npz, data_ulm, ['total_proj_tII', 'total_Hxc_proj_tII'])