Source code for aihwkit.nn.low_precision_modules.conversion_utils

# -*- coding: utf-8 -*-

# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# Licensed under the MIT license. See LICENSE file in the project root for details.

# Copyright (c) 2021 Qualcomm Technologies, Inc.
# All Rights Reserved.

# mypy: disable-error-code=attr-defined

"""Utility funcitons for converting a module to its quantized counterpart"""

from copy import deepcopy
from typing import Any, Dict, Optional

from torch.nn import Conv2d, Embedding, LayerNorm, Linear, Module

from aihwkit.nn.low_precision_modules.quantized_base_modules import (
    QuantConv2d,
    QuantEmbedding,
    QuantLayerNorm,
    QuantLinear,
)
from aihwkit.simulator.digital_low_precision.quantizers import QuantizerBase
from aihwkit.simulator.digital_low_precision.range_estimators import RangeEstimatorBase
from aihwkit.simulator.parameters.quantization import QuantizationMap, QuantizedModuleConfig

DEFAULT_CONVERSIONS = {
    Linear: QuantLinear,
    LayerNorm: QuantLayerNorm,
    Embedding: QuantEmbedding,
    Conv2d: QuantConv2d,
}
LEAF_MODULES = (RangeEstimatorBase, QuantizerBase)


[docs] def append_default_conversions(quantization_map: QuantizationMap) -> None: """ Appends the default conversions defined in the `DEFAULT_CONVERSIONS` dictionary in the `QuantizationMap` datastructure. If a conversion for a specific layer is already defined in the datastructure, it skips it. As for the conversion's `QuantizationConfig`, it utilizes the default one defined in the quantization_map.default_qconfig field. Parameters ---------- quantization_map : QuantizationMap The QuantizationMap instance to append the default conversions """ for module, q_module in DEFAULT_CONVERSIONS.items(): if module in quantization_map.module_qconfig_map: continue quantization_map.module_qconfig_map[module] = QuantizedModuleConfig( quantized_module=q_module, module_qconfig=deepcopy(quantization_map.default_qconfig) )
[docs] def get_module_args(module: Module, activation: Optional[Module] = None) -> dict: """ Get the arguments from a pytorch module to provide it to the initialization function of the quantized modules. The way to retrieve the arguments for each type of module are defined with functions defined inside this functions, with the convention `get_{module_type}_args` Parameters ---------- module : Module The module to extract the arguments from activation : Optional[Module], optional The activation function for the `QuantizationHijacker` if applicable, by default None Raises ------ ValueError If the function has not been tought how to handle a given module. """ def get_linear_args(module: Module) -> Dict[str, Any]: """Quantization arguments for `QuantLinear`""" args = dict( in_features=module.in_features, out_features=module.out_features, bias=module.bias is not None, ) return args def get_layernorm_args(module: Module) -> Dict[str, Any]: """Quantization arguments for `QuantLayerNorm`""" args = dict(normalized_shape=module.normalized_shape, eps=module.eps) return args def get_embedding_args(module: Module) -> Dict[str, Any]: """Quantization arguments for `QuantEmbeddings`""" args = dict( num_embeddings=module.num_embeddings, embedding_dim=module.embedding_dim, padding_idx=module.padding_idx, max_norm=module.max_norm, norm_type=module.norm_type, scale_grad_by_freq=module.scale_grad_by_freq, sparse=module.sparse, ) return args def get_conv2d_args(module: Module) -> Dict[str, Any]: """Quantization arguments for `QuantConv2d`""" args = dict( in_channels=module.in_channels, out_channels=module.out_channels, kernel_size=module.kernel_size, stride=module.stride, padding=module.padding, dilation=module.dilation, groups=module.groups, bias=module.bias is not None, padding_mode=module.padding_mode, ) return args if isinstance(module, Linear): kwargs = get_linear_args(module) elif isinstance(module, LayerNorm): kwargs = get_layernorm_args(module) elif isinstance(module, Embedding): kwargs = get_embedding_args(module) elif isinstance(module, Conv2d): kwargs = get_conv2d_args(module) else: raise ValueError kwargs["activation"] = activation return kwargs