# Copyright (C) 2019 Titus Cieslewski, RPG, University of Zurich, Switzerland # You can contact the author at <titus at ifi dot uzh dot ch> # Copyright (C) 2019 Konstantinos G. Derpanis, # Dept. of Computer Science, Ryerson University, Toronto, Canada # Copyright (C) 2019 Davide Scaramuzza, RPG, University of Zurich, Switzerland # # This file is part of sips2_open. # # sips2_open is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # sips2_open is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with sips2_open. If not, see <http:#www.gnu.org/licenses/>. from absl import flags import os import tensorflow as tf from rpg_datasets_py.hpatches import HPatches import rpg_datasets_py.euroc import rpg_datasets_py.kitti import rpg_datasets_py.robotcar import rpg_datasets_py.tum_mono import baselines import graph import multiscale import sequences import system FLAGS = flags.FLAGS def shortString(): ret = 'd=%d_tds=%s_nms=%d' % (FLAGS.d, FLAGS.tds, FLAGS.nms) if FLAGS.num_scales > 1: ret = ret + '_ms%d_%.02f' % (FLAGS.num_scales, FLAGS.scale_factor) if FLAGS.pbs: ret = ret + '_pbs' if FLAGS.augment: ret = ret + '_aug' if FLAGS.w != 128: ret = ret + '_w=%d' % FLAGS.w if FLAGS.scale_aug_range != 1.: ret = ret + '_sar=%.01f' % FLAGS.scale_aug_range if FLAGS.klti: ret = ret + '_klti' if FLAGS.lk: ret = ret + '_lk' if FLAGS.ol != 0.5: ret = ret + '_ol=%.2f' % FLAGS.ol return ret def announce(what): print(os.linesep + os.linesep + what + os.linesep + os.linesep) def announceTraining(): announce('Training %s...' % shortString()) def getTrainDataGen(): if FLAGS.klti: pair_class = sequences.PairWithIntermediates else: pair_class = sequences.PairWithStereo assert FLAGS.pck == 'tr' if FLAGS.tds == 'kt': ks = [rpg_datasets_py.kitti.KittiSeq(i) for i in rpg_datasets_py.kitti.split('training')] return sequences.MultiSeqTrackingPairPicker(ks, FLAGS.ol, pair_class) elif FLAGS.tds == 'rc': rs = rpg_datasets_py.robotcar.getSplitSequences('training') return sequences.MultiSeqTrackingPairPicker(rs, FLAGS.ol, pair_class) elif FLAGS.tds == 'tm': tms = [rpg_datasets_py.tum_mono.Sequence(i) for i in ['01', '02', '03', '48', '49', '50']] return sequences.MultiSeqTrackingPairPicker( tms, FLAGS.ol, sequences.PairWithIntermediates) elif FLAGS.tds == 'tmb': tms = [rpg_datasets_py.tum_mono.Sequence('%02d' % i) for i in range(1, 51)] return sequences.MultiSeqTrackingPairPicker( tms, FLAGS.ol, sequences.PairWithIntermediates) elif FLAGS.tds == 'tmbrc': tms = [rpg_datasets_py.tum_mono.Sequence('%02d' % i) for i in range(1, 51)] + \ [rpg_datasets_py.robotcar.CroppedGraySequence( '2014-07-14-14-49-50')] return sequences.MultiSeqTrackingPairPicker( tms, FLAGS.ol, sequences.PairWithIntermediates) elif FLAGS.tds == 'en': ks = [rpg_datasets_py.kitti.KittiSeq(i) for i in rpg_datasets_py.kitti.split('training')] rs = rpg_datasets_py.robotcar.getSplitSequences('training') tms = [rpg_datasets_py.tum_mono.Sequence(i) for i in ['01', '02', '03', '48', '49', '50']] pair_classes = [pair_class] * len(ks) + \ [sequences.PairWithIntermediates] * (len(rs) + len(tms)) return sequences.MultiSeqTrackingPairPicker( ks + rs + tms, FLAGS.ol, pair_classes) elif FLAGS.tds == 'hp': data_gen = HPatches('training', use_min=True) else: assert False return data_gen def getEvalDataGen(): if not FLAGS.testing: if FLAGS.ds == 'kt': val_seqs = rpg_datasets_py.kitti.split('validation') assert len(val_seqs) == 1 k = rpg_datasets_py.kitti.KittiSeq(val_seqs[0]) rpick = sequences.TrackingPairPicker( k, 0.5, pair_class=sequences.PairWithStereo) return sequences.FixPairs(rpick, 100) else: assert FLAGS.ds == 'hp' ret = HPatches('validation', use_min=True) print(ret.folder_names) return ret else: if FLAGS.ds == 'hp': ret = HPatches('testing', use_min=True) print(ret.folder_names) return ret elif FLAGS.ds == 'kt': val_seqs = rpg_datasets_py.kitti.split('testing') assert len(val_seqs) == 1 k = rpg_datasets_py.kitti.KittiSeq(val_seqs[0]) rpick = sequences.TrackingPairPicker( k, 0.5, pair_class=sequences.PairWithStereo) return sequences.FixPairs(rpick, 100) elif FLAGS.ds == 'eu': seq = rpg_datasets_py.euroc.EurocSeq('V1_01_easy') rpick = sequences.TrackingPairPicker( seq, 0.5, pair_class=sequences.PairWithStereo) return sequences.FixPairs(rpick, 100) elif FLAGS.ds == 'eumh': seq = rpg_datasets_py.euroc.EurocSeq('MH_01_easy') rpick = sequences.TrackingPairPicker( seq, 0.5, pair_class=sequences.PairWithStereo) return sequences.FixPairs(rpick, 100) else: raise NotImplementedError def getEvalSequences(): if FLAGS.ds == 'kt': if FLAGS.testing: seq_names = rpg_datasets_py.kitti.split('testing') else: seq_names = rpg_datasets_py.kitti.split('validation') return [rpg_datasets_py.kitti.KittiSeq(i) for i in seq_names] elif FLAGS.ds == 'eumh': assert FLAGS.testing seq_names = ['MH_01_easy'] return [rpg_datasets_py.euroc.EurocSeq(i) for i in seq_names] else: assert FLAGS.ds == 'eu' assert FLAGS.testing seq_names = ['V1_01_easy'] return [rpg_datasets_py.euroc.EurocSeq(i) for i in seq_names] def getForwardPasser(graph=None, sess=None): if FLAGS.baseline == '': assert graph is not None assert sess is not None fp = system.ForwardPasser( graph, sess, FLAGS.num_test_pts, FLAGS.nms) if FLAGS.num_scales > 1: return multiscale.ForwardPasser( fp, FLAGS.scale_factor, FLAGS.num_scales) else: return fp elif FLAGS.baseline in ['surf', 'sift']: return baselines.OpenCVForwardPasser(FLAGS.baseline) else: raise Exception('Baseline %s unknown' % FLAGS.baseline) def modelFromCheckpoint(): if FLAGS.baseline is not '': return None, None tf.reset_default_graph() g = graph.Graph() saver = tf.train.Saver() sess = tf.Session() saver.restore(sess, checkpointPath()) return g, sess def checkpointRoot(): return os.path.join(os.path.dirname(__file__), 'checkpoints') def checkpointPath(): if FLAGS.val_best: return os.path.join(checkpointRoot(), shortString() + '_best') else: return os.path.join(checkpointRoot(), shortString()) def trainStatsPath(): return '%s_stats' % checkpointPath() def methodString(): if FLAGS.baseline == '': return shortString() else: return FLAGS.baseline def evalString(k=True): result = '%s_%d' % (FLAGS.ds, FLAGS.num_test_pts) if k: result += '_%d' % FLAGS.k if FLAGS.testing: result += '_TESTING' return result def vt(): if FLAGS.testing: return 'testing' else: return 'validation' def methodEvalString(k=True): return '%s_%s' % (methodString(), evalString(k=k)) def announceEval(): announce('Evaluating %s...' % methodEvalString()) if FLAGS.testing: announce('THIS IS A TESTING RUN!!!') def resultPath(k=True): return os.path.join('results', methodEvalString(k=k)) def wasInlierFilePath(): return '%s_wasinl' % resultPath() def cachedForwardPath(): return '%s_cached_fp.hkl' % resultPath(k=False) def evalPath(): return '%s_n_r_t.hkl' % resultPath() def choutRootPath(): return '%s_chouts' % resultPath() def seqFpsPath(): here = os.path.dirname(__file__) return os.path.join(here, 'sequence_fps', methodEvalString(k=False)) def label(): if FLAGS.baseline != '': return FLAGS.baseline else: if FLAGS.num_scales > 1: return 'ours, multi-scale' else: return 'ours, single scale'