aihwkit.inference.compensation.drift module

Global drift compensation for inference.

class aihwkit.inference.compensation.drift.GlobalDriftCompensation[source]

Bases: aihwkit.inference.compensation.base.BaseDriftCompensation

Global drift compensation.

Uses a constant factor for compensating the drift.

Return type

None

get_readout_tensor(in_size)[source]

Return the read-out tensor.

Uses the set of one-hot vectors (eye).

Parameters

in_size (int) –

Return type

torch.Tensor

readout(out_tensor)[source]

Read outs the mean abs.

Parameters

out_tensor (torch.Tensor) –

Return type

torch.Tensor