Source code for gpax.kernels.kernels

"""
kernels.py
==========

Kernel functions

Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com)
"""

from typing import Union, Dict, Callable

import math

import jax.numpy as jnp
from jax import jit, vmap

kernel_fn_type = Callable[[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray], jnp.ndarray],  jnp.ndarray]


def _sqrt(x, eps=1e-12):
    return jnp.sqrt(x + eps)


def add_jitter(x, jitter=1e-6):
    return x + jitter


def square_scaled_distance(X: jnp.ndarray, Z: jnp.ndarray,
                           lengthscale: Union[jnp.ndarray, float] = 1.
                           ) -> jnp.ndarray:
    r"""
    Computes a square of scaled distance, :math:`\|\frac{X-Z}{l}\|^2`,
    between X and Z are vectors with :math:`n x num_features` dimensions
    """
    scaled_X = X / lengthscale
    scaled_Z = Z / lengthscale
    X2 = (scaled_X ** 2).sum(1, keepdims=True)
    Z2 = (scaled_Z ** 2).sum(1, keepdims=True)
    XZ = jnp.matmul(scaled_X, scaled_Z.T)
    r2 = X2 - 2 * XZ + Z2.T
    return r2.clip(0)


[docs]@jit def RBFKernel(X: jnp.ndarray, Z: jnp.ndarray, params: Dict[str, jnp.ndarray], noise: int = 0, jitter: float = 1e-6, **kwargs) -> jnp.ndarray: """ Radial basis function kernel Args: X: 2D vector with *(number of points, number of features)* dimension Z: 2D vector with *(number of points, number of features)* dimension params: Dictionary with kernel hyperparameters 'k_length' and 'k_scale' noise: optional noise vector with dimension (n,) Returns: Computed kernel matrix betwenen X and Z """ r2 = square_scaled_distance(X, Z, params["k_length"]) k = params["k_scale"] * jnp.exp(-0.5 * r2) if X.shape == Z.shape: k += add_jitter(noise, jitter) * jnp.eye(X.shape[0]) return k
[docs]@jit def MaternKernel(X: jnp.ndarray, Z: jnp.ndarray, params: Dict[str, jnp.ndarray], noise: int = 0, jitter: float = 1e-6, **kwargs) -> jnp.ndarray: """ Matern52 kernel Args: X: 2D vector with *(number of points, number of features)* dimension Z: 2D vector with *(number of points, number of features)* dimension params: Dictionary with kernel hyperparameters 'k_length' and 'k_scale' noise: optional noise vector with dimension (n,) Returns: Computed kernel matrix between X and Z """ r2 = square_scaled_distance(X, Z, params["k_length"]) r = _sqrt(r2) sqrt5_r = 5**0.5 * r k = params["k_scale"] * (1 + sqrt5_r + (5/3) * r2) * jnp.exp(-sqrt5_r) if X.shape == Z.shape: k += add_jitter(noise, jitter) * jnp.eye(X.shape[0]) return k
[docs]@jit def PeriodicKernel(X: jnp.ndarray, Z: jnp.ndarray, params: Dict[str, jnp.ndarray], noise: int = 0, jitter: float = 1e-6, **kwargs ) -> jnp.ndarray: """ Periodic kernel Args: X: 2D vector with *(number of points, number of features)* dimension Z: 2D vector with *(number of points, number of features)* dimension params: Dictionary with kernel hyperparameters 'k_length', 'k_scale', and 'period' noise: optional noise vector with dimension (n,) Returns: Computed kernel matrix between X and Z """ d = X[:, None] - Z[None] scaled_sin = jnp.sin(math.pi * d / params["period"]) / params["k_length"] k = params["k_scale"] * jnp.exp(-2 * (scaled_sin ** 2).sum(-1)) if X.shape == Z.shape: k += add_jitter(noise, jitter) * jnp.eye(X.shape[0]) return k
def nngp_erf(x1: jnp.ndarray, x2: jnp.ndarray, var_b: jnp.array, var_w: jnp.array, depth: int = 3) -> jnp.array: """ Computes the Neural Network Gaussian Process (NNGP) kernel value for a single pair of inputs using the Erf activation. Args: x1: First input vector. x2: Second input vector. var_b: Bias variance. var_w: Weight variance. depth: The number of layers in the corresponding infinite-width neural network. Controls the level of recursion in the computation. Returns: Kernel value for the pair of inputs. """ d = x1.shape[-1] if depth == 0: return var_b + var_w * jnp.sum(x1 * x2, axis=-1) / d else: K_12 = nngp_erf(x1, x2, var_b, var_w, depth - 1) K_11 = nngp_erf(x1, x1, var_b, var_w, depth - 1) K_22 = nngp_erf(x2, x2, var_b, var_w, depth - 1) sqrt_term = jnp.sqrt((1 + 2 * K_11) * (1 + 2 * K_22)) fraction = 2 * K_12 / sqrt_term epsilon = 1e-7 theta = jnp.arcsin(jnp.clip(fraction, a_min=-1 + epsilon, a_max=1 - epsilon)) result = var_b + 2 * var_w / jnp.pi * theta return result def nngp_relu(x1: jnp.ndarray, x2: jnp.ndarray, var_b: jnp.array, var_w: jnp.array, depth: int = 3) -> jnp.array: """ Computes the Neural Network Gaussian Process (NNGP) kernel value for a single pair of inputs using RELU activation. Args: x1: First input vector. x2: Second input vector. var_b: Bias variance. var_w: Weight variance. depth: The number of layers in the corresponding infinite-width neural network. Controls the level of recursion in the computation. Returns: Kernel value for the pair of inputs. """ eps = 1e-7 d = x1.shape[-1] if depth == 0: return var_b + var_w * jnp.sum(x1 * x2, axis=-1) / d else: K_12 = nngp_relu(x1, x2, var_b, var_w, depth - 1, ) K_11 = nngp_relu(x1, x1, var_b, var_w, depth - 1, ) K_22 = nngp_relu(x2, x2, var_b, var_w, depth - 1, ) sqrt_term = jnp.sqrt(K_11 * K_22) fraction = K_12 / sqrt_term theta = jnp.arccos(jnp.clip(fraction, a_min=-1 + eps, a_max=1 - eps)) theta_term = jnp.sin(theta) + (jnp.pi - theta) * fraction return var_b + var_w / (2 * jnp.pi) * sqrt_term * theta_term
[docs]def NNGPKernel(activation: str = 'erf', depth: int = 3 ) -> Callable[[jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray]: """ Neural Network Gaussian Process (NNGP) kernel function Args: activation: activation function ('erf' or 'relu') depth: The number of layers in the corresponding infinite-width neural network. Controls the level of recursion in the computation. Returns: Function for computing kernel matrix between X and Z. """ nngp_single_pair_ = nngp_relu if activation == 'relu' else nngp_erf def NNGPKernel_func(X: jnp.ndarray, Z: jnp.ndarray, params: Dict[str, jnp.ndarray], noise: jnp.ndarray = 0, jitter: float = 1e-6, **kwargs ) -> jnp.ndarray: """ Computes the Neural Network Gaussian Process (NNGP) kernel. Args: X: First set of input vectors. Z: Second set of input vectors. params: Dictionary containing bias variance and weight variance Returns: Computed kernel matrix between X and Z. """ var_b = params["var_b"] var_w = params["var_w"] k = vmap(lambda x: vmap(lambda z: nngp_single_pair_(x, z, var_b, var_w, depth))(Z))(X) if X.shape == Z.shape: k += add_jitter(noise, jitter) * jnp.eye(X.shape[0]) return k return NNGPKernel_func
def get_kernel(kernel: Union[str, kernel_fn_type] = 'RBF', **kwargs): kernel_book = { 'RBF': RBFKernel, 'Matern': MaternKernel, 'Periodic': PeriodicKernel, 'NNGP': NNGPKernel(**kwargs) } if isinstance(kernel, str): try: kernel = kernel_book[kernel] except KeyError: print('Select one of the currently available kernels:', *kernel_book.keys()) raise return kernel