import numpy as np
import tensorflow as tf
from tensorflow.keras import Model, Input

from spektral.layers import GraphConv, ChebConv, EdgeConditionedConv, GraphAttention, \
    GraphConvSkip, ARMAConv, APPNP, GraphSageConv, GINConv, DiffusionConv, \
    GatedGraphConv, AGNNConv, TAGConv, CrystalConv, MessagePassing, EdgeConv
from spektral.layers.ops import sp_matrix_to_sp_tensor

tf.keras.backend.set_floatx('float64')
SINGLE, BATCH, MIXED = 1, 2, 3  # Single, batch, mixed
LAYER_K_, MODES_K_, KWARGS_K_ = 'layer', 'modes', 'kwargs'
batch_size = 32
N = 11
F = 7
S = 3
A = np.ones((N, N))
X = np.random.normal(size=(N, F))
E = np.random.normal(size=(N, N, S))
E_single = np.random.normal(size=(N * N, S))


"""
Each entry in TESTS represent a test to be run for a particular Layer.
Each config dictionary has the form: 
{
    LAYER_K_: LayerClass,
    MODES_K_: [...],
    KWARGS_K_: {...},
},

LAYER_K_ is the class of the layer to be tested.
 
MODES_K_ is a list containing the data modes supported by the model, and should 
be at least one of: SINGLE, MIXED, BATCH. 

KWARGS_K_ is a dictionary containing: 
    - all keywords to be passed to the layer (including mandatory ones);
    - an optional entry 'edges': True if the layer supports edge attributes; 
    - an optional entry 'sparse': [...], indicating whether the layer supports 
    sparse or dense inputs as a bool (e.g., 'sparse': [False, True] will 
    test the layer on both dense and sparse adjacency matrix; 'sparse': [True] 
    will only test for sparse). By default, each layer is tested only on dense
    inputs. Batch mode only tests for dense inputs. 

The testing loop will create a simple 1-layer model and run it in single, mixed, 
and batch mode according the what specified in MODES_K_ in the testing config. 
The loop will check: 
    - that the model does not crash; 
    - that the output shape is pre-computed correctly; 
    - that the real output shape is correct; 
    - that the get_config() method works correctly (i.e., it is possible to 
    re-instatiate a layer using LayerClass(**layer_instance.get_config())).
"""

TESTS = [
    {
        LAYER_K_: GraphConv,
        MODES_K_: [SINGLE, BATCH, MIXED],
        KWARGS_K_: {'channels': 8, 'activation': 'relu', 'sparse': [False, True]},
    },
    {
        LAYER_K_: ChebConv,
        MODES_K_: [SINGLE, BATCH, MIXED],
        KWARGS_K_: {'K': 3, 'channels': 8, 'activation': 'relu', 'sparse': [False, True]}
    },
    {
        LAYER_K_: GraphSageConv,
        MODES_K_: [SINGLE],
        KWARGS_K_: {'channels': 8, 'activation': 'relu', 'sparse': [False, True]}
    },
    {
        LAYER_K_: EdgeConditionedConv,
        MODES_K_: [SINGLE, BATCH],
        KWARGS_K_: {'kernel_network': [8], 'channels': 8, 'activation': 'relu',
                    'edges': True, 'sparse': [False, True]}
    },
    {
        LAYER_K_: GraphAttention,
        MODES_K_: [SINGLE, BATCH, MIXED],
        KWARGS_K_: {'channels': 8, 'attn_heads': 2, 'concat_heads': False,
                    'activation': 'relu', 'sparse': [False, True]}
    },
    {
        LAYER_K_: GraphConvSkip,
        MODES_K_: [SINGLE, BATCH, MIXED],
        KWARGS_K_: {'channels': 8, 'activation': 'relu', 'sparse': [False, True]}
    },
    {
        LAYER_K_: ARMAConv,
        MODES_K_: [SINGLE, BATCH, MIXED],
        KWARGS_K_: {'channels': 8, 'activation': 'relu', 'order': 2, 'iterations': 2,
                    'share_weights': True, 'sparse': [False, True]}
    },
    {
        LAYER_K_: APPNP,
        MODES_K_: [SINGLE, BATCH, MIXED],
        KWARGS_K_: {'channels': 8, 'activation': 'relu', 'mlp_hidden': [16],
                    'sparse': [False, True]}
    },
    {
        LAYER_K_: GINConv,
        MODES_K_: [SINGLE],
        KWARGS_K_: {'channels': 8, 'activation': 'relu', 'mlp_hidden': [16],
                    'sparse': [True]}
    },
    {
        LAYER_K_: DiffusionConv,
        MODES_K_: [SINGLE, BATCH, MIXED],
        KWARGS_K_: {'channels': 8, 'activation': 'tanh', 'num_diffusion_steps': 5,
                    'sparse': [False]}
    },
    {
        LAYER_K_: GatedGraphConv,
        MODES_K_: [SINGLE],
        KWARGS_K_: {'channels': 10, 'n_layers': 3, 'sparse': [True]}
    },
    {
        LAYER_K_: AGNNConv,
        MODES_K_: [SINGLE],
        KWARGS_K_: {'channels': F, 'trainable': True, 'sparse': [True]}
    },
    {
        LAYER_K_: TAGConv,
        MODES_K_: [SINGLE],
        KWARGS_K_: {'channels': F, 'K': 3, 'sparse': [True]}
    },
    {
        LAYER_K_: CrystalConv,
        MODES_K_: [SINGLE],
        KWARGS_K_: {'channels': F, 'edges': True, 'sparse': [True]}
    },
    {
        LAYER_K_: EdgeConv,
        MODES_K_: [SINGLE],
        KWARGS_K_: {'channels': 8, 'activation': 'relu', 'mlp_hidden': [16],
                    'sparse': [True]}
    },
    {
        LAYER_K_: MessagePassing,
        MODES_K_: [SINGLE],
        KWARGS_K_: {'channels': F, 'sparse': [True]}
    },
]


def _test_single_mode(layer, **kwargs):
    sparse = kwargs.pop('sparse', False)
    A_in = Input(shape=(None,), sparse=sparse)
    X_in = Input(shape=(F,))
    inputs = [X_in, A_in]
    if sparse:
        input_data = [X, sp_matrix_to_sp_tensor(A)]
    else:
        input_data = [X, A]

    if kwargs.pop('edges', None):
        E_in = Input(shape=(S, ))
        inputs.append(E_in)
        input_data.append(E_single)

    layer_instance = layer(**kwargs)
    output = layer_instance(inputs)
    model = Model(inputs, output)

    output = model(input_data)

    assert output.shape == (N, kwargs['channels'])


def _test_batch_mode(layer, **kwargs):
    A_batch = np.stack([A] * batch_size)
    X_batch = np.stack([X] * batch_size)

    A_in = Input(shape=(N, N))
    X_in = Input(shape=(N, F))
    inputs = [X_in, A_in]
    input_data = [X_batch, A_batch]

    if kwargs.pop('edges', None):
        E_batch = np.stack([E] * batch_size)
        E_in = Input(shape=(N, N, S))
        inputs.append(E_in)
        input_data.append(E_batch)

    layer_instance = layer(**kwargs)
    output = layer_instance(inputs)
    model = Model(inputs, output)

    output = model(input_data)

    assert output.shape == (batch_size, N, kwargs['channels'])


def _test_mixed_mode(layer, **kwargs):
    sparse = kwargs.pop('sparse', False)
    X_batch = np.stack([X] * batch_size)
    A_in = Input(shape=(N,), sparse=sparse)
    X_in = Input(shape=(N, F))
    inputs = [X_in, A_in]
    if sparse:
        input_data = [X_batch, sp_matrix_to_sp_tensor(A)]
    else:
        input_data = [X_batch, A]

    layer_instance = layer(**kwargs)
    output = layer_instance(inputs)
    model = Model(inputs, output)

    output = model(input_data)

    assert output.shape == (batch_size, N, kwargs['channels'])


def _test_get_config(layer, **kwargs):
    if kwargs.get('edges'):
        kwargs.pop('edges')
    layer_instance = layer(**kwargs)
    config = layer_instance.get_config()
    assert layer(**config)


def test_layers():
    for test in TESTS:
        for mode in test[MODES_K_]:
            if mode == SINGLE:
                if 'sparse' in test[KWARGS_K_]:
                    sparse = test[KWARGS_K_].pop('sparse')
                    for s in sparse:
                        _test_single_mode(test[LAYER_K_], sparse=s, **test[KWARGS_K_])
                else:
                    _test_single_mode(test[LAYER_K_], **test[KWARGS_K_])
            elif mode == BATCH:
                _test_batch_mode(test[LAYER_K_], **test[KWARGS_K_])
            elif mode == MIXED:
                if 'sparse' in test[KWARGS_K_]:
                    sparse = test[KWARGS_K_].pop('sparse')
                    for s in sparse:
                        _test_mixed_mode(test[LAYER_K_], sparse=s, **test[KWARGS_K_])
                else:
                    _test_mixed_mode(test[LAYER_K_], **test[KWARGS_K_])
        _test_get_config(test[LAYER_K_], **test[KWARGS_K_])