import itertools
import gzip
import numpy as np
import os
import struct
import sys

from features import bulk_extract_features
import go
from sgf_wrapper import replay_sgf
import utils

# Number of data points to store in a chunk on disk
CHUNK_SIZE = 4096
CHUNK_HEADER_FORMAT = "iii?"
CHUNK_HEADER_SIZE = struct.calcsize(CHUNK_HEADER_FORMAT)

def make_onehot(coords):
    num_positions = len(coords)
    output = np.zeros([num_positions, go.N ** 2], dtype=np.uint8)
    for i, coord in enumerate(coords):
        output[i, utils.flatten_coords(coord)] = 1
    return output

def find_sgf_files(*dataset_dirs):
    for dataset_dir in dataset_dirs:
        full_dir = os.path.join(os.getcwd(), dataset_dir)
        dataset_files = [os.path.join(full_dir, name) for name in os.listdir(full_dir)]
        for f in dataset_files:
            if os.path.isfile(f) and f.endswith(".sgf"):
                yield f

def get_positions_from_sgf(file):
    with open(file) as f:
        for position_w_context in replay_sgf(f.read()):
            if position_w_context.is_usable():
                yield position_w_context

def split_test_training(positions_w_context, est_num_positions):
    print("Estimated number of chunks: %s" % (est_num_positions // CHUNK_SIZE), file=sys.stderr)
    desired_test_size = 10**5
    if est_num_positions < 2 * desired_test_size:
        positions_w_context = list(positions_w_context)
        test_size = len(positions_w_context) // 3
        return positions_w_context[:test_size], [positions_w_context[test_size:]]
    else:
        shuffled_positions = utils.shuffler(positions_w_context)
        test_chunk = utils.take_n(desired_test_size, shuffled_positions)
        training_chunks = utils.iter_chunks(CHUNK_SIZE, shuffled_positions)
        return test_chunk, training_chunks


class DataSet(object):
    def __init__(self, pos_features, next_moves, results, is_test=False):
        self.pos_features = pos_features
        self.next_moves = next_moves
        self.results = results
        self.is_test = is_test
        assert pos_features.shape[0] == next_moves.shape[0], "Didn't pass in same number of pos_features and next_moves."
        self.data_size = pos_features.shape[0]
        self.board_size = pos_features.shape[1]
        self.input_planes = pos_features.shape[-1]
        self._index_within_epoch = 0

    def shuffle(self):
        perm = np.arange(self.data_size)
        np.random.shuffle(perm)
        self.pos_features = self.pos_features[perm]
        self.next_moves = self.next_moves[perm]
        self._index_within_epoch = 0

    def get_batch(self, batch_size):
        assert batch_size < self.data_size
        if self._index_within_epoch + batch_size > self.data_size:
            self.shuffle()
        start = self._index_within_epoch
        end = start + batch_size
        self._index_within_epoch += batch_size
        return self.pos_features[start:end], self.next_moves[start:end]

    @staticmethod
    def from_positions_w_context(positions_w_context, is_test=False):
        positions, next_moves, results = zip(*positions_w_context)
        extracted_features = bulk_extract_features(positions)
        encoded_moves = make_onehot(next_moves)
        return DataSet(extracted_features, encoded_moves, results, is_test=is_test)

    def write(self, filename):
        header_bytes = struct.pack(CHUNK_HEADER_FORMAT, self.data_size, self.board_size, self.input_planes, self.is_test)
        position_bytes = np.packbits(self.pos_features).tostring()
        next_move_bytes = np.packbits(self.next_moves).tostring()
        with gzip.open(filename, "wb", compresslevel=6) as f:
            f.write(header_bytes)
            f.write(position_bytes)
            f.write(next_move_bytes)

    @staticmethod
    def read(filename):
        with gzip.open(filename, "rb") as f:
            header_bytes = f.read(CHUNK_HEADER_SIZE)
            data_size, board_size, input_planes, is_test = struct.unpack(CHUNK_HEADER_FORMAT, header_bytes)

            position_dims = data_size * board_size * board_size * input_planes
            next_move_dims = data_size * board_size * board_size

            # the +7 // 8 compensates for numpy's bitpacking padding
            packed_position_bytes = f.read((position_dims + 7) // 8)
            packed_next_move_bytes = f.read((next_move_dims + 7) // 8)
            # should have cleanly finished reading all bytes from file!
            assert len(f.read()) == 0

            flat_position = np.unpackbits(np.fromstring(packed_position_bytes, dtype=np.uint8))[:position_dims]
            flat_nextmoves = np.unpackbits(np.fromstring(packed_next_move_bytes, dtype=np.uint8))[:next_move_dims]

            pos_features = flat_position.reshape(data_size, board_size, board_size, input_planes)
            next_moves = flat_nextmoves.reshape(data_size, board_size * board_size)

        return DataSet(pos_features, next_moves, [], is_test=is_test)

def parse_data_sets(*data_sets):
    sgf_files = list(find_sgf_files(*data_sets))
    print("%s sgfs found." % len(sgf_files), file=sys.stderr)
    est_num_positions = len(sgf_files) * 200 # about 200 moves per game
    positions_w_context = itertools.chain(*map(get_positions_from_sgf, sgf_files))

    test_chunk, training_chunks = split_test_training(positions_w_context, est_num_positions)
    return test_chunk, training_chunks