# -*- coding: utf-8 -*-
# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
# pylint: disable=too-many-locals, too-many-arguments
"""Low level implementation of torch-based tile."""
from typing import Any, Tuple, List, Optional
from math import log2
from torch import (
Tensor,
zeros,
sum as torch_sum,
flip,
abs as torch_abs,
sign,
floor,
fmod,
allclose,
linspace,
)
from torch.autograd import no_grad
from torch.nn.functional import pad
from aihwkit.extension import extension_ops # type: ignore
from aihwkit.inference.converter.conductance import SinglePairConductanceConverter
from aihwkit.exceptions import ConfigError
from aihwkit.simulator.tiles.analog_mvm import AnalogMVM
from aihwkit.simulator.parameters.enums import (
NoiseManagementType,
BoundManagementType,
AnalogMVType,
WeightNoiseType,
)
from aihwkit.simulator.parameters.io import IOParameters, IOParametersIRDropT
[docs]class AnalogMVMIRDropT(AnalogMVM):
"""Torch Perseus implementation of (part of) the IO-managed forward /
backward pass in RPUCuda.
"""
# pylint: disable=arguments-differ
@classmethod
def _get_res(cls, res: float) -> float:
"""Return resolution as number less than 1
Args:
res: resolution specified either less than or greater than 1
Returns:
float resolution specified as a number less than 1
"""
res = 1.0 / res if res > 1.0 else res
assert res > 0, "resolution is <= 0"
return res
@classmethod
def _interleave_cols_2d(cls, mvm1: Tensor, mvm2: Tensor) -> Tensor:
"""Returns 2D matrix with columns interleaved, starting with mvm1
Args:
mvm1: ``[batch_size, out_size/2]`` output activations (to south ADCs)
mvm2: ``[batch_size, out_size/2]`` output activations (to north ADCs)
Returns:
Column-wise interleaved 2D matrix of output activations that captures
IR drop in both directions (to north and south ADCs) in a
symmetric tile design.
"""
mvm = zeros((mvm1.shape[0], mvm1.shape[1] + mvm2.shape[1])).to(mvm1.device)
mvm[:, 0::2] = mvm1
mvm[:, 1::2] = mvm2
return mvm
@classmethod
def _pad_symmetric(
cls, input_: Tensor, weight: Tensor, phys_input_size: int = 512
) -> Tuple[Tensor, Tensor]:
"""Return input_ (activations) and weights symmetrically padded
with zeros to mimic symmetric ADCs (north and south).
Args:
input_: ``[batch_size, in_size]`` input_ (activations).
weight: ``[in_size, out_size]`` weight matrix.
phys_input_size: number of hardware tile rows
Returns:
Tuple[Tensor] containing the symmetrically 0-padded input_ and weight
"""
pad1 = int((phys_input_size - weight.shape[0]) / 2)
if pad1 == 0:
return input_, weight
pad2 = phys_input_size - weight.shape[0] - pad1
input_ = pad(input_, (pad1, pad2, 0, 0), "constant", 0)
weight = pad(weight, (0, 0, pad1, pad2), "constant", 0)
return input_, weight
@classmethod
def _prepare_inputs(
cls,
input_: Tensor,
scale: Tensor,
scaling: bool,
with_asymmetry: bool,
io_pars: IOParametersIRDropT,
) -> List[Tensor]:
"""
Returns list of activations that will be applied to MVM tile
depending on the mode of operation. For instance, ONE_PASS
will simply return a list containing one set of activations
to be applied. SPLIT_MODE will return a list of two sets
of activations that will be applied to the MVM tile. This
results in higher throughput/energy efficiency, but may
amplify some output noise and sacrifice accuracy. BIT_WISE
will return a list of activations for each bit, which
depends on the inp_res. SPLIT_MODE and BIT_WISE will
appropriately bit-shift the output results to perform
the MVM operation correctly.
Args:
input_: ``[N, in_size]`` input activations.
scale: scale for rescaling input activations.
scaling: whether to rescale input activations.
with_asymmetry:
io_pars: forward pass configuration.
Returns:
List[Tensor] containing input activations to be sequentially
applied to the MVM tile.
Raises:
NotImplementedError: If choices in the rpu config were made that are not supported.
ConfigError: If unknown AnalogMVType
"""
res = cls._get_res(io_pars.inp_res)
prepared_input = super(AnalogMVMIRDropT, cls)._prepare_input(
input_, scale, scaling, with_asymmetry, io_pars
)
if io_pars.mv_type == AnalogMVType.ONE_PASS:
prepared_input = [prepared_input]
elif io_pars.mv_type == AnalogMVType.SPLIT_MODE:
n_bits = int(log2(1.0 / res))
if not log2(1.0 / res) % 1 == 0:
raise ConfigError(
f"inp_res={1. / res} must be power of 2 (or 1/2**n_bits) "
"for AnalogMVType.SPLIT_MODE"
)
if not io_pars.split_mode_bit_shift % 1 == 0:
raise ConfigError(
f"split_mode_bit_shift={io_pars.split_mode_bit_shift}" " must be integer"
)
if not io_pars.split_mode_bit_shift < n_bits:
raise ConfigError(
f"split_mode_bit_shift={io_pars.split_mode_bit_shift} "
f"cannot exceed equivalent bits specified by inp_res={n_bits}."
)
int_input = prepared_input / (2 * res)
upper_bits = sign(int_input) * floor(
torch_abs(int_input) / (2**io_pars.split_mode_bit_shift)
)
prepared_input_msb = upper_bits * (2 * res)
# +/- remainders
lower_bits = fmod(int_input, 2**io_pars.split_mode_bit_shift)
prepared_input_lsb = lower_bits * (2 * res)
if not allclose(
(2**io_pars.split_mode_bit_shift) * prepared_input_msb + prepared_input_lsb,
prepared_input,
):
raise ConfigError("Split mode pwm conversion error")
prepared_input = [prepared_input_lsb, prepared_input_msb]
elif io_pars.mv_type == AnalogMVType.BIT_WISE:
int_input = prepared_input / (2 * res)
n_bits = int(log2(1.0 / res))
prepared_input = []
for _ in range(n_bits):
# fmod for +/- remainders
lsb = fmod(int_input, 2)
int_input = sign(int_input) * floor(torch_abs(int_input) / 2)
prepared_input.append(lsb * (2 * res))
elif io_pars.mv_type in [
AnalogMVType.POS_NEG_SEPARATE,
AnalogMVType.POS_NEG_SEPARATE_DIGITAL_SUM,
]:
raise NotImplementedError
else:
raise ConfigError(f"Unknown AnalogMVType {io_pars.mv_type}")
return prepared_input
@classmethod
@no_grad()
def _thev_equiv(
cls,
input_: Tensor,
weight: Tensor,
g_converter: Optional[SinglePairConductanceConverter] = None,
time_steps: int = 128,
t_max: float = 1.0,
segments: int = 8,
r_s: float = 0.15,
phys_input_size: int = 512,
use_extension: bool = True,
) -> Tuple[Tensor, Tensor]:
"""Returns the Thevenin voltages and resistances for an entire 2D MVM
tile as a function of time. First computes the atomic time-varying
Thevenin equivalents for each unit cell. These are combined into
synthetic (non-physical) segments, which serve as an approximation to
limit computational expense. Lastly, the wire series resistance between
segments is included when collapsing the time-varying segment Thevenin
equivalents into one time-varying Thevenin equivalent circuit per MVM
tile column (i.e. each element of out_size). This part is represented by
the for loop at the end of the method. Largest indices are closest to
the ADC.
Args:
input_: ``[N, in_size]`` MVM tile input activations
weight: ``[in_size, out_size]`` MVM tile weights
time_steps: discrete time steps for time-varying Thevenin equivalent
(approximation). High value is more accurate whereas lower value
results in faster computation with more IR drop inaccuracy. SPLIT_MODE
and BIT_WISE will automatically set this parameter to an appropriate value.
g_converter: specifies weight programming scheme which determines conductances
for Thevenin equivlanet circuit.
t_max: max sim time, beyond which activations all zero. Can cease computation
segments: Number of synthetic segments for IR drop calculation (approximation).
Ideally the number of segments matches the number of unit cells per column (i.e.
in_size). Default value is 8. Increasing beyond 8 for in_size of 512 results in
diminishing returns for IR drop accuracy calculations while incurring higher
computational overhead.
Note:
When using the C++ extension segments is always maximized.
r_s: wire series resistance in units of Ohms
phys_input_size: max hardware MVM tile rows (need to 0-pad)
use_extension: Whether to use the C++ extension operator
for speedup if available
Returns:
Tuple[Tensor] containing thevenin voltages vth_3d and rth_3d where both
have dimensions ``[batch_size, out_size/2, time_steps]``. Tensor vth_3d
is given volts and rth_3d is in units of MOhms.
"""
seg_rows = int(phys_input_size / segments)
assert seg_rows * segments == phys_input_size, (
"Error: phys_input_size "
"(%s) must be evenly divisible by number "
"of segments (%s)" % (str(phys_input_size), str(segments))
)
input_, weight = cls._pad_symmetric(input_, weight, phys_input_size=phys_input_size)
if g_converter is None:
g_converter = SinglePairConductanceConverter()
[gp_2d, gm_2d], _ = g_converter.convert_to_conductances(weight)
if use_extension and extension_ops is not None:
# use C++ code for speedup if available
output = extension_ops.thevenin_equiv(
input_, gp_2d.T.contiguous(), gm_2d.T.contiguous(), r_s, t_max, time_steps
)
vth_3d = output[0, :, :, :]
rth_3d = output[1, :, :, :]
return vth_3d, rth_3d
gp_4d = gp_2d[None, :, :, None]
gm_4d = gm_2d[None, :, :, None]
t_4d = linspace(0.0, t_max, time_steps)[None, None, None, :].to(input_.device)
x_4d = input_[:, :, None, None]
def sum_segs(g_values: Tensor) -> Tensor:
if seg_rows == 1:
return g_values
shape = g_values.shape
return g_values.view(
(shape[0], shape[1] // seg_rows, seg_rows, shape[2], shape[3])
).sum(dim=2)
# pp
pos_msk = x_4d > t_4d
neg_msk = x_4d < -t_4d
g_4d = gp_4d * pos_msk + gm_4d * neg_msk
g_4d = sum_segs(g_4d)
gth_4d_segs = g_4d
vth_4d_segs = 0.6 * g_4d
# mm
g_4d = gm_4d * pos_msk + gp_4d * neg_msk
g_4d = sum_segs(g_4d)
gth_4d_segs += g_4d
vth_4d_segs += 0.2 * g_4d
# zz
g_4d = (torch_abs(x_4d) <= t_4d) * (gp_4d + gm_4d)
g_4d = sum_segs(g_4d)
# regularized to avoid device-by-zero
gth_4d_segs += g_4d + 1e-12 # atomic Thev equiv conductance [uS]
vth_4d_segs += 0.4 * g_4d # atomic Thevenin equivalent resistance [MOhm]
vth_4d_segs /= gth_4d_segs
g_4d = None
# wire resistance depends on segmentation
rw_segs = 1e-6 * r_s * seg_rows
vth_3d = vth_4d_segs[:, 0, :, :]
rth_3d = 1.0 / gth_4d_segs[:, 0, :, :]
for seg in range(1, segments, 1):
r_1 = rth_3d + rw_segs
r_2 = 1.0 / gth_4d_segs[:, seg, :, :]
rth_3d = (r_1 * r_2) / (r_1 + r_2) # parallel R
vth_3d = (vth_3d / r_1 + vth_4d_segs[:, seg, :, :] / r_2) * rth_3d
rth_3d += 0.5 * rw_segs
return vth_3d, rth_3d # rth_3d in MOhm
@classmethod
def _matmul_irdrop(
cls,
weight: Tensor,
input_: Tensor,
trans: bool,
io_pars: IOParametersIRDropT,
ir_drop: float,
t_max: float,
time_steps: int,
phys_input_size: int,
g_converter: SinglePairConductanceConverter,
info: Optional[str] = None, # pylint: disable=unused-argument
) -> Tensor:
"""The inner FP GEMM.
Args:
weight: ``[in_size, out_size]`` MVM tile weights
input_: ``[N, in_size]`` MVM tile input activations
trans: whether to transpose the weight
io_pars: Parameter defining the mat-mul nonlinearities and time-dependent IR drop
ir_drop: scale of the IR-drop wire resistance
t_max: max time, beyond which activations all zero. Can cease computation
phys_input_size: physical size of the tile in input dimension
g_converter: specifies weight programming scheme which determines conductances
for Thevenin equivlanet circuit.
info: info string.
Returns:
Tensor with 2D matmul result
"""
ir_drop_rs = ir_drop * io_pars.ir_drop_rs
if ir_drop == 0.0:
return super(AnalogMVMIRDropT, cls)._matmul(weight, input_, trans)
if weight.dim() != 2:
raise ConfigError("Only 2-d weights are supported for time-sliced IR-drop")
if not trans:
new_weight = weight.T
else:
new_weight = weight
vth_3d, rth_3d = cls._thev_equiv(
input_,
new_weight[:, 0::2], # even cols
g_converter,
time_steps=time_steps,
t_max=t_max,
segments=io_pars.ir_drop_segments,
r_s=ir_drop_rs,
phys_input_size=phys_input_size,
)
i_out_3d = (vth_3d - io_pars.ir_drop_v_read) / rth_3d # uA
mvm_even_col_down_adc = torch_sum(i_out_3d, dim=2) # batch_size x n_cols/2
vth_3d, rth_3d = cls._thev_equiv(
flip(input_, (1,)), # flip input
flip(new_weight[:, 1::2], (0,)), # odd cols
g_converter,
time_steps=time_steps,
t_max=t_max,
segments=io_pars.ir_drop_segments,
r_s=ir_drop_rs,
phys_input_size=phys_input_size,
)
i_out_3d = (vth_3d - io_pars.ir_drop_v_read) / rth_3d # uA
mvm_odd_col_up_adc = torch_sum(i_out_3d, dim=2) # batch_size x n_cols/2
mvm = cls._interleave_cols_2d(mvm_even_col_down_adc, mvm_odd_col_up_adc) # symmetric ADCs
mvm /= g_converter.g_max - g_converter.g_min # conductance normalization
mvm /= 0.2 # hardware normalization
return mvm
@classmethod
def _compute_analog_mv( # type: ignore
cls,
weight: Tensor,
input_: Tensor,
trans: bool,
scale: float,
scaling: bool,
is_test: bool,
io_pars: IOParametersIRDropT,
phys_input_size: int,
g_converter: SinglePairConductanceConverter,
**fwd_pars: Any,
) -> Tensor:
"""
Prepare input, perform noisy MVM and finalize output. Takes care of noise/bound
management and discretization.
Args:
weight: Weight tensor.
input_: Input tensor in format [N, in_size].
trans: whether to transpose the weight
scale: Scale for scaling the input.
scaling: Whether to scale.
io_pars: forward pass configuration.
ir_drop: scale of the ir drop
phys_input_size: Physical input size
g_converter: conductance programming scheme for calculating Thevenin equivalent circuit
fwd_pars: additional forward parameters
Returns:
Whether the bound management test passed and the result.
Raises:
NotImplementedError: If choices in the rpu config were made that are not supported.
ConfigError: If unknown AnalogMVType
"""
ir_drop = io_pars.ir_drop if is_test else 0.0
prepared_input = cls._prepare_inputs(
input_=input_,
scale=scale,
scaling=scaling,
with_asymmetry=io_pars.inp_asymmetry != 0.0,
io_pars=io_pars,
)
res = cls._get_res(io_pars.inp_res)
bit_res = io_pars.inp_bound / res
if io_pars.mv_type == AnalogMVType.ONE_PASS:
# - Perform the noisy MVM
out_values = cls._matmul_irdrop(
weight,
prepared_input[0],
trans,
io_pars,
ir_drop=ir_drop,
t_max=1.0,
time_steps=int((bit_res / 2) * io_pars.ir_drop_time_step_resolution_scale),
phys_input_size=phys_input_size,
g_converter=g_converter,
)
out_values /= bit_res / 2.0
bound_test_passed, finalized_outputs = cls._finalize_output(
out_values=out_values, io_pars=io_pars, **fwd_pars
)
elif io_pars.mv_type == AnalogMVType.SPLIT_MODE:
[prepared_input_lsb, prepared_input_msb] = prepared_input
time_steps = int(2**io_pars.split_mode_bit_shift)
t_max = (2 * res) * (2**io_pars.split_mode_bit_shift - 1)
out_values_lsb = cls._matmul_irdrop(
weight,
prepared_input_lsb,
trans,
io_pars,
ir_drop=ir_drop,
t_max=t_max,
time_steps=int(time_steps * io_pars.ir_drop_time_step_resolution_scale),
phys_input_size=phys_input_size,
g_converter=g_converter,
info="LSB",
)
out_values_lsb /= bit_res / 2.0 # normalize
bound_test_passed_lsb, finalized_outputs_lsb = cls._finalize_output(
out_values=out_values_lsb, io_pars=io_pars, **fwd_pars
)
time_steps = 2 ** int(
log2(bit_res + 1) - io_pars.split_mode_bit_shift - 1
) # minus 1 for sign bit
t_max = io_pars.inp_bound / (2**io_pars.split_mode_bit_shift)
out_values_msb = cls._matmul_irdrop(
weight,
prepared_input_msb,
trans,
io_pars,
ir_drop=ir_drop,
t_max=t_max,
time_steps=int(time_steps * io_pars.ir_drop_time_step_resolution_scale),
phys_input_size=phys_input_size,
g_converter=g_converter,
info="MSB",
)
out_values_msb /= bit_res / 2.0 # normalize
bound_test_passed_msb, finalized_outputs_msb = cls._finalize_output(
out_values=out_values_msb, io_pars=io_pars, **fwd_pars
)
finalized_outputs = (
finalized_outputs_lsb + (2**io_pars.split_mode_bit_shift) * finalized_outputs_msb
)
bound_test_passed = bound_test_passed_lsb * bound_test_passed_msb
elif io_pars.mv_type == AnalogMVType.BIT_WISE:
finalized_outputs, bound_test_passed = 0.0, True
for bit_pos, prepared_input_1b in enumerate(prepared_input):
out_values_1b = cls._matmul_irdrop(
weight,
prepared_input_1b,
trans,
io_pars,
ir_drop=ir_drop,
t_max=2 * res,
time_steps=2,
phys_input_size=phys_input_size,
g_converter=g_converter,
info=str(bit_pos),
)
out_values_1b /= bit_res / 2.0 # normalize
bound_test_passed_1b, finalized_outputs_1b = cls._finalize_output(
out_values=out_values_1b, io_pars=io_pars, **fwd_pars
)
finalized_outputs += (2**bit_pos) * finalized_outputs_1b
bound_test_passed *= bound_test_passed_1b
elif io_pars.mv_type in [
AnalogMVType.POS_NEG_SEPARATE,
AnalogMVType.POS_NEG_SEPARATE_DIGITAL_SUM,
]:
raise NotImplementedError
else:
raise ConfigError(f"Unknown AnalogMVType {io_pars.mv_type}")
return bound_test_passed, finalized_outputs
[docs] @classmethod
def check_support(cls, io_pars: IOParameters) -> None:
"""Check whether the IO settings are supported.
Throws an assertion error when there is an incompatibility
Args:
io_pars: the IOParametersDropT to be checked
Raises:
ConfigError: in case a feature is not supported
"""
# pylint: disable=too-many-branches
if io_pars.mv_type not in [
AnalogMVType.ONE_PASS,
AnalogMVType.SPLIT_MODE,
AnalogMVType.BIT_WISE,
]:
raise ConfigError(
"Only AnalogMVType.ONE_PASS, "
"AnalogMVType.SPLIT_MODE, and "
"AnalogMVType.BIT_WISE supported as forward.mv_type"
)
if io_pars.bound_management == BoundManagementType.SHIFT:
raise ConfigError("Shift bound management not supported in torch tile")
if io_pars.noise_management in [
NoiseManagementType.AVERAGE_ABS_MAX,
NoiseManagementType.ABS_MAX_NP_SUM,
]:
raise ConfigError("Special noise mangement types not supported.")
if io_pars.w_noise > 0.0 or io_pars.w_noise_type != WeightNoiseType.NONE:
raise ConfigError("forward.w_noise not supported in torch tile")
if io_pars.out_nonlinearity > 0.0:
raise ConfigError("S-shaped non-linearity not supported in torch tile")
if io_pars.slope_calibration > 0.0:
raise ConfigError("Slope calibration not supported in torch tile")
if io_pars.v_offset_std > 0.0 or io_pars.v_offset_w_min > 0.0 or io_pars.r_series > 0.0:
raise ConfigError("Voltage offset or R-series not supported in torch tile")
if io_pars.w_read_asymmetry_dtod > 0.0:
raise ConfigError("Device polarity read dependence is not supported in torch tile")