from math import ceil from typing import Callable from typing import Generic from typing import List from typing import Optional from typing import Type from typing import TypeVar from fastapi.encoders import jsonable_encoder from pydantic import parse_obj_as from sqlalchemy import desc from sqlalchemy import inspect from sqlalchemy.orm import RelationshipProperty from sqlalchemy.orm import Session from app.db.base import BaseModel from app.schemas.base import BaseSchema ModelType = TypeVar("ModelType", bound=BaseModel) CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) class Pagination(object): """Internal helper class returned by :meth:`BaseQuery.paginate`. You can also construct it from any other SQLAlchemy query object if you are working with other libraries. Additionally it is possible to pass `None` as query object in which case the :meth:`prev` and :meth:`next` will no longer work. """ def __init__(self, query, page, per_page, total, items): #: the unlimited query object that was used to create this #: pagination object. self.query = query #: the current page number (1 indexed) self.page = int(page) or 1 #: the number of items to be displayed on a page. self.per_page = int(per_page) #: the total number of items matching the query self.total = total #: the items for the current page self.items = items @property def pages(self): """The total number of pages""" if self.per_page == 0 or self.total is None: pages = 0 else: pages = int(ceil(self.total / float(self.per_page))) return pages def prev(self, error_out=False): """Returns a :class:`Pagination` object for the previous page.""" assert self.query is not None, 'a query object is required ' \ 'for this method to work' return self.query.paginate(self.page - 1, self.per_page, error_out) @property def prev_num(self): """Number of the previous page.""" if not self.has_prev: return None return self.page - 1 @property def has_prev(self): """True if a previous page exists""" return self.page > 1 def next(self, error_out=False): """Returns a :class:`Pagination` object for the next page.""" assert self.query is not None, 'a query object is required ' \ 'for this method to work' return self.query.paginate(self.page + 1, self.per_page, error_out) @property def has_next(self): """True if a next page exists.""" return self.page < self.pages @property def next_num(self): """Number of the next page""" if not self.has_next: return None return self.page + 1 def to_dict(self): return { 'page': self.page, 'prev_num': self.prev_num, 'has_prev': self.has_prev, 'has_next': self.has_next, 'next_num': self.next_num, 'total': self.total, 'items': self.items } class CrudBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): def __init__( self, model: Type[ModelType], create_schema: CreateSchemaType = None, update_schema: UpdateSchemaType = None ): """ CRUD object with default methods to Create, Read, Update, Delete (CRUD). **Parameters** * `model`: A SQLAlchemy model class * `schema`: A Pydantic model (schema) class """ self.model = model self.createSchema = create_schema self.updateSchema = update_schema def filter_by(self, db_session: Session, **filter_kwargs): return db_session.query(self.model).filter_by(**filter_kwargs) def exists(self, db_session: Session, **filter_kwargs): filter_query = self.filter_by(db_session=db_session, **filter_kwargs) return db_session.query(filter_query.exists()).scalar() def get(self, db_session: Session, id: int) -> Optional[ModelType]: return db_session.query(self.model).filter(self.model.id == id).first() def get_multi(self, db_session: Session, *, skip=0, limit=100) -> List[ModelType]: return db_session.query(self.model).offset(skip).limit(limit).all() def create( self, db_session: Session, *, obj_in: CreateSchemaType, serializer: Callable = jsonable_encoder ) -> ModelType: if serializer: obj_in_data = serializer(obj_in) else: obj_in_data = obj_in.dict() if hasattr(obj_in, 'dict') else obj_in db_obj = self.model(**obj_in_data) db_session.add(db_obj) db_session.commit() db_session.refresh(db_obj) return db_obj def update( self, db_session: Session, *, obj_id: int, obj_in: UpdateSchemaType ) -> ModelType: db_obj = self.get(db_session, obj_id) if getattr(obj_in, 'dict', None): update_data = obj_in.dict( skip_defaults=True, exclude_unset=True, exclude={'updated_on', 'created_on', 'id'} ) else: update_data = obj_in for field in update_data: if hasattr(db_obj, field): setattr(db_obj, field, update_data[field]) db_session.commit() db_session.refresh(db_obj) return db_obj def remove(self, db_session: Session, *, id: int) -> bool: result = db_session.query(self.model).filter_by(id=id).delete() db_session.commit() return result def paginate( self, db_session: Session, query=None, page=None, per_page=None, count=True, query_all=False, **kwargs ): if not query: query = db_session.query(self.model).filter_by(**kwargs).order_by(desc(self.model.id)) if query_all: items = query.all() else: items = query.limit(per_page).offset((page - 1) * per_page).all() if not count: total = None else: total = query.order_by(None).count() return Pagination(query, page, per_page, total, items) def check_relation_data_exists(self, db_session: Session, id: int, relation_key_list: List = None) -> bool: if not relation_key_list: relation_key_list = [ model_relation.key for model_relation in inspect(self.model).mapper.relationships ] obj = self.get(db_session, id) relation_exists = any( [ getattr(obj, model_relation_key, False) for model_relation_key in relation_key_list ] ) if obj else False return relation_exists @staticmethod def serialize_list_obj(serialize_schema: Type[BaseSchema], obj_list: List[ModelType]) -> List[Type[BaseModel]]: item_list = [ jsonable_encoder(item) for item in parse_obj_as(List[serialize_schema], obj_list) ] return item_list