package org.nd4j.linalg.api.blas.impl;

import lombok.extern.slf4j.Slf4j;
import org.nd4j.linalg.api.blas.Lapack;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JArraySizeException;
import org.nd4j.linalg.factory.Nd4j;

/**
 * Base lapack define float and double versions.
 *
 * @author Adam Gibson
 * @author rcorbish
 */
@Slf4j
public abstract class BaseLapack implements Lapack {

    @Override
    public INDArray getrf(INDArray A) {

        // FIXME: int cast
        if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE)
            throw new ND4JArraySizeException();

        int m = (int) A.rows();
        int n = (int) A.columns();

        INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1),
                        Nd4j.getShapeInfoProvider().createShapeInformation(new int[] {1, 1}).getFirst());

        int mn = Math.min(m, n);
        INDArray IPIV = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(mn),
                        Nd4j.getShapeInfoProvider().createShapeInformation(new int[] {1, mn}).getFirst());

        if (A.data().dataType() == DataBuffer.Type.DOUBLE)
            dgetrf(m, n, A, IPIV, INFO);
        else if (A.data().dataType() == DataBuffer.Type.FLOAT)
            sgetrf(m, n, A, IPIV, INFO);
        else
            throw new UnsupportedOperationException();

        if (INFO.getInt(0) < 0) {
            throw new Error("Parameter #" + INFO.getInt(0) + " to getrf() was not valid");
        } else if (INFO.getInt(0) > 0) {
            log.warn("The matrix is singular - cannot be used for inverse op. Check L matrix at row " + INFO.getInt(0));
        }

        return IPIV;
    }



    /**
    * Float/Double versions of LU decomp.
    * This is the official LAPACK interface (in case you want to call this directly)
    * See getrf for full details on LU Decomp
    *
    * @param M  the number of rows in the matrix A
    * @param N  the number of cols in the matrix A
    * @param A  the matrix to factorize - data must be in column order ( create with 'f' ordering )
    * @param IPIV an output array for the permutations ( must be int based storage )
    * @param INFO error details 1 int array, a positive number (i) implies row i cannot be factored, a negative value implies paramtere i is invalid
    */
    public abstract void sgetrf(int M, int N, INDArray A, INDArray IPIV, INDArray INFO);

    public abstract void dgetrf(int M, int N, INDArray A, INDArray IPIV, INDArray INFO);



    @Override
    public void potrf(INDArray A, boolean lower) {

        // FIXME: int cast
        if (A.columns() > Integer.MAX_VALUE)
            throw new ND4JArraySizeException();

        byte uplo = (byte) (lower ? 'L' : 'U'); // upper or lower part of the factor desired ?
        int n = (int) A.columns();

        INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1),
                        Nd4j.getShapeInfoProvider().createShapeInformation(new int[] {1, 1}).getFirst());

        if (A.data().dataType() == DataBuffer.Type.DOUBLE)
            dpotrf(uplo, n, A, INFO);
        else if (A.data().dataType() == DataBuffer.Type.FLOAT)
            spotrf(uplo, n, A, INFO);
        else
            throw new UnsupportedOperationException();

        if (INFO.getInt(0) < 0) {
            throw new Error("Parameter #" + INFO.getInt(0) + " to potrf() was not valid");
        } else if (INFO.getInt(0) > 0) {
            throw new Error("The matrix is not positive definite! (potrf fails @ order " + INFO.getInt(0) + ")");
        }

        return;
    }



    /**
    * Float/Double versions of cholesky decomp for positive definite matrices    
    * 
    *   A = LL*
    *
    * @param uplo which factor to return L or U 
    * @param A  the matrix to factorize - data must be in column order ( create with 'f' ordering )
    * @param INFO error details 1 int array, a positive number (i) implies row i cannot be factored, a negative value implies paramtere i is invalid
    */
    public abstract void spotrf(byte uplo, int N, INDArray A, INDArray INFO);

    public abstract void dpotrf(byte uplo, int N, INDArray A, INDArray INFO);



    @Override
    public void geqrf(INDArray A, INDArray R) {

        // FIXME: int cast
        if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE)
            throw new ND4JArraySizeException();

        int m = (int) A.rows();
        int n = (int) A.columns();

        INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1),
                        Nd4j.getShapeInfoProvider().createShapeInformation(new int[] {1, 1}).getFirst());

        if (R.rows() != A.columns() || R.columns() != A.columns()) {
            throw new Error("geqrf: R must be N x N (n = columns in A)");
        }
        if (A.data().dataType() == DataBuffer.Type.DOUBLE) {
            dgeqrf(m, n, A, R, INFO);
        } else if (A.data().dataType() == DataBuffer.Type.FLOAT) {
            sgeqrf(m, n, A, R, INFO);
        } else {
            throw new UnsupportedOperationException();
        }

        if (INFO.getInt(0) < 0) {
            throw new Error("Parameter #" + INFO.getInt(0) + " to getrf() was not valid");
        }
    }


    /**
    * Float/Double versions of QR decomp.
    * This is the official LAPACK interface (in case you want to call this directly)
    * See geqrf for full details on LU Decomp
    *
    * @param M  the number of rows in the matrix A
    * @param N  the number of cols in the matrix A
    * @param A  the matrix to factorize - data must be in column order ( create with 'f' ordering )
    * @param R  an output array for other part of factorization
    * @param INFO error details 1 int array, a positive number (i) implies row i cannot be factored, a negative value implies paramtere i is invalid
    */
    public abstract void sgeqrf(int M, int N, INDArray A, INDArray R, INDArray INFO);

    public abstract void dgeqrf(int M, int N, INDArray A, INDArray R, INDArray INFO);



    @Override
    public int syev(char jobz, char uplo, INDArray A, INDArray V) {

        if (A.rows() != A.columns()) {
            throw new Error("syev: A must be square.");
        }
        if (A.rows() != V.length()) {
            throw new Error("syev: V must be the length of the matrix dimension.");
        }

        // FIXME: int cast
        if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE)
            throw new ND4JArraySizeException();

        int status = -1;
        if (A.data().dataType() == DataBuffer.Type.DOUBLE) {
            status = dsyev(jobz, uplo, (int) A.rows(), A, V);
        } else if (A.data().dataType() == DataBuffer.Type.FLOAT) {
            status = ssyev(jobz, uplo, (int) A.rows(), A, V);
        } else {
            throw new UnsupportedOperationException();
        }

        return status;
    }


    /**
    * Float/Double versions of eigen value/vector calc.
    *
    * @param jobz 'N' - no eigen vectors, 'V' - return eigenvectors
    * @param uplo upper or lower part of symmetric matrix to use
    * @param N  the number of rows & cols in the matrix A
    * @param A  the matrix to calculate eigenvectors
    * @param R  an output array for eigenvalues ( may be null )
    */
    public abstract int ssyev(char jobz, char uplo, int N, INDArray A, INDArray R);

    public abstract int dsyev(char jobz, char uplo, int N, INDArray A, INDArray R);



    @Override
    public void gesvd(INDArray A, INDArray S, INDArray U, INDArray VT) {
        // FIXME: int cast
        if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE)
            throw new ND4JArraySizeException();

        int m = (int) A.rows();
        int n = (int) A.columns();

        byte jobu = (byte) (U == null ? 'N' : 'A');
        byte jobvt = (byte) (VT == null ? 'N' : 'A');

        INDArray INFO = Nd4j.createArrayFromShapeBuffer(Nd4j.getDataBufferFactory().createInt(1),
                        Nd4j.getShapeInfoProvider().createShapeInformation(new int[] {1, 1}).getFirst());

        if (A.data().dataType() == DataBuffer.Type.DOUBLE)
            dgesvd(jobu, jobvt, m, n, A, S, U, VT, INFO);
        else if (A.data().dataType() == DataBuffer.Type.FLOAT)
            sgesvd(jobu, jobvt, m, n, A, S, U, VT, INFO);
        else
            throw new UnsupportedOperationException();

        if (INFO.getInt(0) < 0) {
            throw new Error("Parameter #" + INFO.getInt(0) + " to gesvd() was not valid");
        } else if (INFO.getInt(0) > 0) {
            log.warn("The matrix contains singular elements. Check S matrix at row " + INFO.getInt(0));
        }
    }

    public abstract void sgesvd(byte jobu, byte jobvt, int M, int N, INDArray A, INDArray S, INDArray U, INDArray VT,
                    INDArray INFO);

    public abstract void dgesvd(byte jobu, byte jobvt, int M, int N, INDArray A, INDArray S, INDArray U, INDArray VT,
                    INDArray INFO);



    @Override
    public INDArray getPFactor(int M, INDArray ipiv) {
        // The simplest permutation is the identity matrix
        INDArray P = Nd4j.eye(M); // result is a square matrix with given size
        for (int i = 0; i < ipiv.length(); i++) {
            int pivot = ipiv.getInt(i) - 1; // Did we swap row #i with anything?
            if (pivot > i) { // don't reswap when we get lower down in the vector
                INDArray v1 = P.getColumn(i).dup(); // because of row vs col major order we'll ...
                INDArray v2 = P.getColumn(pivot); // ... make a transposed matrix immediately
                P.putColumn(i, v2);
                P.putColumn(pivot, v1); // note dup() above is required - getColumn() is a 'view'
            }
        }
        return P; // the permutation matrix - contains a single 1 in any row and column
    }


    /* TODO: consider doing this in place to save memory. This implies U is taken out first
       L is the same shape as the input matrix. Just the lower triangular with a diagonal of 1s
     */
    @Override
    public INDArray getLFactor(INDArray A) {
        // FIXME: int cast
        if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE)
            throw new ND4JArraySizeException();

        int m = (int) A.rows();
        int n = (int) A.columns();

        INDArray L = Nd4j.create(m, n);
        for (int r = 0; r < m; r++) {
            for (int c = 0; c < n; c++) {
                if (r > c && r < m && c < n) {
                    L.putScalar(r, c, A.getFloat(r, c));
                } else if (r < c) {
                    L.putScalar(r, c, 0.f);
                } else {
                    L.putScalar(r, c, 1.f);
                }
            }
        }
        return L;
    }


    @Override
    public INDArray getUFactor(INDArray A) {
        // FIXME: int cast
        if (A.rows() > Integer.MAX_VALUE || A.columns() > Integer.MAX_VALUE)
            throw new ND4JArraySizeException();

        int m = (int) A.rows();
        int n = (int) A.columns();

        INDArray U = Nd4j.create(n, n);

        for (int r = 0; r < n; r++) {
            for (int c = 0; c < n; c++) {
                if (r <= c && r < m && c < n) {
                    U.putScalar(r, c, A.getFloat(r, c));
                } else {
                    U.putScalar(r, c, 0.f);
                }
            }
        }
        return U;
    }

}