aihwkit.nn.functions module¶
Autograd functions for aihwkit.
- class aihwkit.nn.functions.AnalogFunction(*args, **kwargs)[source]¶
Bases:
aihwkit.nn.functions.AnalogFunctionBase
Function that delegates into a RPU unit.
- static forward(ctx, analog_ctx, input_, shared_weights=None, is_test=False)[source]¶
Execute the forward pass in the analog tile.
- Parameters
ctx (Any) –
analog_ctx (aihwkit.optim.context.AnalogContext) –
input_ (torch.Tensor) –
shared_weights (Optional[torch.Tensor]) –
is_test (bool) –
- Return type
torch.Tensor
- class aihwkit.nn.functions.AnalogFunctionBase(*args, **kwargs)[source]¶
Bases:
torch.autograd.function.Function
Base function for analog functions.
- static backward(ctx, grad_output)[source]¶
Execute the backward pass in the analog tile.
- Parameters
ctx (Any) –
grad_output (torch.Tensor) –
- Return type
Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]
- static forward(ctx, analog_ctx, input_, shared_weights=None, is_test=False)[source]¶
Execute the forward pass in the analog tile.
Note: Indexed versions can used when analog_ctx.use_indexed is set to True.
- Parameters
ctx (Any) –
analog_ctx (aihwkit.optim.context.AnalogContext) –
input_ (torch.Tensor) –
shared_weights (Optional[torch.Tensor]) –
is_test (bool) –
- Return type
torch.Tensor
- class aihwkit.nn.functions.AnalogIndexedFunction(*args, **kwargs)[source]¶
Bases:
aihwkit.nn.functions.AnalogFunctionBase
Function that delegates into a RPU unit to use the indexed forward/backward/update.
- static forward(ctx, analog_ctx, input_, shared_weights=None, is_test=False)[source]¶
Execute the forward pass in the analog tile.
- Parameters
ctx (Any) –
analog_ctx (aihwkit.optim.context.AnalogContext) –
input_ (torch.Tensor) –
shared_weights (Optional[torch.Tensor]) –
is_test (bool) –
- Return type
torch.Tensor