import json
import traceback

import cv2

__author__ = 'ananya'

import caffe
import sys

sys.path.append("..")
from scipy.misc import imresize
import numpy as np
import os
from datetime import datetime


class FeatureExtractor(object):
    def __init__(self, path_to_deploy_file, path_to_model_file, input_layer_name="data_q", gpu_mode=True, device_id=1,
                 height=None, width=None):
        self.path_to_deploy_file = path_to_deploy_file
        self.path_to_model_file = path_to_model_file
        if gpu_mode:
            caffe.set_mode_gpu()
            caffe.set_device(device_id)
        else:
            caffe.set_mode_cpu()
        self.net = caffe.Net(path_to_deploy_file, path_to_model_file, caffe.TEST)
        self.input_layer_name = input_layer_name
        self.height = height or self.net.blobs[self.input_layer_name].data.shape[2]
        self.width = width or self.net.blobs[self.input_layer_name].data.shape[3]

    def extract_one(self, img_path, layer):
        img = self.getImageFromPath(img_path)
        resized_img = imresize(img, (self.height, self.width), 'bilinear')
        transposed_img = np.transpose(resized_img, (2, 0, 1))
        assert self.net.blobs[self.input_layer_name].data.shape == (1,) + transposed_img.shape
        self.net.blobs[self.input_layer_name].data[...] = transposed_img
        self.net.forward()
        fv = self.net.blobs[layer].data[0].flatten()
        return fv

    def extract_batch(self, img_paths, layer):
        batch_size = len(img_paths)
        fv_dict = {}
        start_time = datetime.now()
        resized_imgs = []
        for path in img_paths:
            try:
                img = self.getImageFromPath(path)
                resized_imgs.append(imresize(img, (self.height, self.width), 'bilinear'))
            except Exception as e:
                print "Exception for image", path
                traceback.print_exc()

        transposed_imgs = [np.transpose(x, (2, 0, 1)) for x in resized_imgs]
        reqd_shape = (batch_size,) + transposed_imgs[0].shape
        self.net.blobs[self.input_layer_name].reshape(*reqd_shape)
        self.net.blobs[self.input_layer_name].data[...] = transposed_imgs
        self.net.forward()
        fv = self.net.blobs[layer].data
        count = 0
        for img_path in img_paths:
            fv_key = os.path.splitext(os.path.basename(img_path))[0]
            fv_value = fv[count].flatten()
            fv_dict[fv_key] = fv_value
            count += 1
        end_time = datetime.now()
        delta = end_time - start_time
        print("Batch took " + str(delta.total_seconds() * 1000))
        return fv_dict

    def getImageFromPath(self, path):
        return cv2.imread(path, cv2.IMREAD_COLOR)