Eight schools - hierarchical modelling with HMCΒΆ

import numpy as np
import jax
import jax.numpy as jnp
import jax.scipy.stats as stats

from gaul import hmc

import matplotlib.pyplot as plt
import seaborn as sns
import pears

plt.rcParams['figure.figsize'] = (11, 7)
num_schools = 8 
treatment_effects = jnp.array(
    [28., 8, -3, 7, -1, 1, 18, 12], 
)
treatment_stddevs = jnp.array(
    [15., 10, 16, 11, 9, 11, 10, 18],
)

fig, ax = plt.subplots()
plt.bar(range(num_schools), treatment_effects, yerr=treatment_stddevs)
plt.title("8 Schools treatment effects")
plt.xlabel("School")
plt.ylabel("Treatment effect")
fig.set_size_inches(10, 8)
plt.show()
../_images/03569baeb90dc2cad55b6348dca02f87ce7207ff9b85c02d0ec190be6c79ea1a.png
params = dict(
    mu=jnp.zeros(1),
    logtau=jnp.zeros(1),
    theta_prime=jnp.zeros(8),
)

data = dict(
    treatment_effect=treatment_effects,
    treatment_std=treatment_stddevs,
)
def ln_posterior(params, data):
    target = 0.
    
    target += stats.norm.logpdf(params['mu'], 0., 10.)
    target += stats.norm.logpdf(params['logtau'], 5., 1.)
    target += stats.norm.logpdf(params['theta_prime']).sum()
    
    theta_i = params['mu'] + jnp.exp(params['logtau']) * params['theta_prime']
    target += stats.norm.logpdf(data['treatment_effect'], loc=theta_i, scale=data['treatment_std']).sum()
    
    return target.sum()
samples = hmc.sample(
    ln_posterior, 
    params, 
    n_chains=10,
    n_samples=1000,
    n_warmup=3000,
    step_size=0.2,
    leapfrog_steps=5,
    data=data
)
school_effect_samples = samples['mu'] + jnp.exp(samples['logtau']) * samples['theta_prime']
fig, axes = plt.subplots(8, 2, sharex='col')
fig.set_size_inches(12, 10)
for i in range(num_schools):
    for j in range(4):
        axes[i][0].plot(school_effect_samples[j,i])
        axes[i][0].title.set_text(f"School {i+1} treatment effect chain")
        sns.kdeplot(school_effect_samples[j,i], ax=axes[i][1], shade=True)
        axes[i][1].title.set_text(f"School {i+1} treatment effect distribution")
axes[num_schools - 1][0].set_xlabel("Iteration")
axes[num_schools - 1][1].set_xlabel("School effect")
fig.tight_layout()
../_images/166369dbb354c0d9f5500677c9a7c6d4fbd68ae08331afe2266987f7e20c7091.png
pears.pears(
    school_effect_samples.reshape(8, -1),
    truths=treatment_effects,
    labels=[f'School {i+1}' for i in range(8)],
    scatter_thin=10,
);
../_images/5f44c71de4359df0e3ca0c318d6b6d277a72100deb771bcbb775a90015d36341.png