from collections import OrderedDict from typing import Type, Union from flask import jsonify from werkzeug.wrappers import Response from werkzeug.exceptions import BadRequest, InternalServerError from marshmallow import Schema, EXCLUDE, RAISE from marshmallow.fields import List from marshmallow.exceptions import ValidationError from flask_restx.model import Model from flask_restx import fields, reqparse, inputs from flask_accepts.utils import for_swagger, get_default_model_name, is_list_field, ma_field_to_reqparse_argument def accepts( *args, model_name: str = None, schema: Union[Schema, Type[Schema], None] = None, query_params_schema: Union[Schema, Type[Schema], None] = None, headers_schema: Union[Schema, Type[Schema], None] = None, many: bool = False, api=None, use_swagger: bool = True, partial: bool = False ): """ Wrap a Flask route with input validation using a combination of reqparse from Flask-restx and/or Marshmallow schemas Args: *args: any number of dictionaries containing parameters to pass to reqparse.RequestParser().add_argument(). A single string parameter may also be provided that is used as the model name. By default these parameters will be parsed using the default logic however, if a schema is provided then the JSON body is assumed to correspond to it and will not be parsed for query params. model_name (str): the name to pass to api.Model, can optionally be provided as a str argument to *args schema (Marshmallow.Schema, optional): A Marshmallow Schema that will be used to parse JSON data from the request body and store in request.parsed_obj. Defaults to None. query_params_schema (Marshmallow.Schema, optional): A Marshmallow Schema that will be used to parse data from the request query params and store in request.parsed_query_params. These values will also be added to the `request.args` dict. Defaults to None. headers_schema (Marshmallow.Schema, optional): A Marshmallow Schema that will be used to parse data from the request header and store in request.parsed_headers. Defaults to None. many (bool, optional): The Marshmallow schema `many` parameter, which will return a list of the corresponding schema objects when set to True. This flag corresopnds only to the request body schema, and not the `query_params_schema` or `headers_schema` arguments. partial (bool): The partial argument for marshmallow schema loading. Returns: The wrapped route """ _check_deprecate_many(many) # If an api was passed in, we need to use its parser so Swagger is aware if api: _parser = api.parser() else: _parser = reqparse.RequestParser(bundle_errors=True) query_params = [arg for arg in args if isinstance(arg, dict)] for arg in args: # check for positional string-arg, which is the model name if isinstance(arg, str): model_name = arg break # Handles query params passed in as positional arguments. for qp in query_params: params = {**qp, "location": qp.get("location") or "values"} if qp["type"] == bool: # mapping native bool is necessary so that string "false" is not truthy # https://flask-restx.readthedocs.io/en/stable/parsing.html#advanced-types-handling params["type"] = inputs.boolean _parser.add_argument(**params) # Handles request body schema. if schema: schema = _get_or_create_schema(schema, many=many) # Handles query params schema. if query_params_schema: query_params_schema = _get_or_create_schema(query_params_schema, unknown=EXCLUDE) for name, field in query_params_schema.fields.items(): params = {**ma_field_to_reqparse_argument(field), "location": "values"} _parser.add_argument(name, **params) # Handles headers schema. if headers_schema: headers_schema = _get_or_create_schema(headers_schema, unknown=EXCLUDE) for name, field in headers_schema.fields.items(): params = {**ma_field_to_reqparse_argument(field), "location": "headers"} _parser.add_argument(name, **params) def decorator(func): from functools import wraps # Check if we are decorating a class method _IS_METHOD = _is_method(func) @wraps(func) def inner(*args, **kwargs): from flask import request error = schema_error = None # Handle arguments try: request.parsed_args = _parser.parse_args() except Exception as e: error = e # Handle Marshmallow schema for request body if schema: try: obj = schema.load(request.get_json(), partial=partial) request.parsed_obj = obj except ValidationError as ex: schema_error = ex.messages if schema_error: error = error or BadRequest( f"Error parsing request body: {schema_error}" ) if hasattr(error, "data"): error.data["errors"].update({"schema_errors": schema_error}) else: error.data = {"schema_errors": schema_error} # Handle Marshmallow schema for query params if query_params_schema: request_args = _convert_multidict_values_to_schema( request.args, query_params_schema) try: obj = query_params_schema.load(request_args) request.parsed_query_params = obj except ValidationError as ex: schema_error = ex.messages if schema_error: error = error or BadRequest( f"Error parsing query params: {schema_error}" ) if hasattr(error, "data"): error.data["errors"].update({"schema_errors": schema_error}) else: error.data = {"schema_errors": schema_error} # Handle Marshmallow schema for headers if headers_schema: request_headers = _convert_multidict_values_to_schema( request.headers, headers_schema) try: obj = headers_schema.load(request_headers) request.parsed_headers = obj except ValidationError as ex: schema_error = ex.messages if schema_error: error = error or BadRequest( f"Error parsing headers: {schema_error}" ) if hasattr(error, "data"): error.data["errors"].update({"schema_errors": schema_error}) else: error.data = {"schema_errors": schema_error} # If any parsing produced an error, combine them and re-raise if error: raise error return func(*args, **kwargs) # Add Swagger if api and use_swagger and _IS_METHOD: if schema: body = for_swagger( schema=schema, model_name=model_name or get_default_model_name(schema), api=api, operation="load", ) params = { "expect": [body, _parser], } inner = api.doc(**params)(inner) elif _parser: inner = api.expect(_parser)(inner) return inner return decorator def responds( *args, model_name: str = None, schema=None, many: bool = False, api=None, envelope=None, status_code: int = 200, validate: bool = False, description: str = None, use_swagger: bool = True, ): """ Serialize the output of a function using the Marshmallow schema to dump the results. Note that `schema` should be the type, not an instance -- the `responds` decorator will internally handle creation of the schema. If the outputted value is already of type flask.Response, it will be passed along without further modification. Args: schema (bool, optional): Marshmallow schema with which to serialize the output of the wrapped function. many (bool, optional): (DEPRECATED) The Marshmallow schema `many` parameter, which will return a list of the corresponding schema objects when set to True. Returns: The output of schema(many=many).dumps(<return value>) of the wrapped function """ from functools import wraps from flask_restx import reqparse _check_deprecate_many(many) # If an api was passed in, we need to use its parser so Swagger is aware if api: _parser = api.parser() else: _parser = reqparse.RequestParser(bundle_errors=True) query_params = [arg for arg in args if isinstance(arg, dict)] for arg in args: # check for positional string-arg, which is the model name if isinstance(arg, str): model_name = arg break for qp in query_params: _parser.add_argument(**qp, location="values") ordered = None if schema: schema = _get_or_create_schema(schema, many=many) ordered = schema.ordered model_name = model_name or get_default_model_name(schema) model_from_parser = _model_from_parser(model_name=model_name, parser=_parser) def decorator(func): # Check if we are decorating a class method _IS_METHOD = _is_method(func) @wraps(func) def inner(*args, **kwargs): rv = func(*args, **kwargs) # If a Flask response has been made already, it is passed through unchanged if isinstance(rv, Response): return rv if schema: serialized = schema.dump(rv) # Validate data if asked to (throws) if validate: errs = schema.validate(serialized) if errs: raise InternalServerError( description="Server attempted to return invalid data" ) # Apply the flask-restx mask after validation serialized = _apply_restx_mask(serialized) else: from flask_restx import marshal serialized = marshal(rv, model_from_parser) if envelope: serialized = OrderedDict([(envelope, serialized)]) if ordered else {envelope: serialized} if not _is_method(func): # Regular route, need to manually create Response return jsonify(serialized), status_code return serialized, status_code # Add Swagger if api and use_swagger and _IS_METHOD: if schema: api_model = for_swagger( schema=schema, model_name=model_name, api=api, operation="dump" ) if schema.many is True: api_model = [api_model] inner = _document_like_marshal_with( api_model, status_code=status_code, description=description, )(inner) elif _parser: api.add_model(model_name, model_from_parser) inner = _document_like_marshal_with( model_from_parser, status_code=status_code )(inner) return inner return decorator def _apply_restx_mask(serialized): from flask import current_app, request from flask_restx.mask import apply as apply_mask mask_header = current_app.config.get("RESTX_MASK_HEADER", "X-Fields") mask = request.headers.get(mask_header) return apply_mask(serialized, mask) if mask else serialized def _check_deprecate_many(many: bool = False): if many: import warnings warnings.simplefilter("always", DeprecationWarning) warnings.warn( "The 'many' parameter is deprecated in favor of passing these " "arguments to an actual instance of Marshmallow schema (i.e. " "prefer @responds(schema=MySchema(many=True)) instead of " "@responds(schema=MySchema, many=True))", DeprecationWarning, stacklevel=3, ) def _get_or_create_schema( schema: Union[Schema, Type[Schema]], many: bool = False, unknown: str = RAISE ) -> Schema: if isinstance(schema, Schema): return schema return schema(many=many, unknown=unknown) def _model_from_parser(model_name: str, parser: reqparse.RequestParser) -> Model: from flask_restx import fields base_type_map = { "integer": fields.Integer, "string": fields.String, "number": fields.Float, } type_factory = { "integer": lambda arg: base_type_map["integer"], "string": lambda arg: base_type_map["string"], "number": lambda arg: base_type_map["number"], "array": lambda arg: fields.List(base_type_map[arg["items"]["type"]]), } return Model( model_name, {arg["name"]: type_factory[arg["type"]](arg) for arg in parser.__schema__}, ) def merge(first: dict, second: dict) -> dict: return {**first, **second} def _document_like_marshal_with( values, status_code: int = 200, description: str = None ): description = description or "Success" def inner(func): doc = {"responses": {status_code: (description, values)}, "__mask__": True} func.__apidoc__ = merge(getattr(func, "__apidoc__", {}), doc) return func return inner def _is_method(func): """ Check is function is defined inside a class. ASSUMES YOU ARE USING THE CONVENTION THAT FIRST ARG IS 'self' """ import inspect sig = inspect.signature(func) return "self" in sig.parameters def _convert_multidict_values_to_schema(multidict, schema): """Helper function that converts values in the given multidict into either single or list values based on the schema definition. This function is necessary for parsing multidict mappings like querystrings where it's ambiguous whether the value is single or list value. Take the following query string as an example: ?foo=bar In this case, `foo` could map to a single string `'bar'`, or a list with one string element `['bar']`. This function looks at the given `schema` and converts the values in the given `multidict` appropriately to be parsed be loaded by `marshmallow` later on. """ result = {} fields = dict(schema.fields.items()) for key, value in multidict.items(): # If the key isn't defined in the schema, then insert it into the # result set as is and let marshmallow validation raise an error. if key not in fields: result[key] = value # If the corresponding field is a list, then make sure to return the # value as a list. elif is_list_field(fields[key]): result[key] = multidict.getlist(key) else: result[key] = value return result