diff --git a/server/szurubooru/search/base_search_config.py b/server/szurubooru/search/base_search_config.py index 38221993..a9b0e681 100644 --- a/server/szurubooru/search/base_search_config.py +++ b/server/szurubooru/search/base_search_config.py @@ -11,9 +11,12 @@ class BaseSearchConfig(object): SORT_DESC = -1 SORT_ASC = 1 - def create_query(self): + def create_filter_query(self): raise NotImplementedError() + def create_count_query(self): + return self.create_filter_query() + @property def anonymous_filter(self): return None diff --git a/server/szurubooru/search/comment_search_config.py b/server/szurubooru/search/comment_search_config.py index 52fa7790..03224265 100644 --- a/server/szurubooru/search/comment_search_config.py +++ b/server/szurubooru/search/comment_search_config.py @@ -3,7 +3,7 @@ from szurubooru import db from szurubooru.search.base_search_config import BaseSearchConfig class CommentSearchConfig(BaseSearchConfig): - def create_query(self): + def create_filter_query(self): return db.session.query(db.Comment).join(db.User) def finalize_query(self, query): diff --git a/server/szurubooru/search/post_search_config.py b/server/szurubooru/search/post_search_config.py index 79dc7bae..8c90f304 100644 --- a/server/szurubooru/search/post_search_config.py +++ b/server/szurubooru/search/post_search_config.py @@ -22,7 +22,7 @@ def _type_transformer(value): value, available_types)) class PostSearchConfig(BaseSearchConfig): - def create_query(self): + def create_filter_query(self): return db.session.query(db.Post) def finalize_query(self, query): diff --git a/server/szurubooru/search/search_executor.py b/server/szurubooru/search/search_executor.py index bade7a42..fc5d9c0a 100644 --- a/server/szurubooru/search/search_executor.py +++ b/server/szurubooru/search/search_executor.py @@ -1,6 +1,6 @@ import re import sqlalchemy -from szurubooru import errors +from szurubooru import db, errors from szurubooru.search import criteria class SearchExecutor(object): @@ -17,15 +17,22 @@ class SearchExecutor(object): Parse input and return tuple containing total record count and filtered entities. ''' - filter_query = self._prepare(query_text) + filter_query = self.config.create_filter_query() + filter_query = filter_query.options(sqlalchemy.orm.lazyload('*')) + filter_query = self._prepare(filter_query, query_text) entities = filter_query \ - .offset((page - 1) * page_size).limit(page_size).all() - count_query = filter_query.statement \ + .offset((page - 1) * page_size) \ + .limit(page_size) \ + .all() + + count_query = self.config.create_count_query() + count_query = count_query.options(sqlalchemy.orm.lazyload('*')) + count_query = self._prepare(count_query, query_text) + count_statement = count_query \ + .statement \ .with_only_columns([sqlalchemy.func.count()]) \ .order_by(None) - count = filter_query.session \ - .execute(count_query) \ - .scalar() + count = db.session.execute(count_statement).scalar() return (count, entities) def execute_and_serialize(self, ctx, serializer): @@ -41,10 +48,8 @@ class SearchExecutor(object): 'results': [serializer(entity) for entity in entities], } - def _prepare(self, query_text): + def _prepare(self, query, query_text): ''' Parse input and return SQLAlchemy query. ''' - query = self.config.create_query() \ - .options(sqlalchemy.orm.lazyload('*')) for token in re.split(r'\s+', (query_text or '').lower()): if not token: continue diff --git a/server/szurubooru/search/snapshot_search_config.py b/server/szurubooru/search/snapshot_search_config.py index daa783dc..971863f1 100644 --- a/server/szurubooru/search/snapshot_search_config.py +++ b/server/szurubooru/search/snapshot_search_config.py @@ -2,7 +2,7 @@ from szurubooru import db from szurubooru.search.base_search_config import BaseSearchConfig class SnapshotSearchConfig(BaseSearchConfig): - def create_query(self): + def create_filter_query(self): return db.session.query(db.Snapshot) def finalize_query(self, query): diff --git a/server/szurubooru/search/tag_search_config.py b/server/szurubooru/search/tag_search_config.py index 4be9c023..3daa0655 100644 --- a/server/szurubooru/search/tag_search_config.py +++ b/server/szurubooru/search/tag_search_config.py @@ -1,9 +1,18 @@ +from sqlalchemy.orm import joinedload from sqlalchemy.sql.expression import func from szurubooru import db from szurubooru.search.base_search_config import BaseSearchConfig class TagSearchConfig(BaseSearchConfig): - def create_query(self): + def create_filter_query(self): + return self.create_count_query().options( + joinedload(db.Tag.names), + joinedload(db.Tag.category), + joinedload(db.Tag.suggestions).joinedload(db.Tag.names), + joinedload(db.Tag.implications).joinedload(db.Tag.names) + ) + + def create_count_query(self): return db.session.query(db.Tag) def finalize_query(self, query): diff --git a/server/szurubooru/search/user_search_config.py b/server/szurubooru/search/user_search_config.py index 13adae9e..bc6927ef 100644 --- a/server/szurubooru/search/user_search_config.py +++ b/server/szurubooru/search/user_search_config.py @@ -5,7 +5,7 @@ from szurubooru.search.base_search_config import BaseSearchConfig class UserSearchConfig(BaseSearchConfig): ''' Executes searches related to the users. ''' - def create_query(self): + def create_filter_query(self): return db.session.query(db.User) def finalize_query(self, query): diff --git a/server/szurubooru/tests/search/test_post_search_config.py b/server/szurubooru/tests/search/test_post_search_config.py index 2a9e8bb1..810f79de 100644 --- a/server/szurubooru/tests/search/test_post_search_config.py +++ b/server/szurubooru/tests/search/test_post_search_config.py @@ -56,7 +56,6 @@ def verify_unpaged(executor): actual_count, actual_posts = executor.execute( input, page=1, page_size=100) actual_post_ids = list([p.post_id for p in actual_posts]) - print(actual_post_ids, expected_post_ids) assert actual_count == len(expected_post_ids) if not test_order: actual_post_ids = sorted(actual_post_ids)