Source code for aihwkit.experiments.runners.local

# -*- coding: utf-8 -*-

# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Runner that executes Experiments locally."""

from typing import Dict, List, Optional

from torch import device as torch_device
from torchvision.datasets import FashionMNIST, SVHN

from aihwkit.experiments.experiments.base import Signals
from aihwkit.experiments.runners.base import Runner
from aihwkit.experiments.experiments.training import BasicTraining
from aihwkit.experiments.runners.metrics import LocalMetric


[docs]class LocalRunner(Runner): """Runner that executes Experiments locally. Class that allows executing Experiments locally. """ # pylint: disable=too-few-public-methods def __init__(self, device: Optional[torch_device] = None): """Create a new ``LocalRunner``. Args: device: the device where the model will be moved to. """ self.device = device
[docs] def run( # type: ignore[override] self, experiment: BasicTraining, max_elements_train: int = 0, dataset_root: str = "/tmp/datasets", stdout: bool = False, ) -> List[Dict]: """Run a single Experiment. Executes an experiment locally, in the device specified by ``self.device``, optionally printing information to stdout. Note: If using a dataset different than ``FashionMNIST`` or ``SVHN``, the runner assumes that the files for the dataset are downloaded at ``dataset_root``. For those two datasets, the downloading will take place automatically if the files are not present. Args: experiment: the experiment to be executed. max_elements_train: limit on the amount of samples to use from the dataset. If ``0``, no limit is applied. dataset_root: path for the dataset files. stdout: enable printing to stdout during the execution of the experiment. Returns: A list of dictionaries with information about each epoch. """ # pylint: disable=arguments-differ # Setup the metric helper for the experiment. metric = LocalMetric(stdout=stdout) experiment.clear_hooks() experiment.add_hook(Signals.EPOCH_START, metric.receive_epoch_start) experiment.add_hook(Signals.EPOCH_END, metric.receive_epoch_end) experiment.add_hook(Signals.TRAIN_EPOCH_END, metric.receive_train_epoch_end) experiment.add_hook(Signals.TRAIN_EPOCH_BATCH_END, metric.receive_train_epoch_batch_end) experiment.add_hook( Signals.VALIDATION_EPOCH_BATCH_END, metric.receive_validation_epoch_batch_end ) experiment.add_hook(Signals.VALIDATION_EPOCH_END, metric.receive_validation_epoch_end) # Download the dataset if needed. if experiment.dataset == FashionMNIST: _ = experiment.dataset(dataset_root, download=True) elif experiment.dataset == SVHN: _ = experiment.dataset(dataset_root, download=True, split="train") _ = experiment.dataset(dataset_root, download=True, split="test") return experiment.run(max_elements_train, dataset_root, self.device)