Coverage for rhodent/density_matrices/distributed/frequency.py: 98%

168 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-01 16:57 +0000

1from __future__ import annotations 

2 

3from typing import Generator 

4import numpy as np 

5 

6from gpaw.tddft.units import au_to_eV 

7 

8from .base import BaseDistributor, RhoParameters 

9from .time import TimeDistributor, AlltoallvTimeDistributor 

10from ..buffer import DensityMatrixBuffer 

11from ..readers.gpaw import KohnShamRhoWfsReader 

12from ...utils import get_array_filter, safe_fill_larger, fast_pad 

13from ...utils.logging import format_frequencies 

14from ...utils.memory import MemoryEstimate 

15from ...perturbation import create_perturbation, PerturbationLike 

16from ...typing import Array1D 

17 

18 

19class FourierTransformer(BaseDistributor): 

20 

21 """ Iteratively take the Fourier transform of density matrices. 

22 

23 Parameters 

24 ---------- 

25 rho_nn_reader 

26 Object that can iteratively read density matrices in the time domain, 

27 distributed such that different ranks have different chunks of the density 

28 matrix, but each ranks has all times for the same chunk. 

29 perturbation 

30 The perturbation which the density matrices are a response to. 

31 filter_frequencies 

32 After Fourier transformation keep only these frequencies (or the frequencies 

33 closest to them). In atomic units. 

34 frequency_broadening 

35 Gaussian broadening width in atomic units. Default (0) is no broadening. 

36 result_on_ranks 

37 List of ranks among which the resulting arrays will be distributed. 

38 Empty list (default) to distribute among all ranks. 

39 """ 

40 

41 def __init__(self, 

42 rho_nn_reader: TimeDistributor, 

43 perturbation: PerturbationLike, 

44 filter_frequencies: list[float] | Array1D[np.float64] | None = None, 

45 frequency_broadening: float = 0, 

46 result_on_ranks: list[int] = []): 

47 super().__init__(rho_nn_reader.rho_wfs_reader, 

48 rho_nn_reader._parameters, 

49 comm=rho_nn_reader.comm) 

50 self.rho_nn_reader = rho_nn_reader 

51 self.perturbation = create_perturbation(perturbation) 

52 self.frequency_broadening = frequency_broadening 

53 self._flt_w = get_array_filter(self._omega_w, filter_frequencies) 

54 

55 if len(result_on_ranks) == 0: 

56 self._result_on_ranks = set(range(self.comm.size)) 

57 else: 

58 assert all(isinstance(rank, int) and rank >= 0 and rank < self.comm.size 

59 for rank in result_on_ranks), result_on_ranks 

60 self._result_on_ranks = set(result_on_ranks) 

61 

62 self._dist_buffer: DensityMatrixBuffer | None = None 

63 

64 @property 

65 def dtype(self): 

66 return np.complex128 

67 

68 @property 

69 def xshape(self): 

70 return (self.nw, ) 

71 

72 @property 

73 def freq_w(self) -> Array1D[np.float64]: 

74 return self._omega_w[self.flt_w] # type: ignore 

75 

76 @property 

77 def _omega_w(self) -> Array1D[np.float64]: 

78 padnt = fast_pad(self.rho_nn_reader.nt) 

79 dt = self.rho_nn_reader.dt 

80 omega_w = 2 * np.pi * np.fft.rfftfreq(padnt, dt) 

81 

82 return omega_w # type: ignore 

83 

84 @property 

85 def nw(self) -> int: 

86 return len(self.freq_w) 

87 

88 @property 

89 def nlocalw(self) -> int: 

90 return (self.nw + self.nranks_result - 1) // self.nranks_result 

91 

92 @property 

93 def flt_w(self) -> slice | Array1D[np.bool_]: 

94 return self._flt_w 

95 

96 @property 

97 def result_on_ranks(self) -> list[int]: 

98 """ Set of ranks among which the result will be distributed """ 

99 return sorted(self._result_on_ranks) 

100 

101 @property 

102 def nranks_result(self) -> int: 

103 """ Number of ranks that the resulting arrays will be distributed among """ 

104 return len(self._result_on_ranks) 

105 

106 def distributed_work(self) -> list[list[int]]: 

107 freqw_r = self.comm.size * [[]] 

108 for r, rank in enumerate(self.result_on_ranks): 

109 freqw_r[rank] = list(range(r, self.nw, self.nranks_result)) 

110 

111 return freqw_r 

112 

113 def my_work(self) -> list[int]: 

114 freqw_r = self.distributed_work() 

115 return freqw_r[self.comm.rank] 

116 

117 def __str__(self) -> str: 

118 nt = len(self.rho_nn_reader.time_t) 

119 niters = len(list(self.work_loop_by_ranks())) 

120 

121 lines = [] 

122 lines.append('Fourier transformer') 

123 lines.append(f' Calculating Fourier transform on {self.maxnchunks} ranks') 

124 lines.append(' Fast Fourier transform') 

125 lines.append(f' matrix dimensions {self.rho_nn_reader._parameters.nnshape}') 

126 lines.append(f' grid of {nt} times') 

127 lines.append(f' {self.describe_reim()}') 

128 if self.frequency_broadening == 0: 

129 lines.append(' No frequency broadening') 

130 else: 

131 lines.append(f' Applying frequency broadening of {self.frequency_broadening * au_to_eV:.2f}eV') 

132 lines.append(f' keeping frequency grid of {self.nw} frequencies') 

133 lines.append(f' {format_frequencies(self.freq_w, units="au")}') 

134 lines.append('') 

135 

136 lines.append(' Redistributing into full density matrices') 

137 lines.append(f' {niters} iterations to process all chunks') 

138 lines.append(f' matrix dimensions {self.rho_nn_reader._parameters.full_nnshape}') 

139 lines.append(f' result stored on {self.nranks_result} ranks') 

140 

141 return '\n'.join(lines) 

142 

143 def get_memory_estimate(self) -> MemoryEstimate: 

144 parameters = self.rho_nn_reader._parameters 

145 

146 narrays = 2 if self.yield_re and self.yield_im else 1 

147 temp_shape = parameters.nnshape + (self.maxnchunks, self.nlocalw, narrays) 

148 result_shape = parameters.full_nnshape + (self.nlocalw, narrays) 

149 

150 total_result_size = int(np.prod(parameters.full_nnshape)) * self.nw * narrays 

151 

152 comment = f'Buffers hold {narrays} arrays ({self.describe_reim()})' 

153 own_memory_estimate = MemoryEstimate(comment=comment) 

154 own_memory_estimate.add_key('Temporary buffer', temp_shape, complex, 

155 on_num_ranks=self.nranks_result) 

156 own_memory_estimate.add_key('Result buffer', result_shape, complex, 

157 total_size=total_result_size, 

158 on_num_ranks=self.nranks_result) 

159 

160 memory_estimate = MemoryEstimate() 

161 memory_estimate.add_child('Time-dependent wave functions reader', 

162 self.rho_nn_reader.rho_wfs_reader.get_memory_estimate()) 

163 memory_estimate.add_child('Parallel density matrices reader', 

164 self.rho_nn_reader.get_memory_estimate()) 

165 memory_estimate.add_child('Fourier transformer', 

166 own_memory_estimate) 

167 

168 return memory_estimate 

169 

170 def __iter__(self) -> Generator[DensityMatrixBuffer, None, None]: 

171 time_t = self.rho_nn_reader.time_t # Times in wave functions file 

172 dt = self.rho_nn_reader.dt # Time step 

173 padnt = fast_pad(len(time_t)) # Pad with zeros 

174 

175 dm_buffer = DensityMatrixBuffer(self.rho_nn_reader._parameters.nnshape, 

176 (self.nw, ), 

177 np.complex128) 

178 if self.yield_re: 

179 dm_buffer.zeros(True, 0) 

180 if self.yield_im: 

181 dm_buffer.zeros(False, 0) 

182 

183 for read_buffer in self.rho_nn_reader: 

184 for data_nnt, buffer_nnw in zip(read_buffer._iter_buffers(), dm_buffer._iter_buffers()): 

185 if self.frequency_broadening == 0: 

186 data_nnw = self.perturbation.normalize_frequency_response(data_nnt, time_t, padnt, axis=-1) 

187 else: 

188 data_nnt = self.perturbation.normalize_time_response(data_nnt, time_t, axis=-1) 

189 data_nnt[..., :len(time_t)] *= np.exp(-0.5 * self.frequency_broadening ** 2 * time_t**2) 

190 data_nnw = np.fft.rfft(data_nnt, n=padnt, axis=-1) * dt 

191 buffer_nnw[:] = data_nnw[..., self.flt_w].conj() # Change sign convention 

192 

193 yield dm_buffer.copy() 

194 

195 @property 

196 def dist_buffer(self) -> DensityMatrixBuffer: 

197 """ Buffer of density matrices on this rank after redistribution """ 

198 if self._dist_buffer is None: 

199 self._dist_buffer = self.redistribute() 

200 return self._dist_buffer 

201 

202 def create_out_buffer(self) -> DensityMatrixBuffer: 

203 """ Create the DensityMatrixBuffer to hold the temporary density matrix after each redistribution """ 

204 parameters = self.rho_nn_reader._parameters 

205 nlocalw = self.nlocalw if self.comm.rank in self.result_on_ranks else 0 

206 out_dm = DensityMatrixBuffer(nnshape=parameters.nnshape, 

207 xshape=(self.maxnchunks, nlocalw), 

208 dtype=np.complex128) 

209 out_dm.zero_buffers(real=self.yield_re, imag=self.yield_im, derivative_order_s=[0]) 

210 

211 return out_dm 

212 

213 def create_result_buffer(self) -> DensityMatrixBuffer: 

214 """ Create the DensityMatrixBuffer to hold the resulting density matrix """ 

215 parameters = self.rho_nn_reader._parameters 

216 nnshape = parameters.full_nnshape 

217 full_dm = DensityMatrixBuffer(nnshape=nnshape, 

218 xshape=(len(self.my_work()), ), 

219 dtype=np.complex128) 

220 full_dm.zero_buffers(real=self.yield_re, imag=self.yield_im, derivative_order_s=[0]) 

221 

222 return full_dm 

223 

224 def redistribute(self) -> DensityMatrixBuffer: 

225 """ Perform the Fourier transform and redistribute the data 

226 

227 When the Fourier transform is performed, the data is distributed such that each rank 

228 stores the entire time/frequency series for one chunk of the density matrices, i.e. indices n1, n2. 

229 

230 This function then performs a redistribution of the data such that each rank stores full 

231 density matrices, for certain frequencies. 

232 

233 If the density matrices are split into more chunks than there are ranks, then the 

234 chunks are read, Fourier transformed and distributed in a loop several times until all 

235 data has been processed. 

236 

237 Returns 

238 ------- 

239 Density matrix buffer with x-dimensions (Number of local frequencies, ) 

240 where the Number of local frequencies variers between the ranks. 

241 """ 

242 local_work = iter(self) 

243 parameters = self.rho_nn_reader._parameters 

244 log = self.log 

245 self.rho_nn_reader.rho_wfs_reader.lcao_rho_reader.striden == 0, \ 

246 'n stride must be 0 (index all) for redistribute' 

247 

248 # Frequency indices of result on each rank 

249 freqw_r = self.distributed_work() 

250 niters = len(list(self.work_loop_by_ranks())) 

251 

252 out_dm = self.create_out_buffer() 

253 full_dm = self.create_result_buffer() 

254 

255 _exhausted = object() 

256 

257 # Loop over the chunks of the density matrix 

258 for chunki, indices_r in enumerate(self.work_loop_by_ranks()): 

259 # At this point, each rank stores one unique chunk of the density matrix. 

260 # All ranks have the entire time series of data for their own chunk. 

261 # If there are more chunks than ranks, then this loop will run 

262 # for several iterations. If the number of chunks is not divisible by the number of 

263 # ranks then, during the last iteration, some of the chunks are None (meaning the rank 

264 # currently has no data). 

265 

266 # List of chunks that each rank currently stores, where element r of the list 

267 # contains the chunk that rank r works with. Ranks higher than the length of the list 

268 # currently store no chunks. 

269 # The list itself is identical on all ranks. 

270 chunks_by_rank = [indices[2:] for indices in indices_r if indices is not None] 

271 

272 ntargets = len(chunks_by_rank) 

273 

274 if self.comm.rank < ntargets: 

275 # This rank has data to send. Compute the Fourier transform and store the result 

276 dm_buffer = next(local_work) 

277 else: 

278 # This rank has no data to send 

279 assert next(local_work, _exhausted) is _exhausted 

280 # Still, we need to create a dummy buffer 

281 dm_buffer = DensityMatrixBuffer(nnshape=parameters.nnshape, 

282 xshape=(0, ), dtype=np.complex128) 

283 dm_buffer.zero_buffers(real=self.yield_re, imag=self.yield_im, 

284 derivative_order_s=[0]) 

285 

286 log.start('alltoallv') 

287 

288 # Redistribute the data: 

289 # - dm_buffer stores single chunks of density matrices, for all frequencies. 

290 # - out_dm will store several chunks, for a few frequencies. 

291 # source_indices_r describes which slices of dm_buffer should be sent to which rank 

292 # target_indices_r describes to which positions of the out_dm buffer should be received 

293 # from which rank 

294 source_indices_r = [None if len(w) == 0 else w for w in freqw_r] 

295 target_indices_r = [r if r < ntargets else None for r in range(self.comm.size)] 

296 dm_buffer.redistribute(out_dm, 

297 comm=self.comm, 

298 source_indices_r=source_indices_r, 

299 target_indices_r=target_indices_r, 

300 log=log) 

301 

302 if self.comm.rank == 0: 

303 log(f'Chunk {chunki+1}/{niters}: distributed frequency response in ' 

304 f'{log.elapsed("alltoallv"):.1f}s', flush=True, who='Response') 

305 

306 for array_nnrw, full_array_nnw in zip(out_dm._iter_buffers(), full_dm._iter_buffers()): 

307 for r, nn_indices in enumerate(chunks_by_rank): 

308 safe_fill_larger(full_array_nnw[nn_indices], array_nnrw[:, :, r]) 

309 

310 assert next(local_work, _exhausted) is _exhausted 

311 

312 return full_dm 

313 

314 @classmethod 

315 def from_reader(cls, # type: ignore 

316 rho_nn_reader: KohnShamRhoWfsReader, 

317 parameters: RhoParameters, 

318 *, 

319 perturbation: PerturbationLike, 

320 filter_frequencies: list[float] | Array1D[np.float64] | None = None, 

321 frequency_broadening: float = 0, 

322 result_on_ranks: list[int] = []) -> FourierTransformer: 

323 time_distributor = AlltoallvTimeDistributor(rho_nn_reader, parameters) 

324 fourier_transformer = FourierTransformer(time_distributor, 

325 perturbation=perturbation, 

326 filter_frequencies=filter_frequencies, 

327 frequency_broadening=frequency_broadening, 

328 result_on_ranks=result_on_ranks) 

329 return fourier_transformer