# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
Various helpers.

import random
import copy
import pdb
import sys

import torch
import numpy as np

def is_selection(out):
    """if dialog end"""
    return len(out) == 1 and (out[0] in ['<selection>', '<no_agreement>'])

def backward_hook(grad):
    """Hook for backward pass."""
    return grad

def save_model(model, file_name):
    """Serializes model to a file."""
    if file_name != '':
        with open(file_name, 'wb') as f:
            torch.save(model, f)

def load_model(file_name, map_location=None):
    """Reads model from a file."""
    with open(file_name, 'rb') as f:
        return torch.load(f, map_location=map_location)

def set_seed(seed):
    """Sets random seed everywhere."""
    if torch.cuda.is_available():

def use_cuda(enabled, device_id=0):
    """Verifies if CUDA is available and sets default device to be device_id."""
    if not enabled:
        return None
    assert torch.cuda.is_available(), 'CUDA is not available'
    return device_id

def prob_random():
    """Prints out the states of various RNGs."""
    print('random state: python %.3f torch %.3f numpy %.3f' % (
        random.random(), torch.rand(1)[0], np.random.rand()))

class ContextGenerator(object):
    """Dialogue context generator. Generates contexes from the file."""
    def __init__(self, context_file):
        self.ctxs = []
        with open(context_file, 'r') as f:
            ctx_pair = []
            for line in f:
                ctx = line.strip().split()
                if len(ctx_pair) == 2:
                    ctx_pair = []

    def sample(self):
        return random.choice(self.ctxs)

    def iter(self, nepoch=1):
        for e in range(nepoch):
            for ctx in self.ctxs:
                yield ctx

class ManualContextGenerator(object):
    """Dialogue context generator. Takes contexes from stdin."""
    def __init__(self, num_types=3, num_objects=10, max_score=10):
            self.num_types = num_types
            self.num_objects = num_objects
            self.max_score = max_score

    def _input_ctx(self):
        while True:
                ctx = input('Input context: ')
                ctx = ctx.strip().split()
                if len(ctx) != 2 * self.num_types:
                    raise Exception()
                if np.sum([int(x) for x in ctx[0::2]]) != self.num_objects:
                    raise Exception()
                if np.max([int(x) for x in ctx[1::2]]) > self.max_score:
                    raise Exception()
                return ctx
            except KeyboardInterrupt:
                print('The context is invalid! Try again.')
                print('Reason: num_types=%d, num_objects=%d, max_score=%s' % (
                    self.num_types, self.num_objects, self.max_score))

    def _update_scores(self, ctx):
        for i in range(1, len(ctx), 2):
            ctx[i] = np.random.randint(0, self.args.max_score + 1)
        return ctx

    def sample(self):
        ctx1 = self._input_ctx()
        ctx2 = self._update_scores(copy.copy(ctx1))
        return [ctx1, ctx2]