Source code for rhodent.utils.memory
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
import numpy as np
from numpy._typing import _DTypeLike as DTypeLike
[docs]
@dataclass
class MemoryEntry:
    shape: tuple[int, ...]
    dtype: np.dtype
    on_num_ranks: int = 1
    total_size: int | None = None
[docs]
    def get_total_size(self) -> int:
        """ Get the total number of elements on all ranks. """
        if self.total_size is not None:
            return self.total_size
        return int(np.prod(self.shape)) * self.on_num_ranks 
 
[docs]
@dataclass
class MemoryEstimate:
    comment: str = ''
    children: dict[str, MemoryEstimate] = field(default_factory=dict)
    arrays: dict[str, MemoryEntry] = field(default_factory=dict)
    def __str__(self) -> str:
        if len(self.children) == 0 and len(self.arrays) == 0:
            return 'Unknown'
        to_MiB = 1024 ** -2
        totalstr = f'{self.grand_total * to_MiB:.1f} MiB'
        lines = []
        if self.comment != '':
            lines += ['Note: ' + line for line in self.comment.split('\n')]
            lines.append('')
        for key, entry in self.arrays.items():
            size_per_rank_MiB = int(np.prod(entry.shape)) * entry.dtype.itemsize * to_MiB
            size_total_MiB = entry.get_total_size() * entry.dtype.itemsize * to_MiB
            lines.append(f'{key}: {entry.shape} {entry.dtype}')
            lines.append(f'. {size_per_rank_MiB:.1f} MiB '
                         f'per rank on {entry.on_num_ranks} ranks')
            lines.append(f'. {size_total_MiB:.1f} MiB in total on all ranks')
            lines.append('')
        for name, child in self.children.items():
            lines.append(f'{name}:')
            lines += ['    ' + line for line in str(child).split('\n')]
            lines.append('')
        lines.append(f'{"":.^24}')
        lines.append(f'{" Total on all ranks ":.^24}')
        lines.append(f'{totalstr:.^24}')
        return '\n'.join(lines)
    @property
    def grand_total(self) -> int:
        """ Grand total of bytes. """
        total = 0
        for entry in self.arrays.values():
            total += entry.get_total_size() * entry.dtype.itemsize
        for child in self.children.values():
            total += child.grand_total
        return total
    def add_key(self,
                key: str,
                shape: tuple[int, ...] | int = (),
                dtype: DTypeLike = float,
                *,
                total_size: int | None = None,
                on_num_ranks: int = 1):
        assert isinstance(key, str)
        if isinstance(shape, int):
            shape = (shape, )
        assert all([isinstance(d, (int, np.integer)) for d in shape])
        shape = tuple(int(d) for d in shape)
        assert isinstance(on_num_ranks, int)
        dtype = np.dtype(dtype)
        self.arrays[key] = MemoryEntry(shape, dtype,
                                       total_size=total_size,
                                       on_num_ranks=on_num_ranks)
    def add_child(self,
                  name: str,
                  child: MemoryEstimate):
        assert isinstance(child, MemoryEstimate)
        self.children[name] = child 
[docs]
class HasMemoryEstimate(ABC):
    """ Classes inheriting from this class are able to
    provide a memory estimate """
    @abstractmethod
    def get_memory_estimate(self) -> MemoryEstimate:
        raise NotImplementedError