Coverage for rhodent/utils/memory.py: 97%

68 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-01 16:57 +0000

1from __future__ import annotations 

2 

3from abc import ABC, abstractmethod 

4from dataclasses import dataclass, field 

5 

6import numpy as np 

7from numpy._typing import _DTypeLike as DTypeLike 

8 

9 

10@dataclass 

11class MemoryEntry: 

12 

13 shape: tuple[int, ...] 

14 dtype: np.dtype 

15 on_num_ranks: int = 1 

16 total_size: int | None = None 

17 

18 def get_total_size(self) -> int: 

19 """ Get the total number of elements on all ranks. """ 

20 if self.total_size is not None: 

21 return self.total_size 

22 return int(np.prod(self.shape)) * self.on_num_ranks 

23 

24 

25@dataclass 

26class MemoryEstimate: 

27 

28 comment: str = '' 

29 children: dict[str, MemoryEstimate] = field(default_factory=dict) 

30 arrays: dict[str, MemoryEntry] = field(default_factory=dict) 

31 

32 def __str__(self) -> str: 

33 if len(self.children) == 0 and len(self.arrays) == 0: 

34 return 'Unknown' 

35 to_MiB = 1024 ** -2 

36 totalstr = f'{self.grand_total * to_MiB:.1f} MiB' 

37 

38 lines = [] 

39 if self.comment != '': 

40 lines += ['Note: ' + line for line in self.comment.split('\n')] 

41 lines.append('') 

42 

43 for key, entry in self.arrays.items(): 

44 size_per_rank_MiB = int(np.prod(entry.shape)) * entry.dtype.itemsize * to_MiB 

45 size_total_MiB = entry.get_total_size() * entry.dtype.itemsize * to_MiB 

46 

47 lines.append(f'{key}: {entry.shape} {entry.dtype}') 

48 lines.append(f'. {size_per_rank_MiB:.1f} MiB ' 

49 f'per rank on {entry.on_num_ranks} ranks') 

50 lines.append(f'. {size_total_MiB:.1f} MiB in total on all ranks') 

51 lines.append('') 

52 

53 for name, child in self.children.items(): 

54 lines.append(f'{name}:') 

55 lines += [' ' + line for line in str(child).split('\n')] 

56 lines.append('') 

57 lines.append(f'{"":.^24}') 

58 lines.append(f'{" Total on all ranks ":.^24}') 

59 lines.append(f'{totalstr:.^24}') 

60 

61 return '\n'.join(lines) 

62 

63 @property 

64 def grand_total(self) -> int: 

65 """ Grand total of bytes. """ 

66 

67 total = 0 

68 for entry in self.arrays.values(): 

69 total += entry.get_total_size() * entry.dtype.itemsize 

70 

71 for child in self.children.values(): 

72 total += child.grand_total 

73 

74 return total 

75 

76 def add_key(self, 

77 key: str, 

78 shape: tuple[int, ...] | int = (), 

79 dtype: DTypeLike = float, 

80 *, 

81 total_size: int | None = None, 

82 on_num_ranks: int = 1): 

83 assert isinstance(key, str) 

84 if isinstance(shape, int): 

85 shape = (shape, ) 

86 assert all([isinstance(d, (int, np.integer)) for d in shape]) 

87 shape = tuple(int(d) for d in shape) 

88 assert isinstance(on_num_ranks, int) 

89 dtype = np.dtype(dtype) 

90 self.arrays[key] = MemoryEntry(shape, dtype, 

91 total_size=total_size, 

92 on_num_ranks=on_num_ranks) 

93 

94 def add_child(self, 

95 name: str, 

96 child: MemoryEstimate): 

97 assert isinstance(child, MemoryEstimate) 

98 self.children[name] = child 

99 

100 

101class HasMemoryEstimate(ABC): 

102 

103 """ Classes inheriting from this class are able to 

104 provide a memory estimate """ 

105 

106 @abstractmethod 

107 def get_memory_estimate(self) -> MemoryEstimate: 

108 raise NotImplementedError