#!/usr/bin/env python
"""Processing input and plotting."""

from __future__ import division, print_function
import numpy as np
import segmentator.config as cfg
import matplotlib
matplotlib.use(cfg.matplotlib_backend)
print("Matplotlib backend: {}".format(matplotlib.rcParams['backend']))
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from matplotlib.widgets import Slider, Button, LassoSelector
from matplotlib import path
from nibabel import load
from segmentator.utils import map_ima_to_2D_hist, prep_2D_hist
from segmentator.utils import truncate_range, scale_range, check_data
from segmentator.utils import set_gradient_magnitude
from segmentator.utils import export_gradient_magnitude_image
from segmentator.gui_utils import sector_mask, responsiveObj
from segmentator.config_gui import palette, axcolor, hovcolor

#
"""Data Processing"""
nii = load(cfg.filename)
orig, dims = check_data(nii.get_data(), cfg.force_original_precision)
# Save min and max truncation thresholds to be used in axis labels
if np.isnan(cfg.valmin) or np.isnan(cfg.valmax):
    orig, pMin, pMax = truncate_range(orig, percMin=cfg.perc_min,
                                      percMax=cfg.perc_max)
else:  # TODO: integrate this into truncate range function
    orig[orig < cfg.valmin] = cfg.valmin
    orig[orig > cfg.valmax] = cfg.valmax
    pMin, pMax = cfg.valmin, cfg.valmax

# Continue with scaling the original truncated image and recomputing gradient
orig = scale_range(orig, scale_factor=cfg.scale, delta=0.0001)
gra = set_gradient_magnitude(orig, cfg.gramag)
if cfg.export_gramag:
    export_gradient_magnitude_image(gra, nii.get_filename(), cfg.gramag,
                                    nii.affine)
# Reshape for voxel-wise operations
ima = np.copy(orig.flatten())
gra = gra.flatten()

#
"""Plots"""
print("Preparing GUI...")
# Plot 2D histogram
fig = plt.figure(facecolor='0.775')
ax = fig.add_subplot(121)

counts, volHistH, d_min, d_max, nr_bins, bin_edges \
    = prep_2D_hist(ima, gra, discard_zeros=cfg.discard_zeros)

# Set x-y axis range to the same (x-axis range)
ax.set_xlim(d_min, d_max)
ax.set_ylim(d_min, d_max)
ax.set_xlabel("Intensity f(x)")
ax.set_ylabel("Gradient Magnitude f'(x)")
ax.set_title("2D Histogram")

# Plot colorbar for 2D hist
volHistH.set_norm(LogNorm(vmax=np.power(10, cfg.cbar_init)))
fig.colorbar(volHistH, fraction=0.046, pad=0.04)  # magical scaling

# Plot 3D ima by default
ax2 = fig.add_subplot(122)
sliceNr = int(0.5*dims[2])
imaSlcH = ax2.imshow(orig[:, :, sliceNr], cmap=plt.cm.gray, vmin=ima.min(),
                     vmax=ima.max(), interpolation='none',
                     extent=[0, dims[1], dims[0], 0], zorder=0)

imaSlcMsk = np.ones(dims[0:2])
imaSlcMskH = ax2.imshow(imaSlcMsk, cmap=palette, vmin=0.1,
                        interpolation='none', alpha=0.5,
                        extent=[0, dims[1], dims[0], 0], zorder=1)

# Adjust subplots on figure
bottom = 0.30
fig.subplots_adjust(bottom=bottom)
fig.canvas.set_window_title(nii.get_filename())
plt.axis('off')

#
"""Initialisation"""
# Create first instance of sector mask
sectorObj = sector_mask((nr_bins, nr_bins), cfg.init_centre, cfg.init_radius,
                        cfg.init_theta)

# Draw sector mask for the first time
volHistMaskH, volHistMask = sectorObj.draw(ax, cmap=palette, alpha=0.2,
                                           vmin=0.1, interpolation='nearest',
                                           origin='lower', zorder=1,
                                           extent=[0, nr_bins, 0, nr_bins])

# Initiate a flexible figure object, pass to it useful properties
idxLasso = np.zeros(nr_bins*nr_bins, dtype=bool)
lassoSwitchCount = 0
lassoErase = 1  # 1 for drawing, 0 for erasing
flexFig = responsiveObj(figure=ax.figure, axes=ax.axes, axes2=ax2.axes,
                        segmType='main', orig=orig, nii=nii,
                        sectorObj=sectorObj,
                        nrBins=nr_bins,
                        sliceNr=sliceNr,
                        imaSlcH=imaSlcH,
                        imaSlcMsk=imaSlcMsk, imaSlcMskH=imaSlcMskH,
                        volHistMask=volHistMask, volHistMaskH=volHistMaskH,
                        contains=volHistMaskH.contains,
                        counts=counts,
                        idxLasso=idxLasso,
                        lassoSwitchCount=lassoSwitchCount,
                        lassoErase=lassoErase)

# Make the figure responsive to clicks
flexFig.connect()
ima2volHistMap = map_ima_to_2D_hist(xinput=ima, yinput=gra, bins_arr=bin_edges)
flexFig.invHistVolume = np.reshape(ima2volHistMap, dims)
ima, gra = None, None

#
"""Sliders and Buttons"""
# Colorbar slider
axHistC = plt.axes([0.15, bottom-0.20, 0.25, 0.025], facecolor=axcolor)
flexFig.sHistC = Slider(axHistC, 'Colorbar', 1, cfg.cbar_max,
                        valinit=cfg.cbar_init, valfmt='%0.1f')

# Image browser slider
axSliceNr = plt.axes([0.6, bottom-0.15, 0.25, 0.025], facecolor=axcolor)
flexFig.sSliceNr = Slider(axSliceNr, 'Slice', 0, 0.999, valinit=0.5,
                          valfmt='%0.2f')

# Theta sliders
aThetaMin = plt.axes([0.15, bottom-0.10, 0.25, 0.025], facecolor=axcolor)
flexFig.sThetaMin = Slider(aThetaMin, 'ThetaMin', 0, 359.9,
                           valinit=cfg.init_theta[0], valfmt='%0.1f')
aThetaMax = plt.axes([0.15, bottom-0.15, 0.25, 0.025], facecolor=axcolor)
flexFig.sThetaMax = Slider(aThetaMax, 'ThetaMax', 0, 359.9,
                           valinit=cfg.init_theta[1]-0.1, valfmt='%0.1f')

# Cycle button
cycleax = plt.axes([0.55, bottom-0.2475, 0.075, 0.0375])
flexFig.bCycle = Button(cycleax, 'Cycle',
                        color=axcolor, hovercolor=hovcolor)

# Rotate button
rotateax = plt.axes([0.55, bottom-0.285, 0.075, 0.0375])
flexFig.bRotate = Button(rotateax, 'Rotate',
                         color=axcolor, hovercolor=hovcolor)

# Reset button
resetax = plt.axes([0.65, bottom-0.285, 0.075, 0.075])
flexFig.bReset = Button(resetax, 'Reset', color=axcolor, hovercolor=hovcolor)

# Export nii button
exportax = plt.axes([0.75, bottom-0.285, 0.075, 0.075])
flexFig.bExport = Button(exportax, 'Export\nNifti',
                         color=axcolor, hovercolor=hovcolor)

# Export nyp button
exportax = plt.axes([0.85, bottom-0.285, 0.075, 0.075])
flexFig.bExportNyp = Button(exportax, 'Export\nHist',
                            color=axcolor, hovercolor=hovcolor)

#
"""Updates"""
flexFig.sHistC.on_changed(flexFig.updateColorBar)
flexFig.sSliceNr.on_changed(flexFig.updateImaBrowser)
flexFig.sThetaMin.on_changed(flexFig.updateThetaMin)
flexFig.sThetaMax.on_changed(flexFig.updateThetaMax)
flexFig.bCycle.on_clicked(flexFig.cycleView)
flexFig.bRotate.on_clicked(flexFig.changeRotation)
flexFig.bExport.on_clicked(flexFig.exportNifti)
flexFig.bExportNyp.on_clicked(flexFig.exportNyp)
flexFig.bReset.on_clicked(flexFig.resetGlobal)


# TODO: Temporary solution for displaying original x-y axis labels
def update_axis_labels(event):
    """Swap histogram bin indices with original values."""
    xlabels = [item.get_text() for item in ax.get_xticklabels()]
    orig_range_labels = np.linspace(pMin, pMax, len(xlabels))

    # Adjust displayed decimals based on data range
    data_range = pMax - pMin
    if data_range > 200:  # arbitrary value
        xlabels = [('%i' % i) for i in orig_range_labels]
    elif data_range > 20:
        xlabels = [('%.1f' % i) for i in orig_range_labels]
    elif data_range > 2:
        xlabels = [('%.2f' % i) for i in orig_range_labels]
    else:
        xlabels = [('%.3f' % i) for i in orig_range_labels]

    ax.set_xticklabels(xlabels)
    ax.set_yticklabels(xlabels)  # limits of y axis assumed to be the same as x


fig.canvas.mpl_connect('resize_event', update_axis_labels)

#
"""Lasso selection"""
# Lasso button
lassoax = plt.axes([0.15, bottom-0.285, 0.075, 0.075])
bLasso = Button(lassoax, 'Lasso\nOff', color=axcolor, hovercolor=hovcolor)

# Lasso draw/erase
lassoEraseAx = plt.axes([0.25, bottom-0.285, 0.075, 0.075])
bLassoErase = Button(lassoEraseAx, 'Erase\nOff', color=axcolor,
                     hovercolor=hovcolor)
bLassoErase.ax.patch.set_visible(False)
bLassoErase.label.set_visible(False)
bLassoErase.ax.axis('off')


def lassoSwitch(event):
    """Enable disable lasso tool."""
    global lasso
    lasso = []
    flexFig.lassoSwitchCount = (flexFig.lassoSwitchCount+1) % 2
    if flexFig.lassoSwitchCount == 1:  # enable lasso
        flexFig.disconnect()  # disable drag function of sector mask
        lasso = LassoSelector(ax, onselect)
        bLasso.label.set_text("Lasso\nOn")
        # Make erase button appear on in lasso mode
        bLassoErase.ax.patch.set_visible(True)
        bLassoErase.label.set_visible(True)
        bLassoErase.ax.axis('on')

    else:  # disable lasso
        flexFig.connect()  # enable drag function of sector mask
        bLasso.label.set_text("Lasso\nOff")
        # Make erase button disappear
        bLassoErase.ax.patch.set_visible(False)
        bLassoErase.label.set_visible(False)
        bLassoErase.ax.axis('off')

# Pixel coordinates
pix = np.arange(nr_bins)
xv, yv = np.meshgrid(pix, pix)
pix = np.vstack((xv.flatten(), yv.flatten())).T


def onselect(verts):
    """Lasso related."""
    global pix
    p = path.Path(verts)
    newLasIdx = p.contains_points(pix, radius=1.5)  # New lasso indices
    flexFig.idxLasso[newLasIdx] = flexFig.lassoErase  # Update lasso indices
    flexFig.remapMsks()  # Update volume histogram mask
    flexFig.updatePanels(update_slice=False, update_rotation=True,
                         update_extent=True)


def lassoEraseSwitch(event):
    """Enable disable lasso erase function."""
    flexFig.lassoErase = (flexFig.lassoErase + 1) % 2
    if flexFig.lassoErase is 1:
        bLassoErase.label.set_text("Erase\nOff")
    elif flexFig.lassoErase is 0:
        bLassoErase.label.set_text("Erase\nOn")


bLasso.on_clicked(lassoSwitch)  # lasso on/off
bLassoErase.on_clicked(lassoEraseSwitch)  # lasso erase on/off
flexFig.remapMsks()
flexFig.updatePanels(update_slice=True, update_rotation=False,
                     update_extent=False)

print("GUI is ready.")
plt.show()