Coverage for rhodent/writers/writer.py: 96%

185 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-01 16:57 +0000

1from __future__ import annotations 

2 

3from abc import ABC, abstractmethod 

4from typing import Any, Generic, TypeVar 

5import numpy as np 

6from numpy.typing import NDArray 

7 

8from gpaw.mpi import world 

9from gpaw.io import Writer as GPAWWriter 

10 

11from ..density_matrices.base import WorkMetadata, WorkMetadataT, BaseDensityMatrices 

12from ..density_matrices.frequency import FrequencyDensityMatrices, FrequencyDensityMatrixMetadata 

13from ..density_matrices.time import ConvolutionDensityMatrices, ConvolutionDensityMatrixMetadata 

14from ..calculators.base import BaseObservableCalculator 

15from ..voronoi import VoronoiWeights, EmptyVoronoiWeights, atom_projections_to_numpy 

16from ..utils import Result, ResultKeys 

17 

18 

19class ResultsCollector(ABC, Generic[WorkMetadataT]): 

20 

21 """ Utility class to collect result arrays for different 

22 times, pulses, or frequencies. 

23 

24 Parameters 

25 ---------- 

26 resultkeys 

27 Result keys to be collected. 

28 additional_dimension 

29 Shape of additional dimension(s) due to the different times, frequencies, etc. 

30 additional_suffix 

31 String prepended to the suffix if each key. 

32 """ 

33 

34 def __init__(self, 

35 calc: BaseObservableCalculator, 

36 calc_kwargs: dict[str, Any], 

37 resultkeys: ResultKeys, 

38 additional_suffix: str, 

39 additional_dimension: tuple[int, ...], 

40 exclude: list[str] = []): 

41 self.calc = calc 

42 self.calc_kwargs = calc_kwargs 

43 self.resultkeys = resultkeys.__copy__() 

44 for key in exclude: 

45 if key in self.resultkeys: 

46 self.resultkeys.remove(key) 

47 self.additional_dimension = additional_dimension 

48 self.additional_suffix = additional_suffix 

49 

50 # Create the new result keys for the aggregated data 

51 self.collect_resultkeys = ResultKeys() 

52 for key, shape, dtype in self.resultkeys: 

53 newkey = self.format_key(key) 

54 self.collect_resultkeys.add_key(newkey, additional_dimension + shape, dtype) 

55 

56 self.result = Result(mutable=True) 

57 

58 def empty_results(self): 

59 if world.rank == 0: 

60 self.result.create_all_zeros(self.collect_resultkeys) 

61 

62 def finalize_results(self): 

63 pass 

64 

65 def format_key(self, 

66 key: str) -> str: 

67 """ Add the new suffix to the key. 

68 

69 Parameters 

70 ---------- 

71 key 

72 Original result key. 

73 

74 Returns 

75 ------- 

76 New result key with the added suffix. 

77 """ 

78 shape, _ = self.resultkeys[key] 

79 if len(shape) == 0: 

80 return key + f'_{self.additional_suffix}' 

81 

82 s = key.split('_') 

83 assert len(s) > 1 

84 s[-1] = self.additional_suffix + s[-1] 

85 return '_'.join(s) 

86 

87 @abstractmethod 

88 def accumulate_results(self, 

89 work: WorkMetadataT, 

90 result: Result): 

91 pass 

92 

93 

94ResultsCollectorT = TypeVar('ResultsCollectorT', bound=ResultsCollector) 

95 

96 

97class TimeResultsCollector(ResultsCollector): 

98 

99 """ Collect results after convolution with different pulses. 

100 

101 The letter t is prepended to the suffix of the result keys to indicate 

102 an additional dimension of time. 

103 

104 Parameters 

105 ---------- 

106 calc 

107 Calculator. 

108 calc_kwargs 

109 Keyword arguments passed to the icalculate function. 

110 exclude 

111 Keys that are excluded from collection. 

112 """ 

113 

114 def __init__(self, 

115 calc: BaseObservableCalculator, 

116 calc_kwargs: dict[str, Any], 

117 exclude: list[str] = []): 

118 assert isinstance(calc.density_matrices, ConvolutionDensityMatrices) 

119 assert len(calc.density_matrices.pulses) == 1 

120 Nt = len(calc.times) 

121 

122 resultkeys = calc.get_result_keys(**calc_kwargs) 

123 super().__init__(calc, calc_kwargs, resultkeys, 

124 additional_suffix='t', additional_dimension=(Nt, ), exclude=exclude) 

125 

126 def accumulate_results(self, 

127 work: ConvolutionDensityMatrixMetadata, 

128 result: Result): 

129 assert isinstance(work, ConvolutionDensityMatrixMetadata) 

130 assert world.rank == 0 

131 

132 for key, _, _ in self.resultkeys: 

133 newkey = self.format_key(key) 

134 self.result.set_to(newkey, work.globalt, result[key]) 

135 

136 

137class TimeAverageResultsCollector(ResultsCollector): 

138 

139 """ Collect results and average over times. 

140 

141 Parameters 

142 ---------- 

143 calc 

144 Calculator. 

145 calc_kwargs 

146 Keyword arguments passed to the icalculate function. 

147 exclude 

148 Keys that are excluded from collection. 

149 """ 

150 

151 def __init__(self, 

152 calc: BaseObservableCalculator, 

153 calc_kwargs: dict[str, Any], 

154 exclude: list[str] = []): 

155 assert isinstance(calc.density_matrices, ConvolutionDensityMatrices) 

156 assert len(calc.density_matrices.pulses) == 1 

157 

158 resultkeys = calc.get_result_keys(**calc_kwargs) 

159 super().__init__(calc, calc_kwargs, resultkeys, 

160 additional_suffix='', additional_dimension=(), exclude=exclude) 

161 

162 def accumulate_results(self, 

163 work: ConvolutionDensityMatrixMetadata, 

164 result: Result): 

165 assert isinstance(work, ConvolutionDensityMatrixMetadata) 

166 assert world.rank == 0 

167 

168 for key, _, _ in self.resultkeys: 

169 newkey = self.format_key(key) 

170 self.result.add_to(newkey, slice(None), result[key]) 

171 

172 def finalize_results(self): 

173 if world.rank > 0: 

174 return 

175 

176 nt = len(self.calc.density_matrices.times) 

177 for key, _, _ in self.collect_resultkeys: 

178 self.result[key] /= nt 

179 

180 

181class PulseConvolutionResultsCollector(ResultsCollector): 

182 

183 """ Collect results after convolution with different pulses. 

184 

185 The letters pt are prepended to the suffix of the result keys to indicate 

186 an additional dimension of pulse and time. 

187 

188 Parameters 

189 ---------- 

190 calc 

191 Calculator. 

192 calc_kwargs 

193 Keyword arguments passed to the icalculate function. 

194 exclude 

195 Keys that are excluded from collection. 

196 """ 

197 

198 def __init__(self, 

199 calc: BaseObservableCalculator, 

200 calc_kwargs: dict[str, Any], 

201 exclude: list[str] = []): 

202 assert isinstance(calc.density_matrices, ConvolutionDensityMatrices) 

203 Np = len(calc.pulses) 

204 Nt = len(calc.times) 

205 

206 resultkeys = calc.get_result_keys(**calc_kwargs) 

207 super().__init__(calc, calc_kwargs, resultkeys, 

208 additional_suffix='pt', additional_dimension=(Np, Nt), exclude=exclude) 

209 

210 def accumulate_results(self, 

211 work: ConvolutionDensityMatrixMetadata, 

212 result: Result): 

213 assert isinstance(work, ConvolutionDensityMatrixMetadata) 

214 assert world.rank == 0 

215 

216 for key, _, _ in self.resultkeys: 

217 newkey = self.format_key(key) 

218 self.result.set_to(newkey, (work.globalp, work.globalt), result[key]) 

219 

220 

221class PulseConvolutionAverageResultsCollector(ResultsCollector): 

222 

223 """ Collect results after convolution with different pulses, average over times. 

224 

225 The letter p is prepended to the suffix of the result keys to indicate 

226 an additional dimension of pulse. 

227 

228 Parameters 

229 ---------- 

230 calc 

231 Calculator. 

232 calc_kwargs 

233 Keyword arguments passed to the icalculate function. 

234 exclude 

235 Keys that are excluded from collection. 

236 """ 

237 

238 def __init__(self, 

239 calc: BaseObservableCalculator, 

240 calc_kwargs: dict[str, Any], 

241 exclude: list[str] = []): 

242 assert isinstance(calc.density_matrices, ConvolutionDensityMatrices) 

243 Np = len(calc.pulses) 

244 

245 resultkeys = calc.get_result_keys(**calc_kwargs) 

246 super().__init__(calc, calc_kwargs, resultkeys, 

247 additional_suffix='p', additional_dimension=(Np, ), exclude=exclude) 

248 

249 def accumulate_results(self, 

250 work: ConvolutionDensityMatrixMetadata, 

251 result: Result): 

252 assert isinstance(work, ConvolutionDensityMatrixMetadata) 

253 assert world.rank == 0 

254 

255 for key, _, _ in self.resultkeys: 

256 newkey = self.format_key(key) 

257 self.result.add_to(newkey, work.globalp, result[key]) 

258 

259 def finalize_results(self): 

260 if world.rank > 0: 

261 return 

262 

263 nt = len(self.calc.density_matrices.times) 

264 for key, _, _ in self.collect_resultkeys: 

265 self.result[key] /= nt 

266 

267 

268class FrequencyResultsCollector(ResultsCollector): 

269 

270 """ Collect results in the frequency domain. 

271 

272 This class should work with the Fourier transform of 

273 the real part of density matrices. 

274 

275 The letter w is prepended to the suffix of the result keys to indicate 

276 an additional dimension of frequency. 

277 

278 Parameters 

279 ---------- 

280 calc 

281 Calculator. 

282 calc_kwargs 

283 Keyword arguments passed to the icalculate function. 

284 exclude 

285 Keys that are excluded from collection 

286 """ 

287 

288 def __init__(self, 

289 calc: BaseObservableCalculator, 

290 calc_kwargs: dict[str, Any], 

291 exclude: list[str] = []): 

292 assert isinstance(calc.density_matrices, FrequencyDensityMatrices) 

293 Nw = len(calc.frequencies) 

294 assert 'Im' not in calc.density_matrices.reim 

295 

296 resultkeys = calc.get_result_keys(**calc_kwargs) 

297 super().__init__(calc, calc_kwargs, resultkeys, 

298 additional_suffix='w', additional_dimension=(Nw, ), exclude=exclude) 

299 

300 def accumulate_results(self, 

301 work: FrequencyDensityMatrixMetadata, 

302 result: Result): 

303 assert isinstance(work, FrequencyDensityMatrixMetadata) 

304 assert world.rank == 0 

305 

306 for key, _, _ in self.resultkeys: 

307 newkey = self.format_key(key) 

308 self.result.set_to(newkey, work.globalw, result[key]) 

309 

310 

311class Writer(Generic[ResultsCollectorT]): 

312 

313 def __init__(self, collector: ResultsCollectorT): 

314 self._collector = collector 

315 self._ulm_tag = 'RhodentResults' 

316 

317 @property 

318 def collector(self) -> ResultsCollectorT: 

319 return self._collector 

320 

321 @property 

322 def calc(self) -> BaseObservableCalculator: 

323 return self.collector.calc 

324 

325 @property 

326 def density_matrices(self) -> BaseDensityMatrices: 

327 return self.collector.calc.density_matrices 

328 

329 @property 

330 def voronoi(self) -> VoronoiWeights: 

331 voronoi = self.calc.voronoi 

332 if voronoi is None: 

333 return EmptyVoronoiWeights() 

334 return voronoi 

335 

336 @property 

337 def common_arrays(self) -> dict[str, NDArray[np.float64] | NDArray[np.int64] | int | float]: 

338 """ Dictionary of eigenvalues and limits. """ 

339 imin, imax, amin, amax = self.calc.ksd.ialims() 

340 arrays: dict[str, NDArray[np.float64] | NDArray[np.int64] | int | float] = dict() 

341 arrays['eig_n'] = self.calc.eig_n 

342 arrays['eig_i'] = self.calc.eig_i 

343 arrays['eig_a'] = self.calc.eig_a 

344 arrays['imin'] = imin 

345 arrays['imax'] = imax 

346 arrays['amin'] = amin 

347 arrays['amax'] = amax 

348 

349 return arrays 

350 

351 @property 

352 def icalculate_kwargs(self) -> dict: 

353 """ Keyword arguments to icalculate. """ 

354 return self.collector.calc_kwargs 

355 

356 def fill_ulm(self, 

357 writer, 

358 work: WorkMetadata, 

359 result: Result): 

360 """ Fill one entry of the ULM file. 

361 

362 Parameters 

363 ---------- 

364 writer 

365 Open ULM writer object. 

366 work 

367 Metadata to current piece of data. 

368 result 

369 Result containing the current observables. 

370 """ 

371 raise NotImplementedError 

372 

373 def write_empty_arrays_ulm(self, writer): 

374 """ Add empty arrays in to the ULM file. 

375 

376 Parameters 

377 ---------- 

378 writer 

379 Open ULM writer object. 

380 """ 

381 raise NotImplementedError 

382 

383 def calculate_data(self) -> Result: 

384 """ Calculate results on all ranks and return Result object. 

385 

386 Returns 

387 ------- 

388 Retult object. Is empty on non-root ranks. 

389 """ 

390 self.collector.empty_results() 

391 

392 for work, res in self.calc.icalculate_gather_on_root(**self.icalculate_kwargs): 

393 self.collector.accumulate_results(work, res) 

394 

395 self.collector.finalize_results() 

396 

397 return self.collector.result 

398 

399 def calculate_and_save_npz(self, 

400 out_fname: str, 

401 write_extra: dict[str, Any] = dict()): 

402 """ Calculate results on all ranks and save to npz file. 

403 

404 Parameters 

405 ---------- 

406 out_fname 

407 File name. 

408 """ 

409 result = self.calculate_data() 

410 

411 if world.rank > 0: 

412 return 

413 

414 atom_projections = atom_projections_to_numpy(self.voronoi.atom_projections) 

415 np.savez(out_fname, **self.common_arrays, **result._data, # type: ignore 

416 atom_projections=atom_projections) 

417 self.calc.log_parallel(f'Written {out_fname}', flush=True) 

418 

419 def calculate_and_save_ulm(self, 

420 out_fname: str, 

421 write_extra: dict[str, Any] = dict()): 

422 """ Calculate results on all ranks and save to ULM file. 

423 

424 Parameters 

425 ---------- 

426 out_fname 

427 File name. 

428 """ 

429 self.collector.empty_results() 

430 

431 with GPAWWriter(out_fname, world, mode='w', tag=self._ulm_tag[:16]) as writer: 

432 writer.write(version=1) 

433 writer.write('atom_projections', self.voronoi.atom_projections) 

434 writer.write(**(self.common_arrays if world.rank == 0 else dict())) 

435 

436 self.write_empty_arrays_ulm(writer) 

437 

438 for work, res in self.calc.icalculate_gather_on_root(**self.icalculate_kwargs): 

439 self.fill_ulm(writer, work, res) 

440 self.collector.accumulate_results(work, res) 

441 

442 self.collector.finalize_results() 

443 writer.write(**self.collector.result._data) 

444 

445 if world.rank == 0: 

446 self.calc.log_parallel(f'Written {out_fname}', flush=True)