"""trt_googlenet.py

This is the 'async' version of trt_googlenet.py implementation.

Refer to trt_ssd_async.py for description about the design and
synchronization between the main and child threads.
"""


import sys
import time
import argparse
import threading

import numpy as np
import cv2
from utils.camera import add_camera_args, Camera
from utils.display import open_window, set_display, show_fps
from pytrt import PyTrtGooglenet


PIXEL_MEANS = np.array([[[104., 117., 123.]]], dtype=np.float32)
DEPLOY_ENGINE = 'googlenet/deploy.engine'
ENGINE_SHAPE0 = (3, 224, 224)
ENGINE_SHAPE1 = (1000, 1, 1)
RESIZED_SHAPE = (224, 224)

WINDOW_NAME = 'TrtGooglenetDemo'

# 'shared' global variables
s_img, s_probs, s_labels = None, None, None


def parse_args():
    """Parse input arguments."""
    desc = ('Capture and display live camera video, while doing '
            'real-time image classification with TrtGooglenet '
            'on Jetson Nano')
    parser = argparse.ArgumentParser(description=desc)
    parser = add_camera_args(parser)
    parser.add_argument('--crop', dest='crop_center',
                        help='crop center square of image for '
                             'inferencing [False]',
                        action='store_true')
    args = parser.parse_args()
    return args


def classify(img, net, labels, do_cropping):
    """Classify 1 image (crop)."""
    crop = img
    if do_cropping:
        h, w, _ = img.shape
        if h < w:
            crop = img[:, ((w-h)//2):((w+h)//2), :]
        else:
            crop = img[((h-w)//2):((h+w)//2), :, :]

    # preprocess the image crop
    crop = cv2.resize(crop, RESIZED_SHAPE)
    crop = crop.astype(np.float32) - PIXEL_MEANS
    crop = crop.transpose((2, 0, 1))  # HWC -> CHW

    # inference the (cropped) image
    out = net.forward(crop[None])  # add 1 dimension to 'crop' as batch

    # output top 3 predicted scores and class labels
    out_prob = np.squeeze(out['prob'][0])
    top_inds = out_prob.argsort()[::-1][:3]
    return (out_prob[top_inds], labels[top_inds])


class TrtGooglenetThread(threading.Thread):
    def __init__(self, condition, cam, labels, do_cropping):
        """__init__

        # Arguments
            condition: the condition variable used to notify main
                       thread about new frame and detection result
            cam: the camera object for reading input image frames
            labels: a numpy array of class labels
            do_cropping: whether to do center-cropping of input image
        """
        threading.Thread.__init__(self)
        self.condition = condition
        self.cam = cam
        self.labels = labels
        self.do_cropping = do_cropping
        self.running = False

    def run(self):
        """Run until 'running' flag is set to False by main thread."""
        global s_img, s_probs, s_labels

        print('TrtGooglenetThread: loading the TRT Googlenet engine...')
        self.net = PyTrtGooglenet(DEPLOY_ENGINE, ENGINE_SHAPE0, ENGINE_SHAPE1)
        print('TrtGooglenetThread: start running...')
        self.running = True
        while self.running:
            img = self.cam.read()
            top_probs, top_labels = classify(
                img, self.net, self.labels, self.do_cropping)
            with self.condition:
                s_img, s_probs, s_labels = img, top_probs, top_labels
                self.condition.notify()
        del self.net
        print('TrtGooglenetThread: stopped...')

    def stop(self):
        self.running = False
        self.join()


def show_top_preds(img, top_probs, top_labels):
    """Show top predicted classes and softmax scores."""
    x = 10
    y = 40
    for prob, label in zip(top_probs, top_labels):
        pred = '{:.4f} {:20s}'.format(prob, label)
        #cv2.putText(img, pred, (x+1, y), cv2.FONT_HERSHEY_PLAIN, 1.0,
        #            (32, 32, 32), 4, cv2.LINE_AA)
        cv2.putText(img, pred, (x, y), cv2.FONT_HERSHEY_PLAIN, 1.0,
                    (0, 0, 240), 1, cv2.LINE_AA)
        y += 20


def loop_and_display(condition):
    """Continuously capture images from camera and do classification."""
    global s_img, s_probs, s_labels

    full_scrn = False
    fps = 0.0
    tic = time.time()
    while True:
        if cv2.getWindowProperty(WINDOW_NAME, 0) < 0:
            break
        with condition:
            condition.wait()
            img, top_probs, top_labels = s_img, s_probs, s_labels
        show_top_preds(img, top_probs, top_labels)
        img = show_fps(img, fps)
        cv2.imshow(WINDOW_NAME, img)
        toc = time.time()
        curr_fps = 1.0 / (toc - tic)
        # calculate an exponentially decaying average of fps number
        fps = curr_fps if fps == 0.0 else (fps*0.95 + curr_fps*0.05)
        tic = toc
        key = cv2.waitKey(1)
        if key == 27:  # ESC key: quit program
            break
        elif key == ord('H') or key == ord('h'):  # Toggle help message
            show_help = not show_help
        elif key == ord('F') or key == ord('f'):  # Toggle fullscreen
            full_scrn = not full_scrn
            set_display(WINDOW_NAME, full_scrn)


def main():
    args = parse_args()
    labels = np.loadtxt('googlenet/synset_words.txt', str, delimiter='\t')
    cam = Camera(args)
    cam.open()
    if not cam.is_opened:
        sys.exit('Failed to open camera!')

    cam.start()  # let camera start grabbing frames
    open_window(WINDOW_NAME, args.image_width, args.image_height,
                'Camera TensorRT GoogLeNet Demo for Jetson Nano')
    condition = threading.Condition()
    trt_thread = TrtGooglenetThread(condition, cam, labels, args.crop_center)
    trt_thread.start()  # start the child thread
    loop_and_display(condition)
    trt_thread.stop()   # stop the child thread

    cam.stop()
    cam.release()
    cv2.destroyAllWindows()


if __name__ == '__main__':
    main()