Source code for aihwkit.nn.modules.rnn.layers

# -*- coding: utf-8 -*-

# (C) Copyright 2020, 2021, 2022 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

""" Analog RNN layers """

from typing import Any, List, Tuple, Type, Union
from torch import Tensor, stack, jit, cat
from torch.nn import ModuleList
from aihwkit.nn import AnalogSequential


[docs]class AnalogRNNLayer(AnalogSequential): """Analog RNN Layer. Args: cell: RNNCell type (AnalogLSTMCell/AnalogGRUCell/AnalogVanillaRNNCell) cell_args: arguments to RNNCell (e.g. input_size, hidden_size, rpu_configs) """ # pylint: disable=abstract-method def __init__(self, cell: Type, *cell_args: Any): super().__init__() self.cell = cell(*cell_args)
[docs] def get_zero_state(self, batch_size: int) -> Tensor: """Returns a zeroed state. Args: batch_size: batch size of the input Returns: Zeroed state tensor """ return self.cell.get_zero_state(batch_size)
[docs] def forward( self, input_: Tensor, state: Union[Tuple[Tensor, Tensor], Tensor] ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: # pylint: disable=arguments-differ inputs = input_.unbind(0) outputs = jit.annotate(List[Tensor], []) for input_item in inputs: out, state = self.cell(input_item, state) outputs += [out] return stack(outputs), state
[docs]class AnalogReverseRNNLayer(AnalogSequential): """ Analog RNN layer for direction. Args: cell: RNNCell type (AnalogLSTMCell/AnalogGRUCell/AnalogVanillaRNNCell) cell_args: arguments to RNNCell (e.g. input_size, hidden_size, rpu_configs) """ def __init__(self, cell: Type, *cell_args: Any): super().__init__() self.cell = cell(*cell_args)
[docs] @staticmethod def reverse(lst: List[Tensor]) -> List[Tensor]: """ Reverses the list of input tensors. """ return lst[::-1]
[docs] def get_zero_state(self, batch_size: int) -> Tensor: """Returns a zeroed state. Args: batch_size: batch size of the input Returns: Zeroed state tensor """ return self.cell.get_zero_state(batch_size)
[docs] def forward(self, input_: Tensor, state: Union[Tuple[Tensor, Tensor], Tensor] ) -> Tuple[Tensor, Union[Tuple[Tensor, Tensor], Tensor]]: # pylint: disable=arguments-differ inputs = self.reverse(input_.unbind(0)) outputs = jit.annotate(List[Tensor], []) for input_values in inputs: out, state = self.cell(input_values, state) outputs += [out] return stack(self.reverse(outputs)), state
[docs]class AnalogBidirRNNLayer(AnalogSequential): """ Bi-directional analog RNN layer. Args: cell: RNNCell type (AnalogLSTMCell/AnalogGRUCell/AnalogVanillaRNNCell) cell_args: arguments to RNNCell (e.g. input_size, hidden_size, rpu_configs) """ __constants__ = ['directions'] def __init__(self, cell: Type, *cell_args: Any): super().__init__() self.directions = ModuleList([ AnalogRNNLayer(cell, *cell_args), AnalogReverseRNNLayer(cell, *cell_args), ])
[docs] def get_zero_state(self, batch_size: int) -> Tensor: """Returns a zeroed state. Args: batch_size: batch size of the input Returns: Zeroed state tensor """ return [self.directions[0].get_zero_state(batch_size), self.directions[1].get_zero_state(batch_size)]
[docs] def forward(self, input_: Tensor, states: List[Union[Tuple[Tensor, Tensor], Tensor]] ) -> Tuple[Tensor, List[Union[Tuple[Tensor, Tensor], Tensor]]]: # pylint: disable=arguments-differ # List[RNNState]: [forward RNNState, backward RNNState] outputs = jit.annotate(List[Tensor], []) output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) for direction, state in zip(self.directions, states): out, out_state = direction(input_, state) outputs += [out] output_states += [out_state] return cat(outputs, -1), output_states