Source code for aihwkit.nn.modules.rnn.layers

# -*- coding: utf-8 -*-

# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# Licensed under the MIT license. See LICENSE file in the project root for details.

""" Analog RNN layers """

from typing import Any, List, Tuple, Type, Union
from torch import Tensor, stack, jit, cat
from torch.nn import ModuleList, Module


[docs]class AnalogRNNLayer(Module): """Analog RNN Layer. Args: cell: RNNCell type (AnalogLSTMCell/AnalogGRUCell/AnalogVanillaRNNCell/ AnalogLSTMCellSingleRPU) cell_args: arguments to RNNCell (e.g. input_size, hidden_size, rpu_configs) """ # pylint: disable=abstract-method def __init__(self, cell: Type, *cell_args: Any): super().__init__() self.cell = cell(*cell_args)
[docs] def get_zero_state(self, batch_size: int) -> Tensor: """Returns a zeroed state. Args: batch_size: batch size of the input Returns: Zeroed state tensor """ return self.cell.get_zero_state(batch_size)
[docs] def forward( self, input_: Tensor, state: Union[Tuple[Tensor, Tensor], Tensor] ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: """Forward pass. Args: input_: input tensor state: LSTM state tensor Returns: stacked outputs and state """ # pylint: disable=arguments-differ inputs = input_.unbind(0) outputs = jit.annotate(List[Tensor], []) for input_item in inputs: out, state = self.cell(input_item, state) outputs += [out] return stack(outputs), state
[docs]class AnalogReverseRNNLayer(Module): """Analog RNN layer for direction. Args: cell: RNNCell type (AnalogLSTMCell/AnalogGRUCell/AnalogVanillaRNNCell) cell_args: arguments to RNNCell (e.g. input_size, hidden_size, rpu_configs) """ def __init__(self, cell: Type, *cell_args: Any): super().__init__() self.cell = cell(*cell_args)
[docs] @staticmethod def reverse(lst: List[Tensor]) -> List[Tensor]: """Reverses the list of input tensors.""" return lst[::-1]
[docs] def get_zero_state(self, batch_size: int) -> Tensor: """Returns a zeroed state. Args: batch_size: batch size of the input Returns: Zeroed state tensor """ return self.cell.get_zero_state(batch_size)
[docs] def forward( self, input_: Tensor, state: Union[Tuple[Tensor, Tensor], Tensor] ) -> Tuple[Tensor, Union[Tuple[Tensor, Tensor], Tensor]]: """Forward pass. Args: input_: input tensor state: LSTM state tensor Returns: stacked reverse outputs and state """ # pylint: disable=arguments-differ inputs = self.reverse(input_.unbind(0)) outputs = jit.annotate(List[Tensor], []) for input_values in inputs: out, state = self.cell(input_values, state) outputs += [out] return stack(self.reverse(outputs)), state
[docs]class AnalogBidirRNNLayer(Module): """Bi-directional analog RNN layer. Args: cell: RNNCell type (AnalogLSTMCell/AnalogGRUCell/AnalogVanillaRNNCell) cell_args: arguments to RNNCell (e.g. input_size, hidden_size, rpu_configs) """ __constants__ = ["directions"] def __init__(self, cell: Type, *cell_args: Any): super().__init__() self.directions = ModuleList( [AnalogRNNLayer(cell, *cell_args), AnalogReverseRNNLayer(cell, *cell_args)] )
[docs] def get_zero_state(self, batch_size: int) -> Tensor: """Returns a zeroed state. Args: batch_size: batch size of the input Returns: Zeroed state tensor """ return [ self.directions[0].get_zero_state(batch_size), self.directions[1].get_zero_state(batch_size), ]
[docs] def forward( self, input_: Tensor, states: List[Union[Tuple[Tensor, Tensor], Tensor]] ) -> Tuple[Tensor, List[Union[Tuple[Tensor, Tensor], Tensor]]]: """Forward pass. Args: input_: input tensor states: LSTM state tensor Returns: cat outputs and states """ # pylint: disable=arguments-differ # List[RNNState]: [forward RNNState, backward RNNState] outputs = jit.annotate(List[Tensor], []) output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) for direction, state in zip(self.directions, states): out, out_state = direction(input_, state) outputs += [out] output_states += [out_state] return cat(outputs, -1), output_states