diff --git a/server/szurubooru/search/configs/base_search_config.py b/server/szurubooru/search/configs/base_search_config.py index 2dda98fa..dfbe4b0f 100644 --- a/server/szurubooru/search/configs/base_search_config.py +++ b/server/szurubooru/search/configs/base_search_config.py @@ -1,8 +1,7 @@ import sqlalchemy from szurubooru import db, errors from szurubooru.func import util -from szurubooru.search import criteria -from szurubooru.search import tokens +from szurubooru.search import criteria, tokens def wildcard_transformer(value): return value.replace('*', '%') @@ -11,6 +10,9 @@ class BaseSearchConfig(object): SORT_ASC = tokens.SortToken.SORT_ASC SORT_DESC = tokens.SortToken.SORT_DESC + def on_search_query_parsed(self, search_query): + pass + def create_filter_query(self): raise NotImplementedError() @@ -71,7 +73,8 @@ class BaseSearchConfig(object): return wrapper @staticmethod - def _apply_str_criterion_to_column(column, criterion, transformer): + def _apply_str_criterion_to_column( + column, criterion, transformer=wildcard_transformer): ''' Decorate SQLAlchemy filter on given column using supplied criterion. ''' diff --git a/server/szurubooru/search/configs/post_search_config.py b/server/szurubooru/search/configs/post_search_config.py index 75ab71a1..85c25e38 100644 --- a/server/szurubooru/search/configs/post_search_config.py +++ b/server/szurubooru/search/configs/post_search_config.py @@ -1,7 +1,8 @@ -from sqlalchemy.orm import subqueryload, lazyload, defer +from sqlalchemy.orm import subqueryload, lazyload, defer, aliased from sqlalchemy.sql.expression import func from szurubooru import db, errors from szurubooru.func import util +from szurubooru.search import criteria, tokens from szurubooru.search.configs.base_search_config import BaseSearchConfig def _enum_transformer(available_values, value): @@ -35,7 +36,47 @@ def _safety_transformer(value): } return _enum_transformer(available_values, value) +def _create_score_filter(score): + def wrapper(query, criterion, negated): + if not getattr(criterion, 'internal', False): + raise errors.SearchError( + 'Votes cannot be seen publicly. Did you mean %r?' \ + % 'special:liked') + user_alias = aliased(db.User) + score_alias = aliased(db.PostScore) + expr = score_alias.score == score + expr = expr & BaseSearchConfig._apply_str_criterion_to_column( + user_alias.name, criterion) + if negated: + expr = ~expr + ret = query \ + .join(score_alias, score_alias.post_id == db.Post.post_id) \ + .join(user_alias, user_alias.user_id == score_alias.user_id) \ + .filter(expr) + return ret + return wrapper + class PostSearchConfig(BaseSearchConfig): + def on_search_query_parsed(self, search_query): + new_special_tokens = [] + for token in search_query.special_tokens: + if token.value in ('fav', 'liked', 'disliked'): + assert self.user + if self.user.rank == 'anonymous': + raise errors.SearchError('Must be logged in to use this feature.') + criterion = criteria.PlainCriterion( + original_text=self.user.name, + value=self.user.name) + criterion.internal = True + search_query.named_tokens.append( + tokens.NamedToken( + name=token.value, + criterion=criterion, + negated=token.negated)) + else: + new_special_tokens.append(token) + search_query.special_tokens = new_special_tokens + def create_filter_query(self): return self.create_count_query() \ .options( @@ -101,6 +142,8 @@ class PostSearchConfig(BaseSearchConfig): db.User.name, self._create_str_filter, lambda subquery: subquery.join(db.User)), + 'liked': _create_score_filter(1), + 'disliked': _create_score_filter(-1), 'tag-count': self._create_num_filter(db.Post.tag_count), 'comment-count': self._create_num_filter(db.Post.comment_count), 'fav-count': self._create_num_filter(db.Post.favorite_count), @@ -158,50 +201,13 @@ class PostSearchConfig(BaseSearchConfig): @property def special_filters(self): return { - 'liked': self.own_liked_filter, - 'disliked': self.own_disliked_filter, - 'fav': self.own_fav_filter, + # handled by parsed + 'fav': None, + 'liked': None, + 'disliked': None, 'tumbleweed': self.tumbleweed_filter, } - def own_liked_filter(self, query, negated): - assert self.user - if self.user.rank == 'anonymous': - raise errors.SearchError('Must be logged in to use this feature.') - expr = db.Post.post_id.in_( - db.session \ - .query(db.PostScore.post_id) \ - .filter(db.PostScore.user_id == self.user.user_id) \ - .filter(db.PostScore.score == 1)) - if negated: - expr = ~expr - return query.filter(expr) - - def own_disliked_filter(self, query, negated): - assert self.user - if self.user.rank == 'anonymous': - raise errors.SearchError('Must be logged in to use this feature.') - expr = db.Post.post_id.in_( - db.session \ - .query(db.PostScore.post_id) \ - .filter(db.PostScore.user_id == self.user.user_id) \ - .filter(db.PostScore.score == -1)) - if negated: - expr = ~expr - return query.filter(expr) - - def own_fav_filter(self, query, negated): - assert self.user - if self.user.rank == 'anonymous': - raise errors.SearchError('Must be logged in to use this feature.') - expr = db.Post.post_id.in_( - db.session \ - .query(db.PostFavorite.post_id) \ - .filter(db.PostFavorite.user_id == self.user.user_id)) - if negated: - expr = ~expr - return query.filter(expr) - def tumbleweed_filter(self, query, negated): expr = \ (db.Post.comment_count == 0) \ diff --git a/server/szurubooru/search/executor.py b/server/szurubooru/search/executor.py index c50712c9..c024d94f 100644 --- a/server/szurubooru/search/executor.py +++ b/server/szurubooru/search/executor.py @@ -34,6 +34,7 @@ class Executor(object): ''' search_query = self.parser.parse(query_text) + self.config.on_search_query_parsed(search_query) key = (id(self.config), hash(search_query), page, page_size) if cache.has(key): diff --git a/server/szurubooru/tests/search/test_executor.py b/server/szurubooru/tests/search/test_executor.py index bb1cb515..99e085f2 100644 --- a/server/szurubooru/tests/search/test_executor.py +++ b/server/szurubooru/tests/search/test_executor.py @@ -1,4 +1,5 @@ import unittest.mock +import pytest from szurubooru import search from szurubooru.func import cache @@ -11,7 +12,7 @@ def test_retrieving_from_cache(user_factory): executor.execute('test:whatever', 1, 10) assert cache.get.called -def test_putting_equivalent_queries_into_cache(user_factory): +def test_putting_equivalent_queries_into_cache(): config = search.configs.PostSearchConfig() with unittest.mock.patch('szurubooru.func.cache.has'), \ unittest.mock.patch('szurubooru.func.cache.put'): @@ -30,7 +31,7 @@ def test_putting_equivalent_queries_into_cache(user_factory): assert len(hashes) == 6 assert len(set(hashes)) == 1 -def test_putting_non_equivalent_queries_into_cache(user_factory): +def test_putting_non_equivalent_queries_into_cache(): config = search.configs.PostSearchConfig() with unittest.mock.patch('szurubooru.func.cache.has'), \ unittest.mock.patch('szurubooru.func.cache.put'): @@ -82,3 +83,33 @@ def test_putting_non_equivalent_queries_into_cache(user_factory): executor.execute(*arg) assert len(hashes) == len(args) assert len(set(hashes)) == len(args) + +@pytest.mark.parametrize('input', [ + 'special:fav', + 'special:liked', + 'special:disliked', + '-special:fav', + '-special:liked', + '-special:disliked', +]) +def test_putting_auth_dependent_queries_into_cache(user_factory, input): + config = search.configs.PostSearchConfig() + with unittest.mock.patch('szurubooru.func.cache.has'), \ + unittest.mock.patch('szurubooru.func.cache.put'): + hashes = [] + def appender(key, value): + hashes.append(key) + cache.has.side_effect = lambda *args: False + cache.put.side_effect = appender + executor = search.Executor(config) + + executor.config.user = user_factory() + executor.execute(input, 1, 1) + assert len(set(hashes)) == 1 + + executor.config.user = user_factory() + executor.execute(input, 1, 1) + assert len(set(hashes)) == 2 + + executor.execute(input, 1, 1) + assert len(set(hashes)) == 2 diff --git a/server/szurubooru/tests/search/test_post_search_config.py b/server/szurubooru/tests/search/test_post_search_config.py index 59a05307..035e7708 100644 --- a/server/szurubooru/tests/search/test_post_search_config.py +++ b/server/szurubooru/tests/search/test_post_search_config.py @@ -559,6 +559,15 @@ def test_own_disliked( verify_unpaged('special:disliked', [1]) verify_unpaged('-special:disliked', [2, 3]) +@pytest.mark.parametrize('input', [ + 'liked:x', + 'disliked:x', +]) +def test_someones_score(executor, input): + with pytest.raises(errors.SearchError): + actual_count, actual_posts = executor.execute( + input, page=1, page_size=100) + def test_own_fav( auth_executor, post_factory, fav_factory, user_factory, verify_unpaged): auth_user = auth_executor()