Source code for aihwkit.nn.modules.container

# -*- coding: utf-8 -*-

# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Analog Modules that contain children Modules."""

# pylint: disable=unused-argument, arguments-differ
from types import new_class
from typing import Any, Optional, Tuple, Dict, Type
from collections import OrderedDict

from torch import Tensor
from torch.nn import Sequential, Module

from aihwkit.nn.modules.base import AnalogLayerBase
from aihwkit.exceptions import ModuleError


[docs]class AnalogContainerBase(AnalogLayerBase): """Base class for analog containers.""" IS_CONTAINER: bool = True
[docs] def get_weights( # type: ignore self, **kwargs: Any ) -> "OrderedDict[str, Tuple[Tensor, Optional[Tensor]]]": """Returns all weights, bias tuples in an ordered dictionary. Args: kwargs: passed to the TileModule ``get_weights`` call Returns: All analog weight of all layers """ weights_dic = OrderedDict() for name, analog_tile in self.named_analog_tiles(): weights_dic[name] = analog_tile.get_weights(**kwargs) return weights_dic
[docs] def set_weights( # type: ignore self, weights_dic: "OrderedDict[str, Tuple[Tensor, Optional[Tensor]]]", **kwargs: Any ) -> None: """Set all analog weights part of this parent module. Args: weights_dic: Ordered dictionary of weight data kwargs: passed to the TileModule ``set_weights`` call Raises: ModuleError: in case tile name cannot be found """ for name, analog_tile in self.named_analog_tiles(): if name not in weights_dic: raise ModuleError("Cannot find tile weight {} in given dictionary.".format(name)) analog_tile.set_weights(*weights_dic[name], **kwargs)
[docs]class AnalogSequential(AnalogContainerBase, Sequential): """An analog-aware sequential container. Specialization of torch ``nn.Sequential`` with extra functionality for handling analog layers: * apply analog-specific functions to all its children (drift and program weights). Note: This class is recommended to be used in place of ``nn.Sequential`` in order to correctly propagate the actions to all the children analog layers. If using regular containers, please be aware that operations need to be applied manually to the children analog layers when needed. """
[docs] @classmethod def from_digital(cls, module: Sequential, *args: Any, **kwargs: Any) -> "AnalogSequential": """Construct AnalogSequential in-place from Sequential.""" return cls(OrderedDict(mod for mod in module.named_children()))
[docs] @classmethod def to_digital(cls, module: "AnalogSequential", *args: Any, **kwargs: Any) -> Sequential: """Construct Sequential in-place from AnalogSequential.""" return Sequential(OrderedDict(mod for mod in module.named_children()))
[docs]class AnalogWrapper(AnalogContainerBase): """Generic wrapper over an given Module. Will add the AnalogLayerBase functionality to the given Module (as an added subclass). Note: Here the state dictionary of the give module will be simply copied by reference. The original model therefore should not be used any more as the underlying tensor data is shared. Args: model: model to wrap with the analog wrapper. """ SUBCLASSES = {} # type: Dict[str, Type] """Registry of the created subclasses.""" def __new__(cls, module: Optional[Module] = None, **__: Any) -> "AnalogWrapper": if module is None: # for deepcopy and the like return super().__new__(cls) module_cls = module.__class__ subclass_name = "{}{}".format(cls.__name__, module_cls.__name__) # Retrieve or create a new subclass, that inherits both from # `AnalogModuleBase` and for the specific torch module # (`module_cls`). if subclass_name not in cls.SUBCLASSES: cls.SUBCLASSES[subclass_name] = new_class(subclass_name, (cls, module_cls), {}) return super().__new__(cls.SUBCLASSES[subclass_name]) def __init__(self, module: Module): self.__dict__.update(module.__dict__)
[docs] @classmethod def from_digital(cls, module: Module, *args: Any, **kwargs: Any) -> "AnalogWrapper": """Construct AnalogSequential in-place from any module.""" return cls(module)
[docs] @classmethod def to_digital(cls, module: "AnalogWrapper", *args: Any, **kwargs: Any) -> Module: """Construct Sequential in-place from AnalogSequential.""" digital_class = module.__class__.__bases__[1] new_module = digital_class.__new__(digital_class) # type: ignore new_module.__dict__ = module.__dict__ return new_module