Source code for yaw.catalogs.scipy.kdtree

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
from scipy.spatial import cKDTree

from yaw.core.coordinates import Coordinate, DistSky

if TYPE_CHECKING:  # pragma: no cover
    from numpy.typing import NDArray

__all__ = ["SphericalKDTree"]


class InvalidScalesError(Exception):
    pass


[docs] class SphericalKDTree: """Wrapper around :obj:`scipy.spatial.cKDTree` that represents angular coordinates as points on the unitsphere. The only implemented operation is counting pairs in a fixed angular annulus. Angular distances are converted to the corresponding Euclidean distance on the unitsphere. Individual weights for points are supported. """ _total = None def __init__( self, position: Coordinate, weights: NDArray[np.float64] | None = None, leafsize: int = 16, ) -> None: """Build a new tree from a set of coordinates. Args: position (:obj:`yaw.coordinates.Coordinate`): A vector of coordinates in either angular or 3D coordiantes, is converted to 3D coordinates if needed. weights (:obj:`NDArray`, optional): Individual weights for the points. leafsize (:obj:`int`, optional): Size at which branches of the KDTree are considered leaf nodes with no further childs. """ position = np.atleast_2d(position.to_3d().values) self.tree = cKDTree(position, leafsize) if weights is None: self.weights = np.ones(len(position)) else: assert len(weights) == len(position) self.weights = np.asarray(weights) def __len__(self) -> int: return len(self.weights) @property def total(self) -> float: """Sum of weights or total number of objects if not provided.""" if self._total is None: self._total = self.weights.sum() return self._total
[docs] def count( self, other: SphericalKDTree, scales: NDArray[np.float64], dist_weight_scale: float | None = None, weight_res: int = 50, ) -> NDArray: """Count pairs on a set of angular scales. Pairs are counted with in a range of minimum and maximum angle in radian. If multiple scales are provided, the set of scales is converted into a list of radial bins. After counting, the binned counts are summed to obtain the counts for the (potentially overlapping) input scales. The method also supports weighting the pairs radially by a simple power-law :math:`r^\\alpha`, where :math:`r` is the pair separation. To speed up computation, the weight is computed individually, but for all pairs within one angular bin in the logarithmic center of the bin. If radial weights are provided, the resultion of the angular binning is increased beyond the binning obtained by combining the scale limits (see above). Args: other (:obj:`SphericalKDTree`): Second tree used to count pairs. scales (:obj:`NDArray`): Array with angular scales in radian with shape (2, N). The scales are provided as at least one tuple of minimum and maximum angular scale. dist_weight_scale (:obj:`float`, optional): The power-law index for the radial weighting. weight_res (:obj:`NDArray`): The number of logarithmic angular bins used to compute the angular weights. Ignored if no power-law index is set. Returns: :obj:`NDArray`: The pair counts for each input scale, with optional inidividual point weights and radial weights applied. .. Warning:: For autocorrelation measurements, ``other`` must be the same tree as the calling instance itself. This will results in pairs being counted twice, as they normally would be in the cross-tree counting case. """ # unpack query scales scales = np.atleast_2d(scales) if scales.shape[1] != 2: raise InvalidScalesError("'scales' must be composed of tuples of length 2") if np.any(scales <= 0.0): raise InvalidScalesError("scales must be positive (r > 0)") if np.any(scales > np.pi): raise InvalidScalesError("scales exceed 180 deg") log_scales = np.log10(scales).flatten() # construct bins rlog_edges = np.linspace(log_scales.min(), log_scales.max(), weight_res) rlog_edges = np.array(sorted(set(rlog_edges) | set(log_scales))) r_edges = 10**rlog_edges # count pairs try: counts = self.tree.count_neighbors( other.tree, DistSky(r_edges).to_3d().values, weights=(self.weights, other.weights), cumulative=False, ) except IndexError: counts = np.zeros_like(r_edges) counts = counts[1:] # discard counts with 0 < R <= r_min # apply the distance weights if dist_weight_scale is not None: rlog_centers = (rlog_edges[:-1] + rlog_edges[1:]) / 2.0 counts *= (10**rlog_centers) ** dist_weight_scale # compute counts for original bins result = np.empty(len(scales)) for i, scale in enumerate(scales): i_lo = np.argmin(np.abs(r_edges - scale[0])) i_hi = np.argmin(np.abs(r_edges - scale[1])) select = np.arange(i_lo, i_hi) result[i] = counts[select].sum() return result