from TradzQAI.core import Local_Worker, Local_env
from TradzQAI.core.environnement.base import dataLoader
from TradzQAI.tools import Saver, Logger, red

import time, os
from threading import Thread


class Local_session(Thread):

    def __init__(self, mode="train", contract_type="classic", config='config/', db=None, agent="PPO"):
        self.db = db
        if not "/" in config[len(config)-1]:
            raise ValueError("You forget \"/\" at the end, it should be {}/".format(config))
    
        self.env = None
        self.mode = mode
        self.contract_type = contract_type
        self.config = config
        self.agent = None
        self.worker = None

        self.saver = Saver()
        self.logger = None
        self.dl = None

        self.settings = dict()

        if self.saver.check_settings_files(config):
            self.settings['env'], self.settings['agent'], self.settings['network'] = self.saver.load_settings(config)
            self.logger = Logger()
            self.dl = dataLoader(directory=self.settings['env']['base']['data_directory'], mode=self.mode)
            #self.settings['env']['base'].pop('data_directory')
            self.saver._check(self.settings['agent']['type'].split('_')[0].upper(), self.settings)
        else:
            self.initEnv()
            default_env, default_network = self.env.get_default_settings()
            self.saver.save_settings(default_env,
                getattr(__import__('TradzQAI'), agent).get_specs(), 
                default_network, config)

        Thread.__init__(self)

    def stop(self):
        self.env.close()
        self.logger.stop()

    def getWorker(self):
        return self.worker

    def getEnv(self):
        return self.env

    def getAgent(self):
        return self.agent

    def setAgent(self, agent=None, device=None):
        if agent:
            self.env.model_name = agent
        if self.settings['agent']['type'].split('_')[0].upper() in self.src_agents():
            import warnings
            with warnings.catch_warnings():
                warnings.filterwarnings("ignore",category=FutureWarning)
                from TradzQAI.agents.agent import Agent
                self.agent = Agent
                self.device = device
        else:
            raise ValueError('could not import %s' % self.settings['agent']['type'].split('_')[0].upper())

    def loadSession(self):
        if not self.env:
            self.initEnv()
        if not self.env.stop:
            self.initAgent()
            self.initWorker()
        else:
            print (red("Warning : ")+"You cannot start the session without setting, "+\
            "any data directory in {}environnement".format(self.config))

    def src_agents(self):
        ignore = ['Agent.py', '__init__.py', '__pycache__']
        valid = []
        for f in os.listdir("TradzQAI/agents"):
            if f not in ignore:
                valid.append(f.replace(".py", ""))
        return valid

    def initAgent(self):
        if not self.agent:
            self.setAgent()
        for classe in self.agent.__mro__:
            if ("tensorforce" and self.agent.__name__) in str(classe):
                self.agent = self.agent(env=self.env, device=self.device)._get()
                return
        self.agent = self.agent(env=self.env, device=self.device)._get()

    def initWorker(self):
        self.worker = Local_Worker(env=self.env, agent=self.agent)

    def initEnv(self):
        self.env = Local_env(mode=self.mode, 
            contract_type=self.contract_type, config=self.settings, 
            logger=self.logger, saver=self.saver, dataloader=self.dl)

    def run(self):
        if not self.agent:
            raise ValueError("add an agent and load the session before running")
        elif not self.env.stop:
            self.logger.start()
            Thread(target=self.worker.run).start()
        else:
            print (red("Warning : ")+"You cannot start the session without setting, "+\
            "any data directory in {}environnement".format(self.config))
            self.stop()