Coverage for tests/unittests/density_matrices/distributed/test_distr_time.py: 96%

26 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-01 16:57 +0000

1from __future__ import annotations 

2 

3import os 

4import pytest 

5import numpy as np 

6 

7from gpaw.tddft.units import fs_to_au 

8from gpaw.mpi import world, SerialCommunicator 

9from rhodent.density_matrices.distributed.time import TimeDistributor, AlltoallvTimeDistributor, RhoParameters 

10 

11 

12@pytest.mark.parametrize('maxsize', [5000, None]) 

13@pytest.mark.parametrize('test_system', ['Na8', 'Ag8']) 

14def test_alltoallv(mock_ks_rho_reader, ksd_fname, maxsize): 

15 if maxsize is None: 

16 os.environ.pop('RHODENT_REDISTRIBUTE_MAXSIZE', None) 

17 else: 

18 os.environ['RHODENT_REDISTRIBUTE_MAXSIZE'] = str(maxsize) 

19 

20 # Set up mock KS density matrices readers 

21 serial_rho_chunk_reader = mock_ks_rho_reader(filter_times=np.linspace(0, 30, 15) * fs_to_au, 

22 comm=SerialCommunicator()) 

23 rho_chunk_reader = mock_ks_rho_reader(filter_times=np.linspace(0, 30, 15) * fs_to_au) 

24 

25 # Set up the serial and alltoallv distributors 

26 parameters = RhoParameters.from_ksd(rho_chunk_reader.ksd, comm=rho_chunk_reader.comm, striden1=300, striden2=300) 

27 serial_distributor = TimeDistributor(serial_rho_chunk_reader, parameters) 

28 world.barrier() 

29 parameters = RhoParameters.from_ksd(rho_chunk_reader.ksd, striden1=30, striden2=30) 

30 time_distributor = AlltoallvTimeDistributor(rho_chunk_reader, parameters) 

31 # Collect on root. This calculates different chunks on different ranks 

32 # and gathers and groups the results on the root rank 

33 ref_full_dm = serial_distributor.collect_on_root() 

34 test_full_dm = time_distributor.collect_on_root() 

35 

36 if world.rank != 0: 

37 return 

38 

39 # Check that data is the same 

40 for test_array_iat, ref_array_iat in zip(test_full_dm._iter_buffers(), 

41 ref_full_dm._iter_buffers()): 

42 np.testing.assert_allclose(test_array_iat, ref_array_iat)