Python pyspark.sql.types.StructType() Examples

The following are code examples for showing how to use pyspark.sql.types.StructType(). They are from open source Python projects. You can vote up the examples you like or vote down the ones you don't like.

Example 1
Project: pb2df   Author: bridgewell   File: conftest.py    BSD 3-Clause "New" or "Revised" License 9 votes vote down vote up
def basic_msg_schema():
    schema = types.StructType([
        types.StructField('double_field', types.DoubleType()),
        types.StructField('float_field', types.FloatType()),
        types.StructField('int32_field', types.IntegerType()),
        types.StructField('int64_field', types.LongType()),
        types.StructField('uint32_field', types.IntegerType()),
        types.StructField('uint64_field', types.LongType()),
        types.StructField('sint32_field', types.IntegerType()),
        types.StructField('sint64_field', types.LongType()),
        types.StructField('fixed32_field', types.IntegerType()),
        types.StructField('fixed64_field', types.LongType()),
        types.StructField('sfixed32_field', types.IntegerType()),
        types.StructField('sfixed64_field', types.LongType()),
        types.StructField('bool_field', types.BooleanType()),
        types.StructField('string_field', types.StringType()),
        types.StructField('bytes_field', types.BinaryType()),
        types.StructField('enum_field', types.IntegerType()),
    ])
    return schema 
Example 2
Project: spark-deep-learning   Author: databricks   File: imageIO.py    Apache License 2.0 8 votes vote down vote up
def filesToDF(sc, path, numPartitions=None):
    """
    Read files from a directory to a DataFrame.

    :param sc: SparkContext.
    :param path: str, path to files.
    :param numPartition: int, number or partitions to use for reading files.
    :return: DataFrame, with columns: (filePath: str, fileData: BinaryType)
    """
    numPartitions = numPartitions or sc.defaultParallelism
    schema = StructType([StructField("filePath", StringType(), False),
                         StructField("fileData", BinaryType(), False)])
    rdd = sc.binaryFiles(
        path, minPartitions=numPartitions).repartition(numPartitions)
    rdd = rdd.map(lambda x: (x[0], bytearray(x[1])))
    return rdd.toDF(schema) 
Example 3
Project: search-MjoLniR   Author: wikimedia   File: transform.py    MIT License 7 votes vote down vote up
def _merge_schemas(*schemas: T.StructType):
    """Merge one or more spark schemas into a new schema"""
    fields = cast(Dict[str, T.StructField], {})
    errors = []
    for schema in schemas:
        for field in schema:
            if field.name not in fields:
                fields[field.name] = field
            elif field != fields[field.name]:
                errors.append('Incompatible fields: {} != {}'.format(field, fields[field.name]))
    if errors:
        raise Exception('\n'.join(errors))
    return T.StructType(list(fields.values()))


# Primary input schema from which most everything else is derived 
Example 4
Project: sparkly   Author: tubular   File: test_writer.py    Apache License 2.0 7 votes vote down vote up
def test_db(self):
        df = self.spark.createDataFrame(
            data=[
                ('k1', 'k14', [1, 14, 141]),
            ],
            schema=T.StructType([
                T.StructField('key_1', T.StringType()),
                T.StructField('key_2', T.StringType()),
                T.StructField('aux_data', T.ArrayType(T.IntegerType())),
            ])
        )

        df.write_ext.by_url('redis://redis.docker/1?keyBy=key_2&maxPipelineSize=3')

        redis_client = redis.StrictRedis('redis.docker', db=1)

        self.assertEqual(redis_client.keys(), [b'k14'])

        written_data = json.loads(redis_client.get('k14'))
        expected = {'key_1': 'k1', 'key_2': 'k14', 'aux_data': [1, 14, 141]}
        self.assertEqual(written_data, expected) 
Example 5
Project: sparkly   Author: tubular   File: test_functions.py    Apache License 2.0 7 votes vote down vote up
def test_nullable_values(self):
        df = self.spark.createDataFrame(
            data=[
                ('1', 'test1', None, 3),
                ('1', None, 2, 4),
                ('2', 'test2', 3, 1),
                ('2', 'test3', 4, 2),
            ],
            schema=T.StructType([
                T.StructField('id', T.StringType(), nullable=True),
                T.StructField('value1', T.StringType(), nullable=True),
                T.StructField('value2', T.IntegerType(), nullable=True),
                T.StructField('target', T.IntegerType(), nullable=True),
            ]),
        )

        df = (
            df
            .groupBy('id')
            .agg(
                F.max('target').alias('target'),
                *[
                    SF.argmax(col, 'target').alias(col)
                    for col in df.columns
                    if col not in ['id', 'target']
                ]
            )
        )

        self.assertDataFrameEqual(
            df,
            [
                {'id': '1', 'target': 4, 'value1': None, 'value2': 2},
                {'id': '2', 'target': 2, 'value1': 'test3', 'value2': 4},
            ],
        ) 
Example 6
Project: search-MjoLniR   Author: wikimedia   File: transform.py    MIT License 6 votes vote down vote up
def _simplify_data_type(data_type: T.DataType) -> Tuple:
    """Simplify datatype into a tuple of equality information we care about

    Most notably this ignores nullability concerns due to hive not
    being able to represent not null in it's schemas.
    """
    try:
        # Normalize UDT into it's sql form. Allows comparison of schemas
        # from hive and spark.
        sql_type = data_type.sqlType()  # type: ignore
    except AttributeError:
        sql_type = data_type

    if isinstance(sql_type, T.StructType):
        return ('StructType', [(field.name, _simplify_data_type(field.dataType)) for field in sql_type])
    elif isinstance(sql_type, T.ArrayType):
        return ('ArrayType', _simplify_data_type(sql_type.elementType))
    else:
        return (type(sql_type).__name__,) 
Example 7
Project: search-MjoLniR   Author: wikimedia   File: transform.py    MIT License 6 votes vote down vote up
def _verify_schema_compatability(expect: T.StructType, have: T.StructType) -> List[str]:
    """Verify all expected fields and types are present

    Allows additional columns in the `have` schema. Additionally
    allows relaxing nullability """
    errors = []
    for expect_field in expect:
        try:
            have_field = have[expect_field.name]
        except KeyError:
            errors.append('Field {} missing. Have: {}'.format(expect_field.name, ','.join(have.names)))
            continue
        expect_type = _simplify_data_type(expect_field.dataType)
        have_type = _simplify_data_type(have_field.dataType)
        if expect_type != have_type:
            errors.append('Field {} has incompatible data types: expect {} != have {}'.format(
                          expect_field.name, expect_type, have_type))
    return errors 
Example 8
Project: petastorm   Author: uber   File: unischema.py    Apache License 2.0 6 votes vote down vote up
def as_spark_schema(self):
        """Returns an object derived from the unischema as spark schema.

        Example:

        >>> spark.createDataFrame(dataset_rows,
        >>>                       SomeSchema.as_spark_schema())
        """
        # Lazy loading pyspark to avoid creating pyspark dependency on data reading code path
        # (currently works only with make_batch_reader)
        import pyspark.sql.types as sql_types

        schema_entries = []
        for field in self._fields.values():
            spark_type = _field_spark_dtype(field)
            schema_entries.append(sql_types.StructField(field.name, spark_type, field.nullable))

        return sql_types.StructType(schema_entries) 
Example 9
Project: sparkly   Author: tubular   File: utils.py    Apache License 2.0 6 votes vote down vote up
def _init_struct(*args):
    struct = T.StructType()
    for item in args:
        field_name, field_type = item.split(':', 1)
        field_type = parse_schema(field_type)
        struct.add(field_name, field_type)

    return struct 
Example 10
Project: sparkly   Author: tubular   File: test_functions.py    Apache License 2.0 6 votes vote down vote up
def test_inner_join(self):
        first_df = self.spark.createDataFrame(
            data=[(1, ), (2, ), (3, )],
            schema=T.StructType([T.StructField('id', T.IntegerType())]),
        )
        second_df = self.spark.createDataFrame(
            data=[(2, ), (3, ), (4, )],
            schema=T.StructType([T.StructField('id', T.IntegerType())]),
        )
        third_df = self.spark.createDataFrame(
            data=[(3, ), (4, ), (5, )],
            schema=T.StructType([T.StructField('id', T.IntegerType())]),
        )

        joined_df = SF.multijoin([first_df, second_df, third_df], on='id', how='inner')

        self.assertDataFrameEqual(joined_df, [{'id': 3}]) 
Example 11
Project: sparkly   Author: tubular   File: test_functions.py    Apache License 2.0 6 votes vote down vote up
def test_outer_join(self):
        first_df = self.spark.createDataFrame(
            data=[(1, ), (2, ), (3, )],
            schema=T.StructType([T.StructField('id', T.IntegerType())]),
        )
        second_df = self.spark.createDataFrame(
            data=[(2, ), (3, ), (4, )],
            schema=T.StructType([T.StructField('id', T.IntegerType())]),
        )
        third_df = self.spark.createDataFrame(
            data=[(3, ), (4, ), (5, )],
            schema=T.StructType([T.StructField('id', T.IntegerType())]),
        )

        joined_df = SF.multijoin([first_df, second_df, third_df], on='id', how='outer')

        self.assertDataFrameEqual(joined_df, [{'id': i} for i in [1, 2, 3, 4, 5]]) 
Example 12
Project: sparkly   Author: tubular   File: test_functions.py    Apache License 2.0 6 votes vote down vote up
def test_coalescing_light_type_mismatch(self):
        first_df = self.spark.createDataFrame(
            data=[(1, None), (2, 'hi'), (3, None), (4, 'may')],
            schema=T.StructType([
                T.StructField('id', T.IntegerType()),
                T.StructField('value', T.StringType()),
            ]),
        )
        second_df = self.spark.createDataFrame(
            data=[(2, 2), (3, 3), (4, None)],
            schema=T.StructType([
                T.StructField('id', T.IntegerType()),
                T.StructField('value', T.IntegerType()),
            ]),
        )

        joined_df = SF.multijoin([first_df, second_df], on='id', how='inner', coalesce=['value'])

        self.assertDataFrameEqual(
            joined_df,
            [{'id': 2, 'value': 'hi'}, {'id': 3, 'value': '3'}, {'id': 4, 'value': 'may'}],
        ) 
Example 13
Project: sparkly   Author: tubular   File: test_functions.py    Apache License 2.0 6 votes vote down vote up
def test_coalescing_heavy_type_mismatch(self):
        first_df = self.spark.createDataFrame(
            data=[(1, None), (2, 'hi'), (3, None), (4, 'may')],
            schema=T.StructType([
                T.StructField('id', T.IntegerType()),
                T.StructField('value', T.StringType()),
            ]),
        )
        second_df = self.spark.createDataFrame(
            data=[(2, [2, ]), (3, [3, ]), (4, None)],
            schema=T.StructType([
                T.StructField('id', T.IntegerType()),
                T.StructField('value', T.ArrayType(T.IntegerType())),
            ]),
        )

        with self.assertRaises(U.AnalysisException):
            SF.multijoin([first_df, second_df], on='id', how='inner', coalesce=['value']) 
Example 14
Project: sparkly   Author: tubular   File: test_functions.py    Apache License 2.0 6 votes vote down vote up
def test_no_cases(self):
        df = self.spark.createDataFrame(
            data=[('one', ), ('two', ), ('three', ), ('hi', )],
            schema=T.StructType([T.StructField('name', T.StringType())]),
        )

        df = df.withColumn('value', SF.switch_case('name'))

        self.assertDataFrameEqual(
            df,
            [
                {'name': 'one', 'value': None},
                {'name': 'two', 'value': None},
                {'name': 'three', 'value': None},
                {'name': 'hi', 'value': None},
            ],
        ) 
Example 15
Project: sparkly   Author: tubular   File: test_functions.py    Apache License 2.0 6 votes vote down vote up
def test_default_as_a_lit(self):
        df = self.spark.createDataFrame(
            data=[('one', ), ('two', ), ('three', ), ('hi', )],
            schema=T.StructType([T.StructField('name', T.StringType())]),
        )

        df = df.withColumn('value', SF.switch_case('name', default=0))

        self.assertDataFrameEqual(
            df,
            [
                {'name': 'one', 'value': 0},
                {'name': 'two', 'value': 0},
                {'name': 'three', 'value': 0},
                {'name': 'hi', 'value': 0},
            ],
        ) 
Example 16
Project: sparkly   Author: tubular   File: test_functions.py    Apache License 2.0 6 votes vote down vote up
def test_default_as_a_column(self):
        df = self.spark.createDataFrame(
            data=[('one', ), ('two', ), ('three', ), ('hi', )],
            schema=T.StructType([T.StructField('name', T.StringType())]),
        )

        df = df.withColumn('value', SF.switch_case('name', default=F.col('name')))

        self.assertDataFrameEqual(
            df,
            [
                {'name': 'one', 'value': 'one'},
                {'name': 'two', 'value': 'two'},
                {'name': 'three', 'value': 'three'},
                {'name': 'hi', 'value': 'hi'},
            ],
        ) 
Example 17
Project: sparkly   Author: tubular   File: test_functions.py    Apache License 2.0 6 votes vote down vote up
def test_switch_as_a_column_cases_as_kwargs(self):
        df = self.spark.createDataFrame(
            data=[('one', ), ('two', ), ('three', ), ('hi', )],
            schema=T.StructType([T.StructField('name', T.StringType())]),
        )

        df = df.withColumn(
            'value',
            SF.switch_case(F.col('name'), one=1, two=2, three=3, default=0),
        )

        self.assertDataFrameEqual(
            df,
            [
                {'name': 'one', 'value': 1},
                {'name': 'two', 'value': 2},
                {'name': 'three', 'value': 3},
                {'name': 'hi', 'value': 0},
            ],
        ) 
Example 18
Project: sparkly   Author: tubular   File: test_functions.py    Apache License 2.0 6 votes vote down vote up
def test_dict_cases_override_kwarg_cases(self):
        df = self.spark.createDataFrame(
            data=[('one', ), ('two', ), ('three', ), ('hi', )],
            schema=T.StructType([T.StructField('name', T.StringType())]),
        )

        df = df.withColumn(
            'value',
            SF.switch_case('name', {'one': 11, 'three': 33}, one=1, two=2, three=3, default=0),
        )

        self.assertDataFrameEqual(
            df,
            [
                {'name': 'one', 'value': 11},
                {'name': 'two', 'value': 2},
                {'name': 'three', 'value': 33},
                {'name': 'hi', 'value': 0},
            ],
        ) 
Example 19
Project: sparkly   Author: tubular   File: test_functions.py    Apache License 2.0 6 votes vote down vote up
def test_cases_condition_constant_as_an_arbitrary_value(self):
        df = self.spark.createDataFrame(
            data=[(1, ), (2, ), (3, ), (0, )],
            schema=T.StructType([T.StructField('value', T.IntegerType())]),
        )

        df = df.withColumn(
            'name',
            SF.switch_case('value', {1: 'one', 2: 'two', 3: 'three'}, default='hi'),
        )

        self.assertDataFrameEqual(
            df,
            [
                {'name': 'one', 'value': 1},
                {'name': 'two', 'value': 2},
                {'name': 'three', 'value': 3},
                {'name': 'hi', 'value': 0},
            ],
        ) 
Example 20
Project: sparkly   Author: tubular   File: test_utils.py    Apache License 2.0 6 votes vote down vote up
def test_structs_nested_subset(self):
        schema_has(
            T.StructType([
                T.StructField(
                    'f1',
                    T.ArrayType(T.StructType([
                        T.StructField('f11', T.IntegerType()),
                        T.StructField('f12', T.StringType()),
                    ])),
                ),
            ]),
            T.StructType([
                T.StructField(
                    'f1',
                    T.ArrayType(T.StructType([T.StructField('f11', T.IntegerType())])),
                ),
            ]),
        ) 
Example 21
Project: LearningApacheSpark   Author: runawayhorse001   File: tests.py    MIT License 6 votes vote down vote up
def test_union_with_udt(self):
        from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
        row1 = (1.0, ExamplePoint(1.0, 2.0))
        row2 = (2.0, ExamplePoint(3.0, 4.0))
        schema = StructType([StructField("label", DoubleType(), False),
                             StructField("point", ExamplePointUDT(), False)])
        df1 = self.spark.createDataFrame([row1], schema)
        df2 = self.spark.createDataFrame([row2], schema)

        result = df1.union(df2).orderBy("label").collect()
        self.assertEqual(
            result,
            [
                Row(label=1.0, point=ExamplePoint(1.0, 2.0)),
                Row(label=2.0, point=ExamplePoint(3.0, 4.0))
            ]
        ) 
Example 22
Project: incubator-spot   Author: apache   File: streaming.py    Apache License 2.0 5 votes vote down vote up
def dstream(self):
        '''
            Return the schema of this :class:`DataFrame` as a
        :class:`pyspark.sql.types.StructType`.
        '''
        return self.__dstream\
            .map(lambda x: x[1])\
            .flatMap(lambda x: x)\
            .map(lambda x: _analyzer(x)) 
Example 23
Project: abalon   Author: Tagar   File: sparkutils.py    Apache License 2.0 5 votes vote down vote up
def dfZipWithIndex(df, offset=1, colName="rowId"):
    '''
        Enumerates dataframe rows in native order, like rdd.ZipWithIndex(), but on a dataframe
        and preserves a schema

        :param df: source dataframe
        :param offset: adjustment to zipWithIndex()'s index
        :param colName: name of the index column
    '''

    sparkutils_init()

    new_schema = StructType(
        [StructField(colName, LongType(), True)]  # new added field in front
        + df.schema.fields  # previous schema
    )

    zipped_rdd = df.rdd.zipWithIndex()

    # py2:
    # new_rdd = zipped_rdd.map(lambda (row, rowId): ([rowId + offset] + list(row)))

    new_rdd = zipped_rdd.map(lambda rec: ([rec[1] + offset] + list(rec[1])))

    return spark.createDataFrame(new_rdd, new_schema)


### 
Example 24
Project: pb2df   Author: bridgewell   File: conftest.py    BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def simple_msg_schema():
    schema = types.StructType([
        types.StructField('field', types.IntegerType()),
    ])
    return schema 
Example 25
Project: pb2df   Author: bridgewell   File: conftest.py    BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def labeled_msg_schema():
    schema = types.StructType([
        types.StructField('optional_field', types.BooleanType()),
        types.StructField('required_field', types.DoubleType(),
                          nullable=False),
        types.StructField('repeated_field', types.ArrayType(
            types.IntegerType(), containsNull=False)),
        types.StructField('default_field', types.StringType()),
    ])
    return schema 
Example 26
Project: pb2df   Author: bridgewell   File: conftest.py    BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def nested_msg_schema(simple_msg_schema):
    schema = types.StructType([
        types.StructField('optional_nested_field', simple_msg_schema),
        types.StructField('required_nested_field', simple_msg_schema,
                          nullable=False),
        types.StructField('repeated_nested_field', types.ArrayType(
            simple_msg_schema, containsNull=False)),
    ])
    return schema 
Example 27
Project: pb2df   Author: bridgewell   File: conftest.py    BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def custom_field_msg_schema():
    schema = types.StructType([
        types.StructField('timestamp_field', types.TimestampType(),
                          nullable=False),
    ])
    return schema 
Example 28
Project: search-MjoLniR   Author: wikimedia   File: es_hits.py    MIT License 5 votes vote down vote up
def transform(df, url_list=None, brokers=None, **kwargs):
    if brokers and url_list:
        raise ValueError('cannot specify brokers and url_list')
    if brokers:
        rdd = transform_from_kafka(df, brokers, **kwargs)
    else:
        rdd = transform_from_elasticsearch(df, url_list, **kwargs)
    return df.sql_ctx.createDataFrame(rdd, T.StructType([
        df.schema['wikiid'],
        df.schema['query'],
        df.schema['norm_query'],
        T.StructField('hit_page_ids', T.ArrayType(T.IntegerType()), nullable=False),
    ])) 
Example 29
Project: search-MjoLniR   Author: wikimedia   File: transform.py    MIT License 5 votes vote down vote up
def _verify_schema_equality(expect: T.StructType, have: T.StructType) -> List[str]:
    """Verify the dataframe and table have equal schemas"""
    def resolve(schema, field_name) -> Optional[Tuple]:
        try:
            field = schema[field_name]
        except KeyError:
            return None
        return _simplify_data_type(field.dataType)

    errors = []
    for field_name in set(expect.names).union(have.names):
        expect_type = resolve(expect, field_name)
        if expect_type is None:
            errors.append('Extra field in provided schema: {}'.format(field_name))
            continue

        have_type = resolve(have, field_name)
        if have_type is None:
            errors.append('Missing field in provided schema: {}'.format(field_name))
            continue

        if expect_type != have_type:
            fmt = 'Column {} of type {} does not match expected {}'
            errors.append(fmt.format(field_name, have_type, expect_type))
            continue
        # TODO: Test nullability? But hive doesn't track nullability, everything is nullable.
    return errors 
Example 30
Project: search-MjoLniR   Author: wikimedia   File: transform.py    MIT License 5 votes vote down vote up
def read_partition(
    spark: SparkSession,
    table: str,
    partition_spec: Mapping[str, str],
    schema: Optional[T.StructType] = None,
    direct_parquet_read: bool = False
) -> DataFrame:
    """Read a single partition from a hive table.

    Verifies the partition specification describes a complete partition,
    that the partition exists, and optionally that the table is compatible
    with an expected schema. The partition could still be empty.
    """
    # We don't need to do anything with the result, our goal is to
    # trigger AnalysisException when the arguments are invalid.
    spark.sql(_describe_partition_ql(table, partition_spec)).collect()

    partition_cond = F.lit(True)
    for k, v in partition_spec.items():
        partition_cond &= F.col(k) == v
    df = spark.read.table(table).where(partition_cond)
    # The df we have now has types defined by the hive table, but this downgrades
    # non-standard types like VectorUDT() to it's sql equivalent. Use the first
    # df to find the files, then read them directly.
    if direct_parquet_read:
        input_files = list(df._jdf.inputFiles())  # type: ignore
        input_dirs = set(os.path.dirname(path) for path in input_files)
        if len(input_dirs) != 1:
            raise Exception('Expected single directory containing partition data: [{}]'.format(
                '],['.join(input_files)))
        df = spark.read.parquet(list(input_dirs)[0])
    if schema is not None:
        # TODO: This only allows extra top level columns, anything
        # nested must be exactly the same. Fine for now.
        _verify_schema_compatability(schema, df.schema)
        df = df.select(*(field.name for field in schema))
    # Drop partitioning columns. These are not part of the mjolnir transformations, and
    # are only an implementation detail of putting them on disk and tracking history.
    return df.drop(*partition_spec.keys()) 
Example 31
Project: search-MjoLniR   Author: wikimedia   File: transform.py    MIT License 5 votes vote down vote up
def typed_transformer(
    schema_in: Optional[T.StructType] = None,
    schema_out: Optional[T.StructType] = None,
    context: Optional[str] = None
) -> Callable[[Callable[..., Transformer]], Callable[..., Transformer]]:
    """Decorates a transformer factory with schema validation

    An idiom in transform is calling a function to return a Transform. This
    decorator can be applied to those factory functions to return transformers
    that apply runtime schema validation.
    """
    def decorate(fn: Callable[..., Transformer]) -> Callable[..., Transformer]:
        def error_context(kind: str) -> str:
            return 'While checking {} {}:'.format(fn.__name__ if context is None else context, kind)

        @functools.wraps(fn)
        def factory(*args, **kwargs) -> Transformer:
            transformer = fn(*args, **kwargs)

            @functools.wraps(transformer)
            def transform(df_in: DataFrame) -> DataFrame:
                if schema_in is not None:
                    check_schema(df_in, schema_in, error_context('schema_in'))
                    df_in = df_in.select(*schema_in.names)
                df_out = transformer(df_in)
                if schema_out is not None:
                    check_schema(df_out, schema_out, error_context('schema_out'))
                    df_out = df_out.select(*schema_out.names)
                return df_out
            return transform
        return factory
    return decorate


# Shared schemas between the primary mjolnir transformations. Transformations
# may require a schema with slightly more columns than they require to keep
# the total number of schemas low. 
Example 32
Project: search-MjoLniR   Author: wikimedia   File: test_transform.py    MIT License 5 votes vote down vote up
def test_schema_comparison(expect: T.StructType, have: T.StructType, compatible: bool, equal: bool) -> None:
    if equal and not compatible:
        raise Exception('Invalid constraint, can not be equal but not compatible')
    # functions return list of errors, not bool() returns true when everything is ok
    assert compatible is not bool(mt._verify_schema_compatability(expect, have))
    assert equal is not bool(mt._verify_schema_equality(expect, have)) 
Example 33
Project: hops-util-py   Author: logicalclocks   File: test_featurestore.py    Apache License 2.0 5 votes vote down vote up
def _sample_spark_dataframe(self, spark):
        """ Creates a sample dataframe for testing"""
        sqlContext = SQLContext(spark.sparkContext)
        schema = StructType([StructField("equipo_id", IntegerType(), True),
                             StructField("equipo_presupuesto", FloatType(), True),
                             StructField("equipo_posicion", IntegerType(), True)
                             ])
        sample_df = sqlContext.createDataFrame([(999, 41251.52, 1), (998, 1319.4, 8), (997, 21219.1, 2)], schema)
        return sample_df 
Example 34
Project: decorators4DS   Author: urigoren   File: pyspark_udf.py    MIT License 5 votes vote down vote up
def _rec_build_types(t):
    if type(t) == list:
        return T.ArrayType(_rec_build_types(t[0]))
    elif type(t) == dict:
        k, v = list(t.items())[0]
        return T.MapType(_rec_build_types(k), _rec_build_types(v), True)
    elif type(t) == tuple:
        return T.StructType([T.StructField("v_" + str(i), _rec_build_types(f), True) for i, f in enumerate(t)])
    elif t in T._type_mappings:
        return T._type_mappings[t]()
    else:
        raise TypeError(repr(t) + " is not supported") 
Example 35
Project: monasca-transform   Author: openstack   File: transform_utils.py    Apache License 2.0 5 votes vote down vote up
def _get_instance_usage_schema():
        """get instance usage schema."""

        # Initialize columns for all string fields
        columns = ["tenant_id", "user_id", "resource_uuid",
                   "geolocation", "region", "zone", "host", "project_id",
                   "aggregated_metric_name", "firstrecord_timestamp_string",
                   "lastrecord_timestamp_string",
                   "usage_date", "usage_hour", "usage_minute",
                   "aggregation_period"]

        columns_struct_fields = [StructField(field_name, StringType(), True)
                                 for field_name in columns]

        # Add columns for non-string fields
        columns_struct_fields.append(StructField("firstrecord_timestamp_unix",
                                                 DoubleType(), True))
        columns_struct_fields.append(StructField("lastrecord_timestamp_unix",
                                                 DoubleType(), True))
        columns_struct_fields.append(StructField("quantity",
                                                 DoubleType(), True))
        columns_struct_fields.append(StructField("record_count",
                                                 DoubleType(), True))

        columns_struct_fields.append(StructField("processing_meta",
                                                 MapType(StringType(),
                                                         StringType(),
                                                         True),
                                                 True))

        columns_struct_fields.append(StructField("extra_data_map",
                                                 MapType(StringType(),
                                                         StringType(),
                                                         True),
                                                 True))
        schema = StructType(columns_struct_fields)

        return schema 
Example 36
Project: monasca-transform   Author: openstack   File: transform_utils.py    Apache License 2.0 5 votes vote down vote up
def _get_mon_metric_json_schema():
        """get the schema of the incoming monasca metric."""

        metric_struct_field = StructField(
            "metric",
            StructType([StructField("dimensions",
                                    MapType(StringType(),
                                            StringType(),
                                            True),
                                    True),
                        StructField("value_meta",
                                    MapType(StringType(),
                                            StringType(),
                                            True),
                                    True),
                        StructField("name", StringType(), True),
                        StructField("timestamp", StringType(), True),
                        StructField("value", StringType(), True)]), True)

        meta_struct_field = StructField("meta",
                                        MapType(StringType(),
                                                StringType(),
                                                True),
                                        True)

        creation_time_struct_field = StructField("creation_time",
                                                 StringType(), True)

        schema = StructType([creation_time_struct_field,
                             meta_struct_field, metric_struct_field])
        return schema 
Example 37
Project: monasca-transform   Author: openstack   File: transform_utils.py    Apache License 2.0 5 votes vote down vote up
def _get_pre_transform_specs_df_schema():
        """get pre_transform_specs df schema."""

        # FIXME: change when pre_transform_specs df is finalized

        event_type = StructField("event_type", StringType(), True)

        metric_id_list = StructField("metric_id_list",
                                     ArrayType(StringType(),
                                               containsNull=False),
                                     True)
        required_raw_fields_list = StructField("required_raw_fields_list",
                                               ArrayType(StringType(),
                                                         containsNull=False),
                                               True)

        event_processing_params = \
            StructField("event_processing_params",
                        StructType([StructField("set_default_zone_to",
                                                StringType(), True),
                                    StructField("set_default_geolocation_to",
                                                StringType(), True),
                                    StructField("set_default_region_to",
                                                StringType(), True),
                                    ]), True)

        schema = StructType([event_processing_params, event_type,
                             metric_id_list, required_raw_fields_list])

        return schema 
Example 38
Project: monasca-transform   Author: openstack   File: transform_utils.py    Apache License 2.0 5 votes vote down vote up
def _get_grouping_results_df_schema(group_by_column_list):
        """get grouping results schema."""

        group_by_field_list = [StructField(field_name, StringType(), True)
                               for field_name in group_by_column_list]

        # Initialize columns for string fields
        columns = ["firstrecord_timestamp_string",
                   "lastrecord_timestamp_string"]

        columns_struct_fields = [StructField(field_name, StringType(), True)
                                 for field_name in columns]

        # Add columns for non-string fields
        columns_struct_fields.append(StructField("firstrecord_timestamp_unix",
                                                 DoubleType(), True))
        columns_struct_fields.append(StructField("lastrecord_timestamp_unix",
                                                 DoubleType(), True))
        columns_struct_fields.append(StructField("firstrecord_quantity",
                                                 DoubleType(), True))
        columns_struct_fields.append(StructField("lastrecord_quantity",
                                                 DoubleType(), True))
        columns_struct_fields.append(StructField("record_count",
                                                 DoubleType(), True))

        instance_usage_schema_part = StructType(columns_struct_fields)

        grouping_results = \
            StructType([StructField("grouping_key",
                                    StringType(), True),
                        StructField("results",
                                    instance_usage_schema_part,
                                    True),
                        StructField("grouping_key_dict",
                                    StructType(group_by_field_list))])

        # schema = \
        #     StructType([StructField("GroupingResults", grouping_results)])
        return grouping_results 
Example 39
Project: sparkly   Author: tubular   File: utils.py    Apache License 2.0 5 votes vote down vote up
def parse_schema(schema):
    """Generate schema by its string definition.

    It's basically an opposite action to `DataType.simpleString` method.
    Supports all atomic types (like string, int, float...) and complex types (array, map, struct)
    except DecimalType.

    Usages:
        >>> parse_schema('string')
        StringType
        >>> parse_schema('int')
        IntegerType
        >>> parse_schema('array<int>')
        ArrayType(IntegerType,true)
        >>> parse_schema('map<string,int>')
        MapType(StringType,IntegerType,true)
        >>> parse_schema('struct<a:int,b:string>')
        StructType(List(StructField(a,IntegerType,true),StructField(b,StringType,true)))
        >>> parse_schema('unsupported')
        Traceback (most recent call last):
        ...
        sparkly.exceptions.UnsupportedDataType: Cannot parse type from string: "unsupported"
    """
    field_type, args_string = re.match('(\w+)<?(.*)>?$', schema).groups()
    args = _parse_args(args_string) if args_string else []

    if field_type in ATOMIC_TYPES:
        return ATOMIC_TYPES[field_type]()
    elif field_type in COMPLEX_TYPES:
        return COMPLEX_TYPES[field_type](*args)
    else:
        message = 'Cannot parse type from string: "{}"'.format(field_type)
        raise UnsupportedDataType(message) 
Example 40
Project: sparkly   Author: tubular   File: testing.py    Apache License 2.0 5 votes vote down vote up
def __init__(
        self,
        spark,
        df_schema,
        key_deserializer,
        value_deserializer,
        host,
        topic,
        port=9092,
    ):
        """Initialize context manager

        Parameters `key_deserializer` and `value_deserializer` are callables
        which get bytes as input and should return python structures as output.

        Args:
            spark (SparklySession): currently active SparklySession
            df_schema (pyspark.sql.types.StructType): schema of dataframe to be generated
            key_deserializer (function): function used to deserialize the key
            value_deserializer (function): function used to deserialize the value
            host (basestring): host or ip address of the kafka server to connect to
            topic (basestring): Kafka topic to monitor
            port (int): port number of the Kafka server to connect to
        """
        self.spark = spark
        self.topic = topic
        self.df_schema = df_schema
        self.key_deser, self.val_deser = key_deserializer, value_deserializer
        self.host, self.port = host, port
        self._df = None
        self.count = 0

        kafka_client = SimpleClient(host)
        kafka_client.ensure_topic_exists(topic) 
Example 41
Project: sparkly   Author: tubular   File: test_writer.py    Apache License 2.0 5 votes vote down vote up
def test_simple_key_uncompressed(self):
        df = self.spark.createDataFrame(
            data=[
                ('k1', 'k14', [1, 14, 141]),
                ('k1', 'k12', [1, 12, 121]),
                ('k1', 'k11', [1, 11, 111]),
                ('k1', 'k13', [1, 13, 131]),
            ],
            schema=T.StructType([
                T.StructField('key_1', T.StringType()),
                T.StructField('key_2', T.StringType()),
                T.StructField('aux_data', T.ArrayType(T.IntegerType())),
            ])
        )

        df.write_ext.redis(
            key_by=['key_2'],
            max_pipeline_size=3,
            host='redis.docker',
        )

        redis_client = redis.StrictRedis('redis.docker')

        self.assertRowsEqual(
            redis_client.keys(),
            [b'k11', b'k12', b'k13', b'k14'],
            ignore_order=True,
        )

        written_data = [
            json.loads(redis_client.get(key)) for key in ['k11', 'k12', 'k13', 'k14']
        ]

        expected = [
            {'key_1': 'k1', 'key_2': 'k11', 'aux_data': [1, 11, 111]},
            {'key_1': 'k1', 'key_2': 'k12', 'aux_data': [1, 12, 121]},
            {'key_1': 'k1', 'key_2': 'k13', 'aux_data': [1, 13, 131]},
            {'key_1': 'k1', 'key_2': 'k14', 'aux_data': [1, 14, 141]},
        ]

        self.assertEqual(written_data, expected) 
Example 42
Project: sparkly   Author: tubular   File: test_writer.py    Apache License 2.0 5 votes vote down vote up
def test_group_by(self):
        df = self.spark.createDataFrame(
            data=[
                ('k4', 'k14', [1, 14, 141]),
                ('k1', 'k12', [1, 12, 121]),
                ('k1', 'k11', [1, 11, 111]),
                ('k1', 'k13', [1, 13, 131]),
            ],
            schema=T.StructType([
                T.StructField('key_1', T.StringType()),
                T.StructField('key_2', T.StringType()),
                T.StructField('aux_data', T.ArrayType(T.IntegerType())),
            ])
        )

        df.write_ext.redis(
            key_by=['key_1'],
            group_by_key=True,
            max_pipeline_size=2,
            host='redis.docker',
        )

        redis_client = redis.StrictRedis('redis.docker')

        self.assertRowsEqual(redis_client.keys(), [b'k1', b'k4'], ignore_order=True)

        written_data = [json.loads(redis_client.get(key)) for key in [b'k1', b'k4']]

        expected = [
            [
                {'key_1': 'k1', 'key_2': 'k11', 'aux_data': [1, 11, 111]},
                {'key_1': 'k1', 'key_2': 'k12', 'aux_data': [1, 12, 121]},
                {'key_1': 'k1', 'key_2': 'k13', 'aux_data': [1, 13, 131]},
            ],
            [{'key_1': 'k4', 'key_2': 'k14', 'aux_data': [1, 14, 141]}],
        ]

        self.assertRowsEqual(written_data, expected, ignore_order=True) 
Example 43
Project: sparkly   Author: tubular   File: test_writer.py    Apache License 2.0 5 votes vote down vote up
def test_db(self):
        df = self.spark.createDataFrame(
            data=[
                ('k1', 'k14', [1, 14, 141]),
            ],
            schema=T.StructType([
                T.StructField('key_1', T.StringType()),
                T.StructField('key_2', T.StringType()),
                T.StructField('aux_data', T.ArrayType(T.IntegerType())),
            ])
        )

        df.write_ext.redis(
            key_by=['key_2'],
            max_pipeline_size=3,
            host='redis.docker',
            db=1,
        )

        redis_client = redis.StrictRedis('redis.docker', db=1)

        self.assertEqual(redis_client.keys(), [b'k14'])

        written_data = json.loads(redis_client.get('k14'))
        expected = {'key_1': 'k1', 'key_2': 'k14', 'aux_data': [1, 14, 141]}
        self.assertEqual(written_data, expected) 
Example 44
Project: sparkly   Author: tubular   File: test_writer.py    Apache License 2.0 5 votes vote down vote up
def test_simple_key_uncompressed(self):
        df = self.spark.createDataFrame(
            data=[
                ('k1', 'k14', [1, 14, 141]),
                ('k1', 'k12', [1, 12, 121]),
                ('k1', 'k11', [1, 11, 111]),
                ('k1', 'k13', [1, 13, 131]),
            ],
            schema=T.StructType([
                T.StructField('key_1', T.StringType()),
                T.StructField('key_2', T.StringType()),
                T.StructField('aux_data', T.ArrayType(T.IntegerType())),
            ])
        )

        df.write_ext.by_url('redis://redis.docker?keyBy=key_2&maxPipelineSize=3')

        redis_client = redis.StrictRedis('redis.docker')

        self.assertRowsEqual(
            redis_client.keys(),
            [b'k11', b'k12', b'k13', b'k14'],
            ignore_order=True,
        )

        written_data = [
            json.loads(redis_client.get(key)) for key in ['k11', 'k12', 'k13', 'k14']
        ]

        expected = [
            {'key_1': 'k1', 'key_2': 'k11', 'aux_data': [1, 11, 111]},
            {'key_1': 'k1', 'key_2': 'k12', 'aux_data': [1, 12, 121]},
            {'key_1': 'k1', 'key_2': 'k13', 'aux_data': [1, 13, 131]},
            {'key_1': 'k1', 'key_2': 'k14', 'aux_data': [1, 14, 141]},
        ]

        self.assertEqual(written_data, expected) 
Example 45
Project: sparkly   Author: tubular   File: test_writer.py    Apache License 2.0 5 votes vote down vote up
def test_composite_key_compressed(self):
        df = self.spark.createDataFrame(
            data=[
                ('k1', 'k14', [1, 14, 141]),
                ('k1', 'k12', [1, 12, 121]),
                ('k1', 'k11', [1, 11, 111]),
                ('k1', 'k13', [1, 13, 131]),
            ],
            schema=T.StructType([
                T.StructField('key_1', T.StringType()),
                T.StructField('key_2', T.StringType()),
                T.StructField('aux_data', T.ArrayType(T.IntegerType())),
            ])
        )

        df.write_ext.by_url(
            'redis://redis.docker?keyBy=key_1,key_2&compression=gzip&maxPipelineSize=3'
        )

        redis_client = redis.StrictRedis('redis.docker')

        self.assertRowsEqual(
            redis_client.keys(),
            [b'k1.k11', b'k1.k12', b'k1.k13', b'k1.k14'],
            ignore_order=True,
        )

        written_data = [
            json.loads(self._gzip_decompress(redis_client.get(key)))
            for key in ['k1.k11', 'k1.k12', 'k1.k13', 'k1.k14']
        ]

        expected = [
            {'key_1': 'k1', 'key_2': 'k11', 'aux_data': [1, 11, 111]},
            {'key_1': 'k1', 'key_2': 'k12', 'aux_data': [1, 12, 121]},
            {'key_1': 'k1', 'key_2': 'k13', 'aux_data': [1, 13, 131]},
            {'key_1': 'k1', 'key_2': 'k14', 'aux_data': [1, 14, 141]},
        ]

        self.assertEqual(written_data, expected) 
Example 46
Project: sparkly   Author: tubular   File: test_writer.py    Apache License 2.0 5 votes vote down vote up
def test_group_by(self):
        df = self.spark.createDataFrame(
            data=[
                ('k4', 'k14', [1, 14, 141]),
                ('k1', 'k12', [1, 12, 121]),
                ('k1', 'k11', [1, 11, 111]),
                ('k1', 'k13', [1, 13, 131]),
            ],
            schema=T.StructType([
                T.StructField('key_1', T.StringType()),
                T.StructField('key_2', T.StringType()),
                T.StructField('aux_data', T.ArrayType(T.IntegerType())),
            ])
        )

        df.write_ext.by_url(
            'redis://redis.docker?keyBy=key_1&groupByKey=true&maxPipelineSize=2'
        )

        redis_client = redis.StrictRedis('redis.docker')

        self.assertRowsEqual(redis_client.keys(), [b'k1', b'k4'], ignore_order=True)

        written_data = [json.loads(redis_client.get(key)) for key in [b'k1', b'k4']]

        expected = [
            [
                {'key_1': 'k1', 'key_2': 'k11', 'aux_data': [1, 11, 111]},
                {'key_1': 'k1', 'key_2': 'k12', 'aux_data': [1, 12, 121]},
                {'key_1': 'k1', 'key_2': 'k13', 'aux_data': [1, 13, 131]},
            ],
            [{'key_1': 'k4', 'key_2': 'k14', 'aux_data': [1, 14, 141]}],
        ]

        self.assertRowsEqual(written_data, expected, ignore_order=True) 
Example 47
Project: sparkly   Author: tubular   File: test_writer.py    Apache License 2.0 5 votes vote down vote up
def test_exclude_null_fields(self):
        df = self.spark.createDataFrame(
            data=[
                ('k1', 'k14', [1, 14, 141]),
                (None, 'k12', [1, 12, 121]),
                ('k1', 'k11', None),
                ('k1', 'k13', [1, 13, 131]),
            ],
            schema=T.StructType([
                T.StructField('key_1', T.StringType()),
                T.StructField('key_2', T.StringType()),
                T.StructField('aux_data', T.ArrayType(T.IntegerType())),
            ])
        )

        redis_client = redis.StrictRedis('redis.docker')

        df.write_ext.by_url(
            'redis://redis.docker?keyBy=key_2&keyPrefix=hello&excludeNullFields=true'
        )

        self.assertRowsEqual(
            redis_client.keys(),
            [b'hello.k11', b'hello.k12', b'hello.k13', b'hello.k14'],
            ignore_order=True,
        )

        written_data = [
            json.loads(redis_client.get(key))
            for key in ['hello.k11', 'hello.k12', 'hello.k13', 'hello.k14']
        ]

        expected = [
            {'key_1': 'k1', 'key_2': 'k11'},
            {'key_2': 'k12', 'aux_data': [1, 12, 121]},
            {'key_1': 'k1', 'key_2': 'k13', 'aux_data': [1, 13, 131]},
            {'key_1': 'k1', 'key_2': 'k14', 'aux_data': [1, 14, 141]},
        ]

        self.assertEqual(written_data, expected) 
Example 48
Project: sparkly   Author: tubular   File: test_functions.py    Apache License 2.0 5 votes vote down vote up
def test_cases_values_as_a_column(self):
        df = self.spark.createDataFrame(
            data=[(1, ), (2, ), (3, ), (0, )],
            schema=T.StructType([T.StructField('value', T.IntegerType())]),
        )

        df = df.withColumn(
            'value_2',
            SF.switch_case(
                'value',
                {
                    1: 11 * F.col('value'),
                    2: F.col('value') * F.col('value'),
                    'hi': 5,
                },
                default=F.col('value'),
            ),
        )

        self.assertDataFrameEqual(
            df,
            [
                {'value': 1, 'value_2': 11},
                {'value': 2, 'value_2': 4},
                {'value': 3, 'value_2': 3},
                {'value': 0, 'value_2': 0},
            ],
        ) 
Example 49
Project: sparkly   Author: tubular   File: test_functions.py    Apache License 2.0 5 votes vote down vote up
def test_switch_case_with_custom_operand_between(self):
        df = self.spark.createDataFrame(
            data=[(1, ), (2, ), (3, ), (0, )],
            schema=T.StructType([T.StructField('value', T.IntegerType())]),
        )

        df = df.withColumn(
            'value_2',
            SF.switch_case(
                'value',
                {
                    (1, 1): 'aloha',
                    (2, 3): 'hi',
                },
                operand=lambda c, v: c.between(*v),
            ),
        )

        self.assertDataFrameEqual(
            df,
            [
                {'value': 1, 'value_2': 'aloha'},
                {'value': 2, 'value_2': 'hi'},
                {'value': 3, 'value_2': 'hi'},
                {'value': 0, 'value_2': None},
            ],
        ) 
Example 50
Project: sparkly   Author: tubular   File: test_functions.py    Apache License 2.0 5 votes vote down vote up
def test_non_nullable_values(self):
        df = self.spark.createDataFrame(
            data=[
                ('1', 'test1', None, 3),
                ('1', None, 2, 4),
                ('2', 'test2', 3, 1),
                ('2', 'test3', 4, 2),
            ],
            schema=T.StructType([
                T.StructField('id', T.StringType(), nullable=True),
                T.StructField('value1', T.StringType(), nullable=True),
                T.StructField('value2', T.IntegerType(), nullable=True),
                T.StructField('target', T.IntegerType(), nullable=True),
            ]),
        )

        df = (
            df
            .groupBy('id')
            .agg(
                F.max('target').alias('target'),
                *[
                    SF.argmax(col, 'target', condition=F.col(col).isNotNull()).alias(col)
                    for col in df.columns
                    if col not in ['id', 'target']
                ]
            )
        )

        self.assertDataFrameEqual(
            df,
            [
                {'id': '1', 'target': 4, 'value1': 'test1', 'value2': 2},
                {'id': '2', 'target': 2, 'value1': 'test3', 'value2': 4},
            ],
        ) 
Example 51
Project: sparkly   Author: tubular   File: test_functions.py    Apache License 2.0 5 votes vote down vote up
def test_break_ties(self):
        df = self.spark.createDataFrame(
            data=[
                ('1', 'test1', 1, 4),
                ('1', 'test2', 1, 3),
                ('2', 'test3', 1, 4),
                ('2', 'test4', 2, 3),
            ],
            schema=T.StructType([
                T.StructField('id', T.StringType(), nullable=True),
                T.StructField('value', T.StringType(), nullable=True),
                T.StructField('target1', T.IntegerType(), nullable=True),
                T.StructField('target2', T.IntegerType(), nullable=True),
            ]),
        )

        df = (
            df
            .groupBy('id')
            .agg(
                SF.argmax('value', ['target1', 'target2']).alias('value')
            )
        )

        self.assertDataFrameEqual(
            df,
            [
                {'id': '1', 'value': 'test1'},
                {'id': '2', 'value': 'test4'},
            ],
        ) 
Example 52
Project: sparkly   Author: tubular   File: test_functions.py    Apache License 2.0 5 votes vote down vote up
def test_with_column_expressions(self):
        df = self.spark.createDataFrame(
            data=[
                ('1', None, 'test1', 1, 4),
                ('1', 'test2', 'test2_1', 1, 3),
                ('2', 'test3', None, 1, 4),
                ('2', 'test4', 'test5', 2, 6),
            ],
            schema=T.StructType([
                T.StructField('id', T.StringType(), nullable=True),
                T.StructField('value1', T.StringType(), nullable=True),
                T.StructField('value2', T.StringType(), nullable=True),
                T.StructField('target1', T.IntegerType(), nullable=True),
                T.StructField('target2', T.IntegerType(), nullable=True),
            ]),
        )

        df = (
            df
            .groupBy('id')
            .agg(
                SF.argmax(
                    F.coalesce(F.col('value1'), F.col('value2')),
                    F.col('target1') + F.col('target2'),
                ).alias('value'),
            )
        )

        self.assertDataFrameEqual(
            df,
            [
                {'id': '1', 'value': 'test1'},
                {'id': '2', 'value': 'test4'},
            ],
        ) 
Example 53
Project: sparkly   Author: tubular   File: test_utils.py    Apache License 2.0 5 votes vote down vote up
def test_structs_equal(self):
        schema_has(
            T.StructType([
                T.StructField('f1', T.IntegerType()),
                T.StructField('f2', T.FloatType()),
                T.StructField('f3', T.StringType()),
            ]),
            T.StructType([
                T.StructField('f3', T.StringType()),
                T.StructField('f2', T.FloatType()),
                T.StructField('f1', T.IntegerType()),
            ]),
        ) 
Example 54
Project: sparkly   Author: tubular   File: test_utils.py    Apache License 2.0 5 votes vote down vote up
def test_structs_equal_with_dict(self):
        schema_has(
            T.StructType([
                T.StructField('f1', T.IntegerType()),
                T.StructField('f2', T.FloatType()),
                T.StructField('f3', T.StringType()),
            ]),
            {
                'f1': T.IntegerType(),
                'f2': T.FloatType(),
                'f3': T.StringType(),
            },
        ) 
Example 55
Project: sparkly   Author: tubular   File: test_utils.py    Apache License 2.0 5 votes vote down vote up
def test_structs_subset(self):
        schema_has(
            T.StructType([
                T.StructField('f1', T.IntegerType()),
                T.StructField('f2', T.FloatType()),
                T.StructField('f3', T.StringType()),
            ]),
            T.StructType([
                T.StructField('f2', T.FloatType()),
            ]),
        ) 
Example 56
Project: sparkly   Author: tubular   File: test_utils.py    Apache License 2.0 5 votes vote down vote up
def test_arrays_nested_subset(self):
        schema_has(
            T.ArrayType(T.ArrayType(T.StructType([
                T.StructField('f1', T.ArrayType(T.LongType())),
                T.StructField('f2', T.ArrayType(T.StringType())),
            ]))),
            T.ArrayType(T.ArrayType(T.StructType([
                T.StructField('f1', T.ArrayType(T.LongType()))
            ]))),
        ) 
Example 57
Project: sparkly   Author: tubular   File: test_utils.py    Apache License 2.0 5 votes vote down vote up
def test_undefined_field(self):
        with six.assertRaisesRegex(self, KeyError, 'f2'):
            schema_has(
                T.StructType([T.StructField('f1', T.IntegerType())]),
                T.StructType([T.StructField('f2', T.LongType())]),
            )

        with six.assertRaisesRegex(self, KeyError, 'f1\.element\.s2'):
            schema_has(
                T.StructType([
                    T.StructField(
                        'f1',
                        T.ArrayType(T.StructType([T.StructField('s1', T.IntegerType())])),
                    ),
                ]),
                T.StructType([
                    T.StructField(
                        'f1',
                        T.ArrayType(T.StructType([T.StructField('s2', T.LongType())])),
                    ),
                ]),
            )

        with six.assertRaisesRegex(self, TypeError, 'element is IntegerType, expected LongType'):
            schema_has(
                T.ArrayType(T.IntegerType()),
                T.ArrayType(T.LongType()),
            ) 
Example 58
Project: LearningApacheSpark   Author: runawayhorse001   File: base.py    MIT License 5 votes vote down vote up
def transformSchema(self, schema):
        inputType = schema[self.getInputCol()].dataType
        self.validateInputType(inputType)
        if self.getOutputCol() in schema.names:
            raise ValueError("Output column %s already exists." % self.getOutputCol())
        outputFields = copy.copy(schema.fields)
        outputFields.append(StructField(self.getOutputCol(),
                                        self.outputDataType(),
                                        nullable=False))
        return StructType(outputFields) 
Example 59
Project: LearningApacheSpark   Author: runawayhorse001   File: evaluation.py    MIT License 5 votes vote down vote up
def __init__(self, scoreAndLabels):
        sc = scoreAndLabels.ctx
        sql_ctx = SQLContext.getOrCreate(sc)
        df = sql_ctx.createDataFrame(scoreAndLabels, schema=StructType([
            StructField("score", DoubleType(), nullable=False),
            StructField("label", DoubleType(), nullable=False)]))
        java_class = sc._jvm.org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
        java_model = java_class(df._jdf)
        super(BinaryClassificationMetrics, self).__init__(java_model) 
Example 60
Project: LearningApacheSpark   Author: runawayhorse001   File: evaluation.py    MIT License 5 votes vote down vote up
def __init__(self, predictionAndObservations):
        sc = predictionAndObservations.ctx
        sql_ctx = SQLContext.getOrCreate(sc)
        df = sql_ctx.createDataFrame(predictionAndObservations, schema=StructType([
            StructField("prediction", DoubleType(), nullable=False),
            StructField("observation", DoubleType(), nullable=False)]))
        java_class = sc._jvm.org.apache.spark.mllib.evaluation.RegressionMetrics
        java_model = java_class(df._jdf)
        super(RegressionMetrics, self).__init__(java_model) 
Example 61
Project: LearningApacheSpark   Author: runawayhorse001   File: evaluation.py    MIT License 5 votes vote down vote up
def __init__(self, predictionAndLabels):
        sc = predictionAndLabels.ctx
        sql_ctx = SQLContext.getOrCreate(sc)
        df = sql_ctx.createDataFrame(predictionAndLabels, schema=StructType([
            StructField("prediction", DoubleType(), nullable=False),
            StructField("label", DoubleType(), nullable=False)]))
        java_class = sc._jvm.org.apache.spark.mllib.evaluation.MulticlassMetrics
        java_model = java_class(df._jdf)
        super(MulticlassMetrics, self).__init__(java_model) 
Example 62
Project: LearningApacheSpark   Author: runawayhorse001   File: tests.py    MIT License 5 votes vote down vote up
def test_apply_schema_to_dict_and_rows(self):
        schema = StructType().add("b", StringType()).add("a", IntegerType())
        input = [{"a": 1}, {"b": "coffee"}]
        rdd = self.sc.parallelize(input)
        for verify in [False, True]:
            df = self.spark.createDataFrame(input, schema, verifySchema=verify)
            df2 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
            self.assertEqual(df.schema, df2.schema)

            rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None))
            df3 = self.spark.createDataFrame(rdd, schema, verifySchema=verify)
            self.assertEqual(10, df3.count())
            input = [Row(a=x, b=str(x)) for x in range(10)]
            df4 = self.spark.createDataFrame(input, schema, verifySchema=verify)
            self.assertEqual(10, df4.count()) 
Example 63
Project: LearningApacheSpark   Author: runawayhorse001   File: tests.py    MIT License 5 votes vote down vote up
def test_apply_schema(self):
        from datetime import date, datetime
        rdd = self.sc.parallelize([(127, -128, -32768, 32767, 2147483647, 1.0,
                                    date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1),
                                    {"a": 1}, (2,), [1, 2, 3], None)])
        schema = StructType([
            StructField("byte1", ByteType(), False),
            StructField("byte2", ByteType(), False),
            StructField("short1", ShortType(), False),
            StructField("short2", ShortType(), False),
            StructField("int1", IntegerType(), False),
            StructField("float1", FloatType(), False),
            StructField("date1", DateType(), False),
            StructField("time1", TimestampType(), False),
            StructField("map1", MapType(StringType(), IntegerType(), False), False),
            StructField("struct1", StructType([StructField("b", ShortType(), False)]), False),
            StructField("list1", ArrayType(ByteType(), False), False),
            StructField("null1", DoubleType(), True)])
        df = self.spark.createDataFrame(rdd, schema)
        results = df.rdd.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1,
                             x.date1, x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1))
        r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1),
             datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
        self.assertEqual(r, results.first())

        df.createOrReplaceTempView("table2")
        r = self.spark.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
                           "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " +
                           "float1 + 1.5 as float1 FROM table2").first()

        self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r)) 
Example 64
Project: LearningApacheSpark   Author: runawayhorse001   File: tests.py    MIT License 5 votes vote down vote up
def test_udt(self):
        from pyspark.sql.types import _parse_datatype_json_string, _infer_type, _make_type_verifier
        from pyspark.sql.tests import ExamplePointUDT, ExamplePoint

        def check_datatype(datatype):
            pickled = pickle.loads(pickle.dumps(datatype))
            assert datatype == pickled
            scala_datatype = self.spark._jsparkSession.parseDataType(datatype.json())
            python_datatype = _parse_datatype_json_string(scala_datatype.json())
            assert datatype == python_datatype

        check_datatype(ExamplePointUDT())
        structtype_with_udt = StructType([StructField("label", DoubleType(), False),
                                          StructField("point", ExamplePointUDT(), False)])
        check_datatype(structtype_with_udt)
        p = ExamplePoint(1.0, 2.0)
        self.assertEqual(_infer_type(p), ExamplePointUDT())
        _make_type_verifier(ExamplePointUDT())(ExamplePoint(1.0, 2.0))
        self.assertRaises(ValueError, lambda: _make_type_verifier(ExamplePointUDT())([1.0, 2.0]))

        check_datatype(PythonOnlyUDT())
        structtype_with_udt = StructType([StructField("label", DoubleType(), False),
                                          StructField("point", PythonOnlyUDT(), False)])
        check_datatype(structtype_with_udt)
        p = PythonOnlyPoint(1.0, 2.0)
        self.assertEqual(_infer_type(p), PythonOnlyUDT())
        _make_type_verifier(PythonOnlyUDT())(PythonOnlyPoint(1.0, 2.0))
        self.assertRaises(
            ValueError,
            lambda: _make_type_verifier(PythonOnlyUDT())([1.0, 2.0])) 
Example 65
Project: LearningApacheSpark   Author: runawayhorse001   File: tests.py    MIT License 5 votes vote down vote up
def test_simple_udt_in_df(self):
        schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
        df = self.spark.createDataFrame(
            [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
            schema=schema)
        df.collect() 
Example 66
Project: LearningApacheSpark   Author: runawayhorse001   File: tests.py    MIT License 5 votes vote down vote up
def test_nested_udt_in_df(self):
        schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT()))
        df = self.spark.createDataFrame(
            [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)],
            schema=schema)
        df.collect()

        schema = StructType().add("key", LongType()).add("val",
                                                         MapType(LongType(), PythonOnlyUDT()))
        df = self.spark.createDataFrame(
            [(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)],
            schema=schema)
        df.collect() 
Example 67
Project: LearningApacheSpark   Author: runawayhorse001   File: tests.py    MIT License 5 votes vote down vote up
def test_complex_nested_udt_in_df(self):
        from pyspark.sql.functions import udf

        schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
        df = self.spark.createDataFrame(
            [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
            schema=schema)
        df.collect()

        gd = df.groupby("key").agg({"val": "collect_list"})
        gd.collect()
        udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema))
        gd.select(udf(*gd)).collect() 
Example 68
Project: LearningApacheSpark   Author: runawayhorse001   File: tests.py    MIT License 5 votes vote down vote up
def test_cast_to_string_with_udt(self):
        from pyspark.sql.tests import ExamplePointUDT, ExamplePoint
        from pyspark.sql.functions import col
        row = (ExamplePoint(1.0, 2.0), PythonOnlyPoint(3.0, 4.0))
        schema = StructType([StructField("point", ExamplePointUDT(), False),
                             StructField("pypoint", PythonOnlyUDT(), False)])
        df = self.spark.createDataFrame([row], schema)

        result = df.select(col('point').cast('string'), col('pypoint').cast('string')).head()
        self.assertEqual(result, Row(point=u'(1.0, 2.0)', pypoint=u'[3.0, 4.0]')) 
Example 69
Project: incubator-spot   Author: apache   File: streaming.py    Apache License 2.0 4 votes vote down vote up
def schema(self):
        '''
            Return the data type that represents a row from the received data list.
        '''
        from pyspark.sql.types import IntegerType, LongType, StringType, StructField, StructType

        return StructType(
            [
                StructField('p_date', StringType(), True),
                StructField('p_time', StringType(), True),
                StructField('clientip', StringType(), True),
                StructField('host', StringType(), True),
                StructField('reqmethod', StringType(), True),
                StructField('useragent', StringType(), True),
                StructField('resconttype', StringType(), True),
                StructField('duration', LongType(), True),
                StructField('username', StringType(), True),
                StructField('authgroup', StringType(), True),
                StructField('exceptionid', StringType(), True),
                StructField('filterresult', StringType(), True),
                StructField('webcat', StringType(), True),
                StructField('referer', StringType(), True),
                StructField('respcode', StringType(), True),
                StructField('action', StringType(), True),
                StructField('urischeme', StringType(), True),
                StructField('uriport', StringType(), True),
                StructField('uripath', StringType(), True),
                StructField('uriquery', StringType(), True),
                StructField('uriextension', StringType(), True),
                StructField('serverip', StringType(), True),
                StructField('scbytes', IntegerType(), True),
                StructField('csbytes', IntegerType(), True),
                StructField('virusid', StringType(), True),
                StructField('bcappname', StringType(), True),
                StructField('bcappoper', StringType(), True),
                StructField('fulluri', StringType(), True),
                StructField('y', StringType(), True),
                StructField('m', StringType(), True),
                StructField('d', StringType(), True),
                StructField('h', StringType(), True)
            ]
        ) 
Example 70
Project: search-MjoLniR   Author: wikimedia   File: features.py    MIT License 4 votes vote down vote up
def collect_from_ltr_plugin_and_kafka(df, brokers, model, feature_names_accu, indices=None):
    """Collect feature vectors from elasticsearch via kafka

    Pushes queries into a kafka topic and retrieves results from a second kafka topic.
    A daemon must be running on relforge to collect the queries and produce results.

    Parameters
    ----------
    df : pyspark.sql.DataFrame
        Source dataframe containing wikiid, query and hit_page_id fields
        to collect feature vectors for.
    brokers : list of str
        List of kafka brokers used to bootstrap access into the kafka cluster.
    model : string
        definition of the model/featureset: "featureset:name", "model:name" or "featureset:name@storeName"
    feature_names_accu : Accumulator
        used to collect feature names
    indices : dict, optional
        map from wikiid to elasticsearch index to query. If wikiid is
        not present the wikiid will be used as index name. (Default: None)
    """
    mjolnir.spark.assert_columns(df, ['wikiid', 'query', 'hit_page_id'])
    if indices is None:
        indices = {}
    eltType, name, store = mjolnir.utils.explode_ltr_model_definition(model)
    log_query = LtrLoggingQuery(eltType, name, store)

    def kafka_handle_response(record):
        assert record['status_code'] == 200
        parsed = json.loads(record['text'])
        response = parsed['responses'][0]
        meta = record['meta']

        for hit_page_id, features in extract_ltr_log_feature_values(response, feature_names_accu):
            yield [meta['wikiid'], meta['query'], hit_page_id, features]

    rdd = mjolnir.kafka.client.msearch(
        df.groupBy('wikiid', 'query').agg(F.collect_set('hit_page_id').alias('hit_page_ids')),
        client_config=brokers,
        meta_keys=['wikiid', 'query'],
        create_es_query=lambda row: log_query.make_msearch(row, indices),
        handle_response=kafka_handle_response)

    return df.sql_ctx.createDataFrame(rdd, T.StructType([
        df.schema['wikiid'], df.schema['query'], df.schema['hit_page_id'],
        T.StructField('features', VectorUDT(), nullable=False)
        # We could have gotten duplicate data from kafka. Clean them up.
    ])).drop_duplicates(['wikiid', 'query', 'hit_page_id']) 
Example 71
Project: monasca-transform   Author: openstack   File: transform_utils.py    Apache License 2.0 4 votes vote down vote up
def _get_record_store_df_schema():
        """get instance usage schema."""

        columns = ["event_timestamp_string",
                   "event_type", "event_quantity_name",
                   "event_status", "event_version",
                   "record_type", "resource_uuid", "tenant_id",
                   "user_id", "region", "zone",
                   "host", "project_id",
                   "event_date", "event_hour", "event_minute",
                   "event_second", "metric_group", "metric_id"]

        columns_struct_fields = [StructField(field_name, StringType(), True)
                                 for field_name in columns]

        # Add a column for a non-string fields
        columns_struct_fields.insert(0,
                                     StructField("event_timestamp_unix",
                                                 DoubleType(), True))
        columns_struct_fields.insert(0,
                                     StructField("event_quantity",
                                                 DoubleType(), True))

        # map to metric meta
        columns_struct_fields.append(StructField("meta",
                                                 MapType(StringType(),
                                                         StringType(),
                                                         True),
                                                 True))
        # map to dimensions
        columns_struct_fields.append(StructField("dimensions",
                                                 MapType(StringType(),
                                                         StringType(),
                                                         True),
                                                 True))
        # map to value_meta
        columns_struct_fields.append(StructField("value_meta",
                                                 MapType(StringType(),
                                                         StringType(),
                                                         True),
                                                 True))

        schema = StructType(columns_struct_fields)

        return schema 
Example 72
Project: sparkly   Author: tubular   File: test_writer.py    Apache License 2.0 4 votes vote down vote up
def test_composite_key_compressed(self):
        df = self.spark.createDataFrame(
            data=[
                ('k1', 'k14', [1, 14, 141]),
                ('k1', 'k12', [1, 12, 121]),
                ('k1', 'k11', [1, 11, 111]),
                ('k1', 'k13', [1, 13, 131]),
            ],
            schema=T.StructType([
                T.StructField('key_1', T.StringType()),
                T.StructField('key_2', T.StringType()),
                T.StructField('aux_data', T.ArrayType(T.IntegerType())),
            ])
        )

        df.write_ext.redis(
            key_by=['key_1', 'key_2'],
            compression='gzip',
            max_pipeline_size=3,
            host='redis.docker',
        )

        redis_client = redis.StrictRedis('redis.docker')

        self.assertRowsEqual(
            redis_client.keys(),
            [b'k1.k11', b'k1.k12', b'k1.k13', b'k1.k14'],
            ignore_order=True,
        )

        written_data = [
            json.loads(self._gzip_decompress(redis_client.get(key)))
            for key in ['k1.k11', 'k1.k12', 'k1.k13', 'k1.k14']
        ]

        expected = [
            {'key_1': 'k1', 'key_2': 'k11', 'aux_data': [1, 11, 111]},
            {'key_1': 'k1', 'key_2': 'k12', 'aux_data': [1, 12, 121]},
            {'key_1': 'k1', 'key_2': 'k13', 'aux_data': [1, 13, 131]},
            {'key_1': 'k1', 'key_2': 'k14', 'aux_data': [1, 14, 141]},
        ]

        self.assertEqual(written_data, expected) 
Example 73
Project: sparkly   Author: tubular   File: test_writer.py    Apache License 2.0 4 votes vote down vote up
def test_exclude_null_fields(self):
        df = self.spark.createDataFrame(
            data=[
                ('k1', 'k14', [1, 14, 141]),
                (None, 'k12', [1, 12, 121]),
                ('k1', 'k11', None),
                ('k1', 'k13', [1, 13, 131]),
            ],
            schema=T.StructType([
                T.StructField('key_1', T.StringType()),
                T.StructField('key_2', T.StringType()),
                T.StructField('aux_data', T.ArrayType(T.IntegerType())),
            ])
        )

        redis_client = redis.StrictRedis('redis.docker')

        df.write_ext.redis(
            key_by=['key_2'],
            key_prefix='hello',
            exclude_null_fields=True,
            host='redis.docker',
        )

        self.assertRowsEqual(
            redis_client.keys(),
            [b'hello.k11', b'hello.k12', b'hello.k13', b'hello.k14'],
            ignore_order=True,
        )

        written_data = [
            json.loads(redis_client.get(key))
            for key in ['hello.k11', 'hello.k12', 'hello.k13', 'hello.k14']
        ]

        expected = [
            {'key_1': 'k1', 'key_2': 'k11'},
            {'key_2': 'k12', 'aux_data': [1, 12, 121]},
            {'key_1': 'k1', 'key_2': 'k13', 'aux_data': [1, 13, 131]},
            {'key_1': 'k1', 'key_2': 'k14', 'aux_data': [1, 14, 141]},
        ]

        self.assertEqual(written_data, expected) 
Example 74
Project: sparkly   Author: tubular   File: test_writer.py    Apache License 2.0 4 votes vote down vote up
def test_exclude_null_fields(self):
        df = self.spark.createDataFrame(
            data=[
                ('k1', 'k14', [1, 14, 141]),
                (None, 'k12', [1, 12, 121]),
                ('k1', 'k11', None),
                ('k1', 'k13', [1, 13, 131]),
            ],
            schema=T.StructType([
                T.StructField('key_1', T.StringType()),
                T.StructField('key_2', T.StringType()),
                T.StructField('aux_data', T.ArrayType(T.IntegerType())),
            ])
        )

        redis_client = redis.StrictRedis('redis.docker')

        df.write_ext.redis(
            key_by=['key_2'],
            key_prefix='hello',
            exclude_null_fields=True,
            host='redis.docker',
        )

        self.assertRowsEqual(
            redis_client.keys(),
            [b'hello.k11', b'hello.k12', b'hello.k13', b'hello.k14'],
            ignore_order=True,
        )

        written_data = [
            json.loads(redis_client.get(key))
            for key in ['hello.k11', 'hello.k12', 'hello.k13', 'hello.k14']
        ]

        expected = [
            {'key_1': 'k1', 'key_2': 'k11'},
            {'key_2': 'k12', 'aux_data': [1, 12, 121]},
            {'key_1': 'k1', 'key_2': 'k13', 'aux_data': [1, 13, 131]},
            {'key_1': 'k1', 'key_2': 'k14', 'aux_data': [1, 14, 141]},
        ]

        self.assertEqual(written_data, expected) 
Example 75
Project: sparkly   Author: tubular   File: test_writer.py    Apache License 2.0 4 votes vote down vote up
def test_expiration(self):
        df = self.spark.createDataFrame(
            data=[
                ('k1', 'k14', [1, 14, 141]),
                ('k1', 'k12', [1, 12, 121]),
                ('k1', 'k11', [1, 11, 111]),
                ('k1', 'k13', [1, 13, 131]),
            ],
            schema=T.StructType([
                T.StructField('key_1', T.StringType()),
                T.StructField('key_2', T.StringType()),
                T.StructField('aux_data', T.ArrayType(T.IntegerType())),
            ])
        )

        df.write_ext.redis(
            key_by=['key_2'],
            key_prefix='hello',
            expire=2,
            host='redis.docker',
        )

        redis_client = redis.StrictRedis('redis.docker')

        self.assertRowsEqual(
            redis_client.keys(),
            [b'hello.k11', b'hello.k12', b'hello.k13', b'hello.k14'],
            ignore_order=True,
        )

        written_data = [
            json.loads(redis_client.get(key))
            for key in ['hello.k11', 'hello.k12', 'hello.k13', 'hello.k14']
        ]

        expected = [
            {'key_1': 'k1', 'key_2': 'k11', 'aux_data': [1, 11, 111]},
            {'key_1': 'k1', 'key_2': 'k12', 'aux_data': [1, 12, 121]},
            {'key_1': 'k1', 'key_2': 'k13', 'aux_data': [1, 13, 131]},
            {'key_1': 'k1', 'key_2': 'k14', 'aux_data': [1, 14, 141]},
        ]

        self.assertEqual(written_data, expected)

        sleep(3)

        self.assertEqual(redis_client.keys(), []) 
Example 76
Project: sparkly   Author: tubular   File: test_writer.py    Apache License 2.0 4 votes vote down vote up
def test_invalid_input(self):
        df = self.spark.createDataFrame(
            data=[],
            schema=T.StructType([T.StructField('key_1', T.StringType())]),
        )

        with six.assertRaisesRegex(
                self,
                ValueError,
                'redis: expire must be positive',
        ):
            df.write_ext.redis(
                key_by=['key_1'],
                expire=0,
                host='redis.docker',
            )

        with six.assertRaisesRegex(
                self,
                ValueError,
                'redis: bzip2, gzip and zlib are the only supported compression codecs',
        ):
            df.write_ext.redis(
                key_by=['key_1'],
                compression='snappy',
                host='redis.docker',
            )

        with six.assertRaisesRegex(
                self,
                ValueError,
                'redis: max pipeline size must be positive',
        ):
            df.write_ext.redis(
                key_by=['key_1'],
                max_pipeline_size=0,
                host='redis.docker',
            )

        with six.assertRaisesRegex(
                self,
                ValueError,
                'redis: only append \(default\), ignore and overwrite modes are supported',
        ):
            df.write_ext.redis(
                key_by=['key_1'],
                mode='error',
                host='redis.docker',
            )

        with six.assertRaisesRegex(
                self,
                AssertionError,
                'redis: At least one of host or redis_client_init must be provided',
        ):
            df.write_ext.redis(
                key_by=['key_1'],
            ) 
Example 77
Project: sparkly   Author: tubular   File: test_writer.py    Apache License 2.0 4 votes vote down vote up
def test_composite_key_with_prefix_compressed(self):
        df = self.spark.createDataFrame(
            data=[
                ('k1', 'k14', [1, 14, 141]),
                ('k1', 'k12', [1, 12, 121]),
                ('k1', 'k11', [1, 11, 111]),
                ('k1', 'k13', [1, 13, 131]),
            ],
            schema=T.StructType([
                T.StructField('key_1', T.StringType()),
                T.StructField('key_2', T.StringType()),
                T.StructField('aux_data', T.ArrayType(T.IntegerType())),
            ])
        )

        df.write_ext.by_url(
            'redis://redis.docker?'
            'keyBy=key_1,key_2&'
            'keyPrefix=hello&'
            'compression=gzip&'
            'maxPipelineSize=3'
        )

        redis_client = redis.StrictRedis('redis.docker')

        self.assertRowsEqual(
            redis_client.keys(),
            [b'hello.k1.k11', b'hello.k1.k12', b'hello.k1.k13', b'hello.k1.k14'],
            ignore_order=True,
        )

        written_data = [
            json.loads(self._gzip_decompress(redis_client.get(key)))
            for key in ['hello.k1.k11', 'hello.k1.k12', 'hello.k1.k13', 'hello.k1.k14']
        ]

        expected = [
            {'key_1': 'k1', 'key_2': 'k11', 'aux_data': [1, 11, 111]},
            {'key_1': 'k1', 'key_2': 'k12', 'aux_data': [1, 12, 121]},
            {'key_1': 'k1', 'key_2': 'k13', 'aux_data': [1, 13, 131]},
            {'key_1': 'k1', 'key_2': 'k14', 'aux_data': [1, 14, 141]},
        ]

        self.assertEqual(written_data, expected) 
Example 78
Project: sparkly   Author: tubular   File: test_writer.py    Apache License 2.0 4 votes vote down vote up
def test_expiration(self):
        df = self.spark.createDataFrame(
            data=[
                ('k1', 'k14', [1, 14, 141]),
                ('k1', 'k12', [1, 12, 121]),
                ('k1', 'k11', [1, 11, 111]),
                ('k1', 'k13', [1, 13, 131]),
            ],
            schema=T.StructType([
                T.StructField('key_1', T.StringType()),
                T.StructField('key_2', T.StringType()),
                T.StructField('aux_data', T.ArrayType(T.IntegerType())),
            ])
        )

        df.write_ext.by_url('redis://redis.docker?keyBy=key_2&keyPrefix=hello&expire=2')

        redis_client = redis.StrictRedis('redis.docker')

        self.assertRowsEqual(
            redis_client.keys(),
            [b'hello.k11', b'hello.k12', b'hello.k13', b'hello.k14'],
            ignore_order=True,
        )

        written_data = [
            json.loads(redis_client.get(key))
            for key in ['hello.k11', 'hello.k12', 'hello.k13', 'hello.k14']
        ]

        expected = [
            {'key_1': 'k1', 'key_2': 'k11', 'aux_data': [1, 11, 111]},
            {'key_1': 'k1', 'key_2': 'k12', 'aux_data': [1, 12, 121]},
            {'key_1': 'k1', 'key_2': 'k13', 'aux_data': [1, 13, 131]},
            {'key_1': 'k1', 'key_2': 'k14', 'aux_data': [1, 14, 141]},
        ]

        self.assertEqual(written_data, expected)

        sleep(3)

        self.assertEqual(redis_client.keys(), []) 
Example 79
Project: sparkly   Author: tubular   File: test_functions.py    Apache License 2.0 4 votes vote down vote up
def test_with_conditions(self):
        df = self.spark.createDataFrame(
            data=[
                ('1', 'test1', 2),
                ('1', 'test2', 1),
                ('2', 'test3', 1),
                ('2', 'test4', 2),
            ],
            schema=T.StructType([
                T.StructField('id', T.StringType(), nullable=True),
                T.StructField('value', T.StringType(), nullable=True),
                T.StructField('target1', T.IntegerType(), nullable=True),
            ]),
        )

        df = (
            df
            .groupBy('id')
            .agg(
                SF.argmax(
                    'value',
                    'target1',
                    condition=F.col('value') != 'test1',
                ).alias('value'),
            )
        )

        self.assertDataFrameEqual(
            df,
            [
                {'id': '1', 'value': 'test2'},
                {'id': '2', 'value': 'test4'},
            ],
        ) 
Example 80
Project: sparkly   Author: tubular   File: test_utils.py    Apache License 2.0 4 votes vote down vote up
def test_type_mismatch(self):
        with six.assertRaisesRegex(self, AssertionError, 'Cannot compare heterogeneous types'):
            schema_has(
                T.StructType([T.StructField('f1', T.IntegerType())]),
                T.ArrayType(T.IntegerType()),
            )

        with six.assertRaisesRegex(self, AssertionError, 'Cannot compare heterogeneous types'):
            schema_has(
                T.ArrayType(T.IntegerType()),
                {'f1': T.IntegerType()},
            )

        with six.assertRaisesRegex(self, TypeError, 'f1 is IntegerType, expected LongType'):
            schema_has(
                T.StructType([T.StructField('f1', T.IntegerType())]),
                T.StructType([T.StructField('f1', T.LongType())]),
            )

        with six.assertRaisesRegex(
                self,
                TypeError,
                'f1\.element\.s1 is IntegerType, expected LongType',
        ):
            schema_has(
                T.StructType([
                    T.StructField(
                        'f1',
                        T.ArrayType(T.StructType([T.StructField('s1', T.IntegerType())])),
                    ),
                ]),
                T.StructType([
                    T.StructField(
                        'f1',
                        T.ArrayType(T.StructType([T.StructField('s1', T.LongType())])),
                    ),
                ]),
            )

        with six.assertRaisesRegex(self, TypeError, 'element is IntegerType, expected LongType'):
            schema_has(
                T.ArrayType(T.IntegerType()),
                T.ArrayType(T.LongType()),
            )

        with six.assertRaisesRegex(self, TypeError, 'key is StringType, expected LongType'):
            schema_has(
                T.MapType(T.StringType(), T.IntegerType()),
                T.MapType(T.LongType(), T.IntegerType()),
            )

        with six.assertRaisesRegex(self, TypeError, 'value is IntegerType, expected LongType'):
            schema_has(
                T.MapType(T.StringType(), T.IntegerType()),
                T.MapType(T.StringType(), T.LongType()),
            )