import os from collections import OrderedDict import torch import pyro ''' Basic functions. ''' def init_weights(m): class_name = m.__class__.__name__ try: if class_name.find('Conv') != -1: m.weight.data.normal_(0.0, 0.02) if m.bias is not None: m.bias.data.fill_(0) elif class_name.find('Linear') != -1: m.weight.data.normal_(0.0, 0.02) m.bias.data.fill_(0) elif class_name.find('BatchNorm2d') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) except: print('Exception in init_weights:', class_name) class BaseModel: ''' Base model that implements basic functions such as saving and loading checkpoints, saving results, update hyperparameters, etc. ''' def __init__(self): self.nets, self.optimizers, self.schedulers = {}, {}, [] self.video_dict = {} # For visualization def initialize_weights(self): for _, net in self.nets.items(): net.apply(init_weights) def setup(self, is_train): for _, net in self.nets.items(): if is_train: net.train() else: net.eval() def load(self, ckpt_path, epoch, load_optimizer=False): ''' Load checkpoint. ''' for name, net in self.nets.items(): path = os.path.join(ckpt_path, 'net_{}_{}.pth'.format(name, epoch)) if not os.path.exists(path): print('{} does not exist, ignore.'.format(path)) continue ckpt = torch.load(path) if isinstance(net, torch.nn.DataParallel): module = net.module else: module = net try: module.load_state_dict(ckpt) except: print('net_{} and checkpoint have different parameter names'.format(name)) new_ckpt = OrderedDict() for ckpt_key, module_key in zip(ckpt.keys(), module.state_dict().keys()): assert ckpt_key.split('.')[-1] == module_key.split('.')[-1] new_ckpt[module_key] = ckpt[ckpt_key] module.load_state_dict(new_ckpt) if load_optimizer: for name, optimizer in self.optimizers.items(): path = os.path.join(ckpt_path, 'optimizer_{}_{}.pth'.format(name, epoch)) if not os.path.exists(path): print('{} does not exist, ignore.'.format(path)) continue ckpt = torch.load(path) optimizer.load_state_dict(ckpt) def save(self, ckpt_path, epoch): ''' Save checkpoint. ''' for name, net in self.nets.items(): if isinstance(net, torch.nn.DataParallel): module = net.module else: module = net path = os.path.join(ckpt_path, 'net_{}_{}.pth'.format(name, epoch)) torch.save(module.state_dict(), path) for name, optimizer in self.optimizers.items(): path = os.path.join(ckpt_path, 'optimizer_{}_{}.pth'.format(name, epoch)) torch.save(optimizer.state_dict(), path) def pyro_sample(self, name, fn, mu, sigma, sample=True): ''' Sample with pyro.sample. fn should be dist.Normal. If sample is False, then return mean. ''' if sample: return pyro.sample(name, fn(mu, sigma)) else: return mu.contiguous() def save_visuals(self, gt, output, components, latent): ''' Save data for visualization. Take the first result in the batch. ''' videos = [gt.data[0].cpu()] for i in range(components.size(2)): images = components.data[0, :, i, ...].cpu() videos.append(images) videos.append(output.data[0].cpu()) videos = torch.cat(videos, dim=2).clamp(0, 1) videos = videos * 2 - 1 # map to [-1, 1] self.video_dict.update({'results': videos}) def get_visuals(self): return self.video_dict def update_hyperparameters(self, epoch, n_epochs): ''' Update learning rate. Multiply learning rate by 0.1 halfway through training. ''' # Learning rate lr = self.lr_init if self.lr_decay: if epoch >= n_epochs // 2: lr = self.lr_init * 0.1 for param_group in self.optimizer.param_groups: param_group['lr'] = lr return {'lr': lr}