# --------------------------------------------------------
# TRIPLET LOSS
# Copyright (c) 2015 Pinguo Tech.
# Written by David Lu
# --------------------------------------------------------

"""The data layer used during training to train the network.
   This is a example for online triplet selection
   Each minibatch contains a set of archor-positive pairs, random select negative exemplar
"""

import caffe
import numpy as np
from numpy import *
import yaml
from sampledata import sampledata
import random
import cv2
from blob import prep_im_for_blob, im_list_to_blob
import config


class DataLayer(caffe.Layer):
    """Sample data layer used for training."""    
        
    
    def _get_next_minibatch(self):
        num_images = self._batch_size
        # Sample to use for each image in this batch
        sample = []
        if self._index >= len(self.data_container._sample):
            self._index = 0
        archor = self.data_container._sample[self._index]
        archor_personname = archor.split('@')[0]
        self._index = self._index + 1
        while len(sample) < self._triplet:
            sample.append(archor)
        # Sample positive samples
        while len(sample) < self._triplet*2:    
            picindex = random.randint(0,len(self.data_container._sample_person[archor_personname])-1)
            if (self.data_container._sample_person[archor_personname][picindex]) not in sample:
                sample.append(self.data_container._sample_person[archor_personname][picindex])
        # Sample negative samples
        while len(sample) < self._triplet*3:
            rand = random.randint(0,len(self.data_container._sample_person)-1)
            personname = self.data_container._sample_person.keys()[rand]
            if archor_personname == personname :
                index = max(0,rand - 1)
                if index == 0 :
                    index = rand + 1
                else:
                    index = rand - 1
                personname = self.data_container._sample_person.keys()[index]
            picindex = random.randint(0,len(self.data_container._sample_person[personname])-1)
            if (self.data_container._sample_person[personname][picindex]) not in sample:
                sample.append(self.data_container._sample_person[personname][picindex])
        im_blob,labels_blob = self._get_image_blob(sample)
        #print sample
        blobs = {'data': im_blob,
             'labels': labels_blob}
        return blobs

    def _get_image_blob(self,sample):
        im_blob = []
        labels_blob = []
        for i in range(self._batch_size):
            im = cv2.imread(config.IMAGEPATH+sample[i])
            personname = sample[i].split('@')[0]
            #print str(i)+':'+personname+','+str(len(sample))
            labels_blob.append(self.data_container._sample_label[personname])
            im = prep_im_for_blob(im)
            
            im_blob.append(im)

        # Create a blob to hold the input images
        blob = im_list_to_blob(im_blob)
        return blob,labels_blob

    def setup(self, bottom, top):
        """Setup the RoIDataLayer."""
        # parse the layer parameter string, which must be valid YAML
        layer_params = yaml.load(self.param_str_)    
        self._batch_size = config.BATCH_SIZE
        self._triplet = self._batch_size/3
        assert self._batch_size % 3 == 0
        self._name_to_top_map = {
            'data': 0,
            'labels': 1}

        self.data_container =  sampledata() 
        self._index = 0

        # data blob: holds a batch of N images, each with 3 channels
        # The height and width (100 x 100) are dummy values
        top[0].reshape(self._batch_size, 3, 224, 224)

        top[1].reshape(self._batch_size)

    def forward(self, bottom, top):
        """Get blobs and copy them into this layer's top blob vector."""
        blobs = self._get_next_minibatch()

        for blob_name, blob in blobs.iteritems():
            top_ind = self._name_to_top_map[blob_name]
            # Reshape net's input blobs
            #top[top_ind].reshape(*(blob.shape))
            # Copy data into net's input blobs
            top[top_ind].data[...] = blob

    def backward(self, top, propagate_down, bottom):
        """This layer does not propagate gradients."""
        pass

    def reshape(self, bottom, top):
        """Reshaping happens during the call to forward."""
        pass

class TestBlobFetcher():
    """Experimental class for prefetching blobs in a separate process."""
    
    def __init__(self):
        self._batch_size = 30
        self.data_container =  sampledata() 
        self._index = 0

    def _get_next_minibatch(self):
        num_images = self._batch_size
        # Sample to use for each image in this batch
        sample = []
        if self._index >= len(self.data_container._sample):
            self._index = 0
        archor = self.data_container._sample[self._index]
        archor_personname = archor.split('@')[0]
        self._index = self._index + 1
        while len(sample) < self._triplet:
            sample.append(archor)
        # Sample positive samples
        while len(sample) < 2 * self._triplet:    
            picindex = random.randint(0,len(self.data_container._sample_person[archor_personname])-1)
            if (self.data_container._sample_person[archor_personname][picindex]) not in sample:
                sample.append(self.data_container._sample_person[archor_personname][picindex])
        # Sample negative samples
        while len(sample) < 3 * self._triplet:	    
            rand = random.randint(0,len(self.data_container._sample_person)-1)
            personname = self.data_container._sample_person.keys()[rand]
            if archor_personname == personname :
                index = max(0,rand - 1)
                if index == 0 :
                    index = rand + 1
                else:
                    index = rand - 1
                personname = self.data_container._sample_person.keys()[index]
            picindex = random.randint(0,len(self.data_container._sample_person[personname])-1)
            if (self.data_container._sample_person[personname][picindex]) not in sample:
                sample.append(self.data_container._sample_person[personname][picindex])
            
        im_blob,labels_blob = self._get_image_blob(sample)

        blobs = {'data': im_blob,
             'labels': labels_blob}
        print blobs['labels']
        return blobs

    def _get_image_blob(self,sample):
        im_blob = []
        labels_blob = []
        for i in range(len(sample)):
            im = cv2.imread(config.IMAGEPATH+sample[i])
            personname = sample[i].split('@')[0]
            #print str(i)+':'+personname+','+str(len(sample))
            labels_blob.append(self.data_container._sample_label[personname])
            im = prep_im_for_blob(im)
            
            im_blob.append(im)

        # Create a blob to hold the input images
        blob = im_list_to_blob(im_blob)
        return blob,labels_blob

if __name__ == '__main__':

    #print data_container._sample
    test = TestBlobFetcher()
    for i in range(500):
        blob = test._get_next_minibatch()
        print str(i),np.shape(blob["data"]),blob["labels"]#,blob