Support sorting post search results by pool post order
This commit is contained in:
parent
ca77149597
commit
81645864ec
5 changed files with 81 additions and 15 deletions
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple, Callable, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
@ -114,17 +114,30 @@ def _pool_filter(
|
|||
query: SaQuery, criterion: Optional[criteria.BaseCriterion], negated: bool
|
||||
) -> SaQuery:
|
||||
assert criterion
|
||||
return search_util.create_subquery_filter(
|
||||
model.Post.post_id,
|
||||
model.PoolPost.post_id,
|
||||
model.PoolPost.pool_id,
|
||||
search_util.create_num_filter,
|
||||
)(query, criterion, negated)
|
||||
from szurubooru.search.configs import util as search_util
|
||||
subquery = db.session.query(model.PoolPost.post_id.label("foreign_id"))
|
||||
subquery = subquery.options(sa.orm.lazyload("*"))
|
||||
subquery = search_util.create_num_filter(model.PoolPost.pool_id)(subquery, criterion, False)
|
||||
subquery = subquery.subquery("t")
|
||||
expression = model.Post.post_id.in_(subquery)
|
||||
if negated:
|
||||
expression = ~expression
|
||||
return query.filter(expression)
|
||||
|
||||
|
||||
def _pool_sort(
|
||||
query: SaQuery, pool_id: Optional[int]
|
||||
) -> SaQuery:
|
||||
if pool_id is None:
|
||||
return query
|
||||
return query.join(model.PoolPost, sa.and_(model.PoolPost.post_id == model.Post.post_id, model.PoolPost.pool_id == pool_id)) \
|
||||
.order_by(model.PoolPost.order.desc())
|
||||
|
||||
|
||||
class PostSearchConfig(BaseSearchConfig):
|
||||
def __init__(self) -> None:
|
||||
self.user = None # type: Optional[model.User]
|
||||
self.pool_id = None # type: Optional[int]
|
||||
|
||||
def on_search_query_parsed(self, search_query: SearchQuery) -> SaQuery:
|
||||
new_special_tokens = []
|
||||
|
@ -149,6 +162,10 @@ class PostSearchConfig(BaseSearchConfig):
|
|||
else:
|
||||
new_special_tokens.append(token)
|
||||
search_query.special_tokens = new_special_tokens
|
||||
self.pool_id = None
|
||||
for token in search_query.named_tokens:
|
||||
if token.name == "pool" and isinstance(token.criterion, criteria.PlainCriterion):
|
||||
self.pool_id = token.criterion.value
|
||||
|
||||
def create_around_query(self) -> SaQuery:
|
||||
return db.session.query(model.Post).options(sa.orm.lazyload("*"))
|
||||
|
@ -353,7 +370,7 @@ class PostSearchConfig(BaseSearchConfig):
|
|||
)
|
||||
|
||||
@property
|
||||
def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]:
|
||||
def sort_columns(self) -> Dict[str, Union[Tuple[SaColumn, str], Callable[[SaQuery], None]]]:
|
||||
return util.unalias_dict(
|
||||
[
|
||||
(
|
||||
|
@ -415,6 +432,10 @@ class PostSearchConfig(BaseSearchConfig):
|
|||
["feature-date", "feature-time"],
|
||||
(model.Post.last_feature_time, self.SORT_DESC),
|
||||
),
|
||||
(
|
||||
["pool"],
|
||||
lambda subquery: _pool_sort(subquery, self.pool_id)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
@ -205,6 +205,7 @@ def create_subquery_filter(
|
|||
filter_column: SaColumn,
|
||||
filter_factory: SaColumn,
|
||||
subquery_decorator: Callable[[SaQuery], None] = None,
|
||||
order: SaQuery = None,
|
||||
) -> Filter:
|
||||
filter_func = filter_factory(filter_column)
|
||||
|
||||
|
|
|
@ -181,14 +181,18 @@ class Executor:
|
|||
_format_dict_keys(self.config.sort_columns),
|
||||
)
|
||||
)
|
||||
column, default_order = self.config.sort_columns[
|
||||
entry = self.config.sort_columns[
|
||||
sort_token.name
|
||||
]
|
||||
order = _get_order(sort_token.order, default_order)
|
||||
if order == sort_token.SORT_ASC:
|
||||
db_query = db_query.order_by(column.asc())
|
||||
elif order == sort_token.SORT_DESC:
|
||||
db_query = db_query.order_by(column.desc())
|
||||
if callable(entry):
|
||||
db_query = entry(db_query)
|
||||
else:
|
||||
column, default_order = entry
|
||||
order = _get_order(sort_token.order, default_order)
|
||||
if order == sort_token.SORT_ASC:
|
||||
db_query = db_query.order_by(column.asc())
|
||||
elif order == sort_token.SORT_DESC:
|
||||
db_query = db_query.order_by(column.desc())
|
||||
|
||||
db_query = self.config.finalize_query(db_query)
|
||||
return db_query
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Callable
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
SaColumn = Any
|
||||
SaQuery = Any
|
||||
|
|
|
@ -725,6 +725,7 @@ def test_filter_by_feature_date(
|
|||
"sort:fav-time",
|
||||
"sort:feature-date",
|
||||
"sort:feature-time",
|
||||
"sort:pool",
|
||||
],
|
||||
)
|
||||
def test_sort_tokens(verify_unpaged, post_factory, input):
|
||||
|
@ -863,3 +864,42 @@ def test_tumbleweed(
|
|||
db.session.flush()
|
||||
verify_unpaged("special:tumbleweed", [4])
|
||||
verify_unpaged("-special:tumbleweed", [1, 2, 3])
|
||||
|
||||
|
||||
def test_sort_pool(
|
||||
post_factory, pool_factory, pool_category_factory, verify_unpaged
|
||||
):
|
||||
post1 = post_factory(id=1)
|
||||
post2 = post_factory(id=2)
|
||||
post3 = post_factory(id=3)
|
||||
post4 = post_factory(id=4)
|
||||
pool1 = pool_factory(
|
||||
id=1,
|
||||
names=["pool1"],
|
||||
description="desc",
|
||||
category=pool_category_factory("test-cat1"),
|
||||
)
|
||||
pool1.posts = [post1, post4, post3]
|
||||
pool2 = pool_factory(
|
||||
id=2,
|
||||
names=["pool2"],
|
||||
description="desc",
|
||||
category=pool_category_factory("test-cat2"),
|
||||
)
|
||||
pool2.posts = [post3, post4, post2]
|
||||
db.session.add_all(
|
||||
[
|
||||
post1,
|
||||
post2,
|
||||
post3,
|
||||
post4,
|
||||
pool1,
|
||||
pool2
|
||||
]
|
||||
)
|
||||
db.session.flush()
|
||||
verify_unpaged("pool:1 sort:pool", [1, 4, 3])
|
||||
verify_unpaged("pool:2 sort:pool", [3, 4, 2])
|
||||
verify_unpaged("pool:1 pool:2 sort:pool", [4, 3])
|
||||
verify_unpaged("pool:2 pool:1 sort:pool", [3, 4])
|
||||
verify_unpaged("sort:pool", [1, 2, 3, 4])
|
||||
|
|
Reference in a new issue