"""
Copyright 2019, ETH Zurich

This file is part of L3C-PyTorch.

L3C-PyTorch is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
any later version.

L3C-PyTorch is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with L3C-PyTorch.  If not, see <https://www.gnu.org/licenses/>.
"""
import os
import time
import re
import shutil

import pytorch_ext as pe
from os.path import basename
import torch
from torch.optim import optimizer
from fjcommon.no_op import NoOp
from fjcommon import timer
from fjcommon.assertions import assert_exc


class _CheckpointTracker(object):
    """ out_dir is usally set via set_out_dir """
    def __init__(self, out_dir=None, ckpt_name_fmt='ckpt_{:010d}.pt', tmp_postfix='.tmp'):
        assert len(tmp_postfix)
        assert '.' in tmp_postfix
        m = re.search(r'{:0(\d+?)d}', ckpt_name_fmt)
        assert m, 'Expected ckpt_name_fmt to have an int specifier such as or {:09d} or {:010d}.'
        max_itr = 10 ** int(m.group(1)) - 1
        if max_itr < 10000000:  # ten million, should be enough
            print(f'Maximum iteration supported: {max_itr}')
        assert os.sep not in ckpt_name_fmt
        self.ckpt_name_fmt = ckpt_name_fmt
        self.ckpt_prefix = ckpt_name_fmt.split('{')[0]
        assert len(self.ckpt_prefix), 'Expected ckpt_name_fmt to start with a prefix before the format part!'
        self.tmp_postfix = tmp_postfix

        self._out_dir = None
        if out_dir is not None:
            self.set_out_dir(out_dir)

    def set_out_dir(self, out_dir):
        assert self._out_dir is None
        os.makedirs(out_dir, exist_ok=True)
        self._out_dir = out_dir

    def get_all_ckpts(self):
        """
        :return: All checkpoints in `self._out_dir`, sorted ascendingly by global_step.
        """
        return [os.path.join(self._out_dir, f)
                for f in sorted(os.listdir(self._out_dir))
                if f.startswith(self.ckpt_prefix)]

    def itr_ckpt(self):
        for ckpt_p in self.get_all_ckpts():
            yield self.get_itr_from_ckpt_p(ckpt_p), ckpt_p

    def get_ckpt_for_itr(self, itr):
        """
        Gets ckpt_itrc where itrc <= itr, i.e., the latest ckpt before `itr`.
        Special values: itr == -1 -> newest ckpt
        """
        ckpts = list(self.itr_ckpt())
        assert_exc(len(ckpts) > 0, 'No ckpts found in {}'.format(self._out_dir))
        if itr == -1:
            return ckpts[-1]
        first_itrc, _ = ckpts[0]
        assert_exc(first_itrc <= itr, 'Earliest ckpt {} is after {}'.format(first_itrc, itr))
        for itrc, ckpt_p in reversed(ckpts):
            if itrc <= itr:
                return itrc, ckpt_p
        raise ValueError('Unexpected, {}, {}'.format(itr, ckpts))

    def get_latest_ckpt(self):
        """
        :return: Most recent checkpoint. May be a temporary checkpoint.
        """
        return self.get_all_ckpts()[-1]

    def get_lastest_persistent_ckpt(self):
        """
        :return: Most recent persistent checkpoint. May be a temporary checkpoint.
        """
        candidates = [p for p in self.get_all_ckpts() if not p.endswith(self.tmp_postfix)]
        if len(candidates) == 0:
            raise ValueError('No persistent checkpoints')
        return candidates[-1]

    def _get_out_p(self, global_step, is_tmp):
        postfix = self.tmp_postfix if is_tmp else ''
        return os.path.join(self._out_dir, self.ckpt_name_fmt.format(global_step) + postfix)

    def get_itr_from_ckpt_p(self, ckpt_p):
        file_name = os.path.splitext(os.path.basename(ckpt_p))[0]
        assert self.ckpt_prefix in file_name
        itr_part = file_name.replace(self.ckpt_prefix, '')
        itr_part_digits_only = int(''.join(c for c in itr_part if c.isdigit()))
        return itr_part_digits_only



class Saver(_CheckpointTracker):
    """
    Saves ckpts:
    - ckpt_XXXXXXXX.pt.tmp
    If keep_tmp_last=None:
        Every `keep_every`-th ckpt is renamed to
        - ckpt_XXXXXXXX.pt
        and kept, the intermediate ones are removed. We call this a persistent checkpoint.
    else:
        Let C be the most recent persistent checkpoint.
        In addition to C being kept, the last `keep_tmp_last` temporary checkpoints before C are also kept.
        This means that always `keep_tmp_last` more checkpoints are kept than if keep_tmp_last=None
    """
    def __init__(self,
                 keep_tmp_itr: int, keep_every=10, keep_tmp_last=None,
                 out_dir=None, ckpt_name_fmt='ckpt_{:010d}.pt', tmp_postfix='.tmp',
                 verbose=False):
        """
        :param keep_every: keep every `keep_every`-th checkpoint, making it a persistent checkpoint
        :param keep_tmp_itr: keep checkpoint every `keep_tmp_itr` iterations.
        :param keep_tmp_last: Also keep the last `keep_tmp_last` temporary checkpoints before a persistent checkpoint.
        :param ckpt_name_fmt: filename, must include a format spec and some prefix before the format
        :param tmp_postfix: non-empty string to append to temporary checkpoints
        :param verbose: if True, print rename and remove info.
        """
        self.keep_every = keep_every
        self.keep_tmp_last = keep_tmp_last
        self.keep_tmp_itr = keep_tmp_itr
        self.ckpts_since_last_permanent = 0
        self.print = print if verbose else NoOp
        self.save_time_acc = timer.TimeAccumulator()
        super(Saver, self).__init__(out_dir, ckpt_name_fmt, tmp_postfix)

    def save(self, modules, global_step, force=False):
        """
        Save iff (force given or global_step % keep_tmp_itr == 0)
        :param modules: dictionary name -> nn.Module
        :param global_step: current step
        :return: bool, Whether previous checkpoints were removed
        """
        if not (force or (global_step % self.keep_tmp_itr == 0)):
            return False
        assert self._out_dir is not None
        current_ckpt_p = self._save(modules, global_step)
        self.ckpts_since_last_permanent += 1
        if self.ckpts_since_last_permanent == self.keep_every:
            self._remove_previous(current_ckpt_p)
            self.ckpts_since_last_permanent = 0
            return True
        return False

    def _save(self, modules, global_step):
        out_p = self._get_out_p(global_step, is_tmp=True)
        with self.save_time_acc.execute():
            torch.save({key: m.state_dict() for key, m in modules.items()}, out_p)
        return out_p

    def _remove_previous(self, current_ckpt_p):
        assert self.tmp_postfix in current_ckpt_p
        current_ckpt_p_non_tmp = current_ckpt_p.replace(self.tmp_postfix, '')
        self.print('{} -> {}'.format(basename(current_ckpt_p), basename(current_ckpt_p_non_tmp)))
        os.rename(current_ckpt_p, current_ckpt_p_non_tmp)
        keep_tmp_last = self.get_all_ckpts()[-(self.keep_tmp_last+1):] if self.keep_tmp_last else []
        for p in self.get_all_ckpts():
            if self.tmp_postfix in p and p not in keep_tmp_last:
                self.print('Removing {}...'.format(basename(p)))
                os.remove(p)
        self.print('Average save time: {:.3f}s'.format(self.save_time_acc.mean_time_spent()))


class Restorer(_CheckpointTracker):
    def restore_latest_persistent(self, net):
        return self.restore(net, self.get_lastest_persistent_ckpt())

    def restore(self, modules, ckpt_p, strict=True, restore_restart=False):
        print('Restoring {}... (strict={})'.format(ckpt_p, strict))
        map_location = None if pe.CUDA_AVAILABLE else 'cpu'
        state_dicts = torch.load(ckpt_p, map_location=map_location)
        # ---
        for key, m in modules.items():
            # optim implements its own load_state_dict which does not have the `strict` keyword...
            if isinstance(m, optimizer.Optimizer):
                if restore_restart:
                    print('Not restoring optimizer, --restore_restart given...')
                else:
                    try:
                        m.load_state_dict(state_dicts[key])
                    except ValueError as e:
                        raise ValueError('Error while restoring Optimizer:', str(e))
            else:
                try:
                    m.load_state_dict(state_dicts[key], strict=strict)
                except RuntimeError as e:  # loading error
                    for n, module in sorted(m.named_modules()):
                        print(n, module)
                    raise e
        return self.get_itr_from_ckpt_p(ckpt_p)