Coverage for rhodent/calculators/density.py: 83%

191 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-01 16:57 +0000

1from __future__ import annotations 

2 

3from pathlib import Path 

4import numpy as np 

5from numpy.typing import NDArray 

6from typing import Any, Generator, Sequence, Collection 

7 

8from ase.units import Bohr 

9from gpaw import GPAW 

10from gpaw.grid_descriptor import GridDescriptor 

11from gpaw.lcaotddft.densitymatrix import get_density 

12 

13from .base import BaseObservableCalculator 

14from ..typing import GPAWCalculator 

15from ..density_matrices.base import WorkMetadata 

16from ..density_matrices.frequency import FrequencyDensityMatrixMetadata 

17from ..density_matrices.time import ConvolutionDensityMatrixMetadata 

18from ..perturbation import PerturbationLike 

19from ..response import BaseResponse 

20from ..typing import ArrayIsOnRootRank, Array1D, DistributedArray 

21from ..utils import ResultKeys, Result, get_gaussian_pulse_values, ParallelMatrix 

22 

23 

24class DensityCalculator(BaseObservableCalculator): 

25 

26 r""" Calculate densities in the time or frequency domain. 

27 

28 The induced density (i.e. the density minus the ground state density) is to first 

29 order given by 

30 

31 .. math:: 

32 

33 \delta n(\boldsymbol{r}) = -2 \sum_{ia}^\text{eh} 

34 n_{ia}(\boldsymbol{r}) \mathrm{Re}\:\delta\rho_{ia} 

35 

36 plus PAW corrections, where :math:`n_{ia}(\boldsymbol{r})` is the density of 

37 ground state Kohn-Sham pair :math:`ia` 

38 

39 .. math:: 

40 

41 n_{ia}(\boldsymbol{r}) = \psi^{(0)}_i(\boldsymbol{r}) \psi^{(0)}_a(\boldsymbol{r}). 

42 

43 In the time domain, electrons and holes densities can be computed. 

44 

45 .. math:: 

46 

47 \begin{align} 

48 n^\text{holes}(\boldsymbol{r}) &= \sum_{ii'} 

49 n_{ii'}(\boldsymbol{r}) \delta\rho_{ii'} \\ 

50 n^\text{electrons}(\boldsymbol{r}) &= \sum_{aa'} 

51 n_{aa'}(\boldsymbol{r}) \delta\rho_{aa'}. 

52 \end{align} 

53 

54 Refer to the documentation of 

55 :class:`HotCarriersCalculator <rhodent.calculators.HotCarriersCalculator>` for definitions 

56 of :math:`\delta\rho_{ii'}` and :math:`\delta\rho_{aa'}`. 

57 

58 Parameters 

59 ---------- 

60 gpw_file 

61 File name of GPAW ground state file. 

62 response 

63 Response object. 

64 filter_occ 

65 Filters for occupied states (holes). Provide a list of tuples (low, high) 

66 to compute the density of holes with energies within the interval low-high. 

67 filter_unocc 

68 Filters for unoccupied states (electrons). Provide a list of tuples (low, high) 

69 to compute the density of excited electrons with energies within the interval low-high. 

70 times 

71 Compute densities in the time domain, for these times (or as close to them as possible). 

72 In units of as. 

73 

74 May not be used together with :attr:`frequencies` or :attr:`frequency_broadening`. 

75 pulses 

76 Compute densities in the time domain, in response to these pulses. 

77 If none, then no pulse convolution is performed. 

78 

79 May not be used together with :attr:`frequencies` or :attr:`frequency_broadening`. 

80 frequencies 

81 Compute densities in the frequency domain, for these frequencies. In units of eV. 

82 

83 May not be used together with :attr:`times` or :attr:`pulses`. 

84 frequency_broadening 

85 Compute densities in the frequency domain, with Gaussian broadening of this width. 

86 In units of eV. 

87 

88 May not be used together with :attr:`times` or :attr:`pulses`. 

89 """ 

90 

91 def __init__(self, 

92 gpw_file: str, 

93 response: BaseResponse, 

94 filter_occ: Sequence[tuple[float, float]] = [], 

95 filter_unocc: Sequence[tuple[float, float]] = [], 

96 *, 

97 times: list[float] | Array1D[np.float64] | None = None, 

98 pulses: Collection[PerturbationLike] | None = None, 

99 frequencies: list[float] | Array1D[np.float64] | None = None, 

100 frequency_broadening: float = 0, 

101 ): 

102 super().__init__(response=response, 

103 times=times, 

104 pulses=pulses, 

105 frequencies=frequencies, 

106 frequency_broadening=frequency_broadening) 

107 self._occ_filters = [self._build_single_filter('o', low, high) for low, high in filter_occ] 

108 self._unocc_filters = [self._build_single_filter('u', low, high) for low, high in filter_unocc] 

109 

110 self.log.start('load_gpaw') 

111 self._calc = GPAW(gpw_file, txt=None, communicator=self.calc_comm, 

112 parallel={'domain': self.calc_comm.size}) 

113 msg = f'Loaded/initialized GPAW in {self.log.elapsed("load_gpaw"):.1f}' 

114 self.log.start('init_gpaw') 

115 

116 self.calc.initialize_positions() # Initialize in order to calculate density 

117 msg += f'/{self.log.elapsed("init_gpaw"):.1f} s' 

118 if self.calc_comm.rank == 0: 

119 self.log_parallel(msg) 

120 self.ksd.density = self.calc.density 

121 

122 @property 

123 def gdshape(self) -> tuple[int, int, int]: 

124 """ Shape of the real space grid. 

125 """ 

126 shape = tuple(int(N) - 1 for N in self.N_c) 

127 return shape # type: ignore 

128 

129 @property 

130 def gd(self) -> GridDescriptor: 

131 """ Real space grid. """ 

132 return self.ksd.density.finegd 

133 

134 @property 

135 def N_c(self) -> NDArray[np.int_]: 

136 """ Number of points in each Cartesian direction of the grid. 

137 """ 

138 return self.gd.N_c 

139 

140 @property 

141 def cell_cv(self) -> NDArray[np.float64]: 

142 """ Cell vectors. """ 

143 return self.gd.cell_cv * Bohr 

144 

145 @property 

146 def occ_filters(self) -> list[slice]: 

147 """ List of energy filters for occupied states. """ 

148 return self._occ_filters 

149 

150 @property 

151 def unocc_filters(self) -> list[slice]: 

152 """ List of energy filters for unoccupied states. """ 

153 return self._unocc_filters 

154 

155 @property 

156 def calc(self) -> GPAWCalculator: 

157 """ GPAW calculator instance. """ 

158 return self._calc # type: ignore 

159 

160 def get_result_keys(self, 

161 yield_total: bool = True, 

162 yield_electrons: bool = False, 

163 yield_holes: bool = False) -> ResultKeys: 

164 noccf = len(self.occ_filters) 

165 nunoccf = len(self.unocc_filters) 

166 if (yield_electrons or yield_holes) and not self._is_time_density_matrices: 

167 raise ValueError('Electron or hole densities can only be computed in the time domain.') 

168 

169 resultkeys = ResultKeys() 

170 if yield_total: 

171 resultkeys.add_key('rho_g', self.gdshape) 

172 

173 if yield_holes: 

174 resultkeys.add_key('occ_rho_g', self.gdshape) 

175 if noccf > 0: 

176 resultkeys.add_key('occ_rho_rows_fg', (noccf, ) + self.gdshape) 

177 resultkeys.add_key('occ_rho_diag_fg', (noccf, ) + self.gdshape) 

178 

179 if yield_electrons: 

180 resultkeys.add_key('unocc_rho_g', self.gdshape) 

181 if nunoccf > 0: 

182 resultkeys.add_key('unocc_rho_rows_fg', (nunoccf, ) + self.gdshape) 

183 resultkeys.add_key('unocc_rho_diag_fg', (nunoccf, ) + self.gdshape) 

184 

185 return resultkeys 

186 

187 @property 

188 def _need_derivatives_real_imag(self) -> tuple[list[int], bool, bool]: 

189 # Time domain: We only need the real part of the density matrix. 

190 # Frequency domain: We need the (complex) Fourier transform of 

191 # the real part of the density matrix. 

192 return ([0], True, False) 

193 

194 def _find_limit(self, 

195 lim: float) -> int: 

196 """ Find the first eigenvalue larger than :attr:`lim`. 

197 

198 Parameters 

199 ---------- 

200 lim 

201 Threshold value in units of eV. 

202 

203 Returns 

204 ------- 

205 The index of the first eigenvalue larger than :attr:`lim`. 

206 Returns `len(eig_n)` if :attr:`lim` is larger than all eigenvalues. 

207 """ 

208 if lim > self.eig_n[-1]: 

209 return len(self.eig_n) 

210 return int(np.argmax(self.eig_n > lim)) 

211 

212 def _build_single_filter(self, 

213 key: str, 

214 low: float, 

215 high: float) -> slice: 

216 imin, imax, amin, amax = self.ksd.ialims() 

217 

218 if key == 'o': 

219 nlow = min(self._find_limit(low), imax) - imin 

220 nhigh = min(self._find_limit(high), imax) - imin 

221 elif key == 'u': 

222 nlow = min(self._find_limit(low), amax) - amin 

223 nhigh = min(self._find_limit(high), amax) - amin 

224 else: 

225 raise RuntimeError(f'Unknown key {key}. Key must be "o" or "u"') 

226 return slice(nlow, nhigh) 

227 

228 def get_density(self, 

229 rho_nn: DistributedArray, 

230 nn_indices: str, 

231 fltn1: slice | NDArray[np.bool_] = slice(None), 

232 fltn2: slice | NDArray[np.bool_] = slice(None), 

233 u: int = 0) -> DistributedArray: 

234 r""" Calculate a real space density from a density matrix in the Kohn-Sham basis. 

235 

236 Parameters 

237 ---------- 

238 rho_nn 

239 Density matrix :math:`\delta\rho_{ia}`, :math:`\delta\rho_{ii'}`, or 

240 :math:`\delta\rho_{aa'}`. 

241 nn_indices 

242 Indices describing the density matrices :attr:`rho_nn`. One of 

243 

244 - `ia` for induced density :math:`\delta\rho_{ia'}`. 

245 - `ii` for holes density :math:`\delta\rho_{ii'}`. 

246 - `aa` for electrons density :math:`\delta\rho_{aa'}`. 

247 flt_n1 

248 Filter selecting rows of the density matrix. 

249 flt_n2 

250 Filter selecting columns of the density matrix. 

251 u 

252 k-point index. 

253 Returns 

254 ------- 

255 Distributed array with the density in real space on the root rank. 

256 """ 

257 imin, imax, amin, amax = self.ksd.ialims() 

258 if nn_indices not in ['ia', 'ii', 'aa']: 

259 raise ValueError(f'Parameter nn_indices must be either "ia", "ii" or "aa". Is {nn_indices}.') 

260 n1 = slice(imin, imax + 1) if nn_indices[0] == 'i' else slice(amin, amax + 1) 

261 n2 = slice(imin, imax + 1) if nn_indices[1] == 'i' else slice(amin, amax + 1) 

262 

263 nn1, nn2 = n1.stop - n1.start, n2.stop - n2.start 

264 nM = self.ksd.C0_unM[0].shape[-1] 

265 

266 if self.calc_comm.rank == 0: 

267 C0_nM = self.ksd.C0_unM[u] 

268 rho_n1n2 = ParallelMatrix((nn1, nn2), np.float64, comm=self.calc_comm, 

269 array=rho_nn[fltn1][:, fltn2]) 

270 C0_n1M = ParallelMatrix((nn1, nM), np.float64, comm=self.calc_comm, 

271 array=C0_nM[n1][fltn1]) 

272 C0_n2M = ParallelMatrix((nn2, nM), np.float64, comm=self.calc_comm, 

273 array=C0_nM[n2][fltn2]) 

274 else: 

275 rho_n1n2 = ParallelMatrix((nn1, nn2), np.float64, comm=self.calc_comm) 

276 C0_n1M = ParallelMatrix((nn1, nM), np.float64, comm=self.calc_comm) 

277 C0_n2M = ParallelMatrix((nn2, nM), np.float64, comm=self.calc_comm) 

278 

279 # Transform to LCAO basis C0_n1M.T @ rho_n1n2 @ C0_n2M 

280 self.log.start('transform_dm') 

281 

282 rho_MM = (C0_n1M.T @ rho_n1n2 @ C0_n2M).broadcast() 

283 # assert np.issubdtype(rho_nn.dtype, float) 

284 rho_MM = 0.5 * (rho_MM + rho_MM.T) 

285 

286 msg = f'Transformed DM and constructed density in {self.log.elapsed("transform_dm"):.1f}s' 

287 self.log.start('get_density') 

288 rho_g = get_density(rho_MM, self.calc.wfs, self.calc.density, u=u) 

289 msg += f'+{self.log.elapsed("get_density"):.1f}s' 

290 if self.calc_comm.rank == 0: 

291 self.log_parallel(msg, flush=True) 

292 

293 big_rho_g = self.gd.collect(rho_g) 

294 

295 if self.calc_comm.rank == 0: 

296 return big_rho_g 

297 else: 

298 return ArrayIsOnRootRank() 

299 

300 def icalculate(self, 

301 yield_total: bool = True, 

302 yield_electrons: bool = False, 

303 yield_holes: bool = False) -> Generator[tuple[WorkMetadata, Result], None, None]: 

304 """ Iteratively calculate results. 

305 

306 Parameters 

307 ---------- 

308 yield_total 

309 The results should include the total induced density. 

310 yield_holes 

311 The results should include the holes densities, optionally decomposed by `filter_occ`. 

312 yield_electrons 

313 The results should include the electrons densities, optionally decomposed by `filter_unocc`. 

314 

315 Yields 

316 ------ 

317 Tuple (work, result) on the root rank of the calculation communicator. \ 

318 Does not yield on non-root ranks of the calculation communicator. 

319 

320 work 

321 An object representing the metadata (time, frequency or pulse) for the work done. 

322 result 

323 Object containg the calculation results for this time, frequency or pulse. 

324 """ 

325 noccf = len(self.occ_filters) 

326 nunoccf = len(self.unocc_filters) 

327 

328 if (yield_electrons or yield_holes) and not self._is_time_density_matrices: 

329 raise ValueError('Electron or hole densities can only be computed in the time domain.') 

330 

331 # Iterate over the pulses and times, or frequencies 

332 for work, dm in self.density_matrices: 

333 if self._is_time_density_matrices: 

334 # Real part contributes to density 

335 rho_ia = dm.rho_ia.real 

336 else: 

337 # Imaginary part gives absorption contribution 

338 rho_ia = -dm.rho_ia.imag 

339 

340 self.log.start('calculate') 

341 

342 # Non-root ranks on calc_comm will write empty arrays to result, but will not be yielded 

343 result = Result() 

344 

345 if yield_total: 

346 result['rho_g'] = self.get_density(rho_ia.real, 'ia') * Bohr ** -3 

347 

348 if yield_holes: 

349 # Holes 

350 M_ii = 0.5 * (dm.Q_ia @ dm.Q_ia.T + dm.P_ia @ dm.P_ia.T) 

351 

352 result['occ_rho_g'] = self.get_density(M_ii, 'ii') * Bohr ** -3 

353 

354 if noccf > 0: 

355 result['occ_rho_rows_fg'] = np.array([self.get_density(M_ii, 'ii', fltn1=flt) 

356 for flt in self.occ_filters]) * Bohr ** -3 

357 result['occ_rho_diag_fg'] = np.array([self.get_density(M_ii, 'ii', fltn1=flt, fltn2=flt) 

358 for flt in self.occ_filters]) * Bohr ** -3 

359 

360 if yield_electrons: 

361 # Electrons 

362 M_aa = 0.5 * (dm.Q_ia.T @ dm.Q_ia + dm.P_ia.T @ dm.P_ia) 

363 

364 result['unocc_rho_g'] = self.get_density(M_aa, 'aa') * Bohr ** -3 

365 

366 if nunoccf > 0: 

367 result['unocc_rho_rows_fg'] = np.array([self.get_density(M_aa, 'aa', fltn1=flt) 

368 for flt in self.unocc_filters]) * Bohr ** -3 

369 result['unocc_rho_diag_fg'] = np.array([self.get_density(M_aa, 'aa', fltn1=flt, fltn2=flt) 

370 for flt in self.unocc_filters]) * Bohr ** -3 

371 if dm.rank > 0: 

372 continue 

373 

374 self.log_parallel(f'Calculated density in {self.log.elapsed("calculate"):.2f}s ' 

375 f'for {work.desc}', flush=True) 

376 

377 yield work, result 

378 

379 if self.calc_comm.rank == 0: 

380 self.log_parallel('Finished calculating density contributions', flush=True) 

381 

382 def calculate_and_write(self, 

383 out_fname: str, 

384 which: str | list[str] = 'induced', 

385 write_extra: dict[str, Any] = dict()): 

386 """ Calculate density contributions. 

387 

388 Densities are saved in a numpy archive, ULM file or cube file depending on 

389 whether the file extension is `.npz`, `.ulm`, or `.cube`. 

390 

391 If the file extension is `.cube` then :attr:`out_fname` is taken to be a formatting string. 

392 

393 The formatting string should be a plain string containing variable 

394 placeholders within curly brackets `{}`. It should not be confused with 

395 a formatted string literal (f-string). 

396 

397 It acccepts the variables: 

398 

399 * `{time}` - Time in units of as (time domain only). 

400 * `{freq}` - Frequency in units of eV (frequency domain only). 

401 * `{which}` - The :attr:`which` argument. 

402 * `{pulsefreq}` - Pulse frequency in units of eV (time domain only). 

403 * `{pulsefwhm}` - Pulse FWHM in units of fs (time domain only). 

404 

405 Examples: 

406 

407 * out_fname = `{which}_density_t{time:09.1f}.cube`. 

408 * out_fname = `{which}_density_w{freq:05.2f}.cube`. 

409 

410 Parameters 

411 ---------- 

412 out_fname 

413 File name of the resulting data file. 

414 which 

415 String, or list of strings specifying the types of density to compute: 

416 

417 * `induced` - Induced density. 

418 * `holes` - Holes density (only allowed in the time domain). 

419 * `electrons` - Electrons density (only allowed in the time domain). 

420 write_extra 

421 Dictionary of extra key-value pairs to write to the data file. 

422 """ 

423 from ..writers.density import DensityWriter, write_density 

424 from ..writers.writer import FrequencyResultsCollector, TimeResultsCollector 

425 

426 cls = TimeResultsCollector if self._is_time_density_matrices else FrequencyResultsCollector 

427 

428 if isinstance(which, str): 

429 which = [which] 

430 

431 for which_key in which: 

432 if which_key in ['holes', 'electrons'] and not self._is_time_density_matrices: 

433 raise ValueError(f'Option which={which_key} not allowed in the frequency domain.') 

434 if which_key not in ['induced', 'holes', 'electrons']: 

435 raise ValueError(f'Option which={which} not recognized. ' 

436 'Must be one of: induced, holes, electrons') 

437 

438 calc_kwargs = dict(yield_total='induced' in which, 

439 yield_holes='holes' in which, 

440 yield_electrons='electrons' in which) 

441 

442 keys = {'induced': 'rho_g', 

443 'holes': 'occ_rho_g', 

444 'electrons': 'unocc_rho_g'} 

445 

446 out_fname = str(out_fname) 

447 if out_fname.endswith('.npz'): 

448 exclude = ['occ_rho_rows_fg', 'occ_rho_diag_fg', 'unocc_rho_rows_fg', 'unocc_rho_diag_fg'] 

449 writer = DensityWriter(cls(self, calc_kwargs=calc_kwargs, exclude=exclude)) 

450 writer.calculate_and_save_npz(out_fname=out_fname, write_extra=write_extra) 

451 elif out_fname.endswith('.ulm'): 

452 exclude = ['rho_g', 'occ_rho_rows_fg', 'occ_rho_diag_fg', 'unocc_rho_rows_fg', 'unocc_rho_diag_fg'] 

453 writer = DensityWriter(cls(self, calc_kwargs=calc_kwargs, exclude=exclude)) 

454 writer.calculate_and_save_ulm(out_fname=out_fname, write_extra=write_extra) 

455 elif out_fname.endswith('.cube'): 

456 atoms = self.calc.atoms 

457 for work, res in self.icalculate(**calc_kwargs): 

458 if self.calc_comm.rank > 0: 

459 continue 

460 for which_key in which: 

461 key = keys[which_key] 

462 fname_kw: dict[str, float | str] = dict(which=which_key) 

463 data = res[key] 

464 if self._is_time_density_matrices: 

465 assert isinstance(work, ConvolutionDensityMatrixMetadata) 

466 fname_kw.update(time=work.time, **get_gaussian_pulse_values(work.pulse)) 

467 else: 

468 assert isinstance(work, FrequencyDensityMatrixMetadata) 

469 fname_kw.update(freq=work.freq) 

470 fpath = Path(out_fname.format(**fname_kw)) 

471 fpath.parent.mkdir(parents=True, exist_ok=True) 

472 write_density(str(fpath), atoms, data) 

473 self.log_parallel(f'Written {fpath}', flush=True) 

474 

475 else: 

476 raise ValueError(f'output-file must have ending .npz or .ulm, is {out_fname}')