Source code for aihwkit.simulator.tiles.array

# -*- coding: utf-8 -*-

# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# Licensed under the MIT license. See LICENSE file in the project root for details.

# mypy: disable-error-code=attr-defined

"""Implements analog tile module array ."""

from typing import Any, Optional, Tuple, List, TYPE_CHECKING

from torch import Tensor, cat, split, zeros, full
from torch.nn import ModuleList, Parameter, Module
from torch.autograd import no_grad

from aihwkit.simulator.tiles.base import TileModuleBase
from aihwkit.exceptions import TileModuleError
from aihwkit.simulator.digital_low_precision.base_quantized_classes import QuantizedActivation
from aihwkit.simulator.digital_low_precision.config_utils import convert_act_config_to_kwargs_dict
from aihwkit.simulator.digital_low_precision.base_quantized_classes import QuantizationManager
from aihwkit.simulator.digital_low_precision.quantizers import QMethods
from aihwkit.simulator.digital_low_precision.range_estimators import RangeEstimators

if TYPE_CHECKING:
    from aihwkit.simulator.configs.configs import MappableRPU
    from aihwkit.simulator.configs import QuantizedTorchInferenceRPUConfig


[docs] class TileModuleArray(Module, TileModuleBase): """Logical array of tile modules. Note: The bias in the RPUConfig does not have any effect since the bias is always concatenated for the logical array and added at the end in digital """ supports_indexed = False def __init__( self, out_size: int, in_size: int, rpu_config: "MappableRPU", bias: bool = False, **kwargs: Any, ) -> None: super().__init__() # More than one tile may need to be created. If so, divide # weight matrix into equal pieces along input dimension with # as many tiles as needed if bias: self.bias = Parameter(zeros(out_size), requires_grad=True) else: self.bias = None # type: ignore max_input_size = rpu_config.mapping.max_input_size max_output_size = rpu_config.mapping.max_output_size self.in_size = in_size self.out_size = out_size self.in_sizes = self.get_split_sizes(in_size, max_input_size) self.out_sizes = self.get_split_sizes(out_size, max_output_size) self.analog_tile_count = len(self.in_sizes) * len(self.out_sizes) self.array = ModuleList() for in_tile_size in self.in_sizes: in_tiles = ModuleList() for out_tile_size in self.out_sizes: tile = rpu_config.tile_class( out_tile_size, in_tile_size, rpu_config, bias=False, **kwargs ) in_tiles.append(tile) self.array.append(in_tiles)
[docs] @no_grad() def get_split_sizes(self, size: int, split_max_size: int) -> List[int]: """Computed the split sizes. Args: size: number of elements of the layer in one dimension split_max_size: max size of the split Returns: List of split sizes """ if split_max_size <= 0: return [size] n_splits = (size + split_max_size - 1) // split_max_size base, extra = divmod(size, n_splits) return [base + (i < extra) for i in range(n_splits)]
[docs] @no_grad() def set_weights(self, weight: Tensor, bias: Optional[Tensor] = None, **kwargs: Any) -> None: """Set the weight (and bias) values with given tensors to the analog crossbar(s). Args: weight: weight matrix bias: bias vector **kwargs: see tile level, e.g. :meth:`~aihwkit.simulator.tiles.analog.AnalogTile.set_weights` """ shape = [self.out_size, self.in_size] weight = weight.detach().reshape(shape) in_start = in_end = 0 for in_size, in_tiles in zip(self.in_sizes, self.array): in_end += in_size out_start = out_end = 0 for out_size, analog_tile in zip(self.out_sizes, in_tiles): out_end += out_size tile_weight = weight[out_start:out_end, in_start:in_end] analog_tile.set_weights(tile_weight, None, **kwargs) out_start = out_end in_start = in_end if self.bias is not None and bias is not None: self.bias.data = bias.detach().to(self.bias.device)
[docs] @no_grad() def get_weights(self, **kwargs: Any) -> Tuple[Tensor, Optional[Tensor]]: """Get the (combined) weight (and bias) tensors from the analog crossbar(s). Args: kwargs: see tile level, e.g. :meth:`~aihwkit.simulator.tiles.analog.AnalogTile.get_weights` Returns: tuple: weight matrix, bias vector """ weight_lst = [] for in_tiles in self.array: in_tile_weight = [] for analog_tile in in_tiles: tile_weight, _ = analog_tile.get_weights(**kwargs) in_tile_weight.append(tile_weight) weight_lst.append(cat(in_tile_weight, 0)) weight = cat(weight_lst, 1) if self.bias is not None: return weight, self.bias.clone().cpu() return weight, None
[docs] def forward(self, x_input: Tensor, tensor_view: Optional[Tuple] = None) -> Tensor: """Compute the forward pass.""" # pylint: disable=arguments-differ,arguments-renamed if self.analog_tile_count == 1: analog_tile = self.array[0][0] # pylint: disable=unsubscriptable-object result = analog_tile(x_input) else: # mapped version last_dim = x_input.ndim - 1 splits = split(x_input, self.in_sizes, dim=last_dim) result = None for idx, (x, in_tiles) in enumerate(zip(splits, self.array)): out_result = [] for analog_tile in in_tiles: out_result.append(analog_tile(x, tensor_view=tensor_view)) if idx == 0: result = cat(out_result, last_dim) else: result.add_(cat(out_result, last_dim)) if self.bias is not None: if tensor_view is None: tensor_view = analog_tile.get_tensor_view(result.dim()) return result + self.bias.view(*tensor_view) return result
def _apply_to_tiles(self, method_name: str, *args: Any, **kwargs: Any) -> List[List[Any]]: """Applies function to each tile in the array. Raises: TileModuleError: if method does not exist """ out_values = [] for in_tiles in self.array: out_values_row = [] for analog_tile in in_tiles: if not hasattr(analog_tile, method_name): raise TileModuleError(f"Tile does not have method '{method_name}'") out_values_row.append(getattr(analog_tile, method_name)(*args, **kwargs)) out_values.append(out_values_row) return out_values
[docs] class QuantizedTileModuleArray(TileModuleArray): """Logical array of quantized torch inference tile modules. It extends the functionality of `TileModuleArray`, by adding quantization capability for the bias (which is applied here instead of the individual tiles) and for the final result of the array, after all the partial results from the tiles have been accumulated. It only overwrites the forward function of the `TileModuleArray`, to add the output and bias quantization. """ def __init__( self, out_size: int, in_size: int, rpu_config: "QuantizedTorchInferenceRPUConfig", bias: bool = False, **kwargs: Any, ): super().__init__(out_size, in_size, rpu_config, bias, **kwargs) self.periph_quant = rpu_config.pre_post.periph_quant # Quantization in tall and wide layers. The tall layers need to be # requantized because they were produced by accumulation # of partial results, and as such are no longer quantized in the # same bit-precision. The wide layers need to be quantized to equalize # the range across the different tiles and keep the precision down # in case of bias addition. if self.analog_tile_count > 1: if rpu_config.act_quant_config is not None and rpu_config.act_quant_config.n_bits > 0: self.module_out_quantizer = QuantizedActivation( **convert_act_config_to_kwargs_dict(rpu_config.act_quant_config) ) # Enable the quantization self.module_out_quantizer.quantized_acts() else: self.module_out_quantizer = None # Initialize the bias quantizer, if quantized periphery is defined if self.bias is not None: if rpu_config.pre_post.periph_quant.n_bits > 0: periph_quant = rpu_config.pre_post.periph_quant self.bias_quantizer = QuantizationManager( qmethod=( QMethods.symmetric_uniform if periph_quant.symmetric else QMethods.asymmetric_uniform ), qparams={"n_bits": periph_quant.n_bits}, init=RangeEstimators.current_minmax, ) if periph_quant.learn_quant_params: self.bias_quant_update_idx = Parameter( full((1,), 0.0, device=self.device), requires_grad=False ) else: self.bias_quantizer = None
[docs] def forward(self, x_input: Tensor, tensor_view: Optional[Tuple] = None) -> Tensor: """Compute the forward pass, quantizing the final result as appropriate""" # pylint: disable=arguments-differ,arguments-renamed,too-many-branches # Create the final result. In tall splits, perform the intermediate accumulation if self.analog_tile_count == 1: analog_tile = self.array[0][0] # pylint: disable=unsubscriptable-object result = analog_tile(x_input) else: # mapped version last_dim = x_input.ndim - 1 splits = split(x_input, self.in_sizes, dim=last_dim) result = None for idx, (x, in_tiles) in enumerate(zip(splits, self.array)): out_result = [] for analog_tile in in_tiles: out_result.append(analog_tile(x, tensor_view=tensor_view)) if idx == 0: result = cat(out_result, last_dim) else: result.add_(cat(out_result, last_dim)) # Add the bias if self.bias is not None: if tensor_view is None: tensor_view = analog_tile.get_tensor_view(result.dim()) if self.bias_quantizer is None: result += self.bias.view(*tensor_view) else: # In the case of evaluation with uninitialized quantizer, take care of # estimating the ranges first and then fixing them if not self.training and not self.bias_quantizer.quantizer.is_initialized: self.bias_quantizer.estimate_ranges() q_bias = self.bias_quantizer(self.bias) self.bias_quantizer.fix_ranges() result += q_bias.view(*tensor_view) else: if ( self.training and self.periph_quant.learn_quant_params and not self.bias_quantizer.is_learning() ): # If learning is enabled, estimate till `init_learning_after` # before switching to learned self.bias_quant_update_idx.data += 1 # count up to the desired batch if self.bias_quant_update_idx > self.periph_quant.init_learning_after: self.bias_quantizer.learn_ranges() # Switch to learned # Add to the result the quantized bias result += self.bias_quantizer(self.bias).view(*tensor_view) # Quantize the final result post accumulation and bias addition if self.module_out_quantizer is not None: result = self.module_out_quantizer(result) return result