Coverage for rhodent/utils/__init__.py: 88%
282 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-08-01 16:57 +0000
1from __future__ import annotations
3import re
4from contextlib import nullcontext
5from operator import itemgetter
6from pathlib import Path
7from typing import Any, Callable, Generic, Iterable, NamedTuple, TypeVar
8import numpy as np
9from numpy.typing import NDArray
10from numpy._typing import _DTypeLike as DTypeLike # parametrizable wrt generic
12from ase.io.ulm import open
13from ase.parallel import parprint
14from gpaw.lcaotddft.ksdecomposition import KohnShamDecomposition
15from gpaw.lcaotddft.laser import GaussianPulse
16from gpaw.mpi import SerialCommunicator, world
17from gpaw.tddft.units import fs_to_au, au_to_eV
19from .logging import Logger, NoLogger
20from .result import Result, ResultKeys
21from ..perturbation import Perturbation
22from ..typing import Array1D, Communicator
24__all__ = [
25 'Logger',
26 'Result',
27 'ResultKeys',
28]
31DTypeT = TypeVar('DTypeT', bound=np.generic, covariant=True)
34class ParallelMatrix(Generic[DTypeT]):
36 """ Distributed array, with data on the root rank.
38 Parameters
39 ----------
40 shape
41 Shape of array.
42 dtype
43 Dtype of array.
44 comm
45 MPI communicator.
46 array
47 Array on root rank of the communicator. Must be ``None`` on other ranks.
48 """
50 def __init__(self,
51 shape: tuple[int, ...],
52 dtype: DTypeLike[DTypeT],
53 comm: Communicator | None = None,
54 array: NDArray[DTypeT] | None = None):
55 if comm is None:
56 comm = world
57 self.comm = comm
58 self.shape = shape
59 self.dtype = np.dtype(dtype)
61 self._array: NDArray[DTypeT] | None
62 if self.root:
63 assert array is not None
64 assert array.shape == shape
65 assert array.dtype == np.dtype(dtype)
66 self._array = array
67 else:
68 assert array is None
69 self._array = None
71 @property
72 def array(self) -> NDArray[DTypeT]:
73 """ Array with data. May only be called on the root rank. """
74 if not self.root:
75 raise RuntimeError('May only be called on root')
76 assert self._array is not None
77 return self._array
79 @property
80 def root(self) -> bool:
81 """ Whether this rank is the root rank. """
82 return self.comm.rank == 0
84 @property
85 def T(self) -> ParallelMatrix:
86 shape = self.shape[:-2] + self.shape[-2:][::-1]
87 return ParallelMatrix(shape=shape, dtype=self.dtype, comm=self.comm,
88 array=self.array.T if self.root else None)
90 def broadcast(self) -> NDArray[DTypeT]:
91 """ Broadcasted data. """
92 if self.root:
93 A = np.ascontiguousarray(self.array)
94 else:
95 A = np.zeros(self.shape, self.dtype)
97 self.comm.broadcast(A, 0)
99 return A
101 def __matmul__(self, other) -> ParallelMatrix[DTypeT]:
102 """ Perform matrix multiplication in parallel. """
103 if not isinstance(other, ParallelMatrix):
104 raise NotImplementedError
106 assert self.dtype == other.dtype
108 A = self.broadcast()
109 B = other.broadcast()
111 # Allocate array for result
112 ni, nj = A.shape[-2:]
113 nk = B.shape[-1]
114 C_shape = np.broadcast_shapes(A.shape[:-2], B.shape[:-2]) + (ni, nk)
115 C = np.zeros(C_shape, self.dtype)
117 # Determine slice for ranks
118 stridek = (nk + self.comm.size - 1) // self.comm.size
119 slicek = slice(stridek * self.comm.rank, stridek * (self.comm.rank + 1))
121 # Perform the matrix multiplication
122 C[..., :, slicek] = A @ B[..., :, slicek]
124 # Sum to root rank
125 self.comm.sum(C, 0)
127 result = ParallelMatrix(C_shape, self.dtype, comm=self.comm,
128 array=C if self.root else None)
129 return result
132def gauss_ij_with_filter(energy_i: np.typing.ArrayLike,
133 energy_j: np.typing.ArrayLike,
134 sigma: float,
135 fltthresh: float | None = None,
136 ) -> tuple[NDArray[np.float64], NDArray[np.float64]]:
137 r""" Computes the matrix
139 .. math::
141 M_{ij}
142 = \frac{1}{\sqrt{2 \pi \sigma^2}}
143 \exp\left(-\frac{
144 \left(\varepsilon_i - \varepsilon_j\right)^2
145 }{
146 2 \sigma^2
147 }\right)
149 Useful for Gaussian broadening. Optionally only computes the exponent
150 above a certain threshold, and returns the filter.
152 Parameters
153 ----------
154 energy_i
155 Energies :math:`\varepsilon_i`.
156 energy_j
157 Energies :math:`\varepsilon_j`.
158 sigma
159 Gaussian broadening width :math:`\sigma`.
160 fltthresh
161 Filtering threshold.
163 Returns
164 -------
165 Matrix :math:`M_{ij}`, filter.
166 """
167 energy_i = np.asarray(energy_i)
168 energy_j = np.asarray(energy_j)
170 norm = 1.0 / (sigma * np.sqrt(2 * np.pi))
172 denergy_ij = energy_i[:, np.newaxis] - energy_j[np.newaxis, :]
173 exponent_ij = -0.5 * (denergy_ij / sigma) ** 2
175 if fltthresh is not None:
176 flt_i = np.any(exponent_ij > fltthresh, axis=1)
177 M_ij = np.zeros_like(exponent_ij)
178 M_ij[flt_i] = norm * np.exp(exponent_ij[flt_i])
179 else:
180 flt_i = np.ones(energy_i.shape, dtype=bool)
181 M_ij = norm * np.exp(exponent_ij)
183 return M_ij, flt_i # type: ignore
186def gauss_ij(energy_i: np.typing.ArrayLike,
187 energy_j: np.typing.ArrayLike,
188 sigma: float) -> NDArray[np.float64]:
189 r""" Computes the matrix
191 .. math::
193 M_{ij}
194 = \frac{1}{\sqrt{2 \pi \sigma^2}}
195 \exp\left(-\frac{
196 \left(\varepsilon_i - \varepsilon_j\right)^2
197 }{
198 2 \sigma^2
199 }\right),
201 which is useful for Gaussian broadening.
203 Parameters
204 ----------
205 energy_i
206 Energies :math:`\varepsilon_i`.
207 energy_j
208 Energies :math:`\varepsilon_j`.
209 sigma
210 Gaussian broadening width :math:`\sigma`.
212 Returns
213 -------
214 Matrix :math:`M_{ij}`.
215 """
216 M_ij, _ = gauss_ij_with_filter(energy_i, energy_j, sigma)
217 return M_ij
220def broaden_n2e(M_n: np.typing.ArrayLike,
221 energy_n: np.typing.ArrayLike,
222 energy_e: np.typing.ArrayLike,
223 sigma: float) -> NDArray[np.float64]:
224 r""" Broaden matrix onto energy grids
226 .. math::
228 M(\varepsilon_e)
229 = \sum_n M_n \frac{1}{\sqrt{2 \pi \sigma^2}}
230 \exp\left(-\frac{
231 \left(\varepsilon_n - \varepsilon_e\right)^2
232 }{
233 2 \sigma^2
234 }\right),
236 Returns
237 -------
238 :math:`M(\varepsilon_0)`
239 """
240 M_n = np.asarray(M_n)
241 gauss_ne, flt_n = gauss_ij_with_filter(energy_n, energy_e, sigma)
243 M_e = np.einsum('n,ne->e', M_n[flt_n], gauss_ne[flt_n], optimize=True)
245 return M_e
248def broaden_xn2e(M_xn: np.typing.ArrayLike,
249 energy_n: np.typing.ArrayLike,
250 energy_e: np.typing.ArrayLike,
251 sigma: float) -> NDArray[np.float64]:
252 r""" Broaden matrix onto energy grids
254 .. math::
256 M(\varepsilon_e)
257 = \sum_n M_n \frac{1}{\sqrt{2 \pi \sigma^2}}
258 \exp\left(-\frac{
259 \left(\varepsilon_n - \varepsilon_e\right)^2
260 }{
261 2 \sigma^2
262 }\right).
264 Returns
265 -------
266 :math:`M(\varepsilon_0)`.
267 """
268 M_xn = np.asarray(M_xn)
269 gauss_ne, flt_n = gauss_ij_with_filter(energy_n, energy_e, sigma)
271 M_xe = np.einsum('xn,ne->xe',
272 M_xn.reshape((-1, len(flt_n)))[:, flt_n],
273 gauss_ne[flt_n],
274 optimize=True).reshape(M_xn.shape[:-1] + (-1, ))
276 return M_xe
279def broaden_ia2ou(M_ia: np.typing.ArrayLike,
280 energy_i: np.typing.ArrayLike,
281 energy_a: np.typing.ArrayLike,
282 energy_o: np.typing.ArrayLike,
283 energy_u: np.typing.ArrayLike,
284 sigma: float) -> NDArray[np.float64]:
285 r""" Broaden matrix onto energy grids.
287 .. math::
289 M(\varepsilon_o, \varepsilon_u)
290 = \sum_{ia} M_{ia} \frac{1}{\sqrt{2 \pi \sigma^2}}
291 \exp\left(-\frac{
292 (\varepsilon_i - \varepsilon_o)^2
293 }{
294 2 \sigma^2
295 }\right)
296 \exp\left(-\frac{
297 (\varepsilon_a - \varepsilon_u)^2
298 }{
299 2 \sigma^2
300 }\right)
302 Returns
303 -------
304 :math:`M(\varepsilon_o, \varepsilon_u)`.
305 """
306 M_ia = np.asarray(M_ia)
307 ia_shape = M_ia.shape[:2]
308 x_shape = M_ia.shape[2:]
309 M_iax = M_ia.reshape(ia_shape + (-1, ))
310 gauss_io, flt_i = gauss_ij_with_filter(energy_i, energy_o, sigma)
311 gauss_au, flt_a = gauss_ij_with_filter(energy_a, energy_u, sigma)
313 M_oux = np.einsum('iax,io,au->oux', M_iax[flt_i, :][:, flt_a],
314 gauss_io[flt_i], gauss_au[flt_a],
315 optimize=True, order='C')
317 return M_oux.reshape(M_oux.shape[:2] + x_shape)
320def get_array_filter(values: Array1D[np.float64] | list[float],
321 filter_values: Array1D[np.float64] | list[float] | None,
322 ) -> slice | Array1D[np.bool_]:
323 """ Get array filter that can be used to filter out data.
325 Parameters
326 ----------
327 values
328 Array of values, e.g. linspace of times or frequencies.
329 filter_values
330 List of values that one wishes to extract. The closes values from values
331 will be selected as filter.
333 Returns
334 -------
335 Object that can be used to index values and arrays with the same shape as values.
336 """
337 _values = np.array(values)
338 flt_x: slice | NDArray[np.bool_]
339 if len(values) == 0:
340 # Empty list of arrays
341 return slice(None)
343 if filter_values is None or len(filter_values) == 0:
344 # No filter
345 return slice(None)
347 flt_x = np.zeros(len(values), dtype=bool)
348 for filter_value in filter_values:
349 # Search for closest value
350 idx = np.argmin(np.abs(_values - filter_value))
351 flt_x[idx] = True
353 return flt_x
356def filter_array(values: Array1D[np.float64] | list[float],
357 filter_values: Array1D[np.float64] | list[float] | None,
358 ) -> Array1D[np.float64]:
359 """ Filter array, picking values closest to :attr:`filter_values`.
361 Parameters
362 ----------
363 values
364 Array of values, e.g. linspace of times or frequencies.
365 filter_values
366 List of values that one wishes to extract. The closes values from values
367 will be selected as filter.
369 Returns
370 -------
371 Filtered array.
372 """
373 array = np.array(values)
374 return array[get_array_filter(array, filter_values)] # type: ignore
377def two_communicator_sizes(*comm_sizes) -> tuple[int, int]:
378 assert len(comm_sizes) == 2
379 comm_size_c: list[int] = [world.size if size == 'world' else size for size in comm_sizes]
380 if comm_size_c[0] == -1:
381 comm_size_c[0] = world.size // comm_size_c[1]
382 elif comm_size_c[1] == -1:
383 comm_size_c[1] = world.size // comm_size_c[0]
385 assert np.prod(comm_size_c) == world.size, \
386 f'Communicator sizes must factorize world size {world.size} '\
387 'but they are ' + ' and '.join([str(s) for s in comm_size_c]) + '.'
388 return comm_size_c[0], comm_size_c[1]
391def two_communicators(*comm_sizes) -> tuple[Communicator, Communicator]:
392 """ Create two MPI communicators.
394 Must satisfy ``comm_sizes[0] * comm_sizes[1] = world.size``.
396 The second communicator has the ranks in sequence.
398 Example
399 -------
401 >>> world.size == 8
402 >>> two_communicators(2, 4)
404 This gives::
406 [0, 4]
407 [1, 5]
408 [2, 6]
409 [3, 7]
411 and::
413 [0, 1, 2, 3]
414 [4, 5, 6, 7]
415 """
416 comm_size_c = two_communicator_sizes(*comm_sizes)
418 # Create communicators
419 if comm_size_c[0] == 1:
420 return (SerialCommunicator(), world) # type: ignore
421 elif comm_size_c[0] == world.size:
422 return (world, SerialCommunicator()) # type: ignore
423 else:
424 assert world.size % comm_size_c[0] == 0, world.size
425 # Comm 2, ranks in sequence. Comm 1, ranks skip by size of comm 2
426 first_rank_in_comm_c = [world.rank % comm_size_c[1],
427 world.rank - world.rank % comm_size_c[1]]
428 step_c = [comm_size_c[1], 1]
429 comm_ranks_cr = [list(range(start, start + size*step, step))
430 for start, size, step in zip(first_rank_in_comm_c, comm_size_c, step_c)]
431 comm_c = [world.new_communicator(comm_ranks_r) for comm_ranks_r in comm_ranks_cr]
432 return comm_c[0], comm_c[1]
435def detect_repeatrange(n: int,
436 stride: int,
437 verbose: bool = True) -> slice | None:
438 """ If an array of length :attr:`n` is not divisible by the stride :attr:`stride`
439 then some work will have to be repeated
440 """
441 final_start = (n // stride) * stride
442 repeatrange = slice(final_start, n)
443 if repeatrange.start == repeatrange.stop:
444 return None
445 else:
446 print(f'Detected repeatrange {repeatrange}', flush=True)
447 return repeatrange
449 return None
452def safe_fill(a: NDArray[DTypeT],
453 b: NDArray[DTypeT]):
454 """ Perform the operation ``a[:] = b``, checking if the dimensions match.
456 If the dimensions of :attr:`b` are larger than the dimensions of :attr:`a`, raise an error.
458 If the dimensions of :attr:`b` are smaller than the dimensions of :attr:`a`, write to
459 the first elements of :attr:`a`.
460 """
461 assert len(a.shape) == len(b.shape), f'{a.shape} != {b.shape}'
462 assert all([dima >= dimb for dima, dimb in zip(a.shape, b.shape)]), f'{a.shape} < {b.shape}'
463 s = tuple([slice(dim) for dim in b.shape])
464 a[s] = b
467def safe_fill_larger(a: NDArray[DTypeT],
468 b: NDArray[DTypeT]):
469 """ Perform the operation ``a[:] = b``, checking if the dimensions match.
471 If the dimensions of :attr:`b` are smaller than the dimensions of :attr:`a`, raise an error.
473 If the dimensions of :attr:`b` are larger than the dimensions of :attr:`a`, write the first
474 elements of :attr:`b` to :attr:`a`.
475 """
476 assert len(a.shape) == len(b.shape), f'{a.shape} != {b.shape}'
477 assert all([dimb >= dima for dima, dimb in zip(a.shape, b.shape)]), f'{a.shape} > {b.shape}'
478 s = tuple([slice(dim) for dim in a.shape])
479 a[:] = b[s]
482IND = TypeVar('IND', slice, tuple[slice, ...])
485def concatenate_indices(indices_list: Iterable[IND],
486 ) -> tuple[IND, list[IND]]:
487 """ Concatenate indices.
489 Given an array A and a list of incides indices_list such that A can be indexed
491 >>> for indices in indices_list:
492 >>> A[indices]
494 this function shall concatenate the indices into indices_concat so that the array
495 can be indexed in one go. This function will also give a new list of indices
496 new_indices_list that can be used to index the ``A[indices_concat]``. The following
497 snippet shall be equivalent to the previous snipped.
499 >>> B = A[indices_concat]
500 >>> for indices in new_indices_list:
501 >>> B[indices]
503 Note that the indices need not be ordered, nor contigous, but the returned
504 indices_concat will be a list of slices, and thus contiguous.
506 Example
507 -------
509 >>> A = np.random.rand(100)
510 >>> value = 0
511 >>> new_value = 0
512 >>>
513 >>> indices_list = [slice(10, 12), slice(12, 19)]
514 >>> for indices in indices_list:
515 >>> value += np.sum(A[indices])
516 >>>
517 >>> indices_concat, new_indices_list = concatenate_indices(indices_list)
518 >>> new_value = np.sum(A[indices_concat])
519 >>>
520 >>> assert abs(value - new_value) < 1e-10
521 >>>
522 >>> B = A[indices_concat]
523 >>> assert B.shape == (9, )
524 >>> new_value = 0
525 >>> for indices in new_indices_list:
526 >>> new_value += np.sum(B[indices])
527 >>>
528 >>> assert abs(value - new_value) < 1e-10
530 Returns
531 -------
532 ``(indices_concat, new_indices_list)``
533 """
534 indices_list = list(indices_list)
535 if len(indices_list) == 0:
536 return slice(0), [] # type: ignore
538 if not isinstance(indices_list[0], tuple):
539 # If indices are not tuples, then wrap everything in tuples and recurse
540 assert all([not isinstance(indices, tuple) for indices in indices_list])
541 _indices_concat, _new_indices_list = _concatenate_indices([(indices, ) for indices in indices_list])
542 return _indices_concat[0], [indices[0] for indices in _new_indices_list]
544 # All indices are wrapped in tuples
545 assert all([isinstance(indices, tuple) for indices in indices_list])
546 return _concatenate_indices(indices_list) # type: ignore
549def _concatenate_indices(indices_list: Iterable[tuple[slice, ...]],
550 ) -> tuple[tuple[slice, ...], list[tuple[slice, ...]]]:
551 """ See :func:`concatenate_indices`
552 """
553 limits_jis = np.array([[(index.start, index.stop, index.step) for index in indices]
554 for indices in indices_list])
556 start_i = np.min(limits_jis[..., 0], axis=0)
557 stop_i = np.max(limits_jis[..., 1], axis=0)
559 indices_concat = tuple([slice(start, stop) for start, stop in zip(start_i, stop_i)])
560 new_indices_list = [tuple([slice(start - startcat, stop - startcat, step)
561 for (startcat, (start, stop, step)) in zip(start_i, limits_is)])
562 for limits_is in limits_jis]
564 return indices_concat, new_indices_list
567def parulmopen(fname: str, comm: Communicator, *args, **kwargs):
568 if comm.rank == 0:
569 return open(fname, *args, **kwargs)
570 else:
571 return nullcontext()
574def proxy_sknX_slicen(reader, *args, comm: Communicator) -> NDArray[np.complex128]:
575 if len(args) == 0:
576 A_sknX = reader
577 else:
578 A_sknX = reader.proxy(*args)
579 nn = A_sknX.shape[2]
580 nlocaln = (nn + comm.size - 1) // comm.size
581 myslicen = slice(comm.rank * nlocaln, (comm.rank + 1) * nlocaln)
582 my_A_sknX = np.array([[A_nX[myslicen] for A_nX in A_knX] for A_knX in A_sknX])
584 return my_A_sknX
587def add_fake_kpts(ksd: KohnShamDecomposition):
588 """This function is necessary to read some fields without having a
589 calculator attached.
590 """
592 class FakeKpt(NamedTuple):
593 s: int
594 k: int
596 class FakeKsl(NamedTuple):
597 using_blacs: bool = False
599 # Figure out
600 ksdreader = ksd.reader
601 skshape = ksdreader.eig_un.shape[:2]
602 kpt_u = [FakeKpt(s=s, k=k)
603 for s in range(skshape[0])
604 for k in range(skshape[1])]
605 ksd.kpt_u = kpt_u
606 ksd.ksl = FakeKsl()
609def create_pulse(frequency: float,
610 fwhm: float = 5.0,
611 t0: float = 10.0,
612 print: Callable | None = None) -> GaussianPulse:
613 """ Create Gaussian laser pulse.
615 frequency
616 Pulse frequncy in units of eV.
617 fwhm
618 Full width at half maximum in time domain in units of fs.
619 t0
620 Maximum of pulse envelope in units of fs.
621 print
622 Printing function to control verbosity.
623 """
624 if print is None:
625 print = parprint
627 # Pulse
628 fwhm_eV = 8 * np.log(2) / (fwhm * fs_to_au) * au_to_eV
629 tau = fwhm / (2 * np.sqrt(2 * np.log(2)))
630 sigma = 1 / (tau * fs_to_au) * au_to_eV # eV
631 strength = 1e-6
632 t0 = t0 * 1e3
633 sincos = 'cos'
634 print(f'Creating pulse at {frequency:.3f}eV with FWHM {fwhm:.2f}fs '
635 f'({fwhm_eV:.2f}eV) t0 {t0:.1f}fs', flush=True)
637 return GaussianPulse(strength, t0, frequency, sigma, sincos)
640def get_gaussian_pulse_values(pulse: Perturbation) -> dict[str, float]:
641 """ Get pulse frequency and FWHM of Gaussian pulse.
643 Returns
644 -------
645 Empty dictionary if pulse is not `GaussianPulse`, otherwise dictionary with keys:
647 ``pulsefreq`` - pulse frequency in units of eV.
648 ``pulsefwhm`` - pulse full width at half-maximum in time domain in units of fs.
649 """
650 from gpaw.tddft.units import eV_to_au, au_to_fs
652 d = pulse.todict()
653 ret: dict[str, float] = dict()
654 if d['name'] == 'GaussianPulse':
655 ret['pulsefreq'] = d['frequency']
656 ret['pulsefwhm'] = (1 / (d['sigma'] * eV_to_au) * au_to_fs *
657 (2 * np.sqrt(2 * np.log(2))))
658 return ret
661fast_pad_nice_factors = np.array([4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 8096])
664def fast_pad(nt: int) -> int:
665 """ Return a length that is at least twice as large as the given input,
666 and the FFT of data of such length is fast.
667 """
668 padnt = 2 * nt
669 insert = np.searchsorted(fast_pad_nice_factors, padnt)
670 if insert <= len(fast_pad_nice_factors):
671 padnt = fast_pad_nice_factors[insert]
672 assert padnt >= 2 * nt
673 return padnt
676def format_string_to_glob(fmt: str) -> str:
677 """ Convert a format string to a glob-type expression.
679 Replaces all the replacement fields ``{...}`` in the format string
680 with a glob ``*``.
682 Example
683 -------
684 >>> format_string_to_glob('pulserho_pf{pulsefreq:.2f}/t{time:09.1f}{tag}.npy')
685 pulserho_pf*/t*.npy
687 Parameters
688 ---------
689 fmt
690 Format string.
692 Returns
693 -------
694 Glob-type expression.
695 """
696 # Replace replacement fields by *
697 # Note how several replacement fields next to each other are
698 # replaced by only one * thanks to the (...)+
699 glob_expr = re.sub(r'({[^{}]*})+', '*', fmt)
700 return glob_expr
703def format_string_to_regex(fmt: str) -> re.Pattern:
704 r""" Convert a format string to a regex expression.
706 Replaces all the replacement fields ``{...}`` in the format string
707 with a regular expression and escapes all special characters outside
708 the replacement fields.
710 Replacement fields for variables ``time``, ``freq``, ``pulsefreq``
711 and ``pulsefwhm`` are replaced by regex matching floating point numbers.
712 Replacement fields for variables ``reim`` and ``tag`` are replaced by
713 regex matching alphabetic characters and dashes.
714 Remaining replacement fields are replaced by regex matching alphabetic
715 characters.
717 This can be used to parse a formatted string in order to get back the original
718 values.
720 Example
721 -------
722 >>> fmt = 'pulserho_pf{pulsefreq:.2f}/t{time:09.1f}{tag}.npy'
723 >>> s = fmt.format(pulsefreq=3.8, time=30000, tag='-Iomega')
724 pulserho_pf3.80/t0030000.0-Iomega.npy
725 >>> regex = format_string_to_regex(fmt)
726 re.compile('pulserho_pf(?P<pulsefreq>[-+]?[\d.]+)/t(?P<time>[-+]?[\d.]+)(?P<tag>[-A-za-z]*)\.npy')
727 >>> regex.fullmatch(s).groupdict()
728 {'pulsefreq': '3.80', 'time': '0030000.0', 'tag': '-Iomega'}
730 Parameters
731 ---------
732 fmt
733 Format string.
735 Returns
736 -------
737 Compiled regex pattern.
739 Notes
740 -----
741 Replacement fields should be named and not contain any attributes or indexing.
742 """
743 regex_expr = str(fmt)
745 # Split the expression by parts separated by replacement fields
746 split = re.split(r'({[^{}]+})', regex_expr)
747 # Every other element is guaranteed to be a replacement field
748 # Escape everything that is not a replacement field
749 split[::2] = [re.escape(s) for s in split[::2]]
750 # Join the expression back together
751 regex_expr = ''.join(split)
753 # Replace float variables
754 regex_expr = re.sub(r'{(time|freq|pulsefreq|pulsefwhm)[-:.\w]+}',
755 r'(?P<\1>[-+]?[\\d.]+)',
756 regex_expr)
758 # Replace reim and tag
759 regex_expr = re.sub(r'{(reim)}', r'(?P<\1>[A-za-z]+)', regex_expr)
760 regex_expr = re.sub(r'{(tag)}', r'(?P<\1>[-A-za-z]*)', regex_expr)
762 # Replace other
763 regex_expr = re.sub(r'{(\w*)}', r'(?P<\1>[A-za-z]*)', regex_expr)
765 compiled = re.compile(regex_expr)
767 return compiled
770def partial_format(fmt, **kwargs) -> str:
771 """ Partially format the format string.
773 Equivalent to calling ``fmt.format(**kwargs)`` but replacement fields
774 that are not present in the ``**kwargs`` will be left in the format string.
776 Parameters
777 ----------
778 fmt
779 Format string.
780 **kwargs
781 Passed to the :py:meth:`str.format` call.
783 Returns
784 -------
785 Partially formatted string.
787 Example
788 -------
789 >>> fmt = 'pulserho_pf{pulsefreq:.2f}/t{time:09.1f}{tag}.npy'
790 >>> partial_format(fmt, pulsefreq=3.8)
791 pulserho_pf3.80/t{time:09.1f}{tag}.npy
792 """
793 def partial_format_single(s):
794 try:
795 # Try to format
796 return s.format(**kwargs)
797 except KeyError:
798 # If the replacement field is not among the kwargs, return unchanged
799 return s
801 # Split the expression by parts separated by replacement fields
802 split = re.split(r'({[^{}]+})', fmt)
803 # Every other element is guaranteed to be a replacement field
804 # Try to format each field separately
805 split[1::2] = [partial_format_single(s) for s in split[1::2]]
806 # Join the expression back together
807 fmt = ''.join(split)
809 return fmt
812def find_files(fmt: str,
813 log: Logger | None = None,
814 *,
815 expected_keys: list[str]):
816 """ Find files in file system matching the format string :attr:`fmt`.
818 This function walks the file tree and looks for file names matching the
819 format string :attr:`fmt`.
821 Parameters
822 ----------
823 fmt
824 Format string.
825 log
826 Optional logger object.
827 expected_keys
828 List of replacement fields ``{...}`` that are expected to be parsed from
829 the file names. Unexpected fields raise :py:exc:`ValueError`.
831 Returns
832 -------
833 Dictionary with keys, sorted by the parsed values matching :attr:`expected_keys`:
835 ``filename`` - List of filenames found.
836 **key** - List of parsed value for each key in :attr:`expected_keys`.
838 Example
839 -------
840 >>> fmt = 'pulserho_pf3.80/t{time:09.1f}{tag}.npy'
841 >>> find_files(fmt, expected_keys=['time', 'tag'])
842 {'filename': ['pulserho_pf3.80/t0000010.0.npy',
843 'pulserho_pf3.80/t0000010.0-Iomega.npy',
844 'pulserho_pf3.80/t0000060.0.npy',
845 'pulserho_pf3.80/t0000060.0-Iomega.npy'],
846 'time': [10.0, 10.0, 60.0, 60.0],
847 'tag': ['', '-Iomega', '', '-Iomega']}
848 """
849 if log is None:
850 log = NoLogger()
852 # Find base (containing no format string replacement fields)
853 non_format_parents = [parent for parent in Path(fmt).parents
854 if '{' not in parent.name]
855 base = non_format_parents[0] if len(non_format_parents) > 0 else Path('.')
856 log(str(base), who='Find files', rank=0)
858 # Express the format string relative to the base
859 rel_pulserho_fmt = str(Path(fmt).relative_to(base))
860 log(rel_pulserho_fmt, who='Find files', rank=0)
862 # Convert format specifier to glob, and to regex
863 pulserho_glob = format_string_to_glob(rel_pulserho_fmt)
864 pulserho_regex = format_string_to_regex(rel_pulserho_fmt)
865 log(pulserho_glob, who='Find files', rank=0)
866 log(pulserho_regex.pattern, who='Find files', rank=0)
868 matching = base.glob(pulserho_glob)
870 found: list[dict[str, Any]] = []
872 # Loop over the matching files
873 for match in matching:
874 relmatch = match.relative_to(base)
875 m = pulserho_regex.fullmatch(str(relmatch))
876 if m is None:
877 continue
878 metadata = {key: float(value) if key not in ['tag', 'reim'] else value
879 for key, value in m.groupdict().items()}
880 fname = str(base / relmatch)
881 test = fmt.format(**metadata)
882 assert fname == test, fname + ' != ' + test
883 if set(metadata.keys()) > set(expected_keys):
884 raise ValueError(f'Found unexpected key in file name {base / relmatch}:\n'
885 f'Found {metadata}\nExpected {expected_keys}')
886 log(relmatch, metadata, who='Find files', rank=0)
887 metadata['filename'] = fname
888 found.append(metadata)
890 # Sort list of found files by expected_keys
891 found = sorted(found, key=itemgetter(*expected_keys))
893 # Unwrap so that we return one dictionary of lists
894 ret = {key: [f.get(key, None) for f in found]
895 for key in ['filename'] + expected_keys}
897 return ret