Source code for fastar.synthesis.stellar
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from functools import partial
import flax.serialization as flax_ser
import h5py
import jax
import jax.numpy as jnp
from fastar.nn.pca_regressor import PCARegressor
from fastar.tools.assets import get_asset_path
[docs]
class StellarSynthesizer:
"""
Base class. Synthesizer of stellar spectra.
"""
[docs]
def __init__(self, model_label=None):
self.npc = 16
self.activation_type = 'gelu'
if model_label is None:
self.rlabel = '_spec'
if model_label == 'phot':
self.rlabel = '_phot'
self._load_model()
def _load_model(self):
"""
Load trained PCA regressor, scalers, and PCA components.
"""
model = PCARegressor(output_dim=self.npc, activation_type=self.activation_type)
with open(
get_asset_path(f'pca_regressor{self.rlabel}.flax'),
'rb',
) as pca_regressor_file:
self.params = flax_ser.from_bytes(
model.init(jax.random.PRNGKey(0), jnp.ones((1, 3))),
pca_regressor_file.read(),
)
self.model = model
with h5py.File(
get_asset_path(f'training_artifacts{self.rlabel}.h5'), 'r'
) as training_artifacts_file:
self.scaler_x_mean = training_artifacts_file['scaler_X/mean_'][:]
self.scaler_x_scale = training_artifacts_file['scaler_X/scale_'][:]
self.scaler_y_mean = training_artifacts_file['scaler_Y/mean_'][:]
self.scaler_y_scale = training_artifacts_file['scaler_Y/scale_'][:]
self.pca_components = training_artifacts_file['pca/components_'][:]
self.pca_mean = training_artifacts_file['pca/mean_'][:]
self.mean_spectrum = training_artifacts_file['mean_spectrum'][:]
self.wave = training_artifacts_file['wave'][:]
# *** Review the following method: it should be a function, not? ***
@partial(jax.jit, static_argnames=['self'])
def _softplus(self, input_flux, beta=100.0): # pylint: disable=no-self-use
"""
Smooth activation function with soft floor to prevent
any negative flux in the spectrum of extreme stars
"""
return (1.0 / beta) * jnp.logaddexp(0.0, beta * input_flux)
[docs]
@partial(jax.jit, static_argnames=['self'])
def predict_spectrum(self, logg, teff, fmet):
"""
Predict stellar spectra given logg, Teff, and [Fe/H] using the
PCA regressor.
"""
inputs = jnp.stack([logg, teff, fmet], axis=-1)
input_scaled = (inputs - self.scaler_x_mean) / self.scaler_x_scale
pca_scaled = self.model.apply(self.params, input_scaled)
pca_coeffs = pca_scaled * self.scaler_y_scale + self.scaler_y_mean
spectra = (
jnp.dot(pca_coeffs, self.pca_components) + self.pca_mean
) + self.mean_spectrum
return self._softplus(spectra)