Source code for yaw.catalogs.treecorr.catalog

from __future__ import annotations

import itertools
import os
import sys
from collections.abc import Iterator
from typing import TYPE_CHECKING, Dict, NoReturn, Tuple

try:  # pragma: no cover
    from typing import TypeAlias
except ImportError:  # pragma: no cover
    from typing_extensions import TypeAlias

import numpy as np
import pandas as pd
from numpy.typing import NDArray
from treecorr import Catalog, NNCorrelation

from yaw.catalogs import BaseCatalog
from yaw.config import Configuration, ResamplingConfig
from yaw.core.coordinates import Coord3D, Coordinate, CoordSky, DistSky
from yaw.core.logging import TimedLog
from yaw.correlation.paircounts import (
    NormalisedCounts,
    PatchedCount,
    PatchedTotal,
    pack_results,
)
from yaw.redshifts import HistData

if TYPE_CHECKING:  # pragma: no cover
    from pandas import DataFrame, Interval

    from yaw.catalogs import PatchLinkage

__all__ = ["EmptyCatalog", "TreecorrCatalog"]


TypeNNResult: TypeAlias = Dict[Tuple[int, int], NNCorrelation]  # supports py3.8


def _iter_bin_masks(
    data: NDArray, bins: NDArray, closed: str = "left"
) -> Iterator[tuple[Interval, NDArray[np.bool_]]]:
    """Split data into bins and return an iterator that yields the boolean masks
    that select the data of the current bin out of the input data array."""
    if closed not in ("left", "right"):
        raise ValueError("'closed' must be either of 'left', 'right'")
    intervals = pd.IntervalIndex.from_breaks(bins, closed=closed)
    bin_ids = np.digitize(data, bins, right=(closed == "right"))
    for i, interval in enumerate(intervals, 1):
        yield interval, (bin_ids == i)


def take_subset(
    cat: TreecorrCatalog, items: NDArray[np.bool_] | NDArray[np.int64] | slice
) -> TreecorrCatalog | EmptyCatalog:
    """Construct a new TreecorrCatalog with a subset of its entries."""
    ra = cat.ra[items]
    if len(ra) == 0:
        return EmptyCatalog(n_patches=cat.n_patches)
    kwargs = dict(
        ra=ra,
        ra_units="radian",
        dec=cat.dec[items],
        dec_units="radian",
        patch=cat.patch[items],
    )
    if cat.has_redshifts():
        kwargs["r"] = cat.redshifts[items]
    if cat.has_weights():
        kwargs["w"] = cat.weights[items]
    return TreecorrCatalog.from_treecorr(Catalog(**kwargs))


class EmptyCatalog:
    """A minimal representation of a TreecorrCatalog that contains no data."""

    def __init__(self, n_patches: int) -> None:
        self.n_patches = n_patches

    def __len__(self) -> int:
        return 0

    def total(self) -> float:
        return 0.0

    def get_totals(self) -> NDArray[np.float64]:
        return np.zeros(self.n_patches)


[docs] class TreecorrCatalog(BaseCatalog): """An implementation of the :obj:`BaseCatalog` using ``TreeCorr`` for the pair counting. .. Note:: The current implementation is not very efficient, because the internal fields of the underlying :obj:`treecorr.Catalog` must be rebuilt every time the catalog is iterated in redshift bins for the pair counting. .. Warning:: Currently this backend does not support restoration from cache and raises an :obj:`NotImplementedError`` when calling the :meth:`from_cache` method. """ def __init__( self, data: DataFrame, ra_name: str, dec_name: str, *, patch_name: str | None = None, patch_centers: BaseCatalog | Coordinate | None = None, n_patches: int | None = None, redshift_name: str | None = None, weight_name: str | None = None, cache_directory: str | None = None, progress: bool = False, ) -> None: # construct the underlying TreeCorr catalogue kwargs = dict() if cache_directory is not None: kwargs["save_patch_dir"] = cache_directory if not os.path.exists(cache_directory): raise FileNotFoundError( f"patch directory does not exist: '{cache_directory}'" ) self._logger.info("using cache directory '%s'", cache_directory) if n_patches is not None: kwargs["npatch"] = n_patches log_msg = f"creating {n_patches} patches" elif patch_name is not None: kwargs["patch"] = data[patch_name] log_msg = "splitting data into predefined patches" elif isinstance(patch_centers, BaseCatalog): kwargs["patch_centers"] = patch_centers.centers.to_3d().values n_patches = patch_centers.n_patches log_msg = f"applying {n_patches} patches from external data" elif isinstance(patch_centers, Coordinate): centers = patch_centers.to_3d() kwargs["patch_centers"] = centers.values n_patches = len(centers) log_msg = f"applying {n_patches} patches from external data" else: raise ValueError( "either of 'patch_name', 'patch_centers', or 'n_patches' " "must be provided" ) with TimedLog(self._logger.info, log_msg): self._catalog = Catalog( ra=data[ra_name], ra_units="degrees", dec=data[dec_name], dec_units="degrees", r=None if redshift_name is None else data[redshift_name], w=None if weight_name is None else data[weight_name], **kwargs, ) self._make_patches() if cache_directory is not None: self.unload()
[docs] @classmethod def from_cache(cls, cache_directory: str, progress: bool = False) -> NoReturn: """ Raises: NotImplementedError .. Warning:: Currently this backend does not support restoration from cache. """ # super().from_cache(cache_directory) # self._make_patches() raise NotImplementedError("restoring from cache is currently not supported")
[docs] @classmethod def from_treecorr(cls, cat: Catalog) -> TreecorrCatalog: """Create a new instace from a :obj:`treecorr.Catalog`.""" new = cls.__new__(cls) new._catalog = cat new._make_patches() return new
def _make_patches(self) -> None: c = self._catalog if c._patches is None: low_mem = (not self.is_loaded()) and (c.save_patch_dir is not None) c.get_patches(low_mem=low_mem)
[docs] def to_treecorr(self) -> Catalog: """Get the internal :obj:`treecorr.Catalog` instance.""" return self._catalog
def __len__(self) -> int: return self._catalog.ntot def __getitem__(self, item: int) -> Catalog: return self._catalog._patches[item] @property def ids(self) -> list[int]: return list(range(self.n_patches)) @property def n_patches(self) -> int: return self._catalog.npatch def __iter__(self) -> Iterator[Catalog]: for patch in self._catalog._patches: yield patch
[docs] def is_loaded(self) -> bool: return self._catalog.loaded
[docs] def load(self) -> None: super().load() self._catalog.load()
[docs] def unload(self) -> None: super().unload() self._catalog.unload()
[docs] def has_redshifts(self) -> bool: return self.redshifts is not None
[docs] def has_weights(self) -> bool: return self.weights is not None
@property def ra(self) -> NDArray[np.float64]: return self._catalog.ra @property def dec(self) -> NDArray[np.float64]: return self._catalog.dec @property def redshifts(self) -> NDArray[np.float64] | None: return self._catalog.r @property def weights(self) -> NDArray[np.float64]: return self._catalog.w @property def patch(self) -> NDArray[np.int64]: return self._catalog.patch
[docs] def get_min_redshift(self) -> float: if not hasattr(self, "_zmin"): if self.has_redshifts(): self._zmin = self.redshifts.min() else: self._zmin = None return self._zmin
[docs] def get_max_redshift(self) -> float: if not hasattr(self, "_zmax"): if self.has_redshifts(): self._zmax = self.redshifts.max() else: self._zmax = None return self._zmax
@property def total(self) -> float: return self._catalog.sumw
[docs] def get_totals(self) -> NDArray[np.float64]: return np.array([patch.sumw for patch in iter(self)])
@property def centers(self) -> CoordSky: centers = Coord3D.from_array(self._catalog.get_patch_centers()) return centers.to_sky() @property def radii(self) -> DistSky: radii = [] cls = self.__class__ for cat in iter(self): # build a new TreecorrCatalog without any postprocessing patch = cls.__new__(cls) patch._catalog = cat # compute the angular radius from the maximum separation in 3D position = patch.pos.to_3d() radius = patch.centers.to_3d().distance(position).max() radii.append(radius.to_sky()) return DistSky.from_dists(radii)
[docs] def iter_bins( self, z_bins: NDArray[np.float64], allow_no_redshift: bool = False ) -> Iterator[tuple[Interval, TreecorrCatalog | EmptyCatalog]]: """Iterate the catalogue in bins of redshift. Args: z_bins (:obj:`NDArray`): Edges of the redshift bins. allow_no_redshift (:obj:`bool`): If true and the data has no redshifts, the iterator yields the whole catalogue at each iteration step. Yields: (tuple): tuple containing: - **intv** (:obj:`pandas.Interval`): the selection for this bin. - **cat** (:obj:`TreecorrCatalog`): instance containing the data for this bin. """ if not allow_no_redshift and not self.has_redshifts(): raise ValueError("no redshifts for iteration provdided") if allow_no_redshift: for intv in pd.IntervalIndex.from_breaks(z_bins, closed="left"): yield intv, self else: for interval, bin_mask in _iter_bin_masks(self.redshifts, z_bins): yield interval, take_subset(self, bin_mask)
[docs] def correlate( self, config: Configuration, binned: bool, other: TreecorrCatalog = None, linkage: PatchLinkage | None = None, progress: bool = False, ) -> NormalisedCounts | dict[str, NormalisedCounts]: super().correlate(config, binned, other, linkage) auto = other is None if not auto and not isinstance(other, TreecorrCatalog): raise TypeError nncorr_config = dict( sep_units="radian", metric="Arc", nbins=(1 if config.scales.rweight is None else config.scales.rbin_num), bin_slop=config.backend.rbin_slop, num_threads=config.backend.get_threads(), ) # bin the catalogues if necessary cats1 = self.iter_bins(config.binning.zbins) if auto: cats2 = itertools.repeat((None, None)) else: if binned: cats2 = other.iter_bins(config.binning.zbins) else: cats2 = itertools.repeat((None, other)) # allocate output data containers binning = pd.IntervalIndex.from_breaks(config.binning.zbins) n_bins = len(binning) n_patches = self.n_patches totals1 = np.zeros((n_patches, n_bins)) totals2 = np.zeros((n_patches, n_bins)) count_dict = { str(scale): PatchedCount.zeros(binning, n_patches, auto=auto) for scale in config.scales } # iterate the bins and compute the correlation self._logger.debug( "running treecorr on %i threads", config.backend.get_threads() ) for i, ((intv, bincat1), (_, bincat2)) in enumerate(zip(cats1, cats2)): if progress: _prog_msg = f"processing bin {i+1} / {n_bins}\r" sys.stderr.write(_prog_msg) sys.stderr.flush() angles = [ scale.to_radian(intv.mid, config.cosmology) for scale in config.scales ] # extract the total number of objects per patch totals1[:, i] = bincat1.get_totals() if bincat2 is None: totals2[:, i] = totals1[:, i] else: totals2[:, i] = bincat2.get_totals() # trivial case: no data in redshift interval, no counts to update if isinstance(bincat1, EmptyCatalog) or isinstance(bincat2, EmptyCatalog): continue for scale, (ang_min, ang_max) in zip(config.scales, angles): # run the correlation measurement correlation = NNCorrelation( min_sep=ang_min, max_sep=ang_max, **nncorr_config ) correlation.process( bincat1.to_treecorr(), None if bincat2 is None else bincat2.to_treecorr(), ) # extract the pair counts scale_counts = count_dict[str(scale)] result: TypeNNResult = correlation.results for (pid1, pid2), corr_result in result.items(): scale_counts.counts[pid1, pid2, i] = corr_result.weight if progress: sys.stderr.write((" " * len(_prog_msg)) + "\r") # clear line total = PatchedTotal( # not scale-dependent binning=binning, totals1=totals1, totals2=totals2, auto=auto ) return pack_results(count_dict, total)
[docs] def true_redshifts( self, config: Configuration, sampling_config: ResamplingConfig | None = None, progress: bool = False, ) -> HistData: if sampling_config is None: sampling_config = ResamplingConfig() # default values super().true_redshifts(config) if not self.has_redshifts(): raise ValueError("catalog has no redshifts") # compute the reshift histogram in each patch hist_counts = [] for patch in iter(self): counts, bins = np.histogram(patch.r, config.binning.zbins, weights=patch.w) hist_counts.append(counts) hist_counts = np.array(hist_counts) # construct the output data samples binning = pd.IntervalIndex.from_breaks(config.binning.zbins) patch_idx = sampling_config.get_samples(self.n_patches) nz_data = hist_counts.sum(axis=0) nz_samp = np.sum(hist_counts[patch_idx], axis=1) return HistData( binning=binning, data=nz_data, samples=nz_samp, method=sampling_config.method, )