aihwkit.nn.modules.rnn.cells module

Analog cells for RNNs.

class aihwkit.nn.modules.rnn.cells.AnalogGRUCell(input_size, hidden_size, bias, rpu_config=None, tile_module_class=None)[source]

Bases: Module

Analog GRU Cell.

Parameters:
  • input_size (int) – in_features size for W_ih matrix

  • hidden_size (int) – in_features and out_features size for W_hh matrix

  • bias (bool) – whether to use a bias row on the analog tile or not

  • rpu_config (RPUConfigBase | None) – configuration for an analog resistive processing unit. If not given a native torch model will be constructed instead.

  • tile_module_class (Type | None) – Class for the analog tile module (default will be specified from the RPUConfig).

forward(input_, state)[source]

Forward pass.

Parameters:
  • input – input tensor

  • state (Tensor) – LSTM state tensor

  • input_ (Tensor) –

Returns:

output h_y and output states h_y (which is the same here)

Return type:

Tuple[Tensor, Tensor]

get_zero_state(batch_size)[source]

Returns a zeroed state.

Parameters:

batch_size (int) – batch size of the input

Returns:

Zeroed state tensor

Return type:

Tensor

class aihwkit.nn.modules.rnn.cells.AnalogLSTMCell(input_size, hidden_size, bias, rpu_config=None, tile_module_class=None)[source]

Bases: Module

Analog LSTM Cell.

Parameters:
  • input_size (int) – in_features size for W_ih matrix

  • hidden_size (int) – in_features and out_features size for W_hh matrix

  • bias (bool) – whether to use a bias row on the analog tile or not

  • rpu_config (RPUConfigBase | None) – configuration for an analog resistive processing unit. If not given a native torch model will be constructed instead.

  • tile_module_class (Type | None) – Class for the analog tile module (default will be specified from the RPUConfig).

forward(input_, state)[source]

Forward pass.

Parameters:
  • input – input tensor

  • state (Tuple[Tensor, Tensor]) – LSTM state tensor

  • input_ (Tensor) –

Returns:

output h_y and output states tuple h_y and c_y

Return type:

Tuple[Tensor, Tuple[Tensor, Tensor]]

get_zero_state(batch_size)[source]

Returns a zeroed state.

Parameters:

batch_size (int) – batch size of the input

Returns:

Zeroed state tensor

Return type:

Tensor

class aihwkit.nn.modules.rnn.cells.AnalogLSTMCellCombinedWeight(input_size, hidden_size, bias, rpu_config=None, tile_module_class=None)[source]

Bases: Module

Analog LSTM Cell that use a combined weight for storing gates and inputs.

Parameters:
  • input_size (int) – The number of expected features in the input x

  • hidden_size (int) – The number of features in the hidden state h

  • bias (bool) – whether to use a bias row on the analog tile or not.

  • rpu_config (RPUConfigBase | None) – configuration for an analog resistive processing unit. If not given a native torch model will be constructed instead.

  • tile_module_class (Type | None) – Class for the analog tile module (default will be specified from the RPUConfig).

forward(input_, state)[source]

Forward pass.

Parameters:
  • input – input tensor

  • state (Tuple[Tensor, Tensor]) – LSTM state tensor

  • input_ (Tensor) –

Returns:

output h_y and output states tuple h_y and c_y

Return type:

Tuple[Tensor, Tuple[Tensor, Tensor]]

get_zero_state(batch_size)[source]

Returns a zeroed state.

Parameters:

batch_size (int) – batch size of the input

Returns:

Zeroed state tensor

Return type:

Tensor

class aihwkit.nn.modules.rnn.cells.AnalogVanillaRNNCell(input_size, hidden_size, bias, rpu_config=None, tile_module_class=None)[source]

Bases: Module

Analog Vanilla RNN Cell.

Parameters:
  • input_size (int) – in_features size for W_ih matrix

  • hidden_size (int) – in_features and out_features size for W_hh matrix

  • bias (bool) – whether to use a bias row on the analog tile or not

  • rpu_config (RPUConfigBase | None) – configuration for an analog resistive processing unit. If not given a native torch model will be constructed instead.

  • tile_module_class (Type | None) – Class for the analog tile module (default will be specified from the RPUConfig).

forward(input_, state)[source]

Forward pass.

Parameters:
  • input – input tensor

  • state (Tensor) – LSTM state tensor

  • input_ (Tensor) –

Returns:

output and output states (which is the same here)

Return type:

Tuple[Tensor, Tensor]

get_zero_state(batch_size)[source]

Returns a zeroed state.

Parameters:

batch_size (int) – batch size of the input

Returns:

Zeroed state tensor

Return type:

Tensor

class aihwkit.nn.modules.rnn.cells.LSTMState(hx, cx)

Bases: tuple

cx

Alias for field number 1

hx

Alias for field number 0