diff --git a/server/szurubooru/api/comment_api.py b/server/szurubooru/api/comment_api.py index e25ce606..6f5ee00a 100644 --- a/server/szurubooru/api/comment_api.py +++ b/server/szurubooru/api/comment_api.py @@ -12,8 +12,8 @@ def _serialize(ctx, comment, **kwargs): class CommentListApi(BaseApi): def __init__(self): super().__init__() - self._search_executor = search.SearchExecutor( - search.CommentSearchConfig()) + self._search_executor = search.Executor( + search.configs.CommentSearchConfig()) def get(self, ctx): auth.verify_privilege(ctx.user, 'comments:list') diff --git a/server/szurubooru/api/post_api.py b/server/szurubooru/api/post_api.py index 1496952d..c6a2713c 100644 --- a/server/szurubooru/api/post_api.py +++ b/server/szurubooru/api/post_api.py @@ -12,7 +12,8 @@ def _serialize_post(ctx, post): class PostListApi(BaseApi): def __init__(self): super().__init__() - self._search_executor = search.SearchExecutor(search.PostSearchConfig()) + self._search_executor = search.Executor( + search.configs.PostSearchConfig()) def get(self, ctx): auth.verify_privilege(ctx.user, 'posts:list') diff --git a/server/szurubooru/api/snapshot_api.py b/server/szurubooru/api/snapshot_api.py index 9f16e489..3f830f90 100644 --- a/server/szurubooru/api/snapshot_api.py +++ b/server/szurubooru/api/snapshot_api.py @@ -5,7 +5,8 @@ from szurubooru.func import auth, snapshots class SnapshotListApi(BaseApi): def __init__(self): super().__init__() - self._search_executor = search.SearchExecutor(search.SnapshotSearchConfig()) + self._search_executor = search.Executor( + search.configs.SnapshotSearchConfig()) def get(self, ctx): auth.verify_privilege(ctx.user, 'snapshots:list') diff --git a/server/szurubooru/api/tag_api.py b/server/szurubooru/api/tag_api.py index 372db1a9..c49ec1bb 100644 --- a/server/szurubooru/api/tag_api.py +++ b/server/szurubooru/api/tag_api.py @@ -19,7 +19,8 @@ def _create_if_needed(tag_names, user): class TagListApi(BaseApi): def __init__(self): super().__init__() - self._search_executor = search.SearchExecutor(search.TagSearchConfig()) + self._search_executor = search.Executor( + search.configs.TagSearchConfig()) def get(self, ctx): auth.verify_privilege(ctx.user, 'tags:list') diff --git a/server/szurubooru/api/user_api.py b/server/szurubooru/api/user_api.py index b7d5bd00..6125fc55 100644 --- a/server/szurubooru/api/user_api.py +++ b/server/szurubooru/api/user_api.py @@ -12,7 +12,8 @@ def _serialize(ctx, user, **kwargs): class UserListApi(BaseApi): def __init__(self): super().__init__() - self._search_executor = search.SearchExecutor(search.UserSearchConfig()) + self._search_executor = search.Executor( + search.configs.UserSearchConfig()) def get(self, ctx): auth.verify_privilege(ctx.user, 'users:list') diff --git a/server/szurubooru/search/__init__.py b/server/szurubooru/search/__init__.py index d16e9879..919475fe 100644 --- a/server/szurubooru/search/__init__.py +++ b/server/szurubooru/search/__init__.py @@ -1,8 +1,2 @@ -''' Search parsers and services. ''' - -from szurubooru.search.search_executor import SearchExecutor -from szurubooru.search.user_search_config import UserSearchConfig -from szurubooru.search.snapshot_search_config import SnapshotSearchConfig -from szurubooru.search.tag_search_config import TagSearchConfig -from szurubooru.search.comment_search_config import CommentSearchConfig -from szurubooru.search.post_search_config import PostSearchConfig +from szurubooru.search.executor import Executor +import szurubooru.search.configs diff --git a/server/szurubooru/search/configs/__init__.py b/server/szurubooru/search/configs/__init__.py new file mode 100644 index 00000000..9f48e14d --- /dev/null +++ b/server/szurubooru/search/configs/__init__.py @@ -0,0 +1,5 @@ +from szurubooru.search.configs.user_search_config import UserSearchConfig +from szurubooru.search.configs.snapshot_search_config import SnapshotSearchConfig +from szurubooru.search.configs.tag_search_config import TagSearchConfig +from szurubooru.search.configs.comment_search_config import CommentSearchConfig +from szurubooru.search.configs.post_search_config import PostSearchConfig diff --git a/server/szurubooru/search/base_search_config.py b/server/szurubooru/search/configs/base_search_config.py similarity index 72% rename from server/szurubooru/search/base_search_config.py rename to server/szurubooru/search/configs/base_search_config.py index 3830768e..2dda98fa 100644 --- a/server/szurubooru/search/base_search_config.py +++ b/server/szurubooru/search/configs/base_search_config.py @@ -2,13 +2,14 @@ import sqlalchemy from szurubooru import db, errors from szurubooru.func import util from szurubooru.search import criteria +from szurubooru.search import tokens def wildcard_transformer(value): return value.replace('*', '%') class BaseSearchConfig(object): - SORT_DESC = -1 - SORT_ASC = 1 + SORT_ASC = tokens.SortToken.SORT_ASC + SORT_DESC = tokens.SortToken.SORT_DESC def create_filter_query(self): raise NotImplementedError() @@ -38,11 +39,11 @@ class BaseSearchConfig(object): Decorate SQLAlchemy filter on given column using supplied criterion. ''' try: - if isinstance(criterion, criteria.PlainSearchCriterion): + if isinstance(criterion, criteria.PlainCriterion): expr = column == int(criterion.value) - elif isinstance(criterion, criteria.ArraySearchCriterion): + elif isinstance(criterion, criteria.ArrayCriterion): expr = column.in_(int(value) for value in criterion.values) - elif isinstance(criterion, criteria.RangedSearchCriterion): + elif isinstance(criterion, criteria.RangedCriterion): assert criterion.min_value != '' \ or criterion.max_value != '' if criterion.min_value != '' and criterion.max_value != '': @@ -57,40 +58,45 @@ class BaseSearchConfig(object): except ValueError: raise errors.SearchError( 'Criterion value %r must be a number.' % (criterion,)) - if criterion.negated: - expr = ~expr return expr @staticmethod def _create_num_filter(column): - return lambda query, criterion: query.filter( - BaseSearchConfig._apply_num_criterion_to_column(column, criterion)) + def wrapper(query, criterion, negated): + expr = BaseSearchConfig._apply_num_criterion_to_column( + column, criterion) + if negated: + expr = ~expr + return query.filter(expr) + return wrapper @staticmethod def _apply_str_criterion_to_column(column, criterion, transformer): ''' Decorate SQLAlchemy filter on given column using supplied criterion. ''' - if isinstance(criterion, criteria.PlainSearchCriterion): + if isinstance(criterion, criteria.PlainCriterion): expr = column.ilike(transformer(criterion.value)) - elif isinstance(criterion, criteria.ArraySearchCriterion): + elif isinstance(criterion, criteria.ArrayCriterion): expr = sqlalchemy.sql.false() for value in criterion.values: expr = expr | column.ilike(transformer(value)) - elif isinstance(criterion, criteria.RangedSearchCriterion): + elif isinstance(criterion, criteria.RangedCriterion): raise errors.SearchError( 'Composite token %r is invalid in this context.' % (criterion,)) else: assert False - if criterion.negated: - expr = ~expr return expr @staticmethod def _create_str_filter(column, transformer=wildcard_transformer): - return lambda query, criterion: query.filter( - BaseSearchConfig._apply_str_criterion_to_column( - column, criterion, transformer)) + def wrapper(query, criterion, negated): + expr = BaseSearchConfig._apply_str_criterion_to_column( + column, criterion, transformer) + if negated: + expr = ~expr + return query.filter(expr) + return wrapper @staticmethod def _apply_date_criterion_to_column(column, criterion): @@ -98,15 +104,15 @@ class BaseSearchConfig(object): Decorate SQLAlchemy filter on given column using supplied criterion. Parse the datetime inside the criterion. ''' - if isinstance(criterion, criteria.PlainSearchCriterion): + if isinstance(criterion, criteria.PlainCriterion): min_date, max_date = util.parse_time_range(criterion.value) expr = column.between(min_date, max_date) - elif isinstance(criterion, criteria.ArraySearchCriterion): + elif isinstance(criterion, criteria.ArrayCriterion): expr = sqlalchemy.sql.false() for value in criterion.values: min_date, max_date = util.parse_time_range(value) expr = expr | column.between(min_date, max_date) - elif isinstance(criterion, criteria.RangedSearchCriterion): + elif isinstance(criterion, criteria.RangedCriterion): assert criterion.min_value or criterion.max_value if criterion.min_value and criterion.max_value: min_date = util.parse_time_range(criterion.min_value)[0] @@ -120,14 +126,17 @@ class BaseSearchConfig(object): expr = column <= max_date else: assert False - if criterion.negated: - expr = ~expr return expr @staticmethod def _create_date_filter(column): - return lambda query, criterion: query.filter( - BaseSearchConfig._apply_date_criterion_to_column(column, criterion)) + def wrapper(query, criterion, negated): + expr = BaseSearchConfig._apply_date_criterion_to_column( + column, criterion) + if negated: + expr = ~expr + return query.filter(expr) + return wrapper @staticmethod def _create_subquery_filter( @@ -137,12 +146,12 @@ class BaseSearchConfig(object): filter_factory, subquery_decorator=None): filter_func = filter_factory(filter_column) - def func(query, criterion): + def wrapper(query, criterion, negated): subquery = db.session.query(right_id_column.label('foreign_id')) if subquery_decorator: subquery = subquery_decorator(subquery) subquery = subquery.options(sqlalchemy.orm.lazyload('*')) - subquery = filter_func(subquery, criterion) + subquery = filter_func(subquery, criterion, negated) subquery = subquery.subquery('t') return query.filter(left_id_column.in_(subquery)) - return func + return wrapper diff --git a/server/szurubooru/search/comment_search_config.py b/server/szurubooru/search/configs/comment_search_config.py similarity index 96% rename from server/szurubooru/search/comment_search_config.py rename to server/szurubooru/search/configs/comment_search_config.py index 03224265..bc6eea70 100644 --- a/server/szurubooru/search/comment_search_config.py +++ b/server/szurubooru/search/configs/comment_search_config.py @@ -1,6 +1,6 @@ from sqlalchemy.sql.expression import func from szurubooru import db -from szurubooru.search.base_search_config import BaseSearchConfig +from szurubooru.search.configs.base_search_config import BaseSearchConfig class CommentSearchConfig(BaseSearchConfig): def create_filter_query(self): diff --git a/server/szurubooru/search/post_search_config.py b/server/szurubooru/search/configs/post_search_config.py similarity index 97% rename from server/szurubooru/search/post_search_config.py rename to server/szurubooru/search/configs/post_search_config.py index 0797491c..75ab71a1 100644 --- a/server/szurubooru/search/post_search_config.py +++ b/server/szurubooru/search/configs/post_search_config.py @@ -2,14 +2,15 @@ from sqlalchemy.orm import subqueryload, lazyload, defer from sqlalchemy.sql.expression import func from szurubooru import db, errors from szurubooru.func import util -from szurubooru.search.base_search_config import BaseSearchConfig +from szurubooru.search.configs.base_search_config import BaseSearchConfig def _enum_transformer(available_values, value): try: return available_values[value.lower()] except KeyError: - raise errors.SearchError('Invalid value: %r. Possible values: %r.' % ( - value, list(sorted(available_values.keys())))) + raise errors.SearchError( + 'Invalid value: %r. Possible values: %r.' % ( + value, list(sorted(available_values.keys())))) def _type_transformer(value): available_values = { diff --git a/server/szurubooru/search/snapshot_search_config.py b/server/szurubooru/search/configs/snapshot_search_config.py similarity index 90% rename from server/szurubooru/search/snapshot_search_config.py rename to server/szurubooru/search/configs/snapshot_search_config.py index 971863f1..a99aee73 100644 --- a/server/szurubooru/search/snapshot_search_config.py +++ b/server/szurubooru/search/configs/snapshot_search_config.py @@ -1,5 +1,5 @@ from szurubooru import db -from szurubooru.search.base_search_config import BaseSearchConfig +from szurubooru.search.configs.base_search_config import BaseSearchConfig class SnapshotSearchConfig(BaseSearchConfig): def create_filter_query(self): diff --git a/server/szurubooru/search/tag_search_config.py b/server/szurubooru/search/configs/tag_search_config.py similarity index 97% rename from server/szurubooru/search/tag_search_config.py rename to server/szurubooru/search/configs/tag_search_config.py index d9ed898f..4422d875 100644 --- a/server/szurubooru/search/tag_search_config.py +++ b/server/szurubooru/search/configs/tag_search_config.py @@ -2,7 +2,7 @@ from sqlalchemy.orm import subqueryload from sqlalchemy.sql.expression import func from szurubooru import db from szurubooru.func import util -from szurubooru.search.base_search_config import BaseSearchConfig +from szurubooru.search.configs.base_search_config import BaseSearchConfig class TagSearchConfig(BaseSearchConfig): def create_filter_query(self): diff --git a/server/szurubooru/search/user_search_config.py b/server/szurubooru/search/configs/user_search_config.py similarity index 95% rename from server/szurubooru/search/user_search_config.py rename to server/szurubooru/search/configs/user_search_config.py index bc6927ef..5d50be10 100644 --- a/server/szurubooru/search/user_search_config.py +++ b/server/szurubooru/search/configs/user_search_config.py @@ -1,6 +1,6 @@ from sqlalchemy.sql.expression import func from szurubooru import db -from szurubooru.search.base_search_config import BaseSearchConfig +from szurubooru.search.configs.base_search_config import BaseSearchConfig class UserSearchConfig(BaseSearchConfig): ''' Executes searches related to the users. ''' diff --git a/server/szurubooru/search/criteria.py b/server/szurubooru/search/criteria.py index 422966b4..ac0e95ce 100644 --- a/server/szurubooru/search/criteria.py +++ b/server/szurubooru/search/criteria.py @@ -1,23 +1,22 @@ -class _BaseSearchCriterion(object): - def __init__(self, original_text, negated): +class _BaseCriterion(object): + def __init__(self, original_text): self.original_text = original_text - self.negated = negated def __repr__(self): return self.original_text -class RangedSearchCriterion(_BaseSearchCriterion): - def __init__(self, original_text, negated, min_value, max_value): - super().__init__(original_text, negated) +class RangedCriterion(_BaseCriterion): + def __init__(self, original_text, min_value, max_value): + super().__init__(original_text) self.min_value = min_value self.max_value = max_value -class PlainSearchCriterion(_BaseSearchCriterion): - def __init__(self, original_text, negated, value): - super().__init__(original_text, negated) +class PlainCriterion(_BaseCriterion): + def __init__(self, original_text, value): + super().__init__(original_text) self.value = value -class ArraySearchCriterion(_BaseSearchCriterion): - def __init__(self, original_text, negated, values): - super().__init__(original_text, negated) +class ArrayCriterion(_BaseCriterion): + def __init__(self, original_text, values): + super().__init__(original_text) self.values = values diff --git a/server/szurubooru/search/executor.py b/server/szurubooru/search/executor.py new file mode 100644 index 00000000..d4113bed --- /dev/null +++ b/server/szurubooru/search/executor.py @@ -0,0 +1,118 @@ +import sqlalchemy +from szurubooru import db, errors +from szurubooru.func import cache +from szurubooru.search import tokens, parser + +def _format_dict_keys(source): + return list(sorted(source.keys())) + +def _get_direction(direction, default_direction): + if direction == tokens.SortToken.SORT_DEFAULT: + return default_direction + if direction == tokens.SortToken.SORT_NEGATED_DEFAULT: + if default_direction == tokens.SortToken.SORT_ASC: + return tokens.SortToken.SORT_DESC + elif default_direction == tokens.SortToken.SORT_DESC: + return tokens.SortToken.SORT_ASC + assert False + return direction + +class Executor(object): + ''' + Class for search parsing and execution. Handles plaintext parsing and + delegates sqlalchemy filter decoration to SearchConfig instances. + ''' + + def __init__(self, search_config): + self.config = search_config + self.parser = parser.Parser() + + def execute(self, query_text, page, page_size): + ''' + Parse input and return tuple containing total record count and filtered + entities. + ''' + key = (id(self.config), query_text, page, page_size) + if cache.has(key): + return cache.get(key) + + search_query = self.parser.parse(query_text) + + filter_query = self.config.create_filter_query() + filter_query = filter_query.options(sqlalchemy.orm.lazyload('*')) + filter_query = self._prepare_db_query(filter_query, search_query, True) + entities = filter_query \ + .offset((page - 1) * page_size) \ + .limit(page_size) \ + .all() + + count_query = self.config.create_count_query() + count_query = count_query.options(sqlalchemy.orm.lazyload('*')) + count_query = self._prepare_db_query(count_query, search_query, False) + count_statement = count_query \ + .statement \ + .with_only_columns([sqlalchemy.func.count()]) \ + .order_by(None) + count = db.session.execute(count_statement).scalar() + + ret = (count, entities) + cache.put(key, ret) + return ret + + def execute_and_serialize(self, ctx, serializer): + query = ctx.get_param_as_string('query') + page = ctx.get_param_as_int('page', default=1, min=1) + page_size = ctx.get_param_as_int('pageSize', default=100, min=1, max=100) + count, entities = self.execute(query, page, page_size) + return { + 'query': query, + 'page': page, + 'pageSize': page_size, + 'total': count, + 'results': [serializer(entity) for entity in entities], + } + + def _prepare_db_query(self, db_query, search_query, use_sort): + ''' Parse input and return SQLAlchemy query. ''' + + for token in search_query.anonymous_tokens: + if not self.config.anonymous_filter: + raise errors.SearchError( + 'Anonymous tokens are not valid in this context.') + db_query = self.config.anonymous_filter( + db_query, token.criterion, token.negated) + + for token in search_query.named_tokens: + if token.name not in self.config.named_filters: + raise errors.SearchError( + 'Unknown named token: %r. Available named tokens: %r.' % ( + token.name, + _format_dict_keys(self.config.named_filters))) + db_query = self.config.named_filters[token.name]( + db_query, token.criterion, token.negated) + + for token in search_query.special_tokens: + if token.value not in self.config.special_filters: + raise errors.SearchError( + 'Unknown special token: %r. Available special tokens: %r.' % ( + token.value, + _format_dict_keys(self.config.special_filters))) + db_query = self.config.special_filters[token.value]( + db_query, token.negated) + + if use_sort: + for token in search_query.sort_tokens: + if token.name not in self.config.sort_columns: + raise errors.SearchError( + 'Unknown sort token: %r. Available sort tokens: %r.' % ( + token.name, + _format_dict_keys(self.config.sort_columns))) + column, default_direction = self.config.sort_columns[token.name] + direction = _get_direction(token.direction, default_direction) + if direction == token.SORT_ASC: + db_query = db_query.order_by(column.asc()) + elif direction == token.SORT_DESC: + db_query = db_query.order_by(column.desc()) + + db_query = self.config.finalize_query(db_query) + return db_query diff --git a/server/szurubooru/search/parser.py b/server/szurubooru/search/parser.py new file mode 100644 index 00000000..432f05cd --- /dev/null +++ b/server/szurubooru/search/parser.py @@ -0,0 +1,90 @@ +import re +from szurubooru import errors +from szurubooru.search import criteria, tokens + +def _create_criterion(original_value, value): + if '..' in value: + low, high = value.split('..', 1) + if not low and not high: + raise errors.SearchError('Empty ranged value') + return criteria.RangedCriterion(original_value, low, high) + if ',' in value: + return criteria.ArrayCriterion( + original_value, value.split(',')) + return criteria.PlainCriterion(original_value, value) + +def _parse_anonymous(value, negated): + criterion = _create_criterion(value, value) + return tokens.AnonymousToken(criterion, negated) + +def _parse_named(key, value, negated): + original_value = value + if key.endswith('-min'): + key = key[:-4] + value += '..' + elif key.endswith('-max'): + key = key[:-4] + value = '..' + value + criterion = _create_criterion(original_value, value) + return tokens.NamedToken(key, criterion, negated) + +def _parse_special(value, negated): + return tokens.SpecialToken(value, negated) + +def _parse_sort(value, negated): + if value.count(',') == 0: + direction_str = None + elif value.count(',') == 1: + value, direction_str = value.split(',') + else: + raise errors.SearchError('Too many commas in sort style token.') + try: + direction = { + 'asc': tokens.SortToken.SORT_ASC, + 'desc': tokens.SortToken.SORT_DESC, + '': tokens.SortToken.SORT_DEFAULT, + None: tokens.SortToken.SORT_DEFAULT, + }[direction_str] + except KeyError: + raise errors.SearchError( + 'Unknown search direction: %r.' % direction_str) + if negated: + direction = { + tokens.SortToken.SORT_ASC: tokens.SortToken.SORT_DESC, + tokens.SortToken.SORT_DESC: tokens.SortToken.SORT_ASC, + tokens.SortToken.SORT_DEFAULT: tokens.SortToken.SORT_NEGATED_DEFAULT, + tokens.SortToken.SORT_NEGATED_DEFAULT: tokens.SortToken.SORT_DEFAULT, + }[direction] + return tokens.SortToken(value, direction) + +class SearchQuery(): + def __init__(self): + self.anonymous_tokens = [] + self.named_tokens = [] + self.special_tokens = [] + self.sort_tokens = [] + +class Parser(object): + def parse(self, query_text): + query = SearchQuery() + for chunk in re.split(r'\s+', (query_text or '').lower()): + if not chunk: + continue + negated = False + while chunk[0] == '-': + chunk = chunk[1:] + negated = not negated + if ':' in chunk and chunk[0] != ':': + key, value = chunk.split(':', 2) + if key == 'sort': + query.sort_tokens.append( + _parse_sort(value, negated)) + elif key == 'special': + query.special_tokens.append( + _parse_special(value, negated)) + else: + query.named_tokens.append( + _parse_named(key, value, negated)) + else: + query.anonymous_tokens.append(_parse_anonymous(chunk, negated)) + return query diff --git a/server/szurubooru/search/search_executor.py b/server/szurubooru/search/search_executor.py deleted file mode 100644 index d23b2888..00000000 --- a/server/szurubooru/search/search_executor.py +++ /dev/null @@ -1,161 +0,0 @@ -import re -import sqlalchemy -from szurubooru import db, errors -from szurubooru.func import cache -from szurubooru.search import criteria - -class SearchExecutor(object): - ''' - Class for search parsing and execution. Handles plaintext parsing and - delegates sqlalchemy filter decoration to SearchConfig instances. - ''' - - def __init__(self, search_config): - self.config = search_config - - def execute(self, query_text, page, page_size): - ''' - Parse input and return tuple containing total record count and filtered - entities. - ''' - key = (id(self.config), query_text, page, page_size) - if cache.has(key): - return cache.get(key) - filter_query = self.config.create_filter_query() - filter_query = filter_query.options(sqlalchemy.orm.lazyload('*')) - filter_query = self._prepare(filter_query, query_text) - entities = filter_query \ - .offset((page - 1) * page_size) \ - .limit(page_size) \ - .all() - - count_query = self.config.create_count_query() - count_query = count_query.options(sqlalchemy.orm.lazyload('*')) - count_query = self._prepare(count_query, query_text) - count_statement = count_query \ - .statement \ - .with_only_columns([sqlalchemy.func.count()]) \ - .order_by(None) - count = db.session.execute(count_statement).scalar() - ret = (count, entities) - cache.put(key, ret) - return ret - - def execute_and_serialize(self, ctx, serializer): - query = ctx.get_param_as_string('query') - page = ctx.get_param_as_int('page', default=1, min=1) - page_size = ctx.get_param_as_int('pageSize', default=100, min=1, max=100) - count, entities = self.execute(query, page, page_size) - return { - 'query': query, - 'page': page, - 'pageSize': page_size, - 'total': count, - 'results': [serializer(entity) for entity in entities], - } - - def _prepare(self, query, query_text): - ''' Parse input and return SQLAlchemy query. ''' - for token in re.split(r'\s+', (query_text or '').lower()): - if not token: - continue - negated = False - while token[0] == '-': - token = token[1:] - negated = not negated - - if ':' in token and token[0] != ':': - key, value = token.split(':', 2) - query = self._handle_key_value(query, key, value, negated) - else: - query = self._handle_anonymous( - query, self._create_criterion(token, negated)) - - query = self.config.finalize_query(query) - return query - - def _handle_key_value(self, query, key, value, negated): - if key == 'sort': - return self._handle_sort(query, value, negated) - elif key == 'special': - return self._handle_special(query, value, negated) - else: - return self._handle_named(query, key, value, negated) - - def _handle_anonymous(self, query, criterion): - if not self.config.anonymous_filter: - raise errors.SearchError( - 'Anonymous tokens are not valid in this context.') - return self.config.anonymous_filter(query, criterion) - - def _handle_named(self, query, key, value, negated): - if key.endswith('-min'): - key = key[:-4] - value += '..' - elif key.endswith('-max'): - key = key[:-4] - value = '..' + value - criterion = self._create_criterion(value, negated) - if key in self.config.named_filters: - return self.config.named_filters[key](query, criterion) - raise errors.SearchError( - 'Unknown named token: %r. Available named tokens: %r.' % ( - key, list(self.config.named_filters.keys()))) - - def _handle_special(self, query, value, negated): - if value in self.config.special_filters: - return self.config.special_filters[value](query, negated) - raise errors.SearchError( - 'Unknown special token: %r. Available special tokens: %r.' % ( - value, list(self.config.special_filters.keys()))) - - def _handle_sort(self, query, value, negated): - if value.count(',') == 0: - dir_str = None - elif value.count(',') == 1: - value, dir_str = value.split(',') - else: - raise errors.SearchError('Too many commas in sort style token.') - - try: - column, default_sort = self.config.sort_columns[value] - except KeyError: - raise errors.SearchError( - 'Unknown sort style: %r. Available sort styles: %r.' % ( - value, list(self.config.sort_columns.keys()))) - - sort_asc = self.config.SORT_ASC - sort_desc = self.config.SORT_DESC - - try: - sort_map = { - 'asc': sort_asc, - 'desc': sort_desc, - '': default_sort, - None: default_sort, - } - sort = sort_map[dir_str] - except KeyError: - raise errors.SearchError('Unknown search direction: %r.' % dir_str) - - if negated and sort: - sort = -sort - - transform_map = { - sort_asc: lambda input: input.asc(), - sort_desc: lambda input: input.desc(), - None: lambda input: input, - } - transform = transform_map[sort] - return query.order_by(transform(column)) - - def _create_criterion(self, value, negated): - if '..' in value: - low, high = value.split('..', 1) - if not low and not high: - raise errors.SearchError('Empty ranged value') - return criteria.RangedSearchCriterion(value, negated, low, high) - if ',' in value: - return criteria.ArraySearchCriterion( - value, negated, value.split(',')) - return criteria.PlainSearchCriterion(value, negated, value) diff --git a/server/szurubooru/search/tokens.py b/server/szurubooru/search/tokens.py new file mode 100644 index 00000000..1e1a8038 --- /dev/null +++ b/server/szurubooru/search/tokens.py @@ -0,0 +1,24 @@ +class AnonymousToken(object): + def __init__(self, criterion, negated): + self.criterion = criterion + self.negated = negated + +class NamedToken(AnonymousToken): + def __init__(self, name, criterion, negated): + super().__init__(criterion, negated) + self.name = name + +class SortToken(object): + SORT_DESC = 'desc' + SORT_ASC = 'asc' + SORT_DEFAULT = 'default' + SORT_NEGATED_DEFAULT = 'negated default' + + def __init__(self, name, direction): + self.name = name + self.direction = direction + +class SpecialToken(object): + def __init__(self, value, negated): + self.value = value + self.negated = negated diff --git a/server/szurubooru/tests/search/test_comment_search_config.py b/server/szurubooru/tests/search/test_comment_search_config.py index 235fcea4..5a88a71d 100644 --- a/server/szurubooru/tests/search/test_comment_search_config.py +++ b/server/szurubooru/tests/search/test_comment_search_config.py @@ -4,8 +4,7 @@ from szurubooru import db, errors, search @pytest.fixture def executor(): - search_config = search.CommentSearchConfig() - return search.SearchExecutor(search_config) + return search.Executor(search.configs.CommentSearchConfig()) @pytest.fixture def verify_unpaged(executor): diff --git a/server/szurubooru/tests/search/test_post_search_config.py b/server/szurubooru/tests/search/test_post_search_config.py index a3d1749f..59a05307 100644 --- a/server/szurubooru/tests/search/test_post_search_config.py +++ b/server/szurubooru/tests/search/test_post_search_config.py @@ -38,7 +38,7 @@ def feature_factory(user_factory): @pytest.fixture def executor(user_factory): - return search.SearchExecutor(search.PostSearchConfig()) + return search.Executor(search.configs.PostSearchConfig()) @pytest.fixture def auth_executor(executor, user_factory): diff --git a/server/szurubooru/tests/search/test_tag_search_config.py b/server/szurubooru/tests/search/test_tag_search_config.py index 7582a57f..5b3002f6 100644 --- a/server/szurubooru/tests/search/test_tag_search_config.py +++ b/server/szurubooru/tests/search/test_tag_search_config.py @@ -4,8 +4,7 @@ from szurubooru import db, errors, search @pytest.fixture def executor(): - search_config = search.TagSearchConfig() - return search.SearchExecutor(search_config) + return search.Executor(search.configs.TagSearchConfig()) @pytest.fixture def verify_unpaged(executor): @@ -321,19 +320,4 @@ def test_sort_by_category( tag2 = tag_factory(names=['t2'], category=cat2) tag3 = tag_factory(names=['t3'], category=cat1) db.session.add_all([tag1, tag2, tag3]) - import sqlalchemy - from sqlalchemy.orm import joinedload - print('test', [tag.first_name for tag in db.session.query(db.Tag) - .join(db.TagCategory).options( - joinedload(db.Tag.names), - joinedload(db.Tag.category), - joinedload(db.Tag.suggestions).joinedload(db.Tag.names), - joinedload(db.Tag.implications).joinedload(db.Tag.names) - ) - .options(sqlalchemy.orm.lazyload('*')) - .order_by(db.TagCategory.name.asc()) - .order_by(db.Tag.first_name.asc()) - .offset(0) - .limit(100) - .all()]) verify_unpaged(input, expected_tag_names) diff --git a/server/szurubooru/tests/search/test_user_search_config.py b/server/szurubooru/tests/search/test_user_search_config.py index dfaecac4..7a2c4503 100644 --- a/server/szurubooru/tests/search/test_user_search_config.py +++ b/server/szurubooru/tests/search/test_user_search_config.py @@ -4,8 +4,7 @@ from szurubooru import db, errors, search @pytest.fixture def executor(): - search_config = search.UserSearchConfig() - return search.SearchExecutor(search_config) + return search.Executor(search.configs.UserSearchConfig()) @pytest.fixture def verify_unpaged(executor):