Source code for aihwkit.cloud.client.v1.parsers

# -*- coding: utf-8 -*-

# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Parsers for the AIHW Composer API."""

from datetime import datetime, timezone
from typing import Any, Dict

from aihwkit.cloud.client.entities import (
    CloudExperiment,
    CloudExperimentCategory,
    CloudJobStatus,
    CloudJob,
)
from aihwkit.cloud.client.exceptions import InvalidResponseFieldError


[docs]class ExperimentParser: """Parser for Experiment API responses."""
[docs] @staticmethod def parse_experiment(api_response: Dict, api_client: Any) -> CloudExperiment: """Return a CloudExperiment from an API response. Args: api_response: the response from the API. api_client: the client to be used in API requests. Returns: A `CloudExperiment` based on the response. Some of the fields might not be populated if they are not present in the response. """ experiment = CloudExperiment( _api_client=api_client, id_=api_response["id"], name=api_response["name"], category=ExperimentParser.parse_experiment_category(api_response), created_at=ExperimentParser.parse_date_string(api_response["createdAt"]), input_id=None, job=None, ) # debug # print('api_response: ', api_response) if api_response.get("input"): if api_response.get("input", {}).get("id"): experiment.input_id = api_response["input"]["id"] if api_response.get("job", None): experiment.job = ExperimentParser.parse_job(api_response["job"]) return experiment
[docs] @staticmethod def parse_job(api_response: Dict) -> CloudJob: """Return an CloudJob from an API response. Args: api_response: the response from the API. Returns: A `CloudJob` based on the response. Some of the fields might not be populated if they are not present in the response. """ job = CloudJob( id_=api_response["id"], output_id=None, status=ExperimentParser.parse_experiment_status(api_response), ) if api_response.get("output", None): job.output_id = api_response["output"] return job
[docs] @staticmethod def parse_experiment_status(api_response: Dict) -> CloudJobStatus: """Return an Experiment status from an API response. Args: api_response: the response from the API. Returns: A value from the `CloudJobStatus` enum. Raises: InvalidResponseFieldError: if the API response contains an unrecognized status code. """ job_status = api_response["status"] if job_status in ("waiting", "validating", "validated"): return CloudJobStatus.WAITING if job_status in ("running",): return CloudJobStatus.RUNNING if job_status in ("failed", "cancelled"): return CloudJobStatus.FAILED if job_status in ("completed",): return CloudJobStatus.COMPLETED raise InvalidResponseFieldError("Unsupported job status: {}".format(job_status))
[docs] @staticmethod def parse_experiment_category(api_response: Dict) -> CloudExperimentCategory: """Return an Experiment category from an API response. Args: api_response: the response from the API. Returns: A value from the `CloudExperimentCategory` enum. Raises: InvalidResponseFieldError: if the API response contains an unrecognized category. """ job_category = api_response["category"] if job_category in ("train", "trainweb"): return CloudExperimentCategory.BASIC_TRAINING if job_category in ("inference", "inferenceweb"): return CloudExperimentCategory.BASIC_INFERENCE raise InvalidResponseFieldError("Unsupported experiment category: {}".format(job_category))
[docs] @staticmethod def parse_date_string(date_string: str) -> datetime: """Return a datetime from a date string. Args: date_string: the date string from the API. Returns: A value from the `CloudExperimentCategory` enum. """ tmp_datetime = datetime.strptime(date_string, "%Y-%m-%dT%H:%M:%S.%fZ") return tmp_datetime.replace(tzinfo=timezone.utc)
[docs]class GeneralParser: """Parser for generic responses.""" # pylint: disable=too-few-public-methods
[docs] @staticmethod def parse_login(api_response: Dict) -> str: """Return the jwt token from an API response. Args: api_response: the response from the API. Returns: A string with the jwt token. """ return api_response["jwt"]