diff --git a/server/szurubooru/search/configs/post_search_config.py b/server/szurubooru/search/configs/post_search_config.py index ddc003b7..8d4672d4 100644 --- a/server/szurubooru/search/configs/post_search_config.py +++ b/server/szurubooru/search/configs/post_search_config.py @@ -122,6 +122,34 @@ def _pool_filter( )(query, criterion, negated) +def _category_filter( + query: SaQuery, criterion: Optional[criteria.BaseCriterion], negated: bool +) -> SaQuery: + assert criterion + + # Step 1. find the id for the category + q1 = db.session.query(model.TagCategory.tag_category_id).filter( + model.TagCategory.name == criterion.value + ) + + # Step 2. find the tags with that category + q2 = db.session.query(model.Tag.tag_id).filter( + model.Tag.category_id.in_(q1) + ) + + # Step 3. find all posts that have at least one of those tags + q3 = db.session.query(model.PostTag.post_id).filter( + model.PostTag.tag_id.in_(q2) + ) + + # Step 4. profit + expr = model.Post.post_id.in_(q3) + if negated: + expr = ~expr + + return query.filter(expr) + + class PostSearchConfig(BaseSearchConfig): def __init__(self) -> None: self.user = None # type: Optional[model.User] @@ -349,6 +377,7 @@ class PostSearchConfig(BaseSearchConfig): ), ), (["pool"], _pool_filter), + (["category"], _category_filter), ] )