server: lint

This commit is contained in:
rr- 2017-04-24 23:30:53 +02:00
parent fea9a94945
commit 4bc58a3c95
42 changed files with 192 additions and 169 deletions

View file

@ -5,10 +5,10 @@ from hashlib import md5
MAIL_SUBJECT = 'Password reset for {name}' MAIL_SUBJECT = 'Password reset for {name}'
MAIL_BODY = \ MAIL_BODY = (
'You (or someone else) requested to reset your password on {name}.\n' \ 'You (or someone else) requested to reset your password on {name}.\n'
'If you wish to proceed, click this link: {url}\n' \ 'If you wish to proceed, click this link: {url}\n'
'Otherwise, please ignore this email.' 'Otherwise, please ignore this email.')
@rest.routes.get('/password-reset/(?P<user_name>[^/]+)/?') @rest.routes.get('/password-reset/(?P<user_name>[^/]+)/?')

View file

@ -35,7 +35,8 @@ def get_tags(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
@rest.routes.post('/tags/?') @rest.routes.post('/tags/?')
def create_tag(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: def create_tag(
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'tags:create') auth.verify_privilege(ctx.user, 'tags:create')
names = ctx.get_param_as_string_list('names') names = ctx.get_param_as_string_list('names')

View file

@ -16,7 +16,8 @@ def _serialize(
@rest.routes.get('/users/?') @rest.routes.get('/users/?')
def get_users(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: def get_users(
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'users:list') auth.verify_privilege(ctx.user, 'users:list')
return _search_executor.execute_and_serialize( return _search_executor.execute_and_serialize(
ctx, lambda user: _serialize(ctx, user)) ctx, lambda user: _serialize(ctx, user))

View file

@ -21,9 +21,9 @@ class LruCache:
i i
for i, v in enumerate(self.item_list) for i, v in enumerate(self.item_list)
if v.key == item.key) if v.key == item.key)
self.item_list[:] \ self.item_list[:] = (
= self.item_list[:item_index] \ self.item_list[:item_index] +
+ self.item_list[item_index + 1:] self.item_list[item_index + 1:])
self.item_list.insert(0, item) self.item_list.insert(0, item)
else: else:
if len(self.item_list) > self.length: if len(self.item_list) > self.length:

View file

@ -1,7 +1,7 @@
from datetime import datetime from datetime import datetime
from typing import Any, Optional, List, Dict, Callable from typing import Any, Optional, List, Dict, Callable
from szurubooru import db, model, errors, rest from szurubooru import db, model, errors, rest
from szurubooru.func import users, scores, util, serialization from szurubooru.func import users, scores, serialization
class InvalidCommentIdError(errors.ValidationError): class InvalidCommentIdError(errors.ValidationError):
@ -73,10 +73,11 @@ def serialize_comment(
def try_get_comment_by_id(comment_id: int) -> Optional[model.Comment]: def try_get_comment_by_id(comment_id: int) -> Optional[model.Comment]:
comment_id = int(comment_id) comment_id = int(comment_id)
return db.session \ return (
.query(model.Comment) \ db.session
.filter(model.Comment.comment_id == comment_id) \ .query(model.Comment)
.one_or_none() .filter(model.Comment.comment_id == comment_id)
.one_or_none())
def get_comment_by_id(comment_id: int) -> model.Comment: def get_comment_by_id(comment_id: int) -> model.Comment:

View file

@ -11,8 +11,8 @@ from szurubooru.func import mime, util
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_SCALE_FIT_FMT = \ _SCALE_FIT_FMT = (
r'scale=iw*max({width}/iw\,{height}/ih):ih*max({width}/iw\,{height}/ih)' r'scale=iw*max({width}/iw\,{height}/ih):ih*max({width}/iw\,{height}/ih)')
class Image: class Image:

View file

@ -7,10 +7,10 @@ from szurubooru.func import (
mime, images, files, image_hash, serialization) mime, images, files, image_hash, serialization)
EMPTY_PIXEL = \ EMPTY_PIXEL = (
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' \ b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00'
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \ b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00'
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b' b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b')
class PostNotFoundError(errors.NotFoundError): class PostNotFoundError(errors.NotFoundError):
@ -300,10 +300,11 @@ def get_post_count() -> int:
def try_get_post_by_id(post_id: int) -> Optional[model.Post]: def try_get_post_by_id(post_id: int) -> Optional[model.Post]:
return db.session \ return (
.query(model.Post) \ db.session
.filter(model.Post.post_id == post_id) \ .query(model.Post)
.one_or_none() .filter(model.Post.post_id == post_id)
.one_or_none())
def get_post_by_id(post_id: int) -> model.Post: def get_post_by_id(post_id: int) -> model.Post:
@ -314,10 +315,11 @@ def get_post_by_id(post_id: int) -> model.Post:
def try_get_current_post_feature() -> Optional[model.PostFeature]: def try_get_current_post_feature() -> Optional[model.PostFeature]:
return db.session \ return (
.query(model.PostFeature) \ db.session
.order_by(model.PostFeature.time.desc()) \ .query(model.PostFeature)
.first() .order_by(model.PostFeature.time.desc())
.first())
def try_get_featured_post() -> Optional[model.Post]: def try_get_featured_post() -> Optional[model.Post]:
@ -426,11 +428,12 @@ def update_post_content(post: model.Post, content: Optional[bytes]) -> None:
'Unhandled file type: %r' % post.mime_type) 'Unhandled file type: %r' % post.mime_type)
post.checksum = util.get_sha1(content) post.checksum = util.get_sha1(content)
other_post = db.session \ other_post = (
.query(model.Post) \ db.session
.filter(model.Post.checksum == post.checksum) \ .query(model.Post)
.filter(model.Post.post_id != post.post_id) \ .filter(model.Post.checksum == post.checksum)
.one_or_none() .filter(model.Post.post_id != post.post_id)
.one_or_none())
if other_post \ if other_post \
and other_post.post_id \ and other_post.post_id \
and other_post.post_id != post.post_id: and other_post.post_id != post.post_id:
@ -492,10 +495,11 @@ def update_post_relations(post: model.Post, new_post_ids: List[int]) -> None:
old_posts = post.relations old_posts = post.relations
old_post_ids = [int(p.post_id) for p in old_posts] old_post_ids = [int(p.post_id) for p in old_posts]
if new_post_ids: if new_post_ids:
new_posts = db.session \ new_posts = (
.query(model.Post) \ db.session
.filter(model.Post.post_id.in_(new_post_ids)) \ .query(model.Post)
.all() .filter(model.Post.post_id.in_(new_post_ids))
.all())
else: else:
new_posts = [] new_posts = []
if len(new_posts) != len(new_post_ids): if len(new_posts) != len(new_post_ids):
@ -673,10 +677,11 @@ def merge_posts(
def search_by_image_exact(image_content: bytes) -> Optional[model.Post]: def search_by_image_exact(image_content: bytes) -> Optional[model.Post]:
checksum = util.get_sha1(image_content) checksum = util.get_sha1(image_content)
return db.session \ return (
.query(model.Post) \ db.session
.filter(model.Post.checksum == checksum) \ .query(model.Post)
.one_or_none() .filter(model.Post.checksum == checksum)
.one_or_none())
def search_by_image(image_content: bytes) -> List[PostLookalike]: def search_by_image(image_content: bytes) -> List[PostLookalike]:

View file

@ -39,11 +39,12 @@ def get_score(entity: model.Base, user: model.User) -> int:
assert entity assert entity
assert user assert user
table, get_column = _get_table_info(entity) table, get_column = _get_table_info(entity)
row = db.session \ row = (
.query(table.score) \ db.session
.filter(get_column(table) == get_column(entity)) \ .query(table.score)
.filter(table.user_id == user.user_id) \ .filter(get_column(table) == get_column(entity))
.one_or_none() .filter(table.user_id == user.user_id)
.one_or_none())
return row[0] if row else 0 return row[0] if row else 0

View file

@ -1,5 +1,5 @@
from typing import Any, Optional, List, Dict, Callable from typing import Any, List, Dict, Callable
from szurubooru import db, model, rest, errors from szurubooru import model, rest, errors
def get_serialization_options(ctx: rest.Context) -> List[str]: def get_serialization_options(ctx: rest.Context) -> List[str]:

View file

@ -114,9 +114,10 @@ def update_category_color(category: model.TagCategory, color: str) -> None:
def try_get_category_by_name( def try_get_category_by_name(
name: str, lock: bool = False) -> Optional[model.TagCategory]: name: str, lock: bool = False) -> Optional[model.TagCategory]:
query = db.session \ query = (
.query(model.TagCategory) \ db.session
.filter(sa.func.lower(model.TagCategory.name) == name.lower()) .query(model.TagCategory)
.filter(sa.func.lower(model.TagCategory.name) == name.lower()))
if lock: if lock:
query = query.with_lockmode('update') query = query.with_lockmode('update')
return query.one_or_none() return query.one_or_none()
@ -137,19 +138,22 @@ def get_all_categories() -> List[model.TagCategory]:
return db.session.query(model.TagCategory).all() return db.session.query(model.TagCategory).all()
def try_get_default_category(lock: bool=False) -> Optional[model.TagCategory]: def try_get_default_category(
query = db.session \ lock: bool = False) -> Optional[model.TagCategory]:
.query(model.TagCategory) \ query = (
.filter(model.TagCategory.default) db.session
.query(model.TagCategory)
.filter(model.TagCategory.default))
if lock: if lock:
query = query.with_lockmode('update') query = query.with_lockmode('update')
category = query.first() category = query.first()
# if for some reason (e.g. as a result of migration) there's no default # if for some reason (e.g. as a result of migration) there's no default
# category, get the first record available. # category, get the first record available.
if not category: if not category:
query = db.session \ query = (
.query(model.TagCategory) \ db.session
.order_by(model.TagCategory.tag_category_id.asc()) .query(model.TagCategory)
.order_by(model.TagCategory.tag_category_id.asc()))
if lock: if lock:
query = query.with_lockmode('update') query = query.with_lockmode('update')
category = query.first() category = query.first()

View file

@ -209,7 +209,8 @@ def get_tags_by_names(names: List[str]) -> List[model.Tag]:
names = util.icase_unique(names) names = util.icase_unique(names)
if len(names) == 0: if len(names) == 0:
return [] return []
return (db.session.query(model.Tag) return (
db.session.query(model.Tag)
.join(model.TagName) .join(model.TagName)
.filter( .filter(
sa.sql.or_( sa.sql.or_(

View file

@ -170,10 +170,11 @@ def get_user_count() -> int:
def try_get_user_by_name(name: str) -> Optional[model.User]: def try_get_user_by_name(name: str) -> Optional[model.User]:
return db.session \ return (
.query(model.User) \ db.session
.filter(sa.func.lower(model.User.name) == sa.func.lower(name)) \ .query(model.User)
.one_or_none() .filter(sa.func.lower(model.User.name) == sa.func.lower(name))
.one_or_none())
def get_user_by_name(name: str) -> model.User: def get_user_by_name(name: str) -> model.User:

View file

@ -2,8 +2,7 @@ import os
import hashlib import hashlib
import re import re
import tempfile import tempfile
from typing import ( from typing import Any, Optional, Union, Tuple, List, Dict, Generator, TypeVar
Any, Optional, Union, Tuple, List, Dict, Generator, Callable, TypeVar)
from datetime import datetime, timedelta from datetime import datetime, timedelta
from contextlib import contextmanager from contextlib import contextmanager
from szurubooru import errors from szurubooru import errors

View file

@ -27,8 +27,9 @@ def _get_user(ctx: rest.Context) -> Optional[model.User]:
credentials.encode('ascii')).decode('utf8').split(':') credentials.encode('ascii')).decode('utf8').split(':')
return _authenticate(username, password) return _authenticate(username, password)
except ValueError as err: except ValueError as err:
msg = 'Basic authentication header value are not properly formed. ' \ msg = (
+ 'Supplied header {0}. Got error: {1}' 'Basic authentication header value are not properly formed. '
'Supplied header {0}. Got error: {1}')
raise HttpBadRequest( raise HttpBadRequest(
'ValidationError', 'ValidationError',
msg.format(ctx.get_header('Authorization'), str(err))) msg.format(ctx.get_header('Authorization'), str(err)))

View file

@ -50,13 +50,14 @@ def upgrade():
def downgrade(): def downgrade():
session = sa.orm.session.Session(bind=op.get_bind()) session = sa.orm.session.Session(bind=op.get_bind())
default_category = session \ default_category = (
.query(TagCategory) \ session
.filter(TagCategory.name == 'default') \ .query(TagCategory)
.filter(TagCategory.color == 'default') \ .filter(TagCategory.name == 'default')
.filter(TagCategory.version == 1) \ .filter(TagCategory.color == 'default')
.filter(TagCategory.default == True) \ .filter(TagCategory.version == 1)
.one_or_none() .filter(TagCategory.default == 1)
.one_or_none())
if default_category: if default_category:
session.delete(default_category) session.delete(default_category)
session.commit() session.commit()

View file

@ -211,10 +211,11 @@ class Post(Base):
@property @property
def is_featured(self) -> bool: def is_featured(self) -> bool:
featured_post = sa.orm.object_session(self) \ featured_post = (
.query(PostFeature) \ sa.orm.object_session(self)
.order_by(PostFeature.time.desc()) \ .query(PostFeature)
.first() .order_by(PostFeature.time.desc())
.first())
return featured_post and featured_post.post_id == self.post_id return featured_post and featured_post.post_id == self.post_id
score = sa.orm.column_property( score = sa.orm.column_property(

View file

@ -1,4 +1,4 @@
from typing import Callable, Type, Dict from typing import Optional, Callable, Type, Dict
error_handlers = {} # pylint: disable=invalid-name error_handlers = {} # pylint: disable=invalid-name
@ -12,8 +12,8 @@ class BaseHttpError(RuntimeError):
self, self,
name: str, name: str,
description: str, description: str,
title: str=None, title: Optional[str] = None,
extra_fields: Dict[str, str]=None) -> None: extra_fields: Optional[Dict[str, str]] = None) -> None:
super().__init__() super().__init__()
# error name for programmers # error name for programmers
self.name = name self.name = name

View file

@ -1,4 +1,4 @@
from typing import Callable from typing import List, Callable
from szurubooru.rest.context import Context from szurubooru.rest.context import Context

View file

@ -1,4 +1,4 @@
from typing import Callable, Dict, Any from typing import Callable, Dict
from collections import defaultdict from collections import defaultdict
from szurubooru.rest.context import Context, Response from szurubooru.rest.context import Context, Response

View file

@ -52,10 +52,11 @@ def _create_score_filter(score: int) -> Filter:
user_alias.name, criterion) user_alias.name, criterion)
if negated: if negated:
expr = ~expr expr = ~expr
ret = query \ ret = (
.join(score_alias, score_alias.post_id == model.Post.post_id) \ query
.join(user_alias, user_alias.user_id == score_alias.user_id) \ .join(score_alias, score_alias.post_id == model.Post.post_id)
.filter(expr) .join(user_alias, user_alias.user_id == score_alias.user_id)
.filter(expr))
return ret return ret
return wrapper return wrapper
@ -124,7 +125,8 @@ class PostSearchConfig(BaseSearchConfig):
sa.orm.lazyload sa.orm.lazyload
if disable_eager_loads if disable_eager_loads
else sa.orm.subqueryload) else sa.orm.subqueryload)
return db.session.query(model.Post) \ return (
db.session.query(model.Post)
.options( .options(
sa.orm.lazyload('*'), sa.orm.lazyload('*'),
# use config optimized for official client # use config optimized for official client
@ -141,7 +143,7 @@ class PostSearchConfig(BaseSearchConfig):
strategy(model.Post.tags).subqueryload(model.Tag.names), strategy(model.Post.tags).subqueryload(model.Tag.names),
strategy(model.Post.tags).defer(model.Tag.post_count), strategy(model.Post.tags).defer(model.Tag.post_count),
strategy(model.Post.tags).lazyload(model.Tag.implications), strategy(model.Post.tags).lazyload(model.Tag.implications),
strategy(model.Post.tags).lazyload(model.Tag.suggestions)) strategy(model.Post.tags).lazyload(model.Tag.suggestions)))
def create_count_query(self, _disable_eager_loads: bool) -> SaQuery: def create_count_query(self, _disable_eager_loads: bool) -> SaQuery:
return db.session.query(model.Post) return db.session.query(model.Post)

View file

@ -14,8 +14,9 @@ class TagSearchConfig(BaseSearchConfig):
sa.orm.lazyload sa.orm.lazyload
if _disable_eager_loads if _disable_eager_loads
else sa.orm.subqueryload) else sa.orm.subqueryload)
return db.session.query(model.Tag) \ return (
.join(model.TagCategory) \ db.session.query(model.Tag)
.join(model.TagCategory)
.options( .options(
sa.orm.defer(model.Tag.first_name), sa.orm.defer(model.Tag.first_name),
sa.orm.defer(model.Tag.suggestion_count), sa.orm.defer(model.Tag.suggestion_count),
@ -23,7 +24,7 @@ class TagSearchConfig(BaseSearchConfig):
sa.orm.defer(model.Tag.post_count), sa.orm.defer(model.Tag.post_count),
strategy(model.Tag.names), strategy(model.Tag.names),
strategy(model.Tag.suggestions).joinedload(model.Tag.names), strategy(model.Tag.suggestions).joinedload(model.Tag.names),
strategy(model.Tag.implications).joinedload(model.Tag.names)) strategy(model.Tag.implications).joinedload(model.Tag.names)))
def create_count_query(self, _disable_eager_loads: bool) -> SaQuery: def create_count_query(self, _disable_eager_loads: bool) -> SaQuery:
return db.session.query(model.Tag) return db.session.query(model.Tag)

View file

@ -128,8 +128,8 @@ def apply_str_criterion_to_column(
def create_str_filter( def create_str_filter(
column: SaColumn, transformer: Callable[[str], str]=wildcard_transformer column: SaColumn,
) -> Filter: transformer: Callable[[str], str] = wildcard_transformer) -> Filter:
def wrapper( def wrapper(
query: SaQuery, query: SaQuery,
criterion: Optional[criteria.BaseCriterion], criterion: Optional[criteria.BaseCriterion],

View file

@ -1,4 +1,4 @@
from typing import Optional, List, Callable from typing import Optional, List
from szurubooru.search.typing import SaQuery from szurubooru.search.typing import SaQuery

View file

@ -100,18 +100,20 @@ class Executor:
filter_query = self.config.create_filter_query(disable_eager_loads) filter_query = self.config.create_filter_query(disable_eager_loads)
filter_query = filter_query.options(sa.orm.lazyload('*')) filter_query = filter_query.options(sa.orm.lazyload('*'))
filter_query = self._prepare_db_query(filter_query, search_query, True) filter_query = self._prepare_db_query(filter_query, search_query, True)
entities = filter_query \ entities = (
.offset(offset) \ filter_query
.limit(limit) \ .offset(offset)
.all() .limit(limit)
.all())
count_query = self.config.create_count_query(disable_eager_loads) count_query = self.config.create_count_query(disable_eager_loads)
count_query = count_query.options(sa.orm.lazyload('*')) count_query = count_query.options(sa.orm.lazyload('*'))
count_query = self._prepare_db_query(count_query, search_query, False) count_query = self._prepare_db_query(count_query, search_query, False)
count_statement = count_query \ count_statement = (
.statement \ count_query
.with_only_columns([sa.func.count()]) \ .statement
.order_by(None) .with_only_columns([sa.func.count()])
.order_by(None))
count = db.session.execute(count_statement).scalar() count = db.session.execute(count_statement).scalar()
ret = (count, entities) ret = (count, entities)

View file

@ -1,5 +1,4 @@
import re import re
from typing import Match, List
from szurubooru import errors from szurubooru import errors
from szurubooru.search import criteria, tokens from szurubooru.search import criteria, tokens
from szurubooru.search.query import SearchQuery from szurubooru.search.query import SearchQuery

View file

@ -1,4 +1,5 @@
from szurubooru.search import tokens from szurubooru.search import tokens
from typing import List
class SearchQuery: class SearchQuery:

View file

@ -1,6 +1,6 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db, model, errors from szurubooru import api, model, errors
from szurubooru.func import tags, snapshots from szurubooru.func import tags, snapshots

View file

@ -1,6 +1,6 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db, model, errors from szurubooru import api, model, errors
from szurubooru.func import users from szurubooru.func import users

View file

@ -5,10 +5,10 @@ from szurubooru import db, model, errors
from szurubooru.func import auth, users, files, util from szurubooru.func import auth, users, files, util
EMPTY_PIXEL = \ EMPTY_PIXEL = (
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' \ b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00'
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \ b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00'
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b' b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b')
@pytest.mark.parametrize('user_name', ['test', 'TEST']) @pytest.mark.parametrize('user_name', ['test', 'TEST'])

View file

@ -28,11 +28,12 @@ def test_saving_tag(tag_factory):
tag.implications.append(imp2) tag.implications.append(imp2)
db.session.commit() db.session.commit()
tag = db.session \ tag = (
.query(model.Tag) \ db.session
.join(model.TagName) \ .query(model.Tag)
.filter(model.TagName.name == 'alias1') \ .join(model.TagName)
.one() .filter(model.TagName.name == 'alias1')
.one())
assert [tag_name.name for tag_name in tag.names] == ['alias1', 'alias2'] assert [tag_name.name for tag_name in tag.names] == ['alias1', 'alias2']
assert tag.category.name == 'category' assert tag.category.name == 'category'
assert tag.creation_time == datetime(1997, 1, 1) assert tag.creation_time == datetime(1997, 1, 1)

View file

@ -300,7 +300,7 @@ def test_filter_by_note_count(
('note-text:text3*', [3]), ('note-text:text3*', [3]),
('note-text:text3a,text2', [2, 3]), ('note-text:text3a,text2', [2, 3]),
]) ])
def test_filter_by_note_count( def test_filter_by_note_text(
verify_unpaged, post_factory, note_factory, input, expected_post_ids): verify_unpaged, post_factory, note_factory, input, expected_post_ids):
post1 = post_factory(id=1) post1 = post_factory(id=1)
post2 = post_factory(id=2) post2 = post_factory(id=2)