# -------------------------------------------------------- # TRIPLET LOSS # Copyright (c) 2015 Pinguo Tech. # Written by David Lu # -------------------------------------------------------- """Train the network.""" import caffe from timer import Timer import numpy as np import os from caffe.proto import caffe_pb2 import google.protobuf as pb2 import sys import config import _init_paths class SolverWrapper(object): """A simple wrapper around Caffe's solver. """ def __init__(self, solver_prototxt, output_dir, pretrained_model=None): """Initialize the SolverWrapper.""" self.output_dir = output_dir caffe.set_mode_gpu() caffe.set_device(0) self.solver = caffe.SGDSolver(solver_prototxt) if pretrained_model is not None: print ('Loading pretrained model ' 'weights from {:s}').format(pretrained_model) self.solver.net.copy_from(pretrained_model) self.solver_param = caffe_pb2.SolverParameter() with open(solver_prototxt, 'rt') as f: pb2.text_format.Merge(f.read(), self.solver_param) def snapshot(self): """Take a snapshot of the network after unnormalizing the learned """ net = self.solver.net if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) filename = (self.solver_param.snapshot_prefix + '_iter_{:d}'.format(self.solver.iter) + '.caffemodel') filename = os.path.join(self.output_dir, filename) net.save(str(filename)) print 'Wrote snapshot to: {:s}'.format(filename) def train_model(self, max_iters): """Network training loop.""" last_snapshot_iter = -1 timer = Timer() while self.solver.iter < max_iters: timer.tic() self.solver.step(1) print 'fc9_1:',sorted(self.solver.net.params['fc9_1'][0].data[0])[-1] #print 'fc9:',sorted(self.solver.net.params['fc9'][0].data[0])[-1] #print 'fc7:',sorted(self.solver.net.params['fc7'][0].data[0])[-1] #print 'fc6:',sorted(self.solver.net.params['fc6'][0].data[0])[-1] #print 'fc9:',(self.solver.net.params['fc9'][0].data[0])[0] #print 'fc7:',(self.solver.net.params['fc7'][0].data[0])[0] #print 'fc6:',(self.solver.net.params['fc6'][0].data[0])[0] #print 'conv5_3:',self.solver.net.params['conv5_3'][0].data[0][0][0] #print 'conv5_2:',self.solver.net.params['conv5_2'][0].data[0][0][0] #print 'conv5_1:',self.solver.net.params['conv5_1'][0].data[0][0][0] #print 'conv4_3:',self.solver.net.params['conv4_3'][0].data[0][0][0] #print 'fc9:',self.solver.net.params['fc9'][0].data[0][0] timer.toc() if self.solver.iter % (10 * self.solver_param.display) == 0: print 'speed: {:.3f}s / iter'.format(timer.average_time) if __name__ == '__main__': """Train network.""" solver_prototxt = '../solver.prototxt' output_dir = '../models/vgg_face_tripletloss/' pretrained_model = '../models/_iter_40000.caffemodel' #pretrained_model = None #pretrained_model = '/home/seal/dataset/fast-rcnn/data/vgg_face_caffe/VGG_FACE.caffemodel' max_iters = config.MAX_ITERS sw = SolverWrapper(solver_prototxt, output_dir,pretrained_model) print 'Solving...' sw.train_model(max_iters) print 'done solving'