mask.MultiIO
Description
This is to support the senario when we need to project different inputs to different areas and readout from different areas. This will generate a multi-area RNN. This class primarily serves to generate the input and readout masks.
Parameters
- input_dims(list, default:- [dims[0]]):
 A list denoting the dimensions of each group of input signals.- Example: For a 1-dimensional olfactory signal and a 100-dimensional visual signal, use [1, 100].
- Must sum to the total input dimension specified in dims[0].
 
- Example: For a 1-dimensional olfactory signal and a 100-dimensional visual signal, use 
- readout_dims(list, default:- [dims[2]]):
 A list denoting the dimensions of each group of readout signals.- Example: For a 1-dimensional olfactory signal and a 100-dimensional visual signal, use [1, 100].
- Must sum to the total readout dimension specified in dims[2].
 
- Example: For a 1-dimensional olfactory signal and a 100-dimensional visual signal, use 
- input_table(np.ndarray, default: all ones):
 A table specifying whether an input group is projected to a hidden layer node.- Shape: (n_input_groups, hidden_size).
- Values: 0 (no connection) or 1 (connection).
 
- Shape: 
- readout_table(np.ndarray, default: all ones):
 A table specifying whether a hidden layer node contributes to a readout group.- Shape: (n_readout_groups, hidden_size).
- Values: 0 (no contribution) or 1 (contribution).
 
- Shape: 
Methods
get_specs()
Returns the specifications of the network, including:
- "dims"
- "hidden_size"
- "input_dim"
- "readout_dim"
- "input_dims"
- "readout_dims"
- "input_table_shape"
- "readout_table_shape"
Returns:
- dict: Specifications of the network.
Usage:
specs = multi_io.get_specs()
Example
import numpy as np
from nn4n.mask import MultiIO
input_table = np.ones((2, 200))
readout_table = np.ones((2, 200))
input_table[0, 0:100] = 0
input_table[1, 100:150] = 0
readout_table[0, 100:200] = 0
mask_params = {
    "dims": [20, 200, 10],
    "input_dims": [5, 15],
    "readout_dims": [5, 5],
    "input_table": input_table,
    "readout_table": readout_table,
}
mask = MultiIO(**mask_params)
mask.print_specs()
mask.plot_masks()
Output:
 
This mask does not change the hidden layer connectivity. Skipped.
 
MultiIO Specs: 
   | dims:               [20, 200, 10]
   | hidden_size:        200
   | input_dim:          20
   | readout_dim:         10
   | input_dims:         [5, 15]
   | readout_dims:        [5, 5]
   | input_table_shape:  (2, 200)
   | readout_table_shape: (2, 200)