Source code for yaw.catalog.catalog

"""
Implements data catalogs, which are the centeral container for row-data and
facilitate correlation measurements.

Catalogs are dictionary-like collections of patches, which each hold a portion
of the catalog data. Data is not permanently held in memory, instead each
catalog is tied to a cache directory on disk. To retrive data, access a patch
and manually load the data from its cache. This design allows flexibility while
minimising the memory footprint of large datasets.

Catalogs can be constructed directly from input files or random generators.
"""

from __future__ import annotations

import logging
from collections import deque
from collections.abc import Mapping
from contextlib import AbstractContextManager
from enum import Enum
from pathlib import Path
from shutil import rmtree
from typing import TYPE_CHECKING

import numpy as np
import treecorr
from scipy.cluster import vq

from yaw.binning import Binning
from yaw.catalog.patch import Patch, PatchWriter
from yaw.catalog.readers import (
    DataChunkReader,
    DataFrameReader,
    RandomReader,
    new_filereader,
)
from yaw.catalog.trees import BinnedTrees, groupby
from yaw.coordinates import AngularCoordinates, AngularDistances
from yaw.datachunk import (
    PATCH_ID_DTYPE,
    DataChunk,
    DataChunkInfo,
    HandlesDataChunk,
    check_patch_ids,
)
from yaw.options import Closed
from yaw.randoms import RandomsBase
from yaw.utils import format_long_num, parallel
from yaw.utils.logging import Indicator
from yaw.utils.parallel import EndOfQueue

if TYPE_CHECKING:
    from collections.abc import Iterator

    from numpy.typing import NDArray
    from typing_extensions import Self

    from yaw.catalog.readers import DataFrame
    from yaw.datachunk import TypeDataChunk, TypePatchIDs

__all__ = [
    "Catalog",
    "create_patch_centers",
    "assign_patch_centers",
    "load_patches",
    "write_patches",
]


PATCH_NAME_TEMPLATE = "patch_{:d}"
"""Template to name patch directories in catalog cache directory."""

PATCH_INFO_FILE = "patch_ids.bin"
"""Name of file listing patch IDs in catalog cache directory."""

logger = logging.getLogger(__name__)


class InconsistentPatchesError(Exception):
    pass


class PatchMode(Enum):
    """Enumeration to specify the patch creation method."""

    apply = 0
    divide = 1
    create = 2

    @classmethod
    def determine(
        cls,
        patch_centers: AngularCoordinates | Catalog | None,
        patch_name: str | None,
        patch_num: int | None,
    ) -> PatchMode:
        """
        Determine the patch creation method to use.

        Method is determined from the three possible input parameters in the
        :obj:`~yaw.Catalog` creation routines, by checking which of the
        parameters are set in the following order of precedence:
        ``patch_centers`` > ``patch_name`` > ``patch_num``

        Args:
            patch_centers:
                A list of patch centers to use when creating the patches. Can be
                either :obj:`~yaw.AngularCoordinates` or an other
                :obj:`~yaw.Catalog` as reference.
            patch_name:
                Optional column name in the data frame for a column with integer
                patch indices. Indices must be contiguous and starting from 0.
                Ignored if ``patch_centers`` is given.
            patch_num:
                Automatically compute patch centers from a sparse sample of the
                input data using `treecorr`. Requires an additional scan of the
                input file to read a sparse sampling of the object coordinates.
                Ignored if ``patch_centers`` or ``patch_name`` is given.

        Returns:
            The Enum value indicating which patch creation method to use.

        Raises:
            TypeError:
                If the input values are of an invalid type.
            ValueError:
                If the number of patches exceeds the maximum allowed number or
                none of the input parameters are provided.
        """
        log_sink = logger.debug if parallel.on_root() else lambda *x: x

        if patch_centers is not None:
            if not isinstance(patch_centers, (AngularCoordinates, Catalog)):
                raise TypeError(
                    "'patch_centers' must be a set of coordinates or another catalog"
                )
            check_patch_ids(len(patch_centers))

            log_sink("applying %d patches", len(patch_centers))
            return PatchMode.apply

        if patch_name is not None:
            if not isinstance(patch_name, str):
                raise TypeError("'patch_name' must be a string")

            log_sink("dividing patches based on '%s'", patch_name)
            return PatchMode.divide

        elif patch_num is not None:
            if not isinstance(patch_num, int):
                raise TypeError("'patch_num' must be an integer")
            check_patch_ids(patch_num)

            log_sink("creating %d patches", patch_num)
            return PatchMode.create

        raise ValueError("no patch method specified")


def get_patch_centers(instance: AngularCoordinates | Catalog) -> AngularCoordinates:
    """Extract the patch centers from a set of angular coordinates or a catalog
    instance, raises ``TypeError`` otherwise."""
    try:
        return instance.get_centers()
    except AttributeError as err:
        if isinstance(instance, AngularCoordinates):
            return instance
        raise TypeError(
            "'patch_centers' must be of type 'Catalog' or 'AngularCoordinates'"
        ) from err


def create_patch_centers(
    reader: DataChunkReader, patch_num: int, probe_size: int
) -> AngularCoordinates:
    """
    Automatically create new patch centers from a data source.

    Data source can be a file reader or random generator. Patch centers are
    computed from a small data subset using ``treecorr`` for optimal efficiency.

    Args:
        reader:
            A :obj:`DataChunkReader` instance that exposes a random generator or
            file reader.
        patch_num:
            The number of patches to create.
        probe_size:
            The size of the subsample from which the patches are computed.

    Returns:
        A new set of angular coordinates of the patch centers.
    """
    if probe_size < 10 * patch_num:
        probe_size = int(100_000 * np.sqrt(patch_num))
    if parallel.on_root():
        logger.info(
            "computing patch centers from subset of %s records",
            format_long_num(probe_size),
        )

    data_probe = reader.get_probe(probe_size)

    patch_centers = None
    if parallel.on_root():
        cat = treecorr.Catalog(
            ra=DataChunk.getattr(data_probe, "ra"),
            ra_units="radians",
            dec=DataChunk.getattr(data_probe, "dec"),
            dec_units="radians",
            w=DataChunk.getattr(data_probe, "weights", None),
            npatch=patch_num,
            config=dict(num_threads=parallel.get_size()),
        )
        patch_centers = AngularCoordinates.from_3d(cat.patch_centers)
    return parallel.COMM.bcast(patch_centers, root=0)


def assign_patch_centers(patch_centers: NDArray, data: TypeDataChunk) -> TypePatchIDs:
    """
    Computes the patch ID for a set of objects and patch center coordinates.

    Objects are assigned to the nearest patch center, expressed by the index of
    the patch center in the input list of patch centers.

    Args:
        patch_centers:
            Numpy array of patch center coordinates in radian and shape
            `(N, 2)`.
        data:
            Numpy array holding input object coordinates, i.e. a chunk of
            catalog data, must contain with fields ``ra`` and ``dec``.

    Returns:
        Array of 16-bit integer patch IDs for each input obejct.
    """
    coords = DataChunk.get_coords(data)
    ids, _ = vq.vq(coords.to_3d(), patch_centers)
    return ids.astype(PATCH_ID_DTYPE)


def split_into_patches(
    chunk: TypeDataChunk, patch_centers: NDArray | None
) -> dict[int, TypeDataChunk]:
    """
    Split a numpy array of catalog data into patches.

    If patch centers are provided, assigns patch IDs from nearest patch center.
    If a patch ID column is contained in the input data, uses that to assign
    objects to patches.

    Args:
        chunk:
            Numpy array holding input object coordinates, i.e. a chunk of
            catalog data, must contain with fields ``ra`` and ``dec``, and
            optionally ``patch_ids``.
        patch_centers:
            Optional, numpy array of patch center coordinates in radian and
            shape `(N, 2)`.

    Returns:
        Dictionary with patch IDs as keys and subset of input data chunk with
        objects belonging to the corresponding patch ID.

    Raises:
        RuntimeError:
            If neither patch centers nor patch IDs per object are provided.
    """
    has_patch_ids = DataChunk.hasattr(chunk, "patch_ids")

    # statement order matters
    if patch_centers is not None:
        patch_ids = assign_patch_centers(patch_centers, chunk)
        if has_patch_ids:
            chunk, _ = DataChunk.pop(chunk, "patch_ids")
    elif has_patch_ids:
        # patch IDs will be redundant information so we delete them
        chunk, patch_ids = DataChunk.pop(chunk, "patch_ids")
    else:  # pragma: no cover
        raise RuntimeError("found no way to obtain patch centers")

    return {
        int(patch_id): patch_data for patch_id, patch_data in groupby(patch_ids, chunk)
    }


def get_patch_path_from_id(cache_directory: Path | str, patch_id: int) -> Path:
    """
    Get the patch to a specific patch cache directory.

    Args:
        cache_directory:
            The cache directory used by the parent catalog.
        patch_id:
            ID of the patch for which to optain the patch cache directory.

    Returns:
        Path as a :obj:`pathlib.Path`.
    """
    return Path(cache_directory) / PATCH_NAME_TEMPLATE.format(patch_id)


def get_id_from_patch_path(cache_path: Path | str) -> int:
    """
    Extract the integer patch ID from a patch cache path.

    .. caution::
        This will fail if the patch has not been created through a
        :obj:`CatalogWriter` instance, which manages the patch creation.
    """
    _, id_str = Path(cache_path).name.split("_")
    return int(id_str)


def read_patch_ids(cache_directory: Path) -> list[int]:
    """Reads a list of patch IDs in a catalog from a metadata file stored in the
    catalog's cache directory."""
    path = cache_directory / PATCH_INFO_FILE
    if not path.exists():
        raise InconsistentPatchesError("patch info file not found")
    return np.fromfile(path, dtype=PATCH_ID_DTYPE).tolist()


def load_patches(
    cache_directory: Path,
    *,
    patch_centers: AngularCoordinates | Catalog | None,
    progress: bool,
    max_workers: int | None = None,
) -> dict[int, Patch]:
    """
    Instantiate all patches stored in a catalog's cache directory.

    Function is MPI aware, patches are loaded on the root worker and broadcasted
    to all workers. Computes patch metadata if not present. If patch centers
    are provided, only the patch radius is computed but not the patch centers.

    Args:
        cache_directory:
            Cache directory of the parent catalog.

    Keyword Args:
        patch_centers:
            Optional set of angular coordinates or catalog instance that
            defines the exact patch centers to use.
        progress:
            Show a progress on the terminal (disabled by default).
        max_workers:
            Limit the  number of parallel workers for this operation (all by
            default).
    """
    patch_ids = None
    if parallel.on_root():
        patch_ids = read_patch_ids(cache_directory)
    patch_ids = parallel.COMM.bcast(patch_ids, root=0)

    # instantiate patches, which triggers computing the patch meta-data
    path_template = str(cache_directory / PATCH_NAME_TEMPLATE)
    patch_paths = map(path_template.format, patch_ids)

    if patch_centers is not None:
        if isinstance(patch_centers, Catalog):
            patch_centers = patch_centers.get_centers()
        patch_arg_iter = zip(patch_paths, patch_centers)

    else:
        patch_arg_iter = zip(patch_paths)

    patch_iter = parallel.iter_unordered(
        Patch, patch_arg_iter, unpack=True, max_workers=max_workers
    )
    if progress:
        patch_iter = Indicator(patch_iter, len(patch_ids))

    patches = {get_id_from_patch_path(patch.cache_path): patch for patch in patch_iter}
    return parallel.COMM.bcast(patches, root=0)


class CatalogWriter(AbstractContextManager, HandlesDataChunk):
    """
    A helper class that handles a stream of input catalog data and splits and
    writes it to patches.

    Args:
        cache_directory:
            Cache directory of the catalog.

    Keyword Args:
        overwrite:
            Whether to overwrite an existing catalog at the given cache
            location.
        chunk_info:
            An instance of :obj:`yaw.datachunk.DataChunkInfo` indicating which
            optional data attributes are processed by the pipeline.
        buffersize:
            Optional, maximum number of records to store in the internal cache
            of each patch writer.

    Attributes:
        cache_directory:
            Cache directory to use when creating the patches.
        writers:
            Dictionary of patch IDs / :obj:`~yaw.catalog.patch.PatchWriters`
            that delegates writing data for an individual patch.
        buffersize:
            Optional, maximum number of records to store in the internal cache
            of each patch writer.

    Raises:
        FileExistsError:
            If the cache directory already exists and ``overwrite==False``.
    """

    __slots__ = (
        "_chunk_info",
        "cache_directory",
        "buffersize",
        "writers",
    )

    def __init__(
        self,
        cache_directory: Path | str,
        *,
        chunk_info: DataChunkInfo,
        overwrite: bool = True,
        buffersize: int = -1,
    ) -> None:
        self._chunk_info = chunk_info
        self.cache_directory = Path(cache_directory)
        cache_exists = self.cache_directory.exists()

        if parallel.on_root():
            logger.info(
                "%s cache directory: %s",
                "overwriting" if cache_exists and overwrite else "using",
                cache_directory,
            )

        if self.cache_directory.exists():
            if overwrite:
                rmtree(self.cache_directory)
            else:
                raise FileExistsError(f"cache directory exists: {cache_directory}")

        self.buffersize = buffersize
        self.cache_directory.mkdir()
        self.writers: dict[int, PatchWriter] = {}

    def __repr__(self) -> str:
        items = (
            f"num_patches={self.num_patches}",
            f"max_buffersize={self.buffersize * self.num_patches}",
        )
        attrs = self._chunk_info.format()
        return f"{type(self).__name__}({', '.join(items)}, {attrs}) @ {self.cache_directory}"

    def __enter__(self) -> Self:
        return self

    def __exit__(self, *args, **kwargs) -> None:
        self.finalize()

    @property
    def num_patches(self) -> int:
        """The number of unique patch IDs encountered so far."""
        return len(self.writers)

    def get_writer(self, patch_id: int) -> PatchWriter:
        """Get the patch writer for the given patch ID and create it if it does
        not yet exist."""
        try:
            return self.writers[patch_id]

        except KeyError:
            writer = PatchWriter(
                get_patch_path_from_id(self.cache_directory, patch_id),
                chunk_info=self.copy_chunk_info(),
                buffersize=self.buffersize,
            )
            self.writers[patch_id] = writer
            return writer

    def process_patches(self, patches: dict[int, TypeDataChunk]) -> None:
        """
        Process a dictionary of catalog data split into patches.

        Dictionary values are sent to the individual patch data writers which
        cache the data in memory temporarily or write them to disk.

        Args:
            patches:
                A dictionary of patch ID / numpy array with catalog data
                (containing ``ra``, ``dec``, or optionally ``weights`` and
                ``redshifts`` fields).
        """
        for patch_id, patch in patches.items():
            self.get_writer(patch_id).process_chunk(patch)

    def finalize(self) -> None:
        """
        Finalise the catalog cache directory.

        Flushes all patch writer caches and writes a list of patch IDs to
        the cache directory that simplifes loading the catalog instance later.

        Raises:
            ValueError:
                If any of the patches does not contain any data.
        """
        empty_patches = set()
        for patch_id, writer in self.writers.items():
            writer.close()
            if writer.num_processed == 0:
                empty_patches.add(patch_id)

        for patch_id in empty_patches:
            raise ValueError(f"patch with ID {patch_id} contains no data")

        patch_ids = np.fromiter(self.writers.keys(), dtype=np.int16)
        np.sort(patch_ids).tofile(self.cache_directory / PATCH_INFO_FILE)


def write_patches_unthreaded(
    path: Path | str,
    reader: DataChunkReader,
    patch_centers: AngularCoordinates | Catalog | None,
    *,
    overwrite: bool,
    progress: bool,
    buffersize: int = -1,
) -> None:
    """
    Read catalog from an input source and write the data to catalog cache
    directory.

    Creates patch centers automatically from data source if none are provided.
    This is a fallback implementation if parallel workers are disabled.

    Args:
        path:
            The target cache directory.
        reader:
            A :obj:`DataChunkReader` instance that exposes a random generator or
            file reader.
        patch_centers:
            Optional set of angular coordinates or catalog instance that
            defines the exact patch centers to use.

    Keyword Args:
        overwrite:
            Whether to overwrite an existing catalog at the given cache
            location. If the directory is not a valid catalog, a
            ``FileExistsError`` is raised.
        progress:
            Show a progress on the terminal (disabled by default).
        buffersize:
            Optional, maximum number of records to store in the internal cache
            of each patch writer.

    """
    with reader:
        if patch_centers is not None:
            patch_centers = get_patch_centers(patch_centers).to_3d()

        with CatalogWriter(
            cache_directory=path,
            chunk_info=reader.copy_chunk_info(drop_patch_ids=True),
            overwrite=overwrite,
            buffersize=buffersize,
        ) as writer:
            chunk_iter = Indicator(reader) if progress else iter(reader)
            for chunk in chunk_iter:
                patches = split_into_patches(chunk, patch_centers)
                writer.process_patches(patches)


if parallel.use_mpi():
    """Implementation of parallel input data processing based on OpenMPI."""

    from mpi4py import MPI

    if TYPE_CHECKING:
        from mpi4py.MPI import Comm

    class WorkerManager:
        """Contains information required by the MPI workers to coordinate
        parallel processing: rank that is responsible for reading, rank that is
        responsible for writing and which ranks are responsible for processing
        chunk data in parallel."""

        def __init__(self, max_workers: int | None, reader_rank: int = 0) -> None:
            self.reader_rank = reader_rank

            max_workers = parallel.get_size(max_workers)
            self.active_ranks = parallel.ranks_on_same_node(reader_rank, max_workers)

            self.active_ranks.discard(reader_rank)
            self.writer_rank = self.active_ranks.pop()
            self.active_ranks.add(reader_rank)

        def get_comm(self) -> Comm:
            rank = parallel.COMM.Get_rank()
            if rank in self.active_ranks:
                return parallel.COMM.Split(1, rank)
            else:
                return parallel.COMM.Split(MPI.UNDEFINED, rank)

    def scatter_data_chunk(comm: Comm, reader_rank: int, chunk: DataChunk) -> DataChunk:
        """Takes a chunk of catalog data, splits it into chunks and broadcasts
        the chunks to the parallel chunk processing tasks."""
        num_ranks = comm.Get_size()

        if comm.Get_rank() == reader_rank:
            splits = np.array_split(chunk, num_ranks)

            for rank, split in enumerate(splits):
                if rank != reader_rank:
                    comm.send(split, dest=rank, tag=2)

            return splits[reader_rank]

        else:
            return comm.recv(source=0, tag=2)

    def chunk_processing_task(
        comm: Comm,
        worker_config: WorkerManager,
        patch_centers: AngularCoordinates | Catalog | None,
        chunk_iter: Iterator[DataChunk],
    ) -> None:
        """A dedicated parallel worker task which splits catalog data into
        paches and sends the data to the writer process."""
        if patch_centers is not None:
            patch_centers = patch_centers.to_3d()

        reader_rank = parallel.world_to_comm_rank(comm, worker_config.reader_rank)

        for chunk in chunk_iter:
            worker_chunk = scatter_data_chunk(comm, reader_rank, chunk)
            patches = split_into_patches(worker_chunk, patch_centers)
            parallel.COMM.send(patches, dest=worker_config.writer_rank, tag=1)

        comm.Barrier()

    def writer_task(
        cache_directory: Path | str,
        *,
        chunk_info: DataChunkInfo,
        overwrite: bool = True,
        buffersize: int = -1,
    ) -> None:
        """A dedicated writer process that recieves a dictionary with patch IDs
        and patch data to write using a :obj:`CatalogWriter`, terminated when
        receiving :obj:`EndOfQueue` sentinel."""
        recv = parallel.COMM.recv
        with CatalogWriter(
            cache_directory,
            chunk_info=chunk_info,
            overwrite=overwrite,
            buffersize=buffersize,
        ) as writer:
            while (patches := recv(source=MPI.ANY_SOURCE, tag=1)) is not EndOfQueue:
                writer.process_patches(patches)

    def write_patches(
        path: Path | str,
        reader: DataChunkReader,
        patch_centers: AngularCoordinates | Catalog | None,
        *,
        overwrite: bool,
        progress: bool,
        max_workers: int | None = None,
        buffersize: int = -1,
    ) -> None:
        """
        Read catalog from an input source and write the data to catalog cache
        directory.

        Creates patch centers automatically from data source if none are
        provided. This is an implementation with MPI parallelsim. The root rank
        is responsible for reading data from the source, one rank is responsible
        for writing to the cache directory, any remaining ranks process the
        input data.

        .. Note::
            The code tries to schedule all work only on the same node that
            hosts the root tasks to avoid inter-node communication.

        Args:
            path:
                The target cache directory.
            reader:
                A :obj:`DataChunkReader` instance that exposes a random
                generator or file reader.
            patch_centers:
                Optional set of angular coordinates or catalog instance that
                defines the exact patch centers to use.

        Keyword Args:
            overwrite:
                Whether to overwrite an existing catalog at the given cache
                location. If the directory is not a valid catalog, a
                ``FileExistsError`` is raised.
            progress:
                Show a progress on the terminal (disabled by default).
            max_workers:
                Limit the  number of parallel workers for this operation (all by
                default).
            buffersize:
                Optional, maximum number of records to store in the internal
                cache of each patch writer.
        """
        max_workers = parallel.get_size(max_workers)
        if max_workers < 2:
            raise ValueError("catalog creation requires at least two workers")
        if parallel.on_root():
            logger.debug("running preprocessing on %d workers", max_workers)

        rank = parallel.COMM.Get_rank()
        worker_config = WorkerManager(max_workers, 0)
        worker_comm = worker_config.get_comm()

        if rank == worker_config.writer_rank:
            writer_task(
                cache_directory=path,
                chunk_info=reader.copy_chunk_info(drop_patch_ids=True),
                overwrite=overwrite,
                buffersize=buffersize,
            )

        elif rank in worker_config.active_ranks:
            if patch_centers is not None:
                patch_centers = get_patch_centers(patch_centers)

            with reader:
                chunk_iter = Indicator(reader) if progress else iter(reader)
                chunk_processing_task(
                    worker_comm,
                    worker_config,
                    patch_centers,
                    chunk_iter,
                )

            worker_comm.Free()

        if parallel.COMM.Get_rank() == worker_config.reader_rank:
            parallel.COMM.send(EndOfQueue, dest=worker_config.writer_rank, tag=1)
        parallel.COMM.Barrier()

else:
    """Implementation of parallel input data processing based on python's
    multiprocessing."""

    import multiprocessing
    from dataclasses import dataclass, field

    if TYPE_CHECKING:
        from multiprocessing import Queue

    class ChunkProcessingTask:
        """Defines the worker task which splits catalog data into paches and
        puts the data into the writer process queue."""

        def __init__(
            self,
            patch_queue: Queue[dict[int, TypeDataChunk] | EndOfQueue],
            patch_centers: AngularCoordinates | None,
        ) -> None:
            self.patch_queue = patch_queue

            if isinstance(patch_centers, AngularCoordinates):
                self.patch_centers = patch_centers.to_3d()
            else:
                self.patch_centers = None

        def __call__(self, chunk: DataChunk) -> dict[int, TypeDataChunk]:
            patches = split_into_patches(chunk, self.patch_centers)
            self.patch_queue.put(patches)

    @dataclass
    class WriterProcess(AbstractContextManager):
        """A dedicated writer process that recieves a dictionary with patch IDs
        and patch data to write using a :obj:`CatalogWriter`, terminated when
        receiving :obj:`EndOfQueue` sentinel."""

        patch_queue: Queue[dict[int, TypeDataChunk] | EndOfQueue]
        cache_directory: Path | str
        chunk_info: DataChunkInfo = field(kw_only=True)
        overwrite: bool = field(default=True, kw_only=True)
        buffersize: int = field(default=-1, kw_only=True)

        def __post_init__(self) -> None:
            self.process = multiprocessing.Process(target=self.task)

        def __enter__(self) -> Self:
            self.start()
            return self

        def __exit__(self, *args, **kwargs) -> None:
            self.join()

        def task(self) -> None:
            with CatalogWriter(
                self.cache_directory,
                overwrite=self.overwrite,
                chunk_info=self.chunk_info,
                buffersize=self.buffersize,
            ) as writer:
                while (patches := self.patch_queue.get()) is not EndOfQueue:
                    writer.process_patches(patches)

        def start(self) -> None:
            self.process.start()

        def join(self) -> None:
            self.process.join()

    def write_patches(
        path: Path | str,
        reader: DataChunkReader,
        patch_centers: AngularCoordinates | Catalog | None,
        *,
        overwrite: bool,
        progress: bool,
        max_workers: int | None = None,
        buffersize: int = -1,
    ) -> None:
        """
        Read catalog from an input source and write the data to catalog cache
        directory.

        Creates patch centers automatically from data source if none are
        provided. This is an implementation with MPI parallelsim. There is a
        dedicated process that handles writing data to the catalog cache
        directory.

        Args:
            path:
                The target cache directory.
            reader:
                A :obj:`DataChunkReader` instance that exposes a random
                generator or file reader.
            patch_centers:
                Optional set of angular coordinates or catalog instance that
                defines the exact patch centers to use.

        Keyword Args:
            overwrite:
                Whether to overwrite an existing catalog at the given cache
                location. If the directory is not a valid catalog, a
                ``FileExistsError`` is raised.
            progress:
                Show a progress on the terminal (disabled by default).
            max_workers:
                Limit the  number of parallel workers for this operation (all by
                default).
            buffersize:
                Optional, maximum number of records to store in the internal
                cache of each patch writer.
        """
        max_workers = parallel.get_size(max_workers)

        if max_workers == 1:
            logger.debug("running preprocessing sequentially")
            return write_patches_unthreaded(
                path,
                reader,
                patch_centers,
                overwrite=overwrite,
                progress=progress,
                buffersize=buffersize,
            )

        else:
            logger.debug("running preprocessing on %d workers", max_workers)

        with (
            reader,
            multiprocessing.Manager() as manager,
            multiprocessing.Pool(max_workers) as pool,
        ):
            patch_queue = manager.Queue()

            if patch_centers is not None:
                patch_centers = get_patch_centers(patch_centers)
            chunk_processing_task = ChunkProcessingTask(patch_queue, patch_centers)

            with WriterProcess(
                patch_queue,
                cache_directory=path,
                chunk_info=reader.copy_chunk_info(drop_patch_ids=True),
                overwrite=overwrite,
                buffersize=buffersize,
            ):
                chunk_iter = Indicator(reader) if progress else iter(reader)
                for chunk in chunk_iter:
                    pool.map(chunk_processing_task, np.array_split(chunk, max_workers))

                patch_queue.put(EndOfQueue)


[docs] class Catalog(Mapping[int, Patch]): """ A container for catalog data. Catalogs are the core data structure for managing point data catalogs. Besides right ascension and declination coordinates, catalogs may have additional per-object weights and redshifts. Catalogs divided into spatial :obj:`~yaw.catalog.Patch` es, which each cache a portion of the data on disk to minimise the memory footprint when dealing with large data-sets, allowing to process the data in a patch-wise manner, only loading data from disk when they are needed. Additionally, the patches are used to estimate uncertainties using jackknife resampling. .. note:: The number of patches should be sufficently large to support the redshift binning used for correlation measurements. The number of patches is also a trade-off between runtime and memory footprint during correlation measurements. The cached data is organised in a single directory, with one sub-directory for each spatial :obj:`~yaw.Patch`:: [cache_directory]/ ├╴ patch_ids.bin # list of patch IDs for this catalog ├╴ patch_0/ │ └╴ ... # patch data ├╴ patch_1/ │ ... └╴ patch_N/ .. caution:: Empty patches are currently not supported and the catalog creation will fail if a patch without any data is encountered (e.g. if the input catalog is too sparse or inhomogeneous). Args: cache_directory: The cache directory to use for this catalog, must exist and contain a valid catalog cache. Keyword Args: max_workers: Limit the number of parallel workers for this operation (all by default). """ __slots__ = ("cache_directory", "_patches") _patches: dict[int, Patch] def __init__( self, cache_directory: Path | str, *, max_workers: int | None = None ) -> None: if parallel.on_root(): logger.info("restoring from cache directory: %s", cache_directory) self.cache_directory = Path(cache_directory) if not self.cache_directory.exists(): raise OSError(f"cache directory not found: {self.cache_directory}") self._patches = load_patches( self.cache_directory, patch_centers=None, progress=False, max_workers=max_workers, )
[docs] @classmethod def from_dataframe( cls, cache_directory: Path | str, dataframe: DataFrame, *, ra_name: str, dec_name: str, weight_name: str | None = None, redshift_name: str | None = None, patch_centers: AngularCoordinates | Catalog | None = None, patch_name: str | None = None, patch_num: int | None = None, degrees: bool = True, overwrite: bool = False, progress: bool = False, max_workers: int | None = None, chunksize: int | None = None, probe_size: int = -1, **reader_kwargs, ) -> Catalog: """ Create a new catalog instance from a :obj:`pandas.DataFrame`. Assign objects from the input data frame to spatial patches, write the patches to a cache on disk, and compute the patch meta data. .. note:: One of the optional patch creation arguments (``patch_centers``, ``patch_name``, or ``patch_num``) must be provided. Args: cache_directory: The cache directory to use for this catalog. Created automatically or overwritten if requested. dataframe: The input data frame. May also be an object that supports mapping from string (column name) to data (numpy array-like). Keyword Args: ra_name: Column name in the data frame for right ascension. dec_name: Column name in the data frame for declination. weight_name: Optional column name in the data frame for weights. redshift_name: Optional column name in the data frame for redshifts. patch_centers: A list of patch centers to use when creating the patches. Can be either :obj:`~yaw.AngularCoordinates` or an other :obj:`~yaw.Catalog` as reference. patch_name: Optional column name in the data frame for a column with integer patch indices. Indices must be contiguous and starting from 0. Ignored if ``patch_centers`` is given. patch_num: Automatically compute patch centers from a sparse sample of the input data using `treecorr`. Requires an additional scan of the input file to read a sparse sampling of the object coordinates. Ignored if ``patch_centers`` or ``patch_name`` is given. degrees: Whether the input coordinates are given in degreees (default). overwrite: Whether to overwrite an existing catalog at the given cache location. If the directory is not a valid catalog, a ``FileExistsError`` is raised. progress: Show a progress on the terminal (disabled by default). max_workers: Limit the number of parallel workers for this operation (all by default). chunksize: The maximum number of records to load into memory at once when processing the input file in chunks. probe_size: The approximate number of records to read when generating patch centers (``patch_num``). Returns: A new catalog instance. Raises: FileExistsError: If the cache directory exists or is not a valid catalog when providing ``overwrite=True``. """ reader = DataFrameReader( dataframe, ra_name=ra_name, dec_name=dec_name, weight_name=weight_name, redshift_name=redshift_name, patch_name=patch_name, chunksize=chunksize, degrees=degrees, **reader_kwargs, ) mode = PatchMode.determine(patch_centers, patch_name, patch_num) if mode == PatchMode.create: patch_centers = create_patch_centers(reader, patch_num, probe_size) # split the data into patches and create the cached Patch instances. write_patches( cache_directory, reader, patch_centers, overwrite=overwrite, progress=progress, max_workers=max_workers, buffersize=-1, ) if parallel.on_root(): logger.info("computing patch metadata") new = cls.__new__(cls) new.cache_directory = Path(cache_directory) new._patches = load_patches( new.cache_directory, patch_centers=patch_centers, progress=progress, max_workers=max_workers, ) return new
[docs] @classmethod def from_file( cls, cache_directory: Path | str, path: Path | str, *, ra_name: str, dec_name: str, weight_name: str | None = None, redshift_name: str | None = None, patch_centers: AngularCoordinates | Catalog | None = None, patch_name: str | None = None, patch_num: int | None = None, degrees: bool = True, overwrite: bool = False, progress: bool = False, max_workers: int | None = None, chunksize: int | None = None, probe_size: int = -1, **reader_kwargs, ) -> Catalog: """ Create a new catalog instance from a data file. Processes the input file in chunks, assign objects to spatial patches, write the patches to a cache on disk, and compute the patch meta data. Supported file formats are `FITS`, `Parquet`, and `HDF5`. .. note:: One of the optional patch creation arguments (``patch_centers``, ``patch_name``, or ``patch_num``) must be provided. Args: cache_directory: The cache directory to use for this catalog. Created automatically or overwritten if requested. path: The path to the input data file. Keyword Args: ra_name: Column or path name in the file for right ascension. dec_name: Column or path name in the file for declination. weight_name: Optional column or path name in the file for weights. redshift_name: Optional column or path name in the file for redshifts. patch_centers: A list of patch centers to use when creating the patches. Can be either :obj:`~yaw.AngularCoordinates` or an other :obj:`~yaw.Catalog` as reference. patch_name: Optional column or path name for a column with integer patch indices. Indices must be contiguous and starting from 0. Ignored if ``patch_centers`` is given. patch_num: Automatically compute patch centers from a sparse sample of the input data using `treecorr`. Requires an additional scan of the input file to read a sparse sampling of the object coordinates. Ignored if ``patch_centers`` or ``patch_name`` is given. degrees: Whether the input coordinates are given in degreees (default). overwrite: Whether to overwrite an existing catalog at the given cache location. If the directory is not a valid catalog, a ``FileExistsError`` is raised. progress: Show a progress on the terminal (disabled by default). max_workers: Limit the number of parallel workers for this operation (all by default). chunksize: The maximum number of records to load into memory at once when processing the input file in chunks. probe_size: The approximate number of records to read when generating patch centers (``patch_num``). Returns: A new catalog instance. Raises: FileExistsError: If the cache directory exists or is not a valid catalog when providing ``overwrite=True``. Additional reader keyword arguments are passed on to the file reader class constuctor. """ reader = new_filereader( path, ra_name=ra_name, dec_name=dec_name, weight_name=weight_name, redshift_name=redshift_name, patch_name=patch_name, chunksize=chunksize, degrees=degrees, **reader_kwargs, ) mode = PatchMode.determine(patch_centers, patch_name, patch_num) if mode == PatchMode.create: patch_centers = create_patch_centers(reader, patch_num, probe_size) # split the data into patches and create the cached Patch instances. write_patches( cache_directory, reader, patch_centers, overwrite=overwrite, progress=progress, max_workers=max_workers, buffersize=-1, ) if parallel.on_root(): logger.info("computing patch metadata") new = cls.__new__(cls) new.cache_directory = Path(cache_directory) new._patches = load_patches( new.cache_directory, patch_centers=patch_centers, progress=progress, max_workers=max_workers, ) return new
[docs] @classmethod def from_random( cls, cache_directory: Path | str, generator: RandomsBase, num_randoms: int, *, patch_centers: AngularCoordinates | Catalog | None = None, patch_num: int | None = None, overwrite: bool = False, progress: bool = False, max_workers: int | None = None, chunksize: int | None = None, probe_size: int = -1, ) -> Catalog: """ Create a new catalog instance from a data file. Generate a catalog from uniform random data points in chunks, assign objects to spatial patches, write the patches to a cache on disk, and compute the patch meta data. The :ref:`generator object<generator>` must be created separately by the user. .. note:: One of the optional patch creation arguments (``patch_centers``, or ``patch_num``) must be provided (``patch_name`` is not supported). Args: cache_directory: The cache directory to use for this catalog. Created automatically or overwritten if requested. generator: A random generator (:obj:`~yaw.catalog.generator.RandomsBase`) instance from which samples are drawn. num_randoms: The number of randoms to generate. Keyword Args: patch_centers: A list of patch centers to use when creating the patches. Can be either :obj:`~yaw.AngularCoordinates` or an other :obj:`~yaw.Catalog` as reference. patch_num: Automatically compute patch centers from a sparse sample of the input data using `treecorr`. Requires an additional scan of the input file to read a sparse sampling of the object coordinates. Ignored if ``patch_centers`` or ``patch_name`` is given. overwrite: Whether to overwrite an existing catalog at the given cache location. If the directory is not a valid catalog, a ``FileExistsError`` is raised. progress: Show a progress on the terminal (disabled by default). max_workers: Limit the number of parallel workers for this operation (all by default). chunksize: The maximum number of records to generate and write at once. probe_size: The number of initial random samples to draw read when generating patch centers (``patch_num``). Returns: A new catalog instance. Raises: FileExistsError: If the cache directory exists or is not a valid catalog when providing ``overwrite=True``. """ rand_iter = RandomReader(generator, num_randoms, chunksize) mode = PatchMode.determine(patch_centers, None, patch_num) if mode == PatchMode.create: patch_centers = create_patch_centers(rand_iter, patch_num, probe_size) # split the data into patches and create the cached Patch instances. write_patches( cache_directory, rand_iter, patch_centers, overwrite=overwrite, progress=progress, max_workers=max_workers, buffersize=-1, ) if parallel.on_root(): logger.info("computing patch metadata") new = cls.__new__(cls) new.cache_directory = Path(cache_directory) new._patches = load_patches( new.cache_directory, patch_centers=patch_centers, progress=progress, max_workers=max_workers, ) return new
def __repr__(self) -> str: items = ( f"num_patches={self.num_patches}", f"num_records={sum(self.get_num_records())}", ) patch = next(iter(self.values())) attrs = patch._chunk_info.format() return f"{type(self).__name__}({', '.join(items)}, {attrs}) @ {self.cache_directory}" def __len__(self) -> int: return len(self._patches) def __getitem__(self, patch_id: int) -> Patch: return self._patches[patch_id] def __iter__(self) -> Iterator[int]: yield from sorted(self._patches.keys()) @property def num_patches(self) -> int: """The number of patches of this catalog.""" return len(self) @property def has_weights(self) -> bool: """Whether weights are available.""" has_weights = tuple(patch.has_weights for patch in self.values()) if all(has_weights): return True elif not any(has_weights): return False raise InconsistentPatchesError("'weights' not consistent") @property def has_redshifts(self) -> bool: """Whether redshifts are available.""" has_redshifts = tuple(patch.has_redshifts for patch in self.values()) if all(has_redshifts): return True elif not any(has_redshifts): return False raise InconsistentPatchesError("'redshifts' not consistent")
[docs] def get_num_records(self) -> tuple[int, ...]: """Get the number of records in each patches.""" return tuple(patch.meta.num_records for patch in self.values())
[docs] def get_sum_weights(self) -> tuple[float, ...]: """Get the sum of weights of the patches.""" return tuple(patch.meta.sum_weights for patch in self.values())
[docs] def get_centers(self) -> AngularCoordinates: """Get the center coordinates of the patches.""" return AngularCoordinates.from_coords( patch.meta.center for patch in self.values() )
[docs] def get_radii(self) -> AngularDistances: """Get the radii of the patches.""" return AngularDistances.from_dists(patch.meta.radius for patch in self.values())
[docs] def build_trees( self, binning: NDArray | None = None, *, closed: Closed | str = Closed.right, leafsize: int = 16, force: bool = False, progress: bool = False, max_workers: int | None = None, ) -> None: """ Build binary search trees on for each patch. The trees are cached in the patches' cache directory and can be retrieved through ``yaw.trees.BinnedTrees(patch)``. Args: binning: Optional array with redshift bin edges to apply to the data before building trees. Keyword Args: closed: Indicating which side of the bin edges is a closed interval, see :obj:`~yaw.options.Closed` for valid options. leafsize: Leafsize when building trees. force: Whether to overwrite any existing, cached trees. progress: Show a progress on the terminal (disabled by default). max_workers: Limit the number of parallel workers for this operation (all by default). Takes precedence over the value in the configuration. """ if binning is not None: binning = Binning(binning, closed=closed) if parallel.on_root(): logger.debug( "building patch-wise trees (%s)", "unbinned" if binning is None else f"using {len(binning)} bins", ) patch_tree_iter = parallel.iter_unordered( BinnedTrees.build, self.values(), func_args=(binning,), func_kwargs=dict(leafsize=leafsize, force=force), max_workers=max_workers, ) if progress: patch_tree_iter = Indicator(patch_tree_iter, len(self)) deque(patch_tree_iter, maxlen=0)
Catalog.get.__doc__ = "Return the :obj:`~yaw.Patch` for ID if exists, else default." Catalog.keys.__doc__ = "A set-like object providing a view of all patch IDs." Catalog.values.__doc__ = ( "A set-like object providing a view of all :obj:`~yaw.Patch` es." ) Catalog.items.__doc__ = "A set-like object providing a view of `(key, value)` pairs."