# -*- coding: utf-8 -*-
# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
""" Analog RNN modules. """
import warnings
import math
from typing import Any, List, Optional, Tuple, Type, Callable
from torch import Tensor, jit
from torch.nn import Dropout, ModuleList, init, Module, Linear
from torch.autograd import no_grad
from aihwkit.nn.modules.container import AnalogContainerBase
from aihwkit.nn.modules.linear import AnalogLinear
from aihwkit.simulator.parameters.base import RPUConfigBase
from .layers import AnalogRNNLayer, AnalogBidirRNNLayer
[docs]class ModularRNN(Module):
"""Helper class to create a Modular RNN
Args:
num_layers: number of serially connected RNN layers
layer: RNN layer type (e.g. AnalogLSTMLayer)
dropout: dropout applied to output of all RNN layers except last
first_layer_args: RNNCell type, input_size, hidden_size, rpu_config, etc.
other_layer_args: RNNCell type, hidden_size, hidden_size, rpu_config, etc.
"""
# pylint: disable=abstract-method
# Necessary for iterating through self.layers and dropout support
__constants__ = ["layers", "num_layers"]
def __init__(
self,
num_layers: int,
layer: Type,
dropout: float,
first_layer_args: Any,
other_layer_args: Any,
):
super().__init__()
self.layers = self.init_stacked_analog_lstm(
num_layers, layer, first_layer_args, other_layer_args
)
# Introduce a Dropout layer on the outputs of each RNN layer except
# the last layer.
self.num_layers = num_layers
if num_layers == 1 and dropout > 0:
warnings.warn(
"dropout lstm adds dropout layers after all but last "
"recurrent layer, it expects num_layers greater than "
"1, but got num_layers = 1"
)
self.dropout_layer = Dropout(dropout) if dropout > 0.0 else None
[docs] @staticmethod
def init_stacked_analog_lstm(
num_layers: int, layer: Type, first_layer_args: Any, other_layer_args: Any
) -> ModuleList:
"""Construct a list of LSTMLayers over which to iterate.
Args:
num_layers: number of serially connected LSTM layers
layer: RNN layer type (e.g. AnalogLSTMLayer)
first_layer_args: RNNCell type, input_size, hidden_size, rpu_config, etc.
other_layer_args: RNNCell type, hidden_size, hidden_size, rpu_config, etc.
Returns:
torch.nn.ModuleList, which is similar to a regular Python list,
but where torch.nn.Module methods can be applied
"""
layers = [layer(*first_layer_args)] + [
layer(*other_layer_args) for _ in range(num_layers - 1)
]
return ModuleList(layers)
[docs] def get_zero_state(self, batch_size: int) -> List[Tensor]:
"""Returns a zeroed state.
Args:
batch_size: batch size of the input
Returns:
List of zeroed state tensors for each layer
"""
return [lay.get_zero_state(batch_size) for lay in self.layers]
[docs] def forward( # pylint: disable=arguments-differ
self, input: Tensor, states: List # pylint: disable=redefined-builtin
) -> Tuple[Tensor, List]:
"""Forward pass.
Args:
input: input tensor
states: list of LSTM state tensors
Returns:
outputs and states
"""
# List[RNNState]: One state per layer.
output_states = jit.annotate(List, [])
output = input
for i, rnn_layer in enumerate(self.layers):
state = states[i]
output, out_state = rnn_layer(output, state)
# Apply the dropout layer except the last layer.
if i < self.num_layers - 1 and self.dropout_layer is not None:
output = self.dropout_layer(output)
output_states += [out_state]
return output, output_states
[docs]class AnalogRNN(AnalogContainerBase, Module):
"""Modular RNN that uses analog tiles.
Args:
cell: type of Analog RNN cell (AnalogLSTMCell/AnalogGRUCell/AnalogVanillaRNNCell)
input_size: in_features to W_{ih} matrix of first layer
hidden_size: in_features and out_features for W_{hh} matrices
bias: whether to use a bias row on the analog tile or not
rpu_config: configuration for an analog resistive processing
unit. If not given a native torch model will be
constructed instead.
tile_module_class: Class for the analog tile module (default
will be specified from the ``RPUConfig``).
xavier: whether standard PyTorch LSTM weight
initialization (default) or Xavier initialization
num_layers: number of serially connected RNN layers
bidir: if True, becomes a bidirectional RNN
dropout: dropout applied to output of all RNN layers except last
"""
# pylint: disable=abstract-method, too-many-arguments
def __init__(
self,
cell: Type,
input_size: int,
hidden_size: int,
bias: bool = True,
rpu_config: Optional[RPUConfigBase] = None,
tile_module_class: Optional[Type] = None,
xavier: bool = False,
num_layers: int = 1,
bidir: bool = False,
dropout: float = 0.0,
):
super().__init__()
if bidir:
layer = AnalogBidirRNNLayer
num_dirs = 2
else:
layer = AnalogRNNLayer
num_dirs = 1
self.rnn = ModularRNN(
num_layers,
layer,
dropout,
first_layer_args=[cell, input_size, hidden_size, bias, rpu_config, tile_module_class],
other_layer_args=[
cell,
num_dirs * hidden_size,
hidden_size,
bias,
rpu_config,
tile_module_class,
],
)
self.hidden_size = hidden_size
self.num_layers = num_layers
self.reset_parameters(xavier)
[docs] @no_grad()
def init_layers(
self, weight_init_fn: Callable, bias_init_fn: Optional[Callable] = None
) -> None:
"""Init the analog layers with custom functions.
Args:
weight_init_fn: in-place tensor function applied to weight of
``AnalogLinear`` layers
bias_init_fn: in-place tensor function applied to bias of
``AnalogLinear`` layers
Note:
If no bias init function is provided the weight init
function is taken for the bias as well.
"""
def init_weight_and_bias(weight: Tensor, bias: Optional[Tensor]) -> None:
"""Init the weight and bias"""
weight_init_fn(weight.data)
if bias is not None:
if bias_init_fn is None:
weight_init_fn(bias.data)
else:
bias_init_fn(bias.data)
for module in self.modules():
if isinstance(module, AnalogLinear):
weight, bias = module.get_weights()
init_weight_and_bias(weight, bias)
module.set_weights(weight, bias)
elif isinstance(module, Linear):
# init torch layers if any
init_weight_and_bias(module.weight, module.bias)
[docs] def reset_parameters(self, xavier: bool = False) -> None:
"""Weight and bias initialization.
Args:
xavier: whether standard PyTorch LSTM weight
initialization (default) or Xavier initialization
"""
if xavier:
self.init_layers(init.xavier_uniform_, init.zeros_)
else:
stdv = 1.0 / math.sqrt(self.hidden_size)
self.init_layers(lambda x: x.uniform_(-stdv, stdv))
[docs] def get_zero_state(self, batch_size: int) -> List[Tensor]:
"""Returns a zeroed RNN state based on cell type and layer type
Args:
batch_size: batch size of the input
Returns:
List of zeroed state tensors for each layer
"""
return self.rnn.get_zero_state(batch_size)
[docs] def forward(
self, input: Tensor, states: Optional[List] = None # pylint: disable=redefined-builtin
) -> Tuple[Tensor, List]:
"""Forward pass.
Args:
input: input tensor
states: list of LSTM state tensors
Returns:
outputs and states
"""
if states is None:
# TODO: batch_first.
states = self.get_zero_state(input.shape[1])
return self.rnn(input, states)