#!/usr/bin/env python
# coding: utf-8
# @Author: lapis-hong
# @Date  : 2018/8/13
"""This module contains efficient data read and transform using tf.data API.

Data iterator for triplets (h, t, r) 
and corrupt sampling (with either the head or tail replaced by a random entity).


Input format:
    Train data: data file
        each line contains (h, t, r) triples separated by '\t'
"""
import collections
import random

import tensorflow as tf


class BatchedInput(
    collections.namedtuple(
        "BatchedInput", ("initializer", "h", "t", "r", "h_neg", "t_neg"))):
    pass


def _parse(line):
    """Parse train data."""
    cols_types = [[''], [''], ['']]
    return tf.decode_csv(line, record_defaults=cols_types, field_delim='\t')


def get_iterator(data_file, entity, entity_table, relation_table, batch_size, shuffle_buffer_size=None):
    """Iterator for train and eval.
    Args:
        data_file: data file, each line contains (h, t, r) triple
        entity: list or tuple of all entities.
        entity_table: entity tf look-up table
        relation_table: relation tf look-up table
        shuffle_buffer_size: buffer size for shuffle
    Returns:
        BatchedInput instance
    """
    shuffle_buffer_size = shuffle_buffer_size or batch_size * 1000

    dataset = tf.data.TextLineDataset(data_file)
    dataset = dataset.map(_parse, num_parallel_calls=4)
    dataset = dataset.shuffle(shuffle_buffer_size)

    # corrupt sampling
    def sample():
        if random.random() < 0.5:
            return lambda h, t, r: (h, t, r, random.choice(entity), t)
        else:
            return lambda h, t, r: (h, t, r, h, random.choice(entity))

    dataset = dataset.map(sample())

    dataset = dataset.map(
        lambda h, t, r, h_neg, t_neg: (
            tf.cast(entity_table.lookup(h), tf.int32),
            tf.cast(entity_table.lookup(t), tf.int32),
            tf.cast(relation_table.lookup(r), tf.int32),
            tf.cast(entity_table.lookup(h_neg), tf.int32),
            tf.cast(entity_table.lookup(t_neg), tf.int32)
        ),
        num_parallel_calls=4)

    dataset = dataset.padded_batch(
        batch_size,
        padded_shapes=(
            tf.TensorShape([]),
            tf.TensorShape([]),
            tf.TensorShape([]),
            tf.TensorShape([]),
            tf.TensorShape([]),
        ),
        padding_values=(0, 0, 0, 0, 0),
        drop_remainder=True).prefetch(2*batch_size)

    batched_iter = dataset.make_initializable_iterator()
    h, t, r, h_neg, t_neg = batched_iter.get_next()

    return BatchedInput(initializer=batched_iter.initializer, h=h, t=t, r=r, h_neg=h_neg, t_neg=t_neg)