aihwkit.nn.modules.base module

Base class for analog Modules.

class aihwkit.nn.modules.base.AnalogModuleBase(in_features, out_features, bias, realistic_read_write=False, mapping=None)[source]

Bases: torch.nn.modules.module.Module

Base class for analog Modules.

Base Module for analog layers that use analog tiles. When subclassing, please note:

  • the _setup_tile() method is expected to be called by the subclass constructor, and it does not only create a tile.

  • register_analog_tile() needs to be called for each created analog tile

  • this module does not call torch’s Module init as the child is likely again derived from Module

  • the weight and bias Parameters are not guaranteed to be in sync with the tile weights and biases during the lifetime of the instance, for performance reasons. The canonical way of reading and writing weights is via the set_weights() and get_weights() as opposed to using the attributes directly.

  • the BaseTile subclass that is created is retrieved from the rpu_config.tile_class attribute.

Parameters
  • in_features (int) – input vector size (number of columns).

  • out_features (int) – output vector size (number of rows).

  • bias (bool) – whether to use a bias row on the analog tile or not.

  • realistic_read_write (bool) – whether to enable realistic read/write for setting initial weights and during reading of the weights.

  • mapping (Optional[aihwkit.simulator.configs.utils.MappingParameter]) – Configuration of the hardware architecture (e.g. tile size).

Return type

None

ANALOG_CTX_PREFIX: str = 'analog_ctx_'
ANALOG_INPUT_RANGE_PREFIX: str = 'analog_input_range_'
ANALOG_OUT_SCALING_ALPHA_PREFIX: str = 'analog_out_scaling_alpha_'
ANALOG_SHARED_WEIGHT_PREFIX: str = 'analog_shared_weights_'
ANALOG_STATE_PREFIX: str = 'analog_tile_state_'
analog_tile_count()[source]

Return the number of registered tiles.

Returns

Number of registered tiles

Return type

int

analog_tiles()[source]

Generator to loop over all registered analog tiles of the module

Return type

Generator[BaseTile, None, None]

drift_analog_weights(t_inference=0.0)[source]

(Program) and drift the analog weights.

Parameters

t_inference (float) – assumed time of inference (in sec)

Raises

ModuleError – if the layer is not in evaluation mode.

Return type

None

extra_repr()[source]

Set the extra representation of the module.

Returns

A string with the extra representation.

Return type

str

get_analog_tile_devices()[source]

Return a list of the devices used by the analog tiles.

Returns

List of torch devices

Return type

List[Optional[Union[torch.device, str, int]]]

get_weights(force_exact=False, apply_weight_scaling=True)[source]

Get the weight (and bias) tensors.

This uses an realistic read if the property realistic_read_write of the layer is set, unless it is overwritten by force_exact. It scales the analog weights by the digital output scales by default.

Note

This is the recommended way for setting the weight/bias matrix from the analog tile, as it will correctly fetch the weights from the internal memory. Accessing self.weight and self.bias might yield wrong results as they are not always in sync with the analog tile library, for performance reasons.

Parameters
  • force_exact (bool) – Forces an exact read to the analog tiles

  • apply_weight_scaling (bool) – Whether to return the weights with the (digital) output scaling factors applied. Note the “logical” weights of the layer which the DNN is effectively using are those with the output scales applied. If apply_weight_scaling is set to False, then only the weight values that is programmed onto the crossbar array are returned, without applying the digital scales. Default is True.

Returns

weight matrix, bias vector

Return type

tuple

Raises

ModuleError – in case of multiple defined analog tiles in the module

load_state_dict(state_dict, strict=True, load_rpu_config=True, strict_rpu_config_check=True)[source]

Specializes torch’s load_state_dict to add a flag whether to load the RPU config from the saved state.

Parameters
  • state_dict (OrderedDict[str, Tensor]) – see torch’s load_state_dict

  • strict (bool) – see torch’s load_state_dict

  • load_rpu_config (bool) –

    Whether to load the saved RPU config or use the current RPU config of the model.

    Caution

    If load_rpu_config=False the RPU config can be changed from the stored model. However, the user has to make sure that the changed RPU config makes sense.

    For instance, changing the device type might change the expected fields in the hidden parameters and result in an error.

  • strict_rpu_config_check (bool) – Whether to check and throw an error if the current rpu_config is not of the same class type when setting load_rpu_config to False. In case of False the user has to make sure that the rpu_config are compatible.

Returns

see torch’s load_state_dict

Raises
  • ModuleError – in case the rpu_config class mismatches

  • or mapping parameter mismatch for

  • load_rpu_config=False

Return type

NamedTuple

named_analog_tiles()[source]

Generator to loop over all registered analog tiles of the module with names.

Return type

Generator[Tuple[str, BaseTile], None, None]

program_analog_weights()[source]

Program the analog weights.

Raises

ModuleError – if the layer is not in evaluation mode.

Return type

None

register_analog_tile(tile, name=None, update_only=False)[source]

Register the analog context of the tile.

Note

Needs to be called at the end init to register the tile for the analog optimizers.

Parameters
  • tile (BaseTile) – tile to register

  • name (Optional[str]) – Optional tile name used as the parameter name.

  • update_only (bool) – Whether to re-register (does not advance tile counter)

Return type

None

register_helper(name)[source]

Register a helper name that is not saved to the state dict

Parameters

name (str) –

Return type

None

remap_weights(weight_scaling_omega=1.0)[source]

Gets and re-sets the weights in case of using the weight scaling.

This re-sets the weights with applied mapping scales, so that the weight mapping scales are updated.

In case of hardware-aware training, this would update the weight mapping scales so that the absolute max analog weights are set to 1 (as specified in the weight_scaling configuration of MappingParameter).

Note

By default the weight scaling omega factor is set to 1 here (overriding any setting in the rpu_config). This means that the max weight value is set to 1 internally for the analog weights.

Caution

This should typically not be called for analog training unless realistic_read_write is set. In this case, it would perform a full re-write of the weights.

Parameters

weight_scaling_omega (Optional[float]) – The weight scaling omega factor (see MappingParameter). If set to None here, it will take the value in the mapping parameters. Default is however 1.0.

Return type

None

set_weights(weight, bias=None, force_exact=False, apply_weight_scaling=True, weight_scaling_omega=None)[source]

Set the weight (and bias) values with given tensors.

This uses an realistic write if the property realistic_read_write of the layer is set, unless it is overwritten by force_exact.

If weight_scaling_omega is larger than 0, the weights are set in a scaled manner (assuming a digital output scale). See apply_weight_scaling() for details.

Note

This is the recommended way for setting the weight/bias matrix of the analog tile, as it will correctly store the weights into the internal memory. Directly writing to self.weight and self.bias might yield wrong results as they are not always in sync with the analog tile Parameters, for performance reasons.

Parameters
  • weight (torch.Tensor) – weight matrix

  • bias (Optional[torch.Tensor]) – bias vector

  • force_exact (bool) – forces an exact write to the analog tiles

  • apply_weight_scaling (bool) – Whether to rescale the given weight matrix and populate the digital output scaling factors as specified in the configuration MappingParameter. A new weight_scaling_omega can be given. Note that this will overwrite the existing digital out scaling factors.

  • weight_scaling_omega (Optional[float]) – The weight scaling omega factor (see MappingParameter). If given explicitly here, it will overwrite the value in the mapping field.

Raises

ModuleError – in case of multiple defined analog tiles in the module

Return type

None

state_dict(destination=None, prefix='', keep_vars=False)[source]

Return a dictionary containing a whole state of the module.

Parameters
  • destination (Optional[Any]) –

  • prefix (str) –

  • keep_vars (bool) –

Return type

Dict

unregister_parameter(param_name)[source]

Unregister module parameter from parameters.

Raises

ModuleError – In case parameter is not found

Parameters

param_name (str) –

Return type

None