aihwkit.experiments.experiments.training module

Basic training Experiment.

class aihwkit.experiments.experiments.training.BasicTraining(dataset, model, batch_size=64, loss_function=<class 'torch.nn.modules.loss.NLLLoss'>, epochs=30, learning_rate=0.05)[source]

Bases: Experiment

Experiment for training a neural network.

Experiment that represents training a neural network using a basic training loop.

This class contains:

  • the data needed for an experiment. The recommended way of setting this data is via the arguments of the constructor. Additionally, some of the items have getters that are used by the Workers that execute the experiments and by the training loop.

  • the training algorithm, with the main entry point being train().

Note

When executing a BasicTraining in the cloud, additional constraints are applied to the data. For example, the model is restricted to sequential layers of specific types; the dataset choices are limited, etc. Please check the CloudRunner documentation.

Parameters:
  • dataset (Type[Dataset]) –

  • model (Module) –

  • batch_size (int) –

  • loss_function (type) –

  • epochs (int) –

  • learning_rate (float) –

get_data_loaders(dataset, batch_size, max_elements_train=0, dataset_root='/tmp/datasets')[source]

Return DataLoaders for the selected dataset.

Parameters:
  • dataset (type) – the dataset class to be used.

  • batch_size (int) – the batch size used for training.

  • max_elements_train (int) – the maximum number of elements of the dataset to be used. If 0, the full dataset is used.

  • dataset_root (str) – the path to the folder where the files from the dataset are stored.

Returns:

A tuple with the training and validation loaders.

Return type:

Tuple[DataLoader, DataLoader]

get_dataset_arguments(dataset)[source]

Return the dataset constructor arguments for specifying subset.

Parameters:

dataset (type) –

Return type:

Tuple[Dict, Dict]

get_dataset_transform(dataset)[source]

Return the dataset transform.

Parameters:

dataset (type) –

Return type:

Any

get_optimizer(learning_rate, model)[source]

Return the Optimizer for the experiment.

Parameters:
  • learning_rate (float) – the learning rate used by the optimizer.

  • model (Module) – the neural network to be trained.

Returns:

the optimizer to be used in the experiment.

Return type:

Optimizer

run(max_elements=0, dataset_root='/tmp/data', device=None)[source]

Sets up and runs the training.

Results are returned and the internal model is updated.

Parameters:
  • max_elements (int) –

  • dataset_root (str) –

  • device (device | None) –

Return type:

List[Dict]

train(training_loader, validation_loader, model, optimizer, loss_function, epochs, device)[source]

Run the training loop.

Parameters:
  • training_loader (DataLoader) – the data loader for the training data.

  • validation_loader (DataLoader) – the data loader for the validation data.

  • model (Module) – the neural network to be trained.

  • optimizer (Optimizer) – the optimizer used for the training.

  • loss_function (_Loss) – the loss function used for training.

  • epochs (int) – the number of epochs for the training.

  • device (device) – the torch device used for the model.

Returns:

A list of the metrics for each epoch.

Return type:

List[Dict]

training_step(training_loader, model, optimizer, loss_function, device)[source]

Run a single training step.

Parameters:
  • training_loader (DataLoader) – the data loader for the training data.

  • model (Module) – the neural network to be trained.

  • optimizer (Optimizer) – the optimizer used for the training.

  • loss_function (_Loss) – the loss function used for training.

  • device (device) – the torch device used for the model.

Return type:

None

validation_step(validation_loader, model, loss_function, device)[source]

Run a single evaluation step.

Parameters:
  • validation_loader (DataLoader) – the data loader for the validation data.

  • model (Module) – the neural network to be trained.

  • loss_function (_Loss) – the loss function used for training.

  • device (device) – the torch device used for the model.

Return type:

None