Source code for aihwkit.simulator.configs.compounds

# -*- coding: utf-8 -*-

# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

# pylint: disable=too-many-instance-attributes, too-many-lines

"""Compound configuration for Analog (Resistive Device) tiles."""

from copy import deepcopy
from dataclasses import dataclass, field
from typing import ClassVar, List, Union, Any, TYPE_CHECKING
from warnings import warn

from aihwkit.exceptions import ConfigError
from aihwkit.simulator.parameters.helpers import _PrintableMixin, parameters_to_bindings
from aihwkit.simulator.parameters.training import UpdateParameters
from aihwkit.simulator.parameters.io import IOParameters
from aihwkit.simulator.parameters.enums import VectorUnitCellUpdatePolicy, RPUDataType

if TYPE_CHECKING:
    from aihwkit.simulator.configs.devices import PulsedDevice


[docs]@dataclass class UnitCell(_PrintableMixin): """Parameters that modify the behaviour of a unit cell.""" bindings_class: ClassVar[str] = "VectorResistiveDeviceParameter" bindings_module: ClassVar[str] = "devices" bindings_typed: ClassVar[bool] = True bindings_ignore: ClassVar[List] = ["diffusion", "lifetime"] unit_cell_devices: List = field(default_factory=list) """Devices that compose this unit cell.""" construction_seed: int = 0 """If not ``0``, set a unique seed for hidden parameters during construction. Applies to all ``unit_cell_devices``. """
[docs] def as_bindings(self, data_type: RPUDataType) -> Any: """Return a representation of this instance as a simulator bindings object of a given data type. """ raise NotImplementedError
[docs] def requires_diffusion(self) -> bool: """Return whether device has diffusion enabled.""" return any(dev.requires_diffusion() for dev in self.unit_cell_devices)
[docs] def requires_decay(self) -> bool: """Return whether device has decay enabled.""" return any(dev.requires_decay() for dev in self.unit_cell_devices)
############################################################################### # Specific devices based on ``unit cell``. ###############################################################################
[docs]@dataclass class VectorUnitCell(UnitCell): """Abstract resistive device that combines multiple pulsed resistive devices in a single 'unit cell'. For instance, a vector device can consist of 2 resistive devices where the sum of the two resistive values are coded for each weight of a cross point. """ bindings_class: ClassVar[str] = "VectorResistiveDeviceParameter" update_policy: VectorUnitCellUpdatePolicy = VectorUnitCellUpdatePolicy.ALL """The update policy of which if the devices will be receiving the update of a mini-batch.""" first_update_idx: int = 0 """Device that receives the first mini-batch. Useful only for ``VectorUnitCellUpdatePolicy.SINGLE_FIXED``. """ gamma_vec: List[float] = field(default_factory=list, metadata={"hide_if": []}) """Weighting of the unit cell devices to reduce to final weight. User-defined weightening can be given as a list if factors. If not given, each device index of the unit cell is weighted by equal amounts (:math:`1/n`). """
[docs] def as_bindings(self, data_type: RPUDataType) -> Any: """Return a representation of this instance as a simulator bindings object.""" vector_parameters = parameters_to_bindings(self, data_type) if not isinstance(self.unit_cell_devices, list): raise ConfigError("unit_cell_devices should be a list of devices") for param in self.unit_cell_devices: device_parameters = param.as_bindings(data_type) if not vector_parameters.append_parameter(device_parameters): raise ConfigError("Could not add unit cell device parameter") return vector_parameters
[docs]@dataclass class ReferenceUnitCell(UnitCell): """Abstract device model takes two arbitrary device per cross-point and implements an device with reference pair. The update will only be on the 0-th device whereas the other will stay fixed. The resulting effective weight is the difference of the two. Note: Exactly 2 devices are used, if more are given the are discarded, if less, the same device will be used twice. Note: The reference device weights will all zero on default. To set the reference device with a particular value one can select the device update index:: analog_tile.set_hidden_update_index(1) analog_tile.set_weights(W) analog_tile.set_hidden_update_index(0) # set back to 0 for the following updates """ bindings_class: ClassVar[str] = "VectorResistiveDeviceParameter" update_policy: VectorUnitCellUpdatePolicy = VectorUnitCellUpdatePolicy.SINGLE_FIXED """The update policy of which if the devices will be receiving the update of a mini-batch. Caution: This parameter should be kept to SINGLE_FIXED for this device. """ first_update_idx: int = 0 """Device that receives the update.""" gamma_vec: List[float] = field( default_factory=lambda: [1.0, -1.0], metadata={"hide_if": [1.0, -1.0]} ) """Weighting of the unit cell devices to reduce to final weight. Note: While user-defined weighting can be given it is suggested to keep it to the default ``[1, -1]`` to implement the reference device subtraction. """
[docs] def as_bindings(self, data_type: RPUDataType) -> Any: """Return a representation of this instance as a simulator bindings object.""" vector_parameters = parameters_to_bindings(self, data_type) if not isinstance(self.unit_cell_devices, list): raise ConfigError("unit_cell_devices should be a list of devices") if len(self.unit_cell_devices) > 2: self.unit_cell_devices = self.unit_cell_devices[:2] elif len(self.unit_cell_devices) == 1: self.unit_cell_devices = [ self.unit_cell_devices[0], deepcopy(self.unit_cell_devices[0]), ] elif len(self.unit_cell_devices) != 2: raise ConfigError("ReferenceUnitCell expects two unit_cell_devices") for param in self.unit_cell_devices: device_parameters = param.as_bindings(data_type) if not vector_parameters.append_parameter(device_parameters): raise ConfigError("Could not add unit cell device parameter") return vector_parameters
[docs]@dataclass class OneSidedUnitCell(UnitCell): """Abstract device model takes an arbitrary device per crosspoint and implements an explicit plus-minus device pair with one sided update. One device will receive all positive updated and the other all negative updates. Since the devices will quickly saturate, the device implements a refresh strategy. With fixed frequency per update call (``refresh_every``, in units of single vector updates) a refresh is performed. During the refresh, each column will be read using a forward pass (parameters are specified with ``refresh_forward``) to read out the positive and negative device weights. Whether a weight needs refreshing is determined by the following criterion: The larger weight (normalized by the tile-wise fixed w_max setting) is tested against the upper threshold. If larger than the upper threshold, and the normalized lower weight is larger than the lower threshold, then a reset and rewriting will be performed. Note that this abstract device needs single devices that are derived from :class:`~PulsedDevice`. The reset properties (bias and cycle-to-cycle noise) can be thus adjusted (see :class:`~PulsedDevice`). The rewriting of the computed difference is only done onto one of the two devices using the update properties defined in ``refresh_update``. Note: This device will take only the first ``unit_cell_device`` to generate two devices. Both positive and negative device will thus have the same (reversed) parameters, e.g. the specified ``w_min``, will become the w_max of the negative device. """ bindings_class: ClassVar[str] = "OneSidedResistiveDeviceParameter" refresh_every: int = 0 """How often a refresh is performed (in units of the number of vector updates). Note: If a refresh is done, full reads of both positive and negative devices are performed. Additionally, if single devices deemed to be refreshed, an (open-loop) re-write is done (once per column). Thus, refresh might have considerable runtime impacts. """ units_in_mbatch: bool = True """If set, the ``refresh_every`` counter is given in ``m_batch`` which is the re-use factor. Smaller numbers are not possible. Caution: For CUDA devices, refresh is always done in ``m_batch`` (ie the number of re-use per layer for a mini-batch). Smaller numbers will have no effect. """ refresh_upper_thres: float = 0.75 """Upper threshold for determining the refresh, see above.""" refresh_lower_thres: float = 0.25 """Lower threshold for determining the refresh, see above.""" refresh_forward: IOParameters = field(default_factory=IOParameters) """Input-output parameters that define the read during a refresh event. :class:`~aihwkit.simulator.configs.IOParameters` that define the read (forward) of an refresh event. For instance the amount of noise or whether refresh is done using a ADC/DAC etc. """ refresh_update: UpdateParameters = field(default_factory=UpdateParameters) """Update parameters that define the type of update used for each refresh event. Update parameters :class:`~aihwkit.simulator.configs.UpdateParameters` that define the type of update used for each refresh event. """ copy_inverted: bool = False """Whether the use the "down" update behavior of the first device for the negative updates instead of the positive half of the second device."""
[docs] def as_bindings(self, data_type: RPUDataType) -> Any: """Return a representation of this instance as a simulator bindings object.""" if not isinstance(self.unit_cell_devices, list): raise ConfigError("unit_cell_devices should be a list of devices") onesided_parameters = parameters_to_bindings(self, data_type) device_parameter0 = self.unit_cell_devices[0].as_bindings(data_type) if len(self.unit_cell_devices) == 0 or len(self.unit_cell_devices) > 2: raise ConfigError("Need 1 or 2 unit_cell_devices") if len(self.unit_cell_devices) == 1: device_parameter1 = device_parameter0 else: device_parameter1 = self.unit_cell_devices[1].as_bindings(data_type) # need to be exactly 2 and same parameters if not onesided_parameters.append_parameter(device_parameter0): raise ConfigError("Could not add unit cell device parameter") if not onesided_parameters.append_parameter(device_parameter1): raise ConfigError( "Could not add unit cell device parameter " + "(both devices need to be of the same type)" ) return onesided_parameters
[docs]@dataclass class DifferenceUnitCell(OneSidedUnitCell): """Deprecated alias to ``OneSidedUnitCell``.""" def __post__init__(self) -> None: warn( "The DifferenceUnitCell class is deprecated. Please use OneSidedUnitCell instead.", DeprecationWarning, )
[docs]@dataclass class TransferCompound(UnitCell): r"""Abstract device model that takes 2 or more devices and implements a transfer-based learning rule. It uses a (partly) hidden weight (where the SGD update is accumulated), which then is transferred partly and occasionally to the visible weight. This can implement an analog friendly variant of stochastic gradient descent (Tiki-taka), as described in `Gokmen & Haensch (2020)`_. The hidden weight is always the first in the list of ``unit_cell_devices`` given, and the transfer is done from left to right. The first of the ``unit_cell_devices`` can have different HW specifications from the rest, but the others need to be of identical specs. In detail, when specifying the list of devices only the first two will actually be used and the rest discarded and instead replaced by the second device specification. In this manner, the *fast* crossbar (receiving the SGD updates) and the *slow* crossbar (receiving the occasional partial transfers from the fast) can have different specs, but all additional slow crossbars (receiving transfers from the left neighboring crossbar in the list of ``unit_cell_devices``) need to be of the same spec. The rate of transfer (e.g. learning rate and how often and how many columns/rows per transfer) and the type (ie. with ADC or without, with noise etc.) can be adjusted. Each transfer event that is triggered by counting the update cycles (in units of either mini-batch or single mat-vecs), ``n_reads_per_transfer`` columns/rows are read from the left device using the forward pass with transfer vectors as input and transferred to the right (taking the order of the ``unit_cell_devices`` list) using the outer-product update with the read-out vectors and the transfer vectors. Currently, transfer vectors are fixed to be one-hot vectors. The columns/rows to take are in sequential order and warped around at the edge of the crossbar. The learning rate and forward and update specs of the transfer can be user-defined. The weight that is seen in the forward and backward pass is governed by the :math:`\gamma` weightening setting. Note: Here the devices could be either transferred in analog (essentially within the unit cell) or on separate arrays (using the usual (non-ideal) forward pass and update steps. This can be set with ``transfer_forward`` and ``transfer_update``. .. _Gokmen & Haensch (2020): https://www.frontiersin.org/articles/10.3389/fnins.2020.00103/full """ bindings_class: ClassVar[str] = "TransferResistiveDeviceParameter" gamma: float = 0.0 r"""Weighting factor to compute the effective SGD weight from the hidden matrices. The default scheme is: .. math:: g^{n-1} W_0 + g^{n-2} W_1 + \ldots + g^0 W_{n-1} """ gamma_vec: List[float] = field(default_factory=list, metadata={"hide_if": []}) """User-defined weightening. User-defined weightening can be given as a list if weights in which case the default weightening scheme with ``gamma`` is not used. """ transfer_every: float = 1.0 """Transfers every :math:`n` mat-vec operations or :math:`n` batches. Transfers every :math:`n` mat-vec operations (rounded to multiples/ratios of ``m_batch`` for CUDA). If ``units_in_mbatch`` is set, then the units are in ``m_batch`` instead of mat-vecs, which is equal to the overall the weight re-use during a while mini-batch. Note: If ``transfer_every`` is 0.0 *no transfer* will be made. If not given explicitely with ``transfer_every_vec``, then the higher transfer cycles are geometrically scaled, the first is set to transfer_every. Each next transfer cycle is multiplied by ``x_size / n_reads_per_transfer``. """ no_self_transfer: bool = True """Whether to set the transfer rate of the last device (which is applied to itself) to zero.""" transfer_every_vec: List[float] = field(default_factory=list, metadata={"hide_if": []}) """Transfer cycles lengths. A list of :math:`n` entries, to explicitly set the transfer cycles lengths. In this case, the above defaults are ignored. """ units_in_mbatch: bool = True """Units for ``transfer_every``. If set, then the cycle length units of ``transfer_every`` are in ``m_batch`` instead of mat-vecs, which is equal to the overall of the weight re-use during a while mini-batch. """ n_reads_per_transfer: int = 1 """Number of consecutive reads to use during transfer events. How many consecutive columns or rows to read (from one tile) and write (to the next tile) every transfer event. For read, the input is a 1-hot vector. Once the final columns or row is reached, reading starts again from the first. """ transfer_columns: bool = True """Whether to read and transfer columns or rows. If set, read is done with an additional forward pass determined by the ``transfer_forward`` settings. If not set, rows are transferred instead, that is, the read is done internally with a backward pass instead. However, the parameters defining the backward are still given by setting the ``transfer_forward`` field for convenience. """ with_reset_prob: float = 0.0 """Whether to apply reset of the columns that were transferred with a given probability. Note: Reset is only available in case of column reads (``transfer_columns==True``). """ random_selection: bool = False """Whether to select a random starting column or row. Whether to select a random starting column or row for each transfer event and not take the next column or row that was previously not transferred as a starting column or row (the default). """ fast_lr: float = 1.0 """Whether to set the `fast` tile's learning rate. If set, then the SGD gradient update onto the first (fast) tile is set to this learning rate and is kept constant even when the SGD learning rate is scheduled. The SGD learning rate is then only used to scale the transfer LR (see ``scale_transfer_lr``). """ transfer_lr: float = 1.0 """Learning rate (LR) for the update step of the transfer event. Per default all learning rates are identical. If ``scale_transfer_lr`` is set, the transfer LR is scaled by current learning rate of the SGD. Note: LR is always a positive number, sign will be correctly applied internally. """ transfer_lr_vec: List[float] = field(default_factory=list, metadata={"hide_if": []}) """Transfer LR for each individual transfer in the device chain can be given.""" scale_transfer_lr: bool = True """Whether to give the transfer_lr in relative units. ie. whether to scale the transfer LR with the current LR of the SGD. """ transfer_forward: IOParameters = field(default_factory=IOParameters) """Input-output parameters that define the read of a transfer event. :class:`~aihwkit.simulator.configs.IOParameters` that define the read (forward or backward) of an transfer event. For instance the amount of noise or whether transfer is done using a ADC/DAC etc. """ transfer_update: UpdateParameters = field(default_factory=UpdateParameters) """Update parameters that define the type of update used for each transfer event. Update parameters :class:`~aihwkit.simulator.configs.UpdateParameters` that define the type of update used for each transfer event. """
[docs] def as_bindings(self, data_type: RPUDataType) -> Any: """Return a representation of this instance as a simulator bindings object.""" if not isinstance(self.unit_cell_devices, list): raise ConfigError("unit_cell_devices should be a list of devices") n_devices = len(self.unit_cell_devices) transfer_parameters = parameters_to_bindings(self, data_type) param_fast = self.unit_cell_devices[0].as_bindings(data_type) param_slow = self.unit_cell_devices[1].as_bindings(data_type) if not transfer_parameters.append_parameter(param_fast): raise ConfigError("Could not add unit cell device parameter") for _ in range(n_devices - 1): if not transfer_parameters.append_parameter(param_slow): raise ConfigError("Could not add unit cell device parameter") return transfer_parameters
[docs]@dataclass class BufferedTransferCompound(TransferCompound): r"""Abstract device model that takes 2 or more devices and implements a buffered transfer-based learning rule. Different to :class:`TransferCompound`, however, readout is done first onto a digital buffer (in floating point precision), from which then the second analog matrix is updated. This second step is very similar to the analog update in :class:`MixedPrecisionCompound`. Note, however, that in contrast to :class:`MixedPrecisionCompound` the rank-update is still done in analog with parallel update using pulse trains. The buffer is assumed to be in floating point precision and only one row/column at a time needs to be processed in one update cycle, thus greatly reducing on-chip memory requirements. For details, see `Gokmen (2021)`_. .. _Gokmen (2021): https://www.frontiersin.org/articles/10.3389/frai.2021.699148/full """ bindings_class: ClassVar[str] = "BufferedTransferResistiveDeviceParameter" thres_scale: float = 1.0 """Threshold scale for buffer to determine whether to transfer to next device. Will be multiplied by the device granularity to get the threshold. """ step: float = 1.0 """Value to fill the ``d`` vector for the update if buffered value is above threshold. """ momentum: float = 0.1 """Momentum of the buffer. After transfer, this momentum fraction stays on the buffer instead of subtracting all of what was transferred. """ forget_buffer: bool = True """Whether to forget the value of the buffer after transfer. If enabled, the buffer is reset to the momentum times the transferred value. Thus, if the number of pulses is limited to e.g. 1 (``desired_BL`` in the ``transfer_update``) the transfer might be clipped and the potentially larger buffer values are forgotten. If disabled, then the buffer values are faithfully subtracted by the amount transferred (times one minus momentum). """ transfer_update: UpdateParameters = field( default_factory=lambda: UpdateParameters( desired_bl=1, update_bl_management=False, update_management=False ) ) """Update parameters that define the type of update used for each transfer event. Update parameters :class:`~aihwkit.simulator.configs.UpdateParameters` that define the type of update used for each transfer event. """
[docs]@dataclass class ChoppedTransferCompound(TransferCompound): r"""Abstract device model that takes exactly two devices and implements a chopped and buffered transfer-based learning rule. Similar to :class:`BufferedTransferCompound`, however, the gradient update onto the fast tile is done with `choppers`, that is random sign changes. These sign changes reduce any potential bias or long-term correlation that might be present on the fast device, as gradients are written in both directions and signs recovered after readout (averaging out any existing correlations during the update on the devices). Here the choices of transfer are more restricted, to enable a fast CUDA optimization of the transfer. In particular, only 2 devices are supported, transfer has to be sequential with exactly one read at each transfer event (that is the settings ``random_selection=False``, ``with_reset_prob=0.0``, ``n_reads_per_transfer=1``). Note: This device is identical to :class:`BufferedTransferCompound` if the chopper probabilities are set to 0 (with the above restrictions), but will run up to 40x faster on GPU for larger batches and small settings of ``transfer_every``, because of a fused CUDA kernel design. """ bindings_class: ClassVar[str] = "ChoppedTransferResistiveDeviceParameter" in_chop_prob: float = 0.1 """Switching probability of the input choppers. The chopper will be switched with the given probability once after the corresponding vector read (column or row). """ in_chop_random: bool = True """Whether to switch randomly (default) or regular. If regular, then the ``in_chop_prob`` sets the frequency of switching, ie. ``MIN(1/in_chop_prob, 2)`` is the period of switching in terms of number of reads of a particular row / col. All rows/cols will switch at the same matrix update cycle. """ out_chop_prob: float = 0.0 """Switching probability of the output choppers. The chopper will be switched with the given probability once after a full matrix update has been accomplished. """ buffer_granularity: float = 1.0 """ Granularity if the buffer. """ auto_granularity: float = 0.0 """If set, scales the ``buffer_granularity`` based on the expected number of MVMs needed to cross the buffer, ie:: buffer_granularity *= auto_granularity / (in_size * transfer_every) * weight_granularity Typical value would be e.g. 1000. But see ``auto_granularity_scaled_with_weight_granularity`` switch. """ step: float = 1.0 """Value to fill the ``d`` vector for the update if buffered value is above threshold. """ momentum: float = 0.0 """Momentum of the buffer. After transfer, this momentum fraction stays on the buffer instead of subtracting all of what was transferred. """ forget_buffer: bool = True """Whether to forget the value of the buffer after transfer. If enabled, the buffer is reset to the momentum times the transferred value. Thus, if the number of pulses is limited to e.g. 1 (``desired_BL`` in the ``transfer_update``) the transfer might be clipped and the potentially larger buffer values are forgotten. If disabled, then the buffer values are faithfully subtracted by the amount transferred (times one minus momentum). """ no_buffer: bool = False """Turn off the usage of the digital buffer. This is identical to the Tiki-taka (version 1) algorithm with `gamma=0` """ units_in_mbatch: bool = False """Units for ``transfer_every``. If set, then the cycle length units of ``transfer_every`` are in ``m_batch`` instead of mat-vecs, which is equal to the overall of the weight re-use during a while mini-batch. """ auto_scale: bool = False """Scaling the weight gradient onto the fast matrix by the averaged recent past of the maximum gradient. This will dynamically compute a reasonable update strength onto the fast matrix. ``fast_lr`` can be used to scale the gradient update further. """ auto_momentum: float = 0.99 """Momentum of the gradient when using auto scale """ correct_gradient_magnitudes: bool = False """Scale the transfer LR with the fast LR to yield the correct gradient magnitudes. Note: ``auto_granularity`` has no effect in this case """ transfer_columns: bool = True """Whether to read and transfer columns or rows. If set, read is done with an additional forward pass determined by the ``transfer_forward`` settings. If not set, rows are transferred instead, that is, the read is done internally with a backward pass instead. However, the parameters defining the backward are still given by setting the ``transfer_forward`` field for convenience. """ fast_lr: float = 1.0 """Whether to set the `fast` tile's learning rate. If set, then the SGD gradient update onto the first (fast) tile is set to this learning rate and is kept constant even when the SGD learning rate is scheduled. The SGD learning rate is then only used to scale the transfer LR (see ``scale_transfer_lr``). """ transfer_lr: float = 1.0 """Learning rate (LR) for the update step of the transfer event. Per default all learning rates are identical. If ``scale_transfer_lr`` is set, the transfer LR is scaled by current learning rate of the SGD. Note: LR is always a positive number, sign will be correctly applied internally. """ scale_transfer_lr: bool = True """Whether to give the transfer_lr in relative units. ie. whether to scale the transfer LR with the current LR of the SGD. """ transfer_forward: IOParameters = field(default_factory=IOParameters) """Input-output parameters that define the read of a transfer event. :class:`~aihwkit.simulator.configs.IOParameters` that define the read (forward or backward) of an transfer event. For instance the amount of noise or whether transfer is done using a ADC/DAC etc. """ transfer_update: UpdateParameters = field( default_factory=lambda: UpdateParameters( desired_bl=1, update_bl_management=False, update_management=False ) ) """Update parameters that define the type of update used for each transfer event. Update parameters :class:`~aihwkit.simulator.configs.UpdateParameters` that define the type of update used for each transfer event. """
[docs] def as_bindings(self, data_type: RPUDataType) -> Any: """Return a representation of this instance as a simulator bindings object.""" if not isinstance(self.unit_cell_devices, list): raise ConfigError("unit_cell_devices should be a list of devices") n_devices = len(self.unit_cell_devices) if n_devices != 2: raise ConfigError("Only 2 devices supported for ChoppedTransferCompound") transfer_parameters = parameters_to_bindings(self, data_type) param_fast = self.unit_cell_devices[0].as_bindings(data_type) param_slow = self.unit_cell_devices[1].as_bindings(data_type) if not transfer_parameters.append_parameter(param_fast): raise ConfigError("Could not add unit cell device parameter") for _ in range(n_devices - 1): if not transfer_parameters.append_parameter(param_slow): raise ConfigError("Could not add unit cell device parameter") return transfer_parameters
[docs]@dataclass class DynamicTransferCompound(ChoppedTransferCompound): r"""Abstract device model that takes exactly two devices and implements a chopped and buffered transfer-based learning rule. Similar to :class:`ChoppedTransferCompound`, however, the gradient update onto the fast tile is done with a statistically motivated gradient computation: The mean of the reads during the last chopper period is compared with the current (switched) chopper period and the update (transfer) onto the slow matrix is proportionally to the difference. In addition, no update is done if the difference is not significantly different from zero, judged by the running std estimation (thus computing the standard error of the mean). Note that the choices of transfer are similarly restricted as in the :class:`ChoppedTransferCompound`, to enable a fast CUDA optimization of the transfer. In particular, only 2 devices are supported, transfer has to be sequential with exactly one read at each transfer event (that is the settings ``random_selection=False``, ``with_reset_prob=0.0``, ``n_reads_per_transfer=1``). """ bindings_class: ClassVar[str] = "DynamicTransferResistiveDeviceParameter" in_chop_prob: float = 0.1 """Switching probability of the input choppers. The chopper will be switched with the given frequency once after the corresponding vector read (column or row). Note: In contrast to :class:`ChoppedTransferCompound` here the chopping periods are regular with the switches occurring with frequency of the given value """ in_chop_random: bool = False """Whether to switch randomly or regular (default). If regular, then the ``in_chop_prob`` sets the frequency of switching, ie. ``MIN(1/in_chop_prob, 2)`` is the period of switching in terms of number of reads of a particular row / col. All rows/cols will switch at the same matrix update cycle. """ experimental_correct_accumulation: bool = False """Correct the gradient accumulation for the multiple reads. Caution: This feature is only approximately computed in the CPU version if ``in_chop_random`` is set. """ step: float = 1.0 """Value to fill the ``d`` vector for the update if buffered value is above threshold. """ momentum: float = 0.0 """Momentum. If enabled an additional momentum matrix is used that is filtering the computed weight update in the usual manner. """ transfer_columns: bool = True """Whether to read and transfer columns or rows. If set, read is done with an additional forward pass determined by the ``transfer_forward`` settings. If not set, rows are transferred instead, that is, the read is done internally with a backward pass instead. However, the parameters defining the backward are still given by setting the ``transfer_forward`` field for convenience. """ buffer_granularity: float = 1.0 """ Granularity if the buffer. """ buffer_cap: float = 0.0 """Capacity of buffer. Capacity in times of max steps (``transfer_update.desired_bl``). Only applied in case of ``forget_buffer=False`` """ forget_buffer: bool = True """Whether to forget the value of the buffer after transfer. If enabled, the buffer is reset to the momentum times the transferred value. Thus, if the number of pulses is limited to e.g. 1 (``desired_BL`` in the ``transfer_update``) the transfer might be clipped and the potentially larger buffer values are forgotten. If disabled, then the buffer values are faithfully subtracted by the amount transferred (times one minus momentum). """ fast_lr: float = 1.0 """Whether to set the `fast` tile's learning rate. If set, then the SGD gradient update onto the first (fast) tile is set to this learning rate and is kept constant even when the SGD learning rate is scheduled. The SGD learning rate is then only used to scale the transfer LR (see ``scale_transfer_lr``). """ transfer_lr: float = 1.0 """Learning rate (LR) for the update step of the transfer event. Per default all learning rates are identical. If ``scale_transfer_lr`` is set, the transfer LR is scaled by current learning rate of the SGD. Note: LR is always a positive number, sign will be correctly applied internally. """ scale_transfer_lr: bool = True """Whether to give the transfer_lr in relative units. ie. whether to scale the transfer LR with the current LR of the SGD. """ tail_weightening: float = 5.0 """Weight the tail of the chopper period more (if larger than 1). This helps to reduce the impact of the transient period""" transfer_forward: IOParameters = field(default_factory=IOParameters) """Input-output parameters that define the read of a transfer event. :class:`~aihwkit.simulator.configs.IOParameters` that define the read (forward or backward) of an transfer event. For instance the amount of noise or whether transfer is done using a ADC/DAC etc. """ transfer_update: UpdateParameters = field( default_factory=lambda: UpdateParameters( desired_bl=1, update_bl_management=False, update_management=False ) ) """Update parameters that define the type of update used for each transfer event. Update parameters :class:`~aihwkit.simulator.configs.UpdateParameters` that define the type of update used for each transfer event. """
[docs] def as_bindings(self, data_type: RPUDataType) -> Any: """Return a representation of this instance as a simulator bindings object.""" if not isinstance(self.unit_cell_devices, list): raise ConfigError("unit_cell_devices should be a list of devices") n_devices = len(self.unit_cell_devices) if n_devices != 2: raise ConfigError("Only 2 devices supported for ChoppedTransferCompound") transfer_parameters = parameters_to_bindings(self, data_type) param_fast = self.unit_cell_devices[0].as_bindings(data_type) param_slow = self.unit_cell_devices[1].as_bindings(data_type) if not transfer_parameters.append_parameter(param_fast): raise ConfigError("Could not add unit cell device parameter") for _ in range(n_devices - 1): if not transfer_parameters.append_parameter(param_slow): raise ConfigError("Could not add unit cell device parameter") return transfer_parameters
############################################################################### # Specific compound-devices with digital rank update ###############################################################################
[docs]@dataclass class DigitalRankUpdateCell(_PrintableMixin): """Parameters that modify the behavior of the digital rank update cell. This is the base class for devices that compute the rank update in digital and then (occasionally) transfer the information to the (analog) crossbar array that is used during forward and backward. """ bindings_class: ClassVar[str] = "AbstractResistiveDeviceParameter" bindings_ignore: ClassVar[List] = ["diffusion", "lifetime"] device: Union["PulsedDevice", OneSidedUnitCell, VectorUnitCell, ReferenceUnitCell] = field( default_factory=VectorUnitCell ) """(Analog) device that are used for forward and backward.""" construction_seed: int = 0 """If not ``0``, set a unique seed for hidden parameters during construction. Applies to ``device``. """
[docs] def as_bindings(self, data_type: RPUDataType) -> Any: """Return a representation of this instance as a simulator bindings object.""" raise NotImplementedError
[docs] def requires_diffusion(self) -> bool: """Return whether device has diffusion enabled.""" return self.device.requires_diffusion()
[docs] def requires_decay(self) -> bool: """Return whether device has decay enabled.""" return self.device.requires_decay()
[docs]@dataclass class MixedPrecisionCompound(DigitalRankUpdateCell): r"""Abstract device model that takes 1 (analog) device and implements a transfer-based learning rule, where the outer product is computed in digital. Here, the outer product of the activations and error is done on a full-precision floating-point :math:`\chi` matrix. Then, with a threshold given by the ``granularity``, pulses will be applied to transfer the information row-by-row to the analog matrix. For details, see `Nandakumar et al. Front. in Neurosci. (2020)`_. Note: This version of update is different from a parallel update in analog other devices are implementing with stochastic pulsing, as here :math:`{\cal O}(n^2)` digital computations are needed to compute the outer product (rank update). This need for digital compute in potentially high precision might result in inferior run time and power estimates in real-world applications, although sparse integer products can potentially be employed to speed up to improve run time estimates. For details, see discussion in `Nandakumar et al. Front. in Neurosci. (2020)`_. .. _`Nandakumar et al. Front. in Neurosci. (2020)`: https://doi.org/10.3389/fnins.2020.00406 """ bindings_class: ClassVar[str] = "MixedPrecResistiveDeviceParameter" transfer_every: int = 1 """Transfers every :math:`n` mat-vec operations. Transfers every :math:`n` mat-vec operations (rounded to multiples/ratios of ``m_batch``). Standard setting is 1.0 for mixed precision, but it could potentially be reduced to get better run time estimates. """ n_rows_per_transfer: int = -1 r"""How many consecutive rows to write to the tile from the :math:`\chi` matrix. ``-1`` means full matrix read each transfer event. """ random_row: bool = False """Whether to select a random starting row. Whether to select a random starting row for each transfer event and not take the next row that was previously not transferred as a starting row (the default). """ granularity: float = 0.0 r"""Granularity of the device. Granularity :math:`\varepsilon` of the device that is used to calculate the number of pulses transferred from :math:`\chi` to analog. If 0, it will take granularity from the analog device used. """ transfer_lr: float = 1.0 r"""Scale of the transfer to analog . The update onto the analog tile will be proportional to :math:`\langle\chi/\varepsilon\rangle\varepsilon\lambda_\text{tr}`, where :math:`\lambda_\text{tr}` is given by ``transfer_lr`` and :math:`\varepsilon` is the granularity. """ n_x_bins: int = 0 """The number of bins to discretize (symmetrically around zero) the activation before computing the outer product. Dynamic quantization is used by computing the absolute max value of each input. Quantization can be turned off by setting this to 0. """ n_d_bins: int = 0 """The number of bins to discretize (symmetrically around zero) the error before computing the outer product. Dynamic quantization is used by computing the absolute max value of each error vector. Quantization can be turned off by setting this to 0. """ stoc_round_x: bool = True """Whether to use stochastic rounding in case of quantization of the input x. """ stoc_round_d: bool = True """Whether to use stochastic rounding in case of quantization of the error d. """
[docs] def as_bindings(self, data_type: RPUDataType) -> Any: """Return a representation of this instance as a simulator bindings object.""" mixed_prec_parameter = parameters_to_bindings(self, data_type) param_device = self.device.as_bindings(data_type) if not mixed_prec_parameter.set_device_parameter(param_device): raise ConfigError("Could not add device parameter") return mixed_prec_parameter