# -*- coding: utf-8 -*-
# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.
"""Conversion script of legacy checkpoints (pre v0.8) to the new format."""
# pylint: disable=too-many-locals, too-many-statements, too-many-branches
from typing import Tuple, Optional
from collections import OrderedDict
from copy import deepcopy
from torch import Tensor, float32, ones
from torch.nn import Module
from aihwkit.simulator.configs.configs import InferenceRPUConfig
from aihwkit.simulator.parameters.base import RPUConfigBase
from aihwkit.simulator.parameters import PrePostProcessingParameter, WeightRemapParameter
from aihwkit.simulator.presets.web import OldWebComposerMappingParameter
from aihwkit.inference.noise.pcm import PCMLikeNoiseModel
[docs]def convert_legacy_checkpoint(
legacy_chkpt: "OrderedDict", model: Optional[Module] = None
) -> Tuple["OrderedDict", RPUConfigBase]:
"""Attempts to convert the fields of an legacy checkoint model so
that it can be loaded with the new (v0.8) tile structure.
Caution:
Might not be fully functional in all cases.
Important:
Only one of the RPUConfig of any tile is return. If tiles have
different RPUConfigs, any of them might be return in not
particular order.
Args:
legacy_chkpt: loaded checkpoint (state_dict) from pre v0.8 version.
model: Will solve more issues if instantiated model is given (that will load the stat_dict)
Returns:
Tuple of converted checkoint and the (one of the any) RPUConfig found in the tiles.
"""
def check_conv_mapped(prefix1: str) -> bool:
mod_name = prefix1[:-1]
if mod_name in layer_dic:
if "Conv" in layer_dic[mod_name] and "Mapped" in layer_dic[mod_name]:
return True
return False
def get_key_from_ending(key_name: str, par_name: str, prefix: str) -> str:
ending = key_name.split(par_name)[-1]
arr = [int(val) for val in ending.split("_") if len(val) > 0]
if len(arr) == 1 and arr[0] == 0:
new_key = "analog_module." + par_name
elif len(arr) == 1 and arr[0] != 0:
new_key = "analog_module." + par_name + "." + str(arr)
elif len(arr) > 1:
new_key = "analog_module.array"
if check_conv_mapped(prefix):
new_key = "array"
for val in arr:
new_key += "." + str(val)
new_key += "." + par_name
else:
# don't know this should not happen. Just use same ending
new_key = "analog_module." + par_name + ending
return new_key
has_mapped = False
layer_dic = {}
if model is not None:
for name, analog_layer in model.named_analog_layers():
has_mapped = has_mapped or "Mapped" in analog_layer.__class__.__name__
layer_dic[name] = analog_layer.__class__.__name__
if not has_mapped:
for tile in model.analog_tiles():
tile.rpu_config.mapping.max_input_size = 0
tile.rpu_config.mapping.max_output_size = 0
legacy_chkpt = deepcopy(legacy_chkpt)
for key, value in legacy_chkpt.items():
if "analog_model.analog_tile_state" in key:
# this is actually a new checkpoint. abort.
rpu_config = deepcopy(value["rpu_config"])
return legacy_chkpt, rpu_config
for key, value in legacy_chkpt.items():
if "analog_tile_state" in key:
rpu_config = value["rpu_config"]
if not isinstance(rpu_config, InferenceRPUConfig):
continue
if not hasattr(rpu_config, "mapping"):
rpu_config.mapping = OldWebComposerMappingParameter()
if "weight_scaling_omega_columnwise" in rpu_config.mapping.__dict__:
rpu_config.mapping.weight_scaling_columnwise = rpu_config.mapping.__dict__.pop(
"weight_scaling_omega_columnwise"
)
if "learn_out_scaling_alpha" in rpu_config.mapping.__dict__:
rpu_config.mapping.learn_out_scaling = rpu_config.mapping.__dict__.pop(
"learn_out_scaling_alpha"
)
rpu_config.mapping.out_scaling_columnwise = (
rpu_config.mapping.weight_scaling_columnwise
)
if not has_mapped:
# need to set tile sizes to full since otherwise
# mapping would now still occur
rpu_config.mapping.max_input_size = 0
rpu_config.mapping.max_output_size = 0
if isinstance(rpu_config.noise_model, PCMLikeNoiseModel) and not hasattr(
rpu_config.noise_model, "prog_coeff_g_max_reference"
):
rpu_config.noise_model.prog_coeff_g_max_reference = rpu_config.noise_model.g_max
if not hasattr(rpu_config, "pre_post"):
rpu_config.pre_post = PrePostProcessingParameter()
if not hasattr(rpu_config, "remap"):
rpu_config.remap = WeightRemapParameter()
if not hasattr(rpu_config.modifier, "coeffs"):
dic = rpu_config.modifier.__dict__
rpu_config.modifier.coeffs = [
dic.pop("coeff0"),
dic.pop("coeff1"),
dic.pop("coeff2"),
]
chkpt = OrderedDict()
for key, value in legacy_chkpt.items():
name = key.split(".")[-1]
prefix = key.split(name)[0]
if "bias" == name:
if len([k for k in legacy_chkpt if k.startswith(prefix + "analog_tile_state")]) > 0:
# digital bias of an analog module. Now handled inside the TileModule
if check_conv_mapped(prefix):
new_key = prefix + name
else:
new_key = prefix + "analog_module." + name
chkpt[new_key] = value
continue
if name.startswith("analog_tile_state"):
# tile (array) numbers
new_key = prefix + get_key_from_ending(name, "analog_tile_state", prefix)
state = value
for legacy_key in [
"noise_model",
"drift_compensation",
"drift_baseline",
"drift_readout_tensor",
"reference_combined_weights",
"programmed_weights",
"nu_drift_list",
"shared_weights",
"image_sizes",
]:
state.pop(legacy_key, None)
# rename bias flag
state["use_bias"] = state.pop("bias", False)
# drift comp
alpha = state.get("alpha", None)
if isinstance(alpha, Tensor):
p_key = new_key.replace("analog_tile_state", "alpha")
chkpt[p_key] = alpha
# will be applied
# out_scaling_alpha = state.get('out_scaling_alpha', None)
out_size = state["out_size"]
rpu_config = state["rpu_config"]
device = state["analog_ctx"].device
# mapping scales
p_key = new_key.replace("analog_tile_state", "mapping_scales")
chkpt[p_key] = ones((out_size,), dtype=float32, device=device)
if "mapping_scales" in state:
chkpt[p_key] *= state.pop("mapping_scales", 1.0)
# in case of the very old alpha scale
chkpt[p_key] *= state.pop("analog_alpha_scale", 1.0)
chkpt[new_key] = state
continue
if name.startswith("analog_out_scaling_alpha"):
# this is used for the mapped
new_name = name.split("analog_")[-1] + "_0"
new_key = prefix + get_key_from_ending(new_name, "out_scaling_alpha", prefix)
chkpt[new_key] = value
continue
if name.startswith("analog_shared_weights"):
# new_key = prefix + 'analog_module.' + name
# chkpt[new_key] = value
continue
if name.startswith("analog_ctx"):
new_key = prefix + get_key_from_ending(name, "analog_ctx", prefix)
chkpt[new_key] = value
continue
chkpt[key] = value
return chkpt, rpu_config