# -*- coding: utf-8 -*- """ @author : Haoran You """ import numpy as np import torch import os from collections import OrderedDict from torch.autograd import Variable # from torch.optim import lr_scheduler import itertools import util.util as util from util.image_pool import ImagePool from . import networks import sys import random from PIL import Image import torchvision.transforms as transforms transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) class CycleGAN(): def name(self): return 'Bayesian CycleGAN Model' def initialize(self, opt): self.opt = opt self.isTrain = opt.isTrain if torch.cuda.is_available(): print('cuda is available, we will use gpu!') self.Tensor = torch.cuda.FloatTensor torch.cuda.manual_seed_all(100) else: self.Tensor = torch.FloatTensor torch.manual_seed(100) self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # get radio for network initialization ratio = 256 * 256 / opt.loadSize / (opt.loadSize / opt.ratio) # load network netG_input_nc = opt.input_nc + 1 netG_output_nc = opt.output_nc + 1 self.netG_A = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG_A, opt.n_downsample_global, opt.n_blocks_global, opt.norm).type(self.Tensor) self.netG_B = networks.define_G(netG_output_nc, opt.input_nc, opt.ngf, opt.netG_B, opt.n_downsample_global, opt.n_blocks_global, opt.norm).type(self.Tensor) self.netE_A = networks.define_G(opt.input_nc, 1, 64, 'encoder', opt.n_downsample_global, norm=opt.norm, ratio=ratio).type(self.Tensor) self.netE_B = networks.define_G(opt.output_nc, 1, 64, 'encoder', opt.n_downsample_global, norm=opt.norm, ratio=ratio).type(self.Tensor) if self.isTrain: use_sigmoid = opt.no_lsgan self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, opt.num_D_A).type(self.Tensor) self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, opt.num_D_B).type(self.Tensor) if not self.isTrain or opt.continue_train: self.load_network(self.netG_A, 'G_A', opt.which_epoch, self.save_dir) self.load_network(self.netG_B, 'G_B', opt.which_epoch, self.save_dir) self.load_network(self.netE_A, 'E_A', opt.which_epoch, self.save_dir) self.load_network(self.netE_B, 'E_B', opt.which_epoch, self.save_dir) if self.isTrain: self.load_network(self.netD_A, 'D_A', opt.which_epoch, self.save_dir) self.load_network(self.netD_B, 'D_B', opt.which_epoch, self.save_dir) # set loss functions and optimizers if self.isTrain: self.old_lr = opt.lr # define loss function self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) self.criterionCycle = torch.nn.L1Loss() self.criterionL1 = torch.nn.L1Loss() # initialize optimizers self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_E_A = torch.optim.Adam(self.netE_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_E_B = torch.optim.Adam(self.netE_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_A = torch.optim.Adam(self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) self.optimizer_D_B = torch.optim.Adam(self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) print('----------Network initialized!-----------') self.print_network(self.netG_A) self.print_network(self.netG_B) self.print_network(self.netE_A) self.print_network(self.netE_B) if self.isTrain: self.print_network(self.netD_A) self.print_network(self.netD_B) print('-----------------------------------------') # dataset path and name list self.origin_path = os.getcwd() self.path_A = self.opt.dataroot + '/trainA' self.path_B = self.opt.dataroot + '/trainB' if not self.opt.isTrain: if self.opt.use_feat: self.path_A = self.opt.dataroot + '/feat' self.path_B = self.opt.dataroot + '/feat' self.list_A = os.listdir(self.path_A) self.list_B = os.listdir(self.path_B) def set_input(self, input): AtoB = self.opt.which_direction == 'AtoB' self.input_A = input['A' if AtoB else 'B'] self.input_B = input['B' if AtoB else 'A'] self.image_paths = input['A_paths' if AtoB else 'B_paths'] def forward(self): self.real_A = Variable(self.input_A).type(self.Tensor) self.real_B = Variable(self.input_B).type(self.Tensor) # feature map mc_sample_x = random.sample(self.list_A, self.opt.mc_x) mc_sample_y = random.sample(self.list_B, self.opt.mc_y) self.real_B_zx = [] self.real_A_zy = [] self.mu_x = [] self.mu_y = [] self.logvar_x = [] self.logvar_y = [] os.chdir(self.path_A) for sample_x in mc_sample_x: z_x = Image.open(sample_x).convert('RGB') z_x = self.img_resize(z_x, self.opt.loadSize) z_x = transform(z_x) if self.opt.input_nc == 1: # RGB to gray z_x = z_x[0, ...] * 0.299 + z_x[1, ...] * 0.587 + z_x[2, ...] * 0.114 z_x = z_x.unsqueeze(0) z_x = Variable(z_x).type(self.Tensor) z_x = torch.unsqueeze(z_x, 0) mu_x, logvar_x, feat_map = self.netE_A.forward(z_x) self.mu_x.append(mu_x) self.logvar_x.append(logvar_x) self.feat_map_zx = feat_map real_B_zx = [] for i in range(0, self.opt.batchSize): _real = torch.unsqueeze(self.real_B[i], 0) _real = torch.cat([_real, feat_map], dim=1) real_B_zx.append(_real) real_B_zx = torch.cat(real_B_zx) self.real_B_zx.append(real_B_zx) self.mu_x = torch.cat(self.mu_x) self.logvar_x = torch.cat(self.logvar_x) os.chdir(self.path_B) for sample_y in mc_sample_y: z_y = Image.open(sample_y).convert('RGB') z_y = self.img_resize(z_y, self.opt.loadSize) z_y = transform(z_y) if self.opt.output_nc == 1: # RGB to gray z_y = z_y[0, ...] * 0.299 + z_y[1, ...] * 0.587 + z_y[2, ...] * 0.114 z_y = z_y.unsqueeze(0) z_y = Variable(z_y).type(self.Tensor) z_y = torch.unsqueeze(z_y, 0) mu_y, logvar_y, feat_map = self.netE_B.forward(z_y) self.mu_y.append(mu_y) self.logvar_y.append(logvar_y) self.feat_map_zy = feat_map real_A_zy = [] for i in range(0, self.opt.batchSize): _real = torch.unsqueeze(self.real_A[i], 0) _real = torch.cat((_real, feat_map), dim=1) real_A_zy.append(_real) real_A_zy = torch.cat(real_A_zy) self.real_A_zy.append(real_A_zy) self.mu_y = torch.cat(self.mu_y) self.logvar_y = torch.cat(self.logvar_y) os.chdir(self.origin_path) def inference(self): real_A = Variable(self.input_A).type(self.Tensor) real_B = Variable(self.input_B).type(self.Tensor) # feature map os.chdir(self.path_A) mc_sample_x = random.sample(self.list_A, 1) z_x = Image.open(mc_sample_x[0]).convert('RGB') z_x = self.img_resize(z_x, self.opt.loadSize) z_x = transform(z_x) if self.opt.input_nc == 1: # RGB to gray z_x = z_x[0, ...] * 0.299 + z_x[1, ...] * 0.587 + z_x[2, ...] * 0.114 z_x = z_x.unsqueeze(0) if self.opt.use_feat: z_x = z_x[0, ...] * 0.299 + z_x[1, ...] * 0.587 + z_x[2, ...] * 0.114 z_x = z_x.unsqueeze(0) z_x = Variable(z_x).type(self.Tensor) z_x = torch.unsqueeze(z_x, 0) if not self.opt.use_feat: mu_x, logvar_x, feat_map_zx = self.netE_A.forward(z_x) else: feat_map_zx = z_x os.chdir(self.path_B) mc_sample_y = random.sample(self.list_B, 1) z_y = Image.open(mc_sample_y[0]).convert('RGB') z_y = self.img_resize(z_y, self.opt.loadSize) z_y = transform(z_y) if self.opt.output_nc == 1: # RGB to gray z_y = z_y[0, ...] * 0.299 + z_y[1, ...] * 0.587 + z_y[2, ...] * 0.114 z_y = z_y.unsqueeze(0) if self.opt.use_feat: z_y = z_y[0, ...] * 0.299 + z_y[1, ...] * 0.587 + z_y[2, ...] * 0.114 z_y = z_y.unsqueeze(0) z_y = Variable(z_y).type(self.Tensor) z_y = torch.unsqueeze(z_y, 0) if not self.opt.use_feat: mu_y, logvar_y, feat_map_zy = self.netE_B.forward(z_y) else: feat_map_zy = z_y os.chdir(self.origin_path) # combine input image with random feature map real_B_zx = [] for i in range(0, self.opt.batchSize): _real = torch.cat((real_B[i:i+1], feat_map_zx), dim=1) real_B_zx.append(_real) real_B_zx = torch.cat(real_B_zx) real_A_zy = [] for i in range(0, self.opt.batchSize): _real = torch.cat((real_A[i:i+1], feat_map_zy), dim=1) real_A_zy.append(_real) real_A_zy = torch.cat(real_A_zy) # inference fake_B = self.netG_A(real_A_zy) fake_B_next = torch.cat((fake_B, feat_map_zx), dim=1) self.rec_A = self.netG_B(fake_B_next).data self.fake_B = fake_B.data fake_A = self.netG_B(real_B_zx) fake_A_next = torch.cat((fake_A, feat_map_zy), dim=1) self.rec_B = self.netG_A(fake_A_next).data self.fake_A = fake_A.data def get_image_paths(self): return self.image_paths def img_resize(self, img, target_width): ow, oh = img.size if (ow == target_width): return img else: w = target_width h = int(target_width * oh / ow) return img.resize((w, h), Image.BICUBIC) def get_z_random(self, batchSize, nz, random_type='gauss'): z = self.Tensor(batchSize, nz) if random_type == 'uni': z.copy_(torch.rand(batchSize, nz) * 2.0 - 1.0) elif random_type == 'gauss': z.copy_(torch.randn(batchSize, nz)) z = Variable(z) return z def backward_G(self): # GAN loss D_A(G_A(A)) fake_B = [] for real_A in self.real_A_zy: _fake = self.netG_A(real_A) fake_B.append(_fake) fake_B = torch.cat(fake_B) pred_fake = self.netD_A(fake_B) loss_G_A = self.criterionGAN(pred_fake, True) # GAN loss D_B(G_B(B)) fake_A = [] for real_B in self.real_B_zx: _fake = self.netG_B(real_B) fake_A.append(_fake) fake_A = torch.cat(fake_A) pred_fake = self.netD_B(fake_A) loss_G_B = self.criterionGAN(pred_fake, True) # cycle loss lambda_A = self.opt.lambda_A lambda_B = self.opt.lambda_B # Forward cycle loss fake_B_next = [] for i in range(0, fake_B.size(0)): _fake = fake_B[i:(i+1)] _fake = torch.cat((_fake, self.feat_map_zx), dim=1) fake_B_next.append(_fake) # _fake = fake_B[i*self.opt.batchSize:(i+1)*self.opt.batchSize] # feat_map_zx = [] # for i in range(0, self.opt.batchSize): # feat_map_zx.append(self.feat_map_zx) # feat_map_zx = torch.cat(feat_map_zx) fake_B_next = torch.cat(fake_B_next) rec_A = self.netG_B(fake_B_next) loss_cycle_A = 0 for i in range(0, self.opt.mc_y): loss_cycle_A += self.criterionCycle(rec_A[i*self.real_A.size(0):(i+1)*self.real_A.size(0)], self.real_A) * lambda_A pred_cycle_G_A = self.netD_B(rec_A) loss_cycle_G_A = self.criterionGAN(pred_cycle_G_A, True) # Backward cycle loss fake_A_next = [] for i in range(0, fake_A.size(0)): _fake = fake_A[i:(i+1)] _fake = torch.cat((_fake, self.feat_map_zy), dim=1) fake_A_next.append(_fake) # _fake = fake_A[i*self.opt.batchSize:(i+1)*self.opt.batchSize] # feat_map_zy = [] # for i in range(0, self.opt.batchSize): # feat_map_zy.append(self.feat_map_zy) # feat_map_zy = torch.cat(feat_map_zy) fake_A_next = torch.cat(fake_A_next) rec_B = self.netG_A(fake_A_next) loss_cycle_B = 0 for i in range(0, self.opt.mc_x): loss_cycle_B += self.criterionCycle(rec_B[i*self.real_B.size(0):(i+1)*self.real_B.size(0)], self.real_B) * lambda_B pred_cycle_G_B = self.netD_A(rec_B) loss_cycle_G_B = self.criterionGAN(pred_cycle_G_B, True) # prior loss prior_loss_G_A = self.get_prior(self.netG_A.parameters(), self.opt.batchSize) prior_loss_G_B = self.get_prior(self.netG_B.parameters(), self.opt.batchSize) # KL loss kl_element = self.mu_x.pow(2).add_(self.logvar_x.exp()).mul_(-1).add_(1).add_(self.logvar_x) loss_kl_EA = torch.sum(kl_element).mul_(-0.5) * self.opt.lambda_kl kl_element = self.mu_y.pow(2).add_(self.logvar_y.exp()).mul_(-1).add_(1).add_(self.logvar_y) loss_kl_EB = torch.sum(kl_element).mul_(-0.5) * self.opt.lambda_kl # total loss loss_G = loss_G_A + loss_G_B + (prior_loss_G_A + prior_loss_G_B) + (loss_cycle_G_A + loss_cycle_G_B) * self.opt.gamma + (loss_cycle_A + loss_cycle_B) + (loss_kl_EA + loss_kl_EB) loss_G.backward() self.fake_B = fake_B.data self.fake_A = fake_A.data self.rec_A = rec_A.data self.rec_B = rec_B.data self.loss_G_A = loss_G_A.data[0] + loss_cycle_G_A.data[0] * self.opt.gamma + prior_loss_G_A.data[0] self.loss_G_B = loss_G_B.data[0] + loss_cycle_G_B.data[0] * self.opt.gamma + prior_loss_G_A.data[0] self.loss_cycle_A = loss_cycle_A.data[0] self.loss_cycle_B = loss_cycle_B.data[0] self.loss_kl_EA = loss_kl_EA.data[0] self.loss_kl_EB = loss_kl_EB.data[0] def backward_D_A(self): fake_B = Variable(self.fake_B).type(self.Tensor) rec_B = Variable(self.rec_B).type(self.Tensor) # how well it classifiers fake images pred_fake = self.netD_A(fake_B.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) pred_cycle_fake = self.netD_A(rec_B.detach()) loss_D_cycle_fake = self.criterionGAN(pred_cycle_fake, False) # how well it classifiers real images pred_real = self.netD_A(self.real_B) loss_D_real = self.criterionGAN(pred_real, True) * self.opt.mc_y # prior loss prior_loss_D_A = self.get_prior(self.netD_A.parameters(), self.opt.batchSize) # total loss loss_D_A = (loss_D_real + loss_D_fake) * 0.5 + (loss_D_real + loss_D_cycle_fake) * 0.5 * self.opt.gamma + prior_loss_D_A loss_D_A.backward() self.loss_D_A = loss_D_A.data[0] def backward_D_B(self): fake_A = Variable(self.fake_A).type(self.Tensor) rec_A = Variable(self.rec_A).type(self.Tensor) # how well it classifiers fake images pred_fake = self.netD_B(fake_A.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) pred_cycle_fake = self.netD_B(rec_A.detach()) loss_D_cycle_fake = self.criterionGAN(pred_cycle_fake, False) # how well it classifiers real images pred_real = self.netD_B(self.real_A) loss_D_real = self.criterionGAN(pred_real, True) * self.opt.mc_x # prior loss prior_loss_D_B = self.get_prior(self.netD_B.parameters(), self.opt.batchSize) # total loss loss_D_B = (loss_D_real + loss_D_fake) * 0.5 + (loss_D_real + loss_D_cycle_fake) * 0.5 * self.opt.gamma + prior_loss_D_B loss_D_B.backward() self.loss_D_B = loss_D_B.data[0] def backward_G_pair(self): # GAN loss D_A(G_A(A)) and L1 loss fake_B = [] loss_G_A_L1 = 0 for real_A in self.real_A_zy: _fake = self.netG_A(real_A) loss_G_A_L1 += self.criterionL1(_fake, self.real_B) * self.opt.lambda_A fake_B.append(_fake) fake_B = torch.cat(fake_B) pred_fake = self.netD_A(fake_B) loss_G_A = self.criterionGAN(pred_fake, True) # GAN loss D_B(G_B(B)) fake_A = [] loss_G_B_L1 = 0 for real_B in self.real_B_zx: _fake = self.netG_B(real_B) loss_G_B_L1 += self.criterionL1(_fake, self.real_A) * self.opt.lambda_B fake_A.append(_fake) fake_A = torch.cat(fake_A) pred_fake = self.netD_B(fake_A) loss_G_B = self.criterionGAN(pred_fake, True) # prior loss prior_loss_G_A = self.get_prior(self.netG_A.parameters(), self.opt.batchSize) prior_loss_G_B = self.get_prior(self.netG_B.parameters(), self.opt.batchSize) # KL loss kl_element = self.mu_x.pow(2).add_(self.logvar_x.exp()).mul_(-1).add_(1).add_(self.logvar_x) loss_kl_EA = torch.sum(kl_element).mul_(-0.5) * self.opt.lambda_kl kl_element = self.mu_y.pow(2).add_(self.logvar_y.exp()).mul_(-1).add_(1).add_(self.logvar_y) loss_kl_EB = torch.sum(kl_element).mul_(-0.5) * self.opt.lambda_kl # total loss loss_G = loss_G_A + loss_G_B + (prior_loss_G_A + prior_loss_G_B) + (loss_kl_EA + loss_kl_EB) + (loss_G_A_L1 + loss_G_B_L1) loss_G.backward() self.fake_B = fake_B.data self.fake_A = fake_A.data def backward_D_A_pair(self): fake_B = Variable(self.fake_B).type(self.Tensor) # how well it classifiers fake images pred_fake = self.netD_A(fake_B.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # how well it classifiers real images pred_real = self.netD_A(self.real_B) loss_D_real = self.criterionGAN(pred_real, True) * self.opt.mc_y # prior loss prior_loss_D_A = self.get_prior(self.netD_A.parameters(), self.opt.batchSize) # total loss loss_D_A = (loss_D_real + loss_D_fake) * 0.5 + prior_loss_D_A loss_D_A.backward() def backward_D_B_pair(self): fake_A = Variable(self.fake_A).type(self.Tensor) # how well it classifiers fake images pred_fake = self.netD_B(fake_A.detach()) loss_D_fake = self.criterionGAN(pred_fake, False) # how well it classifiers real images pred_real = self.netD_B(self.real_A) loss_D_real = self.criterionGAN(pred_real, True) * self.opt.mc_x # prior loss prior_loss_D_B = self.get_prior(self.netD_B.parameters(), self.opt.batchSize) # total loss loss_D_B = (loss_D_real + loss_D_fake) * 0.5 + prior_loss_D_B loss_D_B.backward() def optimize(self, pair=False): # forward self.forward() # G_A and G_B # E_A and E_B self.optimizer_G.zero_grad() self.optimizer_E_A.zero_grad() self.optimizer_E_B.zero_grad() if pair==True: self.backward_G_pair() else: self.backward_G() self.optimizer_G.step() self.optimizer_E_A.step() self.optimizer_E_B.step() # D_A self.optimizer_D_A.zero_grad() if pair==True: self.backward_D_A_pair() else: self.backward_D_A() self.optimizer_D_A.step() # D_B self.optimizer_D_B.zero_grad() if pair==True: self.backward_D_B_pair() else: self.backward_D_B() self.optimizer_D_B.step() def get_current_loss(self): loss = OrderedDict([ ('D_A', self.loss_D_A), ('D_B', self.loss_D_B), ('G_A', self.loss_G_A), ('G_B', self.loss_G_B) ]) if self.opt.gamma == 0: loss['cyc_A'] = self.loss_cycle_A loss['cyc_B'] = self.loss_cycle_B elif self.opt.gamma > 0: loss['cyc_G_A'] = self.loss_cycle_A loss['cyc_G_B'] = self.loss_cycle_B if self.opt.lambda_kl > 0: loss['kl_EA'] = self.loss_kl_EA loss['kl_EB'] = self.loss_kl_EB return loss def get_stye_loss(self): loss = OrderedDict([ ('L1_A', self.loss_G_A_L1), ('L1_B', self.loss_G_B_L1) ]) return loss def get_current_visuals(self): real_A = util.tensor2im(self.input_A) fake_B = util.tensor2im(self.fake_B) rec_A = util.tensor2im(self.rec_A) real_B = util.tensor2im(self.input_B) fake_A = util.tensor2im(self.fake_A) rec_B = util.tensor2im(self.rec_B) visuals = OrderedDict([ ('real_A', real_A), ('fake_B', fake_B), ('rec_A', rec_A), ('real_B', real_B), ('fake_A', fake_A), ('rec_B', rec_B) ]) return visuals def get_prior(self, parameters, dataset_size): prior_loss = Variable(torch.zeros((1))).cuda() for param in parameters: prior_loss += torch.mean(param*param) return prior_loss / dataset_size # def get_noise(self, parameters, alpha, dataset_size): # noise_loss = Variable(torch.zeros((1))).cuda() # noise_std = np.sqrt(2 * alpha) # for param in parameters: # noise = Variable(torch.normal(std=torch.ones(param.size()))).cuda() # noise_loss += torch.sum(param*noise*noise_std) # return noise_loss / dataset_size def save_model(self, label): self.save_network(self.netG_A, 'G_A', label) self.save_network(self.netG_B, 'G_B', label) self.save_network(self.netE_A, 'E_A', label) self.save_network(self.netE_B, 'E_B', label) self.save_network(self.netD_A, 'D_A', label) self.save_network(self.netD_B, 'D_B', label) def load_network(self, network, network_label, epoch_label, save_dir=''): save_filename = '%s_net_%s.pth' % (epoch_label, network_label) save_path = os.path.join(self.save_dir, save_filename) try: network.load_state_dict(torch.load(save_path)) except: pretrained_dict = torch.load(save_path) model_dict = network.state_dict() try: pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} network.load_state_dict(pretrained_dict) print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label) except: print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label) if sys.version_info >= (3, 0): not_initialized = set() else: from sets import Set not_initialized = Set() for k, v in pretrained_dict.items(): if v.size() == model_dict[k].size(): model_dict[k] = v for k, v in model_dict.items(): if k not in pretrained_dict or v.size() != pretrained_dict[k].size(): not_initialized.add(k.split('.')[0]) print(sorted(not_initialized)) network.load_state_dict(model_dict) def save_network(self, network, network_label, epoch_label): save_filename = '%s_net_%s.pth' % (epoch_label, network_label) save_path = os.path.join(self.save_dir, save_filename) torch.save(network.cpu().state_dict(), save_path) if torch.cuda.is_available(): network.cuda() def print_network(self, net): num_params = 0 for param in net.parameters(): num_params += param.numel() print(net) print('Total number of parameters: %d' % num_params) # update learning rate (called once every iter) def update_learning_rate(self, epoch, epoch_iter, dataset_size): # lrd = self.opt.lr / self.opt.niter_decay if epoch > self.opt.niter: lr = self.opt.lr * np.exp(-1.0 * min(1.0, epoch_iter/float(dataset_size))) for param_group in self.optimizer_D_A.param_groups: param_group['lr'] = lr for param_group in self.optimizer_D_B.param_groups: param_group['lr'] = lr for param_group in self.optimizer_G.param_groups: param_group['lr'] = lr print('update learning rate: %f -> %f' % (self.old_lr, lr)) self.old_lr = lr else: lr = self.old_lr