Source code for aihwkit.nn.modules.conv

# -*- coding: utf-8 -*-

# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Convolution layers."""

# pylint: disable=too-many-arguments, too-many-locals, too-many-instance-attributes

from typing import Optional, Tuple, Union, List, Type

from torch import Tensor, arange, cat, float64, int32, ones
from torch.autograd import no_grad
from torch.nn.functional import pad, unfold
from torch.nn.modules.conv import _ConvNd, Conv1d, Conv2d, Conv3d
from torch.nn.modules.utils import _single, _pair, _triple

from aihwkit.exceptions import ModuleError
from aihwkit.nn.modules.base import AnalogLayerBase
from aihwkit.simulator.parameters.base import RPUConfigBase


class _AnalogConvNd(AnalogLayerBase, _ConvNd):
    """Base class for convolution layers."""

    NEEDS_INDEXED = False

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: Tuple[int, ...],
        stride: Tuple[int, ...],
        padding: Tuple[int, ...],
        dilation: Tuple[int, ...],
        transposed: bool,
        output_padding: Tuple[int, ...],
        groups: int,
        bias: bool,
        padding_mode: str,
        rpu_config: Optional[RPUConfigBase] = None,
        tile_module_class: Optional[Type] = None,
        use_indexed: Optional[bool] = None,
    ):
        if groups != 1:
            raise ValueError("Only one group is supported")
        if padding_mode != "zeros":
            raise ValueError('Only "zeros" padding mode is supported')

        # Call super()
        _ConvNd.__init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            transposed,
            output_padding,
            groups,
            bias,
            padding_mode,
        )

        # Create the tile and set the analog.
        AnalogLayerBase.__init__(self)

        if rpu_config is None:
            # pylint: disable=import-outside-toplevel
            from aihwkit.simulator.configs.configs import SingleRPUConfig

            rpu_config = SingleRPUConfig()

        if tile_module_class is None:
            tile_module_class = rpu_config.get_default_tile_module_class()
        self.in_features = self.get_tile_size(in_channels, groups, kernel_size)
        self.out_features = out_channels
        self.analog_module = tile_module_class(
            self.out_features, self.in_features, rpu_config, bias
        )

        # Set the index matrices.
        self.use_indexed = use_indexed
        if not self.analog_module.supports_indexed:
            self.use_indexed = False

        self.fold_indices = Tensor().detach()
        self.input_size = 0
        self.tensor_view = (-1,)  # type: Tuple[int, ...]

        # Unregister weight/bias as a parameter but keep it for syncs
        self.unregister_parameter("weight")
        if bias:
            self.unregister_parameter("bias")
        else:
            # seems to be a torch bug
            self._parameters.pop("bias", None)
        self.bias = bias

        self.reset_parameters()

    def get_tile_size(self, in_channels: int, groups: int, kernel_size: Tuple[int, ...]) -> int:
        """Calculate the tile size."""
        raise NotImplementedError

    def get_image_size(self, size: int, i: int) -> int:
        """Calculate the output image sizes."""
        # pylint: disable=superfluous-parens
        nom = size + 2 * self.padding[i] - self.dilation[i] * (self.kernel_size[i] - 1) - 1
        return nom // self.stride[i] + 1

    def reset_parameters(self) -> None:
        """Reset the parameters (weight and bias)."""
        if hasattr(self, "analog_module"):
            bias = self.bias
            self.weight, self.bias = self.get_weights()  # type: ignore
            super().reset_parameters()
            self.set_weights(self.weight, self.bias)
            self.weight, self.bias = None, bias

    @no_grad()
    def _recalculate_indexes(self, x_input: Tensor) -> None:
        """Calculate and set the indexes of the analog tile."""

        self.fold_indices, image_sizes, self.input_size = self._calculate_indexes(
            x_input, self.in_channels
        )
        self.analog_module.set_indexed(self.fold_indices, image_sizes)

    @no_grad()
    def _calculate_indexes(
        self, x_input: Tensor, in_channels: int
    ) -> Tuple[Tensor, List[int], int]:
        """Calculate and return the fold indexes and sizes.

        Args:
            x_input: input matrix
            in_channels: number of input channel

        Returns:
            fold_indices: indices for the analog tile
            image_sizes: image sizes for the analog tile
            input_size: size of the current input
        """
        raise NotImplementedError

    def forward(self, x_input: Tensor) -> Tensor:
        """Compute the forward pass.

        Raises:
            ModuleError: in case indexed convolution is needed but not supported by the TileModule.
        """

        # Use indexed only in case of cuda.
        use_indexed = self.use_indexed
        if use_indexed is None and not self.NEEDS_INDEXED:
            use_indexed = self.analog_module.is_cuda
        if not use_indexed and self.NEEDS_INDEXED:
            raise ModuleError("Tile module does not support indexed computation.")
        if use_indexed:
            input_size = x_input.numel() / x_input.size(0)
            if self.input_size != input_size or not self.analog_module.is_indexed():
                self._recalculate_indexes(x_input)

            return self.analog_module(x_input, tensor_view=self.tensor_view)

        # Brute-force unfold.
        im_shape = x_input.shape
        x_input_ = unfold(
            x_input,
            kernel_size=self.kernel_size,
            dilation=self.dilation,
            padding=self.padding,
            stride=self.stride,
        ).transpose(1, 2)

        out = self.analog_module(x_input_).transpose(1, 2)
        out_size = (
            im_shape[2] + 2 * self.padding[0] - self.dilation[0] * (self.kernel_size[0] - 1) - 1
        ) // self.stride[0] + 1
        return out.view(im_shape[0], self.out_channels, out_size, -1)


[docs]class AnalogConv1d(_AnalogConvNd): """1D convolution layer that uses an analog tile. Applies a 1D convolution over an input signal composed of several input planes, using an analog tile for its forward, backward and update passes. Note: The tensor parameters of this layer (``.weight`` and ``.bias``) are not guaranteed to contain the same values as the internal weights and biases stored in the analog tile. Please use ``set_weights`` and ``get_weights`` when attempting to read or modify the weight/bias. This read/write process can simulate the (noisy and inexact) analog writing and reading of the resistive elements. Args: in_channels: number of channels in the input image. out_channels: number of channels produced by the convolution. kernel_size: size of the convolving kernel. stride: stride of the convolution. padding: zero-padding added to both sides of the input. dilation: spacing between kernel elements. groups: number of blocked connections from input channels to output channels. bias: whether to use a bias row on the analog tile or not. padding_mode: padding strategy. Only ``'zeros'`` is supported. rpu_config: resistive processing unit configuration. tile_module_class: Class for the tile module (default will be specified from the ``RPUConfig``). """ # pylint: disable=abstract-method NEEDS_INDEXED = True def __init__( self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple], stride: Union[int, Tuple] = 1, padding: Union[int, Tuple] = 0, dilation: Union[int, Tuple] = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", rpu_config: Optional["RPUConfigBase"] = None, tile_module_class: Optional[Type] = None, ): # pylint: disable=too-many-arguments kernel_size = _single(kernel_size) stride = _single(stride) padding = _single(padding) dilation = _single(dilation) if dilation != _single(1): raise ValueError("Only dilation = 1 is supported") super().__init__( in_channels, out_channels, kernel_size, # type: ignore stride, # type: ignore padding, # type: ignore dilation, # type: ignore False, _single(0), groups, bias, padding_mode, rpu_config, tile_module_class, True, ) self.tensor_view = (-1, 1)
[docs] @classmethod def from_digital( cls, module: Conv1d, rpu_config: "RPUConfigBase", tile_module_class: Optional[Type] = None ) -> "AnalogConv1d": """Return an AnalogConv1d layer from a torch Conv1d layer. Args: module: The torch module to convert. All layers that are defined in the ``conversion_map``. rpu_config: RPU config to apply to all converted tiles. Applied to all converted tiles. tile_module_class: Class for the tile module (default will be specified from the ``RPUConfig``). Returns: an AnalogConv1d layer based on the digital Conv1d ``module``. """ analog_layer = cls( module.in_channels, module.out_channels, module.kernel_size, module.stride, module.padding, module.dilation, module.groups, module.bias is not None, module.padding_mode, rpu_config, tile_module_class, ) analog_layer.set_weights(module.weight, module.bias) return analog_layer.to(module.weight.device)
[docs] @classmethod def to_digital(cls, module: "AnalogConv1d", realistic: bool = False) -> Conv1d: """Return an nn.Conv1d layer from an AnalogConv1d layer. Args: module: The analog module to convert. realistic: whehter to estimate the weights with the non-ideal forward pass. If not set, analog weights are (unrealistically) copies exactly Returns: an torch Linear layer with the same dimension and weights as the analog linear layer. """ weight, bias = module.get_weights(realistic=realistic) digital_layer = Conv1d( module.in_channels, module.out_channels, module.kernel_size, module.stride, module.padding, module.dilation, module.groups, bias is not None, module.padding_mode, ) digital_layer.weight.data = weight.data.view(-1, module.in_channels, *module.kernel_size) if bias is not None: digital_layer.bias.data = bias.data analog_tile = next(module.analog_tiles()) return digital_layer.to(device=analog_tile.device, dtype=analog_tile.get_dtype())
[docs] def get_tile_size(self, in_channels: int, groups: int, kernel_size: Tuple[int, ...]) -> int: """Calculate the tile size.""" return (in_channels // groups) * kernel_size[0]
def _calculate_indexes( self, x_input: Tensor, in_channels: int ) -> Tuple[Tensor, List[int], int]: """Calculate and return the fold indexes and sizes. Args: x_input: input matrix in_channels: number of input channel Returns: fold_indices: indices for the analog tile image_sizes: image sizes for the analog tile input_size: size of the current input """ input_size = x_input.numel() / x_input.size(0) # pytorch just always uses NCHW order fold_indices = arange(2, x_input.size(2) + 2, dtype=float64).detach() shape = [1, 1] + list(x_input.shape[2:]) fold_indices = fold_indices.reshape(*shape) if not all(item == 0 for item in self.padding): fold_indices = pad( fold_indices, pad=[self.padding[0], self.padding[0]], mode="constant", value=0 ) unfolded = fold_indices.unfold(2, self.kernel_size[0], self.stride[0]).clone() fold_indices = unfolded.reshape(-1, self.kernel_size[0]).transpose(0, 1).flatten().round() # concatenate the matrix index for different channels fold_indices_orig = fold_indices.clone() for i in range(in_channels - 1): fold_indices_tmp = fold_indices_orig.clone() for j in range(fold_indices_orig.size(0)): if fold_indices_orig[j] != 0: fold_indices_tmp[j] += (input_size / in_channels) * (i + 1) fold_indices = cat([fold_indices, fold_indices_tmp], dim=0).clone() fold_indices = fold_indices.to(dtype=int32) if self.analog_module.analog_bias: out_image_size = fold_indices.numel() // (self.kernel_size[0]) fold_indices = cat((fold_indices, ones(out_image_size, dtype=int32)), 0) fold_indices = fold_indices.to(x_input.device) x_height = x_input.size(2) d_height = self.get_image_size(x_height, 0) image_sizes = [in_channels, x_height, d_height] return (fold_indices, image_sizes, input_size)
[docs]class AnalogConv2d(_AnalogConvNd): """2D convolution layer that uses an analog tile. Applies a 2D convolution over an input signal composed of several input planes, using an analog tile for its forward, backward and update passes. Note: The tensor parameters of this layer (``.weight`` and ``.bias``) are not guaranteed to contain the same values as the internal weights and biases stored in the analog tile. Please use ``set_weights`` and ``get_weights`` when attempting to read or modify the weight/bias. This read/write process can simulate the (noisy and inexact) analog writing and reading of the resistive elements. Args: in_channels: number of channels in the input image. out_channels: number of channels produced by the convolution. kernel_size: size of the convolving kernel. stride: stride of the convolution. padding: zero-padding added to both sides of the input. dilation: spacing between kernel elements. groups: number of blocked connections from input channels to output channels. bias: whether to use a bias row on the analog tile or not. padding_mode: padding strategy. Only ``'zeros'`` is supported. rpu_config: resistive processing unit configuration. tile_module_class: Class for the tile module (default will be specified from the ``RPUConfig``). use_indexed: Whether to use explicit unfolding or implicit indexing. If None (default), it will use implicit indexing for CUDA and explicit unfolding for CPU """ # pylint: disable=abstract-method def __init__( self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple], stride: Union[int, Tuple] = 1, padding: Union[int, Tuple] = 0, dilation: Union[int, Tuple] = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", rpu_config: Optional["RPUConfigBase"] = None, tile_module_class: Optional[Type] = None, use_indexed: Optional[bool] = None, ): # pylint: disable=too-many-arguments kernel_size = _pair(kernel_size) stride = _pair(stride) padding = _pair(padding) dilation = _pair(dilation) super().__init__( in_channels, out_channels, kernel_size, # type: ignore stride, # type: ignore padding, # type: ignore dilation, # type: ignore False, _pair(0), groups, bias, padding_mode, rpu_config, tile_module_class, use_indexed, ) self.tensor_view = (-1, 1, 1)
[docs] @classmethod def from_digital( cls, module: Conv2d, rpu_config: "RPUConfigBase", tile_module_class: Optional[Type] = None ) -> "AnalogConv2d": """Return an AnalogConv2d layer from a torch Conv2d layer. Args: module: The torch module to convert. All layers that are defined in the ``conversion_map``. rpu_config: RPU config to apply to all converted tiles. Applied to all converted tiles. tile_module_class: Class for the tile module (default will be specified from the ``RPUConfig``). Returns: an AnalogConv2d layer based on the digital Conv2d ``module``. """ analog_layer = cls( module.in_channels, module.out_channels, module.kernel_size, module.stride, module.padding, module.dilation, module.groups, module.bias is not None, module.padding_mode, rpu_config, tile_module_class, ) analog_layer.set_weights(module.weight, module.bias) return analog_layer.to(module.weight.device)
[docs] @classmethod def to_digital(cls, module: "AnalogConv2d", realistic: bool = False) -> Conv2d: """Return an nn.Conv2d layer from an AnalogConv2d layer. Args: module: The analog module to convert. realistic: whehter to estimate the weights with the non-ideal forward pass. If not set, analog weights are (unrealistically) copies exactly Returns: an torch Linear layer with the same dimension and weights as the analog linear layer. """ weight, bias = module.get_weights(realistic=realistic) digital_layer = Conv2d( module.in_channels, module.out_channels, module.kernel_size, module.stride, module.padding, module.dilation, module.groups, bias is not None, module.padding_mode, ) digital_layer.weight.data = weight.data.view(-1, module.in_channels, *module.kernel_size) if bias is not None: digital_layer.bias.data = bias.data analog_tile = next(module.analog_tiles()) return digital_layer.to(device=analog_tile.device, dtype=analog_tile.get_dtype())
[docs] def get_tile_size(self, in_channels: int, groups: int, kernel_size: Tuple[int, ...]) -> int: """Calculate the tile size.""" return (in_channels // groups) * kernel_size[0] * kernel_size[1]
def _calculate_indexes( self, x_input: Tensor, in_channels: int ) -> Tuple[Tensor, List[int], int]: """Calculate and return the fold indexes and sizes. Args: x_input: input matrix in_channels: number of input channel Returns: fold_indices: indices for the analog tile image_sizes: image sizes for the analog tile input_size: size of the current input """ input_size = x_input.numel() / x_input.size(0) # pytorch just always uses NCHW order fold_indices = arange(2, input_size + 2, dtype=float64).detach() shape = [1] + list(x_input.shape[1:]) fold_indices = fold_indices.reshape(*shape) fold_indices = ( unfold( fold_indices, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, dilation=self.dilation, ) .flatten() .round() .to(dtype=int32) ) if self.analog_module.analog_bias: out_image_size = fold_indices.numel() // (self.kernel_size[0] * self.kernel_size[1]) fold_indices = cat((fold_indices, ones(out_image_size, dtype=int32)), 0) fold_indices = fold_indices.to(x_input.device) x_height = x_input.size(2) x_width = x_input.size(3) d_height = self.get_image_size(x_height, 0) d_width = self.get_image_size(x_width, 1) image_sizes = [in_channels, x_height, x_width, d_height, d_width] return (fold_indices, image_sizes, input_size)
[docs]class AnalogConv3d(_AnalogConvNd): """3D convolution layer that uses an analog tile. Applies a 3D convolution over an input signal composed of several input planes, using an analog tile for its forward, backward and update passes. Note: The tensor parameters of this layer (``.weight`` and ``.bias``) are not guaranteed to contain the same values as the internal weights and biases stored in the analog tile. Please use ``set_weights`` and ``get_weights`` when attempting to read or modify the weight/bias. This read/write process can simulate the (noisy and inexact) analog writing and reading of the resistive elements. Args: in_channels: number of channels in the input image. out_channels: number of channels produced by the convolution. kernel_size: size of the convolving kernel. stride: stride of the convolution. padding: zero-padding added to both sides of the input. dilation: spacing between kernel elements. groups: number of blocked connections from input channels to output channels. bias: whether to use a bias row on the analog tile or not. padding_mode: padding strategy. Only ``'zeros'`` is supported. rpu_config: resistive processing unit configuration. tile_module_class: Class for the tile module (default will be specified from the ``RPUConfig``). """ # pylint: disable=abstract-method NEEDS_INDEXED = True def __init__( self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple], stride: Union[int, Tuple] = 1, padding: Union[int, Tuple] = 0, dilation: Union[int, Tuple] = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", rpu_config: Optional["RPUConfigBase"] = None, tile_module_class: Optional[Type] = None, ): # pylint: disable=too-many-arguments kernel_size = _triple(kernel_size) stride = _triple(stride) padding = _triple(padding) dilation = _triple(dilation) if dilation != _triple(1): raise ValueError("Only dilation = 1 is supported") super().__init__( in_channels, out_channels, kernel_size, # type: ignore stride, # type: ignore padding, # type: ignore dilation, # type: ignore False, _triple(0), groups, bias, padding_mode, rpu_config, tile_module_class, True, ) self.tensor_view = (-1, 1, 1, 1)
[docs] @classmethod def from_digital( cls, module: Conv3d, rpu_config: "RPUConfigBase", tile_module_class: Optional[Type] = None ) -> "AnalogConv3d": """Return an AnalogConv3d layer from a torch Conv3d layer. Args: module: The torch module to convert. All layers that are defined in the ``conversion_map``. rpu_config: RPU config to apply to all converted tiles. Applied to all converted tiles. tile_module_class: Class for the tile module (default will be specified from the ``RPUConfig``). Returns: an AnalogConv3d layer based on the digital Conv3d ``module``. """ analog_layer = cls( module.in_channels, module.out_channels, module.kernel_size, module.stride, module.padding, module.dilation, module.groups, module.bias is not None, module.padding_mode, rpu_config, tile_module_class, ) analog_layer.set_weights(module.weight, module.bias) return analog_layer.to(module.weight.device)
[docs] @classmethod def to_digital(cls, module: "AnalogConv3d", realistic: bool = False) -> Conv3d: """Return an nn.Conv3d layer from an AnalogConv3d layer. Args: module: The analog module to convert. realistic: whehter to estimate the weights with the non-ideal forward pass. If not set, analog weights are (unrealistically) copies exactly Returns: an torch Linear layer with the same dimension and weights as the analog linear layer. """ weight, bias = module.get_weights(realistic=realistic) digital_layer = Conv3d( module.in_channels, module.out_channels, module.kernel_size, module.stride, module.padding, module.dilation, module.groups, bias is not None, module.padding_mode, ) digital_layer.weight.data = weight.data.view(-1, module.in_channels, *module.kernel_size) if bias is not None: digital_layer.bias.data = bias.data analog_tile = next(module.analog_tiles()) return digital_layer.to(device=analog_tile.device, dtype=analog_tile.get_dtype())
[docs] def get_tile_size(self, in_channels: int, groups: int, kernel_size: Tuple[int, ...]) -> int: """Calculate the tile size.""" return (in_channels // groups) * (kernel_size[0] * kernel_size[1] * kernel_size[2])
def _calculate_indexes( self, x_input: Tensor, in_channels: int ) -> Tuple[Tensor, List[int], int]: """Calculate and return the fold indexes and sizes. Args: x_input: input matrix in_channels: then number of in channels Returns: fold_indices: indices for the analog tile image_sizes: image sizes for the analog tile input_size: size of the current input """ # pylint: disable=too-many-locals input_size = x_input.numel() / x_input.size(0) # pytorch just always uses NCDHW order fold_indices = arange( 2, x_input.size(2) * x_input.size(3) * x_input.size(4) + 2, dtype=float64 ).detach() shape = [1] + [1] + list(x_input.shape[2:]) fold_indices = fold_indices.reshape(*shape) if not all(item == 0 for item in self.padding): fold_indices = pad( fold_indices, pad=[ self.padding[2], self.padding[2], self.padding[1], self.padding[1], self.padding[0], self.padding[0], ], mode="constant", value=0, ) unfolded = ( fold_indices.unfold(2, self.kernel_size[0], self.stride[0]) .unfold(3, self.kernel_size[1], self.stride[1]) .unfold(4, self.kernel_size[2], self.stride[2]) .clone() ) fold_indices = ( unfolded.reshape(-1, self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2]) .transpose(0, 1) .flatten() .round() ) # concatenate the matrix index for different channels fold_indices_orig = fold_indices.clone() for i in range(in_channels - 1): fold_indices_tmp = fold_indices_orig.clone() for j in range(fold_indices_orig.size(0)): if fold_indices_orig[j] != 0: fold_indices_tmp[j] += (input_size / in_channels) * (i + 1) fold_indices = cat([fold_indices, fold_indices_tmp], dim=0).clone() fold_indices = fold_indices.to(dtype=int32) if self.analog_module.analog_bias: out_image_size = fold_indices.numel() // ( self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2] ) fold_indices = cat((fold_indices, ones(out_image_size, dtype=int32)), 0) fold_indices = fold_indices.to(x_input.device) x_depth = x_input.size(2) x_height = x_input.size(3) x_width = x_input.size(4) d_depth = self.get_image_size(x_depth, 0) d_height = self.get_image_size(x_height, 1) d_width = self.get_image_size(x_width, 2) image_sizes = [in_channels, x_depth, x_height, x_width, d_depth, d_height, d_width] return (fold_indices, image_sizes, input_size)