Support sorting post search results by pool post order
This commit is contained in:
parent
ca77149597
commit
81645864ec
5 changed files with 81 additions and 15 deletions
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple, Callable, Union
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
@ -114,17 +114,30 @@ def _pool_filter(
|
||||||
query: SaQuery, criterion: Optional[criteria.BaseCriterion], negated: bool
|
query: SaQuery, criterion: Optional[criteria.BaseCriterion], negated: bool
|
||||||
) -> SaQuery:
|
) -> SaQuery:
|
||||||
assert criterion
|
assert criterion
|
||||||
return search_util.create_subquery_filter(
|
from szurubooru.search.configs import util as search_util
|
||||||
model.Post.post_id,
|
subquery = db.session.query(model.PoolPost.post_id.label("foreign_id"))
|
||||||
model.PoolPost.post_id,
|
subquery = subquery.options(sa.orm.lazyload("*"))
|
||||||
model.PoolPost.pool_id,
|
subquery = search_util.create_num_filter(model.PoolPost.pool_id)(subquery, criterion, False)
|
||||||
search_util.create_num_filter,
|
subquery = subquery.subquery("t")
|
||||||
)(query, criterion, negated)
|
expression = model.Post.post_id.in_(subquery)
|
||||||
|
if negated:
|
||||||
|
expression = ~expression
|
||||||
|
return query.filter(expression)
|
||||||
|
|
||||||
|
|
||||||
|
def _pool_sort(
|
||||||
|
query: SaQuery, pool_id: Optional[int]
|
||||||
|
) -> SaQuery:
|
||||||
|
if pool_id is None:
|
||||||
|
return query
|
||||||
|
return query.join(model.PoolPost, sa.and_(model.PoolPost.post_id == model.Post.post_id, model.PoolPost.pool_id == pool_id)) \
|
||||||
|
.order_by(model.PoolPost.order.desc())
|
||||||
|
|
||||||
|
|
||||||
class PostSearchConfig(BaseSearchConfig):
|
class PostSearchConfig(BaseSearchConfig):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.user = None # type: Optional[model.User]
|
self.user = None # type: Optional[model.User]
|
||||||
|
self.pool_id = None # type: Optional[int]
|
||||||
|
|
||||||
def on_search_query_parsed(self, search_query: SearchQuery) -> SaQuery:
|
def on_search_query_parsed(self, search_query: SearchQuery) -> SaQuery:
|
||||||
new_special_tokens = []
|
new_special_tokens = []
|
||||||
|
@ -149,6 +162,10 @@ class PostSearchConfig(BaseSearchConfig):
|
||||||
else:
|
else:
|
||||||
new_special_tokens.append(token)
|
new_special_tokens.append(token)
|
||||||
search_query.special_tokens = new_special_tokens
|
search_query.special_tokens = new_special_tokens
|
||||||
|
self.pool_id = None
|
||||||
|
for token in search_query.named_tokens:
|
||||||
|
if token.name == "pool" and isinstance(token.criterion, criteria.PlainCriterion):
|
||||||
|
self.pool_id = token.criterion.value
|
||||||
|
|
||||||
def create_around_query(self) -> SaQuery:
|
def create_around_query(self) -> SaQuery:
|
||||||
return db.session.query(model.Post).options(sa.orm.lazyload("*"))
|
return db.session.query(model.Post).options(sa.orm.lazyload("*"))
|
||||||
|
@ -353,7 +370,7 @@ class PostSearchConfig(BaseSearchConfig):
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]:
|
def sort_columns(self) -> Dict[str, Union[Tuple[SaColumn, str], Callable[[SaQuery], None]]]:
|
||||||
return util.unalias_dict(
|
return util.unalias_dict(
|
||||||
[
|
[
|
||||||
(
|
(
|
||||||
|
@ -415,6 +432,10 @@ class PostSearchConfig(BaseSearchConfig):
|
||||||
["feature-date", "feature-time"],
|
["feature-date", "feature-time"],
|
||||||
(model.Post.last_feature_time, self.SORT_DESC),
|
(model.Post.last_feature_time, self.SORT_DESC),
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
["pool"],
|
||||||
|
lambda subquery: _pool_sort(subquery, self.pool_id)
|
||||||
|
)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -205,6 +205,7 @@ def create_subquery_filter(
|
||||||
filter_column: SaColumn,
|
filter_column: SaColumn,
|
||||||
filter_factory: SaColumn,
|
filter_factory: SaColumn,
|
||||||
subquery_decorator: Callable[[SaQuery], None] = None,
|
subquery_decorator: Callable[[SaQuery], None] = None,
|
||||||
|
order: SaQuery = None,
|
||||||
) -> Filter:
|
) -> Filter:
|
||||||
filter_func = filter_factory(filter_column)
|
filter_func = filter_factory(filter_column)
|
||||||
|
|
||||||
|
|
|
@ -181,9 +181,13 @@ class Executor:
|
||||||
_format_dict_keys(self.config.sort_columns),
|
_format_dict_keys(self.config.sort_columns),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
column, default_order = self.config.sort_columns[
|
entry = self.config.sort_columns[
|
||||||
sort_token.name
|
sort_token.name
|
||||||
]
|
]
|
||||||
|
if callable(entry):
|
||||||
|
db_query = entry(db_query)
|
||||||
|
else:
|
||||||
|
column, default_order = entry
|
||||||
order = _get_order(sort_token.order, default_order)
|
order = _get_order(sort_token.order, default_order)
|
||||||
if order == sort_token.SORT_ASC:
|
if order == sort_token.SORT_ASC:
|
||||||
db_query = db_query.order_by(column.asc())
|
db_query = db_query.order_by(column.asc())
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable, Union
|
||||||
|
|
||||||
SaColumn = Any
|
SaColumn = Any
|
||||||
SaQuery = Any
|
SaQuery = Any
|
||||||
|
|
|
@ -725,6 +725,7 @@ def test_filter_by_feature_date(
|
||||||
"sort:fav-time",
|
"sort:fav-time",
|
||||||
"sort:feature-date",
|
"sort:feature-date",
|
||||||
"sort:feature-time",
|
"sort:feature-time",
|
||||||
|
"sort:pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_sort_tokens(verify_unpaged, post_factory, input):
|
def test_sort_tokens(verify_unpaged, post_factory, input):
|
||||||
|
@ -863,3 +864,42 @@ def test_tumbleweed(
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
verify_unpaged("special:tumbleweed", [4])
|
verify_unpaged("special:tumbleweed", [4])
|
||||||
verify_unpaged("-special:tumbleweed", [1, 2, 3])
|
verify_unpaged("-special:tumbleweed", [1, 2, 3])
|
||||||
|
|
||||||
|
|
||||||
|
def test_sort_pool(
|
||||||
|
post_factory, pool_factory, pool_category_factory, verify_unpaged
|
||||||
|
):
|
||||||
|
post1 = post_factory(id=1)
|
||||||
|
post2 = post_factory(id=2)
|
||||||
|
post3 = post_factory(id=3)
|
||||||
|
post4 = post_factory(id=4)
|
||||||
|
pool1 = pool_factory(
|
||||||
|
id=1,
|
||||||
|
names=["pool1"],
|
||||||
|
description="desc",
|
||||||
|
category=pool_category_factory("test-cat1"),
|
||||||
|
)
|
||||||
|
pool1.posts = [post1, post4, post3]
|
||||||
|
pool2 = pool_factory(
|
||||||
|
id=2,
|
||||||
|
names=["pool2"],
|
||||||
|
description="desc",
|
||||||
|
category=pool_category_factory("test-cat2"),
|
||||||
|
)
|
||||||
|
pool2.posts = [post3, post4, post2]
|
||||||
|
db.session.add_all(
|
||||||
|
[
|
||||||
|
post1,
|
||||||
|
post2,
|
||||||
|
post3,
|
||||||
|
post4,
|
||||||
|
pool1,
|
||||||
|
pool2
|
||||||
|
]
|
||||||
|
)
|
||||||
|
db.session.flush()
|
||||||
|
verify_unpaged("pool:1 sort:pool", [1, 4, 3])
|
||||||
|
verify_unpaged("pool:2 sort:pool", [3, 4, 2])
|
||||||
|
verify_unpaged("pool:1 pool:2 sort:pool", [4, 3])
|
||||||
|
verify_unpaged("pool:2 pool:1 sort:pool", [3, 4])
|
||||||
|
verify_unpaged("sort:pool", [1, 2, 3, 4])
|
||||||
|
|
Reference in a new issue