Coverage for rhodent/utils/__init__.py: 88%

282 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-08-01 16:57 +0000

1from __future__ import annotations 

2 

3import re 

4from contextlib import nullcontext 

5from operator import itemgetter 

6from pathlib import Path 

7from typing import Any, Callable, Generic, Iterable, NamedTuple, TypeVar 

8import numpy as np 

9from numpy.typing import NDArray 

10from numpy._typing import _DTypeLike as DTypeLike # parametrizable wrt generic 

11 

12from ase.io.ulm import open 

13from ase.parallel import parprint 

14from gpaw.lcaotddft.ksdecomposition import KohnShamDecomposition 

15from gpaw.lcaotddft.laser import GaussianPulse 

16from gpaw.mpi import SerialCommunicator, world 

17from gpaw.tddft.units import fs_to_au, au_to_eV 

18 

19from .logging import Logger, NoLogger 

20from .result import Result, ResultKeys 

21from ..perturbation import Perturbation 

22from ..typing import Array1D, Communicator 

23 

24__all__ = [ 

25 'Logger', 

26 'Result', 

27 'ResultKeys', 

28] 

29 

30 

31DTypeT = TypeVar('DTypeT', bound=np.generic, covariant=True) 

32 

33 

34class ParallelMatrix(Generic[DTypeT]): 

35 

36 """ Distributed array, with data on the root rank. 

37 

38 Parameters 

39 ---------- 

40 shape 

41 Shape of array. 

42 dtype 

43 Dtype of array. 

44 comm 

45 MPI communicator. 

46 array 

47 Array on root rank of the communicator. Must be ``None`` on other ranks. 

48 """ 

49 

50 def __init__(self, 

51 shape: tuple[int, ...], 

52 dtype: DTypeLike[DTypeT], 

53 comm: Communicator | None = None, 

54 array: NDArray[DTypeT] | None = None): 

55 if comm is None: 

56 comm = world 

57 self.comm = comm 

58 self.shape = shape 

59 self.dtype = np.dtype(dtype) 

60 

61 self._array: NDArray[DTypeT] | None 

62 if self.root: 

63 assert array is not None 

64 assert array.shape == shape 

65 assert array.dtype == np.dtype(dtype) 

66 self._array = array 

67 else: 

68 assert array is None 

69 self._array = None 

70 

71 @property 

72 def array(self) -> NDArray[DTypeT]: 

73 """ Array with data. May only be called on the root rank. """ 

74 if not self.root: 

75 raise RuntimeError('May only be called on root') 

76 assert self._array is not None 

77 return self._array 

78 

79 @property 

80 def root(self) -> bool: 

81 """ Whether this rank is the root rank. """ 

82 return self.comm.rank == 0 

83 

84 @property 

85 def T(self) -> ParallelMatrix: 

86 shape = self.shape[:-2] + self.shape[-2:][::-1] 

87 return ParallelMatrix(shape=shape, dtype=self.dtype, comm=self.comm, 

88 array=self.array.T if self.root else None) 

89 

90 def broadcast(self) -> NDArray[DTypeT]: 

91 """ Broadcasted data. """ 

92 if self.root: 

93 A = np.ascontiguousarray(self.array) 

94 else: 

95 A = np.zeros(self.shape, self.dtype) 

96 

97 self.comm.broadcast(A, 0) 

98 

99 return A 

100 

101 def __matmul__(self, other) -> ParallelMatrix[DTypeT]: 

102 """ Perform matrix multiplication in parallel. """ 

103 if not isinstance(other, ParallelMatrix): 

104 raise NotImplementedError 

105 

106 assert self.dtype == other.dtype 

107 

108 A = self.broadcast() 

109 B = other.broadcast() 

110 

111 # Allocate array for result 

112 ni, nj = A.shape[-2:] 

113 nk = B.shape[-1] 

114 C_shape = np.broadcast_shapes(A.shape[:-2], B.shape[:-2]) + (ni, nk) 

115 C = np.zeros(C_shape, self.dtype) 

116 

117 # Determine slice for ranks 

118 stridek = (nk + self.comm.size - 1) // self.comm.size 

119 slicek = slice(stridek * self.comm.rank, stridek * (self.comm.rank + 1)) 

120 

121 # Perform the matrix multiplication 

122 C[..., :, slicek] = A @ B[..., :, slicek] 

123 

124 # Sum to root rank 

125 self.comm.sum(C, 0) 

126 

127 result = ParallelMatrix(C_shape, self.dtype, comm=self.comm, 

128 array=C if self.root else None) 

129 return result 

130 

131 

132def gauss_ij_with_filter(energy_i: np.typing.ArrayLike, 

133 energy_j: np.typing.ArrayLike, 

134 sigma: float, 

135 fltthresh: float | None = None, 

136 ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: 

137 r""" Computes the matrix 

138 

139 .. math:: 

140 

141 M_{ij} 

142 = \frac{1}{\sqrt{2 \pi \sigma^2}} 

143 \exp\left(-\frac{ 

144 \left(\varepsilon_i - \varepsilon_j\right)^2 

145 }{ 

146 2 \sigma^2 

147 }\right) 

148 

149 Useful for Gaussian broadening. Optionally only computes the exponent 

150 above a certain threshold, and returns the filter. 

151 

152 Parameters 

153 ---------- 

154 energy_i 

155 Energies :math:`\varepsilon_i`. 

156 energy_j 

157 Energies :math:`\varepsilon_j`. 

158 sigma 

159 Gaussian broadening width :math:`\sigma`. 

160 fltthresh 

161 Filtering threshold. 

162 

163 Returns 

164 ------- 

165 Matrix :math:`M_{ij}`, filter. 

166 """ 

167 energy_i = np.asarray(energy_i) 

168 energy_j = np.asarray(energy_j) 

169 

170 norm = 1.0 / (sigma * np.sqrt(2 * np.pi)) 

171 

172 denergy_ij = energy_i[:, np.newaxis] - energy_j[np.newaxis, :] 

173 exponent_ij = -0.5 * (denergy_ij / sigma) ** 2 

174 

175 if fltthresh is not None: 

176 flt_i = np.any(exponent_ij > fltthresh, axis=1) 

177 M_ij = np.zeros_like(exponent_ij) 

178 M_ij[flt_i] = norm * np.exp(exponent_ij[flt_i]) 

179 else: 

180 flt_i = np.ones(energy_i.shape, dtype=bool) 

181 M_ij = norm * np.exp(exponent_ij) 

182 

183 return M_ij, flt_i # type: ignore 

184 

185 

186def gauss_ij(energy_i: np.typing.ArrayLike, 

187 energy_j: np.typing.ArrayLike, 

188 sigma: float) -> NDArray[np.float64]: 

189 r""" Computes the matrix 

190 

191 .. math:: 

192 

193 M_{ij} 

194 = \frac{1}{\sqrt{2 \pi \sigma^2}} 

195 \exp\left(-\frac{ 

196 \left(\varepsilon_i - \varepsilon_j\right)^2 

197 }{ 

198 2 \sigma^2 

199 }\right), 

200 

201 which is useful for Gaussian broadening. 

202 

203 Parameters 

204 ---------- 

205 energy_i 

206 Energies :math:`\varepsilon_i`. 

207 energy_j 

208 Energies :math:`\varepsilon_j`. 

209 sigma 

210 Gaussian broadening width :math:`\sigma`. 

211 

212 Returns 

213 ------- 

214 Matrix :math:`M_{ij}`. 

215 """ 

216 M_ij, _ = gauss_ij_with_filter(energy_i, energy_j, sigma) 

217 return M_ij 

218 

219 

220def broaden_n2e(M_n: np.typing.ArrayLike, 

221 energy_n: np.typing.ArrayLike, 

222 energy_e: np.typing.ArrayLike, 

223 sigma: float) -> NDArray[np.float64]: 

224 r""" Broaden matrix onto energy grids 

225 

226 .. math:: 

227 

228 M(\varepsilon_e) 

229 = \sum_n M_n \frac{1}{\sqrt{2 \pi \sigma^2}} 

230 \exp\left(-\frac{ 

231 \left(\varepsilon_n - \varepsilon_e\right)^2 

232 }{ 

233 2 \sigma^2 

234 }\right), 

235 

236 Returns 

237 ------- 

238 :math:`M(\varepsilon_0)` 

239 """ 

240 M_n = np.asarray(M_n) 

241 gauss_ne, flt_n = gauss_ij_with_filter(energy_n, energy_e, sigma) 

242 

243 M_e = np.einsum('n,ne->e', M_n[flt_n], gauss_ne[flt_n], optimize=True) 

244 

245 return M_e 

246 

247 

248def broaden_xn2e(M_xn: np.typing.ArrayLike, 

249 energy_n: np.typing.ArrayLike, 

250 energy_e: np.typing.ArrayLike, 

251 sigma: float) -> NDArray[np.float64]: 

252 r""" Broaden matrix onto energy grids 

253 

254 .. math:: 

255 

256 M(\varepsilon_e) 

257 = \sum_n M_n \frac{1}{\sqrt{2 \pi \sigma^2}} 

258 \exp\left(-\frac{ 

259 \left(\varepsilon_n - \varepsilon_e\right)^2 

260 }{ 

261 2 \sigma^2 

262 }\right). 

263 

264 Returns 

265 ------- 

266 :math:`M(\varepsilon_0)`. 

267 """ 

268 M_xn = np.asarray(M_xn) 

269 gauss_ne, flt_n = gauss_ij_with_filter(energy_n, energy_e, sigma) 

270 

271 M_xe = np.einsum('xn,ne->xe', 

272 M_xn.reshape((-1, len(flt_n)))[:, flt_n], 

273 gauss_ne[flt_n], 

274 optimize=True).reshape(M_xn.shape[:-1] + (-1, )) 

275 

276 return M_xe 

277 

278 

279def broaden_ia2ou(M_ia: np.typing.ArrayLike, 

280 energy_i: np.typing.ArrayLike, 

281 energy_a: np.typing.ArrayLike, 

282 energy_o: np.typing.ArrayLike, 

283 energy_u: np.typing.ArrayLike, 

284 sigma: float) -> NDArray[np.float64]: 

285 r""" Broaden matrix onto energy grids. 

286 

287 .. math:: 

288 

289 M(\varepsilon_o, \varepsilon_u) 

290 = \sum_{ia} M_{ia} \frac{1}{\sqrt{2 \pi \sigma^2}} 

291 \exp\left(-\frac{ 

292 (\varepsilon_i - \varepsilon_o)^2 

293 }{ 

294 2 \sigma^2 

295 }\right) 

296 \exp\left(-\frac{ 

297 (\varepsilon_a - \varepsilon_u)^2 

298 }{ 

299 2 \sigma^2 

300 }\right) 

301 

302 Returns 

303 ------- 

304 :math:`M(\varepsilon_o, \varepsilon_u)`. 

305 """ 

306 M_ia = np.asarray(M_ia) 

307 ia_shape = M_ia.shape[:2] 

308 x_shape = M_ia.shape[2:] 

309 M_iax = M_ia.reshape(ia_shape + (-1, )) 

310 gauss_io, flt_i = gauss_ij_with_filter(energy_i, energy_o, sigma) 

311 gauss_au, flt_a = gauss_ij_with_filter(energy_a, energy_u, sigma) 

312 

313 M_oux = np.einsum('iax,io,au->oux', M_iax[flt_i, :][:, flt_a], 

314 gauss_io[flt_i], gauss_au[flt_a], 

315 optimize=True, order='C') 

316 

317 return M_oux.reshape(M_oux.shape[:2] + x_shape) 

318 

319 

320def get_array_filter(values: Array1D[np.float64] | list[float], 

321 filter_values: Array1D[np.float64] | list[float] | None, 

322 ) -> slice | Array1D[np.bool_]: 

323 """ Get array filter that can be used to filter out data. 

324 

325 Parameters 

326 ---------- 

327 values 

328 Array of values, e.g. linspace of times or frequencies. 

329 filter_values 

330 List of values that one wishes to extract. The closes values from values 

331 will be selected as filter. 

332 

333 Returns 

334 ------- 

335 Object that can be used to index values and arrays with the same shape as values. 

336 """ 

337 _values = np.array(values) 

338 flt_x: slice | NDArray[np.bool_] 

339 if len(values) == 0: 

340 # Empty list of arrays 

341 return slice(None) 

342 

343 if filter_values is None or len(filter_values) == 0: 

344 # No filter 

345 return slice(None) 

346 

347 flt_x = np.zeros(len(values), dtype=bool) 

348 for filter_value in filter_values: 

349 # Search for closest value 

350 idx = np.argmin(np.abs(_values - filter_value)) 

351 flt_x[idx] = True 

352 

353 return flt_x 

354 

355 

356def filter_array(values: Array1D[np.float64] | list[float], 

357 filter_values: Array1D[np.float64] | list[float] | None, 

358 ) -> Array1D[np.float64]: 

359 """ Filter array, picking values closest to :attr:`filter_values`. 

360 

361 Parameters 

362 ---------- 

363 values 

364 Array of values, e.g. linspace of times or frequencies. 

365 filter_values 

366 List of values that one wishes to extract. The closes values from values 

367 will be selected as filter. 

368 

369 Returns 

370 ------- 

371 Filtered array. 

372 """ 

373 array = np.array(values) 

374 return array[get_array_filter(array, filter_values)] # type: ignore 

375 

376 

377def two_communicator_sizes(*comm_sizes) -> tuple[int, int]: 

378 assert len(comm_sizes) == 2 

379 comm_size_c: list[int] = [world.size if size == 'world' else size for size in comm_sizes] 

380 if comm_size_c[0] == -1: 

381 comm_size_c[0] = world.size // comm_size_c[1] 

382 elif comm_size_c[1] == -1: 

383 comm_size_c[1] = world.size // comm_size_c[0] 

384 

385 assert np.prod(comm_size_c) == world.size, \ 

386 f'Communicator sizes must factorize world size {world.size} '\ 

387 'but they are ' + ' and '.join([str(s) for s in comm_size_c]) + '.' 

388 return comm_size_c[0], comm_size_c[1] 

389 

390 

391def two_communicators(*comm_sizes) -> tuple[Communicator, Communicator]: 

392 """ Create two MPI communicators. 

393 

394 Must satisfy ``comm_sizes[0] * comm_sizes[1] = world.size``. 

395 

396 The second communicator has the ranks in sequence. 

397 

398 Example 

399 ------- 

400 

401 >>> world.size == 8 

402 >>> two_communicators(2, 4) 

403 

404 This gives:: 

405 

406 [0, 4] 

407 [1, 5] 

408 [2, 6] 

409 [3, 7] 

410 

411 and:: 

412 

413 [0, 1, 2, 3] 

414 [4, 5, 6, 7] 

415 """ 

416 comm_size_c = two_communicator_sizes(*comm_sizes) 

417 

418 # Create communicators 

419 if comm_size_c[0] == 1: 

420 return (SerialCommunicator(), world) # type: ignore 

421 elif comm_size_c[0] == world.size: 

422 return (world, SerialCommunicator()) # type: ignore 

423 else: 

424 assert world.size % comm_size_c[0] == 0, world.size 

425 # Comm 2, ranks in sequence. Comm 1, ranks skip by size of comm 2 

426 first_rank_in_comm_c = [world.rank % comm_size_c[1], 

427 world.rank - world.rank % comm_size_c[1]] 

428 step_c = [comm_size_c[1], 1] 

429 comm_ranks_cr = [list(range(start, start + size*step, step)) 

430 for start, size, step in zip(first_rank_in_comm_c, comm_size_c, step_c)] 

431 comm_c = [world.new_communicator(comm_ranks_r) for comm_ranks_r in comm_ranks_cr] 

432 return comm_c[0], comm_c[1] 

433 

434 

435def detect_repeatrange(n: int, 

436 stride: int, 

437 verbose: bool = True) -> slice | None: 

438 """ If an array of length :attr:`n` is not divisible by the stride :attr:`stride` 

439 then some work will have to be repeated 

440 """ 

441 final_start = (n // stride) * stride 

442 repeatrange = slice(final_start, n) 

443 if repeatrange.start == repeatrange.stop: 

444 return None 

445 else: 

446 print(f'Detected repeatrange {repeatrange}', flush=True) 

447 return repeatrange 

448 

449 return None 

450 

451 

452def safe_fill(a: NDArray[DTypeT], 

453 b: NDArray[DTypeT]): 

454 """ Perform the operation ``a[:] = b``, checking if the dimensions match. 

455 

456 If the dimensions of :attr:`b` are larger than the dimensions of :attr:`a`, raise an error. 

457 

458 If the dimensions of :attr:`b` are smaller than the dimensions of :attr:`a`, write to 

459 the first elements of :attr:`a`. 

460 """ 

461 assert len(a.shape) == len(b.shape), f'{a.shape} != {b.shape}' 

462 assert all([dima >= dimb for dima, dimb in zip(a.shape, b.shape)]), f'{a.shape} < {b.shape}' 

463 s = tuple([slice(dim) for dim in b.shape]) 

464 a[s] = b 

465 

466 

467def safe_fill_larger(a: NDArray[DTypeT], 

468 b: NDArray[DTypeT]): 

469 """ Perform the operation ``a[:] = b``, checking if the dimensions match. 

470 

471 If the dimensions of :attr:`b` are smaller than the dimensions of :attr:`a`, raise an error. 

472 

473 If the dimensions of :attr:`b` are larger than the dimensions of :attr:`a`, write the first 

474 elements of :attr:`b` to :attr:`a`. 

475 """ 

476 assert len(a.shape) == len(b.shape), f'{a.shape} != {b.shape}' 

477 assert all([dimb >= dima for dima, dimb in zip(a.shape, b.shape)]), f'{a.shape} > {b.shape}' 

478 s = tuple([slice(dim) for dim in a.shape]) 

479 a[:] = b[s] 

480 

481 

482IND = TypeVar('IND', slice, tuple[slice, ...]) 

483 

484 

485def concatenate_indices(indices_list: Iterable[IND], 

486 ) -> tuple[IND, list[IND]]: 

487 """ Concatenate indices. 

488 

489 Given an array A and a list of incides indices_list such that A can be indexed 

490 

491 >>> for indices in indices_list: 

492 >>> A[indices] 

493 

494 this function shall concatenate the indices into indices_concat so that the array 

495 can be indexed in one go. This function will also give a new list of indices 

496 new_indices_list that can be used to index the ``A[indices_concat]``. The following 

497 snippet shall be equivalent to the previous snipped. 

498 

499 >>> B = A[indices_concat] 

500 >>> for indices in new_indices_list: 

501 >>> B[indices] 

502 

503 Note that the indices need not be ordered, nor contigous, but the returned 

504 indices_concat will be a list of slices, and thus contiguous. 

505 

506 Example 

507 ------- 

508 

509 >>> A = np.random.rand(100) 

510 >>> value = 0 

511 >>> new_value = 0 

512 >>> 

513 >>> indices_list = [slice(10, 12), slice(12, 19)] 

514 >>> for indices in indices_list: 

515 >>> value += np.sum(A[indices]) 

516 >>> 

517 >>> indices_concat, new_indices_list = concatenate_indices(indices_list) 

518 >>> new_value = np.sum(A[indices_concat]) 

519 >>> 

520 >>> assert abs(value - new_value) < 1e-10 

521 >>> 

522 >>> B = A[indices_concat] 

523 >>> assert B.shape == (9, ) 

524 >>> new_value = 0 

525 >>> for indices in new_indices_list: 

526 >>> new_value += np.sum(B[indices]) 

527 >>> 

528 >>> assert abs(value - new_value) < 1e-10 

529 

530 Returns 

531 ------- 

532 ``(indices_concat, new_indices_list)`` 

533 """ 

534 indices_list = list(indices_list) 

535 if len(indices_list) == 0: 

536 return slice(0), [] # type: ignore 

537 

538 if not isinstance(indices_list[0], tuple): 

539 # If indices are not tuples, then wrap everything in tuples and recurse 

540 assert all([not isinstance(indices, tuple) for indices in indices_list]) 

541 _indices_concat, _new_indices_list = _concatenate_indices([(indices, ) for indices in indices_list]) 

542 return _indices_concat[0], [indices[0] for indices in _new_indices_list] 

543 

544 # All indices are wrapped in tuples 

545 assert all([isinstance(indices, tuple) for indices in indices_list]) 

546 return _concatenate_indices(indices_list) # type: ignore 

547 

548 

549def _concatenate_indices(indices_list: Iterable[tuple[slice, ...]], 

550 ) -> tuple[tuple[slice, ...], list[tuple[slice, ...]]]: 

551 """ See :func:`concatenate_indices` 

552 """ 

553 limits_jis = np.array([[(index.start, index.stop, index.step) for index in indices] 

554 for indices in indices_list]) 

555 

556 start_i = np.min(limits_jis[..., 0], axis=0) 

557 stop_i = np.max(limits_jis[..., 1], axis=0) 

558 

559 indices_concat = tuple([slice(start, stop) for start, stop in zip(start_i, stop_i)]) 

560 new_indices_list = [tuple([slice(start - startcat, stop - startcat, step) 

561 for (startcat, (start, stop, step)) in zip(start_i, limits_is)]) 

562 for limits_is in limits_jis] 

563 

564 return indices_concat, new_indices_list 

565 

566 

567def parulmopen(fname: str, comm: Communicator, *args, **kwargs): 

568 if comm.rank == 0: 

569 return open(fname, *args, **kwargs) 

570 else: 

571 return nullcontext() 

572 

573 

574def proxy_sknX_slicen(reader, *args, comm: Communicator) -> NDArray[np.complex128]: 

575 if len(args) == 0: 

576 A_sknX = reader 

577 else: 

578 A_sknX = reader.proxy(*args) 

579 nn = A_sknX.shape[2] 

580 nlocaln = (nn + comm.size - 1) // comm.size 

581 myslicen = slice(comm.rank * nlocaln, (comm.rank + 1) * nlocaln) 

582 my_A_sknX = np.array([[A_nX[myslicen] for A_nX in A_knX] for A_knX in A_sknX]) 

583 

584 return my_A_sknX 

585 

586 

587def add_fake_kpts(ksd: KohnShamDecomposition): 

588 """This function is necessary to read some fields without having a 

589 calculator attached. 

590 """ 

591 

592 class FakeKpt(NamedTuple): 

593 s: int 

594 k: int 

595 

596 class FakeKsl(NamedTuple): 

597 using_blacs: bool = False 

598 

599 # Figure out 

600 ksdreader = ksd.reader 

601 skshape = ksdreader.eig_un.shape[:2] 

602 kpt_u = [FakeKpt(s=s, k=k) 

603 for s in range(skshape[0]) 

604 for k in range(skshape[1])] 

605 ksd.kpt_u = kpt_u 

606 ksd.ksl = FakeKsl() 

607 

608 

609def create_pulse(frequency: float, 

610 fwhm: float = 5.0, 

611 t0: float = 10.0, 

612 print: Callable | None = None) -> GaussianPulse: 

613 """ Create Gaussian laser pulse. 

614 

615 frequency 

616 Pulse frequncy in units of eV. 

617 fwhm 

618 Full width at half maximum in time domain in units of fs. 

619 t0 

620 Maximum of pulse envelope in units of fs. 

621 print 

622 Printing function to control verbosity. 

623 """ 

624 if print is None: 

625 print = parprint 

626 

627 # Pulse 

628 fwhm_eV = 8 * np.log(2) / (fwhm * fs_to_au) * au_to_eV 

629 tau = fwhm / (2 * np.sqrt(2 * np.log(2))) 

630 sigma = 1 / (tau * fs_to_au) * au_to_eV # eV 

631 strength = 1e-6 

632 t0 = t0 * 1e3 

633 sincos = 'cos' 

634 print(f'Creating pulse at {frequency:.3f}eV with FWHM {fwhm:.2f}fs ' 

635 f'({fwhm_eV:.2f}eV) t0 {t0:.1f}fs', flush=True) 

636 

637 return GaussianPulse(strength, t0, frequency, sigma, sincos) 

638 

639 

640def get_gaussian_pulse_values(pulse: Perturbation) -> dict[str, float]: 

641 """ Get pulse frequency and FWHM of Gaussian pulse. 

642 

643 Returns 

644 ------- 

645 Empty dictionary if pulse is not `GaussianPulse`, otherwise dictionary with keys: 

646 

647 ``pulsefreq`` - pulse frequency in units of eV. 

648 ``pulsefwhm`` - pulse full width at half-maximum in time domain in units of fs. 

649 """ 

650 from gpaw.tddft.units import eV_to_au, au_to_fs 

651 

652 d = pulse.todict() 

653 ret: dict[str, float] = dict() 

654 if d['name'] == 'GaussianPulse': 

655 ret['pulsefreq'] = d['frequency'] 

656 ret['pulsefwhm'] = (1 / (d['sigma'] * eV_to_au) * au_to_fs * 

657 (2 * np.sqrt(2 * np.log(2)))) 

658 return ret 

659 

660 

661fast_pad_nice_factors = np.array([4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 8096]) 

662 

663 

664def fast_pad(nt: int) -> int: 

665 """ Return a length that is at least twice as large as the given input, 

666 and the FFT of data of such length is fast. 

667 """ 

668 padnt = 2 * nt 

669 insert = np.searchsorted(fast_pad_nice_factors, padnt) 

670 if insert <= len(fast_pad_nice_factors): 

671 padnt = fast_pad_nice_factors[insert] 

672 assert padnt >= 2 * nt 

673 return padnt 

674 

675 

676def format_string_to_glob(fmt: str) -> str: 

677 """ Convert a format string to a glob-type expression. 

678 

679 Replaces all the replacement fields ``{...}`` in the format string 

680 with a glob ``*``. 

681 

682 Example 

683 ------- 

684 >>> format_string_to_glob('pulserho_pf{pulsefreq:.2f}/t{time:09.1f}{tag}.npy') 

685 pulserho_pf*/t*.npy 

686 

687 Parameters 

688 --------- 

689 fmt 

690 Format string. 

691 

692 Returns 

693 ------- 

694 Glob-type expression. 

695 """ 

696 # Replace replacement fields by * 

697 # Note how several replacement fields next to each other are 

698 # replaced by only one * thanks to the (...)+ 

699 glob_expr = re.sub(r'({[^{}]*})+', '*', fmt) 

700 return glob_expr 

701 

702 

703def format_string_to_regex(fmt: str) -> re.Pattern: 

704 r""" Convert a format string to a regex expression. 

705 

706 Replaces all the replacement fields ``{...}`` in the format string 

707 with a regular expression and escapes all special characters outside 

708 the replacement fields. 

709 

710 Replacement fields for variables ``time``, ``freq``, ``pulsefreq`` 

711 and ``pulsefwhm`` are replaced by regex matching floating point numbers. 

712 Replacement fields for variables ``reim`` and ``tag`` are replaced by 

713 regex matching alphabetic characters and dashes. 

714 Remaining replacement fields are replaced by regex matching alphabetic 

715 characters. 

716 

717 This can be used to parse a formatted string in order to get back the original 

718 values. 

719 

720 Example 

721 ------- 

722 >>> fmt = 'pulserho_pf{pulsefreq:.2f}/t{time:09.1f}{tag}.npy' 

723 >>> s = fmt.format(pulsefreq=3.8, time=30000, tag='-Iomega') 

724 pulserho_pf3.80/t0030000.0-Iomega.npy 

725 >>> regex = format_string_to_regex(fmt) 

726 re.compile('pulserho_pf(?P<pulsefreq>[-+]?[\d.]+)/t(?P<time>[-+]?[\d.]+)(?P<tag>[-A-za-z]*)\.npy') 

727 >>> regex.fullmatch(s).groupdict() 

728 {'pulsefreq': '3.80', 'time': '0030000.0', 'tag': '-Iomega'} 

729 

730 Parameters 

731 --------- 

732 fmt 

733 Format string. 

734 

735 Returns 

736 ------- 

737 Compiled regex pattern. 

738 

739 Notes 

740 ----- 

741 Replacement fields should be named and not contain any attributes or indexing. 

742 """ 

743 regex_expr = str(fmt) 

744 

745 # Split the expression by parts separated by replacement fields 

746 split = re.split(r'({[^{}]+})', regex_expr) 

747 # Every other element is guaranteed to be a replacement field 

748 # Escape everything that is not a replacement field 

749 split[::2] = [re.escape(s) for s in split[::2]] 

750 # Join the expression back together 

751 regex_expr = ''.join(split) 

752 

753 # Replace float variables 

754 regex_expr = re.sub(r'{(time|freq|pulsefreq|pulsefwhm)[-:.\w]+}', 

755 r'(?P<\1>[-+]?[\\d.]+)', 

756 regex_expr) 

757 

758 # Replace reim and tag 

759 regex_expr = re.sub(r'{(reim)}', r'(?P<\1>[A-za-z]+)', regex_expr) 

760 regex_expr = re.sub(r'{(tag)}', r'(?P<\1>[-A-za-z]*)', regex_expr) 

761 

762 # Replace other 

763 regex_expr = re.sub(r'{(\w*)}', r'(?P<\1>[A-za-z]*)', regex_expr) 

764 

765 compiled = re.compile(regex_expr) 

766 

767 return compiled 

768 

769 

770def partial_format(fmt, **kwargs) -> str: 

771 """ Partially format the format string. 

772 

773 Equivalent to calling ``fmt.format(**kwargs)`` but replacement fields 

774 that are not present in the ``**kwargs`` will be left in the format string. 

775 

776 Parameters 

777 ---------- 

778 fmt 

779 Format string. 

780 **kwargs 

781 Passed to the :py:meth:`str.format` call. 

782 

783 Returns 

784 ------- 

785 Partially formatted string. 

786 

787 Example 

788 ------- 

789 >>> fmt = 'pulserho_pf{pulsefreq:.2f}/t{time:09.1f}{tag}.npy' 

790 >>> partial_format(fmt, pulsefreq=3.8) 

791 pulserho_pf3.80/t{time:09.1f}{tag}.npy 

792 """ 

793 def partial_format_single(s): 

794 try: 

795 # Try to format 

796 return s.format(**kwargs) 

797 except KeyError: 

798 # If the replacement field is not among the kwargs, return unchanged 

799 return s 

800 

801 # Split the expression by parts separated by replacement fields 

802 split = re.split(r'({[^{}]+})', fmt) 

803 # Every other element is guaranteed to be a replacement field 

804 # Try to format each field separately 

805 split[1::2] = [partial_format_single(s) for s in split[1::2]] 

806 # Join the expression back together 

807 fmt = ''.join(split) 

808 

809 return fmt 

810 

811 

812def find_files(fmt: str, 

813 log: Logger | None = None, 

814 *, 

815 expected_keys: list[str]): 

816 """ Find files in file system matching the format string :attr:`fmt`. 

817 

818 This function walks the file tree and looks for file names matching the 

819 format string :attr:`fmt`. 

820 

821 Parameters 

822 ---------- 

823 fmt 

824 Format string. 

825 log 

826 Optional logger object. 

827 expected_keys 

828 List of replacement fields ``{...}`` that are expected to be parsed from 

829 the file names. Unexpected fields raise :py:exc:`ValueError`. 

830 

831 Returns 

832 ------- 

833 Dictionary with keys, sorted by the parsed values matching :attr:`expected_keys`: 

834 

835 ``filename`` - List of filenames found. 

836 **key** - List of parsed value for each key in :attr:`expected_keys`. 

837 

838 Example 

839 ------- 

840 >>> fmt = 'pulserho_pf3.80/t{time:09.1f}{tag}.npy' 

841 >>> find_files(fmt, expected_keys=['time', 'tag']) 

842 {'filename': ['pulserho_pf3.80/t0000010.0.npy', 

843 'pulserho_pf3.80/t0000010.0-Iomega.npy', 

844 'pulserho_pf3.80/t0000060.0.npy', 

845 'pulserho_pf3.80/t0000060.0-Iomega.npy'], 

846 'time': [10.0, 10.0, 60.0, 60.0], 

847 'tag': ['', '-Iomega', '', '-Iomega']} 

848 """ 

849 if log is None: 

850 log = NoLogger() 

851 

852 # Find base (containing no format string replacement fields) 

853 non_format_parents = [parent for parent in Path(fmt).parents 

854 if '{' not in parent.name] 

855 base = non_format_parents[0] if len(non_format_parents) > 0 else Path('.') 

856 log(str(base), who='Find files', rank=0) 

857 

858 # Express the format string relative to the base 

859 rel_pulserho_fmt = str(Path(fmt).relative_to(base)) 

860 log(rel_pulserho_fmt, who='Find files', rank=0) 

861 

862 # Convert format specifier to glob, and to regex 

863 pulserho_glob = format_string_to_glob(rel_pulserho_fmt) 

864 pulserho_regex = format_string_to_regex(rel_pulserho_fmt) 

865 log(pulserho_glob, who='Find files', rank=0) 

866 log(pulserho_regex.pattern, who='Find files', rank=0) 

867 

868 matching = base.glob(pulserho_glob) 

869 

870 found: list[dict[str, Any]] = [] 

871 

872 # Loop over the matching files 

873 for match in matching: 

874 relmatch = match.relative_to(base) 

875 m = pulserho_regex.fullmatch(str(relmatch)) 

876 if m is None: 

877 continue 

878 metadata = {key: float(value) if key not in ['tag', 'reim'] else value 

879 for key, value in m.groupdict().items()} 

880 fname = str(base / relmatch) 

881 test = fmt.format(**metadata) 

882 assert fname == test, fname + ' != ' + test 

883 if set(metadata.keys()) > set(expected_keys): 

884 raise ValueError(f'Found unexpected key in file name {base / relmatch}:\n' 

885 f'Found {metadata}\nExpected {expected_keys}') 

886 log(relmatch, metadata, who='Find files', rank=0) 

887 metadata['filename'] = fname 

888 found.append(metadata) 

889 

890 # Sort list of found files by expected_keys 

891 found = sorted(found, key=itemgetter(*expected_keys)) 

892 

893 # Unwrap so that we return one dictionary of lists 

894 ret = {key: [f.get(key, None) for f in found] 

895 for key in ['filename'] + expected_keys} 

896 

897 return ret