Coverage for tests/unittests/writers/test_tcm_writers.py: 96%
46 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
1from __future__ import annotations
3import pytest
4import numpy as np
5from ase.io.ulm import Reader
7from gpaw.mpi import world
8from rhodent.calculators.dipole import DipoleCalculator
9from rhodent.utils import create_pulse
12def load_ulm(fname):
13 with Reader(fname) as reader:
14 return reader.asdict()
17def compare(leftdata, rightdata, keys):
18 for key in keys:
19 np.testing.assert_allclose(leftdata[key], rightdata[key], err_msg=key)
22def compare_known(leftdata, rightdata, keys, reference):
23 # Loose tolerance to allow for tcm and time grids which are written out with less precision
24 for key, ref in zip(keys, reference):
25 np.testing.assert_allclose(leftdata[key], rightdata[key], atol=1e-5, err_msg=key)
26 np.testing.assert_allclose(leftdata[key], ref, atol=1e-5, err_msg=f'{key} != reference')
29@pytest.mark.parametrize('test_system', ['Na55'])
30def test_tcm_time(tmp_path, mock_response):
31 time_t = np.linspace(0, 30, 7) # In fs
33 # Set up mock objects
34 pulse = create_pulse(4.0)
35 calc_kwargs = dict(response=mock_response(),
36 voronoi=None,
37 times=time_t * 1e3,
38 pulses=[pulse],
39 energies_occ=np.linspace(-5, 1, 10),
40 energies_unocc=np.linspace(-1, 5, 10),
41 sigma=0.05)
42 calc = DipoleCalculator(**calc_kwargs)
43 kwargs = dict(include_tcm=True)
45 calc.calculate_and_write(str(tmp_path / 'tcm.ulm'), **kwargs)
46 calc.calculate_and_write(str(tmp_path / 'tcm.npz'), **kwargs)
48 if world.rank != 0:
49 return
51 data_ulm = load_ulm(tmp_path / 'tcm.ulm')
52 data_npz = np.load(tmp_path / 'tcm.npz')
54 compare_known(data_npz, data_ulm,
55 ['pulsefreq', 'pulsefwhm', 'energy_o', 'energy_u', 'time_t', 'sigma'],
56 [4.0, 5.0, calc_kwargs['energies_occ'], calc_kwargs['energies_unocc'], time_t, calc_kwargs['sigma']])
57 compare(data_npz, data_ulm, ['dm_tv', 'dm_touv'])
60@pytest.mark.parametrize('test_system', ['Na55'])
61def test_tcm_frequency(tmp_path, mock_response):
62 freq_w = np.linspace(2, 6, 7) # In units of eV
64 # Set up mock objects
65 calc_kwargs = dict(response=mock_response(),
66 voronoi=None,
67 frequencies=freq_w,
68 energies_occ=np.linspace(-5, 1, 10),
69 energies_unocc=np.linspace(-1, 5, 10),
70 sigma=0.05)
71 calc = DipoleCalculator(**calc_kwargs)
72 kwargs = dict(include_tcm=True)
74 calc.calculate_and_write(str(tmp_path / 'tcm.ulm'), **kwargs)
75 calc.calculate_and_write(str(tmp_path / 'tcm.npz'), **kwargs)
77 if world.rank != 0:
78 return
80 data_ulm = load_ulm(tmp_path / 'tcm.ulm')
81 data_npz = np.load(tmp_path / 'tcm.npz')
83 compare_known(data_npz, data_ulm,
84 ['energy_o', 'energy_u', 'freq_w', 'sigma'],
85 [calc_kwargs['energies_occ'], calc_kwargs['energies_unocc'], freq_w, calc_kwargs['sigma']])
86 compare(data_npz, data_ulm, ['dm_wv', 'dm_wouv'])