Source code for aihwkit.experiments.runners.metrics

# -*- coding: utf-8 -*-

# (C) Copyright 2020, 2021, 2022 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Helper for retrieving Metrics of an Experiment."""

from datetime import datetime
from typing import Dict, List


[docs]class LocalMetric: """Metric for local experiments. Metric for local execution of experiments. Output to stdout can be controlled by the ``stdout`` parameter to the constructor. """ def __init__(self, stdout: bool = True) -> None: self.epochs: List[Dict] = [] self.current_epoch: Dict = {} self.stdout = stdout
[docs] def receive_epoch_start(self, epoch: int) -> None: """Hook for `EPOCH_START`.""" self.current_epoch = { 'number': epoch, 'start_time': datetime.utcnow(), 'total_loss': 0, 'training_images': 0, 'validation_images': 0, 'validation_correct': 0, 'validation_loss': 0 }
[docs] def receive_train_epoch_batch_end( self, total: int, train_loss: float ) -> None: """Hook for `TRAIN_EPOCH_START`.""" self.current_epoch['training_images'] += total self.current_epoch['total_loss'] += train_loss
[docs] def receive_validation_epoch_batch_end( self, total: int, correct: int, validation_loss: float ) -> None: """Hook for `VALIDATION_EPOCH_BATCH_END`.""" self.current_epoch['validation_images'] += total self.current_epoch['validation_correct'] += int(correct) self.current_epoch['validation_loss'] += validation_loss
[docs] def receive_train_epoch_end(self) -> None: """Hook for `TRAIN_EPOCH_END`.""" if not self.stdout: return print('Epoch: {}, loss: {:.8f}'.format( self.current_epoch['number'], self.current_epoch['total_loss'] / self.current_epoch['training_images'], ))
[docs] def receive_validation_epoch_end(self) -> None: """Hook for `VALIDATION_EPOCH_END`.""" if not self.stdout: return print('Number of images: {}, accuracy: {:.6%}, validation loss: {:.8f}'.format( self.current_epoch['validation_images'], self.current_epoch['validation_correct'] / self.current_epoch['validation_images'], self.current_epoch['validation_loss'] / self.current_epoch['validation_images'], ))
[docs] def receive_epoch_end(self) -> Dict: """Hook for `EPOCH_END`.""" end_time = datetime.utcnow() time_epoch = (end_time - self.current_epoch['start_time']).total_seconds() if self.stdout: print('Time for epoch {}: {:.4}s'.format(self.current_epoch['number'], time_epoch)) self.epochs.append(self.current_epoch) return { 'epoch': self.current_epoch['number'], 'time_epoch': time_epoch, 'accuracy': (self.current_epoch['validation_correct'] / self.current_epoch['validation_images']), 'train_loss': (self.current_epoch['total_loss'] / self.current_epoch['training_images']), 'valid_loss': (self.current_epoch['validation_loss'] / self.current_epoch['validation_images']) }