Coverage for tests/unittests/test_utils.py: 82%
49 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
1from __future__ import annotations
3import numpy as np
4import pytest
7from rhodent.utils import concatenate_indices, ParallelMatrix, partial_format, format_string_to_regex
10@pytest.mark.parametrize('dims, indices_list, ref_indices_concat, ref_new_indices_list', [
11 ((20, ), [slice(10, 12), slice(12, 19)], slice(10, 19), [slice(0, 2), slice(2, 9)]),
12 ((20, 15),
13 [(slice(10, 12), slice(3, 6)), (slice(12, 19), slice(3, 6))],
14 (slice(10, 19), slice(3, 6)),
15 [(slice(0, 2), slice(0, 3)), (slice(2, 9), slice(0, 3))]),
16 ((20, 15),
17 [(slice(10, 12), slice(3, 6)), (slice(12, 19), slice(1, 2))],
18 (slice(10, 19), slice(1, 6)),
19 [(slice(0, 2), slice(2, 5)), (slice(2, 9), slice(0, 1))]),
20])
21def test_concatenate_indices(dims, indices_list, ref_indices_concat, ref_new_indices_list):
22 A = np.random.rand(*dims)
23 value = 0
24 new_value = 0
26 for indices in indices_list:
27 value += np.sum(A[indices])
29 indices_concat, new_indices_list = concatenate_indices(indices_list)
31 assert indices_concat == ref_indices_concat
32 assert ref_new_indices_list == new_indices_list
33 B = A[indices_concat]
35 for indices in new_indices_list:
36 new_value += np.sum(B[indices])
38 assert abs(value - new_value) < 1e-10
41@pytest.mark.parametrize('shape', [
42 (2503, 2400, 165),
43 (2503, 2400, 1),
44 (2503, 2400, 3),
45 (3, 6, 333, 123, 940),
46 ])
47def test_parallel_matrix(shape):
48 """ Test that parallel matrix multiplication works.
50 For a demonstration of where the speed-up matters, try e.g.
51 shape = (1079, 7618, 7618), corresponding to Al586, or
52 shape = (3423, 10548, 10548), corresponding to Ag586. Non root rank
53 will finish the test while the root rank computes the serial comparison.
54 """
55 from gpaw.mpi import world
57 if world.size < 2:
58 pytest.skip('Parallel only test')
60 A_shape = shape[:-1]
61 B_shape = shape[:-3] + shape[-2:]
63 A = ParallelMatrix(A_shape, float,
64 array=np.random.rand(*A_shape) if world.rank == 0 else None)
65 B = ParallelMatrix(B_shape, float,
66 array=np.random.rand(*B_shape) if world.rank == 0 else None)
68 # Calculate in parallel
69 C = A @ B
71 if world.rank != 0:
72 return
74 serial_C = A.array @ B.array
75 np.testing.assert_allclose(C.array, serial_C, atol=1e-12)
78@pytest.mark.parametrize('fmt, kwargs', [
79 ('pulserho_pf{pulsefreq:.2f}/t{time:09.1f}{tag}.npy', dict(pulsefreq=3.8, time=30e3, tag='')),
80 ('{which}_density_w{freq:05.2f}.cube', dict(freq=3.8, which='induced')),
81 ('{foo:.2f}_dens{bar}/{baz}{qux}', dict(foo=4321, bar='fdsa', baz='eee', qux=342.643)),
82 ])
83def test_partial_format(fmt, kwargs):
84 # Format the string the usual way
85 ref_string = fmt.format(**kwargs)
87 # Partially format the string
88 kwargs_1 = dict(kwargs)
89 kwargs_2 = {key: value for key, value in [kwargs_1.popitem()]}
91 test_string = partial_format(fmt, **kwargs_1).format(**kwargs_2)
93 assert ref_string == test_string
95 # Test the opposite order as well
96 test_string = partial_format(fmt, **kwargs_2).format(**kwargs_1)
98 assert ref_string == test_string
101@pytest.mark.parametrize('fmt, kwargs', [
102 ('pulserho_pf{pulsefreq:.2f}/t{time:09.1f}{tag}.npy', dict(pulsefreq=3.8, time=30e3, tag='')),
103 ('pulserho_pf{pulsefreq:.2f}/t{time:09.1f}{tag}.npy', dict(pulsefreq=3.8, time=30e3, tag='-Iomega')),
104 ('pulserho_pf{pulsefreq:.2f}pw{pulsefwhm:.2f}/t{time:09.1f}{tag}.npy',
105 dict(pulsefreq=4.0, pulsefwhm=5.0, time=12e3, tag='-Iomega')),
106 ('{which}_density_w{freq:05.2f}.cube', dict(freq=3.8, which='induced')),
107 ])
108def test_format_string_to_regex(fmt, kwargs):
109 # Format the format string
110 formatted_str = fmt.format(**kwargs)
112 # Construct the regex and parse the formatted string
113 regex = format_string_to_regex(fmt)
114 m = regex.fullmatch(formatted_str)
115 assert m is not None
117 # We have a dictionary with string values
118 # Get the type of the original value and cast the string to the same type
119 parsed_kwargs = {key: type(kwargs[key])(value) for key, value in m.groupdict().items()}
121 assert kwargs == parsed_kwargs