Retrieve surrounding pool posts in pool search query

This commit is contained in:
Ruin0x11 2021-05-08 16:42:30 -07:00
parent eee9b70b0e
commit 676a5ff97c
3 changed files with 48 additions and 12 deletions

View file

@ -1,5 +1,7 @@
from typing import Callable, Dict, Optional, Tuple from typing import Callable, Dict, Optional, Tuple
import sqlalchemy as sa
from szurubooru.search import criteria, tokens from szurubooru.search import criteria, tokens
from szurubooru.search.query import SearchQuery from szurubooru.search.query import SearchQuery
from szurubooru.search.typing import SaColumn, SaQuery from szurubooru.search.typing import SaColumn, SaQuery
@ -24,6 +26,21 @@ class BaseSearchConfig:
def create_around_query(self) -> SaQuery: def create_around_query(self) -> SaQuery:
raise NotImplementedError() raise NotImplementedError()
def create_around_filter_queries(self, filter_query: SaQuery, entity_id: int) -> Tuple[SaQuery, SaQuery]:
prev_filter_query = (
filter_query.filter(self.id_column > entity_id)
.order_by(None)
.order_by(sa.func.abs(self.id_column - entity_id).asc())
.limit(1)
)
next_filter_query = (
filter_query.filter(self.id_column < entity_id)
.order_by(None)
.order_by(sa.func.abs(self.id_column - entity_id).asc())
.limit(1)
)
return (prev_filter_query, next_filter_query)
def finalize_query(self, query: SaQuery) -> SaQuery: def finalize_query(self, query: SaQuery) -> SaQuery:
return query return query

View file

@ -134,6 +134,31 @@ def _pool_sort(
.order_by(model.PoolPost.order.desc()) .order_by(model.PoolPost.order.desc())
def _posts_around_pool(filter_query: SaQuery, post_id: int, pool_id: int) -> Tuple[SaQuery, SaQuery]:
this_order = db.session.query(model.PoolPost) \
.filter(model.PoolPost.post_id == post_id) \
.filter(model.PoolPost.pool_id == pool_id) \
.one().order
filter_query = db.session.query(model.Post) \
.join(model.PoolPost, model.PoolPost.pool_id == pool_id) \
.filter(model.PoolPost.post_id == model.Post.post_id)
prev_filter_query = (
filter_query.filter(model.PoolPost.order > this_order)
.order_by(None)
.order_by(sa.func.abs(model.PoolPost.order - this_order).asc())
.limit(1)
)
next_filter_query = (
filter_query.filter(model.PoolPost.order < this_order)
.order_by(None)
.order_by(sa.func.abs(model.PoolPost.order - this_order).asc())
.limit(1)
)
return (prev_filter_query, next_filter_query)
class PostSearchConfig(BaseSearchConfig): class PostSearchConfig(BaseSearchConfig):
def __init__(self) -> None: def __init__(self) -> None:
self.user = None # type: Optional[model.User] self.user = None # type: Optional[model.User]
@ -170,6 +195,11 @@ class PostSearchConfig(BaseSearchConfig):
def create_around_query(self) -> SaQuery: def create_around_query(self) -> SaQuery:
return db.session.query(model.Post).options(sa.orm.lazyload("*")) return db.session.query(model.Post).options(sa.orm.lazyload("*"))
def create_around_filter_queries(self, filter_query: SaQuery, entity_id: int) -> Tuple[SaQuery, SaQuery]:
if self.pool_id is not None:
return _posts_around_pool(filter_query, entity_id, self.pool_id)
return super(PostSearchConfig, self).create_around_filter_queries(filter_query, entity_id)
def create_filter_query(self, disable_eager_loads: bool) -> SaQuery: def create_filter_query(self, disable_eager_loads: bool) -> SaQuery:
strategy = ( strategy = (
sa.orm.lazyload if disable_eager_loads else sa.orm.subqueryload sa.orm.lazyload if disable_eager_loads else sa.orm.subqueryload

View file

@ -47,18 +47,7 @@ class Executor:
filter_query = self._prepare_db_query( filter_query = self._prepare_db_query(
filter_query, search_query, False filter_query, search_query, False
) )
prev_filter_query = ( prev_filter_query, next_filter_query = self.config.create_around_filter_queries(filter_query, entity_id)
filter_query.filter(self.config.id_column > entity_id)
.order_by(None)
.order_by(sa.func.abs(self.config.id_column - entity_id).asc())
.limit(1)
)
next_filter_query = (
filter_query.filter(self.config.id_column < entity_id)
.order_by(None)
.order_by(sa.func.abs(self.config.id_column - entity_id).asc())
.limit(1)
)
return ( return (
prev_filter_query.one_or_none(), prev_filter_query.one_or_none(),
next_filter_query.one_or_none(), next_filter_query.one_or_none(),