from collections import namedtuple
import time

import numpy as np
from tentacle.board import Board
from tentacle.game import Game
from tentacle.tree_node import TreeNode


class MCTS1(object):
    StatItem = namedtuple('StatItem', 'P, Nv, Nr, Wv, Wr, Q')

    def __init__(self, value_fn, policy_fn, rollout_policy_fn):
        self._lmbda = 0.5
        self._c_puct = 5
        self.n_thr = 40
        self.n_vl = 3
        self._rollout_limit = 80
        self._L = 5
        self._n_playout = 50

        self._root = TreeNode(None, 1.0)
        self._value = value_fn
        self._policy = policy_fn
        self._rollout = rollout_policy_fn

#     def work(self, root_state):
#         max_iter = 10000
#
#         root = Node(root_state) #some infor
#
#         for _ in range(max_iter):
#             node = root
#             state = root_state
#
#             while !node.has_untried_moves() and node->has_children():
#                 node = node->select_child_UCT()
#                 state.do_move(node.move)
#
#             if node.has_untried_moves():
#                 move = node.get_untried_move(randomly)
#                 state.do_move(move)
#                 node = node.add_child(move, state)
#
#             while state.has_moves():
#                 state.do_random_move(randomly)
#
#             while node is not None:
#                 node.update(state.get_result(node.player_to_move))
#                 node = node.parent
#
#         return root

    def _playout(self, state, leaf_depth):
        # start_time = time.time()
        node = self._root

        print('exploit')
        for i in range(leaf_depth):
            legal_states, _, legal_moves = Game.possible_moves(state)
#             print(state)
#             print(legal_moves)
#             print('depth:', i, 'legal moves:', legal_moves.shape)

            if len(legal_states) == 0:
                break
            if node.is_leaf():
                action_probs = self._policy(state)
                if len(action_probs) == 0:
                    break
#                 print('num of action-prob:', len(action_probs))
                node.expand(action_probs)

#             print('num of children:', len(node._children))
            best_move, node = node.select()
            idx = np.where(legal_moves == best_move)[0]
            if idx.size == 0:
                print('depth:', i, idx)
                print('best move:', best_move)
#                 print(legal_moves)
                p = node.parent
                for a, s1 in p.children.items():
                    print('  ', a, s1.get_value())

            assert idx.size == 1
            state = legal_states[idx[0]]

#         duration = time.time() - start_time
#         print('time cost:', duration)
        print('rollout...')
        v = self._value(state) if self._lmbda < 1 else 0
        z = self._evaluate_rollout(state, self._rollout_limit) if self._lmbda > 0 else 0
        leaf_value = (1 - self._lmbda) * v + self._lmbda * z

        node.update_recursive(leaf_value, self._c_puct)

    def _evaluate_rollout(self, state, limit):
        # _, player, legal_moves = Game.possible_moves(state)
        winner = 0

#         old_board = Board()
#         old_board.stones = state
        player = None
        for i in range(limit):
            legal_states, p, legal_moves = Game.possible_moves(state)
            if player is None:
                player = p
            if len(legal_states) == 0:
                break

            probs = self._rollout(state, legal_moves)
            mask = np.full_like(probs, -0.01)
            mask[:, legal_moves] = probs[:, legal_moves]
            probs = mask

            best_move = np.argmax(probs, 1)[0]

            idx = np.where(legal_moves == best_move)[0]
#             if idx.size == 0:
#                 print(i, idx)
#                 print(best_move)
#                 print(probs.shape)
#                 print(legal_moves)
#                 print(probs)
            assert idx.size == 1
            idx = idx[0]
            st1 = legal_states[idx]

            over, winner, last_loc = st1.is_over(state)
            if over:
                break

            state = st1
        else:
            # If no break from the loop, issue a warning.
            print("WARNING: rollout reached move limit")

        if winner == 0:
            return 0
        else:
            return 1 if winner == player else -1

    def get_move(self, state):
        for n in range(self._n_playout):
            # state_copy = state.copy()
            self._playout(state, self._L)

        return max(self._root._children.items(), key=lambda act_node: act_node[1]._n_visits)[0]

    def update_with_move(self, last_move):
        if last_move in self._root._children:
            self._root = self._root._children[last_move]
            self._root._parent = None
        else:
            self._root = TreeNode(None, 1.0)

    def pack_state(self, state):
        black = np.packbits(state == Board.STONE_BLACK)
        white = np.packbits(state == Board.STONE_WHITE)
        empty = np.packbits(state == Board.STONE_EMPTY)
        image = np.concatenate((black, white, empty))
        return bytes(image)

    def unpack_state(self, s, shape):
        a = np.fromstring(s, dtype=np.uint8)
        a = np.unpackbits(a)
        a = a.reshape(shape[0], -1)
        a = a[:, :shape[1]]
        b = np.zeros_like(a[0], np.int)
        b[a[0] == 1] = Board.STONE_BLACK
        b[a[1] == 1] = Board.STONE_WHITE
        b[a[2] == 1] = Board.STONE_EMPTY
        return b

    def test_pack_unpack(self):
        for _ in range(1000):
            a = np.random.choice([0, 1, 2], 81)
            compact = self.pack_state(a)
            b = self.unpack_state(compact, (3, 81))
            assert np.all(a == b)

# if __name__ == '__main__':
#     mcts = MCTS1()
#     mcts.test_pack_unpack()