aihwkit.inference.compensation.base module¶
Base drift compensation for inference.
- class aihwkit.inference.compensation.base.BaseDriftCompensation[source]¶
Bases:
object
Base class for drift compensations.
- Return type
None
- apply(forward_output, ref_value)[source]¶
Read out the current value from the output of the forward pass and returns the drift compensation alpha scale.
- Parameters
forward_output (torch.Tensor) –
ref_value (torch.Tensor) –
- Return type
torch.Tensor
- get_readout_tensor(in_size)[source]¶
Return the read-out tensor.
Called once during
init_baseline()
.- Parameters
in_size (int) –
- Return type
torch.Tensor
- init_baseline(forward_output)[source]¶
Initialize the base line for applying the compensation.
Uses a all one tensor for read_out.
- Parameters
forward_output (torch.Tensor) – forward output of the read out vector to compensate
- Returns
reference tensor readout
- Return type
Tuple[torch.Tensor, torch.Tensor]