server: refactor + add type hinting

- Added type hinting (for now, 3.5-compatible)
- Split `db` namespace into `db` module and `model` namespace
- Changed elastic search to be created lazily for each operation
- Changed to class based approach in entity serialization to allow
  stronger typing
- Removed `required` argument from `context.get_*` family of functions;
  now it's implied if `default` argument is omitted
- Changed `unalias_dict` implementation to use less magic inputs
This commit is contained in:
rr- 2017-02-04 01:08:12 +01:00
parent abf1fc2b2d
commit ad842ee8a5
116 changed files with 2868 additions and 2037 deletions

View file

@ -8,7 +8,7 @@ import zlib
import concurrent.futures
import logging
import coloredlogs
import sqlalchemy
import sqlalchemy as sa
from szurubooru import config, db
from szurubooru.func import files, images, posts, comments
@ -42,8 +42,8 @@ def get_v1_session(args):
port=args.port,
name=args.name)
logger.info('Connecting to %r...', dsn)
engine = sqlalchemy.create_engine(dsn)
session_maker = sqlalchemy.orm.sessionmaker(bind=engine)
engine = sa.create_engine(dsn)
session_maker = sa.orm.sessionmaker(bind=engine)
return session_maker()
def parse_args():

14
server/mypy.ini Normal file
View file

@ -0,0 +1,14 @@
[mypy]
ignore_missing_imports = True
follow_imports = skip
disallow_untyped_calls = True
disallow_untyped_defs = True
check_untyped_defs = True
disallow_subclassing_any = False
warn_redundant_casts = True
warn_unused_ignores = True
strict_optional = True
strict_boolean = False
[mypy-szurubooru.tests.*]
ignore_errors=True

View file

@ -1,31 +1,44 @@
import datetime
from szurubooru import search
from szurubooru.rest import routes
from szurubooru.func import auth, comments, posts, scores, util, versions
from typing import Dict
from datetime import datetime
from szurubooru import search, rest, model
from szurubooru.func import (
auth, comments, posts, scores, versions, serialization)
_search_executor = search.Executor(search.configs.CommentSearchConfig())
def _serialize(ctx, comment, **kwargs):
def _get_comment(params: Dict[str, str]) -> model.Comment:
try:
comment_id = int(params['comment_id'])
except TypeError:
raise comments.InvalidCommentIdError(
'Invalid comment ID: %r.' % params['comment_id'])
return comments.get_comment_by_id(comment_id)
def _serialize(
ctx: rest.Context, comment: model.Comment) -> rest.Response:
return comments.serialize_comment(
comment,
ctx.user,
options=util.get_serialization_options(ctx), **kwargs)
options=serialization.get_serialization_options(ctx))
@routes.get('/comments/?')
def get_comments(ctx, _params=None):
@rest.routes.get('/comments/?')
def get_comments(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
auth.verify_privilege(ctx.user, 'comments:list')
return _search_executor.execute_and_serialize(
ctx, lambda comment: _serialize(ctx, comment))
@routes.post('/comments/?')
def create_comment(ctx, _params=None):
@rest.routes.post('/comments/?')
def create_comment(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
auth.verify_privilege(ctx.user, 'comments:create')
text = ctx.get_param_as_string('text', required=True)
post_id = ctx.get_param_as_int('postId', required=True)
text = ctx.get_param_as_string('text')
post_id = ctx.get_param_as_int('postId')
post = posts.get_post_by_id(post_id)
comment = comments.create_comment(ctx.user, post, text)
ctx.session.add(comment)
@ -33,30 +46,30 @@ def create_comment(ctx, _params=None):
return _serialize(ctx, comment)
@routes.get('/comment/(?P<comment_id>[^/]+)/?')
def get_comment(ctx, params):
@rest.routes.get('/comment/(?P<comment_id>[^/]+)/?')
def get_comment(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'comments:view')
comment = comments.get_comment_by_id(params['comment_id'])
comment = _get_comment(params)
return _serialize(ctx, comment)
@routes.put('/comment/(?P<comment_id>[^/]+)/?')
def update_comment(ctx, params):
comment = comments.get_comment_by_id(params['comment_id'])
@rest.routes.put('/comment/(?P<comment_id>[^/]+)/?')
def update_comment(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
comment = _get_comment(params)
versions.verify_version(comment, ctx)
versions.bump_version(comment)
infix = 'own' if ctx.user.user_id == comment.user_id else 'any'
text = ctx.get_param_as_string('text', required=True)
text = ctx.get_param_as_string('text')
auth.verify_privilege(ctx.user, 'comments:edit:%s' % infix)
comments.update_comment_text(comment, text)
comment.last_edit_time = datetime.datetime.utcnow()
comment.last_edit_time = datetime.utcnow()
ctx.session.commit()
return _serialize(ctx, comment)
@routes.delete('/comment/(?P<comment_id>[^/]+)/?')
def delete_comment(ctx, params):
comment = comments.get_comment_by_id(params['comment_id'])
@rest.routes.delete('/comment/(?P<comment_id>[^/]+)/?')
def delete_comment(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
comment = _get_comment(params)
versions.verify_version(comment, ctx)
infix = 'own' if ctx.user.user_id == comment.user_id else 'any'
auth.verify_privilege(ctx.user, 'comments:delete:%s' % infix)
@ -65,20 +78,22 @@ def delete_comment(ctx, params):
return {}
@routes.put('/comment/(?P<comment_id>[^/]+)/score/?')
def set_comment_score(ctx, params):
@rest.routes.put('/comment/(?P<comment_id>[^/]+)/score/?')
def set_comment_score(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'comments:score')
score = ctx.get_param_as_int('score', required=True)
comment = comments.get_comment_by_id(params['comment_id'])
score = ctx.get_param_as_int('score')
comment = _get_comment(params)
scores.set_score(comment, ctx.user, score)
ctx.session.commit()
return _serialize(ctx, comment)
@routes.delete('/comment/(?P<comment_id>[^/]+)/score/?')
def delete_comment_score(ctx, params):
@rest.routes.delete('/comment/(?P<comment_id>[^/]+)/score/?')
def delete_comment_score(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'comments:score')
comment = comments.get_comment_by_id(params['comment_id'])
comment = _get_comment(params)
scores.delete_score(comment, ctx.user)
ctx.session.commit()
return _serialize(ctx, comment)

View file

@ -1,19 +1,20 @@
import datetime
import os
from szurubooru import config
from szurubooru.rest import routes
from typing import Optional, Dict
from datetime import datetime, timedelta
from szurubooru import config, rest
from szurubooru.func import posts, users, util
_cache_time = None
_cache_result = None
_cache_time = None # type: Optional[datetime]
_cache_result = None # type: Optional[int]
def _get_disk_usage():
def _get_disk_usage() -> int:
global _cache_time, _cache_result # pylint: disable=global-statement
threshold = datetime.timedelta(hours=48)
now = datetime.datetime.utcnow()
threshold = timedelta(hours=48)
now = datetime.utcnow()
if _cache_time and _cache_time > now - threshold:
assert _cache_result
return _cache_result
total_size = 0
for dir_path, _, file_names in os.walk(config.config['data_dir']):
@ -25,8 +26,9 @@ def _get_disk_usage():
return total_size
@routes.get('/info/?')
def get_info(ctx, _params=None):
@rest.routes.get('/info/?')
def get_info(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
post_feature = posts.try_get_current_post_feature()
return {
'postCount': posts.get_post_count(),
@ -38,7 +40,7 @@ def get_info(ctx, _params=None):
'featuringUser':
users.serialize_user(post_feature.user, ctx.user)
if post_feature else None,
'serverTime': datetime.datetime.utcnow(),
'serverTime': datetime.utcnow(),
'config': {
'userNameRegex': config.config['user_name_regex'],
'passwordRegex': config.config['password_regex'],

View file

@ -1,5 +1,5 @@
from szurubooru import config, errors
from szurubooru.rest import routes
from typing import Dict
from szurubooru import config, errors, rest
from szurubooru.func import auth, mailer, users, versions
@ -10,9 +10,9 @@ MAIL_BODY = \
'Otherwise, please ignore this email.'
@routes.get('/password-reset/(?P<user_name>[^/]+)/?')
def start_password_reset(_ctx, params):
''' Send a mail with secure token to the correlated user. '''
@rest.routes.get('/password-reset/(?P<user_name>[^/]+)/?')
def start_password_reset(
_ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
user_name = params['user_name']
user = users.get_user_by_name_or_email(user_name)
if not user.email:
@ -30,13 +30,13 @@ def start_password_reset(_ctx, params):
return {}
@routes.post('/password-reset/(?P<user_name>[^/]+)/?')
def finish_password_reset(ctx, params):
''' Verify token from mail, generate a new password and return it. '''
@rest.routes.post('/password-reset/(?P<user_name>[^/]+)/?')
def finish_password_reset(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
user_name = params['user_name']
user = users.get_user_by_name_or_email(user_name)
good_token = auth.generate_authentication_token(user)
token = ctx.get_param_as_string('token', required=True)
token = ctx.get_param_as_string('token')
if token != good_token:
raise errors.ValidationError('Invalid password reset token.')
new_password = users.reset_user_password(user)

View file

@ -1,44 +1,60 @@
import datetime
from szurubooru import search, db, errors
from szurubooru.rest import routes
from typing import Optional, Dict
from datetime import datetime
from szurubooru import db, model, errors, rest, search
from szurubooru.func import (
auth, tags, posts, snapshots, favorites, scores, util, versions)
auth, tags, posts, snapshots, favorites, scores, serialization, versions)
_search_executor = search.Executor(search.configs.PostSearchConfig())
_search_executor_config = search.configs.PostSearchConfig()
_search_executor = search.Executor(_search_executor_config)
def _serialize_post(ctx, post):
def _get_post_id(params: Dict[str, str]) -> int:
try:
return int(params['post_id'])
except TypeError:
raise posts.InvalidPostIdError(
'Invalid post ID: %r.' % params['post_id'])
def _get_post(params: Dict[str, str]) -> model.Post:
return posts.get_post_by_id(_get_post_id(params))
def _serialize_post(
ctx: rest.Context, post: Optional[model.Post]) -> rest.Response:
return posts.serialize_post(
post,
ctx.user,
options=util.get_serialization_options(ctx))
options=serialization.get_serialization_options(ctx))
@routes.get('/posts/?')
def get_posts(ctx, _params=None):
@rest.routes.get('/posts/?')
def get_posts(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:list')
_search_executor.config.user = ctx.user
_search_executor_config.user = ctx.user
return _search_executor.execute_and_serialize(
ctx, lambda post: _serialize_post(ctx, post))
@routes.post('/posts/?')
def create_post(ctx, _params=None):
@rest.routes.post('/posts/?')
def create_post(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
anonymous = ctx.get_param_as_bool('anonymous', default=False)
if anonymous:
auth.verify_privilege(ctx.user, 'posts:create:anonymous')
else:
auth.verify_privilege(ctx.user, 'posts:create:identified')
content = ctx.get_file('content', required=True)
tag_names = ctx.get_param_as_list('tags', required=False, default=[])
safety = ctx.get_param_as_string('safety', required=True)
source = ctx.get_param_as_string('source', required=False, default=None)
content = ctx.get_file('content')
tag_names = ctx.get_param_as_list('tags', default=[])
safety = ctx.get_param_as_string('safety')
source = ctx.get_param_as_string('source', default='')
if ctx.has_param('contentUrl') and not source:
source = ctx.get_param_as_string('contentUrl')
relations = ctx.get_param_as_list('relations', required=False) or []
notes = ctx.get_param_as_list('notes', required=False) or []
flags = ctx.get_param_as_list('flags', required=False) or []
source = ctx.get_param_as_string('contentUrl', default='')
relations = ctx.get_param_as_list('relations', default=[])
notes = ctx.get_param_as_list('notes', default=[])
flags = ctx.get_param_as_list('flags', default=[])
post, new_tags = posts.create_post(
content, tag_names, None if anonymous else ctx.user)
@ -61,16 +77,16 @@ def create_post(ctx, _params=None):
return _serialize_post(ctx, post)
@routes.get('/post/(?P<post_id>[^/]+)/?')
def get_post(ctx, params):
@rest.routes.get('/post/(?P<post_id>[^/]+)/?')
def get_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:view')
post = posts.get_post_by_id(params['post_id'])
post = _get_post(params)
return _serialize_post(ctx, post)
@routes.put('/post/(?P<post_id>[^/]+)/?')
def update_post(ctx, params):
post = posts.get_post_by_id(params['post_id'])
@rest.routes.put('/post/(?P<post_id>[^/]+)/?')
def update_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
post = _get_post(params)
versions.verify_version(post, ctx)
versions.bump_version(post)
if ctx.has_file('content'):
@ -104,7 +120,7 @@ def update_post(ctx, params):
if ctx.has_file('thumbnail'):
auth.verify_privilege(ctx.user, 'posts:edit:thumbnail')
posts.update_post_thumbnail(post, ctx.get_file('thumbnail'))
post.last_edit_time = datetime.datetime.utcnow()
post.last_edit_time = datetime.utcnow()
ctx.session.flush()
snapshots.modify(post, ctx.user)
ctx.session.commit()
@ -112,10 +128,10 @@ def update_post(ctx, params):
return _serialize_post(ctx, post)
@routes.delete('/post/(?P<post_id>[^/]+)/?')
def delete_post(ctx, params):
@rest.routes.delete('/post/(?P<post_id>[^/]+)/?')
def delete_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:delete')
post = posts.get_post_by_id(params['post_id'])
post = _get_post(params)
versions.verify_version(post, ctx)
snapshots.delete(post, ctx.user)
posts.delete(post)
@ -124,13 +140,14 @@ def delete_post(ctx, params):
return {}
@routes.post('/post-merge/?')
def merge_posts(ctx, _params=None):
source_post_id = ctx.get_param_as_string('remove', required=True) or ''
target_post_id = ctx.get_param_as_string('mergeTo', required=True) or ''
replace_content = ctx.get_param_as_bool('replaceContent')
@rest.routes.post('/post-merge/?')
def merge_posts(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
source_post_id = ctx.get_param_as_int('remove')
target_post_id = ctx.get_param_as_int('mergeTo')
source_post = posts.get_post_by_id(source_post_id)
target_post = posts.get_post_by_id(target_post_id)
replace_content = ctx.get_param_as_bool('replaceContent')
versions.verify_version(source_post, ctx, 'removeVersion')
versions.verify_version(target_post, ctx, 'mergeToVersion')
versions.bump_version(target_post)
@ -141,16 +158,18 @@ def merge_posts(ctx, _params=None):
return _serialize_post(ctx, target_post)
@routes.get('/featured-post/?')
def get_featured_post(ctx, _params=None):
@rest.routes.get('/featured-post/?')
def get_featured_post(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
post = posts.try_get_featured_post()
return _serialize_post(ctx, post)
@routes.post('/featured-post/?')
def set_featured_post(ctx, _params=None):
@rest.routes.post('/featured-post/?')
def set_featured_post(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:feature')
post_id = ctx.get_param_as_int('id', required=True)
post_id = ctx.get_param_as_int('id')
post = posts.get_post_by_id(post_id)
featured_post = posts.try_get_featured_post()
if featured_post and featured_post.post_id == post.post_id:
@ -162,55 +181,61 @@ def set_featured_post(ctx, _params=None):
return _serialize_post(ctx, post)
@routes.put('/post/(?P<post_id>[^/]+)/score/?')
def set_post_score(ctx, params):
@rest.routes.put('/post/(?P<post_id>[^/]+)/score/?')
def set_post_score(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:score')
post = posts.get_post_by_id(params['post_id'])
score = ctx.get_param_as_int('score', required=True)
post = _get_post(params)
score = ctx.get_param_as_int('score')
scores.set_score(post, ctx.user, score)
ctx.session.commit()
return _serialize_post(ctx, post)
@routes.delete('/post/(?P<post_id>[^/]+)/score/?')
def delete_post_score(ctx, params):
@rest.routes.delete('/post/(?P<post_id>[^/]+)/score/?')
def delete_post_score(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:score')
post = posts.get_post_by_id(params['post_id'])
post = _get_post(params)
scores.delete_score(post, ctx.user)
ctx.session.commit()
return _serialize_post(ctx, post)
@routes.post('/post/(?P<post_id>[^/]+)/favorite/?')
def add_post_to_favorites(ctx, params):
@rest.routes.post('/post/(?P<post_id>[^/]+)/favorite/?')
def add_post_to_favorites(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:favorite')
post = posts.get_post_by_id(params['post_id'])
post = _get_post(params)
favorites.set_favorite(post, ctx.user)
ctx.session.commit()
return _serialize_post(ctx, post)
@routes.delete('/post/(?P<post_id>[^/]+)/favorite/?')
def delete_post_from_favorites(ctx, params):
@rest.routes.delete('/post/(?P<post_id>[^/]+)/favorite/?')
def delete_post_from_favorites(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:favorite')
post = posts.get_post_by_id(params['post_id'])
post = _get_post(params)
favorites.unset_favorite(post, ctx.user)
ctx.session.commit()
return _serialize_post(ctx, post)
@routes.get('/post/(?P<post_id>[^/]+)/around/?')
def get_posts_around(ctx, params):
@rest.routes.get('/post/(?P<post_id>[^/]+)/around/?')
def get_posts_around(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:list')
_search_executor.config.user = ctx.user
_search_executor_config.user = ctx.user
post_id = _get_post_id(params)
return _search_executor.get_around_and_serialize(
ctx, params['post_id'], lambda post: _serialize_post(ctx, post))
ctx, post_id, lambda post: _serialize_post(ctx, post))
@routes.post('/posts/reverse-search/?')
def get_posts_by_image(ctx, _params=None):
@rest.routes.post('/posts/reverse-search/?')
def get_posts_by_image(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:reverse_search')
content = ctx.get_file('content', required=True)
content = ctx.get_file('content')
try:
lookalikes = posts.search_by_image(content)

View file

@ -1,13 +1,14 @@
from szurubooru import search
from szurubooru.rest import routes
from typing import Dict
from szurubooru import search, rest
from szurubooru.func import auth, snapshots
_search_executor = search.Executor(search.configs.SnapshotSearchConfig())
@routes.get('/snapshots/?')
def get_snapshots(ctx, _params=None):
@rest.routes.get('/snapshots/?')
def get_snapshots(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
auth.verify_privilege(ctx.user, 'snapshots:list')
return _search_executor.execute_and_serialize(
ctx, lambda snapshot: snapshots.serialize_snapshot(snapshot, ctx.user))

View file

@ -1,18 +1,22 @@
import datetime
from szurubooru import db, search
from szurubooru.rest import routes
from szurubooru.func import auth, tags, snapshots, util, versions
from typing import Optional, List, Dict
from datetime import datetime
from szurubooru import db, model, search, rest
from szurubooru.func import auth, tags, snapshots, serialization, versions
_search_executor = search.Executor(search.configs.TagSearchConfig())
def _serialize(ctx, tag):
def _serialize(ctx: rest.Context, tag: model.Tag) -> rest.Response:
return tags.serialize_tag(
tag, options=util.get_serialization_options(ctx))
tag, options=serialization.get_serialization_options(ctx))
def _create_if_needed(tag_names, user):
def _get_tag(params: Dict[str, str]) -> model.Tag:
return tags.get_tag_by_name(params['tag_name'])
def _create_if_needed(tag_names: List[str], user: model.User) -> None:
if not tag_names:
return
_existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names)
@ -23,25 +27,22 @@ def _create_if_needed(tag_names, user):
snapshots.create(tag, user)
@routes.get('/tags/?')
def get_tags(ctx, _params=None):
@rest.routes.get('/tags/?')
def get_tags(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
auth.verify_privilege(ctx.user, 'tags:list')
return _search_executor.execute_and_serialize(
ctx, lambda tag: _serialize(ctx, tag))
@routes.post('/tags/?')
def create_tag(ctx, _params=None):
@rest.routes.post('/tags/?')
def create_tag(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
auth.verify_privilege(ctx.user, 'tags:create')
names = ctx.get_param_as_list('names', required=True)
category = ctx.get_param_as_string('category', required=True)
description = ctx.get_param_as_string(
'description', required=False, default=None)
suggestions = ctx.get_param_as_list(
'suggestions', required=False, default=[])
implications = ctx.get_param_as_list(
'implications', required=False, default=[])
names = ctx.get_param_as_list('names')
category = ctx.get_param_as_string('category')
description = ctx.get_param_as_string('description', default='')
suggestions = ctx.get_param_as_list('suggestions', default=[])
implications = ctx.get_param_as_list('implications', default=[])
_create_if_needed(suggestions, ctx.user)
_create_if_needed(implications, ctx.user)
@ -56,16 +57,16 @@ def create_tag(ctx, _params=None):
return _serialize(ctx, tag)
@routes.get('/tag/(?P<tag_name>.+)')
def get_tag(ctx, params):
@rest.routes.get('/tag/(?P<tag_name>.+)')
def get_tag(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'tags:view')
tag = tags.get_tag_by_name(params['tag_name'])
tag = _get_tag(params)
return _serialize(ctx, tag)
@routes.put('/tag/(?P<tag_name>.+)')
def update_tag(ctx, params):
tag = tags.get_tag_by_name(params['tag_name'])
@rest.routes.put('/tag/(?P<tag_name>.+)')
def update_tag(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
tag = _get_tag(params)
versions.verify_version(tag, ctx)
versions.bump_version(tag)
if ctx.has_param('names'):
@ -78,7 +79,7 @@ def update_tag(ctx, params):
if ctx.has_param('description'):
auth.verify_privilege(ctx.user, 'tags:edit:description')
tags.update_tag_description(
tag, ctx.get_param_as_string('description', default=None))
tag, ctx.get_param_as_string('description'))
if ctx.has_param('suggestions'):
auth.verify_privilege(ctx.user, 'tags:edit:suggestions')
suggestions = ctx.get_param_as_list('suggestions')
@ -89,7 +90,7 @@ def update_tag(ctx, params):
implications = ctx.get_param_as_list('implications')
_create_if_needed(implications, ctx.user)
tags.update_tag_implications(tag, implications)
tag.last_edit_time = datetime.datetime.utcnow()
tag.last_edit_time = datetime.utcnow()
ctx.session.flush()
snapshots.modify(tag, ctx.user)
ctx.session.commit()
@ -97,9 +98,9 @@ def update_tag(ctx, params):
return _serialize(ctx, tag)
@routes.delete('/tag/(?P<tag_name>.+)')
def delete_tag(ctx, params):
tag = tags.get_tag_by_name(params['tag_name'])
@rest.routes.delete('/tag/(?P<tag_name>.+)')
def delete_tag(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
tag = _get_tag(params)
versions.verify_version(tag, ctx)
auth.verify_privilege(ctx.user, 'tags:delete')
snapshots.delete(tag, ctx.user)
@ -109,10 +110,11 @@ def delete_tag(ctx, params):
return {}
@routes.post('/tag-merge/?')
def merge_tags(ctx, _params=None):
source_tag_name = ctx.get_param_as_string('remove', required=True) or ''
target_tag_name = ctx.get_param_as_string('mergeTo', required=True) or ''
@rest.routes.post('/tag-merge/?')
def merge_tags(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
source_tag_name = ctx.get_param_as_string('remove')
target_tag_name = ctx.get_param_as_string('mergeTo')
source_tag = tags.get_tag_by_name(source_tag_name)
target_tag = tags.get_tag_by_name(target_tag_name)
versions.verify_version(source_tag, ctx, 'removeVersion')
@ -126,10 +128,11 @@ def merge_tags(ctx, _params=None):
return _serialize(ctx, target_tag)
@routes.get('/tag-siblings/(?P<tag_name>.+)')
def get_tag_siblings(ctx, params):
@rest.routes.get('/tag-siblings/(?P<tag_name>.+)')
def get_tag_siblings(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'tags:view')
tag = tags.get_tag_by_name(params['tag_name'])
tag = _get_tag(params)
result = tags.get_tag_siblings(tag)
serialized_siblings = []
for sibling, occurrences in result:

View file

@ -1,15 +1,18 @@
from szurubooru.rest import routes
from typing import Dict
from szurubooru import model, rest
from szurubooru.func import (
auth, tags, tag_categories, snapshots, util, versions)
auth, tags, tag_categories, snapshots, serialization, versions)
def _serialize(ctx, category):
def _serialize(
ctx: rest.Context, category: model.TagCategory) -> rest.Response:
return tag_categories.serialize_category(
category, options=util.get_serialization_options(ctx))
category, options=serialization.get_serialization_options(ctx))
@routes.get('/tag-categories/?')
def get_tag_categories(ctx, _params=None):
@rest.routes.get('/tag-categories/?')
def get_tag_categories(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
auth.verify_privilege(ctx.user, 'tag_categories:list')
categories = tag_categories.get_all_categories()
return {
@ -17,11 +20,12 @@ def get_tag_categories(ctx, _params=None):
}
@routes.post('/tag-categories/?')
def create_tag_category(ctx, _params=None):
@rest.routes.post('/tag-categories/?')
def create_tag_category(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
auth.verify_privilege(ctx.user, 'tag_categories:create')
name = ctx.get_param_as_string('name', required=True)
color = ctx.get_param_as_string('color', required=True)
name = ctx.get_param_as_string('name')
color = ctx.get_param_as_string('color')
category = tag_categories.create_category(name, color)
ctx.session.add(category)
ctx.session.flush()
@ -31,15 +35,17 @@ def create_tag_category(ctx, _params=None):
return _serialize(ctx, category)
@routes.get('/tag-category/(?P<category_name>[^/]+)/?')
def get_tag_category(ctx, params):
@rest.routes.get('/tag-category/(?P<category_name>[^/]+)/?')
def get_tag_category(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'tag_categories:view')
category = tag_categories.get_category_by_name(params['category_name'])
return _serialize(ctx, category)
@routes.put('/tag-category/(?P<category_name>[^/]+)/?')
def update_tag_category(ctx, params):
@rest.routes.put('/tag-category/(?P<category_name>[^/]+)/?')
def update_tag_category(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
category = tag_categories.get_category_by_name(
params['category_name'], lock=True)
versions.verify_version(category, ctx)
@ -59,8 +65,9 @@ def update_tag_category(ctx, params):
return _serialize(ctx, category)
@routes.delete('/tag-category/(?P<category_name>[^/]+)/?')
def delete_tag_category(ctx, params):
@rest.routes.delete('/tag-category/(?P<category_name>[^/]+)/?')
def delete_tag_category(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
category = tag_categories.get_category_by_name(
params['category_name'], lock=True)
versions.verify_version(category, ctx)
@ -72,8 +79,9 @@ def delete_tag_category(ctx, params):
return {}
@routes.put('/tag-category/(?P<category_name>[^/]+)/default/?')
def set_tag_category_as_default(ctx, params):
@rest.routes.put('/tag-category/(?P<category_name>[^/]+)/default/?')
def set_tag_category_as_default(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'tag_categories:set_default')
category = tag_categories.get_category_by_name(
params['category_name'], lock=True)

View file

@ -1,10 +1,12 @@
from szurubooru.rest import routes
from typing import Dict
from szurubooru import rest
from szurubooru.func import auth, file_uploads
@routes.post('/uploads/?')
def create_temporary_file(ctx, _params=None):
@rest.routes.post('/uploads/?')
def create_temporary_file(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
auth.verify_privilege(ctx.user, 'uploads:create')
content = ctx.get_file('content', required=True, allow_tokens=False)
content = ctx.get_file('content', allow_tokens=False)
token = file_uploads.save(content)
return {'token': token}

View file

@ -1,56 +1,57 @@
from szurubooru import search
from szurubooru.rest import routes
from szurubooru.func import auth, users, util, versions
from typing import Any, Dict
from szurubooru import model, search, rest
from szurubooru.func import auth, users, serialization, versions
_search_executor = search.Executor(search.configs.UserSearchConfig())
def _serialize(ctx, user, **kwargs):
def _serialize(
ctx: rest.Context, user: model.User, **kwargs: Any) -> rest.Response:
return users.serialize_user(
user,
ctx.user,
options=util.get_serialization_options(ctx),
options=serialization.get_serialization_options(ctx),
**kwargs)
@routes.get('/users/?')
def get_users(ctx, _params=None):
@rest.routes.get('/users/?')
def get_users(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
auth.verify_privilege(ctx.user, 'users:list')
return _search_executor.execute_and_serialize(
ctx, lambda user: _serialize(ctx, user))
@routes.post('/users/?')
def create_user(ctx, _params=None):
@rest.routes.post('/users/?')
def create_user(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
auth.verify_privilege(ctx.user, 'users:create')
name = ctx.get_param_as_string('name', required=True)
password = ctx.get_param_as_string('password', required=True)
email = ctx.get_param_as_string('email', required=False, default='')
name = ctx.get_param_as_string('name')
password = ctx.get_param_as_string('password')
email = ctx.get_param_as_string('email', default='')
user = users.create_user(name, password, email)
if ctx.has_param('rank'):
users.update_user_rank(
user, ctx.get_param_as_string('rank'), ctx.user)
users.update_user_rank(user, ctx.get_param_as_string('rank'), ctx.user)
if ctx.has_param('avatarStyle'):
users.update_user_avatar(
user,
ctx.get_param_as_string('avatarStyle'),
ctx.get_file('avatar'))
ctx.get_file('avatar', default=b''))
ctx.session.add(user)
ctx.session.commit()
return _serialize(ctx, user, force_show_email=True)
@routes.get('/user/(?P<user_name>[^/]+)/?')
def get_user(ctx, params):
@rest.routes.get('/user/(?P<user_name>[^/]+)/?')
def get_user(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
user = users.get_user_by_name(params['user_name'])
if ctx.user.user_id != user.user_id:
auth.verify_privilege(ctx.user, 'users:view')
return _serialize(ctx, user)
@routes.put('/user/(?P<user_name>[^/]+)/?')
def update_user(ctx, params):
@rest.routes.put('/user/(?P<user_name>[^/]+)/?')
def update_user(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
user = users.get_user_by_name(params['user_name'])
versions.verify_version(user, ctx)
versions.bump_version(user)
@ -74,13 +75,13 @@ def update_user(ctx, params):
users.update_user_avatar(
user,
ctx.get_param_as_string('avatarStyle'),
ctx.get_file('avatar'))
ctx.get_file('avatar', default=b''))
ctx.session.commit()
return _serialize(ctx, user)
@routes.delete('/user/(?P<user_name>[^/]+)/?')
def delete_user(ctx, params):
@rest.routes.delete('/user/(?P<user_name>[^/]+)/?')
def delete_user(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
user = users.get_user_by_name(params['user_name'])
versions.verify_version(user, ctx)
infix = 'self' if ctx.user.user_id == user.user_id else 'any'

View file

@ -1,8 +1,9 @@
from typing import Dict
import os
import yaml
def merge(left, right):
def merge(left: Dict, right: Dict) -> Dict:
for key in right:
if key in left:
if isinstance(left[key], dict) and isinstance(right[key], dict):
@ -14,7 +15,7 @@ def merge(left, right):
return left
def read_config():
def read_config() -> Dict:
with open('../config.yaml.dist') as handle:
ret = yaml.load(handle.read())
if os.path.exists('../config.yaml'):

36
server/szurubooru/db.py Normal file
View file

@ -0,0 +1,36 @@
from typing import Any
import threading
import sqlalchemy as sa
import sqlalchemy.orm
from szurubooru import config
# pylint: disable=invalid-name
_data = threading.local()
_engine = sa.create_engine(config.config['database']) # type: Any
sessionmaker = sa.orm.sessionmaker(bind=_engine, autoflush=False) # type: Any
session = sa.orm.scoped_session(sessionmaker) # type: Any
def get_session() -> Any:
global session
return session
def set_sesssion(new_session: Any) -> None:
global session
session = new_session
def reset_query_count() -> None:
_data.query_count = 0
def get_query_count() -> int:
return _data.query_count
def _bump_query_count() -> None:
_data.query_count = getattr(_data, 'query_count', 0) + 1
sa.event.listen(_engine, 'after_execute', lambda *args: _bump_query_count())

View file

@ -1,17 +0,0 @@
from szurubooru.db.base import Base
from szurubooru.db.user import User
from szurubooru.db.tag_category import TagCategory
from szurubooru.db.tag import (Tag, TagName, TagSuggestion, TagImplication)
from szurubooru.db.post import (
Post,
PostTag,
PostRelation,
PostFavorite,
PostScore,
PostNote,
PostFeature)
from szurubooru.db.comment import (Comment, CommentScore)
from szurubooru.db.snapshot import Snapshot
from szurubooru.db.session import (
session, sessionmaker, reset_query_count, get_query_count)
import szurubooru.db.util

View file

@ -1,27 +0,0 @@
import threading
import sqlalchemy
from szurubooru import config
# pylint: disable=invalid-name
_engine = sqlalchemy.create_engine(config.config['database'])
sessionmaker = sqlalchemy.orm.sessionmaker(bind=_engine, autoflush=False)
session = sqlalchemy.orm.scoped_session(sessionmaker)
_data = threading.local()
def reset_query_count():
_data.query_count = 0
def get_query_count():
return _data.query_count
def _bump_query_count():
_data.query_count = getattr(_data, 'query_count', 0) + 1
sqlalchemy.event.listen(
_engine, 'after_execute', lambda *args: _bump_query_count())

View file

@ -1,34 +0,0 @@
from sqlalchemy.inspection import inspect
def get_resource_info(entity):
serializers = {
'tag': lambda tag: tag.first_name,
'tag_category': lambda category: category.name,
'comment': lambda comment: comment.comment_id,
'post': lambda post: post.post_id,
}
resource_type = entity.__table__.name
assert resource_type in serializers
primary_key = inspect(entity).identity
assert primary_key is not None
assert len(primary_key) == 1
resource_name = serializers[resource_type](entity)
assert resource_name
resource_pkey = primary_key[0]
assert resource_pkey
return (resource_type, resource_pkey, resource_name)
def get_aux_entity(session, get_table_info, entity, user):
table, get_column = get_table_info(entity)
return session \
.query(table) \
.filter(get_column(table) == get_column(entity)) \
.filter(table.user_id == user.user_id) \
.one_or_none()

View file

@ -1,5 +1,11 @@
from typing import Dict
class BaseError(RuntimeError):
def __init__(self, message='Unknown error', extra_fields=None):
def __init__(
self,
message: str='Unknown error',
extra_fields: Dict[str, str]=None) -> None:
super().__init__(message)
self.extra_fields = extra_fields

View file

@ -2,7 +2,10 @@ import os
import time
import logging
import threading
from typing import Callable, Any, Type
import coloredlogs
import sqlalchemy as sa
import sqlalchemy.orm.exc
from szurubooru import config, db, errors, rest
from szurubooru.func import posts, file_uploads
@ -10,7 +13,10 @@ from szurubooru.func import posts, file_uploads
from szurubooru import api, middleware
def _map_error(ex, target_class, title):
def _map_error(
ex: Exception,
target_class: Type[rest.errors.BaseHttpError],
title: str) -> rest.errors.BaseHttpError:
return target_class(
name=type(ex).__name__,
title=title,
@ -18,38 +24,38 @@ def _map_error(ex, target_class, title):
extra_fields=getattr(ex, 'extra_fields', {}))
def _on_auth_error(ex):
def _on_auth_error(ex: Exception) -> None:
raise _map_error(ex, rest.errors.HttpForbidden, 'Authentication error')
def _on_validation_error(ex):
def _on_validation_error(ex: Exception) -> None:
raise _map_error(ex, rest.errors.HttpBadRequest, 'Validation error')
def _on_search_error(ex):
def _on_search_error(ex: Exception) -> None:
raise _map_error(ex, rest.errors.HttpBadRequest, 'Search error')
def _on_integrity_error(ex):
def _on_integrity_error(ex: Exception) -> None:
raise _map_error(ex, rest.errors.HttpConflict, 'Integrity violation')
def _on_not_found_error(ex):
def _on_not_found_error(ex: Exception) -> None:
raise _map_error(ex, rest.errors.HttpNotFound, 'Not found')
def _on_processing_error(ex):
def _on_processing_error(ex: Exception) -> None:
raise _map_error(ex, rest.errors.HttpBadRequest, 'Processing error')
def _on_third_party_error(ex):
def _on_third_party_error(ex: Exception) -> None:
raise _map_error(
ex,
rest.errors.HttpInternalServerError,
'Server configuration error')
def _on_stale_data_error(_ex):
def _on_stale_data_error(_ex: Exception) -> None:
raise rest.errors.HttpConflict(
name='IntegrityError',
title='Integrity violation',
@ -58,7 +64,7 @@ def _on_stale_data_error(_ex):
'Please try again.'))
def validate_config():
def validate_config() -> None:
'''
Check whether config doesn't contain errors that might prove
lethal at runtime.
@ -86,7 +92,7 @@ def validate_config():
raise errors.ConfigError('Database is not configured')
def purge_old_uploads():
def purge_old_uploads() -> None:
while True:
try:
file_uploads.purge_old_uploads()
@ -95,7 +101,7 @@ def purge_old_uploads():
time.sleep(60 * 5)
def create_app():
def create_app() -> Callable[[Any, Any], Any]:
''' Create a WSGI compatible App object. '''
validate_config()
coloredlogs.install(fmt='[%(asctime)-15s] %(name)s %(message)s')
@ -122,7 +128,7 @@ def create_app():
rest.errors.handle(errors.NotFoundError, _on_not_found_error)
rest.errors.handle(errors.ProcessingError, _on_processing_error)
rest.errors.handle(errors.ThirdPartyError, _on_third_party_error)
rest.errors.handle(sqlalchemy.orm.exc.StaleDataError, _on_stale_data_error)
rest.errors.handle(sa.orm.exc.StaleDataError, _on_stale_data_error)
return rest.application

View file

@ -1,22 +1,22 @@
import hashlib
import random
from collections import OrderedDict
from szurubooru import config, db, errors
from szurubooru import config, model, errors
from szurubooru.func import util
RANK_MAP = OrderedDict([
(db.User.RANK_ANONYMOUS, 'anonymous'),
(db.User.RANK_RESTRICTED, 'restricted'),
(db.User.RANK_REGULAR, 'regular'),
(db.User.RANK_POWER, 'power'),
(db.User.RANK_MODERATOR, 'moderator'),
(db.User.RANK_ADMINISTRATOR, 'administrator'),
(db.User.RANK_NOBODY, 'nobody'),
(model.User.RANK_ANONYMOUS, 'anonymous'),
(model.User.RANK_RESTRICTED, 'restricted'),
(model.User.RANK_REGULAR, 'regular'),
(model.User.RANK_POWER, 'power'),
(model.User.RANK_MODERATOR, 'moderator'),
(model.User.RANK_ADMINISTRATOR, 'administrator'),
(model.User.RANK_NOBODY, 'nobody'),
])
def get_password_hash(salt, password):
def get_password_hash(salt: str, password: str) -> str:
''' Retrieve new-style password hash. '''
digest = hashlib.sha256()
digest.update(config.config['secret'].encode('utf8'))
@ -25,7 +25,7 @@ def get_password_hash(salt, password):
return digest.hexdigest()
def get_legacy_password_hash(salt, password):
def get_legacy_password_hash(salt: str, password: str) -> str:
''' Retrieve old-style password hash. '''
digest = hashlib.sha1()
digest.update(b'1A2/$_4xVa')
@ -34,7 +34,7 @@ def get_legacy_password_hash(salt, password):
return digest.hexdigest()
def create_password():
def create_password() -> str:
alphabet = {
'c': list('bcdfghijklmnpqrstvwxyz'),
'v': list('aeiou'),
@ -44,7 +44,7 @@ def create_password():
return ''.join(random.choice(alphabet[l]) for l in list(pattern))
def is_valid_password(user, password):
def is_valid_password(user: model.User, password: str) -> bool:
assert user
salt, valid_hash = user.password_salt, user.password_hash
possible_hashes = [
@ -54,7 +54,7 @@ def is_valid_password(user, password):
return valid_hash in possible_hashes
def has_privilege(user, privilege_name):
def has_privilege(user: model.User, privilege_name: str) -> bool:
assert user
all_ranks = list(RANK_MAP.keys())
assert privilege_name in config.config['privileges']
@ -65,13 +65,13 @@ def has_privilege(user, privilege_name):
return user.rank in good_ranks
def verify_privilege(user, privilege_name):
def verify_privilege(user: model.User, privilege_name: str) -> None:
assert user
if not has_privilege(user, privilege_name):
raise errors.AuthError('Insufficient privileges to do this.')
def generate_authentication_token(user):
def generate_authentication_token(user: model.User) -> str:
''' Generate nonguessable challenge (e.g. links in password reminder). '''
assert user
digest = hashlib.md5()

View file

@ -1,21 +1,21 @@
from typing import Any, List, Dict
from datetime import datetime
class LruCacheItem:
def __init__(self, key, value):
def __init__(self, key: object, value: Any) -> None:
self.key = key
self.value = value
self.timestamp = datetime.utcnow()
class LruCache:
def __init__(self, length, delta=None):
def __init__(self, length: int) -> None:
self.length = length
self.delta = delta
self.hash = {}
self.item_list = []
self.hash = {} # type: Dict[object, LruCacheItem]
self.item_list = [] # type: List[LruCacheItem]
def insert_item(self, item):
def insert_item(self, item: LruCacheItem) -> None:
if item.key in self.hash:
item_index = next(
i
@ -31,11 +31,11 @@ class LruCache:
self.hash[item.key] = item
self.item_list.insert(0, item)
def remove_all(self):
def remove_all(self) -> None:
self.hash = {}
self.item_list = []
def remove_item(self, item):
def remove_item(self, item: LruCacheItem) -> None:
del self.hash[item.key]
del self.item_list[self.item_list.index(item)]
@ -43,22 +43,22 @@ class LruCache:
_CACHE = LruCache(length=100)
def purge():
def purge() -> None:
_CACHE.remove_all()
def has(key):
def has(key: object) -> bool:
return key in _CACHE.hash
def get(key):
def get(key: object) -> Any:
return _CACHE.hash[key].value
def remove(key):
def remove(key: object) -> None:
if has(key):
del _CACHE.hash[key]
def put(key, value):
def put(key: object, value: Any) -> None:
_CACHE.insert_item(LruCacheItem(key, value))

View file

@ -1,6 +1,7 @@
import datetime
from szurubooru import db, errors
from szurubooru.func import users, scores, util
from datetime import datetime
from typing import Any, Optional, List, Dict, Callable
from szurubooru import db, model, errors, rest
from szurubooru.func import users, scores, util, serialization
class InvalidCommentIdError(errors.ValidationError):
@ -15,52 +16,87 @@ class EmptyCommentTextError(errors.ValidationError):
pass
def serialize_comment(comment, auth_user, options=None):
return util.serialize_entity(
comment,
{
'id': lambda: comment.comment_id,
'user':
lambda: users.serialize_micro_user(comment.user, auth_user),
'postId': lambda: comment.post.post_id,
'version': lambda: comment.version,
'text': lambda: comment.text,
'creationTime': lambda: comment.creation_time,
'lastEditTime': lambda: comment.last_edit_time,
'score': lambda: comment.score,
'ownScore': lambda: scores.get_score(comment, auth_user),
},
options)
class CommentSerializer(serialization.BaseSerializer):
def __init__(self, comment: model.Comment, auth_user: model.User) -> None:
self.comment = comment
self.auth_user = auth_user
def _serializers(self) -> Dict[str, Callable[[], Any]]:
return {
'id': self.serialize_id,
'user': self.serialize_user,
'postId': self.serialize_post_id,
'version': self.serialize_version,
'text': self.serialize_text,
'creationTime': self.serialize_creation_time,
'lastEditTime': self.serialize_last_edit_time,
'score': self.serialize_score,
'ownScore': self.serialize_own_score,
}
def serialize_id(self) -> Any:
return self.comment.comment_id
def serialize_user(self) -> Any:
return users.serialize_micro_user(self.comment.user, self.auth_user)
def serialize_post_id(self) -> Any:
return self.comment.post.post_id
def serialize_version(self) -> Any:
return self.comment.version
def serialize_text(self) -> Any:
return self.comment.text
def serialize_creation_time(self) -> Any:
return self.comment.creation_time
def serialize_last_edit_time(self) -> Any:
return self.comment.last_edit_time
def serialize_score(self) -> Any:
return self.comment.score
def serialize_own_score(self) -> Any:
return scores.get_score(self.comment, self.auth_user)
def try_get_comment_by_id(comment_id):
try:
def serialize_comment(
comment: model.Comment,
auth_user: model.User,
options: List[str]=[]) -> rest.Response:
if comment is None:
return None
return CommentSerializer(comment, auth_user).serialize(options)
def try_get_comment_by_id(comment_id: int) -> Optional[model.Comment]:
comment_id = int(comment_id)
except ValueError:
raise InvalidCommentIdError('Invalid comment ID: %r.' % comment_id)
return db.session \
.query(db.Comment) \
.filter(db.Comment.comment_id == comment_id) \
.query(model.Comment) \
.filter(model.Comment.comment_id == comment_id) \
.one_or_none()
def get_comment_by_id(comment_id):
def get_comment_by_id(comment_id: int) -> model.Comment:
comment = try_get_comment_by_id(comment_id)
if comment:
return comment
raise CommentNotFoundError('Comment %r not found.' % comment_id)
def create_comment(user, post, text):
comment = db.Comment()
def create_comment(
user: model.User, post: model.Post, text: str) -> model.Comment:
comment = model.Comment()
comment.user = user
comment.post = post
update_comment_text(comment, text)
comment.creation_time = datetime.datetime.utcnow()
comment.creation_time = datetime.utcnow()
return comment
def update_comment_text(comment, text):
def update_comment_text(comment: model.Comment, text: str) -> None:
assert comment
if not text:
raise EmptyCommentTextError('Comment text cannot be empty.')

View file

@ -1,21 +1,26 @@
def get_list_diff(old, new):
value = {'type': 'list change', 'added': [], 'removed': []}
from typing import List, Dict, Any
def get_list_diff(old: List[Any], new: List[Any]) -> Any:
equal = True
removed = [] # type: List[Any]
added = [] # type: List[Any]
for item in old:
if item not in new:
equal = False
value['removed'].append(item)
removed.append(item)
for item in new:
if item not in old:
equal = False
value['added'].append(item)
added.append(item)
return None if equal else value
return None if equal else {
'type': 'list change', 'added': added, 'removed': removed}
def get_dict_diff(old, new):
def get_dict_diff(old: Dict[str, Any], new: Dict[str, Any]) -> Any:
value = {}
equal = True

View file

@ -1,32 +1,34 @@
import datetime
from szurubooru import db, errors
from typing import Any, Optional, Callable, Tuple
from datetime import datetime
from szurubooru import db, model, errors
class InvalidFavoriteTargetError(errors.ValidationError):
pass
def _get_table_info(entity):
def _get_table_info(
entity: model.Base) -> Tuple[model.Base, Callable[[model.Base], Any]]:
assert entity
resource_type, _, _ = db.util.get_resource_info(entity)
resource_type, _, _ = model.util.get_resource_info(entity)
if resource_type == 'post':
return db.PostFavorite, lambda table: table.post_id
return model.PostFavorite, lambda table: table.post_id
raise InvalidFavoriteTargetError()
def _get_fav_entity(entity, user):
def _get_fav_entity(entity: model.Base, user: model.User) -> model.Base:
assert entity
assert user
return db.util.get_aux_entity(db.session, _get_table_info, entity, user)
return model.util.get_aux_entity(db.session, _get_table_info, entity, user)
def has_favorited(entity, user):
def has_favorited(entity: model.Base, user: model.User) -> bool:
assert entity
assert user
return _get_fav_entity(entity, user) is not None
def unset_favorite(entity, user):
def unset_favorite(entity: model.Base, user: Optional[model.User]) -> None:
assert entity
assert user
fav_entity = _get_fav_entity(entity, user)
@ -34,7 +36,7 @@ def unset_favorite(entity, user):
db.session.delete(fav_entity)
def set_favorite(entity, user):
def set_favorite(entity: model.Base, user: Optional[model.User]) -> None:
from szurubooru.func import scores
assert entity
assert user
@ -48,5 +50,5 @@ def set_favorite(entity, user):
fav_entity = table()
setattr(fav_entity, get_column(table).name, get_column(entity))
fav_entity.user = user
fav_entity.time = datetime.datetime.utcnow()
fav_entity.time = datetime.utcnow()
db.session.add(fav_entity)

View file

@ -1,27 +1,28 @@
import datetime
from typing import Optional
from datetime import datetime, timedelta
from szurubooru.func import files, util
MAX_MINUTES = 60
def _get_path(checksum):
def _get_path(checksum: str) -> str:
return 'temporary-uploads/%s.dat' % checksum
def purge_old_uploads():
now = datetime.datetime.now()
def purge_old_uploads() -> None:
now = datetime.now()
for file in files.scan('temporary-uploads'):
file_time = datetime.datetime.fromtimestamp(file.stat().st_ctime)
if now - file_time > datetime.timedelta(minutes=MAX_MINUTES):
file_time = datetime.fromtimestamp(file.stat().st_ctime)
if now - file_time > timedelta(minutes=MAX_MINUTES):
files.delete('temporary-uploads/%s' % file.name)
def get(checksum):
def get(checksum: str) -> Optional[bytes]:
return files.get('temporary-uploads/%s.dat' % checksum)
def save(content):
def save(content: bytes) -> str:
checksum = util.get_sha1(content)
path = _get_path(checksum)
if not files.has(path):

View file

@ -1,32 +1,33 @@
from typing import Any, Optional, List
import os
from szurubooru import config
def _get_full_path(path):
def _get_full_path(path: str) -> str:
return os.path.join(config.config['data_dir'], path)
def delete(path):
def delete(path: str) -> None:
full_path = _get_full_path(path)
if os.path.exists(full_path):
os.unlink(full_path)
def has(path):
def has(path: str) -> bool:
return os.path.exists(_get_full_path(path))
def scan(path):
def scan(path: str) -> List[os.DirEntry]:
if has(path):
return os.scandir(_get_full_path(path))
return list(os.scandir(_get_full_path(path)))
return []
def move(source_path, target_path):
return os.rename(_get_full_path(source_path), _get_full_path(target_path))
def move(source_path: str, target_path: str) -> None:
os.rename(_get_full_path(source_path), _get_full_path(target_path))
def get(path):
def get(path: str) -> Optional[bytes]:
full_path = _get_full_path(path)
if not os.path.exists(full_path):
return None
@ -34,7 +35,7 @@ def get(path):
return handle.read()
def save(path, content):
def save(path: str, content: bytes) -> None:
full_path = _get_full_path(path)
os.makedirs(os.path.dirname(full_path), exist_ok=True)
with open(full_path, 'wb') as handle:

View file

@ -1,6 +1,7 @@
import logging
from io import BytesIO
from datetime import datetime
from typing import Any, Optional, Tuple, Set, List, Callable
import elasticsearch
import elasticsearch_dsl
import numpy as np
@ -10,13 +11,8 @@ from szurubooru import config, errors
# pylint: disable=invalid-name
logger = logging.getLogger(__name__)
es = elasticsearch.Elasticsearch([{
'host': config.config['elasticsearch']['host'],
'port': config.config['elasticsearch']['port'],
}])
# Math based on paper from H. Chi Wong, Marshall Bern and David Goldber
# Math based on paper from H. Chi Wong, Marshall Bern and David Goldberg
# Math code taken from https://github.com/ascribe/image-match
# (which is licensed under Apache 2 license)
@ -32,14 +28,27 @@ MAX_WORDS = 63
ES_DOC_TYPE = 'image'
ES_MAX_RESULTS = 100
Window = Tuple[Tuple[float, float], Tuple[float, float]]
NpMatrix = Any
def _preprocess_image(image_or_path):
img = Image.open(BytesIO(image_or_path))
def _get_session() -> elasticsearch.Elasticsearch:
return elasticsearch.Elasticsearch([{
'host': config.config['elasticsearch']['host'],
'port': config.config['elasticsearch']['port'],
}])
def _preprocess_image(content: bytes) -> NpMatrix:
img = Image.open(BytesIO(content))
img = img.convert('RGB')
return rgb2gray(np.asarray(img, dtype=np.uint8))
def _crop_image(image, lower_percentile, upper_percentile):
def _crop_image(
image: NpMatrix,
lower_percentile: float,
upper_percentile: float) -> Window:
rw = np.cumsum(np.sum(np.abs(np.diff(image, axis=1)), axis=1))
cw = np.cumsum(np.sum(np.abs(np.diff(image, axis=0)), axis=0))
upper_column_limit = np.searchsorted(
@ -56,16 +65,19 @@ def _crop_image(image, lower_percentile, upper_percentile):
if lower_column_limit > upper_column_limit:
lower_column_limit = int(lower_percentile / 100. * image.shape[1])
upper_column_limit = int(upper_percentile / 100. * image.shape[1])
return [
return (
(lower_row_limit, upper_row_limit),
(lower_column_limit, upper_column_limit)]
(lower_column_limit, upper_column_limit))
def _normalize_and_threshold(diff_array, identical_tolerance, n_levels):
def _normalize_and_threshold(
diff_array: NpMatrix,
identical_tolerance: float,
n_levels: int) -> None:
mask = np.abs(diff_array) < identical_tolerance
diff_array[mask] = 0.
if np.all(mask):
return None
return
positive_cutoffs = np.percentile(
diff_array[diff_array > 0.], np.linspace(0, 100, n_levels + 1))
negative_cutoffs = np.percentile(
@ -82,18 +94,24 @@ def _normalize_and_threshold(diff_array, identical_tolerance, n_levels):
diff_array[
(diff_array <= interval[0]) & (diff_array >= interval[1])] = \
-(level + 1)
return None
def _compute_grid_points(image, n, window=None):
def _compute_grid_points(
image: NpMatrix,
n: float,
window: Window=None) -> Tuple[NpMatrix, NpMatrix]:
if window is None:
window = [(0, image.shape[0]), (0, image.shape[1])]
window = ((0, image.shape[0]), (0, image.shape[1]))
x_coords = np.linspace(window[0][0], window[0][1], n + 2, dtype=int)[1:-1]
y_coords = np.linspace(window[1][0], window[1][1], n + 2, dtype=int)[1:-1]
return x_coords, y_coords
def _compute_mean_level(image, x_coords, y_coords, p):
def _compute_mean_level(
image: NpMatrix,
x_coords: NpMatrix,
y_coords: NpMatrix,
p: Optional[float]) -> NpMatrix:
if p is None:
p = max([2.0, int(0.5 + min(image.shape) / 20.)])
avg_grey = np.zeros((x_coords.shape[0], y_coords.shape[0]))
@ -108,7 +126,7 @@ def _compute_mean_level(image, x_coords, y_coords, p):
return avg_grey
def _compute_differentials(grey_level_matrix):
def _compute_differentials(grey_level_matrix: NpMatrix) -> NpMatrix:
flipped = np.fliplr(grey_level_matrix)
right_neighbors = -np.concatenate(
(
@ -152,8 +170,8 @@ def _compute_differentials(grey_level_matrix):
lower_right_neighbors]))
def _generate_signature(path_or_image):
im_array = _preprocess_image(path_or_image)
def _generate_signature(content: bytes) -> NpMatrix:
im_array = _preprocess_image(content)
image_limits = _crop_image(
im_array,
lower_percentile=LOWER_PERCENTILE,
@ -169,7 +187,7 @@ def _generate_signature(path_or_image):
return np.ravel(diff_matrix).astype('int8')
def _get_words(array, k, n):
def _get_words(array: NpMatrix, k: int, n: int) -> NpMatrix:
word_positions = np.linspace(
0, array.shape[0], n, endpoint=False).astype('int')
assert k <= array.shape[0]
@ -187,21 +205,23 @@ def _get_words(array, k, n):
return words
def _words_to_int(word_array):
def _words_to_int(word_array: NpMatrix) -> NpMatrix:
width = word_array.shape[1]
coding_vector = 3**np.arange(width)
return np.dot(word_array + 1, coding_vector)
def _max_contrast(array):
def _max_contrast(array: NpMatrix) -> None:
array[array > 0] = 1
array[array < 0] = -1
return None
def _normalized_distance(_target_array, _vec, nan_value=1.0):
target_array = _target_array.astype(int)
vec = _vec.astype(int)
def _normalized_distance(
target_array: NpMatrix,
vec: NpMatrix,
nan_value: float=1.0) -> List[float]:
target_array = target_array.astype(int)
vec = vec.astype(int)
topvec = np.linalg.norm(vec - target_array, axis=1)
norm1 = np.linalg.norm(vec, axis=0)
norm2 = np.linalg.norm(target_array, axis=1)
@ -210,9 +230,9 @@ def _normalized_distance(_target_array, _vec, nan_value=1.0):
return finvec
def _safety_blanket(default_param_factory):
def wrapper_outer(target_function):
def wrapper_inner(*args, **kwargs):
def _safety_blanket(default_param_factory: Callable[[], Any]) -> Callable:
def wrapper_outer(target_function: Callable) -> Callable:
def wrapper_inner(*args: Any, **kwargs: Any) -> Any:
try:
return target_function(*args, **kwargs)
except elasticsearch.exceptions.NotFoundError:
@ -226,20 +246,20 @@ def _safety_blanket(default_param_factory):
except IOError:
raise errors.ProcessingError('Not an image.')
except Exception as ex:
raise errors.ThirdPartyError('Unknown error (%s).', ex)
raise errors.ThirdPartyError('Unknown error (%s).' % ex)
return wrapper_inner
return wrapper_outer
class Lookalike:
def __init__(self, score, distance, path):
def __init__(self, score: int, distance: float, path: Any) -> None:
self.score = score
self.distance = distance
self.path = path
@_safety_blanket(lambda: None)
def add_image(path, image_content):
def add_image(path: str, image_content: bytes) -> None:
assert path
assert image_content
signature = _generate_signature(image_content)
@ -253,7 +273,7 @@ def add_image(path, image_content):
for i in range(MAX_WORDS):
record['simple_word_' + str(i)] = words[i].tolist()
es.index(
_get_session().index(
index=config.config['elasticsearch']['index'],
doc_type=ES_DOC_TYPE,
body=record,
@ -261,20 +281,20 @@ def add_image(path, image_content):
@_safety_blanket(lambda: None)
def delete_image(path):
def delete_image(path: str) -> None:
assert path
es.delete_by_query(
_get_session().delete_by_query(
index=config.config['elasticsearch']['index'],
doc_type=ES_DOC_TYPE,
body={'query': {'term': {'path': path}}})
@_safety_blanket(lambda: [])
def search_by_image(image_content):
def search_by_image(image_content: bytes) -> List[Lookalike]:
signature = _generate_signature(image_content)
words = _get_words(signature, k=SAMPLE_WORDS, n=MAX_WORDS)
res = es.search(
res = _get_session().search(
index=config.config['elasticsearch']['index'],
doc_type=ES_DOC_TYPE,
body={
@ -299,7 +319,7 @@ def search_by_image(image_content):
sigs = np.array([x['_source']['signature'] for x in res])
dists = _normalized_distance(sigs, np.array(signature))
ids = set()
ids = set() # type: Set[int]
ret = []
for item, dist in zip(res, dists):
id = item['_id']
@ -314,8 +334,8 @@ def search_by_image(image_content):
@_safety_blanket(lambda: None)
def purge():
es.delete_by_query(
def purge() -> None:
_get_session().delete_by_query(
index=config.config['elasticsearch']['index'],
doc_type=ES_DOC_TYPE,
body={'query': {'match_all': {}}},
@ -323,10 +343,10 @@ def purge():
@_safety_blanket(lambda: set())
def get_all_paths():
def get_all_paths() -> Set[str]:
search = (
elasticsearch_dsl.Search(
using=es,
using=_get_session(),
index=config.config['elasticsearch']['index'],
doc_type=ES_DOC_TYPE)
.source(['path']))

View file

@ -1,3 +1,4 @@
from typing import List
import logging
import json
import shlex
@ -15,23 +16,23 @@ _SCALE_FIT_FMT = \
class Image:
def __init__(self, content):
def __init__(self, content: bytes) -> None:
self.content = content
self._reload_info()
@property
def width(self):
def width(self) -> int:
return self.info['streams'][0]['width']
@property
def height(self):
def height(self) -> int:
return self.info['streams'][0]['height']
@property
def frames(self):
def frames(self) -> int:
return self.info['streams'][0]['nb_read_frames']
def resize_fill(self, width, height):
def resize_fill(self, width: int, height: int) -> None:
cli = [
'-i', '{path}',
'-f', 'image2',
@ -53,7 +54,7 @@ class Image:
assert self.content
self._reload_info()
def to_png(self):
def to_png(self) -> bytes:
return self._execute([
'-i', '{path}',
'-f', 'image2',
@ -63,7 +64,7 @@ class Image:
'-',
])
def to_jpeg(self):
def to_jpeg(self) -> bytes:
return self._execute([
'-f', 'lavfi',
'-i', 'color=white:s=%dx%d' % (self.width, self.height),
@ -76,7 +77,7 @@ class Image:
'-',
])
def _execute(self, cli, program='ffmpeg'):
def _execute(self, cli: List[str], program: str='ffmpeg') -> bytes:
extension = mime.get_extension(mime.get_mime_type(self.content))
assert extension
with util.create_temp_file(suffix='.' + extension) as handle:
@ -99,7 +100,7 @@ class Image:
'Error while processing image.\n' + err.decode('utf-8'))
return out
def _reload_info(self):
def _reload_info(self) -> None:
self.info = json.loads(self._execute([
'-i', '{path}',
'-of', 'json',

View file

@ -3,7 +3,7 @@ import email.mime.text
from szurubooru import config
def send_mail(sender, recipient, subject, body):
def send_mail(sender: str, recipient: str, subject: str, body: str) -> None:
msg = email.mime.text.MIMEText(body)
msg['Subject'] = subject
msg['From'] = sender

View file

@ -1,7 +1,8 @@
import re
from typing import Optional
def get_mime_type(content):
def get_mime_type(content: bytes) -> str:
if not content:
return 'application/octet-stream'
@ -26,7 +27,7 @@ def get_mime_type(content):
return 'application/octet-stream'
def get_extension(mime_type):
def get_extension(mime_type: str) -> Optional[str]:
extension_map = {
'application/x-shockwave-flash': 'swf',
'image/gif': 'gif',
@ -39,19 +40,19 @@ def get_extension(mime_type):
return extension_map.get((mime_type or '').strip().lower(), None)
def is_flash(mime_type):
def is_flash(mime_type: str) -> bool:
return mime_type.lower() == 'application/x-shockwave-flash'
def is_video(mime_type):
def is_video(mime_type: str) -> bool:
return mime_type.lower() in ('application/ogg', 'video/mp4', 'video/webm')
def is_image(mime_type):
def is_image(mime_type: str) -> bool:
return mime_type.lower() in ('image/jpeg', 'image/png', 'image/gif')
def is_animated_gif(content):
def is_animated_gif(content: bytes) -> bool:
pattern = b'\x21\xF9\x04[\x00-\xFF]{4}\x00[\x2C\x21]'
return get_mime_type(content) == 'image/gif' \
and len(re.findall(pattern, content)) > 1

View file

@ -2,7 +2,7 @@ import urllib.request
from szurubooru import errors
def download(url):
def download(url: str) -> bytes:
assert url
request = urllib.request.Request(url)
request.add_header('Referer', url)

View file

@ -1,8 +1,10 @@
import datetime
import sqlalchemy
from szurubooru import config, db, errors
from typing import Any, Optional, Tuple, List, Dict, Callable
from datetime import datetime
import sqlalchemy as sa
from szurubooru import config, db, model, errors, rest
from szurubooru.func import (
users, scores, comments, tags, util, mime, images, files, image_hash)
users, scores, comments, tags, util,
mime, images, files, image_hash, serialization)
EMPTY_PIXEL = \
@ -20,7 +22,7 @@ class PostAlreadyFeaturedError(errors.ValidationError):
class PostAlreadyUploadedError(errors.ValidationError):
def __init__(self, other_post):
def __init__(self, other_post: model.Post) -> None:
super().__init__(
'Post already uploaded (%d)' % other_post.post_id,
{
@ -58,30 +60,30 @@ class InvalidPostFlagError(errors.ValidationError):
class PostLookalike(image_hash.Lookalike):
def __init__(self, score, distance, post):
def __init__(self, score: int, distance: float, post: model.Post) -> None:
super().__init__(score, distance, post.post_id)
self.post = post
SAFETY_MAP = {
db.Post.SAFETY_SAFE: 'safe',
db.Post.SAFETY_SKETCHY: 'sketchy',
db.Post.SAFETY_UNSAFE: 'unsafe',
model.Post.SAFETY_SAFE: 'safe',
model.Post.SAFETY_SKETCHY: 'sketchy',
model.Post.SAFETY_UNSAFE: 'unsafe',
}
TYPE_MAP = {
db.Post.TYPE_IMAGE: 'image',
db.Post.TYPE_ANIMATION: 'animation',
db.Post.TYPE_VIDEO: 'video',
db.Post.TYPE_FLASH: 'flash',
model.Post.TYPE_IMAGE: 'image',
model.Post.TYPE_ANIMATION: 'animation',
model.Post.TYPE_VIDEO: 'video',
model.Post.TYPE_FLASH: 'flash',
}
FLAG_MAP = {
db.Post.FLAG_LOOP: 'loop',
model.Post.FLAG_LOOP: 'loop',
}
def get_post_content_url(post):
def get_post_content_url(post: model.Post) -> str:
assert post
return '%s/posts/%d.%s' % (
config.config['data_url'].rstrip('/'),
@ -89,31 +91,31 @@ def get_post_content_url(post):
mime.get_extension(post.mime_type) or 'dat')
def get_post_thumbnail_url(post):
def get_post_thumbnail_url(post: model.Post) -> str:
assert post
return '%s/generated-thumbnails/%d.jpg' % (
config.config['data_url'].rstrip('/'),
post.post_id)
def get_post_content_path(post):
def get_post_content_path(post: model.Post) -> str:
assert post
assert post.post_id
return 'posts/%d.%s' % (
post.post_id, mime.get_extension(post.mime_type) or 'dat')
def get_post_thumbnail_path(post):
def get_post_thumbnail_path(post: model.Post) -> str:
assert post
return 'generated-thumbnails/%d.jpg' % (post.post_id)
def get_post_thumbnail_backup_path(post):
def get_post_thumbnail_backup_path(post: model.Post) -> str:
assert post
return 'posts/custom-thumbnails/%d.dat' % (post.post_id)
def serialize_note(note):
def serialize_note(note: model.PostNote) -> rest.Response:
assert note
return {
'polygon': note.polygon,
@ -121,113 +123,216 @@ def serialize_note(note):
}
def serialize_post(post, auth_user, options=None):
return util.serialize_entity(
post,
class PostSerializer(serialization.BaseSerializer):
def __init__(self, post: model.Post, auth_user: model.User) -> None:
self.post = post
self.auth_user = auth_user
def _serializers(self) -> Dict[str, Callable[[], Any]]:
return {
'id': self.serialize_id,
'version': self.serialize_version,
'creationTime': self.serialize_creation_time,
'lastEditTime': self.serialize_last_edit_time,
'safety': self.serialize_safety,
'source': self.serialize_source,
'type': self.serialize_type,
'mimeType': self.serialize_mime,
'checksum': self.serialize_checksum,
'fileSize': self.serialize_file_size,
'canvasWidth': self.serialize_canvas_width,
'canvasHeight': self.serialize_canvas_height,
'contentUrl': self.serialize_content_url,
'thumbnailUrl': self.serialize_thumbnail_url,
'flags': self.serialize_flags,
'tags': self.serialize_tags,
'relations': self.serialize_relations,
'user': self.serialize_user,
'score': self.serialize_score,
'ownScore': self.serialize_own_score,
'ownFavorite': self.serialize_own_favorite,
'tagCount': self.serialize_tag_count,
'favoriteCount': self.serialize_favorite_count,
'commentCount': self.serialize_comment_count,
'noteCount': self.serialize_note_count,
'relationCount': self.serialize_relation_count,
'featureCount': self.serialize_feature_count,
'lastFeatureTime': self.serialize_last_feature_time,
'favoritedBy': self.serialize_favorited_by,
'hasCustomThumbnail': self.serialize_has_custom_thumbnail,
'notes': self.serialize_notes,
'comments': self.serialize_comments,
}
def serialize_id(self) -> Any:
return self.post.post_id
def serialize_version(self) -> Any:
return self.post.version
def serialize_creation_time(self) -> Any:
return self.post.creation_time
def serialize_last_edit_time(self) -> Any:
return self.post.last_edit_time
def serialize_safety(self) -> Any:
return SAFETY_MAP[self.post.safety]
def serialize_source(self) -> Any:
return self.post.source
def serialize_type(self) -> Any:
return TYPE_MAP[self.post.type]
def serialize_mime(self) -> Any:
return self.post.mime_type
def serialize_checksum(self) -> Any:
return self.post.checksum
def serialize_file_size(self) -> Any:
return self.post.file_size
def serialize_canvas_width(self) -> Any:
return self.post.canvas_width
def serialize_canvas_height(self) -> Any:
return self.post.canvas_height
def serialize_content_url(self) -> Any:
return get_post_content_url(self.post)
def serialize_thumbnail_url(self) -> Any:
return get_post_thumbnail_url(self.post)
def serialize_flags(self) -> Any:
return self.post.flags
def serialize_tags(self) -> Any:
return [tag.names[0].name for tag in tags.sort_tags(self.post.tags)]
def serialize_relations(self) -> Any:
return sorted(
{
'id': lambda: post.post_id,
'version': lambda: post.version,
'creationTime': lambda: post.creation_time,
'lastEditTime': lambda: post.last_edit_time,
'safety': lambda: SAFETY_MAP[post.safety],
'source': lambda: post.source,
'type': lambda: TYPE_MAP[post.type],
'mimeType': lambda: post.mime_type,
'checksum': lambda: post.checksum,
'fileSize': lambda: post.file_size,
'canvasWidth': lambda: post.canvas_width,
'canvasHeight': lambda: post.canvas_height,
'contentUrl': lambda: get_post_content_url(post),
'thumbnailUrl': lambda: get_post_thumbnail_url(post),
'flags': lambda: post.flags,
'tags': lambda: [
tag.names[0].name for tag in tags.sort_tags(post.tags)],
'relations': lambda: sorted(
{
post['id']:
post for post in [
serialize_micro_post(rel, auth_user)
for rel in post.relations]
post['id']: post
for post in [
serialize_micro_post(rel, self.auth_user)
for rel in self.post.relations]
}.values(),
key=lambda post: post['id']),
'user': lambda: users.serialize_micro_user(post.user, auth_user),
'score': lambda: post.score,
'ownScore': lambda: scores.get_score(post, auth_user),
'ownFavorite': lambda: len([
user for user in post.favorited_by
if user.user_id == auth_user.user_id]
) > 0,
'tagCount': lambda: post.tag_count,
'favoriteCount': lambda: post.favorite_count,
'commentCount': lambda: post.comment_count,
'noteCount': lambda: post.note_count,
'relationCount': lambda: post.relation_count,
'featureCount': lambda: post.feature_count,
'lastFeatureTime': lambda: post.last_feature_time,
'favoritedBy': lambda: [
users.serialize_micro_user(rel.user, auth_user)
for rel in post.favorited_by
],
'hasCustomThumbnail':
lambda: files.has(get_post_thumbnail_backup_path(post)),
'notes': lambda: sorted(
[serialize_note(note) for note in post.notes],
key=lambda x: x['polygon']),
'comments': lambda: [
comments.serialize_comment(comment, auth_user)
key=lambda post: post['id'])
def serialize_user(self) -> Any:
return users.serialize_micro_user(self.post.user, self.auth_user)
def serialize_score(self) -> Any:
return self.post.score
def serialize_own_score(self) -> Any:
return scores.get_score(self.post, self.auth_user)
def serialize_own_favorite(self) -> Any:
return len([
user for user in self.post.favorited_by
if user.user_id == self.auth_user.user_id]
) > 0
def serialize_tag_count(self) -> Any:
return self.post.tag_count
def serialize_favorite_count(self) -> Any:
return self.post.favorite_count
def serialize_comment_count(self) -> Any:
return self.post.comment_count
def serialize_note_count(self) -> Any:
return self.post.note_count
def serialize_relation_count(self) -> Any:
return self.post.relation_count
def serialize_feature_count(self) -> Any:
return self.post.feature_count
def serialize_last_feature_time(self) -> Any:
return self.post.last_feature_time
def serialize_favorited_by(self) -> Any:
return [
users.serialize_micro_user(rel.user, self.auth_user)
for rel in self.post.favorited_by
]
def serialize_has_custom_thumbnail(self) -> Any:
return files.has(get_post_thumbnail_backup_path(self.post))
def serialize_notes(self) -> Any:
return sorted(
[serialize_note(note) for note in self.post.notes],
key=lambda x: x['polygon'])
def serialize_comments(self) -> Any:
return [
comments.serialize_comment(comment, self.auth_user)
for comment in sorted(
post.comments,
key=lambda comment: comment.creation_time)],
},
options)
self.post.comments,
key=lambda comment: comment.creation_time)]
def serialize_micro_post(post, auth_user):
def serialize_post(
post: Optional[model.Post],
auth_user: model.User,
options: List[str]=[]) -> Optional[rest.Response]:
if not post:
return None
return PostSerializer(post, auth_user).serialize(options)
def serialize_micro_post(
post: model.Post, auth_user: model.User) -> Optional[rest.Response]:
return serialize_post(
post,
auth_user=auth_user,
options=['id', 'thumbnailUrl'])
post, auth_user=auth_user, options=['id', 'thumbnailUrl'])
def get_post_count():
return db.session.query(sqlalchemy.func.count(db.Post.post_id)).one()[0]
def get_post_count() -> int:
return db.session.query(sa.func.count(model.Post.post_id)).one()[0]
def try_get_post_by_id(post_id):
try:
post_id = int(post_id)
except ValueError:
raise InvalidPostIdError('Invalid post ID: %r.' % post_id)
def try_get_post_by_id(post_id: int) -> Optional[model.Post]:
return db.session \
.query(db.Post) \
.filter(db.Post.post_id == post_id) \
.query(model.Post) \
.filter(model.Post.post_id == post_id) \
.one_or_none()
def get_post_by_id(post_id):
def get_post_by_id(post_id: int) -> model.Post:
post = try_get_post_by_id(post_id)
if not post:
raise PostNotFoundError('Post %r not found.' % post_id)
return post
def try_get_current_post_feature():
def try_get_current_post_feature() -> Optional[model.PostFeature]:
return db.session \
.query(db.PostFeature) \
.order_by(db.PostFeature.time.desc()) \
.query(model.PostFeature) \
.order_by(model.PostFeature.time.desc()) \
.first()
def try_get_featured_post():
def try_get_featured_post() -> Optional[model.Post]:
post_feature = try_get_current_post_feature()
return post_feature.post if post_feature else None
def create_post(content, tag_names, user):
post = db.Post()
post.safety = db.Post.SAFETY_SAFE
def create_post(
content: bytes,
tag_names: List[str],
user: Optional[model.User]) -> Tuple[model.Post, List[model.Tag]]:
post = model.Post()
post.safety = model.Post.SAFETY_SAFE
post.user = user
post.creation_time = datetime.datetime.utcnow()
post.creation_time = datetime.utcnow()
post.flags = []
post.type = ''
@ -240,7 +345,7 @@ def create_post(content, tag_names, user):
return (post, new_tags)
def update_post_safety(post, safety):
def update_post_safety(post: model.Post, safety: str) -> None:
assert post
safety = util.flip(SAFETY_MAP).get(safety, None)
if not safety:
@ -249,30 +354,33 @@ def update_post_safety(post, safety):
post.safety = safety
def update_post_source(post, source):
def update_post_source(post: model.Post, source: Optional[str]) -> None:
assert post
if util.value_exceeds_column_size(source, db.Post.source):
if util.value_exceeds_column_size(source, model.Post.source):
raise InvalidPostSourceError('Source is too long.')
post.source = source
post.source = source or None
@sqlalchemy.events.event.listens_for(db.Post, 'after_insert')
def _after_post_insert(_mapper, _connection, post):
@sa.events.event.listens_for(model.Post, 'after_insert')
def _after_post_insert(
_mapper: Any, _connection: Any, post: model.Post) -> None:
_sync_post_content(post)
@sqlalchemy.events.event.listens_for(db.Post, 'after_update')
def _after_post_update(_mapper, _connection, post):
@sa.events.event.listens_for(model.Post, 'after_update')
def _after_post_update(
_mapper: Any, _connection: Any, post: model.Post) -> None:
_sync_post_content(post)
@sqlalchemy.events.event.listens_for(db.Post, 'before_delete')
def _before_post_delete(_mapper, _connection, post):
@sa.events.event.listens_for(model.Post, 'before_delete')
def _before_post_delete(
_mapper: Any, _connection: Any, post: model.Post) -> None:
if post.post_id:
image_hash.delete_image(post.post_id)
def _sync_post_content(post):
def _sync_post_content(post: model.Post) -> None:
regenerate_thumb = False
if hasattr(post, '__content'):
@ -281,7 +389,7 @@ def _sync_post_content(post):
delattr(post, '__content')
regenerate_thumb = True
if post.post_id and post.type in (
db.Post.TYPE_IMAGE, db.Post.TYPE_ANIMATION):
model.Post.TYPE_IMAGE, model.Post.TYPE_ANIMATION):
image_hash.delete_image(post.post_id)
image_hash.add_image(post.post_id, content)
@ -299,29 +407,29 @@ def _sync_post_content(post):
generate_post_thumbnail(post)
def update_post_content(post, content):
def update_post_content(post: model.Post, content: Optional[bytes]) -> None:
assert post
if not content:
raise InvalidPostContentError('Post content missing.')
post.mime_type = mime.get_mime_type(content)
if mime.is_flash(post.mime_type):
post.type = db.Post.TYPE_FLASH
post.type = model.Post.TYPE_FLASH
elif mime.is_image(post.mime_type):
if mime.is_animated_gif(content):
post.type = db.Post.TYPE_ANIMATION
post.type = model.Post.TYPE_ANIMATION
else:
post.type = db.Post.TYPE_IMAGE
post.type = model.Post.TYPE_IMAGE
elif mime.is_video(post.mime_type):
post.type = db.Post.TYPE_VIDEO
post.type = model.Post.TYPE_VIDEO
else:
raise InvalidPostContentError(
'Unhandled file type: %r' % post.mime_type)
post.checksum = util.get_sha1(content)
other_post = db.session \
.query(db.Post) \
.filter(db.Post.checksum == post.checksum) \
.filter(db.Post.post_id != post.post_id) \
.query(model.Post) \
.filter(model.Post.checksum == post.checksum) \
.filter(model.Post.post_id != post.post_id) \
.one_or_none()
if other_post \
and other_post.post_id \
@ -343,18 +451,20 @@ def update_post_content(post, content):
setattr(post, '__content', content)
def update_post_thumbnail(post, content=None):
def update_post_thumbnail(
post: model.Post, content: Optional[bytes]=None) -> None:
assert post
setattr(post, '__thumbnail', content)
def generate_post_thumbnail(post):
def generate_post_thumbnail(post: model.Post) -> None:
assert post
if files.has(get_post_thumbnail_backup_path(post)):
content = files.get(get_post_thumbnail_backup_path(post))
else:
content = files.get(get_post_content_path(post))
try:
assert content
image = images.Image(content)
image.resize_fill(
int(config.config['thumbnails']['post_width']),
@ -364,14 +474,15 @@ def generate_post_thumbnail(post):
files.save(get_post_thumbnail_path(post), EMPTY_PIXEL)
def update_post_tags(post, tag_names):
def update_post_tags(
post: model.Post, tag_names: List[str]) -> List[model.Tag]:
assert post
existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names)
post.tags = existing_tags + new_tags
return new_tags
def update_post_relations(post, new_post_ids):
def update_post_relations(post: model.Post, new_post_ids: List[int]) -> None:
assert post
try:
new_post_ids = [int(id) for id in new_post_ids]
@ -382,8 +493,8 @@ def update_post_relations(post, new_post_ids):
old_post_ids = [int(p.post_id) for p in old_posts]
if new_post_ids:
new_posts = db.session \
.query(db.Post) \
.filter(db.Post.post_id.in_(new_post_ids)) \
.query(model.Post) \
.filter(model.Post.post_id.in_(new_post_ids)) \
.all()
else:
new_posts = []
@ -402,7 +513,7 @@ def update_post_relations(post, new_post_ids):
relation.relations.append(post)
def update_post_notes(post, notes):
def update_post_notes(post: model.Post, notes: Any) -> None:
assert post
post.notes = []
for note in notes:
@ -433,13 +544,13 @@ def update_post_notes(post, notes):
except ValueError:
raise InvalidPostNoteError(
'A point in note\'s polygon must be numeric.')
if util.value_exceeds_column_size(note['text'], db.PostNote.text):
if util.value_exceeds_column_size(note['text'], model.PostNote.text):
raise InvalidPostNoteError('Note text is too long.')
post.notes.append(
db.PostNote(polygon=note['polygon'], text=str(note['text'])))
model.PostNote(polygon=note['polygon'], text=str(note['text'])))
def update_post_flags(post, flags):
def update_post_flags(post: model.Post, flags: List[str]) -> None:
assert post
target_flags = []
for flag in flags:
@ -451,88 +562,95 @@ def update_post_flags(post, flags):
post.flags = target_flags
def feature_post(post, user):
def feature_post(post: model.Post, user: Optional[model.User]) -> None:
assert post
post_feature = db.PostFeature()
post_feature.time = datetime.datetime.utcnow()
post_feature = model.PostFeature()
post_feature.time = datetime.utcnow()
post_feature.post = post
post_feature.user = user
db.session.add(post_feature)
def delete(post):
def delete(post: model.Post) -> None:
assert post
db.session.delete(post)
def merge_posts(source_post, target_post, replace_content):
def merge_posts(
source_post: model.Post,
target_post: model.Post,
replace_content: bool) -> None:
assert source_post
assert target_post
if source_post.post_id == target_post.post_id:
raise InvalidPostRelationError('Cannot merge post with itself.')
def merge_tables(table, anti_dup_func, source_post_id, target_post_id):
def merge_tables(
table: model.Base,
anti_dup_func: Optional[Callable[[model.Base, model.Base], bool]],
source_post_id: int,
target_post_id: int) -> None:
alias1 = table
alias2 = sqlalchemy.orm.util.aliased(table)
alias2 = sa.orm.util.aliased(table)
update_stmt = (
sqlalchemy.sql.expression.update(alias1)
sa.sql.expression.update(alias1)
.where(alias1.post_id == source_post_id))
if anti_dup_func is not None:
update_stmt = (
update_stmt
.where(
~sqlalchemy.exists()
~sa.exists()
.where(anti_dup_func(alias1, alias2))
.where(alias2.post_id == target_post_id)))
update_stmt = update_stmt.values(post_id=target_post_id)
db.session.execute(update_stmt)
def merge_tags(source_post_id, target_post_id):
def merge_tags(source_post_id: int, target_post_id: int) -> None:
merge_tables(
db.PostTag,
model.PostTag,
lambda alias1, alias2: alias1.tag_id == alias2.tag_id,
source_post_id,
target_post_id)
def merge_scores(source_post_id, target_post_id):
def merge_scores(source_post_id: int, target_post_id: int) -> None:
merge_tables(
db.PostScore,
model.PostScore,
lambda alias1, alias2: alias1.user_id == alias2.user_id,
source_post_id,
target_post_id)
def merge_favorites(source_post_id, target_post_id):
def merge_favorites(source_post_id: int, target_post_id: int) -> None:
merge_tables(
db.PostFavorite,
model.PostFavorite,
lambda alias1, alias2: alias1.user_id == alias2.user_id,
source_post_id,
target_post_id)
def merge_comments(source_post_id, target_post_id):
merge_tables(db.Comment, None, source_post_id, target_post_id)
def merge_comments(source_post_id: int, target_post_id: int) -> None:
merge_tables(model.Comment, None, source_post_id, target_post_id)
def merge_relations(source_post_id, target_post_id):
alias1 = db.PostRelation
alias2 = sqlalchemy.orm.util.aliased(db.PostRelation)
def merge_relations(source_post_id: int, target_post_id: int) -> None:
alias1 = model.PostRelation
alias2 = sa.orm.util.aliased(model.PostRelation)
update_stmt = (
sqlalchemy.sql.expression.update(alias1)
sa.sql.expression.update(alias1)
.where(alias1.parent_id == source_post_id)
.where(alias1.child_id != target_post_id)
.where(
~sqlalchemy.exists()
~sa.exists()
.where(alias2.child_id == alias1.child_id)
.where(alias2.parent_id == target_post_id))
.values(parent_id=target_post_id))
db.session.execute(update_stmt)
update_stmt = (
sqlalchemy.sql.expression.update(alias1)
sa.sql.expression.update(alias1)
.where(alias1.child_id == source_post_id)
.where(alias1.parent_id != target_post_id)
.where(
~sqlalchemy.exists()
~sa.exists()
.where(alias2.parent_id == alias1.parent_id)
.where(alias2.child_id == target_post_id))
.values(child_id=target_post_id))
@ -553,15 +671,15 @@ def merge_posts(source_post, target_post, replace_content):
update_post_content(target_post, content)
def search_by_image_exact(image_content):
def search_by_image_exact(image_content: bytes) -> Optional[model.Post]:
checksum = util.get_sha1(image_content)
return db.session \
.query(db.Post) \
.filter(db.Post.checksum == checksum) \
.query(model.Post) \
.filter(model.Post.checksum == checksum) \
.one_or_none()
def search_by_image(image_content):
def search_by_image(image_content: bytes) -> List[PostLookalike]:
ret = []
for result in image_hash.search_by_image(image_content):
ret.append(PostLookalike(
@ -571,24 +689,24 @@ def search_by_image(image_content):
return ret
def populate_reverse_search():
def populate_reverse_search() -> None:
excluded_post_ids = image_hash.get_all_paths()
post_ids_to_hash = (
db.session
.query(db.Post.post_id)
.query(model.Post.post_id)
.filter(
(db.Post.type == db.Post.TYPE_IMAGE) |
(db.Post.type == db.Post.TYPE_ANIMATION))
.filter(~db.Post.post_id.in_(excluded_post_ids))
.order_by(db.Post.post_id.asc())
(model.Post.type == model.Post.TYPE_IMAGE) |
(model.Post.type == model.Post.TYPE_ANIMATION))
.filter(~model.Post.post_id.in_(excluded_post_ids))
.order_by(model.Post.post_id.asc())
.all())
for post_ids_chunk in util.chunks(post_ids_to_hash, 100):
posts_chunk = (
db.session
.query(db.Post)
.filter(db.Post.post_id.in_(post_ids_chunk))
.query(model.Post)
.filter(model.Post.post_id.in_(post_ids_chunk))
.all())
for post in posts_chunk:
content_path = get_post_content_path(post)

View file

@ -1,5 +1,6 @@
import datetime
from szurubooru import db, errors
from typing import Any, Tuple, Callable
from szurubooru import db, model, errors
class InvalidScoreTargetError(errors.ValidationError):
@ -10,22 +11,23 @@ class InvalidScoreValueError(errors.ValidationError):
pass
def _get_table_info(entity):
def _get_table_info(
entity: model.Base) -> Tuple[model.Base, Callable[[model.Base], Any]]:
assert entity
resource_type, _, _ = db.util.get_resource_info(entity)
resource_type, _, _ = model.util.get_resource_info(entity)
if resource_type == 'post':
return db.PostScore, lambda table: table.post_id
return model.PostScore, lambda table: table.post_id
elif resource_type == 'comment':
return db.CommentScore, lambda table: table.comment_id
return model.CommentScore, lambda table: table.comment_id
raise InvalidScoreTargetError()
def _get_score_entity(entity, user):
def _get_score_entity(entity: model.Base, user: model.User) -> model.Base:
assert user
return db.util.get_aux_entity(db.session, _get_table_info, entity, user)
return model.util.get_aux_entity(db.session, _get_table_info, entity, user)
def delete_score(entity, user):
def delete_score(entity: model.Base, user: model.User) -> None:
assert entity
assert user
score_entity = _get_score_entity(entity, user)
@ -33,7 +35,7 @@ def delete_score(entity, user):
db.session.delete(score_entity)
def get_score(entity, user):
def get_score(entity: model.Base, user: model.User) -> int:
assert entity
assert user
table, get_column = _get_table_info(entity)
@ -45,7 +47,7 @@ def get_score(entity, user):
return row[0] if row else 0
def set_score(entity, user, score):
def set_score(entity: model.Base, user: model.User, score: int) -> None:
from szurubooru.func import favorites
assert entity
assert user

View file

@ -0,0 +1,27 @@
from typing import Any, Optional, List, Dict, Callable
from szurubooru import db, model, rest, errors
def get_serialization_options(ctx: rest.Context) -> List[str]:
return ctx.get_param_as_list('fields', default=[])
class BaseSerializer:
_fields = {} # type: Dict[str, Callable[[model.Base], Any]]
def serialize(self, options: List[str]) -> Any:
field_factories = self._serializers()
if not options:
options = list(field_factories.keys())
ret = {}
for key in options:
if key not in field_factories:
raise errors.ValidationError(
'Invalid key: %r. Valid keys: %r.' % (
key, list(sorted(field_factories.keys()))))
factory = field_factories[key]
ret[key] = factory()
return ret
def _serializers(self) -> Dict[str, Callable[[], Any]]:
raise NotImplementedError()

View file

@ -1,9 +1,10 @@
from typing import Any, Optional, Dict, Callable
from datetime import datetime
from szurubooru import db
from szurubooru import db, model
from szurubooru.func import diff, users
def get_tag_category_snapshot(category):
def get_tag_category_snapshot(category: model.TagCategory) -> Dict[str, Any]:
assert category
return {
'name': category.name,
@ -12,7 +13,7 @@ def get_tag_category_snapshot(category):
}
def get_tag_snapshot(tag):
def get_tag_snapshot(tag: model.Tag) -> Dict[str, Any]:
assert tag
return {
'names': [tag_name.name for tag_name in tag.names],
@ -22,7 +23,7 @@ def get_tag_snapshot(tag):
}
def get_post_snapshot(post):
def get_post_snapshot(post: model.Post) -> Dict[str, Any]:
assert post
return {
'source': post.source,
@ -45,10 +46,11 @@ _snapshot_factories = {
'tag_category': lambda entity: get_tag_category_snapshot(entity),
'tag': lambda entity: get_tag_snapshot(entity),
'post': lambda entity: get_post_snapshot(entity),
}
} # type: Dict[model.Base, Callable[[model.Base], Dict[str ,Any]]]
def serialize_snapshot(snapshot, auth_user):
def serialize_snapshot(
snapshot: model.Snapshot, auth_user: model.User) -> Dict[str, Any]:
assert snapshot
return {
'operation': snapshot.operation,
@ -60,11 +62,14 @@ def serialize_snapshot(snapshot, auth_user):
}
def _create(operation, entity, auth_user):
def _create(
operation: str,
entity: model.Base,
auth_user: Optional[model.User]) -> model.Snapshot:
resource_type, resource_pkey, resource_name = (
db.util.get_resource_info(entity))
model.util.get_resource_info(entity))
snapshot = db.Snapshot()
snapshot = model.Snapshot()
snapshot.creation_time = datetime.utcnow()
snapshot.operation = operation
snapshot.resource_type = resource_type
@ -74,33 +79,33 @@ def _create(operation, entity, auth_user):
return snapshot
def create(entity, auth_user):
def create(entity: model.Base, auth_user: Optional[model.User]) -> None:
assert entity
snapshot = _create(db.Snapshot.OPERATION_CREATED, entity, auth_user)
snapshot = _create(model.Snapshot.OPERATION_CREATED, entity, auth_user)
snapshot_factory = _snapshot_factories[snapshot.resource_type]
snapshot.data = snapshot_factory(entity)
db.session.add(snapshot)
# pylint: disable=protected-access
def modify(entity, auth_user):
def modify(entity: model.Base, auth_user: Optional[model.User]) -> None:
assert entity
model = next(
table = next(
(
model
for model in db.Base._decl_class_registry.values()
if hasattr(model, '__table__')
and model.__table__.fullname == entity.__table__.fullname
cls
for cls in model.Base._decl_class_registry.values()
if hasattr(cls, '__table__')
and cls.__table__.fullname == entity.__table__.fullname
),
None)
assert model
assert table
snapshot = _create(db.Snapshot.OPERATION_MODIFIED, entity, auth_user)
snapshot = _create(model.Snapshot.OPERATION_MODIFIED, entity, auth_user)
snapshot_factory = _snapshot_factories[snapshot.resource_type]
detached_session = db.sessionmaker()
detached_entity = detached_session.query(model).get(snapshot.resource_pkey)
detached_entity = detached_session.query(table).get(snapshot.resource_pkey)
assert detached_entity, 'Entity not found in DB, have you committed it?'
detached_snapshot = snapshot_factory(detached_entity)
detached_session.close()
@ -113,19 +118,23 @@ def modify(entity, auth_user):
db.session.add(snapshot)
def delete(entity, auth_user):
def delete(entity: model.Base, auth_user: Optional[model.User]) -> None:
assert entity
snapshot = _create(db.Snapshot.OPERATION_DELETED, entity, auth_user)
snapshot = _create(model.Snapshot.OPERATION_DELETED, entity, auth_user)
snapshot_factory = _snapshot_factories[snapshot.resource_type]
snapshot.data = snapshot_factory(entity)
db.session.add(snapshot)
def merge(source_entity, target_entity, auth_user):
def merge(
source_entity: model.Base,
target_entity: model.Base,
auth_user: Optional[model.User]) -> None:
assert source_entity
assert target_entity
snapshot = _create(db.Snapshot.OPERATION_MERGED, source_entity, auth_user)
snapshot = _create(
model.Snapshot.OPERATION_MERGED, source_entity, auth_user)
resource_type, _resource_pkey, resource_name = (
db.util.get_resource_info(target_entity))
model.util.get_resource_info(target_entity))
snapshot.data = [resource_type, resource_name]
db.session.add(snapshot)

View file

@ -1,7 +1,8 @@
import re
import sqlalchemy
from szurubooru import config, db, errors
from szurubooru.func import util, cache
from typing import Any, Optional, Dict, List, Callable
import sqlalchemy as sa
from szurubooru import config, db, model, errors, rest
from szurubooru.func import util, serialization, cache
DEFAULT_CATEGORY_NAME_CACHE_KEY = 'default-tag-category'
@ -27,28 +28,52 @@ class InvalidTagCategoryColorError(errors.ValidationError):
pass
def _verify_name_validity(name):
def _verify_name_validity(name: str) -> None:
name_regex = config.config['tag_category_name_regex']
if not re.match(name_regex, name):
raise InvalidTagCategoryNameError(
'Name must satisfy regex %r.' % name_regex)
def serialize_category(category, options=None):
return util.serialize_entity(
category,
{
'name': lambda: category.name,
'version': lambda: category.version,
'color': lambda: category.color,
'usages': lambda: category.tag_count,
'default': lambda: category.default,
},
options)
class TagCategorySerializer(serialization.BaseSerializer):
def __init__(self, category: model.TagCategory) -> None:
self.category = category
def _serializers(self) -> Dict[str, Callable[[], Any]]:
return {
'name': self.serialize_name,
'version': self.serialize_version,
'color': self.serialize_color,
'usages': self.serialize_usages,
'default': self.serialize_default,
}
def serialize_name(self) -> Any:
return self.category.name
def serialize_version(self) -> Any:
return self.category.version
def serialize_color(self) -> Any:
return self.category.color
def serialize_usages(self) -> Any:
return self.category.tag_count
def serialize_default(self) -> Any:
return self.category.default
def create_category(name, color):
category = db.TagCategory()
def serialize_category(
category: Optional[model.TagCategory],
options: List[str]=[]) -> Optional[rest.Response]:
if not category:
return None
return TagCategorySerializer(category).serialize(options)
def create_category(name: str, color: str) -> model.TagCategory:
category = model.TagCategory()
update_category_name(category, name)
update_category_color(category, color)
if not get_all_categories():
@ -56,64 +81,66 @@ def create_category(name, color):
return category
def update_category_name(category, name):
def update_category_name(category: model.TagCategory, name: str) -> None:
assert category
if not name:
raise InvalidTagCategoryNameError('Name cannot be empty.')
expr = sqlalchemy.func.lower(db.TagCategory.name) == name.lower()
expr = sa.func.lower(model.TagCategory.name) == name.lower()
if category.tag_category_id:
expr = expr & (
db.TagCategory.tag_category_id != category.tag_category_id)
already_exists = db.session.query(db.TagCategory).filter(expr).count() > 0
model.TagCategory.tag_category_id != category.tag_category_id)
already_exists = (
db.session.query(model.TagCategory).filter(expr).count() > 0)
if already_exists:
raise TagCategoryAlreadyExistsError(
'A category with this name already exists.')
if util.value_exceeds_column_size(name, db.TagCategory.name):
if util.value_exceeds_column_size(name, model.TagCategory.name):
raise InvalidTagCategoryNameError('Name is too long.')
_verify_name_validity(name)
category.name = name
cache.remove(DEFAULT_CATEGORY_NAME_CACHE_KEY)
def update_category_color(category, color):
def update_category_color(category: model.TagCategory, color: str) -> None:
assert category
if not color:
raise InvalidTagCategoryColorError('Color cannot be empty.')
if not re.match(r'^#?[0-9a-z]+$', color):
raise InvalidTagCategoryColorError('Invalid color.')
if util.value_exceeds_column_size(color, db.TagCategory.color):
if util.value_exceeds_column_size(color, model.TagCategory.color):
raise InvalidTagCategoryColorError('Color is too long.')
category.color = color
def try_get_category_by_name(name, lock=False):
def try_get_category_by_name(
name: str, lock: bool=False) -> Optional[model.TagCategory]:
query = db.session \
.query(db.TagCategory) \
.filter(sqlalchemy.func.lower(db.TagCategory.name) == name.lower())
.query(model.TagCategory) \
.filter(sa.func.lower(model.TagCategory.name) == name.lower())
if lock:
query = query.with_lockmode('update')
return query.one_or_none()
def get_category_by_name(name, lock=False):
def get_category_by_name(name: str, lock: bool=False) -> model.TagCategory:
category = try_get_category_by_name(name, lock)
if not category:
raise TagCategoryNotFoundError('Tag category %r not found.' % name)
return category
def get_all_category_names():
return [row[0] for row in db.session.query(db.TagCategory.name).all()]
def get_all_category_names() -> List[str]:
return [row[0] for row in db.session.query(model.TagCategory.name).all()]
def get_all_categories():
return db.session.query(db.TagCategory).all()
def get_all_categories() -> List[model.TagCategory]:
return db.session.query(model.TagCategory).all()
def try_get_default_category(lock=False):
def try_get_default_category(lock: bool=False) -> Optional[model.TagCategory]:
query = db.session \
.query(db.TagCategory) \
.filter(db.TagCategory.default)
.query(model.TagCategory) \
.filter(model.TagCategory.default)
if lock:
query = query.with_lockmode('update')
category = query.first()
@ -121,22 +148,22 @@ def try_get_default_category(lock=False):
# category, get the first record available.
if not category:
query = db.session \
.query(db.TagCategory) \
.order_by(db.TagCategory.tag_category_id.asc())
.query(model.TagCategory) \
.order_by(model.TagCategory.tag_category_id.asc())
if lock:
query = query.with_lockmode('update')
category = query.first()
return category
def get_default_category(lock=False):
def get_default_category(lock: bool=False) -> model.TagCategory:
category = try_get_default_category(lock)
if not category:
raise TagCategoryNotFoundError('No tag category created yet.')
return category
def get_default_category_name():
def get_default_category_name() -> str:
if cache.has(DEFAULT_CATEGORY_NAME_CACHE_KEY):
return cache.get(DEFAULT_CATEGORY_NAME_CACHE_KEY)
default_category = get_default_category()
@ -145,7 +172,7 @@ def get_default_category_name():
return default_category_name
def set_default_category(category):
def set_default_category(category: model.TagCategory) -> None:
assert category
old_category = try_get_default_category(lock=True)
if old_category:
@ -156,7 +183,7 @@ def set_default_category(category):
cache.remove(DEFAULT_CATEGORY_NAME_CACHE_KEY)
def delete_category(category):
def delete_category(category: model.TagCategory) -> None:
assert category
if len(get_all_category_names()) == 1:
raise TagCategoryIsInUseError('Cannot delete the last category.')

View file

@ -1,10 +1,11 @@
import datetime
import json
import os
import re
import sqlalchemy
from szurubooru import config, db, errors
from szurubooru.func import util, tag_categories
from typing import Any, Optional, Tuple, List, Dict, Callable
from datetime import datetime
import sqlalchemy as sa
from szurubooru import config, db, model, errors, rest
from szurubooru.func import util, tag_categories, serialization
class TagNotFoundError(errors.NotFoundError):
@ -35,31 +36,32 @@ class InvalidTagDescriptionError(errors.ValidationError):
pass
def _verify_name_validity(name):
if util.value_exceeds_column_size(name, db.TagName.name):
def _verify_name_validity(name: str) -> None:
if util.value_exceeds_column_size(name, model.TagName.name):
raise InvalidTagNameError('Name is too long.')
name_regex = config.config['tag_name_regex']
if not re.match(name_regex, name):
raise InvalidTagNameError('Name must satisfy regex %r.' % name_regex)
def _get_names(tag):
def _get_names(tag: model.Tag) -> List[str]:
assert tag
return [tag_name.name for tag_name in tag.names]
def _lower_list(names):
def _lower_list(names: List[str]) -> List[str]:
return [name.lower() for name in names]
def _check_name_intersection(names1, names2, case_sensitive):
def _check_name_intersection(
names1: List[str], names2: List[str], case_sensitive: bool) -> bool:
if not case_sensitive:
names1 = _lower_list(names1)
names2 = _lower_list(names2)
return len(set(names1).intersection(names2)) > 0
def sort_tags(tags):
def sort_tags(tags: List[model.Tag]) -> List[model.Tag]:
default_category_name = tag_categories.get_default_category_name()
return sorted(
tags,
@ -70,35 +72,70 @@ def sort_tags(tags):
)
def serialize_tag(tag, options=None):
return util.serialize_entity(
tag,
{
'names': lambda: [tag_name.name for tag_name in tag.names],
'category': lambda: tag.category.name,
'version': lambda: tag.version,
'description': lambda: tag.description,
'creationTime': lambda: tag.creation_time,
'lastEditTime': lambda: tag.last_edit_time,
'usages': lambda: tag.post_count,
'suggestions': lambda: [
class TagSerializer(serialization.BaseSerializer):
def __init__(self, tag: model.Tag) -> None:
self.tag = tag
def _serializers(self) -> Dict[str, Callable[[], Any]]:
return {
'names': self.serialize_names,
'category': self.serialize_category,
'version': self.serialize_version,
'description': self.serialize_description,
'creationTime': self.serialize_creation_time,
'lastEditTime': self.serialize_last_edit_time,
'usages': self.serialize_usages,
'suggestions': self.serialize_suggestions,
'implications': self.serialize_implications,
}
def serialize_names(self) -> Any:
return [tag_name.name for tag_name in self.tag.names]
def serialize_category(self) -> Any:
return self.tag.category.name
def serialize_version(self) -> Any:
return self.tag.version
def serialize_description(self) -> Any:
return self.tag.description
def serialize_creation_time(self) -> Any:
return self.tag.creation_time
def serialize_last_edit_time(self) -> Any:
return self.tag.last_edit_time
def serialize_usages(self) -> Any:
return self.tag.post_count
def serialize_suggestions(self) -> Any:
return [
relation.names[0].name
for relation in sort_tags(tag.suggestions)],
'implications': lambda: [
for relation in sort_tags(self.tag.suggestions)]
def serialize_implications(self) -> Any:
return [
relation.names[0].name
for relation in sort_tags(tag.implications)],
},
options)
for relation in sort_tags(self.tag.implications)]
def export_to_json():
tags = {}
categories = {}
def serialize_tag(
tag: model.Tag, options: List[str]=[]) -> Optional[rest.Response]:
if not tag:
return None
return TagSerializer(tag).serialize(options)
def export_to_json() -> None:
tags = {} # type: Dict[int, Any]
categories = {} # type: Dict[int, Any]
for result in db.session.query(
db.TagCategory.tag_category_id,
db.TagCategory.name,
db.TagCategory.color).all():
model.TagCategory.tag_category_id,
model.TagCategory.name,
model.TagCategory.color).all():
categories[result[0]] = {
'name': result[1],
'color': result[2],
@ -106,8 +143,8 @@ def export_to_json():
for result in (
db.session
.query(db.TagName.tag_id, db.TagName.name)
.order_by(db.TagName.order)
.query(model.TagName.tag_id, model.TagName.name)
.order_by(model.TagName.order)
.all()):
if not result[0] in tags:
tags[result[0]] = {'names': []}
@ -115,8 +152,10 @@ def export_to_json():
for result in (
db.session
.query(db.TagSuggestion.parent_id, db.TagName.name)
.join(db.TagName, db.TagName.tag_id == db.TagSuggestion.child_id)
.query(model.TagSuggestion.parent_id, model.TagName.name)
.join(
model.TagName,
model.TagName.tag_id == model.TagSuggestion.child_id)
.all()):
if 'suggestions' not in tags[result[0]]:
tags[result[0]]['suggestions'] = []
@ -124,17 +163,19 @@ def export_to_json():
for result in (
db.session
.query(db.TagImplication.parent_id, db.TagName.name)
.join(db.TagName, db.TagName.tag_id == db.TagImplication.child_id)
.query(model.TagImplication.parent_id, model.TagName.name)
.join(
model.TagName,
model.TagName.tag_id == model.TagImplication.child_id)
.all()):
if 'implications' not in tags[result[0]]:
tags[result[0]]['implications'] = []
tags[result[0]]['implications'].append(result[1])
for result in db.session.query(
db.Tag.tag_id,
db.Tag.category_id,
db.Tag.post_count).all():
model.Tag.tag_id,
model.Tag.category_id,
model.Tag.post_count).all():
tags[result[0]]['category'] = categories[result[1]]['name']
tags[result[0]]['usages'] = result[2]
@ -148,33 +189,34 @@ def export_to_json():
handle.write(json.dumps(output, separators=(',', ':')))
def try_get_tag_by_name(name):
def try_get_tag_by_name(name: str) -> Optional[model.Tag]:
return (
db.session
.query(db.Tag)
.join(db.TagName)
.filter(sqlalchemy.func.lower(db.TagName.name) == name.lower())
.query(model.Tag)
.join(model.TagName)
.filter(sa.func.lower(model.TagName.name) == name.lower())
.one_or_none())
def get_tag_by_name(name):
def get_tag_by_name(name: str) -> model.Tag:
tag = try_get_tag_by_name(name)
if not tag:
raise TagNotFoundError('Tag %r not found.' % name)
return tag
def get_tags_by_names(names):
def get_tags_by_names(names: List[str]) -> List[model.Tag]:
names = util.icase_unique(names)
if len(names) == 0:
return []
expr = sqlalchemy.sql.false()
expr = sa.sql.false()
for name in names:
expr = expr | (sqlalchemy.func.lower(db.TagName.name) == name.lower())
return db.session.query(db.Tag).join(db.TagName).filter(expr).all()
expr = expr | (sa.func.lower(model.TagName.name) == name.lower())
return db.session.query(model.Tag).join(model.TagName).filter(expr).all()
def get_or_create_tags_by_names(names):
def get_or_create_tags_by_names(
names: List[str]) -> Tuple[List[model.Tag], List[model.Tag]]:
names = util.icase_unique(names)
existing_tags = get_tags_by_names(names)
new_tags = []
@ -197,86 +239,87 @@ def get_or_create_tags_by_names(names):
return existing_tags, new_tags
def get_tag_siblings(tag):
def get_tag_siblings(tag: model.Tag) -> List[model.Tag]:
assert tag
tag_alias = sqlalchemy.orm.aliased(db.Tag)
pt_alias1 = sqlalchemy.orm.aliased(db.PostTag)
pt_alias2 = sqlalchemy.orm.aliased(db.PostTag)
tag_alias = sa.orm.aliased(model.Tag)
pt_alias1 = sa.orm.aliased(model.PostTag)
pt_alias2 = sa.orm.aliased(model.PostTag)
result = (
db.session
.query(tag_alias, sqlalchemy.func.count(pt_alias2.post_id))
.query(tag_alias, sa.func.count(pt_alias2.post_id))
.join(pt_alias1, pt_alias1.tag_id == tag_alias.tag_id)
.join(pt_alias2, pt_alias2.post_id == pt_alias1.post_id)
.filter(pt_alias2.tag_id == tag.tag_id)
.filter(pt_alias1.tag_id != tag.tag_id)
.group_by(tag_alias.tag_id)
.order_by(sqlalchemy.func.count(pt_alias2.post_id).desc())
.order_by(sa.func.count(pt_alias2.post_id).desc())
.limit(50))
return result
def delete(source_tag):
def delete(source_tag: model.Tag) -> None:
assert source_tag
db.session.execute(
sqlalchemy.sql.expression.delete(db.TagSuggestion)
.where(db.TagSuggestion.child_id == source_tag.tag_id))
sa.sql.expression.delete(model.TagSuggestion)
.where(model.TagSuggestion.child_id == source_tag.tag_id))
db.session.execute(
sqlalchemy.sql.expression.delete(db.TagImplication)
.where(db.TagImplication.child_id == source_tag.tag_id))
sa.sql.expression.delete(model.TagImplication)
.where(model.TagImplication.child_id == source_tag.tag_id))
db.session.delete(source_tag)
def merge_tags(source_tag, target_tag):
def merge_tags(source_tag: model.Tag, target_tag: model.Tag) -> None:
assert source_tag
assert target_tag
if source_tag.tag_id == target_tag.tag_id:
raise InvalidTagRelationError('Cannot merge tag with itself.')
def merge_posts(source_tag_id, target_tag_id):
alias1 = db.PostTag
alias2 = sqlalchemy.orm.util.aliased(db.PostTag)
def merge_posts(source_tag_id: int, target_tag_id: int) -> None:
alias1 = model.PostTag
alias2 = sa.orm.util.aliased(model.PostTag)
update_stmt = (
sqlalchemy.sql.expression.update(alias1)
sa.sql.expression.update(alias1)
.where(alias1.tag_id == source_tag_id))
update_stmt = (
update_stmt
.where(
~sqlalchemy.exists()
~sa.exists()
.where(alias1.post_id == alias2.post_id)
.where(alias2.tag_id == target_tag_id)))
update_stmt = update_stmt.values(tag_id=target_tag_id)
db.session.execute(update_stmt)
def merge_relations(table, source_tag_id, target_tag_id):
def merge_relations(
table: model.Base, source_tag_id: int, target_tag_id: int) -> None:
alias1 = table
alias2 = sqlalchemy.orm.util.aliased(table)
alias2 = sa.orm.util.aliased(table)
update_stmt = (
sqlalchemy.sql.expression.update(alias1)
sa.sql.expression.update(alias1)
.where(alias1.parent_id == source_tag_id)
.where(alias1.child_id != target_tag_id)
.where(
~sqlalchemy.exists()
~sa.exists()
.where(alias2.child_id == alias1.child_id)
.where(alias2.parent_id == target_tag_id))
.values(parent_id=target_tag_id))
db.session.execute(update_stmt)
update_stmt = (
sqlalchemy.sql.expression.update(alias1)
sa.sql.expression.update(alias1)
.where(alias1.child_id == source_tag_id)
.where(alias1.parent_id != target_tag_id)
.where(
~sqlalchemy.exists()
~sa.exists()
.where(alias2.parent_id == alias1.parent_id)
.where(alias2.child_id == target_tag_id))
.values(child_id=target_tag_id))
db.session.execute(update_stmt)
def merge_suggestions(source_tag_id, target_tag_id):
merge_relations(db.TagSuggestion, source_tag_id, target_tag_id)
def merge_suggestions(source_tag_id: int, target_tag_id: int) -> None:
merge_relations(model.TagSuggestion, source_tag_id, target_tag_id)
def merge_implications(source_tag_id, target_tag_id):
merge_relations(db.TagImplication, source_tag_id, target_tag_id)
def merge_implications(source_tag_id: int, target_tag_id: int) -> None:
merge_relations(model.TagImplication, source_tag_id, target_tag_id)
merge_posts(source_tag.tag_id, target_tag.tag_id)
merge_suggestions(source_tag.tag_id, target_tag.tag_id)
@ -284,9 +327,13 @@ def merge_tags(source_tag, target_tag):
delete(source_tag)
def create_tag(names, category_name, suggestions, implications):
tag = db.Tag()
tag.creation_time = datetime.datetime.utcnow()
def create_tag(
names: List[str],
category_name: str,
suggestions: List[str],
implications: List[str]) -> model.Tag:
tag = model.Tag()
tag.creation_time = datetime.utcnow()
update_tag_names(tag, names)
update_tag_category_name(tag, category_name)
update_tag_suggestions(tag, suggestions)
@ -294,12 +341,12 @@ def create_tag(names, category_name, suggestions, implications):
return tag
def update_tag_category_name(tag, category_name):
def update_tag_category_name(tag: model.Tag, category_name: str) -> None:
assert tag
tag.category = tag_categories.get_category_by_name(category_name)
def update_tag_names(tag, names):
def update_tag_names(tag: model.Tag, names: List[str]) -> None:
# sanitize
assert tag
names = util.icase_unique([name for name in names if name])
@ -309,12 +356,12 @@ def update_tag_names(tag, names):
_verify_name_validity(name)
# check for existing tags
expr = sqlalchemy.sql.false()
expr = sa.sql.false()
for name in names:
expr = expr | (sqlalchemy.func.lower(db.TagName.name) == name.lower())
expr = expr | (sa.func.lower(model.TagName.name) == name.lower())
if tag.tag_id:
expr = expr & (db.TagName.tag_id != tag.tag_id)
existing_tags = db.session.query(db.TagName).filter(expr).all()
expr = expr & (model.TagName.tag_id != tag.tag_id)
existing_tags = db.session.query(model.TagName).filter(expr).all()
if len(existing_tags):
raise TagAlreadyExistsError(
'One of names is already used by another tag.')
@ -326,7 +373,7 @@ def update_tag_names(tag, names):
# add wanted items
for name in names:
if not _check_name_intersection(_get_names(tag), [name], True):
tag.names.append(db.TagName(name, None))
tag.names.append(model.TagName(name, -1))
# set alias order to match the request
for i, name in enumerate(names):
@ -336,7 +383,7 @@ def update_tag_names(tag, names):
# TODO: what to do with relations that do not yet exist?
def update_tag_implications(tag, relations):
def update_tag_implications(tag: model.Tag, relations: List[str]) -> None:
assert tag
if _check_name_intersection(_get_names(tag), relations, False):
raise InvalidTagRelationError('Tag cannot imply itself.')
@ -344,15 +391,15 @@ def update_tag_implications(tag, relations):
# TODO: what to do with relations that do not yet exist?
def update_tag_suggestions(tag, relations):
def update_tag_suggestions(tag: model.Tag, relations: List[str]) -> None:
assert tag
if _check_name_intersection(_get_names(tag), relations, False):
raise InvalidTagRelationError('Tag cannot suggest itself.')
tag.suggestions = get_tags_by_names(relations)
def update_tag_description(tag, description):
def update_tag_description(tag: model.Tag, description: str) -> None:
assert tag
if util.value_exceeds_column_size(description, db.Tag.description):
if util.value_exceeds_column_size(description, model.Tag.description):
raise InvalidTagDescriptionError('Description is too long.')
tag.description = description
tag.description = description or None

View file

@ -1,8 +1,9 @@
import datetime
import re
from sqlalchemy import func
from szurubooru import config, db, errors
from szurubooru.func import auth, util, files, images
from typing import Any, Optional, Union, List, Dict, Callable
from datetime import datetime
import sqlalchemy as sa
from szurubooru import config, db, model, errors, rest
from szurubooru.func import auth, util, serialization, files, images
class UserNotFoundError(errors.NotFoundError):
@ -33,11 +34,11 @@ class InvalidAvatarError(errors.ValidationError):
pass
def get_avatar_path(user_name):
def get_avatar_path(user_name: str) -> str:
return 'avatars/' + user_name.lower() + '.png'
def get_avatar_url(user):
def get_avatar_url(user: model.User) -> str:
assert user
if user.avatar_style == user.AVATAR_GRAVATAR:
assert user.email or user.name
@ -49,7 +50,10 @@ def get_avatar_url(user):
config.config['data_url'].rstrip('/'), user.name.lower())
def get_email(user, auth_user, force_show_email):
def get_email(
user: model.User,
auth_user: model.User,
force_show_email: bool) -> Union[bool, str]:
assert user
assert auth_user
if not force_show_email \
@ -59,7 +63,8 @@ def get_email(user, auth_user, force_show_email):
return user.email
def get_liked_post_count(user, auth_user):
def get_liked_post_count(
user: model.User, auth_user: model.User) -> Union[bool, int]:
assert user
assert auth_user
if auth_user.user_id != user.user_id:
@ -67,7 +72,8 @@ def get_liked_post_count(user, auth_user):
return user.liked_post_count
def get_disliked_post_count(user, auth_user):
def get_disliked_post_count(
user: model.User, auth_user: model.User) -> Union[bool, int]:
assert user
assert auth_user
if auth_user.user_id != user.user_id:
@ -75,91 +81,144 @@ def get_disliked_post_count(user, auth_user):
return user.disliked_post_count
def serialize_user(user, auth_user, options=None, force_show_email=False):
return util.serialize_entity(
user,
{
'name': lambda: user.name,
'creationTime': lambda: user.creation_time,
'lastLoginTime': lambda: user.last_login_time,
'version': lambda: user.version,
'rank': lambda: user.rank,
'avatarStyle': lambda: user.avatar_style,
'avatarUrl': lambda: get_avatar_url(user),
'commentCount': lambda: user.comment_count,
'uploadedPostCount': lambda: user.post_count,
'favoritePostCount': lambda: user.favorite_post_count,
'likedPostCount':
lambda: get_liked_post_count(user, auth_user),
'dislikedPostCount':
lambda: get_disliked_post_count(user, auth_user),
'email':
lambda: get_email(user, auth_user, force_show_email),
},
options)
class UserSerializer(serialization.BaseSerializer):
def __init__(
self,
user: model.User,
auth_user: model.User,
force_show_email: bool=False) -> None:
self.user = user
self.auth_user = auth_user
self.force_show_email = force_show_email
def _serializers(self) -> Dict[str, Callable[[], Any]]:
return {
'name': self.serialize_name,
'creationTime': self.serialize_creation_time,
'lastLoginTime': self.serialize_last_login_time,
'version': self.serialize_version,
'rank': self.serialize_rank,
'avatarStyle': self.serialize_avatar_style,
'avatarUrl': self.serialize_avatar_url,
'commentCount': self.serialize_comment_count,
'uploadedPostCount': self.serialize_uploaded_post_count,
'favoritePostCount': self.serialize_favorite_post_count,
'likedPostCount': self.serialize_liked_post_count,
'dislikedPostCount': self.serialize_disliked_post_count,
'email': self.serialize_email,
}
def serialize_name(self) -> Any:
return self.user.name
def serialize_creation_time(self) -> Any:
return self.user.creation_time
def serialize_last_login_time(self) -> Any:
return self.user.last_login_time
def serialize_version(self) -> Any:
return self.user.version
def serialize_rank(self) -> Any:
return self.user.rank
def serialize_avatar_style(self) -> Any:
return self.user.avatar_style
def serialize_avatar_url(self) -> Any:
return get_avatar_url(self.user)
def serialize_comment_count(self) -> Any:
return self.user.comment_count
def serialize_uploaded_post_count(self) -> Any:
return self.user.post_count
def serialize_favorite_post_count(self) -> Any:
return self.user.favorite_post_count
def serialize_liked_post_count(self) -> Any:
return get_liked_post_count(self.user, self.auth_user)
def serialize_disliked_post_count(self) -> Any:
return get_disliked_post_count(self.user, self.auth_user)
def serialize_email(self) -> Any:
return get_email(self.user, self.auth_user, self.force_show_email)
def serialize_micro_user(user, auth_user):
def serialize_user(
user: Optional[model.User],
auth_user: model.User,
options: List[str]=[],
force_show_email: bool=False) -> Optional[rest.Response]:
if not user:
return None
return UserSerializer(user, auth_user, force_show_email).serialize(options)
def serialize_micro_user(
user: Optional[model.User],
auth_user: model.User) -> Optional[rest.Response]:
return serialize_user(
user,
auth_user=auth_user,
options=['name', 'avatarUrl'])
user, auth_user=auth_user, options=['name', 'avatarUrl'])
def get_user_count():
return db.session.query(db.User).count()
def get_user_count() -> int:
return db.session.query(model.User).count()
def try_get_user_by_name(name):
def try_get_user_by_name(name: str) -> Optional[model.User]:
return db.session \
.query(db.User) \
.filter(func.lower(db.User.name) == func.lower(name)) \
.query(model.User) \
.filter(sa.func.lower(model.User.name) == sa.func.lower(name)) \
.one_or_none()
def get_user_by_name(name):
def get_user_by_name(name: str) -> model.User:
user = try_get_user_by_name(name)
if not user:
raise UserNotFoundError('User %r not found.' % name)
return user
def try_get_user_by_name_or_email(name_or_email):
def try_get_user_by_name_or_email(name_or_email: str) -> Optional[model.User]:
return (
db.session
.query(db.User)
.query(model.User)
.filter(
(func.lower(db.User.name) == func.lower(name_or_email)) |
(func.lower(db.User.email) == func.lower(name_or_email)))
(sa.func.lower(model.User.name) == sa.func.lower(name_or_email)) |
(sa.func.lower(model.User.email) == sa.func.lower(name_or_email)))
.one_or_none())
def get_user_by_name_or_email(name_or_email):
def get_user_by_name_or_email(name_or_email: str) -> model.User:
user = try_get_user_by_name_or_email(name_or_email)
if not user:
raise UserNotFoundError('User %r not found.' % name_or_email)
return user
def create_user(name, password, email):
user = db.User()
def create_user(name: str, password: str, email: str) -> model.User:
user = model.User()
update_user_name(user, name)
update_user_password(user, password)
update_user_email(user, email)
if get_user_count() > 0:
user.rank = util.flip(auth.RANK_MAP)[config.config['default_rank']]
else:
user.rank = db.User.RANK_ADMINISTRATOR
user.creation_time = datetime.datetime.utcnow()
user.avatar_style = db.User.AVATAR_GRAVATAR
user.rank = model.User.RANK_ADMINISTRATOR
user.creation_time = datetime.utcnow()
user.avatar_style = model.User.AVATAR_GRAVATAR
return user
def update_user_name(user, name):
def update_user_name(user: model.User, name: str) -> None:
assert user
if not name:
raise InvalidUserNameError('Name cannot be empty.')
if util.value_exceeds_column_size(name, db.User.name):
if util.value_exceeds_column_size(name, model.User.name):
raise InvalidUserNameError('User name is too long.')
name = name.strip()
name_regex = config.config['user_name_regex']
@ -174,7 +233,7 @@ def update_user_name(user, name):
user.name = name
def update_user_password(user, password):
def update_user_password(user: model.User, password: str) -> None:
assert user
if not password:
raise InvalidPasswordError('Password cannot be empty.')
@ -186,20 +245,18 @@ def update_user_password(user, password):
user.password_hash = auth.get_password_hash(user.password_salt, password)
def update_user_email(user, email):
def update_user_email(user: model.User, email: str) -> None:
assert user
if email:
email = email.strip()
if not email:
email = None
if email and util.value_exceeds_column_size(email, db.User.email):
if util.value_exceeds_column_size(email, model.User.email):
raise InvalidEmailError('Email is too long.')
if not util.is_valid_email(email):
raise InvalidEmailError('E-mail is invalid.')
user.email = email
user.email = email or None
def update_user_rank(user, rank, auth_user):
def update_user_rank(
user: model.User, rank: str, auth_user: model.User) -> None:
assert user
if not rank:
raise InvalidRankError('Rank cannot be empty.')
@ -208,7 +265,7 @@ def update_user_rank(user, rank, auth_user):
if not rank:
raise InvalidRankError(
'Rank can be either of %r.' % all_ranks)
if rank in (db.User.RANK_ANONYMOUS, db.User.RANK_NOBODY):
if rank in (model.User.RANK_ANONYMOUS, model.User.RANK_NOBODY):
raise InvalidRankError('Rank %r cannot be used.' % auth.RANK_MAP[rank])
if all_ranks.index(auth_user.rank) \
< all_ranks.index(rank) and get_user_count() > 0:
@ -216,7 +273,10 @@ def update_user_rank(user, rank, auth_user):
user.rank = rank
def update_user_avatar(user, avatar_style, avatar_content=None):
def update_user_avatar(
user: model.User,
avatar_style: str,
avatar_content: Optional[bytes]=None) -> None:
assert user
if avatar_style == 'gravatar':
user.avatar_style = user.AVATAR_GRAVATAR
@ -238,12 +298,12 @@ def update_user_avatar(user, avatar_style, avatar_content=None):
avatar_style, ['gravatar', 'manual']))
def bump_user_login_time(user):
def bump_user_login_time(user: model.User) -> None:
assert user
user.last_login_time = datetime.datetime.utcnow()
user.last_login_time = datetime.utcnow()
def reset_user_password(user):
def reset_user_password(user: model.User) -> str:
assert user
password = auth.create_password()
user.password_salt = auth.create_password()

View file

@ -2,52 +2,39 @@ import os
import hashlib
import re
import tempfile
from typing import (
Any, Optional, Union, Tuple, List, Dict, Generator, Callable, TypeVar)
from datetime import datetime, timedelta
from contextlib import contextmanager
from szurubooru import errors
def snake_case_to_lower_camel_case(text):
T = TypeVar('T')
def snake_case_to_lower_camel_case(text: str) -> str:
components = text.split('_')
return components[0].lower() + \
''.join(word[0].upper() + word[1:].lower() for word in components[1:])
def snake_case_to_upper_train_case(text):
def snake_case_to_upper_train_case(text: str) -> str:
return '-'.join(
word[0].upper() + word[1:].lower() for word in text.split('_'))
def snake_case_to_lower_camel_case_keys(source):
def snake_case_to_lower_camel_case_keys(
source: Dict[str, Any]) -> Dict[str, Any]:
target = {}
for key, value in source.items():
target[snake_case_to_lower_camel_case(key)] = value
return target
def get_serialization_options(ctx):
return ctx.get_param_as_list('fields', required=False, default=None)
def serialize_entity(entity, field_factories, options):
if not entity:
return None
if not options or len(options) == 0:
options = field_factories.keys()
ret = {}
for key in options:
if key not in field_factories:
raise errors.ValidationError('Invalid key: %r. Valid keys: %r.' % (
key, list(sorted(field_factories.keys()))))
factory = field_factories[key]
ret[key] = factory()
return ret
@contextmanager
def create_temp_file(**kwargs):
(handle, path) = tempfile.mkstemp(**kwargs)
os.close(handle)
def create_temp_file(**kwargs: Any) -> Generator:
(descriptor, path) = tempfile.mkstemp(**kwargs)
os.close(descriptor)
try:
with open(path, 'r+b') as handle:
yield handle
@ -55,17 +42,15 @@ def create_temp_file(**kwargs):
os.remove(path)
def unalias_dict(input_dict):
output_dict = {}
for key_list, value in input_dict.items():
if isinstance(key_list, str):
key_list = [key_list]
for key in key_list:
output_dict[key] = value
def unalias_dict(source: List[Tuple[List[str], T]]) -> Dict[str, T]:
output_dict = {} # type: Dict[str, T]
for aliases, value in source:
for alias in aliases:
output_dict[alias] = value
return output_dict
def get_md5(source):
def get_md5(source: Union[str, bytes]) -> str:
if not isinstance(source, bytes):
source = source.encode('utf-8')
md5 = hashlib.md5()
@ -73,7 +58,7 @@ def get_md5(source):
return md5.hexdigest()
def get_sha1(source):
def get_sha1(source: Union[str, bytes]) -> str:
if not isinstance(source, bytes):
source = source.encode('utf-8')
sha1 = hashlib.sha1()
@ -81,24 +66,25 @@ def get_sha1(source):
return sha1.hexdigest()
def flip(source):
def flip(source: Dict[Any, Any]) -> Dict[Any, Any]:
return {v: k for k, v in source.items()}
def is_valid_email(email):
def is_valid_email(email: Optional[str]) -> bool:
''' Return whether given email address is valid or empty. '''
return not email or re.match(r'^[^@]*@[^@]*\.[^@]*$', email)
return not email or re.match(r'^[^@]*@[^@]*\.[^@]*$', email) is not None
class dotdict(dict): # pylint: disable=invalid-name
''' dot.notation access to dictionary attributes. '''
def __getattr__(self, attr):
def __getattr__(self, attr: str) -> Any:
return self.get(attr)
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__
def parse_time_range(value):
def parse_time_range(value: str) -> Tuple[datetime, datetime]:
''' Return tuple containing min/max time for given text representation. '''
one_day = timedelta(days=1)
one_second = timedelta(seconds=1)
@ -146,9 +132,9 @@ def parse_time_range(value):
raise errors.ValidationError('Invalid date format: %r.' % value)
def icase_unique(source):
target = []
target_low = []
def icase_unique(source: List[str]) -> List[str]:
target = [] # type: List[str]
target_low = [] # type: List[str]
for source_item in source:
if source_item.lower() not in target_low:
target.append(source_item)
@ -156,7 +142,7 @@ def icase_unique(source):
return target
def value_exceeds_column_size(value, column):
def value_exceeds_column_size(value: Optional[str], column: Any) -> bool:
if not value:
return False
max_length = column.property.columns[0].type.length
@ -165,6 +151,6 @@ def value_exceeds_column_size(value, column):
return len(value) > max_length
def chunks(source_list, part_size):
def chunks(source_list: List[Any], part_size: int) -> Generator:
for i in range(0, len(source_list), part_size):
yield source_list[i:i + part_size]

View file

@ -1,8 +1,11 @@
from szurubooru import errors
from szurubooru import errors, rest, model
def verify_version(entity, context, field_name='version'):
actual_version = context.get_param_as_int(field_name, required=True)
def verify_version(
entity: model.Base,
context: rest.Context,
field_name: str='version') -> None:
actual_version = context.get_param_as_int(field_name)
expected_version = entity.version
if actual_version != expected_version:
raise errors.IntegrityError(
@ -10,5 +13,5 @@ def verify_version(entity, context, field_name='version'):
'Please try again.')
def bump_version(entity):
def bump_version(entity: model.Base) -> None:
entity.version = entity.version + 1

View file

@ -1,11 +1,11 @@
import base64
from szurubooru import db, errors
from typing import Optional
from szurubooru import db, model, errors, rest
from szurubooru.func import auth, users
from szurubooru.rest import middleware
from szurubooru.rest.errors import HttpBadRequest
def _authenticate(username, password):
def _authenticate(username: str, password: str) -> model.User:
''' Try to authenticate user. Throw AuthError for invalid users. '''
user = users.get_user_by_name(username)
if not auth.is_valid_password(user, password):
@ -13,16 +13,9 @@ def _authenticate(username, password):
return user
def _create_anonymous_user():
user = db.User()
user.name = None
user.rank = 'anonymous'
return user
def _get_user(ctx):
def _get_user(ctx: rest.Context) -> Optional[model.User]:
if not ctx.has_header('Authorization'):
return _create_anonymous_user()
return None
try:
auth_type, credentials = ctx.get_header('Authorization').split(' ', 1)
@ -41,10 +34,12 @@ def _get_user(ctx):
msg.format(ctx.get_header('Authorization'), str(err)))
@middleware.pre_hook
def process_request(ctx):
@rest.middleware.pre_hook
def process_request(ctx: rest.Context) -> None:
''' Bind the user to request. Update last login time if needed. '''
ctx.user = _get_user(ctx)
if ctx.get_param_as_bool('bump-login') and ctx.user.user_id:
auth_user = _get_user(ctx)
if auth_user:
ctx.user = auth_user
if ctx.get_param_as_bool('bump-login', default=False) and ctx.user.user_id:
users.bump_user_login_time(ctx.user)
ctx.session.commit()

View file

@ -1,8 +1,9 @@
from szurubooru import rest
from szurubooru.func import cache
from szurubooru.rest import middleware
@middleware.pre_hook
def process_request(ctx):
def process_request(ctx: rest.Context) -> None:
if ctx.method != 'GET':
cache.purge()

View file

@ -1,5 +1,5 @@
import logging
from szurubooru import db
from szurubooru import db, rest
from szurubooru.rest import middleware
@ -7,12 +7,12 @@ logger = logging.getLogger(__name__)
@middleware.pre_hook
def process_request(_ctx):
def process_request(_ctx: rest.Context) -> None:
db.reset_query_count()
@middleware.post_hook
def process_response(ctx):
def process_response(ctx: rest.Context) -> None:
logger.info(
'%s %s (user=%s, queries=%d)',
ctx.method,

View file

@ -2,7 +2,7 @@ import os
import sys
import alembic
import sqlalchemy
import sqlalchemy as sa
import logging.config
# make szurubooru module importable
@ -48,7 +48,7 @@ def run_migrations_online():
In this scenario we need to create an Engine
and associate a connection with the context.
'''
connectable = sqlalchemy.engine_from_config(
connectable = sa.engine_from_config(
alembic_config.get_section(alembic_config.config_ini_section),
prefix='sqlalchemy.',
poolclass=sqlalchemy.pool.NullPool)

View file

@ -0,0 +1,15 @@
from szurubooru.model.base import Base
from szurubooru.model.user import User
from szurubooru.model.tag_category import TagCategory
from szurubooru.model.tag import Tag, TagName, TagSuggestion, TagImplication
from szurubooru.model.post import (
Post,
PostTag,
PostRelation,
PostFavorite,
PostScore,
PostNote,
PostFeature)
from szurubooru.model.comment import Comment, CommentScore
from szurubooru.model.snapshot import Snapshot
import szurubooru.model.util

View file

@ -1,7 +1,8 @@
from sqlalchemy import Column, Integer, DateTime, UnicodeText, ForeignKey
from sqlalchemy.orm import relationship, backref
from sqlalchemy.sql.expression import func
from szurubooru.db.base import Base
from szurubooru.db import get_session
from szurubooru.model.base import Base
class CommentScore(Base):
@ -48,12 +49,12 @@ class Comment(Base):
'CommentScore', cascade='all, delete-orphan', lazy='joined')
@property
def score(self):
from szurubooru.db import session
return session \
.query(func.sum(CommentScore.score)) \
.filter(CommentScore.comment_id == self.comment_id) \
.one()[0] or 0
def score(self) -> int:
return (
get_session()
.query(func.sum(CommentScore.score))
.filter(CommentScore.comment_id == self.comment_id)
.one()[0] or 0)
__mapper_args__ = {
'version_id_col': version,

View file

@ -3,8 +3,8 @@ from sqlalchemy import (
Column, Integer, DateTime, Unicode, UnicodeText, PickleType, ForeignKey)
from sqlalchemy.orm import (
relationship, column_property, object_session, backref)
from szurubooru.db.base import Base
from szurubooru.db.comment import Comment
from szurubooru.model.base import Base
from szurubooru.model.comment import Comment
class PostFeature(Base):
@ -17,10 +17,9 @@ class PostFeature(Base):
'user_id', Integer, ForeignKey('user.id'), nullable=False, index=True)
time = Column('time', DateTime, nullable=False)
post = relationship('Post')
post = relationship('Post') # type: Post
user = relationship(
'User',
backref=backref('post_features', cascade='all, delete-orphan'))
'User', backref=backref('post_features', cascade='all, delete-orphan'))
class PostScore(Base):
@ -104,7 +103,7 @@ class PostRelation(Base):
nullable=False,
index=True)
def __init__(self, parent_id, child_id):
def __init__(self, parent_id: int, child_id: int) -> None:
self.parent_id = parent_id
self.child_id = child_id
@ -127,7 +126,7 @@ class PostTag(Base):
nullable=False,
index=True)
def __init__(self, post_id, tag_id):
def __init__(self, post_id: int, tag_id: int) -> None:
self.post_id = post_id
self.tag_id = tag_id
@ -197,7 +196,7 @@ class Post(Base):
canvas_area = column_property(canvas_width * canvas_height)
@property
def is_featured(self):
def is_featured(self) -> bool:
featured_post = object_session(self) \
.query(PostFeature) \
.order_by(PostFeature.time.desc()) \

View file

@ -1,7 +1,7 @@
from sqlalchemy.orm import relationship
from sqlalchemy import (
Column, Integer, DateTime, Unicode, PickleType, ForeignKey)
from szurubooru.db.base import Base
from szurubooru.model.base import Base
class Snapshot(Base):

View file

@ -2,8 +2,8 @@ from sqlalchemy import (
Column, Integer, DateTime, Unicode, UnicodeText, ForeignKey)
from sqlalchemy.orm import relationship, column_property
from sqlalchemy.sql.expression import func, select
from szurubooru.db.base import Base
from szurubooru.db.post import PostTag
from szurubooru.model.base import Base
from szurubooru.model.post import PostTag
class TagSuggestion(Base):
@ -24,7 +24,7 @@ class TagSuggestion(Base):
primary_key=True,
index=True)
def __init__(self, parent_id, child_id):
def __init__(self, parent_id: int, child_id: int) -> None:
self.parent_id = parent_id
self.child_id = child_id
@ -47,7 +47,7 @@ class TagImplication(Base):
primary_key=True,
index=True)
def __init__(self, parent_id, child_id):
def __init__(self, parent_id: int, child_id: int) -> None:
self.parent_id = parent_id
self.child_id = child_id
@ -61,7 +61,7 @@ class TagName(Base):
name = Column('name', Unicode(64), nullable=False, unique=True)
order = Column('ord', Integer, nullable=False, index=True)
def __init__(self, name, order):
def __init__(self, name: str, order: int) -> None:
self.name = name
self.order = order

View file

@ -1,8 +1,9 @@
from typing import Optional
from sqlalchemy import Column, Integer, Unicode, Boolean, table
from sqlalchemy.orm import column_property
from sqlalchemy.sql.expression import func, select
from szurubooru.db.base import Base
from szurubooru.db.tag import Tag
from szurubooru.model.base import Base
from szurubooru.model.tag import Tag
class TagCategory(Base):
@ -14,7 +15,7 @@ class TagCategory(Base):
color = Column('color', Unicode(32), nullable=False, default='#000000')
default = Column('default', Boolean, nullable=False, default=False)
def __init__(self, name=None):
def __init__(self, name: Optional[str]=None) -> None:
self.name = name
tag_count = column_property(

View file

@ -1,9 +1,7 @@
from sqlalchemy import Column, Integer, Unicode, DateTime
from sqlalchemy.orm import relationship
from sqlalchemy.sql.expression import func
from szurubooru.db.base import Base
from szurubooru.db.post import Post, PostScore, PostFavorite
from szurubooru.db.comment import Comment
import sqlalchemy as sa
from szurubooru.model.base import Base
from szurubooru.model.post import Post, PostScore, PostFavorite
from szurubooru.model.comment import Comment
class User(Base):
@ -20,63 +18,64 @@ class User(Base):
RANK_ADMINISTRATOR = 'administrator'
RANK_NOBODY = 'nobody' # unattainable, used for privileges
user_id = Column('id', Integer, primary_key=True)
creation_time = Column('creation_time', DateTime, nullable=False)
last_login_time = Column('last_login_time', DateTime)
version = Column('version', Integer, default=1, nullable=False)
name = Column('name', Unicode(50), nullable=False, unique=True)
password_hash = Column('password_hash', Unicode(64), nullable=False)
password_salt = Column('password_salt', Unicode(32))
email = Column('email', Unicode(64), nullable=True)
rank = Column('rank', Unicode(32), nullable=False)
avatar_style = Column(
'avatar_style', Unicode(32), nullable=False, default=AVATAR_GRAVATAR)
user_id = sa.Column('id', sa.Integer, primary_key=True)
creation_time = sa.Column('creation_time', sa.DateTime, nullable=False)
last_login_time = sa.Column('last_login_time', sa.DateTime)
version = sa.Column('version', sa.Integer, default=1, nullable=False)
name = sa.Column('name', sa.Unicode(50), nullable=False, unique=True)
password_hash = sa.Column('password_hash', sa.Unicode(64), nullable=False)
password_salt = sa.Column('password_salt', sa.Unicode(32))
email = sa.Column('email', sa.Unicode(64), nullable=True)
rank = sa.Column('rank', sa.Unicode(32), nullable=False)
avatar_style = sa.Column(
'avatar_style', sa.Unicode(32), nullable=False,
default=AVATAR_GRAVATAR)
comments = relationship('Comment')
comments = sa.orm.relationship('Comment')
@property
def post_count(self):
def post_count(self) -> int:
from szurubooru.db import session
return (
session
.query(func.sum(1))
.query(sa.sql.expression.func.sum(1))
.filter(Post.user_id == self.user_id)
.one()[0] or 0)
@property
def comment_count(self):
def comment_count(self) -> int:
from szurubooru.db import session
return (
session
.query(func.sum(1))
.query(sa.sql.expression.func.sum(1))
.filter(Comment.user_id == self.user_id)
.one()[0] or 0)
@property
def favorite_post_count(self):
def favorite_post_count(self) -> int:
from szurubooru.db import session
return (
session
.query(func.sum(1))
.query(sa.sql.expression.func.sum(1))
.filter(PostFavorite.user_id == self.user_id)
.one()[0] or 0)
@property
def liked_post_count(self):
def liked_post_count(self) -> int:
from szurubooru.db import session
return (
session
.query(func.sum(1))
.query(sa.sql.expression.func.sum(1))
.filter(PostScore.user_id == self.user_id)
.filter(PostScore.score == 1)
.one()[0] or 0)
@property
def disliked_post_count(self):
def disliked_post_count(self) -> int:
from szurubooru.db import session
return (
session
.query(func.sum(1))
.query(sa.sql.expression.func.sum(1))
.filter(PostScore.user_id == self.user_id)
.filter(PostScore.score == -1)
.one()[0] or 0)

View file

@ -0,0 +1,42 @@
from typing import Tuple, Any, Dict, Callable, Union, Optional
import sqlalchemy as sa
from szurubooru.model.base import Base
from szurubooru.model.user import User
def get_resource_info(entity: Base) -> Tuple[Any, Any, Union[str, int]]:
serializers = {
'tag': lambda tag: tag.first_name,
'tag_category': lambda category: category.name,
'comment': lambda comment: comment.comment_id,
'post': lambda post: post.post_id,
} # type: Dict[str, Callable[[Base], Any]]
resource_type = entity.__table__.name
assert resource_type in serializers
primary_key = sa.inspection.inspect(entity).identity # type: Any
assert primary_key is not None
assert len(primary_key) == 1
resource_name = serializers[resource_type](entity) # type: Union[str, int]
assert resource_name
resource_pkey = primary_key[0] # type: Any
assert resource_pkey
return (resource_type, resource_pkey, resource_name)
def get_aux_entity(
session: Any,
get_table_info: Callable[[Base], Tuple[Base, Callable[[Base], Any]]],
entity: Base,
user: User) -> Optional[Base]:
table, get_column = get_table_info(entity)
return (
session
.query(table)
.filter(get_column(table) == get_column(entity))
.filter(table.user_id == user.user_id)
.one_or_none())

View file

@ -1,2 +1,2 @@
from szurubooru.rest.app import application
from szurubooru.rest.context import Context
from szurubooru.rest.context import Context, Response

View file

@ -2,13 +2,14 @@ import urllib.parse
import cgi
import json
import re
from typing import Dict, Any, Callable, Tuple
from datetime import datetime
from szurubooru import db
from szurubooru.func import util
from szurubooru.rest import errors, middleware, routes, context
def _json_serializer(obj):
def _json_serializer(obj: Any) -> str:
''' JSON serializer for objects not serializable by default JSON code '''
if isinstance(obj, datetime):
serial = obj.isoformat('T') + 'Z'
@ -16,12 +17,12 @@ def _json_serializer(obj):
raise TypeError('Type not serializable')
def _dump_json(obj):
def _dump_json(obj: Any) -> str:
return json.dumps(obj, default=_json_serializer, indent=2)
def _get_headers(env):
headers = {}
def _get_headers(env: Dict[str, Any]) -> Dict[str, str]:
headers = {} # type: Dict[str, str]
for key, value in env.items():
if key.startswith('HTTP_'):
key = util.snake_case_to_upper_train_case(key[5:])
@ -29,7 +30,7 @@ def _get_headers(env):
return headers
def _create_context(env):
def _create_context(env: Dict[str, Any]) -> context.Context:
method = env['REQUEST_METHOD']
path = '/' + env['PATH_INFO'].lstrip('/')
headers = _get_headers(env)
@ -64,7 +65,9 @@ def _create_context(env):
return context.Context(method, path, headers, params, files)
def application(env, start_response):
def application(
env: Dict[str, Any],
start_response: Callable[[str, Any], Any]) -> Tuple[bytes]:
try:
ctx = _create_context(env)
if 'application/json' not in ctx.get_header('Accept'):
@ -106,9 +109,9 @@ def application(env, start_response):
return (_dump_json(response).encode('utf-8'),)
except Exception as ex:
for exception_type, handler in errors.error_handlers.items():
for exception_type, ex_handler in errors.error_handlers.items():
if isinstance(ex, exception_type):
handler(ex)
ex_handler(ex)
raise
except errors.BaseHttpError as ex:

View file

@ -1,111 +1,158 @@
from szurubooru import errors
from typing import Any, Union, List, Dict, Optional, cast
from szurubooru import model, errors
from szurubooru.func import net, file_uploads
def _lower_first(source):
return source[0].lower() + source[1:]
def _param_wrapper(func):
def wrapper(self, name, required=False, default=None, **kwargs):
# pylint: disable=protected-access
if name in self._params:
value = self._params[name]
try:
value = func(self, value, **kwargs)
except errors.InvalidParameterError as ex:
raise errors.InvalidParameterError(
'Parameter %r is invalid: %s' % (
name, _lower_first(str(ex))))
return value
if not required:
return default
raise errors.MissingRequiredParameterError(
'Required parameter %r is missing.' % name)
return wrapper
MISSING = object()
Request = Dict[str, Any]
Response = Optional[Dict[str, Any]]
class Context:
def __init__(self, method, url, headers=None, params=None, files=None):
def __init__(
self,
method: str,
url: str,
headers: Dict[str, str]=None,
params: Request=None,
files: Dict[str, bytes]=None) -> None:
self.method = method
self.url = url
self._headers = headers or {}
self._params = params or {}
self._files = files or {}
# provided by middleware
# self.session = None
# self.user = None
self.user = model.User()
self.user.name = None
self.user.rank = 'anonymous'
def has_header(self, name):
self.session = None # type: Any
def has_header(self, name: str) -> bool:
return name in self._headers
def get_header(self, name):
return self._headers.get(name, None)
def get_header(self, name: str) -> str:
return self._headers.get(name, '')
def has_file(self, name, allow_tokens=True):
def has_file(self, name: str, allow_tokens: bool=True) -> bool:
return (
name in self._files or
name + 'Url' in self._params or
(allow_tokens and name + 'Token' in self._params))
def get_file(self, name, required=False, allow_tokens=True):
ret = None
if name in self._files:
ret = self._files[name]
elif name + 'Url' in self._params:
ret = net.download(self._params[name + 'Url'])
elif allow_tokens and name + 'Token' in self._params:
def get_file(
self,
name: str,
default: Union[object, bytes]=MISSING,
allow_tokens: bool=True) -> bytes:
if name in self._files and self._files[name]:
return self._files[name]
if name + 'Url' in self._params:
return net.download(self._params[name + 'Url'])
if allow_tokens and name + 'Token' in self._params:
ret = file_uploads.get(self._params[name + 'Token'])
if required and not ret:
if ret:
return ret
elif default is not MISSING:
raise errors.MissingOrExpiredRequiredFileError(
'Required file %r is missing or has expired.' % name)
if required and not ret:
if default is not MISSING:
return cast(bytes, default)
raise errors.MissingRequiredFileError(
'Required file %r is missing.' % name)
return ret
def has_param(self, name):
def has_param(self, name: str) -> bool:
return name in self._params
@_param_wrapper
def get_param_as_list(self, value):
if not isinstance(value, list):
def get_param_as_list(
self,
name: str,
default: Union[object, List[Any]]=MISSING) -> List[Any]:
if name not in self._params:
if default is not MISSING:
return cast(List[Any], default)
raise errors.MissingRequiredParameterError(
'Required parameter %r is missing.' % name)
value = self._params[name]
if type(value) is str:
if ',' in value:
return value.split(',')
return [value]
if type(value) is list:
return value
raise errors.InvalidParameterError(
'Parameter %r must be a list.' % name)
@_param_wrapper
def get_param_as_string(self, value):
if isinstance(value, list):
def get_param_as_string(
self,
name: str,
default: Union[object, str]=MISSING) -> str:
if name not in self._params:
if default is not MISSING:
return cast(str, default)
raise errors.MissingRequiredParameterError(
'Required parameter %r is missing.' % name)
value = self._params[name]
try:
value = ','.join(value)
except TypeError:
raise errors.InvalidParameterError('Expected simple string.')
if value is None:
return ''
if type(value) is list:
return ','.join(value)
if type(value) is int or type(value) is float:
return str(value)
if type(value) is str:
return value
except TypeError:
pass
raise errors.InvalidParameterError(
'Parameter %r must be a string value.' % name)
@_param_wrapper
def get_param_as_int(self, value, min=None, max=None):
def get_param_as_int(
self,
name: str,
default: Union[object, int]=MISSING,
min: Optional[int]=None,
max: Optional[int]=None) -> int:
if name not in self._params:
if default is not MISSING:
return cast(int, default)
raise errors.MissingRequiredParameterError(
'Required parameter %r is missing.' % name)
value = self._params[name]
try:
value = int(value)
except (ValueError, TypeError):
raise errors.InvalidParameterError(
'The value must be an integer.')
if min is not None and value < min:
raise errors.InvalidParameterError(
'The value must be at least %r.' % min)
'Parameter %r must be at least %r.' % (name, min))
if max is not None and value > max:
raise errors.InvalidParameterError(
'The value may not exceed %r.' % max)
'Parameter %r may not exceed %r.' % (name, max))
return value
except (ValueError, TypeError):
pass
raise errors.InvalidParameterError(
'Parameter %r must be an integer value.' % name)
@_param_wrapper
def get_param_as_bool(self, value):
def get_param_as_bool(
self,
name: str,
default: Union[object, bool]=MISSING) -> bool:
if name not in self._params:
if default is not MISSING:
return cast(bool, default)
raise errors.MissingRequiredParameterError(
'Required parameter %r is missing.' % name)
value = self._params[name]
try:
value = str(value).lower()
except TypeError:
pass
if value in ['1', 'y', 'yes', 'yeah', 'yep', 'yup', 't', 'true']:
return True
if value in ['0', 'n', 'no', 'nope', 'f', 'false']:
return False
raise errors.InvalidParameterError(
'The value must be a boolean value.')
'Parameter %r must be a boolean value.' % name)

View file

@ -1,11 +1,19 @@
from typing import Callable, Type, Dict
error_handlers = {} # pylint: disable=invalid-name
class BaseHttpError(RuntimeError):
code = None
reason = None
code = -1
reason = ''
def __init__(self, name, description, title=None, extra_fields=None):
def __init__(
self,
name: str,
description: str,
title: str=None,
extra_fields: Dict[str, str]=None) -> None:
super().__init__()
# error name for programmers
self.name = name
@ -52,5 +60,7 @@ class HttpInternalServerError(BaseHttpError):
reason = 'Internal Server Error'
def handle(exception_type, handler):
def handle(
exception_type: Type[Exception],
handler: Callable[[Exception], None]) -> None:
error_handlers[exception_type] = handler

View file

@ -1,11 +1,15 @@
from typing import Callable
from szurubooru.rest.context import Context
# pylint: disable=invalid-name
pre_hooks = []
post_hooks = []
pre_hooks = [] # type: List[Callable[[Context], None]]
post_hooks = [] # type: List[Callable[[Context], None]]
def pre_hook(handler):
def pre_hook(handler: Callable) -> None:
pre_hooks.append(handler)
def post_hook(handler):
def post_hook(handler: Callable) -> None:
post_hooks.insert(0, handler)

View file

@ -1,32 +1,36 @@
from typing import Callable, Dict, Any
from collections import defaultdict
from szurubooru.rest.context import Context, Response
routes = defaultdict(dict) # pylint: disable=invalid-name
# pylint: disable=invalid-name
RouteHandler = Callable[[Context, Dict[str, str]], Response]
routes = defaultdict(dict) # type: Dict[str, Dict[str, RouteHandler]]
def get(url):
def wrapper(handler):
def get(url: str) -> Callable[[RouteHandler], RouteHandler]:
def wrapper(handler: RouteHandler) -> RouteHandler:
routes[url]['GET'] = handler
return handler
return wrapper
def put(url):
def wrapper(handler):
def put(url: str) -> Callable[[RouteHandler], RouteHandler]:
def wrapper(handler: RouteHandler) -> RouteHandler:
routes[url]['PUT'] = handler
return handler
return wrapper
def post(url):
def wrapper(handler):
def post(url: str) -> Callable[[RouteHandler], RouteHandler]:
def wrapper(handler: RouteHandler) -> RouteHandler:
routes[url]['POST'] = handler
return handler
return wrapper
def delete(url):
def wrapper(handler):
def delete(url: str) -> Callable[[RouteHandler], RouteHandler]:
def wrapper(handler: RouteHandler) -> RouteHandler:
routes[url]['DELETE'] = handler
return handler
return wrapper

View file

@ -1,38 +1,47 @@
from szurubooru.search import tokens
from typing import Optional, Tuple, Dict, Callable
from szurubooru.search import tokens, criteria
from szurubooru.search.query import SearchQuery
from szurubooru.search.typing import SaColumn, SaQuery
Filter = Callable[[SaQuery, Optional[criteria.BaseCriterion], bool], SaQuery]
class BaseSearchConfig:
SORT_NONE = tokens.SortToken.SORT_NONE
SORT_ASC = tokens.SortToken.SORT_ASC
SORT_DESC = tokens.SortToken.SORT_DESC
def on_search_query_parsed(self, search_query):
def on_search_query_parsed(self, search_query: SearchQuery) -> None:
pass
def create_filter_query(self, _disable_eager_loads):
def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery:
raise NotImplementedError()
def create_count_query(self, disable_eager_loads):
def create_count_query(self, disable_eager_loads: bool) -> SaQuery:
raise NotImplementedError()
def create_around_query(self):
def create_around_query(self) -> SaQuery:
raise NotImplementedError()
def finalize_query(self, query: SaQuery) -> SaQuery:
return query
@property
def id_column(self):
def id_column(self) -> SaColumn:
return None
@property
def anonymous_filter(self):
def anonymous_filter(self) -> Optional[Filter]:
return None
@property
def special_filters(self):
def special_filters(self) -> Dict[str, Filter]:
return {}
@property
def named_filters(self):
def named_filters(self) -> Dict[str, Filter]:
return {}
@property
def sort_columns(self):
def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]:
return {}

View file

@ -1,59 +1,62 @@
from sqlalchemy.sql.expression import func
from szurubooru import db
from typing import Tuple, Dict
import sqlalchemy as sa
from szurubooru import db, model
from szurubooru.search.typing import SaColumn, SaQuery
from szurubooru.search.configs import util as search_util
from szurubooru.search.configs.base_search_config import BaseSearchConfig
from szurubooru.search.configs.base_search_config import (
BaseSearchConfig, Filter)
class CommentSearchConfig(BaseSearchConfig):
def create_filter_query(self, _disable_eager_loads):
return db.session.query(db.Comment).join(db.User)
def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery:
return db.session.query(model.Comment).join(model.User)
def create_count_query(self, disable_eager_loads):
def create_count_query(self, disable_eager_loads: bool) -> SaQuery:
return self.create_filter_query(disable_eager_loads)
def create_around_query(self):
def create_around_query(self) -> SaQuery:
raise NotImplementedError()
def finalize_query(self, query):
return query.order_by(db.Comment.creation_time.desc())
def finalize_query(self, query: SaQuery) -> SaQuery:
return query.order_by(model.Comment.creation_time.desc())
@property
def anonymous_filter(self):
return search_util.create_str_filter(db.Comment.text)
def anonymous_filter(self) -> SaQuery:
return search_util.create_str_filter(model.Comment.text)
@property
def named_filters(self):
def named_filters(self) -> Dict[str, Filter]:
return {
'id': search_util.create_num_filter(db.Comment.comment_id),
'post': search_util.create_num_filter(db.Comment.post_id),
'user': search_util.create_str_filter(db.User.name),
'author': search_util.create_str_filter(db.User.name),
'text': search_util.create_str_filter(db.Comment.text),
'id': search_util.create_num_filter(model.Comment.comment_id),
'post': search_util.create_num_filter(model.Comment.post_id),
'user': search_util.create_str_filter(model.User.name),
'author': search_util.create_str_filter(model.User.name),
'text': search_util.create_str_filter(model.Comment.text),
'creation-date':
search_util.create_date_filter(db.Comment.creation_time),
search_util.create_date_filter(model.Comment.creation_time),
'creation-time':
search_util.create_date_filter(db.Comment.creation_time),
search_util.create_date_filter(model.Comment.creation_time),
'last-edit-date':
search_util.create_date_filter(db.Comment.last_edit_time),
search_util.create_date_filter(model.Comment.last_edit_time),
'last-edit-time':
search_util.create_date_filter(db.Comment.last_edit_time),
search_util.create_date_filter(model.Comment.last_edit_time),
'edit-date':
search_util.create_date_filter(db.Comment.last_edit_time),
search_util.create_date_filter(model.Comment.last_edit_time),
'edit-time':
search_util.create_date_filter(db.Comment.last_edit_time),
search_util.create_date_filter(model.Comment.last_edit_time),
}
@property
def sort_columns(self):
def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]:
return {
'random': (func.random(), None),
'user': (db.User.name, self.SORT_ASC),
'author': (db.User.name, self.SORT_ASC),
'post': (db.Comment.post_id, self.SORT_DESC),
'creation-date': (db.Comment.creation_time, self.SORT_DESC),
'creation-time': (db.Comment.creation_time, self.SORT_DESC),
'last-edit-date': (db.Comment.last_edit_time, self.SORT_DESC),
'last-edit-time': (db.Comment.last_edit_time, self.SORT_DESC),
'edit-date': (db.Comment.last_edit_time, self.SORT_DESC),
'edit-time': (db.Comment.last_edit_time, self.SORT_DESC),
'random': (sa.sql.expression.func.random(), self.SORT_NONE),
'user': (model.User.name, self.SORT_ASC),
'author': (model.User.name, self.SORT_ASC),
'post': (model.Comment.post_id, self.SORT_DESC),
'creation-date': (model.Comment.creation_time, self.SORT_DESC),
'creation-time': (model.Comment.creation_time, self.SORT_DESC),
'last-edit-date': (model.Comment.last_edit_time, self.SORT_DESC),
'last-edit-time': (model.Comment.last_edit_time, self.SORT_DESC),
'edit-date': (model.Comment.last_edit_time, self.SORT_DESC),
'edit-time': (model.Comment.last_edit_time, self.SORT_DESC),
}

View file

@ -1,13 +1,16 @@
from sqlalchemy.orm import subqueryload, lazyload, defer, aliased
from sqlalchemy.sql.expression import func
from szurubooru import db, errors
from typing import Any, Optional, Tuple, Dict
import sqlalchemy as sa
from szurubooru import db, model, errors
from szurubooru.func import util
from szurubooru.search import criteria, tokens
from szurubooru.search.typing import SaColumn, SaQuery
from szurubooru.search.query import SearchQuery
from szurubooru.search.configs import util as search_util
from szurubooru.search.configs.base_search_config import BaseSearchConfig
from szurubooru.search.configs.base_search_config import (
BaseSearchConfig, Filter)
def _enum_transformer(available_values, value):
def _enum_transformer(available_values: Dict[str, Any], value: str) -> str:
try:
return available_values[value.lower()]
except KeyError:
@ -16,71 +19,82 @@ def _enum_transformer(available_values, value):
value, list(sorted(available_values.keys()))))
def _type_transformer(value):
def _type_transformer(value: str) -> str:
available_values = {
'image': db.Post.TYPE_IMAGE,
'animation': db.Post.TYPE_ANIMATION,
'animated': db.Post.TYPE_ANIMATION,
'anim': db.Post.TYPE_ANIMATION,
'gif': db.Post.TYPE_ANIMATION,
'video': db.Post.TYPE_VIDEO,
'webm': db.Post.TYPE_VIDEO,
'flash': db.Post.TYPE_FLASH,
'swf': db.Post.TYPE_FLASH,
'image': model.Post.TYPE_IMAGE,
'animation': model.Post.TYPE_ANIMATION,
'animated': model.Post.TYPE_ANIMATION,
'anim': model.Post.TYPE_ANIMATION,
'gif': model.Post.TYPE_ANIMATION,
'video': model.Post.TYPE_VIDEO,
'webm': model.Post.TYPE_VIDEO,
'flash': model.Post.TYPE_FLASH,
'swf': model.Post.TYPE_FLASH,
}
return _enum_transformer(available_values, value)
def _safety_transformer(value):
def _safety_transformer(value: str) -> str:
available_values = {
'safe': db.Post.SAFETY_SAFE,
'sketchy': db.Post.SAFETY_SKETCHY,
'questionable': db.Post.SAFETY_SKETCHY,
'unsafe': db.Post.SAFETY_UNSAFE,
'safe': model.Post.SAFETY_SAFE,
'sketchy': model.Post.SAFETY_SKETCHY,
'questionable': model.Post.SAFETY_SKETCHY,
'unsafe': model.Post.SAFETY_UNSAFE,
}
return _enum_transformer(available_values, value)
def _create_score_filter(score):
def wrapper(query, criterion, negated):
def _create_score_filter(score: int) -> Filter:
def wrapper(
query: SaQuery,
criterion: Optional[criteria.BaseCriterion],
negated: bool) -> SaQuery:
assert criterion
if not getattr(criterion, 'internal', False):
raise errors.SearchError(
'Votes cannot be seen publicly. Did you mean %r?'
% 'special:liked')
user_alias = aliased(db.User)
score_alias = aliased(db.PostScore)
user_alias = sa.orm.aliased(model.User)
score_alias = sa.orm.aliased(model.PostScore)
expr = score_alias.score == score
expr = expr & search_util.apply_str_criterion_to_column(
user_alias.name, criterion)
if negated:
expr = ~expr
ret = query \
.join(score_alias, score_alias.post_id == db.Post.post_id) \
.join(score_alias, score_alias.post_id == model.Post.post_id) \
.join(user_alias, user_alias.user_id == score_alias.user_id) \
.filter(expr)
return ret
return wrapper
def _create_user_filter():
def wrapper(query, criterion, negated):
def _create_user_filter() -> Filter:
def wrapper(
query: SaQuery,
criterion: Optional[criteria.BaseCriterion],
negated: bool) -> SaQuery:
assert criterion
if isinstance(criterion, criteria.PlainCriterion) \
and not criterion.value:
# pylint: disable=singleton-comparison
expr = db.Post.user_id == None
expr = model.Post.user_id == None
if negated:
expr = ~expr
return query.filter(expr)
return search_util.create_subquery_filter(
db.Post.user_id,
db.User.user_id,
db.User.name,
model.Post.user_id,
model.User.user_id,
model.User.name,
search_util.create_str_filter)(query, criterion, negated)
return wrapper
class PostSearchConfig(BaseSearchConfig):
def on_search_query_parsed(self, search_query):
def __init__(self) -> None:
self.user = None # type: Optional[model.User]
def on_search_query_parsed(self, search_query: SearchQuery) -> SaQuery:
new_special_tokens = []
for token in search_query.special_tokens:
if token.value in ('fav', 'liked', 'disliked'):
@ -91,7 +105,7 @@ class PostSearchConfig(BaseSearchConfig):
criterion = criteria.PlainCriterion(
original_text=self.user.name,
value=self.user.name)
criterion.internal = True
setattr(criterion, 'internal', True)
search_query.named_tokens.append(
tokens.NamedToken(
name=token.value,
@ -101,160 +115,324 @@ class PostSearchConfig(BaseSearchConfig):
new_special_tokens.append(token)
search_query.special_tokens = new_special_tokens
def create_around_query(self):
return db.session.query(db.Post).options(lazyload('*'))
def create_around_query(self) -> SaQuery:
return db.session.query(model.Post).options(sa.orm.lazyload('*'))
def create_filter_query(self, disable_eager_loads):
strategy = lazyload if disable_eager_loads else subqueryload
return db.session.query(db.Post) \
def create_filter_query(self, disable_eager_loads: bool) -> SaQuery:
strategy = (
sa.orm.lazyload
if disable_eager_loads
else sa.orm.subqueryload)
return db.session.query(model.Post) \
.options(
lazyload('*'),
sa.orm.lazyload('*'),
# use config optimized for official client
# defer(db.Post.score),
# defer(db.Post.favorite_count),
# defer(db.Post.comment_count),
defer(db.Post.last_favorite_time),
defer(db.Post.feature_count),
defer(db.Post.last_feature_time),
defer(db.Post.last_comment_creation_time),
defer(db.Post.last_comment_edit_time),
defer(db.Post.note_count),
defer(db.Post.tag_count),
strategy(db.Post.tags).subqueryload(db.Tag.names),
strategy(db.Post.tags).defer(db.Tag.post_count),
strategy(db.Post.tags).lazyload(db.Tag.implications),
strategy(db.Post.tags).lazyload(db.Tag.suggestions))
# sa.orm.defer(model.Post.score),
# sa.orm.defer(model.Post.favorite_count),
# sa.orm.defer(model.Post.comment_count),
sa.orm.defer(model.Post.last_favorite_time),
sa.orm.defer(model.Post.feature_count),
sa.orm.defer(model.Post.last_feature_time),
sa.orm.defer(model.Post.last_comment_creation_time),
sa.orm.defer(model.Post.last_comment_edit_time),
sa.orm.defer(model.Post.note_count),
sa.orm.defer(model.Post.tag_count),
strategy(model.Post.tags).subqueryload(model.Tag.names),
strategy(model.Post.tags).defer(model.Tag.post_count),
strategy(model.Post.tags).lazyload(model.Tag.implications),
strategy(model.Post.tags).lazyload(model.Tag.suggestions))
def create_count_query(self, _disable_eager_loads):
return db.session.query(db.Post)
def create_count_query(self, _disable_eager_loads: bool) -> SaQuery:
return db.session.query(model.Post)
def finalize_query(self, query):
return query.order_by(db.Post.post_id.desc())
def finalize_query(self, query: SaQuery) -> SaQuery:
return query.order_by(model.Post.post_id.desc())
@property
def id_column(self):
return db.Post.post_id
def id_column(self) -> SaColumn:
return model.Post.post_id
@property
def anonymous_filter(self):
def anonymous_filter(self) -> Optional[Filter]:
return search_util.create_subquery_filter(
db.Post.post_id,
db.PostTag.post_id,
db.TagName.name,
model.Post.post_id,
model.PostTag.post_id,
model.TagName.name,
search_util.create_str_filter,
lambda subquery: subquery.join(db.Tag).join(db.TagName))
lambda subquery: subquery.join(model.Tag).join(model.TagName))
@property
def named_filters(self):
return util.unalias_dict({
'id': search_util.create_num_filter(db.Post.post_id),
'tag': search_util.create_subquery_filter(
db.Post.post_id,
db.PostTag.post_id,
db.TagName.name,
def named_filters(self) -> Dict[str, Filter]:
return util.unalias_dict([
(
['id'],
search_util.create_num_filter(model.Post.post_id)
),
(
['tag'],
search_util.create_subquery_filter(
model.Post.post_id,
model.PostTag.post_id,
model.TagName.name,
search_util.create_str_filter,
lambda subquery: subquery.join(db.Tag).join(db.TagName)),
'score': search_util.create_num_filter(db.Post.score),
('uploader', 'upload', 'submit'):
_create_user_filter(),
'comment': search_util.create_subquery_filter(
db.Post.post_id,
db.Comment.post_id,
db.User.name,
lambda subquery:
subquery.join(model.Tag).join(model.TagName))
),
(
['score'],
search_util.create_num_filter(model.Post.score)
),
(
['uploader', 'upload', 'submit'],
_create_user_filter()
),
(
['comment'],
search_util.create_subquery_filter(
model.Post.post_id,
model.Comment.post_id,
model.User.name,
search_util.create_str_filter,
lambda subquery: subquery.join(db.User)),
'fav': search_util.create_subquery_filter(
db.Post.post_id,
db.PostFavorite.post_id,
db.User.name,
lambda subquery: subquery.join(model.User))
),
(
['fav'],
search_util.create_subquery_filter(
model.Post.post_id,
model.PostFavorite.post_id,
model.User.name,
search_util.create_str_filter,
lambda subquery: subquery.join(db.User)),
'liked': _create_score_filter(1),
'disliked': _create_score_filter(-1),
'tag-count': search_util.create_num_filter(db.Post.tag_count),
'comment-count':
search_util.create_num_filter(db.Post.comment_count),
'fav-count':
search_util.create_num_filter(db.Post.favorite_count),
'note-count': search_util.create_num_filter(db.Post.note_count),
'relation-count':
search_util.create_num_filter(db.Post.relation_count),
'feature-count':
search_util.create_num_filter(db.Post.feature_count),
'type':
lambda subquery: subquery.join(model.User))
),
(
['liked'],
_create_score_filter(1)
),
(
['disliked'],
_create_score_filter(-1)
),
(
['tag-count'],
search_util.create_num_filter(model.Post.tag_count)
),
(
['comment-count'],
search_util.create_num_filter(model.Post.comment_count)
),
(
['fav-count'],
search_util.create_num_filter(model.Post.favorite_count)
),
(
['note-count'],
search_util.create_num_filter(model.Post.note_count)
),
(
['relation-count'],
search_util.create_num_filter(model.Post.relation_count)
),
(
['feature-count'],
search_util.create_num_filter(model.Post.feature_count)
),
(
['type'],
search_util.create_str_filter(
db.Post.type, _type_transformer),
'content-checksum': search_util.create_str_filter(
db.Post.checksum),
'file-size': search_util.create_num_filter(db.Post.file_size),
('image-width', 'width'):
search_util.create_num_filter(db.Post.canvas_width),
('image-height', 'height'):
search_util.create_num_filter(db.Post.canvas_height),
('image-area', 'area'):
search_util.create_num_filter(db.Post.canvas_area),
('creation-date', 'creation-time', 'date', 'time'):
search_util.create_date_filter(db.Post.creation_time),
('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'):
search_util.create_date_filter(db.Post.last_edit_time),
('comment-date', 'comment-time'):
model.Post.type, _type_transformer)
),
(
['content-checksum'],
search_util.create_str_filter(model.Post.checksum)
),
(
['file-size'],
search_util.create_num_filter(model.Post.file_size)
),
(
['image-width', 'width'],
search_util.create_num_filter(model.Post.canvas_width)
),
(
['image-height', 'height'],
search_util.create_num_filter(model.Post.canvas_height)
),
(
['image-area', 'area'],
search_util.create_num_filter(model.Post.canvas_area)
),
(
['creation-date', 'creation-time', 'date', 'time'],
search_util.create_date_filter(model.Post.creation_time)
),
(
['last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'],
search_util.create_date_filter(model.Post.last_edit_time)
),
(
['comment-date', 'comment-time'],
search_util.create_date_filter(
db.Post.last_comment_creation_time),
('fav-date', 'fav-time'):
search_util.create_date_filter(db.Post.last_favorite_time),
('feature-date', 'feature-time'):
search_util.create_date_filter(db.Post.last_feature_time),
('safety', 'rating'):
model.Post.last_comment_creation_time)
),
(
['fav-date', 'fav-time'],
search_util.create_date_filter(model.Post.last_favorite_time)
),
(
['feature-date', 'feature-time'],
search_util.create_date_filter(model.Post.last_feature_time)
),
(
['safety', 'rating'],
search_util.create_str_filter(
db.Post.safety, _safety_transformer),
})
model.Post.safety, _safety_transformer)
),
])
@property
def sort_columns(self):
return util.unalias_dict({
'random': (func.random(), None),
'id': (db.Post.post_id, self.SORT_DESC),
'score': (db.Post.score, self.SORT_DESC),
'tag-count': (db.Post.tag_count, self.SORT_DESC),
'comment-count': (db.Post.comment_count, self.SORT_DESC),
'fav-count': (db.Post.favorite_count, self.SORT_DESC),
'note-count': (db.Post.note_count, self.SORT_DESC),
'relation-count': (db.Post.relation_count, self.SORT_DESC),
'feature-count': (db.Post.feature_count, self.SORT_DESC),
'file-size': (db.Post.file_size, self.SORT_DESC),
('image-width', 'width'):
(db.Post.canvas_width, self.SORT_DESC),
('image-height', 'height'):
(db.Post.canvas_height, self.SORT_DESC),
('image-area', 'area'):
(db.Post.canvas_area, self.SORT_DESC),
('creation-date', 'creation-time', 'date', 'time'):
(db.Post.creation_time, self.SORT_DESC),
('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'):
(db.Post.last_edit_time, self.SORT_DESC),
('comment-date', 'comment-time'):
(db.Post.last_comment_creation_time, self.SORT_DESC),
('fav-date', 'fav-time'):
(db.Post.last_favorite_time, self.SORT_DESC),
('feature-date', 'feature-time'):
(db.Post.last_feature_time, self.SORT_DESC),
})
def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]:
return util.unalias_dict([
(
['random'],
(sa.sql.expression.func.random(), self.SORT_NONE)
),
(
['id'],
(model.Post.post_id, self.SORT_DESC)
),
(
['score'],
(model.Post.score, self.SORT_DESC)
),
(
['tag-count'],
(model.Post.tag_count, self.SORT_DESC)
),
(
['comment-count'],
(model.Post.comment_count, self.SORT_DESC)
),
(
['fav-count'],
(model.Post.favorite_count, self.SORT_DESC)
),
(
['note-count'],
(model.Post.note_count, self.SORT_DESC)
),
(
['relation-count'],
(model.Post.relation_count, self.SORT_DESC)
),
(
['feature-count'],
(model.Post.feature_count, self.SORT_DESC)
),
(
['file-size'],
(model.Post.file_size, self.SORT_DESC)
),
(
['image-width', 'width'],
(model.Post.canvas_width, self.SORT_DESC)
),
(
['image-height', 'height'],
(model.Post.canvas_height, self.SORT_DESC)
),
(
['image-area', 'area'],
(model.Post.canvas_area, self.SORT_DESC)
),
(
['creation-date', 'creation-time', 'date', 'time'],
(model.Post.creation_time, self.SORT_DESC)
),
(
['last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'],
(model.Post.last_edit_time, self.SORT_DESC)
),
(
['comment-date', 'comment-time'],
(model.Post.last_comment_creation_time, self.SORT_DESC)
),
(
['fav-date', 'fav-time'],
(model.Post.last_favorite_time, self.SORT_DESC)
),
(
['feature-date', 'feature-time'],
(model.Post.last_feature_time, self.SORT_DESC)
),
])
@property
def special_filters(self):
def special_filters(self) -> Dict[str, Filter]:
return {
# handled by parsed
'fav': None,
'liked': None,
'disliked': None,
# handled by parser
'fav': self.noop_filter,
'liked': self.noop_filter,
'disliked': self.noop_filter,
'tumbleweed': self.tumbleweed_filter,
}
def tumbleweed_filter(self, query, negated):
expr = \
(db.Post.comment_count == 0) \
& (db.Post.favorite_count == 0) \
& (db.Post.score == 0)
def noop_filter(
self,
query: SaQuery,
_criterion: Optional[criteria.BaseCriterion],
_negated: bool) -> SaQuery:
return query
def tumbleweed_filter(
self,
query: SaQuery,
_criterion: Optional[criteria.BaseCriterion],
negated: bool) -> SaQuery:
expr = (
(model.Post.comment_count == 0)
& (model.Post.favorite_count == 0)
& (model.Post.score == 0))
if negated:
expr = ~expr
return query.filter(expr)

View file

@ -1,28 +1,37 @@
from szurubooru import db
from typing import Dict
from szurubooru import db, model
from szurubooru.search.typing import SaQuery
from szurubooru.search.configs import util as search_util
from szurubooru.search.configs.base_search_config import BaseSearchConfig
from szurubooru.search.configs.base_search_config import (
BaseSearchConfig, Filter)
class SnapshotSearchConfig(BaseSearchConfig):
def create_filter_query(self, _disable_eager_loads):
return db.session.query(db.Snapshot)
def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery:
return db.session.query(model.Snapshot)
def create_count_query(self, _disable_eager_loads):
return db.session.query(db.Snapshot)
def create_count_query(self, _disable_eager_loads: bool) -> SaQuery:
return db.session.query(model.Snapshot)
def create_around_query(self):
def create_around_query(self) -> SaQuery:
raise NotImplementedError()
def finalize_query(self, query):
return query.order_by(db.Snapshot.creation_time.desc())
def finalize_query(self, query: SaQuery) -> SaQuery:
return query.order_by(model.Snapshot.creation_time.desc())
@property
def named_filters(self):
def named_filters(self) -> Dict[str, Filter]:
return {
'type': search_util.create_str_filter(db.Snapshot.resource_type),
'id': search_util.create_str_filter(db.Snapshot.resource_name),
'date': search_util.create_date_filter(db.Snapshot.creation_time),
'time': search_util.create_date_filter(db.Snapshot.creation_time),
'operation': search_util.create_str_filter(db.Snapshot.operation),
'user': search_util.create_str_filter(db.User.name),
'type':
search_util.create_str_filter(model.Snapshot.resource_type),
'id':
search_util.create_str_filter(model.Snapshot.resource_name),
'date':
search_util.create_date_filter(model.Snapshot.creation_time),
'time':
search_util.create_date_filter(model.Snapshot.creation_time),
'operation':
search_util.create_str_filter(model.Snapshot.operation),
'user':
search_util.create_str_filter(model.User.name),
}

View file

@ -1,79 +1,134 @@
from sqlalchemy.orm import subqueryload, lazyload, defer
from sqlalchemy.sql.expression import func
from szurubooru import db
from typing import Tuple, Dict
import sqlalchemy as sa
from szurubooru import db, model
from szurubooru.func import util
from szurubooru.search.typing import SaColumn, SaQuery
from szurubooru.search.configs import util as search_util
from szurubooru.search.configs.base_search_config import BaseSearchConfig
from szurubooru.search.configs.base_search_config import (
BaseSearchConfig, Filter)
class TagSearchConfig(BaseSearchConfig):
def create_filter_query(self, _disable_eager_loads):
strategy = lazyload if _disable_eager_loads else subqueryload
return db.session.query(db.Tag) \
.join(db.TagCategory) \
def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery:
strategy = (
sa.orm.lazyload
if _disable_eager_loads
else sa.orm.subqueryload)
return db.session.query(model.Tag) \
.join(model.TagCategory) \
.options(
defer(db.Tag.first_name),
defer(db.Tag.suggestion_count),
defer(db.Tag.implication_count),
defer(db.Tag.post_count),
strategy(db.Tag.names),
strategy(db.Tag.suggestions).joinedload(db.Tag.names),
strategy(db.Tag.implications).joinedload(db.Tag.names))
sa.orm.defer(model.Tag.first_name),
sa.orm.defer(model.Tag.suggestion_count),
sa.orm.defer(model.Tag.implication_count),
sa.orm.defer(model.Tag.post_count),
strategy(model.Tag.names),
strategy(model.Tag.suggestions).joinedload(model.Tag.names),
strategy(model.Tag.implications).joinedload(model.Tag.names))
def create_count_query(self, _disable_eager_loads):
return db.session.query(db.Tag)
def create_count_query(self, _disable_eager_loads: bool) -> SaQuery:
return db.session.query(model.Tag)
def create_around_query(self):
def create_around_query(self) -> SaQuery:
raise NotImplementedError()
def finalize_query(self, query):
return query.order_by(db.Tag.first_name.asc())
def finalize_query(self, query: SaQuery) -> SaQuery:
return query.order_by(model.Tag.first_name.asc())
@property
def anonymous_filter(self):
def anonymous_filter(self) -> Filter:
return search_util.create_subquery_filter(
db.Tag.tag_id,
db.TagName.tag_id,
db.TagName.name,
model.Tag.tag_id,
model.TagName.tag_id,
model.TagName.name,
search_util.create_str_filter)
@property
def named_filters(self):
return util.unalias_dict({
'name': search_util.create_subquery_filter(
db.Tag.tag_id,
db.TagName.tag_id,
db.TagName.name,
search_util.create_str_filter),
'category': search_util.create_subquery_filter(
db.Tag.category_id,
db.TagCategory.tag_category_id,
db.TagCategory.name,
search_util.create_str_filter),
('creation-date', 'creation-time'):
search_util.create_date_filter(db.Tag.creation_time),
('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'):
search_util.create_date_filter(db.Tag.last_edit_time),
('usage-count', 'post-count', 'usages'):
search_util.create_num_filter(db.Tag.post_count),
'suggestion-count':
search_util.create_num_filter(db.Tag.suggestion_count),
'implication-count':
search_util.create_num_filter(db.Tag.implication_count),
})
def named_filters(self) -> Dict[str, Filter]:
return util.unalias_dict([
(
['name'],
search_util.create_subquery_filter(
model.Tag.tag_id,
model.TagName.tag_id,
model.TagName.name,
search_util.create_str_filter)
),
(
['category'],
search_util.create_subquery_filter(
model.Tag.category_id,
model.TagCategory.tag_category_id,
model.TagCategory.name,
search_util.create_str_filter)
),
(
['creation-date', 'creation-time'],
search_util.create_date_filter(model.Tag.creation_time)
),
(
['last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'],
search_util.create_date_filter(model.Tag.last_edit_time)
),
(
['usage-count', 'post-count', 'usages'],
search_util.create_num_filter(model.Tag.post_count)
),
(
['suggestion-count'],
search_util.create_num_filter(model.Tag.suggestion_count)
),
(
['implication-count'],
search_util.create_num_filter(model.Tag.implication_count)
),
])
@property
def sort_columns(self):
return util.unalias_dict({
'random': (func.random(), None),
'name': (db.Tag.first_name, self.SORT_ASC),
'category': (db.TagCategory.name, self.SORT_ASC),
('creation-date', 'creation-time'):
(db.Tag.creation_time, self.SORT_DESC),
('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'):
(db.Tag.last_edit_time, self.SORT_DESC),
('usage-count', 'post-count', 'usages'):
(db.Tag.post_count, self.SORT_DESC),
'suggestion-count': (db.Tag.suggestion_count, self.SORT_DESC),
'implication-count': (db.Tag.implication_count, self.SORT_DESC),
})
def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]:
return util.unalias_dict([
(
['random'],
(sa.sql.expression.func.random(), self.SORT_NONE)
),
(
['name'],
(model.Tag.first_name, self.SORT_ASC)
),
(
['category'],
(model.TagCategory.name, self.SORT_ASC)
),
(
['creation-date', 'creation-time'],
(model.Tag.creation_time, self.SORT_DESC)
),
(
['last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'],
(model.Tag.last_edit_time, self.SORT_DESC)
),
(
['usage-count', 'post-count', 'usages'],
(model.Tag.post_count, self.SORT_DESC)
),
(
['suggestion-count'],
(model.Tag.suggestion_count, self.SORT_DESC)
),
(
['implication-count'],
(model.Tag.implication_count, self.SORT_DESC)
),
])

View file

@ -1,53 +1,57 @@
from sqlalchemy.sql.expression import func
from szurubooru import db
from typing import Tuple, Dict
import sqlalchemy as sa
from szurubooru import db, model
from szurubooru.search.typing import SaColumn, SaQuery
from szurubooru.search.configs import util as search_util
from szurubooru.search.configs.base_search_config import BaseSearchConfig
from szurubooru.search.configs.base_search_config import (
BaseSearchConfig, Filter)
class UserSearchConfig(BaseSearchConfig):
def create_filter_query(self, _disable_eager_loads):
return db.session.query(db.User)
def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery:
return db.session.query(model.User)
def create_count_query(self, _disable_eager_loads):
return db.session.query(db.User)
def create_count_query(self, _disable_eager_loads: bool) -> SaQuery:
return db.session.query(model.User)
def create_around_query(self):
def create_around_query(self) -> SaQuery:
raise NotImplementedError()
def finalize_query(self, query):
return query.order_by(db.User.name.asc())
def finalize_query(self, query: SaQuery) -> SaQuery:
return query.order_by(model.User.name.asc())
@property
def anonymous_filter(self):
return search_util.create_str_filter(db.User.name)
def anonymous_filter(self) -> Filter:
return search_util.create_str_filter(model.User.name)
@property
def named_filters(self):
def named_filters(self) -> Dict[str, Filter]:
return {
'name': search_util.create_str_filter(db.User.name),
'name':
search_util.create_str_filter(model.User.name),
'creation-date':
search_util.create_date_filter(db.User.creation_time),
search_util.create_date_filter(model.User.creation_time),
'creation-time':
search_util.create_date_filter(db.User.creation_time),
search_util.create_date_filter(model.User.creation_time),
'last-login-date':
search_util.create_date_filter(db.User.last_login_time),
search_util.create_date_filter(model.User.last_login_time),
'last-login-time':
search_util.create_date_filter(db.User.last_login_time),
search_util.create_date_filter(model.User.last_login_time),
'login-date':
search_util.create_date_filter(db.User.last_login_time),
search_util.create_date_filter(model.User.last_login_time),
'login-time':
search_util.create_date_filter(db.User.last_login_time),
search_util.create_date_filter(model.User.last_login_time),
}
@property
def sort_columns(self):
def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]:
return {
'random': (func.random(), None),
'name': (db.User.name, self.SORT_ASC),
'creation-date': (db.User.creation_time, self.SORT_DESC),
'creation-time': (db.User.creation_time, self.SORT_DESC),
'last-login-date': (db.User.last_login_time, self.SORT_DESC),
'last-login-time': (db.User.last_login_time, self.SORT_DESC),
'login-date': (db.User.last_login_time, self.SORT_DESC),
'login-time': (db.User.last_login_time, self.SORT_DESC),
'random': (sa.sql.expression.func.random(), self.SORT_NONE),
'name': (model.User.name, self.SORT_ASC),
'creation-date': (model.User.creation_time, self.SORT_DESC),
'creation-time': (model.User.creation_time, self.SORT_DESC),
'last-login-date': (model.User.last_login_time, self.SORT_DESC),
'last-login-time': (model.User.last_login_time, self.SORT_DESC),
'login-date': (model.User.last_login_time, self.SORT_DESC),
'login-time': (model.User.last_login_time, self.SORT_DESC),
}

View file

@ -1,10 +1,13 @@
import sqlalchemy
from typing import Any, Optional, Callable
import sqlalchemy as sa
from szurubooru import db, errors
from szurubooru.func import util
from szurubooru.search import criteria
from szurubooru.search.typing import SaColumn, SaQuery
from szurubooru.search.configs.base_search_config import Filter
def wildcard_transformer(value):
def wildcard_transformer(value: str) -> str:
return (
value
.replace('\\', '\\\\')
@ -13,24 +16,21 @@ def wildcard_transformer(value):
.replace('*', '%'))
def apply_num_criterion_to_column(column, criterion):
'''
Decorate SQLAlchemy filter on given column using supplied criterion.
'''
def apply_num_criterion_to_column(
column: Any, criterion: criteria.BaseCriterion) -> Any:
try:
if isinstance(criterion, criteria.PlainCriterion):
expr = column == int(criterion.value)
elif isinstance(criterion, criteria.ArrayCriterion):
expr = column.in_(int(value) for value in criterion.values)
elif isinstance(criterion, criteria.RangedCriterion):
assert criterion.min_value != '' \
or criterion.max_value != ''
if criterion.min_value != '' and criterion.max_value != '':
assert criterion.min_value or criterion.max_value
if criterion.min_value and criterion.max_value:
expr = column.between(
int(criterion.min_value), int(criterion.max_value))
elif criterion.min_value != '':
elif criterion.min_value:
expr = column >= int(criterion.min_value)
elif criterion.max_value != '':
elif criterion.max_value:
expr = column <= int(criterion.max_value)
else:
assert False
@ -40,10 +40,13 @@ def apply_num_criterion_to_column(column, criterion):
return expr
def create_num_filter(column):
def wrapper(query, criterion, negated):
expr = apply_num_criterion_to_column(
column, criterion)
def create_num_filter(column: Any) -> Filter:
def wrapper(
query: SaQuery,
criterion: Optional[criteria.BaseCriterion],
negated: bool) -> SaQuery:
assert criterion
expr = apply_num_criterion_to_column(column, criterion)
if negated:
expr = ~expr
return query.filter(expr)
@ -51,14 +54,13 @@ def create_num_filter(column):
def apply_str_criterion_to_column(
column, criterion, transformer=wildcard_transformer):
'''
Decorate SQLAlchemy filter on given column using supplied criterion.
'''
column: SaColumn,
criterion: criteria.BaseCriterion,
transformer: Callable[[str], str]=wildcard_transformer) -> SaQuery:
if isinstance(criterion, criteria.PlainCriterion):
expr = column.ilike(transformer(criterion.value))
elif isinstance(criterion, criteria.ArrayCriterion):
expr = sqlalchemy.sql.false()
expr = sa.sql.false()
for value in criterion.values:
expr = expr | column.ilike(transformer(value))
elif isinstance(criterion, criteria.RangedCriterion):
@ -68,8 +70,15 @@ def apply_str_criterion_to_column(
return expr
def create_str_filter(column, transformer=wildcard_transformer):
def wrapper(query, criterion, negated):
def create_str_filter(
column: SaColumn,
transformer: Callable[[str], str]=wildcard_transformer
) -> Filter:
def wrapper(
query: SaQuery,
criterion: Optional[criteria.BaseCriterion],
negated: bool) -> SaQuery:
assert criterion
expr = apply_str_criterion_to_column(
column, criterion, transformer)
if negated:
@ -78,16 +87,13 @@ def create_str_filter(column, transformer=wildcard_transformer):
return wrapper
def apply_date_criterion_to_column(column, criterion):
'''
Decorate SQLAlchemy filter on given column using supplied criterion.
Parse the datetime inside the criterion.
'''
def apply_date_criterion_to_column(
column: SaQuery, criterion: criteria.BaseCriterion) -> SaQuery:
if isinstance(criterion, criteria.PlainCriterion):
min_date, max_date = util.parse_time_range(criterion.value)
expr = column.between(min_date, max_date)
elif isinstance(criterion, criteria.ArrayCriterion):
expr = sqlalchemy.sql.false()
expr = sa.sql.false()
for value in criterion.values:
min_date, max_date = util.parse_time_range(value)
expr = expr | column.between(min_date, max_date)
@ -108,10 +114,13 @@ def apply_date_criterion_to_column(column, criterion):
return expr
def create_date_filter(column):
def wrapper(query, criterion, negated):
expr = apply_date_criterion_to_column(
column, criterion)
def create_date_filter(column: SaColumn) -> Filter:
def wrapper(
query: SaQuery,
criterion: Optional[criteria.BaseCriterion],
negated: bool) -> SaQuery:
assert criterion
expr = apply_date_criterion_to_column(column, criterion)
if negated:
expr = ~expr
return query.filter(expr)
@ -119,18 +128,22 @@ def create_date_filter(column):
def create_subquery_filter(
left_id_column,
right_id_column,
filter_column,
filter_factory,
subquery_decorator=None):
left_id_column: SaColumn,
right_id_column: SaColumn,
filter_column: SaColumn,
filter_factory: SaColumn,
subquery_decorator: Callable[[SaQuery], None]=None) -> Filter:
filter_func = filter_factory(filter_column)
def wrapper(query, criterion, negated):
def wrapper(
query: SaQuery,
criterion: Optional[criteria.BaseCriterion],
negated: bool) -> SaQuery:
assert criterion
subquery = db.session.query(right_id_column.label('foreign_id'))
if subquery_decorator:
subquery = subquery_decorator(subquery)
subquery = subquery.options(sqlalchemy.orm.lazyload('*'))
subquery = subquery.options(sa.orm.lazyload('*'))
subquery = filter_func(subquery, criterion, False)
subquery = subquery.subquery('t')
expression = left_id_column.in_(subquery)

View file

@ -1,34 +1,42 @@
class _BaseCriterion:
def __init__(self, original_text):
from typing import Optional, List, Callable
from szurubooru.search.typing import SaQuery
class BaseCriterion:
def __init__(self, original_text: str) -> None:
self.original_text = original_text
def __repr__(self):
def __repr__(self) -> str:
return self.original_text
class RangedCriterion(_BaseCriterion):
def __init__(self, original_text, min_value, max_value):
class RangedCriterion(BaseCriterion):
def __init__(
self,
original_text: str,
min_value: Optional[str],
max_value: Optional[str]) -> None:
super().__init__(original_text)
self.min_value = min_value
self.max_value = max_value
def __hash__(self):
def __hash__(self) -> int:
return hash(('range', self.min_value, self.max_value))
class PlainCriterion(_BaseCriterion):
def __init__(self, original_text, value):
class PlainCriterion(BaseCriterion):
def __init__(self, original_text: str, value: str) -> None:
super().__init__(original_text)
self.value = value
def __hash__(self):
def __hash__(self) -> int:
return hash(self.value)
class ArrayCriterion(_BaseCriterion):
def __init__(self, original_text, values):
class ArrayCriterion(BaseCriterion):
def __init__(self, original_text: str, values: List[str]) -> None:
super().__init__(original_text)
self.values = values
def __hash__(self):
def __hash__(self) -> int:
return hash(tuple(['array'] + self.values))

View file

@ -1,14 +1,18 @@
import sqlalchemy
from szurubooru import db, errors
from typing import Union, Tuple, List, Dict, Callable
import sqlalchemy as sa
from szurubooru import db, model, errors, rest
from szurubooru.func import cache
from szurubooru.search import tokens, parser
from szurubooru.search.typing import SaQuery
from szurubooru.search.query import SearchQuery
from szurubooru.search.configs.base_search_config import BaseSearchConfig
def _format_dict_keys(source):
def _format_dict_keys(source: Dict) -> List[str]:
return list(sorted(source.keys()))
def _get_order(order, default_order):
def _get_order(order: str, default_order: str) -> Union[bool, str]:
if order == tokens.SortToken.SORT_DEFAULT:
return default_order or tokens.SortToken.SORT_ASC
if order == tokens.SortToken.SORT_NEGATED_DEFAULT:
@ -26,50 +30,57 @@ class Executor:
delegates sqlalchemy filter decoration to SearchConfig instances.
'''
def __init__(self, search_config):
def __init__(self, search_config: BaseSearchConfig) -> None:
self.config = search_config
self.parser = parser.Parser()
def get_around(self, query_text, entity_id):
def get_around(
self,
query_text: str,
entity_id: int) -> Tuple[model.Base, model.Base]:
search_query = self.parser.parse(query_text)
self.config.on_search_query_parsed(search_query)
filter_query = (
self.config
.create_around_query()
.options(sqlalchemy.orm.lazyload('*')))
.options(sa.orm.lazyload('*')))
filter_query = self._prepare_db_query(
filter_query, search_query, False)
prev_filter_query = (
filter_query
.filter(self.config.id_column > entity_id)
.order_by(None)
.order_by(sqlalchemy.func.abs(
self.config.id_column - entity_id).asc())
.order_by(sa.func.abs(self.config.id_column - entity_id).asc())
.limit(1))
next_filter_query = (
filter_query
.filter(self.config.id_column < entity_id)
.order_by(None)
.order_by(sqlalchemy.func.abs(
self.config.id_column - entity_id).asc())
.order_by(sa.func.abs(self.config.id_column - entity_id).asc())
.limit(1))
return [
return (
prev_filter_query.one_or_none(),
next_filter_query.one_or_none()]
next_filter_query.one_or_none())
def get_around_and_serialize(self, ctx, entity_id, serializer):
entities = self.get_around(ctx.get_param_as_string('query'), entity_id)
def get_around_and_serialize(
self,
ctx: rest.Context,
entity_id: int,
serializer: Callable[[model.Base], rest.Response]
) -> rest.Response:
entities = self.get_around(
ctx.get_param_as_string('query', default=''), entity_id)
return {
'prev': serializer(entities[0]),
'next': serializer(entities[1]),
}
def execute(self, query_text, page, page_size):
'''
Parse input and return tuple containing total record count and filtered
entities.
'''
def execute(
self,
query_text: str,
page: int,
page_size: int
) -> Tuple[int, List[model.Base]]:
search_query = self.parser.parse(query_text)
self.config.on_search_query_parsed(search_query)
@ -83,7 +94,7 @@ class Executor:
return cache.get(key)
filter_query = self.config.create_filter_query(disable_eager_loads)
filter_query = filter_query.options(sqlalchemy.orm.lazyload('*'))
filter_query = filter_query.options(sa.orm.lazyload('*'))
filter_query = self._prepare_db_query(filter_query, search_query, True)
entities = filter_query \
.offset(max(page - 1, 0) * page_size) \
@ -91,11 +102,11 @@ class Executor:
.all()
count_query = self.config.create_count_query(disable_eager_loads)
count_query = count_query.options(sqlalchemy.orm.lazyload('*'))
count_query = count_query.options(sa.orm.lazyload('*'))
count_query = self._prepare_db_query(count_query, search_query, False)
count_statement = count_query \
.statement \
.with_only_columns([sqlalchemy.func.count()]) \
.with_only_columns([sa.func.count()]) \
.order_by(None)
count = db.session.execute(count_statement).scalar()
@ -103,8 +114,12 @@ class Executor:
cache.put(key, ret)
return ret
def execute_and_serialize(self, ctx, serializer):
query = ctx.get_param_as_string('query')
def execute_and_serialize(
self,
ctx: rest.Context,
serializer: Callable[[model.Base], rest.Response]
) -> rest.Response:
query = ctx.get_param_as_string('query', default='')
page = ctx.get_param_as_int('page', default=1, min=1)
page_size = ctx.get_param_as_int(
'pageSize', default=100, min=1, max=100)
@ -117,48 +132,51 @@ class Executor:
'results': [serializer(entity) for entity in entities],
}
def _prepare_db_query(self, db_query, search_query, use_sort):
''' Parse input and return SQLAlchemy query. '''
for token in search_query.anonymous_tokens:
def _prepare_db_query(
self,
db_query: SaQuery,
search_query: SearchQuery,
use_sort: bool) -> SaQuery:
for anon_token in search_query.anonymous_tokens:
if not self.config.anonymous_filter:
raise errors.SearchError(
'Anonymous tokens are not valid in this context.')
db_query = self.config.anonymous_filter(
db_query, token.criterion, token.negated)
db_query, anon_token.criterion, anon_token.negated)
for token in search_query.named_tokens:
if token.name not in self.config.named_filters:
for named_token in search_query.named_tokens:
if named_token.name not in self.config.named_filters:
raise errors.SearchError(
'Unknown named token: %r. Available named tokens: %r.' % (
token.name,
named_token.name,
_format_dict_keys(self.config.named_filters)))
db_query = self.config.named_filters[token.name](
db_query, token.criterion, token.negated)
db_query = self.config.named_filters[named_token.name](
db_query, named_token.criterion, named_token.negated)
for token in search_query.special_tokens:
if token.value not in self.config.special_filters:
for sp_token in search_query.special_tokens:
if sp_token.value not in self.config.special_filters:
raise errors.SearchError(
'Unknown special token: %r. '
'Available special tokens: %r.' % (
token.value,
sp_token.value,
_format_dict_keys(self.config.special_filters)))
db_query = self.config.special_filters[token.value](
db_query, token.negated)
db_query = self.config.special_filters[sp_token.value](
db_query, None, sp_token.negated)
if use_sort:
for token in search_query.sort_tokens:
if token.name not in self.config.sort_columns:
for sort_token in search_query.sort_tokens:
if sort_token.name not in self.config.sort_columns:
raise errors.SearchError(
'Unknown sort token: %r. '
'Available sort tokens: %r.' % (
token.name,
sort_token.name,
_format_dict_keys(self.config.sort_columns)))
column, default_order = self.config.sort_columns[token.name]
order = _get_order(token.order, default_order)
if order == token.SORT_ASC:
column, default_order = (
self.config.sort_columns[sort_token.name])
order = _get_order(sort_token.order, default_order)
if order == sort_token.SORT_ASC:
db_query = db_query.order_by(column.asc())
elif order == token.SORT_DESC:
elif order == sort_token.SORT_DESC:
db_query = db_query.order_by(column.desc())
db_query = self.config.finalize_query(db_query)

View file

@ -1,9 +1,12 @@
import re
from typing import List
from szurubooru import errors
from szurubooru.search import criteria, tokens
from szurubooru.search.query import SearchQuery
def _create_criterion(original_value, value):
def _create_criterion(
original_value: str, value: str) -> criteria.BaseCriterion:
if ',' in value:
return criteria.ArrayCriterion(
original_value, value.split(','))
@ -15,12 +18,12 @@ def _create_criterion(original_value, value):
return criteria.PlainCriterion(original_value, value)
def _parse_anonymous(value, negated):
def _parse_anonymous(value: str, negated: bool) -> tokens.AnonymousToken:
criterion = _create_criterion(value, value)
return tokens.AnonymousToken(criterion, negated)
def _parse_named(key, value, negated):
def _parse_named(key: str, value: str, negated: bool) -> tokens.NamedToken:
original_value = value
if key.endswith('-min'):
key = key[:-4]
@ -32,11 +35,11 @@ def _parse_named(key, value, negated):
return tokens.NamedToken(key, criterion, negated)
def _parse_special(value, negated):
def _parse_special(value: str, negated: bool) -> tokens.SpecialToken:
return tokens.SpecialToken(value, negated)
def _parse_sort(value, negated):
def _parse_sort(value: str, negated: bool) -> tokens.SortToken:
if value.count(',') == 0:
order_str = None
elif value.count(',') == 1:
@ -67,23 +70,8 @@ def _parse_sort(value, negated):
return tokens.SortToken(value, order)
class SearchQuery:
def __init__(self):
self.anonymous_tokens = []
self.named_tokens = []
self.special_tokens = []
self.sort_tokens = []
def __hash__(self):
return hash((
tuple(self.anonymous_tokens),
tuple(self.named_tokens),
tuple(self.special_tokens),
tuple(self.sort_tokens)))
class Parser:
def parse(self, query_text):
def parse(self, query_text: str) -> SearchQuery:
query = SearchQuery()
for chunk in re.split(r'\s+', (query_text or '').lower()):
if not chunk:

View file

@ -0,0 +1,16 @@
from szurubooru.search import tokens
class SearchQuery:
def __init__(self) -> None:
self.anonymous_tokens = [] # type: List[tokens.AnonymousToken]
self.named_tokens = [] # type: List[tokens.NamedToken]
self.special_tokens = [] # type: List[tokens.SpecialToken]
self.sort_tokens = [] # type: List[tokens.SortToken]
def __hash__(self) -> int:
return hash((
tuple(self.anonymous_tokens),
tuple(self.named_tokens),
tuple(self.special_tokens),
tuple(self.sort_tokens)))

View file

@ -1,39 +1,44 @@
from szurubooru.search.criteria import BaseCriterion
class AnonymousToken:
def __init__(self, criterion, negated):
def __init__(self, criterion: BaseCriterion, negated: bool) -> None:
self.criterion = criterion
self.negated = negated
def __hash__(self):
def __hash__(self) -> int:
return hash((self.criterion, self.negated))
class NamedToken(AnonymousToken):
def __init__(self, name, criterion, negated):
def __init__(
self, name: str, criterion: BaseCriterion, negated: bool) -> None:
super().__init__(criterion, negated)
self.name = name
def __hash__(self):
def __hash__(self) -> int:
return hash((self.name, self.criterion, self.negated))
class SortToken:
SORT_DESC = 'desc'
SORT_ASC = 'asc'
SORT_NONE = ''
SORT_DEFAULT = 'default'
SORT_NEGATED_DEFAULT = 'negated default'
def __init__(self, name, order):
def __init__(self, name: str, order: str) -> None:
self.name = name
self.order = order
def __hash__(self):
def __hash__(self) -> int:
return hash((self.name, self.order))
class SpecialToken:
def __init__(self, value, negated):
def __init__(self, value: str, negated: bool) -> None:
self.value = value
self.negated = negated
def __hash__(self):
def __hash__(self) -> int:
return hash((self.value, self.negated))

View file

@ -0,0 +1,6 @@
from typing import Any, Callable
SaColumn = Any
SaQuery = Any
SaQueryFactory = Callable[[], SaQuery]

View file

@ -1,19 +1,20 @@
from datetime import datetime
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors
from szurubooru import api, db, model, errors
from szurubooru.func import comments, posts
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({'privileges': {'comments:create': db.User.RANK_REGULAR}})
config_injector(
{'privileges': {'comments:create': model.User.RANK_REGULAR}})
def test_creating_comment(
user_factory, post_factory, context_factory, fake_datetime):
post = post_factory()
user = user_factory(rank=db.User.RANK_REGULAR)
user = user_factory(rank=model.User.RANK_REGULAR)
db.session.add_all([post, user])
db.session.flush()
with patch('szurubooru.func.comments.serialize_comment'), \
@ -24,7 +25,7 @@ def test_creating_comment(
params={'text': 'input', 'postId': post.post_id},
user=user))
assert result == 'serialized comment'
comment = db.session.query(db.Comment).one()
comment = db.session.query(model.Comment).one()
assert comment.text == 'input'
assert comment.creation_time == datetime(1997, 1, 1)
assert comment.last_edit_time is None
@ -41,7 +42,7 @@ def test_creating_comment(
def test_trying_to_pass_invalid_params(
user_factory, post_factory, context_factory, params):
post = post_factory()
user = user_factory(rank=db.User.RANK_REGULAR)
user = user_factory(rank=model.User.RANK_REGULAR)
db.session.add_all([post, user])
db.session.flush()
real_params = {'text': 'input', 'postId': post.post_id}
@ -63,11 +64,11 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
api.comment_api.create_comment(
context_factory(
params={},
user=user_factory(rank=db.User.RANK_REGULAR)))
user=user_factory(rank=model.User.RANK_REGULAR)))
def test_trying_to_comment_non_existing(user_factory, context_factory):
user = user_factory(rank=db.User.RANK_REGULAR)
user = user_factory(rank=model.User.RANK_REGULAR)
db.session.add_all([user])
db.session.flush()
with pytest.raises(posts.PostNotFoundError):
@ -81,4 +82,4 @@ def test_trying_to_create_without_privileges(user_factory, context_factory):
api.comment_api.create_comment(
context_factory(
params={},
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
user=user_factory(rank=model.User.RANK_ANONYMOUS)))

View file

@ -1,5 +1,5 @@
import pytest
from szurubooru import api, db, errors
from szurubooru import api, db, model, errors
from szurubooru.func import comments
@ -7,8 +7,8 @@ from szurubooru.func import comments
def inject_config(config_injector):
config_injector({
'privileges': {
'comments:delete:own': db.User.RANK_REGULAR,
'comments:delete:any': db.User.RANK_MODERATOR,
'comments:delete:own': model.User.RANK_REGULAR,
'comments:delete:any': model.User.RANK_MODERATOR,
},
})
@ -22,26 +22,26 @@ def test_deleting_own_comment(user_factory, comment_factory, context_factory):
context_factory(params={'version': 1}, user=user),
{'comment_id': comment.comment_id})
assert result == {}
assert db.session.query(db.Comment).count() == 0
assert db.session.query(model.Comment).count() == 0
def test_deleting_someones_else_comment(
user_factory, comment_factory, context_factory):
user1 = user_factory(rank=db.User.RANK_REGULAR)
user2 = user_factory(rank=db.User.RANK_MODERATOR)
user1 = user_factory(rank=model.User.RANK_REGULAR)
user2 = user_factory(rank=model.User.RANK_MODERATOR)
comment = comment_factory(user=user1)
db.session.add(comment)
db.session.commit()
api.comment_api.delete_comment(
context_factory(params={'version': 1}, user=user2),
{'comment_id': comment.comment_id})
assert db.session.query(db.Comment).count() == 0
assert db.session.query(model.Comment).count() == 0
def test_trying_to_delete_someones_else_comment_without_privileges(
user_factory, comment_factory, context_factory):
user1 = user_factory(rank=db.User.RANK_REGULAR)
user2 = user_factory(rank=db.User.RANK_REGULAR)
user1 = user_factory(rank=model.User.RANK_REGULAR)
user2 = user_factory(rank=model.User.RANK_REGULAR)
comment = comment_factory(user=user1)
db.session.add(comment)
db.session.commit()
@ -49,7 +49,7 @@ def test_trying_to_delete_someones_else_comment_without_privileges(
api.comment_api.delete_comment(
context_factory(params={'version': 1}, user=user2),
{'comment_id': comment.comment_id})
assert db.session.query(db.Comment).count() == 1
assert db.session.query(model.Comment).count() == 1
def test_trying_to_delete_non_existing(user_factory, context_factory):
@ -57,5 +57,5 @@ def test_trying_to_delete_non_existing(user_factory, context_factory):
api.comment_api.delete_comment(
context_factory(
params={'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)),
user=user_factory(rank=model.User.RANK_REGULAR)),
{'comment_id': 1})

View file

@ -1,17 +1,18 @@
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors
from szurubooru import api, db, model, errors
from szurubooru.func import comments
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({'privileges': {'comments:score': db.User.RANK_REGULAR}})
config_injector(
{'privileges': {'comments:score': model.User.RANK_REGULAR}})
def test_simple_rating(
user_factory, comment_factory, context_factory, fake_datetime):
user = user_factory(rank=db.User.RANK_REGULAR)
user = user_factory(rank=model.User.RANK_REGULAR)
comment = comment_factory(user=user)
db.session.add(comment)
db.session.commit()
@ -22,14 +23,14 @@ def test_simple_rating(
context_factory(params={'score': 1}, user=user),
{'comment_id': comment.comment_id})
assert result == 'serialized comment'
assert db.session.query(db.CommentScore).count() == 1
assert db.session.query(model.CommentScore).count() == 1
assert comment is not None
assert comment.score == 1
def test_updating_rating(
user_factory, comment_factory, context_factory, fake_datetime):
user = user_factory(rank=db.User.RANK_REGULAR)
user = user_factory(rank=model.User.RANK_REGULAR)
comment = comment_factory(user=user)
db.session.add(comment)
db.session.commit()
@ -42,14 +43,14 @@ def test_updating_rating(
api.comment_api.set_comment_score(
context_factory(params={'score': -1}, user=user),
{'comment_id': comment.comment_id})
comment = db.session.query(db.Comment).one()
assert db.session.query(db.CommentScore).count() == 1
comment = db.session.query(model.Comment).one()
assert db.session.query(model.CommentScore).count() == 1
assert comment.score == -1
def test_updating_rating_to_zero(
user_factory, comment_factory, context_factory, fake_datetime):
user = user_factory(rank=db.User.RANK_REGULAR)
user = user_factory(rank=model.User.RANK_REGULAR)
comment = comment_factory(user=user)
db.session.add(comment)
db.session.commit()
@ -62,14 +63,14 @@ def test_updating_rating_to_zero(
api.comment_api.set_comment_score(
context_factory(params={'score': 0}, user=user),
{'comment_id': comment.comment_id})
comment = db.session.query(db.Comment).one()
assert db.session.query(db.CommentScore).count() == 0
comment = db.session.query(model.Comment).one()
assert db.session.query(model.CommentScore).count() == 0
assert comment.score == 0
def test_deleting_rating(
user_factory, comment_factory, context_factory, fake_datetime):
user = user_factory(rank=db.User.RANK_REGULAR)
user = user_factory(rank=model.User.RANK_REGULAR)
comment = comment_factory(user=user)
db.session.add(comment)
db.session.commit()
@ -82,15 +83,15 @@ def test_deleting_rating(
api.comment_api.delete_comment_score(
context_factory(user=user),
{'comment_id': comment.comment_id})
comment = db.session.query(db.Comment).one()
assert db.session.query(db.CommentScore).count() == 0
comment = db.session.query(model.Comment).one()
assert db.session.query(model.CommentScore).count() == 0
assert comment.score == 0
def test_ratings_from_multiple_users(
user_factory, comment_factory, context_factory, fake_datetime):
user1 = user_factory(rank=db.User.RANK_REGULAR)
user2 = user_factory(rank=db.User.RANK_REGULAR)
user1 = user_factory(rank=model.User.RANK_REGULAR)
user2 = user_factory(rank=model.User.RANK_REGULAR)
comment = comment_factory()
db.session.add_all([user1, user2, comment])
db.session.commit()
@ -103,8 +104,8 @@ def test_ratings_from_multiple_users(
api.comment_api.set_comment_score(
context_factory(params={'score': -1}, user=user2),
{'comment_id': comment.comment_id})
comment = db.session.query(db.Comment).one()
assert db.session.query(db.CommentScore).count() == 2
comment = db.session.query(model.Comment).one()
assert db.session.query(model.CommentScore).count() == 2
assert comment.score == 0
@ -125,7 +126,7 @@ def test_trying_to_update_non_existing(user_factory, context_factory):
api.comment_api.set_comment_score(
context_factory(
params={'score': 1},
user=user_factory(rank=db.User.RANK_REGULAR)),
user=user_factory(rank=model.User.RANK_REGULAR)),
{'comment_id': 5})
@ -138,5 +139,5 @@ def test_trying_to_rate_without_privileges(
api.comment_api.set_comment_score(
context_factory(
params={'score': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
user=user_factory(rank=model.User.RANK_ANONYMOUS)),
{'comment_id': comment.comment_id})

View file

@ -1,6 +1,6 @@
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors
from szurubooru import api, db, model, errors
from szurubooru.func import comments
@ -8,8 +8,8 @@ from szurubooru.func import comments
def inject_config(config_injector):
config_injector({
'privileges': {
'comments:list': db.User.RANK_REGULAR,
'comments:view': db.User.RANK_REGULAR,
'comments:list': model.User.RANK_REGULAR,
'comments:view': model.User.RANK_REGULAR,
},
})
@ -24,7 +24,7 @@ def test_retrieving_multiple(user_factory, comment_factory, context_factory):
result = api.comment_api.get_comments(
context_factory(
params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_REGULAR)))
user=user_factory(rank=model.User.RANK_REGULAR)))
assert result == {
'query': '',
'page': 1,
@ -40,7 +40,7 @@ def test_trying_to_retrieve_multiple_without_privileges(
api.comment_api.get_comments(
context_factory(
params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
user=user_factory(rank=model.User.RANK_ANONYMOUS)))
def test_retrieving_single(user_factory, comment_factory, context_factory):
@ -51,7 +51,7 @@ def test_retrieving_single(user_factory, comment_factory, context_factory):
comments.serialize_comment.return_value = 'serialized comment'
result = api.comment_api.get_comment(
context_factory(
user=user_factory(rank=db.User.RANK_REGULAR)),
user=user_factory(rank=model.User.RANK_REGULAR)),
{'comment_id': comment.comment_id})
assert result == 'serialized comment'
@ -60,7 +60,7 @@ def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
with pytest.raises(comments.CommentNotFoundError):
api.comment_api.get_comment(
context_factory(
user=user_factory(rank=db.User.RANK_REGULAR)),
user=user_factory(rank=model.User.RANK_REGULAR)),
{'comment_id': 5})
@ -68,5 +68,5 @@ def test_trying_to_retrieve_single_without_privileges(
user_factory, context_factory):
with pytest.raises(errors.AuthError):
api.comment_api.get_comment(
context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)),
context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)),
{'comment_id': 5})

View file

@ -1,7 +1,7 @@
from datetime import datetime
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors
from szurubooru import api, db, model, errors
from szurubooru.func import comments
@ -9,15 +9,15 @@ from szurubooru.func import comments
def inject_config(config_injector):
config_injector({
'privileges': {
'comments:edit:own': db.User.RANK_REGULAR,
'comments:edit:any': db.User.RANK_MODERATOR,
'comments:edit:own': model.User.RANK_REGULAR,
'comments:edit:any': model.User.RANK_MODERATOR,
},
})
def test_simple_updating(
user_factory, comment_factory, context_factory, fake_datetime):
user = user_factory(rank=db.User.RANK_REGULAR)
user = user_factory(rank=model.User.RANK_REGULAR)
comment = comment_factory(user=user)
db.session.add(comment)
db.session.commit()
@ -73,14 +73,14 @@ def test_trying_to_update_non_existing(user_factory, context_factory):
api.comment_api.update_comment(
context_factory(
params={'text': 'new text'},
user=user_factory(rank=db.User.RANK_REGULAR)),
user=user_factory(rank=model.User.RANK_REGULAR)),
{'comment_id': 5})
def test_trying_to_update_someones_comment_without_privileges(
user_factory, comment_factory, context_factory):
user = user_factory(rank=db.User.RANK_REGULAR)
user2 = user_factory(rank=db.User.RANK_REGULAR)
user = user_factory(rank=model.User.RANK_REGULAR)
user2 = user_factory(rank=model.User.RANK_REGULAR)
comment = comment_factory(user=user)
db.session.add(comment)
db.session.commit()
@ -93,8 +93,8 @@ def test_trying_to_update_someones_comment_without_privileges(
def test_updating_someones_comment_with_privileges(
user_factory, comment_factory, context_factory):
user = user_factory(rank=db.User.RANK_REGULAR)
user2 = user_factory(rank=db.User.RANK_MODERATOR)
user = user_factory(rank=model.User.RANK_REGULAR)
user2 = user_factory(rank=model.User.RANK_MODERATOR)
comment = comment_factory(user=user)
db.session.add(comment)
db.session.commit()

View file

@ -1,6 +1,6 @@
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors
from szurubooru import api, db, model, errors
from szurubooru.func import auth, mailer
@ -15,7 +15,7 @@ def inject_config(config_injector):
def test_reset_sending_email(context_factory, user_factory):
db.session.add(user_factory(
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com'))
name='u1', rank=model.User.RANK_REGULAR, email='user@example.com'))
db.session.flush()
for initiating_user in ['u1', 'user@example.com']:
with patch('szurubooru.func.mailer.send_mail'):
@ -39,7 +39,7 @@ def test_trying_to_reset_non_existing(context_factory):
def test_trying_to_reset_without_email(context_factory, user_factory):
db.session.add(
user_factory(name='u1', rank=db.User.RANK_REGULAR, email=None))
user_factory(name='u1', rank=model.User.RANK_REGULAR, email=None))
db.session.flush()
with pytest.raises(errors.ValidationError):
api.password_reset_api.start_password_reset(
@ -48,7 +48,7 @@ def test_trying_to_reset_without_email(context_factory, user_factory):
def test_confirming_with_good_token(context_factory, user_factory):
user = user_factory(
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')
name='u1', rank=model.User.RANK_REGULAR, email='user@example.com')
old_hash = user.password_hash
db.session.add(user)
db.session.flush()
@ -68,7 +68,7 @@ def test_trying_to_confirm_non_existing(context_factory):
def test_trying_to_confirm_without_token(context_factory, user_factory):
db.session.add(user_factory(
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com'))
name='u1', rank=model.User.RANK_REGULAR, email='user@example.com'))
db.session.flush()
with pytest.raises(errors.ValidationError):
api.password_reset_api.finish_password_reset(
@ -77,7 +77,7 @@ def test_trying_to_confirm_without_token(context_factory, user_factory):
def test_trying_to_confirm_with_bad_token(context_factory, user_factory):
db.session.add(user_factory(
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com'))
name='u1', rank=model.User.RANK_REGULAR, email='user@example.com'))
db.session.flush()
with pytest.raises(errors.ValidationError):
api.password_reset_api.finish_password_reset(

View file

@ -1,6 +1,6 @@
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors
from szurubooru import api, db, model, errors
from szurubooru.func import posts, tags, snapshots, net
@ -8,16 +8,16 @@ from szurubooru.func import posts, tags, snapshots, net
def inject_config(config_injector):
config_injector({
'privileges': {
'posts:create:anonymous': db.User.RANK_REGULAR,
'posts:create:identified': db.User.RANK_REGULAR,
'tags:create': db.User.RANK_REGULAR,
'posts:create:anonymous': model.User.RANK_REGULAR,
'posts:create:identified': model.User.RANK_REGULAR,
'tags:create': model.User.RANK_REGULAR,
},
})
def test_creating_minimal_posts(
context_factory, post_factory, user_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR)
auth_user = user_factory(rank=model.User.RANK_REGULAR)
post = post_factory()
db.session.add(post)
db.session.flush()
@ -53,20 +53,20 @@ def test_creating_minimal_posts(
posts.update_post_thumbnail.assert_called_once_with(
post, 'post-thumbnail')
posts.update_post_safety.assert_called_once_with(post, 'safe')
posts.update_post_source.assert_called_once_with(post, None)
posts.update_post_source.assert_called_once_with(post, '')
posts.update_post_relations.assert_called_once_with(post, [])
posts.update_post_notes.assert_called_once_with(post, [])
posts.update_post_flags.assert_called_once_with(post, [])
posts.update_post_thumbnail.assert_called_once_with(
post, 'post-thumbnail')
posts.serialize_post.assert_called_once_with(
post, auth_user, options=None)
post, auth_user, options=[])
snapshots.create.assert_called_once_with(post, auth_user)
tags.export_to_json.assert_called_once_with()
def test_creating_full_posts(context_factory, post_factory, user_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR)
auth_user = user_factory(rank=model.User.RANK_REGULAR)
post = post_factory()
db.session.add(post)
db.session.flush()
@ -109,14 +109,14 @@ def test_creating_full_posts(context_factory, post_factory, user_factory):
posts.update_post_flags.assert_called_once_with(
post, ['flag1', 'flag2'])
posts.serialize_post.assert_called_once_with(
post, auth_user, options=None)
post, auth_user, options=[])
snapshots.create.assert_called_once_with(post, auth_user)
tags.export_to_json.assert_called_once_with()
def test_anonymous_uploads(
config_injector, context_factory, post_factory, user_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR)
auth_user = user_factory(rank=model.User.RANK_REGULAR)
post = post_factory()
db.session.add(post)
db.session.flush()
@ -126,7 +126,7 @@ def test_anonymous_uploads(
patch('szurubooru.func.posts.create_post'), \
patch('szurubooru.func.posts.update_post_source'):
config_injector({
'privileges': {'posts:create:anonymous': db.User.RANK_REGULAR},
'privileges': {'posts:create:anonymous': model.User.RANK_REGULAR},
})
posts.create_post.return_value = [post, []]
api.post_api.create_post(
@ -146,7 +146,7 @@ def test_anonymous_uploads(
def test_creating_from_url_saves_source(
config_injector, context_factory, post_factory, user_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR)
auth_user = user_factory(rank=model.User.RANK_REGULAR)
post = post_factory()
db.session.add(post)
db.session.flush()
@ -157,7 +157,7 @@ def test_creating_from_url_saves_source(
patch('szurubooru.func.posts.create_post'), \
patch('szurubooru.func.posts.update_post_source'):
config_injector({
'privileges': {'posts:create:identified': db.User.RANK_REGULAR},
'privileges': {'posts:create:identified': model.User.RANK_REGULAR},
})
net.download.return_value = b'content'
posts.create_post.return_value = [post, []]
@ -177,7 +177,7 @@ def test_creating_from_url_saves_source(
def test_creating_from_url_with_source_specified(
config_injector, context_factory, post_factory, user_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR)
auth_user = user_factory(rank=model.User.RANK_REGULAR)
post = post_factory()
db.session.add(post)
db.session.flush()
@ -188,7 +188,7 @@ def test_creating_from_url_with_source_specified(
patch('szurubooru.func.posts.create_post'), \
patch('szurubooru.func.posts.update_post_source'):
config_injector({
'privileges': {'posts:create:identified': db.User.RANK_REGULAR},
'privileges': {'posts:create:identified': model.User.RANK_REGULAR},
})
net.download.return_value = b'content'
posts.create_post.return_value = [post, []]
@ -218,14 +218,14 @@ def test_trying_to_omit_mandatory_field(context_factory, user_factory, field):
context_factory(
params=params,
files={'content': '...'},
user=user_factory(rank=db.User.RANK_REGULAR)))
user=user_factory(rank=model.User.RANK_REGULAR)))
@pytest.mark.parametrize(
'field', ['tags', 'relations', 'source', 'notes', 'flags'])
def test_omitting_optional_field(
field, context_factory, post_factory, user_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR)
auth_user = user_factory(rank=model.User.RANK_REGULAR)
post = post_factory()
db.session.add(post)
db.session.flush()
@ -268,10 +268,10 @@ def test_errors_not_spending_ids(
'post_height': 300,
},
'privileges': {
'posts:create:identified': db.User.RANK_REGULAR,
'posts:create:identified': model.User.RANK_REGULAR,
},
})
auth_user = user_factory(rank=db.User.RANK_REGULAR)
auth_user = user_factory(rank=model.User.RANK_REGULAR)
# successful request
with patch('szurubooru.func.posts.serialize_post'), \
@ -316,7 +316,7 @@ def test_trying_to_omit_content(context_factory, user_factory):
'safety': 'safe',
'tags': ['tag1', 'tag2'],
},
user=user_factory(rank=db.User.RANK_REGULAR)))
user=user_factory(rank=model.User.RANK_REGULAR)))
def test_trying_to_create_post_without_privileges(
@ -324,16 +324,16 @@ def test_trying_to_create_post_without_privileges(
with pytest.raises(errors.AuthError):
api.post_api.create_post(context_factory(
params='whatever',
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
user=user_factory(rank=model.User.RANK_ANONYMOUS)))
def test_trying_to_create_tags_without_privileges(
config_injector, context_factory, user_factory):
config_injector({
'privileges': {
'posts:create:anonymous': db.User.RANK_REGULAR,
'posts:create:identified': db.User.RANK_REGULAR,
'tags:create': db.User.RANK_ADMINISTRATOR,
'posts:create:anonymous': model.User.RANK_REGULAR,
'posts:create:identified': model.User.RANK_REGULAR,
'tags:create': model.User.RANK_ADMINISTRATOR,
},
})
with pytest.raises(errors.AuthError), \
@ -349,4 +349,4 @@ def test_trying_to_create_tags_without_privileges(
files={
'content': posts.EMPTY_PIXEL,
},
user=user_factory(rank=db.User.RANK_REGULAR)))
user=user_factory(rank=model.User.RANK_REGULAR)))

View file

@ -1,16 +1,16 @@
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors
from szurubooru import api, db, model, errors
from szurubooru.func import posts, tags, snapshots
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({'privileges': {'posts:delete': db.User.RANK_REGULAR}})
config_injector({'privileges': {'posts:delete': model.User.RANK_REGULAR}})
def test_deleting(user_factory, post_factory, context_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR)
auth_user = user_factory(rank=model.User.RANK_REGULAR)
post = post_factory(id=1)
db.session.add(post)
db.session.flush()
@ -20,7 +20,7 @@ def test_deleting(user_factory, post_factory, context_factory):
context_factory(params={'version': 1}, user=auth_user),
{'post_id': 1})
assert result == {}
assert db.session.query(db.Post).count() == 0
assert db.session.query(model.Post).count() == 0
snapshots.delete.assert_called_once_with(post, auth_user)
tags.export_to_json.assert_called_once_with()
@ -28,7 +28,7 @@ def test_deleting(user_factory, post_factory, context_factory):
def test_trying_to_delete_non_existing(user_factory, context_factory):
with pytest.raises(posts.PostNotFoundError):
api.post_api.delete_post(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
context_factory(user=user_factory(rank=model.User.RANK_REGULAR)),
{'post_id': 999})
@ -38,6 +38,6 @@ def test_trying_to_delete_without_privileges(
db.session.commit()
with pytest.raises(errors.AuthError):
api.post_api.delete_post(
context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)),
context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)),
{'post_id': 1})
assert db.session.query(db.Post).count() == 1
assert db.session.query(model.Post).count() == 1

View file

@ -1,13 +1,14 @@
from datetime import datetime
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors
from szurubooru import api, db, model, errors
from szurubooru.func import posts
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({'privileges': {'posts:favorite': db.User.RANK_REGULAR}})
config_injector(
{'privileges': {'posts:favorite': model.User.RANK_REGULAR}})
def test_adding_to_favorites(
@ -23,8 +24,8 @@ def test_adding_to_favorites(
context_factory(user=user_factory()),
{'post_id': post.post_id})
assert result == 'serialized post'
post = db.session.query(db.Post).one()
assert db.session.query(db.PostFavorite).count() == 1
post = db.session.query(model.Post).one()
assert db.session.query(model.PostFavorite).count() == 1
assert post is not None
assert post.favorite_count == 1
assert post.score == 1
@ -47,9 +48,9 @@ def test_removing_from_favorites(
api.post_api.delete_post_from_favorites(
context_factory(user=user),
{'post_id': post.post_id})
post = db.session.query(db.Post).one()
post = db.session.query(model.Post).one()
assert post.score == 1
assert db.session.query(db.PostFavorite).count() == 0
assert db.session.query(model.PostFavorite).count() == 0
assert post.favorite_count == 0
@ -68,8 +69,8 @@ def test_favoriting_twice(
api.post_api.add_post_to_favorites(
context_factory(user=user),
{'post_id': post.post_id})
post = db.session.query(db.Post).one()
assert db.session.query(db.PostFavorite).count() == 1
post = db.session.query(model.Post).one()
assert db.session.query(model.PostFavorite).count() == 1
assert post.favorite_count == 1
@ -92,8 +93,8 @@ def test_removing_twice(
api.post_api.delete_post_from_favorites(
context_factory(user=user),
{'post_id': post.post_id})
post = db.session.query(db.Post).one()
assert db.session.query(db.PostFavorite).count() == 0
post = db.session.query(model.Post).one()
assert db.session.query(model.PostFavorite).count() == 0
assert post.favorite_count == 0
@ -113,8 +114,8 @@ def test_favorites_from_multiple_users(
api.post_api.add_post_to_favorites(
context_factory(user=user2),
{'post_id': post.post_id})
post = db.session.query(db.Post).one()
assert db.session.query(db.PostFavorite).count() == 2
post = db.session.query(model.Post).one()
assert db.session.query(model.PostFavorite).count() == 2
assert post.favorite_count == 2
assert post.last_favorite_time == datetime(1997, 12, 2)
@ -133,5 +134,5 @@ def test_trying_to_rate_without_privileges(
db.session.commit()
with pytest.raises(errors.AuthError):
api.post_api.add_post_to_favorites(
context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)),
context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)),
{'post_id': post.post_id})

View file

@ -1,6 +1,6 @@
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors
from szurubooru import api, db, model, errors
from szurubooru.func import posts, snapshots
@ -8,14 +8,14 @@ from szurubooru.func import posts, snapshots
def inject_config(config_injector):
config_injector({
'privileges': {
'posts:feature': db.User.RANK_REGULAR,
'posts:view': db.User.RANK_REGULAR,
'posts:feature': model.User.RANK_REGULAR,
'posts:view': model.User.RANK_REGULAR,
},
})
def test_featuring(user_factory, post_factory, context_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR)
auth_user = user_factory(rank=model.User.RANK_REGULAR)
post = post_factory(id=1)
db.session.add(post)
db.session.flush()
@ -31,7 +31,7 @@ def test_featuring(user_factory, post_factory, context_factory):
assert posts.get_post_by_id(1).is_featured
result = api.post_api.get_featured_post(
context_factory(
user=user_factory(rank=db.User.RANK_REGULAR)))
user=user_factory(rank=model.User.RANK_REGULAR)))
assert result == 'serialized post'
snapshots.modify.assert_called_once_with(post, auth_user)
@ -40,7 +40,7 @@ def test_trying_to_omit_required_parameter(user_factory, context_factory):
with pytest.raises(errors.MissingRequiredParameterError):
api.post_api.set_featured_post(
context_factory(
user=user_factory(rank=db.User.RANK_REGULAR)))
user=user_factory(rank=model.User.RANK_REGULAR)))
def test_trying_to_feature_the_same_post_twice(
@ -51,12 +51,12 @@ def test_trying_to_feature_the_same_post_twice(
api.post_api.set_featured_post(
context_factory(
params={'id': 1},
user=user_factory(rank=db.User.RANK_REGULAR)))
user=user_factory(rank=model.User.RANK_REGULAR)))
with pytest.raises(posts.PostAlreadyFeaturedError):
api.post_api.set_featured_post(
context_factory(
params={'id': 1},
user=user_factory(rank=db.User.RANK_REGULAR)))
user=user_factory(rank=model.User.RANK_REGULAR)))
def test_featuring_one_post_after_another(
@ -72,12 +72,12 @@ def test_featuring_one_post_after_another(
api.post_api.set_featured_post(
context_factory(
params={'id': 1},
user=user_factory(rank=db.User.RANK_REGULAR)))
user=user_factory(rank=model.User.RANK_REGULAR)))
with fake_datetime('1998'):
api.post_api.set_featured_post(
context_factory(
params={'id': 2},
user=user_factory(rank=db.User.RANK_REGULAR)))
user=user_factory(rank=model.User.RANK_REGULAR)))
assert posts.try_get_featured_post() is not None
assert posts.try_get_featured_post().post_id == 2
assert not posts.get_post_by_id(1).is_featured
@ -89,7 +89,7 @@ def test_trying_to_feature_non_existing(user_factory, context_factory):
api.post_api.set_featured_post(
context_factory(
params={'id': 1},
user=user_factory(rank=db.User.RANK_REGULAR)))
user=user_factory(rank=model.User.RANK_REGULAR)))
def test_trying_to_feature_without_privileges(user_factory, context_factory):
@ -97,10 +97,10 @@ def test_trying_to_feature_without_privileges(user_factory, context_factory):
api.post_api.set_featured_post(
context_factory(
params={'id': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
user=user_factory(rank=model.User.RANK_ANONYMOUS)))
def test_getting_featured_post_without_privileges_to_view(
user_factory, context_factory):
api.post_api.get_featured_post(
context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)))
context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)))

View file

@ -1,16 +1,16 @@
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors
from szurubooru import api, db, model, errors
from szurubooru.func import posts, snapshots
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({'privileges': {'posts:merge': db.User.RANK_REGULAR}})
config_injector({'privileges': {'posts:merge': model.User.RANK_REGULAR}})
def test_merging(user_factory, context_factory, post_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR)
auth_user = user_factory(rank=model.User.RANK_REGULAR)
source_post = post_factory()
target_post = post_factory()
db.session.add_all([source_post, target_post])
@ -25,6 +25,7 @@ def test_merging(user_factory, context_factory, post_factory):
'mergeToVersion': 1,
'remove': source_post.post_id,
'mergeTo': target_post.post_id,
'replaceContent': False,
},
user=auth_user))
posts.merge_posts.called_once_with(source_post, target_post)
@ -45,13 +46,14 @@ def test_trying_to_omit_mandatory_field(
'mergeToVersion': 1,
'remove': source_post.post_id,
'mergeTo': target_post.post_id,
'replaceContent': False,
}
del params[field]
with pytest.raises(errors.ValidationError):
api.post_api.merge_posts(
context_factory(
params=params,
user=user_factory(rank=db.User.RANK_REGULAR)))
user=user_factory(rank=model.User.RANK_REGULAR)))
def test_trying_to_merge_non_existing(
@ -63,12 +65,12 @@ def test_trying_to_merge_non_existing(
api.post_api.merge_posts(
context_factory(
params={'remove': post.post_id, 'mergeTo': 999},
user=user_factory(rank=db.User.RANK_REGULAR)))
user=user_factory(rank=model.User.RANK_REGULAR)))
with pytest.raises(posts.PostNotFoundError):
api.post_api.merge_posts(
context_factory(
params={'remove': 999, 'mergeTo': post.post_id},
user=user_factory(rank=db.User.RANK_REGULAR)))
user=user_factory(rank=model.User.RANK_REGULAR)))
def test_trying_to_merge_without_privileges(
@ -85,5 +87,6 @@ def test_trying_to_merge_without_privileges(
'mergeToVersion': 1,
'remove': source_post.post_id,
'mergeTo': target_post.post_id,
'replaceContent': False,
},
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
user=user_factory(rank=model.User.RANK_ANONYMOUS)))

View file

@ -1,12 +1,12 @@
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors
from szurubooru import api, db, model, errors
from szurubooru.func import posts
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({'privileges': {'posts:score': db.User.RANK_REGULAR}})
config_injector({'privileges': {'posts:score': model.User.RANK_REGULAR}})
def test_simple_rating(
@ -22,8 +22,8 @@ def test_simple_rating(
params={'score': 1}, user=user_factory()),
{'post_id': post.post_id})
assert result == 'serialized post'
post = db.session.query(db.Post).one()
assert db.session.query(db.PostScore).count() == 1
post = db.session.query(model.Post).one()
assert db.session.query(model.PostScore).count() == 1
assert post is not None
assert post.score == 1
@ -43,8 +43,8 @@ def test_updating_rating(
api.post_api.set_post_score(
context_factory(params={'score': -1}, user=user),
{'post_id': post.post_id})
post = db.session.query(db.Post).one()
assert db.session.query(db.PostScore).count() == 1
post = db.session.query(model.Post).one()
assert db.session.query(model.PostScore).count() == 1
assert post.score == -1
@ -63,8 +63,8 @@ def test_updating_rating_to_zero(
api.post_api.set_post_score(
context_factory(params={'score': 0}, user=user),
{'post_id': post.post_id})
post = db.session.query(db.Post).one()
assert db.session.query(db.PostScore).count() == 0
post = db.session.query(model.Post).one()
assert db.session.query(model.PostScore).count() == 0
assert post.score == 0
@ -83,8 +83,8 @@ def test_deleting_rating(
api.post_api.delete_post_score(
context_factory(user=user),
{'post_id': post.post_id})
post = db.session.query(db.Post).one()
assert db.session.query(db.PostScore).count() == 0
post = db.session.query(model.Post).one()
assert db.session.query(model.PostScore).count() == 0
assert post.score == 0
@ -104,8 +104,8 @@ def test_ratings_from_multiple_users(
api.post_api.set_post_score(
context_factory(params={'score': -1}, user=user2),
{'post_id': post.post_id})
post = db.session.query(db.Post).one()
assert db.session.query(db.PostScore).count() == 2
post = db.session.query(model.Post).one()
assert db.session.query(model.PostScore).count() == 2
assert post.score == 0
@ -136,5 +136,5 @@ def test_trying_to_rate_without_privileges(
api.post_api.set_post_score(
context_factory(
params={'score': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
user=user_factory(rank=model.User.RANK_ANONYMOUS)),
{'post_id': post.post_id})

View file

@ -1,7 +1,7 @@
from datetime import datetime
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors
from szurubooru import api, db, model, errors
from szurubooru.func import posts
@ -9,8 +9,8 @@ from szurubooru.func import posts
def inject_config(config_injector):
config_injector({
'privileges': {
'posts:list': db.User.RANK_REGULAR,
'posts:view': db.User.RANK_REGULAR,
'posts:list': model.User.RANK_REGULAR,
'posts:view': model.User.RANK_REGULAR,
},
})
@ -25,7 +25,7 @@ def test_retrieving_multiple(user_factory, post_factory, context_factory):
result = api.post_api.get_posts(
context_factory(
params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_REGULAR)))
user=user_factory(rank=model.User.RANK_REGULAR)))
assert result == {
'query': '',
'page': 1,
@ -36,10 +36,10 @@ def test_retrieving_multiple(user_factory, post_factory, context_factory):
def test_using_special_tokens(user_factory, post_factory, context_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR)
auth_user = user_factory(rank=model.User.RANK_REGULAR)
post1 = post_factory(id=1)
post2 = post_factory(id=2)
post1.favorited_by = [db.PostFavorite(
post1.favorited_by = [model.PostFavorite(
user=auth_user, time=datetime.utcnow())]
db.session.add_all([post1, post2, auth_user])
db.session.flush()
@ -68,7 +68,7 @@ def test_trying_to_use_special_tokens_without_logging_in(
api.post_api.get_posts(
context_factory(
params={'query': 'special:fav', 'page': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
user=user_factory(rank=model.User.RANK_ANONYMOUS)))
def test_trying_to_retrieve_multiple_without_privileges(
@ -77,7 +77,7 @@ def test_trying_to_retrieve_multiple_without_privileges(
api.post_api.get_posts(
context_factory(
params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
user=user_factory(rank=model.User.RANK_ANONYMOUS)))
def test_retrieving_single(user_factory, post_factory, context_factory):
@ -86,7 +86,7 @@ def test_retrieving_single(user_factory, post_factory, context_factory):
with patch('szurubooru.func.posts.serialize_post'):
posts.serialize_post.return_value = 'serialized post'
result = api.post_api.get_post(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
context_factory(user=user_factory(rank=model.User.RANK_REGULAR)),
{'post_id': 1})
assert result == 'serialized post'
@ -94,7 +94,7 @@ def test_retrieving_single(user_factory, post_factory, context_factory):
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
with pytest.raises(posts.PostNotFoundError):
api.post_api.get_post(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
context_factory(user=user_factory(rank=model.User.RANK_REGULAR)),
{'post_id': 999})
@ -102,5 +102,5 @@ def test_trying_to_retrieve_single_without_privileges(
user_factory, context_factory):
with pytest.raises(errors.AuthError):
api.post_api.get_post(
context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)),
context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)),
{'post_id': 999})

View file

@ -1,7 +1,7 @@
from datetime import datetime
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors
from szurubooru import api, db, model, errors
from szurubooru.func import posts, tags, snapshots, net
@ -9,22 +9,22 @@ from szurubooru.func import posts, tags, snapshots, net
def inject_config(config_injector):
config_injector({
'privileges': {
'posts:edit:tags': db.User.RANK_REGULAR,
'posts:edit:content': db.User.RANK_REGULAR,
'posts:edit:safety': db.User.RANK_REGULAR,
'posts:edit:source': db.User.RANK_REGULAR,
'posts:edit:relations': db.User.RANK_REGULAR,
'posts:edit:notes': db.User.RANK_REGULAR,
'posts:edit:flags': db.User.RANK_REGULAR,
'posts:edit:thumbnail': db.User.RANK_REGULAR,
'tags:create': db.User.RANK_MODERATOR,
'posts:edit:tags': model.User.RANK_REGULAR,
'posts:edit:content': model.User.RANK_REGULAR,
'posts:edit:safety': model.User.RANK_REGULAR,
'posts:edit:source': model.User.RANK_REGULAR,
'posts:edit:relations': model.User.RANK_REGULAR,
'posts:edit:notes': model.User.RANK_REGULAR,
'posts:edit:flags': model.User.RANK_REGULAR,
'posts:edit:thumbnail': model.User.RANK_REGULAR,
'tags:create': model.User.RANK_MODERATOR,
},
})
def test_post_updating(
context_factory, post_factory, user_factory, fake_datetime):
auth_user = user_factory(rank=db.User.RANK_REGULAR)
auth_user = user_factory(rank=model.User.RANK_REGULAR)
post = post_factory()
db.session.add(post)
db.session.flush()
@ -76,7 +76,7 @@ def test_post_updating(
posts.update_post_flags.assert_called_once_with(
post, ['flag1', 'flag2'])
posts.serialize_post.assert_called_once_with(
post, auth_user, options=None)
post, auth_user, options=[])
snapshots.modify.assert_called_once_with(post, auth_user)
tags.export_to_json.assert_called_once_with()
assert post.last_edit_time == datetime(1997, 1, 1)
@ -97,7 +97,7 @@ def test_uploading_from_url_saves_source(
api.post_api.update_post(
context_factory(
params={'contentUrl': 'example.com', 'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)),
user=user_factory(rank=model.User.RANK_REGULAR)),
{'post_id': post.post_id})
net.download.assert_called_once_with('example.com')
posts.update_post_content.assert_called_once_with(post, b'content')
@ -122,7 +122,7 @@ def test_uploading_from_url_with_source_specified(
'contentUrl': 'example.com',
'source': 'example2.com',
'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)),
user=user_factory(rank=model.User.RANK_REGULAR)),
{'post_id': post.post_id})
net.download.assert_called_once_with('example.com')
posts.update_post_content.assert_called_once_with(post, b'content')
@ -134,7 +134,7 @@ def test_trying_to_update_non_existing(context_factory, user_factory):
api.post_api.update_post(
context_factory(
params='whatever',
user=user_factory(rank=db.User.RANK_REGULAR)),
user=user_factory(rank=model.User.RANK_REGULAR)),
{'post_id': 1})
@ -158,7 +158,7 @@ def test_trying_to_update_field_without_privileges(
context_factory(
params={**params, **{'version': 1}},
files=files,
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
user=user_factory(rank=model.User.RANK_ANONYMOUS)),
{'post_id': post.post_id})
@ -173,5 +173,5 @@ def test_trying_to_create_tags_without_privileges(
api.post_api.update_post(
context_factory(
params={'tags': ['tag1', 'tag2'], 'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)),
user=user_factory(rank=model.User.RANK_REGULAR)),
{'post_id': post.post_id})

View file

@ -1,10 +1,10 @@
from datetime import datetime
import pytest
from szurubooru import api, db, errors
from szurubooru import api, db, model, errors
def snapshot_factory():
snapshot = db.Snapshot()
snapshot = model.Snapshot()
snapshot.creation_time = datetime(1999, 1, 1)
snapshot.resource_type = 'dummy'
snapshot.resource_pkey = 1
@ -17,7 +17,7 @@ def snapshot_factory():
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({
'privileges': {'snapshots:list': db.User.RANK_REGULAR},
'privileges': {'snapshots:list': model.User.RANK_REGULAR},
})
@ -29,7 +29,7 @@ def test_retrieving_multiple(user_factory, context_factory):
result = api.snapshot_api.get_snapshots(
context_factory(
params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_REGULAR)))
user=user_factory(rank=model.User.RANK_REGULAR)))
assert result['query'] == ''
assert result['page'] == 1
assert result['pageSize'] == 100
@ -43,4 +43,4 @@ def test_trying_to_retrieve_multiple_without_privileges(
api.snapshot_api.get_snapshots(
context_factory(
params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
user=user_factory(rank=model.User.RANK_ANONYMOUS)))

View file

@ -1,6 +1,6 @@
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors
from szurubooru import api, db, model, errors
from szurubooru.func import tag_categories, tags, snapshots
@ -11,13 +11,13 @@ def _update_category_name(category, name):
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({
'privileges': {'tag_categories:create': db.User.RANK_REGULAR},
'privileges': {'tag_categories:create': model.User.RANK_REGULAR},
})
def test_creating_category(
tag_category_factory, user_factory, context_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR)
auth_user = user_factory(rank=model.User.RANK_REGULAR)
category = tag_category_factory(name='meta')
db.session.add(category)
@ -49,7 +49,7 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
api.tag_category_api.create_tag_category(
context_factory(
params=params,
user=user_factory(rank=db.User.RANK_REGULAR)))
user=user_factory(rank=model.User.RANK_REGULAR)))
def test_trying_to_create_without_privileges(user_factory, context_factory):
@ -57,4 +57,4 @@ def test_trying_to_create_without_privileges(user_factory, context_factory):
api.tag_category_api.create_tag_category(
context_factory(
params={'name': 'meta', 'color': 'black'},
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
user=user_factory(rank=model.User.RANK_ANONYMOUS)))

View file

@ -1,18 +1,18 @@
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors
from szurubooru import api, db, model, errors
from szurubooru.func import tag_categories, tags, snapshots
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({
'privileges': {'tag_categories:delete': db.User.RANK_REGULAR},
'privileges': {'tag_categories:delete': model.User.RANK_REGULAR},
})
def test_deleting(user_factory, tag_category_factory, context_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR)
auth_user = user_factory(rank=model.User.RANK_REGULAR)
category = tag_category_factory(name='category')
db.session.add(tag_category_factory(name='root'))
db.session.add(category)
@ -23,8 +23,8 @@ def test_deleting(user_factory, tag_category_factory, context_factory):
context_factory(params={'version': 1}, user=auth_user),
{'category_name': 'category'})
assert result == {}
assert db.session.query(db.TagCategory).count() == 1
assert db.session.query(db.TagCategory).one().name == 'root'
assert db.session.query(model.TagCategory).count() == 1
assert db.session.query(model.TagCategory).one().name == 'root'
snapshots.delete.assert_called_once_with(category, auth_user)
tags.export_to_json.assert_called_once_with()
@ -41,9 +41,9 @@ def test_trying_to_delete_used(
api.tag_category_api.delete_tag_category(
context_factory(
params={'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)),
user=user_factory(rank=model.User.RANK_REGULAR)),
{'category_name': 'category'})
assert db.session.query(db.TagCategory).count() == 1
assert db.session.query(model.TagCategory).count() == 1
def test_trying_to_delete_last(
@ -54,14 +54,14 @@ def test_trying_to_delete_last(
api.tag_category_api.delete_tag_category(
context_factory(
params={'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)),
user=user_factory(rank=model.User.RANK_REGULAR)),
{'category_name': 'root'})
def test_trying_to_delete_non_existing(user_factory, context_factory):
with pytest.raises(tag_categories.TagCategoryNotFoundError):
api.tag_category_api.delete_tag_category(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
context_factory(user=user_factory(rank=model.User.RANK_REGULAR)),
{'category_name': 'bad'})
@ -73,6 +73,6 @@ def test_trying_to_delete_without_privileges(
api.tag_category_api.delete_tag_category(
context_factory(
params={'version': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
user=user_factory(rank=model.User.RANK_ANONYMOUS)),
{'category_name': 'category'})
assert db.session.query(db.TagCategory).count() == 1
assert db.session.query(model.TagCategory).count() == 1

View file

@ -1,5 +1,5 @@
import pytest
from szurubooru import api, db, errors
from szurubooru import api, db, model, errors
from szurubooru.func import tag_categories
@ -7,8 +7,8 @@ from szurubooru.func import tag_categories
def inject_config(config_injector):
config_injector({
'privileges': {
'tag_categories:list': db.User.RANK_REGULAR,
'tag_categories:view': db.User.RANK_REGULAR,
'tag_categories:list': model.User.RANK_REGULAR,
'tag_categories:view': model.User.RANK_REGULAR,
},
})
@ -21,7 +21,7 @@ def test_retrieving_multiple(
])
db.session.flush()
result = api.tag_category_api.get_tag_categories(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)))
context_factory(user=user_factory(rank=model.User.RANK_REGULAR)))
assert [cat['name'] for cat in result['results']] == ['c1', 'c2']
@ -30,7 +30,7 @@ def test_retrieving_single(
db.session.add(tag_category_factory(name='cat'))
db.session.flush()
result = api.tag_category_api.get_tag_category(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
context_factory(user=user_factory(rank=model.User.RANK_REGULAR)),
{'category_name': 'cat'})
assert result == {
'name': 'cat',
@ -44,7 +44,7 @@ def test_retrieving_single(
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
with pytest.raises(tag_categories.TagCategoryNotFoundError):
api.tag_category_api.get_tag_category(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
context_factory(user=user_factory(rank=model.User.RANK_REGULAR)),
{'category_name': '-'})
@ -52,5 +52,5 @@ def test_trying_to_retrieve_single_without_privileges(
user_factory, context_factory):
with pytest.raises(errors.AuthError):
api.tag_category_api.get_tag_category(
context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)),
context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)),
{'category_name': '-'})

View file

@ -1,6 +1,6 @@
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors
from szurubooru import api, db, model, errors
from szurubooru.func import tag_categories, tags, snapshots
@ -12,15 +12,15 @@ def _update_category_name(category, name):
def inject_config(config_injector):
config_injector({
'privileges': {
'tag_categories:edit:name': db.User.RANK_REGULAR,
'tag_categories:edit:color': db.User.RANK_REGULAR,
'tag_categories:set_default': db.User.RANK_REGULAR,
'tag_categories:edit:name': model.User.RANK_REGULAR,
'tag_categories:edit:color': model.User.RANK_REGULAR,
'tag_categories:set_default': model.User.RANK_REGULAR,
},
})
def test_simple_updating(user_factory, tag_category_factory, context_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR)
auth_user = user_factory(rank=model.User.RANK_REGULAR)
category = tag_category_factory(name='name', color='black')
db.session.add(category)
db.session.flush()
@ -61,7 +61,7 @@ def test_omitting_optional_field(
api.tag_category_api.update_tag_category(
context_factory(
params={**params, **{'version': 1}},
user=user_factory(rank=db.User.RANK_REGULAR)),
user=user_factory(rank=model.User.RANK_REGULAR)),
{'category_name': 'name'})
@ -70,7 +70,7 @@ def test_trying_to_update_non_existing(user_factory, context_factory):
api.tag_category_api.update_tag_category(
context_factory(
params={'name': ['dummy']},
user=user_factory(rank=db.User.RANK_REGULAR)),
user=user_factory(rank=model.User.RANK_REGULAR)),
{'category_name': 'bad'})
@ -86,7 +86,7 @@ def test_trying_to_update_without_privileges(
api.tag_category_api.update_tag_category(
context_factory(
params={**params, **{'version': 1}},
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
user=user_factory(rank=model.User.RANK_ANONYMOUS)),
{'category_name': 'dummy'})
@ -106,7 +106,7 @@ def test_set_as_default(user_factory, tag_category_factory, context_factory):
'color': 'white',
'version': 1,
},
user=user_factory(rank=db.User.RANK_REGULAR)),
user=user_factory(rank=model.User.RANK_REGULAR)),
{'category_name': 'name'})
assert result == 'serialized category'
tag_categories.set_default_category.assert_called_once_with(category)

View file

@ -1,16 +1,16 @@
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors
from szurubooru import api, db, model, errors
from szurubooru.func import tags, snapshots
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({'privileges': {'tags:create': db.User.RANK_REGULAR}})
config_injector({'privileges': {'tags:create': model.User.RANK_REGULAR}})
def test_creating_simple_tags(tag_factory, user_factory, context_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR)
auth_user = user_factory(rank=model.User.RANK_REGULAR)
tag = tag_factory()
with patch('szurubooru.func.tags.create_tag'), \
patch('szurubooru.func.tags.get_or_create_tags_by_names'), \
@ -50,7 +50,7 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
api.tag_api.create_tag(
context_factory(
params=params,
user=user_factory(rank=db.User.RANK_REGULAR)))
user=user_factory(rank=model.User.RANK_REGULAR)))
@pytest.mark.parametrize('field', ['implications', 'suggestions'])
@ -70,7 +70,7 @@ def test_omitting_optional_field(
api.tag_api.create_tag(
context_factory(
params=params,
user=user_factory(rank=db.User.RANK_REGULAR)))
user=user_factory(rank=model.User.RANK_REGULAR)))
def test_trying_to_create_tag_without_privileges(
@ -84,4 +84,4 @@ def test_trying_to_create_tag_without_privileges(
'suggestions': ['tag'],
'implications': [],
},
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
user=user_factory(rank=model.User.RANK_ANONYMOUS)))

View file

@ -1,16 +1,16 @@
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors
from szurubooru import api, db, model, errors
from szurubooru.func import tags, snapshots
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({'privileges': {'tags:delete': db.User.RANK_REGULAR}})
config_injector({'privileges': {'tags:delete': model.User.RANK_REGULAR}})
def test_deleting(user_factory, tag_factory, context_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR)
auth_user = user_factory(rank=model.User.RANK_REGULAR)
tag = tag_factory(names=['tag'])
db.session.add(tag)
db.session.commit()
@ -20,7 +20,7 @@ def test_deleting(user_factory, tag_factory, context_factory):
context_factory(params={'version': 1}, user=auth_user),
{'tag_name': 'tag'})
assert result == {}
assert db.session.query(db.Tag).count() == 0
assert db.session.query(model.Tag).count() == 0
snapshots.delete.assert_called_once_with(tag, auth_user)
tags.export_to_json.assert_called_once_with()
@ -36,17 +36,17 @@ def test_deleting_used(
api.tag_api.delete_tag(
context_factory(
params={'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)),
user=user_factory(rank=model.User.RANK_REGULAR)),
{'tag_name': 'tag'})
db.session.refresh(post)
assert db.session.query(db.Tag).count() == 0
assert db.session.query(model.Tag).count() == 0
assert post.tags == []
def test_trying_to_delete_non_existing(user_factory, context_factory):
with pytest.raises(tags.TagNotFoundError):
api.tag_api.delete_tag(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
context_factory(user=user_factory(rank=model.User.RANK_REGULAR)),
{'tag_name': 'bad'})
@ -58,6 +58,6 @@ def test_trying_to_delete_without_privileges(
api.tag_api.delete_tag(
context_factory(
params={'version': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
user=user_factory(rank=model.User.RANK_ANONYMOUS)),
{'tag_name': 'tag'})
assert db.session.query(db.Tag).count() == 1
assert db.session.query(model.Tag).count() == 1

View file

@ -1,16 +1,16 @@
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors
from szurubooru import api, db, model, errors
from szurubooru.func import tags, snapshots
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({'privileges': {'tags:merge': db.User.RANK_REGULAR}})
config_injector({'privileges': {'tags:merge': model.User.RANK_REGULAR}})
def test_merging(user_factory, tag_factory, context_factory, post_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR)
auth_user = user_factory(rank=model.User.RANK_REGULAR)
source_tag = tag_factory(names=['source'])
target_tag = tag_factory(names=['target'])
db.session.add_all([source_tag, target_tag])
@ -62,7 +62,7 @@ def test_trying_to_omit_mandatory_field(
api.tag_api.merge_tags(
context_factory(
params=params,
user=user_factory(rank=db.User.RANK_REGULAR)))
user=user_factory(rank=model.User.RANK_REGULAR)))
def test_trying_to_merge_non_existing(
@ -73,12 +73,12 @@ def test_trying_to_merge_non_existing(
api.tag_api.merge_tags(
context_factory(
params={'remove': 'good', 'mergeTo': 'bad'},
user=user_factory(rank=db.User.RANK_REGULAR)))
user=user_factory(rank=model.User.RANK_REGULAR)))
with pytest.raises(tags.TagNotFoundError):
api.tag_api.merge_tags(
context_factory(
params={'remove': 'bad', 'mergeTo': 'good'},
user=user_factory(rank=db.User.RANK_REGULAR)))
user=user_factory(rank=model.User.RANK_REGULAR)))
def test_trying_to_merge_without_privileges(
@ -97,4 +97,4 @@ def test_trying_to_merge_without_privileges(
'remove': 'source',
'mergeTo': 'target',
},
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
user=user_factory(rank=model.User.RANK_ANONYMOUS)))

View file

@ -1,6 +1,6 @@
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors
from szurubooru import api, db, model, errors
from szurubooru.func import tags
@ -8,8 +8,8 @@ from szurubooru.func import tags
def inject_config(config_injector):
config_injector({
'privileges': {
'tags:list': db.User.RANK_REGULAR,
'tags:view': db.User.RANK_REGULAR,
'tags:list': model.User.RANK_REGULAR,
'tags:view': model.User.RANK_REGULAR,
},
})
@ -24,7 +24,7 @@ def test_retrieving_multiple(user_factory, tag_factory, context_factory):
result = api.tag_api.get_tags(
context_factory(
params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_REGULAR)))
user=user_factory(rank=model.User.RANK_REGULAR)))
assert result == {
'query': '',
'page': 1,
@ -40,7 +40,7 @@ def test_trying_to_retrieve_multiple_without_privileges(
api.tag_api.get_tags(
context_factory(
params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
user=user_factory(rank=model.User.RANK_ANONYMOUS)))
def test_retrieving_single(user_factory, tag_factory, context_factory):
@ -50,7 +50,7 @@ def test_retrieving_single(user_factory, tag_factory, context_factory):
tags.serialize_tag.return_value = 'serialized tag'
result = api.tag_api.get_tag(
context_factory(
user=user_factory(rank=db.User.RANK_REGULAR)),
user=user_factory(rank=model.User.RANK_REGULAR)),
{'tag_name': 'tag'})
assert result == 'serialized tag'
@ -59,7 +59,7 @@ def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
with pytest.raises(tags.TagNotFoundError):
api.tag_api.get_tag(
context_factory(
user=user_factory(rank=db.User.RANK_REGULAR)),
user=user_factory(rank=model.User.RANK_REGULAR)),
{'tag_name': '-'})
@ -68,5 +68,5 @@ def test_trying_to_retrieve_single_without_privileges(
with pytest.raises(errors.AuthError):
api.tag_api.get_tag(
context_factory(
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
user=user_factory(rank=model.User.RANK_ANONYMOUS)),
{'tag_name': '-'})

View file

@ -1,12 +1,12 @@
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors
from szurubooru import api, db, model, errors
from szurubooru.func import tags
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({'privileges': {'tags:view': db.User.RANK_REGULAR}})
config_injector({'privileges': {'tags:view': model.User.RANK_REGULAR}})
def test_get_tag_siblings(user_factory, tag_factory, context_factory):
@ -21,7 +21,7 @@ def test_get_tag_siblings(user_factory, tag_factory, context_factory):
(tag_factory(names=['sib2']), 3),
]
result = api.tag_api.get_tag_siblings(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
context_factory(user=user_factory(rank=model.User.RANK_REGULAR)),
{'tag_name': 'tag'})
assert result == {
'results': [
@ -40,12 +40,12 @@ def test_get_tag_siblings(user_factory, tag_factory, context_factory):
def test_trying_to_retrieve_non_existing(user_factory, context_factory):
with pytest.raises(tags.TagNotFoundError):
api.tag_api.get_tag_siblings(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
context_factory(user=user_factory(rank=model.User.RANK_REGULAR)),
{'tag_name': '-'})
def test_trying_to_retrieve_without_privileges(user_factory, context_factory):
with pytest.raises(errors.AuthError):
api.tag_api.get_tag_siblings(
context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)),
context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)),
{'tag_name': '-'})

View file

@ -1,6 +1,6 @@
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors
from szurubooru import api, db, model, errors
from szurubooru.func import tags, snapshots
@ -8,18 +8,18 @@ from szurubooru.func import tags, snapshots
def inject_config(config_injector):
config_injector({
'privileges': {
'tags:create': db.User.RANK_REGULAR,
'tags:edit:names': db.User.RANK_REGULAR,
'tags:edit:category': db.User.RANK_REGULAR,
'tags:edit:description': db.User.RANK_REGULAR,
'tags:edit:suggestions': db.User.RANK_REGULAR,
'tags:edit:implications': db.User.RANK_REGULAR,
'tags:create': model.User.RANK_REGULAR,
'tags:edit:names': model.User.RANK_REGULAR,
'tags:edit:category': model.User.RANK_REGULAR,
'tags:edit:description': model.User.RANK_REGULAR,
'tags:edit:suggestions': model.User.RANK_REGULAR,
'tags:edit:implications': model.User.RANK_REGULAR,
},
})
def test_simple_updating(user_factory, tag_factory, context_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR)
auth_user = user_factory(rank=model.User.RANK_REGULAR)
tag = tag_factory(names=['tag1', 'tag2'])
db.session.add(tag)
db.session.commit()
@ -56,8 +56,7 @@ def test_simple_updating(user_factory, tag_factory, context_factory):
tag, ['sug1', 'sug2'])
tags.update_tag_implications.assert_called_once_with(
tag, ['imp1', 'imp2'])
tags.serialize_tag.assert_called_once_with(
tag, options=None)
tags.serialize_tag.assert_called_once_with(tag, options=[])
snapshots.modify.assert_called_once_with(tag, auth_user)
tags.export_to_json.assert_called_once_with()
@ -90,7 +89,7 @@ def test_omitting_optional_field(
api.tag_api.update_tag(
context_factory(
params={**params, **{'version': 1}},
user=user_factory(rank=db.User.RANK_REGULAR)),
user=user_factory(rank=model.User.RANK_REGULAR)),
{'tag_name': 'tag'})
@ -99,7 +98,7 @@ def test_trying_to_update_non_existing(user_factory, context_factory):
api.tag_api.update_tag(
context_factory(
params={'names': ['dummy']},
user=user_factory(rank=db.User.RANK_REGULAR)),
user=user_factory(rank=model.User.RANK_REGULAR)),
{'tag_name': 'tag1'})
@ -117,7 +116,7 @@ def test_trying_to_update_without_privileges(
api.tag_api.update_tag(
context_factory(
params={**params, **{'version': 1}},
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
user=user_factory(rank=model.User.RANK_ANONYMOUS)),
{'tag_name': 'tag'})
@ -127,9 +126,9 @@ def test_trying_to_create_tags_without_privileges(
db.session.add(tag)
db.session.commit()
config_injector({'privileges': {
'tags:create': db.User.RANK_ADMINISTRATOR,
'tags:edit:suggestions': db.User.RANK_REGULAR,
'tags:edit:implications': db.User.RANK_REGULAR,
'tags:create': model.User.RANK_ADMINISTRATOR,
'tags:edit:suggestions': model.User.RANK_REGULAR,
'tags:edit:implications': model.User.RANK_REGULAR,
}})
with patch('szurubooru.func.tags.get_or_create_tags_by_names'):
tags.get_or_create_tags_by_names.return_value = ([], ['new-tag'])
@ -137,12 +136,12 @@ def test_trying_to_create_tags_without_privileges(
api.tag_api.update_tag(
context_factory(
params={'suggestions': ['tag1', 'tag2'], 'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)),
user=user_factory(rank=model.User.RANK_REGULAR)),
{'tag_name': 'tag'})
db.session.rollback()
with pytest.raises(errors.AuthError):
api.tag_api.update_tag(
context_factory(
params={'implications': ['tag1', 'tag2'], 'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)),
user=user_factory(rank=model.User.RANK_REGULAR)),
{'tag_name': 'tag'})

View file

@ -1,6 +1,6 @@
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors
from szurubooru import api, db, model, errors
from szurubooru.func import users
@ -31,7 +31,7 @@ def test_creating_user(user_factory, context_factory, fake_datetime):
'avatarStyle': 'manual',
},
files={'avatar': b'...'},
user=user_factory(rank=db.User.RANK_REGULAR)))
user=user_factory(rank=model.User.RANK_REGULAR)))
assert result == 'serialized user'
users.create_user.assert_called_once_with(
'chewie1', 'oks', 'asd@asd.asd')
@ -50,7 +50,7 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
'password': 'oks',
}
user = user_factory()
auth_user = user_factory(rank=db.User.RANK_REGULAR)
auth_user = user_factory(rank=model.User.RANK_REGULAR)
del params[field]
with patch('szurubooru.func.users.create_user'), \
pytest.raises(errors.MissingRequiredParameterError):
@ -70,7 +70,7 @@ def test_omitting_optional_field(user_factory, context_factory, field):
}
del params[field]
user = user_factory()
auth_user = user_factory(rank=db.User.RANK_MODERATOR)
auth_user = user_factory(rank=model.User.RANK_MODERATOR)
with patch('szurubooru.func.users.create_user'), \
patch('szurubooru.func.users.update_user_avatar'), \
patch('szurubooru.func.users.serialize_user'):
@ -84,4 +84,4 @@ def test_trying_to_create_user_without_privileges(
with pytest.raises(errors.AuthError):
api.user_api.create_user(context_factory(
params='whatever',
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
user=user_factory(rank=model.User.RANK_ANONYMOUS)))

View file

@ -1,5 +1,5 @@
import pytest
from szurubooru import api, db, errors
from szurubooru import api, db, model, errors
from szurubooru.func import users
@ -7,45 +7,45 @@ from szurubooru.func import users
def inject_config(config_injector):
config_injector({
'privileges': {
'users:delete:self': db.User.RANK_REGULAR,
'users:delete:any': db.User.RANK_MODERATOR,
'users:delete:self': model.User.RANK_REGULAR,
'users:delete:any': model.User.RANK_MODERATOR,
},
})
def test_deleting_oneself(user_factory, context_factory):
user = user_factory(name='u', rank=db.User.RANK_REGULAR)
user = user_factory(name='u', rank=model.User.RANK_REGULAR)
db.session.add(user)
db.session.commit()
result = api.user_api.delete_user(
context_factory(
params={'version': 1}, user=user), {'user_name': 'u'})
assert result == {}
assert db.session.query(db.User).count() == 0
assert db.session.query(model.User).count() == 0
def test_deleting_someone_else(user_factory, context_factory):
user1 = user_factory(name='u1', rank=db.User.RANK_REGULAR)
user2 = user_factory(name='u2', rank=db.User.RANK_MODERATOR)
user1 = user_factory(name='u1', rank=model.User.RANK_REGULAR)
user2 = user_factory(name='u2', rank=model.User.RANK_MODERATOR)
db.session.add_all([user1, user2])
db.session.commit()
api.user_api.delete_user(
context_factory(
params={'version': 1}, user=user2), {'user_name': 'u1'})
assert db.session.query(db.User).count() == 1
assert db.session.query(model.User).count() == 1
def test_trying_to_delete_someone_else_without_privileges(
user_factory, context_factory):
user1 = user_factory(name='u1', rank=db.User.RANK_REGULAR)
user2 = user_factory(name='u2', rank=db.User.RANK_REGULAR)
user1 = user_factory(name='u1', rank=model.User.RANK_REGULAR)
user2 = user_factory(name='u2', rank=model.User.RANK_REGULAR)
db.session.add_all([user1, user2])
db.session.commit()
with pytest.raises(errors.AuthError):
api.user_api.delete_user(
context_factory(
params={'version': 1}, user=user2), {'user_name': 'u1'})
assert db.session.query(db.User).count() == 2
assert db.session.query(model.User).count() == 2
def test_trying_to_delete_non_existing(user_factory, context_factory):
@ -53,5 +53,5 @@ def test_trying_to_delete_non_existing(user_factory, context_factory):
api.user_api.delete_user(
context_factory(
params={'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)),
user=user_factory(rank=model.User.RANK_REGULAR)),
{'user_name': 'bad'})

View file

@ -1,6 +1,6 @@
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors
from szurubooru import api, db, model, errors
from szurubooru.func import users
@ -8,16 +8,16 @@ from szurubooru.func import users
def inject_config(config_injector):
config_injector({
'privileges': {
'users:list': db.User.RANK_REGULAR,
'users:view': db.User.RANK_REGULAR,
'users:edit:any:email': db.User.RANK_MODERATOR,
'users:list': model.User.RANK_REGULAR,
'users:view': model.User.RANK_REGULAR,
'users:edit:any:email': model.User.RANK_MODERATOR,
},
})
def test_retrieving_multiple(user_factory, context_factory):
user1 = user_factory(name='u1', rank=db.User.RANK_MODERATOR)
user2 = user_factory(name='u2', rank=db.User.RANK_MODERATOR)
user1 = user_factory(name='u1', rank=model.User.RANK_MODERATOR)
user2 = user_factory(name='u2', rank=model.User.RANK_MODERATOR)
db.session.add_all([user1, user2])
db.session.flush()
with patch('szurubooru.func.users.serialize_user'):
@ -25,7 +25,7 @@ def test_retrieving_multiple(user_factory, context_factory):
result = api.user_api.get_users(
context_factory(
params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_REGULAR)))
user=user_factory(rank=model.User.RANK_REGULAR)))
assert result == {
'query': '',
'page': 1,
@ -41,12 +41,12 @@ def test_trying_to_retrieve_multiple_without_privileges(
api.user_api.get_users(
context_factory(
params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
user=user_factory(rank=model.User.RANK_ANONYMOUS)))
def test_retrieving_single(user_factory, context_factory):
user = user_factory(name='u1', rank=db.User.RANK_REGULAR)
auth_user = user_factory(rank=db.User.RANK_REGULAR)
user = user_factory(name='u1', rank=model.User.RANK_REGULAR)
auth_user = user_factory(rank=model.User.RANK_REGULAR)
db.session.add(user)
db.session.flush()
with patch('szurubooru.func.users.serialize_user'):
@ -57,7 +57,7 @@ def test_retrieving_single(user_factory, context_factory):
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR)
auth_user = user_factory(rank=model.User.RANK_REGULAR)
with pytest.raises(users.UserNotFoundError):
api.user_api.get_user(
context_factory(user=auth_user), {'user_name': '-'})
@ -65,8 +65,8 @@ def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
def test_trying_to_retrieve_single_without_privileges(
user_factory, context_factory):
auth_user = user_factory(rank=db.User.RANK_ANONYMOUS)
db.session.add(user_factory(name='u1', rank=db.User.RANK_REGULAR))
auth_user = user_factory(rank=model.User.RANK_ANONYMOUS)
db.session.add(user_factory(name='u1', rank=model.User.RANK_REGULAR))
db.session.flush()
with pytest.raises(errors.AuthError):
api.user_api.get_user(

View file

@ -1,6 +1,6 @@
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors
from szurubooru import api, db, model, errors
from szurubooru.func import users
@ -8,23 +8,23 @@ from szurubooru.func import users
def inject_config(config_injector):
config_injector({
'privileges': {
'users:edit:self:name': db.User.RANK_REGULAR,
'users:edit:self:pass': db.User.RANK_REGULAR,
'users:edit:self:email': db.User.RANK_REGULAR,
'users:edit:self:rank': db.User.RANK_MODERATOR,
'users:edit:self:avatar': db.User.RANK_MODERATOR,
'users:edit:any:name': db.User.RANK_MODERATOR,
'users:edit:any:pass': db.User.RANK_MODERATOR,
'users:edit:any:email': db.User.RANK_MODERATOR,
'users:edit:any:rank': db.User.RANK_ADMINISTRATOR,
'users:edit:any:avatar': db.User.RANK_ADMINISTRATOR,
'users:edit:self:name': model.User.RANK_REGULAR,
'users:edit:self:pass': model.User.RANK_REGULAR,
'users:edit:self:email': model.User.RANK_REGULAR,
'users:edit:self:rank': model.User.RANK_MODERATOR,
'users:edit:self:avatar': model.User.RANK_MODERATOR,
'users:edit:any:name': model.User.RANK_MODERATOR,
'users:edit:any:pass': model.User.RANK_MODERATOR,
'users:edit:any:email': model.User.RANK_MODERATOR,
'users:edit:any:rank': model.User.RANK_ADMINISTRATOR,
'users:edit:any:avatar': model.User.RANK_ADMINISTRATOR,
},
})
def test_updating_user(context_factory, user_factory):
user = user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR)
auth_user = user_factory(rank=db.User.RANK_ADMINISTRATOR)
user = user_factory(name='u1', rank=model.User.RANK_ADMINISTRATOR)
auth_user = user_factory(rank=model.User.RANK_ADMINISTRATOR)
db.session.add(user)
db.session.flush()
@ -63,13 +63,13 @@ def test_updating_user(context_factory, user_factory):
users.update_user_avatar.assert_called_once_with(
user, 'manual', b'...')
users.serialize_user.assert_called_once_with(
user, auth_user, options=None)
user, auth_user, options=[])
@pytest.mark.parametrize(
'field', ['name', 'email', 'password', 'rank', 'avatarStyle'])
def test_omitting_optional_field(user_factory, context_factory, field):
user = user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR)
user = user_factory(name='u1', rank=model.User.RANK_ADMINISTRATOR)
db.session.add(user)
db.session.flush()
params = {
@ -96,7 +96,7 @@ def test_omitting_optional_field(user_factory, context_factory, field):
def test_trying_to_update_non_existing(user_factory, context_factory):
user = user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR)
user = user_factory(name='u1', rank=model.User.RANK_ADMINISTRATOR)
db.session.add(user)
db.session.flush()
with pytest.raises(users.UserNotFoundError):
@ -113,8 +113,8 @@ def test_trying_to_update_non_existing(user_factory, context_factory):
])
def test_trying_to_update_field_without_privileges(
user_factory, context_factory, params):
user1 = user_factory(name='u1', rank=db.User.RANK_REGULAR)
user2 = user_factory(name='u2', rank=db.User.RANK_REGULAR)
user1 = user_factory(name='u1', rank=model.User.RANK_REGULAR)
user2 = user_factory(name='u2', rank=model.User.RANK_REGULAR)
db.session.add_all([user1, user2])
db.session.flush()
with pytest.raises(errors.AuthError):

Some files were not shown because too many files have changed in this diff Show more