Linear regression with Laplace approximationΒΆ

import jax
import numpy as np
import jax.numpy as jnp
import jax.scipy.stats as stats

from gaul import quap

import matplotlib.pyplot as plt
import seaborn as sns
import pears

plt.rcParams['figure.figsize'] = (11, 7)
beta0 = 0.5
beta1 = -0.4
sigma = 1.6
n = 50

key = jax.random.PRNGKey(0)
xkey, ykey = jax.random.split(key)

x = jax.random.uniform(xkey, shape=(n,)) * 10.
x = jnp.sort(x)
ymean = beta0 + beta1 * x
y = jax.random.normal(ykey, shape=(n,)) * sigma + ymean

plt.scatter(x, y)
plt.plot(x, ymean)
plt.fill_between(x, ymean - sigma, ymean + sigma, alpha=0.5)
<matplotlib.collections.PolyCollection at 0x7effafab66e0>
../_images/65007daac3db38c9a708d0e21ea4ec8be660a0a06eb8527028cd6a3ef80a4315.png
def ln_posterior(params, data):
    target = 0
    
    target += stats.norm.logpdf(params['beta0'], 0., 2.)
    target += stats.norm.logpdf(params['beta1'], 0., 2.)
    target += stats.expon.logpdf(jnp.exp(params['sigma']), scale=1.)

    ymean = params['beta0'] + params['beta1'] * data['x']
    target += stats.norm.logpdf(data['y'], ymean, jnp.exp(params['sigma'])).sum()

    return target.sum()
params = dict(
    beta0=jnp.zeros(1),
    beta1=jnp.zeros(1),
    sigma=jnp.ones(1),
)

data = dict(
    x=x,
    y=y,
)
samples = quap.sample(
    ln_posterior,
    params,
    n_steps=1000,
    n_samples=1000,
    data=data
)
samples = jax.tree_util.tree_map(lambda x: x.reshape(-1), samples)
fig, ax = plt.subplots(3, 2, figsize=(16, 6))

plot_data = [samples['beta0'], samples['beta1'], jnp.exp(samples['sigma'])]
truths = [beta0, beta1, sigma]

for i in range(3):
    ax[i,0].plot(plot_data[i], alpha=0.5)
    ax[i,0].axhline(truths[i], c='k', ls='--')
    sns.kdeplot(plot_data[i], ax=ax[i,1])
    ax[i,1].axvline(truths[i], c='k', ls='--')
../_images/29ba9cecb1e9ccde05ba7959044a816646e53049f8d914abfcde95c9322fcdfb.png
pears.pears(samples, truths=[beta0, beta1, jnp.log(sigma)], scatter_thin=5, hspace=0.03, wspace=0.03);
../_images/d5f4691c0e1a7d82b80d67d83e20a890352a57cc7e98916155512a7107083c8b.png
def sim_linreg(beta0, beta1, sigma):
    ymean_sim = beta0 + beta1 * x
    return ymean_sim

y_sim = jax.vmap(sim_linreg)(
    samples['beta0'],
    samples['beta1'],
    samples['sigma'],
)

for i in range(100):
    plt.plot(x, y_sim[i], c='grey', alpha=0.3)
plt.plot(x, ymean, lw=4)
[<matplotlib.lines.Line2D at 0x7eff91ee17b0>]
../_images/cb55c4f64b24efc25a880ef139b20f8759befae26d799e0e2f3a706a9ba276f6.png