server/posts: add sketch of post table
This commit is contained in:
parent
9ac70dbed4
commit
bc15fb6675
9 changed files with 226 additions and 14 deletions
|
@ -1,3 +1,4 @@
|
|||
from szurubooru.db.base import Base
|
||||
from szurubooru.db.user import User
|
||||
from szurubooru.db.tag import Tag, TagName, TagSuggestion, TagImplication
|
||||
from szurubooru.db.post import Post, PostTag, PostRelation
|
||||
|
|
82
server/szurubooru/db/post.py
Normal file
82
server/szurubooru/db/post.py
Normal file
|
@ -0,0 +1,82 @@
|
|||
from sqlalchemy import Column, Integer, DateTime, String, ForeignKey
|
||||
from sqlalchemy.orm import relationship, column_property
|
||||
from sqlalchemy.sql.expression import func, select
|
||||
from szurubooru.db.base import Base
|
||||
|
||||
class PostRelation(Base):
|
||||
__tablename__ = 'post_relation'
|
||||
|
||||
parent_id = Column('parent_id', Integer, ForeignKey('post.id'), primary_key=True)
|
||||
child_id = Column('child_id', Integer, ForeignKey('post.id'), primary_key=True)
|
||||
|
||||
def __init__(self, parent_id, child_id):
|
||||
self.parent_id = parent_id
|
||||
self.child_id = child_id
|
||||
|
||||
class PostTag(Base):
|
||||
__tablename__ = 'post_tag'
|
||||
|
||||
post_id = Column('post_id', Integer, ForeignKey('post.id'), primary_key=True)
|
||||
tag_id = Column('tag_id', Integer, ForeignKey('tag.id'), primary_key=True)
|
||||
|
||||
def __init__(self, tag_id, post_id):
|
||||
self.tag_id = tag_id
|
||||
self.post_id = post_id
|
||||
|
||||
class Post(Base):
|
||||
__tablename__ = 'post'
|
||||
|
||||
SAFETY_SAFE = 'safe'
|
||||
SAFETY_SKETCHY = 'sketchy'
|
||||
SAFETY_UNSAFE = 'unsafe'
|
||||
TYPE_IMAGE = 'anim'
|
||||
TYPE_ANIMATION = 'anim'
|
||||
TYPE_FLASH = 'flash'
|
||||
TYPE_VIDEO = 'video'
|
||||
TYPE_YOUTUBE = 'youtube'
|
||||
FLAG_LOOP_VIDEO = 1
|
||||
|
||||
post_id = Column('id', Integer, primary_key=True)
|
||||
user_id = Column('user_id', Integer, ForeignKey('user.id'))
|
||||
creation_time = Column('creation_time', DateTime, nullable=False)
|
||||
last_edit_time = Column('last_edit_time', DateTime)
|
||||
safety = Column('safety', String(32), nullable=False)
|
||||
type = Column('type', String(32), nullable=False)
|
||||
checksum = Column('checksum', String(64), nullable=False)
|
||||
source = Column('source', String(200))
|
||||
file_size = Column('file_size', Integer)
|
||||
image_width = Column('image_width', Integer)
|
||||
image_height = Column('image_height', Integer)
|
||||
flags = Column('flags', Integer, nullable=False, default=0)
|
||||
|
||||
user = relationship('User')
|
||||
tags = relationship('Tag', backref='posts', secondary='post_tag')
|
||||
relations = relationship(
|
||||
'Post',
|
||||
secondary='post_relation',
|
||||
primaryjoin=post_id == PostRelation.parent_id,
|
||||
secondaryjoin=post_id == PostRelation.child_id)
|
||||
|
||||
tag_count = column_property(
|
||||
select(
|
||||
[func.count('1')],
|
||||
PostTag.post_id == post_id
|
||||
) \
|
||||
.correlate('Post') \
|
||||
.label('tag_count')
|
||||
)
|
||||
|
||||
# TODO: wire these
|
||||
fav_count = Column('auto_fav_count', Integer, nullable=False, default=0)
|
||||
score = Column('auto_score', Integer, nullable=False, default=0)
|
||||
feature_count = Column('auto_feature_count', Integer, nullable=False, default=0)
|
||||
comment_count = Column('auto_comment_count', Integer, nullable=False, default=0)
|
||||
note_count = Column('auto_note_count', Integer, nullable=False, default=0)
|
||||
last_fav_time = Column(
|
||||
'auto_fav_time', Integer, nullable=False, default=0)
|
||||
last_feature_time = Column(
|
||||
'auto_feature_time', Integer, nullable=False, default=0)
|
||||
last_comment_edit_time = Column(
|
||||
'auto_comment_creation_time', Integer, nullable=False, default=0)
|
||||
last_comment_creation_time = Column(
|
||||
'auto_comment_edit_time', Integer, nullable=False, default=0)
|
|
@ -1,6 +1,8 @@
|
|||
from sqlalchemy import Column, Integer, DateTime, String, ForeignKey
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.orm import relationship, column_property
|
||||
from sqlalchemy.sql.expression import func, select
|
||||
from szurubooru.db.base import Base
|
||||
from szurubooru.db.post import PostTag
|
||||
|
||||
class TagSuggestion(Base):
|
||||
__tablename__ = 'tag_suggestion'
|
||||
|
@ -52,5 +54,11 @@ class Tag(Base):
|
|||
primaryjoin=tag_id == TagImplication.parent_id,
|
||||
secondaryjoin=tag_id == TagImplication.child_id)
|
||||
|
||||
# TODO: wire this
|
||||
post_count = Column('auto_post_count', Integer, nullable=False, default=0)
|
||||
post_count = column_property(
|
||||
select(
|
||||
[func.count('Post.post_id')],
|
||||
PostTag.tag_id == tag_id
|
||||
) \
|
||||
.correlate('Tag') \
|
||||
.label('post_count')
|
||||
)
|
||||
|
|
|
@ -20,7 +20,6 @@ def upgrade():
|
|||
sa.Column('category', sa.String(length=32), nullable=False),
|
||||
sa.Column('creation_time', sa.DateTime(), nullable=False),
|
||||
sa.Column('last_edit_time', sa.DateTime(), nullable=True),
|
||||
sa.Column('auto_post_count', sa.Integer(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id'))
|
||||
|
||||
op.create_table(
|
||||
|
|
|
@ -4,11 +4,7 @@ from szurubooru import db
|
|||
from szurubooru.search.base_search_config import BaseSearchConfig
|
||||
|
||||
class TagSearchConfig(BaseSearchConfig):
|
||||
def __init__(self):
|
||||
self._session = None
|
||||
|
||||
def create_query(self, session):
|
||||
self._session = session
|
||||
return session.query(db.Tag)
|
||||
|
||||
def finalize_query(self, query):
|
||||
|
@ -65,7 +61,7 @@ class TagSearchConfig(BaseSearchConfig):
|
|||
str_filter = self._create_str_filter(db.TagName.name)
|
||||
return query.filter(
|
||||
db.Tag.tag_id.in_(
|
||||
str_filter(self._session.query(db.TagName.tag_id), criterion)))
|
||||
str_filter(query.session.query(db.TagName.tag_id), criterion)))
|
||||
|
||||
def _suggestion_count_filter(self, query, criterion):
|
||||
return query.filter(
|
||||
|
|
|
@ -48,10 +48,11 @@ def test_removing_tags_without_privileges(test_ctx):
|
|||
'tag')
|
||||
assert test_ctx.session.query(db.Tag).count() == 1
|
||||
|
||||
def test_removing_tags_with_usages(test_ctx):
|
||||
def test_removing_tags_with_usages(test_ctx, post_factory):
|
||||
tag = test_ctx.tag_factory(names=['tag'])
|
||||
tag.post_count = 5
|
||||
test_ctx.session.add(tag)
|
||||
post = post_factory()
|
||||
post.tags.append(tag)
|
||||
test_ctx.session.add_all([tag, post])
|
||||
test_ctx.session.commit()
|
||||
with pytest.raises(tags.TagIsInUseError):
|
||||
test_ctx.api.delete(
|
||||
|
|
|
@ -71,3 +71,18 @@ def tag_factory():
|
|||
tag.creation_time = datetime.datetime(1996, 1, 1)
|
||||
return tag
|
||||
return factory
|
||||
|
||||
@pytest.fixture
|
||||
def post_factory():
|
||||
def factory(
|
||||
safety=db.Post.SAFETY_SAFE,
|
||||
type=db.Post.TYPE_IMAGE,
|
||||
checksum='...'):
|
||||
post = db.Post()
|
||||
post.safety = safety
|
||||
post.type = type
|
||||
post.checksum = checksum
|
||||
post.flags = 0
|
||||
post.creation_time = datetime.datetime(1996, 1, 1)
|
||||
return post
|
||||
return factory
|
||||
|
|
91
server/szurubooru/tests/db/test_post.py
Normal file
91
server/szurubooru/tests/db/test_post.py
Normal file
|
@ -0,0 +1,91 @@
|
|||
from datetime import datetime
|
||||
from szurubooru import db
|
||||
|
||||
def test_saving_post(session, post_factory, user_factory, tag_factory):
|
||||
user = user_factory()
|
||||
tag1 = tag_factory()
|
||||
tag2 = tag_factory()
|
||||
related_post1 = post_factory()
|
||||
related_post2 = post_factory()
|
||||
post = db.Post()
|
||||
post.safety = 'safety'
|
||||
post.type = 'type'
|
||||
post.checksum = 'deadbeef'
|
||||
post.creation_time = datetime(1997, 1, 1)
|
||||
post.last_edit_time = datetime(1998, 1, 1)
|
||||
session.add_all([user, tag1, tag2, related_post1, related_post2, post])
|
||||
|
||||
post.user = user
|
||||
post.tags.append(tag1)
|
||||
post.tags.append(tag2)
|
||||
post.relations.append(related_post1)
|
||||
post.relations.append(related_post2)
|
||||
session.commit()
|
||||
|
||||
post = session.query(db.Post).filter(db.Post.post_id == post.post_id).one()
|
||||
assert not session.dirty
|
||||
assert post.user.user_id is not None
|
||||
assert post.safety == 'safety'
|
||||
assert post.type == 'type'
|
||||
assert post.checksum == 'deadbeef'
|
||||
assert post.creation_time == datetime(1997, 1, 1)
|
||||
assert post.last_edit_time == datetime(1998, 1, 1)
|
||||
assert len(post.relations) == 2
|
||||
|
||||
def test_cascade_deletions(session, post_factory, user_factory, tag_factory):
|
||||
user = user_factory()
|
||||
tag1 = tag_factory()
|
||||
tag2 = tag_factory()
|
||||
related_post1 = post_factory()
|
||||
related_post2 = post_factory()
|
||||
post = post_factory()
|
||||
session.add_all([user, tag1, tag2, post, related_post1, related_post2])
|
||||
session.flush()
|
||||
|
||||
post.user = user
|
||||
post.tags.append(tag1)
|
||||
post.tags.append(tag2)
|
||||
post.relations.append(related_post1)
|
||||
post.relations.append(related_post2)
|
||||
session.flush()
|
||||
|
||||
assert not session.dirty
|
||||
assert post.user.user_id is not None
|
||||
assert len(post.relations) == 2
|
||||
assert session.query(db.User).count() == 1
|
||||
assert session.query(db.Tag).count() == 2
|
||||
assert session.query(db.Post).count() == 3
|
||||
assert session.query(db.PostTag).count() == 2
|
||||
assert session.query(db.PostRelation).count() == 2
|
||||
|
||||
session.delete(post)
|
||||
session.commit()
|
||||
|
||||
assert not session.dirty
|
||||
assert session.query(db.User).count() == 1
|
||||
assert session.query(db.Tag).count() == 2
|
||||
assert session.query(db.Post).count() == 2
|
||||
assert session.query(db.PostTag).count() == 0
|
||||
assert session.query(db.PostRelation).count() == 0
|
||||
|
||||
def test_tracking_tag_count(session, post_factory, tag_factory):
|
||||
post = post_factory()
|
||||
tag1 = tag_factory()
|
||||
tag2 = tag_factory()
|
||||
session.add_all([tag1, tag2, post])
|
||||
session.flush()
|
||||
post.tags.append(tag1)
|
||||
post.tags.append(tag2)
|
||||
session.commit()
|
||||
assert len(post.tags) == 2
|
||||
assert post.tag_count == 2
|
||||
session.delete(tag1)
|
||||
session.commit()
|
||||
session.refresh(post)
|
||||
assert len(post.tags) == 1
|
||||
assert post.tag_count == 1
|
||||
session.delete(tag2)
|
||||
session.commit()
|
||||
session.refresh(post)
|
||||
assert len(post.tags) == 0
|
||||
assert post.tag_count == 0
|
|
@ -13,7 +13,6 @@ def test_saving_tag(session, tag_factory):
|
|||
tag.category = 'category'
|
||||
tag.creation_time = datetime(1997, 1, 1)
|
||||
tag.last_edit_time = datetime(1998, 1, 1)
|
||||
tag.post_count = 1
|
||||
session.add_all([
|
||||
tag, suggested_tag1, suggested_tag2, implied_tag1, implied_tag2])
|
||||
session.commit()
|
||||
|
@ -37,7 +36,6 @@ def test_saving_tag(session, tag_factory):
|
|||
assert tag.category == 'category'
|
||||
assert tag.creation_time == datetime(1997, 1, 1)
|
||||
assert tag.last_edit_time == datetime(1998, 1, 1)
|
||||
assert tag.post_count == 1
|
||||
assert [relation.names[0].name for relation in tag.suggestions] \
|
||||
== ['suggested1', 'suggested2']
|
||||
assert [relation.names[0].name for relation in tag.implications] \
|
||||
|
@ -77,3 +75,24 @@ def test_cascade_deletions(session, tag_factory):
|
|||
assert session.query(db.TagName).count() == 4
|
||||
assert session.query(db.TagImplication).count() == 0
|
||||
assert session.query(db.TagSuggestion).count() == 0
|
||||
|
||||
def test_tracking_post_count(session, post_factory, tag_factory):
|
||||
tag = tag_factory()
|
||||
post1 = post_factory()
|
||||
post2 = post_factory()
|
||||
session.add_all([tag, post1, post2])
|
||||
session.flush()
|
||||
post1.tags.append(tag)
|
||||
post2.tags.append(tag)
|
||||
session.commit()
|
||||
assert len(post1.tags) == 1
|
||||
assert len(post2.tags) == 1
|
||||
assert tag.post_count == 2
|
||||
session.delete(post1)
|
||||
session.commit()
|
||||
session.refresh(tag)
|
||||
assert tag.post_count == 1
|
||||
session.delete(post2)
|
||||
session.commit()
|
||||
session.refresh(tag)
|
||||
assert tag.post_count == 0
|
||||
|
|
Loading…
Reference in a new issue