Source code for aihwkit.cloud.converter.v1.noise_model_info

# -*- coding: utf-8 -*-

# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

"""Noise model info in rpu_config to neural network model"""

# pylint: disable=no-name-in-module,import-error
from aihwkit.cloud.converter.definitions.i_input_file_pb2 import (  # type: ignore[attr-defined]
    NoiseModelProto,
)


# pylint: disable=too-many-instance-attributes
[docs]class NoiseModelInfo: """Data only class for fields from protobuf NoiseModelProto message""" PCM = "pcm" GENERIC = "generic" def __init__(self, nm_proto: NoiseModelProto): # type: ignore[valid-type] """Constructor for this class""" type_ = nm_proto.WhichOneof("item") # type: ignore[attr-defined] info = None if type_ == "pcm": # pcm does NOT have 2 extra fields info = nm_proto.pcm # type: ignore[attr-defined] else: # generic HAS 2 extra fields info = nm_proto.generic # type: ignore[attr-defined] self.device_id = info.device_id self.programming_noise_scale = info.programming_noise_scale self.read_noise_scale = info.read_noise_scale self.drift_scale = info.drift_scale self.drift_compensation = info.drift_compensation self.poly_first_order_coef = info.poly_first_order_coef self.poly_second_order_coef = info.poly_second_order_coef self.poly_constant_coef = info.poly_constant_coef self._drift_mean = -1.1 self._drift_std = -1.1 # Generic device has two extra field. if info.device_id == self.GENERIC: self._drift_mean = info.drift_mean self._drift_std = info.drift_std def _assert_generic(self) -> None: """Check is device is generic""" assert self.device_id == self.GENERIC, "device_id does not have value '{}'".format( self.GENERIC ) @property def drift_mean(self) -> float: """Enforce access to drift_mean if this is a generic device""" self._assert_generic() return self._drift_mean @property def drift_std(self) -> float: """Enforce access to drift_std if this is a generic device""" self._assert_generic() return self._drift_std
# pylint: enable=too-many-instance-attributes