aihwkit.simulator.tiles.utils module

Low level implementation of torch-based tile.

class aihwkit.simulator.tiles.utils.UniformQuantize(inplace=False)[source]

Bases: InplaceFunction

Quantization in-place function.

static backward(ctx, grad_output)[source]

Straight-through estimator.

Parameters:
  • ctx (FunctionCtx) – Context.

  • grad_output (Tensor) – Gradient w.r.t. the inputs.

Returns:

Gradients w.r.t. inputs to forward.

Return type:

Tuple[Tensor, None, None, None]

static forward(ctx, inp, res, bound, stochastic=False)[source]

Quantizes the input tensor and performs straight-through estimation.

Parameters:
  • ctx (FunctionCtx) – Context.

  • inp (torch.Tensor) – Input to be discretized.

  • res (float) – Resolution (number of states).

  • bound (float) – Input bounds w.r.t. which we quantize.

  • stochastic (bool, optional) – Stochastic rounding? Defaults to False.

Returns:

Quantized input.

Return type:

torch.Tensor

aihwkit.simulator.tiles.utils.isinf(x)[source]

Checks if the input is inf.

Parameters:

x (Union[float, str, torch.Tensor, ndarray]) – Input.

Returns:

Boolean tensor where tensor is inf.

Return type:

torch.Tensor