import yaml
import tensorflow as tf
import random
from tqdm import tqdm
from collections import Counter
from contextlib import ExitStack

import logging
logger = logging.getLogger(__name__)

# --------------------------------------------------------------------------
# Miscel
# --------------------------------------------------------------------------

def min_none(a, b):
	if a is None:
		return b
	if b is None:
		return a
	return min(a,b)


# --------------------------------------------------------------------------
# TFRecord functions
# --------------------------------------------------------------------------

# Why it's so awkward to write a record I do not know

def write_int64_feature(value):
	return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def write_int64_array_feature(value):
	return tf.train.Feature(int64_list=tf.train.Int64List(value=value)),

def write_boolean_array_feature(value):
	return write_int64_array_feature(value)

def write_string_feature(value):
	return tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.compat.as_bytes(value)]))


# TODO: Better naming / structure

def parse_feature_int_array():
	return tf.FixedLenSequenceFeature([],tf.int64, allow_missing=True)

def parse_feature_boolean_array():
	return parse_feature_int_array()

def parse_feature_string():
	return tf.FixedLenSequenceFeature([],tf.string, allow_missing=True)

def parse_feature_int():
	return tf.FixedLenFeature([], tf.int64)


# --------------------------------------------------------------------------
# TF helpers
# --------------------------------------------------------------------------

def tf_startswith(tensor, prefix, axis=None):
	return tf.reduce_all(tf.equal(tf.substr(tensor, 0, len(prefix)), prefix), axis=axis)



# --------------------------------------------------------------------------
# File readers and writers
# --------------------------------------------------------------------------

def read_gqa(args, limit=None):

	if limit is None:
		limit = args["limit"]

	with ExitStack() as stack:
		files = [stack.enter_context(open(fname)) for fname in args["gqa_paths"]]
		
		in_files = [
			stack.enter_context(tf.gfile.GFile(i, 'r'))
			for i in args["gqa_paths"]
		]

		yamls = [
			yaml.safe_load_all(i)
			for i in in_files
		]

		ctr = 0

		for row in zip(*yamls):
			for i in row:
				if i is not None:
					if args["filter_type_prefix"] is None or i["question"]["type_string"].startswith(args["filter_type_prefix"]):
						yield i
						ctr += 1
						if limit is not None and ctr >= limit:
							logger.debug("Hit limit, stop")
							return
					else:
						logger.debug(f"{i['question']['type_string']} does not match prefix {args['filter_type_prefix']}")
				else:
					logger.debug("Skipping None yaml doc")




# --------------------------------------------------------------------------
# Dataset helpers
# --------------------------------------------------------------------------

def StringDataset(s):

	def generator():
		yield s

	return tf.data.Dataset.from_generator(generator, tf.string, tf.TensorShape([]) )