# -*- coding: utf-8 -*-
# (C) Copyright 2020, 2021, 2022 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""Helper for retrieving Metrics of an Experiment."""
from datetime import datetime
from typing import Dict, List
[docs]class LocalMetric:
"""Metric for local experiments.
Metric for local execution of experiments. Output to stdout can be
controlled by the ``stdout`` parameter to the constructor.
"""
def __init__(self, stdout: bool = True) -> None:
self.epochs: List[Dict] = []
self.current_epoch: Dict = {}
self.stdout = stdout
[docs] def receive_epoch_start(self, epoch: int) -> None:
"""Hook for `EPOCH_START`."""
self.current_epoch = {
'number': epoch,
'start_time': datetime.utcnow(),
'total_loss': 0,
'training_images': 0,
'validation_images': 0,
'validation_correct': 0,
'validation_loss': 0
}
[docs] def receive_train_epoch_batch_end(
self,
total: int,
train_loss: float
) -> None:
"""Hook for `TRAIN_EPOCH_START`."""
self.current_epoch['training_images'] += total
self.current_epoch['total_loss'] += train_loss
[docs] def receive_validation_epoch_batch_end(
self,
total: int,
correct: int,
validation_loss: float
) -> None:
"""Hook for `VALIDATION_EPOCH_BATCH_END`."""
self.current_epoch['validation_images'] += total
self.current_epoch['validation_correct'] += int(correct)
self.current_epoch['validation_loss'] += validation_loss
[docs] def receive_train_epoch_end(self) -> None:
"""Hook for `TRAIN_EPOCH_END`."""
if not self.stdout:
return
print('Epoch: {}, loss: {:.8f}'.format(
self.current_epoch['number'],
self.current_epoch['total_loss'] / self.current_epoch['training_images'],
))
[docs] def receive_validation_epoch_end(self) -> None:
"""Hook for `VALIDATION_EPOCH_END`."""
if not self.stdout:
return
print('Number of images: {}, accuracy: {:.6%}, validation loss: {:.8f}'.format(
self.current_epoch['validation_images'],
self.current_epoch['validation_correct'] / self.current_epoch['validation_images'],
self.current_epoch['validation_loss'] / self.current_epoch['validation_images'],
))
[docs] def receive_epoch_end(self) -> Dict:
"""Hook for `EPOCH_END`."""
end_time = datetime.utcnow()
time_epoch = (end_time - self.current_epoch['start_time']).total_seconds()
if self.stdout:
print('Time for epoch {}: {:.4}s'.format(self.current_epoch['number'],
time_epoch))
self.epochs.append(self.current_epoch)
return {
'epoch': self.current_epoch['number'],
'time_epoch': time_epoch,
'accuracy': (self.current_epoch['validation_correct'] /
self.current_epoch['validation_images']),
'train_loss': (self.current_epoch['total_loss'] /
self.current_epoch['training_images']),
'valid_loss': (self.current_epoch['validation_loss'] /
self.current_epoch['validation_images'])
}