Python pyspark.sql.types.StringType() Examples

The following are 30 code examples for showing how to use pyspark.sql.types.StringType(). These examples are extracted from open source projects. You can vote up the ones you like or vote down the ones you don't like, and go to the original project or source file by following the links above each example.

You may check out the related API usage on the sidebar.

You may also want to check out all available functions/classes of the module pyspark.sql.types , or try the search function .

Example 1
Project: Hanhan-Spark-Python   Author: hanhanwu   File: temp_range_sql.py    License: MIT License 10 votes vote down vote up
def main():
    temp_schema = StructType([
    StructField('StationID', StringType(), False),
    StructField('DateTime', StringType(), False),
    StructField('Observation', StringType(), False),
    StructField('DataValue', DoubleType(), False),
    StructField('MFlag', StringType(), True),
    StructField('QFlag', StringType(), True),
    StructField('SFlag', StringType(), True),
    StructField('OBSTime', StringType(), True),
    ])

    df = sqlContext.read.format('com.databricks.spark.csv').options(header='false').load(inputs1, schema=temp_schema)
    df = df.filter(df.QFlag == '')

    dfrange = get_range(df)
    result = dfrange.rdd.map(lambda r: str(r.DateTime)+' '+str(r.StationID)+' '+str(r.MaxRange))
    outdata = result.sortBy(lambda r: r[0]).coalesce(1)
    outdata.saveAsTextFile(output) 
Example 2
Project: spark-deep-learning   Author: databricks   File: imageIO.py    License: Apache License 2.0 8 votes vote down vote up
def filesToDF(sc, path, numPartitions=None):
    """
    Read files from a directory to a DataFrame.

    :param sc: SparkContext.
    :param path: str, path to files.
    :param numPartition: int, number or partitions to use for reading files.
    :return: DataFrame, with columns: (filePath: str, fileData: BinaryType)
    """
    numPartitions = numPartitions or sc.defaultParallelism
    schema = StructType([StructField("filePath", StringType(), False),
                         StructField("fileData", BinaryType(), False)])
    rdd = sc.binaryFiles(
        path, minPartitions=numPartitions).repartition(numPartitions)
    rdd = rdd.map(lambda x: (x[0], bytearray(x[1])))
    return rdd.toDF(schema) 
Example 3
Project: spark-deep-learning   Author: databricks   File: named_image.py    License: Apache License 2.0 6 votes vote down vote up
def _decodeOutputAsPredictions(self, df):
        # If we start having different weights than imagenet, we'll need to
        # move this logic to individual model building in NamedImageTransformer.
        # Also, we could put the computation directly in the main computation
        # graph or use a scala UDF for potentially better performance.
        topK = self.getOrDefault(self.topK)

        def decode(predictions):
            pred_arr = np.expand_dims(np.array(predictions), axis=0)
            decoded = decode_predictions(pred_arr, top=topK)[0]
            # convert numpy dtypes to python native types
            return [(t[0], t[1], t[2].item()) for t in decoded]

        decodedSchema = ArrayType(
            StructType([
                StructField("class", StringType(), False),
                StructField("description", StringType(), False),
                StructField("probability", FloatType(), False)
            ]))
        decodeUDF = udf(decode, decodedSchema)
        interim_output = self._getIntermediateOutputCol()
        return df \
            .withColumn(self.getOutputCol(), decodeUDF(df[interim_output])) \
            .drop(interim_output) 
Example 4
Project: petastorm   Author: uber   File: test_unischema.py    License: Apache License 2.0 6 votes vote down vote up
def test_as_spark_schema():
    """Try using 'as_spark_schema' function"""
    TestSchema = Unischema('TestSchema', [
        UnischemaField('int_field', np.int8, (), ScalarCodec(IntegerType()), False),
        UnischemaField('string_field', np.string_, (), ScalarCodec(StringType()), False),
        UnischemaField('string_field_implicit', np.string_, ()),
    ])

    spark_schema = TestSchema.as_spark_schema()
    assert spark_schema.fields[0].name == 'int_field'

    assert spark_schema.fields[1].name == 'string_field'
    assert spark_schema.fields[1].dataType == StringType()

    assert spark_schema.fields[2].name == 'string_field_implicit'
    assert spark_schema.fields[2].dataType == StringType()

    assert TestSchema.fields['int_field'].name == 'int_field'
    assert TestSchema.fields['string_field'].name == 'string_field' 
Example 5
Project: Hanhan-Spark-Python   Author: hanhanwu   File: reddit_average_sql.py    License: MIT License 6 votes vote down vote up
def main():
    schema = StructType([
    StructField('subreddit', StringType(), False),
    StructField('score', IntegerType(), False),
    ])
    inputs = sqlContext.read.json(inputs1, schema=schema)

    # Uncomment this then shcema is not added
    # inputs = sqlContext.read.json(inputs1)

    # Uncomment these when there are 2 inputs dir
    # comments_input1 = sqlContext.read.json(inputs1, schema=schema)
    # comments_input2 = sqlContext.read.json(inputs2, schema=schema)
    # inputs = comments_input1.unionAll(comments_input2)

    df = get_avg(inputs)
    df.write.save(output, format='json', mode='overwrite') 
Example 6
Project: LearningApacheSpark   Author: runawayhorse001   File: functions.py    License: MIT License 6 votes vote down vote up
def to_date(col, format=None):
    """Converts a :class:`Column` of :class:`pyspark.sql.types.StringType` or
    :class:`pyspark.sql.types.TimestampType` into :class:`pyspark.sql.types.DateType`
    using the optionally specified format. Specify formats according to
    `SimpleDateFormats <http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html>`_.
    By default, it follows casting rules to :class:`pyspark.sql.types.DateType` if the format
    is omitted (equivalent to ``col.cast("date")``).

    >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
    >>> df.select(to_date(df.t).alias('date')).collect()
    [Row(date=datetime.date(1997, 2, 28))]

    >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
    >>> df.select(to_date(df.t, 'yyyy-MM-dd HH:mm:ss').alias('date')).collect()
    [Row(date=datetime.date(1997, 2, 28))]
    """
    sc = SparkContext._active_spark_context
    if format is None:
        jc = sc._jvm.functions.to_date(_to_java_column(col))
    else:
        jc = sc._jvm.functions.to_date(_to_java_column(col), format)
    return Column(jc) 
Example 7
Project: LearningApacheSpark   Author: runawayhorse001   File: functions.py    License: MIT License 6 votes vote down vote up
def to_timestamp(col, format=None):
    """Converts a :class:`Column` of :class:`pyspark.sql.types.StringType` or
    :class:`pyspark.sql.types.TimestampType` into :class:`pyspark.sql.types.DateType`
    using the optionally specified format. Specify formats according to
    `SimpleDateFormats <http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html>`_.
    By default, it follows casting rules to :class:`pyspark.sql.types.TimestampType` if the format
    is omitted (equivalent to ``col.cast("timestamp")``).

    >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
    >>> df.select(to_timestamp(df.t).alias('dt')).collect()
    [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))]

    >>> df = spark.createDataFrame([('1997-02-28 10:30:00',)], ['t'])
    >>> df.select(to_timestamp(df.t, 'yyyy-MM-dd HH:mm:ss').alias('dt')).collect()
    [Row(dt=datetime.datetime(1997, 2, 28, 10, 30))]
    """
    sc = SparkContext._active_spark_context
    if format is None:
        jc = sc._jvm.functions.to_timestamp(_to_java_column(col))
    else:
        jc = sc._jvm.functions.to_timestamp(_to_java_column(col), format)
    return Column(jc) 
Example 8
Project: LearningApacheSpark   Author: runawayhorse001   File: functions.py    License: MIT License 6 votes vote down vote up
def locate(substr, str, pos=1):
    """
    Locate the position of the first occurrence of substr in a string column, after position pos.

    .. note:: The position is not zero based, but 1 based index. Returns 0 if substr
        could not be found in str.

    :param substr: a string
    :param str: a Column of :class:`pyspark.sql.types.StringType`
    :param pos: start position (zero based)

    >>> df = spark.createDataFrame([('abcd',)], ['s',])
    >>> df.select(locate('b', df.s, 1).alias('s')).collect()
    [Row(s=2)]
    """
    sc = SparkContext._active_spark_context
    return Column(sc._jvm.functions.locate(substr, _to_java_column(str), pos)) 
Example 9
Project: SMV   Author: TresAmigosSD   File: testColumnHelper.py    License: Apache License 2.0 6 votes vote down vote up
def test_smvArrayFlatten(self):
        df = self.createDF('a:String;b:String;c:String', ',,;1,2,;2,3,4')
        df1 = df.select(F.array(
            F.array(F.lit(None), F.col('a')),
            F.array(F.col('a'), F.col('b'), F.col('c'))
        ).alias('aa'))

        res1 = df1.select(F.col('aa').smvArrayFlatten(StringType()).alias('a'))\
            .select(SF.smvArrayCat('|', F.col('a')).alias('k'))

        exp = self.createDF("k: String",
        """||||;
            |1|1|2|;
            |2|2|3|4""")

        res2 = df1.select(F.col('aa').smvArrayFlatten(df1).alias('a'))\
            .select(SF.smvArrayCat('|', F.col('a')).alias('k'))

        self.should_be_same(res1, exp)
        self.should_be_same(res2, exp) 
Example 10
Project: example_dataproc_twitter   Author: WillianFuks   File: df_naive.py    License: MIT License 6 votes vote down vote up
def register_udfs(self, sess, sc):
        """Register UDFs to be used in SQL queries.

        :type sess: `pyspark.sql.SparkSession`
        :param sess: Session used in Spark for SQL queries.

        :type sc: `pyspark.SparkContext`
        :param sc: Spark Context to run Spark jobs.
        """ 
        sess.udf.register("SQUARED", self.squared, returnType=(
            stypes.ArrayType(stypes.StructType(
            fields=[stypes.StructField('sku0', stypes.StringType()),
            stypes.StructField('norm', stypes.FloatType())]))))

        sess.udf.register('INTERSECTIONS',self.process_intersections,
            returnType=stypes.ArrayType(stypes.StructType(fields=[
            stypes.StructField('sku0', stypes.StringType()),
            stypes.StructField('sku1', stypes.StringType()),
            stypes.StructField('cor', stypes.FloatType())]))) 
Example 11
Project: example_dataproc_twitter   Author: WillianFuks   File: df_naive.py    License: MIT License 6 votes vote down vote up
def register_udfs(self, sess, sc):
        """Register UDFs to be used in SQL queries.

        :type sess: `pyspark.sql.SparkSession`
        :param sess: Session used in Spark for SQL queries.

        :type sc: `pyspark.SparkContext`
        :param sc: Spark Context to run Spark jobs.
        """ 
        sess.udf.register("SQUARED", self.squared, returnType=(
            stypes.ArrayType(stypes.StructType(
            fields=[stypes.StructField('sku0', stypes.StringType()),
            stypes.StructField('norm', stypes.FloatType())]))))

        sess.udf.register('INTERSECTIONS',self.process_intersections,
            returnType=stypes.ArrayType(stypes.StructType(fields=[
            stypes.StructField('sku0', stypes.StringType()),
            stypes.StructField('sku1', stypes.StringType()),
            stypes.StructField('cor', stypes.FloatType())]))) 
Example 12
Project: HoloClean-Legacy-deprecated   Author: HoloClean   File: accuracy.py    License: Apache License 2.0 6 votes vote down vote up
def read_groundtruth(self):

        """
        Create a dataframe from the ground truth csv file

        Takes as argument the full path name of the csv file
        and the spark_session
        """
        filereader = Reader(self.spark_session)

        groundtruth_schema = StructType([
            StructField("tid", IntegerType(), False),
            StructField("attr_name", StringType(), False),
            StructField("attr_val", StringType(), False)])

        self.ground_truth_flat = filereader.read(self.path_to_grand_truth, 0,
                                                 groundtruth_schema).\
            drop(GlobalVariables.index_name)

        self.dataengine.add_db_table(
            'Groundtruth', self.ground_truth_flat, self.dataset) 
Example 13
Project: reinvent-scaffold-decorator   Author: undeadpixel   File: sample_scaffolds.py    License: MIT License 6 votes vote down vote up
def _join_results_multi(self, scaffolds_df, sampled_df):
        def _join_scaffold(scaff, dec):
            mol = usc.join(scaff, dec)
            if mol:
                return usc.to_smiles(mol)

        def _format_attachment_point(smi, num):
            smi = usc.add_first_attachment_point_number(smi, num)
            return usc.to_smiles(uc.to_mol(smi))  # canonicalize

        join_scaffold_udf = psf.udf(_join_scaffold, pst.StringType())
        format_attachment_point_udf = psf.udf(_format_attachment_point, pst.StringType())

        return sampled_df.join(scaffolds_df, on="id")\
            .withColumn("decoration", format_attachment_point_udf("decoration_smi", psf.col("attachment_points")[0]))\
            .select(
                join_scaffold_udf("smiles", "decoration").alias("smiles"),
                psf.map_concat(
                    psf.create_map(psf.col("attachment_points")[0],
                                   SampleScaffolds.cleanup_decoration_udf("decoration")),
                    "decorations",
                ).alias("decorations"),
                "scaffold") 
Example 14
Project: reinvent-scaffold-decorator   Author: undeadpixel   File: sample_scaffolds.py    License: MIT License 6 votes vote down vote up
def _join_results_single(self, scaffolds_df, sampled_df):
        def _join_scaffold(scaff, decs):
            mol = usc.join_joined_attachments(scaff, decs)
            if mol:
                return usc.to_smiles(mol)
        join_scaffold_udf = psf.udf(_join_scaffold, pst.StringType())

        def _create_decorations_map(decorations_smi, attachment_points):
            decorations = decorations_smi.split(usc.ATTACHMENT_SEPARATOR_TOKEN)
            return {idx: _cleanup_decoration(dec) for dec, idx in zip(decorations, attachment_points)}
        create_decorations_map_udf = psf.udf(_create_decorations_map, pst.MapType(pst.IntegerType(), pst.StringType()))

        return sampled_df.join(scaffolds_df, on="id")\
            .select(
                join_scaffold_udf("randomized_scaffold", "decoration_smi").alias("smiles"),
                create_decorations_map_udf("decoration_smi", "attachment_points").alias("decorations"),
                "scaffold") 
Example 15
Project: mlflow   Author: mlflow   File: utils.py    License: Apache License 2.0 6 votes vote down vote up
def format_to_file_path(spark_session):
    rows = [
        Row(8, 32, "bat"),
        Row(64, 40, "mouse"),
        Row(-27, 55, "horse")
    ]
    schema = StructType([
        StructField("number2", IntegerType()),
        StructField("number1", IntegerType()),
        StructField("word", StringType())
    ])
    rdd = spark_session.sparkContext.parallelize(rows)
    df = spark_session.createDataFrame(rdd, schema)
    res = {}
    tempdir = tempfile.mkdtemp()
    for data_format in ["csv", "parquet", "json"]:
        res[data_format] = os.path.join(tempdir, "test-data-%s" % data_format)

    for data_format, file_path in res.items():
        df.write.option("header", "true").format(data_format).save(file_path)
    yield res
    shutil.rmtree(tempdir) 
Example 16
Project: spark-deep-learning   Author: databricks   File: image_utils.py    License: Apache License 2.0 5 votes vote down vote up
def getSampleImagePathsDF(sqlContext, colName):
    files = getSampleImagePaths()
    return sqlContext.createDataFrame(files, StringType()).toDF(colName)

# Methods for making comparisons between outputs of using different frameworks.
# For ImageNet. 
Example 17
Project: spark-deep-learning   Author: databricks   File: test_imageIO.py    License: Apache License 2.0 5 votes vote down vote up
def test_filesTODF(self):
        df = imageIO.filesToDF(self.binaryFilesMock, "path", 217)
        self.assertEqual(df.rdd.getNumPartitions(), 217)
        df.schema.fields[0].dataType == StringType()
        df.schema.fields[0].dataType == BinaryType()
        first = df.first()
        self.assertTrue(hasattr(first, "filePath"))
        self.assertEqual(type(first.fileData), bytearray)


# TODO: make unit tests for arrayToImageRow on arrays of varying shapes, channels, dtypes. 
Example 18
Project: search-MjoLniR   Author: wikimedia   File: norm_query_clustering.py    License: MIT License 5 votes vote down vote up
def cluster_within_norm_query_groups(df: DataFrame) -> DataFrame:
    make_groups = F.udf(_make_query_groups, T.ArrayType(T.StructType([
        T.StructField('query', T.StringType(), nullable=False),
        T.StructField('norm_query_group_id', T.IntegerType(), nullable=False),
    ])))
    return (
        df
        .groupBy('wikiid', 'norm_query')
        .agg(F.collect_list(F.struct('query', 'hit_page_ids')).alias('source'))
        .select(
            'wikiid', 'norm_query',
            F.explode(make_groups('source')).alias('group'))
        .select('wikiid', 'norm_query', 'group.query', 'group.norm_query_group_id')) 
Example 19
Project: python_moztelemetry   Author: mozilla   File: test_dataset.py    License: Mozilla Public License 2.0 5 votes vote down vote up
def test_dataframe_bad_schema(dataset, spark):
    spark.catalog.dropTempView('bar')
    schema = StructType([StructField("name", StringType(), True)])
    df = dataset.dataframe(spark, decode=decode, schema=schema, table_name='bar')

    assert type(df) == DataFrame
    assert df.collect() == [Row(name=None), Row(name=None)] 
Example 20
Project: eva   Author: georgia-tech-db   File: schema_utils.py    License: Apache License 2.0 5 votes vote down vote up
def get_petastorm_column(df_column):

        column_type = df_column.type
        column_name = df_column.name
        column_is_nullable = df_column.is_nullable
        column_array_dimensions = df_column.array_dimensions

        # Reference:
        # https://github.com/uber/petastorm/blob/master/petastorm/
        # tests/test_common.py

        petastorm_column = None
        if column_type == ColumnType.INTEGER:
            petastorm_column = UnischemaField(column_name,
                                              np.int32,
                                              (),
                                              ScalarCodec(IntegerType()),
                                              column_is_nullable)
        elif column_type == ColumnType.FLOAT:
            petastorm_column = UnischemaField(column_name,
                                              np.float64,
                                              (),
                                              ScalarCodec(FloatType()),
                                              column_is_nullable)
        elif column_type == ColumnType.TEXT:
            petastorm_column = UnischemaField(column_name,
                                              np.string_,
                                              (),
                                              ScalarCodec(StringType()),
                                              column_is_nullable)
        elif column_type == ColumnType.NDARRAY:
            petastorm_column = UnischemaField(column_name,
                                              np.uint8,
                                              column_array_dimensions,
                                              NdarrayCodec(),
                                              column_is_nullable)
        else:
            LoggingManager().log("Invalid column type: " + str(column_type),
                                 LoggingLevel.ERROR)

        return petastorm_column 
Example 21
Project: elephas   Author: maxpumperla   File: ml_model.py    License: MIT License 5 votes vote down vote up
def _transform(self, df):
        """Private transform method of a Transformer. This serves as batch-prediction method for our purposes.
        """
        output_col = self.getOutputCol()
        label_col = self.getLabelCol()
        new_schema = copy.deepcopy(df.schema)
        new_schema.add(StructField(output_col, StringType(), True))

        rdd = df.rdd.coalesce(1)
        features = np.asarray(
            rdd.map(lambda x: from_vector(x.features)).collect())
        # Note that we collect, since executing this on the rdd would require model serialization once again
        model = model_from_yaml(self.get_keras_model_config())
        model.set_weights(self.weights.value)
        predictions = rdd.ctx.parallelize(
            model.predict_classes(features)).coalesce(1)
        predictions = predictions.map(lambda x: tuple(str(x)))

        results_rdd = rdd.zip(predictions).map(lambda x: x[0] + x[1])
        results_df = df.sql_ctx.createDataFrame(results_rdd, new_schema)
        results_df = results_df.withColumn(
            output_col, results_df[output_col].cast(DoubleType()))
        results_df = results_df.withColumn(
            label_col, results_df[label_col].cast(DoubleType()))

        return results_df 
Example 22
Project: petastorm   Author: uber   File: unischema.py    License: Apache License 2.0 5 votes vote down vote up
def _numpy_to_spark_mapping():
    """Returns a mapping from numpy to pyspark.sql type. Caches the mapping dictionary inorder to avoid instantiation
    of multiple objects in each call."""

    # Refer to the attribute of the function we use to cache the map using a name in the variable instead of a 'dot'
    # notation to avoid copy/paste/typo mistakes
    cache_attr_name = 'cached_numpy_to_pyspark_types_map'
    if not hasattr(_numpy_to_spark_mapping, cache_attr_name):
        import pyspark.sql.types as T

        setattr(_numpy_to_spark_mapping, cache_attr_name,
                {
                    np.int8: T.ByteType(),
                    np.uint8: T.ShortType(),
                    np.int16: T.ShortType(),
                    np.uint16: T.IntegerType(),
                    np.int32: T.IntegerType(),
                    np.int64: T.LongType(),
                    np.float32: T.FloatType(),
                    np.float64: T.DoubleType(),
                    np.string_: T.StringType(),
                    np.str_: T.StringType(),
                    np.unicode_: T.StringType(),
                    np.bool_: T.BooleanType(),
                })

    return getattr(_numpy_to_spark_mapping, cache_attr_name)


# TODO: Changing fields in this class or the UnischemaField will break reading due to the schema being pickled next to
# the dataset on disk 
Example 23
Project: petastorm   Author: uber   File: test_end_to_end.py    License: Apache License 2.0 5 votes vote down vote up
def test_invalid_schema_field(synthetic_dataset, reader_factory):
    # Let's assume we are selecting columns using a schema which is different from the one
    # stored in the dataset. Would expect to get a reasonable error message
    BogusSchema = Unischema('BogusSchema', [
        UnischemaField('partition_key', np.string_, (), ScalarCodec(StringType()), False),
        UnischemaField('id', np.int64, (), ScalarCodec(LongType()), False),
        UnischemaField('bogus_key', np.int32, (), ScalarCodec(ShortType()), False)])

    expected_values = {'bogus_key': 11, 'id': 1}
    with pytest.raises(ValueError, match='bogus_key'):
        reader_factory(synthetic_dataset.url, schema_fields=BogusSchema.fields.values(),
                       shuffle_row_groups=False,
                       predicate=EqualPredicate(expected_values)) 
Example 24
Project: petastorm   Author: uber   File: test_unischema.py    License: Apache License 2.0 5 votes vote down vote up
def test_fields():
    """Try using 'fields' getter"""
    TestSchema = Unischema('TestSchema', [
        UnischemaField('int_field', np.int8, (), ScalarCodec(IntegerType()), False),
        UnischemaField('string_field', np.string_, (), ScalarCodec(StringType()), False),
    ])

    assert len(TestSchema.fields) == 2
    assert TestSchema.fields['int_field'].name == 'int_field'
    assert TestSchema.fields['string_field'].name == 'string_field' 
Example 25
Project: petastorm   Author: uber   File: test_unischema.py    License: Apache License 2.0 5 votes vote down vote up
def test_dict_to_spark_row_field_validation_scalar_types():
    """Test various validations done on data types when converting a dictionary to a spark row"""
    TestSchema = Unischema('TestSchema', [
        UnischemaField('string_field', np.string_, (), ScalarCodec(StringType()), False),
    ])

    assert isinstance(dict_to_spark_row(TestSchema, {'string_field': 'abc'}), Row)

    # Not a nullable field
    with pytest.raises(ValueError):
        isinstance(dict_to_spark_row(TestSchema, {'string_field': None}), Row)

    # Wrong field type
    with pytest.raises(TypeError):
        isinstance(dict_to_spark_row(TestSchema, {'string_field': []}), Row) 
Example 26
Project: petastorm   Author: uber   File: test_unischema.py    License: Apache License 2.0 5 votes vote down vote up
def test_make_named_tuple():
    TestSchema = Unischema('TestSchema', [
        UnischemaField('string_scalar', np.string_, (), ScalarCodec(StringType()), True),
        UnischemaField('int32_scalar', np.int32, (), ScalarCodec(ShortType()), False),
        UnischemaField('uint8_scalar', np.uint8, (), ScalarCodec(ShortType()), False),
        UnischemaField('int32_matrix', np.float32, (10, 20, 3), NdarrayCodec(), True),
        UnischemaField('decimal_scalar', Decimal, (10, 20, 3), ScalarCodec(DecimalType(10, 9)), False),
    ])

    TestSchema.make_namedtuple(string_scalar='abc', int32_scalar=10, uint8_scalar=20,
                               int32_matrix=np.int32((10, 20, 3)), decimal_scalar=Decimal(123) / Decimal(10))

    TestSchema.make_namedtuple(string_scalar=None, int32_scalar=10, uint8_scalar=20,
                               int32_matrix=None, decimal_scalar=Decimal(123) / Decimal(10)) 
Example 27
Project: petastorm   Author: uber   File: test_unischema.py    License: Apache License 2.0 5 votes vote down vote up
def test_insert_explicit_nulls():
    TestSchema = Unischema('TestSchema', [
        UnischemaField('nullable', np.int32, (), ScalarCodec(StringType()), True),
        UnischemaField('not_nullable', np.int32, (), ScalarCodec(ShortType()), False),
    ])

    # Insert_explicit_nulls to leave the dictionary as is.
    row_dict = {'nullable': 0, 'not_nullable': 1}
    insert_explicit_nulls(TestSchema, row_dict)
    assert len(row_dict) == 2
    assert row_dict['nullable'] == 0
    assert row_dict['not_nullable'] == 1

    # Insert_explicit_nulls to leave the dictionary as is.
    row_dict = {'nullable': None, 'not_nullable': 1}
    insert_explicit_nulls(TestSchema, row_dict)
    assert len(row_dict) == 2
    assert row_dict['nullable'] is None
    assert row_dict['not_nullable'] == 1

    # We are missing a nullable field here. insert_explicit_nulls should add a None entry.
    row_dict = {'not_nullable': 1}
    insert_explicit_nulls(TestSchema, row_dict)
    assert len(row_dict) == 2
    assert row_dict['nullable'] is None
    assert row_dict['not_nullable'] == 1

    # We are missing a not_nullable field here. Should raise an ValueError.
    row_dict = {'nullable': 0}
    with pytest.raises(ValueError):
        insert_explicit_nulls(TestSchema, row_dict) 
Example 28
Project: petastorm   Author: uber   File: test_unischema.py    License: Apache License 2.0 5 votes vote down vote up
def test_create_schema_view_fails_validate():
    """ Exercises code paths unischema.create_schema_view ValueError, and unischema.__str__."""
    TestSchema = Unischema('TestSchema', [
        UnischemaField('int_field', np.int8, (), ScalarCodec(IntegerType()), False),
        UnischemaField('string_field', np.string_, (), ScalarCodec(StringType()), False),
    ])
    with pytest.raises(ValueError, match='does not belong to the schema'):
        TestSchema.create_schema_view([UnischemaField('id', np.int64, (), ScalarCodec(LongType()), False)]) 
Example 29
Project: petastorm   Author: uber   File: test_unischema.py    License: Apache License 2.0 5 votes vote down vote up
def test_create_schema_view_using_invalid_type():
    """ Exercises code paths unischema.create_schema_view ValueError, and unischema.__str__."""
    TestSchema = Unischema('TestSchema', [
        UnischemaField('int_field', np.int8, (), ScalarCodec(IntegerType()), False),
        UnischemaField('string_field', np.string_, (), ScalarCodec(StringType()), False),
    ])
    with pytest.raises(ValueError, match='must be either a string'):
        TestSchema.create_schema_view([42]) 
Example 30
Project: petastorm   Author: uber   File: test_unischema.py    License: Apache License 2.0 5 votes vote down vote up
def test_create_schema_view_using_unischema_fields():
    TestSchema = Unischema('TestSchema', [
        UnischemaField('int_field', np.int8, (), ScalarCodec(IntegerType()), False),
        UnischemaField('string_field', np.string_, (), ScalarCodec(StringType()), False),
    ])
    view = TestSchema.create_schema_view([TestSchema.int_field])
    assert set(view.fields.keys()) == {'int_field'}