criterion.RNNLoss
RNNLoss
This document provides a detailed description of the RNNLoss
module, its functionality, and the mathematical formulation for each loss component.
Overview
The RNNLoss
module computes a composite loss for a CTRNN
model. It includes various loss components to constrain the model’s behavior and promote specific properties like sparsity, firing rate regularization, and output accuracy.
Key Features:
- Modular loss components, each associated with a lambda weight.
- Loss functions for sparsity in input, hidden, and readout layers.
- Firing rate regularization with multiple metrics (mean, standard deviation, coefficient of variation).
- Mean Squared Error (MSE) for output predictions.
Parameters
-
model
(CTRNN):
The recurrent neural network model for which the loss is computed. Must be an instance ofCTRNN
. -
Keyword Arguments:
Coefficients for individual loss components:lambda_mse
(float, default:1
):
Weight for the Mean Squared Error loss.lambda_in
(float, default:0
):
Weight for the sparsity loss of the input layer.lambda_hid
(float, default:0
):
Weight for the sparsity loss of the hidden layer.lambda_out
(float, default:0
):
Weight for the sparsity loss of the readout layer.lambda_fr
(float, default:0
):
Weight for the firing rate loss.lambda_fr_sd
(float, default:0
):
Weight for the standard deviation of firing rates.lambda_fr_cv
(float, default:0
):
Weight for the coefficient of variation of firing rates.
Loss Components
InputLayer Sparsity
Lambda Key: lambda_in
Mathematical Formulation:
\[\mathcal{L}_{\text{in}} = \frac{\lambda_{\text{in}}}{N_{\text{in}} N_{\text{hid}}} ||\mathbf{W}_{\text{in}}||_F^2\]- $ N_{\text{in}} $: Number of input neurons.
- $ N_{\text{hid}} $: Number of hidden neurons.
- $ \mathbf{W}_{\text{in}} $: Weight matrix of the input layer.
HiddenLayer Sparsity
Lambda Key: lambda_hid
Mathematical Formulation:
\[\mathcal{L}_{\text{hid}} = \frac{\lambda_{\text{hid}}}{N_{\text{hid}}^2} ||\mathbf{W}_{\text{hid}}||_F^2\]- $ N_{\text{hid}} $: Number of hidden neurons.
- $ \mathbf{W}_{\text{hid}} $: Weight matrix of the hidden layer.
ReadoutLayer Sparsity
Lambda Key: lambda_out
Mathematical Formulation:
\[\mathcal{L}_{\text{out}} = \frac{\lambda_{\text{out}}}{N_{\text{out}} N_{\text{hid}}} ||\mathbf{W}_{\text{out}}||_F^2\]- $ N_{\text{out}} $: Number of readout neurons.
- $ N_{\text{hid}} $: Number of hidden neurons.
- $ \mathbf{W}_{\text{out}} $: Weight matrix of the readout layer.
Firing Rate
Lambda Key: lambda_fr
Mathematical Formulation:
\[\mathcal{L}_{\text{fr}} = \frac{\lambda_{\text{fr}}}{B \cdot T \cdot N_{\text{hid}}} \sum_{b,t,i=1}^{B,T,N_{\text{hid}}} r_{b,t,i}^2\]- $ B $: Number of batches.
- $ T $: Number of time steps.
- $ N_{\text{hid}} $: Number of hidden neurons.
- $ r_{b,t,i} $: Firing rate of the $ i $-th neuron at time $ t $ in the $ b $-th batch.
Firing Rate Standard Deviation (SD)
Lambda Key: lambda_fr_sd
Mathematical Formulation:
\[\mathcal{L}_{\text{fr\_sd}} = \lambda_{\text{fr\_sd}} \sqrt{\frac{1}{N_{\text{hid}}} \sum_{i=1}^{N_{\text{hid}}} \left( \bar{r}_i - \mu \right)^2}\] \[\mu = \frac{1}{B \cdot T \cdot N_{\text{hid}}} \sum_{b,t,i=1}^{B,T,N_{\text{hid}}} r_{b,t,i}\]- $ \bar{r}_i $: Mean firing rate of neuron $ i $.
- $ \mu $: Overall mean firing rate.
Firing Rate Coefficient of Variation (CV)
Lambda Key: lambda_fr_cv
Mathematical Formulation:
\[\mathcal{L}_{\text{fr\_cv}} = \lambda_{\text{fr\_cv}} \frac{\sigma}{\mu}\]- $ \sigma $: Standard deviation of firing rates across neurons.
- $ \mu $: Mean firing rate of the network.
Mean Squared Error (MSE)
Lambda Key: lambda_mse
Mathematical Formulation:
\[\mathcal{L}_{\text{mse}} = \frac{\lambda_{\text{mse}}}{B \cdot T \cdot N_{\text{out}}} \sum_{b,t,k=1}^{B,T,N_{\text{out}}} \left( \hat{y}_{b,t,k} - y_{b,t,k} \right)^2\]- $ B $: Number of batches.
- $ T $: Number of time steps.
- $ N_{\text{out}} $: Dimension of the outputs.
- $ \hat{y}_{b,t,k} $: Predicted output.
- $ y_{b,t,k} $: Ground truth output.
Methods
forward(pred, label, **kwargs)
Computes the total loss and individual component losses.
Parameters:
pred
(torch.Tensor): Predicted outputs of shape $(-1, B, 2)$.label
(torch.Tensor): Ground truth labels of shape $(-1, B, 2)$.kwargs
(dict): Additional inputs for specific loss components (e.g., states).
Returns:
total_loss
(torch.Tensor): Sum of all weighted losses.loss_components
(torch.Tensor): Individual loss components.
Usage:
total_loss, losses = rnn_loss(pred=predicted_outputs, label=ground_truth, states=hidden_states)