import itertools import numpy as np import scipy.sparse as sp import chess, chess.pgn import random from pathlib import Path import argparse import concurrent.futures import functools import pyspark from pystreams.pystreams import Stream #from pyspark.context import SparkContext #sc = SparkContext('local[*]', 'test') #sc = SparkContext('test') # TODO: Consider input features for # - castling and ep-rights # - past move # - occupied squares (makes it easier to play legally) # - whether the king is in check # - attackers/defenders fr each square def binary_encode(board): """ Returns the board as a binary vector, for eval prediction purposes. """ rows = [] for color in [chess.WHITE, chess.BLACK]: for ptype in range(chess.PAWN, chess.KING + 1): mask = board.pieces_mask(ptype, color) rows.append(list(map(int, bin(mask)[2:].zfill(64)))) ep = [0] * 64 if board.ep_square: ep[board.ep_square] = 1 rows.append(ep) rows.append([ int(board.turn), int(bool(board.castling_rights & chess.BB_A1)), int(bool(board.castling_rights & chess.BB_H1)), int(bool(board.castling_rights & chess.BB_A8)), int(bool(board.castling_rights & chess.BB_H8)), int(board.is_check()) ]) return np.concatenate(rows) def encode_move(move): if move.promotion: return 64**2 + move.promotion - 1 return move.from_square + move.to_square * 64 def process(node): board = binary_encode(node.parent.board()) res = node.root().headers['Result'] score = 0 if res == '1-0': score = int(not node.board().turn) # turn has already been changed elif res == '0-1': score = int(node.board().turn) elif res == '1/2-1/2': score = 1/2 move = encode_move(node.move) ar = np.concatenate((board, [score, move])) #return sp.csr_matrix(ar) return ar def get_games(path, max_size=None): import chess.pgn games = iter(lambda: chess.pgn.read_game(open(path)), None) if max_size is None: yield from games for i, game in enumerate(games): if i >= max_size: break yield game def merge(it, n=1000): while True: chunk = list(itertools.islice(it, n)) if not chunk: return yield sp.csr_matrix(chunk) def work_spark(args): conf = pyspark.SparkConf().setAppName( "temp1" ).setMaster( "local[*]" ).set( "spark.driver.host", "localhost" ) \ .set('spark.executor.memory', '6g') with pyspark.SparkContext("local[*]", "PySparkWordCount", conf=conf) as sc: (sc.parallelize(args.files) .flatMap(get_games) .flatMap(lambda game: game.mainline()) #.sample(False, .1) .map(process) .mapPartitions(merge) .saveAsPickleFile('pikle.out') ) def work_streams(args): Stream(args.files) \ .peek(print) \ .flatmap(get_games) \ .peek(lambda _: print('g')) \ .flatmap(lambda game: game.mainline()) \ .map(process) \ .foreach(print) #.sample(.1) \ def main(): parser = argparse.ArgumentParser() parser.add_argument('files', help='glob for pgn files, e.g. **/*.pgn', nargs="+") parser.add_argument('-test', help='test out') parser.add_argument('-train', help='train out') parser.add_argument('-ttsplit', default=.8, help='test train split') parser.add_argument('-eval', action='store_true', help='predict eval rather than moves') args = parser.parse_args() work_spark(args) #work_streams(args) print('Done!') if __name__ == '__main__': main()