Source code for gpax.models.gp

"""
gp.py
=======

Fully Bayesian implementation of Gaussian process regression

Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com)
"""

import warnings
from typing import Callable, Dict, Optional, Tuple, Type, Union

import jax
import jaxlib
import jax.numpy as jnp
import jax.random as jra
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, init_to_median, Predictive

from ..kernels import get_kernel
from ..utils import split_in_batches

kernel_fn_type = Callable[[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray], jnp.ndarray], jnp.ndarray]

clear_cache = jax._src.dispatch.xla_primitive_callable.cache_clear


[docs]class ExactGP: """ Gaussian process class Args: input_dim: Number of input dimensions kernel: Kernel function ('RBF', 'Matern', 'Periodic', or custom function) mean_fn: Optional deterministic mean function (use 'mean_fn_priors' to make it probabilistic) kernel_prior: Optional custom priors over kernel hyperparameters. Use it when passing your custom kernel. mean_fn_prior: Optional priors over mean function parameters noise_prior_dist: Optional custom prior distribution over the observational noise variance. Defaults to LogNormal(0,1). lengthscale_prior_dist: Optional custom prior distribution over kernel lengthscale. Defaults to LogNormal(0, 1). Examples: Regular GP for sparse noisy obervations >>> # Get random number generator keys for training and prediction >>> rng_key, rng_key_predict = gpax.utils.get_keys() >>> # Initialize model >>> gp_model = gpax.ExactGP(input_dim=1, kernel='Matern') >>> # Run HMC to obtain posterior samples for the GP model parameters >>> gp_model.fit(rng_key, X, y) # X and y are arrays with dimensions (n, 1) and (n,) >>> # Make a noiseless prediction on new inputs >>> y_pred, y_samples = gp_model.predict(rng_key_predict, X_new, noiseless=True) GP with custom noise prior >>> gp_model = gpax.ExactGP( >>> input_dim=1, kernel='RBF', >>> noise_prior_dist = numpyro.distributions.HalfNormal(.1) >>> ) >>> # Run HMC to obtain posterior samples for the GP model parameters >>> gp_model.fit(rng_key, X, y) # X and y are arrays with dimensions (n, 1) and (n,) >>> # Make a noiselsess prediction on new inputs >>> y_pred, y_samples = gp_model.predict(rng_key_predict, X_new, noiseless=True) GP with custom probabilistic model as its mean function >>> # Define a deterministic mean function >>> mean_fn = lambda x, param: param["a"]*x + param["b"] >>> >>> # Define priors over the mean function parameters (to make it probabilistic) >>> def mean_fn_prior(): >>> a = numpyro.sample("a", numpyro.distributions.Normal(3, 1)) >>> b = numpyro.sample("b", numpyro.distributions.Normal(0, 1)) >>> return {"a": a, "b": b} >>> >>> # Initialize structural GP model >>> sgp_model = gpax.ExactGP( input_dim=1, kernel='Matern', mean_fn=mean_fn, mean_fn_prior=mean_fn_prior) >>> # Run HMC to obtain posterior samples for the GP model parameters >>> sgp_model.fit(rng_key, X, y) # X and y are numpy arrays with dimensions (n, d) and (n,) >>> # Make a noiselsess prediction on new inputs >>> y_pred, y_samples = gp_model.predict(rng_key_predict, X_new, noiseless=True) """ def __init__( self, input_dim: int, kernel: Union[str, kernel_fn_type], mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None, kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, noise_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, noise_prior_dist: Optional[dist.Distribution] = None, lengthscale_prior_dist: Optional[dist.Distribution] = None, ) -> None: clear_cache() if noise_prior is not None: warnings.warn( "`noise_prior` is deprecated and will be removed in a future version. " "Please use `noise_prior_dist` instead, which accepts an instance of a " "numpyro.distributions Distribution object, e.g., `dist.HalfNormal(scale=0.1)`, " "rather than a function that calls `numpyro.sample`.", FutureWarning, ) if kernel_prior is not None: warnings.warn( "`kernel_prior` will remain available for complex priors. However, for " "modifying only the lengthscales, it is recommended to use `lengthscale_prior_dist` instead. " "`lengthscale_prior_dist` accepts an instance of a numpyro.distributions Distribution object, " "e.g., `dist.Gamma(2, 5)`, rather than a function that calls `numpyro.sample`.", UserWarning, ) self.kernel_dim = input_dim self.kernel = get_kernel(kernel) self.kernel_name = kernel if isinstance(kernel, str) else None self.mean_fn = mean_fn self.kernel_prior = kernel_prior self.mean_fn_prior = mean_fn_prior self.noise_prior = noise_prior self.noise_prior_dist = noise_prior_dist self.lengthscale_prior_dist = lengthscale_prior_dist self.X_train = None self.y_train = None self.mcmc = None
[docs] def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float) -> None: """GP probabilistic model with inputs X and targets y""" # Initialize mean function at zeros f_loc = jnp.zeros(X.shape[0]) # Sample kernel parameters if self.kernel_prior: kernel_params = self.kernel_prior() else: kernel_params = self._sample_kernel_params() # Sample noise if self.noise_prior: # this will be removed in the future releases noise = self.noise_prior() else: noise = self._sample_noise() # Add mean function (if any) if self.mean_fn is not None: args = [X] if self.mean_fn_prior is not None: args += [self.mean_fn_prior()] f_loc += self.mean_fn(*args).squeeze() # compute kernel k = self.kernel(X, X, kernel_params, noise, **kwargs) # sample y according to the standard Gaussian process formula numpyro.sample( "y", dist.MultivariateNormal(loc=f_loc, covariance_matrix=k), obs=y, )
[docs] def fit( self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray, num_warmup: int = 2000, num_samples: int = 2000, num_chains: int = 1, chain_method: str = "sequential", progress_bar: bool = True, print_summary: bool = True, device: Type[jaxlib.xla_client.Device] = None, **kwargs: float ) -> None: """ Run Hamiltonian Monter Carlo to infer the GP parameters Args: rng_key: random number generator key X: 2D feature vector y: 1D target vector num_warmup: number of HMC warmup states num_samples: number of HMC samples num_chains: number of HMC chains chain_method: 'sequential', 'parallel' or 'vectorized' progress_bar: show progress bar print_summary: print summary at the end of sampling device: optionally specify a cpu or gpu device on which to run the inference; e.g., ``device=jax.devices("cpu")[0]`` **jitter: Small positive term added to the diagonal part of a covariance matrix for numerical stability (Default: 1e-6) """ X, y = self._set_data(X, y) if device: X = jax.device_put(X, device) y = jax.device_put(y, device) self.X_train = X self.y_train = y init_strategy = init_to_median(num_samples=10) kernel = NUTS(self.model, init_strategy=init_strategy) self.mcmc = MCMC( kernel, num_warmup=num_warmup, num_samples=num_samples, num_chains=num_chains, chain_method=chain_method, progress_bar=progress_bar, jit_model_args=False, ) self.mcmc.run(rng_key, X, y, **kwargs) if print_summary: self._print_summary()
def _sample_noise(self) -> jnp.ndarray: if self.noise_prior_dist is not None: noise_dist = self.noise_prior_dist else: noise_dist = dist.LogNormal(0, 1) return numpyro.sample("noise", noise_dist) def _sample_kernel_params(self, output_scale=True) -> Dict[str, jnp.ndarray]: """ Sample kernel parameters with default weakly-informative log-normal priors """ if self.lengthscale_prior_dist is not None: length_dist = self.lengthscale_prior_dist else: length_dist = dist.LogNormal(0.0, 1.0) with numpyro.plate("ard", self.kernel_dim): # allows using ARD kernel for kernel_dim > 1 length = numpyro.sample("k_length", length_dist) if output_scale: scale = numpyro.sample("k_scale", dist.LogNormal(0.0, 1.0)) else: scale = numpyro.deterministic("k_scale", jnp.array(1.0)) if self.kernel_name == "Periodic": period = numpyro.sample("period", dist.LogNormal(0.0, 1.0)) kernel_params = {"k_length": length, "k_scale": scale, "period": period if self.kernel_name == "Periodic" else None} return kernel_params
[docs] def get_samples(self, chain_dim: bool = False) -> Dict[str, jnp.ndarray]: """Get posterior samples (after running the MCMC chains)""" return self.mcmc.get_samples(group_by_chain=chain_dim)
[docs] def get_mvn_posterior( self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], noiseless: bool = False, **kwargs: float ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Returns parameters (mean and cov) of multivariate normal posterior for a single sample of GP parameters """ noise = params["noise"] noise_p = noise * (1 - jnp.array(noiseless, int)) y_residual = self.y_train.copy() if self.mean_fn is not None: args = [self.X_train, params] if self.mean_fn_prior else [self.X_train] y_residual -= self.mean_fn(*args).squeeze() # compute kernel matrices for train and test data k_pp = self.kernel(X_new, X_new, params, noise_p, **kwargs) k_pX = self.kernel(X_new, self.X_train, params, jitter=0.0) k_XX = self.kernel(self.X_train, self.X_train, params, noise, **kwargs) # compute the predictive covariance and mean K_xx_inv = jnp.linalg.inv(k_XX) cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX))) mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_residual)) if self.mean_fn is not None: args = [X_new, params] if self.mean_fn_prior else [X_new] mean += self.mean_fn(*args).squeeze() return mean, cov
def _predict( self, rng_key: jnp.ndarray, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], n: int, noiseless: bool = False, **kwargs: float ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Prediction with a single sample of GP parameters""" # Get the predictive mean and covariance y_mean, K = self.get_mvn_posterior(X_new, params, noiseless, **kwargs) # draw samples from the posterior predictive for a given set of parameters y_sampled = dist.MultivariateNormal(y_mean, K).sample(rng_key, sample_shape=(n,)) return y_mean, y_sampled def _predict_in_batches( self, rng_key: jnp.ndarray, X_new: jnp.ndarray, batch_size: int = 100, batch_dim: int = 0, samples: Optional[Dict[str, jnp.ndarray]] = None, n: int = 1, filter_nans: bool = False, predict_fn: Callable[[jnp.ndarray, int], Tuple[jnp.ndarray]] = None, noiseless: bool = False, device: Type[jaxlib.xla_client.Device] = None, **kwargs: float ) -> Tuple[jnp.ndarray, jnp.ndarray]: if predict_fn is None: predict_fn = lambda xi: self.predict(rng_key, xi, samples, n, filter_nans, noiseless, device, **kwargs) def predict_batch(Xi): out1, out2 = predict_fn(Xi) out1 = jax.device_put(out1, jax.devices("cpu")[0]) out2 = jax.device_put(out2, jax.devices("cpu")[0]) return out1, out2 y_out1, y_out2 = [], [] for Xi in split_in_batches(X_new, batch_size, dim=batch_dim): out1, out2 = predict_batch(Xi) y_out1.append(out1) y_out2.append(out2) return y_out1, y_out2
[docs] def predict_in_batches( self, rng_key: jnp.ndarray, X_new: jnp.ndarray, batch_size: int = 100, samples: Optional[Dict[str, jnp.ndarray]] = None, n: int = 1, filter_nans: bool = False, predict_fn: Callable[[jnp.ndarray, int], Tuple[jnp.ndarray]] = None, noiseless: bool = False, device: Type[jaxlib.xla_client.Device] = None, **kwargs: float ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Make prediction at X_new with sampled GP parameters by spitting the input array into chunks ("batches") and running predict_fn (defaults to self.predict) on each of them one-by-one to avoid a memory overflow """ y_pred, y_sampled = self._predict_in_batches( rng_key, X_new, batch_size, 0, samples, n, filter_nans, predict_fn, noiseless, device, **kwargs ) y_pred = jnp.concatenate(y_pred, 0) y_sampled = jnp.concatenate(y_sampled, -1) return y_pred, y_sampled
[docs] def predict( self, rng_key: jnp.ndarray, X_new: jnp.ndarray, samples: Optional[Dict[str, jnp.ndarray]] = None, n: int = 1, filter_nans: bool = False, noiseless: bool = False, device: Type[jaxlib.xla_client.Device] = None, **kwargs: float ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Make prediction at X_new points using posterior samples for GP parameters Args: rng_key: random number generator key X_new: new inputs with *(number of points, number of features)* dimensions samples: optional (different) samples with GP parameters n: number of samples from Multivariate Normal posterior for each HMC sample with GP parameters filter_nans: filter out samples containing NaN values (if any) noiseless: Noise-free prediction. It is set to False by default as new/unseen data is assumed to follow the same distribution as the training data. Hence, since we introduce a model noise by default for the training data, we also want to include that noise in our prediction. device: optionally specify a cpu or gpu device on which to make a prediction; e.g., ```device=jax.devices("gpu")[0]``` **jitter: Small positive term added to the diagonal part of a covariance matrix for numerical stability (Default: 1e-6) Returns Center of the mass of sampled means and all the sampled predictions """ X_new = self._set_data(X_new) if samples is None: samples = self.get_samples(chain_dim=False) if device: self._set_training_data(device=device) X_new = jax.device_put(X_new, device) samples = jax.device_put(samples, device) num_samples = len(next(iter(samples.values()))) vmap_args = (jra.split(rng_key, num_samples), samples) predictive = jax.vmap(lambda prms: self._predict(prms[0], X_new, prms[1], n, noiseless, **kwargs)) y_means, y_sampled = predictive(vmap_args) if filter_nans: y_sampled_ = [y_i for y_i in y_sampled if not jnp.isnan(y_i).any()] y_sampled = jnp.array(y_sampled_) return y_means.mean(0), y_sampled
[docs] def sample_from_prior(self, rng_key: jnp.ndarray, X: jnp.ndarray, num_samples: int = 10): """ Samples from prior predictive distribution at X """ X = self._set_data(X) prior_predictive = Predictive(self.model, num_samples=num_samples) samples = prior_predictive(rng_key, X) return samples["y"]
def _set_data(self, X: jnp.ndarray, y: Optional[jnp.ndarray] = None) -> Union[Tuple[jnp.ndarray], jnp.ndarray]: X = X if X.ndim > 1 else X[:, None] if y is not None: return X, y.squeeze() return X def _set_training_data( self, X_train_new: jnp.ndarray = None, y_train_new: jnp.ndarray = None, device: Type[jaxlib.xla_client.Device] = None, ) -> None: X_train = self.X_train if X_train_new is None else X_train_new y_train = self.y_train if y_train_new is None else y_train_new if device: X_train = jax.device_put(X_train, device) y_train = jax.device_put(y_train, device) self.X_train = X_train self.y_train = y_train def _print_summary(self): samples = self.get_samples(1) numpyro.diagnostics.print_summary(samples)