# -*- coding: utf-8 -*-
# (C) Copyright 2020, 2021, 2022 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""Session handler for the AIHW Composer API."""
from typing import Any, Optional, Text, Union
from requests import HTTPError, Session
import urllib3
from urllib3.exceptions import InsecureRequestWarning
from aihwkit.version import __version__
from aihwkit.cloud.client.exceptions import ApiResponseError, ResponseError
[docs]class ObjectStorageSession(Session):
"""Session handler for requests to object storage."""
[docs] def request(
self,
method: str,
url: Union[str, bytes, Text],
*args: Any,
**kwargs: Any
) -> Any:
"""Construct a Request, prepares it and sends it.
Args:
method: method for the new ``Request`` object.
url: URL for the new ``Request`` object.
args: additional arguments for the original ``requests`` method.
kwargs: additional arguments for the original ``requests`` method.
Returns:
A new ``Response`` object.
Raises:
ResponseError: if the response did not have a valid status code.
"""
# pylint: disable=signature-differs
response = super().request(method, url, *args, **kwargs)
try:
response.raise_for_status()
except HTTPError:
raise ResponseError(response) from None
return response
[docs]class ApiSession(Session):
"""Session handler for requests to the AIHW Composer API.
Custom ``Session`` for interfacing with the AIHW Composer API, using:
* authorization based on jwt token.
* custom user agent for the requests.
Additionally, this class stores information about the API URL and base
token.
"""
def __init__(
self,
api_url: str,
api_token: str,
verify: bool = True
):
super().__init__()
self.api_url = api_url
self.api_token = api_token
self.verify = verify
if not verify:
urllib3.disable_warnings(InsecureRequestWarning)
self.jwt_token = None # type: Optional[str]
self.headers.update({'User-Agent': 'aihwkit/{}'.format(__version__)})
[docs] def update_jwt_token(self, jwt_token: str) -> None:
"""Set the jwt token for the session."""
self.jwt_token = jwt_token
self.headers.update({'Authorization': 'Bearer {}'.format(jwt_token)})
[docs] def request(
self,
method: str,
url: Union[str, bytes, Text],
*args: Any,
**kwargs: Any
) -> Any:
"""Construct a Request, prepares it and sends it.
Args:
method: method for the new ``Request`` object.
url: URL for the new ``Request`` object.
args: additional arguments for the original ``requests`` method.
kwargs: additional arguments for the original ``requests`` method.
Returns:
A new ``Response`` object.
Raises:
ApiResponseError: if the response did not have a valid status code.
"""
# pylint: disable=signature-differs
full_url = '{}/{}'.format(self.api_url, str(url))
response = super().request(method, full_url, *args, **kwargs)
try:
response.raise_for_status()
except HTTPError:
raise ApiResponseError(response) from None
return response