Source code for fastar.imf.named_imf.broken_power_law

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# pylint: disable=duplicate-code
# *** Duplicate code will be addressed in future IMF refactoring ***
"""
Broken power-law
"""

import jax.numpy as jnp
import jax.scipy.integrate as jsp_integrate


[docs] def broken_power_law_raw( mass, m_min=0.1, m_max=100.0, m_break1=0.5, m_break2=1.0, alpha1=1.3, alpha2=1.8, alpha3=2.3, ): """ Returns the normalized broken power-law IMF evaluated at `mass`, fully JAX-compatible with numerical normalization. Behavior: m < m_break1 -> m^alpha1 m_break1 ≤ m < m_break2 -> m^alpha2 m ≥ m_break2 -> m^alpha3 Parameters ---------- mass : array-like Stellar mass or array of masses. m_min : float, optional Lower mass limit. Default is 0.1. m_max : float, optional Upper mass limit. Default is 100.0. m_break1, m_break2 : float Break masses (must satisfy m_min < m_break1 < m_break2 < m_max). alpha1, alpha2, alpha3 : float Slopes of the three segments. Returns ------- jnp.ndarray or float Normalized IMF values. """ mass = jnp.atleast_1d(mass) def imf_piecewise(m): # Coefficients to ensure continuity seg1 = m ** (-alpha1) # Middle segment: A * m_break1^(alpha2-alpha1) * m^(-alpha2) coeff2 = m_break1 ** (alpha2 - alpha1) seg2 = coeff2 * m ** (-alpha2) # High-mass segment: # A * m_break1^(alpha2-alpha1) * m_break2^(alpha3-alpha2) * m^(-alpha3) coeff3 = coeff2 * (m_break2 ** (alpha3 - alpha2)) seg3 = coeff3 * m ** (-alpha3) return jnp.where( m < m_break1, seg1, jnp.where(m < m_break2, seg2, seg3), ) m_vals = jnp.linspace(m_min, m_max, 5000) norm = jsp_integrate.trapezoid(imf_piecewise(m_vals) * m_vals, x=m_vals) imf_vals = imf_piecewise(mass) / norm imf_vals = jnp.where((mass >= m_min) & (mass <= m_max), imf_vals, 0.0) return imf_vals if imf_vals.shape[0] > 1 else imf_vals[0]
[docs] def broken_power_law(mass, params): """ Wrapper for the broken power-law IMF using a parameter dictionary. Parameters ---------- mass : array-like Stellar mass or array of masses. params : dict Dictionary of parameters to pass to `broken_power_law_raw`. Returns ------- jnp.ndarray or float Normalized IMF values. """ return broken_power_law_raw(mass, **params)