diff --git a/server/.pylintrc b/server/.pylintrc index b637369f..1188f6df 100644 --- a/server/.pylintrc +++ b/server/.pylintrc @@ -1,6 +1,7 @@ [basic] function-rgx=^_?[a-z_][a-z0-9_]{2,}$|^test_ method-rgx=^[a-z_][a-z0-9_]{2,}$|^test_ +const-rgx=^[A-Z_]+$|^_[a-zA-Z_]*$ good-names=ex,_,logger [variables] diff --git a/server/host-waitress b/server/host-waitress index c1798582..75980818 100755 --- a/server/host-waitress +++ b/server/host-waitress @@ -10,7 +10,7 @@ import argparse import os.path import sys import waitress -from szurubooru.app import create_app +from szurubooru.facade import create_app def main(): parser = argparse.ArgumentParser('Starts szurubooru using waitress.') diff --git a/server/requirements.txt b/server/requirements.txt index b1079d78..fd68f896 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -1,6 +1,5 @@ alembic>=0.8.5 pyyaml>=3.11 -falcon>=0.3.0 psycopg2>=2.6.1 SQLAlchemy>=1.0.12 pytest>=2.9.1 diff --git a/server/szurubooru/api/__init__.py b/server/szurubooru/api/__init__.py index 59b8d6be..308b86bf 100644 --- a/server/szurubooru/api/__init__.py +++ b/server/szurubooru/api/__init__.py @@ -1,27 +1,8 @@ -''' Falcon-compatible API facades. ''' - -from szurubooru.api.password_reset_api import PasswordResetApi -from szurubooru.api.user_api import UserListApi, UserDetailApi -from szurubooru.api.tag_api import ( - TagListApi, - TagDetailApi, - TagMergeApi, - TagSiblingsApi) -from szurubooru.api.tag_category_api import ( - TagCategoryListApi, - TagCategoryDetailApi, - DefaultTagCategoryApi) -from szurubooru.api.comment_api import ( - CommentListApi, - CommentDetailApi, - CommentScoreApi) -from szurubooru.api.post_api import ( - PostListApi, - PostDetailApi, - PostFeatureApi, - PostScoreApi, - PostFavoriteApi, - PostsAroundApi) -from szurubooru.api.snapshot_api import SnapshotListApi -from szurubooru.api.info_api import InfoApi -from szurubooru.api.context import Context, Request +import szurubooru.api.info_api +import szurubooru.api.user_api +import szurubooru.api.post_api +import szurubooru.api.tag_api +import szurubooru.api.tag_category_api +import szurubooru.api.comment_api +import szurubooru.api.password_reset_api +import szurubooru.api.snapshot_api diff --git a/server/szurubooru/api/base_api.py b/server/szurubooru/api/base_api.py deleted file mode 100644 index 9fd2096c..00000000 --- a/server/szurubooru/api/base_api.py +++ /dev/null @@ -1,27 +0,0 @@ -import types - -def _bind_method(target, desired_method_name): - actual_method = getattr(target, desired_method_name) - def _wrapper_method(_self, request, _response, *args, **kwargs): - request.context.output = \ - actual_method(request.context, *args, **kwargs) - return types.MethodType(_wrapper_method, target) - -class BaseApi(object): - ''' - A wrapper around falcon's API interface that eases input and output - management. - ''' - - def __init__(self): - self._translate_routes() - - def _translate_routes(self): - for method_name in ['GET', 'PUT', 'POST', 'DELETE']: - desired_method_name = method_name.lower() - falcon_method_name = 'on_%s' % method_name.lower() - if hasattr(self, desired_method_name): - setattr( - self, - falcon_method_name, - _bind_method(self, desired_method_name)) diff --git a/server/szurubooru/api/comment_api.py b/server/szurubooru/api/comment_api.py index d5362047..7ac72bc2 100644 --- a/server/szurubooru/api/comment_api.py +++ b/server/szurubooru/api/comment_api.py @@ -1,7 +1,9 @@ import datetime from szurubooru import search -from szurubooru.api.base_api import BaseApi from szurubooru.func import auth, comments, posts, scores, util +from szurubooru.rest import routes + +_search_executor = search.Executor(search.configs.CommentSearchConfig()) def _serialize(ctx, comment, **kwargs): return comments.serialize_comment( @@ -9,67 +11,65 @@ def _serialize(ctx, comment, **kwargs): ctx.user, options=util.get_serialization_options(ctx), **kwargs) -class CommentListApi(BaseApi): - def __init__(self): - super().__init__() - self._search_executor = search.Executor( - search.configs.CommentSearchConfig()) +@routes.get('/comments/?') +def get_comments(ctx, _params=None): + auth.verify_privilege(ctx.user, 'comments:list') + return _search_executor.execute_and_serialize( + ctx, lambda comment: _serialize(ctx, comment)) - def get(self, ctx): - auth.verify_privilege(ctx.user, 'comments:list') - return self._search_executor.execute_and_serialize( - ctx, - lambda comment: _serialize(ctx, comment)) +@routes.post('/comments/?') +def create_comment(ctx, _params=None): + auth.verify_privilege(ctx.user, 'comments:create') + text = ctx.get_param_as_string('text', required=True) + post_id = ctx.get_param_as_int('postId', required=True) + post = posts.get_post_by_id(post_id) + comment = comments.create_comment(ctx.user, post, text) + ctx.session.add(comment) + ctx.session.commit() + return _serialize(ctx, comment) - def post(self, ctx): - auth.verify_privilege(ctx.user, 'comments:create') - text = ctx.get_param_as_string('text', required=True) - post_id = ctx.get_param_as_int('postId', required=True) - post = posts.get_post_by_id(post_id) - comment = comments.create_comment(ctx.user, post, text) - ctx.session.add(comment) - ctx.session.commit() - return _serialize(ctx, comment) +@routes.get('/comment/(?P[^/]+)/?') +def get_comment(ctx, params): + auth.verify_privilege(ctx.user, 'comments:view') + comment = comments.get_comment_by_id(params['comment_id']) + return _serialize(ctx, comment) -class CommentDetailApi(BaseApi): - def get(self, ctx, comment_id): - auth.verify_privilege(ctx.user, 'comments:view') - comment = comments.get_comment_by_id(comment_id) - return _serialize(ctx, comment) +@routes.put('/comment/(?P[^/]+)/?') +def update_comment(ctx, params): + comment = comments.get_comment_by_id(params['comment_id']) + util.verify_version(comment, ctx) + infix = 'own' if ctx.user.user_id == comment.user_id else 'any' + text = ctx.get_param_as_string('text', required=True) + auth.verify_privilege(ctx.user, 'comments:edit:%s' % infix) + comments.update_comment_text(comment, text) + util.bump_version(comment) + comment.last_edit_time = datetime.datetime.utcnow() + ctx.session.commit() + return _serialize(ctx, comment) - def put(self, ctx, comment_id): - comment = comments.get_comment_by_id(comment_id) - util.verify_version(comment, ctx) - infix = 'own' if ctx.user.user_id == comment.user_id else 'any' - text = ctx.get_param_as_string('text', required=True) - auth.verify_privilege(ctx.user, 'comments:edit:%s' % infix) - comments.update_comment_text(comment, text) - util.bump_version(comment) - comment.last_edit_time = datetime.datetime.utcnow() - ctx.session.commit() - return _serialize(ctx, comment) +@routes.delete('/comment/(?P[^/]+)/?') +def delete_comment(ctx, params): + comment = comments.get_comment_by_id(params['comment_id']) + util.verify_version(comment, ctx) + infix = 'own' if ctx.user.user_id == comment.user_id else 'any' + auth.verify_privilege(ctx.user, 'comments:delete:%s' % infix) + ctx.session.delete(comment) + ctx.session.commit() + return {} - def delete(self, ctx, comment_id): - comment = comments.get_comment_by_id(comment_id) - util.verify_version(comment, ctx) - infix = 'own' if ctx.user.user_id == comment.user_id else 'any' - auth.verify_privilege(ctx.user, 'comments:delete:%s' % infix) - ctx.session.delete(comment) - ctx.session.commit() - return {} +@routes.put('/comment/(?P[^/]+)/score/?') +def set_comment_score(ctx, params): + auth.verify_privilege(ctx.user, 'comments:score') + score = ctx.get_param_as_int('score', required=True) + comment = comments.get_comment_by_id(params['comment_id']) + scores.set_score(comment, ctx.user, score) + ctx.session.commit() + return _serialize(ctx, comment) -class CommentScoreApi(BaseApi): - def put(self, ctx, comment_id): - auth.verify_privilege(ctx.user, 'comments:score') - score = ctx.get_param_as_int('score', required=True) - comment = comments.get_comment_by_id(comment_id) - scores.set_score(comment, ctx.user, score) - ctx.session.commit() - return _serialize(ctx, comment) - - def delete(self, ctx, comment_id): - auth.verify_privilege(ctx.user, 'comments:score') - comment = comments.get_comment_by_id(comment_id) - scores.delete_score(comment, ctx.user) - ctx.session.commit() - return _serialize(ctx, comment) +@routes.delete('/comment/(?P[^/]+)/score/?') +def delete_comment_score(ctx, params): + auth.verify_privilege(ctx.user, 'comments:score') + comment = comments.get_comment_by_id(params['comment_id']) + scores.delete_score(comment, ctx.user) + ctx.session.commit() + return _serialize(ctx, comment) diff --git a/server/szurubooru/api/info_api.py b/server/szurubooru/api/info_api.py index f0e17512..16a4e384 100644 --- a/server/szurubooru/api/info_api.py +++ b/server/szurubooru/api/info_api.py @@ -1,47 +1,46 @@ import datetime import os from szurubooru import config -from szurubooru.api.base_api import BaseApi from szurubooru.func import posts, users, util +from szurubooru.rest import routes -class InfoApi(BaseApi): - def __init__(self): - super().__init__() - self._cache_time = None - self._cache_result = None +_cache_time = None +_cache_result = None - def get(self, ctx): - post_feature = posts.try_get_current_post_feature() - return { - 'postCount': posts.get_post_count(), - 'diskUsage': self._get_disk_usage(), - 'featuredPost': posts.serialize_post(post_feature.post, ctx.user) \ - if post_feature else None, - 'featuringTime': post_feature.time if post_feature else None, - 'featuringUser': users.serialize_user(post_feature.user, ctx.user) \ - if post_feature else None, - 'serverTime': datetime.datetime.utcnow(), - 'config': { - 'userNameRegex': config.config['user_name_regex'], - 'passwordRegex': config.config['password_regex'], - 'tagNameRegex': config.config['tag_name_regex'], - 'tagCategoryNameRegex': config.config['tag_category_name_regex'], - 'defaultUserRank': config.config['default_rank'], - 'privileges': util.snake_case_to_lower_camel_case_keys( - config.config['privileges']), - }, - } +def _get_disk_usage(): + global _cache_time, _cache_result # pylint: disable=global-statement + threshold = datetime.timedelta(hours=1) + now = datetime.datetime.utcnow() + if _cache_time and _cache_time > now - threshold: + return _cache_result + total_size = 0 + for dir_path, _, file_names in os.walk(config.config['data_dir']): + for file_name in file_names: + file_path = os.path.join(dir_path, file_name) + total_size += os.path.getsize(file_path) + _cache_time = now + _cache_result = total_size + return total_size - def _get_disk_usage(self): - threshold = datetime.timedelta(hours=1) - now = datetime.datetime.utcnow() - if self._cache_time and self._cache_time > now - threshold: - return self._cache_result - total_size = 0 - for dir_path, _, file_names in os.walk(config.config['data_dir']): - for file_name in file_names: - file_path = os.path.join(dir_path, file_name) - total_size += os.path.getsize(file_path) - self._cache_time = now - self._cache_result = total_size - return total_size +@routes.get('/info/?') +def get_info(ctx, _params=None): + post_feature = posts.try_get_current_post_feature() + return { + 'postCount': posts.get_post_count(), + 'diskUsage': _get_disk_usage(), + 'featuredPost': posts.serialize_post(post_feature.post, ctx.user) \ + if post_feature else None, + 'featuringTime': post_feature.time if post_feature else None, + 'featuringUser': users.serialize_user(post_feature.user, ctx.user) \ + if post_feature else None, + 'serverTime': datetime.datetime.utcnow(), + 'config': { + 'userNameRegex': config.config['user_name_regex'], + 'passwordRegex': config.config['password_regex'], + 'tagNameRegex': config.config['tag_name_regex'], + 'tagCategoryNameRegex': config.config['tag_category_name_regex'], + 'defaultUserRank': config.config['default_rank'], + 'privileges': util.snake_case_to_lower_camel_case_keys( + config.config['privileges']), + }, + } diff --git a/server/szurubooru/api/password_reset_api.py b/server/szurubooru/api/password_reset_api.py index 00a8d483..2040fb64 100644 --- a/server/szurubooru/api/password_reset_api.py +++ b/server/szurubooru/api/password_reset_api.py @@ -1,6 +1,6 @@ from szurubooru import config, errors -from szurubooru.api.base_api import BaseApi from szurubooru.func import auth, mailer, users, util +from szurubooru.rest import routes MAIL_SUBJECT = 'Password reset for {name}' MAIL_BODY = \ @@ -8,32 +8,35 @@ MAIL_BODY = \ 'If you wish to proceed, click this link: {url}\n' \ 'Otherwise, please ignore this email.' -class PasswordResetApi(BaseApi): - def get(self, _ctx, user_name): - ''' Send a mail with secure token to the correlated user. ''' - user = users.get_user_by_name_or_email(user_name) - if not user.email: - raise errors.ValidationError( - 'User %r hasn\'t supplied email. Cannot reset password.' % ( - user_name)) - token = auth.generate_authentication_token(user) - url = '%s/password-reset/%s:%s' % ( - config.config['base_url'].rstrip('/'), user.name, token) - mailer.send_mail( - 'noreply@%s' % config.config['name'], - user.email, - MAIL_SUBJECT.format(name=config.config['name']), - MAIL_BODY.format(name=config.config['name'], url=url)) - return {} +@routes.get('/password-reset/(?P[^/]+)/?') +def start_password_reset(_ctx, params): + ''' Send a mail with secure token to the correlated user. ''' + user_name = params['user_name'] + user = users.get_user_by_name_or_email(user_name) + if not user.email: + raise errors.ValidationError( + 'User %r hasn\'t supplied email. Cannot reset password.' % ( + user_name)) + token = auth.generate_authentication_token(user) + url = '%s/password-reset/%s:%s' % ( + config.config['base_url'].rstrip('/'), user.name, token) + mailer.send_mail( + 'noreply@%s' % config.config['name'], + user.email, + MAIL_SUBJECT.format(name=config.config['name']), + MAIL_BODY.format(name=config.config['name'], url=url)) + return {} - def post(self, ctx, user_name): - ''' Verify token from mail, generate a new password and return it. ''' - user = users.get_user_by_name_or_email(user_name) - good_token = auth.generate_authentication_token(user) - token = ctx.get_param_as_string('token', required=True) - if token != good_token: - raise errors.ValidationError('Invalid password reset token.') - new_password = users.reset_user_password(user) - util.bump_version(user) - ctx.session.commit() - return {'password': new_password} +@routes.post('/password-reset/(?P[^/]+)/?') +def finish_password_reset(ctx, params): + ''' Verify token from mail, generate a new password and return it. ''' + user_name = params['user_name'] + user = users.get_user_by_name_or_email(user_name) + good_token = auth.generate_authentication_token(user) + token = ctx.get_param_as_string('token', required=True) + if token != good_token: + raise errors.ValidationError('Invalid password reset token.') + new_password = users.reset_user_password(user) + util.bump_version(user) + ctx.session.commit() + return {'password': new_password} diff --git a/server/szurubooru/api/post_api.py b/server/szurubooru/api/post_api.py index 68a9835a..b75f40b7 100644 --- a/server/szurubooru/api/post_api.py +++ b/server/szurubooru/api/post_api.py @@ -1,7 +1,9 @@ import datetime from szurubooru import search -from szurubooru.api.base_api import BaseApi from szurubooru.func import auth, tags, posts, snapshots, favorites, scores, util +from szurubooru.rest import routes + +_search_executor = search.Executor(search.configs.PostSearchConfig()) def _serialize_post(ctx, post): return posts.serialize_post( @@ -9,165 +11,161 @@ def _serialize_post(ctx, post): ctx.user, options=util.get_serialization_options(ctx)) -class PostListApi(BaseApi): - def __init__(self): - super().__init__() - self._search_executor = search.Executor( - search.configs.PostSearchConfig()) +@routes.get('/posts/?') +def get_posts(ctx, _params=None): + auth.verify_privilege(ctx.user, 'posts:list') + _search_executor.config.user = ctx.user + return _search_executor.execute_and_serialize( + ctx, lambda post: _serialize_post(ctx, post)) - def get(self, ctx): - auth.verify_privilege(ctx.user, 'posts:list') - self._search_executor.config.user = ctx.user - return self._search_executor.execute_and_serialize( - ctx, lambda post: _serialize_post(ctx, post)) +@routes.post('/posts/?') +def create_post(ctx, _params=None): + anonymous = ctx.get_param_as_bool('anonymous', default=False) + if anonymous: + auth.verify_privilege(ctx.user, 'posts:create:anonymous') + else: + auth.verify_privilege(ctx.user, 'posts:create:identified') + content = ctx.get_file('content', required=True) + tag_names = ctx.get_param_as_list('tags', required=True) + safety = ctx.get_param_as_string('safety', required=True) + source = ctx.get_param_as_string('source', required=False, default=None) + if ctx.has_param('contentUrl') and not source: + source = ctx.get_param_as_string('contentUrl') + relations = ctx.get_param_as_list('relations', required=False) or [] + notes = ctx.get_param_as_list('notes', required=False) or [] + flags = ctx.get_param_as_list('flags', required=False) or [] - def post(self, ctx): - anonymous = ctx.get_param_as_bool('anonymous', default=False) - if anonymous: - auth.verify_privilege(ctx.user, 'posts:create:anonymous') - else: - auth.verify_privilege(ctx.user, 'posts:create:identified') - content = ctx.get_file('content', required=True) - tag_names = ctx.get_param_as_list('tags', required=True) - safety = ctx.get_param_as_string('safety', required=True) - source = ctx.get_param_as_string('source', required=False, default=None) - if ctx.has_param('contentUrl') and not source: - source = ctx.get_param_as_string('contentUrl') - relations = ctx.get_param_as_list('relations', required=False) or [] - notes = ctx.get_param_as_list('notes', required=False) or [] - flags = ctx.get_param_as_list('flags', required=False) or [] + post, new_tags = posts.create_post( + content, tag_names, None if anonymous else ctx.user) + if len(new_tags): + auth.verify_privilege(ctx.user, 'tags:create') + posts.update_post_safety(post, safety) + posts.update_post_source(post, source) + posts.update_post_relations(post, relations) + posts.update_post_notes(post, notes) + posts.update_post_flags(post, flags) + if ctx.has_file('thumbnail'): + posts.update_post_thumbnail(post, ctx.get_file('thumbnail')) + ctx.session.add(post) + snapshots.save_entity_creation(post, ctx.user) + ctx.session.commit() + tags.export_to_json() + return _serialize_post(ctx, post) - post, new_tags = posts.create_post( - content, tag_names, None if anonymous else ctx.user) +@routes.get('/post/(?P[^/]+)/?') +def get_post(ctx, params): + auth.verify_privilege(ctx.user, 'posts:view') + post = posts.get_post_by_id(params['post_id']) + return _serialize_post(ctx, post) + +@routes.put('/post/(?P[^/]+)/?') +def update_post(ctx, params): + post = posts.get_post_by_id(params['post_id']) + util.verify_version(post, ctx) + if ctx.has_file('content'): + auth.verify_privilege(ctx.user, 'posts:edit:content') + posts.update_post_content(post, ctx.get_file('content')) + if ctx.has_param('tags'): + auth.verify_privilege(ctx.user, 'posts:edit:tags') + new_tags = posts.update_post_tags(post, ctx.get_param_as_list('tags')) if len(new_tags): auth.verify_privilege(ctx.user, 'tags:create') - posts.update_post_safety(post, safety) - posts.update_post_source(post, source) - posts.update_post_relations(post, relations) - posts.update_post_notes(post, notes) - posts.update_post_flags(post, flags) - if ctx.has_file('thumbnail'): - posts.update_post_thumbnail(post, ctx.get_file('thumbnail')) - ctx.session.add(post) - snapshots.save_entity_creation(post, ctx.user) - ctx.session.commit() - tags.export_to_json() - return _serialize_post(ctx, post) + if ctx.has_param('safety'): + auth.verify_privilege(ctx.user, 'posts:edit:safety') + posts.update_post_safety(post, ctx.get_param_as_string('safety')) + if ctx.has_param('source'): + auth.verify_privilege(ctx.user, 'posts:edit:source') + posts.update_post_source(post, ctx.get_param_as_string('source')) + elif ctx.has_param('contentUrl'): + posts.update_post_source(post, ctx.get_param_as_string('contentUrl')) + if ctx.has_param('relations'): + auth.verify_privilege(ctx.user, 'posts:edit:relations') + posts.update_post_relations(post, ctx.get_param_as_list('relations')) + if ctx.has_param('notes'): + auth.verify_privilege(ctx.user, 'posts:edit:notes') + posts.update_post_notes(post, ctx.get_param_as_list('notes')) + if ctx.has_param('flags'): + auth.verify_privilege(ctx.user, 'posts:edit:flags') + posts.update_post_flags(post, ctx.get_param_as_list('flags')) + if ctx.has_file('thumbnail'): + auth.verify_privilege(ctx.user, 'posts:edit:thumbnail') + posts.update_post_thumbnail(post, ctx.get_file('thumbnail')) + util.bump_version(post) + post.last_edit_time = datetime.datetime.utcnow() + ctx.session.flush() + snapshots.save_entity_modification(post, ctx.user) + ctx.session.commit() + tags.export_to_json() + return _serialize_post(ctx, post) -class PostDetailApi(BaseApi): - def get(self, ctx, post_id): - auth.verify_privilege(ctx.user, 'posts:view') - post = posts.get_post_by_id(post_id) - return _serialize_post(ctx, post) +@routes.delete('/post/(?P[^/]+)/?') +def delete_post(ctx, params): + auth.verify_privilege(ctx.user, 'posts:delete') + post = posts.get_post_by_id(params['post_id']) + util.verify_version(post, ctx) + snapshots.save_entity_deletion(post, ctx.user) + posts.delete(post) + ctx.session.commit() + tags.export_to_json() + return {} - def put(self, ctx, post_id): - post = posts.get_post_by_id(post_id) - util.verify_version(post, ctx) - if ctx.has_file('content'): - auth.verify_privilege(ctx.user, 'posts:edit:content') - posts.update_post_content(post, ctx.get_file('content')) - if ctx.has_param('tags'): - auth.verify_privilege(ctx.user, 'posts:edit:tags') - new_tags = posts.update_post_tags(post, ctx.get_param_as_list('tags')) - if len(new_tags): - auth.verify_privilege(ctx.user, 'tags:create') - if ctx.has_param('safety'): - auth.verify_privilege(ctx.user, 'posts:edit:safety') - posts.update_post_safety(post, ctx.get_param_as_string('safety')) - if ctx.has_param('source'): - auth.verify_privilege(ctx.user, 'posts:edit:source') - posts.update_post_source(post, ctx.get_param_as_string('source')) - elif ctx.has_param('contentUrl'): - posts.update_post_source(post, ctx.get_param_as_string('contentUrl')) - if ctx.has_param('relations'): - auth.verify_privilege(ctx.user, 'posts:edit:relations') - posts.update_post_relations(post, ctx.get_param_as_list('relations')) - if ctx.has_param('notes'): - auth.verify_privilege(ctx.user, 'posts:edit:notes') - posts.update_post_notes(post, ctx.get_param_as_list('notes')) - if ctx.has_param('flags'): - auth.verify_privilege(ctx.user, 'posts:edit:flags') - posts.update_post_flags(post, ctx.get_param_as_list('flags')) - if ctx.has_file('thumbnail'): - auth.verify_privilege(ctx.user, 'posts:edit:thumbnail') - posts.update_post_thumbnail(post, ctx.get_file('thumbnail')) - util.bump_version(post) - post.last_edit_time = datetime.datetime.utcnow() - ctx.session.flush() - snapshots.save_entity_modification(post, ctx.user) - ctx.session.commit() - tags.export_to_json() - return _serialize_post(ctx, post) +@routes.get('/featured-post/?') +def get_featured_post(ctx, _params=None): + post = posts.try_get_featured_post() + return _serialize_post(ctx, post) - def delete(self, ctx, post_id): - auth.verify_privilege(ctx.user, 'posts:delete') - post = posts.get_post_by_id(post_id) - util.verify_version(post, ctx) - snapshots.save_entity_deletion(post, ctx.user) - posts.delete(post) - ctx.session.commit() - tags.export_to_json() - return {} +@routes.post('/featured-post/?') +def set_featured_post(ctx, _params=None): + auth.verify_privilege(ctx.user, 'posts:feature') + post_id = ctx.get_param_as_int('id', required=True) + post = posts.get_post_by_id(post_id) + featured_post = posts.try_get_featured_post() + if featured_post and featured_post.post_id == post.post_id: + raise posts.PostAlreadyFeaturedError( + 'Post %r is already featured.' % post_id) + posts.feature_post(post, ctx.user) + if featured_post: + snapshots.save_entity_modification(featured_post, ctx.user) + snapshots.save_entity_modification(post, ctx.user) + ctx.session.commit() + return _serialize_post(ctx, post) -class PostFeatureApi(BaseApi): - def post(self, ctx): - auth.verify_privilege(ctx.user, 'posts:feature') - post_id = ctx.get_param_as_int('id', required=True) - post = posts.get_post_by_id(post_id) - featured_post = posts.try_get_featured_post() - if featured_post and featured_post.post_id == post.post_id: - raise posts.PostAlreadyFeaturedError( - 'Post %r is already featured.' % post_id) - posts.feature_post(post, ctx.user) - if featured_post: - snapshots.save_entity_modification(featured_post, ctx.user) - snapshots.save_entity_modification(post, ctx.user) - ctx.session.commit() - return _serialize_post(ctx, post) +@routes.put('/post/(?P[^/]+)/score/?') +def set_post_score(ctx, params): + auth.verify_privilege(ctx.user, 'posts:score') + post = posts.get_post_by_id(params['post_id']) + score = ctx.get_param_as_int('score', required=True) + scores.set_score(post, ctx.user, score) + ctx.session.commit() + return _serialize_post(ctx, post) - def get(self, ctx): - post = posts.try_get_featured_post() - return _serialize_post(ctx, post) +@routes.delete('/post/(?P[^/]+)/score/?') +def delete_post_score(ctx, params): + auth.verify_privilege(ctx.user, 'posts:score') + post = posts.get_post_by_id(params['post_id']) + scores.delete_score(post, ctx.user) + ctx.session.commit() + return _serialize_post(ctx, post) -class PostScoreApi(BaseApi): - def put(self, ctx, post_id): - auth.verify_privilege(ctx.user, 'posts:score') - post = posts.get_post_by_id(post_id) - score = ctx.get_param_as_int('score', required=True) - scores.set_score(post, ctx.user, score) - ctx.session.commit() - return _serialize_post(ctx, post) +@routes.post('/post/(?P[^/]+)/favorite/?') +def add_post_to_favorites(ctx, params): + auth.verify_privilege(ctx.user, 'posts:favorite') + post = posts.get_post_by_id(params['post_id']) + favorites.set_favorite(post, ctx.user) + ctx.session.commit() + return _serialize_post(ctx, post) - def delete(self, ctx, post_id): - auth.verify_privilege(ctx.user, 'posts:score') - post = posts.get_post_by_id(post_id) - scores.delete_score(post, ctx.user) - ctx.session.commit() - return _serialize_post(ctx, post) +@routes.delete('/post/(?P[^/]+)/favorite/?') +def delete_post_from_favorites(ctx, params): + auth.verify_privilege(ctx.user, 'posts:favorite') + post = posts.get_post_by_id(params['post_id']) + favorites.unset_favorite(post, ctx.user) + ctx.session.commit() + return _serialize_post(ctx, post) -class PostFavoriteApi(BaseApi): - def post(self, ctx, post_id): - auth.verify_privilege(ctx.user, 'posts:favorite') - post = posts.get_post_by_id(post_id) - favorites.set_favorite(post, ctx.user) - ctx.session.commit() - return _serialize_post(ctx, post) - - def delete(self, ctx, post_id): - auth.verify_privilege(ctx.user, 'posts:favorite') - post = posts.get_post_by_id(post_id) - favorites.unset_favorite(post, ctx.user) - ctx.session.commit() - return _serialize_post(ctx, post) - -class PostsAroundApi(BaseApi): - def __init__(self): - super().__init__() - self._search_executor = search.Executor( - search.configs.PostSearchConfig()) - - def get(self, ctx, post_id): - auth.verify_privilege(ctx.user, 'posts:list') - self._search_executor.config.user = ctx.user - return self._search_executor.get_around_and_serialize( - ctx, post_id, lambda post: _serialize_post(ctx, post)) +@routes.get('/post/(?P[^/]+)/around/?') +def get_posts_around(ctx, params): + auth.verify_privilege(ctx.user, 'posts:list') + _search_executor.config.user = ctx.user + return _search_executor.get_around_and_serialize( + ctx, params['post_id'], lambda post: _serialize_post(ctx, post)) diff --git a/server/szurubooru/api/snapshot_api.py b/server/szurubooru/api/snapshot_api.py index 3f830f90..1fd6fc52 100644 --- a/server/szurubooru/api/snapshot_api.py +++ b/server/szurubooru/api/snapshot_api.py @@ -1,14 +1,12 @@ from szurubooru import search -from szurubooru.api.base_api import BaseApi from szurubooru.func import auth, snapshots +from szurubooru.rest import routes -class SnapshotListApi(BaseApi): - def __init__(self): - super().__init__() - self._search_executor = search.Executor( - search.configs.SnapshotSearchConfig()) +_search_executor = search.Executor( + search.configs.SnapshotSearchConfig()) - def get(self, ctx): - auth.verify_privilege(ctx.user, 'snapshots:list') - return self._search_executor.execute_and_serialize( - ctx, snapshots.serialize_snapshot) +@routes.get('/snapshots/?') +def get_snapshots(ctx, _params=None): + auth.verify_privilege(ctx.user, 'snapshots:list') + return _search_executor.execute_and_serialize( + ctx, snapshots.serialize_snapshot) diff --git a/server/szurubooru/api/tag_api.py b/server/szurubooru/api/tag_api.py index 2c8a2e2c..079e3dd5 100644 --- a/server/szurubooru/api/tag_api.py +++ b/server/szurubooru/api/tag_api.py @@ -1,7 +1,9 @@ import datetime from szurubooru import db, search -from szurubooru.api.base_api import BaseApi from szurubooru.func import auth, tags, util, snapshots +from szurubooru.rest import routes + +_search_executor = search.Executor(search.configs.TagSearchConfig()) def _serialize(ctx, tag): return tags.serialize_tag( @@ -17,116 +19,112 @@ def _create_if_needed(tag_names, user): for tag in new_tags: snapshots.save_entity_creation(tag, user) -class TagListApi(BaseApi): - def __init__(self): - super().__init__() - self._search_executor = search.Executor( - search.configs.TagSearchConfig()) +@routes.get('/tags/?') +def get_tags(ctx, _params=None): + auth.verify_privilege(ctx.user, 'tags:list') + return _search_executor.execute_and_serialize( + ctx, lambda tag: _serialize(ctx, tag)) - def get(self, ctx): - auth.verify_privilege(ctx.user, 'tags:list') - return self._search_executor.execute_and_serialize( - ctx, lambda tag: _serialize(ctx, tag)) +@routes.post('/tags/?') +def create_tag(ctx, _params=None): + auth.verify_privilege(ctx.user, 'tags:create') - def post(self, ctx): - auth.verify_privilege(ctx.user, 'tags:create') + names = ctx.get_param_as_list('names', required=True) + category = ctx.get_param_as_string('category', required=True) + description = ctx.get_param_as_string( + 'description', required=False, default=None) + suggestions = ctx.get_param_as_list( + 'suggestions', required=False, default=[]) + implications = ctx.get_param_as_list( + 'implications', required=False, default=[]) - names = ctx.get_param_as_list('names', required=True) - category = ctx.get_param_as_string('category', required=True) or '' - description = ctx.get_param_as_string( - 'description', required=False, default=None) - suggestions = ctx.get_param_as_list( - 'suggestions', required=False, default=[]) - implications = ctx.get_param_as_list( - 'implications', required=False, default=[]) + _create_if_needed(suggestions, ctx.user) + _create_if_needed(implications, ctx.user) + tag = tags.create_tag(names, category, suggestions, implications) + tags.update_tag_description(tag, description) + ctx.session.add(tag) + ctx.session.flush() + snapshots.save_entity_creation(tag, ctx.user) + ctx.session.commit() + tags.export_to_json() + return _serialize(ctx, tag) + +@routes.get('/tag/(?P[^/]+)/?') +def get_tag(ctx, params): + auth.verify_privilege(ctx.user, 'tags:view') + tag = tags.get_tag_by_name(params['tag_name']) + return _serialize(ctx, tag) + +@routes.put('/tag/(?P[^/]+)/?') +def update_tag(ctx, params): + tag = tags.get_tag_by_name(params['tag_name']) + util.verify_version(tag, ctx) + if ctx.has_param('names'): + auth.verify_privilege(ctx.user, 'tags:edit:names') + tags.update_tag_names(tag, ctx.get_param_as_list('names')) + if ctx.has_param('category'): + auth.verify_privilege(ctx.user, 'tags:edit:category') + tags.update_tag_category_name( + tag, ctx.get_param_as_string('category')) + if ctx.has_param('description'): + auth.verify_privilege(ctx.user, 'tags:edit:description') + tags.update_tag_description( + tag, ctx.get_param_as_string('description', default=None)) + if ctx.has_param('suggestions'): + auth.verify_privilege(ctx.user, 'tags:edit:suggestions') + suggestions = ctx.get_param_as_list('suggestions') _create_if_needed(suggestions, ctx.user) + tags.update_tag_suggestions(tag, suggestions) + if ctx.has_param('implications'): + auth.verify_privilege(ctx.user, 'tags:edit:implications') + implications = ctx.get_param_as_list('implications') _create_if_needed(implications, ctx.user) + tags.update_tag_implications(tag, implications) + util.bump_version(tag) + tag.last_edit_time = datetime.datetime.utcnow() + ctx.session.flush() + snapshots.save_entity_modification(tag, ctx.user) + ctx.session.commit() + tags.export_to_json() + return _serialize(ctx, tag) - tag = tags.create_tag(names, category, suggestions, implications) - tags.update_tag_description(tag, description) - ctx.session.add(tag) - ctx.session.flush() - snapshots.save_entity_creation(tag, ctx.user) - ctx.session.commit() - tags.export_to_json() - return _serialize(ctx, tag) +@routes.delete('/tag/(?P[^/]+)/?') +def delete_tag(ctx, params): + tag = tags.get_tag_by_name(params['tag_name']) + util.verify_version(tag, ctx) + auth.verify_privilege(ctx.user, 'tags:delete') + snapshots.save_entity_deletion(tag, ctx.user) + tags.delete(tag) + ctx.session.commit() + tags.export_to_json() + return {} -class TagDetailApi(BaseApi): - def get(self, ctx, tag_name): - auth.verify_privilege(ctx.user, 'tags:view') - tag = tags.get_tag_by_name(tag_name) - return _serialize(ctx, tag) +@routes.post('/tag-merge/?') +def merge_tags(ctx, _params=None): + source_tag_name = ctx.get_param_as_string('remove', required=True) or '' + target_tag_name = ctx.get_param_as_string('mergeTo', required=True) or '' + source_tag = tags.get_tag_by_name(source_tag_name) + target_tag = tags.get_tag_by_name(target_tag_name) + util.verify_version(source_tag, ctx, 'removeVersion') + util.verify_version(target_tag, ctx, 'mergeToVersion') + auth.verify_privilege(ctx.user, 'tags:merge') + tags.merge_tags(source_tag, target_tag) + snapshots.save_entity_deletion(source_tag, ctx.user) + util.bump_version(target_tag) + ctx.session.commit() + tags.export_to_json() + return _serialize(ctx, target_tag) - def put(self, ctx, tag_name): - tag = tags.get_tag_by_name(tag_name) - util.verify_version(tag, ctx) - if ctx.has_param('names'): - auth.verify_privilege(ctx.user, 'tags:edit:names') - tags.update_tag_names(tag, ctx.get_param_as_list('names')) - if ctx.has_param('category'): - auth.verify_privilege(ctx.user, 'tags:edit:category') - tags.update_tag_category_name( - tag, ctx.get_param_as_string('category') or '') - if ctx.has_param('description'): - auth.verify_privilege(ctx.user, 'tags:edit:description') - tags.update_tag_description( - tag, ctx.get_param_as_string('description', default=None)) - if ctx.has_param('suggestions'): - auth.verify_privilege(ctx.user, 'tags:edit:suggestions') - suggestions = ctx.get_param_as_list('suggestions') - _create_if_needed(suggestions, ctx.user) - tags.update_tag_suggestions(tag, suggestions) - if ctx.has_param('implications'): - auth.verify_privilege(ctx.user, 'tags:edit:implications') - implications = ctx.get_param_as_list('implications') - _create_if_needed(implications, ctx.user) - tags.update_tag_implications(tag, implications) - util.bump_version(tag) - tag.last_edit_time = datetime.datetime.utcnow() - ctx.session.flush() - snapshots.save_entity_modification(tag, ctx.user) - ctx.session.commit() - tags.export_to_json() - return _serialize(ctx, tag) - - def delete(self, ctx, tag_name): - tag = tags.get_tag_by_name(tag_name) - util.verify_version(tag, ctx) - auth.verify_privilege(ctx.user, 'tags:delete') - snapshots.save_entity_deletion(tag, ctx.user) - tags.delete(tag) - ctx.session.commit() - tags.export_to_json() - return {} - -class TagMergeApi(BaseApi): - def post(self, ctx): - source_tag_name = ctx.get_param_as_string('remove', required=True) or '' - target_tag_name = ctx.get_param_as_string('mergeTo', required=True) or '' - source_tag = tags.get_tag_by_name(source_tag_name) - target_tag = tags.get_tag_by_name(target_tag_name) - util.verify_version(source_tag, ctx, 'removeVersion') - util.verify_version(target_tag, ctx, 'mergeToVersion') - if source_tag.tag_id == target_tag.tag_id: - raise tags.InvalidTagRelationError('Cannot merge tag with itself.') - auth.verify_privilege(ctx.user, 'tags:merge') - snapshots.save_entity_deletion(source_tag, ctx.user) - tags.merge_tags(source_tag, target_tag) - util.bump_version(target_tag) - ctx.session.commit() - tags.export_to_json() - return _serialize(ctx, target_tag) - -class TagSiblingsApi(BaseApi): - def get(self, ctx, tag_name): - auth.verify_privilege(ctx.user, 'tags:view') - tag = tags.get_tag_by_name(tag_name) - result = tags.get_tag_siblings(tag) - serialized_siblings = [] - for sibling, occurrences in result: - serialized_siblings.append({ - 'tag': _serialize(ctx, sibling), - 'occurrences': occurrences - }) - return {'results': serialized_siblings} +@routes.get('/tag-siblings/(?P[^/]+)/?') +def get_tag_siblings(ctx, params): + auth.verify_privilege(ctx.user, 'tags:view') + tag = tags.get_tag_by_name(params['tag_name']) + result = tags.get_tag_siblings(tag) + serialized_siblings = [] + for sibling, occurrences in result: + serialized_siblings.append({ + 'tag': _serialize(ctx, sibling), + 'occurrences': occurrences + }) + return {'results': serialized_siblings} diff --git a/server/szurubooru/api/tag_category_api.py b/server/szurubooru/api/tag_category_api.py index dc092789..9ac9ffa5 100644 --- a/server/szurubooru/api/tag_category_api.py +++ b/server/szurubooru/api/tag_category_api.py @@ -1,70 +1,73 @@ -from szurubooru.api.base_api import BaseApi +from szurubooru.rest import routes from szurubooru.func import auth, tags, tag_categories, util, snapshots def _serialize(ctx, category): return tag_categories.serialize_category( category, options=util.get_serialization_options(ctx)) -class TagCategoryListApi(BaseApi): - def get(self, ctx): - auth.verify_privilege(ctx.user, 'tag_categories:list') - categories = tag_categories.get_all_categories() - return { - 'results': [_serialize(ctx, category) for category in categories], - } +@routes.get('/tag-categories/?') +def get_tag_categories(ctx, _params=None): + auth.verify_privilege(ctx.user, 'tag_categories:list') + categories = tag_categories.get_all_categories() + return { + 'results': [_serialize(ctx, category) for category in categories], + } - def post(self, ctx): - auth.verify_privilege(ctx.user, 'tag_categories:create') - name = ctx.get_param_as_string('name', required=True) - color = ctx.get_param_as_string('color', required=True) - category = tag_categories.create_category(name, color) - ctx.session.add(category) - ctx.session.flush() - snapshots.save_entity_creation(category, ctx.user) - ctx.session.commit() - tags.export_to_json() - return _serialize(ctx, category) +@routes.post('/tag-categories/?') +def create_tag_category(ctx, _params=None): + auth.verify_privilege(ctx.user, 'tag_categories:create') + name = ctx.get_param_as_string('name', required=True) + color = ctx.get_param_as_string('color', required=True) + category = tag_categories.create_category(name, color) + ctx.session.add(category) + ctx.session.flush() + snapshots.save_entity_creation(category, ctx.user) + ctx.session.commit() + tags.export_to_json() + return _serialize(ctx, category) -class TagCategoryDetailApi(BaseApi): - def get(self, ctx, category_name): - auth.verify_privilege(ctx.user, 'tag_categories:view') - category = tag_categories.get_category_by_name(category_name) - return _serialize(ctx, category) +@routes.get('/tag-category/(?P[^/]+)/?') +def get_tag_category(ctx, params): + auth.verify_privilege(ctx.user, 'tag_categories:view') + category = tag_categories.get_category_by_name(params['category_name']) + return _serialize(ctx, category) - def put(self, ctx, category_name): - category = tag_categories.get_category_by_name(category_name) - util.verify_version(category, ctx) - if ctx.has_param('name'): - auth.verify_privilege(ctx.user, 'tag_categories:edit:name') - tag_categories.update_category_name( - category, ctx.get_param_as_string('name')) - if ctx.has_param('color'): - auth.verify_privilege(ctx.user, 'tag_categories:edit:color') - tag_categories.update_category_color( - category, ctx.get_param_as_string('color')) - util.bump_version(category) - ctx.session.flush() - snapshots.save_entity_modification(category, ctx.user) - ctx.session.commit() - tags.export_to_json() - return _serialize(ctx, category) +@routes.put('/tag-category/(?P[^/]+)/?') +def update_tag_category(ctx, params): + category = tag_categories.get_category_by_name(params['category_name']) + util.verify_version(category, ctx) + if ctx.has_param('name'): + auth.verify_privilege(ctx.user, 'tag_categories:edit:name') + tag_categories.update_category_name( + category, ctx.get_param_as_string('name')) + if ctx.has_param('color'): + auth.verify_privilege(ctx.user, 'tag_categories:edit:color') + tag_categories.update_category_color( + category, ctx.get_param_as_string('color')) + util.bump_version(category) + ctx.session.flush() + snapshots.save_entity_modification(category, ctx.user) + ctx.session.commit() + tags.export_to_json() + return _serialize(ctx, category) - def delete(self, ctx, category_name): - category = tag_categories.get_category_by_name(category_name) - util.verify_version(category, ctx) - auth.verify_privilege(ctx.user, 'tag_categories:delete') - tag_categories.delete_category(category) - snapshots.save_entity_deletion(category, ctx.user) - ctx.session.commit() - tags.export_to_json() - return {} +@routes.delete('/tag-category/(?P[^/]+)/?') +def delete_tag_category(ctx, params): + category = tag_categories.get_category_by_name(params['category_name']) + util.verify_version(category, ctx) + auth.verify_privilege(ctx.user, 'tag_categories:delete') + tag_categories.delete_category(category) + snapshots.save_entity_deletion(category, ctx.user) + ctx.session.commit() + tags.export_to_json() + return {} -class DefaultTagCategoryApi(BaseApi): - def put(self, ctx, category_name): - auth.verify_privilege(ctx.user, 'tag_categories:set_default') - category = tag_categories.get_category_by_name(category_name) - tag_categories.set_default_category(category) - snapshots.save_entity_modification(category, ctx.user) - ctx.session.commit() - tags.export_to_json() - return _serialize(ctx, category) +@routes.put('/tag-category/(?P[^/]+)/default/?') +def set_tag_category_as_default(ctx, params): + auth.verify_privilege(ctx.user, 'tag_categories:set_default') + category = tag_categories.get_category_by_name(params['category_name']) + tag_categories.set_default_category(category) + snapshots.save_entity_modification(category, ctx.user) + ctx.session.commit() + tags.export_to_json() + return _serialize(ctx, category) diff --git a/server/szurubooru/api/user_api.py b/server/szurubooru/api/user_api.py index 706e5d74..94b82f3f 100644 --- a/server/szurubooru/api/user_api.py +++ b/server/szurubooru/api/user_api.py @@ -1,6 +1,8 @@ from szurubooru import search -from szurubooru.api.base_api import BaseApi from szurubooru.func import auth, users, util +from szurubooru.rest import routes + +_search_executor = search.Executor(search.configs.UserSearchConfig()) def _serialize(ctx, user, **kwargs): return users.serialize_user( @@ -9,75 +11,73 @@ def _serialize(ctx, user, **kwargs): options=util.get_serialization_options(ctx), **kwargs) -class UserListApi(BaseApi): - def __init__(self): - super().__init__() - self._search_executor = search.Executor( - search.configs.UserSearchConfig()) +@routes.get('/users/?') +def get_users(ctx, _params=None): + auth.verify_privilege(ctx.user, 'users:list') + return _search_executor.execute_and_serialize( + ctx, lambda user: _serialize(ctx, user)) - def get(self, ctx): - auth.verify_privilege(ctx.user, 'users:list') - return self._search_executor.execute_and_serialize( - ctx, lambda user: _serialize(ctx, user)) +@routes.post('/users/?') +def create_user(ctx, _params=None): + auth.verify_privilege(ctx.user, 'users:create') + name = ctx.get_param_as_string('name', required=True) + password = ctx.get_param_as_string('password', required=True) + email = ctx.get_param_as_string('email', required=False, default='') + user = users.create_user(name, password, email) + if ctx.has_param('rank'): + users.update_user_rank( + user, ctx.get_param_as_string('rank'), ctx.user) + if ctx.has_param('avatarStyle'): + users.update_user_avatar( + user, + ctx.get_param_as_string('avatarStyle'), + ctx.get_file('avatar')) + ctx.session.add(user) + ctx.session.commit() + return _serialize(ctx, user, force_show_email=True) - def post(self, ctx): - auth.verify_privilege(ctx.user, 'users:create') - name = ctx.get_param_as_string('name', required=True) - password = ctx.get_param_as_string('password', required=True) - email = ctx.get_param_as_string('email', required=False, default='') - user = users.create_user(name, password, email) - if ctx.has_param('rank'): - users.update_user_rank( - user, ctx.get_param_as_string('rank'), ctx.user) - if ctx.has_param('avatarStyle'): - users.update_user_avatar( - user, - ctx.get_param_as_string('avatarStyle'), - ctx.get_file('avatar')) - ctx.session.add(user) - ctx.session.commit() - return _serialize(ctx, user, force_show_email=True) +@routes.get('/user/(?P[^/]+)/?') +def get_user(ctx, params): + user = users.get_user_by_name(params['user_name']) + if ctx.user.user_id != user.user_id: + auth.verify_privilege(ctx.user, 'users:view') + return _serialize(ctx, user) -class UserDetailApi(BaseApi): - def get(self, ctx, user_name): - user = users.get_user_by_name(user_name) - if ctx.user.user_id != user.user_id: - auth.verify_privilege(ctx.user, 'users:view') - return _serialize(ctx, user) +@routes.put('/user/(?P[^/]+)/?') +def update_user(ctx, params): + user = users.get_user_by_name(params['user_name']) + util.verify_version(user, ctx) + infix = 'self' if ctx.user.user_id == user.user_id else 'any' + if ctx.has_param('name'): + auth.verify_privilege(ctx.user, 'users:edit:%s:name' % infix) + users.update_user_name(user, ctx.get_param_as_string('name')) + if ctx.has_param('password'): + auth.verify_privilege(ctx.user, 'users:edit:%s:pass' % infix) + users.update_user_password( + user, ctx.get_param_as_string('password')) + if ctx.has_param('email'): + auth.verify_privilege(ctx.user, 'users:edit:%s:email' % infix) + users.update_user_email(user, ctx.get_param_as_string('email')) + if ctx.has_param('rank'): + auth.verify_privilege(ctx.user, 'users:edit:%s:rank' % infix) + users.update_user_rank( + user, ctx.get_param_as_string('rank'), ctx.user) + if ctx.has_param('avatarStyle'): + auth.verify_privilege(ctx.user, 'users:edit:%s:avatar' % infix) + users.update_user_avatar( + user, + ctx.get_param_as_string('avatarStyle'), + ctx.get_file('avatar')) + util.bump_version(user) + ctx.session.commit() + return _serialize(ctx, user) - def put(self, ctx, user_name): - user = users.get_user_by_name(user_name) - util.verify_version(user, ctx) - infix = 'self' if ctx.user.user_id == user.user_id else 'any' - if ctx.has_param('name'): - auth.verify_privilege(ctx.user, 'users:edit:%s:name' % infix) - users.update_user_name(user, ctx.get_param_as_string('name')) - if ctx.has_param('password'): - auth.verify_privilege(ctx.user, 'users:edit:%s:pass' % infix) - users.update_user_password( - user, ctx.get_param_as_string('password')) - if ctx.has_param('email'): - auth.verify_privilege(ctx.user, 'users:edit:%s:email' % infix) - users.update_user_email(user, ctx.get_param_as_string('email')) - if ctx.has_param('rank'): - auth.verify_privilege(ctx.user, 'users:edit:%s:rank' % infix) - users.update_user_rank( - user, ctx.get_param_as_string('rank'), ctx.user) - if ctx.has_param('avatarStyle'): - auth.verify_privilege(ctx.user, 'users:edit:%s:avatar' % infix) - users.update_user_avatar( - user, - ctx.get_param_as_string('avatarStyle'), - ctx.get_file('avatar')) - util.bump_version(user) - ctx.session.commit() - return _serialize(ctx, user) - - def delete(self, ctx, user_name): - user = users.get_user_by_name(user_name) - util.verify_version(user, ctx) - infix = 'self' if ctx.user.user_id == user.user_id else 'any' - auth.verify_privilege(ctx.user, 'users:delete:%s' % infix) - ctx.session.delete(user) - ctx.session.commit() - return {} +@routes.delete('/user/(?P[^/]+)/?') +def delete_user(ctx, params): + user = users.get_user_by_name(params['user_name']) + util.verify_version(user, ctx) + infix = 'self' if ctx.user.user_id == user.user_id else 'any' + auth.verify_privilege(ctx.user, 'users:delete:%s' % infix) + ctx.session.delete(user) + ctx.session.commit() + return {} diff --git a/server/szurubooru/app.py b/server/szurubooru/app.py deleted file mode 100644 index 3d103796..00000000 --- a/server/szurubooru/app.py +++ /dev/null @@ -1,124 +0,0 @@ -''' Exports create_app. ''' - -import os -import logging -import coloredlogs -import falcon -from szurubooru import api, config, errors, middleware - -def _on_auth_error(ex, _request, _response, _params): - raise falcon.HTTPForbidden( - title='Authentication error', description=str(ex)) - -def _on_validation_error(ex, _request, _response, _params): - raise falcon.HTTPBadRequest(title='Validation error', description=str(ex)) - -def _on_search_error(ex, _request, _response, _params): - raise falcon.HTTPBadRequest(title='Search error', description=str(ex)) - -def _on_integrity_error(ex, _request, _response, _params): - raise falcon.HTTPConflict( - title='Integrity violation', description=ex.args[0]) - -def _on_not_found_error(ex, _request, _response, _params): - raise falcon.HTTPNotFound(title='Not found', description=str(ex)) - -def _on_processing_error(ex, _request, _response, _params): - raise falcon.HTTPBadRequest(title='Processing error', description=str(ex)) - -def create_method_not_allowed(allowed_methods): - allowed = ', '.join(allowed_methods) - def method_not_allowed(request, response, **_kwargs): - response.status = falcon.status_codes.HTTP_405 - response.set_header('Allow', allowed) - request.context.output = { - 'title': 'Method not allowed', - 'description': 'Allowed methods: %r' % allowed_methods, - } - return method_not_allowed - -def validate_config(): - ''' - Check whether config doesn't contain errors that might prove - lethal at runtime. - ''' - from szurubooru.func.auth import RANK_MAP - for privilege, rank in config.config['privileges'].items(): - if rank not in RANK_MAP.values(): - raise errors.ConfigError( - 'Rank %r for privilege %r is missing' % (rank, privilege)) - if config.config['default_rank'] not in RANK_MAP.values(): - raise errors.ConfigError( - 'Default rank %r is not on the list of known ranks' % ( - config.config['default_rank'])) - - for key in ['base_url', 'api_url', 'data_url', 'data_dir']: - if not config.config[key]: - raise errors.ConfigError( - 'Service is not configured: %r is missing' % key) - - if not os.path.isabs(config.config['data_dir']): - raise errors.ConfigError( - 'data_dir must be an absolute path') - - for key in ['schema', 'host', 'port', 'user', 'pass', 'name']: - if not config.config['database'][key]: - raise errors.ConfigError( - 'Database is not configured: %r is missing' % key) - -def create_app(): - ''' Create a WSGI compatible App object. ''' - validate_config() - falcon.responders.create_method_not_allowed = create_method_not_allowed - - coloredlogs.install(fmt='[%(asctime)-15s] %(name)s %(message)s') - if config.config['debug']: - logging.getLogger('szurubooru').setLevel(logging.INFO) - if config.config['show_sql']: - logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO) - - app = falcon.API( - request_type=api.Request, - middleware=[ - middleware.RequireJson(), - middleware.CachePurger(), - middleware.ContextAdapter(), - middleware.DbSession(), - middleware.Authenticator(), - middleware.RequestLogger(), - ]) - - app.add_error_handler(errors.AuthError, _on_auth_error) - app.add_error_handler(errors.IntegrityError, _on_integrity_error) - app.add_error_handler(errors.ValidationError, _on_validation_error) - app.add_error_handler(errors.SearchError, _on_search_error) - app.add_error_handler(errors.NotFoundError, _on_not_found_error) - app.add_error_handler(errors.ProcessingError, _on_processing_error) - - app.add_route('/users/', api.UserListApi()) - app.add_route('/user/{user_name}', api.UserDetailApi()) - app.add_route('/password-reset/{user_name}', api.PasswordResetApi()) - - app.add_route('/tag-categories/', api.TagCategoryListApi()) - app.add_route('/tag-category/{category_name}', api.TagCategoryDetailApi()) - app.add_route('/tag-category/{category_name}/default', api.DefaultTagCategoryApi()) - app.add_route('/tags/', api.TagListApi()) - app.add_route('/tag/{tag_name}', api.TagDetailApi()) - app.add_route('/tag-merge/', api.TagMergeApi()) - app.add_route('/tag-siblings/{tag_name}', api.TagSiblingsApi()) - - app.add_route('/posts/', api.PostListApi()) - app.add_route('/post/{post_id}', api.PostDetailApi()) - app.add_route('/post/{post_id}/score', api.PostScoreApi()) - app.add_route('/post/{post_id}/favorite', api.PostFavoriteApi()) - app.add_route('/post/{post_id}/around', api.PostsAroundApi()) - - app.add_route('/comments/', api.CommentListApi()) - app.add_route('/comment/{comment_id}', api.CommentDetailApi()) - app.add_route('/comment/{comment_id}/score', api.CommentScoreApi()) - - app.add_route('/info/', api.InfoApi()) - app.add_route('/featured-post/', api.PostFeatureApi()) - app.add_route('/snapshots/', api.SnapshotListApi()) - - return app diff --git a/server/szurubooru/facade.py b/server/szurubooru/facade.py new file mode 100644 index 00000000..50652c45 --- /dev/null +++ b/server/szurubooru/facade.py @@ -0,0 +1,79 @@ +''' Exports create_app. ''' + +import os +import logging +import coloredlogs +from szurubooru import config, errors, rest +# pylint: disable=unused-import +from szurubooru import api, middleware + +def _on_auth_error(ex): + raise rest.errors.HttpForbidden( + title='Authentication error', description=str(ex)) + +def _on_validation_error(ex): + raise rest.errors.HttpBadRequest( + title='Validation error', description=str(ex)) + +def _on_search_error(ex): + raise rest.errors.HttpBadRequest( + title='Search error', description=str(ex)) + +def _on_integrity_error(ex): + raise rest.errors.HttpConflict( + title='Integrity violation', description=ex.args[0]) + +def _on_not_found_error(ex): + raise rest.errors.HttpNotFound( + title='Not found', description=str(ex)) + +def _on_processing_error(ex): + raise rest.errors.HttpBadRequest( + title='Processing error', description=str(ex)) + +def validate_config(): + ''' + Check whether config doesn't contain errors that might prove + lethal at runtime. + ''' + from szurubooru.func.auth import RANK_MAP + for privilege, rank in config.config['privileges'].items(): + if rank not in RANK_MAP.values(): + raise errors.ConfigError( + 'Rank %r for privilege %r is missing' % (rank, privilege)) + if config.config['default_rank'] not in RANK_MAP.values(): + raise errors.ConfigError( + 'Default rank %r is not on the list of known ranks' % ( + config.config['default_rank'])) + + for key in ['base_url', 'api_url', 'data_url', 'data_dir']: + if not config.config[key]: + raise errors.ConfigError( + 'Service is not configured: %r is missing' % key) + + if not os.path.isabs(config.config['data_dir']): + raise errors.ConfigError( + 'data_dir must be an absolute path') + + for key in ['schema', 'host', 'port', 'user', 'pass', 'name']: + if not config.config['database'][key]: + raise errors.ConfigError( + 'Database is not configured: %r is missing' % key) + +def create_app(): + ''' Create a WSGI compatible App object. ''' + validate_config() + coloredlogs.install(fmt='[%(asctime)-15s] %(name)s %(message)s') + if config.config['debug']: + logging.getLogger('szurubooru').setLevel(logging.INFO) + if config.config['show_sql']: + logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO) + + rest.errors.handle(errors.AuthError, _on_auth_error) + rest.errors.handle(errors.ValidationError, _on_validation_error) + rest.errors.handle(errors.SearchError, _on_search_error) + rest.errors.handle(errors.IntegrityError, _on_integrity_error) + rest.errors.handle(errors.NotFoundError, _on_not_found_error) + rest.errors.handle(errors.ProcessingError, _on_processing_error) + + return rest.application diff --git a/server/szurubooru/func/snapshots.py b/server/szurubooru/func/snapshots.py index fdfa2b34..29da205b 100644 --- a/server/szurubooru/func/snapshots.py +++ b/server/szurubooru/func/snapshots.py @@ -32,13 +32,6 @@ def get_tag_category_snapshot(category): 'default': True if category.default else False, } -# pylint: disable=invalid-name -serializers = { - 'tag': get_tag_snapshot, - 'tag_category': get_tag_category_snapshot, - 'post': get_post_snapshot, -} - def get_previous_snapshot(snapshot): assert snapshot return db.session \ @@ -87,6 +80,12 @@ def get_serialized_history(entity): def _save(operation, entity, auth_user): assert operation assert entity + serializers = { + 'tag': get_tag_snapshot, + 'tag_category': get_tag_category_snapshot, + 'post': get_post_snapshot, + } + resource_type, resource_id, resource_repr = db.util.get_resource_info(entity) now = datetime.datetime.utcnow() diff --git a/server/szurubooru/func/util.py b/server/szurubooru/func/util.py index e0fc3ce2..a4bd9177 100644 --- a/server/szurubooru/func/util.py +++ b/server/szurubooru/func/util.py @@ -11,6 +11,9 @@ def snake_case_to_lower_camel_case(text): return components[0].lower() + \ ''.join(word[0].upper() + word[1:].lower() for word in components[1:]) +def snake_case_to_upper_train_case(text): + return '-'.join(word[0].upper() + word[1:].lower() for word in text.split('_')) + def snake_case_to_lower_camel_case_keys(source): target = {} for key, value in source.items(): diff --git a/server/szurubooru/middleware/__init__.py b/server/szurubooru/middleware/__init__.py index 1ffc2840..010af680 100644 --- a/server/szurubooru/middleware/__init__.py +++ b/server/szurubooru/middleware/__init__.py @@ -1,8 +1,6 @@ ''' Various hooks that get executed for each request. ''' -from szurubooru.middleware.authenticator import Authenticator -from szurubooru.middleware.context_adapter import ContextAdapter -from szurubooru.middleware.require_json import RequireJson -from szurubooru.middleware.db_session import DbSession -from szurubooru.middleware.cache_purger import CachePurger -from szurubooru.middleware.request_logger import RequestLogger +import szurubooru.middleware.db_session +import szurubooru.middleware.authenticator +import szurubooru.middleware.cache_purger +import szurubooru.middleware.request_logger diff --git a/server/szurubooru/middleware/authenticator.py b/server/szurubooru/middleware/authenticator.py index 02d48c18..b483494d 100644 --- a/server/szurubooru/middleware/authenticator.py +++ b/server/szurubooru/middleware/authenticator.py @@ -1,51 +1,44 @@ import base64 -import falcon from szurubooru import db, errors from szurubooru.func import auth, users +from szurubooru.rest import middleware +from szurubooru.rest.errors import HttpBadRequest -class Authenticator(object): - ''' - Authenticates every request and put information on active user in the - request context. - ''' +def _authenticate(username, password): + ''' Try to authenticate user. Throw AuthError for invalid users. ''' + user = users.get_user_by_name(username) + if not auth.is_valid_password(user, password): + raise errors.AuthError('Invalid password.') + return user - def process_request(self, request, _response): - ''' Bind the user to request. Update last login time if needed. ''' - request.context.user = self._get_user(request) - if request.get_param_as_bool('bump-login') \ - and request.context.user.user_id: - users.bump_user_login_time(request.context.user) - request.context.session.commit() +def _create_anonymous_user(): + user = db.User() + user.name = None + user.rank = 'anonymous' + return user - def _get_user(self, request): - if not request.auth: - return self._create_anonymous_user() +def _get_user(ctx): + if not ctx.has_header('Authorization'): + return _create_anonymous_user() - try: - auth_type, user_and_password = request.auth.split(' ', 1) - if auth_type.lower() != 'basic': - raise falcon.HTTPBadRequest( - 'Invalid authentication type', - 'Only basic authorization is supported.') - username, password = base64.decodebytes( - user_and_password.encode('ascii')).decode('utf8').split(':') - return self._authenticate(username, password) - except ValueError as err: - msg = 'Basic authentication header value not properly formed. ' \ - + 'Supplied header {0}. Got error: {1}' - raise falcon.HTTPBadRequest( - 'Malformed authentication request', - msg.format(request.auth, str(err))) + try: + auth_type, user_and_password = ctx.get_header('Authorization').split(' ', 1) + if auth_type.lower() != 'basic': + raise HttpBadRequest( + 'Only basic HTTP authentication is supported.') + username, password = base64.decodebytes( + user_and_password.encode('ascii')).decode('utf8').split(':') + return _authenticate(username, password) + except ValueError as err: + msg = 'Basic authentication header value are not properly formed. ' \ + + 'Supplied header {0}. Got error: {1}' + raise HttpBadRequest( + msg.format(ctx.get_header('Authorization'), str(err))) - def _authenticate(self, username, password): - ''' Try to authenticate user. Throw AuthError for invalid users. ''' - user = users.get_user_by_name(username) - if not auth.is_valid_password(user, password): - raise errors.AuthError('Invalid password.') - return user - - def _create_anonymous_user(self): - user = db.User() - user.name = None - user.rank = 'anonymous' - return user +@middleware.pre_hook +def process_request(ctx): + ''' Bind the user to request. Update last login time if needed. ''' + ctx.user = _get_user(ctx) + if ctx.get_param_as_bool('bump-login') and ctx.user.user_id: + users.bump_user_login_time(ctx.user) + ctx.session.commit() diff --git a/server/szurubooru/middleware/cache_purger.py b/server/szurubooru/middleware/cache_purger.py index b8daa880..bc9bd9bc 100644 --- a/server/szurubooru/middleware/cache_purger.py +++ b/server/szurubooru/middleware/cache_purger.py @@ -1,6 +1,7 @@ from szurubooru.func import cache +from szurubooru.rest import middleware -class CachePurger(object): - def process_request(self, request, _response): - if request.method != 'GET': - cache.purge() +@middleware.pre_hook +def process_request(ctx): + if ctx.method != 'GET': + cache.purge() diff --git a/server/szurubooru/middleware/context_adapter.py b/server/szurubooru/middleware/context_adapter.py deleted file mode 100644 index 2a1816a4..00000000 --- a/server/szurubooru/middleware/context_adapter.py +++ /dev/null @@ -1,65 +0,0 @@ -import cgi -import datetime -import json -import falcon - -def json_serializer(obj): - ''' JSON serializer for objects not serializable by default JSON code ''' - if isinstance(obj, datetime.datetime): - serial = obj.isoformat('T') + 'Z' - return serial - raise TypeError('Type not serializable') - -class ContextAdapter(object): - ''' - 1. Deserialize API requests into the context: - - Pass GET parameters - - Handle multipart/form-data file uploads - - Handle JSON requests - 2. Serialize API responses from the context as JSON. - ''' - def process_request(self, request, _response): - request.context.files = {} - request.context.input = {} - request.context.output = None - # pylint: disable=protected-access - for key, value in request._params.items(): - request.context.input[key] = value - - if request.content_length in (None, 0): - return - - if request.content_type and 'multipart/form-data' in request.content_type: - # obscure, claims to "avoid a bug in cgi.FieldStorage" - request.env.setdefault('QUERY_STRING', '') - - form = cgi.FieldStorage(fp=request.stream, environ=request.env) - for key in form: - if key != 'metadata': - _original_file_name = getattr(form[key], 'filename', None) - request.context.files[key] = form.getvalue(key) - body = form.getvalue('metadata') - else: - body = request.stream.read() - - if not body: - raise falcon.HTTPBadRequest( - 'Empty request body', - 'A valid JSON document is required.') - - try: - if isinstance(body, bytes): - body = body.decode('utf-8') - - for key, value in json.loads(body).items(): - request.context.input[key] = value - except (ValueError, UnicodeDecodeError): - raise falcon.HTTPBadRequest( - 'Malformed JSON', - 'Could not decode the request body. The ' - 'JSON was incorrect or not encoded as UTF-8.') - - def process_response(self, request, response, _resource): - if request.context.output: - response.body = json.dumps( - request.context.output, default=json_serializer, indent=2) diff --git a/server/szurubooru/middleware/db_session.py b/server/szurubooru/middleware/db_session.py index d2313def..8d2c1c9c 100644 --- a/server/szurubooru/middleware/db_session.py +++ b/server/szurubooru/middleware/db_session.py @@ -1,14 +1,11 @@ -import logging from szurubooru import db +from szurubooru.rest import middleware -logger = logging.getLogger(__name__) +@middleware.pre_hook +def _process_request(ctx): + ctx.session = db.session() + db.reset_query_count() -class DbSession(object): - ''' Attaches database session to the context of every request. ''' - - def process_request(self, request, _response): - request.context.session = db.session() - db.reset_query_count() - - def process_response(self, _request, _response, _resource): - db.session.remove() +@middleware.post_hook +def _process_response(_ctx): + db.session.remove() diff --git a/server/szurubooru/middleware/request_logger.py b/server/szurubooru/middleware/request_logger.py index fd616467..638df16b 100644 --- a/server/szurubooru/middleware/request_logger.py +++ b/server/szurubooru/middleware/request_logger.py @@ -1,16 +1,14 @@ import logging from szurubooru import db +from szurubooru.rest import middleware logger = logging.getLogger(__name__) -class RequestLogger(object): - def process_request(self, request, _response): - pass - - def process_response(self, request, _response, _resource): - logger.info( - '%s %s (user=%s, queries=%d)', - request.method, - request.url, - request.context.user.name, - db.get_query_count()) +@middleware.post_hook +def process_response(ctx): + logger.info( + '%s %s (user=%s, queries=%d)', + ctx.method, + ctx.url, + ctx.user.name, + db.get_query_count()) diff --git a/server/szurubooru/middleware/require_json.py b/server/szurubooru/middleware/require_json.py deleted file mode 100644 index a965a7bd..00000000 --- a/server/szurubooru/middleware/require_json.py +++ /dev/null @@ -1,9 +0,0 @@ -import falcon - -class RequireJson(object): - ''' Sanitizes requests so that only JSON is accepted. ''' - - def process_request(self, request, _response): - if not request.client_accepts_json: - raise falcon.HTTPNotAcceptable( - 'This API only supports responses encoded as JSON.') diff --git a/server/szurubooru/rest/__init__.py b/server/szurubooru/rest/__init__.py new file mode 100644 index 00000000..ac9958a5 --- /dev/null +++ b/server/szurubooru/rest/__init__.py @@ -0,0 +1,2 @@ +from szurubooru.rest.app import application +from szurubooru.rest.context import Context diff --git a/server/szurubooru/rest/app.py b/server/szurubooru/rest/app.py new file mode 100644 index 00000000..5a5b21b2 --- /dev/null +++ b/server/szurubooru/rest/app.py @@ -0,0 +1,124 @@ +import cgi +import io +import json +import re +from datetime import datetime +from szurubooru.func import util +from szurubooru.rest import errors, middleware, routes, context + +def _json_serializer(obj): + ''' JSON serializer for objects not serializable by default JSON code ''' + if isinstance(obj, datetime): + serial = obj.isoformat('T') + 'Z' + return serial + raise TypeError('Type not serializable') + +def _dump_json(obj): + return json.dumps(obj, default=_json_serializer, indent=2) + +def _read(env): + length = int(env.get('CONTENT_LENGTH', 0)) + output = io.BytesIO() + while length > 0: + part = env['wsgi.input'].read(min(length, 1024*200)) + if not part: + break + output.write(part) + length -= len(part) + output.seek(0) + return output + +def _get_headers(env): + headers = {} + for key, value in env.items(): + if key.startswith('HTTP_'): + key = util.snake_case_to_upper_train_case(key[5:]) + headers[key] = value + return headers + +def _create_context(env): + method = env['REQUEST_METHOD'] + path = '/' + env['PATH_INFO'].lstrip('/') + headers = _get_headers(env) + + # obscure, claims to "avoid a bug in cgi.FieldStorage" + env.setdefault('QUERY_STRING', '') + + files = {} + params = {} + + request_stream = _read(env) + form = cgi.FieldStorage(fp=request_stream, environ=env) + + if form.list: + for key in form: + if key != 'metadata': + if isinstance(form[key], cgi.MiniFieldStorage): + params[key] = form.getvalue(key) + else: + _original_file_name = getattr(form[key], 'filename', None) + files[key] = form.getvalue(key) + if 'metadata' in form: + body = form.getvalue('metadata') + else: + body = request_stream.read() + else: + body = None + + if body: + try: + if isinstance(body, bytes): + body = body.decode('utf-8') + + for key, value in json.loads(body).items(): + params[key] = value + except (ValueError, UnicodeDecodeError): + raise errors.HttpBadRequest( + 'Could not decode the request body. The JSON ' + 'was incorrect or was not encoded as UTF-8.') + + return context.Context(method, path, headers, params, files) + +def application(env, start_response): + try: + ctx = _create_context(env) + if not 'application/json' in ctx.get_header('Accept'): + raise errors.HttpNotAcceptable( + 'This API only supports JSON responses.') + + for url, allowed_methods in routes.routes.items(): + match = re.fullmatch(url, ctx.url) + if not match: + continue + if ctx.method not in allowed_methods: + raise errors.HttpMethodNotAllowed( + 'Allowed methods: %r' % allowed_methods) + + for hook in middleware.pre_hooks: + hook(ctx) + handler = allowed_methods[ctx.method] + try: + response = handler(ctx, match.groupdict()) + except Exception as ex: + for exception_type, handler in errors.error_handlers.items(): + if isinstance(ex, exception_type): + handler(ex) + raise + finally: + for hook in middleware.post_hooks: + hook(ctx) + + start_response('200', [('content-type', 'application/json')]) + return (_dump_json(response).encode('utf-8'),) + + raise errors.HttpNotFound( + 'Requested path ' + ctx.url + ' was not found.') + + except errors.BaseHttpError as ex: + start_response( + '%d %s' % (ex.code, ex.reason), + [('content-type', 'application/json')]) + return (_dump_json({ + 'title': ex.title, + 'description': ex.description, + }).encode('utf-8'),) diff --git a/server/szurubooru/api/context.py b/server/szurubooru/rest/context.py similarity index 70% rename from server/szurubooru/api/context.py rename to server/szurubooru/rest/context.py index a21cce02..a6454a4e 100644 --- a/server/szurubooru/api/context.py +++ b/server/szurubooru/rest/context.py @@ -1,4 +1,3 @@ -import falcon from szurubooru import errors from szurubooru.func import net @@ -7,8 +6,9 @@ def _lower_first(source): def _param_wrapper(func): def wrapper(self, name, required=False, default=None, **kwargs): - if name in self.input: - value = self.input[name] + # pylint: disable=protected-access + if name in self._params: + value = self._params[name] try: value = func(self, value, **kwargs) except errors.InvalidParameterError as ex: @@ -22,34 +22,46 @@ def _param_wrapper(func): 'Required parameter %r is missing.' % name) return wrapper -class Context(object): - def __init__(self): - self.session = None - self.user = None - self.files = {} - self.input = {} - self.output = None - self.settings = {} +class Context(): + # pylint: disable=too-many-arguments + def __init__(self, method, url, headers=None, params=None, files=None): + self.method = method + self.url = url + self._headers = headers or {} + self._params = params or {} + self._files = files or {} - def has_param(self, name): - return name in self.input + # provided by middleware + # self.session = None + # self.user = None + + def has_header(self, name): + return name in self._headers + + def get_header(self, name): + return self._headers.get(name, None) def has_file(self, name): - return name in self.files or name + 'Url' in self.input + return name in self._files or name + 'Url' in self._params def get_file(self, name, required=False): - if name in self.files: - return self.files[name] - if name + 'Url' in self.input: - return net.download(self.input[name + 'Url']) + if name in self._files: + return self._files[name] + if name + 'Url' in self._params: + return net.download(self._params[name + 'Url']) if not required: return None raise errors.MissingRequiredFileError( 'Required file %r is missing.' % name) + def has_param(self, name): + return name in self._params + @_param_wrapper def get_param_as_list(self, value): if not isinstance(value, list): + if ',' in value: + return value.split(',') return [value] return value @@ -86,6 +98,3 @@ class Context(object): if value in ['0', 'n', 'no', 'nope', 'f', 'false']: return False raise errors.InvalidParameterError('The value must be a boolean value.') - -class Request(falcon.Request): - context_type = Context diff --git a/server/szurubooru/rest/errors.py b/server/szurubooru/rest/errors.py new file mode 100644 index 00000000..9ada0235 --- /dev/null +++ b/server/szurubooru/rest/errors.py @@ -0,0 +1,37 @@ +error_handlers = {} # pylint: disable=invalid-name + +class BaseHttpError(RuntimeError): + code = None + reason = None + + def __init__(self, description, title=None): + super().__init__() + self.description = description + self.title = title or self.reason + +class HttpBadRequest(BaseHttpError): + code = 400 + reason = 'Bad Request' + +class HttpForbidden(BaseHttpError): + code = 403 + reason = 'Forbidden' + +class HttpNotFound(BaseHttpError): + code = 404 + reason = 'Not Found' + +class HttpNotAcceptable(BaseHttpError): + code = 406 + reason = 'Not Acceptable' + +class HttpConflict(BaseHttpError): + code = 409 + reason = 'Conflict' + +class HttpMethodNotAllowed(BaseHttpError): + code = 405 + reason = 'Method Not Allowed' + +def handle(exception_type, handler): + error_handlers[exception_type] = handler diff --git a/server/szurubooru/rest/middleware.py b/server/szurubooru/rest/middleware.py new file mode 100644 index 00000000..e569d692 --- /dev/null +++ b/server/szurubooru/rest/middleware.py @@ -0,0 +1,9 @@ +# pylint: disable=invalid-name +pre_hooks = [] +post_hooks = [] + +def pre_hook(handler): + pre_hooks.append(handler) + +def post_hook(handler): + post_hooks.insert(0, handler) diff --git a/server/szurubooru/rest/routes.py b/server/szurubooru/rest/routes.py new file mode 100644 index 00000000..f5567219 --- /dev/null +++ b/server/szurubooru/rest/routes.py @@ -0,0 +1,27 @@ +from collections import defaultdict + +routes = defaultdict(dict) # pylint: disable=invalid-name + +def get(url): + def wrapper(handler): + routes[url]['GET'] = handler + return handler + return wrapper + +def put(url): + def wrapper(handler): + routes[url]['PUT'] = handler + return handler + return wrapper + +def post(url): + def wrapper(handler): + routes[url]['POST'] = handler + return handler + return wrapper + +def delete(url): + def wrapper(handler): + routes[url]['DELETE'] = handler + return handler + return wrapper diff --git a/server/szurubooru/tests/api/test_comment_creating.py b/server/szurubooru/tests/api/test_comment_creating.py index 00a9de60..68614d75 100644 --- a/server/szurubooru/tests/api/test_comment_creating.py +++ b/server/szurubooru/tests/api/test_comment_creating.py @@ -1,89 +1,78 @@ -import datetime import pytest +import unittest.mock +from datetime import datetime from szurubooru import api, db, errors -from szurubooru.func import util, posts +from szurubooru.func import comments, posts -@pytest.fixture -def test_ctx( - tmpdir, config_injector, context_factory, post_factory, user_factory): - config_injector({ - 'data_dir': str(tmpdir), - 'data_url': 'http://example.com', - 'privileges': {'comments:create': db.User.RANK_REGULAR}, - 'thumbnails': {'avatar_width': 200}, - }) - ret = util.dotdict() - ret.context_factory = context_factory - ret.post_factory = post_factory - ret.user_factory = user_factory - ret.api = api.CommentListApi() - return ret +@pytest.fixture(autouse=True) +def inject_config(config_injector): + config_injector({'privileges': {'comments:create': db.User.RANK_REGULAR}}) -def test_creating_comment(test_ctx, fake_datetime): - post = test_ctx.post_factory() - user = test_ctx.user_factory(rank=db.User.RANK_REGULAR) +def test_creating_comment( + user_factory, post_factory, context_factory, fake_datetime): + post = post_factory() + user = user_factory(rank=db.User.RANK_REGULAR) db.session.add_all([post, user]) db.session.flush() - with fake_datetime('1997-01-01'): - result = test_ctx.api.post( - test_ctx.context_factory( - input={'text': 'input', 'postId': post.post_id}, + with unittest.mock.patch('szurubooru.func.comments.serialize_comment'), \ + fake_datetime('1997-01-01'): + comments.serialize_comment.return_value = 'serialized comment' + result = api.comment_api.create_comment( + context_factory( + params={'text': 'input', 'postId': post.post_id}, user=user)) - assert result['text'] == 'input' - assert 'id' in result - assert 'user' in result - assert 'name' in result['user'] - assert 'postId' in result - comment = db.session.query(db.Comment).one() - assert comment.text == 'input' - assert comment.creation_time == datetime.datetime(1997, 1, 1) - assert comment.last_edit_time is None - assert comment.user and comment.user.user_id == user.user_id - assert comment.post and comment.post.post_id == post.post_id + assert result == 'serialized comment' + comment = db.session.query(db.Comment).one() + assert comment.text == 'input' + assert comment.creation_time == datetime(1997, 1, 1) + assert comment.last_edit_time is None + assert comment.user and comment.user.user_id == user.user_id + assert comment.post and comment.post.post_id == post.post_id -@pytest.mark.parametrize('input', [ +@pytest.mark.parametrize('params', [ {'text': None}, {'text': ''}, {'text': [None]}, {'text': ['']}, ]) -def test_trying_to_pass_invalid_input(test_ctx, input): - post = test_ctx.post_factory() - user = test_ctx.user_factory(rank=db.User.RANK_REGULAR) +def test_trying_to_pass_invalid_params( + user_factory, post_factory, context_factory, params): + post = post_factory() + user = user_factory(rank=db.User.RANK_REGULAR) db.session.add_all([post, user]) db.session.flush() - real_input = {'text': 'input', 'postId': post.post_id} - for key, value in input.items(): - real_input[key] = value + real_params = {'text': 'input', 'postId': post.post_id} + for key, value in params.items(): + real_params[key] = value with pytest.raises(errors.ValidationError): - test_ctx.api.post( - test_ctx.context_factory(input=real_input, user=user)) + api.comment_api.create_comment( + context_factory(params=real_params, user=user)) @pytest.mark.parametrize('field', ['text', 'postId']) -def test_trying_to_omit_mandatory_field(test_ctx, field): - input = { +def test_trying_to_omit_mandatory_field(user_factory, context_factory, field): + params = { 'text': 'input', 'postId': 1, } - del input[field] + del params[field] with pytest.raises(errors.ValidationError): - test_ctx.api.post( - test_ctx.context_factory( - input={}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) + api.comment_api.create_comment( + context_factory( + params={}, + user=user_factory(rank=db.User.RANK_REGULAR))) -def test_trying_to_comment_non_existing(test_ctx): - user = test_ctx.user_factory(rank=db.User.RANK_REGULAR) +def test_trying_to_comment_non_existing(user_factory, context_factory): + user = user_factory(rank=db.User.RANK_REGULAR) db.session.add_all([user]) db.session.flush() with pytest.raises(posts.PostNotFoundError): - test_ctx.api.post( - test_ctx.context_factory( - input={'text': 'bad', 'postId': 5}, user=user)) + api.comment_api.create_comment( + context_factory( + params={'text': 'bad', 'postId': 5}, user=user)) -def test_trying_to_create_without_privileges(test_ctx): +def test_trying_to_create_without_privileges(user_factory, context_factory): with pytest.raises(errors.AuthError): - test_ctx.api.post( - test_ctx.context_factory( - input={}, - user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS))) + api.comment_api.create_comment( + context_factory( + params={}, + user=user_factory(rank=db.User.RANK_ANONYMOUS))) diff --git a/server/szurubooru/tests/api/test_comment_deleting.py b/server/szurubooru/tests/api/test_comment_deleting.py index ecc200a6..99b875f9 100644 --- a/server/szurubooru/tests/api/test_comment_deleting.py +++ b/server/szurubooru/tests/api/test_comment_deleting.py @@ -1,61 +1,56 @@ import pytest -from datetime import datetime from szurubooru import api, db, errors -from szurubooru.func import util, comments +from szurubooru.func import comments -@pytest.fixture -def test_ctx(config_injector, context_factory, user_factory, comment_factory): +@pytest.fixture(autouse=True) +def inject_config(config_injector): config_injector({ 'privileges': { 'comments:delete:own': db.User.RANK_REGULAR, 'comments:delete:any': db.User.RANK_MODERATOR, }, }) - ret = util.dotdict() - ret.context_factory = context_factory - ret.user_factory = user_factory - ret.comment_factory = comment_factory - ret.api = api.CommentDetailApi() - return ret -def test_deleting_own_comment(test_ctx): - user = test_ctx.user_factory() - comment = test_ctx.comment_factory(user=user) +def test_deleting_own_comment(user_factory, comment_factory, context_factory): + user = user_factory() + comment = comment_factory(user=user) db.session.add(comment) db.session.commit() - result = test_ctx.api.delete( - test_ctx.context_factory(input={'version': 1}, user=user), - comment.comment_id) + result = api.comment_api.delete_comment( + context_factory(params={'version': 1}, user=user), + {'comment_id': comment.comment_id}) assert result == {} assert db.session.query(db.Comment).count() == 0 -def test_deleting_someones_else_comment(test_ctx): - user1 = test_ctx.user_factory(rank=db.User.RANK_REGULAR) - user2 = test_ctx.user_factory(rank=db.User.RANK_MODERATOR) - comment = test_ctx.comment_factory(user=user1) +def test_deleting_someones_else_comment( + user_factory, comment_factory, context_factory): + user1 = user_factory(rank=db.User.RANK_REGULAR) + user2 = user_factory(rank=db.User.RANK_MODERATOR) + comment = comment_factory(user=user1) db.session.add(comment) db.session.commit() - result = test_ctx.api.delete( - test_ctx.context_factory(input={'version': 1}, user=user2), - comment.comment_id) + result = api.comment_api.delete_comment( + context_factory(params={'version': 1}, user=user2), + {'comment_id': comment.comment_id}) assert db.session.query(db.Comment).count() == 0 -def test_trying_to_delete_someones_else_comment_without_privileges(test_ctx): - user1 = test_ctx.user_factory(rank=db.User.RANK_REGULAR) - user2 = test_ctx.user_factory(rank=db.User.RANK_REGULAR) - comment = test_ctx.comment_factory(user=user1) +def test_trying_to_delete_someones_else_comment_without_privileges( + user_factory, comment_factory, context_factory): + user1 = user_factory(rank=db.User.RANK_REGULAR) + user2 = user_factory(rank=db.User.RANK_REGULAR) + comment = comment_factory(user=user1) db.session.add(comment) db.session.commit() with pytest.raises(errors.AuthError): - test_ctx.api.delete( - test_ctx.context_factory(input={'version': 1}, user=user2), - comment.comment_id) + api.comment_api.delete_comment( + context_factory(params={'version': 1}, user=user2), + {'comment_id': comment.comment_id}) assert db.session.query(db.Comment).count() == 1 -def test_trying_to_delete_non_existing(test_ctx): +def test_trying_to_delete_non_existing(user_factory, context_factory): with pytest.raises(comments.CommentNotFoundError): - test_ctx.api.delete( - test_ctx.context_factory( - input={'version': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), - 1) + api.comment_api.delete_comment( + context_factory( + params={'version': 1}, + user=user_factory(rank=db.User.RANK_REGULAR)), + {'comment_id': 1}) diff --git a/server/szurubooru/tests/api/test_comment_rating.py b/server/szurubooru/tests/api/test_comment_rating.py index ac0aa7ff..1191f2fc 100644 --- a/server/szurubooru/tests/api/test_comment_rating.py +++ b/server/szurubooru/tests/api/test_comment_rating.py @@ -1,152 +1,134 @@ -import datetime import pytest +import unittest.mock from szurubooru import api, db, errors -from szurubooru.func import util, comments, scores +from szurubooru.func import comments, scores -@pytest.fixture -def test_ctx( - tmpdir, config_injector, context_factory, user_factory, comment_factory): - config_injector({ - 'data_dir': str(tmpdir), - 'data_url': 'http://example.com', - 'privileges': { - 'comments:score': db.User.RANK_REGULAR, - 'users:edit:any:email': db.User.RANK_MODERATOR, - }, - 'thumbnails': {'avatar_width': 200}, - }) - db.session.flush() - ret = util.dotdict() - ret.context_factory = context_factory - ret.user_factory = user_factory - ret.comment_factory = comment_factory - ret.api = api.CommentScoreApi() - return ret +@pytest.fixture(autouse=True) +def inject_config(config_injector): + config_injector({'privileges': {'comments:score': db.User.RANK_REGULAR}}) -def test_simple_rating(test_ctx, fake_datetime): - user = test_ctx.user_factory(rank=db.User.RANK_REGULAR) - comment = test_ctx.comment_factory(user=user) +def test_simple_rating( + user_factory, comment_factory, context_factory, fake_datetime): + user = user_factory(rank=db.User.RANK_REGULAR) + comment = comment_factory(user=user) db.session.add(comment) db.session.commit() - with fake_datetime('1997-12-01'): - result = test_ctx.api.put( - test_ctx.context_factory(input={'score': 1}, user=user), - comment.comment_id) - assert 'text' in result - assert db.session.query(db.CommentScore).count() == 1 - assert comment is not None - assert comment.score == 1 + with unittest.mock.patch('szurubooru.func.comments.serialize_comment'): + comments.serialize_comment.return_value = 'serialized comment' + with fake_datetime('1997-12-01'): + result = api.comment_api.set_comment_score( + context_factory(params={'score': 1}, user=user), + {'comment_id': comment.comment_id}) + assert result == 'serialized comment' + assert db.session.query(db.CommentScore).count() == 1 + assert comment is not None + assert comment.score == 1 -def test_updating_rating(test_ctx, fake_datetime): - user = test_ctx.user_factory(rank=db.User.RANK_REGULAR) - comment = test_ctx.comment_factory(user=user) +def test_updating_rating( + user_factory, comment_factory, context_factory, fake_datetime): + user = user_factory(rank=db.User.RANK_REGULAR) + comment = comment_factory(user=user) db.session.add(comment) db.session.commit() - with fake_datetime('1997-12-01'): - result = test_ctx.api.put( - test_ctx.context_factory(input={'score': 1}, user=user), - comment.comment_id) - with fake_datetime('1997-12-02'): - result = test_ctx.api.put( - test_ctx.context_factory(input={'score': -1}, user=user), - comment.comment_id) - comment = db.session.query(db.Comment).one() - assert db.session.query(db.CommentScore).count() == 1 - assert comment.score == -1 + with unittest.mock.patch('szurubooru.func.comments.serialize_comment'): + with fake_datetime('1997-12-01'): + result = api.comment_api.set_comment_score( + context_factory(params={'score': 1}, user=user), + {'comment_id': comment.comment_id}) + with fake_datetime('1997-12-02'): + result = api.comment_api.set_comment_score( + context_factory(params={'score': -1}, user=user), + {'comment_id': comment.comment_id}) + comment = db.session.query(db.Comment).one() + assert db.session.query(db.CommentScore).count() == 1 + assert comment.score == -1 -def test_updating_rating_to_zero(test_ctx, fake_datetime): - user = test_ctx.user_factory(rank=db.User.RANK_REGULAR) - comment = test_ctx.comment_factory(user=user) +def test_updating_rating_to_zero( + user_factory, comment_factory, context_factory, fake_datetime): + user = user_factory(rank=db.User.RANK_REGULAR) + comment = comment_factory(user=user) db.session.add(comment) db.session.commit() - with fake_datetime('1997-12-01'): - result = test_ctx.api.put( - test_ctx.context_factory(input={'score': 1}, user=user), - comment.comment_id) - with fake_datetime('1997-12-02'): - result = test_ctx.api.put( - test_ctx.context_factory(input={'score': 0}, user=user), - comment.comment_id) - comment = db.session.query(db.Comment).one() - assert db.session.query(db.CommentScore).count() == 0 - assert comment.score == 0 + with unittest.mock.patch('szurubooru.func.comments.serialize_comment'): + with fake_datetime('1997-12-01'): + result = api.comment_api.set_comment_score( + context_factory(params={'score': 1}, user=user), + {'comment_id': comment.comment_id}) + with fake_datetime('1997-12-02'): + result = api.comment_api.set_comment_score( + context_factory(params={'score': 0}, user=user), + {'comment_id': comment.comment_id}) + comment = db.session.query(db.Comment).one() + assert db.session.query(db.CommentScore).count() == 0 + assert comment.score == 0 -def test_deleting_rating(test_ctx, fake_datetime): - user = test_ctx.user_factory(rank=db.User.RANK_REGULAR) - comment = test_ctx.comment_factory(user=user) +def test_deleting_rating( + user_factory, comment_factory, context_factory, fake_datetime): + user = user_factory(rank=db.User.RANK_REGULAR) + comment = comment_factory(user=user) db.session.add(comment) db.session.commit() - with fake_datetime('1997-12-01'): - result = test_ctx.api.put( - test_ctx.context_factory(input={'score': 1}, user=user), - comment.comment_id) - with fake_datetime('1997-12-02'): - result = test_ctx.api.delete( - test_ctx.context_factory(user=user), comment.comment_id) - comment = db.session.query(db.Comment).one() - assert db.session.query(db.CommentScore).count() == 0 - assert comment.score == 0 + with unittest.mock.patch('szurubooru.func.comments.serialize_comment'): + with fake_datetime('1997-12-01'): + result = api.comment_api.set_comment_score( + context_factory(params={'score': 1}, user=user), + {'comment_id': comment.comment_id}) + with fake_datetime('1997-12-02'): + result = api.comment_api.delete_comment_score( + context_factory(user=user), + {'comment_id': comment.comment_id}) + comment = db.session.query(db.Comment).one() + assert db.session.query(db.CommentScore).count() == 0 + assert comment.score == 0 -def test_ratings_from_multiple_users(test_ctx, fake_datetime): - user1 = test_ctx.user_factory(rank=db.User.RANK_REGULAR) - user2 = test_ctx.user_factory(rank=db.User.RANK_REGULAR) - comment = test_ctx.comment_factory() +def test_ratings_from_multiple_users( + user_factory, comment_factory, context_factory, fake_datetime): + user1 = user_factory(rank=db.User.RANK_REGULAR) + user2 = user_factory(rank=db.User.RANK_REGULAR) + comment = comment_factory() db.session.add_all([user1, user2, comment]) db.session.commit() - with fake_datetime('1997-12-01'): - result = test_ctx.api.put( - test_ctx.context_factory(input={'score': 1}, user=user1), - comment.comment_id) - with fake_datetime('1997-12-02'): - result = test_ctx.api.put( - test_ctx.context_factory(input={'score': -1}, user=user2), - comment.comment_id) - comment = db.session.query(db.Comment).one() - assert db.session.query(db.CommentScore).count() == 2 - assert comment.score == 0 + with unittest.mock.patch('szurubooru.func.comments.serialize_comment'): + with fake_datetime('1997-12-01'): + result = api.comment_api.set_comment_score( + context_factory(params={'score': 1}, user=user1), + {'comment_id': comment.comment_id}) + with fake_datetime('1997-12-02'): + result = api.comment_api.set_comment_score( + context_factory(params={'score': -1}, user=user2), + {'comment_id': comment.comment_id}) + comment = db.session.query(db.Comment).one() + assert db.session.query(db.CommentScore).count() == 2 + assert comment.score == 0 -@pytest.mark.parametrize('input,expected_exception', [ - ({'score': None}, errors.ValidationError), - ({'score': ''}, errors.ValidationError), - ({'score': -2}, scores.InvalidScoreValueError), - ({'score': 2}, scores.InvalidScoreValueError), - ({'score': [1]}, errors.ValidationError), -]) -def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception): - user = test_ctx.user_factory() - comment = test_ctx.comment_factory(user=user) - db.session.add(comment) - db.session.commit() - with pytest.raises(expected_exception): - test_ctx.api.put( - test_ctx.context_factory(input=input, user=user), - comment.comment_id) - -def test_trying_to_omit_mandatory_field(test_ctx): - user = test_ctx.user_factory() - comment = test_ctx.comment_factory(user=user) +def test_trying_to_omit_mandatory_field( + user_factory, comment_factory, context_factory): + user = user_factory() + comment = comment_factory(user=user) db.session.add(comment) db.session.commit() with pytest.raises(errors.ValidationError): - test_ctx.api.put( - test_ctx.context_factory(input={}, user=user), - comment.comment_id) + api.comment_api.set_comment_score( + context_factory(params={}, user=user), + {'comment_id': comment.comment_id}) -def test_trying_to_update_non_existing(test_ctx): +def test_trying_to_update_non_existing( + user_factory, comment_factory, context_factory): with pytest.raises(comments.CommentNotFoundError): - test_ctx.api.put( - test_ctx.context_factory( - input={'score': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), - 5) + api.comment_api.set_comment_score( + context_factory( + params={'score': 1}, + user=user_factory(rank=db.User.RANK_REGULAR)), + {'comment_id': 5}) -def test_trying_to_rate_without_privileges(test_ctx): - comment = test_ctx.comment_factory() +def test_trying_to_rate_without_privileges( + user_factory, comment_factory, context_factory): + comment = comment_factory() db.session.add(comment) db.session.commit() with pytest.raises(errors.AuthError): - test_ctx.api.put( - test_ctx.context_factory( - input={'score': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)), - comment.comment_id) + api.comment_api.set_comment_score( + context_factory( + params={'score': 1}, + user=user_factory(rank=db.User.RANK_ANONYMOUS)), + {'comment_id': comment.comment_id}) diff --git a/server/szurubooru/tests/api/test_comment_retrieving.py b/server/szurubooru/tests/api/test_comment_retrieving.py index 3338d1b4..5033f798 100644 --- a/server/szurubooru/tests/api/test_comment_retrieving.py +++ b/server/szurubooru/tests/api/test_comment_retrieving.py @@ -1,76 +1,65 @@ -import datetime import pytest +import unittest.mock from szurubooru import api, db, errors -from szurubooru.func import util, comments +from szurubooru.func import comments -@pytest.fixture -def test_ctx( - tmpdir, context_factory, config_injector, user_factory, comment_factory): +@pytest.fixture(autouse=True) +def inject_config(config_injector): config_injector({ - 'data_dir': str(tmpdir), - 'data_url': 'http://example.com', 'privileges': { 'comments:list': db.User.RANK_REGULAR, 'comments:view': db.User.RANK_REGULAR, - 'users:edit:any:email': db.User.RANK_MODERATOR, }, - 'thumbnails': {'avatar_width': 200}, }) - ret = util.dotdict() - ret.context_factory = context_factory - ret.user_factory = user_factory - ret.comment_factory = comment_factory - ret.list_api = api.CommentListApi() - ret.detail_api = api.CommentDetailApi() - return ret -def test_retrieving_multiple(test_ctx): - comment1 = test_ctx.comment_factory(text='text 1') - comment2 = test_ctx.comment_factory(text='text 2') +def test_retrieving_multiple(user_factory, comment_factory, context_factory): + comment1 = comment_factory(text='text 1') + comment2 = comment_factory(text='text 2') db.session.add_all([comment1, comment2]) - result = test_ctx.list_api.get( - test_ctx.context_factory( - input={'query': '', 'page': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) - assert result['query'] == '' - assert result['page'] == 1 - assert result['pageSize'] == 100 - assert result['total'] == 2 - assert [c['text'] for c in result['results']] == ['text 1', 'text 2'] + with unittest.mock.patch('szurubooru.func.comments.serialize_comment'): + comments.serialize_comment.return_value = 'serialized comment' + result = api.comment_api.get_comments( + context_factory( + params={'query': '', 'page': 1}, + user=user_factory(rank=db.User.RANK_REGULAR))) + assert result == { + 'query': '', + 'page': 1, + 'pageSize': 100, + 'total': 2, + 'results': ['serialized comment', 'serialized comment'], + } -def test_trying_to_retrieve_multiple_without_privileges(test_ctx): +def test_trying_to_retrieve_multiple_without_privileges( + user_factory, context_factory): with pytest.raises(errors.AuthError): - test_ctx.list_api.get( - test_ctx.context_factory( - input={'query': '', 'page': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS))) + api.comment_api.get_comments( + context_factory( + params={'query': '', 'page': 1}, + user=user_factory(rank=db.User.RANK_ANONYMOUS))) -def test_retrieving_single(test_ctx): - comment = test_ctx.comment_factory(text='dummy text') +def test_retrieving_single(user_factory, comment_factory, context_factory): + comment = comment_factory(text='dummy text') db.session.add(comment) db.session.flush() - result = test_ctx.detail_api.get( - test_ctx.context_factory( - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), - comment.comment_id) - assert 'id' in result - assert 'lastEditTime' in result - assert 'creationTime' in result - assert 'text' in result - assert 'user' in result - assert 'name' in result['user'] - assert 'postId' in result + with unittest.mock.patch('szurubooru.func.comments.serialize_comment'): + comments.serialize_comment.return_value = 'serialized comment' + result = api.comment_api.get_comment( + context_factory( + user=user_factory(rank=db.User.RANK_REGULAR)), + {'comment_id': comment.comment_id}) + assert result == 'serialized comment' -def test_trying_to_retrieve_single_non_existing(test_ctx): +def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): with pytest.raises(comments.CommentNotFoundError): - test_ctx.detail_api.get( - test_ctx.context_factory( - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), - 5) + api.comment_api.get_comment( + context_factory( + user=user_factory(rank=db.User.RANK_REGULAR)), + {'comment_id': 5}) -def test_trying_to_retrieve_single_without_privileges(test_ctx): +def test_trying_to_retrieve_single_without_privileges( + user_factory, context_factory): with pytest.raises(errors.AuthError): - test_ctx.detail_api.get( - test_ctx.context_factory( - user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)), - 5) + api.comment_api.get_comment( + context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)), + {'comment_id': 5}) diff --git a/server/szurubooru/tests/api/test_comment_updating.py b/server/szurubooru/tests/api/test_comment_updating.py index e679c008..9023cb55 100644 --- a/server/szurubooru/tests/api/test_comment_updating.py +++ b/server/szurubooru/tests/api/test_comment_updating.py @@ -1,103 +1,94 @@ -import datetime import pytest +import unittest.mock +from datetime import datetime from szurubooru import api, db, errors -from szurubooru.func import util, comments +from szurubooru.func import comments -@pytest.fixture -def test_ctx( - tmpdir, config_injector, context_factory, user_factory, comment_factory): +@pytest.fixture(autouse=True) +def inject_config(config_injector): config_injector({ - 'data_dir': str(tmpdir), - 'data_url': 'http://example.com', 'privileges': { 'comments:edit:own': db.User.RANK_REGULAR, 'comments:edit:any': db.User.RANK_MODERATOR, - 'users:edit:any:email': db.User.RANK_MODERATOR, }, - 'thumbnails': {'avatar_width': 200}, }) - db.session.flush() - ret = util.dotdict() - ret.context_factory = context_factory - ret.user_factory = user_factory - ret.comment_factory = comment_factory - ret.api = api.CommentDetailApi() - return ret -def test_simple_updating(test_ctx, fake_datetime): - user = test_ctx.user_factory(rank=db.User.RANK_REGULAR) - comment = test_ctx.comment_factory(user=user) +def test_simple_updating( + user_factory, comment_factory, context_factory, fake_datetime): + user = user_factory(rank=db.User.RANK_REGULAR) + comment = comment_factory(user=user) db.session.add(comment) db.session.commit() - with fake_datetime('1997-12-01'): - result = test_ctx.api.put( - test_ctx.context_factory( - input={'text': 'new text', 'version': 1}, user=user), - comment.comment_id) - assert result['text'] == 'new text' - comment = db.session.query(db.Comment).one() - assert comment is not None - assert comment.text == 'new text' - assert comment.last_edit_time is not None + with unittest.mock.patch('szurubooru.func.comments.serialize_comment'), \ + fake_datetime('1997-12-01'): + comments.serialize_comment.return_value = 'serialized comment' + result = api.comment_api.update_comment( + context_factory( + params={'text': 'new text', 'version': 1}, user=user), + {'comment_id': comment.comment_id}) + assert result == 'serialized comment' + assert comment.last_edit_time == datetime(1997, 12, 1) -@pytest.mark.parametrize('input,expected_exception', [ +@pytest.mark.parametrize('params,expected_exception', [ ({'text': None}, comments.EmptyCommentTextError), ({'text': ''}, comments.EmptyCommentTextError), ({'text': []}, comments.EmptyCommentTextError), ({'text': [None]}, errors.ValidationError), ({'text': ['']}, comments.EmptyCommentTextError), ]) -def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception): - user = test_ctx.user_factory() - comment = test_ctx.comment_factory(user=user) +def test_trying_to_pass_invalid_params( + user_factory, comment_factory, context_factory, params, expected_exception): + user = user_factory() + comment = comment_factory(user=user) db.session.add(comment) db.session.commit() with pytest.raises(expected_exception): - test_ctx.api.put( - test_ctx.context_factory( - input={**input, **{'version': 1}}, user=user), - comment.comment_id) + api.comment_api.update_comment( + context_factory( + params={**params, **{'version': 1}}, user=user), + {'comment_id': comment.comment_id}) -def test_trying_to_omit_mandatory_field(test_ctx): - user = test_ctx.user_factory() - comment = test_ctx.comment_factory(user=user) +def test_trying_to_omit_mandatory_field( + user_factory, comment_factory, context_factory): + user = user_factory() + comment = comment_factory(user=user) db.session.add(comment) db.session.commit() with pytest.raises(errors.ValidationError): - test_ctx.api.put( - test_ctx.context_factory(input={'version': 1}, user=user), - comment.comment_id) + api.comment_api.update_comment( + context_factory(params={'version': 1}, user=user), + {'comment_id': comment.comment_id}) -def test_trying_to_update_non_existing(test_ctx): +def test_trying_to_update_non_existing(user_factory, context_factory): with pytest.raises(comments.CommentNotFoundError): - test_ctx.api.put( - test_ctx.context_factory( - input={'text': 'new text'}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), - 5) + api.comment_api.update_comment( + context_factory( + params={'text': 'new text'}, + user=user_factory(rank=db.User.RANK_REGULAR)), + {'comment_id': 5}) -def test_trying_to_update_someones_comment_without_privileges(test_ctx): - user = test_ctx.user_factory(rank=db.User.RANK_REGULAR) - user2 = test_ctx.user_factory(rank=db.User.RANK_REGULAR) - comment = test_ctx.comment_factory(user=user) +def test_trying_to_update_someones_comment_without_privileges( + user_factory, comment_factory, context_factory): + user = user_factory(rank=db.User.RANK_REGULAR) + user2 = user_factory(rank=db.User.RANK_REGULAR) + comment = comment_factory(user=user) db.session.add(comment) db.session.commit() with pytest.raises(errors.AuthError): - test_ctx.api.put( - test_ctx.context_factory( - input={'text': 'new text', 'version': 1}, user=user2), - comment.comment_id) + api.comment_api.update_comment( + context_factory( + params={'text': 'new text', 'version': 1}, user=user2), + {'comment_id': comment.comment_id}) -def test_updating_someones_comment_with_privileges(test_ctx): - user = test_ctx.user_factory(rank=db.User.RANK_REGULAR) - user2 = test_ctx.user_factory(rank=db.User.RANK_MODERATOR) - comment = test_ctx.comment_factory(user=user) +def test_updating_someones_comment_with_privileges( + user_factory, comment_factory, context_factory): + user = user_factory(rank=db.User.RANK_REGULAR) + user2 = user_factory(rank=db.User.RANK_MODERATOR) + comment = comment_factory(user=user) db.session.add(comment) db.session.commit() - try: - test_ctx.api.put( - test_ctx.context_factory( - input={'text': 'new text', 'version': 1}, user=user2), - comment.comment_id) - except: - pytest.fail() + with unittest.mock.patch('szurubooru.func.comments.serialize_comment'): + api.comment_api.update_comment( + context_factory( + params={'text': 'new text', 'version': 1}, user=user2), + {'comment_id': comment.comment_id}) diff --git a/server/szurubooru/tests/api/test_info.py b/server/szurubooru/tests/api/test_info.py index 91659cd6..c1499c1d 100644 --- a/server/szurubooru/tests/api/test_info.py +++ b/server/szurubooru/tests/api/test_info.py @@ -31,9 +31,8 @@ def test_info_api( }, } - info_api = api.InfoApi() with fake_datetime('2016-01-01 13:00'): - assert info_api.get(context_factory()) == { + assert api.info_api.get_info(context_factory()) == { 'postCount': 2, 'diskUsage': 3, 'featuredPost': None, @@ -44,7 +43,7 @@ def test_info_api( } directory.join('test2.txt').write('abc') with fake_datetime('2016-01-01 13:59'): - assert info_api.get(context_factory()) == { + assert api.info_api.get_info(context_factory()) == { 'postCount': 2, 'diskUsage': 3, # still 3 - it's cached 'featuredPost': None, @@ -54,7 +53,7 @@ def test_info_api( 'config': expected_config_key, } with fake_datetime('2016-01-01 14:01'): - assert info_api.get(context_factory()) == { + assert api.info_api.get_info(context_factory()) == { 'postCount': 2, 'diskUsage': 6, # cache expired 'featuredPost': None, diff --git a/server/szurubooru/tests/api/test_password_reset.py b/server/szurubooru/tests/api/test_password_reset.py index 5c3eb86b..d1a844ce 100644 --- a/server/szurubooru/tests/api/test_password_reset.py +++ b/server/szurubooru/tests/api/test_password_reset.py @@ -1,71 +1,70 @@ -from datetime import datetime -from unittest import mock import pytest +import unittest.mock from szurubooru import api, db, errors from szurubooru.func import auth, mailer -@pytest.fixture -def password_reset_api(config_injector): +@pytest.fixture(autouse=True) +def inject_config(tmpdir, config_injector): config_injector({ 'secret': 'x', 'base_url': 'http://example.com/', 'name': 'Test instance', }) - return api.PasswordResetApi() -def test_reset_sending_email( - password_reset_api, context_factory, user_factory): +def test_reset_sending_email(context_factory, user_factory): db.session.add(user_factory( name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')) - for getter in ['u1', 'user@example.com']: - mailer.send_mail = mock.MagicMock() - assert password_reset_api.get(context_factory(), getter) == {} - mailer.send_mail.assert_called_once_with( - 'noreply@Test instance', - 'user@example.com', - 'Password reset for Test instance', - 'You (or someone else) requested to reset your password ' + - 'on Test instance.\nIf you wish to proceed, click this l' + - 'ink: http://example.com/password-reset/u1:4ac0be176fb36' + - '4f13ee6b634c43220e2\nOtherwise, please ignore this email.') + for initiating_user in ['u1', 'user@example.com']: + with unittest.mock.patch('szurubooru.func.mailer.send_mail'): + assert api.password_reset_api.start_password_reset( + context_factory(), {'user_name': initiating_user}) == {} + mailer.send_mail.assert_called_once_with( + 'noreply@Test instance', + 'user@example.com', + 'Password reset for Test instance', + 'You (or someone else) requested to reset your password ' + + 'on Test instance.\nIf you wish to proceed, click this l' + + 'ink: http://example.com/password-reset/u1:4ac0be176fb36' + + '4f13ee6b634c43220e2\nOtherwise, please ignore this email.') -def test_trying_to_reset_non_existing(password_reset_api, context_factory): +def test_trying_to_reset_non_existing(context_factory): with pytest.raises(errors.NotFoundError): - password_reset_api.get(context_factory(), 'u1') + api.password_reset_api.start_password_reset( + context_factory(), {'user_name': 'u1'}) -def test_trying_to_reset_without_email( - password_reset_api, context_factory, user_factory): +def test_trying_to_reset_without_email(context_factory, user_factory): db.session.add(user_factory(name='u1', rank=db.User.RANK_REGULAR, email=None)) with pytest.raises(errors.ValidationError): - password_reset_api.get(context_factory(), 'u1') + api.password_reset_api.start_password_reset( + context_factory(), {'user_name': 'u1'}) -def test_confirming_with_good_token( - password_reset_api, context_factory, user_factory): +def test_confirming_with_good_token(context_factory, user_factory): user = user_factory( name='u1', rank=db.User.RANK_REGULAR, email='user@example.com') old_hash = user.password_hash db.session.add(user) context = context_factory( - input={'token': '4ac0be176fb364f13ee6b634c43220e2'}) - result = password_reset_api.post(context, 'u1') + params={'token': '4ac0be176fb364f13ee6b634c43220e2'}) + result = api.password_reset_api.finish_password_reset( + context, {'user_name': 'u1'}) assert user.password_hash != old_hash assert auth.is_valid_password(user, result['password']) is True -def test_trying_to_confirm_non_existing(password_reset_api, context_factory): +def test_trying_to_confirm_non_existing(context_factory): with pytest.raises(errors.NotFoundError): - password_reset_api.post(context_factory(), 'u1') + api.password_reset_api.finish_password_reset( + context_factory(), {'user_name': 'u1'}) -def test_trying_to_confirm_without_token( - password_reset_api, context_factory, user_factory): +def test_trying_to_confirm_without_token(context_factory, user_factory): db.session.add(user_factory( name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')) with pytest.raises(errors.ValidationError): - password_reset_api.post(context_factory(input={}), 'u1') + api.password_reset_api.finish_password_reset( + context_factory(params={}), {'user_name': 'u1'}) -def test_trying_to_confirm_with_bad_token( - password_reset_api, context_factory, user_factory): +def test_trying_to_confirm_with_bad_token(context_factory, user_factory): db.session.add(user_factory( name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')) with pytest.raises(errors.ValidationError): - password_reset_api.post( - context_factory(input={'token': 'bad'}), 'u1') + api.password_reset_api.finish_password_reset( + context_factory(params={'token': 'bad'}), {'user_name': 'u1'}) diff --git a/server/szurubooru/tests/api/test_post_creating.py b/server/szurubooru/tests/api/test_post_creating.py index 9b15379c..bddd10ba 100644 --- a/server/szurubooru/tests/api/test_post_creating.py +++ b/server/szurubooru/tests/api/test_post_creating.py @@ -1,7 +1,5 @@ -import datetime -import os -import unittest.mock import pytest +import unittest.mock from szurubooru import api, db, errors from szurubooru.func import posts, tags, snapshots, net @@ -35,9 +33,9 @@ def test_creating_minimal_posts( posts.create_post.return_value = (post, []) posts.serialize_post.return_value = 'serialized post' - result = api.PostListApi().post( + result = api.post_api.create_post( context_factory( - input={ + params={ 'safety': 'safe', 'tags': ['tag1', 'tag2'], }, @@ -79,9 +77,9 @@ def test_creating_full_posts(context_factory, post_factory, user_factory): posts.create_post.return_value = (post, []) posts.serialize_post.return_value = 'serialized post' - result = api.PostListApi().post( + result = api.post_api.create_post( context_factory( - input={ + params={ 'safety': 'safe', 'tags': ['tag1', 'tag2'], 'relations': [1, 2], @@ -122,9 +120,9 @@ def test_anonymous_uploads( 'privileges': {'posts:create:anonymous': db.User.RANK_REGULAR}, }) posts.create_post.return_value = [post, []] - api.PostListApi().post( + api.post_api.create_post( context_factory( - input={ + params={ 'safety': 'safe', 'tags': ['tag1', 'tag2'], 'anonymous': 'True', @@ -154,9 +152,9 @@ def test_creating_from_url_saves_source( }) net.download.return_value = b'content' posts.create_post.return_value = [post, []] - api.PostListApi().post( + api.post_api.create_post( context_factory( - input={ + params={ 'safety': 'safe', 'tags': ['tag1', 'tag2'], 'contentUrl': 'example.com', @@ -185,9 +183,9 @@ def test_creating_from_url_with_source_specified( }) net.download.return_value = b'content' posts.create_post.return_value = [post, []] - api.PostListApi().post( + api.post_api.create_post( context_factory( - input={ + params={ 'safety': 'safe', 'tags': ['tag1', 'tag2'], 'contentUrl': 'example.com', @@ -201,23 +199,23 @@ def test_creating_from_url_with_source_specified( @pytest.mark.parametrize('field', ['tags', 'safety']) def test_trying_to_omit_mandatory_field(context_factory, user_factory, field): - input = { + params = { 'safety': 'safe', 'tags': ['tag1', 'tag2'], } - del input[field] + del params[field] with pytest.raises(errors.MissingRequiredParameterError): - api.PostListApi().post( + api.post_api.create_post( context_factory( - input=input, + params=params, files={'content': '...'}, user=user_factory(rank=db.User.RANK_REGULAR))) def test_trying_to_omit_content(context_factory, user_factory): with pytest.raises(errors.MissingRequiredFileError): - api.PostListApi().post( + api.post_api.create_post( context_factory( - input={ + params={ 'safety': 'safe', 'tags': ['tag1', 'tag2'], }, @@ -225,10 +223,9 @@ def test_trying_to_omit_content(context_factory, user_factory): def test_trying_to_create_post_without_privileges(context_factory, user_factory): with pytest.raises(errors.AuthError): - api.PostListApi().post( - context_factory( - input='whatever', - user=user_factory(rank=db.User.RANK_ANONYMOUS))) + api.post_api.create_post(context_factory( + params='whatever', + user=user_factory(rank=db.User.RANK_ANONYMOUS))) def test_trying_to_create_tags_without_privileges( config_injector, context_factory, user_factory): @@ -243,9 +240,9 @@ def test_trying_to_create_tags_without_privileges( unittest.mock.patch('szurubooru.func.posts.update_post_content'), \ unittest.mock.patch('szurubooru.func.posts.update_post_tags'): posts.update_post_tags.return_value = ['new-tag'] - api.PostListApi().post( + api.post_api.create_post( context_factory( - input={ + params={ 'safety': 'safe', 'tags': ['tag1', 'tag2'], }, diff --git a/server/szurubooru/tests/api/test_post_deleting.py b/server/szurubooru/tests/api/test_post_deleting.py index 1b116399..f128d527 100644 --- a/server/szurubooru/tests/api/test_post_deleting.py +++ b/server/szurubooru/tests/api/test_post_deleting.py @@ -1,49 +1,37 @@ import pytest -import os -from datetime import datetime -from szurubooru import api, config, db, errors -from szurubooru.func import util, posts +import unittest.mock +from szurubooru import api, db, errors +from szurubooru.func import posts, tags -@pytest.fixture -def test_ctx( - tmpdir, config_injector, context_factory, post_factory, user_factory): - config_injector({ - 'data_dir': str(tmpdir), - 'privileges': { - 'posts:delete': db.User.RANK_REGULAR, - }, - }) - ret = util.dotdict() - ret.context_factory = context_factory - ret.user_factory = user_factory - ret.post_factory = post_factory - ret.api = api.PostDetailApi() - return ret +@pytest.fixture(autouse=True) +def inject_config(config_injector): + config_injector({'privileges': {'posts:delete': db.User.RANK_REGULAR}}) -def test_deleting(test_ctx): - db.session.add(test_ctx.post_factory(id=1)) +def test_deleting(user_factory, post_factory, context_factory): + db.session.add(post_factory(id=1)) db.session.commit() - result = test_ctx.api.delete( - test_ctx.context_factory( - input={'version': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), - 1) - assert result == {} - assert db.session.query(db.Post).count() == 0 - assert os.path.exists(os.path.join(config.config['data_dir'], 'tags.json')) + with unittest.mock.patch('szurubooru.func.tags.export_to_json'): + result = api.post_api.delete_post( + context_factory( + params={'version': 1}, + user=user_factory(rank=db.User.RANK_REGULAR)), + {'post_id': 1}) + assert result == {} + assert db.session.query(db.Post).count() == 0 + tags.export_to_json.assert_called_once_with() -def test_trying_to_delete_non_existing(test_ctx): +def test_trying_to_delete_non_existing(user_factory, context_factory): with pytest.raises(posts.PostNotFoundError): - test_ctx.api.delete( - test_ctx.context_factory( - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), '999') + api.post_api.delete_post( + context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), + {'post_id': 999}) -def test_trying_to_delete_without_privileges(test_ctx): - db.session.add(test_ctx.post_factory(id=1)) +def test_trying_to_delete_without_privileges( + user_factory, post_factory, context_factory): + db.session.add(post_factory(id=1)) db.session.commit() with pytest.raises(errors.AuthError): - test_ctx.api.delete( - test_ctx.context_factory( - user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)), - 1) + api.post_api.delete_post( + context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)), + {'post_id': 1}) assert db.session.query(db.Post).count() == 1 diff --git a/server/szurubooru/tests/api/test_post_favoriting.py b/server/szurubooru/tests/api/test_post_favoriting.py index 0e24a01f..0788ec2f 100644 --- a/server/szurubooru/tests/api/test_post_favoriting.py +++ b/server/szurubooru/tests/api/test_post_favoriting.py @@ -1,132 +1,129 @@ -import datetime import pytest +import unittest.mock +from datetime import datetime from szurubooru import api, db, errors -from szurubooru.func import util, posts +from szurubooru.func import posts -@pytest.fixture -def test_ctx( - tmpdir, config_injector, context_factory, user_factory, post_factory): - config_injector({ - 'data_dir': str(tmpdir), - 'data_url': 'http://example.com', - 'privileges': { - 'posts:favorite': db.User.RANK_REGULAR, - 'users:edit:any:email': db.User.RANK_MODERATOR, - }, - 'thumbnails': {'avatar_width': 200}, - }) - db.session.flush() - ret = util.dotdict() - ret.context_factory = context_factory - ret.user_factory = user_factory - ret.post_factory = post_factory - ret.api = api.PostFavoriteApi() - return ret +@pytest.fixture(autouse=True) +def inject_config(config_injector): + config_injector({'privileges': {'posts:favorite': db.User.RANK_REGULAR}}) -def test_adding_to_favorites(test_ctx, fake_datetime): - post = test_ctx.post_factory() +def test_adding_to_favorites( + user_factory, post_factory, context_factory, fake_datetime): + post = post_factory() db.session.add(post) db.session.commit() assert post.score == 0 - with fake_datetime('1997-12-01'): - result = test_ctx.api.post( - test_ctx.context_factory(user=test_ctx.user_factory()), - post.post_id) - assert 'id' in result - post = db.session.query(db.Post).one() - assert db.session.query(db.PostFavorite).count() == 1 - assert post is not None - assert post.favorite_count == 1 - assert post.score == 1 + with unittest.mock.patch('szurubooru.func.posts.serialize_post'), \ + fake_datetime('1997-12-01'): + posts.serialize_post.return_value = 'serialized post' + result = api.post_api.add_post_to_favorites( + context_factory(user=user_factory()), + {'post_id': post.post_id}) + assert result == 'serialized post' + post = db.session.query(db.Post).one() + assert db.session.query(db.PostFavorite).count() == 1 + assert post is not None + assert post.favorite_count == 1 + assert post.score == 1 -def test_removing_from_favorites(test_ctx, fake_datetime): - user = test_ctx.user_factory() - post = test_ctx.post_factory() +def test_removing_from_favorites( + user_factory, post_factory, context_factory, fake_datetime): + user = user_factory() + post = post_factory() db.session.add(post) db.session.commit() assert post.score == 0 - with fake_datetime('1997-12-01'): - result = test_ctx.api.post( - test_ctx.context_factory(user=user), - post.post_id) - assert post.score == 1 - with fake_datetime('1997-12-02'): - result = test_ctx.api.delete( - test_ctx.context_factory(user=user), - post.post_id) - post = db.session.query(db.Post).one() - assert post.score == 1 - assert db.session.query(db.PostFavorite).count() == 0 - assert post.favorite_count == 0 + with unittest.mock.patch('szurubooru.func.posts.serialize_post'): + with fake_datetime('1997-12-01'): + api.post_api.add_post_to_favorites( + context_factory(user=user), + {'post_id': post.post_id}) + assert post.score == 1 + with fake_datetime('1997-12-02'): + api.post_api.delete_post_from_favorites( + context_factory(user=user), + {'post_id': post.post_id}) + post = db.session.query(db.Post).one() + assert post.score == 1 + assert db.session.query(db.PostFavorite).count() == 0 + assert post.favorite_count == 0 -def test_favoriting_twice(test_ctx, fake_datetime): - user = test_ctx.user_factory() - post = test_ctx.post_factory() +def test_favoriting_twice( + user_factory, post_factory, context_factory, fake_datetime): + user = user_factory() + post = post_factory() db.session.add(post) db.session.commit() - with fake_datetime('1997-12-01'): - result = test_ctx.api.post( - test_ctx.context_factory(user=user), - post.post_id) - with fake_datetime('1997-12-02'): - result = test_ctx.api.post( - test_ctx.context_factory(user=user), - post.post_id) - post = db.session.query(db.Post).one() - assert db.session.query(db.PostFavorite).count() == 1 - assert post.favorite_count == 1 + with unittest.mock.patch('szurubooru.func.posts.serialize_post'): + with fake_datetime('1997-12-01'): + api.post_api.add_post_to_favorites( + context_factory(user=user), + {'post_id': post.post_id}) + with fake_datetime('1997-12-02'): + api.post_api.add_post_to_favorites( + context_factory(user=user), + {'post_id': post.post_id}) + post = db.session.query(db.Post).one() + assert db.session.query(db.PostFavorite).count() == 1 + assert post.favorite_count == 1 -def test_removing_twice(test_ctx, fake_datetime): - user = test_ctx.user_factory() - post = test_ctx.post_factory() +def test_removing_twice( + user_factory, post_factory, context_factory, fake_datetime): + user = user_factory() + post = post_factory() db.session.add(post) db.session.commit() - with fake_datetime('1997-12-01'): - result = test_ctx.api.post( - test_ctx.context_factory(user=user), - post.post_id) - with fake_datetime('1997-12-02'): - result = test_ctx.api.delete( - test_ctx.context_factory(user=user), - post.post_id) - with fake_datetime('1997-12-02'): - result = test_ctx.api.delete( - test_ctx.context_factory(user=user), - post.post_id) - post = db.session.query(db.Post).one() - assert db.session.query(db.PostFavorite).count() == 0 - assert post.favorite_count == 0 + with unittest.mock.patch('szurubooru.func.posts.serialize_post'): + with fake_datetime('1997-12-01'): + api.post_api.add_post_to_favorites( + context_factory(user=user), + {'post_id': post.post_id}) + with fake_datetime('1997-12-02'): + api.post_api.delete_post_from_favorites( + context_factory(user=user), + {'post_id': post.post_id}) + with fake_datetime('1997-12-02'): + api.post_api.delete_post_from_favorites( + context_factory(user=user), + {'post_id': post.post_id}) + post = db.session.query(db.Post).one() + assert db.session.query(db.PostFavorite).count() == 0 + assert post.favorite_count == 0 -def test_favorites_from_multiple_users(test_ctx, fake_datetime): - user1 = test_ctx.user_factory() - user2 = test_ctx.user_factory() - post = test_ctx.post_factory() +def test_favorites_from_multiple_users( + user_factory, post_factory, context_factory, fake_datetime): + user1 = user_factory() + user2 = user_factory() + post = post_factory() db.session.add_all([user1, user2, post]) db.session.commit() - with fake_datetime('1997-12-01'): - result = test_ctx.api.post( - test_ctx.context_factory(user=user1), - post.post_id) - with fake_datetime('1997-12-02'): - result = test_ctx.api.post( - test_ctx.context_factory(user=user2), - post.post_id) - post = db.session.query(db.Post).one() - assert db.session.query(db.PostFavorite).count() == 2 - assert post.favorite_count == 2 - assert post.last_favorite_time == datetime.datetime(1997, 12, 2) + with unittest.mock.patch('szurubooru.func.posts.serialize_post'): + with fake_datetime('1997-12-01'): + api.post_api.add_post_to_favorites( + context_factory(user=user1), + {'post_id': post.post_id}) + with fake_datetime('1997-12-02'): + api.post_api.add_post_to_favorites( + context_factory(user=user2), + {'post_id': post.post_id}) + post = db.session.query(db.Post).one() + assert db.session.query(db.PostFavorite).count() == 2 + assert post.favorite_count == 2 + assert post.last_favorite_time == datetime(1997, 12, 2) -def test_trying_to_update_non_existing(test_ctx): +def test_trying_to_update_non_existing(user_factory, context_factory): with pytest.raises(posts.PostNotFoundError): - test_ctx.api.post( - test_ctx.context_factory(user=test_ctx.user_factory()), 5) + api.post_api.add_post_to_favorites( + context_factory(user=user_factory()), + {'post_id': 5}) -def test_trying_to_rate_without_privileges(test_ctx): - post = test_ctx.post_factory() +def test_trying_to_rate_without_privileges( + user_factory, post_factory, context_factory): + post = post_factory() db.session.add(post) db.session.commit() with pytest.raises(errors.AuthError): - test_ctx.api.post( - test_ctx.context_factory( - user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)), - post.post_id) + api.post_api.add_post_to_favorites( + context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)), + {'post_id': post.post_id}) diff --git a/server/szurubooru/tests/api/test_post_featuring.py b/server/szurubooru/tests/api/test_post_featuring.py index fbd6e45f..45cae474 100644 --- a/server/szurubooru/tests/api/test_post_featuring.py +++ b/server/szurubooru/tests/api/test_post_featuring.py @@ -1,107 +1,100 @@ -import datetime import pytest +import unittest.mock from szurubooru import api, db, errors -from szurubooru.func import util, posts +from szurubooru.func import posts -@pytest.fixture -def test_ctx( - tmpdir, context_factory, config_injector, user_factory, post_factory): +@pytest.fixture(autouse=True) +def inject_config(config_injector): config_injector({ - 'data_dir': str(tmpdir), - 'data_url': 'http://example.com', 'privileges': { 'posts:feature': db.User.RANK_REGULAR, 'posts:view': db.User.RANK_REGULAR, }, }) - ret = util.dotdict() - ret.context_factory = context_factory - ret.user_factory = user_factory - ret.post_factory = post_factory - ret.api = api.PostFeatureApi() - return ret -def test_no_featured_post(test_ctx): +def test_no_featured_post(user_factory, post_factory, context_factory): assert posts.try_get_featured_post() is None - result = test_ctx.api.get( - test_ctx.context_factory( - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) - assert result is None -def test_featuring(test_ctx): - db.session.add(test_ctx.post_factory(id=1)) +def test_featuring(user_factory, post_factory, context_factory): + db.session.add(post_factory(id=1)) db.session.commit() assert not posts.get_post_by_id(1).is_featured - result = test_ctx.api.post( - test_ctx.context_factory( - input={'id': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) - assert posts.try_get_featured_post() is not None - assert posts.try_get_featured_post().post_id == 1 - assert posts.get_post_by_id(1).is_featured - assert 'id' in result - assert 'snapshots' in result - assert 'comments' in result - result = test_ctx.api.get( - test_ctx.context_factory( - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) - assert 'id' in result - assert 'snapshots' in result - assert 'comments' in result + with unittest.mock.patch('szurubooru.func.posts.serialize_post'): + posts.serialize_post.return_value = 'serialized post' + result = api.post_api.set_featured_post( + context_factory( + params={'id': 1}, + user=user_factory(rank=db.User.RANK_REGULAR))) + assert result == 'serialized post' + assert posts.try_get_featured_post() is not None + assert posts.try_get_featured_post().post_id == 1 + assert posts.get_post_by_id(1).is_featured + result = api.post_api.get_featured_post( + context_factory( + user=user_factory(rank=db.User.RANK_REGULAR))) + assert result == 'serialized post' -def test_trying_to_feature_the_same_post_twice(test_ctx): - db.session.add(test_ctx.post_factory(id=1)) +def test_trying_to_omit_required_parameter( + user_factory, post_factory, context_factory): + with pytest.raises(errors.MissingRequiredParameterError): + api.post_api.set_featured_post( + context_factory( + user=user_factory(rank=db.User.RANK_REGULAR))) + +def test_trying_to_feature_the_same_post_twice( + user_factory, post_factory, context_factory): + db.session.add(post_factory(id=1)) db.session.commit() - test_ctx.api.post( - test_ctx.context_factory( - input={'id': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) - with pytest.raises(posts.PostAlreadyFeaturedError): - test_ctx.api.post( - test_ctx.context_factory( - input={'id': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) + with unittest.mock.patch('szurubooru.func.posts.serialize_post'): + api.post_api.set_featured_post( + context_factory( + params={'id': 1}, + user=user_factory(rank=db.User.RANK_REGULAR))) + with pytest.raises(posts.PostAlreadyFeaturedError): + api.post_api.set_featured_post( + context_factory( + params={'id': 1}, + user=user_factory(rank=db.User.RANK_REGULAR))) -def test_featuring_one_post_after_another(test_ctx, fake_datetime): - db.session.add(test_ctx.post_factory(id=1)) - db.session.add(test_ctx.post_factory(id=2)) +def test_featuring_one_post_after_another( + user_factory, post_factory, context_factory, fake_datetime): + db.session.add(post_factory(id=1)) + db.session.add(post_factory(id=2)) db.session.commit() assert posts.try_get_featured_post() is None assert not posts.get_post_by_id(1).is_featured assert not posts.get_post_by_id(2).is_featured - with fake_datetime('1997'): - result = test_ctx.api.post( - test_ctx.context_factory( - input={'id': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) - with fake_datetime('1998'): - result = test_ctx.api.post( - test_ctx.context_factory( - input={'id': 2}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) - assert posts.try_get_featured_post() is not None - assert posts.try_get_featured_post().post_id == 2 - assert not posts.get_post_by_id(1).is_featured - assert posts.get_post_by_id(2).is_featured + with unittest.mock.patch('szurubooru.func.posts.serialize_post'): + with fake_datetime('1997'): + result = api.post_api.set_featured_post( + context_factory( + params={'id': 1}, + user=user_factory(rank=db.User.RANK_REGULAR))) + with fake_datetime('1998'): + result = api.post_api.set_featured_post( + context_factory( + params={'id': 2}, + user=user_factory(rank=db.User.RANK_REGULAR))) + assert posts.try_get_featured_post() is not None + assert posts.try_get_featured_post().post_id == 2 + assert not posts.get_post_by_id(1).is_featured + assert posts.get_post_by_id(2).is_featured -def test_trying_to_feature_non_existing(test_ctx): +def test_trying_to_feature_non_existing(user_factory, context_factory): with pytest.raises(posts.PostNotFoundError): - test_ctx.api.post( - test_ctx.context_factory( - input={'id': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) + api.post_api.set_featured_post( + context_factory( + params={'id': 1}, + user=user_factory(rank=db.User.RANK_REGULAR))) -def test_trying_to_feature_without_privileges(test_ctx): +def test_trying_to_feature_without_privileges(user_factory, context_factory): with pytest.raises(errors.AuthError): - test_ctx.api.post( - test_ctx.context_factory( - input={'id': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS))) + api.post_api.set_featured_post( + context_factory( + params={'id': 1}, + user=user_factory(rank=db.User.RANK_ANONYMOUS))) -def test_getting_featured_post_without_privileges_to_view(test_ctx): - try: - test_ctx.api.get( - test_ctx.context_factory( - user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS))) - except: - pytest.fail() +def test_getting_featured_post_without_privileges_to_view( + user_factory, context_factory): + api.post_api.get_featured_post( + context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS))) diff --git a/server/szurubooru/tests/api/test_post_rating.py b/server/szurubooru/tests/api/test_post_rating.py index fd631158..ed646b3b 100644 --- a/server/szurubooru/tests/api/test_post_rating.py +++ b/server/szurubooru/tests/api/test_post_rating.py @@ -1,147 +1,132 @@ import pytest +import unittest.mock from szurubooru import api, db, errors -from szurubooru.func import util, posts, scores +from szurubooru.func import posts, scores -@pytest.fixture -def test_ctx( - tmpdir, config_injector, context_factory, user_factory, post_factory): - config_injector({ - 'data_dir': str(tmpdir), - 'data_url': 'http://example.com', - 'privileges': {'posts:score': db.User.RANK_REGULAR}, - 'thumbnails': {'avatar_width': 200}, - }) - db.session.flush() - ret = util.dotdict() - ret.context_factory = context_factory - ret.user_factory = user_factory - ret.post_factory = post_factory - ret.api = api.PostScoreApi() - return ret +@pytest.fixture(autouse=True) +def inject_config(config_injector): + config_injector({'privileges': {'posts:score': db.User.RANK_REGULAR}}) -def test_simple_rating(test_ctx, fake_datetime): - post = test_ctx.post_factory() +def test_simple_rating( + user_factory, post_factory, context_factory, fake_datetime): + post = post_factory() db.session.add(post) db.session.commit() - with fake_datetime('1997-12-01'): - result = test_ctx.api.put( - test_ctx.context_factory( - input={'score': 1}, user=test_ctx.user_factory()), - post.post_id) - assert 'id' in result - post = db.session.query(db.Post).one() - assert db.session.query(db.PostScore).count() == 1 - assert post is not None - assert post.score == 1 + with unittest.mock.patch('szurubooru.func.posts.serialize_post'): + posts.serialize_post.return_value = 'serialized post' + with fake_datetime('1997-12-01'): + result = api.post_api.set_post_score( + context_factory( + params={'score': 1}, user=user_factory()), + {'post_id': post.post_id}) + assert result == 'serialized post' + post = db.session.query(db.Post).one() + assert db.session.query(db.PostScore).count() == 1 + assert post is not None + assert post.score == 1 -def test_updating_rating(test_ctx, fake_datetime): - user = test_ctx.user_factory() - post = test_ctx.post_factory() +def test_updating_rating( + user_factory, post_factory, context_factory, fake_datetime): + user = user_factory() + post = post_factory() db.session.add(post) db.session.commit() - with fake_datetime('1997-12-01'): - result = test_ctx.api.put( - test_ctx.context_factory(input={'score': 1}, user=user), - post.post_id) - with fake_datetime('1997-12-02'): - result = test_ctx.api.put( - test_ctx.context_factory(input={'score': -1}, user=user), - post.post_id) - post = db.session.query(db.Post).one() - assert db.session.query(db.PostScore).count() == 1 - assert post.score == -1 + with unittest.mock.patch('szurubooru.func.posts.serialize_post'): + with fake_datetime('1997-12-01'): + result = api.post_api.set_post_score( + context_factory(params={'score': 1}, user=user), + {'post_id': post.post_id}) + with fake_datetime('1997-12-02'): + result = api.post_api.set_post_score( + context_factory(params={'score': -1}, user=user), + {'post_id': post.post_id}) + post = db.session.query(db.Post).one() + assert db.session.query(db.PostScore).count() == 1 + assert post.score == -1 -def test_updating_rating_to_zero(test_ctx, fake_datetime): - user = test_ctx.user_factory() - post = test_ctx.post_factory() +def test_updating_rating_to_zero( + user_factory, post_factory, context_factory, fake_datetime): + user = user_factory() + post = post_factory() db.session.add(post) db.session.commit() - with fake_datetime('1997-12-01'): - result = test_ctx.api.put( - test_ctx.context_factory(input={'score': 1}, user=user), - post.post_id) - with fake_datetime('1997-12-02'): - result = test_ctx.api.put( - test_ctx.context_factory(input={'score': 0}, user=user), - post.post_id) - post = db.session.query(db.Post).one() - assert db.session.query(db.PostScore).count() == 0 - assert post.score == 0 + with unittest.mock.patch('szurubooru.func.posts.serialize_post'): + with fake_datetime('1997-12-01'): + result = api.post_api.set_post_score( + context_factory(params={'score': 1}, user=user), + {'post_id': post.post_id}) + with fake_datetime('1997-12-02'): + result = api.post_api.set_post_score( + context_factory(params={'score': 0}, user=user), + {'post_id': post.post_id}) + post = db.session.query(db.Post).one() + assert db.session.query(db.PostScore).count() == 0 + assert post.score == 0 -def test_deleting_rating(test_ctx, fake_datetime): - user = test_ctx.user_factory() - post = test_ctx.post_factory() +def test_deleting_rating( + user_factory, post_factory, context_factory, fake_datetime): + user = user_factory() + post = post_factory() db.session.add(post) db.session.commit() - with fake_datetime('1997-12-01'): - result = test_ctx.api.put( - test_ctx.context_factory(input={'score': 1}, user=user), - post.post_id) - with fake_datetime('1997-12-02'): - result = test_ctx.api.delete( - test_ctx.context_factory(user=user), post.post_id) - post = db.session.query(db.Post).one() - assert db.session.query(db.PostScore).count() == 0 - assert post.score == 0 + with unittest.mock.patch('szurubooru.func.posts.serialize_post'): + with fake_datetime('1997-12-01'): + result = api.post_api.set_post_score( + context_factory(params={'score': 1}, user=user), + {'post_id': post.post_id}) + with fake_datetime('1997-12-02'): + result = api.post_api.delete_post_score( + context_factory(user=user), + {'post_id': post.post_id}) + post = db.session.query(db.Post).one() + assert db.session.query(db.PostScore).count() == 0 + assert post.score == 0 -def test_ratings_from_multiple_users(test_ctx, fake_datetime): - user1 = test_ctx.user_factory() - user2 = test_ctx.user_factory() - post = test_ctx.post_factory() +def test_ratings_from_multiple_users( + user_factory, post_factory, context_factory, fake_datetime): + user1 = user_factory() + user2 = user_factory() + post = post_factory() db.session.add_all([user1, user2, post]) db.session.commit() - with fake_datetime('1997-12-01'): - result = test_ctx.api.put( - test_ctx.context_factory(input={'score': 1}, user=user1), - post.post_id) - with fake_datetime('1997-12-02'): - result = test_ctx.api.put( - test_ctx.context_factory(input={'score': -1}, user=user2), - post.post_id) - post = db.session.query(db.Post).one() - assert db.session.query(db.PostScore).count() == 2 - assert post.score == 0 + with unittest.mock.patch('szurubooru.func.posts.serialize_post'): + with fake_datetime('1997-12-01'): + result = api.post_api.set_post_score( + context_factory(params={'score': 1}, user=user1), + {'post_id': post.post_id}) + with fake_datetime('1997-12-02'): + result = api.post_api.set_post_score( + context_factory(params={'score': -1}, user=user2), + {'post_id': post.post_id}) + post = db.session.query(db.Post).one() + assert db.session.query(db.PostScore).count() == 2 + assert post.score == 0 -@pytest.mark.parametrize('input,expected_exception', [ - ({'score': None}, errors.ValidationError), - ({'score': ''}, errors.ValidationError), - ({'score': -2}, scores.InvalidScoreValueError), - ({'score': 2}, scores.InvalidScoreValueError), - ({'score': [1]}, errors.ValidationError), -]) -def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception): - post = test_ctx.post_factory() - db.session.add(post) - db.session.commit() - with pytest.raises(expected_exception): - test_ctx.api.put( - test_ctx.context_factory(input=input, user=test_ctx.user_factory()), - post.post_id) - -def test_trying_to_omit_mandatory_field(test_ctx): - post = test_ctx.post_factory() +def test_trying_to_omit_mandatory_field( + user_factory, post_factory, context_factory): + post = post_factory() db.session.add(post) db.session.commit() with pytest.raises(errors.ValidationError): - test_ctx.api.put( - test_ctx.context_factory(input={}, user=test_ctx.user_factory()), - post.post_id) + api.post_api.set_post_score( + context_factory(params={}, user=user_factory()), + {'post_id': post.post_id}) -def test_trying_to_update_non_existing(test_ctx): +def test_trying_to_update_non_existing( + user_factory, post_factory, context_factory): with pytest.raises(posts.PostNotFoundError): - test_ctx.api.put( - test_ctx.context_factory( - input={'score': 1}, - user=test_ctx.user_factory()), - 5) + api.post_api.set_post_score( + context_factory(params={'score': 1}, user=user_factory()), + {'post_id': 5}) -def test_trying_to_rate_without_privileges(test_ctx): - post = test_ctx.post_factory() +def test_trying_to_rate_without_privileges( + user_factory, post_factory, context_factory): + post = post_factory() db.session.add(post) db.session.commit() with pytest.raises(errors.AuthError): - test_ctx.api.put( - test_ctx.context_factory( - input={'score': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)), - post.post_id) + api.post_api.set_post_score( + context_factory( + params={'score': 1}, + user=user_factory(rank=db.User.RANK_ANONYMOUS)), + {'post_id': post.post_id}) diff --git a/server/szurubooru/tests/api/test_post_retrieving.py b/server/szurubooru/tests/api/test_post_retrieving.py index d62d32f3..34583460 100644 --- a/server/szurubooru/tests/api/test_post_retrieving.py +++ b/server/szurubooru/tests/api/test_post_retrieving.py @@ -1,105 +1,97 @@ -import datetime import pytest +import unittest.mock +from datetime import datetime from szurubooru import api, db, errors -from szurubooru.func import util, posts +from szurubooru.func import posts -@pytest.fixture -def test_ctx( - tmpdir, context_factory, config_injector, user_factory, post_factory): +@pytest.fixture(autouse=True) +def inject_config(tmpdir, config_injector): config_injector({ - 'data_dir': str(tmpdir), - 'data_url': 'http://example.com', 'privileges': { 'posts:list': db.User.RANK_REGULAR, 'posts:view': db.User.RANK_REGULAR, }, - 'thumbnails': {'avatar_width': 200}, }) - ret = util.dotdict() - ret.context_factory = context_factory - ret.user_factory = user_factory - ret.post_factory = post_factory - ret.list_api = api.PostListApi() - ret.detail_api = api.PostDetailApi() - return ret -def test_retrieving_multiple(test_ctx): - post1 = test_ctx.post_factory(id=1) - post2 = test_ctx.post_factory(id=2) +def test_retrieving_multiple(user_factory, post_factory, context_factory): + post1 = post_factory(id=1) + post2 = post_factory(id=2) db.session.add_all([post1, post2]) - result = test_ctx.list_api.get( - test_ctx.context_factory( - input={'query': '', 'page': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) - assert result['query'] == '' - assert result['page'] == 1 - assert result['pageSize'] == 100 - assert result['total'] == 2 - assert [t['id'] for t in result['results']] == [2, 1] + with unittest.mock.patch('szurubooru.func.posts.serialize_post'): + posts.serialize_post.return_value = 'serialized post' + result = api.post_api.get_posts( + context_factory( + params={'query': '', 'page': 1}, + user=user_factory(rank=db.User.RANK_REGULAR))) + assert result == { + 'query': '', + 'page': 1, + 'pageSize': 100, + 'total': 2, + 'results': ['serialized post', 'serialized post'], + } -def test_using_special_tokens( - test_ctx, config_injector): - auth_user = test_ctx.user_factory(rank=db.User.RANK_REGULAR) - post1 = test_ctx.post_factory(id=1) - post2 = test_ctx.post_factory(id=2) +def test_using_special_tokens(user_factory, post_factory, context_factory): + auth_user = user_factory(rank=db.User.RANK_REGULAR) + post1 = post_factory(id=1) + post2 = post_factory(id=2) post1.favorited_by = [db.PostFavorite( - user=auth_user, time=datetime.datetime.utcnow())] + user=auth_user, time=datetime.utcnow())] db.session.add_all([post1, post2, auth_user]) db.session.flush() - result = test_ctx.list_api.get( - test_ctx.context_factory( - input={'query': 'special:fav', 'page': 1}, - user=auth_user)) - assert result['query'] == 'special:fav' - assert result['page'] == 1 - assert result['pageSize'] == 100 - assert result['total'] == 1 - assert [t['id'] for t in result['results']] == [1] + with unittest.mock.patch('szurubooru.func.posts.serialize_post'): + posts.serialize_post.side_effect = \ + lambda post, *_args, **_kwargs: \ + 'serialized post %d' % post.post_id + result = api.post_api.get_posts( + context_factory( + params={'query': 'special:fav', 'page': 1}, + user=auth_user)) + assert result == { + 'query': 'special:fav', + 'page': 1, + 'pageSize': 100, + 'total': 1, + 'results': ['serialized post 1'], + } def test_trying_to_use_special_tokens_without_logging_in( - test_ctx, config_injector): + user_factory, post_factory, context_factory, config_injector): config_injector({ 'privileges': {'posts:list': 'anonymous'}, }) with pytest.raises(errors.SearchError): - test_ctx.list_api.get( - test_ctx.context_factory( - input={'query': 'special:fav', 'page': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS))) + api.post_api.get_posts( + context_factory( + params={'query': 'special:fav', 'page': 1}, + user=user_factory(rank=db.User.RANK_ANONYMOUS))) -def test_trying_to_retrieve_multiple_without_privileges(test_ctx): +def test_trying_to_retrieve_multiple_without_privileges( + user_factory, context_factory): with pytest.raises(errors.AuthError): - test_ctx.list_api.get( - test_ctx.context_factory( - input={'query': '', 'page': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS))) + api.post_api.get_posts( + context_factory( + params={'query': '', 'page': 1}, + user=user_factory(rank=db.User.RANK_ANONYMOUS))) -def test_retrieving_single(test_ctx): - db.session.add(test_ctx.post_factory(id=1)) - result = test_ctx.detail_api.get( - test_ctx.context_factory( - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 1) - assert 'id' in result - assert 'snapshots' in result - assert 'comments' in result +def test_retrieving_single(user_factory, post_factory, context_factory): + db.session.add(post_factory(id=1)) + with unittest.mock.patch('szurubooru.func.posts.serialize_post'): + posts.serialize_post.return_value = 'serialized post' + result = api.post_api.get_post( + context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), + {'post_id': 1}) + assert result == 'serialized post' -def test_trying_to_retrieve_invalid_id(test_ctx): - with pytest.raises(posts.InvalidPostIdError): - test_ctx.detail_api.get( - test_ctx.context_factory( - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), - '-') - -def test_trying_to_retrieve_single_non_existing(test_ctx): +def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): with pytest.raises(posts.PostNotFoundError): - test_ctx.detail_api.get( - test_ctx.context_factory( - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), - '999') + api.post_api.get_post( + context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), + {'post_id': 999}) -def test_trying_to_retrieve_single_without_privileges(test_ctx): +def test_trying_to_retrieve_single_without_privileges( + user_factory, context_factory): with pytest.raises(errors.AuthError): - test_ctx.detail_api.get( - test_ctx.context_factory( - user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)), - '999') + api.post_api.get_post( + context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)), + {'post_id': 999}) diff --git a/server/szurubooru/tests/api/test_post_updating.py b/server/szurubooru/tests/api/test_post_updating.py index 89e582d6..ce659f25 100644 --- a/server/szurubooru/tests/api/test_post_updating.py +++ b/server/szurubooru/tests/api/test_post_updating.py @@ -1,12 +1,11 @@ -import datetime -import os -import unittest.mock import pytest +import unittest.mock +from datetime import datetime from szurubooru import api, db, errors from szurubooru.func import posts, tags, snapshots, net -def test_post_updating( - config_injector, context_factory, post_factory, user_factory, fake_datetime): +@pytest.fixture(autouse=True) +def inject_config(tmpdir, config_injector): config_injector({ 'privileges': { 'posts:edit:tags': db.User.RANK_REGULAR, @@ -17,46 +16,49 @@ def test_post_updating( 'posts:edit:notes': db.User.RANK_REGULAR, 'posts:edit:flags': db.User.RANK_REGULAR, 'posts:edit:thumbnail': db.User.RANK_REGULAR, + 'tags:create': db.User.RANK_MODERATOR, }, }) + +def test_post_updating( + context_factory, post_factory, user_factory, fake_datetime): auth_user = user_factory(rank=db.User.RANK_REGULAR) post = post_factory() db.session.add(post) db.session.flush() with unittest.mock.patch('szurubooru.func.posts.create_post'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_tags'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_content'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_thumbnail'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_safety'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_source'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_relations'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_notes'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_flags'), \ - unittest.mock.patch('szurubooru.func.posts.serialize_post'), \ - unittest.mock.patch('szurubooru.func.tags.export_to_json'), \ - unittest.mock.patch('szurubooru.func.snapshots.save_entity_modification'): - + unittest.mock.patch('szurubooru.func.posts.update_post_tags'), \ + unittest.mock.patch('szurubooru.func.posts.update_post_content'), \ + unittest.mock.patch('szurubooru.func.posts.update_post_thumbnail'), \ + unittest.mock.patch('szurubooru.func.posts.update_post_safety'), \ + unittest.mock.patch('szurubooru.func.posts.update_post_source'), \ + unittest.mock.patch('szurubooru.func.posts.update_post_relations'), \ + unittest.mock.patch('szurubooru.func.posts.update_post_notes'), \ + unittest.mock.patch('szurubooru.func.posts.update_post_flags'), \ + unittest.mock.patch('szurubooru.func.posts.serialize_post'), \ + unittest.mock.patch('szurubooru.func.tags.export_to_json'), \ + unittest.mock.patch('szurubooru.func.snapshots.save_entity_modification'), \ + fake_datetime('1997-01-01'): posts.serialize_post.return_value = 'serialized post' - with fake_datetime('1997-01-01'): - result = api.PostDetailApi().put( - context_factory( - input={ - 'version': 1, - 'safety': 'safe', - 'tags': ['tag1', 'tag2'], - 'relations': [1, 2], - 'source': 'source', - 'notes': ['note1', 'note2'], - 'flags': ['flag1', 'flag2'], - }, - files={ - 'content': 'post-content', - 'thumbnail': 'post-thumbnail', - }, - user=auth_user), - post.post_id) + result = api.post_api.update_post( + context_factory( + params={ + 'version': 1, + 'safety': 'safe', + 'tags': ['tag1', 'tag2'], + 'relations': [1, 2], + 'source': 'source', + 'notes': ['note1', 'note2'], + 'flags': ['flag1', 'flag2'], + }, + files={ + 'content': 'post-content', + 'thumbnail': 'post-thumbnail', + }, + user=auth_user), + {'post_id': post.post_id}) assert result == 'serialized post' posts.create_post.assert_not_called() @@ -71,71 +73,62 @@ def test_post_updating( posts.serialize_post.assert_called_once_with(post, auth_user, options=None) tags.export_to_json.assert_called_once_with() snapshots.save_entity_modification.assert_called_once_with(post, auth_user) - assert post.last_edit_time == datetime.datetime(1997, 1, 1) + assert post.last_edit_time == datetime(1997, 1, 1) def test_uploading_from_url_saves_source( - config_injector, context_factory, post_factory, user_factory): - config_injector({ - 'privileges': {'posts:edit:content': db.User.RANK_REGULAR}, - }) + context_factory, post_factory, user_factory): post = post_factory() db.session.add(post) db.session.flush() with unittest.mock.patch('szurubooru.func.net.download'), \ - unittest.mock.patch('szurubooru.func.tags.export_to_json'), \ - unittest.mock.patch('szurubooru.func.snapshots.save_entity_modification'), \ - unittest.mock.patch('szurubooru.func.posts.serialize_post'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_content'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_source'): + unittest.mock.patch('szurubooru.func.tags.export_to_json'), \ + unittest.mock.patch('szurubooru.func.snapshots.save_entity_modification'), \ + unittest.mock.patch('szurubooru.func.posts.serialize_post'), \ + unittest.mock.patch('szurubooru.func.posts.update_post_content'), \ + unittest.mock.patch('szurubooru.func.posts.update_post_source'): net.download.return_value = b'content' - api.PostDetailApi().put( + api.post_api.update_post( context_factory( - input={'contentUrl': 'example.com', 'version': 1}, + params={'contentUrl': 'example.com', 'version': 1}, user=user_factory(rank=db.User.RANK_REGULAR)), - post.post_id) + {'post_id': post.post_id}) net.download.assert_called_once_with('example.com') posts.update_post_content.assert_called_once_with(post, b'content') posts.update_post_source.assert_called_once_with(post, 'example.com') def test_uploading_from_url_with_source_specified( - config_injector, context_factory, post_factory, user_factory): - config_injector({ - 'privileges': { - 'posts:edit:content': db.User.RANK_REGULAR, - 'posts:edit:source': db.User.RANK_REGULAR, - }, - }) + context_factory, post_factory, user_factory): post = post_factory() db.session.add(post) db.session.flush() with unittest.mock.patch('szurubooru.func.net.download'), \ - unittest.mock.patch('szurubooru.func.tags.export_to_json'), \ - unittest.mock.patch('szurubooru.func.snapshots.save_entity_modification'), \ - unittest.mock.patch('szurubooru.func.posts.serialize_post'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_content'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_source'): + unittest.mock.patch('szurubooru.func.tags.export_to_json'), \ + unittest.mock.patch('szurubooru.func.snapshots.save_entity_modification'), \ + unittest.mock.patch('szurubooru.func.posts.serialize_post'), \ + unittest.mock.patch('szurubooru.func.posts.update_post_content'), \ + unittest.mock.patch('szurubooru.func.posts.update_post_source'): net.download.return_value = b'content' - api.PostDetailApi().put( + api.post_api.update_post( context_factory( - input={ + params={ 'contentUrl': 'example.com', 'source': 'example2.com', 'version': 1}, user=user_factory(rank=db.User.RANK_REGULAR)), - post.post_id) + {'post_id': post.post_id}) net.download.assert_called_once_with('example.com') posts.update_post_content.assert_called_once_with(post, b'content') posts.update_post_source.assert_called_once_with(post, 'example2.com') def test_trying_to_update_non_existing(context_factory, user_factory): with pytest.raises(posts.PostNotFoundError): - api.PostDetailApi().put( + api.post_api.update_post( context_factory( - input='whatever', + params='whatever', user=user_factory(rank=db.User.RANK_REGULAR)), - 1) + {'post_id': 1}) -@pytest.mark.parametrize('privilege,files,input', [ +@pytest.mark.parametrize('privilege,files,params', [ ('posts:edit:tags', {}, {'tags': '...'}), ('posts:edit:safety', {}, {'safety': '...'}), ('posts:edit:source', {}, {'source': '...'}), @@ -146,43 +139,28 @@ def test_trying_to_update_non_existing(context_factory, user_factory): ('posts:edit:thumbnail', {'thumbnail': '...'}, {}), ]) def test_trying_to_update_field_without_privileges( - config_injector, - context_factory, - post_factory, - user_factory, - files, - input, - privilege): - config_injector({ - 'privileges': {privilege: db.User.RANK_REGULAR}, - }) + context_factory, post_factory, user_factory, files, params, privilege): post = post_factory() db.session.add(post) db.session.flush() with pytest.raises(errors.AuthError): - api.PostDetailApi().put( + api.post_api.update_post( context_factory( - input={**input, **{'version': 1}}, + params={**params, **{'version': 1}}, files=files, user=user_factory(rank=db.User.RANK_ANONYMOUS)), - post.post_id) + {'post_id': post.post_id}) def test_trying_to_create_tags_without_privileges( - config_injector, context_factory, post_factory, user_factory): - config_injector({ - 'privileges': { - 'posts:edit:tags': db.User.RANK_REGULAR, - 'tags:create': db.User.RANK_ADMINISTRATOR, - }, - }) + context_factory, post_factory, user_factory): post = post_factory() db.session.add(post) db.session.flush() with pytest.raises(errors.AuthError), \ unittest.mock.patch('szurubooru.func.posts.update_post_tags'): posts.update_post_tags.return_value = ['new-tag'] - api.PostDetailApi().put( + api.post_api.update_post( context_factory( - input={'tags': ['tag1', 'tag2'], 'version': 1}, + params={'tags': ['tag1', 'tag2'], 'version': 1}, user=user_factory(rank=db.User.RANK_REGULAR)), - post.post_id) + {'post_id': post.post_id}) diff --git a/server/szurubooru/tests/api/test_snapshot_retrieving.py b/server/szurubooru/tests/api/test_snapshot_retrieving.py index aeeb8294..c7d917a9 100644 --- a/server/szurubooru/tests/api/test_snapshot_retrieving.py +++ b/server/szurubooru/tests/api/test_snapshot_retrieving.py @@ -1,11 +1,10 @@ -import datetime import pytest +from datetime import datetime from szurubooru import api, db, errors -from szurubooru.func import util, tags def snapshot_factory(): snapshot = db.Snapshot() - snapshot.creation_time = datetime.datetime(1999, 1, 1) + snapshot.creation_time = datetime(1999, 1, 1) snapshot.resource_type = 'dummy' snapshot.resource_id = 1 snapshot.resource_repr = 'dummy' @@ -13,37 +12,30 @@ def snapshot_factory(): snapshot.data = '{}' return snapshot -@pytest.fixture -def test_ctx(context_factory, config_injector, user_factory): +@pytest.fixture(autouse=True) +def inject_config(config_injector): config_injector({ - 'privileges': { - 'snapshots:list': db.User.RANK_REGULAR, - }, - 'thumbnails': {'avatar_width': 200}, + 'privileges': {'snapshots:list': db.User.RANK_REGULAR}, }) - ret = util.dotdict() - ret.context_factory = context_factory - ret.user_factory = user_factory - ret.api = api.SnapshotListApi() - return ret -def test_retrieving_multiple(test_ctx): +def test_retrieving_multiple(user_factory, context_factory): snapshot1 = snapshot_factory() snapshot2 = snapshot_factory() db.session.add_all([snapshot1, snapshot2]) - result = test_ctx.api.get( - test_ctx.context_factory( - input={'query': '', 'page': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) + result = api.snapshot_api.get_snapshots( + context_factory( + params={'query': '', 'page': 1}, + user=user_factory(rank=db.User.RANK_REGULAR))) assert result['query'] == '' assert result['page'] == 1 assert result['pageSize'] == 100 assert result['total'] == 2 assert len(result['results']) == 2 -def test_trying_to_retrieve_multiple_without_privileges(test_ctx): +def test_trying_to_retrieve_multiple_without_privileges( + user_factory, context_factory): with pytest.raises(errors.AuthError): - test_ctx.api.get( - test_ctx.context_factory( - input={'query': '', 'page': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS))) + api.snapshot_api.get_snapshots( + context_factory( + params={'query': '', 'page': 1}, + user=user_factory(rank=db.User.RANK_ANONYMOUS))) diff --git a/server/szurubooru/tests/api/test_tag_category_creating.py b/server/szurubooru/tests/api/test_tag_category_creating.py index f2f61a56..576fe4f5 100644 --- a/server/szurubooru/tests/api/test_tag_category_creating.py +++ b/server/szurubooru/tests/api/test_tag_category_creating.py @@ -1,94 +1,50 @@ -import os import pytest -from szurubooru import api, config, db, errors -from szurubooru.func import util, tag_categories +import unittest.mock +from szurubooru import api, db, errors +from szurubooru.func import tag_categories, tags -@pytest.fixture -def test_ctx(tmpdir, config_injector, context_factory, user_factory): +def _update_category_name(category, name): + category.name = name + +@pytest.fixture(autouse=True) +def inject_config(config_injector): config_injector({ - 'data_dir': str(tmpdir), - 'tag_category_name_regex': '^[^!]+$', 'privileges': {'tag_categories:create': db.User.RANK_REGULAR}, }) - ret = util.dotdict() - ret.context_factory = context_factory - ret.user_factory = user_factory - ret.api = api.TagCategoryListApi() - return ret -def test_creating_category(test_ctx): - result = test_ctx.api.post( - test_ctx.context_factory( - input={'name': 'meta', 'color': 'black'}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) - assert len(result['snapshots']) == 1 - del result['snapshots'] - assert result == { - 'name': 'meta', - 'color': 'black', - 'usages': 0, - 'default': True, - 'version': 1, - } - category = db.session.query(db.TagCategory).one() - assert category.name == 'meta' - assert category.color == 'black' - assert category.tag_count == 0 - assert os.path.exists(os.path.join(config.config['data_dir'], 'tags.json')) - -@pytest.mark.parametrize('input', [ - {'name': None}, - {'name': ''}, - {'name': '!bad'}, - {'color': None}, - {'color': ''}, - {'color': 'a' * 100}, -]) -def test_trying_to_pass_invalid_input(test_ctx, input): - real_input = { - 'name': 'okay', - 'color': 'okay', - } - for key, value in input.items(): - real_input[key] = value - with pytest.raises(errors.ValidationError): - test_ctx.api.post( - test_ctx.context_factory( - input=real_input, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) +def test_creating_category(user_factory, context_factory): + with unittest.mock.patch('szurubooru.func.tag_categories.serialize_category'), \ + unittest.mock.patch('szurubooru.func.tag_categories.update_category_name'), \ + unittest.mock.patch('szurubooru.func.tags.export_to_json'): + tag_categories.update_category_name.side_effect = _update_category_name + tag_categories.serialize_category.return_value = 'serialized category' + result = api.tag_category_api.create_tag_category( + context_factory( + params={'name': 'meta', 'color': 'black'}, + user=user_factory(rank=db.User.RANK_REGULAR))) + assert result == 'serialized category' + category = db.session.query(db.TagCategory).one() + assert category.name == 'meta' + assert category.color == 'black' + assert category.tag_count == 0 + tags.export_to_json.assert_called_once_with() @pytest.mark.parametrize('field', ['name', 'color']) -def test_trying_to_omit_mandatory_field(test_ctx, field): - input = { +def test_trying_to_omit_mandatory_field(user_factory, context_factory, field): + params = { 'name': 'meta', 'color': 'black', } - del input[field] + del params[field] with pytest.raises(errors.ValidationError): - test_ctx.api.post( - test_ctx.context_factory( - input=input, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) + api.tag_category_api.create_tag_category( + context_factory( + params=params, + user=user_factory(rank=db.User.RANK_REGULAR))) -def test_trying_to_use_existing_name(test_ctx): - result = test_ctx.api.post( - test_ctx.context_factory( - input={'name': 'meta', 'color': 'black'}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) - with pytest.raises(tag_categories.TagCategoryAlreadyExistsError): - result = test_ctx.api.post( - test_ctx.context_factory( - input={'name': 'meta', 'color': 'black'}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) - with pytest.raises(tag_categories.TagCategoryAlreadyExistsError): - result = test_ctx.api.post( - test_ctx.context_factory( - input={'name': 'META', 'color': 'black'}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) - -def test_trying_to_create_without_privileges(test_ctx): +def test_trying_to_create_without_privileges(user_factory, context_factory): with pytest.raises(errors.AuthError): - test_ctx.api.post( - test_ctx.context_factory( - input={'name': 'meta', 'color': 'black'}, - user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS))) + api.tag_category_api.create_tag_category( + context_factory( + params={'name': 'meta', 'color': 'black'}, + user=user_factory(rank=db.User.RANK_ANONYMOUS))) diff --git a/server/szurubooru/tests/api/test_tag_category_deleting.py b/server/szurubooru/tests/api/test_tag_category_deleting.py index 8986ec72..4cbd6437 100644 --- a/server/szurubooru/tests/api/test_tag_category_deleting.py +++ b/server/szurubooru/tests/api/test_tag_category_deleting.py @@ -1,84 +1,70 @@ import pytest -import os -from datetime import datetime -from szurubooru import api, config, db, errors -from szurubooru.func import util, tags, tag_categories +import unittest.mock +from szurubooru import api, db, errors +from szurubooru.func import tag_categories, tags -@pytest.fixture -def test_ctx( - tmpdir, - config_injector, - context_factory, - tag_factory, - tag_category_factory, - user_factory): +@pytest.fixture(autouse=True) +def inject_config(config_injector): config_injector({ - 'data_dir': str(tmpdir), - 'privileges': { - 'tag_categories:delete': db.User.RANK_REGULAR, - }, + 'privileges': {'tag_categories:delete': db.User.RANK_REGULAR}, }) - ret = util.dotdict() - ret.context_factory = context_factory - ret.user_factory = user_factory - ret.tag_factory = tag_factory - ret.tag_category_factory = tag_category_factory - ret.api = api.TagCategoryDetailApi() - return ret -def test_deleting(test_ctx): - db.session.add(test_ctx.tag_category_factory(name='root')) - db.session.add(test_ctx.tag_category_factory(name='category')) +def test_deleting(user_factory, tag_category_factory, context_factory): + db.session.add(tag_category_factory(name='root')) + db.session.add(tag_category_factory(name='category')) db.session.commit() - result = test_ctx.api.delete( - test_ctx.context_factory( - input={'version': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), - 'category') - assert result == {} - assert db.session.query(db.TagCategory).count() == 1 - assert db.session.query(db.TagCategory).one().name == 'root' - assert os.path.exists(os.path.join(config.config['data_dir'], 'tags.json')) + with unittest.mock.patch('szurubooru.func.tags.export_to_json'): + result = api.tag_category_api.delete_tag_category( + context_factory( + params={'version': 1}, + user=user_factory(rank=db.User.RANK_REGULAR)), + {'category_name': 'category'}) + assert result == {} + assert db.session.query(db.TagCategory).count() == 1 + assert db.session.query(db.TagCategory).one().name == 'root' + tags.export_to_json.assert_called_once_with() -def test_trying_to_delete_used(test_ctx, tag_factory): - category = test_ctx.tag_category_factory(name='category') +def test_trying_to_delete_used( + user_factory, tag_category_factory, tag_factory, context_factory): + category = tag_category_factory(name='category') db.session.add(category) db.session.flush() - tag = test_ctx.tag_factory(names=['tag'], category=category) + tag = tag_factory(names=['tag'], category=category) db.session.add(tag) db.session.commit() with pytest.raises(tag_categories.TagCategoryIsInUseError): - test_ctx.api.delete( - test_ctx.context_factory( - input={'version': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), - 'category') + api.tag_category_api.delete_tag_category( + context_factory( + params={'version': 1}, + user=user_factory(rank=db.User.RANK_REGULAR)), + {'category_name': 'category'}) assert db.session.query(db.TagCategory).count() == 1 -def test_trying_to_delete_last(test_ctx, tag_factory): - db.session.add(test_ctx.tag_category_factory(name='root')) +def test_trying_to_delete_last( + user_factory, tag_category_factory, context_factory): + db.session.add(tag_category_factory(name='root')) db.session.commit() with pytest.raises(tag_categories.TagCategoryIsInUseError): - result = test_ctx.api.delete( - test_ctx.context_factory( - input={'version': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), - 'root') + api.tag_category_api.delete_tag_category( + context_factory( + params={'version': 1}, + user=user_factory(rank=db.User.RANK_REGULAR)), + {'category_name': 'root'}) -def test_trying_to_delete_non_existing(test_ctx): +def test_trying_to_delete_non_existing(user_factory, context_factory): with pytest.raises(tag_categories.TagCategoryNotFoundError): - test_ctx.api.delete( - test_ctx.context_factory( - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), - 'bad') + api.tag_category_api.delete_tag_category( + context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), + {'category_name': 'bad'}) -def test_trying_to_delete_without_privileges(test_ctx): - db.session.add(test_ctx.tag_category_factory(name='category')) +def test_trying_to_delete_without_privileges( + user_factory, tag_category_factory, context_factory): + db.session.add(tag_category_factory(name='category')) db.session.commit() with pytest.raises(errors.AuthError): - test_ctx.api.delete( - test_ctx.context_factory( - input={'version': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)), - 'category') + api.tag_category_api.delete_tag_category( + context_factory( + params={'version': 1}, + user=user_factory(rank=db.User.RANK_ANONYMOUS)), + {'category_name': 'category'}) assert db.session.query(db.TagCategory).count() == 1 diff --git a/server/szurubooru/tests/api/test_tag_category_retrieving.py b/server/szurubooru/tests/api/test_tag_category_retrieving.py index 0ef5e751..3bfc115a 100644 --- a/server/szurubooru/tests/api/test_tag_category_retrieving.py +++ b/server/szurubooru/tests/api/test_tag_category_retrieving.py @@ -1,42 +1,31 @@ -import datetime import pytest from szurubooru import api, db, errors -from szurubooru.func import util, tag_categories +from szurubooru.func import tag_categories -@pytest.fixture -def test_ctx( - context_factory, config_injector, user_factory, tag_category_factory): +@pytest.fixture(autouse=True) +def inject_config(config_injector): config_injector({ 'privileges': { 'tag_categories:list': db.User.RANK_REGULAR, 'tag_categories:view': db.User.RANK_REGULAR, }, - 'thumbnails': {'avatar_width': 200}, }) - ret = util.dotdict() - ret.context_factory = context_factory - ret.user_factory = user_factory - ret.tag_category_factory = tag_category_factory - ret.list_api = api.TagCategoryListApi() - ret.detail_api = api.TagCategoryDetailApi() - return ret -def test_retrieving_multiple(test_ctx): +def test_retrieving_multiple( + user_factory, tag_category_factory, context_factory): db.session.add_all([ - test_ctx.tag_category_factory(name='c1'), - test_ctx.tag_category_factory(name='c2'), + tag_category_factory(name='c1'), + tag_category_factory(name='c2'), ]) - result = test_ctx.list_api.get( - test_ctx.context_factory( - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) + result = api.tag_category_api.get_tag_categories( + context_factory(user=user_factory(rank=db.User.RANK_REGULAR))) assert [cat['name'] for cat in result['results']] == ['c1', 'c2'] -def test_retrieving_single(test_ctx): - db.session.add(test_ctx.tag_category_factory(name='cat')) - result = test_ctx.detail_api.get( - test_ctx.context_factory( - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), - 'cat') +def test_retrieving_single(user_factory, tag_category_factory, context_factory): + db.session.add(tag_category_factory(name='cat')) + result = api.tag_category_api.get_tag_category( + context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), + {'category_name': 'cat'}) assert result == { 'name': 'cat', 'color': 'dummy', @@ -46,16 +35,15 @@ def test_retrieving_single(test_ctx): 'version': 1, } -def test_trying_to_retrieve_single_non_existing(test_ctx): +def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): with pytest.raises(tag_categories.TagCategoryNotFoundError): - test_ctx.detail_api.get( - test_ctx.context_factory( - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), - '-') + api.tag_category_api.get_tag_category( + context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), + {'category_name': '-'}) -def test_trying_to_retrieve_single_without_privileges(test_ctx): +def test_trying_to_retrieve_single_without_privileges( + user_factory, context_factory): with pytest.raises(errors.AuthError): - test_ctx.detail_api.get( - test_ctx.context_factory( - user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)), - '-') + api.tag_category_api.get_tag_category( + context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)), + {'category_name': '-'}) diff --git a/server/szurubooru/tests/api/test_tag_category_updating.py b/server/szurubooru/tests/api/test_tag_category_updating.py index 78b09961..af10fbde 100644 --- a/server/szurubooru/tests/api/test_tag_category_updating.py +++ b/server/szurubooru/tests/api/test_tag_category_updating.py @@ -1,137 +1,104 @@ -import os import pytest -from szurubooru import api, config, db, errors -from szurubooru.func import util, tag_categories +import unittest.mock +from szurubooru import api, db, errors +from szurubooru.func import tag_categories, tags -@pytest.fixture -def test_ctx( - tmpdir, - config_injector, - context_factory, - user_factory, - tag_category_factory): +def _update_category_name(category, name): + category.name = name + +@pytest.fixture(autouse=True) +def inject_config(config_injector): config_injector({ - 'data_dir': str(tmpdir), - 'tag_category_name_regex': '^[^!]*$', 'privileges': { 'tag_categories:edit:name': db.User.RANK_REGULAR, 'tag_categories:edit:color': db.User.RANK_REGULAR, + 'tag_categories:set_default': db.User.RANK_REGULAR, }, }) - ret = util.dotdict() - ret.context_factory = context_factory - ret.user_factory = user_factory - ret.tag_category_factory = tag_category_factory - ret.api = api.TagCategoryDetailApi() - return ret -def test_simple_updating(test_ctx): - category = test_ctx.tag_category_factory(name='name', color='black') +def test_simple_updating(user_factory, tag_category_factory, context_factory): + category = tag_category_factory(name='name', color='black') db.session.add(category) db.session.commit() - result = test_ctx.api.put( - test_ctx.context_factory( - input={ - 'name': 'changed', - 'color': 'white', - 'version': 1, - }, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), - 'name') - assert len(result['snapshots']) == 1 - del result['snapshots'] - assert result == { - 'name': 'changed', - 'color': 'white', - 'usages': 0, - 'default': False, - 'version': 2, - } - assert tag_categories.try_get_category_by_name('name') is None - category = tag_categories.get_category_by_name('changed') - assert category is not None - assert category.name == 'changed' - assert category.color == 'white' - assert os.path.exists(os.path.join(config.config['data_dir'], 'tags.json')) - -@pytest.mark.parametrize('input,expected_exception', [ - ({'name': None}, tag_categories.InvalidTagCategoryNameError), - ({'name': ''}, tag_categories.InvalidTagCategoryNameError), - ({'name': '!bad'}, tag_categories.InvalidTagCategoryNameError), - ({'color': None}, tag_categories.InvalidTagCategoryColorError), - ({'color': ''}, tag_categories.InvalidTagCategoryColorError), - ({'color': '; float:left'}, tag_categories.InvalidTagCategoryColorError), -]) -def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception): - db.session.add(test_ctx.tag_category_factory(name='meta', color='black')) - db.session.commit() - with pytest.raises(expected_exception): - test_ctx.api.put( - test_ctx.context_factory( - input={**input, **{'version': 1}}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), - 'meta') + with unittest.mock.patch('szurubooru.func.tag_categories.serialize_category'), \ + unittest.mock.patch('szurubooru.func.tag_categories.update_category_name'), \ + unittest.mock.patch('szurubooru.func.tag_categories.update_category_color'), \ + unittest.mock.patch('szurubooru.func.tags.export_to_json'): + tag_categories.update_category_name.side_effect = _update_category_name + tag_categories.serialize_category.return_value = 'serialized category' + result = api.tag_category_api.update_tag_category( + context_factory( + params={ + 'name': 'changed', + 'color': 'white', + 'version': 1, + }, + user=user_factory(rank=db.User.RANK_REGULAR)), + {'category_name': 'name'}) + assert result == 'serialized category' + tag_categories.update_category_name.assert_called_once_with(category, 'changed') + tag_categories.update_category_color.assert_called_once_with(category, 'white') + tags.export_to_json.assert_called_once_with() @pytest.mark.parametrize('field', ['name', 'color']) -def test_omitting_optional_field(test_ctx, field): - db.session.add(test_ctx.tag_category_factory(name='name', color='black')) +def test_omitting_optional_field( + user_factory, tag_category_factory, context_factory, field): + db.session.add(tag_category_factory(name='name', color='black')) db.session.commit() - input = { + params = { 'name': 'changed', 'color': 'white', } - del input[field] - result = test_ctx.api.put( - test_ctx.context_factory( - input={**input, **{'version': 1}}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), - 'name') - assert result is not None + del params[field] + with unittest.mock.patch('szurubooru.func.tag_categories.serialize_category'), \ + unittest.mock.patch('szurubooru.func.tag_categories.update_category_name'), \ + unittest.mock.patch('szurubooru.func.tags.export_to_json'): + api.tag_category_api.update_tag_category( + context_factory( + params={**params, **{'version': 1}}, + user=user_factory(rank=db.User.RANK_REGULAR)), + {'category_name': 'name'}) -def test_trying_to_update_non_existing(test_ctx): +def test_trying_to_update_non_existing(user_factory, context_factory): with pytest.raises(tag_categories.TagCategoryNotFoundError): - test_ctx.api.put( - test_ctx.context_factory( - input={'name': ['dummy']}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), - 'bad') + api.tag_category_api.update_tag_category( + context_factory( + params={'name': ['dummy']}, + user=user_factory(rank=db.User.RANK_REGULAR)), + {'category_name': 'bad'}) -@pytest.mark.parametrize('new_name', ['cat', 'CAT']) -def test_reusing_own_name(test_ctx, new_name): - db.session.add(test_ctx.tag_category_factory(name='cat', color='black')) - db.session.commit() - result = test_ctx.api.put( - test_ctx.context_factory( - input={'name': new_name, 'version': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), - 'cat') - assert result['name'] == new_name - category = tag_categories.get_category_by_name('cat') - assert category.name == new_name - -@pytest.mark.parametrize('dup_name', ['cat1', 'CAT1']) -def test_trying_to_use_existing_name(test_ctx, dup_name): - db.session.add_all([ - test_ctx.tag_category_factory(name='cat1', color='black'), - test_ctx.tag_category_factory(name='cat2', color='black')]) - db.session.commit() - with pytest.raises(tag_categories.TagCategoryAlreadyExistsError): - test_ctx.api.put( - test_ctx.context_factory( - input={'name': dup_name, 'version': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), - 'cat2') - -@pytest.mark.parametrize('input', [ +@pytest.mark.parametrize('params', [ {'name': 'whatever'}, {'color': 'whatever'}, ]) -def test_trying_to_update_without_privileges(test_ctx, input): - db.session.add(test_ctx.tag_category_factory(name='dummy')) +def test_trying_to_update_without_privileges( + user_factory, tag_category_factory, context_factory, params): + db.session.add(tag_category_factory(name='dummy')) db.session.commit() with pytest.raises(errors.AuthError): - test_ctx.api.put( - test_ctx.context_factory( - input={**input, **{'version': 1}}, - user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)), - 'dummy') + api.tag_category_api.update_tag_category( + context_factory( + params={**params, **{'version': 1}}, + user=user_factory(rank=db.User.RANK_ANONYMOUS)), + {'category_name': 'dummy'}) + +def test_set_as_default(user_factory, tag_category_factory, context_factory): + category = tag_category_factory(name='name', color='black') + db.session.add(category) + db.session.commit() + with unittest.mock.patch('szurubooru.func.tag_categories.serialize_category'), \ + unittest.mock.patch('szurubooru.func.tag_categories.set_default_category'), \ + unittest.mock.patch('szurubooru.func.tags.export_to_json'): + tag_categories.update_category_name.side_effect = _update_category_name + tag_categories.serialize_category.return_value = 'serialized category' + result = api.tag_category_api.set_tag_category_as_default( + context_factory( + params={ + 'name': 'changed', + 'color': 'white', + 'version': 1, + }, + user=user_factory(rank=db.User.RANK_REGULAR)), + {'category_name': 'name'}) + assert result == 'serialized category' + tag_categories.set_default_category.assert_called_once_with(category) diff --git a/server/szurubooru/tests/api/test_tag_creating.py b/server/szurubooru/tests/api/test_tag_creating.py index 31120541..2076f5d4 100644 --- a/server/szurubooru/tests/api/test_tag_creating.py +++ b/server/szurubooru/tests/api/test_tag_creating.py @@ -1,187 +1,77 @@ -import datetime -import os import pytest -from szurubooru import api, config, db, errors -from szurubooru.func import util, tags, tag_categories, cache +import unittest.mock +from szurubooru import api, db, errors +from szurubooru.func import tags, tag_categories -def assert_relations(relations, expected_tag_names): - actual_names = sorted([rel.names[0].name for rel in relations]) - assert actual_names == sorted(expected_tag_names) +@pytest.fixture(autouse=True) +def inject_config(config_injector): + config_injector({'privileges': {'tags:create': db.User.RANK_REGULAR}}) -@pytest.fixture -def test_ctx( - tmpdir, config_injector, context_factory, user_factory, tag_factory): - config_injector({ - 'data_dir': str(tmpdir), - 'tag_name_regex': '^[^!]*$', - 'tag_category_name_regex': '^[^!]*$', - 'privileges': {'tags:create': db.User.RANK_REGULAR}, - }) - db.session.add_all([ - db.TagCategory(name) for name in ['meta', 'character', 'copyright']]) - db.session.flush() - cache.purge() - ret = util.dotdict() - ret.context_factory = context_factory - ret.user_factory = user_factory - ret.tag_factory = tag_factory - ret.api = api.TagListApi() - return ret - -def test_creating_simple_tags(test_ctx, fake_datetime): - with fake_datetime('1997-12-01'): - result = test_ctx.api.post( - test_ctx.context_factory( - input={ +def test_creating_simple_tags(tag_factory, user_factory, context_factory): + with unittest.mock.patch('szurubooru.func.tags.create_tag'), \ + unittest.mock.patch('szurubooru.func.tags.get_or_create_tags_by_names'), \ + unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \ + unittest.mock.patch('szurubooru.func.tags.export_to_json'): + tags.get_or_create_tags_by_names.return_value = ([], []) + tags.create_tag.return_value = tag_factory() + tags.serialize_tag.return_value = 'serialized tag' + result = api.tag_api.create_tag( + context_factory( + params={ 'names': ['tag1', 'tag2'], 'category': 'meta', 'description': 'desc', - 'suggestions': [], - 'implications': [], + 'suggestions': ['sug1', 'sug2'], + 'implications': ['imp1', 'imp2'], }, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) - assert len(result['snapshots']) == 1 - del result['snapshots'] - assert result == { - 'names': ['tag1', 'tag2'], - 'category': 'meta', - 'description': 'desc', - 'suggestions': [], - 'implications': [], - 'creationTime': datetime.datetime(1997, 12, 1), - 'lastEditTime': None, - 'usages': 0, - 'version': 1, - } - tag = tags.get_tag_by_name('tag1') - assert [tag_name.name for tag_name in tag.names] == ['tag1', 'tag2'] - assert tag.category.name == 'meta' - assert tag.last_edit_time is None - assert tag.post_count == 0 - assert_relations(tag.suggestions, []) - assert_relations(tag.implications, []) - assert os.path.exists(os.path.join(config.config['data_dir'], 'tags.json')) - -@pytest.mark.parametrize('input,expected_exception', [ - ({'names': None}, tags.InvalidTagNameError), - ({'names': []}, tags.InvalidTagNameError), - ({'names': [None]}, tags.InvalidTagNameError), - ({'names': ['']}, tags.InvalidTagNameError), - ({'names': ['!bad']}, tags.InvalidTagNameError), - ({'names': ['x' * 65]}, tags.InvalidTagNameError), - ({'category': None}, tag_categories.TagCategoryNotFoundError), - ({'category': ''}, tag_categories.TagCategoryNotFoundError), - ({'category': '!bad'}, tag_categories.TagCategoryNotFoundError), - ({'suggestions': ['good', '!bad']}, tags.InvalidTagNameError), - ({'implications': ['good', '!bad']}, tags.InvalidTagNameError), -]) -def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception): - real_input={ - 'names': ['tag1', 'tag2'], - 'category': 'meta', - 'suggestions': [], - 'implications': [], - } - for key, value in input.items(): - real_input[key] = value - with pytest.raises(expected_exception): - test_ctx.api.post( - test_ctx.context_factory( - input=real_input, - user=test_ctx.user_factory())) + user=user_factory(rank=db.User.RANK_REGULAR))) + assert result == 'serialized tag' + tags.create_tag.assert_called_once_with( + ['tag1', 'tag2'], 'meta', ['sug1', 'sug2'], ['imp1', 'imp2']) + tags.export_to_json.assert_called_once_with() @pytest.mark.parametrize('field', ['names', 'category']) -def test_trying_to_omit_mandatory_field(test_ctx, field): - input = { +def test_trying_to_omit_mandatory_field(user_factory, context_factory, field): + params = { 'names': ['tag1', 'tag2'], 'category': 'meta', 'suggestions': [], 'implications': [], } - del input[field] + del params[field] with pytest.raises(errors.ValidationError): - test_ctx.api.post( - test_ctx.context_factory( - input=input, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) + api.tag_api.create_tag( + context_factory( + params=params, + user=user_factory(rank=db.User.RANK_REGULAR))) @pytest.mark.parametrize('field', ['implications', 'suggestions']) -def test_omitting_optional_field(test_ctx, field): - input = { +def test_omitting_optional_field( + tag_factory, user_factory, context_factory, field): + params = { 'names': ['tag1', 'tag2'], 'category': 'meta', 'suggestions': [], 'implications': [], } - del input[field] - result = test_ctx.api.post( - test_ctx.context_factory( - input=input, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) - assert result is not None + del params[field] + with unittest.mock.patch('szurubooru.func.tags.create_tag'), \ + unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \ + unittest.mock.patch('szurubooru.func.tags.export_to_json'): + tags.create_tag.return_value = tag_factory() + api.tag_api.create_tag( + context_factory( + params=params, + user=user_factory(rank=db.User.RANK_REGULAR))) -def test_creating_new_category(test_ctx): - with pytest.raises(tag_categories.TagCategoryNotFoundError): - test_ctx.api.post( - test_ctx.context_factory( - input={ - 'names': ['main'], - 'category': 'new', - 'suggestions': [], - 'implications': [], - }, user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) - -@pytest.mark.parametrize('input,expected_suggestions,expected_implications', [ - # new relations - ({ - 'names': ['main'], - 'category': 'meta', - 'suggestions': ['sug1', 'sug2'], - 'implications': ['imp1', 'imp2'], - }, ['sug1', 'sug2'], ['imp1', 'imp2']), - # overlapping relations - ({ - 'names': ['main'], - 'category': 'meta', - 'suggestions': ['sug', 'shared'], - 'implications': ['shared', 'imp'], - }, ['shared', 'sug'], ['imp', 'shared']), - # duplicate relations - ({ - 'names': ['main'], - 'category': 'meta', - 'suggestions': ['sug', 'SUG'], - 'implications': ['imp', 'IMP'], - }, ['sug'], ['imp']), - # overlapping duplicate relations - ({ - 'names': ['main'], - 'category': 'meta', - 'suggestions': ['shared1', 'shared2'], - 'implications': ['SHARED1', 'SHARED2'], - }, ['shared1', 'shared2'], ['shared1', 'shared2']), -]) -def test_creating_new_suggestions_and_implications( - test_ctx, input, expected_suggestions, expected_implications): - result = test_ctx.api.post( - test_ctx.context_factory( - input=input, user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) - assert result['suggestions'] == expected_suggestions - assert result['implications'] == expected_implications - tag = tags.get_tag_by_name('main') - assert_relations(tag.suggestions, expected_suggestions) - assert_relations(tag.implications, expected_implications) - for name in ['main'] + expected_suggestions + expected_implications: - assert tags.try_get_tag_by_name(name) is not None - -def test_trying_to_create_tag_without_privileges(test_ctx): +def test_trying_to_create_tag_without_privileges(user_factory, context_factory): with pytest.raises(errors.AuthError): - test_ctx.api.post( - test_ctx.context_factory( - input={ + api.tag_api.create_tag( + context_factory( + params={ 'names': ['tag'], 'category': 'meta', 'suggestions': ['tag'], 'implications': [], }, - user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS))) + user=user_factory(rank=db.User.RANK_ANONYMOUS))) diff --git a/server/szurubooru/tests/api/test_tag_deleting.py b/server/szurubooru/tests/api/test_tag_deleting.py index 350fc0a9..98ff9cf5 100644 --- a/server/szurubooru/tests/api/test_tag_deleting.py +++ b/server/szurubooru/tests/api/test_tag_deleting.py @@ -1,50 +1,55 @@ import pytest -import os -from datetime import datetime -from szurubooru import api, config, db, errors -from szurubooru.func import util, tags +import unittest.mock +from szurubooru import api, db, errors +from szurubooru.func import tags -@pytest.fixture -def test_ctx( - tmpdir, config_injector, context_factory, tag_factory, user_factory): - config_injector({ - 'data_dir': str(tmpdir), - 'privileges': { - 'tags:delete': db.User.RANK_REGULAR, - }, - }) - ret = util.dotdict() - ret.context_factory = context_factory - ret.user_factory = user_factory - ret.tag_factory = tag_factory - ret.api = api.TagDetailApi() - return ret +@pytest.fixture(autouse=True) +def inject_config(config_injector): + config_injector({'privileges': {'tags:delete': db.User.RANK_REGULAR}}) -def test_deleting(test_ctx): - db.session.add(test_ctx.tag_factory(names=['tag'])) +def test_deleting(user_factory, tag_factory, context_factory): + db.session.add(tag_factory(names=['tag'])) db.session.commit() - result = test_ctx.api.delete( - test_ctx.context_factory( - input={'version': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), - 'tag') - assert result == {} - assert db.session.query(db.Tag).count() == 0 - assert os.path.exists(os.path.join(config.config['data_dir'], 'tags.json')) + with unittest.mock.patch('szurubooru.func.tags.export_to_json'): + result = api.tag_api.delete_tag( + context_factory( + params={'version': 1}, + user=user_factory(rank=db.User.RANK_REGULAR)), + {'tag_name': 'tag'}) + assert result == {} + assert db.session.query(db.Tag).count() == 0 + tags.export_to_json.assert_called_once_with() -def test_trying_to_delete_non_existing(test_ctx): +def test_deleting_used(user_factory, tag_factory, context_factory, post_factory): + tag = tag_factory(names=['tag']) + post = post_factory() + post.tags.append(tag) + db.session.add_all([tag, post]) + db.session.commit() + with unittest.mock.patch('szurubooru.func.tags.export_to_json'): + api.tag_api.delete_tag( + context_factory( + params={'version': 1}, + user=user_factory(rank=db.User.RANK_REGULAR)), + {'tag_name': 'tag'}) + db.session.refresh(post) + assert db.session.query(db.Tag).count() == 0 + assert post.tags == [] + +def test_trying_to_delete_non_existing(user_factory, context_factory): with pytest.raises(tags.TagNotFoundError): - test_ctx.api.delete( - test_ctx.context_factory( - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 'bad') + api.tag_api.delete_tag( + context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), + {'tag_name': 'bad'}) -def test_trying_to_delete_without_privileges(test_ctx): - db.session.add(test_ctx.tag_factory(names=['tag'])) +def test_trying_to_delete_without_privileges( + user_factory, tag_factory, context_factory): + db.session.add(tag_factory(names=['tag'])) db.session.commit() with pytest.raises(errors.AuthError): - test_ctx.api.delete( - test_ctx.context_factory( - input={'version': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)), - 'tag') + api.tag_api.delete_tag( + context_factory( + params={'version': 1}, + user=user_factory(rank=db.User.RANK_ANONYMOUS)), + {'tag_name': 'tag'}) assert db.session.query(db.Tag).count() == 1 diff --git a/server/szurubooru/tests/api/test_tag_merging.py b/server/szurubooru/tests/api/test_tag_merging.py index 0a642406..90fa4d83 100644 --- a/server/szurubooru/tests/api/test_tag_merging.py +++ b/server/szurubooru/tests/api/test_tag_merging.py @@ -1,34 +1,15 @@ -import datetime -import os import pytest -from szurubooru import api, config, db, errors -from szurubooru.func import util, tags +import unittest.mock +from szurubooru import api, db, errors +from szurubooru.func import tags -@pytest.fixture -def test_ctx( - tmpdir, - config_injector, - context_factory, - user_factory, - tag_factory, - tag_category_factory): - config_injector({ - 'data_dir': str(tmpdir), - 'privileges': { - 'tags:merge': db.User.RANK_REGULAR, - }, - }) - ret = util.dotdict() - ret.context_factory = context_factory - ret.user_factory = user_factory - ret.tag_factory = tag_factory - ret.tag_category_factory = tag_category_factory - ret.api = api.TagMergeApi() - return ret +@pytest.fixture(autouse=True) +def inject_config(config_injector): + config_injector({'privileges': {'tags:merge': db.User.RANK_REGULAR}}) -def test_merging_with_usages(test_ctx, fake_datetime, post_factory): - source_tag = test_ctx.tag_factory(names=['source']) - target_tag = test_ctx.tag_factory(names=['target']) +def test_merging(user_factory, tag_factory, context_factory, post_factory): + source_tag = tag_factory(names=['source']) + target_tag = tag_factory(names=['target']) db.session.add_all([source_tag, target_tag]) db.session.flush() assert source_tag.post_count == 0 @@ -39,73 +20,78 @@ def test_merging_with_usages(test_ctx, fake_datetime, post_factory): db.session.commit() assert source_tag.post_count == 1 assert target_tag.post_count == 0 - with fake_datetime('1997-12-01'): - result = test_ctx.api.post( - test_ctx.context_factory( - input={ + with unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \ + unittest.mock.patch('szurubooru.func.tags.merge_tags'), \ + unittest.mock.patch('szurubooru.func.tags.export_to_json'): + result = api.tag_api.merge_tags( + context_factory( + params={ 'removeVersion': 1, 'mergeToVersion': 1, 'remove': 'source', 'mergeTo': 'target', }, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) - assert tags.try_get_tag_by_name('source') is None - assert tags.get_tag_by_name('target').post_count == 1 + user=user_factory(rank=db.User.RANK_REGULAR))) + tags.merge_tags.called_once_with(source_tag, target_tag) + tags.export_to_json.assert_called_once_with() @pytest.mark.parametrize( 'field', ['remove', 'mergeTo', 'removeVersion', 'mergeToVersion']) -def test_trying_to_omit_mandatory_field(test_ctx, field): +def test_trying_to_omit_mandatory_field( + user_factory, tag_factory, context_factory, field): db.session.add_all([ - test_ctx.tag_factory(names=['source']), - test_ctx.tag_factory(names=['target']), + tag_factory(names=['source']), + tag_factory(names=['target']), ]) db.session.commit() - input = { + params = { 'removeVersion': 1, 'mergeToVersion': 1, 'remove': 'source', 'mergeTo': 'target', } - del input[field] + del params[field] with pytest.raises(errors.ValidationError): - test_ctx.api.post( - test_ctx.context_factory( - input=input, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) + api.tag_api.merge_tags( + context_factory( + params=params, + user=user_factory(rank=db.User.RANK_REGULAR))) -def test_trying_to_merge_non_existing(test_ctx): - db.session.add(test_ctx.tag_factory(names=['good'])) +def test_trying_to_merge_non_existing( + user_factory, tag_factory, context_factory): + db.session.add(tag_factory(names=['good'])) db.session.commit() with pytest.raises(tags.TagNotFoundError): - test_ctx.api.post( - test_ctx.context_factory( - input={'remove': 'good', 'mergeTo': 'bad'}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) + api.tag_api.merge_tags( + context_factory( + params={'remove': 'good', 'mergeTo': 'bad'}, + user=user_factory(rank=db.User.RANK_REGULAR))) with pytest.raises(tags.TagNotFoundError): - test_ctx.api.post( - test_ctx.context_factory( - input={'remove': 'bad', 'mergeTo': 'good'}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) + api.tag_api.merge_tags( + context_factory( + params={'remove': 'bad', 'mergeTo': 'good'}, + user=user_factory(rank=db.User.RANK_REGULAR))) -@pytest.mark.parametrize('input', [ +@pytest.mark.parametrize('params', [ {'names': 'whatever'}, {'category': 'whatever'}, {'suggestions': ['whatever']}, {'implications': ['whatever']}, ]) -def test_trying_to_merge_without_privileges(test_ctx, input): +def test_trying_to_merge_without_privileges( + user_factory, tag_factory, context_factory, params): db.session.add_all([ - test_ctx.tag_factory(names=['source']), - test_ctx.tag_factory(names=['target']), + tag_factory(names=['source']), + tag_factory(names=['target']), ]) db.session.commit() with pytest.raises(errors.AuthError): - test_ctx.api.post( - test_ctx.context_factory( - input={ + api.tag_api.merge_tags( + context_factory( + params={ 'removeVersion': 1, 'mergeToVersion': 1, 'remove': 'source', 'mergeTo': 'target', }, - user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS))) + user=user_factory(rank=db.User.RANK_ANONYMOUS))) diff --git a/server/szurubooru/tests/api/test_tag_retrieving.py b/server/szurubooru/tests/api/test_tag_retrieving.py index 0351e25b..cff7ee05 100644 --- a/server/szurubooru/tests/api/test_tag_retrieving.py +++ b/server/szurubooru/tests/api/test_tag_retrieving.py @@ -1,82 +1,64 @@ -import datetime import pytest +import unittest.mock from szurubooru import api, db, errors -from szurubooru.func import util, tags +from szurubooru.func import tags -@pytest.fixture -def test_ctx( - context_factory, - config_injector, - user_factory, - tag_factory, - tag_category_factory): +@pytest.fixture(autouse=True) +def inject_config(config_injector): config_injector({ 'privileges': { 'tags:list': db.User.RANK_REGULAR, 'tags:view': db.User.RANK_REGULAR, }, - 'thumbnails': {'avatar_width': 200}, }) - ret = util.dotdict() - ret.context_factory = context_factory - ret.user_factory = user_factory - ret.tag_factory = tag_factory - ret.tag_category_factory = tag_category_factory - ret.list_api = api.TagListApi() - ret.detail_api = api.TagDetailApi() - return ret -def test_retrieving_multiple(test_ctx): - tag1 = test_ctx.tag_factory(names=['t1']) - tag2 = test_ctx.tag_factory(names=['t2']) +def test_retrieving_multiple(user_factory, tag_factory, context_factory): + tag1 = tag_factory(names=['t1']) + tag2 = tag_factory(names=['t2']) db.session.add_all([tag1, tag2]) - result = test_ctx.list_api.get( - test_ctx.context_factory( - input={'query': '', 'page': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) - assert result['query'] == '' - assert result['page'] == 1 - assert result['pageSize'] == 100 - assert result['total'] == 2 - assert [t['names'] for t in result['results']] == [['t1'], ['t2']] + with unittest.mock.patch('szurubooru.func.tags.serialize_tag'): + tags.serialize_tag.return_value = 'serialized tag' + result = api.tag_api.get_tags( + context_factory( + params={'query': '', 'page': 1}, + user=user_factory(rank=db.User.RANK_REGULAR))) + assert result == { + 'query': '', + 'page': 1, + 'pageSize': 100, + 'total': 2, + 'results': ['serialized tag', 'serialized tag'], + } -def test_trying_to_retrieve_multiple_without_privileges(test_ctx): +def test_trying_to_retrieve_multiple_without_privileges( + user_factory, context_factory): with pytest.raises(errors.AuthError): - test_ctx.list_api.get( - test_ctx.context_factory( - input={'query': '', 'page': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS))) + api.tag_api.get_tags( + context_factory( + params={'query': '', 'page': 1}, + user=user_factory(rank=db.User.RANK_ANONYMOUS))) -def test_retrieving_single(test_ctx): - category = test_ctx.tag_category_factory(name='meta') - db.session.add(test_ctx.tag_factory(names=['tag'], category=category)) - result = test_ctx.detail_api.get( - test_ctx.context_factory( - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), - 'tag') - assert result == { - 'names': ['tag'], - 'category': 'meta', - 'description': None, - 'creationTime': datetime.datetime(1996, 1, 1), - 'lastEditTime': None, - 'suggestions': [], - 'implications': [], - 'usages': 0, - 'snapshots': [], - 'version': 1, - } +def test_retrieving_single(user_factory, tag_factory, context_factory): + db.session.add(tag_factory(names=['tag'])) + with unittest.mock.patch('szurubooru.func.tags.serialize_tag'): + tags.serialize_tag.return_value = 'serialized tag' + result = api.tag_api.get_tag( + context_factory( + user=user_factory(rank=db.User.RANK_REGULAR)), + {'tag_name': 'tag'}) + assert result == 'serialized tag' -def test_trying_to_retrieve_single_non_existing(test_ctx): +def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): with pytest.raises(tags.TagNotFoundError): - test_ctx.detail_api.get( - test_ctx.context_factory( - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), - '-') + api.tag_api.get_tag( + context_factory( + user=user_factory(rank=db.User.RANK_REGULAR)), + {'tag_name': '-'}) -def test_trying_to_retrieve_single_without_privileges(test_ctx): +def test_trying_to_retrieve_single_without_privileges( + user_factory, context_factory): with pytest.raises(errors.AuthError): - test_ctx.detail_api.get( - test_ctx.context_factory( - user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)), - '-') + api.tag_api.get_tag( + context_factory( + user=user_factory(rank=db.User.RANK_ANONYMOUS)), + {'tag_name': '-'}) diff --git a/server/szurubooru/tests/api/test_tag_siblings_retrieving.py b/server/szurubooru/tests/api/test_tag_siblings_retrieving.py index f3d8e2c9..f79bc340 100644 --- a/server/szurubooru/tests/api/test_tag_siblings_retrieving.py +++ b/server/szurubooru/tests/api/test_tag_siblings_retrieving.py @@ -1,56 +1,47 @@ -import datetime import pytest +import unittest.mock from szurubooru import api, db, errors -from szurubooru.func import util, tags +from szurubooru.func import tags -def assert_results(result, expected_tag_names_and_occurrences): - actual_tag_names_and_occurences = [] - for item in result['results']: - tag_name = item['tag']['names'][0] - occurrences = item['occurrences'] - actual_tag_names_and_occurences.append((tag_name, occurrences)) - assert actual_tag_names_and_occurences == expected_tag_names_and_occurrences +@pytest.fixture(autouse=True) +def inject_config(config_injector): + config_injector({'privileges': {'tags:view': db.User.RANK_REGULAR}}) -@pytest.fixture -def test_ctx( - context_factory, config_injector, user_factory, tag_factory, post_factory): - config_injector({ - 'privileges': { - 'tags:view': db.User.RANK_REGULAR, - }, - 'thumbnails': {'avatar_width': 200}, - }) - ret = util.dotdict() - ret.context_factory = context_factory - ret.user_factory = user_factory - ret.tag_factory = tag_factory - ret.post_factory = post_factory - ret.api = api.TagSiblingsApi() - return ret +def test_get_tag_siblings(user_factory, tag_factory, post_factory, context_factory): + db.session.add(tag_factory(names=['tag'])) + with unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \ + unittest.mock.patch('szurubooru.func.tags.get_tag_siblings'): + tags.serialize_tag.side_effect = \ + lambda tag, *args, **kwargs: \ + 'serialized tag %s' % tag.names[0].name + tags.get_tag_siblings.return_value = [ + (tag_factory(names=['sib1']), 1), + (tag_factory(names=['sib2']), 3), + ] + result = api.tag_api.get_tag_siblings( + context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), + {'tag_name': 'tag'}) + assert result == { + 'results': [ + { + 'tag': 'serialized tag sib1', + 'occurrences': 1, + }, + { + 'tag': 'serialized tag sib2', + 'occurrences': 3, + }, + ], + } -def test_used_with_others(test_ctx): - tag1 = test_ctx.tag_factory(names=['tag1']) - tag2 = test_ctx.tag_factory(names=['tag2']) - post = test_ctx.post_factory() - post.tags = [tag1, tag2] - db.session.add_all([post, tag1, tag2]) - result = test_ctx.api.get( - test_ctx.context_factory( - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 'tag1') - assert_results(result, [('tag2', 1)]) - result = test_ctx.api.get( - test_ctx.context_factory( - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 'tag2') - assert_results(result, [('tag1', 1)]) - -def test_trying_to_retrieve_non_existing(test_ctx): +def test_trying_to_retrieve_non_existing(user_factory, context_factory): with pytest.raises(tags.TagNotFoundError): - test_ctx.api.get( - test_ctx.context_factory( - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), '-') + api.tag_api.get_tag_siblings( + context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), + {'tag_name': '-'}) -def test_trying_to_retrieve_without_privileges(test_ctx): +def test_trying_to_retrieve_without_privileges(user_factory, context_factory): with pytest.raises(errors.AuthError): - test_ctx.api.get( - test_ctx.context_factory( - user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)), '-') + api.tag_api.get_tag_siblings( + context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)), + {'tag_name': '-'}) diff --git a/server/szurubooru/tests/api/test_tag_updating.py b/server/szurubooru/tests/api/test_tag_updating.py index 8ec965f4..4e9f8525 100644 --- a/server/szurubooru/tests/api/test_tag_updating.py +++ b/server/szurubooru/tests/api/test_tag_updating.py @@ -1,20 +1,11 @@ -import datetime -import os import pytest -from szurubooru import api, config, db, errors -from szurubooru.func import util, tags, tag_categories, cache +import unittest.mock +from szurubooru import api, db, errors +from szurubooru.func import tags -def assert_relations(relations, expected_tag_names): - actual_names = sorted([rel.names[0].name for rel in relations]) - assert actual_names == sorted(expected_tag_names) - -@pytest.fixture -def test_ctx( - tmpdir, config_injector, context_factory, user_factory, tag_factory): +@pytest.fixture(autouse=True) +def inject_config(config_injector): config_injector({ - 'data_dir': str(tmpdir), - 'tag_name_regex': '^[^!]*$', - 'tag_category_name_regex': '^[^!]*$', 'privileges': { 'tags:create': db.User.RANK_REGULAR, 'tags:edit:names': db.User.RANK_REGULAR, @@ -24,118 +15,115 @@ def test_ctx( 'tags:edit:implications': db.User.RANK_REGULAR, }, }) - db.session.add_all([ - db.TagCategory(name) for name in ['meta', 'character', 'copyright']]) - db.session.commit() - cache.purge() - ret = util.dotdict() - ret.context_factory = context_factory - ret.user_factory = user_factory - ret.tag_factory = tag_factory - ret.api = api.TagDetailApi() - return ret -def test_simple_updating(test_ctx, fake_datetime): - tag = test_ctx.tag_factory(names=['tag1', 'tag2']) +def test_simple_updating(user_factory, tag_factory, context_factory, fake_datetime): + auth_user = user_factory(rank=db.User.RANK_REGULAR) + tag = tag_factory(names=['tag1', 'tag2']) db.session.add(tag) db.session.commit() - with fake_datetime('1997-12-01'): - result = test_ctx.api.put( - test_ctx.context_factory( - input={ + with unittest.mock.patch('szurubooru.func.tags.create_tag'), \ + unittest.mock.patch('szurubooru.func.tags.get_or_create_tags_by_names'), \ + unittest.mock.patch('szurubooru.func.tags.update_tag_names'), \ + unittest.mock.patch('szurubooru.func.tags.update_tag_category_name'), \ + unittest.mock.patch('szurubooru.func.tags.update_tag_description'), \ + unittest.mock.patch('szurubooru.func.tags.update_tag_suggestions'), \ + unittest.mock.patch('szurubooru.func.tags.update_tag_implications'), \ + unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \ + unittest.mock.patch('szurubooru.func.tags.export_to_json'): + tags.get_or_create_tags_by_names.return_value = ([], []) + tags.serialize_tag.return_value = 'serialized tag' + result = api.tag_api.update_tag( + context_factory( + params={ 'version': 1, 'names': ['tag3'], 'category': 'character', 'description': 'desc', + 'suggestions': ['sug1', 'sug2'], + 'implications': ['imp1', 'imp2'], }, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), - 'tag1') - assert len(result['snapshots']) == 1 - del result['snapshots'] - assert result == { - 'names': ['tag3'], - 'category': 'character', - 'description': 'desc', - 'suggestions': [], - 'implications': [], - 'creationTime': datetime.datetime(1996, 1, 1), - 'lastEditTime': datetime.datetime(1997, 12, 1), - 'usages': 0, - 'version': 2, - } - assert tags.try_get_tag_by_name('tag1') is None - assert tags.try_get_tag_by_name('tag2') is None - tag = tags.get_tag_by_name('tag3') - assert tag is not None - assert [tag_name.name for tag_name in tag.names] == ['tag3'] - assert tag.category.name == 'character' - assert tag.suggestions == [] - assert tag.implications == [] - assert os.path.exists(os.path.join(config.config['data_dir'], 'tags.json')) - -@pytest.mark.parametrize('input,expected_exception', [ - ({'names': None}, tags.InvalidTagNameError), - ({'names': []}, tags.InvalidTagNameError), - ({'names': [None]}, tags.InvalidTagNameError), - ({'names': ['']}, tags.InvalidTagNameError), - ({'names': ['!bad']}, tags.InvalidTagNameError), - ({'names': ['x' * 65]}, tags.InvalidTagNameError), - ({'category': None}, tag_categories.TagCategoryNotFoundError), - ({'category': ''}, tag_categories.TagCategoryNotFoundError), - ({'category': '!bad'}, tag_categories.TagCategoryNotFoundError), - ({'suggestions': ['good', '!bad']}, tags.InvalidTagNameError), - ({'implications': ['good', '!bad']}, tags.InvalidTagNameError), -]) -def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception): - db.session.add(test_ctx.tag_factory(names=['tag1'])) - db.session.commit() - with pytest.raises(expected_exception): - test_ctx.api.put( - test_ctx.context_factory( - input={**input, **{'version': 1}}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), - 'tag1') + user=auth_user), + {'tag_name': 'tag1'}) + assert result == 'serialized tag' + tags.create_tag.assert_not_called() + tags.update_tag_names.assert_called_once_with(tag, ['tag3']) + tags.update_tag_category_name.assert_called_once_with(tag, 'character') + tags.update_tag_description.assert_called_once_with(tag, 'desc') + tags.update_tag_suggestions.assert_called_once_with(tag, ['sug1', 'sug2']) + tags.update_tag_implications.assert_called_once_with(tag, ['imp1', 'imp2']) + tags.serialize_tag.assert_called_once_with(tag, options=None) @pytest.mark.parametrize( 'field', ['names', 'category', 'description', 'implications', 'suggestions']) -def test_omitting_optional_field(test_ctx, field): - db.session.add(test_ctx.tag_factory(names=['tag'])) +def test_omitting_optional_field( + user_factory, tag_factory, context_factory, field): + db.session.add(tag_factory(names=['tag'])) db.session.commit() - input = { + params = { 'names': ['tag1', 'tag2'], 'category': 'meta', 'description': 'desc', 'suggestions': [], 'implications': [], } - del input[field] - result = test_ctx.api.put( - test_ctx.context_factory( - input={**input, **{'version': 1}}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), - 'tag') - assert result is not None + del params[field] + with unittest.mock.patch('szurubooru.func.tags.create_tag'), \ + unittest.mock.patch('szurubooru.func.tags.update_tag_names'), \ + unittest.mock.patch('szurubooru.func.tags.update_tag_category_name'), \ + unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \ + unittest.mock.patch('szurubooru.func.tags.export_to_json'): + api.tag_api.update_tag( + context_factory( + params={**params, **{'version': 1}}, + user=user_factory(rank=db.User.RANK_REGULAR)), + {'tag_name': 'tag'}) -def test_trying_to_update_non_existing(test_ctx): +def test_trying_to_update_non_existing(user_factory, context_factory): with pytest.raises(tags.TagNotFoundError): - test_ctx.api.put( - test_ctx.context_factory( - input={'names': ['dummy']}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), - 'tag1') + api.tag_api.update_tag( + context_factory( + params={'names': ['dummy']}, + user=user_factory(rank=db.User.RANK_REGULAR)), + {'tag_name': 'tag1'}) -@pytest.mark.parametrize('input', [ +@pytest.mark.parametrize('params', [ {'names': 'whatever'}, {'category': 'whatever'}, {'suggestions': ['whatever']}, {'implications': ['whatever']}, ]) -def test_trying_to_update_without_privileges(test_ctx, input): - db.session.add(test_ctx.tag_factory(names=['tag'])) +def test_trying_to_update_without_privileges( + user_factory, tag_factory, context_factory, params): + db.session.add(tag_factory(names=['tag'])) db.session.commit() with pytest.raises(errors.AuthError): - test_ctx.api.put( - test_ctx.context_factory( - input={**input, **{'version': 1}}, - user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)), - 'tag') + api.tag_api.update_tag( + context_factory( + params={**params, **{'version': 1}}, + user=user_factory(rank=db.User.RANK_ANONYMOUS)), + {'tag_name': 'tag'}) + +def test_trying_to_create_tags_without_privileges( + config_injector, context_factory, tag_factory, user_factory): + tag = tag_factory(names=['tag']) + db.session.add(tag) + db.session.commit() + config_injector({'privileges': { + 'tags:create': db.User.RANK_ADMINISTRATOR, + 'tags:edit:suggestions': db.User.RANK_REGULAR, + 'tags:edit:implications': db.User.RANK_REGULAR, + }}) + with unittest.mock.patch('szurubooru.func.tags.get_or_create_tags_by_names'): + tags.get_or_create_tags_by_names.return_value = ([], ['new-tag']) + with pytest.raises(errors.AuthError): + api.tag_api.update_tag( + context_factory( + params={'suggestions': ['tag1', 'tag2'], 'version': 1}, + user=user_factory(rank=db.User.RANK_REGULAR)), + {'tag_name': 'tag'}) + with pytest.raises(errors.AuthError): + api.tag_api.update_tag( + context_factory( + params={'implications': ['tag1', 'tag2'], 'version': 1}, + user=user_factory(rank=db.User.RANK_REGULAR)), + {'tag_name': 'tag'}) diff --git a/server/szurubooru/tests/api/test_user_creating.py b/server/szurubooru/tests/api/test_user_creating.py index 0909af2b..3e607ec5 100644 --- a/server/szurubooru/tests/api/test_user_creating.py +++ b/server/szurubooru/tests/api/test_user_creating.py @@ -1,230 +1,79 @@ -import datetime import pytest -from szurubooru import api, config, db, errors -from szurubooru.func import auth, util, users +import unittest.mock +from szurubooru import api, db, errors +from szurubooru.func import users -EMPTY_PIXEL = \ - b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' \ - b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \ - b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b' +@pytest.fixture(autouse=True) +def inject_config(config_injector): + config_injector({'privileges': {'users:create': 'regular'}}) -@pytest.fixture -def test_ctx(tmpdir, config_injector, context_factory, user_factory): - config_injector({ - 'secret': '', - 'user_name_regex': '[^!]{3,}', - 'password_regex': '[^!]{3,}', - 'default_rank': db.User.RANK_REGULAR, - 'thumbnails': {'avatar_width': 200, 'avatar_height': 200}, - 'privileges': {'users:create': 'anonymous'}, - 'data_dir': str(tmpdir.mkdir('data')), - 'data_url': 'http://example.com/data/', - }) - ret = util.dotdict() - ret.context_factory = context_factory - ret.user_factory = user_factory - ret.api = api.UserListApi() - return ret - -def test_creating_user(test_ctx, fake_datetime): - with fake_datetime('1969-02-12'): - result = test_ctx.api.post( - test_ctx.context_factory( - input={ +def test_creating_user(user_factory, context_factory, fake_datetime): + user = user_factory() + with unittest.mock.patch('szurubooru.func.users.create_user'), \ + unittest.mock.patch('szurubooru.func.users.update_user_name'), \ + unittest.mock.patch('szurubooru.func.users.update_user_password'), \ + unittest.mock.patch('szurubooru.func.users.update_user_email'), \ + unittest.mock.patch('szurubooru.func.users.update_user_rank'), \ + unittest.mock.patch('szurubooru.func.users.update_user_avatar'), \ + unittest.mock.patch('szurubooru.func.users.serialize_user'), \ + fake_datetime('1969-02-12'): + users.serialize_user.return_value = 'serialized user' + users.create_user.return_value = user + result = api.user_api.create_user( + context_factory( + params={ 'name': 'chewie1', 'email': 'asd@asd.asd', 'password': 'oks', + 'rank': 'moderator', + 'avatarStyle': 'manual', }, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) - assert result == { - 'avatarStyle': 'gravatar', - 'avatarUrl': 'https://gravatar.com/avatar/' + - '6f370c8c7109534c3d5c394123a477d7?d=retro&s=200', - 'creationTime': datetime.datetime(1969, 2, 12), - 'lastLoginTime': None, - 'name': 'chewie1', - 'rank': 'administrator', - 'email': 'asd@asd.asd', - 'commentCount': 0, - 'likedPostCount': 0, - 'dislikedPostCount': 0, - 'favoritePostCount': 0, - 'uploadedPostCount': 0, - 'version': 1, - } - user = users.get_user_by_name('chewie1') - assert user.name == 'chewie1' - assert user.email == 'asd@asd.asd' - assert user.rank == db.User.RANK_ADMINISTRATOR - assert auth.is_valid_password(user, 'oks') is True - assert auth.is_valid_password(user, 'invalid') is False - -def test_first_user_becomes_admin_others_not(test_ctx): - result1 = test_ctx.api.post( - test_ctx.context_factory( - input={ - 'name': 'chewie1', - 'email': 'asd@asd.asd', - 'password': 'oks', - }, - user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS))) - result2 = test_ctx.api.post( - test_ctx.context_factory( - input={ - 'name': 'chewie2', - 'email': 'asd@asd.asd', - 'password': 'sok', - }, - user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS))) - assert result1['rank'] == 'administrator' - assert result2['rank'] == 'regular' - first_user = users.get_user_by_name('chewie1') - other_user = users.get_user_by_name('chewie2') - assert first_user.rank == db.User.RANK_ADMINISTRATOR - assert other_user.rank == db.User.RANK_REGULAR - -def test_first_user_does_not_become_admin_if_they_dont_wish_so(test_ctx): - result = test_ctx.api.post( - test_ctx.context_factory( - input={ - 'name': 'chewie1', - 'email': 'asd@asd.asd', - 'password': 'oks', - 'rank': 'regular', - }, - user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS))) - assert result['rank'] == 'regular' - -def test_trying_to_become_someone_else(test_ctx): - test_ctx.api.post( - test_ctx.context_factory( - input={ - 'name': 'chewie', - 'email': 'asd@asd.asd', - 'password': 'oks', - }, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) - with pytest.raises(users.UserAlreadyExistsError): - test_ctx.api.post( - test_ctx.context_factory( - input={ - 'name': 'chewie', - 'email': 'asd@asd.asd', - 'password': 'oks', - }, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) - with pytest.raises(users.UserAlreadyExistsError): - test_ctx.api.post( - test_ctx.context_factory( - input={ - 'name': 'CHEWIE', - 'email': 'asd@asd.asd', - 'password': 'oks', - }, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) - -@pytest.mark.parametrize('input,expected_exception', [ - ({'name': None}, users.InvalidUserNameError), - ({'name': ''}, users.InvalidUserNameError), - ({'name': '!bad'}, users.InvalidUserNameError), - ({'name': 'x' * 51}, users.InvalidUserNameError), - ({'password': None}, users.InvalidPasswordError), - ({'password': ''}, users.InvalidPasswordError), - ({'password': '!bad'}, users.InvalidPasswordError), - ({'rank': None}, users.InvalidRankError), - ({'rank': ''}, users.InvalidRankError), - ({'rank': 'bad'}, users.InvalidRankError), - ({'rank': 'anonymous'}, users.InvalidRankError), - ({'rank': 'nobody'}, users.InvalidRankError), - ({'email': 'bad'}, users.InvalidEmailError), - ({'email': 'x@' * 65 + '.com'}, users.InvalidEmailError), - ({'avatarStyle': None}, users.InvalidAvatarError), - ({'avatarStyle': ''}, users.InvalidAvatarError), - ({'avatarStyle': 'invalid'}, users.InvalidAvatarError), - ({'avatarStyle': 'manual'}, users.InvalidAvatarError), # missing file -]) -def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception): - real_input={ - 'name': 'chewie', - 'email': 'asd@asd.asd', - 'password': 'oks', - } - for key, value in input.items(): - real_input[key] = value - with pytest.raises(expected_exception): - test_ctx.api.post( - test_ctx.context_factory( - input=real_input, - user=test_ctx.user_factory( - name='u1', rank=db.User.RANK_ADMINISTRATOR))) + files={'avatar': b'...'}, + user=user_factory(rank=db.User.RANK_REGULAR))) + assert result == 'serialized user' + users.create_user.assert_called_once_with('chewie1', 'oks', 'asd@asd.asd') + assert not users.update_user_name.called + assert not users.update_user_password.called + assert not users.update_user_email.called + users.update_user_rank.called_once_with(user, 'moderator') + users.update_user_avatar.called_once_with(user, 'manual', b'...') @pytest.mark.parametrize('field', ['name', 'password']) -def test_trying_to_omit_mandatory_field(test_ctx, field): - input = { +def test_trying_to_omit_mandatory_field(user_factory, context_factory, field): + params = { 'name': 'chewie', 'email': 'asd@asd.asd', 'password': 'oks', } - del input[field] - with pytest.raises(errors.ValidationError): - test_ctx.api.post( - test_ctx.context_factory( - input=input, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) + user = user_factory() + auth_user = user_factory(rank=db.User.RANK_REGULAR) + del params[field] + with unittest.mock.patch('szurubooru.func.users.create_user'), \ + pytest.raises(errors.MissingRequiredParameterError): + users.create_user.return_value = user + api.user_api.create_user(context_factory(params=params, user=auth_user)) @pytest.mark.parametrize('field', ['rank', 'email', 'avatarStyle']) -def test_omitting_optional_field(test_ctx, field): - input = { +def test_omitting_optional_field(user_factory, context_factory, field): + params = { 'name': 'chewie', 'email': 'asd@asd.asd', 'password': 'oks', 'rank': 'moderator', - 'avatarStyle': 'manual', + 'avatarStyle': 'gravatar', } - del input[field] - result = test_ctx.api.post( - test_ctx.context_factory( - input=input, - files={'avatar': EMPTY_PIXEL}, - user=test_ctx.user_factory(rank=db.User.RANK_MODERATOR))) - assert result is not None + del params[field] + user = user_factory() + auth_user = user_factory(rank=db.User.RANK_MODERATOR) + with unittest.mock.patch('szurubooru.func.users.create_user'), \ + unittest.mock.patch('szurubooru.func.users.update_user_avatar'), \ + unittest.mock.patch('szurubooru.func.users.serialize_user'): + users.create_user.return_value = user + api.user_api.create_user( + context_factory(params=params, user=auth_user)) -def test_mods_trying_to_become_admin(test_ctx): - user1 = test_ctx.user_factory(name='u1', rank=db.User.RANK_MODERATOR) - user2 = test_ctx.user_factory(name='u2', rank=db.User.RANK_MODERATOR) - db.session.add_all([user1, user2]) - context = test_ctx.context_factory(input={ - 'name': 'chewie', - 'email': 'asd@asd.asd', - 'password': 'oks', - 'rank': 'administrator', - }, user=user1) +def test_trying_to_create_user_without_privileges(context_factory, user_factory): with pytest.raises(errors.AuthError): - test_ctx.api.post(context) - -def test_admin_creating_mod_account(test_ctx): - user = test_ctx.user_factory(rank=db.User.RANK_ADMINISTRATOR) - db.session.add(user) - context = test_ctx.context_factory(input={ - 'name': 'chewie', - 'email': 'asd@asd.asd', - 'password': 'oks', - 'rank': 'moderator', - }, user=user) - result = test_ctx.api.post(context) - assert result['rank'] == 'moderator' - -def test_uploading_avatar(test_ctx): - response = test_ctx.api.post( - test_ctx.context_factory( - input={ - 'name': 'chewie', - 'email': 'asd@asd.asd', - 'password': 'oks', - 'avatarStyle': 'manual', - }, - files={'avatar': EMPTY_PIXEL}, - user=test_ctx.user_factory(rank=db.User.RANK_MODERATOR))) - user = users.get_user_by_name('chewie') - assert user.avatar_style == user.AVATAR_MANUAL - assert response['avatarUrl'] == 'http://example.com/data/avatars/chewie.png' + api.user_api.create_user(context_factory( + params='whatever', + user=user_factory(rank=db.User.RANK_ANONYMOUS))) diff --git a/server/szurubooru/tests/api/test_user_deleting.py b/server/szurubooru/tests/api/test_user_deleting.py index 3e54b28c..ec6ca635 100644 --- a/server/szurubooru/tests/api/test_user_deleting.py +++ b/server/szurubooru/tests/api/test_user_deleting.py @@ -1,54 +1,52 @@ import pytest -from datetime import datetime from szurubooru import api, db, errors -from szurubooru.func import util, users +from szurubooru.func import users -@pytest.fixture -def test_ctx(config_injector, context_factory, user_factory): +@pytest.fixture(autouse=True) +def inject_config(config_injector): config_injector({ 'privileges': { 'users:delete:self': db.User.RANK_REGULAR, 'users:delete:any': db.User.RANK_MODERATOR, }, }) - ret = util.dotdict() - ret.context_factory = context_factory - ret.user_factory = user_factory - ret.api = api.UserDetailApi() - return ret -def test_deleting_oneself(test_ctx): - user = test_ctx.user_factory(name='u', rank=db.User.RANK_REGULAR) +def test_deleting_oneself(user_factory, context_factory): + user = user_factory(name='u', rank=db.User.RANK_REGULAR) db.session.add(user) db.session.commit() - result = test_ctx.api.delete( - test_ctx.context_factory(input={'version': 1}, user=user), 'u') + result = api.user_api.delete_user( + context_factory( + params={'version': 1}, user=user), {'user_name': 'u'}) assert result == {} assert db.session.query(db.User).count() == 0 -def test_deleting_someone_else(test_ctx): - user1 = test_ctx.user_factory(name='u1', rank=db.User.RANK_REGULAR) - user2 = test_ctx.user_factory(name='u2', rank=db.User.RANK_MODERATOR) +def test_deleting_someone_else(user_factory, context_factory): + user1 = user_factory(name='u1', rank=db.User.RANK_REGULAR) + user2 = user_factory(name='u2', rank=db.User.RANK_MODERATOR) db.session.add_all([user1, user2]) db.session.commit() - test_ctx.api.delete( - test_ctx.context_factory(input={'version': 1}, user=user2), 'u1') + api.user_api.delete_user( + context_factory( + params={'version': 1}, user=user2), {'user_name': 'u1'}) assert db.session.query(db.User).count() == 1 -def test_trying_to_delete_someone_else_without_privileges(test_ctx): - user1 = test_ctx.user_factory(name='u1', rank=db.User.RANK_REGULAR) - user2 = test_ctx.user_factory(name='u2', rank=db.User.RANK_REGULAR) +def test_trying_to_delete_someone_else_without_privileges( + user_factory, context_factory): + user1 = user_factory(name='u1', rank=db.User.RANK_REGULAR) + user2 = user_factory(name='u2', rank=db.User.RANK_REGULAR) db.session.add_all([user1, user2]) db.session.commit() with pytest.raises(errors.AuthError): - test_ctx.api.delete( - test_ctx.context_factory(input={'version': 1}, user=user2), 'u1') + api.user_api.delete_user( + context_factory( + params={'version': 1}, user=user2), {'user_name': 'u1'}) assert db.session.query(db.User).count() == 2 -def test_trying_to_delete_non_existing(test_ctx): +def test_trying_to_delete_non_existing(user_factory, context_factory): with pytest.raises(users.UserNotFoundError): - test_ctx.api.delete( - test_ctx.context_factory( - input={'version': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), - 'bad') + api.user_api.delete_user( + context_factory( + params={'version': 1}, + user=user_factory(rank=db.User.RANK_REGULAR)), + {'user_name': 'bad'}) diff --git a/server/szurubooru/tests/api/test_user_retrieving.py b/server/szurubooru/tests/api/test_user_retrieving.py index 2720a3c0..3e0dab0b 100644 --- a/server/szurubooru/tests/api/test_user_retrieving.py +++ b/server/szurubooru/tests/api/test_user_retrieving.py @@ -1,83 +1,64 @@ -import datetime +import unittest.mock import pytest from szurubooru import api, db, errors -from szurubooru.func import util, users +from szurubooru.func import users -@pytest.fixture -def test_ctx(context_factory, config_injector, user_factory): +@pytest.fixture(autouse=True) +def inject_config(config_injector): config_injector({ 'privileges': { 'users:list': db.User.RANK_REGULAR, 'users:view': db.User.RANK_REGULAR, 'users:edit:any:email': db.User.RANK_MODERATOR, }, - 'thumbnails': {'avatar_width': 200}, }) - ret = util.dotdict() - ret.context_factory = context_factory - ret.user_factory = user_factory - ret.list_api = api.UserListApi() - ret.detail_api = api.UserDetailApi() - return ret -def test_retrieving_multiple(test_ctx): - user1 = test_ctx.user_factory(name='u1', rank=db.User.RANK_MODERATOR) - user2 = test_ctx.user_factory(name='u2', rank=db.User.RANK_MODERATOR) +def test_retrieving_multiple(user_factory, context_factory): + user1 = user_factory(name='u1', rank=db.User.RANK_MODERATOR) + user2 = user_factory(name='u2', rank=db.User.RANK_MODERATOR) db.session.add_all([user1, user2]) - result = test_ctx.list_api.get( - test_ctx.context_factory( - input={'query': '', 'page': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) - assert result['query'] == '' - assert result['page'] == 1 - assert result['pageSize'] == 100 - assert result['total'] == 2 - assert [u['name'] for u in result['results']] == ['u1', 'u2'] + with unittest.mock.patch('szurubooru.func.users.serialize_user'): + users.serialize_user.return_value = 'serialized user' + result = api.user_api.get_users( + context_factory( + params={'query': '', 'page': 1}, + user=user_factory(rank=db.User.RANK_REGULAR))) + assert result == { + 'query': '', + 'page': 1, + 'pageSize': 100, + 'total': 2, + 'results': ['serialized user', 'serialized user'], + } -def test_trying_to_retrieve_multiple_without_privileges(test_ctx): +def test_trying_to_retrieve_multiple_without_privileges( + user_factory, context_factory): with pytest.raises(errors.AuthError): - test_ctx.list_api.get( - test_ctx.context_factory( - input={'query': '', 'page': 1}, - user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS))) + api.user_api.get_users( + context_factory( + params={'query': '', 'page': 1}, + user=user_factory(rank=db.User.RANK_ANONYMOUS))) -def test_retrieving_single(test_ctx): - db.session.add(test_ctx.user_factory(name='u1', rank=db.User.RANK_REGULAR)) - result = test_ctx.detail_api.get( - test_ctx.context_factory( - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), - 'u1') - assert result == { - 'name': 'u1', - 'rank': db.User.RANK_REGULAR, - 'creationTime': datetime.datetime(1997, 1, 1), - 'lastLoginTime': None, - 'avatarStyle': 'gravatar', - 'avatarUrl': 'https://gravatar.com/avatar/' + - '275876e34cf609db118f3d84b799a790?d=retro&s=200', - 'email': False, - 'commentCount': 0, - 'likedPostCount': False, - 'dislikedPostCount': False, - 'favoritePostCount': 0, - 'uploadedPostCount': 0, - 'version': 1, - } - assert result['email'] is False - assert result['likedPostCount'] is False - assert result['dislikedPostCount'] is False +def test_retrieving_single(user_factory, context_factory): + user = user_factory(name='u1', rank=db.User.RANK_REGULAR) + auth_user = user_factory(rank=db.User.RANK_REGULAR) + db.session.add(user) + with unittest.mock.patch('szurubooru.func.users.serialize_user'): + users.serialize_user.return_value = 'serialized user' + result = api.user_api.get_user( + context_factory(user=auth_user), {'user_name': 'u1'}) + assert result == 'serialized user' -def test_trying_to_retrieve_single_non_existing(test_ctx): +def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): + auth_user = user_factory(rank=db.User.RANK_REGULAR) with pytest.raises(users.UserNotFoundError): - test_ctx.detail_api.get( - test_ctx.context_factory( - user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), - '-') + api.user_api.get_user( + context_factory(user=auth_user), {'user_name': '-'}) -def test_trying_to_retrieve_single_without_privileges(test_ctx): - db.session.add(test_ctx.user_factory(name='u1', rank=db.User.RANK_REGULAR)) +def test_trying_to_retrieve_single_without_privileges( + user_factory, context_factory): + auth_user = user_factory(rank=db.User.RANK_ANONYMOUS) + db.session.add(user_factory(name='u1', rank=db.User.RANK_REGULAR)) with pytest.raises(errors.AuthError): - test_ctx.detail_api.get( - test_ctx.context_factory( - user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)), - 'u1') + api.user_api.get_user( + context_factory(user=auth_user), {'user_name': 'u1'}) diff --git a/server/szurubooru/tests/api/test_user_updating.py b/server/szurubooru/tests/api/test_user_updating.py index 7ea0b810..bc93295f 100644 --- a/server/szurubooru/tests/api/test_user_updating.py +++ b/server/szurubooru/tests/api/test_user_updating.py @@ -1,20 +1,12 @@ -import datetime import pytest -from szurubooru import api, config, db, errors -from szurubooru.func import auth, util, users +import unittest.mock +from datetime import datetime +from szurubooru import api, db, errors +from szurubooru.func import users -EMPTY_PIXEL = \ - b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' \ - b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \ - b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b' - -@pytest.fixture -def test_ctx(tmpdir, config_injector, context_factory, user_factory): +@pytest.fixture(autouse=True) +def inject_config(config_injector): config_injector({ - 'secret': '', - 'user_name_regex': '^[^!]{3,}$', - 'password_regex': '^[^!]{3,}$', - 'thumbnails': {'avatar_width': 200, 'avatar_height': 200}, 'privileges': { 'users:edit:self:name': db.User.RANK_REGULAR, 'users:edit:self:pass': db.User.RANK_REGULAR, @@ -27,203 +19,97 @@ def test_ctx(tmpdir, config_injector, context_factory, user_factory): 'users:edit:any:rank': db.User.RANK_ADMINISTRATOR, 'users:edit:any:avatar': db.User.RANK_ADMINISTRATOR, }, - 'data_dir': str(tmpdir.mkdir('data')), - 'data_url': 'http://example.com/data/', }) - ret = util.dotdict() - ret.context_factory = context_factory - ret.user_factory = user_factory - ret.api = api.UserDetailApi() - return ret -def test_updating_user(test_ctx): - user = test_ctx.user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR) +def test_updating_user(context_factory, user_factory): + user = user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR) + auth_user = user_factory(rank=db.User.RANK_ADMINISTRATOR) db.session.add(user) - result = test_ctx.api.put( - test_ctx.context_factory( - input={ - 'version': 1, - 'name': 'chewie', - 'email': 'asd@asd.asd', - 'password': 'oks', - 'rank': 'moderator', - 'avatarStyle': 'gravatar', - }, - user=user), - 'u1') - assert result == { - 'avatarStyle': 'gravatar', - 'avatarUrl': 'https://gravatar.com/avatar/' + - '6f370c8c7109534c3d5c394123a477d7?d=retro&s=200', - 'creationTime': datetime.datetime(1997, 1, 1), - 'lastLoginTime': None, - 'email': 'asd@asd.asd', - 'name': 'chewie', - 'rank': 'moderator', - 'commentCount': 0, - 'likedPostCount': 0, - 'dislikedPostCount': 0, - 'favoritePostCount': 0, - 'uploadedPostCount': 0, - 'version': 2, - } - user = users.get_user_by_name('chewie') - assert user.name == 'chewie' - assert user.email == 'asd@asd.asd' - assert user.rank == db.User.RANK_MODERATOR - assert user.avatar_style == user.AVATAR_GRAVATAR - assert auth.is_valid_password(user, 'oks') is True - assert auth.is_valid_password(user, 'invalid') is False + db.session.flush() -@pytest.mark.parametrize('input,expected_exception', [ - ({'name': None}, users.InvalidUserNameError), - ({'name': ''}, users.InvalidUserNameError), - ({'name': '!bad'}, users.InvalidUserNameError), - ({'name': 'x' * 51}, users.InvalidUserNameError), - ({'password': None}, users.InvalidPasswordError), - ({'password': ''}, users.InvalidPasswordError), - ({'password': '!bad'}, users.InvalidPasswordError), - ({'rank': None}, users.InvalidRankError), - ({'rank': ''}, users.InvalidRankError), - ({'rank': 'bad'}, users.InvalidRankError), - ({'rank': 'anonymous'}, users.InvalidRankError), - ({'rank': 'nobody'}, users.InvalidRankError), - ({'email': 'bad'}, users.InvalidEmailError), - ({'email': 'x@' * 65 + '.com'}, users.InvalidEmailError), - ({'avatarStyle': None}, users.InvalidAvatarError), - ({'avatarStyle': ''}, users.InvalidAvatarError), - ({'avatarStyle': 'invalid'}, users.InvalidAvatarError), - ({'avatarStyle': 'manual'}, users.InvalidAvatarError), # missing file -]) -def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception): - user = test_ctx.user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR) - db.session.add(user) - with pytest.raises(expected_exception): - test_ctx.api.put( - test_ctx.context_factory( - input={**input, **{'version': 1}}, - user=user), - 'u1') + with unittest.mock.patch('szurubooru.func.users.create_user'), \ + unittest.mock.patch('szurubooru.func.users.update_user_name'), \ + unittest.mock.patch('szurubooru.func.users.update_user_password'), \ + unittest.mock.patch('szurubooru.func.users.update_user_email'), \ + unittest.mock.patch('szurubooru.func.users.update_user_rank'), \ + unittest.mock.patch('szurubooru.func.users.update_user_avatar'), \ + unittest.mock.patch('szurubooru.func.users.serialize_user'): + users.serialize_user.return_value = 'serialized user' + + result = api.user_api.update_user( + context_factory( + params={ + 'version': 1, + 'name': 'chewie', + 'email': 'asd@asd.asd', + 'password': 'oks', + 'rank': 'moderator', + 'avatarStyle': 'manual', + }, + files={ + 'avatar': b'...', + }, + user=auth_user), + {'user_name': 'u1'}) + + assert result == 'serialized user' + users.create_user.assert_not_called() + users.update_user_name.assert_called_once_with(user, 'chewie') + users.update_user_password.assert_called_once_with(user, 'oks') + users.update_user_email.assert_called_once_with(user, 'asd@asd.asd') + users.update_user_rank.assert_called_once_with(user, 'moderator', auth_user) + users.update_user_avatar.assert_called_once_with(user, 'manual', b'...') + users.serialize_user.assert_called_once_with(user, auth_user, options=None) @pytest.mark.parametrize( 'field', ['name', 'email', 'password', 'rank', 'avatarStyle']) -def test_omitting_optional_field(test_ctx, field): - user = test_ctx.user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR) +def test_omitting_optional_field(user_factory, context_factory, field): + user = user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR) db.session.add(user) - input = { + params = { 'name': 'chewie', 'email': 'asd@asd.asd', 'password': 'oks', 'rank': 'moderator', 'avatarStyle': 'gravatar', } - del input[field] - result = test_ctx.api.put( - test_ctx.context_factory( - input={**input, **{'version': 1}}, - files={'avatar': EMPTY_PIXEL}, - user=user), - 'u1') - assert result is not None + del params[field] + with unittest.mock.patch('szurubooru.func.users.create_user'), \ + unittest.mock.patch('szurubooru.func.users.update_user_name'), \ + unittest.mock.patch('szurubooru.func.users.update_user_password'), \ + unittest.mock.patch('szurubooru.func.users.update_user_email'), \ + unittest.mock.patch('szurubooru.func.users.update_user_rank'), \ + unittest.mock.patch('szurubooru.func.users.update_user_avatar'), \ + unittest.mock.patch('szurubooru.func.users.serialize_user'): + api.user_api.update_user( + context_factory( + params={**params, **{'version': 1}}, + files={'avatar': b'...'}, + user=user), + {'user_name': 'u1'}) -def test_trying_to_update_non_existing(test_ctx): - user = test_ctx.user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR) +def test_trying_to_update_non_existing(user_factory, context_factory): + user = user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR) db.session.add(user) with pytest.raises(users.UserNotFoundError): - test_ctx.api.put(test_ctx.context_factory(user=user), 'u2') + api.user_api.update_user( + context_factory(user=user), {'user_name': 'u2'}) -def test_removing_email(test_ctx): - user = test_ctx.user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR) - db.session.add(user) - test_ctx.api.put( - test_ctx.context_factory( - input={'email': '', 'version': 1}, user=user), 'u1') - assert users.get_user_by_name('u1').email is None - -@pytest.mark.parametrize('input', [ +@pytest.mark.parametrize('params', [ {'name': 'whatever'}, {'email': 'whatever'}, {'rank': 'whatever'}, {'password': 'whatever'}, {'avatarStyle': 'whatever'}, ]) -def test_trying_to_update_someone_else(test_ctx, input): - user1 = test_ctx.user_factory(name='u1', rank=db.User.RANK_REGULAR) - user2 = test_ctx.user_factory(name='u2', rank=db.User.RANK_REGULAR) +def test_trying_to_update_field_without_privileges( + user_factory, context_factory, params): + user1 = user_factory(name='u1', rank=db.User.RANK_REGULAR) + user2 = user_factory(name='u2', rank=db.User.RANK_REGULAR) db.session.add_all([user1, user2]) with pytest.raises(errors.AuthError): - test_ctx.api.put( - test_ctx.context_factory( - input={**input, **{'version': 1}}, + api.user_api.update_user( + context_factory( + params={**params, **{'version': 1}}, user=user1), - user2.name) - -def test_trying_to_become_someone_else(test_ctx): - user1 = test_ctx.user_factory(name='me', rank=db.User.RANK_REGULAR) - user2 = test_ctx.user_factory(name='her', rank=db.User.RANK_REGULAR) - db.session.add_all([user1, user2]) - with pytest.raises(users.UserAlreadyExistsError): - test_ctx.api.put( - test_ctx.context_factory( - input={'name': 'her', 'version': 1}, user=user1), - 'me') - with pytest.raises(users.UserAlreadyExistsError): - test_ctx.api.put( - test_ctx.context_factory( - input={'name': 'HER', 'version': 1}, user=user1), - 'me') - -def test_trying_to_make_someone_into_someone_else(test_ctx): - user1 = test_ctx.user_factory(name='him', rank=db.User.RANK_REGULAR) - user2 = test_ctx.user_factory(name='her', rank=db.User.RANK_REGULAR) - user3 = test_ctx.user_factory(name='me', rank=db.User.RANK_MODERATOR) - db.session.add_all([user1, user2, user3]) - with pytest.raises(users.UserAlreadyExistsError): - test_ctx.api.put( - test_ctx.context_factory( - input={'name': 'her', 'version': 1}, user=user3), - 'him') - with pytest.raises(users.UserAlreadyExistsError): - test_ctx.api.put( - test_ctx.context_factory( - input={'name': 'HER', 'version': 1}, user=user3), - 'him') - -def test_renaming_someone_else(test_ctx): - user1 = test_ctx.user_factory(name='him', rank=db.User.RANK_REGULAR) - user2 = test_ctx.user_factory(name='me', rank=db.User.RANK_MODERATOR) - db.session.add_all([user1, user2]) - test_ctx.api.put( - test_ctx.context_factory( - input={'name': 'himself', 'version': 1}, user=user2), - 'him') - test_ctx.api.put( - test_ctx.context_factory( - input={'name': 'HIMSELF', 'version': 2}, user=user2), - 'himself') - -def test_mods_trying_to_become_admin(test_ctx): - user1 = test_ctx.user_factory(name='u1', rank=db.User.RANK_MODERATOR) - user2 = test_ctx.user_factory(name='u2', rank=db.User.RANK_MODERATOR) - db.session.add_all([user1, user2]) - context = test_ctx.context_factory( - input={'rank': 'administrator', 'version': 1}, - user=user1) - with pytest.raises(errors.AuthError): - test_ctx.api.put(context, user1.name) - with pytest.raises(errors.AuthError): - test_ctx.api.put(context, user2.name) - -def test_uploading_avatar(test_ctx): - user = test_ctx.user_factory(name='u1', rank=db.User.RANK_MODERATOR) - db.session.add(user) - response = test_ctx.api.put( - test_ctx.context_factory( - input={'avatarStyle': 'manual', 'version': 1}, - files={'avatar': EMPTY_PIXEL}, - user=user), - 'u1') - user = users.get_user_by_name('u1') - assert user.avatar_style == user.AVATAR_MANUAL - assert response['avatarUrl'] == \ - 'http://example.com/data/avatars/u1.png' + {'user_name': user2.name}) diff --git a/server/szurubooru/tests/conftest.py b/server/szurubooru/tests/conftest.py index 4fb3882f..11f8398a 100644 --- a/server/szurubooru/tests/conftest.py +++ b/server/szurubooru/tests/conftest.py @@ -5,7 +5,7 @@ import uuid import pytest import freezegun import sqlalchemy -from szurubooru import api, config, db +from szurubooru import api, config, db, rest from szurubooru.func import util class QueryCounter(object): @@ -74,12 +74,14 @@ def session(query_logger): @pytest.fixture def context_factory(session): - def factory(request=None, input=None, files=None, user=None): - ctx = api.Context() - ctx.input = input or {} + def factory(params=None, files=None, user=None): + ctx = rest.Context( + method=None, + url=None, + headers={}, + params=params or {}, + files=files or {}) ctx.session = session - ctx.request = request or {} - ctx.files = files or {} ctx.user = user or db.User() return ctx return factory diff --git a/server/szurubooru/tests/api/test_context.py b/server/szurubooru/tests/rest/test_context.py similarity index 80% rename from server/szurubooru/tests/api/test_context.py rename to server/szurubooru/tests/rest/test_context.py index 9f74c054..6fca391f 100644 --- a/server/szurubooru/tests/api/test_context.py +++ b/server/szurubooru/tests/rest/test_context.py @@ -1,32 +1,30 @@ import unittest.mock import pytest -from szurubooru import api, errors +from szurubooru import rest, errors from szurubooru.func import net def test_has_param(): - ctx = api.Context() - ctx.input = {'key': 'value'} + ctx = rest.Context(method=None, url=None, params={'key': 'value'}) assert ctx.has_param('key') assert not ctx.has_param('key2') def test_get_file(): - ctx = api.Context() - ctx.files = {'key': b'content'} + ctx = rest.Context(method=None, url=None, files={'key': b'content'}) assert ctx.get_file('key') == b'content' assert ctx.get_file('key2') is None def test_get_file_from_url(): with unittest.mock.patch('szurubooru.func.net.download'): net.download.return_value = b'content' - ctx = api.Context() - ctx.input = {'keyUrl': 'example.com'} + ctx = rest.Context( + method=None, url=None, params={'keyUrl': 'example.com'}) assert ctx.get_file('key') == b'content' assert ctx.get_file('key2') is None net.download.assert_called_once_with('example.com') def test_getting_list_parameter(): - ctx = api.Context() - ctx.input = {'key': 'value', 'list': ['1', '2', '3']} + ctx = rest.Context( + method=None, url=None, params={'key': 'value', 'list': ['1', '2', '3']}) assert ctx.get_param_as_list('key') == ['value'] assert ctx.get_param_as_list('key2') is None assert ctx.get_param_as_list('key2', default=['def']) == ['def'] @@ -35,8 +33,8 @@ def test_getting_list_parameter(): ctx.get_param_as_list('key2', required=True) def test_getting_string_parameter(): - ctx = api.Context() - ctx.input = {'key': 'value', 'list': ['1', '2', '3']} + ctx = rest.Context( + method=None, url=None, params={'key': 'value', 'list': ['1', '2', '3']}) assert ctx.get_param_as_string('key') == 'value' assert ctx.get_param_as_string('key2') is None assert ctx.get_param_as_string('key2', default='def') == 'def' @@ -45,8 +43,10 @@ def test_getting_string_parameter(): ctx.get_param_as_string('key2', required=True) def test_getting_int_parameter(): - ctx = api.Context() - ctx.input = {'key': '50', 'err': 'invalid', 'list': [1, 2, 3]} + ctx = rest.Context( + method=None, + url=None, + params={'key': '50', 'err': 'invalid', 'list': [1, 2, 3]}) assert ctx.get_param_as_int('key') == 50 assert ctx.get_param_as_int('key2') is None assert ctx.get_param_as_int('key2', default=5) == 5 @@ -65,8 +65,7 @@ def test_getting_int_parameter(): def test_getting_bool_parameter(): def test(value): - ctx = api.Context() - ctx.input = {'key': value} + ctx = rest.Context(method=None, url=None, params={'key': value}) return ctx.get_param_as_bool('key') assert test('1') is True @@ -94,7 +93,7 @@ def test_getting_bool_parameter(): with pytest.raises(errors.ValidationError): test(['1', '2']) - ctx = api.Context() + ctx = rest.Context(method=None, url=None) assert ctx.get_param_as_bool('non-existing') is None assert ctx.get_param_as_bool('non-existing', default=True) is True with pytest.raises(errors.ValidationError):