# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Common data and utilities for tf_metadata tests."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from tensorflow_transform.tf_metadata import dataset_schema as sch
from tensorflow_transform.tf_metadata import schema_utils


test_feature_spec = {
    # FixedLenFeatures
    'fixed_categorical_int_with_range':
        tf.io.FixedLenFeature(shape=[], dtype=tf.int64),
    'fixed_int':
        tf.io.FixedLenFeature(shape=[5], dtype=tf.int64),
    'fixed_float':
        tf.io.FixedLenFeature(shape=[5], dtype=tf.float32),
    'fixed_string':
        tf.io.FixedLenFeature(shape=[5], dtype=tf.string),

    # VarLenFeatures
    'var_int':
        tf.io.VarLenFeature(dtype=tf.int64),
    'var_float':
        tf.io.VarLenFeature(dtype=tf.float32),
    'var_string':
        tf.io.VarLenFeature(dtype=tf.string),
}


def get_test_schema():
  return schema_utils.schema_from_feature_spec(test_feature_spec)


_COLUMN_SCHEMAS = {
    # FixedLenFeatures
    'fixed_categorical_int_with_range': sch.ColumnSchema(
        sch.IntDomain(tf.int64, -5, 10, True),
        [], sch.FixedColumnRepresentation()),
    'fixed_int': sch.ColumnSchema(
        tf.int64, [5], sch.FixedColumnRepresentation()),
    'fixed_float': sch.ColumnSchema(
        tf.float32, [5], sch.FixedColumnRepresentation()),
    'fixed_string': sch.ColumnSchema(
        tf.string, [5], sch.FixedColumnRepresentation()),
    # VarLenFeatures
    'var_int': sch.ColumnSchema(
        tf.int64, None, sch.ListColumnRepresentation()),
    'var_float': sch.ColumnSchema(
        tf.float32, None, sch.ListColumnRepresentation()),
    'var_string': sch.ColumnSchema(
        tf.string, None, sch.ListColumnRepresentation())
}


def get_manually_created_schema():
  """Provide a test schema built from scratch using the Schema classes."""
  return sch.Schema(_COLUMN_SCHEMAS)