import os
import pickle
from pathlib import Path
import numpy as np

import multiprocessing as mp

import torch
from torch import nn
from torch import optim

import ray

from butterfly import ButterflyProduct
from learning_hadamard import TrainableHadamardFactorFixedOrder, TrainableHadamardFactorSoftmax, TrainableHadamardFactorSparsemax
from learning_fft import TrainableFftFactorFixedOrder, TrainableFftFactorSoftmax, TrainableFftFactorSparsemax


N_LBFGS_STEPS = 300
N_TRIALS_TO_POLISH = 20

# We'll use multiple processes so disable MKL multithreading
os.environ['MKL_NUM_THREADS'] = '1'

# @ray.remote
def polish_hadamard(trial):
    trainable = eval(trial.trainable_name)(trial.config)
    trainable.restore(str(Path(trial.logdir) / trial._checkpoint.value))
    model = trainable.model
    config = trial.config
    polished_model = ButterflyProduct(size=config['size'], complex=model.complex, fixed_order=True)
    if not model.fixed_order:
        prob = model.softmax_fn(model.logit)
        maxes, argmaxes = torch.max(prob, dim=-1)
        # print(maxes)
        # if torch.all(maxes >= 0.99):
        polished_model.butterflies = nn.ModuleList([model.butterflies[argmax] for argmax in argmaxes])
        # else:
        #     return -trial.last_result['negative_loss']
    else:
        polished_model.butterflies = model.butterflies
    optimizer = optim.LBFGS(polished_model.parameters())
    def closure():
        optimizer.zero_grad()
        loss = nn.functional.mse_loss(polished_model.matrix(), trainable.target_matrix)
        loss.backward()
        return loss
    for i in range(N_LBFGS_STEPS):
        optimizer.step(closure)
    torch.save(polished_model.state_dict(), str((Path(trial.logdir) / trial._checkpoint.value).parent / 'polished_model.pth'))
    loss = nn.functional.mse_loss(polished_model.matrix(), trainable.target_matrix)
    return loss.item()


def polish_fft(trial):
    trainable = eval(trial.trainable_name)(trial.config)
    trainable.restore(str(Path(trial.logdir) / trial._checkpoint.value))
    model = trainable.model
    config = trial.config
    polished_model = ButterflyProduct(size=config['size'], complex=model.complex, fixed_order=True)
    if not model.fixed_order:
        prob = model.softmax_fn(model.logit)
        maxes, argmaxes = torch.max(prob, dim=-1)
        # print(maxes)
        # if torch.all(maxes >= 0.99):
        polished_model.butterflies = nn.ModuleList([model.butterflies[argmax] for argmax in argmaxes])
        # else:
        #     return -trial.last_result['negative_loss']
    else:
        polished_model.butterflies = model.butterflies
    optimizer = optim.LBFGS(polished_model.parameters())
    def closure():
        optimizer.zero_grad()
        loss = nn.functional.mse_loss(polished_model.matrix()[:, trainable.br_perm], trainable.target_matrix)
        loss.backward()
        return loss
    for i in range(N_LBFGS_STEPS):
        optimizer.step(closure)
    torch.save(polished_model.state_dict(), str((Path(trial.logdir) / trial._checkpoint.value).parent / 'polished_model.pth'))
    loss = nn.functional.mse_loss(polished_model.matrix()[:, trainable.br_perm], trainable.target_matrix)
    return loss.item()



if __name__ == '__main__':
    # ray.init()
    result_dir = 'results'
    experiment_names = [[f'Hadamard_factorization_fixed_order_{size}' for size in [8, 16, 32, 64, 128, 256]]]
    experiment_names += [[f'Hadamard_factorization_softmax_{size}' for size in [8, 16, 32, 64, 128, 256]]]
    experiment_names += [[f'Hadamard_factorization_sparsemax_{size}' for size in [8, 16, 32, 64, 128]]]
    experiment_names += [[f'Fft_factorization_fixed_order_{size}' for size in [8, 16, 32, 64, 128]]]
    experiment_names += [[f'Fft_factorization_softmax_{size}' for size in [8, 16, 32, 64, 128]]]
    experiment_names += [[f'Fft_factorization_sparsemax_{size}' for size in [8, 16, 32, 64, 128]]]

    pool = mp.Pool()
    for experiment_names_ in experiment_names:
        # print(experiment_names_[0])
        for experiment_name in experiment_names_:
            print(experiment_name)
            checkpoint_path = Path(result_dir) / experiment_name / 'trial.pkl'
            with checkpoint_path.open('rb') as f:
                trials = pickle.load(f)
            sorted_trials = sorted(trials, key=lambda trial: -trial.last_result['negative_loss'])
            losses = [-trial.last_result['negative_loss'] for trial in sorted_trials]
            # polished_losses = ray.get([polish.remote(trial) for trial in sorted_trials[:N_TRIALS_TO_POLISH]])
            if experiment_name.startswith('Hadamard'):
                polished_losses = pool.map(polish_hadamard, sorted_trials[:20])
            elif experiment_name.startswith('Fft'):
                polished_losses = pool.map(polish_fft, sorted_trials[:20])
            else:
                assert False, 'Unknown experiment'
            print(np.sort(losses)[:N_TRIALS_TO_POLISH])
            for i in range(N_TRIALS_TO_POLISH):
                sorted_trials[i].last_result['polished_negative_loss'] = -polished_losses[i]
            print(np.array([trial.last_result['polished_negative_loss'] for trial in sorted_trials[:N_TRIALS_TO_POLISH]]))
            with checkpoint_path.open('wb') as f:
                pickle.dump(trials, f)