```"""Estimate models with the method of simulated moments (MSM).

The method of simulated moments is developed by [1], [2], and [3] and an estimation
technique where the distance between the moments of the actual data and the moments
implied by the model parameters is minimized.

References
----------
.. [1] McFadden, D. (1989). A method of simulated moments for estimation of discrete
response models without numerical integration. Econometrica: Journal of the
Econometric Society, 995-1026.
.. [2] Lee, B. S., & Ingram, B. F. (1991). Simulation estimation of time-series models.
Journal of Econometrics, 47(2-3), 197-205.
.. [3] Duffie, D., & Singleton, K. (1993). Simulated Moments Estimation of Markov Models
of Asset Prices. Econometrica, 61(4), 929-952.

"""
import copy
import functools
import itertools

import numpy as np
import pandas as pd

from respy.simulate import get_simulate_func

def get_msm_func(
params,
options,
calc_moments,
replace_nans,
empirical_moments,
weighting_matrix,
n_simulation_periods=None,
return_scalar=True,
return_simulated_moments=False,
return_comparison_plot_data=False,
):
"""Get the MSM function.

Parameters
----------
params : pandas.DataFrame or pandas.Series
Contains parameters.
options : dict
Dictionary containing model options.
calc_moments : callable or list
Function(s) used to calculate simulated moments. Must match structure
of empirical moments i.e. if empirical_moments is a list of
pandas.DataFrames, calc_moments must be a list of the same length
containing functions that correspond to the moments in
empirical_moments.
replace_nans : callable or list
Functions(s) specifying how to handle missings in simulated_moments.
Must match structure of empirical_moments.
Exception: If only one replacement function is specified, it will be
used on all sets of simulated moments.
empirical_moments : pandas.DataFrame or pandas.Series or dict or list
Contains the empirical moments calculated for the observed data. Moments
should be saved to pandas.DataFrame or pandas.Series that can either be
passed to the function directly or as items of a list or dictionary.
Index of pandas.DataFrames can be of type MultiIndex, but columns cannot.
weighting_matrix : numpy.ndarray
Square matrix of dimension (NxN) with N denoting the number of
empirical_moments. Used to weight squared moment errors.
n_simulation_periods : int, default None
Dictates the number of periods in the simulated dataset.
This option does not affect ``options["n_periods"]`` which controls the
number of periods for which decision rules are computed.
return_scalar : bool, default True
Indicates whether to return moment error vector (False) or weighted
square product of moment error vectors (True).
return_simulated_moments: bool, default False
Indicates whether simulated moments should be returned with other output.
If True will return simulated moments of the same type as empirical_moments.
return_comparison_plot_data: bool, default False
Indicator for whether a :class:`pandas.DataFrame` with empirical and simulated
moments for the visualization with estimagic should be returned. Data contains
the following columns:
- moment_column: Contains the column names of the moment DataFrames/Series
names.
- moment_index: Contains the index of the moment DataFrames/Series. MultiIndex
indices will be joined to one string.
- value: Contains moment values.
- moment_set: Indicator for each set of moments, will use keys if
empirical_moments are specified in a dict. Moments input as lists will
be numbered according to position.
- kind: Indicates whether moments are empirical or simulated.
Returns
-------
msm_func: callable
MSM function where all arguments except the parameter vector are set.

"""
empirical_moments = copy.deepcopy(empirical_moments)

# Save keys of dictionary for comparison plot if applicable.
return_comparison_plot_data = [return_comparison_plot_data]
if isinstance(empirical_moments, dict):
moment_keys = sorted(empirical_moments)
return_comparison_plot_data.append(moment_keys)
else:
return_comparison_plot_data.append(None)

simulate = get_simulate_func(
params=params, options=options, n_simulation_periods=n_simulation_periods
)

empirical_moments = _harmonize_input(empirical_moments)
calc_moments = _harmonize_input(calc_moments)
replace_nans = _harmonize_input(replace_nans)

# If only one replacement function is given for multiple sets of moments, duplicate
# replacement function for all sets of simulated moments.
if len(replace_nans) == 1 and len(empirical_moments) > 1:
replace_nans = replace_nans * len(empirical_moments)

elif 1 < len(replace_nans) < len(empirical_moments):
raise ValueError(
"Replacement functions can only be matched 1:1 or 1:n with sets of "
"empirical moments."
)

elif len(replace_nans) > len(empirical_moments):
raise ValueError(
"There are more replacement functions than sets of empirical moments."
)

else:
pass

if len(calc_moments) != len(empirical_moments):
raise ValueError(
"Number of functions to calculate simulated moments must be equal to "
"the number of sets of empirical moments."
)

if return_simulated_moments and return_comparison_plot_data[0]:
raise ValueError(
"Can only return either simulated moments or comparison plot data, not both."
)

msm_func = functools.partial(
msm,
simulate=simulate,
calc_moments=calc_moments,
replace_nans=replace_nans,
empirical_moments=empirical_moments,
weighting_matrix=weighting_matrix,
return_scalar=return_scalar,
return_simulated_moments=return_simulated_moments,
return_comparison_plot_data=return_comparison_plot_data,
)

return msm_func

def msm(
params,
simulate,
calc_moments,
replace_nans,
empirical_moments,
weighting_matrix,
return_scalar,
return_simulated_moments,
return_comparison_plot_data,
):
"""Loss function for MSM estimation.

Parameters
----------
params : pandas.DataFrame or pandas.Series
Contains model parameters.
simulate : callable
Function used to simulate data for MSM estimation.
calc_moments : list
List of function(s) used to calculate simulated moments. Must match length of
empirical_moments i.e. calc_moments contains a moments function for each item in
empirical_moments.
replace_nans : list
List of functions(s) specifying how to handle missings in simulated_moments.
Must match length of empirical_moments.
empirical_moments : list
Contains the empirical moments calculated for the observed data. Each item in
the list constitutes a set of moments saved to a pandas.DataFrame or
pandas.Series. Index of pandas.DataFrames can be of type MultiIndex, but columns
cannot.
weighting_matrix : numpy.ndarray
Square matrix of dimension (NxN) with N denoting the number of
empirical_moments. Used to weight squared moment errors.
return_scalar : bool
Indicates whether to return moment error vector (False) or weighted square
product of moment error vector (True).
return_simulated_moments: bool
Indicates whether simulated moments should be returned with other output.
If True will return simulated moments of the same type as empirical_moments.
return_comparison_plot_data: list
Will output moments in a tidy data format if True. Expects a list as input where
the first element is a boolean indicating whether to return the comparison plot
data. The second element in the list can be a list of keys used to identify sets
of moments which otherwise will be numbered.

Returns
-------
out : pandas.Series or float or tuple
Scalar or moment error vector depending on value of return_scalar. Will be a
tuple containing simulated moments of same type as empirical_moments or a tidy
pandas.DataFrame if either return_simulated_moments or the first element in
return_comparison_plot_data is True.

"""
empirical_moments = copy.deepcopy(empirical_moments)

df = simulate(params)

simulated_moments = [func(df) for func in calc_moments]

simulated_moments = [
sim_mom.reindex_like(emp_mom)
for emp_mom, sim_mom in zip(empirical_moments, simulated_moments)
]

simulated_moments = [
func(mom) for mom, func in zip(simulated_moments, replace_nans)
]

flat_empirical_moments = _flatten_index(empirical_moments)
flat_simulated_moments = _flatten_index(simulated_moments)

moment_errors = flat_empirical_moments - flat_simulated_moments

# Return moment errors as indexed DataFrame or calculate weighted square product of
# moment errors depending on return_scalar.
if return_scalar:
out = moment_errors.T @ weighting_matrix @ moment_errors
else:
out = moment_errors

if return_simulated_moments:
simulated_moments = _reconstruct_inputs(
simulated_moments, return_comparison_plot_data[1]
)
out = (out, simulated_moments)

elif return_comparison_plot_data[0]:
tidy_moments = _create_comparison_plot_data_msm(
empirical_moments, simulated_moments, return_comparison_plot_data[1]
)
out = (out, tidy_moments)

else:
pass

return out

def get_diag_weighting_matrix(empirical_moments, weights=None):
"""Create a diagonal weighting matrix from weights.

Parameters
----------
empirical_moments : pandas.DataFrame or pandas.Series or dict or list
Contains the empirical moments calculated for the observed data. Moments should
be saved to pandas.DataFrame or pandas.Series that can either be passed to the
function directly or as items of a list or dictionary.
weights : pandas.DataFrame or pandas.Series or dict or list
Contains weights (usually variances) of empirical moments. Must match structure
of empirical_moments i.e. if empirical_moments is a list of pandas.DataFrames,
weights be list of pandas.DataFrames as well where each DataFrame entry contains
the weight for the corresponding moment in empirical_moments.

Returns
-------
numpy.ndarray
Array contains a diagonal weighting matrix.

"""
weights = copy.deepcopy(weights)
empirical_moments = copy.deepcopy(empirical_moments)
empirical_moments = _harmonize_input(empirical_moments)

# Use identity matrix if no weights are specified.
if weights is None:
flat_weights = _flatten_index(empirical_moments)
flat_weights[:] = 1

# Harmonize input weights.
else:
weights = _harmonize_input(weights)

# Reindex weights to ensure they are assigned to the correct moments in
# the MSM function.
weights = [
weight.reindex_like(emp_mom)
for emp_mom, weight in zip(empirical_moments, weights)
]

flat_weights = _flatten_index(weights)

return np.diag(flat_weights)

def get_flat_moments(empirical_moments):
"""Compute the empirical moments flat indexes.

Parameters
----------
empirical_moments : pandas.DataFrame or pandas.Series or dict or list
containing pandas.DataFrame or pandas.Series. Contains the empirical moments
calculated for the observed data. Moments should be saved to pandas.DataFrame or
pandas.Series that can either be passed to the function directly or as items of
a list or dictionary.

Returns
-------
flat_empirical_moments : pandas.DataFrame
Vector of empirical_moments with flat index.

"""
empirical_moments = copy.deepcopy(empirical_moments)
empirical_moments = _harmonize_input(empirical_moments)
flat_empirical_moments = _flatten_index(empirical_moments)

return flat_empirical_moments

def _harmonize_input(data):
"""Harmonize different types of inputs by turning all inputs into lists.

- pandas.DataFrames/Series and callable functions will turn into a list containing a
single item (i.e. the input).
- Dictionaries will be sorted according to keys and then turn into a list containing
the dictionary entries.

"""
# Convert single pandas.DataFrames, pandas.Series or function into list containing
# one item.
if isinstance(data, (pd.DataFrame, pd.Series)) or callable(data):
data = [data]

# Sort dictionary according to keys and turn into list.
elif isinstance(data, dict):
data = [data[key] for key in sorted(data)]

elif isinstance(data, list):
pass

else:
raise TypeError(
"Function only accepts lists, dictionaries, functions, Series and "
"DataFrames as inputs."
)

return data

def _flatten_index(data):
"""Flatten the index as a combination of the former index and the columns."""
data = copy.deepcopy(data)
data_flat = []
counter = itertools.count()

for series_or_df in data:
series_or_df.index = series_or_df.index.map(str)
# Unstack pandas.DataFrames and pandas.Series to add
# columns/name to index.
if isinstance(series_or_df, pd.DataFrame):
df = series_or_df.rename(columns=str)
# pandas.Series without a name are named using a counter to avoid duplicate
# indexes.
elif isinstance(series_or_df, pd.Series) and series_or_df.name is None:
df = series_or_df.to_frame(name=str(next(counter)))
else:
df = series_or_df.to_frame(str(series_or_df.name))

# Columns to the index.
df = df.unstack()
df.index = df.index.to_flat_index().str.join("_")
data_flat.append(df)

return pd.concat(data_flat)

def _create_comparison_plot_data_msm(
empirical_moments, simulated_moments, moment_set_labels
):
"""Create pandas.DataFrame for estimagic's comparison plot."""
if moment_set_labels is None:
moment_set_labels = list(range(0, len(empirical_moments)))

tidy_empirical_moments = _create_tidy_data(empirical_moments, moment_set_labels)
tidy_simulated_moments = _create_tidy_data(simulated_moments, moment_set_labels)

tidy_simulated_moments["kind"] = "simulated"
tidy_empirical_moments["kind"] = "empirical"

return pd.concat(
[tidy_empirical_moments, tidy_simulated_moments], ignore_index=True
)

def _create_tidy_data(data, moment_set_labels):
"""Create tidy data from list of pandas.DataFrames."""
counter = itertools.count()
tidy_data = []
for series_or_df, label in zip(data, moment_set_labels):
# Join index levels for MultiIndex objects.
if isinstance(series_or_df.index, pd.MultiIndex):
series_or_df = series_or_df.rename(index=str)
series_or_df.index = series_or_df.index.to_flat_index().str.join("_")
# If moments are a pandas.Series, convert into pandas.DataFrame.
if isinstance(series_or_df, pd.Series):
# Unnamed pandas.Series receive a name based on a counter.
if series_or_df.name is None:
series_or_df = series_or_df.to_frame(name=next(counter))
else:
series_or_df = series_or_df.to_frame()

# Create pandas.DataFrame in tidy format.
tidy_df = series_or_df.unstack()
tidy_df.index.names = ("moment_column", "moment_index")
tidy_df.rename("value", inplace=True)
tidy_df = tidy_df.reset_index()
tidy_df["moment_set"] = label
tidy_data.append(tidy_df)

return pd.concat(tidy_data, ignore_index=True)

def _reconstruct_inputs(inputs, dict_keys=None):
"""Reconstruct inputs from lists back to a dictionary or single object."""
if dict_keys is not None:
output = dict(zip(dict_keys, inputs))
elif len(inputs) == 1:
output = inputs[0]
else:
output = inputs

return output
```