aihwkit.nn.modules.rnn.rnn module
Analog RNN modules.
- class aihwkit.nn.modules.rnn.rnn.AnalogRNN(cell, input_size, hidden_size, bias=True, rpu_config=None, tile_module_class=None, xavier=False, num_layers=1, bidir=False, dropout=0.0)[source]
Bases:
AnalogContainerBase
,Module
Modular RNN that uses analog tiles.
- Parameters:
cell (Type) – type of Analog RNN cell (AnalogLSTMCell/AnalogGRUCell/AnalogVanillaRNNCell)
input_size (int) – in_features to W_{ih} matrix of first layer
hidden_size (int) – in_features and out_features for W_{hh} matrices
bias (bool) – whether to use a bias row on the analog tile or not
rpu_config (RPUConfigBase | None) – configuration for an analog resistive processing unit. If not given a native torch model will be constructed instead.
tile_module_class (Type | None) – Class for the analog tile module (default will be specified from the
RPUConfig
).xavier (bool) – whether standard PyTorch LSTM weight initialization (default) or Xavier initialization
num_layers (int) – number of serially connected RNN layers
bidir (bool) – if True, becomes a bidirectional RNN
dropout (float) – dropout applied to output of all RNN layers except last
- forward(input, states=None)[source]
Forward pass.
- Parameters:
input (Tensor) – input tensor
states (List | None) – list of LSTM state tensors
- Returns:
outputs and states
- Return type:
Tuple[Tensor, List]
- get_zero_state(batch_size)[source]
Returns a zeroed RNN state based on cell type and layer type
- Parameters:
batch_size (int) – batch size of the input
- Returns:
List of zeroed state tensors for each layer
- Return type:
List[Tensor]
- init_layers(weight_init_fn, bias_init_fn=None)[source]
Init the analog layers with custom functions.
- Parameters:
weight_init_fn (Callable) – in-place tensor function applied to weight of
AnalogLinear
layersbias_init_fn (Callable | None) – in-place tensor function applied to bias of
AnalogLinear
layers
- Return type:
None
Note
If no bias init function is provided the weight init function is taken for the bias as well.
- class aihwkit.nn.modules.rnn.rnn.ModularRNN(num_layers, layer, dropout, first_layer_args, other_layer_args)[source]
Bases:
Module
Helper class to create a Modular RNN
- Parameters:
num_layers (int) – number of serially connected RNN layers
layer (Type) – RNN layer type (e.g. AnalogLSTMLayer)
dropout (float) – dropout applied to output of all RNN layers except last
first_layer_args (Any) – RNNCell type, input_size, hidden_size, rpu_config, etc.
other_layer_args (Any) – RNNCell type, hidden_size, hidden_size, rpu_config, etc.
- forward(input, states)[source]
Forward pass.
- Parameters:
input (Tensor) – input tensor
states (List) – list of LSTM state tensors
- Returns:
outputs and states
- Return type:
Tuple[Tensor, List]
- get_zero_state(batch_size)[source]
Returns a zeroed state.
- Parameters:
batch_size (int) – batch size of the input
- Returns:
List of zeroed state tensors for each layer
- Return type:
List[Tensor]
- static init_stacked_analog_lstm(num_layers, layer, first_layer_args, other_layer_args)[source]
Construct a list of LSTMLayers over which to iterate.
- Parameters:
num_layers (int) – number of serially connected LSTM layers
layer (Type) – RNN layer type (e.g. AnalogLSTMLayer)
first_layer_args (Any) – RNNCell type, input_size, hidden_size, rpu_config, etc.
other_layer_args (Any) – RNNCell type, hidden_size, hidden_size, rpu_config, etc.
- Returns:
torch.nn.ModuleList, which is similar to a regular Python list, but where torch.nn.Module methods can be applied
- Return type:
ModuleList