"""
Implements CorrFunc that stores all the pair counts need to compute a
correlation function amplitude, measured in bins of redshift.
Pair counts are stored separately for data-data, data-random, etc. catalog
when running the measurements.
"""
from __future__ import annotations
import logging
from abc import abstractmethod
from functools import wraps
from pathlib import Path
from typing import TYPE_CHECKING, Generic, TypeVar
import h5py
from yaw.binning import Binning
from yaw.correlation.corrdata import CorrData
from yaw.correlation.paircounts import (
BaseNormalisedCounts,
NormalisedCounts,
NormalisedScalarCounts,
)
from yaw.utils import parallel, write_version_tag
from yaw.utils.abc import BinwiseData, HdfSerializable, PatchwiseData, Serialisable
from yaw.utils.parallel import Broadcastable, bcast_instance
if TYPE_CHECKING:
from collections.abc import Callable
from typing import Any
from h5py import Group
from numpy.typing import NDArray
from typing_extensions import Self
from yaw.utils.abc import TypeSliceIndex
T = TypeVar("T", bound=BaseNormalisedCounts)
__all__ = [
"CorrFunc",
"ScalarCorrFunc",
]
logger = logging.getLogger(__name__)
class EstimatorError(Exception):
pass
def named(key):
"""Attatch a ``.name`` attribute to a function."""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
wrapper.name = key
return wrapper
return decorator
@named("DP")
def davis_peebles(
*, dd: NDArray, dr: NDArray | None = None, rd: NDArray | None = None
) -> NDArray:
"""Davis-Peebles estimator with either RD or DR pair counts optional."""
if dr is None and rd is None:
raise EstimatorError("either 'dr' or 'rd' are required")
mixed = dr if rd is None else rd
return (dd - mixed) / mixed
@named("LS")
def landy_szalay(
*, dd: NDArray, dr: NDArray, rd: NDArray | None = None, rr: NDArray
) -> NDArray:
"""Landy-Szalay estimator with optional RD pair counts."""
if rd is None:
rd = dr
return ((dd - dr) + (rr - rd)) / rr
@named("SC")
def scalar_correlation(*, dd: NDArray, dr: NDArray | None = None) -> NDArray:
"""Scalar field estimator with optional DR pair counts."""
if dr is None:
return dd
else:
return dd - dr
class BaseCorrFunc(
Generic[T], BinwiseData, PatchwiseData, Serialisable, HdfSerializable, Broadcastable
):
"""
Base class for storing correlation function data based on pair counts.
Subclasses should implement optional pair counts as properties (e.g.
dr/rd/rr) as needed for backwards-compatibility. The keys of `_counts_name`
should match those of `_counts_dict` and the values define the group names
when serialising the class instance to/from an HDF5 file.
"""
__slots__ = ("_counts_dict",)
_counts_dict: dict[str, T]
"""Stores normalised pair counts for obtained from data/random catalogs."""
_counts_type: type[T]
"""The type of the container used in `_counts_dict`."""
_counts_name: dict[str, str]
"""Mapping of keys in `_counts_dict` to group names in HDF5 file when
serialising data."""
def _init(self, dd: T, **counts: T | None) -> None:
if type(dd) is not self._counts_type:
raise TypeError(f"pair counts must be of type {self._counts_type}")
if len(counts) == 0:
raise EstimatorError("missing at least one additional pair count")
self._counts_dict = dict(dd=dd)
for kind, count in counts.items():
if count is not None:
try:
dd.is_compatible(count, require=True)
except ValueError as err:
msg = f"pair counts '{kind}' and 'dd' are not compatible"
raise ValueError(msg) from err
self._counts_dict[kind] = count
def __repr__(self) -> str:
items = (
f"counts={'|'.join(self._counts_dict.keys())}",
f"auto={self.auto}",
f"binning={self.binning}",
f"num_patches={self.num_patches}",
)
return f"{type(self).__name__}({', '.join(items)})"
@property
def binning(self) -> Binning:
return self.dd.binning
@property
def auto(self) -> bool:
"""Whether the pair counts describe an autocorrelation function."""
return self.dd.auto
@classmethod
def from_hdf(cls: type[Self], source: Group) -> Self:
def _try_load(name: str) -> Any | None:
if name not in source:
return None
return cls._counts_type.from_hdf(source[name])
try:
cf_class = source["kind"][()].decode("utf-8")
except KeyError:
cf_class = "CorrFunc"
if cf_class != cls.__name__:
raise TypeError(f"input file stores pair counts for type '{cf_class}'")
# ignore "version" since this method did not change from legacy
kwargs = {kind: _try_load(name) for kind, name in cls._counts_name.items()}
return cls.from_dict(kwargs)
def to_hdf(self, dest: Group) -> None:
write_version_tag(dest)
dest.create_dataset("kind", data=type(self).__name__)
for kind, count in self._counts_dict.items():
name = self._counts_name[kind]
group = dest.create_group(name)
count.to_hdf(group)
@classmethod
def from_file(cls: type[Self], path: Path | str) -> Self:
new = None
if parallel.on_root():
logger.info("reading %s from: %s", cls.__name__, path)
new = super().from_file(path)
return bcast_instance(new)
@parallel.broadcasted
def to_file(self, path: Path | str) -> None:
logger.info("writing %s to: %s", type(self).__name__, path)
super().to_file(path)
def to_dict(self) -> dict[str, Any]:
return self._counts_dict.copy()
@property
def num_patches(self) -> int:
return self.dd.num_patches
def __eq__(self, other: Any) -> bool:
"""Element-wise comparison on all data attributes, recusive."""
if type(self) is not type(other):
return NotImplemented
dict_self = self.to_dict()
dict_other = other.to_dict()
for key in set(dict_self.keys()) | set(dict_other.keys()):
if dict_self.get(key, None) != dict_other.get(key, None):
return False
return True
def _make_bin_slice(self, item: TypeSliceIndex) -> Self:
kwargs = {kind: count.bins[item] for kind, count in self._counts_dict.items()}
return type(self).from_dict(kwargs)
def _make_patch_slice(self, item: TypeSliceIndex) -> Self:
kwargs = {
kind: count.patches[item] for kind, count in self._counts_dict.items()
}
return type(self).from_dict(kwargs)
def is_compatible(self, other: Any, *, require: bool = False) -> bool:
if type(self) is not type(other):
if not require:
return False
raise TypeError(f"{type(other)} is not compatible with {type(self)}")
return self.dd.is_compatible(other.dd, require=require)
@abstractmethod
def get_estimator(self) -> Callable[..., NDArray]:
"""Get the most appropriate correlation estimator for evaluating the
pair counts."""
pass
def sample(self) -> CorrData:
"""
Compute an estimate of the correlation function in bins of redshift.
Sums the pair counts over all spatial patches and uses the Landy-Szalay
estimator if random-random pair counts exist, otherwise the Davis-
Peebles estimator to compute the correlation function. Computes the
uncertainty of the correlation function by computing jackknife samples
from the spatial patches.
Returns:
The correlation function estimate with jackknife samples wrapped in
a :obj:`~yaw.CorrData` instance.
"""
estimator = self.get_estimator()
if parallel.on_root():
logger.debug(
"sampling correlation function with estimator '%s'", estimator.name
)
counts_values = {}
counts_samples = {}
for kind, paircounts in self._counts_dict.items():
resampled = paircounts.sample_patch_sum()
counts_values[kind] = resampled.data
counts_samples[kind] = resampled.samples
corr_data = estimator(**counts_values)
corr_samples = estimator(**counts_samples)
return CorrData(self.binning, corr_data, corr_samples)
@property
def dd(self) -> T:
"""The data-data pair counts."""
return self._counts_dict["dd"]
[docs]
class CorrFunc(BaseCorrFunc[NormalisedCounts]):
"""
Container for correlation function amplitude pair counts.
The container is typically created by :func:`~yaw.crosscorrelate` or
:func:`~yaw.autocorrelate` and stores pair counts in bins of redshift and
per spatial patch of the input :obj:`~yaw.Catalog` s. The data-data,
data-random, etc. pair counts are stored in separate attributes.
.. note::
While the pair counts ``dr``, ``rd``, or ``rr`` are all optional, at
least one of these pair counts must pre provided.
Additionally implements comparison with the ``==`` operator, addition with
``+`` and scaling of the pair counts by a scalar with ``*``.
Args:
dd:
The data-data pair counts as
:obj:`~yaw.correlation.paircounts.NormalisedCounts`.
Keyword Args:
dr:
The optional data-random pair counts as
:obj:`~yaw.correlation.paircounts.NormalisedCounts`.
rd:
The optional random-random pair counts as
:obj:`~yaw.correlation.paircounts.NormalisedCounts`.
rr:
The optional random-random pair counts as
:obj:`~yaw.correlation.paircounts.NormalisedCounts`.
Raises:
ValueError:
If any of the pair counts are not compatible (by binning or number
of patches).
EstimatorError:
If none of the optional pair counts are provided.
"""
__slots__ = ("_counts_dict",)
_counts_type = NormalisedCounts
_counts_name = dict(
dd="data_data", dr="data_random", rd="random_data", rr="random_random"
)
def __init__(
self,
dd: NormalisedCounts,
dr: NormalisedCounts | None = None,
rd: NormalisedCounts | None = None,
rr: NormalisedCounts | None = None,
) -> None:
self._init(dd=dd, dr=dr, rd=rd, rr=rr)
[docs]
def get_estimator(self) -> Callable[..., NDArray]:
return davis_peebles if self.rr is None else landy_szalay
@property
def dr(self) -> NormalisedCounts | None:
"""The data-random pair counts."""
return self._counts_dict.get("dr", None)
@property
def rd(self) -> NormalisedCounts | None:
"""The random-data pair counts."""
return self._counts_dict.get("rd", None)
@property
def rr(self) -> NormalisedCounts | None:
"""The random-random pair counts."""
return self._counts_dict.get("rr", None)
[docs]
class ScalarCorrFunc(CorrFunc):
"""
Container for scalar field correlation function amplitude pair counts.
The container is typically created by :func:`~yaw.crosscorrelate_scalar` or
:func:`~yaw.autocorrelate_scalar` and stores pair counts in bins of redshift
and per spatial patch of the input :obj:`~yaw.Catalog` s. The data-data and
data-random pair counts are stored in separate attributes.
Additionally implements comparison with the ``==`` operator, addition with
``+`` and scaling of the pair counts by a scalar with ``*``.
Args:
dd:
The data-data pair counts as
:obj:`~yaw.correlation.paircounts.NormalisedScalarCounts`.
dr:
The data-random pair counts as
:obj:`~yaw.correlation.paircounts.NormalisedScalarCounts`.
Raises:
ValueError:
If any of the pair counts are not compatible (by binning or number
of patches).
"""
__slots__ = ("_counts_dict",)
_counts_type = NormalisedScalarCounts
_counts_name = dict(dd="data_data", dr="data_random")
def __init__(
self,
dd: NormalisedScalarCounts,
dr: NormalisedScalarCounts | None = None,
) -> None:
self._init(dd=dd, dr=dr)
[docs]
def get_estimator(self) -> Callable[..., NDArray]:
return scalar_correlation
@property
def dr(self) -> NormalisedCounts | None:
"""The data-random pair counts."""
return self._counts_dict.get("dr", None)
def load_corrfunc(path: Path | str) -> BaseCorrFunc:
"""
Read back correlation function pair counts from a HDF5 file.
Automatically determines, based on the file's metadata, which correlation
data class to use.
Args:
path:
Input HDF5 to read from.
Returns:
Correlation function pair count data wrapped in an appropriate instance
of :class:`BaseCorrFunc`.
"""
with h5py.File(str(path)) as f:
for cls in BaseCorrFunc.__subclasses__():
try:
return cls.from_hdf(f)
except TypeError as err:
if "stores pair counts" not in str(err):
continue
raise ValueError(
"input file is not compatible with any correlation data implementation: "
+ str(path)
)