Gaussian process regression

Adapted from NumPyro’s GP example.

import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
import jax.random as random

from functools import partial
from gaul import quap

import matplotlib.pyplot as plt
import seaborn as sns
import pears

plt.rcParams['figure.figsize'] = (11, 7)
rng = random.PRNGKey(0)

n_train = 25
n_test = 400
sigma_obs = 0.3

fn = lambda x: 20. * jnp.sin(x / 40.) + \
    0.1 * jnp.power(x, 3.0) + \
    0.1 * jnp.power(1. + x, 3.0) + \
    1.2 * jnp.cos(4.0 * x)
    
x = jnp.linspace(-1, 1, n_train)
y = fn(x)
y += sigma_obs * random.normal(rng, shape=(n_train,))
y -= jnp.mean(y)
y /= jnp.std(y)

x_test = jnp.linspace(-1.3, 1.3, n_test)

x_true = jnp.linspace(-1.3, 1.3, 1000)
y_true = fn(x_true)

plt.scatter(x, y)
plt.plot(x_true, y_true)
[<matplotlib.lines.Line2D at 0x7f6c7aa1f5e0>]
../_images/ed7338812e01ac467935f7330455fa960023fa78d197b3e4eb5b7dc787571781.png
@partial(jax.jit, static_argnames=['include_noise'])
def kernel(X, Z, var, length, noise, jitter=1.0e-6, include_noise=True):
    deltaXsq = jnp.power((X[:, None] - Z) / length, 2.0)
    k = var * jnp.exp(-0.5 * deltaXsq)
    if include_noise:
        k += (noise + jitter) * jnp.eye(X.shape[0])
    return k
@jax.jit
def ln_posterior(params, data):
    target = 0
    
    target += stats.norm.logpdf(params['log_kernel_var'], 0., 10.)
    target += stats.norm.logpdf(params['log_kernel_noise'], 0., 10.)
    target += stats.norm.logpdf(params['log_kernel_length'], 0., 10.)

    k = kernel(
        data['X'], data['X'], 
        jnp.exp(params['log_kernel_var']), 
        jnp.exp(params['log_kernel_length']), 
        jnp.exp(params['log_kernel_noise'])
    )

    target += stats.multivariate_normal.logpdf(
        data['Y'], 
        jnp.zeros(data['X'].shape[0]), 
        k
    ).sum()
    
    return target.sum()
params = dict(
    log_kernel_var=jnp.zeros(1),
    log_kernel_noise=jnp.zeros(1),
    log_kernel_length=jnp.zeros(1),
)

data = dict(
    X=x,
    Y=y,
)
samples = quap.sample(ln_posterior, params, data=data)
samples = jax.tree_util.tree_map(lambda x: x.reshape(-1), samples)
@jax.jit
def predict(rng_key, X, Y, X_test, var, length, noise):
    k_pp = kernel(X_test, X_test, var, length, noise, include_noise=True)
    k_pX = kernel(X_test, X, var, length, noise, include_noise=False)
    k_XX = kernel(X, X, var, length, noise, include_noise=True)
    K_xx_inv = jnp.linalg.inv(k_XX)
    K = k_pp - k_pX @ (K_xx_inv @ k_pX.T)
    sigma_noise = jnp.sqrt(jnp.clip(jnp.diag(K), a_min=0.0)) * jax.random.normal(
        rng_key, X_test.shape[:1]
    )
    mean = k_pX @ (K_xx_inv @ Y)
    return mean, mean + sigma_noise
means, predictions = jax.vmap(
    predict, 
    in_axes=(0, None, None, None, 0, 0, 0)
)(
    random.split(rng, samples["log_kernel_var"].shape[0]),
    x, y, x_test,
    jnp.exp(samples["log_kernel_var"]),
    jnp.exp(samples["log_kernel_length"]),
    jnp.exp(samples["log_kernel_noise"]),
)
mean_prediction = jnp.mean(means, axis=0)
percentiles = jnp.percentile(predictions, jnp.array([5., 95.]), axis=0)
plt.plot(x_true, y_true, label='truth')
plt.scatter(x, y, label='data')
plt.plot(x_test, mean_prediction, c='g', lw=2, label='prediction')
plt.fill_between(x_test, percentiles[0, :], percentiles[1, :], alpha=0.2, color='g', label='90% CI')
plt.legend(fontsize=(15))
<matplotlib.legend.Legend at 0x7f6c7a2ebee0>
../_images/c3173a75434bdf003976e95b1218184fcfe24df1d9f4fec565ead2a36c151bf9.png
fig, ax = plt.subplots(3, 2, figsize=(16, 6))

plot_data = [samples[i] for i in samples.keys()]

for i in range(3):
    ax[i,0].plot(plot_data[i], alpha=0.5)
    sns.kdeplot(plot_data[i], ax=ax[i,1])
../_images/b1a9e095b11d66cc86aefbb5bae3b3fb3fbd50495ec52ea2194c80ba8983c45e.png
pears.pears(samples, scatter_thin=3, fontsize_labels=12);
../_images/d51c708d0ae5f963ad743d71bd922d56d34dc542228d33bdce5cba9fae3f1654.png