from TradzQAI.core import Live_Worker, Live_env
from TradzQAI.API import Api
from TradzQAI.core.environnement.base import dataLoader
from TradzQAI.tools import Saver, Logger, red

from threading import Thread
import datetime, os

class Live_session(Thread):

    def __init__(self, mode="train", contract_type="classic",
            config='config/', db=None, agent="PPO"):

        self.db = db
        if not "/" in config[len(config)-1]:
            raise ValueError("You forget \"/\" at the end, it should be {}/".format(config))
        self.env = None
        self.mode = mode
        self.contract_type = contract_type
        self.config = config
        self.agent = None
        self.worker = None
        self.api_name = None

        self.saver = Saver()
        self.logger = None
        self.dl = None

        self.settings = dict()

        if self.saver.check_settings_files(config):
            self.settings['env'], self.settings['agent'], self.settings['network'] = self.saver.load_settings(config)
            self.logger = Logger()
        else:
            self.initEnv()
            default_env, default_network = self.env.get_default_settings()
            self.saver.save_settings(default_env,
                getattr(__import__('TradzQAI'), agent).get_specs(), 
                default_network, config)
        Thread.__init__(self)

    def stop(self):
        if self.worker:
            self.worker.close()
        self.logger.stop()
        self.api.close()
        self.env.close()

    def getWorker(self):
        return self.worker

    def getEnv(self):
        return self.env

    def getAgent(self):
        return self.agent

    def getApi(self):
        return self.api

    def setAgent(self, agent=None, device=None):
        if agent:
            self.env.model_name = agent
        if self.settings['agent']['type'].split('_')[0].upper() in self.src_agents():
            import warnings
            with warnings.catch_warnings():
                warnings.filterwarnings("ignore",category=FutureWarning)
                from TradzQAI.agents.agent import Agent
                self.agent = Agent
                self.device = device
        else:
            raise ValueError('could not import %s' % self.settings['agent']['type'].split('_')[0].upper())

    def setApi(self, api_name="cbpro"):
        self.api_name = api_name

    def loadSession(self):
        if not self.env:
            self.initEnv()
        if not self.env.stop:
            self.initAgent()
            if not self.api:
                self.initApi()
            self.initWorker()
        else:
            print (red("Warning : ")+"You cannot load the session without setting,\
            any data directory in %s/environnement" % self.config)

    def src_agents(self):
        ignore = ['Agent.py', '__init__.py', '__pycache__']
        valid = []
        for f in os.listdir("TradzQAI/agents"):
            if f not in ignore:
                valid.append(f.replace(".py", ""))
        return valid

    def initAgent(self):
        if not self.agent:
            self.setAgent()
        for classe in self.agent.__mro__:
            if ("tensorforce" and self.agent.__name__) in str(classe):
                self.agent = self.agent(env=self.env, device=self.device)._get()
                return
        self.agent = self.agent(env=self.env, device=self.device)

    def initWorker(self):
        self.worker = Live_Worker(env=self.env, agent=self.agent)

    def initApi(self, key=None, b64=None, passphrase=None, url=None,
            product_id=['BTC-EUR'], mode="maker", auto_cancel=True):
        if not self.api_name:
            self.setApi()
        self.api = Api(api_name=self.api_name, key=key, b64=b64,
            passphrase=passphrase, url=url, product_id=product_id, mode=mode,
            auto_cancel=auto_cancel)  
        self.dl = dataLoader(mode=self.mode, api=self.api)
        del self.settings['env']['base']['data_directory']
        self.saver._check(self.settings['agent']['type'].split('_')[0].upper(), self.settings)
        self.initEnv()

    def initEnv(self):
        self.env = Live_env(mode=self.mode, 
            contract_type=self.contract_type, config=self.settings, 
            logger=self.logger, saver=self.saver, dataloader=self.dl, api=self.api)

    def run(self):
        if not self.agent:
            raise ValueError("add an agent and load the session before running")
        elif not self.env.stop:
            self.env.logger.start()
            Thread(target=self.worker.run).start()
        else:
            print (red("Warning : ")+"You cannot start the session without setting,\
            any data directory in %s/environnement" % self.config)