# -*- coding: utf-8 -*-

import pymongo
import pandas as pd

class TradeResults:

    def __init__(self, data):
        self.raw  = data.copy()
        self.data = TradeResults.normalize(TradeResults.clean(data))

    def all_data(self):
        return self.__drop_profit_or_loss(self.data)

    def train_data(self):
        return self.__drop_profit_or_loss(self.__train_data())

    def test_data(self):
        return self.__drop_profit_or_loss(self.__test_data())

    def train_profit_or_loss(self):
        return self.__train_data()["normalized_profit_or_loss"]

    def test_profit_or_loss(self):
        return self.__test_data()["normalized_profit_or_loss"]

    def train_up_down(self):
        return self.__up_down(self.__train_data()["profit_or_loss"])

    def test_up_down(self):
        return self.__up_down(self.__test_data()["profit_or_loss"])

    def __train_data(self):
        # 全データの 2/3 を訓練データとして使う。
        # トレード時の地合いの影響を分散させるため、時系列でソートしたものから均等に抜き出す。
        return self.data.loc[lambda df: df.index % 3 != 0, :]

    def __test_data(self):
        # 全データの 1/3 をテストデータとして使う。
        return self.data.loc[lambda df: df.index % 3 == 0, :]

    def __drop_profit_or_loss(self, data):
        return data.drop("profit_or_loss", axis=1).drop("normalized_profit_or_loss", axis=1)

    def __up_down(self, profit_or_loss):
        return profit_or_loss.apply(
            lambda p: pd.Series([
                1 if p >  0  else 0,
                1 if p <= 0  else 0
            ], index=['up', 'down']))

    @staticmethod
    def clean(data):
        del data['_id']
        del data['entered_at']
        del data['exited_at']
        data['sell_or_buy'] = data['sell_or_buy'].apply(
            lambda sell_or_buy: 0 if sell_or_buy == "sell" else 1)
        return data

    @staticmethod
    def normalize(data):
        # すべてのデータをz-scoreで正規化する
        for col in data.columns:
            key = 'normalized_' + col if col == 'profit_or_loss' else col
            data[key] = (data[col] - data[col].mean())/data[col].std(ddof=0)
        data = data.fillna(0)
        return data



class TradeResultsLoader:

    DB_HOST='mongodb'
    DB_PORT=27017

    DB='jiji'
    COLLECTION='tensorflow_example_trade_and_signals'

    def retrieve_trade_data(self):
        client = pymongo.MongoClient(
            TradeResultsLoader.DB_HOST, TradeResultsLoader.DB_PORT)
        collection = client[TradeResultsLoader.DB][TradeResultsLoader.COLLECTION]
        cursor = collection.find().sort("entered_at")
        return pd.DataFrame(list(cursor))