'''
多GPU处理
Demo
'''
import os
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.utils.data as data
from PIL import Image
from torchvision import transforms
# from skimage import transform,color
from torch.utils.data import DataLoader


class DatasetFromImage(data.Dataset):
    def __init__(self, file_path, scale=2):
        super(DatasetFromImage, self).__init__()
        self.ims = os.listdir(file_path)
        self.file_path = file_path
        self.trans = transforms.Compose([
            transforms.ToTensor(),
        ])
        tim = Image.open(os.path.join(file_path, self.ims[0]))
        h, w = tim.size
        self.trans_bic = transforms.Compose([
            transforms.Resize((h*scale, w*scale), Image.BICUBIC),
            transforms.ToTensor()
        ])

    def __getitem__(self, index):
        data = Image.open(os.path.join(self.file_path, self.ims[index]))
        bic_im = self.trans_bic(data)
        data = data.convert("YCbCr")
        data_y, cb, cr = data.split()
        data = self.trans(data_y)
        # batch must contain tensors, numbers, dicts or lists;
        # data_dict = {'bic':bic_im,'name':self.ims[index]}
        return data, bic_im, self.ims[index]

    def __len__(self):
        return len(self.ims)


def model_convert(path, scale, gpus=1):
    if gpus > 1:
        loadmultiGPU = True
        gids = [i for i in range(gpus)]
    else:
        loadmultiGPU = False

    if scale == 2:
        from models import Net2x as Net
    if scale == 3:
        from models import Net3x as Net
    elif scale == 4:
        from models import Net4x as Net
    model = Net()

    if loadmultiGPU and torch.cuda.is_available():
        model = nn.DataParallel(model, device_ids=gids).cuda()
    elif torch.cuda.is_available():
        model = model.cuda()
    else:
        model = model.cpu()
    # optionally resume from a checkpoint
    if os.path.isfile(path):
        print("=> loading checkpoint '{}'".format(path))
        weights = torch.load(path)
        # saved_state = weights.state_dict()
        model.load_state_dict(weights)
        # multi gpu loader之前存的模型好像去掉了这部分只有权重
        # if loadmultiGPU:
        #     from collections import OrderedDict
        #     new_state_dict = OrderedDict()
        #     for k, v in saved_state.items():
        #         namekey = 'module.'+k  # add `module.`
        #         new_state_dict[namekey] = v
        #         # load params
        #     model.load_state_dict(new_state_dict)
        # else:
        #     model.load_state_dict(saved_state)
    else:
        print("=> no checkpoint found at '{}'".format(path))
    return model


def multi_gpu_run(model, im_path, outpath, gpus):
    print('running with multi GPU')
    # 如果patch一样大,开这个会加速
    cudnn.benchmark = True
    # 输入2D Tensor=[Batch,Channel,H,W]
    dataset = DatasetFromImage(im_path)
    # 暂定一个GPU跑一张图
    loader = DataLoader(dataset=dataset, batch_size=gpus)
    trans = transforms.Compose([
        transforms.ToPILImage(),
    ])
    for iter, batch in enumerate(loader, 1):
        y = batch[0]
        bicubic = batch[1]
        names = batch[2]
        # print(y)
        if torch.cuda.is_available():
            y = y.cuda()
        else:
            y = y.cpu()
        im_h_y = model(y)
        print(im_h_y.shape)
        #得看下shape,此处应该有个循环,把batch分解成多个图再合成
        im_h_y = trans(im_h_y)
        bicubic = trans(bicubic).convert("YCbCr")
        y, cb, cr = bicubic.split()
        HR = Image.merge('YCbCr', (im_h_y, cb, cr))
        HR.save(outpath)


if __name__ == '__main__':
    print('running with multi GPU')
    gpus = 1
    inpath = './temp/out/'
    outpath = './temp/vsr/'
    if not os.path.exists(outpath):
        os.makedirs(outpath)
    model = model_convert('./model/a2/model_new.pth', 2, gpus)
    print('load success')
    multi_gpu_run(model, inpath, outpath, gpus)