Coverage for tests/unittests/test_utils.py: 82%

49 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-01 16:57 +0000

1from __future__ import annotations 

2 

3import numpy as np 

4import pytest 

5 

6 

7from rhodent.utils import concatenate_indices, ParallelMatrix, partial_format, format_string_to_regex 

8 

9 

10@pytest.mark.parametrize('dims, indices_list, ref_indices_concat, ref_new_indices_list', [ 

11 ((20, ), [slice(10, 12), slice(12, 19)], slice(10, 19), [slice(0, 2), slice(2, 9)]), 

12 ((20, 15), 

13 [(slice(10, 12), slice(3, 6)), (slice(12, 19), slice(3, 6))], 

14 (slice(10, 19), slice(3, 6)), 

15 [(slice(0, 2), slice(0, 3)), (slice(2, 9), slice(0, 3))]), 

16 ((20, 15), 

17 [(slice(10, 12), slice(3, 6)), (slice(12, 19), slice(1, 2))], 

18 (slice(10, 19), slice(1, 6)), 

19 [(slice(0, 2), slice(2, 5)), (slice(2, 9), slice(0, 1))]), 

20]) 

21def test_concatenate_indices(dims, indices_list, ref_indices_concat, ref_new_indices_list): 

22 A = np.random.rand(*dims) 

23 value = 0 

24 new_value = 0 

25 

26 for indices in indices_list: 

27 value += np.sum(A[indices]) 

28 

29 indices_concat, new_indices_list = concatenate_indices(indices_list) 

30 

31 assert indices_concat == ref_indices_concat 

32 assert ref_new_indices_list == new_indices_list 

33 B = A[indices_concat] 

34 

35 for indices in new_indices_list: 

36 new_value += np.sum(B[indices]) 

37 

38 assert abs(value - new_value) < 1e-10 

39 

40 

41@pytest.mark.parametrize('shape', [ 

42 (2503, 2400, 165), 

43 (2503, 2400, 1), 

44 (2503, 2400, 3), 

45 (3, 6, 333, 123, 940), 

46 ]) 

47def test_parallel_matrix(shape): 

48 """ Test that parallel matrix multiplication works. 

49 

50 For a demonstration of where the speed-up matters, try e.g. 

51 shape = (1079, 7618, 7618), corresponding to Al586, or 

52 shape = (3423, 10548, 10548), corresponding to Ag586. Non root rank 

53 will finish the test while the root rank computes the serial comparison. 

54 """ 

55 from gpaw.mpi import world 

56 

57 if world.size < 2: 

58 pytest.skip('Parallel only test') 

59 

60 A_shape = shape[:-1] 

61 B_shape = shape[:-3] + shape[-2:] 

62 

63 A = ParallelMatrix(A_shape, float, 

64 array=np.random.rand(*A_shape) if world.rank == 0 else None) 

65 B = ParallelMatrix(B_shape, float, 

66 array=np.random.rand(*B_shape) if world.rank == 0 else None) 

67 

68 # Calculate in parallel 

69 C = A @ B 

70 

71 if world.rank != 0: 

72 return 

73 

74 serial_C = A.array @ B.array 

75 np.testing.assert_allclose(C.array, serial_C, atol=1e-12) 

76 

77 

78@pytest.mark.parametrize('fmt, kwargs', [ 

79 ('pulserho_pf{pulsefreq:.2f}/t{time:09.1f}{tag}.npy', dict(pulsefreq=3.8, time=30e3, tag='')), 

80 ('{which}_density_w{freq:05.2f}.cube', dict(freq=3.8, which='induced')), 

81 ('{foo:.2f}_dens{bar}/{baz}{qux}', dict(foo=4321, bar='fdsa', baz='eee', qux=342.643)), 

82 ]) 

83def test_partial_format(fmt, kwargs): 

84 # Format the string the usual way 

85 ref_string = fmt.format(**kwargs) 

86 

87 # Partially format the string 

88 kwargs_1 = dict(kwargs) 

89 kwargs_2 = {key: value for key, value in [kwargs_1.popitem()]} 

90 

91 test_string = partial_format(fmt, **kwargs_1).format(**kwargs_2) 

92 

93 assert ref_string == test_string 

94 

95 # Test the opposite order as well 

96 test_string = partial_format(fmt, **kwargs_2).format(**kwargs_1) 

97 

98 assert ref_string == test_string 

99 

100 

101@pytest.mark.parametrize('fmt, kwargs', [ 

102 ('pulserho_pf{pulsefreq:.2f}/t{time:09.1f}{tag}.npy', dict(pulsefreq=3.8, time=30e3, tag='')), 

103 ('pulserho_pf{pulsefreq:.2f}/t{time:09.1f}{tag}.npy', dict(pulsefreq=3.8, time=30e3, tag='-Iomega')), 

104 ('pulserho_pf{pulsefreq:.2f}pw{pulsefwhm:.2f}/t{time:09.1f}{tag}.npy', 

105 dict(pulsefreq=4.0, pulsefwhm=5.0, time=12e3, tag='-Iomega')), 

106 ('{which}_density_w{freq:05.2f}.cube', dict(freq=3.8, which='induced')), 

107 ]) 

108def test_format_string_to_regex(fmt, kwargs): 

109 # Format the format string 

110 formatted_str = fmt.format(**kwargs) 

111 

112 # Construct the regex and parse the formatted string 

113 regex = format_string_to_regex(fmt) 

114 m = regex.fullmatch(formatted_str) 

115 assert m is not None 

116 

117 # We have a dictionary with string values 

118 # Get the type of the original value and cast the string to the same type 

119 parsed_kwargs = {key: type(kwargs[key])(value) for key, value in m.groupdict().items()} 

120 

121 assert kwargs == parsed_kwargs