"""
The MIT License

Copyright 2019 Derek Miller

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
persons to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or substantial portions
of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""

from sparktorch.util import handle_features, load_torch_model, DataObj, load_base_torch
from sparktorch.early_stopper import EarlyStopping
from pyspark.rdd import RDD
from typing import Dict, List, Union
from uuid import uuid4
import numpy as np
from pyspark.rdd import PipelinedRDD
import torch
from torch.multiprocessing import Process
import torch.distributed as dist
from socket import gethostbyname, gethostname

import os


def retrieve_url() -> str:
    return os.environ.get("SPARK_LOCAL_IP", gethostbyname(gethostname()))


def mapPartitionsWithIndex(rdd, f, preservesPartitioning=False):
    """
    Temporary function for barrier map partitions.
    """
    return PipelinedRDD(rdd, f, preservesPartitioning, isFromBarrier=True)


def process_generic_model(params: List, iters: int, has_early_stop: bool = False):
    """
    Runs a mock training with zero grads. This is due to a bug where the connection gets reset with custom new groups.
    :param params: The params of the model
    :param iters: Iterations.
    """
    # Hopefully this function can go away in newer versions.
    for i in range(iters):
        for p in params:
            z = torch.zeros(p)
            dist.all_reduce(z, op=torch.distributed.ReduceOp.SUM)

        if has_early_stop:
            dist.all_reduce(torch.tensor(0.0), op=torch.distributed.ReduceOp.SUM)
            zeros = torch.zeros(1)
            dist.all_reduce(zeros, op=torch.distributed.ReduceOp.SUM)
            if zeros.item() > 0:
                break


def handle_model(
    index: int,
    data: List[DataObj],
    torch_obj: Union[str, List],
    master_url: str = '127.0.0.1',
    iters: int = 1000,
    world_size: int = 2,
    early_stop_patience: int = -1,
    verbose: int = 1,
    mini_batch: int = -1,
    validation_pct: float = 0,
    device: str = 'cpu'
) -> List[Dict]:
    """
    Runs the training of pytorch model, utilizing the distributed package.

    :param index: Partition index. Used for registering.
    :param data: The data from the partition
    :param torch_obj: The torch object string. Needs serialized
    :param master_url: The master url for the service.
    :param iters: The iterations for training
    :param world_size: The amount of partitions. Typically partitions + 1 for the driver
    :param verbose: whether to log the loss or not.
    :param mini_batch: Mini batch for training
    :param validation_pct: Validation percentage.
    :param device: The pytorch device to use for training. cpu/cuda
    :param early_stop_patience: Amount of patient for early stopping. -1 means don't use early stopping.

    :return: A list of the model state dictionary.
    """

    # If a process has already been setup on the machine, kill it.
    if dist.is_initialized():
        dist.destroy_process_group()

    # Set up the distributed server.
    os.environ['MASTER_ADDR'] = master_url
    os.environ['MASTER_PORT'] = '3333'

    dist.init_process_group('gloo', rank=index + 1, world_size=world_size)

    # Def Load model
    if index == -1:
        process_generic_model(torch_obj, iters, early_stop_patience > 0)
        return []
    else:
        torch_obj = load_torch_model(torch_obj)

    # Loaded the model
    model = torch_obj.model.to(device)
    model.train()
    criterion = torch_obj.criterion
    optimizer = torch_obj.optimizer

    # Set up early stopping
    es = EarlyStopping(patience=early_stop_patience)
    should_stop = torch.zeros(1)
    has_early_stop = early_stop_patience > 0

    partition_id = str(uuid4())

    # Process the data. Converts to x_train, y_train, x_val, y_val
    data_obj = handle_features(data, validation_pct)

    # check if data is none. We will still need to register.
    if data_obj is None or data_obj.x_train is None:
        process_generic_model([list(p.shape) for p in model.parameters()], iters, early_stop_patience > 0)
        return []

    # Passes all of the data
    x_train = data_obj.x_train.to(device)
    y_train = data_obj.y_train.to(device) if data_obj.y_train is not None else x_train
    x_val = data_obj.x_val.to(device) if data_obj.x_val is not None else None
    y_val = data_obj.y_val.to(device) if data_obj.y_val is not None else x_val

    for i in range(iters):

        optimizer.zero_grad()

        # utilize minibatch
        if 0 < mini_batch < len(data_obj.x_train):
            idxs = np.random.choice(len(data_obj.x_train), mini_batch, replace=False).tolist()
            x_train = data_obj.x_train[idxs]
            y_train = data_obj.y_train[idxs]

        y_pred = model(x_train)

        try:
            loss = criterion(y_pred, y_train)
        except RuntimeError as e:
            # utilized when loss need a long label
            y_train = torch.flatten(y_train.long())
            loss = criterion(y_pred, y_train)

        loss_v = loss.item()

        # Process validation loss
        val_loss = None
        val_loss_v = None
        if x_val is not None:
            pred_val = model(x_val)

            try:
                val_loss = criterion(pred_val, y_val)
            except RuntimeError as e:
                y_val = torch.flatten(y_val.long())
                val_loss = criterion(pred_val, y_val)

            val_loss_v = val_loss.item()

        # Calculate gradients
        loss.backward()

        # Distributed part of training.
        for param in model.parameters():
            dist.all_reduce(param.grad.data, op=torch.distributed.ReduceOp.SUM)
            param.grad.data /= (world_size-1)

        # Processes the early stop work
        loss_distributed = None
        if has_early_stop:
            loss_to_use = val_loss if val_loss is not None else loss

            dist.all_reduce(loss_to_use, op=torch.distributed.ReduceOp.SUM)
            loss_distributed = loss_to_use.item() / (world_size - 1)
            stop = es.step(loss_distributed)
            if stop:
                should_stop = should_stop + 1.0

            dist.all_reduce(should_stop, op=torch.distributed.ReduceOp.SUM)
            if should_stop.item() > 0:
                break

        optimizer.step()

        if verbose:
            print(f"Partition: {partition_id}. Iteration: {i}. Distributed Loss: {loss_distributed} "
                  f"Partition Training Loss: {loss_v}, "
                  f"Partition Validation Loss: {val_loss_v}")

    return [model.state_dict()]


def train_distributed(
    rdd: RDD,
    torch_obj: str,
    iters: int = 10,
    partition_shuffles: int = 1,
    verbose: int = 1,
    mini_batch: int = -1,
    validation_pct: float = 0.0,
    world_size: int = 2,
    device: str = 'cpu',
    early_stop_patience: int = -1
) -> Dict:
    """
    Entry point to train the model in distributed fashion.

    :param rdd: The rdd of data to run on the model.
    :param torch_obj: The torch object as a string that includes the model and param shapes.
    :param master_url: The main url for the driver.
    :param iters: Number of iterations for training.
    :param partition_shuffles: Number of partition shuffles (Need to implement)
    :param verbose: Verbosity of logs
    :param mini_batch: Mini batch for each iteration of training.
    :param validation_pct: How many items to validate
    :param world_size: number of partitions.
    :param device: pytorch device

    :return: The train dict.
    """
    master_url = retrieve_url()

    torch_loaded, params = load_base_torch(torch_obj)

    # Start the driver process.
    p = Process(
        target=handle_model,
        args=(-1, None, params, master_url, iters, world_size, early_stop_patience)
    )
    p.start()

    try:
        state_dict = None
        for i in range(partition_shuffles):

            # Run model with barrier execution mode.
            state_dict = mapPartitionsWithIndex(
                rdd, lambda i, x: handle_model(
                    i,
                    x,
                    torch_obj=torch_loaded,
                    master_url=master_url,
                    iters=iters,
                    verbose=verbose,
                    mini_batch=mini_batch,
                    validation_pct=validation_pct,
                    world_size=world_size,
                    device=device,
                    early_stop_patience=int(early_stop_patience+0)
                )
            ).collect()

            if partition_shuffles - i > 1:
                num_partitions = rdd.getNumPartitions()
                rdd = rdd.repartition(num_partitions)

        return state_dict[0]

    finally:
        p.terminate()
        p.join()