Source code for fast_lisa_subtraction.priors.population

"""
Some standard population priors. See Abbott et al 2021 (arxiv.org/pdf/2010.14533)
"""

import torch
import numpy as np

from ..utils import interp1d
from .analytical import Prior

[docs] class PowerLaw(Prior): r""" Power-law prior distribution. .. math:: p(\theta) = \begin{cases} \theta^{\alpha}, & \theta_{\min} \le \theta \le \theta_{\max} \\ 0, & \text{otherwise} \end{cases} where :math:`\theta_{\min}` and :math:`\theta_{\max}` are the minimum and maximum values of the prior, respectively. Parameters ---------- alpha : float Power-law index :math:`\alpha`. minimum : float Minimum value of the prior :math:`\theta_{\min}`. maximum : float Maximum value of the prior :math:`\theta_{\max}`. name : str, optional Name of the prior parameter. device : str, optional Torch device used for sampling. """ def __init__(self, alpha, minimum, maximum, name=None, device='cpu'): """Initialize a power-law prior. Parameters ---------- alpha : float Power-law index. minimum : float Minimum value of the support. maximum : float Maximum value of the support. name : str, optional Name of the prior parameter. device : str, optional Torch device used for sampling. """ self.alpha = alpha super().__init__(minimum, maximum, name, device) #compute the pdf and cdf self._compute_pdf() self._compute_cdf() @property def pdf(self): """Probability density function evaluated on the internal grid. Returns ------- torch.Tensor Discretized PDF values over ``self.x``. """ return self._pdf @property def pdf_max(self): """Maximum value of the discretized PDF. Returns ------- torch.Tensor Maximum PDF value. """ return self._pdf.max() @property def cdf(self): """Cumulative distribution function evaluated on the internal grid. Returns ------- torch.Tensor Discretized CDF values over ``self.x``. """ return self._cdf def _compute_pdf(self): """Compute a discretized power-law PDF on a fixed grid. Returns ------- None """ self.x = torch.linspace(self.minimum, self.maximum, int(1e4), device=self.device) pdf = self.x.pow(self.alpha) pdf = pdf / torch.trapz(pdf, self.x) self._pdf = pdf def _compute_cdf(self): """Compute a discretized CDF from the cached PDF. Returns ------- None """ cdf = torch.cumsum(self.pdf, dim=-1) self._cdf = cdf / cdf[-1]
[docs] def log_prob(self, x, standardize=False): """Evaluate the log-probability of the power-law prior. Parameters ---------- x : array-like or torch.Tensor Points at which to evaluate ``log p(x)``. standardize : bool, optional If True, interpret ``x`` as standardized and de-standardize before evaluating. Returns ------- torch.Tensor Log-probability values. """ x = torch.as_tensor(x, device=self.device, dtype=torch.float32) if standardize: x = self.destandardize(x) xmin = torch.as_tensor(self.minimum, device=self.device, dtype=x.dtype) xmax = torch.as_tensor(self.maximum, device=self.device, dtype=x.dtype) alpha = torch.as_tensor(self.alpha, device=self.device, dtype=x.dtype) inside = (x >= xmin) & (x <= xmax) logp = torch.full_like(x, -float("inf")) # alpha = -1 case if torch.isclose(alpha, torch.tensor(-1.0, device=self.device)): norm = torch.log(xmax) - torch.log(xmin) logp_inside = -torch.log(x) - torch.log(norm) else: norm = (xmax.pow(alpha + 1) - xmin.pow(alpha + 1)) / (alpha + 1) logp_inside = alpha * torch.log(x) - torch.log(norm) logp[inside] = logp_inside[inside] return logp
[docs] def sample(self, num_samples, standardize=False): """Sample from the power-law prior. Parameters ---------- num_samples : int Number of samples to draw. standardize : bool, optional Whether to standardize the samples to zero mean and unit variance (for training purposes). Default: False Returns ------- torch.Tensor Samples drawn from the prior. """ u = torch.rand(num_samples, device=self.device) samples = interp1d(self._cdf, self.x, u).flatten() if standardize: samples = self.standardize(samples) return samples
[docs] class BrokenPowerLaw(Prior): r"""Broken power-law prior distribution. Notes ----- The density is piecewise: .. math:: p(\theta) \propto \theta^{\alpha}, \quad \theta_{\min} \le \theta < \theta_{\mathrm{break}}, .. math:: p(\theta) \propto \theta^{\beta}\, \theta_{\mathrm{break}}^{\alpha-\beta}, \quad \theta_{\mathrm{break}} \le \theta \le \theta_{\max}. The break point is determined by ``b``. If ``0 <= b <= 1``, then :math:`\theta_{\mathrm{break}} = \theta_{\min} + b(\theta_{\max}-\theta_{\min})`; otherwise ``b`` is treated as an absolute break value within the support. Optional low-end smoothing :math:`S(\theta; \delta)` can be applied in the interval :math:`[\theta_{\min}, \theta_{\min}+\delta]`. Parameters ---------- alpha : float Power-law index for :math:`\theta < \theta_{\mathrm{break}}`. beta : float Power-law index for :math:`\theta \ge \theta_{\mathrm{break}}`. b : float Break parameter (fraction or absolute). minimum : float Lower bound of the support. maximum : float Upper bound of the support. delta_p : float or None, optional Optional smoothing width :math:`\delta`. If None or 0, smoothing is disabled. name : str or None, optional Parameter name. device : str, optional Torch device used for sampling. """ def __init__(self, alpha, beta, b, minimum, maximum, delta_p=None, name=None, device='cpu'): r"""Initialize a broken power-law prior. Parameters ---------- alpha : float Power-law index for :math:`\theta < \theta_{\mathrm{break}}`. beta : float Power-law index for :math:`\theta \ge \theta_{\mathrm{break}}`. b : float Break parameter (fraction or absolute). minimum : float Lower bound of the support. maximum : float Upper bound of the support. delta_p : float or None, optional Optional smoothing width :math:`\delta`. If None or 0, smoothing is disabled. name : str or None, optional Parameter name. device : str, optional Torch device used for sampling. Raises ------ ValueError If ``b`` is provided as an absolute value outside the support. """ self.alpha = float(alpha) self.beta = float(beta) self.device = device super().__init__(minimum, maximum, name, device) # Resolve break: allow fraction in [0,1] or absolute within [min,max] if 0.0 <= b <= 1.0: self.b = float(b) self.break_point = self.minimum + self.b * (self.maximum - self.minimum) else: if not (self.minimum < b < self.maximum): raise ValueError(f"`b` must be in [0,1] (fraction) or within ({self.minimum},{self.maximum}) as absolute.") self.break_point = float(b) self.b = (self.break_point - self.minimum) / (self.maximum - self.minimum) # Optional smoothing width (delta); treat None/0 as disabled self.delta_p = None if (delta_p is None or float(delta_p) <= 0.0) else float(delta_p) # Precompute PDF/CDF grids self._compute_pdf() self._compute_cdf() @property def pdf(self): """Probability density function evaluated on the internal grid. Returns ------- torch.Tensor Discretized PDF values over ``self.x``. """ return self._pdf @property def pdf_max(self): """Maximum value of the discretized PDF. Returns ------- torch.Tensor Maximum PDF value. """ return self._pdf.max() @property def cdf(self): """Cumulative distribution function evaluated on the internal grid. Returns ------- torch.Tensor Discretized CDF values over ``self.x``. """ return self._cdf def _compute_pdf(self): """Compute a discretized broken power-law PDF on a fixed grid. Returns ------- None """ # Discretization grid (same density as PowerLaw) self.x = torch.linspace(self.minimum, self.maximum, int(1e4), device=self.device) pdf = torch.zeros_like(self.x) # continuity factor to match values at theta_break fix_factor = (self.break_point ** (self.alpha - self.beta)) mask_lo = self.x < self.break_point mask_hi = ~mask_lo # includes theta == break # piecewise power law if mask_lo.any(): pdf[mask_lo] = self.x[mask_lo].pow(self.alpha) if mask_hi.any(): pdf[mask_hi] = self.x[mask_hi].pow(self.beta) * fix_factor # optional low-end smoothing near theta_min if self.delta_p is not None: pdf = pdf * self._smoothing(self.x) # normalize by integral over theta norm = torch.trapz(pdf, self.x) if not torch.isfinite(norm) or norm <= 0: raise RuntimeError("BrokenPowerLaw PDF normalization failed (non-positive or non-finite integral).") self._pdf = pdf / norm def _compute_cdf(self): """Compute a discretized CDF from the cached PDF. Returns ------- None """ # cum-trapezoid so that CDF[0] = 0 and CDF[-1] = 1 x = self.x p = self.pdf cdf = torch.zeros_like(p) # trapezoid increments dx = x[1:] - x[:-1] incr = 0.5 * (p[1:] + p[:-1]) * dx cdf[1:] = torch.cumsum(incr, dim=-1) # normalize cdf = cdf / cdf[-1] self._cdf = cdf def _f(self, p: torch.Tensor) -> torch.Tensor: r"""Compute the smoothing helper function. .. math:: f(\theta) = \exp\!\left(\frac{\delta}{\theta-\theta_{\min}} + \frac{\delta}{(\theta-\theta_{\min})-\delta}\right). This is only evaluated where denominators are positive; callers should mask inputs appropriately. Parameters ---------- p : torch.Tensor Input samples. Returns ------- torch.Tensor ``f(p)`` values. """ p1 = p - self.minimum eps = torch.finfo(p.dtype).eps d1 = torch.clamp(p1, min=eps) d2 = torch.clamp(p1 - self.delta_p, min=eps) return torch.exp(self.delta_p / d1 + self.delta_p / d2) def _smoothing(self, samples: torch.Tensor) -> torch.Tensor: r"""Compute the low-end smoothing factor. .. math:: S(\theta) = \begin{cases} 0, & \theta < \theta_{\min}, \\ \frac{1}{f(\theta)+1}, & \theta_{\min} \le \theta < \theta_{\min}+\delta, \\ 1, & \theta \ge \theta_{\min}+\delta. \end{cases} Parameters ---------- samples : torch.Tensor Input samples. Returns ------- torch.Tensor Smoothing factors for each sample. """ S = torch.zeros_like(samples) p_thr = self.minimum + self.delta_p mask_mid = (samples >= self.minimum) & (samples < p_thr) if mask_mid.any(): S[mask_mid] = 1.0 / (self._f(samples[mask_mid]) + 1.0) S[samples >= p_thr] = 1.0 # (samples < minimum) stay at 0.0, but those are out of support anyway return S
[docs] def sample(self, num_samples: int, standardize: bool = False): """Sample from the broken power-law prior. Parameters ---------- num_samples : int Number of samples to draw. standardize : bool, optional Whether to standardize the samples to zero mean and unit variance (for training purposes). Default: False Returns ------- torch.Tensor Samples drawn from the prior. """ # explicit endpoints ensure no extrapolation outside [min,max] cdf_aug = torch.cat([torch.tensor([0.0], device=self.device), self._cdf]) x_aug = torch.cat([torch.tensor([self.minimum], device=self.device), self.x]) u = torch.rand(num_samples, device=self.device) # in [0,1) samples = interp1d(cdf_aug, x_aug, u).flatten() # final tiny safety (handles any numerical wiggles) samples = torch.clamp(samples, min=self.minimum, max=self.maximum) if standardize: samples = self.standardize(samples) return samples
[docs] class PowerLawPlusPeak(Prior): """Power-law plus peak prior distribution. Parameters ---------- alpha : float Power-law index for the background. beta : float Power-law index for the peak component. minimum : float Lower bound of the support. maximum : float Upper bound of the support. peak : float Peak location or scale parameter (implementation-specific). name : str, optional Name of the prior parameter. """ def __init__(self, alpha, beta, minimum, maximum, peak, name='Parameter'): """Initialize a power-law plus peak prior. Parameters ---------- alpha : float Power-law index for the background. beta : float Power-law index for the peak component. minimum : float Lower bound of the support. maximum : float Upper bound of the support. peak : float Peak location or scale parameter (implementation-specific). name : str, optional Name of the prior parameter. """ self.alpha = alpha self.beta = beta self.peak = peak super().__init__(minimum, maximum, name)
[docs] def sample(self, N): """Sample from the power-law plus peak prior. Parameters ---------- N : int Number of samples to draw. Raises ------ NotImplementedError Sampling for this prior is not implemented. """ raise NotImplementedError("Sampling from PowerLawPlusPeak is not implemented")