# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Base class for an Experiment Runner."""

from typing import Any, Dict, List, Optional

from import CloudExperiment
from import CredentialsError
from import ApiSession
from import ClientConfiguration
from import InferenceApiClient
from aihwkit.experiments import BasicInferencing
from aihwkit.experiments.runners.base import Runner

[docs]class InferenceCloudRunner(Runner): """Runner that executes Experiments in the AIHW Composer cloud. Class that allows executing Experiments in the cloud. """ # pylint: disable=too-few-public-methods def __init__( self, api_url: Optional[str] = None, api_token: Optional[str] = None, verify: bool = False ): """Create a new ``InferenceCloudRunner``. Note: If no ``api_token`` or ``api_url`` is provided, this class will attempt to read them from the local configuration file (by default, at ``~/.config/aihwkit/aihwkit.conf`` or environment variables (``AIHW_API_TOKEN``). Args: api_url: the URL of the AIHW Composer API. api_token: the API token for authentication. verify: if ``False``, disable the remote server TLS verification. Raises: CredentialsError: if no credentials could be found. """ # Attempt to load credentials if not present. if not api_url or not api_token: config = ClientConfiguration() api_url = api_url or config.url api_token = api_token or config.token if not api_url or not api_token: raise CredentialsError("No credentials could be found") self.api_url = api_url self.api_token = api_token # Authenticate. self.session = ApiSession(self.api_url, self.api_token, verify) self.api_client = InferenceApiClient(self.session)
[docs] def get_cloud_experiment(self, id_: str) -> CloudExperiment: """Return a single cloud experiment by id. Args: id_: the identifier of the cloud experiment. Returns: A ``CloudExperiment``. """ return self.api_client.experiment_get(id_)
[docs] def list_cloud_experiments(self) -> List[CloudExperiment]: """Return a list of cloud experiments. Returns: A list of ``CloudExperiments``. """ return self.api_client.experiments_list()
[docs] def run( # type: ignore[override] self, experiment: BasicInferencing, analog_info: Dict, noise_model_info: Dict, name: str = "", device: str = "gpu", **_: Any, ) -> CloudExperiment: """Run a single Experiment. Starts the execution of an Experiment in the cloud. Upon successful invocation, this method will return a ``CloudExperiment`` object that can be used for inspecting the status of the remote execution. Note: Please be aware that the ``experiment`` is subjected to some constraints compared to local running of experiments. Args: experiment: the experiment to be executed. analog_info: analog information, noise_model_info: noise model information, name: an optional name for the experiment. device: the desired device. _: extra arguments for the runner. Returns: A ``CloudExperiment`` which represents the remote experiment. """ # pylint: disable=arguments-differ # Generate an experiment name if not given. if not name: name = "aihwkit inference cloud experiment ({}, {} layers)".format( experiment.dataset.__name__, len(experiment.model) ) return self.api_client.experiment_create( experiment, analog_info, noise_model_info, name, device )