# -*- coding: utf-8 -*-
# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
# pylint: disable=too-many-locals, too-many-arguments
"""Low level implementation of torch-based tile."""
from typing import TYPE_CHECKING
from torch import Tensor
from aihwkit.exceptions import TorchTileConfigError
from aihwkit.simulator.tiles.torch_tile import TorchSimulatorTile
from aihwkit.simulator.tiles.analog_mvm_irdrop_t import AnalogMVMIRDropT
if TYPE_CHECKING:
from aihwkit.simulator.configs.configs import TorchInferenceRPUConfigIRDropT
[docs]class TorchSimulatorTileIRDropT(TorchSimulatorTile):
"""Torch tile class including time-dependent IR drop calculation.
Args:
x_size: input size
d_size: output size
rpu_config: resistive processing unit configuration.
"""
# pylint: disable=abstract-method
def __init__(
self,
x_size: int,
d_size: int,
rpu_config: "TorchInferenceRPUConfigIRDropT",
bias: bool = False,
):
self._phys_input_size = rpu_config.mapping.max_input_size or x_size
self._g_converter = rpu_config.noise_model.g_converter
super().__init__(x_size, d_size, rpu_config, bias, analog_mvm=AnalogMVMIRDropT)
[docs] def forward(
self,
x_input: Tensor,
bias: bool = False,
in_trans: bool = False,
out_trans: bool = False,
is_test: bool = False,
non_blocking: bool = False,
) -> Tensor:
if in_trans or out_trans:
raise TorchTileConfigError("Non-trans MVMs supported only.")
if not is_test:
noisy_weights = self.modify_weight(self.weight, self._modifier, x_input.shape[0])
else:
noisy_weights = self.weight
return self._analog_mvm.matmul(
noisy_weights,
x_input,
self._f_io,
False,
is_test,
phys_input_size=self._phys_input_size,
g_converter=self._g_converter,
out_noise_values=self.out_noise_values,
)