Coverage for rhodent/density_matrices/base.py: 68%

148 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-01 16:57 +0000

1from __future__ import annotations 

2 

3from abc import ABC, abstractmethod 

4from typing import Generic, Generator, NamedTuple, TypeVar 

5 

6from gpaw.lcaotddft.ksdecomposition import KohnShamDecomposition 

7 

8from .density_matrix import DensityMatrix 

9from ..utils import Logger, add_fake_kpts, two_communicators 

10from ..utils.memory import HasMemoryEstimate, MemoryEstimate 

11from ..typing import Communicator 

12 

13 

14class WorkMetadata(NamedTuple): 

15 """ Metadata to the density matrices """ 

16 density_matrices: BaseDensityMatrices 

17 

18 @property 

19 def global_indices(self) -> tuple[int, ...]: 

20 """ Unique index for this work. """ 

21 raise NotImplementedError 

22 

23 @property 

24 @abstractmethod 

25 def desc(self) -> str: 

26 raise NotImplementedError 

27 

28 def __str__(self) -> str: 

29 return f'{self.__class__.__name__}{self.global_indices}' 

30 

31 def __repr__(self) -> str: 

32 return f'{self.__class__.__name__}{self.global_indices}' 

33 

34 

35WorkMetadataT = TypeVar('WorkMetadataT', bound=WorkMetadata) 

36 

37 

38class BaseDensityMatrices(HasMemoryEstimate, ABC, Generic[WorkMetadataT]): 

39 

40 _log: Logger 

41 _ksd: KohnShamDecomposition 

42 

43 """ 

44 Collection of density matrices in the Kohn-Sham basis for different times 

45 or frequencies, possibly after convolution with various pulses. 

46 

47 Plain density matrices and/or derivatives thereof may be represented. 

48 

49 Parameters 

50 ---------- 

51 ksd 

52 KohnShamDecomposition object or file name. 

53 real 

54 Calculate the real part of density matrices. 

55 imag 

56 Calculate the imaginary part of density matrices. 

57 calc_size 

58 Size of the calculation communicator. 

59 """ 

60 

61 def __init__(self, 

62 ksd: KohnShamDecomposition | str, 

63 real: bool = True, 

64 imag: bool = True, 

65 calc_size: int = 1, 

66 log: Logger | None = None): 

67 assert real or imag 

68 self._reim_r: list[str] = [] 

69 if real: 

70 self._reim_r.append('Re') 

71 if imag: 

72 self._reim_r.append('Im') 

73 

74 if log is None: 

75 self._log = Logger() 

76 else: 

77 self._log = log 

78 

79 self._loop_comm, self._calc_comm = two_communicators(-1, calc_size) 

80 if isinstance(ksd, KohnShamDecomposition): 

81 self._ksd = ksd 

82 else: 

83 self._ksd = KohnShamDecomposition(filename=ksd) 

84 add_fake_kpts(self._ksd) 

85 

86 # Do a quick sanity check at runtime 

87 self._runtime_verify_work_loop() 

88 

89 @abstractmethod 

90 def __str__(self) -> str: 

91 raise NotImplementedError 

92 

93 def get_memory_estimate(self) -> MemoryEstimate: 

94 memory_estimate = MemoryEstimate(comment='Unknown') 

95 

96 return memory_estimate 

97 

98 def parallel_prepare(self): 

99 """ Read everything necessary synchronously on all ranks. """ 

100 

101 @abstractmethod 

102 def __iter__(self) -> Generator[tuple[WorkMetadataT, DensityMatrix], None, None]: 

103 """ Obtain density matrices for various times, pulses or frequencies. 

104 

105 Yields 

106 ------ 

107 Tuple (work, dm) on the root rank of the calculation communicator: 

108 

109 work 

110 An object representing the metadata (time, frequency or pulse) for the work done. 

111 dm 

112 Density matrix for this time, frequency or pulse. 

113 """ 

114 raise NotImplementedError 

115 

116 def iread_gather_on_root(self) -> Generator[tuple[WorkMetadataT, DensityMatrix], None, None]: 

117 """ Obtain density matrices for various times, pulses or frequencies and gather to the root rank. 

118 

119 Yields 

120 ------ 

121 Tuple (work, dm) on the root rank of the loop and calculation communicators: 

122 

123 work 

124 An object representing the metadata (time, frequency or pulse) for the work done. 

125 dm 

126 Density matrix for this time, frequency or pulse. 

127 """ 

128 work: WorkMetadataT | None 

129 gen = iter(self) 

130 

131 # Loop over the work to be done, and the ranks that are supposed to do it 

132 self.parallel_prepare() 

133 for rank, work in self.global_work_loop_with_idle(): 

134 if work is None: 

135 # Rank rank will not do any work at this point 

136 continue 

137 

138 if rank == self.loop_comm.rank: 

139 mywork, mydm = next(gen) 

140 if self.calc_comm.rank == 0: 

141 self.log(f'Read {mywork.desc} in {self.log.elapsed("read"):.1f}s', 

142 who='Response', if_elapsed=5) 

143 assert work.global_indices == mywork.global_indices, f'{work.desc} != {mywork.desc}' 

144 

145 dm = DensityMatrix.broadcast( 

146 mydm if self.loop_comm.rank == rank else None, 

147 ksd=self.ksd, 

148 root=rank, dm_comm=self.calc_comm, comm=self.loop_comm) 

149 

150 yield work, dm 

151 

152 _exhausted = object() 

153 rem = next(gen, _exhausted) 

154 assert rem is _exhausted, rem 

155 

156 @property 

157 def ksd(self) -> KohnShamDecomposition: 

158 """ Kohn-Sham decomposition object. """ 

159 return self._ksd 

160 

161 @property 

162 def log(self) -> Logger: 

163 """ Logger. """ 

164 return self._log 

165 

166 def log_parallel(self, *args, **kwargs) -> Logger: 

167 """ Log message with communicator information. """ 

168 return self._log(*args, **kwargs, comm=self.loop_comm, who='Response') 

169 

170 @property 

171 def reim(self) -> list[str]: 

172 """ List of strings ``'Re'`` and ``'Im'``, depending on whether real, and/or imaginary parts are computed. """ 

173 return self._reim_r 

174 

175 @abstractmethod 

176 def work_loop(self, 

177 rank: int) -> Generator[WorkMetadataT | None, None, None]: 

178 """ The work to be done by a certain rank of the loop communicator. 

179 

180 Parameters 

181 ---------- 

182 rank 

183 Rank of the loop communicator. 

184 

185 Yields 

186 ------ 

187 Objects representing the time, frequency or pulse to be computed by rank ``rank``. 

188 None is yielded when `rank` does not do any work while other ranks are doing work. 

189 """ 

190 raise NotImplementedError 

191 

192 def _runtime_verify_work_loop(self): 

193 """ Verify that the description of work to be done is consistent across ranks. """ 

194 local_work_r = [list(self.work_loop(rank)) for rank in range(self.loop_comm.size)] 

195 work_lengths = [len(local_work) for local_work in local_work_r] 

196 assert all([work_lengths[0] == work_length for work_length in work_lengths]), \ 

197 f'The work loop has different length across the different ranks. {work_lengths}' 

198 concat_work_list = [work.global_indices for local_work in local_work_r for work in local_work 

199 if work is not None] 

200 assert len(concat_work_list) == len(set(concat_work_list)), \ 

201 f'Different ranks do duplicate work {concat_work_list}' 

202 

203 @property 

204 def local_work_plan(self) -> tuple[WorkMetadataT, ...]: 

205 """ The work to be done by a this rank of the loop communicator. 

206 

207 Yields 

208 ------ 

209 Objects representing the time, frequency or pulse to be computed by this rank. 

210 """ 

211 local_work_plan = tuple(work for work in self.work_loop(self.loop_comm.rank) 

212 if work is not None) 

213 return local_work_plan 

214 

215 @property 

216 def local_work_plan_with_idle(self) -> tuple[WorkMetadataT | None, ...]: 

217 """ The work to be done by a this rank of the loop communicator. 

218 

219 This function includes None values for when this rank does not do any work 

220 in order to synchronize the execution. 

221 

222 Yields 

223 ------ 

224 Objects representing the time, frequency or pulse to be computed by this rank. 

225 None is yielded when this rank does not do any work while other ranks are doing work. 

226 """ 

227 local_work_plan = tuple(self.work_loop(self.loop_comm.rank)) 

228 

229 return local_work_plan 

230 

231 def global_work_loop_with_idle(self) -> Generator[tuple[int, WorkMetadataT | None], None, None]: 

232 """ The work to be done by a all ranks of the loop communicator. 

233 

234 This function includes None values for when ranks do not do any work 

235 in order to synchronize the execution. 

236 

237 Yields 

238 ------ 

239 Lists of length equal to the loop communicator size. Each element in the list represents 

240 the work to be done. See `local_work_plan_with_idle`. 

241 """ 

242 work_loop_r = [self.work_loop(rank) for rank in range(self.loop_comm.size)] 

243 while True: 

244 for rank in range(self.loop_comm.size): 

245 try: 

246 work = next(work_loop_r[rank]) 

247 yield rank, work 

248 except StopIteration: 

249 if rank == 0: 

250 # No more work to do 

251 return 

252 else: 

253 raise RuntimeError(f'Ranks have different amount of work. Exited on {rank}') 

254 

255 def global_work_loop(self) -> Generator[tuple[int, WorkMetadataT | None], None, None]: 

256 """ The work to be done by a all ranks of the loop communicator. 

257 

258 Yields 

259 ------ 

260 Lists of length equal to the loop communicator size. Each element in the list represents 

261 the work to be done. See :func:`local_work_plan`. 

262 """ 

263 for rank, work in self.global_work_loop_with_idle(): 

264 if work is None: 

265 continue 

266 yield rank, work 

267 

268 @property 

269 def localn(self) -> int: 

270 """ Total number of density matrices this rank will work with. """ 

271 return len(self.local_work_plan) 

272 

273 @property 

274 def globaln(self) -> int: 

275 """ Total number of density matrices to work with across all ranks. """ 

276 local_work_r = [list(self.work_loop(rank)) for rank in range(self.loop_comm.size)] 

277 concat_work_list = [work for local_work in local_work_r for work in local_work 

278 if work is not None] 

279 return len(concat_work_list) 

280 

281 @property 

282 def calc_comm(self) -> Communicator: 

283 """ Calculation communicator. 

284 

285 Each rank of this communicator calculates the observables corresponding to 

286 a part (in electron-hole space) of the density matrices. """ 

287 return self._calc_comm 

288 

289 @calc_comm.setter 

290 def calc_comm(self, value: Communicator): 

291 from gpaw.mpi import world 

292 if value is None: 

293 self.calc_comm = world 

294 return 

295 

296 assert hasattr(value, 'rank') 

297 assert hasattr(value, 'size') 

298 self._calc_comm = value 

299 

300 @property 

301 def loop_comm(self) -> Communicator: 

302 """ Loop communicator. 

303 

304 Each rank of this communicator calculates the density matrices corresponding to 

305 different times, frequencies or after convolution with a different pulse. """ 

306 return self._loop_comm 

307 

308 @loop_comm.setter 

309 def loop_comm(self, value: Communicator): 

310 from gpaw.mpi import world 

311 if value is None: 

312 self.loop_comm = world 

313 return 

314 

315 assert hasattr(value, 'rank') 

316 assert hasattr(value, 'size') 

317 self._loop_comm = value 

318 raise NotImplementedError 

319 

320 @abstractmethod 

321 def write_to_disk(self, 

322 fmt: str): 

323 """ Calculate the density matrices amd save to disk. 

324 

325 Parameters 

326 ---------- 

327 fmt 

328 Formatting string. 

329 """ 

330 raise NotImplementedError