Source code for aihwkit.optim.context

# -*- coding: utf-8 -*-

# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# Licensed under the MIT license. See LICENSE file in the project root for details.

"""Parameter context for analog tiles."""

# pylint: disable=attribute-defined-outside-init

from typing import Optional, Type, Union, Any, List, TYPE_CHECKING

from torch import dtype, Tensor
from torch._C import DisableTorchFunction
from torch.nn import Parameter
from torch import device as torch_device
from torch.utils._pytree import tree_map

from aihwkit.optim.weight_view import (
    ReadOnlyWeightView,
    raise_if_readonly_write_target,
    PlaceholderDataView,
    _PLACEHOLDER_METADATA_FUNCTIONS,
    _raise_placeholder_read_error,
)
from aihwkit.simulator.parameters.enums import AnalogContextDataViewMode

if TYPE_CHECKING:
    from aihwkit.simulator.tiles.base import SimulatorTileWrapper


# Tensor properties that materialize weight values (transpose / conjugate views
# and complex parts). They must honor the data-view mode like their method
# equivalents (``t()``, ``conj()``); otherwise the getset descriptors fall
# through to ``__torch_function__`` as the whitelisted ``__get__`` and silently
# return uninitialized placeholder memory in PLACEHOLDER mode.
_VALUE_VIEW_PROPERTIES = ("T", "mT", "H", "mH", "real", "imag")


[docs] class AnalogContext(Parameter): """Context for analog optimizer. If `analog_bias` (which is provided by `analog_tile`) is False, `data` has the same meaning as `torch.nn.Parameter` If `analog_bias` (which is provided by `analog_tile`) is True, The last column of `data` is the `bias` term For diagnostic purposes, `AnalogContext` provides three public data view modes. Consider the code: --- layer = AnalogLinear(4, 3, bias=False, rpu_config=rpu_config) analog_tile = layer.analog_module analog_ctx = analog_tile.analog_ctx weight = analog_tile.get_weights()[0] --- where `weight` is the logical weight view, which is already ``physical weights x scaling`` Data view modes are controlled by `analog_ctx.data_view_mode` and the corresponding methods: --- analog_ctx.enable_placeholder(). # PLACEHOLDER mode (default) analog_ctx.enable_data_view(). # DATA_VIEW mode analog_ctx.enable_buffer(). # BUFFER mode --- * PLACEHOLDER (default): only metadata, such as ``size()``, ``shape``. Since the RPU conductance values is not directly accessible in physic, the weight values, as well as value-based operations, such as ``norm()``, are blocked by default Access them raises ``RuntimeError``. --- # inspect metadata without reading values: analog_ctx.size() # [4, 3] analog_ctx.device() # 'cpu' analog_ctx.norm() # RuntimeError --- * DATA_VIEW: exposes a read-only logical weight view through the `data` attribute, which is equivalent to `analog_tile.get_weights()[0]`. This allows users to inspect the effective weights. Since the changes of both weights and scaling affect the logical weights, we adopt the convetion that this logical view is read-only Therefore, in-place operations, such as ``add_``, ``mul_``, etc, are blocked --- # The following three lines will print the same value: analog_ctx.size() analog_ctx.data.size() weight.size() # Accessing values is allowed, but they are read-only: analog_ctx.norm() # Successfully returns the norm analog_ctx.norm() == weight.norm() # True analog_ctx.add_(1.0) # RuntimeError --- * BUFFER: exposes a zero-initialized tensor with the logical weight shape through the `data` At that mode, `data` is an independent buffer that is not connected to the analog tile. It is intended for optimizers with digital auxiliary state, such as mixed-precision training or TT-v2. --- analog_ctx.norm() == weight.norm() # Typically False, since the buffer is independent analog_ctx.add_(1.0) # Successfully adds 1.0 to the buffer, but does not affect the analog tile weights --- To update the internal analog weights, use the following update methods instead of writing `data` directly in the analog optimizer: --- analog_ctx.analog_tile.update(...) analog_ctx.analog_tile.update_indexed(...) --- Caution: Even though DATA_VIEW mode allows us to access the weights directly, always keep in mind that it is used only for diagnostic purposes. To simulate the real reading, call the `read_weights` method instead, i.e. given `analog_ctx: AnalogContext`, estimated_weights, estimated_bias = analog_ctx.analog_tile.read_weights() """ def __new__( cls: Type["AnalogContext"], analog_tile: "SimulatorTileWrapper", parameter: Optional[Parameter] = None, ) -> "AnalogContext": # pylint: disable=signature-differs if parameter is None: weights_ref = analog_tile._get_tile_weights_ref() return Parameter.__new__( cls, data=weights_ref, requires_grad=True, ) # analog_tile.tile can come from different classes: # aihwkit.simulator.rpu_base.devices.AnalogTile (C++) # TorchInferenceTile (Python) # It stores the raw tile matrix; if analog_tile.analog_bias is True, # the last raw column stores the bias. parameter.__class__ = cls return parameter def __init__( self, analog_tile: "SimulatorTileWrapper", parameter: Optional[Parameter] = None ): # pylint: disable=unused-argument super().__init__() self.analog_tile = analog_tile self.use_torch_update = False self.use_indexed = False self._data_view_mode = AnalogContextDataViewMode.PLACEHOLDER self._data_buffer = None # type: Optional[Tensor] self.analog_input = [] # type: list self.analog_grad_output = [] # type: list self.reset(analog_tile) @classmethod def __torch_function__( cls, func: Any, _types: Any, args: Any = (), kwargs: Optional[Any] = None ) -> Any: kwargs = kwargs or {} func_name = getattr(func, "__name__", "") if func_name == "requires_grad_" and args and isinstance(args[0], AnalogContext): # ``requires_grad_`` toggles the autograd flag, not weight values, so # the read-only in-place guard below must not reject it and the data # view redirection must not send it to a throwaway placeholder. # ``nn.Module.requires_grad_`` calls this on every parameter to # (un)freeze a layer, so route it straight to the real Parameter, # mirroring the ``requires_grad`` attribute setter. target = args[0] requested = args[1] if len(args) > 1 else kwargs.get("requires_grad", True) target.requires_grad = bool(requested) return target def is_readonly(value: Any) -> bool: # BUFFER mode exposes an independent, writable digital buffer # (used by mixed-precision optimizers), so in-place ops are allowed. if isinstance(value, AnalogContext): return value._get_data_view_mode() != AnalogContextDataViewMode.BUFFER return isinstance(value, ReadOnlyWeightView) raise_if_readonly_write_target(func_name, args, kwargs, is_readonly) def to_public_tensor(value: Any) -> Any: if isinstance(value, AnalogContext): return value._torch_function_data(func_name) return value args = tree_map(to_public_tensor, args) kwargs = tree_map(to_public_tensor, kwargs) return func(*args, **kwargs) def __setitem__(self, key: Any, value: Any) -> None: """Block direct item assignment.""" raise RuntimeError( "Direct item assignment on analog weights is not allowed. " "Use analog_tile.set_weights() instead." ) def __dir__(self) -> List[str]: """List attribute names for interactive tab-completion. ``Tensor.__dir__`` dispatches through ``__torch_function__`` (this class defines one), which in PLACEHOLDER mode routes ``__dir__`` through the weight-read guard and raises. Completers (rlcompleter, IPython) swallow that error and show nothing. Listing the class and instance attribute names directly avoids the value-read dispatch, so ``analog_ctx.<tab>`` offers the same tensor ops as a plain tensor in every data-view mode. """ keys = set(dir(type(self))) keys.update(object.__getattribute__(self, "__dict__")) return sorted(keys) @staticmethod def _coerce_data_view_mode(value: Any) -> AnalogContextDataViewMode: """Convert public mode inputs to ``AnalogContextDataViewMode``.""" if isinstance(value, AnalogContextDataViewMode): return value if isinstance(value, str): for mode in AnalogContextDataViewMode: if value == mode.value or value.upper() == mode.name: return mode raise ValueError( "data_view_mode must be an AnalogContextDataViewMode value, " "or one of: placeholder, data_view, buffer." ) def _get_data_view_mode(self) -> AnalogContextDataViewMode: """Return the active public data view mode.""" try: return object.__getattribute__(self, "_data_view_mode") except AttributeError: return AnalogContextDataViewMode.PLACEHOLDER @property def data_view_mode(self) -> AnalogContextDataViewMode: """Return the active public data access mode.""" return self._get_data_view_mode() @data_view_mode.setter def data_view_mode(self, value: Any) -> None: """Set the active public data access mode.""" mode = self._coerce_data_view_mode(value) self._data_view_mode = mode if mode == AnalogContextDataViewMode.BUFFER: self._data_buffer = self._new_data_buffer() else: self._data_buffer = None
[docs] def enable_data_view(self) -> "AnalogContext": """Enable read-only logical weight reads for diagnostics.""" self.data_view_mode = AnalogContextDataViewMode.DATA_VIEW return self
[docs] def enable_placeholder(self) -> "AnalogContext": """Enable metadata-only placeholder mode.""" self.data_view_mode = AnalogContextDataViewMode.PLACEHOLDER return self
[docs] def enable_buffer(self) -> "AnalogContext": """Enable an independent zero-initialized digital data buffer.""" self.data_view_mode = AnalogContextDataViewMode.BUFFER return self
def _raw_data(self) -> Tensor: """Return the internal raw tile backing tensor.""" with DisableTorchFunction(): # pylint: disable=not-context-manager return super().__getattribute__("data") def _logical_shape(self) -> Any: """Return the logical public weight shape without reading values.""" raw = self._raw_data() try: analog_tile = object.__getattribute__(self, "analog_tile") except AttributeError: return raw.shape if getattr(analog_tile, "analog_bias", False) and raw.dim() >= 2: return raw[:, : analog_tile.in_size].shape return raw.shape def _logical_data(self) -> Tensor: """Return logical weights equivalent to ``analog_tile.get_weights()[0]``.""" raw = self._raw_data() try: analog_tile = object.__getattribute__(self, "analog_tile") except AttributeError: return raw logical = raw if getattr(analog_tile, "analog_bias", False) and raw.dim() >= 2: logical = raw[:, : analog_tile.in_size] get_scales = getattr(analog_tile, "get_scales", None) if get_scales is None: return logical scales = get_scales() if scales is None: return logical scales = scales.to(device=logical.device, dtype=logical.dtype) return logical * scales.view(-1, 1) def _placeholder_data(self) -> PlaceholderDataView: """Return a metadata-only public placeholder.""" return PlaceholderDataView(self._raw_data().new_empty(self._logical_shape())) def _new_data_buffer(self) -> Tensor: """Create a zero digital buffer with the logical public weight shape.""" return self._raw_data().new_zeros(self._logical_shape()) def _buffer_data(self) -> Tensor: """Return the independent digital buffer.""" return object.__getattribute__(self, "_data_buffer") def _public_data(self) -> Tensor: """Return the tensor exposed by the active public data mode.""" mode = self._get_data_view_mode() if mode == AnalogContextDataViewMode.PLACEHOLDER: return self._placeholder_data() if mode == AnalogContextDataViewMode.DATA_VIEW: return ReadOnlyWeightView(self._logical_data()) if mode == AnalogContextDataViewMode.BUFFER: return self._buffer_data() raise RuntimeError(f"Unsupported AnalogContext data view mode: {mode}") def _torch_function_data(self, func_name: str) -> Tensor: """Return the tensor used to dispatch public torch operations.""" mode = self._get_data_view_mode() if mode == AnalogContextDataViewMode.PLACEHOLDER: if func_name not in _PLACEHOLDER_METADATA_FUNCTIONS: _raise_placeholder_read_error(func_name) return self._placeholder_data() if mode == AnalogContextDataViewMode.DATA_VIEW: return self._logical_data() if mode == AnalogContextDataViewMode.BUFFER: return self._buffer_data() raise RuntimeError(f"Unsupported AnalogContext data view mode: {mode}") def __getattribute__(self, name: str) -> Any: """Intercept public tensor reads according to ``data_view_mode``.""" if name == "grad_fn": return None if name in ("device", "dtype", "is_cuda", "is_leaf", "layout"): return getattr(self._raw_data(), name) if name == "requires_grad": with DisableTorchFunction(): # pylint: disable=not-context-manager return self.as_subclass(Tensor).requires_grad if name == "grad": # Mixed-precision optimizers SET param.grad (mpmixin.prepare_grad) and # torch optimizers READ it; both must hit the Parameter's own grad slot. # DisableTorchFunction bypasses the data-view dispatch that would # otherwise redirect the read to the raw data view and return None. with DisableTorchFunction(): # pylint: disable=not-context-manager return super().__getattribute__("grad") if name == "data": return self._public_data() if name == "shape": return self._logical_shape() if name == "ndim": return len(self._logical_shape()) if name in _VALUE_VIEW_PROPERTIES: return self._value_view_property(name) return super().__getattribute__(name) def _value_view_property(self, name: str) -> Any: """Return a value-bearing view property honoring the data-view mode. ``T`` / ``mT`` / ``H`` / ``mH`` / ``real`` / ``imag`` read weight values, so they follow the same rules as ``t()`` / ``conj()``: blocked in PLACEHOLDER mode, served from the read-only logical view in DATA_VIEW mode, and from the digital buffer in BUFFER mode. """ if self._get_data_view_mode() == AnalogContextDataViewMode.PLACEHOLDER: _raise_placeholder_read_error(name) return getattr(self._public_data(), name) def __setattr__(self, name: str, value: Any) -> None: """Block user-level replacement of ``.data``.""" if name == "data": raise RuntimeError( "Direct replacement of analog_ctx.data is not allowed. " "Use analog_tile.set_weights(new_weight) for programmatic writes." ) if name in ("grad", "requires_grad"): # Both must bypass the data-view dispatch: ``grad`` to hit the # Parameter's own grad slot, and ``requires_grad`` because its # setter would otherwise be routed through ``__torch_function__`` # (as ``__set__``) and raise in PLACEHOLDER mode. Toggling # ``requires_grad`` is how analog layers are frozen/unfrozen. with DisableTorchFunction(): # pylint: disable=not-context-manager super().__setattr__(name, value) return super().__setattr__(name, value) def _replace_raw_data(self, data: Tensor) -> None: """Replace the internal raw ``Parameter.data`` for tile rebinding.""" if isinstance(data, (ReadOnlyWeightView, PlaceholderDataView)): data = data.as_subclass(Tensor) with DisableTorchFunction(): # pylint: disable=not-context-manager super().__setattr__("data", data) if self._get_data_view_mode() != AnalogContextDataViewMode.BUFFER: return buffer = object.__getattribute__(self, "_data_buffer") logical_shape = self._logical_shape() if buffer is not None and buffer.shape == logical_shape: self._data_buffer = buffer.to(device=data.device, dtype=data.dtype) else: self._data_buffer = self._new_data_buffer() # -- existing API ---------------------------------------------------------
[docs] def set_indexed(self, value: bool = True) -> None: """Set the context to forward_indexed.""" self.use_indexed = value
[docs] def get_data(self) -> Tensor: """Get a detached tensor from the active public data view.""" if self._get_data_view_mode() == AnalogContextDataViewMode.PLACEHOLDER: _raise_placeholder_read_error("get_data") return self.data.detach()
[docs] def reset(self, analog_tile: Optional["SimulatorTileWrapper"] = None) -> None: """Reset the gradient trace and optionally sets the tile pointer.""" if analog_tile is not None: self.analog_tile = analog_tile self.analog_tile.analog_ctx = self self.analog_input = [] self.analog_grad_output = []
[docs] def has_gradient(self) -> bool: """Return whether a gradient trace was stored.""" return len(self.analog_input) > 0
def __copy__(self) -> Parameter: """Turn off copying of the pointers. Context will be re-created when tile is created""" return Parameter(self._raw_data()) def __deepcopy__(self, memo: Any) -> Parameter: """Turn off deep copying. Context will be re-created when tile is created""" return Parameter(self._raw_data())
[docs] def cuda(self, device: Optional[Union[torch_device, str, int]] = None) -> "AnalogContext": """Move the context to a cuda device. Args: device: the desired device of the tile. Returns: This context in the specified device. """ if not self.analog_tile.is_cuda: self._replace_raw_data(self.analog_tile._get_tile_weights_ref()) self.analog_tile = self.analog_tile.cuda(device) self.reset(self.analog_tile) return self
[docs] def cpu(self) -> "AnalogContext": """Move the context to CPU. Note: This is a no-op for CPU context. Returns: self """ self._replace_raw_data(self._raw_data().cpu()) if self.analog_tile is not None and self.analog_tile.is_cuda: self.analog_tile = self.analog_tile.cpu() self.reset(self.analog_tile) return self
[docs] def to(self, *args: Any, **kwargs: Any) -> "AnalogContext": """Move analog tiles of the current context to a device. Note: Please be aware that moving analog tiles from GPU to CPU is currently not supported. Caution: Other tensor conversions than moving the device to CUDA, such as changing the data type are not supported for analog tiles and will be simply ignored. Returns: This module in the specified device. """ # pylint: disable=invalid-name self._replace_raw_data(self._raw_data().to(*args, **kwargs)) device = None if "device" in kwargs: device = kwargs["device"] elif len(args) > 0 and not isinstance(args[0], (Tensor, dtype)): device = torch_device(args[0]) if device is not None: device = torch_device(device) if device.type == "cuda" and not self.analog_tile.is_cuda: self.cuda(device) elif device.type == "cpu" and self.analog_tile.is_cuda: self.cpu() return self
def __repr__(self) -> str: return "AnalogContext of " + self.analog_tile.get_brief_info()