Source code for aihwkit.utils.fitting

# -*- coding: utf-8 -*-

# (C) Copyright 2020, 2021, 2022, 2023, 2024 IBM. All Rights Reserved.
#
# This code is licensed under the Apache License, Version 2.0. You may
# obtain a copy of this license in the LICENSE.txt file in the root directory
# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0.
#
# Any modifications or derivative works of this code must retain this
# copyright notice, and modified files need to carry a notice indicating
# that they have been altered from the originals.

# pylint: disable=too-many-locals, too-many-branches

"""Fitting utilities.

This module includes fitting utilities for ``aihwkit``.
Using this module has extra dependencies that can be installed via the
extras mechanism::

    pip install aihwkit[fitting]
"""

from typing import Union, Dict, TypeVar, Tuple, Optional, List, Any
from copy import deepcopy
from dataclasses import fields

from numpy import array, concatenate, newaxis, ndarray
from torch import from_numpy, ones, stack, float32
from lmfit import minimize, Parameters, report_fit

from aihwkit.exceptions import ArgumentError, ConfigError
from aihwkit.simulator.configs.devices import PulsedDevice
from aihwkit.simulator.configs.configs import SingleRPUConfig
from aihwkit.simulator.tiles.analog import AnalogTile

RPUConfigGeneric = TypeVar("RPUConfigGeneric")


def _apply_parameters_to_config(
    device_config: Union[PulsedDevice, RPUConfigGeneric], params: Parameters
) -> None:
    """Apply the fit parameters to the device config.

    Args:
         device_config: device config to be set (in place)
         params: lmfit.Parameters structure

    Raises: ConfigError if parameter was not found"""

    parvals = params.valuesdict()
    if isinstance(device_config, PulsedDevice):
        device = device_config
    else:
        device = getattr(device_config, "device")  # type ignore

    for par, value in parvals.items():
        if not hasattr(device, par):
            raise ConfigError(f" Cannot find parameter '{par}' in device config.")
        setattr(device, par, value)


[docs]def fit_measurements( parameters: Union[Dict, Parameters], pulse_data: Union[Tuple[ndarray], ndarray], response_data: Union[Tuple[ndarray], ndarray], device_config: Union[PulsedDevice, RPUConfigGeneric], suppress_device_noise: bool = True, max_pulses: Optional[int] = 1, n_traces: int = 1, fit_weights: Optional[Union[Tuple[int], int]] = None, method: str = "powell", verbose: bool = False, **fit_kwargs: Any, ) -> Tuple[Any, Union[PulsedDevice, RPUConfigGeneric], List[ndarray]]: """Fit pulse response measurement to the given device model using lmfit. For example: .. code-block:: python # responses are conductance data in response to pulses (-1, 1) # choose device model and parameter to fit device_config = SoftBoundsDevice(w_min=-1.0, w_max=1.0) params = {'dw_min': (0.1, 0.001, 5.0), 'up_down': (0.0, -0.99, 0.99), 'w_max': (1.0, 0.1, 5.0)} # fit the response fit_res, fit_device_config, model_response = fit_measurements( params, pulses, responses, device_config=device_config, suppress_device_noise=True, method='powell', fit_weights=fit_weights, ) # fit parameter print(fit_res.params.valuesdict()) # device of best fit print(fit_device_config) Args: parameters: Parameter to vary. Dictionary with parameter names (attributes of the device config). Each value is either a single value (thus only set, not varied) or a tuple ``(x_init, x_min, x_max)``. ``lmfit.Parameters`` class can also given directly. pulse_data: Pulse data, ie array of number of pulses in up (pos) or down (neg) direction. Can be a tuple of multiple measurements response_data: Corresponfing measured responses to the pulses given by ``pulse_data`` as numpy array or list. Caution: ``axes=1`` can be used for multiple device fit. However, then all pulse data needs to have the same axis=0 dimension device_config: base device configuration suppress_device_noise: sets all dtod and std parameters of the device to 0 n_traces: how many traces to simulate simulaenously max_pulses: constrain the number of pulses given. fit_weights: the weightening of the individual response traces in the loss function method: fitting method from ``lmfit`` (default "powell") verbose: whether to print fitting results fit_kwargs: additional parameter passed to ``lmfit.minimize`` Returns: fit_results: Result of the fit in ``lmfit`` format device_config: Device config with found parameter applied model_response: Model response of parameter fit Raises: ArgumentError: in case wrong arguments are given """ if isinstance(pulse_data, tuple) != isinstance(response_data, tuple): raise ArgumentError("Either all data inputs need to be tuples or None. ") device_config = deepcopy(device_config) if isinstance(device_config, PulsedDevice): rpu_config = SingleRPUConfig(device=device_config) # single pulse mode if max_pulses is not None: rpu_config.update.desired_bl = max_pulses rpu_config.update.update_bl_management = False rpu_config.update.update_management = False else: rpu_config = device_config # type: ignore if suppress_device_noise: for field in fields(rpu_config.device): if field.name.endswith("dtod") or field.name.endswith("std"): setattr(rpu_config.device, field.name, 0.0) params = Parameters() if isinstance(parameters, Parameters): params = parameters elif isinstance(parameters, dict): for par, values in parameters.items(): if isinstance(values, tuple): x_init, x_min, x_max = values params.add(par, value=x_init, min=x_min, max=x_max, vary=True) else: params.add(par, value=values, vary=False) else: raise ArgumentError("Expect dict or Parameters for parmeters.") # fit parameters args = (pulse_data, response_data, rpu_config, n_traces, fit_weights, verbose) result = minimize(model_response, params, args=args, method=method, **fit_kwargs) if verbose: report_fit(result) best_model_res = model_response(result.params, *args, only_response=True) _apply_parameters_to_config(device_config, result.params) return result, device_config, best_model_res # type: ignore
[docs]def model_response( params: Parameters, pulse_data: Union[Tuple[ndarray], ndarray], response_data: Union[Tuple[ndarray], ndarray], rpu_config: RPUConfigGeneric, n_traces: int = 1, fit_weights: Optional[Union[Tuple[int], int]] = None, verbose: bool = True, only_response: bool = False, ) -> Union[ndarray, List[ndarray]]: """Compute the model respunses given the pulses. Args: params: ``lmfit.Parameters`` of the current parameter setting pulse_data: Pulse data, ie array of number of pulses in up (pos) or down (neg) direction. Can be a tuple of multiple measurements response_data: Corresponfing measured responses to the pulses given by ``pulse_data`` as numpy array or list. Caution: ``axes=1`` can be used for multiple device fit. However, then all pulse data needs to have the same axis=0 dimension rpu_config: base device configuration (will be modified) fit_weights: the weightening of the individual response traces in the loss function n_traces: how many traces to simulate simulaenously verbose: whether to print std of deviation only_response: whether to returns a list of model response instead of the deviation Returns: deviation vector or list of model responses (weight traces) Note: overwrites the given rpu_config """ _apply_parameters_to_config(rpu_config, params) # likley somewhat inefficient since we need to always create a new # tile, repeats are quick though no_list = False if not isinstance(pulse_data, tuple): pulse_data = (pulse_data,) response_data = (response_data,) # type: ignore if fit_weights is not None: fit_weights = (fit_weights,) # type: ignore no_list = True numpy_pulses = array(pulse_data[0]) n_devices = 1 if numpy_pulses.ndim > 1: n_devices = numpy_pulses.shape[1] analog_tile = AnalogTile(n_traces, n_devices, rpu_config) # type: ignore analog_tile.set_learning_rate(1) deviation = array([], "float") model_responses = [] for idx, (numpy_pulses, response) in enumerate(zip(pulse_data, response_data)): if numpy_pulses.ndim == 1: numpy_pulses = numpy_pulses.reshape(-1, 1) if response.ndim == 1: response = response.reshape(-1, 1) w_init = response[0, :] weights = from_numpy(array(w_init).flatten()[newaxis, :]).to(dtype=float32) * ones( (n_traces, n_devices), dtype=float32 ) analog_tile.set_weights(weights) pulses = from_numpy(numpy_pulses).to(dtype=float32) w_trace = [weights] for pulse in pulses[:-1]: analog_tile.update( pulse * ones(n_devices, dtype=float32), -ones((n_traces), dtype=float32) ) w_trace.append(analog_tile.tile.get_weights()) stacked_w_trace = stack(w_trace).cpu().numpy() # compute square error num_samples = response.shape[0] avg_w_trace = stacked_w_trace.mean(axis=1)[:num_samples, :] model_responses.append(avg_w_trace) dev = avg_w_trace - response if fit_weights is not None: dev = dev * array(fit_weights[idx])[: dev.shape[1]] # type: ignore deviation = concatenate([deviation, dev.flatten()]) if only_response: if no_list: return model_responses[0] return model_responses if verbose: print(deviation.std()) return deviation