""" This model only disentangles pid from wid inside the IdentityEncoder, which is the most crucial part """ from __future__ import print_function, division from collections import OrderedDict import torch from torch.autograd import Variable import numpy as np import Options import random import embedding_utils import loss_functions import network.FAN_feature_extractor as FAN_feature_extractor import network.IdentityEncoder as IdentityEncoder # import network.VGGM as VGGM import network.Decoder_networks as Decoder_network import network.mfcc_networks as mfcc_networks import network.networks as networks import util.util as util from network import Discriminator_networks as Discriminator_networks # opt = Options.Config() class GenModel(): def __init__(self, opt): self.opt = opt self.Tensor = torch.cuda.FloatTensor if opt.cuda_on else torch.Tensor # define tensors self.input_A = self.Tensor(opt.batchSize, opt.image_channel_size, opt.image_size, opt.image_size) self.input_B = self.Tensor(opt.batchSize, opt.pred_length, opt.image_channel_size, opt.image_size, opt.image_size) self.input_video = self.Tensor(opt.batchSize, opt.sequence_length + 1, opt.image_channel_size, opt.image_size, opt.image_size) self.input_audio = self.Tensor(opt.batchSize, opt.sequence_length + 1, 1, opt.mfcc_length, opt.mfcc_width) self.B_audio = self.Tensor(opt.batchSize, opt.pred_length, 1, opt.mfcc_length, opt.mfcc_width) self.input_video_dis = self.Tensor(opt.batchSize, opt.disfc_length , opt.image_channel_size, opt.image_size, opt.image_size) self.video_pred_data = self.Tensor(opt.batchSize, opt.pred_length, opt.image_channel_size, opt.image_size, opt.image_size) self.audio_pred_data = self.Tensor(opt.batchSize, opt.pred_length, 1, opt.image_size, opt.image_size) self.ID_encoder = IdentityEncoder.IdentityEncoder() self.Decoder = Decoder_network.Decoder(opt) # audio wid feature encoder self.mfcc_encoder = mfcc_networks.mfcc_encoder_two(opt) # visual wid feature encoder self.lip_feature_encoder = FAN_feature_extractor.FanFusion(opt) # discriminator to disentangle wid from pid self.ID_lip_discriminator = Discriminator_networks.ID_dis32(feature_length=64, config=opt) # Classifier from wid to class label self.model_fusion = networks.ModelFusion(opt) # discriminator for adv in embedding wid self.discriminator_audio = networks.discriminator_audio() use_sigmoid = opt.no_lsgan self.netD = Discriminator_networks.Discriminator(input_nc=3, use_sigmoid=use_sigmoid) self.netD_mul = Discriminator_networks.Discriminator(input_nc=3 * opt.sequence_length, use_sigmoid=use_sigmoid) self.netD_mul.apply(networks.weights_init) self.netD.apply(networks.weights_init) # self.Decoder.apply(networks.weights_init) self.ID_lip_discriminator.apply(networks.weights_init) self.old_lr = opt.lr # define loss functions self.criterionGAN = loss_functions.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor, softlabel=False) self.criterionGAN_soft = loss_functions.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor, softlabel=True) self.criterionL1 = torch.nn.L1Loss() self.criterionSmoothL1 = torch.nn.SmoothL1Loss() self.criterionL2 = torch.nn.MSELoss() self.L2Contrastive = loss_functions.L2ContrastiveLoss(margin=opt.L2margin) self.criterionCE = torch.nn.CrossEntropyLoss() self.inv_dis_loss = loss_functions.L2SoftmaxLoss() # initialize optimizers self.optimizer_G = torch.optim.Adam(list(self.Decoder.parameters()) + list(self.ID_encoder.parameters()) + list(self.model_fusion.parameters()) + list(self.mfcc_encoder.parameters()) + list(self.lip_feature_encoder.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D = torch.optim.Adam(list(self.netD.parameters()) + list(self.netD_mul.parameters()) + list(self.discriminator_audio.parameters()) + list(self.ID_lip_discriminator.parameters()) , lr=opt.lr, betas=(opt.beta1, 0.999)) if torch.cuda.is_available(): if opt.cuda_on: if opt.mul_gpu: self.ID_encoder = torch.nn.DataParallel(self.ID_encoder) self.Decoder = torch.nn.DataParallel(self.Decoder) self.mfcc_encoder = torch.nn.DataParallel(self.mfcc_encoder) self.netD_mul = torch.nn.DataParallel(self.netD_mul) self.netD = torch.nn.DataParallel(self.netD) self.lip_feature_encoder = torch.nn.DataParallel(self.lip_feature_encoder) self.ID_lip_discriminator = torch.nn.DataParallel(self.ID_lip_discriminator) self.model_fusion = torch.nn.DataParallel(self.model_fusion) self.discriminator_audio = torch.nn.DataParallel(self.discriminator_audio) self.ID_encoder.cuda() self.Decoder.cuda() self.mfcc_encoder.cuda() self.lip_feature_encoder.cuda() self.ID_lip_discriminator.cuda() self.netD_mul.cuda() self.netD.cuda() self.criterionL1.cuda() self.criterionGAN.cuda() self.criterionGAN_soft.cuda() self.criterionL2.cuda() self.criterionCE.cuda() self.inv_dis_loss.cuda() self.model_fusion.cuda() self.discriminator_audio.cuda() self.L2Contrastive.cuda() print('---------- Networks initialized -------------') def name(self): return 'GenModel' def set_input(self, input, input_label): input_video = input['video'] input_audio = input['mfcc20'] self.input_label = input_label.cuda() dis_select_start = random.randint(0, 25 - self.opt.disfc_length - 1) A_select = random.randint(0, 28) pred_start = random.randint(0, 1) input_A = input_video[:, A_select, :, :, :].contiguous() input_video_dis = input_video[:, dis_select_start:dis_select_start + self.opt.disfc_length, :, :, :] video_pred_data = input_video[:, pred_start:pred_start + self.opt.pred_length * 2:2, :, :, :] audio_pred_data = input_audio[:, pred_start:pred_start + self.opt.pred_length * 2:2, :, :, :] self.input_A.resize_(input_A.size()).copy_(input_A) self.input_video_dis.resize_(input_video_dis.size()).copy_(input_video_dis) self.video_pred_data.resize_(video_pred_data.size()).copy_(video_pred_data) self.audio_pred_data.resize_(audio_pred_data.size()).copy_(audio_pred_data) self.image_paths = input['A_path'] def forward(self): self.input_label = Variable(self.input_label) self.real_A = Variable(self.input_A) B_start = random.randint(0, self.opt.pred_length - self.opt.sequence_length) self.audios_dis = Variable(self.audio_pred_data) self.video_dis = Variable(self.video_pred_data) # real_videos are the frames used for training generation, self.real_videos = Variable(self.video_pred_data[:, B_start:B_start + self.opt.sequence_length, :, :, :].contiguous()) self.audios = Variable(self.audio_pred_data[:, B_start:B_start + self.opt.sequence_length, :, :, :].contiguous()) self.video_send_to_disfc = Variable(self.input_video_dis) self.mask = Variable(self.Tensor(self.opt.batchSize, (self.opt.sequence_length) * self.opt.image_channel_size, self.opt.image_size, self.opt.image_size).fill_(0)) self.mask[:, :, 170:234, 64:192] = 1 self.mask_ones = Variable(self.Tensor(self.opt.batchSize, self.opt.image_channel_size, self.opt.image_size, self.opt.image_size).fill_(1)) self.mask_ones[:, :, 170:234, 64:192] = 0 self.mfcc_encoder.train() self.lip_feature_encoder.train() # compute the ID embeddings self.real_A_id_embedding = self.ID_encoder.forward(self.real_A) # compute the sequence ID imbeddings if self.opt.disfc_length == 12: self.sequence_id_embedding = self.ID_encoder.forward(self.video_dis) else: self.sequence_id_embedding = self.ID_encoder.forward(self.video_send_to_disfc) self.sequence_id_embedding = self.sequence_id_embedding[4].view(-1, self.opt.disfc_length * 64, 64, 64) # extract the lip feature # self.audio_embedding = self.mfcc_encoder.forward(self.audio_A) self.audio_embeddings_dis = self.mfcc_encoder.forward(self.audios_dis) self.lip_embeddings_dis = self.lip_feature_encoder.forward(self.video_dis) self.audio_embeddings = self.audio_embeddings_dis[:, B_start:B_start + self.opt.sequence_length].contiguous() self.lip_embeddings = self.lip_embeddings_dis[:, B_start:B_start + self.opt.sequence_length].contiguous() # loss between audio and lip embedding self.lip_embedding_norm = embedding_utils.l2_norm(self.lip_embeddings_dis.view(-1, 256 * self.opt.pred_length)) self.audio_embedding_norm = embedding_utils.l2_norm(self.audio_embeddings_dis.view(-1, 256 * self.opt.pred_length)) self.lip_embeddings_buffer = Variable(self.lip_embedding_norm.data) self.EmbeddingL2 = self.L2Contrastive.forward(self.lip_embeddings_buffer, self.audio_embedding_norm) # generate fake images self.sequence_generation() # single self.fakes = torch.cat((self.audio_gen_fakes_batch, self.image_gen_fakes_batch), 0) self.real_one = self.real_videos.view(-1, self.opt.image_channel_size, self.opt.image_size, self.opt.image_size) self.reals = torch.cat((self.real_one, self.real_one), 0) self.audio_reals = torch.cat((self.audios.view(-1, 1, self.opt.mfcc_length, self.opt.mfcc_width), self.audios.view(-1, 1, self.opt.mfcc_length, self.opt.mfcc_width)), 0) # sequence self.fakes_sequence = self.fakes.view(-1, self.opt.image_channel_size * (self.opt.sequence_length), self.opt.image_size, self.opt.image_size) self.real_one_sequence = self.real_videos.view(-1, self.opt.image_channel_size * (self.opt.sequence_length), self.opt.image_size, self.opt.image_size) self.reals_sequence = self.reals.view(-1, self.opt.image_channel_size * self.opt.sequence_length, self.opt.image_size, self.opt.image_size) self.audio_reals_sequence = self.audio_reals.view(-1, self.opt.sequence_length, self.opt.mfcc_length, self.opt.mfcc_width) def sequence_generation(self): self.lip_embeddings = self.lip_embeddings.view(-1, self.opt.sequence_length, self.opt.feature_length) image_gen_fakes = [] self.audio_embeddings = self.audio_embeddings.view(-1, self.opt.sequence_length, self.opt.feature_length) audio_gen_fakes = [] self.last_frame = Variable(self.real_A.data) self.G_x_loss = 0 for i in range(self.opt.sequence_length): image_gen_fakes_buffer = self.Decoder(self.real_A_id_embedding, self.lip_embeddings[:, i, :]) image_gen_fakes.append(image_gen_fakes_buffer.view(-1, 1, self.opt.image_channel_size, self.opt.image_size, self.opt.image_size)) audio_gen_fakes_buffer = self.Decoder(self.real_A_id_embedding, self.audio_embeddings[:, i, :]) audio_gen_fakes.append(audio_gen_fakes_buffer.view(-1, 1, self.opt.image_channel_size, self.opt.image_size, self.opt.image_size)) self.G_x_loss = self.G_x_loss + self.criterionL1(audio_gen_fakes_buffer* self.mask_ones, self.last_frame * self.mask_ones) last_frame = audio_gen_fakes_buffer.data self.last_frame = Variable(last_frame) if i > 0: last_frame = audio_gen_fakes_buffer.data self.last_frame = Variable(last_frame) self.image_gen_fakes = torch.cat(image_gen_fakes, 1) self.image_gen_fakes_batch = self.image_gen_fakes.view(-1, self.opt.image_channel_size, self.opt.image_size, self.opt.image_size) self.image_gen_fakes = self.image_gen_fakes.view(-1, self.opt.image_channel_size * (self.opt.sequence_length), self.opt.image_size, self.opt.image_size) self.audio_gen_fakes = torch.cat(audio_gen_fakes, 1) self.audio_gen_fakes_batch = self.audio_gen_fakes.view(-1, self.opt.image_channel_size, self.opt.image_size, self.opt.image_size) self.audio_gen_fakes = self.audio_gen_fakes.view(-1, self.opt.image_channel_size * (self.opt.sequence_length), self.opt.image_size, self.opt.image_size) def backward_dis(self): self.audio_D_real = self.discriminator_audio(self.audio_embeddings_dis.detach()) self.audio_D_fake = self.discriminator_audio(self.lip_embeddings_dis.detach()) self.image_loss_D_real = self.criterionGAN(self.audio_D_fake, False) self.audio_loss_D_real = self.criterionGAN(self.audio_D_real, True) self.dis_R_loss = (self.image_loss_D_real + self.audio_loss_D_real) * 0.5 self.dis_R_loss.backward() def backward_D(self): # train ID_disciminate_fc self.lip_pred = self.ID_lip_discriminator(self.sequence_id_embedding.detach()) self.CE_loss = self.criterionCE(self.lip_pred, self.input_label) * self.opt.lambda_CE # GAN single fake if self.opt.require_single_GAN: self.pred_fake_single, self.pred_fake_single_combine = self.netD.forward(self.fakes.detach(), self.audio_reals) self.loss_D_single_fake = self.criterionGAN_soft(self.pred_fake_single, False) self.loss_D_single_combine_fake = self.criterionGAN_soft(self.pred_fake_single_combine, False) # GAN single real self.pred_real, self.pred_real_combine = self.netD.forward(self.reals, self.audio_reals) self.loss_D_single_real = self.criterionGAN_soft(self.pred_real, True) self.loss_D_single_combine_real = self.criterionGAN_soft(self.pred_real_combine, True) self.loss_D_single = (self.loss_D_single_fake + self.loss_D_single_real) * 0.5 self.loss_D_single_combine = (self.loss_D_single_combine_fake + self.loss_D_single_combine_real) * 0.5 else: self.loss_D_single_combine = 0 self.loss_D_single = 0 if self.opt.require_sequence_GAN: # GAN sequence fake self.pred_fake_sequence, self.pred_fake_sequence_combine = self.netD_mul.forward(self.fakes_sequence.detach(), self.audio_reals_sequence) self.loss_D_sequence_fake = self.criterionGAN_soft(self.pred_fake_sequence, False) self.loss_D_sequence_combine_fake = self.criterionGAN_soft(self.pred_fake_sequence_combine, False) # GAN sequence real self.pred_real_sequence, self.pred_real_sequence_combine = self.netD_mul.forward(self.reals_sequence, self.audio_reals_sequence) self.loss_D_sequence_real = self.criterionGAN_soft(self.pred_real_sequence, True) self.loss_D_sequence_combine_real = self.criterionGAN_soft(self.pred_real_sequence_combine, True) self.loss_D_sequence = (self.loss_D_sequence_fake + self.loss_D_sequence_real) * 0.5 self.loss_D_sequence_combine = (self.loss_D_sequence_combine_fake + self.loss_D_sequence_combine_real) * 0.5 else: self.loss_D_sequence_combine = 0 self.loss_D_sequence = 0 # Combined loss self.loss_D = (self.loss_D_sequence_combine + self.loss_D_sequence) + \ (self.loss_D_single_combine + self.loss_D_single) + \ self.CE_loss self.loss_D.backward() def backward_G(self): self.audio_D_real = self.discriminator_audio(self.audio_embeddings_dis) self.audio_loss_D_inv = self.criterionGAN(self.audio_D_real, False) self.audio_D_fake = self.discriminator_audio(self.lip_embeddings_dis) self.image_loss_D_inv = self.criterionGAN(self.audio_D_fake, True) # classification self.audio_pred = self.model_fusion.forward(self.audio_embeddings_dis) self.audio_CE_loss = self.criterionCE(self.audio_pred, self.input_label) self.audio_acc = self.compute_acc(self.audio_pred) self.image_pred = self.model_fusion.forward(self.lip_embeddings_dis) self.image_CE_loss = self.criterionCE(self.image_pred, self.input_label) self.image_acc = self.compute_acc(self.image_pred) # id_discriminator self.lip_pred = self.ID_lip_discriminator(self.sequence_id_embedding) self.softmax_loss = self.inv_dis_loss.forward(self.lip_pred) * self.opt.lambda_CE_inv self.lip_acc = self.compute_acc(self.lip_pred) # single if self.opt.require_single_GAN: pred_fake, pred_combine_fake = self.netD.forward(self.fakes, self.audio_reals) self.loss_G_GAN_single = self.criterionGAN(pred_fake, True) self.loss_G_GAN_single_combine = self.criterionGAN(pred_combine_fake, True) else: self.loss_G_GAN_single = 0 self.loss_G_GAN_single_combine = 0 #sequence if self.opt.require_sequence_GAN: pred_fake, pred_combine_fake = self.netD_mul.forward(self.fakes_sequence, self.audio_reals_sequence) self.loss_G_GAN_sequence = self.criterionGAN(pred_fake, True) self.loss_G_GAN_sequence_combine = self.criterionGAN(pred_combine_fake, True) else: self.loss_G_GAN_sequence = 0 self.loss_G_GAN_sequence_combine = 0 self.loss_G_L1_audio = self.criterionL1(self.audio_gen_fakes * 255, self.real_one_sequence * 255) * self.opt.lambda_A + \ self.criterionL1(self.audio_gen_fakes * self.mask * 255, self.real_one_sequence * self.mask * 255) * self.opt.lambda_B self.loss_G_L1_image = self.criterionL1(self.image_gen_fakes * 255, self.real_one_sequence * 255) * self.opt.lambda_A + \ self.criterionL1(self.image_gen_fakes * self.mask * 255, self.real_one_sequence * self.mask * 255) * self.opt.lambda_B self.loss_G = (self.loss_G_GAN_single + self.loss_G_GAN_single_combine) + \ (self.loss_G_GAN_sequence + self.loss_G_GAN_sequence_combine) + \ self.loss_G_L1_audio + self.loss_G_L1_image + self.G_x_loss * 5\ + self.EmbeddingL2 +\ self.softmax_loss + self.audio_CE_loss + self.image_CE_loss + \ (self.audio_loss_D_inv + self.image_loss_D_inv)*5 self.loss_G.backward() def set_test_input(self, input, input_label): input_video = input['video'] input_audio = input['mfcc20'] self.input_label = input_label.cuda() dis_select_start = random.randint(0, 25 - self.opt.disfc_length - 1) pred_start = random.randint(0, 1) input_video_dis = input_video[:, dis_select_start:dis_select_start + self.opt.disfc_length, :, :, :] video_pred_data = input_video[:, pred_start:pred_start + self.opt.pred_length * 2:2, :, :, :] audio_pred_data = input_audio[:, pred_start:pred_start + self.opt.pred_length * 2:2, :, :, :] self.input_video_dis.resize_(input_video_dis.size()).copy_(input_video_dis) self.video_pred_data.resize_(video_pred_data.size()).copy_(video_pred_data) self.audio_pred_data.resize_(audio_pred_data.size()).copy_(audio_pred_data) self.image_paths = input['A_path'] def test(self): self.mfcc_encoder.eval() self.lip_feature_encoder.eval() self.input_label = Variable(self.input_label, volatile=True) self.audios_dis = Variable(self.audio_pred_data, volatile=True) self.video_dis = Variable(self.video_pred_data, volatile=True) # compute the sequence ID imbeddings self.audio_embeddings_dis = self.mfcc_encoder.forward(self.audios_dis).view(-1, 256 * self.opt.pred_length) self.lip_embeddings_dis = self.lip_feature_encoder.forward(self.video_dis).view(-1, 256 * self.opt.pred_length) # loss between audio and lip embedding self.lip_embedding_norm = embedding_utils.l2_norm(self.lip_embeddings_dis) self.audio_embedding_norm = embedding_utils.l2_norm(self.audio_embeddings_dis) self.lip_embeddings_buffer = Variable(self.lip_embedding_norm.data) self.EmbeddingL2 = self.L2Contrastive.forward(self.audio_embedding_norm, self.lip_embeddings_buffer) # generate fake images # classification self.audio_pred = self.model_fusion.forward(self.audio_embeddings_dis) self.audio_acc = self.compute_acc(self.audio_pred) self.image_pred = self.model_fusion.forward(self.lip_embeddings_dis) self.image_acc = self.compute_acc(self.image_pred) self.output = (self.audio_pred + self.image_pred) self.final_acc = self.compute_acc(self.output) def forward_no_generation(self): # Used when training without generation self.mfcc_encoder.train() self.lip_feature_encoder.train() self.model_fusion.train() self.input_audio_data = Variable(self.audio_pred_data) self.input_image_data = Variable(self.video_pred_data) self.input_label = Variable(self.input_label) self.audio_embeddings_dis = self.mfcc_encoder.forward(self.input_audio_data) self.lip_embeddings_dis = self.lip_feature_encoder.forward(self.input_image_data) self.audio_fusion = self.audio_embeddings_dis.view(-1, int(256 * self.opt.pred_length / 3)) self.image_fusion = self.lip_embeddings_dis.view(-1, int(256 * self.opt.pred_length / 3)) self.audio_fusion_buffer = Variable(self.audio_fusion.data) self.image_fusion_buffer = Variable(self.image_fusion.data) def backward_no_generation(self): # Used when training without generation self.audio_D_real = self.discriminator_audio(self.audio_embeddings_dis) self.audio_loss_D_inv = self.criterionGAN(self.audio_D_real, False) self.audio_output = self.model_fusion.forward(self.audio_embeddings_dis) self.audio_CE_loss = self.criterionCE(self.audio_output, self.input_label) self.audio_L2_loss = self.L2Contrastive.forward(self.image_fusion.detach(), self.audio_fusion) self.audio_ranking_loss = self.Contrastive.forward(self.image_fusion.detach(), self.audio_fusion) self.loss_audio = self.audio_loss_D_inv + self.audio_CE_loss + \ self.audio_L2_loss self.loss_audio.backward() audio_D_fake = self.discriminator_audio(self.lip_embeddings_dis) self.image_loss_D_inv = self.criterionGAN(audio_D_fake, True) self.image_output = self.model_fusion.forward(self.lip_embeddings_dis) self.image_acc = self.compute_acc(self.image_output) self.output = (self.audio_pred + self.image_output) self.acc = self.compute_acc(self.output) self.image_CE_loss = self.criterionCE(self.image_output, self.input_label) self.image_L2_loss = self.L2Contrastive.forward(self.audio_fusion.detach(), self.image_fusion) self.image_ranking_loss = self.Contrastive.forward(self.audio_fusion.detach(), self.image_fusion) self.loss_image = self.image_L2_loss + self.image_CE_loss\ + self.image_loss_D_inv self.loss_image.backward() def save_feature(self): self.ID_encoder.eval() self.ID_lip_discriminator.eval() self.video_send_to_disfc = Variable(self.input_video_dis, volatile=True) # compute the sequence ID imbeddings # self.audio_embedding = self.mfcc_encoder.forward(self.audio_A) self.sequence_id_embedding = self.ID_encoder.forward(self.video_send_to_disfc) # self.sequence_id_embedding = self.sequence_id_embedding[0].view(-1, opt.disfc_length, opt.feature_length) self.lip_pred_feature = self.sequence_id_embedding[0].view(-1, self.opt.disfc_length * 256) def optimize_parameters(self): self.forward() self.optimizer_D.zero_grad() self.backward_dis() self.backward_D() self.optimizer_D.step() self.optimizer_G.zero_grad() self.backward_G() self.optimizer_G.step() def optimize_parameters_no_generation(self): self.forward_no_generation() self.optimizer_D.zero_grad() self.backward_dis() self.optimizer_D.step() self.optimizer_G.zero_grad() self.backward_no_generation() self.optimizer_G.step() def get_current_errors(self): if self.opt.require_single_GAN: return OrderedDict([('G_GAN_single', self.loss_G_GAN_single.data[0]), ('G_GAN_single_combine', self.loss_G_GAN_single_combine.data[0]), ('G_L1_audio', self.loss_G_L1_audio.data[0]), ('G_L1_image', self.loss_G_L1_image.data[0]), ('D_real_single', self.loss_D_single_real.data[0]), ('D_fake_single', self.loss_D_single_fake.data[0]), ('D_combine_real_single', self.loss_D_single_combine_real.data[0]), ('D_combine_fake_single', self.loss_D_single_combine_fake.data[0]), ('CE_loss', self.CE_loss.data[0]), ('lossoftmax', self.softmax_loss.data[0]), ('audio_acc', self.audio_acc), ('image_acc', self.image_acc), ('EmbeddingL2', self.EmbeddingL2.data[0]), ('dis_R_loss', self.dis_R_loss.data[0]) ]) else: return OrderedDict([('G_GAN_sequence', self.loss_G_GAN_sequence.data[0]), ('G_GAN_sequence_combine', self.loss_G_GAN_sequence_combine.data[0]), ('G_L1_audio', self.loss_G_L1_audio.data[0]), ('G_L1_image', self.loss_G_L1_image.data[0]), ('D_real_sequence', self.loss_D_sequence_real.data[0]), ('D_fake_sequence', self.loss_D_sequence_combine_real.data[0]), ('D_combine_real_sequence', self.loss_D_sequence_combine_real.data[0]), ('D_combine_fake_sequence', self.loss_D_sequence_combine_fake.data[0]), ('CE_loss', self.CE_loss.data[0]), ('lossoftmax', self.softmax_loss.data[0]), ('audio_acc', self.audio_acc), ('image_acc', self.image_acc), ('EmbeddingL2', self.EmbeddingL2.data[0]), ('dis_R_loss', self.dis_R_loss.data[0]) ]) def get_current_visuals(self): fake_B_audio = self.audio_gen_fakes.view(-1, self.opt.sequence_length, self.opt.image_channel_size, self.opt.image_size, self.opt.image_size) fake_B_image = self.image_gen_fakes.view(-1, self.opt.sequence_length, self.opt.image_channel_size, self.opt.image_size, self.opt.image_size) real_A = util.tensor2im(self.real_A.data) oderdict = OrderedDict([('real_A', real_A)]) fake_audio_B = {} fake_image_B = {} real_B = {} for i in range(self.opt.sequence_length): fake_audio_B[i] = util.tensor2im(fake_B_audio[:, i, :, :, :].data) fake_image_B[i] = util.tensor2im(fake_B_image[:, i, :, :, :].data) real_B[i] = util.tensor2im(self.real_videos[:, i, :, :, :].data) oderdict['real_B_' + str(i)] = real_B[i] oderdict['fake_audio_B_' + str(i)] = fake_audio_B[i] oderdict['fake_image_B_' + str(i)] = fake_image_B[i] return oderdict def get_visual_path(self): print(self.image_paths[0]) def update_learning_rate(self): lrd = self.opt.lr / self.opt.niter_decay lr = self.old_lr - lrd for param_group in self.optimizer_D.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr def compute_acc(self, out): _, pred = out.topk(1, 1) pred0 = pred.squeeze().data acc = 100 * torch.sum(pred0 == self.input_label.data) / self.input_label.size(0) return acc def TfWriter(self, writer, total_steps): # write loss to tensorboard writer.add_scalar('train_image_L2_loss', embedding_utils.to_np(self.EmbeddingL2), total_steps) writer.add_scalar('image_loss_D_inv', embedding_utils.to_np(self.image_loss_D_inv), total_steps) writer.add_scalar('train_audio_acc', self.audio_acc, total_steps) writer.add_scalar('train_image_acc', self.image_acc, total_steps)