aihwkit.simulator.tiles.rpucuda module
Wrapper for the RPUCuda C++ tiles.
- class aihwkit.simulator.tiles.rpucuda.RPUCudaSimulatorTileWrapper(out_size, in_size, rpu_config, bias=True, in_trans=False, out_trans=False, shared_weights=False)[source]
Bases:
SimulatorTileWrapper
Wraps the RPUCuda simulator tile.
This class adds some functionality to the minimalistic
SimulatorTileWrapper
specific to the RPUCuda tiles that are handled in C++ through python bindings .- Parameters:
out_size (int) – output size
in_size (int) – input size
rpu_config (InferenceRPUConfig | SingleRPUConfig | UnitCellRPUConfig | TorchInferenceRPUConfig | DigitalRankUpdateRPUConfig) – resistive processing unit configuration.
bias (bool) – whether to add a bias column to the tile.
in_trans (bool) – Whether to assume an transposed input (batch first)
out_trans (bool) – Whether to assume an transposed output (batch first)
shared_weights (bool) – optional shared weights tensor memory that should be used.
- cuda(device=None)[source]
Return a copy of the tile in CUDA memory.
- Parameters:
device (str | device | int | None) – CUDA device
- Returns:
Self with the underlying C++ tile moved to CUDA memory.
- Raises:
CudaError – if the library has not been compiled with CUDA.
- Return type:
- decay_weights(alpha=1.0)[source]
Decays the weights once according to the decay parameters of the tile.
- Parameters:
alpha (float) – additional decay scale (such as LR). The base decay rate is set during tile init.
- Returns:
None.
- Return type:
None
- diffuse_weights()[source]
Diffuses the weights once according to the diffusion parameters of the tile.
The base diffusion rate is set during tile init.
- Returns:
None
- Return type:
None
- drift_weights(delta_t=1.0)[source]
Drifts the weights once according to the drift parameters of the tile.
See also
DriftParameter
.- Parameters:
delta_t (float) – Time since last drift call.
- Returns:
None.
- Return type:
None
- dump_extra()[source]
Dumps any extra states / attributed necessary for checkpointing.
For Tiles based on Modules, this should be normally handled by torch automatically.
- Return type:
Dict[str, Any] | None
Ensure that the shared_weights is set properly.
Caution
This is only called from analog function.
No-op if shared weights is not used.
- Parameters:
shared_weights (Tensor | None) –
- Return type:
None
- get_backward_parameters()[source]
Get the additional parameters generated for the backward pass.
- Returns:
Dictionary of the forward parameters set.
- Return type:
Dict[str, Tensor]
- get_forward_out_bound()[source]
Helper for getting the output bound to correct the gradients using the AnalogFunction.
- Return type:
float | None
- get_forward_parameters()[source]
Get the additional parameters generated for the forward pass.
- Returns:
Dictionary of the forward parameters set.
- Return type:
Dict[str, Tensor]
Get the current updated device index of the hidden devices.
Usually this is 0 as only one device is present per cross-point for many tile RPU configs. However, some RPU configs maintain internally multiple devices per cross-point (e.g.
VectorUnitCell
).- Returns:
The next mini-batch updated device index.
- Return type:
int
Note
Depending on the update and learning policy implemented in the tile, updated devices might switch internally as well.
- load_extra(extra, strict=False)[source]
Load any extra states / attributed necessary for loading from checkpoint.
For Tiles based on Modules, this should be normally handled by torch automatically.
Note
Expects the exact same RPUConfig / device etc for applying the states. Cross-loading of state-dicts is not supported for extra states, they will be just ignored.
- Parameters:
extra (Dict[str, Any]) – dictionary of states from dump_extra.
strict (bool) – Whether to throw an error if keys are not found.
- Return type:
None
- post_update_step()[source]
Operators that need to be called once per mini-batch.
Note
This function is called by the analog optimizer.
Caution
If no analog optimizer is used, the post update steps will not be performed.
- Return type:
None
- reset(reset_prob=1.0)[source]
Reset the updated device tile according to the reset parameters of the tile.
Resets the weights with device-to-device and cycle-to-cycle variability (depending on device type), typically:
\[W_{ij} = \xi*\sigma_\text{reset} + b^\text{reset}_{ij}\]The reset parameters are set during tile init.
- Parameters:
reset_prob (float) – individual probability of reset.
- Returns:
None
- Return type:
None
- reset_columns(start_column_idx=0, num_columns=1, reset_prob=1.0)[source]
Reset (a number of) columns according to the reset parameters of the tile.
Resets the weights with device-to-device and cycle-to-cycle variability (depending on device type), typically:
\[W_{ij} = \xi*\sigma_\text{reset} + b^\text{reset}_{ij}\]The reset parameters are set during tile init.
- Parameters:
start_column_idx (int) – a start index of columns (0..x_size-1)
num_columns (int) – how many consecutive columns to reset (with circular warping)
reset_prob (float) – individual probability of reset.
- Returns:
None
- Return type:
None
- reset_delta_weights()[source]
Reset the weight grad tensor to default update behavior (i.e. adding the update directly to the weight).
No-op if shared weights is not used.
- Return type:
None
- set_backward_parameters(dic, **kwargs)[source]
Set the additional parameters generated for the backward pass.
- Parameters:
dic (Dict[str, Tensor] | None) – dictionary of parameters to set (from
get_backward_parameter()
)kwargs (Dict[str, Tensor]) – parameter names can alternatively given directly as keywords
- Return type:
None
- set_delta_weights(delta_weights=None)[source]
Set the weight grad tensor and set the update to.
No-op if shared weights is not used.
- Parameters:
delta_weights (Tensor | None) –
- Return type:
None
- set_forward_parameters(dic=None, **kwargs)[source]
Set the additional parameters generated for the forward pass.
- Parameters:
dic (Dict[str, Tensor] | None) – dictionary of parameters to set (from
get_forward_parameter()
)kwargs (Dict[str, Tensor]) – parameter names can alternatively given directly as keywords
- Return type:
None
Set the current updated hidden device index.
Usually this is ignored and fixed to 0 as only one device is present per cross-point. Other devices, might not allow explicit setting as it would interfere with the implemented learning rule. However, some tiles have internally multiple devices per cross-point (eg. unit cell) that can be chosen depending on the update policy.
- Parameters:
index (int) – device index to be updated in the next mini-batch
- Return type:
None
Note
Depending on the update and learning policy implemented in the tile, updated devices might switch internally as well.