Coverage for tests/unittests/density_matrices/readers/test_gpaw.py: 89%
427 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
1from __future__ import annotations
3import numpy as np
4import pytest
5from itertools import product
7from gpaw.mpi import world, SerialCommunicator
9from rhodent.density_matrices.readers.gpaw import KohnShamRhoWfsReader
10from rhodent.density_matrices.distributed.time import TimeDistributor, AlltoallvTimeDistributor, RhoParameters
12from tests import wrap_test, get_permanent_test_file
15def write_reference_data(ref_density_matrix, gpw_fname, ksd_fname, wfssnap_fname):
16 assert world.size == 1, 'Run me in serial mode'
18 testt = [0, 1, 3]
20 wfs_strides = dict(stridet=2, striden=0)
21 stride_opts = dict(striden1=300, striden2=300, only_ia=False)
23 # Read on many ranks, gather by using world.sum
24 rho_reader = KohnShamRhoWfsReader(wfs_fname=wfssnap_fname, ksd=ksd_fname, **wfs_strides)
25 parameters = RhoParameters.from_ksd(rho_reader.ksd, **stride_opts)
26 time_distributor = AlltoallvTimeDistributor(rho_reader, parameters)
28 # Get the ground state stuff
29 f_n = time_distributor.ksd.reader.occ_un[0, 0]
30 rho0_nn = time_distributor.rho0_sknn[0, 0]
31 nn = rho0_nn.shape[0]
33 rho_nnt = np.zeros((nn, nn, len(testt)), dtype=complex)
34 collect_rho_nn(time_distributor, rho_nnt, testt)
36 save = dict(f_n=f_n, rho0_nn=rho0_nn, rho_nnt=rho_nnt, testt=testt)
38 np.savez_compressed(ref_density_matrix, **save)
39 print(f'Saved data to {ref_density_matrix}')
42def read_ksd(ksd_fname: str):
43 from ase.io.ulm import Reader
45 arrays = dict()
46 with Reader(ksd_fname) as reader:
47 arrays['S_MM'] = reader.S_uMM[0, 0, 0][:]
48 arrays['C0_nM'] = reader.C0_unM[0, 0, 0][:]
50 return arrays
53def assert_orthonormal(C_nM, S_MM):
54 test_nn = C_nM @ S_MM @ C_nM.T.conj()
56 eye_nn = np.eye(len(test_nn))
58 error_nn = np.abs(test_nn - eye_nn)
59 assert np.max(error_nn) < 1e-12
62def _test_read_C_nM(ksd_fname, wfssnap_fname, strideM, striden):
63 from rhodent.density_matrices.readers.gpaw import WfsReader
65 if world.rank == 0:
66 arrays = read_ksd(ksd_fname)
67 S_MM = arrays['S_MM']
69 testt = [0, 1, 3]
71 # Read on many ranks, gather by using world.sum
72 stride_opts = dict(stridet=2)
73 wfs_reader = WfsReader(wfs_fname=wfssnap_fname, **stride_opts)
75 # Get the ground state stuff
76 C0_nM = wfs_reader.C0_sknM[0, 0]
77 nn, nM = C0_nM.shape
79 if world.rank == 0:
80 assert_orthonormal(C0_nM, S_MM)
82 # Read on many ranks, gather by world.sum
83 C_nMt = np.zeros((nn, nM, len(testt)), dtype=complex)
84 collect_C_nM(wfs_reader, striden, strideM, C_nMt, testt)
86 if world.rank == 0:
87 for t in range(len(testt)):
88 C_nM = C_nMt[..., t]
89 assert_orthonormal(C_nM, S_MM)
91 # Read on many ranks, gather properly
92 C_root_nMt = np.zeros((nn, nM, len(testt)), dtype=complex) if world.rank == 0 else None
93 collect_C_nM_gather_on_root(wfs_reader, striden, strideM, C_root_nMt, testt)
95 if world.rank == 0:
96 assert np.allclose(C_nMt, C_root_nMt)
99def _test_read_rho_MM(ksd_fname, wfssnap_fname, strideM1, strideM2, striden):
100 from rhodent.density_matrices.readers.gpaw import LCAORhoWfsReader
102 if world.rank == 0:
103 arrays = read_ksd(ksd_fname)
104 S_MM = arrays['S_MM']
106 testt = [0, 1, 3]
108 stride_opts = dict(stridet=2, striden=striden)
109 rho_reader = LCAORhoWfsReader(wfs_fname=wfssnap_fname, **stride_opts)
111 # Get the ground state stuff
112 f_n = rho_reader.f_skn[0, 0]
113 C0_nM = rho_reader.C0_sknM[0, 0]
114 rho0_MM = rho_reader.rho0_skMM[0, 0]
115 assert np.isrealobj(rho0_MM)
116 assert np.isrealobj(C0_nM)
118 nn, nM = C0_nM.shape
119 diag_mask_nn = np.eye(nn, dtype=bool)
121 if world.rank == 0:
122 C0S_nM = C0_nM @ S_MM
124 # Assert orthonormal
125 rho0_nn = C0S_nM @ (C0S_nM @ rho0_MM).swapaxes(0, 1)
126 rho0_n = rho0_nn[diag_mask_nn]
127 assert np.allclose(rho0_n, f_n, rtol=1e-10), rho0_n
129 # Read on many ranks, gather by world.sum
130 drho_MMt = np.zeros((nM, nM, len(testt)), dtype=complex)
131 collect_rho_MM(rho_reader, strideM1, strideM2, drho_MMt, testt)
133 if world.rank == 0:
134 # Matrix should be Hermitian
135 error = np.abs(np.sum(drho_MMt.swapaxes(0, 1).conj() - drho_MMt))
136 assert error < 1e-15, f'rho not Hermitian, {error:.5e}'
138 # Read on many ranks, gather properly
139 drho_root_MMt = np.zeros((nM, nM, len(testt)), dtype=complex) if world.rank == 0 else None
140 collect_rho_MM_gather_on_root(rho_reader, strideM1, strideM2, drho_root_MMt, testt)
142 if world.rank == 0:
143 assert np.allclose(drho_MMt, drho_root_MMt)
146def collect_C_nM(wfs_reader, striden, strideM, C_nMt, testt):
147 # This is more like a test of the gather_on_root
148 C_nMt[:] = 0
149 nn, nM = wfs_reader.C0_sknM.shape[2:]
151 for s, k, nmin, Mmin in product(range(1),
152 range(1),
153 range(0, nn, striden),
154 range(0, nM, strideM)):
155 n = slice(nmin, nmin + striden)
156 M = slice(Mmin, Mmin + strideM)
157 gen = wfs_reader.iread(s, k, n, M)
159 for globalt in wfs_reader.work_loop(wfs_reader.comm.rank):
160 if globalt is None:
161 continue
162 dm_buffer = next(gen)
163 if globalt not in testt:
164 continue
165 t = testt.index(globalt)
166 rho_x = dm_buffer.real + 1j * dm_buffer.imag
167 _nM1, _nM2, _ = C_nMt[n, M, :].shape
168 C_nMt[n, M, t] += rho_x[:_nM1, :_nM2]
169 _exhausted = object()
170 assert next(gen, _exhausted) is _exhausted
172 world.sum(C_nMt, 0)
175def collect_C_nM_gather_on_root(wfs_reader, striden, strideM, C_nMt, testt):
176 nn, nM = wfs_reader.C0_sknM.shape[2:]
177 full_dm = wfs_reader.collect_on_root(0, 0, slice(0, nn), slice(0, nM))
178 if world.rank != 0:
179 assert full_dm is None
180 return
182 assert full_dm is not None
184 full_dm = full_dm[testt]
185 Rerho_xt = full_dm.real
186 Imrho_xt = full_dm.imag
187 C_nMt[:] = Rerho_xt + 1.0j * Imrho_xt
190def collect_rho_MM(rho_MM_reader, strideM1, strideM2, rho_MMt, testt):
191 # This is more like a test of the gather_on_root
192 rho_MMt[:] = 0
193 nM = rho_MM_reader.C0_sknM.shape[3]
195 for s, k, M1min, M2min in product(range(1),
196 range(1),
197 range(0, nM, strideM1),
198 range(0, nM, strideM2)):
199 M1 = slice(M1min, M1min + strideM1)
200 M2 = slice(M2min, M2min + strideM2)
201 gen = rho_MM_reader.iread(s, k, M1, M2)
203 for globalt in rho_MM_reader.work_loop(rho_MM_reader.comm.rank):
204 if globalt is None:
205 continue
206 dm_buffer = next(gen)
207 if globalt not in testt:
208 continue
209 t = testt.index(globalt)
210 rho_x = dm_buffer.real + 1j * dm_buffer.imag
211 _nM1, _nM2, _ = rho_MMt[M1, M2, :].shape
212 rho_MMt[M1, M2, t] += rho_x[:_nM1, :_nM2]
213 _exhausted = object()
214 assert next(gen, _exhausted) is _exhausted
216 world.sum(rho_MMt, 0)
219def collect_rho_MM_gather_on_root(rho_MM_reader, strideM1, strideM2, rho_MMt, testt):
220 nM = rho_MM_reader.C0_sknM.shape[3]
221 full_dm = rho_MM_reader.collect_on_root(0, 0, slice(0, nM), slice(0, nM))
222 if world.rank != 0:
223 assert full_dm is None
224 return
226 assert full_dm is not None
228 full_dm = full_dm[testt]
229 Rerho_xt = full_dm.real
230 Imrho_xt = full_dm.imag
231 rho_MMt[:] = Rerho_xt + 1.0j * Imrho_xt
234def collect_rho_nn(time_distributor, rho_nnt, testt):
235 # This is more like a test of the gather_on_root
236 rho_nnt[:] = 0
238 gen = iter(time_distributor)
239 for indices in time_distributor.work_loop(time_distributor.comm.rank):
240 if indices is None:
241 continue
242 dm_buffer = next(gen)
243 dm_buffer = dm_buffer[testt]
244 rho_xt = dm_buffer.real + 1j*dm_buffer.imag
245 s, k, n1, n2 = indices
246 assert s == 0
247 assert k == 0
248 _nn1, _nn2, _ = rho_nnt[n1, n2, :].shape
249 rho_nnt[n1, n2, :] += rho_xt[:_nn1, :_nn2]
251 _exhausted = object()
252 assert next(gen, _exhausted) is _exhausted
253 world.sum(rho_nnt, 0)
256def collect_rho_nn_gather_on_root(time_distributor, rho_nnt, testt):
257 full_dm = time_distributor.collect_on_root()
258 if world.rank != 0:
259 assert full_dm is None
260 return
262 assert full_dm is not None
264 full_dm = full_dm[testt]
265 Rerho_xt = full_dm.real
266 Imrho_xt = full_dm.imag
267 rho_nnt[:] = Rerho_xt + 1.0j * Imrho_xt
270def _test_read_rho_nn(ksd_fname, wfssnap_fname, striden, striden1, striden2):
271 testt = [0, 1, 3]
273 wfs_strides = dict(stridet=2, striden=striden)
274 stride_opts = dict(striden1=striden1, striden2=striden2, only_ia=False)
276 if striden == 0:
277 rho_reader = KohnShamRhoWfsReader(wfs_fname=wfssnap_fname, ksd=ksd_fname, **wfs_strides)
278 parameters = RhoParameters.from_ksd(rho_reader.ksd, **stride_opts)
279 time_distributor = AlltoallvTimeDistributor(rho_reader, parameters)
280 else:
281 rho_reader = KohnShamRhoWfsReader(wfs_fname=wfssnap_fname, ksd=ksd_fname,
282 comm=SerialCommunicator(), **wfs_strides)
283 parameters = RhoParameters.from_ksd(rho_reader.ksd, comm=rho_reader.comm, **stride_opts)
284 time_distributor = TimeDistributor(rho_reader, parameters)
286 # Read on many ranks, gather by using world.sum
288 # Get the ground state stuff
289 f_n = time_distributor.ksd.reader.occ_un[0, 0]
290 rho0_nn = time_distributor.rho_wfs_reader.rho0_sknn[0, 0]
291 nn = rho0_nn.shape[0]
292 assert np.isrealobj(rho0_nn)
294 # Note that here, nn is the index of the states in the full KS basis
295 diag_mask_nn = np.eye(nn, dtype=bool)
297 if world.rank == 0:
298 # Assert orthonormal
299 rho0_n = rho0_nn[diag_mask_nn]
300 np.testing.assert_allclose(rho0_n, f_n, rtol=1e-20, atol=1e-20)
301 np.testing.assert_allclose(rho0_nn[~diag_mask_nn], 0, rtol=1e-20, atol=1e-20)
303 rho_nnt = np.zeros((nn, nn, len(testt)), dtype=complex)
304 collect_rho_nn_gather_on_root(time_distributor, rho_nnt, testt)
306 if world.rank == 0:
307 # Trace should be constant
308 trace_t = np.trace(rho_nnt, axis1=0, axis2=1)
309 np.testing.assert_allclose(trace_t, trace_t[0], rtol=1e-05, atol=1e-08)
311 # The following will depend very much on the particular system
312 # Close to 1e-10 on the diagonal
313 rho_nt = rho_nnt[diag_mask_nn]
314 np.testing.assert_allclose(rho_nt, 0, rtol=0, atol=1e-7)
315 # Close to 1e-5 on the off-diagonal
316 np.testing.assert_allclose(rho_nnt[~diag_mask_nn], 0, rtol=0, atol=1e-3)
319def _test_read_rho_nn_against_reference(ref_density_matrix, ksd_fname, wfssnap_fname, striden, striden1, striden2):
320 reference = np.load(ref_density_matrix)
321 testt = reference['testt']
323 rho_nnt = np.zeros_like(reference['rho_nnt'])
325 wfs_strides = dict(stridet=2, striden=striden)
326 stride_opts = dict(striden1=striden1, striden2=striden2, only_ia=False)
328 def read_and_compare(time_distributor, collect):
329 # Get the ground state stuff
330 f_n = time_distributor.ksd.reader.occ_un[0, 0]
331 rho0_nn = time_distributor.rho_wfs_reader.rho0_sknn[0, 0]
332 assert np.isrealobj(rho0_nn)
334 collect(time_distributor, rho_nnt, testt)
336 if world.rank == 0:
337 assert np.allclose(f_n, reference['f_n'])
338 assert np.allclose(rho0_nn, reference['rho0_nn'])
339 assert np.allclose(rho_nnt, reference['rho_nnt'])
341 # Read on many ranks, gather by using world.sum
342 rho_reader = KohnShamRhoWfsReader(wfs_fname=wfssnap_fname, ksd=ksd_fname,
343 comm=SerialCommunicator(), **wfs_strides)
344 parameters = RhoParameters.from_ksd(rho_reader.ksd, comm=rho_reader.comm, **stride_opts)
345 time_distributor = TimeDistributor(rho_reader, parameters)
346 read_and_compare(time_distributor, collect_rho_nn)
348 # Read on many ranks, gather properly
349 read_and_compare(time_distributor, collect_rho_nn_gather_on_root)
351 # With alltoallv
352 if striden == 0:
353 pass
354 else:
355 return
357 rho_reader = KohnShamRhoWfsReader(wfs_fname=wfssnap_fname, ksd=ksd_fname, **wfs_strides)
358 parameters = RhoParameters.from_ksd(rho_reader.ksd, **stride_opts)
359 time_distributor = AlltoallvTimeDistributor(rho_reader, parameters)
360 read_and_compare(time_distributor, collect_rho_nn)
362 read_and_compare(time_distributor, collect_rho_nn_gather_on_root)
365@pytest.mark.parametrize('test_system', ['Na8', 'Ag8'])
366def test_read_C_nM_strides(wfssnap_fname):
367 """ Test reading the time-dependent wave functions file with different strides """
368 from rhodent.density_matrices.readers.gpaw import WfsReader
370 iread_args = (0, 0, slice(0, 10), slice(0, 10)) # s, k, n, M to read. 10x10 elements is enough
371 # Read every time step as reference
373 reader = WfsReader(wfs_fname=wfssnap_fname, stridet=1)
374 dt = reader.dt
375 reftimes = reader.time_t
377 ref_rho_tnn = []
378 for dm_buffer in reader.gather_on_root(*iread_args):
379 if reader.comm.rank > 0:
380 assert dm_buffer is None
381 continue
382 C_nM = dm_buffer.real + 1j * dm_buffer.imag
383 ref_rho_tnn.append(C_nM)
385 ref_rho_tnn = np.array(ref_rho_tnn)
387 if reader.comm.rank == 0:
388 nreftimes = len(ref_rho_tnn)
389 assert nreftimes > 8, nreftimes # Make sure that it is actually reading
391 # Read, skipping few steps at a time
392 for stridet in [2, 3, 4]:
393 reader = WfsReader(wfs_fname=wfssnap_fname, stridet=stridet)
395 test_rho_tnn = []
396 for dm_buffer in reader.gather_on_root(*iread_args):
397 if reader.comm.rank > 0:
398 assert dm_buffer is None
399 continue
400 C_nM = dm_buffer.real + 1j * dm_buffer.imag
401 test_rho_tnn.append(C_nM)
402 test_rho_tnn = np.array(test_rho_tnn)
404 if reader.comm.rank > 0:
405 continue
407 # Make sure that exactly the right values are being read
408 np.testing.assert_almost_equal(reader.dt, dt * stridet)
409 np.testing.assert_equal(len(ref_rho_tnn[::stridet]), len(test_rho_tnn))
410 np.testing.assert_array_equal(reftimes[::stridet], reader.time_t)
411 np.testing.assert_array_equal(ref_rho_tnn[::stridet], test_rho_tnn)
414@pytest.mark.parametrize('test_system', ['Na8', 'Ag8'])
415def test_read_rho_nn_strides(ksd_fname, wfssnap_fname):
416 """ Test reading the density matrix from the time-dependent wave functions file with different strides """
417 from rhodent.density_matrices.readers.gpaw import KohnShamRhoWfsReader
419 iread_args = (0, 0, slice(0, 10), slice(0, 10)) # s, k, n1, n2 to read. 10x10 elements is enough
421 # Read every time step as reference
422 reader = KohnShamRhoWfsReader(wfs_fname=wfssnap_fname, ksd=ksd_fname, stridet=1)
423 dt = reader.dt
424 reftimes = reader.time_t
426 ref_rho_tnn = []
427 for dm_buffer in reader.gather_on_root(*iread_args):
428 if reader.comm.rank > 0:
429 assert dm_buffer is None
430 continue
431 rho_nn = dm_buffer.real + 1j * dm_buffer.imag
432 ref_rho_tnn.append(rho_nn)
434 ref_rho_tnn = np.array(ref_rho_tnn)
436 if reader.comm.rank == 0:
437 nreftimes = len(ref_rho_tnn)
438 assert nreftimes >= 8, nreftimes # Make sure that it is actually reading
440 # Read, skipping few steps at a time
441 for stridet in [2, 3, 4]:
442 reader = KohnShamRhoWfsReader(wfs_fname=wfssnap_fname, ksd=ksd_fname, stridet=stridet)
444 test_rho_tnn = []
445 for dm_buffer in reader.gather_on_root(*iread_args):
446 if reader.comm.rank > 0:
447 assert dm_buffer is None
448 continue
449 rho_nn = dm_buffer.real + 1j * dm_buffer.imag
450 test_rho_tnn.append(rho_nn)
451 test_rho_tnn = np.array(test_rho_tnn)
453 if reader.comm.rank > 0:
454 continue
456 # Make sure that exactly the right values are being read
457 np.testing.assert_almost_equal(reader.dt, dt * stridet)
458 np.testing.assert_equal(len(ref_rho_tnn[::stridet]), len(test_rho_tnn))
459 np.testing.assert_array_equal(reftimes[::stridet], reader.time_t)
460 np.testing.assert_array_equal(ref_rho_tnn[::stridet], test_rho_tnn)
463@pytest.mark.parametrize('strideM', [1, 2, 3, 4])
464@pytest.mark.parametrize('striden', [1, 2, 3, 4])
465@pytest.mark.parametrize('test_system', ['2H2'])
466@wrap_test(_test_read_C_nM)
467def test_read_C_nM_small():
468 pass
471@pytest.mark.parametrize('strideM', [80])
472@pytest.mark.parametrize('striden', [5, 9])
473@pytest.mark.parametrize('test_system', ['Na8', 'Ag8'])
474@wrap_test(_test_read_C_nM)
475def test_read_C_nM_big():
476 pass
479@pytest.mark.parametrize('strideM1', [1, 3, 4])
480@pytest.mark.parametrize('strideM2', [1, 2, 4])
481@pytest.mark.parametrize('striden', [1, 2, 3, 4])
482@pytest.mark.parametrize('test_system', ['2H2'])
483@wrap_test(_test_read_rho_MM)
484def test_read_rho_MM_small():
485 pass
488@pytest.mark.parametrize('strideM1', [80])
489@pytest.mark.parametrize('strideM2', [100])
490@pytest.mark.parametrize('striden', [1, 4, 9])
491@pytest.mark.parametrize('test_system', ['Na8', 'Ag8'])
492@wrap_test(_test_read_rho_MM)
493def test_read_rho_MM_big():
494 pass
497@pytest.mark.parametrize('striden1', [1, 3, 4])
498@pytest.mark.parametrize('striden2', [3, 4])
499@pytest.mark.parametrize('striden', [1, 4])
500@pytest.mark.parametrize('test_system', ['2H2'])
501@wrap_test(_test_read_rho_nn)
502def test_read_rho_nn_small():
503 pass
506@pytest.mark.parametrize('striden1', [8])
507@pytest.mark.parametrize('striden2', [8])
508@pytest.mark.parametrize('striden', [0])
509@pytest.mark.parametrize('test_system', ['Na8'])
510@wrap_test(_test_read_rho_nn)
511def test_read_rho_nn_big():
512 pass
515@pytest.mark.parametrize('striden1', [1, 3, 4])
516@pytest.mark.parametrize('striden2', [3, 4])
517@pytest.mark.parametrize('striden', [1, 4])
518@pytest.mark.parametrize('test_system', ['2H2'])
519@wrap_test(_test_read_rho_nn_against_reference)
520def test_read_rho_nn_against_reference_small():
521 pass
524@pytest.mark.parametrize('striden1', [8])
525@pytest.mark.parametrize('striden2', [8])
526@pytest.mark.parametrize('striden', [0])
527@pytest.mark.parametrize('test_system', ['Na8', 'Ag8'])
528@wrap_test(_test_read_rho_nn_against_reference)
529def test_read_rho_nn_against_reference_big():
530 pass
533@pytest.mark.parametrize('striden1', [8])
534@pytest.mark.parametrize('striden2', [8])
535@pytest.mark.parametrize('striden', [0])
536@pytest.mark.parametrize('test_system', ['Na8'])
537def test_read_rho_nn_parts(ksd_fname, wfssnap_fname, striden, striden1, striden2):
539 def calc(n1min, n1max, n2min, n2max):
540 wfs_strides = dict(stridet=3, striden=striden)
541 stride_opts = dict(striden1=min(striden1, n1max + 1 - n1min),
542 striden2=min(striden2, n2max + 1 - n2min),
543 n1min=n1min, n1max=n1max, n2min=n2min, n2max=n2max)
545 # Read on many ranks, gather properly
546 rho_reader = KohnShamRhoWfsReader(wfs_fname=wfssnap_fname, ksd=ksd_fname, **wfs_strides)
547 parameters = RhoParameters.from_ksd(rho_reader.ksd, **stride_opts)
548 time_distributor = AlltoallvTimeDistributor(rho_reader, parameters)
550 full_dm = time_distributor.collect_on_root()
551 if time_distributor.comm.rank != 0:
552 assert full_dm is None
553 return None
554 assert full_dm is not None
555 full_dm = full_dm[0]
556 rho_root_nnt = full_dm.real + full_dm.imag
557 return rho_root_nnt
559 ref_rho_nn = calc(0, 63, 0, 63)
561 for n1min, n1max in [(0, 4), (2, 13), (7, 14)]:
562 for n2min, n2max in [(0, 4), (24, 53), (24, 38)]:
563 test_rho_nn = calc(n1min, n1max, n2min, n2max)
564 if world.rank != 0:
565 continue
566 cmp_rho_nn = ref_rho_nn[n1min:n1max+1, n2min:n2max+1]
567 assert np.allclose(test_rho_nn, cmp_rho_nn)
570@pytest.mark.parametrize('stridet', [1, 2])
571@pytest.mark.parametrize('test_system', ['2H2', 'Na8'])
572def test_prepare_wave_function_readers(wfssnap_fname, stridet):
573 from gpaw.lcaotddft.wfwriter import WaveFunctionReader
574 from rhodent.density_matrices.readers.gpaw import prepare_wave_function_readers
576 def log(*args, **kwargs):
577 pass
579 comm = world
580 mainreader = WaveFunctionReader(wfssnap_fname)
581 args = (mainreader, comm, log)
583 parallel_time_t, _, _ = prepare_wave_function_readers(*args, stridet=stridet, parallel=True)
584 serial_time_t, _, _ = prepare_wave_function_readers(*args, stridet=stridet, parallel=False)
586 assert np.array_equal(parallel_time_t, serial_time_t)
589@pytest.mark.parametrize('test_system', ['2H2', 'Na8', 'Ag8', 'Na55'])
590def test_rho0(ksd_fname):
591 from ase.io.ulm import Reader
592 from rhodent.density_matrices.readers.gpaw import read_C0S_parallel
593 from rhodent.utils import proxy_sknX_slicen
594 ksdreader = Reader(ksd_fname)
595 C0S_sknM = read_C0S_parallel(ksdreader, comm=world)
597 full_f_skn = ksdreader.occ_un
598 f_skn = ksdreader.proxy('occ_un')
599 occ_C0_sknM = ksdreader.proxy('C0_unM', 0)
600 f_skn = proxy_sknX_slicen(f_skn, comm=world)
601 occ_C0_sknM = proxy_sknX_slicen(occ_C0_sknM, comm=world)
602 if occ_C0_sknM.size > 0:
603 assert np.max(np.abs(occ_C0_sknM.imag)) < 1e-20
604 occ_C0_sknM = occ_C0_sknM.real
606 nn = C0S_sknM.shape[2]
608 if world.rank == 0:
609 print('(all ranks) Constructing rho0_sknn')
610 if f_skn.size > 0:
611 rho0_sknn = np.einsum('skn,sknM,sknO,skmM,skoO->skmo',
612 f_skn, occ_C0_sknM, occ_C0_sknM, C0S_sknM, C0S_sknM, optimize=True)
613 else:
614 rho0_sknn = np.zeros(f_skn.shape[:2] + (nn, nn))
615 world.sum(rho0_sknn, -1)
616 if world.rank == 0:
617 print('(all ranks) Constructed rho0_sknn')
619 assert np.all(rho0_sknn.imag == 0)
621 diag_nn = np.eye(nn, dtype=bool)
622 rho_skn = rho0_sknn[..., diag_nn]
623 assert np.allclose(rho_skn, full_f_skn)
625 rho0_sknn[..., diag_nn] = 0
626 assert np.allclose(rho0_sknn, 0)
629if __name__ == '__main__':
630 for test_system in ['2H2', 'Na8', 'Ag8']:
631 write_reference_data(get_permanent_test_file(test_system, 'ref_density_matrix'),
632 get_permanent_test_file(test_system, 'gpw_fname'),
633 get_permanent_test_file(test_system, 'ksd_fname'),
634 get_permanent_test_file(test_system, 'wfssnap_fname'))