Linear Basis Function Models#
Show code cell content
# pip install packages that are not in Pyodide
%pip install ipympl==0.9.3
%pip install seaborn==0.12.2
# Import the necessary libraries
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from sklearn.neighbors import KNeighborsRegressor
from sklearn.preprocessing import StandardScaler
from mude_tools import magicplotter
from cycler import cycler
import seaborn as sns
%matplotlib widget
# Set the color scheme
sns.set_theme()
colors = [
"#0076C2",
"#EC6842",
"#A50034",
"#009B77",
"#FFB81C",
"#E03C31",
"#6CC24A",
"#EF60A3",
"#0C2340",
"#00B8C8",
"#6F1D77",
]
plt.rcParams["axes.prop_cycle"] = cycler(color=colors)
Introduction#
So far, we have been using a non-parametric model with k-nearest neighbors, meaning we needed access to the whole training dataset for each prediction. We will now focus on parametric models, namely linear models with basis functions. Parametric models are defined by a finite set of parameters calibrated in a training step. All we need for a prediction then are the parameter values. There is no longer a need to carry the whole dataset with us; the information used to make predictions is encoded in the model parameters. Once again, we will employ the simple sine function to demonstrate the concepts presented in this page.
Show code cell source
# The true function relating t to x
def f_truth(x, freq=1, **kwargs):
# Return a sine with a frequency of f
return np.sin(x * freq)
# The data generation function
def f_data(epsilon=0.7, N=100, **kwargs):
# Apply a seed if one is given
if "seed" in kwargs:
np.random.seed(kwargs["seed"])
# Get the minimum and maximum
xmin = kwargs.get("xmin", 0)
xmax = kwargs.get("xmax", 2 * np.pi)
# Generate N evenly spaced observation locations
x = np.linspace(xmin, xmax, N)
# Generate N noisy observations (1 at each location)
t = f_truth(x, **kwargs) + np.random.normal(0, epsilon, N)
# Return both the locations and the observations
return x, t
# Get the observed data
x, t = f_data()
# Plot the data and the ground truth
fig, ax = plt.subplots(figsize=(8, 4.5))
fig.canvas.toolbar_visible = False
ax.set_position([0.2, 0.1, 0.7, 0.8])
plt.plot(x, f_truth(x), "k-", label=r"Ground truth $f(x)$")
plt.plot(x, t, "x", label=r"Noisy data $(x,t)$")
plt.xlabel("x")
plt.ylabel("t")
plt.legend()
plt.show()
Linear model#
The key idea behind linear regression models is that they are linear in their parameters \(\mathbf{w}\). They might be linear in their inputs as well, although this does not necessarily need to case, as we will see later on in this page. The simplest approach is to model our target function \(y(x)\) as a linear combination of the coordinates \(x\):
In the one dimensional case, this is equivalent to fitting a straight line through our datapoints. The parameter \(w_0\), also referred to as bias (not to be confused with the model bias from the previous page), determines the intercept, \(w_1\) determines the slope. The introduction of a dummy input \(x_0 = 1\) allows us to write the model in a more concise way:
We will use the least squares error function from the previous page to fit our model, but will first show how this choice is motivated by a Maximum likelihood approach.
Maximum Likelihood Estimation#
Oftentimes, it is assumed that the target \(t\) is given by a deterministic function \(y(\mathbf{x}, \mathbf{w})\) with additive Gaussian noise so that
with precision \(\beta\) (which is defined as \(1/\sigma^2\)). Note that this assumption is only justified in case of a unimodal conditional distribution for \(t\). It can have grave influence on the model accuracy and its validity therefore must be carefully assessed.
As seen in the previous page, for the square loss function, the optimal prediction for some value \(\mathbf{x}\) is given by the conditional mean of the target variable \(t\). In this case, this conditional mean is given by:
Given a new input \(\mathbf{x}\), the distribution of \(t\) that follows from our model is
Consider now a dataset \(\mathcal{D}\) consisting of inputs \(\mathbf{X} = \{ \mathbf{x}_1, \dots, \mathbf{x}_n \}\) and targets \(\mathbf{t} = \{ t_1, \dots, t_n \}\). Assuming our datapoints are drawn independently from the same distribution (i.i.d. assumption), the likelihood of drawing this dataset from our model is
also referred to as the likelihood function. Taking the logarithm and expanding on the classic expression for a multivariate Gaussian distribution gives:
where we can identify our square-error loss function in the last term. Note that the first two terms are constant for a given dataset and have no influence on the parameter setting \(\bar{\mathbf{w}}\) that maximizes the likelihood. Those optimal parameters values \(\bar{\mathbf{w}}\) can be obtained by setting the gradient of our loss function w.r.t. \(\mathbf{w}\) to zero and solving for \(\mathbf{w}\).
It is convenient to concatenate all inputs to a design matrix \(\mathbf{X} = [\mathbf{x}_1^T, ..., \mathbf{x}_N^T]^T\). Solving for \(\mathbf{w}\) gives
which is the classical expression for a least-squares solution you have by now seen many times during the course.
Data normalization#
Before we move on to fitting the model, we normalize our data. This step is recommended for most machine learning techniques and is often even necessary. A non-normalized dataset almost always leads to a numerically more challenging optimization problem. Part of model selection is to ensure that our basis functions show the desired behavior in the relevant parts of the domain. We center and rescale our data, a process referred to as standardization, to operate in the vicinity of the origin only. We use the StandardScaler
class from the sklearn.preprocessing
library to carry out the standardization. The standardized dataset \(\hat{\mathcal{D}} = ( \hat{x}, \hat{t} )\) is obtained by subtracting the sample mean \(\mu\), and dividing by the sample standard deviation \(\sigma\) of the data:
Take a look below at the standardized and unstandardized data. Note that the standardization of the target \(t\) has a marginal effect, as the sine function is already centered at 0 and almost shows a standard deviation of 1. A 4 by 4 square has been added to indicate the region from \(-2 \hat{\sigma}\) to \(2 \hat{\sigma}\). As you can see, all input and output variables fall roughly in this interval, and this property allows for more stability when applying numerical solvers to the problem.
Note that there is no strictly correct way to shift and scale input data. Depending on the distribution of the data, a min-max scaling or a quantile scaling might lead to a better numerical setup. The dataset’s structure needs to be assessed carefully to make an informed decision on normalization.
Show code cell source
# generate data, instantiate scaler, and fit tranform
np.random.seed(0)
x, t = f_data(N=100)
xscaler, tscaler = StandardScaler(), StandardScaler()
x_norm, t_norm = xscaler.fit_transform(x[:, None]), tscaler.fit_transform(t[:, None])
# plot
fig, ax = plt.subplots(figsize=(8, 4.5))
fig.canvas.toolbar_visible = False
ax.set_position([0.2, 0.1, 0.7, 0.8])
ax.plot(x, t, "x", label="unnormalized data")
ax.plot(x_norm, t_norm, "x", label="data after normalization")
# Create a Rectangle patch
rect = patches.Rectangle((-2, -2), 4, 4, linewidth=1.0, edgecolor="k", facecolor="none")
# Add the patch to the Axes
ax.add_patch(rect)
ax.set_aspect("equal", "datalim")
plt.legend(loc="upper right")
plt.show()
With that out of the way, let us now define a few tools we need to state, solve, and visualize our problem.
Show code cell source
# Function for the basis functions
def LinearBasis(x, **kwargs):
"""
Represents a 1D linear basis.
"""
num_basis = 2 # The number of basis functions is 2 due to the dummy input
x = x.reshape(-1, 1)
return np.hstack((np.ones_like(x), x))
Show code cell source
# Let's test our implementation
print("Design matrix X given by:\n\n", LinearBasis(np.arange(0, 5)))
Show code cell source
# Define the prediction locations
# (note that these are different from the locations where we observed our data)
# x_pred = np.linspace(-1, 2*np.pi+1, 1000)
# Define a function that makes a prediction at the given locations, based on the given (x,t) data
def predict(x, t, x_pred, basis, normalize=True, **kwargs):
# reshape if necessary for scalers
x = x[:, None] if len(x.shape) == 1 else x
t = t[:, None] if len(t.shape) == 1 else t
x_pred = x_pred[:, None] if len(x_pred.shape) == 1 else x_pred
# normalize data (you will see why we have to do this further below)
xscaler, tscaler = StandardScaler(), StandardScaler()
if normalize == True:
x_sc, t_sc = xscaler.fit_transform(x), tscaler.fit_transform(t)
else:
x_sc, t_sc = x, t
# Get the variable matrix using the basis function phi
Phi = basis(x_sc.reshape(-1), **kwargs)
t_sc = t_sc.reshape(-1)
if normalize == True:
x_pred = xscaler.transform(x_pred).reshape(-1)
else:
x_pred = x_pred.reshape(-1)
# Get the coefficient vector
w = np.linalg.solve(Phi.T @ Phi, Phi.T @ t_sc)
# Make a prediction in the prediction locations
Phi_pred = basis(x_pred, **kwargs)
t_pred = Phi_pred @ w
# Return the predicted values
if normalize == True:
return tscaler.inverse_transform(t_pred[:, None]).reshape(-1)
else:
return t_pred[:, None].reshape(-1)
Note that we are not inverting \(\mathbf{X}^T \mathbf{X}\), which is extremely expensive for large amounts of data. A more efficient way to obtain \(\mathbf{w}\) is to solve \(\mathbf{X}^T \mathbf{X} \mathbf{w} = \mathbf{X}^T \mathbf{t}\). It is important to note that this system can only be solved if \(\mathbf{X}\) has full column rank.
Show code cell source
# Let's run our model with linear basis funcitons and plot the results
x_pred = np.linspace(-1, 2 * np.pi + 1, 1000)
plot = magicplotter(
f_data,
f_truth,
predict,
x_pred,
basis=LinearBasis,
pred_label="Prediction $y(x)$",
height=4.5,
)
plot.fig.canvas.toolbar_visible = False
plot.show()
It is clear from the plot that a linear model with linear features lacks the flexibility to fit the data well. A bias-variance decomposition analysis for this model would show that it has little variance but shows a strong bias. We now consider nonlinear functions of the input \(x\) as features/regressors to increase the flexibility of our linear model. A common approach is to use a set of polynomial basis functions,
but numerous other choices are possible. The full formulation for a model with \(M\) polynomial basis functions is thus
which shows how the model is still linear w.r.t. \(\mathbf{w}\), even though it is no longer linear in the input parameters. The design matrix for this more general case reads
As was the case for \(\mathbf{X}\), we need to ensure that \(\boldsymbol \Phi\) has full column rank. This is not always the case the case, for example if we have more basis functions that data points, or if our basis functions are not linearly independent.
Let’s implement a PolynomialBasis
function.
Show code cell source
# Here is a function for the polynomial basis functions:
def PolynomialBasis(x, degree, **kwargs): # **kwargs):
"""
A function that computes polynomial basis functions.
Arguments:
x - The datapoints
degree - The degree of the polynomial
"""
return np.array([x**i for i in range(degree + 1)]).transpose()
Show code cell source
# Let's test our implementation and visualize the polynomial basis functions
degree = 5
x_test = np.linspace(-1, 1, 100)
Phi_p_test = PolynomialBasis(x_test, degree=degree)[:, 1::]
# Plot the data and the ground truth
fig, ax = plt.subplots(figsize=(8, 4.5))
fig.canvas.toolbar_visible = False
ax.set_position([0.2, 0.1, 0.7, 0.8])
for i, row in enumerate(Phi_p_test.transpose()):
plt.plot(x_test, row, label=r"$\phi_{}(x)$".format(i + 1))
plt.xlabel(r"$x$")
plt.ylabel(r"$\phi(x)$")
plt.legend()
plt.show()
We obtain the linear model with nonlinear basis functions by replacing the coordinate vector \(x\) with the feature vector \(\boldsymbol{\phi}(x)\)
The solution procedure remains the same, and we can solve for \(\bar{\mathbf{w}}\) directly
Let’s take a look at the linear model with polynomial regressors.
Show code cell source
# Plot the resulting predictions
fig, ax = plt.subplots(2, 2, figsize=(9, 6), sharex="all", sharey="all")
fig.canvas.toolbar_visible = False
plt.suptitle(r"generalized linear regression for polynomials of degree $p$")
# Plot for degree=2
magicplotter(
f_data,
f_truth,
predict,
x_pred,
basis=PolynomialBasis,
degree=2,
ax=ax[0][0],
hide_legend=True,
pred_label=r"Prediction $y(x)$",
title=r"$degree={degree}$",
)
# Plot for degree=5
magicplotter(
f_data,
f_truth,
predict,
x_pred,
basis=PolynomialBasis,
degree=5,
ax=ax[0][1],
hide_legend=True,
title=r"$degree={degree}$",
)
# Plot for degree=10
magicplotter(
f_data,
f_truth,
predict,
x_pred,
basis=PolynomialBasis,
degree=10,
ax=ax[1][0],
hide_legend=True,
title=r"$degree={degree}$",
)
# Plot for degree=25
magicplotter(
f_data,
f_truth,
predict,
x_pred,
basis=PolynomialBasis,
degree=25,
ax=ax[1][1],
hide_legend=True,
title=r"$degree={degree}$",
)
# Add a general legend at the bottom of the plot
plt.subplots_adjust(bottom=0.2)
handles, labels = ax[0][0].get_legend_handles_labels()
fig.legend(handles, labels, loc="lower center")
plt.show()
That is looking much better already. However, the quality of the fit varies significantly with the degree of the polynomial basis. There seems to be an ideal model complexity for this specific problem. Try out the interactive tool below to get an idea of the interplay of the following variables:
\(p\), the degree of the polynomial basis
\(N\), the size of the training data set
\(freq\), the frequency of the underlying truth
\(\varepsilon\), the level of noise associated with the data
The seed can be updated to generate new random data sets
The truth can be hidden to simulate a situation that is closer to a practical setting
Show code cell source
plot1 = magicplotter(
f_data,
f_truth,
predict,
x_pred,
basis=PolynomialBasis,
pred_label=r"Prediction $y(x)$, $p={degree}$",
)
plot1.fig.canvas.toolbar_visible = False
plot1.add_sliders("epsilon", "degree", "N", "freq")
plot1.add_buttons("truth", "seed", "reset")
plot1.show()
A few questions that might have crossed your minds when playing with the tool:
With a small amount of data (\(N \leq 11\)), what happens if we have as many data points as parameters? \((p + 1 = N)\)
With a small amount of data (\(N \leq 11\)), what happens if we have more model parameters than data? \((p + 1 > N)\)
We only have access to data in the interval \([0,2\pi]\). How well does our model extrapolate beyond the data range?
Other choices of basis functions#
As mentioned previously, the polynomial basis is just one choice among many to define our model. Depending on the problem setting, a different set of basis functions might lead to better results. Another popular choice is the radial basis functions (also called Gaussian basis functions), given by
where \(\phi_j\) is centered around \(\mu_j\), \(l\) determines the width, and \(M\) refers to the number of basis functions. Let’s implement a RadialBasisFunctions
function:
Show code cell source
# Here is a function for the RadialBasisFunctions:
def RadialBasisFunctions(x, M_radial, l_radial, **kwargs):
"""
A function that computes radial basis functions.
Arguments:
x - The datapoints
M_radial - The number of basis functions
l_radial - The width of each basis function
"""
mu = np.linspace(-2, 2, M_radial)
num_basis = mu.shape[0]
Phi = np.ndarray((x.shape[0], num_basis))
for i in range(num_basis):
Phi[:, i] = np.exp(-0.5 * (x - mu[i]) ** 2 / l_radial**2)
return Phi
Show code cell source
# Let's test our implementation
l_radial = 0.5
M_radial = 9
x_test = np.linspace(-2, 2, 200)
Phi_radial_test = RadialBasisFunctions(x_test, M_radial=M_radial, l_radial=l_radial)
# Plot the data and the ground truth
fig, ax = plt.subplots(figsize=(8, 4.5))
fig.canvas.toolbar_visible = False
ax.set_position([0.2, 0.1, 0.7, 0.8])
for i, row in enumerate(Phi_radial_test.transpose()):
plt.plot(x_test, row, label=r"$\phi_{}(x)$".format(i + 1))
plt.xlabel(r"$x$")
plt.ylabel(r"$\phi(x)$")
plt.legend()
plt.show()
One of the attributes of this model is the locality of its individual functions. This means data in one part of the domain will not impact predictions in other parts of the domain. Periodicity can be achieved with a Fourier basis. Wavelets are popular in signal processing since they are localized in both frequency and space. It is up to the user to determine which basis function properties are desired for a given problem, and this is an important part of model selection. Try to implement some of these basis functions yourself and assess how well they compare with the pre-implemented ones.
Let’s see how well the linear model with radial basis functions performs on the sine wave problem. Keep in mind that the lengthscale parameter corresponds to the lengthscale in the standardized space.
Show code cell source
# Plot the resulting predictions
fig, ax = plt.subplots(2, 2, figsize=(9, 6), sharex="all", sharey="all")
fig.canvas.toolbar_visible = False
plt.suptitle(
r"generalized linear regression for radial basis functions with varying $M$ and $l$"
)
# Plot for l=0.5, M=5
magicplotter(
f_data,
f_truth,
predict,
x_pred,
basis=RadialBasisFunctions,
M_radial=5,
l_radial=0.5,
ax=ax[0][0],
hide_legend=True,
pred_label=r"Prediction $y(x)$",
title=r"$l = {l_radial}, M = {M_radial}$",
)
# Plot for l=0.5, M=15
magicplotter(
f_data,
f_truth,
predict,
x_pred,
basis=RadialBasisFunctions,
M_radial=15,
l_radial=0.5,
ax=ax[0][1],
hide_legend=True,
title=r"$l = {l_radial}, M = {M_radial}$",
)
# Plot for l=1.5, M=5
magicplotter(
f_data,
f_truth,
predict,
x_pred,
basis=RadialBasisFunctions,
M_radial=5,
l_radial=1.5,
ax=ax[1][0],
hide_legend=True,
title=r"$l = {l_radial}, M = {M_radial}$",
)
# Plot for l=1.5, M=15
magicplotter(
f_data,
f_truth,
predict,
x_pred,
basis=RadialBasisFunctions,
M_radial=15,
l_radial=1.5,
ax=ax[1][1],
hide_legend=True,
title=r"$l = {l_radial}, M = {M_radial}$",
)
# Add a general legend at the bottom of the plot
plt.subplots_adjust(bottom=0.2)
handles, labels = ax[0][0].get_legend_handles_labels()
fig.legend(handles, labels, loc="lower center")
plt.show()
The figure above shows four different combinations of the hyperparameters (number of basis functions and length scale). The quality of the fit depends strongly on the parameter setting, but a visual inspection indicates our model can replicate the general trend.
Think about the following questions:
Do you notice any major differences in the plots?
Do you think normalization improves model fitting in this particular case? Compare with the polynomial basis.
If so, why does this happen?
Final remarks#
This page introduced generalized linear models with arbitrarily high flexibility. We have seen that increased flexibility is not always good if we perform a simple least-squares analysis. We know from the previous page that we can introduce a validation set to prevent our model from overfitting; however, removing features is not always trivial. The following page will introduce you to ridge regression, an elegant method for controlling the model complexity.