# -*- coding: utf-8 -*-
# (C) Copyright 2020, 2021, 2022 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""Base class for analog Modules."""
from typing import (
Any, Dict, List, Optional, Tuple, NamedTuple, Union,
Generator, TYPE_CHECKING
)
from torch import Tensor, no_grad, ones, float32
from torch.nn import Module, Parameter
from torch import device as torch_device
from aihwkit.exceptions import ModuleError
from aihwkit.simulator.configs.configs import (
FloatingPointRPUConfig, InferenceRPUConfig, SingleRPUConfig,
UnitCellRPUConfig, DigitalRankUpdateRPUConfig
)
from aihwkit.simulator.configs.utils import MappingParameter
from aihwkit.simulator.tiles import InferenceTile
from aihwkit.optim.context import AnalogContext
if TYPE_CHECKING:
from aihwkit.simulator.tiles import BaseTile
from collections import OrderedDict
RPUConfigAlias = Union[FloatingPointRPUConfig, SingleRPUConfig,
UnitCellRPUConfig, InferenceRPUConfig,
DigitalRankUpdateRPUConfig]
[docs]class AnalogModuleBase(Module):
"""Base class for analog Modules.
Base ``Module`` for analog layers that use analog tiles. When subclassing,
please note:
* the :meth:`_setup_tile()` method is expected to be called by the subclass
constructor, and it does not only create a tile.
* :meth:`register_analog_tile` needs to be called for each created analog tile
* this module does *not* call torch's ``Module`` init as the child is
likely again derived from Module
* the ``weight`` and ``bias`` Parameters are not guaranteed to be in
sync with the tile weights and biases during the lifetime of the instance,
for performance reasons. The canonical way of reading and writing
weights is via the :meth:`set_weights()` and :meth:`get_weights()` as opposed
to using the attributes directly.
* the ``BaseTile`` subclass that is created is retrieved from the
``rpu_config.tile_class`` attribute.
Args:
in_features: input vector size (number of columns).
out_features: output vector size (number of rows).
bias: whether to use a bias row on the analog tile or not.
realistic_read_write: whether to enable realistic read/write
for setting initial weights and during reading of the weights.
mapping: Configuration of the hardware architecture (e.g. tile size).
"""
# pylint: disable=abstract-method, too-many-instance-attributes
ANALOG_CTX_PREFIX: str = 'analog_ctx_'
ANALOG_SHARED_WEIGHT_PREFIX: str = 'analog_shared_weights_'
ANALOG_STATE_PREFIX: str = 'analog_tile_state_'
ANALOG_OUT_SCALING_ALPHA_PREFIX: str = 'analog_out_scaling_alpha_'
def __init__(
self,
in_features: int,
out_features: int,
bias: bool,
realistic_read_write: bool = False,
mapping: Optional[MappingParameter] = None,
) -> None:
# pylint: disable=super-init-not-called
self._analog_tile_counter = 0
self._registered_helper_parameter = [] # type: list
self._load_rpu_config = True
if mapping is None:
mapping = MappingParameter()
self.use_bias = bias
self.digital_bias = bias and mapping.digital_bias
self.analog_bias = bias and not mapping.digital_bias
self.realistic_read_write = realistic_read_write
self.in_features = in_features
self.out_features = out_features
[docs] def register_analog_tile(self, tile: 'BaseTile', name: Optional[str] = None) -> None:
"""Register the analog context of the tile.
Note:
Needs to be called at the end init to register the tile
for the analog optimizers.
Args:
tile: tile to register
name: Optional tile name used as the parameter name
"""
if name is None:
name = str(self._analog_tile_counter)
ctx_name = self.ANALOG_CTX_PREFIX + name
if ctx_name not in self._registered_helper_parameter:
self._registered_helper_parameter.append(ctx_name)
self.register_parameter(ctx_name, tile.get_analog_ctx())
if tile.shared_weights is not None:
if not isinstance(tile.shared_weights, Parameter):
tile.shared_weights = Parameter(tile.shared_weights)
par_name = self.ANALOG_SHARED_WEIGHT_PREFIX + str(self._analog_tile_counter)
self.register_parameter(par_name, tile.shared_weights)
if par_name not in self._registered_helper_parameter:
self._registered_helper_parameter.append(par_name)
mapping = tile.rpu_config.mapping
if mapping.learn_out_scaling_alpha:
if tile.out_scaling_alpha is None:
tile.out_scaling_alpha = Parameter(ones([1], device=tile.device, dtype=float32))
elif not isinstance(tile.out_scaling_alpha, Parameter):
tile.out_scaling_alpha = Parameter(tile.out_scaling_alpha)
par_name = self.ANALOG_OUT_SCALING_ALPHA_PREFIX + str(self._analog_tile_counter)
self.register_parameter(par_name, tile.out_scaling_alpha)
if par_name not in self._registered_helper_parameter:
self._registered_helper_parameter.append(par_name)
self._analog_tile_counter += 1
[docs] def unregister_parameter(self, param_name: str) -> None:
"""Unregister module parameter from parameters.
Raises:
ModuleError: In case parameter is not found
"""
param = getattr(self, param_name, None)
if not isinstance(param, Parameter):
raise ModuleError(f"Cannot find parameter {param_name} to unregister")
param_data = param.detach().clone()
delattr(self, param_name)
setattr(self, param_name, param_data)
[docs] def analog_tiles(self) -> Generator['BaseTile', None, None]:
""" Generator to loop over all registered analog tiles of the module """
for param in self.parameters():
if isinstance(param, AnalogContext):
yield param.analog_tile
[docs] def named_analog_tiles(self) -> Generator[Tuple[str, 'BaseTile'], None, None]:
""" Generator to loop over all registered analog tiles of the module with names. """
for name, param in self.named_parameters():
if isinstance(param, AnalogContext):
new_name = name.split(self.ANALOG_CTX_PREFIX)[-1]
yield (new_name, param.analog_tile)
[docs] def get_analog_tile_devices(self) -> List[Optional[Union[torch_device, str, int]]]:
""" Return a list of the devices used by the analog tiles.
Returns:
List of torch devices
"""
return [d.device for d in self.analog_tiles()]
[docs] def analog_tile_count(self) -> int:
"""Return the number of registered tiles.
Returns:
Number of registered tiles
"""
return getattr(self, '_analog_tile_counter', 0)
def _setup_tile(
self,
rpu_config: RPUConfigAlias,
) -> 'BaseTile':
"""Create a single analog tile with the given RPU configuration.
Create an analog tile to be used for the basis of this layer
operations, while using the attributes ``(in_features,
out_features, bias)`` given to this instance during init.
After tile creation, the tile needs to be registered using
:meth:`register_analog_tile`.
Args:
rpu_config: resistive processing unit configuration.
Returns:
An analog tile with the requested parameters.
"""
# pylint: disable=protected-access
# Create the tile.
return rpu_config.tile_class(self.out_features, self.in_features, rpu_config,
bias=self.analog_bias)
[docs] def set_weights(
self,
weight: Tensor,
bias: Optional[Tensor] = None,
force_exact: bool = False,
remap_weights: bool = True,
weight_scaling_omega: float = None
) -> None:
"""Set the weight (and bias) values with given tensors.
This uses an realistic write if the property ``realistic_read_write``
of the layer is set, unless it is overwritten by ``force_exact``.
If ``weight_scaling_omega`` is larger than 0, the weights are set in a
scaled manner (assuming a digital output scale). See
:meth:`~aihwkit.simulator.tiles.base.BaseTile.set_weights_scaled`
for details.
Note:
This is the recommended way for setting the weight/bias matrix of
the analog tile, as it will correctly store the weights into the
internal memory. Directly writing to ``self.weight`` and
``self.bias`` might yield wrong results as they are not always in
sync with the analog tile Parameters, for performance reasons.
Args:
weight: weight matrix
bias: bias vector
force_exact: forces an exact write to the analog tiles
remap_weights: Whether to rescale the given weight matrix
and populate the digital output scaling factors as
specified in the configuration
:class:`~aihwkit.configs.utils.MappingParameter`. A
new ``weight_scaling_omega`` can be given. Note that
this will overwrite the existing digital out scaling
factors.
weight_scaling_omega: The weight scaling omega factor (see
:class:`~aihwkit.configs.utils.MappingParameter`). If
given explicitly here, it will overwrite the value in
the mapping field.
Raises:
ModuleError: in case of multiple defined analog tiles in the module
"""
shape = [self.out_features, self.in_features]
weight = weight.clone().reshape(shape)
realistic = self.realistic_read_write and not force_exact
analog_tiles = list(self.analog_tiles())
if len(analog_tiles) != 1:
raise ModuleError("AnalogModuleBase.set_weights only supports a single tile.")
analog_tile = analog_tiles[0]
if remap_weights:
omega = weight_scaling_omega
if omega is None:
omega = analog_tile.rpu_config.mapping.weight_scaling_omega
analog_tile.set_weights_scaled(
weight, bias if self.analog_bias else None,
realistic=realistic,
weight_scaling_omega=omega)
else:
analog_tile.set_weights(weight, bias if self.analog_bias else None,
realistic=realistic)
if bias is not None and self.digital_bias:
with no_grad():
self.bias.data[:] = bias[:]
self._sync_weights_from_tile()
[docs] def get_weights(
self,
force_exact: bool = False,
apply_out_scales: bool = True,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Get the weight (and bias) tensors.
This uses an realistic read if the property ``realistic_read_write`` of
the layer is set, unless it is overwritten by ``force_exact``. It
scales the analog weights by the digital alpha scale if
``weight_scaling_omega`` is positive (see
:meth:`~aihwkit.simulator.tiles.base.BaseTile.get_weights_scaled`).
Note:
This is the recommended way for setting the weight/bias matrix from
the analog tile, as it will correctly fetch the weights from the
internal memory. Accessing ``self.weight`` and ``self.bias`` might
yield wrong results as they are not always in sync with the
analog tile library, for performance reasons.
Args:
force_exact: Forces an exact read to the analog tiles
apply_out_scales: Whether to return the weights with the
(digital) output scaling factors applied. Note the
"logical" weights of the layer which the DNN is
effectively using are those with the output scales
applied. If ``apply_out_scales`` is set to False, then
only the weight values that is programmed onto the
crossbar array are returned, without applying the
digital scales.
Returns:
tuple: weight matrix, bias vector
Raises:
ModuleError: in case of multiple defined analog tiles in the module
"""
analog_tiles = list(self.analog_tiles())
if len(analog_tiles) != 1:
raise ModuleError("AnalogModuleBase.get_weights only supports a single tile.")
analog_tile = analog_tiles[0]
realistic = self.realistic_read_write and not force_exact
if apply_out_scales:
weight, analog_bias = analog_tile.get_weights_scaled(realistic=realistic)
else:
weight, analog_bias = analog_tile.get_weights(realistic=realistic)
digital_bias = None
if self.digital_bias:
with no_grad():
digital_bias = self.bias.data.clone().detach().cpu()
if (digital_bias is not None) and (analog_bias is not None):
bias = digital_bias + analog_bias
elif digital_bias is not None:
bias = digital_bias
else:
bias = analog_bias
return weight, bias
def _sync_weights_from_tile(self) -> None:
"""Update the layer weight and bias from the values on the analog tile.
Update the ``self.weight`` and ``self.bias`` Parameters with an
exact copy of the internal analog tile weights.
"""
tile_weight, tile_bias = self.get_weights(force_exact=True) # type: Tuple[Tensor, Tensor]
self.weight.data[:] = tile_weight.reshape(self.weight.shape)
if self.analog_bias:
with no_grad():
self.bias.data[:] = tile_bias.reshape(self.bias.shape)
def _sync_weights_to_tile(self) -> None:
"""Update the tile values from the layer weights and bias.
Update the internal tile weights with an exact copy of the values of
the ``self.weight`` and ``self.bias`` Parameters.
"""
self.set_weights(self.weight, self.bias if self.analog_bias else None,
force_exact=True)
def _set_load_rpu_config_state(self, load_rpu_config: bool = True) -> None:
self._load_rpu_config = load_rpu_config
[docs] def load_state_dict(self, # pylint: disable=arguments-differ
state_dict: 'OrderedDict[str, Tensor]',
strict: bool = True,
load_rpu_config: bool = True) -> NamedTuple:
"""Specializes torch's ``load_state_dict`` to add a flag whether to
load the RPU config from the saved state.
Args:
state_dict: see torch's ``load_state_dict``
strict: see torch's ``load_state_dict``
load_rpu_config: Whether to load the saved RPU
config or use the current RPU config of the model.
Caution:
If ``load_rpu_config=False`` the RPU config can
be changed from the stored model. However, the user has to
make sure that the changed RPU config makes sense.
For instance, changing the device type might
change the expected fields in the hidden
parameters and result in an error.
Returns:
see torch's ``load_state_dict``
Raises: ModuleError: in case the rpu_config class mismatches
for ``load_rpu_config=False``.
"""
self._set_load_rpu_config_state(load_rpu_config)
return super().load_state_dict(state_dict, strict)
def _load_from_state_dict(
self,
state_dict: Dict,
prefix: str,
local_metadata: Dict,
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str]) -> None:
"""Copy parameters and buffers from `state_dict` into only this
module, but not its descendants.
This method is a specialization of ``Module._load_from_state_dict``
that takes into account the extra ``analog_tile_state`` key used by
analog layers.
Raises:
ModuleError: in case the rpu_config class mismatches.
"""
for name, analog_tile in list(self.named_analog_tiles()):
key = prefix + self.ANALOG_STATE_PREFIX + name
if key not in state_dict: # legacy
key = prefix + 'analog_tile_state'
if key in state_dict:
analog_state = state_dict.pop(key).copy()
if not self._load_rpu_config:
if analog_tile.rpu_config.__class__ != analog_state['rpu_config'].__class__:
raise ModuleError("RPU config mismatch during loading: "
"Tried to replace "
f"{analog_state['rpu_config'].__class__.__name__} "
f"with {analog_tile.rpu_config.__class__.__name__}")
analog_state['rpu_config'] = analog_tile.rpu_config
analog_tile.__setstate__(analog_state)
elif strict:
missing_keys.append(key)
# update the weight / analog bias (not saved explicitly)
self._sync_weights_from_tile()
# remove helper parameters.
rm_keys = []
for par_name in self._registered_helper_parameter:
key = prefix + par_name
if key in state_dict:
state_dict.pop(key)
rm_keys.append(key)
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys,
unexpected_keys, error_msgs)
# remove the missing keys of the helper parameters
for key in rm_keys:
missing_keys.remove(key)
[docs] def state_dict(
self,
destination: Any = None,
prefix: str = '',
keep_vars: bool = False
) -> Dict:
"""Return a dictionary containing a whole state of the module."""
self._sync_weights_from_tile()
current_state = super().state_dict(destination, prefix, keep_vars)
for name, analog_tile in self.named_analog_tiles():
analog_state = analog_tile.__getstate__()
analog_state_name = prefix + self.ANALOG_STATE_PREFIX + name
current_state[analog_state_name] = analog_state
return current_state
[docs] def drift_analog_weights(self, t_inference: float = 0.0) -> None:
"""(Program) and drift the analog weights.
Args:
t_inference: assumed time of inference (in sec)
Raises:
ModuleError: if the layer is not in evaluation mode.
"""
if self.training:
raise ModuleError('drift_analog_weights can only be applied in '
'evaluation mode')
for analog_tile in self.analog_tiles():
if isinstance(analog_tile, InferenceTile):
analog_tile.drift_weights(t_inference)
[docs] def program_analog_weights(self) -> None:
"""Program the analog weights.
Raises:
ModuleError: if the layer is not in evaluation mode.
"""
if self.training:
raise ModuleError('program_analog_weights can only be applied in '
'evaluation mode')
for analog_tile in self.analog_tiles():
if isinstance(analog_tile, InferenceTile):
analog_tile.program_weights()