From ad842ee8a54c57463b8e28b52970f173ea1d64ea Mon Sep 17 00:00:00 2001 From: rr- Date: Sat, 4 Feb 2017 01:08:12 +0100 Subject: [PATCH] server: refactor + add type hinting - Added type hinting (for now, 3.5-compatible) - Split `db` namespace into `db` module and `model` namespace - Changed elastic search to be created lazily for each operation - Changed to class based approach in entity serialization to allow stronger typing - Removed `required` argument from `context.get_*` family of functions; now it's implied if `default` argument is omitted - Changed `unalias_dict` implementation to use less magic inputs --- server/migrate-v1 | 6 +- server/mypy.ini | 14 + server/szurubooru/api/comment_api.py | 75 +-- server/szurubooru/api/info_api.py | 24 +- server/szurubooru/api/password_reset_api.py | 18 +- server/szurubooru/api/post_api.py | 145 ++--- server/szurubooru/api/snapshot_api.py | 9 +- server/szurubooru/api/tag_api.py | 77 +-- server/szurubooru/api/tag_category_api.py | 44 +- server/szurubooru/api/upload_api.py | 10 +- server/szurubooru/api/user_api.py | 45 +- server/szurubooru/config.py | 5 +- server/szurubooru/db.py | 36 ++ server/szurubooru/db/__init__.py | 17 - server/szurubooru/db/session.py | 27 - server/szurubooru/db/util.py | 34 -- server/szurubooru/errors.py | 8 +- server/szurubooru/facade.py | 32 +- server/szurubooru/func/auth.py | 30 +- server/szurubooru/func/cache.py | 26 +- server/szurubooru/func/comments.py | 98 ++-- server/szurubooru/func/diff.py | 17 +- server/szurubooru/func/favorites.py | 24 +- server/szurubooru/func/file_uploads.py | 17 +- server/szurubooru/func/files.py | 19 +- server/szurubooru/func/image_hash.py | 106 ++-- server/szurubooru/func/images.py | 19 +- server/szurubooru/func/mailer.py | 2 +- server/szurubooru/func/mime.py | 13 +- server/szurubooru/func/net.py | 2 +- server/szurubooru/func/posts.py | 452 ++++++++++------ server/szurubooru/func/scores.py | 22 +- server/szurubooru/func/serialization.py | 27 + server/szurubooru/func/snapshots.py | 59 +- server/szurubooru/func/tag_categories.py | 109 ++-- server/szurubooru/func/tags.py | 239 +++++---- server/szurubooru/func/users.py | 194 ++++--- server/szurubooru/func/util.py | 74 ++- server/szurubooru/func/versions.py | 11 +- server/szurubooru/middleware/authenticator.py | 27 +- server/szurubooru/middleware/cache_purger.py | 3 +- .../szurubooru/middleware/request_logger.py | 6 +- server/szurubooru/migrations/env.py | 4 +- server/szurubooru/model/__init__.py | 15 + server/szurubooru/{db => model}/base.py | 0 server/szurubooru/{db => model}/comment.py | 15 +- server/szurubooru/{db => model}/post.py | 15 +- server/szurubooru/{db => model}/snapshot.py | 2 +- server/szurubooru/{db => model}/tag.py | 10 +- .../szurubooru/{db => model}/tag_category.py | 7 +- server/szurubooru/{db => model}/user.py | 55 +- server/szurubooru/model/util.py | 42 ++ server/szurubooru/rest/__init__.py | 2 +- server/szurubooru/rest/app.py | 19 +- server/szurubooru/rest/context.py | 187 ++++--- server/szurubooru/rest/errors.py | 18 +- server/szurubooru/rest/middleware.py | 12 +- server/szurubooru/rest/routes.py | 22 +- .../search/configs/base_search_config.py | 29 +- .../search/configs/comment_search_config.py | 71 +-- .../search/configs/post_search_config.py | 504 ++++++++++++------ .../search/configs/snapshot_search_config.py | 41 +- .../search/configs/tag_search_config.py | 177 +++--- .../search/configs/user_search_config.py | 62 ++- server/szurubooru/search/configs/util.py | 91 ++-- server/szurubooru/search/criteria.py | 32 +- server/szurubooru/search/executor.py | 114 ++-- server/szurubooru/search/parser.py | 30 +- server/szurubooru/search/query.py | 16 + server/szurubooru/search/tokens.py | 21 +- server/szurubooru/search/typing.py | 6 + .../tests/api/test_comment_creating.py | 17 +- .../tests/api/test_comment_deleting.py | 22 +- .../tests/api/test_comment_rating.py | 39 +- .../tests/api/test_comment_retrieving.py | 16 +- .../tests/api/test_comment_updating.py | 18 +- .../tests/api/test_password_reset.py | 12 +- .../tests/api/test_post_creating.py | 50 +- .../tests/api/test_post_deleting.py | 14 +- .../tests/api/test_post_favoriting.py | 27 +- .../tests/api/test_post_featuring.py | 26 +- .../szurubooru/tests/api/test_post_merging.py | 17 +- .../szurubooru/tests/api/test_post_rating.py | 26 +- .../tests/api/test_post_retrieving.py | 22 +- .../tests/api/test_post_updating.py | 34 +- .../tests/api/test_snapshot_retrieving.py | 10 +- .../tests/api/test_tag_category_creating.py | 10 +- .../tests/api/test_tag_category_deleting.py | 22 +- .../tests/api/test_tag_category_retrieving.py | 14 +- .../tests/api/test_tag_category_updating.py | 18 +- .../szurubooru/tests/api/test_tag_creating.py | 12 +- .../szurubooru/tests/api/test_tag_deleting.py | 18 +- .../szurubooru/tests/api/test_tag_merging.py | 14 +- .../tests/api/test_tag_retrieving.py | 16 +- .../tests/api/test_tag_siblings_retrieving.py | 10 +- .../szurubooru/tests/api/test_tag_updating.py | 35 +- .../tests/api/test_user_creating.py | 10 +- .../tests/api/test_user_deleting.py | 24 +- .../tests/api/test_user_retrieving.py | 26 +- .../tests/api/test_user_updating.py | 36 +- server/szurubooru/tests/conftest.py | 46 +- server/szurubooru/tests/func/test_comments.py | 4 - .../szurubooru/tests/func/test_image_hash.py | 8 +- server/szurubooru/tests/func/test_posts.py | 80 ++- .../szurubooru/tests/func/test_snapshots.py | 32 +- .../tests/func/test_tag_categories.py | 4 +- server/szurubooru/tests/func/test_tags.py | 6 +- server/szurubooru/tests/func/test_users.py | 34 +- .../tests/{db => model}/__init__.py | 0 .../tests/{db => model}/test_comment.py | 18 +- .../tests/{db => model}/test_post.py | 52 +- .../tests/{db => model}/test_tag.py | 28 +- .../tests/{db => model}/test_user.py | 72 +-- server/szurubooru/tests/rest/test_context.py | 36 +- .../search/configs/test_post_search_config.py | 28 +- server/test | 1 + 116 files changed, 2868 insertions(+), 2037 deletions(-) create mode 100644 server/mypy.ini create mode 100644 server/szurubooru/db.py delete mode 100644 server/szurubooru/db/__init__.py delete mode 100644 server/szurubooru/db/session.py delete mode 100644 server/szurubooru/db/util.py create mode 100644 server/szurubooru/func/serialization.py create mode 100644 server/szurubooru/model/__init__.py rename server/szurubooru/{db => model}/base.py (100%) rename server/szurubooru/{db => model}/comment.py (84%) rename server/szurubooru/{db => model}/post.py (95%) rename server/szurubooru/{db => model}/snapshot.py (96%) rename server/szurubooru/{db => model}/tag.py (93%) rename server/szurubooru/{db => model}/tag_category.py (84%) rename server/szurubooru/{db => model}/user.py (50%) create mode 100644 server/szurubooru/model/util.py create mode 100644 server/szurubooru/search/query.py create mode 100644 server/szurubooru/search/typing.py rename server/szurubooru/tests/{db => model}/__init__.py (100%) rename server/szurubooru/tests/{db => model}/test_comment.py (74%) rename server/szurubooru/tests/{db => model}/test_post.py (71%) rename server/szurubooru/tests/{db => model}/test_tag.py (80%) rename server/szurubooru/tests/{db => model}/test_user.py (66%) diff --git a/server/migrate-v1 b/server/migrate-v1 index 0fdf9e4f..d3ec0dda 100755 --- a/server/migrate-v1 +++ b/server/migrate-v1 @@ -8,7 +8,7 @@ import zlib import concurrent.futures import logging import coloredlogs -import sqlalchemy +import sqlalchemy as sa from szurubooru import config, db from szurubooru.func import files, images, posts, comments @@ -42,8 +42,8 @@ def get_v1_session(args): port=args.port, name=args.name) logger.info('Connecting to %r...', dsn) - engine = sqlalchemy.create_engine(dsn) - session_maker = sqlalchemy.orm.sessionmaker(bind=engine) + engine = sa.create_engine(dsn) + session_maker = sa.orm.sessionmaker(bind=engine) return session_maker() def parse_args(): diff --git a/server/mypy.ini b/server/mypy.ini new file mode 100644 index 00000000..a0300b7a --- /dev/null +++ b/server/mypy.ini @@ -0,0 +1,14 @@ +[mypy] +ignore_missing_imports = True +follow_imports = skip +disallow_untyped_calls = True +disallow_untyped_defs = True +check_untyped_defs = True +disallow_subclassing_any = False +warn_redundant_casts = True +warn_unused_ignores = True +strict_optional = True +strict_boolean = False + +[mypy-szurubooru.tests.*] +ignore_errors=True diff --git a/server/szurubooru/api/comment_api.py b/server/szurubooru/api/comment_api.py index 51058160..1cde5385 100644 --- a/server/szurubooru/api/comment_api.py +++ b/server/szurubooru/api/comment_api.py @@ -1,31 +1,44 @@ -import datetime -from szurubooru import search -from szurubooru.rest import routes -from szurubooru.func import auth, comments, posts, scores, util, versions +from typing import Dict +from datetime import datetime +from szurubooru import search, rest, model +from szurubooru.func import ( + auth, comments, posts, scores, versions, serialization) _search_executor = search.Executor(search.configs.CommentSearchConfig()) -def _serialize(ctx, comment, **kwargs): +def _get_comment(params: Dict[str, str]) -> model.Comment: + try: + comment_id = int(params['comment_id']) + except TypeError: + raise comments.InvalidCommentIdError( + 'Invalid comment ID: %r.' % params['comment_id']) + return comments.get_comment_by_id(comment_id) + + +def _serialize( + ctx: rest.Context, comment: model.Comment) -> rest.Response: return comments.serialize_comment( comment, ctx.user, - options=util.get_serialization_options(ctx), **kwargs) + options=serialization.get_serialization_options(ctx)) -@routes.get('/comments/?') -def get_comments(ctx, _params=None): +@rest.routes.get('/comments/?') +def get_comments( + ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: auth.verify_privilege(ctx.user, 'comments:list') return _search_executor.execute_and_serialize( ctx, lambda comment: _serialize(ctx, comment)) -@routes.post('/comments/?') -def create_comment(ctx, _params=None): +@rest.routes.post('/comments/?') +def create_comment( + ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: auth.verify_privilege(ctx.user, 'comments:create') - text = ctx.get_param_as_string('text', required=True) - post_id = ctx.get_param_as_int('postId', required=True) + text = ctx.get_param_as_string('text') + post_id = ctx.get_param_as_int('postId') post = posts.get_post_by_id(post_id) comment = comments.create_comment(ctx.user, post, text) ctx.session.add(comment) @@ -33,30 +46,30 @@ def create_comment(ctx, _params=None): return _serialize(ctx, comment) -@routes.get('/comment/(?P[^/]+)/?') -def get_comment(ctx, params): +@rest.routes.get('/comment/(?P[^/]+)/?') +def get_comment(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: auth.verify_privilege(ctx.user, 'comments:view') - comment = comments.get_comment_by_id(params['comment_id']) + comment = _get_comment(params) return _serialize(ctx, comment) -@routes.put('/comment/(?P[^/]+)/?') -def update_comment(ctx, params): - comment = comments.get_comment_by_id(params['comment_id']) +@rest.routes.put('/comment/(?P[^/]+)/?') +def update_comment(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: + comment = _get_comment(params) versions.verify_version(comment, ctx) versions.bump_version(comment) infix = 'own' if ctx.user.user_id == comment.user_id else 'any' - text = ctx.get_param_as_string('text', required=True) + text = ctx.get_param_as_string('text') auth.verify_privilege(ctx.user, 'comments:edit:%s' % infix) comments.update_comment_text(comment, text) - comment.last_edit_time = datetime.datetime.utcnow() + comment.last_edit_time = datetime.utcnow() ctx.session.commit() return _serialize(ctx, comment) -@routes.delete('/comment/(?P[^/]+)/?') -def delete_comment(ctx, params): - comment = comments.get_comment_by_id(params['comment_id']) +@rest.routes.delete('/comment/(?P[^/]+)/?') +def delete_comment(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: + comment = _get_comment(params) versions.verify_version(comment, ctx) infix = 'own' if ctx.user.user_id == comment.user_id else 'any' auth.verify_privilege(ctx.user, 'comments:delete:%s' % infix) @@ -65,20 +78,22 @@ def delete_comment(ctx, params): return {} -@routes.put('/comment/(?P[^/]+)/score/?') -def set_comment_score(ctx, params): +@rest.routes.put('/comment/(?P[^/]+)/score/?') +def set_comment_score( + ctx: rest.Context, params: Dict[str, str]) -> rest.Response: auth.verify_privilege(ctx.user, 'comments:score') - score = ctx.get_param_as_int('score', required=True) - comment = comments.get_comment_by_id(params['comment_id']) + score = ctx.get_param_as_int('score') + comment = _get_comment(params) scores.set_score(comment, ctx.user, score) ctx.session.commit() return _serialize(ctx, comment) -@routes.delete('/comment/(?P[^/]+)/score/?') -def delete_comment_score(ctx, params): +@rest.routes.delete('/comment/(?P[^/]+)/score/?') +def delete_comment_score( + ctx: rest.Context, params: Dict[str, str]) -> rest.Response: auth.verify_privilege(ctx.user, 'comments:score') - comment = comments.get_comment_by_id(params['comment_id']) + comment = _get_comment(params) scores.delete_score(comment, ctx.user) ctx.session.commit() return _serialize(ctx, comment) diff --git a/server/szurubooru/api/info_api.py b/server/szurubooru/api/info_api.py index c0d2a955..e0fafedd 100644 --- a/server/szurubooru/api/info_api.py +++ b/server/szurubooru/api/info_api.py @@ -1,19 +1,20 @@ -import datetime import os -from szurubooru import config -from szurubooru.rest import routes +from typing import Optional, Dict +from datetime import datetime, timedelta +from szurubooru import config, rest from szurubooru.func import posts, users, util -_cache_time = None -_cache_result = None +_cache_time = None # type: Optional[datetime] +_cache_result = None # type: Optional[int] -def _get_disk_usage(): +def _get_disk_usage() -> int: global _cache_time, _cache_result # pylint: disable=global-statement - threshold = datetime.timedelta(hours=48) - now = datetime.datetime.utcnow() + threshold = timedelta(hours=48) + now = datetime.utcnow() if _cache_time and _cache_time > now - threshold: + assert _cache_result return _cache_result total_size = 0 for dir_path, _, file_names in os.walk(config.config['data_dir']): @@ -25,8 +26,9 @@ def _get_disk_usage(): return total_size -@routes.get('/info/?') -def get_info(ctx, _params=None): +@rest.routes.get('/info/?') +def get_info( + ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: post_feature = posts.try_get_current_post_feature() return { 'postCount': posts.get_post_count(), @@ -38,7 +40,7 @@ def get_info(ctx, _params=None): 'featuringUser': users.serialize_user(post_feature.user, ctx.user) if post_feature else None, - 'serverTime': datetime.datetime.utcnow(), + 'serverTime': datetime.utcnow(), 'config': { 'userNameRegex': config.config['user_name_regex'], 'passwordRegex': config.config['password_regex'], diff --git a/server/szurubooru/api/password_reset_api.py b/server/szurubooru/api/password_reset_api.py index 7e5864c9..f49080a9 100644 --- a/server/szurubooru/api/password_reset_api.py +++ b/server/szurubooru/api/password_reset_api.py @@ -1,5 +1,5 @@ -from szurubooru import config, errors -from szurubooru.rest import routes +from typing import Dict +from szurubooru import config, errors, rest from szurubooru.func import auth, mailer, users, versions @@ -10,9 +10,9 @@ MAIL_BODY = \ 'Otherwise, please ignore this email.' -@routes.get('/password-reset/(?P[^/]+)/?') -def start_password_reset(_ctx, params): - ''' Send a mail with secure token to the correlated user. ''' +@rest.routes.get('/password-reset/(?P[^/]+)/?') +def start_password_reset( + _ctx: rest.Context, params: Dict[str, str]) -> rest.Response: user_name = params['user_name'] user = users.get_user_by_name_or_email(user_name) if not user.email: @@ -30,13 +30,13 @@ def start_password_reset(_ctx, params): return {} -@routes.post('/password-reset/(?P[^/]+)/?') -def finish_password_reset(ctx, params): - ''' Verify token from mail, generate a new password and return it. ''' +@rest.routes.post('/password-reset/(?P[^/]+)/?') +def finish_password_reset( + ctx: rest.Context, params: Dict[str, str]) -> rest.Response: user_name = params['user_name'] user = users.get_user_by_name_or_email(user_name) good_token = auth.generate_authentication_token(user) - token = ctx.get_param_as_string('token', required=True) + token = ctx.get_param_as_string('token') if token != good_token: raise errors.ValidationError('Invalid password reset token.') new_password = users.reset_user_password(user) diff --git a/server/szurubooru/api/post_api.py b/server/szurubooru/api/post_api.py index cbf8f27e..6c76688f 100644 --- a/server/szurubooru/api/post_api.py +++ b/server/szurubooru/api/post_api.py @@ -1,44 +1,60 @@ -import datetime -from szurubooru import search, db, errors -from szurubooru.rest import routes +from typing import Optional, Dict +from datetime import datetime +from szurubooru import db, model, errors, rest, search from szurubooru.func import ( - auth, tags, posts, snapshots, favorites, scores, util, versions) + auth, tags, posts, snapshots, favorites, scores, serialization, versions) -_search_executor = search.Executor(search.configs.PostSearchConfig()) +_search_executor_config = search.configs.PostSearchConfig() +_search_executor = search.Executor(_search_executor_config) -def _serialize_post(ctx, post): +def _get_post_id(params: Dict[str, str]) -> int: + try: + return int(params['post_id']) + except TypeError: + raise posts.InvalidPostIdError( + 'Invalid post ID: %r.' % params['post_id']) + + +def _get_post(params: Dict[str, str]) -> model.Post: + return posts.get_post_by_id(_get_post_id(params)) + + +def _serialize_post( + ctx: rest.Context, post: Optional[model.Post]) -> rest.Response: return posts.serialize_post( post, ctx.user, - options=util.get_serialization_options(ctx)) + options=serialization.get_serialization_options(ctx)) -@routes.get('/posts/?') -def get_posts(ctx, _params=None): +@rest.routes.get('/posts/?') +def get_posts( + ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: auth.verify_privilege(ctx.user, 'posts:list') - _search_executor.config.user = ctx.user + _search_executor_config.user = ctx.user return _search_executor.execute_and_serialize( ctx, lambda post: _serialize_post(ctx, post)) -@routes.post('/posts/?') -def create_post(ctx, _params=None): +@rest.routes.post('/posts/?') +def create_post( + ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: anonymous = ctx.get_param_as_bool('anonymous', default=False) if anonymous: auth.verify_privilege(ctx.user, 'posts:create:anonymous') else: auth.verify_privilege(ctx.user, 'posts:create:identified') - content = ctx.get_file('content', required=True) - tag_names = ctx.get_param_as_list('tags', required=False, default=[]) - safety = ctx.get_param_as_string('safety', required=True) - source = ctx.get_param_as_string('source', required=False, default=None) + content = ctx.get_file('content') + tag_names = ctx.get_param_as_list('tags', default=[]) + safety = ctx.get_param_as_string('safety') + source = ctx.get_param_as_string('source', default='') if ctx.has_param('contentUrl') and not source: - source = ctx.get_param_as_string('contentUrl') - relations = ctx.get_param_as_list('relations', required=False) or [] - notes = ctx.get_param_as_list('notes', required=False) or [] - flags = ctx.get_param_as_list('flags', required=False) or [] + source = ctx.get_param_as_string('contentUrl', default='') + relations = ctx.get_param_as_list('relations', default=[]) + notes = ctx.get_param_as_list('notes', default=[]) + flags = ctx.get_param_as_list('flags', default=[]) post, new_tags = posts.create_post( content, tag_names, None if anonymous else ctx.user) @@ -61,16 +77,16 @@ def create_post(ctx, _params=None): return _serialize_post(ctx, post) -@routes.get('/post/(?P[^/]+)/?') -def get_post(ctx, params): +@rest.routes.get('/post/(?P[^/]+)/?') +def get_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: auth.verify_privilege(ctx.user, 'posts:view') - post = posts.get_post_by_id(params['post_id']) + post = _get_post(params) return _serialize_post(ctx, post) -@routes.put('/post/(?P[^/]+)/?') -def update_post(ctx, params): - post = posts.get_post_by_id(params['post_id']) +@rest.routes.put('/post/(?P[^/]+)/?') +def update_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: + post = _get_post(params) versions.verify_version(post, ctx) versions.bump_version(post) if ctx.has_file('content'): @@ -104,7 +120,7 @@ def update_post(ctx, params): if ctx.has_file('thumbnail'): auth.verify_privilege(ctx.user, 'posts:edit:thumbnail') posts.update_post_thumbnail(post, ctx.get_file('thumbnail')) - post.last_edit_time = datetime.datetime.utcnow() + post.last_edit_time = datetime.utcnow() ctx.session.flush() snapshots.modify(post, ctx.user) ctx.session.commit() @@ -112,10 +128,10 @@ def update_post(ctx, params): return _serialize_post(ctx, post) -@routes.delete('/post/(?P[^/]+)/?') -def delete_post(ctx, params): +@rest.routes.delete('/post/(?P[^/]+)/?') +def delete_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: auth.verify_privilege(ctx.user, 'posts:delete') - post = posts.get_post_by_id(params['post_id']) + post = _get_post(params) versions.verify_version(post, ctx) snapshots.delete(post, ctx.user) posts.delete(post) @@ -124,13 +140,14 @@ def delete_post(ctx, params): return {} -@routes.post('/post-merge/?') -def merge_posts(ctx, _params=None): - source_post_id = ctx.get_param_as_string('remove', required=True) or '' - target_post_id = ctx.get_param_as_string('mergeTo', required=True) or '' - replace_content = ctx.get_param_as_bool('replaceContent') +@rest.routes.post('/post-merge/?') +def merge_posts( + ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: + source_post_id = ctx.get_param_as_int('remove') + target_post_id = ctx.get_param_as_int('mergeTo') source_post = posts.get_post_by_id(source_post_id) target_post = posts.get_post_by_id(target_post_id) + replace_content = ctx.get_param_as_bool('replaceContent') versions.verify_version(source_post, ctx, 'removeVersion') versions.verify_version(target_post, ctx, 'mergeToVersion') versions.bump_version(target_post) @@ -141,16 +158,18 @@ def merge_posts(ctx, _params=None): return _serialize_post(ctx, target_post) -@routes.get('/featured-post/?') -def get_featured_post(ctx, _params=None): +@rest.routes.get('/featured-post/?') +def get_featured_post( + ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: post = posts.try_get_featured_post() return _serialize_post(ctx, post) -@routes.post('/featured-post/?') -def set_featured_post(ctx, _params=None): +@rest.routes.post('/featured-post/?') +def set_featured_post( + ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: auth.verify_privilege(ctx.user, 'posts:feature') - post_id = ctx.get_param_as_int('id', required=True) + post_id = ctx.get_param_as_int('id') post = posts.get_post_by_id(post_id) featured_post = posts.try_get_featured_post() if featured_post and featured_post.post_id == post.post_id: @@ -162,55 +181,61 @@ def set_featured_post(ctx, _params=None): return _serialize_post(ctx, post) -@routes.put('/post/(?P[^/]+)/score/?') -def set_post_score(ctx, params): +@rest.routes.put('/post/(?P[^/]+)/score/?') +def set_post_score(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: auth.verify_privilege(ctx.user, 'posts:score') - post = posts.get_post_by_id(params['post_id']) - score = ctx.get_param_as_int('score', required=True) + post = _get_post(params) + score = ctx.get_param_as_int('score') scores.set_score(post, ctx.user, score) ctx.session.commit() return _serialize_post(ctx, post) -@routes.delete('/post/(?P[^/]+)/score/?') -def delete_post_score(ctx, params): +@rest.routes.delete('/post/(?P[^/]+)/score/?') +def delete_post_score( + ctx: rest.Context, params: Dict[str, str]) -> rest.Response: auth.verify_privilege(ctx.user, 'posts:score') - post = posts.get_post_by_id(params['post_id']) + post = _get_post(params) scores.delete_score(post, ctx.user) ctx.session.commit() return _serialize_post(ctx, post) -@routes.post('/post/(?P[^/]+)/favorite/?') -def add_post_to_favorites(ctx, params): +@rest.routes.post('/post/(?P[^/]+)/favorite/?') +def add_post_to_favorites( + ctx: rest.Context, params: Dict[str, str]) -> rest.Response: auth.verify_privilege(ctx.user, 'posts:favorite') - post = posts.get_post_by_id(params['post_id']) + post = _get_post(params) favorites.set_favorite(post, ctx.user) ctx.session.commit() return _serialize_post(ctx, post) -@routes.delete('/post/(?P[^/]+)/favorite/?') -def delete_post_from_favorites(ctx, params): +@rest.routes.delete('/post/(?P[^/]+)/favorite/?') +def delete_post_from_favorites( + ctx: rest.Context, params: Dict[str, str]) -> rest.Response: auth.verify_privilege(ctx.user, 'posts:favorite') - post = posts.get_post_by_id(params['post_id']) + post = _get_post(params) favorites.unset_favorite(post, ctx.user) ctx.session.commit() return _serialize_post(ctx, post) -@routes.get('/post/(?P[^/]+)/around/?') -def get_posts_around(ctx, params): +@rest.routes.get('/post/(?P[^/]+)/around/?') +def get_posts_around( + ctx: rest.Context, params: Dict[str, str]) -> rest.Response: auth.verify_privilege(ctx.user, 'posts:list') - _search_executor.config.user = ctx.user + _search_executor_config.user = ctx.user + post_id = _get_post_id(params) return _search_executor.get_around_and_serialize( - ctx, params['post_id'], lambda post: _serialize_post(ctx, post)) + ctx, post_id, lambda post: _serialize_post(ctx, post)) -@routes.post('/posts/reverse-search/?') -def get_posts_by_image(ctx, _params=None): +@rest.routes.post('/posts/reverse-search/?') +def get_posts_by_image( + ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: auth.verify_privilege(ctx.user, 'posts:reverse_search') - content = ctx.get_file('content', required=True) + content = ctx.get_file('content') try: lookalikes = posts.search_by_image(content) diff --git a/server/szurubooru/api/snapshot_api.py b/server/szurubooru/api/snapshot_api.py index 009d8c97..cdcee74a 100644 --- a/server/szurubooru/api/snapshot_api.py +++ b/server/szurubooru/api/snapshot_api.py @@ -1,13 +1,14 @@ -from szurubooru import search -from szurubooru.rest import routes +from typing import Dict +from szurubooru import search, rest from szurubooru.func import auth, snapshots _search_executor = search.Executor(search.configs.SnapshotSearchConfig()) -@routes.get('/snapshots/?') -def get_snapshots(ctx, _params=None): +@rest.routes.get('/snapshots/?') +def get_snapshots( + ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: auth.verify_privilege(ctx.user, 'snapshots:list') return _search_executor.execute_and_serialize( ctx, lambda snapshot: snapshots.serialize_snapshot(snapshot, ctx.user)) diff --git a/server/szurubooru/api/tag_api.py b/server/szurubooru/api/tag_api.py index a69673ad..7a379b3c 100644 --- a/server/szurubooru/api/tag_api.py +++ b/server/szurubooru/api/tag_api.py @@ -1,18 +1,22 @@ -import datetime -from szurubooru import db, search -from szurubooru.rest import routes -from szurubooru.func import auth, tags, snapshots, util, versions +from typing import Optional, List, Dict +from datetime import datetime +from szurubooru import db, model, search, rest +from szurubooru.func import auth, tags, snapshots, serialization, versions _search_executor = search.Executor(search.configs.TagSearchConfig()) -def _serialize(ctx, tag): +def _serialize(ctx: rest.Context, tag: model.Tag) -> rest.Response: return tags.serialize_tag( - tag, options=util.get_serialization_options(ctx)) + tag, options=serialization.get_serialization_options(ctx)) -def _create_if_needed(tag_names, user): +def _get_tag(params: Dict[str, str]) -> model.Tag: + return tags.get_tag_by_name(params['tag_name']) + + +def _create_if_needed(tag_names: List[str], user: model.User) -> None: if not tag_names: return _existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names) @@ -23,25 +27,22 @@ def _create_if_needed(tag_names, user): snapshots.create(tag, user) -@routes.get('/tags/?') -def get_tags(ctx, _params=None): +@rest.routes.get('/tags/?') +def get_tags(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: auth.verify_privilege(ctx.user, 'tags:list') return _search_executor.execute_and_serialize( ctx, lambda tag: _serialize(ctx, tag)) -@routes.post('/tags/?') -def create_tag(ctx, _params=None): +@rest.routes.post('/tags/?') +def create_tag(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: auth.verify_privilege(ctx.user, 'tags:create') - names = ctx.get_param_as_list('names', required=True) - category = ctx.get_param_as_string('category', required=True) - description = ctx.get_param_as_string( - 'description', required=False, default=None) - suggestions = ctx.get_param_as_list( - 'suggestions', required=False, default=[]) - implications = ctx.get_param_as_list( - 'implications', required=False, default=[]) + names = ctx.get_param_as_list('names') + category = ctx.get_param_as_string('category') + description = ctx.get_param_as_string('description', default='') + suggestions = ctx.get_param_as_list('suggestions', default=[]) + implications = ctx.get_param_as_list('implications', default=[]) _create_if_needed(suggestions, ctx.user) _create_if_needed(implications, ctx.user) @@ -56,16 +57,16 @@ def create_tag(ctx, _params=None): return _serialize(ctx, tag) -@routes.get('/tag/(?P.+)') -def get_tag(ctx, params): +@rest.routes.get('/tag/(?P.+)') +def get_tag(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: auth.verify_privilege(ctx.user, 'tags:view') - tag = tags.get_tag_by_name(params['tag_name']) + tag = _get_tag(params) return _serialize(ctx, tag) -@routes.put('/tag/(?P.+)') -def update_tag(ctx, params): - tag = tags.get_tag_by_name(params['tag_name']) +@rest.routes.put('/tag/(?P.+)') +def update_tag(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: + tag = _get_tag(params) versions.verify_version(tag, ctx) versions.bump_version(tag) if ctx.has_param('names'): @@ -78,7 +79,7 @@ def update_tag(ctx, params): if ctx.has_param('description'): auth.verify_privilege(ctx.user, 'tags:edit:description') tags.update_tag_description( - tag, ctx.get_param_as_string('description', default=None)) + tag, ctx.get_param_as_string('description')) if ctx.has_param('suggestions'): auth.verify_privilege(ctx.user, 'tags:edit:suggestions') suggestions = ctx.get_param_as_list('suggestions') @@ -89,7 +90,7 @@ def update_tag(ctx, params): implications = ctx.get_param_as_list('implications') _create_if_needed(implications, ctx.user) tags.update_tag_implications(tag, implications) - tag.last_edit_time = datetime.datetime.utcnow() + tag.last_edit_time = datetime.utcnow() ctx.session.flush() snapshots.modify(tag, ctx.user) ctx.session.commit() @@ -97,9 +98,9 @@ def update_tag(ctx, params): return _serialize(ctx, tag) -@routes.delete('/tag/(?P.+)') -def delete_tag(ctx, params): - tag = tags.get_tag_by_name(params['tag_name']) +@rest.routes.delete('/tag/(?P.+)') +def delete_tag(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: + tag = _get_tag(params) versions.verify_version(tag, ctx) auth.verify_privilege(ctx.user, 'tags:delete') snapshots.delete(tag, ctx.user) @@ -109,10 +110,11 @@ def delete_tag(ctx, params): return {} -@routes.post('/tag-merge/?') -def merge_tags(ctx, _params=None): - source_tag_name = ctx.get_param_as_string('remove', required=True) or '' - target_tag_name = ctx.get_param_as_string('mergeTo', required=True) or '' +@rest.routes.post('/tag-merge/?') +def merge_tags( + ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: + source_tag_name = ctx.get_param_as_string('remove') + target_tag_name = ctx.get_param_as_string('mergeTo') source_tag = tags.get_tag_by_name(source_tag_name) target_tag = tags.get_tag_by_name(target_tag_name) versions.verify_version(source_tag, ctx, 'removeVersion') @@ -126,10 +128,11 @@ def merge_tags(ctx, _params=None): return _serialize(ctx, target_tag) -@routes.get('/tag-siblings/(?P.+)') -def get_tag_siblings(ctx, params): +@rest.routes.get('/tag-siblings/(?P.+)') +def get_tag_siblings( + ctx: rest.Context, params: Dict[str, str]) -> rest.Response: auth.verify_privilege(ctx.user, 'tags:view') - tag = tags.get_tag_by_name(params['tag_name']) + tag = _get_tag(params) result = tags.get_tag_siblings(tag) serialized_siblings = [] for sibling, occurrences in result: diff --git a/server/szurubooru/api/tag_category_api.py b/server/szurubooru/api/tag_category_api.py index 139da1d8..c7aaca89 100644 --- a/server/szurubooru/api/tag_category_api.py +++ b/server/szurubooru/api/tag_category_api.py @@ -1,15 +1,18 @@ -from szurubooru.rest import routes +from typing import Dict +from szurubooru import model, rest from szurubooru.func import ( - auth, tags, tag_categories, snapshots, util, versions) + auth, tags, tag_categories, snapshots, serialization, versions) -def _serialize(ctx, category): +def _serialize( + ctx: rest.Context, category: model.TagCategory) -> rest.Response: return tag_categories.serialize_category( - category, options=util.get_serialization_options(ctx)) + category, options=serialization.get_serialization_options(ctx)) -@routes.get('/tag-categories/?') -def get_tag_categories(ctx, _params=None): +@rest.routes.get('/tag-categories/?') +def get_tag_categories( + ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: auth.verify_privilege(ctx.user, 'tag_categories:list') categories = tag_categories.get_all_categories() return { @@ -17,11 +20,12 @@ def get_tag_categories(ctx, _params=None): } -@routes.post('/tag-categories/?') -def create_tag_category(ctx, _params=None): +@rest.routes.post('/tag-categories/?') +def create_tag_category( + ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: auth.verify_privilege(ctx.user, 'tag_categories:create') - name = ctx.get_param_as_string('name', required=True) - color = ctx.get_param_as_string('color', required=True) + name = ctx.get_param_as_string('name') + color = ctx.get_param_as_string('color') category = tag_categories.create_category(name, color) ctx.session.add(category) ctx.session.flush() @@ -31,15 +35,17 @@ def create_tag_category(ctx, _params=None): return _serialize(ctx, category) -@routes.get('/tag-category/(?P[^/]+)/?') -def get_tag_category(ctx, params): +@rest.routes.get('/tag-category/(?P[^/]+)/?') +def get_tag_category( + ctx: rest.Context, params: Dict[str, str]) -> rest.Response: auth.verify_privilege(ctx.user, 'tag_categories:view') category = tag_categories.get_category_by_name(params['category_name']) return _serialize(ctx, category) -@routes.put('/tag-category/(?P[^/]+)/?') -def update_tag_category(ctx, params): +@rest.routes.put('/tag-category/(?P[^/]+)/?') +def update_tag_category( + ctx: rest.Context, params: Dict[str, str]) -> rest.Response: category = tag_categories.get_category_by_name( params['category_name'], lock=True) versions.verify_version(category, ctx) @@ -59,8 +65,9 @@ def update_tag_category(ctx, params): return _serialize(ctx, category) -@routes.delete('/tag-category/(?P[^/]+)/?') -def delete_tag_category(ctx, params): +@rest.routes.delete('/tag-category/(?P[^/]+)/?') +def delete_tag_category( + ctx: rest.Context, params: Dict[str, str]) -> rest.Response: category = tag_categories.get_category_by_name( params['category_name'], lock=True) versions.verify_version(category, ctx) @@ -72,8 +79,9 @@ def delete_tag_category(ctx, params): return {} -@routes.put('/tag-category/(?P[^/]+)/default/?') -def set_tag_category_as_default(ctx, params): +@rest.routes.put('/tag-category/(?P[^/]+)/default/?') +def set_tag_category_as_default( + ctx: rest.Context, params: Dict[str, str]) -> rest.Response: auth.verify_privilege(ctx.user, 'tag_categories:set_default') category = tag_categories.get_category_by_name( params['category_name'], lock=True) diff --git a/server/szurubooru/api/upload_api.py b/server/szurubooru/api/upload_api.py index eaf2880b..9200eaa0 100644 --- a/server/szurubooru/api/upload_api.py +++ b/server/szurubooru/api/upload_api.py @@ -1,10 +1,12 @@ -from szurubooru.rest import routes +from typing import Dict +from szurubooru import rest from szurubooru.func import auth, file_uploads -@routes.post('/uploads/?') -def create_temporary_file(ctx, _params=None): +@rest.routes.post('/uploads/?') +def create_temporary_file( + ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: auth.verify_privilege(ctx.user, 'uploads:create') - content = ctx.get_file('content', required=True, allow_tokens=False) + content = ctx.get_file('content', allow_tokens=False) token = file_uploads.save(content) return {'token': token} diff --git a/server/szurubooru/api/user_api.py b/server/szurubooru/api/user_api.py index 187e4686..910f2a42 100644 --- a/server/szurubooru/api/user_api.py +++ b/server/szurubooru/api/user_api.py @@ -1,56 +1,57 @@ -from szurubooru import search -from szurubooru.rest import routes -from szurubooru.func import auth, users, util, versions +from typing import Any, Dict +from szurubooru import model, search, rest +from szurubooru.func import auth, users, serialization, versions _search_executor = search.Executor(search.configs.UserSearchConfig()) -def _serialize(ctx, user, **kwargs): +def _serialize( + ctx: rest.Context, user: model.User, **kwargs: Any) -> rest.Response: return users.serialize_user( user, ctx.user, - options=util.get_serialization_options(ctx), + options=serialization.get_serialization_options(ctx), **kwargs) -@routes.get('/users/?') -def get_users(ctx, _params=None): +@rest.routes.get('/users/?') +def get_users(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: auth.verify_privilege(ctx.user, 'users:list') return _search_executor.execute_and_serialize( ctx, lambda user: _serialize(ctx, user)) -@routes.post('/users/?') -def create_user(ctx, _params=None): +@rest.routes.post('/users/?') +def create_user( + ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: auth.verify_privilege(ctx.user, 'users:create') - name = ctx.get_param_as_string('name', required=True) - password = ctx.get_param_as_string('password', required=True) - email = ctx.get_param_as_string('email', required=False, default='') + name = ctx.get_param_as_string('name') + password = ctx.get_param_as_string('password') + email = ctx.get_param_as_string('email', default='') user = users.create_user(name, password, email) if ctx.has_param('rank'): - users.update_user_rank( - user, ctx.get_param_as_string('rank'), ctx.user) + users.update_user_rank(user, ctx.get_param_as_string('rank'), ctx.user) if ctx.has_param('avatarStyle'): users.update_user_avatar( user, ctx.get_param_as_string('avatarStyle'), - ctx.get_file('avatar')) + ctx.get_file('avatar', default=b'')) ctx.session.add(user) ctx.session.commit() return _serialize(ctx, user, force_show_email=True) -@routes.get('/user/(?P[^/]+)/?') -def get_user(ctx, params): +@rest.routes.get('/user/(?P[^/]+)/?') +def get_user(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: user = users.get_user_by_name(params['user_name']) if ctx.user.user_id != user.user_id: auth.verify_privilege(ctx.user, 'users:view') return _serialize(ctx, user) -@routes.put('/user/(?P[^/]+)/?') -def update_user(ctx, params): +@rest.routes.put('/user/(?P[^/]+)/?') +def update_user(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: user = users.get_user_by_name(params['user_name']) versions.verify_version(user, ctx) versions.bump_version(user) @@ -74,13 +75,13 @@ def update_user(ctx, params): users.update_user_avatar( user, ctx.get_param_as_string('avatarStyle'), - ctx.get_file('avatar')) + ctx.get_file('avatar', default=b'')) ctx.session.commit() return _serialize(ctx, user) -@routes.delete('/user/(?P[^/]+)/?') -def delete_user(ctx, params): +@rest.routes.delete('/user/(?P[^/]+)/?') +def delete_user(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: user = users.get_user_by_name(params['user_name']) versions.verify_version(user, ctx) infix = 'self' if ctx.user.user_id == user.user_id else 'any' diff --git a/server/szurubooru/config.py b/server/szurubooru/config.py index e5693117..567b4574 100644 --- a/server/szurubooru/config.py +++ b/server/szurubooru/config.py @@ -1,8 +1,9 @@ +from typing import Dict import os import yaml -def merge(left, right): +def merge(left: Dict, right: Dict) -> Dict: for key in right: if key in left: if isinstance(left[key], dict) and isinstance(right[key], dict): @@ -14,7 +15,7 @@ def merge(left, right): return left -def read_config(): +def read_config() -> Dict: with open('../config.yaml.dist') as handle: ret = yaml.load(handle.read()) if os.path.exists('../config.yaml'): diff --git a/server/szurubooru/db.py b/server/szurubooru/db.py new file mode 100644 index 00000000..f90bfaf9 --- /dev/null +++ b/server/szurubooru/db.py @@ -0,0 +1,36 @@ +from typing import Any +import threading +import sqlalchemy as sa +import sqlalchemy.orm +from szurubooru import config + +# pylint: disable=invalid-name +_data = threading.local() +_engine = sa.create_engine(config.config['database']) # type: Any +sessionmaker = sa.orm.sessionmaker(bind=_engine, autoflush=False) # type: Any +session = sa.orm.scoped_session(sessionmaker) # type: Any + + +def get_session() -> Any: + global session + return session + + +def set_sesssion(new_session: Any) -> None: + global session + session = new_session + + +def reset_query_count() -> None: + _data.query_count = 0 + + +def get_query_count() -> int: + return _data.query_count + + +def _bump_query_count() -> None: + _data.query_count = getattr(_data, 'query_count', 0) + 1 + + +sa.event.listen(_engine, 'after_execute', lambda *args: _bump_query_count()) diff --git a/server/szurubooru/db/__init__.py b/server/szurubooru/db/__init__.py deleted file mode 100644 index 3eb18833..00000000 --- a/server/szurubooru/db/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -from szurubooru.db.base import Base -from szurubooru.db.user import User -from szurubooru.db.tag_category import TagCategory -from szurubooru.db.tag import (Tag, TagName, TagSuggestion, TagImplication) -from szurubooru.db.post import ( - Post, - PostTag, - PostRelation, - PostFavorite, - PostScore, - PostNote, - PostFeature) -from szurubooru.db.comment import (Comment, CommentScore) -from szurubooru.db.snapshot import Snapshot -from szurubooru.db.session import ( - session, sessionmaker, reset_query_count, get_query_count) -import szurubooru.db.util diff --git a/server/szurubooru/db/session.py b/server/szurubooru/db/session.py deleted file mode 100644 index fd77b4c2..00000000 --- a/server/szurubooru/db/session.py +++ /dev/null @@ -1,27 +0,0 @@ -import threading -import sqlalchemy -from szurubooru import config - - -# pylint: disable=invalid-name -_engine = sqlalchemy.create_engine(config.config['database']) -sessionmaker = sqlalchemy.orm.sessionmaker(bind=_engine, autoflush=False) -session = sqlalchemy.orm.scoped_session(sessionmaker) - -_data = threading.local() - - -def reset_query_count(): - _data.query_count = 0 - - -def get_query_count(): - return _data.query_count - - -def _bump_query_count(): - _data.query_count = getattr(_data, 'query_count', 0) + 1 - - -sqlalchemy.event.listen( - _engine, 'after_execute', lambda *args: _bump_query_count()) diff --git a/server/szurubooru/db/util.py b/server/szurubooru/db/util.py deleted file mode 100644 index d6edf188..00000000 --- a/server/szurubooru/db/util.py +++ /dev/null @@ -1,34 +0,0 @@ -from sqlalchemy.inspection import inspect - - -def get_resource_info(entity): - serializers = { - 'tag': lambda tag: tag.first_name, - 'tag_category': lambda category: category.name, - 'comment': lambda comment: comment.comment_id, - 'post': lambda post: post.post_id, - } - - resource_type = entity.__table__.name - assert resource_type in serializers - - primary_key = inspect(entity).identity - assert primary_key is not None - assert len(primary_key) == 1 - - resource_name = serializers[resource_type](entity) - assert resource_name - - resource_pkey = primary_key[0] - assert resource_pkey - - return (resource_type, resource_pkey, resource_name) - - -def get_aux_entity(session, get_table_info, entity, user): - table, get_column = get_table_info(entity) - return session \ - .query(table) \ - .filter(get_column(table) == get_column(entity)) \ - .filter(table.user_id == user.user_id) \ - .one_or_none() diff --git a/server/szurubooru/errors.py b/server/szurubooru/errors.py index 4fbb67b6..b5f1cc3b 100644 --- a/server/szurubooru/errors.py +++ b/server/szurubooru/errors.py @@ -1,5 +1,11 @@ +from typing import Dict + + class BaseError(RuntimeError): - def __init__(self, message='Unknown error', extra_fields=None): + def __init__( + self, + message: str='Unknown error', + extra_fields: Dict[str, str]=None) -> None: super().__init__(message) self.extra_fields = extra_fields diff --git a/server/szurubooru/facade.py b/server/szurubooru/facade.py index 48957a1f..f39fcf92 100644 --- a/server/szurubooru/facade.py +++ b/server/szurubooru/facade.py @@ -2,7 +2,10 @@ import os import time import logging import threading +from typing import Callable, Any, Type + import coloredlogs +import sqlalchemy as sa import sqlalchemy.orm.exc from szurubooru import config, db, errors, rest from szurubooru.func import posts, file_uploads @@ -10,7 +13,10 @@ from szurubooru.func import posts, file_uploads from szurubooru import api, middleware -def _map_error(ex, target_class, title): +def _map_error( + ex: Exception, + target_class: Type[rest.errors.BaseHttpError], + title: str) -> rest.errors.BaseHttpError: return target_class( name=type(ex).__name__, title=title, @@ -18,38 +24,38 @@ def _map_error(ex, target_class, title): extra_fields=getattr(ex, 'extra_fields', {})) -def _on_auth_error(ex): +def _on_auth_error(ex: Exception) -> None: raise _map_error(ex, rest.errors.HttpForbidden, 'Authentication error') -def _on_validation_error(ex): +def _on_validation_error(ex: Exception) -> None: raise _map_error(ex, rest.errors.HttpBadRequest, 'Validation error') -def _on_search_error(ex): +def _on_search_error(ex: Exception) -> None: raise _map_error(ex, rest.errors.HttpBadRequest, 'Search error') -def _on_integrity_error(ex): +def _on_integrity_error(ex: Exception) -> None: raise _map_error(ex, rest.errors.HttpConflict, 'Integrity violation') -def _on_not_found_error(ex): +def _on_not_found_error(ex: Exception) -> None: raise _map_error(ex, rest.errors.HttpNotFound, 'Not found') -def _on_processing_error(ex): +def _on_processing_error(ex: Exception) -> None: raise _map_error(ex, rest.errors.HttpBadRequest, 'Processing error') -def _on_third_party_error(ex): +def _on_third_party_error(ex: Exception) -> None: raise _map_error( ex, rest.errors.HttpInternalServerError, 'Server configuration error') -def _on_stale_data_error(_ex): +def _on_stale_data_error(_ex: Exception) -> None: raise rest.errors.HttpConflict( name='IntegrityError', title='Integrity violation', @@ -58,7 +64,7 @@ def _on_stale_data_error(_ex): 'Please try again.')) -def validate_config(): +def validate_config() -> None: ''' Check whether config doesn't contain errors that might prove lethal at runtime. @@ -86,7 +92,7 @@ def validate_config(): raise errors.ConfigError('Database is not configured') -def purge_old_uploads(): +def purge_old_uploads() -> None: while True: try: file_uploads.purge_old_uploads() @@ -95,7 +101,7 @@ def purge_old_uploads(): time.sleep(60 * 5) -def create_app(): +def create_app() -> Callable[[Any, Any], Any]: ''' Create a WSGI compatible App object. ''' validate_config() coloredlogs.install(fmt='[%(asctime)-15s] %(name)s %(message)s') @@ -122,7 +128,7 @@ def create_app(): rest.errors.handle(errors.NotFoundError, _on_not_found_error) rest.errors.handle(errors.ProcessingError, _on_processing_error) rest.errors.handle(errors.ThirdPartyError, _on_third_party_error) - rest.errors.handle(sqlalchemy.orm.exc.StaleDataError, _on_stale_data_error) + rest.errors.handle(sa.orm.exc.StaleDataError, _on_stale_data_error) return rest.application diff --git a/server/szurubooru/func/auth.py b/server/szurubooru/func/auth.py index d71c8f9d..25c991c4 100644 --- a/server/szurubooru/func/auth.py +++ b/server/szurubooru/func/auth.py @@ -1,22 +1,22 @@ import hashlib import random from collections import OrderedDict -from szurubooru import config, db, errors +from szurubooru import config, model, errors from szurubooru.func import util RANK_MAP = OrderedDict([ - (db.User.RANK_ANONYMOUS, 'anonymous'), - (db.User.RANK_RESTRICTED, 'restricted'), - (db.User.RANK_REGULAR, 'regular'), - (db.User.RANK_POWER, 'power'), - (db.User.RANK_MODERATOR, 'moderator'), - (db.User.RANK_ADMINISTRATOR, 'administrator'), - (db.User.RANK_NOBODY, 'nobody'), + (model.User.RANK_ANONYMOUS, 'anonymous'), + (model.User.RANK_RESTRICTED, 'restricted'), + (model.User.RANK_REGULAR, 'regular'), + (model.User.RANK_POWER, 'power'), + (model.User.RANK_MODERATOR, 'moderator'), + (model.User.RANK_ADMINISTRATOR, 'administrator'), + (model.User.RANK_NOBODY, 'nobody'), ]) -def get_password_hash(salt, password): +def get_password_hash(salt: str, password: str) -> str: ''' Retrieve new-style password hash. ''' digest = hashlib.sha256() digest.update(config.config['secret'].encode('utf8')) @@ -25,7 +25,7 @@ def get_password_hash(salt, password): return digest.hexdigest() -def get_legacy_password_hash(salt, password): +def get_legacy_password_hash(salt: str, password: str) -> str: ''' Retrieve old-style password hash. ''' digest = hashlib.sha1() digest.update(b'1A2/$_4xVa') @@ -34,7 +34,7 @@ def get_legacy_password_hash(salt, password): return digest.hexdigest() -def create_password(): +def create_password() -> str: alphabet = { 'c': list('bcdfghijklmnpqrstvwxyz'), 'v': list('aeiou'), @@ -44,7 +44,7 @@ def create_password(): return ''.join(random.choice(alphabet[l]) for l in list(pattern)) -def is_valid_password(user, password): +def is_valid_password(user: model.User, password: str) -> bool: assert user salt, valid_hash = user.password_salt, user.password_hash possible_hashes = [ @@ -54,7 +54,7 @@ def is_valid_password(user, password): return valid_hash in possible_hashes -def has_privilege(user, privilege_name): +def has_privilege(user: model.User, privilege_name: str) -> bool: assert user all_ranks = list(RANK_MAP.keys()) assert privilege_name in config.config['privileges'] @@ -65,13 +65,13 @@ def has_privilege(user, privilege_name): return user.rank in good_ranks -def verify_privilege(user, privilege_name): +def verify_privilege(user: model.User, privilege_name: str) -> None: assert user if not has_privilege(user, privilege_name): raise errors.AuthError('Insufficient privileges to do this.') -def generate_authentication_token(user): +def generate_authentication_token(user: model.User) -> str: ''' Generate nonguessable challenge (e.g. links in password reminder). ''' assert user digest = hashlib.md5() diff --git a/server/szurubooru/func/cache.py b/server/szurubooru/func/cache.py index 4b775548..345835c2 100644 --- a/server/szurubooru/func/cache.py +++ b/server/szurubooru/func/cache.py @@ -1,21 +1,21 @@ +from typing import Any, List, Dict from datetime import datetime class LruCacheItem: - def __init__(self, key, value): + def __init__(self, key: object, value: Any) -> None: self.key = key self.value = value self.timestamp = datetime.utcnow() class LruCache: - def __init__(self, length, delta=None): + def __init__(self, length: int) -> None: self.length = length - self.delta = delta - self.hash = {} - self.item_list = [] + self.hash = {} # type: Dict[object, LruCacheItem] + self.item_list = [] # type: List[LruCacheItem] - def insert_item(self, item): + def insert_item(self, item: LruCacheItem) -> None: if item.key in self.hash: item_index = next( i @@ -31,11 +31,11 @@ class LruCache: self.hash[item.key] = item self.item_list.insert(0, item) - def remove_all(self): + def remove_all(self) -> None: self.hash = {} self.item_list = [] - def remove_item(self, item): + def remove_item(self, item: LruCacheItem) -> None: del self.hash[item.key] del self.item_list[self.item_list.index(item)] @@ -43,22 +43,22 @@ class LruCache: _CACHE = LruCache(length=100) -def purge(): +def purge() -> None: _CACHE.remove_all() -def has(key): +def has(key: object) -> bool: return key in _CACHE.hash -def get(key): +def get(key: object) -> Any: return _CACHE.hash[key].value -def remove(key): +def remove(key: object) -> None: if has(key): del _CACHE.hash[key] -def put(key, value): +def put(key: object, value: Any) -> None: _CACHE.insert_item(LruCacheItem(key, value)) diff --git a/server/szurubooru/func/comments.py b/server/szurubooru/func/comments.py index 6b7def85..fe15d8b6 100644 --- a/server/szurubooru/func/comments.py +++ b/server/szurubooru/func/comments.py @@ -1,6 +1,7 @@ -import datetime -from szurubooru import db, errors -from szurubooru.func import users, scores, util +from datetime import datetime +from typing import Any, Optional, List, Dict, Callable +from szurubooru import db, model, errors, rest +from szurubooru.func import users, scores, util, serialization class InvalidCommentIdError(errors.ValidationError): @@ -15,52 +16,87 @@ class EmptyCommentTextError(errors.ValidationError): pass -def serialize_comment(comment, auth_user, options=None): - return util.serialize_entity( - comment, - { - 'id': lambda: comment.comment_id, - 'user': - lambda: users.serialize_micro_user(comment.user, auth_user), - 'postId': lambda: comment.post.post_id, - 'version': lambda: comment.version, - 'text': lambda: comment.text, - 'creationTime': lambda: comment.creation_time, - 'lastEditTime': lambda: comment.last_edit_time, - 'score': lambda: comment.score, - 'ownScore': lambda: scores.get_score(comment, auth_user), - }, - options) +class CommentSerializer(serialization.BaseSerializer): + def __init__(self, comment: model.Comment, auth_user: model.User) -> None: + self.comment = comment + self.auth_user = auth_user + + def _serializers(self) -> Dict[str, Callable[[], Any]]: + return { + 'id': self.serialize_id, + 'user': self.serialize_user, + 'postId': self.serialize_post_id, + 'version': self.serialize_version, + 'text': self.serialize_text, + 'creationTime': self.serialize_creation_time, + 'lastEditTime': self.serialize_last_edit_time, + 'score': self.serialize_score, + 'ownScore': self.serialize_own_score, + } + + def serialize_id(self) -> Any: + return self.comment.comment_id + + def serialize_user(self) -> Any: + return users.serialize_micro_user(self.comment.user, self.auth_user) + + def serialize_post_id(self) -> Any: + return self.comment.post.post_id + + def serialize_version(self) -> Any: + return self.comment.version + + def serialize_text(self) -> Any: + return self.comment.text + + def serialize_creation_time(self) -> Any: + return self.comment.creation_time + + def serialize_last_edit_time(self) -> Any: + return self.comment.last_edit_time + + def serialize_score(self) -> Any: + return self.comment.score + + def serialize_own_score(self) -> Any: + return scores.get_score(self.comment, self.auth_user) -def try_get_comment_by_id(comment_id): - try: - comment_id = int(comment_id) - except ValueError: - raise InvalidCommentIdError('Invalid comment ID: %r.' % comment_id) +def serialize_comment( + comment: model.Comment, + auth_user: model.User, + options: List[str]=[]) -> rest.Response: + if comment is None: + return None + return CommentSerializer(comment, auth_user).serialize(options) + + +def try_get_comment_by_id(comment_id: int) -> Optional[model.Comment]: + comment_id = int(comment_id) return db.session \ - .query(db.Comment) \ - .filter(db.Comment.comment_id == comment_id) \ + .query(model.Comment) \ + .filter(model.Comment.comment_id == comment_id) \ .one_or_none() -def get_comment_by_id(comment_id): +def get_comment_by_id(comment_id: int) -> model.Comment: comment = try_get_comment_by_id(comment_id) if comment: return comment raise CommentNotFoundError('Comment %r not found.' % comment_id) -def create_comment(user, post, text): - comment = db.Comment() +def create_comment( + user: model.User, post: model.Post, text: str) -> model.Comment: + comment = model.Comment() comment.user = user comment.post = post update_comment_text(comment, text) - comment.creation_time = datetime.datetime.utcnow() + comment.creation_time = datetime.utcnow() return comment -def update_comment_text(comment, text): +def update_comment_text(comment: model.Comment, text: str) -> None: assert comment if not text: raise EmptyCommentTextError('Comment text cannot be empty.') diff --git a/server/szurubooru/func/diff.py b/server/szurubooru/func/diff.py index 0950f0f0..90014f7e 100644 --- a/server/szurubooru/func/diff.py +++ b/server/szurubooru/func/diff.py @@ -1,21 +1,26 @@ -def get_list_diff(old, new): - value = {'type': 'list change', 'added': [], 'removed': []} +from typing import List, Dict, Any + + +def get_list_diff(old: List[Any], new: List[Any]) -> Any: equal = True + removed = [] # type: List[Any] + added = [] # type: List[Any] for item in old: if item not in new: equal = False - value['removed'].append(item) + removed.append(item) for item in new: if item not in old: equal = False - value['added'].append(item) + added.append(item) - return None if equal else value + return None if equal else { + 'type': 'list change', 'added': added, 'removed': removed} -def get_dict_diff(old, new): +def get_dict_diff(old: Dict[str, Any], new: Dict[str, Any]) -> Any: value = {} equal = True diff --git a/server/szurubooru/func/favorites.py b/server/szurubooru/func/favorites.py index 00952de7..f567bfad 100644 --- a/server/szurubooru/func/favorites.py +++ b/server/szurubooru/func/favorites.py @@ -1,32 +1,34 @@ -import datetime -from szurubooru import db, errors +from typing import Any, Optional, Callable, Tuple +from datetime import datetime +from szurubooru import db, model, errors class InvalidFavoriteTargetError(errors.ValidationError): pass -def _get_table_info(entity): +def _get_table_info( + entity: model.Base) -> Tuple[model.Base, Callable[[model.Base], Any]]: assert entity - resource_type, _, _ = db.util.get_resource_info(entity) + resource_type, _, _ = model.util.get_resource_info(entity) if resource_type == 'post': - return db.PostFavorite, lambda table: table.post_id + return model.PostFavorite, lambda table: table.post_id raise InvalidFavoriteTargetError() -def _get_fav_entity(entity, user): +def _get_fav_entity(entity: model.Base, user: model.User) -> model.Base: assert entity assert user - return db.util.get_aux_entity(db.session, _get_table_info, entity, user) + return model.util.get_aux_entity(db.session, _get_table_info, entity, user) -def has_favorited(entity, user): +def has_favorited(entity: model.Base, user: model.User) -> bool: assert entity assert user return _get_fav_entity(entity, user) is not None -def unset_favorite(entity, user): +def unset_favorite(entity: model.Base, user: Optional[model.User]) -> None: assert entity assert user fav_entity = _get_fav_entity(entity, user) @@ -34,7 +36,7 @@ def unset_favorite(entity, user): db.session.delete(fav_entity) -def set_favorite(entity, user): +def set_favorite(entity: model.Base, user: Optional[model.User]) -> None: from szurubooru.func import scores assert entity assert user @@ -48,5 +50,5 @@ def set_favorite(entity, user): fav_entity = table() setattr(fav_entity, get_column(table).name, get_column(entity)) fav_entity.user = user - fav_entity.time = datetime.datetime.utcnow() + fav_entity.time = datetime.utcnow() db.session.add(fav_entity) diff --git a/server/szurubooru/func/file_uploads.py b/server/szurubooru/func/file_uploads.py index 95698e36..e7f93d83 100644 --- a/server/szurubooru/func/file_uploads.py +++ b/server/szurubooru/func/file_uploads.py @@ -1,27 +1,28 @@ -import datetime +from typing import Optional +from datetime import datetime, timedelta from szurubooru.func import files, util MAX_MINUTES = 60 -def _get_path(checksum): +def _get_path(checksum: str) -> str: return 'temporary-uploads/%s.dat' % checksum -def purge_old_uploads(): - now = datetime.datetime.now() +def purge_old_uploads() -> None: + now = datetime.now() for file in files.scan('temporary-uploads'): - file_time = datetime.datetime.fromtimestamp(file.stat().st_ctime) - if now - file_time > datetime.timedelta(minutes=MAX_MINUTES): + file_time = datetime.fromtimestamp(file.stat().st_ctime) + if now - file_time > timedelta(minutes=MAX_MINUTES): files.delete('temporary-uploads/%s' % file.name) -def get(checksum): +def get(checksum: str) -> Optional[bytes]: return files.get('temporary-uploads/%s.dat' % checksum) -def save(content): +def save(content: bytes) -> str: checksum = util.get_sha1(content) path = _get_path(checksum) if not files.has(path): diff --git a/server/szurubooru/func/files.py b/server/szurubooru/func/files.py index 3ca87776..0a992ee4 100644 --- a/server/szurubooru/func/files.py +++ b/server/szurubooru/func/files.py @@ -1,32 +1,33 @@ +from typing import Any, Optional, List import os from szurubooru import config -def _get_full_path(path): +def _get_full_path(path: str) -> str: return os.path.join(config.config['data_dir'], path) -def delete(path): +def delete(path: str) -> None: full_path = _get_full_path(path) if os.path.exists(full_path): os.unlink(full_path) -def has(path): +def has(path: str) -> bool: return os.path.exists(_get_full_path(path)) -def scan(path): +def scan(path: str) -> List[os.DirEntry]: if has(path): - return os.scandir(_get_full_path(path)) + return list(os.scandir(_get_full_path(path))) return [] -def move(source_path, target_path): - return os.rename(_get_full_path(source_path), _get_full_path(target_path)) +def move(source_path: str, target_path: str) -> None: + os.rename(_get_full_path(source_path), _get_full_path(target_path)) -def get(path): +def get(path: str) -> Optional[bytes]: full_path = _get_full_path(path) if not os.path.exists(full_path): return None @@ -34,7 +35,7 @@ def get(path): return handle.read() -def save(path, content): +def save(path: str, content: bytes) -> None: full_path = _get_full_path(path) os.makedirs(os.path.dirname(full_path), exist_ok=True) with open(full_path, 'wb') as handle: diff --git a/server/szurubooru/func/image_hash.py b/server/szurubooru/func/image_hash.py index c89a2ec1..dc998e83 100644 --- a/server/szurubooru/func/image_hash.py +++ b/server/szurubooru/func/image_hash.py @@ -1,6 +1,7 @@ import logging from io import BytesIO from datetime import datetime +from typing import Any, Optional, Tuple, Set, List, Callable import elasticsearch import elasticsearch_dsl import numpy as np @@ -10,13 +11,8 @@ from szurubooru import config, errors # pylint: disable=invalid-name logger = logging.getLogger(__name__) -es = elasticsearch.Elasticsearch([{ - 'host': config.config['elasticsearch']['host'], - 'port': config.config['elasticsearch']['port'], -}]) - -# Math based on paper from H. Chi Wong, Marshall Bern and David Goldber +# Math based on paper from H. Chi Wong, Marshall Bern and David Goldberg # Math code taken from https://github.com/ascribe/image-match # (which is licensed under Apache 2 license) @@ -32,14 +28,27 @@ MAX_WORDS = 63 ES_DOC_TYPE = 'image' ES_MAX_RESULTS = 100 +Window = Tuple[Tuple[float, float], Tuple[float, float]] +NpMatrix = Any -def _preprocess_image(image_or_path): - img = Image.open(BytesIO(image_or_path)) + +def _get_session() -> elasticsearch.Elasticsearch: + return elasticsearch.Elasticsearch([{ + 'host': config.config['elasticsearch']['host'], + 'port': config.config['elasticsearch']['port'], + }]) + + +def _preprocess_image(content: bytes) -> NpMatrix: + img = Image.open(BytesIO(content)) img = img.convert('RGB') return rgb2gray(np.asarray(img, dtype=np.uint8)) -def _crop_image(image, lower_percentile, upper_percentile): +def _crop_image( + image: NpMatrix, + lower_percentile: float, + upper_percentile: float) -> Window: rw = np.cumsum(np.sum(np.abs(np.diff(image, axis=1)), axis=1)) cw = np.cumsum(np.sum(np.abs(np.diff(image, axis=0)), axis=0)) upper_column_limit = np.searchsorted( @@ -56,16 +65,19 @@ def _crop_image(image, lower_percentile, upper_percentile): if lower_column_limit > upper_column_limit: lower_column_limit = int(lower_percentile / 100. * image.shape[1]) upper_column_limit = int(upper_percentile / 100. * image.shape[1]) - return [ + return ( (lower_row_limit, upper_row_limit), - (lower_column_limit, upper_column_limit)] + (lower_column_limit, upper_column_limit)) -def _normalize_and_threshold(diff_array, identical_tolerance, n_levels): +def _normalize_and_threshold( + diff_array: NpMatrix, + identical_tolerance: float, + n_levels: int) -> None: mask = np.abs(diff_array) < identical_tolerance diff_array[mask] = 0. if np.all(mask): - return None + return positive_cutoffs = np.percentile( diff_array[diff_array > 0.], np.linspace(0, 100, n_levels + 1)) negative_cutoffs = np.percentile( @@ -82,18 +94,24 @@ def _normalize_and_threshold(diff_array, identical_tolerance, n_levels): diff_array[ (diff_array <= interval[0]) & (diff_array >= interval[1])] = \ -(level + 1) - return None -def _compute_grid_points(image, n, window=None): +def _compute_grid_points( + image: NpMatrix, + n: float, + window: Window=None) -> Tuple[NpMatrix, NpMatrix]: if window is None: - window = [(0, image.shape[0]), (0, image.shape[1])] + window = ((0, image.shape[0]), (0, image.shape[1])) x_coords = np.linspace(window[0][0], window[0][1], n + 2, dtype=int)[1:-1] y_coords = np.linspace(window[1][0], window[1][1], n + 2, dtype=int)[1:-1] return x_coords, y_coords -def _compute_mean_level(image, x_coords, y_coords, p): +def _compute_mean_level( + image: NpMatrix, + x_coords: NpMatrix, + y_coords: NpMatrix, + p: Optional[float]) -> NpMatrix: if p is None: p = max([2.0, int(0.5 + min(image.shape) / 20.)]) avg_grey = np.zeros((x_coords.shape[0], y_coords.shape[0])) @@ -108,7 +126,7 @@ def _compute_mean_level(image, x_coords, y_coords, p): return avg_grey -def _compute_differentials(grey_level_matrix): +def _compute_differentials(grey_level_matrix: NpMatrix) -> NpMatrix: flipped = np.fliplr(grey_level_matrix) right_neighbors = -np.concatenate( ( @@ -152,8 +170,8 @@ def _compute_differentials(grey_level_matrix): lower_right_neighbors])) -def _generate_signature(path_or_image): - im_array = _preprocess_image(path_or_image) +def _generate_signature(content: bytes) -> NpMatrix: + im_array = _preprocess_image(content) image_limits = _crop_image( im_array, lower_percentile=LOWER_PERCENTILE, @@ -169,7 +187,7 @@ def _generate_signature(path_or_image): return np.ravel(diff_matrix).astype('int8') -def _get_words(array, k, n): +def _get_words(array: NpMatrix, k: int, n: int) -> NpMatrix: word_positions = np.linspace( 0, array.shape[0], n, endpoint=False).astype('int') assert k <= array.shape[0] @@ -187,21 +205,23 @@ def _get_words(array, k, n): return words -def _words_to_int(word_array): +def _words_to_int(word_array: NpMatrix) -> NpMatrix: width = word_array.shape[1] coding_vector = 3**np.arange(width) return np.dot(word_array + 1, coding_vector) -def _max_contrast(array): +def _max_contrast(array: NpMatrix) -> None: array[array > 0] = 1 array[array < 0] = -1 - return None -def _normalized_distance(_target_array, _vec, nan_value=1.0): - target_array = _target_array.astype(int) - vec = _vec.astype(int) +def _normalized_distance( + target_array: NpMatrix, + vec: NpMatrix, + nan_value: float=1.0) -> List[float]: + target_array = target_array.astype(int) + vec = vec.astype(int) topvec = np.linalg.norm(vec - target_array, axis=1) norm1 = np.linalg.norm(vec, axis=0) norm2 = np.linalg.norm(target_array, axis=1) @@ -210,9 +230,9 @@ def _normalized_distance(_target_array, _vec, nan_value=1.0): return finvec -def _safety_blanket(default_param_factory): - def wrapper_outer(target_function): - def wrapper_inner(*args, **kwargs): +def _safety_blanket(default_param_factory: Callable[[], Any]) -> Callable: + def wrapper_outer(target_function: Callable) -> Callable: + def wrapper_inner(*args: Any, **kwargs: Any) -> Any: try: return target_function(*args, **kwargs) except elasticsearch.exceptions.NotFoundError: @@ -226,20 +246,20 @@ def _safety_blanket(default_param_factory): except IOError: raise errors.ProcessingError('Not an image.') except Exception as ex: - raise errors.ThirdPartyError('Unknown error (%s).', ex) + raise errors.ThirdPartyError('Unknown error (%s).' % ex) return wrapper_inner return wrapper_outer class Lookalike: - def __init__(self, score, distance, path): + def __init__(self, score: int, distance: float, path: Any) -> None: self.score = score self.distance = distance self.path = path @_safety_blanket(lambda: None) -def add_image(path, image_content): +def add_image(path: str, image_content: bytes) -> None: assert path assert image_content signature = _generate_signature(image_content) @@ -253,7 +273,7 @@ def add_image(path, image_content): for i in range(MAX_WORDS): record['simple_word_' + str(i)] = words[i].tolist() - es.index( + _get_session().index( index=config.config['elasticsearch']['index'], doc_type=ES_DOC_TYPE, body=record, @@ -261,20 +281,20 @@ def add_image(path, image_content): @_safety_blanket(lambda: None) -def delete_image(path): +def delete_image(path: str) -> None: assert path - es.delete_by_query( + _get_session().delete_by_query( index=config.config['elasticsearch']['index'], doc_type=ES_DOC_TYPE, body={'query': {'term': {'path': path}}}) @_safety_blanket(lambda: []) -def search_by_image(image_content): +def search_by_image(image_content: bytes) -> List[Lookalike]: signature = _generate_signature(image_content) words = _get_words(signature, k=SAMPLE_WORDS, n=MAX_WORDS) - res = es.search( + res = _get_session().search( index=config.config['elasticsearch']['index'], doc_type=ES_DOC_TYPE, body={ @@ -299,7 +319,7 @@ def search_by_image(image_content): sigs = np.array([x['_source']['signature'] for x in res]) dists = _normalized_distance(sigs, np.array(signature)) - ids = set() + ids = set() # type: Set[int] ret = [] for item, dist in zip(res, dists): id = item['_id'] @@ -314,8 +334,8 @@ def search_by_image(image_content): @_safety_blanket(lambda: None) -def purge(): - es.delete_by_query( +def purge() -> None: + _get_session().delete_by_query( index=config.config['elasticsearch']['index'], doc_type=ES_DOC_TYPE, body={'query': {'match_all': {}}}, @@ -323,10 +343,10 @@ def purge(): @_safety_blanket(lambda: set()) -def get_all_paths(): +def get_all_paths() -> Set[str]: search = ( elasticsearch_dsl.Search( - using=es, + using=_get_session(), index=config.config['elasticsearch']['index'], doc_type=ES_DOC_TYPE) .source(['path'])) diff --git a/server/szurubooru/func/images.py b/server/szurubooru/func/images.py index fdab793b..103a6ff8 100644 --- a/server/szurubooru/func/images.py +++ b/server/szurubooru/func/images.py @@ -1,3 +1,4 @@ +from typing import List import logging import json import shlex @@ -15,23 +16,23 @@ _SCALE_FIT_FMT = \ class Image: - def __init__(self, content): + def __init__(self, content: bytes) -> None: self.content = content self._reload_info() @property - def width(self): + def width(self) -> int: return self.info['streams'][0]['width'] @property - def height(self): + def height(self) -> int: return self.info['streams'][0]['height'] @property - def frames(self): + def frames(self) -> int: return self.info['streams'][0]['nb_read_frames'] - def resize_fill(self, width, height): + def resize_fill(self, width: int, height: int) -> None: cli = [ '-i', '{path}', '-f', 'image2', @@ -53,7 +54,7 @@ class Image: assert self.content self._reload_info() - def to_png(self): + def to_png(self) -> bytes: return self._execute([ '-i', '{path}', '-f', 'image2', @@ -63,7 +64,7 @@ class Image: '-', ]) - def to_jpeg(self): + def to_jpeg(self) -> bytes: return self._execute([ '-f', 'lavfi', '-i', 'color=white:s=%dx%d' % (self.width, self.height), @@ -76,7 +77,7 @@ class Image: '-', ]) - def _execute(self, cli, program='ffmpeg'): + def _execute(self, cli: List[str], program: str='ffmpeg') -> bytes: extension = mime.get_extension(mime.get_mime_type(self.content)) assert extension with util.create_temp_file(suffix='.' + extension) as handle: @@ -99,7 +100,7 @@ class Image: 'Error while processing image.\n' + err.decode('utf-8')) return out - def _reload_info(self): + def _reload_info(self) -> None: self.info = json.loads(self._execute([ '-i', '{path}', '-of', 'json', diff --git a/server/szurubooru/func/mailer.py b/server/szurubooru/func/mailer.py index 94f9c506..76682f11 100644 --- a/server/szurubooru/func/mailer.py +++ b/server/szurubooru/func/mailer.py @@ -3,7 +3,7 @@ import email.mime.text from szurubooru import config -def send_mail(sender, recipient, subject, body): +def send_mail(sender: str, recipient: str, subject: str, body: str) -> None: msg = email.mime.text.MIMEText(body) msg['Subject'] = subject msg['From'] = sender diff --git a/server/szurubooru/func/mime.py b/server/szurubooru/func/mime.py index 2277ed64..c83f744e 100644 --- a/server/szurubooru/func/mime.py +++ b/server/szurubooru/func/mime.py @@ -1,7 +1,8 @@ import re +from typing import Optional -def get_mime_type(content): +def get_mime_type(content: bytes) -> str: if not content: return 'application/octet-stream' @@ -26,7 +27,7 @@ def get_mime_type(content): return 'application/octet-stream' -def get_extension(mime_type): +def get_extension(mime_type: str) -> Optional[str]: extension_map = { 'application/x-shockwave-flash': 'swf', 'image/gif': 'gif', @@ -39,19 +40,19 @@ def get_extension(mime_type): return extension_map.get((mime_type or '').strip().lower(), None) -def is_flash(mime_type): +def is_flash(mime_type: str) -> bool: return mime_type.lower() == 'application/x-shockwave-flash' -def is_video(mime_type): +def is_video(mime_type: str) -> bool: return mime_type.lower() in ('application/ogg', 'video/mp4', 'video/webm') -def is_image(mime_type): +def is_image(mime_type: str) -> bool: return mime_type.lower() in ('image/jpeg', 'image/png', 'image/gif') -def is_animated_gif(content): +def is_animated_gif(content: bytes) -> bool: pattern = b'\x21\xF9\x04[\x00-\xFF]{4}\x00[\x2C\x21]' return get_mime_type(content) == 'image/gif' \ and len(re.findall(pattern, content)) > 1 diff --git a/server/szurubooru/func/net.py b/server/szurubooru/func/net.py index fb0c427a..a6e18214 100644 --- a/server/szurubooru/func/net.py +++ b/server/szurubooru/func/net.py @@ -2,7 +2,7 @@ import urllib.request from szurubooru import errors -def download(url): +def download(url: str) -> bytes: assert url request = urllib.request.Request(url) request.add_header('Referer', url) diff --git a/server/szurubooru/func/posts.py b/server/szurubooru/func/posts.py index c942e799..aa4e137f 100644 --- a/server/szurubooru/func/posts.py +++ b/server/szurubooru/func/posts.py @@ -1,8 +1,10 @@ -import datetime -import sqlalchemy -from szurubooru import config, db, errors +from typing import Any, Optional, Tuple, List, Dict, Callable +from datetime import datetime +import sqlalchemy as sa +from szurubooru import config, db, model, errors, rest from szurubooru.func import ( - users, scores, comments, tags, util, mime, images, files, image_hash) + users, scores, comments, tags, util, + mime, images, files, image_hash, serialization) EMPTY_PIXEL = \ @@ -20,7 +22,7 @@ class PostAlreadyFeaturedError(errors.ValidationError): class PostAlreadyUploadedError(errors.ValidationError): - def __init__(self, other_post): + def __init__(self, other_post: model.Post) -> None: super().__init__( 'Post already uploaded (%d)' % other_post.post_id, { @@ -58,30 +60,30 @@ class InvalidPostFlagError(errors.ValidationError): class PostLookalike(image_hash.Lookalike): - def __init__(self, score, distance, post): + def __init__(self, score: int, distance: float, post: model.Post) -> None: super().__init__(score, distance, post.post_id) self.post = post SAFETY_MAP = { - db.Post.SAFETY_SAFE: 'safe', - db.Post.SAFETY_SKETCHY: 'sketchy', - db.Post.SAFETY_UNSAFE: 'unsafe', + model.Post.SAFETY_SAFE: 'safe', + model.Post.SAFETY_SKETCHY: 'sketchy', + model.Post.SAFETY_UNSAFE: 'unsafe', } TYPE_MAP = { - db.Post.TYPE_IMAGE: 'image', - db.Post.TYPE_ANIMATION: 'animation', - db.Post.TYPE_VIDEO: 'video', - db.Post.TYPE_FLASH: 'flash', + model.Post.TYPE_IMAGE: 'image', + model.Post.TYPE_ANIMATION: 'animation', + model.Post.TYPE_VIDEO: 'video', + model.Post.TYPE_FLASH: 'flash', } FLAG_MAP = { - db.Post.FLAG_LOOP: 'loop', + model.Post.FLAG_LOOP: 'loop', } -def get_post_content_url(post): +def get_post_content_url(post: model.Post) -> str: assert post return '%s/posts/%d.%s' % ( config.config['data_url'].rstrip('/'), @@ -89,31 +91,31 @@ def get_post_content_url(post): mime.get_extension(post.mime_type) or 'dat') -def get_post_thumbnail_url(post): +def get_post_thumbnail_url(post: model.Post) -> str: assert post return '%s/generated-thumbnails/%d.jpg' % ( config.config['data_url'].rstrip('/'), post.post_id) -def get_post_content_path(post): +def get_post_content_path(post: model.Post) -> str: assert post assert post.post_id return 'posts/%d.%s' % ( post.post_id, mime.get_extension(post.mime_type) or 'dat') -def get_post_thumbnail_path(post): +def get_post_thumbnail_path(post: model.Post) -> str: assert post return 'generated-thumbnails/%d.jpg' % (post.post_id) -def get_post_thumbnail_backup_path(post): +def get_post_thumbnail_backup_path(post: model.Post) -> str: assert post return 'posts/custom-thumbnails/%d.dat' % (post.post_id) -def serialize_note(note): +def serialize_note(note: model.PostNote) -> rest.Response: assert note return { 'polygon': note.polygon, @@ -121,113 +123,216 @@ def serialize_note(note): } -def serialize_post(post, auth_user, options=None): - return util.serialize_entity( - post, - { - 'id': lambda: post.post_id, - 'version': lambda: post.version, - 'creationTime': lambda: post.creation_time, - 'lastEditTime': lambda: post.last_edit_time, - 'safety': lambda: SAFETY_MAP[post.safety], - 'source': lambda: post.source, - 'type': lambda: TYPE_MAP[post.type], - 'mimeType': lambda: post.mime_type, - 'checksum': lambda: post.checksum, - 'fileSize': lambda: post.file_size, - 'canvasWidth': lambda: post.canvas_width, - 'canvasHeight': lambda: post.canvas_height, - 'contentUrl': lambda: get_post_content_url(post), - 'thumbnailUrl': lambda: get_post_thumbnail_url(post), - 'flags': lambda: post.flags, - 'tags': lambda: [ - tag.names[0].name for tag in tags.sort_tags(post.tags)], - 'relations': lambda: sorted( - { - post['id']: - post for post in [ - serialize_micro_post(rel, auth_user) - for rel in post.relations] - }.values(), - key=lambda post: post['id']), - 'user': lambda: users.serialize_micro_user(post.user, auth_user), - 'score': lambda: post.score, - 'ownScore': lambda: scores.get_score(post, auth_user), - 'ownFavorite': lambda: len([ - user for user in post.favorited_by - if user.user_id == auth_user.user_id] - ) > 0, - 'tagCount': lambda: post.tag_count, - 'favoriteCount': lambda: post.favorite_count, - 'commentCount': lambda: post.comment_count, - 'noteCount': lambda: post.note_count, - 'relationCount': lambda: post.relation_count, - 'featureCount': lambda: post.feature_count, - 'lastFeatureTime': lambda: post.last_feature_time, - 'favoritedBy': lambda: [ - users.serialize_micro_user(rel.user, auth_user) - for rel in post.favorited_by - ], - 'hasCustomThumbnail': - lambda: files.has(get_post_thumbnail_backup_path(post)), - 'notes': lambda: sorted( - [serialize_note(note) for note in post.notes], - key=lambda x: x['polygon']), - 'comments': lambda: [ - comments.serialize_comment(comment, auth_user) - for comment in sorted( - post.comments, - key=lambda comment: comment.creation_time)], - }, - options) +class PostSerializer(serialization.BaseSerializer): + def __init__(self, post: model.Post, auth_user: model.User) -> None: + self.post = post + self.auth_user = auth_user + + def _serializers(self) -> Dict[str, Callable[[], Any]]: + return { + 'id': self.serialize_id, + 'version': self.serialize_version, + 'creationTime': self.serialize_creation_time, + 'lastEditTime': self.serialize_last_edit_time, + 'safety': self.serialize_safety, + 'source': self.serialize_source, + 'type': self.serialize_type, + 'mimeType': self.serialize_mime, + 'checksum': self.serialize_checksum, + 'fileSize': self.serialize_file_size, + 'canvasWidth': self.serialize_canvas_width, + 'canvasHeight': self.serialize_canvas_height, + 'contentUrl': self.serialize_content_url, + 'thumbnailUrl': self.serialize_thumbnail_url, + 'flags': self.serialize_flags, + 'tags': self.serialize_tags, + 'relations': self.serialize_relations, + 'user': self.serialize_user, + 'score': self.serialize_score, + 'ownScore': self.serialize_own_score, + 'ownFavorite': self.serialize_own_favorite, + 'tagCount': self.serialize_tag_count, + 'favoriteCount': self.serialize_favorite_count, + 'commentCount': self.serialize_comment_count, + 'noteCount': self.serialize_note_count, + 'relationCount': self.serialize_relation_count, + 'featureCount': self.serialize_feature_count, + 'lastFeatureTime': self.serialize_last_feature_time, + 'favoritedBy': self.serialize_favorited_by, + 'hasCustomThumbnail': self.serialize_has_custom_thumbnail, + 'notes': self.serialize_notes, + 'comments': self.serialize_comments, + } + + def serialize_id(self) -> Any: + return self.post.post_id + + def serialize_version(self) -> Any: + return self.post.version + + def serialize_creation_time(self) -> Any: + return self.post.creation_time + + def serialize_last_edit_time(self) -> Any: + return self.post.last_edit_time + + def serialize_safety(self) -> Any: + return SAFETY_MAP[self.post.safety] + + def serialize_source(self) -> Any: + return self.post.source + + def serialize_type(self) -> Any: + return TYPE_MAP[self.post.type] + + def serialize_mime(self) -> Any: + return self.post.mime_type + + def serialize_checksum(self) -> Any: + return self.post.checksum + + def serialize_file_size(self) -> Any: + return self.post.file_size + + def serialize_canvas_width(self) -> Any: + return self.post.canvas_width + + def serialize_canvas_height(self) -> Any: + return self.post.canvas_height + + def serialize_content_url(self) -> Any: + return get_post_content_url(self.post) + + def serialize_thumbnail_url(self) -> Any: + return get_post_thumbnail_url(self.post) + + def serialize_flags(self) -> Any: + return self.post.flags + + def serialize_tags(self) -> Any: + return [tag.names[0].name for tag in tags.sort_tags(self.post.tags)] + + def serialize_relations(self) -> Any: + return sorted( + { + post['id']: post + for post in [ + serialize_micro_post(rel, self.auth_user) + for rel in self.post.relations] + }.values(), + key=lambda post: post['id']) + + def serialize_user(self) -> Any: + return users.serialize_micro_user(self.post.user, self.auth_user) + + def serialize_score(self) -> Any: + return self.post.score + + def serialize_own_score(self) -> Any: + return scores.get_score(self.post, self.auth_user) + + def serialize_own_favorite(self) -> Any: + return len([ + user for user in self.post.favorited_by + if user.user_id == self.auth_user.user_id] + ) > 0 + + def serialize_tag_count(self) -> Any: + return self.post.tag_count + + def serialize_favorite_count(self) -> Any: + return self.post.favorite_count + + def serialize_comment_count(self) -> Any: + return self.post.comment_count + + def serialize_note_count(self) -> Any: + return self.post.note_count + + def serialize_relation_count(self) -> Any: + return self.post.relation_count + + def serialize_feature_count(self) -> Any: + return self.post.feature_count + + def serialize_last_feature_time(self) -> Any: + return self.post.last_feature_time + + def serialize_favorited_by(self) -> Any: + return [ + users.serialize_micro_user(rel.user, self.auth_user) + for rel in self.post.favorited_by + ] + + def serialize_has_custom_thumbnail(self) -> Any: + return files.has(get_post_thumbnail_backup_path(self.post)) + + def serialize_notes(self) -> Any: + return sorted( + [serialize_note(note) for note in self.post.notes], + key=lambda x: x['polygon']) + + def serialize_comments(self) -> Any: + return [ + comments.serialize_comment(comment, self.auth_user) + for comment in sorted( + self.post.comments, + key=lambda comment: comment.creation_time)] -def serialize_micro_post(post, auth_user): +def serialize_post( + post: Optional[model.Post], + auth_user: model.User, + options: List[str]=[]) -> Optional[rest.Response]: + if not post: + return None + return PostSerializer(post, auth_user).serialize(options) + + +def serialize_micro_post( + post: model.Post, auth_user: model.User) -> Optional[rest.Response]: return serialize_post( - post, - auth_user=auth_user, - options=['id', 'thumbnailUrl']) + post, auth_user=auth_user, options=['id', 'thumbnailUrl']) -def get_post_count(): - return db.session.query(sqlalchemy.func.count(db.Post.post_id)).one()[0] +def get_post_count() -> int: + return db.session.query(sa.func.count(model.Post.post_id)).one()[0] -def try_get_post_by_id(post_id): - try: - post_id = int(post_id) - except ValueError: - raise InvalidPostIdError('Invalid post ID: %r.' % post_id) +def try_get_post_by_id(post_id: int) -> Optional[model.Post]: return db.session \ - .query(db.Post) \ - .filter(db.Post.post_id == post_id) \ + .query(model.Post) \ + .filter(model.Post.post_id == post_id) \ .one_or_none() -def get_post_by_id(post_id): +def get_post_by_id(post_id: int) -> model.Post: post = try_get_post_by_id(post_id) if not post: raise PostNotFoundError('Post %r not found.' % post_id) return post -def try_get_current_post_feature(): +def try_get_current_post_feature() -> Optional[model.PostFeature]: return db.session \ - .query(db.PostFeature) \ - .order_by(db.PostFeature.time.desc()) \ + .query(model.PostFeature) \ + .order_by(model.PostFeature.time.desc()) \ .first() -def try_get_featured_post(): +def try_get_featured_post() -> Optional[model.Post]: post_feature = try_get_current_post_feature() return post_feature.post if post_feature else None -def create_post(content, tag_names, user): - post = db.Post() - post.safety = db.Post.SAFETY_SAFE +def create_post( + content: bytes, + tag_names: List[str], + user: Optional[model.User]) -> Tuple[model.Post, List[model.Tag]]: + post = model.Post() + post.safety = model.Post.SAFETY_SAFE post.user = user - post.creation_time = datetime.datetime.utcnow() + post.creation_time = datetime.utcnow() post.flags = [] post.type = '' @@ -240,7 +345,7 @@ def create_post(content, tag_names, user): return (post, new_tags) -def update_post_safety(post, safety): +def update_post_safety(post: model.Post, safety: str) -> None: assert post safety = util.flip(SAFETY_MAP).get(safety, None) if not safety: @@ -249,30 +354,33 @@ def update_post_safety(post, safety): post.safety = safety -def update_post_source(post, source): +def update_post_source(post: model.Post, source: Optional[str]) -> None: assert post - if util.value_exceeds_column_size(source, db.Post.source): + if util.value_exceeds_column_size(source, model.Post.source): raise InvalidPostSourceError('Source is too long.') - post.source = source + post.source = source or None -@sqlalchemy.events.event.listens_for(db.Post, 'after_insert') -def _after_post_insert(_mapper, _connection, post): +@sa.events.event.listens_for(model.Post, 'after_insert') +def _after_post_insert( + _mapper: Any, _connection: Any, post: model.Post) -> None: _sync_post_content(post) -@sqlalchemy.events.event.listens_for(db.Post, 'after_update') -def _after_post_update(_mapper, _connection, post): +@sa.events.event.listens_for(model.Post, 'after_update') +def _after_post_update( + _mapper: Any, _connection: Any, post: model.Post) -> None: _sync_post_content(post) -@sqlalchemy.events.event.listens_for(db.Post, 'before_delete') -def _before_post_delete(_mapper, _connection, post): +@sa.events.event.listens_for(model.Post, 'before_delete') +def _before_post_delete( + _mapper: Any, _connection: Any, post: model.Post) -> None: if post.post_id: image_hash.delete_image(post.post_id) -def _sync_post_content(post): +def _sync_post_content(post: model.Post) -> None: regenerate_thumb = False if hasattr(post, '__content'): @@ -281,7 +389,7 @@ def _sync_post_content(post): delattr(post, '__content') regenerate_thumb = True if post.post_id and post.type in ( - db.Post.TYPE_IMAGE, db.Post.TYPE_ANIMATION): + model.Post.TYPE_IMAGE, model.Post.TYPE_ANIMATION): image_hash.delete_image(post.post_id) image_hash.add_image(post.post_id, content) @@ -299,29 +407,29 @@ def _sync_post_content(post): generate_post_thumbnail(post) -def update_post_content(post, content): +def update_post_content(post: model.Post, content: Optional[bytes]) -> None: assert post if not content: raise InvalidPostContentError('Post content missing.') post.mime_type = mime.get_mime_type(content) if mime.is_flash(post.mime_type): - post.type = db.Post.TYPE_FLASH + post.type = model.Post.TYPE_FLASH elif mime.is_image(post.mime_type): if mime.is_animated_gif(content): - post.type = db.Post.TYPE_ANIMATION + post.type = model.Post.TYPE_ANIMATION else: - post.type = db.Post.TYPE_IMAGE + post.type = model.Post.TYPE_IMAGE elif mime.is_video(post.mime_type): - post.type = db.Post.TYPE_VIDEO + post.type = model.Post.TYPE_VIDEO else: raise InvalidPostContentError( 'Unhandled file type: %r' % post.mime_type) post.checksum = util.get_sha1(content) other_post = db.session \ - .query(db.Post) \ - .filter(db.Post.checksum == post.checksum) \ - .filter(db.Post.post_id != post.post_id) \ + .query(model.Post) \ + .filter(model.Post.checksum == post.checksum) \ + .filter(model.Post.post_id != post.post_id) \ .one_or_none() if other_post \ and other_post.post_id \ @@ -343,18 +451,20 @@ def update_post_content(post, content): setattr(post, '__content', content) -def update_post_thumbnail(post, content=None): +def update_post_thumbnail( + post: model.Post, content: Optional[bytes]=None) -> None: assert post setattr(post, '__thumbnail', content) -def generate_post_thumbnail(post): +def generate_post_thumbnail(post: model.Post) -> None: assert post if files.has(get_post_thumbnail_backup_path(post)): content = files.get(get_post_thumbnail_backup_path(post)) else: content = files.get(get_post_content_path(post)) try: + assert content image = images.Image(content) image.resize_fill( int(config.config['thumbnails']['post_width']), @@ -364,14 +474,15 @@ def generate_post_thumbnail(post): files.save(get_post_thumbnail_path(post), EMPTY_PIXEL) -def update_post_tags(post, tag_names): +def update_post_tags( + post: model.Post, tag_names: List[str]) -> List[model.Tag]: assert post existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names) post.tags = existing_tags + new_tags return new_tags -def update_post_relations(post, new_post_ids): +def update_post_relations(post: model.Post, new_post_ids: List[int]) -> None: assert post try: new_post_ids = [int(id) for id in new_post_ids] @@ -382,8 +493,8 @@ def update_post_relations(post, new_post_ids): old_post_ids = [int(p.post_id) for p in old_posts] if new_post_ids: new_posts = db.session \ - .query(db.Post) \ - .filter(db.Post.post_id.in_(new_post_ids)) \ + .query(model.Post) \ + .filter(model.Post.post_id.in_(new_post_ids)) \ .all() else: new_posts = [] @@ -402,7 +513,7 @@ def update_post_relations(post, new_post_ids): relation.relations.append(post) -def update_post_notes(post, notes): +def update_post_notes(post: model.Post, notes: Any) -> None: assert post post.notes = [] for note in notes: @@ -433,13 +544,13 @@ def update_post_notes(post, notes): except ValueError: raise InvalidPostNoteError( 'A point in note\'s polygon must be numeric.') - if util.value_exceeds_column_size(note['text'], db.PostNote.text): + if util.value_exceeds_column_size(note['text'], model.PostNote.text): raise InvalidPostNoteError('Note text is too long.') post.notes.append( - db.PostNote(polygon=note['polygon'], text=str(note['text']))) + model.PostNote(polygon=note['polygon'], text=str(note['text']))) -def update_post_flags(post, flags): +def update_post_flags(post: model.Post, flags: List[str]) -> None: assert post target_flags = [] for flag in flags: @@ -451,88 +562,95 @@ def update_post_flags(post, flags): post.flags = target_flags -def feature_post(post, user): +def feature_post(post: model.Post, user: Optional[model.User]) -> None: assert post - post_feature = db.PostFeature() - post_feature.time = datetime.datetime.utcnow() + post_feature = model.PostFeature() + post_feature.time = datetime.utcnow() post_feature.post = post post_feature.user = user db.session.add(post_feature) -def delete(post): +def delete(post: model.Post) -> None: assert post db.session.delete(post) -def merge_posts(source_post, target_post, replace_content): +def merge_posts( + source_post: model.Post, + target_post: model.Post, + replace_content: bool) -> None: assert source_post assert target_post if source_post.post_id == target_post.post_id: raise InvalidPostRelationError('Cannot merge post with itself.') - def merge_tables(table, anti_dup_func, source_post_id, target_post_id): + def merge_tables( + table: model.Base, + anti_dup_func: Optional[Callable[[model.Base, model.Base], bool]], + source_post_id: int, + target_post_id: int) -> None: alias1 = table - alias2 = sqlalchemy.orm.util.aliased(table) + alias2 = sa.orm.util.aliased(table) update_stmt = ( - sqlalchemy.sql.expression.update(alias1) + sa.sql.expression.update(alias1) .where(alias1.post_id == source_post_id)) if anti_dup_func is not None: update_stmt = ( update_stmt .where( - ~sqlalchemy.exists() + ~sa.exists() .where(anti_dup_func(alias1, alias2)) .where(alias2.post_id == target_post_id))) update_stmt = update_stmt.values(post_id=target_post_id) db.session.execute(update_stmt) - def merge_tags(source_post_id, target_post_id): + def merge_tags(source_post_id: int, target_post_id: int) -> None: merge_tables( - db.PostTag, + model.PostTag, lambda alias1, alias2: alias1.tag_id == alias2.tag_id, source_post_id, target_post_id) - def merge_scores(source_post_id, target_post_id): + def merge_scores(source_post_id: int, target_post_id: int) -> None: merge_tables( - db.PostScore, + model.PostScore, lambda alias1, alias2: alias1.user_id == alias2.user_id, source_post_id, target_post_id) - def merge_favorites(source_post_id, target_post_id): + def merge_favorites(source_post_id: int, target_post_id: int) -> None: merge_tables( - db.PostFavorite, + model.PostFavorite, lambda alias1, alias2: alias1.user_id == alias2.user_id, source_post_id, target_post_id) - def merge_comments(source_post_id, target_post_id): - merge_tables(db.Comment, None, source_post_id, target_post_id) + def merge_comments(source_post_id: int, target_post_id: int) -> None: + merge_tables(model.Comment, None, source_post_id, target_post_id) - def merge_relations(source_post_id, target_post_id): - alias1 = db.PostRelation - alias2 = sqlalchemy.orm.util.aliased(db.PostRelation) + def merge_relations(source_post_id: int, target_post_id: int) -> None: + alias1 = model.PostRelation + alias2 = sa.orm.util.aliased(model.PostRelation) update_stmt = ( - sqlalchemy.sql.expression.update(alias1) + sa.sql.expression.update(alias1) .where(alias1.parent_id == source_post_id) .where(alias1.child_id != target_post_id) .where( - ~sqlalchemy.exists() + ~sa.exists() .where(alias2.child_id == alias1.child_id) .where(alias2.parent_id == target_post_id)) .values(parent_id=target_post_id)) db.session.execute(update_stmt) update_stmt = ( - sqlalchemy.sql.expression.update(alias1) + sa.sql.expression.update(alias1) .where(alias1.child_id == source_post_id) .where(alias1.parent_id != target_post_id) .where( - ~sqlalchemy.exists() + ~sa.exists() .where(alias2.parent_id == alias1.parent_id) .where(alias2.child_id == target_post_id)) .values(child_id=target_post_id)) @@ -553,15 +671,15 @@ def merge_posts(source_post, target_post, replace_content): update_post_content(target_post, content) -def search_by_image_exact(image_content): +def search_by_image_exact(image_content: bytes) -> Optional[model.Post]: checksum = util.get_sha1(image_content) return db.session \ - .query(db.Post) \ - .filter(db.Post.checksum == checksum) \ + .query(model.Post) \ + .filter(model.Post.checksum == checksum) \ .one_or_none() -def search_by_image(image_content): +def search_by_image(image_content: bytes) -> List[PostLookalike]: ret = [] for result in image_hash.search_by_image(image_content): ret.append(PostLookalike( @@ -571,24 +689,24 @@ def search_by_image(image_content): return ret -def populate_reverse_search(): +def populate_reverse_search() -> None: excluded_post_ids = image_hash.get_all_paths() post_ids_to_hash = ( db.session - .query(db.Post.post_id) + .query(model.Post.post_id) .filter( - (db.Post.type == db.Post.TYPE_IMAGE) | - (db.Post.type == db.Post.TYPE_ANIMATION)) - .filter(~db.Post.post_id.in_(excluded_post_ids)) - .order_by(db.Post.post_id.asc()) + (model.Post.type == model.Post.TYPE_IMAGE) | + (model.Post.type == model.Post.TYPE_ANIMATION)) + .filter(~model.Post.post_id.in_(excluded_post_ids)) + .order_by(model.Post.post_id.asc()) .all()) for post_ids_chunk in util.chunks(post_ids_to_hash, 100): posts_chunk = ( db.session - .query(db.Post) - .filter(db.Post.post_id.in_(post_ids_chunk)) + .query(model.Post) + .filter(model.Post.post_id.in_(post_ids_chunk)) .all()) for post in posts_chunk: content_path = get_post_content_path(post) diff --git a/server/szurubooru/func/scores.py b/server/szurubooru/func/scores.py index a42961f2..fde279eb 100644 --- a/server/szurubooru/func/scores.py +++ b/server/szurubooru/func/scores.py @@ -1,5 +1,6 @@ import datetime -from szurubooru import db, errors +from typing import Any, Tuple, Callable +from szurubooru import db, model, errors class InvalidScoreTargetError(errors.ValidationError): @@ -10,22 +11,23 @@ class InvalidScoreValueError(errors.ValidationError): pass -def _get_table_info(entity): +def _get_table_info( + entity: model.Base) -> Tuple[model.Base, Callable[[model.Base], Any]]: assert entity - resource_type, _, _ = db.util.get_resource_info(entity) + resource_type, _, _ = model.util.get_resource_info(entity) if resource_type == 'post': - return db.PostScore, lambda table: table.post_id + return model.PostScore, lambda table: table.post_id elif resource_type == 'comment': - return db.CommentScore, lambda table: table.comment_id + return model.CommentScore, lambda table: table.comment_id raise InvalidScoreTargetError() -def _get_score_entity(entity, user): +def _get_score_entity(entity: model.Base, user: model.User) -> model.Base: assert user - return db.util.get_aux_entity(db.session, _get_table_info, entity, user) + return model.util.get_aux_entity(db.session, _get_table_info, entity, user) -def delete_score(entity, user): +def delete_score(entity: model.Base, user: model.User) -> None: assert entity assert user score_entity = _get_score_entity(entity, user) @@ -33,7 +35,7 @@ def delete_score(entity, user): db.session.delete(score_entity) -def get_score(entity, user): +def get_score(entity: model.Base, user: model.User) -> int: assert entity assert user table, get_column = _get_table_info(entity) @@ -45,7 +47,7 @@ def get_score(entity, user): return row[0] if row else 0 -def set_score(entity, user, score): +def set_score(entity: model.Base, user: model.User, score: int) -> None: from szurubooru.func import favorites assert entity assert user diff --git a/server/szurubooru/func/serialization.py b/server/szurubooru/func/serialization.py new file mode 100644 index 00000000..df78959f --- /dev/null +++ b/server/szurubooru/func/serialization.py @@ -0,0 +1,27 @@ +from typing import Any, Optional, List, Dict, Callable +from szurubooru import db, model, rest, errors + + +def get_serialization_options(ctx: rest.Context) -> List[str]: + return ctx.get_param_as_list('fields', default=[]) + + +class BaseSerializer: + _fields = {} # type: Dict[str, Callable[[model.Base], Any]] + + def serialize(self, options: List[str]) -> Any: + field_factories = self._serializers() + if not options: + options = list(field_factories.keys()) + ret = {} + for key in options: + if key not in field_factories: + raise errors.ValidationError( + 'Invalid key: %r. Valid keys: %r.' % ( + key, list(sorted(field_factories.keys())))) + factory = field_factories[key] + ret[key] = factory() + return ret + + def _serializers(self) -> Dict[str, Callable[[], Any]]: + raise NotImplementedError() diff --git a/server/szurubooru/func/snapshots.py b/server/szurubooru/func/snapshots.py index f7efda9e..240c3bce 100644 --- a/server/szurubooru/func/snapshots.py +++ b/server/szurubooru/func/snapshots.py @@ -1,9 +1,10 @@ +from typing import Any, Optional, Dict, Callable from datetime import datetime -from szurubooru import db +from szurubooru import db, model from szurubooru.func import diff, users -def get_tag_category_snapshot(category): +def get_tag_category_snapshot(category: model.TagCategory) -> Dict[str, Any]: assert category return { 'name': category.name, @@ -12,7 +13,7 @@ def get_tag_category_snapshot(category): } -def get_tag_snapshot(tag): +def get_tag_snapshot(tag: model.Tag) -> Dict[str, Any]: assert tag return { 'names': [tag_name.name for tag_name in tag.names], @@ -22,7 +23,7 @@ def get_tag_snapshot(tag): } -def get_post_snapshot(post): +def get_post_snapshot(post: model.Post) -> Dict[str, Any]: assert post return { 'source': post.source, @@ -45,10 +46,11 @@ _snapshot_factories = { 'tag_category': lambda entity: get_tag_category_snapshot(entity), 'tag': lambda entity: get_tag_snapshot(entity), 'post': lambda entity: get_post_snapshot(entity), -} +} # type: Dict[model.Base, Callable[[model.Base], Dict[str ,Any]]] -def serialize_snapshot(snapshot, auth_user): +def serialize_snapshot( + snapshot: model.Snapshot, auth_user: model.User) -> Dict[str, Any]: assert snapshot return { 'operation': snapshot.operation, @@ -60,11 +62,14 @@ def serialize_snapshot(snapshot, auth_user): } -def _create(operation, entity, auth_user): +def _create( + operation: str, + entity: model.Base, + auth_user: Optional[model.User]) -> model.Snapshot: resource_type, resource_pkey, resource_name = ( - db.util.get_resource_info(entity)) + model.util.get_resource_info(entity)) - snapshot = db.Snapshot() + snapshot = model.Snapshot() snapshot.creation_time = datetime.utcnow() snapshot.operation = operation snapshot.resource_type = resource_type @@ -74,33 +79,33 @@ def _create(operation, entity, auth_user): return snapshot -def create(entity, auth_user): +def create(entity: model.Base, auth_user: Optional[model.User]) -> None: assert entity - snapshot = _create(db.Snapshot.OPERATION_CREATED, entity, auth_user) + snapshot = _create(model.Snapshot.OPERATION_CREATED, entity, auth_user) snapshot_factory = _snapshot_factories[snapshot.resource_type] snapshot.data = snapshot_factory(entity) db.session.add(snapshot) # pylint: disable=protected-access -def modify(entity, auth_user): +def modify(entity: model.Base, auth_user: Optional[model.User]) -> None: assert entity - model = next( + table = next( ( - model - for model in db.Base._decl_class_registry.values() - if hasattr(model, '__table__') - and model.__table__.fullname == entity.__table__.fullname + cls + for cls in model.Base._decl_class_registry.values() + if hasattr(cls, '__table__') + and cls.__table__.fullname == entity.__table__.fullname ), None) - assert model + assert table - snapshot = _create(db.Snapshot.OPERATION_MODIFIED, entity, auth_user) + snapshot = _create(model.Snapshot.OPERATION_MODIFIED, entity, auth_user) snapshot_factory = _snapshot_factories[snapshot.resource_type] detached_session = db.sessionmaker() - detached_entity = detached_session.query(model).get(snapshot.resource_pkey) + detached_entity = detached_session.query(table).get(snapshot.resource_pkey) assert detached_entity, 'Entity not found in DB, have you committed it?' detached_snapshot = snapshot_factory(detached_entity) detached_session.close() @@ -113,19 +118,23 @@ def modify(entity, auth_user): db.session.add(snapshot) -def delete(entity, auth_user): +def delete(entity: model.Base, auth_user: Optional[model.User]) -> None: assert entity - snapshot = _create(db.Snapshot.OPERATION_DELETED, entity, auth_user) + snapshot = _create(model.Snapshot.OPERATION_DELETED, entity, auth_user) snapshot_factory = _snapshot_factories[snapshot.resource_type] snapshot.data = snapshot_factory(entity) db.session.add(snapshot) -def merge(source_entity, target_entity, auth_user): +def merge( + source_entity: model.Base, + target_entity: model.Base, + auth_user: Optional[model.User]) -> None: assert source_entity assert target_entity - snapshot = _create(db.Snapshot.OPERATION_MERGED, source_entity, auth_user) + snapshot = _create( + model.Snapshot.OPERATION_MERGED, source_entity, auth_user) resource_type, _resource_pkey, resource_name = ( - db.util.get_resource_info(target_entity)) + model.util.get_resource_info(target_entity)) snapshot.data = [resource_type, resource_name] db.session.add(snapshot) diff --git a/server/szurubooru/func/tag_categories.py b/server/szurubooru/func/tag_categories.py index a9169dec..41c9c928 100644 --- a/server/szurubooru/func/tag_categories.py +++ b/server/szurubooru/func/tag_categories.py @@ -1,7 +1,8 @@ import re -import sqlalchemy -from szurubooru import config, db, errors -from szurubooru.func import util, cache +from typing import Any, Optional, Dict, List, Callable +import sqlalchemy as sa +from szurubooru import config, db, model, errors, rest +from szurubooru.func import util, serialization, cache DEFAULT_CATEGORY_NAME_CACHE_KEY = 'default-tag-category' @@ -27,28 +28,52 @@ class InvalidTagCategoryColorError(errors.ValidationError): pass -def _verify_name_validity(name): +def _verify_name_validity(name: str) -> None: name_regex = config.config['tag_category_name_regex'] if not re.match(name_regex, name): raise InvalidTagCategoryNameError( 'Name must satisfy regex %r.' % name_regex) -def serialize_category(category, options=None): - return util.serialize_entity( - category, - { - 'name': lambda: category.name, - 'version': lambda: category.version, - 'color': lambda: category.color, - 'usages': lambda: category.tag_count, - 'default': lambda: category.default, - }, - options) +class TagCategorySerializer(serialization.BaseSerializer): + def __init__(self, category: model.TagCategory) -> None: + self.category = category + + def _serializers(self) -> Dict[str, Callable[[], Any]]: + return { + 'name': self.serialize_name, + 'version': self.serialize_version, + 'color': self.serialize_color, + 'usages': self.serialize_usages, + 'default': self.serialize_default, + } + + def serialize_name(self) -> Any: + return self.category.name + + def serialize_version(self) -> Any: + return self.category.version + + def serialize_color(self) -> Any: + return self.category.color + + def serialize_usages(self) -> Any: + return self.category.tag_count + + def serialize_default(self) -> Any: + return self.category.default -def create_category(name, color): - category = db.TagCategory() +def serialize_category( + category: Optional[model.TagCategory], + options: List[str]=[]) -> Optional[rest.Response]: + if not category: + return None + return TagCategorySerializer(category).serialize(options) + + +def create_category(name: str, color: str) -> model.TagCategory: + category = model.TagCategory() update_category_name(category, name) update_category_color(category, color) if not get_all_categories(): @@ -56,64 +81,66 @@ def create_category(name, color): return category -def update_category_name(category, name): +def update_category_name(category: model.TagCategory, name: str) -> None: assert category if not name: raise InvalidTagCategoryNameError('Name cannot be empty.') - expr = sqlalchemy.func.lower(db.TagCategory.name) == name.lower() + expr = sa.func.lower(model.TagCategory.name) == name.lower() if category.tag_category_id: expr = expr & ( - db.TagCategory.tag_category_id != category.tag_category_id) - already_exists = db.session.query(db.TagCategory).filter(expr).count() > 0 + model.TagCategory.tag_category_id != category.tag_category_id) + already_exists = ( + db.session.query(model.TagCategory).filter(expr).count() > 0) if already_exists: raise TagCategoryAlreadyExistsError( 'A category with this name already exists.') - if util.value_exceeds_column_size(name, db.TagCategory.name): + if util.value_exceeds_column_size(name, model.TagCategory.name): raise InvalidTagCategoryNameError('Name is too long.') _verify_name_validity(name) category.name = name cache.remove(DEFAULT_CATEGORY_NAME_CACHE_KEY) -def update_category_color(category, color): +def update_category_color(category: model.TagCategory, color: str) -> None: assert category if not color: raise InvalidTagCategoryColorError('Color cannot be empty.') if not re.match(r'^#?[0-9a-z]+$', color): raise InvalidTagCategoryColorError('Invalid color.') - if util.value_exceeds_column_size(color, db.TagCategory.color): + if util.value_exceeds_column_size(color, model.TagCategory.color): raise InvalidTagCategoryColorError('Color is too long.') category.color = color -def try_get_category_by_name(name, lock=False): +def try_get_category_by_name( + name: str, lock: bool=False) -> Optional[model.TagCategory]: query = db.session \ - .query(db.TagCategory) \ - .filter(sqlalchemy.func.lower(db.TagCategory.name) == name.lower()) + .query(model.TagCategory) \ + .filter(sa.func.lower(model.TagCategory.name) == name.lower()) if lock: query = query.with_lockmode('update') return query.one_or_none() -def get_category_by_name(name, lock=False): +def get_category_by_name(name: str, lock: bool=False) -> model.TagCategory: category = try_get_category_by_name(name, lock) if not category: raise TagCategoryNotFoundError('Tag category %r not found.' % name) return category -def get_all_category_names(): - return [row[0] for row in db.session.query(db.TagCategory.name).all()] +def get_all_category_names() -> List[str]: + return [row[0] for row in db.session.query(model.TagCategory.name).all()] -def get_all_categories(): - return db.session.query(db.TagCategory).all() +def get_all_categories() -> List[model.TagCategory]: + return db.session.query(model.TagCategory).all() -def try_get_default_category(lock=False): +def try_get_default_category(lock: bool=False) -> Optional[model.TagCategory]: query = db.session \ - .query(db.TagCategory) \ - .filter(db.TagCategory.default) + .query(model.TagCategory) \ + .filter(model.TagCategory.default) if lock: query = query.with_lockmode('update') category = query.first() @@ -121,22 +148,22 @@ def try_get_default_category(lock=False): # category, get the first record available. if not category: query = db.session \ - .query(db.TagCategory) \ - .order_by(db.TagCategory.tag_category_id.asc()) + .query(model.TagCategory) \ + .order_by(model.TagCategory.tag_category_id.asc()) if lock: query = query.with_lockmode('update') category = query.first() return category -def get_default_category(lock=False): +def get_default_category(lock: bool=False) -> model.TagCategory: category = try_get_default_category(lock) if not category: raise TagCategoryNotFoundError('No tag category created yet.') return category -def get_default_category_name(): +def get_default_category_name() -> str: if cache.has(DEFAULT_CATEGORY_NAME_CACHE_KEY): return cache.get(DEFAULT_CATEGORY_NAME_CACHE_KEY) default_category = get_default_category() @@ -145,7 +172,7 @@ def get_default_category_name(): return default_category_name -def set_default_category(category): +def set_default_category(category: model.TagCategory) -> None: assert category old_category = try_get_default_category(lock=True) if old_category: @@ -156,7 +183,7 @@ def set_default_category(category): cache.remove(DEFAULT_CATEGORY_NAME_CACHE_KEY) -def delete_category(category): +def delete_category(category: model.TagCategory) -> None: assert category if len(get_all_category_names()) == 1: raise TagCategoryIsInUseError('Cannot delete the last category.') diff --git a/server/szurubooru/func/tags.py b/server/szurubooru/func/tags.py index 1665282b..fb043245 100644 --- a/server/szurubooru/func/tags.py +++ b/server/szurubooru/func/tags.py @@ -1,10 +1,11 @@ -import datetime import json import os import re -import sqlalchemy -from szurubooru import config, db, errors -from szurubooru.func import util, tag_categories +from typing import Any, Optional, Tuple, List, Dict, Callable +from datetime import datetime +import sqlalchemy as sa +from szurubooru import config, db, model, errors, rest +from szurubooru.func import util, tag_categories, serialization class TagNotFoundError(errors.NotFoundError): @@ -35,31 +36,32 @@ class InvalidTagDescriptionError(errors.ValidationError): pass -def _verify_name_validity(name): - if util.value_exceeds_column_size(name, db.TagName.name): +def _verify_name_validity(name: str) -> None: + if util.value_exceeds_column_size(name, model.TagName.name): raise InvalidTagNameError('Name is too long.') name_regex = config.config['tag_name_regex'] if not re.match(name_regex, name): raise InvalidTagNameError('Name must satisfy regex %r.' % name_regex) -def _get_names(tag): +def _get_names(tag: model.Tag) -> List[str]: assert tag return [tag_name.name for tag_name in tag.names] -def _lower_list(names): +def _lower_list(names: List[str]) -> List[str]: return [name.lower() for name in names] -def _check_name_intersection(names1, names2, case_sensitive): +def _check_name_intersection( + names1: List[str], names2: List[str], case_sensitive: bool) -> bool: if not case_sensitive: names1 = _lower_list(names1) names2 = _lower_list(names2) return len(set(names1).intersection(names2)) > 0 -def sort_tags(tags): +def sort_tags(tags: List[model.Tag]) -> List[model.Tag]: default_category_name = tag_categories.get_default_category_name() return sorted( tags, @@ -70,35 +72,70 @@ def sort_tags(tags): ) -def serialize_tag(tag, options=None): - return util.serialize_entity( - tag, - { - 'names': lambda: [tag_name.name for tag_name in tag.names], - 'category': lambda: tag.category.name, - 'version': lambda: tag.version, - 'description': lambda: tag.description, - 'creationTime': lambda: tag.creation_time, - 'lastEditTime': lambda: tag.last_edit_time, - 'usages': lambda: tag.post_count, - 'suggestions': lambda: [ - relation.names[0].name - for relation in sort_tags(tag.suggestions)], - 'implications': lambda: [ - relation.names[0].name - for relation in sort_tags(tag.implications)], - }, - options) +class TagSerializer(serialization.BaseSerializer): + def __init__(self, tag: model.Tag) -> None: + self.tag = tag + + def _serializers(self) -> Dict[str, Callable[[], Any]]: + return { + 'names': self.serialize_names, + 'category': self.serialize_category, + 'version': self.serialize_version, + 'description': self.serialize_description, + 'creationTime': self.serialize_creation_time, + 'lastEditTime': self.serialize_last_edit_time, + 'usages': self.serialize_usages, + 'suggestions': self.serialize_suggestions, + 'implications': self.serialize_implications, + } + + def serialize_names(self) -> Any: + return [tag_name.name for tag_name in self.tag.names] + + def serialize_category(self) -> Any: + return self.tag.category.name + + def serialize_version(self) -> Any: + return self.tag.version + + def serialize_description(self) -> Any: + return self.tag.description + + def serialize_creation_time(self) -> Any: + return self.tag.creation_time + + def serialize_last_edit_time(self) -> Any: + return self.tag.last_edit_time + + def serialize_usages(self) -> Any: + return self.tag.post_count + + def serialize_suggestions(self) -> Any: + return [ + relation.names[0].name + for relation in sort_tags(self.tag.suggestions)] + + def serialize_implications(self) -> Any: + return [ + relation.names[0].name + for relation in sort_tags(self.tag.implications)] -def export_to_json(): - tags = {} - categories = {} +def serialize_tag( + tag: model.Tag, options: List[str]=[]) -> Optional[rest.Response]: + if not tag: + return None + return TagSerializer(tag).serialize(options) + + +def export_to_json() -> None: + tags = {} # type: Dict[int, Any] + categories = {} # type: Dict[int, Any] for result in db.session.query( - db.TagCategory.tag_category_id, - db.TagCategory.name, - db.TagCategory.color).all(): + model.TagCategory.tag_category_id, + model.TagCategory.name, + model.TagCategory.color).all(): categories[result[0]] = { 'name': result[1], 'color': result[2], @@ -106,8 +143,8 @@ def export_to_json(): for result in ( db.session - .query(db.TagName.tag_id, db.TagName.name) - .order_by(db.TagName.order) + .query(model.TagName.tag_id, model.TagName.name) + .order_by(model.TagName.order) .all()): if not result[0] in tags: tags[result[0]] = {'names': []} @@ -115,8 +152,10 @@ def export_to_json(): for result in ( db.session - .query(db.TagSuggestion.parent_id, db.TagName.name) - .join(db.TagName, db.TagName.tag_id == db.TagSuggestion.child_id) + .query(model.TagSuggestion.parent_id, model.TagName.name) + .join( + model.TagName, + model.TagName.tag_id == model.TagSuggestion.child_id) .all()): if 'suggestions' not in tags[result[0]]: tags[result[0]]['suggestions'] = [] @@ -124,17 +163,19 @@ def export_to_json(): for result in ( db.session - .query(db.TagImplication.parent_id, db.TagName.name) - .join(db.TagName, db.TagName.tag_id == db.TagImplication.child_id) + .query(model.TagImplication.parent_id, model.TagName.name) + .join( + model.TagName, + model.TagName.tag_id == model.TagImplication.child_id) .all()): if 'implications' not in tags[result[0]]: tags[result[0]]['implications'] = [] tags[result[0]]['implications'].append(result[1]) for result in db.session.query( - db.Tag.tag_id, - db.Tag.category_id, - db.Tag.post_count).all(): + model.Tag.tag_id, + model.Tag.category_id, + model.Tag.post_count).all(): tags[result[0]]['category'] = categories[result[1]]['name'] tags[result[0]]['usages'] = result[2] @@ -148,33 +189,34 @@ def export_to_json(): handle.write(json.dumps(output, separators=(',', ':'))) -def try_get_tag_by_name(name): +def try_get_tag_by_name(name: str) -> Optional[model.Tag]: return ( db.session - .query(db.Tag) - .join(db.TagName) - .filter(sqlalchemy.func.lower(db.TagName.name) == name.lower()) + .query(model.Tag) + .join(model.TagName) + .filter(sa.func.lower(model.TagName.name) == name.lower()) .one_or_none()) -def get_tag_by_name(name): +def get_tag_by_name(name: str) -> model.Tag: tag = try_get_tag_by_name(name) if not tag: raise TagNotFoundError('Tag %r not found.' % name) return tag -def get_tags_by_names(names): +def get_tags_by_names(names: List[str]) -> List[model.Tag]: names = util.icase_unique(names) if len(names) == 0: return [] - expr = sqlalchemy.sql.false() + expr = sa.sql.false() for name in names: - expr = expr | (sqlalchemy.func.lower(db.TagName.name) == name.lower()) - return db.session.query(db.Tag).join(db.TagName).filter(expr).all() + expr = expr | (sa.func.lower(model.TagName.name) == name.lower()) + return db.session.query(model.Tag).join(model.TagName).filter(expr).all() -def get_or_create_tags_by_names(names): +def get_or_create_tags_by_names( + names: List[str]) -> Tuple[List[model.Tag], List[model.Tag]]: names = util.icase_unique(names) existing_tags = get_tags_by_names(names) new_tags = [] @@ -197,86 +239,87 @@ def get_or_create_tags_by_names(names): return existing_tags, new_tags -def get_tag_siblings(tag): +def get_tag_siblings(tag: model.Tag) -> List[model.Tag]: assert tag - tag_alias = sqlalchemy.orm.aliased(db.Tag) - pt_alias1 = sqlalchemy.orm.aliased(db.PostTag) - pt_alias2 = sqlalchemy.orm.aliased(db.PostTag) + tag_alias = sa.orm.aliased(model.Tag) + pt_alias1 = sa.orm.aliased(model.PostTag) + pt_alias2 = sa.orm.aliased(model.PostTag) result = ( db.session - .query(tag_alias, sqlalchemy.func.count(pt_alias2.post_id)) + .query(tag_alias, sa.func.count(pt_alias2.post_id)) .join(pt_alias1, pt_alias1.tag_id == tag_alias.tag_id) .join(pt_alias2, pt_alias2.post_id == pt_alias1.post_id) .filter(pt_alias2.tag_id == tag.tag_id) .filter(pt_alias1.tag_id != tag.tag_id) .group_by(tag_alias.tag_id) - .order_by(sqlalchemy.func.count(pt_alias2.post_id).desc()) + .order_by(sa.func.count(pt_alias2.post_id).desc()) .limit(50)) return result -def delete(source_tag): +def delete(source_tag: model.Tag) -> None: assert source_tag db.session.execute( - sqlalchemy.sql.expression.delete(db.TagSuggestion) - .where(db.TagSuggestion.child_id == source_tag.tag_id)) + sa.sql.expression.delete(model.TagSuggestion) + .where(model.TagSuggestion.child_id == source_tag.tag_id)) db.session.execute( - sqlalchemy.sql.expression.delete(db.TagImplication) - .where(db.TagImplication.child_id == source_tag.tag_id)) + sa.sql.expression.delete(model.TagImplication) + .where(model.TagImplication.child_id == source_tag.tag_id)) db.session.delete(source_tag) -def merge_tags(source_tag, target_tag): +def merge_tags(source_tag: model.Tag, target_tag: model.Tag) -> None: assert source_tag assert target_tag if source_tag.tag_id == target_tag.tag_id: raise InvalidTagRelationError('Cannot merge tag with itself.') - def merge_posts(source_tag_id, target_tag_id): - alias1 = db.PostTag - alias2 = sqlalchemy.orm.util.aliased(db.PostTag) + def merge_posts(source_tag_id: int, target_tag_id: int) -> None: + alias1 = model.PostTag + alias2 = sa.orm.util.aliased(model.PostTag) update_stmt = ( - sqlalchemy.sql.expression.update(alias1) + sa.sql.expression.update(alias1) .where(alias1.tag_id == source_tag_id)) update_stmt = ( update_stmt .where( - ~sqlalchemy.exists() + ~sa.exists() .where(alias1.post_id == alias2.post_id) .where(alias2.tag_id == target_tag_id))) update_stmt = update_stmt.values(tag_id=target_tag_id) db.session.execute(update_stmt) - def merge_relations(table, source_tag_id, target_tag_id): + def merge_relations( + table: model.Base, source_tag_id: int, target_tag_id: int) -> None: alias1 = table - alias2 = sqlalchemy.orm.util.aliased(table) + alias2 = sa.orm.util.aliased(table) update_stmt = ( - sqlalchemy.sql.expression.update(alias1) + sa.sql.expression.update(alias1) .where(alias1.parent_id == source_tag_id) .where(alias1.child_id != target_tag_id) .where( - ~sqlalchemy.exists() + ~sa.exists() .where(alias2.child_id == alias1.child_id) .where(alias2.parent_id == target_tag_id)) .values(parent_id=target_tag_id)) db.session.execute(update_stmt) update_stmt = ( - sqlalchemy.sql.expression.update(alias1) + sa.sql.expression.update(alias1) .where(alias1.child_id == source_tag_id) .where(alias1.parent_id != target_tag_id) .where( - ~sqlalchemy.exists() + ~sa.exists() .where(alias2.parent_id == alias1.parent_id) .where(alias2.child_id == target_tag_id)) .values(child_id=target_tag_id)) db.session.execute(update_stmt) - def merge_suggestions(source_tag_id, target_tag_id): - merge_relations(db.TagSuggestion, source_tag_id, target_tag_id) + def merge_suggestions(source_tag_id: int, target_tag_id: int) -> None: + merge_relations(model.TagSuggestion, source_tag_id, target_tag_id) - def merge_implications(source_tag_id, target_tag_id): - merge_relations(db.TagImplication, source_tag_id, target_tag_id) + def merge_implications(source_tag_id: int, target_tag_id: int) -> None: + merge_relations(model.TagImplication, source_tag_id, target_tag_id) merge_posts(source_tag.tag_id, target_tag.tag_id) merge_suggestions(source_tag.tag_id, target_tag.tag_id) @@ -284,9 +327,13 @@ def merge_tags(source_tag, target_tag): delete(source_tag) -def create_tag(names, category_name, suggestions, implications): - tag = db.Tag() - tag.creation_time = datetime.datetime.utcnow() +def create_tag( + names: List[str], + category_name: str, + suggestions: List[str], + implications: List[str]) -> model.Tag: + tag = model.Tag() + tag.creation_time = datetime.utcnow() update_tag_names(tag, names) update_tag_category_name(tag, category_name) update_tag_suggestions(tag, suggestions) @@ -294,12 +341,12 @@ def create_tag(names, category_name, suggestions, implications): return tag -def update_tag_category_name(tag, category_name): +def update_tag_category_name(tag: model.Tag, category_name: str) -> None: assert tag tag.category = tag_categories.get_category_by_name(category_name) -def update_tag_names(tag, names): +def update_tag_names(tag: model.Tag, names: List[str]) -> None: # sanitize assert tag names = util.icase_unique([name for name in names if name]) @@ -309,12 +356,12 @@ def update_tag_names(tag, names): _verify_name_validity(name) # check for existing tags - expr = sqlalchemy.sql.false() + expr = sa.sql.false() for name in names: - expr = expr | (sqlalchemy.func.lower(db.TagName.name) == name.lower()) + expr = expr | (sa.func.lower(model.TagName.name) == name.lower()) if tag.tag_id: - expr = expr & (db.TagName.tag_id != tag.tag_id) - existing_tags = db.session.query(db.TagName).filter(expr).all() + expr = expr & (model.TagName.tag_id != tag.tag_id) + existing_tags = db.session.query(model.TagName).filter(expr).all() if len(existing_tags): raise TagAlreadyExistsError( 'One of names is already used by another tag.') @@ -326,7 +373,7 @@ def update_tag_names(tag, names): # add wanted items for name in names: if not _check_name_intersection(_get_names(tag), [name], True): - tag.names.append(db.TagName(name, None)) + tag.names.append(model.TagName(name, -1)) # set alias order to match the request for i, name in enumerate(names): @@ -336,7 +383,7 @@ def update_tag_names(tag, names): # TODO: what to do with relations that do not yet exist? -def update_tag_implications(tag, relations): +def update_tag_implications(tag: model.Tag, relations: List[str]) -> None: assert tag if _check_name_intersection(_get_names(tag), relations, False): raise InvalidTagRelationError('Tag cannot imply itself.') @@ -344,15 +391,15 @@ def update_tag_implications(tag, relations): # TODO: what to do with relations that do not yet exist? -def update_tag_suggestions(tag, relations): +def update_tag_suggestions(tag: model.Tag, relations: List[str]) -> None: assert tag if _check_name_intersection(_get_names(tag), relations, False): raise InvalidTagRelationError('Tag cannot suggest itself.') tag.suggestions = get_tags_by_names(relations) -def update_tag_description(tag, description): +def update_tag_description(tag: model.Tag, description: str) -> None: assert tag - if util.value_exceeds_column_size(description, db.Tag.description): + if util.value_exceeds_column_size(description, model.Tag.description): raise InvalidTagDescriptionError('Description is too long.') - tag.description = description + tag.description = description or None diff --git a/server/szurubooru/func/users.py b/server/szurubooru/func/users.py index 5547bbae..fd0c6240 100644 --- a/server/szurubooru/func/users.py +++ b/server/szurubooru/func/users.py @@ -1,8 +1,9 @@ -import datetime import re -from sqlalchemy import func -from szurubooru import config, db, errors -from szurubooru.func import auth, util, files, images +from typing import Any, Optional, Union, List, Dict, Callable +from datetime import datetime +import sqlalchemy as sa +from szurubooru import config, db, model, errors, rest +from szurubooru.func import auth, util, serialization, files, images class UserNotFoundError(errors.NotFoundError): @@ -33,11 +34,11 @@ class InvalidAvatarError(errors.ValidationError): pass -def get_avatar_path(user_name): +def get_avatar_path(user_name: str) -> str: return 'avatars/' + user_name.lower() + '.png' -def get_avatar_url(user): +def get_avatar_url(user: model.User) -> str: assert user if user.avatar_style == user.AVATAR_GRAVATAR: assert user.email or user.name @@ -49,7 +50,10 @@ def get_avatar_url(user): config.config['data_url'].rstrip('/'), user.name.lower()) -def get_email(user, auth_user, force_show_email): +def get_email( + user: model.User, + auth_user: model.User, + force_show_email: bool) -> Union[bool, str]: assert user assert auth_user if not force_show_email \ @@ -59,7 +63,8 @@ def get_email(user, auth_user, force_show_email): return user.email -def get_liked_post_count(user, auth_user): +def get_liked_post_count( + user: model.User, auth_user: model.User) -> Union[bool, int]: assert user assert auth_user if auth_user.user_id != user.user_id: @@ -67,7 +72,8 @@ def get_liked_post_count(user, auth_user): return user.liked_post_count -def get_disliked_post_count(user, auth_user): +def get_disliked_post_count( + user: model.User, auth_user: model.User) -> Union[bool, int]: assert user assert auth_user if auth_user.user_id != user.user_id: @@ -75,91 +81,144 @@ def get_disliked_post_count(user, auth_user): return user.disliked_post_count -def serialize_user(user, auth_user, options=None, force_show_email=False): - return util.serialize_entity( - user, - { - 'name': lambda: user.name, - 'creationTime': lambda: user.creation_time, - 'lastLoginTime': lambda: user.last_login_time, - 'version': lambda: user.version, - 'rank': lambda: user.rank, - 'avatarStyle': lambda: user.avatar_style, - 'avatarUrl': lambda: get_avatar_url(user), - 'commentCount': lambda: user.comment_count, - 'uploadedPostCount': lambda: user.post_count, - 'favoritePostCount': lambda: user.favorite_post_count, - 'likedPostCount': - lambda: get_liked_post_count(user, auth_user), - 'dislikedPostCount': - lambda: get_disliked_post_count(user, auth_user), - 'email': - lambda: get_email(user, auth_user, force_show_email), - }, - options) +class UserSerializer(serialization.BaseSerializer): + def __init__( + self, + user: model.User, + auth_user: model.User, + force_show_email: bool=False) -> None: + self.user = user + self.auth_user = auth_user + self.force_show_email = force_show_email + + def _serializers(self) -> Dict[str, Callable[[], Any]]: + return { + 'name': self.serialize_name, + 'creationTime': self.serialize_creation_time, + 'lastLoginTime': self.serialize_last_login_time, + 'version': self.serialize_version, + 'rank': self.serialize_rank, + 'avatarStyle': self.serialize_avatar_style, + 'avatarUrl': self.serialize_avatar_url, + 'commentCount': self.serialize_comment_count, + 'uploadedPostCount': self.serialize_uploaded_post_count, + 'favoritePostCount': self.serialize_favorite_post_count, + 'likedPostCount': self.serialize_liked_post_count, + 'dislikedPostCount': self.serialize_disliked_post_count, + 'email': self.serialize_email, + } + + def serialize_name(self) -> Any: + return self.user.name + + def serialize_creation_time(self) -> Any: + return self.user.creation_time + + def serialize_last_login_time(self) -> Any: + return self.user.last_login_time + + def serialize_version(self) -> Any: + return self.user.version + + def serialize_rank(self) -> Any: + return self.user.rank + + def serialize_avatar_style(self) -> Any: + return self.user.avatar_style + + def serialize_avatar_url(self) -> Any: + return get_avatar_url(self.user) + + def serialize_comment_count(self) -> Any: + return self.user.comment_count + + def serialize_uploaded_post_count(self) -> Any: + return self.user.post_count + + def serialize_favorite_post_count(self) -> Any: + return self.user.favorite_post_count + + def serialize_liked_post_count(self) -> Any: + return get_liked_post_count(self.user, self.auth_user) + + def serialize_disliked_post_count(self) -> Any: + return get_disliked_post_count(self.user, self.auth_user) + + def serialize_email(self) -> Any: + return get_email(self.user, self.auth_user, self.force_show_email) -def serialize_micro_user(user, auth_user): +def serialize_user( + user: Optional[model.User], + auth_user: model.User, + options: List[str]=[], + force_show_email: bool=False) -> Optional[rest.Response]: + if not user: + return None + return UserSerializer(user, auth_user, force_show_email).serialize(options) + + +def serialize_micro_user( + user: Optional[model.User], + auth_user: model.User) -> Optional[rest.Response]: return serialize_user( - user, - auth_user=auth_user, - options=['name', 'avatarUrl']) + user, auth_user=auth_user, options=['name', 'avatarUrl']) -def get_user_count(): - return db.session.query(db.User).count() +def get_user_count() -> int: + return db.session.query(model.User).count() -def try_get_user_by_name(name): +def try_get_user_by_name(name: str) -> Optional[model.User]: return db.session \ - .query(db.User) \ - .filter(func.lower(db.User.name) == func.lower(name)) \ + .query(model.User) \ + .filter(sa.func.lower(model.User.name) == sa.func.lower(name)) \ .one_or_none() -def get_user_by_name(name): +def get_user_by_name(name: str) -> model.User: user = try_get_user_by_name(name) if not user: raise UserNotFoundError('User %r not found.' % name) return user -def try_get_user_by_name_or_email(name_or_email): +def try_get_user_by_name_or_email(name_or_email: str) -> Optional[model.User]: return ( db.session - .query(db.User) + .query(model.User) .filter( - (func.lower(db.User.name) == func.lower(name_or_email)) | - (func.lower(db.User.email) == func.lower(name_or_email))) + (sa.func.lower(model.User.name) == sa.func.lower(name_or_email)) | + (sa.func.lower(model.User.email) == sa.func.lower(name_or_email))) .one_or_none()) -def get_user_by_name_or_email(name_or_email): +def get_user_by_name_or_email(name_or_email: str) -> model.User: user = try_get_user_by_name_or_email(name_or_email) if not user: raise UserNotFoundError('User %r not found.' % name_or_email) return user -def create_user(name, password, email): - user = db.User() +def create_user(name: str, password: str, email: str) -> model.User: + user = model.User() update_user_name(user, name) update_user_password(user, password) update_user_email(user, email) if get_user_count() > 0: user.rank = util.flip(auth.RANK_MAP)[config.config['default_rank']] else: - user.rank = db.User.RANK_ADMINISTRATOR - user.creation_time = datetime.datetime.utcnow() - user.avatar_style = db.User.AVATAR_GRAVATAR + user.rank = model.User.RANK_ADMINISTRATOR + user.creation_time = datetime.utcnow() + user.avatar_style = model.User.AVATAR_GRAVATAR return user -def update_user_name(user, name): +def update_user_name(user: model.User, name: str) -> None: assert user if not name: raise InvalidUserNameError('Name cannot be empty.') - if util.value_exceeds_column_size(name, db.User.name): + if util.value_exceeds_column_size(name, model.User.name): raise InvalidUserNameError('User name is too long.') name = name.strip() name_regex = config.config['user_name_regex'] @@ -174,7 +233,7 @@ def update_user_name(user, name): user.name = name -def update_user_password(user, password): +def update_user_password(user: model.User, password: str) -> None: assert user if not password: raise InvalidPasswordError('Password cannot be empty.') @@ -186,20 +245,18 @@ def update_user_password(user, password): user.password_hash = auth.get_password_hash(user.password_salt, password) -def update_user_email(user, email): +def update_user_email(user: model.User, email: str) -> None: assert user - if email: - email = email.strip() - if not email: - email = None - if email and util.value_exceeds_column_size(email, db.User.email): + email = email.strip() + if util.value_exceeds_column_size(email, model.User.email): raise InvalidEmailError('Email is too long.') if not util.is_valid_email(email): raise InvalidEmailError('E-mail is invalid.') - user.email = email + user.email = email or None -def update_user_rank(user, rank, auth_user): +def update_user_rank( + user: model.User, rank: str, auth_user: model.User) -> None: assert user if not rank: raise InvalidRankError('Rank cannot be empty.') @@ -208,7 +265,7 @@ def update_user_rank(user, rank, auth_user): if not rank: raise InvalidRankError( 'Rank can be either of %r.' % all_ranks) - if rank in (db.User.RANK_ANONYMOUS, db.User.RANK_NOBODY): + if rank in (model.User.RANK_ANONYMOUS, model.User.RANK_NOBODY): raise InvalidRankError('Rank %r cannot be used.' % auth.RANK_MAP[rank]) if all_ranks.index(auth_user.rank) \ < all_ranks.index(rank) and get_user_count() > 0: @@ -216,7 +273,10 @@ def update_user_rank(user, rank, auth_user): user.rank = rank -def update_user_avatar(user, avatar_style, avatar_content=None): +def update_user_avatar( + user: model.User, + avatar_style: str, + avatar_content: Optional[bytes]=None) -> None: assert user if avatar_style == 'gravatar': user.avatar_style = user.AVATAR_GRAVATAR @@ -238,12 +298,12 @@ def update_user_avatar(user, avatar_style, avatar_content=None): avatar_style, ['gravatar', 'manual'])) -def bump_user_login_time(user): +def bump_user_login_time(user: model.User) -> None: assert user - user.last_login_time = datetime.datetime.utcnow() + user.last_login_time = datetime.utcnow() -def reset_user_password(user): +def reset_user_password(user: model.User) -> str: assert user password = auth.create_password() user.password_salt = auth.create_password() diff --git a/server/szurubooru/func/util.py b/server/szurubooru/func/util.py index 11caedd2..40d19d39 100644 --- a/server/szurubooru/func/util.py +++ b/server/szurubooru/func/util.py @@ -2,52 +2,39 @@ import os import hashlib import re import tempfile +from typing import ( + Any, Optional, Union, Tuple, List, Dict, Generator, Callable, TypeVar) from datetime import datetime, timedelta from contextlib import contextmanager from szurubooru import errors -def snake_case_to_lower_camel_case(text): +T = TypeVar('T') + + +def snake_case_to_lower_camel_case(text: str) -> str: components = text.split('_') return components[0].lower() + \ ''.join(word[0].upper() + word[1:].lower() for word in components[1:]) -def snake_case_to_upper_train_case(text): +def snake_case_to_upper_train_case(text: str) -> str: return '-'.join( word[0].upper() + word[1:].lower() for word in text.split('_')) -def snake_case_to_lower_camel_case_keys(source): +def snake_case_to_lower_camel_case_keys( + source: Dict[str, Any]) -> Dict[str, Any]: target = {} for key, value in source.items(): target[snake_case_to_lower_camel_case(key)] = value return target -def get_serialization_options(ctx): - return ctx.get_param_as_list('fields', required=False, default=None) - - -def serialize_entity(entity, field_factories, options): - if not entity: - return None - if not options or len(options) == 0: - options = field_factories.keys() - ret = {} - for key in options: - if key not in field_factories: - raise errors.ValidationError('Invalid key: %r. Valid keys: %r.' % ( - key, list(sorted(field_factories.keys())))) - factory = field_factories[key] - ret[key] = factory() - return ret - - @contextmanager -def create_temp_file(**kwargs): - (handle, path) = tempfile.mkstemp(**kwargs) - os.close(handle) +def create_temp_file(**kwargs: Any) -> Generator: + (descriptor, path) = tempfile.mkstemp(**kwargs) + os.close(descriptor) try: with open(path, 'r+b') as handle: yield handle @@ -55,17 +42,15 @@ def create_temp_file(**kwargs): os.remove(path) -def unalias_dict(input_dict): - output_dict = {} - for key_list, value in input_dict.items(): - if isinstance(key_list, str): - key_list = [key_list] - for key in key_list: - output_dict[key] = value +def unalias_dict(source: List[Tuple[List[str], T]]) -> Dict[str, T]: + output_dict = {} # type: Dict[str, T] + for aliases, value in source: + for alias in aliases: + output_dict[alias] = value return output_dict -def get_md5(source): +def get_md5(source: Union[str, bytes]) -> str: if not isinstance(source, bytes): source = source.encode('utf-8') md5 = hashlib.md5() @@ -73,7 +58,7 @@ def get_md5(source): return md5.hexdigest() -def get_sha1(source): +def get_sha1(source: Union[str, bytes]) -> str: if not isinstance(source, bytes): source = source.encode('utf-8') sha1 = hashlib.sha1() @@ -81,24 +66,25 @@ def get_sha1(source): return sha1.hexdigest() -def flip(source): +def flip(source: Dict[Any, Any]) -> Dict[Any, Any]: return {v: k for k, v in source.items()} -def is_valid_email(email): +def is_valid_email(email: Optional[str]) -> bool: ''' Return whether given email address is valid or empty. ''' - return not email or re.match(r'^[^@]*@[^@]*\.[^@]*$', email) + return not email or re.match(r'^[^@]*@[^@]*\.[^@]*$', email) is not None class dotdict(dict): # pylint: disable=invalid-name ''' dot.notation access to dictionary attributes. ''' - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> Any: return self.get(attr) + __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ -def parse_time_range(value): +def parse_time_range(value: str) -> Tuple[datetime, datetime]: ''' Return tuple containing min/max time for given text representation. ''' one_day = timedelta(days=1) one_second = timedelta(seconds=1) @@ -146,9 +132,9 @@ def parse_time_range(value): raise errors.ValidationError('Invalid date format: %r.' % value) -def icase_unique(source): - target = [] - target_low = [] +def icase_unique(source: List[str]) -> List[str]: + target = [] # type: List[str] + target_low = [] # type: List[str] for source_item in source: if source_item.lower() not in target_low: target.append(source_item) @@ -156,7 +142,7 @@ def icase_unique(source): return target -def value_exceeds_column_size(value, column): +def value_exceeds_column_size(value: Optional[str], column: Any) -> bool: if not value: return False max_length = column.property.columns[0].type.length @@ -165,6 +151,6 @@ def value_exceeds_column_size(value, column): return len(value) > max_length -def chunks(source_list, part_size): +def chunks(source_list: List[Any], part_size: int) -> Generator: for i in range(0, len(source_list), part_size): yield source_list[i:i + part_size] diff --git a/server/szurubooru/func/versions.py b/server/szurubooru/func/versions.py index ee84407b..459b0256 100644 --- a/server/szurubooru/func/versions.py +++ b/server/szurubooru/func/versions.py @@ -1,8 +1,11 @@ -from szurubooru import errors +from szurubooru import errors, rest, model -def verify_version(entity, context, field_name='version'): - actual_version = context.get_param_as_int(field_name, required=True) +def verify_version( + entity: model.Base, + context: rest.Context, + field_name: str='version') -> None: + actual_version = context.get_param_as_int(field_name) expected_version = entity.version if actual_version != expected_version: raise errors.IntegrityError( @@ -10,5 +13,5 @@ def verify_version(entity, context, field_name='version'): 'Please try again.') -def bump_version(entity): +def bump_version(entity: model.Base) -> None: entity.version = entity.version + 1 diff --git a/server/szurubooru/middleware/authenticator.py b/server/szurubooru/middleware/authenticator.py index f6b7853f..2c5ac087 100644 --- a/server/szurubooru/middleware/authenticator.py +++ b/server/szurubooru/middleware/authenticator.py @@ -1,11 +1,11 @@ import base64 -from szurubooru import db, errors +from typing import Optional +from szurubooru import db, model, errors, rest from szurubooru.func import auth, users -from szurubooru.rest import middleware from szurubooru.rest.errors import HttpBadRequest -def _authenticate(username, password): +def _authenticate(username: str, password: str) -> model.User: ''' Try to authenticate user. Throw AuthError for invalid users. ''' user = users.get_user_by_name(username) if not auth.is_valid_password(user, password): @@ -13,16 +13,9 @@ def _authenticate(username, password): return user -def _create_anonymous_user(): - user = db.User() - user.name = None - user.rank = 'anonymous' - return user - - -def _get_user(ctx): +def _get_user(ctx: rest.Context) -> Optional[model.User]: if not ctx.has_header('Authorization'): - return _create_anonymous_user() + return None try: auth_type, credentials = ctx.get_header('Authorization').split(' ', 1) @@ -41,10 +34,12 @@ def _get_user(ctx): msg.format(ctx.get_header('Authorization'), str(err))) -@middleware.pre_hook -def process_request(ctx): +@rest.middleware.pre_hook +def process_request(ctx: rest.Context) -> None: ''' Bind the user to request. Update last login time if needed. ''' - ctx.user = _get_user(ctx) - if ctx.get_param_as_bool('bump-login') and ctx.user.user_id: + auth_user = _get_user(ctx) + if auth_user: + ctx.user = auth_user + if ctx.get_param_as_bool('bump-login', default=False) and ctx.user.user_id: users.bump_user_login_time(ctx.user) ctx.session.commit() diff --git a/server/szurubooru/middleware/cache_purger.py b/server/szurubooru/middleware/cache_purger.py index e26b3bae..d83fb845 100644 --- a/server/szurubooru/middleware/cache_purger.py +++ b/server/szurubooru/middleware/cache_purger.py @@ -1,8 +1,9 @@ +from szurubooru import rest from szurubooru.func import cache from szurubooru.rest import middleware @middleware.pre_hook -def process_request(ctx): +def process_request(ctx: rest.Context) -> None: if ctx.method != 'GET': cache.purge() diff --git a/server/szurubooru/middleware/request_logger.py b/server/szurubooru/middleware/request_logger.py index 47b43ab5..54e40e4a 100644 --- a/server/szurubooru/middleware/request_logger.py +++ b/server/szurubooru/middleware/request_logger.py @@ -1,5 +1,5 @@ import logging -from szurubooru import db +from szurubooru import db, rest from szurubooru.rest import middleware @@ -7,12 +7,12 @@ logger = logging.getLogger(__name__) @middleware.pre_hook -def process_request(_ctx): +def process_request(_ctx: rest.Context) -> None: db.reset_query_count() @middleware.post_hook -def process_response(ctx): +def process_response(ctx: rest.Context) -> None: logger.info( '%s %s (user=%s, queries=%d)', ctx.method, diff --git a/server/szurubooru/migrations/env.py b/server/szurubooru/migrations/env.py index 1359ab8a..a4257d48 100644 --- a/server/szurubooru/migrations/env.py +++ b/server/szurubooru/migrations/env.py @@ -2,7 +2,7 @@ import os import sys import alembic -import sqlalchemy +import sqlalchemy as sa import logging.config # make szurubooru module importable @@ -48,7 +48,7 @@ def run_migrations_online(): In this scenario we need to create an Engine and associate a connection with the context. ''' - connectable = sqlalchemy.engine_from_config( + connectable = sa.engine_from_config( alembic_config.get_section(alembic_config.config_ini_section), prefix='sqlalchemy.', poolclass=sqlalchemy.pool.NullPool) diff --git a/server/szurubooru/model/__init__.py b/server/szurubooru/model/__init__.py new file mode 100644 index 00000000..ad2231c2 --- /dev/null +++ b/server/szurubooru/model/__init__.py @@ -0,0 +1,15 @@ +from szurubooru.model.base import Base +from szurubooru.model.user import User +from szurubooru.model.tag_category import TagCategory +from szurubooru.model.tag import Tag, TagName, TagSuggestion, TagImplication +from szurubooru.model.post import ( + Post, + PostTag, + PostRelation, + PostFavorite, + PostScore, + PostNote, + PostFeature) +from szurubooru.model.comment import Comment, CommentScore +from szurubooru.model.snapshot import Snapshot +import szurubooru.model.util diff --git a/server/szurubooru/db/base.py b/server/szurubooru/model/base.py similarity index 100% rename from server/szurubooru/db/base.py rename to server/szurubooru/model/base.py diff --git a/server/szurubooru/db/comment.py b/server/szurubooru/model/comment.py similarity index 84% rename from server/szurubooru/db/comment.py rename to server/szurubooru/model/comment.py index bf325859..55c1596b 100644 --- a/server/szurubooru/db/comment.py +++ b/server/szurubooru/model/comment.py @@ -1,7 +1,8 @@ from sqlalchemy import Column, Integer, DateTime, UnicodeText, ForeignKey from sqlalchemy.orm import relationship, backref from sqlalchemy.sql.expression import func -from szurubooru.db.base import Base +from szurubooru.db import get_session +from szurubooru.model.base import Base class CommentScore(Base): @@ -48,12 +49,12 @@ class Comment(Base): 'CommentScore', cascade='all, delete-orphan', lazy='joined') @property - def score(self): - from szurubooru.db import session - return session \ - .query(func.sum(CommentScore.score)) \ - .filter(CommentScore.comment_id == self.comment_id) \ - .one()[0] or 0 + def score(self) -> int: + return ( + get_session() + .query(func.sum(CommentScore.score)) + .filter(CommentScore.comment_id == self.comment_id) + .one()[0] or 0) __mapper_args__ = { 'version_id_col': version, diff --git a/server/szurubooru/db/post.py b/server/szurubooru/model/post.py similarity index 95% rename from server/szurubooru/db/post.py rename to server/szurubooru/model/post.py index f0c9f91f..23f52b57 100644 --- a/server/szurubooru/db/post.py +++ b/server/szurubooru/model/post.py @@ -3,8 +3,8 @@ from sqlalchemy import ( Column, Integer, DateTime, Unicode, UnicodeText, PickleType, ForeignKey) from sqlalchemy.orm import ( relationship, column_property, object_session, backref) -from szurubooru.db.base import Base -from szurubooru.db.comment import Comment +from szurubooru.model.base import Base +from szurubooru.model.comment import Comment class PostFeature(Base): @@ -17,10 +17,9 @@ class PostFeature(Base): 'user_id', Integer, ForeignKey('user.id'), nullable=False, index=True) time = Column('time', DateTime, nullable=False) - post = relationship('Post') + post = relationship('Post') # type: Post user = relationship( - 'User', - backref=backref('post_features', cascade='all, delete-orphan')) + 'User', backref=backref('post_features', cascade='all, delete-orphan')) class PostScore(Base): @@ -104,7 +103,7 @@ class PostRelation(Base): nullable=False, index=True) - def __init__(self, parent_id, child_id): + def __init__(self, parent_id: int, child_id: int) -> None: self.parent_id = parent_id self.child_id = child_id @@ -127,7 +126,7 @@ class PostTag(Base): nullable=False, index=True) - def __init__(self, post_id, tag_id): + def __init__(self, post_id: int, tag_id: int) -> None: self.post_id = post_id self.tag_id = tag_id @@ -197,7 +196,7 @@ class Post(Base): canvas_area = column_property(canvas_width * canvas_height) @property - def is_featured(self): + def is_featured(self) -> bool: featured_post = object_session(self) \ .query(PostFeature) \ .order_by(PostFeature.time.desc()) \ diff --git a/server/szurubooru/db/snapshot.py b/server/szurubooru/model/snapshot.py similarity index 96% rename from server/szurubooru/db/snapshot.py rename to server/szurubooru/model/snapshot.py index 4b211f61..beb3bb25 100644 --- a/server/szurubooru/db/snapshot.py +++ b/server/szurubooru/model/snapshot.py @@ -1,7 +1,7 @@ from sqlalchemy.orm import relationship from sqlalchemy import ( Column, Integer, DateTime, Unicode, PickleType, ForeignKey) -from szurubooru.db.base import Base +from szurubooru.model.base import Base class Snapshot(Base): diff --git a/server/szurubooru/db/tag.py b/server/szurubooru/model/tag.py similarity index 93% rename from server/szurubooru/db/tag.py rename to server/szurubooru/model/tag.py index 10813eb9..1bce3ffa 100644 --- a/server/szurubooru/db/tag.py +++ b/server/szurubooru/model/tag.py @@ -2,8 +2,8 @@ from sqlalchemy import ( Column, Integer, DateTime, Unicode, UnicodeText, ForeignKey) from sqlalchemy.orm import relationship, column_property from sqlalchemy.sql.expression import func, select -from szurubooru.db.base import Base -from szurubooru.db.post import PostTag +from szurubooru.model.base import Base +from szurubooru.model.post import PostTag class TagSuggestion(Base): @@ -24,7 +24,7 @@ class TagSuggestion(Base): primary_key=True, index=True) - def __init__(self, parent_id, child_id): + def __init__(self, parent_id: int, child_id: int) -> None: self.parent_id = parent_id self.child_id = child_id @@ -47,7 +47,7 @@ class TagImplication(Base): primary_key=True, index=True) - def __init__(self, parent_id, child_id): + def __init__(self, parent_id: int, child_id: int) -> None: self.parent_id = parent_id self.child_id = child_id @@ -61,7 +61,7 @@ class TagName(Base): name = Column('name', Unicode(64), nullable=False, unique=True) order = Column('ord', Integer, nullable=False, index=True) - def __init__(self, name, order): + def __init__(self, name: str, order: int) -> None: self.name = name self.order = order diff --git a/server/szurubooru/db/tag_category.py b/server/szurubooru/model/tag_category.py similarity index 84% rename from server/szurubooru/db/tag_category.py rename to server/szurubooru/model/tag_category.py index 907910ba..001f9653 100644 --- a/server/szurubooru/db/tag_category.py +++ b/server/szurubooru/model/tag_category.py @@ -1,8 +1,9 @@ +from typing import Optional from sqlalchemy import Column, Integer, Unicode, Boolean, table from sqlalchemy.orm import column_property from sqlalchemy.sql.expression import func, select -from szurubooru.db.base import Base -from szurubooru.db.tag import Tag +from szurubooru.model.base import Base +from szurubooru.model.tag import Tag class TagCategory(Base): @@ -14,7 +15,7 @@ class TagCategory(Base): color = Column('color', Unicode(32), nullable=False, default='#000000') default = Column('default', Boolean, nullable=False, default=False) - def __init__(self, name=None): + def __init__(self, name: Optional[str]=None) -> None: self.name = name tag_count = column_property( diff --git a/server/szurubooru/db/user.py b/server/szurubooru/model/user.py similarity index 50% rename from server/szurubooru/db/user.py rename to server/szurubooru/model/user.py index 4f4f9961..dd7c0629 100644 --- a/server/szurubooru/db/user.py +++ b/server/szurubooru/model/user.py @@ -1,9 +1,7 @@ -from sqlalchemy import Column, Integer, Unicode, DateTime -from sqlalchemy.orm import relationship -from sqlalchemy.sql.expression import func -from szurubooru.db.base import Base -from szurubooru.db.post import Post, PostScore, PostFavorite -from szurubooru.db.comment import Comment +import sqlalchemy as sa +from szurubooru.model.base import Base +from szurubooru.model.post import Post, PostScore, PostFavorite +from szurubooru.model.comment import Comment class User(Base): @@ -20,63 +18,64 @@ class User(Base): RANK_ADMINISTRATOR = 'administrator' RANK_NOBODY = 'nobody' # unattainable, used for privileges - user_id = Column('id', Integer, primary_key=True) - creation_time = Column('creation_time', DateTime, nullable=False) - last_login_time = Column('last_login_time', DateTime) - version = Column('version', Integer, default=1, nullable=False) - name = Column('name', Unicode(50), nullable=False, unique=True) - password_hash = Column('password_hash', Unicode(64), nullable=False) - password_salt = Column('password_salt', Unicode(32)) - email = Column('email', Unicode(64), nullable=True) - rank = Column('rank', Unicode(32), nullable=False) - avatar_style = Column( - 'avatar_style', Unicode(32), nullable=False, default=AVATAR_GRAVATAR) + user_id = sa.Column('id', sa.Integer, primary_key=True) + creation_time = sa.Column('creation_time', sa.DateTime, nullable=False) + last_login_time = sa.Column('last_login_time', sa.DateTime) + version = sa.Column('version', sa.Integer, default=1, nullable=False) + name = sa.Column('name', sa.Unicode(50), nullable=False, unique=True) + password_hash = sa.Column('password_hash', sa.Unicode(64), nullable=False) + password_salt = sa.Column('password_salt', sa.Unicode(32)) + email = sa.Column('email', sa.Unicode(64), nullable=True) + rank = sa.Column('rank', sa.Unicode(32), nullable=False) + avatar_style = sa.Column( + 'avatar_style', sa.Unicode(32), nullable=False, + default=AVATAR_GRAVATAR) - comments = relationship('Comment') + comments = sa.orm.relationship('Comment') @property - def post_count(self): + def post_count(self) -> int: from szurubooru.db import session return ( session - .query(func.sum(1)) + .query(sa.sql.expression.func.sum(1)) .filter(Post.user_id == self.user_id) .one()[0] or 0) @property - def comment_count(self): + def comment_count(self) -> int: from szurubooru.db import session return ( session - .query(func.sum(1)) + .query(sa.sql.expression.func.sum(1)) .filter(Comment.user_id == self.user_id) .one()[0] or 0) @property - def favorite_post_count(self): + def favorite_post_count(self) -> int: from szurubooru.db import session return ( session - .query(func.sum(1)) + .query(sa.sql.expression.func.sum(1)) .filter(PostFavorite.user_id == self.user_id) .one()[0] or 0) @property - def liked_post_count(self): + def liked_post_count(self) -> int: from szurubooru.db import session return ( session - .query(func.sum(1)) + .query(sa.sql.expression.func.sum(1)) .filter(PostScore.user_id == self.user_id) .filter(PostScore.score == 1) .one()[0] or 0) @property - def disliked_post_count(self): + def disliked_post_count(self) -> int: from szurubooru.db import session return ( session - .query(func.sum(1)) + .query(sa.sql.expression.func.sum(1)) .filter(PostScore.user_id == self.user_id) .filter(PostScore.score == -1) .one()[0] or 0) diff --git a/server/szurubooru/model/util.py b/server/szurubooru/model/util.py new file mode 100644 index 00000000..e82539f1 --- /dev/null +++ b/server/szurubooru/model/util.py @@ -0,0 +1,42 @@ +from typing import Tuple, Any, Dict, Callable, Union, Optional +import sqlalchemy as sa +from szurubooru.model.base import Base +from szurubooru.model.user import User + + +def get_resource_info(entity: Base) -> Tuple[Any, Any, Union[str, int]]: + serializers = { + 'tag': lambda tag: tag.first_name, + 'tag_category': lambda category: category.name, + 'comment': lambda comment: comment.comment_id, + 'post': lambda post: post.post_id, + } # type: Dict[str, Callable[[Base], Any]] + + resource_type = entity.__table__.name + assert resource_type in serializers + + primary_key = sa.inspection.inspect(entity).identity # type: Any + assert primary_key is not None + assert len(primary_key) == 1 + + resource_name = serializers[resource_type](entity) # type: Union[str, int] + assert resource_name + + resource_pkey = primary_key[0] # type: Any + assert resource_pkey + + return (resource_type, resource_pkey, resource_name) + + +def get_aux_entity( + session: Any, + get_table_info: Callable[[Base], Tuple[Base, Callable[[Base], Any]]], + entity: Base, + user: User) -> Optional[Base]: + table, get_column = get_table_info(entity) + return ( + session + .query(table) + .filter(get_column(table) == get_column(entity)) + .filter(table.user_id == user.user_id) + .one_or_none()) diff --git a/server/szurubooru/rest/__init__.py b/server/szurubooru/rest/__init__.py index ac9958a5..14a3e305 100644 --- a/server/szurubooru/rest/__init__.py +++ b/server/szurubooru/rest/__init__.py @@ -1,2 +1,2 @@ from szurubooru.rest.app import application -from szurubooru.rest.context import Context +from szurubooru.rest.context import Context, Response diff --git a/server/szurubooru/rest/app.py b/server/szurubooru/rest/app.py index 1bbf8dce..b29110e7 100644 --- a/server/szurubooru/rest/app.py +++ b/server/szurubooru/rest/app.py @@ -2,13 +2,14 @@ import urllib.parse import cgi import json import re +from typing import Dict, Any, Callable, Tuple from datetime import datetime from szurubooru import db from szurubooru.func import util from szurubooru.rest import errors, middleware, routes, context -def _json_serializer(obj): +def _json_serializer(obj: Any) -> str: ''' JSON serializer for objects not serializable by default JSON code ''' if isinstance(obj, datetime): serial = obj.isoformat('T') + 'Z' @@ -16,12 +17,12 @@ def _json_serializer(obj): raise TypeError('Type not serializable') -def _dump_json(obj): +def _dump_json(obj: Any) -> str: return json.dumps(obj, default=_json_serializer, indent=2) -def _get_headers(env): - headers = {} +def _get_headers(env: Dict[str, Any]) -> Dict[str, str]: + headers = {} # type: Dict[str, str] for key, value in env.items(): if key.startswith('HTTP_'): key = util.snake_case_to_upper_train_case(key[5:]) @@ -29,7 +30,7 @@ def _get_headers(env): return headers -def _create_context(env): +def _create_context(env: Dict[str, Any]) -> context.Context: method = env['REQUEST_METHOD'] path = '/' + env['PATH_INFO'].lstrip('/') headers = _get_headers(env) @@ -64,7 +65,9 @@ def _create_context(env): return context.Context(method, path, headers, params, files) -def application(env, start_response): +def application( + env: Dict[str, Any], + start_response: Callable[[str, Any], Any]) -> Tuple[bytes]: try: ctx = _create_context(env) if 'application/json' not in ctx.get_header('Accept'): @@ -106,9 +109,9 @@ def application(env, start_response): return (_dump_json(response).encode('utf-8'),) except Exception as ex: - for exception_type, handler in errors.error_handlers.items(): + for exception_type, ex_handler in errors.error_handlers.items(): if isinstance(ex, exception_type): - handler(ex) + ex_handler(ex) raise except errors.BaseHttpError as ex: diff --git a/server/szurubooru/rest/context.py b/server/szurubooru/rest/context.py index ae26f38b..bb33bfab 100644 --- a/server/szurubooru/rest/context.py +++ b/server/szurubooru/rest/context.py @@ -1,111 +1,158 @@ -from szurubooru import errors +from typing import Any, Union, List, Dict, Optional, cast +from szurubooru import model, errors from szurubooru.func import net, file_uploads -def _lower_first(source): - return source[0].lower() + source[1:] - - -def _param_wrapper(func): - def wrapper(self, name, required=False, default=None, **kwargs): - # pylint: disable=protected-access - if name in self._params: - value = self._params[name] - try: - value = func(self, value, **kwargs) - except errors.InvalidParameterError as ex: - raise errors.InvalidParameterError( - 'Parameter %r is invalid: %s' % ( - name, _lower_first(str(ex)))) - return value - if not required: - return default - raise errors.MissingRequiredParameterError( - 'Required parameter %r is missing.' % name) - return wrapper +MISSING = object() +Request = Dict[str, Any] +Response = Optional[Dict[str, Any]] class Context: - def __init__(self, method, url, headers=None, params=None, files=None): + def __init__( + self, + method: str, + url: str, + headers: Dict[str, str]=None, + params: Request=None, + files: Dict[str, bytes]=None) -> None: self.method = method self.url = url self._headers = headers or {} self._params = params or {} self._files = files or {} - # provided by middleware - # self.session = None - # self.user = None + self.user = model.User() + self.user.name = None + self.user.rank = 'anonymous' - def has_header(self, name): + self.session = None # type: Any + + def has_header(self, name: str) -> bool: return name in self._headers - def get_header(self, name): - return self._headers.get(name, None) + def get_header(self, name: str) -> str: + return self._headers.get(name, '') - def has_file(self, name, allow_tokens=True): + def has_file(self, name: str, allow_tokens: bool=True) -> bool: return ( name in self._files or name + 'Url' in self._params or (allow_tokens and name + 'Token' in self._params)) - def get_file(self, name, required=False, allow_tokens=True): - ret = None - if name in self._files: - ret = self._files[name] - elif name + 'Url' in self._params: - ret = net.download(self._params[name + 'Url']) - elif allow_tokens and name + 'Token' in self._params: + def get_file( + self, + name: str, + default: Union[object, bytes]=MISSING, + allow_tokens: bool=True) -> bytes: + if name in self._files and self._files[name]: + return self._files[name] + + if name + 'Url' in self._params: + return net.download(self._params[name + 'Url']) + + if allow_tokens and name + 'Token' in self._params: ret = file_uploads.get(self._params[name + 'Token']) - if required and not ret: + if ret: + return ret + elif default is not MISSING: raise errors.MissingOrExpiredRequiredFileError( 'Required file %r is missing or has expired.' % name) - if required and not ret: - raise errors.MissingRequiredFileError( - 'Required file %r is missing.' % name) - return ret - def has_param(self, name): + if default is not MISSING: + return cast(bytes, default) + raise errors.MissingRequiredFileError( + 'Required file %r is missing.' % name) + + def has_param(self, name: str) -> bool: return name in self._params - @_param_wrapper - def get_param_as_list(self, value): - if not isinstance(value, list): + def get_param_as_list( + self, + name: str, + default: Union[object, List[Any]]=MISSING) -> List[Any]: + if name not in self._params: + if default is not MISSING: + return cast(List[Any], default) + raise errors.MissingRequiredParameterError( + 'Required parameter %r is missing.' % name) + value = self._params[name] + if type(value) is str: if ',' in value: return value.split(',') return [value] - return value + if type(value) is list: + return value + raise errors.InvalidParameterError( + 'Parameter %r must be a list.' % name) - @_param_wrapper - def get_param_as_string(self, value): - if isinstance(value, list): - try: - value = ','.join(value) - except TypeError: - raise errors.InvalidParameterError('Expected simple string.') - return value + def get_param_as_string( + self, + name: str, + default: Union[object, str]=MISSING) -> str: + if name not in self._params: + if default is not MISSING: + return cast(str, default) + raise errors.MissingRequiredParameterError( + 'Required parameter %r is missing.' % name) + value = self._params[name] + try: + if value is None: + return '' + if type(value) is list: + return ','.join(value) + if type(value) is int or type(value) is float: + return str(value) + if type(value) is str: + return value + except TypeError: + pass + raise errors.InvalidParameterError( + 'Parameter %r must be a string value.' % name) - @_param_wrapper - def get_param_as_int(self, value, min=None, max=None): + def get_param_as_int( + self, + name: str, + default: Union[object, int]=MISSING, + min: Optional[int]=None, + max: Optional[int]=None) -> int: + if name not in self._params: + if default is not MISSING: + return cast(int, default) + raise errors.MissingRequiredParameterError( + 'Required parameter %r is missing.' % name) + value = self._params[name] try: value = int(value) + if min is not None and value < min: + raise errors.InvalidParameterError( + 'Parameter %r must be at least %r.' % (name, min)) + if max is not None and value > max: + raise errors.InvalidParameterError( + 'Parameter %r may not exceed %r.' % (name, max)) + return value except (ValueError, TypeError): - raise errors.InvalidParameterError( - 'The value must be an integer.') - if min is not None and value < min: - raise errors.InvalidParameterError( - 'The value must be at least %r.' % min) - if max is not None and value > max: - raise errors.InvalidParameterError( - 'The value may not exceed %r.' % max) - return value + pass + raise errors.InvalidParameterError( + 'Parameter %r must be an integer value.' % name) - @_param_wrapper - def get_param_as_bool(self, value): - value = str(value).lower() + def get_param_as_bool( + self, + name: str, + default: Union[object, bool]=MISSING) -> bool: + if name not in self._params: + if default is not MISSING: + return cast(bool, default) + raise errors.MissingRequiredParameterError( + 'Required parameter %r is missing.' % name) + value = self._params[name] + try: + value = str(value).lower() + except TypeError: + pass if value in ['1', 'y', 'yes', 'yeah', 'yep', 'yup', 't', 'true']: return True if value in ['0', 'n', 'no', 'nope', 'f', 'false']: return False raise errors.InvalidParameterError( - 'The value must be a boolean value.') + 'Parameter %r must be a boolean value.' % name) diff --git a/server/szurubooru/rest/errors.py b/server/szurubooru/rest/errors.py index b0f5b882..6854e7d3 100644 --- a/server/szurubooru/rest/errors.py +++ b/server/szurubooru/rest/errors.py @@ -1,11 +1,19 @@ +from typing import Callable, Type, Dict + + error_handlers = {} # pylint: disable=invalid-name class BaseHttpError(RuntimeError): - code = None - reason = None + code = -1 + reason = '' - def __init__(self, name, description, title=None, extra_fields=None): + def __init__( + self, + name: str, + description: str, + title: str=None, + extra_fields: Dict[str, str]=None) -> None: super().__init__() # error name for programmers self.name = name @@ -52,5 +60,7 @@ class HttpInternalServerError(BaseHttpError): reason = 'Internal Server Error' -def handle(exception_type, handler): +def handle( + exception_type: Type[Exception], + handler: Callable[[Exception], None]) -> None: error_handlers[exception_type] = handler diff --git a/server/szurubooru/rest/middleware.py b/server/szurubooru/rest/middleware.py index 7cf07296..05d9495e 100644 --- a/server/szurubooru/rest/middleware.py +++ b/server/szurubooru/rest/middleware.py @@ -1,11 +1,15 @@ +from typing import Callable +from szurubooru.rest.context import Context + + # pylint: disable=invalid-name -pre_hooks = [] -post_hooks = [] +pre_hooks = [] # type: List[Callable[[Context], None]] +post_hooks = [] # type: List[Callable[[Context], None]] -def pre_hook(handler): +def pre_hook(handler: Callable) -> None: pre_hooks.append(handler) -def post_hook(handler): +def post_hook(handler: Callable) -> None: post_hooks.insert(0, handler) diff --git a/server/szurubooru/rest/routes.py b/server/szurubooru/rest/routes.py index ffa95f56..c0b6bea3 100644 --- a/server/szurubooru/rest/routes.py +++ b/server/szurubooru/rest/routes.py @@ -1,32 +1,36 @@ +from typing import Callable, Dict, Any from collections import defaultdict +from szurubooru.rest.context import Context, Response -routes = defaultdict(dict) # pylint: disable=invalid-name +# pylint: disable=invalid-name +RouteHandler = Callable[[Context, Dict[str, str]], Response] +routes = defaultdict(dict) # type: Dict[str, Dict[str, RouteHandler]] -def get(url): - def wrapper(handler): +def get(url: str) -> Callable[[RouteHandler], RouteHandler]: + def wrapper(handler: RouteHandler) -> RouteHandler: routes[url]['GET'] = handler return handler return wrapper -def put(url): - def wrapper(handler): +def put(url: str) -> Callable[[RouteHandler], RouteHandler]: + def wrapper(handler: RouteHandler) -> RouteHandler: routes[url]['PUT'] = handler return handler return wrapper -def post(url): - def wrapper(handler): +def post(url: str) -> Callable[[RouteHandler], RouteHandler]: + def wrapper(handler: RouteHandler) -> RouteHandler: routes[url]['POST'] = handler return handler return wrapper -def delete(url): - def wrapper(handler): +def delete(url: str) -> Callable[[RouteHandler], RouteHandler]: + def wrapper(handler: RouteHandler) -> RouteHandler: routes[url]['DELETE'] = handler return handler return wrapper diff --git a/server/szurubooru/search/configs/base_search_config.py b/server/szurubooru/search/configs/base_search_config.py index adc50d30..0cb814d4 100644 --- a/server/szurubooru/search/configs/base_search_config.py +++ b/server/szurubooru/search/configs/base_search_config.py @@ -1,38 +1,47 @@ -from szurubooru.search import tokens +from typing import Optional, Tuple, Dict, Callable +from szurubooru.search import tokens, criteria +from szurubooru.search.query import SearchQuery +from szurubooru.search.typing import SaColumn, SaQuery + +Filter = Callable[[SaQuery, Optional[criteria.BaseCriterion], bool], SaQuery] class BaseSearchConfig: + SORT_NONE = tokens.SortToken.SORT_NONE SORT_ASC = tokens.SortToken.SORT_ASC SORT_DESC = tokens.SortToken.SORT_DESC - def on_search_query_parsed(self, search_query): + def on_search_query_parsed(self, search_query: SearchQuery) -> None: pass - def create_filter_query(self, _disable_eager_loads): + def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery: raise NotImplementedError() - def create_count_query(self, disable_eager_loads): + def create_count_query(self, disable_eager_loads: bool) -> SaQuery: raise NotImplementedError() - def create_around_query(self): + def create_around_query(self) -> SaQuery: raise NotImplementedError() + def finalize_query(self, query: SaQuery) -> SaQuery: + return query + @property - def id_column(self): + def id_column(self) -> SaColumn: return None @property - def anonymous_filter(self): + def anonymous_filter(self) -> Optional[Filter]: return None @property - def special_filters(self): + def special_filters(self) -> Dict[str, Filter]: return {} @property - def named_filters(self): + def named_filters(self) -> Dict[str, Filter]: return {} @property - def sort_columns(self): + def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]: return {} diff --git a/server/szurubooru/search/configs/comment_search_config.py b/server/szurubooru/search/configs/comment_search_config.py index 9b2515e8..8b154460 100644 --- a/server/szurubooru/search/configs/comment_search_config.py +++ b/server/szurubooru/search/configs/comment_search_config.py @@ -1,59 +1,62 @@ -from sqlalchemy.sql.expression import func -from szurubooru import db +from typing import Tuple, Dict +import sqlalchemy as sa +from szurubooru import db, model +from szurubooru.search.typing import SaColumn, SaQuery from szurubooru.search.configs import util as search_util -from szurubooru.search.configs.base_search_config import BaseSearchConfig +from szurubooru.search.configs.base_search_config import ( + BaseSearchConfig, Filter) class CommentSearchConfig(BaseSearchConfig): - def create_filter_query(self, _disable_eager_loads): - return db.session.query(db.Comment).join(db.User) + def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery: + return db.session.query(model.Comment).join(model.User) - def create_count_query(self, disable_eager_loads): + def create_count_query(self, disable_eager_loads: bool) -> SaQuery: return self.create_filter_query(disable_eager_loads) - def create_around_query(self): + def create_around_query(self) -> SaQuery: raise NotImplementedError() - def finalize_query(self, query): - return query.order_by(db.Comment.creation_time.desc()) + def finalize_query(self, query: SaQuery) -> SaQuery: + return query.order_by(model.Comment.creation_time.desc()) @property - def anonymous_filter(self): - return search_util.create_str_filter(db.Comment.text) + def anonymous_filter(self) -> SaQuery: + return search_util.create_str_filter(model.Comment.text) @property - def named_filters(self): + def named_filters(self) -> Dict[str, Filter]: return { - 'id': search_util.create_num_filter(db.Comment.comment_id), - 'post': search_util.create_num_filter(db.Comment.post_id), - 'user': search_util.create_str_filter(db.User.name), - 'author': search_util.create_str_filter(db.User.name), - 'text': search_util.create_str_filter(db.Comment.text), + 'id': search_util.create_num_filter(model.Comment.comment_id), + 'post': search_util.create_num_filter(model.Comment.post_id), + 'user': search_util.create_str_filter(model.User.name), + 'author': search_util.create_str_filter(model.User.name), + 'text': search_util.create_str_filter(model.Comment.text), 'creation-date': - search_util.create_date_filter(db.Comment.creation_time), + search_util.create_date_filter(model.Comment.creation_time), 'creation-time': - search_util.create_date_filter(db.Comment.creation_time), + search_util.create_date_filter(model.Comment.creation_time), 'last-edit-date': - search_util.create_date_filter(db.Comment.last_edit_time), + search_util.create_date_filter(model.Comment.last_edit_time), 'last-edit-time': - search_util.create_date_filter(db.Comment.last_edit_time), + search_util.create_date_filter(model.Comment.last_edit_time), 'edit-date': - search_util.create_date_filter(db.Comment.last_edit_time), + search_util.create_date_filter(model.Comment.last_edit_time), 'edit-time': - search_util.create_date_filter(db.Comment.last_edit_time), + search_util.create_date_filter(model.Comment.last_edit_time), } @property - def sort_columns(self): + def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]: return { - 'random': (func.random(), None), - 'user': (db.User.name, self.SORT_ASC), - 'author': (db.User.name, self.SORT_ASC), - 'post': (db.Comment.post_id, self.SORT_DESC), - 'creation-date': (db.Comment.creation_time, self.SORT_DESC), - 'creation-time': (db.Comment.creation_time, self.SORT_DESC), - 'last-edit-date': (db.Comment.last_edit_time, self.SORT_DESC), - 'last-edit-time': (db.Comment.last_edit_time, self.SORT_DESC), - 'edit-date': (db.Comment.last_edit_time, self.SORT_DESC), - 'edit-time': (db.Comment.last_edit_time, self.SORT_DESC), + 'random': (sa.sql.expression.func.random(), self.SORT_NONE), + 'user': (model.User.name, self.SORT_ASC), + 'author': (model.User.name, self.SORT_ASC), + 'post': (model.Comment.post_id, self.SORT_DESC), + 'creation-date': (model.Comment.creation_time, self.SORT_DESC), + 'creation-time': (model.Comment.creation_time, self.SORT_DESC), + 'last-edit-date': (model.Comment.last_edit_time, self.SORT_DESC), + 'last-edit-time': (model.Comment.last_edit_time, self.SORT_DESC), + 'edit-date': (model.Comment.last_edit_time, self.SORT_DESC), + 'edit-time': (model.Comment.last_edit_time, self.SORT_DESC), } diff --git a/server/szurubooru/search/configs/post_search_config.py b/server/szurubooru/search/configs/post_search_config.py index 7005cd7c..cda1b1ac 100644 --- a/server/szurubooru/search/configs/post_search_config.py +++ b/server/szurubooru/search/configs/post_search_config.py @@ -1,13 +1,16 @@ -from sqlalchemy.orm import subqueryload, lazyload, defer, aliased -from sqlalchemy.sql.expression import func -from szurubooru import db, errors +from typing import Any, Optional, Tuple, Dict +import sqlalchemy as sa +from szurubooru import db, model, errors from szurubooru.func import util from szurubooru.search import criteria, tokens +from szurubooru.search.typing import SaColumn, SaQuery +from szurubooru.search.query import SearchQuery from szurubooru.search.configs import util as search_util -from szurubooru.search.configs.base_search_config import BaseSearchConfig +from szurubooru.search.configs.base_search_config import ( + BaseSearchConfig, Filter) -def _enum_transformer(available_values, value): +def _enum_transformer(available_values: Dict[str, Any], value: str) -> str: try: return available_values[value.lower()] except KeyError: @@ -16,71 +19,82 @@ def _enum_transformer(available_values, value): value, list(sorted(available_values.keys())))) -def _type_transformer(value): +def _type_transformer(value: str) -> str: available_values = { - 'image': db.Post.TYPE_IMAGE, - 'animation': db.Post.TYPE_ANIMATION, - 'animated': db.Post.TYPE_ANIMATION, - 'anim': db.Post.TYPE_ANIMATION, - 'gif': db.Post.TYPE_ANIMATION, - 'video': db.Post.TYPE_VIDEO, - 'webm': db.Post.TYPE_VIDEO, - 'flash': db.Post.TYPE_FLASH, - 'swf': db.Post.TYPE_FLASH, + 'image': model.Post.TYPE_IMAGE, + 'animation': model.Post.TYPE_ANIMATION, + 'animated': model.Post.TYPE_ANIMATION, + 'anim': model.Post.TYPE_ANIMATION, + 'gif': model.Post.TYPE_ANIMATION, + 'video': model.Post.TYPE_VIDEO, + 'webm': model.Post.TYPE_VIDEO, + 'flash': model.Post.TYPE_FLASH, + 'swf': model.Post.TYPE_FLASH, } return _enum_transformer(available_values, value) -def _safety_transformer(value): +def _safety_transformer(value: str) -> str: available_values = { - 'safe': db.Post.SAFETY_SAFE, - 'sketchy': db.Post.SAFETY_SKETCHY, - 'questionable': db.Post.SAFETY_SKETCHY, - 'unsafe': db.Post.SAFETY_UNSAFE, + 'safe': model.Post.SAFETY_SAFE, + 'sketchy': model.Post.SAFETY_SKETCHY, + 'questionable': model.Post.SAFETY_SKETCHY, + 'unsafe': model.Post.SAFETY_UNSAFE, } return _enum_transformer(available_values, value) -def _create_score_filter(score): - def wrapper(query, criterion, negated): +def _create_score_filter(score: int) -> Filter: + def wrapper( + query: SaQuery, + criterion: Optional[criteria.BaseCriterion], + negated: bool) -> SaQuery: + assert criterion if not getattr(criterion, 'internal', False): raise errors.SearchError( 'Votes cannot be seen publicly. Did you mean %r?' % 'special:liked') - user_alias = aliased(db.User) - score_alias = aliased(db.PostScore) + user_alias = sa.orm.aliased(model.User) + score_alias = sa.orm.aliased(model.PostScore) expr = score_alias.score == score expr = expr & search_util.apply_str_criterion_to_column( user_alias.name, criterion) if negated: expr = ~expr ret = query \ - .join(score_alias, score_alias.post_id == db.Post.post_id) \ + .join(score_alias, score_alias.post_id == model.Post.post_id) \ .join(user_alias, user_alias.user_id == score_alias.user_id) \ .filter(expr) return ret return wrapper -def _create_user_filter(): - def wrapper(query, criterion, negated): +def _create_user_filter() -> Filter: + def wrapper( + query: SaQuery, + criterion: Optional[criteria.BaseCriterion], + negated: bool) -> SaQuery: + assert criterion if isinstance(criterion, criteria.PlainCriterion) \ and not criterion.value: # pylint: disable=singleton-comparison - expr = db.Post.user_id == None + expr = model.Post.user_id == None if negated: expr = ~expr return query.filter(expr) return search_util.create_subquery_filter( - db.Post.user_id, - db.User.user_id, - db.User.name, + model.Post.user_id, + model.User.user_id, + model.User.name, search_util.create_str_filter)(query, criterion, negated) return wrapper class PostSearchConfig(BaseSearchConfig): - def on_search_query_parsed(self, search_query): + def __init__(self) -> None: + self.user = None # type: Optional[model.User] + + def on_search_query_parsed(self, search_query: SearchQuery) -> SaQuery: new_special_tokens = [] for token in search_query.special_tokens: if token.value in ('fav', 'liked', 'disliked'): @@ -91,7 +105,7 @@ class PostSearchConfig(BaseSearchConfig): criterion = criteria.PlainCriterion( original_text=self.user.name, value=self.user.name) - criterion.internal = True + setattr(criterion, 'internal', True) search_query.named_tokens.append( tokens.NamedToken( name=token.value, @@ -101,160 +115,324 @@ class PostSearchConfig(BaseSearchConfig): new_special_tokens.append(token) search_query.special_tokens = new_special_tokens - def create_around_query(self): - return db.session.query(db.Post).options(lazyload('*')) + def create_around_query(self) -> SaQuery: + return db.session.query(model.Post).options(sa.orm.lazyload('*')) - def create_filter_query(self, disable_eager_loads): - strategy = lazyload if disable_eager_loads else subqueryload - return db.session.query(db.Post) \ + def create_filter_query(self, disable_eager_loads: bool) -> SaQuery: + strategy = ( + sa.orm.lazyload + if disable_eager_loads + else sa.orm.subqueryload) + return db.session.query(model.Post) \ .options( - lazyload('*'), + sa.orm.lazyload('*'), # use config optimized for official client - # defer(db.Post.score), - # defer(db.Post.favorite_count), - # defer(db.Post.comment_count), - defer(db.Post.last_favorite_time), - defer(db.Post.feature_count), - defer(db.Post.last_feature_time), - defer(db.Post.last_comment_creation_time), - defer(db.Post.last_comment_edit_time), - defer(db.Post.note_count), - defer(db.Post.tag_count), - strategy(db.Post.tags).subqueryload(db.Tag.names), - strategy(db.Post.tags).defer(db.Tag.post_count), - strategy(db.Post.tags).lazyload(db.Tag.implications), - strategy(db.Post.tags).lazyload(db.Tag.suggestions)) + # sa.orm.defer(model.Post.score), + # sa.orm.defer(model.Post.favorite_count), + # sa.orm.defer(model.Post.comment_count), + sa.orm.defer(model.Post.last_favorite_time), + sa.orm.defer(model.Post.feature_count), + sa.orm.defer(model.Post.last_feature_time), + sa.orm.defer(model.Post.last_comment_creation_time), + sa.orm.defer(model.Post.last_comment_edit_time), + sa.orm.defer(model.Post.note_count), + sa.orm.defer(model.Post.tag_count), + strategy(model.Post.tags).subqueryload(model.Tag.names), + strategy(model.Post.tags).defer(model.Tag.post_count), + strategy(model.Post.tags).lazyload(model.Tag.implications), + strategy(model.Post.tags).lazyload(model.Tag.suggestions)) - def create_count_query(self, _disable_eager_loads): - return db.session.query(db.Post) + def create_count_query(self, _disable_eager_loads: bool) -> SaQuery: + return db.session.query(model.Post) - def finalize_query(self, query): - return query.order_by(db.Post.post_id.desc()) + def finalize_query(self, query: SaQuery) -> SaQuery: + return query.order_by(model.Post.post_id.desc()) @property - def id_column(self): - return db.Post.post_id + def id_column(self) -> SaColumn: + return model.Post.post_id @property - def anonymous_filter(self): + def anonymous_filter(self) -> Optional[Filter]: return search_util.create_subquery_filter( - db.Post.post_id, - db.PostTag.post_id, - db.TagName.name, + model.Post.post_id, + model.PostTag.post_id, + model.TagName.name, search_util.create_str_filter, - lambda subquery: subquery.join(db.Tag).join(db.TagName)) + lambda subquery: subquery.join(model.Tag).join(model.TagName)) @property - def named_filters(self): - return util.unalias_dict({ - 'id': search_util.create_num_filter(db.Post.post_id), - 'tag': search_util.create_subquery_filter( - db.Post.post_id, - db.PostTag.post_id, - db.TagName.name, - search_util.create_str_filter, - lambda subquery: subquery.join(db.Tag).join(db.TagName)), - 'score': search_util.create_num_filter(db.Post.score), - ('uploader', 'upload', 'submit'): - _create_user_filter(), - 'comment': search_util.create_subquery_filter( - db.Post.post_id, - db.Comment.post_id, - db.User.name, - search_util.create_str_filter, - lambda subquery: subquery.join(db.User)), - 'fav': search_util.create_subquery_filter( - db.Post.post_id, - db.PostFavorite.post_id, - db.User.name, - search_util.create_str_filter, - lambda subquery: subquery.join(db.User)), - 'liked': _create_score_filter(1), - 'disliked': _create_score_filter(-1), - 'tag-count': search_util.create_num_filter(db.Post.tag_count), - 'comment-count': - search_util.create_num_filter(db.Post.comment_count), - 'fav-count': - search_util.create_num_filter(db.Post.favorite_count), - 'note-count': search_util.create_num_filter(db.Post.note_count), - 'relation-count': - search_util.create_num_filter(db.Post.relation_count), - 'feature-count': - search_util.create_num_filter(db.Post.feature_count), - 'type': + def named_filters(self) -> Dict[str, Filter]: + return util.unalias_dict([ + ( + ['id'], + search_util.create_num_filter(model.Post.post_id) + ), + + ( + ['tag'], + search_util.create_subquery_filter( + model.Post.post_id, + model.PostTag.post_id, + model.TagName.name, + search_util.create_str_filter, + lambda subquery: + subquery.join(model.Tag).join(model.TagName)) + ), + + ( + ['score'], + search_util.create_num_filter(model.Post.score) + ), + + ( + ['uploader', 'upload', 'submit'], + _create_user_filter() + ), + + ( + ['comment'], + search_util.create_subquery_filter( + model.Post.post_id, + model.Comment.post_id, + model.User.name, + search_util.create_str_filter, + lambda subquery: subquery.join(model.User)) + ), + + ( + ['fav'], + search_util.create_subquery_filter( + model.Post.post_id, + model.PostFavorite.post_id, + model.User.name, + search_util.create_str_filter, + lambda subquery: subquery.join(model.User)) + ), + + ( + ['liked'], + _create_score_filter(1) + ), + ( + ['disliked'], + _create_score_filter(-1) + ), + + ( + ['tag-count'], + search_util.create_num_filter(model.Post.tag_count) + ), + + ( + ['comment-count'], + search_util.create_num_filter(model.Post.comment_count) + ), + + ( + ['fav-count'], + search_util.create_num_filter(model.Post.favorite_count) + ), + + ( + ['note-count'], + search_util.create_num_filter(model.Post.note_count) + ), + + ( + ['relation-count'], + search_util.create_num_filter(model.Post.relation_count) + ), + + ( + ['feature-count'], + search_util.create_num_filter(model.Post.feature_count) + ), + + ( + ['type'], search_util.create_str_filter( - db.Post.type, _type_transformer), - 'content-checksum': search_util.create_str_filter( - db.Post.checksum), - 'file-size': search_util.create_num_filter(db.Post.file_size), - ('image-width', 'width'): - search_util.create_num_filter(db.Post.canvas_width), - ('image-height', 'height'): - search_util.create_num_filter(db.Post.canvas_height), - ('image-area', 'area'): - search_util.create_num_filter(db.Post.canvas_area), - ('creation-date', 'creation-time', 'date', 'time'): - search_util.create_date_filter(db.Post.creation_time), - ('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'): - search_util.create_date_filter(db.Post.last_edit_time), - ('comment-date', 'comment-time'): + model.Post.type, _type_transformer) + ), + + ( + ['content-checksum'], + search_util.create_str_filter(model.Post.checksum) + ), + + ( + ['file-size'], + search_util.create_num_filter(model.Post.file_size) + ), + + ( + ['image-width', 'width'], + search_util.create_num_filter(model.Post.canvas_width) + ), + + ( + ['image-height', 'height'], + search_util.create_num_filter(model.Post.canvas_height) + ), + + ( + ['image-area', 'area'], + search_util.create_num_filter(model.Post.canvas_area) + ), + + ( + ['creation-date', 'creation-time', 'date', 'time'], + search_util.create_date_filter(model.Post.creation_time) + ), + + ( + ['last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'], + search_util.create_date_filter(model.Post.last_edit_time) + ), + + ( + ['comment-date', 'comment-time'], search_util.create_date_filter( - db.Post.last_comment_creation_time), - ('fav-date', 'fav-time'): - search_util.create_date_filter(db.Post.last_favorite_time), - ('feature-date', 'feature-time'): - search_util.create_date_filter(db.Post.last_feature_time), - ('safety', 'rating'): + model.Post.last_comment_creation_time) + ), + + ( + ['fav-date', 'fav-time'], + search_util.create_date_filter(model.Post.last_favorite_time) + ), + + ( + ['feature-date', 'feature-time'], + search_util.create_date_filter(model.Post.last_feature_time) + ), + + ( + ['safety', 'rating'], search_util.create_str_filter( - db.Post.safety, _safety_transformer), - }) + model.Post.safety, _safety_transformer) + ), + ]) @property - def sort_columns(self): - return util.unalias_dict({ - 'random': (func.random(), None), - 'id': (db.Post.post_id, self.SORT_DESC), - 'score': (db.Post.score, self.SORT_DESC), - 'tag-count': (db.Post.tag_count, self.SORT_DESC), - 'comment-count': (db.Post.comment_count, self.SORT_DESC), - 'fav-count': (db.Post.favorite_count, self.SORT_DESC), - 'note-count': (db.Post.note_count, self.SORT_DESC), - 'relation-count': (db.Post.relation_count, self.SORT_DESC), - 'feature-count': (db.Post.feature_count, self.SORT_DESC), - 'file-size': (db.Post.file_size, self.SORT_DESC), - ('image-width', 'width'): - (db.Post.canvas_width, self.SORT_DESC), - ('image-height', 'height'): - (db.Post.canvas_height, self.SORT_DESC), - ('image-area', 'area'): - (db.Post.canvas_area, self.SORT_DESC), - ('creation-date', 'creation-time', 'date', 'time'): - (db.Post.creation_time, self.SORT_DESC), - ('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'): - (db.Post.last_edit_time, self.SORT_DESC), - ('comment-date', 'comment-time'): - (db.Post.last_comment_creation_time, self.SORT_DESC), - ('fav-date', 'fav-time'): - (db.Post.last_favorite_time, self.SORT_DESC), - ('feature-date', 'feature-time'): - (db.Post.last_feature_time, self.SORT_DESC), - }) + def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]: + return util.unalias_dict([ + ( + ['random'], + (sa.sql.expression.func.random(), self.SORT_NONE) + ), + + ( + ['id'], + (model.Post.post_id, self.SORT_DESC) + ), + + ( + ['score'], + (model.Post.score, self.SORT_DESC) + ), + + ( + ['tag-count'], + (model.Post.tag_count, self.SORT_DESC) + ), + + ( + ['comment-count'], + (model.Post.comment_count, self.SORT_DESC) + ), + + ( + ['fav-count'], + (model.Post.favorite_count, self.SORT_DESC) + ), + + ( + ['note-count'], + (model.Post.note_count, self.SORT_DESC) + ), + + ( + ['relation-count'], + (model.Post.relation_count, self.SORT_DESC) + ), + + ( + ['feature-count'], + (model.Post.feature_count, self.SORT_DESC) + ), + + ( + ['file-size'], + (model.Post.file_size, self.SORT_DESC) + ), + + ( + ['image-width', 'width'], + (model.Post.canvas_width, self.SORT_DESC) + ), + + ( + ['image-height', 'height'], + (model.Post.canvas_height, self.SORT_DESC) + ), + + ( + ['image-area', 'area'], + (model.Post.canvas_area, self.SORT_DESC) + ), + + ( + ['creation-date', 'creation-time', 'date', 'time'], + (model.Post.creation_time, self.SORT_DESC) + ), + + ( + ['last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'], + (model.Post.last_edit_time, self.SORT_DESC) + ), + + ( + ['comment-date', 'comment-time'], + (model.Post.last_comment_creation_time, self.SORT_DESC) + ), + + ( + ['fav-date', 'fav-time'], + (model.Post.last_favorite_time, self.SORT_DESC) + ), + + ( + ['feature-date', 'feature-time'], + (model.Post.last_feature_time, self.SORT_DESC) + ), + ]) @property - def special_filters(self): + def special_filters(self) -> Dict[str, Filter]: return { - # handled by parsed - 'fav': None, - 'liked': None, - 'disliked': None, + # handled by parser + 'fav': self.noop_filter, + 'liked': self.noop_filter, + 'disliked': self.noop_filter, 'tumbleweed': self.tumbleweed_filter, } - def tumbleweed_filter(self, query, negated): - expr = \ - (db.Post.comment_count == 0) \ - & (db.Post.favorite_count == 0) \ - & (db.Post.score == 0) + def noop_filter( + self, + query: SaQuery, + _criterion: Optional[criteria.BaseCriterion], + _negated: bool) -> SaQuery: + return query + + def tumbleweed_filter( + self, + query: SaQuery, + _criterion: Optional[criteria.BaseCriterion], + negated: bool) -> SaQuery: + expr = ( + (model.Post.comment_count == 0) + & (model.Post.favorite_count == 0) + & (model.Post.score == 0)) if negated: expr = ~expr return query.filter(expr) diff --git a/server/szurubooru/search/configs/snapshot_search_config.py b/server/szurubooru/search/configs/snapshot_search_config.py index 4ea7280a..0fdb69d0 100644 --- a/server/szurubooru/search/configs/snapshot_search_config.py +++ b/server/szurubooru/search/configs/snapshot_search_config.py @@ -1,28 +1,37 @@ -from szurubooru import db +from typing import Dict +from szurubooru import db, model +from szurubooru.search.typing import SaQuery from szurubooru.search.configs import util as search_util -from szurubooru.search.configs.base_search_config import BaseSearchConfig +from szurubooru.search.configs.base_search_config import ( + BaseSearchConfig, Filter) class SnapshotSearchConfig(BaseSearchConfig): - def create_filter_query(self, _disable_eager_loads): - return db.session.query(db.Snapshot) + def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery: + return db.session.query(model.Snapshot) - def create_count_query(self, _disable_eager_loads): - return db.session.query(db.Snapshot) + def create_count_query(self, _disable_eager_loads: bool) -> SaQuery: + return db.session.query(model.Snapshot) - def create_around_query(self): + def create_around_query(self) -> SaQuery: raise NotImplementedError() - def finalize_query(self, query): - return query.order_by(db.Snapshot.creation_time.desc()) + def finalize_query(self, query: SaQuery) -> SaQuery: + return query.order_by(model.Snapshot.creation_time.desc()) @property - def named_filters(self): + def named_filters(self) -> Dict[str, Filter]: return { - 'type': search_util.create_str_filter(db.Snapshot.resource_type), - 'id': search_util.create_str_filter(db.Snapshot.resource_name), - 'date': search_util.create_date_filter(db.Snapshot.creation_time), - 'time': search_util.create_date_filter(db.Snapshot.creation_time), - 'operation': search_util.create_str_filter(db.Snapshot.operation), - 'user': search_util.create_str_filter(db.User.name), + 'type': + search_util.create_str_filter(model.Snapshot.resource_type), + 'id': + search_util.create_str_filter(model.Snapshot.resource_name), + 'date': + search_util.create_date_filter(model.Snapshot.creation_time), + 'time': + search_util.create_date_filter(model.Snapshot.creation_time), + 'operation': + search_util.create_str_filter(model.Snapshot.operation), + 'user': + search_util.create_str_filter(model.User.name), } diff --git a/server/szurubooru/search/configs/tag_search_config.py b/server/szurubooru/search/configs/tag_search_config.py index 4595d82f..6dba5b02 100644 --- a/server/szurubooru/search/configs/tag_search_config.py +++ b/server/szurubooru/search/configs/tag_search_config.py @@ -1,79 +1,134 @@ -from sqlalchemy.orm import subqueryload, lazyload, defer -from sqlalchemy.sql.expression import func -from szurubooru import db +from typing import Tuple, Dict +import sqlalchemy as sa +from szurubooru import db, model from szurubooru.func import util +from szurubooru.search.typing import SaColumn, SaQuery from szurubooru.search.configs import util as search_util -from szurubooru.search.configs.base_search_config import BaseSearchConfig +from szurubooru.search.configs.base_search_config import ( + BaseSearchConfig, Filter) class TagSearchConfig(BaseSearchConfig): - def create_filter_query(self, _disable_eager_loads): - strategy = lazyload if _disable_eager_loads else subqueryload - return db.session.query(db.Tag) \ - .join(db.TagCategory) \ + def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery: + strategy = ( + sa.orm.lazyload + if _disable_eager_loads + else sa.orm.subqueryload) + return db.session.query(model.Tag) \ + .join(model.TagCategory) \ .options( - defer(db.Tag.first_name), - defer(db.Tag.suggestion_count), - defer(db.Tag.implication_count), - defer(db.Tag.post_count), - strategy(db.Tag.names), - strategy(db.Tag.suggestions).joinedload(db.Tag.names), - strategy(db.Tag.implications).joinedload(db.Tag.names)) + sa.orm.defer(model.Tag.first_name), + sa.orm.defer(model.Tag.suggestion_count), + sa.orm.defer(model.Tag.implication_count), + sa.orm.defer(model.Tag.post_count), + strategy(model.Tag.names), + strategy(model.Tag.suggestions).joinedload(model.Tag.names), + strategy(model.Tag.implications).joinedload(model.Tag.names)) - def create_count_query(self, _disable_eager_loads): - return db.session.query(db.Tag) + def create_count_query(self, _disable_eager_loads: bool) -> SaQuery: + return db.session.query(model.Tag) - def create_around_query(self): + def create_around_query(self) -> SaQuery: raise NotImplementedError() - def finalize_query(self, query): - return query.order_by(db.Tag.first_name.asc()) + def finalize_query(self, query: SaQuery) -> SaQuery: + return query.order_by(model.Tag.first_name.asc()) @property - def anonymous_filter(self): + def anonymous_filter(self) -> Filter: return search_util.create_subquery_filter( - db.Tag.tag_id, - db.TagName.tag_id, - db.TagName.name, + model.Tag.tag_id, + model.TagName.tag_id, + model.TagName.name, search_util.create_str_filter) @property - def named_filters(self): - return util.unalias_dict({ - 'name': search_util.create_subquery_filter( - db.Tag.tag_id, - db.TagName.tag_id, - db.TagName.name, - search_util.create_str_filter), - 'category': search_util.create_subquery_filter( - db.Tag.category_id, - db.TagCategory.tag_category_id, - db.TagCategory.name, - search_util.create_str_filter), - ('creation-date', 'creation-time'): - search_util.create_date_filter(db.Tag.creation_time), - ('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'): - search_util.create_date_filter(db.Tag.last_edit_time), - ('usage-count', 'post-count', 'usages'): - search_util.create_num_filter(db.Tag.post_count), - 'suggestion-count': - search_util.create_num_filter(db.Tag.suggestion_count), - 'implication-count': - search_util.create_num_filter(db.Tag.implication_count), - }) + def named_filters(self) -> Dict[str, Filter]: + return util.unalias_dict([ + ( + ['name'], + search_util.create_subquery_filter( + model.Tag.tag_id, + model.TagName.tag_id, + model.TagName.name, + search_util.create_str_filter) + ), + + ( + ['category'], + search_util.create_subquery_filter( + model.Tag.category_id, + model.TagCategory.tag_category_id, + model.TagCategory.name, + search_util.create_str_filter) + ), + + ( + ['creation-date', 'creation-time'], + search_util.create_date_filter(model.Tag.creation_time) + ), + + ( + ['last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'], + search_util.create_date_filter(model.Tag.last_edit_time) + ), + + ( + ['usage-count', 'post-count', 'usages'], + search_util.create_num_filter(model.Tag.post_count) + ), + + ( + ['suggestion-count'], + search_util.create_num_filter(model.Tag.suggestion_count) + ), + + ( + ['implication-count'], + search_util.create_num_filter(model.Tag.implication_count) + ), + ]) @property - def sort_columns(self): - return util.unalias_dict({ - 'random': (func.random(), None), - 'name': (db.Tag.first_name, self.SORT_ASC), - 'category': (db.TagCategory.name, self.SORT_ASC), - ('creation-date', 'creation-time'): - (db.Tag.creation_time, self.SORT_DESC), - ('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'): - (db.Tag.last_edit_time, self.SORT_DESC), - ('usage-count', 'post-count', 'usages'): - (db.Tag.post_count, self.SORT_DESC), - 'suggestion-count': (db.Tag.suggestion_count, self.SORT_DESC), - 'implication-count': (db.Tag.implication_count, self.SORT_DESC), - }) + def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]: + return util.unalias_dict([ + ( + ['random'], + (sa.sql.expression.func.random(), self.SORT_NONE) + ), + + ( + ['name'], + (model.Tag.first_name, self.SORT_ASC) + ), + + ( + ['category'], + (model.TagCategory.name, self.SORT_ASC) + ), + + ( + ['creation-date', 'creation-time'], + (model.Tag.creation_time, self.SORT_DESC) + ), + + ( + ['last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'], + (model.Tag.last_edit_time, self.SORT_DESC) + ), + + ( + ['usage-count', 'post-count', 'usages'], + (model.Tag.post_count, self.SORT_DESC) + ), + + ( + ['suggestion-count'], + (model.Tag.suggestion_count, self.SORT_DESC) + ), + + ( + ['implication-count'], + (model.Tag.implication_count, self.SORT_DESC) + ), + ]) diff --git a/server/szurubooru/search/configs/user_search_config.py b/server/szurubooru/search/configs/user_search_config.py index c7e727e6..64534009 100644 --- a/server/szurubooru/search/configs/user_search_config.py +++ b/server/szurubooru/search/configs/user_search_config.py @@ -1,53 +1,57 @@ -from sqlalchemy.sql.expression import func -from szurubooru import db +from typing import Tuple, Dict +import sqlalchemy as sa +from szurubooru import db, model +from szurubooru.search.typing import SaColumn, SaQuery from szurubooru.search.configs import util as search_util -from szurubooru.search.configs.base_search_config import BaseSearchConfig +from szurubooru.search.configs.base_search_config import ( + BaseSearchConfig, Filter) class UserSearchConfig(BaseSearchConfig): - def create_filter_query(self, _disable_eager_loads): - return db.session.query(db.User) + def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery: + return db.session.query(model.User) - def create_count_query(self, _disable_eager_loads): - return db.session.query(db.User) + def create_count_query(self, _disable_eager_loads: bool) -> SaQuery: + return db.session.query(model.User) - def create_around_query(self): + def create_around_query(self) -> SaQuery: raise NotImplementedError() - def finalize_query(self, query): - return query.order_by(db.User.name.asc()) + def finalize_query(self, query: SaQuery) -> SaQuery: + return query.order_by(model.User.name.asc()) @property - def anonymous_filter(self): - return search_util.create_str_filter(db.User.name) + def anonymous_filter(self) -> Filter: + return search_util.create_str_filter(model.User.name) @property - def named_filters(self): + def named_filters(self) -> Dict[str, Filter]: return { - 'name': search_util.create_str_filter(db.User.name), + 'name': + search_util.create_str_filter(model.User.name), 'creation-date': - search_util.create_date_filter(db.User.creation_time), + search_util.create_date_filter(model.User.creation_time), 'creation-time': - search_util.create_date_filter(db.User.creation_time), + search_util.create_date_filter(model.User.creation_time), 'last-login-date': - search_util.create_date_filter(db.User.last_login_time), + search_util.create_date_filter(model.User.last_login_time), 'last-login-time': - search_util.create_date_filter(db.User.last_login_time), + search_util.create_date_filter(model.User.last_login_time), 'login-date': - search_util.create_date_filter(db.User.last_login_time), + search_util.create_date_filter(model.User.last_login_time), 'login-time': - search_util.create_date_filter(db.User.last_login_time), + search_util.create_date_filter(model.User.last_login_time), } @property - def sort_columns(self): + def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]: return { - 'random': (func.random(), None), - 'name': (db.User.name, self.SORT_ASC), - 'creation-date': (db.User.creation_time, self.SORT_DESC), - 'creation-time': (db.User.creation_time, self.SORT_DESC), - 'last-login-date': (db.User.last_login_time, self.SORT_DESC), - 'last-login-time': (db.User.last_login_time, self.SORT_DESC), - 'login-date': (db.User.last_login_time, self.SORT_DESC), - 'login-time': (db.User.last_login_time, self.SORT_DESC), + 'random': (sa.sql.expression.func.random(), self.SORT_NONE), + 'name': (model.User.name, self.SORT_ASC), + 'creation-date': (model.User.creation_time, self.SORT_DESC), + 'creation-time': (model.User.creation_time, self.SORT_DESC), + 'last-login-date': (model.User.last_login_time, self.SORT_DESC), + 'last-login-time': (model.User.last_login_time, self.SORT_DESC), + 'login-date': (model.User.last_login_time, self.SORT_DESC), + 'login-time': (model.User.last_login_time, self.SORT_DESC), } diff --git a/server/szurubooru/search/configs/util.py b/server/szurubooru/search/configs/util.py index 2eaaf8d7..086f3921 100644 --- a/server/szurubooru/search/configs/util.py +++ b/server/szurubooru/search/configs/util.py @@ -1,10 +1,13 @@ -import sqlalchemy +from typing import Any, Optional, Callable +import sqlalchemy as sa from szurubooru import db, errors from szurubooru.func import util from szurubooru.search import criteria +from szurubooru.search.typing import SaColumn, SaQuery +from szurubooru.search.configs.base_search_config import Filter -def wildcard_transformer(value): +def wildcard_transformer(value: str) -> str: return ( value .replace('\\', '\\\\') @@ -13,24 +16,21 @@ def wildcard_transformer(value): .replace('*', '%')) -def apply_num_criterion_to_column(column, criterion): - ''' - Decorate SQLAlchemy filter on given column using supplied criterion. - ''' +def apply_num_criterion_to_column( + column: Any, criterion: criteria.BaseCriterion) -> Any: try: if isinstance(criterion, criteria.PlainCriterion): expr = column == int(criterion.value) elif isinstance(criterion, criteria.ArrayCriterion): expr = column.in_(int(value) for value in criterion.values) elif isinstance(criterion, criteria.RangedCriterion): - assert criterion.min_value != '' \ - or criterion.max_value != '' - if criterion.min_value != '' and criterion.max_value != '': + assert criterion.min_value or criterion.max_value + if criterion.min_value and criterion.max_value: expr = column.between( int(criterion.min_value), int(criterion.max_value)) - elif criterion.min_value != '': + elif criterion.min_value: expr = column >= int(criterion.min_value) - elif criterion.max_value != '': + elif criterion.max_value: expr = column <= int(criterion.max_value) else: assert False @@ -40,10 +40,13 @@ def apply_num_criterion_to_column(column, criterion): return expr -def create_num_filter(column): - def wrapper(query, criterion, negated): - expr = apply_num_criterion_to_column( - column, criterion) +def create_num_filter(column: Any) -> Filter: + def wrapper( + query: SaQuery, + criterion: Optional[criteria.BaseCriterion], + negated: bool) -> SaQuery: + assert criterion + expr = apply_num_criterion_to_column(column, criterion) if negated: expr = ~expr return query.filter(expr) @@ -51,14 +54,13 @@ def create_num_filter(column): def apply_str_criterion_to_column( - column, criterion, transformer=wildcard_transformer): - ''' - Decorate SQLAlchemy filter on given column using supplied criterion. - ''' + column: SaColumn, + criterion: criteria.BaseCriterion, + transformer: Callable[[str], str]=wildcard_transformer) -> SaQuery: if isinstance(criterion, criteria.PlainCriterion): expr = column.ilike(transformer(criterion.value)) elif isinstance(criterion, criteria.ArrayCriterion): - expr = sqlalchemy.sql.false() + expr = sa.sql.false() for value in criterion.values: expr = expr | column.ilike(transformer(value)) elif isinstance(criterion, criteria.RangedCriterion): @@ -68,8 +70,15 @@ def apply_str_criterion_to_column( return expr -def create_str_filter(column, transformer=wildcard_transformer): - def wrapper(query, criterion, negated): +def create_str_filter( + column: SaColumn, + transformer: Callable[[str], str]=wildcard_transformer +) -> Filter: + def wrapper( + query: SaQuery, + criterion: Optional[criteria.BaseCriterion], + negated: bool) -> SaQuery: + assert criterion expr = apply_str_criterion_to_column( column, criterion, transformer) if negated: @@ -78,16 +87,13 @@ def create_str_filter(column, transformer=wildcard_transformer): return wrapper -def apply_date_criterion_to_column(column, criterion): - ''' - Decorate SQLAlchemy filter on given column using supplied criterion. - Parse the datetime inside the criterion. - ''' +def apply_date_criterion_to_column( + column: SaQuery, criterion: criteria.BaseCriterion) -> SaQuery: if isinstance(criterion, criteria.PlainCriterion): min_date, max_date = util.parse_time_range(criterion.value) expr = column.between(min_date, max_date) elif isinstance(criterion, criteria.ArrayCriterion): - expr = sqlalchemy.sql.false() + expr = sa.sql.false() for value in criterion.values: min_date, max_date = util.parse_time_range(value) expr = expr | column.between(min_date, max_date) @@ -108,10 +114,13 @@ def apply_date_criterion_to_column(column, criterion): return expr -def create_date_filter(column): - def wrapper(query, criterion, negated): - expr = apply_date_criterion_to_column( - column, criterion) +def create_date_filter(column: SaColumn) -> Filter: + def wrapper( + query: SaQuery, + criterion: Optional[criteria.BaseCriterion], + negated: bool) -> SaQuery: + assert criterion + expr = apply_date_criterion_to_column(column, criterion) if negated: expr = ~expr return query.filter(expr) @@ -119,18 +128,22 @@ def create_date_filter(column): def create_subquery_filter( - left_id_column, - right_id_column, - filter_column, - filter_factory, - subquery_decorator=None): + left_id_column: SaColumn, + right_id_column: SaColumn, + filter_column: SaColumn, + filter_factory: SaColumn, + subquery_decorator: Callable[[SaQuery], None]=None) -> Filter: filter_func = filter_factory(filter_column) - def wrapper(query, criterion, negated): + def wrapper( + query: SaQuery, + criterion: Optional[criteria.BaseCriterion], + negated: bool) -> SaQuery: + assert criterion subquery = db.session.query(right_id_column.label('foreign_id')) if subquery_decorator: subquery = subquery_decorator(subquery) - subquery = subquery.options(sqlalchemy.orm.lazyload('*')) + subquery = subquery.options(sa.orm.lazyload('*')) subquery = filter_func(subquery, criterion, False) subquery = subquery.subquery('t') expression = left_id_column.in_(subquery) diff --git a/server/szurubooru/search/criteria.py b/server/szurubooru/search/criteria.py index 9d4dc664..7b1fee31 100644 --- a/server/szurubooru/search/criteria.py +++ b/server/szurubooru/search/criteria.py @@ -1,34 +1,42 @@ -class _BaseCriterion: - def __init__(self, original_text): +from typing import Optional, List, Callable +from szurubooru.search.typing import SaQuery + + +class BaseCriterion: + def __init__(self, original_text: str) -> None: self.original_text = original_text - def __repr__(self): + def __repr__(self) -> str: return self.original_text -class RangedCriterion(_BaseCriterion): - def __init__(self, original_text, min_value, max_value): +class RangedCriterion(BaseCriterion): + def __init__( + self, + original_text: str, + min_value: Optional[str], + max_value: Optional[str]) -> None: super().__init__(original_text) self.min_value = min_value self.max_value = max_value - def __hash__(self): + def __hash__(self) -> int: return hash(('range', self.min_value, self.max_value)) -class PlainCriterion(_BaseCriterion): - def __init__(self, original_text, value): +class PlainCriterion(BaseCriterion): + def __init__(self, original_text: str, value: str) -> None: super().__init__(original_text) self.value = value - def __hash__(self): + def __hash__(self) -> int: return hash(self.value) -class ArrayCriterion(_BaseCriterion): - def __init__(self, original_text, values): +class ArrayCriterion(BaseCriterion): + def __init__(self, original_text: str, values: List[str]) -> None: super().__init__(original_text) self.values = values - def __hash__(self): + def __hash__(self) -> int: return hash(tuple(['array'] + self.values)) diff --git a/server/szurubooru/search/executor.py b/server/szurubooru/search/executor.py index d9adc940..2819593e 100644 --- a/server/szurubooru/search/executor.py +++ b/server/szurubooru/search/executor.py @@ -1,14 +1,18 @@ -import sqlalchemy -from szurubooru import db, errors +from typing import Union, Tuple, List, Dict, Callable +import sqlalchemy as sa +from szurubooru import db, model, errors, rest from szurubooru.func import cache from szurubooru.search import tokens, parser +from szurubooru.search.typing import SaQuery +from szurubooru.search.query import SearchQuery +from szurubooru.search.configs.base_search_config import BaseSearchConfig -def _format_dict_keys(source): +def _format_dict_keys(source: Dict) -> List[str]: return list(sorted(source.keys())) -def _get_order(order, default_order): +def _get_order(order: str, default_order: str) -> Union[bool, str]: if order == tokens.SortToken.SORT_DEFAULT: return default_order or tokens.SortToken.SORT_ASC if order == tokens.SortToken.SORT_NEGATED_DEFAULT: @@ -26,50 +30,57 @@ class Executor: delegates sqlalchemy filter decoration to SearchConfig instances. ''' - def __init__(self, search_config): + def __init__(self, search_config: BaseSearchConfig) -> None: self.config = search_config self.parser = parser.Parser() - def get_around(self, query_text, entity_id): + def get_around( + self, + query_text: str, + entity_id: int) -> Tuple[model.Base, model.Base]: search_query = self.parser.parse(query_text) self.config.on_search_query_parsed(search_query) filter_query = ( self.config .create_around_query() - .options(sqlalchemy.orm.lazyload('*'))) + .options(sa.orm.lazyload('*'))) filter_query = self._prepare_db_query( filter_query, search_query, False) prev_filter_query = ( filter_query .filter(self.config.id_column > entity_id) .order_by(None) - .order_by(sqlalchemy.func.abs( - self.config.id_column - entity_id).asc()) + .order_by(sa.func.abs(self.config.id_column - entity_id).asc()) .limit(1)) next_filter_query = ( filter_query .filter(self.config.id_column < entity_id) .order_by(None) - .order_by(sqlalchemy.func.abs( - self.config.id_column - entity_id).asc()) + .order_by(sa.func.abs(self.config.id_column - entity_id).asc()) .limit(1)) - return [ + return ( prev_filter_query.one_or_none(), - next_filter_query.one_or_none()] + next_filter_query.one_or_none()) - def get_around_and_serialize(self, ctx, entity_id, serializer): - entities = self.get_around(ctx.get_param_as_string('query'), entity_id) + def get_around_and_serialize( + self, + ctx: rest.Context, + entity_id: int, + serializer: Callable[[model.Base], rest.Response] + ) -> rest.Response: + entities = self.get_around( + ctx.get_param_as_string('query', default=''), entity_id) return { 'prev': serializer(entities[0]), 'next': serializer(entities[1]), } - def execute(self, query_text, page, page_size): - ''' - Parse input and return tuple containing total record count and filtered - entities. - ''' - + def execute( + self, + query_text: str, + page: int, + page_size: int + ) -> Tuple[int, List[model.Base]]: search_query = self.parser.parse(query_text) self.config.on_search_query_parsed(search_query) @@ -83,7 +94,7 @@ class Executor: return cache.get(key) filter_query = self.config.create_filter_query(disable_eager_loads) - filter_query = filter_query.options(sqlalchemy.orm.lazyload('*')) + filter_query = filter_query.options(sa.orm.lazyload('*')) filter_query = self._prepare_db_query(filter_query, search_query, True) entities = filter_query \ .offset(max(page - 1, 0) * page_size) \ @@ -91,11 +102,11 @@ class Executor: .all() count_query = self.config.create_count_query(disable_eager_loads) - count_query = count_query.options(sqlalchemy.orm.lazyload('*')) + count_query = count_query.options(sa.orm.lazyload('*')) count_query = self._prepare_db_query(count_query, search_query, False) count_statement = count_query \ .statement \ - .with_only_columns([sqlalchemy.func.count()]) \ + .with_only_columns([sa.func.count()]) \ .order_by(None) count = db.session.execute(count_statement).scalar() @@ -103,8 +114,12 @@ class Executor: cache.put(key, ret) return ret - def execute_and_serialize(self, ctx, serializer): - query = ctx.get_param_as_string('query') + def execute_and_serialize( + self, + ctx: rest.Context, + serializer: Callable[[model.Base], rest.Response] + ) -> rest.Response: + query = ctx.get_param_as_string('query', default='') page = ctx.get_param_as_int('page', default=1, min=1) page_size = ctx.get_param_as_int( 'pageSize', default=100, min=1, max=100) @@ -117,48 +132,51 @@ class Executor: 'results': [serializer(entity) for entity in entities], } - def _prepare_db_query(self, db_query, search_query, use_sort): - ''' Parse input and return SQLAlchemy query. ''' - - for token in search_query.anonymous_tokens: + def _prepare_db_query( + self, + db_query: SaQuery, + search_query: SearchQuery, + use_sort: bool) -> SaQuery: + for anon_token in search_query.anonymous_tokens: if not self.config.anonymous_filter: raise errors.SearchError( 'Anonymous tokens are not valid in this context.') db_query = self.config.anonymous_filter( - db_query, token.criterion, token.negated) + db_query, anon_token.criterion, anon_token.negated) - for token in search_query.named_tokens: - if token.name not in self.config.named_filters: + for named_token in search_query.named_tokens: + if named_token.name not in self.config.named_filters: raise errors.SearchError( 'Unknown named token: %r. Available named tokens: %r.' % ( - token.name, + named_token.name, _format_dict_keys(self.config.named_filters))) - db_query = self.config.named_filters[token.name]( - db_query, token.criterion, token.negated) + db_query = self.config.named_filters[named_token.name]( + db_query, named_token.criterion, named_token.negated) - for token in search_query.special_tokens: - if token.value not in self.config.special_filters: + for sp_token in search_query.special_tokens: + if sp_token.value not in self.config.special_filters: raise errors.SearchError( 'Unknown special token: %r. ' 'Available special tokens: %r.' % ( - token.value, + sp_token.value, _format_dict_keys(self.config.special_filters))) - db_query = self.config.special_filters[token.value]( - db_query, token.negated) + db_query = self.config.special_filters[sp_token.value]( + db_query, None, sp_token.negated) if use_sort: - for token in search_query.sort_tokens: - if token.name not in self.config.sort_columns: + for sort_token in search_query.sort_tokens: + if sort_token.name not in self.config.sort_columns: raise errors.SearchError( 'Unknown sort token: %r. ' 'Available sort tokens: %r.' % ( - token.name, + sort_token.name, _format_dict_keys(self.config.sort_columns))) - column, default_order = self.config.sort_columns[token.name] - order = _get_order(token.order, default_order) - if order == token.SORT_ASC: + column, default_order = ( + self.config.sort_columns[sort_token.name]) + order = _get_order(sort_token.order, default_order) + if order == sort_token.SORT_ASC: db_query = db_query.order_by(column.asc()) - elif order == token.SORT_DESC: + elif order == sort_token.SORT_DESC: db_query = db_query.order_by(column.desc()) db_query = self.config.finalize_query(db_query) diff --git a/server/szurubooru/search/parser.py b/server/szurubooru/search/parser.py index 33b41173..93affe26 100644 --- a/server/szurubooru/search/parser.py +++ b/server/szurubooru/search/parser.py @@ -1,9 +1,12 @@ import re +from typing import List from szurubooru import errors from szurubooru.search import criteria, tokens +from szurubooru.search.query import SearchQuery -def _create_criterion(original_value, value): +def _create_criterion( + original_value: str, value: str) -> criteria.BaseCriterion: if ',' in value: return criteria.ArrayCriterion( original_value, value.split(',')) @@ -15,12 +18,12 @@ def _create_criterion(original_value, value): return criteria.PlainCriterion(original_value, value) -def _parse_anonymous(value, negated): +def _parse_anonymous(value: str, negated: bool) -> tokens.AnonymousToken: criterion = _create_criterion(value, value) return tokens.AnonymousToken(criterion, negated) -def _parse_named(key, value, negated): +def _parse_named(key: str, value: str, negated: bool) -> tokens.NamedToken: original_value = value if key.endswith('-min'): key = key[:-4] @@ -32,11 +35,11 @@ def _parse_named(key, value, negated): return tokens.NamedToken(key, criterion, negated) -def _parse_special(value, negated): +def _parse_special(value: str, negated: bool) -> tokens.SpecialToken: return tokens.SpecialToken(value, negated) -def _parse_sort(value, negated): +def _parse_sort(value: str, negated: bool) -> tokens.SortToken: if value.count(',') == 0: order_str = None elif value.count(',') == 1: @@ -67,23 +70,8 @@ def _parse_sort(value, negated): return tokens.SortToken(value, order) -class SearchQuery: - def __init__(self): - self.anonymous_tokens = [] - self.named_tokens = [] - self.special_tokens = [] - self.sort_tokens = [] - - def __hash__(self): - return hash(( - tuple(self.anonymous_tokens), - tuple(self.named_tokens), - tuple(self.special_tokens), - tuple(self.sort_tokens))) - - class Parser: - def parse(self, query_text): + def parse(self, query_text: str) -> SearchQuery: query = SearchQuery() for chunk in re.split(r'\s+', (query_text or '').lower()): if not chunk: diff --git a/server/szurubooru/search/query.py b/server/szurubooru/search/query.py new file mode 100644 index 00000000..7d29dbd3 --- /dev/null +++ b/server/szurubooru/search/query.py @@ -0,0 +1,16 @@ +from szurubooru.search import tokens + + +class SearchQuery: + def __init__(self) -> None: + self.anonymous_tokens = [] # type: List[tokens.AnonymousToken] + self.named_tokens = [] # type: List[tokens.NamedToken] + self.special_tokens = [] # type: List[tokens.SpecialToken] + self.sort_tokens = [] # type: List[tokens.SortToken] + + def __hash__(self) -> int: + return hash(( + tuple(self.anonymous_tokens), + tuple(self.named_tokens), + tuple(self.special_tokens), + tuple(self.sort_tokens))) diff --git a/server/szurubooru/search/tokens.py b/server/szurubooru/search/tokens.py index cff7dc5f..0cd7fd7d 100644 --- a/server/szurubooru/search/tokens.py +++ b/server/szurubooru/search/tokens.py @@ -1,39 +1,44 @@ +from szurubooru.search.criteria import BaseCriterion + + class AnonymousToken: - def __init__(self, criterion, negated): + def __init__(self, criterion: BaseCriterion, negated: bool) -> None: self.criterion = criterion self.negated = negated - def __hash__(self): + def __hash__(self) -> int: return hash((self.criterion, self.negated)) class NamedToken(AnonymousToken): - def __init__(self, name, criterion, negated): + def __init__( + self, name: str, criterion: BaseCriterion, negated: bool) -> None: super().__init__(criterion, negated) self.name = name - def __hash__(self): + def __hash__(self) -> int: return hash((self.name, self.criterion, self.negated)) class SortToken: SORT_DESC = 'desc' SORT_ASC = 'asc' + SORT_NONE = '' SORT_DEFAULT = 'default' SORT_NEGATED_DEFAULT = 'negated default' - def __init__(self, name, order): + def __init__(self, name: str, order: str) -> None: self.name = name self.order = order - def __hash__(self): + def __hash__(self) -> int: return hash((self.name, self.order)) class SpecialToken: - def __init__(self, value, negated): + def __init__(self, value: str, negated: bool) -> None: self.value = value self.negated = negated - def __hash__(self): + def __hash__(self) -> int: return hash((self.value, self.negated)) diff --git a/server/szurubooru/search/typing.py b/server/szurubooru/search/typing.py new file mode 100644 index 00000000..ebb1b30d --- /dev/null +++ b/server/szurubooru/search/typing.py @@ -0,0 +1,6 @@ +from typing import Any, Callable + + +SaColumn = Any +SaQuery = Any +SaQueryFactory = Callable[[], SaQuery] diff --git a/server/szurubooru/tests/api/test_comment_creating.py b/server/szurubooru/tests/api/test_comment_creating.py index c7d0b0f6..ad243661 100644 --- a/server/szurubooru/tests/api/test_comment_creating.py +++ b/server/szurubooru/tests/api/test_comment_creating.py @@ -1,19 +1,20 @@ from datetime import datetime from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import comments, posts @pytest.fixture(autouse=True) def inject_config(config_injector): - config_injector({'privileges': {'comments:create': db.User.RANK_REGULAR}}) + config_injector( + {'privileges': {'comments:create': model.User.RANK_REGULAR}}) def test_creating_comment( user_factory, post_factory, context_factory, fake_datetime): post = post_factory() - user = user_factory(rank=db.User.RANK_REGULAR) + user = user_factory(rank=model.User.RANK_REGULAR) db.session.add_all([post, user]) db.session.flush() with patch('szurubooru.func.comments.serialize_comment'), \ @@ -24,7 +25,7 @@ def test_creating_comment( params={'text': 'input', 'postId': post.post_id}, user=user)) assert result == 'serialized comment' - comment = db.session.query(db.Comment).one() + comment = db.session.query(model.Comment).one() assert comment.text == 'input' assert comment.creation_time == datetime(1997, 1, 1) assert comment.last_edit_time is None @@ -41,7 +42,7 @@ def test_creating_comment( def test_trying_to_pass_invalid_params( user_factory, post_factory, context_factory, params): post = post_factory() - user = user_factory(rank=db.User.RANK_REGULAR) + user = user_factory(rank=model.User.RANK_REGULAR) db.session.add_all([post, user]) db.session.flush() real_params = {'text': 'input', 'postId': post.post_id} @@ -63,11 +64,11 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field): api.comment_api.create_comment( context_factory( params={}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) def test_trying_to_comment_non_existing(user_factory, context_factory): - user = user_factory(rank=db.User.RANK_REGULAR) + user = user_factory(rank=model.User.RANK_REGULAR) db.session.add_all([user]) db.session.flush() with pytest.raises(posts.PostNotFoundError): @@ -81,4 +82,4 @@ def test_trying_to_create_without_privileges(user_factory, context_factory): api.comment_api.create_comment( context_factory( params={}, - user=user_factory(rank=db.User.RANK_ANONYMOUS))) + user=user_factory(rank=model.User.RANK_ANONYMOUS))) diff --git a/server/szurubooru/tests/api/test_comment_deleting.py b/server/szurubooru/tests/api/test_comment_deleting.py index efb432a6..e1d1baa0 100644 --- a/server/szurubooru/tests/api/test_comment_deleting.py +++ b/server/szurubooru/tests/api/test_comment_deleting.py @@ -1,5 +1,5 @@ import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import comments @@ -7,8 +7,8 @@ from szurubooru.func import comments def inject_config(config_injector): config_injector({ 'privileges': { - 'comments:delete:own': db.User.RANK_REGULAR, - 'comments:delete:any': db.User.RANK_MODERATOR, + 'comments:delete:own': model.User.RANK_REGULAR, + 'comments:delete:any': model.User.RANK_MODERATOR, }, }) @@ -22,26 +22,26 @@ def test_deleting_own_comment(user_factory, comment_factory, context_factory): context_factory(params={'version': 1}, user=user), {'comment_id': comment.comment_id}) assert result == {} - assert db.session.query(db.Comment).count() == 0 + assert db.session.query(model.Comment).count() == 0 def test_deleting_someones_else_comment( user_factory, comment_factory, context_factory): - user1 = user_factory(rank=db.User.RANK_REGULAR) - user2 = user_factory(rank=db.User.RANK_MODERATOR) + user1 = user_factory(rank=model.User.RANK_REGULAR) + user2 = user_factory(rank=model.User.RANK_MODERATOR) comment = comment_factory(user=user1) db.session.add(comment) db.session.commit() api.comment_api.delete_comment( context_factory(params={'version': 1}, user=user2), {'comment_id': comment.comment_id}) - assert db.session.query(db.Comment).count() == 0 + assert db.session.query(model.Comment).count() == 0 def test_trying_to_delete_someones_else_comment_without_privileges( user_factory, comment_factory, context_factory): - user1 = user_factory(rank=db.User.RANK_REGULAR) - user2 = user_factory(rank=db.User.RANK_REGULAR) + user1 = user_factory(rank=model.User.RANK_REGULAR) + user2 = user_factory(rank=model.User.RANK_REGULAR) comment = comment_factory(user=user1) db.session.add(comment) db.session.commit() @@ -49,7 +49,7 @@ def test_trying_to_delete_someones_else_comment_without_privileges( api.comment_api.delete_comment( context_factory(params={'version': 1}, user=user2), {'comment_id': comment.comment_id}) - assert db.session.query(db.Comment).count() == 1 + assert db.session.query(model.Comment).count() == 1 def test_trying_to_delete_non_existing(user_factory, context_factory): @@ -57,5 +57,5 @@ def test_trying_to_delete_non_existing(user_factory, context_factory): api.comment_api.delete_comment( context_factory( params={'version': 1}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'comment_id': 1}) diff --git a/server/szurubooru/tests/api/test_comment_rating.py b/server/szurubooru/tests/api/test_comment_rating.py index 981e0dd8..aae5e241 100644 --- a/server/szurubooru/tests/api/test_comment_rating.py +++ b/server/szurubooru/tests/api/test_comment_rating.py @@ -1,17 +1,18 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import comments @pytest.fixture(autouse=True) def inject_config(config_injector): - config_injector({'privileges': {'comments:score': db.User.RANK_REGULAR}}) + config_injector( + {'privileges': {'comments:score': model.User.RANK_REGULAR}}) def test_simple_rating( user_factory, comment_factory, context_factory, fake_datetime): - user = user_factory(rank=db.User.RANK_REGULAR) + user = user_factory(rank=model.User.RANK_REGULAR) comment = comment_factory(user=user) db.session.add(comment) db.session.commit() @@ -22,14 +23,14 @@ def test_simple_rating( context_factory(params={'score': 1}, user=user), {'comment_id': comment.comment_id}) assert result == 'serialized comment' - assert db.session.query(db.CommentScore).count() == 1 + assert db.session.query(model.CommentScore).count() == 1 assert comment is not None assert comment.score == 1 def test_updating_rating( user_factory, comment_factory, context_factory, fake_datetime): - user = user_factory(rank=db.User.RANK_REGULAR) + user = user_factory(rank=model.User.RANK_REGULAR) comment = comment_factory(user=user) db.session.add(comment) db.session.commit() @@ -42,14 +43,14 @@ def test_updating_rating( api.comment_api.set_comment_score( context_factory(params={'score': -1}, user=user), {'comment_id': comment.comment_id}) - comment = db.session.query(db.Comment).one() - assert db.session.query(db.CommentScore).count() == 1 + comment = db.session.query(model.Comment).one() + assert db.session.query(model.CommentScore).count() == 1 assert comment.score == -1 def test_updating_rating_to_zero( user_factory, comment_factory, context_factory, fake_datetime): - user = user_factory(rank=db.User.RANK_REGULAR) + user = user_factory(rank=model.User.RANK_REGULAR) comment = comment_factory(user=user) db.session.add(comment) db.session.commit() @@ -62,14 +63,14 @@ def test_updating_rating_to_zero( api.comment_api.set_comment_score( context_factory(params={'score': 0}, user=user), {'comment_id': comment.comment_id}) - comment = db.session.query(db.Comment).one() - assert db.session.query(db.CommentScore).count() == 0 + comment = db.session.query(model.Comment).one() + assert db.session.query(model.CommentScore).count() == 0 assert comment.score == 0 def test_deleting_rating( user_factory, comment_factory, context_factory, fake_datetime): - user = user_factory(rank=db.User.RANK_REGULAR) + user = user_factory(rank=model.User.RANK_REGULAR) comment = comment_factory(user=user) db.session.add(comment) db.session.commit() @@ -82,15 +83,15 @@ def test_deleting_rating( api.comment_api.delete_comment_score( context_factory(user=user), {'comment_id': comment.comment_id}) - comment = db.session.query(db.Comment).one() - assert db.session.query(db.CommentScore).count() == 0 + comment = db.session.query(model.Comment).one() + assert db.session.query(model.CommentScore).count() == 0 assert comment.score == 0 def test_ratings_from_multiple_users( user_factory, comment_factory, context_factory, fake_datetime): - user1 = user_factory(rank=db.User.RANK_REGULAR) - user2 = user_factory(rank=db.User.RANK_REGULAR) + user1 = user_factory(rank=model.User.RANK_REGULAR) + user2 = user_factory(rank=model.User.RANK_REGULAR) comment = comment_factory() db.session.add_all([user1, user2, comment]) db.session.commit() @@ -103,8 +104,8 @@ def test_ratings_from_multiple_users( api.comment_api.set_comment_score( context_factory(params={'score': -1}, user=user2), {'comment_id': comment.comment_id}) - comment = db.session.query(db.Comment).one() - assert db.session.query(db.CommentScore).count() == 2 + comment = db.session.query(model.Comment).one() + assert db.session.query(model.CommentScore).count() == 2 assert comment.score == 0 @@ -125,7 +126,7 @@ def test_trying_to_update_non_existing(user_factory, context_factory): api.comment_api.set_comment_score( context_factory( params={'score': 1}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'comment_id': 5}) @@ -138,5 +139,5 @@ def test_trying_to_rate_without_privileges( api.comment_api.set_comment_score( context_factory( params={'score': 1}, - user=user_factory(rank=db.User.RANK_ANONYMOUS)), + user=user_factory(rank=model.User.RANK_ANONYMOUS)), {'comment_id': comment.comment_id}) diff --git a/server/szurubooru/tests/api/test_comment_retrieving.py b/server/szurubooru/tests/api/test_comment_retrieving.py index 908e9eb8..e0378fa2 100644 --- a/server/szurubooru/tests/api/test_comment_retrieving.py +++ b/server/szurubooru/tests/api/test_comment_retrieving.py @@ -1,6 +1,6 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import comments @@ -8,8 +8,8 @@ from szurubooru.func import comments def inject_config(config_injector): config_injector({ 'privileges': { - 'comments:list': db.User.RANK_REGULAR, - 'comments:view': db.User.RANK_REGULAR, + 'comments:list': model.User.RANK_REGULAR, + 'comments:view': model.User.RANK_REGULAR, }, }) @@ -24,7 +24,7 @@ def test_retrieving_multiple(user_factory, comment_factory, context_factory): result = api.comment_api.get_comments( context_factory( params={'query': '', 'page': 1}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) assert result == { 'query': '', 'page': 1, @@ -40,7 +40,7 @@ def test_trying_to_retrieve_multiple_without_privileges( api.comment_api.get_comments( context_factory( params={'query': '', 'page': 1}, - user=user_factory(rank=db.User.RANK_ANONYMOUS))) + user=user_factory(rank=model.User.RANK_ANONYMOUS))) def test_retrieving_single(user_factory, comment_factory, context_factory): @@ -51,7 +51,7 @@ def test_retrieving_single(user_factory, comment_factory, context_factory): comments.serialize_comment.return_value = 'serialized comment' result = api.comment_api.get_comment( context_factory( - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'comment_id': comment.comment_id}) assert result == 'serialized comment' @@ -60,7 +60,7 @@ def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): with pytest.raises(comments.CommentNotFoundError): api.comment_api.get_comment( context_factory( - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'comment_id': 5}) @@ -68,5 +68,5 @@ def test_trying_to_retrieve_single_without_privileges( user_factory, context_factory): with pytest.raises(errors.AuthError): api.comment_api.get_comment( - context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)), + context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)), {'comment_id': 5}) diff --git a/server/szurubooru/tests/api/test_comment_updating.py b/server/szurubooru/tests/api/test_comment_updating.py index 5f3d12b0..761b1ce0 100644 --- a/server/szurubooru/tests/api/test_comment_updating.py +++ b/server/szurubooru/tests/api/test_comment_updating.py @@ -1,7 +1,7 @@ from datetime import datetime from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import comments @@ -9,15 +9,15 @@ from szurubooru.func import comments def inject_config(config_injector): config_injector({ 'privileges': { - 'comments:edit:own': db.User.RANK_REGULAR, - 'comments:edit:any': db.User.RANK_MODERATOR, + 'comments:edit:own': model.User.RANK_REGULAR, + 'comments:edit:any': model.User.RANK_MODERATOR, }, }) def test_simple_updating( user_factory, comment_factory, context_factory, fake_datetime): - user = user_factory(rank=db.User.RANK_REGULAR) + user = user_factory(rank=model.User.RANK_REGULAR) comment = comment_factory(user=user) db.session.add(comment) db.session.commit() @@ -73,14 +73,14 @@ def test_trying_to_update_non_existing(user_factory, context_factory): api.comment_api.update_comment( context_factory( params={'text': 'new text'}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'comment_id': 5}) def test_trying_to_update_someones_comment_without_privileges( user_factory, comment_factory, context_factory): - user = user_factory(rank=db.User.RANK_REGULAR) - user2 = user_factory(rank=db.User.RANK_REGULAR) + user = user_factory(rank=model.User.RANK_REGULAR) + user2 = user_factory(rank=model.User.RANK_REGULAR) comment = comment_factory(user=user) db.session.add(comment) db.session.commit() @@ -93,8 +93,8 @@ def test_trying_to_update_someones_comment_without_privileges( def test_updating_someones_comment_with_privileges( user_factory, comment_factory, context_factory): - user = user_factory(rank=db.User.RANK_REGULAR) - user2 = user_factory(rank=db.User.RANK_MODERATOR) + user = user_factory(rank=model.User.RANK_REGULAR) + user2 = user_factory(rank=model.User.RANK_MODERATOR) comment = comment_factory(user=user) db.session.add(comment) db.session.commit() diff --git a/server/szurubooru/tests/api/test_password_reset.py b/server/szurubooru/tests/api/test_password_reset.py index 52b568da..e46dbbec 100644 --- a/server/szurubooru/tests/api/test_password_reset.py +++ b/server/szurubooru/tests/api/test_password_reset.py @@ -1,6 +1,6 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import auth, mailer @@ -15,7 +15,7 @@ def inject_config(config_injector): def test_reset_sending_email(context_factory, user_factory): db.session.add(user_factory( - name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')) + name='u1', rank=model.User.RANK_REGULAR, email='user@example.com')) db.session.flush() for initiating_user in ['u1', 'user@example.com']: with patch('szurubooru.func.mailer.send_mail'): @@ -39,7 +39,7 @@ def test_trying_to_reset_non_existing(context_factory): def test_trying_to_reset_without_email(context_factory, user_factory): db.session.add( - user_factory(name='u1', rank=db.User.RANK_REGULAR, email=None)) + user_factory(name='u1', rank=model.User.RANK_REGULAR, email=None)) db.session.flush() with pytest.raises(errors.ValidationError): api.password_reset_api.start_password_reset( @@ -48,7 +48,7 @@ def test_trying_to_reset_without_email(context_factory, user_factory): def test_confirming_with_good_token(context_factory, user_factory): user = user_factory( - name='u1', rank=db.User.RANK_REGULAR, email='user@example.com') + name='u1', rank=model.User.RANK_REGULAR, email='user@example.com') old_hash = user.password_hash db.session.add(user) db.session.flush() @@ -68,7 +68,7 @@ def test_trying_to_confirm_non_existing(context_factory): def test_trying_to_confirm_without_token(context_factory, user_factory): db.session.add(user_factory( - name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')) + name='u1', rank=model.User.RANK_REGULAR, email='user@example.com')) db.session.flush() with pytest.raises(errors.ValidationError): api.password_reset_api.finish_password_reset( @@ -77,7 +77,7 @@ def test_trying_to_confirm_without_token(context_factory, user_factory): def test_trying_to_confirm_with_bad_token(context_factory, user_factory): db.session.add(user_factory( - name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')) + name='u1', rank=model.User.RANK_REGULAR, email='user@example.com')) db.session.flush() with pytest.raises(errors.ValidationError): api.password_reset_api.finish_password_reset( diff --git a/server/szurubooru/tests/api/test_post_creating.py b/server/szurubooru/tests/api/test_post_creating.py index 9737a73b..a653b3bf 100644 --- a/server/szurubooru/tests/api/test_post_creating.py +++ b/server/szurubooru/tests/api/test_post_creating.py @@ -1,6 +1,6 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import posts, tags, snapshots, net @@ -8,16 +8,16 @@ from szurubooru.func import posts, tags, snapshots, net def inject_config(config_injector): config_injector({ 'privileges': { - 'posts:create:anonymous': db.User.RANK_REGULAR, - 'posts:create:identified': db.User.RANK_REGULAR, - 'tags:create': db.User.RANK_REGULAR, + 'posts:create:anonymous': model.User.RANK_REGULAR, + 'posts:create:identified': model.User.RANK_REGULAR, + 'tags:create': model.User.RANK_REGULAR, }, }) def test_creating_minimal_posts( context_factory, post_factory, user_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) post = post_factory() db.session.add(post) db.session.flush() @@ -53,20 +53,20 @@ def test_creating_minimal_posts( posts.update_post_thumbnail.assert_called_once_with( post, 'post-thumbnail') posts.update_post_safety.assert_called_once_with(post, 'safe') - posts.update_post_source.assert_called_once_with(post, None) + posts.update_post_source.assert_called_once_with(post, '') posts.update_post_relations.assert_called_once_with(post, []) posts.update_post_notes.assert_called_once_with(post, []) posts.update_post_flags.assert_called_once_with(post, []) posts.update_post_thumbnail.assert_called_once_with( post, 'post-thumbnail') posts.serialize_post.assert_called_once_with( - post, auth_user, options=None) + post, auth_user, options=[]) snapshots.create.assert_called_once_with(post, auth_user) tags.export_to_json.assert_called_once_with() def test_creating_full_posts(context_factory, post_factory, user_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) post = post_factory() db.session.add(post) db.session.flush() @@ -109,14 +109,14 @@ def test_creating_full_posts(context_factory, post_factory, user_factory): posts.update_post_flags.assert_called_once_with( post, ['flag1', 'flag2']) posts.serialize_post.assert_called_once_with( - post, auth_user, options=None) + post, auth_user, options=[]) snapshots.create.assert_called_once_with(post, auth_user) tags.export_to_json.assert_called_once_with() def test_anonymous_uploads( config_injector, context_factory, post_factory, user_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) post = post_factory() db.session.add(post) db.session.flush() @@ -126,7 +126,7 @@ def test_anonymous_uploads( patch('szurubooru.func.posts.create_post'), \ patch('szurubooru.func.posts.update_post_source'): config_injector({ - 'privileges': {'posts:create:anonymous': db.User.RANK_REGULAR}, + 'privileges': {'posts:create:anonymous': model.User.RANK_REGULAR}, }) posts.create_post.return_value = [post, []] api.post_api.create_post( @@ -146,7 +146,7 @@ def test_anonymous_uploads( def test_creating_from_url_saves_source( config_injector, context_factory, post_factory, user_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) post = post_factory() db.session.add(post) db.session.flush() @@ -157,7 +157,7 @@ def test_creating_from_url_saves_source( patch('szurubooru.func.posts.create_post'), \ patch('szurubooru.func.posts.update_post_source'): config_injector({ - 'privileges': {'posts:create:identified': db.User.RANK_REGULAR}, + 'privileges': {'posts:create:identified': model.User.RANK_REGULAR}, }) net.download.return_value = b'content' posts.create_post.return_value = [post, []] @@ -177,7 +177,7 @@ def test_creating_from_url_saves_source( def test_creating_from_url_with_source_specified( config_injector, context_factory, post_factory, user_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) post = post_factory() db.session.add(post) db.session.flush() @@ -188,7 +188,7 @@ def test_creating_from_url_with_source_specified( patch('szurubooru.func.posts.create_post'), \ patch('szurubooru.func.posts.update_post_source'): config_injector({ - 'privileges': {'posts:create:identified': db.User.RANK_REGULAR}, + 'privileges': {'posts:create:identified': model.User.RANK_REGULAR}, }) net.download.return_value = b'content' posts.create_post.return_value = [post, []] @@ -218,14 +218,14 @@ def test_trying_to_omit_mandatory_field(context_factory, user_factory, field): context_factory( params=params, files={'content': '...'}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) @pytest.mark.parametrize( 'field', ['tags', 'relations', 'source', 'notes', 'flags']) def test_omitting_optional_field( field, context_factory, post_factory, user_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) post = post_factory() db.session.add(post) db.session.flush() @@ -268,10 +268,10 @@ def test_errors_not_spending_ids( 'post_height': 300, }, 'privileges': { - 'posts:create:identified': db.User.RANK_REGULAR, + 'posts:create:identified': model.User.RANK_REGULAR, }, }) - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) # successful request with patch('szurubooru.func.posts.serialize_post'), \ @@ -316,7 +316,7 @@ def test_trying_to_omit_content(context_factory, user_factory): 'safety': 'safe', 'tags': ['tag1', 'tag2'], }, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) def test_trying_to_create_post_without_privileges( @@ -324,16 +324,16 @@ def test_trying_to_create_post_without_privileges( with pytest.raises(errors.AuthError): api.post_api.create_post(context_factory( params='whatever', - user=user_factory(rank=db.User.RANK_ANONYMOUS))) + user=user_factory(rank=model.User.RANK_ANONYMOUS))) def test_trying_to_create_tags_without_privileges( config_injector, context_factory, user_factory): config_injector({ 'privileges': { - 'posts:create:anonymous': db.User.RANK_REGULAR, - 'posts:create:identified': db.User.RANK_REGULAR, - 'tags:create': db.User.RANK_ADMINISTRATOR, + 'posts:create:anonymous': model.User.RANK_REGULAR, + 'posts:create:identified': model.User.RANK_REGULAR, + 'tags:create': model.User.RANK_ADMINISTRATOR, }, }) with pytest.raises(errors.AuthError), \ @@ -349,4 +349,4 @@ def test_trying_to_create_tags_without_privileges( files={ 'content': posts.EMPTY_PIXEL, }, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) diff --git a/server/szurubooru/tests/api/test_post_deleting.py b/server/szurubooru/tests/api/test_post_deleting.py index c4187ed4..643b952c 100644 --- a/server/szurubooru/tests/api/test_post_deleting.py +++ b/server/szurubooru/tests/api/test_post_deleting.py @@ -1,16 +1,16 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import posts, tags, snapshots @pytest.fixture(autouse=True) def inject_config(config_injector): - config_injector({'privileges': {'posts:delete': db.User.RANK_REGULAR}}) + config_injector({'privileges': {'posts:delete': model.User.RANK_REGULAR}}) def test_deleting(user_factory, post_factory, context_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) post = post_factory(id=1) db.session.add(post) db.session.flush() @@ -20,7 +20,7 @@ def test_deleting(user_factory, post_factory, context_factory): context_factory(params={'version': 1}, user=auth_user), {'post_id': 1}) assert result == {} - assert db.session.query(db.Post).count() == 0 + assert db.session.query(model.Post).count() == 0 snapshots.delete.assert_called_once_with(post, auth_user) tags.export_to_json.assert_called_once_with() @@ -28,7 +28,7 @@ def test_deleting(user_factory, post_factory, context_factory): def test_trying_to_delete_non_existing(user_factory, context_factory): with pytest.raises(posts.PostNotFoundError): api.post_api.delete_post( - context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), + context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), {'post_id': 999}) @@ -38,6 +38,6 @@ def test_trying_to_delete_without_privileges( db.session.commit() with pytest.raises(errors.AuthError): api.post_api.delete_post( - context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)), + context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)), {'post_id': 1}) - assert db.session.query(db.Post).count() == 1 + assert db.session.query(model.Post).count() == 1 diff --git a/server/szurubooru/tests/api/test_post_favoriting.py b/server/szurubooru/tests/api/test_post_favoriting.py index d78d199e..ce91a028 100644 --- a/server/szurubooru/tests/api/test_post_favoriting.py +++ b/server/szurubooru/tests/api/test_post_favoriting.py @@ -1,13 +1,14 @@ from datetime import datetime from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import posts @pytest.fixture(autouse=True) def inject_config(config_injector): - config_injector({'privileges': {'posts:favorite': db.User.RANK_REGULAR}}) + config_injector( + {'privileges': {'posts:favorite': model.User.RANK_REGULAR}}) def test_adding_to_favorites( @@ -23,8 +24,8 @@ def test_adding_to_favorites( context_factory(user=user_factory()), {'post_id': post.post_id}) assert result == 'serialized post' - post = db.session.query(db.Post).one() - assert db.session.query(db.PostFavorite).count() == 1 + post = db.session.query(model.Post).one() + assert db.session.query(model.PostFavorite).count() == 1 assert post is not None assert post.favorite_count == 1 assert post.score == 1 @@ -47,9 +48,9 @@ def test_removing_from_favorites( api.post_api.delete_post_from_favorites( context_factory(user=user), {'post_id': post.post_id}) - post = db.session.query(db.Post).one() + post = db.session.query(model.Post).one() assert post.score == 1 - assert db.session.query(db.PostFavorite).count() == 0 + assert db.session.query(model.PostFavorite).count() == 0 assert post.favorite_count == 0 @@ -68,8 +69,8 @@ def test_favoriting_twice( api.post_api.add_post_to_favorites( context_factory(user=user), {'post_id': post.post_id}) - post = db.session.query(db.Post).one() - assert db.session.query(db.PostFavorite).count() == 1 + post = db.session.query(model.Post).one() + assert db.session.query(model.PostFavorite).count() == 1 assert post.favorite_count == 1 @@ -92,8 +93,8 @@ def test_removing_twice( api.post_api.delete_post_from_favorites( context_factory(user=user), {'post_id': post.post_id}) - post = db.session.query(db.Post).one() - assert db.session.query(db.PostFavorite).count() == 0 + post = db.session.query(model.Post).one() + assert db.session.query(model.PostFavorite).count() == 0 assert post.favorite_count == 0 @@ -113,8 +114,8 @@ def test_favorites_from_multiple_users( api.post_api.add_post_to_favorites( context_factory(user=user2), {'post_id': post.post_id}) - post = db.session.query(db.Post).one() - assert db.session.query(db.PostFavorite).count() == 2 + post = db.session.query(model.Post).one() + assert db.session.query(model.PostFavorite).count() == 2 assert post.favorite_count == 2 assert post.last_favorite_time == datetime(1997, 12, 2) @@ -133,5 +134,5 @@ def test_trying_to_rate_without_privileges( db.session.commit() with pytest.raises(errors.AuthError): api.post_api.add_post_to_favorites( - context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)), + context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)), {'post_id': post.post_id}) diff --git a/server/szurubooru/tests/api/test_post_featuring.py b/server/szurubooru/tests/api/test_post_featuring.py index a0a82c75..88e4e001 100644 --- a/server/szurubooru/tests/api/test_post_featuring.py +++ b/server/szurubooru/tests/api/test_post_featuring.py @@ -1,6 +1,6 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import posts, snapshots @@ -8,14 +8,14 @@ from szurubooru.func import posts, snapshots def inject_config(config_injector): config_injector({ 'privileges': { - 'posts:feature': db.User.RANK_REGULAR, - 'posts:view': db.User.RANK_REGULAR, + 'posts:feature': model.User.RANK_REGULAR, + 'posts:view': model.User.RANK_REGULAR, }, }) def test_featuring(user_factory, post_factory, context_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) post = post_factory(id=1) db.session.add(post) db.session.flush() @@ -31,7 +31,7 @@ def test_featuring(user_factory, post_factory, context_factory): assert posts.get_post_by_id(1).is_featured result = api.post_api.get_featured_post( context_factory( - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) assert result == 'serialized post' snapshots.modify.assert_called_once_with(post, auth_user) @@ -40,7 +40,7 @@ def test_trying_to_omit_required_parameter(user_factory, context_factory): with pytest.raises(errors.MissingRequiredParameterError): api.post_api.set_featured_post( context_factory( - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) def test_trying_to_feature_the_same_post_twice( @@ -51,12 +51,12 @@ def test_trying_to_feature_the_same_post_twice( api.post_api.set_featured_post( context_factory( params={'id': 1}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) with pytest.raises(posts.PostAlreadyFeaturedError): api.post_api.set_featured_post( context_factory( params={'id': 1}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) def test_featuring_one_post_after_another( @@ -72,12 +72,12 @@ def test_featuring_one_post_after_another( api.post_api.set_featured_post( context_factory( params={'id': 1}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) with fake_datetime('1998'): api.post_api.set_featured_post( context_factory( params={'id': 2}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) assert posts.try_get_featured_post() is not None assert posts.try_get_featured_post().post_id == 2 assert not posts.get_post_by_id(1).is_featured @@ -89,7 +89,7 @@ def test_trying_to_feature_non_existing(user_factory, context_factory): api.post_api.set_featured_post( context_factory( params={'id': 1}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) def test_trying_to_feature_without_privileges(user_factory, context_factory): @@ -97,10 +97,10 @@ def test_trying_to_feature_without_privileges(user_factory, context_factory): api.post_api.set_featured_post( context_factory( params={'id': 1}, - user=user_factory(rank=db.User.RANK_ANONYMOUS))) + user=user_factory(rank=model.User.RANK_ANONYMOUS))) def test_getting_featured_post_without_privileges_to_view( user_factory, context_factory): api.post_api.get_featured_post( - context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS))) + context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS))) diff --git a/server/szurubooru/tests/api/test_post_merging.py b/server/szurubooru/tests/api/test_post_merging.py index e6540904..eb8464f8 100644 --- a/server/szurubooru/tests/api/test_post_merging.py +++ b/server/szurubooru/tests/api/test_post_merging.py @@ -1,16 +1,16 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import posts, snapshots @pytest.fixture(autouse=True) def inject_config(config_injector): - config_injector({'privileges': {'posts:merge': db.User.RANK_REGULAR}}) + config_injector({'privileges': {'posts:merge': model.User.RANK_REGULAR}}) def test_merging(user_factory, context_factory, post_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) source_post = post_factory() target_post = post_factory() db.session.add_all([source_post, target_post]) @@ -25,6 +25,7 @@ def test_merging(user_factory, context_factory, post_factory): 'mergeToVersion': 1, 'remove': source_post.post_id, 'mergeTo': target_post.post_id, + 'replaceContent': False, }, user=auth_user)) posts.merge_posts.called_once_with(source_post, target_post) @@ -45,13 +46,14 @@ def test_trying_to_omit_mandatory_field( 'mergeToVersion': 1, 'remove': source_post.post_id, 'mergeTo': target_post.post_id, + 'replaceContent': False, } del params[field] with pytest.raises(errors.ValidationError): api.post_api.merge_posts( context_factory( params=params, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) def test_trying_to_merge_non_existing( @@ -63,12 +65,12 @@ def test_trying_to_merge_non_existing( api.post_api.merge_posts( context_factory( params={'remove': post.post_id, 'mergeTo': 999}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) with pytest.raises(posts.PostNotFoundError): api.post_api.merge_posts( context_factory( params={'remove': 999, 'mergeTo': post.post_id}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) def test_trying_to_merge_without_privileges( @@ -85,5 +87,6 @@ def test_trying_to_merge_without_privileges( 'mergeToVersion': 1, 'remove': source_post.post_id, 'mergeTo': target_post.post_id, + 'replaceContent': False, }, - user=user_factory(rank=db.User.RANK_ANONYMOUS))) + user=user_factory(rank=model.User.RANK_ANONYMOUS))) diff --git a/server/szurubooru/tests/api/test_post_rating.py b/server/szurubooru/tests/api/test_post_rating.py index 18e823e7..0fca2f56 100644 --- a/server/szurubooru/tests/api/test_post_rating.py +++ b/server/szurubooru/tests/api/test_post_rating.py @@ -1,12 +1,12 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import posts @pytest.fixture(autouse=True) def inject_config(config_injector): - config_injector({'privileges': {'posts:score': db.User.RANK_REGULAR}}) + config_injector({'privileges': {'posts:score': model.User.RANK_REGULAR}}) def test_simple_rating( @@ -22,8 +22,8 @@ def test_simple_rating( params={'score': 1}, user=user_factory()), {'post_id': post.post_id}) assert result == 'serialized post' - post = db.session.query(db.Post).one() - assert db.session.query(db.PostScore).count() == 1 + post = db.session.query(model.Post).one() + assert db.session.query(model.PostScore).count() == 1 assert post is not None assert post.score == 1 @@ -43,8 +43,8 @@ def test_updating_rating( api.post_api.set_post_score( context_factory(params={'score': -1}, user=user), {'post_id': post.post_id}) - post = db.session.query(db.Post).one() - assert db.session.query(db.PostScore).count() == 1 + post = db.session.query(model.Post).one() + assert db.session.query(model.PostScore).count() == 1 assert post.score == -1 @@ -63,8 +63,8 @@ def test_updating_rating_to_zero( api.post_api.set_post_score( context_factory(params={'score': 0}, user=user), {'post_id': post.post_id}) - post = db.session.query(db.Post).one() - assert db.session.query(db.PostScore).count() == 0 + post = db.session.query(model.Post).one() + assert db.session.query(model.PostScore).count() == 0 assert post.score == 0 @@ -83,8 +83,8 @@ def test_deleting_rating( api.post_api.delete_post_score( context_factory(user=user), {'post_id': post.post_id}) - post = db.session.query(db.Post).one() - assert db.session.query(db.PostScore).count() == 0 + post = db.session.query(model.Post).one() + assert db.session.query(model.PostScore).count() == 0 assert post.score == 0 @@ -104,8 +104,8 @@ def test_ratings_from_multiple_users( api.post_api.set_post_score( context_factory(params={'score': -1}, user=user2), {'post_id': post.post_id}) - post = db.session.query(db.Post).one() - assert db.session.query(db.PostScore).count() == 2 + post = db.session.query(model.Post).one() + assert db.session.query(model.PostScore).count() == 2 assert post.score == 0 @@ -136,5 +136,5 @@ def test_trying_to_rate_without_privileges( api.post_api.set_post_score( context_factory( params={'score': 1}, - user=user_factory(rank=db.User.RANK_ANONYMOUS)), + user=user_factory(rank=model.User.RANK_ANONYMOUS)), {'post_id': post.post_id}) diff --git a/server/szurubooru/tests/api/test_post_retrieving.py b/server/szurubooru/tests/api/test_post_retrieving.py index a02c7bc1..9d9db72a 100644 --- a/server/szurubooru/tests/api/test_post_retrieving.py +++ b/server/szurubooru/tests/api/test_post_retrieving.py @@ -1,7 +1,7 @@ from datetime import datetime from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import posts @@ -9,8 +9,8 @@ from szurubooru.func import posts def inject_config(config_injector): config_injector({ 'privileges': { - 'posts:list': db.User.RANK_REGULAR, - 'posts:view': db.User.RANK_REGULAR, + 'posts:list': model.User.RANK_REGULAR, + 'posts:view': model.User.RANK_REGULAR, }, }) @@ -25,7 +25,7 @@ def test_retrieving_multiple(user_factory, post_factory, context_factory): result = api.post_api.get_posts( context_factory( params={'query': '', 'page': 1}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) assert result == { 'query': '', 'page': 1, @@ -36,10 +36,10 @@ def test_retrieving_multiple(user_factory, post_factory, context_factory): def test_using_special_tokens(user_factory, post_factory, context_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) post1 = post_factory(id=1) post2 = post_factory(id=2) - post1.favorited_by = [db.PostFavorite( + post1.favorited_by = [model.PostFavorite( user=auth_user, time=datetime.utcnow())] db.session.add_all([post1, post2, auth_user]) db.session.flush() @@ -68,7 +68,7 @@ def test_trying_to_use_special_tokens_without_logging_in( api.post_api.get_posts( context_factory( params={'query': 'special:fav', 'page': 1}, - user=user_factory(rank=db.User.RANK_ANONYMOUS))) + user=user_factory(rank=model.User.RANK_ANONYMOUS))) def test_trying_to_retrieve_multiple_without_privileges( @@ -77,7 +77,7 @@ def test_trying_to_retrieve_multiple_without_privileges( api.post_api.get_posts( context_factory( params={'query': '', 'page': 1}, - user=user_factory(rank=db.User.RANK_ANONYMOUS))) + user=user_factory(rank=model.User.RANK_ANONYMOUS))) def test_retrieving_single(user_factory, post_factory, context_factory): @@ -86,7 +86,7 @@ def test_retrieving_single(user_factory, post_factory, context_factory): with patch('szurubooru.func.posts.serialize_post'): posts.serialize_post.return_value = 'serialized post' result = api.post_api.get_post( - context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), + context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), {'post_id': 1}) assert result == 'serialized post' @@ -94,7 +94,7 @@ def test_retrieving_single(user_factory, post_factory, context_factory): def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): with pytest.raises(posts.PostNotFoundError): api.post_api.get_post( - context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), + context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), {'post_id': 999}) @@ -102,5 +102,5 @@ def test_trying_to_retrieve_single_without_privileges( user_factory, context_factory): with pytest.raises(errors.AuthError): api.post_api.get_post( - context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)), + context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)), {'post_id': 999}) diff --git a/server/szurubooru/tests/api/test_post_updating.py b/server/szurubooru/tests/api/test_post_updating.py index 790e835e..d3649307 100644 --- a/server/szurubooru/tests/api/test_post_updating.py +++ b/server/szurubooru/tests/api/test_post_updating.py @@ -1,7 +1,7 @@ from datetime import datetime from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import posts, tags, snapshots, net @@ -9,22 +9,22 @@ from szurubooru.func import posts, tags, snapshots, net def inject_config(config_injector): config_injector({ 'privileges': { - 'posts:edit:tags': db.User.RANK_REGULAR, - 'posts:edit:content': db.User.RANK_REGULAR, - 'posts:edit:safety': db.User.RANK_REGULAR, - 'posts:edit:source': db.User.RANK_REGULAR, - 'posts:edit:relations': db.User.RANK_REGULAR, - 'posts:edit:notes': db.User.RANK_REGULAR, - 'posts:edit:flags': db.User.RANK_REGULAR, - 'posts:edit:thumbnail': db.User.RANK_REGULAR, - 'tags:create': db.User.RANK_MODERATOR, + 'posts:edit:tags': model.User.RANK_REGULAR, + 'posts:edit:content': model.User.RANK_REGULAR, + 'posts:edit:safety': model.User.RANK_REGULAR, + 'posts:edit:source': model.User.RANK_REGULAR, + 'posts:edit:relations': model.User.RANK_REGULAR, + 'posts:edit:notes': model.User.RANK_REGULAR, + 'posts:edit:flags': model.User.RANK_REGULAR, + 'posts:edit:thumbnail': model.User.RANK_REGULAR, + 'tags:create': model.User.RANK_MODERATOR, }, }) def test_post_updating( context_factory, post_factory, user_factory, fake_datetime): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) post = post_factory() db.session.add(post) db.session.flush() @@ -76,7 +76,7 @@ def test_post_updating( posts.update_post_flags.assert_called_once_with( post, ['flag1', 'flag2']) posts.serialize_post.assert_called_once_with( - post, auth_user, options=None) + post, auth_user, options=[]) snapshots.modify.assert_called_once_with(post, auth_user) tags.export_to_json.assert_called_once_with() assert post.last_edit_time == datetime(1997, 1, 1) @@ -97,7 +97,7 @@ def test_uploading_from_url_saves_source( api.post_api.update_post( context_factory( params={'contentUrl': 'example.com', 'version': 1}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'post_id': post.post_id}) net.download.assert_called_once_with('example.com') posts.update_post_content.assert_called_once_with(post, b'content') @@ -122,7 +122,7 @@ def test_uploading_from_url_with_source_specified( 'contentUrl': 'example.com', 'source': 'example2.com', 'version': 1}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'post_id': post.post_id}) net.download.assert_called_once_with('example.com') posts.update_post_content.assert_called_once_with(post, b'content') @@ -134,7 +134,7 @@ def test_trying_to_update_non_existing(context_factory, user_factory): api.post_api.update_post( context_factory( params='whatever', - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'post_id': 1}) @@ -158,7 +158,7 @@ def test_trying_to_update_field_without_privileges( context_factory( params={**params, **{'version': 1}}, files=files, - user=user_factory(rank=db.User.RANK_ANONYMOUS)), + user=user_factory(rank=model.User.RANK_ANONYMOUS)), {'post_id': post.post_id}) @@ -173,5 +173,5 @@ def test_trying_to_create_tags_without_privileges( api.post_api.update_post( context_factory( params={'tags': ['tag1', 'tag2'], 'version': 1}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'post_id': post.post_id}) diff --git a/server/szurubooru/tests/api/test_snapshot_retrieving.py b/server/szurubooru/tests/api/test_snapshot_retrieving.py index 73b6f060..facbcd8a 100644 --- a/server/szurubooru/tests/api/test_snapshot_retrieving.py +++ b/server/szurubooru/tests/api/test_snapshot_retrieving.py @@ -1,10 +1,10 @@ from datetime import datetime import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors def snapshot_factory(): - snapshot = db.Snapshot() + snapshot = model.Snapshot() snapshot.creation_time = datetime(1999, 1, 1) snapshot.resource_type = 'dummy' snapshot.resource_pkey = 1 @@ -17,7 +17,7 @@ def snapshot_factory(): @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({ - 'privileges': {'snapshots:list': db.User.RANK_REGULAR}, + 'privileges': {'snapshots:list': model.User.RANK_REGULAR}, }) @@ -29,7 +29,7 @@ def test_retrieving_multiple(user_factory, context_factory): result = api.snapshot_api.get_snapshots( context_factory( params={'query': '', 'page': 1}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) assert result['query'] == '' assert result['page'] == 1 assert result['pageSize'] == 100 @@ -43,4 +43,4 @@ def test_trying_to_retrieve_multiple_without_privileges( api.snapshot_api.get_snapshots( context_factory( params={'query': '', 'page': 1}, - user=user_factory(rank=db.User.RANK_ANONYMOUS))) + user=user_factory(rank=model.User.RANK_ANONYMOUS))) diff --git a/server/szurubooru/tests/api/test_tag_category_creating.py b/server/szurubooru/tests/api/test_tag_category_creating.py index 96afc390..fbd8b1bc 100644 --- a/server/szurubooru/tests/api/test_tag_category_creating.py +++ b/server/szurubooru/tests/api/test_tag_category_creating.py @@ -1,6 +1,6 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import tag_categories, tags, snapshots @@ -11,13 +11,13 @@ def _update_category_name(category, name): @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({ - 'privileges': {'tag_categories:create': db.User.RANK_REGULAR}, + 'privileges': {'tag_categories:create': model.User.RANK_REGULAR}, }) def test_creating_category( tag_category_factory, user_factory, context_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) category = tag_category_factory(name='meta') db.session.add(category) @@ -49,7 +49,7 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field): api.tag_category_api.create_tag_category( context_factory( params=params, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) def test_trying_to_create_without_privileges(user_factory, context_factory): @@ -57,4 +57,4 @@ def test_trying_to_create_without_privileges(user_factory, context_factory): api.tag_category_api.create_tag_category( context_factory( params={'name': 'meta', 'color': 'black'}, - user=user_factory(rank=db.User.RANK_ANONYMOUS))) + user=user_factory(rank=model.User.RANK_ANONYMOUS))) diff --git a/server/szurubooru/tests/api/test_tag_category_deleting.py b/server/szurubooru/tests/api/test_tag_category_deleting.py index 1f1cde4c..1fc86431 100644 --- a/server/szurubooru/tests/api/test_tag_category_deleting.py +++ b/server/szurubooru/tests/api/test_tag_category_deleting.py @@ -1,18 +1,18 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import tag_categories, tags, snapshots @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({ - 'privileges': {'tag_categories:delete': db.User.RANK_REGULAR}, + 'privileges': {'tag_categories:delete': model.User.RANK_REGULAR}, }) def test_deleting(user_factory, tag_category_factory, context_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) category = tag_category_factory(name='category') db.session.add(tag_category_factory(name='root')) db.session.add(category) @@ -23,8 +23,8 @@ def test_deleting(user_factory, tag_category_factory, context_factory): context_factory(params={'version': 1}, user=auth_user), {'category_name': 'category'}) assert result == {} - assert db.session.query(db.TagCategory).count() == 1 - assert db.session.query(db.TagCategory).one().name == 'root' + assert db.session.query(model.TagCategory).count() == 1 + assert db.session.query(model.TagCategory).one().name == 'root' snapshots.delete.assert_called_once_with(category, auth_user) tags.export_to_json.assert_called_once_with() @@ -41,9 +41,9 @@ def test_trying_to_delete_used( api.tag_category_api.delete_tag_category( context_factory( params={'version': 1}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'category_name': 'category'}) - assert db.session.query(db.TagCategory).count() == 1 + assert db.session.query(model.TagCategory).count() == 1 def test_trying_to_delete_last( @@ -54,14 +54,14 @@ def test_trying_to_delete_last( api.tag_category_api.delete_tag_category( context_factory( params={'version': 1}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'category_name': 'root'}) def test_trying_to_delete_non_existing(user_factory, context_factory): with pytest.raises(tag_categories.TagCategoryNotFoundError): api.tag_category_api.delete_tag_category( - context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), + context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), {'category_name': 'bad'}) @@ -73,6 +73,6 @@ def test_trying_to_delete_without_privileges( api.tag_category_api.delete_tag_category( context_factory( params={'version': 1}, - user=user_factory(rank=db.User.RANK_ANONYMOUS)), + user=user_factory(rank=model.User.RANK_ANONYMOUS)), {'category_name': 'category'}) - assert db.session.query(db.TagCategory).count() == 1 + assert db.session.query(model.TagCategory).count() == 1 diff --git a/server/szurubooru/tests/api/test_tag_category_retrieving.py b/server/szurubooru/tests/api/test_tag_category_retrieving.py index 4f6610b3..0b98d743 100644 --- a/server/szurubooru/tests/api/test_tag_category_retrieving.py +++ b/server/szurubooru/tests/api/test_tag_category_retrieving.py @@ -1,5 +1,5 @@ import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import tag_categories @@ -7,8 +7,8 @@ from szurubooru.func import tag_categories def inject_config(config_injector): config_injector({ 'privileges': { - 'tag_categories:list': db.User.RANK_REGULAR, - 'tag_categories:view': db.User.RANK_REGULAR, + 'tag_categories:list': model.User.RANK_REGULAR, + 'tag_categories:view': model.User.RANK_REGULAR, }, }) @@ -21,7 +21,7 @@ def test_retrieving_multiple( ]) db.session.flush() result = api.tag_category_api.get_tag_categories( - context_factory(user=user_factory(rank=db.User.RANK_REGULAR))) + context_factory(user=user_factory(rank=model.User.RANK_REGULAR))) assert [cat['name'] for cat in result['results']] == ['c1', 'c2'] @@ -30,7 +30,7 @@ def test_retrieving_single( db.session.add(tag_category_factory(name='cat')) db.session.flush() result = api.tag_category_api.get_tag_category( - context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), + context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), {'category_name': 'cat'}) assert result == { 'name': 'cat', @@ -44,7 +44,7 @@ def test_retrieving_single( def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): with pytest.raises(tag_categories.TagCategoryNotFoundError): api.tag_category_api.get_tag_category( - context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), + context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), {'category_name': '-'}) @@ -52,5 +52,5 @@ def test_trying_to_retrieve_single_without_privileges( user_factory, context_factory): with pytest.raises(errors.AuthError): api.tag_category_api.get_tag_category( - context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)), + context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)), {'category_name': '-'}) diff --git a/server/szurubooru/tests/api/test_tag_category_updating.py b/server/szurubooru/tests/api/test_tag_category_updating.py index 9dd0f6bb..d406dd1f 100644 --- a/server/szurubooru/tests/api/test_tag_category_updating.py +++ b/server/szurubooru/tests/api/test_tag_category_updating.py @@ -1,6 +1,6 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import tag_categories, tags, snapshots @@ -12,15 +12,15 @@ def _update_category_name(category, name): def inject_config(config_injector): config_injector({ 'privileges': { - 'tag_categories:edit:name': db.User.RANK_REGULAR, - 'tag_categories:edit:color': db.User.RANK_REGULAR, - 'tag_categories:set_default': db.User.RANK_REGULAR, + 'tag_categories:edit:name': model.User.RANK_REGULAR, + 'tag_categories:edit:color': model.User.RANK_REGULAR, + 'tag_categories:set_default': model.User.RANK_REGULAR, }, }) def test_simple_updating(user_factory, tag_category_factory, context_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) category = tag_category_factory(name='name', color='black') db.session.add(category) db.session.flush() @@ -61,7 +61,7 @@ def test_omitting_optional_field( api.tag_category_api.update_tag_category( context_factory( params={**params, **{'version': 1}}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'category_name': 'name'}) @@ -70,7 +70,7 @@ def test_trying_to_update_non_existing(user_factory, context_factory): api.tag_category_api.update_tag_category( context_factory( params={'name': ['dummy']}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'category_name': 'bad'}) @@ -86,7 +86,7 @@ def test_trying_to_update_without_privileges( api.tag_category_api.update_tag_category( context_factory( params={**params, **{'version': 1}}, - user=user_factory(rank=db.User.RANK_ANONYMOUS)), + user=user_factory(rank=model.User.RANK_ANONYMOUS)), {'category_name': 'dummy'}) @@ -106,7 +106,7 @@ def test_set_as_default(user_factory, tag_category_factory, context_factory): 'color': 'white', 'version': 1, }, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'category_name': 'name'}) assert result == 'serialized category' tag_categories.set_default_category.assert_called_once_with(category) diff --git a/server/szurubooru/tests/api/test_tag_creating.py b/server/szurubooru/tests/api/test_tag_creating.py index dc056280..771b9f61 100644 --- a/server/szurubooru/tests/api/test_tag_creating.py +++ b/server/szurubooru/tests/api/test_tag_creating.py @@ -1,16 +1,16 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import tags, snapshots @pytest.fixture(autouse=True) def inject_config(config_injector): - config_injector({'privileges': {'tags:create': db.User.RANK_REGULAR}}) + config_injector({'privileges': {'tags:create': model.User.RANK_REGULAR}}) def test_creating_simple_tags(tag_factory, user_factory, context_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) tag = tag_factory() with patch('szurubooru.func.tags.create_tag'), \ patch('szurubooru.func.tags.get_or_create_tags_by_names'), \ @@ -50,7 +50,7 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field): api.tag_api.create_tag( context_factory( params=params, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) @pytest.mark.parametrize('field', ['implications', 'suggestions']) @@ -70,7 +70,7 @@ def test_omitting_optional_field( api.tag_api.create_tag( context_factory( params=params, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) def test_trying_to_create_tag_without_privileges( @@ -84,4 +84,4 @@ def test_trying_to_create_tag_without_privileges( 'suggestions': ['tag'], 'implications': [], }, - user=user_factory(rank=db.User.RANK_ANONYMOUS))) + user=user_factory(rank=model.User.RANK_ANONYMOUS))) diff --git a/server/szurubooru/tests/api/test_tag_deleting.py b/server/szurubooru/tests/api/test_tag_deleting.py index a657b02e..fbd35e12 100644 --- a/server/szurubooru/tests/api/test_tag_deleting.py +++ b/server/szurubooru/tests/api/test_tag_deleting.py @@ -1,16 +1,16 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import tags, snapshots @pytest.fixture(autouse=True) def inject_config(config_injector): - config_injector({'privileges': {'tags:delete': db.User.RANK_REGULAR}}) + config_injector({'privileges': {'tags:delete': model.User.RANK_REGULAR}}) def test_deleting(user_factory, tag_factory, context_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) tag = tag_factory(names=['tag']) db.session.add(tag) db.session.commit() @@ -20,7 +20,7 @@ def test_deleting(user_factory, tag_factory, context_factory): context_factory(params={'version': 1}, user=auth_user), {'tag_name': 'tag'}) assert result == {} - assert db.session.query(db.Tag).count() == 0 + assert db.session.query(model.Tag).count() == 0 snapshots.delete.assert_called_once_with(tag, auth_user) tags.export_to_json.assert_called_once_with() @@ -36,17 +36,17 @@ def test_deleting_used( api.tag_api.delete_tag( context_factory( params={'version': 1}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'tag_name': 'tag'}) db.session.refresh(post) - assert db.session.query(db.Tag).count() == 0 + assert db.session.query(model.Tag).count() == 0 assert post.tags == [] def test_trying_to_delete_non_existing(user_factory, context_factory): with pytest.raises(tags.TagNotFoundError): api.tag_api.delete_tag( - context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), + context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), {'tag_name': 'bad'}) @@ -58,6 +58,6 @@ def test_trying_to_delete_without_privileges( api.tag_api.delete_tag( context_factory( params={'version': 1}, - user=user_factory(rank=db.User.RANK_ANONYMOUS)), + user=user_factory(rank=model.User.RANK_ANONYMOUS)), {'tag_name': 'tag'}) - assert db.session.query(db.Tag).count() == 1 + assert db.session.query(model.Tag).count() == 1 diff --git a/server/szurubooru/tests/api/test_tag_merging.py b/server/szurubooru/tests/api/test_tag_merging.py index a448c9c4..484fbfa6 100644 --- a/server/szurubooru/tests/api/test_tag_merging.py +++ b/server/szurubooru/tests/api/test_tag_merging.py @@ -1,16 +1,16 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import tags, snapshots @pytest.fixture(autouse=True) def inject_config(config_injector): - config_injector({'privileges': {'tags:merge': db.User.RANK_REGULAR}}) + config_injector({'privileges': {'tags:merge': model.User.RANK_REGULAR}}) def test_merging(user_factory, tag_factory, context_factory, post_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) source_tag = tag_factory(names=['source']) target_tag = tag_factory(names=['target']) db.session.add_all([source_tag, target_tag]) @@ -62,7 +62,7 @@ def test_trying_to_omit_mandatory_field( api.tag_api.merge_tags( context_factory( params=params, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) def test_trying_to_merge_non_existing( @@ -73,12 +73,12 @@ def test_trying_to_merge_non_existing( api.tag_api.merge_tags( context_factory( params={'remove': 'good', 'mergeTo': 'bad'}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) with pytest.raises(tags.TagNotFoundError): api.tag_api.merge_tags( context_factory( params={'remove': 'bad', 'mergeTo': 'good'}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) def test_trying_to_merge_without_privileges( @@ -97,4 +97,4 @@ def test_trying_to_merge_without_privileges( 'remove': 'source', 'mergeTo': 'target', }, - user=user_factory(rank=db.User.RANK_ANONYMOUS))) + user=user_factory(rank=model.User.RANK_ANONYMOUS))) diff --git a/server/szurubooru/tests/api/test_tag_retrieving.py b/server/szurubooru/tests/api/test_tag_retrieving.py index 86837f97..fd2b2cb5 100644 --- a/server/szurubooru/tests/api/test_tag_retrieving.py +++ b/server/szurubooru/tests/api/test_tag_retrieving.py @@ -1,6 +1,6 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import tags @@ -8,8 +8,8 @@ from szurubooru.func import tags def inject_config(config_injector): config_injector({ 'privileges': { - 'tags:list': db.User.RANK_REGULAR, - 'tags:view': db.User.RANK_REGULAR, + 'tags:list': model.User.RANK_REGULAR, + 'tags:view': model.User.RANK_REGULAR, }, }) @@ -24,7 +24,7 @@ def test_retrieving_multiple(user_factory, tag_factory, context_factory): result = api.tag_api.get_tags( context_factory( params={'query': '', 'page': 1}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) assert result == { 'query': '', 'page': 1, @@ -40,7 +40,7 @@ def test_trying_to_retrieve_multiple_without_privileges( api.tag_api.get_tags( context_factory( params={'query': '', 'page': 1}, - user=user_factory(rank=db.User.RANK_ANONYMOUS))) + user=user_factory(rank=model.User.RANK_ANONYMOUS))) def test_retrieving_single(user_factory, tag_factory, context_factory): @@ -50,7 +50,7 @@ def test_retrieving_single(user_factory, tag_factory, context_factory): tags.serialize_tag.return_value = 'serialized tag' result = api.tag_api.get_tag( context_factory( - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'tag_name': 'tag'}) assert result == 'serialized tag' @@ -59,7 +59,7 @@ def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): with pytest.raises(tags.TagNotFoundError): api.tag_api.get_tag( context_factory( - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'tag_name': '-'}) @@ -68,5 +68,5 @@ def test_trying_to_retrieve_single_without_privileges( with pytest.raises(errors.AuthError): api.tag_api.get_tag( context_factory( - user=user_factory(rank=db.User.RANK_ANONYMOUS)), + user=user_factory(rank=model.User.RANK_ANONYMOUS)), {'tag_name': '-'}) diff --git a/server/szurubooru/tests/api/test_tag_siblings_retrieving.py b/server/szurubooru/tests/api/test_tag_siblings_retrieving.py index 6de25fcc..fc2f5aaa 100644 --- a/server/szurubooru/tests/api/test_tag_siblings_retrieving.py +++ b/server/szurubooru/tests/api/test_tag_siblings_retrieving.py @@ -1,12 +1,12 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import tags @pytest.fixture(autouse=True) def inject_config(config_injector): - config_injector({'privileges': {'tags:view': db.User.RANK_REGULAR}}) + config_injector({'privileges': {'tags:view': model.User.RANK_REGULAR}}) def test_get_tag_siblings(user_factory, tag_factory, context_factory): @@ -21,7 +21,7 @@ def test_get_tag_siblings(user_factory, tag_factory, context_factory): (tag_factory(names=['sib2']), 3), ] result = api.tag_api.get_tag_siblings( - context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), + context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), {'tag_name': 'tag'}) assert result == { 'results': [ @@ -40,12 +40,12 @@ def test_get_tag_siblings(user_factory, tag_factory, context_factory): def test_trying_to_retrieve_non_existing(user_factory, context_factory): with pytest.raises(tags.TagNotFoundError): api.tag_api.get_tag_siblings( - context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), + context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), {'tag_name': '-'}) def test_trying_to_retrieve_without_privileges(user_factory, context_factory): with pytest.raises(errors.AuthError): api.tag_api.get_tag_siblings( - context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)), + context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)), {'tag_name': '-'}) diff --git a/server/szurubooru/tests/api/test_tag_updating.py b/server/szurubooru/tests/api/test_tag_updating.py index 3fe69bd8..fb63e353 100644 --- a/server/szurubooru/tests/api/test_tag_updating.py +++ b/server/szurubooru/tests/api/test_tag_updating.py @@ -1,6 +1,6 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import tags, snapshots @@ -8,18 +8,18 @@ from szurubooru.func import tags, snapshots def inject_config(config_injector): config_injector({ 'privileges': { - 'tags:create': db.User.RANK_REGULAR, - 'tags:edit:names': db.User.RANK_REGULAR, - 'tags:edit:category': db.User.RANK_REGULAR, - 'tags:edit:description': db.User.RANK_REGULAR, - 'tags:edit:suggestions': db.User.RANK_REGULAR, - 'tags:edit:implications': db.User.RANK_REGULAR, + 'tags:create': model.User.RANK_REGULAR, + 'tags:edit:names': model.User.RANK_REGULAR, + 'tags:edit:category': model.User.RANK_REGULAR, + 'tags:edit:description': model.User.RANK_REGULAR, + 'tags:edit:suggestions': model.User.RANK_REGULAR, + 'tags:edit:implications': model.User.RANK_REGULAR, }, }) def test_simple_updating(user_factory, tag_factory, context_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) tag = tag_factory(names=['tag1', 'tag2']) db.session.add(tag) db.session.commit() @@ -56,8 +56,7 @@ def test_simple_updating(user_factory, tag_factory, context_factory): tag, ['sug1', 'sug2']) tags.update_tag_implications.assert_called_once_with( tag, ['imp1', 'imp2']) - tags.serialize_tag.assert_called_once_with( - tag, options=None) + tags.serialize_tag.assert_called_once_with(tag, options=[]) snapshots.modify.assert_called_once_with(tag, auth_user) tags.export_to_json.assert_called_once_with() @@ -90,7 +89,7 @@ def test_omitting_optional_field( api.tag_api.update_tag( context_factory( params={**params, **{'version': 1}}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'tag_name': 'tag'}) @@ -99,7 +98,7 @@ def test_trying_to_update_non_existing(user_factory, context_factory): api.tag_api.update_tag( context_factory( params={'names': ['dummy']}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'tag_name': 'tag1'}) @@ -117,7 +116,7 @@ def test_trying_to_update_without_privileges( api.tag_api.update_tag( context_factory( params={**params, **{'version': 1}}, - user=user_factory(rank=db.User.RANK_ANONYMOUS)), + user=user_factory(rank=model.User.RANK_ANONYMOUS)), {'tag_name': 'tag'}) @@ -127,9 +126,9 @@ def test_trying_to_create_tags_without_privileges( db.session.add(tag) db.session.commit() config_injector({'privileges': { - 'tags:create': db.User.RANK_ADMINISTRATOR, - 'tags:edit:suggestions': db.User.RANK_REGULAR, - 'tags:edit:implications': db.User.RANK_REGULAR, + 'tags:create': model.User.RANK_ADMINISTRATOR, + 'tags:edit:suggestions': model.User.RANK_REGULAR, + 'tags:edit:implications': model.User.RANK_REGULAR, }}) with patch('szurubooru.func.tags.get_or_create_tags_by_names'): tags.get_or_create_tags_by_names.return_value = ([], ['new-tag']) @@ -137,12 +136,12 @@ def test_trying_to_create_tags_without_privileges( api.tag_api.update_tag( context_factory( params={'suggestions': ['tag1', 'tag2'], 'version': 1}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'tag_name': 'tag'}) db.session.rollback() with pytest.raises(errors.AuthError): api.tag_api.update_tag( context_factory( params={'implications': ['tag1', 'tag2'], 'version': 1}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'tag_name': 'tag'}) diff --git a/server/szurubooru/tests/api/test_user_creating.py b/server/szurubooru/tests/api/test_user_creating.py index 8b583b6e..df2e80bb 100644 --- a/server/szurubooru/tests/api/test_user_creating.py +++ b/server/szurubooru/tests/api/test_user_creating.py @@ -1,6 +1,6 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import users @@ -31,7 +31,7 @@ def test_creating_user(user_factory, context_factory, fake_datetime): 'avatarStyle': 'manual', }, files={'avatar': b'...'}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) assert result == 'serialized user' users.create_user.assert_called_once_with( 'chewie1', 'oks', 'asd@asd.asd') @@ -50,7 +50,7 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field): 'password': 'oks', } user = user_factory() - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) del params[field] with patch('szurubooru.func.users.create_user'), \ pytest.raises(errors.MissingRequiredParameterError): @@ -70,7 +70,7 @@ def test_omitting_optional_field(user_factory, context_factory, field): } del params[field] user = user_factory() - auth_user = user_factory(rank=db.User.RANK_MODERATOR) + auth_user = user_factory(rank=model.User.RANK_MODERATOR) with patch('szurubooru.func.users.create_user'), \ patch('szurubooru.func.users.update_user_avatar'), \ patch('szurubooru.func.users.serialize_user'): @@ -84,4 +84,4 @@ def test_trying_to_create_user_without_privileges( with pytest.raises(errors.AuthError): api.user_api.create_user(context_factory( params='whatever', - user=user_factory(rank=db.User.RANK_ANONYMOUS))) + user=user_factory(rank=model.User.RANK_ANONYMOUS))) diff --git a/server/szurubooru/tests/api/test_user_deleting.py b/server/szurubooru/tests/api/test_user_deleting.py index 9dd87764..2bd53e2b 100644 --- a/server/szurubooru/tests/api/test_user_deleting.py +++ b/server/szurubooru/tests/api/test_user_deleting.py @@ -1,5 +1,5 @@ import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import users @@ -7,45 +7,45 @@ from szurubooru.func import users def inject_config(config_injector): config_injector({ 'privileges': { - 'users:delete:self': db.User.RANK_REGULAR, - 'users:delete:any': db.User.RANK_MODERATOR, + 'users:delete:self': model.User.RANK_REGULAR, + 'users:delete:any': model.User.RANK_MODERATOR, }, }) def test_deleting_oneself(user_factory, context_factory): - user = user_factory(name='u', rank=db.User.RANK_REGULAR) + user = user_factory(name='u', rank=model.User.RANK_REGULAR) db.session.add(user) db.session.commit() result = api.user_api.delete_user( context_factory( params={'version': 1}, user=user), {'user_name': 'u'}) assert result == {} - assert db.session.query(db.User).count() == 0 + assert db.session.query(model.User).count() == 0 def test_deleting_someone_else(user_factory, context_factory): - user1 = user_factory(name='u1', rank=db.User.RANK_REGULAR) - user2 = user_factory(name='u2', rank=db.User.RANK_MODERATOR) + user1 = user_factory(name='u1', rank=model.User.RANK_REGULAR) + user2 = user_factory(name='u2', rank=model.User.RANK_MODERATOR) db.session.add_all([user1, user2]) db.session.commit() api.user_api.delete_user( context_factory( params={'version': 1}, user=user2), {'user_name': 'u1'}) - assert db.session.query(db.User).count() == 1 + assert db.session.query(model.User).count() == 1 def test_trying_to_delete_someone_else_without_privileges( user_factory, context_factory): - user1 = user_factory(name='u1', rank=db.User.RANK_REGULAR) - user2 = user_factory(name='u2', rank=db.User.RANK_REGULAR) + user1 = user_factory(name='u1', rank=model.User.RANK_REGULAR) + user2 = user_factory(name='u2', rank=model.User.RANK_REGULAR) db.session.add_all([user1, user2]) db.session.commit() with pytest.raises(errors.AuthError): api.user_api.delete_user( context_factory( params={'version': 1}, user=user2), {'user_name': 'u1'}) - assert db.session.query(db.User).count() == 2 + assert db.session.query(model.User).count() == 2 def test_trying_to_delete_non_existing(user_factory, context_factory): @@ -53,5 +53,5 @@ def test_trying_to_delete_non_existing(user_factory, context_factory): api.user_api.delete_user( context_factory( params={'version': 1}, - user=user_factory(rank=db.User.RANK_REGULAR)), + user=user_factory(rank=model.User.RANK_REGULAR)), {'user_name': 'bad'}) diff --git a/server/szurubooru/tests/api/test_user_retrieving.py b/server/szurubooru/tests/api/test_user_retrieving.py index 6400e0d4..9be26200 100644 --- a/server/szurubooru/tests/api/test_user_retrieving.py +++ b/server/szurubooru/tests/api/test_user_retrieving.py @@ -1,6 +1,6 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import users @@ -8,16 +8,16 @@ from szurubooru.func import users def inject_config(config_injector): config_injector({ 'privileges': { - 'users:list': db.User.RANK_REGULAR, - 'users:view': db.User.RANK_REGULAR, - 'users:edit:any:email': db.User.RANK_MODERATOR, + 'users:list': model.User.RANK_REGULAR, + 'users:view': model.User.RANK_REGULAR, + 'users:edit:any:email': model.User.RANK_MODERATOR, }, }) def test_retrieving_multiple(user_factory, context_factory): - user1 = user_factory(name='u1', rank=db.User.RANK_MODERATOR) - user2 = user_factory(name='u2', rank=db.User.RANK_MODERATOR) + user1 = user_factory(name='u1', rank=model.User.RANK_MODERATOR) + user2 = user_factory(name='u2', rank=model.User.RANK_MODERATOR) db.session.add_all([user1, user2]) db.session.flush() with patch('szurubooru.func.users.serialize_user'): @@ -25,7 +25,7 @@ def test_retrieving_multiple(user_factory, context_factory): result = api.user_api.get_users( context_factory( params={'query': '', 'page': 1}, - user=user_factory(rank=db.User.RANK_REGULAR))) + user=user_factory(rank=model.User.RANK_REGULAR))) assert result == { 'query': '', 'page': 1, @@ -41,12 +41,12 @@ def test_trying_to_retrieve_multiple_without_privileges( api.user_api.get_users( context_factory( params={'query': '', 'page': 1}, - user=user_factory(rank=db.User.RANK_ANONYMOUS))) + user=user_factory(rank=model.User.RANK_ANONYMOUS))) def test_retrieving_single(user_factory, context_factory): - user = user_factory(name='u1', rank=db.User.RANK_REGULAR) - auth_user = user_factory(rank=db.User.RANK_REGULAR) + user = user_factory(name='u1', rank=model.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) db.session.add(user) db.session.flush() with patch('szurubooru.func.users.serialize_user'): @@ -57,7 +57,7 @@ def test_retrieving_single(user_factory, context_factory): def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): - auth_user = user_factory(rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=model.User.RANK_REGULAR) with pytest.raises(users.UserNotFoundError): api.user_api.get_user( context_factory(user=auth_user), {'user_name': '-'}) @@ -65,8 +65,8 @@ def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): def test_trying_to_retrieve_single_without_privileges( user_factory, context_factory): - auth_user = user_factory(rank=db.User.RANK_ANONYMOUS) - db.session.add(user_factory(name='u1', rank=db.User.RANK_REGULAR)) + auth_user = user_factory(rank=model.User.RANK_ANONYMOUS) + db.session.add(user_factory(name='u1', rank=model.User.RANK_REGULAR)) db.session.flush() with pytest.raises(errors.AuthError): api.user_api.get_user( diff --git a/server/szurubooru/tests/api/test_user_updating.py b/server/szurubooru/tests/api/test_user_updating.py index 921b2697..af750493 100644 --- a/server/szurubooru/tests/api/test_user_updating.py +++ b/server/szurubooru/tests/api/test_user_updating.py @@ -1,6 +1,6 @@ from unittest.mock import patch import pytest -from szurubooru import api, db, errors +from szurubooru import api, db, model, errors from szurubooru.func import users @@ -8,23 +8,23 @@ from szurubooru.func import users def inject_config(config_injector): config_injector({ 'privileges': { - 'users:edit:self:name': db.User.RANK_REGULAR, - 'users:edit:self:pass': db.User.RANK_REGULAR, - 'users:edit:self:email': db.User.RANK_REGULAR, - 'users:edit:self:rank': db.User.RANK_MODERATOR, - 'users:edit:self:avatar': db.User.RANK_MODERATOR, - 'users:edit:any:name': db.User.RANK_MODERATOR, - 'users:edit:any:pass': db.User.RANK_MODERATOR, - 'users:edit:any:email': db.User.RANK_MODERATOR, - 'users:edit:any:rank': db.User.RANK_ADMINISTRATOR, - 'users:edit:any:avatar': db.User.RANK_ADMINISTRATOR, + 'users:edit:self:name': model.User.RANK_REGULAR, + 'users:edit:self:pass': model.User.RANK_REGULAR, + 'users:edit:self:email': model.User.RANK_REGULAR, + 'users:edit:self:rank': model.User.RANK_MODERATOR, + 'users:edit:self:avatar': model.User.RANK_MODERATOR, + 'users:edit:any:name': model.User.RANK_MODERATOR, + 'users:edit:any:pass': model.User.RANK_MODERATOR, + 'users:edit:any:email': model.User.RANK_MODERATOR, + 'users:edit:any:rank': model.User.RANK_ADMINISTRATOR, + 'users:edit:any:avatar': model.User.RANK_ADMINISTRATOR, }, }) def test_updating_user(context_factory, user_factory): - user = user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR) - auth_user = user_factory(rank=db.User.RANK_ADMINISTRATOR) + user = user_factory(name='u1', rank=model.User.RANK_ADMINISTRATOR) + auth_user = user_factory(rank=model.User.RANK_ADMINISTRATOR) db.session.add(user) db.session.flush() @@ -63,13 +63,13 @@ def test_updating_user(context_factory, user_factory): users.update_user_avatar.assert_called_once_with( user, 'manual', b'...') users.serialize_user.assert_called_once_with( - user, auth_user, options=None) + user, auth_user, options=[]) @pytest.mark.parametrize( 'field', ['name', 'email', 'password', 'rank', 'avatarStyle']) def test_omitting_optional_field(user_factory, context_factory, field): - user = user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR) + user = user_factory(name='u1', rank=model.User.RANK_ADMINISTRATOR) db.session.add(user) db.session.flush() params = { @@ -96,7 +96,7 @@ def test_omitting_optional_field(user_factory, context_factory, field): def test_trying_to_update_non_existing(user_factory, context_factory): - user = user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR) + user = user_factory(name='u1', rank=model.User.RANK_ADMINISTRATOR) db.session.add(user) db.session.flush() with pytest.raises(users.UserNotFoundError): @@ -113,8 +113,8 @@ def test_trying_to_update_non_existing(user_factory, context_factory): ]) def test_trying_to_update_field_without_privileges( user_factory, context_factory, params): - user1 = user_factory(name='u1', rank=db.User.RANK_REGULAR) - user2 = user_factory(name='u2', rank=db.User.RANK_REGULAR) + user1 = user_factory(name='u1', rank=model.User.RANK_REGULAR) + user2 = user_factory(name='u2', rank=model.User.RANK_REGULAR) db.session.add_all([user1, user2]) db.session.flush() with pytest.raises(errors.AuthError): diff --git a/server/szurubooru/tests/conftest.py b/server/szurubooru/tests/conftest.py index db34ee02..e71f9609 100644 --- a/server/szurubooru/tests/conftest.py +++ b/server/szurubooru/tests/conftest.py @@ -7,8 +7,8 @@ from unittest.mock import patch from datetime import datetime import pytest import freezegun -import sqlalchemy -from szurubooru import config, db, rest +import sqlalchemy as sa +from szurubooru import config, db, model, rest class QueryCounter: @@ -36,10 +36,10 @@ if not config.config['test_database']: raise RuntimeError('Test database not configured.') _query_counter = QueryCounter() -_engine = sqlalchemy.create_engine(config.config['test_database']) -db.Base.metadata.drop_all(bind=_engine) -db.Base.metadata.create_all(bind=_engine) -sqlalchemy.event.listen( +_engine = sa.create_engine(config.config['test_database']) +model.Base.metadata.drop_all(bind=_engine) +model.Base.metadata.create_all(bind=_engine) +sa.event.listen( _engine, 'before_cursor_execute', _query_counter.create_before_cursor_execute()) @@ -79,14 +79,14 @@ def query_logger(): @pytest.yield_fixture(scope='function', autouse=True) def session(query_logger): # pylint: disable=unused-argument - db.sessionmaker = sqlalchemy.orm.sessionmaker( + db.sessionmaker = sa.orm.sessionmaker( bind=_engine, autoflush=False) - db.session = sqlalchemy.orm.scoped_session(db.sessionmaker) + db.session = sa.orm.scoped_session(db.sessionmaker) try: yield db.session finally: db.session.remove() - for table in reversed(db.Base.metadata.sorted_tables): + for table in reversed(model.Base.metadata.sorted_tables): db.session.execute(table.delete()) db.session.commit() @@ -101,7 +101,7 @@ def context_factory(session): params=params or {}, files=files or {}) ctx.session = session - ctx.user = user or db.User() + ctx.user = user or model.User() return ctx return factory @@ -115,15 +115,15 @@ def config_injector(): @pytest.fixture def user_factory(): - def factory(name=None, rank=db.User.RANK_REGULAR, email='dummy'): - user = db.User() + def factory(name=None, rank=model.User.RANK_REGULAR, email='dummy'): + user = model.User() user.name = name or get_unique_name() user.password_salt = 'dummy' user.password_hash = 'dummy' user.email = email user.rank = rank user.creation_time = datetime(1997, 1, 1) - user.avatar_style = db.User.AVATAR_GRAVATAR + user.avatar_style = model.User.AVATAR_GRAVATAR return user return factory @@ -131,7 +131,7 @@ def user_factory(): @pytest.fixture def tag_category_factory(): def factory(name=None, color='dummy', default=False): - category = db.TagCategory() + category = model.TagCategory() category.name = name or get_unique_name() category.color = color category.default = default @@ -143,12 +143,12 @@ def tag_category_factory(): def tag_factory(): def factory(names=None, category=None): if not category: - category = db.TagCategory(get_unique_name()) + category = model.TagCategory(get_unique_name()) db.session.add(category) - tag = db.Tag() + tag = model.Tag() tag.names = [] for i, name in enumerate(names or [get_unique_name()]): - tag.names.append(db.TagName(name, i)) + tag.names.append(model.TagName(name, i)) tag.category = category tag.creation_time = datetime(1996, 1, 1) return tag @@ -167,10 +167,10 @@ def post_factory(skip_post_hashing): # pylint: disable=invalid-name def factory( id=None, - safety=db.Post.SAFETY_SAFE, - type=db.Post.TYPE_IMAGE, + safety=model.Post.SAFETY_SAFE, + type=model.Post.TYPE_IMAGE, checksum='...'): - post = db.Post() + post = model.Post() post.post_id = id post.safety = safety post.type = type @@ -191,7 +191,7 @@ def comment_factory(user_factory, post_factory): if not post: post = post_factory() db.session.add(post) - comment = db.Comment() + comment = model.Comment() comment.user = user comment.post = post comment.text = text @@ -207,7 +207,7 @@ def post_score_factory(user_factory, post_factory): user = user_factory() if post is None: post = post_factory() - return db.PostScore( + return model.PostScore( post=post, user=user, score=score, time=datetime(1999, 1, 1)) return factory @@ -219,7 +219,7 @@ def post_favorite_factory(user_factory, post_factory): user = user_factory() if post is None: post = post_factory() - return db.PostFavorite( + return model.PostFavorite( post=post, user=user, time=datetime(1999, 1, 1)) return factory diff --git a/server/szurubooru/tests/func/test_comments.py b/server/szurubooru/tests/func/test_comments.py index c3c2fde1..f1e5d0f1 100644 --- a/server/szurubooru/tests/func/test_comments.py +++ b/server/szurubooru/tests/func/test_comments.py @@ -38,8 +38,6 @@ def test_try_get_comment(comment_factory): db.session.flush() assert comments.try_get_comment_by_id(comment.comment_id + 1) is None assert comments.try_get_comment_by_id(comment.comment_id) is comment - with pytest.raises(comments.InvalidCommentIdError): - comments.try_get_comment_by_id('-') def test_get_comment(comment_factory): @@ -49,8 +47,6 @@ def test_get_comment(comment_factory): with pytest.raises(comments.CommentNotFoundError): comments.get_comment_by_id(comment.comment_id + 1) assert comments.get_comment_by_id(comment.comment_id) is comment - with pytest.raises(comments.InvalidCommentIdError): - comments.get_comment_by_id('-') def test_create_comment(user_factory, post_factory, fake_datetime): diff --git a/server/szurubooru/tests/func/test_image_hash.py b/server/szurubooru/tests/func/test_image_hash.py index becba906..1b6efd21 100644 --- a/server/szurubooru/tests/func/test_image_hash.py +++ b/server/szurubooru/tests/func/test_image_hash.py @@ -2,7 +2,13 @@ from szurubooru.func import image_hash def test_hashing(read_asset, config_injector): - config_injector({'elasticsearch': {'index': 'szurubooru_test'}}) + config_injector({ + 'elasticsearch': { + 'host': 'localhost', + 'port': 9200, + 'index': 'szurubooru_test', + }, + }) image_hash.purge() image_hash.add_image('test', read_asset('jpeg.jpg')) diff --git a/server/szurubooru/tests/func/test_posts.py b/server/szurubooru/tests/func/test_posts.py index 682a1ccc..76064699 100644 --- a/server/szurubooru/tests/func/test_posts.py +++ b/server/szurubooru/tests/func/test_posts.py @@ -2,7 +2,7 @@ import os from unittest.mock import patch from datetime import datetime import pytest -from szurubooru import db +from szurubooru import db, model from szurubooru.func import ( posts, users, comments, tags, images, files, util, image_hash) @@ -14,7 +14,7 @@ from szurubooru.func import ( ]) def test_get_post_url(input_mime_type, expected_url, config_injector): config_injector({'data_url': 'http://example.com/'}) - post = db.Post() + post = model.Post() post.post_id = 1 post.mime_type = input_mime_type assert posts.get_post_content_url(post) == expected_url @@ -23,7 +23,7 @@ def test_get_post_url(input_mime_type, expected_url, config_injector): @pytest.mark.parametrize('input_mime_type', ['image/jpeg', 'image/gif']) def test_get_post_thumbnail_url(input_mime_type, config_injector): config_injector({'data_url': 'http://example.com/'}) - post = db.Post() + post = model.Post() post.post_id = 1 post.mime_type = input_mime_type assert posts.get_post_thumbnail_url(post) \ @@ -36,7 +36,7 @@ def test_get_post_thumbnail_url(input_mime_type, config_injector): ('totally/unknown', 'posts/1.dat'), ]) def test_get_post_content_path(input_mime_type, expected_path): - post = db.Post() + post = model.Post() post.post_id = 1 post.mime_type = input_mime_type assert posts.get_post_content_path(post) == expected_path @@ -44,7 +44,7 @@ def test_get_post_content_path(input_mime_type, expected_path): @pytest.mark.parametrize('input_mime_type', ['image/jpeg', 'image/gif']) def test_get_post_thumbnail_path(input_mime_type): - post = db.Post() + post = model.Post() post.post_id = 1 post.mime_type = input_mime_type assert posts.get_post_thumbnail_path(post) == 'generated-thumbnails/1.jpg' @@ -52,7 +52,7 @@ def test_get_post_thumbnail_path(input_mime_type): @pytest.mark.parametrize('input_mime_type', ['image/jpeg', 'image/gif']) def test_get_post_thumbnail_backup_path(input_mime_type): - post = db.Post() + post = model.Post() post.post_id = 1 post.mime_type = input_mime_type assert posts.get_post_thumbnail_backup_path(post) \ @@ -60,7 +60,7 @@ def test_get_post_thumbnail_backup_path(input_mime_type): def test_serialize_note(): - note = db.PostNote() + note = model.PostNote() note.polygon = [[0, 1], [1, 1], [1, 0], [0, 0]] note.text = '...' assert posts.serialize_note(note) == { @@ -86,7 +86,7 @@ def test_serialize_post( = lambda comment, auth_user: comment.user.name auth_user = user_factory(name='auth user') - post = db.Post() + post = model.Post() post.post_id = 1 post.creation_time = datetime(1997, 1, 1) post.last_edit_time = datetime(1998, 1, 1) @@ -94,9 +94,9 @@ def test_serialize_post( tag_factory(names=['tag1', 'tag2']), tag_factory(names=['tag3']) ] - post.safety = db.Post.SAFETY_SAFE + post.safety = model.Post.SAFETY_SAFE post.source = '4gag' - post.type = db.Post.TYPE_IMAGE + post.type = model.Post.TYPE_IMAGE post.checksum = 'deadbeef' post.mime_type = 'image/jpeg' post.file_size = 100 @@ -116,25 +116,25 @@ def test_serialize_post( user=user_factory(name='commenter2'), post=post, time=datetime(1999, 1, 2)), - db.PostFavorite( + model.PostFavorite( post=post, user=user_factory(name='fav1'), time=datetime(1800, 1, 1)), - db.PostFeature( + model.PostFeature( post=post, user=user_factory(), time=datetime(1999, 1, 1)), - db.PostScore( + model.PostScore( post=post, user=auth_user, score=-1, time=datetime(1800, 1, 1)), - db.PostScore( + model.PostScore( post=post, user=user_factory(), score=1, time=datetime(1800, 1, 1)), - db.PostScore( + model.PostScore( post=post, user=user_factory(), score=1, @@ -209,8 +209,6 @@ def test_try_get_post_by_id(post_factory): db.session.flush() assert posts.try_get_post_by_id(post.post_id) == post assert posts.try_get_post_by_id(post.post_id + 1) is None - with pytest.raises(posts.InvalidPostIdError): - posts.get_post_by_id('-') def test_get_post_by_id(post_factory): @@ -220,8 +218,6 @@ def test_get_post_by_id(post_factory): assert posts.get_post_by_id(post.post_id) == post with pytest.raises(posts.PostNotFoundError): posts.get_post_by_id(post.post_id + 1) - with pytest.raises(posts.InvalidPostIdError): - posts.get_post_by_id('-') def test_create_post(user_factory, fake_datetime): @@ -237,30 +233,30 @@ def test_create_post(user_factory, fake_datetime): @pytest.mark.parametrize('input_safety,expected_safety', [ - ('safe', db.Post.SAFETY_SAFE), - ('sketchy', db.Post.SAFETY_SKETCHY), - ('unsafe', db.Post.SAFETY_UNSAFE), + ('safe', model.Post.SAFETY_SAFE), + ('sketchy', model.Post.SAFETY_SKETCHY), + ('unsafe', model.Post.SAFETY_UNSAFE), ]) def test_update_post_safety(input_safety, expected_safety): - post = db.Post() + post = model.Post() posts.update_post_safety(post, input_safety) assert post.safety == expected_safety def test_update_post_safety_with_invalid_string(): - post = db.Post() + post = model.Post() with pytest.raises(posts.InvalidPostSafetyError): posts.update_post_safety(post, 'bad') def test_update_post_source(): - post = db.Post() + post = model.Post() posts.update_post_source(post, 'x') assert post.source == 'x' def test_update_post_source_with_too_long_string(): - post = db.Post() + post = model.Post() with pytest.raises(posts.InvalidPostSourceError): posts.update_post_source(post, 'x' * 1000) @@ -268,24 +264,24 @@ def test_update_post_source_with_too_long_string(): @pytest.mark.parametrize( 'is_existing,input_file,expected_mime_type,expected_type,output_file_name', [ - (True, 'png.png', 'image/png', db.Post.TYPE_IMAGE, '1.png'), - (False, 'png.png', 'image/png', db.Post.TYPE_IMAGE, '1.png'), - (False, 'jpeg.jpg', 'image/jpeg', db.Post.TYPE_IMAGE, '1.jpg'), - (False, 'gif.gif', 'image/gif', db.Post.TYPE_IMAGE, '1.gif'), + (True, 'png.png', 'image/png', model.Post.TYPE_IMAGE, '1.png'), + (False, 'png.png', 'image/png', model.Post.TYPE_IMAGE, '1.png'), + (False, 'jpeg.jpg', 'image/jpeg', model.Post.TYPE_IMAGE, '1.jpg'), + (False, 'gif.gif', 'image/gif', model.Post.TYPE_IMAGE, '1.gif'), ( False, 'gif-animated.gif', 'image/gif', - db.Post.TYPE_ANIMATION, + model.Post.TYPE_ANIMATION, '1.gif', ), - (False, 'webm.webm', 'video/webm', db.Post.TYPE_VIDEO, '1.webm'), - (False, 'mp4.mp4', 'video/mp4', db.Post.TYPE_VIDEO, '1.mp4'), + (False, 'webm.webm', 'video/webm', model.Post.TYPE_VIDEO, '1.webm'), + (False, 'mp4.mp4', 'video/mp4', model.Post.TYPE_VIDEO, '1.mp4'), ( False, 'flash.swf', 'application/x-shockwave-flash', - db.Post.TYPE_FLASH, + model.Post.TYPE_FLASH, '1.swf' ), ]) @@ -318,7 +314,7 @@ def test_update_post_content_for_new_post( assert post.type == expected_type assert post.checksum == 'crc' assert os.path.exists(output_file_path) - if post.type in (db.Post.TYPE_IMAGE, db.Post.TYPE_ANIMATION): + if post.type in (model.Post.TYPE_IMAGE, model.Post.TYPE_ANIMATION): image_hash.delete_image.assert_called_once_with(post.post_id) image_hash.add_image.assert_called_once_with(post.post_id, content) else: @@ -368,7 +364,7 @@ def test_update_post_content_with_broken_content( @pytest.mark.parametrize('input_content', [None, b'not a media file']) def test_update_post_content_with_invalid_content(input_content): - post = db.Post() + post = model.Post() with pytest.raises(posts.InvalidPostContentError): posts.update_post_content(post, input_content) @@ -492,7 +488,7 @@ def test_update_post_content_leaving_custom_thumbnail( def test_update_post_tags(tag_factory): - post = db.Post() + post = model.Post() with patch('szurubooru.func.tags.get_or_create_tags_by_names'): tags.get_or_create_tags_by_names.side_effect = lambda tag_names: \ ([tag_factory(names=[name]) for name in tag_names], []) @@ -528,7 +524,7 @@ def test_update_post_relations_bidirectionality(post_factory): def test_update_post_relations_with_nonexisting_posts(): - post = db.Post() + post = model.Post() with pytest.raises(posts.InvalidPostRelationError): posts.update_post_relations(post, [100]) @@ -542,7 +538,7 @@ def test_update_post_relations_with_itself(post_factory): def test_update_post_notes(): - post = db.Post() + post = model.Post() posts.update_post_notes( post, [ @@ -576,19 +572,19 @@ def test_update_post_notes(): [{'polygon': [[0, 0], [0, 0], [0, 1]]}], ]) def test_update_post_notes_with_invalid_content(input): - post = db.Post() + post = model.Post() with pytest.raises(posts.InvalidPostNoteError): posts.update_post_notes(post, input) def test_update_post_flags(): - post = db.Post() + post = model.Post() posts.update_post_flags(post, ['loop']) assert post.flags == ['loop'] def test_update_post_flags_with_invalid_content(): - post = db.Post() + post = model.Post() with pytest.raises(posts.InvalidPostFlagError): posts.update_post_flags(post, ['invalid']) diff --git a/server/szurubooru/tests/func/test_snapshots.py b/server/szurubooru/tests/func/test_snapshots.py index d4c6754a..09491990 100644 --- a/server/szurubooru/tests/func/test_snapshots.py +++ b/server/szurubooru/tests/func/test_snapshots.py @@ -1,7 +1,7 @@ from unittest.mock import patch from datetime import datetime import pytest -from szurubooru import db +from szurubooru import db, model from szurubooru.func import snapshots, users @@ -56,20 +56,20 @@ def test_get_post_snapshot(post_factory, user_factory, tag_factory): db.session.add_all([user, tag1, tag2, post, related_post1, related_post2]) db.session.flush() - score = db.PostScore() + score = model.PostScore() score.post = post score.user = user score.time = datetime(1997, 1, 1) score.score = 1 - favorite = db.PostFavorite() + favorite = model.PostFavorite() favorite.post = post favorite.user = user favorite.time = datetime(1997, 1, 1) - feature = db.PostFeature() + feature = model.PostFeature() feature.post = post feature.user = user feature.time = datetime(1997, 1, 1) - note = db.PostNote() + note = model.PostNote() note.post = post note.polygon = [(1, 1), (200, 1), (200, 200), (1, 200)] note.text = 'some text' @@ -105,7 +105,7 @@ def test_get_post_snapshot(post_factory, user_factory, tag_factory): def test_serialize_snapshot(user_factory): auth_user = user_factory() - snapshot = db.Snapshot() + snapshot = model.Snapshot() snapshot.operation = snapshot.OPERATION_CREATED snapshot.resource_type = 'type' snapshot.resource_name = 'id' @@ -132,9 +132,9 @@ def test_create(tag_factory, user_factory): snapshots.get_tag_snapshot.return_value = 'mocked' snapshots.create(tag, user_factory()) db.session.flush() - results = db.session.query(db.Snapshot).all() + results = db.session.query(model.Snapshot).all() assert len(results) == 1 - assert results[0].operation == db.Snapshot.OPERATION_CREATED + assert results[0].operation == model.Snapshot.OPERATION_CREATED assert results[0].data == 'mocked' @@ -144,16 +144,16 @@ def test_modify_saves_non_empty_diffs(post_factory, user_factory): 'SQLite doesn\'t support transaction isolation, ' 'which is required to retrieve original entity') post = post_factory() - post.notes = [db.PostNote(polygon=[(0, 0), (0, 1), (1, 1)], text='old')] + post.notes = [model.PostNote(polygon=[(0, 0), (0, 1), (1, 1)], text='old')] user = user_factory() db.session.add_all([post, user]) db.session.commit() post.source = 'new source' - post.notes = [db.PostNote(polygon=[(0, 0), (0, 1), (1, 1)], text='new')] + post.notes = [model.PostNote(polygon=[(0, 0), (0, 1), (1, 1)], text='new')] db.session.flush() snapshots.modify(post, user) db.session.flush() - results = db.session.query(db.Snapshot).all() + results = db.session.query(model.Snapshot).all() assert len(results) == 1 assert results[0].data == { 'type': 'object change', @@ -181,7 +181,7 @@ def test_modify_doesnt_save_empty_diffs(tag_factory, user_factory): db.session.commit() snapshots.modify(tag, user) db.session.flush() - assert db.session.query(db.Snapshot).count() == 0 + assert db.session.query(model.Snapshot).count() == 0 def test_delete(tag_factory, user_factory): @@ -192,9 +192,9 @@ def test_delete(tag_factory, user_factory): snapshots.get_tag_snapshot.return_value = 'mocked' snapshots.delete(tag, user_factory()) db.session.flush() - results = db.session.query(db.Snapshot).all() + results = db.session.query(model.Snapshot).all() assert len(results) == 1 - assert results[0].operation == db.Snapshot.OPERATION_DELETED + assert results[0].operation == model.Snapshot.OPERATION_DELETED assert results[0].data == 'mocked' @@ -205,6 +205,6 @@ def test_merge(tag_factory, user_factory): db.session.flush() snapshots.merge(source_tag, target_tag, user_factory()) db.session.flush() - result = db.session.query(db.Snapshot).one() - assert result.operation == db.Snapshot.OPERATION_MERGED + result = db.session.query(model.Snapshot).one() + assert result.operation == model.Snapshot.OPERATION_MERGED assert result.data == ['tag', 'target'] diff --git a/server/szurubooru/tests/func/test_tag_categories.py b/server/szurubooru/tests/func/test_tag_categories.py index cf74c2a5..d1e55709 100644 --- a/server/szurubooru/tests/func/test_tag_categories.py +++ b/server/szurubooru/tests/func/test_tag_categories.py @@ -1,6 +1,6 @@ from unittest.mock import patch import pytest -from szurubooru import db +from szurubooru import db, model from szurubooru.func import tag_categories, cache @@ -191,7 +191,7 @@ def test_get_default_category_name(tag_category_factory): db.session.flush() cache.purge() assert tag_categories.get_default_category_name() == category1.name - db.session.query(db.TagCategory).delete() + db.session.query(model.TagCategory).delete() cache.purge() with pytest.raises(tag_categories.TagCategoryNotFoundError): tag_categories.get_default_category_name() diff --git a/server/szurubooru/tests/func/test_tags.py b/server/szurubooru/tests/func/test_tags.py index d4674998..712c8e38 100644 --- a/server/szurubooru/tests/func/test_tags.py +++ b/server/szurubooru/tests/func/test_tags.py @@ -3,7 +3,7 @@ import json from unittest.mock import patch from datetime import datetime import pytest -from szurubooru import db +from szurubooru import db, model from szurubooru.func import tags, tag_categories, cache @@ -304,10 +304,10 @@ def test_delete(tag_factory): tag.implications = [tag_factory(names=['imp'])] db.session.add(tag) db.session.flush() - assert db.session.query(db.Tag).count() == 3 + assert db.session.query(model.Tag).count() == 3 tags.delete(tag) db.session.flush() - assert db.session.query(db.Tag).count() == 2 + assert db.session.query(model.Tag).count() == 2 def test_merge_tags_deletes_source_tag(tag_factory): diff --git a/server/szurubooru/tests/func/test_users.py b/server/szurubooru/tests/func/test_users.py index 73150bb2..53d47de6 100644 --- a/server/szurubooru/tests/func/test_users.py +++ b/server/szurubooru/tests/func/test_users.py @@ -1,7 +1,7 @@ from unittest.mock import patch from datetime import datetime import pytest -from szurubooru import db, errors +from szurubooru import db, model, errors from szurubooru.func import auth, users, files, util @@ -20,28 +20,28 @@ def test_get_avatar_path(user_name): ( 'user', None, - db.User.AVATAR_GRAVATAR, + model.User.AVATAR_GRAVATAR, ('https://gravatar.com/avatar/' + 'ee11cbb19052e40b07aac0ca060c23ee?d=retro&s=100'), ), ( None, 'user@example.com', - db.User.AVATAR_GRAVATAR, + model.User.AVATAR_GRAVATAR, ('https://gravatar.com/avatar/' + 'b58996c504c5638798eb6b511e6f49af?d=retro&s=100'), ), ( 'user', 'user@example.com', - db.User.AVATAR_GRAVATAR, + model.User.AVATAR_GRAVATAR, ('https://gravatar.com/avatar/' + 'b58996c504c5638798eb6b511e6f49af?d=retro&s=100'), ), ( 'user', None, - db.User.AVATAR_MANUAL, + model.User.AVATAR_MANUAL, 'http://example.com/avatars/user.png', ), ]) @@ -51,7 +51,7 @@ def test_get_avatar_url( 'data_url': 'http://example.com/', 'thumbnails': {'avatar_width': 100}, }) - user = db.User() + user = model.User() user.name = user_name user.email = user_email user.avatar_style = avatar_style @@ -100,7 +100,7 @@ def test_get_liked_post_count( user = user_factory() post = post_factory() auth_user = user if same_user else user_factory() - score = db.PostScore( + score = model.PostScore( post=post, user=user, score=score, time=datetime.now()) db.session.add_all([post, user, score]) db.session.flush() @@ -127,8 +127,8 @@ def test_serialize_user(user_factory): user = user_factory(name='dummy user') user.creation_time = datetime(1997, 1, 1) user.last_edit_time = datetime(1998, 1, 1) - user.avatar_style = db.User.AVATAR_MANUAL - user.rank = db.User.RANK_ADMINISTRATOR + user.avatar_style = model.User.AVATAR_MANUAL + user.rank = model.User.RANK_ADMINISTRATOR db.session.add(user) db.session.flush() assert users.serialize_user(user, auth_user) == { @@ -222,7 +222,7 @@ def test_create_user_for_first_user(fake_datetime): user = users.create_user('name', 'password', 'email') assert user.creation_time == datetime(1997, 1, 1) assert user.last_login_time is None - assert user.rank == db.User.RANK_ADMINISTRATOR + assert user.rank == model.User.RANK_ADMINISTRATOR users.update_user_name.assert_called_once_with(user, 'name') users.update_user_password.assert_called_once_with(user, 'password') users.update_user_email.assert_called_once_with(user, 'email') @@ -236,7 +236,7 @@ def test_create_user_for_subsequent_users(user_factory, config_injector): patch('szurubooru.func.users.update_user_email'), \ patch('szurubooru.func.users.update_user_password'): user = users.create_user('name', 'password', 'email') - assert user.rank == db.User.RANK_REGULAR + assert user.rank == model.User.RANK_REGULAR def test_update_user_name_with_empty_string(user_factory): @@ -379,7 +379,7 @@ def test_update_user_rank_with_higher_rank_than_possible(user_factory): db.session.flush() user = user_factory() auth_user = user_factory() - auth_user.rank = db.User.RANK_ANONYMOUS + auth_user.rank = model.User.RANK_ANONYMOUS with pytest.raises(errors.AuthError): users.update_user_rank(user, 'regular', auth_user) with pytest.raises(errors.AuthError): @@ -391,11 +391,11 @@ def test_update_user_rank(user_factory): db.session.flush() user = user_factory() auth_user = user_factory() - auth_user.rank = db.User.RANK_ADMINISTRATOR + auth_user.rank = model.User.RANK_ADMINISTRATOR users.update_user_rank(user, 'regular', auth_user) users.update_user_rank(auth_user, 'regular', auth_user) - assert user.rank == db.User.RANK_REGULAR - assert auth_user.rank == db.User.RANK_REGULAR + assert user.rank == model.User.RANK_REGULAR + assert auth_user.rank == model.User.RANK_REGULAR def test_update_user_avatar_with_invalid_style(user_factory): @@ -407,7 +407,7 @@ def test_update_user_avatar_with_invalid_style(user_factory): def test_update_user_avatar_to_gravatar(user_factory): user = user_factory() users.update_user_avatar(user, 'gravatar') - assert user.avatar_style == db.User.AVATAR_GRAVATAR + assert user.avatar_style == model.User.AVATAR_GRAVATAR def test_update_user_avatar_to_empty_manual(user_factory): @@ -431,7 +431,7 @@ def test_update_user_avatar_to_new_manual(user_factory, config_injector): user = user_factory() with patch('szurubooru.func.files.save'): users.update_user_avatar(user, 'manual', EMPTY_PIXEL) - assert user.avatar_style == db.User.AVATAR_MANUAL + assert user.avatar_style == model.User.AVATAR_MANUAL assert files.save.called diff --git a/server/szurubooru/tests/db/__init__.py b/server/szurubooru/tests/model/__init__.py similarity index 100% rename from server/szurubooru/tests/db/__init__.py rename to server/szurubooru/tests/model/__init__.py diff --git a/server/szurubooru/tests/db/test_comment.py b/server/szurubooru/tests/model/test_comment.py similarity index 74% rename from server/szurubooru/tests/db/test_comment.py rename to server/szurubooru/tests/model/test_comment.py index 9a78f952..ffd51893 100644 --- a/server/szurubooru/tests/db/test_comment.py +++ b/server/szurubooru/tests/model/test_comment.py @@ -1,11 +1,11 @@ from datetime import datetime -from szurubooru import db +from szurubooru import db, model def test_saving_comment(user_factory, post_factory): user = user_factory() post = post_factory() - comment = db.Comment() + comment = model.Comment() comment.text = 'long text' * 1000 comment.user = user comment.post = post @@ -29,7 +29,7 @@ def test_cascade_deletions(comment_factory, user_factory, post_factory): db.session.add_all([user, comment]) db.session.flush() - score = db.CommentScore() + score = model.CommentScore() score.comment = comment score.user = user score.time = datetime(1997, 1, 1) @@ -39,14 +39,14 @@ def test_cascade_deletions(comment_factory, user_factory, post_factory): assert not db.session.dirty assert comment.user is not None and comment.user.user_id is not None - assert db.session.query(db.User).count() == 1 - assert db.session.query(db.Comment).count() == 1 - assert db.session.query(db.CommentScore).count() == 1 + assert db.session.query(model.User).count() == 1 + assert db.session.query(model.Comment).count() == 1 + assert db.session.query(model.CommentScore).count() == 1 db.session.delete(comment) db.session.commit() assert not db.session.dirty - assert db.session.query(db.User).count() == 1 - assert db.session.query(db.Comment).count() == 0 - assert db.session.query(db.CommentScore).count() == 0 + assert db.session.query(model.User).count() == 1 + assert db.session.query(model.Comment).count() == 0 + assert db.session.query(model.CommentScore).count() == 0 diff --git a/server/szurubooru/tests/db/test_post.py b/server/szurubooru/tests/model/test_post.py similarity index 71% rename from server/szurubooru/tests/db/test_post.py rename to server/szurubooru/tests/model/test_post.py index c0213535..f35e2751 100644 --- a/server/szurubooru/tests/db/test_post.py +++ b/server/szurubooru/tests/model/test_post.py @@ -1,5 +1,5 @@ from datetime import datetime -from szurubooru import db +from szurubooru import db, model def test_saving_post(post_factory, user_factory, tag_factory): @@ -8,7 +8,7 @@ def test_saving_post(post_factory, user_factory, tag_factory): tag2 = tag_factory() related_post1 = post_factory() related_post2 = post_factory() - post = db.Post() + post = model.Post() post.safety = 'safety' post.type = 'type' post.checksum = 'deadbeef' @@ -54,20 +54,20 @@ def test_cascade_deletions( user, tag1, tag2, post, related_post1, related_post2, comment]) db.session.flush() - score = db.PostScore() + score = model.PostScore() score.post = post score.user = user score.time = datetime(1997, 1, 1) score.score = 1 - favorite = db.PostFavorite() + favorite = model.PostFavorite() favorite.post = post favorite.user = user favorite.time = datetime(1997, 1, 1) - feature = db.PostFeature() + feature = model.PostFeature() feature.post = post feature.user = user feature.time = datetime(1997, 1, 1) - note = db.PostNote() + note = model.PostNote() note.post = post note.polygon = '' note.text = '' @@ -88,31 +88,31 @@ def test_cascade_deletions( assert not db.session.dirty assert post.user is not None and post.user.user_id is not None assert len(post.relations) == 1 - assert db.session.query(db.User).count() == 1 - assert db.session.query(db.Tag).count() == 2 - assert db.session.query(db.Post).count() == 3 - assert db.session.query(db.PostTag).count() == 2 - assert db.session.query(db.PostRelation).count() == 2 - assert db.session.query(db.PostScore).count() == 1 - assert db.session.query(db.PostNote).count() == 1 - assert db.session.query(db.PostFeature).count() == 1 - assert db.session.query(db.PostFavorite).count() == 1 - assert db.session.query(db.Comment).count() == 1 + assert db.session.query(model.User).count() == 1 + assert db.session.query(model.Tag).count() == 2 + assert db.session.query(model.Post).count() == 3 + assert db.session.query(model.PostTag).count() == 2 + assert db.session.query(model.PostRelation).count() == 2 + assert db.session.query(model.PostScore).count() == 1 + assert db.session.query(model.PostNote).count() == 1 + assert db.session.query(model.PostFeature).count() == 1 + assert db.session.query(model.PostFavorite).count() == 1 + assert db.session.query(model.Comment).count() == 1 db.session.delete(post) db.session.commit() assert not db.session.dirty - assert db.session.query(db.User).count() == 1 - assert db.session.query(db.Tag).count() == 2 - assert db.session.query(db.Post).count() == 2 - assert db.session.query(db.PostTag).count() == 0 - assert db.session.query(db.PostRelation).count() == 0 - assert db.session.query(db.PostScore).count() == 0 - assert db.session.query(db.PostNote).count() == 0 - assert db.session.query(db.PostFeature).count() == 0 - assert db.session.query(db.PostFavorite).count() == 0 - assert db.session.query(db.Comment).count() == 0 + assert db.session.query(model.User).count() == 1 + assert db.session.query(model.Tag).count() == 2 + assert db.session.query(model.Post).count() == 2 + assert db.session.query(model.PostTag).count() == 0 + assert db.session.query(model.PostRelation).count() == 0 + assert db.session.query(model.PostScore).count() == 0 + assert db.session.query(model.PostNote).count() == 0 + assert db.session.query(model.PostFeature).count() == 0 + assert db.session.query(model.PostFavorite).count() == 0 + assert db.session.query(model.Comment).count() == 0 def test_tracking_tag_count(post_factory, tag_factory): diff --git a/server/szurubooru/tests/db/test_tag.py b/server/szurubooru/tests/model/test_tag.py similarity index 80% rename from server/szurubooru/tests/db/test_tag.py rename to server/szurubooru/tests/model/test_tag.py index 02134d69..7d3d8d2f 100644 --- a/server/szurubooru/tests/db/test_tag.py +++ b/server/szurubooru/tests/model/test_tag.py @@ -1,5 +1,5 @@ from datetime import datetime -from szurubooru import db +from szurubooru import db, model def test_saving_tag(tag_factory): @@ -7,11 +7,11 @@ def test_saving_tag(tag_factory): sug2 = tag_factory(names=['sug2']) imp1 = tag_factory(names=['imp1']) imp2 = tag_factory(names=['imp2']) - tag = db.Tag() - tag.names = [db.TagName('alias1', 0), db.TagName('alias2', 1)] + tag = model.Tag() + tag.names = [model.TagName('alias1', 0), model.TagName('alias2', 1)] tag.suggestions = [] tag.implications = [] - tag.category = db.TagCategory('category') + tag.category = model.TagCategory('category') tag.creation_time = datetime(1997, 1, 1) tag.last_edit_time = datetime(1998, 1, 1) db.session.add_all([tag, sug1, sug2, imp1, imp2]) @@ -29,9 +29,9 @@ def test_saving_tag(tag_factory): db.session.commit() tag = db.session \ - .query(db.Tag) \ - .join(db.TagName) \ - .filter(db.TagName.name == 'alias1') \ + .query(model.Tag) \ + .join(model.TagName) \ + .filter(model.TagName.name == 'alias1') \ .one() assert [tag_name.name for tag_name in tag.names] == ['alias1', 'alias2'] assert tag.category.name == 'category' @@ -48,11 +48,11 @@ def test_cascade_deletions(tag_factory): sug2 = tag_factory(names=['sug2']) imp1 = tag_factory(names=['imp1']) imp2 = tag_factory(names=['imp2']) - tag = db.Tag() - tag.names = [db.TagName('alias1', 0), db.TagName('alias2', 1)] + tag = model.Tag() + tag.names = [model.TagName('alias1', 0), model.TagName('alias2', 1)] tag.suggestions = [] tag.implications = [] - tag.category = db.TagCategory('category') + tag.category = model.TagCategory('category') tag.creation_time = datetime(1997, 1, 1) tag.last_edit_time = datetime(1998, 1, 1) tag.post_count = 1 @@ -72,10 +72,10 @@ def test_cascade_deletions(tag_factory): db.session.delete(tag) db.session.commit() - assert db.session.query(db.Tag).count() == 4 - assert db.session.query(db.TagName).count() == 4 - assert db.session.query(db.TagImplication).count() == 0 - assert db.session.query(db.TagSuggestion).count() == 0 + assert db.session.query(model.Tag).count() == 4 + assert db.session.query(model.TagName).count() == 4 + assert db.session.query(model.TagImplication).count() == 0 + assert db.session.query(model.TagSuggestion).count() == 0 def test_tracking_post_count(post_factory, tag_factory): diff --git a/server/szurubooru/tests/db/test_user.py b/server/szurubooru/tests/model/test_user.py similarity index 66% rename from server/szurubooru/tests/db/test_user.py rename to server/szurubooru/tests/model/test_user.py index 59933e36..08875fa2 100644 --- a/server/szurubooru/tests/db/test_user.py +++ b/server/szurubooru/tests/model/test_user.py @@ -1,16 +1,16 @@ from datetime import datetime -from szurubooru import db +from szurubooru import db, model def test_saving_user(): - user = db.User() + user = model.User() user.name = 'name' user.password_salt = 'salt' user.password_hash = 'hash' user.email = 'email' user.rank = 'rank' user.creation_time = datetime(1997, 1, 1) - user.avatar_style = db.User.AVATAR_GRAVATAR + user.avatar_style = model.User.AVATAR_GRAVATAR db.session.add(user) db.session.flush() db.session.refresh(user) @@ -21,7 +21,7 @@ def test_saving_user(): assert user.email == 'email' assert user.rank == 'rank' assert user.creation_time == datetime(1997, 1, 1) - assert user.avatar_style == db.User.AVATAR_GRAVATAR + assert user.avatar_style == model.User.AVATAR_GRAVATAR def test_upload_count(user_factory, post_factory): @@ -61,8 +61,8 @@ def test_favorite_count(user_factory, post_factory): post1 = post_factory() post2 = post_factory() db.session.add_all([ - db.PostFavorite(post=post1, time=datetime.utcnow(), user=user1), - db.PostFavorite(post=post2, time=datetime.utcnow(), user=user2), + model.PostFavorite(post=post1, time=datetime.utcnow(), user=user1), + model.PostFavorite(post=post2, time=datetime.utcnow(), user=user2), ]) db.session.flush() db.session.refresh(user1) @@ -79,8 +79,10 @@ def test_liked_post_count(user_factory, post_factory): post1 = post_factory() post2 = post_factory() db.session.add_all([ - db.PostScore(post=post1, time=datetime.utcnow(), user=user1, score=1), - db.PostScore(post=post2, time=datetime.utcnow(), user=user2, score=1), + model.PostScore( + post=post1, time=datetime.utcnow(), user=user1, score=1), + model.PostScore( + post=post2, time=datetime.utcnow(), user=user2, score=1), ]) db.session.flush() db.session.refresh(user1) @@ -98,8 +100,10 @@ def test_disliked_post_count(user_factory, post_factory): post1 = post_factory() post2 = post_factory() db.session.add_all([ - db.PostScore(post=post1, time=datetime.utcnow(), user=user1, score=-1), - db.PostScore(post=post2, time=datetime.utcnow(), user=user2, score=1), + model.PostScore( + post=post1, time=datetime.utcnow(), user=user1, score=-1), + model.PostScore( + post=post2, time=datetime.utcnow(), user=user2, score=1), ]) db.session.flush() db.session.refresh(user1) @@ -114,34 +118,34 @@ def test_cascade_deletions(post_factory, user_factory, comment_factory): post = post_factory() post.user = user - post_score = db.PostScore() + post_score = model.PostScore() post_score.post = post post_score.user = user post_score.time = datetime(1997, 1, 1) post_score.score = 1 post.scores.append(post_score) - post_favorite = db.PostFavorite() + post_favorite = model.PostFavorite() post_favorite.post = post post_favorite.user = user post_favorite.time = datetime(1997, 1, 1) post.favorited_by.append(post_favorite) - post_feature = db.PostFeature() + post_feature = model.PostFeature() post_feature.post = post post_feature.user = user post_feature.time = datetime(1997, 1, 1) post.features.append(post_feature) comment = comment_factory(post=post, user=user) - comment_score = db.CommentScore() + comment_score = model.CommentScore() comment_score.comment = comment comment_score.user = user comment_score.time = datetime(1997, 1, 1) comment_score.score = 1 comment.scores.append(comment_score) - snapshot = db.Snapshot() + snapshot = model.Snapshot() snapshot.user = user snapshot.creation_time = datetime(1997, 1, 1) snapshot.resource_type = '-' @@ -154,27 +158,27 @@ def test_cascade_deletions(post_factory, user_factory, comment_factory): assert not db.session.dirty assert post.user is not None and post.user.user_id is not None - assert db.session.query(db.User).count() == 1 - assert db.session.query(db.Post).count() == 1 - assert db.session.query(db.PostScore).count() == 1 - assert db.session.query(db.PostFeature).count() == 1 - assert db.session.query(db.PostFavorite).count() == 1 - assert db.session.query(db.Comment).count() == 1 - assert db.session.query(db.CommentScore).count() == 1 - assert db.session.query(db.Snapshot).count() == 1 + assert db.session.query(model.User).count() == 1 + assert db.session.query(model.Post).count() == 1 + assert db.session.query(model.PostScore).count() == 1 + assert db.session.query(model.PostFeature).count() == 1 + assert db.session.query(model.PostFavorite).count() == 1 + assert db.session.query(model.Comment).count() == 1 + assert db.session.query(model.CommentScore).count() == 1 + assert db.session.query(model.Snapshot).count() == 1 db.session.delete(user) db.session.commit() assert not db.session.dirty - assert db.session.query(db.User).count() == 0 - assert db.session.query(db.Post).count() == 1 - assert db.session.query(db.Post)[0].user is None - assert db.session.query(db.PostScore).count() == 0 - assert db.session.query(db.PostFeature).count() == 0 - assert db.session.query(db.PostFavorite).count() == 0 - assert db.session.query(db.Comment).count() == 1 - assert db.session.query(db.Comment)[0].user is None - assert db.session.query(db.CommentScore).count() == 0 - assert db.session.query(db.Snapshot).count() == 1 - assert db.session.query(db.Snapshot)[0].user is None + assert db.session.query(model.User).count() == 0 + assert db.session.query(model.Post).count() == 1 + assert db.session.query(model.Post)[0].user is None + assert db.session.query(model.PostScore).count() == 0 + assert db.session.query(model.PostFeature).count() == 0 + assert db.session.query(model.PostFavorite).count() == 0 + assert db.session.query(model.Comment).count() == 1 + assert db.session.query(model.Comment)[0].user is None + assert db.session.query(model.CommentScore).count() == 0 + assert db.session.query(model.Snapshot).count() == 1 + assert db.session.query(model.Snapshot)[0].user is None diff --git a/server/szurubooru/tests/rest/test_context.py b/server/szurubooru/tests/rest/test_context.py index 7380a855..e112ebbe 100644 --- a/server/szurubooru/tests/rest/test_context.py +++ b/server/szurubooru/tests/rest/test_context.py @@ -8,13 +8,14 @@ from szurubooru.func import net def test_has_param(): ctx = rest.Context(method=None, url=None, params={'key': 'value'}) assert ctx.has_param('key') - assert not ctx.has_param('key2') + assert not ctx.has_param('non-existing') def test_get_file(): ctx = rest.Context(method=None, url=None, files={'key': b'content'}) assert ctx.get_file('key') == b'content' - assert ctx.get_file('key2') is None + with pytest.raises(errors.ValidationError): + ctx.get_file('non-existing') def test_get_file_from_url(): @@ -23,30 +24,33 @@ def test_get_file_from_url(): ctx = rest.Context( method=None, url=None, params={'keyUrl': 'example.com'}) assert ctx.get_file('key') == b'content' - assert ctx.get_file('key2') is None net.download.assert_called_once_with('example.com') + with pytest.raises(errors.ValidationError): + assert ctx.get_file('non-existing') def test_getting_list_parameter(): ctx = rest.Context( - method=None, url=None, params={'key': 'value', 'list': list('123')}) + method=None, + url=None, + params={'key': 'value', 'list': ['1', '2', '3']}) assert ctx.get_param_as_list('key') == ['value'] - assert ctx.get_param_as_list('key2') is None - assert ctx.get_param_as_list('key2', default=['def']) == ['def'] assert ctx.get_param_as_list('list') == ['1', '2', '3'] with pytest.raises(errors.ValidationError): - ctx.get_param_as_list('key2', required=True) + ctx.get_param_as_list('non-existing') + assert ctx.get_param_as_list('non-existing', default=['def']) == ['def'] def test_getting_string_parameter(): ctx = rest.Context( - method=None, url=None, params={'key': 'value', 'list': list('123')}) + method=None, + url=None, + params={'key': 'value', 'list': ['1', '2', '3']}) assert ctx.get_param_as_string('key') == 'value' - assert ctx.get_param_as_string('key2') is None - assert ctx.get_param_as_string('key2', default='def') == 'def' assert ctx.get_param_as_string('list') == '1,2,3' with pytest.raises(errors.ValidationError): - ctx.get_param_as_string('key2', required=True) + ctx.get_param_as_string('non-existing') + assert ctx.get_param_as_string('non-existing', default='x') == 'x' def test_getting_int_parameter(): @@ -55,12 +59,11 @@ def test_getting_int_parameter(): url=None, params={'key': '50', 'err': 'invalid', 'list': [1, 2, 3]}) assert ctx.get_param_as_int('key') == 50 - assert ctx.get_param_as_int('key2') is None - assert ctx.get_param_as_int('key2', default=5) == 5 with pytest.raises(errors.ValidationError): ctx.get_param_as_int('list') with pytest.raises(errors.ValidationError): - ctx.get_param_as_int('key2', required=True) + ctx.get_param_as_int('non-existing') + assert ctx.get_param_as_int('non-existing', default=5) == 5 with pytest.raises(errors.ValidationError): ctx.get_param_as_int('err') with pytest.raises(errors.ValidationError): @@ -102,7 +105,6 @@ def test_getting_bool_parameter(): test(['1', '2']) ctx = rest.Context(method=None, url=None) - assert ctx.get_param_as_bool('non-existing') is None - assert ctx.get_param_as_bool('non-existing', default=True) is True with pytest.raises(errors.ValidationError): - assert ctx.get_param_as_bool('non-existing', required=True) is None + ctx.get_param_as_bool('non-existing') + assert ctx.get_param_as_bool('non-existing', default=True) is True diff --git a/server/szurubooru/tests/search/configs/test_post_search_config.py b/server/szurubooru/tests/search/configs/test_post_search_config.py index d5796779..945a5e4f 100644 --- a/server/szurubooru/tests/search/configs/test_post_search_config.py +++ b/server/szurubooru/tests/search/configs/test_post_search_config.py @@ -1,13 +1,13 @@ # pylint: disable=redefined-outer-name from datetime import datetime import pytest -from szurubooru import db, errors, search +from szurubooru import db, model, errors, search @pytest.fixture def fav_factory(user_factory): def factory(post, user=None): - return db.PostFavorite( + return model.PostFavorite( post=post, user=user or user_factory(), time=datetime.utcnow()) @@ -17,7 +17,7 @@ def fav_factory(user_factory): @pytest.fixture def score_factory(user_factory): def factory(post, user=None, score=1): - return db.PostScore( + return model.PostScore( post=post, user=user or user_factory(), time=datetime.utcnow(), @@ -28,7 +28,7 @@ def score_factory(user_factory): @pytest.fixture def note_factory(): def factory(): - return db.PostNote(polygon='...', text='...') + return model.PostNote(polygon='...', text='...') return factory @@ -36,11 +36,11 @@ def note_factory(): def feature_factory(user_factory): def factory(post=None): if post: - return db.PostFeature( + return model.PostFeature( time=datetime.utcnow(), user=user_factory(), post=post) - return db.PostFeature( + return model.PostFeature( time=datetime.utcnow(), user=user_factory()) return factory @@ -123,7 +123,7 @@ def test_filter_by_score( post3 = post_factory(id=3) for post in [post1, post2, post3]: db.session.add( - db.PostScore( + model.PostScore( score=post.post_id, time=datetime.utcnow(), post=post, @@ -332,10 +332,10 @@ def test_filter_by_type( post2 = post_factory(id=2) post3 = post_factory(id=3) post4 = post_factory(id=4) - post1.type = db.Post.TYPE_IMAGE - post2.type = db.Post.TYPE_ANIMATION - post3.type = db.Post.TYPE_VIDEO - post4.type = db.Post.TYPE_FLASH + post1.type = model.Post.TYPE_IMAGE + post2.type = model.Post.TYPE_ANIMATION + post3.type = model.Post.TYPE_VIDEO + post4.type = model.Post.TYPE_FLASH db.session.add_all([post1, post2, post3, post4]) db.session.flush() verify_unpaged(input, expected_post_ids) @@ -352,9 +352,9 @@ def test_filter_by_safety( post1 = post_factory(id=1) post2 = post_factory(id=2) post3 = post_factory(id=3) - post1.safety = db.Post.SAFETY_SAFE - post2.safety = db.Post.SAFETY_SKETCHY - post3.safety = db.Post.SAFETY_UNSAFE + post1.safety = model.Post.SAFETY_SAFE + post2.safety = model.Post.SAFETY_SKETCHY + post3.safety = model.Post.SAFETY_UNSAFE db.session.add_all([post1, post2, post3]) db.session.flush() verify_unpaged(input, expected_post_ids) diff --git a/server/test b/server/test index 6d7bb6de..69cfe542 100755 --- a/server/test +++ b/server/test @@ -4,4 +4,5 @@ import sys pytest.main([ '--cov-report=term-missing', '--cov=szurubooru', + '--tb=short', ] + (sys.argv[1:] or ['szurubooru']))