API Reference¶
gaul.hmc module¶
- gaul.hmc.accept_reject(state_old, state_new, ln_posterior, mass_matrix_fn, key)¶
- Return type
Tuple
[Any
,Any
]
- gaul.hmc.generate_momentum(key, tree)¶
- Return type
Tuple
[PRNGKeyArray
,Any
]
- gaul.hmc.leapfrog_step(params, momentum, step_size, grad_fn, mass_matrix_fn)¶
- Return type
Tuple
[Any
,Any
]
- gaul.hmc.sample(ln_posterior, init_params, n_chains=4, leapfrog_steps=10, step_size=0.001, n_samples=1000, n_warmup=1000, key=jax.random.PRNGKey, return_momentum=False, *args, **kwargs)¶
- Return type
Union
[Any
,Tuple
[Any
,Any
]]
- gaul.hmc.transpose_samples(samples, shape)¶
- Return type
Any
gaul.quap module¶
- gaul.quap.sample(ln_posterior, params, n_steps=5000, n_samples=2000, rng=jax.random.PRNGKey, opt=None, lr=None, *args, **kwargs)¶
- Return type
Any
gaul.advi module¶
- gaul.advi.batch_elbo(ln_prob, rng, vi_params, nsamples)¶
- Return type
float
- gaul.advi.diag_gaussian_logpdf(x, mean, log_std)¶
- Return type
float
- gaul.advi.diag_gaussian_sample(rng, mean, log_std)¶
- Return type
Any
- gaul.advi.elbo(ln_prob, rng, mean, log_std)¶
- Return type
float
- gaul.advi.sample(ln_posterior, params, n_steps=10000, n_samples=2000, rng=jax.random.PRNGKey, opt=None, lr=None, *args, **kwargs)¶
Run mean-field variational inference to sample from the posterior distribution.
- Return type
Any
gaul.utils.pbar module¶
- gaul.utils.pbar.progress_bar_scan(num_samples, message=None)¶
Progress bar for a JAX scan.
gaul.utils.tree_utils module¶
- gaul.utils.tree_utils.dense_hessian(ln_posterior, params, *args, **kwargs)¶
- Return type
ndarray
- gaul.utils.tree_utils.make_tree_hessian(hess_fn)¶
Makes a function that computes a block diagonal Hessian of a function on a tree. The blocks are organized into the same structure as the tree.
- Parameters
hess_fn (
Callable
) – A function that takes a tree and returns a tree of Hessians. This can be made with, for example, jax.hessian(fn).- Return type
Callable
- gaul.utils.tree_utils.tree_ones_like(tree)¶
Return a new tree with the same structure as t, but with all values set to 1.
- Return type
Any
- gaul.utils.tree_utils.tree_random_normal_like(key, tree, mean=0.0, std=1.0)¶
Return a new tree with the same structure as t, but with all values set to random normal variates.
- Return type
Any
- gaul.utils.tree_utils.tree_split_keys_like(key, tree)¶
Split the key into multiple keys, one for each leaf of the tree.
- Return type
Any
- gaul.utils.tree_utils.tree_stack(trees, axis=0)¶
Stack a list of trees along a given axis.
- Return type
Any
- gaul.utils.tree_utils.tree_zeros_like(tree)¶
Return a new tree with the same structure as t, but with all values set to 0.
- Return type
Any