Source code for aihwkit.simulator.tiles.utils

# -*- coding: utf-8 -*-

# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Low level implementation of torch-based tile."""

from typing import Union, Tuple
from numpy import ndarray

from torch import Tensor, tensor
from torch import isinf as torch_isinf

from torch.autograd.function import FunctionCtx, InplaceFunction


[docs]class UniformQuantize(InplaceFunction): """Quantization in-place function.""" # pylint: disable=abstract-method, redefined-builtin, arguments-differ
[docs] @staticmethod def forward( ctx: FunctionCtx, inp: Tensor, res: float, bound: float, stochastic: bool = False ) -> Tensor: """Quantizes the input tensor and performs straight-through estimation. Args: ctx (FunctionCtx): Context. inp (torch.Tensor): Input to be discretized. res (float): Resolution (number of states). bound (float): Input bounds w.r.t. which we quantize. stochastic (bool, optional): Stochastic rounding? Defaults to False. Returns: torch.Tensor: Quantized input. """ # - Compute 1 / states if the number of states are provided res = 1 / res if res > 1.0 else res assert res > 0, "resolution is <= 0" # - Scale res by range res *= 2 * bound output = inp.clone() output = output / res ctx.stochastic = stochastic if ctx.stochastic: # - Stochastic rounding noise = output.new(output.shape).uniform_(-0.5, 0.5) output.add_(noise) output = output.round() else: # - Perform explicit rounding output = output.round() # - Scale back down output *= res return output
[docs] @staticmethod def backward(ctx: FunctionCtx, grad_output: Tensor) -> Tuple[Tensor, None, None, None]: """Straight-through estimator. Args: ctx: Context. grad_output: Gradient w.r.t. the inputs. Returns: Gradients w.r.t. inputs to forward. """ # - Straight-through estimator grad_input = grad_output return grad_input, None, None, None
[docs]def isinf(x: Union[float, str, Tensor, ndarray]) -> Tensor: """Checks if the input is inf. Args: x (Union[float, str, torch.Tensor, ndarray]): Input. Returns: torch.Tensor: Boolean tensor where tensor is inf. """ return torch_isinf(tensor(x))