#!/usr/bin/env python

import argparse
import cPickle
import traceback
import logging
import time
import sys

import os
import numpy
import codecs

from dialog_encdec import DialogEncoderDecoder
from numpy_compat import argpartition
from state import prototype_state
logger = logging.getLogger(__name__)

def sample_wrapper(sample_logic):
    def sample_apply(*args, **kwargs):
        sampler = args[0]
        contexts = args[1]

        verbose = kwargs.get('verbose', False)

        if verbose:
            logger.info("Starting {} : {} start sequences in total".format(sampler.name, len(contexts)))
         
        context_samples = []
        context_costs = []

        # Start loop for each utterance
        for context_id, context_utterances in enumerate(contexts):
            if verbose:
                logger.info("Searching for {}".format(context_utterances))

            # Convert contextes into list of ids
            joined_context = []
            if len(context_utterances) == 0:
                joined_context = [sampler.model.eos_sym]
            else:
                utterance_ids = sampler.model.words_to_indices(context_utterances.split())
                # Add eos tokens
                if len(utterance_ids) > 0:
                    if not utterance_ids[0] == sampler.model.eos_sym:
                        utterance_ids = [sampler.model.eos_sym] + utterance_ids
                    if not utterance_ids[-1] == sampler.model.eos_sym:
                        utterance_ids += [sampler.model.eos_sym]
                
                else:
                    utterance_ids = [sampler.model.eos_sym]

                joined_context += utterance_ids

            samples, costs = sample_logic(sampler, joined_context, **kwargs) 
             
            # Convert back indices to list of words
            converted_samples = map(lambda sample : sampler.model.indices_to_words(sample, exclude_end_sym=kwargs.get('n_turns', 1) == 1), samples)
            # Join the list of words
            converted_samples = map(' '.join, converted_samples)

            if verbose:
                for i in range(len(converted_samples)):
                    #print "Samples {}: {}".format(costs[i], converted_samples[i].encode('utf-8'))
                    logger.info("Samples {}: {}".format(costs[i], converted_samples[i].encode('utf-8')))

            context_samples.append(converted_samples)
            context_costs.append(costs)

        return context_samples, context_costs
    return sample_apply

class Sampler(object):
    """
    An abstract sampler class 
    """
    def __init__(self, model):
        # Compile beam search
        self.name = 'Sampler'
        self.model = model
        self.compiled = False
        self.max_len = 160

    def compile(self):
        self.next_probs_predictor = self.model.build_next_probs_function()
        self.compute_encoding = self.model.build_encoder_function()

        if not self.model.reset_utterance_decoder_at_end_of_utterance:
            self.compute_decoder_encoding = self.model.build_decoder_encoding()

        self.compiled = True
    
    def select_next_words(self, next_probs, step_num, how_many):
        pass

    def count_n_turns(self, utterance):
        return len([w for w in utterance \
                    if w == self.model.eos_sym])

    @sample_wrapper
    def sample(self, *args, **kwargs):
        context = args[0]

        max_context_length = kwargs.get('max_context_length', 400)
        if len(context) > max_context_length:
            context = context[-max_context_length:]

        n_samples = kwargs.get('n_samples', 1)
        ignore_unk = kwargs.get('ignore_unk', True)
        min_length = kwargs.get('min_length', 1)
        max_length = kwargs.get('max_length', 30)
        beam_diversity = kwargs.get('beam_diversity', 1)
        normalize_by_length = kwargs.get('normalize_by_length', True)
        verbose = kwargs.get('verbose', False)
        n_turns = kwargs.get('n_turns', 1)

        if not self.compiled:
            self.compile()
        
        # Convert to matrix, each column is a context 
        # [[1,1,1],[4,4,4],[2,2,2]]
        context = numpy.repeat(numpy.array(context, dtype='int32')[:,None], 
                               n_samples, axis=1)

        if context[-1, 0] != self.model.eos_sym:
            raise Exception('Last token of context, when present,'
                            'should be the end of utterance: %d' % self.model.eos_sym)

        # Generate the reversed context
        reversed_context = self.model.reverse_utterances(context)

        if self.model.direct_connection_between_encoders_and_decoder:
            if self.model.bidirectional_utterance_encoder:
                dialog_enc_size = self.model.sdim+self.model.qdim_encoder*2
            else:
                dialog_enc_size = self.model.sdim+self.model.qdim_encoder
        else:
            dialog_enc_size = self.model.sdim

        prev_hs = numpy.zeros((n_samples, dialog_enc_size), dtype='float32')

        prev_hd = numpy.zeros((n_samples, self.model.utterance_decoder.complete_hidden_state_size), dtype='float32')

        if not self.model.reset_utterance_decoder_at_end_of_utterance:
            assert self.model.bs >= context.shape[1]
            enlarged_context = numpy.zeros((context.shape[0], self.model.bs), dtype='int32')
            enlarged_context[:, 0:context.shape[1]] = context[:]
            enlarged_reversed_context = numpy.zeros((context.shape[0], self.model.bs), dtype='int32')
            enlarged_reversed_context[:, 0:context.shape[1]] = reversed_context[:]

            ran_gaussian_vector = self.model.rng.normal(size=(context.shape[0],n_samples,self.model.latent_gaussian_per_utterance_dim)).astype('float32')
            ran_uniform_vector = self.model.rng.uniform(low=0.0, high=1.0, size=(context.shape[0],n_samples,self.model.latent_piecewise_per_utterance_dim)).astype('float32')

            zero_mask = numpy.zeros((context.shape[0], self.model.bs), dtype='float32')
            zero_vector = numpy.zeros((self.model.bs), dtype='float32')
            ones_mask = numpy.zeros((context.shape[0], self.model.bs), dtype='float32')

            # Computes new utterance decoder hidden states (including intermediate utterance encoder and dialogue encoder hidden states)
            new_hd = self.compute_decoder_encoding(enlarged_context, enlarged_reversed_context, self.max_len, zero_mask, zero_vector, ran_gaussian_vector, ran_uniform_vector, ones_mask)


            prev_hd[:] = new_hd[0][-1][0:context.shape[1], :]

        fin_gen = []
        fin_costs = []
         
        gen = [[] for i in range(n_samples)]
        costs = [0. for i in range(n_samples)]
        beam_empty = False

        # Compute random vector as additional input
        ran_gaussian_vectors = self.model.rng.normal(size=(n_samples,self.model.latent_gaussian_per_utterance_dim)).astype('float32')
        ran_uniform_vectors = self.model.rng.uniform(low=0.0, high=1.0, size=(n_samples,self.model.latent_piecewise_per_utterance_dim)).astype('float32')

        # HACK
        #ran_uniform_vectors = numpy.greater(ran_uniform_vectors, 0.5).astype('float32')


        for k in range(max_length):
            if len(fin_gen) >= n_samples or beam_empty:
                break
             
            if verbose:
                logger.info("{} : sampling step {}, beams alive {}".format(self.name, k, len(gen)))
             
            # Here we aggregate the context and recompute the hidden state
            # at both session level and query level.
            # Stack only when we sampled something
            if k > 0:
                context = numpy.vstack([context, \
                                        numpy.array(map(lambda g: g[-1], gen))]).astype('int32')
                reversed_context = numpy.copy(context)
                for idx in range(context.shape[1]):
                    eos_indices = numpy.where(context[:, idx] == self.model.eos_sym)[0]
                    prev_eos_index = -1
                    for eos_index in eos_indices:
                        reversed_context[(prev_eos_index+2):eos_index, idx] = (reversed_context[(prev_eos_index+2):eos_index, idx])[::-1]
                        prev_eos_index = eos_index

            prev_words = context[-1, :]
           
            # Recompute encoder states, hs and random variables 
            # only for those particular utterances that meet the end-of-utterance token
            indx_update_hs = [num for num, prev_word in enumerate(prev_words)
                                if prev_word == self.model.eos_sym]

            if len(indx_update_hs):
                encoder_states = self.compute_encoding(context[:, indx_update_hs], reversed_context[:, indx_update_hs], self.max_len)
                prev_hs[indx_update_hs] = encoder_states[1][-1]
                ran_gaussian_vectors[indx_update_hs,:] = self.model.rng.normal(size=(len(indx_update_hs),self.model.latent_gaussian_per_utterance_dim)).astype('float32')
                ran_uniform_vectors[indx_update_hs,:] = self.model.rng.uniform(low=0.0, high=1.0, size=(len(indx_update_hs),self.model.latent_piecewise_per_utterance_dim)).astype('float32')


                # HACK
                #ran_uniform_vectors = numpy.greater(ran_uniform_vectors, 0.5).astype('float32')

            # ... done
            next_probs, new_hd = self.next_probs_predictor(prev_hs, prev_hd, prev_words, context, ran_gaussian_vectors, ran_uniform_vectors)



            assert next_probs.shape[1] == self.model.idim
            
            # Adjust log probs according to search restrictions
            if ignore_unk:
                next_probs[:, self.model.unk_sym] = 0
            if k <= min_length:
                next_probs[:, self.model.eos_sym] = 0
                next_probs[:, self.model.eod_sym] = 0
             
            # Update costs 
            next_costs = numpy.array(costs)[:, None] - numpy.log(next_probs)

            # Select next words here
            (beam_indx, word_indx), costs = self.select_next_words(next_costs, next_probs, k, n_samples)
            
            # Update the stacks
            new_gen = [] 
            new_costs = []
            new_sources = []

            for num, (beam_ind, word_ind, cost) in enumerate(zip(beam_indx, word_indx, costs)):
                if len(new_gen) > n_samples:
                    break

                hypothesis = gen[beam_ind] + [word_ind]
                 
                # End of utterance has been detected
                n_turns_hypothesis = self.count_n_turns(hypothesis)
                if n_turns_hypothesis == n_turns:
                    if verbose:
                        logger.debug("adding utterance {} from beam {}".format(hypothesis, beam_ind))

                    # We finished sampling
                    fin_gen.append(hypothesis)
                    fin_costs.append(cost)
                elif self.model.eod_sym in hypothesis: # End of dialogue detected
                    new_hypothesis = []
                    for wrd in hypothesis:
                        new_hypothesis += [wrd]
                        if wrd == self.model.eod_sym:
                            break
                    hypothesis = new_hypothesis

                    if verbose:
                        logger.debug("adding utterance {} from beam {}".format(hypothesis, beam_ind))

                    # We finished sampling
                    fin_gen.append(hypothesis)
                    fin_costs.append(cost)
                else:
                    # Hypothesis recombination
                    # TODO: pick the one with lowest cost 
                    has_similar = False
                    if self.hyp_rec > 0:
                        has_similar = len([g for g in new_gen if \
                            g[-self.hyp_rec:] == hypothesis[-self.hyp_rec:]]) != 0
                    
                    if not has_similar:
                        new_sources.append(beam_ind)
                        new_gen.append(hypothesis)
                        new_costs.append(cost)
            
            if verbose:
                for gen in new_gen:
                    logger.debug("partial -> {}".format(' '.join(self.model.indices_to_words(gen))))

            prev_hd = new_hd[new_sources]
            prev_hs = prev_hs[new_sources]
            ran_gaussian_vectors = ran_gaussian_vectors[new_sources,:]
            ran_uniform_vectors = ran_uniform_vectors[new_sources,:]
            context = context[:, new_sources]
            reversed_context = reversed_context[:, new_sources]
            gen = new_gen
            costs = new_costs
            beam_empty = len(gen) == 0

        # If we have not sampled anything
        # then force include stuff
        if len(fin_gen) == 0:
            fin_gen = gen 
            fin_costs = costs 
         
        # Normalize costs
        if normalize_by_length:
            fin_costs = [(fin_costs[num]/len(fin_gen[num])) \
                         for num in range(len(fin_gen))]

        fin_gen = numpy.array(fin_gen)[numpy.argsort(fin_costs)]
        fin_costs = numpy.array(sorted(fin_costs))
        return fin_gen[:n_samples], fin_costs[:n_samples] 

class RandomSampler(Sampler):
    def __init__(self, model):
        Sampler.__init__(self, model)
        self.name = 'RandomSampler'
        self.hyp_rec = 0

    def select_next_words(self, next_costs, next_probs, step_num, how_many):
        # Choice is complaining
        next_probs = next_probs.astype("float64") 
        word_indx = numpy.array([self.model.rng.choice(self.model.idim, p = x/numpy.sum(x))
                                    for x in next_probs], dtype='int32')
        beam_indx = range(next_probs.shape[0])

        args = numpy.ravel_multi_index(numpy.array([beam_indx, word_indx]), next_costs.shape)
        return (beam_indx, word_indx), next_costs.flatten()[args]

class BeamSampler(Sampler):
    def __init__(self, model):
        Sampler.__init__(self, model)
        self.name = 'BeamSampler'
        self.hyp_rec = 3

    def select_next_words(self, next_costs, next_probs, step_num, how_many):
        # Pick only on the first line (for the beginning of sampling)
        # This will avoid duplicate <q> token.
        if step_num == 0:
            flat_next_costs = next_costs[:1, :].flatten()
        else:
            # Set the next cost to infinite for finished utterances (they will be replaced)
            # by other utterances in the beam
            flat_next_costs = next_costs.flatten()
         
        voc_size = next_costs.shape[1]
         
        args = numpy.argpartition(flat_next_costs, how_many)[:how_many]
        args = args[numpy.argsort(flat_next_costs[args])]
        
        return numpy.unravel_index(args, next_costs.shape), flat_next_costs[args]