Coverage for rhodent/utils/memory.py: 97%
68 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
1from __future__ import annotations
3from abc import ABC, abstractmethod
4from dataclasses import dataclass, field
6import numpy as np
7from numpy._typing import _DTypeLike as DTypeLike
10@dataclass
11class MemoryEntry:
13 shape: tuple[int, ...]
14 dtype: np.dtype
15 on_num_ranks: int = 1
16 total_size: int | None = None
18 def get_total_size(self) -> int:
19 """ Get the total number of elements on all ranks. """
20 if self.total_size is not None:
21 return self.total_size
22 return int(np.prod(self.shape)) * self.on_num_ranks
25@dataclass
26class MemoryEstimate:
28 comment: str = ''
29 children: dict[str, MemoryEstimate] = field(default_factory=dict)
30 arrays: dict[str, MemoryEntry] = field(default_factory=dict)
32 def __str__(self) -> str:
33 if len(self.children) == 0 and len(self.arrays) == 0:
34 return 'Unknown'
35 to_MiB = 1024 ** -2
36 totalstr = f'{self.grand_total * to_MiB:.1f} MiB'
38 lines = []
39 if self.comment != '':
40 lines += ['Note: ' + line for line in self.comment.split('\n')]
41 lines.append('')
43 for key, entry in self.arrays.items():
44 size_per_rank_MiB = int(np.prod(entry.shape)) * entry.dtype.itemsize * to_MiB
45 size_total_MiB = entry.get_total_size() * entry.dtype.itemsize * to_MiB
47 lines.append(f'{key}: {entry.shape} {entry.dtype}')
48 lines.append(f'. {size_per_rank_MiB:.1f} MiB '
49 f'per rank on {entry.on_num_ranks} ranks')
50 lines.append(f'. {size_total_MiB:.1f} MiB in total on all ranks')
51 lines.append('')
53 for name, child in self.children.items():
54 lines.append(f'{name}:')
55 lines += [' ' + line for line in str(child).split('\n')]
56 lines.append('')
57 lines.append(f'{"":.^24}')
58 lines.append(f'{" Total on all ranks ":.^24}')
59 lines.append(f'{totalstr:.^24}')
61 return '\n'.join(lines)
63 @property
64 def grand_total(self) -> int:
65 """ Grand total of bytes. """
67 total = 0
68 for entry in self.arrays.values():
69 total += entry.get_total_size() * entry.dtype.itemsize
71 for child in self.children.values():
72 total += child.grand_total
74 return total
76 def add_key(self,
77 key: str,
78 shape: tuple[int, ...] | int = (),
79 dtype: DTypeLike = float,
80 *,
81 total_size: int | None = None,
82 on_num_ranks: int = 1):
83 assert isinstance(key, str)
84 if isinstance(shape, int):
85 shape = (shape, )
86 assert all([isinstance(d, (int, np.integer)) for d in shape])
87 shape = tuple(int(d) for d in shape)
88 assert isinstance(on_num_ranks, int)
89 dtype = np.dtype(dtype)
90 self.arrays[key] = MemoryEntry(shape, dtype,
91 total_size=total_size,
92 on_num_ranks=on_num_ranks)
94 def add_child(self,
95 name: str,
96 child: MemoryEstimate):
97 assert isinstance(child, MemoryEstimate)
98 self.children[name] = child
101class HasMemoryEstimate(ABC):
103 """ Classes inheriting from this class are able to
104 provide a memory estimate """
106 @abstractmethod
107 def get_memory_estimate(self) -> MemoryEstimate:
108 raise NotImplementedError