import argparse
import json
import logging

import faiss
from feature import DirectoryFeature
from file_util import make_model

INDEX_PATH = './model/train.index'
ID_PATH = 'data.json'


class Similarity(object):
    def __init__(self):
        self.feature = None

    def set_feature(self, feature):
        self.feature = feature

    def calculate(self, images):
        """Calculate similarity
        Subclasses should override for any actions to run.
        :arg
            images: numpy array, shape(N, image_size, image_size, 3)
        :return:
            names: list[string], predicted person names
        """


class PeopleSimilarity(Similarity):
    def __init__(self, feature, index_path, id_path):
        super(PeopleSimilarity, self).__init__()
        self.feature = feature
        self.index_path = index_path
        self.id_path = id_path

    def calculate(self, images):
        predicted = []
        index = faiss.read_index(self.index_path)
        with open(self.id_path) as f:
            id_json = json.load(f)
        logging.info('database load')
        imgs = self.feature.get_feature(images)
        D, I = index.search(imgs, k=1)
        for p in I:
            predicted.append(id_json[str(p[0])])
        return predicted


def predict(feature, imgs):
    count_frame = []
    array_image = []
    # frame별로 수정해야댐...
    # for frame in imgs:
    #     count_frame.append(len(frame))
    #     for face in frame:
    #         array_image.append(face)
    # image_list = None
    # for img in array_image:
    #     x = image.img_to_array(img)
    #     x = np.expand_dims(x, axis=0)
    #     x = preprocess_input(x)
    #     image_list = np.concatenate((image_list, x)) if image_list is not None else x

    similarity = PeopleSimilarity(feature, INDEX_PATH, ID_PATH)
    preds = similarity.calculate(imgs)
    pred_list = []
    cnt = 0
    for count in count_frame:
        pred_frame = []
        for i in range(count):
            pred_frame.append(preds[cnt])
            cnt += 1
        pred_list.append(pred_frame)

    return pred_list


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model", help="model path", default="model/train.index")
    parser.add_argument("--width", help="image width size", default=224)
    parser.add_argument("--height", help="image height size", default=224)
    parser.add_argument("-i", "--image", help="image", default="./image/jo/val/0jo2.jpg")
    args = parser.parse_args()
    IMAGE_SIZE = (args.width, args.height)
    model = make_model('vgg16', IMAGE_SIZE)
    print(predict(DirectoryFeature(model), 'image/jo'))