server: lint
This commit is contained in:
parent
fea9a94945
commit
4bc58a3c95
42 changed files with 192 additions and 169 deletions
|
@ -5,10 +5,10 @@ from hashlib import md5
|
||||||
|
|
||||||
|
|
||||||
MAIL_SUBJECT = 'Password reset for {name}'
|
MAIL_SUBJECT = 'Password reset for {name}'
|
||||||
MAIL_BODY = \
|
MAIL_BODY = (
|
||||||
'You (or someone else) requested to reset your password on {name}.\n' \
|
'You (or someone else) requested to reset your password on {name}.\n'
|
||||||
'If you wish to proceed, click this link: {url}\n' \
|
'If you wish to proceed, click this link: {url}\n'
|
||||||
'Otherwise, please ignore this email.'
|
'Otherwise, please ignore this email.')
|
||||||
|
|
||||||
|
|
||||||
@rest.routes.get('/password-reset/(?P<user_name>[^/]+)/?')
|
@rest.routes.get('/password-reset/(?P<user_name>[^/]+)/?')
|
||||||
|
|
|
@ -35,7 +35,8 @@ def get_tags(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||||
|
|
||||||
|
|
||||||
@rest.routes.post('/tags/?')
|
@rest.routes.post('/tags/?')
|
||||||
def create_tag(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
def create_tag(
|
||||||
|
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
|
||||||
auth.verify_privilege(ctx.user, 'tags:create')
|
auth.verify_privilege(ctx.user, 'tags:create')
|
||||||
|
|
||||||
names = ctx.get_param_as_string_list('names')
|
names = ctx.get_param_as_string_list('names')
|
||||||
|
|
|
@ -16,7 +16,8 @@ def _serialize(
|
||||||
|
|
||||||
|
|
||||||
@rest.routes.get('/users/?')
|
@rest.routes.get('/users/?')
|
||||||
def get_users(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
def get_users(
|
||||||
|
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
|
||||||
auth.verify_privilege(ctx.user, 'users:list')
|
auth.verify_privilege(ctx.user, 'users:list')
|
||||||
return _search_executor.execute_and_serialize(
|
return _search_executor.execute_and_serialize(
|
||||||
ctx, lambda user: _serialize(ctx, user))
|
ctx, lambda user: _serialize(ctx, user))
|
||||||
|
|
|
@ -21,9 +21,9 @@ class LruCache:
|
||||||
i
|
i
|
||||||
for i, v in enumerate(self.item_list)
|
for i, v in enumerate(self.item_list)
|
||||||
if v.key == item.key)
|
if v.key == item.key)
|
||||||
self.item_list[:] \
|
self.item_list[:] = (
|
||||||
= self.item_list[:item_index] \
|
self.item_list[:item_index] +
|
||||||
+ self.item_list[item_index + 1:]
|
self.item_list[item_index + 1:])
|
||||||
self.item_list.insert(0, item)
|
self.item_list.insert(0, item)
|
||||||
else:
|
else:
|
||||||
if len(self.item_list) > self.length:
|
if len(self.item_list) > self.length:
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Optional, List, Dict, Callable
|
from typing import Any, Optional, List, Dict, Callable
|
||||||
from szurubooru import db, model, errors, rest
|
from szurubooru import db, model, errors, rest
|
||||||
from szurubooru.func import users, scores, util, serialization
|
from szurubooru.func import users, scores, serialization
|
||||||
|
|
||||||
|
|
||||||
class InvalidCommentIdError(errors.ValidationError):
|
class InvalidCommentIdError(errors.ValidationError):
|
||||||
|
@ -73,10 +73,11 @@ def serialize_comment(
|
||||||
|
|
||||||
def try_get_comment_by_id(comment_id: int) -> Optional[model.Comment]:
|
def try_get_comment_by_id(comment_id: int) -> Optional[model.Comment]:
|
||||||
comment_id = int(comment_id)
|
comment_id = int(comment_id)
|
||||||
return db.session \
|
return (
|
||||||
.query(model.Comment) \
|
db.session
|
||||||
.filter(model.Comment.comment_id == comment_id) \
|
.query(model.Comment)
|
||||||
.one_or_none()
|
.filter(model.Comment.comment_id == comment_id)
|
||||||
|
.one_or_none())
|
||||||
|
|
||||||
|
|
||||||
def get_comment_by_id(comment_id: int) -> model.Comment:
|
def get_comment_by_id(comment_id: int) -> model.Comment:
|
||||||
|
|
|
@ -11,8 +11,8 @@ from szurubooru.func import mime, util
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
_SCALE_FIT_FMT = \
|
_SCALE_FIT_FMT = (
|
||||||
r'scale=iw*max({width}/iw\,{height}/ih):ih*max({width}/iw\,{height}/ih)'
|
r'scale=iw*max({width}/iw\,{height}/ih):ih*max({width}/iw\,{height}/ih)')
|
||||||
|
|
||||||
|
|
||||||
class Image:
|
class Image:
|
||||||
|
|
|
@ -7,10 +7,10 @@ from szurubooru.func import (
|
||||||
mime, images, files, image_hash, serialization)
|
mime, images, files, image_hash, serialization)
|
||||||
|
|
||||||
|
|
||||||
EMPTY_PIXEL = \
|
EMPTY_PIXEL = (
|
||||||
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' \
|
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00'
|
||||||
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \
|
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00'
|
||||||
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b'
|
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b')
|
||||||
|
|
||||||
|
|
||||||
class PostNotFoundError(errors.NotFoundError):
|
class PostNotFoundError(errors.NotFoundError):
|
||||||
|
@ -300,10 +300,11 @@ def get_post_count() -> int:
|
||||||
|
|
||||||
|
|
||||||
def try_get_post_by_id(post_id: int) -> Optional[model.Post]:
|
def try_get_post_by_id(post_id: int) -> Optional[model.Post]:
|
||||||
return db.session \
|
return (
|
||||||
.query(model.Post) \
|
db.session
|
||||||
.filter(model.Post.post_id == post_id) \
|
.query(model.Post)
|
||||||
.one_or_none()
|
.filter(model.Post.post_id == post_id)
|
||||||
|
.one_or_none())
|
||||||
|
|
||||||
|
|
||||||
def get_post_by_id(post_id: int) -> model.Post:
|
def get_post_by_id(post_id: int) -> model.Post:
|
||||||
|
@ -314,10 +315,11 @@ def get_post_by_id(post_id: int) -> model.Post:
|
||||||
|
|
||||||
|
|
||||||
def try_get_current_post_feature() -> Optional[model.PostFeature]:
|
def try_get_current_post_feature() -> Optional[model.PostFeature]:
|
||||||
return db.session \
|
return (
|
||||||
.query(model.PostFeature) \
|
db.session
|
||||||
.order_by(model.PostFeature.time.desc()) \
|
.query(model.PostFeature)
|
||||||
.first()
|
.order_by(model.PostFeature.time.desc())
|
||||||
|
.first())
|
||||||
|
|
||||||
|
|
||||||
def try_get_featured_post() -> Optional[model.Post]:
|
def try_get_featured_post() -> Optional[model.Post]:
|
||||||
|
@ -426,11 +428,12 @@ def update_post_content(post: model.Post, content: Optional[bytes]) -> None:
|
||||||
'Unhandled file type: %r' % post.mime_type)
|
'Unhandled file type: %r' % post.mime_type)
|
||||||
|
|
||||||
post.checksum = util.get_sha1(content)
|
post.checksum = util.get_sha1(content)
|
||||||
other_post = db.session \
|
other_post = (
|
||||||
.query(model.Post) \
|
db.session
|
||||||
.filter(model.Post.checksum == post.checksum) \
|
.query(model.Post)
|
||||||
.filter(model.Post.post_id != post.post_id) \
|
.filter(model.Post.checksum == post.checksum)
|
||||||
.one_or_none()
|
.filter(model.Post.post_id != post.post_id)
|
||||||
|
.one_or_none())
|
||||||
if other_post \
|
if other_post \
|
||||||
and other_post.post_id \
|
and other_post.post_id \
|
||||||
and other_post.post_id != post.post_id:
|
and other_post.post_id != post.post_id:
|
||||||
|
@ -492,10 +495,11 @@ def update_post_relations(post: model.Post, new_post_ids: List[int]) -> None:
|
||||||
old_posts = post.relations
|
old_posts = post.relations
|
||||||
old_post_ids = [int(p.post_id) for p in old_posts]
|
old_post_ids = [int(p.post_id) for p in old_posts]
|
||||||
if new_post_ids:
|
if new_post_ids:
|
||||||
new_posts = db.session \
|
new_posts = (
|
||||||
.query(model.Post) \
|
db.session
|
||||||
.filter(model.Post.post_id.in_(new_post_ids)) \
|
.query(model.Post)
|
||||||
.all()
|
.filter(model.Post.post_id.in_(new_post_ids))
|
||||||
|
.all())
|
||||||
else:
|
else:
|
||||||
new_posts = []
|
new_posts = []
|
||||||
if len(new_posts) != len(new_post_ids):
|
if len(new_posts) != len(new_post_ids):
|
||||||
|
@ -673,10 +677,11 @@ def merge_posts(
|
||||||
|
|
||||||
def search_by_image_exact(image_content: bytes) -> Optional[model.Post]:
|
def search_by_image_exact(image_content: bytes) -> Optional[model.Post]:
|
||||||
checksum = util.get_sha1(image_content)
|
checksum = util.get_sha1(image_content)
|
||||||
return db.session \
|
return (
|
||||||
.query(model.Post) \
|
db.session
|
||||||
.filter(model.Post.checksum == checksum) \
|
.query(model.Post)
|
||||||
.one_or_none()
|
.filter(model.Post.checksum == checksum)
|
||||||
|
.one_or_none())
|
||||||
|
|
||||||
|
|
||||||
def search_by_image(image_content: bytes) -> List[PostLookalike]:
|
def search_by_image(image_content: bytes) -> List[PostLookalike]:
|
||||||
|
|
|
@ -39,11 +39,12 @@ def get_score(entity: model.Base, user: model.User) -> int:
|
||||||
assert entity
|
assert entity
|
||||||
assert user
|
assert user
|
||||||
table, get_column = _get_table_info(entity)
|
table, get_column = _get_table_info(entity)
|
||||||
row = db.session \
|
row = (
|
||||||
.query(table.score) \
|
db.session
|
||||||
.filter(get_column(table) == get_column(entity)) \
|
.query(table.score)
|
||||||
.filter(table.user_id == user.user_id) \
|
.filter(get_column(table) == get_column(entity))
|
||||||
.one_or_none()
|
.filter(table.user_id == user.user_id)
|
||||||
|
.one_or_none())
|
||||||
return row[0] if row else 0
|
return row[0] if row else 0
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from typing import Any, Optional, List, Dict, Callable
|
from typing import Any, List, Dict, Callable
|
||||||
from szurubooru import db, model, rest, errors
|
from szurubooru import model, rest, errors
|
||||||
|
|
||||||
|
|
||||||
def get_serialization_options(ctx: rest.Context) -> List[str]:
|
def get_serialization_options(ctx: rest.Context) -> List[str]:
|
||||||
|
|
|
@ -114,9 +114,10 @@ def update_category_color(category: model.TagCategory, color: str) -> None:
|
||||||
|
|
||||||
def try_get_category_by_name(
|
def try_get_category_by_name(
|
||||||
name: str, lock: bool = False) -> Optional[model.TagCategory]:
|
name: str, lock: bool = False) -> Optional[model.TagCategory]:
|
||||||
query = db.session \
|
query = (
|
||||||
.query(model.TagCategory) \
|
db.session
|
||||||
.filter(sa.func.lower(model.TagCategory.name) == name.lower())
|
.query(model.TagCategory)
|
||||||
|
.filter(sa.func.lower(model.TagCategory.name) == name.lower()))
|
||||||
if lock:
|
if lock:
|
||||||
query = query.with_lockmode('update')
|
query = query.with_lockmode('update')
|
||||||
return query.one_or_none()
|
return query.one_or_none()
|
||||||
|
@ -137,19 +138,22 @@ def get_all_categories() -> List[model.TagCategory]:
|
||||||
return db.session.query(model.TagCategory).all()
|
return db.session.query(model.TagCategory).all()
|
||||||
|
|
||||||
|
|
||||||
def try_get_default_category(lock: bool=False) -> Optional[model.TagCategory]:
|
def try_get_default_category(
|
||||||
query = db.session \
|
lock: bool = False) -> Optional[model.TagCategory]:
|
||||||
.query(model.TagCategory) \
|
query = (
|
||||||
.filter(model.TagCategory.default)
|
db.session
|
||||||
|
.query(model.TagCategory)
|
||||||
|
.filter(model.TagCategory.default))
|
||||||
if lock:
|
if lock:
|
||||||
query = query.with_lockmode('update')
|
query = query.with_lockmode('update')
|
||||||
category = query.first()
|
category = query.first()
|
||||||
# if for some reason (e.g. as a result of migration) there's no default
|
# if for some reason (e.g. as a result of migration) there's no default
|
||||||
# category, get the first record available.
|
# category, get the first record available.
|
||||||
if not category:
|
if not category:
|
||||||
query = db.session \
|
query = (
|
||||||
.query(model.TagCategory) \
|
db.session
|
||||||
.order_by(model.TagCategory.tag_category_id.asc())
|
.query(model.TagCategory)
|
||||||
|
.order_by(model.TagCategory.tag_category_id.asc()))
|
||||||
if lock:
|
if lock:
|
||||||
query = query.with_lockmode('update')
|
query = query.with_lockmode('update')
|
||||||
category = query.first()
|
category = query.first()
|
||||||
|
|
|
@ -209,7 +209,8 @@ def get_tags_by_names(names: List[str]) -> List[model.Tag]:
|
||||||
names = util.icase_unique(names)
|
names = util.icase_unique(names)
|
||||||
if len(names) == 0:
|
if len(names) == 0:
|
||||||
return []
|
return []
|
||||||
return (db.session.query(model.Tag)
|
return (
|
||||||
|
db.session.query(model.Tag)
|
||||||
.join(model.TagName)
|
.join(model.TagName)
|
||||||
.filter(
|
.filter(
|
||||||
sa.sql.or_(
|
sa.sql.or_(
|
||||||
|
|
|
@ -170,10 +170,11 @@ def get_user_count() -> int:
|
||||||
|
|
||||||
|
|
||||||
def try_get_user_by_name(name: str) -> Optional[model.User]:
|
def try_get_user_by_name(name: str) -> Optional[model.User]:
|
||||||
return db.session \
|
return (
|
||||||
.query(model.User) \
|
db.session
|
||||||
.filter(sa.func.lower(model.User.name) == sa.func.lower(name)) \
|
.query(model.User)
|
||||||
.one_or_none()
|
.filter(sa.func.lower(model.User.name) == sa.func.lower(name))
|
||||||
|
.one_or_none())
|
||||||
|
|
||||||
|
|
||||||
def get_user_by_name(name: str) -> model.User:
|
def get_user_by_name(name: str) -> model.User:
|
||||||
|
|
|
@ -2,8 +2,7 @@ import os
|
||||||
import hashlib
|
import hashlib
|
||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import (
|
from typing import Any, Optional, Union, Tuple, List, Dict, Generator, TypeVar
|
||||||
Any, Optional, Union, Tuple, List, Dict, Generator, Callable, TypeVar)
|
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from szurubooru import errors
|
from szurubooru import errors
|
||||||
|
|
|
@ -27,8 +27,9 @@ def _get_user(ctx: rest.Context) -> Optional[model.User]:
|
||||||
credentials.encode('ascii')).decode('utf8').split(':')
|
credentials.encode('ascii')).decode('utf8').split(':')
|
||||||
return _authenticate(username, password)
|
return _authenticate(username, password)
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
msg = 'Basic authentication header value are not properly formed. ' \
|
msg = (
|
||||||
+ 'Supplied header {0}. Got error: {1}'
|
'Basic authentication header value are not properly formed. '
|
||||||
|
'Supplied header {0}. Got error: {1}')
|
||||||
raise HttpBadRequest(
|
raise HttpBadRequest(
|
||||||
'ValidationError',
|
'ValidationError',
|
||||||
msg.format(ctx.get_header('Authorization'), str(err)))
|
msg.format(ctx.get_header('Authorization'), str(err)))
|
||||||
|
|
|
@ -50,13 +50,14 @@ def upgrade():
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
session = sa.orm.session.Session(bind=op.get_bind())
|
session = sa.orm.session.Session(bind=op.get_bind())
|
||||||
default_category = session \
|
default_category = (
|
||||||
.query(TagCategory) \
|
session
|
||||||
.filter(TagCategory.name == 'default') \
|
.query(TagCategory)
|
||||||
.filter(TagCategory.color == 'default') \
|
.filter(TagCategory.name == 'default')
|
||||||
.filter(TagCategory.version == 1) \
|
.filter(TagCategory.color == 'default')
|
||||||
.filter(TagCategory.default == True) \
|
.filter(TagCategory.version == 1)
|
||||||
.one_or_none()
|
.filter(TagCategory.default == 1)
|
||||||
|
.one_or_none())
|
||||||
if default_category:
|
if default_category:
|
||||||
session.delete(default_category)
|
session.delete(default_category)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
|
@ -211,10 +211,11 @@ class Post(Base):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_featured(self) -> bool:
|
def is_featured(self) -> bool:
|
||||||
featured_post = sa.orm.object_session(self) \
|
featured_post = (
|
||||||
.query(PostFeature) \
|
sa.orm.object_session(self)
|
||||||
.order_by(PostFeature.time.desc()) \
|
.query(PostFeature)
|
||||||
.first()
|
.order_by(PostFeature.time.desc())
|
||||||
|
.first())
|
||||||
return featured_post and featured_post.post_id == self.post_id
|
return featured_post and featured_post.post_id == self.post_id
|
||||||
|
|
||||||
score = sa.orm.column_property(
|
score = sa.orm.column_property(
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Callable, Type, Dict
|
from typing import Optional, Callable, Type, Dict
|
||||||
|
|
||||||
|
|
||||||
error_handlers = {} # pylint: disable=invalid-name
|
error_handlers = {} # pylint: disable=invalid-name
|
||||||
|
@ -12,8 +12,8 @@ class BaseHttpError(RuntimeError):
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
description: str,
|
description: str,
|
||||||
title: str=None,
|
title: Optional[str] = None,
|
||||||
extra_fields: Dict[str, str]=None) -> None:
|
extra_fields: Optional[Dict[str, str]] = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# error name for programmers
|
# error name for programmers
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Callable
|
from typing import List, Callable
|
||||||
from szurubooru.rest.context import Context
|
from szurubooru.rest.context import Context
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Callable, Dict, Any
|
from typing import Callable, Dict
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from szurubooru.rest.context import Context, Response
|
from szurubooru.rest.context import Context, Response
|
||||||
|
|
||||||
|
|
|
@ -52,10 +52,11 @@ def _create_score_filter(score: int) -> Filter:
|
||||||
user_alias.name, criterion)
|
user_alias.name, criterion)
|
||||||
if negated:
|
if negated:
|
||||||
expr = ~expr
|
expr = ~expr
|
||||||
ret = query \
|
ret = (
|
||||||
.join(score_alias, score_alias.post_id == model.Post.post_id) \
|
query
|
||||||
.join(user_alias, user_alias.user_id == score_alias.user_id) \
|
.join(score_alias, score_alias.post_id == model.Post.post_id)
|
||||||
.filter(expr)
|
.join(user_alias, user_alias.user_id == score_alias.user_id)
|
||||||
|
.filter(expr))
|
||||||
return ret
|
return ret
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
@ -124,7 +125,8 @@ class PostSearchConfig(BaseSearchConfig):
|
||||||
sa.orm.lazyload
|
sa.orm.lazyload
|
||||||
if disable_eager_loads
|
if disable_eager_loads
|
||||||
else sa.orm.subqueryload)
|
else sa.orm.subqueryload)
|
||||||
return db.session.query(model.Post) \
|
return (
|
||||||
|
db.session.query(model.Post)
|
||||||
.options(
|
.options(
|
||||||
sa.orm.lazyload('*'),
|
sa.orm.lazyload('*'),
|
||||||
# use config optimized for official client
|
# use config optimized for official client
|
||||||
|
@ -141,7 +143,7 @@ class PostSearchConfig(BaseSearchConfig):
|
||||||
strategy(model.Post.tags).subqueryload(model.Tag.names),
|
strategy(model.Post.tags).subqueryload(model.Tag.names),
|
||||||
strategy(model.Post.tags).defer(model.Tag.post_count),
|
strategy(model.Post.tags).defer(model.Tag.post_count),
|
||||||
strategy(model.Post.tags).lazyload(model.Tag.implications),
|
strategy(model.Post.tags).lazyload(model.Tag.implications),
|
||||||
strategy(model.Post.tags).lazyload(model.Tag.suggestions))
|
strategy(model.Post.tags).lazyload(model.Tag.suggestions)))
|
||||||
|
|
||||||
def create_count_query(self, _disable_eager_loads: bool) -> SaQuery:
|
def create_count_query(self, _disable_eager_loads: bool) -> SaQuery:
|
||||||
return db.session.query(model.Post)
|
return db.session.query(model.Post)
|
||||||
|
|
|
@ -14,8 +14,9 @@ class TagSearchConfig(BaseSearchConfig):
|
||||||
sa.orm.lazyload
|
sa.orm.lazyload
|
||||||
if _disable_eager_loads
|
if _disable_eager_loads
|
||||||
else sa.orm.subqueryload)
|
else sa.orm.subqueryload)
|
||||||
return db.session.query(model.Tag) \
|
return (
|
||||||
.join(model.TagCategory) \
|
db.session.query(model.Tag)
|
||||||
|
.join(model.TagCategory)
|
||||||
.options(
|
.options(
|
||||||
sa.orm.defer(model.Tag.first_name),
|
sa.orm.defer(model.Tag.first_name),
|
||||||
sa.orm.defer(model.Tag.suggestion_count),
|
sa.orm.defer(model.Tag.suggestion_count),
|
||||||
|
@ -23,7 +24,7 @@ class TagSearchConfig(BaseSearchConfig):
|
||||||
sa.orm.defer(model.Tag.post_count),
|
sa.orm.defer(model.Tag.post_count),
|
||||||
strategy(model.Tag.names),
|
strategy(model.Tag.names),
|
||||||
strategy(model.Tag.suggestions).joinedload(model.Tag.names),
|
strategy(model.Tag.suggestions).joinedload(model.Tag.names),
|
||||||
strategy(model.Tag.implications).joinedload(model.Tag.names))
|
strategy(model.Tag.implications).joinedload(model.Tag.names)))
|
||||||
|
|
||||||
def create_count_query(self, _disable_eager_loads: bool) -> SaQuery:
|
def create_count_query(self, _disable_eager_loads: bool) -> SaQuery:
|
||||||
return db.session.query(model.Tag)
|
return db.session.query(model.Tag)
|
||||||
|
|
|
@ -128,8 +128,8 @@ def apply_str_criterion_to_column(
|
||||||
|
|
||||||
|
|
||||||
def create_str_filter(
|
def create_str_filter(
|
||||||
column: SaColumn, transformer: Callable[[str], str]=wildcard_transformer
|
column: SaColumn,
|
||||||
) -> Filter:
|
transformer: Callable[[str], str] = wildcard_transformer) -> Filter:
|
||||||
def wrapper(
|
def wrapper(
|
||||||
query: SaQuery,
|
query: SaQuery,
|
||||||
criterion: Optional[criteria.BaseCriterion],
|
criterion: Optional[criteria.BaseCriterion],
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Optional, List, Callable
|
from typing import Optional, List
|
||||||
from szurubooru.search.typing import SaQuery
|
from szurubooru.search.typing import SaQuery
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -100,18 +100,20 @@ class Executor:
|
||||||
filter_query = self.config.create_filter_query(disable_eager_loads)
|
filter_query = self.config.create_filter_query(disable_eager_loads)
|
||||||
filter_query = filter_query.options(sa.orm.lazyload('*'))
|
filter_query = filter_query.options(sa.orm.lazyload('*'))
|
||||||
filter_query = self._prepare_db_query(filter_query, search_query, True)
|
filter_query = self._prepare_db_query(filter_query, search_query, True)
|
||||||
entities = filter_query \
|
entities = (
|
||||||
.offset(offset) \
|
filter_query
|
||||||
.limit(limit) \
|
.offset(offset)
|
||||||
.all()
|
.limit(limit)
|
||||||
|
.all())
|
||||||
|
|
||||||
count_query = self.config.create_count_query(disable_eager_loads)
|
count_query = self.config.create_count_query(disable_eager_loads)
|
||||||
count_query = count_query.options(sa.orm.lazyload('*'))
|
count_query = count_query.options(sa.orm.lazyload('*'))
|
||||||
count_query = self._prepare_db_query(count_query, search_query, False)
|
count_query = self._prepare_db_query(count_query, search_query, False)
|
||||||
count_statement = count_query \
|
count_statement = (
|
||||||
.statement \
|
count_query
|
||||||
.with_only_columns([sa.func.count()]) \
|
.statement
|
||||||
.order_by(None)
|
.with_only_columns([sa.func.count()])
|
||||||
|
.order_by(None))
|
||||||
count = db.session.execute(count_statement).scalar()
|
count = db.session.execute(count_statement).scalar()
|
||||||
|
|
||||||
ret = (count, entities)
|
ret = (count, entities)
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import re
|
import re
|
||||||
from typing import Match, List
|
|
||||||
from szurubooru import errors
|
from szurubooru import errors
|
||||||
from szurubooru.search import criteria, tokens
|
from szurubooru.search import criteria, tokens
|
||||||
from szurubooru.search.query import SearchQuery
|
from szurubooru.search.query import SearchQuery
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from szurubooru.search import tokens
|
from szurubooru.search import tokens
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
class SearchQuery:
|
class SearchQuery:
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, model, errors
|
from szurubooru import api, model, errors
|
||||||
from szurubooru.func import tags, snapshots
|
from szurubooru.func import tags, snapshots
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, model, errors
|
from szurubooru import api, model, errors
|
||||||
from szurubooru.func import users
|
from szurubooru.func import users
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,10 +5,10 @@ from szurubooru import db, model, errors
|
||||||
from szurubooru.func import auth, users, files, util
|
from szurubooru.func import auth, users, files, util
|
||||||
|
|
||||||
|
|
||||||
EMPTY_PIXEL = \
|
EMPTY_PIXEL = (
|
||||||
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' \
|
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00'
|
||||||
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \
|
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00'
|
||||||
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b'
|
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b')
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('user_name', ['test', 'TEST'])
|
@pytest.mark.parametrize('user_name', ['test', 'TEST'])
|
||||||
|
|
|
@ -28,11 +28,12 @@ def test_saving_tag(tag_factory):
|
||||||
tag.implications.append(imp2)
|
tag.implications.append(imp2)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
tag = db.session \
|
tag = (
|
||||||
.query(model.Tag) \
|
db.session
|
||||||
.join(model.TagName) \
|
.query(model.Tag)
|
||||||
.filter(model.TagName.name == 'alias1') \
|
.join(model.TagName)
|
||||||
.one()
|
.filter(model.TagName.name == 'alias1')
|
||||||
|
.one())
|
||||||
assert [tag_name.name for tag_name in tag.names] == ['alias1', 'alias2']
|
assert [tag_name.name for tag_name in tag.names] == ['alias1', 'alias2']
|
||||||
assert tag.category.name == 'category'
|
assert tag.category.name == 'category'
|
||||||
assert tag.creation_time == datetime(1997, 1, 1)
|
assert tag.creation_time == datetime(1997, 1, 1)
|
||||||
|
|
|
@ -300,7 +300,7 @@ def test_filter_by_note_count(
|
||||||
('note-text:text3*', [3]),
|
('note-text:text3*', [3]),
|
||||||
('note-text:text3a,text2', [2, 3]),
|
('note-text:text3a,text2', [2, 3]),
|
||||||
])
|
])
|
||||||
def test_filter_by_note_count(
|
def test_filter_by_note_text(
|
||||||
verify_unpaged, post_factory, note_factory, input, expected_post_ids):
|
verify_unpaged, post_factory, note_factory, input, expected_post_ids):
|
||||||
post1 = post_factory(id=1)
|
post1 = post_factory(id=1)
|
||||||
post2 = post_factory(id=2)
|
post2 = post_factory(id=2)
|
||||||
|
|
Loading…
Reference in a new issue