Source code for aihwkit.nn.low_precision_modules.quantization_states

# -*- coding: utf-8 -*-

# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# Licensed under the MIT license. See LICENSE file in the project root for details.

# Copyright (c) 2021 Qualcomm Technologies, Inc.
# All Rights Reserved.

# mypy: disable-error-code=attr-defined

"""Functions to manipulate the quantization state of a module"""

from torch.nn import Module

from aihwkit.simulator.digital_low_precision.base_quantized_classes import (
    QuantizedModule,
    _set_layer_estimate_ranges,
    _set_layer_estimate_ranges_train,
    _set_layer_fix_ranges,
    _set_layer_learn_ranges,
)


[docs] def quantized_weights(model: Module) -> None: """Enables quantization of the weights. Parameters ---------- model : Module Model to enable weight quantization recursively on its modules. """ def _fn(layer: Module) -> None: if isinstance(layer, QuantizedModule): layer.quantized_weights() model.apply(_fn)
[docs] def full_precision_weights(model: Module) -> None: """Places the weights in full precision. Parameters ---------- model : Module Model to place the weight of its modules in full precision """ def _fn(layer: Module) -> None: if isinstance(layer, QuantizedModule): layer.full_precision_weights() model.apply(_fn)
[docs] def quantized_acts(model: Module) -> None: """Enables quantization of the activations. Parameters ---------- model : Module Model to enable activation quantization recursively on its modules. """ def _fn(layer: Module) -> None: if isinstance(layer, QuantizedModule): layer.quantized_acts() model.apply(_fn)
[docs] def full_precision_acts(model: Module) -> None: """Places the activations in full precision. Parameters ---------- model : Module Model to place the activations of its modules in full precision """ def _fn(layer: Module) -> None: if isinstance(layer, QuantizedModule): layer.full_precision_acts() model.apply(_fn)
[docs] def quantized(model: Module) -> None: """Enables quantization on both weights and activations. Parameters ---------- model : Module Model to enable activation and weight quantization recursively on its modules. """ def _fn(layer: Module) -> None: if isinstance(layer, QuantizedModule): layer.quantized() model.apply(_fn)
[docs] def full_precision(model: Module) -> None: """Places the activations and weights in full precision. Parameters ---------- model : Module Model to place the activations and weights of its modules in full precision """ def _fn(layer: Module) -> None: if isinstance(layer, QuantizedModule): layer.full_precision() model.apply(_fn)
# Methods for switching quantizer quantization states
[docs] def learn_ranges(model: Module) -> None: """Places the quantizers of a model in `learn_ranges` mode Parameters ---------- model : Module Model to place the quantizers of its modules in `learn_ranges` mode """ model.apply(_set_layer_learn_ranges)
[docs] def fix_ranges(model: Module) -> None: """Places the quantizers of a model in `fix_ranges` mode Parameters ---------- model : Module Model to place the quantizers of its modules in `fix_ranges` mode """ model.apply(_set_layer_fix_ranges)
[docs] def fix_act_ranges(model: Module) -> None: """Places the activation quantizers of a model in `fix_ranges` mode Parameters ---------- model : Module Model to place the activation quantizers of its modules in `fix_ranges` mode """ def _fn(module: Module) -> None: if isinstance(module, QuantizedModule) and hasattr(module, "activation_quantizer"): _set_layer_fix_ranges(module.activation_quantizer) model.apply(_fn)
[docs] def fix_weight_ranges(model: Module) -> None: """Places the weight quantizers of a model in `fix_ranges` mode Parameters ---------- model : Module Model to place the weight quantizers of its modules in `fix_ranges` mode """ def _fn(module: Module) -> None: if isinstance(module, QuantizedModule) and hasattr(module, "weight_quantizer"): _set_layer_fix_ranges(module.weight_quantizer) model.apply(_fn)
[docs] def estimate_ranges(model: Module) -> None: """Places the quantizers of a model in `estimate_ranges` mode Parameters ---------- model : Module Model to place the quantizers of its modules in `estimate_ranges` mode """ model.apply(_set_layer_estimate_ranges)
[docs] def estimate_act_ranges(model: Module) -> None: """Places the activation quantizers of a model in `estimate_ranges` mode Parameters ---------- model : Module Model to place the activation quantizers of its modules in `estimate_ranges` mode """ def _fn(module: Module) -> None: if isinstance(module, QuantizedModule) and hasattr(module, "activation_quantizer"): _set_layer_estimate_ranges(module.activation_quantizer) model.apply(_fn)
[docs] def estimate_ranges_train(model: Module) -> None: """Places the quantizers of a model in `estimate_ranges_train` mode Parameters ---------- model : Module Model to place the quantizers of its modules in `estimate_ranges_train` mode """ model.apply(_set_layer_estimate_ranges_train)
[docs] def reset_act_ranges(model: Module) -> None: """Resets the activation ranges of a model to uninitialized Parameters ---------- model : Module Model to reset the activation quantizers of its modules """ def _fn(module: Module) -> None: if isinstance(module, QuantizedModule) and hasattr(module, "activation_quantizer"): module.activation_quantizer.reset_ranges() model.apply(_fn)
[docs] def set_quant_state(model: Module, weight_quant: bool, act_quant: bool) -> None: """Function to configure the activation and weight quantizers of a model. The model can be configured to either have the weights/activations in full precision or the quantization enabled. Parameters ---------- model : Module Model to configure weight_quant : bool If True, enable weight quantization for all modules in the model. If False, keep the weights in full precision. act_quant : bool If True, enable activation quantization for all modules in the model. If False, keep the activations in full precision. """ if act_quant: quantized_acts(model) else: full_precision_acts(model) if weight_quant: quantized_weights(model) else: full_precision_weights(model)
[docs] def enable_quant_states(model: Module) -> None: """Function to enable the quantization states in all modules that inherit from the `QuantizedModule` class. Parameters ---------- model : Module Model to configure """ def _fn(layer: Module) -> None: if isinstance(layer, QuantizedModule): if layer.n_bits > 0: layer.quantized_weights() else: layer.full_precision_weights() if layer.n_bits_act > 0: layer.quantized_acts() else: layer.full_precision_acts() model.apply(_fn)