Source code for yaw.correlation.corrfuncs

"""Implements the two primary data containers for correlation function data,
:obj:`CorrFunc` which stores the pair counts and :obj:`CorrData` which stores
the (resampled) values of the correlation function in bins of redshift.
"""

from __future__ import annotations

import logging
from collections.abc import Sequence
from dataclasses import dataclass, field, fields
from typing import TYPE_CHECKING, Any, Type, TypeVar

import h5py
import numpy as np
import pandas as pd
from deprecated import deprecated

from yaw.catalogs import PatchLinkage
from yaw.config import OPTIONS, ResamplingConfig
from yaw.core.abc import BinnedQuantity, HDFSerializable, PatchedQuantity
from yaw.core.containers import Indexer, SampledData
from yaw.core.logging import TimedLog
from yaw.core.utils import TypePathStr
from yaw.core.utils import format_float_fixed_width as fmt_num
from yaw.correlation.estimators import (
    CorrelationEstimator,
    CtsMix,
    EstimatorError,
    cts_from_code,
)
from yaw.correlation.paircounts import NormalisedCounts, TypeIndex

if TYPE_CHECKING:  # pragma: no cover
    from matplotlib.axis import Axis
    from numpy.typing import NDArray
    from pandas import IntervalIndex

    from yaw.catalogs import BaseCatalog
    from yaw.config import Configuration
    from yaw.correlation.estimators import Cts

__all__ = ["CorrData", "CorrFunc", "add_corrfuncs", "autocorrelate", "crosscorrelate"]


logger = logging.getLogger(__name__)

_Tdata = TypeVar("_Tdata", bound="CorrData")


[docs] @dataclass(frozen=True, repr=False, eq=False) class CorrData(SampledData): """Container class for sampled correlation function data. Contains the redshift binning, correlation function amplitudes, and resampled amplitudes (e.g. jackknife or bootstrap). The resampled values are used to compute error estimates and covariance/correlation matrices. Provides some plotting methods for convenience. The comparison, addition and subtraction and indexing rules are inherited from :obj:`~yaw.core.containers.SampledData`, see some examples below. .. rubric:: Examples Create a new instance by sampling a correlation function: >>> from yaw.examples import w_sp >>> data = w_sp.sample() # uses the default ResamplingConfig >>> data CorrData(n_bins=30, z='0.070...1.420', n_samples=64, method='jackknife') View the data for a subset of the redshift bins: >>> data.bins[5:9].data array([0.10158809, 0.08079947, 0.03876175, 0.02715336]) View the same subset as series: >>> data.bins[5:9].get_data() (0.295, 0.34] 0.101588 (0.34, 0.385] 0.080799 (0.385, 0.43] 0.038762 (0.43, 0.475] 0.027153 dtype: float64 Get the redshift bin centers for these bins: >>> data.bins[5:9].mids array([0.3175, 0.3625, 0.4075, 0.4525]) Args: binning (:obj:`pandas.IntervalIndex`): The redshift bin edges used for this correlation function. data (:obj:`NDArray`): The correlation function values. samples (:obj:`NDArray`): The resampled correlation function values. method (:obj:`str`): The resampling method used, see :class:`~yaw.ResamplingConfig` for available options. info (:obj:`str`, optional): Descriptive text included in the headers of output files produced by :func:`CorrData.to_files`. """ info: str | None = None """Optional descriptive text for the contained data.""" def __post_init__(self) -> None: super().__post_init__()
[docs] @classmethod def from_files(cls: Type[_Tdata], path_prefix: TypePathStr) -> _Tdata: """Create a new instance by loading the data from ASCII files. The data is restored from a set of three input files produced by :meth:`to_files`. .. Note:: These file have the same names but different file extension, therefore only provide the base name without any extension to specifiy the input files. Args: path_prefix (:obj:`str`): The base name of the input files without any file extension. Returns: :obj:`CorrData` """ name = cls.__name__.lower()[:-4] logger.debug("reading %s data from '%s.*'", name, path_prefix) # load data and errors ext = "dat" data_error = np.loadtxt(f"{path_prefix}.{ext}") # restore index binning = pd.IntervalIndex.from_arrays(data_error[:, 0], data_error[:, 1]) # load samples ext = "smp" samples = np.loadtxt(f"{path_prefix}.{ext}") # load header info = None with open(f"{path_prefix}.{ext}") as f: for line in f.readlines(): if "extra info" in line: _, info = line.split(":", maxsplit=1) info = info.strip() if "z_low" in line: line = line[2:].strip("\n") # remove leading '# ' header = [col for col in line.split(" ") if len(col) > 0] break else: raise ValueError("sample file header misformatted") method_key, n_samples = header[-1].rsplit("_", 1) n_samples = int(n_samples) + 1 # reconstruct sampling method for method in OPTIONS.method: if method.startswith(method_key): break else: raise ValueError(f"invalid sampling method key '{method_key}'") return cls( binning=binning, data=data_error[:, 2], # take data column samples=samples.T[2:], # remove redshift bin columns method=method, info=info, )
@property def _dat_desc(self) -> str: """Description included in the data file.""" return ( "# correlation function estimate with symmetric 68% percentile " "confidence" ) @property def _smp_desc(self) -> str: """Description included in the samples file.""" return f"# {self.n_samples} {self.method} correlation function samples" @property def _cov_desc(self) -> str: """Description included in the covariance file.""" return ( f"# correlation function estimate covariance matrix " f"({self.n_bins}x{self.n_bins})" )
[docs] def to_files(self, path_prefix: TypePathStr) -> None: """Store the data in a set of ASCII files on disk. These files can be loaded with the :meth:`from_files` method. There are three files with the same name but different file extension. .. rubric:: Files ``[path_prefix].dat``: Contains the redshift bin edges, the data values and their standard error. Additionally there is information about the error estimate and the :obj:`info` attribute. ``[path_prefix].smp``: Contains one row for each redshift bin. The first two columns list the lower and upper edge of the redshift bin, the remaining columns list the values of the samples, i.e. there are ``N+2`` columns. Additionally contains the :obj:`info` attribute. ``[path_prefix].cov``: Contains the covariance matrix and additionally the :obj:`info` attribute. Args: path_prefix (:obj:`str`): The base name of the output files without any file extension. """ name = self.__class__.__name__.lower()[:-4] logger.info("writing %s data to '%s.*'", name, path_prefix) PREC = 10 DELIM = " " def comment(string: str) -> str: if self.info is not None: string = f"{string}\n# extra info: {self.info}" return string def write_head(f, description, header, delim=DELIM): f.write(f"{description}\n") line = delim.join(f"{h:>{PREC}s}" for h in header) f.write(f"# {line[2:]}\n") # write data and errors ext = "dat" header = ["z_low", "z_high", "nz", "nz_err"] with open(f"{path_prefix}.{ext}", "w") as f: write_head(f, comment(self._dat_desc), header, delim=DELIM) for zlow, zhigh, nz, nz_err in zip( self.edges[:-1], self.edges[1:], self.data, self.error ): values = [fmt_num(val, PREC) for val in (zlow, zhigh, nz, nz_err)] f.write(DELIM.join(values) + "\n") # write samples ext = "smp" header = ["z_low", "z_high"] header.extend(f"{self.method[:4]}_{i}" for i in range(self.n_samples)) with open(f"{path_prefix}.{ext}", "w") as f: write_head(f, comment(self._smp_desc), header, delim=DELIM) for zlow, zhigh, samples in zip( self.edges[:-1], self.edges[1:], self.samples.T ): values = [fmt_num(zlow, PREC), fmt_num(zhigh, PREC)] values.extend(fmt_num(val, PREC) for val in samples) f.write(DELIM.join(values) + "\n") # write covariance (just for convenience) ext = "cov" fmt_str = DELIM.join("{: .{prec}e}" for _ in range(self.n_bins)) + "\n" with open(f"{path_prefix}.{ext}", "w") as f: f.write(f"{comment(self._cov_desc)}\n") for values in self.covariance: f.write(fmt_str.format(*values, prec=PREC - 3))
def _make_plot( self, x: NDArray[np.float64], y: NDArray[np.float64], yerr: NDArray[np.float64], *, color: str | NDArray | None = None, label: str | None = None, error_bars: bool = True, ax: Axis | None = None, plot_kwargs: dict[str, Any] | None = None, zero_line: bool = False, ) -> Axis: from matplotlib import pyplot as plt # configure plot if ax is None: ax = plt.gca() if plot_kwargs is None: plot_kwargs = {} plot_kwargs.update(dict(color=color, label=label)) ebar_kwargs = dict(fmt=".", ls="none") ebar_kwargs.update(plot_kwargs) # plot zero line if zero_line: lw = 0.7 for spine in ax.spines.values(): lw = spine.get_linewidth() ax.axhline(0.0, color="k", lw=lw, zorder=-2) # plot data if error_bars: ax.errorbar(x, y, yerr, **ebar_kwargs) else: color = ax.plot(x, y, **plot_kwargs)[0].get_color() ax.fill_between(x, y - yerr, y + yerr, color=color, alpha=0.2) return ax
[docs] def plot( self, *, color: str | NDArray | None = None, label: str | None = None, error_bars: bool = True, ax: Axis | None = None, xoffset: float = 0.0, plot_kwargs: dict[str, Any] | None = None, zero_line: bool = False, scale_by_dz: bool = False, ) -> Axis: """Create a plot of the correlation data as a function of redshift. Create a new axis or plot to an existing one, add x-axis offsets, if plotting multiple instances, or specify if the values should be represented as points with errorbars (default) or as line plot with shaded area to represent uncertainties. Args: color: Valid :mod:`matplotlib` color used for the error bars or the line and the shaded uncertainty area. label (:obj:`str`, optional): Plot label for the legend. error_bars (:obj:`bool`, optional): Whether to plot error bars (the default) or a line plot with shaded area. ax (plot axis, optional): Optional :mod:`matplotlib` axis to plot into. xoffset (:obj:`int`, optional): Shift to apply to the x-axis (redshift) values. plot_kwargs (:obj:`dict`, optional): Parameters passed to the :func:`errobar` or :func:`plot` plotting functions. zero_lilne (:obj:`bool`, optional): Wether to draw a thin black line that indicates ``y=0``. scale_by_dz (:obj:`bool`, optional): Whether to multiply the y-values by the redshift bin width :obj:`dz`. """ x = self.mids + xoffset y = self.data.astype(np.float64) yerr = self.error.astype(np.float64) if scale_by_dz: y *= self.dz yerr *= self.dz return self._make_plot( x, y, yerr, color=color, label=label, error_bars=error_bars, ax=ax, plot_kwargs=plot_kwargs, zero_line=zero_line, )
[docs] def plot_corr( self, *, redshift: bool = False, cmap: str = "RdBu_r", ax: Axis | None = None ) -> Axis: """Plot the correlation matrix of the data. Create a new axis or plot to an existing one. Args: redshift (:obj:`bool`, optional): Whether to map the matrix onto redshifts or as regular matrix plot (the default). cmap (:obj:`str`, optional): Name of a :mod:`matplotlib` colormap to use. ax (plot axis, optional): Optional :mod:`matplotlib` axis to plot into. """ from matplotlib import pyplot as plt if ax is None: ax = plt.gca() corr = self.get_correlation() cmap_kwargs = dict(cmap=cmap, vmin=-1.0, vmax=1.0) if redshift: ticks = self.mids ax.pcolormesh(ticks, ticks, np.flipud(corr), **cmap_kwargs) ax.xaxis.tick_top() ax.set_aspect("equal") else: ax.matshow(corr, **cmap_kwargs) return ax
def check_mergable(cfs: Sequence[CorrFunc | None]) -> None: """Helper function that checks if a set of :obj:`CorrFunc` have the same kinds of pair counts.""" reference = cfs[0] for kind in ("dd", "dr", "rd", "rr"): ref_pcounts = getattr(reference, kind) for cf in cfs[1:]: pcounts = getattr(cf, kind) if type(ref_pcounts) is not type(pcounts): raise ValueError(f"cannot merge, '{kind}' incompatible")
[docs] @dataclass(frozen=True) class CorrFunc(PatchedQuantity, BinnedQuantity, HDFSerializable): """Container object for measured correlation pair counts. Container returned by :meth:`~yaw.catalogs.BaseCatalog.correlate` that computes the correlations between data catalogs. The correlation function can be computed from four kinds of pair counts, data-data (DD), data-random (DR), random-data (RD), and random-random (RR). .. Note:: DD is always required, but DR, RD, and RR are optional as long as at least one is provided. Provides methods to read and write data to disk and compute the actual correlation function values (see :class:`~yaw.CorrData`) using spatial resampling (see :class:`~yaw.ResamplingConfig`). The container supports comparison with ``==`` and ``!=`` on the pair count level. The supported arithmetic operations between two correlation functions, addition and subtraction, are applied between all internally stored pair counts data. The same applies to rescaling of the counts by a scalar, see some examples below. .. rubric:: Examples Create a new instance by sampling a correlation function: >>> from yaw.examples import w_sp >>> dd, dr = w_sp.dd, w_sp.dr # get example data-data and data-rand counts >>> corr = yaw.CorrFunc(dd=dd, dr=dr) >>> corr CorrFunc(n_bins=30, z='0.070...1.420', dd=True, dr=True, rd=False, rr=False, n_patches=64) Access the pair counts: >>> corr.dd NormalisedCounts(n_bins=30, z='0.070...1.420', n_patches=64) Check if it is an autocorrelation function measurement: >>> corr.auto False Check which pair counts are available to compute the correlation function: >>> corr.estimators {'DP': yaw.correlation.estimators.DavisPeebles} Sample the correlation function >>> corr.sample() # uses the default ResamplingConfig CorrData(n_bins=30, z='0.070...1.420', n_samples=64, method='jackknife') Note how the indicated shape changes when a patch subset is selected: >>> corr.patches[:10] CorrFunc(n_bins=30, z='0.070...1.420', dd=True, dr=True, rd=False, rr=False, n_patches=10) Note how the indicated redshift range and shape change when a bin subset is selected: >>> corr.bins[:3] CorrFunc(n_bins=3, z='0.070...0.205', dd=True, dr=True, rd=False, rr=False, n_patches=64) Args: dd (:obj:`~yaw.correlation.paircounts.NormalisedCounts`): Pair counts from a data-data count measurement. dr (:obj:`~yaw.correlation.paircounts.NormalisedCounts`, optional): Pair counts from a data-random count measurement. rd (:obj:`~yaw.correlation.paircounts.NormalisedCounts`, optional): Pair counts from a random-data count measurement. rr (:obj:`~yaw.correlation.paircounts.NormalisedCounts`, optional): Pair counts from a random-random count measurement. """ dd: NormalisedCounts """Pair counts for a data-data correlation measurement""" dr: NormalisedCounts | None = field(default=None) """Pair counts from a data-random count measurement.""" rd: NormalisedCounts | None = field(default=None) """Pair counts from a random-data count measurement.""" rr: NormalisedCounts | None = field(default=None) """Pair counts from a random-random count measurement.""" def __post_init__(self) -> None: # check if any random pairs are required if self.dr is None and self.rd is None and self.rr is None: raise ValueError("either 'dr', 'rd' or 'rr' is required") # check that the pair counts are compatible for kind in ("dr", "rd", "rr"): pairs: NormalisedCounts | None = getattr(self, kind) if pairs is None: continue try: self.dd.is_compatible(pairs, require=True) assert self.dd.n_patches == pairs.n_patches except (ValueError, AssertionError) as e: raise ValueError( f"pair counts '{kind}' and 'dd' are not compatible" ) from e def __repr__(self) -> str: string = super().__repr__()[:-1] pairs = f"dd=True, dr={self.dr is not None}, " pairs += f"rd={self.rd is not None}, rr={self.rr is not None}" other = f"n_patches={self.n_patches}" return f"{string}, {pairs}, {other})" def __eq__(self, other: object) -> bool: if isinstance(other, self.__class__): for cfield in fields(self): kind = cfield.name if getattr(self, kind) != getattr(other, kind): return False return True return NotImplemented def __add__(self, other: object) -> CorrFunc: if isinstance(other, self.__class__): # check that the pair counts are set consistently kinds = [] for cfield in fields(self): kind = cfield.name self_set = getattr(self, kind) is not None other_set = getattr(other, kind) is not None if (self_set and not other_set) or (not self_set and other_set): raise ValueError( f"pair counts for '{kind}' not set for both operands" ) elif self_set and other_set: kinds.append(kind) kwargs = { kind: getattr(self, kind) + getattr(other, kind) for kind in kinds } return self.__class__(**kwargs) return NotImplemented def __radd__(self, other: object) -> CorrFunc: if np.isscalar(other) and other == 0: return self return other.__add__(self) def __mul__(self, other: object) -> CorrFunc: if np.isscalar(other) and not isinstance(other, (bool, np.bool_)): # check that the pair counts are set consistently kwargs = {} for cfield in fields(self): kind = cfield.name counts = getattr(self, kind) if counts is not None: kwargs[kind] = counts * other return self.__class__(**kwargs) return NotImplemented @property def auto(self) -> bool: """Whether the stored data are from an autocorrelation measurement.""" return self.dd.auto @property def bins(self) -> Indexer[TypeIndex, CorrFunc]: def builder(inst: CorrFunc, item: TypeIndex) -> CorrFunc: if isinstance(item, int): item = [item] kwargs = {} for cfield in fields(inst): pairs: NormalisedCounts | None = getattr(inst, cfield.name) if pairs is None: kwargs[cfield.name] = None else: kwargs[cfield.name] = pairs.bins[item] return CorrFunc(**kwargs) return Indexer(self, builder) @property def patches(self) -> Indexer[TypeIndex, CorrFunc]: def builder(inst: CorrFunc, item: TypeIndex) -> CorrFunc: kwargs = {} for cfield in fields(inst): counts: NormalisedCounts | None = getattr(inst, cfield.name) if counts is not None: counts = counts.patches[item] kwargs[cfield.name] = counts return CorrFunc(**kwargs) return Indexer(self, builder)
[docs] def get_binning(self) -> IntervalIndex: return self.dd.get_binning()
@property def n_patches(self) -> int: return self.dd.n_patches
[docs] def is_compatible(self, other: CorrFunc, require: bool = False) -> bool: """Check whether this instance is compatible with another instance. Ensures that the redshift binning and the number of patches are identical. Args: other (:obj:`BinnedQuantity`): Object instance to compare to. require (:obj:`bool`) Raise a ValueError if any of the checks fail. Returns: :obj:`bool` """ if self.dd.n_patches != other.dd.n_patches: if require: raise ValueError("number of patches does not agree") return False return self.dd.is_compatible(other.dd, require)
@property def estimators(self) -> dict[str, CorrelationEstimator]: """Get a listing of correlation estimators implemented, depending on which pair counts are available. Returns: :obj:`dict`: Mapping from correlation estimator name abbreviation to correlation function class. """ # figure out which of dd, dr, ... are not None available = set() # iterate all dataclass attributes that are in __init__ for attr in fields(self): if getattr(self, attr.name) is not None: available.add(cts_from_code(attr.name)) # check which estimators are supported estimators = {} for estimator in CorrelationEstimator.variants: # registered estimators if set(estimator.requires) <= available: estimators[estimator.short] = estimator return estimators def _check_and_select_estimator( self, estimator: str | None = None ) -> type[CorrelationEstimator]: options = self.estimators if estimator is None: for shortname in ["LS", "DP", "PH"]: # preferred hierarchy if shortname in options: estimator = shortname break estimator = estimator.upper() if estimator not in options: try: index = [e.short for e in CorrelationEstimator.variants].index( estimator ) est_class = CorrelationEstimator.variants[index] except ValueError as e: raise ValueError(f"invalid estimator '{estimator}'") from e # determine which pair counts are missing for attr in fields(self): name = attr.name cts = cts_from_code(name) if getattr(self, name) is None and cts in est_class.requires: raise EstimatorError(f"estimator requires {name}") # select the correct estimator cls = options[estimator] logger.debug( "selecting estimator '%s' from %s", cls.short, "/".join(self.estimators) ) return cls def _getattr_from_cts(self, cts: Cts) -> NormalisedCounts | None: if isinstance(cts, CtsMix): for code in str(cts).split("_"): value = getattr(self, code) if value is not None: break return value else: return getattr(self, str(cts))
[docs] @deprecated(reason="renamed to CorrFunc.sample", version="2.3.1") def get(self, *args, **kwargs): """ .. deprecated:: 2.3.1 Renamed to :meth:`sample`. """ return self.sample(*args, **kwargs) # pragma: no cover
[docs] def sample( self, config: ResamplingConfig | None = None, *, estimator: str | None = None, info: str | None = None, ) -> CorrData: """Compute the correlation function from the stored pair counts, including an error estimate from spatial resampling of patches. Args: config (:obj:`~yaw.ResamplingConfig`): Specify the resampling method and its configuration. Keyword Args: estimator (:obj:`str`, optional): The name abbreviation for the correlation estimator to use. Defaults to Landy-Szalay if RR is available, otherwise to Davis-Peebles. info (:obj:`str`, optional): Descriptive text passed on to the output :obj:`CorrData` object. Returns: :obj:`CorrData`: Correlation function data, including redshift binning, function values and samples. """ if config is None: config = ResamplingConfig() est_fun = self._check_and_select_estimator(estimator) logger.debug("computing correlation and %s samples", config.method) # get the pair counts for the required terms (DD, maybe DR and/or RR) required_data = {} required_samples = {} for cts in est_fun.requires: try: # if pairs are None, estimator with throw error pairs = self._getattr_from_cts(cts).sample(config) required_data[str(cts)] = pairs.data required_samples[str(cts)] = pairs.samples except AttributeError as e: if "NoneType" not in e.args[0]: raise # get the pair counts for the optional terms (e.g. RD) optional_data = {} optional_samples = {} for cts in est_fun.optional: try: # if pairs are None, estimator with throw error pairs = self._getattr_from_cts(cts).sample(config) optional_data[str(cts)] = pairs.data optional_samples[str(cts)] = pairs.samples except AttributeError as e: if "NoneType" not in e.args[0]: raise # evaluate the correlation estimator data = est_fun.eval(**required_data, **optional_data) samples = est_fun.eval(**required_samples, **optional_samples) return CorrData( binning=self.get_binning(), data=data, samples=samples, method=config.method, info=info, )
[docs] @classmethod def from_hdf(cls, source: h5py.File | h5py.Group) -> CorrFunc: def _try_load(root: h5py.Group, name: str) -> NormalisedCounts | None: try: return NormalisedCounts.from_hdf(root[name]) except KeyError: return None dd = NormalisedCounts.from_hdf(source["data_data"]) dr = _try_load(source, "data_random") rd = _try_load(source, "random_data") rr = _try_load(source, "random_random") return cls(dd=dd, dr=dr, rd=rd, rr=rr)
[docs] def to_hdf(self, dest: h5py.File | h5py.Group) -> None: group = dest.create_group("data_data") self.dd.to_hdf(group) group_names = dict(dr="data_random", rd="random_data", rr="random_random") for kind, name in group_names.items(): data: NormalisedCounts | None = getattr(self, kind) if data is not None: group = dest.create_group(name) data.to_hdf(group) dest.create_dataset("n_patches", data=self.n_patches)
[docs] @classmethod def from_file(cls, path: TypePathStr) -> CorrFunc: logger.debug("reading pair counts from '%s'", path) with h5py.File(str(path)) as f: return cls.from_hdf(f)
[docs] def to_file(self, path: TypePathStr) -> None: logger.info("writing pair counts to '%s'", path) with h5py.File(str(path), mode="w") as f: self.to_hdf(f)
[docs] def concatenate_patches(self, *cfs: CorrFunc) -> CorrFunc: check_mergable([self, *cfs]) merged = {} for kind in ("dd", "dr", "rd", "rr"): self_pcounts = getattr(self, kind) if self_pcounts is not None: other_pcounts = [getattr(cf, kind) for cf in cfs] merged[kind] = self_pcounts.concatenate_patches(*other_pcounts) return self.__class__(**merged)
[docs] def concatenate_bins(self, *cfs: CorrFunc) -> CorrFunc: check_mergable([self, *cfs]) merged = {} for kind in ("dd", "dr", "rd", "rr"): self_pcounts = getattr(self, kind) if self_pcounts is not None: other_pcounts = [getattr(cf, kind) for cf in cfs] merged[kind] = self_pcounts.concatenate_bins(*other_pcounts) return self.__class__(**merged)
def _create_dummy_counts(counts: Any | dict[str, Any]) -> dict[str, None]: """Duplicate a the return values of :meth:`yaw.catalogs.BaseCatalog.correlate`, but replace the :obj:`CorrFunc` instances by :obj:`None`.""" if isinstance(counts, dict): dummy = {scale_key: None for scale_key in counts} else: dummy = None return dummy
[docs] def add_corrfuncs( corrfuncs: Sequence[CorrFunc], weights: Sequence[np.number] | None = None ) -> CorrFunc: """Add correlation functions that are measured at different scales. The correlation functions are added by summing together their pair counts. They can be weighted prior to summation by effectively scaling their pair counts with a set of scalar weights, one for each input correlation function. .. Note:: The actual scales are not checked, but the number of patches and the redshift binning of the inputs must be identical. This operation is effectively equivalent to: >>> corrfunc1 * weight1 + corrfunc2 * weight2 # + ... Args: corrfuncs (sequence of :obj:`CorrFunc`): A list of correlation functions to add. weights (sequence of :obj:`int` or :obj:`float`, optional): An optional list of weights, one for each correlation function. Returns: :obj:`CorrFunc`: The combined correlation function after summing the pairs. """ if weights is None: weights = [1.0] * len(corrfuncs) else: if len(corrfuncs) != len(weights): raise ValueError( "number of weights must match number of correlation functions" ) # run summation, rescaling by weights combined = 0.0 for corrfunc, weight in zip(corrfuncs, weights): combined = combined + (corrfunc * weight) return combined
class PatchError(Exception): pass def _check_patch_centers(catalogues: Sequence[BaseCatalog]) -> None: """Check whether the patch centers of a set of data catalogues are seperated by no more than the radius of the patches.""" refcat = catalogues[0] for cat in catalogues[1:]: if refcat.n_patches != cat.n_patches: raise PatchError("number of patches does not agree") ref_coord = refcat.centers.to_sky() cat_coord = cat.centers.to_sky() dist = ref_coord.distance(cat_coord) if np.any(dist.values > refcat.radii.values): raise PatchError("the patch centers are inconsistent")
[docs] def autocorrelate( config: Configuration, data: BaseCatalog, random: BaseCatalog, *, linkage: PatchLinkage | None = None, compute_rr: bool = True, progress: bool = False, ) -> CorrFunc | dict[str, CorrFunc]: """Compute an angular autocorrelation function in bins of redshift. The correlation is measured on fixed physical scales that are converted to angles for each redshift bin. All parameters (binning, scales, etc.) are bundled in the input configuration, see :mod:`yaw.config`. .. Note:: Both the data and random catalogue require redshift point estimates. Args: config (:obj:`~yaw.config.Configuration`): Provides all major run parameters, such as scales, binning, and for the correlation measurement backend. data (:obj:`~yaw.catalogs.BaseCatalog`): The data sample catalogue. random (:obj:`~yaw.catalogs.BaseCatalog`): Random catalogue for the data sample. Keyword Args: linkage (:obj:`~yaw.catalogs.PatchLinkage`, optional): Provide a linkage object that determines which spatial patches must be correlated given the measurement scales. Ensures consistency when measuring correlations repeatedly for a fixed set of input catalogues. Generated automatically by default. compute_rr (:obj:`bool`): Whether the random-random (RR) pair counts are computed. progress (:obj:`bool`): Display a progress bar. Returns: :obj:`CorrFunc` or :obj:`dict[str, CorrFunc]`: Container that holds the measured pair counts, or a dictionary of containers if multiple scales are configured. Dictionary keys have a ``kpcXXtXX`` pattern, where ``XX`` are the lower and upper scale limit as integers, in kpc (see :obj:`yaw.core.cosmology.Scale`). """ _check_patch_centers([data, random]) scales = config.scales.as_array() logger.info( "running autocorrelation (%i scales, %.0f<r<=%.0fkpc)", len(scales), scales.min(), scales.max(), ) if linkage is None: linkage = PatchLinkage.from_setup(config, random) kwargs = dict(linkage=linkage, progress=progress) logger.debug("scheduling DD, DR" + (", RR" if compute_rr else "")) with TimedLog(logger.info, "counting data-data pairs"): DD = data.correlate(config, binned=True, **kwargs) with TimedLog(logger.info, "counting data-rand pairs"): DR = data.correlate(config, binned=True, other=random, **kwargs) if compute_rr: with TimedLog(logger.info, "counting rand-rand pairs"): RR = random.correlate(config, binned=True, **kwargs) else: RR = _create_dummy_counts(DD) if isinstance(DD, dict): result = { scale: CorrFunc(dd=DD[scale], dr=DR[scale], rr=RR[scale]) for scale in DD } else: result = CorrFunc(dd=DD, dr=DR, rr=RR) return result
[docs] def crosscorrelate( config: Configuration, reference: BaseCatalog, unknown: BaseCatalog, *, ref_rand: BaseCatalog | None = None, unk_rand: BaseCatalog | None = None, linkage: PatchLinkage | None = None, progress: bool = False, ) -> CorrFunc | dict[str, CorrFunc]: """Compute an angular crosscorrelation function in bins of redshift. The correlation is measured on fixed physical scales that are converted to angles for each redshift bin. All parameters (binning, scales, etc.) are bundled in the input configuration, see :mod:`yaw.config`. At least one random catalogue (either for the reference or the unknown sample) must be provided, which will either trigger counting the DR (reference-random) or RD (random-unknown) pair counts. If both random catalogues are provided, the random-random pairs (RR) are counted as well, this is equivalent to enabling the ``compute_rr`` parameter in :func:`autocorrelate`. .. Note:: The reference catalogue requires redshift point estimates. If the reference random cataloge is provided, it also requires redshifts. Args: config (:obj:`~yaw.config.Configuration`): Provides all major run parameters. reference (:obj:`yaw.catalogs.BaseCatalog`): The reference sample. unknown (:obj:`yaw.catalogs.BaseCatalog`): The sample with unknown redshift distribution. Keyword Args: ref_rand (:obj:`yaw.catalogs.BaseCatalog`, optional): Random catalog for the reference sample, requires redshifts configured. unk_rand (:obj:`yaw.catalogs.BaseCatalog`, optional): Random catalog for the unknown sample. linkage (:obj:`yaw.catalogs.PatchLinkage`, optional): Provide a linkage object that determines which spatial patches must be correlated given the measurement scales. Ensures consistency when measuring multiple correlations, otherwise generated automatically. progress (:obj:`bool`): Display a progress bar. Returns: :obj:`CorrFunc` or :obj:`dict[str, CorrFunc]`: Container that holds the measured pair counts, or a dictionary of containers if multiple scales are configured. Dictionary keys have a ``kpcXXtXX`` pattern, where ``XX`` are the lower and upper scale limit as integers, in kpc (see :obj:`yaw.core.cosmology.Scale`). """ compute_dr = unk_rand is not None compute_rd = ref_rand is not None compute_rr = compute_dr and compute_rd # make sure that the patch centers are consistent all_cats = [reference, unknown] if compute_dr: all_cats.append(unk_rand) if compute_rd: all_cats.append(ref_rand) _check_patch_centers(all_cats) scales = config.scales.as_array() logger.info( "running crosscorrelation (%i scales, %.0f<r<=%.0fkpc)", len(scales), scales.min(), scales.max(), ) if linkage is None: linkage = PatchLinkage.from_setup(config, unknown) logger.debug( "scheduling DD" + (", DR" if compute_dr else "") + (", RD" if compute_rd else "") + (", RR" if compute_rr else "") ) kwargs = dict(linkage=linkage, progress=progress) with TimedLog(logger.info, "counting data-data pairs"): DD = reference.correlate(config, binned=False, other=unknown, **kwargs) if compute_dr: with TimedLog(logger.info, "counting data-rand pairs"): DR = reference.correlate(config, binned=False, other=unk_rand, **kwargs) else: DR = _create_dummy_counts(DD) if compute_rd: with TimedLog(logger.info, "counting rand-data pairs"): RD = ref_rand.correlate(config, binned=False, other=unknown, **kwargs) else: RD = _create_dummy_counts(DD) if compute_rr: with TimedLog(logger.info, "counting rand-rand pairs"): RR = ref_rand.correlate(config, binned=False, other=unk_rand, **kwargs) else: RR = _create_dummy_counts(DD) if isinstance(DD, dict): result = { scale: CorrFunc(dd=DD[scale], dr=DR[scale], rd=RD[scale], rr=RR[scale]) for scale in DD } else: result = CorrFunc(dd=DD, dr=DR, rd=RD, rr=RR) return result