aihwkit.inference.compensation.base module

Base drift compensation for inference.

class aihwkit.inference.compensation.base.BaseDriftCompensation[source]

Bases: object

Base class for drift compensations.

apply(forward_output, ref_value)[source]

Read out the current value from the output of the forward pass and returns the drift compensation alpha scale.

Parameters:
  • forward_output (Tensor) –

  • ref_value (Tensor) –

Return type:

Tensor

get_readout_tensor(in_size)[source]

Return the read-out tensor.

Called once during init_baseline().

Parameters:

in_size (int) –

Return type:

Tensor

init_baseline(forward_output)[source]

Initialize the base line for applying the compensation.

Uses a all one tensor for read_out.

Parameters:

forward_output (Tensor) – forward output of the read out vector to compensate

Returns:

reference tensor readout

Return type:

Tuple[Tensor, Tensor]

readout(out_tensor)[source]

Implement the read out math.

Parameters:

out_tensor (Tensor) –

Return type:

Tensor