"""
dkl.py
=======
Fully Bayesian implementation of deep kernel learning
Created by Maxim Ziatdinov (email: maxim.ziatdinov@gmail.com)
"""
from functools import partial
from typing import Callable, Dict, Optional, Tuple, Union, List
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from jax import jit
from .gp import ExactGP
[docs]class DKL(ExactGP):
"""
Fully Bayesian implementation of deep kernel learning
Args:
input_dim:
Number of input dimensions
z_dim:
Latent space dimensionality (defaults to 2)
kernel:
Kernel function ('RBF', 'Matern', 'Periodic', or custom function)
kernel_prior:
Optional priors over kernel hyperparameters; uses LogNormal(0,1) by default
nn:
Custom neural network ('feature extractor'); uses a 3-layer MLP
with hyperbolic tangent activations by default
nn_prior:
Priors over the weights and biases in 'nn'; uses normal priors by default
latent_prior:
Optional prior over the latent space (BNN embedding); uses none by default
hidden_dim:
Optional custom MLP architecture. For example [16, 8, 4] corresponds to a 3-layer
neural network backbone containing 16, 8, and 4 neurons activated by tanh(). The latent
layer is added autoamtically and doesn't have to be spcified here. Defaults to [64, 32].
**kwargs:
Optional custom prior distributions over observational noise (noise_dist_prior)
and kernel lengthscale (lengthscale_prior_dist)
Examples:
DKL with image patches as inputs and a 1-d vector as targets
>>> # Get random number generator keys for training and prediction
>>> key1, key2 = gpax.utils.get_keys()
>>> input data dimensions are (n, height*width*channels)
>>> data_dim = X.shape[-1]
>>> # Initialize DKL model with 2 latent dimensions
>>> dkl = gpax.DKL(data_dim, z_dim=2, kernel='RBF')
>>> # Train model by parallelizing HMC chains on a single GPU
>>> dkl.fit(key1, X, y, num_warmup=333, num_samples=333, num_chains=3, chain_method='vectorized')
>>> # Obtain posterior mean and samples from DKL posterior at new inputs
>>> # using batches to avoid memory overflow
>>> y_pred, y_samples = dkl.predict_in_batches(key2, X_new)
"""
def __init__(self, input_dim: int, z_dim: int = 2, kernel: str = 'RBF',
kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
nn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None,
nn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None,
latent_prior: Optional[Callable[[jnp.ndarray], Dict[str, jnp.ndarray]]] = None,
hidden_dim: Optional[List[int]] = None, **kwargs
) -> None:
super(DKL, self).__init__(input_dim, kernel, None, kernel_prior, **kwargs)
hdim = hidden_dim if hidden_dim is not None else [64, 32]
self.nn = nn if nn else get_mlp(hdim)
self.nn_prior = nn_prior if nn_prior else get_mlp_prior(input_dim, z_dim, hdim)
self.kernel_dim = z_dim
self.latent_prior = latent_prior
[docs] def model(self,
X: jnp.ndarray,
y: jnp.ndarray = None,
**kwargs: float
) -> None:
"""DKL probabilistic model"""
jitter = kwargs.get("jitter", 1e-6)
# BNN part
nn_params = self.nn_prior()
z = self.nn(X, nn_params)
if self.latent_prior: # Sample latent variable
z = self.latent_prior(z)
# Sample GP kernel parameters
if self.kernel_prior:
kernel_params = self.kernel_prior()
else:
kernel_params = self._sample_kernel_params()
# Sample noise
noise = self._sample_noise()
# GP's mean function
f_loc = jnp.zeros(z.shape[0])
# compute kernel
k = self.kernel(z, z, kernel_params, noise, jitter=jitter)
# Sample y according to the standard Gaussian process formula
numpyro.sample(
"y",
dist.MultivariateNormal(loc=f_loc, covariance_matrix=k),
obs=y,
)
[docs] def get_mvn_posterior(self,
X_new: jnp.ndarray,
params: Dict[str, jnp.ndarray],
noiseless: bool = False,
**kwargs: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
noise = params["noise"]
noise_p = noise * (1 - jnp.array(noiseless, int))
# embed data into the latent space
z_train = self.nn(self.X_train, params)
z_new = self.nn(X_new, params)
# compute kernel matrices for train and new ('test') data
k_pp = self.kernel(z_new, z_new, params, noise_p, **kwargs)
k_pX = self.kernel(z_new, z_train, params, jitter=0.0)
k_XX = self.kernel(z_train, z_train, params, noise, **kwargs)
# compute the predictive covariance and mean
K_xx_inv = jnp.linalg.inv(k_XX)
cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX)))
mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, self.y_train))
return mean, cov
[docs] @partial(jit, static_argnames='self')
def embed(self, X_new: jnp.ndarray) -> jnp.ndarray:
"""
Embeds data into the latent space using the inferred weights
of the DKL's Bayesian neural network
"""
samples = self.get_samples(chain_dim=False)
predictive = jax.vmap(lambda params: self.nn(X_new, params))
z = predictive(samples)
return z
def _print_summary(self):
list_of_keys = ["k_scale", "k_length", "noise", "period"]
samples = self.get_samples(1)
numpyro.diagnostics.print_summary(
{k: v for (k, v) in samples.items() if k in list_of_keys})
def sample_weights(name: str, in_channels: int, out_channels: int) -> jnp.ndarray:
"""Sampling weights matrix"""
w = numpyro.sample(name=name, fn=dist.Normal(
loc=jnp.zeros((in_channels, out_channels)),
scale=jnp.ones((in_channels, out_channels))))
return w
def sample_biases(name: str, channels: int) -> jnp.ndarray:
"""Sampling bias vector"""
b = numpyro.sample(name=name, fn=dist.Cauchy(
loc=jnp.zeros((channels)), scale=jnp.ones((channels))))
return b
def get_mlp(architecture: List[int]) -> Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]:
"""Returns a function that represents an MLP for a given architecture."""
def mlp(X: jnp.ndarray, params: Dict[str, jnp.ndarray]) -> jnp.ndarray:
"""MLP for a single MCMC sample of weights and biases, handling arbitrary number of layers."""
h = X
for i in range(len(architecture)):
h = jnp.tanh(jnp.matmul(h, params[f"w{i}"]) + params[f"b{i}"])
# No non-linearity after the last layer
z = jnp.matmul(h, params[f"w{len(architecture)}"]) + params[f"b{len(architecture)}"]
return z
return mlp
def get_mlp_prior(input_dim: int, output_dim: int, architecture: List[int]) -> Callable[[], Dict[str, jnp.ndarray]]:
"""Priors over weights and biases for a Bayesian MLP"""
def mlp_prior():
params = {}
in_channels = input_dim
for i, out_channels in enumerate(architecture):
params[f"w{i}"] = sample_weights(f"w{i}", in_channels, out_channels)
params[f"b{i}"] = sample_biases(f"b{i}", out_channels)
in_channels = out_channels
# Output layer
params[f"w{len(architecture)}"] = sample_weights(f"w{len(architecture)}", in_channels, output_dim)
params[f"b{len(architecture)}"] = sample_biases(f"b{len(architecture)}", output_dim)
return params
return mlp_prior