Source code for yaw.config.backend
from __future__ import annotations
import os
from dataclasses import dataclass, field
from yaw.config import DEFAULT
from yaw.config.abc import BaseConfig
from yaw.core.docs import Parameter
__all__ = ["BackendConfig"]
[docs]
@dataclass(frozen=True)
class BackendConfig(BaseConfig):
"""Configuration of backends used for correlation measurements.
Args:
thread_num (:obj:`int`, optional):
Number of threads to use for parallel processing.
crosspatch (:obj:`bool`, optional):
Whether to count pairs across patch boundaries (``scipy`` backend
only).
rbin_slop (:obj:`int`, optional):
`TreeCorr` ``rbin_slop`` parameter (``treecorr`` backend only).
"""
# general
thread_num: int | None = field(
default=DEFAULT.Backend.thread_num,
metadata=Parameter(
type=int,
help="default number of threads to use",
default_text="(default: all)",
),
)
"""Number of threads to use for parallel processing."""
# scipy
crosspatch: bool = field(
default=DEFAULT.Backend.crosspatch,
metadata=Parameter(
type=bool,
help="whether to count pairs across patch boundaries (scipy backend only)",
),
)
"""Whether to count pairs across patch boundaries (``scipy`` backend only).
"""
# treecorr
rbin_slop: float = field(
default=DEFAULT.Backend.rbin_slop,
metadata=Parameter(
type=float,
help="TreeCorr 'rbin_slop' parameter",
default_text="(default: %(default)s), without 'rweight' this just "
"a single radial bin, otherwise 'rbin_num'",
),
)
"""`TreeCorr` ``rbin_slop`` parameter (``treecorr`` backend only)."""
def __post_init__(self) -> None:
if self.thread_num is None:
object.__setattr__(self, "thread_num", os.cpu_count())
[docs]
def modify(
self,
thread_num: int | None = DEFAULT.NotSet,
crosspatch: bool = DEFAULT.NotSet,
rbin_slop: float = DEFAULT.NotSet,
) -> BackendConfig:
return super().modify(
thread_num=thread_num, crosspatch=crosspatch, rbin_slop=rbin_slop
)
[docs]
def get_threads(self, max=None) -> int:
"""Get the number of threads for parallel processing.
The value is capped at an optional maximum value.
Args:
max (:obj:`int`, optional):
Maximum number to return.
Returns:
:obj:`int`
"""
thread_num = self.thread_num
if max is not None:
if max < 1:
raise ValueError("'max' must be positive")
thread_num = min(max, thread_num)
return thread_num