aihwkit.simulator.tiles.rpucuda module

Wrapper for the RPUCuda C++ tiles.

class aihwkit.simulator.tiles.rpucuda.RPUCudaSimulatorTileWrapper(out_size, in_size, rpu_config, bias=True, in_trans=False, out_trans=False, shared_weights=False)[source]

Bases: SimulatorTileWrapper

Wraps the RPUCuda simulator tile.

This class adds some functionality to the minimalistic SimulatorTileWrapper specific to the RPUCuda tiles that are handled in C++ through python bindings .

Parameters:
  • out_size (int) – output size

  • in_size (int) – input size

  • rpu_config (InferenceRPUConfig | SingleRPUConfig | UnitCellRPUConfig | TorchInferenceRPUConfig | DigitalRankUpdateRPUConfig) – resistive processing unit configuration.

  • bias (bool) – whether to add a bias column to the tile.

  • in_trans (bool) – Whether to assume an transposed input (batch first)

  • out_trans (bool) – Whether to assume an transposed output (batch first)

  • shared_weights (bool) – optional shared weights tensor memory that should be used.

cpu()[source]

Return a copy of this tile in CPU memory.

Returns:

self in case of CPU

Return type:

SimulatorTileWrapper

cuda(device=None)[source]

Return a copy of the tile in CUDA memory.

Parameters:

device (str | device | int | None) – CUDA device

Returns:

Self with the underlying C++ tile moved to CUDA memory.

Raises:

CudaError – if the library has not been compiled with CUDA.

Return type:

SimulatorTileWrapper

decay_weights(alpha=1.0)[source]

Decays the weights once according to the decay parameters of the tile.

Parameters:

alpha (float) – additional decay scale (such as LR). The base decay rate is set during tile init.

Returns:

None.

Return type:

None

diffuse_weights()[source]

Diffuses the weights once according to the diffusion parameters of the tile.

The base diffusion rate is set during tile init.

Returns:

None

Return type:

None

drift_weights(delta_t=1.0)[source]

Drifts the weights once according to the drift parameters of the tile.

See also DriftParameter.

Parameters:

delta_t (float) – Time since last drift call.

Returns:

None.

Return type:

None

dump_extra()[source]

Dumps any extra states / attributed necessary for checkpointing.

For Tiles based on Modules, this should be normally handled by torch automatically.

Return type:

Dict[str, Any] | None

ensure_shared_weights(shared_weights=None)[source]

Ensure that the shared_weights is set properly.

Caution

This is only called from analog function.

No-op if shared weights is not used.

Parameters:

shared_weights (Tensor | None) –

Return type:

None

get_backward_parameters()[source]

Get the additional parameters generated for the backward pass.

Returns:

Dictionary of the forward parameters set.

Return type:

Dict[str, Tensor]

get_forward_out_bound()[source]

Helper for getting the output bound to correct the gradients using the AnalogFunction.

Return type:

float | None

get_forward_parameters()[source]

Get the additional parameters generated for the forward pass.

Returns:

Dictionary of the forward parameters set.

Return type:

Dict[str, Tensor]

get_hidden_update_index()[source]

Get the current updated device index of the hidden devices.

Usually this is 0 as only one device is present per cross-point for many tile RPU configs. However, some RPU configs maintain internally multiple devices per cross-point (e.g. VectorUnitCell).

Returns:

The next mini-batch updated device index.

Return type:

int

Note

Depending on the update and learning policy implemented in the tile, updated devices might switch internally as well.

load_extra(extra, strict=False)[source]

Load any extra states / attributed necessary for loading from checkpoint.

For Tiles based on Modules, this should be normally handled by torch automatically.

Note

Expects the exact same RPUConfig / device etc for applying the states. Cross-loading of state-dicts is not supported for extra states, they will be just ignored.

Parameters:
  • extra (Dict[str, Any]) – dictionary of states from dump_extra.

  • strict (bool) – Whether to throw an error if keys are not found.

Return type:

None

post_update_step()[source]

Operators that need to be called once per mini-batch.

Note

This function is called by the analog optimizer.

Caution

If no analog optimizer is used, the post update steps will not be performed.

Return type:

None

reset(reset_prob=1.0)[source]

Reset the updated device tile according to the reset parameters of the tile.

Resets the weights with device-to-device and cycle-to-cycle variability (depending on device type), typically:

\[W_{ij} = \xi*\sigma_\text{reset} + b^\text{reset}_{ij}\]

The reset parameters are set during tile init.

Parameters:

reset_prob (float) – individual probability of reset.

Returns:

None

Return type:

None

reset_columns(start_column_idx=0, num_columns=1, reset_prob=1.0)[source]

Reset (a number of) columns according to the reset parameters of the tile.

Resets the weights with device-to-device and cycle-to-cycle variability (depending on device type), typically:

\[W_{ij} = \xi*\sigma_\text{reset} + b^\text{reset}_{ij}\]

The reset parameters are set during tile init.

Parameters:
  • start_column_idx (int) – a start index of columns (0..x_size-1)

  • num_columns (int) – how many consecutive columns to reset (with circular warping)

  • reset_prob (float) – individual probability of reset.

Returns:

None

Return type:

None

reset_delta_weights()[source]

Reset the weight grad tensor to default update behavior (i.e. adding the update directly to the weight).

No-op if shared weights is not used.

Return type:

None

set_backward_parameters(dic, **kwargs)[source]

Set the additional parameters generated for the backward pass.

Parameters:
  • dic (Dict[str, Tensor] | None) – dictionary of parameters to set (from get_backward_parameter())

  • kwargs (Dict[str, Tensor]) – parameter names can alternatively given directly as keywords

Return type:

None

set_delta_weights(delta_weights=None)[source]

Set the weight grad tensor and set the update to.

No-op if shared weights is not used.

Parameters:

delta_weights (Tensor | None) –

Return type:

None

set_forward_parameters(dic=None, **kwargs)[source]

Set the additional parameters generated for the forward pass.

Parameters:
  • dic (Dict[str, Tensor] | None) – dictionary of parameters to set (from get_forward_parameter())

  • kwargs (Dict[str, Tensor]) – parameter names can alternatively given directly as keywords

Return type:

None

set_hidden_update_index(index)[source]

Set the current updated hidden device index.

Usually this is ignored and fixed to 0 as only one device is present per cross-point. Other devices, might not allow explicit setting as it would interfere with the implemented learning rule. However, some tiles have internally multiple devices per cross-point (eg. unit cell) that can be chosen depending on the update policy.

Parameters:

index (int) – device index to be updated in the next mini-batch

Return type:

None

Note

Depending on the update and learning policy implemented in the tile, updated devices might switch internally as well.

set_verbosity_level(verbose)[source]

Set verbosity level of tile.

Parameters:

verbose (int) – level of verbosity

Return type:

None