Coverage for tests/conftest.py: 79%

137 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-01 16:57 +0000

1from __future__ import annotations 

2 

3import os 

4import warnings 

5from contextlib import contextmanager 

6from pathlib import Path 

7 

8from ase.io.ulm import Reader 

9from gpaw.mpi import world 

10 

11import pytest 

12 

13from tests import frho, get_permanent_test_file, permanent_test_files 

14from tests.mock import (MockKohnShamRhoWfsReader, MockTimeDensityMatrices, 

15 MockConvolutionDensityMatrices, MockVoronoiWeights, 

16 MockFrequencyDensityMatrices, MockResponse) 

17 

18 

19@contextmanager 

20def execute_in_tmp_path(request, tmp_path_factory): 

21 from gpaw.mpi import broadcast 

22 

23 if world.rank == 0: 

24 # Obtain basename as 

25 # * request.function.__name__ for function fixture 

26 # * request.module.__name__ for module fixture 

27 basename = getattr(request, request.scope).__name__ 

28 path = tmp_path_factory.mktemp(basename) 

29 else: 

30 path = None 

31 path = broadcast(path) 

32 cwd = os.getcwd() 

33 os.chdir(path) 

34 try: 

35 yield path 

36 finally: 

37 os.chdir(cwd) 

38 

39 

40@pytest.fixture(scope='function') 

41def in_tmp_dir(request, tmp_path_factory): 

42 """Run test function in a temporary directory.""" 

43 with execute_in_tmp_path(request, tmp_path_factory) as path: 

44 yield path 

45 

46 

47@pytest.fixture(scope='session') 

48def cache_path(request, tmp_path_factory) -> Path: 

49 """ Path for cached test data """ 

50 import re 

51 from gpaw.mpi import broadcast 

52 

53 pathstr: str | None = os.environ.get('GPW_TEST_FILES') 

54 if pathstr is not None: 

55 return Path(pathstr) 

56 

57 warnings.warn( 

58 'Note that you can speed up the tests by reusing gpw-files ' 

59 'from an earlier pytest session: ' 

60 'set the $GPW_TEST_FILES environment variable and the ' 

61 'files will be written to/read from that folder. ') 

62 if world.rank == 0: 

63 name = request.node.name 

64 name = re.sub(r'[\W]', '_', name) 

65 name = 'rhodent' if name == '' else name # Explicitly needed in pytest 8.0.2 

66 MAXVAL = 30 

67 name = name[:MAXVAL] 

68 path = tmp_path_factory.mktemp(name, numbered=True) 

69 broadcast(path) 

70 else: 

71 path = broadcast(None) 

72 

73 return path 

74 

75 

76@pytest.fixture 

77def cached_file_if_existing(cache_path, test_system): 

78 

79 def _cached_file_if_existing(fname: str, check_empty: bool = False) -> tuple[Path, bool]: 

80 """ Return path to datafile in cache directory and whether it exists """ 

81 

82 from gpaw.mpi import broadcast 

83 

84 parent = cache_path / test_system 

85 parent.mkdir(exist_ok=True) 

86 path = parent / fname 

87 if world.rank == 0: 

88 exists = path.exists() 

89 if exists and check_empty: 

90 exists = next(path.iterdir(), 'empty') != 'empty' 

91 broadcast(exists) 

92 else: 

93 exists = broadcast(None) 

94 

95 return path, exists 

96 

97 return _cached_file_if_existing 

98 

99 

100@pytest.fixture 

101def gpw_fname(test_system): 

102 return get_permanent_test_file(test_system, 'gpw_fname') 

103 

104 

105@pytest.fixture 

106def fdm_fname(test_system): 

107 return get_permanent_test_file(test_system, 'fdm_fname') 

108 

109 

110@pytest.fixture 

111def wfs_fname(test_system): 

112 return get_permanent_test_file(test_system, 'wfs_fname') 

113 

114 

115@pytest.fixture 

116def wfssnap_fname(test_system): 

117 return get_permanent_test_file(test_system, 'wfssnap_fname') 

118 

119 

120@pytest.fixture 

121def ref_density_matrix(test_system): 

122 return get_permanent_test_file(test_system, 'ref_density_matrix') 

123 

124 

125@pytest.fixture 

126def ref_voronoi(test_system): 

127 return get_permanent_test_file(test_system, 'ref_voronoi') 

128 

129 

130@pytest.fixture 

131def ref_hcdist(test_system): 

132 return get_permanent_test_file(test_system, 'ref_hcdist') 

133 

134 

135@pytest.fixture 

136def ref_energy(test_system): 

137 return get_permanent_test_file(test_system, 'ref_energy') 

138 

139 

140@pytest.fixture 

141def ksd_fname(cached_file_if_existing, test_system): 

142 return get_permanent_test_file(test_system, 'ksd_fname') 

143 

144 

145@pytest.fixture 

146def dm_sinc(test_system): 

147 return get_permanent_test_file(test_system, 'dm_sinc') 

148 

149 

150@pytest.fixture 

151def dm_delta(test_system): 

152 return get_permanent_test_file(test_system, 'dm_delta') 

153 

154 

155@pytest.fixture 

156def frho_dname(cached_file_if_existing, test_system, ksd_fname): 

157 path, exists = cached_file_if_existing('frho', check_empty=True) 

158 if not exists: 

159 # Extract the frho 

160 frho(ksd_fname=ksd_fname, 

161 fdm_fname=permanent_test_files[test_system]['fdm_fname'], 

162 frho_dname=path) 

163 return path 

164 

165 

166@pytest.fixture 

167def mock_voronoi(test_system, ksd_fname): 

168 with Reader(ksd_fname) as reader: 

169 nn = reader.eig_un.shape[2] 

170 

171 def factory(**kwargs): 

172 reader = MockVoronoiWeights(nn=nn, **kwargs) 

173 return reader 

174 

175 return factory 

176 

177 

178@pytest.fixture 

179def mock_ks_rho_reader(test_system, ksd_fname): 

180 

181 def factory(**kwargs): 

182 reader = MockKohnShamRhoWfsReader(ksd=ksd_fname, **kwargs) 

183 return reader 

184 

185 return factory 

186 

187 

188@pytest.fixture 

189def mock_time_density_matrices(test_system, ksd_fname): 

190 

191 def factory(**kwargs): 

192 density_matrices = MockTimeDensityMatrices(ksd=ksd_fname, **kwargs) 

193 return density_matrices 

194 

195 return factory 

196 

197 

198@pytest.fixture 

199def mock_convolution_density_matrices(test_system, ksd_fname): 

200 

201 def factory(**kwargs): 

202 density_matrices = MockConvolutionDensityMatrices(ksd=ksd_fname, **kwargs) 

203 return density_matrices 

204 

205 return factory 

206 

207 

208@pytest.fixture 

209def mock_frequency_density_matrices(test_system, ksd_fname): 

210 

211 def factory(**kwargs): 

212 density_matrices = MockFrequencyDensityMatrices(ksd=ksd_fname, **kwargs) 

213 return density_matrices 

214 

215 return factory 

216 

217 

218@pytest.fixture 

219def mock_response(test_system, ksd_fname): 

220 

221 def factory(**kwargs): 

222 response = MockResponse(ksd=ksd_fname, **kwargs) 

223 return response 

224 

225 return factory