# -*- coding: utf-8 -*-
# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# Licensed under the MIT license. See LICENSE file in the project root for details.
# Copyright (c) 2021 Qualcomm Technologies, Inc.
# All Rights Reserved.
# mypy: disable-error-code=attr-defined
# pylint: disable=not-callable
"""Basic quantized modules"""
from typing import Any, Optional
from torch import Tensor, nn
from torch.nn.functional import linear, conv2d, layer_norm, embedding
from aihwkit.simulator.digital_low_precision.base_quantized_classes import (
FP32Acts,
QuantizedActivation,
)
from aihwkit.simulator.digital_low_precision.base_quantized_model import QuantizedModel
from aihwkit.simulator.digital_low_precision.hijacker import QuantizationHijacker
[docs]
class QuantLinear(QuantizationHijacker, nn.Linear):
"""Quantized layer of torch.nn.Linear with weight/act quantization"""
[docs]
def run_forward(
self, x: Tensor, weight: Tensor, bias: Tensor, offsets: Optional[Any] = None
) -> Tensor:
return linear(x.contiguous(), weight.contiguous(), bias=bias)
[docs]
class QuantConv2d(QuantizationHijacker, nn.Conv2d):
"""Quantized layer of torch.nn.Conv2d with weight/act quantization"""
[docs]
def run_forward(
self, x: Tensor, weight: Tensor, bias: Tensor, offsets: Optional[Any] = None
) -> Tensor:
return conv2d(
x.contiguous(),
weight.contiguous(),
bias,
self.stride,
self.padding,
self.dilation,
self.groups,
)
[docs]
class QuantLayerNorm(QuantizationHijacker, nn.LayerNorm):
"""Quantized layer of torch.nn.LayerNorm with input and weight quantization"""
[docs]
def run_forward(
self, x: Tensor, weight: Tensor, bias: Tensor, offsets: Optional[Any] = None
) -> Tensor:
return layer_norm(
input=x.contiguous(),
normalized_shape=self.normalized_shape,
weight=weight.contiguous(),
bias=bias.contiguous(),
eps=self.eps,
)
[docs]
class QuantEmbedding(QuantizationHijacker, nn.Embedding):
"""Quantization of the Embedding, weight quantization.
Note: Embedding should not quantize activations, as it is simply a lookup table,
which is already quantized.
"""
def __init__(self, *args: Any, activation: Optional[Any] = None, **kwargs: Any):
super().__init__(*args, activation=activation, **kwargs)
# NB: Embedding should not quantize activations, as it is simply a lookup table,
# which is already quantized.
self.activation_quantizer = FP32Acts()
[docs]
def run_forward(
self, x: Tensor, weight: Tensor, bias: Tensor, offsets: Optional[Any] = None
) -> Tensor:
return embedding(
input=x.contiguous(),
weight=weight.contiguous(),
padding_idx=self.padding_idx,
max_norm=self.max_norm,
norm_type=self.norm_type,
scale_grad_by_freq=self.scale_grad_by_freq,
sparse=self.sparse,
)
[docs]
class QuantBatchNorm2d(QuantizedModel):
"""Quantization of the BatchNorm2d module. output activations are quantized."""
def __init__(self, org_model: nn.Module, **quant_params: Any):
super().__init__()
self.module = org_model
self.act_bn_quantizer = QuantizedActivation(**quant_params)
[docs]
def forward(self, x: Tensor) -> Tensor:
"""Execute BatchNorm2d and then quantize its output"""
y = self.module(x)
y_quant = self.act_bn_quantizer(y)
return y_quant