# coding=utf-8 import operator from collections import OrderedDict import coreschema import uritemplate from coreapi import Link, Document, Field from coreapi.compat import force_text from django.db import models from django.utils.functional import Promise from pkg_resources import parse_version from rest_framework import serializers from rest_framework.fields import IntegerField, URLField from rest_framework.pagination import PageNumberPagination, LimitOffsetPagination, CursorPagination from rest_framework.schemas import SchemaGenerator from rest_framework.schemas.generators import insert_into, distribute_links, LinkNode from rest_framework.schemas.inspectors import get_pk_description, field_to_schema from drf_openapi.codec import _get_parameters class VersionedSerializers: """Adapted from https://github.com/avanov/Rhetoric/ :) """ OPERATORS = { '>': operator.gt, '<': operator.lt, '==': operator.eq, '>=': operator.ge, '<=': operator.le } """ A map of version and serializer definition May be represented in the following form 1. ``VERSION`` 2. ``==VERSION`` (the same as above) 3. ``>VERSION`` 4. ``<VERSION`` 5. ``>=VERSION`` 6. ``<=Version`` 7. Comma-separated list of 1-7 evaluated as AND Must override in subclass, for example VERSION_MAP = ( ('>1.3, <=1.6', MeSerializer16) ('>1.6', MeSerializer) ) """ VERSION_MAP = () @classmethod def get(cls, request_version): for allowed_version, schema in cls.VERSION_MAP: distinct_versions = [version.strip() for version in allowed_version.split(',')] matched = True for distinct_version in distinct_versions: operation = cls.OPERATORS.get(distinct_version[:2]) if operation: # prepare cases #2, #5, #6 compare_with = distinct_version[2:] else: operation = cls.OPERATORS.get(distinct_version[0]) if operation: # prepare cases #3, #4 compare_with = distinct_version[1:] else: # prepare case #1 compare_with = distinct_version operation = cls.OPERATORS['=='] matched = operation(parse_version(request_version), parse_version(compare_with)) if not matched: matched = False if matched: return schema raise ValueError('Invalid request version {}'.format(request_version)) get.__func__.__annotations__ = {'request_version': str} class OpenApiSchemaGenerator(SchemaGenerator): def __init__(self, version, title=None, url=None, description=None, patterns=None, urlconf=None): self.version = version super(OpenApiSchemaGenerator, self).__init__(title, url, description, patterns, urlconf) def get_schema(self, request=None, public=False): if self.endpoints is None: inspector = self.endpoint_inspector_cls(self.patterns, self.urlconf) self.endpoints = inspector.get_api_endpoints() links = self.get_links(None if public else request) if not links: return None url = self.url if not url and request is not None: url = request.build_absolute_uri() distribute_links(links) return OpenApiDocument( version=self.version, title=self.title, description=self.description, url=url, content=links ) def get_links(self, request=None): """ Return a dictionary containing all the links that should be included in the API schema. """ links = LinkNode() # Generate (path, method, view) given (path, method, callback). paths = [] view_endpoints = [] for path, method, callback in self.endpoints: view = self.create_view(callback, method, request) if getattr(view, 'exclude_from_schema', False): continue path = self.coerce_path(path, method, view) paths.append(path) view_endpoints.append((path, method, view)) # Only generate the path prefix for paths that will be included if not paths: return None prefix = self.determine_path_prefix(paths) for path, method, view in view_endpoints: if not self.has_view_permissions(path, method, view): continue link = self.get_link(path, method, view, version=getattr(request, 'version', None)) subpath = path[len(prefix):] keys = self.get_keys(subpath, method, view) try: insert_into(links, keys, link) except Exception: continue return links def get_serializer_doc(self, serializer): if serializer.__doc__ is None: return '' doc = [] for line in serializer.__doc__.splitlines(): doc.append(line.strip()) return '\n'.join(doc) def get_link(self, path, method, view, version=None): method_name = getattr(view, 'action', method.lower()) method_func = getattr(view, method_name, None) fields = self.get_path_fields(path, method, view) fields += self.get_serializer_fields(path, method, view, version=version, method_func=method_func) fields += view.schema.get_pagination_fields(path, method) fields += view.schema.get_filter_fields(path, method) if fields and any([field.location in ('form', 'body') for field in fields]): encoding = view.schema.get_encoding(path, method) else: encoding = None description = view.schema.get_description(path, method) request_serializer_class = getattr(method_func, 'request_serializer', None) if request_serializer_class and issubclass(request_serializer_class, VersionedSerializers): request_doc = self.get_serializer_doc(request_serializer_class) if request_doc: description = description + '\n\n**Request Description:**\n' + request_doc response_serializer_class = getattr(method_func, 'response_serializer', None) if response_serializer_class and issubclass(response_serializer_class, VersionedSerializers): res_doc = self.get_serializer_doc(response_serializer_class) if res_doc: description = description + '\n\n**Response Description:**\n' + res_doc response_serializer_class = response_serializer_class.get(version) if not response_serializer_class and method_name in ('list', 'retrieve'): if hasattr(view, 'get_serializer_class'): response_serializer_class = view.get_serializer_class() elif hasattr(view, 'serializer_class'): response_serializer_class = view.serializer_class if response_serializer_class and method_name == 'list': response_serializer_class = self.get_paginator_serializer( view, response_serializer_class) response_schema, error_status_codes = self.get_response_object( response_serializer_class, method_func.__doc__) if response_serializer_class else ({}, {}) return OpenApiLink( response_schema=response_schema, error_status_codes=error_status_codes, url=path.replace('{version}', self.version), # can't use format because there may be other param action=method.lower(), encoding=encoding, fields=fields, description=description ) def get_paginator_serializer(self, view, child_serializer_class): class BaseFakeListSerializer(serializers.Serializer): results = child_serializer_class(many=True) class FakePrevNextListSerializer(BaseFakeListSerializer): next = URLField() previous = URLField() # Validate if the view has a pagination_class if not (hasattr(view, 'pagination_class')) or view.pagination_class is None: return BaseFakeListSerializer pager = view.pagination_class if hasattr(pager, 'default_pager'): # Must be a ProxyPagination pager = pager.default_pager if issubclass(pager, (PageNumberPagination, LimitOffsetPagination)): class FakeListSerializer(FakePrevNextListSerializer): count = IntegerField() return FakeListSerializer elif issubclass(pager, CursorPagination): return FakePrevNextListSerializer return BaseFakeListSerializer def get_path_fields(self, path, method, view): """ Return a list of `coreapi.Field` instances corresponding to any templated path variables. """ model = getattr(getattr(view, 'queryset', None), 'model', None) fields = [] for variable in uritemplate.variables(path): if variable == 'version': continue title = '' description = '' schema_cls = coreschema.String kwargs = {} if model is not None: # Attempt to infer a field description if possible. try: model_field = model._meta.get_field(variable) except: model_field = None if model_field is not None and model_field.verbose_name: title = force_text(model_field.verbose_name) if model_field is not None and model_field.help_text: description = force_text(model_field.help_text) elif model_field is not None and model_field.primary_key: description = get_pk_description(model, model_field) if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable: kwargs['pattern'] = view.lookup_value_regex elif isinstance(model_field, models.AutoField): schema_cls = coreschema.Integer field = Field( name=variable, location='path', required=True, schema=schema_cls(title=title, description=description, **kwargs) ) fields.append(field) return fields def get_serializer_class(self, view, method_func): """ Try to get the serializer class from view method. If view method don't have request serializer, fallback to serializer_class on view class """ if hasattr(method_func, 'request_serializer'): return getattr(method_func, 'request_serializer') if hasattr(view, 'serializer_class'): return getattr(view, 'serializer_class') if hasattr(view, 'get_serializer_class'): return getattr(view, 'get_serializer_class')() return None def fallback_schema_from_field(self, field): """ Fallback schema for field that isn't inspected properly by DRF and probably won't land in upstream canon due to its hacky nature only for doc purposes """ title = force_text(field.label) if field.label else '' description = force_text(field.help_text) if field.help_text else '' # since we can't really inspect dictfield and jsonfield, at least display object as type # instead of string if isinstance(field, (serializers.DictField, serializers.JSONField)): return coreschema.Object( properties={}, title=title, description=description ) def get_serializer_fields(self, path, method, view, version=None, method_func=None): """ Return a list of `coreapi.Field` instances corresponding to any request body input, as determined by the serializer class. """ if method in ('PUT', 'PATCH', 'POST'): location = 'form' else: location = 'query' serializer_class = self.get_serializer_class(view, method_func) if not serializer_class: return [] serializer = serializer_class() if isinstance(serializer, serializers.ListSerializer): return [ Field( name='data', location=location, required=True, schema=coreschema.Array() ) ] if not isinstance(serializer, serializers.Serializer): return [] fields = [] for field in serializer.fields.values(): if field.read_only or isinstance(field, serializers.HiddenField): continue required = field.required and method != 'PATCH' # if the attribute ('help_text') of this field is a lazy translation object, force it to generate a string description = str(field.help_text) if isinstance(field.help_text, Promise) else field.help_text fallback_schema = self.fallback_schema_from_field(field) field = Field( name=field.field_name, location=location, required=required, schema=fallback_schema if fallback_schema else field_to_schema(field), description=description, ) fields.append(field) return fields def get_response_object(self, response_serializer_class, description): fields = [] serializer = response_serializer_class() nested_obj = {} for field in serializer.fields.values(): # If field is a serializer, attempt to get its schema. if isinstance(field, serializers.Serializer): subfield_schema = self.get_response_object(field.__class__, None)[0].get('schema') # If the schema exists, use it as the nested_obj if subfield_schema is not None: nested_obj[field.field_name] = subfield_schema nested_obj[field.field_name]['description'] = field.help_text continue # Otherwise, carry-on and use the field's schema. fallback_schema = self.fallback_schema_from_field(field) fields.append(Field( name=field.field_name, location='form', required=field.required, schema=fallback_schema if fallback_schema else field_to_schema(field), )) res = _get_parameters(Link(fields=fields), None) if not res: if nested_obj: return { 'description': description, 'schema': { 'type': 'object', 'properties': nested_obj } }, {} else: return {}, {} schema = res[0]['schema'] schema['properties'].update(nested_obj) response_schema = { 'description': description, 'schema': schema } error_status_codes = {} response_meta = getattr(response_serializer_class, 'Meta', None) for status_code, description in getattr(response_meta, 'error_status_codes', {}).items(): error_status_codes[status_code] = {'description': description} return response_schema, error_status_codes class OpenApiDocument(Document): """OpenAPI-compliant document provides: - Versioning information """ def __init__(self, version, url=None, title=None, description=None, media_type=None, content=None): super(OpenApiDocument, self).__init__( url=url, title=title, description=description, media_type=media_type, content=content ) self._version = version @property def version(self): return self._version class OpenApiLink(Link): """OpenAPI-compliant Link provides: - Schema to the response """ def __init__(self, response_schema, error_status_codes, url=None, action=None, encoding=None, transform=None, title=None, description=None, fields=None): super(OpenApiLink, self).__init__( url=url, action=action, encoding=encoding, transform=transform, title=title, description=description, fields=fields ) self._response_schema = response_schema self._error_status_codes = error_status_codes @property def response_schema(self): return self._response_schema @property def error_status_codes(self): return self._error_status_codes