Coverage for rhodent/density_matrices/distributed/base.py: 84%

223 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-01 16:57 +0000

1from __future__ import annotations 

2 

3from abc import ABC, abstractmethod 

4from typing import Generator, Generic, NamedTuple, Iterable 

5from itertools import product, zip_longest 

6 

7import numpy as np 

8 

9from gpaw.mpi import world 

10from gpaw.lcaotddft.ksdecomposition import KohnShamDecomposition 

11 

12from ..buffer import DensityMatrixBuffer 

13from ..readers.gpaw import KohnShamRhoWfsReader 

14from ...utils import DTypeT, Logger, concatenate_indices, env 

15from ...typing import Communicator 

16from ...utils.memory import HasMemoryEstimate 

17 

18 

19class BaseDistributor(HasMemoryEstimate, ABC, Generic[DTypeT]): 

20 

21 """ Distribute density matrices over time, frequency or other dimensions across MPI ranks 

22 """ 

23 

24 def __init__(self, 

25 rho_reader: KohnShamRhoWfsReader, 

26 parameters: RhoParameters | None = None, 

27 comm: Communicator | None = None): 

28 self.rho_wfs_reader = rho_reader 

29 

30 self._comm = world if comm is None else comm 

31 if parameters is None: 

32 parameters = RhoParameters.from_ksd(self.ksd, self.comm) 

33 self._parameters = parameters 

34 

35 self.derivative_order_s = [0] 

36 

37 @property 

38 @abstractmethod 

39 def dtype(self) -> np.dtype[DTypeT]: 

40 """ Dtype of buffers. """ 

41 raise NotImplementedError 

42 

43 @property 

44 @abstractmethod 

45 def xshape(self) -> tuple[int, ...]: 

46 """ Shape of x-dimension in buffers. """ 

47 raise NotImplementedError 

48 

49 @property 

50 def ksd(self) -> KohnShamDecomposition: 

51 """ Kohn-Sham decomposition object. """ 

52 return self.rho_wfs_reader.ksd 

53 

54 @property 

55 def comm(self) -> Communicator: 

56 """ MPI communicator. """ 

57 return self._comm 

58 

59 @property 

60 def yield_re(self) -> bool: 

61 """ Whether real part of density matrices is calculated. """ 

62 return self.rho_wfs_reader.yield_re 

63 

64 @property 

65 def yield_im(self) -> bool: 

66 """ Whether imaginary part of density matrices is calculated. """ 

67 return self.rho_wfs_reader.yield_im 

68 

69 @property 

70 def log(self) -> Logger: 

71 """ Logger object. """ 

72 return self.rho_wfs_reader.log 

73 

74 @abstractmethod 

75 def __iter__(self) -> Generator[DensityMatrixBuffer, None, None]: 

76 """ Yield density matrices in parts. Different data is 

77 yielded on different ranks 

78 

79 Yields 

80 ------ 

81 Part of the density matrix 

82 """ 

83 raise NotImplementedError 

84 

85 def work_loop(self, 

86 rank: int) -> Generator[RhoIndices | None, None, None]: 

87 """ Like work_loop_by_rank but for one particular rank 

88 """ 

89 for chunks_r in self.work_loop_by_ranks(): 

90 yield chunks_r[rank] 

91 

92 @property 

93 def niters(self) -> int: 

94 """ Number of iterations needed to read all chunks. """ 

95 return len(list(self.work_loop_by_ranks())) 

96 

97 @property 

98 def maxntimes(self) -> int: 

99 """ Maximum number of ranks participating in reading of times. """ 

100 for t_r in self.rho_wfs_reader.work_loop_by_ranks(): 

101 return sum(1 for t in t_r if t is not None) 

102 

103 raise RuntimeError 

104 

105 @property 

106 def maxnchunks(self) -> int: 

107 """ Maximum number of ranks participating in reading of chunks. """ 

108 for chunks_r in self.work_loop_by_ranks(): 

109 return sum(1 for chunk in chunks_r if chunk is not None) 

110 

111 raise RuntimeError 

112 

113 def describe_reim(self) -> str: 

114 if self.yield_re and self.yield_im: 

115 return 'Real and imaginary parts' 

116 elif self.yield_re: 

117 return 'Real part' 

118 else: 

119 return 'Imaginary part' 

120 

121 def describe_derivatives(self) -> str: 

122 return 'derivative orders: ' + ', '.join([str(d) for d in self.derivative_order_s]) 

123 

124 def work_loop_by_ranks(self) -> Generator[list[RhoIndices | None], None, None]: 

125 """ Yield slice objects corresponding to the chunk of the density matrix 

126 that is gathered on each rank. 

127 

128 New indices are yielded until the entire density matrix is processed 

129 (across all ranks). 

130 

131 Yields 

132 ------ 

133 List of slice objects corresponding to part of the density matrix 

134 yielded on each ranks. None in place of the slice object if there is 

135 nothing yielded for that rank. 

136 """ 

137 gen = self._parameters.iterate_indices() 

138 

139 while True: 

140 chunks_r: list[RhoIndices | None] = [indices for _, indices 

141 in zip(range(self.comm.size), gen)] 

142 

143 remaining = self.comm.size - len(chunks_r) 

144 if remaining == 0: 

145 yield chunks_r 

146 elif remaining == self.comm.size: 

147 # There is nothing left to do for any rank 

148 break 

149 else: 

150 # Append Nones for the ranks that are not doing anything 

151 chunks_r += remaining * [None] 

152 yield chunks_r 

153 break 

154 

155 def gather_on_root(self) -> Generator[DensityMatrixBuffer | None, None, None]: 

156 self.rho_wfs_reader.C0S_sknM # Make sure to read this synchronously 

157 

158 for indices_r, dm_buffer in zip_longest(self.work_loop_by_ranks(), 

159 self, fillvalue=None): 

160 assert indices_r is not None, 'Work loop shorter than work' 

161 

162 # Yield root's own work 

163 if self.comm.rank == 0: 

164 assert indices_r[0] is not None 

165 assert dm_buffer is not None 

166 dm_buffer.ensure_contiguous_buffers() 

167 

168 yield dm_buffer.copy() 

169 else: 

170 yield None 

171 

172 # Yield the work of non-root 

173 for recvrank, recvindices in enumerate(indices_r[1:], start=1): 

174 if recvindices is None: 

175 # No work on this recvrank 

176 continue 

177 

178 if self.comm.rank == 0: 

179 # Receive work 

180 assert dm_buffer is not None 

181 dm_buffer.recv_arrays(self.comm, recvrank, log=self.log) 

182 yield dm_buffer.copy() 

183 else: 

184 # Send work to root if there is any 

185 if self.comm.rank == recvrank: 

186 assert dm_buffer is not None 

187 dm_buffer.send_arrays(self.comm, 0, log=self.log) 

188 yield None 

189 

190 def collect_on_root(self) -> DensityMatrixBuffer | None: 

191 gen = self._parameters.iterate_indices() 

192 

193 nnshape = (self._parameters.n1size, self._parameters.n2size) 

194 full_dm = DensityMatrixBuffer(nnshape, self.xshape, dtype=self.dtype) 

195 full_dm.zero_buffers(real=self.yield_re, imag=self.yield_im, derivative_order_s=self.derivative_order_s) 

196 

197 for indices, dm_buffer in zip_longest(gen, 

198 self.gather_on_root(), fillvalue=None): 

199 if self.comm.rank != 0: 

200 continue 

201 

202 assert indices is not None, 'Iterators must be same length' 

203 assert dm_buffer is not None, 'Iterators must be same length' 

204 

205 s, k, n1, n2 = indices 

206 assert s == 0 

207 assert k == 0 

208 

209 for partial_data, full_data in zip(dm_buffer._iter_buffers(), full_dm._iter_buffers()): 

210 _nn1, _nn2 = full_data[n1, n2].shape[:2] 

211 # Numpy struggles with the static type below 

212 full_data[n1, n2, :] += partial_data[:_nn1, :_nn2:] # type: ignore 

213 self.log(f'Collected on root: density matrix slice [s={s}, k={k}, n1={n1}, n2={n2}].', 

214 flush=True, who='Response') 

215 

216 if self.comm.rank != 0: 

217 return None 

218 

219 return full_dm 

220 

221 @classmethod 

222 @abstractmethod 

223 def from_reader(cls, 

224 rho_nn_reader: KohnShamRhoWfsReader, 

225 parameters: RhoParameters, 

226 **kwargs) -> BaseDistributor: 

227 """ Set up this class from a density matrix reader and parameters object 

228 

229 """ 

230 raise NotImplementedError 

231 

232 @classmethod 

233 def from_parameters(cls, 

234 wfs_fname: str, 

235 ksd: KohnShamDecomposition | str, 

236 comm=world, 

237 yield_re: bool = True, 

238 yield_im: bool = True, 

239 stridet: int = 1, 

240 log: Logger | None = None, 

241 verbose: bool = False, 

242 **kwargs): 

243 """ Set up this class, trying to enforce memory limit. 

244 

245 Parameters 

246 ---------- 

247 wfs_fname 

248 File name of the time-dependent wave functions file. 

249 ksd 

250 KohnShamDecomposition object or file name to the ksd file. 

251 comm 

252 MPI communicator. 

253 yield_re 

254 Whether to read and yield the real part of wave functions/density matrices. 

255 yield_im 

256 Whether to read and yield the imaginary part of wave functions/density matrices. 

257 stridet 

258 Skip this many steps when reading the time-dependent wave functions file. 

259 log 

260 Logger object. 

261 verbose 

262 Be verbose in the attempts to satisfy memory requirement. 

263 kwargs 

264 Options passed through to :func:`from_reader`. 

265 """ 

266 # Set up the time-dependent wave functions reader which is always needed 

267 rho_reader = KohnShamRhoWfsReader( 

268 wfs_fname=wfs_fname, ksd=ksd, comm=comm, 

269 yield_re=yield_re, yield_im=yield_im, log=log, stridet=stridet) 

270 

271 log = rho_reader.log 

272 

273 # Get the target memory limit 

274 to_MiB = 1024 ** -2 

275 mem_limit = env.get_response_max_mem(comm.size) / to_MiB 

276 log('Attempting to set up response calculation with memory limit of ' 

277 f'{mem_limit * to_MiB:.1f} MiB across all ranks.', who='Setup', rank=0) 

278 

279 totals = [] 

280 for iterations in range(1, 100): 

281 # Try setting up the distributor such that `iterations` iterations are 

282 # needed to process all chunks 

283 parameters = RhoParameters.from_ksd(rho_reader.ksd, comm, chunk_iterations=iterations) 

284 distributor = cls.from_reader(rho_reader, parameters, **kwargs) 

285 total = distributor.get_memory_estimate().grand_total 

286 totals.append(total) 

287 compare = totals[:-5:-1] # Last 4 totals in reverse order 

288 last_changes = [tot_new / tot_old for tot_new, tot_old in zip(compare, compare[1:])] 

289 if len(last_changes) == 0: 

290 improvement = '' 

291 else: 

292 s = ', '.join([f'{(1 - change)*100:.1f}%' for change in last_changes]) 

293 improvement = f'Last improvements {s}' 

294 

295 if verbose: 

296 log(f'Trying splitting in {distributor.niters:3} chunks -- estimate {total * to_MiB:.1f} MiB. ' 

297 f'{improvement}', who='Setup', rank=0) 

298 if total < mem_limit: 

299 log(f'Found suitable set of parameters limiting the memory to {total * to_MiB:.1f} MiB.', 

300 who='Setup', rank=0) 

301 return distributor 

302 if len(last_changes) == 3 and sum(last_changes) / 3 > 0.98: 

303 break 

304 

305 parameters = RhoParameters.from_ksd(rho_reader.ksd, comm, chunk_iterations=iterations) 

306 distributor = cls.from_reader(rho_reader, parameters, **kwargs) 

307 total = distributor.get_memory_estimate().grand_total 

308 

309 log(f'Cannot satisfy memory limit. Estimate is {total * to_MiB:.1f} MiB.', 

310 who='Setup', rank=0) 

311 

312 return distributor 

313 

314 

315class RhoIndices(NamedTuple): 

316 

317 s: int 

318 k: int 

319 n1: slice 

320 n2: slice 

321 

322 @staticmethod 

323 def concatenate_indices(indices_list: Iterable[RhoIndices], 

324 ) -> tuple[RhoIndices, list[RhoIndices]]: 

325 indices_list = list(indices_list) 

326 assert len(indices_list) > 0 

327 s, k = indices_list[0][:2] 

328 assert all(indices.s == s for indices in indices_list), f'All s must be identical {indices_list}' 

329 assert all(indices.k == k for indices in indices_list), f'All k must be identical {indices_list}' 

330 

331 _indices_concat, _reduced_indices_list = concatenate_indices( 

332 [(indices.n1, indices.n2) for indices in indices_list]) 

333 indices_concat = RhoIndices(s, k, *_indices_concat) 

334 reduced_indices_list = [RhoIndices(s, k, *indices) for indices in _reduced_indices_list] 

335 

336 return indices_concat, reduced_indices_list 

337 

338 

339class RhoParameters(NamedTuple): 

340 

341 """ Utility class to describe density matrix indices. 

342 

343 Parameters 

344 ---------- 

345 ns 

346 Number of spins. 

347 nk 

348 Number of kpoints. 

349 n1min 

350 Smallest index of row to read. 

351 n1max 

352 Largest index of row to read. 

353 n2min 

354 Smallest index of column to read. 

355 n2max 

356 Largest index of column to read. 

357 striden1 

358 Stride for reading rows. Each chunk will be this size in the first dimension. 

359 striden2 

360 Stride for reading columns. Each chunk will be this size in the second dimension. 

361 """ 

362 

363 ns: int 

364 nk: int 

365 n1min: int 

366 n1max: int 

367 n2min: int 

368 n2max: int 

369 striden1: int = 4 

370 striden2: int = 4 

371 

372 def __post_init__(self): 

373 self.striden1 = min(self.striden1, self.n1size) 

374 self.striden2 = min(self.striden2, self.n2size) 

375 

376 @property 

377 def full_nnshape(self) -> tuple[int, int]: 

378 """ Shape of the full density matrix to be read. """ 

379 return (self.n1size, self.n2size) 

380 

381 @property 

382 def nnshape(self) -> tuple[int, int]: 

383 """ Shape of each density matrix chunk. """ 

384 return (self.striden1, self.striden2) 

385 

386 @property 

387 def n1size(self) -> int: 

388 """ Size of full density matrix in the first dimension. """ 

389 return self.n1max + 1 - self.n1min 

390 

391 @property 

392 def n2size(self) -> int: 

393 """ Size of full density matrix in the first dimension. """ 

394 return self.n2max + 1 - self.n2min 

395 

396 def iterate_indices(self) -> Generator[RhoIndices, None, None]: 

397 """ Iteratively yield indices slicing chunks of the density matrix. """ 

398 for s, k, n1, n2 in product(range(self.ns), range(self.nk), 

399 range(0, self.n1size, self.striden1), 

400 range(0, self.n2size, self.striden2)): 

401 indices = RhoIndices(s=0, k=0, 

402 n1=slice(n1, n1 + self.striden1), 

403 n2=slice(n2, n2 + self.striden2)) 

404 yield indices 

405 

406 @classmethod 

407 def from_ksd(cls, 

408 ksd: KohnShamDecomposition, 

409 comm: Communicator | None = None, 

410 only_ia: bool = True, 

411 chunk_iterations: int = 1, 

412 **kwargs) -> RhoParameters: 

413 """ Initialize from KohnShamDecomposition. 

414 

415 Parameters 

416 ---------- 

417 ksd 

418 KohnShamDecomposition. 

419 comm 

420 MPI Communicator. 

421 only_ia 

422 ``True`` if the parameters should be set up such that 

423 the electron-hole part of the density matrix is read, 

424 otherwise full density matrix. 

425 chunk_iterations 

426 Attempt to set up the strides so that the total number of 

427 chunks is as close as possible but not more than the number 

428 of MPI ranks times :attr:`chunk_iterations`. 

429 kwargs 

430 Options passed through to the constructor. 

431 """ 

432 if comm is None: 

433 comm = world 

434 

435 # Number of spins, kpoints and states 

436 ns, nk, nn, _ = ksd.reader.proxy('C0_unM', 0).shape 

437 

438 params = dict() 

439 if only_ia: 

440 # Dimensions of electron-hole part 

441 imin, imax, amin, amax = [int(i) for i in ksd.ialims()] 

442 

443 params['n1min'], params['n2min'] = imin, amin 

444 params['n1max'], params['n2max'] = imax, amax 

445 else: 

446 params['n1min'], params['n2min'] = 0, 0 

447 params['n1max'], params['n2max'] = nn - 1, nn - 1 

448 

449 # Set up a helper object get the size 

450 helper = cls(ns, nk, **params) 

451 

452 # We want this many chunks in total 

453 target_nchunks = chunk_iterations * comm.size 

454 ar = helper.n2size / helper.n1size # Aspect ratio of density matrix 

455 

456 nsplits1 = max(int(np.floor(np.sqrt(target_nchunks / ar))), 1) 

457 nsplits2 = (target_nchunks + nsplits1 - 1) // nsplits1 

458 

459 # Defaults 

460 params['striden1'] = (helper.n1size + nsplits1 - 1) // nsplits1 

461 params['striden2'] = (helper.n2size + nsplits2 - 1) // nsplits2 

462 

463 # Overwrite the default options in params with explicitly set options 

464 params.update(**kwargs) 

465 

466 return cls(ns, nk, **params)