Source code for aihwkit.simulator.tiles.functions

# -*- coding: utf-8 -*-

# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Autograd functions for aihwkit."""

from typing import Any, Optional, Tuple

from torch import Tensor, empty_like
from torch.autograd import Function, no_grad
from aihwkit.optim.context import AnalogContext


[docs]class AnalogFunction(Function): """Function for analog functions.""" # pylint: disable=arguments-differ, protected-access, abstract-method
[docs] @staticmethod @no_grad() def forward( ctx: Any, analog_ctx: AnalogContext, analog_tile: Any, input_: Tensor, shared_weights: Optional[Tensor] = None, is_test: bool = False, ) -> Tensor: """Execute the forward pass in the analog tile. Note: Indexed versions can used when analog_ctx.use_indexed is set to True. """ # Store in context for using during `backward()`. ctx.analog_ctx = analog_ctx ctx.analog_tile = analog_tile ctx.shared_weights = None ctx.saved_analog_tensors = [input_] runtime = analog_tile.get_runtime() use_indexed = analog_ctx.use_indexed if shared_weights is not None: ctx.shared_weights = shared_weights analog_tile.ensure_shared_weights(shared_weights) analog_ctx.use_torch_update = True else: analog_ctx.use_torch_update = False # Invoke the forward pass in the tile instance. if use_indexed: out = analog_tile.joint_forward_indexed(input_, is_test, ctx) else: out = analog_tile.joint_forward(input_, is_test, ctx) if runtime.offload_input: ctx.saved_analog_tensors[0] = ctx.saved_analog_tensors[0].cpu() ctx.save_for_backward(*ctx.saved_analog_tensors) ctx.saved_analog_tensors = [] return out
[docs] @staticmethod @no_grad() def backward( ctx: Any, grad_output: Tensor ) -> Tuple[ Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor], Optional[Tensor] ]: """Execute the backward pass in the analog tile.""" analog_ctx = ctx.analog_ctx analog_tile = ctx.analog_tile ctx.saved_analog_tensors = ctx.saved_tensors input_ = ctx.saved_analog_tensors[0] runtime = analog_tile.get_runtime() shared_weights_grad = None use_indexed = analog_ctx.use_indexed if ctx.shared_weights is not None: analog_tile.ensure_shared_weights(ctx.shared_weights) # Call the backward function in the tile instance. if use_indexed: grad_input = analog_tile.backward_indexed(grad_output, ctx) else: grad_input = analog_tile.backward(grad_output, ctx) if analog_ctx.use_torch_update: if runtime.offload_input: input_ = input_.to(analog_tile.device) # Grad computed directly (for inference training) shared_weights_grad = empty_like(ctx.shared_weights) analog_tile.set_delta_weights(shared_weights_grad) if use_indexed: analog_tile.update_indexed(input_, grad_output) else: analog_tile.update(input_, grad_output) analog_tile.reset_delta_weights() else: # Store activation and errors for optimizer (for analog training) analog_ctx.analog_input.append(input_) if runtime.offload_gradient: store_gradients = grad_output.cpu() else: store_gradients = grad_output analog_ctx.analog_grad_output.append(store_gradients) ctx.saved_analog_tensors = [] return None, None, grad_input, shared_weights_grad, None