Coverage for rhodent/perturbation.py: 78%

142 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-01 16:57 +0000

1from __future__ import annotations 

2 

3from abc import ABC, abstractmethod 

4from typing import Any, Union 

5from numbers import Number 

6import numpy as np 

7from numpy.typing import NDArray 

8 

9from gpaw.lcaotddft.laser import Laser, create_laser 

10 

11 

12def create_perturbation(perturbation: PerturbationLike): 

13 if isinstance(perturbation, Perturbation): 

14 return perturbation 

15 if perturbation is None: 

16 return NoPerturbation() 

17 if isinstance(perturbation, Laser): 

18 return PulsePerturbation(perturbation) 

19 

20 assert isinstance(perturbation, dict) 

21 if perturbation['name'] == 'none': 

22 return NoPerturbation 

23 if perturbation['name'] == 'deltakick': 

24 return DeltaKick(strength=perturbation['strength']) 

25 pulse = create_laser(perturbation) 

26 return PulsePerturbation(pulse) 

27 

28 

29class Perturbation(ABC): 

30 

31 """ Perturbation. """ 

32 

33 def timestep(self, 

34 times: NDArray[np.float64]): 

35 if len(times) < 2: 

36 raise ValueError('At least two times must be given to get a timestep.') 

37 dt = times[1] - times[0] 

38 if not np.allclose(times[1:] - dt, times[:-1]): 

39 raise ValueError('The time step may not vary.') 

40 return dt 

41 

42 def frequencies(self, 

43 times: NDArray[np.float64], 

44 padnt: int | None = None) -> NDArray[np.float64]: 

45 """ Get the frequencies grid. 

46 

47 Parameters 

48 ---------- 

49 times 

50 Time grid in atomic units. 

51 padnt 

52 Length of data, including zero padding. Default is not zero padding. 

53 

54 Returns 

55 ------- 

56 Frequencies grid in atomic units. 

57 """ 

58 timestep = self.timestep(times) 

59 if padnt is None: 

60 padnt = len(times) 

61 return 2 * np.pi * np.fft.rfftfreq(padnt, timestep) 

62 

63 @abstractmethod 

64 def normalize_frequency_response(self, 

65 data: NDArray[np.float64], 

66 times: NDArray[np.float64], 

67 padnt: int, 

68 axis: int = -1) -> NDArray[np.complex128]: 

69 """ 

70 Calculate a normalized response in the frequency domain, i.e., the 

71 response to a unity strength delta kick. For example, polarizability. 

72 

73 Parameters 

74 ---------- 

75 data 

76 Real valued response in the time domain to this perturbation. 

77 times 

78 Time grid in atomic units. 

79 axis 

80 Axis corresponding to time dimension. 

81 padnt 

82 Length of data, including zero padding. 

83 

84 Returns 

85 ------- 

86 Normalized response in the frequency domain. 

87 """ 

88 raise NotImplementedError 

89 

90 def normalize_time_response(self, 

91 data: NDArray[np.float64], 

92 times: NDArray[np.float64], 

93 axis: int = -1) -> NDArray[np.float64]: 

94 """ 

95 Transform response in the time domain to a "normalized response", 

96 which is the response to a unity strength delta kick. 

97 

98 Parameters 

99 ---------- 

100 data 

101 Real valued response in the time domain to this perturbation. 

102 times 

103 Time grid in atomic units. 

104 axis 

105 Axis corresponding to time dimension. 

106 

107 Returns 

108 ------- 

109 Normalized response in the time domain. 

110 """ 

111 from .utils import fast_pad 

112 

113 dt = times[1] - times[0] 

114 nt = len(times) 

115 padnt = fast_pad(nt) 

116 

117 data = data.swapaxes(axis, -1) # Put the time dimension last 

118 

119 # Calculate the normalized response in the frequency domain 

120 data_w = self.normalize_frequency_response(data, times, padnt, axis=-1) 

121 

122 # Fourier transform back to time tomain 

123 data_t = np.fft.irfft(data_w, n=padnt, axis=-1)[..., :nt] / dt 

124 

125 data_t = data_t.swapaxes(axis, -1) 

126 return data_t 

127 

128 @abstractmethod 

129 def amplitude(self, 

130 times: NDArray[np.float64]) -> NDArray[np.float64]: 

131 """ 

132 Perturbation amplitudes in time domain. 

133 

134 Parameters 

135 ---------- 

136 times 

137 Time grid in atomic units. 

138 

139 Returns 

140 ------- 

141 Perturbation at the given times. 

142 """ 

143 raise NotImplementedError 

144 

145 @abstractmethod 

146 def fourier(self, 

147 times: NDArray[np.float64], 

148 padnt: int | None = None) -> NDArray[np.complex128]: 

149 """ 

150 Fourier transform of perturbation. 

151 

152 Parameters 

153 ---------- 

154 times 

155 Time grid in atomic units. 

156 padnt 

157 Length of data, including zero padding. Default is no added zero padding. 

158 

159 Returns 

160 ------- 

161 Fourier transform of the perturbation at the frequency grid \ 

162 given by :func:`frequencies`. 

163 """ 

164 raise NotImplementedError 

165 

166 @abstractmethod 

167 def todict(self) -> dict[str, Any]: 

168 raise NotImplementedError 

169 

170 def __eq__(self, other) -> bool: 

171 """ Equal if dicts are identical (up to numerical tolerance). 

172 """ 

173 try: 

174 d1 = self.todict() 

175 d2 = other.todict() 

176 except AttributeError: 

177 return False 

178 

179 if d1.keys() != d2.keys(): 

180 return False 

181 

182 for key in d1.keys(): 

183 if isinstance(d1[key], Number) and isinstance(d2[key], Number): 

184 if not np.isclose(d1[key], d2[key]): 

185 return False 

186 else: 

187 if not d1[key] == d2[key]: 

188 return False 

189 

190 return True 

191 

192 

193PerturbationLike = Union[Perturbation, Laser, dict, None] 

194 

195 

196class NoPerturbation(Perturbation): 

197 

198 """ No perturbation 

199 

200 Used to indicate that we do not know the perturbation, 

201 and that it should not matter. 

202 """ 

203 

204 def __init__(self): 

205 pass 

206 

207 def amplitude(self, 

208 times: NDArray[np.float64]) -> NDArray[np.float64]: 

209 raise RuntimeError('Not possible for no perturbation.') 

210 

211 def fourier(self, 

212 times: NDArray[np.float64], 

213 padnt: int | None = None) -> NDArray[np.complex128]: 

214 raise RuntimeError('Not possible for no perturbation') 

215 

216 def normalize_frequency_response(self, 

217 data: NDArray[np.float64], 

218 times: NDArray[np.float64], 

219 padnt: int, 

220 axis: int = -1) -> NDArray[np.complex128]: 

221 raise RuntimeError('Not possible for no perturbation') 

222 

223 def __str__(self) -> str: 

224 return 'No perturbation' 

225 

226 def todict(self) -> dict[str, Any]: 

227 return {'name': 'none'} 

228 

229 

230class DeltaKick(Perturbation): 

231 

232 """ Delta-kick perturbation. 

233 

234 Parameters 

235 ---------- 

236 strength 

237 Strength of the perturbation in the frequency domain. 

238 """ 

239 

240 def __init__(self, 

241 strength: float): 

242 self.strength = strength 

243 

244 def amplitude(self, 

245 times: NDArray[np.float64]) -> NDArray[np.float64]: 

246 dt = self.timestep(times) 

247 amplitudes = np.abs(times) < 1e-3 * dt # 1 if zero, else 0 

248 

249 return self.strength / dt * amplitudes 

250 

251 def fourier(self, 

252 times: NDArray[np.float64], 

253 padnt: int | None = None) -> NDArray[np.complex128]: 

254 nw = len(self.frequencies(times, padnt)) # Length of frequencies grid 

255 return self.strength * np.ones(nw) # type: ignore 

256 

257 def normalize_frequency_response(self, 

258 data: NDArray[np.float64], 

259 times: NDArray[np.float64], 

260 padnt: int, 

261 axis: int = -1) -> NDArray[np.complex128]: 

262 data_w = np.fft.rfft(data, n=padnt) * self.timestep(times) 

263 # The strength is specified in the frequency domain, so the timestep is included in strength 

264 return data_w / self.strength 

265 

266 def normalize_time_response(self, 

267 data: NDArray[np.float64], 

268 times: NDArray[np.float64], 

269 axis: int = -1) -> NDArray[np.float64]: 

270 # The strength is specified in the frequency domain, hence no multiplication by timestep 

271 return data / self.strength # type: ignore 

272 

273 def todict(self) -> dict[str, Any]: 

274 return {'name': 'deltakick', 'strength': self.strength} 

275 

276 def __str__(self) -> str: 

277 return f'Delta-kick perturbation (strength {self.strength:.1e})' 

278 

279 

280class PulsePerturbation(Perturbation): 

281 

282 """ Perturbation as a time-dependent function. 

283 

284 Parameters 

285 ---------- 

286 pulse 

287 Object representing the pulse. 

288 """ 

289 

290 def __init__(self, 

291 pulse: Laser | dict): 

292 self.pulse = create_laser(pulse) 

293 

294 def amplitude(self, 

295 times: NDArray[np.float64]) -> NDArray[np.float64]: 

296 return self.pulse.strength(times) 

297 

298 def fourier(self, 

299 times: NDArray[np.float64], 

300 padnt: int | None = None) -> NDArray[np.complex128]: 

301 pulse_t = self.amplitude(times) 

302 if padnt is None: 

303 padnt = len(times) 

304 return np.fft.rfft(pulse_t, n=padnt) * self.timestep(times) 

305 

306 def normalize_frequency_response(self, 

307 data: NDArray[np.float64], 

308 times: NDArray[np.float64], 

309 padnt: int, 

310 axis: int = -1) -> NDArray[np.complex128]: 

311 data = data.swapaxes(axis, -1) # Put the time dimension last 

312 thresh = 0.005 # Threshold for filtering where perturbation is zero 

313 

314 # Fourier transform of perturbation 

315 perturb_t = self.pulse.strength(times) 

316 perturb_w = np.fft.rfft(perturb_t, n=padnt) 

317 

318 # Fourier transform of data 

319 data_w = np.fft.rfft(data, n=padnt) 

320 

321 # Mask where perturbation is below threshold 

322 flt_w = np.abs(perturb_w) > thresh * np.abs(perturb_w).max() 

323 data_w[..., ~flt_w] = 0 

324 

325 # Divide by the perturbation 

326 data_w[..., flt_w] /= perturb_w[flt_w] 

327 

328 # Move back the time/frequency dimension 

329 data_w = data_w.swapaxes(axis, -1) 

330 

331 return data_w 

332 

333 def todict(self) -> dict[str, Any]: 

334 try: 

335 return self.pulse.todict() 

336 except AttributeError: 

337 return {'name': self.pulse.__class__.__name__} 

338 

339 def __str__(self) -> str: 

340 lines: list[str] = [] 

341 width = 50 

342 for key, value in self.todict().items(): 

343 line = f'{key}: {value}' 

344 if len(lines) == 0: 

345 lines.append(line) 

346 continue 

347 if len(lines[-1]) + len(line) + 2 < width: 

348 lines[-1] = lines[-1] + ', ' + line 

349 else: 

350 lines.append(line) 

351 return '\n'.join(lines)