Retrieve surrounding pool posts in pool search query

This commit is contained in:
Ruin0x11 2021-05-08 16:42:30 -07:00
parent eee9b70b0e
commit 676a5ff97c
3 changed files with 48 additions and 12 deletions

View file

@ -1,5 +1,7 @@
from typing import Callable, Dict, Optional, Tuple
import sqlalchemy as sa
from szurubooru.search import criteria, tokens
from szurubooru.search.query import SearchQuery
from szurubooru.search.typing import SaColumn, SaQuery
@ -24,6 +26,21 @@ class BaseSearchConfig:
def create_around_query(self) -> SaQuery:
raise NotImplementedError()
def create_around_filter_queries(self, filter_query: SaQuery, entity_id: int) -> Tuple[SaQuery, SaQuery]:
prev_filter_query = (
filter_query.filter(self.id_column > entity_id)
.order_by(None)
.order_by(sa.func.abs(self.id_column - entity_id).asc())
.limit(1)
)
next_filter_query = (
filter_query.filter(self.id_column < entity_id)
.order_by(None)
.order_by(sa.func.abs(self.id_column - entity_id).asc())
.limit(1)
)
return (prev_filter_query, next_filter_query)
def finalize_query(self, query: SaQuery) -> SaQuery:
return query

View file

@ -134,6 +134,31 @@ def _pool_sort(
.order_by(model.PoolPost.order.desc())
def _posts_around_pool(filter_query: SaQuery, post_id: int, pool_id: int) -> Tuple[SaQuery, SaQuery]:
this_order = db.session.query(model.PoolPost) \
.filter(model.PoolPost.post_id == post_id) \
.filter(model.PoolPost.pool_id == pool_id) \
.one().order
filter_query = db.session.query(model.Post) \
.join(model.PoolPost, model.PoolPost.pool_id == pool_id) \
.filter(model.PoolPost.post_id == model.Post.post_id)
prev_filter_query = (
filter_query.filter(model.PoolPost.order > this_order)
.order_by(None)
.order_by(sa.func.abs(model.PoolPost.order - this_order).asc())
.limit(1)
)
next_filter_query = (
filter_query.filter(model.PoolPost.order < this_order)
.order_by(None)
.order_by(sa.func.abs(model.PoolPost.order - this_order).asc())
.limit(1)
)
return (prev_filter_query, next_filter_query)
class PostSearchConfig(BaseSearchConfig):
def __init__(self) -> None:
self.user = None # type: Optional[model.User]
@ -170,6 +195,11 @@ class PostSearchConfig(BaseSearchConfig):
def create_around_query(self) -> SaQuery:
return db.session.query(model.Post).options(sa.orm.lazyload("*"))
def create_around_filter_queries(self, filter_query: SaQuery, entity_id: int) -> Tuple[SaQuery, SaQuery]:
if self.pool_id is not None:
return _posts_around_pool(filter_query, entity_id, self.pool_id)
return super(PostSearchConfig, self).create_around_filter_queries(filter_query, entity_id)
def create_filter_query(self, disable_eager_loads: bool) -> SaQuery:
strategy = (
sa.orm.lazyload if disable_eager_loads else sa.orm.subqueryload

View file

@ -47,18 +47,7 @@ class Executor:
filter_query = self._prepare_db_query(
filter_query, search_query, False
)
prev_filter_query = (
filter_query.filter(self.config.id_column > entity_id)
.order_by(None)
.order_by(sa.func.abs(self.config.id_column - entity_id).asc())
.limit(1)
)
next_filter_query = (
filter_query.filter(self.config.id_column < entity_id)
.order_by(None)
.order_by(sa.func.abs(self.config.id_column - entity_id).asc())
.limit(1)
)
prev_filter_query, next_filter_query = self.config.create_around_filter_queries(filter_query, entity_id)
return (
prev_filter_query.one_or_none(),
next_filter_query.one_or_none(),