Coverage for tests/conftest.py: 79%
137 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
1from __future__ import annotations
3import os
4import warnings
5from contextlib import contextmanager
6from pathlib import Path
8from ase.io.ulm import Reader
9from gpaw.mpi import world
11import pytest
13from tests import frho, get_permanent_test_file, permanent_test_files
14from tests.mock import (MockKohnShamRhoWfsReader, MockTimeDensityMatrices,
15 MockConvolutionDensityMatrices, MockVoronoiWeights,
16 MockFrequencyDensityMatrices, MockResponse)
19@contextmanager
20def execute_in_tmp_path(request, tmp_path_factory):
21 from gpaw.mpi import broadcast
23 if world.rank == 0:
24 # Obtain basename as
25 # * request.function.__name__ for function fixture
26 # * request.module.__name__ for module fixture
27 basename = getattr(request, request.scope).__name__
28 path = tmp_path_factory.mktemp(basename)
29 else:
30 path = None
31 path = broadcast(path)
32 cwd = os.getcwd()
33 os.chdir(path)
34 try:
35 yield path
36 finally:
37 os.chdir(cwd)
40@pytest.fixture(scope='function')
41def in_tmp_dir(request, tmp_path_factory):
42 """Run test function in a temporary directory."""
43 with execute_in_tmp_path(request, tmp_path_factory) as path:
44 yield path
47@pytest.fixture(scope='session')
48def cache_path(request, tmp_path_factory) -> Path:
49 """ Path for cached test data """
50 import re
51 from gpaw.mpi import broadcast
53 pathstr: str | None = os.environ.get('GPW_TEST_FILES')
54 if pathstr is not None:
55 return Path(pathstr)
57 warnings.warn(
58 'Note that you can speed up the tests by reusing gpw-files '
59 'from an earlier pytest session: '
60 'set the $GPW_TEST_FILES environment variable and the '
61 'files will be written to/read from that folder. ')
62 if world.rank == 0:
63 name = request.node.name
64 name = re.sub(r'[\W]', '_', name)
65 name = 'rhodent' if name == '' else name # Explicitly needed in pytest 8.0.2
66 MAXVAL = 30
67 name = name[:MAXVAL]
68 path = tmp_path_factory.mktemp(name, numbered=True)
69 broadcast(path)
70 else:
71 path = broadcast(None)
73 return path
76@pytest.fixture
77def cached_file_if_existing(cache_path, test_system):
79 def _cached_file_if_existing(fname: str, check_empty: bool = False) -> tuple[Path, bool]:
80 """ Return path to datafile in cache directory and whether it exists """
82 from gpaw.mpi import broadcast
84 parent = cache_path / test_system
85 parent.mkdir(exist_ok=True)
86 path = parent / fname
87 if world.rank == 0:
88 exists = path.exists()
89 if exists and check_empty:
90 exists = next(path.iterdir(), 'empty') != 'empty'
91 broadcast(exists)
92 else:
93 exists = broadcast(None)
95 return path, exists
97 return _cached_file_if_existing
100@pytest.fixture
101def gpw_fname(test_system):
102 return get_permanent_test_file(test_system, 'gpw_fname')
105@pytest.fixture
106def fdm_fname(test_system):
107 return get_permanent_test_file(test_system, 'fdm_fname')
110@pytest.fixture
111def wfs_fname(test_system):
112 return get_permanent_test_file(test_system, 'wfs_fname')
115@pytest.fixture
116def wfssnap_fname(test_system):
117 return get_permanent_test_file(test_system, 'wfssnap_fname')
120@pytest.fixture
121def ref_density_matrix(test_system):
122 return get_permanent_test_file(test_system, 'ref_density_matrix')
125@pytest.fixture
126def ref_voronoi(test_system):
127 return get_permanent_test_file(test_system, 'ref_voronoi')
130@pytest.fixture
131def ref_hcdist(test_system):
132 return get_permanent_test_file(test_system, 'ref_hcdist')
135@pytest.fixture
136def ref_energy(test_system):
137 return get_permanent_test_file(test_system, 'ref_energy')
140@pytest.fixture
141def ksd_fname(cached_file_if_existing, test_system):
142 return get_permanent_test_file(test_system, 'ksd_fname')
145@pytest.fixture
146def dm_sinc(test_system):
147 return get_permanent_test_file(test_system, 'dm_sinc')
150@pytest.fixture
151def dm_delta(test_system):
152 return get_permanent_test_file(test_system, 'dm_delta')
155@pytest.fixture
156def frho_dname(cached_file_if_existing, test_system, ksd_fname):
157 path, exists = cached_file_if_existing('frho', check_empty=True)
158 if not exists:
159 # Extract the frho
160 frho(ksd_fname=ksd_fname,
161 fdm_fname=permanent_test_files[test_system]['fdm_fname'],
162 frho_dname=path)
163 return path
166@pytest.fixture
167def mock_voronoi(test_system, ksd_fname):
168 with Reader(ksd_fname) as reader:
169 nn = reader.eig_un.shape[2]
171 def factory(**kwargs):
172 reader = MockVoronoiWeights(nn=nn, **kwargs)
173 return reader
175 return factory
178@pytest.fixture
179def mock_ks_rho_reader(test_system, ksd_fname):
181 def factory(**kwargs):
182 reader = MockKohnShamRhoWfsReader(ksd=ksd_fname, **kwargs)
183 return reader
185 return factory
188@pytest.fixture
189def mock_time_density_matrices(test_system, ksd_fname):
191 def factory(**kwargs):
192 density_matrices = MockTimeDensityMatrices(ksd=ksd_fname, **kwargs)
193 return density_matrices
195 return factory
198@pytest.fixture
199def mock_convolution_density_matrices(test_system, ksd_fname):
201 def factory(**kwargs):
202 density_matrices = MockConvolutionDensityMatrices(ksd=ksd_fname, **kwargs)
203 return density_matrices
205 return factory
208@pytest.fixture
209def mock_frequency_density_matrices(test_system, ksd_fname):
211 def factory(**kwargs):
212 density_matrices = MockFrequencyDensityMatrices(ksd=ksd_fname, **kwargs)
213 return density_matrices
215 return factory
218@pytest.fixture
219def mock_response(test_system, ksd_fname):
221 def factory(**kwargs):
222 response = MockResponse(ksd=ksd_fname, **kwargs)
223 return response
225 return factory