#!/usr/bin/env python

# --------------------------------------------------------------------
# This file is part of
# Weakly-supervised Pedestrian Attribute Localization Network.
#
# Weakly-supervised Pedestrian Attribute Localization Network
# is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Weakly-supervised Pedestrian Attribute Localization Network
# is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Weakly-supervised Pedestrian Attribute Localization Network.
# If not, see <http://www.gnu.org/licenses/>.
# --------------------------------------------------------------------

import os

import caffe
import google.protobuf as pb2
from caffe.proto import caffe_pb2
from utils.timer import Timer

from config import cfg
from test import test_net


class SolverWrapper(object):
    """A simple wrapper around Caffe's solver.
    This wrapper gives us control over the snapshotting process.
    """

    def __init__(self, solver_prototxt, db, output_dir, do_flip,
                 snapshot_path=None):
        """Initialize the SolverWrapper."""
        self._output_dir = output_dir
        self._solver = caffe.SGDSolver(solver_prototxt)

        self._solver_param = caffe_pb2.SolverParameter()
        with open(solver_prototxt, 'rt') as f:
            pb2.text_format.Merge(f.read(), self._solver_param)

        infix = ('_' + cfg.TRAIN.SNAPSHOT_INFIX
                 if cfg.TRAIN.SNAPSHOT_INFIX != '' else '')
        self._snapshot_prefix = self._solver_param.snapshot_prefix + infix + '_iter_'

        if snapshot_path is not None:
            print ('Loading snapshot weights from {:s}').format(snapshot_path)
            self._solver.net.copy_from(snapshot_path)

            snapshot_path = snapshot_path.split('/')[-1]
            if snapshot_path.startswith(self._snapshot_prefix):
                print 'Warning! Existing snapshots may be overriden by new snapshots!'
 
        self._db = db
        self._solver.net.layers[0].set_db(self._db, do_flip)

    def snapshot(self):
        """Take a snapshot of the network."""
        net = self._solver.net

        filename = self._snapshot_prefix + ('{:d}'.format(self._solver.iter) + '.caffemodel')
        filepath = os.path.join(self._output_dir, filename)

        print 'Attempting to save snapshot to \"{}\"'.format(filepath)
        if not os.path.exists(self._output_dir):        
            os.makedirs(self._output_dir)
        net.save(str(filepath))
        print 'Wrote snapshot to: {:s}'.format(filepath)

        return filepath

    def train_model(self, max_iters):
        """Network training loop."""
        last_snapshot_iter = -1
        timer = Timer()
        model_paths = []
        while self._solver.iter < max_iters:
            # Make one SGD update
            timer.tic()
            self._solver.step(1)
            timer.toc()
            if self._solver.iter % (10 * self._solver_param.display) == 0:
                print 'speed: {:.3f}s / iter'.format(timer.average_time)
            if self._solver.iter % 10 == 0:
                print "Python: iter", self._solver.iter
            if self._solver.iter % cfg.TRAIN.SNAPSHOT_ITERS == 0:
                last_snapshot_iter = self._solver.iter
                model_paths.append(self.snapshot())

            if self._solver.iter % cfg.TRAIN.TEST_ITERS == 0:
                test_net(self._solver.test_net, self._db, self._output_dir)

        if last_snapshot_iter != self._solver.iter:
            model_paths.append(self.snapshot())
        return model_paths


def train_net(solver_prototxt, db, output_dir,
              snapshot_path=None, max_iters=40000):
    """Train a WMA network."""
    sw = SolverWrapper(solver_prototxt, db, output_dir, cfg.TRAIN.DO_FLIP,
                       snapshot_path=snapshot_path)

    print 'Solving...'
    model_paths = sw.train_model(max_iters)
    print 'done solving'
    return model_paths