server/search: add search term escaping

This commit is contained in:
rr- 2017-04-24 21:51:49 +02:00
parent 9814b132c3
commit ba4df16499
7 changed files with 164 additions and 35 deletions

8
API.md
View file

@ -2258,6 +2258,9 @@ Date/time values can be of following form:
Some fields, such as user names, can take wildcards (`*`). Some fields, such as user names, can take wildcards (`*`).
You can escape special characters such as `:` and `-` by prepending them with a
backslash: `\\`.
**Example** **Example**
Searching for posts with following query: Searching for posts with following query:
@ -2266,3 +2269,8 @@ Searching for posts with following query:
will show flash files tagged as sea, that were liked by seven people at most, will show flash files tagged as sea, that were liked by seven people at most,
uploaded by user Pirate. uploaded by user Pirate.
Searching for posts with `re:zero` will show an error message about unknown
named token.
Searching for posts with `re\:zero` will show posts tagged with `re:zero`.

View file

@ -80,6 +80,9 @@ take following form:</p>
<code>,desc</code> to control the sort direction, which can be also controlled <code>,desc</code> to control the sort direction, which can be also controlled
by negating the whole token.</p> by negating the whole token.</p>
<p>You can escape special characters such as <code>:</code> and <code>-</code>
by prepending them with a backslash: <code>\\</code>.</p>
<h1>Example</h1> <h1>Example</h1>
<p>Searching for posts with following query:</p> <p>Searching for posts with following query:</p>
@ -89,3 +92,8 @@ by negating the whole token.</p>
<p>will show flash files tagged as sea, that were liked by seven people at <p>will show flash files tagged as sea, that were liked by seven people at
most, uploaded by user Pirate.</p> most, uploaded by user Pirate.</p>
<p>Searching for posts with <code>re:zero</code> will show an error message
about unknown named token.</p>
<p>Searching for posts with <code>re\:zero</code> will show posts tagged with
<code>re:zero</code>.</p>

View file

@ -10,15 +10,6 @@ from szurubooru.search.configs.base_search_config import (
BaseSearchConfig, Filter) BaseSearchConfig, Filter)
def _enum_transformer(available_values: Dict[str, Any], value: str) -> str:
try:
return available_values[value.lower()]
except KeyError:
raise errors.SearchError(
'Invalid value: %r. Possible values: %r.' % (
value, list(sorted(available_values.keys()))))
def _type_transformer(value: str) -> str: def _type_transformer(value: str) -> str:
available_values = { available_values = {
'image': model.Post.TYPE_IMAGE, 'image': model.Post.TYPE_IMAGE,
@ -31,7 +22,7 @@ def _type_transformer(value: str) -> str:
'flash': model.Post.TYPE_FLASH, 'flash': model.Post.TYPE_FLASH,
'swf': model.Post.TYPE_FLASH, 'swf': model.Post.TYPE_FLASH,
} }
return _enum_transformer(available_values, value) return search_util.enum_transformer(available_values, value)
def _safety_transformer(value: str) -> str: def _safety_transformer(value: str) -> str:
@ -41,7 +32,7 @@ def _safety_transformer(value: str) -> str:
'questionable': model.Post.SAFETY_SKETCHY, 'questionable': model.Post.SAFETY_SKETCHY,
'unsafe': model.Post.SAFETY_UNSAFE, 'unsafe': model.Post.SAFETY_UNSAFE,
} }
return _enum_transformer(available_values, value) return search_util.enum_transformer(available_values, value)
def _create_score_filter(score: int) -> Filter: def _create_score_filter(score: int) -> Filter:

View file

@ -1,4 +1,4 @@
from typing import Any, Optional, Union, Callable from typing import Any, Optional, Union, Dict, Callable
import sqlalchemy as sa import sqlalchemy as sa
from szurubooru import db, errors from szurubooru import db, errors
from szurubooru.func import util from szurubooru.func import util
@ -8,27 +8,62 @@ from szurubooru.search.configs.base_search_config import Filter
Number = Union[int, float] Number = Union[int, float]
WILDCARD = '(--wildcard--)' # something unlikely to be used by the users
def unescape(text: str, make_wildcards_special: bool = False) -> str:
output = ''
i = 0
while i < len(text):
if text[i] == '\\':
try:
char = text[i+1]
i += 1
except IndexError:
raise errors.SearchError(
'Unterminated escape sequence (did you forget to escape '
'the ending backslash?)')
if char not in '*\\:-.,':
raise errors.SearchError(
'Unknown escape sequence (did you forget to escape '
'the backslash?)')
elif text[i] == '*' and make_wildcards_special:
char = WILDCARD
else:
char = text[i]
output += char
i += 1
return output
def wildcard_transformer(value: str) -> str: def wildcard_transformer(value: str) -> str:
return ( return (
value unescape(value, make_wildcards_special=True)
.replace('\\', '\\\\') .replace('\\', '\\\\')
.replace('%', '\\%') .replace('%', '\\%')
.replace('_', '\\_') .replace('_', '\\_')
.replace('*', '%')) .replace(WILDCARD, '%'))
def enum_transformer(available_values: Dict[str, Any], value: str) -> str:
try:
return available_values[unescape(value.lower())]
except KeyError:
raise errors.SearchError(
'Invalid value: %r. Possible values: %r.' % (
value, list(sorted(available_values.keys()))))
def integer_transformer(value: str) -> int: def integer_transformer(value: str) -> int:
return int(value) return int(unescape(value))
def float_transformer(value: str) -> float: def float_transformer(value: str) -> float:
for sep in list('/:'): for sep in list('/:'):
if sep in value: if sep in value:
a, b = value.split(sep, 1) a, b = value.split(sep, 1)
return float(a) / float(b) return float(unescape(a)) / float(unescape(b))
return float(value) return float(unescape(value))
def apply_num_criterion_to_column( def apply_num_criterion_to_column(
@ -84,23 +119,23 @@ def apply_str_criterion_to_column(
for value in criterion.values: for value in criterion.values:
expr = expr | column.ilike(transformer(value)) expr = expr | column.ilike(transformer(value))
elif isinstance(criterion, criteria.RangedCriterion): elif isinstance(criterion, criteria.RangedCriterion):
expr = column.ilike(transformer(criterion.original_text)) raise errors.SearchError(
'Ranged criterion is invalid in this context. '
'Did you forget to escape the dots?')
else: else:
assert False assert False
return expr return expr
def create_str_filter( def create_str_filter(
column: SaColumn, column: SaColumn, transformer: Callable[[str], str]=wildcard_transformer
transformer: Callable[[str], str]=wildcard_transformer
) -> Filter: ) -> Filter:
def wrapper( def wrapper(
query: SaQuery, query: SaQuery,
criterion: Optional[criteria.BaseCriterion], criterion: Optional[criteria.BaseCriterion],
negated: bool) -> SaQuery: negated: bool) -> SaQuery:
assert criterion assert criterion
expr = apply_str_criterion_to_column( expr = apply_str_criterion_to_column(column, criterion, transformer)
column, criterion, transformer)
if negated: if negated:
expr = ~expr expr = ~expr
return query.filter(expr) return query.filter(expr)

View file

@ -1,17 +1,20 @@
import re import re
from typing import List from typing import Match, List
from szurubooru import errors from szurubooru import errors
from szurubooru.search import criteria, tokens from szurubooru.search import criteria, tokens
from szurubooru.search.query import SearchQuery from szurubooru.search.query import SearchQuery
from szurubooru.search.configs import util
def _create_criterion( def _create_criterion(
original_value: str, value: str) -> criteria.BaseCriterion: original_value: str, value: str) -> criteria.BaseCriterion:
if ',' in value: if re.search(r'(?<!\\),', value):
return criteria.ArrayCriterion( values = re.split(r'(?<!\\),', value)
original_value, value.split(',')) if any(not term.strip() for term in values):
if '..' in value: raise errors.SearchError('Empty compound value')
low, high = value.split('..', 1) return criteria.ArrayCriterion(original_value, values)
if re.search(r'(?<!\\)\.(?<!\\)\.', value):
low, high = re.split(r'(?<!\\)\.(?<!\\)\.', value, 1)
if not low and not high: if not low and not high:
raise errors.SearchError('Empty ranged value') raise errors.SearchError('Empty ranged value')
return criteria.RangedCriterion(original_value, low, high) return criteria.RangedCriterion(original_value, low, high)
@ -82,9 +85,10 @@ class Parser:
negated = True negated = True
if not chunk: if not chunk:
raise errors.SearchError('Empty negated token.') raise errors.SearchError('Empty negated token.')
match = re.match('([a-z_-]+):(.*)', chunk) match = re.match(r'^(.*?)(?<!\\):(.*)$', chunk)
if match: if match:
key, value = list(match.groups()) key, value = list(match.groups())
key = util.unescape(key)
if key == 'sort': if key == 'sort':
query.sort_tokens.append( query.sort_tokens.append(
_parse_sort(value, negated)) _parse_sort(value, negated))

View file

@ -35,10 +35,77 @@ def test_filter_anonymous(
verify_unpaged(input, expected_tag_names) verify_unpaged(input, expected_tag_names)
@pytest.mark.parametrize('db_driver,input,expected_tag_names', [
(None, ',', None),
(None, 't1,', None),
(None, 't1,t2', ['t1', 't2']),
(None, 't1\\,', []),
(None, 'asd..asd', None),
(None, 'asd\\..asd', []),
(None, 'asd.\\.asd', []),
(None, 'asd\\.\\.asd', []),
(None, '-', None),
(None, '\\-', ['-']),
(None, '--', [
't1', 't2', '*', '*asd*', ':', 'asd:asd', '\\', '\\asd', '-asd',
]),
(None, '\\--', []),
(None, '-\\-', [
't1', 't2', '*', '*asd*', ':', 'asd:asd', '\\', '\\asd', '-asd',
]),
(None, '-*', []),
(None, '\\-*', ['-', '-asd']),
(None, ':', None),
(None, '\\:', [':']),
(None, '\\:asd', []),
(None, '*\\:*', [':', 'asd:asd']),
(None, 'asd:asd', None),
(None, 'asd\\:asd', ['asd:asd']),
(None, '*', [
't1', 't2', '*', '*asd*', ':', 'asd:asd', '\\', '\\asd', '-', '-asd'
]),
(None, '\\*', ['*']),
(None, '\\', None),
(None, '\\asd', None),
('psycopg2', '\\\\', ['\\']),
('psycopg2', '\\\\asd', ['\\asd']),
])
def test_escaping(
executor, tag_factory, input, expected_tag_names, db_driver):
db.session.add_all([
tag_factory(names=['t1']),
tag_factory(names=['t2']),
tag_factory(names=['*']),
tag_factory(names=['*asd*']),
tag_factory(names=[':']),
tag_factory(names=['asd:asd']),
tag_factory(names=['\\']),
tag_factory(names=['\\asd']),
tag_factory(names=['-']),
tag_factory(names=['-asd'])
])
db.session.flush()
if db_driver:
if db.sessionmaker.kw['bind'].driver != db_driver:
pytest.xfail()
if expected_tag_names is None:
with pytest.raises(errors.SearchError):
executor.execute(input, offset=0, limit=100)
else:
actual_count, actual_tags = executor.execute(
input, offset=0, limit=100)
actual_tag_names = [u.names[0].name for u in actual_tags]
assert actual_count == len(expected_tag_names)
assert sorted(actual_tag_names) == sorted(expected_tag_names)
def test_filter_anonymous_starting_with_colon(verify_unpaged, tag_factory): def test_filter_anonymous_starting_with_colon(verify_unpaged, tag_factory):
db.session.add(tag_factory(names=[':t'])) db.session.add(tag_factory(names=[':t']))
db.session.flush() db.session.flush()
with pytest.raises(errors.SearchError):
verify_unpaged(':t', [':t']) verify_unpaged(':t', [':t'])
verify_unpaged('\\:t', [':t'])
@pytest.mark.parametrize('input,expected_tag_names', [ @pytest.mark.parametrize('input,expected_tag_names', [

View file

@ -86,12 +86,24 @@ def test_filter_by_name(
@pytest.mark.parametrize('input,expected_user_names', [ @pytest.mark.parametrize('input,expected_user_names', [
('name:u1', ['u1']), ('name:u1', ['u1']),
('name:u2..', ['u2..']),
('name:u2*', ['u2..']), ('name:u2*', ['u2..']),
('name:*..*', ['u2..', 'u3..x']),
('name:u3..x', ['u3..x']),
('name:*..x', ['u3..x']),
('name:u1,u3..x', ['u1', 'u3..x']), ('name:u1,u3..x', ['u1', 'u3..x']),
('name:u2..', None),
('name:*..*', None),
('name:u3..x', None),
('name:*..x', None),
('name:u2\\..', ['u2..']),
('name:*\\..*', ['u2..', 'u3..x']),
('name:u3\\..x', ['u3..x']),
('name:*\\..x', ['u3..x']),
('name:u2.\\.', ['u2..']),
('name:*.\\.*', ['u2..', 'u3..x']),
('name:u3.\\.x', ['u3..x']),
('name:*.\\.x', ['u3..x']),
('name:u2\\.\\.', ['u2..']),
('name:*\\.\\.*', ['u2..', 'u3..x']),
('name:u3\\.\\.x', ['u3..x']),
('name:*\\.\\.x', ['u3..x']),
]) ])
def test_filter_by_name_that_looks_like_range( def test_filter_by_name_that_looks_like_range(
verify_unpaged, input, expected_user_names, user_factory): verify_unpaged, input, expected_user_names, user_factory):
@ -99,6 +111,10 @@ def test_filter_by_name_that_looks_like_range(
db.session.add(user_factory(name='u2..')) db.session.add(user_factory(name='u2..'))
db.session.add(user_factory(name='u3..x')) db.session.add(user_factory(name='u3..x'))
db.session.flush() db.session.flush()
if not expected_user_names:
with pytest.raises(errors.SearchError):
verify_unpaged(input, expected_user_names)
else:
verify_unpaged(input, expected_user_names) verify_unpaged(input, expected_user_names)