Python tensorflow_transform.tf_metadata.dataset_schema.Schema() Examples

The following are 15 code examples of tensorflow_transform.tf_metadata.dataset_schema.Schema(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module tensorflow_transform.tf_metadata.dataset_schema , or try the search function .
Example #1
Source File: input_metadata.py    From cloudml-samples with Apache License 2.0 6 votes vote down vote up
def _create_raw_metadata():
  """Create a DatasetMetadata for the raw data."""
  column_schemas = {
      key: dataset_schema.ColumnSchema(
          tf.string, [], dataset_schema.FixedColumnRepresentation())
      for key in CATEGORICAL_FEATURE_KEYS
  }
  column_schemas.update({
      key: dataset_schema.ColumnSchema(
          tf.float32, [], dataset_schema.FixedColumnRepresentation())
      for key in NUMERIC_FEATURE_KEYS
  })
  column_schemas[LABEL_KEY] = dataset_schema.ColumnSchema(
      tf.string, [], dataset_schema.FixedColumnRepresentation())

  raw_data_metadata = dataset_metadata.DatasetMetadata(dataset_schema.Schema(
      column_schemas))
  return raw_data_metadata 
Example #2
Source File: executor.py    From tfx with Apache License 2.0 6 votes vote down vote up
def _GetSchemaProto(
    metadata: dataset_metadata.DatasetMetadata) -> schema_pb2.Schema:
  """Gets the schema proto associated with a DatasetMetadata.

  This is needed because tensorflow_transform 0.13 and tensorflow_transform 0.14
  have a different API for DatasetMetadata.

  Args:
    metadata: A dataset_metadata.DatasetMetadata.

  Returns:
    A schema_pb2.Schema.
  """
  # `schema` is either a Schema proto or dataset_schema.Schema.
  schema = metadata.schema
  # In the case where it's a dataset_schema.Schema, fetch the schema proto.
  return getattr(schema, '_schema_proto', schema) 
Example #3
Source File: executor.py    From tfx with Apache License 2.0 6 votes vote down vote up
def _ReadMetadata(self, data_format: Text,
                    schema_path: Text) -> dataset_metadata.DatasetMetadata:
    """Returns a dataset_metadata.DatasetMetadata for the input data.

    Args:
      data_format: name of the input data format.
      schema_path: path to schema file.

    Returns:
      A dataset_metadata.DatasetMetadata representing the provided set of
          columns.
    """

    if self._ShouldDecodeAsRawExample(data_format):
      return dataset_metadata.DatasetMetadata(_RAW_EXAMPLE_SCHEMA)
    schema_proto = self._GetSchema(schema_path)
    # For compatibility with tensorflow_transform 0.13 and 0.14, we create and
    # then update a DatasetMetadata.
    result = dataset_metadata.DatasetMetadata(dataset_schema.Schema({}))
    _GetSchemaProto(result).CopyFrom(schema_proto)
    return result 
Example #4
Source File: executor.py    From tfx with Apache License 2.0 6 votes vote down vote up
def _GetDecodeFunction(self, data_format: Union[Text, int],
                         schema: dataset_schema.Schema) -> Any:
    """Returns the decode function for `data_format`.

    Args:
      data_format: name of data format.
      schema: a dataset_schema.Schema for the data.

    Returns:
      Function for decoding examples.
    """
    if self._ShouldDecodeAsRawExample(data_format):
      if self._IsDataFormatSequenceExample(data_format):
        absl.logging.warning(
            'TFX Transform doesn\'t officially support tf.SequenceExample, '
            'follow b/38235367 to track official support progress. We do not '
            'guarantee not to break your pipeline if you use Transform with a '
            'tf.SequenceExample data type. Use at your own risk.')
      return lambda x: {RAW_EXAMPLE_KEY: x}
    else:
      return tft.coders.ExampleProtoCoder(schema, serialized=True).decode 
Example #5
Source File: dataset_metadata.py    From transform with Apache License 2.0 5 votes vote down vote up
def __init__(self, schema):
    if isinstance(schema, dict):
      schema = dataset_schema.Schema(schema)
    self._schema = schema 
Example #6
Source File: dataset_schema_test.py    From transform with Apache License 2.0 5 votes vote down vote up
def test_feature_spec_unsupported_dtype(self):
    with self.assertRaisesRegexp(ValueError, 'invalid dtype'):
      sch.Schema({
          'fixed_float': sch.ColumnSchema(
              tf.float64, [], sch.FixedColumnRepresentation())
      }) 
Example #7
Source File: dataset_schema_test.py    From transform with Apache License 2.0 5 votes vote down vote up
def test_schema_equality(self):
    schema1 = sch.Schema(column_schemas={
        'fixed_int': sch.ColumnSchema(
            tf.int64, [2], sch.FixedColumnRepresentation()),
        'var_float': sch.ColumnSchema(
            tf.float32, None, sch.ListColumnRepresentation())
    })
    schema2 = sch.Schema(column_schemas={
        'fixed_int': sch.ColumnSchema(
            tf.int64, [2], sch.FixedColumnRepresentation()),
        'var_float': sch.ColumnSchema(
            tf.float32, None, sch.ListColumnRepresentation())
    })
    schema3 = sch.Schema(column_schemas={
        'fixed_int': sch.ColumnSchema(
            tf.int64, [2], sch.FixedColumnRepresentation()),
        'var_float': sch.ColumnSchema(
            tf.string, None, sch.ListColumnRepresentation())
    })
    schema4 = sch.Schema(column_schemas={
        'fixed_int': sch.ColumnSchema(
            tf.int64, [2], sch.FixedColumnRepresentation())
    })

    self.assertEqual(schema1, schema2)
    self.assertNotEqual(schema1, schema3)
    self.assertNotEqual(schema1, schema4) 
Example #8
Source File: test_common.py    From transform with Apache License 2.0 5 votes vote down vote up
def get_manually_created_schema():
  """Provide a test schema built from scratch using the Schema classes."""
  return sch.Schema(_COLUMN_SCHEMAS) 
Example #9
Source File: pipeline.py    From realtime-embeddings-matching with Apache License 2.0 5 votes vote down vote up
def get_metadata():
  from tensorflow_transform.tf_metadata import dataset_schema
  from tensorflow_transform.tf_metadata import dataset_metadata

  metadata = dataset_metadata.DatasetMetadata(dataset_schema.Schema({
    'id': dataset_schema.ColumnSchema(
      tf.string, [], dataset_schema.FixedColumnRepresentation()),
    'text': dataset_schema.ColumnSchema(
      tf.string, [], dataset_schema.FixedColumnRepresentation())
  }))
  return metadata 
Example #10
Source File: task.py    From pipelines with Apache License 2.0 5 votes vote down vote up
def make_tft_input_metadata(schema):
  """Create tf-transform metadata from given schema."""
  tft_schema = {}

  for col_schema in schema:
    col_type = col_schema['type']
    col_name = col_schema['name']
    if col_type == 'NUMBER':
      tft_schema[col_name] = dataset_schema.ColumnSchema(
          tf.float32, [], dataset_schema.FixedColumnRepresentation(default_value=0.0))
    elif col_type in ['CATEGORY', 'TEXT', 'IMAGE_URL', 'KEY']:
      tft_schema[col_name] = dataset_schema.ColumnSchema(
          tf.string, [], dataset_schema.FixedColumnRepresentation(default_value=''))
  return dataset_metadata.DatasetMetadata(dataset_schema.Schema(tft_schema)) 
Example #11
Source File: executor.py    From tfx with Apache License 2.0 5 votes vote down vote up
def _GetSchema(self, schema_path: Text) -> schema_pb2.Schema:
    """Gets a tf.metadata schema.

    Args:
      schema_path: Path to schema file.

    Returns:
      A tf.metadata schema.
    """
    schema_reader = io_utils.SchemaReader()
    return schema_reader.read(schema_path) 
Example #12
Source File: executor.py    From tfx with Apache License 2.0 5 votes vote down vote up
def _GenerateStats(
      pcoll: beam.pvalue.PCollection,
      stats_output_path: Text,
      schema: schema_pb2.Schema,
      stats_options: tfdv.StatsOptions,
  ) -> beam.pvalue.PDone:
    """Generates statistics.

    Args:
      pcoll: PCollection of examples.
      stats_output_path: path where statistics is written to.
      schema: schema.
      stats_options: An instance of `tfdv.StatsOptions()` used when computing
        statistics.

    Returns:
      beam.pvalue.PDone.
    """
    def _FilterInternalColumn(record_batch):
      filtered_column_names = []
      filtered_columns = []
      for i, column_name in enumerate(record_batch.schema.names):
        if column_name != _TRANSFORM_INTERNAL_FEATURE_FOR_KEY:
          filtered_column_names.append(column_name)
          filtered_columns.append(record_batch.column(i))
      return pa.RecordBatch.from_arrays(filtered_columns, filtered_column_names)

    pcoll |= 'FilterInternalColumn' >> beam.Map(_FilterInternalColumn)
    stats_options.schema = schema
    # pylint: disable=no-value-for-parameter
    return (
        pcoll
        | 'GenerateStatistics' >> tfdv.GenerateStatistics(stats_options)
        | 'WriteStats' >> Executor._WriteStats(stats_output_path))

  # TODO(b/150456345): Obviate this once TFXIO-in-Transform rollout is
  # completed. 
Example #13
Source File: executor.py    From tfx with Apache License 2.0 5 votes vote down vote up
def __init__(self, schema: Optional[schema_pb2.Schema]):
      self._serialized_schema = schema.SerializeToString() if schema else None 
Example #14
Source File: executor.py    From tfx with Apache License 2.0 5 votes vote down vote up
def process(self, element: Dict[Text, Any], schema: schema_pb2.Schema
               ) -> Generator[Tuple[Any, Any], None, None]:
      if self._coder is None:
        self._coder = tft.coders.ExampleProtoCoder(schema, serialized=True)

      # Make sure that the synthetic key feature doesn't get encoded.
      key = element.get(_TRANSFORM_INTERNAL_FEATURE_FOR_KEY, None)
      if key is not None:
        element = element.copy()
        del element[_TRANSFORM_INTERNAL_FEATURE_FOR_KEY]
      yield (key, self._coder.encode(element)) 
Example #15
Source File: executor.py    From tfx with Apache License 2.0 5 votes vote down vote up
def _CreateTFXIO(self, dataset: _Dataset,
                   schema: schema_pb2.Schema) -> tfxio.TFXIO:
    """Creates a TFXIO instance for `dataset`."""
    if self._ShouldDecodeAsRawExample(dataset.data_format):
      return raw_tf_record.RawTfRecordTFXIO(
          file_pattern=dataset.file_pattern,
          raw_record_column_name=RAW_EXAMPLE_KEY,
          telemetry_descriptors=[_TRANSFORM_COMPONENT_DESCRIPTOR])
    else:
      return tf_example_record.TFExampleRecord(
          file_pattern=dataset.file_pattern,
          validate=False,
          telemetry_descriptors=[_TRANSFORM_COMPONENT_DESCRIPTOR],
          schema=schema)