# -*- coding: utf-8 -*-
# (C) Copyright 2020, 2021, 2022 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""Analog-aware inference optimizer."""
from types import new_class
from typing import Any, Callable, Dict, Optional, Type
from torch.optim import Optimizer, SGD
from aihwkit.optim.context import AnalogContext
[docs]class AnalogOptimizerMixin:
"""Mixin for analog optimizers.
This class contains the methods needed for enabling analog in an existing
``Optimizer``. It is designed to be used as a mixin in conjunction with an
``AnalogOptimizer`` or torch ``Optimizer``.
"""
[docs] def regroup_param_groups(self, *_: Any) -> None:
"""Reorganize the parameter groups, isolating analog layers.
Update the `param_groups` of the optimizer, moving the parameters for
each analog layer to a new single group.
"""
# Create the new param groups.
analog_param_groups = []
rm_group_lst = []
for group in self.param_groups: # type: ignore[has-type]
rm_lst = []
for param in group['params']:
if isinstance(param, AnalogContext):
param.analog_tile.set_learning_rate(
self.defaults['lr']) # type: ignore[attr-defined]
analog_param_groups.append({
'params': [param],
})
rm_lst.append(id(param))
group['params'] = [p for p in group['params'] if id(p) not in rm_lst]
if len(group['params']) == 0:
rm_group_lst.append(id(group))
self.param_groups = [g for g in self.param_groups # type: ignore[has-type]
if id(g) not in rm_group_lst]
# Add analog groups.
for group in analog_param_groups:
self.add_param_group(group) # type: ignore[attr-defined]
[docs] def step(self, closure: Optional[Callable] = None) -> Optional[float]:
"""Perform an analog-aware single optimization step.
If a group containing analog parameters is detected, the optimization
step calls the related RPU controller. For regular parameter groups,
the optimization step has the same behaviour as ``torch.optim.SGD``.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
Returns:
The loss, if ``closure`` has been passed as a parameter.
"""
# pylint: disable=too-many-branches
# Update non-analog parameters using the given optimizer
ret = super().step(closure) # type: ignore[misc]
# Update analog parameters
for group in self.param_groups:
learning_rate = group.get('lr')
# Use analog_tile object.
for param in group['params']:
if isinstance(param, AnalogContext):
# Handle internal analog update.
analog_ctx = param
analog_tile = analog_ctx.analog_tile
if analog_ctx.use_torch_update:
# In this case a separate weight parameter exists: do nothing.
continue
# Update learning rate.
if learning_rate is not None:
analog_tile.set_learning_rate(learning_rate)
# Call `update` in the tile.
if not analog_ctx.has_gradient():
# Forward never used.
continue
if analog_ctx.use_indexed:
for x_input, d_input in zip(analog_ctx.analog_input,
analog_ctx.analog_grad_output):
analog_tile.update_indexed(x_input, d_input)
else:
for x_input, d_input in zip(analog_ctx.analog_input,
analog_ctx.analog_grad_output):
analog_tile.update(x_input, d_input)
analog_ctx.reset()
# Apply post-update step operations (diffuse, decay, etc).
# (only here because of unknown params order and shared weights)
for group in self.param_groups:
for param in group['params']:
if isinstance(param, AnalogContext):
param.analog_tile.post_update_step()
return ret
[docs] def set_learning_rate(self, learning_rate: float = 0.1) -> None:
"""Update the learning rate to a new value.
Update the learning rate of the optimizer, propagating the changes
to the analog tiles accordingly.
Args:
learning_rate: learning rate for the optimizer.
"""
for param_group in self.param_groups:
param_group['lr'] = learning_rate
for param in param_group['params']:
if isinstance(param, AnalogContext):
# Update learning rate on the tile
param.analog_tile.set_learning_rate(learning_rate)
[docs]class AnalogOptimizer(AnalogOptimizerMixin, Optimizer):
"""Generic optimizer that wraps an existing ``Optimizer`` for analog inference.
This class wraps an existing ``Optimizer``, customizing the optimization
step for triggering the analog update needed for analog tiles. All other
(digital) parameters are governed by the given torch optimizer. In case of
hardware-aware training (``InferenceTile``) the tile weight update is also
governed by the given optimizer, otherwise it is using the internal analog
update as defined in the ``rpu_config``.
The ``AnalogOptimizer`` constructor expects the wrapped optimizer class as
the first parameter, followed by any arguments required by the wrapped
optimizer.
Note:
The instances returned are of a *new* type that is a subclass of:
* the wrapped ``Optimizer`` (allowing access to all their methods and
attributes).
* this ``AnalogOptimizer``.
Example:
The following block illustrate how to create an optimizer that wraps
standard SGD:
>>> from torch.optim import SGD
>>> from torch.nn import Linear
>>> from aihwkit.simulator.configs.configs import InferenceRPUConfig
>>> from aihwkit.optim import AnalogOptimizer
>>> model = AnalogLinear(3, 4, rpu_config=InferenceRPUConfig)
>>> optimizer = AnalogOptimizer(SGD, model.parameters(), lr=0.02)
"""
SUBCLASSES = {} # type: Dict[str, Type]
"""Registry of the created subclasses."""
def __new__(
cls,
optimizer_cls: Type,
*_: Any,
**__: Any
) -> 'AnalogOptimizer':
subclass_name = '{}{}'.format(cls.__name__, optimizer_cls.__name__)
# Retrieve or create a new subclass, that inherits both from
# `AnalogOptimizer` and for the specific torch optimizer
# (`optimizer_cls`).
if subclass_name not in cls.SUBCLASSES:
cls.SUBCLASSES[subclass_name] = new_class(subclass_name, (cls, optimizer_cls), {})
return super().__new__(cls.SUBCLASSES[subclass_name])
def __init__(
self,
optimizer_cls: Type, # pylint: disable=unused-argument
*args: Any,
**kwargs: Any
):
super().__init__(*args, **kwargs)
[docs]class AnalogSGD(AnalogOptimizerMixin, SGD):
"""Implements analog-aware stochastic gradient descent."""