Source code for aihwkit.nn.conversion

# -*- coding: utf-8 -*-

# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Digital/analog model conversion utilities.

This module includes tools for converting a given torch model to a
model containing analog layers.
"""

from typing import TypeVar, Optional, Dict, Callable, Set, List
from copy import deepcopy

from torch.nn import Module, Linear, Conv1d, Conv2d, Conv3d, Sequential

from aihwkit.exceptions import ArgumentError
from aihwkit.simulator.tiles.module import TileModule
from aihwkit.nn.modules.container import AnalogWrapper
from aihwkit.nn.modules.base import AnalogLayerBase

from aihwkit.nn import (
    AnalogLinear,
    AnalogConv1d,
    AnalogConv2d,
    AnalogConv3d,
    AnalogLinearMapped,
    AnalogConv1dMapped,
    AnalogConv2dMapped,
    AnalogConv3dMapped,
    AnalogSequential,
)

RPUConfigGeneric = TypeVar("RPUConfigGeneric")

_DEFAULT_CONVERSION_MAP = {
    Linear: AnalogLinear,
    Conv1d: AnalogConv1d,
    Conv2d: AnalogConv2d,
    Conv3d: AnalogConv3d,
    Sequential: AnalogSequential,
}

_DEFAULT_MAPPED_CONVERSION_MAP = {
    Linear: AnalogLinearMapped,
    Conv1d: AnalogConv1dMapped,
    Conv2d: AnalogConv2dMapped,
    Conv3d: AnalogConv3dMapped,
    Sequential: AnalogSequential,
}

_DEFAULT_DIGITAL_CONVERSION_SET = {
    *_DEFAULT_CONVERSION_MAP.values(),
    *_DEFAULT_MAPPED_CONVERSION_MAP.values(),
}


[docs]def specific_rpu_config_id( module_name: str, module: Module, rpu_config: RPUConfigGeneric ) -> RPUConfigGeneric: """ID default function for specifying the ``RPUConfig`` during conversion for specific layers. A similar function can be given to the conversion. Args: module_name: The name of the module currently converted to analog module: the actual digital module to be converted rpu_config: a copy of the generic ``RPUConfig`` given to the overall conversion. Returns: modified ``RPUConfig`` """ # pylint: disable=unused-argument return rpu_config
[docs]def convert_to_analog( module: Module, rpu_config: RPUConfigGeneric, tile_module_class: Optional[TileModule] = None, conversion_map: Optional[Dict] = None, specific_rpu_config_fun: Optional[Callable] = None, module_name: str = "", ensure_analog_root: bool = True, exclude_modules: Optional[List[str]] = None, inplace: bool = False, verbose: bool = False, ) -> Module: """Convert a given digital model to its analog counterpart. Note: The torch device (cuda/cpu) is inferred from the original models parameters, however, if multiple torch devices are used in a given module, the corresponding analog module is not moved to any device. Args: module: The torch module to convert. All layers that are defined in the ``conversion_map``. rpu_config: RPU config to apply to all converted tiles. Applied to all converted tiles. tile_module_class: Custom tile module class conversion_map: Dictionary of module classes to be replaced in case of custom replacement rules. By default all ``Conv`` and ``Linear`` layers are replaced with their analog counterparts. Note: The analog layer needs to have a class method ``from_digital`` which will be called during the conversion. specific_rpu_config_fun: Function that modifies the generic RPUConfig for specific modules. See :func:`~specific_rpu_config_id` as an example how to specify it. module_name: Explicitly given name of the base (root) module, given to ``specific_rpu_config_fun``. ensure_analog_root: Whether to ensure that the root module is of layer type `AnalogLayerBase` so that custom analog are methods such as `drift_analog_weigths` are available. If set, it will wrap the model if `AnalogWrapper` if necessary. Note: Since the module structure changes when wrapped, the checkpoint names will also change if this is enabled (for legacy load this might need to be disabled). exclude_modules: List of module names that are in the conversion map but should be excluded from the conversion inplace: Whether to perform in-place conversion (without deepcopy) verbose: Increase verbosity. Will print converted layers. Returns: Module where all the digital layers are replaced with analog mapped layers. Raises: ArgumentError: in case conversion map is not a dictionary """ # pylint: disable=too-many-branches, too-many-locals if exclude_modules is None: exclude_modules = [] if not inplace: module = deepcopy(module) if conversion_map is None: conversion_map = _DEFAULT_CONVERSION_MAP if not isinstance(conversion_map, dict): raise ArgumentError("Conversion map should be a dictionary") if specific_rpu_config_fun is None: specific_rpu_config_fun = specific_rpu_config_id # Convert parent. if module.__class__ in conversion_map and module_name not in exclude_modules: if verbose: print(f"Converted '{module_name}' to '{conversion_map[module.__class__].__name__}'.") module = conversion_map[module.__class__].from_digital( module, specific_rpu_config_fun(module_name, module, deepcopy(rpu_config)), tile_module_class, ) # Convert children. convert_dic = {} for name, mod in module.named_children(): full_name = module_name + "." + name if module_name else name if full_name in exclude_modules: continue n_grand_children = len(list(mod.named_children())) if n_grand_children > 0: convert_dic[name] = convert_to_analog( mod, rpu_config, tile_module_class, conversion_map, specific_rpu_config_fun, full_name, True, exclude_modules, True, verbose, ) continue if mod.__class__ not in conversion_map.keys(): continue if verbose: print(f"Converted '{full_name}' to '{conversion_map[mod.__class__].__name__}'.") convert_dic[name] = conversion_map[mod.__class__].from_digital( mod, specific_rpu_config_fun(full_name, mod, deepcopy(rpu_config)), tile_module_class ) for name, new_mod in convert_dic.items(): module._modules[name] = new_mod # pylint: disable=protected-access # in case of root, make sure it is wrapped as analog if ensure_analog_root and not module_name and not isinstance(module, AnalogLayerBase): module = AnalogWrapper(module) return module
[docs]def convert_to_analog_mapped( module: Module, rpu_config: RPUConfigGeneric, tile_module_class: Optional[TileModule] = None, specific_rpu_config_fun: Optional[Callable] = None, module_name: str = "", ensure_analog_root: bool = True, exclude_modules: Optional[List[str]] = None, inplace: bool = False, verbose: bool = False, ) -> Module: """Convert a given digital model to its analog counterpart with tile mapping support. Note: The torch device (cuda/cpu) is inferred from the original models parameters, however, if multiple torch devices are used in a given module, the corresponding analog module is not moved to any device. Args: module: The torch module to convert. All layers that are defined in the ``conversion_map``. rpu_config: RPU config to apply to all converted tiles. tile_module_class: Custom tile module class specific_rpu_config_fun: Function that modifies the generic RPUConfig for specific modules. See :func:`~specific_rpu_config_id` as an example how to specify it. module_name: Explicitly given name of the base (root) module, given to ``specific_rpu_config_fun``. ensure_analog_root: Whether to ensure that the root module is of layer type `AnalogLayerBase` so that custom analog are methods such as `drift_analog_weigths` are available. If set, it will wrap the model if `AnalogWrapper` if necessary. Note: Since the module structure changes when wrapped, the checkpoint names will also change if this is enabled (for legacy load this might need to be disabled). exclude_modules: List of modules names that are in the conversion map but should be excluded from the conversion inplace: Whether to for in place conversion (without deepcopy) verbose: Increase verbosity. Will print converted layers. Returns: Module where all the digital layers are replaced with analog mapped layers. """ return convert_to_analog( module, rpu_config, tile_module_class, _DEFAULT_MAPPED_CONVERSION_MAP, specific_rpu_config_fun, module_name, ensure_analog_root, exclude_modules, inplace, verbose, )
[docs]def convert_to_digital( module: Module, conversion_set: Optional[Set] = None, module_name: str = "", inplace: bool = False, ) -> Module: """Convert a given analog model to digital counter parts. Note: The torch device (cuda/cpu) is inferred from the original models parameters, however, if multiple torch devices are used in a given module, the corresponding analog module is not moved to any device. Args: module: The torch module to convert having analog layers. All analog layer classes to be converted are defined in the ``conversion_set``. conversion_set: Set of analog module classes to be replaced. By default all ``AnalogConv*`` and ``AnalogLinear*`` layers are replaced with their digital counterparts. Note: The analog layer needs to have a class method ``to_digital`` which will be called during the conversion. module_name: Explicitly given name of the base (root) module inplace: Whether to for in place conversion (without deepcopy) Returns: Module where all the analog layers are replaced with digital layers Raises: ArgumentError: in case conversion map is not a set or contains layers that cannot be converted """ # pylint: disable=too-many-branches if not inplace: module = deepcopy(module) if conversion_set is None: conversion_set = _DEFAULT_DIGITAL_CONVERSION_SET if not isinstance(conversion_set, set): raise ArgumentError("Conversion set should be a set") # Convert parent. if module.__class__ in conversion_set or isinstance(module, AnalogWrapper): if not hasattr(module.__class__, "to_digital"): raise ArgumentError( "The conversion set should only contain " "analog layers that support `to_digital` conversion." ) module = module.__class__.to_digital(module) # Convert children. convert_dic = {} for name, mod in module.named_children(): full_name = module_name + "." + name if module_name else name n_grand_children = len(list(mod.named_children())) if n_grand_children > 0: convert_dic[name] = convert_to_digital(mod, conversion_set, full_name, inplace=True) continue if mod.__class__ not in conversion_set: continue if not hasattr(mod, "to_digital"): raise ArgumentError( "The conversion set should only contain " "analog layers that support `to_digital` conversion." ) convert_dic[name] = mod.to_digital(mod) # type: ignore for name, new_mod in convert_dic.items(): module._modules[name] = new_mod # pylint: disable=protected-access return module