In [None]:
# NECESSARY CELL TO REMOVE THE DOWNLOAD AND EXECUTE BUTTONS FROM THE PAGE 

# Empirical Bayes

To wrap up our discussion on GPs, we talk about **model selection**. Recall that we must pick a kernel for our Gaussian Process, for instance the Squared Exponential:

$$
k(\bx,\bx') = \sigma_f^2\exp\left(\displaystyle-\frac{1}{2\ell^2}\lVert\bx-\bx'\rVert^2\right)
$$(empiricalbayeskernel)

Given that our choice of kernel is fixed, the model selection problem becomes one of determining suitable values for hyperparameters $\sigma_f$, $\ell$ and the noise $\beta$.

In {doc}`../../2-bayesregression/lectures/model_selection`, we did the same for weight-space models by marginalizing $\bw$ to obtain an expression for $p(\mbf{t})$, the **marginal likelihood**, or **evidence** function. We then used *Empirical Bayes* to compute $\alpha$ and $\beta$ that maximized this evidence.

For GPs the operation is exactly the same, but getting to the evidence is much easier now. Recall from Eq. {eq}`gprjoint` that we already have an expression for $p(\mbf{t})$:

$$
p(\mbf{t}) = 
\gauss\left(
\mbf{t}\left\vert\right.\mbf{0},
K\left(\mbf{X},\mbf{X}\right) + \beta^{-1}\mbf{I}
\right)
$$(gprevidence)

Since this is nothing more than a multivariate Gaussian, we can easily compute the log likelihood of our training dataset:

$$
\ln p(\mbf{t}\vert\sigma_f,\ell,\beta) = \displaystyle
-\frac{1}{2}\ln\vert\mbf{K}+\beta^{-1}\mbf{I}\vert
-\frac{1}{2}\mbf{t}^\T\left(\mbf{K}+\beta^{-1}\mbf{I}\right)^{-1}\mbf{t}
-\frac{N}{2}\ln\left(2\pi\right)
$$(gploglikelihood)

where $N$ is the size of our dataset and the dependencies on $\sigma_f$ and $\ell$ come from $\mbf{K}$. We can then use an optimizer to maximize this expression.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from myst_nb import glue
from matplotlib.lines import Line2D
import torch
from cycler import cycler
import seaborn as sns

# Set the color scheme
sns.set_theme()
colors = ['#0076C2', '#EC6842', '#A50034', '#009B77', '#FFB81C', '#E03C31', '#6CC24A', '#EF60A3', '#0C2340', '#00B8C8', '#6F1D77']
plt.rcParams['axes.prop_cycle'] = cycler(color=colors)

c_mean = c_std = 'C0'

torch.set_default_tensor_type(torch.DoubleTensor)


def f (X):
    return np.sin(X)

def kernel (X1,X2,sigf,length):
    n1 = len(X1)
    n2 = len(X2)
        
    K = torch.zeros((n1,n2))
    
    for i in range(n1):
        for j in range(n2):
            K[i,j] = sigf * torch.exp(-0.5 / length / length * (X1[i]-X2[j])**2.0)
    
    return K

def gp(X,y,logsig,loglen,lognoise,Xhat=None):
    sigf   = torch.exp(logsig)
    length = torch.exp(loglen)
    noise  = torch.exp(lognoise)
        
    K = kernel(X,X,sigf,length) + 1./noise * torch.eye(len(X))# + 1.e-0 * torch.eye(len(X))

    L = torch.linalg.cholesky(K)
    alpha = torch.cholesky_solve(torch.reshape(y,(-1,1)),L)
        
    if Xhat is not None:
        mean = kernel(Xhat,X,sigf,length) @ alpha
        
        cov  = kernel(Xhat,Xhat,sigf,length) - kernel(Xhat,X,sigf,length) @ torch.cholesky_solve(kernel(X,Xhat,sigf,length),L) + 1./noise * torch.eye(len(Xhat))

        return mean, cov - 1./noise * torch.eye(len(Xhat)), mean[:,0] - 1.96*torch.sqrt(torch.diagonal(cov)), mean[:,0] + 1.96*torch.sqrt(torch.diagonal(cov))
    else:       
        lml = -0.5 * torch.inner(y,alpha[:,0]) - torch.sum(torch.log(torch.diagonal(L))) - float(len(y)) / 2. * torch.log(torch.tensor(2.*torch.pi))
        
        return lml
   
torch.manual_seed(10)
np.random.seed(10)

N = 5

beta_true = 100.0

X = np.random.uniform(0,2*np.pi,N)
y = f(X) + np.random.normal(0.,np.sqrt(1./beta_true),N)

X = torch.tensor(X)
y = torch.tensor(y)

lsig = torch.tensor([np.log(100.)],requires_grad=True)
llen = torch.tensor([np.log(100.)],requires_grad=True)
lnoi = torch.tensor([np.log(100.)],requires_grad=True)

bfgs = torch.optim.LBFGS((lsig,llen,lnoi),max_iter=20)

losses = []
sigs   = []
lens   = []
noises = []

def closure():
    bfgs.zero_grad()
    loss = -gp(X,y,lsig,llen,lnoi)
    loss.backward()
    
    losses.append(loss.detach().numpy())
    sigs.append(lsig.clone().detach())
    lens.append(llen.clone().detach())
    noises.append(lnoi.clone().detach())

    return loss

bfgs.step(closure)

Xhat = np.linspace(0,2.*np.pi,100)

print('Final loss',losses[-1],'Sigf',sigs[-1],'Length scale',lens[-1],'Noise',noises[-1])

# print(losses)

n_samples = 10

model = 0

lsig = sigs[model]
llen = lens[model]
lnoi = noises[model]

print('Optimizer progress',float(model/len(losses)),'Log likelihood',losses[model],'sigf',torch.exp(lsig),'length',torch.exp(llen),'noise',torch.exp(lnoi))

fig0, (ax0,ax1) = plt.subplots(1,2,figsize=(10,3),dpi=400)
ax0.set_xlabel('x')
ax0.set_ylabel('y')
ax1.set_xlabel('x')
ax1.set_ylabel('t')
ax0.set_title('Prior')
ax1.set_title('Posterior')

prior_mean = torch.zeros(len(Xhat))
prior_cov = kernel(Xhat,Xhat,torch.exp(lsig),torch.exp(llen))

prior_stdm = prior_mean - 1.96*torch.sqrt(torch.diagonal(prior_cov) + 1./torch.exp(lnoi))
prior_stdp = prior_mean + 1.96*torch.sqrt(torch.diagonal(prior_cov) + 1./torch.exp(lnoi))

prior_samples = np.random.multivariate_normal(prior_mean.detach().numpy(),prior_cov.detach().numpy(),n_samples)

post_mean, post_cov, post_stdm, post_stdp = gp(X,y,lsig,llen,lnoi,Xhat)

post_samples = np.random.multivariate_normal(post_mean[:,0].detach().numpy(),post_cov.detach().numpy(),n_samples)

for s in prior_samples:
    ax0.plot(Xhat,s,linewidth=1,alpha=0.5)

ax0.plot(Xhat,f(Xhat),'k--',label='Ground truth',linewidth=1)

ax0.fill_between(Xhat,prior_stdm.detach().numpy(),prior_stdp.detach().numpy(),
                 alpha=0.15,label='95% conf. interval',color=c_std)
ax0.plot(Xhat,prior_mean.detach().numpy(),label='Mean',color=c_mean,linewidth=2.5)

handles, labels = ax0.get_legend_handles_labels()
line = Line2D([0], [0], color='gray', lw=1,alpha=0.5)
handles.append(line)
labels.append('y(x) samples')
ax0.legend(handles=handles, labels=labels, fontsize=7)

for s in post_samples:
    ax1.plot(Xhat,s,linewidth=1,alpha=0.5)

ax1.plot(Xhat,f(Xhat),'k--',label='Ground truth',linewidth=1)

ax1.fill_between(Xhat,post_stdm.detach().numpy(),post_stdp.detach().numpy(),
                 alpha=0.15,label='95% conf. interval',color=c_std)
ax1.plot(Xhat,post_mean.detach().numpy(),label='Mean',color=c_mean,linewidth=2.5)

ax1.plot(X,y,'k.',markersize=8,label='Observations')


handles, labels = ax1.get_legend_handles_labels()
line = Line2D([0], [0], color='gray', lw=1,alpha=0.5)
handles.append(line)
labels.append('y(x) samples')
ax1.legend(handles=handles, labels=labels, fontsize=7)

glue("fig0",fig0,display=False)

model = 5

lsig = sigs[model]
llen = lens[model]
lnoi = noises[model]

print('Optimizer progress',float(model/len(losses)),'Log likelihood',losses[model],'sigf',torch.exp(lsig),'length',torch.exp(llen),'noise',torch.exp(lnoi))

fig1, (ax0,ax1) = plt.subplots(1,2,figsize=(10,3),dpi=400)
ax0.set_xlabel('x')
ax0.set_ylabel('y')
ax1.set_xlabel('x')
ax1.set_ylabel('t')
ax0.set_title('Prior')
ax1.set_title('Posterior')

prior_mean = torch.zeros(len(Xhat))
prior_cov = kernel(Xhat,Xhat,torch.exp(lsig),torch.exp(llen))

prior_stdm = prior_mean - 1.96*torch.sqrt(torch.diagonal(prior_cov) + 1./torch.exp(lnoi))
prior_stdp = prior_mean + 1.96*torch.sqrt(torch.diagonal(prior_cov) + 1./torch.exp(lnoi))

prior_samples = np.random.multivariate_normal(prior_mean.detach().numpy(),prior_cov.detach().numpy(),n_samples)

post_mean, post_cov, post_stdm, post_stdp = gp(X,y,lsig,llen,lnoi,Xhat)

post_samples = np.random.multivariate_normal(post_mean[:,0].detach().numpy(),post_cov.detach().numpy(),n_samples)

for s in prior_samples:
    ax0.plot(Xhat,s,linewidth=1,alpha=0.5)

ax0.plot(Xhat,f(Xhat),'k--',label='Ground truth',linewidth=1)

ax0.fill_between(Xhat,prior_stdm.detach().numpy(),prior_stdp.detach().numpy(),
                 alpha=0.15,label='95% conf. interval',color=c_std)
ax0.plot(Xhat,prior_mean.detach().numpy(),label='Mean',color=c_mean,linewidth=2.5)

handles, labels = ax0.get_legend_handles_labels()
line = Line2D([0], [0], color='gray', lw=1,alpha=0.5)
handles.append(line)
labels.append('y(x) samples')
ax0.legend(handles=handles, labels=labels, fontsize=7)

for s in post_samples:
    ax1.plot(Xhat,s,linewidth=1,alpha=0.5)

ax1.plot(Xhat,f(Xhat),'k--',label='Ground truth',linewidth=1)

ax1.fill_between(Xhat,post_stdm.detach().numpy(),post_stdp.detach().numpy(),
                 alpha=0.15,label='95% conf. interval',color=c_std)
ax1.plot(Xhat,post_mean.detach().numpy(),label='Mean',color=c_mean,linewidth=2.5)

ax1.plot(X,y,'k.',markersize=8,label='Observations')


handles, labels = ax1.get_legend_handles_labels()
line = Line2D([0], [0], color='gray', lw=1,alpha=0.5)
handles.append(line)
labels.append('y(x) samples')
ax1.legend(handles=handles, labels=labels, fontsize=7)

glue("fig1",fig1,display=False)

model = 10

lsig = sigs[model]
llen = lens[model]
lnoi = noises[model]

print('Optimizer progress',float(model/len(losses)),'Log likelihood',losses[model],'sigf',torch.exp(lsig),'length',torch.exp(llen),'noise',torch.exp(lnoi))

fig2, (ax0,ax1) = plt.subplots(1,2,figsize=(10,3),dpi=400)
ax0.set_xlabel('x')
ax0.set_ylabel('y')
ax1.set_xlabel('x')
ax1.set_ylabel('t')
ax0.set_title('Prior')
ax1.set_title('Posterior')

prior_mean = torch.zeros(len(Xhat))
prior_cov = kernel(Xhat,Xhat,torch.exp(lsig),torch.exp(llen))

prior_stdm = prior_mean - 1.96*torch.sqrt(torch.diagonal(prior_cov) + 1./torch.exp(lnoi))
prior_stdp = prior_mean + 1.96*torch.sqrt(torch.diagonal(prior_cov) + 1./torch.exp(lnoi))

prior_samples = np.random.multivariate_normal(prior_mean.detach().numpy(),prior_cov.detach().numpy(),n_samples)

post_mean, post_cov, post_stdm, post_stdp = gp(X,y,lsig,llen,lnoi,Xhat)

post_samples = np.random.multivariate_normal(post_mean[:,0].detach().numpy(),post_cov.detach().numpy(),n_samples)

for s in prior_samples:
    ax0.plot(Xhat,s,linewidth=1,alpha=0.5)

ax0.plot(Xhat,f(Xhat),'k--',label='Ground truth',linewidth=1)

ax0.fill_between(Xhat,prior_stdm.detach().numpy(),prior_stdp.detach().numpy()
                 ,alpha=0.15,label='95% conf. interval',color=c_std)
ax0.plot(Xhat,prior_mean.detach().numpy(),label='Mean',color=c_mean,linewidth=2.5)

handles, labels = ax0.get_legend_handles_labels()
line = Line2D([0], [0], color='gray', lw=1,alpha=0.5)
handles.append(line)
labels.append('y(x) samples')
ax0.legend(handles=handles, labels=labels, fontsize=7)

for s in post_samples:
    ax1.plot(Xhat,s,linewidth=1,alpha=0.5)

ax1.plot(Xhat,f(Xhat),'k--',label='Ground truth',linewidth=1)

ax1.fill_between(Xhat,post_stdm.detach().numpy(),post_stdp.detach().numpy(),
                 alpha=0.15,label='95% conf. interval',color=c_std)
ax1.plot(Xhat,post_mean.detach().numpy(),label='Mean',color=c_mean,linewidth=2.5)

ax1.plot(X,y,'k.',markersize=8,label='Observations')


handles, labels = ax1.get_legend_handles_labels()
line = Line2D([0], [0], color='gray', lw=1,alpha=0.5)
handles.append(line)
labels.append('y(x) samples')
ax1.legend(handles=handles, labels=labels, fontsize=7)

glue("fig2",fig2,display=False)

model = 15

lsig = sigs[model]
llen = lens[model]
lnoi = noises[model]

print('Optimizer progress',float(model/len(losses)),'Log likelihood',losses[model],'sigf',torch.exp(lsig),'length',torch.exp(llen),'noise',torch.exp(lnoi))

fig3, (ax0,ax1) = plt.subplots(1,2,figsize=(10,3),dpi=400)
ax0.set_xlabel('x')
ax0.set_ylabel('y')
ax1.set_xlabel('x')
ax1.set_ylabel('t')
ax0.set_title('Prior')
ax1.set_title('Posterior')

prior_mean = torch.zeros(len(Xhat))
prior_cov = kernel(Xhat,Xhat,torch.exp(lsig),torch.exp(llen))

prior_stdm = prior_mean - 1.96*torch.sqrt(torch.diagonal(prior_cov) + 1./torch.exp(lnoi))
prior_stdp = prior_mean + 1.96*torch.sqrt(torch.diagonal(prior_cov) + 1./torch.exp(lnoi))

prior_samples = np.random.multivariate_normal(prior_mean.detach().numpy(),prior_cov.detach().numpy(),n_samples)

post_mean, post_cov, post_stdm, post_stdp = gp(X,y,lsig,llen,lnoi,Xhat)

post_samples = np.random.multivariate_normal(post_mean[:,0].detach().numpy(),post_cov.detach().numpy(),n_samples)

for s in prior_samples:
    ax0.plot(Xhat,s,linewidth=1,alpha=0.5)

ax0.plot(Xhat,f(Xhat),'k--',label='Ground truth',linewidth=1)

ax0.fill_between(Xhat,prior_stdm.detach().numpy(),prior_stdp.detach().numpy(),
                 alpha=0.15,label='95% conf. interval',color=c_std)
ax0.plot(Xhat,prior_mean.detach().numpy(),label='Mean',color=c_mean,linewidth=2.5)

handles, labels = ax0.get_legend_handles_labels()
line = Line2D([0], [0], color='gray', lw=1,alpha=0.5)
handles.append(line)
labels.append('y(x) samples')
ax0.legend(handles=handles, labels=labels, fontsize=7)

for s in post_samples:
    ax1.plot(Xhat,s,linewidth=1,alpha=0.5)

ax1.plot(Xhat,f(Xhat),'k--',label='Ground truth',linewidth=1)

ax1.fill_between(Xhat,post_stdm.detach().numpy(),post_stdp.detach().numpy(),
                 alpha=0.15,label='95% conf. interval',color=c_std)
ax1.plot(Xhat,post_mean.detach().numpy(),label='Mean',color=c_mean,linewidth=2.5)

ax1.plot(X,y,'k.',markersize=8,label='Observations')


handles, labels = ax1.get_legend_handles_labels()
line = Line2D([0], [0], color='gray', lw=1,alpha=0.5)
handles.append(line)
labels.append('y(x) samples')
ax1.legend(handles=handles, labels=labels, fontsize=7)

glue("fig3",fig3,display=False)

model = len(losses)-1

lsig = sigs[model]
llen = lens[model]
lnoi = noises[model]

print('Optimizer progress',float(model/len(losses)),'Log likelihood',losses[model],'sigf',torch.exp(lsig),'length',torch.exp(llen),'noise',torch.exp(lnoi))

fig4, (ax0,ax1) = plt.subplots(1,2,figsize=(10,3),dpi=400)
ax0.set_xlabel('x')
ax0.set_ylabel('y')
ax1.set_xlabel('x')
ax1.set_ylabel('t')
ax0.set_title('Prior')
ax1.set_title('Posterior')

prior_mean = torch.zeros(len(Xhat))
prior_cov = kernel(Xhat,Xhat,torch.exp(lsig),torch.exp(llen))

prior_stdm = prior_mean - 1.96*torch.sqrt(torch.diagonal(prior_cov) + 1./torch.exp(lnoi))
prior_stdp = prior_mean + 1.96*torch.sqrt(torch.diagonal(prior_cov) + 1./torch.exp(lnoi))

prior_samples = np.random.multivariate_normal(prior_mean.detach().numpy(),prior_cov.detach().numpy(),n_samples)

post_mean, post_cov, post_stdm, post_stdp = gp(X,y,lsig,llen,lnoi,Xhat)

post_samples = np.random.multivariate_normal(post_mean[:,0].detach().numpy(),post_cov.detach().numpy(),n_samples)

for s in prior_samples:
    ax0.plot(Xhat,s,linewidth=1,alpha=0.5)

ax0.plot(Xhat,f(Xhat),'k--',label='Ground truth',linewidth=1)

ax0.fill_between(Xhat,prior_stdm.detach().numpy(),prior_stdp.detach().numpy(),
                 alpha=0.15,label='95% conf. interval',color=c_std)
ax0.plot(Xhat,prior_mean.detach().numpy(),label='Mean',color=c_mean,linewidth=2.5)

handles, labels = ax0.get_legend_handles_labels()
line = Line2D([0], [0], color='gray', lw=1,alpha=0.5)
handles.append(line)
labels.append('y(x) samples')
ax0.legend(handles=handles, labels=labels, fontsize=7)

for s in post_samples:
    ax1.plot(Xhat,s,linewidth=1,alpha=0.5)

ax1.plot(Xhat,f(Xhat),'k--',label='Ground truth',linewidth=1)

ax1.fill_between(Xhat,post_stdm.detach().numpy(),post_stdp.detach().numpy(),
                 alpha=0.15,label='95% conf. interval',color=c_std)
ax1.plot(Xhat,post_mean.detach().numpy(),label='Mean',color=c_mean,linewidth=2.5)

ax1.plot(X,y,'k.',markersize=8,label='Observations')


handles, labels = ax1.get_legend_handles_labels()
line = Line2D([0], [0], color='gray', lw=1,alpha=0.5)
handles.append(line)
labels.append('y(x) samples')
ax1.legend(handles=handles, labels=labels, fontsize=7)

glue("fig4",fig4,display=False)

Click through the tabs below to observe the optimization progress of the regression example with $N=5$ data points we have shown before. The figure captions show the hyperparameter and log marginal likelihood values as more optimizer iterations are run. Note how the likelihood gradually increases, starting from a severely underfit model, passsing through a somewhat overfit model and ending at a well-balanced model. Crucially, we do this without having to define a validation dataset, and therefore **using all of our data** to make predictions.

`````{tab-set}
````{tab-item} Initial guess

```{glue:figure} fig0
:figwidth: 750px

Prior and posterior distributions, with $\sigma_f=100$, $\ell=100$, $\beta=100$, $\ln p(\mbf{t}) = -19.91$
```

````

````{tab-item} 25% optimized

```{glue:figure} fig1
:figwidth: 750px

Prior and posterior distributions, with $\sigma_f=68.9$, $\ell=40.8$, $\beta=8.3$, $\ln p(\mbf{t}) = -6.21$
```

````

````{tab-item} 50% optimized

```{glue:figure} fig2
:figwidth: 750px

Prior and posterior distributions, with $\sigma_f=2.1$, $\ell=0.4282$, $\beta=16.1$, $\ln p(\mbf{t}) = -5.98$
```

````

````{tab-item} 75% optimized

```{glue:figure} fig3
:figwidth: 750px

Prior and posterior distributions, with $\sigma_f=0.4$, $\ell=0.9$, $\beta=16.7575$, $\ln p(\mbf{t}) = -2.98$
```

````

````{tab-item} Final values

```{glue:figure} fig4
:figwidth: 750px

Prior and posterior distributions, with $\sigma_f=0.4$, $\ell=1.2$, $\beta=167.1$, $\ln p(\mbf{t}) = -2.25$
```

````

`````