Support sorting post search results by pool post order

This commit is contained in:
Ruin0x11 2021-05-08 02:38:40 -07:00
parent ca77149597
commit 81645864ec
5 changed files with 81 additions and 15 deletions

View file

@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple, Callable, Union
import sqlalchemy as sa import sqlalchemy as sa
@ -114,17 +114,30 @@ def _pool_filter(
query: SaQuery, criterion: Optional[criteria.BaseCriterion], negated: bool query: SaQuery, criterion: Optional[criteria.BaseCriterion], negated: bool
) -> SaQuery: ) -> SaQuery:
assert criterion assert criterion
return search_util.create_subquery_filter( from szurubooru.search.configs import util as search_util
model.Post.post_id, subquery = db.session.query(model.PoolPost.post_id.label("foreign_id"))
model.PoolPost.post_id, subquery = subquery.options(sa.orm.lazyload("*"))
model.PoolPost.pool_id, subquery = search_util.create_num_filter(model.PoolPost.pool_id)(subquery, criterion, False)
search_util.create_num_filter, subquery = subquery.subquery("t")
)(query, criterion, negated) expression = model.Post.post_id.in_(subquery)
if negated:
expression = ~expression
return query.filter(expression)
def _pool_sort(
query: SaQuery, pool_id: Optional[int]
) -> SaQuery:
if pool_id is None:
return query
return query.join(model.PoolPost, sa.and_(model.PoolPost.post_id == model.Post.post_id, model.PoolPost.pool_id == pool_id)) \
.order_by(model.PoolPost.order.desc())
class PostSearchConfig(BaseSearchConfig): class PostSearchConfig(BaseSearchConfig):
def __init__(self) -> None: def __init__(self) -> None:
self.user = None # type: Optional[model.User] self.user = None # type: Optional[model.User]
self.pool_id = None # type: Optional[int]
def on_search_query_parsed(self, search_query: SearchQuery) -> SaQuery: def on_search_query_parsed(self, search_query: SearchQuery) -> SaQuery:
new_special_tokens = [] new_special_tokens = []
@ -149,6 +162,10 @@ class PostSearchConfig(BaseSearchConfig):
else: else:
new_special_tokens.append(token) new_special_tokens.append(token)
search_query.special_tokens = new_special_tokens search_query.special_tokens = new_special_tokens
self.pool_id = None
for token in search_query.named_tokens:
if token.name == "pool" and isinstance(token.criterion, criteria.PlainCriterion):
self.pool_id = token.criterion.value
def create_around_query(self) -> SaQuery: def create_around_query(self) -> SaQuery:
return db.session.query(model.Post).options(sa.orm.lazyload("*")) return db.session.query(model.Post).options(sa.orm.lazyload("*"))
@ -353,7 +370,7 @@ class PostSearchConfig(BaseSearchConfig):
) )
@property @property
def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]: def sort_columns(self) -> Dict[str, Union[Tuple[SaColumn, str], Callable[[SaQuery], None]]]:
return util.unalias_dict( return util.unalias_dict(
[ [
( (
@ -415,6 +432,10 @@ class PostSearchConfig(BaseSearchConfig):
["feature-date", "feature-time"], ["feature-date", "feature-time"],
(model.Post.last_feature_time, self.SORT_DESC), (model.Post.last_feature_time, self.SORT_DESC),
), ),
(
["pool"],
lambda subquery: _pool_sort(subquery, self.pool_id)
)
] ]
) )

View file

@ -205,6 +205,7 @@ def create_subquery_filter(
filter_column: SaColumn, filter_column: SaColumn,
filter_factory: SaColumn, filter_factory: SaColumn,
subquery_decorator: Callable[[SaQuery], None] = None, subquery_decorator: Callable[[SaQuery], None] = None,
order: SaQuery = None,
) -> Filter: ) -> Filter:
filter_func = filter_factory(filter_column) filter_func = filter_factory(filter_column)

View file

@ -181,9 +181,13 @@ class Executor:
_format_dict_keys(self.config.sort_columns), _format_dict_keys(self.config.sort_columns),
) )
) )
column, default_order = self.config.sort_columns[ entry = self.config.sort_columns[
sort_token.name sort_token.name
] ]
if callable(entry):
db_query = entry(db_query)
else:
column, default_order = entry
order = _get_order(sort_token.order, default_order) order = _get_order(sort_token.order, default_order)
if order == sort_token.SORT_ASC: if order == sort_token.SORT_ASC:
db_query = db_query.order_by(column.asc()) db_query = db_query.order_by(column.asc())

View file

@ -1,4 +1,4 @@
from typing import Any, Callable from typing import Any, Callable, Union
SaColumn = Any SaColumn = Any
SaQuery = Any SaQuery = Any

View file

@ -725,6 +725,7 @@ def test_filter_by_feature_date(
"sort:fav-time", "sort:fav-time",
"sort:feature-date", "sort:feature-date",
"sort:feature-time", "sort:feature-time",
"sort:pool",
], ],
) )
def test_sort_tokens(verify_unpaged, post_factory, input): def test_sort_tokens(verify_unpaged, post_factory, input):
@ -863,3 +864,42 @@ def test_tumbleweed(
db.session.flush() db.session.flush()
verify_unpaged("special:tumbleweed", [4]) verify_unpaged("special:tumbleweed", [4])
verify_unpaged("-special:tumbleweed", [1, 2, 3]) verify_unpaged("-special:tumbleweed", [1, 2, 3])
def test_sort_pool(
post_factory, pool_factory, pool_category_factory, verify_unpaged
):
post1 = post_factory(id=1)
post2 = post_factory(id=2)
post3 = post_factory(id=3)
post4 = post_factory(id=4)
pool1 = pool_factory(
id=1,
names=["pool1"],
description="desc",
category=pool_category_factory("test-cat1"),
)
pool1.posts = [post1, post4, post3]
pool2 = pool_factory(
id=2,
names=["pool2"],
description="desc",
category=pool_category_factory("test-cat2"),
)
pool2.posts = [post3, post4, post2]
db.session.add_all(
[
post1,
post2,
post3,
post4,
pool1,
pool2
]
)
db.session.flush()
verify_unpaged("pool:1 sort:pool", [1, 4, 3])
verify_unpaged("pool:2 sort:pool", [3, 4, 2])
verify_unpaged("pool:1 pool:2 sort:pool", [4, 3])
verify_unpaged("pool:2 pool:1 sort:pool", [3, 4])
verify_unpaged("sort:pool", [1, 2, 3, 4])