aihwkit.inference.compensation.base module

Base drift compensation for inference.

class aihwkit.inference.compensation.base.BaseDriftCompensation[source]

Bases: object

Base class for drift compensations.

apply(forward_output, ref_value)[source]

Read out the current value from the output of the forward pass and returns the drift compensation alpha scale.

Return type

torch.Tensor

get_readout_tensor(in_size)[source]

Return the read-out tensor.

Called once during init_baseline().

Return type

torch.Tensor

init_baseline(forward_output)[source]

Initialize the base line for applying the compensation.

Uses a all one tensor for read_out.

Parameters

forward_output – forward output of the read out vector to compensate

Returns

reference tensor readout

Return type

Tuple[torch.Tensor, torch.Tensor]

readout(out_tensor)[source]

Implement the read out math.

Return type

torch.Tensor