Coverage for rhodent/density_matrices/buffer.py: 93%
270 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
1from __future__ import annotations
3from typing import Generator, Generic, Sequence
4import numpy as np
5from numpy.typing import NDArray
6from numpy._typing import _DTypeLike as DTypeLike # parametrizable wrt generic
8from ..utils import DTypeT, Logger, env
9from ..typing import ArrayIndex
12class DensityMatrixBuffer(Generic[DTypeT]):
14 """ Buffer for the density matrix.
16 Objects of this class can hold buffers for real and imaginary parts
17 and various derivatives at the same time.
19 Each buffer has two dimensions corresponding to (part of) the
20 density matrix, and optionally additional dimensions for e.g. time.
22 Parameters
23 ----------
24 nnshape
25 Shape of the dimension corresponding to the density matrix. Must
26 have dimension 2.
27 xshape
28 Shape of the additional dimension corresponding to, e.g., time.
29 dtype
30 Data type of the density matrices.
31 re_buffers
32 Buffers for the different derivatives of the real part of the
33 density matrix as dictionaries, where the keys is the derivative
34 order (0, 1, 2, etc.) and the value is a numpy array of shape
35 ``(nnshape, xshape)``.
36 im_buffers
37 Same as :attr:`re_buffers` but for imaginary part.
38 """
40 def __init__(self,
41 nnshape: tuple[int, int],
42 xshape: tuple[int, ...],
43 dtype: DTypeLike[DTypeT],
44 re_buffers: dict[int, NDArray[DTypeT]] = dict(),
45 im_buffers: dict[int, NDArray[DTypeT]] = dict()):
46 assert len(nnshape) == 2
47 assert all(isinstance(dim, (int, np.integer)) and dim >= 0 for dim in nnshape)
48 assert all(isinstance(dim, (int, np.integer)) and dim >= 0 for dim in xshape)
49 assert isinstance(np.dtype(dtype), np.dtype)
50 self._nnshape = nnshape
51 self._xshape = xshape
52 self._dtype = np.dtype(dtype)
53 self._re_buffers: dict[int, NDArray[DTypeT]] = dict()
54 self._im_buffers: dict[int, NDArray[DTypeT]] = dict()
56 for derivative, buffer_nnx in re_buffers.items():
57 self.store(True, derivative, buffer_nnx)
59 for derivative, buffer_nnx in im_buffers.items():
60 self.store(False, derivative, buffer_nnx)
62 @property
63 def real(self) -> NDArray[DTypeT]:
64 """ Buffer of shape nnshape + xshape corresponding to real data. """
65 return self._get_real(0)
67 @property
68 def real1(self) -> NDArray[DTypeT]:
69 """ Buffer of shape nnshape + xshape corresponding to real part of first derivative. """
70 return self._get_real(1)
72 @property
73 def real2(self) -> NDArray[DTypeT]:
74 """ Buffer of shape nnshape + xshape corresponding to real part of second derivative. """
75 return self._get_real(2)
77 @property
78 def imag(self) -> NDArray[DTypeT]:
79 """ Buffer of shape nnshape + xshape corresponding to imaginary data. """
80 return self._get_imag(0)
82 @property
83 def imag1(self) -> NDArray[DTypeT]:
84 """ Buffer of shape nnshape + xshape corresponding to imaginary part of first derivative. """
85 return self._get_imag(1)
87 @property
88 def imag2(self) -> NDArray[DTypeT]:
89 """ Buffer of shape nnshape + xshape corresponding to imaginary part of second derivative. """
90 return self._get_imag(2)
92 def _get_real(self,
93 derivative: int) -> NDArray[DTypeT]:
94 """ Fetch density matrix buffer for real data.
96 Parameters
97 ----------
98 derivative
99 Derivative order.
100 """
101 return self._re_buffers[derivative]
103 def _get_imag(self,
104 derivative: int) -> NDArray[DTypeT]:
105 """ Fetch density matrix buffer for imaginary data.
107 Parameters
108 ----------
109 derivative
110 Derivative order.
111 """
112 return self._im_buffers[derivative]
114 def _get_data(self,
115 real: bool,
116 derivative: int) -> NDArray[DTypeT]:
117 """ Fetch density matrix buffer.
119 Parameters
120 ----------
121 real
122 ``True`` if real, ``False`` if imaginary.
123 derivative
124 Derivative order.
125 """
126 return self._get_real(derivative) if real else self._get_imag(derivative)
128 def copy(self) -> DensityMatrixBuffer:
129 """ Return a deep copy of this object (buffers are copied). """
130 re_buffers = {derivative: np.array(buffer_nnx)
131 for derivative, buffer_nnx in self._re_buffers.items()}
132 im_buffers = {derivative: np.array(buffer_nnx)
133 for derivative, buffer_nnx in self._im_buffers.items()}
135 dm_buffer = DensityMatrixBuffer(self.nnshape, self.xshape,
136 dtype=self.dtype,
137 re_buffers=re_buffers,
138 im_buffers=im_buffers)
139 return dm_buffer
141 def new(self) -> DensityMatrixBuffer:
142 """ Return a new buffer with the same shape. """
143 dm_buffer = DensityMatrixBuffer(self.nnshape, self.xshape,
144 dtype=self.dtype)
145 return dm_buffer
147 def __getitem__(self,
148 value) -> DensityMatrixBuffer:
149 """ Index the buffers and return a new DensityMatrixBuffer
150 with buffers that are views of the buffers of this DensityMatrixBuffer.
151 """
152 if len(self._im_buffers) == 0 and len(self._re_buffers) == 0:
153 # This case needs some special handing to get the dimension of
154 # the output
155 raise NotImplementedError
157 # Wrap in a tuple
158 if not isinstance(value, tuple):
159 value = (value, )
160 s = (slice(None), slice(None)) + value
161 re_buffers = {derivative: buffer_nnx[s]
162 for derivative, buffer_nnx in self._re_buffers.items()}
163 im_buffers = {derivative: buffer_nnx[s]
164 for derivative, buffer_nnx in self._im_buffers.items()}
166 # Ugly hack. Get any of the buffers
167 xshape = (list(re_buffers.values()) + list(im_buffers.values()))[0].shape[2:]
168 return DensityMatrixBuffer(self.nnshape, xshape, dtype=self.dtype,
169 re_buffers=re_buffers, im_buffers=im_buffers)
171 @property
172 def narrays(self) -> int:
173 """ Number of arrays stored in this object. """
174 return len(self.derivatives_imag) + len(self.derivatives_real)
176 @property
177 def nnshape(self) -> tuple[int, int]:
178 """ Shape of the part of the density matrix. """
179 return self._nnshape
181 @property
182 def xshape(self) -> tuple[int, ...]:
183 """ Shape of the additional dimension of the buffers. """
184 return self._xshape
186 @property
187 def shape(self) -> tuple[int, ...]:
188 """ Shape of the buffers. """
189 return self.nnshape + self.xshape
191 @property
192 def dtype(self) -> np.dtype[DTypeT]:
193 """ Dtype of the buffers. """
194 return self._dtype
196 def store(self,
197 real: bool,
198 derivative: int,
199 buffer_nnx: NDArray[DTypeT]):
200 """ Store buffer for part of density matrix.
202 Parameters
203 ----------
204 real
205 ``True`` if real, ``False`` if imaginary.
206 derivative
207 Derivative order.
208 buffer_nnx
209 Buffer of shape ``(nnshape, xshape)``.
210 """
211 assert isinstance(derivative, int) and derivative >= 0, derivative
212 assert isinstance(buffer_nnx, np.ndarray)
213 assert buffer_nnx.shape == self.shape, f'{buffer_nnx.shape} != {self.shape}'
214 assert buffer_nnx.dtype == self.dtype
215 if real:
216 self._re_buffers[derivative] = buffer_nnx
217 else:
218 self._im_buffers[derivative] = buffer_nnx
220 def zero_buffers(self,
221 real: bool,
222 imag: bool,
223 derivative_order_s: list[int]):
224 """ Initialize many buffers at once to and write zeros.
226 Parameters
227 ----------
228 real
229 Initialize buffers for real parts.
230 imag
231 Initialize buffers for imaginary parts.
232 derivative_order_s
233 Initialize buffers for these derivative orders.
234 """
235 for derivative in derivative_order_s:
236 if real:
237 self.zeros(True, derivative)
238 if imag:
239 self.zeros(False, derivative)
241 def zeros(self,
242 real: bool,
243 derivative: int):
244 """ Initialize buffer with zeros for part of density matrix.
246 Parameters
247 ----------
248 real
249 ``True`` if real, ``False`` if imaginary.
250 derivative
251 Derivative order.
252 """
253 self.store(real, derivative, np.zeros(self.shape, dtype=self.dtype))
255 def broadcast_x(self,
256 data_nnx: NDArray[DTypeT]) -> NDArray[DTypeT]:
257 """ Broadcast the x dimensions of data_nnx. """
258 nnshape = data_nnx.shape[:2]
259 data_xnn = np.moveaxis(np.moveaxis(data_nnx, 0, -1), 0, -1)
260 data_xnn = np.broadcast_to(data_xnn, self.xshape + nnshape)
261 data_nnx = np.moveaxis(np.moveaxis(data_xnn, -1, 0), -1, 0)
262 return data_nnx
264 @property
265 def nnellipsis(self) -> tuple[slice, slice]:
266 return (slice(None), slice(None))
268 @property
269 def xellipsis(self) -> tuple[slice, ...]:
270 return tuple(len(self.xshape) * [slice(None)])
272 def safe_fill(self,
273 real: bool,
274 derivative: int,
275 data_nnx: NDArray[DTypeT]):
276 """ Write data_nnx to the corrsponding buffer, if the dimensions of data_nnx
277 are equal to or smaller than the buffer.
279 If the dimensions of data_nnx are smaller than or equal to the dimensions
280 of the buffer, write to the first elements of the buffer.
281 Otherwise raise and error.
283 Parameters
284 ----------
285 real
286 ``True`` if real, ``False`` if imaginary.
287 derivative
288 Derivative order.
289 buffer_nnx
290 Data of shape ``(nnshape, xshape)``.
291 """
292 assert len(data_nnx.shape) <= len(self.shape), f'{data_nnx.shape} > {self.shape}'
293 assert all([dima >= dimb for dima, dimb in zip(self.nnshape, data_nnx.shape[:2])]), \
294 f'{self.nnshape} < {data_nnx.shape[:2]}'
295 data_nnx = self.broadcast_x(data_nnx) # Broadcast the last dimensions
296 assert data_nnx.shape[2:] == self.xshape, f'{data_nnx.shape[2:]} != {self.xshape}'
297 s = tuple([slice(dim) for dim in data_nnx.shape[:2]]) + self.xellipsis
298 buffer_nnx = self._get_data(real, derivative)
299 buffer_nnx[s] = data_nnx
301 def safe_add(self,
302 real: bool,
303 derivative: int,
304 data_nnx: NDArray[DTypeT]):
305 """ Add data_nnx to the corrsponding buffer, if the dimensions of data_nnx
306 are equal to or smaller than the buffer
308 If the dimensions of data_nnx are smaller than or equal to the dimensions
309 of the buffer, add to the first elements of the buffer.
310 Otherwise raise and error.
312 Parameters
313 ----------
314 real
315 ``True`` if real, ``False`` if imaginary.
316 derivative
317 Derivative order.
318 buffer_nnx
319 Data of shape ``(nnshape, xshape)``.
320 """
321 assert len(data_nnx.shape) <= len(self.shape), f'{data_nnx.shape} > {self.shape}'
322 assert all([dima >= dimb for dima, dimb in zip(self.nnshape, data_nnx.shape[:2])]), \
323 f'{self.nnshape} < {data_nnx.shape[:2]}'
324 data_nnx = self.broadcast_x(data_nnx) # Broadcast the last dimensions
325 assert data_nnx.shape[2:] == self.xshape, f'{data_nnx.shape[2:]} != {self.xshape}'
326 s = tuple([slice(dim) for dim in data_nnx.shape[:2]]) + self.xellipsis
327 buffer_nnx = self._get_data(real, derivative)
328 # Regarding ignore:
329 # https://stackoverflow.com/questions/74633074/how-to-type-hint-a-generic-numpy-array/74634650#74634650
330 buffer_nnx[s] += data_nnx # type: ignore
332 @property
333 def derivatives_real(self) -> list[int]:
334 """ Stored derivative order of real density matrices in sorted order """
335 return list(sorted(self._re_buffers.keys()))
337 @property
338 def derivatives_imag(self) -> list[int]:
339 """ Stored derivative order of real density matrices in sorted order """
340 return list(sorted(self._im_buffers.keys()))
342 def _iter_buffers(self) -> Generator[NDArray[DTypeT], None, None]:
343 """ Loop over all real imaginary buffers in a sorted order """
344 for derivative in self.derivatives_real:
345 yield self._re_buffers[derivative]
346 for derivative in self.derivatives_imag:
347 yield self._im_buffers[derivative]
349 def _iter_reim_derivatives(self) -> Generator[tuple[bool, int], None, None]:
350 """ Loop over tuples (real, derivative) in sorted order.
352 The parameter real is ``True`` for real buffers and the parameter derivative denotes the
353 derivative order of the buffer.
354 """
355 for derivative in self.derivatives_real:
356 yield (True, derivative)
357 for derivative in self.derivatives_imag:
358 yield (False, derivative)
360 def ensure_contiguous_buffers(self):
361 """ Make the buffers contiguous if they are not already. """
362 for derivative in self.derivatives_real:
363 self._re_buffers[derivative] = np.ascontiguousarray(self._re_buffers[derivative])
364 for derivative in self.derivatives_imag:
365 self._im_buffers[derivative] = np.ascontiguousarray(self._im_buffers[derivative])
367 def send_arrays(self,
368 comm,
369 rank: int,
370 log: Logger | None = None):
371 """ Send data to another MPI rank.
373 Parameters
374 ----------
375 comm
376 Communicator.
377 rank
378 Send to this rank.
379 log
380 Optional logger.
381 """
382 if comm.rank == rank:
383 # Sending to send
384 return
386 if log is not None:
387 log.start('send_to_root')
389 requests = []
390 for mpitag, buffer_nnx in enumerate(self._iter_buffers(), start=987):
391 buffer_nnx = np.ascontiguousarray(buffer_nnx)
392 requests.append(comm.send(buffer_nnx, 0, tag=mpitag, block=False))
393 comm.waitall(requests)
395 if log is not None:
396 log(f'Sending to root {log.elapsed("send_to_root"):.1f}s', who='Response', flush=True)
398 def recv_arrays(self,
399 comm,
400 rank: int,
401 log: Logger | None = None):
402 """ Receive data to another MPI rank.
404 Parameters
405 ----------
406 comm
407 Communicator.
408 rank
409 Send to this rank.
410 log
411 Optional logger.
412 """
413 if comm.rank == rank:
414 # Receiving from self
415 return
417 if log is not None:
418 log.start('root_recv')
420 requests = []
421 for mpitag, buffer_nnx in enumerate(self._iter_buffers(), start=987):
422 requests.append(comm.receive(buffer_nnx, rank, tag=mpitag, block=False))
423 comm.waitall(requests)
425 if log is not None:
426 log(f'Root received {log.elapsed("root_recv"):.1f}s from {rank}', who='Response', flush=True)
428 def redistribute(self,
429 target: DensityMatrixBuffer,
430 comm,
431 source_indices_r: Sequence[tuple[ArrayIndex, ...] | ArrayIndex | None],
432 target_indices_r: Sequence[tuple[ArrayIndex, ...] | ArrayIndex | None],
433 log: Logger | None = None,
434 ):
435 """ Redistribute this DensityMatrixBuffer to another according to user specified options.
437 The nn dimensions of the self and target buffers should be the same,
438 but the x dimensions can be different. However, self need to have the same shape on all ranks
439 and target needs to have the same shape on all ranks.
441 Parameters
442 ----------
443 target
444 Target :class:`DensityMatrixBuffer`.
445 comm
446 MPI communicator.
447 source_indices_r
448 List of x-indices. The length of the list must equal to the communicator size.
449 The x-index that is element r of the list corresponds
450 to the data from the source that will be sent to rank r.
451 If the x-index is None, then the rank corresponding to that element will not be
452 receiving data. This argument must be the same on all ranks
453 recv_indices_r
454 List of x-indices. The length of the list must equal to the communicator size.
455 The x-index that is element r of the list corresponds
456 to the data in the target that will be received from rank r.
457 If the x-index is None, then the rank corresponding to that element will not be
458 sending data. This argument must be the same on all ranks
459 log
460 Optional logger.
461 """
462 # Size of each density matrix (the nn-dimensions)
463 nnsize = int(np.prod(self.nnshape))
464 # Convert maxsize to maximum number of elements
465 maxsize = env.get_float('REDISTRIBUTE_MAXSIZE')
466 maxtotalelems = int(np.ceil(maxsize / self.dtype.itemsize))
468 assert len(source_indices_r) == comm.size, len(source_indices_r)
469 assert len(target_indices_r) == comm.size, len(target_indices_r)
471 # Extract source and target indices that are not None and make sure they are tuples
472 source_indices_by_rank = {rank: x_indices if isinstance(x_indices, tuple) else (x_indices, )
473 for rank, x_indices in enumerate(source_indices_r)
474 if x_indices is not None}
475 target_indices_by_rank = {rank: x_indices if isinstance(x_indices, tuple) else (x_indices, )
476 for rank, x_indices in enumerate(target_indices_r)
477 if x_indices is not None}
478 assert len(source_indices_by_rank) > 0
479 assert len(target_indices_by_rank) > 0
481 # Make sure that same derivatives and real/imaginary parts are stored and that dtypes are same
482 assert tuple(self.derivatives_real) == tuple(target.derivatives_real)
483 assert tuple(self.derivatives_imag) == tuple(target.derivatives_imag)
484 assert self.dtype == target.dtype, f'{self.dtype} != {target.dtype}'
486 # Get the xshapes of all sources
487 source_xshape_by_rank: dict[int, tuple[int, ...]] = dict()
488 if comm.rank in target_indices_by_rank.keys():
489 for buf_nnx in self._iter_buffers():
490 source_xshape_by_rank = {rank: buf_nnx[self.nnellipsis + x_indices].shape[2:]
491 for rank, x_indices in source_indices_by_rank.items()}
492 break
494 # Get the xshapes of the targets by an alltoall operation with the sources
495 # -2 means nothing, -1 in first field means empty tuple
496 xdims = max(len(self.xshape), len(target.xshape))
497 pad_target_xshape_r = -2 * np.ones((comm.size, xdims), dtype=int)
498 pad_source_xshape_r = -2 * np.ones((comm.size, xdims), dtype=int)
499 for rank, xshape in source_xshape_by_rank.items():
500 pad_source_xshape_r[rank, :len(xshape)] = xshape
501 if xshape == ():
502 pad_source_xshape_r[rank, 0] = -1
503 comm.alltoallv(pad_source_xshape_r,
504 xdims * np.ones(comm.size, dtype=int),
505 xdims * np.arange(comm.size, dtype=int),
506 pad_target_xshape_r,
507 xdims * np.ones(comm.size, dtype=int),
508 xdims * np.arange(comm.size, dtype=int))
509 target_xshape_by_rank = {rank: tuple(xshape[xshape > -1])
510 for rank, xshape in enumerate(pad_target_xshape_r) if xshape[0] > -2}
511 # Check that target sizes supplied by the user are shorter than
512 # or equal to the sizes from the alltoall operation
513 if comm.rank in source_indices_by_rank.keys():
514 for buf_nnx in target._iter_buffers():
515 for rank, x_indices in target_indices_by_rank.items():
516 target_xshape = target_xshape_by_rank[rank]
517 xshape = buf_nnx[self.nnellipsis + x_indices].shape[2:]
518 assert all(np.less_equal(target_xshape, xshape))
519 break
521 # Get the total number of density matrices that this rank sends/receives
522 source_xsize_by_rank = {rank: np.prod(xshape, dtype=int)
523 for rank, xshape in source_xshape_by_rank.items()}
524 target_xsize_by_rank = {rank: np.prod(xshape, dtype=int)
525 for rank, xshape in target_xshape_by_rank.items()}
526 my_sourcexsize = sum(source_xsize_by_rank.values())
527 my_targetxsize = sum(target_xsize_by_rank.values())
528 # Get the total number of array elements to be sent across all ranks
529 sizes = np.array([my_sourcexsize, my_targetxsize], dtype=int)
530 comm.sum(sizes, root=-1)
531 total_sourcexsize, total_targetxsize = sizes
532 totalsize = max(total_sourcexsize, total_targetxsize) * nnsize
534 # Split the data across the nn-dimensions since they are always the same; how many times?
535 factortoolarge = totalsize / maxtotalelems
536 nnstride = int(np.ceil(nnsize / factortoolarge))
537 nnstride = min(nnsize, nnstride)
538 nsplits = int(np.ceil(nnsize / nnstride))
539 if log is not None and comm.rank == 0:
540 total_MiB = totalsize * self.dtype.itemsize / (1024 ** 2)
541 buftotal_MiB = totalsize / nsplits * self.dtype.itemsize / (1024 ** 2)
543 log(f'Redistribute: {len(target_indices_by_rank)} sending '
544 f'and {len(source_indices_by_rank)} receiving. '
545 f'Total size on all ranks ({total_MiB:.1f} MiB) '
546 f'splitting in {nsplits} parts ({buftotal_MiB:.1f} MiB on all ranks)',
547 who='Response', flush=True)
549 # Perpare buffers for sending and receiving
550 # counts - Number of elements to send to (s) or receive from (r) each rank
551 # displs - Position of data to send to (s) or receive from (r) each rank
552 sendbuf = np.zeros(my_sourcexsize * nnstride, dtype=self.dtype)
553 recvbuf = np.zeros(my_targetxsize * nnstride, dtype=self.dtype)
554 scounts_r = np.zeros(comm.size, dtype=int)
555 sdispls_r = np.zeros(comm.size, dtype=int)
556 rcounts_r = np.zeros(comm.size, dtype=int)
557 rdispls_r = np.zeros(comm.size, dtype=int)
559 displ = 0
560 if comm.rank in target_indices_by_rank.keys():
561 # This rank has some data to send. It will send to the ranks that are among the source keys
562 sendbuf_by_rank = dict()
563 for buf_nnx in self._iter_buffers():
564 for destrank, xsize in source_xsize_by_rank.items():
565 size = xsize * nnstride
566 sendbuf_by_rank[destrank] = sendbuf[displ:displ+size]
567 scounts_r[destrank] = size
568 if size > 0:
569 sdispls_r[destrank] = displ
570 displ += size
571 break
572 displ = 0
573 if comm.rank in source_indices_by_rank.keys():
574 # This rank has some data to receive. It will receive from the ranks that are among the target keys
575 recvbuf_by_rank = dict()
576 for buf_nnx in self._iter_buffers():
577 for destrank, xsize in target_xsize_by_rank.items():
578 size = xsize * nnstride
579 recvbuf_by_rank[destrank] = recvbuf[displ:displ+size]
580 rcounts_r[destrank] = size
581 if size > 0:
582 rdispls_r[destrank] = displ
583 displ += size
584 break
586 # Flattened nn-dimensions for splitting
587 flatslices = [slice(start, start + nnstride, 1) for start in range(0, nnsize, nnstride)]
588 grid = np.mgrid[:self.nnshape[0], :self.nnshape[1]]
590 # Loop over real and imaginary parts and derivatives
591 for (real, derivative), sendbuf_nnx, recvbuf_nnx in zip(
592 self._iter_reim_derivatives(), self._iter_buffers(), target._iter_buffers()):
594 # List of data to send and list of buffers where data should be received
595 if comm.rank in target_indices_by_rank.keys():
596 senddata_by_rank = {rank: sendbuf_nnx[self.nnellipsis + x_indices]
597 for rank, x_indices in source_indices_by_rank.items()}
598 if comm.rank in source_indices_by_rank.keys():
599 recvdata_by_rank = {rank: recvbuf_nnx[self.nnellipsis + x_indices]
600 for rank, x_indices in target_indices_by_rank.items()}
601 # The target data may be smaller than what is given by the user
602 recvdata_by_rank = {rank: recvdata_by_rank[rank][
603 self.nnellipsis + tuple([slice(dim) for dim in xshape])]
604 for rank, xshape in target_xshape_by_rank.items()}
605 for data, xshape in zip(recvdata_by_rank.values(), target_xshape_by_rank.values()):
606 assert data.shape[2:] == xshape, str(data.shape[2:]) + ' != ' + str(xshape)
608 # Loop over the data in splits
609 for flatslice in flatslices:
610 # Slices of nn
611 nnslices = (grid[0].ravel()[flatslice], grid[1].ravel()[flatslice])
613 # Copy data to the contiguous send buffer
614 if comm.rank in target_indices_by_rank.keys():
615 for destrank, buf in sendbuf_by_rank.items():
616 data = senddata_by_rank[destrank][nnslices].ravel()
617 buf[:len(data)] = data
619 # Send the data
620 comm.alltoallv(sendbuf, scounts_r, sdispls_r,
621 recvbuf, rcounts_r, rdispls_r)
623 # Copy data from the contiguous receive buffer
624 if comm.rank in source_indices_by_rank.keys():
625 for destrank, buf in recvbuf_by_rank.items():
626 # Copy the first elements of the receive buffer to the data position
627 datashape = recvdata_by_rank[destrank][nnslices].shape
628 datalen = np.prod(datashape, dtype=int)
629 buf = buf[:datalen]
630 recvdata_by_rank[destrank][nnslices] = buf.reshape(datashape)