""" Spherical harmonics transforms/inverses and spherical convolution. """


import functools

import tensorflow as tf
import numpy as np
from scipy.special import sph_harm

from .util import safe_cast

from . import util
from . import tfnp_compatibility as tfnp
from .tfnp_compatibility import istf


# cache outputs; 2050 > 32*64
@functools.lru_cache(maxsize=2050, typed=False)
def sph_harm_lm(l, m, n):
    """ Wrapper around scipy.special.sph_harm. Return spherical harmonic of degree l and order m. """
    phi, theta = util.sph_sample(n)
    phi, theta = np.meshgrid(phi, theta)
    f = sph_harm(m, l, theta, phi)

    return f


def sph_harm_all(n, as_tfvar=False, real=False):
    """ Compute spherical harmonics for an n x n input (degree up to n // 2)

    Args:
        n (int): input dimensions; order will be n // 2
        as_tfvar (bool): if True, return as list of tensorflow Variables.
        real (bool): if True, return real harmonics
    """
    harmonics = []

    for l in range(n // 2):
        if real:
            minl = 0
        else:
            minl = -l
        row = []
        for m in range(minl, l+1):
            row.append(sph_harm_lm(l, m, n))
        harmonics.append(row)

    if as_tfvar:
        return tf.cast(tf.constant(sph_harm_to_shtools(harmonics)),
                       'complex64')

    else:
        return harmonics


def DHaj(n, mode='DH'):
    """ Sampling weights. """
    # Driscoll and Healy sampling weights (on the phi dimension)
    # note: weights depend on the chosen grid, given by sph_sample
    if mode == 'DH':
        gridfun = lambda j: np.pi*j/n
    elif mode == 'ours':
        gridfun = lambda j: np.pi*(2*j+1)/2/n
    else:
        raise NotImplementedError()

    l = np.arange(0, n/2)
    a = [(2*np.sqrt(2)/n *
          np.sin(gridfun(j)) *
          (1/(2*l+1) * np.sin((2*l+1)*gridfun(j))).sum())
         for j in range(n)]

    return a


def sph_harm_transform(f, mode='DH', harmonics=None):
    """ Project spherical function into the spherical harmonics basis. """
    assert f.shape[0] == f.shape[1]

    if isinstance(f, tf.Tensor):
        sumfun = tf.reduce_sum
        conjfun = lambda x: tf.conj(x)
        n = f.shape[0].value
    else:
        sumfun = np.sum
        conjfun = np.conj
        n = f.shape[0]
    assert np.log2(n).is_integer()

    if harmonics is None:
        harmonics = sph_harm_all(n)

    a = DHaj(n, mode)

    f = f*np.array(a)[np.newaxis, :]

    real = is_real_sft(harmonics)

    coeffs = []
    for l in range(n // 2):
        row = []
        minl = 0 if real else -l
        for m in range(minl, l+1):
            # WARNING: results are off by this factor, when using driscoll1994computing formulas
            factor = 2*np.sqrt(np.pi)
            row.append(sumfun(factor * np.sqrt(2*np.pi)/n * f * conjfun(harmonics[l][m-minl])))
        coeffs.append(row)

    return coeffs


def sph_harm_inverse(c, harmonics=None):
    """ Inverse spherical harmonics transform. """
    n = 2*len(c)

    real = is_real_sft(c)
    dtype = 'float32' if real else c[1][1].dtype
    if harmonics is None:
        harmonics = sph_harm_all(n, real=real)

    if isinstance(c[0][0], tf.Tensor):
        f = tf.zeros((n, n), dtype=dtype)
    else:
        f = np.zeros((n, n), dtype=dtype)

    for l in range(n // 2):
        lenm = l+1 if real else 2*l+1
        for m in range(lenm):
            if real:
                # leverage symmetry of coefficients and harmonics
                factor = 1 if m == 0 else 2
                f += factor*(tfnp.real(c[l][m]) * tfnp.real(harmonics[l][m]) -
                             tfnp.imag(c[l][m]) * tfnp.imag(harmonics[l][m]))
            else:
                f += c[l][m] * harmonics[l][m]

    return f


def sph_harm_transform_batch(f, method=None, *args, **kwargs):
    return sph_harm_transform_batch_naive(f, *args, **kwargs)


def sph_harm_inverse_batch(c, method=None, *args, **kwargs):
    return sph_harm_inverse_batch_naive(c, *args, **kwargs)


def sph_harm_transform_batch_naive(f, harmonics=None, m0_only=False):
    """ Spherical harmonics batch-transform.

    Args:
        f (n, l, l, c)-array : functions are on l x l grid
        harmonics (2, l/2, l/2, l, l)-array:
        m0_only (bool): return only coefficients with order 0;
                        only them are needed when computing convolutions

    Returns:
        coeffs ((n, 2, l/2, l/2, c)-array):
    """
    shapef = tfnp.shape(f)
    n, l = shapef[:2]
    assert shapef[2] == l
    if harmonics is None:
        harmonics = sph_harm_to_shtools(sph_harm_all(l))
    shapeh = tfnp.shape(harmonics)
    assert shapeh[1:] == (l//2, l//2, l, l)
    assert shapeh[0] in [1, 2]

    aj = np.array(DHaj(l))

    # returns m=0 only; useful to expand kernel in spherical convolution
    if m0_only:
        harmonics = harmonics[slice(0, 1), :, slice(0, 1), ...]

    na = np.newaxis
    coeffs = tfnp.transpose(2*np.sqrt(2)*np.pi/l *
                            tfnp.dot(f * aj[na, na, :, na],
                                     tfnp.conj(harmonics),
                                     [[1, 2], [3, 4]]),
                            (0, 2, 3, 4, 1))

    return coeffs


def sph_harm_inverse_batch_naive(c, harmonics=None):
    """ Spherical harmonics batch inverse transform.

    Args:
        c ((n, 2, l/2, l/2, c)-array): sph harm coefficients; max degree is l/2
        harmonics (2, l/2, l/2, l, l)-array:

    Returns:
        recons ((n, l, l, c)-array):
    """
    shapec = tfnp.shape(c)
    l = 2*shapec[2]
    assert shapec[3] == l//2
    if harmonics is None:
        harmonics = sph_harm_to_shtools(sph_harm_all(l))
    shapeh = tfnp.shape(harmonics)
    assert shapeh[1:] == (l//2, l//2, l, l)
    assert shapeh[0] in [1, 2]

    real = True if shapeh[0] == 1 else False

    na = np.newaxis

    if real:
        # using m, -m symmetry:
        # c^{-m}Y^{-m} + c^mY^m = 2(Re(c^{m})Re(Y^m) - Im(c^{m})Im(Y^m))
        # that does not apply to c_0 so we compensate by dividing it by two
        factor = np.ones(tfnp.shape(c)[1:])[np.newaxis, ...]
        factor[..., 0, :] = factor[..., 0, :]/2
        c = c * factor
        # c[..., 0, :] = c[..., 0, :]/2
        recons = tfnp.transpose(2*(tfnp.dot(tfnp.real(c), tfnp.real(harmonics),
                                            [[1, 2, 3], [0, 1, 2]]) -
                                   tfnp.dot(tfnp.imag(c), tfnp.imag(harmonics),
                                            [[1, 2, 3], [0, 1, 2]])),
                                (0, 2, 3, 1))
    else:
        recons = tfnp.transpose(tfnp.dot(c, harmonics,
                                       [[1, 2, 3], [0, 1, 2]]),
                               (0, 2, 3, 1))
    return recons


def sph_conv(f, g, harmonics=None):
    """ Spherical convolution f * g. """
    stackfun = tf.stack if isinstance(f, tf.Tensor) else np.array
    cf, cg = [sph_harm_transform(x, harmonics=harmonics) for x in [f, g]]
    cfg = [2*np.pi*np.sqrt(4*np.pi / (2*l+1)) * stackfun(c1) * c2[l]
           for l, (c1, c2) in enumerate(zip(cf, cg))]

    return sph_harm_inverse(cfg)


def sph_conv_batch(f, g,
                   harmonics_or_legendre=None,
                   method=None,
                   spectral_pool=0,
                   harmonics_or_legendre_low=None):
    """ CNN-like batch spherical convolution.

    Args:
        f (n, l, l, c)-array: input feature map. n entries, c channels
        g (c, l, l, d)-array: convolution kernels
        harmonics_or_legendre (): spherical harmonics or legendre polynomials to expand f and g
        method (str): see sph_harm_transform_batch
        spectral_pool (int): if > 0 run spectral pooling before ISHT
    (bandwidth is reduced by a factor of 2**spectral_pool)
        harmonics_or_legendre_low (): low frequency harmonics of legendre to be used when spectral_pool==True

    Returns:
        fg (n, l, l, d)-array
    """
    shapef, shapeg = [tfnp.shape(x) for x in [f, g]]
    spectral_filter = True if len(shapeg) == 5 else False
    spectral_input = True if len(shapef) == 5 else False
    n = shapef[2]
    if spectral_input:
        n *= 2

    if not spectral_input:
        cf = sph_harm_transform_batch(f, method, harmonics_or_legendre, m0_only=False)
    else:
        cf = f
    if not spectral_filter:
        cg = sph_harm_transform_batch(g, method, harmonics_or_legendre, m0_only=True)
    else:
        cg = g

    shapecf, shapecg = [tfnp.shape(x) for x in [cf, cg]]
    assert shapecf[4] == shapecg[0]
    assert shapecf[2] == shapecg[2]

    na = np.newaxis
    # per degree factor
    factor = (2*np.pi*np.sqrt(4*np.pi / (2*np.arange(n/2)+1)))[na, na, :, na, na, na]
    cg = tfnp.transpose(cg, (1, 2, 3, 0, 4))[na, ...]
    cf = cf[..., na]
    if istf(cg) and istf(cf):
        cg, cf = safe_cast(cg, cf)
        cfg = factor * cf * cg
    else:
        cfg = factor * cf * cg

    if spectral_pool > 0:
        cfg = cfg[:, :, :n//(2**(spectral_pool+1)), :n//(2**(spectral_pool+1)), ...]
        hol = harmonics_or_legendre_low
    else:
        hol = harmonics_or_legendre

    # sum over channels
    cfg = tfnp.sum(cfg, axis=-2)

    return sph_harm_inverse_batch(cfg, method, hol)


def is_real_sft(h_or_c):
    """ Detect if list of lists of harmonics or coefficients assumes real inputs (m>0) """
    if istf(h_or_c):
        d = tfnp.shape(h_or_c[1])[0]
    else:
        d = len(h_or_c[1])

    isreal = True if d == 2 else False

    return isreal


def sph_harm_to_shtools(c):
    """ Convert our list format for the sph harm coefficients/harmonics to pyshtools (2, n, n) format. """
    n = len(c)
    real = is_real_sft(c)
    dim1 = 1 if real else 2
    out = np.zeros((dim1, n, n, *c[0][0].shape)) + 0j
    for l, cc in enumerate(c):
        cc = np.array(cc)
        if not real:
            m_minus = cc[:l][::-1]
            m_plus = cc[l:]
        else:
            m_minus = np.array([])
            m_plus = cc

        # we get warnings here when using reals
        if m_minus.size > 0:
            out[1, l, 1:l+1, ...] = m_minus
        out[0, l, :l+1, ...] = m_plus

    return out