# -*- coding: utf-8 -*-
# (C) Copyright 2020, 2021, 2022 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""Autograd functions for aihwkit."""
from typing import Any, Optional, Tuple
from torch import Tensor, empty_like
from torch.autograd import Function
from aihwkit.optim.context import AnalogContext
[docs]class AnalogFunctionBase(Function):
"""Base function for analog functions."""
# pylint: disable=arguments-differ, protected-access, abstract-method
[docs] @staticmethod
def forward(
ctx: Any,
analog_ctx: AnalogContext,
input_: Tensor,
shared_weights: Optional[Tensor] = None,
is_test: bool = False) -> Tensor:
"""Execute the forward pass in the analog tile.
Note: Indexed versions can used when analog_ctx.use_indexed is
set to True.
"""
# Store in context for using during `backward()`.
analog_tile = analog_ctx.analog_tile
ctx.analog_ctx = analog_ctx
ctx.shared_weights = None
ctx.save_for_backward(input_)
use_indexed = analog_ctx.use_indexed
if shared_weights is not None:
ctx.shared_weights = shared_weights
analog_tile.ensure_shared_weights(shared_weights)
analog_ctx.use_torch_update = True
else:
analog_ctx.use_torch_update = False
# Invoke the forward pass in the tile instance.
if use_indexed:
return analog_tile.forward_indexed(input_, is_test)
return analog_tile.forward(input_, is_test)
[docs] @staticmethod
def backward(
ctx: Any,
grad_output: Tensor,
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
"""Execute the backward pass in the analog tile."""
analog_ctx = ctx.analog_ctx
analog_tile = analog_ctx.analog_tile
input_, = ctx.saved_tensors
shared_weights_grad = None
use_indexed = analog_ctx.use_indexed
if ctx.shared_weights is not None:
analog_tile.ensure_shared_weights(ctx.shared_weights)
# Call the backward function in the tile instance.
if use_indexed:
grad_input = analog_tile.backward_indexed(grad_output)
else:
grad_input = analog_tile.backward(grad_output)
if analog_ctx.use_torch_update:
# Grad computed directly (for inference training)
shared_weights_grad = empty_like(ctx.shared_weights)
analog_tile.set_delta_weights(shared_weights_grad)
if use_indexed:
analog_tile.update_indexed(input_, grad_output)
else:
analog_tile.update(input_, grad_output)
analog_tile.reset_delta_weights()
else:
# Store activation and errors for optimizer (for analog training)
analog_ctx.analog_input.append(input_)
analog_ctx.analog_grad_output.append(grad_output)
return None, grad_input, shared_weights_grad, None
[docs]class AnalogFunction(AnalogFunctionBase):
"""Function that delegates into a `RPU` unit."""
# pylint: disable=arguments-differ, abstract-method
[docs] @staticmethod
def forward(
ctx: Any,
analog_ctx: AnalogContext,
input_: Tensor,
shared_weights: Optional[Tensor] = None,
is_test: bool = False) -> Tensor:
"""Execute the forward pass in the analog tile."""
analog_ctx.use_indexed = False
return AnalogFunctionBase.forward(
ctx, analog_ctx, input_, shared_weights, is_test)
[docs]class AnalogIndexedFunction(AnalogFunctionBase):
"""Function that delegates into a `RPU` unit to use the indexed forward/backward/update."""
# pylint: disable=arguments-differ, abstract-method
[docs] @staticmethod
def forward(
ctx: Any,
analog_ctx: AnalogContext,
input_: Tensor,
shared_weights: Optional[Tensor] = None,
is_test: bool = False) -> Tensor:
"""Execute the forward pass in the analog tile."""
analog_ctx.use_indexed = True
return AnalogFunctionBase.forward(
ctx, analog_ctx, input_, shared_weights, is_test)