Coverage for rhodent/utils/logging.py: 85%

113 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-01 16:57 +0000

1from __future__ import annotations 

2 

3import os 

4import sys 

5import time 

6 

7import numpy as np 

8 

9import ase 

10import gpaw 

11from gpaw.mpi import world 

12from gpaw.tddft.units import au_to_eV, au_to_as 

13 

14from .. import __version__ 

15from ..typing import Communicator 

16 

17ascii_icon = r""" 

18 ### ##### 

19 ############### 

20 ######## ##### ## 

21 ########### ### 

22 ####### ### 

23 #### ## 

24 # ## 

25 # ## 

26 # ## 

27 # ## 

28 ## ### 

29 #### ### 

30 ########## ## 

31 ######## ### 

32 ### #### 

33 ## #### 

34 ## ###### 

35 ################# 

36""" 

37 

38 

39ascii_logo = r""" 

40 _ _ _ 

41 _ __| |__ ___ __| | ___ _ __ | |_ 

42| '__| '_ \ / _ \ / _` |/ _ \ '_ \| __| 

43| | | | | | (_) | (_| | __/ | | | |_ 

44|_| |_| |_|\___/ \__,_|\___|_| |_|\__| 

45""" 

46 

47 

48class Logger: 

49 

50 """ Logger 

51 

52 Parameters 

53 ---------- 

54 t0 

55 Start time (default is current time). 

56 """ 

57 _t0: float 

58 _starttimes: dict[str, float] 

59 

60 def __init__(self, 

61 t0: float | None = None): 

62 self._starttimes = dict() 

63 if t0 is None: 

64 self._t0 = time.time() 

65 else: 

66 assert isinstance(t0, float) 

67 self._t0 = t0 

68 self._time_of_last_log = self._t0 

69 

70 @property 

71 def t0(self) -> float: 

72 return self._t0 

73 

74 @t0.setter 

75 def t0(self, 

76 value: float | None): 

77 if value is None: 

78 self._t0 = time.time() 

79 return 

80 assert isinstance(value, float) 

81 self._t0 = value 

82 

83 def __getitem__(self, key) -> float: 

84 return self._starttimes.get(key, self.t0) 

85 

86 def __call__(self, 

87 *args, 

88 who: str | None = None, 

89 rank: int | None = None, 

90 if_elapsed: float = 0, 

91 comm: Communicator | None = None, 

92 **kwargs): 

93 """ Log message. 

94 

95 Parameters 

96 ---------- 

97 rank 

98 Only log if rank is :attr:`rank`. ``None`` to always log. 

99 who 

100 Sender of the message. 

101 comm 

102 Communicator. If included, rank and size is included in the message. 

103 if_elapsed 

104 Only log if :attr:`if_elapsed` seconds have passed since last logged message. 

105 """ 

106 

107 myrank = world.rank if comm is None else comm.rank 

108 if rank is not None and myrank != rank: 

109 return 

110 if time.time() < self._time_of_last_log + if_elapsed: 

111 return 

112 if comm is not None and comm.size > 1: 

113 commstr = f'{comm.rank:04.0f}/{comm.size:04.0f}' 

114 who = commstr if who is None else f'{who} {commstr}' 

115 _args = list(args) 

116 if who is not None: 

117 _args.insert(0, f'[{who}]') 

118 return self.log(*_args, **kwargs) 

119 

120 def __str__(self) -> str: 

121 s = f'{self.__class__.__name__} t0: {self.t0}' 

122 return s 

123 

124 def log(self, *args, **kwargs): 

125 """ Log message, prepending a timestamp. """ 

126 self._time_of_last_log = time.time() 

127 hh, rem = divmod(self._time_of_last_log - self.t0, 3600) 

128 mm, ss = divmod(rem, 60) 

129 timestr = f'[{hh:02.0f}:{mm:02.0f}:{ss:04.1f}]' 

130 print(f'{timestr}', *args, **kwargs) 

131 

132 def start(self, key): 

133 self._starttimes[key] = time.time() 

134 

135 def elapsed(self, key) -> float: 

136 return time.time() - self[key] 

137 

138 def startup_message(self): 

139 """ Print a start up message. """ 

140 if world.rank != 0: 

141 return 

142 

143 # Piece together logotype and version number 

144 logo_lines = ascii_logo.split('\n') 

145 width = max(len(line) for line in logo_lines) + 2 

146 i = -2 

147 logo_lines[i] += (width - len(logo_lines[i])) * ' ' # Pad to width 

148 logo_lines[i] += __version__ 

149 

150 # Piece together icon and logotype 

151 lines = ascii_icon.split('\n') 

152 width = max(len(line) for line in lines) 

153 

154 for i, logoline in enumerate(logo_lines, start=3): 

155 line = lines[i] 

156 line += (width - len(line)) * ' ' # Pad to same length 

157 lines[i] = line + logoline 

158 

159 print('\n'.join(lines)) 

160 print('Date: ', time.asctime()) 

161 print('CWD: ', os.getcwd()) 

162 print('cores: ', world.size) 

163 print('Python: {}.{}.{}'.format(*sys.version_info[:3])) 

164 print(f'numpy: {os.path.dirname(np.__file__)} (version {np.version.version})') 

165 print(f'ASE: {os.path.dirname(ase.__file__)} (version {ase.__version__})') 

166 print(f'GPAW: {os.path.dirname(gpaw.__file__)} (version {gpaw.__version__})') 

167 print(flush=True) 

168 

169 

170class NoLogger(Logger): 

171 

172 def __str__(self) -> str: 

173 return self.__class__.__name__ 

174 

175 def log(self, *args, **kwargs): 

176 pass 

177 

178 

179def format_times(times: np.typing.ArrayLike, 

180 units: str = 'as') -> str: 

181 """ Write a short list of times for pretty priting. 

182 

183 Parameters 

184 ---------- 

185 times 

186 List of times in units of :attr:`units`. 

187 units 

188 Units of the supplied times. 

189 

190 * ``au`` - atomic units 

191 * ``as`` - attoseconds 

192 

193 Returns 

194 ------- 

195 Formatted list of times in units of as. 

196 """ 

197 times = np.array(times) 

198 if units == 'au': 

199 times *= au_to_as 

200 elif units != 'as': 

201 raise ValueError(f'Unknown units {units}. Must be "au" or "as".') 

202 if len(times) < 5: 

203 # Print all times 

204 timesstrings = [f'{time:.1f}' for time in times] 

205 else: 

206 timesstrings = [f'{time:.1f}' for time in times[[0, 1, 2, -1]]] 

207 timesstrings.insert(-1, '...') 

208 timesstrings[-1] += ' as' 

209 return ', '.join(timesstrings) 

210 

211 

212def format_frequencies(frequencies: np.typing.ArrayLike, 

213 units: str = 'eV') -> str: 

214 """ Write a short list of frequencies for pretty priting. 

215 

216 Parameters 

217 ---------- 

218 frequencies 

219 List of frequencies in units of :attr:`units`. 

220 units 

221 Units of the supplied frequencies. 

222 

223 * ``au`` - atomic units 

224 * ``eV`` - electron volts 

225 

226 Returns 

227 ------- 

228 Formatted list of times in units of as. 

229 """ 

230 frequencies = np.array(frequencies) 

231 if units == 'au': 

232 frequencies *= au_to_eV 

233 elif units != 'eV': 

234 raise ValueError(f'Unknown units {units}. Must be "au" or "eV".') 

235 

236 if len(frequencies) < 5: 

237 # Print all frequencies 

238 freqsstrings = [f'{freq:.1f}' for freq in frequencies] 

239 else: 

240 freqsstrings = [f'{freq:.1f}' for freq in frequencies[[0, 1, 2, -1]]] 

241 freqsstrings.insert(-1, '...') 

242 freqsstrings[-1] += ' eV' 

243 return ', '.join(freqsstrings)