aihwkit.inference.compensation.base module
Base drift compensation for inference.
- class aihwkit.inference.compensation.base.BaseDriftCompensation[source]
Bases:
object
Base class for drift compensations.
- apply(forward_output, ref_value)[source]
Read out the current value from the output of the forward pass and returns the drift compensation alpha scale.
- Parameters:
forward_output (Tensor) –
ref_value (Tensor) –
- Return type:
Tensor
- get_readout_tensor(in_size)[source]
Return the read-out tensor.
Called once during
init_baseline()
.- Parameters:
in_size (int) –
- Return type:
Tensor