Source code for aihwkit.cloud.client.v1.parsers

# -*- coding: utf-8 -*-

# (C) Copyright 2020, 2021, 2022 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Parsers for the AIHW Composer API."""

from datetime import datetime, timezone
from typing import Any, Dict

from aihwkit.cloud.client.entities import (
    CloudExperiment, CloudExperimentCategory, CloudJobStatus, CloudJob
)
from aihwkit.cloud.client.exceptions import InvalidResponseFieldError


[docs]class ExperimentParser: """Parser for Experiment API responses."""
[docs] @staticmethod def parse_experiment(api_response: Dict, api_client: Any) -> CloudExperiment: """Return a CloudExperiment from an API response. Args: api_response: the response from the API. api_client: the client to be used in API requests. Returns: A `CloudExperiment` based on the response. Some of the fields might not be populated if they are not present in the response. """ experiment = CloudExperiment( _api_client=api_client, id_=api_response['id'], name=api_response['name'], category=ExperimentParser.parse_experiment_category(api_response), created_at=ExperimentParser.parse_date_string(api_response['createdAt']), input_id=None, job=None ) if api_response.get('input', {}).get('id'): experiment.input_id = api_response['input']['id'] if api_response.get('job', None): experiment.job = ExperimentParser.parse_job(api_response['job']) return experiment
[docs] @staticmethod def parse_job(api_response: Dict) -> CloudJob: """Return an CloudJob from an API response. Args: api_response: the response from the API. Returns: A `CloudJob` based on the response. Some of the fields might not be populated if they are not present in the response. """ job = CloudJob( id_=api_response['id'], output_id=None, status=ExperimentParser.parse_experiment_status(api_response) ) if api_response.get('output', None): job.output_id = api_response['output'] return job
[docs] @staticmethod def parse_experiment_status(api_response: Dict) -> CloudJobStatus: """Return an Experiment status from an API response. Args: api_response: the response from the API. Returns: A value from the `CloudJobStatus` enum. Raises: InvalidResponseFieldError: if the API response contains an unrecognized status code. """ job_status = api_response['status'] if job_status in ('waiting', 'validating', 'validated'): return CloudJobStatus.WAITING if job_status in ('running', ): return CloudJobStatus.RUNNING if job_status in ('failed', 'cancelled'): return CloudJobStatus.FAILED if job_status in ('completed',): return CloudJobStatus.COMPLETED raise InvalidResponseFieldError('Unsupported job status: {}'.format(job_status))
[docs] @staticmethod def parse_experiment_category(api_response: Dict) -> CloudExperimentCategory: """Return an Experiment category from an API response. Args: api_response: the response from the API. Returns: A value from the `CloudExperimentCategory` enum. Raises: InvalidResponseFieldError: if the API response contains an unrecognized category. """ job_category = api_response['category'] if job_category in ('train', 'trainweb'): return CloudExperimentCategory.BASIC_TRAINING raise InvalidResponseFieldError( 'Unsupported experiment category: {}'.format(job_category))
[docs] @staticmethod def parse_date_string(date_string: str) -> datetime: """Return a datetime from a date string. Args: date_string: the date string from the API. Returns: A value from the `CloudExperimentCategory` enum. """ tmp_datetime = datetime.strptime(date_string, '%Y-%m-%dT%H:%M:%S.%fZ') return tmp_datetime.replace(tzinfo=timezone.utc)
[docs]class GeneralParser: """Parser for generic responses.""" # pylint: disable=too-few-public-methods
[docs] @staticmethod def parse_login(api_response: Dict) -> str: """Return the jwt token from an API response. Args: api_response: the response from the API. Returns: A string with the jwt token. """ return api_response['jwt']