Coverage for rhodent/perturbation.py: 78%
142 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
1from __future__ import annotations
3from abc import ABC, abstractmethod
4from typing import Any, Union
5from numbers import Number
6import numpy as np
7from numpy.typing import NDArray
9from gpaw.lcaotddft.laser import Laser, create_laser
12def create_perturbation(perturbation: PerturbationLike):
13 if isinstance(perturbation, Perturbation):
14 return perturbation
15 if perturbation is None:
16 return NoPerturbation()
17 if isinstance(perturbation, Laser):
18 return PulsePerturbation(perturbation)
20 assert isinstance(perturbation, dict)
21 if perturbation['name'] == 'none':
22 return NoPerturbation
23 if perturbation['name'] == 'deltakick':
24 return DeltaKick(strength=perturbation['strength'])
25 pulse = create_laser(perturbation)
26 return PulsePerturbation(pulse)
29class Perturbation(ABC):
31 """ Perturbation. """
33 def timestep(self,
34 times: NDArray[np.float64]):
35 if len(times) < 2:
36 raise ValueError('At least two times must be given to get a timestep.')
37 dt = times[1] - times[0]
38 if not np.allclose(times[1:] - dt, times[:-1]):
39 raise ValueError('The time step may not vary.')
40 return dt
42 def frequencies(self,
43 times: NDArray[np.float64],
44 padnt: int | None = None) -> NDArray[np.float64]:
45 """ Get the frequencies grid.
47 Parameters
48 ----------
49 times
50 Time grid in atomic units.
51 padnt
52 Length of data, including zero padding. Default is not zero padding.
54 Returns
55 -------
56 Frequencies grid in atomic units.
57 """
58 timestep = self.timestep(times)
59 if padnt is None:
60 padnt = len(times)
61 return 2 * np.pi * np.fft.rfftfreq(padnt, timestep)
63 @abstractmethod
64 def normalize_frequency_response(self,
65 data: NDArray[np.float64],
66 times: NDArray[np.float64],
67 padnt: int,
68 axis: int = -1) -> NDArray[np.complex128]:
69 """
70 Calculate a normalized response in the frequency domain, i.e., the
71 response to a unity strength delta kick. For example, polarizability.
73 Parameters
74 ----------
75 data
76 Real valued response in the time domain to this perturbation.
77 times
78 Time grid in atomic units.
79 axis
80 Axis corresponding to time dimension.
81 padnt
82 Length of data, including zero padding.
84 Returns
85 -------
86 Normalized response in the frequency domain.
87 """
88 raise NotImplementedError
90 def normalize_time_response(self,
91 data: NDArray[np.float64],
92 times: NDArray[np.float64],
93 axis: int = -1) -> NDArray[np.float64]:
94 """
95 Transform response in the time domain to a "normalized response",
96 which is the response to a unity strength delta kick.
98 Parameters
99 ----------
100 data
101 Real valued response in the time domain to this perturbation.
102 times
103 Time grid in atomic units.
104 axis
105 Axis corresponding to time dimension.
107 Returns
108 -------
109 Normalized response in the time domain.
110 """
111 from .utils import fast_pad
113 dt = times[1] - times[0]
114 nt = len(times)
115 padnt = fast_pad(nt)
117 data = data.swapaxes(axis, -1) # Put the time dimension last
119 # Calculate the normalized response in the frequency domain
120 data_w = self.normalize_frequency_response(data, times, padnt, axis=-1)
122 # Fourier transform back to time tomain
123 data_t = np.fft.irfft(data_w, n=padnt, axis=-1)[..., :nt] / dt
125 data_t = data_t.swapaxes(axis, -1)
126 return data_t
128 @abstractmethod
129 def amplitude(self,
130 times: NDArray[np.float64]) -> NDArray[np.float64]:
131 """
132 Perturbation amplitudes in time domain.
134 Parameters
135 ----------
136 times
137 Time grid in atomic units.
139 Returns
140 -------
141 Perturbation at the given times.
142 """
143 raise NotImplementedError
145 @abstractmethod
146 def fourier(self,
147 times: NDArray[np.float64],
148 padnt: int | None = None) -> NDArray[np.complex128]:
149 """
150 Fourier transform of perturbation.
152 Parameters
153 ----------
154 times
155 Time grid in atomic units.
156 padnt
157 Length of data, including zero padding. Default is no added zero padding.
159 Returns
160 -------
161 Fourier transform of the perturbation at the frequency grid \
162 given by :func:`frequencies`.
163 """
164 raise NotImplementedError
166 @abstractmethod
167 def todict(self) -> dict[str, Any]:
168 raise NotImplementedError
170 def __eq__(self, other) -> bool:
171 """ Equal if dicts are identical (up to numerical tolerance).
172 """
173 try:
174 d1 = self.todict()
175 d2 = other.todict()
176 except AttributeError:
177 return False
179 if d1.keys() != d2.keys():
180 return False
182 for key in d1.keys():
183 if isinstance(d1[key], Number) and isinstance(d2[key], Number):
184 if not np.isclose(d1[key], d2[key]):
185 return False
186 else:
187 if not d1[key] == d2[key]:
188 return False
190 return True
193PerturbationLike = Union[Perturbation, Laser, dict, None]
196class NoPerturbation(Perturbation):
198 """ No perturbation
200 Used to indicate that we do not know the perturbation,
201 and that it should not matter.
202 """
204 def __init__(self):
205 pass
207 def amplitude(self,
208 times: NDArray[np.float64]) -> NDArray[np.float64]:
209 raise RuntimeError('Not possible for no perturbation.')
211 def fourier(self,
212 times: NDArray[np.float64],
213 padnt: int | None = None) -> NDArray[np.complex128]:
214 raise RuntimeError('Not possible for no perturbation')
216 def normalize_frequency_response(self,
217 data: NDArray[np.float64],
218 times: NDArray[np.float64],
219 padnt: int,
220 axis: int = -1) -> NDArray[np.complex128]:
221 raise RuntimeError('Not possible for no perturbation')
223 def __str__(self) -> str:
224 return 'No perturbation'
226 def todict(self) -> dict[str, Any]:
227 return {'name': 'none'}
230class DeltaKick(Perturbation):
232 """ Delta-kick perturbation.
234 Parameters
235 ----------
236 strength
237 Strength of the perturbation in the frequency domain.
238 """
240 def __init__(self,
241 strength: float):
242 self.strength = strength
244 def amplitude(self,
245 times: NDArray[np.float64]) -> NDArray[np.float64]:
246 dt = self.timestep(times)
247 amplitudes = np.abs(times) < 1e-3 * dt # 1 if zero, else 0
249 return self.strength / dt * amplitudes
251 def fourier(self,
252 times: NDArray[np.float64],
253 padnt: int | None = None) -> NDArray[np.complex128]:
254 nw = len(self.frequencies(times, padnt)) # Length of frequencies grid
255 return self.strength * np.ones(nw) # type: ignore
257 def normalize_frequency_response(self,
258 data: NDArray[np.float64],
259 times: NDArray[np.float64],
260 padnt: int,
261 axis: int = -1) -> NDArray[np.complex128]:
262 data_w = np.fft.rfft(data, n=padnt) * self.timestep(times)
263 # The strength is specified in the frequency domain, so the timestep is included in strength
264 return data_w / self.strength
266 def normalize_time_response(self,
267 data: NDArray[np.float64],
268 times: NDArray[np.float64],
269 axis: int = -1) -> NDArray[np.float64]:
270 # The strength is specified in the frequency domain, hence no multiplication by timestep
271 return data / self.strength # type: ignore
273 def todict(self) -> dict[str, Any]:
274 return {'name': 'deltakick', 'strength': self.strength}
276 def __str__(self) -> str:
277 return f'Delta-kick perturbation (strength {self.strength:.1e})'
280class PulsePerturbation(Perturbation):
282 """ Perturbation as a time-dependent function.
284 Parameters
285 ----------
286 pulse
287 Object representing the pulse.
288 """
290 def __init__(self,
291 pulse: Laser | dict):
292 self.pulse = create_laser(pulse)
294 def amplitude(self,
295 times: NDArray[np.float64]) -> NDArray[np.float64]:
296 return self.pulse.strength(times)
298 def fourier(self,
299 times: NDArray[np.float64],
300 padnt: int | None = None) -> NDArray[np.complex128]:
301 pulse_t = self.amplitude(times)
302 if padnt is None:
303 padnt = len(times)
304 return np.fft.rfft(pulse_t, n=padnt) * self.timestep(times)
306 def normalize_frequency_response(self,
307 data: NDArray[np.float64],
308 times: NDArray[np.float64],
309 padnt: int,
310 axis: int = -1) -> NDArray[np.complex128]:
311 data = data.swapaxes(axis, -1) # Put the time dimension last
312 thresh = 0.005 # Threshold for filtering where perturbation is zero
314 # Fourier transform of perturbation
315 perturb_t = self.pulse.strength(times)
316 perturb_w = np.fft.rfft(perturb_t, n=padnt)
318 # Fourier transform of data
319 data_w = np.fft.rfft(data, n=padnt)
321 # Mask where perturbation is below threshold
322 flt_w = np.abs(perturb_w) > thresh * np.abs(perturb_w).max()
323 data_w[..., ~flt_w] = 0
325 # Divide by the perturbation
326 data_w[..., flt_w] /= perturb_w[flt_w]
328 # Move back the time/frequency dimension
329 data_w = data_w.swapaxes(axis, -1)
331 return data_w
333 def todict(self) -> dict[str, Any]:
334 try:
335 return self.pulse.todict()
336 except AttributeError:
337 return {'name': self.pulse.__class__.__name__}
339 def __str__(self) -> str:
340 lines: list[str] = []
341 width = 50
342 for key, value in self.todict().items():
343 line = f'{key}: {value}'
344 if len(lines) == 0:
345 lines.append(line)
346 continue
347 if len(lines[-1]) + len(line) + 2 < width:
348 lines[-1] = lines[-1] + ', ' + line
349 else:
350 lines.append(line)
351 return '\n'.join(lines)