import math
import multiprocessing as mp
import os
from pathlib import Path
import pickle
import random
import sys

import numpy as np
from numpy.polynomial import legendre

import torch
from torch import nn
from torch import optim

from sacred import Experiment
from sacred.observers import FileStorageObserver, SlackObserver

import ray
from ray.tune import Trainable, Experiment as RayExperiment, sample_from, run_experiments
from ray.tune.schedulers import AsyncHyperBandScheduler

from hstack_diag import HstackDiagProduct
from utils import PytorchTrainable, bitreversal_permutation


N_LBFGS_STEPS = 300
N_TRIALS_TO_POLISH = 20


class TrainableOps(PytorchTrainable):

    def _setup(self, config):
        torch.manual_seed(config['seed'])
        self.model = HstackDiagProduct(size=config['size'])
        self.optimizer = optim.Adam(self.model.parameters(), lr=config['lr'])
        self.n_steps_per_epoch = config['n_steps_per_epoch']
        size = config['size']
        # Target: Legendre polynomials
        P = np.zeros((size, size), dtype=np.float64)
        for i, coef in enumerate(np.eye(size)):
            P[i, :i + 1] = legendre.leg2poly(coef)
        self.target_matrix = torch.tensor(P)
        self.br_perm = bitreversal_permutation(size)
        self.input = (torch.eye(size)[:, :, None, None] * torch.eye(2)).unsqueeze(-1)
        self.input_permuted = self.input[:, self.br_perm]

    def _train(self):
        for _ in range(self.n_steps_per_epoch):
            self.optimizer.zero_grad()
            y = self.model(self.input_permuted)
            loss = nn.functional.mse_loss(y.double(), self.target_matrix)
            loss.backward()
            self.optimizer.step()
        return {'negative_loss': -loss.item()}


def polish_ops(trial):
    """Load model from checkpoint, and re-optimize using L-BFGS to find
    the nearest local optimum.
    """
    trainable = eval(trial.trainable_name)(trial.config)
    trainable.restore(str(Path(trial.logdir) / trial._checkpoint.value))
    model = trainable.model
    config = trial.config
    polished_model = HstackDiagProduct(size=config['size'])
    polished_model.factors = model.factors
    polished_model.P_init = model.P_init
    optimizer = optim.LBFGS(polished_model.parameters())
    def closure():
        optimizer.zero_grad()
        eye = torch.eye(polished_model.size)
        x = (eye[:, :, None, None] * torch.eye(2)).unsqueeze(-1)
        y = polished_model(x[:, trainable.br_perm])
        loss = nn.functional.mse_loss(y, trainable.target_matrix)
        loss.backward()
        return loss
    for i in range(N_LBFGS_STEPS):
        optimizer.step(closure)
    torch.save(polished_model.state_dict(), str((Path(trial.logdir) / trial._checkpoint.value).parent / 'polished_model.pth'))
    eye = torch.eye(polished_model.size)
    x = (eye[:, :, None, None] * torch.eye(2)).unsqueeze(-1)
    y = polished_model(x[:, trainable.br_perm])
    loss = nn.functional.mse_loss(y, trainable.target_matrix)
    return loss.item()


ex = Experiment('Ops_factorization')
ex.observers.append(FileStorageObserver.create('logs'))
slack_config_path = Path('config/slack.json')  # Add webhook_url there for Slack notification
if slack_config_path.exists():
    ex.observers.append(SlackObserver.from_config(str(slack_config_path)))


@ex.config
def fixed_order_config():
    size = 8  # Size of matrix to factor, must be power of 2
    ntrials = 20  # Number of trials for hyperparameter tuning
    nsteps = 400  # Number of steps per epoch
    nmaxepochs = 200  # Maximum number of epochs
    result_dir = 'results'  # Directory to store results
    nthreads = 1  # Number of CPU threads per job
    smoke_test = False  # Finish quickly for testing


@ex.capture
def ops_experiment(size, ntrials, nsteps, result_dir, nthreads, smoke_test):
    config={
        'size': size,
        'lr': sample_from(lambda spec: math.exp(random.uniform(math.log(1e-4), math.log(5e-1)))),
        'seed': sample_from(lambda spec: random.randint(0, 1 << 16)),
        'n_steps_per_epoch': nsteps,
     }
    experiment = RayExperiment(
        name=f'Ops_factorization_{size}',
        run=TrainableOps,
        local_dir=result_dir,
        num_samples=ntrials,
        checkpoint_at_end=True,
        resources_per_trial={'cpu': nthreads, 'gpu': 0},
        stop={
            'training_iteration': 1 if smoke_test else 99999,
            'negative_loss': -1e-8
        },
        config=config,
    )
    return experiment


@ex.automain
def run(result_dir, nmaxepochs, nthreads):
    experiment = ops_experiment()
    # We'll use multiple processes so disable MKL multithreading
    os.environ['MKL_NUM_THREADS'] = str(nthreads)
    os.environ['OMP_NUM_THREADS'] = str(nthreads)  # For some reason we need this for OPs otherwise it'll thrash
    ray.init()
    ahb = AsyncHyperBandScheduler(reward_attr='negative_loss', max_t=nmaxepochs)
    trials = run_experiments(experiment, scheduler=ahb, raise_on_failed_trial=False)
    losses = [-trial.last_result['negative_loss'] for trial in trials]

    # Polish solutions with L-BFGS
    pool = mp.Pool()
    sorted_trials = sorted(trials, key=lambda trial: -trial.last_result['negative_loss'])
    polished_losses = pool.map(polish_ops, sorted_trials[:N_TRIALS_TO_POLISH])
    pool.close()
    pool.join()
    for i in range(N_TRIALS_TO_POLISH):
        sorted_trials[i].last_result['polished_negative_loss'] = -polished_losses[i]
    print(np.array(losses))
    print(np.sort(losses))
    # print(np.sort(losses)[:N_TRIALS_TO_POLISH])
    print(np.sort(polished_losses))

    checkpoint_path = Path(result_dir) / experiment.name
    checkpoint_path.mkdir(parents=True, exist_ok=True)
    checkpoint_path /= 'trial.pkl'
    with checkpoint_path.open('wb') as f:
        pickle.dump(trials, f)

    ex.add_artifact(str(checkpoint_path))
    return min(losses + polished_losses)

# TODO: there might be a memory leak, trying to find it here
# import gc
# import operator as op
# from functools import reduce

# for obj in gc.get_objects():
#     try:
#         if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
#             print(reduce(op.mul, obj.size()) if len(obj.size()) > 0 else 0, type(obj), obj.size())
#             # print(type(obj), obj.size(), obj.type())
#     except:
#         pass