import numpy as np
import scipy.sparse.linalg as sla
import scipy.sparse as sp

try:
    import torch
    from torch.autograd import Variable
except:
    pass


import re
from abc import ABCMeta, abstractmethod


def block(rows, dtype=None, arrtype=None):
    if (not _is_list_or_tup(rows)) or len(rows) == 0 or \
       np.any([not _is_list_or_tup(row) for row in rows]):
        raise RuntimeError('''
Unexpected input: Expected a non-empty list of lists.
If you are interested in helping expand the functionality
for your use case please send in an issue or PR at
http://github.com/bamos/block''')

    rowLens = [len(row) for row in rows]
    if len(np.unique(rowLens)) > 1:
        raise RuntimeError('''
Unexpected input: Rows are not the same length.
Row lengths: {}'''.format(rowLens))

    nRows = len(rows)
    nCols = rowLens[0]
    rowSizes = np.zeros(nRows, dtype=int)
    colSizes = np.zeros(nCols, dtype=int)

    backend = _get_backend(rows, dtype, arrtype)

    for i, row in enumerate(rows):
        for j, elem in enumerate(row):
            if backend.is_complete(elem):
                rowSz, colSz = backend.extract_shape(elem)
                rowSizes[i] = rowSz
                colSizes[j] = colSz
            elif hasattr(elem, 'shape'):
                rowSz, colSz = elem.shape
                rowSizes[i] = rowSz
                colSizes[j] = colSz
            elif hasattr(elem, 'size'):
                rowSz, colSz = elem.size()
                rowSizes[i] = rowSz
                colSizes[j] = colSz

    cRows = []
    for row, rowSz in zip(rows, rowSizes):
        rowSz = int(rowSz)
        if rowSz == 0:
            continue
        cCol = []
        for elem, colSz in zip(row, colSizes):
            colSz = int(colSz)
            if colSz == 0:
                continue
            # TODO: Check types.
            if backend.is_complete(elem):
                cElem = elem
            elif isinstance(elem, float) or isinstance(elem, int):
                cElem = backend.build_full((rowSz, colSz), elem)
            elif isinstance(elem, str):
                if elem == 'I':
                    assert(rowSz == colSz)
                    cElem = backend.build_eye(rowSz)
                elif elem == '-I':
                    assert(rowSz == colSz)
                    cElem = -backend.build_eye(rowSz)
                else:
                    assert(False)
            else:
                cElem = backend.convert(elem)
            cCol.append(cElem)
        cRows.append(cCol)

    return backend.build(cRows)


def block_diag(elems, dtype=None, arrtype=None):
    n = len(elems)
    return block([[0] * i + [elem] + [0] * (n - 1 - i)
                  for i, elem in enumerate(elems)],
                 dtype=dtype, arrtype=arrtype)


def block_tridiag(main, upper, lower):
    n = len(main)
    assert len(main) == len(upper) + 1
    assert len(main) == len(lower) + 1
    mat = ()
    for i in range(n):
        tup = ()
        for j in range(n):
            if (i==j):   tup = (*tup, main[i])
            elif (i==j-1): tup = (*tup, upper[-i])
            elif (i==j+1): tup = (*tup, lower[i-1])
            else: tup = (*tup,0)
        mat = (*mat,tup)
    return block(mat)

def _is_list_or_tup(x):
    return isinstance(x, list) or isinstance(x, tuple)


def _get_backend(rows, dtype, arrtype):
    if arrtype == np.ndarray and dtype is not None:
        return NumpyBackend(arrtype, dtype)
    elif arrtype == sla.LinearOperator:
        return LinearOperatorBackend(dtype)
    elif arrtype is not None and re.search('torch\..*Tensor', repr(arrtype)):
        return TorchBackend(dtype)
    elif arrtype is not None and re.search('torch\..*(Variable|Parameter)', repr(arrtype)):
        return TorchVariableBackend(dtype)
    else:
        npb = NumpyBackend()
        tb = TorchBackend()
        lob = LinearOperatorBackend()
        tvb = TorchVariableBackend()
        for row in rows:
            for elem in row:
                if npb.is_complete(elem) and elem.size > 0:
                    if dtype is None:
                        dtype = type(elem[0, 0])
                    if arrtype is None:
                        arrtype = type(elem)
                    return NumpyBackend(dtype, arrtype)
                elif tb.is_complete(elem):
                    return TorchBackend(type(elem))
                elif lob.is_complete(elem):
                    return LinearOperatorBackend(elem.dtype)
                elif tvb.is_complete(elem):
                    return TorchVariableBackend(type(elem.data))

    assert(False)


class Backend():
    __metaclass__ = ABCMeta

    @abstractmethod
    def extract_shape(self, x): pass

    @abstractmethod
    def build_eye(self, n): pass

    @abstractmethod
    def build_full(self, shape, fill_val): pass

    @abstractmethod
    def convert(self, x): pass

    @abstractmethod
    def build(self, rows): pass

    @abstractmethod
    def is_complete(self, rows): pass


class NumpyBackend(Backend):

    def __init__(self, dtype=None, arrtype=None):
        self.dtype = dtype
        self.arrtype = arrtype

    def extract_shape(self, x):
        return x.shape

    def build_eye(self, n):
        return np.eye(n)

    def build_full(self, shape, fill_val):
        return np.full(shape, fill_val, self.dtype)

    def convert(self, x):
        assert(False)

    def build(self, rows):
        return np.bmat(rows)

    def is_complete(self, x):
        return isinstance(x, np.ndarray)


class TorchBackend(Backend):

    def __init__(self, dtype=None):
        self.dtype = dtype

    def extract_shape(self, x):
        return x.size()

    def build_eye(self, n):
        return torch.eye(n).type(self.dtype)

    def build_full(self, shape, fill_val):
        return fill_val * torch.ones(*shape).type(self.dtype)

    def convert(self, x):
        assert(False)

    def build(self, rows):
        compRows = []
        for row in rows:
            compRows.append(torch.cat(row, 1))
        return torch.cat(compRows)

    def is_complete(self, x):
        return (re.search('torch\..*Tensor', str(x.__class__)) is not None) \
            and x.ndimension() == 2


class TorchVariableBackend(TorchBackend):
    def build_eye(self, n):
        return Variable(super().build_eye(n))

    def build_full(self, shape, fill_val):
        return Variable(super().build_full(shape, fill_val))

    def convert(self, x):
        if TorchBackend.is_complete(self, x):
            return Variable(x)
        assert(False)

    def is_complete(self, x):
        return re.search('torch\..*(Variable|Parameter)', str(x.__class__))


class LinearOperatorBackend(Backend):
    def __init__(self, dtype=None):
        self.dtype = dtype

    def extract_shape(self, x):
        return x.shape

    def build_eye(self, n):
        def identity(v): return v
        return sla.LinearOperator(shape=(n, n),
                                  matvec=identity,
                                  rmatvec=identity,
                                  matmat=identity,
                                  dtype=self.dtype)

    def build_full(self, shape, fill_val):
        m, n = shape
        if fill_val == 0:
            return shape
        else:
            def matvec(v):
                return v.sum() * fill_val * np.ones(m)

            def rmatvec(v):
                return v.sum() * fill_val * np.ones(n)

            def matmat(M):
                return M.sum(axis=0) * fill_val * np.ones((m, M.shape[1]))

            return sla.LinearOperator(shape=shape,
                                      matvec=matvec,
                                      rmatvec=rmatvec,
                                      matmat=matmat,
                                      dtype=self.dtype)

    def convert(self, x):
        if (isinstance(x, (np.ndarray, sp.spmatrix))):
            return sla.aslinearoperator(x)
        else:
            assert(False)

    def build(self, rows):
        col_sizes = [lo.shape[1] if self.is_complete(lo) else lo[1]
                     for lo in rows[0]]
        col_idxs = np.cumsum([0] + col_sizes)
        row_sizes = [row[0].shape[0] if self.is_complete(row[0]) else row[0][0]
                     for row in rows]
        row_idxs = np.cumsum([0] + row_sizes)
        m, n = sum(row_sizes), sum(col_sizes)

        def matvec(v):
            out = np.zeros(m)
            for row, i, j in zip(rows, row_idxs[:-1], row_idxs[1:]):
                out[i:j] = sum(lo.matvec(v[k:l]) for lo, k, l in
                               zip(row, col_idxs[:-1], col_idxs[1:])
                               if self.is_complete(lo))
            return out

        # The transposed list
        cols = zip(*rows)

        def rmatvec(v):
            out = np.zeros(n)
            for col, i, j in zip(cols, col_idxs[:-1], col_idxs[1:]):
                out[i:j] = sum(lo.rmatvec(v[k:l]) for lo, k, l in
                               zip(col, row_idxs[:-1], row_idxs[1:])
                               if self.is_complete(lo))
            return out

        def matmat(M):
            out = np.zeros((m, M.shape[1]))
            for row, i, j in zip(rows, row_idxs[:-1], row_idxs[1:]):
                out[i:j] = sum(lo.matmat(M[k:l]) for lo, k, l in
                               zip(row, col_idxs[:-1], col_idxs[1:])
                               if self.is_complete(lo))
            return out

        return sla.LinearOperator(shape=(m, n),
                                  matvec=matvec,
                                  rmatvec=rmatvec,
                                  matmat=matmat,
                                  dtype=self.dtype)

    def is_complete(self, x):
        return isinstance(x, sla.LinearOperator)