#!/usr/bin/env python

# Copyrigh 2018 houjingyong@gmail.com

# MIT Licence

from __future__ import print_function

import os, sys, argparse, datetime, shutil
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from bbox_transform import get_out_utt_boxes
from config import cfg
from config import cfg_from_file
from streaming_special_torch_dataset import * 
from kaldi_io import *
from RNNs import GRU
from RPN import RPN
from RPN_KWS import RPN_KWS
from utils import AverageMeter, count_parameters
from loss import loss_frame_fn_ce, acc_frame

def get_args():
    """Get arguments from stdin."""
    parser = argparse.ArgumentParser(description='Pytorch acoustic model.')
    parser.add_argument('--encoder', type=str, default='gru',
                        help='encoder type {default: gru}')
    parser.add_argument('--num-anchor', type=int, default=10, metavar='HF',
                        help='Num anchors per frame {default: 10.0}')
    parser.add_argument('--lambda-factor', type=float, default=5.0, metavar='HF',
                        help='Balance factor between classification and regression loss (default: 5.0).')
    parser.add_argument('--input-dim', type=int, default=40, metavar='N',
                        help='Input feature dimension without context (default: 40).')
    parser.add_argument('--kernel-size', type=int, default=3, metavar='N',
                        help='Kernel size of Wavenet or CNN (default:3).')
    parser.add_argument('--hidden-dim', type=int, default=128, metavar='N',
                        help='Hidden dimension of feature extractor (default: 128).')
    parser.add_argument('--num-layers', type=int, default=2, metavar='N',
                        help='Numbers of hidden layers of feature extractor (default: 2).')
    parser.add_argument('--output-dim', type=int, default=2000, metavar='N',
                        help='Output dimension, number of classes (default: 2000).')
    parser.add_argument('--dropout', type=float, default=0.0001, metavar='DR',
                        help='dropout of feature extractor (default: 0.0001).')
    parser.add_argument('--left-context', type=int, default=5, metavar='N',
                        help='Left context length for splicing feature (default: 5).')
    parser.add_argument('--right-context', type=int, default=5, metavar='N',
                        help='Right context length for splicing feature (default: 5).')
    parser.add_argument('--max-epochs', type=int, default=20, metavar='N',
                        help='Maximum epochs to train (default: 20).')
    parser.add_argument('--min-epochs', type=int, default=0, metavar='N',
                        help='Minimum epochs to train (default: 0).')
    parser.add_argument('--batch-size', type=int, default=8, metavar='N',
                        help='Batch size for training (default: 8).')
    parser.add_argument('--learning-rate', type=float, default=0.001, metavar='LR',
                        help='Initial learning rate (default: 0.001).')
    parser.add_argument('--halving-factor', type=float, default=0.5, metavar='HF',
                        help='Half factor for learning rate (default: 0.5).')
    parser.add_argument('--start-halving-impr', type=float, default=0.01, metavar='S',
                        help='Improvement threshold to half the learning rate (default: 0.01).')
    parser.add_argument('--end-halving-impr', type=float, default=0.001, metavar='E',
                        help='Improvement threshold to stop half learning rate (default: 0.001).')
    parser.add_argument('--init-weight-decay', type=float, default=1e-5, metavar='E',
                                    help='Weight decay of L2 normalization (default: 1e-5).')
    parser.add_argument('--seed', type=int, default=1234, metavar='S',
                        help='Random seed (default: 1234).')
    parser.add_argument('--use-cuda', type=int, default=1, metavar='C',
                        help='Use cuda (1) or cpu(0).')
    parser.add_argument('--multi-gpu', type=int, default=0, metavar='G',
                        help='Use multi gpu (1) or not (0).')
    parser.add_argument('--train', type=int, default=1,
                        help='Executing mode, train (1) or test (0).')
    parser.add_argument('--train-scp', type=str, default='',
                        help='Training data file.')
    parser.add_argument('--dev-scp', type=str, default='',
                        help='Development data file.')
    parser.add_argument('--save-dir', type=str, default='',
                        help='Directory to output the model.')
    parser.add_argument('--load-model', type=str, default='',
                        help='Previous model to load.')
    parser.add_argument('--test', type=int, default=0,
                        help='Executing mode, 1 for test, 0 no test')
    parser.add_argument('--test-scp', type=str, default='',
                        help='Test data file.')
    parser.add_argument('--output-file', type=str, default='',
                        help='Test output file')
    parser.add_argument('--region-output-file', type=str, default='',
                        help='Region output file')
    parser.add_argument('--log-interval', type=int, default=1000, metavar='N',
                        help='How many batches to wait before logging training status.')
    parser.add_argument('--num-workers', type=int, default=1, metavar='N',
                        help='How many workers used to load data')
    parser.add_argument('--config-file', type=str, default='',
                        help='config file in yaml format')
    args = parser.parse_args()

    if args.config_file != '':
        cfg_from_file(args.config_file)
    
    return args

def get_new_target(device, target, num_p, num_n):
    new_target=[]
    for i in range(target.size(0)):
        if target[i][0] == 0:
            new_target += ([target[i][0]] * num_n)
        else:
            new_target += ([target[i][0]] * num_p)
    return torch.LongTensor(new_target).to(device)

def adjust_learning_rate(args, optimizer):
    """Half the learning rate when relative improvement is too low.
    Args:
        args: Arguments for training.
        optimizer: Optimizer for training.
    """
    args.learning_rate *= args.halving_factor
    for param_group in optimizer.param_groups:
        param_group['lr'] = args.learning_rate

def train(args, model, device, train_loader, optimizer, epoch):
    """Train one epoch."""
    tr_rpn_loss_bbox = AverageMeter()
    tr_rpn_loss_cls = AverageMeter()
    tr_loss = AverageMeter()
    tr_rpn_acc = AverageMeter()
    model.train()
    total_step = len(train_loader)
    balance_weight=args.lambda_factor
    for batch_idx, (utt_id, act_lens, data, target) in enumerate(train_loader):
        act_lens, data, target = act_lens.to(device), data.to(device), target.to(device)
        target = target.reshape(target.size(0), 1, target.size(1)).float()
        # Forward pass
        batch_size = data.shape[0]
        outputs = model(epoch, data, act_lens, target, 100)
        rois, rpn_cls_score, rpn_label, rpn_loss_cls, rpn_loss_bbox = outputs
        rpn_acc = acc_frame(rpn_cls_score, rpn_label)
        
        # Backward and optimize
        loss = rpn_loss_cls + balance_weight * rpn_loss_bbox
        optimizer.zero_grad()
        loss.backward()
        #name, param=list(model.named_parameters())[1]
        #print('Epoch:[{}/{}], param name:{},\n param:'.format(epoch+1, args.max_epochs, name, param))
        optimizer.step()

        tr_rpn_acc.update(rpn_acc, 1)
        tr_loss.update(loss, 1)
        tr_rpn_loss_cls.update(rpn_loss_cls, 1)
        tr_rpn_loss_bbox.update(rpn_loss_bbox, 1)

        if batch_idx % args.log_interval == 0:
            print('Epoch: [{}/{}], Step [{}/{}], Train Loss: {:.4f}, Train RPN Acc: {:.4f}%'
                  .format(epoch+1, args.max_epochs, batch_idx+1, total_step, tr_loss.cur,  tr_rpn_acc.cur))
            print('Epoch: [{}/{}], Step [{}/{}], Train RPN cls Loss: {:.4f}, Train RPN bbox Loss: {:.4f} '
                  .format(epoch+1, args.max_epochs, batch_idx+1, total_step, tr_rpn_loss_cls.cur, tr_rpn_loss_bbox.cur))

    print('Epoch: [{}/{}], Average Train Loss: {:.4f}, Average Train RPN cls Loss: {:.4f}, Average Train RPN bbox Loss: {:.4f}, AverageAverage Train RPN Acc: {:.4f}%'
         .format(epoch+1, args.max_epochs, tr_loss.avg, tr_rpn_loss_cls.avg, tr_rpn_loss_bbox.avg, tr_rpn_acc.avg))
    return float("{:.4f}".format(tr_loss.avg))


def validate(args, model, device, dev_loader, epoch):
    """Cross validate the model."""
    meter_rpn_loss_bbox = AverageMeter()
    balance_weight = args.lambda_factor
    meter_rpn_loss_cls = AverageMeter()
    meter_loss = AverageMeter()
    meter_rpn_acc = AverageMeter()
    balance_weight = args.lambda_factor
    with torch.no_grad():
        total_step = len(dev_loader)
        for batch_idx, (utt_id, act_lens, data, target) in enumerate(dev_loader):
            act_lens, data, target = act_lens.to(device), data.to(device), target.to(device)
            target = target.reshape(target.size(0), 1, target.size(1)).float()
            # Forward pass
            batch_size = data.shape[0]
            outputs = model(epoch, data, act_lens, target, 100)
            rois, rpn_cls_score, rpn_label, rpn_loss_cls, rpn_loss_bbox = outputs
            rpn_acc = acc_frame(rpn_cls_score, rpn_label)
            # Backward and optimize
            loss = rpn_loss_cls + balance_weight * rpn_loss_bbox 
            meter_rpn_acc.update(rpn_acc, 1)
            meter_loss.update(loss, 1)
            meter_rpn_loss_cls.update(rpn_loss_cls, 1)
            meter_rpn_loss_bbox.update(rpn_loss_bbox, 1)

            if batch_idx % args.log_interval == 0:
                print('Epoch: [{}/{}], Step [{}/{}], Val Loss: {:.4f}, Val RPN Acc: {:.4f}% '
                      .format(epoch+1, args.max_epochs, batch_idx+1, total_step, meter_loss.cur, meter_rpn_acc.cur))
                print('Epoch: [{}/{}], Step [{}/{}], Val RPN cls Loss: {:.4f}, Val RPN bbox Loss: {:.4f} '
                      .format(epoch+1, args.max_epochs, batch_idx+1, total_step, meter_rpn_loss_cls.cur, meter_rpn_loss_bbox.cur))

        print('Epoch: [{}/{}], Average Val Loss: {:.4f}, Average Val RPN cls Loss: {:.4f}, Average Val RPN bbox Loss: {:.4f}, Average Val RPN Acc: {:.4f}%'
             .format(epoch+1, args.max_epochs, meter_loss.avg, meter_rpn_loss_cls.avg, meter_rpn_loss_bbox.avg, meter_rpn_acc.avg))
    return float("{:.4f}".format(meter_loss.avg))

def test(args, model, device, test_loader, output_file, region_output_file):
    """Test the model"""
    write_post = open_or_fd(output_file, "wb")                                  
    fid = open(region_output_file, "w")
    model.eval()
    with torch.no_grad():
        total_step = len(test_loader)
        for batch_idx, (utt_ids, act_lens, data, target) in enumerate(test_loader):
            act_lens, data, target = act_lens.to(device), data.to(device), target.to(device)
            target = target.reshape(target.size(0), 1, target.size(1)).float()
            # Forward pass
            batch_size = data.shape[0]
            max_lens = data.shape[1]
            num_anchors_per_frame = args.num_anchor
            num_classes = args.output_dim
            outputs = model(0, data, act_lens, target, 100)
            rois, rpn_cls_score, anchors_per_utt = outputs
            rpn_cls_prob = F.softmax(rpn_cls_score, dim=2)
            disable_indexes = get_out_utt_boxes(anchors_per_utt, act_lens, batch_size)
            rpn_cls_prob[disable_indexes] = 0
            rpn_cls_prob = rpn_cls_prob.view(batch_size, max_lens, num_anchors_per_frame, num_classes)
            rois = rois.view(batch_size, max_lens, num_anchors_per_frame, 2)
            anchors_per_utt = anchors_per_utt.view(max_lens, num_anchors_per_frame, 2)
            rpn_cls_prob, arg_max_anchor = torch.max(rpn_cls_prob, dim=2)
            max_score, arg_max_score = torch.max(rpn_cls_prob, dim=1) # get the index of each utterance
            data_write = rpn_cls_prob.cpu().numpy()
            for i in range (len(utt_ids)):
                utt_id = utt_ids[i]
                act_len = act_lens[i]
                write_mat(write_post, data_write[i,0:act_len,:], utt_id)
                fid.writelines(utt_id)
                label = target[i][0].cpu().numpy()
                fid.writelines(", %f %f %f"%(label[0],label[1],label[2])) 
                for j in range(num_classes-1):
                    best_score1 = max_score[i][1+j]
                    best_frame1 = arg_max_score[i][1+j]
                    best_anchor1 = arg_max_anchor[i][best_frame1][1+j]
                    roi1 = rois[i][best_frame1][best_anchor1] # anchor of keyword 1
                    anchor1 = anchors_per_utt[best_frame1][best_anchor1]
                    roi1=roi1.cpu().numpy()
                    anchor1 = anchor1.cpu().numpy()
                    fid.writelines(", %f %f %f, %f %f %f"%(best_score1, anchor1[0], anchor1[1], best_score1, roi1[0], roi1[1]))
                fid.writelines("\n")
    write_post.close()
    fid.close()
def main():
    args = get_args()

    device = torch.device('cuda' if args.use_cuda else 'cpu')
    torch.manual_seed(args.seed)
    if args.encoder=='gru':
        feature_extractor = GRU(input_size=args.input_dim, 
                output_size=args.hidden_dim, 
                hidden_size=args.hidden_dim, 
                num_layers=args.num_layers, 
                bias=True, batch_first=True, 
                dropout=args.dropout, 
                bidirectional=False, 
                output_layer=False)
    else:
        print("unsupported feature extractor: %s"%args.encoder)
        exit(1)

    rpn = RPN(128, args.num_anchor, args.output_dim)

    model = RPN_KWS(feature_extractor, rpn, args.output_dim).to(device)

    params = count_parameters(model)                                            
    print("Num parameters: %d, Num Flops: %d\n"%(params,0))
    
    if args.multi_gpu:
        model = nn.DataParallel(model)

    optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=args.init_weight_decay)

    print("Global Config:\n {}".format(cfg))
    print("Training Arguments:\n {}".format(args))
    print("Training Model:\n {}".format(model))
    print("Training Optimizer:\n {}".format(optimizer))

    # Load previous trained model
    if args.load_model != '':
        print("=> Loading previous checkpoint to train: {}".format(args.load_model))
        checkpoint = torch.load(args.load_model)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        prev_val_loss = checkpoint['prev_val_loss']
    elif not args.train:
        sys.exit("Option --load-model should not be empty for testing.")
    else:
        print("=> No checkpoint found.")
        prev_val_loss = float('inf')

    # For training
    if args.train:
        if args.train_scp == '' or args.dev_scp == '':
            sys.exit("Options --train-scp and --dev-scp are required for training.")

        if args.save_dir == '':
            sys.exit("Option --save-dir is required to save model.")

        if not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)

        halving = 0
        best_model = args.load_model
        kwargs = {'num_workers': 3, 'pin_memory': True} if args.use_cuda else {}

        # Training data loader
        train_set = StreamingTorchDataset(args.train_scp, ["kaldi_reader", "raw_list_reader"], args.left_context, args.right_context)
        train_loader = torch.utils.data.DataLoader(
            dataset=train_set,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            collate_fn=collate_fn)

        # Dev data loader
        dev_set = StreamingTorchDataset(args.dev_scp,["kaldi_reader", "raw_list_reader"], args.left_context, args.right_context)
        dev_loader = torch.utils.data.DataLoader(
            dataset=dev_set,
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=args.num_workers,
            collate_fn=collate_fn)

        for epoch in range(args.max_epochs):
            cur_tr_loss = train(args, model, device, train_loader,optimizer, epoch)
            cur_val_loss = validate(args, model, device, dev_loader, epoch)
            rel_impr = (prev_val_loss - cur_val_loss) / prev_val_loss

            model_name = 'nnet_epoch' + str(epoch+1) + '_lr' \
                        + str(args.learning_rate) + '_tr' + str(cur_tr_loss) \
                        + '_cv' + str(cur_val_loss) + '.ckpt'
            model_path = args.save_dir + '/' + model_name

            if cur_val_loss < prev_val_loss:

                prev_val_loss = cur_val_loss
                torch.save({
                    'prev_val_loss': prev_val_loss,
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, model_path)
                best_model = model_path

                print("Model {} accepted. Time: {}".format(model_name,
                                                           datetime.datetime.now()))

            else:
                print ("Model {} rejected. Time: {}".format(model_name,
                                                            datetime.datetime.now()))
                if best_model != '':
                    print("=> Loading best checkpoint: {}".format(best_model))
                    checkpoint = torch.load(best_model)
                    model.load_state_dict(checkpoint['model'])
                    optimizer.load_state_dict(checkpoint['optimizer'])
                    prev_val_loss = checkpoint['prev_val_loss']
                else:
                    sys.exit("Error training neural network.")

            # Stopping training criterion
            if halving and rel_impr < args.end_halving_impr:
                if epoch < args.min_epochs:
                    print("We were supposed to finish, but we continue as min_epochs"
                          .format(args.min_epochs))
                    continue
                else:
                    print("Finished, too small relative improvement {}".format(rel_impr))
                    break

            # Start halving when improvement is low
            if rel_impr < args.start_halving_impr:
                halving = 1

            if halving:
                adjust_learning_rate(args, optimizer)
                print("Halving learning rate to {}".format(args.learning_rate))

        if best_model != args.load_model:
            final_model = args.save_dir + "/final.mdl"
            shutil.copyfile(best_model, final_model)
            print("Succeeded training the neural network: {}/final.mdl"
                  .format(args.save_dir))
        else:
            sys.exit("Error training neural network.")
    # For testing
    if args.test:
        # Test data loader
        if args.test_scp == '' or args.output_file == '':
            sys.exit("Options --test-scp and --output-file are required for testing")
        test_set = StreamingTorchDataset(args.test_scp,["kaldi_reader", "raw_list_reader"], args.left_context, args.right_context)
        test_loader = torch.utils.data.DataLoader(
            dataset=test_set,
            batch_size=args.batch_size,
            shuffle=False,
            num_workers=args.num_workers,
            collate_fn=collate_fn)
        test(args, model, device, test_loader, args.output_file, args.region_output_file) 


if __name__ == '__main__':
    main()