"""Classes & functions for managing parameters for simulating power spectra."""
import numpy as np
from fooof.core.utils import group_three, check_flat
from fooof.core.info import get_indices
from fooof.core.funcs import infer_ap_func
from fooof.core.errors import InconsistentDataError
from fooof.data import SimParams
###################################################################################################
###################################################################################################
def collect_sim_params(aperiodic_params, periodic_params, nlv):
    """Collect simulation parameters into a SimParams object.
    Parameters
    ----------
    aperiodic_params : list of float
        Parameters of the aperiodic component of the power spectrum.
    periodic_params : list of float or list of list of float
        Parameters of the periodic component of the power spectrum.
    nlv : float
        Noise level of the power spectrum.
    Returns
    -------
    SimParams
        Object containing the simulation parameters.
    """
    return SimParams(aperiodic_params.copy(),
                     sorted(group_three(check_flat(periodic_params))),
                     nlv)
[docs]def update_sim_ap_params(sim_params, delta, field=None):
    """Update the aperiodic parameter definition in a SimParams object.
    Parameters
    ----------
    sim_params : SimParams
        Object storing the current parameter definition.
    delta : float or list of float
        Value(s) by which to update the parameters.
    field : {'offset', 'knee', 'exponent'} or list of string
        Field of the aperiodic parameter(s) to update.
    Returns
    -------
    new_sim_params : SimParams
        Updated object storing the new parameter definition.
    Raises
    ------
    InconsistentDataError
        If the input parameters and update values are inconsistent.
    """
    # Grab the aperiodic parameters that need updating
    ap_params = sim_params.aperiodic_params.copy()
    # If field isn't specified, check shapes line up and update across parameters
    if not field:
        if not len(ap_params) == len(delta):
            raise InconsistentDataError("The number of items to update and "
                                        "number of new values is inconsistent.")
        ap_params = [param + update for param, update in zip(ap_params, delta)]
    # If labels are given, update deltas according to their labels
    else:
        # This loop checks & casts to list, to work for single or multiple passed in values
        for cur_field, cur_delta in zip(list([field]) if not isinstance(field, list) else field,
                                        list([delta]) if not isinstance(delta, list) else delta):
            data_ind = get_indices(infer_ap_func(ap_params))[cur_field]
            ap_params[data_ind] = ap_params[data_ind] + cur_delta
    # Replace parameters. Note that this copies a new object, as data objects are immutable
    new_sim_params = sim_params._replace(aperiodic_params=ap_params)
    return new_sim_params 
[docs]class Stepper():
    """Object for stepping across parameter values.
    Parameters
    ----------
    start : float
        Start value to iterate from.
    stop : float
        End value to iterate to.
    step : float
        Increment of each iteration.
    Attributes
    ----------
    len : int
        Length of generator range.
    data : iterator
        Set of specified parameters to iterate across.
    Examples
    --------
    Define a stepper object for center frequency values for an alpha peak:
    >>> alpha_cf_steps = Stepper(8, 12.5, 0.5)
    """
[docs]    def __init__(self, start, stop, step):
        """Initialize a Stepper object."""
        self._check_values(start, stop, step)
        self.start = start
        self.stop = stop
        self.step = step
        self.len = round((stop-start)/step)
        self.data = iter(np.arange(start, stop, step)) 
    def __len__(self):
        return self.len
    def __next__(self):
        return round(next(self.data), 4)
    def __iter__(self):
        return self.data
    @staticmethod
    def _check_values(start, stop, step):
        """Checks if provided values are valid.
        Parameters
        ----------
        start, stop, step : float
            Definition of the parameter range to iterate over.
        Raises
        ------
        ValueError
            If the given values for defining the iteration range are inconsistent.
        """
        if any(ii < 0 for ii in [start, stop, step]):
            raise ValueError("Inputs 'start', 'stop', and 'step' should all be positive values.")
        if not start < stop:
            raise ValueError("Input 'start' should be less than 'stop'.")
        if not step < (stop - start):
            raise ValueError("Input 'step' is too large given values for 'start' and 'stop'.") 
[docs]def param_iter(params):
    """Create a generator to iterate across parameter ranges.
    Parameters
    ----------
    params : list of floats and Stepper
        Parameters over which to iterate, including a Stepper object.
        The Stepper defines the iterated parameter and its range.
        Floats define the other parameters, that will be held constant.
    Yields
    ------
    list of floats
        Next generated list of parameters.
    Raises
    ------
    ValueError
        If the number of Stepper objects given is greater than one.
    Examples
    --------
    Iterate across exponent values from 1 to 2, in steps of 0.1:
    >>> aps = param_iter([Stepper(1, 2, 0.1), 1])
    Iterate over center frequency values from 8 to 12 in increments of 0.25:
    >>> peaks = param_iter([Stepper(8, 12, .25), 0.5, 1])
    """
    # If input is a list of lists, check each element, and flatten if needed
    if isinstance(params[0], list):
        params = [item for sublist in params for item in sublist]
    # Finds where Stepper object is. If there is more than one, raise an error
    iter_ind = 0
    num_iters = 0
    for cur_ind, param in enumerate(params):
        if isinstance(param, Stepper):
            num_iters += 1
            iter_ind = cur_ind
        if num_iters > 1:
            raise ValueError("Iteration is only supported across one parameter at a time.")
    # Generate and yield next set of parameters
    gen = params[iter_ind]
    while True:
        try:
            params[iter_ind] = next(gen)
            yield params
        except StopIteration:
            return 
[docs]def param_sampler(params, probs=None):
    """Create a generator to sample randomly from possible parameters.
    Parameters
    ----------
    params : list of lists or list of float
        Possible parameter values.
    probs : list of float, optional
        Probabilities with which to sample each parameter option.
        If None, each parameter option is sampled uniformly.
    Yields
    ------
    list of float
        A randomly sampled set of parameters.
    Examples
    --------
    Sample from aperiodic definitions with high and low exponents, with 50% probability of each:
    >>> aps = param_sampler([[1, 1], [2, 1]], probs=[0.5, 0.5])
    Sample from peak definitions of alpha or alpha & beta, with 75% change of sampling just alpha:
    >>> peaks = param_sampler([[10, 1, 1], [[10, 1, 1], [20, 0.5, 1]]], probs=[0.75, 0.25])
    """
    # If input is a list of lists, check each element, and flatten if needed
    if isinstance(params[0], list):
        params = [check_flat(lst) for lst in params]
    # In order to use numpy's choice, with probabilities, choices are made on indices
    # This is because the params can be a messy-sized list, that numpy choice does not like
    inds = np.array(range(len(params)))
    # Check that length of options is same as length of probs, if provided
    if np.any(probs):
        if len(inds) != len(probs):
            raise ValueError("The number of options must match the number of probabilities.")
    # While loop allows the generator to be called an arbitrary number of times
    while True:
        yield params[np.random.choice(inds, p=probs)] 
[docs]def param_jitter(params, jitters):
    """Create a generator that adds jitter to parameter definitions.
    Parameters
    ----------
    params : list of lists or list of float
        Possible parameter values.
    jitters : list of lists or list of float
        The scale of the jitter for each parameter.
        Must be the same shape and organization as `params`.
    Yields
    ------
    list of float
        A jittered set of parameters.
    Notes
    -----
    - Jitter is added as random samples from a normal (gaussian) distribution.
        - The jitter specified corresponds to the standard deviation of the normal distribution.
    - For any parameter for which there should be no jitter, set the corresponding value to zero.
    Examples
    --------
    Jitter aperiodic definitions, for offset and exponent, each with the same amount of jitter:
    >>> aps = param_jitter([1, 1], [0.1, 0.1])
    Jitter center frequency of peak definitions, by different amounts for alpha & beta:
    >>> peaks = param_jitter([[10, 1, 1], [20, 1, 1]], [[0.1, 0, 0], [0.5, 0, 0]])
    """
    # Check if inputs are list of lists, and flatten if so
    if isinstance(params[0], list):
        params = check_flat(params)
        jitters = check_flat(jitters)
    # While loop allows the generator to be called an arbitrary number of times
    while True:
        out_params = [None] * len(params)
        for ind, (param, jitter) in enumerate(zip(params, jitters)):
            out_params[ind] = param + np.random.normal(0, jitter)
        yield out_params