Source code for gpax.hypo

"""
hypo.py
========

Utility functions for hypothesis learning based on arXiv:2112.06649

Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com)
"""

from typing import Callable, Dict, Optional, Union

import jax.numpy as jnp
import numpy as np
import numpyro

from .models.gp import ExactGP
from .models.spm import sPM
from .utils import get_keys


[docs]def step(model: Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], jnp.ndarray], model_prior: Callable[[], Dict[str, jnp.ndarray]], X_measured: jnp.ndarray, y_measured: jnp.ndarray, X_unmeasured: Optional[jnp.ndarray] = None, gp_wrap: Optional[bool] = False, noise_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, gp_kernel: str = 'Matern', gp_kernel_prior: Optional[Callable[[], Dict[str, jnp.ndarray]]] = None, gp_input_dim: Optional[int] = 1, num_warmup: Optional[int] = 2000, num_samples: Optional[int] = 2000, num_chains: Optional[int] = 1, num_restarts: Optional[int] = 1, print_summary: Optional[bool] = True): """ Compute model posterior and use it to derive predictive uncertainty Args: model: Parametric model in jax.numpy model_prior: Prior over model parameters using numpyro.distributions X_measured: Measured points y_measured: Measured values X_unmeasured: Unmeasured points gp_wrap: Wrap probabilistic model into a Gaussian process (Default: False) noise_prior: Custom prior for observation noise. Defaults to LogNormal(0,1) gp_kernel: Gaussian process kernel (if gp_wrap is True). Defaults to Matern gp_kernel_prior: Custom priors over kernel hyperparameters. Defaults to LogNormal(0,1) gp_input_dim: Number of lenghscale dimensions in GP kernel. Equals to number of input dimensions or 1 (default) num_warmup: Number of warmup steps for HMC. Defaults to 2000 num_samples: Number of HMC samples. Defaults to 2000 num_chains: Number of HMC chains. Defaults to 2000 num_restarts: Number of restarts if r_hat values are not acceptable (>1.1). Defaults to 1 print_summary: Verbose parameter Returns: Predictive uncertainty and trained model object """ verbose = print_summary sgr = numpyro.diagnostics.split_gelman_rubin for i in range(num_restarts): rng_key, rng_key_predict = get_keys(i) # Get/update model posterior if gp_wrap: # wrap model into a gaussian process (gives more flexibility) model_ = ExactGP( gp_input_dim, gp_kernel, model, gp_kernel_prior, model_prior, noise_prior) model_.fit( rng_key, X_measured, y_measured, num_warmup, num_samples, num_chains, print_summary=verbose) else: # use a standalone model model_ = sPM(model, model_prior, noise_prior) model_.fit( rng_key, X_measured, y_measured, num_warmup, num_samples, num_chains, print_summary=verbose) rhats = [sgr(v).item() for (k,v) in model_.get_samples(1).items() if k != 'mu'] if max(rhats) < 1.1: break # compute predictive uncertainty for the unmeasured part of the parameter space obj = 0 if X_unmeasured is not None: mean, samples = model_.predict(rng_key, X_unmeasured) obj = samples.squeeze().var(0) return obj, model_
[docs]def sample_next(rewards: Union[np.array, jnp.array], method: Optional[str] = 'softmax', temperature: Optional[float] = 1.0, eps: Optional[float] = 0.4) -> int: """ Sample model or input channel based on softmax or epsilon-greedy policy Args: rewards: Array of shape (N,) with running rewards method: Selection policy, choose between 'softmax' and 'epsilon-greedy' temperature: Optional temperature parameter for softmax selection policy eps: Optional epsilon parameter for epsilon-greedy policy Returns: The index of model or input channel to sample next """ if method not in ['softmax', 'eps-greedy']: raise NotImplementedError( "The currently implemented samplong methods are 'softmax' and 'eps-greedy'") if rewards.ndim != 1: raise AttributeError("Pass rewards as 1-dimensional array") if method == 'softmax': idx = softmax(rewards, temperature) else: idx = eps_greedy(rewards, eps) return idx
[docs]def softmax(logits: Union[np.array, jnp.array], temperature: Optional[float] = 1.0) -> int: """ Softmax selection policy. Based on Zai, A., Brown, B. (2020). Deep reinforcement learning in action. Manning Publications """ probs = np.exp(logits / temperature) / np.sum(np.exp(logits / temperature)) x = np.arange(len(logits)) idx = np.random.choice(x, p=probs) return idx
[docs]def eps_greedy(rewards: Union[np.array, jnp.array], eps: Optional[float] = 0.4) -> int: """ Epsilon-greedy selection policy. Based on Zai, A., Brown, B. (2020). Deep reinforcement learning in action. Manning Publications """ if np.random.random() > eps: idx = rewards.argmax() else: idx = np.random.randint(len(rewards)) return idx
[docs]def update_record(record: np.array, action: int, r: Union[int, float]) -> np.array: """ Update the reward record. Based on Zai, A., Brown, B. (2020). Deep reinforcement learning in action. Manning Publications """ new_r = (record[action, 0] * record[action, 1] + r) / (record[action, 0] + 1) record[action, 0] += 1 record[action, 1] = new_r return record