import os import time import random import numpy as np import torch import torch.nn as nn import torch.nn.parallel import torch.optim import torch.utils.data import torchvision.transforms as transforms import torchvision.datasets as datasets import torchvision.models as models import torch.backends.cudnn as cudnn from data_loader import get_loader from args import get_parser from models import * from torch.optim import lr_scheduler from tqdm import tqdm import pdb import torch.nn.functional as F from triplet_loss import * import pickle from build_vocab import Vocabulary from torch.autograd import Variable from torch.nn.utils.rnn import pack_padded_sequence import torchvision.utils as vutils # ============================================================================= parser = get_parser() opts = parser.parse_args() device = [0] with open(opts.vocab_path, 'rb') as f: vocab = pickle.load(f) # ============================================================================= ##load models image_model = torch.nn.DataParallel(ImageEmbedding().cuda(), device_ids=device) recipe_model = torch.nn.DataParallel(TextEmbedding().cuda(), device_ids=device) netG = torch.nn.DataParallel(G_NET().cuda(), device_ids=device) multi_label_net = torch.nn.DataParallel(MultiLabelNet().cuda(), device_ids=device) cm_discriminator = torch.nn.DataParallel(cross_modal_discriminator().cuda(), device_ids=device) text_discriminator = torch.nn.DataParallel(text_emb_discriminator().cuda(), device_ids=device) netsD = torch.nn.DataParallel(D_NET128().cuda(), device_ids=device) ## load loss functions triplet_loss = TripletLoss(device, margin=0.3) img2text_criterion = nn.MultiLabelMarginLoss().cuda() weights_class = torch.Tensor(opts.numClasses).fill_(1) weights_class[0] = 0 class_criterion = nn.CrossEntropyLoss(weight=weights_class).cuda() GAN_criterion = nn.BCELoss().cuda() nz = opts.Z_DIM noise = Variable(torch.FloatTensor(opts.batch_size, nz)).cuda() fixed_noise = Variable(torch.FloatTensor(opts.batch_size, nz).normal_(0, 1)).cuda() real_labels = Variable(torch.FloatTensor(opts.batch_size).fill_(1)).cuda() fake_labels = Variable(torch.FloatTensor(opts.batch_size).fill_(0)).cuda() fc_sia = nn.Sequential( nn.Linear(opts.embDim, opts.embDim), nn.BatchNorm1d(opts.embDim), nn.Tanh(), ).cuda() model_list = [image_model, recipe_model, netG, multi_label_net, cm_discriminator, text_discriminator, netsD, fc_sia] optimizer = torch.optim.Adam([ {'params': image_model.parameters()}, {'params': recipe_model.parameters()}, {'params': netG.parameters()}, {'params': multi_label_net.parameters()} ], lr=opts.lr, betas=(0.5, 0.999)) optimizers_imgD = torch.optim.Adam(netsD.parameters(), lr=opts.lr, betas=(0.5, 0.999)) optimizer_cmD = torch.optim.Adam(cm_discriminator.parameters(), lr=opts.lr, betas=(0.5, 0.999)) label = list(range(0, opts.batch_size)) label.extend(label) label = np.array(label) label = torch.tensor(label).cuda().long() method = 'acme' save_folder = method os.makedirs(save_folder, exist_ok=True) epoch_trace_f_dir = os.path.join(save_folder, "trace_" + method + ".csv") with open(epoch_trace_f_dir, "w") as f: f.write("epoch,lr,I2R,R@1,R@5,R@10,R2I,R@1,R@5,R@10\n") def main(): # data preparation, loaders normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_transform = transforms.Compose([ transforms.Resize(256), transforms.RandomCrop(224), transforms.RandomHorizontalFlip()]) val_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224)]) cudnn.benchmark = True # preparing the training laoder train_loader = get_loader(opts.img_path, train_transform, vocab, opts.data_path, partition='train', batch_size=opts.batch_size, shuffle=True, num_workers=opts.workers, pin_memory=True) print('Training loader prepared.') # preparing validation loader val_loader = get_loader(opts.img_path, val_transform, vocab, opts.data_path, partition='test', batch_size=opts.batch_size, shuffle=False, num_workers=opts.workers, pin_memory=True) print('Validation loader prepared.') best_val_i2t = {1:0.0,5:0.0,10:0.0} best_val_t2i = {1:0.0,5:0.0,10:0.0} best_epoch_i2t = 0 best_epoch_t2i = 0 for epoch in range(0, opts.epochs): train(train_loader, epoch, val_loader) recall_i2t, recall_t2i, medR_i2t, medR_t2i = validate(val_loader) with open(epoch_trace_f_dir, "a") as f: lr = optimizer.param_groups[1]['lr'] f.write("{},{},{},{},{},{},{},{},{},{}\n".format\ (epoch,lr,medR_i2t,recall_i2t[1],recall_i2t[5],recall_i2t[10],\ medR_t2i,recall_t2i[1],recall_t2i[5],recall_t2i[10])) for keys in best_val_i2t: if recall_i2t[keys] > best_val_i2t[keys]: best_val_i2t = recall_i2t best_epoch = epoch+1 model_num = 1 for model_n in model_list: filename = save_folder + '/model_e%03d_v%d.pkl' % (epoch+1, model_num) torch.save(model_n.state_dict(), filename) model_num += 1 break print("best: ", best_epoch, best_val_i2t) print('params lr: %f' % optimizer.param_groups[1]['lr']) if epoch == 30: optimizer.param_groups[0]['lr'] = 0.00001 optimizer.param_groups[1]['lr'] = 0.00001 optimizer.param_groups[2]['lr'] = 0.00001 optimizer.param_groups[3]['lr'] = 0.00001 optimizers_imgD.param_groups[0]['lr'] = 0.00001 optimizer_cmD.param_groups[0]['lr'] = 0.00001 def train_Dnet(idx, real_imgs, fake_imgs, mu, label_class): netD = netsD real_imgs = real_imgs[idx] fake_imgs = fake_imgs[idx] real_logits = netD(real_imgs, mu.detach()) fake_logits = netD(fake_imgs.detach(), mu.detach()) lossD_real = GAN_criterion(real_logits[0], real_labels) lossD_fake = GAN_criterion(fake_logits[0], fake_labels) lossD = lossD_real + lossD_fake return lossD def KL_loss(mu, logvar): KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) KLD = torch.mean(KLD_element).mul_(-0.5) return KLD def train_Gnet(idx, real_imgs, fake_imgs, mu, logvar, label_class): netD = netsD real_imgs = real_imgs[idx] fake_imgs = fake_imgs[idx] real_logits = netD(real_imgs, mu) fake_logits = netD(fake_imgs, mu) lossG_fake = GAN_criterion(fake_logits[0], real_labels) lossG_real_cond = class_criterion(real_logits[1], label_class) lossG_fake_cond = class_criterion(fake_logits[1], label_class) lossG_cond = lossG_real_cond + lossG_fake_cond lossG = lossG_fake + lossG_cond kl_loss = KL_loss(mu, logvar) * 2 lossG = kl_loss + lossG return lossG def compute_gradient_penalty(D, real_samples, fake_samples): """Calculates the gradient penalty loss for WGAN GP""" # Random weight term for interpolation between real and fake samples alpha = torch.cuda.FloatTensor(np.random.random((real_samples.size(0), 1))) # Get random interpolation between real and fake samples interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True) d_interpolates = D(interpolates) fake = torch.autograd.Variable(torch.cuda.FloatTensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False) # Get gradient w.r.t. interpolates gradients = torch.autograd.grad( outputs=d_interpolates, # fack samples inputs=interpolates, # real samples grad_outputs=fake, create_graph=True, retain_graph=True, only_inputs=True, )[0] gradients = gradients.view(gradients.size(0), -1) gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() return gradient_penalty def train(train_loader, epoch, val_loader): tri_losses = AverageMeter() img_losses = AverageMeter() text_losses = AverageMeter() cmG_losses = AverageMeter() image_model.train() recipe_model.train() for i, data in enumerate(tqdm(train_loader)): img_emd_modal = image_model(data[0][0].cuda()) recipe_emb_modal = recipe_model(data[0][1].cuda(), data[0][2].cuda(), data[0][3].cuda(), data[0][4].cuda()) ################################################################ # modal-level fusion ################################################################ real_validity = cm_discriminator(img_emd_modal.detach()) fake_validity = cm_discriminator(recipe_emb_modal.detach()) gradient_penalty = compute_gradient_penalty(cm_discriminator, img_emd_modal.detach(), recipe_emb_modal.detach()) loss_cmD = -torch.mean(real_validity) + torch.mean(fake_validity) + 10 * gradient_penalty optimizer_cmD.zero_grad() loss_cmD.backward() optimizer_cmD.step() g_fake_validity = cm_discriminator(recipe_emb_modal) loss_cmG = -torch.mean(g_fake_validity) ################################################################ # cross-modal retrieval ################################################################ img_id_fea = norm(fc_sia(img_emd_modal)) rec_id_fea = norm(fc_sia(recipe_emb_modal)) tri_loss = global_loss(triplet_loss, torch.cat((img_id_fea, rec_id_fea)), label)[0] ################################################################ # translation consistency label_class = data[1][7].cuda() real_imgs = [] real_imgs.append(data[1][8].cuda()) ingr_cap = data[1][5].cuda() lengths = torch.tensor(data[1][6]).cuda() targets = pack_padded_sequence(ingr_cap, lengths, batch_first=True)[0] one_hot_cap = data[1][9].cuda().long() ################################################################ # img2text ################################################################ recipe_out = multi_label_net(img_id_fea) loss_i2t = img2text_criterion(recipe_out[0], one_hot_cap) loss_t_class = class_criterion(recipe_out[1], label_class) loss_text = loss_i2t + loss_t_class ############################################################### # text2img ############################################################### noise.data.normal_(0, 1) fake_imgs, mu, logvar = netG(noise, rec_id_fea) lossD_total = 0 lossD = train_Dnet(0, real_imgs, fake_imgs, mu, label_class) optimizers_imgD.zero_grad() lossD.backward() optimizers_imgD.step() lossG = train_Gnet(0, real_imgs, fake_imgs, mu, logvar, label_class) loss_img = lossG if loss_text.item() < loss_img.item(): loss_img = (loss_text.item()/loss_img.item()) * loss_img else: loss_text = (loss_img.item()/loss_text.item()) * loss_text loss_g = loss_img + loss_text ############################################################### # back-propogate ############################################################### loss = tri_loss + 0.005 * loss_cmG + 0.002 * loss_g tri_losses.update(tri_loss.item(), data[0][0].size(0)) img_losses.update(loss_img.item(), data[0][0].size(0)) text_losses.update(loss_text.item(), data[0][0].size(0)) cmG_losses.update(loss_cmG.item(), data[0][0].size(0)) optimizer.zero_grad() loss.backward() optimizer.step() print(epoch) print('Epoch: {0} ' 'tri loss {tri_loss.val:.4f} ({tri_loss.avg:.4f}), ' 'cm loss {loss_cmG.val:.4f} ({loss_cmG.avg:.4f}), ' 'img loss {img_losses.val:.4f} ({img_losses.avg:.4f}), ' 'text loss {loss_text.val:.4f} ({loss_text.avg:.4f})' .format( epoch, tri_loss=tri_losses, loss_cmG=cmG_losses, img_losses=img_losses, loss_text=text_losses)) def validate(val_loader): # switch to evaluate mode image_model.eval() recipe_model.eval() end = time.time() for i, data in enumerate(tqdm(val_loader)): with torch.no_grad(): img_emd_modal = image_model(data[0][0].cuda()) recipe_emb_modal = recipe_model(data[0][1].cuda(), data[0][2].cuda(), data[0][3].cuda(), data[0][4].cuda()) img_emd_modal = norm(fc_sia(img_emd_modal)) recipe_emb_modal = norm(fc_sia(recipe_emb_modal)) if i==0: data0 = img_emd_modal.data.cpu().numpy() data1 = recipe_emb_modal.data.cpu().numpy() else: data0 = np.concatenate((data0,img_emd_modal.data.cpu().numpy()),axis=0) data1 = np.concatenate((data1,recipe_emb_modal.data.cpu().numpy()),axis=0) medR_i2t, recall_i2t = rank_i2t(opts, data0, data1) print('I2T Val medR {medR:.4f}\t' 'Recall {recall}'.format(medR=medR_i2t, recall=recall_i2t)) medR_t2i, recall_t2i = rank_t2i(opts, data0, data1) print('T2I Val medR {medR:.4f}\t' 'Recall {recall}'.format(medR=medR_t2i, recall=recall_t2i)) return recall_i2t, recall_t2i, medR_i2t, medR_t2i def rank_i2t(opts, img_embeds, rec_embeds): random.seed(opts.seed) im_vecs = img_embeds instr_vecs = rec_embeds # Ranker N = 1000 idxs = range(N) glob_rank = [] glob_recall = {1:0.0,5:0.0,10:0.0} for i in range(10): ids = random.sample(range(0,len(img_embeds)), N) im_sub = im_vecs[ids,:] instr_sub = instr_vecs[ids,:] med_rank = [] recall = {1:0.0,5:0.0,10:0.0} for ii in idxs: distance = {} for j in range(N): distance[j] = np.linalg.norm(im_sub[ii] - instr_sub[j]) distance_sorted = sorted(distance.items(), key=lambda x:x[1]) pos = np.where(np.array(distance_sorted) == distance[ii])[0][0] if (pos+1) == 1: recall[1]+=1 if (pos+1) <=5: recall[5]+=1 if (pos+1)<=10: recall[10]+=1 # store the position med_rank.append(pos+1) for i in recall.keys(): recall[i]=recall[i]/N med = np.median(med_rank) for i in recall.keys(): glob_recall[i]+=recall[i] glob_rank.append(med) for i in glob_recall.keys(): glob_recall[i] = glob_recall[i]/10 return np.average(glob_rank), glob_recall def rank_t2i(opts, img_embeds, rec_embeds): random.seed(opts.seed) im_vecs = img_embeds instr_vecs = rec_embeds # Ranker N = 1000 idxs = range(N) glob_rank = [] glob_recall = {1:0.0,5:0.0,10:0.0} for i in range(10): ids = random.sample(range(0,len(img_embeds)), N) im_sub = im_vecs[ids,:] instr_sub = instr_vecs[ids,:] med_rank = [] recall = {1:0.0,5:0.0,10:0.0} for ii in idxs: distance = {} for j in range(N): distance[j] = np.linalg.norm(instr_sub[ii] - im_sub[j]) distance_sorted = sorted(distance.items(), key=lambda x:x[1]) pos = np.where(np.array(distance_sorted) == distance[ii])[0][0] if (pos+1) == 1: recall[1]+=1 if (pos+1) <=5: recall[5]+=1 if (pos+1)<=10: recall[10]+=1 # store the position med_rank.append(pos+1) for i in recall.keys(): recall[i]=recall[i]/N med = np.median(med_rank) for i in recall.keys(): glob_recall[i]+=recall[i] glob_rank.append(med) for i in glob_recall.keys(): glob_recall[i] = glob_recall[i]/10 return np.average(glob_rank), glob_recall class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count if __name__ == '__main__': main()