Source code for yaw.config.resampling

from __future__ import annotations

from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Any

import numpy as np

from yaw.config import OPTIONS
from yaw.config import default as DEFAULT
from yaw.config.abc import BaseConfig
from yaw.config.utils import ConfigError
from yaw.core.docs import Parameter

if TYPE_CHECKING:  # pragma: no cover
    from numpy.typing import NDArray

__all__ = ["ResamplingConfig"]


[docs] @dataclass(frozen=True) class ResamplingConfig(BaseConfig): """Configuration for error estimation from spatial resampling. Used for all functions and methods that use spatial patches for error estimation. Use the :meth:`get_samples` method to generate samples from the spatial patches, which can be reused to ensure consistent error estimates for different data products that use the same patches. Args: method (:obj:`str`): Resampling method to use, see :obj:`~yaw.config.options.Options.method`. crosspath (:obj:`str`): Whether to use cross-patch pair count measurements. n_boot (:obj:`int`): Number of samples to generate for the ``bootstrap`` method. global_norm (:obj:`bool`): Whether to normalise paircounts globally or for each sample. Usually not recommended. seed (:obj:`int`): Random seed to use. """ method: str = field( default=DEFAULT.Resampling.method, metadata=Parameter(type=str, help="resampling method to use"), ) """Resampling method to use, see :obj:`~yaw.config.options.Options.method`. """ crosspatch: bool = field( default=DEFAULT.Resampling.crosspatch, metadata=Parameter( type=bool, help="whether to use cross-patch pair count measurements" ), ) """Whether to use cross-patch pair count measurements.""" n_boot: int = field( default=DEFAULT.Resampling.n_boot, metadata=Parameter( type=int, help="number of samples to generate if method='bootstrap'" ), ) """Number of samples to generate for the ``bootstrap`` method.""" global_norm: bool = field( default=DEFAULT.Resampling.global_norm, metadata=Parameter( type=bool, help="whether to normalise paircounts globally or for each sample", ), ) """Whether to normalise paircounts globally or for each sample.""" seed: int = field( default=DEFAULT.Resampling.seed, metadata=Parameter(type=int, help="random seed to use"), ) """Random seed to use.""" _resampling_idx: NDArray[np.int64] | None = field( default=None, init=False, repr=False ) def __post_init__(self) -> None: if self.method not in OPTIONS.method: opts = ", ".join(f"'{s}'" for s in OPTIONS.method) raise ConfigError( f"invalid resampling method '{self.method}', must either of {opts}" )
[docs] def modify( self, method: str = DEFAULT.NotSet, crosspatch: bool = DEFAULT.NotSet, n_boot: int = DEFAULT.NotSet, global_norm: bool = DEFAULT.NotSet, seed: int = DEFAULT.NotSet, ) -> ResamplingConfig: return super().modify( method=method, crosspatch=crosspatch, n_boot=n_boot, global_norm=global_norm, seed=seed, )
@property def n_patches(self) -> int | None: """The number of spatial patches for which this configuratin is valid. Available only after generating samples with :meth:`get_samples`. Returns: int if samples have been generated, else None. """ if self._resampling_idx is None: return None elif self.method == "bootstrap": return self._resampling_idx.shape[1] else: return self._resampling_idx.shape[0] def _generate_bootstrap(self, n_patches: int) -> NDArray[np.int64]: """Generate samples for the bootstrap resampling method. For N patches, draw M realisations each containing N randomly chosen patches. """ N = n_patches rng = np.random.default_rng(seed=self.seed) return rng.integers(0, N, size=(self.n_boot, N)) def _generate_jackknife(self, n_patches: int) -> NDArray[np.int64]: """Generate samples for the jackknife resampling method. For N patches, draw N realisations by leaving out one of the N patches. """ N = n_patches idx = np.delete(np.tile(np.arange(0, N), N), np.s_[:: N + 1]) return idx.reshape((N, N - 1))
[docs] def get_samples(self, n_patches: int) -> NDArray[np.int64]: """Generate a list of patch indices that produces samples for the selected resampling method. Args: n_patches (:obj:`int`): Total number of patches for which the samples are generated. .. Note:: Samples are generated only once for each instance. Later calls to this method will only check if the number of patches agree with the first call and return the initially generated index list. Raises a :exc:`ValueError` otherwise. The reason is, that the ``bootstrap`` method produces random samples, which must be consistent if the resampling is applied to different pair count measurements. """ if self._resampling_idx is None: if self.method == "jackknife": idx = self._generate_jackknife(n_patches) else: idx = self._generate_bootstrap(n_patches) object.__setattr__(self, "_resampling_idx", idx) elif n_patches != self.n_patches: raise ValueError( f"'n_patches' does not match, expected {self.n_patches}, but " f"got {n_patches}" ) return self._resampling_idx
[docs] def reset(self) -> None: """Reset the internally stored patch indices generated by :meth:`get_samples`.""" object.__setattr__(self, "_resampling_idx", None)
[docs] def to_dict(self) -> dict[str, Any]: if self.method == "jackknife": return dict( method=self.method, crosspatch=self.crosspatch, global_norm=self.global_norm, ) else: the_dict = asdict(self) the_dict.pop("_resampling_idx") return the_dict