Source code for aihwkit.experiments.runners.i_metrics

# -*- coding: utf-8 -*-

# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# Licensed under the MIT license. See LICENSE file in the project root for details.

"""Helper for retrieving Metrics of an Experiment."""

from datetime import datetime
import json
from typing import Dict


[docs]class InferenceLocalMetric: """Metric used by the InferenceWorker Runner.""" def __init__(self, stdout: bool = False) -> None: self.current_repeat: Dict = {} self.time_init = datetime.utcnow() self.stdout = stdout
[docs] def receive_repeat_start(self, repeat: int) -> None: """Hook for `INFERENCE_REPEAT_START`.""" self.current_repeat = {"number": repeat, "inference_results": []}
[docs] def receive_repeat_end( self, t_inference_array: list, avg_acc_arr: list, std_acc_arr: list, avg_err_arr: list, avg_loss_arr: list, inference_repeats: int, ) -> Dict: """Hook for `INFERENCE_REPEAT_END`.""" inf_results = [] n_inference = len(t_inference_array) # The input are the arrays of avg accuracy, avg error and avg loss. # Create the dict entry for the items in the arrays. for i in range(n_inference): new_dict = { "t_inference": t_inference_array[i], "avg_accuracy": avg_acc_arr[i], "std_accuracy": std_acc_arr[i], "avg_error": avg_err_arr[i], "avg_loss": avg_loss_arr[i], } inf_results.append(new_dict) repeat = self.current_repeat["number"] + 1 time_elapsed = (datetime.utcnow() - self.time_init).total_seconds() is_partial = bool(repeat < inference_repeats) partial = { "inference_runs": { "inference_repeat": repeat, "is_partial": is_partial, "time_elapsed": time_elapsed, "inference_results": inf_results, } } if self.stdout: print("{}".format(json.dumps(partial))) # Return the partial. return partial