Conference Paper Workshop Paper Code Experiments

Function-space Parameterization of Neural Networks for Sequential Learning
Aidan Scannell*, Riccardo Mereu*, Paul Chang, Ella Tamir, Joni Pajarinen, Arno Solin
International Conference on Learning Representations (ICLR 2024)
Sparse Function-space Representation of Neural Networks
Aidan Scannell*, Riccardo Mereu*, Paul Chang, Ella Tamir, Joni Pajarinen, Arno Solin
ICML 2023 Workshop on Duality Principles for Modern Machine Learning

SFR

Abstract

Sequential learning paradigms pose challenges for gradient-based deep learning due to difficulties incorporating new data and retaining prior knowledge. While Gaussian processes elegantly tackle these problems, they struggle with scalability and handling rich inputs, such as images. To address these issues, we introduce a technique that converts neural networks from weight space to function space, through a dual parameterization. Our parameterization offers: (i) a way to scale function-space methods to large data sets via sparsification, (ii) retention of prior knowledge when access to past data is limited, and (iii) a mechanism to incorporate new data without retraining. Our experiments demonstrate that we can retain knowledge in continual learning and incorporate new data efficiently. We further show its strengths in uncertainty quantification and guiding exploration in model-based RL.

TL;DR

  • SFR is a “posthoc” Bayesian deep learning method
    • Equip any trained NN with uncertainty estimates
  • SFR can be viewed as a function-space Laplace approximation for NNs
  • SFR has several benefits over weight-space Laplace approximation for NNs:
    • Its function-space representation is effective for regularization in continual learning (CL)
    • It has good uncertainty estimates
      • We use them to guide exploration in model-based reinforcement learning (RL)
    • It can incorporate new data without retraining the NN
  SFR GP Laplace BNN
Function-space ❌ (weight space)
Image inputs
Large data
Incorporate new data fast ✅/❌ ❌ (requires retraining)

Useage

See the notebooks for how to use our code for both regression and classification.

Minimal example

Here’s a short example:

import src
import torch

torch.set_default_dtype(torch.float64)

def func(x, noise=True):
    return torch.sin(x * 5) / x + torch.cos(x * 10)

# Toy data set
X_train = torch.rand((100, 1)) * 2
Y_train = func(X_train, noise=True)
data = (X_train, Y_train)

# Training config
width = 64
num_epochs = 1000
batch_size = 16
learning_rate = 1e-3
delta = 0.00005  # prior precision
data_loader = torch.utils.data.DataLoader(
    torch.utils.data.TensorDataset(*data), batch_size=batch_size
)

# Create a neural network
network = torch.nn.Sequential(
    torch.nn.Linear(1, width),
    torch.nn.Tanh(),
    torch.nn.Linear(width, width),
    torch.nn.Tanh(),
    torch.nn.Linear(width, 1),
)

# Instantiate SFR (handles NN training/prediction as they're coupled via the prior/likelihood)
sfr = src.SFR(
    network=network,
    prior=src.priors.Gaussian(params=network.parameters, delta=delta),
    likelihood=src.likelihoods.Gaussian(sigma_noise=2),
    output_dim=1,
    num_inducing=32,
    dual_batch_size=None, # this reduces the memory required for computing dual parameters
    jitter=1e-4,
)

sfr.train()
optimizer = torch.optim.Adam([{"params": sfr.parameters()}], lr=learning_rate)
for epoch_idx in range(num_epochs):
    for batch_idx, batch in enumerate(data_loader):
        x, y = batch
        loss = sfr.loss(x, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

sfr.set_data(data) # This builds the dual parameters

# Make predictions in function space
X_test = torch.linspace(-0.7, 3.5, 300, dtype=torch.float64).reshape(-1, 1)
f_mean, f_var = sfr.predict_f(X_test)

# Make predictions in output space
y_mean, y_var = sfr.predict(X_test)

Citation

Please consider citing our conference paper:

@inproceedings{scannell2024functionspace,
    title       = {Function-space Parameterization of Neural Networks for Sequential Learning},
    author      = {Aidan Scannell and Riccardo Mereu and Paul Edmund Chang and Ella Tamir and Joni Pajarinen and Arno Solin},
    booktitle   = {The Twelfth International Conference on Learning Representations},
    year        = {2024},
    url         = {https://openreview.net/forum?id=2dhxxIKhqz}
}

Or our workshop paper:

@inproceedings{scannellSparse2023,
  title           = {Sparse Function-space Representation of Neural Networks},
  author          = {Aidan Scannell and Riccardo Mereu and Paul Chang and Ella Tami and Joni Pajarinen and Arno Solin},
  booktitle       = {ICML 2023 Workshop on Duality Principles for Modern Machine Learning},
  year            = {2023},
  month           = {7},
}