Source code for gpax.models.vi_mtdkl

"""
vi_mtdkl.py
========

Variational inference-based implementation of multi-task deep kernel learning

Created by Maxim Ziatdinov (email: maxim.ziatdinov@gmail.com)
"""

from functools import partial
from typing import Callable, Dict, Optional, Tuple

import jax
from jax import jit
import jax.numpy as jnp
import numpy as onp
import numpyro
import numpyro.distributions as dist
from numpyro.contrib.module import random_haiku_module, haiku_module

from ..kernels import LCMKernel
from .vidkl import viDKL


[docs]class viMTDKL(viDKL): """ Implementation of the variational inference-based deep kernel learning for multi-task/fidelity problems Args: input_dim: Number of input dimensions, not counting the column with task indices (if any) z_dim: Latent space dimensionality (defaults to 2) data_kernel: Kernel function operating on data inputs ('RBF', 'Matern', 'Periodic', or a custom function) num_latents: Number of latent functions. Typically equal to or less than the number of tasks shared_input_space: If True, assumes that all tasks share the same input space and uses a multivariate kernel (Kronecker product). If False (default), assumes that different tasks have different number of observations and uses a multitask kernel (elementwise multiplication). In that case, the task indices must be appended as the last column of the input vector. num_tasks: Number of tasks. This is only needed if `shared_input_space` is True. rank: Rank of the weight matrix in the task kernel. Cannot be larger than the number of tasks. Higher rank implies higher correlation. Uses *(num_tasks - 1)* when not specified. data_kernel_prior: Optional priors over kernel hyperparameters; uses LogNormal(0,1) by default nn: Custom neural network ('feature extractor'); uses a 3-layer MLP with ReLU activations by default nn_prior: Places probabilistic priors over NN weights and biases (Default: True) latent_prior: Optional prior over the latent space (NN embedding); uses none by default guide: Auto-guide option, use 'delta' (default) or 'normal' W_prior_dist: Optional custom prior distribution over W in the task kernel, :math:`WW^T + diag(v)`. Defaults to Normal(0, 10). v_prior_dist: Optional custom prior distribution over v in the task kernel, :math:`WW^T + diag(v)`. Must be non-negative. Defaults to LogNormal(0, 1) task_kernel_prior: Optional custom priors over task kernel parameters; Defaults to Normal(0, 10) for weights W and LogNormal(0, 1) for variances v. **kwargs: Optional custom prior distributions over observational noise (noise_dist_prior) and kernel lengthscale (lengthscale_prior_dist) """ def __init__(self, input_dim: int, z_dim: int = 2, data_kernel: str = 'RBF', num_latents: int = None, shared_input_space: bool = False, num_tasks: int = None, rank: Optional[int] = None, data_kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, nn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, nn_prior: bool = True, guide: str = 'delta', W_prior_dist: Optional[dist.Distribution] = None, v_prior_dist: Optional[dist.Distribution] = None, task_kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, **kwargs) -> None: args = (input_dim, z_dim, None, None, nn, nn_prior, None, guide) super(viMTDKL, self).__init__(*args, **kwargs) if shared_input_space: if num_tasks is None: raise ValueError("Please specify num_tasks") else: if num_latents is None: raise ValueError("Please specify num_latents") self.num_tasks = num_tasks self.num_latents = num_tasks if num_latents is None else num_latents self.rank = rank self.kernel = LCMKernel( data_kernel, shared_input_space, num_tasks, **kwargs) self.data_kernel_prior = data_kernel_prior self.task_kernel_prior = task_kernel_prior self.shared_input = shared_input_space self.W_prior_dist = W_prior_dist self.v_prior_dist = v_prior_dist
[docs] def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs) -> None: """Multitask DKL probabilistic model""" # Check that we have necessary info for sampling kernel params if not self.shared_input and self.num_tasks is None: self.num_tasks = len(onp.unique(self.X_train[:, -1])) if self.rank is None: self.rank = self.num_tasks - 1 # NN part if self.nn_prior: # MAP feature_extractor = random_haiku_module( "feature_extractor", self.nn_module, input_shape=(1, *self.data_dim), prior=(lambda name, shape: dist.Cauchy() if name.startswith("b") else dist.Normal())) else: # MLE feature_extractor = haiku_module( "feature_extractor", self.nn_module, input_shape=(1, *self.data_dim)) z = feature_extractor(X if self.shared_input else X[:, :-1]) if not self.shared_input: z = jnp.column_stack((z, X[:, -1])) # Initialize GP kernel mean function at zeros if self.shared_input: f_loc = jnp.zeros(self.num_tasks * X.shape[0]) else: f_loc = jnp.zeros(X.shape[0]) # Sample data kernel parameters if self.data_kernel_prior: data_kernel_params = self.data_kernel_prior() else: data_kernel_params = self._sample_kernel_params() # Sample task kernel parameters if self.task_kernel_prior: task_kernel_params = self.task_kernel_prior() else: task_kernel_params = self._sample_task_kernel_params() # Combine two dictionaries with parameters kernel_params = {**data_kernel_params, **task_kernel_params} # Sample noise if self.noise_prior: # this will be removed in the future releases noise = self.noise_prior() else: noise = self._sample_noise() # Compute multitask_kernel k = self.kernel(z, z, kernel_params, noise, **kwargs) # Sample y according to the standard Gaussian process formula numpyro.sample( "y", dist.MultivariateNormal(loc=f_loc, covariance_matrix=k), obs=y, )
def _sample_noise(self): """Sample observational noise""" if self.noise_prior_dist is not None: noise_dist = self.noise_prior_dist else: noise_dist = dist.LogNormal( jnp.zeros(self.num_tasks), jnp.ones(self.num_tasks)) noise = numpyro.sample("noise", noise_dist.to_event(1)) return noise def _sample_task_kernel_params(self): """ Sample task kernel parameters with default weakly-informative priors or custom priors for all the latent functions """ if self.W_prior_dist is not None: W_dist = self.W_prior_dist else: W_dist = dist.Normal( jnp.zeros(shape=(self.num_latents, self.num_tasks, self.rank)), # loc 10*jnp.ones(shape=(self.num_latents, self.num_tasks, self.rank)) # var ) if self.v_prior_dist is not None: v_dist = self.v_prior_dist else: v_dist = dist.LogNormal( jnp.zeros(shape=(self.num_latents, self.num_tasks)), # loc jnp.ones(shape=(self.num_latents, self.num_tasks)) # var ) with numpyro.plate("latent_plate_task", self.num_latents): W = numpyro.sample("W", W_dist.to_event(2)) v = numpyro.sample("v", v_dist.to_event(1)) return {"W": W, "v": v} def _sample_kernel_params(self): """ Sample data ("base") kernel parameters with default weakly-informative priors for all the latent functions """ squeezer = lambda x: x.squeeze() if self.num_latents > 1 else x with numpyro.plate("latent_plate_data", self.num_latents, dim=-2): with numpyro.plate("ard", self.kernel_dim, dim=-1): length = numpyro.sample("k_length", dist.LogNormal(0.0, 1.0)) scale = numpyro.sample("k_scale", dist.Normal(1.0, 1e-4)) return {"k_length": squeezer(length), "k_scale": squeezer(scale)} #@partial(jit, static_argnames='self')
[docs] def get_mvn_posterior(self, X_new: jnp.ndarray, nn_params: Dict[str, jnp.ndarray], k_params: Dict[str, jnp.ndarray], noiseless: bool = False, y_residual: jnp.ndarray = None, **kwargs ) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Returns predictive mean and covariance at new points (mean and cov, where cov.diagonal() is 'uncertainty') given a single set of DKL parameters """ if y_residual is None: y_residual = self.y_train noise = k_params["noise"] noise_p = noise * (1 - jnp.array(noiseless, int)) # embed data into the latent space z_train = self.nn_module.apply( nn_params, jax.random.PRNGKey(0), self.X_train if self.shared_input else self.X_train[:, :-1]) z_test = self.nn_module.apply( nn_params, jax.random.PRNGKey(0), X_new if self.shared_input else X_new[:, :-1]) if not self.shared_input: z_train = jnp.column_stack((z_train, self.X_train[:, -1])) z_test = jnp.column_stack((z_test, X_new[:, -1])) # compute kernel matrices for train and test data k_pp = self.kernel(z_test, z_test, k_params, noise_p, **kwargs) k_pX = self.kernel(z_test, z_train, k_params, jitter=0.0) k_XX = self.kernel(z_train, z_train, k_params, noise, **kwargs) # compute the predictive covariance and mean K_xx_inv = jnp.linalg.inv(k_XX) cov = k_pp - jnp.matmul(k_pX, jnp.matmul(K_xx_inv, jnp.transpose(k_pX))) mean = jnp.matmul(k_pX, jnp.matmul(K_xx_inv, y_residual)) return mean, cov