SFR: Sparse Function-space Representation of Neural Networks
Function-space Parameterization of Neural Networks for Sequential Learning Aidan Scannell*, Riccardo Mereu*, Paul Chang, Ella Tamir, Joni Pajarinen, Arno Solin International Conference on Learning Representations (ICLR 2024) |
Sparse Function-space Representation of Neural Networks Aidan Scannell*, Riccardo Mereu*, Paul Chang, Ella Tamir, Joni Pajarinen, Arno Solin ICML 2023 Workshop on Duality Principles for Modern Machine Learning |
Abstract
Sequential learning paradigms pose challenges for gradient-based deep learning due to difficulties incorporating new data and retaining prior knowledge. While Gaussian processes elegantly tackle these problems, they struggle with scalability and handling rich inputs, such as images. To address these issues, we introduce a technique that converts neural networks from weight space to function space, through a dual parameterization. Our parameterization offers: (i) a way to scale function-space methods to large data sets via sparsification, (ii) retention of prior knowledge when access to past data is limited, and (iii) a mechanism to incorporate new data without retraining. Our experiments demonstrate that we can retain knowledge in continual learning and incorporate new data efficiently. We further show its strengths in uncertainty quantification and guiding exploration in model-based RL.
TL;DR
SFR
is a “posthoc” Bayesian deep learning method- Equip any trained NN with uncertainty estimates
SFR
can be viewed as a function-space Laplace approximation for NNsSFR
has several benefits over weight-space Laplace approximation for NNs:- Its function-space representation is effective for regularization in continual learning (CL)
- It has good uncertainty estimates
- We use them to guide exploration in model-based reinforcement learning (RL)
- It can incorporate new data without retraining the NN
SFR | GP | Laplace BNN | |
---|---|---|---|
Function-space | ✅ | ✅ | ❌ (weight space) |
Image inputs | ✅ | ❌ | ✅ |
Large data | ✅ | ❌ | ✅ |
Incorporate new data fast | ✅/❌ | ✅ | ❌ (requires retraining) |
Useage
See the notebooks for how to use our code for both regression and classification.
Minimal example
Here’s a short example:
import src
import torch
torch.set_default_dtype(torch.float64)
def func(x, noise=True):
return torch.sin(x * 5) / x + torch.cos(x * 10)
# Toy data set
X_train = torch.rand((100, 1)) * 2
Y_train = func(X_train, noise=True)
data = (X_train, Y_train)
# Training config
width = 64
num_epochs = 1000
batch_size = 16
learning_rate = 1e-3
delta = 0.00005 # prior precision
data_loader = torch.utils.data.DataLoader(
torch.utils.data.TensorDataset(*data), batch_size=batch_size
)
# Create a neural network
network = torch.nn.Sequential(
torch.nn.Linear(1, width),
torch.nn.Tanh(),
torch.nn.Linear(width, width),
torch.nn.Tanh(),
torch.nn.Linear(width, 1),
)
# Instantiate SFR (handles NN training/prediction as they're coupled via the prior/likelihood)
sfr = src.SFR(
network=network,
prior=src.priors.Gaussian(params=network.parameters, delta=delta),
likelihood=src.likelihoods.Gaussian(sigma_noise=2),
output_dim=1,
num_inducing=32,
dual_batch_size=None, # this reduces the memory required for computing dual parameters
jitter=1e-4,
)
sfr.train()
optimizer = torch.optim.Adam([{"params": sfr.parameters()}], lr=learning_rate)
for epoch_idx in range(num_epochs):
for batch_idx, batch in enumerate(data_loader):
x, y = batch
loss = sfr.loss(x, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
sfr.set_data(data) # This builds the dual parameters
# Make predictions in function space
X_test = torch.linspace(-0.7, 3.5, 300, dtype=torch.float64).reshape(-1, 1)
f_mean, f_var = sfr.predict_f(X_test)
# Make predictions in output space
y_mean, y_var = sfr.predict(X_test)
Citation
Please consider citing our conference paper:
@inproceedings{scannell2024functionspace,
title = {Function-space Parameterization of Neural Networks for Sequential Learning},
author = {Aidan Scannell and Riccardo Mereu and Paul Edmund Chang and Ella Tamir and Joni Pajarinen and Arno Solin},
booktitle = {The Twelfth International Conference on Learning Representations},
year = {2024},
url = {https://openreview.net/forum?id=2dhxxIKhqz}
}
Or our workshop paper:
@inproceedings{scannellSparse2023,
title = {Sparse Function-space Representation of Neural Networks},
author = {Aidan Scannell and Riccardo Mereu and Paul Chang and Ella Tami and Joni Pajarinen and Arno Solin},
booktitle = {ICML 2023 Workshop on Duality Principles for Modern Machine Learning},
year = {2023},
month = {7},
}