Coverage for rhodent/density_matrices/density_matrix.py: 65%
115 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
1from __future__ import annotations
3import numpy as np
4from numpy.typing import NDArray
6from gpaw.mpi import broadcast, world, SerialCommunicator
7from gpaw.lcaotddft.ksdecomposition import KohnShamDecomposition
9from ..typing import ArrayIsOnRootRank, DistributedArray, Communicator
12class DensityMatrix:
14 """ Wrapper for the density matrix in the Kohn-Sham basis at one moment
15 in time or at one frequency.
17 The plain density matrix and/or derivatives thereof may be stored.
19 Parameters
20 ----------
21 ksd
22 KohnShamDecomposition object.
23 matrices
24 Dictionary mapping derivative orders (0, 1, 2) for zeroth,
25 first, second derivative, .. to arrays storing the matrices.
26 comm
27 MPI communicator. Serial communicator by default.
28 """
30 def __init__(self,
31 ksd: KohnShamDecomposition,
32 matrices: dict[int, NDArray[np.complex128] | None],
33 comm: Communicator | None = None):
34 self._ksd = ksd
35 if comm is None:
36 comm = SerialCommunicator() # type: ignore
37 self._comm = comm
39 # Calculate occupation number difference
40 f_n = self.ksd.occ_un[0]
41 imin, imax, amin, amax = self.ksd.ialims()
42 self._f_ia = f_n[imin:imax+1, None] - f_n[None, amin:amax+1]
43 self._f_ia[self._f_ia < 0] = 0
45 # Calculate mask
46 min_occdiff = min(self.ksd.f_p)
47 mask_ia = self.f_ia >= min_occdiff
49 self._matrices: dict[int, DistributedArray] = dict()
50 self.derivative_desc = {0: 'Plain DM', 1: '1st DM derivative', 2: '2nd DM derivative'}
52 # Save the matrices
53 for derivative, rho in matrices.items():
54 assert isinstance(derivative, int)
55 if self.rank == 0:
56 assert isinstance(rho, np.ndarray), rho
57 self._matrices[derivative] = rho * mask_ia
58 else:
59 assert rho is None
60 self._matrices[derivative] = ArrayIsOnRootRank()
62 @property
63 def ksd(self) -> KohnShamDecomposition:
64 """ Kohn-Sham decomposition object. """
65 return self._ksd
67 @property
68 def rank(self) -> int:
69 """ MPI rank of the communicator. """
70 return self.comm.rank
72 @property
73 def comm(self) -> Communicator:
74 return self._comm # type: ignore
76 @property
77 def f_ia(self) -> DistributedArray:
78 """ Occupation number difference :math:`f_{ia}`. """
79 return self._f_ia
81 @property
82 def rho_ia(self) -> DistributedArray:
83 r""" Electron-hole part of induced density matrix :math:`\delta rho_{ia}`. """
84 try:
85 return self._matrices[0]
86 except KeyError:
87 raise ValueError('Plain density matrix not in {self._matrices.keys()}')
89 @property
90 def drho_ia(self) -> DistributedArray:
91 r""" First time derivative of :math:`\delta rho_{ia}`. """
92 try:
93 return self._matrices[1]
94 except KeyError:
95 raise ValueError('First derivative of density matrix not in {self._matrices.keys()}')
97 @property
98 def ddrho_ia(self) -> DistributedArray:
99 r""" Second time derivative of :math:`\delta rho_{ia}`. """
100 try:
101 return self._matrices[2]
102 except KeyError:
103 raise ValueError('Second derivative of density matrix not in {self._matrices.keys()}')
105 @property
106 def Q_ia(self) -> DistributedArray:
107 r""" The quantity
109 .. math::
110 Q_{ia} = \frac{2 \mathrm{Re}\:\delta\rho_{ia}}{\sqrt{2 f_{ia}}}
112 where :math:`f_{ia}` is the occupation number difference of pair :math:`ia`.
113 """
114 return self._divide_by_sqrt_fia(np.sqrt(2) * self.rho_ia.real)
116 @property
117 def P_ia(self) -> DistributedArray:
118 r""" The quantity
120 .. math::
121 P_{ia} = \frac{2 \mathrm{Im}\:\delta\rho_{ia}}{\sqrt{2 f_{ia}}}
123 where :math:`f_{ia}` is the occupation number difference of pair :math:`ia`.
124 """
125 return self._divide_by_sqrt_fia(np.sqrt(2) * self.rho_ia.imag)
127 @property
128 def dQ_ia(self) -> DistributedArray:
129 r""" First time derivative of :math:`Q_{ia}`. """
130 return self._divide_by_sqrt_fia(np.sqrt(2) * self.drho_ia.real)
132 @property
133 def dP_ia(self) -> DistributedArray:
134 r""" First time derivative of :math:`P_{ia}`. """
135 return self._divide_by_sqrt_fia(np.sqrt(2) * self.drho_ia.imag)
137 @property
138 def ddQ_ia(self) -> DistributedArray:
139 r""" Second time derivative of :math:`Q_{ia}`. """
140 return self._divide_by_sqrt_fia(np.sqrt(2) * self.ddrho_ia.real)
142 @property
143 def ddP_ia(self) -> DistributedArray:
144 r""" Second time derivative of :math:`P_{ia}`. """
145 return self._divide_by_sqrt_fia(np.sqrt(2) * self.ddrho_ia.imag)
147 def _divide_by_sqrt_fia(self,
148 X_ia: DistributedArray) -> DistributedArray:
149 r""" Divide by :math:`\sqrt{f_{ia}}` where :math:`f_{ia} \neq 0`.
150 Leave everything else as 0."""
151 if self.rank > 0:
152 assert isinstance(X_ia, ArrayIsOnRootRank)
153 return ArrayIsOnRootRank()
154 assert not isinstance(X_ia, ArrayIsOnRootRank)
155 flt_ia = self.f_ia != 0
156 Y_ia = np.zeros_like(X_ia)
157 Y_ia[flt_ia] = X_ia[flt_ia] / np.sqrt(self.f_ia[flt_ia])
159 return Y_ia
161 def copy(self) -> DensityMatrix:
162 """ Copy the density matrix. """
163 matrices: dict[int, NDArray[np.complex128] | None] = {
164 derivative: np.array(matrix) for derivative, matrix in self._matrices.items()}
165 dm = DensityMatrix(ksd=self.ksd, matrices=matrices, comm=self.comm)
166 return dm
168 @classmethod
169 def broadcast(cls,
170 density_matrix: DensityMatrix | None,
171 ksd: KohnShamDecomposition,
172 root: int,
173 dm_comm,
174 comm) -> DensityMatrix:
175 """ Broadcast a density matrix object which is on one rank to all other ranks.
177 Parameters
178 ----------
179 density_matrix
180 The density matrix to be broadcast on the root rank, and ``None`` on other ranks.
181 ksd
182 KohnShamDecomposition object.
183 root
184 Rank of the process that has the original data.
185 dm_comm
186 Must be identical to communicator of :attr:`density_matrix`.
187 comm
188 MPI communicator. Must be complementary to the communicator of :attr:`density_matrix`.
189 """
190 matrices: dict[int, NDArray[np.complex128] | None]
191 # Broadcast necessary metadata
192 if comm.rank == root:
193 assert density_matrix is not None
194 matrix_shapes_dtypes = {derivative: (matrix.shape, matrix.dtype)
195 for derivative, matrix in density_matrix._matrices.items()}
196 broadcast(matrix_shapes_dtypes, root=root, comm=comm)
197 matrices = {derivative: None if isinstance(arr, ArrayIsOnRootRank) else arr
198 for derivative, arr in density_matrix._matrices.items()}
199 else:
200 assert density_matrix is None
201 matrix_shapes_dtypes = broadcast(None, root=root, comm=comm)
203 if comm.rank != root:
204 if dm_comm.rank == 0:
205 matrices = {derivative: np.empty(shape, dtype=dtype)
206 for derivative, (shape, dtype) in matrix_shapes_dtypes.items()}
207 else:
208 matrices = {derivative: None
209 for derivative, (shape, dtype) in matrix_shapes_dtypes.items()}
211 if dm_comm.size > 1 and comm.size > 1:
212 # Make sure communicators are complementary
213 comm_members = comm.get_members()
214 dm_members = dm_comm.get_members()
216 intersect = set(comm_members) & set(dm_members)
217 intersect.remove(world.rank)
218 assert len(intersect) == 0, f'{comm_members} / {dm_members}'
220 # On density matrix non-root ranks, return ArrayIsOnRootRank()
221 if dm_comm.rank > 0:
222 return DensityMatrix(ksd=ksd, matrices=matrices, comm=dm_comm)
224 # Broadcast the matrices
225 for derivative, matrix in matrices.items():
226 comm.broadcast(np.ascontiguousarray(matrix), root)
228 if comm.rank == root:
229 assert density_matrix is not None
230 return density_matrix
231 else:
232 return DensityMatrix(ksd=ksd, matrices=matrices, comm=dm_comm)