# -*- coding: utf-8 -*-
# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""Utilities for resistive processing units configurations."""
from sys import version_info
from typing import Any, List, Optional, Type
from dataclasses import Field, fields, is_dataclass
from enum import Enum
from textwrap import indent
from aihwkit.simulator import rpu_base
from aihwkit.exceptions import ConfigError
from .enums import RPUDataType
if version_info[0] >= 3 and version_info[1] > 7:
# pylint: disable=no-name-in-module, ungrouped-imports
from typing import get_origin # type: ignore
HAS_ORIGIN = True
else:
HAS_ORIGIN = False
ALL_SKIP_FIELD = "is_perfect"
FIELD_MAP = {"forward": "forward_io", "backward": "backward_io"}
ALWAYS_INCLUDE = ["forward", "backward", "update"]
[docs]def get_bindings_class(params: Any, data_type: RPUDataType) -> Optional[Type]:
"""Return the data class from the param binding fields.
Args:
params: parameter dataclass
data_type: RPUDataType to use
Returns:
the C++ binding class
Raises:
ConfigError: if the class is not found
"""
if getattr(params, "bindings_class", None) is None:
return None
if not isinstance(params.bindings_class, str):
return params.bindings_class
# string / typed
class_name = params.bindings_class
module = getattr(rpu_base, getattr(params, "bindings_module", "devices"))
if data_type != RPUDataType.FLOAT:
if not hasattr(module, data_type.value):
raise ConfigError(
f"Cannot find requested data_type '{data_type.value}' in rpu_base module. "
)
module = getattr(module, data_type.value)
param_class = getattr(module, class_name, None)
if param_class is None:
ConfigError(f"Cannot find requested class '{class_name}' in rpu_base module. ")
return param_class
[docs]def parameters_to_bindings(params: Any, data_type: RPUDataType, check_fields: bool = True) -> Any:
"""Convert a dataclass parameter into a bindings class.
Args:
params: parameter dataclass
data_type: RPUDataType to use
check_fields: whether to check for the correct attributes
Returns:
the C++ bindings
Raises:
ConfigError: if the field type mismatches (int to float conversion is ignored)
"""
# pylint: disable=no-name-in-module, too-many-branches
result = get_bindings_class(params, data_type)
if result is None:
return params
result = result()
field_dict = {field.name: (field, getattr(params, field.name)) for field in fields(params)}
ignore_fields = getattr(params, "bindings_ignore", [])
if check_fields:
for key in params.__dict__.keys():
if key not in field_dict and key not in ignore_fields:
raise ConfigError(
f"Cannot find '{key}' in params "
f"'{params.__class__.__name__}'. "
"Wrong attribute name?"
)
for field, (dataclass_field, value) in field_dict.items():
# Convert enums to the bindings enums.
if field in ("unit_cell_devices", "device") or field in ignore_fields:
# Exclude special fields that are not present in the bindings.
continue
if isinstance(value, Enum):
if hasattr(rpu_base.tiles, value.__class__.__name__):
enum_class = getattr(rpu_base.tiles, value.__class__.__name__)
else:
enum_class = getattr(rpu_base.devices, value.__class__.__name__)
enum_value = getattr(enum_class, value.value)
setattr(result, field, enum_value)
elif is_dataclass(value):
if hasattr(value, "bindings_class"):
setattr(result, field, parameters_to_bindings(value, data_type=data_type))
else:
if HAS_ORIGIN:
expected_type = get_origin(dataclass_field.type) or dataclass_field.type
if (not isinstance(value, expected_type)) and not (
expected_type == float
and isinstance(value, int)
and not isinstance(value, bool)
):
raise ConfigError(f"Expected type {expected_type} for field {field}")
setattr(result, field, value)
return result
[docs]def tile_parameters_to_bindings(params: Any, data_type: RPUDataType) -> Any:
"""Convert a tile dataclass parameter into a bindings class.
Ignores fields that do not have metadata with ``bindings_include`` key.
Args:
params: parameter dataclass
data_type: RPUDataType to use
Returns:
the C++ bindings
"""
result = get_bindings_class(params, data_type)
if result is None:
return params
result = result() # instantiate results class
for field in fields(params):
# Get the mapped field name, if needed.
if field.name not in ALWAYS_INCLUDE and not field.metadata.get("bindings_include", False):
continue
value = params.__dict__[field.name]
field_name = FIELD_MAP.get(field.name, field.name)
if isinstance(value, Enum):
if hasattr(rpu_base.tiles, value.__class__.__name__):
enum_class = getattr(rpu_base.tiles, value.__class__.__name__)
else:
enum_class = getattr(rpu_base.devices, value.__class__.__name__)
enum_value = getattr(enum_class, value.value)
setattr(result, field_name, enum_value)
elif is_dataclass(value):
if getattr(value, "bindings_class", None) is not None:
setattr(result, field_name, parameters_to_bindings(value, data_type=data_type))
else:
setattr(result, field_name, value)
return result
class _PrintableMixin:
"""Helper class for pretty-printing of config dataclasses."""
# pylint: disable=too-few-public-methods
def __str__(self) -> str:
"""Return a pretty-print representation."""
def lines_list_to_str(
lines_list: List[str],
prefix: str = "",
suffix: str = "",
indent_: int = 0,
force_multiline: bool = False,
) -> str:
"""Convert a list of lines into a string.
Args:
lines_list: the list of lines to be converted.
prefix: an optional prefix to be appended at the beginning of
the string.
suffix: an optional suffix to be appended at the end of the string.
indent_: the optional number of spaces to indent the code.
force_multiline: force the output to be multiline.
Returns:
The lines collapsed into a single string (potentially with line
breaks).
"""
if force_multiline or len(lines_list) > 3 or any("\n" in line for line in lines_list):
# Return a multi-line string.
lines_str = indent(",\n".join(lines_list), " " * indent_)
prefix = "{}\n".format(prefix) if prefix else prefix
suffix = "\n{}".format(suffix) if suffix else suffix
else:
# Return an inline string.
lines_str = ", ".join(lines_list)
return "{}{}{}".format(prefix, lines_str, suffix)
def field_to_str(field_value: Any) -> str:
"""Return a string representation of the value of a field.
Args:
field_value: the object that contains a field value.
Returns:
The string representation of the field (potentially with line
breaks).
"""
field_lines = []
force_multiline = False
# Handle special cases.
if isinstance(field_value, list) and len(value) > 0:
# For non-empty lists, always use multiline, with one item per line.
for item in field_value:
field_lines.append(indent("{}".format(str(item)), " " * 4))
force_multiline = True
else:
field_lines.append(str(field_value))
prefix = "[" if force_multiline else ""
suffix = "]" if force_multiline else ""
return lines_list_to_str(field_lines, prefix, suffix, force_multiline=force_multiline)
def is_skippable(field: Field, value: Any) -> bool:
"""Return whether a field should be skipped."""
if field.metadata.get("always_show", False):
return False
if value == field.default:
# Skip fields with the default value.
return True
if "hide_if" in field.metadata and field.metadata.get("hide_if") == value:
return True
return False
# Main loop.
# Build the list of lines.
fields_lines = []
# special case for global skip:
all_skip = hasattr(self, ALL_SKIP_FIELD) and getattr(self, ALL_SKIP_FIELD)
for field in fields(self): # type: ignore[arg-type]
value = getattr(self, field.name)
# Exclude fields.
if (all_skip and field.name != ALL_SKIP_FIELD) or is_skippable(field, value):
continue
# Convert the value into a string, falling back to repr if needed.
try:
value_str = field_to_str(value)
except Exception: # pylint: disable=broad-except
value_str = str(value)
fields_lines.append("{}={}".format(field.name, value_str))
# Convert the full object to str.
output = lines_list_to_str(fields_lines, "{}(".format(self.__class__.__name__), ")", 4)
return output