__author__ = 'yawli'
import math
import torch.nn as nn
import torch
from model import common
# from model import ops

def make_model(args, parent=False):
    return SRResNet(args)

def norm(norm_type, channel, group):
    if norm_type == 'batchnorm':
        norm = nn.BatchNorm2d(channel)
    elif norm_type == 'groupnorm':
        norm = nn.GroupNorm(group, channel)
    elif norm_type == 'instancenorm':
        norm = nn.InstanceNorm2d(channel)
    elif norm_type == 'instancenorm_affine':
        norm = nn.InstanceNorm2d(channel, affine=True)
    elif norm_type == 'layernorm':
        norm = nn.LayerNorm(channel)
    else:
        norm = None
    return norm

class VarBlockSimple(nn.Module):

    def __init__(self, conv=common.default_conv, n_feats=64, kernel_size=3, reg_act=nn.Softplus(), rescale=1, norm_f=None):
        super(VarBlockSimple, self).__init__()
        if norm_f is not None:
            conv_mask = [norm_f, nn.Conv2d(n_feats, n_feats, kernel_size=kernel_size, padding=kernel_size//2, groups=n_feats), reg_act]
        else:
            conv_mask = [nn.Conv2d(n_feats, n_feats, kernel_size=kernel_size, padding=kernel_size//2, groups=n_feats), reg_act]
        conv_body = [conv(n_feats, n_feats, kernel_size), nn.PReLU()]
        self.rescale = rescale
        self.conv_mask = nn.Sequential(*conv_mask)
        self.conv_body = nn.Sequential(*conv_body)

    def forward(self, x):
        res = self.conv_body(self.conv_mask(x) * x)
        x = res.mul(self.rescale) + x
        return x

class JointAttention(nn.Module):

    def __init__(self, conv=common.default_conv, n_feats=64, kernel_size=3, reg_act=nn.Softplus(), rescale=1, norm_f=None):
        super(JointAttention, self).__init__()
        mask_conv = [nn.Conv2d(n_feats, 16, kernel_size=kernel_size, stride=4, padding=kernel_size//2), nn.PReLU()]
        mask_deconv = nn.ConvTranspose2d(16, n_feats, kernel_size=kernel_size, stride=4, padding=1)
        mask_deconv_act = nn.Softmax2d()
        conv_body = [conv(n_feats, n_feats, kernel_size), nn.PReLU()]
        self.mask_conv = nn.Sequential(*mask_conv)
        self.mask_deconv = mask_deconv
        self.mask_deconv_act = mask_deconv_act
        # self.ca = CALayer(n_feats)
        self.conv_body = nn.Sequential(*conv_body)

    def forward(self, x):
        mask = self.mask_deconv_act(self.mask_deconv(self.mask_conv(x), output_size=x.size()))
        res = mask * x
        # res = self.ca(res)
        res = self.conv_body(res)
        x = res + x
        return x

class UpsampleBlock(nn.Module):
    def __init__(self,
                 n_channels, scale, multi_scale,
                 group=1):
        super(UpsampleBlock, self).__init__()

        if multi_scale:
            self.up2 = _UpsampleBlock(n_channels, scale=2, group=group)
            self.up3 = _UpsampleBlock(n_channels, scale=3, group=group)
            self.up4 = _UpsampleBlock(n_channels, scale=4, group=group)
        else:
            self.up =  _UpsampleBlock(n_channels, scale=scale, group=group)

        self.multi_scale = multi_scale

    def forward(self, x, scale):
        if self.multi_scale:
            if scale == 2:
                return self.up2(x)
            elif scale == 3:
                return self.up3(x)
            elif scale == 4:
                return self.up4(x)
        else:
            return self.up(x)


class _UpsampleBlock(nn.Module):
    def __init__(self,
				 n_channels, scale,
				 group=1):
        super(_UpsampleBlock, self).__init__()

        modules = []
        if scale == 2 or scale == 4 or scale == 8:
            for _ in range(int(math.log(scale, 2))):
                modules += [nn.Conv2d(n_channels, 4*n_channels, 3, 1, 1, groups=group), nn.PReLU()]
                modules += [nn.PixelShuffle(2)]
        elif scale == 3:
            modules += [nn.Conv2d(n_channels, 9*n_channels, 3, 1, 1, groups=group), nn.PReLU()]
            modules += [nn.PixelShuffle(3)]

        self.body = nn.Sequential(*modules)

    def forward(self, x):
        out = self.body(x)
        return out


class SRResNet(nn.Module):
    def __init__(self, args, conv=common.default_conv):
        super(SRResNet, self).__init__()

        n_resblocks = args.n_resblocks
        n_feats = args.n_feats
        kernel_size = 3
        # scale = args.scale[0]
        act = nn.PReLU()

        multi_scale = len(args.scale) > 1
        self.scale_idx = 0
        scale = args.scale[self.scale_idx]
        group = 1
        self.scale = args.scale

        rgb_mean = (0.4488, 0.4371, 0.4040)
        rgb_std = (1.0, 1.0, 1.0)

        norm_f = norm(args.norm_type, args.n_feats, args.n_groups)
        act_vconv = common.act_vconv(args.res_act)

        head = [conv(args.n_colors, n_feats, kernel_size), act]
        body_r = [JointAttention(conv, n_feats, kernel_size, reg_act=act_vconv, norm_f=norm_f, rescale=args.res_scale)
                         for _ in range(n_resblocks)]
        #body_r = [common.ResBlock(conv, n_feats, kernel_size, bn=False, act=act, res_scale=args.res_scale, num_conv=2)
        #                 for _ in range(n_resblocks)]


        body_conv = [conv(n_feats, n_feats, kernel_size)]
        #body_conv = [conv(n_feats, n_feats, kernel_size), nn.BatchNorm2d(n_feats)]

        # tail = [
        #     common.Upsampler(conv, scale, n_feats, act=act),
        #     conv(n_feats, args.n_colors, kernel_size)
        # ]

        tail = UpsampleBlock(n_feats,
                             scale=scale,
                             multi_scale=multi_scale,
                             group=group)
        tail_conv = [conv(n_feats, args.n_colors, kernel_size)]

        self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)
        self.head = nn.Sequential(*head)
        self.body_r = nn.Sequential(*body_r)
        self.body_conv = nn.Sequential(*body_conv)
        self.tail = tail
        self.tail_conv = nn.Sequential(*tail_conv)
        self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)

    def forward(self, x):
        x = self.sub_mean(x)
        x = self.head(x)
        f = self.body_r(x)
        f = self.body_conv(f)
        scale = self.scale[self.scale_idx]
        x = self.tail(x + f, scale)
        x = self.tail_conv(x)
        x = self.add_mean(x)
        return x

    def set_scale(self, scale_idx):
        self.scale_idx = scale_idx