Coverage for tests/unittests/writers/test_energy_writers.py: 98%

44 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-01 16:57 +0000

1from __future__ import annotations 

2 

3import pytest 

4import numpy as np 

5from ase.io.ulm import Reader 

6 

7from gpaw.mpi import world 

8from rhodent.calculators.energy import EnergyCalculator 

9from rhodent.utils import create_pulse 

10 

11 

12def load_ulm(fname): 

13 with Reader(fname) as reader: 

14 return reader.asdict() 

15 

16 

17def compare(leftdata, rightdata, keys): 

18 for key in keys: 

19 np.testing.assert_allclose(leftdata[key], rightdata[key], err_msg=key) 

20 

21 

22def compare_known(leftdata, rightdata, keys, reference): 

23 # Loose tolerance to allow for energy and time grids which are written out with less precision 

24 for key, ref in zip(keys, reference): 

25 np.testing.assert_allclose(leftdata[key], rightdata[key], atol=1e-5, err_msg=key) 

26 np.testing.assert_allclose(leftdata[key], ref, atol=1e-5, err_msg=f'{key} != reference') 

27 

28 

29@pytest.mark.parametrize('test_system', ['Na55']) 

30def test_energy(tmp_path, mock_voronoi, mock_response): 

31 time_t = np.linspace(0, 30, 7) # In fs 

32 

33 # Set up mock objects 

34 voronoi = mock_voronoi(atom_projections=[[0, 1, 2], [3]]) 

35 pulse = create_pulse(4.0) 

36 calc_kwargs = dict(response=mock_response(), 

37 voronoi=None, 

38 times=time_t * 1e3, 

39 pulses=[pulse], 

40 energies_occ=np.linspace(-5, 1, 10), 

41 energies_unocc=np.linspace(-1, 5, 9), 

42 sigma=0.05) 

43 calc = EnergyCalculator(**calc_kwargs) 

44 

45 kwargs = dict(include_tcm=True, save_dist=True) 

46 

47 # Calculate with and without Voronoi 

48 calc.calculate_and_write(str(tmp_path / 'energy.npz'), **kwargs) 

49 calc.calculate_and_write(str(tmp_path / 'energy.ulm'), **kwargs) 

50 

51 calc_kwargs['voronoi'] = voronoi 

52 calc = EnergyCalculator(**calc_kwargs) 

53 calc.calculate_and_write(str(tmp_path / 'energy_voronoi.npz'), **kwargs) 

54 calc.calculate_and_write(str(tmp_path / 'energy_voronoi.ulm'), **kwargs) 

55 

56 if world.rank != 0: 

57 return 

58 

59 # Verify without Voronoi 

60 data_ulm = load_ulm(tmp_path / 'energy.ulm') 

61 data_npz = np.load(tmp_path / 'energy.npz') 

62 

63 compare_known(data_npz, data_ulm, 

64 ['pulsefreq', 'pulsefwhm', 'energy_o', 'energy_u', 'time_t', 'sigma'], 

65 [4.0, 5.0, calc_kwargs['energies_occ'], calc_kwargs['energies_unocc'], time_t, calc_kwargs['sigma']]) 

66 compare(data_npz, data_ulm, ['dm_t', 'total_t', 'total_Hxc_t', 'Epulse_t', 'E_tou']) 

67 

68 total_t = data_npz['total_t'] 

69 total_Hxc_t = data_npz['total_Hxc_t'] 

70 Epulse_t = data_npz['Epulse_t'] 

71 

72 # Verify with Voronoi 

73 data_ulm = load_ulm(tmp_path / 'energy_voronoi.ulm') 

74 data_npz = np.load(tmp_path / 'energy_voronoi.npz') 

75 

76 compare_known(data_npz, data_ulm, 

77 ['pulsefreq', 'pulsefwhm', 'energy_o', 'energy_u', 'time_t', 'sigma', 

78 'total_t', 'total_Hxc_t', 'Epulse_t'], 

79 [4.0, 5.0, calc_kwargs['energies_occ'], calc_kwargs['energies_unocc'], time_t, calc_kwargs['sigma'], 

80 total_t, total_Hxc_t, Epulse_t]) 

81 compare(data_npz, data_ulm, ['total_proj_tII', 'total_Hxc_proj_tII'])