diff --git a/server/szurubooru/db/post.py b/server/szurubooru/db/post.py index 4a12899e..73314b28 100644 --- a/server/szurubooru/db/post.py +++ b/server/szurubooru/db/post.py @@ -112,11 +112,16 @@ class Post(Base): # foreign tables user = relationship('User') tags = relationship('Tag', backref='posts', secondary='post_tag') - relations = relationship( + relating_to = relationship( 'Post', secondary='post_relation', primaryjoin=post_id == PostRelation.parent_id, secondaryjoin=post_id == PostRelation.child_id, lazy='joined') + related_by = relationship( + 'Post', + secondary='post_relation', + primaryjoin=post_id == PostRelation.child_id, + secondaryjoin=post_id == PostRelation.parent_id, lazy='joined') features = relationship( 'PostFeature', cascade='all, delete-orphan', lazy='joined') scores = relationship( @@ -190,5 +195,7 @@ class Post(Base): relation_count = column_property( select([func.count(PostRelation.child_id)]) \ - .where(PostRelation.parent_id == post_id) \ + .where( + (PostRelation.parent_id == post_id) \ + | (PostRelation.child_id == post_id)) \ .correlate_except(PostRelation)) diff --git a/server/szurubooru/func/posts.py b/server/szurubooru/func/posts.py index 1f1429aa..5d12cb4b 100644 --- a/server/szurubooru/func/posts.py +++ b/server/szurubooru/func/posts.py @@ -264,7 +264,7 @@ def update_post_relations(post, post_ids): .all() if len(relations) != len(post_ids): raise InvalidPostRelationError('One of relations does not exist.') - post.relations = relations + post.relating_to = relations def update_post_notes(post, notes): post.notes = [] diff --git a/server/szurubooru/func/snapshots.py b/server/szurubooru/func/snapshots.py index 832ac164..49dea66c 100644 --- a/server/szurubooru/func/snapshots.py +++ b/server/szurubooru/func/snapshots.py @@ -15,7 +15,8 @@ def get_post_snapshot(post): 'safety': post.safety, 'checksum': post.checksum, 'tags': sorted([tag.first_name for tag in post.tags]), - 'relations': sorted([rel.post_id for rel in post.relations]), + 'relations': sorted([ + rel.post_id for rel in post.relating_to + post.related_by]), 'notes': sorted([{ 'polygon': note.polygon, 'text': note.text, diff --git a/server/szurubooru/search/configs/post_search_config.py b/server/szurubooru/search/configs/post_search_config.py index 4fca489f..b83002c3 100644 --- a/server/szurubooru/search/configs/post_search_config.py +++ b/server/szurubooru/search/configs/post_search_config.py @@ -97,7 +97,8 @@ class PostSearchConfig(BaseSearchConfig): defer(db.Post.tag_count), subqueryload(db.Post.tags).subqueryload(db.Tag.names), lazyload(db.Post.user), - lazyload(db.Post.relations), + lazyload(db.Post.relating_to), + lazyload(db.Post.related_by), lazyload(db.Post.notes), lazyload(db.Post.favorited_by), ) diff --git a/server/szurubooru/tests/db/test_post.py b/server/szurubooru/tests/db/test_post.py index 041f58a4..3a57e909 100644 --- a/server/szurubooru/tests/db/test_post.py +++ b/server/szurubooru/tests/db/test_post.py @@ -19,11 +19,13 @@ def test_saving_post(post_factory, user_factory, tag_factory): post.user = user post.tags.append(tag1) post.tags.append(tag2) - post.relations.append(related_post1) - post.relations.append(related_post2) + post.relating_to.append(related_post1) + post.relating_to.append(related_post2) db.session.commit() db.session.refresh(post) + db.session.refresh(related_post1) + db.session.refresh(related_post2) assert not db.session.dirty assert post.user.user_id is not None assert post.safety == 'safety' @@ -31,7 +33,12 @@ def test_saving_post(post_factory, user_factory, tag_factory): assert post.checksum == 'deadbeef' assert post.creation_time == datetime(1997, 1, 1) assert post.last_edit_time == datetime(1998, 1, 1) - assert len(post.relations) == 2 + assert len(post.relating_to) == 2 + assert len(related_post1.relating_to) == 0 + assert len(related_post1.relating_to) == 0 + assert len(post.related_by) == 0 + assert len(related_post1.related_by) == 1 + assert len(related_post1.related_by) == 1 def test_cascade_deletions(post_factory, user_factory, tag_factory): user = user_factory() @@ -66,8 +73,8 @@ def test_cascade_deletions(post_factory, user_factory, tag_factory): post.user = user post.tags.append(tag1) post.tags.append(tag2) - post.relations.append(related_post1) - post.relations.append(related_post2) + post.relating_to.append(related_post1) + post.relating_to.append(related_post2) post.scores.append(score) post.favorited_by.append(favorite) post.features.append(feature) @@ -76,7 +83,7 @@ def test_cascade_deletions(post_factory, user_factory, tag_factory): assert not db.session.dirty assert post.user is not None and post.user.user_id is not None - assert len(post.relations) == 2 + assert len(post.relating_to) == 2 assert db.session.query(db.User).count() == 1 assert db.session.query(db.Tag).count() == 2 assert db.session.query(db.Post).count() == 3 diff --git a/server/szurubooru/tests/func/test_posts.py b/server/szurubooru/tests/func/test_posts.py index 594fcec1..899ee9e4 100644 --- a/server/szurubooru/tests/func/test_posts.py +++ b/server/szurubooru/tests/func/test_posts.py @@ -395,9 +395,9 @@ def test_update_post_relations(post_factory): db.session.flush() post = db.Post() posts.update_post_relations(post, [relation1.post_id, relation2.post_id]) - assert len(post.relations) == 2 - assert post.relations[0].post_id == relation1.post_id - assert post.relations[1].post_id == relation2.post_id + assert len(post.relating_to) == 2 + assert post.relating_to[0].post_id == relation1.post_id + assert post.relating_to[1].post_id == relation2.post_id def test_update_post_non_existing_relations(): post = db.Post() diff --git a/server/szurubooru/tests/func/test_snapshots.py b/server/szurubooru/tests/func/test_snapshots.py index c82b82f0..abb182db 100644 --- a/server/szurubooru/tests/func/test_snapshots.py +++ b/server/szurubooru/tests/func/test_snapshots.py @@ -38,8 +38,8 @@ def test_serializing_post(post_factory, user_factory, tag_factory): post.source = 'example.com' post.tags.append(tag1) post.tags.append(tag2) - post.relations.append(related_post1) - post.relations.append(related_post2) + post.relating_to.append(related_post1) + post.relating_to.append(related_post2) post.scores.append(score) post.favorited_by.append(favorite) post.features.append(feature)