from __future__ import absolute_import, division, print_function, unicode_literals from datetime import datetime from io import BytesIO from unittest import TestCase import pyarrow as pa import pyarrow.parquet as pq import sqlalchemy as sa from spectrify.utils.parquet import Writer class UncloseableBytesIO(BytesIO): """ pyarrow tries to close the BytesIO instance, which frees memory. This will trick it into thinking it has been closed... """ def __exit__(self, exc_type, exc_val, exc_tb): self.really_close() def close(self, *args, **kwargs): pass def really_close(self, *args, **kwargs): super(UncloseableBytesIO, self).close(*args, **kwargs) class TestParquetWriter(TestCase): def setUp(self): self.sa_meta = sa.MetaData() self.data = [ [17.124, 1.12, 3.14, 13.37], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [True, None, False, True], ['string 1', 'string 2', None, 'string 3'], [datetime(2007, 7, 13, 1, 23, 34, 123456), None, datetime(2006, 1, 13, 12, 34, 56, 432539), datetime(2010, 8, 13, 5, 46, 57, 437699), ], ["Test Text", "Some#More#Test# Text", "!@#$%%^&*&", None], ] self.table = sa.Table( 'unit_test_table', self.sa_meta, sa.Column('real_col', sa.REAL), sa.Column('bigint_col', sa.BIGINT), sa.Column('int_col', sa.INTEGER), sa.Column('smallint_col', sa.SMALLINT), sa.Column('bool_col', sa.BOOLEAN), sa.Column('str_col', sa.VARCHAR), sa.Column('timestamp_col', sa.TIMESTAMP), sa.Column('plaintext_col', sa.TEXT), ) self.expected_datatypes = [ pa.float32(), pa.int64(), pa.int32(), pa.int16(), pa.bool_(), pa.string(), pa.timestamp('ns'), pa.string(), ] def test_write(self): # Write out test file with UncloseableBytesIO() as write_buffer: with Writer(write_buffer, self.table) as writer: writer.write_row_group(self.data) file_bytes = write_buffer.getvalue() # Read in test file read_buffer = BytesIO(file_bytes) with pa.PythonFile(read_buffer, mode='r') as infile: # Verify data parq_table = pq.read_table(infile) written_data = list(parq_table.to_pydict().values()) tuples_by_data_type = zip(self.data, written_data) for i in tuples_by_data_type: tuples_by_order = zip(i[0], i[1]) for j in tuples_by_order: self.assertAlmostEquals(j[0], j[1], places=5) # Verify parquet file schema for i, field in enumerate(parq_table.schema): self.assertEqual(field.type.id, self.expected_datatypes[i].id) # Ensure timestamp column was written with int96; right now # there is no way to see except to check that the unit on # the timestamp type is 'ns' ts_col = parq_table.schema.field_by_name('timestamp_col') self.assertEqual(ts_col.type.unit, 'ns')