aihwkit.inference.compensation.base module

Base drift compensation for inference.

class aihwkit.inference.compensation.base.BaseDriftCompensation[source]

Bases: object

Base class for drift compensations.

Return type

None

apply(forward_output, ref_value)[source]

Read out the current value from the output of the forward pass and returns the drift compensation alpha scale.

Parameters
  • forward_output (torch.Tensor) –

  • ref_value (torch.Tensor) –

Return type

torch.Tensor

get_readout_tensor(in_size)[source]

Return the read-out tensor.

Called once during init_baseline().

Parameters

in_size (int) –

Return type

torch.Tensor

init_baseline(forward_output)[source]

Initialize the base line for applying the compensation.

Uses a all one tensor for read_out.

Parameters

forward_output (torch.Tensor) – forward output of the read out vector to compensate

Returns

reference tensor readout

Return type

Tuple[torch.Tensor, torch.Tensor]

readout(out_tensor)[source]

Implement the read out math.

Parameters

out_tensor (torch.Tensor) –

Return type

torch.Tensor