Coverage for tests/integration/density_matrices/distributed/test_pulse.py: 96%

47 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-01 16:57 +0000

1from __future__ import annotations 

2 

3import pytest 

4import numpy as np 

5 

6 

7@pytest.mark.bigdata 

8@pytest.mark.parametrize('test_system', ['Na8', 'Ag8']) 

9def test_density_matrix_frho_against_wave_functions(ksd_fname, frho_dname, wfs_fname): 

10 """ Test that we can use the 'distributed' class to compute pulse response 

11 density matrices from a wave function file. Tests that the following works: 

12 

13 - Wave functions are read in time and transformed 

14 into the Kohn-Sham basis (KohnShamRhoWfsReader). 

15 Each MPI rank holds the same part of the density 

16 matrix at different times. 

17 - The MPI ranks exchange data so that each rank now 

18 holds a smaller part of the density matrix for 

19 contiguous times (AlltoallvTimeDistributor). 

20 - The matrices are elementwise convoluted with the 

21 pulse (PulseConvolver). 

22 - Everything is collected to a large buffer on the 

23 root rank (collect_on_root). 

24 

25 Compare against the density matrices convolved from from the 

26 FrequencyDensityMatrix file. 

27 

28 Reading with different strides etc is tested in unit tests with mock data. 

29 """ 

30 from gpaw.lcaotddft.ksdecomposition import KohnShamDecomposition 

31 from gpaw.mpi import world 

32 from gpaw.tddft.units import as_to_au, au_to_as 

33 

34 from rhodent.density_matrices.time import ConvolutionDensityMatricesFromFrequency 

35 from rhodent.density_matrices.distributed.pulse import PulseConvolver 

36 from rhodent.utils import add_fake_kpts, create_pulse 

37 

38 ksd = KohnShamDecomposition(filename=ksd_fname) 

39 add_fake_kpts(ksd) 

40 imin, imax, amin, amax = ksd.ialims() 

41 

42 # Set up pulses 

43 fwhm = 5.0 

44 pulses = [create_pulse(pf, fwhm, 10.0) for pf in [1.1, 1.3]] 

45 # Set up times to compute. This was picks a few times at the end 

46 # which are guaranteed to align the the times in the wave function file 

47 times = np.arange(1, 3000, 20)[-10::2] * 10 # In units of as 

48 shape = (imax - imin + 1, amax - amin + 1, len(pulses), len(times)) 

49 

50 # Set up the KohnShamRhoWfsReader, AlltoallvTimeDistributor, and PulseConvolver 

51 perturbation = {'name': 'SincPulse', 'strength': 1e-6, 

52 'cutoff_freq': 4, 'time0': 5, 'relative_t0': True} 

53 rho_nn_conv = PulseConvolver.from_parameters( 

54 wfs_fname=wfs_fname, 

55 ksd=ksd_fname, 

56 yield_re=True, 

57 yield_im=True, 

58 perturbation=perturbation, 

59 pulses=pulses, 

60 derivative_order_s=[0, 1, 2], 

61 filter_times=times * as_to_au) 

62 

63 # Read and Fourier transform, collect on the root rank 

64 wfs_buffer = rho_nn_conv.collect_on_root() 

65 

66 if world.rank == 0: 

67 # Copy over data to plain numpy arrays, merging real and imaginary parts 

68 wfs_rho_iapt = wfs_buffer.real + 1j * wfs_buffer.imag # Density matrix 

69 wfs_drho_iapt = wfs_buffer.real1 + 1j * wfs_buffer.imag1 # Derivative 

70 wfs_ddrho_iapt = wfs_buffer.real2 + 1j * wfs_buffer.imag2 # Second derivative 

71 else: 

72 assert wfs_buffer is None 

73 

74 # Set up frho reader 

75 perturbation = {'name': 'deltakick', 'strength': 1e-5} 

76 density_matrices = ConvolutionDensityMatricesFromFrequency( 

77 ksd=ksd_fname, 

78 frho_fmt=str(frho_dname / 'w{freq:05.2f}-{reim}.npy'), 

79 perturbation=perturbation, 

80 pulses=pulses, 

81 real=True, 

82 imag=True, 

83 derivative_order_s=[0, 1, 2], 

84 times=times) 

85 

86 # Make sure that the times match 

87 np.testing.assert_almost_equal(rho_nn_conv.time_t * au_to_as, density_matrices.times) 

88 

89 frho_rho_iapt = np.zeros(shape, dtype=complex) # Density matrix 

90 frho_drho_iapt = np.zeros(shape, dtype=complex) # Derivative 

91 frho_ddrho_iapt = np.zeros(shape, dtype=complex) # Second derivative 

92 

93 # Read files, fill buffers on all ranks 

94 for work, dm in density_matrices: 

95 index = (..., work.globalp, work.globalt) 

96 frho_rho_iapt[index] = dm.rho_ia.conj() # Need to take conjugate here 

97 frho_drho_iapt[index] = dm.drho_ia.conj() 

98 frho_ddrho_iapt[index] = dm.ddrho_ia.conj() 

99 

100 # And sum to the root rank 

101 world.sum(frho_rho_iapt, 0) 

102 world.sum(frho_drho_iapt, 0) 

103 world.sum(frho_ddrho_iapt, 0) 

104 

105 if world.rank > 0: 

106 # Let the root rank test 

107 return 

108 

109 # Check that for each pulse, at each time are all the density matrix elements 

110 # reasonably close 

111 for label, wfs_iapt, frho_iapt in [ 

112 ('plain density matrix', wfs_rho_iapt, frho_rho_iapt), 

113 ('first derivative', wfs_drho_iapt, frho_drho_iapt), 

114 ('second derivative', wfs_ddrho_iapt, frho_ddrho_iapt), 

115 ]: 

116 vmax = np.abs(frho_iapt).max() 

117 

118 wfs_iapt[np.tril_indices(n=imax+1-imin, m=amax+1-amin, k=-amin)] = 0 

119 np.testing.assert_allclose(frho_iapt, wfs_iapt, 

120 err_msg=label, atol=0.01 * vmax, rtol=1.0)