import torch
import numba
import numpy as np
from scipy.linalg import expm
from torch_geometric.utils import add_self_loops, is_undirected, to_dense_adj
from torch_sparse import coalesce
from torch_scatter import scatter_add


def jit():
    def decorator(func):
        try:
            return numba.jit(cache=True)(func)
        except RuntimeError:
            return numba.jit(cache=False)(func)

    return decorator


class GDC(object):
    r"""Processes the graph via Graph Diffusion Convolution (GDC) from the
    `"Diffusion Improves Graph Learning" <https://www.kdd.in.tum.de/gdc>`_
    paper.

    .. note::

        The paper offers additional advice on how to choose the
        hyperparameters.
        For an example of using GCN with GDC, see `examples/gcn.py
        <https://github.com/rusty1s/pytorch_geometric/blob/master/examples/
        gcn.py>`_.

    Args:
        self_loop_weight (float, optional): Weight of the added self-loop.
            Set to :obj:`None` to add no self-loops. (default: :obj:`1`)
        normalization_in (str, optional): Normalization of the transition
            matrix on the original (input) graph. Possible values:
            :obj:`"sym"`, :obj:`"col"`, and :obj:`"row"`.
            See :func:`GDC.transition_matrix` for details.
            (default: :obj:`"sym"`)
        normalization_out (str, optional): Normalization of the transition
            matrix on the transformed GDC (output) graph. Possible values:
            :obj:`"sym"`, :obj:`"col"`, :obj:`"row"`, and :obj:`None`.
            See :func:`GDC.transition_matrix` for details.
            (default: :obj:`"col"`)
        diffusion_kwargs (dict, optional): Dictionary containing the parameters
            for diffusion.
            `method` specifies the diffusion method (:obj:`"ppr"`,
            :obj:`"heat"` or :obj:`"coeff"`).
            Each diffusion method requires different additional parameters.
            See :func:`GDC.diffusion_matrix_exact` or
            :func:`GDC.diffusion_matrix_approx` for details.
            (default: :obj:`dict(method='ppr', alpha=0.15)`)
        sparsification_kwargs (dict, optional): Dictionary containing the
            parameters for sparsification.
            `method` specifies the sparsification method (:obj:`"threshold"` or
            :obj:`"topk"`).
            Each sparsification method requires different additional
            parameters.
            See :func:`GDC.sparsify_dense` for details.
            (default: :obj:`dict(method='threshold', avg_degree=64)`)
        exact (bool, optional): Whether to exactly calculate the diffusion
            matrix.
            Note that the exact variants are not scalable.
            They densify the adjacency matrix and calculate either its inverse
            or its matrix exponential.
            However, the approximate variants do not support edge weights and
            currently only personalized PageRank and sparsification by
            threshold are implemented as fast, approximate versions.
            (default: :obj:`True`)

    :rtype: :class:`torch_geometric.data.Data`
    """
    def __init__(self, self_loop_weight=1, normalization_in='sym',
                 normalization_out='col',
                 diffusion_kwargs=dict(method='ppr', alpha=0.15),
                 sparsification_kwargs=dict(method='threshold',
                                            avg_degree=64), exact=True):
        self.self_loop_weight = self_loop_weight
        self.normalization_in = normalization_in
        self.normalization_out = normalization_out
        self.diffusion_kwargs = diffusion_kwargs
        self.sparsification_kwargs = sparsification_kwargs
        self.exact = exact

        if self_loop_weight:
            assert exact or self_loop_weight == 1

    @torch.no_grad()
    def __call__(self, data):
        N = data.num_nodes
        edge_index = data.edge_index
        if data.edge_attr is None:
            edge_weight = torch.ones(edge_index.size(1),
                                     device=edge_index.device)
        else:
            edge_weight = data.edge_attr
            assert self.exact
            assert edge_weight.dim() == 1

        if self.self_loop_weight:
            edge_index, edge_weight = add_self_loops(
                edge_index, edge_weight, fill_value=self.self_loop_weight,
                num_nodes=N)

        edge_index, edge_weight = coalesce(edge_index, edge_weight, N, N)

        if self.exact:
            edge_index, edge_weight = self.transition_matrix(
                edge_index, edge_weight, N, self.normalization_in)
            diff_mat = self.diffusion_matrix_exact(edge_index, edge_weight, N,
                                                   **self.diffusion_kwargs)
            edge_index, edge_weight = self.sparsify_dense(
                diff_mat, **self.sparsification_kwargs)
        else:
            edge_index, edge_weight = self.diffusion_matrix_approx(
                edge_index, edge_weight, N, self.normalization_in,
                **self.diffusion_kwargs)
            edge_index, edge_weight = self.sparsify_sparse(
                edge_index, edge_weight, N, **self.sparsification_kwargs)

        edge_index, edge_weight = coalesce(edge_index, edge_weight, N, N)
        edge_index, edge_weight = self.transition_matrix(
            edge_index, edge_weight, N, self.normalization_out)

        data.edge_index = edge_index
        data.edge_attr = edge_weight

        return data

    def transition_matrix(self, edge_index, edge_weight, num_nodes,
                          normalization):
        r"""Calculate the approximate, sparse diffusion on a given sparse
        matrix.

        Args:
            edge_index (LongTensor): The edge indices.
            edge_weight (Tensor): One-dimensional edge weights.
            num_nodes (int): Number of nodes.
            normalization (str): Normalization scheme:

                1. :obj:`"sym"`: Symmetric normalization
                   :math:`\mathbf{T} = \mathbf{D}^{-1/2} \mathbf{A}
                   \mathbf{D}^{-1/2}`.
                2. :obj:`"col"`: Column-wise normalization
                   :math:`\mathbf{T} = \mathbf{A} \mathbf{D}^{-1}`.
                3. :obj:`"row"`: Row-wise normalization
                   :math:`\mathbf{T} = \mathbf{D}^{-1} \mathbf{A}`.
                4. :obj:`None`: No normalization.

        :rtype: (:class:`LongTensor`, :class:`Tensor`)
        """
        if normalization == 'sym':
            row, col = edge_index
            deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
            deg_inv_sqrt = deg.pow(-0.5)
            deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
            edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
        elif normalization == 'col':
            _, col = edge_index
            deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
            deg_inv = 1. / deg
            deg_inv[deg_inv == float('inf')] = 0
            edge_weight = edge_weight * deg_inv[col]
        elif normalization == 'row':
            row, _ = edge_index
            deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
            deg_inv = 1. / deg
            deg_inv[deg_inv == float('inf')] = 0
            edge_weight = edge_weight * deg_inv[row]
        elif normalization is None:
            pass
        else:
            raise ValueError(
                'Transition matrix normalization {} unknown.'.format(
                    normalization))

        return edge_index, edge_weight

    def diffusion_matrix_exact(self, edge_index, edge_weight, num_nodes,
                               method, **kwargs):
        r"""Calculate the (dense) diffusion on a given sparse graph.
        Note that these exact variants are not scalable. They densify the
        adjacency matrix and calculate either its inverse or its matrix
        exponential.

        Args:
            edge_index (LongTensor): The edge indices.
            edge_weight (Tensor): One-dimensional edge weights.
            num_nodes (int): Number of nodes.
            method (str): Diffusion method:

                1. :obj:`"ppr"`: Use personalized PageRank as diffusion.
                   Additionally expects the parameter:

                   - **alpha** (*float*) - Return probability in PPR.
                     Commonly lies in :obj:`[0.05, 0.2]`.

                2. :obj:`"heat"`: Use heat kernel diffusion.
                   Additionally expects the parameter:

                   - **t** (*float*) - Time of diffusion. Commonly lies in
                     :obj:`[2, 10]`.

                3. :obj:`"coeff"`: Freely choose diffusion coefficients.
                   Additionally expects the parameter:

                   - **coeffs** (*List[float]*) - List of coefficients
                     :obj:`theta_k` for each power of the transition matrix
                     (starting at :obj:`0`).

        :rtype: (:class:`Tensor`)
        """
        if method == 'ppr':
            # α (I_n + (α - 1) A)^-1
            edge_weight = (kwargs['alpha'] - 1) * edge_weight
            edge_index, edge_weight = add_self_loops(edge_index, edge_weight,
                                                     fill_value=1,
                                                     num_nodes=num_nodes)
            mat = to_dense_adj(edge_index, edge_attr=edge_weight).squeeze()
            diff_matrix = kwargs['alpha'] * torch.inverse(mat)

        elif method == 'heat':
            # exp(t (A - I_n))
            edge_index, edge_weight = add_self_loops(edge_index, edge_weight,
                                                     fill_value=-1,
                                                     num_nodes=num_nodes)
            edge_weight = kwargs['t'] * edge_weight
            mat = to_dense_adj(edge_index, edge_attr=edge_weight).squeeze()
            undirected = is_undirected(edge_index, edge_weight, num_nodes)
            diff_matrix = self.__expm__(mat, undirected)

        elif method == 'coeff':
            adj_matrix = to_dense_adj(edge_index,
                                      edge_attr=edge_weight).squeeze()
            mat = torch.eye(num_nodes, device=edge_index.device)

            diff_matrix = kwargs['coeffs'][0] * mat
            for coeff in kwargs['coeffs'][1:]:
                mat = mat @ adj_matrix
                diff_matrix += coeff * mat
        else:
            raise ValueError('Exact GDC diffusion {} unknown.'.format(method))

        return diff_matrix

    def diffusion_matrix_approx(self, edge_index, edge_weight, num_nodes,
                                normalization, method, **kwargs):
        r"""Calculate the approximate, sparse diffusion on a given sparse
        graph.

        Args:
            edge_index (LongTensor): The edge indices.
            edge_weight (Tensor): One-dimensional edge weights.
            num_nodes (int): Number of nodes.
            normalization (str): Transition matrix normalization scheme
                (:obj:`"sym"`, :obj:`"row"`, or :obj:`"col"`).
                See :func:`GDC.transition_matrix` for details.
            method (str): Diffusion method:

                1. :obj:`"ppr"`: Use personalized PageRank as diffusion.
                   Additionally expects the parameters:

                   - **alpha** (*float*) - Return probability in PPR.
                     Commonly lies in :obj:`[0.05, 0.2]`.

                   - **eps** (*float*) - Threshold for PPR calculation stopping
                     criterion (:obj:`edge_weight >= eps * out_degree`).
                     Recommended default: :obj:`1e-4`.

        :rtype: (:class:`LongTensor`, :class:`Tensor`)
        """
        if method == 'ppr':
            if normalization == 'sym':
                # Calculate original degrees.
                _, col = edge_index
                deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)

            edge_index_np = edge_index.cpu().numpy()
            # Assumes coalesced edge_index.
            _, indptr, out_degree = np.unique(edge_index_np[0],
                                              return_index=True,
                                              return_counts=True)

            neighbors, neighbor_weights = GDC.__calc_ppr__(
                indptr, edge_index_np[1], out_degree, kwargs['alpha'],
                kwargs['eps'])
            ppr_normalization = 'col' if normalization == 'col' else 'row'
            edge_index, edge_weight = self.__neighbors_to_graph__(
                neighbors, neighbor_weights, ppr_normalization,
                device=edge_index.device)
            edge_index = edge_index.to(torch.long)

            if normalization == 'sym':
                # We can change the normalization from row-normalized to
                # symmetric by multiplying the resulting matrix with D^{1/2}
                # from the left and D^{-1/2} from the right.
                # Since we use the original degrees for this it will be like
                # we had used symmetric normalization from the beginning
                # (except for errors due to approximation).
                row, col = edge_index
                deg_inv = deg.sqrt()
                deg_inv_sqrt = deg.pow(-0.5)
                deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
                edge_weight = deg_inv[row] * edge_weight * deg_inv_sqrt[col]
            elif normalization in ['col', 'row']:
                pass
            else:
                raise ValueError(
                    ('Transition matrix normalization {} not implemented for '
                     'non-exact GDC computation.').format(normalization))

        elif method == 'heat':
            raise NotImplementedError(
                ('Currently no fast heat kernel is implemented. You are '
                 'welcome to create one yourself, e.g., based on '
                 '"Kloster and Gleich: Heat kernel based community detection '
                 '(KDD 2014)."'))
        else:
            raise ValueError(
                'Approximate GDC diffusion {} unknown.'.format(method))

        return edge_index, edge_weight

    def sparsify_dense(self, matrix, method, **kwargs):
        r"""Sparsifies the given dense matrix.

        Args:
            matrix (Tensor): Matrix to sparsify.
            num_nodes (int): Number of nodes.
            method (str): Method of sparsification. Options:

                1. :obj:`"threshold"`: Remove all edges with weights smaller
                   than :obj:`eps`.
                   Additionally expects one of these parameters:

                   - **eps** (*float*) - Threshold to bound edges at.

                   - **avg_degree** (*int*) - If :obj:`eps` is not given,
                     it can optionally be calculated by calculating the
                     :obj:`eps` required to achieve a given :obj:`avg_degree`.

                2. :obj:`"topk"`: Keep edges with top :obj:`k` edge weights per
                   node (column).
                   Additionally expects the following parameters:

                   - **k** (*int*) - Specifies the number of edges to keep.

                   - **dim** (*int*) - The axis along which to take the top
                     :obj:`k`.

        :rtype: (:class:`LongTensor`, :class:`Tensor`)
        """
        assert matrix.shape[0] == matrix.shape[1]
        N = matrix.shape[1]

        if method == 'threshold':
            if 'eps' not in kwargs.keys():
                kwargs['eps'] = self.__calculate_eps__(matrix, N,
                                                       kwargs['avg_degree'])

            edge_index = torch.nonzero(matrix >= kwargs['eps']).t()
            edge_index_flat = edge_index[0] * N + edge_index[1]
            edge_weight = matrix.flatten()[edge_index_flat]

        elif method == 'topk':
            assert kwargs['dim'] in [0, 1]
            sort_idx = torch.argsort(matrix, dim=kwargs['dim'],
                                     descending=True)
            if kwargs['dim'] == 0:
                top_idx = sort_idx[:kwargs['k']]
                edge_weight = torch.gather(matrix, dim=kwargs['dim'],
                                           index=top_idx).flatten()

                row_idx = torch.arange(0, N, device=matrix.device).repeat(
                    kwargs['k'])
                edge_index = torch.stack([top_idx.flatten(), row_idx], dim=0)
            else:
                top_idx = sort_idx[:, :kwargs['k']]
                edge_weight = torch.gather(matrix, dim=kwargs['dim'],
                                           index=top_idx).flatten()

                col_idx = torch.arange(
                    0, N, device=matrix.device).repeat_interleave(kwargs['k'])
                edge_index = torch.stack([col_idx, top_idx.flatten()], dim=0)
        else:
            raise ValueError('GDC sparsification {} unknown.'.format(method))

        return edge_index, edge_weight

    def sparsify_sparse(self, edge_index, edge_weight, num_nodes, method,
                        **kwargs):
        r"""Sparsifies a given sparse graph further.

        Args:
            edge_index (LongTensor): The edge indices.
            edge_weight (Tensor): One-dimensional edge weights.
            num_nodes (int): Number of nodes.
            method (str): Method of sparsification:

                1. :obj:`"threshold"`: Remove all edges with weights smaller
                   than :obj:`eps`.
                   Additionally expects one of these parameters:

                   - **eps** (*float*) - Threshold to bound edges at.

                   - **avg_degree** (*int*) - If :obj:`eps` is not given,
                     it can optionally be calculated by calculating the
                     :obj:`eps` required to achieve a given :obj:`avg_degree`.

        :rtype: (:class:`LongTensor`, :class:`Tensor`)
        """
        if method == 'threshold':
            if 'eps' not in kwargs.keys():
                kwargs['eps'] = self.__calculate_eps__(edge_weight, num_nodes,
                                                       kwargs['avg_degree'])

            remaining_edge_idx = torch.nonzero(
                edge_weight >= kwargs['eps']).flatten()
            edge_index = edge_index[:, remaining_edge_idx]
            edge_weight = edge_weight[remaining_edge_idx]
        elif method == 'topk':
            raise NotImplementedError(
                'Sparse topk sparsification not implemented.')
        else:
            raise ValueError('GDC sparsification {} unknown.'.format(method))

        return edge_index, edge_weight

    def __expm__(self, matrix, symmetric):
        r"""Calculates matrix exponential.

        Args:
            matrix (Tensor): Matrix to take exponential of.
            symmetric (bool): Specifies whether the matrix is symmetric.

        :rtype: (:class:`Tensor`)
        """
        if symmetric:
            e, V = torch.symeig(matrix, eigenvectors=True)
            diff_mat = V @ torch.diag(e.exp()) @ V.t()
        else:
            diff_mat_np = expm(matrix.cpu().numpy())
            diff_mat = torch.Tensor(diff_mat_np).to(matrix.device)
        return diff_mat

    def __calculate_eps__(self, matrix, num_nodes, avg_degree):
        r"""Calculates threshold necessary to achieve a given average degree.

        Args:
            matrix (Tensor): Adjacency matrix or edge weights.
            num_nodes (int): Number of nodes.
            avg_degree (int): Target average degree.

        :rtype: (:class:`float`)
        """
        sorted_edges = torch.sort(matrix.flatten(), descending=True).values
        if avg_degree * num_nodes > len(sorted_edges):
            return -np.inf

        left = sorted_edges[avg_degree * num_nodes - 1]
        right = sorted_edges[avg_degree * num_nodes]
        return (left + right) / 2.0

    def __neighbors_to_graph__(self, neighbors, neighbor_weights,
                               normalization='row', device='cpu'):
        r"""Combine a list of neighbors and neighbor weights to create a sparse
        graph.

        Args:
            neighbors (List[List[int]]): List of neighbors for each node.
            neighbor_weights (List[List[float]]): List of weights for the
                neighbors of each node.
            normalization (str): Normalization of resulting matrix
                (options: :obj:`"row"`, :obj:`"col"`). (default: :obj:`"row"`)
            device (torch.device): Device to create output tensors on.
                (default: :obj:`"cpu"`)

        :rtype: (:class:`LongTensor`, :class:`Tensor`)
        """
        edge_weight = torch.Tensor(np.concatenate(neighbor_weights)).to(device)
        i = np.repeat(np.arange(len(neighbors)),
                      np.fromiter(map(len, neighbors), dtype=np.int))
        j = np.concatenate(neighbors)
        if normalization == 'col':
            edge_index = torch.Tensor(np.vstack([j, i])).to(device)
            N = len(neighbors)
            edge_index, edge_weight = coalesce(edge_index, edge_weight, N, N)
        elif normalization == 'row':
            edge_index = torch.Tensor(np.vstack([i, j])).to(device)
        else:
            raise ValueError(
                f"PPR matrix normalization {normalization} unknown.")
        return edge_index, edge_weight

    @staticmethod
    @jit()
    def __calc_ppr__(indptr, indices, out_degree, alpha, eps):
        r"""Calculate the personalized PageRank vector for all nodes
        using a variant of the Andersen algorithm
        (see Andersen et al. :Local Graph Partitioning using PageRank Vectors.)

        Args:
            indptr (np.ndarray): Index pointer for the sparse matrix
                (CSR-format).
            indices (np.ndarray): Indices of the sparse matrix entries
                (CSR-format).
            out_degree (np.ndarray): Out-degree of each node.
            alpha (float): Alpha of the PageRank to calculate.
            eps (float): Threshold for PPR calculation stopping criterion
                (:obj:`edge_weight >= eps * out_degree`).

        :rtype: (:class:`List[List[int]]`, :class:`List[List[float]]`)
        """
        alpha_eps = alpha * eps
        js = []
        vals = []
        for inode in range(len(out_degree)):
            p = {inode: 0.0}
            r = {}
            r[inode] = alpha
            q = [inode]
            while len(q) > 0:
                unode = q.pop()

                res = r[unode] if unode in r else 0
                if unode in p:
                    p[unode] += res
                else:
                    p[unode] = res
                r[unode] = 0
                for vnode in indices[indptr[unode]:indptr[unode + 1]]:
                    _val = (1 - alpha) * res / out_degree[unode]
                    if vnode in r:
                        r[vnode] += _val
                    else:
                        r[vnode] = _val

                    res_vnode = r[vnode] if vnode in r else 0
                    if res_vnode >= alpha_eps * out_degree[vnode]:
                        if vnode not in q:
                            q.append(vnode)
            js.append(list(p.keys()))
            vals.append(list(p.values()))
        return js, vals

    def __repr__(self):
        return '{}()'.format(self.__class__.__name__)