Coverage for rhodent/utils/result.py: 72%
101 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
1from __future__ import annotations
3import numpy as np
4from numpy.typing import NDArray, DTypeLike
6from typing import Iterator
9class ResultKeys():
11 """ List of result keys.
13 """
15 _keys_dimensions_dtypes: dict[str, tuple[tuple[int, ...], np.dtype]]
17 def __init__(self,
18 *scalar_keys):
19 self._keys_dimensions_dtypes = dict()
21 for key in scalar_keys:
22 self.add_key(key, (), float)
24 def add_key(self,
25 key: str,
26 shape: tuple[int, ...] | int = (),
27 dtype: DTypeLike = float):
28 """ Add a new result key.
30 Parameters
31 ----------
32 key
33 Name of result.
34 shape
35 Shape of result (at one time or frequency instance). Default is scalar.
36 dtype
37 Result dtype.
38 """
40 assert isinstance(key, str)
41 if isinstance(shape, int):
42 shape = (shape, )
43 assert all([isinstance(d, int) for d in shape])
44 dtype = np.dtype(dtype)
45 self._keys_dimensions_dtypes[key] = (shape, dtype)
47 def remove(self,
48 key: str):
49 assert key in self
50 self._keys_dimensions_dtypes.pop(key)
52 def __contains__(self,
53 key: str) -> bool:
54 return key in self._keys_dimensions_dtypes.keys()
56 def __iter__(self) -> Iterator[tuple[str, tuple[int, ...], np.dtype]]:
57 for key, (shape, dtype) in self._keys_dimensions_dtypes.items():
58 yield key, shape, dtype
60 def __getitem__(self,
61 key: str) -> tuple[tuple[int, ...], np.typing.DTypeLike]:
62 assert key in self._keys_dimensions_dtypes, f'Key {key} not among keys'
63 return self._keys_dimensions_dtypes[key]
65 def __copy__(self):
66 cpy = ResultKeys()
67 cpy._keys_dimensions_dtypes.update(self._keys_dimensions_dtypes)
68 return cpy
71class Result:
73 """ Class holding results.
75 """
77 _data: dict[str, NDArray[np.float64]]
79 def __init__(self,
80 mutable: bool = False):
81 self._data = dict()
82 self._mutable = mutable
84 def __contains__(self,
85 key: str) -> bool:
86 return key in self._data
88 def __setitem__(self,
89 key: str,
90 value: np.typing.ArrayLike | int):
91 if not self._mutable:
92 assert key not in self._data, f'Key {key} is already among results'
93 if np.ndim(value) == 0:
94 value = np.array([value])
95 self._data[key] = np.ascontiguousarray(value)
97 def __getitem__(self,
98 key: str) -> NDArray[np.float64]:
99 assert key in self._data, f'Key {key} not among results'
100 return self._data[key]
102 def __str__(self) -> str:
103 lines = [f'{self.__class__.__name__} with arrays (dimensions)']
105 for key, data in self._data.items():
106 lines.append(f' {key} {data.shape}')
108 return '\n'.join(lines)
110 def set_to(self,
111 key: str,
112 idx,
113 value: np.typing.ArrayLike | int | float):
114 if np.ndim(self._data[key][idx]) == 0:
115 assert np.size(value) == 1
116 value = np.atleast_1d(value)[0]
117 self._data[key][idx] = value
119 def add_to(self,
120 key: str,
121 idx,
122 value: np.typing.ArrayLike | int | float):
123 if np.ndim(self._data[key][idx]) == 0:
124 assert np.size(value) == 1
125 value = np.atleast_1d(value)[0]
126 self._data[key][idx] += value
128 def create_all_empty(self,
129 keys: ResultKeys):
130 for key, shape, dtype in keys:
131 if key in self:
132 continue
133 self[key] = np.empty(shape, dtype=dtype)
135 def create_all_zeros(self,
136 keys: ResultKeys):
137 for key, shape, dtype in keys:
138 if key in self:
139 continue
140 self[key] = np.zeros(shape, dtype=dtype)
142 def remove(self,
143 key: str):
144 assert key in self._data
145 self._data.pop(key)
147 def empty(self,
148 key: str,
149 keys: ResultKeys):
150 shape, dtype = keys[key]
151 self[key] = np.empty(shape, dtype=dtype)
153 def assert_keys(self,
154 keys: ResultKeys):
155 copy = dict(self._data)
156 try:
157 for key, shape, dtype in keys:
158 array = copy.pop(key)
159 if len(shape) == 0:
160 assert array.shape == (1, ), f'{array.shape} != (1,)'
161 else:
162 assert array.shape == shape, f'{array.shape} != {shape}'
163 assert array.dtype == dtype, f'{array.dtype} != {dtype}'
164 except KeyError:
165 raise AssertionError(f'Key {key} missing from Result')
166 assert len(copy) == 0, f'Result has additional keys {copy.keys()}'
168 def send(self,
169 keys: ResultKeys,
170 rank,
171 comm):
172 self.assert_keys(keys)
173 for vi, (key, _, _) in enumerate(keys):
174 value = self._data[key]
175 comm.send(value, rank, tag=100 + vi)
177 def inplace_receive(self,
178 keys: ResultKeys,
179 rank: int,
180 comm):
181 self.assert_keys(keys)
182 for vi, (key, _, _) in enumerate(keys):
183 value = self._data[key]
184 comm.receive(value, rank, tag=100 + vi)