Solution: Combining bayesian linear models and neural networks#

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from matplotlib.pyplot import Slider

%matplotlib widget
from cycler import cycler
import seaborn as sns

# Set the color scheme
sns.set_theme()
colors = [
    "#0076C2",
    "#EC6842",
    "#A50034",
    "#009B77",
    "#FFB81C",
    "#E03C31",
    "#6CC24A",
    "#EF60A3",
    "#0C2340",
    "#00B8C8",
    "#6F1D77",
]
plt.rcParams["axes.prop_cycle"] = cycler(color=colors)
torch.set_default_dtype(torch.float64)
# Class that normalizes data to follow Normal(0, 1) distribution.
class normUnitvar:
    def __init__(self, fullDataset):
        self.normmean = fullDataset.mean(axis=0)
        self.normstd = fullDataset.std(axis=0)

    def normalize(self, data):
        return (data - self.normmean) / self.normstd

    def denormalize(self, data):
        return data * self.normstd + self.normmean

Introduction#

In this exercise notebook, we will combine a bayesian linear model and a neural network to create a very flexible Bayesian model.

In a standard bayesian linear model the basis functions are pre-defined. Here, we use the flexibility of neural networks instead to find a good set of basis functions from data. Take a look of the proposed model in this schematic figure:

The model is created by first training a neural network to predict \(y\) based on \(x\). Then, keeping the network fixed, we replace the output layer of the network with the bayesian linear model, which uses the output of the last hidden layer as input.

After creating this model, we will use emperical bayes as discussed in the lecture to fit the hyperparameters.

First, we will create a simple dataset. Our data is generated by f_data which takes the ground truth but corrupts it with i.i.d. Gaussian noise, imitating a realistic setting where we do not have perfect measurements.

After generating our data, we standardize it and wrap it with a dataloader.

# The true function relating t to x
def f_truth(x):
    # Return a sine
    return torch.sin(x)


# The data is generated from the ground truth with i.i.d. gaussian noise
def f_data(x, beta_true, rng):
    # Generate N noisy observations (1 at each location)
    t = f_truth(x) + torch.normal(
        mean=0, std=math.sqrt(1 / beta_true), size=x.shape, generator=rng
    )

    # Return the observations
    return t
x = torch.arange(0, 10, 0.2).reshape(-1, 1)  # Reshape necessary for network prediction

# We manually remove data to create some gaps in the data.
x = torch.cat((x[:10], x[18:29], x[43:]), dim=0)

rng = torch.manual_seed(1)
beta_true = 50

t = f_data(x, beta_true, rng)

fig, ax = plt.subplots(1, 1, figsize=(6, 4))
ax.scatter(x, t, marker="x")
ax.set_xlabel("x")
ax.set_ylabel("y")
plt.show()
# Create normalizers and normalize the data
x_normalizer = normUnitvar(x)
t_normalizer = normUnitvar(t)

x_norm = x_normalizer.normalize(x)
t_norm = t_normalizer.normalize(t)

# We create a simple dataset with only training data and a single batch.
training_dataset = torch.utils.data.TensorDataset(x_norm, t_norm)

train_loader = DataLoader(training_dataset, batch_size=len(x), shuffle=True)

Exercise: Defining the neural network#

The first step in making the model is to create the neural network. The skelleton of the model is in place, but the following tasks still need to be done:

  • Define a neural network. You can make it any size you want, but it should have at least 1 hidden layer.

  • Create an additional function learned_phi, which uses the same layers as the forward function but returns the output of the last hidden layer, with an additional term for the bias. This output will act as the basis for the Bayesian model.

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # ---------------------- student exercise --------------------------------- #
        # Define 2 Linear network layers
        self.hid_n = 10
        self.fc1 = nn.Linear(1, self.hid_n)
        self.fc2 = nn.Linear(self.hid_n, 1)
        # ---------------------- student exercise --------------------------------- #

    def forward(self, x):
        # ---------------------- student exercise --------------------------------- #
        # First network layer + Activation
        x = F.sigmoid(self.fc1(x))
        # Second network layer + Linear Activation
        out = self.fc2(x)
        # ---------------------- student exercise --------------------------------- #
        return out

    def learned_phi(self, x):
        # ---------------------- student exercise --------------------------------- #
        # First network layer + Activation
        x = F.sigmoid(self.fc1(x))
        # No second layer
        # ---------------------- student exercise --------------------------------- #
        # Add bias term        
        phi = torch.cat((torch.ones(x.shape[0], 1), x), dim=1)
        return phi


net = Net()
print(net)

# Define the optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=0.1)

Exercise: Training loop#

Your next task is to train the neural network. which can be broken down into the following steps:

  • Implement a neural network training loop. Think about which parts are necessary when optimizing a Bayesian model.

  • Have the function return the trained network and the loss curve.

  • Call your function to train the model

def train_nn(network, train_loader, epochs, optimizer):
    """
    Class that trains a network
    :param network: torch.nn class with forward function
    :param train_loader: DataLoader object containing all training batches
    :param epochs: Int, number of epochs to train the model for
    :param optimizer: The optimizer used (e.g. ADAM/SGD)
    :return: [Trained network,
              Tensor of MSE losses per epoch]
    """

    # ---------------------- student exercise --------------------------------- #
    train_losses = torch.zeros(epochs)
    for epoch in range(epochs):
        if epoch % 200 == 0:
            print(f"Training epoch {epoch}")
        network.train()
        train_loss = 0
        for i, data in enumerate(train_loader):
            # Prediction
            x, y = data
            pred = network(x)
            # Loss
            loss = F.mse_loss(pred, y)

            train_loss += float(loss)

            optimizer.zero_grad()  # Reset gradients
            loss.backward()  # Backprop
            optimizer.step()  # Update parameters
        train_losses[epoch] = train_loss / len(
            train_loader
        )  # Store loss (independent of number of batches)
    # ---------------------- student exercise --------------------------------- #

    return network, train_losses
# Train the network
# Call your neural network training function, store the network as "net" and train_losses as "losses"
# ---------------------- student exercise --------------------------------- #
num_epochs = 500
net, losses = train_nn(net, train_loader, num_epochs, optimizer)
# ---------------------- student exercise --------------------------------- #

# Plotting the loss curve
fig, ax = plt.subplots(1, 1, figsize=(6, 4))
ax.plot(torch.arange(len(losses)), losses, label="Training")
ax.legend()
ax.set_xlabel(f"Epochs")
ax.set_ylabel(f"MSE")
ax.set_title(f"Loss curve")
plt.show()
# Plot the result
x, t = next(iter(train_loader))  # Get data

# Create prediction array
x_pred = torch.arange(-2, 12, 0.1).reshape(
    -1, 1
)  # Reshape necessary for network prediction

x_pred_norm = x_normalizer.normalize(x_pred)  # normalized data

# Predictions
pred = net(x_pred_norm).detach()  # Make prediction

# Denormalize
x_denorm = x_normalizer.denormalize(x)
t_denorm = t_normalizer.denormalize(t)
pred_denorm = t_normalizer.denormalize(pred)

# Scatter plot for shuffled data
fig, ax = plt.subplots(1, 1, figsize=(6, 4))
ax.scatter(x_denorm, t_denorm, label="True", color="C0", marker="x")
ax.plot(x_pred, pred_denorm, label="Prediction", color="C1")
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.legend()
plt.show()

If you’ve succesfully trained your model, you should see a reasonable overlap between the true values and model predictions.

Bayesian Linear Model#

With the neural network trained, we can turn our attention to the Bayesian Linear Model. We will create a class here that computes the mean and covariance of the posterior.

Recall from the lecture that the posterior mean is defined as

\[ \mathbf{m} = \beta\mathbf{S}\boldsymbol{\Phi}^\mathrm{T}\mathbf{t}, \]

and its covariance as

\[ \mathbf{S}^{-1} = \alpha\mathbf{I} + \beta\boldsymbol{\Phi}^\mathrm{T}\boldsymbol{\Phi}. \]

In addition, we create a function which returns the output distribution — again defined by its mean and covariance — for a new input:

\[ p(\hat{t}\vert\hat{\mathbf{x}},\beta) = \displaystyle\mathcal{N}\left( \hat{t}\vert \mathbf{m}^\mathrm{T}\boldsymbol{\phi}(\hat{\mathbf{x}}), \frac{1}{\beta} + \boldsymbol{\phi}(\hat{\mathbf{x}})^\mathrm{T}\mathbf{S}\boldsymbol{\phi}(\hat{\mathbf{x}}) \right). \]

The set of basis function is passed as an input, allowing it to handle arbitrary basis functions, such as the ones learned by the neural network defined earlier.

class bayesmodel:
    def __init__(self, basisfunc, x, t, alpha, beta):
        """
        Class for a general Bayesian Linear Model
        :param basisfunc: The basis function to be used
        :param x: Dataset inputs to build the model with
        :param t: Dataset targets to build the model with
        :param alpha: prior precision
        :param beta: noise precision
        """
        with torch.no_grad():
            self.phi = basisfunc
            # Initial values
            self.alpha = alpha
            self.beta = beta

            # Compute m_N & S_N
            self.phi_x = self.phi(x)
            phi_x_t = torch.transpose(self.phi_x, 0, 1)

            self.cov_inverse = self.alpha * torch.eye(
                self.phi_x.shape[1]
            ) + self.beta * torch.matmul(phi_x_t, self.phi_x)
            self.covariance = torch.inverse(self.cov_inverse)  # Bishop 3.54
            self.mean = self.beta * torch.matmul(
                self.covariance, torch.matmul(phi_x_t, t)
            )  # Bishop 3.53

    def pred(self, x_new):
        """
        :param x_new: Inputs for which to compute the mean and std.
        :return: [mean, std] for all x_new
        """
        with torch.no_grad():
            # Make new predictions
            phi_x_new = self.phi(x_new)
            phi_x_new_t = torch.transpose(phi_x_new, 0, 1)
            pred_m = torch.matmul(phi_x_new, self.mean)

            S_N = (1 / self.beta) * torch.eye(phi_x_new.shape[0]) + torch.matmul(
                phi_x_new, torch.matmul(self.covariance, phi_x_new_t)
            )  # Biship 3.59

            pred_var = torch.unsqueeze(torch.diag(S_N), 1)

        return pred_m, pred_var

Our Bayesian Linear Model is complete, and we can use it to make predictions. The final task for us is to find which values of \(\alpha\) and \(\beta\) are appropriate. Before we do that, let’s see our model in action assuming that we know these hyperparameters.

Assuming known hyperparameters#

Let’s test our model to see if it can make accurate predictions if we a priori fix the hyperparameters \(\alpha\) and \(\beta\). Try different values for \(\alpha\) and \(\beta\). What happens if you set either to very large or very small values?

# Load data for making bayesian model (same as NN training data)
x, t = next(iter(train_loader))

# To show the outcome, we use fixed values for alpha and beta here:
alpha = 1
beta = 50

# We pass the hidden network layer as basis function
bayesNet = bayesmodel(net.learned_phi, x, t, alpha, beta)

# Get truth
target = f_truth(x_pred)

# Predictions
pred_mean, pred_var = bayesNet.pred(x_pred_norm)

# Denormalize outputs
pred = t_normalizer.denormalize(pred_mean.squeeze())
variance = t_normalizer.denormalize(pred_var.squeeze())
t_data = t_normalizer.denormalize(t)

# Plotting
fig, ax = plt.subplots(1, 1, figsize=(6, 4))
ax.plot(x_pred.squeeze(), target, label="True", color="black")
pred_plt = ax.plot(x_pred.squeeze(), pred, label="Prediction", color="C1")
if variance is not None:
    std = torch.sqrt(variance)
    ax.fill_between(
        x_pred.squeeze(),
        pred - 1.96 * std,
        pred + 1.96 * std,
        color="C1",
        alpha=0.3,
        linewidth=0,
        label="95% conf. interval",
    )
ax.scatter(x_normalizer.denormalize(x), t_data, label="Data", color="C0", marker="x")
ax.legend()
plt.title(f"alpha={alpha:.2f}, beta={beta:.2f}")
plt.show()

Dealing with unknown hyperparameters#

In most realistic cases, \(\alpha\) and \(\beta\) are not known beforehand. Ideally, we would introduce priors for \(\alpha\) and \(\beta\), and make predictions by marginalizing with respect to these hyperparameters as well as with respect to the parameters w. However, the complete marginalization over all of these variables is analytically intractable. As discussed in the lecture, we can use an approximation, where we first integrate over the parameters w, and then set the hyperparameters \(\alpha\) and \(\beta\) by maximizing the marginal likelihood. This is known as Empirical Bayes or the evidence approximation.

Exercise:#

As discussed in the lecture, the marginal likelihood can be computed as:

\[ \ln p(\mathbf{t}\vert\alpha,\beta) = \displaystyle\frac{N}{2}\ln\beta -\frac{\beta}{2}\sum_{n=1}^{N}\left[t_n-\mathbf{m}_N^\mathrm{T}\boldsymbol{\phi}(\mathbf{x}_n)\right]^2 -\frac{\alpha}{2}\mathbf{m}_N^\mathrm{T}\mathbf{m}_N + \frac{M}{2}\ln\alpha - \frac{1}{2}\ln\vert\mathbf{S}_N\vert - \frac{N}{2}\ln(2\pi), \]

where \(N\) is the number of samples, \(M\) the number of basis functions, and \(\mathbf{m}_N\) and \(\mathbf{S}_N\) come from the posterior of the weights.

It is your task to finish the implementation of this function below.

(Note: M = net.hid_n + 1 will probably raise an error. Change hid_n to your own variable name definition for the number of hidden nodes in class Net defined in the first exercise.)

def logmarginal(alpha, beta, x, t):
    # Create a new model
    model = bayesmodel(net.learned_phi, x, t, alpha, beta)

    # We need to compute the determinant of the matrix of second derivatives of the error funcion (the Hessian) (see Biship 3.81).
    # This is equivalent to S_N:
    A = model.cov_inverse
    detA = torch.linalg.det(A)

    M = net.hid_n + 1  # Number of 'basis functions' (hidden nodes + bias)
    N = x.shape[0]  # Number of samples

    # Compute the error
    diff = (t - model.phi_x @ model.mean).flatten()

    model_m = model.mean.flatten()

    # ---------------------- student exercise --------------------------------- #
    error = beta / 2 * torch.inner(diff, diff) + alpha / 2 * torch.inner(
        model_m, model_m
    )  # Bishop 3.79

    log_marg_like = (
        M / 2 * torch.log(alpha)
        + N / 2 * torch.log(beta)
        - error
        - 1 / 2 * torch.log(detA)
        - N / 2 * torch.log(torch.tensor(2.0 * torch.pi))
    )
    # ---------------------- student exercise --------------------------------- #

    return log_marg_like

With the marginal likelihood implemented, we can optimize our hyperparameters. We take advantage of automatic differentiation, allowing us to use a gradient-based optimizer BFGS. The BFGS optimizer finds the values for \(\alpha\) and \(\beta\) that maximize the model evidence.

Finally, we plot the results for each iteration in an interactive plot.

# Initialize logalpha and logbeta, track the gradients.
logalpha = torch.tensor([0.0], requires_grad=True)  # Log(1) = 0
logbeta = torch.tensor([0.0], requires_grad=True)

# We use the bfgs optimizer.
bfgs = torch.optim.LBFGS((logalpha, logbeta), max_iter=20)

alphas = []
betas = []


def update():  # This function is called by the optimizer, it cannot have any arguments, therefore we use the global variables.
    bfgs.zero_grad()
    loss = -logmarginal(torch.exp(logalpha), torch.exp(logbeta), x, t)
    loss.backward()

    alphas.append(torch.exp(logalpha).detach()[0])
    betas.append(torch.exp(logbeta).detach()[0])
    return loss


# We perform a single step. In this step, multiple iterations of the optimizer are performed.
bfgs.step(update)

print("BFGS optimal result: alpha", alphas[-1], "beta", betas[-1])
# Initialize a new model with the alpha and beta
bayesNet = bayesmodel(net.learned_phi, x, t, alphas[0], betas[0])
# Predictions
pred_mean, pred_var = bayesNet.pred(x_pred_norm)

fig, ax = plt.subplots(1, 1, figsize=(6, 4))
fig.subplots_adjust(bottom=0.25)

# Denormalize outputs
pred = t_normalizer.denormalize(pred_mean.squeeze())
variance = t_normalizer.denormalize(pred_var.squeeze())
std = torch.sqrt(variance)
t_data = t_normalizer.denormalize(t)

# Plotting
(truth_l,) = ax.plot(x_pred.squeeze(), target, label="True", color="black")

(pred_l,) = ax.plot(x_pred.squeeze(), pred, label="Prediction", color="C1")
var_l = ax.fill_between(
    x_pred.squeeze(),
    pred - 2 * std,
    pred + 2 * std,
    color="C1",
    alpha=0.3,
    linewidth=0,
    label="95% conf. interval",
)

ax.scatter(x_normalizer.denormalize(x), t_data, label="Data", color="C0", marker="x")
ax.legend(loc="lower left")
ax.set_title(f"Iteration {0}, alpha={alphas[0].item():.2f}, beta={betas[0].item():.2f}")


def update_plot(val):
    val = int(val)
    # Initialize a new model with the alpha and beta
    bayesNet = bayesmodel(net.learned_phi, x, t, alphas[val], betas[val])

    # Predictions
    pred_mean, pred_var = bayesNet.pred(x_pred_norm)

    # Denormalize outputs
    pred = t_normalizer.denormalize(pred_mean.squeeze())
    variance = t_normalizer.denormalize(pred_var.squeeze())
    std = torch.sqrt(variance)

    # Adjust plot
    pred_l.set_ydata(pred)

    global var_l
    var_l.remove()
    var_l = ax.fill_between(
        x_pred.squeeze(),
        pred - 1.96 * std,
        pred + 1.96 * std,
        color="C1",
        alpha=0.3,
        linewidth=0,
        label="95% conf. interval",
    )

    ax.set_title(
        f"Iteration {val}, alpha={alphas[val].item():.2f}, beta={betas[val].item():.2f}"
    )
    ax.set_ylim([-2, 2])

    fig.canvas.draw_idle()


# Make a horizontal slider to control the frequency.
axiter = fig.add_axes([0.25, 0.1, 0.65, 0.03])
it_slider = Slider(
    ax=axiter,
    label="BFGS iteration",
    valmin=0,
    valmax=len(betas),
    valinit=0,
    valfmt="%0.0f",
)

# register the update function with each slider
it_slider.on_changed(update_plot)
plt.show()

Based on the outputs for the various BFGS iterations, how do \(\alpha\) and \(\beta\) influence the mean and variance of the posterior predictive distribution?

Reflection:#

With the code complete, it should be little work to try different networks. Try different network sizes and see how this affects the results.

The main takeaway from this exercise should be a good conceptual understanding of how this model works. Answering the following questions can help you develop this understanding:

  • What happens if you use a neural network without nonlinear activation functions as a basis?

  • What happens if you use a very flexible neural network?

  • We did not use a validation set here. Why are we still not overfitting?