Source code for gpax.models.hskgp

"""
hskgp.py
=========

Fully Bayesian implementation of heteroskedastic Gaussian process regression

Created by Maxim Ziatdinov (email: maxim.ziatdinov@gmail.com)
"""

from typing import Callable, Dict, Optional, Tuple, Union

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist

from . import ExactGP
from ..kernels import get_kernel
from ..utils import _set_noise_kernel_fn

kernel_fn_type = Callable[[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray], jnp.ndarray], jnp.ndarray]


[docs]class VarNoiseGP(ExactGP): """ Heteroskedastic Gaussian process class Args: input_dim: Number of input dimensions kernel: Main kernel function ('RBF', 'Matern', 'Periodic', or custom function) noise_kernel: Noise kernel function ('RBF', 'Matern', 'Periodic', or custom function) mean_fn: Optional deterministic mean function (use 'mean_fn_priors' to make it probabilistic) kernel_prior: Optional custom priors over main kernel hyperparameters. Use it when passing your custom kernel. mean_fn_prior: Optional priors over mean function parameters noise_kernel_prior: Optional custom priors over noise kernel hyperparameters. Use it when passing your custom kernel. lengthscale_prior_dist: Optional custom prior distribution over main kernel lengthscale. Defaults to LogNormal(0, 1). noise_mean_fn: Optional noise mean function noise_mean_fn_prior: Optional priors over noise mean function noise_lengthscale_prior_dist: Optional custom prior distribution over noise kernel lengthscale. Defaults to LogNormal(0, 1). Examples: Use two different kernels with default priors for main and noise processes >>> # Get random number generator keys for training and prediction >>> rng_key, rng_key_predict = gpax.utils.get_keys() >>> # Initialize model >>> gp_model = gpax.VarNoiseGP(input_dim=1, kernel='RBF, noise_kernel='Matern') >>> # Run HMC to obtain posterior samples for the GP model parameters >>> gp_model.fit(rng_key, X, y) >>> # Make a prediction on new inputs >>> y_pred, y_samples = gp_model.predict(rng_key_predict, X_new) >>> # Get the inferred noise samples (for training data) >>> data_variance = gp_model.get_data_var_samples() Specify custom kernel lengthscale priors for main and noise kernels >>> lscale_prior = gpax.utils.gamma_dist(5, 1) # equivalent to numpyro.distributions.Gamma(5, 1) >>> noise_lscale_prior = gpax.utils.halfnormal_dist(1) # equivalent to numpyro.distributions.HalfNormal(1) >>> # Initialize model >>> gp_model = gpax.VarNoiseGP( >>> input_dim=1, kernel='RBF, noise_kernel='Matern', >>> lengthscale_prior_dist=lscale_prior, noise_lengthscale_prior_dist=noise_lscale_prior) >>> # Run HMC to obtain posterior samples for the GP model parameters >>> gp_model.fit(rng_key, X, y) >>> # Make a prediction on new inputs >>> y_pred, y_samples = gp_model.predict(rng_key_predict, X_new) >>> # Get the inferred noise samples (for training data) >>> data_variance = gp_model.get_data_var_samples() """ def __init__( self, input_dim: int, kernel: Union[str, kernel_fn_type], noise_kernel: Union[str, kernel_fn_type] = 'RBF', mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None, kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, noise_kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, lengthscale_prior_dist: Optional[dist.Distribution] = None, noise_mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None, noise_mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, noise_lengthscale_prior_dist: Optional[dist.Distribution] = None ) -> None: args = (input_dim, kernel, mean_fn, kernel_prior, mean_fn_prior, None, None, lengthscale_prior_dist) super(VarNoiseGP, self).__init__(*args) noise_kernel_ = get_kernel(noise_kernel) self.noise_kernel = _set_noise_kernel_fn(noise_kernel_) if isinstance(noise_kernel, str) else noise_kernel_ self.noise_mean_fn = noise_mean_fn self.noise_mean_fn_prior = noise_mean_fn_prior self.noise_kernel_prior = noise_kernel_prior self.noise_lengthscale_prior_dist = noise_lengthscale_prior_dist
[docs] def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float) -> None: """Heteroskedastic GP probabilistic model with inputs X and targets y""" # Initialize mean functions at zeros f_loc = jnp.zeros(X.shape[0]) noise_f_loc = jnp.zeros(X.shape[0]) # Sample noise kernel parameters if self.noise_kernel_prior: noise_kernel_params = self.noise_kernel_prior() else: noise_kernel_params = self._sample_noise_kernel_params() # Add noise prior mean function (if any) if self.noise_mean_fn is not None: args = [X] if self.noise_mean_fn_prior is not None: args += [self.noise_mean_fn_prior()] noise_f_loc += jnp.log(self.noise_mean_fn(*args)).squeeze() # Compute noise kernel k_noise = self.noise_kernel(X, X, noise_kernel_params, 0, **kwargs) # Compute log variance of the data points points_log_var = numpyro.sample( "log_var", dist.MultivariateNormal(loc=noise_f_loc, covariance_matrix=k_noise) ) # Sample main kernel parameters if self.kernel_prior: kernel_params = self.kernel_prior() else: kernel_params = self._sample_kernel_params() # Add mean function (if any) if self.mean_fn is not None: args = [X] if self.mean_fn_prior is not None: args += [self.mean_fn_prior()] f_loc += self.mean_fn(*args).squeeze() # Compute main kernel k = self.kernel(X, X, kernel_params, 0, **kwargs) # Sample y according to the standard Gaussian process formula. Note that instead of adding a fixed noise term to the kernel as in regular GP, # we exponentiate and take a diagonal of the log_var samples to get the variance at each data point and add that variance to the main kernel numpyro.sample( "y", dist.MultivariateNormal(loc=f_loc, covariance_matrix=k+jnp.diag(jnp.exp(points_log_var))), obs=y, )
def _sample_noise_kernel_params(self) -> Dict[str, jnp.ndarray]: """ Sample noise kernel parameters """ if self.noise_lengthscale_prior_dist is not None: noise_length_dist = self.noise_lengthscale_prior_dist else: noise_length_dist = dist.LogNormal(0, 1) noise_scale = numpyro.sample("k_noise_scale", dist.LogNormal(0, 1)) noise_length = numpyro.sample("k_noise_length", noise_length_dist) return {"k_noise_length": noise_length, "k_noise_scale": noise_scale}
[docs] def get_mvn_posterior( self, X_new: jnp.ndarray, params: Dict[str, jnp.ndarray], *args, **kwargs ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Returns parameters (mean and cov) of multivariate normal posterior for a single sample of heteroskedastic GP parameters """ # Main GP part y_residual = self.y_train.copy() if self.mean_fn is not None: args = [self.X_train, params] if self.mean_fn_prior else [self.X_train] y_residual -= self.mean_fn(*args).squeeze() # Compute main kernel matrices for train and test data k_pp = self.kernel(X_new, X_new, params, 0, **kwargs) k_pX = self.kernel(X_new, self.X_train, params, jitter=0.0) k_XX = self.kernel(self.X_train, self.X_train, params, 0, **kwargs) # Compute the predictive covariance and mean K_xx_inv = jnp.linalg.inv(k_XX) cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX))) mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_residual)) if self.mean_fn is not None: args = [X_new, params] if self.mean_fn_prior else [X_new] mean += self.mean_fn(*args).squeeze() # Noise GP part # Compute noise kernel matrices k_pX_noise = self.noise_kernel(X_new, self.X_train, params, jitter=0.0) k_XX_noise = self.noise_kernel(self.X_train, self.X_train, params, 0, **kwargs) # Compute noise predictive mean log_var_residual = params["log_var"].copy() if self.noise_mean_fn is not None: args = [self.X_train, params] if self.noise_mean_fn_prior else [self.X_train] log_var_residual -= jnp.log(self.noise_mean_fn(*args)).squeeze() K_xx_noise_inv = jnp.linalg.inv(k_XX_noise) predicted_log_var = jnp.matmul(k_pX_noise, jnp.matmul(K_xx_noise_inv, log_var_residual)) if self.noise_mean_fn is not None: args = [X_new, params] if self.noise_mean_fn_prior else [X_new] predicted_log_var += jnp.log(self.noise_mean_fn(*args)).squeeze() predicted_noise_variance = jnp.exp(predicted_log_var) # Return the main GP's predictive mean and combined (main + noise) covariance matrix return mean, cov + jnp.diag(predicted_noise_variance)
[docs] def get_data_var_samples(self): """Returns samples with inferred (training) data variance - aka noise""" samples = self.mcmc.get_samples() log_var = samples["log_var"] if self.noise_mean_fn is not None: if self.noise_mean_fn_prior is not None: mean_ = jax.vmap(self.noise_mean_fn, in_axes=(None, 0))(self.X_train.squeeze(), samples) else: mean_ = self.noise_mean_fn(self.X_train.squeeze()) log_var += jnp.log(mean_) return jnp.exp(log_var)
def _print_summary(self): samples = self.get_samples(1) numpyro.diagnostics.print_summary({k: v for (k, v) in samples.items() if 'log_var' not in k})