aihwkit.simulator.tiles.utils module
Low level implementation of torch-based tile.
- class aihwkit.simulator.tiles.utils.UniformQuantize(inplace=False)[source]
Bases:
InplaceFunction
Quantization in-place function.
- static backward(ctx, grad_output)[source]
Straight-through estimator.
- Parameters:
ctx (FunctionCtx) – Context.
grad_output (Tensor) – Gradient w.r.t. the inputs.
- Returns:
Gradients w.r.t. inputs to forward.
- Return type:
Tuple[Tensor, None, None, None]
- static forward(ctx, inp, res, bound, stochastic=False)[source]
Quantizes the input tensor and performs straight-through estimation.
- Parameters:
ctx (FunctionCtx) – Context.
inp (torch.Tensor) – Input to be discretized.
res (float) – Resolution (number of states).
bound (float) – Input bounds w.r.t. which we quantize.
stochastic (bool, optional) – Stochastic rounding? Defaults to False.
- Returns:
Quantized input.
- Return type:
torch.Tensor