Coverage for rhodent/density_matrices/readers/numpy.py: 87%

171 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-01 16:57 +0000

1from __future__ import annotations 

2 

3from typing import Collection 

4import numpy as np 

5from numpy.typing import ArrayLike, NDArray 

6 

7from gpaw.mpi import world 

8from gpaw.lcaotddft.ksdecomposition import KohnShamDecomposition 

9 

10from ...perturbation import Perturbation, PerturbationLike, create_perturbation 

11from ...typing import Array1D, Communicator 

12from ...utils import Logger, add_fake_kpts, find_files, partial_format, get_gaussian_pulse_values, filter_array 

13from ...utils.logging import format_times, format_frequencies 

14 

15 

16class TimeDensityMatrixReader: 

17 

18 """ Finds density matrices in the time domain saved to disk and reads them. 

19 

20 Parameters 

21 ---------- 

22 pulserho_fmt 

23 Formatting string for the density matrices saved to disk. 

24 

25 The formatting string should be a plain string containing variable 

26 placeholders within curly brackets ``{}``. It should not be confused with 

27 a formatted string literal (f-string). 

28 

29 Example: 

30 

31 * pulserho_fmt = ``pulserho/t{time:09.1f}{tag}.npy``. 

32 

33 Accepts variables 

34 

35 * ``{time}`` - Time in units of as. 

36 * ``{tag}`` - Derivative tag, ``''``, ``'-Iomega'``, or ``'-omega2'``. 

37 * ``{pulsefreq}`` - Pulse frequency in units of eV. 

38 * ``{pulsefwhm}`` - Pulse FWHM in units of fs. 

39 ksd 

40 KohnShamDecomposition object or file name to the ksd file. 

41 filter_times 

42 Look for these times (or as close to them as possible). In units of as. 

43 pulses 

44 Density matrices in response to these pulses. By default, no information about the 

45 pulse. 

46 derivative_order_s 

47 List of derivative orders. 

48 log 

49 Logger object. 

50 """ 

51 

52 def __init__(self, 

53 pulserho_fmt: str, 

54 ksd: str | KohnShamDecomposition, 

55 filter_times: Array1D[np.float64] | list[float] | None = None, 

56 pulses: Collection[PerturbationLike] = [None], 

57 derivative_order_s: list[int] = [0], 

58 log: Logger | None = None, 

59 comm: Communicator | None = None): 

60 if log is None: 

61 log = Logger() 

62 self._log = log 

63 self._comm = world if comm is None else comm 

64 

65 if isinstance(ksd, KohnShamDecomposition): 

66 self._ksd = ksd 

67 else: 

68 self._ksd = KohnShamDecomposition(filename=ksd) 

69 add_fake_kpts(self._ksd) 

70 

71 self._pulses = [create_perturbation(pulse) for pulse in pulses] 

72 self.derivative_order_s = derivative_order_s 

73 self.pulserho_fmt = pulserho_fmt 

74 tag_s = ['', '-Iomega', '-omega2'] 

75 

76 if self.comm.rank == 0: 

77 # Look for files on the root rank 

78 

79 nested_times: list[Array1D[np.float64]] = [] 

80 for pulse in self.pulses: 

81 for derivative in self.derivative_order_s: 

82 # Partially format the format string, i.e. fill out the pulsefreq, 

83 # pulsefwhm, and tag fields 

84 tag = tag_s[derivative] 

85 fmt = partial_format(pulserho_fmt, tag=tag, **get_gaussian_pulse_values(pulse)) 

86 

87 # Search the file tree for files 

88 f = find_files(fmt, expected_keys=['time']) 

89 nested_times.append(f['time']) 

90 

91 self._time_t, self._part_time_t = extract_common(nested_times if self.comm.rank == 0 else None, 

92 filter_times, 

93 self.comm) 

94 

95 @property 

96 def ksd(self) -> KohnShamDecomposition: 

97 """ Kohn-Sham decomposition object. """ 

98 return self._ksd 

99 

100 @property 

101 def comm(self) -> Communicator: 

102 """ MPI communicator. """ 

103 return self._comm 

104 

105 @property 

106 def log(self) -> Logger: 

107 """ Logger object. """ 

108 return self._log 

109 

110 @property 

111 def times(self) -> Array1D[np.float64]: 

112 """ Simulation time in units of as. """ 

113 return self._time_t 

114 

115 @property 

116 def nt(self) -> int: 

117 """ Number of times. """ 

118 return len(self.times) 

119 

120 @property 

121 def pulses(self) -> list[Perturbation]: 

122 """ Pulses with which density matrices are convoluted. """ 

123 return self._pulses 

124 

125 def __str__(self) -> str: 

126 lines = ['Response from density matrices on disk'] 

127 

128 lines.append('') 

129 lines.append(f'Format string: {self.pulserho_fmt}') 

130 lines.append(f'Calculating response for {self.nt} times and {len(self.pulses)} pulses') 

131 lines.append(f' times: {format_times(self.times)}') 

132 npartt = len(self._part_time_t) 

133 if npartt > 0: 

134 lines.append(f'Additionally {npartt} times are available for some ' 

135 'pulses/derivatives only') 

136 lines.append(f' times: {format_times(self._part_time_t)}') 

137 

138 return '\n'.join(lines) 

139 

140 def read(self, 

141 time: float, 

142 pulse: Perturbation, 

143 derivative: int) -> NDArray[np.complex128]: 

144 r""" Read single density matrix from disk. 

145 

146 Parameters 

147 ---------- 

148 time 

149 Simulation time in units of as. 

150 pulse 

151 Pulse which this density matrix is in response to. 

152 derivative 

153 Read derivative of this order. 

154 

155 Returns 

156 ------- 

157 Density matrix :math:`\rho_ia`. 

158 """ 

159 

160 tag_s = ['', '-Iomega', '-omega2'] 

161 

162 fname_kw = dict(time=time, tag=tag_s[derivative], 

163 **get_gaussian_pulse_values(pulse)) 

164 fname = self.pulserho_fmt.format(**fname_kw) 

165 

166 rho = read_numpy(fname) 

167 if len(rho.shape) == 1: 

168 # Transform from ravelled form 

169 rho_ia = self.ksd.M_ia_from_M_p(rho) 

170 else: 

171 rho_ia = rho 

172 

173 return rho_ia 

174 

175 

176class FrequencyDensityMatrixReader: 

177 

178 """ Finds density matrices in the frequency domain saved to disk and reads them. 

179 

180 Parameters 

181 ---------- 

182 frho_fmt 

183 Formatting string for the density matrices 

184 in frequency space saved to disk. 

185 

186 The formatting string should be a plain string containing variable 

187 placeholders within curly brackets ``{}``. It should not be confused with 

188 a formatted string literal (f-string). 

189 

190 Example: 

191 

192 * frho_fmt = ``frho/w{freq:05.2f}-{reim}.npy``. 

193 

194 Accepts variables: 

195 

196 * ``{freq}`` - Frequency in units of eV. 

197 * ``{reim}`` - ``'Re'`` or ``'Im'`` for Fourier transform of real/imaginary 

198 part of density matrix. 

199 ksd 

200 KohnShamDecomposition object or file name to the ksd file. 

201 filter_frequencies 

202 Look for these frequencies (or as close to them as possible). In units of eV. 

203 log 

204 Logger object. 

205 """ 

206 

207 def __init__(self, 

208 frho_fmt: str, 

209 ksd: str | KohnShamDecomposition, 

210 filter_frequencies: ArrayLike | None = None, 

211 real: bool = True, 

212 imag: bool = True, 

213 log: Logger | None = None, 

214 comm: Communicator | None = None): 

215 if log is None: 

216 log = Logger() 

217 self._log = log 

218 if isinstance(ksd, KohnShamDecomposition): 

219 self._ksd = ksd 

220 else: 

221 self._ksd = KohnShamDecomposition(filename=ksd) 

222 add_fake_kpts(self._ksd) 

223 

224 self._comm = world if comm is None else comm 

225 self.frho_fmt = frho_fmt 

226 if not real and not imag: 

227 raise ValueError('At least one of real or imag must be true') 

228 reim_r = ['Re'] if real else [] 

229 reim_r += ['Im'] if imag else [] 

230 

231 if self.comm.rank == 0: 

232 # Look for files on the root rank 

233 

234 nested_freqs: list[Array1D[np.float64]] = [] 

235 for reim in reim_r: 

236 # Partially format the format string, i.e. fill out the pulsefreq, 

237 # pulsefwhm, and tag fields 

238 fmt = partial_format(frho_fmt, reim=reim) 

239 

240 # Search the file tree for files 

241 f = find_files(fmt, expected_keys=['freq']) 

242 nested_freqs.append(f['freq']) 

243 

244 self._freq_w, self._part_freq_w = extract_common(nested_freqs if self.comm.rank == 0 else None, 

245 filter_frequencies, 

246 self.comm) 

247 

248 @property 

249 def ksd(self) -> KohnShamDecomposition: 

250 """ Kohn-Sham decomposition object. """ 

251 return self._ksd 

252 

253 @property 

254 def comm(self) -> Communicator: 

255 """ MPI communicator. """ 

256 return self._comm 

257 

258 @property 

259 def log(self) -> Logger: 

260 """ Logger object. """ 

261 return self._log 

262 

263 @property 

264 def frequencies(self) -> Array1D[np.float64]: 

265 """ Frequencies in units of eV. """ 

266 return self._freq_w 

267 

268 @property 

269 def nw(self) -> int: 

270 """ Number of frequencies. """ 

271 return len(self.frequencies) 

272 

273 def __str__(self) -> str: 

274 lines = ['Response from Fourier transform of density matrices on disk'] 

275 

276 lines.append('') 

277 lines.append(f'Format string: {self.frho_fmt}') 

278 lines.append(f'Calculating response for {self.nw} frequencies') 

279 lines.append(f' frequencies: {format_frequencies(self.frequencies)}') 

280 npartw = len(self._part_freq_w) 

281 if npartw > 0: 

282 lines.append(f'Additionally {npartw} frequencies are available for real ' 

283 'or imaginary parts but not both') 

284 lines.append(f' frequencies: {format_frequencies(self._part_freq_w)}') 

285 

286 return '\n'.join(lines) 

287 

288 def read(self, 

289 frequency: float, 

290 real: bool) -> NDArray[np.complex128]: 

291 fname_kw = dict(freq=frequency, reim='Re' if real else 'Im') 

292 fname = self.frho_fmt.format(**fname_kw) 

293 

294 rho = read_numpy(fname) 

295 

296 if len(rho.shape) == 1 and self.ksd.only_ia: 

297 # Twice the rho is saved by the KohnShamDecomposition transform 

298 rho /= 2 

299 

300 if len(rho.shape) == 1: 

301 # Transform from ravelled form 

302 rho_ia = self.ksd.M_ia_from_M_p(rho) 

303 else: 

304 rho_ia = rho 

305 

306 return rho_ia 

307 

308 

309def read_numpy(fname: str) -> NDArray[np.complex128]: 

310 r""" Read density matrix from numpy binary file or archive. 

311 

312 Supports data stored in non-ravelled form (preferred; indices :math:`ia` 

313 for electron-hole pairs) and in ravelled form (legacy; single index :math:`p` 

314 for electron-hole pairs). 

315 

316 Parameters 

317 ---------- 

318 fname 

319 File name. 

320 

321 Returns 

322 ------- 

323 Density matrix :math:`\rho_ia` or :math:`\rho_p`. 

324 """ 

325 

326 f = np.load(fname) 

327 if isinstance(f, np.lib.npyio.NpzFile): 

328 # Read npz file 

329 if 'rho_p' in f.files: 

330 rho = f['rho_p'] 

331 if len(rho.shape) != 1: 

332 raise RuntimeError(f'Expected 1D array, got shape {rho.shape}.') 

333 elif 'rho_ia' in f.files: 

334 rho = f['rho_ia'] 

335 if len(rho.shape) != 2: 

336 raise RuntimeError(f'Expected 2D array, got shape {rho.shape}.') 

337 else: 

338 raise RuntimeError("Expected file 'rho_p', or 'rho_ia' in file") 

339 f.close() 

340 else: 

341 # Read npy file 

342 assert isinstance(f, np.ndarray) 

343 rho = f 

344 if len(rho.shape) not in [1, 2]: 

345 raise RuntimeError(f'Expected 1D or 2D array, got shape {rho.shape}.') 

346 

347 return rho 

348 

349 

350def extract_common(nested_values: list[Array1D[np.float64]] | None, 

351 filter_values: ArrayLike | None, 

352 comm: Communicator) -> tuple[Array1D[np.float64], Array1D[np.float64]]: 

353 """ From a list of arrays, extract array elements that are present in all arrays. 

354 

355 Parameters 

356 ---------- 

357 nested_values 

358 List of arrays on root rank, ``None`` on other ranks. 

359 filter_values 

360 Filter the values, keeping only the values closest to these. 

361 comm 

362 MPI communicator. 

363 

364 Returns 

365 ------- 

366 Tuple of values present in all arrays, and values present in at least one array. \ 

367 Broadcast to all ranks. 

368 """ 

369 if comm.rank == 0: 

370 assert nested_values is not None 

371 # Values present in all arrays 

372 values_any: set[float] = set(nested_values[0]) 

373 # Values present in any array 

374 values_all: set[float] = set(values_any) 

375 

376 for values in nested_values[1:]: 

377 values_any |= set(values) 

378 values_all &= set(values) 

379 

380 # Filter the values 

381 values_any = set(filter_array(sorted(values_any), filter_values)) # type: ignore 

382 values_all = set(filter_array(sorted(values_all), filter_values)) # type: ignore 

383 

384 # Values present in some arrays only 

385 values_some = values_any - values_all 

386 

387 # Broadcast to all ranks 

388 shapes = np.array([len(values_all), len(values_some)], dtype=int) 

389 else: 

390 assert nested_values is None 

391 shapes = np.array([0, 0], dtype=int) 

392 

393 comm.broadcast(shapes, 0) 

394 

395 # Values in all arrays: as array 

396 array_all = np.zeros(shapes[0], dtype=float) 

397 # Values in some arrays: as array 

398 array_some = np.zeros(shapes[1], dtype=float) 

399 

400 if comm.rank == 0: 

401 array_all[:] = sorted(values_all) 

402 array_some[:] = sorted(values_some) 

403 

404 comm.broadcast(array_all, 0) 

405 if array_some.size > 0: 

406 comm.broadcast(array_some, 0) 

407 

408 return array_all, array_some