Source code for aihwkit.simulator.digital_low_precision.base_quantized_classes

# Copyright (c) 2021 Qualcomm Technologies, Inc.
# All Rights Reserved.

# pylint: skip-file
# type: ignore

from torch import nn

from aihwkit.simulator.digital_low_precision.quantization_manager import QuantizationManager
from aihwkit.simulator.digital_low_precision.quantizers import QMethods
from aihwkit.simulator.digital_low_precision.range_estimators import RangeEstimators


def _set_layer_learn_ranges(layer):
    if isinstance(layer, QuantizationManager):
        if layer.quantizer.is_initialized:
            layer.learn_ranges()


def _set_layer_fix_ranges(layer):
    if isinstance(layer, QuantizationManager):
        if layer.quantizer.is_initialized:
            layer.fix_ranges()


def _set_layer_estimate_ranges(layer):
    if isinstance(layer, QuantizationManager):
        # if layer.quantizer.is_initialized:
        layer.estimate_ranges()


def _set_layer_estimate_ranges_train(layer):
    if isinstance(layer, QuantizationManager):
        # if layer.quantizer.is_initialized:
        layer.estimate_ranges_train()


[docs] class QuantizedModule(nn.Module): """ Parent class for a quantized module. It adds the basic functionality of switching the module between quantized and full precision mode. It also defines the cached parameters and handles the reset of the cache properly. """ def __init__( self, *args, method=QMethods.asymmetric_uniform, act_method=None, n_bits=8, n_bits_act=None, per_channel_weights=False, per_channel_acts=False, percentile=None, weight_range_method=RangeEstimators.current_minmax, weight_range_options=None, act_range_method=RangeEstimators.running_minmax, act_range_options=None, scale_domain="linear", **kwargs, ): kwargs.pop("quant_dict", None) kwargs.pop("module_map", None) kwargs.pop("quantization_map", None) super().__init__(*args, **kwargs) self.method = method self.act_method = act_method or method self.n_bits = n_bits self.n_bits_act = n_bits_act if n_bits_act is not None else n_bits self.per_channel_weights = per_channel_weights self.per_channel_acts = per_channel_acts self.percentile = percentile self.weight_range_method = weight_range_method self.weight_range_options = weight_range_options if weight_range_options else {} self.act_range_method = act_range_method self.act_range_options = act_range_options if act_range_options else {} self.scale_domain = scale_domain self.cached_params = None self._caching = True self.quant_params = None self._quant_w = False self._quant_a = False @property def caching(self): return self._caching @caching.setter def caching(self, value: bool): self._caching = value if not value: self.cached_params = None
[docs] def quantized_weights(self): self.cached_params = None self._quant_w = True
[docs] def full_precision_weights(self): self.cached_params = None self._quant_w = False
[docs] def quantized_acts(self): self._quant_a = True
[docs] def full_precision_acts(self): self._quant_a = False
[docs] def quantized(self): self.quantized_weights() self.quantized_acts()
[docs] def full_precision(self): self.full_precision_weights() self.full_precision_acts()
[docs] def learn_ranges(self): self.apply(_set_layer_learn_ranges)
[docs] def fix_ranges(self): self.apply(_set_layer_fix_ranges)
[docs] def estimate_ranges(self): self.apply(_set_layer_estimate_ranges)
[docs] def estimate_ranges_train(self): self.apply(_set_layer_estimate_ranges_train)
[docs] def train(self, mode=True): super().train(mode) if mode: self.cached_params = None return self
def _apply(self, *args, **kwargs): self.cached_params = None return super(QuantizedModule, self)._apply(*args, **kwargs)
[docs] def extra_repr(self): quant_state = "weight_quant={}, act_quant={}".format(self._quant_w, self._quant_a) parent_repr = super().extra_repr() return "{},\n{}".format(parent_repr, quant_state) if parent_repr else quant_state
[docs] class QuantizedActivation(QuantizedModule): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) act_qparams = dict(n_bits=self.n_bits_act, scale_domain=self.scale_domain) self.activation_quantizer = QuantizationManager( qmethod=self.act_method, qparams=act_qparams, init=self.act_range_method, init_params=self.act_range_options, )
[docs] def quantize_activations(self, x): if self._quant_a: return self.activation_quantizer(x) else: return x
[docs] def forward(self, x): return self.quantize_activations(x)
[docs] class FP32Acts(nn.Module):
[docs] def forward(self, x): return x
[docs] def reset_ranges(self): pass