Source code for gpax.models.vgp

"""
vgp.py
=======

Fully Bayesian implementation of Gaussian process regression for vector-valued functions

Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com)
"""

from functools import partial
from typing import Callable, Dict, Optional, Tuple, Union, Type

import jax
import jaxlib
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from jax import jit

from .gp import ExactGP


[docs]class vExactGP(ExactGP): """ Gaussian process class for vector-valued targets Args: input_dim: number of input dimensions kernel: type of kernel ('RBF', 'Matern', 'Periodic') mean_fn: optional deterministic mean function (use 'mean_fn_priors' to make it probabilistic) kernel_prior: optional custom priors over kernel hyperparameters (uses LogNormal(0,1) by default) mean_fn_prior: optional priors over mean function parameters noise_prior: optional custom prior for observation noise noise_prior_dist: Optional custom prior distribution over the observational noise variance. Defaults to LogNormal(0,1). lengthscale_prior_dist: Optional custom prior distribution over kernel lengthscale. Defaults to LogNormal(0, 1). """ def __init__(self, input_dim: int, kernel: str, mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None, kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, noise_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, noise_prior_dist: Optional[dist.Distribution] = None, lengthscale_prior_dist: Optional[dist.Distribution] = None ) -> None: args = (input_dim, kernel, mean_fn, kernel_prior, mean_fn_prior, noise_prior) super(vExactGP, self).__init__(*args) self.noise_prior_dist = noise_prior_dist self.lengthscale_prior_dist = lengthscale_prior_dist
[docs] def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float ) -> None: """GP probabilistic model with inputs X and vector-valued targets y""" jitter = kwargs.get("jitter", 1e-6) task_dim = X.shape[0] # Initialize mean function at zeros f_loc = jnp.zeros(X.shape[:2]) # Sample parameters of kernels if self.kernel_prior: kernel_params = self.kernel_prior() else: kernel_params = self._sample_kernel_params(task_dim=task_dim) # Sample noise for each task if self.noise_prior: # this will be removed in the future releases noise = self.noise_prior() else: noise = self._sample_noise(task_dim) # Add mean function (if any) if self.mean_fn is not None: args = [X] if self.mean_fn_prior is not None: args += [self.mean_fn_prior()] f_loc += self.mean_fn(*args).squeeze() # Compute kernels for each task in parallel jitter = jnp.array(jitter).repeat(task_dim) k_args = (X, X, kernel_params, noise) k = jax.vmap(self.kernel)(*k_args, jitter=jitter) # Sample y according to the standard Gaussian process formula numpyro.sample( "y", dist.MultivariateNormal(loc=f_loc, covariance_matrix=k), obs=y, )
def _sample_noise(self, task_dim) -> jnp.ndarray: if self.noise_prior_dist is not None: noise_dist = self.noise_prior_dist else: noise_dist = dist.LogNormal(0, 1) with numpyro.plate("noise_plate", task_dim): noise = numpyro.sample("noise", noise_dist) return noise def _sample_kernel_params(self, task_dim: int = None) -> Dict[str, jnp.ndarray]: """ Sample kernel parameters with default weakly-informative log-normal priors """ if self.lengthscale_prior_dist is not None: length_dist = self.lengthscale_prior_dist else: length_dist = dist.LogNormal(0.0, 1.0) with numpyro.plate("plate_1", task_dim, dim=-2): # task dimension with numpyro.plate('lengthscale', self.kernel_dim, dim=-1): # allows using ARD kernel for kernel_dim > 1 length = numpyro.sample("k_length", dist.LogNormal(0.0, 1.0)) with numpyro.plate("plate_2", task_dim): # task dimension' scale = numpyro.sample("k_scale", length_dist) if self.kernel_name == 'Periodic': period = numpyro.sample("period", dist.LogNormal(0.0, 1.0)) kernel_params = { "k_length": length, "k_scale": scale, "period": period if self.kernel_name == "Periodic" else None} return kernel_params #@partial(jit, static_argnames='self') def _get_mvn_posterior(self, X_train: jnp.ndarray, y_train: jnp.ndarray, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], m_X: Optional[jnp.ndarray] = None, m_p: Optional[jnp.ndarray] = None, noiseless: bool = False, **kwargs: float ) -> Tuple[jnp.ndarray, jnp.ndarray]: noise = params["noise"] noise_p = noise * (1 - jnp.array(noiseless, int)) y_residual = y_train if m_X is not None: y_residual -= m_X # compute kernel matrices for train and test data k_pp = self.kernel(X_new, X_new, params, noise_p, **kwargs) k_pX = self.kernel(X_new, X_train, params, jitter=0.0) k_XX = self.kernel(X_train, X_train, params, noise, **kwargs) # compute the predictive covariance and mean K_xx_inv = jnp.linalg.inv(k_XX) cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX))) mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_residual)) if m_p is not None: mean += m_p return mean, cov
[docs] def get_mvn_posterior(self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], noiseless: bool = False, **kwargs: float ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Returns parameters (mean and cov) of multivariate normal posterior for a single sample of GP parameters. Wrapper over self._get_mvn_posterior. """ task_dim = X_new.shape[0] jitter = kwargs.get("jitter", 1e-6) if self.mean_fn is not None: # Compute mean function for training data and new data get_args = lambda x: [x, params] if self.mean_fn_prior else [x] m_X = self.mean_fn(*get_args(self.X_train)).squeeze() m_p = self.mean_fn(*get_args(X_new)).squeeze() params_unsqueezed = { # ensure all params has the same 'batch dimension' for vmap to work k: p[None].repeat(X_new.shape[0], axis=0) if p.ndim == 0 else p for k, p in params.items() } vmap_args = ( self.X_train, self.y_train, X_new, params_unsqueezed, m_X, m_p) else: vmap_args = (self.X_train, self.y_train, X_new, params) noiseless = jnp.array(noiseless, int).repeat(task_dim) jitter = jnp.array(jitter).repeat(task_dim) mean, cov = jax.vmap( self._get_mvn_posterior)(*vmap_args, noiseless=noiseless, jitter=jitter) return mean, cov
[docs] def predict_in_batches(self, rng_key: jnp.ndarray, X_new: jnp.ndarray, batch_size: int = 100, samples: Optional[Dict[str, jnp.ndarray]] = None, n: int = 1, filter_nans: bool = False, predict_fn: Callable[[jnp.ndarray, int], Tuple[jnp.ndarray]] = None, noiseless: bool = False, device: Type[jaxlib.xla_client.Device] = None, **kwargs: float ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Make prediction at X_new with sampled GP parameters by spitting the input array into chunks ("batches") and running predict_fn (defaults to self.predict) on each of them one-by-one to avoid a memory overflow """ X_new = self._set_data(X_new) y_pred, y_sampled = self._predict_in_batches( rng_key, X_new, batch_size, 1, samples, n, filter_nans, predict_fn, noiseless, device, **kwargs) y_pred = jnp.concatenate(y_pred, -1) y_sampled = jnp.concatenate(y_sampled, -1) return y_pred, y_sampled
def _set_data(self, X: jnp.ndarray, y: Optional[jnp.ndarray] = None ) -> Union[Tuple[jnp.ndarray], jnp.ndarray]: X = X[..., None] if X.ndim == 2 else X # add feature pseudo-dimension if y is not None: if y.shape[0] != X.shape[0]: raise AssertionError( "Task dimensions must be identical in inputs and targets") return X, y return X