# Copyright 2016 James Hensman
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import tensorflow as tf
from gpflow import settings
from functools import reduce
float_type = settings.dtypes.float_type
import numpy as np


class BlockDiagMat_many:
    def __init__(self, mats):
        self.mats = mats

    @property
    def shape(self):
        return (sum([m.shape[0] for m in mats]), sum([m.shape[1] for m in mats]))

    @property
    def sqrt_dims(self):
        return sum([m.sqrt_dims for m in mats])

    def _get_rhs_slices(self, X):
        ret = []
        start = 0
        for m in self.mats:
            ret.append(tf.slice(X, begin=tf.stack([start, 0]), size=tf.stack([m.shape[1], -1])))
            start = start + m.shape[1]
        return ret

    def _get_rhs_blocks(self, X):
        """
        X is a solid matrix, same size as this one. Get the blocks of X that
        correspond to the structure of this matrix
        """
        ret = []
        start1 = 0
        start2 = 0
        for m in self.mats:
            ret.append(tf.slice(X, begin=tf.stack([start1, start2]), size=m.shape))
            start1 = start1 + m.shape[0]
            start2 = start2 + m.shape[1]
        return ret

    def get(self):
        ret = self.mats[0].get()
        for m in self.mats[1:]:
            tr_shape = tf.stack([tf.shape(ret)[0], m.shape[1]])
            bl_shape = tf.stack([m.shape[0], tf.shape(ret)[1]])
            top = tf.concat([ret, tf.zeros(tr_shape, float_type)], axis=1)
            bottom = tf.concat([tf.zeros(bl_shape, float_type), m.get()], axis=1)
            ret = tf.concat([top, bottom], axis=0)
        return ret

    def logdet(self):
        return reduce(tf.add, [m.logdet() for m in self.mats])

    def matmul(self, X):
        return tf.concat([m.matmul(Xi) for m, Xi in zip(self.mats, self._get_rhs_slices(X))], axis=0)

    def solve(self, X):
        return tf.concat([m.solve(Xi) for m, Xi in zip(self.mats, self._get_rhs_slices(X))], axis=0)

    def inv(self):
        return BlockDiagMat_many([mat.inv() for mat in self.mats])

    def trace_KiX(self, X):
        """
        X is a square matrix of the same size as this one.
        if self is K, compute tr(K^{-1} X)
        """
        return reduce(tf.add, [m.trace_KiX(Xi) for m, Xi in zip(self.mats, self._get_rhs_blocks(X))])

    def get_diag(self):
        return tf.concat([m.get_diag() for m in self.mats], axis=0)

    def inv_diag(self):
        return tf.concat([m.inv_diag() for m in self.mats], axis=0)

    def matmul_sqrt(self, X):
        return tf.concat([m.matmul_sqrt(Xi) for m, Xi in zip(self.mats, self._get_rhs_slices(X))], axis=0)

    def matmul_sqrt_transpose(self, X):
        ret = []
        start = np.zeros((2, np.int32))
        for m in self.mats:
            ret.append(m.matmul_sqrt_transpose(tf.slice(X, begin=start, size=tf.stack([m.sqrt_dims, -1]))))
            start[0] += m.sqrt_dims

        return tf.concat(ret, axis=0)


class BlockDiagMat:
    def __init__(self, A, B):
        self.A, self.B = A, B

    @property
    def shape(self):
        mats = [self.A, self.B]
        return (sum([m.shape[0] for m in mats]), sum([m.shape[1] for m in mats]))

    @property
    def sqrt_dims(self):
        mats = [self.A, self.B]
        return sum([m.sqrt_dims for m in mats])

    def _get_rhs_slices(self, X):
        # X1 = X[:self.A.shape[1], :]
        X1 = tf.slice(X, begin=tf.zeros((2,), tf.int32), size=tf.stack([self.A.shape[1], -1]))
        # X2 = X[self.A.shape[1]:, :]
        X2 = tf.slice(X, begin=tf.stack([self.A.shape[1], 0]), size=-tf.ones((2,), tf.int32))
        return X1, X2

    def get(self):
        tl_shape = tf.stack([self.A.shape[0], self.B.shape[1]])
        br_shape = tf.stack([self.B.shape[0], self.A.shape[1]])
        top = tf.concat([self.A.get(), tf.zeros(tl_shape, float_type)], axis=1)
        bottom = tf.concat([tf.zeros(br_shape, float_type), self.B.get()], axis=1)
        return tf.concat([top, bottom], axis=0)

    def logdet(self):
        return self.A.logdet() + self.B.logdet()

    def matmul(self, X):
        X1, X2 = self._get_rhs_slices(X)
        top = self.A.matmul(X1)
        bottom = self.B.matmul(X2)
        return tf.concat([top, bottom], axis=0)

    def solve(self, X):
        X1, X2 = self._get_rhs_slices(X)
        top = self.A.solve(X1)
        bottom = self.B.solve(X2)
        return tf.concat([top, bottom], axis=0)

    def inv(self):
        return BlockDiagMat(self.A.inv(), self.B.inv())

    def trace_KiX(self, X):
        """
        X is a square matrix of the same size as this one.
        if self is K, compute tr(K^{-1} X)
        """
        X1, X2 = tf.slice(X, [0, 0], self.A.shape), tf.slice(X, self.A.shape, [-1, -1])
        top = self.A.trace_KiX(X1)
        bottom = self.B.trace_KiX(X2)
        return top + bottom

    def get_diag(self):
        return tf.concat([self.A.get_diag(), self.B.get_diag()], axis=0)

    def inv_diag(self):
        return tf.concat([self.A.inv_diag(), self.B.inv_diag()], axis=0)

    def matmul_sqrt(self, X):
        X1, X2 = self._get_rhs_slices(X)
        top = self.A.matmul_sqrt(X1)
        bottom = self.B.matmul_sqrt(X2)
        return tf.concat([top, bottom], axis=0)

    def matmul_sqrt_transpose(self, X):
        X1 = tf.slice(X, begin=tf.zeros((2,), tf.int32), size=tf.stack([self.A.sqrt_dims, -1]))
        X2 = tf.slice(X, begin=tf.stack([self.A.sqrt_dims, 0]), size=-tf.ones((2,), tf.int32))
        top = self.A.matmul_sqrt_transpose(X1)
        bottom = self.B.matmul_sqrt_transpose(X2)

        return tf.concat([top, bottom], axis=0)


class LowRankMat:
    def __init__(self, d, W):
        """
        A matrix of the form

            diag(d) + W W^T

        """
        self.d = d
        self.W = W

    @property
    def shape(self):
        return (tf.size(self.d), tf.size(self.d))

    @property
    def sqrt_dims(self):
        return tf.size(self.d) + tf.shape(W)[1]

    def get(self):
        return tf.diag(self.d) + tf.matmul(self.W, tf.transpose(self.W))

    def logdet(self):
        part1 = tf.reduce_sum(tf.log(self.d))
        I = tf.eye(tf.shape(self.W)[1], float_type)
        M = I + tf.matmul(tf.transpose(self.W) / self.d, self.W)
        part2 = 2*tf.reduce_sum(tf.log(tf.diag_part(tf.cholesky(M))))
        return part1 + part2

    def matmul(self, B):
        WTB = tf.matmul(tf.transpose(self.W), B)
        WWTB = tf.matmul(self.W, WTB)
        DB = tf.reshape(self.d, [-1, 1]) * B
        return DB + WWTB

    def get_diag(self):
        return self.d + tf.reduce_sum(tf.square(self.W), 1)

    def solve(self, B):
        d_col = tf.expand_dims(self.d, 1)
        DiB = B / d_col
        DiW = self.W / d_col
        WTDiB = tf.matmul(tf.transpose(DiW), B)
        M = tf.eye(tf.shape(self.W)[1], float_type) + tf.matmul(tf.transpose(DiW), self.W)
        L = tf.cholesky(M)
        tmp1 = tf.matrix_triangular_solve(L, WTDiB, lower=True)
        tmp2 = tf.matrix_triangular_solve(tf.transpose(L), tmp1, lower=False)
        return DiB - tf.matmul(DiW, tmp2)

    def inv(self):
        di = tf.reciprocal(self.d)
        d_col = tf.expand_dims(self.d, 1)
        DiW = self.W / d_col
        M = tf.eye(tf.shape(self.W)[1], float_type) + tf.matmul(tf.transpose(DiW), self.W)
        L = tf.cholesky(M)
        v = tf.transpose(tf.matrix_triangular_solve(L, tf.transpose(DiW), lower=True))
        return LowRankMatNeg(di, V)

    def trace_KiX(self, X):
        """
        X is a square matrix of the same size as this one.
        if self is K, compute tr(K^{-1} X)
        """
        d_col = tf.expand_dims(self.d, 1)
        R = self.W / d_col
        RTX = tf.matmul(tf.transpose(R), X)
        RTXR = tf.matmul(RTX, R)
        M = tf.eye(tf.shape(self.W)[1], float_type) + tf.matmul(tf.transpose(R), self.W)
        Mi = tf.matrix_inverse(M)
        return tf.reduce_sum(tf.diag_part(X) * 1./self.d) - tf.reduce_sum(RTXR * Mi)

    def inv_diag(self):
        d_col = tf.expand_dims(self.d, 1)
        WTDi = tf.transpose(self.W / d_col)
        M = tf.eye(tf.shape(self.W)[1], float_type) + tf.matmul(WTDi, self.W)
        L = tf.cholesky(M)
        tmp1 = tf.matrix_triangular_solve(L, WTDi, lower=True)
        return 1./self.d - tf.reduce_sum(tf.square(tmp1), 0)

    def matmul_sqrt(self, B):
        """
        There's a non-square sqrt of this matrix given by
          [ D^{1/2}]
          [   W^T  ]

        This method right-multiplies the sqrt by the matrix B
        """

        DB = tf.expand_dims(tf.sqrt(self.d), 1) * B
        VTB = tf.matmul(tf.transpose(self.W), B)
        return tf.concat([DB, VTB], axis=0)

    def matmul_sqrt_transpose(self, B):
        """
        There's a non-square sqrt of this matrix given by
          [ D^{1/2}]
          [   W^T  ]

        This method right-multiplies the transposed-sqrt by the matrix B
        """
        B1 = tf.slice(B, tf.zeros((2,), tf.int32), tf.stack([tf.size(self.d), -1]))
        B2 = tf.slice(B, tf.stack([tf.size(self.d), 0]), -tf.ones((2,), tf.int32))
        return tf.expand_dims(tf.sqrt(self.d), 1) * B1 + tf.matmul(self.W, B2)


class LowRankMatNeg:
    def __init__(self, d, W):
        """
        A matrix of the form

            diag(d) - W W^T

        (note the minus sign)
        """
        self.d = d
        self.W = W

    @property
    def shape(self):
        return (tf.size(self.d), tf.size(self.d))

    def get(self):
        return tf.diag(self.d) - tf.matmul(self.W, tf.transpose(self.W))




class Rank1Mat:
    def __init__(self, d, v):
        """
        A matrix of the form

            diag(d) + v v^T

        """
        self.d = d
        self.v = v

    @property
    def shape(self):
        return (tf.size(self.d), tf.size(self.d))

    @property
    def sqrt_dims(self):
        return tf.size(self.d) + 1

    def get(self):
        V = tf.expand_dims(self.v, 1)
        return tf.diag(self.d) + tf.matmul(V, tf.transpose(V))

    def logdet(self):
        return tf.reduce_sum(tf.log(self.d)) +\
            tf.log(1. + tf.reduce_sum(tf.square(self.v) / self.d))

    def matmul(self, B):
        V = tf.expand_dims(self.v, 1)
        return tf.expand_dims(self.d, 1) * B +\
            tf.matmul(V, tf.matmul(tf.transpose(V), B))

    def solve(self, B):
        div = self.v / self.d
        c = 1. + tf.reduce_sum(div * self.v)
        div = tf.expand_dims(div, 1)
        return B / tf.expand_dims(self.d, 1) -\
            tf.matmul(div/c, tf.matmul(tf.transpose(div), B))

    def inv(self):
        di = tf.reciprocal(self.d)
        Div = self.v * di
        M = 1. + tf.reduce_sum(Div * self.v)
        v_new = Div / tf.sqrt(M)
        return Rank1MatNeg(di, v_new)

    def trace_KiX(self, X):
        """
        X is a square matrix of the same size as this one.
        if self is K, compute tr(K^{-1} X)
        """
        R = tf.expand_dims(self.v / self.d, 1)
        RTX = tf.matmul(tf.transpose(R), X)
        RTXR = tf.matmul(RTX, R)
        M = 1 + tf.reduce_sum(tf.square(self.v) / self.d)
        return tf.reduce_sum(tf.diag_part(X) / self.d) - RTXR / M

    def get_diag(self):
        return self.d + tf.square(self.v)

    def inv_diag(self):
        div = self.v / self.d
        c = 1. + tf.reduce_sum(div * self.v)
        return 1./self.d - tf.square(div) / c

    def matmul_sqrt(self, B):
        """
        There's a non-square sqrt of this matrix given by
          [ D^{1/2}]
          [   V^T  ]

        This method right-multiplies the sqrt by the matrix B
        """

        DB = tf.expand_dims(tf.sqrt(self.d), 1) * B
        VTB = tf.matmul(tf.expand_dims(self.v, 0), B)
        return tf.concat([DB, VTB], axis=0)

    def matmul_sqrt_transpose(self, B):
        """
        There's a non-square sqrt of this matrix given by
          [ D^{1/2}]
          [   W^T  ]

        This method right-multiplies the transposed-sqrt by the matrix B
        """
        B1 = tf.slice(B, tf.zeros((2,), tf.int32), tf.stack([tf.size(self.d), -1]))
        B2 = tf.slice(B, tf.stack([tf.size(self.d), 0]), -tf.ones((2,), tf.int32))
        return tf.expand_dims(tf.sqrt(self.d), 1) * B1 + tf.matmul(tf.expand_dims(self.v, 1), B2)


class Rank1MatNeg:
    def __init__(self, d, v):
        """
        A matrix of the form

            diag(d) - v v^T

        (note the minus sign)
        """
        self.d = d
        self.v = v

    @property
    def shape(self):
        return (tf.size(self.d), tf.size(self.d))

    def get(self):
        W = tf.expand_dims(self.v, 1)
        return tf.diag(self.d) - tf.matmul(W, tf.transpose(W))


class DiagMat:
    def __init__(self, d):
        self.d = d

    @property
    def shape(self):
        return (tf.size(self.d), tf.size(self.d))

    @property
    def sqrt_dims(self):
        return tf.size(self.d)

    def get(self):
        return tf.diag(self.d)

    def logdet(self):
        return tf.reduce_sum(tf.log(self.d))

    def matmul(self, B):
        return tf.expand_dims(self.d, 1) * B

    def solve(self, B):
        return B / tf.expand_dims(self.d, 1)

    def inv(self):
        return DiagMat(tf.reciprocal(self.d))

    def trace_KiX(self, X):
        """
        X is a square matrix of the same size as this one.
        if self is K, compute tr(K^{-1} X)
        """
        return tf.reduce_sum(tf.diag_part(X) / self.d)

    def get_diag(self):
        return self.d

    def inv_diag(self):
        return 1. / self.d

    def matmul_sqrt(self, B):
        return tf.expand_dims(tf.sqrt(self.d), 1) * B

    def matmul_sqrt_transpose(self, B):
        return tf.expand_dims(tf.sqrt(self.d), 1) * B