import numpy as np
import matplotlib.pyplot as plt
import datetime as dt
import pandas as pd
import numpy.lib.recfunctions as rec

from scipy.interpolate import splev, splrep
from sklearn import datasets
from sklearn.cross_validation import train_test_split
from sklearn.grid_search import GridSearchCV, RandomizedSearchCV
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.svm import SVC, OneClassSVM
from sklearn import preprocessing as prep


def import_data(
    scada_data='Source Data/SCADA_data.csv',
        status_data_wec='Source Data/status_data_wec.csv',
        status_data_rtu='Source Data/status_data_rtu.csv',
        warning_data_wec='Source Data/warning_data_wec.csv',
        warning_data_rtu='Source Data/warning_data_rtu.csv'):
    """This imports the data, and returns arrays of SCADA & status data.
    Dates are converted to unix time, and strings are encoded in the
    correct format (unicode). Two new fields, "Inverter_averages" and
    "Inverter_std_dev", are also added to the SCADA data. These are the
    average and standard deviation of all Inverter Temperature fields.

    Parameters
    ----------

    scada_data: str, optional
        The raw SCADA data csv file.
    status_data_wec: str, optional
        The status/fault csv file for the WEC
    status_data_rtu: str, optional
        The status/fault csv file for the RTU
    warning_data_wec: str, optional
        The warning/information csv file for the WEC
    warning_data_rtu: str, optional
        The warning/information csv file for the RTU

    Returns
    -------
    scada_data: ndarray
        The imported and correctly formatted SCADA data
    status_data_wec: ndarray
        The imported and correctly formatted WEC status data
    status_data_rtu: ndarray
        The imported and correctly formatted RTU status data
    warning_data_wec: ndarray
        The imported and correctly formatted WEC warning data
    warning_data_rtu: ndarray
        The imported and correctly formatted RTU warning data

    Extra Notes
    -----------
    Both status_wec.csv & status_rtu.csv originally come from
    pes_extrainfo.csv, filtered according to their plant number.
    SCADA_data.csv contains the wsd, 03d and 04d data files all combined
    together.
    """
    SCADA = np.genfromtxt(open(scada_data, 'rb'), dtype=(
        '<U19', '<f4', '<f4', '<f4', '<f4', '<f4', '<f4', '<f4', '<f4',
        '<f4', '<f4', '<f4', '<f4', '<f4', '<f4', '<f4', '<f4', '<f4',
        '<f4', '<f4', '<f4', '<f4', '<f4', '<f4', '<f4', '<f4', '<f4',
        '<f4', '<f4', '<f4', '<f4', '<f4', '<f4', '<f4', '<f4', '<f4',
        '<f4', '<f4', '<f4', '<f4', '<f4', '<f4', '<f4', '<f4', '<f4',
        '<f4', '<f4', '<f4', '<f4', '<f4', '<f4', '<f4', '<f4', '<f4',
        '<f4', '<f4', '<f4', '<f4', '<f4', '<f4', '<f4', '<f4', '<f4'),
        delimiter=",", names=True)

    status_wec = np.genfromtxt(open(status_data_wec, 'rb'), dtype=(
        '<U19', '<i4', '<i4', '<U9', '<U63', '<i4', '|b1', '|b1', '<f4'),
        delimiter=",", names=True)

    status_rtu = np.genfromtxt(open(status_data_rtu, 'rb'), dtype=(
        '<U19', '<i4', '<i4', '<U9', '<U63', '<i4', '|b1', '|b1', '<f4'),
        delimiter=",", names=True)

    warning_wec = np.genfromtxt(open(warning_data_wec, 'rb'), dtype=(
        '<U19', '<i4', '<i4', '<U9', '<U63', '|b1', '<f4'),
        delimiter=",", names=True)

    warning_rtu = np.genfromtxt(open(warning_data_rtu, 'rb'), dtype=(
        '<U19', '<i4', '<i4', '<U9', '<U63', '|b1', '<f4'),
        delimiter=",", names=True)

    # Convert dates in the files to UNIX timestamps

    data_files = (SCADA, status_rtu, status_wec, warning_rtu, warning_wec)

    for data_file in data_files:
        # Convert datetimes to Unix timestamps (as strings)
        time = data_file['Time']
        for i in range(0, len(time)):
            t = dt.datetime.strptime(time[i], "%d/%m/%Y %H:%M:%S")
            t = (t - dt.datetime.fromtimestamp(3600)).total_seconds()
            time[i] = t

    # convert Unix timestamp string to float (for some reason this
    # doesn't work when in the loop above)

    dtlist = SCADA.dtype.descr
    dtlist[0] = (dtlist[0][0], '<f4')
    dtlist = np.dtype(dtlist)
    SCADA = SCADA.astype(dtlist)

    dtlist = status_wec.dtype.descr
    dtlist[0] = (dtlist[0][0], '<f4')
    dtlist = np.dtype(dtlist)
    status_wec = status_wec.astype(dtlist)

    dtlist = status_rtu.dtype.descr
    dtlist[0] = (dtlist[0][0], '<f4')
    dtlist = np.dtype(dtlist)
    status_rtu = status_rtu.astype(dtlist)

    dtlist = warning_wec.dtype.descr
    dtlist[0] = (dtlist[0][0], '<f4')
    dtlist = np.dtype(dtlist)
    warning_wec = warning_wec.astype(dtlist)

    dtlist = warning_rtu.dtype.descr
    dtlist[0] = (dtlist[0][0], '<f4')
    dtlist = np.dtype(dtlist)
    warning_rtu = warning_rtu.astype(dtlist)

    # Add 2 extra columns - Inverter_averages and Inverter_std_dev, as features
    inverters = np.array([
        'CS101__Sys_1_inverter_1_cabinet_temp',
        'CS101__Sys_1_inverter_2_cabinet_temp',
        'CS101__Sys_1_inverter_3_cabinet_temp',
        'CS101__Sys_1_inverter_4_cabinet_temp',
        'CS101__Sys_1_inverter_5_cabinet_temp',
        'CS101__Sys_1_inverter_6_cabinet_temp',
        'CS101__Sys_1_inverter_7_cabinet_temp',
        'CS101__Sys_2_inverter_1_cabinet_temp',
        'CS101__Sys_2_inverter_2_cabinet_temp',
        'CS101__Sys_2_inverter_3_cabinet_temp',
        'CS101__Sys_2_inverter_4_cabinet_temp'])
    means = pd.DataFrame(SCADA[inverters]).mean(axis=1).values
    stds = pd.DataFrame(SCADA[inverters]).std(axis=1).values
    SCADA = rec.append_fields(SCADA, ['Inverter_averages', 'Inverter_std_dev'],
                              data=[means, stds], usemask=False)

    return SCADA, status_wec, status_rtu, warning_wec, warning_rtu


# ---------------Filtering Functions-----------------------------
def power_curve_filtering(SCADA):
    """The algorithm, taken from [1], is used to label the data by
    filtering using the power curve. It primarily takes a SCADA
    argument, and this is then filtered according to a visual inspection
    of the power curve. The primary outputs are the filtered good and
    bad points on the power curve. All other outputs relate to plotting
    this data nicely in the power_curve_filtered_plot function.

    Parameters
    ----------
    SCADA: ndarray
        The SCADA data to be filtered. Note this must be either the
        SCADA data imported using import_data, or a subset of this data

    Returns
    -------
    SCADA_good_pc: ndarray
        The SCADA data marked as part of the nominal power curve by the
        algorithm
    SCADA_bad_pc: ndarray
        The SCADA data marked as anomalous by the algorithm
    SCADA_bin_averages: ndarray
        The average wind speed for each wind speed bin
    bins: ndarray
        The different wind speed bins
    x2: ndarray
        The x points for the generated interpolated power curve
    y2: ndarray
        The y points for the generated interpolated power curve
    upper_limit_ud:
        The upper limit of the power curve, above which points were
        marked as abnormal
    lower_limit_ud:
        The lower limit of the power curve, below which points were
        marked as abnormal

    References
    ----------
    [1] J. Park, J. Lee, K. Oh, and J. Lee, “Development of a Novel
    Power Curve Monitoring Method for Wind Turbines and Its Field
    Tests”,, IEEE Trans. Energy Convers., vol. 29, no. 1, pp. 119–128,
    2014.
    """

    # basic filtering for SCADA_real
    SCADA_real = SCADA[np.where((SCADA['Time'] > 0) &
                                (SCADA['WEC_ava_windspeed'] < 20))]

    # ------------------------------------------------------------------
    # --------------------------Algorithm loop start--------------------
    # ------------------------------------------------------------------
    SCADA_loop = SCADA_real
    # when terminate = True, the algorithm terminates
    terminate = False
    # initialise the standard deviation of the bins (empty array):
    SCADA_bin_stds_avg = np.zeros(200)
    # this increases on every loop iteration:
    k = 1

    while terminate is False:
        # --------1: set windspeed bins, width=1/loop iter no. (k)------
        max_wind = np.ceil(np.nanmax(SCADA_loop['WEC_ava_windspeed']))
        bin_width = 1 / k
        bins = np.arange(0, max_wind + bin_width, bin_width)

        # initialise average and std arrays
        SCADA_bin_averages = np.zeros(len(bins))
        SCADA_bin_stds = np.zeros(len(bins))

        # --------2: get average and std for each bin ------------------
        i = 0
        for i in range(0, len(bins)):
            SCADA_bin_averages[i] = np.mean(SCADA_loop[
                np.where((SCADA_loop['WEC_ava_windspeed'] >=
                          (bins[i] - bin_width / 2)) &
                         (SCADA_loop['WEC_ava_windspeed'] <
                          (bins[i] + bin_width / 2)))]['WEC_ava_Power'])
            SCADA_bin_stds[i] = np.std(SCADA_loop[
                np.where((SCADA_loop['WEC_ava_windspeed'] >=
                          (bins[i] - bin_width / 2)) &
                         (SCADA_loop['WEC_ava_windspeed'] <
                          (bins[i] + bin_width / 2)))]['WEC_ava_Power'])
            i = +1
        SCADA_bin_stds_avg[k] = np.nanmean(SCADA_bin_stds)

        # --------3: create splines-------------------------------------
        x = bins
        y = SCADA_bin_averages

        x2 = np.round(np.arange(0.0, 20.1, 0.1), 1)

        tck = splrep(x, y)

        tck_list = list(tck)
        yl = y.tolist()
        tck_list[1] = yl + [0.0, 0.0, 0.0, 0.0]

        y2 = splev(x2, tck_list)

        # --------4: Find left/right shifts:----------------------------

        # initialise SCADA_cur & PDL
        SCADA_cur = SCADA_loop
        PDL = np.zeros(100)
        PDL[0] = 20
        j = 0
        b_shift = .8
        dv = 0.1
        upper_limit_lr = np.zeros(len(x2))
        lower_limit_lr = np.zeros(len(x2))

        while (PDL[j] - PDL[j - 1]) >= b_shift:
            j += 1

            # create shifts
            right_shift = np.round(x2 + (dv * j), 1)
            left_shift = np.round(x2 - (dv * j), 1)

            t_inds = np.array([])

            # find the points which lie inside the upper and lower power
            # limits for the shifted power curves
            for i in range(0, len(x2)):
                # find where the upper and lower power limits are on the
                # left/right curves at the current windspeed. These are
                # the upper/lower "y" values for the corresponding
                # current x value
                upper_limit_lr[i] = np.nansum(
                    y2[np.where(left_shift == x2[i])])

                # this fixes the problem whereby the power curve is
                # shifted left, so the final few "upper_limit_lr" values
                # don't exist, so are shown as zeroes
                if (i >= 150) & (upper_limit_lr[i] == 0):
                    upper_limit_lr[i] = y2[i]

                lower_limit_lr[i] = np.nansum(
                    y2[np.where(right_shift == x2[i])])

                # get indices of points inside these lines
                t_inds_cur = np.where(
                    (SCADA_cur['WEC_ava_windspeed'] == x2[i]) &
                    (SCADA_cur['WEC_ava_Power'] >= lower_limit_lr[i]))
                t_inds = np.concatenate([t_inds, t_inds_cur[0]])

            t_inds = t_inds.astype(int)

            # make an array of these points
            mask = np.array(False).repeat(len(SCADA_cur))
            mask[t_inds] = True
            SCADA_inside_wind = SCADA_cur[mask]

            # calculate PDL
            PDL[j] = (len(SCADA_inside_wind) / len(SCADA_cur)) * 100

        # --------5: Find up/down shifts:-------------------------------

        # initialise SCADA_cur & PDL
        dP = 5
        y_offset = .03
        j = 0
        PDL = np.zeros(300)
        PDL[0] = 1
        SCADA_cur2 = SCADA_cur

        while (PDL[j] - PDL[j - 1]) >= y_offset:
            j += 1

            # create shifts
            upper_limit_ud = upper_limit_lr + (dP * j)
            lower_limit_ud = lower_limit_lr - (dP * j)

            t_inds = np.array([])

            # find the points which lie inside the upper and lower power
            # limits for the shifted power curves
            for i in range(0, len(x2)):
                # get indices of points inside these lines
                t_inds_cur = np.where(
                    (SCADA_cur2['WEC_ava_windspeed'] == x2[i]) &
                    (SCADA_cur2['WEC_ava_Power'] >= lower_limit_ud[i]))
                t_inds = np.concatenate([t_inds, t_inds_cur[0]])

            t_inds = t_inds.astype(int)

            # make an array of these points
            mask = np.array(False).repeat(len(SCADA_cur2))
            mask[t_inds] = True
            SCADA_inside_wind2 = SCADA_cur2[mask]

            # calculate PDL
            PDL[j] = (len(SCADA_inside_wind2) / len(SCADA_cur2)) * 100

        # set the output as the input of the next loop
        SCADA_loop = SCADA_inside_wind2

        # Check if the loop will be terminated
        a_loop = SCADA_bin_stds_avg[k] - SCADA_bin_stds_avg[k - 1]

        if a_loop < 1:
            terminate = True

        k += 1

    # ------------------------------------------------------------------
    # ------------------------Algorithm loop end------------------------
    # ------------------------------------------------------------------

    # list out good SCADA indices:
    SCADA_good_pc = SCADA_inside_wind2

    # list out bad SCADA indices:
    bad_mask = np.array([True]).repeat(len(SCADA_real))

    for time in SCADA_good_pc['Time']:
        bad_mask[np.where(SCADA_real['Time'] == time)] = False
    SCADA_bad_pc = SCADA_real[bad_mask]

    return SCADA_good_pc, SCADA_bad_pc, SCADA_bin_averages, bins, x2, y2,
    upper_limit_ud, lower_limit_ud


def filtering(
    SCADA, filter_file, column_name, time_diff_before=3600,
        time_diff_after=3600, good=True, *filter_codes):
    """This function filters the SCADA data obtained in import_data, or
    a subset of that data, by matching the data with the timestamps, and
    a band around the timestamps, of certain status or warning code
    messages. The end result is SCADA data which corresponds to certain
    operating states or faults.

    Parameters
    ----------
    SCADA: ndarray
        The SCADA data to be filtered. Must be the SCADA data obtained
        from import_data, or a subset of that data
    filter_file: ndarray
        The is one of: warning_rtu, warning_wec, status_rtu, status_wec.
    column_name: string
        Refers to the column being filtered (i.e. "Main_Status" or
        "Full_Status").
    time_diff_before: integer, optional
        The timeband before which to be filtered
    time_diff_after: integer, optional
        Timeband after which to be filtered
    good: Boolean
        Refers to whether the passed filter codes refer to fault data or
        fault-free data. If good=True, then the function will assume the
        filter_codes provided correspond to nominal operation, and will
        be filtered according to [time_of_status + time_diff_before,
        time_of_next_status - time_diff_after]. If good=False, it's
        assumed filter_codes refer to faults, and will be filtered
        according to[time_of_status - time_diff_before,
        time_of_next_status + time_diff_after].
    *filter_codes: array of str or int
        The set of codes to be filtered. Can be a single value or an
        array. If column_name is "Full_Status", then it must be a set of
        strings (e.g. '0 : 0' for nominal operation). Otherwise, an
        integer referring to the Main_Status.

    Returns
    -------
    SCADA_good: ndarray
        If good=True, SCADA_good is data strictly corresponding to
        fault-free data.
        If good=False, SCADA_good is data which isn't faulty data (but
        not necessarily fault-free, i.e. it could include times when the
        turbine was down for routine maintenance or curtailed power
        output, etc.).
    SCADA_bad: ndarray
        If good=True, SCADA_bad is data which isn't strictly fault-free
        (but not necessarily definitely faulty, i.e. it could include
        times when the turbine was down for routine maintenance or
        curtailed power output, etc.).
        If good=False, SCADA_bad is date strictly corresponding to fault
        data.

    """

    # Get the indices of filter_file which do NOT correspond to the
    # passed filter codes
    filter_file_indices = np.array([], dtype='i4')

    for filter_code in filter_codes:
        f = np.where((filter_file[column_name] == filter_code))
        filter_file_indices = np.sort(
            np.concatenate((filter_file_indices, f[0]), axis=0))

    # this finds SCADA timestamps which are greater than the "bad" wec
    # time less a time_diff, AND MORE than the next wec time + the
    # time_diff
    SCADA_filtered_indices = np.array([], dtype='i4')

    if filter_file[-1] == filter_file[filter_file_indices][-1]:
        # less 1 so as not to create an out of bounds error at run time
        index_range = range(0, len(filter_file_indices) - 1)
    else:
        # if it's not the last entry, we're all good
        index_range = range(0, len(filter_file_indices))

    if good is True:
        for i in index_range:
            g1 = np.where(
                (SCADA['Time'] >= filter_file['Time'][filter_file_indices[i]] +
                    time_diff_before) &
                (SCADA['Time'] <
                    filter_file['Time'][filter_file_indices[i] + 1] -
                    time_diff_after))

            SCADA_filtered_indices = np.concatenate(
                (SCADA_filtered_indices, g1[0]), axis=0)

        SCADA_filtered_indices = np.unique(SCADA_filtered_indices)

        # create the "good" mask
        mask = np.array([False]).repeat(len(SCADA))
        mask[SCADA_filtered_indices] = True
        SCADA_good = SCADA[mask]

        # create the "bad" mask
        mask = np.array([True]).repeat(len(SCADA))
        mask[SCADA_filtered_indices] = False
        SCADA_bad = SCADA[mask]

    else:
        for i in index_range:
            g1 = np.where(
                (SCADA['Time'] >= filter_file['Time'][filter_file_indices[i]] -
                    time_diff_before) &
                (SCADA['Time'] <
                    filter_file['Time'][filter_file_indices[i] + 1] +
                    time_diff_after))

            SCADA_filtered_indices = np.concatenate(
                (SCADA_filtered_indices, g1[0]), axis=0)

        SCADA_filtered_indices = np.unique(SCADA_filtered_indices)

        # create the "good" mask
        mask = np.array([True]).repeat(len(SCADA))
        mask[SCADA_filtered_indices] = False
        SCADA_good = SCADA[mask]

        # create the "bad" mask
        mask = np.array([False]).repeat(len(SCADA))
        mask[SCADA_filtered_indices] = True
        SCADA_bad = SCADA[mask]

    return SCADA_good, SCADA_bad


def get_fault_data(before, after):
    """This function is a shortcut to get a list of faults using the
    filtering() function. Returns a bunch of different faults.

    Parameters
    ----------
    before: integer
        The time_diff_before to be passed to the filtering() function
    after: integer
        The time_diff_after to be passed to the filtering() function

    Returns
    -------
    SCADA_all_faults: ndarray
        An array of all SCADA data corresponding to fault times
    SCADA_feeding_faults: ndarray
        An array of SCADA data corresponding to feeding faults
    SCADA_aircooling_faults: ndarray
        An array of SCADA data corresponding to aircooling faults
    SCADA_excitation_faults: ndarray
        An array of SCADA data corresponding to excitation faults
    SCADA_generator_heating_faults: ndarray
        An array of SCADA data corresponding to generator heating faults
    SCADA_mains_failure_faults: ndarray
        An array of SCADA data corresponding to mains_failure faults

    """
    # Shortcut Function to get all the fault data
    faults = (80, 62, 228, 60, 9)
    SCADA_all_faults = filtering(
        SCADA, status_wec, 'Main_Status', before, after, False, *faults)[1]
    SCADA_feeding_faults = filtering(
        SCADA, status_wec, 'Main_Status', before, after, False, 62)[1]
    SCADA_mains_failure_faults = filtering(
        SCADA, status_wec, 'Main_Status', before, after, False, 60)[1]
    SCADA_aircooling_faults = filtering(
        SCADA, status_wec, 'Main_Status', before, after, False, 228)[1]
    SCADA_excitation_faults = filtering(
        SCADA, status_wec, 'Main_Status', before, after, False, 80)[1]
    SCADA_generator_heating_faults = filtering(
        SCADA, status_wec, 'Main_Status', before, after, False, 9)[1]

    return SCADA_all_faults, SCADA_feeding_faults, SCADA_aircooling_faults, \
        SCADA_excitation_faults, SCADA_generator_heating_faults, \
        SCADA_mains_failure_faults

# ---------------------Exporting Data-----------------------------------


def export_data(filenames, data):
    """Export a csv of a subset of the SCADA data (e.g. relating to a
    specific fault, or fault-free)

    Parameters
    ----------
    filenames: str or array of strings
        The file name(s) to be exported
    data: ndarray or array of ndarrays
        The corresponding SCADA data to be exported
    """
    for f, d in zip(filenames, data):
        dtlist = np.array(d.dtype.descr)
        headings = ""
        for i in dtlist[:, 0]:
            headings += (i)
            headings += (",")
        headings = headings[0:-1]
        np.savetxt(
            f, d, delimiter=',', newline='\r\n', header=headings, fmt='%s')

# -------------------------Plot Functions-------------------------------

# These are various different plotting functions used for testing, etc.


def power_curve_filtered_plot(
    SCADA_good, SCADA_bad, SCADA_bin_averages, bins, x2, y2, upper_limit_ud,
        lower_limit_ud):
    """This function generates a nice plot from the data generated in
    power_curve_filtering()

    Parameters
    ----------
    SCADA_good: ndarray
        The "good" SCADA data from power_curve_filtering() to be plotted
    SCADA_bad: ndarray
        The "bad" SCADA data from power_curve_filtering() to be plotted
    SCADA_bin_averages, bins, x2, y2, upper_limit_ud, lower_limit_ud:
        These are all variables from power_curve_filtering().

    See help(power_curve_filtering) for more info, and an explanation of
    the algorithm)

    """

    # plot it all
    plt.figure(figsize=(40, 20))

    # ax1=fig.add_subplot(111)

    power_curve_plot = plt.plot(x2, y2, 'g', linewidth=3.0)

    # ----left/right plot (for testing only)----
    # ax1.scatter(SCADA_inside_wind['WEC_ava_windspeed'],
    #     SCADA_inside_wind['WEC_ava_Power'], c='g', s=50)
    # upper_limit_plot = ax1.plot(x2, upper_limit_lr, 'y', linewidth=3.0)
    # lower_limit_plot = ax1.plot(x2, lower_limit_lr, 'r', linewidth=3.0)

    # up/down plot
    # good and bad points
    ava_good_temp = (SCADA_good['CS101__Nacelle_ambient_temp_1'] +
                     SCADA_good['CS101__Nacelle_ambient_temp_2']) / 2
    ava_bad_tmep = (SCADA_bad['CS101__Nacelle_ambient_temp_1'] +
                    SCADA_bad['CS101__Nacelle_ambient_temp_2']) / 2

    good_plt = plt.scatter(
        SCADA_good['WEC_ava_windspeed'], SCADA_good['WEC_ava_Power'],
        c=ava_good_temp, cmap=plt.cm.Blues, linewidth='0', s=50)
    bad_plt = plt.scatter(
        SCADA_bad['WEC_ava_windspeed'], SCADA_bad['WEC_ava_Power'],
        c=ava_bad_tmep, cmap=plt.cm.Reds, linewidth='0', s=50)

    # upper and lower limits
    upper_limit_plot = plt.plot(x2, upper_limit_ud, 'black', linewidth=1.0)
    lower_limit_plot = plt.plot(x2, lower_limit_ud, 'black', linewidth=1.0)

    # show the bins on top!
    plt.scatter(bins, SCADA_bin_averages, c='r', label='bins', s=100)

    # put a grid on it
    plt.grid(b=True, which='major', color='b', linestyle='-')
    plt.minorticks_on()
    plt.grid(b=True, which='minor', color='r', linestyle='--')

    # legend, title, colorbar
    plt.legend(loc='upper left')
    plt.title("Filtered Power Curve")
    plt.colorbar(good_plt)

    plt.show()


def standard_plot(
    SCADA_good, SCADA_fault, title='Power Curve Plot',
        temp=False):

    # plot it all
    plt.figure(figsize=(40, 20))

    # up/down plot
    # good and bad points
    if temp is True:
        ava_good_temp = (SCADA_good['CS101__Nacelle_ambient_temp_1'] +
                         SCADA_good['CS101__Nacelle_ambient_temp_2']) / 2
        ava_fault_temp = (SCADA_fault['CS101__Nacelle_ambient_temp_1'] +
                          SCADA_fault['CS101__Nacelle_ambient_temp_2']) / 2
        good_colour = ava_good_temp
        bad_colour = ava_fault_temp
    else:
        good_colour = 'b'
        bad_colour = 'r'

    good_plt = plt.scatter(
        SCADA_good['WEC_ava_windspeed'], SCADA_good['WEC_ava_Power'],
        c=good_colour, cmap=plt.cm.Blues, linewidth='0', s=50)
    fault_plt = plt.scatter(
        SCADA_fault['WEC_ava_windspeed'], SCADA_fault['WEC_ava_Power'],
        c=bad_colour, cmap=plt.cm.Reds, linewidth='0', s=50)

    # put a grid on it
    plt.grid(b=True, which='major', color='b', linestyle='-')
    plt.minorticks_on()
    plt.grid(b=True, which='minor', color='r', linestyle='--')

    # legend, title, colorbar
    plt.legend(loc='upper left')
    plt.title(title)

    if temp:
        plt.colorbar(good_plt)

    plt.show()


# --------------------SVM Functions-------------------------------------

def generate_labels(
    no_fault_data, fault_data, features, normalize=True,
        split=0.2):
    """Generate labels for the SCADA data.

    Parameters
    ----------
    no_fault_data: ndarray
        Subset of SCADA data imported using import_data which
        corresponds to fault-free data
    fault_data: ndarray
        Subset of SCADA data imported using import_data which
        corresponds to fault data. Examples include SCADA_all_faults,
        SCADA_feeding_faults, etc.
    features: array of strings
        set of features used in the dataset.
    normalize: Boolean, optional
        Whether or not to normalize the training data. Default is True.
    split: float, optional
        The ratio of testing : training data to use. Default is 0.2

    Returns
    -------
    X_train: ndarray
        The set of data to be trained on
    y_train: ndarray
        The associated labels
    X_test: ndarray
        The set of data to be tested on
    y_test: ndarray
        The associated labels
    X_train_bal: ndarray
        Used for balanced training data (i.e. no. fault class=no.
        of fault-free class)
    y_train_bal: ndarray
        Used for balanced testing data
    """

    good_labels = np.zeros(len(no_fault_data), dtype=np.int)
    bad_labels = np.ones(len(fault_data), dtype=np.int)

    # append the appropriate labels to the data
    good = rec.append_fields(
        no_fault_data[features], ['label'], data=[good_labels], usemask=False)
    bad = rec.append_fields(
        fault_data[features], ['label'], data=[bad_labels], usemask=False)

    # join all the data together and shuffle it
    dataset = np.concatenate([good, bad])
    np.random.shuffle(dataset)

    # separate the training data from the labels, and normalize
    y = dataset['label']
    X = rec.drop_fields(dataset, ['label'], False).view(
        np.float32).reshape(len(dataset), len(dataset.dtype) - 1)

    # Create Training and Test Sets
    if normalize is True:
        X_norm = prep.normalize(X)
        X_train, X_test, y_train, y_test = train_test_split(
            X_norm, y, test_size=split)
    else:
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=split)

    # Create the balanced training sets
    X_train_bad = X_train[np.where(y_train == 1)]
    y_train_bad = y_train[np.where(y_train == 1)]

    X_train_good_bal = X_train[
        np.where(y_train == 0)][0:round(len(X_train_bad))]
    y_train_good_bal = y_train[
        np.where(y_train == 0)][0:round(len(X_train_bad))]

    X_train_bal_unshuffled = np.concatenate([X_train_good_bal, X_train_bad])
    y_train_bal_unshuffled = np.concatenate([y_train_good_bal, y_train_bad])

    balanced_training_data = np.append(
        X_train_bal_unshuffled, np.array([y_train_bal_unshuffled]).T, axis=1)
    np.random.shuffle(balanced_training_data)
    y_train_bal = balanced_training_data[:, 29]
    X_train_bal = balanced_training_data[:, 0:29]

    return X_train, X_test, y_train, y_test, X_train_bal, y_train_bal