Utilities

Automatic function setters

gpax.utils.set_fn(func)[source]

Transforms the given deterministic function to use a params dictionary for its parameters, excluding the first one (assumed to be the dependent variable).

Args: - func (Callable): The deterministic function to be transformed.

Return type:

Callable

Returns: - Callable: The transformed function where parameters are accessed

from a params dictionary.

gpax.utils.set_kernel_fn(func, independent_vars=['X', 'Z'], jit_decorator=True, docstring=None)[source]

Transforms the given kernel function to use a params dictionary for its hyperparameters. The resultant function will always add jitter before returning the computed kernel.

Parameters:
  • func (Callable) – The kernel function to be transformed.

  • independent_vars (List[str], optional) – List of independent variable names in the function. Defaults to [“X”, “Z”].

  • jit_decorator (bool, optional) – @jax.jit decorator to be applied to the transformed function. Defaults to True.

  • docstring (Optional[str], optional) – Docstring to be added to the transformed function. Defaults to None.

Returns:

The transformed kernel function where hyperparameters are accessed from a params dictionary.

Return type:

Callable

Other utilities

gpax.utils.dviz(d, samples=1000)[source]

Utility function for visualizing numpyro distributions

Parameters:
  • d (Type[Distribution]) – numpyro distribution; e.g. numpyro.distributions.Gamma(2, 2)

  • samples (int) – number of samples

Return type:

None

gpax.utils.get_keys(seed=0)[source]

Simple wrapper for jax.random.split to get rng keys for model inference and prediction

gpax.utils.enable_x64()[source]

Use double (x64) precision for jax arrays