#!/usr/bin/python3
#-*-coding:utf-8-*-
#$File: model.py
#$Date: Sat May  7 10:59:45 2016
#$Author: Like Ma <milkpku[at]gmail[dot]com>

from config import Config

import tensorflow as tf
import functools

from util.model import Model, conv2d

def get_model(name):
    name = functools.partial('{}-{}'.format, name)

    self_pos = tf.placeholder(Config.dtype, Config.data_shape, name='self_pos')
    self_ability = tf.placeholder(Config.dtype, Config.data_shape, name='self_ability')
    enemy_pos = tf.placeholder(Config.dtype, Config.data_shape, name='enemy_pos')
    input_label = tf.placeholder(Config.dtype, Config.label_shape, name='input_label')

    x = tf.concat(3, [self_pos, self_ability, enemy_pos], name=name('input_concat'))
    y = input_label

    nl = tf.nn.tanh

    def conv_pip(name, x, nl):
        name = functools.partial('{}_{}'.format, name)

        x = conv2d(name('0'), x, Config.data_shape[3]*2, kernel=3, stride=1, nl=nl)
        x = conv2d(name('1'), x, Config.data_shape[3], kernel=3, stride=1, nl=nl)
        return x

    for layer in range(5):
        x_branch = conv_pip(name('conv%d'%layer), x, nl)
        x = tf.concat(3, [x,x_branch], name=name('concate%d'%layer))

    x = conv_pip(name('conv5'), x, nl=None)
    pred = tf.sigmoid(x)

    # another formula of y*logy
    loss = -tf.log(tf.reduce_sum(tf.mul(x, y), reduction_indices=[1,2,3]))
    loss += - 0.1 * tf.log(tf.reduce_sum(tf.mul(x, self_ability), reduction_indices=[1,2,3]))
    pred = tf.mul(pred, self_ability)

    return Model([self_pos, self_ability, enemy_pos], input_label, loss, pred)

if __name__=='__main__':

    model = get_model('test')
    sess = tf.InteractiveSession()
    sess.run(tf.initialize_all_variables())

    import numpy as np
    x_data = np.random.randint(2, size=[3,100,9,10,16]).astype('float32')
    y_data = np.random.randint(2, size=[100,9,10,16]).astype('float32')

    input_dict = {}
    for var, data in zip(model.inputs, x_data):
        input_dict[var] = data
    input_dict[model.label] = y_data

    loss_val = model.loss.eval(feed_dict=input_dict)
    pred_val = model.pred.eval(feed_dict=input_dict)
    print(loss_val)
    # print(pred_val)

    pred_val = pred_val.reshape(pred_val.shape[0], -1)
    assert all(abs(pred_val.sum(axis=1)-1.0<1e-6))
    print('model test OK')