# -*- coding: utf-8 -*-
# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""High level analog tiles (base)."""
# pylint: disable=too-few-public-methods, abstract-method, too-many-instance-attributes
from collections import OrderedDict
from typing import Optional, Union, Tuple, Any, Dict, List
from copy import deepcopy
from numpy import array
from numpy.typing import ArrayLike
from torch import Tensor, from_numpy, float32, unsqueeze, cat, empty, stack, dtype
from torch import device as torch_device
from torch.cuda import device as cuda_device
from torch.autograd import no_grad
from aihwkit import __version__
from aihwkit.exceptions import TileError
from aihwkit.simulator.parameters.mapping import MappingParameter
from aihwkit.simulator.parameters.base import RPUConfigGeneric
from aihwkit.simulator.parameters.runtime import RuntimeParameter
from aihwkit.simulator.parameters.enums import RPUDataType
from aihwkit.optim.context import AnalogContext
[docs]class TileModuleBase:
"""Base class for (logical) tile modules that can be used in
layers, e.g. array of TileModules.
"""
[docs]class AnalogTileStateNames:
"""Class defining analog tile state name constants.
Caution:
Do *not* edit. Some names are attribute names of the tile.
"""
VERSION = "aihwkit_version"
WEIGHTS = "analog_tile_weights"
HIDDEN_PARAMETERS = "analog_tile_hidden_parameters"
HIDDEN_PARAMETER_NAMES = "analog_tile_hidden_parameter_names"
CLASS = "analog_tile_class"
LR = "analog_lr"
SHARED_WEIGHTS = "shared_weights"
CONTEXT = "analog_ctx"
OUT_SCALING = "out_scaling_alpha"
MAPPING_SCALES = "mapping_scales"
RPU_CONFIG = "rpu_config"
ANALOG_STATE_PREFIX = "analog_tile_state_"
ANALOG_STATE_NAME = "analog_tile_state"
EXTRA = "state_extra"
[docs] @staticmethod
def get_field_names() -> List[str]:
"""Returns expected field names."""
return [
getattr(AnalogTileStateNames, key)
for key in AnalogTileStateNames.__dict__
if not key.startswith("_")
]
[docs]class BaseTile:
"""Base class for tile classes (without ``torch.Module`` dependence)."""
[docs] def joint_forward(self, x_input: Tensor, is_test: bool = False, ctx: Any = None) -> Tensor:
"""Perform the joint forward method.
Calls first the ``pre_forward``, then the tile forward, and
finally the ``post_forward`` step.
Note:
The full forward pass is not using autograd, thus all pre
and post functions need to be handled appropriately in the
pre/post backward functions.
Args:
x_input: ``[N, in_size]`` tensor. If ``in_trans`` is set, transposed.
is_test: whether to assume testing mode.
ctx: torch auto-grad context [Optional]
Returns:
torch.Tensor: ``[N, out_size]`` tensor. If ``out_trans`` is set, transposed.
"""
raise NotImplementedError
[docs] def backward(self, d_input: Tensor, ctx: Any = None) -> Tensor:
"""Perform the backward pass.
Args:
d_input: ``[N, out_size]`` tensor. If ``out_trans`` is set, transposed.
ctx: torch auto-grad context [Optional]
Returns:
torch.Tensor: ``[N, in_size]`` tensor. If ``in_trans`` is set, transposed.
"""
raise NotImplementedError
[docs] def update(self, x_input: Tensor, d_input: Tensor) -> None:
"""Perform the update pass.
Calls the ``pre_update`` method to pre-process the inputs.
Args:
x_input: ``[..., in_size]`` tensor. If ``in_trans`` is set, ``[in_size, ...]``.
d_input: ``[..., out_size]`` tensor. If ``out_trans`` is set, ``[out_size, ...]``.
Returns:
None
"""
raise NotImplementedError
[docs]class SimulatorTile:
"""Minimal class interface for implementing the simulator tile.
Note:
This tile is generated by ``_create_simulator_tile`` in the
``SimulatorTileWrapper``.
"""
[docs] def forward(
self,
x_input: Tensor,
bias: bool = False,
in_trans: bool = False,
out_trans: bool = False,
is_test: bool = False,
non_blocking: bool = False,
) -> Tensor:
"""General simulator tile forward."""
raise NotImplementedError
[docs] def backward(
self,
d_input: Tensor,
bias: bool = False,
in_trans: bool = False,
out_trans: bool = False,
non_blocking: bool = False,
) -> Tensor:
"""Backward pass.
Only needs to be implemented if torch autograd is `not` used.
"""
raise NotImplementedError
[docs] def update(
self,
x_input: Tensor,
d_input: Tensor,
bias: bool = False,
in_trans: bool = False,
out_trans: bool = False,
non_blocking: bool = False,
) -> Tensor:
"""Update.
Only needs to be implemented if torch autograd update is `not` used.
"""
raise NotImplementedError
[docs] def get_brief_info(self) -> str:
"""Returns a brief info"""
raise NotImplementedError
[docs] def get_weights(self) -> Tensor:
"""Returns the analog weights."""
raise NotImplementedError
[docs] def set_weights(self, weight: Tensor) -> None:
"""Stets the analog weights."""
raise NotImplementedError
[docs] def get_x_size(self) -> int:
"""Returns input size of tile"""
raise NotImplementedError
[docs] def get_d_size(self) -> int:
"""Returns output size of tile"""
raise NotImplementedError
[docs] def get_hidden_parameters(self) -> Tensor:
"""Get the hidden parameters of the tile.
Returns:
Hidden parameter tensor.
"""
return empty(0, dtype=float32)
[docs] def get_hidden_parameter_names(self) -> List[str]:
"""Get the hidden parameters names.
Each name corresponds to a slice in the Tensor slice of the
``get_hidden_parameters`` tensor.
Returns:
List of names.
"""
return []
[docs] def set_hidden_parameters(self, params: Tensor) -> None:
"""Set the hidden parameters of the tile."""
[docs] def get_learning_rate(self) -> Optional[float]:
"""Get the learning rate of the tile.
Returns:
learning rate if exists.
"""
[docs] def set_learning_rate(self, learning_rate: Optional[float]) -> None:
"""Set the learning rate of the tile.
No-op for tiles that do not need a learning rate.
Args:
learning rate: learning rate to set
"""
[docs]class SimulatorTileWrapper:
"""Wrapper base class for defining the necessary tile
functionality.
Will be overloaded extended for C++ or for any TorchTile.
Args:
out_size: output size
in_size: input size
rpu_config: resistive processing unit configuration.
bias: whether to add a bias column to the tile.
in_trans: Whether to assume an transposed input (batch first)
out_trans: Whether to assume an transposed output (batch first)
shared_weights: optional shared weights tensor memory that
should be used.
handle_output_bound: whether the bound clamp gradient should be inserted
ignore_analog_state: whether to ignore the analog state when __getstate__ is called
"""
def __init__(
self,
out_size: int,
in_size: int,
rpu_config: RPUConfigGeneric,
bias: bool = True,
in_trans: bool = False,
out_trans: bool = False,
torch_update: bool = False,
handle_output_bound: bool = False,
ignore_analog_state: bool = False,
):
self.is_cuda = False
self.device = torch_device("cpu")
self.out_size = out_size
self.in_size = in_size
self.rpu_config = deepcopy(rpu_config)
self.in_trans = in_trans
self.out_trans = out_trans
self.handle_output_bound = handle_output_bound
self.ignore_analog_state = ignore_analog_state
self.shared_weights = None
# handling the bias
if hasattr(rpu_config, "mapping"):
mapping = rpu_config.mapping
else:
mapping = MappingParameter()
self.digital_bias = bias and mapping.digital_bias
self.use_bias = bias
self.analog_bias = bias and not mapping.digital_bias
x_size = self.in_size + 1 if self.analog_bias else self.in_size
d_size = self.out_size
self.tile = self._create_simulator_tile(x_size, d_size, rpu_config)
self.analog_ctx = AnalogContext(self)
self.analog_ctx.use_torch_update = torch_update
[docs] def get_runtime(self) -> RuntimeParameter:
"""Returns the runtime parameter."""
if not hasattr(self.rpu_config, "runtime"):
self.rpu_config.runtime = RuntimeParameter()
return self.rpu_config.runtime
[docs] def get_data_type(self) -> RPUDataType:
"""Return data_type setting of the RPUConfig"""
return self.get_runtime().data_type
[docs] def get_dtype(self) -> dtype:
"""Return dtype setting of the RPUConfig"""
return self.get_runtime().data_type.as_torch()
def _create_simulator_tile(
self, x_size: int, d_size: int, rpu_config: "RPUConfigGeneric"
) -> Any: # just use Any instead of Union["SimulatorTile", tiles.AnalogTile, ..]
"""Create a simulator tile.
Args:
x_size: input size
d_size: output size
rpu_config: resistive processing unit configuration
Returns:
a simulator tile based on the specified configuration.
"""
raise NotImplementedError
def _recreate_simulator_tile(
self, x_size: int, d_size: int, rpu_config: "RPUConfigGeneric"
) -> Any: # just use Any instead of Union["SimulatorTile", tiles.AnalogTile, ..]
"""Re-create a simulator tile in __setstate__.
Args:
x_size: input size
d_size: output size
rpu_config: resistive processing unit configuration
Returns:
a simulator tile based on the specified configuration.
"""
return self._create_simulator_tile(x_size, d_size, rpu_config)
[docs] def get_tensor_view(self, ndim: int, dim: Optional[int] = None) -> tuple:
"""Return the tensor view for ndim vector at dim.
Args:
ndim: number of dimensions
dim: the dimension to set to -1
Returns:
Tuple of ones with the `dim`` index sets to -1
"""
if dim is None:
dim = 0 if self.out_trans else ndim - 1
tensor_view = [1] * ndim
tensor_view[dim] = -1
return tuple(tensor_view)
[docs] def get_forward_out_bound(self) -> Optional[float]:
"""Helper for getting the output bound to correct the
gradients using the AnalogFunction.
"""
return None
[docs] def set_verbosity_level(self, verbose: int) -> None:
"""Set the verbosity level.
Args:
verbose: level of verbosity
"""
[docs] def get_analog_ctx(self) -> AnalogContext:
"""Return the analog context of the tile to be used in ``AnalogFunction``."""
return self.analog_ctx
[docs] def get_brief_info(self) -> str:
"""Return short info about the underlying C++ tile."""
return self.tile.get_brief_info().rstrip()
[docs] def update(self, x_input: Tensor, d_input: Tensor) -> None:
"""Implements tile update (e.g. using pulse trains)."""
raise NotImplementedError
[docs] def update_indexed(self, x_input: Tensor, d_input: Tensor) -> None:
"""Implements indexed interface to the tile update
(e.g. using pulse trains)."""
raise NotImplementedError
[docs] def get_analog_state(self) -> Dict:
"""Get the analog state for the state_dict.
Excludes the non-analog state names that might be added for
pickling. Only fields defined in ``AnalogTileStateNames`` are
returned.
"""
state = self.__getstate__()
fields = AnalogTileStateNames.get_field_names()
rm_fields = []
for key in state:
if key not in fields:
rm_fields.append(key)
for key in rm_fields:
state.pop(key)
return state
def __getstate__(self) -> Dict:
"""Get the state for pickling.
This method removes the ``tile`` member, as the binding Tiles are not
serializable.
"""
# Caution: all attributes of the tile will be saved.
current_dict = self.__dict__.copy()
if getattr(self, "ignore_analog_state", False):
return current_dict
SN = AnalogTileStateNames
current_dict[SN.WEIGHTS] = self.tile.get_weights()
current_dict[SN.HIDDEN_PARAMETERS] = self.tile.get_hidden_parameters().data
current_dict[SN.HIDDEN_PARAMETER_NAMES] = self.tile.get_hidden_parameter_names()
current_dict[SN.CLASS] = type(self).__name__
current_dict[SN.LR] = self.tile.get_learning_rate()
current_dict.pop("tile", None)
current_dict[SN.CONTEXT] = self.analog_ctx.data
current_dict[SN.EXTRA] = self.tile.dump_extra()
current_dict[SN.VERSION] = __version__
# don't save device. Will be determined by loading object
current_dict.pop("stream", None)
current_dict.pop("is_cuda", None)
current_dict.pop("device", None)
# this is should not be saved.
current_dict.pop("image_sizes", None)
return current_dict
def __setstate__(self, state: Dict) -> None:
"""Set the state after unpickling.
This method recreates the ``tile`` member, creating a new one from
scratch, as the binding Tiles are not serializable.
Caution:
RPU configs are overwritten by loading the state.
Note:
Some RPUCuda (analog training) compounds have some extra
internal states that should be set if checkpointing to
continue training. To support this, extra states are
extracted and stored. However, these are _not_ applied if
cross-loading is done, e.g. map location is different for
inference or tile type is changed. It will not throw any
notice is they are not applied.
Raises:
TileError: if tile class does not match or hidden parameters do not match
"""
# pylint: disable=too-many-locals, too-many-statements, too-many-branches
if getattr(self, "ignore_analog_state", False) or state.get("ignore_analog_state", False):
self.__dict__.update(state)
analog_ctx = self.analog_ctx
else:
SN = AnalogTileStateNames
current_dict = state.copy()
tile_class = current_dict.pop(SN.CLASS, type(self).__name__)
analog_lr = current_dict.pop(SN.LR, 0.01)
analog_ctx = current_dict.pop(SN.CONTEXT, None)
weights = current_dict.pop(SN.WEIGHTS)
extra = current_dict.pop(SN.EXTRA, None)
hidden_parameters = current_dict.pop(SN.HIDDEN_PARAMETERS)
hidden_parameters_names = current_dict.pop(SN.HIDDEN_PARAMETER_NAMES, [])
current_dict.pop("analog_alpha_scale", None) # legacy
current_dict.pop("image_sizes", None) # should not be saved
# legacy
if "non_blocking" not in current_dict:
current_dict["non_blocking"] = False
# Check for tile mismatch
rpu_config = current_dict.pop("rpu_config")
if hasattr(self, "rpu_config"):
# only for state-dict load. Might not yet be defined (in
# case of pickle load or deepcopy)
if not self.rpu_config.compatible_with(tile_class):
raise TileError(
"Error creating tile"
f". Possible mismatch between {tile_class} and {type(self).__name__}"
)
# Need to always keep the same tile class
rpu_config.tile_class = self.rpu_config.tile_class
self.rpu_config = rpu_config
self.__dict__.update(current_dict)
self.device = torch_device("cpu")
self.is_cuda = False
# recreate attributes not saved
# always first create on CPU
x_size = self.in_size + 1 if self.analog_bias else self.in_size
d_size = self.out_size
# Recreate the tile.
self.tile = self._recreate_simulator_tile(x_size, d_size, self.rpu_config)
names = self.tile.get_hidden_parameter_names()
if len(hidden_parameters_names) > 0 and names != hidden_parameters_names:
# Check whether names match
raise TileError(
"Mismatch with loaded analog state: Hidden parameter structure is unexpected."
)
if not isinstance(hidden_parameters, Tensor):
hidden_parameters = from_numpy(array(hidden_parameters))
self.tile.set_hidden_parameters(hidden_parameters)
if not isinstance(weights, Tensor):
weights = from_numpy(array(weights))
self.tile.set_weights(weights)
if analog_lr is not None:
self.tile.set_learning_rate(analog_lr)
# finally set the extra stuff (without complaining if keys not
# found. Note that these extra states are only needed for some
# tiles (compounds) if training needs to be continued without
# resetting counters etc.)
if extra is not None:
self.tile.load_extra(extra, False)
# map location should be applied to tensors in state_dict
self.analog_ctx = AnalogContext(self)
if analog_ctx is not None:
# Keep the object ID and device
to_device = analog_ctx.device
if self.device != to_device:
self.analog_ctx = self.analog_ctx.to(to_device)
self.analog_ctx.set_data(analog_ctx.data)
[docs] @no_grad()
def post_update_step(self) -> None:
"""Operators that need to be called once per mini-batch.
Note:
This function is called by the analog optimizer.
Caution:
If no analog optimizer is used, the post update steps will
not be performed.
"""
def _combine_weights(
self, weight: Union[Tensor, ArrayLike], bias: Optional[Union[Tensor, ArrayLike]] = None
) -> Tensor:
"""Helper to combines weights and biases
In any case, a detached cpu weight and bias copy will be returned.
Args:
weight: weights without the bias
bias: The bias vector if available
Returns:
combined weights with biases
Raises:
ValueError: if the tile has bias but ``bias`` has not been
specified.
"""
d_type = self.get_dtype()
if not isinstance(weight, Tensor):
weight = from_numpy(array(weight))
weight = weight.clone().detach().cpu().to(d_type)
shape = [self.out_size, self.in_size]
weight = weight.reshape(shape)
if self.analog_bias:
# Create a ``[out_size, in_size (+ 1)]`` matrix.
if bias is None:
raise ValueError("Analog tile has a bias, but no bias given")
if not isinstance(bias, Tensor):
bias = from_numpy(array(bias))
bias = unsqueeze(bias.clone().detach().cpu().to(d_type), 1)
return cat((weight, bias), dim=1)
# Use only the ``[out_size, in_size]`` matrix.
return weight
def _separate_weights(self, combined_weights: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
"""Helper to separate the combined weights and biases"""
# Split the internal weights (and potentially biases) matrix.
if self.analog_bias:
# combined_weights is [out_size, in_size (+ 1)].
return Tensor(combined_weights[:, :-1]), Tensor(combined_weights[:, -1])
return combined_weights, None
[docs] @no_grad()
def cpu(self) -> "SimulatorTileWrapper":
"""Return a copy of this tile in CPU memory.
Returns:
self in case of CPU
"""
if not self.is_cuda:
return self
self.is_cuda = False
self.device = torch_device("cpu")
self.analog_ctx.data = self.analog_ctx.data.cpu()
self.analog_ctx.reset(self)
return self
[docs] @no_grad()
def cuda(
self, device: Optional[Union[torch_device, str, int]] = None
) -> "SimulatorTileWrapper":
"""Return a copy of the tile in CUDA memory.
Args:
device: CUDA device
Returns:
Self with the underlying C++ tile moved to CUDA memory.
Raises:
CudaError: if the library has not been compiled with CUDA.
"""
device = torch_device("cuda", cuda_device(device).idx)
self.is_cuda = True
self.device = device
self.analog_ctx.data = self.analog_ctx.data.cuda(device)
self.analog_ctx.reset(self)
return self
[docs] def get_hidden_parameters(self) -> "OrderedDict":
"""Get the hidden parameters of the tile.
Returns:
Ordered dictionary of hidden parameter tensors.
"""
names = self.tile.get_hidden_parameter_names()
hidden_parameters = self.tile.get_hidden_parameters().detach_()
ordered_parameters = OrderedDict()
for idx, name in enumerate(names):
ordered_parameters[name] = hidden_parameters[idx].clone()
return ordered_parameters
[docs] def set_hidden_parameters(self, ordered_parameters: "OrderedDict") -> None:
"""Set the hidden parameters of the tile.
Caution:
Usually the hidden parameters are drawn according to the
parameter definitions (those given in the RPU config). If
the hidden parameters are arbitrary set by the user, then
this correspondence might be broken. This might cause problems
in the learning, in particular, the `weight granularity`
(usually ``dw_min``, depending on the device) is needed for
the dynamic adjustment of the bit length
(``update_bl_management``, see
:class:`~aihwkit.simulator.parameters.utils.UpdateParameters`).
Currently, the new ``dw_min`` parameter is tried to be
estimated from the average of hidden parameters if the
discrepancy with the ``dw_min`` from the definition is too
large.
Args:
ordered_parameters: Ordered dictionary of hidden parameter tensors.
Raises:
TileError: In case the ordered dict keys do not conform
with the current rpu config tile structure of the hidden
parameters
"""
if len(ordered_parameters) == 0:
return
hidden_parameters = stack(list(ordered_parameters.values()), dim=0)
names = self.tile.get_hidden_parameter_names()
if names != list(ordered_parameters.keys()):
raise TileError(
"Mismatch with loaded analog state: Hidden parameter structure is unexpected."
)
self.tile.set_hidden_parameters(hidden_parameters)
[docs] def set_learning_rate(self, learning_rate: Optional[float]) -> None:
"""Set the tile learning rate.
Set the tile learning rate to ``-learning_rate``. Note that the
learning rate is always taken to be negative (because of the meaning in
gradient descent) and positive learning rates are not supported.
Args:
learning_rate: the desired learning rate.
"""
raise NotImplementedError
[docs] def get_learning_rate(self) -> float:
"""Return the tile learning rate.
Returns:
float: the tile learning rate.
"""
raise NotImplementedError