"""
wavelet.py:  Linear transforms based on 2d discrete wavelet transform
"""
from __future__ import division

import pywt
import numpy as np

# Import other subpackages in vampyre
import vampyre.common as common

# Import individual classes from same modules in the same package
from vampyre.trans.base import BaseLinTrans


class Wavelet2DLT(BaseLinTrans):
    """
    Linear transform class based on a 2D wavelet

    :param nrow:  number of rows in the image
    :param ncol:  number of columns in the image
    :param wavelet:  wavelet type (see `pywt` package for a full description)
    :param level:  number of wavelet levels
    :param fwd_mode:  `recon` indicates that `dot()` operation is the
       reconstruction and `dotH()` is the analysis.  `analysis` is the reverse.
    """
    def __init__(self,nrow=256,ncol=256,wavelet='db4',level=3,fwd_mode='recon',\
        dtype=np.float64,name=None):

        # Save parameters
        self.wavelet = wavelet
        self.level = level
        shape0 = (nrow,ncol)
        shape1 = (nrow,ncol)
        dtype0 = dtype
        dtype1 = dtype

        if pywt.Wavelet(wavelet).orthogonal:
            svd_avail = True #SVD calculation assumes an orthogonal wavelet
        else:
            svd_avail = False
        BaseLinTrans.__init__(self, shape0, shape1, dtype0, dtype1,\
           svd_avail=svd_avail,name=name)


        # Set the mode to periodic to make the wavelet orthogonal
        self.mode = 'periodization'

        # Send a zero image to get the coefficient slices
        im = np.zeros((nrow,ncol))
        coeffs = pywt.wavedec2(im, wavelet=self.wavelet, level=self.level, \
            mode=self.mode)
        _, self.coeff_slices = pywt.coeffs_to_array(coeffs)


        # Confirm that fwd_mode is valid
        if (fwd_mode != 'recon') and (fwd_mode != 'analysis'):
            raise common.VpException('fwd_mode must be recon or analysis')
        self.fwd_mode = fwd_mode

    def dot(self,z0):
        """
        Forward multiplication
        """
        if (self.fwd_mode == 'recon'):
            z1 = self.recon(z0)
        else:
            z1 = self.analysis(z0)
        return z1

    def dotH(self,z1):
        """
        Reverse / adjoint multiplication
        """
        if (self.fwd_mode == 'recon'):
            z0 = self.analysis(z1)
        else:
            z0 = self.recon(z1)
        return z0

    def analysis(self,z0):
        """
        Analysis:  image -> coefficients
        """
        coeffs = pywt.wavedec2(z0, wavelet=self.wavelet, level=self.level, \
            mode=self.mode)
        z1, _ = pywt.coeffs_to_array(coeffs)
        return z1

    def recon(self,z1):
        """
        Wavelet reconstruction:  coefficients -> image
        """
        coeffs = pywt.array_to_coeffs(z1, self.coeff_slices, \
            output_format='wavedec2')
        z0 = pywt.waverec2(coeffs, wavelet=self.wavelet, mode=self.mode)
        return z0

    def Usvd(self,q1):
        """
        Multiplication by SVD term :math:`U`
        """
        return self.dot(q1)

    def UsvdH(self,z1):
        """
        Multiplication by SVD term :math:`U^*`
        """
        return self.dotH(z1)


    def Vsvd(self,q0):
        """
        Multiplication by SVD term :math:`V`
        """
        return q0

    def VsvdH(self,z0):
        """
        Multiplication by SVD term :math:`V^*`
        """
        return z0

    def get_svd_diag(self):
        """
        Gets parameters of the SVD diagonal multiplication.

        See :func:`vampyre.trans.base.LinTrans.get_svd_diag()` for
        more information.

        :returns: :code:`s,sshape,srep_axes`, the diagonal parameters
            :code:`s`, the shape in the transformed domain :code:`sshape`,
            and the axes on which the diagonal parameters are to be
            repeated, :code:`srep_axes`
        """
        s = 1
        sshape = self.shape0
        srep_axes = (0,1)
        return s, sshape, srep_axes

    def svd_dot(self,s1,q0):
        """
        Performs diagonal matrix multiplication.

        Implements :math:`q_1 = \\mathrm{diag}(s_1) q_0`.

        :param s1: diagonal parameters
        :param q0: input to the diagonal multiplication
        :returns: :code:`q1` diagonal multiplication output
        """
        srep = common.repeat_axes(s1,self.shape0, (0,1),rep=False)
        q1 = srep*q0
        return q1

    def svd_dotH(self,s1,q1):
        """
        Performs diagonal matrix multiplication conjugate

        Implements :math:`q_0 = \\mathrm{diag}(s_1)^* q_1`.

        :param s1: diagonal parameters
        :param q1: input to the diagonal multiplication
        :returns: :code:`q0` diagonal multiplication output
        """
        srep = common.repeat_axes(np.conj(s1),self.shape0,(0,1),rep=False)
        q0 = srep*q1
        return q0