"""
@author: Viet Nguyen <nhviet1009@gmail.com>
"""

import cv2
import numpy as np
import subprocess as sp
from MAMEToolkit.sf_environment import Environment


class Monitor:
    def __init__(self, width, height, saved_path):

        self.command = ["ffmpeg", "-y", "-f", "rawvideo", "-vcodec", "rawvideo", "-s", "{}X{}".format(width, height),
                        "-pix_fmt", "rgb24", "-r", "60", "-i", "-", "-an", "-vcodec", "mpeg4", saved_path]
        try:
            self.pipe = sp.Popen(self.command, stdin=sp.PIPE, stderr=sp.PIPE)
        except FileNotFoundError:
            pass

    def record(self, image_array):
        self.pipe.stdin.write(image_array.tostring())


def process_frame(frame):
    if frame is not None:
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        frame = cv2.resize(frame, (168, 168))[None, :, :] / 255.
        return frame
    else:
        return np.zeros((1, 168, 168))


class StreetFighterEnv(object):
    def __init__(self, index, monitor = None):
        roms_path = "roms/"
        self.env = Environment("env{}".format(index), roms_path)
        if monitor:
            self.monitor = monitor
        else:
            self.monitor = None
        self.env.start()

    def step(self, action):
        move_action = action//10
        attack_action = action%10
        frames, reward, round_done, stage_done, game_done = self.env.step(move_action, attack_action)
        if self.monitor:
            for frame in frames:
                self.monitor.record(frame)
        if not (round_done or stage_done or game_done):
            frames = np.concatenate([process_frame(frame) for frame in frames], 0)[None, :, :, :].astype(np.float32)
        else:
            frames = np.zeros((1, 3, 168, 168), dtype=np.float32)
        reward = reward["P1"]
        if stage_done:
            reward = 25
        elif game_done:
            reward = -50
        reward *= (1+(self.env.stage-1)/10)
        reward /= 10
        return frames, reward, round_done, stage_done, game_done

    def reset(self, round_done, stage_done, game_done):
        if game_done:
            self.env.new_game()
        elif stage_done:
            self.env.next_stage()
        elif round_done:
            self.env.next_round()
        return np.zeros((1, 3, 168, 168), dtype=np.float32)

def create_train_env(index, output_path=None):
    num_inputs = 3
    num_actions = 90
    if output_path:
        monitor = Monitor(384, 224, output_path)
    else:
        monitor = None
    env = StreetFighterEnv(index, monitor)
    return env, num_inputs, num_actions