Source code for fastar.interpolate.isochrone

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import jax.numpy as jnp
from jax.scipy.ndimage import map_coordinates


# =============================================================================
# Isochrone Interpolation over (age, [Fe/H])
# =============================================================================


[docs] def isochrone_interpolation( age, met, ages, mets, mass_ini_data, teff_out_data, logg_out_data, lumi_out_data, ): """ Perform bilinear interpolation on isochrone data over a (age, [Fe/H]) grid using JAX's map_coordinates for efficient vectorized sampling. Parameters ---------- age : float Target age in the same units as isochrone grid (e.g., Myr). met : float Target metallicity [Fe/H] value. ages : array-like 1D array of age grid points. mets : array-like 1D array of metallicity grid points. mass_ini_data : array-like, shape (n_tracks, len(ages), len(mets)) Grid of initial stellar masses. teff_out_data : array-like, shape (n_tracks, len(ages), len(mets)) Grid of effective temperatures. logg_out_data : array-like, shape (n_tracks, len(ages), len(mets)) Grid of surface gravities. lumi_out_data : array-like, shape (n_tracks, len(ages), len(mets)) Grid of luminosities. Returns ------- tuple of ndarray Interpolated 1D arrays: (mass_ini, teff_out, logg_out, lumi_out). """ # Convert age units (Gyr -> Myr) age = age * 1000 # Compute fractional index positions along age and metallicity axes aidx = jnp.interp(age, ages, jnp.arange(len(ages))) midx = jnp.interp(met, mets, jnp.arange(len(mets))) # Prepare coordinates for sampling: # We sample each 'track' (first axis) exactly at integer positions, # and bilinearly along the last two axes via fractional aidx, midx. n_tracks = mass_ini_data.shape[0] # track indices as floats (to pick exact tracks) t_idx = jnp.arange(n_tracks, dtype=float) age_idx = jnp.full(n_tracks, aidx) met_idx = jnp.full(n_tracks, midx) # Stack into coords of shape (3, n_tracks) coords = jnp.stack([t_idx, age_idx, met_idx], axis=0) # Perform interpolation mass_ini = map_coordinates(mass_ini_data, coords, order=1, mode='nearest') teff_out = map_coordinates(teff_out_data, coords, order=1, mode='nearest') logg_out = map_coordinates(logg_out_data, coords, order=1, mode='nearest') lumi_out = map_coordinates(lumi_out_data, coords, order=1, mode='nearest') return mass_ini, teff_out, logg_out, lumi_out