Source code for fastar.synthesis.ssp.integrated

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
from functools import partial

import h5py
import jax
import jax.numpy as jnp
import jax.random as jr
import numpy as np
from jax.scipy.integrate import trapezoid

from .base import BaseSSPSynthesizer
from fastar.interpolate.color import color_interpolation


[docs] class IntegratedSSPSynthesizer(BaseSSPSynthesizer): """ Class for generating synthetic integrated SSP spectroscopic and photometric predictions with a PCA-based stellar spectral model. """
[docs] @partial(jax.jit, static_argnames=['self']) def synthesize(self, age, met, imf_params=None): """ Generate an SSP spectrum for a given age and metallicity. Parameters ---------- age : float Age of the population (in Gyr). met : float Metallicity [M/H]. imf_params : dict, optional Parameters for the IMF. Default is empty dict. Returns ------- array Synthesized spectrum. """ # ensure we always pass a dict to IMF **params imf_params = imf_params or {} # Interpolate the isochrones at the desired age and metallicity imass, iteff, ilogg, ilum = self._get_isochrone(age, met) # Mock metallicity array imet = jnp.full_like(iteff, met) # Evaluate IMF value at the isochrone stellar masses imf_val = self.imf_function(imass, imf_params) # Evaluate the SSP spec = self._wrapper(imass, iteff, ilogg, ilum, imet, imf_val) return spec
@partial(jax.jit, static_argnames=['self']) def _wrapper(self, imass, iteff, ilogg, ilum, imet, imf_val): # Calculate the stellar spectra spectra = self.predict_spectrum(ilogg, iteff, imet) # Calculate the bolometric corrections bcv_val = color_interpolation( ilogg, iteff, imet, self.logg_array, self.teff_log10_array, self.fmet_array, self.bcv_grid, ) # Get the V-band magnitudes of the predicted stellar spectra (they are # normalized to have a mean flux of 1) magnitudes = self._compute_ab_magnitudes(spectra) # Scale the predicted stellar spectra so they math their theoretical # luminosities vmags = -2.5 * ilum - bcv_val mtarg = vmags + self.sun_vmag corr = 1 / jnp.power(10.0, (magnitudes - mtarg) / -2.5) # Integrate corrected spectra over IMF-weighted stars spec = self._population_synthesis_integrate(spectra, corr, imf_val, imass) return spec
[docs] @partial(jax.jit, static_argnames=['self', 'nsim']) def synthesize_nsim( self, age, met, imf_params=None, dmet=0.1, dteff=0.005, dlogg=0.2, nsim=50, key=jr.PRNGKey(0), ): """ Estimate SSP spectral uncertainties via Monte Carlo perturbation of the stellar parameters. Returns (wave, std_spectrum). """ imf_params = imf_params or {} # Base isochrone & arrays imass, iteff, ilogg, ilum = self._get_isochrone(age, met) imf_val = self.imf_function(imass, imf_params) imet = jnp.full_like(iteff, met) # Prepare per-sim RNG keys k1, k2, k3 = jr.split(key, 3) ilogg_p = ilogg + jr.normal(k1, (nsim,) + ilogg.shape) * dlogg iteff_p = iteff + jr.normal(k2, (nsim,) + iteff.shape) * dteff imet_p = imet + jr.normal(k3, (nsim,) + imet.shape) * dmet specs = jax.vmap(self._wrapper, in_axes=(None, 0, 0, None, 0, None))( imass, iteff_p, ilogg_p, ilum, imet_p, imf_val ) ssp_std = jnp.std(specs, axis=0) return ssp_std
[docs] @partial(jax.jit, static_argnames=['self', 'nsim']) def synthesize_nsim_systematic( self, age, met, imf_params=None, dmet=0.1, dteff=0.005, dlogg=0.2, nsim=50, key=jr.PRNGKey(0), ): """ Estimate SSP spectral uncertainties via Monte Carlo perturbation of the stellar parameters. In contrast to `synthesize_nsim`, this systematically shifts parameters of all stars in the SSP by the same amount. Returns std_spectrum. """ imf_params = imf_params or {} # Base isochrone & arrays imass, iteff, ilogg, ilum = self._get_isochrone(age, met) imf_val = self.imf_function(imass, imf_params) imet = jnp.full_like(iteff, met) # Prepare per-sim RNG keys k1, k2, k3 = jr.split(key, 3) ilogg_p = jr.normal(k1, (nsim,)) * dlogg iteff_p = jr.normal(k2, (nsim,)) * dteff imet_p = jr.normal(k3, (nsim,)) * dmet ilogg_p = ilogg + ilogg_p[:, jnp.newaxis] iteff_p = iteff + iteff_p[:, jnp.newaxis] imet_p = imet + imet_p[:, jnp.newaxis] # spec0 = self._wrapper(imass, iteff, ilogg, ilum, imet, imf_val) specs = jax.vmap(self._wrapper, in_axes=(None, 0, 0, None, 0, None))( imass, iteff_p, ilogg_p, ilum, imet_p, imf_val ) ssp_std = jnp.std(specs, axis=0) return ssp_std
@partial(jax.jit, static_argnames=['self']) def _population_synthesis_integrate(self, spectra, corr, imf_val, imass): """ Integrate IMF-weighted, corrected spectra over initial mass grid. """ weights = corr * imf_val integrand = spectra * weights[:, None] return trapezoid(integrand, x=imass, axis=0)
[docs] def stellar_mass(self, age, met, imf_params=None): """ Compute the stellar mass still contributing to the flux in the SSP. Returns ------- float Total stellar mass (M_sun). """ # ensure we always pass a dict to IMF **params imf_params = imf_params or {} # Interpolate isochrone at given age and metallicity imass, _, _, _ = self._get_isochrone(age, met) omass = self._get_outmass(age, met) # Evaluate IMF (can be overridden per call) imf_val = self.imf_function(imass, imf_params) return trapezoid(imf_val * omass, x=imass)
[docs] @partial(jax.jit, static_argnames=['self']) def mass_to_light_ratio( self, age, met, imf_params=None, filter_response=None, solar_mag=None ): """ Compute the mass-to-light ratio of an SSP in a any photometric filter. This function synthesizes an SSP spectrum for the given `age` and `met`, integrates the total stellar mass from the IMF, and computes the AB magnitude of the integrated spectrum using the specified filter response. If no solar magnitude is provided (`solar_mag=None`), the magnitude of the Sun in the same filter is computed from the stored reference solar spectrum. This allows the M/L ratio to be returned in solar units. Parameters ---------- age : float Age of the stellar population in Gyr. met : float Metallicity [M/H] of the population. imf_params : dict, optional Dictionary of parameters for the initial mass function. Default is empty dict. filter_response : array-like or None, optional Response curve sampled over the wavelength grid. If None, the default V-band filter response is used. solar_mag : float or None, optional AB magnitude of the Sun in the same filter. If None, computed from solar spectrum. Returns ------- dict Dictionary containing: - "stars" : float Stellar mass-to-light ratio (M*/L) in solar units. - "total" : float Total mass-to-light ratio (M_total/L), assuming total mass = 1 solar mass. """ # ensure we always pass a dict to IMF **params imf_params = imf_params or {} response = ( filter_response if filter_response is not None else self.filter_response ) stellar_mass = self.stellar_mass(age, met, imf_params) spectrum = self.synthesize(age, met, imf_params) if solar_mag is None: m_sun = self._compute_ab_magnitudes( self.sun_spec[None, :], filter_response=response )[0] else: m_sun = solar_mag # AB magnitude of integrated spectrum ab_mag = self._compute_ab_magnitudes( spectrum[None, :], filter_response=response )[0] luminosity = 10 ** (-0.4 * (ab_mag - m_sun)) return { 'stars': stellar_mass / luminosity, # M*/L 'total': 1.0 / luminosity, # M_total/L }
[docs] def load_precomputed_models( self, age_range=None, met_range=None, imf_range=None, cache_dir='ssp_cache', user_label=None, ): """ Compute or load a grid of precomputed SSP spectra and save them to disk. This function evaluates the SSP model on a grid of age, metallicity, and IMF parameters. If a cached file with matching parameters exists, it loads the data from disk. Otherwise, it computes the SSP grid, saves it to an HDF5 file, and returns the resulting data arrays. Parameters ---------- age_range : array-like or None, optional Array of SSP ages in Gyr. If None, uses the default isochrone age grid. met_range : array-like or None, optional Array of SSP metallicities [M/H]. If None, uses the default isochrone metallicity grid. imf_range : list of dicts or None, optional List of parameter dictionaries for the IMF function. Each dictionary defines one IMF configuration. If None, defaults to a single empty dictionary (default IMF). cache_dir : str, optional Directory where the SSP grids are stored or will be saved. Default is "ssp_cache". user_label : str, optional Optional string output filename for custom identification. Returns ------- wave : ndarray Wavelength grid of the synthesized SSP spectra. spec_grid : ndarray SSP spectra on the specified grid, with shape (n_ages, n_mets, n_imfs, n_wave). age_range : ndarray Array of ages used in the grid. met_range : ndarray Array of metallicities used in the grid. imf_range : list of dict List of IMF parameter dictionaries used in the grid. """ age_range = jnp.array(age_range if age_range is not None else self.iso_ages) met_range = jnp.array(met_range if met_range is not None else self.iso_mets) imf_range = imf_range if imf_range is not None else [{}] # Validate ranges if (jnp.min(age_range) < jnp.min(self.ages / 1000.0)) or ( jnp.max(age_range) > jnp.max(self.ages / 1000.0) ): raise ValueError('Age range outside isochrone limits.') if (jnp.min(met_range) < jnp.min(self.mets)) or ( jnp.max(met_range) > jnp.max(self.mets) ): raise ValueError('Metallicity range outside isochrone limits.') # Format helpers for the output file name def _format_range(arr): """Return formatted string like '0.1-13.0' from an array.""" return f'{np.min(arr):.2f}-{np.max(arr):.2f}' def _format_imf_range(imf_range): """Create a descriptive string for the IMF range.""" if len(imf_range) == 1 and imf_range[0] == {}: # Use the name of the function as a fallback label imf_func_name = getattr(self.imf_function, '__name__', 'imf') return imf_func_name.lower() # Otherwise, build a param-based label keys = sorted(imf_range[0].keys()) key_strs = [] for key in keys: vals = [params.get(key, None) for params in imf_range] key_str = f'{key}{min(vals):.1f}-{max(vals):.1f}' key_strs.append(key_str) return '_'.join(key_strs) # Create filename from parameter ranges age_str = _format_range(age_range) met_str = _format_range(met_range) imf_str = _format_imf_range(imf_range) if user_label: fname = str(user_label) + '.hdf5' else: fname = ( f'sspgrid_age{age_str}_met{met_str}_imf{imf_str}' + self.rlabel + str(user_label) + '.hdf5' ) cache_path = os.path.join(cache_dir, fname) # Load if exists if os.path.exists(cache_path): print(f'Loading SSP grid from {cache_path}') with h5py.File(cache_path, 'r') as f: wave = f['wavelength'][()] spec_grid = f['spectra'][()] return wave, spec_grid print('Calculating SSP predictions') # Compute SSP grid os.makedirs(cache_dir, exist_ok=True) na, nm, ni, nw = ( len(age_range), len(met_range), len(imf_range), len(self.wave), ) spec_grid = jnp.zeros((na, nm, ni, nw)) for ia, age in enumerate(age_range): for im, met in enumerate(met_range): for ii, imf_params in enumerate(imf_range): spec = self.synthesize(age, met, imf_params) spec_grid = spec_grid.at[ia, im, ii, :].set(spec) # Save with h5py.File(cache_path, 'w') as f: f.create_dataset('wavelength', data=np.array(self.wave)) f.create_dataset('spectra', data=np.array(spec_grid)) return self.wave, spec_grid