"""
mngp.py
=======
Fully Bayesian Gaussian Process model that incorporates measured noise.
Created by Maxim Ziatdinov (email: maxim.ziatdinov@gmail.com)
"""
from typing import Callable, Dict, Optional, Tuple, Type, Union
import jax
import jaxlib
import jax.numpy as jnp
import jax.random as jra
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, init_to_median
from .gp import ExactGP
from .vigp import viGP
from .linreg import LinReg
from ..utils import get_keys
kernel_fn_type = Callable[[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray], jnp.ndarray], jnp.ndarray]
[docs]class MeasuredNoiseGP(ExactGP):
"""
Gaussian Process model that incorporates measured noise.
This class extends the ExactGP model by allowing the inclusion of measured noise variances
in the GP framework. Unlike standard GP models where noise is typically inferred, this model
uses noise values obtained from repeated measurements at the same input points.
Args:
input_dim:
Number of input dimensions
kernel:
Kernel function ('RBF', 'Matern', 'Periodic', or custom function)
mean_fn:
Optional deterministic mean function (use 'mean_fn_priors' to make it probabilistic)
kernel_prior:
Optional custom priors over kernel hyperparameters. Use it when passing your custom kernel.
mean_fn_prior:
Optional priors over mean function parameters
lengthscale_prior_dist:
Optional custom prior distribution over kernel lengthscale. Defaults to LogNormal(0, 1).
Examples:
>>> # Get random number generator keys for training and prediction
>>> key1, key2 = gpax.utils.get_keys()
>>> # Initialize model
>>> gp_model = gpax.MeasuredNoiseGP(input_dim=1, kernel='Matern')
>>> # Run HMC to obtain posterior samples for the GP model parameters
>>> gp_model.fit(key1, X, y_mean, noise) # X, y_mean, and noise have dimensions (n, 1), (n,), and (n,)
>>> # Make a prediction on new inputs by extrapolating noise variance with either linear regression or gaussian process
>>> y_pred, y_samples = gp_model.predict(key2, X_new, noise_prediction_method='linreg')
"""
def __init__(self,
input_dim: int,
kernel: Union[str, kernel_fn_type],
mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None,
kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
lengthscale_prior_dist: Optional[dist.Distribution] = None
) -> None:
args = (input_dim, kernel, mean_fn, kernel_prior, mean_fn_prior, None, None, lengthscale_prior_dist)
super(MeasuredNoiseGP, self).__init__(*args)
self.measured_noise = None
self.noise_predicted = None
[docs] def model(self, X: jnp.ndarray, y: jnp.ndarray = None, measured_noise: jnp.ndarray = None, **kwargs) -> None:
"""GP model that accepts measured noise"""
# Initialize mean function at zeros
f_loc = jnp.zeros(X.shape[0])
# Sample kernel parameters
if self.kernel_prior:
kernel_params = self.kernel_prior()
else:
kernel_params = self._sample_kernel_params()
# Since we provide a measured noise, we don't infer it
noise = numpyro.deterministic("noise", jnp.array(0.0))
# Add mean function (if any)
if self.mean_fn is not None:
args = [X]
if self.mean_fn_prior is not None:
args += [self.mean_fn_prior()]
f_loc += self.mean_fn(*args).squeeze()
# compute kernel (with zero noise)
k = self.kernel(X, X, kernel_params, 0, **kwargs)
# Sample y according to the standard Gaussian process formula. Add measured noise to the covariance matrix
numpyro.sample(
"y",
dist.MultivariateNormal(loc=f_loc, covariance_matrix=k+jnp.diag(measured_noise)),
obs=y,
)
[docs] def fit(
self,
rng_key: jnp.array,
X: jnp.ndarray,
y: jnp.ndarray,
measured_noise: jnp.ndarray,
num_warmup: int = 2000,
num_samples: int = 2000,
num_chains: int = 1,
chain_method: str = "sequential",
progress_bar: bool = True,
print_summary: bool = True,
device: Type[jaxlib.xla_client.Device] = None,
**kwargs: float
) -> None:
"""
Run Hamiltonian Monter Carlo to infer the GP parameters
Args:
rng_key: random number generator key
X: 2D feature vector
y: 1D target vector
measured_noise: 1D vector with measured noise
num_warmup: number of HMC warmup states
num_samples: number of HMC samples
num_chains: number of HMC chains
chain_method: 'sequential', 'parallel' or 'vectorized'
progress_bar: show progress bar
print_summary: print summary at the end of sampling
device:
optionally specify a cpu or gpu device on which to run the inference;
e.g., ``device=jax.devices("cpu")[0]``
**jitter:
Small positive term added to the diagonal part of a covariance
matrix for numerical stability (Default: 1e-6)
"""
X, y = self._set_data(X, y)
if device:
X = jax.device_put(X, device)
y = jax.device_put(y, device)
self.X_train = X
self.y_train = y
self.measured_noise = measured_noise
init_strategy = init_to_median(num_samples=10)
kernel = NUTS(self.model, init_strategy=init_strategy)
self.mcmc = MCMC(
kernel,
num_warmup=num_warmup,
num_samples=num_samples,
num_chains=num_chains,
chain_method=chain_method,
progress_bar=progress_bar,
jit_model_args=False,
)
self.mcmc.run(rng_key, X, y, measured_noise, **kwargs)
if print_summary:
self._print_summary()
def _predict(
self,
rng_key: jnp.ndarray,
X_new: jnp.ndarray,
params: Dict[str, jnp.ndarray],
noise_predicted: jnp.ndarray,
n: int,
noiseless: bool = False,
**kwargs: float) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Prediction with a single sample of GP parameters"""
def sigma_sample(rng_key, K, X_new_shape):
sig = jnp.sqrt(jnp.clip(jnp.diag(K), a_min=0.0))
return sig * jra.normal(rng_key, X_new_shape[:1])
# Get the predictive mean and covariance
y_mean, K = self.get_mvn_posterior(X_new, params, noiseless, **kwargs)
# Add predicted noise to K's diagonal
K += jnp.diag(noise_predicted)
# Draw samples from the posterior predictive for a given set of parameters
rng_keys = jra.split(rng_key, n)
sig = jax.vmap(sigma_sample, in_axes=(0, None, None))(rng_keys, K, X_new.shape)
y_sampled = y_mean + sig
return y_mean, y_sampled
[docs] def predict(
self,
rng_key: jnp.ndarray,
X_new: jnp.ndarray,
samples: Optional[Dict[str, jnp.ndarray]] = None,
n: int = 1,
filter_nans: bool = False,
noiseless: bool = True,
device: Type[jaxlib.xla_client.Device] = None,
noise_prediction_method: str = 'linreg',
**kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Make prediction at X_new points using posterior samples for GP parameters
Args:
rng_key: random number generator key
X_new: new inputs with *(number of points, number of features)* dimensions
samples: optional (different) samples with GP parameters
n: number of samples from Multivariate Normal posterior for each HMC sample with GP parameters
filter_nans: filter out samples containing NaN values (if any)
noiseless:
Noise-free prediction. It is set to False by default as new/unseen data is assumed
to follow the same distribution as the training data. Hence, since we introduce a model noise
by default for the training data, we also want to include that noise in our prediction.
device:
optionally specify a cpu or gpu device on which to make a prediction;
e.g., ```device=jax.devices("gpu")[0]```
noise_prediction_method:
Method for extrapolating noise variance to new/test data.
Choose between 'linreg' and 'gpreg'. Defaults to 'linreg'.
**jitter:
Small positive term added to the diagonal part of a covariance
matrix for numerical stability (Default: 1e-6)
Returns
Center of the mass of sampled means and all the sampled predictions
"""
if noise_prediction_method not in ["linreg", "gpreg"]:
raise NotImplementedError(
"For noise prediction method, select between 'linreg' and 'gpreg'")
noise_pred_fn = self.linreg if noise_prediction_method == "linreg" else self.gpreg
X_new = self._set_data(X_new)
# Predict noise for X_new
if self.noise_predicted is not None:
noise_predicted = self.noise_predicted
else:
noise_predicted = noise_pred_fn(self.X_train, self.measured_noise, X_new, **kwargs)
self.noise_predicted = noise_predicted
if samples is None:
samples = self.get_samples(chain_dim=False)
if device:
self._set_training_data(device=device)
X_new = jax.device_put(X_new, device)
samples = jax.device_put(samples, device)
num_samples = len(next(iter(samples.values())))
vmap_args = (jra.split(rng_key, num_samples), samples)
predictive = jax.vmap(lambda prms: self._predict(prms[0], X_new, prms[1], noise_predicted, n, noiseless, **kwargs))
y_means, y_sampled = predictive(vmap_args)
if filter_nans:
y_sampled_ = [y_i for y_i in y_sampled if not jnp.isnan(y_i).any()]
y_sampled = jnp.array(y_sampled_)
return y_means.mean(0), y_sampled
[docs] def linreg(self, x, y, x_new, **kwargs):
lreg = LinReg()
lreg.train(x, y, **kwargs)
return lreg.predict(x_new)
[docs] def gpreg(self, x, y, x_new, **kwargs):
keys = get_keys()
vigp = viGP(self.kernel_dim, 'RBF', **kwargs)
vigp.fit(keys[0], x, y, progress_bar=False, print_summary=False, **kwargs)
return vigp.predict(keys[1], x_new, noiseless=True)[0]