diff --git a/API.md b/API.md index 53fec511..1d48ea23 100644 --- a/API.md +++ b/API.md @@ -923,7 +923,8 @@ data. "removeVersion": , "remove": , "mergeToVersion": , - "mergeTo": + "mergeTo": , + "replaceContent": } ``` @@ -941,9 +942,11 @@ data. - **Description** Removes source post and merges all of its tags, relations, scores, - favorites and comments to the target post. Source post properties such as - its content, safety, source, whether to loop the video and other scalar - values do not get transferred and are discarded. + favorites and comments to the target post. If `replaceContent` is set to + true, content of the target post is replaced using the content of the + source post; otherwise it remains unchanged. Source post properties such as + its safety, source, whether to loop the video and other scalar values do + not get transferred and are discarded. ## Rating post - **Request** diff --git a/server/szurubooru/api/post_api.py b/server/szurubooru/api/post_api.py index 9cb1237a..142cca87 100644 --- a/server/szurubooru/api/post_api.py +++ b/server/szurubooru/api/post_api.py @@ -128,13 +128,14 @@ def delete_post(ctx, params): def merge_posts(ctx, _params=None): source_post_id = ctx.get_param_as_string('remove', required=True) or '' target_post_id = ctx.get_param_as_string('mergeTo', required=True) or '' + replace_content = ctx.get_param_as_bool('replaceContent') source_post = posts.get_post_by_id(source_post_id) target_post = posts.get_post_by_id(target_post_id) versions.verify_version(source_post, ctx, 'removeVersion') versions.verify_version(target_post, ctx, 'mergeToVersion') versions.bump_version(target_post) auth.verify_privilege(ctx.user, 'posts:merge') - posts.merge_posts(source_post, target_post) + posts.merge_posts(source_post, target_post, replace_content) snapshots.merge(source_post, target_post, ctx.user) ctx.session.commit() return _serialize_post(ctx, target_post) diff --git a/server/szurubooru/func/posts.py b/server/szurubooru/func/posts.py index f3725e5d..1ca70fc5 100644 --- a/server/szurubooru/func/posts.py +++ b/server/szurubooru/func/posts.py @@ -442,7 +442,7 @@ def delete(post): db.session.delete(post) -def merge_posts(source_post, target_post): +def merge_posts(source_post, target_post, replace_content): assert source_post assert target_post if source_post.post_id == target_post.post_id: @@ -515,3 +515,9 @@ def merge_posts(source_post, target_post): merge_relations(source_post.post_id, target_post.post_id) delete(source_post) + + db.session.flush() + + if replace_content: + content = files.get(get_post_content_path(source_post)) + update_post_content(target_post, content) diff --git a/server/szurubooru/tests/func/test_posts.py b/server/szurubooru/tests/func/test_posts.py index 033e0152..5b462941 100644 --- a/server/szurubooru/tests/func/test_posts.py +++ b/server/szurubooru/tests/func/test_posts.py @@ -612,7 +612,7 @@ def test_merge_posts_deletes_source_post(post_factory): target_post = post_factory() db.session.add_all([source_post, target_post]) db.session.flush() - posts.merge_posts(source_post, target_post) + posts.merge_posts(source_post, target_post, False) db.session.flush() assert posts.try_get_post_by_id(source_post.post_id) is None post = posts.get_post_by_id(target_post.post_id) @@ -624,7 +624,7 @@ def test_merge_posts_with_itself(post_factory): db.session.add(source_post) db.session.flush() with pytest.raises(posts.InvalidPostRelationError): - posts.merge_posts(source_post, source_post) + posts.merge_posts(source_post, source_post, False) def test_merge_posts_moves_tags(post_factory, tag_factory): @@ -636,7 +636,7 @@ def test_merge_posts_moves_tags(post_factory, tag_factory): db.session.commit() assert source_post.tag_count == 1 assert target_post.tag_count == 0 - posts.merge_posts(source_post, target_post) + posts.merge_posts(source_post, target_post, False) db.session.commit() assert posts.try_get_post_by_id(source_post.post_id) is None assert posts.get_post_by_id(target_post.post_id).tag_count == 1 @@ -651,7 +651,7 @@ def test_merge_posts_doesnt_duplicate_tags(post_factory, tag_factory): db.session.commit() assert source_post.tag_count == 1 assert target_post.tag_count == 1 - posts.merge_posts(source_post, target_post) + posts.merge_posts(source_post, target_post, False) db.session.commit() assert posts.try_get_post_by_id(source_post.post_id) is None assert posts.get_post_by_id(target_post.post_id).tag_count == 1 @@ -665,7 +665,7 @@ def test_merge_posts_moves_comments(post_factory, comment_factory): db.session.commit() assert source_post.comment_count == 1 assert target_post.comment_count == 0 - posts.merge_posts(source_post, target_post) + posts.merge_posts(source_post, target_post, False) db.session.commit() assert posts.try_get_post_by_id(source_post.post_id) is None assert posts.get_post_by_id(target_post.post_id).comment_count == 1 @@ -679,7 +679,7 @@ def test_merge_posts_moves_scores(post_factory, post_score_factory): db.session.commit() assert source_post.score == 1 assert target_post.score == 0 - posts.merge_posts(source_post, target_post) + posts.merge_posts(source_post, target_post, False) db.session.commit() assert posts.try_get_post_by_id(source_post.post_id) is None assert posts.get_post_by_id(target_post.post_id).score == 1 @@ -696,7 +696,7 @@ def test_merge_posts_doesnt_duplicate_scores( db.session.commit() assert source_post.score == 1 assert target_post.score == 1 - posts.merge_posts(source_post, target_post) + posts.merge_posts(source_post, target_post, False) db.session.commit() assert posts.try_get_post_by_id(source_post.post_id) is None assert posts.get_post_by_id(target_post.post_id).score == 1 @@ -710,7 +710,7 @@ def test_merge_posts_moves_favorites(post_factory, post_favorite_factory): db.session.commit() assert source_post.favorite_count == 1 assert target_post.favorite_count == 0 - posts.merge_posts(source_post, target_post) + posts.merge_posts(source_post, target_post, False) db.session.commit() assert posts.try_get_post_by_id(source_post.post_id) is None assert posts.get_post_by_id(target_post.post_id).favorite_count == 1 @@ -727,7 +727,7 @@ def test_merge_posts_doesnt_duplicate_favorites( db.session.commit() assert source_post.favorite_count == 1 assert target_post.favorite_count == 1 - posts.merge_posts(source_post, target_post) + posts.merge_posts(source_post, target_post, False) db.session.commit() assert posts.try_get_post_by_id(source_post.post_id) is None assert posts.get_post_by_id(target_post.post_id).favorite_count == 1 @@ -742,7 +742,7 @@ def test_merge_posts_moves_child_relations(post_factory): db.session.commit() assert source_post.relation_count == 1 assert target_post.relation_count == 0 - posts.merge_posts(source_post, target_post) + posts.merge_posts(source_post, target_post, False) db.session.commit() assert posts.try_get_post_by_id(source_post.post_id) is None assert posts.get_post_by_id(target_post.post_id).relation_count == 1 @@ -758,7 +758,7 @@ def test_merge_posts_doesnt_duplicate_child_relations(post_factory): db.session.commit() assert source_post.relation_count == 1 assert target_post.relation_count == 1 - posts.merge_posts(source_post, target_post) + posts.merge_posts(source_post, target_post, False) db.session.commit() assert posts.try_get_post_by_id(source_post.post_id) is None assert posts.get_post_by_id(target_post.post_id).relation_count == 1 @@ -774,7 +774,7 @@ def test_merge_posts_moves_parent_relations(post_factory): assert source_post.relation_count == 1 assert target_post.relation_count == 0 assert related_post.relation_count == 1 - posts.merge_posts(source_post, target_post) + posts.merge_posts(source_post, target_post, False) db.session.commit() assert posts.try_get_post_by_id(source_post.post_id) is None assert posts.get_post_by_id(target_post.post_id).relation_count == 1 @@ -791,7 +791,7 @@ def test_merge_posts_doesnt_duplicate_parent_relations(post_factory): assert source_post.relation_count == 1 assert target_post.relation_count == 1 assert related_post.relation_count == 2 - posts.merge_posts(source_post, target_post) + posts.merge_posts(source_post, target_post, False) db.session.commit() assert posts.try_get_post_by_id(source_post.post_id) is None assert posts.get_post_by_id(target_post.post_id).relation_count == 1 @@ -806,7 +806,7 @@ def test_merge_posts_doesnt_create_relation_loop_for_children(post_factory): db.session.commit() assert source_post.relation_count == 1 assert target_post.relation_count == 1 - posts.merge_posts(source_post, target_post) + posts.merge_posts(source_post, target_post, False) db.session.commit() assert posts.try_get_post_by_id(source_post.post_id) is None assert posts.get_post_by_id(target_post.post_id).relation_count == 0 @@ -820,7 +820,36 @@ def test_merge_posts_doesnt_create_relation_loop_for_parents(post_factory): db.session.commit() assert source_post.relation_count == 1 assert target_post.relation_count == 1 - posts.merge_posts(source_post, target_post) + posts.merge_posts(source_post, target_post, False) db.session.commit() assert posts.try_get_post_by_id(source_post.post_id) is None assert posts.get_post_by_id(target_post.post_id).relation_count == 0 + + +def test_merge_posts_replaces_content( + post_factory, config_injector, tmpdir, read_asset): + config_injector({ + 'data_dir': str(tmpdir.mkdir('data')), + 'data_url': 'example.com', + 'thumbnails': { + 'post_width': 300, + 'post_height': 300, + }, + }) + source_post = post_factory() + target_post = post_factory() + content = read_asset('png.png') + db.session.add_all([source_post, target_post]) + db.session.commit() + posts.update_post_content(source_post, content) + db.session.flush() + assert os.path.exists(os.path.join(str(tmpdir), 'data/posts/1.png')) + assert not os.path.exists(os.path.join(str(tmpdir), 'data/posts/2.dat')) + assert not os.path.exists(os.path.join(str(tmpdir), 'data/posts/2.png')) + posts.merge_posts(source_post, target_post, True) + db.session.flush() + assert posts.try_get_post_by_id(source_post.post_id) is None + post = posts.get_post_by_id(target_post.post_id) + assert post is not None + assert os.path.exists(os.path.join(str(tmpdir), 'data/posts/1.png')) + assert os.path.exists(os.path.join(str(tmpdir), 'data/posts/2.png'))