""" Pix2pixHD model with additional image data as input It basically takes additional inputs of cropped image, which is then used as additional 3-channel input. The additional image input has "hole" in the regions inside object bounding box, which is subsequently filled by the generator network. """ import numpy as np import torch import os from torch.autograd import Variable from util.image_pool import ImagePool from .base_model import BaseModel from layer_util import * import util.util as util from collections import OrderedDict NULLVAL = 0.0 # TODO(sh): change the forward_wrapper things to enable multi-gpu training # TODO(sh): add additional input of "mask_in" for the binary mask of object region # TODO(sh): enlarge the context marign class Pix2PixHDModel_condImg(BaseModel): def __init__(self, opt): super(Pix2PixHDModel_condImg, self).__init__(opt) if opt.resize_or_crop != 'none': # when training at full res this causes OOM torch.backends.cudnn.benchmark = True self.isTrain = opt.isTrain self.netG_type = opt.netG self.use_features = opt.instance_feat or opt.label_feat self.gen_features = self.use_features and not self.opt.load_features # NOTE(sh): 3-channels for adddional rgb-image input_nc = opt.label_nc if opt.label_nc != 0 else 3 ##### define networks # Generator network netG_input_nc = input_nc if not opt.no_instance: netG_input_nc += 1 if self.use_features: netG_input_nc += opt.feat_num from .Pix2Pix_NET import GlobalGenerator, GlobalTwoStreamGenerator if opt.netG=='global': netG_input_nc += 3 self.netG = GlobalGenerator(netG_input_nc, opt.output_nc, opt.ngf, opt.n_downsample_global, opt.n_blocks_global, opt.norm, 'reflect', opt.use_output_gate) elif opt.netG=='global_twostream': self.netG = GlobalTwoStreamGenerator(netG_input_nc, opt.output_nc, opt.ngf, opt.n_downsample_global, opt.n_blocks_global, opt.norm, 'reflect', opt.use_skip, opt.which_encoder, opt.use_output_gate, opt.feat_fusion) else: raise NameError('global generator name is not defined properly: %s' % opt.netG) print(self.netG) if len(opt.gpu_ids) > 0: assert(torch.cuda.is_available()) self.netG.cuda(opt.gpu_ids[0]) self.netG.apply(weights_init) # Discriminator network if self.isTrain: self.no_imgCond = opt.no_imgCond self.mask_gan_input = opt.mask_gan_input self.use_soft_mask = opt.use_soft_mask use_sigmoid = opt.no_lsgan if self.no_imgCond: netD_input_nc = input_nc + opt.output_nc else: netD_input_nc = input_nc + 3 + opt.output_nc if not opt.no_instance: netD_input_nc += 1 if opt.netG=='global_twostream' and self.opt.which_encoder=='ctx': netD_input_nc = 3 from .Discriminator_NET import MultiscaleDiscriminator self.netD = MultiscaleDiscriminator(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, opt.num_D, not opt.no_ganFeat_loss) print(self.netD) if len(opt.gpu_ids) > 0: assert(torch.cuda.is_available()) self.netD.cuda(opt.gpu_ids[0]) self.netD.apply(weights_init) ### Encoder network if self.gen_features: from .Pix2Pix_NET import Encoder self.netE = Encoder(opt.output_nc, opt.feat_num, opt.nef, opt.n_downsample_E, opt.norm) print(self.netE) if len(opt.gpu_ids) > 0: assert(torch.cuda.is_available()) self.netE.cuda(opt.gpu_ids[0]) self.netE.apply(weights_init) print('---------- Networks initialized -------------') # load networks if not self.isTrain or opt.continue_train or opt.load_pretrain: pretrained_path = '' if not self.isTrain else opt.load_pretrain self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) if self.isTrain: self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path) if self.gen_features: self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path) # set loss functions and optimizers if self.isTrain: if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: raise NotImplementedError("Fake Pool Not Implemented for MultiGPU") self.fake_pool = ImagePool(opt.pool_size) self.old_lr = opt.lr # define loss functions from .losses import GANLoss, VGGLoss self.criterionGAN = GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionFeat = torch.nn.L1Loss() if not opt.no_vgg_loss: self.criterionVGG = VGGLoss(self.gpu_ids) # Names so we can breakout loss self.loss_names = ['G_GAN', 'G_GAN_Feat', 'G_VGG', 'D_real', 'D_fake'] # initialize optimizers # optimizer G if opt.niter_fix_global > 0: print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global) params_dict = dict(self.netG.named_parameters()) params = [] for key, value in params_dict.items(): if key.startswith('model' + str(opt.n_local_enhancers)): params += [{'params':[value],'lr':opt.lr}] else: params += [{'params':[value],'lr':0.0}] else: params = list(self.netG.parameters()) if self.gen_features: params += list(self.netE.parameters()) self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) # optimizer D params = list(self.netD.parameters()) self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) def name(self): return 'Pix2PixHDModel_condImg' def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, mask_in=None, infer=False): if self.opt.label_nc == 0: input_label = label_map.data.cuda() else: # create one-hot vector for label map size = label_map.size() oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) # get edges from instance map if not self.opt.no_instance: inst_map = inst_map.data.cuda() edge_map = self.get_edges(inst_map) input_label = torch.cat((input_label, edge_map), dim=1) input_label = Variable(input_label, volatile=infer) # real images for training assert(real_image is not None) assert(mask_in is not None) real_image = Variable(real_image.data.cuda()) mask_object_box = mask_in.repeat(1,3,1,1).cuda() cond_image = (1 - mask_object_box) * real_image + mask_object_box * NULLVAL # TODO(sh): define null_img # instance map for feature encoding if self.use_features: # get precomputed feature maps if self.opt.load_features: feat_map = Variable(feat_map.data.cuda()) return input_label, inst_map, real_image, feat_map, cond_image def discriminate(self, input_label, test_image, mask, use_pool=False): input_concat = torch.cat((input_label, test_image.detach()), dim=1) if self.opt.netG=='global_twostream' and self.opt.which_encoder=='ctx': input_concat = test_image.detach() if self.mask_gan_input: input_concat = input_concat * mask.repeat(1, input_concat.size(1), 1, 1) if use_pool: fake_query = self.fake_pool.query(input_concat) return self.netD.forward(fake_query) else: return self.netD.forward(input_concat) def forward_wrapper(self, data, infer=False): label = Variable(data['label']) inst = Variable(data['inst']) image = Variable(data['image']) mask_in = Variable(data['mask_in']) mask_out = Variable(data['mask_out']) feat = None losses, generated = self.forward(label, inst, image, feat, mask_in, mask_out, infer) return losses, generated def forward(self, label, inst, image, feat, mask_in, mask_out, infer=False): # Encode Inputs input_label, inst_map, real_image, feat_map, cond_image = self.encode_input(label, inst, image, feat, mask_in=mask_in) # NOTE(sh): modified with additional image input input_mask = input_label.clone() input_label = torch.cat((input_label, cond_image), 1) # Fake Generation input_concat = input_label if self.netG_type == 'global': fake_image = self.netG.forward(input_concat, mask_in) elif self.netG_type == 'global_twostream': fake_image = self.netG.forward(cond_image, input_mask, mask_in) # Fake Detection and Loss if self.no_imgCond: netD_cond = input_mask else: netD_cond = input_label mask_cond = mask_in if not self.use_soft_mask else mask_out pred_fake_pool = self.discriminate(netD_cond, fake_image, mask_cond, True) loss_D_fake = self.criterionGAN(pred_fake_pool, False) # Real Detection and Loss pred_real = self.discriminate(netD_cond, real_image, mask_cond, False) loss_D_real = self.criterionGAN(pred_real, True) # GAN loss (Fake Passability Loss) netD_in = torch.cat((netD_cond, fake_image), dim=1) if self.opt.netG=='global_twostream' and self.opt.which_encoder=='ctx': netD_in = fake_image if self.mask_gan_input: netD_in = netD_in * mask_cond.repeat(1, netD_in.size(1), 1, 1) pred_fake = self.netD.forward(netD_in) loss_G_GAN = self.criterionGAN(pred_fake, True) # GAN feature matching loss loss_G_GAN_Feat = Variable(self.Tensor([0])) if not self.opt.no_ganFeat_loss: feat_weights = 4.0 / (self.opt.n_layers_D + 1) D_weights = 1.0 / self.opt.num_D for i in range(self.opt.num_D): for j in range(len(pred_fake[i])-1): loss_G_GAN_Feat += D_weights * feat_weights * \ self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat # VGG feature matching loss loss_G_VGG = Variable(self.Tensor([0])) if not self.opt.no_vgg_loss: loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat # color matching loss if self.opt.lambda_rec > 0: # TOOD(sh): this part is bit hacky but let's leave it for now loss_G_GAN_Feat += self.criterionFeat(fake_image, real_image.detach()) * self.opt.lambda_rec self.fake_image = fake_image.cpu().data[0] self.real_image = real_image.cpu().data[0] self.input_label = input_mask.cpu().data[0] self.input_image = cond_image.cpu().data[0] # Only return the fake_B image if necessary to save BW return [ [ loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake ], None if not infer else fake_image ] def inference(self, label, inst, image, mask_in, mask_out): # Encode Inputs input_label, inst_map, real_image, _, cond_image = self.encode_input(label, inst, image, mask_in=mask_in, infer=True) mask_in = mask_in.cuda() # NOTE(sh): modified with additional image input input_mask = input_label.clone() input_label = torch.cat((input_label, cond_image), 1) # Fake Generation input_concat = input_label if self.netG_type == 'global': fake_image = self.netG.forward(input_concat, mask_in) elif self.netG_type == 'global_twostream': mask_in = mask_in.cuda() fake_image = self.netG.forward(cond_image, input_mask, mask_in) self.fake_image = fake_image.cpu().data[0] self.real_image = real_image.cpu().data[0] self.input_label = input_mask.cpu().data[0] self.input_image = cond_image.cpu().data[0] return fake_image def get_edges(self, t): edge = torch.cuda.ByteTensor(t.size()).zero_() edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1]) edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1]) edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) return edge.float() def get_current_visuals(self): return OrderedDict([ ('input_label', util.tensor2label(self.input_label, self.opt.label_nc)), ('input_image', util.tensor2im(self.input_image)), ('real_image', util.tensor2im(self.real_image)), ('synthesized_image', util.tensor2im(self.fake_image)) ]) def save(self, which_epoch): self.save_network(self.netG, 'G', which_epoch, self.gpu_ids) self.save_network(self.netD, 'D', which_epoch, self.gpu_ids) if self.gen_features: self.save_network(self.netE, 'E', which_epoch, self.gpu_ids) def delete_model(self, which_epoch): self.delete_network('G', which_epoch, self.gpu_ids) self.delete_network('D', which_epoch, self.gpu_ids) def update_fixed_params(self): # after fixing the global generator for a number of iterations, also start finetuning it params = list(self.netG.parameters()) if self.gen_features: params += list(self.netE.parameters()) self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) print('------------ Now also finetuning global generator -----------') def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay lr = self.old_lr - lrd for param_group in self.optimizer_D.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr