Source code for rhodent.perturbation

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Any, Union
from numbers import Number
import numpy as np
from numpy.typing import NDArray

from gpaw.lcaotddft.laser import Laser, create_laser


def create_perturbation(perturbation: PerturbationLike):
    if isinstance(perturbation, Perturbation):
        return perturbation
    if perturbation is None:
        return NoPerturbation()
    if isinstance(perturbation, Laser):
        return PulsePerturbation(perturbation)

    assert isinstance(perturbation, dict)
    if perturbation['name'] == 'none':
        return NoPerturbation
    if perturbation['name'] == 'deltakick':
        return DeltaKick(strength=perturbation['strength'])
    pulse = create_laser(perturbation)
    return PulsePerturbation(pulse)


[docs] class Perturbation(ABC): """ Perturbation. """ def timestep(self, times: NDArray[np.float64]): if len(times) < 2: raise ValueError('At least two times must be given to get a timestep.') dt = times[1] - times[0] if not np.allclose(times[1:] - dt, times[:-1]): raise ValueError('The time step may not vary.') return dt
[docs] def frequencies(self, times: NDArray[np.float64], padnt: int | None = None) -> NDArray[np.float64]: """ Get the frequencies grid. Parameters ---------- times Time grid in atomic units. padnt Length of data, including zero padding. Default is not zero padding. Returns ------- Frequencies grid in atomic units. """ timestep = self.timestep(times) if padnt is None: padnt = len(times) return 2 * np.pi * np.fft.rfftfreq(padnt, timestep)
[docs] @abstractmethod def normalize_frequency_response(self, data: NDArray[np.float64], times: NDArray[np.float64], padnt: int, axis: int = -1) -> NDArray[np.complex128]: """ Calculate a normalized response in the frequency domain, i.e., the response to a unity strength delta kick. For example, polarizability. Parameters ---------- data Real valued response in the time domain to this perturbation. times Time grid in atomic units. axis Axis corresponding to time dimension. padnt Length of data, including zero padding. Returns ------- Normalized response in the frequency domain. """ raise NotImplementedError
[docs] def normalize_time_response(self, data: NDArray[np.float64], times: NDArray[np.float64], axis: int = -1) -> NDArray[np.float64]: """ Transform response in the time domain to a "normalized response", which is the response to a unity strength delta kick. Parameters ---------- data Real valued response in the time domain to this perturbation. times Time grid in atomic units. axis Axis corresponding to time dimension. Returns ------- Normalized response in the time domain. """ from .utils import fast_pad dt = times[1] - times[0] nt = len(times) padnt = fast_pad(nt) data = data.swapaxes(axis, -1) # Put the time dimension last # Calculate the normalized response in the frequency domain data_w = self.normalize_frequency_response(data, times, padnt, axis=-1) # Fourier transform back to time tomain data_t = np.fft.irfft(data_w, n=padnt, axis=-1)[..., :nt] / dt data_t = data_t.swapaxes(axis, -1) return data_t
[docs] @abstractmethod def amplitude(self, times: NDArray[np.float64]) -> NDArray[np.float64]: """ Perturbation amplitudes in time domain. Parameters ---------- times Time grid in atomic units. Returns ------- Perturbation at the given times. """ raise NotImplementedError
[docs] @abstractmethod def fourier(self, times: NDArray[np.float64], padnt: int | None = None) -> NDArray[np.complex128]: """ Fourier transform of perturbation. Parameters ---------- times Time grid in atomic units. padnt Length of data, including zero padding. Default is no added zero padding. Returns ------- Fourier transform of the perturbation at the frequency grid \ given by :func:`frequencies`. """ raise NotImplementedError
@abstractmethod def todict(self) -> dict[str, Any]: raise NotImplementedError def __eq__(self, other) -> bool: """ Equal if dicts are identical (up to numerical tolerance). """ try: d1 = self.todict() d2 = other.todict() except AttributeError: return False if d1.keys() != d2.keys(): return False for key in d1.keys(): if isinstance(d1[key], Number) and isinstance(d2[key], Number): if not np.isclose(d1[key], d2[key]): return False else: if not d1[key] == d2[key]: return False return True
PerturbationLike = Union[Perturbation, Laser, dict, None]
[docs] class NoPerturbation(Perturbation): """ No perturbation Used to indicate that we do not know the perturbation, and that it should not matter. """ def __init__(self): pass
[docs] def amplitude(self, times: NDArray[np.float64]) -> NDArray[np.float64]: raise RuntimeError('Not possible for no perturbation.')
[docs] def fourier(self, times: NDArray[np.float64], padnt: int | None = None) -> NDArray[np.complex128]: raise RuntimeError('Not possible for no perturbation')
[docs] def normalize_frequency_response(self, data: NDArray[np.float64], times: NDArray[np.float64], padnt: int, axis: int = -1) -> NDArray[np.complex128]: raise RuntimeError('Not possible for no perturbation')
def __str__(self) -> str: return 'No perturbation' def todict(self) -> dict[str, Any]: return {'name': 'none'}
[docs] class DeltaKick(Perturbation): """ Delta-kick perturbation. Parameters ---------- strength Strength of the perturbation in the frequency domain. """ def __init__(self, strength: float): self.strength = strength
[docs] def amplitude(self, times: NDArray[np.float64]) -> NDArray[np.float64]: dt = self.timestep(times) amplitudes = np.abs(times) < 1e-3 * dt # 1 if zero, else 0 return self.strength / dt * amplitudes
[docs] def fourier(self, times: NDArray[np.float64], padnt: int | None = None) -> NDArray[np.complex128]: nw = len(self.frequencies(times, padnt)) # Length of frequencies grid return self.strength * np.ones(nw) # type: ignore
[docs] def normalize_frequency_response(self, data: NDArray[np.float64], times: NDArray[np.float64], padnt: int, axis: int = -1) -> NDArray[np.complex128]: data_w = np.fft.rfft(data, n=padnt) * self.timestep(times) # The strength is specified in the frequency domain, so the timestep is included in strength return data_w / self.strength
[docs] def normalize_time_response(self, data: NDArray[np.float64], times: NDArray[np.float64], axis: int = -1) -> NDArray[np.float64]: # The strength is specified in the frequency domain, hence no multiplication by timestep return data / self.strength
def todict(self) -> dict[str, Any]: return {'name': 'deltakick', 'strength': self.strength} def __str__(self) -> str: return f'Delta-kick perturbation (strength {self.strength:.1e})'
[docs] class PulsePerturbation(Perturbation): """ Perturbation as a time-dependent function. Parameters ---------- pulse Object representing the pulse. """ def __init__(self, pulse: Laser | dict): self.pulse = create_laser(pulse)
[docs] def amplitude(self, times: NDArray[np.float64]) -> NDArray[np.float64]: return self.pulse.strength(times)
[docs] def fourier(self, times: NDArray[np.float64], padnt: int | None = None) -> NDArray[np.complex128]: pulse_t = self.amplitude(times) if padnt is None: padnt = len(times) return np.fft.rfft(pulse_t, n=padnt) * self.timestep(times)
[docs] def normalize_frequency_response(self, data: NDArray[np.float64], times: NDArray[np.float64], padnt: int, axis: int = -1) -> NDArray[np.complex128]: data = data.swapaxes(axis, -1) # Put the time dimension last thresh = 0.005 # Threshold for filtering where perturbation is zero # Fourier transform of perturbation perturb_t = self.pulse.strength(times) perturb_w = np.fft.rfft(perturb_t, n=padnt) # Fourier transform of data data_w = np.fft.rfft(data, n=padnt) # Mask where perturbation is below threshold flt_w = np.abs(perturb_w) > thresh * np.abs(perturb_w).max() data_w[..., ~flt_w] = 0 # Divide by the perturbation data_w[..., flt_w] /= perturb_w[flt_w] # Move back the time/frequency dimension data_w = data_w.swapaxes(axis, -1) return data_w
def todict(self) -> dict[str, Any]: try: return self.pulse.todict() except AttributeError: return {'name': self.pulse.__class__.__name__} def __str__(self) -> str: lines: list[str] = [] width = 50 for key, value in self.todict().items(): line = f'{key}: {value}' if len(lines) == 0: lines.append(line) continue if len(lines[-1]) + len(line) + 2 < width: lines[-1] = lines[-1] + ', ' + line else: lines.append(line) return '\n'.join(lines)