from .myqt import QT
import pyqtgraph as pg

import numpy as np
import time

from .base import WidgetBase
from .tools import TimeSeeker
from ..tools import median_mad
from ..dataio import _signal_types
from ..peeler_tools import make_prediction_signals


class MyViewBox(pg.ViewBox):
    doubleclicked = QT.pyqtSignal()
    gain_zoom = QT.pyqtSignal(float)
    xsize_zoom = QT.pyqtSignal(float)
    def __init__(self, *args, **kwds):
        pg.ViewBox.__init__(self, *args, **kwds)
        #~ self.disableAutoRange()
    def mouseClickEvent(self, ev):
        ev.accept()
    def mouseDoubleClickEvent(self, ev):
        self.doubleclicked.emit()
        ev.accept()
    def mouseDragEvent(self, ev):
        ev.ignore()
    def wheelEvent(self, ev, axis=None):
        if ev.modifiers() == QT.Qt.ControlModifier:
            z = 10 if ev.delta()>0 else 1/10.
        else:
            z = 1.3 if ev.delta()>0 else 1/1.3
        self.gain_zoom.emit(z)
        ev.accept()
    def mouseDragEvent(self, ev):
        ev.accept()
        self.xsize_zoom.emit((ev.pos()-ev.lastPos()).x())


class BaseTraceViewer(WidgetBase):
    
    _params = [{'name': 'auto_zoom_on_select', 'type': 'bool', 'value': True },
                       {'name': 'zoom_size', 'type': 'float', 'value':  0.08, 'step' : 0.001 },
                      {'name': 'plot_threshold', 'type': 'bool', 'value':  True },
                      {'name': 'alpha', 'type': 'float', 'value' : 0.8, 'limits':(0, 1.), 'step':0.05 },
                      {'name': 'xsize_max', 'type': 'float', 'value': 4.0, 'step': 1.0, 'limits':(1.0, np.inf)},
                      ]
    
    def __init__(self,controller=None, signal_type='initial', parent=None):
        WidgetBase.__init__(self, parent=parent, controller=controller)
    
        self.dataio = controller.dataio
        self.signal_type = signal_type
        
        self.layout = QT.QVBoxLayout()
        self.setLayout(self.layout)
        
        self.create_toolbar()
        
        
        # create graphic view and 2 scroll bar
        g = QT.QGridLayout()
        self.layout.addLayout(g)
        self.scroll_chan = QT.QScrollBar()
        g.addWidget(self.scroll_chan, 0,0)
        self.scroll_chan.valueChanged.connect(self.on_scroll_chan)
        self.graphicsview = pg.GraphicsView()
        g.addWidget(self.graphicsview, 0,1)
        self.initialize_plot()
        self.scroll_time = QT.QScrollBar(orientation=QT.Qt.Horizontal)
        g.addWidget(self.scroll_time, 1,1)
        self.scroll_time.valueChanged.connect(self.on_scroll_time)
        
        #handle time by segments
        self.time_by_seg = np.array([0.]*self.dataio.nb_segment, dtype='float64')
        
        self.change_segment(0)
        self.refresh()
    
    _default_color = QT.QColor( 'white')
    
    def create_toolbar(self):
        tb = self.toolbar = QT.QToolBar()
        
        #Segment selection
        self.combo_seg = QT.QComboBox()
        tb.addWidget(self.combo_seg)
        self.combo_seg.addItems([ 'Segment {}'.format(seg_num) for seg_num in range(self.dataio.nb_segment) ])
        self._seg_pos = 0
        self.seg_num = self._seg_pos
        self.combo_seg.currentIndexChanged.connect(self.on_combo_seg_changed)
        tb.addSeparator()
        
        self.combo_type = QT.QComboBox()
        tb.addWidget(self.combo_type)
        self.combo_type.addItems([ signal_type for signal_type in _signal_types ])
        self.combo_type.setCurrentIndex(_signal_types.index(self.signal_type))
        self.combo_type.currentIndexChanged.connect(self.on_combo_type_changed)

        # time slider
        self.timeseeker = TimeSeeker(show_slider=False)
        tb.addWidget(self.timeseeker)
        self.timeseeker.time_changed.connect(self.seek)
        
        # winsize
        self.xsize = .5
        tb.addWidget(QT.QLabel(u'X size (s)'))
        self.spinbox_xsize = pg.SpinBox(value = self.xsize, bounds = [0.001, self.params['xsize_max']], suffix = 's', siPrefix = True, step = 0.1, dec = True)
        self.spinbox_xsize.sigValueChanged.connect(self.on_xsize_changed)
        tb.addWidget(self.spinbox_xsize)
        tb.addSeparator()
        self.spinbox_xsize.sigValueChanged.connect(self.refresh)
        
        #
        but = QT.QPushButton('auto scale')
        but.clicked.connect(self.auto_scale)
        tb.addWidget(but)
        but = QT.QPushButton('settings')
        but.clicked.connect(self.open_settings)
        tb.addWidget(but)
        self.select_button = QT.QPushButton('select', checkable = True)
        tb.addWidget(self.select_button)
        
        
        self.layout.addWidget(self.toolbar)
        
        self._create_other_toolbar()
    

    def initialize_plot(self):
        self.viewBox = MyViewBox()
        self.plot = pg.PlotItem(viewBox=self.viewBox)
        self.graphicsview.setCentralItem(self.plot)
        self.plot.hideButtons()
        self.plot.showAxis('left', False)
        
        self.viewBox.gain_zoom.connect(self.gain_zoom)
        self.viewBox.xsize_zoom.connect(self.xsize_zoom)
        
        self.visible_channels = np.zeros(self.controller.nb_channel, dtype='bool')
        self.max_channel = min(16, self.controller.nb_channel)
        #~ self.max_channel = min(5, self.controller.nb_channel)
        if self.controller.nb_channel>self.max_channel:
            self.visible_channels[:self.max_channel] = True
            self.scroll_chan.show()
            self.scroll_chan.setMinimum(0)
            self.scroll_chan.setMaximum(self.controller.nb_channel-self.max_channel)
            self.scroll_chan.setPageStep(self.max_channel)
        else:
            self.visible_channels[:] = True
            self.scroll_chan.hide()
            
        self.signals_curve = pg.PlotCurveItem(pen='#7FFF00', connect='finite')
        self.plot.addItem(self.signals_curve)

        self.scatter = pg.ScatterPlotItem(size=10, pxMode = True)
        self.plot.addItem(self.scatter)
        self.scatter.sigClicked.connect(self.scatter_item_clicked)
        
        self.channel_labels = []
        self.threshold_lines =[]
        for i, chan_name in enumerate(self.controller.channel_names):
            #TODO label channels
            label = pg.TextItem('{}: {}'.format(i, chan_name), color='#FFFFFF', anchor=(0, 0.5), border=None, fill=pg.mkColor((128,128,128, 180)))
            self.plot.addItem(label)
            self.channel_labels.append(label)
        
        
        for i in range(self.max_channel):
            tc = pg.InfiniteLine(angle = 0., movable = False, pen = pg.mkPen(color=(128,128,128, 120)))
            tc.setPos(0.)
            self.threshold_lines.append(tc)
            self.plot.addItem(tc)
            tc.hide()
        
        pen = pg.mkPen(color=(128,0,128, 120), width=3, style=QT.Qt.DashLine)
        self.selection_line = pg.InfiniteLine(pos = 0., angle=90, movable=False, pen = pen)
        self.plot.addItem(self.selection_line)
        self.selection_line.hide()
        
        self._initialize_plot()
        
        self.gains = None
        self.offsets = None

    def prev_segment(self):
        self.change_segment(self._seg_pos - 1)
        
    def next_segment(self):
        self.change_segment(self._seg_pos + 1)

    def change_segment(self, seg_pos):
        #TODO: dirty because now seg_pos IS seg_num
        self._seg_pos  =  seg_pos
        if self._seg_pos<0:
            self._seg_pos = self.dataio.nb_segment-1
        if self._seg_pos == self.dataio.nb_segment:
            self._seg_pos = 0
        self.seg_num = self._seg_pos
        self.combo_seg.setCurrentIndex(self._seg_pos)
        
        length = self.dataio.get_segment_length(self.seg_num)
        t_start=0.
        t_stop = length/self.dataio.sample_rate
        self.timeseeker.set_start_stop(t_start, t_stop, seek = False)

        self.scroll_time.setMinimum(0)
        self.scroll_time.setMaximum(length)
        
        if self.isVisible():
            self.refresh()
    
    def on_params_changed(self):
        
        # adjust xsize spinbox bounds, and adjust xsize if out of bounds
        self.spinbox_xsize.opts['bounds'] = [0.001, self.params['xsize_max']]
        if self.xsize > self.params['xsize_max']:
            self.spinbox_xsize.sigValueChanged.disconnect(self.on_xsize_changed)
            self.spinbox_xsize.setValue(self.params['xsize_max'])
            self.xsize = self.params['xsize_max']
            self.spinbox_xsize.sigValueChanged.connect(self.on_xsize_changed)
        
        self.refresh()
    
    def on_combo_seg_changed(self):
        s =  self.combo_seg.currentIndex()
        self.change_segment(s)
    
    def on_combo_type_changed(self):
        s =  self.combo_type.currentIndex()
        self.signal_type = _signal_types[s]
        self.estimate_auto_scale()
        self.change_segment(self._seg_pos)
    

    
    def on_xsize_changed(self):
        self.xsize = self.spinbox_xsize.value()
        if self.isVisible():
            self.refresh()
    
    def refresh(self):
        self.seek(self.time_by_seg[self.seg_num])

    def xsize_zoom(self, xmove):
        factor = xmove/100.
        newsize = self.xsize*(factor+1.)
        limits = self.spinbox_xsize.opts['bounds']
        if newsize>0. and newsize<limits[1]:
            self.spinbox_xsize.setValue(newsize)
    
    def auto_scale(self):
        self.estimate_auto_scale()
        self.refresh()
    
    def estimate_auto_scale(self):

        if self.signal_type=='initial':
            i_stop = min(int(60.*self.dataio.sample_rate), self.dataio.get_segment_shape(self.seg_num, chan_grp=self.controller.chan_grp)[0])
            sigs = self.dataio.get_signals_chunk(seg_num=self.seg_num, chan_grp=self.controller.chan_grp,
                    i_start=0, i_stop=i_stop, signal_type=self.signal_type)
            self.med, self.mad = median_mad(sigs.astype('float32'), axis = 0)

        elif self.signal_type=='processed':
            #in that case it should be already normalize
            self.med = np.zeros(self.controller.nb_channel, dtype='float32')
            self.mad = np.ones(self.controller.nb_channel, dtype='float32')
        
        self.factor = 1.
        self.gain_zoom(15.)
    
    def gain_zoom(self, factor_ratio):
        self.factor *= factor_ratio
        self.gains = np.zeros(self.controller.nb_channel, dtype='float32')
        self.offsets = np.zeros(self.controller.nb_channel, dtype='float32')
        n = np.sum(self.visible_channels)
        self.gains[self.visible_channels] = np.ones(n, dtype=float) * 1./(self.factor*max(self.mad))
        self.offsets[self.visible_channels] = np.arange(n)[::-1] - self.med[self.visible_channels]*self.gains[self.visible_channels]
        self.refresh()

    def on_scroll_time(self, val):
        sr = self.controller.dataio.sample_rate
        self.timeseeker.seek(val/sr)
    
    def on_scroll_chan(self, val):
        self.visible_channels[:] = False
        self.visible_channels[val:val+self.max_channel] = True
        self.gain_zoom(1)
        self.refresh()
    
    def center_scrollbar_on_channel(self, c):
        c = c - self.max_channel//2
        c = min(max(c, 0), self.controller.nb_channel-self.max_channel)
        self.scroll_chan.valueChanged.disconnect(self.on_scroll_chan)
        self.scroll_chan.setValue(c)
        self.scroll_chan.valueChanged.connect(self.on_scroll_chan)
        
        self.visible_channels[:] = False
        self.visible_channels[c:c+self.max_channel] = True
        self.gain_zoom(1)
    
    def scatter_item_clicked(self, plot, points):
        if self.select_button.isChecked()and len(points)==1:
            x = points[0].pos().x()
            self.controller.spike_selection[:] = False
            
            pos_click = int(x*self.dataio.sample_rate )
            mask = self.controller.spikes['segment']==self.seg_num
            ind_nearest = np.argmin(np.abs(self.controller.spikes[mask]['index'] - pos_click))
            
            ind_clicked = np.nonzero(mask)[0][ind_nearest]
            self.controller.spike_selection[ind_clicked] = True
            
            self.spike_selection_changed.emit()
            self.refresh()
    
    def on_spike_selection_changed(self):
        ind_selected, = np.nonzero(self.controller.spike_selection)
        n_selected = ind_selected.size
        if self.params['auto_zoom_on_select'] and n_selected==1:
            ind_selected, = np.nonzero(self.controller.spike_selection)
            ind = ind_selected[0]
            peak_ind = self.controller.spikes[ind]['index']
            seg_num = self.controller.spikes[ind]['segment']
            peak_time = peak_ind/self.dataio.sample_rate
            
            if seg_num != self.seg_num:
                self.combo_seg.setCurrentIndex(seg_num)
            
            self.spinbox_xsize.sigValueChanged.disconnect(self.on_xsize_changed)
            self.spinbox_xsize.setValue(self.params['zoom_size'])
            self.xsize = self.params['zoom_size']
            self.spinbox_xsize.sigValueChanged.connect(self.on_xsize_changed)
            
            label = self.controller.spikes[ind]['cluster_label']
            c = self.controller.get_extremum_channel(label)
            
            if c  is None:
                wf = self.controller.dataio.get_signals_chunk(seg_num=seg_num, chan_grp=self.controller.chan_grp,
                        i_start=peak_ind, i_stop=peak_ind+1,
                        signal_type='processed')
                c = np.argmax(np.abs(wf))
            
            self.center_scrollbar_on_channel(c)
            
            self.seek(peak_time)
            
        else:
            self.refresh()
    
    def seek(self, t):
        #~ tp1 = time.perf_counter()
        
        if self.sender() is not self.timeseeker:
            self.timeseeker.seek(t, emit = False)
        
        self.time_by_seg[self.seg_num] = t
        t1,t2 = t-self.xsize/3. , t+self.xsize*2/3.
        t_start = 0.
        sr = self.dataio.sample_rate

        self.scroll_time.valueChanged.disconnect(self.on_scroll_time)
        self.scroll_time.setValue(int(sr*t))
        self.scroll_time.setPageStep(int(sr*self.xsize))
        self.scroll_time.valueChanged.connect(self.on_scroll_time)
        
        ind1 = max(0, int((t1-t_start)*sr))
        ind2 = int((t2-t_start)*sr)

        sigs_chunk = self.dataio.get_signals_chunk(seg_num=self.seg_num, chan_grp=self.controller.chan_grp,
                i_start=ind1, i_stop=ind2, signal_type=self.signal_type)
        
        if sigs_chunk is None: 
            return
        
        if self.gains is None:
            self.estimate_auto_scale()

        
        nb_visible = np.sum(self.visible_channels)
        
        data_curves = sigs_chunk[:, self.visible_channels].T.copy()
        if data_curves.dtype!='float32':
            data_curves = data_curves.astype('float32')
        
        data_curves *= self.gains[self.visible_channels, None]
        data_curves += self.offsets[self.visible_channels, None]
        #~ data_curves[:,0] = np.nan
        
        connect = np.ones(data_curves.shape, dtype='bool')
        connect[:, -1] = 0
        
        times_chunk = np.arange(sigs_chunk.shape[0], dtype='float64')/self.dataio.sample_rate+max(t1, 0)
        times_chunk_tile = np.tile(times_chunk, nb_visible)
        self.signals_curve.setData(times_chunk_tile, data_curves.flatten(), connect=connect.flatten())
        
        
        #channel labels
        i = 1
        for c in range(self.controller.nb_channel):
            if self.visible_channels[c]:
                self.channel_labels[c].setPos(t1, nb_visible-i)
                self.channel_labels[c].show()
                i +=1
            else:
                self.channel_labels[c].hide()

        n = np.sum(self.visible_channels)
        index_visible, = np.nonzero(self.visible_channels)
        for i, c in enumerate(index_visible):
            if self.params['plot_threshold']:
                threshold = self.controller.get_threshold()
                self.threshold_lines[i].setPos(n-i-1 + self.gains[c]*threshold)
                self.threshold_lines[i].show()
            else:
                self.threshold_lines[i].hide()        
        
        
        # plot peak on signal
        all_spikes = self.controller.spikes
        if len(all_spikes)>0:
            keep = (all_spikes['segment']==self.seg_num) & (all_spikes['index']>=ind1) & (all_spikes['index']<ind2)
            spikes_chunk = np.array(all_spikes[keep], copy=True)
            spikes_chunk['index'] -= ind1
            inwindow_ind = spikes_chunk['index']
            inwindow_label = spikes_chunk['cluster_label']
            inwindow_chan = spikes_chunk['channel']
            if np.any(inwindow_chan==-1):
                inwindow_chan = None
            inwindow_selected = np.array(self.controller.spike_selection[keep])

            self.scatter.clear()
            all_x = []
            all_y = []
            all_brush = []
            for k in self.controller.cluster_labels:

                if not self.controller.cluster_visible[k]: continue
                mask = inwindow_label==k
                if np.sum(mask)==0: continue
                
                color = QT.QColor(self.controller.qcolors.get(k, self._default_color))
                color.setAlpha(int(self.params['alpha']*255))
                
                x = times_chunk[inwindow_ind[mask]]
                sigs_chunk_in = sigs_chunk[inwindow_ind[mask], :]
                chan_inds = None
                if k >=0:
                    c = self.controller.get_extremum_channel(k)
                    if c is not None:
                        chan_inds = np.array([c]*np.sum(mask), dtype='int64')
                if chan_inds is None:
                    if inwindow_chan is None:
                        chan_inds = np.argmax(np.abs(sigs_chunk_in), axis=1)
                    else:
                        chan_inds = inwindow_chan[mask]
                mask_visible = self.visible_channels[chan_inds]
                if np.sum(mask_visible)==0: continue
                
                chan_inds = chan_inds[mask_visible]
                x = x[mask_visible]
                y = sigs_chunk_in[mask_visible, :][np.arange(chan_inds.size), chan_inds]*self.gains[chan_inds]+self.offsets[chan_inds]
                
                all_x.append(x)
                all_y.append(y)
                all_brush.append(np.array([pg.mkBrush(color)]*len(x)))
            #~ print()
            #~ print(all_x)
            #~ print(all_y)
            if len(all_x) > 0:
                all_x = np.concatenate(all_x)
                all_y = np.concatenate(all_y)
                all_brush = np.concatenate(all_brush)
                self.scatter.setData(x=all_x, y=all_y, brush=all_brush)
            
            
            if np.sum(inwindow_selected)==1:
                self.selection_line.setPos(times_chunk[inwindow_ind[inwindow_selected]])
                self.selection_line.show()
            else:
                self.selection_line.hide()            
        else:
            spikes_chunk = None
        
        # plot prediction or residuals ...
        self._plot_specific_items(sigs_chunk, times_chunk, spikes_chunk)
        
        #ranges
        self.plot.setXRange( t1, t2, padding = 0.0)
        self.plot.setYRange(-.5, nb_visible-.5, padding = 0.0)
        
        #TODO : do some thing here
        #~ self.graphicsview.repaint()

        #~ tp2 = time.perf_counter()
        #~ print('seek', tp2-tp1)
        


class CatalogueTraceViewer(BaseTraceViewer):
    """
    **Trace viewer** allow to browser raw signal and preprocess signals.
    
    Note that this viewer do not load the entire signals in memory but load
    chunk on demand from HD, that is why depending on the drive it can be
    quite slow. All zoom and scale factor for signals are computed on CPU and
    not on GPU (it is not vispy!!), so this is not the fastest veiwer but many tips
    help user to navigate very efficiently.
    
    What you can do:
      * On the bottom there is a slider over time
      * On the left there is a slider over channels (if nb_channel>16)
      * If several segments you can switch.
      * You can select manually jump to any time
      * You can zoom the X (time) axis with the spinbox ot by **right click** with mouse.
      * **The mouse wheel make a glocal zoom on signal**
      * You can "select" manually a spike with "select" button and this will be show in **ND Scatter**
      * The threshold is a line for each channel. This is very important to why so peak are not detected.
      * Setting button:
        * "auto_zoom_on_select" : auto zoom when select on ndscatter on peak list
        * "zoom_size" in s
        * disable plot threshold
    
    Important:
      * Ths "preprocessed" signal are normalized (robust Z-score) so that the noise variance is 1.
        So the apparent noise **must be** inbetween  [-3, 3]
    
    """
    def __init__(self, controller=None, signal_type = 'processed', parent=None):
        BaseTraceViewer.__init__(self, controller=controller, signal_type=signal_type, parent=parent)
    
    def _create_other_toolbar(self):
        pass
        
    def _initialize_plot(self):
        pass
    
    def _plot_specific_items(self, sigs_chunk, times_chunk, spikes_chunk):
        pass


class PeelerTraceViewer(BaseTraceViewer):
    def __init__(self, controller=None, signal_type = 'processed', parent=None):
        BaseTraceViewer.__init__(self, controller=controller, signal_type=signal_type, parent=parent)
    
    def _create_other_toolbar(self):
        
        self.toolbar2 = QT.QToolBar()
        self.layout.insertWidget(1, self.toolbar2)
        #~ addToolBarBreak
        
        self.plot_buttons = {}
        for name in ['signals', 'prediction', 'residual']:
            self.plot_buttons[name] = but = QT.QPushButton(name,  checkable = True)
            but.clicked.connect(self.refresh)
            self.toolbar2.addWidget(but)
            
            if name in ['signals', 'prediction']:
                but.setChecked(True)
    
    def _initialize_plot(self):
        self.curve_predictions = pg.PlotCurveItem(pen='#FF00FF', connect='finite')
        self.plot.addItem(self.curve_predictions)
        self.curve_residuals = pg.PlotCurveItem(pen='#FFFF00', connect='finite')
        self.plot.addItem(self.curve_residuals)
   
    def _plot_specific_items(self, sigs_chunk, times_chunk, spikes_chunk):
        if spikes_chunk is None: return
        
        #prediction
        #TODO make prediction only on visible!!!! 
        if self.signal_type == 'processed':
            prediction = make_prediction_signals(spikes_chunk, sigs_chunk.dtype, sigs_chunk.shape, self.controller.catalogue)
            residuals = sigs_chunk - prediction
        
        # plots
        nb_visible = np.sum(self.visible_channels)
        times_chunk_tile = np.tile(times_chunk, nb_visible)
        
        def plot_curves(curve, data):
            data = data[:, self.visible_channels].T.copy()
            data *= self.gains[self.visible_channels, None]
            data += self.offsets[self.visible_channels, None]
            #~ data[:,0] = np.nan
            
            connect = np.ones(data.shape, dtype='bool')
            connect[:, -1] = 0
            
            curve.setData(times_chunk_tile, data.flatten(), connect=connect.flatten())
        
        if self.plot_buttons['prediction'].isChecked() and self.signal_type == 'processed':
            plot_curves(self.curve_predictions, prediction)
        else:
            self.curve_predictions.setData([], [])

        if self.plot_buttons['residual'].isChecked() and self.signal_type == 'processed':
            plot_curves(self.curve_residuals, residuals)
        else:
            self.curve_residuals.setData([], [])
        
        if not self.plot_buttons['signals'].isChecked():
            self.signals_curve.setData([], [])