Source code for aihwkit.inference.calibration.calibration

# -*- coding: utf-8 -*-

# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# Licensed under the MIT license. See LICENSE file in the project root for details.

"""Calibration for inference."""

from typing import Optional, Dict, Tuple, TYPE_CHECKING, cast
from collections.abc import Iterator
from functools import partial
from enum import Enum

from tqdm import tqdm

from torch import tensor, Tensor, cat, randperm, no_grad
from torch.nn import Module
from torch.utils.data.dataloader import DataLoader

from aihwkit.exceptions import ConfigError, ArgumentError
from aihwkit.simulator.parameters.enums import NoiseManagementType
from aihwkit.simulator.parameters.pre_post import PrePostProcessingRPU
from aihwkit.simulator.tiles.base import AnalogTileStateNames
from aihwkit.nn.modules.base import AnalogLayerBase
from aihwkit.nn.low_precision_modules.quantization_states import (
    estimate_ranges,
    fix_act_ranges,
    fix_ranges,
    fix_weight_ranges,
)


if TYPE_CHECKING:
    from aihwkit.simulator.parameters import IOParameters


[docs] class InputRangeCalibrationType(Enum): """Input range post-training calibration type. Different styles of calibrating the DAC ranges post-training. """ NONE = "None" """No Calibration.""" MOVING_STD = "MovingStd" """Computes a moving average of x*standard deviation of the inputs.""" MOVING_QUANTILE = "MovingQuantile" """Computes the moving average of the quantiles. Saves memory.""" CACHE_QUANTILE = "CacheQuantile" """Caches inputs that are then used to compute the Xth quantile for the input range.""" MAX = "Max" """Takes the abs().max() over the inputs."""
def _calibration_pre_forward( mod: Module, input_args: Tuple, calibration_type: InputRangeCalibrationType, cache_key: str, global_cache: Dict[str, Tensor], max_samples: int = 1000, ir_quantile: float = 0.99, ) -> None: """Caches inputs for calibrating the input ranges. Args: input_args: Forward inputs. calibration_type: type used for calibration cache_key: key of global cache max_samples: Maximal number of cache samples """ # get rid of entries that are all-zeros x_input = input_args[0] x_input = x_input.reshape(-1, x_input.size(-1)) x_input = x_input[~(x_input == 0.0).all(-1)] ir_params = mod.rpu_config.pre_post.input_range # type: ignore cache = global_cache[cache_key] if calibration_type in [ InputRangeCalibrationType.CACHE_QUANTILE, InputRangeCalibrationType.MAX, ]: # We need to cache the inputs # Add new samples to the cache if calibration_type in [InputRangeCalibrationType.CACHE_QUANTILE]: cache = cat([cache, x_input.reshape(-1, x_input.size(-1)).clone().detach().cpu()]) # Shuffle and limit the number cache = cache[randperm(cache.size(0))[:max_samples]] else: # Compute the max if cache.numel() == 0: cache = x_input.abs().max().detach() else: cache = max(cache, x_input.abs().max().detach()) elif calibration_type in [ InputRangeCalibrationType.MOVING_QUANTILE, InputRangeCalibrationType.MOVING_STD, ]: idx = mod.input_range_update_idx val = 0 if calibration_type == InputRangeCalibrationType.MOVING_QUANTILE: val = ( x_input.abs().max() if ir_quantile == 1.0 else x_input.flatten().quantile(ir_quantile) ).item() else: if idx < max_samples: std = x_input.std().item() val = ir_params.init_std_alpha * std if val > 0: old_val = mod.input_range.item() new_val = (old_val * idx + val) / (idx + 1) mod.set_input_range(new_val) mod.input_range_update_idx += 1 else: raise ConfigError(f"Unknown InputRangeCalibrationType {calibration_type}") global_cache[cache_key] = cache
[docs] @no_grad() def calibrate_input_ranges( model: Module, calibration_type: InputRangeCalibrationType, dataloader: Iterator, quantile: float = 0.99995, max_samples: int = 1000, std_alpha: Optional[float] = None, force_all_layers: bool = True, verbose: bool = False, ) -> None: """Calibrate the input ranges according to the defined strategy. Only tiles that support and have enabled input range learning will be calibrated. If noise management is turned on an error is raised. Note: This implementation transiently registers a new `forward_pre_hook` on the analog tile level. It assumes that the user has not defined any other forward prehooks. Args: model: The analog model for which to calibrate the input ranges. calibration_type: Strategy of the calibration. See :class:`~InputRangeCalibrationType` dataloader: Iterator that yields the next inputs. Is used like this ``x = next(dataloader); model(x)`` quantile: Quantile used for hard-coded quantile setting. Defaults to 0.99995. max_samples: Max batch samples to cache in each tile. Defaults to 1000. std_alpha: Number of standard deviations for moving standard deviation strategy. Defaults to ``init_std_alpha`` from RPUConfig force_all_layers: Whether to force all layers to be (re)-calibrated (default). Otherwise only the layer having ``input_range.enable = True`` will be calibrated. verbose: Whether to print verbose output. Raises: ConfigError: If RPUConfig does not support input range learning ArgumentError: If non-analog model is given """ # pylint: disable=too-many-statements, too-many-locals, too-many-branches if calibration_type == InputRangeCalibrationType.NONE: return if not isinstance(model, AnalogLayerBase) or not isinstance(model, Module): raise ArgumentError("Expect an analog module") was_training = cast(Module, model).training model = cast(Module, model).eval() handles = [] is_perfect_dic = {} cache = {} # type: Dict[str, Tensor] for tile_name, tile in model.named_analog_tiles(): rpu_config = tile.rpu_config if not isinstance(rpu_config, PrePostProcessingRPU) or not hasattr(rpu_config, "forward"): continue if not force_all_layers and not rpu_config.pre_post.input_range.enable: continue # Reset / modify the necessary tile fields if not rpu_config.pre_post.input_range.enable: rpu_config.pre_post.input_range.enable = True rpu_config.pre_post.input_range.learn_input_range = False tile.init_input_processing() rpu_config.pre_post.input_range.init_from_data = 0 # turn off on-the-fly mechanism if std_alpha is not None: rpu_config.pre_post.input_range.init_std_alpha = std_alpha needs_set_state = False io_pars = rpu_config.forward # type: IOParameters if io_pars.noise_management != NoiseManagementType.NONE: if not force_all_layers: raise ConfigError( "Noise management should be turned off for input_range calibration." ) io_pars.noise_management = NoiseManagementType.NONE needs_set_state = True is_perfect_dic[tile_name] = io_pars.is_perfect if ( "Max" in calibration_type.value or "Cache" in calibration_type.value ) and not is_perfect_dic[tile_name]: rpu_config.forward.is_perfect = True needs_set_state = True if needs_set_state: # need to recreate tile to apply rpu config changes to tile tile_state = tile.__getstate__() tile_state[AnalogTileStateNames.RPU_CONFIG] = rpu_config tile.__setstate__(tile_state) # generate hook cache[tile_name] = tensor([]) hook = partial( _calibration_pre_forward, ir_quantile=quantile, calibration_type=calibration_type, cache_key=tile_name, global_cache=cache, max_samples=max_samples, ) handles.append(tile.register_forward_pre_hook(hook)) # Pass through the samples progress_bar = tqdm if verbose else lambda x: x for args, kwargs in progress_bar(dataloader): # type: ignore model(*args, **kwargs) # Remove hooks for handle in handles: handle.remove() # now create the input range fields for tile_name, tile in model.named_analog_tiles(): rpu_config = tile.rpu_config if not rpu_config.pre_post.input_range.enable: rpu_config.pre_post.input_range.enable = True rpu_config.pre_post.input_range.learn_input_range = False tile.init_input_processing() for tile_name, tile in model.named_analog_tiles(): rpu_config = tile.rpu_config if not isinstance(rpu_config, PrePostProcessingRPU) or not hasattr(rpu_config, "forward"): continue if not force_all_layers and not rpu_config.pre_post.input_range.enable: continue inputs = cache[tile_name] if inputs.numel() == 0: if verbose: print(f"Warning: Tile {tile_name} cached inputs is empty") continue input_range = tile.input_range.item() # Compute on the cache if calibration_type == InputRangeCalibrationType.CACHE_QUANTILE: input_range = inputs.flatten().quantile(quantile).item() elif calibration_type == InputRangeCalibrationType.MAX: input_range = inputs.item() # Restore the tile if necessary if rpu_config.forward.is_perfect != is_perfect_dic[tile_name]: tile_state = tile.__getstate__() tile_state[AnalogTileStateNames.RPU_CONFIG].forward.is_perfect = is_perfect_dic[ tile_name ] tile.__setstate__(tile_state) # set the input range tile.set_input_range(input_range) if verbose: print(f"Calibrated tile {tile_name}: {input_range: .5f}.") # Store calibration info rpu_config.pre_post.input_range.init_value = tile.input_range.item() rpu_config.pre_post.input_range.calibration_info = calibration_type.value if was_training: model = cast(Module, model).train()
[docs] def calibrate_quantization_ranges( model: Module, loader: DataLoader, max_num_batches: int = 20 ) -> None: """ Calibrate the scales and zero-point (if applicable) for a model. The algorithm for estimating the ranges is directly defined inside the quantizers of this library. The primary use of this function is for PTQ or for initializing the quantizers with a value before proceeding to QAT. Parameters ---------- model : Module The model to be calibrated loader : DataLoader The dataloader object to generate the batches of data for the estimation max_num_batches : int, optional The maximum number of batches to use for estimation, by default 20 """ def pass_data_for_range_estimation( loader: DataLoader, model: Module, max_num_batches: int = 20, dataloader_inp_idx: int = 0 ) -> None: """ Places the model in eval mode and passes a number of batches so that the quantizers can estimate their ranges. Its primary use is for PTQ or for initializing the quantizers with a value before proceeding to QAT. Parameters ---------- loader : DataLoader The dataloader object to generate the batches of data for the estimation model : Module The model to be calibrated max_num_batches : int, optional The maximum number of batches to use for estimation, by default 20 dataloader_inp_idx : int, optional Only applicable when the dataloader returns a list or tuple. Signifies which position of the list or tuple are the input data to be provided to the model, by default 0 """ model.eval() device = next(model.parameters()).device for i, data in enumerate(loader): if i > max_num_batches: break if isinstance(data, (tuple, list)): # The case that the dataloader returns a list/tuple of (data, targets) # (e.g., CIFAR dataloaders) x = data[dataloader_inp_idx].to(device=device) model(x) else: # The case that the dataloader returns a dictionary, because the model takes kwargs. x = {k: v.to(device=device) for k, v in data.items()} model(**x) # Place the model in `estimate_ranges` mode estimate_ranges(model) # Pass batches of data for estimation pass_data_for_range_estimation(loader=loader, model=model, max_num_batches=max_num_batches) # Fix all the quantizer ranges after they've been estimated fix_ranges(model) fix_act_ranges(model) fix_weight_ranges(model)