import typing

import tensorflow as tf

FEATURES = [
    "fixed_acidity", "volatile_acidity", "citric_acid", "residual_sugar",
    "chlorides", "free_sulfur_dioxide", "total_sulfur_dioxide", "density",
    "pH", "sulphates", "alcohol"
]
LABEL = "quality"


def get_dataset(
    path: str,
    train_fraction: float = 0.7,
    split: str = "train"
) -> tf.data.Dataset:
    def split_label(*row):
        return dict(zip(FEATURES, row)), row[-1]

    def in_training_set(*row):
        num_buckets = 1000
        key = tf.strings.join(list(map(tf.as_string, row)))
        bucket_id = tf.strings.to_hash_bucket_fast(key, num_buckets)
        return bucket_id < int(train_fraction * num_buckets)

    def in_test_set(*row):
        return ~in_training_set(*row)

    data = tf.data.experimental.CsvDataset(
        path,
        [tf.float32] * len(FEATURES) + [tf.int32],
        header=True,
        field_delim=";")

    if split == "train":
        return data.filter(in_training_set).map(split_label)
    elif split == "test":
        return data.filter(in_test_set).map(split_label)
    else:
        raise ValueError("Unknown option split, must be 'train' or 'test'")


def get_feature_columns():
    return [tf.feature_column.numeric_column(name) for name in FEATURES]


def get_n_classes():
    return 10