import os
import re
import itertools as itt
import numpy as np
import xarray as xr
import pandas as pd
import dask as da
from dask.diagnostics import ProgressBar
from scipy.ndimage.measurements import center_of_mass
from scipy.stats import pearsonr
from scipy.spatial.distance import cdist
from .preprocessing import remove_background
from .motion_correction import estimate_shift_fft, apply_shifts
from .utilities import xrconcat_recursive
from .visualization import centroid
from IPython.core.debugger import set_trace


def load_cnm_dataset_mf(path, pattern=r'^cnm.nc$', concat_dim='session'):
    path = os.path.normpath(path)
    cnmlist = []
    for dirpath, dirnames, fnames in os.walk(path):
        cnmnames = filter(lambda fn: re.search(pattern, fn), fnames)
        cnmpath = [os.path.join(dirpath, cnm) for cnm in cnmnames]
        cnmlist += cnmpath
    if len(cnmlist) > 1:
        return xr.open_mfdataset(cnmlist, concat_dim=concat_dim)
    else:
        print("No CNMF dataset found under path: {}".format(path))
        return None


def load_cnm_dataset(path, pattern=r'^cnm.nc$', concat_dim='session'):
    path = os.path.normpath(path)
    cnmlist = []
    for dirpath, dirnames, fnames in os.walk(path):
        cnmnames = filter(lambda fn: re.search(pattern, fn), fnames)
        for cnm in cnmnames:
            cnmpath = os.path.join(dirpath, cnm)
            cnmds = xr.open_dataset(cnmpath, chunks={})
            cnmds = cnmds.assign_coords(
                animal=cnmds.coords['animal'].astype(str))
            cnmds = cnmds.assign_coords(
                session=cnmds.coords['session'].astype(str))
            cnmds = cnmds.assign_coords(
                session_id=cnmds.coords['session_id'].astype(str))
            cnmds = cnmds.sel(unit_id=cnmds.attrs['unit_mask'])
            cnmlist.append(cnmds)
    if cnmlist:
        return xr.concat(cnmlist, dim=concat_dim)
    else:
        print("No CNMF dataset found under path: {}".format(path))
        return None


def get_minian_list(path, pattern=r'^minian.nc$'):
    path = os.path.normpath(path)
    mnlist = []
    for dirpath, dirnames, fnames in os.walk(path):
        mnames = filter(lambda fn: re.search(pattern, fn), fnames)
        mn_paths = [os.path.join(dirpath, mn) for mn in mnames]
        mnlist += mn_paths
    return mnlist

def estimate_shifts(minian_df, by='session', to='first', temp_var='org', template=None, rm_background=False):
    if template is not None:
        minian_df['template'] = template

    def get_temp(row):
        ds, temp = row['minian'], row['template']
        try:
            return ds.isel(frame=temp).drop('frame')
        except TypeError:
            func_dict = {
                'mean': lambda v: v.mean('frame'),
                'max': lambda v: v.max('frame')}
            try:
                return func_dict[temp](ds)
            except KeyError:
                raise NotImplementedError(
                    "template {} not understood".format(temp))

    minian_df['template'] = minian_df.apply(get_temp, axis='columns')
    grp_dims = list(minian_df.index.names)
    grp_dims.remove(by)
    temp_dict, shift_dict, corr_dict, tempsh_dict = [dict() for _ in range(4)]
    for idxs, df in minian_df.groupby(level=grp_dims):
        try:
            temp_ls = [t[temp_var] for t in df['template']]
        except KeyError:
            raise KeyError(
                "variable {} not found in dataset".format(temp_var))
        temps = (xr.concat(temp_ls, dim=by).expand_dims(grp_dims)
                 .reset_coords(drop=True))
        res = estimate_shift_fft(temps, dim=by, on=to)
        shifts = res.sel(variable=['height', 'width'])
        corrs = res.sel(variable='corr')
        temps_sh = apply_shifts(temps, shifts)
        temp_dict[idxs] = temps
        shift_dict[idxs] = shifts
        corr_dict[idxs] = corrs
        tempsh_dict[idxs] = temps_sh
    temps = xrconcat_recursive(temp_dict, grp_dims).rename('temps')
    shifts = xrconcat_recursive(shift_dict, grp_dims).rename('shifts')
    corrs = xrconcat_recursive(corr_dict, grp_dims).rename('corrs')
    temps_sh = xrconcat_recursive(tempsh_dict, grp_dims).rename('temps_shifted')
    with ProgressBar():
        temps = temps.compute()
        shifts = shifts.compute()
        corrs = corrs.compute()
        temps_sh = temps_sh.compute()
    return xr.merge([temps, shifts, corrs, temps_sh])


def estimate_shifts_old(mn_list,
                    temp_list,
                    z_thres=None,
                    rm_background=False,
                    method='first',
                    concat_dim='session'):
    temps = []
    for imn, mn_path in enumerate(mn_list):
        print(
            "loading template: {:2d}/{:2d}".format(imn, len(mn_list)))
        try:
            with xr.open_dataset(
                mn_path, chunks=dict(width='auto', height='auto'))['org'] as cur_va:
                if temp_list[imn] == 'first':
                    cur_temp = cur_va.isel(frame=0).load().copy()
                elif temp_list[imn] == 'last':
                    cur_temp = cur_va.isel(frame=-1).load().copy()
                elif temp_list[imn] == 'mean':
                    cur_temp = (cur_va.mean('frame'))
                    with ProgressBar():
                        cur_temp = cur_temp.compute()
                else:
                    print("unrecognized template")
                    continue
                if rm_background:
                    cur_temp = remove_background(cur_temp, 'uniform', wnd=51)
                temps.append(cur_temp)
        except KeyError:
            print("no video found for path {}".format(mn_path))
    if concat_dim:
        temps = xr.concat(temps, dim=concat_dim).rename('temps')
        window = ~temps.isnull().sum(concat_dim).astype(bool)
        temps = temps.where(window, drop=True)
    shifts = []
    corrs = []
    for itemp, temp_dst in temps.rolling(**{concat_dim: 1}):
        print("processing: {}".format(itemp.values))
        if method == 'first':
            temp_src = temps.isel(**{concat_dim: 0})
        elif method == 'last':
            temp_src = temps.isel(**{concat_dim: -1})
        # common = (temp_src.isnull() + temp_dst.isnull())
        # temp_src = temp_src.reindex_like(common)
        # temp_dst = temp_dst.reindex_like(common)
        temp_src, temp_dst = temp_src.squeeze(), temp_dst.squeeze()
        src_fft = np.fft.fft2(temp_src)
        dst_fft = np.fft.fft2(temp_dst)
        cur_res = shift_fft(src_fft, dst_fft)
        cur_sh = cur_res[0:2]
        cur_cor = cur_res[2]
        cur_anm = temp_dst.coords['animal']
        cur_ss = temp_dst.coords['session']
        cur_ssid = temp_dst.coords['session_id']
        cur_sh = xr.DataArray(
            cur_sh,
            coords=dict(shift_dim=list(temp_dst.dims)),
            dims=['shift_dim'])
        cur_cor = xr.DataArray(cur_cor)
        cur_sh = cur_sh.assign_coords(
            animal=cur_anm, session=cur_ss, session_id=cur_ssid)
        cur_cor = cur_cor.assign_coords(
            animal=cur_anm, session=cur_ss, session_id=cur_ssid)
        shifts.append(cur_sh)
        corrs.append(cur_cor)
    if concat_dim:
        shifts = xr.concat(shifts, dim=concat_dim).rename('shifts')
        corrs = xr.concat(corrs, dim=concat_dim).rename('corrs')
        temps = xr.concat(temps, dim=concat_dim).rename('temps')
    return shifts, corrs, temps


def apply_shifts_old(var, shifts, inplace=False, dim='session'):
    shifts = shifts.dropna(dim)
    var_list = []
    for dim_n, sh in shifts.groupby(dim):
        sh_dict = (sh.astype(int).to_series().reset_index()
                   .set_index('shift_dim')['shifts'].to_dict())
        var_list.append((var.sel(**{dim: dim_n})
                         .shift(**sh_dict).rename(var.name + "_shifted")))
    return xr.concat(var_list, dim=dim)

def calculate_centroids(A, window):
    A = A.where(window, 0)
    return centroid(A, verbose=True)


def calculate_centroids_old(cnmds, window, grp_dim=['animal', 'session']):
    print("computing centroids")
    cnt_list = []
    for anm, cur_anm in cnmds.groupby('animal'):
        for ss, cur_ss in cur_anm.groupby('session'):
            # cnt = centroids(cur_ss['A_shifted'], window.sel(animal=anm))
            cnt = da.delayed(centroids)(
                cur_ss['A_shifted'], window.sel(animal=anm))
            cnt_list.append(cnt)
    with ProgressBar():
        cnt_list, = da.compute(cnt_list)
    cnts_ds = pd.concat(cnt_list, ignore_index=True)
    cnts_ds.height = cnts_ds.height.astype(float)
    cnts_ds.width = cnts_ds.width.astype(float)
    cnts_ds.unit_id = cnts_ds.unit_id.astype(int)
    cnts_ds.animal = cnts_ds.animal.astype(str)
    cnts_ds.session = cnts_ds.session.astype(str)
    cnts_ds.session_id = cnts_ds.session_id.astype(str)
    return cnts_ds


def centroids(A, window=None):
    A = A.load().dropna('unit_id', how='all')
    if not A.size > 0:
        return pd.DataFrame()
    if window is None:
        window = A.isnull().sum('unit_id') == 0
    try:
        A = A.where(window, drop=True)
    except:
        set_trace()
    A = A.fillna(0)
    meta_dims = set(A.coords.keys()) - set(A.dims)
    meta_dict = {dim: A.coords[dim].values for dim in meta_dims}
    cur_meta = pd.Series(meta_dict)
    cts_list = []
    for uid, cur_uA in A.groupby('unit_id'):
        cur_A = cur_uA.values
        if not (cur_A > 0).any():
            continue
        cur_idxs = cur_uA.dims
        cur_cts = center_of_mass(cur_A)
        cur_cts = pd.Series(cur_cts, index=cur_idxs)
        cur_cts = cur_cts.append(pd.Series(dict(unit_id=uid)))
        cur_cts = cur_cts.append(cur_meta)
        cts_list.append(cur_cts)
    try:
        cts_df = pd.concat(cts_list, axis=1, ignore_index=True).T
    except ValueError:
        cts_df = pd.DataFrame()
    return cts_df


def calculate_centroid_distance(cents, by='session', index_dim=['animal'], tile=(50, 50)):
    res_list = []

    def cent_pair(grp):
        for (byA, grpA), (byB, grpB) in itt.combinations(list(grp.groupby(by)), 2):
            cur_pairs = subset_pairs(grpA, grpB, tile)
            pairs_ls = list(cur_pairs)
            len_df = len(pairs_ls)
            subA = (grpA.set_index('unit_id')
                    .loc[[p[0] for p in pairs_ls]]
                    .reset_index())
            subB = (grpB.set_index('unit_id')
                    .loc[[p[1] for p in pairs_ls]]
                    .reset_index())
            dist = da.delayed(pd_dist)(subA, subB).rename('distance')
            dist_df = da.delayed(pd.concat)(
                [subA['unit_id'].rename(byA), subB['unit_id'].rename(byB), dist], axis='columns')
            dist_df = dist_df.rename(columns={
                'distance': ('variable', 'distance'),
                byA: (by, byA),
                byB: (by, byB)})
            return dist_df, len_df

    print("creating parallel schedule")
    if index_dim:
        for idxs, grp in cents.groupby(index_dim):
            dist_df, len_df = cent_pair(grp)
            if type(idxs) is not tuple:
                idxs = (idxs,)
            meta_df = pd.concat(
                [pd.Series([idx] * len_df, name=('meta', dim))
                 for idx, dim in zip(idxs, index_dim)],
                axis='columns')
            res_df = da.delayed(pd.concat)([meta_df, dist_df], axis='columns')
            res_list.append(res_df)
    else:
        res_list = [cent_pair(cents)[0]]
    print("computing distances")
    res_list = da.compute(res_list)[0]
    res_df = pd.concat(res_list, ignore_index=True)
    res_df.columns = pd.MultiIndex.from_tuples(res_df.columns)
    return res_df


def subset_pairs(A, B, tile):
    Ah, Aw, Bh, Bw = A['height'], A['width'], B['height'], B['width']
    hh = (min(Ah.min(), Bh.min()), max(Ah.max(), Bh.max()))
    ww = (min(Aw.min(), Bw.min()), max(Aw.max(), Bw.max()))
    dh, dw = np.ceil(tile[0] / 2), np.ceil(tile[1] / 2)
    tile_h = np.linspace(hh[0], hh[1], np.ceil((hh[1] - hh[0]) * 2 / tile[0]))
    tile_w = np.linspace(ww[0], ww[1], np.ceil((ww[1] - ww[0]) * 2 / tile[1]))
    pairs = set()
    for h, w in itt.product(tile_h, tile_w):
        curA = A[
            Ah.between(h - dh, h + dh)
            & Aw.between(w - dw, w + dw)]
        curB = B[
            Bh.between(h - dh, h + dh)
            & Bw.between(w - dw, w + dw)]
        Au, Bu = curA['unit_id'].values, curB['unit_id'].values
        pairs.update(
            set(map(tuple, cartesian(Au, Bu).tolist())))
    return pairs


def pd_dist(A, B):
    return np.sqrt(
        ((A[['height', 'width']] - B[['height', 'width']])**2)
        .sum('columns'))

def cartesian(*args):
    n = len(args)
    return np.array(np.meshgrid(*args)).T.reshape((-1, n))


def calculate_centroid_distance_old(cents,
                                A,
                                window,
                                grp_dim=['animal'],
                                tile=(50, 50),
                                shift=True,
                                hamming=True,
                                corr=False):
    dist_list = []
    A = da.delayed(A)
    for cur_anm, cur_grp in cents.groupby('animal'):
        print("processing animal: {}".format(cur_anm))
        cur_A = A.sel(animal=cur_anm)
        cur_wnd = window.sel(animal=cur_anm)
        dist = centroids_distance(cur_grp, cur_A, cur_wnd, shift, hamming,
                                  corr, tile)
        dist['meta', 'animal'] = cur_anm
        dist_list.append(dist)
    dist = pd.concat(dist_list, ignore_index=True)
    return dist


def centroids_distance_old(cents,
                       A,
                       window,
                       shift,
                       hamming,
                       corr,
                       tile=(50, 50)):
    sessions = cents['session'].unique()
    dim_h = (np.min(cents['height']), np.max(cents['height']))
    dim_w = (np.min(cents['width']), np.max(cents['width']))
    dist_list = []
    for ssA, ssB in itt.combinations(sessions, 2):
        # dist = _calc_cent_dist(ssA, ssB, cents, cnmds, window, tile, dim_h, dim_w)
        dist = da.delayed(_calc_cent_dist)(ssA, ssB, cents, A, window,
                                           tile, dim_h, dim_w, shift, hamming,
                                           corr)
        dist_list.append(dist)
    with ProgressBar():
        dist_list, = da.compute(dist_list)
    dists = pd.concat(dist_list, ignore_index=True)
    return dists


def _calc_cent_dist_old(ssA, ssB, cents, A, window, tile, dim_h, dim_w, shift,
                    hamming, corr):
    ssA_df = cents[cents['session'] == ssA]
    ssB_df = cents[cents['session'] == ssB]
    ssA_uids = ssA_df['unit_id'].unique()
    ssB_uids = ssB_df['unit_id'].unique()
    ssA_h = ssA_df['height']
    ssA_w = ssA_df['width']
    ssB_h = ssB_df['height']
    ssB_w = ssB_df['width']
    tile_ct_h = np.linspace(dim_h[0], dim_h[1],
                            np.ceil((dim_h[1] - dim_h[0]) * 2.0 / tile[0]) + 1)
    tile_ct_w = np.linspace(dim_w[0], dim_w[1],
                            np.ceil((dim_w[1] - dim_w[0]) * 2.0 / tile[1]) + 1)
    dh = np.ceil(tile[0] / 2.0)
    dw = np.ceil(tile[1] / 2.0)
    pairs = set()
    for ct_h, ct_w in itt.product(tile_ct_h, tile_ct_w):
        ssA_uid_inrange = ssA_uids[(ct_h - dh < ssA_h)
                                   & (ssA_h < ct_h + dh)
                                   & (ct_w - dw < ssA_w) & (ssA_w < ct_w + dw)]
        ssB_uid_inrange = ssB_uids[(ct_h - dh < ssB_h)
                                   & (ssB_h < ct_h + dh)
                                   & (ct_w - dw < ssB_w) & (ssB_w < ct_w + dw)]
        for pair in itt.product(ssA_uid_inrange, ssB_uid_inrange):
            pairs.add(pair)
    dist_list = []
    for ip, (uidA, uidB) in enumerate(pairs):
        idxarr = [[
            'session', 'session', 'variable', 'variable', 'variable',
            'variable'
        ], [ssA, ssB, 'distance', 'coeff', 'p', 'hamming']]
        mulidx = pd.MultiIndex.from_arrays(
            idxarr, names=('var_class', 'var_name'))
        centA = ssA_df[ssA_df['unit_id'] == uidA][['height', 'width']]
        centB = ssB_df[ssB_df['unit_id'] == uidB][['height', 'width']]
        diff = centA.reset_index(drop=True) - centB.reset_index(drop=True)
        diff = diff.T.squeeze()
        cur_dist = np.sqrt((diff**2).sum())
        if corr or hamming:
            cur_A_A = A.sel(
                session=ssA, unit_id=uidA).where(
                    window, drop=True)
            cur_A_B = A.sel(
                session=ssB, unit_id=uidB).where(
                    window, drop=True)
        if shift:
            cur_A_B = cur_A_B.shift(**diff.round().astype(int).to_dict())
            # wnd_new = cur_A_B.notnull()
            wnd_new = (cur_A_B + cur_A_B) > 0
            cur_A_A = cur_A_A.where(wnd_new, drop=True).fillna(0)
            cur_A_B = cur_A_B.where(wnd_new, drop=True).fillna(0)
        if corr:
            cur_coef, cur_p = pearsonr(cur_A_A.values.flatten(),
                                       cur_A_B.values.flatten())
        else:
            cur_coef, cur_p = np.nan, np.nan
        if hamming:
            ham = xr.apply_ufunc(
                np.absolute, (cur_A_A > 0) - (cur_A_B > 0),
                dask='allowed').sum()
            uni = ((cur_A_A + cur_A_B) > 0).sum()
            ham = np.asscalar((ham / uni).values)
        else:
            ham = np.nan
        dist = pd.Series(
            [uidA, uidB, cur_dist, cur_coef, cur_p, ham], index=mulidx)
        dist_list.append(dist)
    dists = pd.concat(dist_list, axis=1, ignore_index=True).T
    return dists


def group_by_session(df):
    ss = df['session'].notnull()
    grp = ss.apply(lambda r: tuple(r.index[r].tolist()), axis=1)
    df['group', 'group'] = grp
    return df


def calculate_mapping(dist):
    map_idxs = set()
    try:
        for anm, grp in dist.groupby(dist['meta', 'animal']):
            map_idxs.update(mapping(grp))
    except KeyError:
        map_idxs = mapping(dist)
    return dist.loc[list(map_idxs)]


def mapping(dist):
    map_list = set()
    for sess, grp in dist.groupby(dist['group', 'group']):
        minidx_list = []
        for ss in sess:
            minidx = set()
            for uid, uid_grp in grp.groupby(grp['session', ss]):
                minidx.add(uid_grp['variable', 'distance'].idxmin())
            minidx_list.append(minidx)
        minidxs = set.intersection(*minidx_list)
        map_list.update(minidxs)
    return map_list


def resolve_mapping(mapping):
    map_list = []
    try:
        for anm, grp in mapping.groupby(mapping['meta', 'animal']):
            map_list.append(resolve(grp))
    except KeyError:
        map_list = [resolve(mapping)]
    return pd.concat(map_list, ignore_index=True)


def resolve(mapping):
    mapping = mapping.reset_index(drop=True)
    map_ss = mapping['session']
    for ss in map_ss.columns:
        del_idx = []
        for ss_uid, ss_grp in mapping.groupby(mapping['session', ss]):
            if ss_grp.shape[0] > 1:
                del_idx.extend(ss_grp.index)
                new_sess = []
                for s in ss_grp['session']:
                    uval = ss_grp['session', s].dropna().unique()
                    if len(uval) == 0:
                        new_sess.append(np.nan)
                    elif len(uval) == 1:
                        new_sess.append(uval[0])
                    elif len(uval) > 1:
                        break
                else:
                    new_row = ss_grp.iloc[0].copy()
                    new_row['session'] = new_sess
                    new_row['variable', 'distance'] = np.nan
                    mapping = mapping.append(new_row, ignore_index=True)
        mapping = mapping.drop(del_idx).reset_index(drop=True)
    return group_by_session(mapping)


def fill_mapping(mappings,
                 cents):

    def fill(cur_grp, cur_cent):
        fill_ls = []
        for cur_ss in list(cur_grp['session']):
            cur_ss_grp = cur_grp['session'][cur_ss].dropna()
            cur_ss_all = cur_cent[cur_cent['session'] == cur_ss][
                'unit_id'].dropna()
            cur_fill_set = set(cur_ss_all.unique()) - set(
                cur_ss_grp.unique())
            cur_fill_df = pd.DataFrame({
                ('session', cur_ss):
                list(cur_fill_set),
            })
            fill_ls.append(cur_fill_df)
        return pd.concat(fill_ls, ignore_index=True)

    try:
        for cur_id, cur_grp in mappings.groupby(list(mappings['meta'])):
            cur_cent = (cents.set_index(list(mappings['meta']))
                        .loc[cur_id].reset_index())
            cur_grp_fill = fill(cur_grp, cur_cent)
            mappings = pd.concat([mappings, cur_grp_fill], ignore_index=True)
    except KeyError:
        map_fill = fill(mappings, cents)
        mappings = pd.concat([mappings, map_fill], ignore_index=True)
    return mappings