Coverage for tests/integration/test_spectrum.py: 85%

62 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-01 16:57 +0000

1from __future__ import annotations 

2 

3import pytest 

4import numpy as np 

5 

6from gpaw.mpi import world 

7from tests import get_permanent_test_file 

8from rhodent.spectrum import SpectrumCalculator 

9 

10 

11def write_reference_data(ref_spectrum, dipolefile, perturbation): 

12 from gpaw.mpi import world 

13 

14 assert world.size == 1, 'Run me in serial mode' 

15 

16 calc = SpectrumCalculator.from_file(dipolefile, perturbation) 

17 

18 frequencies = np.arange(2, 5, 0.05) 

19 frequency_broadening = 0.1 

20 

21 calc.calculate_spectrum_and_write(ref_spectrum, 

22 frequencies=frequencies, 

23 frequency_broadening=frequency_broadening) 

24 

25 

26@pytest.mark.parametrize('test_system', ['Ag201CO']) 

27def test_spectrum(tmp_path, test_system): 

28 """ Test that we can compute the spectrum from a dipole moment file 

29 """ 

30 frequencies = np.arange(2, 5, 0.05) 

31 

32 for name, perturbation in [ 

33 ('gauss', 

34 {'name': 'GaussianPulse', 'strength': 1e-5, 'frequency': 3.8, 

35 'time0': 10e3, 'sigma': 0.34, 'sincos': 'cos'}), 

36 ('sinc', 

37 {'name': 'SincPulse', 'strength': 1e-6, 

38 'cutoff_freq': 8, 'time0': 5, 'relative_t0': True}), 

39 ('delta', 

40 {'name': 'deltakick', 'strength': 1e-5}), 

41 ]: 

42 dipolefile = get_permanent_test_file(test_system, f'dm_{name}') 

43 calc = SpectrumCalculator.from_file(dipolefile, perturbation) 

44 

45 # Test no broadening 

46 # Write npz file 

47 calc.calculate_spectrum_and_write(tmp_path / 'spec.npz', 

48 frequencies=frequencies, 

49 frequency_broadening=0) 

50 # Write dat file 

51 calc.calculate_spectrum_and_write(tmp_path / 'spec.dat', 

52 frequencies=frequencies, 

53 frequency_broadening=0) 

54 

55 if world.rank == 0: 

56 ref_archive = np.load(get_permanent_test_file(test_system, 'ref_spectrum')) 

57 archive = np.load(tmp_path / 'spec.npz') 

58 data = np.loadtxt(tmp_path / 'spec.dat') 

59 text = (tmp_path / 'spec.dat').read_text() 

60 assert 'Total time = 30.0000 fs, Time steps = 20.00 as' in text 

61 assert 'No broadening' in text 

62 assert str(calc.perturbation).split('\n')[0] in text 

63 

64 # Compare dat to npz 

65 np.testing.assert_allclose(archive['freq_w'], data[:, 0], rtol=0, atol=1e-12) 

66 np.testing.assert_allclose(archive['osc_wv'], data[:, 1:]) 

67 

68 # Check that perturbation parameters are written 

69 for key, value in calc.perturbation.todict().items(): 

70 np.testing.assert_equal(archive[f'perturbation_{key}'], value) # npz file 

71 

72 # Test with broadening 

73 # Write npz file 

74 calc.calculate_spectrum_and_write(tmp_path / 'spec.npz', 

75 frequencies=frequencies, 

76 frequency_broadening=0.1) 

77 # Write dat file 

78 calc.calculate_spectrum_and_write(tmp_path / 'spec.dat', 

79 frequencies=frequencies, 

80 frequency_broadening=0.1) 

81 

82 if world.rank == 0: 

83 ref_archive = np.load(get_permanent_test_file(test_system, 'ref_spectrum')) 

84 archive = np.load(tmp_path / 'spec.npz') 

85 data = np.loadtxt(tmp_path / 'spec.dat') 

86 text = (tmp_path / 'spec.dat').read_text() 

87 assert 'Total time = 30.0000 fs, Time steps = 20.00 as' in text 

88 assert 'Gaussian broadening' in text 

89 assert str(calc.perturbation).split('\n')[0] in text 

90 

91 # Compare dat to npz 

92 np.testing.assert_allclose(archive['freq_w'], data[:, 0], rtol=0, atol=1e-12) 

93 np.testing.assert_allclose(archive['osc_wv'], data[:, 1:]) 

94 

95 # Check that perturbation parameters are written 

96 for key, value in calc.perturbation.todict().items(): 

97 np.testing.assert_equal(archive[f'perturbation_{key}'], value) # npz file 

98 

99 # Compare to reference 

100 np.testing.assert_equal(archive['freq_w'], ref_archive['freq_w']) 

101 np.testing.assert_equal(archive['frequency_broadening'], ref_archive['frequency_broadening']) 

102 

103 # Allow for some mismatch, since it is not exactly the same data 

104 key = 'osc_wv' 

105 vmax = np.abs(ref_archive[key]).max() 

106 ref_osc_wv = ref_archive[key] 

107 osc_wv = archive[key] 

108 if name == 'gauss': 

109 # Obtaining a spectrum from Gauss gives a very narrow range 

110 # that is reasonable.. 

111 flt_w = (archive['freq_w'] > 3.65) & (archive['freq_w'] < 3.9) 

112 ref_osc_wv = ref_osc_wv[flt_w] 

113 osc_wv = osc_wv[flt_w] 

114 np.testing.assert_allclose(ref_osc_wv, osc_wv, 

115 err_msg=name, atol=1e-3 * vmax, rtol=0.1) 

116 

117 

118if __name__ == '__main__': 

119 for test_system in ['Ag201CO']: 

120 dipolefile = get_permanent_test_file(test_system, 'dm_delta') 

121 write_reference_data(get_permanent_test_file(test_system, 'ref_spectrum'), 

122 get_permanent_test_file(test_system, 'dm_delta'), 

123 {'name': 'deltakick', 'strength': 1e-5})