# -*- coding: utf-8 -*-
# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""Visualization utilities (web)."""
import argparse
from typing import Optional, Union
from pathlib import Path
from cycler import cycler
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from matplotlib.axes import Axes
from aihwkit.simulator.presets.devices import (
ReRamSBPresetDevice,
ReRamESPresetDevice,
CapacitorPresetDevice,
EcRamPresetDevice,
IdealizedPresetDevice,
)
from aihwkit.simulator.presets.compounds import PCMPresetUnitCell
from aihwkit.simulator.configs.devices import PulsedDevice
from aihwkit.utils.visualization import plot_device_compact
# Colors used by the frontend.
WEB_COLORS = [
"#8A3FFC",
"#33B1FF",
"#007D79",
"#FF7EB6",
"#FA4D56",
"#FFF1F1",
"#6FDC8C",
"#4589FF",
"#D12771",
"#D2A106",
"#08BDBA",
"#BAE6FF",
"#BA4E00",
"#D4BBFF",
]
# Devices for which plots should be generated.
DEVICES = {
ReRamESPresetDevice: 1000,
ReRamSBPresetDevice: 1000,
CapacitorPresetDevice: 400,
EcRamPresetDevice: 1000,
IdealizedPresetDevice: 10000,
PCMPresetUnitCell: 80,
}
[docs]def set_dark_style(axes: Axes) -> None:
"""Sets a nice color cycle for a given axes."""
axes.set_prop_cycle(cycler(color=WEB_COLORS))
axes.set_facecolor("#262626")
[docs]def plot_device_compact_web(
device: PulsedDevice, w_noise: float = 0.0, n_steps: Optional[int] = None, n_traces: int = 3
) -> Union[Figure, Axes]:
"""Plots a compact step response figure for a given device (preset).
Note:
It will use an amount of read weight noise ``w_noise`` for
reading the weights.
Params:
device: PulsedDevice parameters
w_noise: Weight noise standard deviation during read
n_steps: Number of steps for up/down cycle
n_traces: Number of traces to plot (for device-to-device variation)
show: if `True`, displays the figure.
Returns:
the compact step response figure.
"""
plt.style.use("dark_background")
figure = plot_device_compact(device, w_noise, n_steps, n_traces)
if isinstance(figure, Axes):
return figure
# Tune for web.
for axes in figure.get_axes():
for i, line in enumerate(axes.get_lines()):
line.set_color(WEB_COLORS[i])
# set_dark_style(axes)
return figure
[docs]def save_plots_for_web(path: Path = Path("/tmp"), file_format: str = "svg") -> None:
"""Create the plots for the web.
Args:
path: the path where the images will be stored.
file_format: the image format.
"""
def camel_to_snake(source: str) -> str:
"""Convert a CamelCase string into snake-case."""
return "".join(["_" + char.lower() if char.isupper() else char for char in source]).lstrip(
"_"
)
for device, n_steps in DEVICES.items():
# Images for the detailed modal.
file_name = "{}.{}".format(camel_to_snake(device.__name__), file_format)
file_path = path.absolute() / file_name
figure = plot_device_compact_web(device(), n_steps=n_steps) # type: ignore
figure.savefig( # type: ignore
file_path, format=file_format, transparent=True, bbox_inches="tight"
)
# Images for the mini leftbar.
file_name = "{}-mini.{}".format(camel_to_snake(device.__name__), file_format)
file_path = path.absolute() / file_name
figure = plot_device_compact_web(device(), n_traces=1, n_steps=n_steps)
for axes in figure.get_axes(): # type: ignore
# Disable texts.
axes.set_title("")
axes.set_xlabel("")
axes.set_ylabel("")
# Disable tick labels.
axes.xaxis.set_ticklabels([])
axes.yaxis.set_ticklabels([])
# Disable axis entirely.
axes.get_xaxis().set_visible(False)
axes.get_yaxis().set_visible(False)
# Increase axis width.
for axis in ["top", "bottom", "left", "right"]:
axes.spines[axis].set_linewidth(3)
for line in axes.get_lines():
line.set_linewidth(4)
figure.savefig( # type: ignore
file_path, format=file_format, transparent=True, bbox_inches="tight"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate plots for frontend.")
parser.add_argument("destination", type=str, help="folder where the plots will be stored")
args = parser.parse_args()
destination_path = Path(args.destination)
destination_path.mkdir(exist_ok=True)
save_plots_for_web(destination_path)