AR(2) model with variational inferenceΒΆ

import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
import jax.random as random

from gaul import advi

import matplotlib.pyplot as plt
import seaborn as sns

plt.rcParams['figure.figsize'] = (11, 7)
n = 50
beta1 = 0.5
beta2 = -0.3
sigma = 1.3

y = jnp.zeros(n)
for i in range(2, n):
    y = y.at[i].set(beta1 * y[i - 1] + beta2 * y[i - 2] + sigma * random.normal(random.PRNGKey(i)))

plt.plot(y)
[<matplotlib.lines.Line2D at 0x7f654ea6d030>]
../_images/7ace401700a085bf14e421a079350f2df5578caf90db53c12436ee136558a358.png
def ln_posterior(params, data):
    target = 0
    
    target += stats.norm.logpdf(params['beta1'], 0., 1.)
    target += stats.norm.logpdf(params['beta2'], 0., 1.)
    target += stats.expon.logpdf(jnp.exp(params['logstd']), scale=1.)
    
    target += stats.norm.logpdf(
        params['beta1'] * data['y'][1:-1] + params['beta2'] * data['y'][2:],
        data['y'][:-2], 
        jnp.exp(params['logstd'])
    ).sum()

    return target.sum()
params = dict(
    beta1=jnp.zeros(1),
    beta2 = jnp.zeros(1),
    logstd=jnp.zeros(1),
)

data = dict(
    y=y,
)
samples = advi.sample(
    ln_posterior,
    params,
    lr=0.2,
    data=data
)
samples = jax.tree_util.tree_map(lambda x: x.reshape(-1), samples)
fig, ax = plt.subplots(3, 2, figsize=(16, 6))

data = [samples['beta1'], samples['beta2'], jnp.exp(samples['logstd'])]
truths = [beta1, beta2, sigma]

for i in range(3):
    ax[i,0].plot(data[i], alpha=0.5)
    ax[i,0].axhline(truths[i], c='k', ls='--')
    sns.kdeplot(data[i], ax=ax[i,1])
    ax[i,1].axvline(truths[i], c='k', ls='--')
../_images/8d60eb26b35990b43483cc5efd209df50c434d0987c9242a5822454166d7a587.png
def sim_ar2(beta1, beta2, sigma, rng):
    ysim = jnp.zeros(n)
    for i in range(2, n):
        rng, subkey = random.split(rng)
        ysim = ysim.at[i].set(beta1 * ysim[i - 1] + beta2 * ysim[i - 2] + sigma * random.normal(subkey))
    return ysim

ysim = jax.vmap(sim_ar2)(
    samples['beta1'], 
    samples['beta2'], 
    jnp.exp(samples['logstd']), 
    random.split(random.PRNGKey(0), 2000)
)

for i in range(100):
    plt.plot(ysim[i], c='grey', alpha=0.3)
plt.plot(y, lw=4)
[<matplotlib.lines.Line2D at 0x7f656cbc4730>]
../_images/5878d0c1944c543225edd8d2cb741c1560b606a35354ce0ec9af4368585fc2ba.png