Source code for aihwkit.simulator.tiles.inference_torch

# -*- coding: utf-8 -*-

# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""High level analog tiles (inference)."""

from typing import Optional, Any, Dict, Tuple, TYPE_CHECKING

from torch import Tensor, zeros_like, clamp
from torch.autograd import no_grad, Function

from aihwkit.simulator.tiles.inference import InferenceTileWithPeriphery
from aihwkit.simulator.tiles.torch_tile import TorchSimulatorTile
from aihwkit.simulator.tiles.module import TileModule
from aihwkit.simulator.tiles.base import SimulatorTileWrapper

from aihwkit.exceptions import TorchTileConfigError, AnalogBiasConfigError, ArgumentError
from aihwkit.simulator.parameters.enums import WeightRemapType, WeightClipType

    from torch.nn import BackwardHook
    from aihwkit.simulator.configs import TorchInferenceRPUConfig
    from aihwkit.simulator.parameters import InputRangeParameter

[docs]class InputRangeForward(Function): """ Enable custom input range gradient computation using torch's autograd. """ # pylint: disable=abstract-method, redefined-builtin, arguments-differ
[docs] @staticmethod def forward( ctx: Any, x_input: Tensor, input_range: Tensor, ir_params: "InputRangeParameter" ) -> Tensor: ctx.save_for_backward(x_input, input_range) ctx.ir_params = ir_params return x_input
[docs] @staticmethod def backward(ctx: Any, d_output: Tensor) -> Tuple[Tensor, Tensor, None]: x_input, input_range = ctx.saved_tensors ir_grad = None if input_range is not None: ir_params = ctx.ir_params upper_thres = x_input >= input_range # pylint: disable=invalid-unary-operand-type lower_thres = x_input <= -input_range # pylint: disable=invalid-unary-operand-type ir_grad = zeros_like(input_range) ir_grad += clamp(upper_thres * d_output, min=None, max=0.0).sum() ir_grad -= clamp(lower_thres * d_output, min=0.0, max=None).sum() if ir_params.gradient_relative: ir_grad *= input_range ir_grad *= ir_params.gradient_scale if ir_params.manage_output_clipping: raise NotImplementedError if ir_params.decay > 0: # - We shrink the input range if less than X% of the inputs are clipping. # where X is 1-ir_params.input_min_percentage percentage = (x_input.abs() < input_range).float().mean() ir_grad += ( ir_params.decay * input_range * (percentage > ir_params.input_min_percentage) ) return d_output, ir_grad, None
[docs]class TorchInferenceTile(TileModule, InferenceTileWithPeriphery, SimulatorTileWrapper): """InferenceTile using a torch-based simulator tile (and not a tile from RPUCuda). """ supports_indexed: bool = False supports_ddp: bool = True def __init__( self, out_size: int, in_size: int, rpu_config: Optional["TorchInferenceRPUConfig"] = None, bias: bool = False, in_trans: bool = False, out_trans: bool = False, ): if in_trans or out_trans: raise TorchTileConfigError("in/out trans is not supported.") if not rpu_config: # Import dynamically to avoid import cycles. # pylint: disable=import-outside-toplevel from aihwkit.simulator.configs import TorchInferenceRPUConfig rpu_config = TorchInferenceRPUConfig() TileModule.__init__(self) SimulatorTileWrapper.__init__( self, out_size, in_size, rpu_config, bias, in_trans, out_trans, torch_update=True ) InferenceTileWithPeriphery.__init__(self) if self.analog_bias: raise AnalogBiasConfigError("Analog bias is not supported for the torch tile") # Hooks for input range grad computation. Will not be saved in state_dict self._tile_input_grad_hook = None self._tile_input = None # type: Tensor self._x_input_grad = None # type: Tensor self._backward_hook_handle = None # type: BackwardHook def _create_simulator_tile( # type: ignore self, x_size: int, d_size: int, rpu_config: "TorchInferenceRPUConfig" ) -> "TorchSimulatorTile": """Create a simulator tile. Args: weight: 2D weight rpu_config: resistive processing unit configuration Returns: a simulator tile based on the specified configuration. """ return rpu_config.simulator_tile_class(x_size=x_size, d_size=d_size, rpu_config=rpu_config) def _recreate_simulator_tile( # type: ignore[override] self, x_size: int, d_size: int, rpu_config: "TorchInferenceRPUConfig" ) -> Any: # just use Any instead of Union["SimulatorTile", tiles.AnalogTile, ..] """Re-create a simulator tile in __setstate__. Args: x_size: input size d_size: output size rpu_config: resistive processing unit configuration Returns: a simulator tile based on the specified configuration. """ return self.tile
[docs] def init_input_processing(self) -> bool: """Helper function to initialize the input processing. Note: This method is called from the constructor. Returns: whether input processing is enabled Raises: ConfigError in case ``manage_output_clipping`` is enabled but not supported. """ enable = super().init_input_processing() if not enable: return False return True
[docs] def set_scales(self, scales: Tensor) -> None: """Set all scales with a new scale. This will set the mapping scales to ``scales`` and set all other scales to 1. Args: scales: scales to set. """ super().set_scales(scales) # - Remove old hook if self._backward_hook_handle is not None: self._backward_hook_handle.remove() def hook(grad: Tensor) -> Tensor: return grad /, 1) ** 2 if self.tile.weight.requires_grad: self._backward_hook_handle = self.tile.register_weight_hook(hook)
[docs] def pre_forward( self, x_input: Tensor, dim: int, is_test: bool = False, ctx: Any = None ) -> Tensor: """Operations before the actual forward step for pre processing. By default, this is an no-op. However, it could be overridden in derived tile classes. Args: x_input: input tensor for the analog MVM of the tile. dim: input channel dimension, ie the x_size dimension is_test: whether in eval mode ctx: torch auto-grad context [Optional] Returns: Output tensor of the same shape """ # pylint: disable=unused-argument if self.input_range is not None: x_input = self.apply_input_range(x_input, not is_test) x_input = InputRangeForward.apply( x_input, self.input_range, self.rpu_config.pre_post.input_range ) if self.input_range is not None: x_input = x_input / self.input_range return x_input
[docs] def forward(self, x_input: Tensor, tensor_view: Optional[Tuple] = None) -> Tensor: """Torch forward function that calls the analog forward""" # pylint: disable=arguments-differ # Note: this is now called with autograd enabled and thus will # not use the BaseTile.backward functionality. It will call # the tile.forward pass internally self.tile.set_config(self.rpu_config) # to allow on-the-fly changes out = self.joint_forward(x_input, is_test=not if tensor_view is None: tensor_view = self.get_tensor_view(out.dim()) out = self.apply_out_scaling(out, tensor_view) if self.digital_bias: return out + self.bias.view(*tensor_view) return out
[docs] @no_grad() def post_update_step(self) -> None: """ Clip and remap weights after weights have been updated. """ if hasattr(self.rpu_config, "clip") and self.rpu_config.clip.type != WeightClipType.NONE: self.tile.clip_weights(self.rpu_config.clip) if hasattr(self.rpu_config, "remap") and self.rpu_config.remap.type != WeightRemapType.NONE: scales = self.get_scales() scales = self.tile.remap_weights(self.rpu_config.remap, scales) self.set_scales(scales)
[docs] def get_forward_parameters(self) -> Dict[str, Tensor]: """Get the additional parameters generated for the forward pass. Returns: Dictionary of the forward parameters set. """ dic = {} if self.tile.out_noise_values is not None: dic["out_noise_values"] = self.tile.out_noise_values return dic
[docs] def set_forward_parameters( self, dic: Optional[Dict[str, Tensor]] = None, **kwargs: Dict[str, Tensor] ) -> None: """Set the additional parameters generated for the forward pass. Currently only ``out_noise_values`` is implemented. Args: dic: dictionary of parameters to set (from :meth:`get_forward_parameter`) kwargs: parameter names can alternatively given directly as keywords Raises: ArgumentError: If size are mismatched or keyword unknown """ if dic is None: dic = kwargs par_lst = ["out_noise_values"] for par in par_lst: if par in dic: current_value = getattr(self.tile, par) new_value = dic[par] if not isinstance(new_value, Tensor): raise ArgumentError(f"{par} type mismatch. Expected tensor!") if current_value is None or current_value.size() != new_value.size(): raise ArgumentError(f"{par} size mismatch!") setattr(self.tile, par, new_value.reshape(*current_value.shape)) if set(par_lst) != set(list(dic.keys())): raise ArgumentError("Unknown parameter keys given!")