Source code for rhodent.utils.memory

from __future__ import annotations

from abc import ABC, abstractmethod
from dataclasses import dataclass, field

import numpy as np
from numpy._typing import _DTypeLike as DTypeLike


[docs] @dataclass class MemoryEntry: shape: tuple[int, ...] dtype: np.dtype on_num_ranks: int = 1 total_size: int | None = None
[docs] def get_total_size(self) -> int: """ Get the total number of elements on all ranks. """ if self.total_size is not None: return self.total_size return int(np.prod(self.shape)) * self.on_num_ranks
[docs] @dataclass class MemoryEstimate: comment: str = '' children: dict[str, MemoryEstimate] = field(default_factory=dict) arrays: dict[str, MemoryEntry] = field(default_factory=dict) def __str__(self) -> str: if len(self.children) == 0 and len(self.arrays) == 0: return 'Unknown' to_MiB = 1024 ** -2 totalstr = f'{self.grand_total * to_MiB:.1f} MiB' lines = [] if self.comment != '': lines += ['Note: ' + line for line in self.comment.split('\n')] lines.append('') for key, entry in self.arrays.items(): size_per_rank_MiB = int(np.prod(entry.shape)) * entry.dtype.itemsize * to_MiB size_total_MiB = entry.get_total_size() * entry.dtype.itemsize * to_MiB lines.append(f'{key}: {entry.shape} {entry.dtype}') lines.append(f'. {size_per_rank_MiB:.1f} MiB ' f'per rank on {entry.on_num_ranks} ranks') lines.append(f'. {size_total_MiB:.1f} MiB in total on all ranks') lines.append('') for name, child in self.children.items(): lines.append(f'{name}:') lines += [' ' + line for line in str(child).split('\n')] lines.append('') lines.append(f'{"":.^24}') lines.append(f'{" Total on all ranks ":.^24}') lines.append(f'{totalstr:.^24}') return '\n'.join(lines) @property def grand_total(self) -> int: """ Grand total of bytes. """ total = 0 for entry in self.arrays.values(): total += entry.get_total_size() * entry.dtype.itemsize for child in self.children.values(): total += child.grand_total return total def add_key(self, key: str, shape: tuple[int, ...] | int = (), dtype: DTypeLike = float, *, total_size: int | None = None, on_num_ranks: int = 1): assert isinstance(key, str) if isinstance(shape, int): shape = (shape, ) assert all([isinstance(d, (int, np.integer)) for d in shape]) shape = tuple(int(d) for d in shape) assert isinstance(on_num_ranks, int) dtype = np.dtype(dtype) self.arrays[key] = MemoryEntry(shape, dtype, total_size=total_size, on_num_ranks=on_num_ranks) def add_child(self, name: str, child: MemoryEstimate): assert isinstance(child, MemoryEstimate) self.children[name] = child
[docs] class HasMemoryEstimate(ABC): """ Classes inheriting from this class are able to provide a memory estimate """ @abstractmethod def get_memory_estimate(self) -> MemoryEstimate: raise NotImplementedError