Source code for yaw.correlation.estimators

"""This module implements the correlation estimators and a way to symbolically
represent paircounts (e.g. data-data counts are represented by :obj:`CtsDD`).

The latter is used in :obj:`~yaw.correlation.CorrFunc` to determine which
correlation estimator can be computed for a possibly incomplete set of pair
counts (e.g. only data-data and random-random). Their key property is that they
can be compared, e.g.

>>> CtsDR() == CtsRR()  # random-data is random-random?
False

A special case is the :obj:`DavisPeebles` Estimator, which is the ratio
:math:`DD/DR-1`. In the case of a crosscorrelation it is irrelevant which of the
samples provides a random sample. Therefore, there is the special class
:obj:`CtsMix`, with the property:

>>> CtsMix() == CtsDR()
True
>>> CtsMix() == CtsRD()
True
"""

from __future__ import annotations

import warnings
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

import numpy as np

if TYPE_CHECKING:  # pragma: no cover
    from numpy.typing import NDArray

__all__ = [
    "CorrelationEstimator",
    "PeeblesHauser",
    "DavisPeebles",
    "Hamilton",
    "LandySzalay",
]


class EstimatorError(Exception):
    pass


class Cts(ABC):
    """Base class to symbolically represent pair counts."""

    @property
    @abstractmethod
    def _hash(self) -> int:
        """Used for comparison.

        Equivalent pair counts types should have the same hash value, see
        :obj:`CtsMix`."""
        pass

    @property
    @abstractmethod
    def _str(self) -> str:
        pass

    def __repr__(self) -> str:
        return f"<{self.__class__.__name__}>"

    def __str__(self) -> str:
        return self._str

    def __hash__(self) -> int:
        return self._hash

    def __eq__(self, other: object) -> bool:
        if isinstance(other, Cts):
            var1 = set(self._str.split("_"))
            var2 = set(other._str.split("_"))
            return not var1.isdisjoint(var2)
        return NotImplemented


class CtsDD(Cts):
    """Symbolic representation of data-data paircounts."""

    _str = "dd"
    _hash = 1


class CtsDR(Cts):
    """Symbolic representation of data-random paircounts."""

    _str = "dr"
    _hash = 2


class CtsRD(Cts):
    """Symbolic representation of random-data paircounts."""

    _str = "rd"
    _hash = 2


class CtsMix(Cts):
    """Symbolic representation of either data-random or random-data paircounts.

    >>> CtsDR() == CtsMix()
    True
    >>> CtsRD() == CtsMix()
    True
    """

    _str = "dr_rd"
    _hash = 2


class CtsRR(Cts):
    """Symbolic representation of random-random paircounts."""

    _str = "rr"
    _hash = 3


[docs] def cts_from_code(code: str) -> Cts: """Instantiate the correct :obj:`Cts` subclass from a string. Valid options are ``dd``, ``dr``, ``rd``, ``rr``, e.g.: >>> cts_from_code("dr") <CtsDR> """ codes = dict(dd=CtsDD, dr=CtsDR, rd=CtsRD, rr=CtsRR) if code not in codes: raise ValueError(f"unknown pair counts '{code}'") return codes[code]()
class CorrelationEstimator(ABC): name: str """Full name of the estimator.""" short: str """Get a short form representation of the estimator name.""" requires: list[Cts] """Get a symbolic list of pair counts required to evaluate the estimator. """ optional: list[Cts] """Get a symbolic list of optional pair counts that may be used to evaluate the estimator. """ variants: list[CorrelationEstimator] = [] """List of all implemented correlation estimators classes.""" def __init_subclass__(cls, **kwargs): """Add all subclusses to the :obj:`variants` attribute list.""" super().__init_subclass__(**kwargs) cls.variants.append(cls) @classmethod def _warn_enum_zero(cls, counts: NDArray): """Raise a warning if any value in the expression enumerator is zero""" if np.any(counts == 0.0): warnings.warn(f"estimator {cls.short} encontered zeros in enumerator") @classmethod @abstractmethod def eval( cls, *, dd: NDArray, dr: NDArray, rr: NDArray, rd: NDArray | None = None ) -> NDArray: """Method that implements the estimator. Should call :meth:`_warn_enum_zero` to raise warnings on zero-division. """ pass
[docs] class PeeblesHauser(CorrelationEstimator): """Implementation of the Peebles-Hauser correlation estimator :math:`\\frac{DD}{RR} - 1`. """ name = "PeeblesHauser" short = "PH" requires = [CtsDD(), CtsRR()] optional = []
[docs] @classmethod def eval(cls, *, dd: NDArray, rr: NDArray, **kwargs) -> NDArray: """Evaluate the estimator with the given pair counts. Args: dd (:obj:`NDArray`): Data-data pair counts (normalised). rr (:obj:`NDArray`): Random-random pair counts (normalised). Returns: :obj:`NDArray` """ cls._warn_enum_zero(rr) return dd / rr - 1.0
[docs] class DavisPeebles(CorrelationEstimator): """Implementation of the Davis-Peebles correlation estimator :math:`\\frac{DD}{DR} - 1`. .. Note:: Accepts both :math:`DR` and :math:`RD` for the denominator. """ name = "DavisPeebles" short = "DP" requires = [CtsDD(), CtsMix()] optional = []
[docs] @classmethod def eval(cls, *, dd: NDArray, dr_rd: NDArray, **kwargs) -> NDArray: """Evaluate the estimator with the given pair counts. Args: dd (:obj:`NDArray`): Data-data pair counts (normalised). dr_rd (:obj:`NDArray`): Either data-random or random-data pair counts (normalised). Returns: :obj:`NDArray` """ cls._warn_enum_zero(dr_rd) return dd / dr_rd - 1.0
[docs] class Hamilton(CorrelationEstimator): """Implementation of the Hamilton correlation estimator :math:`\\frac{DD \\times RR}{DR \\times RD} - 1`. """ name = "Hamilton" short = "HM" requires = [CtsDD(), CtsDR(), CtsRR()] optional = [CtsRD()]
[docs] @classmethod def eval( cls, *, dd: NDArray, dr: NDArray, rr: NDArray, rd: NDArray | None = None ) -> NDArray: """Evaluate the estimator with the given pair counts. Args: dd (:obj:`NDArray`): Data-data pair counts (normalised). dr (:obj:`NDArray`): Data-random pair counts (normalised). rd (:obj:`NDArray`, optional): Random-data pair counts (normalised). If not provided, use ``dr`` instead. rr (:obj:`NDArray`): Random-random pair counts (normalised). Returns: :obj:`NDArray` """ rd = dr if rd is None else rd enum = dr * rd cls._warn_enum_zero(enum) return (dd * rr) / enum - 1.0
[docs] class LandySzalay(CorrelationEstimator): """Implementation of the Landy-Szalay correlation estimator :math:`\\frac{DD - (DR + RD)}{RR} + 1`. """ name = "LandySzalay" short = "LS" requires = [CtsDD(), CtsDR(), CtsRR()] optional = [CtsRD()]
[docs] @classmethod def eval( cls, *, dd: NDArray, dr: NDArray, rr: NDArray, rd: NDArray | None = None ) -> NDArray: """Evaluate the estimator with the given pair counts. Args: dd (:obj:`NDArray`): Data-data pair counts (normalised). dr (:obj:`NDArray`): Data-random pair counts (normalised). rd (:obj:`NDArray`, optional): Random-data pair counts (normalised). If not provided, use ``dr`` instead. rr (:obj:`NDArray`): Random-random pair counts (normalised). Returns: :obj:`NDArray` """ rd = dr if rd is None else rd cls._warn_enum_zero(rr) return (dd - (dr + rd) + rr) / rr