import numpy as np
import xarray as xr
import itertools as itt
import functools as fct
import cv2
import skimage as ski
import scipy.ndimage as ndi
import scipy.stats as stat
import numba as nb
from dask import delayed, compute
from dask.diagnostics import ProgressBar
from collections import OrderedDict
from scipy.ndimage import uniform_filter
from skimage.morphology import disk
from medpy.filter.smoothing import anisotropic_diffusion
from scipy.stats import zscore
from warnings import warn
from .utilities import scale_varr
from IPython.core.debugger import set_trace


class HashableDict(dict):
    def __hash__(self):
        return hash(frozenset(self.items()))


def corr_coeff_pixelwise(varray):
    if varray.sizes['frame'] % 2 > 0:
        varr = varray.isel(frame=slice(None, -1))
    else:
        varr = varray

    def corr(a, axis):
        return np.apply_along_axis(
            lambda t: np.corrcoef(np.split(t, 2)[0], np.split(t, 2)[1])[0, 1],
            axis, a)

    return varr.reduce(corr, dim='frame')


# def mask_movie_framewise(mov, mask, vals):
#     mov_re = mov.reshape((mov.shape[0], -1))
#     mask_re = mask.flatten()
#     mov_masked = mov_re.copy()
#     np.apply_along_axis(lambda f: np.place(f, mask_re, vals), 1, mov_masked)
#     return mov_masked.reshape(mov.shape)

# def zscore_xr(xarr, dim=None):
#     mean = xarr.mean(dim=dim)
#     std = xarr.std(dim=dim)
#     return (xarr - mean) / std


def detect_brightspot(varray, thres=None, window=50, step=10):
    print("detecting brightspot")
    spots = xr.DataArray(
        varray.sel(frame=0)).reset_coords(drop=True).astype(int)
    spots.values = np.zeros_like(spots.values)
    meanfm = varray.mean(dim='frame')
    for ih, ph in meanfm.rolling(height=window):
        if ih % step == 0:
            for iw, pw in ph.rolling(width=window):
                if (iw % step == 0 and pw.sizes['height'] == window
                        and pw.sizes['width'] == window):
                    mean_z = xr.apply_ufunc(zscore, pw)
                    if not thres:
                        cur_thres = -mean_z.min().values
                    else:
                        cur_thres = thres
                    spots.loc[{
                        'height': slice(ih - window + 1, ih),
                        'width': slice(iw - window + 1, iw)
                    }] += mean_z > cur_thres
                    print(
                        ("processing window at {:3d}, {:3d}"
                         " using threshold: {:03.2f}").format(
                             int(ih), int(iw), float(cur_thres)),
                        end='\r')
    print("\nbrightspot detection done")
    return spots


def detect_brightspot_perframe(varray, thres=0.95):
    print("creating parallel schedule")
    spots = []
    for fid, fm in varray.rolling(frame=1):
        sp = delayed(lambda f: f > f.quantile(thres, interpolation='lower'))(
            fm)
        spots.append(sp)
    with ProgressBar():
        print("detecting bright spots by frame")
        spots, = compute(spots)
    print("concatenating results")
    spots = xr.concat(spots, dim='frame')
    return spots


# def correct_dust(varray, dust):
#     mov_corr = varray.values
#     nz = np.nonzero(dust)
#     nz_tp = [(d0, d1) for d0, d1 in zip(nz[0], nz[1])]
#     for i in range(np.count_nonzero(dust)):
#         cur_dust = (nz[0][i], nz[1][i])
#         cur_sur = set(
#             itt.product(
#                 range(cur_dust[0] - 1, cur_dust[0] + 2),
#                 range(cur_dust[1] - 1, cur_dust[1] + 2))) - set(
#                     cur_dust) - set(nz_tp)
#         cur_sur = list(
#             filter(
#                 lambda d: 0 < d[0] < mov.shape[1] and 0 < d[1] < mov.shape[2],
#                 cur_sur))
#         if len(cur_sur) > 0:
#             sur_arr = np.empty((mov.shape[0], len(cur_sur)))
#             for si, sur in enumerate(cur_sur):
#                 sur_arr[:, si] = mov[:, sur[0], sur[1]]
#             mov_corr[:, cur_dust[0], cur_dust[1]] = np.mean(sur_arr, axis=1)
#         else:
#             print("unable to correct for point ({}, {})".format(
#                 cur_dust[0], cur_dust[1]))
#     return mov_corr


def correct_brightspot(varray, spots, window=2, spot_thres=10, inplace=True):
    print("correcting brightspot")
    if not spots.sum() > 0:
        print("no bright spots to be corrected, returning input")
        return varray
    if not inplace:
        varr_ds = varray.copy()
    else:
        varr_ds = varray
    spot_dim = spots.dims
    red_dim = tuple(set(varray.dims) - set(spot_dim))
    if len(spot_dim) > 2:
        spot_thres = 0
    brt = np.nonzero(spots.values > spot_thres)
    brt_list = [
        HashableDict((dm, int(spots.coords[dm][brt[idm][ib]].values))
                     for idm, dm in enumerate(spot_dim))
        for ib in range(len(brt[0]))
    ]
    sur_list = []
    for ibrt, brt_cord in enumerate(brt_list):
        cur_sur = [(dim,
                    list(
                        set(range(co - window, co + window + 1)).intersection(
                            set(varr_ds.coords[dim].values.tolist()))))
                   for dim, co in brt_cord.items()]
        cur_sur_list = []
        for cord in itt.product(*[cord_rg[1] for cord_rg in cur_sur]):
            cur_sur_list.append(
                HashableDict(
                    (cur_sur[i][0], cord[i]) for i in range(len(cord))))
        cur_sur = list(set(cur_sur_list) - set(brt_list))
        sur_list.append(cur_sur)
    for ibrt, cur_brt in enumerate(brt_list):
        print(
            "processing spot {:3d} of {:3d}".format(ibrt, len(brt_list)),
            end='\r')
        if len(sur_list[ibrt]) > 0:
            cur_sur = xr.DataArray(
                np.zeros((len(sur_list[ibrt]), ) +
                         tuple([varr_ds.sizes[rd] for rd in red_dim])),
                dims=('sample', ) + red_dim,
                coords=dict({
                    'sample': range(len(sur_list[ibrt]))
                }, **{r: varr_ds.coords[r]
                      for r in red_dim}))
            for isamp, cord_samp in enumerate(sur_list[ibrt]):
                cur_sur.loc[{'sample': isamp}] = varr_ds.loc[cord_samp]
            varr_ds.loc[cur_brt] = cur_sur.mean(dim='sample')
        else:
            print("unable to correct for point {}, coordinates: {}".format(
                ibrt, cur_brt))
    print("\nbrightspot correction done")
    return varr_ds.rename(varray.name + "_DeSpotted")


def remove_background_old(varray, window=51):
    print("creating parallel schedule")
    varr_ft = varray.astype(np.float32)
    compute_list = []
    for fid in varr_ft.coords['frame'].values:
        fm = varr_ft.loc[dict(frame=fid)]
        _ = delayed(remove_background_perframe_old)(fid, fm, varr_ft, window)
        compute_list.append(_)
    with ProgressBar():
        print("removing background")
        compute(compute_list)
    print("normalizing result")
    varr_ft = scale_varr(varr_ft, (0, 255)).astype(varray.dtype, copy=False)
    print("background removal done")
    return varr_ft.rename(varray.name + "_Filtered")


def remove_background_perframe_old(fid, fm, varr, window):
    f = fm - uniform_filter(fm, window)
    varr.loc[dict(frame=fid)] = f


def remove_background(varr, method, wnd):
    selem = disk(wnd)
    res = xr.apply_ufunc(
        remove_background_perframe,
        varr.chunk(dict(height=-1, width=-1)),
        input_core_dims=[['height', 'width']],
        output_core_dims=[['height', 'width']],
        vectorize=True,
        dask='parallelized',
        output_dtypes=[varr.dtype],
        kwargs=dict(method=method, wnd=wnd, selem=selem))
    return res.rename(varr.name + "_subtracted")


def remove_background_perframe(fm, method, wnd, selem):
    if method == 'uniform':
        return fm - uniform_filter(fm, wnd)
    elif method == 'tophat':
        return cv2.morphologyEx(fm, cv2.MORPH_TOPHAT, selem)


def stripe_correction(varr, reduce_dim='height', on='mean'):
    if on == 'mean':
        temp = varr.mean(dim='frame')
    elif on == 'max':
        temp = varr.max(dim='frame')
    elif on == 'perframe':
        temp = varr
    else:
        raise NotImplementedError("on {} not understood".format(on))
    mean1d = temp.mean(dim=reduce_dim)
    varr_sc = varr - mean1d
    return varr_sc.rename(varr.name + "_Stripe_Corrected")


def gaussian_blur(varray, ksize=(3, 3), sigmaX=0):
    return varray.groupby('frame').apply(
        lambda fm: cv2.GaussianBlur(fm.values, ksize, sigmaX))


def denoise(varr, method, **kwargs):
    if method == 'gaussian':
        func = cv2.GaussianBlur
    elif method == 'anisotropic':
        func = anisotropic_diffusion
    elif method == 'median':
        func = cv2.medianBlur
    elif method == 'bilateral':
        func = cv2.bilateralFilter
    else:
        raise NotImplementedError(
            "denoise method {} not understood".format(method))
    res = xr.apply_ufunc(
        func,
        varr,
        input_core_dims=[['height', 'width']],
        output_core_dims=[['height', 'width']],
        vectorize=True,
        dask='parallelized',
        output_dtypes=[varr.dtype],
        kwargs=kwargs)
    return res.rename(varr.name + "_denoised")


def denoise_perframe(fm, method, **kwargs):
    if method == 'gaussian':
        return cv2.GaussianBlur(fm, **kwargs)
    elif method == 'anisotropic':
        return anisotropic_diffusion(fm, **kwargs)


def gradient_norm(varr):
    return xr.apply_ufunc(
        gradient_norm_perframe,
        varr,
        input_core_dims=[['height', 'width']],
        output_core_dims=[['height', 'width']],
        vectorize=True,
        dask='parallelized',
        output_dtypes=[varr.dtype]).rename(varr.name + '_gradient')


def gradient_norm_perframe(f):
    x, y = np.gradient(f)
    return np.sqrt(x**2 + y**2)


def remove_brightspot(varr, thres=3):
    k_mean = ski.morphology.diamond(1)
    k_mean[1, 1] = 0
    k_mean = k_mean / 4
    return xr.apply_ufunc(
        remove_brightspot_perframe,
        varr.chunk(dict(height=-1, width=-1)),
        input_core_dims=[['height', 'width']],
        output_core_dims=[['height', 'width']],
        vectorize=True,
        dask='parallelized',
        kwargs=dict(k_mean=k_mean, thres=thres),
        output_dtypes=[varr.dtype]).rename(varr.name + '_clean')


def remove_brightspot_perframe(fm, k_mean, thres):
    f_mean = ndi.convolve(fm, k_mean)
    f_diff = np.nan_to_num(stat.zscore(fm - f_mean))
    if thres == 'min':
        f_mask = f_diff > -np.min(f_diff)
    else:
        f_mask = f_diff > thres
    return np.ma.masked_where(f_mask, fm).filled(0) + np.ma.masked_where(
        ~f_mask, f_mean).filled(0)