"""
matrix.py:  Linear transforms based on a matrix
"""
from __future__ import division

import numpy as np
from vampyre.trans.base import BaseLinTrans
from vampyre.common.utils import repeat_axes
from vampyre.common.utils import VpException

class MatrixLT(BaseLinTrans):
"""
Linear transform defined by a matrix

The class defines a linear transform :math:z_1 = Az_0 where :math:A
is represented as a :class:numpy.ndarray.

:note: The current code assumes that :param:A is either a 1D or 2D
array.  For higher dimensions, it may be good to develop an alternate
tensor class using the :func:numpy.ndarray.tensordot method.

:param A:  matrix
:param shape0:  input shape (The output shape is computed from this)
"""
def __init__(self, A, shape0):

# Get dimensions
self.A = A
if np.isscalar(shape0):
shape0 = (shape0,)

# Compute the output shape
# Note that A.dot(x) operates on the second to last axis of x
Ashape = A.shape
shape1 = np.array(shape0)
if len(shape0) == 1:
self.aaxis = 0
else:
self.aaxis = len(shape0)-2
shape1[self.aaxis] = Ashape[0]
shape1 = tuple(shape1)

# Check that input shape matches
if shape0[self.aaxis] != Ashape[-1]:
raise VpException("Input shape %s does not match matrix shape %s"\
% (str(shape0), str(Ashape)))

# Get data types
dtype0 = A.dtype
dtype1 = A.dtype

# Superclass constructor
BaseLinTrans.__init__(self, shape0, shape1, dtype0, dtype1,\
svd_avail=True,name='MatrixLT')

# Set SVD terms to not computed
self.svd_computed = False

def dot(self,z0):
"""
Compute matrix multiply :math:A(z0)
"""
return self.A.dot(z0)

def dotH(self,z1):
"""
Compute conjugate transpose multiplication:math:A^*(z1)
"""
return self.A.conj().T.dot(z1)

def _comp_svd(self):
"""
Compute the SVD terms, if necessary

If the SVD is already computed, simply return
"""
# Return if SVD is already computed
if self.svd_computed:
return

# Compute SVD.  Note that linalg.svd returns V, not its
# conjugate transpose as is usual.
U,s,V = np.linalg.svd(self.A, full_matrices=False)
self.U = U
self.s = s
self.V = V.conj().T
self.svd_computed = True

# Compute the shape of the transformed space
self.sshape = np.array(self.shape0)
self.sshape[self.aaxis] = len(s)
self.sshape = tuple(self.sshape)

# Compute the axes on which the diagonal multiplication
# is to be repeated.  This is all but axis 0
ndim = len(self.sshape)
self.srep_axes = tuple(range(1,ndim))

def Usvd(self,q1):
"""
Multiplication by SVD term :math:U
"""
self._comp_svd()
return self.U.dot(q1)

def UsvdH(self,z1):
"""
Multiplication by SVD term :math:U^*
"""
self._comp_svd()
return self.U.conj().T.dot(z1)

def Vsvd(self,q0):
"""
Multiplication by SVD term :math:V
"""
self._comp_svd()
return self.V.dot(q0)

def VsvdH(self,z0):
"""
Multiplication by SVD term :math:V^*
"""
self._comp_svd()
return self.V.conj().T.dot(z0)

def get_svd_diag(self):
"""
Gets parameters of the SVD diagonal multiplication.

See :func:vampyre.trans.base.LinTrans.get_svd_diag() for

:returns: :code:s,sshape,srep_axes, the diagonal parameters
:code:s, the shape in the transformed domain :code:sshape,
and the axes on which the diagonal parameters are to be
repeated, :code:srep_axes
"""
self._comp_svd()
return self.s, self.sshape, self.srep_axes

def svd_dot(self,s1,q0):
"""
Performs diagonal matrix multiplication.

Implements :math:q_1 = \\mathrm{diag}(s_1) q_0.

:param s1: diagonal parameters
:param q0: input to the diagonal multiplication
:returns: :code:q1 diagonal multiplication output
"""
srep = repeat_axes(s1,self.sshape,self.srep_axes,rep=False)
q1 = srep*q0
return q1

def svd_dotH(self,s1,q1):
"""
Performs diagonal matrix multiplication conjugate

Implements :math:q_0 = \\mathrm{diag}(s_1)^* q_1.

:param s1: diagonal parameters
:param q1: input to the diagonal multiplication
:returns: :code:q0 diagonal multiplication output
"""
srep = repeat_axes(np.conj(s1),self.sshape,self.srep_axes,rep=False)
q0 = srep*q1
return q0

#def __str__(self):
#    string = str(self.name) + '\n'\
#              + 'Input: ' + str(self.shape0) + ',' + str(self.dtype0)
#    return string