Skip to content

SUQ Documentation

Base Functions

Classes

SUQ_Base

Bases: Module

Base class for SUQ models.

Provides core functionality for:

  • Managing likelihood type (regression or classification)
  • Probit-based approximation for classification
  • NLPD-based fitting of the scale factor

Parameters:

Name Type Description Default
likelihood str

Either classification or regression.

required
scale_init float

Initial value for the scale factor parameter.

required
Source code in suq/base_suq.py
class SUQ_Base(nn.Module):
    """
    Base class for SUQ models.

    Provides core functionality for:

    - Managing likelihood type (regression or classification)
    - Probit-based approximation for classification
    - NLPD-based fitting of the scale factor

    Args:
        likelihood (str): Either `classification` or `regression`.
        scale_init (float): Initial value for the scale factor parameter.
    """

    def __init__(self, likelihood, scale_init):
        super().__init__()

        if likelihood not in ['classification', 'regression']:
            raise ValueError(f"Invalid likelihood type {likelihood}")

        self.likelihood = likelihood
        self.scale_factor = nn.Parameter(torch.Tensor([scale_init]).to(device))

    def probit_approximation(self, out_mean, out_var):
        """
        Applies a probit approximation to compute class probabilities from the latent Gaussian distribution.

        Args:
            out_mean (Tensor): Latent function mean, shape [B, num_classes]
            out_var (Tensor): Latent function variance, shape [B, num_classes] or [B, num_classes, num_classes]

        Returns:
            posterior_predict_mean (Tensor): Predicted class probabilities, shape [B, num_classes]
        """

        if out_var.dim() == 3:
            kappa = 1 / torch.sqrt(1. + np.pi / 8 * out_var.diagonal(dim1=1, dim2=2))
        else:
            kappa = 1 / torch.sqrt(1. + np.pi / 8 * out_var)

        posterior_predict_mean = torch.softmax(kappa * out_mean, dim=-1)
        return posterior_predict_mean

    def fit_scale_factor(self, data_loader, n_epoches, lr, speedup = True, verbose = False):
        """
        Fits the scale factor for predictive variance using negative log predictive density (NLPD).

        Args:
            data_loader (DataLoader): Dataloader containing (input, target) pairs
            n_epoches (int): Number of epochs for optimization
            lr (float): Learning rate for scale optimizer
            speedup (bool): If True (classification only), caches forward pass outputs to accelerate fitting
            verbose (bool): If True, prints NLPD at each epoch

        Returns:
            total_train_nlpd (List[float]): Average NLPD per epoch over training data
        """
        print("fit scale factor")
        optimizer = torch.optim.Adam([self.scale_factor], lr)
        total_train_nlpd = []

        # store intermediate result and pack it into a data loader, so we only need to do one forward pass
        if speedup:

            if self.likelihood == 'regression':
                raise ValueError(f"Speed up not supported for regression atm")

            if self.likelihood == 'classification':

                f_mean = []
                f_var = []
                labels = []

                for (X, y) in tqdm(data_loader, desc= "packing f_mean f_var into a dataloader"):
                    out_mean, out_var = self.forward_latent(X.to(device))
                    f_mean.append(out_mean.detach().cpu().numpy())
                    f_var.append(out_var.detach().cpu().numpy())
                    if y.dim() == 2:
                        labels.append(y.numpy().argmax(1).reshape(-1, 1))
                    if y.dim() == 1:
                        labels.append(y.numpy().reshape(-1, 1))

                f_mean = np.vstack(f_mean)
                f_var = np.vstack(f_var)
                labels = np.vstack(labels)

                scale_fit_dataset = torch_dataset(f_mean, f_var, labels)
                scale_fit_dataloader = DataLoader(scale_fit_dataset, batch_size=16, shuffle=True)

                for epoch in tqdm(range(n_epoches), desc="fitting scaling factor"):
                    running_nlpd = 0
                    for data_pair in scale_fit_dataloader:
                        x_mean, x_var_label = data_pair
                        num_class = x_mean.shape[1]
                        x_mean = x_mean.to(device)
                        x_var, label = x_var_label.split(num_class, dim=1)
                        x_var = x_var.to(device)
                        label = label.to(device)

                        optimizer.zero_grad()
                        # make prediction
                        x_var = x_var / self.scale_factor.to(device)
                        posterior_predict_mean = self.probit_approximation(x_mean, x_var)
                        # construct log posterior predictive distribution
                        posterior_predictive_dist = Categorical(posterior_predict_mean)
                        # calculate nlpd and update
                        nlpd = -posterior_predictive_dist.log_prob(label).mean()
                        nlpd.backward()
                        optimizer.step()
                        # log nlpd
                        running_nlpd += nlpd.item()
                    total_train_nlpd.append(running_nlpd / len(scale_fit_dataloader))
                    if verbose:
                        print(f"epoch {epoch + 1}, nlpd {total_train_nlpd[-1]}")

                del scale_fit_dataloader
                del scale_fit_dataset

        else:

            if self.likelihood == 'classification':
                for epoch in tqdm(range(n_epoches), desc="fitting scaling factor"):
                    running_nlpd = 0
                    for (data, label) in data_loader:

                        data = data.to(device)
                        label = label.to(device)

                        optimizer.zero_grad()
                        # make prediction
                        posterior_predict_mean = self.forward(data)
                        # construct log posterior predictive distribution
                        posterior_predictive_dist = Categorical(posterior_predict_mean)
                        # calculate nlpd and update
                        nlpd = -posterior_predictive_dist.log_prob(label).mean()
                        nlpd.backward()
                        optimizer.step()
                        # log nlpd
                        running_nlpd += nlpd.item()
                    total_train_nlpd.append(running_nlpd / len(data_loader))
                    if verbose:
                        print(f"epoch {epoch + 1}, nlpd {total_train_nlpd[-1]}")


            if self.likelihood == 'regression':
                for epoch in tqdm(range(n_epoches), desc="fitting scaling factor"):
                    running_nlpd = 0
                    for (data, label) in data_loader:
                        data = data.to(device)
                        label = label.to(device)

                        optimizer.zero_grad()
                        # make prediction
                        posterior_predict_mean, posterior_predict_var = self.forward(data)
                        # construct log posterior predictive distribution
                        posterior_predictive_dist = Normal(posterior_predict_mean, posterior_predict_var.sqrt())
                        # calculate nlpd and update
                        nlpd = -posterior_predictive_dist.log_prob(label).mean()
                        nlpd.backward()
                        optimizer.step()
                        # log nlpd
                        running_nlpd += nlpd.item()

                    total_train_nlpd.append(running_nlpd / len(data_loader))

                    if verbose:
                        print(f"epoch {epoch + 1}, nlpd {total_train_nlpd[-1]}")

        return total_train_nlpd
Functions
probit_approximation(out_mean, out_var)

Applies a probit approximation to compute class probabilities from the latent Gaussian distribution.

Parameters:

Name Type Description Default
out_mean Tensor

Latent function mean, shape [B, num_classes]

required
out_var Tensor

Latent function variance, shape [B, num_classes] or [B, num_classes, num_classes]

required

Returns:

Name Type Description
posterior_predict_mean Tensor

Predicted class probabilities, shape [B, num_classes]

Source code in suq/base_suq.py
def probit_approximation(self, out_mean, out_var):
    """
    Applies a probit approximation to compute class probabilities from the latent Gaussian distribution.

    Args:
        out_mean (Tensor): Latent function mean, shape [B, num_classes]
        out_var (Tensor): Latent function variance, shape [B, num_classes] or [B, num_classes, num_classes]

    Returns:
        posterior_predict_mean (Tensor): Predicted class probabilities, shape [B, num_classes]
    """

    if out_var.dim() == 3:
        kappa = 1 / torch.sqrt(1. + np.pi / 8 * out_var.diagonal(dim1=1, dim2=2))
    else:
        kappa = 1 / torch.sqrt(1. + np.pi / 8 * out_var)

    posterior_predict_mean = torch.softmax(kappa * out_mean, dim=-1)
    return posterior_predict_mean
fit_scale_factor(data_loader, n_epoches, lr, speedup=True, verbose=False)

Fits the scale factor for predictive variance using negative log predictive density (NLPD).

Parameters:

Name Type Description Default
data_loader DataLoader

Dataloader containing (input, target) pairs

required
n_epoches int

Number of epochs for optimization

required
lr float

Learning rate for scale optimizer

required
speedup bool

If True (classification only), caches forward pass outputs to accelerate fitting

True
verbose bool

If True, prints NLPD at each epoch

False

Returns:

Name Type Description
total_train_nlpd List[float]

Average NLPD per epoch over training data

Source code in suq/base_suq.py
def fit_scale_factor(self, data_loader, n_epoches, lr, speedup = True, verbose = False):
    """
    Fits the scale factor for predictive variance using negative log predictive density (NLPD).

    Args:
        data_loader (DataLoader): Dataloader containing (input, target) pairs
        n_epoches (int): Number of epochs for optimization
        lr (float): Learning rate for scale optimizer
        speedup (bool): If True (classification only), caches forward pass outputs to accelerate fitting
        verbose (bool): If True, prints NLPD at each epoch

    Returns:
        total_train_nlpd (List[float]): Average NLPD per epoch over training data
    """
    print("fit scale factor")
    optimizer = torch.optim.Adam([self.scale_factor], lr)
    total_train_nlpd = []

    # store intermediate result and pack it into a data loader, so we only need to do one forward pass
    if speedup:

        if self.likelihood == 'regression':
            raise ValueError(f"Speed up not supported for regression atm")

        if self.likelihood == 'classification':

            f_mean = []
            f_var = []
            labels = []

            for (X, y) in tqdm(data_loader, desc= "packing f_mean f_var into a dataloader"):
                out_mean, out_var = self.forward_latent(X.to(device))
                f_mean.append(out_mean.detach().cpu().numpy())
                f_var.append(out_var.detach().cpu().numpy())
                if y.dim() == 2:
                    labels.append(y.numpy().argmax(1).reshape(-1, 1))
                if y.dim() == 1:
                    labels.append(y.numpy().reshape(-1, 1))

            f_mean = np.vstack(f_mean)
            f_var = np.vstack(f_var)
            labels = np.vstack(labels)

            scale_fit_dataset = torch_dataset(f_mean, f_var, labels)
            scale_fit_dataloader = DataLoader(scale_fit_dataset, batch_size=16, shuffle=True)

            for epoch in tqdm(range(n_epoches), desc="fitting scaling factor"):
                running_nlpd = 0
                for data_pair in scale_fit_dataloader:
                    x_mean, x_var_label = data_pair
                    num_class = x_mean.shape[1]
                    x_mean = x_mean.to(device)
                    x_var, label = x_var_label.split(num_class, dim=1)
                    x_var = x_var.to(device)
                    label = label.to(device)

                    optimizer.zero_grad()
                    # make prediction
                    x_var = x_var / self.scale_factor.to(device)
                    posterior_predict_mean = self.probit_approximation(x_mean, x_var)
                    # construct log posterior predictive distribution
                    posterior_predictive_dist = Categorical(posterior_predict_mean)
                    # calculate nlpd and update
                    nlpd = -posterior_predictive_dist.log_prob(label).mean()
                    nlpd.backward()
                    optimizer.step()
                    # log nlpd
                    running_nlpd += nlpd.item()
                total_train_nlpd.append(running_nlpd / len(scale_fit_dataloader))
                if verbose:
                    print(f"epoch {epoch + 1}, nlpd {total_train_nlpd[-1]}")

            del scale_fit_dataloader
            del scale_fit_dataset

    else:

        if self.likelihood == 'classification':
            for epoch in tqdm(range(n_epoches), desc="fitting scaling factor"):
                running_nlpd = 0
                for (data, label) in data_loader:

                    data = data.to(device)
                    label = label.to(device)

                    optimizer.zero_grad()
                    # make prediction
                    posterior_predict_mean = self.forward(data)
                    # construct log posterior predictive distribution
                    posterior_predictive_dist = Categorical(posterior_predict_mean)
                    # calculate nlpd and update
                    nlpd = -posterior_predictive_dist.log_prob(label).mean()
                    nlpd.backward()
                    optimizer.step()
                    # log nlpd
                    running_nlpd += nlpd.item()
                total_train_nlpd.append(running_nlpd / len(data_loader))
                if verbose:
                    print(f"epoch {epoch + 1}, nlpd {total_train_nlpd[-1]}")


        if self.likelihood == 'regression':
            for epoch in tqdm(range(n_epoches), desc="fitting scaling factor"):
                running_nlpd = 0
                for (data, label) in data_loader:
                    data = data.to(device)
                    label = label.to(device)

                    optimizer.zero_grad()
                    # make prediction
                    posterior_predict_mean, posterior_predict_var = self.forward(data)
                    # construct log posterior predictive distribution
                    posterior_predictive_dist = Normal(posterior_predict_mean, posterior_predict_var.sqrt())
                    # calculate nlpd and update
                    nlpd = -posterior_predictive_dist.log_prob(label).mean()
                    nlpd.backward()
                    optimizer.step()
                    # log nlpd
                    running_nlpd += nlpd.item()

                total_train_nlpd.append(running_nlpd / len(data_loader))

                if verbose:
                    print(f"epoch {epoch + 1}, nlpd {total_train_nlpd[-1]}")

    return total_train_nlpd

MLP Functions with Diagonal Covariance

Classes

SUQ_Linear_Diag

Bases: Module

Linear layer with uncertainty propagation under SUQ, with a diagonal Gaussian posterior.

Wraps a standard nn.Linear layer and applies closed-form mean and variance propagation. See the SUQ paper for theoretical background and assumptions.

Parameters:

Name Type Description Default
org_linear Linear

The original linear layer to wrap

required
w_var Tensor

Element-wise variance of the weights W. Shape: [D_out, D_in]

required
b_var Tensor

Element-wise variance of the bias b. Shape: [D_out, ]

required
Source code in suq/diag_suq_mlp.py
class SUQ_Linear_Diag(nn.Module):
    """
    Linear layer with uncertainty propagation under SUQ, with a diagonal Gaussian posterior.

    Wraps a standard `nn.Linear` layer and applies closed-form mean and variance propagation. See the SUQ paper for theoretical background and assumptions.

    Args:
        org_linear (nn.Linear): The original linear layer to wrap      
        w_var (Tensor): Element-wise variance of the weights `W`. Shape: `[D_out, D_in]`
        b_var (Tensor): Element-wise variance of the bias `b`. Shape: `[D_out, ]`
    """
    def __init__(self, org_linear, w_var, b_var):
        super().__init__()

        self.weight = org_linear.weight.data
        self.bias = org_linear.bias.data
        self.w_var = w_var
        self.b_var = b_var

    def forward(self, a_mean, a_var): 
        """
        Forward pass with uncertainty propagation through a SUQ linear layer.

        Args:
            a_mean (Tensor): Input mean. Shape: `[B, D_in]`
            a_var (Tensor): Input element-wise variance. Shape: `[B, D_in]`

        Returns:
            h_mean (Tensor): Mean of the output `h'. Shape: `[B, D_out]`
            h_var (Tensor): Element-wise variance of output `h'. Shape: `[B, D_out]`
        """

        if a_var == None:
            a_var = torch.zeros_like(a_mean).to(a_mean.device)

        h_mean, h_var = forward_aW_diag(a_mean, a_var, self.weight, self.bias, self.w_var, self.b_var)

        return h_mean, h_var
Functions
forward(a_mean, a_var)

Forward pass with uncertainty propagation through a SUQ linear layer.

Parameters:

Name Type Description Default
a_mean Tensor

Input mean. Shape: [B, D_in]

required
a_var Tensor

Input element-wise variance. Shape: [B, D_in]

required

Returns:

Name Type Description
h_mean Tensor

Mean of the output h'. Shape:[B, D_out]`

h_var Tensor

Element-wise variance of output h'. Shape:[B, D_out]`

Source code in suq/diag_suq_mlp.py
def forward(self, a_mean, a_var): 
    """
    Forward pass with uncertainty propagation through a SUQ linear layer.

    Args:
        a_mean (Tensor): Input mean. Shape: `[B, D_in]`
        a_var (Tensor): Input element-wise variance. Shape: `[B, D_in]`

    Returns:
        h_mean (Tensor): Mean of the output `h'. Shape: `[B, D_out]`
        h_var (Tensor): Element-wise variance of output `h'. Shape: `[B, D_out]`
    """

    if a_var == None:
        a_var = torch.zeros_like(a_mean).to(a_mean.device)

    h_mean, h_var = forward_aW_diag(a_mean, a_var, self.weight, self.bias, self.w_var, self.b_var)

    return h_mean, h_var

SUQ_Activation_Diag

Bases: Module

Activation layer with closed-form uncertainty propagation under SUQ, with a diagonal Gaussian posterior.

Wraps a standard activation function and applies a first-order approximation to propagate input variance through the nonlinearity. See the SUQ paper for theoretical background and assumptions.

Parameters:

Name Type Description Default
afun Callable

A PyTorch activation function (e.g. nn.ReLU())

required
Source code in suq/diag_suq_mlp.py
class SUQ_Activation_Diag(nn.Module):
    """
    Activation layer with closed-form uncertainty propagation under SUQ, with a diagonal Gaussian posterior.

    Wraps a standard activation function and applies a first-order approximation to propagate input variance through the nonlinearity. See the SUQ paper for theoretical background and assumptions.

    Args:
        afun (Callable): A PyTorch activation function (e.g. `nn.ReLU()`)
    """

    def __init__(self, afun):        
        super().__init__()
        self.afun = afun

    def forward(self, h_mean, h_var):
        """
        Forward pass with uncertainty propagation through a SUQ activation layer.

        Args:
            h_mean (Tensor): Mean of the pre-activations `h`. Shape: `[B, D]`
            h_var (Tensor): Element-wise variance of the pre-activation `h`. Shape: `[B, D]`

        Returns:
            a_mean (Tensor): Mean of the activation `a`. Shape: [B, D]
            a_var (Tensor): Element-wise variance of the activation `a`. Shape: `[B, D]`
        """
        a_mean, a_var = forward_activation_implicit_diag(self.afun, h_mean, h_var)
        return a_mean, a_var
Functions
forward(h_mean, h_var)

Forward pass with uncertainty propagation through a SUQ activation layer.

Parameters:

Name Type Description Default
h_mean Tensor

Mean of the pre-activations h. Shape: [B, D]

required
h_var Tensor

Element-wise variance of the pre-activation h. Shape: [B, D]

required

Returns:

Name Type Description
a_mean Tensor

Mean of the activation a. Shape: [B, D]

a_var Tensor

Element-wise variance of the activation a. Shape: [B, D]

Source code in suq/diag_suq_mlp.py
def forward(self, h_mean, h_var):
    """
    Forward pass with uncertainty propagation through a SUQ activation layer.

    Args:
        h_mean (Tensor): Mean of the pre-activations `h`. Shape: `[B, D]`
        h_var (Tensor): Element-wise variance of the pre-activation `h`. Shape: `[B, D]`

    Returns:
        a_mean (Tensor): Mean of the activation `a`. Shape: [B, D]
        a_var (Tensor): Element-wise variance of the activation `a`. Shape: `[B, D]`
    """
    a_mean, a_var = forward_activation_implicit_diag(self.afun, h_mean, h_var)
    return a_mean, a_var

SUQ_BatchNorm_Diag

Bases: Module

BatchNorm layer with closed-form uncertainty propagation under SUQ, with a diagonal Gaussian posterior.

Wraps nn.BatchNorm1d and adjusts input variance using batch normalization statistics and scale parameters. See the SUQ paper for theoretical background and assumptions.

Parameters:

Name Type Description Default
BatchNorm BatchNorm1d

The original batch norm layer

required
Source code in suq/diag_suq_mlp.py
class SUQ_BatchNorm_Diag(nn.Module):
    """
    BatchNorm layer with closed-form uncertainty propagation under SUQ, with a diagonal Gaussian posterior.

    Wraps `nn.BatchNorm1d` and adjusts input variance using batch normalization statistics and scale parameters. See the SUQ paper for theoretical background and assumptions.

    Args:
        BatchNorm (nn.BatchNorm1d): The original batch norm layer
    """

    def __init__(self, BatchNorm):
        super().__init__()

        self.BatchNorm = BatchNorm

    def forward(self, x_mean, x_var):
        """
        Forward pass with uncertainty propagation through a SUQ BatchNorm layer.

        Args:
            x_mean (Tensor): Input mean. Shape: [B, D]
            x_var (Tensor): Input element-wise variance. Shape: [B, D]

        Returns:
            out_mean (Tensor): Output mean after batch normalization. Shape: [B, D]
            out_var (Tensor): Output element-wise variance after batch normalization. Shape: [B, D]
        """

        with torch.no_grad():

            out_mean = self.BatchNorm.forward(x_mean)
            out_var = forward_batch_norm_diag(x_mean, x_var, self.BatchNorm.weight, 1e-5)

        return out_mean, out_var
Functions
forward(x_mean, x_var)

Forward pass with uncertainty propagation through a SUQ BatchNorm layer.

Parameters:

Name Type Description Default
x_mean Tensor

Input mean. Shape: [B, D]

required
x_var Tensor

Input element-wise variance. Shape: [B, D]

required

Returns:

Name Type Description
out_mean Tensor

Output mean after batch normalization. Shape: [B, D]

out_var Tensor

Output element-wise variance after batch normalization. Shape: [B, D]

Source code in suq/diag_suq_mlp.py
def forward(self, x_mean, x_var):
    """
    Forward pass with uncertainty propagation through a SUQ BatchNorm layer.

    Args:
        x_mean (Tensor): Input mean. Shape: [B, D]
        x_var (Tensor): Input element-wise variance. Shape: [B, D]

    Returns:
        out_mean (Tensor): Output mean after batch normalization. Shape: [B, D]
        out_var (Tensor): Output element-wise variance after batch normalization. Shape: [B, D]
    """

    with torch.no_grad():

        out_mean = self.BatchNorm.forward(x_mean)
        out_var = forward_batch_norm_diag(x_mean, x_var, self.BatchNorm.weight, 1e-5)

    return out_mean, out_var

SUQ_MLP_Diag

Bases: SUQ_Base

Multilayer perceptron model with closed-form uncertainty propagation under SUQ, with a diagonal Gaussian posterior.

Wraps a standard MLP, converting its layers into SUQ-compatible components. Supports both classification and regression via predictive Gaussian approximation.

Note

The input model should correspond to the latent function only: - For regression, this is the full model (including final output layer). - For classification, exclude the softmax layer and pass only the logit-producing part.

Parameters:

Name Type Description Default
org_model Module

The original MLP model to convert

required
posterior_variance Tensor

Flattened posterior variance vector

required
likelihood str

Either 'classification' or 'regression'

required
scale_init float

Initial scale factor

1.0
sigma_noise float

noise level (for regression)

None
Source code in suq/diag_suq_mlp.py
class SUQ_MLP_Diag(SUQ_Base):
    """
    Multilayer perceptron model with closed-form uncertainty propagation under SUQ, with a diagonal Gaussian posterior.

    Wraps a standard MLP, converting its layers into SUQ-compatible components.
    Supports both classification and regression via predictive Gaussian approximation.

    Note:
        The input model should correspond to the latent function only:
        - For regression, this is the full model (including final output layer).
        - For classification, exclude the softmax layer and pass only the logit-producing part.

    Args:
        org_model (nn.Module): The original MLP model to convert
        posterior_variance (Tensor): Flattened posterior variance vector
        likelihood (str): Either 'classification' or 'regression'
        scale_init (float, optional): Initial scale factor
        sigma_noise (float, optional): noise level (for regression)
    """

    def __init__(self, org_model, posterior_variance, likelihood, scale_init = 1.0, sigma_noise = None):
        super().__init__(likelihood, scale_init)

        self.sigma_noise = sigma_noise
        self.convert_model(org_model, posterior_variance)

    def forward_latent(self, data, out_var = None):
        """
        Compute the predictive mean and variance of the latent function before applying the likelihood.

        Traverses the model layer by layer, propagating mean and variance through each SUQ-wrapped layer.

        Args:
            data (Tensor): Input data. Shape: [B, D_in]
            out_var (Tensor or None): Optional input variance. Shape: [B, D_in]

        Returns:
            out_mean (Tensor): Output mean after final layer. Shape: [B, D_out]
            out_var (Tensor): Output element-wise variance after final layer. Shape: [B, D_out]
        """

        out_mean = data

        if isinstance(self.model, nn.Sequential):
            for layer in self.model:
                out_mean, out_var = layer.forward(out_mean, out_var)
        ##TODO: other type of model            

        out_var = out_var / self.scale_factor

        return out_mean, out_var

    def forward(self, data):
        """
        Compute the predictive distribution based on the model's likelihood setting.

        For classification, use probit-approximation.
        For regression, returns the latent mean and total predictive variance.

        Args:
            data (Tensor): Input data. Shape: [B, D]

        Returns:
            If classification:
                Tensor: Class probabilities. Shape: [B, num_classes]
            If regression:
                Tuple[Tensor, Tensor]: Output mean and element-wise variance. Shape: [B, D_out]
        """

        out_mean, out_var = self.forward_latent(data)

        if self.likelihood == 'classification':
            kappa = 1 / torch.sqrt(1. + np.pi / 8 * out_var)
            return torch.softmax(kappa * out_mean, dim=-1)

        if self.likelihood == 'regression':
            return out_mean, out_var + self.sigma_noise ** 2

    def convert_model(self, org_model, posterior_variance):
        """
        Converts a deterministic MLP into a SUQ-compatible model with diagonal posterior.

        Each layer is replaced with its corresponding SUQ module (e.g. linear, activation, batchnorm), using the provided flattened posterior variance vector.

        Args:
            org_model (nn.Module): The original model to convert (latent function only)
            posterior_variance (Tensor): Flattened posterior variance for Bayesian parameters
        """

        p_model = copy.deepcopy(org_model)

        loc = 0
        for n, layer in p_model.named_modules():
            if isinstance(layer, nn.Linear):

                D_out, D_in = layer.weight.data.shape
                num_param = torch.numel(parameters_to_vector(layer.parameters()))
                num_weight_param = D_out * D_in

                covariance_block = posterior_variance[loc : loc + num_param]

                b_var = torch.zeros_like(layer.bias.data).to(layer.bias.data.device)
                w_var = torch.zeros_like(layer.weight.data).to(layer.bias.data.device)

                for k in range(D_out):
                    b_var[k] = covariance_block[num_weight_param + k]
                    for i in range(D_in):
                        w_var[k][i] = covariance_block[k * D_in + i]

                new_layer = SUQ_Linear_Diag(layer, w_var, b_var)

                loc += num_param
                setattr(p_model, n, new_layer)

            if isinstance(layer, nn.BatchNorm1d):
                new_layer = SUQ_BatchNorm_Diag(layer)
                setattr(p_model, n, new_layer)

            if type(layer).__name__ in torch.nn.modules.activation.__all__:
                new_layer = SUQ_Activation_Diag(layer)
                setattr(p_model, n, new_layer)

        self.model = p_model
Functions
forward_latent(data, out_var=None)

Compute the predictive mean and variance of the latent function before applying the likelihood.

Traverses the model layer by layer, propagating mean and variance through each SUQ-wrapped layer.

Parameters:

Name Type Description Default
data Tensor

Input data. Shape: [B, D_in]

required
out_var Tensor or None

Optional input variance. Shape: [B, D_in]

None

Returns:

Name Type Description
out_mean Tensor

Output mean after final layer. Shape: [B, D_out]

out_var Tensor

Output element-wise variance after final layer. Shape: [B, D_out]

Source code in suq/diag_suq_mlp.py
def forward_latent(self, data, out_var = None):
    """
    Compute the predictive mean and variance of the latent function before applying the likelihood.

    Traverses the model layer by layer, propagating mean and variance through each SUQ-wrapped layer.

    Args:
        data (Tensor): Input data. Shape: [B, D_in]
        out_var (Tensor or None): Optional input variance. Shape: [B, D_in]

    Returns:
        out_mean (Tensor): Output mean after final layer. Shape: [B, D_out]
        out_var (Tensor): Output element-wise variance after final layer. Shape: [B, D_out]
    """

    out_mean = data

    if isinstance(self.model, nn.Sequential):
        for layer in self.model:
            out_mean, out_var = layer.forward(out_mean, out_var)
    ##TODO: other type of model            

    out_var = out_var / self.scale_factor

    return out_mean, out_var
forward(data)

Compute the predictive distribution based on the model's likelihood setting.

For classification, use probit-approximation. For regression, returns the latent mean and total predictive variance.

Parameters:

Name Type Description Default
data Tensor

Input data. Shape: [B, D]

required

Returns:

Type Description

If classification: Tensor: Class probabilities. Shape: [B, num_classes]

If regression: Tuple[Tensor, Tensor]: Output mean and element-wise variance. Shape: [B, D_out]

Source code in suq/diag_suq_mlp.py
def forward(self, data):
    """
    Compute the predictive distribution based on the model's likelihood setting.

    For classification, use probit-approximation.
    For regression, returns the latent mean and total predictive variance.

    Args:
        data (Tensor): Input data. Shape: [B, D]

    Returns:
        If classification:
            Tensor: Class probabilities. Shape: [B, num_classes]
        If regression:
            Tuple[Tensor, Tensor]: Output mean and element-wise variance. Shape: [B, D_out]
    """

    out_mean, out_var = self.forward_latent(data)

    if self.likelihood == 'classification':
        kappa = 1 / torch.sqrt(1. + np.pi / 8 * out_var)
        return torch.softmax(kappa * out_mean, dim=-1)

    if self.likelihood == 'regression':
        return out_mean, out_var + self.sigma_noise ** 2
convert_model(org_model, posterior_variance)

Converts a deterministic MLP into a SUQ-compatible model with diagonal posterior.

Each layer is replaced with its corresponding SUQ module (e.g. linear, activation, batchnorm), using the provided flattened posterior variance vector.

Parameters:

Name Type Description Default
org_model Module

The original model to convert (latent function only)

required
posterior_variance Tensor

Flattened posterior variance for Bayesian parameters

required
Source code in suq/diag_suq_mlp.py
def convert_model(self, org_model, posterior_variance):
    """
    Converts a deterministic MLP into a SUQ-compatible model with diagonal posterior.

    Each layer is replaced with its corresponding SUQ module (e.g. linear, activation, batchnorm), using the provided flattened posterior variance vector.

    Args:
        org_model (nn.Module): The original model to convert (latent function only)
        posterior_variance (Tensor): Flattened posterior variance for Bayesian parameters
    """

    p_model = copy.deepcopy(org_model)

    loc = 0
    for n, layer in p_model.named_modules():
        if isinstance(layer, nn.Linear):

            D_out, D_in = layer.weight.data.shape
            num_param = torch.numel(parameters_to_vector(layer.parameters()))
            num_weight_param = D_out * D_in

            covariance_block = posterior_variance[loc : loc + num_param]

            b_var = torch.zeros_like(layer.bias.data).to(layer.bias.data.device)
            w_var = torch.zeros_like(layer.weight.data).to(layer.bias.data.device)

            for k in range(D_out):
                b_var[k] = covariance_block[num_weight_param + k]
                for i in range(D_in):
                    w_var[k][i] = covariance_block[k * D_in + i]

            new_layer = SUQ_Linear_Diag(layer, w_var, b_var)

            loc += num_param
            setattr(p_model, n, new_layer)

        if isinstance(layer, nn.BatchNorm1d):
            new_layer = SUQ_BatchNorm_Diag(layer)
            setattr(p_model, n, new_layer)

        if type(layer).__name__ in torch.nn.modules.activation.__all__:
            new_layer = SUQ_Activation_Diag(layer)
            setattr(p_model, n, new_layer)

    self.model = p_model

Functions

forward_aW_diag(a_mean, a_var, weight, bias, w_var, b_var)

Compute the mean and element-wise variance of h = a @ W^T + b when the posterior has diagonal covariance.

Parameters:

Name Type Description Default
a_mean Tensor

Mean of the input a. Shape: [B, D_in].

required
a_var Tensor

Variance of the input a. Shape: [B, D_in]

required
weight Tensor

Mean of the weights W. Shape: [D_out, D_in]

required
bias Tensor

Mean of the bias b. Shape: [D_out, ]

required
b_var Tensor

Element-wise variance of the bias b. Shape: [D_out, ]

required
w_var Tensor

Element-wise variance of the weights W. Shape: [D_out, D_in]

required

Returns:

Name Type Description
h_mean Tensor

Mean of the pre-activations h. Shape: [B, D_out]

h_var Tensor

Element-wise variance of the pre-activations h. Shape: [B, D_out]

Source code in suq/diag_suq_mlp.py
def forward_aW_diag(a_mean, a_var, weight, bias, w_var, b_var):
    """
    Compute the mean and element-wise variance of `h = a @ W^T + b` when the posterior has diagonal covariance.

    Args:
        a_mean (Tensor): Mean of the input `a`. Shape: `[B, D_in]`.
        a_var (Tensor): Variance of the input `a`. Shape: `[B, D_in] `
        weight (Tensor): Mean of the weights `W`. Shape: `[D_out, D_in]`
        bias (Tensor): Mean of the bias `b`. Shape: `[D_out, ]`
        b_var (Tensor): Element-wise variance of the bias `b`. Shape: `[D_out, ]`
        w_var (Tensor): Element-wise variance of the weights `W`. Shape: `[D_out, D_in]`

    Returns: 
        h_mean (Tensor): Mean of the pre-activations `h`. Shape: `[B, D_out]`
        h_var (Tensor): Element-wise variance of the pre-activations `h`. Shape: `[B, D_out]`
    """

    # calculate mean(h)
    h_mean = F.linear(a_mean, weight, bias)

    # calculate var(h)
    weight_mean2_var_sum = weight ** 2 + w_var # [D_out, D_in]
    h_var = a_mean **2 @ w_var.T + a_var @ weight_mean2_var_sum.T + b_var

    return h_mean, h_var

forward_activation_implicit_diag(activation_func, h_mean, h_var)

Approximate the distribution of a = g(h) given h ~ N(h_mean, h_var), where h_var is the element-wise variance of pre-activation h. Uses a first-order Taylor expansion: a ~ N(g(h_mean), g'(h_mean)^T @ h_var @ g'(h_mean)).

Parameters:

Name Type Description Default
activation_func Callable

A PyTorch activation function g(·) (e.g. nn.ReLU())

required
h_mean Tensor

Mean of the pre-activations h. Shape: [B, D]

required
h_var Tensor

Element-wise variance of the pre-activations h. Shape: [B, D]

required

Returns:

Name Type Description
a_mean Tensor

Mean of the activations a. Shape: [B, D]

a_var Tensor

Element-wise variance of the activations a. Shape: [B, D]

Source code in suq/diag_suq_mlp.py
def forward_activation_implicit_diag(activation_func, h_mean, h_var):

    """
    Approximate the distribution of `a = g(h)` given `h ~ N(h_mean, h_var)`, where `h_var` 
    is the element-wise variance of pre-activation `h`.
    Uses a first-order Taylor expansion: `a ~ N(g(h_mean), g'(h_mean)^T @ h_var @ g'(h_mean))`.

    Args:
        activation_func (Callable): A PyTorch activation function `g(·)` (e.g. `nn.ReLU()`)
        h_mean (Tensor): Mean of the pre-activations `h`. Shape: `[B, D]`
        h_var (Tensor): Element-wise variance of the pre-activations `h`. Shape: `[B, D]`

    Returns:
        a_mean (Tensor): Mean of the activations `a`. Shape: `[B, D]`
        a_var (Tensor): Element-wise variance of the activations `a`. Shape: `[B, D]`
    """

    h_mean_grad = h_mean.detach().clone().requires_grad_()

    a_mean = activation_func(h_mean_grad)
    a_mean.retain_grad()
    a_mean.backward(torch.ones_like(a_mean)) #[N, D]

    nabla = h_mean_grad.grad #[N, D]
    a_var = nabla ** 2 * h_var

    return a_mean.detach(), a_var

forward_batch_norm_diag(h_var, bn_weight, bn_running_var, bn_eps)

Compute the output variance when a distribution h ~ N(h_mean, h_var) is passed through a BatchNorm layer.

Parameters:

Name Type Description Default
h_var Tensor

Element-wise variance of the input h. Shape: [B, D].

required
bn_weight Tensor

Batch normalization scale factor (gamma). Shape: [D,].

required
bn_running_var Tensor

Running variance used in batch normalization. Shape: [D,].

required
bn_eps float

Small constant added to the denominator for numerical stability.

required

Returns:

Name Type Description
output_var Tensor

Element-wise variance of the output after batch normalization. Shape: [B, D].

Source code in suq/diag_suq_mlp.py
def forward_batch_norm_diag(h_var, bn_weight, bn_running_var, bn_eps):

    """
    Compute the output variance when a distribution `h ~ N(h_mean, h_var)`
    is passed through a BatchNorm layer.

    Args:
        h_var (Tensor): Element-wise variance of the input `h`. Shape: `[B, D]`.
        bn_weight (Tensor): Batch normalization scale factor (gamma). Shape: `[D,]`.
        bn_running_var (Tensor): Running variance used in batch normalization. Shape: `[D,]`.
        bn_eps (float): Small constant added to the denominator for numerical stability.

    Returns:
        output_var (Tensor): Element-wise variance of the output after batch normalization. Shape: `[B, D]`.
    """

    scale_factor = (1 / (bn_running_var.reshape(1, -1) + bn_eps)) * bn_weight.reshape(1, -1) **2 # [B, D]
    output_var = scale_factor * h_var # [B, D]

    return output_var

Transformer Functions with Diagonal Covariance

Classes

SUQ_LayerNorm_Diag

Bases: Module

LayerNorm module with uncertainty propagation under SUQ.

Wraps nn.LayerNorm and propagates input variance analytically using running statistics. See the SUQ paper for theoretical background and assumptions.

Parameters:

Name Type Description Default
LayerNorm LayerNorm

The original layer norm module to wrap

required
Source code in suq/diag_suq_transformer.py
class SUQ_LayerNorm_Diag(nn.Module):
    """
    LayerNorm module with uncertainty propagation under SUQ.

    Wraps `nn.LayerNorm` and propagates input variance analytically using running statistics. See the SUQ paper for theoretical background and assumptions.

    Args:
        LayerNorm (nn.LayerNorm): The original layer norm module to wrap
    """

    def __init__(self, LayerNorm):
        super().__init__()

        self.LayerNorm = LayerNorm

    def forward(self, x_mean, x_var):
        """
        Forward pass with uncertainty propagation through a SUQ LayerNorm layer.

        Args:
            x_mean (Tensor): Input mean. Shape: [B, T, D]
            x_var (Tensor): Input element-wise variance. Shape: [B, T, D]

        Returns:
            out_mean (Tensor): Output mean after layer normalization. Shape: [B, T, D]
            out_var (Tensor): Output element-wise variance after layer normalization. Shape: [B, T, D]
        """

        with torch.no_grad():

            out_mean = self.LayerNorm.forward(x_mean)
            out_var = forward_layer_norm_diag(x_mean, x_var, self.LayerNorm.weight, 1e-5)

        return out_mean, out_var
Functions
forward(x_mean, x_var)

Forward pass with uncertainty propagation through a SUQ LayerNorm layer.

Parameters:

Name Type Description Default
x_mean Tensor

Input mean. Shape: [B, T, D]

required
x_var Tensor

Input element-wise variance. Shape: [B, T, D]

required

Returns:

Name Type Description
out_mean Tensor

Output mean after layer normalization. Shape: [B, T, D]

out_var Tensor

Output element-wise variance after layer normalization. Shape: [B, T, D]

Source code in suq/diag_suq_transformer.py
def forward(self, x_mean, x_var):
    """
    Forward pass with uncertainty propagation through a SUQ LayerNorm layer.

    Args:
        x_mean (Tensor): Input mean. Shape: [B, T, D]
        x_var (Tensor): Input element-wise variance. Shape: [B, T, D]

    Returns:
        out_mean (Tensor): Output mean after layer normalization. Shape: [B, T, D]
        out_var (Tensor): Output element-wise variance after layer normalization. Shape: [B, T, D]
    """

    with torch.no_grad():

        out_mean = self.LayerNorm.forward(x_mean)
        out_var = forward_layer_norm_diag(x_mean, x_var, self.LayerNorm.weight, 1e-5)

    return out_mean, out_var

SUQ_Classifier_Diag

Bases: Module

Classifier head with uncertainty propagation under SUQ, with a diagonal Gaussian posterior.

Wraps a standard linear classifier and applies closed-form mean and variance propagation. See the SUQ paper for theoretical background and assumptions.

Parameters:

Name Type Description Default
classifier Linear

The final classification head

required
w_var Tensor

Element-wise variance of weight. Shape: [D_out, D_in]

required
b_var Tensor

Element-wise variance of bias. Shape: [D_out]

required
Source code in suq/diag_suq_transformer.py
class SUQ_Classifier_Diag(nn.Module):
    """
    Classifier head with uncertainty propagation under SUQ, with a diagonal Gaussian posterior.

    Wraps a standard linear classifier and applies closed-form mean and variance propagation.
    See the SUQ paper for theoretical background and assumptions.

    Args:
        classifier (nn.Linear): The final classification head
        w_var (Tensor): Element-wise variance of weight. Shape: `[D_out, D_in]`
        b_var (Tensor): Element-wise variance of bias. Shape: `[D_out]`
    """

    def __init__(self, classifier, w_var, b_var):
        super().__init__()

        self.weight = classifier.weight
        self.bias = classifier.bias
        self.w_var = w_var.reshape(self.weight.shape)
        self.b_var = b_var.reshape(self.bias.shape)

    def forward(self, x_mean, x_var):
        """
        Forward pass with uncertainty propagation through a SUQ linear layer.

        Args.
            x_mean (Tensor): Input mean. Shape: `[B, D_in]`
            x_var (Tensor): Input element-wise variance. Shape: `[B, D_in]`

        Returns:
            h_mean (Tensor): Output mean. Shape: `[B, D_out]`
            h_var (Tensor): Output element-wise variance. Shape: `[B, D_out]`
        """
        with torch.no_grad():
            h_mean, h_var = forward_aW_diag(x_mean, x_var, self.weight.data, self.bias.data, self.w_var, self.b_var)
        return h_mean, h_var
Functions
forward(x_mean, x_var)

Forward pass with uncertainty propagation through a SUQ linear layer.

Args. x_mean (Tensor): Input mean. Shape: [B, D_in] x_var (Tensor): Input element-wise variance. Shape: [B, D_in]

Returns:

Name Type Description
h_mean Tensor

Output mean. Shape: [B, D_out]

h_var Tensor

Output element-wise variance. Shape: [B, D_out]

Source code in suq/diag_suq_transformer.py
def forward(self, x_mean, x_var):
    """
    Forward pass with uncertainty propagation through a SUQ linear layer.

    Args.
        x_mean (Tensor): Input mean. Shape: `[B, D_in]`
        x_var (Tensor): Input element-wise variance. Shape: `[B, D_in]`

    Returns:
        h_mean (Tensor): Output mean. Shape: `[B, D_out]`
        h_var (Tensor): Output element-wise variance. Shape: `[B, D_out]`
    """
    with torch.no_grad():
        h_mean, h_var = forward_aW_diag(x_mean, x_var, self.weight.data, self.bias.data, self.w_var, self.b_var)
    return h_mean, h_var

SUQ_TransformerMLP_Diag

Bases: Module

MLP submodule of a transformer block with uncertainty propagation under SUQ.

Supports both deterministic and Bayesian forward modes with closed-form variance propagation. Used internally in SUQ_Transformer_Block_Diag.

Parameters:

Name Type Description Default
MLP Module

Original MLP submodule

required
determinstic bool

Whether to treat the MLP weights as deterministic

True
w_fc_var Tensor

Variance of the first linear layer in MLP(if Bayesian)

None
w_proj_var Tensor

Variance of the second linear layer in MLP (if Bayesian)

None
Source code in suq/diag_suq_transformer.py
class SUQ_TransformerMLP_Diag(nn.Module):
    """
    MLP submodule of a transformer block with uncertainty propagation under SUQ.

    Supports both deterministic and Bayesian forward modes with closed-form variance propagation.
    Used internally in `SUQ_Transformer_Block_Diag`.

    Args:
        MLP (nn.Module): Original MLP submodule
        determinstic (bool): Whether to treat the MLP weights as deterministic
        w_fc_var (Tensor, optional): Variance of the first linear layer in MLP(if Bayesian)
        w_proj_var (Tensor, optional): Variance of the second linear layer in MLP (if Bayesian)
    """

    def __init__(self, MLP, determinstic = True, w_fc_var = None, w_proj_var = None):
        super().__init__()

        self.MLP = MLP
        self.determinstic = determinstic
        if not determinstic:
            self.w_fc_var = w_fc_var.reshape(self.MLP.c_fc.weight.shape)
            self.w_proj_var = w_proj_var.reshape(self.MLP.c_proj.weight.shape)

    def forward(self, x_mean, x_var):
        """
        Forward pass with uncertainty propagation through a SUQ Transformer MLP layer.

        Args:
            x_mean (Tensor): Input mean. Shape [B, T, D]
            x_var (Tensor): Input element-wise variance. Shape [B, T, D]

        Returns:
            h_mean (Tensor): Output mean. Shape [B, T, D]
            h_var (Tensor): Output element-wise variance. Shape [B, T, D]
        """

        # first fc layer
        with torch.no_grad():
            if self.determinstic:
                h_mean, h_var = forward_linear_diag_determinstic_weight(x_mean, x_var, self.MLP.c_fc.weight.data, self.MLP.c_fc.bias.data)
            else:
                h_mean, h_var = forward_linear_diag_Bayesian_weight(x_mean, x_var, self.MLP.c_fc.weight.data, self.w_fc_var, self.MLP.c_fc.bias.data)
        # activation function
        h_mean, h_var = forward_activation_diag(self.MLP.gelu, h_mean, h_var)
        # second fc layer
        with torch.no_grad():
            if self.determinstic:
                h_mean, h_var = forward_linear_diag_determinstic_weight(h_mean, h_var, self.MLP.c_proj.weight.data, self.MLP.c_proj.bias.data)
            else:
                h_mean, h_var = forward_linear_diag_Bayesian_weight(h_mean, h_var, self.MLP.c_proj.weight.data, self.w_proj_var, self.MLP.c_proj.bias.data)

        return h_mean, h_var
Functions
forward(x_mean, x_var)

Forward pass with uncertainty propagation through a SUQ Transformer MLP layer.

Parameters:

Name Type Description Default
x_mean Tensor

Input mean. Shape [B, T, D]

required
x_var Tensor

Input element-wise variance. Shape [B, T, D]

required

Returns:

Name Type Description
h_mean Tensor

Output mean. Shape [B, T, D]

h_var Tensor

Output element-wise variance. Shape [B, T, D]

Source code in suq/diag_suq_transformer.py
def forward(self, x_mean, x_var):
    """
    Forward pass with uncertainty propagation through a SUQ Transformer MLP layer.

    Args:
        x_mean (Tensor): Input mean. Shape [B, T, D]
        x_var (Tensor): Input element-wise variance. Shape [B, T, D]

    Returns:
        h_mean (Tensor): Output mean. Shape [B, T, D]
        h_var (Tensor): Output element-wise variance. Shape [B, T, D]
    """

    # first fc layer
    with torch.no_grad():
        if self.determinstic:
            h_mean, h_var = forward_linear_diag_determinstic_weight(x_mean, x_var, self.MLP.c_fc.weight.data, self.MLP.c_fc.bias.data)
        else:
            h_mean, h_var = forward_linear_diag_Bayesian_weight(x_mean, x_var, self.MLP.c_fc.weight.data, self.w_fc_var, self.MLP.c_fc.bias.data)
    # activation function
    h_mean, h_var = forward_activation_diag(self.MLP.gelu, h_mean, h_var)
    # second fc layer
    with torch.no_grad():
        if self.determinstic:
            h_mean, h_var = forward_linear_diag_determinstic_weight(h_mean, h_var, self.MLP.c_proj.weight.data, self.MLP.c_proj.bias.data)
        else:
            h_mean, h_var = forward_linear_diag_Bayesian_weight(h_mean, h_var, self.MLP.c_proj.weight.data, self.w_proj_var, self.MLP.c_proj.bias.data)

    return h_mean, h_var

SUQ_Attention_Diag

Bases: Module

Self-attention module with uncertainty propagation under SUQ.

Supports deterministic and Bayesian value projections, with optional diagonal covariance assumptions. For details see SUQ paper section A.6 Used internally in SUQ_Transformer_Block_Diag.

Parameters:

Name Type Description Default
Attention Module

The original attention module

required
determinstic bool

Whether to treat value projections as deterministic

True
diag_cov bool

If True, only compute the diagoanl covariance for value

False
W_v_var Tensor

Posterior variance for value matrix (if Bayesian)

None
Source code in suq/diag_suq_transformer.py
class SUQ_Attention_Diag(nn.Module):
    """
    Self-attention module with uncertainty propagation under SUQ.

    Supports deterministic and Bayesian value projections, with optional diagonal covariance assumptions. For details see SUQ paper section A.6
    Used internally in `SUQ_Transformer_Block_Diag`.

    Args:
        Attention (nn.Module): The original attention module
        determinstic (bool): Whether to treat value projections as deterministic
        diag_cov (bool): If True, only compute the diagoanl covariance for value
        W_v_var (Tensor, optional): Posterior variance for value matrix (if Bayesian)
    """

    def __init__(self, Attention, determinstic = True, diag_cov = False, W_v_var = None):
        super().__init__()

        self.Attention = Attention
        self.determinstic = determinstic
        self.diag_cov = diag_cov

        if not self.determinstic:
            self.W_v_var = W_v_var # [D * D]

    def forward(self, x_mean, x_var):
        """
        Forward pass with uncertainty propagation through a SUQ Attention layer.

        Args:
            x_mean (Tensor): Input mean. Shape [B, T, D]
            x_var (Tensor): Input element-wise variance. Shape [B, T, D]

        Returns:
            output_mean (Tensor): Output mean. Shape [B, T, D]
            output_var (Tensor): Output element-wise variance. Shape [B, T, D]
        """

        with torch.no_grad():

            output_mean, attention_score = self.Attention.forward(x_mean, True)

            n_h = self.Attention.n_head
            B, T, D = x_mean.size()
            D_v = D // n_h

            W_v = self.Attention.c_attn_v.weight.data
            project_W = self.Attention.c_proj.weight.data

            if self.determinstic:
                v_cov = forward_value_cov_determinstic_W(W_v, x_var, n_h, D_v)
            else:
                v_cov = forward_value_cov_Bayesian_W(W_v, self.W_v_var.reshape(D, D), x_mean, x_var, n_h, D_v, self.diag_cov)

            QKV_cov = forward_QKV_cov(attention_score, v_cov, self.diag_cov)
            output_var = forward_fuse_multi_head_cov(QKV_cov, project_W, self.diag_cov)

            return output_mean, output_var
Functions
forward(x_mean, x_var)

Forward pass with uncertainty propagation through a SUQ Attention layer.

Parameters:

Name Type Description Default
x_mean Tensor

Input mean. Shape [B, T, D]

required
x_var Tensor

Input element-wise variance. Shape [B, T, D]

required

Returns:

Name Type Description
output_mean Tensor

Output mean. Shape [B, T, D]

output_var Tensor

Output element-wise variance. Shape [B, T, D]

Source code in suq/diag_suq_transformer.py
def forward(self, x_mean, x_var):
    """
    Forward pass with uncertainty propagation through a SUQ Attention layer.

    Args:
        x_mean (Tensor): Input mean. Shape [B, T, D]
        x_var (Tensor): Input element-wise variance. Shape [B, T, D]

    Returns:
        output_mean (Tensor): Output mean. Shape [B, T, D]
        output_var (Tensor): Output element-wise variance. Shape [B, T, D]
    """

    with torch.no_grad():

        output_mean, attention_score = self.Attention.forward(x_mean, True)

        n_h = self.Attention.n_head
        B, T, D = x_mean.size()
        D_v = D // n_h

        W_v = self.Attention.c_attn_v.weight.data
        project_W = self.Attention.c_proj.weight.data

        if self.determinstic:
            v_cov = forward_value_cov_determinstic_W(W_v, x_var, n_h, D_v)
        else:
            v_cov = forward_value_cov_Bayesian_W(W_v, self.W_v_var.reshape(D, D), x_mean, x_var, n_h, D_v, self.diag_cov)

        QKV_cov = forward_QKV_cov(attention_score, v_cov, self.diag_cov)
        output_var = forward_fuse_multi_head_cov(QKV_cov, project_W, self.diag_cov)

        return output_mean, output_var

SUQ_Transformer_Block_Diag

Bases: Module

Single transformer block with uncertainty propagation under SUQ.

Wraps LayerNorm, attention, and MLP submodules with uncertainty-aware versions. Used in SUQ_ViT_Diag to form a full transformer stack.

Parameters:

Name Type Description Default
MLP Module

Original MLP submodule

required
Attention Module

Original attention submodule

required
LN_1 LayerNorm

Pre-attention layer norm

required
LN_2 LayerNorm

Pre-MLP layer norm

required
MLP_determinstic bool

Whether to treat MLP as deterministic

required
Attn_determinstic bool

Whether to treat attention as deterministic

required
diag_cov bool

If True, only compute the diagoanl covariance for value

False
w_fc_var Tensor or None

Posterior variance of MLP input projection (if Bayesian)

None
w_proj_var Tensor or None

Posterior variance of MLP output projection (if Bayesian)

None
W_v_var Tensor or None

Posterior variance of value matrix (if Bayesian)

None
Source code in suq/diag_suq_transformer.py
class SUQ_Transformer_Block_Diag(nn.Module):
    """
    Single transformer block with uncertainty propagation under SUQ.

    Wraps LayerNorm, attention, and MLP submodules with uncertainty-aware versions.
    Used in `SUQ_ViT_Diag` to form a full transformer stack.

    Args:
        MLP (nn.Module): Original MLP submodule
        Attention (nn.Module): Original attention submodule
        LN_1 (nn.LayerNorm): Pre-attention layer norm
        LN_2 (nn.LayerNorm): Pre-MLP layer norm
        MLP_determinstic (bool): Whether to treat MLP as deterministic
        Attn_determinstic (bool): Whether to treat attention as deterministic
        diag_cov (bool): If True, only compute the diagoanl covariance for value
        w_fc_var (Tensor or None): Posterior variance of MLP input projection (if Bayesian)
        w_proj_var (Tensor or None): Posterior variance of MLP output projection (if Bayesian)
        W_v_var (Tensor or None): Posterior variance of value matrix (if Bayesian)
    """


    def __init__(self, MLP, Attention, LN_1, LN_2, MLP_determinstic, Attn_determinstic, diag_cov = False, w_fc_var = None, w_proj_var = None, W_v_var = None):
        super().__init__()

        self.ln_1 = SUQ_LayerNorm_Diag(LN_1)
        self.ln_2 = SUQ_LayerNorm_Diag(LN_2)
        self.attn = SUQ_Attention_Diag(Attention, Attn_determinstic, diag_cov, W_v_var)
        self.mlp = SUQ_TransformerMLP_Diag(MLP, MLP_determinstic, w_fc_var, w_proj_var)

    def forward(self, x_mean, x_var):
        """
        Forward pass with uncertainty propagation through a SUQ Transformer block.    

        Args:
            x_mean (Tensor): Input mean. Shape [B, T, D]
            x_var (Tensor): Input element-wise variance. Shape [B, T, D]

        Returns:
            h_mean (Tensor): Output mean. Shape [B, T, D]
            h_var (Tensor): Output element-wise variance. Shape [B, T, D]
        """

        h_mean, h_var = self.ln_1(x_mean, x_var)
        h_mean, h_var = self.attn(h_mean, h_var)
        h_mean = h_mean + x_mean
        h_var = h_var + x_var

        old_h_mean, old_h_var = h_mean, h_var

        h_mean, h_var = self.ln_2(h_mean, h_var)
        h_mean, h_var = self.mlp(h_mean, h_var)
        h_mean = h_mean + old_h_mean
        h_var = h_var + old_h_var

        return h_mean, h_var
Functions
forward(x_mean, x_var)

Forward pass with uncertainty propagation through a SUQ Transformer block.

Parameters:

Name Type Description Default
x_mean Tensor

Input mean. Shape [B, T, D]

required
x_var Tensor

Input element-wise variance. Shape [B, T, D]

required

Returns:

Name Type Description
h_mean Tensor

Output mean. Shape [B, T, D]

h_var Tensor

Output element-wise variance. Shape [B, T, D]

Source code in suq/diag_suq_transformer.py
def forward(self, x_mean, x_var):
    """
    Forward pass with uncertainty propagation through a SUQ Transformer block.    

    Args:
        x_mean (Tensor): Input mean. Shape [B, T, D]
        x_var (Tensor): Input element-wise variance. Shape [B, T, D]

    Returns:
        h_mean (Tensor): Output mean. Shape [B, T, D]
        h_var (Tensor): Output element-wise variance. Shape [B, T, D]
    """

    h_mean, h_var = self.ln_1(x_mean, x_var)
    h_mean, h_var = self.attn(h_mean, h_var)
    h_mean = h_mean + x_mean
    h_var = h_var + x_var

    old_h_mean, old_h_var = h_mean, h_var

    h_mean, h_var = self.ln_2(h_mean, h_var)
    h_mean, h_var = self.mlp(h_mean, h_var)
    h_mean = h_mean + old_h_mean
    h_var = h_var + old_h_var

    return h_mean, h_var

SUQ_ViT_Diag

Bases: SUQ_Base

Vision Transformer model with uncertainty propagation under SUQ, with a diagonal Gaussian posterior.

Wraps a ViT architecture into a structured uncertainty-aware model by replacing parts of the network with SUQ-compatible blocks. Allows selective Bayesian treatment of MLP and attention modules within each transformer block.

Currently supports classification only. See the SUQ paper for theoretical background and assumptions.

Parameters:

Name Type Description Default
ViT Module

A Vision Transformer model structured like examples/vit_model.py

required
posterior_variance Tensor

Flattened posterior variance vector

required
MLP_determinstic bool

Whether MLP submodules are treated as deterministic

required
Attn_determinstic bool

Whether attention submodules are treated as deterministic

required
scale_init float

Initial value for the scale factor

1.0
attention_diag_cov bool

If True, only compute the diagoanl covariance for value

False
likelihood str

Currently only support 'Classification'

'clasification'
num_det_blocks int

Number of transformer blocks to leave deterministic (from the bottom up)

10
Source code in suq/diag_suq_transformer.py
class SUQ_ViT_Diag(SUQ_Base):
    """
    Vision Transformer model with uncertainty propagation under SUQ, with a diagonal Gaussian posterior.

    Wraps a ViT architecture into a structured uncertainty-aware model by replacing parts
    of the network with SUQ-compatible blocks. Allows selective Bayesian treatment of MLP
    and attention modules within each transformer block.

    Currently supports classification only. See the SUQ paper for theoretical background and assumptions.

    Args:
        ViT (nn.Module): A Vision Transformer model structured like `examples/vit_model.py`
        posterior_variance (Tensor): Flattened posterior variance vector
        MLP_determinstic (bool): Whether MLP submodules are treated as deterministic
        Attn_determinstic (bool): Whether attention submodules are treated as deterministic
        scale_init (float, optional): Initial value for the scale factor
        attention_diag_cov (bool): If True, only compute the diagoanl covariance for value
        likelihood (str): Currently only support 'Classification'
        num_det_blocks (int): Number of transformer blocks to leave deterministic (from the bottom up)
    """

    def __init__(self, ViT, posterior_variance, MLP_determinstic, Attn_determinstic, scale_init = 1.0, attention_diag_cov = False, likelihood = 'clasification', num_det_blocks = 10):
        super().__init__(likelihood, scale_init)

        if likelihood not in ['classification']:
            raise ValueError(f"{likelihood} not supported for ViT")


        self.transformer = nn.ModuleDict(dict(
            pte = ViT.transformer.pte,
            h = nn.ModuleList(),
            ln_f = SUQ_LayerNorm_Diag(ViT.transformer.ln_f)
        ))

        self.scale_factor = nn.Parameter(torch.Tensor([scale_init]))

        num_param_c_fc = ViT.transformer.h[0].mlp.c_fc.weight.numel()
        num_param_c_proj = ViT.transformer.h[0].mlp.c_proj.weight.numel()
        num_param_value_matrix = ViT.transformer.h[0].attn.c_proj.weight.numel()

        index = 0
        for block_index in range(len(ViT.transformer.h)):

            if block_index < num_det_blocks:
                self.transformer.h.append(ViT.transformer.h[block_index])
            else:
                if not MLP_determinstic:
                    w_fc_var = posterior_variance[index: index + num_param_c_fc]
                    index += num_param_c_fc
                    w_proj_var = posterior_variance[index: index + num_param_c_proj]
                    index += num_param_c_proj
                    self.transformer.h.append(
                        SUQ_Transformer_Block_Diag(ViT.transformer.h[block_index].mlp, 
                                                    ViT.transformer.h[block_index].attn, 
                                                    ViT.transformer.h[block_index].ln_1, 
                                                    ViT.transformer.h[block_index].ln_2, 
                                                    MLP_determinstic,
                                                    Attn_determinstic,
                                                    attention_diag_cov,
                                                    w_fc_var, 
                                                    w_proj_var,
                                                    None))

                if not Attn_determinstic:
                    w_v_var = posterior_variance[index : index + num_param_value_matrix]
                    index += num_param_value_matrix
                    self.transformer.h.append(
                        SUQ_Transformer_Block_Diag(ViT.transformer.h[block_index].mlp, 
                                                    ViT.transformer.h[block_index].attn, 
                                                    ViT.transformer.h[block_index].ln_1, 
                                                    ViT.transformer.h[block_index].ln_2, 
                                                    MLP_determinstic,
                                                    Attn_determinstic,
                                                    attention_diag_cov,
                                                    None, 
                                                    None,
                                                    w_v_var))

        num_param_classifier_weight = ViT.classifier.weight.numel()
        self.classifier = SUQ_Classifier_Diag(ViT.classifier, posterior_variance[index: index + num_param_classifier_weight], posterior_variance[index + num_param_classifier_weight:])

    def forward_latent(self, pixel_values, interpolate_pos_encoding = None):

        """
        Compute the predictive mean and variance of the ViT's latent output before applying the final likelihood layer.

        Traverses the full transformer stack with uncertainty propagation.

        Args:
            pixel_values (Tensor): Input image tensor, shape [B, C, H, W]
            interpolate_pos_encoding (optional): Optional positional embedding interpolation

        Returns:
            x_mean (Tensor): Predicted latent mean at the [CLS] token, shape [B, D]
            x_var (Tensor): Predicted latent variance at the [CLS] token, shape [B, D]
        """

        device = pixel_values.device

        x_mean = self.transformer.pte(
            pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
        )

        # pass through model
        x_var = torch.zeros_like(x_mean, device = device)

        for i, block in enumerate(self.transformer.h):

            if isinstance(block, SUQ_Transformer_Block_Diag):

                x_mean, x_var = block(x_mean, x_var)
            else:
                x_mean = block(x_mean)

        x_mean, x_var = self.transformer.ln_f(x_mean, x_var)

        x_mean, x_var = self.classifier(x_mean[:, 0, :], x_var[:, 0, :])
        x_var = x_var / self.scale_factor.to(device)

        return x_mean, x_var

    def forward(self, pixel_values, interpolate_pos_encoding = None):
        """
        Compute predictive class probabilities using a probit approximation.

        Performs a full forward pass through the ViT with uncertainty propagation, and
        produces softmax-normalized class probabilities for classification.

        Args:
            pixel_values (Tensor): Input image tensor, shape [B, C, H, W]
            interpolate_pos_encoding (optional): Optional positional embedding interpolation

        Returns:
            Tensor: Predicted class probabilities, shape [B, num_classes]
        """

        x_mean, x_var = self.forward_latent(pixel_values, interpolate_pos_encoding)
        kappa = 1 / torch.sqrt(1. + np.pi / 8 * x_var)

        return torch.softmax(kappa * x_mean, dim=-1)
Functions
forward_latent(pixel_values, interpolate_pos_encoding=None)

Compute the predictive mean and variance of the ViT's latent output before applying the final likelihood layer.

Traverses the full transformer stack with uncertainty propagation.

Parameters:

Name Type Description Default
pixel_values Tensor

Input image tensor, shape [B, C, H, W]

required
interpolate_pos_encoding optional

Optional positional embedding interpolation

None

Returns:

Name Type Description
x_mean Tensor

Predicted latent mean at the [CLS] token, shape [B, D]

x_var Tensor

Predicted latent variance at the [CLS] token, shape [B, D]

Source code in suq/diag_suq_transformer.py
def forward_latent(self, pixel_values, interpolate_pos_encoding = None):

    """
    Compute the predictive mean and variance of the ViT's latent output before applying the final likelihood layer.

    Traverses the full transformer stack with uncertainty propagation.

    Args:
        pixel_values (Tensor): Input image tensor, shape [B, C, H, W]
        interpolate_pos_encoding (optional): Optional positional embedding interpolation

    Returns:
        x_mean (Tensor): Predicted latent mean at the [CLS] token, shape [B, D]
        x_var (Tensor): Predicted latent variance at the [CLS] token, shape [B, D]
    """

    device = pixel_values.device

    x_mean = self.transformer.pte(
        pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
    )

    # pass through model
    x_var = torch.zeros_like(x_mean, device = device)

    for i, block in enumerate(self.transformer.h):

        if isinstance(block, SUQ_Transformer_Block_Diag):

            x_mean, x_var = block(x_mean, x_var)
        else:
            x_mean = block(x_mean)

    x_mean, x_var = self.transformer.ln_f(x_mean, x_var)

    x_mean, x_var = self.classifier(x_mean[:, 0, :], x_var[:, 0, :])
    x_var = x_var / self.scale_factor.to(device)

    return x_mean, x_var
forward(pixel_values, interpolate_pos_encoding=None)

Compute predictive class probabilities using a probit approximation.

Performs a full forward pass through the ViT with uncertainty propagation, and produces softmax-normalized class probabilities for classification.

Parameters:

Name Type Description Default
pixel_values Tensor

Input image tensor, shape [B, C, H, W]

required
interpolate_pos_encoding optional

Optional positional embedding interpolation

None

Returns:

Name Type Description
Tensor

Predicted class probabilities, shape [B, num_classes]

Source code in suq/diag_suq_transformer.py
def forward(self, pixel_values, interpolate_pos_encoding = None):
    """
    Compute predictive class probabilities using a probit approximation.

    Performs a full forward pass through the ViT with uncertainty propagation, and
    produces softmax-normalized class probabilities for classification.

    Args:
        pixel_values (Tensor): Input image tensor, shape [B, C, H, W]
        interpolate_pos_encoding (optional): Optional positional embedding interpolation

    Returns:
        Tensor: Predicted class probabilities, shape [B, num_classes]
    """

    x_mean, x_var = self.forward_latent(pixel_values, interpolate_pos_encoding)
    kappa = 1 / torch.sqrt(1. + np.pi / 8 * x_var)

    return torch.softmax(kappa * x_mean, dim=-1)

Functions

forward_linear_diag_Bayesian_weight(e_mean, e_var, w_mean, w_var, bias=None)

Compute the mean and element-wise variance of h = e @ W^T + b when e ~ N(e_mean, e_var) and W ~ N(w_mean, w_var)

Note
  • We only make the weight Bayesian, bias is treated determinstically
  • We always assume the input to next layer has diagonal covariance, so we only compute the variance over h here.

Parameters:

Name Type Description Default
e_mean Tensor

Mean of the input embeddings e. Shape: [B, T, D_in]

required
e_var Tensor

Element-wise variance of the input embeddings e. Shape: [B, T, D_in]

required
w_mean Tensor

Mean of the weights W. Shape: [D_out, D_in]

required
w_var Tensor

Element-wise variance of the weights W. Shape: [D_out, D_in]

required
bias Tensor

Bias term b. Shape: [D_out,]

None

Returns:

Name Type Description
h_mean Tensor

Mean of the output h. Shape: [B, T, D_out]

h_var Tensor

Element-wise variance of the output h. Shape: [B, T, D_out]

Source code in suq/diag_suq_transformer.py
def forward_linear_diag_Bayesian_weight(e_mean, e_var, w_mean, w_var, bias = None):
    """
    Compute the mean and element-wise variance of `h = e @ W^T + b` when `e ~ N(e_mean, e_var)` and `W ~ N(w_mean, w_var)`

    Note:
        - We only make the weight Bayesian, bias is treated determinstically
        - We always assume the input to next layer has diagonal covariance, so we only compute the variance over `h` here.

    Args:
        e_mean (Tensor): Mean of the input embeddings `e`. Shape: `[B, T, D_in]`
        e_var (Tensor): Element-wise variance of the input embeddings `e`. Shape: `[B, T, D_in]`
        w_mean (Tensor): Mean of the weights `W`. Shape: `[D_out, D_in]`
        w_var (Tensor): Element-wise variance of the weights `W`. Shape: `[D_out, D_in]`
        bias (Tensor, optional): Bias term `b`. Shape: `[D_out,]`

    Returns:
        h_mean (Tensor): Mean of the output `h`. Shape: `[B, T, D_out]`
        h_var (Tensor): Element-wise variance of the output `h`. Shape: `[B, T, D_out]`
    """

    # calculate mean(h)
    h_mean = F.linear(e_mean, w_mean, bias)

    # calculate var(h)
    weight_mean2_var_sum = w_mean ** 2 + w_var # [D_out, D_in]
    h_var = e_mean **2 @ w_var.T + e_var @ weight_mean2_var_sum.T

    return h_mean, h_var

forward_linear_diag_determinstic_weight(e_mean, e_var, weight, bias=None)

Compute the mean and element-wise variance of h = e @ W^T + b when e ~ N(e_mean, e_var), W and b are both determinstic

Note
  • We always assume the input to next layer has diagonal covariance, so we only compute the variance over h here.

Parameters:

Name Type Description Default
e_mean Tensor

Mean of the input embeddings e. Shape: [B, T, D_in].

required
e_var Tensor

Element-wise variance of the input embeddings e. Shape: [B, T, D_in].

required
weight Tensor

Weights W. Shape: [D_out, D_in].

required
bias Tensor

Bias term b. Shape: [D_out,].

None

Returns:

Name Type Description
h_mean Tensor

Mean of the output h. Shape: [B, T, D_out]

h_var Tensor

Element-wise variance of the output h. Shape: [B, T, D_out]

Source code in suq/diag_suq_transformer.py
def forward_linear_diag_determinstic_weight(e_mean, e_var, weight, bias = None):
    """
    Compute the mean and element-wise variance of `h = e @ W^T + b` when `e ~ N(e_mean, e_var)`, `W` and `b` are both determinstic

    Note:
        - We always assume the input to next layer has diagonal covariance, so we only compute the variance over `h` here.

    Args:
        e_mean (Tensor): Mean of the input embeddings `e`. Shape: `[B, T, D_in]`.
        e_var (Tensor): Element-wise variance of the input embeddings `e`. Shape: `[B, T, D_in]`.
        weight (Tensor): Weights `W`. Shape: `[D_out, D_in]`.
        bias (Tensor, optional): Bias term `b`. Shape: `[D_out,]`.

    Returns:
        h_mean (Tensor): Mean of the output `h`. Shape: `[B, T, D_out]`
        h_var (Tensor): Element-wise variance of the output `h`. Shape: `[B, T, D_out]`
    """

    h_mean = F.linear(e_mean, weight, bias)
    h_var = F.linear(e_var, weight ** 2, None)

    return h_mean, h_var

forward_activation_diag(activation_func, h_mean, h_var)

Approximate the distribution of a = g(h) given h ~ N(h_mean, h_var), where h_var is the element-wise variance of pre-activation h. Uses a first-order Taylor expansion: a ~ N(g(h_mean), g'(h_mean)^T @ h_var @ g'(h_mean)).

Parameters:

Name Type Description Default
activation_func Callable

A PyTorch activation function g(·) (e.g. nn.ReLU())

required
h_mean Tensor

Mean of the pre-activations h. Shape: [B, T, D]

required
h_var Tensor

Element-wise variance of the pre-activations h. Shape: [B, T, D]

required

Returns:

Name Type Description
a_mean Tensor

Mean of the activations a. Shape: [B, T, D]

a_var Tensor

Element-wise variance of the activations a. Shape: [B, T, D]

Source code in suq/diag_suq_transformer.py
@torch.enable_grad()
def forward_activation_diag(activation_func, h_mean, h_var):
    """
    Approximate the distribution of `a = g(h)` given `h ~ N(h_mean, h_var)`, where `h_var` 
    is the element-wise variance of pre-activation `h`.
    Uses a first-order Taylor expansion: `a ~ N(g(h_mean), g'(h_mean)^T @ h_var @ g'(h_mean))`.

    Args:
        activation_func (Callable): A PyTorch activation function `g(·)` (e.g. `nn.ReLU()`)
        h_mean (Tensor): Mean of the pre-activations `h`. Shape: `[B, T, D]`
        h_var (Tensor): Element-wise variance of the pre-activations `h`. Shape: `[B, T, D]`

    Returns:
        a_mean (Tensor): Mean of the activations `a`. Shape: `[B, T, D]`
        a_var (Tensor): Element-wise variance of the activations `a`. Shape: `[B, T, D]`
    """

    h_mean_grad = h_mean.detach().clone().requires_grad_()

    a_mean = activation_func(h_mean_grad)
    a_mean.retain_grad()
    a_mean.backward(torch.ones_like(a_mean)) #[B, T, D]

    nabla = h_mean_grad.grad #[B, T, D]
    a_var = nabla ** 2 * h_var

    return a_mean.detach(), a_var

forward_layer_norm_diag(e_mean, e_var, ln_weight, ln_eps)

Compute the output variance when a distribution e ~ N(e_mean, e_var) is passed through a LayerhNorm layer.

Parameters:

Name Type Description Default
e_mean Tensor

Mean of the input distribution. Shape: [B, T, D]

required
e_var Tensor

Element-wise variance of the input distribution. Shape: [B, T, D]

required
ln_weight Tensor

LayerNorm scale factor (gamma). Shape: [D,]

required
ln_eps float

Small constant added for numerical stability.

required

Returns:

Name Type Description
output_var Tensor

Element-wise variance after LayerNorm. Shape: [B, T, D]

Source code in suq/diag_suq_transformer.py
def forward_layer_norm_diag(e_mean, e_var, ln_weight, ln_eps):
    """
    Compute the output variance when a distribution `e ~ N(e_mean, e_var)`
    is passed through a LayerhNorm layer.

    Args:
        e_mean (Tensor): Mean of the input distribution. Shape: `[B, T, D]`
        e_var (Tensor): Element-wise variance of the input distribution. Shape: `[B, T, D]`
        ln_weight (Tensor): LayerNorm scale factor (gamma). Shape: `[D,]`
        ln_eps (float): Small constant added for numerical stability.

    Returns:
        output_var (Tensor): Element-wise variance after LayerNorm. Shape: `[B, T, D]`
    """

    # calculate the var
    input_mean_var = e_mean.var(dim=-1, keepdim=True, unbiased=False) # [B, T, 1]
    scale_factor = (1 / (input_mean_var + ln_eps)) * ln_weight **2 # [B, T, D]
    output_var = scale_factor * e_var # [B, T, D]

    return output_var

forward_value_cov_Bayesian_W(W_v, W_v_var, e_mean, e_var, n_h, D_v, diag_cov=False)

Given value weight W_v ~ N(W_v, W_v_var) and input E ~ N(e_mean, e_var) Compute the covariance of output v = W_v @ E

Parameters:

Name Type Description Default
n_h int

Number of attention heads.

required
D_v int

Dimension per head. Must satisfy n_h * D_v = D

required
W_v Tensor

Mean of the value weight W_v. Shape: [D, D]

required
W_v_var Tensor

Element-wise variance of the value weight W_v. Shape: [D, D]

required
e_mean Tensor

Mean of the input embeddings e. Shape: [B, T, D]

required
e_var Tensor

Element-wise variance of the input embeddings e. Shape: [B, T, D]

required
diag_cov bool

If True, only compute and return diagonal of the output covariance.

False

Returns:

Name Type Description
v_var Tensor

Returned if diag_cov=True. Element-wise variance of the output v. Shape: [B, T, n_h, D_v]

v_cov Tensor

Returned if diag_cov=False. Full covariance matrices of the output v. Shape: [B, T, n_h, D_v, D_v]

Source code in suq/diag_suq_transformer.py
def forward_value_cov_Bayesian_W(W_v, W_v_var, e_mean, e_var, n_h, D_v, diag_cov = False):
    """
    Given value weight W_v ~ N(W_v, W_v_var) and input E ~ N(e_mean, e_var)
    Compute the covariance of output `v = W_v @ E `

    Args:
        n_h (int): Number of attention heads.
        D_v (int): Dimension per head. Must satisfy `n_h * D_v = D`
        W_v (Tensor): Mean of the value weight `W_v`. Shape: `[D, D]`
        W_v_var (Tensor): Element-wise variance of the value weight `W_v`. Shape: `[D, D]`
        e_mean (Tensor): Mean of the input embeddings `e`. Shape: `[B, T, D]`
        e_var (Tensor): Element-wise variance of the input embeddings `e`. Shape: `[B, T, D]`
        diag_cov (bool): If `True`, only compute and return diagonal of the output covariance.

    Returns:
        v_var (Tensor): Returned if `diag_cov=True`. Element-wise variance of the output `v`.  Shape: `[B, T, n_h, D_v]`
        v_cov (Tensor): Returned if `diag_cov=False`. Full covariance matrices of the output `v`.  Shape: `[B, T, n_h, D_v, D_v]`
    """

    B, T, D = e_var.size()

    if not diag_cov:
        ## compute general covariance 
        W_v_reshaped = W_v.reshape(1, 1, n_h, D_v, D) 
            # [D, D] -> [1, 1, n_h, D_v, D]
        input_var_reshaped = e_var.reshape(B, T, 1, 1, D)
            # [B, T, D] -> [B, T, 1, 1, D]
        v_cov = (W_v_reshaped * input_var_reshaped).transpose(3, 4)
            # [1, 1, n_h, D_v, D] * [B, T, 1, 1, D] -> [B, T, n_h, D_v, D] -> [B, T, n_h, D, D_v]
        v_cov = torch.matmul(W_v_reshaped, v_cov)
            #  [1, 1, n_h, D_v, D] @ [B, T, n_h, D, D_v]  -> [B, T, n_h, D_v, D_v]

        ## add missing part for variance
        W_v_var_reshaped = W_v_var.reshape(1, 1, n_h, D_v, D) 
            #[D, D] -> [1, 1, n_h, D_v, D]
        input_var_plus_mean_square = input_var_reshaped + e_mean.reshape(B, T, 1, 1, D)**2 #[B, T, 1, 1, D]
        extra_var_term = torch.sum(input_var_plus_mean_square * W_v_var_reshaped, dim=[4]) # [B, T, n_h, D_v, D] -> [B, T, n_h, D_v]
        v_cov = v_cov + torch.diag_embed(extra_var_term) 

        return v_cov

    else:
        weight_mean2_var_sum = W_v **2 + W_v_var # [D, D]
        v_var = e_mean **2 @ W_v_var.T + e_var @ weight_mean2_var_sum.T # [B, T, D]

        return v_var.reshape(B, T, n_h, D_v)

forward_value_cov_determinstic_W(W_v, e_var, n_h, D_v, diag_cov=False)

Given determinstic value weight W_v and input E ~ N(e_mean, e_var) Compute the covariance of output v = W_v @ E

Parameters:

Name Type Description Default
n_h int

Number of attention heads.

required
D_v int

Dimension per head. Must satisfy n_h * D_v = D

required
W_v Tensor

Value weight W_v. Shape: [D, D]

required
e_var Tensor

Element-wise variance of the input embeddings e. Shape: [B, T, D]

required
diag_cov bool

If True, only compute and return diagonal of the output covariance.

False

Returns:

Name Type Description
v_var Tensor

Returned if diag_cov=True. Element-wise variance of the output v. Shape: [B, T, n_h, D_v]

v_cov Tensor

Returned if diag_cov=False. Full covariance matrices of the output v. Shape: [B, T, n_h, D_v, D_v]

Source code in suq/diag_suq_transformer.py
def forward_value_cov_determinstic_W(W_v, e_var, n_h, D_v, diag_cov = False):
    """
    Given determinstic value weight W_v and input E ~ N(e_mean, e_var)
    Compute the covariance of output `v = W_v @ E`

    Args:
        n_h (int): Number of attention heads.
        D_v (int): Dimension per head. Must satisfy `n_h * D_v = D`
        W_v (Tensor): Value weight `W_v`. Shape: `[D, D]`
        e_var (Tensor): Element-wise variance of the input embeddings `e`. Shape: `[B, T, D]`
        diag_cov (bool): If `True`, only compute and return diagonal of the output covariance.

    Returns:
        v_var (Tensor): Returned if `diag_cov=True`. Element-wise variance of the output `v`.  Shape: `[B, T, n_h, D_v]`
        v_cov (Tensor): Returned if `diag_cov=False`. Full covariance matrices of the output `v`.  Shape: `[B, T, n_h, D_v, D_v]`

    """

    B, T, D = e_var.size()

    if not diag_cov:
        W_v_reshaped = W_v.reshape(1, 1, n_h, D_v, D) 
            #[n_h, D_v, D] -> [1, 1, n_h, D_v, D]
        input_var_reshaped = e_var.reshape(B, T, 1, 1, D)
            # [B, T, D] -> [B, T, 1, 1, D]
        v_cov = (W_v_reshaped * input_var_reshaped).transpose(3, 4)
            # [1, 1, n_h, D_v, D] * [B, T, 1, 1, D] -> [B, T, n_h, D_v, D] -> [B, T, n_h, D, D_v]
        v_cov = torch.matmul(W_v_reshaped, v_cov)
            #  [1, 1, n_h, D_v, D] @ [B, T, n_h, D, D_v]  -> [B, T, n_h, D_v, D_v]

        return v_cov

    else:
        v_var = e_var @ (W_v ** 2).T

        return v_var.reshape(B, T, n_h, D_v)

forward_QKV_cov(attention_score, v_cov, diag_cov=False)

Given attention score (QK^T) and V ~ N(v_mean, v_cov) Compute the covariance of output E = (QK^T) V

Parameters:

Name Type Description Default
attention_score Tensor

Attention weights A = QK^T. Shape: [B, n_h, T, T]

required
v_cov Tensor

Covariance of the value V. Shape: [B, T, n_h, D_v, D_v] if diag_cov=False. [B, T, n_h, D_v] if diag_cov=True

required
diag_cov bool

If True, value V will have diagonal covariance

False

Returns:

Name Type Description
QKV_var Tensor

Returned if diag_cov=True. Element-wise variance of the output E. Shape: [B, T, n_h, D_v]

QKV_cov Tensor

Returned if diag_cov=False. Full covariance matrices of the output E. Shape: [B, n_h, T, D_v, D_v]

Source code in suq/diag_suq_transformer.py
def forward_QKV_cov(attention_score, v_cov, diag_cov = False):
    """
    Given attention score (QK^T) and `V ~ N(v_mean, v_cov)`
    Compute the covariance of output `E = (QK^T) V`

    Args:
        attention_score (Tensor): Attention weights `A = QK^T`. Shape: `[B, n_h, T, T]`
        v_cov (Tensor): Covariance of the value `V`.  Shape: `[B, T, n_h, D_v, D_v]` if `diag_cov=False`. `[B, T, n_h, D_v]` if `diag_cov=True`
        diag_cov (bool): If `True`, value `V` will have diagonal covariance

    Returns:
        QKV_var (Tensor): Returned if `diag_cov=True`. Element-wise variance of the output `E`. Shape: `[B, T, n_h, D_v]`
        QKV_cov (Tensor): Returned if `diag_cov=False`. Full covariance matrices of the output `E`. Shape: `[B, n_h, T, D_v, D_v]`
    """
    if diag_cov:
        B, T, n_h, D_v = v_cov.size()
        QKV_cov = attention_score **2 @ v_cov.transpose(1, 2) # [B, n_h, T, D_v]
            # v_cov [B, T, n_h, D_v] -> [B, n_h, T, D_v]
            # [B, n_h, T, T] @ [B, n_h, T, D_v]  -> [B, n_h, T, D_v]
    else:

        B, T, n_h, D_v, _ = v_cov.size()

        QKV_cov = attention_score **2 @ v_cov.permute(0, 2, 1, 3, 4).reshape(B, n_h, T, D_v * D_v) # [B, n_h, T, D_v * D_v]
        # v_cov [B, T, n_h, D_v, D_v] -> [B, n_h, T, D_v * D_v]
        # [B, n_h, T, T] @ [B, n_h, T, D_v * D_v]  -> [B, n_h, T, D_v * D_v]
        QKV_cov = QKV_cov.reshape(B, n_h, T, D_v, D_v)

    return QKV_cov

forward_fuse_multi_head_cov(QKV_cov, project_W, diag_cov=False)

Given concatanated multi-head embedding E ~ N(e_mean, e_cov) and the determinstic projection weight matrix W Compute variance of each output dimenison

Parameters:

Name Type Description Default
QKV_cov Tensor

Covariance of the concatenated multi-head output E. Shape: [B, T, n_h, D_v, D_v] if diag_cov=False. [B, T, n_h, D_v] if diag_cov=True

required
project_W Tensor

Projection weight matrix W. Shape: [D_out, D_in], where D_in = n_h * D_v

required
diag_cov bool

If True, QKV_cov will have diagonal covariance

False

Returns:

Name Type Description
output_var Tensor

Element-wise variance of the projected output. Shape: [B, T, D_out]

Source code in suq/diag_suq_transformer.py
def forward_fuse_multi_head_cov(QKV_cov, project_W, diag_cov = False):
    """
    Given concatanated multi-head embedding `E ~ N(e_mean, e_cov)` and the determinstic projection weight matrix `W`
    Compute variance of each output dimenison 

    Args:
        QKV_cov (Tensor): Covariance of the concatenated multi-head output `E`.  Shape: `[B, T, n_h, D_v, D_v]` if `diag_cov=False`. `[B, T, n_h, D_v]` if `diag_cov=True`
        project_W (Tensor): Projection weight matrix `W`. Shape: `[D_out, D_in]`, where `D_in = n_h * D_v`
        diag_cov (bool): If `True`, `QKV_cov` will have diagonal covariance

    Returns: 
        output_var (Tensor): Element-wise variance of the projected output. Shape: `[B, T, D_out]`
    """
    if diag_cov:
        B, n_h, T, D_v = QKV_cov.size()
        output_var = QKV_cov.permute(0, 2, 1, 3).reshape(B, T, n_h * D_v) @ project_W.T ** 2
            # QKV_cov [B, n_h, T, D_v] -> [B, T, n_h, D_v] -> [B, T, n_h * D_v]

        return output_var

    else:
        B, n_h, T, D_v, _ = QKV_cov.size()
        D, _ = project_W.shape

        project_W_reshaped_1 = project_W.T.reshape(n_h, D_v, D).permute(0, 2, 1).reshape(n_h * D, D_v, 1)
            # [n_h, D_v, D] -> [n_h, D, D_v] -> [n_h * D, D_v, 1]
        project_W_reshaped_2 = project_W.T.reshape(n_h, D_v, D).permute(0, 2, 1).reshape(n_h * D, 1, D_v)
            # [n_h, D_v, D] -> [n_h, D, D_v] -> [n_h * D, 1, D_v]

        project_W_outer = torch.bmm(project_W_reshaped_1, project_W_reshaped_2).reshape(n_h, D, D_v, D_v).permute(1, 0, 2, 3) # [D, n_h, D_v, D_v]
        # [n_h * D, D_v, D_v] @ [n_h * D, 1, D_v] -> [n_h * D, D_v, D_v] -> [D, n_h, D_v, D_v]

        output_var_einsum = torch.einsum('dhij,bthij->dbt', project_W_outer, QKV_cov.permute(0, 2, 1, 3, 4))

        return output_var_einsum.permute(1, 2, 0)