Source code for aihwkit.nn.low_precision_modules.quantized_input_module

# -*- coding: utf-8 -*-

# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# Licensed under the MIT license. See LICENSE file in the project root for details.

# mypy: disable-error-code=attr-defined

"""Module wrapper to quantize its input"""

from typing import Any

from torch import Tensor
from torch.nn import Module

from aihwkit.simulator.digital_low_precision.base_quantized_classes import QuantizedActivation
from aihwkit.simulator.digital_low_precision.config_utils import convert_act_config_to_kwargs_dict
from aihwkit.simulator.parameters.quantization import ActivationQuantConfig


[docs] class QuantizedInputModule(Module): """Wraps a module with an activation quantizer on the inputs, this offering the capability to quantize both the inputs and the outputs of a module. (NOTE: Output quantization is considered to be taken care of in the activation quantization of the module that is being wrapped) This is useful when a module is a first layer and consumes directly from the dataloader as well as when a layer follows a functional operation (e.g., activation function or addition) which did not quantize its own output down to the required size. """ def __init__(self, module: Module, act_quant_config: ActivationQuantConfig): super().__init__() # The original module self.module = module # Input quantizer self.input_quantizer = QuantizedActivation( **convert_act_config_to_kwargs_dict(act_quant_config) )
[docs] def forward(self, inp: Tensor, *args: Any, **kwargs: Any) -> Tensor: """Perform the forward of the original module after quantizing its input""" # Quantize inputs input_q = self.input_quantizer(inp) # Feed the quantized inputs to the original module return self.module(input_q, *args, **kwargs)