"""
sparse_gp.py
============
Variational inference implementation of sparse Gaussian process regression
Created by Maxim Ziatdinov (email: maxim.ziatdinov@gmail.com)
"""
from typing import Callable, Dict, Optional, Tuple, Type
import jax
import jaxlib
import jax.numpy as jnp
from jax.scipy.linalg import cholesky, solve_triangular
import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO
from .vigp import viGP
from ..utils import initialize_inducing_points
[docs]class viSparseGP(viGP):
"""
Variational inference-based sparse Gaussian process
Args:
input_dim:
Number of input dimensions
kernel:
Kernel function ('RBF', 'Matern', 'Periodic', or custom function)
mean_fn:
Optional deterministic mean function (use 'mean_fn_priors' to make it probabilistic)
kernel_prior:
Optional custom priors over kernel hyperparameters; uses LogNormal(0,1) by default
mean_fn_prior:
Optional priors over mean function parameters
noise_prior_dist:
Optional custom prior distribution over the observational noise variance.
Defaults to LogNormal(0,1).
lengthscale_prior_dist:
Optional custom prior distribution over kernel lengthscale.
Defaults to LogNormal(0, 1).
guide:
Auto-guide option, use 'delta' (default) or 'normal'
"""
def __init__(self, input_dim: int, kernel: str,
mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None,
kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
noise_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
noise_prior_dist: Optional[dist.Distribution] = None,
lengthscale_prior_dist: Optional[dist.Distribution] = None,
guide: str = 'delta') -> None:
args = (input_dim, kernel, mean_fn, kernel_prior, mean_fn_prior, noise_prior,
noise_prior_dist, lengthscale_prior_dist, guide)
super(viSparseGP, self).__init__(*args)
self.Xu = None
[docs] def model(self,
X: jnp.ndarray,
y: jnp.ndarray = None,
Xu: jnp.ndarray = None,
**kwargs: float) -> None:
"""
Probabilistic sparse Gaussian process regression model
"""
if Xu is not None:
Xu = numpyro.param("Xu", Xu)
# Initialize mean function at zeros
f_loc = jnp.zeros(X.shape[0])
# Sample kernel parameters
if self.kernel_prior:
kernel_params = self.kernel_prior()
else:
kernel_params = self._sample_kernel_params()
# Sample noise
if self.noise_prior: # this will be removed in the future releases
noise = self.noise_prior()
else:
noise = self._sample_noise()
D = jnp.broadcast_to(noise, (X.shape[0],) )
# Add mean function (if any)
if self.mean_fn is not None:
args = [X]
if self.mean_fn_prior is not None:
args += [self.mean_fn_prior()]
f_loc += self.mean_fn(*args).squeeze()
# Compute kernel between inducing points
Kuu = self.kernel(Xu, Xu, kernel_params, **kwargs)
# Cholesky decomposition
Luu = cholesky(Kuu).T
# Compute kernel between inducing and training points
Kuf = self.kernel(Xu, X, kernel_params)
# Solve triangular system
W = solve_triangular(Luu, Kuf, lower=True).T
# Diagonal of the kernel matrix
Kffdiag = jnp.diag(self.kernel(X, X, kernel_params, jitter=0))
# Sum of squares computation
Qffdiag = jnp.square(W).sum(axis=-1)
# Trace term computation
trace_term = (Kffdiag - Qffdiag).sum() / noise
# Clamping the trace term
trace_term = jnp.clip(trace_term, a_min=0)
# VFE approximation
numpyro.factor("trace_term", -trace_term / 2.0)
numpyro.sample(
"y",
dist.LowRankMultivariateNormal(loc=f_loc, cov_factor=W, cov_diag=D),
obs=y)
[docs] def fit(self,
rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray,
inducing_points_ratio: float = 0.1, inducing_points_selection: str = 'random',
num_steps: int = 1000, step_size: float = 5e-3,
progress_bar: bool = True, print_summary: bool = True,
device: Type[jaxlib.xla_client.Device] = None,
**kwargs: float
) -> None:
"""
Run variational inference to learn sparse GP (hyper)parameters
Args:
rng_key: random number generator key
X: 2D feature vector with *(number of points, number of features)* dimensions
y: 1D target vector with *(n,)* dimensions
Xu: Inducing points ratio. Must be a float between 0 and 1. Default value is 0.1.
num_steps: number of SVI steps
step_size: step size schedule for Adam optimizer
progress_bar: show progress bar
print_summary: print summary at the end of training
device:
optionally specify a cpu or gpu device on which to run the inference;
e.g., ``device=jax.devices("cpu")[0]``
**jitter:
Small positive term added to the diagonal part of a covariance
matrix for numerical stability (Default: 1e-6)
"""
X, y = self._set_data(X, y)
if device:
X = jax.device_put(X, device)
y = jax.device_put(y, device)
Xu = initialize_inducing_points(
X.copy(), inducing_points_ratio,
inducing_points_selection, rng_key)
self.X_train = X
self.y_train = y
optim = numpyro.optim.Adam(step_size=step_size, b1=0.5)
self.svi = SVI(
self.model,
guide=self.guide_type(self.model),
optim=optim,
loss=Trace_ELBO(),
X=X,
y=y,
Xu=Xu,
**kwargs
)
self.kernel_params = self.svi.run(
rng_key, num_steps, progress_bar=progress_bar)[0]
self.Xu = self.kernel_params['Xu']
if print_summary:
self._print_summary()
[docs] def get_mvn_posterior(self, X_new: jnp.ndarray,
params: Dict[str, jnp.ndarray],
noiseless: bool = False,
**kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Returns parameters (mean and cov) of multivariate normal posterior
for a single sample of GP parameters
"""
noise = params["noise"]
N = self.X_train.shape[0]
D = jnp.broadcast_to(noise, (N,))
noise_p = noise * (1 - jnp.array(noiseless, int))
y_residual = self.y_train.copy()
if self.mean_fn is not None:
args = [self.X_train, params] if self.mean_fn_prior else [self.X_train]
y_residual -= self.mean_fn(*args).squeeze()
# Compute self- and cross-covariance matrices
Kuu = self.kernel(self.Xu, self.Xu, params, **kwargs)
Luu = cholesky(Kuu, lower=True)
Kuf = self.kernel(self.Xu, self.X_train, params, jitter=0)
W = solve_triangular(Luu, Kuf, lower=True)
W_Dinv = W / D
K = W_Dinv @ W.T
K = K.at[jnp.diag_indices(K.shape[0])].add(1)
L = cholesky(K, lower=True)
y_2D = y_residual.reshape(-1, N).T
W_Dinv_y = W_Dinv @ y_2D
Kus = self.kernel(self.Xu, X_new, params, jitter=0)
Ws = solve_triangular(Luu, Kus, lower=True)
pack = jnp.concatenate((W_Dinv_y, Ws), axis=1)
Linv_pack = solve_triangular(L, pack, lower=True)
Linv_W_Dinv_y = Linv_pack[:, :W_Dinv_y.shape[1]]
Linv_Ws = Linv_pack[:, W_Dinv_y.shape[1]:]
mean = (Linv_W_Dinv_y.T @ Linv_Ws).squeeze()
Kss = self.kernel(X_new, X_new, params, noise_p, **kwargs)
Qss = Ws.T @ Ws
cov = Kss - Qss + Linv_Ws.T @ Linv_Ws
if self.mean_fn is not None:
args = [X_new, params] if self.mean_fn_prior else [X_new]
mean += self.mean_fn(*args).squeeze()
return mean, cov