#!/usr/bin/env python # -*- coding: utf-8 -*- """ Train the baseline model. """ from __future__ import print_function, absolute_import import argparse import os.path as osp import os import numpy as np import sys import time import math import torch from torch import nn from torch.autograd import Variable from torch.backends import cudnn from torch.utils.data import DataLoader from torch.utils.data.sampler import RandomSampler from torch.nn.parameter import Parameter from torchvision import transforms as T from config import * import models from data_pre import Lighting, Preprocessor from utils import Logger, AverageMeter from utils import load_checkpoint, save_checkpoint from evaluators import accuracy import pdb def get_params(pretrained_model): pretrained_checkpoint = load_checkpoint(pretrained_model) for name, param in pretrained_checkpoint.items(): #for name, param in pretrained_checkpoint['state_dict'].items(): print('pretrained_model params name and size: ', name, param.size()) if isinstance(param, Parameter): # backwards compatibility for serialized parameters param = param.data try: np.save(name+'.npy', param.cpu().numpy()) print('############# new_model load params name: ',name) except: raise RuntimeError('While copying the parameter named {}, \ whose dimensions in the model are {} and \ whose dimensions in the checkpoint are {}.' .format(name, new_model_dict[name].size(), param.size())) def get_data(split_id, data_dir, img_size, scale_size, batch_size, workers, train_list, val_list): root = data_dir normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # RGB imagenet # with data augmentation train_transformer = T.Compose([ T.RandomResizedCrop(img_size), T.RandomHorizontalFlip(), T.ToTensor(), # [0, 255] to [0.0, 1.0] normalizer, # normalize each channel of the input ]) test_transformer = T.Compose([ T.Resize(scale_size), T.CenterCrop(img_size), T.ToTensor(), normalizer, ]) train_loader = DataLoader( Preprocessor(train_list, root=root, transform=train_transformer), batch_size=batch_size, num_workers=workers, sampler=RandomSampler(train_list), pin_memory=True, drop_last=False) val_loader = DataLoader( Preprocessor(val_list, root=root, transform=test_transformer), batch_size=batch_size, num_workers=workers, shuffle=False, pin_memory=True) return train_loader, val_loader def main(args): np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) cudnn.benchmark = True data_dir = osp.join(args.data_dir, args.dataset) # Redirect print to both console and log file if not args.evaluate: sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt')) else: sys.stdout = Logger(osp.join(args.logs_dir, 'evaluate-log.txt')) print('\n################## setting ###################') print(parser.parse_args()) print('################## setting ###################\n') # Create data loaders def readlist(fpath): lines=[] with open(fpath, 'r') as f: data = f.readlines() for line in data: name, label = line.split() lines.append((name, int(label))) return lines # Load data list if osp.exists(osp.join(data_dir, 'train.txt')): train_list = readlist(osp.join(data_dir, 'train.txt')) else: raise RuntimeError("The training list -- {} doesn't exist".format(train_list)) if osp.exists(osp.join(data_dir, 'val.txt')): val_list = readlist(osp.join(data_dir, 'val.txt')) else: raise RuntimeError("The val list -- {} doesn't exist".format(val_list)) if args.scale_size is None : args.scale_size = 256 if args.img_size is None : args.img_size = 224 train_loader, val_loader = \ get_data(args.split, data_dir, args.img_size, args.scale_size, args.batch_size, args.workers, train_list, val_list) # Create model #num_classes = 1000 # imagenet 1000 model = models.create(args.arch, False, num_classes=1000) if args.adam: print('The optimizer is Adam !!!') optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) else: print('The optimizer is SGD !!!') optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) # Load model from checkpoint start_epoch = best_top1 = 0 if args.pretrained: print('=> Start load params from pre-trained model...') checkpoint = load_checkpoint(args.pretrained) if 'alexnet' in args.arch or 'resnet' in args.arch: model.load_state_dict(checkpoint) #model.load_state_dict(checkpoint['state_dict']) #torch.save(model.state_dict(), osp.join('./pre-models', 'resnet18-relu6-703.pth')) else: raise RuntimeError('The arch is ERROR!!!') # get model parameters get_params(args.pretrained) pdb.set_trace() if args.resume: checkpoint = load_checkpoint(args.resume) model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = args.resume_epoch print("=> Finetune Start epoch {} " .format(start_epoch)) model = nn.DataParallel(model).cuda() # Criterion criterion = nn.CrossEntropyLoss().cuda() evaluator = Evaluator(model, criterion) if args.evaluate: print('Test model: \n') evaluator.evaluate(val_loader) return # Trainer trainer = Trainer(model, criterion) # Schedule learning rate def adjust_lr(epoch): step_size = args.step_size decay_step = args.decay_step lr = args.lr if epoch < step_size else \ args.lr * (0.1 ** ((epoch - step_size) // decay_step + 1)) for g in optimizer.param_groups: g['lr'] = lr * g.get('lr_mult', 1) # Start training trainer.show_info(with_arch=True, with_grad=False) for epoch in range(start_epoch, args.epochs): adjust_lr(epoch) trainer.train(epoch, train_loader, optimizer, print_info=args.print_info) if epoch < args.start_save: continue top1 = evaluator.evaluate(val_loader) is_best = top1 > best_top1 best_top1 = max(top1, best_top1) save_checkpoint({ 'state_dict':model.module.state_dict(), 'optimizer': optimizer.state_dict()}, is_best, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar')) print('\n * Finished epoch {:3d} top1: {:5.2%} model_best: {:5.2%} \n'. format(epoch, top1, best_top1)) if (epoch+1) % 5 == 0: model_name = 'epoch_'+ str(epoch) + '.pth.tar' torch.save({'state_dict':model.module.state_dict(), 'optimizer': optimizer.state_dict()}, osp.join(args.logs_dir, model_name)) class Trainer(object): def __init__(self, model, criterion): super(Trainer, self).__init__() self.model = model self.criterion = criterion def train(self, epoch, data_loader, optimizer, print_freq=1, print_info=10): self.model.train() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() end = time.time() for i, inputs in enumerate(data_loader): data_time.update(time.time() - end) inputs_var, targets_var = self._parse_data(inputs) loss, prec1, prec5 = self._forward(inputs_var, targets_var) losses.update(loss.data[0], targets_var.size(0)) top1.update(prec1, targets_var.size(0)) top5.update(prec5, targets_var.size(0)) optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm(self.model.parameters(), 5.0) optimizer.step() batch_time.update(time.time() - end) end = time.time() if (i + 1) % print_freq == 0: print('Epoch: [{}][{}/{}]\t' 'Time {:.3f} ({:.3f})\t' 'Data {:.3f} ({:.3f})\t' 'Loss {:.3f} ({:.3f})\t' 'Prec@1 {:.2%} ({:.2%})\t' 'Prec@5 {:.2%} ({:.2%})\t' .format(epoch, i + 1, len(data_loader), batch_time.val, batch_time.avg, data_time.val, data_time.avg, losses.val, losses.avg, top1.val, top1.avg, top5.val, top5.avg)) if (epoch+1) % print_info == 0: self.show_info() def show_info(self, with_arch=False, with_grad=True): if with_arch: print('\n\n################# model modules ###################') for name, m in self.model.named_modules(): print('{}: {}'.format(name, m)) print('################# model modules ###################\n\n') if with_grad: print('################# model params diff ###################') for name, param in self.model.named_parameters(): mean_value = torch.abs(param.data).mean() mean_grad = torch.abs(param.grad).mean().data[0] + 1e-8 print('{}: size{}, data_abd_avg: {}, dgrad_abd_avg: {}, data/grad: {}'.format(name, param.size(), mean_value, mean_grad, mean_value/mean_grad)) print('################# model params diff ###################\n\n') else: print('################# model params ###################') for name, param in self.model.named_parameters(): print('{}: size{}, abs_avg: {}'.format(name, param.size(), torch.abs(param.data.cpu()).mean())) print('################# model params ###################\n\n') def _parse_data(self, inputs): imgs, _, labels = inputs inputs_var = [Variable(imgs)] targets_var = Variable(labels.cuda()) return inputs_var, targets_var def _forward(self, inputs, targets): outputs = self.model(*inputs) if isinstance(self.criterion, torch.nn.CrossEntropyLoss): loss = self.criterion(outputs, targets) prec1, prec5= accuracy(outputs.data, targets.data, topk=(1,5)) prec1 = prec1[0] prec5 = prec5[0] else: raise ValueError("Unsupported loss:", self.criterion) return loss, prec1, prec5 class Evaluator(object): def __init__(self, model, criterion): super(Evaluator, self).__init__() self.model = model self.criterion = criterion def evaluate(self, data_loader, print_freq=1): batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() self.model.eval() end = time.time() for i, inputs in enumerate(data_loader): inputs_var, targets_var = self._parse_data(inputs) loss, prec1, prec5 = self._forward(inputs_var, targets_var) losses.update(loss.data[0], targets_var.size(0)) top1.update(prec1, targets_var.size(0)) top5.update(prec5, targets_var.size(0)) batch_time.update(time.time() - end) end = time.time() if i % print_freq == 0: print('Test: [{}/{}]\t' 'Time {:.3f} ({:.3f})\t' 'Loss {:.4f} ({:.4f})\t' 'Prec@1 {:.2%} ({:.2%})\t' 'Prec@5 {:.2%} ({:.2%})\t' .format(i + 1, len(data_loader), batch_time.val, batch_time.avg, losses.val, losses.avg, top1.val, top1.avg, top5.val, top5.avg)) print(' * Prec@1 {:.2%} Prec@5 {:.2%}'.format(top1.avg, top5.avg)) return top1.avg def _parse_data(self, inputs): imgs, _, labels = inputs inputs_var = [Variable(imgs, volatile=True)] targets_var = Variable(labels.cuda(), volatile=True) return inputs_var, targets_var def _forward(self, inputs, targets): outputs = self.model(*inputs) if isinstance(self.criterion, torch.nn.CrossEntropyLoss): loss = self.criterion(outputs, targets) prec1, prec5= accuracy(outputs.data, targets.data, topk=(1,5)) prec1 = prec1[0] prec5 = prec5[0] else: raise ValueError("Unsupported loss:", self.criterion) return loss, prec1, prec5 if __name__ == '__main__': parser = argparse.ArgumentParser(description="Softmax loss classification") # data parser.add_argument('-d', '--dataset', type=str, default='imagenet') parser.add_argument('-b', '--batch-size', type=int, default=256) parser.add_argument('-j', '--workers', type=int, default=4) parser.add_argument('--split', type=int, default=0) parser.add_argument('--scale_size', type=int, default=256, help="val resize image size, default: 256 for ImageNet") parser.add_argument('--img_size', type=int, default=224, help="input image size, default: 224 for ImageNet") # model parser.add_argument('-a', '--arch', type=str, default='resnet50', choices=models.names()) # optimizer parser.add_argument('--lr', type=float, default=0.001, help="learning rate of new parameters, for pretrained " "parameters it is 10 times smaller than this") parser.add_argument('--momentum', type=float, default=0.9) parser.add_argument('--weight-decay', type=float, default=1e-5) parser.add_argument('--step_size', type=int, default=25) parser.add_argument('--decay_step', type=int, default=25) # training configs pretrained_model parser.add_argument('--pretrained', type=str, default='', metavar='PATH') parser.add_argument('--resume', type=str, default='', metavar='PATH') parser.add_argument('--resume_epoch', type=int,default=0) parser.add_argument('--evaluate', action='store_true', help="evaluation only") parser.add_argument('--adam', action='store_true', help="use Adam") parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--start_save', type=int, default=0, help="start saving checkpoints after specific epoch") parser.add_argument('--seed', type=int, default=1) parser.add_argument('--print-freq', type=int, default=1) parser.add_argument('--print-info', type=int, default=10) # misc working_dir = osp.dirname(osp.abspath(__file__)) parser.add_argument('--data-dir', type=str, metavar='PATH', default=osp.join(working_dir, 'data')) parser.add_argument('--logs-dir', type=str, metavar='PATH', default=osp.join(working_dir, 'logs')) main(parser.parse_args())