diff --git a/server/szurubooru/db/__init__.py b/server/szurubooru/db/__init__.py index 0c65170e..20073dfd 100644 --- a/server/szurubooru/db/__init__.py +++ b/server/szurubooru/db/__init__.py @@ -1,3 +1,4 @@ from szurubooru.db.base import Base from szurubooru.db.user import User from szurubooru.db.tag import Tag, TagName, TagSuggestion, TagImplication +from szurubooru.db.post import Post, PostTag, PostRelation diff --git a/server/szurubooru/db/post.py b/server/szurubooru/db/post.py new file mode 100644 index 00000000..da443381 --- /dev/null +++ b/server/szurubooru/db/post.py @@ -0,0 +1,82 @@ +from sqlalchemy import Column, Integer, DateTime, String, ForeignKey +from sqlalchemy.orm import relationship, column_property +from sqlalchemy.sql.expression import func, select +from szurubooru.db.base import Base + +class PostRelation(Base): + __tablename__ = 'post_relation' + + parent_id = Column('parent_id', Integer, ForeignKey('post.id'), primary_key=True) + child_id = Column('child_id', Integer, ForeignKey('post.id'), primary_key=True) + + def __init__(self, parent_id, child_id): + self.parent_id = parent_id + self.child_id = child_id + +class PostTag(Base): + __tablename__ = 'post_tag' + + post_id = Column('post_id', Integer, ForeignKey('post.id'), primary_key=True) + tag_id = Column('tag_id', Integer, ForeignKey('tag.id'), primary_key=True) + + def __init__(self, tag_id, post_id): + self.tag_id = tag_id + self.post_id = post_id + +class Post(Base): + __tablename__ = 'post' + + SAFETY_SAFE = 'safe' + SAFETY_SKETCHY = 'sketchy' + SAFETY_UNSAFE = 'unsafe' + TYPE_IMAGE = 'anim' + TYPE_ANIMATION = 'anim' + TYPE_FLASH = 'flash' + TYPE_VIDEO = 'video' + TYPE_YOUTUBE = 'youtube' + FLAG_LOOP_VIDEO = 1 + + post_id = Column('id', Integer, primary_key=True) + user_id = Column('user_id', Integer, ForeignKey('user.id')) + creation_time = Column('creation_time', DateTime, nullable=False) + last_edit_time = Column('last_edit_time', DateTime) + safety = Column('safety', String(32), nullable=False) + type = Column('type', String(32), nullable=False) + checksum = Column('checksum', String(64), nullable=False) + source = Column('source', String(200)) + file_size = Column('file_size', Integer) + image_width = Column('image_width', Integer) + image_height = Column('image_height', Integer) + flags = Column('flags', Integer, nullable=False, default=0) + + user = relationship('User') + tags = relationship('Tag', backref='posts', secondary='post_tag') + relations = relationship( + 'Post', + secondary='post_relation', + primaryjoin=post_id == PostRelation.parent_id, + secondaryjoin=post_id == PostRelation.child_id) + + tag_count = column_property( + select( + [func.count('1')], + PostTag.post_id == post_id + ) \ + .correlate('Post') \ + .label('tag_count') + ) + + # TODO: wire these + fav_count = Column('auto_fav_count', Integer, nullable=False, default=0) + score = Column('auto_score', Integer, nullable=False, default=0) + feature_count = Column('auto_feature_count', Integer, nullable=False, default=0) + comment_count = Column('auto_comment_count', Integer, nullable=False, default=0) + note_count = Column('auto_note_count', Integer, nullable=False, default=0) + last_fav_time = Column( + 'auto_fav_time', Integer, nullable=False, default=0) + last_feature_time = Column( + 'auto_feature_time', Integer, nullable=False, default=0) + last_comment_edit_time = Column( + 'auto_comment_creation_time', Integer, nullable=False, default=0) + last_comment_creation_time = Column( + 'auto_comment_edit_time', Integer, nullable=False, default=0) diff --git a/server/szurubooru/db/tag.py b/server/szurubooru/db/tag.py index dbdfed90..3254ba41 100644 --- a/server/szurubooru/db/tag.py +++ b/server/szurubooru/db/tag.py @@ -1,6 +1,8 @@ from sqlalchemy import Column, Integer, DateTime, String, ForeignKey -from sqlalchemy.orm import relationship +from sqlalchemy.orm import relationship, column_property +from sqlalchemy.sql.expression import func, select from szurubooru.db.base import Base +from szurubooru.db.post import PostTag class TagSuggestion(Base): __tablename__ = 'tag_suggestion' @@ -52,5 +54,11 @@ class Tag(Base): primaryjoin=tag_id == TagImplication.parent_id, secondaryjoin=tag_id == TagImplication.child_id) - # TODO: wire this - post_count = Column('auto_post_count', Integer, nullable=False, default=0) + post_count = column_property( + select( + [func.count('Post.post_id')], + PostTag.tag_id == tag_id + ) \ + .correlate('Tag') \ + .label('post_count') + ) diff --git a/server/szurubooru/migrations/versions/00cb3a2734db_create_tags_tables.py b/server/szurubooru/migrations/versions/00cb3a2734db_create_tags_tables.py index fe85ade7..ba20230c 100644 --- a/server/szurubooru/migrations/versions/00cb3a2734db_create_tags_tables.py +++ b/server/szurubooru/migrations/versions/00cb3a2734db_create_tags_tables.py @@ -20,7 +20,6 @@ def upgrade(): sa.Column('category', sa.String(length=32), nullable=False), sa.Column('creation_time', sa.DateTime(), nullable=False), sa.Column('last_edit_time', sa.DateTime(), nullable=True), - sa.Column('auto_post_count', sa.Integer(), nullable=False), sa.PrimaryKeyConstraint('id')) op.create_table( diff --git a/server/szurubooru/search/tag_search_config.py b/server/szurubooru/search/tag_search_config.py index 080538fa..434bf146 100644 --- a/server/szurubooru/search/tag_search_config.py +++ b/server/szurubooru/search/tag_search_config.py @@ -4,11 +4,7 @@ from szurubooru import db from szurubooru.search.base_search_config import BaseSearchConfig class TagSearchConfig(BaseSearchConfig): - def __init__(self): - self._session = None - def create_query(self, session): - self._session = session return session.query(db.Tag) def finalize_query(self, query): @@ -65,7 +61,7 @@ class TagSearchConfig(BaseSearchConfig): str_filter = self._create_str_filter(db.TagName.name) return query.filter( db.Tag.tag_id.in_( - str_filter(self._session.query(db.TagName.tag_id), criterion))) + str_filter(query.session.query(db.TagName.tag_id), criterion))) def _suggestion_count_filter(self, query, criterion): return query.filter( diff --git a/server/szurubooru/tests/api/test_tag_deleting.py b/server/szurubooru/tests/api/test_tag_deleting.py index 852cce4d..a895bec5 100644 --- a/server/szurubooru/tests/api/test_tag_deleting.py +++ b/server/szurubooru/tests/api/test_tag_deleting.py @@ -48,10 +48,11 @@ def test_removing_tags_without_privileges(test_ctx): 'tag') assert test_ctx.session.query(db.Tag).count() == 1 -def test_removing_tags_with_usages(test_ctx): +def test_removing_tags_with_usages(test_ctx, post_factory): tag = test_ctx.tag_factory(names=['tag']) - tag.post_count = 5 - test_ctx.session.add(tag) + post = post_factory() + post.tags.append(tag) + test_ctx.session.add_all([tag, post]) test_ctx.session.commit() with pytest.raises(tags.TagIsInUseError): test_ctx.api.delete( diff --git a/server/szurubooru/tests/conftest.py b/server/szurubooru/tests/conftest.py index e0d6785e..239f5b69 100644 --- a/server/szurubooru/tests/conftest.py +++ b/server/szurubooru/tests/conftest.py @@ -71,3 +71,18 @@ def tag_factory(): tag.creation_time = datetime.datetime(1996, 1, 1) return tag return factory + +@pytest.fixture +def post_factory(): + def factory( + safety=db.Post.SAFETY_SAFE, + type=db.Post.TYPE_IMAGE, + checksum='...'): + post = db.Post() + post.safety = safety + post.type = type + post.checksum = checksum + post.flags = 0 + post.creation_time = datetime.datetime(1996, 1, 1) + return post + return factory diff --git a/server/szurubooru/tests/db/test_post.py b/server/szurubooru/tests/db/test_post.py new file mode 100644 index 00000000..1415f2b9 --- /dev/null +++ b/server/szurubooru/tests/db/test_post.py @@ -0,0 +1,91 @@ +from datetime import datetime +from szurubooru import db + +def test_saving_post(session, post_factory, user_factory, tag_factory): + user = user_factory() + tag1 = tag_factory() + tag2 = tag_factory() + related_post1 = post_factory() + related_post2 = post_factory() + post = db.Post() + post.safety = 'safety' + post.type = 'type' + post.checksum = 'deadbeef' + post.creation_time = datetime(1997, 1, 1) + post.last_edit_time = datetime(1998, 1, 1) + session.add_all([user, tag1, tag2, related_post1, related_post2, post]) + + post.user = user + post.tags.append(tag1) + post.tags.append(tag2) + post.relations.append(related_post1) + post.relations.append(related_post2) + session.commit() + + post = session.query(db.Post).filter(db.Post.post_id == post.post_id).one() + assert not session.dirty + assert post.user.user_id is not None + assert post.safety == 'safety' + assert post.type == 'type' + assert post.checksum == 'deadbeef' + assert post.creation_time == datetime(1997, 1, 1) + assert post.last_edit_time == datetime(1998, 1, 1) + assert len(post.relations) == 2 + +def test_cascade_deletions(session, post_factory, user_factory, tag_factory): + user = user_factory() + tag1 = tag_factory() + tag2 = tag_factory() + related_post1 = post_factory() + related_post2 = post_factory() + post = post_factory() + session.add_all([user, tag1, tag2, post, related_post1, related_post2]) + session.flush() + + post.user = user + post.tags.append(tag1) + post.tags.append(tag2) + post.relations.append(related_post1) + post.relations.append(related_post2) + session.flush() + + assert not session.dirty + assert post.user.user_id is not None + assert len(post.relations) == 2 + assert session.query(db.User).count() == 1 + assert session.query(db.Tag).count() == 2 + assert session.query(db.Post).count() == 3 + assert session.query(db.PostTag).count() == 2 + assert session.query(db.PostRelation).count() == 2 + + session.delete(post) + session.commit() + + assert not session.dirty + assert session.query(db.User).count() == 1 + assert session.query(db.Tag).count() == 2 + assert session.query(db.Post).count() == 2 + assert session.query(db.PostTag).count() == 0 + assert session.query(db.PostRelation).count() == 0 + +def test_tracking_tag_count(session, post_factory, tag_factory): + post = post_factory() + tag1 = tag_factory() + tag2 = tag_factory() + session.add_all([tag1, tag2, post]) + session.flush() + post.tags.append(tag1) + post.tags.append(tag2) + session.commit() + assert len(post.tags) == 2 + assert post.tag_count == 2 + session.delete(tag1) + session.commit() + session.refresh(post) + assert len(post.tags) == 1 + assert post.tag_count == 1 + session.delete(tag2) + session.commit() + session.refresh(post) + assert len(post.tags) == 0 + assert post.tag_count == 0 diff --git a/server/szurubooru/tests/db/test_tag.py b/server/szurubooru/tests/db/test_tag.py index 3c40b880..dfbafbac 100644 --- a/server/szurubooru/tests/db/test_tag.py +++ b/server/szurubooru/tests/db/test_tag.py @@ -13,7 +13,6 @@ def test_saving_tag(session, tag_factory): tag.category = 'category' tag.creation_time = datetime(1997, 1, 1) tag.last_edit_time = datetime(1998, 1, 1) - tag.post_count = 1 session.add_all([ tag, suggested_tag1, suggested_tag2, implied_tag1, implied_tag2]) session.commit() @@ -37,7 +36,6 @@ def test_saving_tag(session, tag_factory): assert tag.category == 'category' assert tag.creation_time == datetime(1997, 1, 1) assert tag.last_edit_time == datetime(1998, 1, 1) - assert tag.post_count == 1 assert [relation.names[0].name for relation in tag.suggestions] \ == ['suggested1', 'suggested2'] assert [relation.names[0].name for relation in tag.implications] \ @@ -77,3 +75,24 @@ def test_cascade_deletions(session, tag_factory): assert session.query(db.TagName).count() == 4 assert session.query(db.TagImplication).count() == 0 assert session.query(db.TagSuggestion).count() == 0 + +def test_tracking_post_count(session, post_factory, tag_factory): + tag = tag_factory() + post1 = post_factory() + post2 = post_factory() + session.add_all([tag, post1, post2]) + session.flush() + post1.tags.append(tag) + post2.tags.append(tag) + session.commit() + assert len(post1.tags) == 1 + assert len(post2.tags) == 1 + assert tag.post_count == 2 + session.delete(post1) + session.commit() + session.refresh(tag) + assert tag.post_count == 1 + session.delete(post2) + session.commit() + session.refresh(tag) + assert tag.post_count == 0