# coding: utf-8
from __future__ import print_function

import numpy as np
from keras.models import Sequential
from keras.layers.recurrent import GRU


def understand_return_sequence():
    """用来帮助理解 recurrent layer 中的 return_sequences 参数"""
    model_1 = Sequential()
    model_1.add(GRU(input_dim=256, output_dim=256, return_sequences=True))
    model_1.compile(loss='mean_squared_error', optimizer='sgd')
    train_x = np.random.randn(100, 78, 256)
    train_y = np.random.randn(100, 78, 256)
    model_1.fit(train_x, train_y, verbose=0)

    model_2 = Sequential()
    model_2.add(GRU(input_dim=256, output_dim=256, return_sequences=False))
    model_2.compile(loss='mean_squared_error', optimizer='sgd')
    train_x = np.random.randn(100, 78, 256)
    train_y = np.random.randn(100, 256)
    model_2.fit(train_x, train_y, verbose=0)

    inz = np.random.randn(100, 78, 256)
    rez_1 = model_1.predict_proba(inz, verbose=0)
    rez_2 = model_2.predict_proba(inz, verbose=0)

    print()
    print('=========== understand return_sequence =================')
    print('Input shape is: {}'.format(inz.shape))
    print('Output shape of model with `return_sequences=True`: {}'.format(rez_1.shape))
    print('Output shape of model with `return_sequences=False`: {}'.format(rez_2.shape))
    print('====================== end =============================')


def understand_variable_length_handle():
    """用来帮助理解如何用 recurrent layer 处理变长序列"""
    model = Sequential()
    model.add(GRU(input_dim=256, output_dim=256, return_sequences=True))
    model.compile(loss='mean_squared_error', optimizer='sgd')
    train_x = np.random.randn(100, 78, 256)
    train_y = np.random.randn(100, 78, 256)
    model.fit(train_x, train_y, verbose=0)

    inz_1 = np.random.randn(1, 78, 256)
    rez_1 = model.predict_proba(inz_1, verbose=0)

    inz_2 = np.random.randn(1, 87, 256)
    rez_2 = model.predict_proba(inz_2, verbose=0)

    print()
    print('=========== understand variable length =================')
    print('With `return_sequence=True`')
    print('Input shape is: {}, output shae is {}'.format(inz_1.shape, rez_1.shape))
    print('Input shape is: {}, output shae is {}'.format(inz_2.shape, rez_2.shape))
    print('====================== end =============================')


def try_variable_length_train():
    """变长序列训练实验

    实验失败,这样得到的 train_x 和 train_y 的 dtype 是 object 类型,
    取其 shape 得到的是 (100,) ,这将导致训练出错
    """
    model = Sequential()
    model.add(GRU(input_dim=256, output_dim=256, return_sequences=True))
    model.compile(loss='mean_squared_error', optimizer='sgd')

    train_x = []
    train_y = []
    for i in range(100):
        seq_length = np.random.randint(78, 87 + 1)
        sequence = []
        for _ in range(seq_length):
            sequence.append([np.random.randn() for _ in range(256)])

        train_x.append(np.array(sequence))
        train_y.append(np.array(sequence))

    train_x = np.array(train_x)
    train_y = np.array(train_y)

    model.fit(np.array(train_x), np.array(train_y))


def try_variable_length_train_in_batch():
    """变长序列训练实验(2)"""
    model = Sequential()
    model.add(GRU(input_dim=256, output_dim=256, return_sequences=True))
    model.compile(loss='mean_squared_error', optimizer='sgd')

    # 分作两个 batch, 不同 batch 中的 sequence 长度不一样
    seq_lens = [78, 87]
    for i in range(2):
        train_x = np.random.randn(20, seq_lens[i], 256)
        train_y = np.random.randn(20, seq_lens[i], 256)
        model.train_on_batch(train_x, train_y)


if __name__ == '__main__':
    understand_return_sequence()
    understand_variable_length_handle()