Python pyspark.sql() Examples

The following are 30 code examples of pyspark.sql(). You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example. You may also want to check out all available functions/classes of the module pyspark , or try the search function .
Example #1
Source File: series.py    From koalas with Apache License 2.0 7 votes vote down vote up
def _cumprod(self, skipna, part_cols=()):
        from pyspark.sql.functions import pandas_udf

        def cumprod(scol):
            @pandas_udf(returnType=self.spark.data_type)
            def negative_check(s):
                assert len(s) == 0 or ((s > 0) | (s.isnull())).all(), (
                    "values should be bigger than 0: %s" % s
                )
                return s

            return F.sum(F.log(negative_check(scol)))

        kser = self._cum(cumprod, skipna, part_cols)
        return kser._with_new_scol(F.exp(kser.spark.column)).rename(self.name)

    # ----------------------------------------------------------------------
    # Accessor Methods
    # ---------------------------------------------------------------------- 
Example #2
Source File: test_ExtractCCLinks.py    From cccatalog with MIT License 6 votes vote down vote up
def summarizeOutput(self):
        s   = SQLContext(self.sc)
        res = s.read.parquet(self.cclinks.output)

        totalLinks  = res.count()
        uniqueContentQuery = res.drop_duplicates(subset=['provider_domain', 'content_path', 'content_query_string']).count()
        uniqueContent = res.drop_duplicates(subset=['provider_domain', 'content_path']).count()


        res.registerTempTable('test_deeds')
        summary = s.sql('SELECT provider_domain, count(*) AS total, count(distinct content_path) AS unique_content_path, count(distinct content_query_string) AS unique_query_string FROM test_deeds GROUP BY provider_domain ORDER BY total DESC LIMIT 100')
        summary.write.mode('overwrite').format('csv').option('header', 'true').save(self.cclinks.output.replace('parquet', 'summary'))

        fh = open('{}/total'.format(self.cclinks.output.replace('parquet', 'summary')), 'w')
        fh.write('Total records: {}\r\n'.format(totalLinks))
        fh.write('Total unique content path: {}\r\n'.format(uniqueContent))
        fh.write('Total unique query strings: {}\r\n'.format(uniqueContentQuery))
        fh.close() 
Example #3
Source File: unischema.py    From petastorm with Apache License 2.0 6 votes vote down vote up
def as_spark_schema(self):
        """Returns an object derived from the unischema as spark schema.

        Example:

        >>> spark.createDataFrame(dataset_rows,
        >>>                       SomeSchema.as_spark_schema())
        """
        # Lazy loading pyspark to avoid creating pyspark dependency on data reading code path
        # (currently works only with make_batch_reader)
        import pyspark.sql.types as sql_types

        schema_entries = []
        for field in self._fields.values():
            spark_type = _field_spark_dtype(field)
            schema_entries.append(sql_types.StructField(field.name, spark_type, field.nullable))

        return sql_types.StructType(schema_entries) 
Example #4
Source File: test_spark_dataset_converter.py    From petastorm with Apache License 2.0 6 votes vote down vote up
def test_atexit(spark_test_ctx):
    lines = """
    from petastorm.spark import SparkDatasetConverter, make_spark_converter
    from pyspark.sql import SparkSession
    import os
    spark = SparkSession.builder.getOrCreate()
    spark.conf.set(SparkDatasetConverter.PARENT_CACHE_DIR_URL_CONF, '{temp_url}')
    df = spark.createDataFrame([(1, 2),(4, 5)], ["col1", "col2"])
    converter = make_spark_converter(df)
    f = open(os.path.join('{tempdir}', 'test_atexit.out'), "w")
    f.write(converter.cache_dir_url)
    f.close()
    """.format(tempdir=spark_test_ctx.tempdir, temp_url=spark_test_ctx.temp_url)
    code_str = "; ".join(
        line.strip() for line in lines.strip().splitlines())
    ret_code = subprocess.call([sys.executable, "-c", code_str])
    assert 0 == ret_code
    with open(os.path.join(spark_test_ctx.tempdir, 'test_atexit.out')) as f:
        cache_dir_url = f.read()

    fs = FilesystemResolver(cache_dir_url).filesystem()
    assert not fs.exists(urlparse(cache_dir_url).path) 
Example #5
Source File: spark.py    From mlflow with Apache License 2.0 6 votes vote down vote up
def _load_pyfunc(path):
    """
    Load PyFunc implementation. Called by ``pyfunc.load_pyfunc``.

    :param path: Local filesystem path to the MLflow Model with the ``spark`` flavor.
    """
    # NOTE: The getOrCreate() call below may change settings of the active session which we do not
    # intend to do here. In particular, setting master to local[1] can break distributed clusters.
    # To avoid this problem, we explicitly check for an active session. This is not ideal but there
    # is no good workaround at the moment.
    import pyspark

    spark = pyspark.sql.SparkSession._instantiatedSession
    if spark is None:
        spark = pyspark.sql.SparkSession.builder.config("spark.python.worker.reuse", True) \
            .master("local[1]").getOrCreate()
    return _PyFuncModelWrapper(spark, _load_model(model_uri=path)) 
Example #6
Source File: groupby.py    From sparklingpandas with Apache License 2.0 6 votes vote down vote up
def last(self):
        """Pull out the last from each group."""
        myargs = self._myargs
        mykwargs = self._mykwargs
        # If its possible to use Spark SQL grouping do it
        if self._can_use_new_school():
            self._prep_spark_sql_groupby()
            import pyspark.sql.functions as func
            return self._use_aggregation(func.last)

        def create_combiner(x):
            return x.groupby(*myargs, **mykwargs).last()

        def merge_value(x, y):
            return create_combiner(y)

        def merge_combiner(x, y):
            return y

        rddOfLast = self._sortIfNeeded(self._distributedRDD.combineByKey(
            create_combiner,
            merge_value,
            merge_combiner)).values()
        return DataFrame.fromDataFrameRDD(rddOfLast, self.sql_ctx) 
Example #7
Source File: testSmvfuncs.py    From SMV with Apache License 2.0 6 votes vote down vote up
def test_distMetric(self):
        df = self.createDF("s1:String; s2:String",
            ",ads;" +\
            "asdfg,asdfg;" +\
            "asdfghj,asdfhgj"
        )

        trunc = lambda c: pyspark.sql.functions.round(c,2)
        res = df.select(
            df.s1, df.s2,
            trunc(nGram2(df.s1, df.s2)).alias("nGram2"),
            trunc(nGram3(df.s1, df.s2)).alias("nGram3"),
            trunc(diceSorensen(df.s1, df.s2)).alias("diceSorensen"),
            trunc(normlevenshtein(df.s1, df.s2)).alias("normlevenshtein"),
            trunc(jaroWinkler(df.s1, df.s2)).alias("jaroWinkler")
        )

        exp = self.createDF("s1: String;s2: String;nGram2: Float;nGram3: Float;diceSorensen: Float;normlevenshtein: Float;jaroWinkler: Float",
            ",ads,,,,,;" + \
            "asdfg,asdfg,1.0,1.0,1.0,1.0,1.0;" + \
            "asdfghj,asdfhgj,0.5,0.4,0.5,0.71,0.97")

        self.should_be_same(res, exp) 
Example #8
Source File: dataframe.py    From sparklingpandas with Apache License 2.0 6 votes vote down vote up
def stats(self, columns):
        """Compute the stats for each column provided in columns.
        Parameters
        ----------
        columns : list of str, contains all columns to compute stats on.
        """
        assert (not isinstance(columns, basestring)), "columns should be a " \
                                                      "list of strs,  " \
                                                      "not a str!"
        assert isinstance(columns, list), "columns should be a list!"

        from pyspark.sql import functions as F
        functions = [F.min, F.max, F.avg, F.count]
        aggs = list(
            self._flatmap(lambda column: map(lambda f: f(column), functions),
                          columns))
        return PStats(self.from_schema_rdd(self._schema_rdd.agg(*aggs))) 
Example #9
Source File: groupby.py    From sparklingpandas with Apache License 2.0 6 votes vote down vote up
def max(self):
        """Compute the max for each group."""
        if self._can_use_new_school():
            self._prep_spark_sql_groupby()
            import pyspark.sql.functions as func
            return self._use_aggregation(func.max)
        self._prep_pandas_groupby()
        myargs = self._myargs
        mykwargs = self._mykwargs

        def create_combiner(x):
            return x.groupby(*myargs, **mykwargs).max()

        def merge_value(x, y):
            return x.append(create_combiner(y)).max()

        def merge_combiner(x, y):
            return x.append(y).max(level=0)

        rddOfMax = self._sortIfNeeded(self._distributedRDD.combineByKey(
            create_combiner,
            merge_value,
            merge_combiner)).values()
        return DataFrame.fromDataFrameRDD(rddOfMax, self.sql_ctx) 
Example #10
Source File: groupby.py    From sparklingpandas with Apache License 2.0 6 votes vote down vote up
def min(self):
        """Compute the min for each group."""
        if self._can_use_new_school():
            self._prep_spark_sql_groupby()
            import pyspark.sql.functions as func
            return self._use_aggregation(func.min)
        self._prep_pandas_groupby()
        myargs = self._myargs
        mykwargs = self._mykwargs

        def create_combiner(x):
            return x.groupby(*myargs, **mykwargs).min()

        def merge_value(x, y):
            return x.append(create_combiner(y)).min()

        def merge_combiner(x, y):
            return x.append(y).min(level=0)

        rddOfMin = self._sortIfNeeded(self._distributedRDD.combineByKey(
            create_combiner,
            merge_value,
            merge_combiner)).values()
        return DataFrame.fromDataFrameRDD(rddOfMin, self.sql_ctx) 
Example #11
Source File: groupby.py    From sparklingpandas with Apache License 2.0 6 votes vote down vote up
def sum(self):
        """Compute the sum for each group."""
        if self._can_use_new_school():
            self._prep_spark_sql_groupby()
            import pyspark.sql.functions as func
            return self._use_aggregation(func.sum)
        self._prep_pandas_groupby()
        myargs = self._myargs
        mykwargs = self._mykwargs

        def create_combiner(x):
            return x.groupby(*myargs, **mykwargs).sum()

        def merge_value(x, y):
            return pd.concat([x, create_combiner(y)])

        def merge_combiner(x, y):
            return x + y

        rddOfSum = self._sortIfNeeded(self._distributedRDD.combineByKey(
            create_combiner,
            merge_value,
            merge_combiner)).values()
        return DataFrame.fromDataFrameRDD(rddOfSum, self.sql_ctx) 
Example #12
Source File: test_spark_session_fixture.py    From pytest-spark with MIT License 5 votes vote down vote up
def test_spark_session_sql(spark_session):
    test_df = spark_session.createDataFrame([[1, 3], [2, 4]], "a: int, b: int")
    test_df.registerTempTable('test')

    test_filtered_df = spark_session.sql('SELECT a, b from test where a > 1')
    assert test_filtered_df.count() == 1 
Example #13
Source File: test_spark.py    From sk-dist with Apache License 2.0 5 votes vote down vote up
def test_spark_session_sql(spark_session):
    test_df = spark_session.createDataFrame([[1, 3], [2, 4]], "a: int, b: int")
    test_df.createOrReplaceTempView('test')

    test_filtered_df = spark_session.sql('SELECT a, b from test where a > 1')
    assert test_filtered_df.count() == 1 
Example #14
Source File: groupby.py    From sparklingpandas with Apache License 2.0 5 votes vote down vote up
def mean(self):
        """Compute mean of groups, excluding missing values.

        For multiple groupings, the result index will be a MultiIndex.
        """
        if self._can_use_new_school():
            self._prep_spark_sql_groupby()
            import pyspark.sql.functions as func
            return self._use_aggregation(func.mean)
        self._prep_pandas_groupby()
        return DataFrame.fromDataFrameRDD(
            self._regroup_mergedRDD().values().map(
                lambda x: x.mean()), self.sql_ctx) 
Example #15
Source File: groupby.py    From sparklingpandas with Apache License 2.0 5 votes vote down vote up
def _use_aggregation(self, agg, columns=None):
        """Compute the result using the aggregation function provided.
        The aggregation name must also be provided so we can strip of the extra
        name that Spark SQL adds."""
        if not columns:
            columns = self._columns
        from pyspark.sql import functions as F
        aggs = map(lambda column: agg(column).alias(column), self._columns)
        aggRdd = self._grouped_spark_sql.agg(*aggs)
        df = DataFrame.from_schema_rdd(aggRdd, self._by)
        return df 
Example #16
Source File: groupby.py    From sparklingpandas with Apache License 2.0 5 votes vote down vote up
def first(self):
        """
        Pull out the first from each group. Note: this is different than
        Spark's first.
        """
        # If its possible to use Spark SQL grouping do it
        if self._can_use_new_school():
            self._prep_spark_sql_groupby()
            import pyspark.sql.functions as func
            return self._use_aggregation(func.first)
        myargs = self._myargs
        mykwargs = self._mykwargs
        self._prep_pandas_groupby()

        def create_combiner(x):
            return x.groupby(*myargs, **mykwargs).first()

        def merge_value(x, y):
            return create_combiner(x)

        def merge_combiner(x, y):
            return x

        rddOfFirst = self._sortIfNeeded(self._distributedRDD.combineByKey(
            create_combiner,
            merge_value,
            merge_combiner)).values()
        return DataFrame.fromDataFrameRDD(rddOfFirst, self.sql_ctx) 
Example #17
Source File: test_spark.py    From sk-dist with Apache License 2.0 5 votes vote down vote up
def test_spark_session_dataframe(spark_session):
    test_df = spark_session.createDataFrame([[1, 3], [2, 4]], "a: int, b: int")

    assert type(test_df) == pyspark.sql.dataframe.DataFrame
    assert test_df.count() == 2 
Example #18
Source File: test_spark_session_fixture.py    From pytest-spark with MIT License 5 votes vote down vote up
def test_spark_session_dataframe(spark_session):
    test_df = spark_session.createDataFrame([[1, 3], [2, 4]], "a: int, b: int")

    assert type(test_df) == pyspark.sql.dataframe.DataFrame
    assert test_df.count() == 2 
Example #19
Source File: indexes.py    From koalas with Apache License 2.0 5 votes vote down vote up
def value_counts(self, normalize=False, sort=True, ascending=False, bins=None, dropna=True):
        if (
            LooseVersion(pyspark.__version__) < LooseVersion("2.4")
            and default_session().conf.get("spark.sql.execution.arrow.enabled") == "true"
            and isinstance(self, MultiIndex)
        ):
            raise RuntimeError(
                "if you're using pyspark < 2.4, set conf "
                "'spark.sql.execution.arrow.enabled' to 'false' "
                "for using this function with MultiIndex"
            )
        return super(MultiIndex, self).value_counts(
            normalize=normalize, sort=sort, ascending=ascending, bins=bins, dropna=dropna
        ) 
Example #20
Source File: accessors.py    From koalas with Apache License 2.0 5 votes vote down vote up
def schema(self, index_col=None):
        """
        Returns the underlying Spark schema.

        Returns
        -------
        pyspark.sql.types.StructType
            The underlying Spark schema.

        Parameters
        ----------
        index_col: str or list of str, optional, default: None
            Column names to be used in Spark to represent Koalas' index. The index name
            in Koalas is ignored. By default, the index is always lost.

        Examples
        --------
        >>> df = ks.DataFrame({'a': list('abc'),
        ...                    'b': list(range(1, 4)),
        ...                    'c': np.arange(3, 6).astype('i1'),
        ...                    'd': np.arange(4.0, 7.0, dtype='float64'),
        ...                    'e': [True, False, True],
        ...                    'f': pd.date_range('20130101', periods=3)},
        ...                   columns=['a', 'b', 'c', 'd', 'e', 'f'])
        >>> df.spark.schema().simpleString()
        'struct<a:string,b:bigint,c:tinyint,d:double,e:boolean,f:timestamp>'
        >>> df.spark.schema(index_col='index').simpleString()
        'struct<index:bigint,a:string,b:bigint,c:tinyint,d:double,e:boolean,f:timestamp>'
        """
        return self.frame(index_col).schema 
Example #21
Source File: utils.py    From koalas with Apache License 2.0 5 votes vote down vote up
def default_session(conf=None):
    if conf is None:
        conf = dict()
    should_use_legacy_ipc = False
    if LooseVersion(pyarrow.__version__) >= LooseVersion("0.15") and LooseVersion(
        pyspark.__version__
    ) < LooseVersion("3.0"):
        conf["spark.executorEnv.ARROW_PRE_0_15_IPC_FORMAT"] = "1"
        conf["spark.yarn.appMasterEnv.ARROW_PRE_0_15_IPC_FORMAT"] = "1"
        conf["spark.mesos.driverEnv.ARROW_PRE_0_15_IPC_FORMAT"] = "1"
        conf["spark.kubernetes.driverEnv.ARROW_PRE_0_15_IPC_FORMAT"] = "1"
        should_use_legacy_ipc = True

    builder = spark.SparkSession.builder.appName("Koalas")
    for key, value in conf.items():
        builder = builder.config(key, value)
    # Currently, Koalas is dependent on such join due to 'compute.ops_on_diff_frames'
    # configuration. This is needed with Spark 3.0+.
    builder.config("spark.sql.analyzer.failAmbiguousSelfJoin", False)
    session = builder.getOrCreate()

    if not should_use_legacy_ipc:
        is_legacy_ipc_set = any(
            v == "1"
            for v in [
                session.conf.get("spark.executorEnv.ARROW_PRE_0_15_IPC_FORMAT", None),
                session.conf.get("spark.yarn.appMasterEnv.ARROW_PRE_0_15_IPC_FORMAT", None),
                session.conf.get("spark.mesos.driverEnv.ARROW_PRE_0_15_IPC_FORMAT", None),
                session.conf.get("spark.kubernetes.driverEnv.ARROW_PRE_0_15_IPC_FORMAT", None),
            ]
        )
        if is_legacy_ipc_set:
            raise RuntimeError(
                "Please explicitly unset 'ARROW_PRE_0_15_IPC_FORMAT' environment variable in "
                "both driver and executor sides. Check your spark.executorEnv.*, "
                "spark.yarn.appMasterEnv.*, spark.mesos.driverEnv.* and "
                "spark.kubernetes.driverEnv.* configurations. It is required to set this "
                "environment variable only when you use pyarrow>=0.15 and pyspark<3.0."
            )
    return session 
Example #22
Source File: __init__.py    From koalas with Apache License 2.0 5 votes vote down vote up
def _auto_patch_spark():
    import os
    import logging

    # Attach a usage logger.
    logger_module = os.getenv("KOALAS_USAGE_LOGGER", None)
    if logger_module is not None:
        try:
            from databricks.koalas import usage_logging

            usage_logging.attach(logger_module)
        except Exception as e:
            logger = logging.getLogger("databricks.koalas.usage_logger")
            logger.warning(
                "Tried to attach usage logger `{}`, but an exception was raised: {}".format(
                    logger_module, str(e)
                )
            )

    # Autopatching is on by default.
    x = os.getenv("SPARK_KOALAS_AUTOPATCH", "true")
    if x.lower() in ("true", "1", "enabled"):
        logger = logging.getLogger("spark")
        logger.info(
            "Patching spark automatically. You can disable it by setting "
            "SPARK_KOALAS_AUTOPATCH=false in your environment"
        )

        from pyspark.sql import dataframe as df

        df.DataFrame.to_koalas = DataFrame.to_koalas 
Example #23
Source File: test_types.py    From dagster with Apache License 2.0 5 votes vote down vote up
def test_spark_data_frame_serialization_file_system_file_handle(spark_config):
    @solid
    def nonce(_):
        return LocalFileHandle(file_relative_path(__file__, 'data/test.csv'))

    @pipeline(mode_defs=[spark_local_fs_mode])
    def spark_df_test_pipeline():
        ingest_csv_file_handle_to_spark(nonce())

    instance = DagsterInstance.ephemeral()

    result = execute_pipeline(
        spark_df_test_pipeline,
        mode='spark',
        run_config={
            'storage': {'filesystem': {}},
            'resources': {'pyspark': {'config': {'spark_conf': spark_config}}},
        },
        instance=instance,
    )

    intermediate_store = build_fs_intermediate_store(
        instance.intermediates_directory, run_id=result.run_id
    )

    assert result.success
    result_dir = os.path.join(
        intermediate_store.root,
        'intermediates',
        'ingest_csv_file_handle_to_spark.compute',
        'result',
    )

    assert '_SUCCESS' in os.listdir(result_dir)

    spark = SparkSession.builder.getOrCreate()
    df = spark.read.parquet(result_dir)
    assert isinstance(df, pyspark.sql.dataframe.DataFrame)
    assert df.head()[0] == '1' 
Example #24
Source File: testSmvfuncs.py    From SMV with Apache License 2.0 5 votes vote down vote up
def test_smvCollectSet(self):
        from pyspark.sql.types import StringType
        df = self.createDF("a:String", "a;b;a;c;a")
        res = df.agg(smvArrayCat("_", sort_array(smvCollectSet(df.a, StringType()))).alias("aa"))
        exp = self.createDF("aa: String", "a_b_c")
        self.should_be_same(res, exp) 
Example #25
Source File: testSmvfuncs.py    From SMV with Apache License 2.0 5 votes vote down vote up
def test_smvCreateLookup(self):
        from pyspark.sql.types import StringType
        df = self.createDF("k:String;v:Integer", "a,1;b,2;,3;c,4")
        map_key = smvCreateLookUp({"a":"AA", "b":"BB"}, "__", StringType())
        res = df.withColumn("mapped", map_key(df.k))
        exp = self.createDF("k: String;v: Integer;mapped: String",
            "a,1,AA;" +
            "b,2,BB;" +
            ",3,__;" +
            "c,4,__")
        self.should_be_same(res, exp) 
Example #26
Source File: launcher.py    From spylon with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def with_sql_context(application_name, conf=None):
    """Context manager for a spark context

    Returns
    -------
    sc : SparkContext
    sql_context: SQLContext

    Examples
    --------
    Used within a context manager
    >>> with with_sql_context("MyApplication") as (sc, sql_context):
    ...     import pyspark
    ...     # Do stuff
    ...     pass

    """
    if conf is None:
        conf = default_configuration
    assert isinstance(conf, SparkConfiguration)

    sc = conf.spark_context(application_name)
    import pyspark.sql
    try:
        yield sc, pyspark.sql.SQLContext(sc)
    finally:
        sc.stop() 
Example #27
Source File: launcher.py    From spylon with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def spark_session(self, application_name):
        sc = self.spark_context(application_name)
        from pyspark.sql import SparkSession
        return SparkSession(sc) 
Example #28
Source File: launcher.py    From spylon with BSD 3-Clause "New" or "Revised" License 5 votes vote down vote up
def spark_context(self, application_name):
        """Create a spark context given the parameters configured in this class.

        The caller is responsible for calling ``.close`` on the resulting spark context

        Parameters
        ----------
        application_name : string

        Returns
        -------
        sc : SparkContext
        """

        # initialize the spark configuration
        self._init_spark()
        import pyspark
        import pyspark.sql

        # initialize conf
        spark_conf = pyspark.SparkConf()
        for k, v in self._spark_conf_helper._conf_dict.items():
            spark_conf.set(k, v)

        log.info("Starting SparkContext")
        return pyspark.SparkContext(appName=application_name, conf=spark_conf) 
Example #29
Source File: unischema.py    From petastorm with Apache License 2.0 5 votes vote down vote up
def _numpy_to_spark_mapping():
    """Returns a mapping from numpy to pyspark.sql type. Caches the mapping dictionary inorder to avoid instantiation
    of multiple objects in each call."""

    # Refer to the attribute of the function we use to cache the map using a name in the variable instead of a 'dot'
    # notation to avoid copy/paste/typo mistakes
    cache_attr_name = 'cached_numpy_to_pyspark_types_map'
    if not hasattr(_numpy_to_spark_mapping, cache_attr_name):
        import pyspark.sql.types as T

        setattr(_numpy_to_spark_mapping, cache_attr_name,
                {
                    np.int8: T.ByteType(),
                    np.uint8: T.ShortType(),
                    np.int16: T.ShortType(),
                    np.uint16: T.IntegerType(),
                    np.int32: T.IntegerType(),
                    np.int64: T.LongType(),
                    np.float32: T.FloatType(),
                    np.float64: T.DoubleType(),
                    np.string_: T.StringType(),
                    np.str_: T.StringType(),
                    np.unicode_: T.StringType(),
                    np.bool_: T.BooleanType(),
                })

    return getattr(_numpy_to_spark_mapping, cache_attr_name)


# TODO: Changing fields in this class or the UnischemaField will break reading due to the schema being pickled next to
# the dataset on disk 
Example #30
Source File: mleap.py    From mlflow with Apache License 2.0 4 votes vote down vote up
def add_to_model(mlflow_model, path, spark_model, sample_input):
    """
    Add the MLeap flavor to an existing MLflow model.

    :param mlflow_model: :py:mod:`mlflow.models.Model` to which this flavor is being added.
    :param path: Path of the model to which this flavor is being added.
    :param spark_model: Spark PipelineModel to be saved. This model must be MLeap-compatible and
                        cannot contain any custom transformers.
    :param sample_input: Sample PySpark DataFrame input that the model can evaluate. This is
                         required by MLeap for data schema inference.
    """
    from pyspark.ml.pipeline import PipelineModel
    from pyspark.sql import DataFrame
    import mleap.version
    from mleap.pyspark.spark_support import SimpleSparkSerializer  # pylint: disable=unused-variable
    from py4j.protocol import Py4JError

    if not isinstance(spark_model, PipelineModel):
        raise Exception("Not a PipelineModel."
                        " MLeap can save only PipelineModels.")
    if sample_input is None:
        raise Exception("A sample input must be specified in order to add the MLeap flavor.")
    if not isinstance(sample_input, DataFrame):
        raise Exception("The sample input must be a PySpark dataframe of type `{df_type}`".format(
            df_type=DataFrame.__module__))

    # MLeap's model serialization routine requires an absolute output path
    path = os.path.abspath(path)

    mleap_path_full = os.path.join(path, "mleap")
    mleap_datapath_sub = os.path.join("mleap", "model")
    mleap_datapath_full = os.path.join(path, mleap_datapath_sub)
    if os.path.exists(mleap_path_full):
        raise Exception("MLeap model data path already exists at: {path}".format(
            path=mleap_path_full))
    os.makedirs(mleap_path_full)

    dataset = spark_model.transform(sample_input)
    model_path = "file:{mp}".format(mp=mleap_datapath_full)
    try:
        spark_model.serializeToBundle(path=model_path,
                                      dataset=dataset)
    except Py4JError:
        _handle_py4j_error(
                MLeapSerializationException,
                "MLeap encountered an error while serializing the model. Ensure that the model is"
                " compatible with MLeap (i.e does not contain any custom transformers).")

    try:
        mleap_version = mleap.version.__version__
        _logger.warning(
            "Detected old mleap version %s. Support for logging models in mleap format with "
            "mleap versions 0.15.0 and below is deprecated and will be removed in a future "
            "MLflow release. Please upgrade to a newer mleap version.", mleap_version)
    except AttributeError:
        mleap_version = mleap.version
    mlflow_model.add_flavor(FLAVOR_NAME,
                            mleap_version=mleap_version,
                            model_data=mleap_datapath_sub)