#!/usr/bin/env python # -*- coding: utf-8 -*- import argparse import imp import logging import os import re import shutil import time from multiprocessing import Process from multiprocessing import Queue import chainer import numpy as np import six from chainer import Variable from chainer import cuda from chainer import optimizers from chainer import serializers import lmdb from draw_loss import draw_loss from utils.transformer import transform def create_args(): parser = argparse.ArgumentParser() # Training settings parser.add_argument('--model', type=str, default='models/MnihCNN_multi.py') parser.add_argument('--gpu', type=int, default=0) parser.add_argument('--epoch', type=int, default=400) parser.add_argument('--batchsize', type=int, default=128) parser.add_argument('--dataset_size', type=float, default=1.0) parser.add_argument('--aug_threads', type=int, default=8) parser.add_argument('--snapshot', type=int, default=10) parser.add_argument('--resume_model', type=str, default=None) parser.add_argument('--resume_opt', type=str, default=None) parser.add_argument('--epoch_offset', type=int, default=0) # Dataset paths parser.add_argument('--train_ortho_db', type=str, default='data/mass_merged/lmdb/train_sat') parser.add_argument('--train_label_db', type=str, default='data/mass_merged/lmdb/train_map') parser.add_argument('--valid_ortho_db', type=str, default='data/mass_merged/lmdb/valid_sat') parser.add_argument('--valid_label_db', type=str, default='data/mass_merged/lmdb/valid_map') # Dataset info parser.add_argument('--ortho_original_side', type=int, default=92) parser.add_argument('--label_original_side', type=int, default=24) parser.add_argument('--ortho_side', type=int, default=64) parser.add_argument('--label_side', type=int, default=16) # Options for data augmentation parser.add_argument('--fliplr', type=int, default=1) parser.add_argument('--rotate', type=int, default=1) parser.add_argument('--angle', type=int, default=90) parser.add_argument('--norm', type=int, default=1) parser.add_argument('--crop', type=int, default=1) # Optimization settings parser.add_argument('--opt', type=str, default='MomentumSGD', choices=['MomentumSGD', 'Adam', 'AdaGrad']) parser.add_argument('--weight_decay', type=float, default=0.0005) parser.add_argument('--alpha', type=float, default=0.001) parser.add_argument('--lr', type=float, default=0.0005) parser.add_argument('--lr_decay_freq', type=int, default=100) parser.add_argument('--lr_decay_ratio', type=float, default=0.1) parser.add_argument('--seed', type=int, default=1701) args = parser.parse_args() return args def create_result_dir(args): if args.resume_model is None: result_dir = 'results/{}_{}'.format( os.path.splitext(os.path.basename(args.model))[0], time.strftime('%Y-%m-%d_%H-%M-%S')) if os.path.exists(result_dir): result_dir += '_{}'.format(time.clock()) if not os.path.exists(result_dir): os.makedirs(result_dir) else: result_dir = os.path.dirname(args.resume_model) log_fn = '%s/log.txt' % result_dir logging.basicConfig( format='%(asctime)s [%(levelname)s] %(message)s', filename=log_fn, level=logging.DEBUG) logging.info(args) return log_fn, result_dir def get_model(args): model_fn = os.path.basename(args.model) model = imp.load_source(model_fn.split('.')[0], args.model).model if 'result_dir' in args: dst = '%s/%s' % (args.result_dir, model_fn) if not os.path.exists(dst): shutil.copy(args.model, dst) dst = '%s/%s' % (args.result_dir, os.path.basename(__file__)) if not os.path.exists(dst): shutil.copy(__file__, dst) # load model if args.resume_model is not None: serializers.load_hdf5(args.resume_model, model) # prepare model if args.gpu >= 0: model.to_gpu() return model def get_model_optimizer(args): model = get_model(args) if 'opt' in args: # prepare optimizer if args.opt == 'MomentumSGD': optimizer = optimizers.MomentumSGD(lr=args.lr, momentum=0.9) elif args.opt == 'Adam': optimizer = optimizers.Adam(alpha=args.alpha) elif args.opt == 'AdaGrad': optimizer = optimizers.AdaGrad(lr=args.lr) else: raise Exception('No optimizer is selected') optimizer.setup(model) if args.opt == 'MomentumSGD': optimizer.add_hook( chainer.optimizer.WeightDecay(args.weight_decay)) if args.resume_opt is not None: serializers.load_hdf5(args.resume_opt, optimizer) args.epoch_offset = int( re.search('epoch-([0-9]+)', args.resume_opt).groups()[0]) return model, optimizer else: print('No optimizer generated.') return model def create_minibatch(args, o_cur, l_cur, batch_queue): np.random.seed(int(time.time())) skip = np.random.randint(args.batchsize) for _ in six.moves.range(skip): o_cur.next() l_cur.next() logging.info('random skip:{}'.format(skip)) x_minibatch = [] y_minibatch = [] i = 0 while True: o_key, o_val = o_cur.item() l_key, l_val = l_cur.item() if o_key != l_key: raise ValueError( 'Keys of ortho and label patches are different: ' '{} != {}'.format(o_key, l_key)) # prepare patch o_side = args.ortho_original_side l_side = args.label_original_side o_patch = np.fromstring( o_val, dtype=np.uint8).reshape((o_side, o_side, 3)) l_patch = np.fromstring( l_val, dtype=np.uint8).reshape((l_side, l_side, 1)) # add patch x_minibatch.append(o_patch) y_minibatch.append(l_patch) o_ret = o_cur.next() l_ret = l_cur.next() if ((not o_ret) and (not l_ret)) or len(x_minibatch) == args.batchsize: x_minibatch = np.asarray(x_minibatch, dtype=np.uint8) y_minibatch = np.asarray(y_minibatch, dtype=np.uint8) batch_queue.put((x_minibatch, y_minibatch)) i += len(x_minibatch) x_minibatch = [] y_minibatch = [] if i > args.N * args.dataset_size: break if ((not o_ret) and (not l_ret)): break for _ in six.moves.range(args.aug_threads): batch_queue.put(None) def apply_transform(args, batch_queue, aug_queue): np.random.seed(int(time.time())) while True: augs = batch_queue.get() if augs is None: break x, y = augs o_aug, l_aug = transform( x, y, args.fliplr, args.rotate, args.norm, args.ortho_side, args.ortho_side, 3, args.label_side, args.label_side) aug_queue.put((o_aug, l_aug)) aug_queue.put(None) def get_cursor(db_fn): env = lmdb.open(db_fn) txn = env.begin(write=False, buffers=False) cur = txn.cursor() cur.next() return cur, txn, env.stat()['entries'] def one_epoch(args, model, optimizer, epoch, train): model.train = train xp = cuda.cupy if args.gpu >= 0 else np # open datasets ortho_db = args.train_ortho_db if train else args.valid_ortho_db label_db = args.train_label_db if train else args.valid_label_db o_cur, o_txn, args.N = get_cursor(ortho_db) l_cur, l_txn, _ = get_cursor(label_db) # for parallel augmentation batch_queue = Queue() batch_worker = Process(target=create_minibatch, args=(args, o_cur, l_cur, batch_queue)) batch_worker.start() aug_queue = Queue() aug_workers = [Process(target=apply_transform, args=(args, batch_queue, aug_queue)) for __ in range(args.aug_threads)] for w in aug_workers: w.start() n_iter = 0 sum_loss = 0 num = 0 while True: minibatch = aug_queue.get() if minibatch is None: break x, t = minibatch volatile = 'off' if train else 'on' x = Variable(xp.asarray(x), volatile=volatile) t = Variable(xp.asarray(t), volatile=volatile) if train: optimizer.update(model, x, t) else: model(x, t) sum_loss += float(model.loss.data) * t.data.shape[0] num += t.data.shape[0] n_iter += 1 del x, t # wait for threads batch_worker.join() for w in aug_workers: w.terminate() if train and (epoch == 1 or epoch % args.snapshot == 0): model_fn = '{}/epoch-{}.model'.format(args.result_dir, epoch) opt_fn = '{}/epoch-{}.state'.format(args.result_dir, epoch) serializers.save_hdf5(model_fn, model) serializers.save_hdf5(opt_fn, optimizer) if train: logging.info( 'epoch:{}\ttrain loss:{}'.format(epoch, sum_loss / num)) else: logging.info( 'epoch:{}\tvalidate loss:{}'.format(epoch, sum_loss / num)) return model, optimizer if __name__ == '__main__': args = create_args() if args.gpu >= 0: cuda.get_device(args.gpu).use() xp = cuda.cupy if args.gpu >= 0 else np xp.random.seed(args.seed) np.random.seed(args.seed) # create result dir log_fn, args.result_dir = create_result_dir(args) # create model and optimizer model, optimizer = get_model_optimizer(args) # start logging logging.info('start training...') for epoch in six.moves.range(args.epoch_offset + 1, args.epoch + 1): logging.info('learning rate:{}'.format(optimizer.lr)) model, optimizer = one_epoch(args, model, optimizer, epoch, True) if epoch == 1 or epoch % args.snapshot == 0: one_epoch(args, model, optimizer, epoch, False) # draw curve draw_loss('{}/log.txt'.format(args.result_dir), '{}/log.png'.format(args.result_dir)) # learning rate reduction if args.opt == 'MomentumSGD' \ and epoch % args.lr_decay_freq == 0: optimizer.lr *= args.lr_decay_ratio logging.info('-' * 20)