"""
vidkl.py
========
Variational inference-based implementation of deep kernel learning
Created by Maxim Ziatdinov (email: maxim.ziatdinov@gmail.com)
"""
from functools import partial
from typing import Callable, Dict, Optional, Tuple, Union
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoDelta, AutoNormal
from numpyro.contrib.module import random_haiku_module, haiku_module
from jax import jit
import haiku as hk
from .gp import ExactGP
from ..utils import get_haiku_dict
[docs]class viDKL(ExactGP):
"""
Implementation of the variational infernece-based deep kernel learning
Args:
input_dim:
Input features dimensions (e.g. 64*64 for a stack of flattened 64-by-64 images)
z_dim:
Latent space dimensionality (defaults to 2)
kernel:
Kernel function ('RBF', 'Matern', 'Periodic', or custom function)
kernel_prior:
Optional priors over kernel hyperparameters; uses LogNormal(0,1) by default
nn:
Custom neural network ('feature extractor'); uses a 3-layer MLP
with ReLU activations by default
nn_prior:
Places probabilistic priors over NN weights and biases (Default: True)
latent_prior:
Optional prior over the latent space (NN embedding); uses none by default
guide:
Auto-guide option, use 'delta' (default) or 'normal'
**kwargs:
Optional custom prior distributions over observational noise (noise_dist_prior)
and kernel lengthscale (lengthscale_prior_dist)
Examples:
vi-DKL with image patches as inputs and a 1-d vector as targets
>>> # Get random number generator keys for training and prediction
>>> key1, key2 = gpax.utils.get_keys()
>>> input data dimensions are (n, height*width*channels)
>>> data_dim = X.shape[-1]
>>> # Initialize vi-DKL model with 2 latent dimensions
>>> dkl = gpax.viDKL(input_dim=data_dim, z_dim=2, kernel='RBF')
>>> Train a model
>>> dkl.fit(rng_key, X_train, y_train, num_steps=1000, step_size=0.005)
>>> # Obtain posterior mean and variance ('uncertainty') at new inputs
>>> y_mean, y_var = dkl.predict(key2, X_new)
"""
def __init__(self, input_dim: Union[int, Tuple[int]], z_dim: int = 2, kernel: str = 'RBF',
kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
nn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, nn_prior: bool = True,
latent_prior: Optional[Callable[[jnp.ndarray], Dict[str, jnp.ndarray]]] = None,
guide: str = 'delta', **kwargs
) -> None:
super(viDKL, self).__init__(input_dim, kernel, None, kernel_prior, **kwargs)
if guide not in ['delta', 'normal']:
raise NotImplementedError("Select guide between 'delta' and 'normal'")
nn_module = nn if nn else MLP
self.nn_module = hk.transform(lambda x: nn_module(z_dim)(x))
self.nn_prior = nn_prior
self.kernel_dim = z_dim
self.data_dim = (input_dim,) if isinstance(input_dim, int) else input_dim
self.latent_prior = latent_prior
self.guide_type = AutoNormal if guide == 'normal' else AutoDelta
self.kernel_params = None
self.nn_params = None
[docs] def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs) -> None:
"""DKL probabilistic model"""
# NN part
if self.nn_prior: # MAP
feature_extractor = random_haiku_module(
"feature_extractor", self.nn_module, input_shape=(1, *self.data_dim),
prior=(lambda name, shape: dist.Cauchy() if name.startswith("b") else dist.Normal()))
else: # MLE
feature_extractor = haiku_module(
"feature_extractor", self.nn_module, input_shape=(1, *self.data_dim))
z = feature_extractor(X)
if self.latent_prior: # Sample latent variable
z = self.latent_prior(z)
# Sample GP kernel parameters
if self.kernel_prior:
kernel_params = self.kernel_prior()
else:
kernel_params = self._sample_kernel_params()
# Sample noise
noise = self._sample_noise()
# GP's mean function
f_loc = jnp.zeros(z.shape[0])
# compute kernel
k = self.kernel(
z, z,
kernel_params,
noise,
**kwargs
)
# sample y according to the standard Gaussian process formula
numpyro.sample(
"y",
dist.MultivariateNormal(loc=f_loc, covariance_matrix=k),
obs=y,
)
[docs] def single_fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray,
num_steps: int = 1000, step_size: float = 5e-3,
print_summary: bool = True, progress_bar=True,
**kwargs) -> None:
"""
Optimizes parameters of a single DKL model
"""
# Setup optimizer and SVI
optim = numpyro.optim.Adam(step_size=step_size, b1=0.5)
svi = SVI(
self.model,
guide=self.guide_type(self.model),
optim=optim,
loss=Trace_ELBO(),
X=X,
y=y,
**kwargs
)
params, _, losses = svi.run(rng_key, num_steps, progress_bar=progress_bar)
# Get DKL trained parameters from the guide
if self.nn_prior: # MAP
params_map = svi.guide.median(params)
# Get NN weights
nn_params = get_haiku_dict(params_map)
# Get GP kernel hyperparmeters
kernel_params = {k: v for (k, v) in params_map.items()
if not k.startswith("feature_extractor")}
else: # MLE
# Get NN weights
nn_params = params["feature_extractor$params"]
# Get kernel parameters from the guide
kernel_params = svi.guide.median(params)
return nn_params, kernel_params, losses
[docs] def fit(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray,
num_steps: int = 1000, step_size: float = 5e-3,
print_summary: bool = True, progress_bar=True, **kwargs):
"""
Run stochastic variational inference to learn a DKL model(s) parameters
Args:
rng_key: random number generator key
X: Input high-dimensional features
y: Target output (scalar of vector)
num_steps: number of SVI steps
step_size: step size schedule for Adam optimizer
print_summary: print summary at the end of sampling
progress_bar: show progress bar (works only for scalar outputs)
"""
self.X_train = X
self.y_train = y
if y.ndim == 2: # y has shape (channels, samples), so so we use vmap to fit all channels in parallel
# Define a wrapper to use with vmap
def _single_fit(yi):
return self.single_fit(
rng_key, X, yi, num_steps, step_size,
print_summary=False, progress_bar=False, **kwargs)
# Apply vmap to the wrapper function
vfit = jax.vmap(_single_fit)
self.nn_params, self.kernel_params, self.loss = vfit(y)
# Poor man version of the progress bar
if progress_bar:
avg_bw = [num_steps - num_steps // 20, num_steps]
print("init loss: {}, final loss (avg) [{}-{}]: {} ".format(
self.loss[0].mean(), avg_bw[0], avg_bw[1],
self.loss.mean(0)[avg_bw[0]:avg_bw[1]].mean().round(4)))
else: # no channel dimension so we use the regular single_fit
self.nn_params, self.kernel_params, self.loss = self.single_fit(
rng_key, X, y, num_steps, step_size, print_summary, progress_bar
)
if print_summary:
self._print_summary()
#@partial(jit, static_argnames='self')
[docs] def get_mvn_posterior(self,
X_new: jnp.ndarray,
nn_params: Dict[str, jnp.ndarray],
k_params: Dict[str, jnp.ndarray],
noiseless: bool = False,
y_residual: jnp.ndarray = None,
**kwargs
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Returns predictive mean and covariance at new points
(mean and cov, where cov.diagonal() is 'uncertainty')
given a single set of DKL parameters
"""
if y_residual is None:
y_residual = self.y_train
noise = k_params["noise"]
noise_p = noise * (1 - jnp.array(noiseless, int))
# embed data into the latent space
z_train = self.nn_module.apply(
nn_params, jax.random.PRNGKey(0), self.X_train)
z_test = self.nn_module.apply(
nn_params, jax.random.PRNGKey(0), X_new)
# compute kernel matrices for train and test data
k_pp = self.kernel(z_test, z_test, k_params, noise_p, **kwargs)
k_pX = self.kernel(z_test, z_train, k_params, jitter=0.0)
k_XX = self.kernel(z_train, z_train, k_params, noise, **kwargs)
# compute the predictive covariance and mean
K_xx_inv = jnp.linalg.inv(k_XX)
cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_residual))
return mean, cov
[docs] def sample_from_posterior(self, rng_key: jnp.ndarray,
X_new: jnp.ndarray, n: int = 1000,
noiseless: bool = False,
**kwargs
) -> Tuple[jnp.ndarray]:
"""
Samples from the DKL posterior at X_new points
"""
if self.y_train.ndim > 1:
raise NotImplementedError("Currently does not support a multi-channel regime")
y_mean, K = self.get_mvn_posterior(
X_new, self.nn_params, self.kernel_params, noiseless, **kwargs)
y_sampled = dist.MultivariateNormal(y_mean, K).sample(rng_key, sample_shape=(n,))
return y_mean, y_sampled
[docs] def get_samples(self) -> Tuple[Dict['str', jnp.ndarray]]:
"""Returns a tuple with trained NN weights and kernel hyperparameters"""
return self.nn_params, self.kernel_params
[docs] def predict_in_batches(self, rng_key: jnp.ndarray,
X_new: jnp.ndarray, batch_size: int = 100,
params: Optional[Dict[str, jnp.ndarray]] = None,
noiseless: bool = False,
**kwargs
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Make prediction at X_new with sampled DKL parameters
by spitting the input array into chunks ("batches") and running
self.predict on each of them one-by-one to avoid a memory overflow
"""
predict_fn = lambda xi: self.predict(
rng_key, xi, params, noiseless=noiseless, **kwargs)
cat_dim = 1 if self.y_train.ndim == 2 else 0
mean, var = self._predict_in_batches(
rng_key, X_new, batch_size, 0, params, predict_fn=predict_fn)
mean = jnp.concatenate(mean, cat_dim)
var = jnp.concatenate(var, cat_dim)
return mean, var
[docs] def predict(self, rng_key: jnp.ndarray, X_new: jnp.ndarray,
params: Optional[Tuple[Dict[str, jnp.ndarray]]] = None,
noiseless: bool = False, *args, **kwargs
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Make prediction at X_new points using a trained DKL model(s)
Args:
rng_key: random number generator key
X_new: New inputs
params: Tuple with neural network weigths and kernel parameters (optional)
noiseless:
Noise-free prediction. It is set to False by default as new/unseen data is assumed
to follow the same distribution as the training data. Hence, since we introduce a model noise
for the training data, we also want to include that noise in our prediction.
Returns:
Predictive mean and variance
"""
if params is None:
nn_params = self.nn_params
k_params = self.kernel_params
else:
nn_params, k_params = params
if self.y_train.ndim == 2: # y has shape (channels, samples)
# Define a wrapper to use with vmap
def _get_mvn_posterior(nn_params_i, k_params_i, yi):
mean, cov = self.get_mvn_posterior(
X_new, nn_params_i, k_params_i, noiseless, yi)
return mean, cov.diagonal()
# vectorize posterior predictive computation over the y's channel dimension
predictive = jax.vmap(_get_mvn_posterior)
mean, var = predictive(nn_params, k_params, self.y_train)
else: # y has shape (samples,)
# Standard prediction
mean, cov = self.get_mvn_posterior(
X_new, nn_params, k_params, noiseless)
var = cov.diagonal()
return mean, var
[docs] def fit_predict(self, rng_key: jnp.array, X: jnp.ndarray, y: jnp.ndarray,
X_new: jnp.ndarray, num_steps: int = 1000, step_size: float = 5e-3,
n_models: int = 1, batch_size: int = 100, noiseless: bool = False,
ensemble_method: str = 'vectorized',
print_summary: bool = True, progress_bar=True, **kwargs
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""
Run SVI to learn DKL model(s) parameters and make a prediction with
trained model(s) on new data. Allows using an ensemble of models.
Args:
rng_key: random number generator key
X: Input high-dimensional features
y: Target output (scalar of vector)
X_new: New ('test') data
num_steps: number of SVI steps
step_size: step size schedule for Adam optimizer
n_models: number of models in the ensemble (defaults to 1)
batch_size: prediction batch size (to avoid memory overflows)
noiseless:
Noise-free prediction. It is set to False by default as new/unseen data is assumed
to follow the same distribution as the training data. Hence, since we introduce a model noise
for the training data, we also want to include that noise in our prediction.
ensemble_method: 'vectorized' (single GPU) or 'parallel' (multiple GPUs)
print_summary: print summary at the end of sampling
progress_bar: show progress bar (works only for scalar outputs)
Returns:
Predictive mean and variance
"""
def single_fit_predict(key):
self.fit(key, X, y, num_steps, step_size,
print_summary, progress_bar, **kwargs)
mean, var = self.predict_in_batches(
key, X_new, batch_size, None, noiseless, **kwargs)
return mean, var
if n_models > 1 and ensemble_method not in ["vectorized", "parallel"]:
raise ValueError(
"For the ensemble_method, select between 'vectorized and 'parallel'.")
keys = jax.random.split(rng_key, num=n_models)
if n_models > 1:
pstrategy = jax.vmap if ensemble_method == 'vectorized' else jax.pmap
print_summary = progress_bar = 0
mean, var = pstrategy(single_fit_predict)(keys)
else:
mean, var = single_fit_predict(keys[0])
return mean, var
[docs] @partial(jit, static_argnames='self')
def embed(self, X_new: jnp.ndarray) -> jnp.ndarray:
"""
Use trained neural network(s) to embed the input data
into the latent space(s)
"""
def single_embed(nnpar_i, x_i):
return self.nn_module.apply(nnpar_i, jax.random.PRNGKey(0), x_i)
if self.X_train.ndim == len(self.data_dim) + 2:
z = jax.vmap(single_embed)(self.nn_params, X_new)
else:
z = single_embed(self.nn_params, X_new)
return z
def _print_summary(self) -> None:
if isinstance(self.kernel_params, dict):
print('\nInferred GP kernel parameters')
if self.X_train.ndim == len(self.data_dim) + 1:
for (k, vals) in self.kernel_params.items():
spaces = " " * (15 - len(k))
print(k, spaces, jnp.around(vals, 4))
else:
for (k, vals) in self.kernel_params.items():
for i, v in enumerate(vals):
spaces = " " * (15 - len(k))
print(k+"[{}]".format(i), spaces, jnp.around(v, 4))
class MLP(hk.Module):
"""Simple MLP"""
def __init__(self, embedim=2):
super().__init__()
self._embedim = embedim
def __call__(self, x):
x = hk.Linear(64)(x)
x = jax.nn.relu(x)
x = hk.Linear(64)(x)
x = jax.nn.relu(x)
x = hk.Linear(self._embedim)(x)
return x