Python pyspark.sql.types.IntegerType() Examples

The following are 30 code examples of pyspark.sql.types.IntegerType(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module pyspark.sql.types , or try the search function .
Example #1
Source File: named_image_test.py    From spark-deep-learning with Apache License 2.0 6 votes vote down vote up
def test_featurizer_in_pipeline(self):
        """
        Tests that featurizer fits into an MLlib Pipeline.
        Does not test how good the featurization is for generalization.
        """
        featurizer = DeepImageFeaturizer(inputCol="image", outputCol="features",
                                         modelName=self.name)
        lr = LogisticRegression(maxIter=20, regParam=0.05, elasticNetParam=0.3, labelCol="label")
        pipeline = Pipeline(stages=[featurizer, lr])

        # add arbitrary labels to run logistic regression
        # TODO: it's weird that the test fails on some combinations of labels. check why.
        label_udf = udf(lambda x: abs(hash(x)) % 2, IntegerType())
        train_df = self.imageDF.withColumn("label", label_udf(self.imageDF["image"]["origin"]))

        lrModel = pipeline.fit(train_df)
        # see if we at least get the training examples right.
        # with 5 examples and e.g. 131k features (for InceptionV3), it ought to.
        pred_df_collected = lrModel.transform(train_df).collect()
        for row in pred_df_collected:
            self.assertEqual(int(row.prediction), row.label) 
Example #2
Source File: reddit_average_sql.py    From Hanhan-Spark-Python with MIT License 6 votes vote down vote up
def main():
    schema = StructType([
    StructField('subreddit', StringType(), False),
    StructField('score', IntegerType(), False),
    ])
    inputs = sqlContext.read.json(inputs1, schema=schema)

    # Uncomment this then shcema is not added
    # inputs = sqlContext.read.json(inputs1)

    # Uncomment these when there are 2 inputs dir
    # comments_input1 = sqlContext.read.json(inputs1, schema=schema)
    # comments_input2 = sqlContext.read.json(inputs2, schema=schema)
    # inputs = comments_input1.unionAll(comments_input2)

    df = get_avg(inputs)
    df.write.save(output, format='json', mode='overwrite') 
Example #3
Source File: test_unischema.py    From petastorm with Apache License 2.0 6 votes vote down vote up
def test_as_spark_schema():
    """Try using 'as_spark_schema' function"""
    TestSchema = Unischema('TestSchema', [
        UnischemaField('int_field', np.int8, (), ScalarCodec(IntegerType()), False),
        UnischemaField('string_field', np.string_, (), ScalarCodec(StringType()), False),
        UnischemaField('string_field_implicit', np.string_, ()),
    ])

    spark_schema = TestSchema.as_spark_schema()
    assert spark_schema.fields[0].name == 'int_field'

    assert spark_schema.fields[1].name == 'string_field'
    assert spark_schema.fields[1].dataType == StringType()

    assert spark_schema.fields[2].name == 'string_field_implicit'
    assert spark_schema.fields[2].dataType == StringType()

    assert TestSchema.fields['int_field'].name == 'int_field'
    assert TestSchema.fields['string_field'].name == 'string_field' 
Example #4
Source File: test_dataset_metadata.py    From petastorm with Apache License 2.0 6 votes vote down vote up
def test_serialize_filesystem_factory(tmpdir):
    SimpleSchema = Unischema('SimpleSchema', [
        UnischemaField('id', np.int32, (), ScalarCodec(IntegerType()), False),
        UnischemaField('foo', np.int32, (), ScalarCodec(IntegerType()), False),
    ])

    class BogusFS(pyarrow.LocalFileSystem):
        def __getstate__(self):
            raise RuntimeError("can not serialize")

    rows_count = 10
    output_url = "file://{0}/fs_factory_test".format(tmpdir)
    rowgroup_size_mb = 256
    spark = SparkSession.builder.config('spark.driver.memory', '2g').master('local[2]').getOrCreate()
    sc = spark.sparkContext
    with materialize_dataset(spark, output_url, SimpleSchema, rowgroup_size_mb, filesystem_factory=BogusFS):
        rows_rdd = sc.parallelize(range(rows_count))\
            .map(lambda x: {'id': x, 'foo': x})\
            .map(lambda x: dict_to_spark_row(SimpleSchema, x))

        spark.createDataFrame(rows_rdd, SimpleSchema.as_spark_schema()) \
            .write \
            .parquet(output_url) 
Example #5
Source File: accuracy.py    From HoloClean-Legacy-deprecated with Apache License 2.0 6 votes vote down vote up
def read_groundtruth(self):

        """
        Create a dataframe from the ground truth csv file

        Takes as argument the full path name of the csv file
        and the spark_session
        """
        filereader = Reader(self.spark_session)

        groundtruth_schema = StructType([
            StructField("tid", IntegerType(), False),
            StructField("attr_name", StringType(), False),
            StructField("attr_val", StringType(), False)])

        self.ground_truth_flat = filereader.read(self.path_to_grand_truth, 0,
                                                 groundtruth_schema).\
            drop(GlobalVariables.index_name)

        self.dataengine.add_db_table(
            'Groundtruth', self.ground_truth_flat, self.dataset) 
Example #6
Source File: sample_scaffolds.py    From reinvent-scaffold-decorator with MIT License 6 votes vote down vote up
def _join_results_single(self, scaffolds_df, sampled_df):
        def _join_scaffold(scaff, decs):
            mol = usc.join_joined_attachments(scaff, decs)
            if mol:
                return usc.to_smiles(mol)
        join_scaffold_udf = psf.udf(_join_scaffold, pst.StringType())

        def _create_decorations_map(decorations_smi, attachment_points):
            decorations = decorations_smi.split(usc.ATTACHMENT_SEPARATOR_TOKEN)
            return {idx: _cleanup_decoration(dec) for dec, idx in zip(decorations, attachment_points)}
        create_decorations_map_udf = psf.udf(_create_decorations_map, pst.MapType(pst.IntegerType(), pst.StringType()))

        return sampled_df.join(scaffolds_df, on="id")\
            .select(
                join_scaffold_udf("randomized_scaffold", "decoration_smi").alias("smiles"),
                create_decorations_map_udf("decoration_smi", "attachment_points").alias("decorations"),
                "scaffold") 
Example #7
Source File: utils.py    From mlflow with Apache License 2.0 6 votes vote down vote up
def format_to_file_path(spark_session):
    rows = [
        Row(8, 32, "bat"),
        Row(64, 40, "mouse"),
        Row(-27, 55, "horse")
    ]
    schema = StructType([
        StructField("number2", IntegerType()),
        StructField("number1", IntegerType()),
        StructField("word", StringType())
    ])
    rdd = spark_session.sparkContext.parallelize(rows)
    df = spark_session.createDataFrame(rdd, schema)
    res = {}
    tempdir = tempfile.mkdtemp()
    for data_format in ["csv", "parquet", "json"]:
        res[data_format] = os.path.join(tempdir, "test-data-%s" % data_format)

    for data_format, file_path in res.items():
        df.write.option("header", "true").format(data_format).save(file_path)
    yield res
    shutil.rmtree(tempdir) 
Example #8
Source File: sample_scaffolds.py    From reinvent-scaffold-decorator with MIT License 5 votes vote down vote up
def _initialize_results(self, scaffolds):
        data = [ps.Row(smiles=scaffold, scaffold=scaffold,
                       decorations={}, count=1) for scaffold in scaffolds]
        data_schema = pst.StructType([
            pst.StructField("smiles", pst.StringType()),
            pst.StructField("scaffold", pst.StringType()),
            pst.StructField("decorations", pst.MapType(pst.IntegerType(), pst.StringType())),
            pst.StructField("count", pst.IntegerType())
        ])
        return SPARK.createDataFrame(data, schema=data_schema) 
Example #9
Source File: criteo.py    From azure-python-labs with MIT License 5 votes vote down vote up
def get_spark_schema(header=DEFAULT_HEADER):
    ## create schema
    schema = StructType()
    ## do label + ints
    n_ints = 14
    for i in range(n_ints):
        schema.add(StructField(header[i], IntegerType()))
    ## do categoricals
    for i in range(26):
        schema.add(StructField(header[i + n_ints], StringType()))
    return schema 
Example #10
Source File: criteo.py    From azure-python-labs with MIT License 5 votes vote down vote up
def get_spark_schema(header=DEFAULT_HEADER):
    ## create schema
    schema = StructType()
    ## do label + ints
    n_ints = 14
    for i in range(n_ints):
        schema.add(StructField(header[i], IntegerType()))
    ## do categoricals
    for i in range(26):
        schema.add(StructField(header[i + n_ints], StringType()))
    return schema 
Example #11
Source File: typehints.py    From koalas with Apache License 2.0 5 votes vote down vote up
def as_spark_type(tpe) -> types.DataType:
    """
    Given a python type, returns the equivalent spark type.
    Accepts:
    - the built-in types in python
    - the built-in types in numpy
    - list of pairs of (field_name, type)
    - dictionaries of field_name -> type
    - python3's typing system
    """
    if tpe in (str, "str", "string"):
        return types.StringType()
    elif tpe in (bytes,):
        return types.BinaryType()
    elif tpe in (np.int8, "int8", "byte"):
        return types.ByteType()
    elif tpe in (np.int16, "int16", "short"):
        return types.ShortType()
    elif tpe in (int, "int", np.int, np.int32):
        return types.IntegerType()
    elif tpe in (np.int64, "int64", "long", "bigint"):
        return types.LongType()
    elif tpe in (float, "float", np.float):
        return types.FloatType()
    elif tpe in (np.float64, "float64", "double"):
        return types.DoubleType()
    elif tpe in (datetime.datetime, np.datetime64):
        return types.TimestampType()
    elif tpe in (datetime.date,):
        return types.DateType()
    elif tpe in (bool, "boolean", "bool", np.bool):
        return types.BooleanType()
    elif tpe in (np.ndarray,):
        # TODO: support other child types
        return types.ArrayType(types.StringType())
    else:
        raise TypeError("Type %s was not understood." % tpe) 
Example #12
Source File: __init__.py    From listenbrainz-server with GNU General Public License v2.0 5 votes vote down vote up
def upload_test_playcounts(cls):
        schema = StructType(
            [
                StructField("user_id", IntegerType()),
                StructField("recording_id", IntegerType()),
                StructField("count", IntegerType())
            ]
        )
        test_playcounts = []
        for i in range(1, PLAYCOUNTS_COUNT // 2 + 1):
            test_playcounts.append([1, 1, 1])
        for i in range(PLAYCOUNTS_COUNT // 2 + 1, PLAYCOUNTS_COUNT + 1):
            test_playcounts.append([2, 2, 1])
        test_playcounts_df = listenbrainz_spark.session.createDataFrame(test_playcounts, schema=schema)
        utils.save_parquet(test_playcounts_df, TEST_PLAYCOUNTS_PATH) 
Example #13
Source File: test_pyspark.py    From dagster with Apache License 2.0 5 votes vote down vote up
def make_df_solid(context):
    schema = StructType([StructField('name', StringType()), StructField('age', IntegerType())])
    rows = [Row(name='John', age=19), Row(name='Jennifer', age=29), Row(name='Henry', age=50)]
    return context.resources.pyspark.spark_session.createDataFrame(rows, schema) 
Example #14
Source File: test_pyspark.py    From dagster with Apache License 2.0 5 votes vote down vote up
def make_df_solid(context):
    schema = StructType([StructField('name', StringType()), StructField('age', IntegerType())])
    rows = [Row(name='John', age=19), Row(name='Jennifer', age=29), Row(name='Henry', age=50)]
    return context.resources.pyspark.spark_session.createDataFrame(rows, schema) 
Example #15
Source File: repo.py    From dagster with Apache License 2.0 5 votes vote down vote up
def make_people(context) -> DataFrame:
    schema = StructType([StructField('name', StringType()), StructField('age', IntegerType())])
    rows = [Row(name='Thom', age=51), Row(name='Jonny', age=48), Row(name='Nigel', age=49)]
    return context.resources.pyspark.spark_session.createDataFrame(rows, schema) 
Example #16
Source File: repo.py    From dagster with Apache License 2.0 5 votes vote down vote up
def make_people(context) -> DataFrame:
    schema = StructType([StructField('name', StringType()), StructField('age', IntegerType())])
    rows = [Row(name='Thom', age=51), Row(name='Jonny', age=48), Row(name='Nigel', age=49)]
    return context.resources.pyspark.spark_session.createDataFrame(rows, schema) 
Example #17
Source File: tests.py    From LearningApacheSpark with MIT License 5 votes vote down vote up
def test_unary_transformer_validate_input_type(self):
        shiftVal = 3
        transformer = MockUnaryTransformer(shiftVal=shiftVal)\
            .setInputCol("input").setOutputCol("output")

        # should not raise any errors
        transformer.validateInputType(DoubleType())

        with self.assertRaises(TypeError):
            # passing the wrong input type should raise an error
            transformer.validateInputType(IntegerType()) 
Example #18
Source File: test_spark.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_spark_udf(spark, model_path):
    mlflow.pyfunc.save_model(
        path=model_path,
        loader_module=__name__,
        code_path=[os.path.dirname(tests.__file__)],
    )
    reloaded_pyfunc_model = mlflow.pyfunc.load_pyfunc(model_path)

    pandas_df = pd.DataFrame(data=np.ones((10, 10)), columns=[str(i) for i in range(10)])
    spark_df = spark.createDataFrame(pandas_df)

    # Test all supported return types
    type_map = {"float": (FloatType(), np.number),
                "int": (IntegerType(), np.int32),
                "double": (DoubleType(), np.number),
                "long": (LongType(), np.int),
                "string": (StringType(), None)}

    for tname, tdef in type_map.items():
        spark_type, np_type = tdef
        prediction_df = reloaded_pyfunc_model.predict(pandas_df)
        for is_array in [True, False]:
            t = ArrayType(spark_type) if is_array else spark_type
            if tname == "string":
                expected = prediction_df.applymap(str)
            else:
                expected = prediction_df.select_dtypes(np_type)
                if tname == "float":
                    expected = expected.astype(np.float32)

            expected = [list(row[1]) if is_array else row[1][0] for row in expected.iterrows()]
            pyfunc_udf = spark_udf(spark, model_path, result_type=t)
            new_df = spark_df.withColumn("prediction", pyfunc_udf(*pandas_df.columns))
            actual = list(new_df.select("prediction").toPandas()['prediction'])
            assert expected == actual
            if not is_array:
                pyfunc_udf = spark_udf(spark, model_path, result_type=tname)
                new_df = spark_df.withColumn("prediction", pyfunc_udf(*pandas_df.columns))
                actual = list(new_df.select("prediction").toPandas()['prediction'])
                assert expected == actual 
Example #19
Source File: codecs.py    From petastorm with Apache License 2.0 5 votes vote down vote up
def encode(self, unischema_field, value):
        # Lazy loading pyspark to avoid creating pyspark dependency on data reading code path
        # (currently works only with make_batch_reader). We should move all pyspark related code into a separate module
        import pyspark.sql.types as sql_types

        # We treat ndarrays with shape=() as scalars
        unsized_numpy_array = isinstance(value, np.ndarray) and value.shape == ()
        # Validate the input to be a scalar (or an unsized numpy array)
        if not unsized_numpy_array and hasattr(value, '__len__') and (not isinstance(value, str)):
            raise TypeError('Expected a scalar as a value for field \'{}\'. '
                            'Got a non-numpy type\'{}\''.format(unischema_field.name, type(value)))

        if unischema_field.shape:
            raise ValueError('The shape field of unischema_field \'%s\' must be an empty tuple (i.e. \'()\' '
                             'to indicate a scalar. However, the actual shape is %s',
                             unischema_field.name, unischema_field.shape)
        if isinstance(self._spark_type, (sql_types.ByteType, sql_types.ShortType, sql_types.IntegerType,
                                         sql_types.LongType)):
            return int(value)
        if isinstance(self._spark_type, (sql_types.FloatType, sql_types.DoubleType)):
            return float(value)
        if isinstance(self._spark_type, sql_types.BooleanType):
            return bool(value)
        if isinstance(self._spark_type, sql_types.StringType):
            if not isinstance(value, str):
                raise ValueError(
                    'Expected a string value for field {}. Got type {}'.format(unischema_field.name, type(value)))
            return str(value)

        return value 
Example #20
Source File: test_unischema.py    From petastorm with Apache License 2.0 5 votes vote down vote up
def test_create_schema_view_using_regex_and_unischema_fields_with_duplicates():
    TestSchema = Unischema('TestSchema', [
        UnischemaField('int_field', np.int8, (), ScalarCodec(IntegerType()), False),
        UnischemaField('string_field', np.string_, (), ScalarCodec(StringType()), False),
        UnischemaField('other_string_field', np.string_, (), ScalarCodec(StringType()), False),
    ])
    view = TestSchema.create_schema_view(['int.*$', TestSchema.int_field])
    assert set(view.fields.keys()) == {'int_field'} 
Example #21
Source File: test_unischema.py    From petastorm with Apache License 2.0 5 votes vote down vote up
def test_create_schema_view_using_regex_and_unischema_fields():
    TestSchema = Unischema('TestSchema', [
        UnischemaField('int_field', np.int8, (), ScalarCodec(IntegerType()), False),
        UnischemaField('string_field', np.string_, (), ScalarCodec(StringType()), False),
        UnischemaField('other_string_field', np.string_, (), ScalarCodec(StringType()), False),
    ])
    view = TestSchema.create_schema_view(['int.*$', TestSchema.string_field])
    assert set(view.fields.keys()) == {'int_field', 'string_field'} 
Example #22
Source File: test_unischema.py    From petastorm with Apache License 2.0 5 votes vote down vote up
def test_create_schema_view_using_regex():
    TestSchema = Unischema('TestSchema', [
        UnischemaField('int_field', np.int8, (), ScalarCodec(IntegerType()), False),
        UnischemaField('string_field', np.string_, (), ScalarCodec(StringType()), False),
    ])
    view = TestSchema.create_schema_view(['int.*$'])
    assert set(view.fields.keys()) == {'int_field'}

    view = TestSchema.create_schema_view([u'int.*$'])
    assert set(view.fields.keys()) == {'int_field'} 
Example #23
Source File: test_unischema.py    From petastorm with Apache License 2.0 5 votes vote down vote up
def test_create_schema_view_using_unischema_fields():
    TestSchema = Unischema('TestSchema', [
        UnischemaField('int_field', np.int8, (), ScalarCodec(IntegerType()), False),
        UnischemaField('string_field', np.string_, (), ScalarCodec(StringType()), False),
    ])
    view = TestSchema.create_schema_view([TestSchema.int_field])
    assert set(view.fields.keys()) == {'int_field'} 
Example #24
Source File: test_unischema.py    From petastorm with Apache License 2.0 5 votes vote down vote up
def test_create_schema_view_using_invalid_type():
    """ Exercises code paths unischema.create_schema_view ValueError, and unischema.__str__."""
    TestSchema = Unischema('TestSchema', [
        UnischemaField('int_field', np.int8, (), ScalarCodec(IntegerType()), False),
        UnischemaField('string_field', np.string_, (), ScalarCodec(StringType()), False),
    ])
    with pytest.raises(ValueError, match='must be either a string'):
        TestSchema.create_schema_view([42]) 
Example #25
Source File: test_unischema.py    From petastorm with Apache License 2.0 5 votes vote down vote up
def test_fields():
    """Try using 'fields' getter"""
    TestSchema = Unischema('TestSchema', [
        UnischemaField('int_field', np.int8, (), ScalarCodec(IntegerType()), False),
        UnischemaField('string_field', np.string_, (), ScalarCodec(StringType()), False),
    ])

    assert len(TestSchema.fields) == 2
    assert TestSchema.fields['int_field'].name == 'int_field'
    assert TestSchema.fields['string_field'].name == 'string_field' 
Example #26
Source File: test_predicates.py    From petastorm with Apache License 2.0 5 votes vote down vote up
def test_predicate_on_partitioned_dataset(tmpdir):
    """
    Generates a partitioned dataset and ensures that readers evaluate the type of the partition
    column according to the type given in the Unischema.
    """
    TestSchema = Unischema('TestSchema', [
        UnischemaField('id', np.int32, (), ScalarCodec(IntegerType()), False),
        UnischemaField('test_field', np.int32, (), ScalarCodec(IntegerType()), False),
    ])

    def test_row_generator(x):
        """Returns a single entry in the generated dataset."""
        return {'id': x,
                'test_field': x*x}

    rowgroup_size_mb = 256
    dataset_url = "file://{0}/partitioned_test_dataset".format(tmpdir)

    spark = SparkSession.builder.config('spark.driver.memory', '2g').master('local[2]').getOrCreate()
    sc = spark.sparkContext

    rows_count = 10
    with materialize_dataset(spark, dataset_url, TestSchema, rowgroup_size_mb):

        rows_rdd = sc.parallelize(range(rows_count))\
            .map(test_row_generator)\
            .map(lambda x: dict_to_spark_row(TestSchema, x))

        spark.createDataFrame(rows_rdd, TestSchema.as_spark_schema()) \
            .write \
            .partitionBy('id') \
            .parquet(dataset_url)

    with make_reader(dataset_url, predicate=in_lambda(['id'], lambda x: x == 3)) as reader:
        assert next(reader).id == 3
    with make_reader(dataset_url, predicate=in_lambda(['id'], lambda x: x == '3')) as reader:
        with pytest.raises(StopIteration):
            # Predicate should have selected none, so a StopIteration should be raised.
            next(reader) 
Example #27
Source File: unischema.py    From petastorm with Apache License 2.0 5 votes vote down vote up
def _numpy_to_spark_mapping():
    """Returns a mapping from numpy to pyspark.sql type. Caches the mapping dictionary inorder to avoid instantiation
    of multiple objects in each call."""

    # Refer to the attribute of the function we use to cache the map using a name in the variable instead of a 'dot'
    # notation to avoid copy/paste/typo mistakes
    cache_attr_name = 'cached_numpy_to_pyspark_types_map'
    if not hasattr(_numpy_to_spark_mapping, cache_attr_name):
        import pyspark.sql.types as T

        setattr(_numpy_to_spark_mapping, cache_attr_name,
                {
                    np.int8: T.ByteType(),
                    np.uint8: T.ShortType(),
                    np.int16: T.ShortType(),
                    np.uint16: T.IntegerType(),
                    np.int32: T.IntegerType(),
                    np.int64: T.LongType(),
                    np.float32: T.FloatType(),
                    np.float64: T.DoubleType(),
                    np.string_: T.StringType(),
                    np.str_: T.StringType(),
                    np.unicode_: T.StringType(),
                    np.bool_: T.BooleanType(),
                })

    return getattr(_numpy_to_spark_mapping, cache_attr_name)


# TODO: Changing fields in this class or the UnischemaField will break reading due to the schema being pickled next to
# the dataset on disk 
Example #28
Source File: schema_utils.py    From eva with Apache License 2.0 5 votes vote down vote up
def get_petastorm_column(df_column):

        column_type = df_column.type
        column_name = df_column.name
        column_is_nullable = df_column.is_nullable
        column_array_dimensions = df_column.array_dimensions

        # Reference:
        # https://github.com/uber/petastorm/blob/master/petastorm/
        # tests/test_common.py

        petastorm_column = None
        if column_type == ColumnType.INTEGER:
            petastorm_column = UnischemaField(column_name,
                                              np.int32,
                                              (),
                                              ScalarCodec(IntegerType()),
                                              column_is_nullable)
        elif column_type == ColumnType.FLOAT:
            petastorm_column = UnischemaField(column_name,
                                              np.float64,
                                              (),
                                              ScalarCodec(FloatType()),
                                              column_is_nullable)
        elif column_type == ColumnType.TEXT:
            petastorm_column = UnischemaField(column_name,
                                              np.string_,
                                              (),
                                              ScalarCodec(StringType()),
                                              column_is_nullable)
        elif column_type == ColumnType.NDARRAY:
            petastorm_column = UnischemaField(column_name,
                                              np.uint8,
                                              column_array_dimensions,
                                              NdarrayCodec(),
                                              column_is_nullable)
        else:
            LoggingManager().log("Invalid column type: " + str(column_type),
                                 LoggingLevel.ERROR)

        return petastorm_column 
Example #29
Source File: test_dataset.py    From python_moztelemetry with Mozilla Public License 2.0 5 votes vote down vote up
def test_dataframe_with_schema(dataset, spark):
    schema = StructType([StructField("foo", IntegerType(), True)])
    df = dataset.dataframe(spark, decode=decode, schema=schema, table_name='bar')

    assert type(df) == DataFrame
    assert df.columns == ['foo']
    assert df.orderBy(["foo"]).collect() == [Row(foo=1), Row(foo=2)] 
Example #30
Source File: dbn.py    From search-MjoLniR with MIT License 5 votes vote down vote up
def train(df, dbn_config):
    """Generate relevance labels for the provided dataframe.

    Process the provided data frame to generate relevance scores for
    all provided pairs of (wikiid, norm_query_id, hit_page_id). The input
    DataFrame must have a row per hit_page_id that was seen by a session.

    Parameters
    ----------
    df : pyspark.sql.DataFrame
        User click logs with columns wikiid, norm_query_id, session_id,
        hit_page_id, hit_position, clicked.
    dbn_config : dict
        Configuration needed by the DBN. See scala implementation docs
        for more information.

    Returns
    -------
    spark.sql.DataFrame
        DataFrame with columns wikiid, norm_query_id, hit_page_id, relevance.
    """

    df = (
        df
        .withColumn('hit_page_id', F.col('hit_page_id').cast(T.IntegerType()))
        .withColumn('hit_position', F.col('hit_position').cast(T.IntegerType())))
    jvm = df._sc._jvm
    # jvm side expects Map[String, String]
    j_config = jvm.PythonUtils.toScalaMap({str(k): str(v) for k, v in dbn_config.items()})
    assert j_config.size() == len(dbn_config)
    j_df = jvm.org.wikimedia.search.mjolnir.DBN.train(df._jdf, j_config)
    return pyspark.sql.DataFrame(j_df, df.sql_ctx)