Coverage for tests/unittests/density_matrices/readers/test_gpaw.py: 89%

427 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-01 16:57 +0000

1from __future__ import annotations 

2 

3import numpy as np 

4import pytest 

5from itertools import product 

6 

7from gpaw.mpi import world, SerialCommunicator 

8 

9from rhodent.density_matrices.readers.gpaw import KohnShamRhoWfsReader 

10from rhodent.density_matrices.distributed.time import TimeDistributor, AlltoallvTimeDistributor, RhoParameters 

11 

12from tests import wrap_test, get_permanent_test_file 

13 

14 

15def write_reference_data(ref_density_matrix, gpw_fname, ksd_fname, wfssnap_fname): 

16 assert world.size == 1, 'Run me in serial mode' 

17 

18 testt = [0, 1, 3] 

19 

20 wfs_strides = dict(stridet=2, striden=0) 

21 stride_opts = dict(striden1=300, striden2=300, only_ia=False) 

22 

23 # Read on many ranks, gather by using world.sum 

24 rho_reader = KohnShamRhoWfsReader(wfs_fname=wfssnap_fname, ksd=ksd_fname, **wfs_strides) 

25 parameters = RhoParameters.from_ksd(rho_reader.ksd, **stride_opts) 

26 time_distributor = AlltoallvTimeDistributor(rho_reader, parameters) 

27 

28 # Get the ground state stuff 

29 f_n = time_distributor.ksd.reader.occ_un[0, 0] 

30 rho0_nn = time_distributor.rho0_sknn[0, 0] 

31 nn = rho0_nn.shape[0] 

32 

33 rho_nnt = np.zeros((nn, nn, len(testt)), dtype=complex) 

34 collect_rho_nn(time_distributor, rho_nnt, testt) 

35 

36 save = dict(f_n=f_n, rho0_nn=rho0_nn, rho_nnt=rho_nnt, testt=testt) 

37 

38 np.savez_compressed(ref_density_matrix, **save) 

39 print(f'Saved data to {ref_density_matrix}') 

40 

41 

42def read_ksd(ksd_fname: str): 

43 from ase.io.ulm import Reader 

44 

45 arrays = dict() 

46 with Reader(ksd_fname) as reader: 

47 arrays['S_MM'] = reader.S_uMM[0, 0, 0][:] 

48 arrays['C0_nM'] = reader.C0_unM[0, 0, 0][:] 

49 

50 return arrays 

51 

52 

53def assert_orthonormal(C_nM, S_MM): 

54 test_nn = C_nM @ S_MM @ C_nM.T.conj() 

55 

56 eye_nn = np.eye(len(test_nn)) 

57 

58 error_nn = np.abs(test_nn - eye_nn) 

59 assert np.max(error_nn) < 1e-12 

60 

61 

62def _test_read_C_nM(ksd_fname, wfssnap_fname, strideM, striden): 

63 from rhodent.density_matrices.readers.gpaw import WfsReader 

64 

65 if world.rank == 0: 

66 arrays = read_ksd(ksd_fname) 

67 S_MM = arrays['S_MM'] 

68 

69 testt = [0, 1, 3] 

70 

71 # Read on many ranks, gather by using world.sum 

72 stride_opts = dict(stridet=2) 

73 wfs_reader = WfsReader(wfs_fname=wfssnap_fname, **stride_opts) 

74 

75 # Get the ground state stuff 

76 C0_nM = wfs_reader.C0_sknM[0, 0] 

77 nn, nM = C0_nM.shape 

78 

79 if world.rank == 0: 

80 assert_orthonormal(C0_nM, S_MM) 

81 

82 # Read on many ranks, gather by world.sum 

83 C_nMt = np.zeros((nn, nM, len(testt)), dtype=complex) 

84 collect_C_nM(wfs_reader, striden, strideM, C_nMt, testt) 

85 

86 if world.rank == 0: 

87 for t in range(len(testt)): 

88 C_nM = C_nMt[..., t] 

89 assert_orthonormal(C_nM, S_MM) 

90 

91 # Read on many ranks, gather properly 

92 C_root_nMt = np.zeros((nn, nM, len(testt)), dtype=complex) if world.rank == 0 else None 

93 collect_C_nM_gather_on_root(wfs_reader, striden, strideM, C_root_nMt, testt) 

94 

95 if world.rank == 0: 

96 assert np.allclose(C_nMt, C_root_nMt) 

97 

98 

99def _test_read_rho_MM(ksd_fname, wfssnap_fname, strideM1, strideM2, striden): 

100 from rhodent.density_matrices.readers.gpaw import LCAORhoWfsReader 

101 

102 if world.rank == 0: 

103 arrays = read_ksd(ksd_fname) 

104 S_MM = arrays['S_MM'] 

105 

106 testt = [0, 1, 3] 

107 

108 stride_opts = dict(stridet=2, striden=striden) 

109 rho_reader = LCAORhoWfsReader(wfs_fname=wfssnap_fname, **stride_opts) 

110 

111 # Get the ground state stuff 

112 f_n = rho_reader.f_skn[0, 0] 

113 C0_nM = rho_reader.C0_sknM[0, 0] 

114 rho0_MM = rho_reader.rho0_skMM[0, 0] 

115 assert np.isrealobj(rho0_MM) 

116 assert np.isrealobj(C0_nM) 

117 

118 nn, nM = C0_nM.shape 

119 diag_mask_nn = np.eye(nn, dtype=bool) 

120 

121 if world.rank == 0: 

122 C0S_nM = C0_nM @ S_MM 

123 

124 # Assert orthonormal 

125 rho0_nn = C0S_nM @ (C0S_nM @ rho0_MM).swapaxes(0, 1) 

126 rho0_n = rho0_nn[diag_mask_nn] 

127 assert np.allclose(rho0_n, f_n, rtol=1e-10), rho0_n 

128 

129 # Read on many ranks, gather by world.sum 

130 drho_MMt = np.zeros((nM, nM, len(testt)), dtype=complex) 

131 collect_rho_MM(rho_reader, strideM1, strideM2, drho_MMt, testt) 

132 

133 if world.rank == 0: 

134 # Matrix should be Hermitian 

135 error = np.abs(np.sum(drho_MMt.swapaxes(0, 1).conj() - drho_MMt)) 

136 assert error < 1e-15, f'rho not Hermitian, {error:.5e}' 

137 

138 # Read on many ranks, gather properly 

139 drho_root_MMt = np.zeros((nM, nM, len(testt)), dtype=complex) if world.rank == 0 else None 

140 collect_rho_MM_gather_on_root(rho_reader, strideM1, strideM2, drho_root_MMt, testt) 

141 

142 if world.rank == 0: 

143 assert np.allclose(drho_MMt, drho_root_MMt) 

144 

145 

146def collect_C_nM(wfs_reader, striden, strideM, C_nMt, testt): 

147 # This is more like a test of the gather_on_root 

148 C_nMt[:] = 0 

149 nn, nM = wfs_reader.C0_sknM.shape[2:] 

150 

151 for s, k, nmin, Mmin in product(range(1), 

152 range(1), 

153 range(0, nn, striden), 

154 range(0, nM, strideM)): 

155 n = slice(nmin, nmin + striden) 

156 M = slice(Mmin, Mmin + strideM) 

157 gen = wfs_reader.iread(s, k, n, M) 

158 

159 for globalt in wfs_reader.work_loop(wfs_reader.comm.rank): 

160 if globalt is None: 

161 continue 

162 dm_buffer = next(gen) 

163 if globalt not in testt: 

164 continue 

165 t = testt.index(globalt) 

166 rho_x = dm_buffer.real + 1j * dm_buffer.imag 

167 _nM1, _nM2, _ = C_nMt[n, M, :].shape 

168 C_nMt[n, M, t] += rho_x[:_nM1, :_nM2] 

169 _exhausted = object() 

170 assert next(gen, _exhausted) is _exhausted 

171 

172 world.sum(C_nMt, 0) 

173 

174 

175def collect_C_nM_gather_on_root(wfs_reader, striden, strideM, C_nMt, testt): 

176 nn, nM = wfs_reader.C0_sknM.shape[2:] 

177 full_dm = wfs_reader.collect_on_root(0, 0, slice(0, nn), slice(0, nM)) 

178 if world.rank != 0: 

179 assert full_dm is None 

180 return 

181 

182 assert full_dm is not None 

183 

184 full_dm = full_dm[testt] 

185 Rerho_xt = full_dm.real 

186 Imrho_xt = full_dm.imag 

187 C_nMt[:] = Rerho_xt + 1.0j * Imrho_xt 

188 

189 

190def collect_rho_MM(rho_MM_reader, strideM1, strideM2, rho_MMt, testt): 

191 # This is more like a test of the gather_on_root 

192 rho_MMt[:] = 0 

193 nM = rho_MM_reader.C0_sknM.shape[3] 

194 

195 for s, k, M1min, M2min in product(range(1), 

196 range(1), 

197 range(0, nM, strideM1), 

198 range(0, nM, strideM2)): 

199 M1 = slice(M1min, M1min + strideM1) 

200 M2 = slice(M2min, M2min + strideM2) 

201 gen = rho_MM_reader.iread(s, k, M1, M2) 

202 

203 for globalt in rho_MM_reader.work_loop(rho_MM_reader.comm.rank): 

204 if globalt is None: 

205 continue 

206 dm_buffer = next(gen) 

207 if globalt not in testt: 

208 continue 

209 t = testt.index(globalt) 

210 rho_x = dm_buffer.real + 1j * dm_buffer.imag 

211 _nM1, _nM2, _ = rho_MMt[M1, M2, :].shape 

212 rho_MMt[M1, M2, t] += rho_x[:_nM1, :_nM2] 

213 _exhausted = object() 

214 assert next(gen, _exhausted) is _exhausted 

215 

216 world.sum(rho_MMt, 0) 

217 

218 

219def collect_rho_MM_gather_on_root(rho_MM_reader, strideM1, strideM2, rho_MMt, testt): 

220 nM = rho_MM_reader.C0_sknM.shape[3] 

221 full_dm = rho_MM_reader.collect_on_root(0, 0, slice(0, nM), slice(0, nM)) 

222 if world.rank != 0: 

223 assert full_dm is None 

224 return 

225 

226 assert full_dm is not None 

227 

228 full_dm = full_dm[testt] 

229 Rerho_xt = full_dm.real 

230 Imrho_xt = full_dm.imag 

231 rho_MMt[:] = Rerho_xt + 1.0j * Imrho_xt 

232 

233 

234def collect_rho_nn(time_distributor, rho_nnt, testt): 

235 # This is more like a test of the gather_on_root 

236 rho_nnt[:] = 0 

237 

238 gen = iter(time_distributor) 

239 for indices in time_distributor.work_loop(time_distributor.comm.rank): 

240 if indices is None: 

241 continue 

242 dm_buffer = next(gen) 

243 dm_buffer = dm_buffer[testt] 

244 rho_xt = dm_buffer.real + 1j*dm_buffer.imag 

245 s, k, n1, n2 = indices 

246 assert s == 0 

247 assert k == 0 

248 _nn1, _nn2, _ = rho_nnt[n1, n2, :].shape 

249 rho_nnt[n1, n2, :] += rho_xt[:_nn1, :_nn2] 

250 

251 _exhausted = object() 

252 assert next(gen, _exhausted) is _exhausted 

253 world.sum(rho_nnt, 0) 

254 

255 

256def collect_rho_nn_gather_on_root(time_distributor, rho_nnt, testt): 

257 full_dm = time_distributor.collect_on_root() 

258 if world.rank != 0: 

259 assert full_dm is None 

260 return 

261 

262 assert full_dm is not None 

263 

264 full_dm = full_dm[testt] 

265 Rerho_xt = full_dm.real 

266 Imrho_xt = full_dm.imag 

267 rho_nnt[:] = Rerho_xt + 1.0j * Imrho_xt 

268 

269 

270def _test_read_rho_nn(ksd_fname, wfssnap_fname, striden, striden1, striden2): 

271 testt = [0, 1, 3] 

272 

273 wfs_strides = dict(stridet=2, striden=striden) 

274 stride_opts = dict(striden1=striden1, striden2=striden2, only_ia=False) 

275 

276 if striden == 0: 

277 rho_reader = KohnShamRhoWfsReader(wfs_fname=wfssnap_fname, ksd=ksd_fname, **wfs_strides) 

278 parameters = RhoParameters.from_ksd(rho_reader.ksd, **stride_opts) 

279 time_distributor = AlltoallvTimeDistributor(rho_reader, parameters) 

280 else: 

281 rho_reader = KohnShamRhoWfsReader(wfs_fname=wfssnap_fname, ksd=ksd_fname, 

282 comm=SerialCommunicator(), **wfs_strides) 

283 parameters = RhoParameters.from_ksd(rho_reader.ksd, comm=rho_reader.comm, **stride_opts) 

284 time_distributor = TimeDistributor(rho_reader, parameters) 

285 

286 # Read on many ranks, gather by using world.sum 

287 

288 # Get the ground state stuff 

289 f_n = time_distributor.ksd.reader.occ_un[0, 0] 

290 rho0_nn = time_distributor.rho_wfs_reader.rho0_sknn[0, 0] 

291 nn = rho0_nn.shape[0] 

292 assert np.isrealobj(rho0_nn) 

293 

294 # Note that here, nn is the index of the states in the full KS basis 

295 diag_mask_nn = np.eye(nn, dtype=bool) 

296 

297 if world.rank == 0: 

298 # Assert orthonormal 

299 rho0_n = rho0_nn[diag_mask_nn] 

300 np.testing.assert_allclose(rho0_n, f_n, rtol=1e-20, atol=1e-20) 

301 np.testing.assert_allclose(rho0_nn[~diag_mask_nn], 0, rtol=1e-20, atol=1e-20) 

302 

303 rho_nnt = np.zeros((nn, nn, len(testt)), dtype=complex) 

304 collect_rho_nn_gather_on_root(time_distributor, rho_nnt, testt) 

305 

306 if world.rank == 0: 

307 # Trace should be constant 

308 trace_t = np.trace(rho_nnt, axis1=0, axis2=1) 

309 np.testing.assert_allclose(trace_t, trace_t[0], rtol=1e-05, atol=1e-08) 

310 

311 # The following will depend very much on the particular system 

312 # Close to 1e-10 on the diagonal 

313 rho_nt = rho_nnt[diag_mask_nn] 

314 np.testing.assert_allclose(rho_nt, 0, rtol=0, atol=1e-7) 

315 # Close to 1e-5 on the off-diagonal 

316 np.testing.assert_allclose(rho_nnt[~diag_mask_nn], 0, rtol=0, atol=1e-3) 

317 

318 

319def _test_read_rho_nn_against_reference(ref_density_matrix, ksd_fname, wfssnap_fname, striden, striden1, striden2): 

320 reference = np.load(ref_density_matrix) 

321 testt = reference['testt'] 

322 

323 rho_nnt = np.zeros_like(reference['rho_nnt']) 

324 

325 wfs_strides = dict(stridet=2, striden=striden) 

326 stride_opts = dict(striden1=striden1, striden2=striden2, only_ia=False) 

327 

328 def read_and_compare(time_distributor, collect): 

329 # Get the ground state stuff 

330 f_n = time_distributor.ksd.reader.occ_un[0, 0] 

331 rho0_nn = time_distributor.rho_wfs_reader.rho0_sknn[0, 0] 

332 assert np.isrealobj(rho0_nn) 

333 

334 collect(time_distributor, rho_nnt, testt) 

335 

336 if world.rank == 0: 

337 assert np.allclose(f_n, reference['f_n']) 

338 assert np.allclose(rho0_nn, reference['rho0_nn']) 

339 assert np.allclose(rho_nnt, reference['rho_nnt']) 

340 

341 # Read on many ranks, gather by using world.sum 

342 rho_reader = KohnShamRhoWfsReader(wfs_fname=wfssnap_fname, ksd=ksd_fname, 

343 comm=SerialCommunicator(), **wfs_strides) 

344 parameters = RhoParameters.from_ksd(rho_reader.ksd, comm=rho_reader.comm, **stride_opts) 

345 time_distributor = TimeDistributor(rho_reader, parameters) 

346 read_and_compare(time_distributor, collect_rho_nn) 

347 

348 # Read on many ranks, gather properly 

349 read_and_compare(time_distributor, collect_rho_nn_gather_on_root) 

350 

351 # With alltoallv 

352 if striden == 0: 

353 pass 

354 else: 

355 return 

356 

357 rho_reader = KohnShamRhoWfsReader(wfs_fname=wfssnap_fname, ksd=ksd_fname, **wfs_strides) 

358 parameters = RhoParameters.from_ksd(rho_reader.ksd, **stride_opts) 

359 time_distributor = AlltoallvTimeDistributor(rho_reader, parameters) 

360 read_and_compare(time_distributor, collect_rho_nn) 

361 

362 read_and_compare(time_distributor, collect_rho_nn_gather_on_root) 

363 

364 

365@pytest.mark.parametrize('test_system', ['Na8', 'Ag8']) 

366def test_read_C_nM_strides(wfssnap_fname): 

367 """ Test reading the time-dependent wave functions file with different strides """ 

368 from rhodent.density_matrices.readers.gpaw import WfsReader 

369 

370 iread_args = (0, 0, slice(0, 10), slice(0, 10)) # s, k, n, M to read. 10x10 elements is enough 

371 # Read every time step as reference 

372 

373 reader = WfsReader(wfs_fname=wfssnap_fname, stridet=1) 

374 dt = reader.dt 

375 reftimes = reader.time_t 

376 

377 ref_rho_tnn = [] 

378 for dm_buffer in reader.gather_on_root(*iread_args): 

379 if reader.comm.rank > 0: 

380 assert dm_buffer is None 

381 continue 

382 C_nM = dm_buffer.real + 1j * dm_buffer.imag 

383 ref_rho_tnn.append(C_nM) 

384 

385 ref_rho_tnn = np.array(ref_rho_tnn) 

386 

387 if reader.comm.rank == 0: 

388 nreftimes = len(ref_rho_tnn) 

389 assert nreftimes > 8, nreftimes # Make sure that it is actually reading 

390 

391 # Read, skipping few steps at a time 

392 for stridet in [2, 3, 4]: 

393 reader = WfsReader(wfs_fname=wfssnap_fname, stridet=stridet) 

394 

395 test_rho_tnn = [] 

396 for dm_buffer in reader.gather_on_root(*iread_args): 

397 if reader.comm.rank > 0: 

398 assert dm_buffer is None 

399 continue 

400 C_nM = dm_buffer.real + 1j * dm_buffer.imag 

401 test_rho_tnn.append(C_nM) 

402 test_rho_tnn = np.array(test_rho_tnn) 

403 

404 if reader.comm.rank > 0: 

405 continue 

406 

407 # Make sure that exactly the right values are being read 

408 np.testing.assert_almost_equal(reader.dt, dt * stridet) 

409 np.testing.assert_equal(len(ref_rho_tnn[::stridet]), len(test_rho_tnn)) 

410 np.testing.assert_array_equal(reftimes[::stridet], reader.time_t) 

411 np.testing.assert_array_equal(ref_rho_tnn[::stridet], test_rho_tnn) 

412 

413 

414@pytest.mark.parametrize('test_system', ['Na8', 'Ag8']) 

415def test_read_rho_nn_strides(ksd_fname, wfssnap_fname): 

416 """ Test reading the density matrix from the time-dependent wave functions file with different strides """ 

417 from rhodent.density_matrices.readers.gpaw import KohnShamRhoWfsReader 

418 

419 iread_args = (0, 0, slice(0, 10), slice(0, 10)) # s, k, n1, n2 to read. 10x10 elements is enough 

420 

421 # Read every time step as reference 

422 reader = KohnShamRhoWfsReader(wfs_fname=wfssnap_fname, ksd=ksd_fname, stridet=1) 

423 dt = reader.dt 

424 reftimes = reader.time_t 

425 

426 ref_rho_tnn = [] 

427 for dm_buffer in reader.gather_on_root(*iread_args): 

428 if reader.comm.rank > 0: 

429 assert dm_buffer is None 

430 continue 

431 rho_nn = dm_buffer.real + 1j * dm_buffer.imag 

432 ref_rho_tnn.append(rho_nn) 

433 

434 ref_rho_tnn = np.array(ref_rho_tnn) 

435 

436 if reader.comm.rank == 0: 

437 nreftimes = len(ref_rho_tnn) 

438 assert nreftimes >= 8, nreftimes # Make sure that it is actually reading 

439 

440 # Read, skipping few steps at a time 

441 for stridet in [2, 3, 4]: 

442 reader = KohnShamRhoWfsReader(wfs_fname=wfssnap_fname, ksd=ksd_fname, stridet=stridet) 

443 

444 test_rho_tnn = [] 

445 for dm_buffer in reader.gather_on_root(*iread_args): 

446 if reader.comm.rank > 0: 

447 assert dm_buffer is None 

448 continue 

449 rho_nn = dm_buffer.real + 1j * dm_buffer.imag 

450 test_rho_tnn.append(rho_nn) 

451 test_rho_tnn = np.array(test_rho_tnn) 

452 

453 if reader.comm.rank > 0: 

454 continue 

455 

456 # Make sure that exactly the right values are being read 

457 np.testing.assert_almost_equal(reader.dt, dt * stridet) 

458 np.testing.assert_equal(len(ref_rho_tnn[::stridet]), len(test_rho_tnn)) 

459 np.testing.assert_array_equal(reftimes[::stridet], reader.time_t) 

460 np.testing.assert_array_equal(ref_rho_tnn[::stridet], test_rho_tnn) 

461 

462 

463@pytest.mark.parametrize('strideM', [1, 2, 3, 4]) 

464@pytest.mark.parametrize('striden', [1, 2, 3, 4]) 

465@pytest.mark.parametrize('test_system', ['2H2']) 

466@wrap_test(_test_read_C_nM) 

467def test_read_C_nM_small(): 

468 pass 

469 

470 

471@pytest.mark.parametrize('strideM', [80]) 

472@pytest.mark.parametrize('striden', [5, 9]) 

473@pytest.mark.parametrize('test_system', ['Na8', 'Ag8']) 

474@wrap_test(_test_read_C_nM) 

475def test_read_C_nM_big(): 

476 pass 

477 

478 

479@pytest.mark.parametrize('strideM1', [1, 3, 4]) 

480@pytest.mark.parametrize('strideM2', [1, 2, 4]) 

481@pytest.mark.parametrize('striden', [1, 2, 3, 4]) 

482@pytest.mark.parametrize('test_system', ['2H2']) 

483@wrap_test(_test_read_rho_MM) 

484def test_read_rho_MM_small(): 

485 pass 

486 

487 

488@pytest.mark.parametrize('strideM1', [80]) 

489@pytest.mark.parametrize('strideM2', [100]) 

490@pytest.mark.parametrize('striden', [1, 4, 9]) 

491@pytest.mark.parametrize('test_system', ['Na8', 'Ag8']) 

492@wrap_test(_test_read_rho_MM) 

493def test_read_rho_MM_big(): 

494 pass 

495 

496 

497@pytest.mark.parametrize('striden1', [1, 3, 4]) 

498@pytest.mark.parametrize('striden2', [3, 4]) 

499@pytest.mark.parametrize('striden', [1, 4]) 

500@pytest.mark.parametrize('test_system', ['2H2']) 

501@wrap_test(_test_read_rho_nn) 

502def test_read_rho_nn_small(): 

503 pass 

504 

505 

506@pytest.mark.parametrize('striden1', [8]) 

507@pytest.mark.parametrize('striden2', [8]) 

508@pytest.mark.parametrize('striden', [0]) 

509@pytest.mark.parametrize('test_system', ['Na8']) 

510@wrap_test(_test_read_rho_nn) 

511def test_read_rho_nn_big(): 

512 pass 

513 

514 

515@pytest.mark.parametrize('striden1', [1, 3, 4]) 

516@pytest.mark.parametrize('striden2', [3, 4]) 

517@pytest.mark.parametrize('striden', [1, 4]) 

518@pytest.mark.parametrize('test_system', ['2H2']) 

519@wrap_test(_test_read_rho_nn_against_reference) 

520def test_read_rho_nn_against_reference_small(): 

521 pass 

522 

523 

524@pytest.mark.parametrize('striden1', [8]) 

525@pytest.mark.parametrize('striden2', [8]) 

526@pytest.mark.parametrize('striden', [0]) 

527@pytest.mark.parametrize('test_system', ['Na8', 'Ag8']) 

528@wrap_test(_test_read_rho_nn_against_reference) 

529def test_read_rho_nn_against_reference_big(): 

530 pass 

531 

532 

533@pytest.mark.parametrize('striden1', [8]) 

534@pytest.mark.parametrize('striden2', [8]) 

535@pytest.mark.parametrize('striden', [0]) 

536@pytest.mark.parametrize('test_system', ['Na8']) 

537def test_read_rho_nn_parts(ksd_fname, wfssnap_fname, striden, striden1, striden2): 

538 

539 def calc(n1min, n1max, n2min, n2max): 

540 wfs_strides = dict(stridet=3, striden=striden) 

541 stride_opts = dict(striden1=min(striden1, n1max + 1 - n1min), 

542 striden2=min(striden2, n2max + 1 - n2min), 

543 n1min=n1min, n1max=n1max, n2min=n2min, n2max=n2max) 

544 

545 # Read on many ranks, gather properly 

546 rho_reader = KohnShamRhoWfsReader(wfs_fname=wfssnap_fname, ksd=ksd_fname, **wfs_strides) 

547 parameters = RhoParameters.from_ksd(rho_reader.ksd, **stride_opts) 

548 time_distributor = AlltoallvTimeDistributor(rho_reader, parameters) 

549 

550 full_dm = time_distributor.collect_on_root() 

551 if time_distributor.comm.rank != 0: 

552 assert full_dm is None 

553 return None 

554 assert full_dm is not None 

555 full_dm = full_dm[0] 

556 rho_root_nnt = full_dm.real + full_dm.imag 

557 return rho_root_nnt 

558 

559 ref_rho_nn = calc(0, 63, 0, 63) 

560 

561 for n1min, n1max in [(0, 4), (2, 13), (7, 14)]: 

562 for n2min, n2max in [(0, 4), (24, 53), (24, 38)]: 

563 test_rho_nn = calc(n1min, n1max, n2min, n2max) 

564 if world.rank != 0: 

565 continue 

566 cmp_rho_nn = ref_rho_nn[n1min:n1max+1, n2min:n2max+1] 

567 assert np.allclose(test_rho_nn, cmp_rho_nn) 

568 

569 

570@pytest.mark.parametrize('stridet', [1, 2]) 

571@pytest.mark.parametrize('test_system', ['2H2', 'Na8']) 

572def test_prepare_wave_function_readers(wfssnap_fname, stridet): 

573 from gpaw.lcaotddft.wfwriter import WaveFunctionReader 

574 from rhodent.density_matrices.readers.gpaw import prepare_wave_function_readers 

575 

576 def log(*args, **kwargs): 

577 pass 

578 

579 comm = world 

580 mainreader = WaveFunctionReader(wfssnap_fname) 

581 args = (mainreader, comm, log) 

582 

583 parallel_time_t, _, _ = prepare_wave_function_readers(*args, stridet=stridet, parallel=True) 

584 serial_time_t, _, _ = prepare_wave_function_readers(*args, stridet=stridet, parallel=False) 

585 

586 assert np.array_equal(parallel_time_t, serial_time_t) 

587 

588 

589@pytest.mark.parametrize('test_system', ['2H2', 'Na8', 'Ag8', 'Na55']) 

590def test_rho0(ksd_fname): 

591 from ase.io.ulm import Reader 

592 from rhodent.density_matrices.readers.gpaw import read_C0S_parallel 

593 from rhodent.utils import proxy_sknX_slicen 

594 ksdreader = Reader(ksd_fname) 

595 C0S_sknM = read_C0S_parallel(ksdreader, comm=world) 

596 

597 full_f_skn = ksdreader.occ_un 

598 f_skn = ksdreader.proxy('occ_un') 

599 occ_C0_sknM = ksdreader.proxy('C0_unM', 0) 

600 f_skn = proxy_sknX_slicen(f_skn, comm=world) 

601 occ_C0_sknM = proxy_sknX_slicen(occ_C0_sknM, comm=world) 

602 if occ_C0_sknM.size > 0: 

603 assert np.max(np.abs(occ_C0_sknM.imag)) < 1e-20 

604 occ_C0_sknM = occ_C0_sknM.real 

605 

606 nn = C0S_sknM.shape[2] 

607 

608 if world.rank == 0: 

609 print('(all ranks) Constructing rho0_sknn') 

610 if f_skn.size > 0: 

611 rho0_sknn = np.einsum('skn,sknM,sknO,skmM,skoO->skmo', 

612 f_skn, occ_C0_sknM, occ_C0_sknM, C0S_sknM, C0S_sknM, optimize=True) 

613 else: 

614 rho0_sknn = np.zeros(f_skn.shape[:2] + (nn, nn)) 

615 world.sum(rho0_sknn, -1) 

616 if world.rank == 0: 

617 print('(all ranks) Constructed rho0_sknn') 

618 

619 assert np.all(rho0_sknn.imag == 0) 

620 

621 diag_nn = np.eye(nn, dtype=bool) 

622 rho_skn = rho0_sknn[..., diag_nn] 

623 assert np.allclose(rho_skn, full_f_skn) 

624 

625 rho0_sknn[..., diag_nn] = 0 

626 assert np.allclose(rho0_sknn, 0) 

627 

628 

629if __name__ == '__main__': 

630 for test_system in ['2H2', 'Na8', 'Ag8']: 

631 write_reference_data(get_permanent_test_file(test_system, 'ref_density_matrix'), 

632 get_permanent_test_file(test_system, 'gpw_fname'), 

633 get_permanent_test_file(test_system, 'ksd_fname'), 

634 get_permanent_test_file(test_system, 'wfssnap_fname'))