API Reference#

gaul.hmc module#

gaul.quap module#

gaul.quap.sample(ln_posterior, params, n_steps=5000, n_samples=2000, rng=jax.random.PRNGKey, opt=None, lr=None, *args, **kwargs)#
Return type:

Any

gaul.advi module#

gaul.advi.batch_elbo(ln_prob, rng, vi_params, nsamples)#
Return type:

float

gaul.advi.diag_gaussian_logpdf(x, mean, log_std)#
Return type:

float

gaul.advi.diag_gaussian_sample(rng, mean, log_std)#
Return type:

Any

gaul.advi.elbo(ln_prob, rng, mean, log_std)#
Return type:

float

gaul.advi.sample(ln_posterior, params, n_steps=10000, n_samples=2000, rng=jax.random.PRNGKey, opt=None, lr=None, *args, **kwargs)#

Run mean-field variational inference to sample from the posterior distribution.

Return type:

Any

gaul.utils.pbar module#

gaul.utils.pbar.progress_bar_scan(num_samples, message=None)#

Progress bar for a JAX scan.

gaul.utils.tree_utils module#

gaul.utils.tree_utils.dense_hessian(ln_posterior, params, *args, **kwargs)#
Return type:

ndarray

gaul.utils.tree_utils.make_tree_hessian(hess_fn)#

Makes a function that computes a block diagonal Hessian of a function on a tree. The blocks are organized into the same structure as the tree.

Parameters:

hess_fn (Callable) – A function that takes a tree and returns a tree of Hessians. This can be made with, for example, jax.hessian(fn).

Return type:

Callable

gaul.utils.tree_utils.tree_ones_like(tree)#

Return a new tree with the same structure as t, but with all values set to 1.

Return type:

Any

gaul.utils.tree_utils.tree_random_normal_like(key, tree, mean=0.0, std=1.0)#

Return a new tree with the same structure as t, but with all values set to random normal variates.

Return type:

Any

gaul.utils.tree_utils.tree_split_keys_like(key, tree)#

Split the key into multiple keys, one for each leaf of the tree.

Return type:

Any

gaul.utils.tree_utils.tree_stack(trees, axis=0)#

Stack a list of trees along a given axis.

Return type:

Any

gaul.utils.tree_utils.tree_zeros_like(tree)#

Return a new tree with the same structure as t, but with all values set to 0.

Return type:

Any