"""Input feature columns and input_fn for models. Handles both training, evaluation and inference. """ import tensorflow as tf def BuildTextExample(text, ngrams=None, label=None): record = tf.train.Example() text = [tf.compat.as_bytes(x) for x in text] record.features.feature["text"].bytes_list.value.extend(text) if label is not None: label = tf.compat.as_bytes(label) record.features.feature["label"].bytes_list.value.append(label) if ngrams is not None: ngrams = [tf.compat.as_bytes(x) for x in ngrams] record.features.feature["ngrams"].bytes_list.value.extend(ngrams) return record def ParseSpec(use_ngrams, include_target): parse_spec = {"text": tf.VarLenFeature(dtype=tf.string)} if use_ngrams: parse_spec["ngrams"] = tf.VarLenFeature(dtype=tf.string) if include_target: parse_spec["label"] = tf.FixedLenFeature(shape=(), dtype=tf.string, default_value=None) return parse_spec def InputFn(mode, use_ngrams, input_file, vocab_file, vocab_size, embedding_dimension, num_oov_vocab_buckets, label_file, label_size, ngram_embedding_dimension, num_ngram_hash_buckets, batch_size, num_epochs=None, num_threads=1): if num_epochs <= 0: num_epochs=None def input_fn(): include_target = mode != tf.estimator.ModeKeys.PREDICT parse_spec = ParseSpec(use_ngrams, include_target) print("ParseSpec", parse_spec) print("Input file:", input_file) features = tf.contrib.learn.read_batch_features( input_file, batch_size, parse_spec, tf.TFRecordReader, num_epochs=num_epochs, reader_num_threads=num_threads) label = None if include_target: label = features.pop("label") return features, label return input_fn def ServingInputFn(use_ngrams): parse_spec = ParseSpec(use_ngrams, include_target=False) return tf.estimator.export.build_parsing_serving_input_receiver_fn( parse_spec)