Coverage for tests/integration/test_spectrum.py: 85%
62 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
1from __future__ import annotations
3import pytest
4import numpy as np
6from gpaw.mpi import world
7from tests import get_permanent_test_file
8from rhodent.spectrum import SpectrumCalculator
11def write_reference_data(ref_spectrum, dipolefile, perturbation):
12 from gpaw.mpi import world
14 assert world.size == 1, 'Run me in serial mode'
16 calc = SpectrumCalculator.from_file(dipolefile, perturbation)
18 frequencies = np.arange(2, 5, 0.05)
19 frequency_broadening = 0.1
21 calc.calculate_spectrum_and_write(ref_spectrum,
22 frequencies=frequencies,
23 frequency_broadening=frequency_broadening)
26@pytest.mark.parametrize('test_system', ['Ag201CO'])
27def test_spectrum(tmp_path, test_system):
28 """ Test that we can compute the spectrum from a dipole moment file
29 """
30 frequencies = np.arange(2, 5, 0.05)
32 for name, perturbation in [
33 ('gauss',
34 {'name': 'GaussianPulse', 'strength': 1e-5, 'frequency': 3.8,
35 'time0': 10e3, 'sigma': 0.34, 'sincos': 'cos'}),
36 ('sinc',
37 {'name': 'SincPulse', 'strength': 1e-6,
38 'cutoff_freq': 8, 'time0': 5, 'relative_t0': True}),
39 ('delta',
40 {'name': 'deltakick', 'strength': 1e-5}),
41 ]:
42 dipolefile = get_permanent_test_file(test_system, f'dm_{name}')
43 calc = SpectrumCalculator.from_file(dipolefile, perturbation)
45 # Test no broadening
46 # Write npz file
47 calc.calculate_spectrum_and_write(tmp_path / 'spec.npz',
48 frequencies=frequencies,
49 frequency_broadening=0)
50 # Write dat file
51 calc.calculate_spectrum_and_write(tmp_path / 'spec.dat',
52 frequencies=frequencies,
53 frequency_broadening=0)
55 if world.rank == 0:
56 ref_archive = np.load(get_permanent_test_file(test_system, 'ref_spectrum'))
57 archive = np.load(tmp_path / 'spec.npz')
58 data = np.loadtxt(tmp_path / 'spec.dat')
59 text = (tmp_path / 'spec.dat').read_text()
60 assert 'Total time = 30.0000 fs, Time steps = 20.00 as' in text
61 assert 'No broadening' in text
62 assert str(calc.perturbation).split('\n')[0] in text
64 # Compare dat to npz
65 np.testing.assert_allclose(archive['freq_w'], data[:, 0], rtol=0, atol=1e-12)
66 np.testing.assert_allclose(archive['osc_wv'], data[:, 1:])
68 # Check that perturbation parameters are written
69 for key, value in calc.perturbation.todict().items():
70 np.testing.assert_equal(archive[f'perturbation_{key}'], value) # npz file
72 # Test with broadening
73 # Write npz file
74 calc.calculate_spectrum_and_write(tmp_path / 'spec.npz',
75 frequencies=frequencies,
76 frequency_broadening=0.1)
77 # Write dat file
78 calc.calculate_spectrum_and_write(tmp_path / 'spec.dat',
79 frequencies=frequencies,
80 frequency_broadening=0.1)
82 if world.rank == 0:
83 ref_archive = np.load(get_permanent_test_file(test_system, 'ref_spectrum'))
84 archive = np.load(tmp_path / 'spec.npz')
85 data = np.loadtxt(tmp_path / 'spec.dat')
86 text = (tmp_path / 'spec.dat').read_text()
87 assert 'Total time = 30.0000 fs, Time steps = 20.00 as' in text
88 assert 'Gaussian broadening' in text
89 assert str(calc.perturbation).split('\n')[0] in text
91 # Compare dat to npz
92 np.testing.assert_allclose(archive['freq_w'], data[:, 0], rtol=0, atol=1e-12)
93 np.testing.assert_allclose(archive['osc_wv'], data[:, 1:])
95 # Check that perturbation parameters are written
96 for key, value in calc.perturbation.todict().items():
97 np.testing.assert_equal(archive[f'perturbation_{key}'], value) # npz file
99 # Compare to reference
100 np.testing.assert_equal(archive['freq_w'], ref_archive['freq_w'])
101 np.testing.assert_equal(archive['frequency_broadening'], ref_archive['frequency_broadening'])
103 # Allow for some mismatch, since it is not exactly the same data
104 key = 'osc_wv'
105 vmax = np.abs(ref_archive[key]).max()
106 ref_osc_wv = ref_archive[key]
107 osc_wv = archive[key]
108 if name == 'gauss':
109 # Obtaining a spectrum from Gauss gives a very narrow range
110 # that is reasonable..
111 flt_w = (archive['freq_w'] > 3.65) & (archive['freq_w'] < 3.9)
112 ref_osc_wv = ref_osc_wv[flt_w]
113 osc_wv = osc_wv[flt_w]
114 np.testing.assert_allclose(ref_osc_wv, osc_wv,
115 err_msg=name, atol=1e-3 * vmax, rtol=0.1)
118if __name__ == '__main__':
119 for test_system in ['Ag201CO']:
120 dipolefile = get_permanent_test_file(test_system, 'dm_delta')
121 write_reference_data(get_permanent_test_file(test_system, 'ref_spectrum'),
122 get_permanent_test_file(test_system, 'dm_delta'),
123 {'name': 'deltakick', 'strength': 1e-5})