# -*- coding: utf-8 -*-
# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# Licensed under the MIT license. See LICENSE file in the project root for details.
"""Helper for generating presets."""
from typing import Union, Type, Callable, Any
from copy import deepcopy
from aihwkit.exceptions import ArgumentError
from aihwkit.simulator.configs.configs import (
UnitCellRPUConfig,
SingleRPUConfig,
DigitalRankUpdateRPUConfig,
)
from aihwkit.simulator.configs.devices import PulsedDevice
from aihwkit.simulator.configs.compounds import (
TransferCompound,
ChoppedTransferCompound,
DynamicTransferCompound,
MixedPrecisionCompound,
VectorUnitCell,
)
from aihwkit.simulator.parameters.training import UpdateParameters
from aihwkit.simulator.parameters.io import IOParameters
from aihwkit.simulator.parameters.enums import (
VectorUnitCellUpdatePolicy,
NoiseManagementType,
BoundManagementType,
)
[docs]
def build_config(
algorithm: str,
device: Union[Type[PulsedDevice], PulsedDevice, Callable],
io_parameters: Union[Type[IOParameters], IOParameters, Callable] = IOParameters,
up_parameters: Union[Type[UpdateParameters], UpdateParameters, Callable] = UpdateParameters,
n_devices: int = 1,
construction_seed: int = 0,
**kwargs: Any,
) -> Union[UnitCellRPUConfig, SingleRPUConfig, DigitalRankUpdateRPUConfig]:
"""Generate a RPU configuration for analog training using a
specific device model and a given training algoithm.
Args:
algorithm: The type of the training algorithm. Valid choices are:
"sgd": Random pulsed (naive) SGD on analog crossbars. See
`Gokmen & Vlasov, Front. Neurosci. 2016`_ for details.
"mp", "mixed-precision": Mixed-precision analog, where the
gradient is computed in digital and only the forward
abd backward pass is in analog. Uses
:class:`~aihwkit.simulator.configs.compounds.MixedPrecisionCompound`. See
also `Nandakumar et al. Front. in Neurosci. (2020)`_
for details.
"tiki-taka", "ttv1", "tt": Tiki-taka I algorithm. Uses
:class:`~aihwkit.simulator.configs.compounds.TransferCompound`.
See `Gokmen & Haensch, Front. Neurosci. 2020`_ for
details.
"ttv2": second version of the Tiki-taka algorithm
(TTv2). Uses
:class:`~aihwkit.simulator.configs.compounds.ChoppedTransferCompound`
with chopper probabilty set to 0. See `Gokmen,
Front. Artif. Intell. 2021`_ for details.
"chopped-ttv2", "ttv3", "c-ttv2": Chopped version of TTv2
algorithm. Uses
:class:`~aihwkit.simulator.configs.compounds.ChoppedTransferCompound`.
See `Rasch et al., ArXiv 2023`_ for details.
"agad", "ttv4": Analog gradient accumulation with dynamic
reference computation. Uses
:class:`~aihwkit.simulator.configs.compounds.StatsticalransferCompound`. See
`Rasch et al., ArXiv 2023`_ for details.
device: Device configuration of the analog devices. Can be the
class or the actual device. All available device will have the
same configuration.
io_parameters: IOParameters class (or actual instance) that
are used for forward / backward and transfer. Default is
:class:`~aihwkit.simulator.config.IOParameters`.
up_parameters: UpdateParameters class (or actual instance) that are used for update and
transfer update. Default is
:class:`~aihwkit.simulator.config.UpdateParameters`.
n_devices: In case of SGD, how many device pairs are used in the unit cell.
Note:
This option is only applied for ``algorithm="sgd"``
and ignored for all other algorithm choices.
construction_seed: Seed of the construction
kwargs: Other RPUConfig fields to assign explicitely (e.g. ``mapping``).
Returns:
RPU config according to the algorithm and device settings.
Raises:
ArgumentError: in case algorithm is not known
.. _`Gokmen & Vlasov, Front. Neurosci. 2016`: \
https://www.frontiersin.org/articles/10.3389/fnins.2016.00333/full
.. _`Gokmen & Haensch, Front. Neurosci. 2020`: \
https://www.frontiersin.org/articles/10.3389/fnins.2020.00103/full
.. _`Gokmen, Front. Artif. Intell. 2021`: \
https://www.frontiersin.org/articles/10.3389/frai.2021.699148/full
.. _`Rasch et al., ArXiv 2023`: \
https://arxiv.org/abs/2303.04721
.. _`Nandakumar et al. Front. in Neurosci. (2020)`: \
https://doi.org/10.3389/fnins.2020.00406
"""
# pylint: disable=too-many-statements, too-many-return-statements
if isinstance(device, PulsedDevice):
device_to_use = device
def device_fun(**kwargs: Any) -> PulsedDevice:
dev = deepcopy(device_to_use)
dev.__dict__.update(**kwargs)
return dev
device = device_fun
if isinstance(io_parameters, IOParameters):
io_pars_to_use = io_parameters
def io_pars_fun(**kwargs: Any) -> IOParameters:
io_pars = deepcopy(io_pars_to_use)
io_pars.__dict__.update(**kwargs)
return io_pars
io_parameters = io_pars_fun
if isinstance(up_parameters, UpdateParameters):
up_pars_to_use = up_parameters
def up_pars_fun(**kwargs: Any) -> UpdateParameters:
up_pars = deepcopy(up_pars_to_use)
up_pars.__dict__.update(**kwargs)
return up_pars
up_parameters = up_pars_fun
if algorithm.lower() in ["sgd"]:
if n_devices == 1:
return SingleRPUConfig(
device=device(construction_seed=construction_seed),
forward=io_parameters(),
backward=io_parameters(),
update=up_parameters(),
**kwargs,
)
return UnitCellRPUConfig(
device=VectorUnitCell(
unit_cell_devices=[device() for _ in range(n_devices)],
update_policy=VectorUnitCellUpdatePolicy.SINGLE_RANDOM,
construction_seed=construction_seed,
),
forward=io_parameters(),
backward=io_parameters(),
update=up_parameters(),
**kwargs,
)
if algorithm.lower() in ["tiki-taka", "tt", "ttv1"]:
return UnitCellRPUConfig(
device=TransferCompound(
unit_cell_devices=[device(), device()],
transfer_forward=io_parameters(
noise_management=NoiseManagementType.NONE,
bound_management=BoundManagementType.NONE,
),
transfer_update=up_parameters(),
units_in_mbatch=True,
construction_seed=construction_seed,
),
forward=io_parameters(),
backward=io_parameters(),
update=up_parameters(),
**kwargs,
)
if algorithm.lower() in ["ttv2"]:
return UnitCellRPUConfig(
device=ChoppedTransferCompound(
unit_cell_devices=[device(), device()],
transfer_forward=io_parameters(
noise_management=NoiseManagementType.NONE,
bound_management=BoundManagementType.NONE,
),
transfer_update=up_parameters(
desired_bl=1, update_bl_management=False, update_management=False
),
in_chop_prob=0.0,
units_in_mbatch=False,
auto_scale=False,
construction_seed=construction_seed,
),
forward=io_parameters(),
backward=io_parameters(),
update=up_parameters(desired_bl=5),
**kwargs,
)
if algorithm.lower() in ["chopped-ttv2", "ttv3", "c-ttv2"]:
return UnitCellRPUConfig(
device=ChoppedTransferCompound(
unit_cell_devices=[device(), device()],
transfer_forward=io_parameters(
noise_management=NoiseManagementType.NONE,
bound_management=BoundManagementType.NONE,
),
transfer_update=up_parameters(
desired_bl=1, update_bl_management=False, update_management=False
),
units_in_mbatch=False,
fast_lr=0.1,
auto_scale=True,
construction_seed=construction_seed,
),
forward=io_parameters(),
backward=io_parameters(),
update=up_parameters(desired_bl=5),
**kwargs,
)
if algorithm.lower() in ["agad", "ttv4"]:
return UnitCellRPUConfig(
device=DynamicTransferCompound(
unit_cell_devices=[device(), device()],
transfer_forward=io_parameters(),
transfer_update=up_parameters(
desired_bl=1, update_bl_management=True, update_management=True
),
auto_scale=True,
fast_lr=0.1,
units_in_mbatch=False,
construction_seed=construction_seed,
),
forward=io_parameters(),
backward=io_parameters(),
update=up_parameters(desired_bl=5),
**kwargs,
)
if algorithm.lower() in ["mp", "mixed-precision"]:
return DigitalRankUpdateRPUConfig(
device=MixedPrecisionCompound(device=device(construction_seed=construction_seed)),
forward=io_parameters(),
backward=io_parameters(),
update=up_parameters(),
**kwargs,
)
raise ArgumentError("Algorithm {} is not known".format(algorithm))