Python pyspark.sql.types.ArrayType() Examples

The following are 26 code examples of pyspark.sql.types.ArrayType(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module pyspark.sql.types , or try the search function .
Example #1
Source File: named_image.py    From spark-deep-learning with Apache License 2.0 6 votes vote down vote up
def _decodeOutputAsPredictions(self, df):
        # If we start having different weights than imagenet, we'll need to
        # move this logic to individual model building in NamedImageTransformer.
        # Also, we could put the computation directly in the main computation
        # graph or use a scala UDF for potentially better performance.
        topK = self.getOrDefault(self.topK)

        def decode(predictions):
            pred_arr = np.expand_dims(np.array(predictions), axis=0)
            decoded = decode_predictions(pred_arr, top=topK)[0]
            # convert numpy dtypes to python native types
            return [(t[0], t[1], t[2].item()) for t in decoded]

        decodedSchema = ArrayType(
            StructType([
                StructField("class", StringType(), False),
                StructField("description", StringType(), False),
                StructField("probability", FloatType(), False)
            ]))
        decodeUDF = udf(decode, decodedSchema)
        interim_output = self._getIntermediateOutputCol()
        return df \
            .withColumn(self.getOutputCol(), decodeUDF(df[interim_output])) \
            .drop(interim_output) 
Example #2
Source File: transform.py    From search-MjoLniR with MIT License 6 votes vote down vote up
def _simplify_data_type(data_type: T.DataType) -> Tuple:
    """Simplify datatype into a tuple of equality information we care about

    Most notably this ignores nullability concerns due to hive not
    being able to represent not null in it's schemas.
    """
    try:
        # Normalize UDT into it's sql form. Allows comparison of schemas
        # from hive and spark.
        sql_type = data_type.sqlType()  # type: ignore
    except AttributeError:
        sql_type = data_type

    if isinstance(sql_type, T.StructType):
        return ('StructType', [(field.name, _simplify_data_type(field.dataType)) for field in sql_type])
    elif isinstance(sql_type, T.ArrayType):
        return ('ArrayType', _simplify_data_type(sql_type.elementType))
    else:
        return (type(sql_type).__name__,) 
Example #3
Source File: spark_dataset_converter.py    From petastorm with Apache License 2.0 6 votes vote down vote up
def _convert_precision(df, dtype):
    if dtype is None:
        return df

    if dtype != "float32" and dtype != "float64":
        raise ValueError("dtype {} is not supported. \
            Use 'float32' or float64".format(dtype))

    source_type, target_type = (DoubleType, FloatType) \
        if dtype == "float32" else (FloatType, DoubleType)

    logger.warning("Converting floating-point columns to %s", dtype)

    for field in df.schema:
        col_name = field.name
        if isinstance(field.dataType, source_type):
            df = df.withColumn(col_name, df[col_name].cast(target_type()))
        elif isinstance(field.dataType, ArrayType) and \
                isinstance(field.dataType.elementType, source_type):
            df = df.withColumn(col_name, df[col_name].cast(ArrayType(target_type())))
    return df 
Example #4
Source File: dataset_utils.py    From mmtf-pyspark with Apache License 2.0 6 votes vote down vote up
def flatten_dataset(dataset: DataFrame):
    tmp = dataset
    for field in tmp.schema.fields:
        if isinstance(field.dataType, ArrayType):
            print(field.name, field.dataType)
            tmp = tmp.withColumn(field.name, explode(tmp.field.name))

    return tmp 
Example #5
Source File: df_naive.py    From example_dataproc_twitter with MIT License 6 votes vote down vote up
def register_udfs(self, sess, sc):
        """Register UDFs to be used in SQL queries.

        :type sess: `pyspark.sql.SparkSession`
        :param sess: Session used in Spark for SQL queries.

        :type sc: `pyspark.SparkContext`
        :param sc: Spark Context to run Spark jobs.
        """ 
        sess.udf.register("SQUARED", self.squared, returnType=(
            stypes.ArrayType(stypes.StructType(
            fields=[stypes.StructField('sku0', stypes.StringType()),
            stypes.StructField('norm', stypes.FloatType())]))))

        sess.udf.register('INTERSECTIONS',self.process_intersections,
            returnType=stypes.ArrayType(stypes.StructType(fields=[
            stypes.StructField('sku0', stypes.StringType()),
            stypes.StructField('sku1', stypes.StringType()),
            stypes.StructField('cor', stypes.FloatType())]))) 
Example #6
Source File: df_naive.py    From example_dataproc_twitter with MIT License 6 votes vote down vote up
def register_udfs(self, sess, sc):
        """Register UDFs to be used in SQL queries.

        :type sess: `pyspark.sql.SparkSession`
        :param sess: Session used in Spark for SQL queries.

        :type sc: `pyspark.SparkContext`
        :param sc: Spark Context to run Spark jobs.
        """ 
        sess.udf.register("SQUARED", self.squared, returnType=(
            stypes.ArrayType(stypes.StructType(
            fields=[stypes.StructField('sku0', stypes.StringType()),
            stypes.StructField('norm', stypes.FloatType())]))))

        sess.udf.register('INTERSECTIONS',self.process_intersections,
            returnType=stypes.ArrayType(stypes.StructType(fields=[
            stypes.StructField('sku0', stypes.StringType()),
            stypes.StructField('sku1', stypes.StringType()),
            stypes.StructField('cor', stypes.FloatType())]))) 
Example #7
Source File: dfutil.py    From TensorFlowOnSpark with Apache License 2.0 5 votes vote down vote up
def infer_schema(example, binary_features=[]):
  """Given a tf.train.Example, infer the Spark DataFrame schema (StructFields).

  Note: TensorFlow represents both strings and binary types as tf.train.BytesList, and we need to
  disambiguate these types for Spark DataFrames DTypes (StringType and BinaryType), so we require a "hint"
  from the caller in the ``binary_features`` argument.

  Args:
    :example: a tf.train.Example
    :binary_features: a list of tf.train.Example features which are expected to be binary/bytearrays.

  Returns:
    A DataFrame StructType schema
  """
  def _infer_sql_type(k, v):
    # special handling for binary features
    if k in binary_features:
      return BinaryType()

    if v.int64_list.value:
      result = v.int64_list.value
      sql_type = LongType()
    elif v.float_list.value:
      result = v.float_list.value
      sql_type = DoubleType()
    else:
      result = v.bytes_list.value
      sql_type = StringType()

    if len(result) > 1:             # represent multi-item tensors as Spark SQL ArrayType() of base types
      return ArrayType(sql_type)
    else:                           # represent everything else as base types (and empty tensors as StringType())
      return sql_type

  return StructType([StructField(k, _infer_sql_type(k, v), True) for k, v in sorted(example.features.feature.items())]) 
Example #8
Source File: test_spark.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_spark_udf_autofills_column_names_with_schema(spark):
    class TestModel(PythonModel):
        def predict(self, context, model_input):
            return [model_input.columns] * len(model_input)

    signature = ModelSignature(
        inputs=Schema([
            ColSpec("long", "a"),
            ColSpec("long", "b"),
            ColSpec("long", "c"),
        ]),
        outputs=Schema([ColSpec("integer")])
    )
    with mlflow.start_run() as run:
        mlflow.pyfunc.log_model("model", python_model=TestModel(), signature=signature)
        udf = mlflow.pyfunc.spark_udf(spark, "runs:/{}/model".format(run.info.run_id),
                                      result_type=ArrayType(StringType()))
        data = spark.createDataFrame(pd.DataFrame(
            columns=["a", "b", "c", "d"],
            data={
                "a": [1],
                "b": [2],
                "c": [3],
                "d": [4]
            }
        ))
        with pytest.raises(Py4JJavaError):
            res = data.withColumn("res1", udf("a", "b")).select("res1").toPandas()

        res = data.withColumn("res2", udf("a", "b", "c")).select("res2").toPandas()
        assert res["res2"][0] == ["a", "b", "c"]
        res = data.withColumn("res4", udf("a", "b", "c", "d")).select("res4").toPandas()
        assert res["res4"][0] == ["a", "b", "c"] 
Example #9
Source File: test_spark.py    From mlflow with Apache License 2.0 5 votes vote down vote up
def test_spark_udf(spark, model_path):
    mlflow.pyfunc.save_model(
        path=model_path,
        loader_module=__name__,
        code_path=[os.path.dirname(tests.__file__)],
    )
    reloaded_pyfunc_model = mlflow.pyfunc.load_pyfunc(model_path)

    pandas_df = pd.DataFrame(data=np.ones((10, 10)), columns=[str(i) for i in range(10)])
    spark_df = spark.createDataFrame(pandas_df)

    # Test all supported return types
    type_map = {"float": (FloatType(), np.number),
                "int": (IntegerType(), np.int32),
                "double": (DoubleType(), np.number),
                "long": (LongType(), np.int),
                "string": (StringType(), None)}

    for tname, tdef in type_map.items():
        spark_type, np_type = tdef
        prediction_df = reloaded_pyfunc_model.predict(pandas_df)
        for is_array in [True, False]:
            t = ArrayType(spark_type) if is_array else spark_type
            if tname == "string":
                expected = prediction_df.applymap(str)
            else:
                expected = prediction_df.select_dtypes(np_type)
                if tname == "float":
                    expected = expected.astype(np.float32)

            expected = [list(row[1]) if is_array else row[1][0] for row in expected.iterrows()]
            pyfunc_udf = spark_udf(spark, model_path, result_type=t)
            new_df = spark_df.withColumn("prediction", pyfunc_udf(*pandas_df.columns))
            actual = list(new_df.select("prediction").toPandas()['prediction'])
            assert expected == actual
            if not is_array:
                pyfunc_udf = spark_udf(spark, model_path, result_type=tname)
                new_df = spark_df.withColumn("prediction", pyfunc_udf(*pandas_df.columns))
                actual = list(new_df.select("prediction").toPandas()['prediction'])
                assert expected == actual 
Example #10
Source File: taar_ensemble.py    From python_mozetl with MIT License 5 votes vote down vote up
def get_addons_per_client(users_df, minimum_addons_count):
    """ Extracts a DataFrame that contains one row
    for each client along with the list of active add-on GUIDs.
    """

    def is_valid_addon(addon):
        return not (
            addon.is_system
            or addon.app_disabled
            or addon.type != "extension"
            or addon.user_disabled
            or addon.foreign_install
            or addon.install_day is None
        )

    # may need additional whitelisting to remove shield addons

    def get_valid_addon_ids(addons):
        sorted_addons = sorted(
            [(a.addon_id, a.install_day) for a in addons if is_valid_addon(a)],
            key=lambda addon_tuple: addon_tuple[1],
        )
        return [addon_id for (addon_id, install_day) in sorted_addons]

    get_valid_addon_ids_udf = udf(get_valid_addon_ids, ArrayType(StringType()))

    # Create an add-ons dataset un-nesting the add-on map from each
    # user to a list of add-on GUIDs. Also filter undesired add-ons.
    return users_df.select(
        "client_id", get_valid_addon_ids_udf("active_addons").alias("addon_ids")
    ).filter(size("addon_ids") > minimum_addons_count) 
Example #11
Source File: test_base.py    From example_dataproc_twitter with MIT License 5 votes vote down vote up
def test_load_neighbor_schema(self):
        klass = self.get_target_klass()()
        result = klass.load_neighbor_schema()
        expected = stypes.StructType(fields=[
                stypes.StructField("item", stypes.StringType()),
                 stypes.StructField("similarity_items", stypes.ArrayType(
                  stypes.StructType(fields=[
                   stypes.StructField("item", stypes.StringType()),
                    stypes.StructField("similarity", stypes.FloatType())])))])
        self.assertEqual(expected, result) 
Example #12
Source File: test_base.py    From example_dataproc_twitter with MIT License 5 votes vote down vote up
def test_load_users_schema(self):
        klass = self.get_target_klass()()
        expected = stypes.StructType(fields=[
        	stypes.StructField("user", stypes.StringType()),
        	 stypes.StructField('interactions', stypes.ArrayType(
        	  stypes.StructType(fields=[stypes.StructField('item', 
        	   stypes.StringType()), stypes.StructField('score', 
        	    stypes.FloatType())])))])
        result = klass.load_users_schema()
        self.assertEqual(result, expected) 
Example #13
Source File: base.py    From example_dataproc_twitter with MIT License 5 votes vote down vote up
def load_users_schema():
        """Loads schema with data type [user, [(sku, score), (sku, score)]]

        :rtype: `pyspark.sql.type.StructType`
        :returns: schema speficiation for user -> (sku, score) data.
        """
        return stypes.StructType(fields=[
        	stypes.StructField("user", stypes.StringType()),
        	 stypes.StructField('interactions', stypes.ArrayType(
        	  stypes.StructType(fields=[stypes.StructField('item', 
        	   stypes.StringType()), stypes.StructField('score', 
        	    stypes.FloatType())])))]) 
Example #14
Source File: base.py    From example_dataproc_twitter with MIT License 5 votes vote down vote up
def load_neighbor_schema(self):
        """Loads neighborhood schema for similarity matrix

        :rtype: `pyspark.sql.types.StructField`
        :returns: schema of type ["key", [("key", "value")]]
        """
        return stypes.StructType(fields=[
                stypes.StructField("item", stypes.StringType()),
                 stypes.StructField("similarity_items", stypes.ArrayType(
                  stypes.StructType(fields=[
                   stypes.StructField("item", stypes.StringType()),
                    stypes.StructField("similarity", stypes.FloatType())])))]) 
Example #15
Source File: base.py    From example_dataproc_twitter with MIT License 5 votes vote down vote up
def load_users_schema():
        """Loads schema with data type [user, [(sku, score), (sku, score)]]

        :rtype: `pyspark.sql.type.StructType`
        :returns: schema speficiation for user -> (sku, score) data.
        """
        return stypes.StructType(fields=[
        	stypes.StructField("user", stypes.StringType()),
        	 stypes.StructField('interactions', stypes.ArrayType(
        	  stypes.StructType(fields=[stypes.StructField('item', 
        	   stypes.StringType()), stypes.StructField('score', 
        	    stypes.FloatType())])))]) 
Example #16
Source File: taar_ensemble.py    From telemetry-airflow with Mozilla Public License 2.0 5 votes vote down vote up
def get_addons_per_client(users_df, minimum_addons_count):
    """ Extracts a DataFrame that contains one row
    for each client along with the list of active add-on GUIDs.
    """

    def is_valid_addon(addon):
        return not (
            addon.is_system
            or addon.app_disabled
            or addon.type != "extension"
            or addon.user_disabled
            or addon.foreign_install
            or addon.install_day is None
        )

    # may need additional whitelisting to remove shield addons

    def get_valid_addon_ids(addons):
        sorted_addons = sorted(
            [(a.addon_id, a.install_day) for a in addons if is_valid_addon(a)],
            key=lambda addon_tuple: addon_tuple[1],
        )
        return [addon_id for (addon_id, install_day) in sorted_addons]

    get_valid_addon_ids_udf = udf(get_valid_addon_ids, ArrayType(StringType()))

    # Create an add-ons dataset un-nesting the add-on map from each
    # user to a list of add-on GUIDs. Also filter undesired add-ons.
    return users_df.select(
        "client_id", get_valid_addon_ids_udf("active_addons").alias("addon_ids")
    ).filter(size("addon_ids") > minimum_addons_count) 
Example #17
Source File: strings.py    From koalas with Apache License 2.0 5 votes vote down vote up
def len(self) -> "ks.Series":
        """
        Computes the length of each element in the Series.

        The element may be a sequence (such as a string, tuple or list).

        Returns
        -------
        Series of int
            A Series of integer values indicating the length of each element in
            the Series.

        Examples
        --------
        Returns the length (number of characters) in a string. Returns the
        number of entries for lists or tuples.

        >>> s1 = ks.Series(['dog', 'monkey'])
        >>> s1.str.len()
        0    3
        1    6
        Name: 0, dtype: int64

        >>> s2 = ks.Series([["a", "b", "c"], []])
        >>> s2.str.len()
        0    3
        1    0
        Name: 0, dtype: int64
        """
        if isinstance(self._data.spark.data_type, (ArrayType, MapType)):
            return column_op(lambda c: F.size(c).cast(LongType()))(self._data).alias(
                self._data.name
            )
        else:
            return column_op(lambda c: F.length(c).cast(LongType()))(self._data).alias(
                self._data.name
            ) 
Example #18
Source File: strings.py    From koalas with Apache License 2.0 5 votes vote down vote up
def __init__(self, series: "ks.Series"):
        if not isinstance(series.spark.data_type, (StringType, BinaryType, ArrayType)):
            raise ValueError("Cannot call StringMethods on type {}".format(series.spark.data_type))
        self._data = series
        self.name = self._data.name

    # Methods 
Example #19
Source File: typehints.py    From koalas with Apache License 2.0 5 votes vote down vote up
def as_spark_type(tpe) -> types.DataType:
    """
    Given a python type, returns the equivalent spark type.
    Accepts:
    - the built-in types in python
    - the built-in types in numpy
    - list of pairs of (field_name, type)
    - dictionaries of field_name -> type
    - python3's typing system
    """
    if tpe in (str, "str", "string"):
        return types.StringType()
    elif tpe in (bytes,):
        return types.BinaryType()
    elif tpe in (np.int8, "int8", "byte"):
        return types.ByteType()
    elif tpe in (np.int16, "int16", "short"):
        return types.ShortType()
    elif tpe in (int, "int", np.int, np.int32):
        return types.IntegerType()
    elif tpe in (np.int64, "int64", "long", "bigint"):
        return types.LongType()
    elif tpe in (float, "float", np.float):
        return types.FloatType()
    elif tpe in (np.float64, "float64", "double"):
        return types.DoubleType()
    elif tpe in (datetime.datetime, np.datetime64):
        return types.TimestampType()
    elif tpe in (datetime.date,):
        return types.DateType()
    elif tpe in (bool, "boolean", "bool", np.bool):
        return types.BooleanType()
    elif tpe in (np.ndarray,):
        # TODO: support other child types
        return types.ArrayType(types.StringType())
    else:
        raise TypeError("Type %s was not understood." % tpe) 
Example #20
Source File: datatypes.py    From ibis with Apache License 2.0 5 votes vote down vote up
def ibis_array_dtype_to_spark_dtype(ibis_dtype_obj):
    element_type = spark_dtype(ibis_dtype_obj.value_type)
    contains_null = ibis_dtype_obj.value_type.nullable
    return pt.ArrayType(element_type, contains_null) 
Example #21
Source File: es_hits.py    From search-MjoLniR with MIT License 5 votes vote down vote up
def transform(df, url_list=None, brokers=None, **kwargs):
    if brokers and url_list:
        raise ValueError('cannot specify brokers and url_list')
    if brokers:
        rdd = transform_from_kafka(df, brokers, **kwargs)
    else:
        rdd = transform_from_elasticsearch(df, url_list, **kwargs)
    return df.sql_ctx.createDataFrame(rdd, T.StructType([
        df.schema['wikiid'],
        df.schema['query'],
        df.schema['norm_query'],
        T.StructField('hit_page_ids', T.ArrayType(T.IntegerType()), nullable=False),
    ])) 
Example #22
Source File: sequenceNgrammer.py    From mmtf-pyspark with Apache License 2.0 4 votes vote down vote up
def ngram(data, n, outputCol):
    '''Splits a one-letter sequence column (e.g., protein sequence)
    into array of overlapping n-grams.

    Examples
    --------
    2-gram: IDCGH ... => [ID, DC, CG, GH, ...]

    Parameters
    ----------
    data : dataset
       input dataset with column "sequence"
    n : int
       size of the n-gram
    outputCol : str
       name of the output column

    Returns
    -------
    dataset
        output dataset with appended ngram column
    '''

    session = SparkSession.builder.getOrCreate()

    #Encoder function to be passed as User Defined Function (UDF)
    def _ngrammer(s):
        ngram = []
        i = 0

        if len(s) < 1:
            return []

        while i < len(s) - n + 1:
            ngram.append(s[i: i + n])
            i += 1

        return ngram

    session.udf.register("ngrammer", _ngrammer, types.ArrayType(types.StringType()))

    data.createOrReplaceTempView("table")
    sql = f"SELECT *, ngrammer(sequence) AS {outputCol} from table"

    data = session.sql(sql)

    return data 
Example #23
Source File: classification.py    From LearningApacheSpark with MIT License 4 votes vote down vote up
def _transform(self, dataset):
        # determine the input columns: these need to be passed through
        origCols = dataset.columns

        # add an accumulator column to store predictions of all the models
        accColName = "mbc$acc" + str(uuid.uuid4())
        initUDF = udf(lambda _: [], ArrayType(DoubleType()))
        newDataset = dataset.withColumn(accColName, initUDF(dataset[origCols[0]]))

        # persist if underlying dataset is not persistent.
        handlePersistence = dataset.storageLevel == StorageLevel(False, False, False, False)
        if handlePersistence:
            newDataset.persist(StorageLevel.MEMORY_AND_DISK)

        # update the accumulator column with the result of prediction of models
        aggregatedDataset = newDataset
        for index, model in enumerate(self.models):
            rawPredictionCol = model._call_java("getRawPredictionCol")
            columns = origCols + [rawPredictionCol, accColName]

            # add temporary column to store intermediate scores and update
            tmpColName = "mbc$tmp" + str(uuid.uuid4())
            updateUDF = udf(
                lambda predictions, prediction: predictions + [prediction.tolist()[1]],
                ArrayType(DoubleType()))
            transformedDataset = model.transform(aggregatedDataset).select(*columns)
            updatedDataset = transformedDataset.withColumn(
                tmpColName,
                updateUDF(transformedDataset[accColName], transformedDataset[rawPredictionCol]))
            newColumns = origCols + [tmpColName]

            # switch out the intermediate column with the accumulator column
            aggregatedDataset = updatedDataset\
                .select(*newColumns).withColumnRenamed(tmpColName, accColName)

        if handlePersistence:
            newDataset.unpersist()

        # output the index of the classifier with highest confidence as prediction
        labelUDF = udf(
            lambda predictions: float(max(enumerate(predictions), key=operator.itemgetter(1))[0]),
            DoubleType())

        # output label and label metadata as prediction
        return aggregatedDataset.withColumn(
            self.getPredictionCol(), labelUDF(aggregatedDataset[accColName])).drop(accColName) 
Example #24
Source File: test_sync_bookmark.py    From python_mozetl with MIT License 4 votes vote down vote up
def sync_summary_schema():
    """"Generate a schema for sync_summary. This subset contains enough
    structure for testing bookmark validation. The schema is derived from
    [`telemetry-batch-view`][1].

    [1]: https://git.io/vdQ5A
    """
    failure_type = StructType([StructField("name", StringType(), False)])

    status_type = StructType([StructField("sync", StringType(), True)])

    validation_problems = StructType(
        [
            StructField("name", StringType(), False),
            StructField("count", LongType(), False),
        ]
    )

    validation_type = StructType(
        [
            StructField("version", LongType(), False),
            StructField("checked", LongType(), False),
            StructField("took", LongType(), False),
            StructField("problems", ArrayType(validation_problems, False), True),
        ]
    )

    engine_type = StructType(
        [
            StructField("name", StringType(), False),
            StructField("status", StringType(), False),
            StructField("failure_reason", failure_type, True),
            StructField("validation", validation_type, True),
        ]
    )

    return StructType(
        [
            StructField("app_build_id", StringType(), True),
            StructField("app_version", StringType(), True),
            StructField("app_display_version", StringType(), True),
            StructField("app_name", StringType(), True),
            StructField("app_channel", StringType(), True),
            StructField("uid", StringType(), False),
            StructField("device_id", StringType(), True),
            StructField("when", LongType(), False),
            StructField("failure_reason", failure_type, True),
            StructField("status", status_type, False),
            StructField("engines", ArrayType(engine_type, False), True),
            StructField("submission_date_s3", StringType(), False),
        ]
    ) 
Example #25
Source File: sample_scaffolds.py    From reinvent-scaffold-decorator with MIT License 4 votes vote down vote up
def run(self, initial_scaffolds):
        randomized_scaffold_udf = psf.udf(self._generate_func, pst.ArrayType(pst.StringType()))
        get_attachment_points_udf = psf.udf(usc.get_attachment_points, pst.ArrayType(pst.IntegerType()))
        remove_attachment_point_numbers_udf = psf.udf(usc.remove_attachment_point_numbers, pst.StringType())

        results_df = self._initialize_results(initial_scaffolds)
        scaffolds_df = results_df.select("smiles", "scaffold", "decorations")
        i = 0
        while scaffolds_df.count() > 0:
            # generate randomized SMILES
            self._log("info", "Starting iteration #%d.", i)
            scaffolds_df = scaffolds_df.withColumn("randomized_scaffold", randomized_scaffold_udf("smiles"))\
                .select(
                    "smiles", "scaffold", "decorations",
                    psf.explode("randomized_scaffold").alias("randomized_scaffold"))\
                .withColumn("attachment_points", get_attachment_points_udf("randomized_scaffold"))\
                .withColumn("randomized_scaffold", remove_attachment_point_numbers_udf("randomized_scaffold"))\
                .withColumn("id", psf.monotonically_increasing_id())\
                .persist()
            self._log("info", "Generated %d randomized SMILES from %d scaffolds.",
                      scaffolds_df.count(), scaffolds_df.select("smiles").distinct().count())

            # sample each randomized scaffold N times
            scaffolds = scaffolds_df.select("id", "randomized_scaffold")\
                .rdd.map(lambda row: (row["id"], row["randomized_scaffold"])).toLocalIterator()
            self._sample_and_write_scaffolds_to_disk(scaffolds, scaffolds_df.count())
            self._log("info", "Sampled %d scaffolds.", scaffolds_df.count())

            # merge decorated molecules
            joined_df = self._join_results(scaffolds_df).persist()

            if joined_df.count() > 0:
                self._log("info", "Joined %d -> %d (valid) -> %d unique sampled scaffolds",
                          scaffolds_df.count(), joined_df.agg(psf.sum("count")).head()[0], joined_df.count())

            scaffolds_df = joined_df.join(results_df, on="smiles", how="left_anti")\
                .select("smiles", "scaffold", "decorations")\
                .where("smiles LIKE '%*%'")
            self._log("info", "Obtained %d scaffolds for next iteration.", scaffolds_df.count())

            results_df = results_df.union(joined_df)\
                .groupBy("smiles")\
                .agg(
                    psf.first("scaffold").alias("scaffold"),
                    psf.first("decorations").alias("decorations"),
                    psf.sum("count").alias("count"))\
                .persist()
            i += 1

        return results_df 
Example #26
Source File: norm_query_clustering.py    From search-MjoLniR with MIT License 4 votes vote down vote up
def cluster_within_norm_query_groups(df: DataFrame) -> DataFrame:
    make_groups = F.udf(_make_query_groups, T.ArrayType(T.StructType([
        T.StructField('query', T.StringType(), nullable=False),
        T.StructField('norm_query_group_id', T.IntegerType(), nullable=False),
    ])))
    return (
        df
        .groupBy('wikiid', 'norm_query')
        .agg(F.collect_list(F.struct('query', 'hit_page_ids')).alias('source'))
        .select(
            'wikiid', 'norm_query',
            F.explode(make_groups('source')).alias('group'))
        .select('wikiid', 'norm_query', 'group.query', 'group.norm_query_group_id'))