# -*- coding: utf-8 -*-
# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""Tile module base."""
# pylint: disable=too-few-public-methods, abstract-method, bad-super-call
from typing import Dict, List, Optional, Union, Any, Callable, Tuple, NamedTuple
from collections import OrderedDict
from copy import deepcopy
from torch import Tensor
from torch.nn import Module
from torch import dtype as torch_dtype
from torch import device as torch_device
from aihwkit.exceptions import TileModuleError
from aihwkit.simulator.tiles.base import AnalogTileStateNames, BaseTile, TileModuleBase
from aihwkit.optim.context import AnalogContext
from aihwkit.simulator.parameters.base import RPUConfigBase
from aihwkit.simulator.parameters.mapping import MappableRPU
[docs]class TileModule(Module, TileModuleBase):
"""Class of all tiles based on ``torch.Module``.
A TileModule class inherits from three base classes::
class MyTile(TileModule, MyTile(BaseTile), MySimulatorTileWrapper)
Assuming this structure, here utility methods are defined that
help to go through all these classes, such as ``cuda``.
"""
supports_ddp: bool = False
def __init__(self) -> None:
super().__init__()
self.load_rpu_config = True # whether to load the rpu_config from the state dict
self.strict_rpu_config_check = True
self.use_state_dict_hooks = False
if not hasattr(self, "_register_state_dict_hook"):
self.use_state_dict_hooks = False
if self.use_state_dict_hooks:
self._register_state_dict_hook(TileModule._state_dict_hook)
self._register_load_state_dict_pre_hook(
TileModule._load_state_dict_pre_hook, with_module=True
)
self.register_load_state_dict_post_hook(TileModule._load_state_dict_post_hook)
[docs] def set_load_rpu_config_state(
self, load_rpu_config: Optional[bool], strict_rpu_config_check: Optional[bool] = None
) -> None:
"""Sets the behavior of when using ``load_state_dict``.
Caution:
If ``load_rpu_config=False`` the RPU config can
be changed from the stored model. However, the user has to
make sure that the changed RPU config makes sense.
For instance, changing the device type might
change the expected fields in the hidden
parameters and result in an error.
Args:
load_rpu_config: Whether to load the saved RPU
config or use the current RPU config of the model.
strict_rpu_config_check: Whether to check and throw an
error if the current ``rpu_config`` is not of the same
class type when setting ``load_rpu_config`` to
False. In case of ``False`` the user has to make sure
that the ``rpu_config`` are compatible.
"""
if load_rpu_config is not None:
self.load_rpu_config = load_rpu_config
if strict_rpu_config_check is not None:
self.strict_rpu_config_check = strict_rpu_config_check
def __getstate__(self) -> Dict:
# pylint: disable=no-member
state = {}
if hasattr(super(Module, self), "__getstate__"):
state = super(Module, self).__getstate__()
# if isinstance(self, BaseTile) and hasattr(super(BaseTile, self), '__getstate__'):
# state.update(super(BaseTile, self).__getstate__()) # type: ignore
return state
def __setstate__(self, state: Dict) -> None:
# pylint: disable=no-member
if hasattr(super(Module, self), "__setstate__"):
# The TileWrapper is handling all the attributes
super(Module, self).__setstate__(state)
else:
Module.__setstate__(self, state)
# update parameter IDs
for name in self._parameters: # type: ignore
self._parameters[name] = getattr(self, name) # type: ignore
for name in self._buffers:
self._buffers[name] = getattr(self, name)
for name in self._modules:
self._modules[name] = getattr(self, name)
def _apply_without_context(self, fn: Callable) -> None:
"""Loops through parameters (excluding the tile module's own AnalogContext)."""
_parameters = self._parameters # type: ignore
self._parameters = OrderedDict()
for name, param in _parameters.items():
if isinstance(param, AnalogContext):
if param.analog_tile == self:
continue
self._parameters[name] = param
Module._apply(self, fn)
self._parameters = _parameters
def _apply(self, fn: Callable) -> None:
# pylint: disable=arguments-differ
"""Delegates to the module level.
This avoids looping through parameters which would cause
unlimited recursions.
Raises
TileModuleError in case the fucntion does not exist on Module level
"""
try:
fn(self)
except Exception as exception:
raise TileModuleError(
"Applied function is not supported for TileModule: {}".format(str(exception))
) from exception
[docs] def is_floating_point(self) -> bool:
"""Dummy for .to to work."""
return True
[docs] def cuda(self, device: Optional[Union[torch_device, str, int]] = None) -> "TileModule":
"""Return a copy of this tile in CUDA memory.
Args:
device: CUDA device
Returns:
Self with the underlying C++ tile moved to CUDA memory.
Raises:
CudaError: if the library has not been compiled with CUDA.
"""
# handle the SimulatorTileWrapper
if hasattr(super(Module, self), "cuda"):
super(Module, self).cuda(device)
if isinstance(self, BaseTile) and hasattr(super(BaseTile, self), "cuda"):
super(BaseTile, self).cuda(device) # type: ignore
# at the end. shared weight might be updated above which might
# yeild issues if the tile is not first updated
self._apply_without_context(lambda t: t.cuda(device))
return self
[docs] def cpu(self) -> "TileModule":
"""Return a copy of this tile in CUDA memory.
Returns:
Self with the underlying C++ tile moved to CUDA memory.
Raises:
CudaError: if the library has not been compiled with CUDA.
"""
if hasattr(super(Module, self), "cpu"):
super(Module, self).cpu()
if isinstance(self, BaseTile) and hasattr(super(BaseTile, self), "cpu"):
super(BaseTile, self).cpu() # type: ignore
self._apply_without_context(lambda t: t.cpu())
return self
[docs] def to(self, *args: Any, **kwargs: Any) -> "TileModule":
"""Move analog tile module to a device.
RPUConfig conversions can be done as well.
Note:
Please be aware that moving analog tiles from GPU to CPU is
currently not supported.
Returns:
This module in the specified device and converted to the
specified data type.
"""
rpu_config = kwargs.pop("rpu_config", None)
new_args = list(args)
if len(new_args) > 0 and isinstance(new_args[0], RPUConfigBase):
rpu_config = args[0]
del new_args[0]
if rpu_config is not None:
self.replace_with(rpu_config)
device = kwargs.pop("device", None)
dtype = kwargs.pop("dtype", None)
for arg in args:
if isinstance(arg, bool):
continue
if isinstance(arg, torch_dtype):
dtype = arg
if isinstance(arg, (str, torch_device)):
device = torch_device(arg) if isinstance(arg, str) else arg
if device is not None:
if device.type == "cuda":
self.cuda(device)
else:
self.cpu()
if len(new_args) > 0 or len(kwargs) > 0:
self._apply_without_context(lambda t: t.to(*new_args, **kwargs))
if dtype is not None:
scales = self.get_scales()
if scales is not None:
self.set_scales(scales)
return self
[docs] @staticmethod
def get_analog_state_name(prefix: str) -> str:
"""Returns the analog state name."""
return prefix + AnalogTileStateNames.ANALOG_STATE_NAME
@staticmethod
def _state_dict_hook(
analog_tile: "TileModule", state_dict: Dict, prefix: str, local_metadata: Dict
) -> None:
# pylint: disable=unused-argument
analog_state_name = TileModule.get_analog_state_name(prefix)
analog_state = analog_tile.get_analog_state()
state_dict[analog_state_name] = analog_state
[docs] def state_dict( # pylint: disable=arguments-differ
self, destination: Dict, prefix: str = "", keep_vars: bool = False
) -> None:
"""Overload to add the hooks for pytorch < 1.12."""
Module.state_dict(self, destination=destination, prefix=prefix, keep_vars=keep_vars)
if not self.use_state_dict_hooks:
TileModule._state_dict_hook(self, destination, prefix, {})
def _load_from_state_dict(
self,
state_dict: Dict,
prefix: str,
local_metadata: Dict,
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
r"""Overloads :meth:`~torch.nn.Module.load_state_dict` to
include state dict hooks prior to pytorch 1.12
"""
if not self.use_state_dict_hooks:
# use the hooks of pytorch >= 1.12 call the hooks
# explicitely
TileModule._load_state_dict_pre_hook(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
Module._load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
if not self.use_state_dict_hooks:
TileModule._load_state_dict_post_hook(
self, (missing_keys, unexpected_keys) # type: ignore
)
@staticmethod
def _load_state_dict_post_hook(
analog_tile: "TileModule", incompatible_keys: NamedTuple
) -> None:
"""Handle some cross-load situation. These are handled on the
__setstate__ level."""
# pylint: disable=unused-argument
missing_keys, unexpected_keys = incompatible_keys
for key in missing_keys.copy():
if AnalogTileStateNames.SHARED_WEIGHTS in key:
missing_keys.remove(key)
for key in unexpected_keys.copy():
if ".tile." in key:
# this needs to be handled in the analog_module level
unexpected_keys.remove(key)
[docs] def compatible_with(self, rpu_config: RPUConfigBase) -> Tuple[bool, Optional[str]]:
"""Checks whether current `RPUConfig` is compatible with given
one.
Args:
rpu_config: New `RPUConfig` to check against
Returns:
success: Whether the given `RPUConfig` is compatible
msg: Error message if not
"""
if self.strict_rpu_config_check:
if not isinstance(self.rpu_config, type(rpu_config)) and not isinstance(
rpu_config, type(self.rpu_config)
):
return False, (
"RPU config mismatch: "
"Cannot replace "
f"{rpu_config.__class__.__name__} "
f"with {self.rpu_config.__class__.__name__}"
)
if (
isinstance(rpu_config, MappableRPU)
and isinstance(self.rpu_config, MappableRPU)
and rpu_config.mapping != self.rpu_config.mapping
):
if not self.rpu_config.mapping.compatible_with(rpu_config.mapping):
return False, (
"MappingParameter mismatch. Cannot in-place change mapping parameters"
"as it might change the model structure."
)
return True, None
[docs] def replace_with(self, rpu_config: RPUConfigBase) -> None:
"""Replaces the current `RPUConfig` with the given one.
Args:
rpu_config: New `RPUConfig` to check against
Raises:
TileModuleError: if given `RPUConfig` is not compatible.
"""
success, msg = self.compatible_with(rpu_config)
if not success:
raise TileModuleError(msg)
analog_state = self.__getstate__()
analog_state[AnalogTileStateNames.RPU_CONFIG] = deepcopy(rpu_config)
self.__setstate__(analog_state)
@staticmethod
def _load_state_dict_pre_hook(
analog_tile: "TileModule",
state_dict: Dict,
prefix: str,
local_metadata: Dict,
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
# pylint: disable=unused-argument
analog_state_name = TileModule.get_analog_state_name(prefix)
if analog_state_name not in state_dict:
missing_keys.append(analog_state_name)
return
# handle shared weights here (need to be transposed on CUDA)
# Torch will throw an error if shapes are not compatible
shared_weights_key = prefix + AnalogTileStateNames.SHARED_WEIGHTS
if shared_weights_key in state_dict and isinstance(state_dict[shared_weights_key], Tensor):
state_dict.pop(shared_weights_key, None)
analog_state = state_dict.pop(analog_state_name).copy()
if not analog_tile.load_rpu_config:
success, msg = analog_tile.compatible_with(analog_state["rpu_config"])
if not success:
raise TileModuleError(msg)
analog_state["rpu_config"] = analog_tile.rpu_config
analog_tile.__setstate__(analog_state)