import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
from toolkit.pytorch_transformers.models import Model
from torch.autograd import Variable

from common_blocks.architectures.classification import Densenet
from common_blocks.utils.misc import get_list_of_image_predictions, sigmoid, softmax
from .architectures import encoders, unet, large_kernel_matters, pspnet
from . import callbacks as cbk
from .lovasz_losses import lovasz_hinge

ENCODERS = {'ResNet': {'model': encoders.ResNetEncoders,
                       'model_config': {'encoder_depth': 34, 'pretrained': True, 'pool0': True
                                        }
                       },
            'SeResNet': {'model': encoders.SeResNetEncoders,
                         'model_config': {'encoder_depth': 50, 'pretrained': 'imagenet', 'pool0': True
                                          }
                         },
            'SeResNetXt': {'model': encoders.SeResNetXtEncoders,
                           'model_config': {'encoder_depth': 101, 'pretrained': 'imagenet', 'pool0': True
                                            }
                           },
            'DenseNet': {'model': encoders.DenseNetEncoders,
                         'model_config': {'encoder_depth': 201, 'pretrained': 'imagenet', 'pool0': True
                                          }
                         },
            }

ARCHITECTURES = {'UNet': {'model': unet.UNet,
                          'model_config': {'use_hypercolumn': False, 'dropout_2d': 0.0, 'pool0': True
                                           }},
                 'LargeKernelMatters': {'model': large_kernel_matters.LargeKernelMatters,
                                        'model_config': {'kernel_size': 9, 'internal_channels': 21,
                                                         'dropout_2d': 0.0, 'use_relu': False, 'pool0': True,
                                                         'use_channel_se': True, 'use_spatial_se': True,
                                                         'reduction_se': 4
                                                         },
                                        },
                 'PSPNet': {'model': pspnet.PSPNet,
                            'model_config': {'use_hypercolumn': False, 'pool0': True
                                             },
                            },
                 }

SNS_ARCHITECTURES = {
    "Densenet": {'model': Densenet,
                 'model_config': {'pretrained': 'imagenet'}
                 }}


class SegmentationModel(Model):
    def __init__(self, architecture_config, training_config, callbacks_config):
        super().__init__(architecture_config, training_config, callbacks_config)
        self.activation_func = self.architecture_config['model_params']['activation']
        self.set_model()
        self.set_loss()
        self.weight_regularization = weight_regularization
        self.optimizer = optim.SGD(self.weight_regularization(self.model, **architecture_config['regularizer_params']),
                                   **architecture_config['optimizer_params'])
        self.callbacks = callbacks_network(self.callbacks_config)

    def fit(self, datagen, validation_datagen=None, meta_valid=None):
        self._initialize_model_weights()

        if not isinstance(self.model, nn.DataParallel):
            self.model = nn.DataParallel(self.model)

        if torch.cuda.is_available():
            self.model = self.model.cuda()

        self.callbacks.set_params(self, validation_datagen=validation_datagen, meta_valid=meta_valid)
        self.callbacks.on_train_begin()

        batch_gen, steps = datagen
        for epoch_id in range(self.training_config['epochs']):
            self.callbacks.on_epoch_begin()
            for batch_id, data in enumerate(batch_gen):
                self.callbacks.on_batch_begin()
                self.freeze_weights()
                metrics = self._fit_loop(data)
                self.callbacks.on_batch_end(metrics=metrics)
                if batch_id == steps:
                    break
            self.callbacks.on_epoch_end()
            if self.callbacks.training_break():
                break
        self.callbacks.on_train_end()
        return self

    def _fit_loop(self, data):
        X = data[0]
        targets_tensors = data[1:]

        if torch.cuda.is_available():
            X = Variable(X).cuda()
            targets_var = []
            for target_tensor in targets_tensors:
                targets_var.append(Variable(target_tensor).cuda())
        else:
            X = Variable(X)
            targets_var = []
            for target_tensor in targets_tensors:
                targets_var.append(Variable(target_tensor))

        self.optimizer.zero_grad()
        outputs_batch = self.model(X)
        partial_batch_losses = {}

        if len(self.output_names) == 1:
            for (name, loss_function, weight), target in zip(self.loss_function, targets_var):
                batch_loss = loss_function(outputs_batch, target) * weight
        else:
            for (name, loss_function, weight), output, target in zip(self.loss_function, outputs_batch, targets_var):
                partial_batch_losses[name] = loss_function(output, target) * weight
            batch_loss = sum(partial_batch_losses.values())
        partial_batch_losses['sum'] = batch_loss

        batch_loss.backward()
        self.optimizer.step()

        return partial_batch_losses

    def transform(self, datagen, validation_datagen=None, *args, **kwargs):
        outputs = self._transform(datagen, validation_datagen)
        for name, prediction in outputs.items():
            if self.activation_func == 'softmax':
                outputs[name] = [softmax(single_prediction, axis=0) for single_prediction in prediction]
            elif self.activation_func == 'sigmoid':
                outputs[name] = [sigmoid(np.squeeze(mask)) for mask in prediction]
            else:
                raise Exception('Only softmax and sigmoid activations are allowed')
        return outputs

    def _transform(self, datagen, validation_datagen=None, **kwargs):
        self.model.eval()

        batch_gen, steps = datagen
        outputs = {}
        for batch_id, data in enumerate(batch_gen):
            if isinstance(data, (list, tuple)):
                X = data[0]
            else:
                X = data

            if torch.cuda.is_available():
                X = Variable(X, volatile=True).cuda()
            else:
                X = Variable(X, volatile=True)
            outputs_batch = self.model(X)

            if len(self.output_names) == 1:
                outputs.setdefault(self.output_names[0], []).append(outputs_batch.data.cpu().numpy())
            else:
                for name, output in zip(self.output_names, outputs_batch):
                    output_ = output.data.cpu().numpy()
                    outputs.setdefault(name, []).append(output_)
            if batch_id == steps:
                break
        self.model.train()
        outputs = {'{}_prediction'.format(name): get_list_of_image_predictions(outputs_) for name, outputs_ in
                   outputs.items()}
        return outputs

    def set_model(self):
        architecture_name = self.architecture_config['model_params']['architecture']
        encoder_name = self.architecture_config['model_params']['encoder']
        encoder = ENCODERS[encoder_name]
        architecture = ARCHITECTURES[architecture_name]

        self.model = architecture['model'](encoder=encoder['model'](**encoder['model_config']),
                                           num_classes=self.architecture_config['model_params']['out_channels'],
                                           **architecture['model_config'])
        self._initialize_model_weights = lambda: None

    def set_loss(self):
        if self.activation_func == 'softmax':
            raise NotImplementedError('No softmax loss defined')
        elif self.activation_func == 'sigmoid':

            loss_function = focal_lovasz
            # loss_function = weighted_sum_loss
            # loss_function = nn.BCEWithLogitsLoss()
            # loss_function = DiceWithLogitsLoss()
            # loss_function = lovasz_loss
            # loss_function = FocalWithLogitsLoss()
        else:
            raise Exception('Only softmax and sigmoid activations are allowed')
        self.loss_function = [('mask', loss_function, 1.0)]

    def freeze_weights(self):
        # # freeze encoder
        # if isinstance(self.model, nn.DataParallel):
        #     encoder_params = self.model.module.encoder.parameters()
        # else:
        #     encoder_params = self.model.encoder.parameters()
        #
        # for parameter in encoder_params:
        #     parameter.requires_grad = False
        #
        # # freeze batchnorm
        # for m in self.model.modules():
        #     if isinstance(m, nn.BatchNorm2d):
        #         m.eval()
        #         m.weight.requires_grad = False
        #         m.bias.requires_grad = False
        pass

    def load(self, filepath):
        self.model.eval()

        if not isinstance(self.model, nn.DataParallel):
            self.model = nn.DataParallel(self.model)

        if torch.cuda.is_available():
            self.model.cpu()
            self.model.load_state_dict(torch.load(filepath))
            self.model = self.model.cuda()
        else:
            self.model.load_state_dict(torch.load(filepath, map_location=lambda storage, loc: storage))
        return self


class BinaryModel(SegmentationModel):
    def __init__(self, architecture_config, training_config, callbacks_config, **kwargs):
        super().__init__(architecture_config, training_config, callbacks_config)
        self.weight_regularization = weight_regularization
        self.set_model()
        self.optimizer = optim.Adam(self.weight_regularization(self.model, **architecture_config['regularizer_params']),
                                    **architecture_config['optimizer_params'])

        self.epochs = 10
        self.callbacks_config = callbacks_config
        self.callbacks = callbacks_ship_no_ship(self.callbacks_config)
        self.activation_func = 'sigmoid'
        self.validation_loss = {}

    def set_model(self):
        architecture = self.architecture_config['model_params']['architecture']
        config = SNS_ARCHITECTURES[architecture]
        self.model = config['model'](**config['model_config'])
        self._initialize_model_weights = lambda: None

    def set_loss(self):
        self.loss_function = [('ship_no_ship', nn.CrossEntropyLoss(), 1.0)]

    def freeze_weights(self):
        pass


class FocalWithLogitsLoss(nn.Module):
    def __init__(self, alpha=1.0, gamma=1.0, reduction='elementwise_mean'):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, output, target):
        if not (target.size() == output.size()):
            raise ValueError(
                "Target size ({}) must be the same as input size ({})".format(target.size(), output.size()))

        max_val = (-output).clamp(min=0)
        logpt = output - output * target + max_val + ((-max_val).exp() + (-output - max_val).exp()).log()
        pt = torch.exp(-logpt)
        at = self.alpha * target + (1 - target)
        loss = at * ((1 - pt).pow(self.gamma)) * logpt

        if self.reduction == 'none':
            return loss
        elif self.reduction == 'elementwise_mean':
            return loss.mean()
        else:
            return loss.sum()


class DiceWithLogitsLoss(nn.Module):
    def __init__(self, smooth=0, eps=1e-7):
        super().__init__()
        self.smooth = smooth
        self.eps = eps

    def forward(self, output, target):
        output = F.sigmoid(output)
        return 1 - (2 * torch.sum(output * target) + self.smooth) / (
                torch.sum(output) + torch.sum(target) + self.smooth + self.eps)


def weighted_sum_loss(output, target):
    bce = nn.BCEWithLogitsLoss()(output, target)
    dice = DiceWithLogitsLoss()(output, target)
    return bce + 0.25 * dice


def focal_lovasz(output, target):
    focal = FocalWithLogitsLoss(alpha=1.0, gamma=2.0)(output, target)
    lovasz = lovasz_loss(output, target)
    return focal + lovasz


def lovasz_loss(output, target):
    target = target.long()
    return lovasz_hinge(output, target)


def weight_regularization(model, regularize, weight_decay_conv2d):
    if regularize:
        parameter_list = [
            {'params': filter(lambda p: p.requires_grad, model.parameters()),
             'weight_decay': weight_decay_conv2d},
        ]
    else:
        parameter_list = [filter(lambda p: p.requires_grad, model.parameters())]
    return parameter_list


def callbacks_network(callbacks_config):
    experiment_timing = cbk.ExperimentTiming(**callbacks_config['experiment_timing'])
    model_checkpoints = cbk.ModelCheckpoint(**callbacks_config['model_checkpoint'])
    lr_scheduler = cbk.ReduceLROnPlateauScheduler(**callbacks_config['reduce_lr_on_plateau_scheduler'])
    training_monitor = cbk.TrainingMonitor(**callbacks_config['training_monitor'])
    validation_monitor = cbk.ValidationMonitor(**callbacks_config['validation_monitor'])
    neptune_monitor = cbk.NeptuneMonitor(**callbacks_config['neptune_monitor'])
    early_stopping = cbk.EarlyStopping(**callbacks_config['early_stopping'])
    init_lr_finder = cbk.InitialLearningRateFinder()
    one_cycle_callback = cbk.OneCycleCallback(**callbacks_config['one_cycle_scheduler'])

    return cbk.CallbackList(
        callbacks=[experiment_timing, training_monitor, validation_monitor,
                   model_checkpoints, neptune_monitor,
                   one_cycle_callback,
                   # early_stopping,
                   # lr_scheduler,
                   # init_lr_finder,
                   ])


def callbacks_ship_no_ship(callbacks_config):
    training_monitor = cbk.TrainingMonitor(**callbacks_config['training_monitor'])
    validation_monitor = cbk.SNS_ValidationMonitor()
    model_checkpoints = cbk.ModelCheckpoint(**callbacks_config['model_checkpoint'])
    one_cycle_callback = cbk.OneCycleCallback(**callbacks_config['one_cycle_scheduler'])

    return cbk.CallbackList([training_monitor, validation_monitor, model_checkpoints,
                             # one_cycle_callback
                             ])