# coding=utf-8
# Copyright (C) 2019 Alibaba Group Holding Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import os
import math
import random
import torch
import torch.nn.functional as f
from tqdm import tqdm
from .network import Network
from .utils.metrics import registry as metrics


class Model:
    prefix = 'checkpoint'
    best_model_name = 'best.pt'

    def __init__(self, args, state_dict=None):
        self.args = args

        # network
        self.network = Network(args)
        self.device = torch.cuda.current_device() if args.cuda else torch.device('cpu')
        self.network.to(self.device)
        # optimizer
        self.params = list(filter(lambda x: x.requires_grad, self.network.parameters()))
        self.opt = torch.optim.Adam(self.params, args.lr, betas=(args.beta1, args.beta2),
                                    weight_decay=args.weight_decay)
        # updates
        self.updates = state_dict['updates'] if state_dict else 0

        if state_dict:
            new_state = set(self.network.state_dict().keys())
            for k in list(state_dict['model'].keys()):
                if k not in new_state:
                    del state_dict['model'][k]
            self.network.load_state_dict(state_dict['model'])
            self.opt.load_state_dict(state_dict['opt'])

    def _update_schedule(self):
        if self.args.lr_decay_rate < 1.:
            args = self.args
            t = self.updates
            base_ratio = args.min_lr / args.lr
            if t < args.lr_warmup_steps:
                ratio = base_ratio + (1. - base_ratio) / max(1., args.lr_warmup_steps) * t
            else:
                ratio = max(base_ratio, args.lr_decay_rate ** math.floor((t - args.lr_warmup_steps) /
                                                                         args.lr_decay_steps))
            self.opt.param_groups[0]['lr'] = args.lr * ratio

    def update(self, batch):
        self.network.train()
        self.opt.zero_grad()
        inputs, target = self.process_data(batch)
        output = self.network(inputs)
        summary = self.network.get_summary()
        loss = self.get_loss(output, target)
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(self.params, self.args.grad_clipping)
        assert grad_norm >= 0, 'encounter nan in gradients.'
        self.opt.step()
        self._update_schedule()
        self.updates += 1
        stats = {
            'updates': self.updates,
            'loss': loss.item(),
            'lr': self.opt.param_groups[0]['lr'],
            'gnorm': grad_norm,
            'summary': summary,
        }
        return stats

    def evaluate(self, data):
        self.network.eval()
        targets = []
        probabilities = []
        predictions = []
        losses = []
        for batch in tqdm(data[:self.args.eval_subset], desc='evaluating', leave=False):
            inputs, target = self.process_data(batch)
            with torch.no_grad():
                output = self.network(inputs)
                loss = self.get_loss(output, target)
                pred = torch.argmax(output, dim=1)
                prob = torch.nn.functional.softmax(output, dim=1)
                losses.append(loss.item())
                targets.extend(target.tolist())
                probabilities.extend(prob.tolist())
                predictions.extend(pred.tolist())
        outputs = {
            'target': targets,
            'prob': probabilities,
            'pred': predictions,
            'args': self.args,
        }
        stats = {
            'updates': self.updates,
            'loss': sum(losses[:-1]) / (len(losses) - 1) if len(losses) > 1 else sum(losses),
        }
        for metric in self.args.watch_metrics:
            if metric not in stats:  # multiple metrics could be computed by the same function
                stats.update(metrics[metric](outputs))
        assert 'score' not in stats, 'metric name collides with "score"'
        eval_score = stats[self.args.metric]
        stats['score'] = eval_score
        return eval_score, stats  # first value is for early stopping

    def predict(self, batch):
        self.network.eval()
        inputs, _ = self.process_data(batch)
        with torch.no_grad():
            output = self.network(inputs)
            output = torch.nn.functional.softmax(output, dim=1)
        return output.tolist()

    def process_data(self, batch):
        text1 = torch.LongTensor(batch['text1']).to(self.device)
        text2 = torch.LongTensor(batch['text2']).to(self.device)
        mask1 = torch.ne(text1, self.args.padding).unsqueeze(2)
        mask2 = torch.ne(text2, self.args.padding).unsqueeze(2)
        inputs = {
            'text1': text1,
            'text2': text2,
            'mask1': mask1,
            'mask2': mask2,
        }
        if 'target' in batch:
            target = torch.LongTensor(batch['target']).to(self.device)
            return inputs, target
        return inputs, None

    @staticmethod
    def get_loss(logits, target):
        return f.cross_entropy(logits, target)

    def save(self, states, name=None):
        if name:
            filename = os.path.join(self.args.summary_dir, name)
        else:
            filename = os.path.join(self.args.summary_dir, f'{self.prefix}_{self.updates}.pt')
        params = {
            'state_dict': {
                'model': self.network.state_dict(),
                'opt': self.opt.state_dict(),
                'updates': self.updates,
            },
            'args': self.args,
            'random_state': random.getstate(),
            'torch_state': torch.random.get_rng_state()
        }
        params.update(states)
        if self.args.cuda:
            params['torch_cuda_state'] = torch.cuda.get_rng_state()
        torch.save(params, filename)

    @classmethod
    def load(cls, file):
        checkpoint = torch.load(file, map_location=(
            lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
        ))
        prev_args = checkpoint['args']
        # update args
        prev_args.output_dir = os.path.dirname(os.path.dirname(file))
        prev_args.summary_dir = os.path.join(prev_args.output_dir, prev_args.name)
        prev_args.cuda = prev_args.cuda and torch.cuda.is_available()
        return cls(prev_args, state_dict=checkpoint['state_dict']), checkpoint

    def num_parameters(self, exclude_embed=False):
        num_params = sum(p.numel() for p in self.network.parameters() if p.requires_grad)
        if exclude_embed:
            num_params -= 0 if self.args.fix_embeddings else next(self.network.embedding.parameters()).numel()
        return num_params

    def set_embeddings(self, embeddings):
        self.network.embedding.set_(embeddings)