Source code for aihwkit.inference.compensation.drift

# -*- coding: utf-8 -*-

# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# Licensed under the MIT license. See LICENSE file in the project root for details.

"""Global drift compensation for inference."""

from typing import Tuple, TYPE_CHECKING

from torch.autograd import no_grad
from torch import abs as torch_abs
from torch import clamp, Tensor, eye

from aihwkit.inference.compensation.base import BaseDriftCompensation

if TYPE_CHECKING:
    from aihwkit.simulator.tiles import InferenceTileWithPeriphery


[docs] class GlobalDriftCompensation(BaseDriftCompensation): """Global drift compensation. Uses a constant factor for compensating the drift. """
[docs] @no_grad() def readout(self, out_tensor: Tensor) -> Tensor: """Read outs the mean abs.""" return clamp(torch_abs(out_tensor).mean(), min=0.0001)
[docs] @no_grad() def get_readout_tensor(self, in_size: int) -> Tensor: """Return the read-out tensor. Uses the set of one-hot vectors (eye). """ return eye(in_size)
def __str__(self) -> str: return "{}()".format(self.__class__.__name__)
[docs] class GlobalDriftCompensationWithExactReference(GlobalDriftCompensation): """Global drift compensation using an exact (ideal) reference readout. Uses a constant factor for compensating the drift. """
[docs] @no_grad() def init_baseline(self, tile: "InferenceTileWithPeriphery") -> Tuple[Tensor, Tensor]: """Initialize the base line for applying the compensation. Uses a all one tensor for read_out. Args: tile: forward output of the read out vector to compensate Returns: reference tensor readout """ forward_output = tile._forward_drift_readout_tensor(True, exact_reference=True) ref_value = self.readout(forward_output) return ref_value
[docs] class PerColumnDriftCompensation(BaseDriftCompensation): """Per column drift compensation. Uses a vector for compensating the drift. """
[docs] @no_grad() def readout(self, out_tensor: Tensor) -> Tensor: """Read outs the per-column mean abs.""" return clamp(torch_abs(out_tensor).mean(dim=0), min=0.0001)
[docs] @no_grad() def get_readout_tensor(self, in_size: int) -> Tensor: """Return the read-out tensor. Uses the set of one-hot vectors (eye). """ return eye(in_size)
def __str__(self) -> str: return "{}()".format(self.__class__.__name__)