Coverage for tests/unittests/writers/test_tcm_writers.py: 96%

46 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-01 16:57 +0000

1from __future__ import annotations 

2 

3import pytest 

4import numpy as np 

5from ase.io.ulm import Reader 

6 

7from gpaw.mpi import world 

8from rhodent.calculators.dipole import DipoleCalculator 

9from rhodent.utils import create_pulse 

10 

11 

12def load_ulm(fname): 

13 with Reader(fname) as reader: 

14 return reader.asdict() 

15 

16 

17def compare(leftdata, rightdata, keys): 

18 for key in keys: 

19 np.testing.assert_allclose(leftdata[key], rightdata[key], err_msg=key) 

20 

21 

22def compare_known(leftdata, rightdata, keys, reference): 

23 # Loose tolerance to allow for tcm and time grids which are written out with less precision 

24 for key, ref in zip(keys, reference): 

25 np.testing.assert_allclose(leftdata[key], rightdata[key], atol=1e-5, err_msg=key) 

26 np.testing.assert_allclose(leftdata[key], ref, atol=1e-5, err_msg=f'{key} != reference') 

27 

28 

29@pytest.mark.parametrize('test_system', ['Na55']) 

30def test_tcm_time(tmp_path, mock_response): 

31 time_t = np.linspace(0, 30, 7) # In fs 

32 

33 # Set up mock objects 

34 pulse = create_pulse(4.0) 

35 calc_kwargs = dict(response=mock_response(), 

36 voronoi=None, 

37 times=time_t * 1e3, 

38 pulses=[pulse], 

39 energies_occ=np.linspace(-5, 1, 10), 

40 energies_unocc=np.linspace(-1, 5, 10), 

41 sigma=0.05) 

42 calc = DipoleCalculator(**calc_kwargs) 

43 kwargs = dict(include_tcm=True) 

44 

45 calc.calculate_and_write(str(tmp_path / 'tcm.ulm'), **kwargs) 

46 calc.calculate_and_write(str(tmp_path / 'tcm.npz'), **kwargs) 

47 

48 if world.rank != 0: 

49 return 

50 

51 data_ulm = load_ulm(tmp_path / 'tcm.ulm') 

52 data_npz = np.load(tmp_path / 'tcm.npz') 

53 

54 compare_known(data_npz, data_ulm, 

55 ['pulsefreq', 'pulsefwhm', 'energy_o', 'energy_u', 'time_t', 'sigma'], 

56 [4.0, 5.0, calc_kwargs['energies_occ'], calc_kwargs['energies_unocc'], time_t, calc_kwargs['sigma']]) 

57 compare(data_npz, data_ulm, ['dm_tv', 'dm_touv']) 

58 

59 

60@pytest.mark.parametrize('test_system', ['Na55']) 

61def test_tcm_frequency(tmp_path, mock_response): 

62 freq_w = np.linspace(2, 6, 7) # In units of eV 

63 

64 # Set up mock objects 

65 calc_kwargs = dict(response=mock_response(), 

66 voronoi=None, 

67 frequencies=freq_w, 

68 energies_occ=np.linspace(-5, 1, 10), 

69 energies_unocc=np.linspace(-1, 5, 10), 

70 sigma=0.05) 

71 calc = DipoleCalculator(**calc_kwargs) 

72 kwargs = dict(include_tcm=True) 

73 

74 calc.calculate_and_write(str(tmp_path / 'tcm.ulm'), **kwargs) 

75 calc.calculate_and_write(str(tmp_path / 'tcm.npz'), **kwargs) 

76 

77 if world.rank != 0: 

78 return 

79 

80 data_ulm = load_ulm(tmp_path / 'tcm.ulm') 

81 data_npz = np.load(tmp_path / 'tcm.npz') 

82 

83 compare_known(data_npz, data_ulm, 

84 ['energy_o', 'energy_u', 'freq_w', 'sigma'], 

85 [calc_kwargs['energies_occ'], calc_kwargs['energies_unocc'], freq_w, calc_kwargs['sigma']]) 

86 compare(data_npz, data_ulm, ['dm_wv', 'dm_wouv'])