Source code for aihwkit.nn.low_precision_conversion

# -*- coding: utf-8 -*-

# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# Licensed under the MIT license. See LICENSE file in the project root for details.

"""Functions to convert a given model to a quantized counterpart"""

from copy import deepcopy

from torch.nn import Module, Linear, Conv2d

from aihwkit.nn.conversion import convert_to_digital
from aihwkit.nn.low_precision_modules.conversion_utils import (
    DEFAULT_CONVERSIONS,
    LEAF_MODULES,
    get_module_args,
)
from aihwkit.nn.low_precision_modules.quantization_states import enable_quant_states
from aihwkit.nn.low_precision_modules.quantized_input_module import QuantizedInputModule
from aihwkit.nn.modules.base import AnalogLayerBase
from aihwkit.simulator.digital_low_precision.config_utils import convert_configs_to_kwargs_dict
from aihwkit.simulator.parameters.quantization import QuantizationMap


[docs] def convert_to_quantized(model: Module, quantization_map: QuantizationMap) -> Module: """High level function to perform the quantization of a model according to the QuantizationMap defined by the user. See the `QuantizationMap` dataclass for instructions on how to define its fields. This function calls the recursive `quantize_model` function which performs the actual conversion between the individual modules. It then enables the quantization states for all the quantized modules. Parameters ---------- model : Module The model to quantize quantization_map : QuantizationMap The dataclass that contains instructions on how to quantize the model Returns ------- torch.nn.Module The quantized model """ quant_model = deepcopy(model) quant_model = quantize_model(quant_model, quantization_map) enable_quant_states(quant_model) return quant_model
[docs] def quantize_model( model: Module, quantization_map: QuantizationMap, model_name: str = "" ) -> Module: """Traverses a model recursively and replaces a module with a quantized counterpart, if such a conversion is defined in the quantization map. It realizes all the capabilities of the QuantizationMap, namely: - Excluding specific modules from the quantization, identified by exact state_dict string (`quantization_map.excluded_modules`) - Instance specific quantization, identified by exact state_dict string (`quantization_map.instance_qconfig_map`) - Module specific quantization, identified by the type of a module (`quantization_map.module_qconfig_map`) - Input quantization on specified modules, identified by the state_dict string (`quantization_map.input_activation_qconfig_map`) (see `QuantizationMap` and the related examples for more details on how to define these fields) Parameters ---------- model : Module Model to recursively quantize quantization_map : QuantizationMap The dataclass that defines the quantization conversions as well as the quantization parameters for each conversion model_name : str, optional The name of the current module in the state dict of the original model, by default "" Returns ------- Module Quantized model """ # pylint: disable=too-many-branches # Exclude modules based on state_dict string if model_name in quantization_map.excluded_modules: return None # Early exit for quantizers and range estimators if isinstance(model, LEAF_MODULES): return None # Quantization code for instance- or module-specific conversions if (model_name in quantization_map.instance_qconfig_map) or ( type(model) in quantization_map.module_qconfig_map ): # Select the quantized module counterpart and the quantization parameters # prioritizing the instance-specific case qmodule_config = ( quantization_map.instance_qconfig_map[model_name] if model_name in quantization_map.instance_qconfig_map else quantization_map.module_qconfig_map[type(model)] ) q_config, q_model = qmodule_config.module_qconfig, qmodule_config.quantized_module # Keep the original module if it's # not going to be used as a quantized module if not (q_config.activation_quant.n_bits > 0 or q_config.weight_quant.n_bits > 0): return model if type(model) in DEFAULT_CONVERSIONS: # If the type of the model falls under the DEFAULT_CONVERSIONS primitives, # then the quantized model code should use the QuantizationHijacker and as such # standard infrastructure is provided kwargs = get_module_args(model, None) quant_model = q_model(**kwargs, **convert_configs_to_kwargs_dict(q_config)) quant_model.weight.data = model.weight.data if getattr(model, "bias", None) is not None: quant_model.bias.data = model.bias.data elif isinstance(model, AnalogLayerBase): # If the module is an analog layer, then use the convert_to_digital first, to # merge the scales with the parameters and then quantize to the counterpart. # NOTE: Only Linear and Conv2d are supported at the moment. model = convert_to_digital(model) assert type(model) in [Linear, Conv2d], "Only Linear and Conv2d supported at the moment" kwargs = get_module_args(model, None) quant_model = q_model(**kwargs, **convert_configs_to_kwargs_dict(q_config)) quant_model.weight.data = model.weight.data if getattr(model, "bias", None) is not None: quant_model.bias.data = model.bias.data else: # Every other custom conversion, should conform to the following constructor structure, # passing the original model in the contructor along with various kwargs and the # quantization map kwargs = { "quantization_map": quantization_map, **convert_configs_to_kwargs_dict(q_config), } quant_model = q_model(model, **kwargs) else: # If the conversion is not defined, recursively call the function on the children quant_model = model for name, mod in quant_model.named_children(): full_name = model_name + "." + name if model_name else name if full_name in quantization_map.excluded_modules: continue new_mod = quantize_model(mod, quantization_map, model_name=full_name) # If the returned module is not None, wrap it with an input quantizer # if such a selection is defined and swap it with the original module if new_mod is not None: if full_name in quantization_map.input_activation_qconfig_map.keys(): new_mod = QuantizedInputModule( new_mod, quantization_map.input_activation_qconfig_map[full_name] ) setattr(quant_model, name, new_mod) return quant_model