"""
Test Apocalipse environment
"""
import os
import shutil
import pytest
import mock
from hypothesis import given, example, settings, strategies as st
from hypothesis.extra.numpy import arrays, array_shapes
from cryptotrader.envs.trading import TradingEnvironment, PaperTradingEnvironment, BacktestDataFeed
from cryptotrader.utils import convert_to, array_normalize, array_softmax, floor_datetime
from cryptotrader.spaces import Box, Tuple
import numpy as np
import pandas as pd
from decimal import Decimal
from cryptotrader.exchange_api.poloniex import Poloniex
from datetime import datetime, timezone

from .mocks import *

# Fixtures
@pytest.fixture
def fresh_env():
    yield TradingEnvironment(period=5, obs_steps=30, tapi=tapi, fiat="USDT", name='env_test')
    shutil.rmtree(os.path.join(os.path.abspath(os.path.curdir), 'logs'))

@pytest.fixture
def ready_env():
    with mock.patch('cryptotrader.envs.trading.datetime') as mock_datetime:
        # mock_datetime.now.return_value = datetime.fromtimestamp(1507990500.000000).astimezone(timezone.utc)
        mock_datetime.now.return_value = datetime.fromtimestamp(np.choose(np.random.randint(low=10, high=len(indexes)),
                                                                          indexes)).astimezone(timezone.utc)
        mock_datetime.fromtimestamp = lambda *args, **kw: datetime.fromtimestamp(*args, **kw)
        mock_datetime.side_effect = lambda *args, **kw: datetime(*args, **kw)

        env = PaperTradingEnvironment(period=5, obs_steps=10, tapi=tapi, fiat="USDT", name='env_test')
        # env.add_pairs("USDT_BTC", "USDT_ETH")
        # env.fiat = "USDT"
        env.balance = env.get_balance()
        env.crypto = {"BTC": Decimal('0.00000000'), 'ETH': Decimal('0.00000000')}
        yield env
        shutil.rmtree(os.path.join(os.path.abspath(os.path.curdir), 'logs'))

@pytest.fixture
def data_feed():
    df = BacktestDataFeed(tapi, period=5, pairs=["USDT_BTC", "USDT_ETH"], balance={"BTC":'1.00000000',
                                                                                    "ETH":'0.50000000',
                                                                                    "USDT":'100.00000000'})
    yield df

# DATA FEED TESTS
def test_returnBalances(data_feed):
    # TODO: REWRITE THIS TEST
    balance = data_feed.returnBalances()
    assert isinstance(balance, dict)
    # data_feed.balance = {"BTC": '10.00000000'}
    # balance = data_feed.returnBalances()
    # assert isinstance(balance, dict)
    # assert balance["BTC"] == '10.00000000'
    # with pytest.raises(AssertionError):
    #     data_feed.balance = 10

def test_returnFeeInfo(data_feed):
    fee = data_feed.returnFeeInfo()
    assert isinstance(fee, dict)
    assert fee['makerFee'] == '0.00150000'

# BACKTEST AND PAPERTRAING ENVIRONMENT TESTS
def test_env_name(fresh_env):
    assert fresh_env.name == 'env_test'

class Test_env_setup(object):
    @classmethod
    def setup_class(cls):
        cls.env = TradingEnvironment(period=5, obs_steps=10, tapi=tapi, fiat="USDT", name='env_test')

    @classmethod
    def teardown_class(cls):
        shutil.rmtree(os.path.join(os.path.abspath(os.path.curdir), 'logs'))

    def test_reset_status(self):
        self.env.reset_status()
        assert self.env.status == {'OOD': False, 'Error': False, 'ValueError': False, 'ActionError': False, "NotEnoughFiat": False}

    def test_add_pairs(self):
        self.env.add_pairs("USDT_BTC")
        assert "USDT_BTC" in self.env.pairs

    @given(value=st.integers(max_value=1000))
    def test_obs_steps(self, value):
        if value >= 3:
            self.env.obs_steps = value
            assert self.env.obs_steps == value
        else:
            with pytest.raises(AssertionError):
                self.env.obs_steps = value

    @given(value=st.integers(max_value=1000))
    def test_period(self, value):
        if value >= 1:
            self.env.period = value
            assert self.env.period == value
        else:
            with pytest.raises(AssertionError):
                self.env.period = value

def test_get_ohlc(ready_env):
    env = ready_env
    for data in tapi.returnChartData()[:-env.obs_steps]:
        for pair in env.pairs:
            df = env.get_ohlc(pair, index=pd.date_range(end=data['date'], freq="%dT" % env.period, periods=env.obs_steps))
            assert isinstance(df, pd.DataFrame)
            assert df.shape[0] == env.obs_steps
            assert list(df.columns) == ['open','high','low','close','volume']
            assert df.index.freqstr == '%dT' % env.period

def test_get_history(ready_env):
    env = ready_env
    df = env.get_history()
    assert isinstance(df, pd.DataFrame)
    assert df.shape[0] == env.obs_steps
    assert set(df.columns.levels[0]) == set(env.pairs)
    assert list(df.columns.levels[1]) == ['open', 'high', 'low', 'close', 'volume']
    assert df.index.freqstr == '%dT' % env.period
    assert type(df.values.all()) == Decimal

    for data in tapi.returnChartData()[-env.obs_steps:]:
        df = env.get_history(end=datetime.fromtimestamp(data['date']))
        assert isinstance(df, pd.DataFrame)
        assert df.shape[0] == env.obs_steps
        assert set(df.columns.levels[0]) == set(env.pairs)
        assert list(df.columns.levels[1]) == ['open', 'high', 'low', 'close', 'volume']
        assert df.index.freqstr == '%dT' % env.period
        assert type(df.values.all()) == Decimal

def test_get_balance(ready_env):
    env = ready_env
    balance = env.get_balance()

    portfolio = []
    for pair in env.symbols:
        symbol = pair.split('_')
        for s in symbol:
            portfolio.append(s)

    portfolio = set(portfolio)

    assert set(balance.keys()).issubset(portfolio)

def test_fiat(fresh_env):
    env = fresh_env

    # with pytest.raises(AssertionError):
    #     env.fiat = "USDT"

    # env.add_pairs("USDT_BTC")
    env.fiat = "USDT"
    # assert env.fiat == Decimal('0.00000000')

    # with pytest.raises(KeyError):
    #     env.fiat

    env.fiat = 0
    assert env.fiat == Decimal('0.00000000')

    for i in range(10):
        env.fiat = i
        assert env.fiat == Decimal(i)

        value = np.random.rand() * i
        env.fiat = value
        assert env.fiat == convert_to.decimal(value)

    timestamp = env.timestamp
    env.fiat = {"USDT": 10, 'timestamp': timestamp}
    assert env.portfolio_df.get_value(timestamp, "USDT") == 10

def test_crypto(fresh_env):
    env = fresh_env
    env.add_pairs("USDT_BTC")
    env.fiat = "USDT"

    with pytest.raises(AssertionError):
        env.crypto = []
        env.crypto = 0
        env.crypto = '0'

    balance = env.get_balance()
    env.crypto = balance
    for symbol, value in env.crypto.items():
        assert value == convert_to.decimal(balance[symbol])
        assert symbol in env.symbols
        assert env._fiat not in env.crypto

    timestamp = env.timestamp
    env.crypto = {"BTC": 10, 'timestamp': timestamp}
    assert env.portfolio_df.get_value(timestamp, "BTC") == Decimal('10')

def test_balance(fresh_env):
    env = fresh_env
    env.add_pairs("USDT_BTC", "USDT_ETH")
    env.fiat = "USDT"

    with pytest.raises(AssertionError):
        env.balance = []
        env.balance = 0
        env.balance = '0'

    env.balance = env.get_balance()
    for symbol, value in env.balance.items():
        assert value == convert_to.decimal(env.balance[symbol])
        assert symbol in env.symbols
        assert env._fiat not in env.crypto.keys()

def test_get_close_price(ready_env):
    env = ready_env
    env.obs_df = env.get_history()

    price = env.get_open_price("BTC")
    assert isinstance(price, Decimal)
    assert price == env.obs_df["USDT_BTC"].open.iloc[-1]

    for i in env.obs_df.index:
        price = env.get_open_price("BTC", i)
        assert isinstance(price, Decimal)
        assert price == env.obs_df["USDT_BTC"].open.loc[i]

def test_get_fee(ready_env):
    env = ready_env
    fee = env.get_fee("BTC")
    assert isinstance(fee, Decimal)
    assert fee == Decimal('0.00250000')

    fee = env.get_fee("BTC", "makerFee")
    assert isinstance(fee, Decimal)
    assert fee == Decimal('0.00150000')

    with pytest.raises(AssertionError):
        fee = env.get_fee("BTC", 'wrong_str')

def test_calc_total_portval(ready_env):
    env = ready_env
    env.obs_df = env.get_history()
    portval = env.calc_total_portval()
    assert isinstance(portval, Decimal)
    assert portval >= Decimal('0.00000000')

def test_calc_posit(ready_env):
    env = ready_env
    env.obs_df = env.get_history()
    total_posit = Decimal('0E-8')
    portval = env.calc_total_portval()
    for symbol in env.symbols:
        posit = env.calc_posit(symbol, portval)
        assert isinstance(posit, Decimal)
        assert Decimal('0.00000000') <= posit <= Decimal('1.00000000')
        total_posit += posit
    assert total_posit - Decimal('1.00000000') <= Decimal('1E-8')

def test_get_previous_portval(ready_env):
    env = ready_env
    env.obs_df = env.get_history()
    with pytest.raises(KeyError):
        portval = env.get_last_portval()

    env.portval = 10
    portval = env.get_last_portval()
    assert portval == Decimal('10')

def test_get_sampled_portfolio(ready_env):
    # TODO: WRITE TEST
    env = ready_env
    env.reset()

    assert env.get_sampled_portfolio().shape == (1, 4)

def test_get_reward(ready_env):
    # TODO FIX REWARD TEST
    env = ready_env
    env.reset()
    r = env.get_reward()
    assert isinstance(r, Decimal)
    # assert r == float()
    env.fiat = Decimal(1)
    a = np.zeros(len(env.pairs) + 1)
    a[-1] = 1
    env.step(a)
    n_tests = 100
    for i, j in zip(np.random.random(n_tests), np.random.random(n_tests)):
        # env.reset()

        # env.step(a)
        env.fiat = Decimal(i)
        r = env.get_reward()
        env.fiat = Decimal(j)
        r2 = env.get_reward()
        assert np.allclose(r, r2)
        # assert r - Decimal(j / i) < Decimal("1e-4"), r - Decimal(j / i)


index = np.choose(np.random.randint(low=10, high=len(indexes)), indexes)
class Test_env_reset(object):
    @classmethod
    def setup_class(cls):
        with mock.patch('cryptotrader.envs.trading.datetime') as mock_datetime:
            mock_datetime.now.return_value = datetime.fromtimestamp(index).astimezone(timezone.utc)
            mock_datetime.fromtimestamp = lambda *args, **kw: datetime.fromtimestamp(*args, **kw)
            mock_datetime.side_effect = lambda *args, **kw: datetime(*args, **kw)

            cls.env = PaperTradingEnvironment(period=5, obs_steps=10, tapi=tapi, fiat="USDT", name='env_test')
            # cls.env.add_pairs("USDT_BTC", "USDT_ETH")
            # cls.env.fiat = "USDT"

    @classmethod
    def teardown_class(cls):
        shutil.rmtree(os.path.join(os.path.abspath(os.path.curdir), 'logs'))

    @mock.patch.object(PaperTradingEnvironment, 'timestamp',
                       floor_datetime(datetime.fromtimestamp(index).astimezone(timezone.utc), 5))
    def test_reset(self):
        obs = self.env.reset()

        # Assert observation
        assert isinstance(self.env.obs_df, pd.DataFrame) and self.env.obs_df.shape[0] == self.env.obs_steps
        assert isinstance(obs, pd.DataFrame) and obs.shape[0] == self.env.obs_steps
        # Assert taxes
        assert tuple(self.env.tax.keys()) == self.env.symbols
        # Assert portfolio log
        assert isinstance(self.env.portfolio_df, pd.DataFrame) and self.env.portfolio_df.shape[0] == 1
        assert list(self.env.portfolio_df.columns) == list(self.env.symbols) + ['portval']
        # Assert action log
        assert isinstance(self.env.action_df, pd.DataFrame) and self.env.action_df.shape[0] == 1
        assert list(self.env.action_df.columns) == list(self.env.symbols) + ['online']
        # Assert balance
        assert list(self.env.balance.keys()) == list(self.env.symbols)
        for symbol in self.env.balance:
            assert isinstance(self.env.balance[symbol], Decimal)

@pytest.mark.incremental
class Test_env_step(object):
    # TODO: CHECK THIS TEST
    @classmethod
    def setup_class(cls):
        with mock.patch('cryptotrader.envs.trading.datetime') as mock_datetime:
            mock_datetime.now.return_value = datetime.fromtimestamp(index).astimezone(timezone.utc)
            mock_datetime.fromtimestamp = lambda *args, **kw: datetime.fromtimestamp(*args, **kw)
            mock_datetime.side_effect = lambda *args, **kw: datetime(*args, **kw)

            cls.env = PaperTradingEnvironment(period=5, obs_steps=5, tapi=tapi, fiat="USDT", name='env_test')
            # cls.env.add_pairs("USDT_BTC", "USDT_ETH")
            # cls.env.fiat = "USDT"
            cls.env.setup()
            cls.env.reset()
            cls.env.fiat = 100
            cls.env.reset_status()

    @classmethod
    def teardown_class(cls):
        shutil.rmtree(os.path.join(os.path.abspath(os.path.curdir), 'logs'))

    @mock.patch.object(PaperTradingEnvironment, 'timestamp',
                       floor_datetime(datetime.fromtimestamp(index).astimezone(timezone.utc), 5))
    @given(arrays(dtype=np.float32,
                  shape=(3,),
                  elements=st.floats(allow_nan=False, allow_infinity=False, max_value=1e8, min_value=0)))
    @settings(max_examples=50)
    def test_simulate_trade(self, action):
        # Normalize action vector
        action = array_normalize(action, False)

        assert action.sum() - Decimal('1.00000000') < Decimal('1E-8'), action.sum() - Decimal('1.00000000')

        # Get timestamp
        timestamp = self.env.obs_df.index[-1]
        print(self.env.obs_df)

        # Call method
        self.env.simulate_trade(action, timestamp)
        print(self.env.action_df)

        # Assert position
        for i, symbol in enumerate(self.env.symbols):
            assert self.env.action_df.get_value(timestamp, symbol) - convert_to.decimal(action[i]) <= Decimal('1E-3')

        # Assert amount
        for i, symbol in enumerate(self.env.symbols):
            if symbol not in self.env._fiat:
                assert self.env.portfolio_df.get_value(self.env.portfolio_df[symbol].last_valid_index(), symbol) - \
                       self.env.action_df.get_value(timestamp, symbol) * self.env.calc_total_portval(timestamp) / \
                        self.env.get_open_price(symbol, timestamp) <= convert_to.decimal('1E-4')

    @mock.patch.object(PaperTradingEnvironment, 'timestamp',
                       floor_datetime(datetime.fromtimestamp(index).astimezone(timezone.utc), 5))
    @given(arrays(dtype=np.float32,
                  shape=(3,),
                  elements=st.floats(allow_nan=False, allow_infinity=False, max_value=1e8, min_value=0)))
    @settings(max_examples=50)
    def test_step(self, action):
        # TODO: FIX STEP TEST
        # obs = self.env.reset()
        action = array_softmax(action)
        obs, reward, done, status = self.env.step(action)

        # Assert returned obs
        assert isinstance(obs, pd.DataFrame)
        assert obs.shape[0] == self.env.obs_steps
        assert set(obs.columns.levels[0]) == set(list(self.env.pairs) + [self.env._fiat])

        # Assert reward
        assert isinstance(reward, np.float64)
        assert reward not in (np.nan, np.inf)

        # Assert done
        assert isinstance(done, bool)

        # Assert status
        assert status == self.env.status
        for key in status:
            assert status[key] == False

# LIVETRADING ENVIRONMENT TESTS


if __name__ == '__main__':
    pytest.main()