import itertools
from tensornetwork.tn_keras import mpo
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras import Input
import numpy as np
import pytest


@pytest.mark.parametrize('in_dim_base,dim1,dim2,num_nodes,bond_dim',
                         itertools.product([3, 4], [3, 4], [2, 5], [3, 4],
                                           [2, 3]))
def test_shape_sanity_check(in_dim_base, dim1, dim2, num_nodes, bond_dim):
  model = Sequential([
      Input(in_dim_base**num_nodes),
      mpo.DenseMPO(dim1**num_nodes, num_nodes=num_nodes, bond_dim=bond_dim),
      mpo.DenseMPO(dim2**num_nodes, num_nodes=num_nodes, bond_dim=bond_dim),
  ])
  # Hard code batch size.
  result = model.predict(np.ones((32, in_dim_base**num_nodes)))
  assert result.shape == (32, dim2**num_nodes)