#!/usr/bin/env python3
import chess

class State(object):
  def __init__(self, board=None):
    if board is None:
      self.board = chess.Board()
    else:
      self.board = board

  def key(self):
    return (self.board.board_fen(), self.board.turn, self.board.castling_rights, self.board.ep_square)

  def serialize(self):
    import numpy as np
    assert self.board.is_valid()

    bstate = np.zeros(64, np.uint8)
    for i in range(64):
      pp = self.board.piece_at(i)
      if pp is not None:
        #print(i, pp.symbol())
        bstate[i] = {"P": 1, "N": 2, "B": 3, "R": 4, "Q": 5, "K": 6, \
                     "p": 9, "n":10, "b":11, "r":12, "q":13, "k": 14}[pp.symbol()]
    if self.board.has_queenside_castling_rights(chess.WHITE):
      assert bstate[0] == 4
      bstate[0] = 7
    if self.board.has_kingside_castling_rights(chess.WHITE):
      assert bstate[7] == 4
      bstate[7] = 7
    if self.board.has_queenside_castling_rights(chess.BLACK):
      assert bstate[56] == 8+4
      bstate[56] = 8+7
    if self.board.has_kingside_castling_rights(chess.BLACK):
      assert bstate[63] == 8+4
      bstate[63] = 8+7

    if self.board.ep_square is not None:
      assert bstate[self.board.ep_square] == 0
      bstate[self.board.ep_square] = 8
    bstate = bstate.reshape(8,8)

    # binary state
    state = np.zeros((5,8,8), np.uint8)

    # 0-3 columns to binary
    state[0] = (bstate>>3)&1
    state[1] = (bstate>>2)&1
    state[2] = (bstate>>1)&1
    state[3] = (bstate>>0)&1

    # 4th column is who's turn it is
    state[4] = (self.board.turn*1.0)

    # 257 bits according to readme
    return state

  def edges(self):
    return list(self.board.legal_moves)

if __name__ == "__main__":
  s = State()