"""
DeepLabCut2.2 Toolbox (deeplabcut.org)
© A. & M. Mathis Labs
https://github.com/AlexEMG/DeepLabCut
Please see AUTHORS for contributors.

https://github.com/AlexEMG/DeepLabCut/blob/master/AUTHORS
Licensed under GNU Lesser General Public License v3.0
"""

import os
import warnings

import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.collections import LineCollection
from matplotlib.path import Path
from matplotlib.widgets import Button, LassoSelector
from ruamel.yaml import YAML
from scipy.spatial import cKDTree as KDTree
from skimage import io


def read_config(configname):
    if not os.path.exists(configname):
        raise FileNotFoundError(
            f"Config {configname} is not found. Please make sure that the file exists."
        )
    with open(configname) as file:
        return YAML().load(file)


def write_config(configname, cfg):
    with open(configname, "w") as file:
        YAML().dump(cfg, file)


class SkeletonBuilder:
    def __init__(self, config_path):
        self.config_path = config_path
        self.cfg = read_config(config_path)
        # Find uncropped labeled data
        root = os.path.join(self.cfg["project_path"], "labeled-data")
        for dir_ in os.listdir(root):
            folder = os.path.join(root, dir_)
            if os.path.isdir(folder) and not any(
                folder.endswith(s) for s in ("cropped", "labeled")
            ):
                break
        self.df = pd.read_hdf(
            os.path.join(folder, f'CollectedData_{self.cfg["scorer"]}.h5')
        )
        # Handle data previously annotated on a different platform
        sep = "/" if "/" in self.df.index[0] else "\\"
        if sep != os.path.sep:
            self.df.index = self.df.index.str.replace(sep, os.path.sep)
        row, col = self.pick_labeled_frame()
        if "individuals" in self.df.columns.names:
            self.df = self.df.xs(col, axis=1, level="individuals")
        self.bpts = self.df.columns.get_level_values("bodyparts").unique()
        self.xy = self.df.loc[row].values.reshape((-1, 2))
        missing = np.flatnonzero(np.isnan(self.xy).all(axis=1))
        if missing.any():
            warnings.warn(
                f"A fully labeled animal could not be found. "
                f"{', '.join(self.bpts[missing])} will need to be manually connected in the config.yaml."
            )
        self.tree = KDTree(self.xy)
        self.image = io.imread(os.path.join(self.cfg["project_path"], row))
        self.inds = set()
        self.segs = set()
        # Draw the skeleton if already existent
        if self.cfg["skeleton"]:
            for bone in self.cfg["skeleton"]:
                pair = np.flatnonzero(self.bpts.isin(bone))
                if len(pair) != 2:
                    continue
                pair_sorted = tuple(sorted(pair))
                self.inds.add(pair_sorted)
                self.segs.add(tuple(map(tuple, self.xy[pair_sorted, :])))
        self.lines = LineCollection(
            self.segs, colors=mcolors.to_rgba(self.cfg["skeleton_color"])
        )
        self.lines.set_picker(True)
        self.show()

    def pick_labeled_frame(self):
        # Find the most 'complete' animal
        try:
            count = self.df.groupby(level="individuals", axis=1).count()
            if "single" in count:
                count.drop("single", axis=1, inplace=True)
        except KeyError:
            count = self.df.count(axis=1).to_frame()
        mask = count.where(count == count.values.max())
        kept = mask.stack().index.to_list()
        np.random.shuffle(kept)
        row, col = kept.pop()
        return row, col

    def show(self):
        self.fig = plt.figure()
        ax = self.fig.add_subplot(111)
        ax.axis("off")
        lo = np.nanmin(self.xy, axis=0)
        hi = np.nanmax(self.xy, axis=0)
        center = (hi + lo) / 2
        w, h = hi - lo
        ampl = 1.3
        w *= ampl
        h *= ampl
        ax.set_xlim(center[0] - w / 2, center[0] + w / 2)
        ax.set_ylim(center[1] - h / 2, center[1] + h / 2)
        ax.imshow(self.image)
        ax.scatter(*self.xy.T, s=self.cfg["dotsize"] ** 2)
        ax.add_collection(self.lines)
        ax.invert_yaxis()

        self.lasso = LassoSelector(ax, onselect=self.on_select)
        ax_clear = self.fig.add_axes([0.85, 0.55, 0.1, 0.1])
        ax_export = self.fig.add_axes([0.85, 0.45, 0.1, 0.1])
        self.clear_button = Button(ax_clear, "Clear")
        self.clear_button.on_clicked(self.clear)
        self.export_button = Button(ax_export, "Export")
        self.export_button.on_clicked(self.export)
        self.fig.canvas.mpl_connect("pick_event", self.on_pick)
        plt.show()

    def clear(self, *args):
        self.inds.clear()
        self.segs.clear()
        self.lines.set_segments(self.segs)

    def export(self, *args):
        inds_flat = set(ind for pair in self.inds for ind in pair)
        unconnected = [i for i in range(len(self.xy)) if i not in inds_flat]
        if len(unconnected):
            warnings.warn(
                f'Unconnected {", ".join(self.bpts[unconnected])}. '
                f"It is desirable that all bodyparts be connected for multi-animal projects."
            )
        self.cfg["skeleton"] = [tuple(self.bpts[list(pair)]) for pair in self.inds]
        write_config(self.config_path, self.cfg)

    def on_pick(self, event):
        if event.mouseevent.button == 3:
            removed = event.artist.get_segments().pop(event.ind[0])
            self.segs.remove(tuple(map(tuple, removed)))
            self.inds.remove(tuple(self.tree.query(removed)[1]))

    def on_select(self, verts):
        self.path = Path(verts)
        self.verts = verts
        inds = self.tree.query_ball_point(verts, 5)
        inds_unique = []
        for lst in inds:
            if len(lst) and lst[0] not in inds_unique:
                inds_unique.append(lst[0])
        for pair in zip(inds_unique, inds_unique[1:]):
            pair_sorted = tuple(sorted(pair))
            self.inds.add(pair_sorted)
            self.segs.add(tuple(map(tuple, self.xy[pair_sorted, :])))
        self.lines.set_segments(self.segs)
        self.fig.canvas.draw_idle()