API Reference

gaul.hmc module

gaul.hmc.accept_reject(state_old, state_new, ln_posterior, mass_matrix_fn, key)
Return type

Tuple[Any, Any]

gaul.hmc.generate_momentum(key, tree)
Return type

Tuple[PRNGKeyArray, Any]

gaul.hmc.leapfrog_step(params, momentum, step_size, grad_fn, mass_matrix_fn)
Return type

Tuple[Any, Any]

gaul.hmc.sample(ln_posterior, init_params, n_chains=4, leapfrog_steps=10, step_size=0.001, n_samples=1000, n_warmup=1000, key=jax.random.PRNGKey, return_momentum=False, *args, **kwargs)
Return type

Union[Any, Tuple[Any, Any]]

gaul.hmc.transpose_samples(samples, shape)
Return type

Any

gaul.quap module

gaul.quap.sample(ln_posterior, params, n_steps=5000, n_samples=2000, rng=jax.random.PRNGKey, opt=None, lr=None, *args, **kwargs)
Return type

Any

gaul.advi module

gaul.advi.batch_elbo(ln_prob, rng, vi_params, nsamples)
Return type

float

gaul.advi.diag_gaussian_logpdf(x, mean, log_std)
Return type

float

gaul.advi.diag_gaussian_sample(rng, mean, log_std)
Return type

Any

gaul.advi.elbo(ln_prob, rng, mean, log_std)
Return type

float

gaul.advi.sample(ln_posterior, params, n_steps=10000, n_samples=2000, rng=jax.random.PRNGKey, opt=None, lr=None, *args, **kwargs)

Run mean-field variational inference to sample from the posterior distribution.

Return type

Any

gaul.utils.pbar module

gaul.utils.pbar.progress_bar_scan(num_samples, message=None)

Progress bar for a JAX scan.

gaul.utils.tree_utils module

gaul.utils.tree_utils.dense_hessian(ln_posterior, params, *args, **kwargs)
Return type

ndarray

gaul.utils.tree_utils.make_tree_hessian(hess_fn)

Makes a function that computes a block diagonal Hessian of a function on a tree. The blocks are organized into the same structure as the tree.

Parameters

hess_fn (Callable) – A function that takes a tree and returns a tree of Hessians. This can be made with, for example, jax.hessian(fn).

Return type

Callable

gaul.utils.tree_utils.tree_ones_like(tree)

Return a new tree with the same structure as t, but with all values set to 1.

Return type

Any

gaul.utils.tree_utils.tree_random_normal_like(key, tree, mean=0.0, std=1.0)

Return a new tree with the same structure as t, but with all values set to random normal variates.

Return type

Any

gaul.utils.tree_utils.tree_split_keys_like(key, tree)

Split the key into multiple keys, one for each leaf of the tree.

Return type

Any

gaul.utils.tree_utils.tree_stack(trees, axis=0)

Stack a list of trees along a given axis.

Return type

Any

gaul.utils.tree_utils.tree_zeros_like(tree)

Return a new tree with the same structure as t, but with all values set to 0.

Return type

Any