Eight schools - hierarchical modelling with HMC

Eight schools - hierarchical modelling with HMC#

import numpy as np
import jax
import jax.numpy as jnp
import jax.scipy.stats as stats

from gaul import hmc

import matplotlib.pyplot as plt
import seaborn as sns
import pears

plt.rcParams['figure.figsize'] = (11, 7)
num_schools = 8 
treatment_effects = jnp.array(
    [28., 8, -3, 7, -1, 1, 18, 12], 
)
treatment_stddevs = jnp.array(
    [15., 10, 16, 11, 9, 11, 10, 18],
)

fig, ax = plt.subplots()
plt.bar(range(num_schools), treatment_effects, yerr=treatment_stddevs)
plt.title("8 Schools treatment effects")
plt.xlabel("School")
plt.ylabel("Treatment effect")
fig.set_size_inches(10, 8)
plt.show()
../_images/03569baeb90dc2cad55b6348dca02f87ce7207ff9b85c02d0ec190be6c79ea1a.png
params = dict(
    mu=jnp.zeros(1),
    logtau=jnp.zeros(1),
    theta_prime=jnp.zeros(8),
)

data = dict(
    treatment_effect=treatment_effects,
    treatment_std=treatment_stddevs,
)
def ln_posterior(params, data):
    target = 0.
    
    target += stats.norm.logpdf(params['mu'], 0., 10.)
    target += stats.norm.logpdf(params['logtau'], 5., 1.)
    target += stats.norm.logpdf(params['theta_prime']).sum()
    
    theta_i = params['mu'] + jnp.exp(params['logtau']) * params['theta_prime']
    target += stats.norm.logpdf(data['treatment_effect'], loc=theta_i, scale=data['treatment_std']).sum()
    
    return target.sum()
samples = hmc.sample(
    ln_posterior, 
    params, 
    n_chains=10,
    n_samples=1000,
    n_warmup=3000,
    step_size=0.2,
    leapfrog_steps=5,
    data=data
)
school_effect_samples = samples['mu'] + jnp.exp(samples['logtau']) * samples['theta_prime']
fig, axes = plt.subplots(8, 2, sharex='col')
fig.set_size_inches(12, 10)
for i in range(num_schools):
    for j in range(4):
        axes[i][0].plot(school_effect_samples[j,i])
        axes[i][0].title.set_text(f"School {i+1} treatment effect chain")
        sns.kdeplot(school_effect_samples[j,i], ax=axes[i][1], shade=True)
        axes[i][1].title.set_text(f"School {i+1} treatment effect distribution")
axes[num_schools - 1][0].set_xlabel("Iteration")
axes[num_schools - 1][1].set_xlabel("School effect")
fig.tight_layout()
../_images/166369dbb354c0d9f5500677c9a7c6d4fbd68ae08331afe2266987f7e20c7091.png
pears.pears(
    school_effect_samples.reshape(8, -1),
    truths=treatment_effects,
    labels=[f'School {i+1}' for i in range(8)],
    scatter_thin=10,
);
../_images/5f44c71de4359df0e3ca0c318d6b6d277a72100deb771bcbb775a90015d36341.png