Source code for gpax.models.spm

"""
spm.py
=======

Bayesian inference with structured probabilistic model. Serves as a wrapper
over the NumPyro's functions for Bayesian inference on probabilistic models.
While it has no direct connection to Gaussian processes, this module may come
in handy for comparisons between predicitons of GP and classical Bayesian models.

Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com)
"""

import warnings
from typing import Callable, Optional, Tuple, Type, Dict

import jax
import jaxlib
import jax.numpy as jnp
import jax.random as jra
from jax import vmap
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive, init_to_median

model_type = Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]
prior_type = Callable[[], Dict[str, jnp.ndarray]]


[docs]class sPM: """ Bayesian inference with structured probabilistic model. Serves as a wrapper over the NumPyro's functions for Bayesian inference on probabilistic models. Args: model: Deterministic model of expected system's behavior. model_prior: Priors over model parameters noise_prior: Optional custom prior for observation noise; uses LogNormal(0,1) by default. """ def __init__(self, model: model_type, model_prior: prior_type, noise_prior: Optional[prior_type] = None, noise_prior_dist: Optional[dist.Distribution] = None,) -> None: self._model = model self.model_prior = model_prior if noise_prior is not None: warnings.warn( "`noise_prior` is deprecated and will be removed in a future version. " "Please use `noise_prior_dist` instead, which accepts an instance of a " "numpyro.distributions Distribution object, e.g., `dist.HalfNormal(scale=0.1)`, " "rather than a function that calls `numpyro.sample`.", FutureWarning, ) self.noise_prior = noise_prior self.noise_prior_dist = noise_prior_dist self.mcmc = None
[docs] def model(self, X: jnp.ndarray, y: jnp.ndarray = None) -> None: """ Full probabilistic model """ # Sample model parameters params = self.model_prior() # Compute the function's value mu = numpyro.deterministic("mu", self._model(X, params)) # Sample observational noise if self.noise_prior: # this will be removed in the future releases sig = self.noise_prior() else: sig = self._sample_noise() # Score against the observed data points numpyro.sample("y", dist.Normal(mu, sig), obs=y)
def _sample_noise(self) -> jnp.ndarray: if self.noise_prior_dist is not None: noise_dist = self.noise_prior_dist else: noise_dist = dist.LogNormal(0, 1) return numpyro.sample("noise", noise_dist)
[docs] def fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray, num_warmup: int = 2000, num_samples: int = 2000, num_chains: int = 1, chain_method: str = 'sequential', progress_bar: bool = True, print_summary: bool = True, device: Type[jaxlib.xla_client.Device] = None) -> None: """ Run HMC to infer parameters of the structured probabilistic model Args: rng_key: random number generator key X: 1D or 2D 'feature vector' with :math:`(n,)` or :math:`n x num_features` dimensions y: 1D 'target vector' with :math:`(n,)` dimensions num_warmup: number of HMC warmup states num_samples: number of HMC samples num_chains: number of HMC chains chain_method: 'sequential', 'parallel' or 'vectorized' progress_bar: show progress bar print_summary: print summary at the end of sampling device: optionally specify a cpu or gpu device on which to run the inference; e.g., ``device=jax.devices("cpu")[0]`` """ X, y = self._set_data(X, y) if device: X = jax.device_put(X, device) y = jax.device_put(y, device) init_strategy = init_to_median(num_samples=10) kernel = NUTS(self.model, init_strategy=init_strategy) self.mcmc = MCMC( kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains, chain_method=chain_method, progress_bar=progress_bar, jit_model_args=False ) self.mcmc.run(rng_key, X, y) if print_summary: self._print_summary()
[docs] def get_samples(self, chain_dim: bool = False) -> Dict[str, jnp.ndarray]: """Get posterior samples (after running the MCMC chains)""" return self.mcmc.get_samples(group_by_chain=chain_dim)
[docs] def get_param_means(self): """ Returns mean value for each probabilistic parameter in the model """ samples = self.get_samples() param_means = { k: v.mean(0).item() for (k, v) in samples.items() if k != 'mu' } return param_means
[docs] def sample_from_prior(self, rng_key: jnp.ndarray, X: jnp.ndarray, num_samples: int = 10): """ Samples from prior predictive distribution at X """ prior_predictive = Predictive(self.model, num_samples=num_samples) samples = prior_predictive(rng_key, X) return samples['y']
[docs] def sample_single_posterior_predictive(self, rng_key, X_new, params, n_draws): sigma = params["noise"] loc = self._model(X_new, params) sample = dist.Normal(loc, sigma).sample(rng_key, (n_draws,)).mean(0) return loc, sample
def _vmap_predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray, samples: Optional[Dict[str, jnp.ndarray]] = None, n_draws: int = 1, ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Helper method to vectorize predictions over posterior samples """ if samples is None: samples = self.get_samples(chain_dim=False) num_samples = len(next(iter(samples.values()))) vmap_args = (jra.split(rng_key, num_samples), samples) predictive = lambda p1, p2: self.sample_single_posterior_predictive(p1, X_new, p2, n_draws) loc, f_samples = vmap(predictive)(*vmap_args) return loc, f_samples
[docs] def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray, samples: Optional[Dict[str, jnp.ndarray]] = None, n: int = 1, filter_nans: bool = False, take_point_predictions_mean: bool = True, device: Type[jaxlib.xla_client.Device] = None ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Make prediction at X_new points using posterior model parameters Args: rng_key: random number generator key X_new: 2D vector with new/'test' data of :math:`n x num_features` dimensionality samples: optional posterior samples n: number of samples to draw from normal distribution per single HMC sample filter_nans: filter out samples containing NaN values (if any) take_point_predictions_mean: take a mean of point predictions (without sampling from the normal distribution) device: optionally specify a cpu or gpu device on which to make a prediction; e.g., ```device=jax.devices("gpu")[0]``` Returns: Point predictions (or their mean) and posterior predictive distribution """ X_new = self._set_data(X_new) if samples is None: samples = self.get_samples(chain_dim=False) if device: X_new = jax.device_put(X_new, device) samples = jax.device_put(samples, device) y_pred, y_sampled = self._vmap_predict(rng_key, X_new, samples, n) if filter_nans: y_sampled_ = [y_i for y_i in y_sampled if not jnp.isnan(y_i).any()] y_sampled = jnp.array(y_sampled_) if take_point_predictions_mean: y_pred = y_pred.mean(0) return y_pred, y_sampled
def _print_summary(self): self.mcmc.print_summary() def _set_data(self, X: jnp.ndarray, y: Optional[jnp.ndarray] = None, ) -> Tuple[jnp.ndarray]: if y is not None: return X, y return X