aihwkit.inference.compensation.base module
Base drift compensation for inference.
- class aihwkit.inference.compensation.base.BaseDriftCompensation[source]
Bases:
objectBase class for drift compensations.
- apply(forward_output, ref_value)[source]
Read out the current value from the output of the forward pass and returns the drift compensation alpha scale.
- Parameters:
forward_output (Tensor)
ref_value (Tensor)
- Return type:
Tensor
- get_readout_tensor(in_size)[source]
Return the read-out tensor.
Called once during
init_baseline().- Parameters:
in_size (int)
- Return type:
Tensor
- init_baseline(tile)[source]
Initialize the base line for applying the compensation.
Uses a all one tensor for read_out.
- Parameters:
tile (InferenceTileWithPeriphery) – forward output of the read out vector to compensate
- Returns:
reference tensor readout
- Return type:
Tuple[Tensor, Tensor]