"""Utility functions for mediation experiment."""
import matplotlib.pylab as plt
import numpy as np


def error_bar_plot(experiment_data, results, title="", ylabel=""):

    true_effect = experiment_data.true_effects.mean()
    estimators = list(results.keys())

    x = list(estimators)
    y = [results[estimator].ate for estimator in estimators]

    cis = [
        np.array(results[estimator].ci) - results[estimator].ate
        if results[estimator].ci is not None
        else [0, 0]
        for estimator in estimators
    ]
    err = [[abs(ci[0]) for ci in cis], [abs(ci[1]) for ci in cis]]

    plt.figure(figsize=(12, 5))
    (_, caps, _) = plt.errorbar(x, y, yerr=err, fmt="o", markersize=8, capsize=5)
    for cap in caps:
        cap.set_markeredgewidth(2)
    plt.plot(x, [true_effect] * len(x), label="True Effect")
    plt.legend(fontsize=12, loc="lower right")
    plt.ylabel(ylabel)
    plt.title(title)