# -*- coding: utf-8 -*- # Copyright 2016 Yelp Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from __future__ import absolute_import from __future__ import unicode_literals import copy from avro import schema class AvroSchemaBuilder(object): """ AvroSchemaBuilder creates json-formatted Avro schemas. It has `create_*` function for each primitive type to create primitive types. To create a complex type, start with corresponding `begin_*` function and finish it with `end` function. It defers the schema validation until the end of schema building. When the schema building ends, it constructs the corresponding schema object which will validate the the syntax of the Avro json object. **Examples**: build a record schema:: ab = AvroSchemaBuilder() record = ab.begin_record( 'user', namespace='yelp' ).add_field( 'id', typ=ab.create_int() ).add_field( 'fav_color', typ=ab.begin_enum('color_enum', ['red', 'blue']).end() ).end() build an enum schema:: ab = AvroSchemaBuilder() enum_schema = ab.begin_enum('color_enum', ['red', 'blue']).end() build an array schema:: ab = AvroSchemaBuilder() array_schema = ab.begin_array(ab.create_string()).end() build a record field:: ab = AvroSchemaBuilder() new_field = AvroSchemaBuilder.create_field( 'col_id', typ=ab.create_int(), has_default=False, default_value=None ) """ def __init__(self): self._schema_json = None # current avro schema in build self._schema_tracker = [] @classmethod def create_null(cls): return 'null' @classmethod def create_boolean(cls): return 'boolean' @classmethod def create_int(cls): return 'int' @classmethod def create_long(cls): return 'long' @classmethod def create_float(cls): return 'float' @classmethod def create_double(cls): return 'double' @classmethod def create_bytes(cls): return 'bytes' @classmethod def create_string(cls): return 'string' def _add_metadata_to_schema(self, schema, **metadata): schema.update(metadata) def _save_and_set_current_schema(self, schema): self._save_current_schema() self._set_current_schema(schema) def begin_date(self, **metadata): date_schema = { 'type': 'int', 'logicalType': 'date' } self._add_metadata_to_schema(date_schema, **metadata) self._save_and_set_current_schema(date_schema) return self def begin_time_millis(self, **metadata): time_millis_schema = { 'type': 'int', 'logicalType': 'time-millis' } self._add_metadata_to_schema(time_millis_schema, **metadata) self._save_and_set_current_schema(time_millis_schema) return self def begin_time_micros(self, **metadata): time_micros_schema = { 'type': 'long', 'logicalType': 'time-micros' } self._add_metadata_to_schema(time_micros_schema, **metadata) self._save_and_set_current_schema(time_micros_schema) return self def begin_timestamp_millis(self, **metadata): timestamp_millis_schema = { 'type': 'long', 'logicalType': 'timestamp-millis' } self._add_metadata_to_schema(timestamp_millis_schema, **metadata) self._save_and_set_current_schema(timestamp_millis_schema) return self def begin_timestamp_micros(self, **metadata): timestamp_micros_schema = { 'type': 'long', 'logicalType': 'timestamp-micros' } self._add_metadata_to_schema(timestamp_micros_schema, **metadata) self._save_and_set_current_schema(timestamp_micros_schema) return self def begin_enum(self, name, symbols, namespace=None, aliases=None, doc=None, **metadata): enum_schema = { 'type': 'enum', 'name': name, 'symbols': symbols } if namespace: self._set_namespace(enum_schema, namespace) if aliases: self._set_aliases(enum_schema, aliases) if doc: self._set_doc(enum_schema, doc) enum_schema.update(metadata) self._save_current_schema() self._set_current_schema(enum_schema) return self def begin_fixed(self, name, size, namespace=None, aliases=None, **metadata): fixed_schema = { 'type': 'fixed', 'name': name, 'size': size } if namespace: self._set_namespace(fixed_schema, namespace) if aliases: self._set_aliases(fixed_schema, aliases) fixed_schema.update(metadata) self._save_current_schema() self._set_current_schema(fixed_schema) return self def begin_decimal_fixed(self, precision, scale, size, name, namespace=None, **metadata): fixed_decimal_schema = { 'type': 'fixed', 'logicalType': 'decimal', 'name': name, 'precision': precision, 'scale': scale, 'size': size } if namespace: self._set_namespace(fixed_decimal_schema, namespace) fixed_decimal_schema.update(metadata) self._save_current_schema() self._set_current_schema(fixed_decimal_schema) return self def begin_decimal_bytes(self, precision, scale, **metadata): bytes_decimal_schema = { 'type': 'bytes', 'logicalType': 'decimal', 'precision': precision, 'scale': scale } bytes_decimal_schema.update(metadata) self._save_current_schema() self._set_current_schema(bytes_decimal_schema) return self def begin_array(self, items_schema, **metadata): array_schema = {'type': 'array', 'items': items_schema} array_schema.update(metadata) self._save_current_schema() self._set_current_schema(array_schema) return self def begin_map(self, values_schema, **metadata): map_schema = {'type': 'map', 'values': values_schema} map_schema.update(metadata) self._save_current_schema() self._set_current_schema(map_schema) return self def begin_record(self, name, namespace=None, aliases=None, doc=None, **metadata): record_schema = {'type': 'record', 'name': name, 'fields': []} if namespace is not None: self._set_namespace(record_schema, namespace) if aliases: self._set_aliases(record_schema, aliases) if doc: self._set_doc(record_schema, doc) record_schema.update(metadata) self._save_current_schema() self._set_current_schema(record_schema) return self def add_field(self, name, typ, has_default=False, default_value=None, sort_order=None, aliases=None, doc=None, **metadata): field = self.create_field( name, typ, has_default=has_default, default_value=default_value, sort_order=sort_order, aliases=aliases, doc=doc, **metadata ) self._schema_json['fields'].append(field) return self @classmethod def create_field(cls, name, typ, has_default=False, default_value=None, sort_order=None, aliases=None, doc=None, **metadata): return AvroField.from_attributes( name, typ, has_default=has_default, default_value=default_value, sort_order=sort_order, aliases=aliases, doc=doc, **metadata ).field_json def begin_union(self, *avro_schemas): union_schema = list(avro_schemas) self._save_current_schema() self._set_current_schema(union_schema) return self def end(self): if not self._schema_tracker: # this is the top level schema; do the schema validation schema_obj = schema.make_avsc_object(self._schema_json) self._schema_json = None return schema_obj.to_json() current_schema_json = self._schema_json self._restore_current_schema() return current_schema_json def _save_current_schema(self): if self._schema_json: self._schema_tracker.append(self._schema_json) def _set_current_schema(self, avro_schema): self._schema_json = avro_schema def _restore_current_schema(self): self._schema_json = self._schema_tracker.pop() @classmethod def _set_namespace(cls, avro_schema, namespace): avro_schema['namespace'] = namespace @classmethod def _set_aliases(cls, avro_schema, aliases): avro_schema['aliases'] = aliases @classmethod def _set_doc(cls, avro_schema, doc): avro_schema['doc'] = doc def begin_nullable_type(self, schema_type, default_value=None): """Create an Avro schema that represents the nullable `schema_type`. The nullable type is a union schema type with `null` primitive type. The given default value is used to determine whether the `null` type should be the first item in the union type. """ null_type = self.create_null() src_type = copy.deepcopy(schema_type) if self.is_nullable_type(schema_type): nullable_schema = src_type else: typ = src_type if isinstance(src_type, list) else [src_type] if default_value is None: typ.insert(0, null_type) else: typ.append(null_type) nullable_schema = self.begin_union(*typ).end() self._save_current_schema() self._set_current_schema(nullable_schema) return self @classmethod def is_nullable_type(cls, schema_type): """Whether the given type is a nullable type, either it is `null` or a union Avro schema type which contains `null`. """ null_type = cls.create_null() return ( schema_type is not None and ( schema_type == null_type or (isinstance(schema_type, list) and any(typ == null_type for typ in schema_type)) ) ) def begin_with_schema_json(self, schema_json): """Begin building the given schema json object. Similar to other `begin_*` functions, it doesn't validate the input schema json until the end of schema. """ self._save_current_schema() self._set_current_schema(copy.deepcopy(schema_json)) return self def remove_field(self, field_name): """Remove the specified field from the fields in the current schema. Raises: ValueError: This exception is thrown if given field cannot be found. """ index, field = self._get_index_and_field(field_name) del self._schema_json['fields'][index] return self def insert_field(self, field, index): """Insert the given field at specified field list index. Args: field (dict): Python json representation of an Avro field. index (int): position index of the field to be inserted. """ self._schema_json['fields'].insert(index, field) return self def insert_fields(self, fields, index): """Insert the given field list at specified field list index. Args: fields (list of dict): List of Python json representation of Avro fields. index (int): start position index to insert the given fields. """ record_fields = self._schema_json['fields'] self._schema_json['fields'] = (record_fields[:index] + fields + record_fields[index:]) return self def get_field_index(self, field_name): """Get the field list index of given field name. Args: field_name (str): name of the field in interest Raises: ValueError: This exception is thrown if given field cannot be found. """ index, _ = self._get_index_and_field(field_name) return index def get_field(self, field_name): """Get the field json dict of given field name. Args: field_name (str): name of the field in interest Raises: ValueError: This exception is thrown if given field cannot be found. """ _, field = self._get_index_and_field(field_name) return field def _get_index_and_field(self, field_name): fields = self._get_fields() for i, field in enumerate(fields): if field['name'] == field_name: return i, field raise ValueError("Cannot find field named {0}".format(field_name)) def _get_fields(self): return self._schema_json.get('fields', []) def replace_field(self, old_field_name, new_fields): """Replace an existing field with 0 or more new fields. Args: old_field_name (str): The name of the field to replace. new_fields (list of dict): A list of new fields to replace the old field. Each field is represented as a dict. If this list is empty the effect will be the same as calling :func:`remove_field`. Raises: ValueError: This exception is thrown if given field cannot be found. Notes: It is recommended to use :func:`create_field` function to construct a new record field instead of hand-crafting the dict. For example:: ab = AvroSchemaBuilder() new_field = AvroSchemaBuilder.create_field( 'col_id', typ=ab.create_int(), has_default=False, default_value=None ) """ index, field = self._get_index_and_field(old_field_name) self._schema_json['fields'][index:index + 1] = new_fields def clear(self): """Clear the schemas that are built so far.""" self._schema_json = None self._schema_tracker = [] class AvroField(object): """This class is used to hold an Avro field schema and provide an easy way to manipulate the field without directly dealing with Python json dict. """ _reserved_keys = {'name', 'type', 'default', 'order', 'aliases', 'doc'} def __init__(self, field_json): self._field_json = field_json @classmethod def from_attributes(cls, name, typ, has_default=False, default_value=None, sort_order=None, aliases=None, doc=None, **metadata): avro_field = AvroField({'name': name}) avro_field.field_type = typ if has_default: avro_field.default_value = default_value if sort_order: avro_field.sort_order = sort_order if aliases: avro_field.aliases = aliases if doc: avro_field.doc = doc avro_field.set_metadata(**metadata) return avro_field @property def field_json(self): return self._field_json @property def name(self): return self._field_json['name'] @property def field_type(self): return self._field_json['type'] @field_type.setter def field_type(self, new_type): self._field_json['type'] = new_type @property def has_default(self): return 'default' in self._field_json @property def default_value(self): return self._field_json['default'] @default_value.setter def default_value(self, new_default_value): self._field_json['default'] = new_default_value @property def sort_order(self): return self._field_json.get('order') @sort_order.setter def sort_order(self, new_sort_order): self._field_json['order'] = new_sort_order @property def aliases(self): return self._field_json.get('aliases') @aliases.setter def aliases(self, new_aliases): self._field_json['aliases'] = new_aliases @property def doc(self): return self._field_json.get('doc') @doc.setter def doc(self, new_doc): self._field_json['doc'] = new_doc @property def metadata(self): return { k: v for k, v in self._field_json.items() if k not in self._reserved_keys } def clear_metadata(self): self._field_json = { k: v for k, v in self._field_json.items() if k in self._reserved_keys } def set_metadata(self, **metadata): self._field_json.update(metadata)