import numpy as np
import xarray as xr
import pandas as pd
import dask as da
import numba as nb
import dask.array.fft as dafft
import dask.array.linalg as dalin
import dask.array as darr
import functools as fct
# import dask_ml.joblib
from dask import delayed, compute
from dask.diagnostics import Profiler
from dask.distributed import progress
from IPython.core.debugger import set_trace
from scipy.ndimage import gaussian_filter, label
from scipy.signal import welch, butter, lfilter
from scipy.sparse import diags, dia_matrix
from scipy.linalg import toeplitz, lstsq
from scipy.spatial.distance import pdist, squareform
from sklearn.linear_model import LassoLars
from sklearn.externals.joblib import parallel_backend
from numba import jit, guvectorize
from skimage import morphology as moph
from statsmodels.tsa.stattools import acovf
import networkx as nx
import cvxpy as cvx
import pyfftw.interfaces.dask_fft as fftw
from timeit import timeit
import warnings
from .utilities import get_chk, rechunk_like

def get_noise_fft(varr, noise_range=(0.25, 0.5), noise_method='logmexp'):
    sn = xr.apply_ufunc(
        _noise_fft,
        varr.chunk(dict(frame=-1)),
        input_core_dims=[['frame']],
        output_core_dims=[[]],
        dask='parallelized',
        vectorize=True,
        kwargs=dict(
            noise_range=noise_range,
            noise_method=noise_method),
        output_dtypes=[np.float]
    )
    return sn

def _noise_fft(px, noise_range=(0.25, 0.5), noise_method='logmexp'):
    _T = len(px)
    nr = np.around(np.array(noise_range) * 2 * _T).astype(int)
    px_fft = np.fft.rfft(px)
    px_psd = 1 / _T * np.abs(px_fft)**2
    px_band = px_psd[nr[0]:nr[1]]
    if noise_method == 'mean':
        return np.sqrt(px_band.mean())
    elif noise_method == 'median':
        return np.sqrt(px_band.median())
    elif noise_method == 'logmexp':
        eps = np.finfo(px_band.dtype).eps
        return np.sqrt(np.exp(np.log(px_band + eps).mean()))
    elif noise_method == 'sum':
        return np.sqrt(px_band.sum())

def psd_fft(varr):
    _T = len(varr.coords['frame'])
    ns = _T // 2 + 1
    if _T % 2 == 0:
        freq_crd = np.linspace(0, 0.5, ns)
    else:
        freq_crd = np.linspace(0, 0.5 * (_T - 1) / _T, ns)
    print("computing psd of input")
    varr_fft = xr.apply_ufunc(
        fftw.rfft,
        varr.chunk(dict(frame=-1)),
        input_core_dims=[['frame']],
        output_core_dims=[['freq']],
        dask='allowed',
        output_sizes=dict(freq=ns),
        output_dtypes=[np.complex_])
    varr_fft = varr_fft.assign_coords(freq=freq_crd)
    varr_psd = 1 / _T * np.abs(varr_fft)**2
    return varr_psd

def psd_welch(varr):
    _T = len(varr.coords['frame'])
    ns = _T // 2 + 1
    if _T % 2 == 0:
        freq_crd = np.linspace(0, 0.5, ns)
    else:
        freq_crd = np.linspace(0, 0.5 * (_T - 1) / _T, ns)
    varr_psd = xr.apply_ufunc(
            _welch,
            varr.chunk(dict(frame=-1)),
            input_core_dims=[['frame']],
            output_core_dims=[['freq']],
            dask='parallelized',
            vectorize=True,
            kwargs=dict(nperseg=_T),
            output_sizes=dict(freq=ns),
            output_dtypes=[varr.dtype])
    varr_psd = varr_psd.assign_coords(freq=freq_crd)
    return varr_psd

def _welch(x, **kwargs):
    return welch(x, **kwargs)[1]

def get_noise(psd, noise_range=(0.25, 0.5), noise_method='logmexp'): 
    psd_band = psd.sel(freq=slice(*noise_range))
    print("estimating noise using method {}".format(noise_method))
    if noise_method == 'mean':
        sn = np.sqrt(psd_band.mean('freq'))
    elif noise_method == 'median':
        sn = np.sqrt(psd_band.median('freq'))
    elif noise_method == 'logmexp':
        eps = np.finfo(psd_band.dtype).eps
        sn = np.sqrt(np.exp(np.log(psd_band + eps).mean('freq')))
    sn = sn.persist()
    return sn 

def get_noise_welch(varr,
                    noise_range=(0.25, 0.5),
                    noise_method='logmexp',
                    compute=True):
    print("estimating noise")
    sn = xr.apply_ufunc(
        noise_welch,
        varr.chunk(dict(frame=-1)),
        input_core_dims=[['frame']],
        dask='parallelized',
        vectorize=True,
        kwargs=dict(noise_range=noise_range, noise_method=noise_method),
        output_dtypes=[varr.dtype])
    if compute:
        sn = sn.compute()
    return sn


def noise_welch(y, noise_range, noise_method):
    ff, Pxx = welch(y)
    mask0, mask1 = ff > noise_range[0], ff < noise_range[1]
    mask = np.logical_and(mask0, mask1)
    Pxx_ind = Pxx[mask]
    sn = {
        'mean': lambda x: np.sqrt(np.mean(x / 2)),
        'median': lambda x: np.sqrt(np.median(x / 2)),
        'logmexp': lambda x: np.sqrt(np.exp(np.mean(np.log(x / 2))))
    }[noise_method](Pxx_ind)
    return sn


def update_spatial(Y,
                   A,
                   b,
                   C,
                   f,
                   sn,
                   gs_sigma=6,
                   dl_wnd=5,
                   sparse_penal=0.5,
                   update_background=True,
                   post_scal=False,
                   normalize=True,
                   zero_thres='eps'):
    _T = len(Y.coords['frame'])
    print("estimating penalty parameter")
    cct = C.dot(C, 'frame')
    alpha = sparse_penal * sn * np.sqrt(np.max(np.diag(cct))) / _T
    alpha = alpha.persist()
    print("computing subsetting matrix")
    if dl_wnd:
        selem = moph.disk(dl_wnd)
        sub = xr.apply_ufunc(
            moph.dilation,
            A.fillna(0).chunk(dict(height=-1, width=-1)),
            input_core_dims=[['height', 'width']],
            output_core_dims=[['height', 'width']],
            vectorize=True,
            kwargs=dict(selem=selem),
            dask='parallelized',
            output_dtypes=[A.dtype])
        sub = (sub > 0)
    else:
        sub = xr.apply_ufunc(np.ones_like, A.compute())
    sub = sub.compute().astype(bool).transpose(*A.dims).chunk(A.chunks)
    if update_background:
        A = xr.concat([A, b.assign_coords(unit_id=-1)], 'unit_id')
        b_erd = xr.apply_ufunc(
            moph.erosion,
            b.chunk(dict(height=-1, width=-1)),
            input_core_dims=[['height', 'width']],
            output_core_dims=[['height', 'width']],
            kwargs=dict(selem=selem),
            dask='parallelized',
            output_dtypes=[b.dtype])
        sub = xr.concat([
            sub, (b_erd > 0).astype(bool).assign_coords(unit_id=-1)],
                        'unit_id')
        C = xr.concat([C, f.assign_coords(unit_id=-1)], 'unit_id')
    print("fitting spatial matrix")
    gu_update = darr.gufunc(
        fct.partial(
            update_spatial_perpx,
            C=C.transpose('frame', 'unit_id').values),
        signature="(f),(),(u)->(u)",
        output_dtypes=A.dtype,
        vectorize=True)
    A_new = xr.apply_ufunc(
        gu_update,
        Y.chunk(dict(frame=-1)),
        alpha,
        sub.chunk(dict(unit_id=-1)),
        input_core_dims=[['frame'], [], ['unit_id']],
        output_core_dims=[['unit_id']],
        dask='allowed')
    A_new = A_new.persist()
    print("removing empty units")
    if zero_thres == 'eps':
        zero_thres = np.finfo(A_new.dtype).eps
    A_new = A_new.where(A_new > zero_thres).fillna(0)
    non_empty = A_new.sum(['width', 'height']) > 0
    A_new = A_new.where(non_empty, drop=True)
    C_new = C.where(non_empty, drop=True)
    A_new = rechunk_like(A_new, A).persist()
    C_new = rechunk_like(C_new, C).persist()
    if post_scal and len(A_new) > 0:
        print("post-hoc scaling")
        A_new_flt = (A_new.stack(spatial=['height', 'width'])
                     .compute())
        Y_flt = (Y.mean('frame').stack(spatial=['height', 'width'])
                 .compute())
        def lstsq(a, b):
            return np.linalg.lstsq(a, b, rcond=-1)[0]
        scale = xr.apply_ufunc(
            lstsq,
            A_new_flt,
            Y_flt,
            input_core_dims=[['spatial', 'unit_id'], ['spatial']],
            output_core_dims=[['unit_id']])
        C_mean = C.mean('frame').compute()
        scale = scale / C_mean
        A_new = A_new * scale
        try:
            A_new = A_new.persist()
        except np.linalg.LinAlgError:
            warnings.warn("post-hoc scaling failed", RuntimeWarning)
    if update_background:
        print("updating background")
        try:
            b_new = A_new.sel(unit_id=-1)
            b_new = b_new / da.array.linalg.norm(b_new.data)
            f_new = xr.apply_ufunc(
                da.array.tensordot, Y, b_new,
                input_core_dims=[['frame', 'height', 'width'], ['height', 'width']],
                output_core_dims=[['frame']],
                kwargs=dict(axes=[(1, 2), (0, 1)]),
                dask='allowed').persist()
            A_new = A_new.drop(-1, 'unit_id')
            C_new = C_new.drop(-1, 'unit_id')
        except KeyError:
            print("background terms are empty")
            b_new = xr.zeros_like(b)
            f_new = xr.zeros_like(f)
    else:
        b_new = b
        f_new = f
    if normalize and len(A_new) > 0:
        print("normalizing result")
        A_norm = xr.apply_ufunc(
            darr.linalg.norm,
            A_new.stack(spatial=['height', 'width']),
            input_core_dims=[['spatial', 'unit_id']],
            output_core_dims=[['unit_id']],
            kwargs=dict(axis=0),
            dask='allowed')
        A_new = (A_new / A_norm).persist()
    return A_new, b_new, C_new, f_new


def update_spatial_perpx(y, alpha, sub, C):
    res = np.zeros_like(sub, dtype=y.dtype)
    if np.sum(sub) > 0:
        C = C[:, sub]
        clf = LassoLars(alpha=alpha, positive=True)
        coef = clf.fit(C, y).coef_
        res[np.where(sub)[0]] = coef
    return res

def compute_trace(Y, A, b, C, f, noise_freq=None):
    nunits = len(A.coords['unit_id'])
    A_rechk = A.chunk(dict(height=-1, width=-1))
    C_rechk = C.chunk(dict(unit_id=-1))
    Y_rechk = Y.chunk(dict(height=-1, width=-1))
    AA = xr.apply_ufunc(
        da.array.tensordot,
        A_rechk,
        A_rechk.rename(dict(unit_id='unit_id_cp')),
        input_core_dims=[['unit_id', 'height', 'width'], ['height', 'width', 'unit_id_cp']],
        output_core_dims=[['unit_id', 'unit_id_cp']],
        dask='allowed',
        kwargs=dict(axes=([1, 2], [0, 1])),
        output_dtypes=[A.dtype])
    nA = (A_rechk**2).sum(['height', 'width']).compute()
    nA_inv = xr.apply_ufunc(
        lambda x: np.asarray(diags(x).todense()),
        1 / nA,
        input_core_dims=[['unit_id']],
        output_core_dims=[['unit_id', 'unit_id_cp']],
        dask='parallelized',
        output_dtypes=[nA.dtype],
        output_sizes=dict(unit_id_cp = nunits)).compute()
    nA_inv = nA_inv.assign_coords(unit_id_temp=AA.coords['unit_id_cp'])
    b = b.fillna(0).expand_dims('dot').chunk(dict(height=-1, width=-1))
    f = f.fillna(0).expand_dims('dot')
    B = xr.apply_ufunc(
        da.array.dot,
        b,
        f,
        input_core_dims=[['height', 'width', 'dot'], ['dot', 'frame']],
        output_core_dims=[['height', 'width', 'frame']],
        dask='allowed',
        output_dtypes=[b.dtype])
    Y = Y_rechk - B
    YA = (xr.apply_ufunc(
        da.array.tensordot,
        Y,
        A_rechk,
        input_core_dims=[['frame', 'height', 'width'], ['height', 'width', 'unit_id']],
        output_core_dims=[['frame', 'unit_id']],
        dask='allowed',
        kwargs=dict(axes=([1, 2], [0, 1])),
        output_dtypes=[A.dtype])
          .rename(dict(unit_id='unit_id_cp')))
    YA_norm = xr.apply_ufunc(
        da.array.dot,
        YA,
        nA_inv,
        input_core_dims=[['frame', 'unit_id_cp'], ['unit_id_cp', 'unit_id']],
        output_core_dims=[['frame', 'unit_id']],
        dask='allowed',
        output_dtypes=[YA.dtype])
    CA = xr.apply_ufunc(
        da.array.dot,
        C_rechk,
        AA.chunk(dict(unit_id=-1, unit_id_cp=-1)),
        input_core_dims=[['frame', 'unit_id'], ['unit_id', 'unit_id_cp']],
        output_core_dims=[['frame', 'unit_id_cp']],
        dask='allowed',
        output_dtypes=[C.dtype])
    CA_norm = xr.apply_ufunc(
        da.array.dot,
        CA,
        nA_inv,
        input_core_dims=[['frame', 'unit_id_cp'], ['unit_id_cp', 'unit_id']],
        output_core_dims=[['frame', 'unit_id']],
        dask='allowed',
        output_dtypes=[CA.dtype])
    YrA = YA_norm - CA_norm + C_rechk
    if noise_freq:
        print("smoothing signals")
        but_b, but_a = butter(2, noise_freq, btype='low', analog=False)
        YrA_smth = xr.apply_ufunc(
            lambda x: lfilter(but_b, but_a, x),
            YrA.chunk(dict(frame=-1)),
            input_core_dims=[['frame']],
            output_core_dims=[['frame']],
            vectorize=True,
            dask='parallelized',
            output_dtypes=[YrA.dtype])
    else:
        YrA_smth = YrA
    return YrA_smth


def update_temporal(Y,
                    A,
                    b,
                    C,
                    f,
                    sn_spatial,
                    YrA=None,
                    noise_freq=0.25,
                    p=None,
                    add_lag='p',
                    jac_thres = 0.1,
                    use_spatial=False,
                    sparse_penal=1,
                    zero_thres=1e-8,
                    max_iters=200,
                    use_smooth=True,
                    compute=True,
                    normalize=True,
                    post_scal=True,
                    scs_fallback=False):
    print("grouping overlaping units")
    A_pos = (A > 0).astype(int)
    A_neg = (A == 0).astype(int)
    A_inter = xr.apply_ufunc(
        da.array.tensordot,
        A_pos,
        A_pos.rename(unit_id='unit_id_cp'),
        input_core_dims=[['unit_id', 'height', 'width'], ['height', 'width', 'unit_id_cp']],
        output_core_dims=[['unit_id', 'unit_id_cp']],
        dask='allowed',
        kwargs=dict(axes=([1, 2], [0, 1])),
        output_dtypes=[A_pos.dtype])
    A_union = xr.apply_ufunc(
        da.array.tensordot,
        A_neg,
        A_neg.rename(unit_id='unit_id_cp'),
        input_core_dims=[['unit_id', 'height', 'width'], ['height', 'width', 'unit_id_cp']],
        output_core_dims=[['unit_id', 'unit_id_cp']],
        dask='allowed',
        kwargs=dict(axes=([1, 2], [0, 1])),
        output_dtypes=[A_neg.dtype])
    A_jac = A_inter / (A.sizes['height'] * A.sizes['width'] - A_union)
    if compute:
        A_jac = A_jac.compute()
    unit_labels = xr.apply_ufunc(
        label_connected,
        A_jac > jac_thres,
        input_core_dims=[['unit_id', 'unit_id_cp']],
        kwargs=dict(only_connected=True),
        output_core_dims=[['unit_id']])
    if YrA is not None:
        YrA = YrA
    else:
        print("computing trace")
        YrA = compute_trace(Y, A, b, C, f).persist()
    YrA = YrA.chunk(dict(frame=-1, unit_id=1))
    YrA = YrA.assign_coords(unit_labels=unit_labels)
    if normalize:
        print("normalizing traces")
        YrA_norm = (YrA / YrA.sum('frame') * YrA.sizes['frame']).persist()
    else:
        YrA_norm = YrA
    sn_temp = get_noise_fft(
        YrA_norm,noise_range=(noise_freq, 1),
        noise_method='sum').persist()
    sn_temp = sn_temp.assign_coords(unit_labels=unit_labels)
    if use_spatial:
        print("flattening spatial dimensions")
        Y_flt = Y.stack(spatial=('height', 'width'))
        A_flt = A.stack(spatial=(
            'height', 'width')).assign_coords(unit_labels=unit_labels)
        sn_spatial = sn_spatial.stack(spatial=('height', 'width'))
    if use_smooth:
        print("smoothing signals")
        but_b, but_a = butter(2, noise_freq, btype='low', analog=False)
        YrA_smth = xr.apply_ufunc(
            lambda x: lfilter(but_b, but_a, x),
            YrA_norm,
            input_core_dims=[['frame']],
            output_core_dims=[['frame']],
            vectorize=True,
            dask='parallelized',
            output_dtypes=[YrA.dtype])
        if compute:
            YrA_smth = YrA_smth.persist()
            sn_temp_smth = get_noise_fft(
                YrA_smth, noise_range=(noise_freq, 1)).persist()
            sn_temp_smth = sn_temp_smth.assign_coords(unit_labels=unit_labels)
    else:
        YrA_smth = YrA_norm
        sn_temp_smth = sn_temp
    if p is None:
        print("estimating order p for each neuron")
        p = xr.apply_ufunc(
            get_p,
            YrA_smth,
            input_core_dims=[['frame']],
            vectorize=True,
            dask='parallelized',
            output_dtypes=[np.int]).clip(1)
        if compute:
            p = p.compute()
        p_max = p.max().values
    else:
        p_max = p
    print("estimating AR coefficients")
    g = xr.apply_ufunc(
        get_ar_coef,
        YrA_smth.chunk(dict(frame=-1)),
        sn_temp_smth,
        p,
        input_core_dims=[['frame'], [], []],
        output_core_dims=[['lag']],
        kwargs=dict(pad=p_max, add_lag=add_lag),
        vectorize=True,
        dask='parallelized',
        output_dtypes=[sn_temp_smth.dtype],
        output_sizes=dict(lag=p_max))
    g = g.assign_coords(lag=np.arange(1, p_max + 1), unit_labels=unit_labels)
    if compute:
        g = g.persist()
    print("updating isolated temporal components")
    if use_spatial is 'full':
        result_iso = xr.apply_ufunc(
            update_temporal_cvxpy,
            Y_flt.chunk(dict(spatial=-1, frame=-1)),
            g.where(unit_labels == -1, drop=True).chunk(dict(lag=-1)),
            sn_spatial.chunk(dict(spatial=-1)),
            A_flt.where(unit_labels == -1, drop=True).chunk(dict(spatial=-1)),
            input_core_dims=[['spatial', 'frame'], ['lag'], ['spatial'],
                             ['spatial']],
            output_core_dims=[['trace', 'frame']],
            vectorize=True,
            dask='parallelized',
            kwargs=dict(
                sparse_penal=sparse_penal,
                max_iters=max_iters,
                scs_fallback=scs_fallback),
            output_sizes=dict(trace=5),
            output_dtypes=[YrA.dtype])
    else:
        gu_update = darr.gufunc(
            fct.partial(
                update_temporal_cvxpy,
                sparse_penal=sparse_penal,
                max_iters=max_iters,
                scs_fallback=scs_fallback),
            signature="(f),(l),()->(t,f)",
            vectorize=True,
            output_dtypes=[YrA.dtype],
            output_sizes=dict(t=5))
        result_iso = xr.apply_ufunc(
            gu_update,
            YrA_norm.where(unit_labels == -1, drop=True).persist(),
            g.where(unit_labels == -1, drop=True).chunk(dict(lag=-1)).persist(),
            sn_temp.where(unit_labels == -1, drop=True).persist(),
            input_core_dims=[['frame'], ['lag'], []],
            output_core_dims=[['trace', 'frame']],
            dask='allowed')
    if compute:
        with da.config.set(scheduler='processes'):
            result_iso = result_iso.compute()
    print("updating overlapping temporal components")
    res_list = []
    g_ovlp = g.where(unit_labels >= 0, drop=True)
    if len(g_ovlp) > 0:
        for cur_labl, cur_g in g_ovlp.groupby('unit_labels'):
            if use_spatial:
                cur_A = A_flt_ovlp.where(unit_labels == cur_labl, drop=True)
                cur_res = delayed(xr.apply_ufunc)(
                    update_temporal_cvxpy,
                    Y_flt.chunk(dict(spatial=-1, frame=-1)),
                    cur_g.chunk(dict(lag=-1)),
                    sn_spatial.chunk(dict(spatial=-1)),
                    cur_A.chunk(dict(spatial=-1)),
                    input_core_dims=[['spatial', 'frame'], ['unit_id', 'lag'],
                                     ['spatial'], ['unit_id', 'spatial']],
                    output_core_dims=[['trace', 'unit_id', 'frame']],
                    dask='parallelized',
                    kwargs=dict(
                        sparse_penal=sparse_penal,
                        max_iters=max_iters,
                        scs_fallback=scs_fallback),
                    output_sizes=dict(trace=5),
                    output_dtypes=[YrA.dtype])
            else:
                cur_YrA = YrA_norm.where(unit_labels == cur_labl, drop=True)
                cur_sn_temp = sn_temp.where(unit_labels == cur_labl, drop=True)
                cur_res = delayed(xr.apply_ufunc)(
                    update_temporal_cvxpy,
                    cur_YrA.compute(),
                    cur_g.compute(),
                    cur_sn_temp.compute(),
                    input_core_dims=[['unit_id', 'frame'], ['unit_id', 'lag'],
                                     ['unit_id']],
                    output_core_dims=[['trace', 'unit_id', 'frame']],
                    dask='forbidden',
                    kwargs=dict(
                        sparse_penal=sparse_penal,
                        max_iters=max_iters,
                        scs_fallback=scs_fallback),
                    output_sizes=dict(trace=5),
                    output_dtypes=[YrA.dtype])
                res_list.append(cur_res)
        if compute:
            with da.config.set(scheduler='processes'):
                result_ovlp, = da.compute(res_list)
                result = (xr.concat(result_ovlp + [result_iso], 'unit_id')
                          .sortby('unit_id').drop('unit_labels'))
    else:
        result = result_iso.sortby('unit_id').drop('unit_labels')
    C_new = result.isel(trace=0).dropna('unit_id')
    S_new = result.isel(trace=1).dropna('unit_id')
    B_new = result.isel(trace=2, frame=0).dropna('unit_id').squeeze()
    C0_new = result.isel(trace=3, frame=0).dropna('unit_id').squeeze()
    dc_new = result.isel(trace=4).dropna('unit_id')
    g_new = g.sel(unit_id=C_new.coords['unit_id']).drop('unit_labels')
    if zero_thres:
        mask = S_new.where(S_new > zero_thres).fillna(0).sum('frame').astype(bool)
        mask_coord = mask.where(~mask, drop=True).coords['unit_id'].values
        print("{} units dropped due to poor fit:\n {}".format(len(mask_coord), str(mask_coord)))
    else:
        mask_coord = S_new.coords['unit_id'].values
        mask = xr.DataArray(np.ones(len(mask_coord)),
                            dims=['unit_id'],
                            coords=dict(unit_id=mask_coord))
    C_new, S_new, C0_new, B_new, dc_new = (
        C_new.where(mask, drop=True), S_new.where(mask, drop=True),
        C0_new.where(mask, drop=True), B_new.where(mask, drop=True),
        dc_new.where(mask, drop=True))
    YrA_new = YrA.drop('unit_labels').sel(unit_id=C_new.coords['unit_id'])
    sig_new = (C0_new * dc_new + B_new + C_new).persist()
    if post_scal and len(sig_new) > 0:
        print("post-hoc scaling")
        def lstsq(a, b):
            a = np.atleast_2d(a).T
            return np.linalg.lstsq(a, b, rcond=-1)[0]
        scal = xr.apply_ufunc(
            lstsq,
            sig_new.chunk(dict(frame=-1)),
            YrA_new.chunk(dict(frame=-1)),
            input_core_dims=[['frame'], ['frame']],
            output_core_dims=[[]],
            vectorize=True,
            dask='parallelized',
            output_dtypes=[C_new.dtype])
        scal = scal.persist()
        C_new = (C_new * scal).persist()
        S_new = (S_new * scal).persist()
        B_new = (B_new * scal).persist()
        C0_new = (C0_new * scal).persist()
        sig_new = (sig_new * scal).persist()
    else:
        scal=None
    if len(sig_new) > 0:
        C_new = rechunk_like(C_new.persist(), C)
        S_new = rechunk_like(S_new.persist(), C)
        B_new = rechunk_like(B_new.persist(), C)
        C0_new = rechunk_like(C0_new.persist(), C)
        g_new = rechunk_like(g_new.persist(), C)
        sig_new = rechunk_like(sig_new.persist(), C)
    return (YrA_norm, C_new, S_new, B_new, C0_new, sig_new, g_new, scal)


def get_ar_coef(y, sn, p, add_lag, pad=None):
    if add_lag is 'p':
        max_lag = p * 2
    else:
        max_lag = p + add_lag
    cov = acovf(y, fft=True)
    C_mat = toeplitz(cov[:max_lag], cov[:p]) - sn**2 * np.eye(max_lag, p)
    g = lstsq(C_mat, cov[1:max_lag + 1])[0]
    if pad:
        res = np.zeros(pad)
        res[:len(g)] = g
        return res
    else:
        return g


def get_p(y):
    dif = np.append(np.diff(y), 0)
    rising = dif > 0
    prd_ris, num_ris = label(rising)
    ext_prd = np.zeros(num_ris)
    for id_prd in range(num_ris):
        prd = y[prd_ris == id_prd + 1]
        ext_prd[id_prd] = prd[-1] - prd[0]
    id_max_prd = np.argmax(ext_prd)
    return np.sum(rising[prd_ris == id_max_prd + 1])


def update_temporal_cvxpy(y, g, sn, A=None, **kwargs):
    """
    spatial:
    (d, f), (u, p), (d), (d, u)
    (d, f), (p), (d), (d)
    trace:
    (u, f), (u, p), (u)
    (f), (p), ()
    """
    # get_parameters
    sparse_penal = kwargs.get('sparse_penal')
    max_iters = kwargs.get('max_iters')
    use_cons = kwargs.get('use_cons', False)
    scs = kwargs.get('scs_fallback')
    # conform variables to generalize multiple unit case
    if y.ndim < 2:
        y = y.reshape((1, -1))
    if g.ndim < 2:
        g = g.reshape((1, -1))
    sn = np.atleast_1d(sn)
    if A is not None:
        if A.ndim < 2:
            A = A.reshape((-1, 1))
    # get count of frames and units
    _T = y.shape[-1]
    _u = g.shape[0]
    if A is not None:
        _d = A.shape[0]
    # construct G matrix and decay vector per unit
    dc_vec = np.zeros((_u, _T))
    G_ls = []
    for cur_u in range(_u):
        cur_g = g[cur_u, :]
        # construct first column and row
        cur_c, cur_r = np.zeros(_T), np.zeros(_T)
        cur_c[0] = 1
        cur_r[0] = 1
        cur_c[1:len(cur_g) + 1] = -cur_g
        # update G with toeplitz matrix
        G_ls.append(cvx.Constant(dia_matrix(toeplitz(cur_c, cur_r))))
        # update dc_vec
        cur_gr = np.roots(cur_c)
        dc_vec[cur_u, :] = np.max(cur_gr.real)**np.arange(_T)
    # get noise threshold
    thres_sn = sn * np.sqrt(_T)
    # construct variables
    b = cvx.Variable(_u)  # baseline fluorescence per unit
    c0 = cvx.Variable(_u)  # initial fluorescence per unit
    c = cvx.Variable((_u, _T))  # calcium trace per unit
    s = cvx.vstack(
        [G_ls[u] * c[u, :] for u in range(_u)])  # spike train per unit
    # residual noise per unit
    if A is not None:
        sig = cvx.vstack([
            (A * c)[px, :]
            + (A * b)[px, :]
            + (A * cvx.diag(c0) *dc_vec)[px, :] for px in range(_d)])
        noise = y - sig
    else:
        sig = cvx.vstack([c[u, :] + b[u] + c0[u] * dc_vec[u, :] for u in range(_u)])
        noise = y - sig
    noise = cvx.vstack(
        [cvx.norm(noise[i, :], 2) for i in range(noise.shape[0])])
    # construct constraints
    cons = []
    cons.append(b >= np.min(y, axis=-1))  # baseline larger than minimum
    cons.append(c0 >= 0)  # initial fluorescence larger than 0
    cons.append(s >= 0)  # spike train non-negativity
    # noise constraints
    cons_noise = [noise[i] <= thres_sn[i] for i in range(thres_sn.shape[0])]
    try:
        obj = cvx.Minimize(cvx.sum(cvx.norm(s, 1, axis=1)))
        prob = cvx.Problem(obj, cons + cons_noise)
        if use_cons:
            _ = prob.solve(solver='ECOS')
        if not (prob.status == 'optimal'
                or prob.status == 'optimal_inaccurate'):
            if use_cons:
                warnings.warn("constrained version of problem infeasible")
            raise ValueError
    except (ValueError, cvx.SolverError):
        lam = sn * sparse_penal / sn.shape[0] # hacky correction for near-linear relationship between sparsity and number of concurrently updated units
        obj = cvx.Minimize(cvx.sum(cvx.sum(noise, axis=1) + lam * cvx.norm(s, 1, axis=1)))
        prob = cvx.Problem(obj, cons)
        try:
            _ = prob.solve(solver='ECOS', max_iters=max_iters)
            if prob.status in ["infeasible", "unbounded", None]:
                raise ValueError
        except (cvx.SolverError, ValueError):
            try:
                if scs:
                    _ = prob.solve(solver='SCS', max_iters=200)
                if prob.status in ["infeasible", "unbounded", None]:
                    raise ValueError
            except (cvx.SolverError, ValueError):
                warnings.warn(
                    "problem status is {}, returning null".format(prob.status),
                    RuntimeWarning)
                return np.full((5, c.shape[0], c.shape[1]), np.nan).squeeze()
    if not prob.status is 'optimal':
        warnings.warn("problem solved sub-optimally", RuntimeWarning)
    try:
        return np.stack(
        np.broadcast_arrays(c.value, s.value, b.value.reshape((-1, 1)),
                            c0.value.reshape((-1, 1)), dc_vec)).squeeze()
    except:
        set_trace()


def unit_merge(A, C, add_list=None, thres_corr=0.9):
    print("computing spatial overlap")
    A_bl = ((A > 0).astype(np.float32)
            .chunk(dict(unit_id='auto', height=-1, width=-1)))
    A_ovlp = xr.apply_ufunc(
        da.array.tensordot,
        A_bl,
        A_bl.rename(unit_id='unit_id_cp'),
        input_core_dims=[['unit_id', 'height', 'width'],
                         ['height', 'width', 'unit_id_cp']],
        output_core_dims=[['unit_id', 'unit_id_cp']],
        dask='allowed',
        kwargs=dict(axes=([1, 2], [0, 1])),
        output_dtypes=[A_bl.dtype])
    A_ovlp = A_ovlp.persist()
    print("computing temporal correlation")
    uid_idx = C.coords['unit_id'].values
    corr = xr.apply_ufunc(
        np.corrcoef,
        C.compute(),
        input_core_dims=[['unit_id', 'frame']],
        output_core_dims=[['unit_id', 'unit_id_cp']],
        output_sizes=dict(unit_id_cp=len(uid_idx)))
    corr = corr.assign_coords(unit_id_cp=uid_idx)
    print("labeling units to be merged")
    adj = np.logical_and(A_ovlp > 0, corr > thres_corr)
    unit_labels = xr.apply_ufunc(
        label_connected,
        adj.compute(),
        input_core_dims=[['unit_id', 'unit_id_cp']],
        output_core_dims=[['unit_id']])
    print("merging units")
    A_merge = (A.assign_coords(unit_labels=unit_labels)
               .groupby('unit_labels').sum('unit_id')
               .persist().rename(unit_labels='unit_id'))
    C_merge = (C.assign_coords(unit_labels=unit_labels)
               .groupby('unit_labels').mean('unit_id')
               .persist().rename(unit_labels='unit_id'))
    A_merge = rechunk_like(A_merge, A)
    C_merge = rechunk_like(C_merge, C)
    if add_list:
        for ivar, var in enumerate(add_list):
            var_mrg = (var.assign_coords(unit_labels=unit_labels)
                       .groupby('unit_labels').mean('unit_id')
                       .persist().rename(unit_labels='unit_id'))
            add_list[ivar] = rechunk_like(var_mrg, var)
        return A_merge, C_merge, add_list
    else:
        return A_merge, C_merge


def label_connected(adj, only_connected=False):
    np.fill_diagonal(adj, 0)
    adj = np.triu(adj)
    g = nx.convert_matrix.from_numpy_matrix(adj)
    labels = np.zeros(adj.shape[0], dtype=np.int)
    for icomp, comp in enumerate(nx.connected_components(g)):
        comp = list(comp)
        if only_connected and len(comp) == 1:
            labels[comp] = -1
        else:
            labels[comp] = icomp
    return labels


def smooth_sig(sig, freq, btype='low'):
    but_b, but_a = butter(2, freq, btype=btype, analog=False)
    sig_smth = xr.apply_ufunc(
            lambda x: lfilter(but_b, but_a, x),
            sig.chunk(dict(frame=-1)),
            input_core_dims=[['frame']],
            output_core_dims=[['frame']],
            vectorize=True,
            dask='parallelized',
            output_dtypes=[sig.dtype])
    return sig_smth