# Copyright (c) 2017-2018 Uber Technologies, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import division from decimal import Decimal import numpy as np import pyarrow as pa import pyarrow.parquet as pq import pytest from pyspark import Row from pyspark.sql.types import StringType, IntegerType, DecimalType, ShortType, LongType from petastorm.codecs import ScalarCodec, NdarrayCodec from petastorm.unischema import Unischema, UnischemaField, dict_to_spark_row, \ insert_explicit_nulls, match_unischema_fields, _new_gt_255_compatible_namedtuple, _fullmatch try: from unittest import mock except ImportError: from mock import mock def _mock_parquet_dataset(partitions, arrow_schema): """Creates a pyarrow.ParquetDataset mock capable of returning: parquet_dataset.pieces[0].get_metadata(parquet_dataset.fs.open).schema.to_arrow_schema() == schema parquet_dataset.partitions = partitions :param partitions: expected to be a list of pa.parquet.PartitionSet :param arrow_schema: an instance of pa.arrow_schema to be assumed by the mock parquet dataset object. :return: """ piece_mock = mock.Mock() piece_mock.get_metadata().schema.to_arrow_schema.return_value = arrow_schema dataset_mock = mock.Mock() type(dataset_mock).pieces = mock.PropertyMock(return_value=[piece_mock]) type(dataset_mock).partitions = partitions return dataset_mock def test_fields(): """Try using 'fields' getter""" TestSchema = Unischema('TestSchema', [ UnischemaField('int_field', np.int8, (), ScalarCodec(IntegerType()), False), UnischemaField('string_field', np.string_, (), ScalarCodec(StringType()), False), ]) assert len(TestSchema.fields) == 2 assert TestSchema.fields['int_field'].name == 'int_field' assert TestSchema.fields['string_field'].name == 'string_field' def test_as_spark_schema(): """Try using 'as_spark_schema' function""" TestSchema = Unischema('TestSchema', [ UnischemaField('int_field', np.int8, (), ScalarCodec(IntegerType()), False), UnischemaField('string_field', np.string_, (), ScalarCodec(StringType()), False), UnischemaField('string_field_implicit', np.string_, ()), ]) spark_schema = TestSchema.as_spark_schema() assert spark_schema.fields[0].name == 'int_field' assert spark_schema.fields[1].name == 'string_field' assert spark_schema.fields[1].dataType == StringType() assert spark_schema.fields[2].name == 'string_field_implicit' assert spark_schema.fields[2].dataType == StringType() assert TestSchema.fields['int_field'].name == 'int_field' assert TestSchema.fields['string_field'].name == 'string_field' def test_as_spark_schema_unspecified_codec_type_for_non_scalars_raises(): """Do not currently support choosing spark type automatically for non-scalar types.""" TestSchema = Unischema('TestSchema', [ UnischemaField('int_vector_unspecified_codec', np.int8, (1,)), ]) with pytest.raises(ValueError, match='has codec set to None'): TestSchema.as_spark_schema() def test_as_spark_schema_unspecified_codec_type_unknown_scalar_type_raises(): """We have a limited list of scalar types we can automatically map from numpy (+Decimal) types to spark types. Make sure that a ValueError is raised if an unknown type is used.""" TestSchema = Unischema('TestSchema', [ UnischemaField('int_vector_unspecified_codec', object, ()), ]) with pytest.raises(ValueError, match='Was not able to map type'): TestSchema.as_spark_schema() def test_dict_to_spark_row_field_validation_scalar_types(): """Test various validations done on data types when converting a dictionary to a spark row""" TestSchema = Unischema('TestSchema', [ UnischemaField('string_field', np.string_, (), ScalarCodec(StringType()), False), ]) assert isinstance(dict_to_spark_row(TestSchema, {'string_field': 'abc'}), Row) # Not a nullable field with pytest.raises(ValueError): isinstance(dict_to_spark_row(TestSchema, {'string_field': None}), Row) # Wrong field type with pytest.raises(TypeError): isinstance(dict_to_spark_row(TestSchema, {'string_field': []}), Row) def test_dict_to_spark_row_field_validation_scalar_nullable(): """Test various validations done on data types when converting a dictionary to a spark row""" TestSchema = Unischema('TestSchema', [ UnischemaField('string_field', np.string_, (), ScalarCodec(StringType()), True), UnischemaField('nullable_implicitly_set', np.string_, (), ScalarCodec(StringType()), True), ]) assert isinstance(dict_to_spark_row(TestSchema, {'string_field': None}), Row) def test_dict_to_spark_row_field_validation_ndarrays(): """Test various validations done on data types when converting a dictionary to a spark row""" TestSchema = Unischema('TestSchema', [ UnischemaField('tensor3d', np.float32, (10, 20, 30), NdarrayCodec(), False), ]) assert isinstance(dict_to_spark_row(TestSchema, {'tensor3d': np.zeros((10, 20, 30), dtype=np.float32)}), Row) # Null value into not nullable field with pytest.raises(ValueError): isinstance(dict_to_spark_row(TestSchema, {'string_field': None}), Row) # Wrong dimensions with pytest.raises(ValueError): isinstance(dict_to_spark_row(TestSchema, {'string_field': np.zeros((1, 2, 3), dtype=np.float32)}), Row) def test_dict_to_spark_row_order(): TestSchema = Unischema('TestSchema', [ UnischemaField('float_col', np.float64, ()), UnischemaField('int_col', np.int64, ()), ]) row_dict = { TestSchema.int_col.name: 3, TestSchema.float_col.name: 2.0, } spark_row = dict_to_spark_row(TestSchema, row_dict) schema_field_names = list(TestSchema.fields) assert spark_row[0] == row_dict[schema_field_names[0]] assert spark_row[1] == row_dict[schema_field_names[1]] def test_make_named_tuple(): TestSchema = Unischema('TestSchema', [ UnischemaField('string_scalar', np.string_, (), ScalarCodec(StringType()), True), UnischemaField('int32_scalar', np.int32, (), ScalarCodec(ShortType()), False), UnischemaField('uint8_scalar', np.uint8, (), ScalarCodec(ShortType()), False), UnischemaField('int32_matrix', np.float32, (10, 20, 3), NdarrayCodec(), True), UnischemaField('decimal_scalar', Decimal, (10, 20, 3), ScalarCodec(DecimalType(10, 9)), False), ]) TestSchema.make_namedtuple(string_scalar='abc', int32_scalar=10, uint8_scalar=20, int32_matrix=np.int32((10, 20, 3)), decimal_scalar=Decimal(123) / Decimal(10)) TestSchema.make_namedtuple(string_scalar=None, int32_scalar=10, uint8_scalar=20, int32_matrix=None, decimal_scalar=Decimal(123) / Decimal(10)) def test_insert_explicit_nulls(): TestSchema = Unischema('TestSchema', [ UnischemaField('nullable', np.int32, (), ScalarCodec(StringType()), True), UnischemaField('not_nullable', np.int32, (), ScalarCodec(ShortType()), False), ]) # Insert_explicit_nulls to leave the dictionary as is. row_dict = {'nullable': 0, 'not_nullable': 1} insert_explicit_nulls(TestSchema, row_dict) assert len(row_dict) == 2 assert row_dict['nullable'] == 0 assert row_dict['not_nullable'] == 1 # Insert_explicit_nulls to leave the dictionary as is. row_dict = {'nullable': None, 'not_nullable': 1} insert_explicit_nulls(TestSchema, row_dict) assert len(row_dict) == 2 assert row_dict['nullable'] is None assert row_dict['not_nullable'] == 1 # We are missing a nullable field here. insert_explicit_nulls should add a None entry. row_dict = {'not_nullable': 1} insert_explicit_nulls(TestSchema, row_dict) assert len(row_dict) == 2 assert row_dict['nullable'] is None assert row_dict['not_nullable'] == 1 # We are missing a not_nullable field here. Should raise an ValueError. row_dict = {'nullable': 0} with pytest.raises(ValueError): insert_explicit_nulls(TestSchema, row_dict) def test_create_schema_view_fails_validate(): """ Exercises code paths unischema.create_schema_view ValueError, and unischema.__str__.""" TestSchema = Unischema('TestSchema', [ UnischemaField('int_field', np.int8, (), ScalarCodec(IntegerType()), False), UnischemaField('string_field', np.string_, (), ScalarCodec(StringType()), False), ]) with pytest.raises(ValueError, match='does not belong to the schema'): TestSchema.create_schema_view([UnischemaField('id', np.int64, (), ScalarCodec(LongType()), False)]) def test_create_schema_view_using_invalid_type(): """ Exercises code paths unischema.create_schema_view ValueError, and unischema.__str__.""" TestSchema = Unischema('TestSchema', [ UnischemaField('int_field', np.int8, (), ScalarCodec(IntegerType()), False), UnischemaField('string_field', np.string_, (), ScalarCodec(StringType()), False), ]) with pytest.raises(ValueError, match='must be either a string'): TestSchema.create_schema_view([42]) def test_create_schema_view_using_unischema_fields(): TestSchema = Unischema('TestSchema', [ UnischemaField('int_field', np.int8, (), ScalarCodec(IntegerType()), False), UnischemaField('string_field', np.string_, (), ScalarCodec(StringType()), False), ]) view = TestSchema.create_schema_view([TestSchema.int_field]) assert set(view.fields.keys()) == {'int_field'} def test_create_schema_view_using_regex(): TestSchema = Unischema('TestSchema', [ UnischemaField('int_field', np.int8, (), ScalarCodec(IntegerType()), False), UnischemaField('string_field', np.string_, (), ScalarCodec(StringType()), False), ]) view = TestSchema.create_schema_view(['int.*$']) assert set(view.fields.keys()) == {'int_field'} view = TestSchema.create_schema_view([u'int.*$']) assert set(view.fields.keys()) == {'int_field'} def test_create_schema_view_using_regex_and_unischema_fields(): TestSchema = Unischema('TestSchema', [ UnischemaField('int_field', np.int8, (), ScalarCodec(IntegerType()), False), UnischemaField('string_field', np.string_, (), ScalarCodec(StringType()), False), UnischemaField('other_string_field', np.string_, (), ScalarCodec(StringType()), False), ]) view = TestSchema.create_schema_view(['int.*$', TestSchema.string_field]) assert set(view.fields.keys()) == {'int_field', 'string_field'} def test_create_schema_view_using_regex_and_unischema_fields_with_duplicates(): TestSchema = Unischema('TestSchema', [ UnischemaField('int_field', np.int8, (), ScalarCodec(IntegerType()), False), UnischemaField('string_field', np.string_, (), ScalarCodec(StringType()), False), UnischemaField('other_string_field', np.string_, (), ScalarCodec(StringType()), False), ]) view = TestSchema.create_schema_view(['int.*$', TestSchema.int_field]) assert set(view.fields.keys()) == {'int_field'} def test_create_schema_view_no_field_matches_regex(): TestSchema = Unischema('TestSchema', [ UnischemaField('int_field', np.int8, (), ScalarCodec(IntegerType()), False), UnischemaField('string_field', np.string_, (), ScalarCodec(StringType()), False), ]) view = TestSchema.create_schema_view(['bogus']) assert not view.fields def test_name_property(): TestSchema = Unischema('TestSchema', [ UnischemaField('nullable', np.int32, (), ScalarCodec(StringType()), True), ]) assert 'TestSchema' == TestSchema._name def test_field_name_conflict_with_unischema_attribute(): # fields is an existing attribute of Unischema with pytest.warns(UserWarning, match='Can not create dynamic property'): Unischema('TestSchema', [UnischemaField('fields', np.int32, (), ScalarCodec(StringType()), True)]) def test_match_unischema_fields(): TestSchema = Unischema('TestSchema', [ UnischemaField('int32', np.int32, (), None, False), UnischemaField('uint8', np.uint8, (), None, False), UnischemaField('uint16', np.uint16, (), None, False), ]) assert match_unischema_fields(TestSchema, ['.*nt.*6']) == [TestSchema.uint16] assert match_unischema_fields(TestSchema, ['nomatch']) == [] assert set(match_unischema_fields(TestSchema, ['.*'])) == set(TestSchema.fields.values()) assert set(match_unischema_fields(TestSchema, ['int32', 'uint8'])) == {TestSchema.int32, TestSchema.uint8} def test_match_unischema_fields_legacy_warning(): TestSchema = Unischema('TestSchema', [ UnischemaField('int32', np.int32, (), None, False), UnischemaField('uint8', np.uint8, (), None, False), UnischemaField('uint16', np.uint16, (), None, False), ]) # Check that no warnings are shown if the legacy and the new way of filtering produce the same results. with pytest.warns(None) as unexpected_warnings: match_unischema_fields(TestSchema, ['uint8']) assert not unexpected_warnings # uint8 and uint16 would have been matched using the old method, but not the new one with pytest.warns(UserWarning, match=r'schema_fields behavior has changed.*uint16, uint8'): assert match_unischema_fields(TestSchema, ['uint']) == [] # Now, all fields will be matched, but in different order (legacy vs current). Make sure we don't issue a warning. with pytest.warns(None) as unexpected_warnings: match_unischema_fields(TestSchema, ['int', 'uint8', 'uint16', 'int32']) assert not unexpected_warnings def test_arrow_schema_convertion(): fields = [ pa.field('string', pa.string()), pa.field('int8', pa.int8()), pa.field('int16', pa.int16()), pa.field('int32', pa.int32()), pa.field('int64', pa.int64()), pa.field('float', pa.float32()), pa.field('double', pa.float64()), pa.field('bool', pa.bool_(), False), pa.field('fixed_size_binary', pa.binary(10)), pa.field('variable_size_binary', pa.binary()), pa.field('decimal', pa.decimal128(3, 4)), pa.field('timestamp_s', pa.timestamp('s')), pa.field('timestamp_ns', pa.timestamp('ns')), pa.field('date_32', pa.date32()), pa.field('date_64', pa.date64()) ] arrow_schema = pa.schema(fields) mock_dataset = _mock_parquet_dataset([], arrow_schema) unischema = Unischema.from_arrow_schema(mock_dataset) for name in arrow_schema.names: assert getattr(unischema, name).name == name assert getattr(unischema, name).codec is None if name == 'bool': assert not getattr(unischema, name).nullable else: assert getattr(unischema, name).nullable # Test schema preserve fields order field_name_list = [f.name for f in fields] assert list(unischema.fields.keys()) == field_name_list def test_arrow_schema_convertion_with_string_partitions(): arrow_schema = pa.schema([ pa.field('int8', pa.int8()), ]) mock_dataset = _mock_parquet_dataset([pq.PartitionSet('part_name', ['a', 'b'])], arrow_schema) unischema = Unischema.from_arrow_schema(mock_dataset) assert unischema.part_name.numpy_dtype == np.str_ def test_arrow_schema_convertion_with_int_partitions(): arrow_schema = pa.schema([ pa.field('int8', pa.int8()), ]) mock_dataset = _mock_parquet_dataset([pq.PartitionSet('part_name', ['0', '1', '2'])], arrow_schema) unischema = Unischema.from_arrow_schema(mock_dataset) assert unischema.part_name.numpy_dtype == np.int64 def test_arrow_schema_convertion_fail(): arrow_schema = pa.schema([ pa.field('list_of_int', pa.float16()), ]) mock_dataset = _mock_parquet_dataset([], arrow_schema) with pytest.raises(ValueError, match='Cannot auto-create unischema due to unsupported column type'): Unischema.from_arrow_schema(mock_dataset) def test_arrow_schema_arrow_1644_list_of_struct(): arrow_schema = pa.schema([ pa.field('id', pa.string()), pa.field('list_of_struct', pa.list_(pa.struct([pa.field('a', pa.string()), pa.field('b', pa.int32())]))) ]) mock_dataset = _mock_parquet_dataset([], arrow_schema) unischema = Unischema.from_arrow_schema(mock_dataset) assert getattr(unischema, 'id').name == 'id' assert not hasattr(unischema, 'list_of_struct') def test_arrow_schema_arrow_1644_list_of_list(): arrow_schema = pa.schema([ pa.field('id', pa.string()), pa.field('list_of_list', pa.list_(pa.list_(pa.struct([pa.field('a', pa.string()), pa.field('b', pa.int32())])))) ]) mock_dataset = _mock_parquet_dataset([], arrow_schema) unischema = Unischema.from_arrow_schema(mock_dataset) assert getattr(unischema, 'id').name == 'id' assert not hasattr(unischema, 'list_of_list') def test_arrow_schema_convertion_ignore(): arrow_schema = pa.schema([ pa.field('list_of_int', pa.float16()), pa.field('struct', pa.struct([pa.field('a', pa.string()), pa.field('b', pa.int32())])), ]) mock_dataset = _mock_parquet_dataset([], arrow_schema) unischema = Unischema.from_arrow_schema(mock_dataset, omit_unsupported_fields=True) assert not hasattr(unischema, 'list_of_int') @pytest.fixture() def equality_fields(): class Fixture(object): string1 = UnischemaField('random', np.string_, (), ScalarCodec(StringType()), False) string2 = UnischemaField('random', np.string_, (), ScalarCodec(StringType()), False) string_implicit = UnischemaField('random', np.string_, ()) string_nullable = UnischemaField('random', np.string_, (), nullable=True) other_string = UnischemaField('Random', np.string_, (), ScalarCodec(StringType()), False) int1 = UnischemaField('id', np.int32, (), ScalarCodec(ShortType()), False) int2 = UnischemaField('id', np.int32, (), ScalarCodec(ShortType()), False) other_int = UnischemaField('ID', np.int32, (), ScalarCodec(ShortType()), False) return Fixture() def test_equality(equality_fields): # Use assertTrue instead of assertEqual/assertNotEqual so we don't depend on which operator (__eq__ or __ne__) # actual implementation of assert uses assert equality_fields.string1 == equality_fields.string2 assert equality_fields.string1 == equality_fields.string_implicit assert equality_fields.int1 == equality_fields.int2 assert equality_fields.string1 != equality_fields.other_string assert equality_fields.other_string != equality_fields.string_implicit assert equality_fields.int1 != equality_fields.other_int assert equality_fields.string_nullable != equality_fields.string_implicit def test_hash(equality_fields): assert hash(equality_fields.string1) == hash(equality_fields.string2) assert hash(equality_fields.int1) == hash(equality_fields.int2) assert hash(equality_fields.string1) != hash(equality_fields.other_string) assert hash(equality_fields.int1) != hash(equality_fields.other_int) def test_new_gt_255_compatible_namedtuple(): fields_count = 1000 field_names = ['f{}'.format(i) for i in range(fields_count)] values = list(range(1000)) huge_tuple = _new_gt_255_compatible_namedtuple('HUGE_TUPLE', field_names) huge_tuple_instance = huge_tuple(**dict(zip(field_names, values))) assert len(huge_tuple_instance) == fields_count assert huge_tuple_instance.f764 == 764 def test_fullmatch(): assert _fullmatch('abc', 'abc') assert _fullmatch('^abc', 'abc') assert _fullmatch('abc$', 'abc') assert _fullmatch('a.c', 'abc') assert _fullmatch('.*abcdef', 'abcdef') assert _fullmatch('abc.*', 'abcdef') assert _fullmatch('.*c.*', 'abcdef') assert _fullmatch('', '') assert not _fullmatch('abc', 'xyz') assert not _fullmatch('abc', 'abcx') assert not _fullmatch('abc', 'xabc')