aihwkit.nn.modules.rnn.rnn module

Analog RNN modules.

class aihwkit.nn.modules.rnn.rnn.AnalogRNN(cell, input_size, hidden_size, bias=True, rpu_config=None, tile_module_class=None, xavier=False, num_layers=1, bidir=False, dropout=0.0)[source]

Bases: AnalogContainerBase, Module

Modular RNN that uses analog tiles.

Parameters:
  • cell (Type) – type of Analog RNN cell (AnalogLSTMCell/AnalogGRUCell/AnalogVanillaRNNCell)

  • input_size (int) – in_features to W_{ih} matrix of first layer

  • hidden_size (int) – in_features and out_features for W_{hh} matrices

  • bias (bool) – whether to use a bias row on the analog tile or not

  • rpu_config (RPUConfigBase | None) – configuration for an analog resistive processing unit. If not given a native torch model will be constructed instead.

  • tile_module_class (Type | None) – Class for the analog tile module (default will be specified from the RPUConfig).

  • xavier (bool) – whether standard PyTorch LSTM weight initialization (default) or Xavier initialization

  • num_layers (int) – number of serially connected RNN layers

  • bidir (bool) – if True, becomes a bidirectional RNN

  • dropout (float) – dropout applied to output of all RNN layers except last

forward(input, states=None)[source]

Forward pass.

Parameters:
  • input (Tensor) – input tensor

  • states (List | None) – list of LSTM state tensors

Returns:

outputs and states

Return type:

Tuple[Tensor, List]

get_zero_state(batch_size)[source]

Returns a zeroed RNN state based on cell type and layer type

Parameters:

batch_size (int) – batch size of the input

Returns:

List of zeroed state tensors for each layer

Return type:

List[Tensor]

init_layers(weight_init_fn, bias_init_fn=None)[source]

Init the analog layers with custom functions.

Parameters:
  • weight_init_fn (Callable) – in-place tensor function applied to weight of AnalogLinear layers

  • bias_init_fn (Callable | None) – in-place tensor function applied to bias of AnalogLinear layers

Return type:

None

Note

If no bias init function is provided the weight init function is taken for the bias as well.

reset_parameters(xavier=False)[source]

Weight and bias initialization.

Parameters:

xavier (bool) – whether standard PyTorch LSTM weight initialization (default) or Xavier initialization

Return type:

None

class aihwkit.nn.modules.rnn.rnn.ModularRNN(num_layers, layer, dropout, first_layer_args, other_layer_args)[source]

Bases: Module

Helper class to create a Modular RNN

Parameters:
  • num_layers (int) – number of serially connected RNN layers

  • layer (Type) – RNN layer type (e.g. AnalogLSTMLayer)

  • dropout (float) – dropout applied to output of all RNN layers except last

  • first_layer_args (Any) – RNNCell type, input_size, hidden_size, rpu_config, etc.

  • other_layer_args (Any) – RNNCell type, hidden_size, hidden_size, rpu_config, etc.

forward(input, states)[source]

Forward pass.

Parameters:
  • input (Tensor) – input tensor

  • states (List) – list of LSTM state tensors

Returns:

outputs and states

Return type:

Tuple[Tensor, List]

get_zero_state(batch_size)[source]

Returns a zeroed state.

Parameters:

batch_size (int) – batch size of the input

Returns:

List of zeroed state tensors for each layer

Return type:

List[Tensor]

static init_stacked_analog_lstm(num_layers, layer, first_layer_args, other_layer_args)[source]

Construct a list of LSTMLayers over which to iterate.

Parameters:
  • num_layers (int) – number of serially connected LSTM layers

  • layer (Type) – RNN layer type (e.g. AnalogLSTMLayer)

  • first_layer_args (Any) – RNNCell type, input_size, hidden_size, rpu_config, etc.

  • other_layer_args (Any) – RNNCell type, hidden_size, hidden_size, rpu_config, etc.

Returns:

torch.nn.ModuleList, which is similar to a regular Python list, but where torch.nn.Module methods can be applied

Return type:

ModuleList