Source code for fastar.synthesis.ssp.semiresolved

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from functools import partial

import jax
import jax.numpy as jnp
from jax import lax
from jax.scipy.integrate import trapezoid

from .base import BaseSSPSynthesizer
from fastar.interpolate.color import color_interpolation


[docs] class SemiresolvedSSPSynthesizer(BaseSSPSynthesizer): """ Class for generating synthetic semi-resolved stellar populations spectroscopic and photometric predictions with a PCA-based stellar spectral model and a stochastic IMF sampling. """
[docs] @partial(jax.jit, static_argnames=['self', 'num_stars', 'out_masses']) def synthesize(self, age, met, num_stars, key, imf_params=None, out_masses=False): """ Generate synthetic semi-resolved population spectrum for a given age and metallicity. Parameters ---------- age : float Stellar population age (in Gyr). met : float Metallicity [M/H]. num_stars : int Number of stars to sample. key : PRNGKey Random key for JAX sampling. imf_params : dict, optional Parameters for the initial mass function. out_masses : bool, optional If True, return array of sampled stellar masses instead of the total mass. Default is False. Returns ------- tuple (spectrum, total stellar mass) or (spectrum, sampled stellar masses) depending on `out_masses`. """ # ensure we always pass a dict to IMF **params imf_params = imf_params or {} # Interpolate the isochrones at the desired age and metallicity imass, iteff, ilogg, ilum = self._get_isochrone(age, met) # Stochastically sample the IMF sampled_masses = self._stochastic_imf_sampling( imass, imf_params, num_stars, key ) # Evaluate the isochrone at the interpolated masses iteff_interp = jnp.interp(sampled_masses, imass, iteff) ilogg_interp = jnp.interp(sampled_masses, imass, ilogg) ilum_interp = jnp.interp(sampled_masses, imass, ilum) # Calculate the stellar spectra spectra = self.predict_spectrum( ilogg_interp, iteff_interp, jnp.full_like(iteff_interp, met) ) # Get the V-band magnitudes of the predicted stellar spectra (they are # normalized to have a mean flux of 1) magnitudes = self._compute_ab_magnitudes(spectra) # Calculate the bolometric corrections bcv_val = color_interpolation( ilogg_interp, iteff_interp, jnp.full_like(iteff_interp, met), self.logg_array, self.teff_log10_array, self.fmet_array, self.bcv_grid, ) # Scale the predicted stellar spectra so they math their theoretical # luminosities vmags = -2.5 * ilum_interp - bcv_val mtarg = vmags + self.sun_vmag corr = 1 / (10 ** ((magnitudes - mtarg) / -2.5)) # Add the flux of all the spectra. There is no IMF weighting here # since it naturally comes from the stochastic sampling spec = jnp.sum(spectra * corr[:, None], axis=0) # The function returns wavelength, spectrum and either the total # stellar mass of the population or the sampled if out_masses: result = (spec, sampled_masses) else: result = (spec, jnp.sum(sampled_masses)) return result
@partial(jax.jit, static_argnames=['self', 'num_stars']) def _stochastic_imf_sampling(self, imass, imf_params, num_stars, key): """ Stochastically sample stellar masses from an IMF assuming it emerges from a probability distribution function. """ imf_params = imf_params or {} # Normalize the IMF to a total number of stars equal to 1 # Note this normalization is different than the one used # for the population synthesis as there the normalizing # quantity is the total mass (not the number of stars) mass_grid = jnp.linspace(imass.min(), imass.max(), 5000) imf_vals = self.imf_function(mass_grid, imf_params) # IMF to PDF (via normalization) pdf = imf_vals / trapezoid(imf_vals, x=mass_grid) cdf = jnp.cumsum(pdf) cdf = cdf / cdf[-1] # Uniform sampling f the CDF uniform_samples = jax.random.uniform(key, shape=(num_stars,)) return jnp.interp(uniform_samples, cdf, mass_grid) @partial(jax.jit, static_argnames=['self']) def _synthesize_massgiven(self, met, imass, iteff, ilogg, ilum, sampled_masses): """ Generate the integrated spectrum of a stellar population using a pre-sampled set of stellar masses and isochrone quantities. """ # Evaluate the isochrone at the interpolated masses iteff_interp = jnp.interp(sampled_masses, imass, iteff) ilogg_interp = jnp.interp(sampled_masses, imass, ilogg) ilum_interp = jnp.interp(sampled_masses, imass, ilum) # Calculate the stellar spectra spectra = self.predict_spectrum( ilogg_interp, iteff_interp, jnp.full_like(iteff_interp, met) ) # Get the V-band magnitudes of the predicted stellar spectra (they are # normalized to have a mean flux of 1) magnitudes = self._compute_ab_magnitudes(spectra) # Calculate the bolometric corrections bcv_val = color_interpolation( ilogg_interp, iteff_interp, jnp.full_like(iteff_interp, met), self.logg_array, self.teff_log10_array, self.fmet_array, self.bcv_grid, ) # Scale the predicted stellar spectra so they math their theoretical # luminosities vmags = -2.5 * ilum_interp - bcv_val mtarg = vmags + self.sun_vmag corr = 1 / (10 ** ((magnitudes - mtarg) / -2.5)) # Add the flux of all the spectra. There is no IMF weighting here # since it naturally comes from the stochastic sampling spec = jnp.sum(spectra * corr[:, None], axis=0) return spec
[docs] @partial( jax.jit, static_argnames=['self', 'num_stars', 'batch_size', 'out_masses'], ) def synthesize_large( self, age, met, num_stars, key, batch_size=10000, out_masses=False, imf_params=None, ): imf_params = imf_params or {} # Isochrone interpolation (same for all batches) imass, iteff, ilogg, ilum = self._get_isochrone(age, met) # Sample IMF once sampled_masses = self._stochastic_imf_sampling( imass, imf_params, num_stars, key ) # Split samples into batches n_batches = num_stars // batch_size remainder = num_stars % batch_size full_samples = sampled_masses[: n_batches * batch_size] batches = full_samples.reshape((n_batches, batch_size)) # Scan-compatible batch function def batch_fn(batch_masses): return self._synthesize_massgiven( met=met, imass=imass, iteff=iteff, ilogg=ilogg, ilum=ilum, sampled_masses=batch_masses, ) # Initial spectrum accumulator init_spec = jnp.zeros_like(self.wave) # Use lax.scan for batches def scan_fn(spec_accum, batch_masses): return spec_accum + batch_fn(batch_masses), None spec_total, _ = lax.scan(scan_fn, init_spec, batches) # Handle remainder using lax.cond (fully JAX-compatible) def add_remainder(spec_accum): rem_masses = sampled_masses[-remainder:] rem_spec = self._synthesize_massgiven( met=met, imass=imass, iteff=iteff, ilogg=ilogg, ilum=ilum, sampled_masses=rem_masses, ) return spec_accum + rem_spec spec_total = lax.cond(remainder > 0, add_remainder, lambda x: x, spec_total) if out_masses: return spec_total, sampled_masses else: return spec_total, jnp.sum(sampled_masses)