Coverage for rhodent/utils/result.py: 72%

101 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-01 16:57 +0000

1from __future__ import annotations 

2 

3import numpy as np 

4from numpy.typing import NDArray, DTypeLike 

5 

6from typing import Iterator 

7 

8 

9class ResultKeys(): 

10 

11 """ List of result keys. 

12 

13 """ 

14 

15 _keys_dimensions_dtypes: dict[str, tuple[tuple[int, ...], np.dtype]] 

16 

17 def __init__(self, 

18 *scalar_keys): 

19 self._keys_dimensions_dtypes = dict() 

20 

21 for key in scalar_keys: 

22 self.add_key(key, (), float) 

23 

24 def add_key(self, 

25 key: str, 

26 shape: tuple[int, ...] | int = (), 

27 dtype: DTypeLike = float): 

28 """ Add a new result key. 

29 

30 Parameters 

31 ---------- 

32 key 

33 Name of result. 

34 shape 

35 Shape of result (at one time or frequency instance). Default is scalar. 

36 dtype 

37 Result dtype. 

38 """ 

39 

40 assert isinstance(key, str) 

41 if isinstance(shape, int): 

42 shape = (shape, ) 

43 assert all([isinstance(d, int) for d in shape]) 

44 dtype = np.dtype(dtype) 

45 self._keys_dimensions_dtypes[key] = (shape, dtype) 

46 

47 def remove(self, 

48 key: str): 

49 assert key in self 

50 self._keys_dimensions_dtypes.pop(key) 

51 

52 def __contains__(self, 

53 key: str) -> bool: 

54 return key in self._keys_dimensions_dtypes.keys() 

55 

56 def __iter__(self) -> Iterator[tuple[str, tuple[int, ...], np.dtype]]: 

57 for key, (shape, dtype) in self._keys_dimensions_dtypes.items(): 

58 yield key, shape, dtype 

59 

60 def __getitem__(self, 

61 key: str) -> tuple[tuple[int, ...], np.typing.DTypeLike]: 

62 assert key in self._keys_dimensions_dtypes, f'Key {key} not among keys' 

63 return self._keys_dimensions_dtypes[key] 

64 

65 def __copy__(self): 

66 cpy = ResultKeys() 

67 cpy._keys_dimensions_dtypes.update(self._keys_dimensions_dtypes) 

68 return cpy 

69 

70 

71class Result: 

72 

73 """ Class holding results. 

74 

75 """ 

76 

77 _data: dict[str, NDArray[np.float64]] 

78 

79 def __init__(self, 

80 mutable: bool = False): 

81 self._data = dict() 

82 self._mutable = mutable 

83 

84 def __contains__(self, 

85 key: str) -> bool: 

86 return key in self._data 

87 

88 def __setitem__(self, 

89 key: str, 

90 value: np.typing.ArrayLike | int): 

91 if not self._mutable: 

92 assert key not in self._data, f'Key {key} is already among results' 

93 if np.ndim(value) == 0: 

94 value = np.array([value]) 

95 self._data[key] = np.ascontiguousarray(value) 

96 

97 def __getitem__(self, 

98 key: str) -> NDArray[np.float64]: 

99 assert key in self._data, f'Key {key} not among results' 

100 return self._data[key] 

101 

102 def __str__(self) -> str: 

103 lines = [f'{self.__class__.__name__} with arrays (dimensions)'] 

104 

105 for key, data in self._data.items(): 

106 lines.append(f' {key} {data.shape}') 

107 

108 return '\n'.join(lines) 

109 

110 def set_to(self, 

111 key: str, 

112 idx, 

113 value: np.typing.ArrayLike | int | float): 

114 if np.ndim(self._data[key][idx]) == 0: 

115 assert np.size(value) == 1 

116 value = np.atleast_1d(value)[0] 

117 self._data[key][idx] = value 

118 

119 def add_to(self, 

120 key: str, 

121 idx, 

122 value: np.typing.ArrayLike | int | float): 

123 if np.ndim(self._data[key][idx]) == 0: 

124 assert np.size(value) == 1 

125 value = np.atleast_1d(value)[0] 

126 self._data[key][idx] += value 

127 

128 def create_all_empty(self, 

129 keys: ResultKeys): 

130 for key, shape, dtype in keys: 

131 if key in self: 

132 continue 

133 self[key] = np.empty(shape, dtype=dtype) 

134 

135 def create_all_zeros(self, 

136 keys: ResultKeys): 

137 for key, shape, dtype in keys: 

138 if key in self: 

139 continue 

140 self[key] = np.zeros(shape, dtype=dtype) 

141 

142 def remove(self, 

143 key: str): 

144 assert key in self._data 

145 self._data.pop(key) 

146 

147 def empty(self, 

148 key: str, 

149 keys: ResultKeys): 

150 shape, dtype = keys[key] 

151 self[key] = np.empty(shape, dtype=dtype) 

152 

153 def assert_keys(self, 

154 keys: ResultKeys): 

155 copy = dict(self._data) 

156 try: 

157 for key, shape, dtype in keys: 

158 array = copy.pop(key) 

159 if len(shape) == 0: 

160 assert array.shape == (1, ), f'{array.shape} != (1,)' 

161 else: 

162 assert array.shape == shape, f'{array.shape} != {shape}' 

163 assert array.dtype == dtype, f'{array.dtype} != {dtype}' 

164 except KeyError: 

165 raise AssertionError(f'Key {key} missing from Result') 

166 assert len(copy) == 0, f'Result has additional keys {copy.keys()}' 

167 

168 def send(self, 

169 keys: ResultKeys, 

170 rank, 

171 comm): 

172 self.assert_keys(keys) 

173 for vi, (key, _, _) in enumerate(keys): 

174 value = self._data[key] 

175 comm.send(value, rank, tag=100 + vi) 

176 

177 def inplace_receive(self, 

178 keys: ResultKeys, 

179 rank: int, 

180 comm): 

181 self.assert_keys(keys) 

182 for vi, (key, _, _) in enumerate(keys): 

183 value = self._data[key] 

184 comm.receive(value, rank, tag=100 + vi)