# -*- coding: utf-8 -*-
# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# Licensed under the MIT license. See LICENSE file in the project root for details.
"""High level analog tiles (base)."""
# pylint: disable=too-few-public-methods, abstract-method, too-many-instance-attributes
from collections import OrderedDict
from typing import Optional, Union, Tuple, Any, Dict, List
from copy import deepcopy
from numpy import array
from numpy.typing import ArrayLike
from torch import Tensor, from_numpy, float32, unsqueeze, cat, empty, stack, dtype, zeros
from torch import device as torch_device
from torch.cuda import device as cuda_device
from torch.autograd import no_grad
from aihwkit import __version__
from aihwkit.exceptions import TileError
from aihwkit.simulator.parameters.mapping import MappingParameter
from aihwkit.simulator.parameters.base import RPUConfigGeneric
from aihwkit.simulator.parameters.runtime import RuntimeParameter
from aihwkit.simulator.parameters.enums import RPUDataType
from aihwkit.optim.context import AnalogContext
[docs]
class TileModuleBase:
"""Base class for (logical) tile modules that can be used in
layers, e.g. array of TileModules.
"""
[docs]
class AnalogTileStateNames:
"""Class defining analog tile state name constants.
Caution:
Do *not* edit. Some names are attribute names of the tile.
"""
VERSION = "aihwkit_version"
WEIGHTS = "analog_tile_weights"
HIDDEN_PARAMETERS = "analog_tile_hidden_parameters"
HIDDEN_PARAMETER_NAMES = "analog_tile_hidden_parameter_names"
CLASS = "analog_tile_class"
LR = "analog_lr"
SHARED_WEIGHTS = "shared_weights"
CONTEXT = "analog_ctx"
OUT_SCALING = "out_scaling_alpha"
MAPPING_SCALES = "mapping_scales"
RPU_CONFIG = "rpu_config"
ANALOG_STATE_PREFIX = "analog_tile_state_"
ANALOG_STATE_NAME = "analog_tile_state"
EXTRA = "state_extra"
[docs]
@staticmethod
def get_field_names() -> List[str]:
"""Returns expected field names."""
return [
getattr(AnalogTileStateNames, key)
for key in AnalogTileStateNames.__dict__
if not key.startswith("_")
]
[docs]
class BaseTile:
"""Base class for tile classes (without ``torch.Module`` dependence)."""
[docs]
def joint_forward(self, x_input: Tensor, is_test: bool = False, ctx: Any = None) -> Tensor:
"""Perform the joint forward method.
Calls first the ``pre_forward``, then the tile forward, and
finally the ``post_forward`` step.
Note:
The full forward pass is not using autograd, thus all pre
and post functions need to be handled appropriately in the
pre/post backward functions.
Args:
x_input: ``[N, in_size]`` tensor. If ``in_trans`` is set, transposed.
is_test: whether to assume testing mode.
ctx: torch auto-grad context [Optional]
Returns:
torch.Tensor: ``[N, out_size]`` tensor. If ``out_trans`` is set, transposed.
"""
raise NotImplementedError
[docs]
def backward(self, d_input: Tensor, ctx: Any = None) -> Tensor:
"""Perform the backward pass.
Args:
d_input: ``[N, out_size]`` tensor. If ``out_trans`` is set, transposed.
ctx: torch auto-grad context [Optional]
Returns:
torch.Tensor: ``[N, in_size]`` tensor. If ``in_trans`` is set, transposed.
"""
raise NotImplementedError
[docs]
def update(self, x_input: Tensor, d_input: Tensor) -> None:
"""Perform the update pass.
Calls the ``pre_update`` method to pre-process the inputs.
Args:
x_input: ``[..., in_size]`` tensor. If ``in_trans`` is set, ``[in_size, ...]``.
d_input: ``[..., out_size]`` tensor. If ``out_trans`` is set, ``[out_size, ...]``.
Returns:
None
"""
raise NotImplementedError
[docs]
class SimulatorTile:
"""Minimal class interface for implementing the simulator tile.
Note:
This tile is generated by ``_create_simulator_tile`` in the
``SimulatorTileWrapper``.
"""
[docs]
def forward(
self,
x_input: Tensor,
bias: bool = False,
in_trans: bool = False,
out_trans: bool = False,
is_test: bool = False,
non_blocking: bool = False,
) -> Tensor:
"""General simulator tile forward."""
raise NotImplementedError
[docs]
def backward(
self,
d_input: Tensor,
bias: bool = False,
in_trans: bool = False,
out_trans: bool = False,
non_blocking: bool = False,
) -> Tensor:
"""Backward pass.
Only needs to be implemented if torch autograd is `not` used.
"""
raise NotImplementedError
[docs]
def update(
self,
x_input: Tensor,
d_input: Tensor,
bias: bool = False,
in_trans: bool = False,
out_trans: bool = False,
non_blocking: bool = False,
) -> Tensor:
"""Update.
Only needs to be implemented if torch autograd update is `not` used.
"""
raise NotImplementedError
[docs]
def get_brief_info(self) -> str:
"""Returns a brief info"""
raise NotImplementedError
[docs]
def get_weights(self, as_ref: bool = False) -> Tensor:
"""Returns the analog weights.
Args:
as_ref: if True, return a reference to the internal weight tensor
(not detached, stays on the current device). If False (default),
return a detached CPU copy. Not all tile types support true
references; C++ tiles always return a copy regardless.
"""
raise NotImplementedError
[docs]
def set_weights(self, weight: Tensor) -> None:
"""Stets the analog weights."""
raise NotImplementedError
[docs]
def get_x_size(self) -> int:
"""Returns input size of tile"""
raise NotImplementedError
[docs]
def get_d_size(self) -> int:
"""Returns output size of tile"""
raise NotImplementedError
[docs]
def get_hidden_parameters(self) -> Tensor:
"""Get the hidden parameters of the tile.
Returns:
Hidden parameter tensor.
"""
return empty(0, dtype=float32)
[docs]
def get_hidden_parameter_names(self) -> List[str]:
"""Get the hidden parameters names.
Each name corresponds to a slice in the Tensor slice of the
``get_hidden_parameters`` tensor.
Returns:
List of names.
"""
return []
[docs]
def set_hidden_parameters(self, params: Tensor) -> None:
"""Set the hidden parameters of the tile."""
[docs]
def get_learning_rate(self) -> Optional[float]:
"""Get the learning rate of the tile.
Returns:
learning rate if exists.
"""
[docs]
def set_learning_rate(self, learning_rate: Optional[float]) -> None:
"""Set the learning rate of the tile.
No-op for tiles that do not need a learning rate.
Args:
learning rate: learning rate to set
"""
# pylint: disable=too-many-public-methods
[docs]
class SimulatorTileWrapper:
"""Wrapper base class for defining the necessary tile
functionality.
Will be overloaded extended for C++ or for any TorchTile.
Args:
out_size: output size
in_size: input size
rpu_config: resistive processing unit configuration.
bias: whether to add a bias column to the tile.
in_trans: Whether to assume an transposed input (batch first)
out_trans: Whether to assume an transposed output (batch first)
shared_weights: optional shared weights tensor memory that
should be used.
handle_output_bound: whether the bound clamp gradient should be inserted
ignore_analog_state: whether to ignore the analog state when __getstate__ is called
Attributes:
tile: A simulator tile object that handles the computations
for the given input/output sizes.
It is created by `self._create_simulator_tile` method,
which is provided by the derived class.
E.g., `aihwkit.simulator.tiles.analog.AnalogTile` and
`aihwkit.simulator.tiles.inference_torch.TorchInferenceTile`
implement this method.
The weight data is stored in the tile object.
analog_ctx: `AnalogContext`, which exposes the tile's logical
weights as a read-only `torch.nn.Parameter` view.
"""
def __init__(
self,
out_size: int,
in_size: int,
rpu_config: RPUConfigGeneric,
bias: bool = True,
in_trans: bool = False,
out_trans: bool = False,
torch_update: bool = False,
handle_output_bound: bool = False,
ignore_analog_state: bool = False,
):
self.out_size = out_size
self.in_size = in_size
self.rpu_config = deepcopy(rpu_config)
self.in_trans = in_trans
self.out_trans = out_trans
self.handle_output_bound = handle_output_bound
self.ignore_analog_state = ignore_analog_state
self.shared_weights = None
# handling the bias
if hasattr(rpu_config, "mapping"):
mapping = rpu_config.mapping
else:
mapping = MappingParameter()
self.digital_bias = bias and mapping.digital_bias
self.use_bias = bias
self.analog_bias = bias and not mapping.digital_bias
x_size = self.in_size + 1 if self.analog_bias else self.in_size
d_size = self.out_size
self.tile = self._create_simulator_tile(x_size, d_size, rpu_config)
# Set up zero-copy shared weight tensor for C++ tiles.
self._shared_weight_tensor = None # type: Optional[Tensor]
self._bind_shared_weights()
self.analog_ctx = AnalogContext(self)
self.analog_ctx.use_torch_update = torch_update
def _bind_shared_weights(self) -> None:
"""Bind a PyTorch tensor as the C++ tile's weight storage.
For C++ tiles that expose ``set_shared_weights``, this allocates a
contiguous tensor and passes its ``data_ptr`` to the C++ side so that
both Python and C++ operate on the same memory. After this call
``tile.update()`` / ``tile.set_weights()`` modify the tensor
in-place — no explicit sync is needed.
For pure-Python tiles (which already store weights as
``torch.Tensor``), this is a no-op.
"""
if not hasattr(self.tile, "set_shared_weights"):
return
# Probe whether the tile is a pure-Python tile by trying to call
# get_weights(as_ref=True).
#
# Tiles that accept ``as_ref``:
# TorchSimulatorTile — returns self.weight.data
# CustomSimulatorTile — returns self._analog_weight.data
# TransferSimulatorTile — accepts but delegates to C++ tile
#
# C++ tiles (pybind11 bindings) do NOT accept keyword arguments
# and raise TypeError. These are the ones that need binding:
# tiles.AnalogTile / CudaAnalogTile
# tiles.FloatingPointTile / CudaFloatingPointTile
# (and their half/double/bfloat16 variants)
try:
self.tile.get_weights(as_ref=True)
return # Pure-Python tile — already backed by torch.Tensor.
except TypeError:
pass # C++ tile — proceed with shared weight binding below.
d_size = self.tile.get_d_size()
x_size = self.tile.get_x_size()
# C++ get_weights() always returns CPU. For CUDA tiles, use the
# private raw analog context backing when available: AIHWKIT keeps that
# context on the same concrete device as the tile during
# .cuda(device)/.to(device) moves. Fall back to the active CUDA context
# only during initialization windows where analog_ctx is not CUDA-backed
# yet.
#
# CUDA C++ tiles expect transposed layout (x_size, d_size) for
# set_shared_weights, while CPU tiles expect (d_size, x_size).
is_cuda = "Cuda" in type(self.tile).__name__
if is_cuda:
w_cpu = self.tile.get_weights() # (d_size, x_size) on CPU
# Normal path after .cuda(device) / .to(device): analog_ctx has
# already been refreshed onto the target GPU, so it is the most
# reliable source of the tile's concrete cuda:N placement.
if hasattr(self, "analog_ctx") and self.analog_ctx._raw_data().is_cuda:
tile_device = self.analog_ctx._raw_data().device
else:
# Fallback for partial-initialization / transition windows
# where analog_ctx does not exist yet or still points to CPU.
# In particular, __init__() binds shared weights before
# analog_ctx is created, so CUDA tiles must use the active
# CUDA context established by the caller.
tile_device = torch_device("cuda", cuda_device(None).idx)
shared = zeros(x_size, d_size, dtype=self.get_dtype(), device=tile_device)
else:
tile_device = torch_device("cpu")
shared = zeros(d_size, x_size, dtype=self.get_dtype(), device=tile_device)
self.tile.set_shared_weights(shared)
# CUDA set_shared_weights does not auto-populate the buffer (unlike CPU).
# Force-sync from the tile's internal device weights into the shared tensor.
if is_cuda:
shared.copy_(w_cpu.t().to(tile_device))
self._shared_weight_tensor = shared
def _get_tile_weights_ref(self) -> Tensor:
"""Get raw tile weights, preferring a reference if the tile supports it.
This returns the backing weights stored by the simulator tile. It does
not apply digital mapping/output scales. The higher-level
``AnalogContext.data`` accessor applies those scales when it presents
read-only logical weights to users.
Possible sources for the returned tensor are:
- ``self._shared_weight_tensor``: allocated and bound by
:meth:`_bind_shared_weights` in this class for C++ tiles exposing
``set_shared_weights``.
- ``self.tile.get_weights_cuda()``: native CUDA C++ binding that returns
device weights in transposed layout when available.
- ``self.tile.get_weights(as_ref=True)``: Python simulator tiles define
this reference-returning path, for example ``TorchSimulatorTile`` in
``torch_tile.py`` and ``CustomSimulatorTile`` in ``custom.py``.
- ``self.tile.get_weights()``: fallback for C++ bindings that cannot
return a reference and therefore provide only a detached copy.
"""
if self._shared_weight_tensor is not None:
# CUDA C++ tiles store shared weights in transposed layout
# (x_size, d_size). Return .t() so callers always see the
# standard (d_size, x_size) shape — still zero-copy because
# .t() on a 2-D tensor is a stride-only view.
if "Cuda" in type(self.tile).__name__:
return self._shared_weight_tensor.t()
return self._shared_weight_tensor
# Fast path: CUDA C++ tiles with native GPU weight access.
# get_weights_cuda() returns [x_size, d_size] on device; .t() gives
# the standard [d_size, x_size] view without any CPU roundtrip.
if hasattr(self.tile, "get_weights_cuda"):
return self.tile.get_weights_cuda().t()
try:
return self.tile.get_weights(as_ref=True)
except TypeError:
# C++ tile bindings don't accept as_ref
return self.tile.get_weights()
def _sync_analog_ctx_weights(self) -> None:
"""Sync the private analog context raw backing with tile weights.
With shared weight tensors, the context raw backing and the tile's
internal weights already share the same memory, so during normal
training (forward -> update) this is a no-op (same ``data_ptr``).
This method is still necessary for device moves (cpu <-> cuda): moving
the tile to a different device replaces its backing store, which
invalidates the old ``data_ptr``. Public ``analog_ctx.data`` is a
logical read-only view and must not be used for this raw rebinding.
"""
if not hasattr(self, "analog_ctx"):
return
current = self.analog_ctx._raw_data()
target_device = current.device
ref = self._get_tile_weights_ref()
if current.data_ptr() != ref.data_ptr() or current.device != ref.device:
self.analog_ctx._replace_raw_data(ref.to(target_device))
def _copy_weights_to_shared_tensors(self, weights: Tensor) -> None:
"""Copy loaded raw weights into Python-side shared backing tensors.
``weights`` is the canonical raw tile matrix saved in CPU layout
``[d_size, x_size]``. During pickle loading, Python-side
``shared_weights`` may be restored from a CUDA checkpoint with the CUDA
internal layout ``[x_size, d_size]`` while the C++ tile is first
recreated on CPU. Normalize that tensor before rebinding it to the CPU
tile.
"""
def copy_to(target: Optional[Tensor]) -> None:
if target is None:
return
source = weights
if target.shape != source.shape:
transposed = weights.t()
if target.shape == transposed.shape:
source = transposed
else:
raise TileError(
"Mismatch with loaded analog state: shared weight shape is unexpected."
)
target.copy_(source.to(device=target.device, dtype=target.dtype))
with no_grad():
copy_to(self._shared_weight_tensor)
shared_weights = getattr(self, "shared_weights", None)
if shared_weights is not None:
target = shared_weights.data
if (
not hasattr(self.tile, "get_weights_cuda")
and target.shape != weights.shape
):
# Replacing .data is intentional here: a CUDA checkpoint can
# restore the Parameter with transposed CUDA layout, and
# copy_ cannot change the target tensor shape.
transposed = weights.t()
if target.shape != transposed.shape:
raise TileError(
"Mismatch with loaded analog state: shared weight shape is unexpected."
)
shared_weights.data = weights.to(
device=target.device, dtype=target.dtype
).clone()
copy_to(shared_weights.data)
@property
def device(self) -> torch_device:
"""Return the device of the tile."""
return self.analog_ctx.device
@property
def is_cuda(self) -> bool:
"""Return the is_cuda state of the tile."""
return self.analog_ctx.is_cuda
[docs]
def get_runtime(self) -> RuntimeParameter:
"""Returns the runtime parameter."""
if not hasattr(self.rpu_config, "runtime"):
self.rpu_config.runtime = RuntimeParameter()
return self.rpu_config.runtime
[docs]
def get_data_type(self) -> RPUDataType:
"""Return data_type setting of the RPUConfig"""
return self.get_runtime().data_type
[docs]
def get_dtype(self) -> dtype:
"""Return dtype setting of the RPUConfig"""
return self.get_runtime().data_type.as_torch()
def _create_simulator_tile(
self, x_size: int, d_size: int, rpu_config: "RPUConfigGeneric"
) -> Any: # just use Any instead of Union["SimulatorTile", tiles.AnalogTile, ..]
"""Create a simulator tile.
Args:
x_size: input size
d_size: output size
rpu_config: resistive processing unit configuration
Returns:
a simulator tile based on the specified configuration.
"""
raise NotImplementedError
def _recreate_simulator_tile(
self, x_size: int, d_size: int, rpu_config: "RPUConfigGeneric"
) -> Any: # just use Any instead of Union["SimulatorTile", tiles.AnalogTile, ..]
"""Re-create a simulator tile in __setstate__.
Args:
x_size: input size
d_size: output size
rpu_config: resistive processing unit configuration
Returns:
a simulator tile based on the specified configuration.
"""
return self._create_simulator_tile(x_size, d_size, rpu_config)
[docs]
def get_tensor_view(self, ndim: int, dim: Optional[int] = None) -> tuple:
"""Return the tensor view for ndim vector at dim.
Args:
ndim: number of dimensions
dim: the dimension to set to -1
Returns:
Tuple of ones with the `dim`` index sets to -1
"""
if dim is None:
dim = 0 if self.out_trans else ndim - 1
tensor_view = [1] * ndim
tensor_view[dim] = -1
return tuple(tensor_view)
[docs]
def get_forward_out_bound(self) -> Optional[float]:
"""Helper for getting the output bound to correct the
gradients using the AnalogFunction.
"""
return None
[docs]
def set_verbosity_level(self, verbose: int) -> None:
"""Set the verbosity level.
Args:
verbose: level of verbosity
"""
[docs]
def get_analog_ctx(self) -> AnalogContext:
"""Return the analog context of the tile to be used in ``AnalogFunction``."""
return self.analog_ctx
[docs]
def get_brief_info(self) -> str:
"""Return short info about the underlying C++ tile."""
return self.tile.get_brief_info().rstrip()
[docs]
def update(self, x_input: Tensor, d_input: Tensor) -> None:
"""Implements tile update (e.g. using pulse trains)."""
raise NotImplementedError
[docs]
def update_indexed(self, x_input: Tensor, d_input: Tensor) -> None:
"""Implements indexed interface to the tile update
(e.g. using pulse trains)."""
raise NotImplementedError
[docs]
def get_analog_state(self) -> Dict:
"""Get the analog state for the state_dict.
Excludes the non-analog state names that might be added for
pickling. Only fields defined in ``AnalogTileStateNames`` are
returned.
"""
state = self.__getstate__()
fields = AnalogTileStateNames.get_field_names()
rm_fields = []
for key in state:
if key not in fields:
rm_fields.append(key)
for key in rm_fields:
state.pop(key)
return state
def __getstate__(self) -> Dict:
"""Get the state for pickling.
This method removes the ``tile`` member, as the binding Tiles are not
serializable.
"""
# Caution: all attributes of the tile will be saved.
current_dict = self.__dict__.copy()
if getattr(self, "ignore_analog_state", False):
return current_dict
SN = AnalogTileStateNames
current_dict[SN.WEIGHTS] = self.tile.get_weights()
current_dict[SN.HIDDEN_PARAMETERS] = self.tile.get_hidden_parameters().data
current_dict[SN.HIDDEN_PARAMETER_NAMES] = self.tile.get_hidden_parameter_names()
current_dict[SN.CLASS] = type(self).__name__
current_dict[SN.LR] = self.tile.get_learning_rate()
current_dict.pop("tile", None)
current_dict[SN.CONTEXT] = self.analog_ctx._raw_data().detach()
current_dict[SN.EXTRA] = self.tile.dump_extra()
current_dict[SN.VERSION] = __version__
# don't save device. Will be determined by loading object
current_dict.pop("stream", None)
# this is should not be saved.
current_dict.pop("image_sizes", None)
# Shared weight tensor is rebuilt by _bind_shared_weights().
current_dict.pop("_shared_weight_tensor", None)
return current_dict
def __setstate__(self, state: Dict) -> None:
"""Set the state after unpickling.
This method recreates the ``tile`` member, creating a new one from
scratch, as the binding Tiles are not serializable.
Caution:
RPU configs are overwritten by loading the state.
Note:
Some RPUCuda (analog training) compounds have some extra
internal states that should be set if checkpointing to
continue training. To support this, extra states are
extracted and stored. However, these are _not_ applied if
cross-loading is done, e.g. map location is different for
inference or tile type is changed. It will not throw any
notice is they are not applied.
Raises:
TileError: if tile class does not match or hidden parameters do not match
"""
# pylint: disable=too-many-locals, too-many-statements, too-many-branches
if getattr(self, "ignore_analog_state", False) or state.get("ignore_analog_state", False):
self.__dict__.update(state)
analog_ctx = self.analog_ctx
else:
SN = AnalogTileStateNames
current_dict = state.copy()
tile_class = current_dict.pop(SN.CLASS, type(self).__name__)
analog_lr = current_dict.pop(SN.LR, 0.01)
analog_ctx = current_dict.pop(SN.CONTEXT, None)
weights = current_dict.pop(SN.WEIGHTS)
extra = current_dict.pop(SN.EXTRA, None)
hidden_parameters = current_dict.pop(SN.HIDDEN_PARAMETERS)
hidden_parameters_names = current_dict.pop(SN.HIDDEN_PARAMETER_NAMES, [])
current_dict.pop("analog_alpha_scale", None) # legacy
current_dict.pop("image_sizes", None) # should not be saved
# legacy
if "non_blocking" not in current_dict:
current_dict["non_blocking"] = False
# Check for tile mismatch
rpu_config = current_dict.pop("rpu_config")
if hasattr(self, "rpu_config"):
# only for state-dict load. Might not yet be defined (in
# case of pickle load or deepcopy)
if not self.rpu_config.compatible_with(tile_class):
raise TileError(
"Error creating tile"
f". Possible mismatch between {tile_class} and {type(self).__name__}"
)
# Need to always keep the same tile class
rpu_config.tile_class = self.rpu_config.tile_class
self.rpu_config = rpu_config
self.__dict__.update(current_dict)
# recreate attributes not saved
# always first create on CPU
x_size = self.in_size + 1 if self.analog_bias else self.in_size
d_size = self.out_size
# Recreate the tile.
self.tile = self._recreate_simulator_tile(x_size, d_size, self.rpu_config)
self._shared_weight_tensor = None
self._bind_shared_weights()
names = self.tile.get_hidden_parameter_names()
if len(hidden_parameters_names) > 0 and names != hidden_parameters_names:
# Check whether names match
raise TileError(
"Mismatch with loaded analog state: Hidden parameter structure is unexpected."
)
if not isinstance(hidden_parameters, Tensor):
hidden_parameters = from_numpy(array(hidden_parameters))
self.tile.set_hidden_parameters(hidden_parameters)
if not isinstance(weights, Tensor):
weights = from_numpy(array(weights))
self.tile.set_weights(weights)
self._copy_weights_to_shared_tensors(weights)
# set_weights() fills the fresh CPU C++ tile from the serialized raw
# weights. _copy_weights_to_shared_tensors() then makes the Python
# Parameter hold the same CPU-layout data. ensure_shared_weights()
# completes the hand-off by making the C++ tile use that Parameter
# as its shared backing store, so tile weights, AnalogContext, and
# optimizer state all observe the same storage.
#
# This rebinding is only valid while the recreated tile is CPU-side.
# If the loaded object is later moved to CUDA, cuda() recreates the
# CUDA tile and binds CUDA-layout shared weights there. The method is
# only present on RPUCudaSimulatorTileWrapper, so keep the dynamic
# callable guard for non-shared tile wrappers.
shared_weights = getattr(self, "shared_weights", None)
ensure_shared_weights = getattr(self, "ensure_shared_weights", None)
if (
shared_weights is not None
and not shared_weights.is_cuda
and callable(ensure_shared_weights)
):
ensure_shared_weights() # pylint: disable=not-callable
if analog_lr is not None:
self.tile.set_learning_rate(analog_lr)
# finally set the extra stuff (without complaining if keys not
# found. Note that these extra states are only needed for some
# tiles (compounds) if training needs to be continued without
# resetting counters etc.)
if extra is not None:
self.tile.load_extra(extra, False)
# map location should be applied to tensors in state_dict
self.analog_ctx = AnalogContext(self)
if analog_ctx is not None:
# Keep the object ID and device
to_device = analog_ctx.device
if self.analog_ctx.device != to_device:
# aihwkit implements analog tiles in both CPU and CUDA versions,
# e.g. FloatingPointTile(RPUSimple<float>(4, 3))
# v.s. FloatingPointTile(RPUCudaSimple<float>(4, 3))
# Here we need to manually convert the tile to the corresponding version
self.to(to_device)
# Note: `self.to(to_device)` will rebind the private raw
# analog context backing, so no additional context copy
# is needed.
# self.analog_ctx = self.analog_ctx.to(to_device)
[docs]
@no_grad()
def post_update_step(self) -> None:
"""Operators that need to be called once per mini-batch.
Note:
This function is called by the analog optimizer.
Caution:
If no analog optimizer is used, the post update steps will
not be performed.
"""
def _combine_weights(
self, weight: Union[Tensor, ArrayLike], bias: Optional[Union[Tensor, ArrayLike]] = None
) -> Tensor:
"""Helper to combines weights and biases
In any case, a detached cpu weight and bias copy will be returned.
Args:
weight: weights without the bias
bias: The bias vector if available
Returns:
combined weights with biases
Raises:
ValueError: if the tile has bias but ``bias`` has not been
specified.
"""
d_type = self.get_dtype()
if not isinstance(weight, Tensor):
weight = from_numpy(array(weight))
weight = weight.clone().detach().cpu().to(d_type)
shape = [self.out_size, self.in_size]
weight = weight.reshape(shape)
if self.analog_bias:
# Create a ``[out_size, in_size (+ 1)]`` matrix.
if bias is None:
raise ValueError("Analog tile has a bias, but no bias given")
if not isinstance(bias, Tensor):
bias = from_numpy(array(bias))
bias = unsqueeze(bias.clone().detach().cpu().to(d_type), 1) # type: ignore
return cat((weight, bias), dim=1)
# Use only the ``[out_size, in_size]`` matrix.
return weight
def _combine_weights_cuda(
self, weight: Union[Tensor, "ArrayLike"], bias: Optional[Union[Tensor, "ArrayLike"]] = None
) -> Tensor:
"""Like _combine_weights but keeps tensors on the tile's CUDA device.
Returns a **contiguous** ``[x_size, d_size]`` CUDA tensor in the
internal transposed layout expected by ``set_weights_cuda``.
"""
d_type = self.get_dtype()
device = self.device # the tile's CUDA device
if not isinstance(weight, Tensor):
weight = from_numpy(array(weight))
weight = weight.detach().to(dtype=d_type, device=device).reshape(
self.out_size, self.in_size
)
if self.analog_bias:
if bias is None:
raise ValueError("Analog tile has a bias, but no bias given")
if not isinstance(bias, Tensor):
bias = from_numpy(array(bias))
bias = unsqueeze(bias.detach().to(dtype=d_type, device=device), 1) # type: ignore
combined = cat((weight, bias), dim=1) # [out_size, in_size+1]
else:
combined = weight # [out_size, in_size]
# Transpose to [x_size, d_size] (the internal CUDA storage layout).
return combined.t().contiguous()
def _separate_weights(self, combined_weights: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
"""Helper to separate the combined weights and biases"""
# Split the internal weights (and potentially biases) matrix.
if self.analog_bias:
# combined_weights is [out_size, in_size (+ 1)].
return Tensor(combined_weights[:, :-1]), Tensor(combined_weights[:, -1])
return combined_weights, None
# pylint: disable=invalid-name
[docs]
def to(self, device: torch_device) -> "SimulatorTileWrapper":
"""Move the tile to a device.
"""
if device.type == "cuda":
self.cuda(device)
else:
self.cpu()
return self
[docs]
@no_grad()
def cpu(self) -> "SimulatorTileWrapper":
"""Return a copy of this tile in CPU memory.
Returns:
self in case of CPU
"""
if not self.is_cuda:
return self
self.analog_ctx._replace_raw_data(self.analog_ctx._raw_data().cpu())
self.analog_ctx.reset(self)
self._shared_weight_tensor = None
self._bind_shared_weights()
self._sync_analog_ctx_weights()
return self
[docs]
@no_grad()
def cuda(
self, device: Optional[Union[torch_device, str, int]] = None
) -> "SimulatorTileWrapper":
"""Return a copy of the tile in CUDA memory.
Args:
device: CUDA device
Returns:
Self with the underlying C++ tile moved to CUDA memory.
Raises:
CudaError: if the library has not been compiled with CUDA.
"""
device = torch_device("cuda", cuda_device(device).idx)
self.analog_ctx._replace_raw_data(self.analog_ctx._raw_data().cuda(device))
self.analog_ctx.reset(self)
self._shared_weight_tensor = None
self._bind_shared_weights()
self._sync_analog_ctx_weights()
return self
[docs]
def get_hidden_parameters(self) -> "OrderedDict":
"""Get the hidden parameters of the tile.
Returns:
Ordered dictionary of hidden parameter tensors.
"""
names = self.tile.get_hidden_parameter_names()
hidden_parameters = self.tile.get_hidden_parameters().detach_()
ordered_parameters = OrderedDict()
for idx, name in enumerate(names):
ordered_parameters[name] = hidden_parameters[idx].clone()
return ordered_parameters
[docs]
def set_hidden_parameters(self, ordered_parameters: "OrderedDict") -> None:
"""Set the hidden parameters of the tile.
Caution:
Usually the hidden parameters are drawn according to the
parameter definitions (those given in the RPU config). If
the hidden parameters are arbitrary set by the user, then
this correspondence might be broken. This might cause problems
in the learning, in particular, the `weight granularity`
(usually ``dw_min``, depending on the device) is needed for
the dynamic adjustment of the bit length
(``update_bl_management``, see
:class:`~aihwkit.simulator.parameters.utils.UpdateParameters`).
Currently, the new ``dw_min`` parameter is tried to be
estimated from the average of hidden parameters if the
discrepancy with the ``dw_min`` from the definition is too
large.
Args:
ordered_parameters: Ordered dictionary of hidden parameter tensors.
Raises:
TileError: In case the ordered dict keys do not conform
with the current rpu config tile structure of the hidden
parameters
"""
if len(ordered_parameters) == 0:
return
hidden_parameters = stack(list(ordered_parameters.values()), dim=0)
names = self.tile.get_hidden_parameter_names()
if names != list(ordered_parameters.keys()):
raise TileError(
"Mismatch with loaded analog state: Hidden parameter structure is unexpected."
)
self.tile.set_hidden_parameters(hidden_parameters)
[docs]
def set_learning_rate(self, learning_rate: Optional[float]) -> None:
"""Set the tile learning rate.
Set the tile learning rate to ``-learning_rate``. Note that the
learning rate is always taken to be negative (because of the meaning in
gradient descent) and positive learning rates are not supported.
Args:
learning_rate: the desired learning rate.
"""
raise NotImplementedError
[docs]
def get_learning_rate(self) -> float:
"""Return the tile learning rate.
Returns:
float: the tile learning rate.
"""
raise NotImplementedError