Coverage for rhodent/density_matrices/buffer.py: 93%

270 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-01 16:57 +0000

1from __future__ import annotations 

2 

3from typing import Generator, Generic, Sequence 

4import numpy as np 

5from numpy.typing import NDArray 

6from numpy._typing import _DTypeLike as DTypeLike # parametrizable wrt generic 

7 

8from ..utils import DTypeT, Logger, env 

9from ..typing import ArrayIndex 

10 

11 

12class DensityMatrixBuffer(Generic[DTypeT]): 

13 

14 """ Buffer for the density matrix. 

15 

16 Objects of this class can hold buffers for real and imaginary parts 

17 and various derivatives at the same time. 

18 

19 Each buffer has two dimensions corresponding to (part of) the 

20 density matrix, and optionally additional dimensions for e.g. time. 

21 

22 Parameters 

23 ---------- 

24 nnshape 

25 Shape of the dimension corresponding to the density matrix. Must 

26 have dimension 2. 

27 xshape 

28 Shape of the additional dimension corresponding to, e.g., time. 

29 dtype 

30 Data type of the density matrices. 

31 re_buffers 

32 Buffers for the different derivatives of the real part of the 

33 density matrix as dictionaries, where the keys is the derivative 

34 order (0, 1, 2, etc.) and the value is a numpy array of shape 

35 ``(nnshape, xshape)``. 

36 im_buffers 

37 Same as :attr:`re_buffers` but for imaginary part. 

38 """ 

39 

40 def __init__(self, 

41 nnshape: tuple[int, int], 

42 xshape: tuple[int, ...], 

43 dtype: DTypeLike[DTypeT], 

44 re_buffers: dict[int, NDArray[DTypeT]] = dict(), 

45 im_buffers: dict[int, NDArray[DTypeT]] = dict()): 

46 assert len(nnshape) == 2 

47 assert all(isinstance(dim, (int, np.integer)) and dim >= 0 for dim in nnshape) 

48 assert all(isinstance(dim, (int, np.integer)) and dim >= 0 for dim in xshape) 

49 assert isinstance(np.dtype(dtype), np.dtype) 

50 self._nnshape = nnshape 

51 self._xshape = xshape 

52 self._dtype = np.dtype(dtype) 

53 self._re_buffers: dict[int, NDArray[DTypeT]] = dict() 

54 self._im_buffers: dict[int, NDArray[DTypeT]] = dict() 

55 

56 for derivative, buffer_nnx in re_buffers.items(): 

57 self.store(True, derivative, buffer_nnx) 

58 

59 for derivative, buffer_nnx in im_buffers.items(): 

60 self.store(False, derivative, buffer_nnx) 

61 

62 @property 

63 def real(self) -> NDArray[DTypeT]: 

64 """ Buffer of shape nnshape + xshape corresponding to real data. """ 

65 return self._get_real(0) 

66 

67 @property 

68 def real1(self) -> NDArray[DTypeT]: 

69 """ Buffer of shape nnshape + xshape corresponding to real part of first derivative. """ 

70 return self._get_real(1) 

71 

72 @property 

73 def real2(self) -> NDArray[DTypeT]: 

74 """ Buffer of shape nnshape + xshape corresponding to real part of second derivative. """ 

75 return self._get_real(2) 

76 

77 @property 

78 def imag(self) -> NDArray[DTypeT]: 

79 """ Buffer of shape nnshape + xshape corresponding to imaginary data. """ 

80 return self._get_imag(0) 

81 

82 @property 

83 def imag1(self) -> NDArray[DTypeT]: 

84 """ Buffer of shape nnshape + xshape corresponding to imaginary part of first derivative. """ 

85 return self._get_imag(1) 

86 

87 @property 

88 def imag2(self) -> NDArray[DTypeT]: 

89 """ Buffer of shape nnshape + xshape corresponding to imaginary part of second derivative. """ 

90 return self._get_imag(2) 

91 

92 def _get_real(self, 

93 derivative: int) -> NDArray[DTypeT]: 

94 """ Fetch density matrix buffer for real data. 

95 

96 Parameters 

97 ---------- 

98 derivative 

99 Derivative order. 

100 """ 

101 return self._re_buffers[derivative] 

102 

103 def _get_imag(self, 

104 derivative: int) -> NDArray[DTypeT]: 

105 """ Fetch density matrix buffer for imaginary data. 

106 

107 Parameters 

108 ---------- 

109 derivative 

110 Derivative order. 

111 """ 

112 return self._im_buffers[derivative] 

113 

114 def _get_data(self, 

115 real: bool, 

116 derivative: int) -> NDArray[DTypeT]: 

117 """ Fetch density matrix buffer. 

118 

119 Parameters 

120 ---------- 

121 real 

122 ``True`` if real, ``False`` if imaginary. 

123 derivative 

124 Derivative order. 

125 """ 

126 return self._get_real(derivative) if real else self._get_imag(derivative) 

127 

128 def copy(self) -> DensityMatrixBuffer: 

129 """ Return a deep copy of this object (buffers are copied). """ 

130 re_buffers = {derivative: np.array(buffer_nnx) 

131 for derivative, buffer_nnx in self._re_buffers.items()} 

132 im_buffers = {derivative: np.array(buffer_nnx) 

133 for derivative, buffer_nnx in self._im_buffers.items()} 

134 

135 dm_buffer = DensityMatrixBuffer(self.nnshape, self.xshape, 

136 dtype=self.dtype, 

137 re_buffers=re_buffers, 

138 im_buffers=im_buffers) 

139 return dm_buffer 

140 

141 def new(self) -> DensityMatrixBuffer: 

142 """ Return a new buffer with the same shape. """ 

143 dm_buffer = DensityMatrixBuffer(self.nnshape, self.xshape, 

144 dtype=self.dtype) 

145 return dm_buffer 

146 

147 def __getitem__(self, 

148 value) -> DensityMatrixBuffer: 

149 """ Index the buffers and return a new DensityMatrixBuffer 

150 with buffers that are views of the buffers of this DensityMatrixBuffer. 

151 """ 

152 if len(self._im_buffers) == 0 and len(self._re_buffers) == 0: 

153 # This case needs some special handing to get the dimension of 

154 # the output 

155 raise NotImplementedError 

156 

157 # Wrap in a tuple 

158 if not isinstance(value, tuple): 

159 value = (value, ) 

160 s = (slice(None), slice(None)) + value 

161 re_buffers = {derivative: buffer_nnx[s] 

162 for derivative, buffer_nnx in self._re_buffers.items()} 

163 im_buffers = {derivative: buffer_nnx[s] 

164 for derivative, buffer_nnx in self._im_buffers.items()} 

165 

166 # Ugly hack. Get any of the buffers 

167 xshape = (list(re_buffers.values()) + list(im_buffers.values()))[0].shape[2:] 

168 return DensityMatrixBuffer(self.nnshape, xshape, dtype=self.dtype, 

169 re_buffers=re_buffers, im_buffers=im_buffers) 

170 

171 @property 

172 def narrays(self) -> int: 

173 """ Number of arrays stored in this object. """ 

174 return len(self.derivatives_imag) + len(self.derivatives_real) 

175 

176 @property 

177 def nnshape(self) -> tuple[int, int]: 

178 """ Shape of the part of the density matrix. """ 

179 return self._nnshape 

180 

181 @property 

182 def xshape(self) -> tuple[int, ...]: 

183 """ Shape of the additional dimension of the buffers. """ 

184 return self._xshape 

185 

186 @property 

187 def shape(self) -> tuple[int, ...]: 

188 """ Shape of the buffers. """ 

189 return self.nnshape + self.xshape 

190 

191 @property 

192 def dtype(self) -> np.dtype[DTypeT]: 

193 """ Dtype of the buffers. """ 

194 return self._dtype 

195 

196 def store(self, 

197 real: bool, 

198 derivative: int, 

199 buffer_nnx: NDArray[DTypeT]): 

200 """ Store buffer for part of density matrix. 

201 

202 Parameters 

203 ---------- 

204 real 

205 ``True`` if real, ``False`` if imaginary. 

206 derivative 

207 Derivative order. 

208 buffer_nnx 

209 Buffer of shape ``(nnshape, xshape)``. 

210 """ 

211 assert isinstance(derivative, int) and derivative >= 0, derivative 

212 assert isinstance(buffer_nnx, np.ndarray) 

213 assert buffer_nnx.shape == self.shape, f'{buffer_nnx.shape} != {self.shape}' 

214 assert buffer_nnx.dtype == self.dtype 

215 if real: 

216 self._re_buffers[derivative] = buffer_nnx 

217 else: 

218 self._im_buffers[derivative] = buffer_nnx 

219 

220 def zero_buffers(self, 

221 real: bool, 

222 imag: bool, 

223 derivative_order_s: list[int]): 

224 """ Initialize many buffers at once to and write zeros. 

225 

226 Parameters 

227 ---------- 

228 real 

229 Initialize buffers for real parts. 

230 imag 

231 Initialize buffers for imaginary parts. 

232 derivative_order_s 

233 Initialize buffers for these derivative orders. 

234 """ 

235 for derivative in derivative_order_s: 

236 if real: 

237 self.zeros(True, derivative) 

238 if imag: 

239 self.zeros(False, derivative) 

240 

241 def zeros(self, 

242 real: bool, 

243 derivative: int): 

244 """ Initialize buffer with zeros for part of density matrix. 

245 

246 Parameters 

247 ---------- 

248 real 

249 ``True`` if real, ``False`` if imaginary. 

250 derivative 

251 Derivative order. 

252 """ 

253 self.store(real, derivative, np.zeros(self.shape, dtype=self.dtype)) 

254 

255 def broadcast_x(self, 

256 data_nnx: NDArray[DTypeT]) -> NDArray[DTypeT]: 

257 """ Broadcast the x dimensions of data_nnx. """ 

258 nnshape = data_nnx.shape[:2] 

259 data_xnn = np.moveaxis(np.moveaxis(data_nnx, 0, -1), 0, -1) 

260 data_xnn = np.broadcast_to(data_xnn, self.xshape + nnshape) 

261 data_nnx = np.moveaxis(np.moveaxis(data_xnn, -1, 0), -1, 0) 

262 return data_nnx 

263 

264 @property 

265 def nnellipsis(self) -> tuple[slice, slice]: 

266 return (slice(None), slice(None)) 

267 

268 @property 

269 def xellipsis(self) -> tuple[slice, ...]: 

270 return tuple(len(self.xshape) * [slice(None)]) 

271 

272 def safe_fill(self, 

273 real: bool, 

274 derivative: int, 

275 data_nnx: NDArray[DTypeT]): 

276 """ Write data_nnx to the corrsponding buffer, if the dimensions of data_nnx 

277 are equal to or smaller than the buffer. 

278 

279 If the dimensions of data_nnx are smaller than or equal to the dimensions 

280 of the buffer, write to the first elements of the buffer. 

281 Otherwise raise and error. 

282 

283 Parameters 

284 ---------- 

285 real 

286 ``True`` if real, ``False`` if imaginary. 

287 derivative 

288 Derivative order. 

289 buffer_nnx 

290 Data of shape ``(nnshape, xshape)``. 

291 """ 

292 assert len(data_nnx.shape) <= len(self.shape), f'{data_nnx.shape} > {self.shape}' 

293 assert all([dima >= dimb for dima, dimb in zip(self.nnshape, data_nnx.shape[:2])]), \ 

294 f'{self.nnshape} < {data_nnx.shape[:2]}' 

295 data_nnx = self.broadcast_x(data_nnx) # Broadcast the last dimensions 

296 assert data_nnx.shape[2:] == self.xshape, f'{data_nnx.shape[2:]} != {self.xshape}' 

297 s = tuple([slice(dim) for dim in data_nnx.shape[:2]]) + self.xellipsis 

298 buffer_nnx = self._get_data(real, derivative) 

299 buffer_nnx[s] = data_nnx 

300 

301 def safe_add(self, 

302 real: bool, 

303 derivative: int, 

304 data_nnx: NDArray[DTypeT]): 

305 """ Add data_nnx to the corrsponding buffer, if the dimensions of data_nnx 

306 are equal to or smaller than the buffer 

307 

308 If the dimensions of data_nnx are smaller than or equal to the dimensions 

309 of the buffer, add to the first elements of the buffer. 

310 Otherwise raise and error. 

311 

312 Parameters 

313 ---------- 

314 real 

315 ``True`` if real, ``False`` if imaginary. 

316 derivative 

317 Derivative order. 

318 buffer_nnx 

319 Data of shape ``(nnshape, xshape)``. 

320 """ 

321 assert len(data_nnx.shape) <= len(self.shape), f'{data_nnx.shape} > {self.shape}' 

322 assert all([dima >= dimb for dima, dimb in zip(self.nnshape, data_nnx.shape[:2])]), \ 

323 f'{self.nnshape} < {data_nnx.shape[:2]}' 

324 data_nnx = self.broadcast_x(data_nnx) # Broadcast the last dimensions 

325 assert data_nnx.shape[2:] == self.xshape, f'{data_nnx.shape[2:]} != {self.xshape}' 

326 s = tuple([slice(dim) for dim in data_nnx.shape[:2]]) + self.xellipsis 

327 buffer_nnx = self._get_data(real, derivative) 

328 # Regarding ignore: 

329 # https://stackoverflow.com/questions/74633074/how-to-type-hint-a-generic-numpy-array/74634650#74634650 

330 buffer_nnx[s] += data_nnx # type: ignore 

331 

332 @property 

333 def derivatives_real(self) -> list[int]: 

334 """ Stored derivative order of real density matrices in sorted order """ 

335 return list(sorted(self._re_buffers.keys())) 

336 

337 @property 

338 def derivatives_imag(self) -> list[int]: 

339 """ Stored derivative order of real density matrices in sorted order """ 

340 return list(sorted(self._im_buffers.keys())) 

341 

342 def _iter_buffers(self) -> Generator[NDArray[DTypeT], None, None]: 

343 """ Loop over all real imaginary buffers in a sorted order """ 

344 for derivative in self.derivatives_real: 

345 yield self._re_buffers[derivative] 

346 for derivative in self.derivatives_imag: 

347 yield self._im_buffers[derivative] 

348 

349 def _iter_reim_derivatives(self) -> Generator[tuple[bool, int], None, None]: 

350 """ Loop over tuples (real, derivative) in sorted order. 

351 

352 The parameter real is ``True`` for real buffers and the parameter derivative denotes the 

353 derivative order of the buffer. 

354 """ 

355 for derivative in self.derivatives_real: 

356 yield (True, derivative) 

357 for derivative in self.derivatives_imag: 

358 yield (False, derivative) 

359 

360 def ensure_contiguous_buffers(self): 

361 """ Make the buffers contiguous if they are not already. """ 

362 for derivative in self.derivatives_real: 

363 self._re_buffers[derivative] = np.ascontiguousarray(self._re_buffers[derivative]) 

364 for derivative in self.derivatives_imag: 

365 self._im_buffers[derivative] = np.ascontiguousarray(self._im_buffers[derivative]) 

366 

367 def send_arrays(self, 

368 comm, 

369 rank: int, 

370 log: Logger | None = None): 

371 """ Send data to another MPI rank. 

372 

373 Parameters 

374 ---------- 

375 comm 

376 Communicator. 

377 rank 

378 Send to this rank. 

379 log 

380 Optional logger. 

381 """ 

382 if comm.rank == rank: 

383 # Sending to send 

384 return 

385 

386 if log is not None: 

387 log.start('send_to_root') 

388 

389 requests = [] 

390 for mpitag, buffer_nnx in enumerate(self._iter_buffers(), start=987): 

391 buffer_nnx = np.ascontiguousarray(buffer_nnx) 

392 requests.append(comm.send(buffer_nnx, 0, tag=mpitag, block=False)) 

393 comm.waitall(requests) 

394 

395 if log is not None: 

396 log(f'Sending to root {log.elapsed("send_to_root"):.1f}s', who='Response', flush=True) 

397 

398 def recv_arrays(self, 

399 comm, 

400 rank: int, 

401 log: Logger | None = None): 

402 """ Receive data to another MPI rank. 

403 

404 Parameters 

405 ---------- 

406 comm 

407 Communicator. 

408 rank 

409 Send to this rank. 

410 log 

411 Optional logger. 

412 """ 

413 if comm.rank == rank: 

414 # Receiving from self 

415 return 

416 

417 if log is not None: 

418 log.start('root_recv') 

419 

420 requests = [] 

421 for mpitag, buffer_nnx in enumerate(self._iter_buffers(), start=987): 

422 requests.append(comm.receive(buffer_nnx, rank, tag=mpitag, block=False)) 

423 comm.waitall(requests) 

424 

425 if log is not None: 

426 log(f'Root received {log.elapsed("root_recv"):.1f}s from {rank}', who='Response', flush=True) 

427 

428 def redistribute(self, 

429 target: DensityMatrixBuffer, 

430 comm, 

431 source_indices_r: Sequence[tuple[ArrayIndex, ...] | ArrayIndex | None], 

432 target_indices_r: Sequence[tuple[ArrayIndex, ...] | ArrayIndex | None], 

433 log: Logger | None = None, 

434 ): 

435 """ Redistribute this DensityMatrixBuffer to another according to user specified options. 

436 

437 The nn dimensions of the self and target buffers should be the same, 

438 but the x dimensions can be different. However, self need to have the same shape on all ranks 

439 and target needs to have the same shape on all ranks. 

440 

441 Parameters 

442 ---------- 

443 target 

444 Target :class:`DensityMatrixBuffer`. 

445 comm 

446 MPI communicator. 

447 source_indices_r 

448 List of x-indices. The length of the list must equal to the communicator size. 

449 The x-index that is element r of the list corresponds 

450 to the data from the source that will be sent to rank r. 

451 If the x-index is None, then the rank corresponding to that element will not be 

452 receiving data. This argument must be the same on all ranks 

453 recv_indices_r 

454 List of x-indices. The length of the list must equal to the communicator size. 

455 The x-index that is element r of the list corresponds 

456 to the data in the target that will be received from rank r. 

457 If the x-index is None, then the rank corresponding to that element will not be 

458 sending data. This argument must be the same on all ranks 

459 log 

460 Optional logger. 

461 """ 

462 # Size of each density matrix (the nn-dimensions) 

463 nnsize = int(np.prod(self.nnshape)) 

464 # Convert maxsize to maximum number of elements 

465 maxsize = env.get_float('REDISTRIBUTE_MAXSIZE') 

466 maxtotalelems = int(np.ceil(maxsize / self.dtype.itemsize)) 

467 

468 assert len(source_indices_r) == comm.size, len(source_indices_r) 

469 assert len(target_indices_r) == comm.size, len(target_indices_r) 

470 

471 # Extract source and target indices that are not None and make sure they are tuples 

472 source_indices_by_rank = {rank: x_indices if isinstance(x_indices, tuple) else (x_indices, ) 

473 for rank, x_indices in enumerate(source_indices_r) 

474 if x_indices is not None} 

475 target_indices_by_rank = {rank: x_indices if isinstance(x_indices, tuple) else (x_indices, ) 

476 for rank, x_indices in enumerate(target_indices_r) 

477 if x_indices is not None} 

478 assert len(source_indices_by_rank) > 0 

479 assert len(target_indices_by_rank) > 0 

480 

481 # Make sure that same derivatives and real/imaginary parts are stored and that dtypes are same 

482 assert tuple(self.derivatives_real) == tuple(target.derivatives_real) 

483 assert tuple(self.derivatives_imag) == tuple(target.derivatives_imag) 

484 assert self.dtype == target.dtype, f'{self.dtype} != {target.dtype}' 

485 

486 # Get the xshapes of all sources 

487 source_xshape_by_rank: dict[int, tuple[int, ...]] = dict() 

488 if comm.rank in target_indices_by_rank.keys(): 

489 for buf_nnx in self._iter_buffers(): 

490 source_xshape_by_rank = {rank: buf_nnx[self.nnellipsis + x_indices].shape[2:] 

491 for rank, x_indices in source_indices_by_rank.items()} 

492 break 

493 

494 # Get the xshapes of the targets by an alltoall operation with the sources 

495 # -2 means nothing, -1 in first field means empty tuple 

496 xdims = max(len(self.xshape), len(target.xshape)) 

497 pad_target_xshape_r = -2 * np.ones((comm.size, xdims), dtype=int) 

498 pad_source_xshape_r = -2 * np.ones((comm.size, xdims), dtype=int) 

499 for rank, xshape in source_xshape_by_rank.items(): 

500 pad_source_xshape_r[rank, :len(xshape)] = xshape 

501 if xshape == (): 

502 pad_source_xshape_r[rank, 0] = -1 

503 comm.alltoallv(pad_source_xshape_r, 

504 xdims * np.ones(comm.size, dtype=int), 

505 xdims * np.arange(comm.size, dtype=int), 

506 pad_target_xshape_r, 

507 xdims * np.ones(comm.size, dtype=int), 

508 xdims * np.arange(comm.size, dtype=int)) 

509 target_xshape_by_rank = {rank: tuple(xshape[xshape > -1]) 

510 for rank, xshape in enumerate(pad_target_xshape_r) if xshape[0] > -2} 

511 # Check that target sizes supplied by the user are shorter than 

512 # or equal to the sizes from the alltoall operation 

513 if comm.rank in source_indices_by_rank.keys(): 

514 for buf_nnx in target._iter_buffers(): 

515 for rank, x_indices in target_indices_by_rank.items(): 

516 target_xshape = target_xshape_by_rank[rank] 

517 xshape = buf_nnx[self.nnellipsis + x_indices].shape[2:] 

518 assert all(np.less_equal(target_xshape, xshape)) 

519 break 

520 

521 # Get the total number of density matrices that this rank sends/receives 

522 source_xsize_by_rank = {rank: np.prod(xshape, dtype=int) 

523 for rank, xshape in source_xshape_by_rank.items()} 

524 target_xsize_by_rank = {rank: np.prod(xshape, dtype=int) 

525 for rank, xshape in target_xshape_by_rank.items()} 

526 my_sourcexsize = sum(source_xsize_by_rank.values()) 

527 my_targetxsize = sum(target_xsize_by_rank.values()) 

528 # Get the total number of array elements to be sent across all ranks 

529 sizes = np.array([my_sourcexsize, my_targetxsize], dtype=int) 

530 comm.sum(sizes, root=-1) 

531 total_sourcexsize, total_targetxsize = sizes 

532 totalsize = max(total_sourcexsize, total_targetxsize) * nnsize 

533 

534 # Split the data across the nn-dimensions since they are always the same; how many times? 

535 factortoolarge = totalsize / maxtotalelems 

536 nnstride = int(np.ceil(nnsize / factortoolarge)) 

537 nnstride = min(nnsize, nnstride) 

538 nsplits = int(np.ceil(nnsize / nnstride)) 

539 if log is not None and comm.rank == 0: 

540 total_MiB = totalsize * self.dtype.itemsize / (1024 ** 2) 

541 buftotal_MiB = totalsize / nsplits * self.dtype.itemsize / (1024 ** 2) 

542 

543 log(f'Redistribute: {len(target_indices_by_rank)} sending ' 

544 f'and {len(source_indices_by_rank)} receiving. ' 

545 f'Total size on all ranks ({total_MiB:.1f} MiB) ' 

546 f'splitting in {nsplits} parts ({buftotal_MiB:.1f} MiB on all ranks)', 

547 who='Response', flush=True) 

548 

549 # Perpare buffers for sending and receiving 

550 # counts - Number of elements to send to (s) or receive from (r) each rank 

551 # displs - Position of data to send to (s) or receive from (r) each rank 

552 sendbuf = np.zeros(my_sourcexsize * nnstride, dtype=self.dtype) 

553 recvbuf = np.zeros(my_targetxsize * nnstride, dtype=self.dtype) 

554 scounts_r = np.zeros(comm.size, dtype=int) 

555 sdispls_r = np.zeros(comm.size, dtype=int) 

556 rcounts_r = np.zeros(comm.size, dtype=int) 

557 rdispls_r = np.zeros(comm.size, dtype=int) 

558 

559 displ = 0 

560 if comm.rank in target_indices_by_rank.keys(): 

561 # This rank has some data to send. It will send to the ranks that are among the source keys 

562 sendbuf_by_rank = dict() 

563 for buf_nnx in self._iter_buffers(): 

564 for destrank, xsize in source_xsize_by_rank.items(): 

565 size = xsize * nnstride 

566 sendbuf_by_rank[destrank] = sendbuf[displ:displ+size] 

567 scounts_r[destrank] = size 

568 if size > 0: 

569 sdispls_r[destrank] = displ 

570 displ += size 

571 break 

572 displ = 0 

573 if comm.rank in source_indices_by_rank.keys(): 

574 # This rank has some data to receive. It will receive from the ranks that are among the target keys 

575 recvbuf_by_rank = dict() 

576 for buf_nnx in self._iter_buffers(): 

577 for destrank, xsize in target_xsize_by_rank.items(): 

578 size = xsize * nnstride 

579 recvbuf_by_rank[destrank] = recvbuf[displ:displ+size] 

580 rcounts_r[destrank] = size 

581 if size > 0: 

582 rdispls_r[destrank] = displ 

583 displ += size 

584 break 

585 

586 # Flattened nn-dimensions for splitting 

587 flatslices = [slice(start, start + nnstride, 1) for start in range(0, nnsize, nnstride)] 

588 grid = np.mgrid[:self.nnshape[0], :self.nnshape[1]] 

589 

590 # Loop over real and imaginary parts and derivatives 

591 for (real, derivative), sendbuf_nnx, recvbuf_nnx in zip( 

592 self._iter_reim_derivatives(), self._iter_buffers(), target._iter_buffers()): 

593 

594 # List of data to send and list of buffers where data should be received 

595 if comm.rank in target_indices_by_rank.keys(): 

596 senddata_by_rank = {rank: sendbuf_nnx[self.nnellipsis + x_indices] 

597 for rank, x_indices in source_indices_by_rank.items()} 

598 if comm.rank in source_indices_by_rank.keys(): 

599 recvdata_by_rank = {rank: recvbuf_nnx[self.nnellipsis + x_indices] 

600 for rank, x_indices in target_indices_by_rank.items()} 

601 # The target data may be smaller than what is given by the user 

602 recvdata_by_rank = {rank: recvdata_by_rank[rank][ 

603 self.nnellipsis + tuple([slice(dim) for dim in xshape])] 

604 for rank, xshape in target_xshape_by_rank.items()} 

605 for data, xshape in zip(recvdata_by_rank.values(), target_xshape_by_rank.values()): 

606 assert data.shape[2:] == xshape, str(data.shape[2:]) + ' != ' + str(xshape) 

607 

608 # Loop over the data in splits 

609 for flatslice in flatslices: 

610 # Slices of nn 

611 nnslices = (grid[0].ravel()[flatslice], grid[1].ravel()[flatslice]) 

612 

613 # Copy data to the contiguous send buffer 

614 if comm.rank in target_indices_by_rank.keys(): 

615 for destrank, buf in sendbuf_by_rank.items(): 

616 data = senddata_by_rank[destrank][nnslices].ravel() 

617 buf[:len(data)] = data 

618 

619 # Send the data 

620 comm.alltoallv(sendbuf, scounts_r, sdispls_r, 

621 recvbuf, rcounts_r, rdispls_r) 

622 

623 # Copy data from the contiguous receive buffer 

624 if comm.rank in source_indices_by_rank.keys(): 

625 for destrank, buf in recvbuf_by_rank.items(): 

626 # Copy the first elements of the receive buffer to the data position 

627 datashape = recvdata_by_rank[destrank][nnslices].shape 

628 datalen = np.prod(datashape, dtype=int) 

629 buf = buf[:datalen] 

630 recvdata_by_rank[destrank][nnslices] = buf.reshape(datashape)