layer.RNNLayer
Parameters
-
hidden_layers
(List[nn.Module]):
A list of hidden layers forming the recurrent network. Each layer should be an instance ofnn.Module
. -
readout_layer
(LinearLayer, optional, default:None
):
An optional readout layer to map the final hidden state to the output space. -
device
(str, default:"cpu"
):
The device where the network and its components are initialized ("cpu"
or"cuda"
).
Methods
to(device)
Moves the entire recurrent layer, including its hidden and readout layers, to the specified device.
Parameters:
device
(torch.device):
Target device (e.g.,torch.device("cuda")
or"cpu"
).
Usage:
layer.to(torch.device("cuda"))
_generate_init_state(dim, batch_size, i_val=0)
Generates an initial state tensor with a specified dimension, batch size, and initialization value.
Parameters:
dim
(int):
Dimensionality of the hidden state.batch_size
(int):
Number of batches.i_val
(float, default:0
):
Initialization value for the state.
Returns:
torch.Tensor
: Initialized state tensor of shape(batch_size, dim)
.
Usage:
init_state = layer._generate_init_state(hidden_size, batch_size)
forward(x, init_states=None)
Performs a forward pass through the recurrent layer over a sequence of timesteps.
Parameters:
x
(torch.Tensor):
Input tensor of shape(batch_size, n_timesteps, input_dim)
.init_states
(List[torch.Tensor], optional, default:None
):
A list of initial states for the hidden layers. Each element has shape(batch_size, hidden_size)
for the corresponding hidden layer.
Returns:
output
(torch.Tensor or None):
Output tensor from the readout layer if it exists; otherwise,None
.hidden_states
(List[torch.Tensor]):
A list of tensors representing the hidden states across all timesteps for each hidden layer.
Usage:
output, hidden_states = layer.forward(input_tensor, initial_states)
train()
Sets the layer to training mode by enabling pre-activation and post-activation noise in hidden layers and enforcing constraints.
Usage:
layer.train()
eval()
Sets the layer to evaluation mode by disabling pre-activation and post-activation noise in hidden layers and pausing constraint enforcement.
Usage:
layer.eval()
apply_plasticity()
Applies plasticity masks to the weight gradients of all hidden layers and the readout layer (if it exists).
Usage:
layer.apply_plasticity()
enforce_constraints()
Enforces constraints, such as sparsity or excitatory/inhibitory balance, on all hidden layers and the readout layer (if it exists).
This method is called automatically during training but can be invoked manually if needed.
Usage:
layer.enforce_constraints()
plot_layer(**kwargs)
Plots the weight matrices and distributions of all hidden layers. Accepts additional keyword arguments for customization.
Usage:
layer.plot_layer()
print_layer()
Placeholder for printing the weight matrices and distributions of all hidden layers. Not implemented.
Usage:
layer.print_layer()