Source code for gpax.acquisition.acquisition

"""
acquisition.py
==============

Acquisition functions

Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com)
"""

from typing import Type, Optional, Tuple

import jax.numpy as jnp
import jax.random as jra
from jax import vmap
import numpy as onp

from ..models.gp import ExactGP
from .base_acq import ei, ucb, poi, ue, kg
from .penalties import compute_penalty


def _compute_mean_and_var(
        rng_key: jnp.ndarray, model: Type[ExactGP], X: jnp.ndarray,
        n: int, noiseless: bool, **kwargs) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Computes predictive mean and variance
    """
    if model.mcmc is not None:
        _, y_sampled = model.predict(
            rng_key, X, n=n, noiseless=noiseless, **kwargs)
        y_sampled = y_sampled.reshape(n * y_sampled.shape[0], -1)
        mean, var = y_sampled.mean(0), y_sampled.var(0)
    else:
        mean, var = model.predict(rng_key, X, noiseless=noiseless, **kwargs)
    return mean, var


def _compute_penalties(
        X: jnp.ndarray, recent_points: jnp.ndarray, penalty: str,
        penalty_factor: float, grid_indices: jnp.ndarray) -> jnp.ndarray:
    """
    Computes penaltes for recent points to be substracted
    from acqusition function values
    """
    X_ = grid_indices if grid_indices is not None else X
    return compute_penalty(X_, recent_points, penalty, penalty_factor)


[docs]def EI(rng_key: jnp.ndarray, model: Type[ExactGP], X: jnp.ndarray, best_f: float = None, maximize: bool = False, n: int = 1, noiseless: bool = False, penalty: Optional[str] = None, recent_points: jnp.ndarray = None, grid_indices: jnp.ndarray = None, penalty_factor: float = 1.0, **kwargs) -> jnp.ndarray: r""" Expected Improvement Given a probabilistic model :math:`m` that models the objective function :math:`f`, the Expected Improvement at an input point :math:`x` is defined as: .. math:: EI(x) = \begin{cases} (\mu(x) - f^+) \Phi(Z) + \sigma(x) \phi(Z) & \text{if } \sigma(x) > 0 \\ 0 & \text{if } \sigma(x) = 0 \end{cases} where :math:`\mu(x)` is the predictive mean, :math:`\sigma(x)` is the predictive standard deviation, :math:`f^+` is the value of the best observed sample. :math:`Z` is defined as: .. math:: Z = \frac{\mu(x) - f^+}{\sigma(x)} provided :math:`\sigma(x) > 0`. In the case of HMC, the function leverages multiple predictive posteriors, each associated with a different HMC sample of the GP model parameters, to capture both prediction uncertainty and hyperparameter uncertainty. In this setup, the uncertainty in parameters of probabilistic mean function (if any) also contributes to the acquisition function values. Args: rng_key: JAX random number generator key model: trained model X: new inputs best_f: Best function value observed so far. Derived from the predictive mean when not provided by a user. maximize: If True, assumes that BO is solving maximization problem n: number of samples drawn from each MVN distribution (number of distributions is equal to the number of HMC samples) noiseless: Noise-free prediction. It is set to False by default as new/unseen data is assumed to follow the same distribution as the training data. Hence, since we introduce a model noise for the training data, we also want to include that noise in our prediction. penalty: Penalty applied to the acquisition function to discourage re-evaluation at or near points that were recently evaluated. Options are: - 'delta': The infinite penalty is applied to the recently visited points. - 'inverse_distance': Modifies the acquisition function by penalizing points near the recent points. For the 'inverse_distance', the acqusition function is penalized as: .. math:: \alpha - \lambda \cdot \pi(X, r) where :math:`\pi(X, r)` computes a penalty for points in :math:`X` based on their distance to recent points :math:`r`, :math:`\alpha` represents the acquisition function, and :math:`\lambda` represents the penalty factor. recent_points: An array of recently visited points [oldest, ..., newest] provided by user grid_indices: Grid indices of data points in X array for the penalty term calculation. For example, if each data point is an image patch, the indices could correspond to the (i, j) pixel coordinates of their centers in the original image. penalty_factor: Penalty factor :math:`\lambda` in :math:`\alpha - \lambda \cdot \pi(X, r)` **jitter: Small positive term added to the diagonal part of a covariance matrix for numerical stability (Default: 1e-6) """ if penalty and not isinstance(recent_points, (onp.ndarray, jnp.ndarray)): raise ValueError("Please provide an array of recently visited points") X = X[:, None] if X.ndim < 2 else X moments = _compute_mean_and_var(rng_key, model, X, n, noiseless, **kwargs) acq = ei(moments, best_f, maximize) if penalty: acq -= _compute_penalties(X, recent_points, penalty, penalty_factor, grid_indices) return acq
[docs]def UCB(rng_key: jnp.ndarray, model: Type[ExactGP], X: jnp.ndarray, beta: float = .25, maximize: bool = False, n: int = 1, noiseless: bool = False, penalty: Optional[str] = None, recent_points: jnp.ndarray = None, grid_indices: jnp.ndarray = None, penalty_factor: float = 1.0, **kwargs) -> jnp.ndarray: r""" Upper confidence bound Given a probabilistic model :math:`m` that models the objective function :math:`f`, the Upper Confidence Bound at an input point :math:`x` is defined as: .. math:: UCB(x) = \mu(x) + \kappa \sigma(x) where :math:`\mu(x)` is the predictive mean, :math:`\sigma(x)` is the predictive standard deviation, and :math:`\kappa` is the exploration-exploitation trade-off parameter. In the case of HMC, the function leverages multiple predictive posteriors, each associated with a different HMC sample of the GP model parameters, to capture both prediction uncertainty and hyperparameter uncertainty. In this setup, the uncertainty in parameters of probabilistic mean function (if any) also contributes to the acquisition function values. Args: rng_key: JAX random number generator key model: trained model X: new inputs beta: coefficient balancing exploration-exploitation trade-off maximize: If True, assumes that BO is solving maximization problem n: number of samples drawn from each MVN distribution (number of distributions is equal to the number of HMC samples) noiseless: Noise-free prediction. It is set to False by default as new/unseen data is assumed to follow the same distribution as the training data. Hence, since we introduce a model noise for the training data, we also want to include that noise in our prediction. penalty: Penalty applied to the acquisition function to discourage re-evaluation at or near points that were recently evaluated. Options are: - 'delta': The infinite penalty is applied to the recently visited points. - 'inverse_distance': Modifies the acquisition function by penalizing points near the recent points. For the 'inverse_distance', the acqusition function is penalized as: .. math:: \alpha - \lambda \cdot \pi(X, r) where :math:`\pi(X, r)` computes a penalty for points in :math:`X` based on their distance to recent points :math:`r`, :math:`\alpha` represents the acquisition function, and :math:`\lambda` represents the penalty factor. recent_points: An array of recently visited points [oldest, ..., newest] provided by user grid_indices: Grid indices of data points in X array for the penalty term calculation. For example, if each data point is an image patch, the indices could correspond to the (i, j) pixel coordinates of their centers in the original image. penalty_factor: Penalty factor :math:`\lambda` in :math:`\alpha - \lambda \cdot \pi(X, r)` **jitter: Small positive term added to the diagonal part of a covariance matrix for numerical stability (Default: 1e-6) """ if penalty and not isinstance(recent_points, (onp.ndarray, jnp.ndarray)): raise ValueError("Please provide an array of recently visited points") X = X[:, None] if X.ndim < 2 else X moments = _compute_mean_and_var(rng_key, model, X, n, noiseless, **kwargs) acq = ucb(moments, beta, maximize) if penalty: acq -= _compute_penalties(X, recent_points, penalty, penalty_factor, grid_indices) return acq
[docs]def POI(rng_key: jnp.ndarray, model: Type[ExactGP], X: jnp.ndarray, best_f: float = None, xi: float = 0.01, maximize: bool = False, n: int = 1, noiseless: bool = False, penalty: Optional[str] = None, recent_points: jnp.ndarray = None, grid_indices: jnp.ndarray = None, penalty_factor: float = 1.0, **kwargs) -> jnp.ndarray: r""" Probability of Improvement Given a probabilistic model :math:`m` that models the objective function :math:`f`, the Probability of Improvement at an input point :math:`x` is defined as: .. math:: PI(x) = \Phi\left(\frac{\mu(x) - f^+ - \xi}{\sigma(x)}\right) where :math:`\mu(x)` is the predictive mean, :math:`\sigma(x)` is the predictive standard deviation, :math:`f^+` is the value of the best observed sample, :math:`\xi` is a small positive "jitter" term to encourage more exploration, and :math:`\Phi` is the cumulative distribution function (CDF) of the standard normal distribution. In the case of HMC, the function leverages multiple predictive posteriors, each associated with a different HMC sample of the GP model parameters, to capture both prediction uncertainty and hyperparameter uncertainty. In this setup, the uncertainty in parameters of probabilistic mean function (if any) also contributes to the acquisition function values. Args: rng_key: JAX random number generator key model: trained model X: new inputs best_f: Best function value observed so far. Derived from the predictive mean when not provided by a user. xi: coefficient affecting exploration-exploitation trade-off maximize: If True, assumes that BO is solving maximization problem n: number of samples drawn from each MVN distribution (number of distributions is equal to the number of HMC samples) noiseless: Noise-free prediction. It is set to False by default as new/unseen data is assumed to follow the same distribution as the training data. Hence, since we introduce a model noise for the training data, we also want to include that noise in our prediction. penalty: Penalty applied to the acquisition function to discourage re-evaluation at or near points that were recently evaluated. Options are: - 'delta': The infinite penalty is applied to the recently visited points. - 'inverse_distance': Modifies the acquisition function by penalizing points near the recent points. For the 'inverse_distance', the acqusition function is penalized as: .. math:: \alpha - \lambda \cdot \pi(X, r) where :math:`\pi(X, r)` computes a penalty for points in :math:`X` based on their distance to recent points :math:`r`, :math:`\alpha` represents the acquisition function, and :math:`\lambda` represents the penalty factor. recent_points: An array of recently visited points [oldest, ..., newest] provided by user grid_indices: Grid indices of data points in X array for the penalty term calculation. For example, if each data point is an image patch, the indices could correspond to the (i, j) pixel coordinates of their centers in the original image. penalty_factor: Penalty factor :math:`\lambda` in :math:`\alpha - \lambda \cdot \pi(X, r)` **jitter: Small positive term added to the diagonal part of a covariance matrix for numerical stability (Default: 1e-6) """ if penalty and not isinstance(recent_points, (onp.ndarray, jnp.ndarray)): raise ValueError("Please provide an array of recently visited points") X = X[:, None] if X.ndim < 2 else X moments = _compute_mean_and_var(rng_key, model, X, n, noiseless, **kwargs) acq = poi(moments, best_f, xi, maximize) if penalty: acq -= _compute_penalties(X, recent_points, penalty, penalty_factor, grid_indices) return acq
[docs]def UE(rng_key: jnp.ndarray, model: Type[ExactGP], X: jnp.ndarray, n: int = 1, noiseless: bool = False, penalty: Optional[str] = None, recent_points: jnp.ndarray = None, grid_indices: jnp.ndarray = None, penalty_factor: float = 1.0, **kwargs) -> jnp.ndarray: r""" Uncertainty-based exploration Given a probabilistic model :math:`m` that models the objective function :math:`f`, the Uncertainty-based Exploration (UE) at an input point :math:`x` targets regions where the model's predictions are most uncertain. It quantifies this uncertainty as: .. math:: UE(x) = \sigma^2(x) where :math:`\sigma^2(x)` is the predictive variance of the model at the input point :math:`x`. In the case of HMC, the function leverages multiple predictive posteriors, each associated with a different HMC sample of the GP model parameters, to capture both prediction uncertainty and hyperparameter uncertainty. In this setup, the uncertainty in parameters of probabilistic mean function (if any) also contributes to the acquisition function values. Args: rng_key: JAX random number generator key model: trained model X: new inputs n: number of samples drawn from each MVN distribution (number of distributions is equal to the number of HMC samples) noiseless: Noise-free prediction. It is set to False by default as new/unseen data is assumed to follow the same distribution as the training data. Hence, since we introduce a model noise for the training data, we also want to include that noise in our prediction. penalty: Penalty applied to the acquisition function to discourage re-evaluation at or near points that were recently evaluated. Options are: - 'delta': The infinite penalty is applied to the recently visited points. - 'inverse_distance': Modifies the acquisition function by penalizing points near the recent points. For the 'inverse_distance', the acqusition function is penalized as: .. math:: \alpha - \lambda \cdot \pi(X, r) where :math:`\pi(X, r)` computes a penalty for points in :math:`X` based on their distance to recent points :math:`r`, :math:`\alpha` represents the acquisition function, and :math:`\lambda` represents the penalty factor. recent_points: An array of recently visited points [oldest, ..., newest] provided by user grid_indices: Grid indices of data points in X array for the penalty term calculation. For example, if each data point is an image patch, the indices could correspond to the (i, j) pixel coordinates of their centers in the original image. penalty_factor: Penalty factor :math:`\lambda` in :math:`\alpha - \lambda \cdot \pi(X, r)` **jitter: Small positive term added to the diagonal part of a covariance matrix for numerical stability (Default: 1e-6) """ if penalty and not isinstance(recent_points, (onp.ndarray, jnp.ndarray)): raise ValueError("Please provide an array of recently visited points") X = X[:, None] if X.ndim < 2 else X moments = _compute_mean_and_var(rng_key, model, X, n, noiseless, **kwargs) acq = ue(moments) if penalty: X_ = grid_indices if grid_indices is not None else X penalties = compute_penalty(X_, recent_points, penalty, penalty_factor) acq -= penalties return acq
[docs]def KG(rng_key: jnp.ndarray, model: Type[ExactGP], X: jnp.ndarray, n: int = 1, maximize: bool = False, noiseless: bool = False, penalty: Optional[str] = None, recent_points: jnp.ndarray = None, grid_indices: jnp.ndarray = None, penalty_factor: float = 1.0, **kwargs) -> jnp.ndarray: r""" Knowledge gradient Given a probabilistic model :math:`m` that models the objective function :math:`f`, the Knowledge Gradient (KG) at an input point :math:`x` quantifies the expected improvement in the optimal decision after observing the function value at :math:`x`. The KG value is defined as: .. math:: KG(x) = \mathbb{E}[V_{n+1}^* - V_n^* | x] where :math:`V_{n+1}^*` is the optimal expected value of the objective function after \(n+1\) observations and :math:`V_n^*` is the optimal expected value of the objective function based on the current \(n\) observations. Args: rng_key: JAX random number generator key for sampling simulated observations model: Trained model X: New inputs n: Number of simulated samples for each point in X maximize: If True, assumes that BO is solving maximization problem noiseless: Noise-free prediction. It is set to False by default as new/unseen data is assumed to follow the same distribution as the training data. Hence, since we introduce a model noise for the training data, we also want to include that noise in our prediction. penalty: Penalty applied to the acquisition function to discourage re-evaluation at or near points that were recently evaluated. Options are: - 'delta': The infinite penalty is applied to the recently visited points. - 'inverse_distance': Modifies the acquisition function by penalizing points near the recent points. For the 'inverse_distance', the acqusition function is penalized as: .. math:: \alpha - \lambda \cdot \pi(X, r) where :math:`\pi(X, r)` computes a penalty for points in :math:`X` based on their distance to recent points :math:`r`, :math:`\alpha` represents the acquisition function, and :math:`\lambda` represents the penalty factor. recent_points: An array of recently visited points [oldest, ..., newest] provided by user grid_indices: Grid indices of data points in X array for the penalty term calculation. For example, if each data point is an image patch, the indices could correspond to the (i, j) pixel coordinates of their centers in the original image. penalty_factor: Penalty factor :math:`\lambda` in :math:`\alpha - \lambda \cdot \pi(X, r)` **jitter: Small positive term added to the diagonal part of a covariance matrix for numerical stability (Default: 1e-6) """ if penalty and not isinstance(recent_points, (onp.ndarray, jnp.ndarray)): raise ValueError("Please provide an array of recently visited points") X = X[:, None] if X.ndim < 2 else X samples = model.get_samples() if model.mcmc is None: acq = kg(model, X, samples, rng_key, n, maximize, noiseless, **kwargs) else: vec_kg = vmap(kg, in_axes=(None, None, 0, 0, None, None, None)) samples = model.get_samples() keys = jra.split(rng_key, num=len(next(iter(samples.values())))) acq = vec_kg(model, X, samples, keys, n, maximize, noiseless, **kwargs) if penalty: acq -= _compute_penalties(X, recent_points, penalty, penalty_factor, grid_indices) return acq
[docs]def Thompson(rng_key: jnp.ndarray, model: Type[ExactGP], X: jnp.ndarray, n: int = 1, noiseless: bool = False, **kwargs) -> jnp.ndarray: """ Thompson sampling. For MAP approximation, it draws a single sample of a function from the posterior predictive distribution. In the case of HMC, it draws a single posterior sample from the HMC samples of GP model parameters and then samples a function from it. Args: rng_key: JAX random number generator key model: trained model X: new inputs n: number of samples drawn from the randomly selected MVN distribution noiseless: Noise-free prediction. It is set to False by default as new/unseen data is assumed to follow the same distribution as the training data. Hence, since we introduce a model noise for the training data, we also want to include that noise in our prediction. **jitter: Small positive term added to the diagonal part of a covariance matrix for numerical stability (Default: 1e-6) """ if model.mcmc is not None: posterior_samples = model.get_samples() idx = jra.randint(rng_key, (1,), 0, len(posterior_samples["k_length"])) samples = {k: v[idx] for (k, v) in posterior_samples.items()} _, tsample = model.predict( rng_key, X, samples, n, noiseless=noiseless, **kwargs) if n > 1: tsample = tsample.mean(1).squeeze() else: _, tsample = model.sample_from_posterior( rng_key, X, n=1, noiseless=noiseless, **kwargs) return tsample