Source code for gpax.models.mtgp

from typing import Callable, Dict, Optional

import jax.numpy as jnp
import numpy as onp
import numpyro
import numpyro.distributions as dist

from .gp import ExactGP
from ..kernels import LCMKernel


[docs]class MultiTaskGP(ExactGP): """ Gaussian process for multi-task/fidelity learning Args: input_dim: Number of input dimensions data_kernel: Kernel function operating on data inputs ('RBF', 'Matern', 'Periodic', or a custom function) num_latents: Number of latent functions. Typically equal to or less than the number of tasks shared_input_space: If True, assumes that all tasks share the same input space and uses a multivariate kernel (Kronecker product). If False (default), assumes that different tasks have different number of observations and uses a multitask kernel (elementwise multiplication). In that case, the task indices must be appended as the last column of the input vector. num_tasks: Number of tasks. This is only needed if `shared_input_space` is True. rank: Rank of the weight matrix in the task kernel. Cannot be larger than the number of tasks. Higher rank implies higher correlation. Uses *(num_tasks - 1)* when not specified. mean_fn: Optional deterministic mean function (use 'mean_fn_priors' to make it probabilistic) data_kernel_prior: Optional custom priors over the data kernel hyperparameters mean_fn_prior: Optional priors over mean function parameters noise_prior_dist: Optional custom prior distribution over the observational noise variance. Defaults to LogNormal(0,1). lengthscale_prior_dist: Optional custom prior distribution over kernel lengthscale. Defaults to LogNormal(0, 1) W_prior_dist: Optional custom prior distribution over W in the task kernel, :math:`WW^T + diag(v)`. Defaults to Normal(0, 10). v_prior_dist: Optional custom prior distribution over v in the task kernel, :math:`WW^T + diag(v)`. Must be non-negative. Defaults to LogNormal(0, 1) task_kernel_prior: Optional custom priors over task kernel parameters; Defaults to Normal(0, 10) for weights W and LogNormal(0, 1) for variances v. output_scale: Option to sample data kernel's output scale. Defaults to False to avoid over-parameterization (the scale is already absorbed into task kernel) """ def __init__(self, input_dim: int, data_kernel: str, num_latents: int = None, shared_input_space: bool = False, num_tasks: int = None, rank: Optional[int] = None, mean_fn: Optional[Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]] = None, data_kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, mean_fn_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, noise_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, noise_prior_dist: Optional[dist.Distribution] = None, lengthscale_prior_dist: Optional[dist.Distribution] = None, W_prior_dist: Optional[dist.Distribution] = None, v_prior_dist: Optional[dist.Distribution] = None, output_scale: bool = False, **kwargs) -> None: args = (input_dim, None, mean_fn, None, mean_fn_prior, noise_prior) super(MultiTaskGP, self).__init__(*args) if shared_input_space: if num_tasks is None: raise ValueError("Please specify num_tasks") else: if num_latents is None: raise ValueError("Please specify num_latents") self.num_tasks = num_tasks self.num_latents = num_tasks if num_latents is None else num_latents self.rank = rank self.kernel = LCMKernel( data_kernel, shared_input_space, num_tasks, **kwargs) self.data_kernel_name = data_kernel if isinstance(data_kernel, str) else None self.data_kernel_prior = data_kernel_prior self.noise_prior = noise_prior # will be removed self.noise_prior_dist = noise_prior_dist self.lengthscale_prior_dist = lengthscale_prior_dist self.W_prior_dist = W_prior_dist self.v_prior_dist = v_prior_dist self.shared_input = shared_input_space self.output_scale = output_scale
[docs] def model(self, X: jnp.ndarray, y: jnp.ndarray = None, **kwargs: float ) -> None: """Multitask GP probabilistic model with inputs X and targets y""" # Initialize mean function at zeros if self.shared_input: f_loc = jnp.zeros(self.num_tasks * X.shape[0]) else: f_loc = jnp.zeros(X.shape[0]) # Check that we have necessary info for sampling kernel params if not self.shared_input and self.num_tasks is None: self.num_tasks = len(onp.unique(self.X_train[:, -1])) if self.rank is None: self.rank = self.num_tasks - 1 # Sample data kernel parameters if self.data_kernel_prior: data_kernel_params = self.data_kernel_prior() else: data_kernel_params = self._sample_kernel_params() # Sample task kernel parameters task_kernel_params = self._sample_task_kernel_params() # Combine two dictionaries with parameters kernel_params = {**data_kernel_params, **task_kernel_params} # Sample noise if self.noise_prior: # this will be removed in the future releases noise = self.noise_prior() else: noise = self._sample_noise() # Compute multitask_kernel k = self.kernel(X, X, kernel_params, noise, **kwargs) # Add mean function (if any) if self.mean_fn is not None: args = [X] if self.mean_fn_prior is not None: args += [self.mean_fn_prior()] f_loc += self.mean_fn(*args).squeeze() # Sample y according to the standard Gaussian process formula numpyro.sample( "y", dist.MultivariateNormal(loc=f_loc, covariance_matrix=k), obs=y, )
def _sample_noise(self): """Sample observational noise""" if self.noise_prior_dist is not None: noise_dist = self.noise_prior_dist else: noise_dist = dist.LogNormal( jnp.zeros(self.num_tasks), jnp.ones(self.num_tasks)) noise = numpyro.sample("noise", noise_dist.to_event(1)) return noise def _sample_task_kernel_params(self): """ Sample task kernel parameters with default weakly-informative priors or custom priors for all the latent functions """ if self.W_prior_dist is not None: W_dist = self.W_prior_dist else: W_dist = dist.Normal( jnp.zeros(shape=(self.num_latents, self.num_tasks, self.rank)), # loc 10*jnp.ones(shape=(self.num_latents, self.num_tasks, self.rank)) # var ) if self.v_prior_dist is not None: v_dist = self.v_prior_dist else: v_dist = dist.LogNormal( jnp.zeros(shape=(self.num_latents, self.num_tasks)), # loc jnp.ones(shape=(self.num_latents, self.num_tasks)) # var ) with numpyro.plate("latent_plate_task", self.num_latents): W = numpyro.sample("W", W_dist.to_event(2)) v = numpyro.sample("v", v_dist.to_event(1)) return {"W": W, "v": v} def _sample_kernel_params(self): """ Sample data ("base") kernel parameters with default weakly-informative priors for all the latent functions. Optionally allows to specify a custom prior over the kernel lengthscale. """ squeezer = lambda x: x.squeeze() if self.num_latents > 1 else x if self.lengthscale_prior_dist is not None: length_dist = self.lengthscale_prior_dist else: length_dist = dist.LogNormal(0.0, 1.0) with numpyro.plate("latent_plate_data", self.num_latents, dim=-2): with numpyro.plate("ard", self.kernel_dim, dim=-1): length = numpyro.sample("k_length", length_dist) if self.output_scale: scale = numpyro.sample("k_scale", dist.LogNormal(0.0, 1.0)) else: scale = numpyro.deterministic("k_scale", jnp.ones(self.num_latents)) if self.data_kernel_name == 'Periodic': period = numpyro.sample("period", dist.LogNormal(0.0, 1.0)) kernel_params = { "k_length": squeezer(length), "k_scale": squeezer(scale), "period": squeezer(period) if self.data_kernel_name == "Periodic" else None } return kernel_params