In [None]:
# NECESSARY CELL TO REMOVE THE DOWNLOAD AND EXECUTE BUTTONS FROM THE PAGE 

# Learning and Model Selection

Up until now we have been assuming values for our hyperparameters $\alpha$ and $\beta$. Recall that the MAP solution is equivalent to an MLE model with $L_2$ regularization term $\lambda={\alpha}/{\beta}$, which suggests that $\alpha$ and $\beta$ control the flexibility of our model.

In an MLE approach, simply minimizing the loss function for $\alpha$ and $\beta$ would always result in the most flexible model possible, which is however prone to overfitting. In {doc}`frequentist model selection<../../1-regression/decision_theory_interactive>` we solve this dilemma by keeping some data aside as **validation dataset** and adopting the regularization parameter that miminizes the validation loss.

Bayesian models on the other hand tend to be more resistant to overfitting, and there is in principle no need to hold out some data when calibrating hyperparameters, as we will see in this page. We first set up this **model selection** problem in a probabilistic way.

In [None]:
import matplotlib.pyplot as plt
from myst_nb import glue
from matplotlib.lines import Line2D
import numpy as np
import torch
from cycler import cycler
import seaborn as sns

# Set the color scheme
sns.set_theme()
colors = ['#0076C2', '#EC6842', '#A50034', '#009B77', '#FFB81C', '#E03C31', '#6CC24A', '#EF60A3', '#0C2340', '#00B8C8', '#6F1D77']
plt.rcParams['axes.prop_cycle'] = cycler(color=colors)

torch.set_default_tensor_type(torch.DoubleTensor)

def f (x):
    return np.sin(x)

def gram(x):
    centers = torch.linspace(0.,2.*torch.pi,M)
    
    gram = torch.zeros((len(x),M+1))
    gram[:,0] = 1.0

    for i in range(len(x)):
        for j in range(M):
            gram[i,j+1] = torch.exp(-(x[i]-centers[j])**2.0/2.0/s/s)
            
    return gram

def bayes(alpha,beta,x,y,xval=None):
    Phi = gram(x)
    
    L_post =  alpha * torch.eye(M+1) + beta * torch.transpose(Phi,0,1) @ Phi
    
    S_post = torch.linalg.inv(L_post)
    
    m_post = beta * S_post @ torch.transpose(Phi,0,1) @ y
                      
    if xval is not None:
        yval = torch.zeros(len(xval))
        yvar = torch.zeros(len(xval))
        
        for i in range(len(xval)):
            centers = torch.linspace(0.,2.*torch.pi,M)
            
            features = torch.zeros(M+1)
            features[0] = 1.0
            for j in range(M):
                features[j+1] = torch.exp(-(xval[i]-centers[j])**2.0/2.0/s/s)
            yval[i] = torch.inner(features,m_post)
            yvar[i] = 1./beta + torch.inner(features,S_post @ features)

        return m_post, L_post, yval, yvar
    else:
        return m_post, L_post

def logmarginal(logalpha,logbeta,x,y):
    Phi = gram(x)
    
    alpha = torch.exp(logalpha)
    beta  = torch.exp(logbeta)
    
    m_post, L_post = bayes(alpha,beta,x,y)
    
    detA = torch.linalg.det(L_post)
    # print(detA)
    
    # loss = torch.nn.MSELoss()
    
    diff = y - Phi @ m_post
    
    error = beta / 2 * torch.inner(diff,diff) + alpha / 2 * torch.inner(m_post,m_post)
    
    # print('Error',error)
    
    lml = (M+1)/2 * torch.log(alpha) + N/2 * torch.log(beta) - error - 1/2 * torch.log(detA) - N / 2 * torch.log(torch.tensor(2.*torch.pi))
    
    # print('logmarginal',lml,'alpha',alpha,'beta',beta,'detA',detA)
    
    return lml

def em(alpha0,beta0,x,y):
    alpha = alpha0
    beta = beta0
    
    N = len(x)
    
    Phi = gram(x)
    lam = torch.linalg.eigvals(torch.transpose(Phi,0,1) @ Phi)
    
    for i in range(100):
        m_post, L_post = bayes(alpha,beta,x,y)
        
        gam = 0        
        for l in lam.real:
            gam += l*beta/(alpha+l*beta)
        
        alpha = gam / torch.inner(m_post,m_post)
        
        diff = y - Phi @ m_post
        
        error = 1 / float(N-gam) * torch.inner(diff,diff)
        beta = 1/error
        
    # print('EM result: alpha',alpha,'beta',beta,'evidence',logmarginal(torch.log(alpha),torch.log(beta),x,y))
        
    return alpha, beta

np.random.seed(100)
np.random.seed(100)


torch.manual_seed(0)

N = 10

M = 9
s = 2.*torch.pi/float(M)

beta_true = 100.0

xval = np.linspace(0,2.*np.pi,1000)

x = np.random.uniform(0,2*np.pi,N)
# x = np.linspace(0,2*np.pi,N)
y = f(x) + np.random.normal(0.,np.sqrt(1./beta_true),N)

x = torch.tensor(x)
y = torch.tensor(y)

lalpha = torch.tensor([np.log(1.)],requires_grad=True)
lbeta = torch.tensor([np.log(1.)],requires_grad=True)

bfgs = torch.optim.LBFGS((lalpha,lbeta),max_iter=20)

alphas = []
betas  = []

# for epoch in range(n_epochs):
def closure():
    bfgs.zero_grad()
    loss = -logmarginal(lalpha,lbeta,x,y)
    loss.backward()
    
    alphas.append(np.exp(lalpha.detach().numpy()[0]))
    betas.append(np.exp(lbeta.detach().numpy()[0]))
    return loss
    
bfgs.step(closure)

print('BFGS result: alpha',torch.exp(lalpha),'beta',torch.exp(lbeta),'Evidence',logmarginal(lalpha,lbeta,x,y))
# print(alphas,betas)

alpha, beta = em(1e-1,1e-1,x,y)

print('EM result: alpha',alpha,'beta',beta,'Evidence',logmarginal(torch.log(alpha),torch.log(beta),x,y))

fig0, (ax0,ax1,ax2) = plt.subplots(1,3,figsize=(12,3),dpi=400)

ax0.set_xlabel('x')
ax1.set_xlabel('x')
ax2.set_xlabel('x')
ax0.set_ylabel('t')

ax0.set_ylim([-2,2])
ax1.set_ylim([-2,2])
ax2.set_ylim([-2,2])

_,_,y_map,y_var = bayes(alphas[0],betas[0],x,y,xval)

ax0.plot(x,y,'.',label='Observations')
ax0.plot(xval,f(xval),'black',label='Ground truth',linewidth=1)
ax0.plot(xval,y_map,label='Mean (no learning)')
ax0.fill_between(xval,y_map-1.96*np.sqrt(y_var),y_map+1.96*np.sqrt(y_var),
                 alpha=0.3,label='95% conf. interval',color='C1')
ax0.legend(fontsize=7)

_,_,y_map,y_var = bayes(alphas[4],betas[4],x,y,xval)

ax1.plot(x,y,'.',label='Observations')
ax1.plot(xval,f(xval),'black',label='Ground truth',linewidth=1)
ax1.plot(xval,y_map,label='Mean (5 iterations)')
ax1.fill_between(xval,y_map-1.96*np.sqrt(y_var),y_map+1.96*np.sqrt(y_var),
                 alpha=0.3,label='95% conf. interval',color='C1')
ax1.legend(fontsize=7)

_,_,y_map,y_var = bayes(alphas[-1],betas[-1],x,y,xval)

ax2.plot(x,y,'.',label='Observations')
ax2.plot(xval,f(xval),'black',label='Ground truth',linewidth=1)
ax2.plot(xval,y_map,label='Mean (20 iterations)')
ax2.fill_between(xval,y_map-1.96*np.sqrt(y_var),y_map+1.96*np.sqrt(y_var),
                 alpha=0.3,label='95% conf. interval',color='C1')
ax2.legend(fontsize=7)

glue("fig0", fig0, display=False)

print(alphas)

## Bayesian model selection

We have seen before that we can draw models with different $\mathbf{w}$ from our posterior distribution. Now imagine a bag of models with not only different weights but **different complexity, structures and assumptions**. In practical engineering, different options could be e.g. simple analytical calculations, a Finite Differences model and a Finite Element Model. Let us assume for now that we have a discrete number of model choices $\mathcal{M}_i$.

Because we are working in a Bayesian setting, it is elegant to also treat model choices $\mathcal{M}_i$ in the same way. Adopting priors $p(\mathcal{M}_i)$ for each choice of model, we can use Bayes' Theorem to compute posteriors:

$$
p(\mathcal{M}_i\vert\mathcal{D}) =
\displaystyle
\frac{p(\mathcal{M}_i)p(\mathcal{D}\vert\mathcal{M}_i)}{p(\mathcal{D})}
$$(bayesovermodels)

These posteriors probabilities provide an answer to the question *given our observed data, which model is most likely to explain what we see?* Once computed, we can follow the same approach we used to compute the {doc}`predictive distribution<linear_models>` $p(\hat{t})$ by marginalizing over $\mathbf{w}$ we can also do the same to average predictions coming from different models:

$$
p(\hat{t}\vert\hat{\mathbf{x}},\mathcal{D}) =
\displaystyle\sum_i^L
p(\hat{t}\vert\hat{\mathbf{x}},\mathcal{M}_i,\mathcal{D})
p(\mathcal{M}_i\vert\mathcal{D})
$$(predictiveovermodels)

where averaging becomes a summation because $\mathcal{M}$ is a discrete variable. The issue is that even if all conditionals $p(\hat{t}\vert\hat{\mathbf{x}},\mathcal{D})$ are Gaussian, this average is a **mixture of Gaussians** and therefore not itself Gaussian.

Suppose we do not want to handle this non-Gaussianity, which is often the case. As always, it is possible to take a shortcut, and as we will see next it is a familiar one.

## Empirical Bayes

Going back to Eq. {eq}`bayesovermodels`, the second term in the numerator is the so-called **model evidence**, i.e. how likely the observed data is given the model. This takes the same role as the likelihood function $p(\mathbf{t}\vert\mathbf{w})$ from before (Eq. {eq}`bayes1`). At that point we could take a shortcut and just get $\mathbf{w}$ that maximizes this likelihood (MLE). 

So why not do the same here? To get to $p(\mathcal{M}_i\vert\mathcal{D})$ we need to marginalize over $\mathbf{w}$:

$$
p(\mathcal{D}\vert\mathcal{M}_i)=
\displaystyle\int
p(\mathcal{D}\vert\mathcal{M}_i,\mathbf{w})p(\mathbf{w}\vert\mathcal{M}_i)\,\mathrm{d}\mathbf{w}
$$(modelevidence)

where we see that the only conditional left is the model $\mathcal{M}_i$ itself. We therefore call this the **marginal likelihood**. In the spirit of MLE, the best possible model choice would be the one that maximizes this likelihood.

More concretely, we can further specify $\mathcal{M}_i$ as models having different values for an arbitrary set of hyperparameters $\mathbf{h}$. We can then formally define our goal:

````{card}
**Empirical Bayes (Type-2 Maximum Likelihood Estimation)**
^^^
Given a **marginal** likelihood function $p(\mathcal{D}\vert\mathbf{h})$ where model parameters $\mathbf{w}$ have been marginalized but hyperparameters $\mathbf{h}$ remain, compute $\mathbf{h}$ that maximizes this function.
````

The crucial point here is that because of this marginalization over $\mathbf{w}$, the likelihood of the most flexible model is **usually not the highest** and extra terms appear in $p(\mathcal{D}\vert\mathcal{M}_i)$ that tend to penalize very flexible models. This means we do not necessarily need a validation set to learn something about $\mathbf{h}$ and can therefore use our complete dataset for learning!

```{admonition} Further Reading    
:class: tip    
This subject is also treated in detail in Section 3.4 with a few additional insights.
+++                         
{bdg-danger}`bishop-prml`     
``` 

## Back to basis function models

Coming back to our original setup with $\mathbf{h}=[\alpha,\beta]$, our marginal likelihood is:

$$
p(\mathbf{t}\vert\alpha,\beta) =
\displaystyle\int
p(\mathbf{t}\vert\mathbf{w},\beta)p(\mathbf{w}\vert\alpha)
\,\mathrm{d}\mathbf{w}
$$(bayesbasisfuncsevidence)

where we omit $\mathbf{X}$ for simplicity. Note how this is exactly the denominator of Bayes' Theorem when we compute $p(\mathbf{w}\vert\mathbf{t})$. We can therefore use the same {ref}`standard expressions<bayes-stdexpressions>` from before and compute the evidence as:

$$
\ln p(\mathbf{t}\vert\alpha,\beta) = \displaystyle\frac{N}{2}\ln\beta
-\frac{\beta}{2}\sum_{n=1}^{N}\left[t_n-\mathbf{m}_N^\mathrm{T}\boldsymbol{\phi}(\mathbf{x}_n)\right]^2
-\frac{\alpha}{2}\mathbf{m}_N^\mathrm{T}\mathbf{m}_N
{\color{red}
+ \frac{M}{2}\ln\alpha - \frac{1}{2}\ln\vert\mathbf{S}_N\vert - \frac{N}{2}\ln(2\pi)
}
$$(bayesbasisfuncslogmarginal)

where $\mathbf{m}_N$ and $\mathbf{S}_N$ come from the posterior $p(\mathbf{w}\vert\mathbf{t})$ and we take the natural logarithm of the resulting distribution. The terms in **red** tend to penalize models that are too flexible. They are often referred to as **Occam's factor**, in reference to the famous <a href="https://en.wikipedia.org/wiki/Occam%27s_razor" target="_blank">Occam's Razor</a>.

Note how Eq. {eq}`bayesbasisfuncslogmarginal` is now a function of only $\alpha$ and $\beta$. We can therefore use an optimization algorithm to find the values that maximize our model evidence.

The example below demonstrates this learning process for a dataset of 10 points fitted with 9 radial basis functions, with each plot showing different moments during the optimization procedure:

```{glue:} fig0
```

We can see how the initial values $\alpha=\beta=1$ sacrifice too much flexibility for parsimony. After 20 iterations, the optimizer finds that the model with $\alpha=4.4$ and $\beta=175.9$ gives the highest value for the log marginal likelihood. Crucially, we could get to these values without sacrificing any of our data points to serve as validation data.