'''
A short script to train and evaluate s2vt model on Flickr30k and
MSR-VTT datasets.

Usage:
            python run_s2vt.py --dataset [MSR-VTT|Flickr30k] --train
            python run_s2vt.py --dataset [MSR-VTT|Flickr30k] --test --checkpoint model_num 
'''

from cfg import *
import os

import pandas as pd
import numpy as np
import pickle as pkl
import json
from tqdm import tqdm
import argparse

from pandas.io.json import json_normalize
from cocoeval import COCOScorer, suppress_stdout_stderr
import sys

from s2vt_model import *

cfg = None


def get_flickr30k_data(cfg):
    #using the provided splits
    train_split = set(map(lambda x: x.split(".")[0], open(cfg.train_file).read().splitlines()))
    val_split = set(map(lambda x: x.split(".")[0], open(cfg.val_file).read().splitlines()))
    test_split = set(map(lambda x: x.split(".")[0], open(cfg.test_file).read().splitlines()))
    
    data = [{"video_id": item.split(".")[0], "sentence_id": item.split("#")[1].split("\t")[0], "caption":item.split("\t")[1]}
            for item in open(cfg.annotations_path).read().splitlines()]
    
    sentences = json_normalize(data)
    sentences['video_path'] = sentences['video_id'].map(lambda x: os.path.join(cfg.path_to_descriptors, x + cfg.descriptor_suffix + ".npy"))
    
    train_imgs = sentences.loc[sentences["video_id"].isin(train_split)]
    train_imgs.reset_index()
    
    val_imgs = sentences.loc[sentences["video_id"].isin(val_split)]
    val_imgs.reset_index()
    
    test_imgs = sentences.loc[sentences["video_id"].isin(test_split)]
    test_imgs.reset_index()
    
    return train_imgs, val_imgs, test_imgs

def get_msr_vtt_data(cfg):
    #trainval data
    with open(cfg.trainval_annotations) as data_file:    
        data = json.load(data_file)
    
    sentences = json_normalize(data['sentences'])
    videos = json_normalize(data['videos'])
    train_vids = sentences.loc[sentences["video_id"].isin(videos[videos['split'] == "train"]["video_id"])]
    val_vids = sentences.loc[sentences["video_id"].isin(videos[videos['split'] == "validate"]["video_id"])]
    train_vids['video_path'] = train_vids['video_id'].map(lambda x: os.path.join(cfg.path_to_trainval_descriptors, x + "_incp_v3.npy"))  
    val_vids['video_path'] = val_vids['video_id'].map(lambda x: os.path.join(cfg.path_to_trainval_descriptors, x + "_incp_v3.npy"))
    
    #test data
    with open(cfg.test_annotations) as data_file:    
        data = json.load(data_file)
    sentences = json_normalize(data['sentences'])
    videos = json_normalize(data['videos'])
    test_vids = sentences.loc[sentences["video_id"].isin(videos[videos['split'] == "test"]["video_id"])]
    test_vids['video_path'] = test_vids['video_id'].map(lambda x: os.path.join(cfg.path_to_test_descriptors, x + "_incp_v3.npy"))
    
    return train_vids, val_vids, test_vids
    

def preProBuildWordVocab(sentence_iterator, word_count_threshold=5): # borrowed this function from NeuralTalk
    print 'preprocessing word counts and creating vocab based on word count threshold %d' % (word_count_threshold, )
    word_counts = {}
    nsents = 0
    for sent in sentence_iterator:
        nsents += 1
        for w in sent.lower().split(' '):
            word_counts[w] = word_counts.get(w, 0) + 1

    vocab = [w for w in word_counts if word_counts[w] >= word_count_threshold]
    print 'filtered words from %d to %d' % (len(word_counts), len(vocab))

    ixtoword = {}
    ixtoword[0] = '.'  # period at the end of the sentence. make first dimension be end token
    wordtoix = {}
    wordtoix['#START#'] = 0 # make first vector be the start token
    ix = 1
    for w in vocab:
        wordtoix[w] = ix
        ixtoword[ix] = w
        ix += 1

    word_counts['.'] = nsents
    bias_init_vector = np.array([1.0*word_counts[ixtoword[i]] for i in ixtoword])
    bias_init_vector /= np.sum(bias_init_vector) # normalize to frequencies
    bias_init_vector = np.log(bias_init_vector)
    bias_init_vector -= np.max(bias_init_vector) # shift to nice numeric range
    return wordtoix, ixtoword, bias_init_vector

def output_progress(current, total, loss):
    bar_length = 20
    progress = current/float(total)
    sys.stdout.write('\r')
    sys.stdout.write(("[%-" + str(bar_length) + "s] %d/%d") % ('='* int(bar_length * progress) + ">", current, total) + ", avg_loss=" + str(loss))
    sys.stdout.flush()

#populate feature dictionary
#unroll features for LSTM encoding
feature_dict = {} 
def load_flickr30k_features(vid):
    if vid in feature_dict:
        return feature_dict[vid]
    else:
        temp_array = np.load(vid)
        temp_array[1::2][:] = temp_array[1::2][:, ::-1][:]
        if cfg.use_hard_cache:
            feature_dict[vid] = temp_array.reshape(cfg.n_frame_step, -1)
            return feature_dict[vid]
        else:
            return temp_array.reshape(cfg.n_frame_step, -1)

def load_msr_vtt_features(vid):
    return np.load(vid)
            

def get_validation_loss(sess, current_val_data, wordtoix, tf_loss, tf_video, tf_caption, tf_caption_mask):
    val_data = current_val_data
    val_captions = val_data['caption'].values
    val_captions = map(lambda x: x.replace('.', ''), val_captions)
    val_captions = map(lambda x: x.replace(',', ''), val_captions)
    
    combine_features = load_flickr30k_features if cfg.id == "Flickr30k" else load_msr_vtt_features 
    
    loss_on_validation = []
    for start,end in zip(
                range(0, len(val_data), cfg.batch_size),
                range(cfg.batch_size, len(val_data)+1, cfg.batch_size)): #during every epoch we are discarding incomplete batch in the end
            
        current_batch = val_data[start:end]
        current_videos = current_batch['video_path'].values

        current_feats = np.zeros((cfg.batch_size, cfg.n_frame_step, cfg.dim_image))
        current_feats_vals = map(lambda vid: combine_features(vid), current_videos) 
            
        for ind,feat in enumerate(current_feats_vals):
            current_feats[ind][:len(current_feats_vals[ind])] = feat 
            

        current_captions = current_batch['caption'].values
        current_caption_ind = map(lambda cap: [wordtoix[word] for word in cap.lower().split(' ')[:cfg.n_lstm_step - 1]
                                                   if word in wordtoix],
                                      current_captions)
        
        current_caption_matrix = np.zeros((cfg.batch_size, cfg.n_lstm_step))
        current_caption_masks = np.zeros((cfg.batch_size, cfg.n_lstm_step))
        for ind, row in enumerate(current_caption_masks):
            valid_length = len(current_caption_ind[ind])
            row[:valid_length] = 1
            current_caption_matrix[ind, :valid_length] = current_caption_ind[ind]
        
        loss_val = sess.run(tf_loss,
                feed_dict={
                    tf_video: current_feats,
                    tf_caption: current_caption_matrix,
                    tf_caption_mask: current_caption_masks
                    })
        loss_on_validation.append(loss_val)
    return np.mean(loss_on_validation)
            
        
def train():
    if not os.path.exists(cfg.model_path):
        os.makedirs(cfg.model_path)
    
    print cfg.model_path
    f = open(cfg.model_path + "loss", "a", 1)
    f.write("Checkpoint\tTrain loss\tValidation loss\n")
    
    if cfg.id == "Flickr30k":
        train_data, val_data, _ = get_flickr30k_data(cfg)
    elif cfg.id == "MSR-VTT":
        train_data, val_data, _ = get_msr_vtt_data(cfg)
    
    #FIXME add validation data vocabulary
    captions = train_data['caption'].values
    captions = map(lambda x: x.replace('.', ''), captions)
    captions = map(lambda x: x.replace(',', ''), captions)
    wordtoix, ixtoword, bias_init_vector = preProBuildWordVocab(captions, word_count_threshold=cfg.word_count_threshold)
    
    combine_features = load_flickr30k_features if cfg.id == "Flickr30k" else load_msr_vtt_features

    np.save(cfg.vocab_path + 'ixtoword', ixtoword)
    with open(cfg.vocab_path + 'wordtoix.pkl', 'wb') as outfile:
        pkl.dump(wordtoix, outfile)
    
    sess = tf.InteractiveSession(config=tf.ConfigProto(gpu_options=gpu_options))
    
    with tf.variable_scope(tf.get_variable_scope()):
        model_train = s2vt(dim_image=cfg.dim_image,
                           n_words=len(ixtoword),
                           dim_hidden=cfg.dim_hidden,
                           batch_size=cfg.batch_size,
                           n_frame_steps=cfg.n_frame_step,
                           n_lstm_steps=cfg.n_lstm_step,
                           dim_word_emb = cfg.dim_word_emb,
                           cell_clip = cfg.cell_clip,
                           forget_bias = cfg.forget_bias,
                           input_keep_prob = cfg.input_keep_prob,
                           output_keep_prob = cfg.output_keep_prob,
                           bias_init_vector=bias_init_vector)
    
        tf_loss, tf_video, tf_caption, tf_caption_mask, _ = model_train.build_model("training")

    with tf.variable_scope(tf.get_variable_scope(), reuse=False):
        train_op = tf.train.AdamOptimizer(cfg.learning_rate).minimize(tf_loss)
    
    saver = tf.train.Saver(max_to_keep=cfg.max_to_keep)
    sess.run(tf.global_variables_initializer())
    
    model_counter = 0
    val_loss = None
    for epoch in range(cfg.n_epochs):
        index = list(train_data.index)
        np.random.shuffle(index)
        train_data = train_data.ix[index]
        
        current_train_data = train_data
        
        total_loss = 0
        saving_schedule = []
        loss_accumulator = []
        
        step_size =  (int(len(current_train_data) * cfg.save_every_n_epoch) // cfg.batch_size ) * cfg.batch_size
        saving_schedule = range(0, len(current_train_data) - step_size, step_size)
        print saving_schedule
        
        for start,end in zip(
                range(0, len(current_train_data), cfg.batch_size),
                range(cfg.batch_size, len(current_train_data)+1, cfg.batch_size)):   
            current_batch = current_train_data[start:end]                           
            current_videos = current_batch['video_path'].values

            current_feats = np.zeros((cfg.batch_size, cfg.n_frame_step, cfg.dim_image))
            current_feats_vals = map(lambda vid: combine_features(vid), current_videos)
            
            
            for ind,feat in enumerate(current_feats_vals):
                current_feats[ind][:len(current_feats_vals[ind])] = feat 
                
            current_captions = current_batch['caption'].values
            current_caption_ind = map(lambda cap: [wordtoix[word] for word in cap.lower().split(' ')[:cfg.n_lstm_step - 1]
                                                   if word in wordtoix],
                                      current_captions)
            
            current_caption_matrix = np.zeros((cfg.batch_size, cfg.n_lstm_step))
            current_caption_masks = np.zeros((cfg.batch_size, cfg.n_lstm_step))
            for ind, row in enumerate(current_caption_masks):
                valid_length = len(current_caption_ind[ind])
                row[:valid_length+1] = 1                                                #forces to predict <EOS> = 0
                current_caption_matrix[ind, :valid_length] = current_caption_ind[ind]
 
            _, train_loss = sess.run(
                    [train_op, tf_loss],
                    feed_dict={
                        tf_video: current_feats,
                        tf_caption: current_caption_matrix,
                        tf_caption_mask: current_caption_masks
                        })
            total_loss += train_loss
            loss_accumulator.append(train_loss)
                
            output_progress(end, len(current_train_data), train_loss)
            
            if start in saving_schedule:
                print start
                train_loss = np.mean(loss_accumulator[-5:])
                val_loss = get_validation_loss(sess, 
                                               val_data.groupby('video_id').apply(lambda x: x.iloc[np.random.choice(len(x))]),
                                               wordtoix, tf_loss, tf_video,
                                               tf_caption, tf_caption_mask)
                f.write(str(model_counter) + "\t" +  str(train_loss) +"\t" + str(val_loss) + "\n")
                sys.stdout.flush()
                saver.save(sess, os.path.join(cfg.model_path, 'model'), global_step=model_counter)
                model_counter+=1
                
        output_progress(end, len(current_train_data), np.mean(loss_accumulator[-5:]))
        print " Done. Validation loss = " + str(val_loss) 
        
        

def convert_data_to_coco_scorer_format(data_frame):
    gts = {}
    non_ascii_count = 0
    for row in zip(data_frame["caption"], data_frame["video_id"]):
        try:
            row[0].encode('ascii', 'ignore').decode('ascii')
        except UnicodeDecodeError:
            non_ascii_count+=1
            continue
        if row[1] in gts:
            gts[row[1]].append({u'image_id': row[1], u'cap_id': len(gts[row[1]]), u'caption':row[0].encode('ascii', 'ignore').decode('ascii')})
        else:
            gts[row[1]] = []
            gts[row[1]].append({u'image_id': row[1], u'cap_id': len(gts[row[1]]), u'caption':row[0].encode('ascii', 'ignore').decode('ascii')})
    if non_ascii_count:
        print "=" * 20 + "\n" + "non-ascii: " + str(non_ascii_count) + "\n" + "=" * 20
    return gts


def test(saved_model=''):
    scorer = COCOScorer()
    ixtoword = pd.Series(np.load(cfg.vocab_path + 'ixtoword.npy').tolist())
    combine_features = load_flickr30k_features if cfg.id == "Flickr30k" else load_msr_vtt_features
    
    model = s2vt(dim_image=cfg.dim_image,
                 n_words=len(ixtoword),
                 dim_hidden=cfg.dim_hidden,
                 batch_size=cfg.batch_size,
                 n_frame_steps=cfg.n_frame_step,
                 n_lstm_steps=cfg.n_lstm_step,
                 dim_word_emb = cfg.dim_word_emb,
                 cell_clip = cfg.cell_clip,
                 forget_bias = cfg.forget_bias,
                 input_keep_prob = cfg.input_keep_prob,
                 output_keep_prob = cfg.output_keep_prob,
                 bias_init_vector=None)
    
    _, video_tf, caption_tf, _, _ = model.build_model("inference")
    session = tf.InteractiveSession(config=tf.ConfigProto(gpu_options=gpu_options))
    saver = tf.train.Saver()
    saver.restore(session, saved_model)
        
    if cfg.id == "Flickr30k":
        _, _, test_data = get_flickr30k_data(cfg)
    elif cfg.id == "MSR-VTT":
        _, _, test_data = get_msr_vtt_data(cfg)
    
    splits = []
    
    splits.append((test_data['video_path'].unique(), test_data))
    results = []
    for split, gt_dataframe in splits:
        gts = convert_data_to_coco_scorer_format(gt_dataframe)
        samples = {}
        for start,end in zip(
                    range(0, len(split), cfg.batch_size),
                    range(cfg.batch_size, len(split) + cfg.batch_size, cfg.batch_size)):
        
            current_batch = split[start:end]
            current_feats = np.zeros((cfg.batch_size, cfg.n_frame_step, cfg.dim_image))
            current_feats_vals = [combine_features(vid) for vid in current_batch] 
            
            for ind,feat in enumerate(current_feats_vals):
                current_feats[ind][:len(current_feats_vals[ind])] = feat
            
            generated_word_index = session.run(caption_tf, feed_dict={video_tf:current_feats})
            generated_word_index = np.asarray(generated_word_index).transpose()
            periods = np.argmax(generated_word_index == 0, axis=1) + 1
            periods[periods == 0] = cfg.n_lstm_step      #take the whole sequence if a period was not produced
            for i in range(len(current_batch)):
                generated_sentence = ' '.join(ixtoword[generated_word_index[i, :periods[i]-1]])
                video_id = current_batch[i].split("/")[-1].split("_")[0] #+ ".jpg"
                samples[video_id] = [{u'image_id': video_id, u'caption': generated_sentence}]
                
        with suppress_stdout_stderr():
            valid_score = scorer.score(gts, samples, samples.keys())
        results.append(valid_score)
        print valid_score
    
    print len(samples)            
    if not os.path.exists(cfg.results_path):
        os.makedirs(cfg.results_path)
    
    with open(cfg.results_path + "scores.txt", 'a') as scores_table:
            scores_table.write(json.dumps(results[0]) + "\n") 
    with open(cfg.results_path + saved_model.split("/")[-1] + ".json", 'w') as prediction_results:
        json.dump({"predictions": samples, "scores": valid_score}, prediction_results)


def main(args):
    global cfg
    if args.dataset == "Flickr30k":
        cfg = flickr_cfg()
    elif args.dataset == "MSR-VTT":
        cfg = msr_vtt_cfg()
    else:
        print "Unknown dataset"
        exit(1)
    
    if args.train_stage:
        train()
    else:
        test(saved_model=cfg.model_path + 'model-' + str(args.checkpoint))
    
    
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Script to train a model for movie description')
    
    group = parser.add_mutually_exclusive_group(required=True)
    group.add_argument('--train', dest='train_stage', action='store_true',  
                        help='Training')
    group.add_argument('--test', dest='train_stage', action='store_false', 
                        help='Testing')
    parser.add_argument('--checkpoint', dest='checkpoint', type = int, default = -1,
                        help='Provide a number of the saved model to run testing only on one snapshot')
    parser.add_argument("--dataset", dest='dataset', type=str,
                        help='Specify one from {Flickr30k, MSR-VTT}')
    parser.add_argument("--gpu", dest='gpu', type=str, required=False,
                        help='Set CUDA_VISIBLE_DEVICES environment variable, optional')
    
    args = parser.parse_args()
    if args.gpu:
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    
    if not args.dataset:
        parser.print_help()
        exit(1)

    if not args.train_stage:
        if args.checkpoint is None:
            parser.print_help()
            exit(1)
    
    main(args)