Source code for aihwkit.simulator.configs.helpers

# -*- coding: utf-8 -*-

# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# Licensed under the MIT license. See LICENSE file in the project root for details.

"""Helper for generating presets."""

from typing import Union, Type, Callable, Any
from copy import deepcopy

from aihwkit.exceptions import ArgumentError
from aihwkit.simulator.configs.configs import (
    UnitCellRPUConfig,
    SingleRPUConfig,
    DigitalRankUpdateRPUConfig,
)
from aihwkit.simulator.configs.devices import PulsedDevice
from aihwkit.simulator.configs.compounds import (
    TransferCompound,
    ChoppedTransferCompound,
    DynamicTransferCompound,
    MixedPrecisionCompound,
    VectorUnitCell,
)
from aihwkit.simulator.parameters.training import UpdateParameters
from aihwkit.simulator.parameters.io import IOParameters
from aihwkit.simulator.parameters.enums import (
    VectorUnitCellUpdatePolicy,
    NoiseManagementType,
    BoundManagementType,
)


[docs]def build_config( algorithm: str, device: Union[Type[PulsedDevice], PulsedDevice, Callable], io_parameters: Union[Type[IOParameters], IOParameters, Callable] = IOParameters, up_parameters: Union[Type[UpdateParameters], UpdateParameters, Callable] = UpdateParameters, n_devices: int = 1, construction_seed: int = 0, **kwargs: Any, ) -> Union[UnitCellRPUConfig, SingleRPUConfig, DigitalRankUpdateRPUConfig]: """Generate a RPU configuration for analog training using a specific device model and a given training algoithm. Args: algorithm: The type of the training algorithm. Valid choices are: "sgd": Random pulsed (naive) SGD on analog crossbars. See `Gokmen & Vlasov, Front. Neurosci. 2016`_ for details. "mp", "mixed-precision": Mixed-precision analog, where the gradient is computed in digital and only the forward abd backward pass is in analog. Uses :class:`~aihwkit.simulator.configs.compounds.MixedPrecisionCompound`. See also `Nandakumar et al. Front. in Neurosci. (2020)`_ for details. "tiki-taka", "ttv1", "tt": Tiki-taka I algorithm. Uses :class:`~aihwkit.simulator.configs.compounds.TransferCompound`. See `Gokmen & Haensch, Front. Neurosci. 2020`_ for details. "ttv2": second version of the Tiki-taka algorithm (TTv2). Uses :class:`~aihwkit.simulator.configs.compounds.ChoppedTransferCompound` with chopper probabilty set to 0. See `Gokmen, Front. Artif. Intell. 2021`_ for details. "chopped-ttv2", "ttv3", "c-ttv2": Chopped version of TTv2 algorithm. Uses :class:`~aihwkit.simulator.configs.compounds.ChoppedTransferCompound`. See `Rasch et al., ArXiv 2023`_ for details. "agad", "ttv4": Analog gradient accumulation with dynamic reference computation. Uses :class:`~aihwkit.simulator.configs.compounds.StatsticalransferCompound`. See `Rasch et al., ArXiv 2023`_ for details. device: Device configuration of the analog devices. Can be the class or the actual device. All available device will have the same configuration. io_parameters: IOParameters class (or actual instance) that are used for forward / backward and transfer. Default is :class:`~aihwkit.simulator.config.IOParameters`. up_parameters: UpdateParameters class (or actual instance) that are used for update and transfer update. Default is :class:`~aihwkit.simulator.config.UpdateParameters`. n_devices: In case of SGD, how many device pairs are used in the unit cell. Note: This option is only applied for ``algorithm="sgd"`` and ignored for all other algorithm choices. construction_seed: Seed of the construction kwargs: Other RPUConfig fields to assign explicitely (e.g. ``mapping``). Returns: RPU config according to the algorithm and device settings. Raises: ArgumentError: in case algorithm is not known .. _`Gokmen & Vlasov, Front. Neurosci. 2016`: \ https://www.frontiersin.org/articles/10.3389/fnins.2016.00333/full .. _`Gokmen & Haensch, Front. Neurosci. 2020`: \ https://www.frontiersin.org/articles/10.3389/fnins.2020.00103/full .. _`Gokmen, Front. Artif. Intell. 2021`: \ https://www.frontiersin.org/articles/10.3389/frai.2021.699148/full .. _`Rasch et al., ArXiv 2023`: \ https://arxiv.org/abs/2303.04721 .. _`Nandakumar et al. Front. in Neurosci. (2020)`: \ https://doi.org/10.3389/fnins.2020.00406 """ # pylint: disable=too-many-statements, too-many-return-statements if isinstance(device, PulsedDevice): device_to_use = device def device_fun(**kwargs: Any) -> PulsedDevice: dev = deepcopy(device_to_use) dev.__dict__.update(**kwargs) return dev device = device_fun if isinstance(io_parameters, IOParameters): io_pars_to_use = io_parameters def io_pars_fun(**kwargs: Any) -> IOParameters: io_pars = deepcopy(io_pars_to_use) io_pars.__dict__.update(**kwargs) return io_pars io_parameters = io_pars_fun if isinstance(up_parameters, UpdateParameters): up_pars_to_use = up_parameters def up_pars_fun(**kwargs: Any) -> UpdateParameters: up_pars = deepcopy(up_pars_to_use) up_pars.__dict__.update(**kwargs) return up_pars up_parameters = up_pars_fun if algorithm.lower() in ["sgd"]: if n_devices == 1: return SingleRPUConfig( device=device(construction_seed=construction_seed), forward=io_parameters(), backward=io_parameters(), update=up_parameters(), **kwargs, ) return UnitCellRPUConfig( device=VectorUnitCell( unit_cell_devices=[device() for _ in range(n_devices)], update_policy=VectorUnitCellUpdatePolicy.SINGLE_RANDOM, construction_seed=construction_seed, ), forward=io_parameters(), backward=io_parameters(), update=up_parameters(), **kwargs, ) if algorithm.lower() in ["tiki-taka", "tt", "ttv1"]: return UnitCellRPUConfig( device=TransferCompound( unit_cell_devices=[device(), device()], transfer_forward=io_parameters( noise_management=NoiseManagementType.NONE, bound_management=BoundManagementType.NONE, ), transfer_update=up_parameters(), units_in_mbatch=True, construction_seed=construction_seed, ), forward=io_parameters(), backward=io_parameters(), update=up_parameters(), **kwargs, ) if algorithm.lower() in ["ttv2"]: return UnitCellRPUConfig( device=ChoppedTransferCompound( unit_cell_devices=[device(), device()], transfer_forward=io_parameters( noise_management=NoiseManagementType.NONE, bound_management=BoundManagementType.NONE, ), transfer_update=up_parameters( desired_bl=1, update_bl_management=False, update_management=False ), in_chop_prob=0.0, units_in_mbatch=False, auto_scale=False, construction_seed=construction_seed, ), forward=io_parameters(), backward=io_parameters(), update=up_parameters(desired_bl=5), **kwargs, ) if algorithm.lower() in ["chopped-ttv2", "ttv3", "c-ttv2"]: return UnitCellRPUConfig( device=ChoppedTransferCompound( unit_cell_devices=[device(), device()], transfer_forward=io_parameters( noise_management=NoiseManagementType.NONE, bound_management=BoundManagementType.NONE, ), transfer_update=up_parameters( desired_bl=1, update_bl_management=False, update_management=False ), units_in_mbatch=False, fast_lr=0.1, auto_scale=True, construction_seed=construction_seed, ), forward=io_parameters(), backward=io_parameters(), update=up_parameters(desired_bl=5), **kwargs, ) if algorithm.lower() in ["agad", "ttv4"]: return UnitCellRPUConfig( device=DynamicTransferCompound( unit_cell_devices=[device(), device()], transfer_forward=io_parameters(), transfer_update=up_parameters( desired_bl=1, update_bl_management=True, update_management=True ), auto_scale=True, fast_lr=0.1, units_in_mbatch=False, construction_seed=construction_seed, ), forward=io_parameters(), backward=io_parameters(), update=up_parameters(desired_bl=5), **kwargs, ) if algorithm.lower() in ["mp", "mixed-precision"]: return DigitalRankUpdateRPUConfig( device=MixedPrecisionCompound(device=device(construction_seed=construction_seed)), forward=io_parameters(), backward=io_parameters(), update=up_parameters(), **kwargs, ) raise ArgumentError("Algorithm {} is not known".format(algorithm))