Source code for aihwkit.simulator.tiles.floating_point

# -*- coding: utf-8 -*-

# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""High level analog tiles (floating point)."""

from typing import Optional, Tuple, TYPE_CHECKING

from torch import Tensor

from aihwkit.simulator.rpu_base import tiles
from aihwkit.simulator.tiles.rpucuda import RPUCudaSimulatorTileWrapper
from aihwkit.simulator.tiles.module import TileModule
from aihwkit.simulator.tiles.periphery import TileWithPeriphery
from aihwkit.simulator.tiles.functions import AnalogFunction
from aihwkit.simulator.parameters.base import RPUConfigGeneric

if TYPE_CHECKING:
    from aihwkit.simulator.configs import FloatingPointRPUConfig


[docs]class FloatingPointTile(TileModule, TileWithPeriphery, RPUCudaSimulatorTileWrapper): r"""Floating point tile. Implements a floating point or ideal analog tile. A linear layer with this tile is perfectly linear, it just uses the RPUCuda library for execution. **Forward pass**: .. math:: \mathbf{y} = W\mathbf{x} :math:`W` are the weights, :math:`\mathbf{x}` is the input vector. :math:`\mathbf{y}` is output of the vector matrix multiplication. Note that if bias is used, :math:`\mathbf{x}` is concatenated with 1 so that the last column of :math:`W` are the biases. **Backward pass**: Typical backward pass with transposed weights: .. math:: \mathbf{d'} = W^T\mathbf{d} where :math:`\mathbf{d}` is the error vector. :math:`\mathbf{d}_o` is output of the backward matrix vector multiplication. **Weight update**: Usual learning rule for back-propagation: .. math:: w_{ij} \leftarrow w_{ij} + \lambda d_i\,x_j **Decay**: .. math:: w_{ij} \leftarrow w_{ij}(1-\alpha r_\text{decay}) Weight decay can be called by calling the analog tile decay. Note: ``life_time`` parameter is set during initialization. alpha is a scaling factor that can be given during run-time. **Diffusion**: .. math:: w_{ij} \leftarrow w_{ij} + \xi\;r_\text{diffusion} Similar to the decay, diffusion is only done when explicitly called. However, the parameter of the diffusion process are set during initialization and are fixed for the remainder. :math:`\xi` is a standard Gaussian process. Args: out_size: output vector size of the tile, ie. the dimension of :math:`\mathbf{y}` in case of :math:`\mathbf{y} = W\mathbf{x}` (or equivalently the dimension of the :math:`\boldsymbol{\delta}` of the backward pass). in_size: input vector size, ie. the dimension of the vector :math:`\mathbf{x}` in case of :math:`\mathbf{y} = W\mathbf{x}`). rpu_config: resistive processing unit configuration. bias: whether to add a bias column to the tile, ie. :math:`W` has an extra column to code the biases. Internally, the input :math:`\mathbf{x}` will be automatically expanded by an extra dimension which will be set to 1 always. in_trans: Whether to assume an transposed input (batch first). out_trans: Whether to assume an transposed output (batch first). """ def __init__( self, out_size: int, in_size: int, rpu_config: "FloatingPointRPUConfig", bias: bool = False, in_trans: bool = False, out_trans: bool = False, ): TileModule.__init__(self) RPUCudaSimulatorTileWrapper.__init__( self, out_size, in_size, rpu_config, bias, in_trans, out_trans # type: ignore ) TileWithPeriphery.__init__(self) def _create_simulator_tile( self, x_size: int, d_size: int, rpu_config: RPUConfigGeneric ) -> tiles.FloatingPointTile: """Create a simulator tile. Args: x_size: input size d_size: output size rpu_config: resistive processing unit configuration Returns: a simulator tile based on the specified configuration. """ meta_parameter = rpu_config.device.as_bindings(self.get_data_type()) return meta_parameter.create_array(x_size, d_size)
[docs] def forward( self, x_input: Tensor, tensor_view: Optional[Tuple] = None # type: ignore ) -> Tensor: """Torch forward function that calls the analog forward""" # pylint: disable=arguments-differ out = AnalogFunction.apply( self.get_analog_ctx(), self, x_input, self.shared_weights, not self.training ) if tensor_view is None: tensor_view = self.get_tensor_view(out.dim()) out = self.apply_out_scaling(out, tensor_view) if self.digital_bias: return out + self.bias.view(*tensor_view) return out