import enum from logbook import Logger from sqlalchemy import Column, Integer, DateTime, String, Boolean, func, LargeBinary, Enum, ForeignKey, Index, Float from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker, relationship log = Logger('database') Base = declarative_base() class PostType(enum.Enum): IMAGE = 1 ANIMATED = 2 VIDEO = 3 @staticmethod def fromImage(image): ending = image[image.rfind('.') + 1:] if ending in ['jpg', 'jpeg', 'png']: return PostType.IMAGE elif ending in ['gif']: return PostType.ANIMATED elif ending in ['mp4']: return PostType.VIDEO class PostTypeTable(Base): __tablename__ = 'posttype' name = Column(Enum(PostType), primary_key=True, index=True) class PostStatus(enum.Enum): NOT_INDEXED = 1 INDEXED = 2 BROKEN = 3 class Flag(enum.Enum): SFW = 1 NSFW = 2 NSFL = 3 NSFP = 4 class Post(Base): __tablename__ = 'post' id = Column(Integer, primary_key=True, index=True) created = Column(DateTime(), nullable=False) image = Column(String(256), nullable=False, index=True) thumb = Column(String(256), nullable=False) fullsize = Column(String(256)) width = Column(Integer(), nullable=False) height = Column(Integer(), nullable=False) audio = Column(Boolean(), nullable=False) source = Column(String(512)) flags = Column(Integer(), nullable=False) user = Column(String(32), nullable=False) type = Column(Enum(PostType), nullable=False, index=True) status = Column(Enum(PostStatus), nullable=False, index=True, default=PostStatus.NOT_INDEXED) __table_args__ = (Index('post_status_type_index', "status", "type"),) def __json__(self): return { 'id': self.id, 'user': self.user, 'created': self.created.isoformat(), 'is_sfw': self.is_sfw(), 'is_nsfw': self.is_nsfw(), 'is_nsfl': self.is_nsfl(), 'image': self.image, 'thumb': self.thumb, } def is_sfw(self): return self.flags & 1 != 0 def is_nsfw(self): return self.flags & 2 != 0 def is_nsfl(self): return self.flags & 4 != 0 def is_nsfp(self): return self.flags & 8 != 0 def get_flags(self): flags = [] if self.is_sfw(): flags.append(Flag.SFW) if self.is_nsfw(): flags.append(Flag.NSFW) if self.is_nsfl(): flags.append(Flag.NSFL) if self.is_nsfp(): flags.append(Flag.NSFP) return flags def get_flag_by_importance(self): if self.is_nsfl(): return Flag.NSFL if self.is_nsfw(): return Flag.NSFW if self.is_nsfp(): return Flag.NSFP return Flag.SFW def __str__(self): return "Post(id=" + str(self.id) + ")" def __repr__(self): return "Post(id=" + str(self.id) + ")" class FeatureType(enum.Enum): FEATURE_VECTOR = 1 AHASH = 2 PHASH = 3 DHASH = 4 WHASH = 5 class FeatureTypeTable(Base): __tablename__ = 'featuretype' name = Column(Enum(FeatureType), primary_key=True, index=True) class Feature(Base): __tablename__ = 'feature' post_id = Column(Integer, ForeignKey('post.id'), primary_key=True, index=True) post = relationship(Post) type = Column(Enum(FeatureType), primary_key=True, index=True) id = Column(Integer, primary_key=True, index=True) data = Column(LargeBinary) def __str__(self): return "Feature(post=%s, type=%s, data=%s)" % (self.post, self.type, self.data) def __repr__(self): return "Feature(id=" + str(self.id) + ")" @staticmethod def from_analyzeresult(post, type, data): feature = Feature() feature.post_id = post.id feature.type = type feature.id = 1 feature.data = data return feature class Tag(Base): __tablename__ = 'tag' id = Column(Integer, primary_key=True, index=True) post_id = Column(Integer, nullable=False, index=True) tag = Column(String(256), nullable=False, index=True) up = Column(Integer, nullable=False) down = Column(Integer, nullable=False) confidence = Column(Float, nullable=False, index=True) def __str__(self): return "Tag(id=%d, post_id=%d, tag=%s, up=%d, down=%d, confidence=%f)" % ( self.id, self.post_id, self.tag, self.up, self.down, self.confidence) def __repr__(self): return self.__str__() class Database(): def __init__(self, engine): self.engine = engine log.info("connecting to database {}", engine) Base.metadata.create_all(self.engine) Base.metadata.bind = self.engine self.DBSession = sessionmaker(bind=self.engine, expire_on_commit=False) log.info("connected to database {}", engine) def latest_post_id(self): session = self.DBSession() res = session.query(func.max(Post.id).label('latest_post_id')).scalar() session.close() return res def latest_tag_id(self): session = self.DBSession() res = session.query(func.max(Tag.id).label('latest_tag_id')).scalar() session.close() return res def get_engine(self): return self.engine def get_posts(self, type=None): session = self.DBSession() res = None if type is not None: res = session.query(Post).filter(Post.type == type) else: res = session.query(Post) session.close() return res def get_posts_missing_features(self): session = self.DBSession() res = session.query(Post).filter((Post.type == PostType.IMAGE) & (Post.status == PostStatus.NOT_INDEXED)) session.close() return res def post_count(self): session = self.DBSession() res = session.query(func.count(Post.id)).filter(Post.type == PostType.IMAGE).scalar() session.close() return res def get_post_by_id(self, id): session = self.DBSession() res = session.query(Post).filter_by(id=id).scalar() session.close() return res def close(self): log.debug("closing database connection {}", self.engine) self.engine.dispose() log.debug("closed database connection {}", self.engine)