import os
import tempfile
from unittest import TestCase
import numpy as np
from keras_gpt_2.backend import keras
from keras_gpt_2 import get_model, get_custom_objects


class TestModel(TestCase):

    def test_save_load(self):
        model = get_model(
            n_vocab=50257,
            n_ctx=1024,
            n_embd=768,
            n_head=12,
            n_layer=12,
        )
        model_path = os.path.join(tempfile.gettempdir(), 'test_gpt_2_%f.h5' % np.random.random())
        model.save(model_path)
        model = keras.models.load_model(model_path, custom_objects=get_custom_objects())
        model.summary()

    def test_fixed_input_shape(self):
        model = get_model(
            n_vocab=50257,
            n_ctx=1024,
            n_embd=768,
            n_head=12,
            n_layer=12,
            fixed_input_shape=True,
        )
        model_path = os.path.join(tempfile.gettempdir(), 'test_gpt_2_%f.h5' % np.random.random())
        model.save(model_path)
        model = keras.models.load_model(model_path, custom_objects=get_custom_objects())
        model.summary()