Source code for aihwkit.nn.modules.rnn.cells

# -*- coding: utf-8 -*-

# (C) Copyright 2020, 2021, 2022 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

""" Analog cells for RNNs. """

from typing import Optional, Tuple
from collections import namedtuple

from torch import Tensor, sigmoid, tanh, zeros

from aihwkit.nn import AnalogLinear, AnalogSequential
from aihwkit.simulator.configs import InferenceRPUConfig
from aihwkit.nn.modules.base import RPUConfigAlias

LSTMState = namedtuple('LSTMState', ['hx', 'cx'])


[docs]class AnalogVanillaRNNCell(AnalogSequential): """Analog Vanilla RNN Cell. Args: input_size: in_features size for W_ih matrix hidden_size: in_features and out_features size for W_hh matrix bias: whether to use a bias row on the analog tile or not rpu_config: configuration for an analog resistive processing unit realistic_read_write: whether to enable realistic read/write for setting initial weights and read out of weights """ # pylint: disable=abstract-method def __init__( self, input_size: int, hidden_size: int, bias: bool, rpu_config: Optional[RPUConfigAlias] = None, realistic_read_write: bool = False, ): super().__init__() # Default to InferenceRPUConfig if not rpu_config: rpu_config = InferenceRPUConfig() self.input_size = input_size self.hidden_size = hidden_size self.weight_ih = AnalogLinear(input_size, hidden_size, bias=bias, rpu_config=rpu_config, realistic_read_write=realistic_read_write) self.weight_hh = AnalogLinear(hidden_size, hidden_size, bias=bias, rpu_config=rpu_config, realistic_read_write=realistic_read_write)
[docs] def get_zero_state(self, batch_size: int) -> Tensor: """Returns a zeroed state. Args: batch_size: batch size of the input Returns: Zeroed state tensor """ device = self.weight_ih.get_analog_tile_devices()[0] return zeros(batch_size, self.hidden_size, device=device)
[docs] def forward( self, input_: Tensor, state: Tensor ) -> Tuple[Tensor, Tensor]: # pylint: disable=arguments-differ igates = self.weight_ih(input_) hgates = self.weight_hh(state) out = tanh(igates + hgates) return out, out # output will also be hidden state
[docs]class AnalogLSTMCell(AnalogSequential): """Analog LSTM Cell. Args: input_size: in_features size for W_ih matrix hidden_size: in_features and out_features size for W_hh matrix bias: whether to use a bias row on the analog tile or not rpu_config: configuration for an analog resistive processing unit realistic_read_write: whether to enable realistic read/write for setting initial weights and read out of weights """ # pylint: disable=abstract-method def __init__( self, input_size: int, hidden_size: int, bias: bool, rpu_config: Optional[RPUConfigAlias] = None, realistic_read_write: bool = False, ): super().__init__() # Default to InferenceRPUConfig if not rpu_config: rpu_config = InferenceRPUConfig() self.input_size = input_size self.hidden_size = hidden_size self.weight_ih = AnalogLinear(input_size, 4 * hidden_size, bias=bias, rpu_config=rpu_config, realistic_read_write=realistic_read_write) self.weight_hh = AnalogLinear(hidden_size, 4 * hidden_size, bias=bias, rpu_config=rpu_config, realistic_read_write=realistic_read_write)
[docs] def get_zero_state(self, batch_size: int) -> Tensor: """Returns a zeroed state. Args: batch_size: batch size of the input Returns: Zeroed state tensor """ device = self.weight_ih.get_analog_tile_devices()[0] return LSTMState(zeros(batch_size, self.hidden_size, device=device), zeros(batch_size, self.hidden_size, device=device))
[docs] def forward(self, input_: Tensor, state: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: # pylint: disable=arguments-differ h_x, c_x = state gates = self.weight_ih(input_) + self.weight_hh(h_x) in_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1) in_gate = sigmoid(in_gate) forget_gate = sigmoid(forget_gate) cell_gate = tanh(cell_gate) out_gate = sigmoid(out_gate) c_y = (forget_gate * c_x) + (in_gate * cell_gate) h_y = out_gate * tanh(c_y) return h_y, (h_y, c_y)
[docs]class AnalogGRUCell(AnalogSequential): """Analog GRU Cell. Args: input_size: in_features size for W_ih matrix hidden_size: in_features and out_features size for W_hh matrix bias: whether to use a bias row on the analog tile or not rpu_config: configuration for an analog resistive processing unit realistic_read_write: whether to enable realistic read/write for setting initial weights and read out of weights """ # pylint: disable=abstract-method def __init__( self, input_size: int, hidden_size: int, bias: bool, rpu_config: Optional[RPUConfigAlias] = None, realistic_read_write: bool = False, ): super().__init__() # Default to InferenceRPUConfig if not rpu_config: rpu_config = InferenceRPUConfig() self.input_size = input_size self.hidden_size = hidden_size self.weight_ih = AnalogLinear(input_size, 3 * hidden_size, bias=bias, rpu_config=rpu_config, realistic_read_write=realistic_read_write) self.weight_hh = AnalogLinear(hidden_size, 3 * hidden_size, bias=bias, rpu_config=rpu_config, realistic_read_write=realistic_read_write)
[docs] def get_zero_state(self, batch_size: int) -> Tensor: """Returns a zeroed state. Args: batch_size: batch size of the input Returns: Zeroed state tensor """ device = self.weight_ih.get_analog_tile_devices()[0] return zeros(batch_size, self.hidden_size, device=device)
[docs] def forward(self, input_: Tensor, state: Tensor) -> Tuple[Tensor, Tensor]: # pylint: disable=arguments-differ g_i = self.weight_ih(input_) g_h = self.weight_hh(state) i_r, i_i, i_n = g_i.chunk(3, 1) h_r, h_i, h_n = g_h.chunk(3, 1) reset_gate = sigmoid(i_r + h_r) input_gate = sigmoid(i_i + h_i) new_gate = tanh(i_n + reset_gate * h_n) hidden_y = new_gate + input_gate * (state - new_gate) return hidden_y, hidden_y # output will also be hidden state