diff --git a/server/szurubooru/func/tags.py b/server/szurubooru/func/tags.py index e4eb2a12..a3944dd3 100644 --- a/server/szurubooru/func/tags.py +++ b/server/szurubooru/func/tags.py @@ -157,10 +157,19 @@ def delete(source_tag): db.session.delete(source_tag) def merge_tags(source_tag, target_tag): - db.session.execute( - sqlalchemy.sql.expression.update(db.PostTag) \ - .where(db.PostTag.tag_id == source_tag.tag_id) \ - .values(tag_id=target_tag.tag_id)) + pt1 = sqlalchemy.orm.util.aliased(db.PostTag) + pt2 = sqlalchemy.orm.util.aliased(db.PostTag) + ids_to_be_tagged = sqlalchemy.sql.expression \ + .select([pt1.tag_id]) \ + .where(pt1.tag_id == source_tag.tag_id) \ + .where(~sqlalchemy.exists() \ + .where(pt2.post_id == pt1.post_id) \ + .where(pt2.tag_id == target_tag.tag_id)) + + update_stmt = sqlalchemy.sql.expression.update(db.PostTag) \ + .where(db.PostTag.tag_id.in_(ids_to_be_tagged)) \ + .values(tag_id=target_tag.tag_id) + db.session.execute(update_stmt) delete(source_tag) def create_tag(names, category_name, suggestions, implications): diff --git a/server/szurubooru/tests/api/test_tag_merging.py b/server/szurubooru/tests/api/test_tag_merging.py index e2138826..7b03f110 100644 --- a/server/szurubooru/tests/api/test_tag_merging.py +++ b/server/szurubooru/tests/api/test_tag_merging.py @@ -98,6 +98,28 @@ def test_merging_when_related(test_ctx, fake_datetime): assert tags.try_get_tag_by_name('parent').implications == [] assert tags.try_get_tag_by_name('parent').suggestions == [] +def test_merging_when_target_exists(test_ctx, fake_datetime, post_factory): + source_tag = test_ctx.tag_factory(names=['source'], category_name='meta') + target_tag = test_ctx.tag_factory(names=['target'], category_name='meta') + db.session.add_all([source_tag, target_tag]) + db.session.flush() + post1 = post_factory() + post1.tags = [source_tag, target_tag] + db.session.add_all([post1]) + db.session.commit() + assert source_tag.post_count == 1 + assert target_tag.post_count == 1 + with fake_datetime('1997-12-01'): + result = test_ctx.api.post( + test_ctx.context_factory( + input={ + 'remove': 'source', + 'mergeTo': 'target', + }, + user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) + assert tags.try_get_tag_by_name('source') is None + assert tags.get_tag_by_name('target').post_count == 1 + @pytest.mark.parametrize('input,expected_exception', [ ({'remove': None}, tags.TagNotFoundError), ({'remove': ''}, tags.TagNotFoundError),