import argparse import os,sys import time import shutil import torch import torchvision import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.optim from torch.nn.utils import clip_grad_norm from log import log from dataset import ClipDataset #from p3d_model import * #from p3d_model import get_optim_policies #from I3D_Pytorch import * from S3DG_Pytorch import * from transforms import * from IPython import embed from dataloader_pkl import KineticsPKL best_prec1=0 channel_dim=1 def get_args(): parser = argparse.ArgumentParser(description="TAL") parser.add_argument('dataset', type=str, choices=['kinetics']) parser.add_argument('modality', type=str, choices=['RGB', 'Flow']) parser.add_argument('train_list', type=str) parser.add_argument('val_list', type=str) # ========================= Model Configs ========================== parser.add_argument('--arch', type=str, default="S3DG",choices=['S3DG']) parser.add_argument('--dropout', '--do', default=0.5, type=float, metavar='DO', help='dropout ratio (default: 0.5)') parser.add_argument('-d','--data_workers',default=8,type=int) # ========================= Learning Configs ========================== parser.add_argument('--epochs', default=45, type=int, metavar='N',help='number of total epochs to run') parser.add_argument('-b', '--batch-size', default=64, type=int,metavar='N', help='mini-batch size (default: 256)') parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,metavar='LR', help='initial learning rate') parser.add_argument('--lr_steps', default=[20, 40], type=float, nargs="+",metavar='LRSteps', help='epochs to decay learning rate by 10') parser.add_argument('--momentum', default=0.9, type=float, metavar='M',help='momentum') parser.add_argument('--weight-decay', '--wd', default=5e-4, type=float,metavar='W', help='weight decay (default: 5e-4)') parser.add_argument('--clip-gradient', '--gd', default=None, type=float,metavar='W', help='gradient norm clipping (default: disabled)') parser.add_argument('--no_partialbn', '--npb', default=False, action="store_true") # ========================= Monitor Configs ========================== parser.add_argument('-p','--print-freq', default=20, type=int,metavar='N', help='print frequency (default: 10)') parser.add_argument('-ef','--eval-freq', default=5, type=int,metavar='N', help='evaluation frequency (default: 5)') # ========================= Runtime Configs ========================== parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',help='number of data loading workers (default: 4)') parser.add_argument('--resume', default='', type=str, metavar='PATH',help='path to latest checkpoint (default: none)') parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',help='evaluate model on validation set') parser.add_argument('--snapshot_pref', type=str, default="Qijie") parser.add_argument('--start-epoch', default=0, type=int, metavar='N',help='manual epoch number (useful on restarts)') parser.add_argument('--gpus', nargs='+', type=int, default=None) parser.add_argument('--flow_prefix', default="", type=str) # ========================= Return the final total Args================== Args=parser.parse_args() return Args def main(): global args,best_prec1 args=get_args() log.l.info('Input command:\npython '+ ' '.join(sys.argv)+' ===========>') if args.dataset == 'kinetics': num_class = 400 else: raise ValueError('Unknown dataset '+args.dataset) log.l.info('============= prepare the model and model\'s parameters =============') if args.arch=='S3DG': input_channel=3 if args.modality=='RGB' else 2 model=S3DG(num_classes=num_class,input_channel=input_channel,dropout_keep_prob=0.5) model.load_state_dict('modelweights/{}_imagenet.pkl'.format(args.modality)) else: raise ValueError('Unknown model'+ args.arch) #model=transfer_model(model,num_classes=num_class) crop_size = 224# model.crop_size #scale_size = model.scale_size input_mean = [0.485,0.456,0.406]#model.input_mean input_std = [0.229,0.224,0.225]#model.input_std temporal_length = 64#model.temporal_length #model = torch.nn.DataParallel(model, device_ids=args.gpus).cuda() model=torch.nn.DataParallel(model).cuda() if args.resume: log.l.info('============== train from checkpoint (finetune mode) =================') if os.path.isfile(args.resume): log.l.info(("=> loading checkpoint '{}'".format(args.resume))) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_prec1 = checkpoint['best_prec1'] model.load_state_dict(checkpoint['state_dict']) log.l.info(("=> loaded checkpoint '{}' (epoch {})" .format(args.evaluate, checkpoint['epoch']))) else: log.l.info(("=> no checkpoint found at '{}'".format(args.resume))) log.l.info('============== Now, loading data ... ==============\n') # Data loading code normalize = GroupNormalize(input_mean, input_std) if args.modality == 'RGB': data_channel = 3 elif args.modality == 'Flow': data_channel = 2 is_gray=False if args.modality=='RGB' else True Kinetics_train=KineticsPKL(args.train_list,seglen=64,is_train=True,cropsize=crop_size, transform=torchvision.transforms.Compose([ GroupScale((256,256)), #train_augmentation, #GroupRandomCrop(crop_size), GroupMultiScaleCrop(224,[1,0.875,0.75,0.66]), GroupRandomHorizontalFlip(), Stack(), ToTorchFormatTensor(div=args.arch != 'BNInception'), normalize, ])) train_loader = torch.utils.data.DataLoader( Kinetics_train, batch_size=args.batch_size, shuffle=True, num_workers=args.data_workers, pin_memory=True,drop_last=True) Kinetics_val=KineticsPKL(args.val_list,seglen=64,is_train=False,cropsize=crop_size, transform=torchvision.transforms.Compose([ GroupScale((256,256)), GroupCenterCrop(224), Stack(), ToTorchFormatTensor(div=args.arch != 'BNInception'), normalize, ])) val_loader = torch.utils.data.DataLoader( Kinetics_val, batch_size=args.batch_size, shuffle=False, num_workers=args.data_workers, pin_memory=True,drop_last=True) log.l.info('================= Now, define loss function and optimizer ==============') criterion = torch.nn.CrossEntropyLoss(ignore_index=-1).cuda() #for group in policies: # log.l.info(('group: {} has {} params, lr_mult: {}, decay_mult: {}'.format( # group['name'], len(group['params']), group['lr_mult'], group['decay_mult']))) optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.evaluate: log.l.info('Need val the data first...') validate(val_loader, model, criterion, 0) log.l.info('\n\n===================> TRAIN and VAL begins <===================\n') for epoch in range(args.start_epoch, args.epochs): adjust_learning_rate(optimizer, epoch, args.lr_steps) # train for one epoch train(train_loader, model, criterion, optimizer, epoch) # evaluate on validation set if (epoch + 1) % args.eval_freq == 0 or epoch == args.epochs - 1: prec1 = validate(val_loader, model, criterion, (epoch + 1) * len(train_loader)) # remember best prec@1 and save checkpoint is_best = prec1 > best_prec1 best_prec1 = max(prec1, best_prec1) save_checkpoint({ 'epoch': epoch + 1, 'arch': args.arch, 'state_dict': model.state_dict(), 'best_prec1': best_prec1, }, is_best) def train(train_loader, model, criterion, optimizer, epoch): batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to train mode model.train() end = time.time() for i, (input, target,vid) in enumerate(train_loader): # measure data loading time data_time.update(time.time() - end) target = target.cuda(async=True) input_var = torch.autograd.Variable(input) target_var = torch.autograd.Variable(target) # compute output output = model(input_var)[0] loss = criterion(output, target_var) # measure accuracy and record loss prec1, prec5 = accuracy(output.data, target, topk=(1,5)) losses.update(loss.data[0], input.size(0)) top1.update(prec1[0], input.size(0)) top5.update(prec5[0], input.size(0)) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() if args.clip_gradient is not None: total_norm = clip_grad_norm(model.parameters(), args.clip_gradient) if total_norm > args.clip_gradient: log.l.info("clipping gradient: {} with coef {}".format(total_norm, args.clip_gradient / total_norm)) optimizer.step() # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: log.l.info(('Epoch: [{0}][{1}/{2}], lr: {lr:.5f}\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( epoch, i, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1, top5=top5, lr=optimizer.param_groups[-1]['lr']))) def validate(val_loader, model, criterion, iter, logger=None): batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() # switch to evaluate mode model.eval() end = time.time() for i, (input, target,is_posi) in enumerate(val_loader): target = target.cuda(async=True) input_var = torch.autograd.Variable(input, volatile=True) target_var = torch.autograd.Variable(target, volatile=True) # compute output output = model(input_var)[0] loss = criterion(output, target_var) # measure accuracy and record loss prec1, prec5 = accuracy(output.data, target, topk=(1,5)) losses.update(loss.data[0], input.size(0)) top1.update(prec1[0], input.size(0)) top5.update(prec5[0], input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: log.l.info(('Test: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( i, len(val_loader), batch_time=batch_time, loss=losses, top1=top1, top5=top5))) log.l.info(('Testing Results: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f} Loss {loss.avg:.5f}' .format(top1=top1, top5=top5, loss=losses))) return top1.avg def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): filename = '_'.join((args.snapshot_pref, args.modality.lower(), filename)) torch.save(state, filename) if is_best: best_name = '_'.join((args.snapshot_pref, args.modality.lower(), 'model_best.pth.tar')) shutil.copyfile(filename, best_name) class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def adjust_learning_rate(optimizer, epoch, lr_steps): """Sets the learning rate to the initial LR decayed by 10 every 30 epochs""" decay = 0.1 ** (sum(epoch >= np.array(lr_steps))) lr = args.lr * decay decay = args.weight_decay for param_group in optimizer.param_groups: param_group['lr'] = lr #* param_group['lr_mult'] param_group['weight_decay'] = decay #* param_group['decay_mult'] def accuracy(output, target, topk=(1,)): """Computes the precision@k for the specified values of k""" maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0) res.append(correct_k.mul_(100.0 / batch_size)) return res if __name__=='__main__': main()