Coverage for rhodent/density_matrices/readers/gpaw.py: 92%

436 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-01 16:57 +0000

1from __future__ import annotations 

2from abc import ABC, abstractmethod 

3from typing import Callable, Generator 

4 

5import numpy as np 

6from numpy.typing import NDArray 

7from itertools import zip_longest 

8 

9from ase.units import Bohr 

10from ase.io.ulm import Reader 

11 

12from gpaw.mpi import broadcast, world 

13from gpaw.lcaotddft.ksdecomposition import KohnShamDecomposition 

14from gpaw.lcaotddft.wfwriter import WaveFunctionReader 

15 

16from ..buffer import DensityMatrixBuffer 

17from ...utils import Logger, add_fake_kpts, get_array_filter 

18from ...utils.logging import format_times 

19from ...utils.memory import HasMemoryEstimate, MemoryEstimate 

20from ...typing import Array1D, Communicator 

21 

22 

23class BaseWfsReader(ABC): 

24 

25 """ Read wave functions or density matrices from the time-dependent wave functions file. 

26 

27 Parameters 

28 ---------- 

29 wfs_fname 

30 File name of the time-dependent wave functions file written by GPAW. 

31 comm 

32 MPI communicator. 

33 yield_re 

34 Whether to read and yield the real part of wave functions/density matrices. 

35 yield_im 

36 Whether to read and yield the imaginary part of wave functions/density matrices. 

37 stridet 

38 Skip this many steps when reading. 

39 tmax 

40 Last time index to read. 

41 filter_times 

42 A list of times to read in atomic units. The closest times in the time-dependent 

43 wave functions file will be read. Applied after skipping the stridet steps to tmax. 

44 log 

45 Logger object. 

46 """ 

47 

48 def __init__(self, 

49 wfs_fname: str, 

50 comm=world, 

51 yield_re: bool = True, 

52 yield_im: bool = True, 

53 stridet: int = 1, 

54 tmax: int = -1, 

55 filter_times: list[float] | Array1D[np.float64] | None = None, 

56 log: Logger | None = None): 

57 self._comm = comm 

58 self._yield_re = yield_re 

59 self._yield_im = yield_im 

60 if log is None: 

61 log = Logger() 

62 self._log = log 

63 

64 # The main reader is closed when it is garbage collected 

65 # Hence, we need to keep it in the scope 

66 self.mainreader = WaveFunctionReader(wfs_fname) 

67 self._time_t, self.initreader, self._full_reader_t = prepare_wave_function_readers( 

68 self.mainreader, comm, self.log, stridet=stridet, tmax=tmax) 

69 self._flt_t = get_array_filter(self._time_t, filter_times) 

70 self.reader_t = [self._full_reader_t[r] 

71 for r in np.arange(len(self._full_reader_t), dtype=int)[self._flt_t]] 

72 

73 @property 

74 def comm(self) -> Communicator: 

75 """ MPI communicator. """ 

76 return self._comm 

77 

78 @property 

79 def yield_re(self) -> bool: 

80 """ Whether this object should read real parts. """ 

81 return self._yield_re 

82 

83 @property 

84 def yield_im(self) -> bool: 

85 """ Whether this object should read imaginary parts. """ 

86 return self._yield_im 

87 

88 @property 

89 def log(self) -> Logger: 

90 """ Logger object. """ 

91 return self._log 

92 

93 @property 

94 def time_t(self) -> Array1D[np.float64]: 

95 """ Array of times to read; in atomic units. """ 

96 return self._time_t[self._flt_t] # type: ignore 

97 

98 @property 

99 def nt(self) -> int: 

100 """ Number of times to read. """ 

101 return len(self.time_t) 

102 

103 @property 

104 def dt(self) -> float: 

105 """ Time step in atomic units. """ 

106 time_t = self.time_t 

107 dt = time_t[1] - time_t[0] 

108 if not np.allclose(time_t[1:] - dt, time_t[:-1]): 

109 fname = self.mainreader.filename 

110 raise ValueError(f'Unable to get a time step. Variable time step in {fname}.') 

111 

112 return dt 

113 

114 def work_loop(self, 

115 rank: int) -> Generator[int | None, None, None]: 

116 """ Yield the time indices that this rank will read. 

117 

118 New indices are yielded until the end of self.reader_t is reached 

119 (across all ranks). 

120 

121 Yields 

122 ------ 

123 Time index between 0 and len(self.reader_t) - 1 corresponding to 

124 the time being read by this rank. Or None if this rank has nothing 

125 to read. 

126 """ 

127 for t_r in self.work_loop_by_ranks(): 

128 yield t_r[rank] 

129 

130 def work_loop_by_ranks(self) -> Generator[list[int], None, None]: 

131 nt = self.nt 

132 ntperrank = (nt + self.comm.size - 1) // self.comm.size 

133 

134 for localt in range(ntperrank): 

135 globalt_r = [rank + localt * self.comm.size for rank in range(self.comm.size)] 

136 globalt_r = [globalt if globalt < nt else None for globalt in globalt_r] 

137 yield globalt_r 

138 

139 def global_work_loop(self) -> Generator[int, None, None]: 

140 for chunks_r in self.work_loop_by_ranks(): 

141 for chunk in chunks_r: 

142 if chunk is None: 

143 continue 

144 yield chunk 

145 

146 @abstractmethod 

147 def iread(self, *args, **kwargs) -> Generator[DensityMatrixBuffer, None, None]: 

148 """ Iteratively read wave functions or density matrices time by time. """ 

149 raise NotImplementedError 

150 

151 @abstractmethod 

152 def nnshape(self, *args, **kwargs) -> tuple[int, int]: 

153 """ Shape of the density matrices or wave functions. """ 

154 raise NotImplementedError 

155 

156 def gather_on_root(self, *args, **kwargs) -> Generator[DensityMatrixBuffer | None, None, None]: 

157 for indices_r, dm_buffer in zip_longest(self.work_loop_by_ranks(), 

158 self.iread(*args, **kwargs), fillvalue=None): 

159 assert indices_r is not None, 'Work loop shorter than work' 

160 

161 # Yield root's own work 

162 if self.comm.rank == 0: 

163 assert indices_r[0] is not None 

164 assert dm_buffer is not None 

165 dm_buffer.ensure_contiguous_buffers() 

166 

167 yield dm_buffer.copy() 

168 else: 

169 yield None 

170 

171 # Yield the work of non-root 

172 for recvrank, recvindices in enumerate(indices_r[1:], start=1): 

173 if recvindices is None: 

174 # No work on this recvrank 

175 continue 

176 

177 if self.comm.rank == 0: 

178 # Receive work 

179 assert dm_buffer is not None 

180 dm_buffer.recv_arrays(self.comm, recvrank, log=self.log) 

181 yield dm_buffer.copy() 

182 else: 

183 # Send work to root if there is any 

184 if self.comm.rank == recvrank: 

185 assert dm_buffer is not None 

186 dm_buffer.send_arrays(self.comm, 0, log=self.log) 

187 yield None 

188 

189 def collect_on_root(self, *args, **kwargs) -> DensityMatrixBuffer | None: 

190 nnshape = self.nnshape(*args, **kwargs) 

191 full_dm = DensityMatrixBuffer(nnshape, (self.nt, ), np.float64) 

192 if self.yield_re: 

193 full_dm.zeros(True, 0) 

194 if self.yield_im: 

195 full_dm.zeros(False, 0) 

196 

197 for t, dm_buffer in zip_longest(self.global_work_loop(), 

198 self.gather_on_root(*args, **kwargs), fillvalue=None): 

199 if self.comm.rank != 0: 

200 continue 

201 

202 assert t is not None, 'Iterators must be same length' 

203 assert dm_buffer is not None, 'Iterators must be same length' 

204 

205 for partial_data_nn, full_data_nn in zip(dm_buffer._iter_buffers(), 

206 full_dm[t]._iter_buffers()): 

207 full_data_nn[:] += partial_data_nn 

208 

209 if self.comm.rank != 0: 

210 return None 

211 

212 return full_dm 

213 

214 

215class KohnShamRhoWfsReader(HasMemoryEstimate, BaseWfsReader): 

216 

217 """ Read density matrices from the time-dependent wave functions file. 

218 

219 Yield density matrices time by time. 

220 

221 Parameters 

222 ---------- 

223 wfs_fname 

224 File name of the time-dependent wave functions file. 

225 ksd 

226 KohnShamDecomposition object or file name to the ksd file. 

227 comm 

228 MPI communicator. 

229 yield_re 

230 Whether to read and yield the real part of wave functions/density matrices. 

231 yield_im 

232 Whether to read and yield the imaginary part of wave functions/density matrices. 

233 stridet 

234 Skip this many steps when reading. 

235 tmax 

236 Last time index to read. 

237 filter_times 

238 A list of times to read in atomic units. The closest times in the time-dependent wave functions file 

239 will be read. 

240 striden 

241 Option passed through to the LCAORhoWfsReader. 

242 log 

243 Logger object. 

244 """ 

245 

246 def __init__(self, 

247 wfs_fname: str, 

248 ksd: str | KohnShamDecomposition, 

249 comm=world, 

250 yield_re: bool = True, 

251 yield_im: bool = True, 

252 stridet: int = 1, 

253 tmax: int = -1, 

254 filter_times: list[float] | Array1D[np.float64] | None = None, 

255 log: Logger | None = None, 

256 striden: int = 0): 

257 # Set up an internal LCAO density matrix reader 

258 self.lcao_rho_reader = LCAORhoWfsReader( 

259 wfs_fname=wfs_fname, comm=comm, 

260 yield_re=yield_re, yield_im=yield_im, log=log, 

261 stridet=stridet, tmax=tmax, 

262 filter_times=filter_times, striden=striden) 

263 

264 # And copy its attributes to self 

265 self._yield_re = yield_re 

266 self._yield_im = yield_im 

267 self._comm = self.lcao_rho_reader.comm 

268 self._log = self.lcao_rho_reader.log 

269 self.mainreader = self.lcao_rho_reader.mainreader 

270 self.initreader = self.lcao_rho_reader.initreader 

271 self._full_reader_t = self.lcao_rho_reader._full_reader_t 

272 self.reader_t = self.lcao_rho_reader.reader_t 

273 self._flt_t = self.lcao_rho_reader._flt_t 

274 self._time_t = self.lcao_rho_reader._time_t 

275 

276 # Set up ksd 

277 if isinstance(ksd, KohnShamDecomposition): 

278 self._ksd = ksd 

279 else: 

280 self._ksd = KohnShamDecomposition(filename=ksd) 

281 add_fake_kpts(self._ksd) 

282 

283 self._C0S_sknM: NDArray[np.float64] | None = None 

284 self._rho0_sknn: NDArray[np.float64] | None = None 

285 

286 def __str__(self) -> str: 

287 nn, nM = proxy_coefficients(self.initreader).shape[2:] 

288 ntperrank = (self.nt + self.comm.size - 1) // self.comm.size 

289 

290 lines = [] 

291 lines.append('Time-dependent wave functions reader') 

292 lines.append(' Constructing density matrices in basis of ground state orbitals.') 

293 lines.append('') 

294 lines.append(f' file: {self.mainreader.filename}') 

295 lines.append(f' wave function dimensions {(nn, nM)}') 

296 lines.append(f' {self.nt} times') 

297 lines.append(f' {format_times(self.time_t, units="au")}') 

298 lines.append(f' {self.comm.size} ranks reading in {ntperrank} iterations') 

299 

300 return '\n'.join(lines) 

301 

302 def get_memory_estimate(self) -> MemoryEstimate: 

303 nn, nM = proxy_coefficients(self.initreader).shape[2:] 

304 

305 memory_estimate = MemoryEstimate() 

306 memory_estimate.add_key('C0S_nM', (nn, nM), float, 

307 on_num_ranks=self.comm.size) 

308 memory_estimate.add_key('rho0_MM', (nM, nM), float, 

309 on_num_ranks=self.comm.size) 

310 

311 return memory_estimate 

312 

313 @property 

314 def ksd(self) -> KohnShamDecomposition: 

315 """ Kohn-Sham decomposition object. """ 

316 return self._ksd 

317 

318 @property 

319 def C0S_sknM(self) -> NDArray[np.float64]: 

320 if self._C0S_sknM is None: 

321 self.log(f'Constructing C0_sknM on {self.comm.size} ranks', 

322 who='Reader', rank=0, flush=True) 

323 self._C0S_sknM = read_C0S_parallel(self.ksd.reader, comm=self.comm) 

324 self.log('Constructed C0_sknM', 

325 who='Reader', rank=0, flush=True) 

326 assert self._C0S_sknM is not None 

327 return self._C0S_sknM 

328 

329 @property 

330 def rho0_sknn(self) -> NDArray[np.float64]: 

331 if self._rho0_sknn is None: 

332 f_skn = self.ksd.reader.occ_un[:] 

333 nn = f_skn.shape[2] 

334 rho0_sknn = np.zeros(f_skn.shape[:2] + (nn, nn)) 

335 diag_nn = np.eye(nn, dtype=bool) 

336 rho0_sknn[..., diag_nn] = f_skn 

337 self._rho0_sknn = rho0_sknn 

338 return self._rho0_sknn 

339 

340 def nnshape(self, 

341 s: int, 

342 k: int, 

343 n1: slice, 

344 n2: slice) -> tuple[int, int]: 

345 n1size = n1.stop - n1.start 

346 n2size = n2.stop - n2.start 

347 nnshape = (n1size, n2size) 

348 return nnshape 

349 

350 def parallel_prepare(self): 

351 """ Read everything necessary synchronously on all ranks. """ 

352 self.C0S_sknM 

353 self.lcao_rho_reader.rho0_skMM 

354 

355 def iread(self, 

356 s: int, 

357 k: int, 

358 n1: slice, 

359 n2: slice) -> Generator[DensityMatrixBuffer, None, None]: 

360 """ Read the density matrices time by time. 

361 

362 Parameters 

363 ---------- 

364 s, k, n1, n2 

365 Read these indices. 

366 """ 

367 dm_buffer = DensityMatrixBuffer(self.nnshape(s, k, n1, n2), (), np.float64) 

368 if self.yield_re: 

369 dm_buffer.zeros(True, 0) 

370 if self.yield_im: 

371 dm_buffer.zeros(False, 0) 

372 

373 einsumstr = 'nN,mM,NM->nm' 

374 self.C0S_sknM # Read this on all ranks 

375 

376 nM = self.C0S_sknM.shape[3] 

377 sliceM = slice(0, nM) 

378 

379 for lcao_dm in self.lcao_rho_reader.iread(s, k, sliceM, sliceM): 

380 C0S_nM1 = self.C0S_sknM[s, k, n1, :] # Here n is full KS basis 

381 C0S_nM2 = self.C0S_sknM[s, k, n2, :] 

382 

383 self.log.start('read') 

384 

385 # Conjugate C_nM2 

386 if self.yield_re: 

387 Rerho_MM = lcao_dm.real 

388 Rerho_x = np.einsum(einsumstr, C0S_nM1, C0S_nM2, Rerho_MM, optimize=True) 

389 dm_buffer.safe_fill(True, 0, Rerho_x) 

390 if self.yield_im: 

391 Imrho_MM = lcao_dm.imag 

392 Imrho_x = np.einsum(einsumstr, C0S_nM1, C0S_nM2, Imrho_MM, optimize=True) 

393 dm_buffer.safe_fill(False, 0, Imrho_x) 

394 

395 yield dm_buffer 

396 

397 

398class LCAORhoWfsReader(BaseWfsReader): 

399 

400 """ Read density matrices in the LCAO basis from the time-dependent wave functions file. 

401 

402 Yield density matrices time by time. 

403 """ 

404 

405 def __init__(self, 

406 wfs_fname: str, 

407 comm=world, 

408 yield_re: bool = True, 

409 yield_im: bool = True, 

410 stridet: int = 1, 

411 tmax: int = -1, 

412 filter_times: list[float] | Array1D[np.float64] | None = None, 

413 log: Logger | None = None, 

414 striden: int = 4): 

415 

416 super().__init__(wfs_fname=wfs_fname, comm=comm, 

417 yield_re=yield_re, yield_im=yield_im, log=log, 

418 stridet=stridet, tmax=tmax, 

419 filter_times=filter_times) 

420 self._f_skn: NDArray[np.float64] | None = None 

421 self._C0_sknM: NDArray[np.float64] | None = None 

422 self._rho0_skMM: NDArray[np.float64] | None = None 

423 self._striden = striden 

424 

425 @property 

426 def nn(self) -> int: 

427 return self.f_skn.shape[2] 

428 

429 @property 

430 def striden(self) -> int: 

431 return self._striden 

432 

433 @property 

434 def true_striden(self) -> int: 

435 if self.striden == 0: 

436 return self.nn 

437 return self.striden 

438 

439 @property 

440 def f_skn(self) -> NDArray[np.float64]: 

441 """ Occupations numbers. """ 

442 if self._f_skn is None: 

443 self._f_skn = proxy_occupations(self.initreader)[:] 

444 return self._f_skn 

445 

446 @property 

447 def C0_sknM(self) -> NDArray[np.float64]: 

448 if self._C0_sknM is None: 

449 C0_sknM = proxy_coefficients(self.initreader)[:] 

450 assert np.max(np.abs(C0_sknM.imag)) < 1e-20 

451 self._C0_sknM = C0_sknM.real 

452 return self._C0_sknM 

453 

454 @property 

455 def rho0_skMM(self) -> NDArray[np.float64]: 

456 if self._rho0_skMM is None: 

457 self.log(f'Constructing rho0_skMM on {self.comm.size} ranks', 

458 who='Reader', rank=0, flush=True) 

459 self._rho0_skMM = calculate_rho0_parallel(self.f_skn, self.C0_sknM, comm=self.comm) 

460 self.log('Constructed rho0_skMM', 

461 who='Reader', rank=0, flush=True) 

462 assert self._rho0_skMM is not None 

463 return self._rho0_skMM 

464 

465 def inner_work_loop(self) -> Generator[slice, None, None]: 

466 for n in range(0, self.nn, self.true_striden): 

467 yield slice(n, n + self.true_striden) 

468 

469 def subtract_ground_state(self, 

470 dm_buffer: DensityMatrixBuffer, 

471 s: int, 

472 k: int, 

473 M1: slice, 

474 M2: slice): 

475 rhs = -self.rho0_skMM[s, k, M1, M2] 

476 dm_buffer.safe_add(True, 0, rhs) 

477 

478 def nnshape(self, 

479 s: int, 

480 k: int, 

481 M1: slice, 

482 M2: slice) -> tuple[int, int]: 

483 M1size = M1.stop - M1.start 

484 M2size = M2.stop - M2.start 

485 MMshape = (M1size, M2size) 

486 return MMshape 

487 

488 def iread(self, 

489 s: int, 

490 k: int, 

491 M1: slice, 

492 M2: slice) -> Generator[DensityMatrixBuffer, None, None]: 

493 """ Read the density matrices time by time. 

494 

495 Parameters 

496 ---------- 

497 s, k, M1, M2 

498 Read these indices. 

499 """ 

500 dm_buffer = DensityMatrixBuffer(self.nnshape(s, k, M1, M2), (), np.float64) 

501 

502 einsumstr = 'n,nM,nO->MO' 

503 

504 self.rho0_skMM # Construct synchronously on all ranks 

505 for globalt in self.work_loop(self.comm.rank): 

506 if globalt is None: 

507 continue 

508 

509 if self.yield_re: 

510 dm_buffer.zeros(True, 0) 

511 if self.yield_im: 

512 dm_buffer.zeros(False, 0) 

513 

514 reader = self.reader_t[globalt] 

515 for n in self.inner_work_loop(): 

516 C_nM1 = proxy_C_nM(reader, s, k, n, M1) # Here n is occupied states only 

517 C_nM2 = proxy_C_nM(reader, s, k, n, M2) 

518 f_n = proxy_occupations(reader, s, k)[n] 

519 path = np.einsum_path(einsumstr, f_n, C_nM1.real, C_nM2.real, optimize='optimal')[0] 

520 

521 # Conjugate C_nM2 

522 if self.yield_re: 

523 Rerho_x = np.einsum(einsumstr, f_n, C_nM1.real, C_nM2.real, optimize=path) 

524 Rerho_x += np.einsum(einsumstr, f_n, C_nM1.imag, C_nM2.imag, optimize=path) 

525 dm_buffer.safe_add(True, 0, Rerho_x) 

526 if self.yield_im: 

527 Imrho_x = np.einsum(einsumstr, f_n, C_nM1.imag, C_nM2.real, optimize=path) 

528 Imrho_x -= np.einsum(einsumstr, f_n, C_nM1.real, C_nM2.imag, optimize=path) 

529 dm_buffer.safe_add(False, 0, Imrho_x) 

530 

531 self.subtract_ground_state(dm_buffer, s, k, M1, M2) 

532 

533 yield dm_buffer 

534 

535 

536class WfsReader(BaseWfsReader): 

537 

538 """ Read wave function LCAO coefficients from the time-dependent wave functions file. 

539 

540 Yield wave functions time by time. 

541 """ 

542 

543 def __init__(self, 

544 wfs_fname: str, 

545 comm=world, 

546 yield_re: bool = True, 

547 yield_im: bool = True, 

548 stridet: int = 1, 

549 tmax: int = -1, 

550 filter_times: list[float] | Array1D[np.float64] | None = None, 

551 log: Logger | None = None): 

552 super().__init__(wfs_fname=wfs_fname, comm=comm, 

553 yield_re=yield_re, yield_im=yield_im, log=log, 

554 stridet=stridet, tmax=tmax, 

555 filter_times=filter_times) 

556 self._f_skn: NDArray[np.float64] | None = None 

557 self._C0_sknM: NDArray[np.float64] | None = None 

558 

559 @property 

560 def nn(self) -> int: 

561 return self.C0_sknM.shape[2] 

562 

563 @property 

564 def nM(self) -> int: 

565 return self.C0_sknM.shape[3] 

566 

567 @property 

568 def f_skn(self) -> NDArray[np.float64]: 

569 """ Occupations """ 

570 if self._f_skn is None: 

571 self._f_skn = proxy_occupations(self.initreader)[:] 

572 return self._f_skn 

573 

574 @property 

575 def C0_sknM(self) -> NDArray[np.float64]: 

576 if self._C0_sknM is None: 

577 C0_sknM = proxy_coefficients(self.initreader)[:] 

578 assert np.max(np.abs(C0_sknM.imag)) < 1e-20 

579 self._C0_sknM = C0_sknM.real 

580 return self._C0_sknM 

581 

582 def nnshape(self, 

583 s: int, 

584 k: int, 

585 n: slice, 

586 M: slice) -> tuple[int, int]: 

587 nsize = n.stop - n.start 

588 Msize = M.stop - M.start 

589 nMshape = (nsize, Msize) 

590 return nMshape 

591 

592 def iread(self, 

593 s: int, 

594 k: int, 

595 n: slice, 

596 M: slice) -> Generator[DensityMatrixBuffer, None, None]: 

597 """ Read the density matrices time by time. 

598 

599 Parameters 

600 ---------- 

601 s, k, n, M 

602 Read these indices. 

603 """ 

604 dm_buffer = DensityMatrixBuffer(self.nnshape(s, k, n, M), (), np.float64) 

605 

606 for globalt in self.work_loop(self.comm.rank): 

607 if globalt is None: 

608 continue 

609 

610 if self.yield_re: 

611 dm_buffer.zeros(True, 0) 

612 if self.yield_im: 

613 dm_buffer.zeros(False, 0) 

614 

615 reader = self.reader_t[globalt] 

616 C_nM = proxy_C_nM(reader, s, k, n, M) 

617 if self.yield_re: 

618 dm_buffer.safe_fill(True, 0, C_nM.real) 

619 if self.yield_im: 

620 dm_buffer.safe_fill(False, 0, C_nM.imag) 

621 

622 yield dm_buffer 

623 

624 

625def prepare_wave_function_readers(mainreader, 

626 comm, 

627 log: Callable = print, 

628 stridet: int = 1, 

629 tmax: int = -1, 

630 parallel: bool = True, 

631 ) -> tuple[Array1D[np.float64], WaveFunctionReader, list[WaveFunctionReader]]: 

632 readerlen = len(mainreader) 

633 log(f'Opening time-dependent wave functions file {mainreader.filename}', who='Reader', rank=0, flush=True) 

634 

635 # Encode the action as int according to the following list 

636 int2action = [None, 'init', 'kick', 'propagate'] 

637 action2int = {a: i for i, a in enumerate(int2action)} 

638 

639 if parallel: 

640 # Read in parallel 

641 readrange = range(0, readerlen, comm.size) 

642 else: 

643 # Read everything on root 

644 readrange = range(readerlen) if comm.rank == 0 else range(0) 

645 

646 if comm.rank == 0 or parallel: 

647 # Buffers to read to on all ranks when parallel 

648 time_1 = np.array([0], dtype=float) 

649 action_1 = np.array([0], dtype=int) 

650 

651 def round_up(val): 

652 v = (val + comm.size - 1) 

653 v = v // comm.size 

654 return v * comm.size 

655 

656 if comm.rank == 0: 

657 # Buffers for gathering 

658 time_t = np.zeros(round_up(readerlen), dtype=time_1.dtype) 

659 action_t = np.zeros(round_up(readerlen), dtype=action_1.dtype) 

660 

661 # Read all times in parallel, gather on root 

662 for roott in readrange: 

663 t = roott + comm.rank 

664 if t < readerlen: 

665 log(f'Opened item #{t}', who='Reader', comm=comm, if_elapsed=5, flush=True) 

666 

667 reader = mainreader[t] 

668 action = getattr(reader, 'action', None) 

669 action_1[:] = action2int[action] 

670 try: 

671 time_1[:] = getattr(reader, 'time') 

672 except AttributeError: 

673 # Depending on the action these might not be set 

674 time_1[:] = np.nan 

675 

676 if parallel: 

677 comm.gather(time_1, 0, time_t[roott:roott + comm.size] if comm.rank == 0 else None) 

678 comm.gather(action_1, 0, action_t[roott:roott + comm.size] if comm.rank == 0 else None) 

679 else: 

680 time_t[t] = time_1[0] 

681 action_t[t] = action_1[0] 

682 

683 if comm.rank == 0: 

684 # The root rank must count the number of time entries 

685 nreadt = 0 

686 lasttime = -np.inf 

687 

688 readtimes = [] 

689 timereadert = [] 

690 initreadert = None 

691 for t in range(readerlen): 

692 action = int2action[action_t[t]] 

693 curtime = time_t[t] 

694 

695 if curtime <= lasttime: 

696 log(f'Negative time difference at #{t} ' 

697 f'({curtime:.1f} <= {lasttime:.1f}). Skipping', 

698 who='Reader', comm=comm, flush=True) 

699 continue 

700 

701 if action is None: 

702 # This is just some dummy entry 

703 continue 

704 

705 # Find the first 'init' entry 

706 if action == 'init': 

707 assert curtime >= lasttime, f'Times not matching {t}:{curtime} !>= {lasttime}' 

708 if initreadert is None: 

709 initreadert = t 

710 continue 

711 

712 assert action in ['kick', 'propagate'] 

713 assert curtime != np.nan 

714 if nreadt % stridet == 0: 

715 # Save multiples of stridet 

716 # Note that 0 will always be saved. That contains the kick 

717 readtimes.append(curtime) 

718 timereadert.append(t) 

719 nreadt += 1 

720 if tmax > 0 and nreadt >= tmax: 

721 break 

722 lasttime = curtime 

723 

724 assert initreadert is not None 

725 broadcast_obj = (initreadert, timereadert, np.array(readtimes)) 

726 else: 

727 broadcast_obj = None 

728 

729 # Distribute 

730 (initreadert, timereadert, time_t) = broadcast(broadcast_obj, comm=comm, root=0) 

731 initreader = mainreader[initreadert] 

732 timereader_t = [mainreader[t] for t in timereadert] 

733 

734 log(f'Opened time-dependent wave functions file with {len(timereader_t)} times', 

735 who='Reader', rank=0, flush=True) 

736 

737 return time_t, initreader, timereader_t 

738 

739 

740def read_C0S_parallel(ksdreader: Reader, 

741 comm=None) -> NDArray[np.float64]: 

742 if comm is None: 

743 comm = world 

744 

745 # Compute C0S in parallel 

746 C0_sknM = ksdreader.proxy('C0_unM', 0) 

747 S_skMM = ksdreader.proxy('S_uMM', 0) 

748 nM = C0_sknM.shape[3] 

749 

750 # Rank local contributions 

751 strideM = (nM + comm.size - 1) // comm.size 

752 sliceM = slice(strideM * comm.rank, strideM * (comm.rank + 1)) 

753 C0_sknM = C0_sknM[:][..., sliceM] 

754 S_skMM = S_skMM[:][..., sliceM, :] 

755 

756 # Compute C0S 

757 C0S_sknM = np.einsum('sknO,skOM->sknM', C0_sknM, S_skMM, optimize=True) 

758 

759 # Sum and distribute 

760 comm.sum(C0S_sknM) 

761 

762 return C0S_sknM 

763 

764 

765def calculate_rho0_parallel(f_skn: NDArray[np.float64], 

766 C0_sknM: NDArray[np.float64], 

767 comm=None) -> NDArray[np.float64]: 

768 if comm is None: 

769 comm = world 

770 

771 nn = f_skn.shape[2] 

772 striden = (nn + comm.size - 1) // comm.size 

773 slicen = slice(striden * comm.rank, striden * (comm.rank + 1)) 

774 f_skn = f_skn[:][..., slicen] 

775 C0_sknM = C0_sknM[:][..., slicen, :] 

776 

777 # Compute density matrix 

778 rho0_skMM = np.einsum('skn,sknM,sknO->skMO', 

779 f_skn, C0_sknM, C0_sknM, 

780 optimize=True) 

781 

782 # Sum and distribute 

783 comm.sum(rho0_skMM) 

784 

785 return rho0_skMM 

786 

787 

788def proxy_coefficients(reader, *indices): 

789 """ Proxy the wave function coefficients, with the correct units.""" 

790 coefficients = reader.wave_functions.proxy('coefficients', *indices) 

791 coefficients.scale = Bohr ** 1.5 

792 

793 return coefficients 

794 

795 

796def proxy_occupations(reader, *indices): 

797 """ Proxy the wave function occupations with the correct scale.""" 

798 occupations = reader.wave_functions.proxy('occupations', *indices) 

799 occupations.scale = 2 / reader.wave_functions.occupations.shape[0] 

800 

801 return occupations 

802 

803 

804def proxy_C_nM(reader, 

805 *indices): 

806 """ Proxy the wave function coefficients, with the correct units. 

807 

808 Make sure that the proxied array has at least two dimensions""" 

809 x = indices[:-2] 

810 n = indices[-2] 

811 M = indices[-1] 

812 

813 C_nM = proxy_coefficients(reader, *x) 

814 if isinstance(n, slice): 

815 C_x = C_nM[n][:, M] 

816 else: 

817 C_M = C_nM.proxy(n) 

818 C_x = np.atleast_2d(C_M[M]) 

819 assert len(C_x.shape) == 2 

820 return C_x