import h5py
import math
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
from bokeh.io import export_png
from bokeh.models import Label, Span, Title, ColumnDataSource
from bokeh.plotting import figure
from argparse import ArgumentParser

def get_annotations(path, fields, enum_field):
    data_labels = {}
    for field in fields:
        data_labels[field] = path[field]
    data_dtypes = {}
    if h5py.check_dtype(enum=path.dtype[enum_field]):
        dataset_dtype = h5py.check_dtype(enum=path.dtype[enum_field])
        # data_dtype may lose some dataset dtypes there are duplicates of 'v'
        data_dtypes = {v: k for k, v in dataset_dtype.items()}
    labels_df = pd.DataFrame(data=data_labels)
    return labels_df, data_dtypes


def create_figure(ch, s, e, sf, file, filename):
    sq = s * sf
    eq = e * sf
    n = e - s
    ch_str = 'Channel_{n}'.format(n=ch)
    step = 1 / sf
    x_data = np.arange(s, e, step)
    y_data = file["Raw"][ch_str]["Signal"][()]
    y_data = y_data[sq:eq]
    d = math.e ** 2.5
    thin_factor = math.ceil(n / d)
    greater_delete_index = np.argwhere(y_data > 1500)
    x_data = np.delete(x_data, greater_delete_index)
    y_data = np.delete(y_data, greater_delete_index)

    lesser_delete_index = np.argwhere(y_data < 0)
    x_data = np.delete(x_data, lesser_delete_index)
    y_data = np.delete(y_data, lesser_delete_index)
    data = {
        'x': x_data[::thin_factor],
        'y': y_data[::thin_factor],
    }
    source = ColumnDataSource(data=data)

    p = figure(
        plot_height=965,
        plot_width=800
    )
    p.output_backend = 'canvas'
    p.add_layout(Title(
        text="Channel: {ch} Start: {st} End: {ed} Sample rate: {sf}".format(
            ch=ch,
            st=s,
            ed=e,
            sf=sf
        )),
        'above'
    )
    p.add_layout(Title(
        text="Bulk-file: {s}".format(s=filename.split('/')[-1])),
        'above'
    )

    p.toolbar.logo = None
    p.yaxis.axis_label = "Raw signal"
    p.yaxis.major_label_orientation = "horizontal"
    p.xaxis.axis_label = "Time (seconds)"
    p.line(source=source, x='x', y='y', line_width=1)
    p.xaxis.major_label_orientation = math.radians(45)
    p.x_range.range_padding = 0.05

    # get annotations
    path = file["IntermediateData"][ch_str]["Reads"]
    fields = ['read_id', 'read_start', 'modal_classification']
    label_df, label_dt = get_annotations(path, fields, 'modal_classification')
    label_df = label_df.drop_duplicates(subset=['read_id', 'modal_classification'], keep="first")
    label_df.read_start = label_df.read_start / sf
    label_df.read_id = label_df.read_id.str.decode('utf8')

    path = file["StateData"][ch_str]["States"]
    fields = ['acquisition_raw_index', 'summary_state']
    state_label_df, state_label_dtypes = get_annotations(path, fields, 'summary_state')
    state_label_df.acquisition_raw_index = state_label_df.acquisition_raw_index / sf
    state_label_df = state_label_df.rename(
        columns={'acquisition_raw_index': 'read_start', 'summary_state': 'modal_classification'}
    )
    label_df = label_df.append(state_label_df, ignore_index=True)
    label_df.sort_values(by='read_start', ascending=True, inplace=True)
    label_dt.update(state_label_dtypes)
    # Here labels are thinned out
    slim_label_df = label_df[(label_df['read_start'] >= s) & (label_df['read_start'] <= e)]
    for index, label in slim_label_df.iterrows():
        event_line = Span(
            location=label.read_start,
            dimension='height',
            line_color='green',
            line_dash='dashed',
            line_width=1
        )
        p.add_layout(event_line)
        labels = Label(
            x=label.read_start,
            y=800,
            text="{cl} - {ri}".format(cl=label_dt[label.modal_classification], ri=label.read_id),
            level='glyph',
            x_offset=0,
            y_offset=0,
            render_mode='canvas',
            angle=-300
        )
        p.add_layout(labels)
    return p


def get_args():
    parser = ArgumentParser(
        description="""Generate plots for all reads in a fused_reads.txt file. 
        This uses bokeh to render a plot and requires selenium, phantomjs, and 
        Pillow to be installed. These are available via conda/pip.""",
        add_help=False)
    general = parser.add_argument_group(
        title='General options')
    general.add_argument("-h", "--help",
                         action="help",
                         help="Show this help and exit"
                         )
    in_args = parser.add_argument_group(
        title='Input sources'
    )
    in_args.add_argument("-f", "--fused",
                         help="A fused read file generated by whale_watch.py",
                         type=str,
                         default="",
                         required=True,
                         metavar=''
                         )
    in_args.add_argument("-b", "--bulk-file",
                         help="An ONT bulk-fast5-file",
                         type=str,
                         default='',
                         required=True,
                         metavar=''
                         )
    out_args = parser.add_argument_group(
        title='Output files'
    )
    out_args.add_argument('-D', '--out-dir',
                          help='''Specify the output directory where plots will be saved. Defaults to current working
                                  directory''',
                          type=str,
                          default='',
                          metavar=''
                          )
    return parser.parse_args()


def main():
    args = get_args()
    # create output dir if not exist
    dir = Path.resolve(Path(args.out_dir))
    dir.mkdir(parents=True, exist_ok=True)
    # open bulkfile
    bulkfile = h5py.File(args.bulk_file, "r")
    # get run_id and sf
    sf = int(bulkfile["UniqueGlobalKey"]["context_tags"].attrs["sample_frequency"].decode('utf8'))
    run_id = bulkfile["UniqueGlobalKey"]["tracking_id"].attrs["run_id"].decode('utf8')
    fields = ['coords', 'run_id']
    fused_df = pd.read_csv(args.fused, sep='\t', usecols=fields)
    # limit to matching run
    fused_df = fused_df[fused_df['run_id'] == run_id]
    for index, row in tqdm(fused_df.iterrows(), total=fused_df.shape[0]):
        line = row['coords']
        coords = line.split(":")
        times = coords[1].split("-")
        channel_num = coords[0]
        start_time, end_time = int(times[0]), int(times[1])
        plot = create_figure(channel_num, start_time, end_time, sf, bulkfile, args.bulk_file)
        line = '{c}.png'.format(c=line.replace(':', '_'))
        write_path = Path(dir / line)
        export_png(plot, filename=write_path)
    return


if __name__ == '__main__':
    main()