server/search: fix caching special tokens

special:liked was being reused between users. Now the cache internally
caches object similar to liked:USER.
This commit is contained in:
rr- 2016-06-03 15:51:50 +02:00
parent f0d3589344
commit 8a5c6f0b31
5 changed files with 97 additions and 47 deletions

View file

@ -1,8 +1,7 @@
import sqlalchemy import sqlalchemy
from szurubooru import db, errors from szurubooru import db, errors
from szurubooru.func import util from szurubooru.func import util
from szurubooru.search import criteria from szurubooru.search import criteria, tokens
from szurubooru.search import tokens
def wildcard_transformer(value): def wildcard_transformer(value):
return value.replace('*', '%') return value.replace('*', '%')
@ -11,6 +10,9 @@ class BaseSearchConfig(object):
SORT_ASC = tokens.SortToken.SORT_ASC SORT_ASC = tokens.SortToken.SORT_ASC
SORT_DESC = tokens.SortToken.SORT_DESC SORT_DESC = tokens.SortToken.SORT_DESC
def on_search_query_parsed(self, search_query):
pass
def create_filter_query(self): def create_filter_query(self):
raise NotImplementedError() raise NotImplementedError()
@ -71,7 +73,8 @@ class BaseSearchConfig(object):
return wrapper return wrapper
@staticmethod @staticmethod
def _apply_str_criterion_to_column(column, criterion, transformer): def _apply_str_criterion_to_column(
column, criterion, transformer=wildcard_transformer):
''' '''
Decorate SQLAlchemy filter on given column using supplied criterion. Decorate SQLAlchemy filter on given column using supplied criterion.
''' '''

View file

@ -1,7 +1,8 @@
from sqlalchemy.orm import subqueryload, lazyload, defer from sqlalchemy.orm import subqueryload, lazyload, defer, aliased
from sqlalchemy.sql.expression import func from sqlalchemy.sql.expression import func
from szurubooru import db, errors from szurubooru import db, errors
from szurubooru.func import util from szurubooru.func import util
from szurubooru.search import criteria, tokens
from szurubooru.search.configs.base_search_config import BaseSearchConfig from szurubooru.search.configs.base_search_config import BaseSearchConfig
def _enum_transformer(available_values, value): def _enum_transformer(available_values, value):
@ -35,7 +36,47 @@ def _safety_transformer(value):
} }
return _enum_transformer(available_values, value) return _enum_transformer(available_values, value)
def _create_score_filter(score):
def wrapper(query, criterion, negated):
if not getattr(criterion, 'internal', False):
raise errors.SearchError(
'Votes cannot be seen publicly. Did you mean %r?' \
% 'special:liked')
user_alias = aliased(db.User)
score_alias = aliased(db.PostScore)
expr = score_alias.score == score
expr = expr & BaseSearchConfig._apply_str_criterion_to_column(
user_alias.name, criterion)
if negated:
expr = ~expr
ret = query \
.join(score_alias, score_alias.post_id == db.Post.post_id) \
.join(user_alias, user_alias.user_id == score_alias.user_id) \
.filter(expr)
return ret
return wrapper
class PostSearchConfig(BaseSearchConfig): class PostSearchConfig(BaseSearchConfig):
def on_search_query_parsed(self, search_query):
new_special_tokens = []
for token in search_query.special_tokens:
if token.value in ('fav', 'liked', 'disliked'):
assert self.user
if self.user.rank == 'anonymous':
raise errors.SearchError('Must be logged in to use this feature.')
criterion = criteria.PlainCriterion(
original_text=self.user.name,
value=self.user.name)
criterion.internal = True
search_query.named_tokens.append(
tokens.NamedToken(
name=token.value,
criterion=criterion,
negated=token.negated))
else:
new_special_tokens.append(token)
search_query.special_tokens = new_special_tokens
def create_filter_query(self): def create_filter_query(self):
return self.create_count_query() \ return self.create_count_query() \
.options( .options(
@ -101,6 +142,8 @@ class PostSearchConfig(BaseSearchConfig):
db.User.name, db.User.name,
self._create_str_filter, self._create_str_filter,
lambda subquery: subquery.join(db.User)), lambda subquery: subquery.join(db.User)),
'liked': _create_score_filter(1),
'disliked': _create_score_filter(-1),
'tag-count': self._create_num_filter(db.Post.tag_count), 'tag-count': self._create_num_filter(db.Post.tag_count),
'comment-count': self._create_num_filter(db.Post.comment_count), 'comment-count': self._create_num_filter(db.Post.comment_count),
'fav-count': self._create_num_filter(db.Post.favorite_count), 'fav-count': self._create_num_filter(db.Post.favorite_count),
@ -158,50 +201,13 @@ class PostSearchConfig(BaseSearchConfig):
@property @property
def special_filters(self): def special_filters(self):
return { return {
'liked': self.own_liked_filter, # handled by parsed
'disliked': self.own_disliked_filter, 'fav': None,
'fav': self.own_fav_filter, 'liked': None,
'disliked': None,
'tumbleweed': self.tumbleweed_filter, 'tumbleweed': self.tumbleweed_filter,
} }
def own_liked_filter(self, query, negated):
assert self.user
if self.user.rank == 'anonymous':
raise errors.SearchError('Must be logged in to use this feature.')
expr = db.Post.post_id.in_(
db.session \
.query(db.PostScore.post_id) \
.filter(db.PostScore.user_id == self.user.user_id) \
.filter(db.PostScore.score == 1))
if negated:
expr = ~expr
return query.filter(expr)
def own_disliked_filter(self, query, negated):
assert self.user
if self.user.rank == 'anonymous':
raise errors.SearchError('Must be logged in to use this feature.')
expr = db.Post.post_id.in_(
db.session \
.query(db.PostScore.post_id) \
.filter(db.PostScore.user_id == self.user.user_id) \
.filter(db.PostScore.score == -1))
if negated:
expr = ~expr
return query.filter(expr)
def own_fav_filter(self, query, negated):
assert self.user
if self.user.rank == 'anonymous':
raise errors.SearchError('Must be logged in to use this feature.')
expr = db.Post.post_id.in_(
db.session \
.query(db.PostFavorite.post_id) \
.filter(db.PostFavorite.user_id == self.user.user_id))
if negated:
expr = ~expr
return query.filter(expr)
def tumbleweed_filter(self, query, negated): def tumbleweed_filter(self, query, negated):
expr = \ expr = \
(db.Post.comment_count == 0) \ (db.Post.comment_count == 0) \

View file

@ -34,6 +34,7 @@ class Executor(object):
''' '''
search_query = self.parser.parse(query_text) search_query = self.parser.parse(query_text)
self.config.on_search_query_parsed(search_query)
key = (id(self.config), hash(search_query), page, page_size) key = (id(self.config), hash(search_query), page, page_size)
if cache.has(key): if cache.has(key):

View file

@ -1,4 +1,5 @@
import unittest.mock import unittest.mock
import pytest
from szurubooru import search from szurubooru import search
from szurubooru.func import cache from szurubooru.func import cache
@ -11,7 +12,7 @@ def test_retrieving_from_cache(user_factory):
executor.execute('test:whatever', 1, 10) executor.execute('test:whatever', 1, 10)
assert cache.get.called assert cache.get.called
def test_putting_equivalent_queries_into_cache(user_factory): def test_putting_equivalent_queries_into_cache():
config = search.configs.PostSearchConfig() config = search.configs.PostSearchConfig()
with unittest.mock.patch('szurubooru.func.cache.has'), \ with unittest.mock.patch('szurubooru.func.cache.has'), \
unittest.mock.patch('szurubooru.func.cache.put'): unittest.mock.patch('szurubooru.func.cache.put'):
@ -30,7 +31,7 @@ def test_putting_equivalent_queries_into_cache(user_factory):
assert len(hashes) == 6 assert len(hashes) == 6
assert len(set(hashes)) == 1 assert len(set(hashes)) == 1
def test_putting_non_equivalent_queries_into_cache(user_factory): def test_putting_non_equivalent_queries_into_cache():
config = search.configs.PostSearchConfig() config = search.configs.PostSearchConfig()
with unittest.mock.patch('szurubooru.func.cache.has'), \ with unittest.mock.patch('szurubooru.func.cache.has'), \
unittest.mock.patch('szurubooru.func.cache.put'): unittest.mock.patch('szurubooru.func.cache.put'):
@ -82,3 +83,33 @@ def test_putting_non_equivalent_queries_into_cache(user_factory):
executor.execute(*arg) executor.execute(*arg)
assert len(hashes) == len(args) assert len(hashes) == len(args)
assert len(set(hashes)) == len(args) assert len(set(hashes)) == len(args)
@pytest.mark.parametrize('input', [
'special:fav',
'special:liked',
'special:disliked',
'-special:fav',
'-special:liked',
'-special:disliked',
])
def test_putting_auth_dependent_queries_into_cache(user_factory, input):
config = search.configs.PostSearchConfig()
with unittest.mock.patch('szurubooru.func.cache.has'), \
unittest.mock.patch('szurubooru.func.cache.put'):
hashes = []
def appender(key, value):
hashes.append(key)
cache.has.side_effect = lambda *args: False
cache.put.side_effect = appender
executor = search.Executor(config)
executor.config.user = user_factory()
executor.execute(input, 1, 1)
assert len(set(hashes)) == 1
executor.config.user = user_factory()
executor.execute(input, 1, 1)
assert len(set(hashes)) == 2
executor.execute(input, 1, 1)
assert len(set(hashes)) == 2

View file

@ -559,6 +559,15 @@ def test_own_disliked(
verify_unpaged('special:disliked', [1]) verify_unpaged('special:disliked', [1])
verify_unpaged('-special:disliked', [2, 3]) verify_unpaged('-special:disliked', [2, 3])
@pytest.mark.parametrize('input', [
'liked:x',
'disliked:x',
])
def test_someones_score(executor, input):
with pytest.raises(errors.SearchError):
actual_count, actual_posts = executor.execute(
input, page=1, page_size=100)
def test_own_fav( def test_own_fav(
auth_executor, post_factory, fav_factory, user_factory, verify_unpaged): auth_executor, post_factory, fav_factory, user_factory, verify_unpaged):
auth_user = auth_executor() auth_user = auth_executor()