aihwkit.nn.modules.rnn.rnn module¶
Analog RNN modules.
- class aihwkit.nn.modules.rnn.rnn.AnalogRNN(cell, input_size, hidden_size, bias=True, rpu_config=None, realistic_read_write=False, xavier=False, num_layers=1, bidir=False, dropout=0.0)[source]¶
Bases:
aihwkit.nn.modules.container.AnalogSequential
Modular RNN that uses analog tiles.
- Parameters
cell (Type) – type of Analog RNN cell (AnalogLSTMCell/AnalogGRUCell/AnalogVanillaRNNCell)
input_size (int) – in_features to W_{ih} matrix of first layer
hidden_size (int) – in_features and out_features for W_{hh} matrices
bias (bool) – whether to use a bias row on the analog tile or not
rpu_config (Optional[Union[aihwkit.simulator.configs.configs.FloatingPointRPUConfig, aihwkit.simulator.configs.configs.SingleRPUConfig, aihwkit.simulator.configs.configs.UnitCellRPUConfig, aihwkit.simulator.configs.configs.InferenceRPUConfig, aihwkit.simulator.configs.configs.DigitalRankUpdateRPUConfig]]) – resistive processing unit configuration.
realistic_read_write (bool) – whether to enable realistic read/write for setting initial weights and read out of weights
xavier (bool) – whether standard PyTorch LSTM weight initialization (default) or Xavier initialization
num_layers (int) – number of serially connected RNN layers
bidir (bool) – if True, becomes a bidirectional RNN
dropout (float) – dropout applied to output of all RNN layers except last
- forward(x, states=None)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Parameters
x (torch.Tensor) –
states (Optional[List]) –
- Return type
Tuple[torch.Tensor, List]
- get_zero_state(batch_size)[source]¶
Returns a zeroed RNN state based on cell type and layer type
- Parameters
batch_size (int) – batch size of the input
- Returns
List of zeroed state tensors for each layer
- Return type
List[torch.Tensor]
- init_layers(weight_init_fn, bias_init_fn=None)[source]¶
Init the analog layers with custom functions.
- Parameters
weight_init_fn (Callable) – in-place tensor function applied to weight of
AnalogLinear
layersbias_init_fn (Optional[Callable]) – in-place tensor function applied to bias of
AnalogLinear
layers
- Return type
None
Note
If no bias init function is provided the weight init function is taken for the bias as well.
- class aihwkit.nn.modules.rnn.rnn.ModularRNN(num_layers, layer, dropout, first_layer_args, other_layer_args)[source]¶
Bases:
aihwkit.nn.modules.container.AnalogSequential
Helper class to create a Modular RNN
- Parameters
num_layers (int) – number of serially connected RNN layers
layer (Type) – RNN layer type (e.g. AnalogLSTMLayer)
dropout (float) – dropout applied to output of all RNN layers except last
first_layer_args (Any) – RNNCell type, input_size, hidden_size, rpu_config, etc.
other_layer_args (Any) – RNNCell type, hidden_size, hidden_size, rpu_config, etc.
- forward(input_, states)[source]¶
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.- Parameters
input_ (torch.Tensor) –
states (List) –
- Return type
Tuple[torch.Tensor, List]
- get_zero_state(batch_size)[source]¶
Returns a zeroed state.
- Parameters
batch_size (int) – batch size of the input
- Returns
List of zeroed state tensors for each layer
- Return type
List[torch.Tensor]
- static init_stacked_analog_lstm(num_layers, layer, first_layer_args, other_layer_args)[source]¶
Construct a list of LSTMLayers over which to iterate.
- Parameters
num_layers (int) – number of serially connected LSTM layers
layer (Type) – RNN layer type (e.g. AnalogLSTMLayer)
first_layer_args (Any) – RNNCell type, input_size, hidden_size, rpu_config, etc.
other_layer_args (Any) – RNNCell type, hidden_size, hidden_size, rpu_config, etc.
- Returns
torch.nn.ModuleList, which is similar to a regular Python list, but where torch.nn.Module methods can be applied
- Return type
torch.nn.modules.container.ModuleList