Coverage for rhodent/density_matrices/density_matrix.py: 65%

115 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-01 16:57 +0000

1from __future__ import annotations 

2 

3import numpy as np 

4from numpy.typing import NDArray 

5 

6from gpaw.mpi import broadcast, world, SerialCommunicator 

7from gpaw.lcaotddft.ksdecomposition import KohnShamDecomposition 

8 

9from ..typing import ArrayIsOnRootRank, DistributedArray, Communicator 

10 

11 

12class DensityMatrix: 

13 

14 """ Wrapper for the density matrix in the Kohn-Sham basis at one moment 

15 in time or at one frequency. 

16 

17 The plain density matrix and/or derivatives thereof may be stored. 

18 

19 Parameters 

20 ---------- 

21 ksd 

22 KohnShamDecomposition object. 

23 matrices 

24 Dictionary mapping derivative orders (0, 1, 2) for zeroth, 

25 first, second derivative, .. to arrays storing the matrices. 

26 comm 

27 MPI communicator. Serial communicator by default. 

28 """ 

29 

30 def __init__(self, 

31 ksd: KohnShamDecomposition, 

32 matrices: dict[int, NDArray[np.complex128] | None], 

33 comm: Communicator | None = None): 

34 self._ksd = ksd 

35 if comm is None: 

36 comm = SerialCommunicator() # type: ignore 

37 self._comm = comm 

38 

39 # Calculate occupation number difference 

40 f_n = self.ksd.occ_un[0] 

41 imin, imax, amin, amax = self.ksd.ialims() 

42 self._f_ia = f_n[imin:imax+1, None] - f_n[None, amin:amax+1] 

43 self._f_ia[self._f_ia < 0] = 0 

44 

45 # Calculate mask 

46 min_occdiff = min(self.ksd.f_p) 

47 mask_ia = self.f_ia >= min_occdiff 

48 

49 self._matrices: dict[int, DistributedArray] = dict() 

50 self.derivative_desc = {0: 'Plain DM', 1: '1st DM derivative', 2: '2nd DM derivative'} 

51 

52 # Save the matrices 

53 for derivative, rho in matrices.items(): 

54 assert isinstance(derivative, int) 

55 if self.rank == 0: 

56 assert isinstance(rho, np.ndarray), rho 

57 self._matrices[derivative] = rho * mask_ia 

58 else: 

59 assert rho is None 

60 self._matrices[derivative] = ArrayIsOnRootRank() 

61 

62 @property 

63 def ksd(self) -> KohnShamDecomposition: 

64 """ Kohn-Sham decomposition object. """ 

65 return self._ksd 

66 

67 @property 

68 def rank(self) -> int: 

69 """ MPI rank of the communicator. """ 

70 return self.comm.rank 

71 

72 @property 

73 def comm(self) -> Communicator: 

74 return self._comm # type: ignore 

75 

76 @property 

77 def f_ia(self) -> DistributedArray: 

78 """ Occupation number difference :math:`f_{ia}`. """ 

79 return self._f_ia 

80 

81 @property 

82 def rho_ia(self) -> DistributedArray: 

83 r""" Electron-hole part of induced density matrix :math:`\delta rho_{ia}`. """ 

84 try: 

85 return self._matrices[0] 

86 except KeyError: 

87 raise ValueError('Plain density matrix not in {self._matrices.keys()}') 

88 

89 @property 

90 def drho_ia(self) -> DistributedArray: 

91 r""" First time derivative of :math:`\delta rho_{ia}`. """ 

92 try: 

93 return self._matrices[1] 

94 except KeyError: 

95 raise ValueError('First derivative of density matrix not in {self._matrices.keys()}') 

96 

97 @property 

98 def ddrho_ia(self) -> DistributedArray: 

99 r""" Second time derivative of :math:`\delta rho_{ia}`. """ 

100 try: 

101 return self._matrices[2] 

102 except KeyError: 

103 raise ValueError('Second derivative of density matrix not in {self._matrices.keys()}') 

104 

105 @property 

106 def Q_ia(self) -> DistributedArray: 

107 r""" The quantity 

108 

109 .. math:: 

110 Q_{ia} = \frac{2 \mathrm{Re}\:\delta\rho_{ia}}{\sqrt{2 f_{ia}}} 

111 

112 where :math:`f_{ia}` is the occupation number difference of pair :math:`ia`. 

113 """ 

114 return self._divide_by_sqrt_fia(np.sqrt(2) * self.rho_ia.real) 

115 

116 @property 

117 def P_ia(self) -> DistributedArray: 

118 r""" The quantity 

119 

120 .. math:: 

121 P_{ia} = \frac{2 \mathrm{Im}\:\delta\rho_{ia}}{\sqrt{2 f_{ia}}} 

122 

123 where :math:`f_{ia}` is the occupation number difference of pair :math:`ia`. 

124 """ 

125 return self._divide_by_sqrt_fia(np.sqrt(2) * self.rho_ia.imag) 

126 

127 @property 

128 def dQ_ia(self) -> DistributedArray: 

129 r""" First time derivative of :math:`Q_{ia}`. """ 

130 return self._divide_by_sqrt_fia(np.sqrt(2) * self.drho_ia.real) 

131 

132 @property 

133 def dP_ia(self) -> DistributedArray: 

134 r""" First time derivative of :math:`P_{ia}`. """ 

135 return self._divide_by_sqrt_fia(np.sqrt(2) * self.drho_ia.imag) 

136 

137 @property 

138 def ddQ_ia(self) -> DistributedArray: 

139 r""" Second time derivative of :math:`Q_{ia}`. """ 

140 return self._divide_by_sqrt_fia(np.sqrt(2) * self.ddrho_ia.real) 

141 

142 @property 

143 def ddP_ia(self) -> DistributedArray: 

144 r""" Second time derivative of :math:`P_{ia}`. """ 

145 return self._divide_by_sqrt_fia(np.sqrt(2) * self.ddrho_ia.imag) 

146 

147 def _divide_by_sqrt_fia(self, 

148 X_ia: DistributedArray) -> DistributedArray: 

149 r""" Divide by :math:`\sqrt{f_{ia}}` where :math:`f_{ia} \neq 0`. 

150 Leave everything else as 0.""" 

151 if self.rank > 0: 

152 assert isinstance(X_ia, ArrayIsOnRootRank) 

153 return ArrayIsOnRootRank() 

154 assert not isinstance(X_ia, ArrayIsOnRootRank) 

155 flt_ia = self.f_ia != 0 

156 Y_ia = np.zeros_like(X_ia) 

157 Y_ia[flt_ia] = X_ia[flt_ia] / np.sqrt(self.f_ia[flt_ia]) 

158 

159 return Y_ia 

160 

161 def copy(self) -> DensityMatrix: 

162 """ Copy the density matrix. """ 

163 matrices: dict[int, NDArray[np.complex128] | None] = { 

164 derivative: np.array(matrix) for derivative, matrix in self._matrices.items()} 

165 dm = DensityMatrix(ksd=self.ksd, matrices=matrices, comm=self.comm) 

166 return dm 

167 

168 @classmethod 

169 def broadcast(cls, 

170 density_matrix: DensityMatrix | None, 

171 ksd: KohnShamDecomposition, 

172 root: int, 

173 dm_comm, 

174 comm) -> DensityMatrix: 

175 """ Broadcast a density matrix object which is on one rank to all other ranks. 

176 

177 Parameters 

178 ---------- 

179 density_matrix 

180 The density matrix to be broadcast on the root rank, and ``None`` on other ranks. 

181 ksd 

182 KohnShamDecomposition object. 

183 root 

184 Rank of the process that has the original data. 

185 dm_comm 

186 Must be identical to communicator of :attr:`density_matrix`. 

187 comm 

188 MPI communicator. Must be complementary to the communicator of :attr:`density_matrix`. 

189 """ 

190 matrices: dict[int, NDArray[np.complex128] | None] 

191 # Broadcast necessary metadata 

192 if comm.rank == root: 

193 assert density_matrix is not None 

194 matrix_shapes_dtypes = {derivative: (matrix.shape, matrix.dtype) 

195 for derivative, matrix in density_matrix._matrices.items()} 

196 broadcast(matrix_shapes_dtypes, root=root, comm=comm) 

197 matrices = {derivative: None if isinstance(arr, ArrayIsOnRootRank) else arr 

198 for derivative, arr in density_matrix._matrices.items()} 

199 else: 

200 assert density_matrix is None 

201 matrix_shapes_dtypes = broadcast(None, root=root, comm=comm) 

202 

203 if comm.rank != root: 

204 if dm_comm.rank == 0: 

205 matrices = {derivative: np.empty(shape, dtype=dtype) 

206 for derivative, (shape, dtype) in matrix_shapes_dtypes.items()} 

207 else: 

208 matrices = {derivative: None 

209 for derivative, (shape, dtype) in matrix_shapes_dtypes.items()} 

210 

211 if dm_comm.size > 1 and comm.size > 1: 

212 # Make sure communicators are complementary 

213 comm_members = comm.get_members() 

214 dm_members = dm_comm.get_members() 

215 

216 intersect = set(comm_members) & set(dm_members) 

217 intersect.remove(world.rank) 

218 assert len(intersect) == 0, f'{comm_members} / {dm_members}' 

219 

220 # On density matrix non-root ranks, return ArrayIsOnRootRank() 

221 if dm_comm.rank > 0: 

222 return DensityMatrix(ksd=ksd, matrices=matrices, comm=dm_comm) 

223 

224 # Broadcast the matrices 

225 for derivative, matrix in matrices.items(): 

226 comm.broadcast(np.ascontiguousarray(matrix), root) 

227 

228 if comm.rank == root: 

229 assert density_matrix is not None 

230 return density_matrix 

231 else: 

232 return DensityMatrix(ksd=ksd, matrices=matrices, comm=dm_comm)