diff --git a/server/szurubooru/search/configs/post_search_config.py b/server/szurubooru/search/configs/post_search_config.py index ddc003b7..8d4672d4 100644 --- a/server/szurubooru/search/configs/post_search_config.py +++ b/server/szurubooru/search/configs/post_search_config.py @@ -122,6 +122,34 @@ def _pool_filter( )(query, criterion, negated) +def _category_filter( + query: SaQuery, criterion: Optional[criteria.BaseCriterion], negated: bool +) -> SaQuery: + assert criterion + + # Step 1. find the id for the category + q1 = db.session.query(model.TagCategory.tag_category_id).filter( + model.TagCategory.name == criterion.value + ) + + # Step 2. find the tags with that category + q2 = db.session.query(model.Tag.tag_id).filter( + model.Tag.category_id.in_(q1) + ) + + # Step 3. find all posts that have at least one of those tags + q3 = db.session.query(model.PostTag.post_id).filter( + model.PostTag.tag_id.in_(q2) + ) + + # Step 4. profit + expr = model.Post.post_id.in_(q3) + if negated: + expr = ~expr + + return query.filter(expr) + + class PostSearchConfig(BaseSearchConfig): def __init__(self) -> None: self.user = None # type: Optional[model.User] @@ -349,6 +377,7 @@ class PostSearchConfig(BaseSearchConfig): ), ), (["pool"], _pool_filter), + (["category"], _category_filter), ] ) diff --git a/server/szurubooru/tests/search/configs/test_post_search_config.py b/server/szurubooru/tests/search/configs/test_post_search_config.py index 4fb8191a..b86fa273 100644 --- a/server/szurubooru/tests/search/configs/test_post_search_config.py +++ b/server/szurubooru/tests/search/configs/test_post_search_config.py @@ -863,3 +863,55 @@ def test_tumbleweed( db.session.flush() verify_unpaged("special:tumbleweed", [4]) verify_unpaged("-special:tumbleweed", [1, 2, 3]) + + +@pytest.mark.parametrize( + "input,expected_post_ids", + [ + ("category:cat1", [1, 2, 3]), + ("category:cat2", [3, 4]), + ], +) +def test_search_by_tag_category( + verify_unpaged, + post_factory, + tag_factory, + tag_category_factory, + input, + expected_post_ids, +): + cat1 = tag_category_factory(name="cat1") + cat2 = tag_category_factory(name="cat2") + tag1 = tag_factory(names=["t1"], category=cat1) + tag2 = tag_factory(names=["t2"], category=cat1) + tag3 = tag_factory(names=["t3"], category=cat2) + + post1 = post_factory(id=1) + post1.tags.append(tag1) + + post2 = post_factory(id=2) + post2.tags.append(tag2) + + post3 = post_factory(id=3) + post3.tags.append(tag1) + post3.tags.append(tag3) + + post4 = post_factory(id=4) + post4.tags.append(tag3) + + post5 = post_factory(id=5) + + db.session.add_all( + [ + tag1, + tag2, + tag3, + post1, + post2, + post3, + post4, + post5, + ] + ) + db.session.flush() + verify_unpaged(input, expected_post_ids)