diff --git a/server/szurubooru/api/comment_api.py b/server/szurubooru/api/comment_api.py index d1201ff2..7f675d18 100644 --- a/server/szurubooru/api/comment_api.py +++ b/server/szurubooru/api/comment_api.py @@ -18,12 +18,9 @@ class CommentListApi(BaseApi): def post(self, ctx): auth.verify_privilege(ctx.user, 'comments:create') - text = ctx.get_param_as_string('text', required=True) post_id = ctx.get_param_as_int('postId', required=True) post = posts.get_post_by_id(post_id) - if not post: - raise posts.PostNotFoundError('Post %r not found.' % post_id) comment = comments.create_comment(ctx.user, post, text) ctx.session.add(comment) ctx.session.commit() @@ -33,41 +30,21 @@ class CommentDetailApi(BaseApi): def get(self, ctx, comment_id): auth.verify_privilege(ctx.user, 'comments:view') comment = comments.get_comment_by_id(comment_id) - if not comment: - raise comments.CommentNotFoundError( - 'Comment %r not found.' % comment_id) return {'comment': comments.serialize_comment(comment, ctx.user)} def put(self, ctx, comment_id): comment = comments.get_comment_by_id(comment_id) - if not comment: - raise comments.CommentNotFoundError( - 'Comment %r not found.' % comment_id) - - if ctx.user.user_id == comment.user_id: - infix = 'self' - else: - infix = 'any' - - comment.last_edit_time = datetime.datetime.now() - auth.verify_privilege(ctx.user, 'comments:edit:%s' % infix) + infix = 'self' if ctx.user.user_id == comment.user_id else 'any' text = ctx.get_param_as_string('text', required=True) + auth.verify_privilege(ctx.user, 'comments:edit:%s' % infix) + comment.last_edit_time = datetime.datetime.now() comments.update_comment_text(comment, text) - ctx.session.commit() return {'comment': comments.serialize_comment(comment, ctx.user)} def delete(self, ctx, comment_id): comment = comments.get_comment_by_id(comment_id) - if not comment: - raise comments.CommentNotFoundError( - 'Comment %r not found.' % comment_id) - - if ctx.user.user_id == comment.user_id: - infix = 'self' - else: - infix = 'any' - + infix = 'self' if ctx.user.user_id == comment.user_id else 'any' auth.verify_privilege(ctx.user, 'comments:delete:%s' % infix) ctx.session.delete(comment) ctx.session.commit() diff --git a/server/szurubooru/api/context.py b/server/szurubooru/api/context.py index 3b725fc1..58890e04 100644 --- a/server/szurubooru/api/context.py +++ b/server/szurubooru/api/context.py @@ -50,19 +50,15 @@ class Context(object): raise errors.ValidationError( 'Parameter %r is invalid: the value must be an integer.' % name) - if min is not None and val < min: raise errors.ValidationError( 'Parameter %r is invalid: the value must be at least %r.' % (name, min)) - if max is not None and val > max: raise errors.ValidationError( 'Parameter %r is invalid: the value may not exceed %r.' % (name, max)) - return val - if not required: return default raise errors.ValidationError( diff --git a/server/szurubooru/api/info_api.py b/server/szurubooru/api/info_api.py index d6c32e8f..29043b75 100644 --- a/server/szurubooru/api/info_api.py +++ b/server/szurubooru/api/info_api.py @@ -11,7 +11,7 @@ class InfoApi(BaseApi): self._cache_result = None def get(self, ctx): - featured_post = posts.get_featured_post() + featured_post = posts.try_get_featured_post() return { 'postCount': posts.get_post_count(), 'diskUsage': self._get_disk_usage(), diff --git a/server/szurubooru/api/password_reset_api.py b/server/szurubooru/api/password_reset_api.py index 93466705..cca9576e 100644 --- a/server/szurubooru/api/password_reset_api.py +++ b/server/szurubooru/api/password_reset_api.py @@ -12,8 +12,6 @@ class PasswordResetApi(BaseApi): def get(self, _ctx, user_name): ''' Send a mail with secure token to the correlated user. ''' user = users.get_user_by_name_or_email(user_name) - if not user: - raise errors.NotFoundError('User %r not found.' % user_name) if not user.email: raise errors.ValidationError( 'User %r hasn\'t supplied email. Cannot reset password.' % user_name) @@ -30,8 +28,6 @@ class PasswordResetApi(BaseApi): def post(self, ctx, user_name): ''' Verify token from mail, generate a new password and return it. ''' user = users.get_user_by_name_or_email(user_name) - if not user: - raise errors.NotFoundError('User %r not found.' % user_name) good_token = auth.generate_authentication_token(user) token = ctx.get_param_as_string('token', required=True) if token != good_token: diff --git a/server/szurubooru/api/post_api.py b/server/szurubooru/api/post_api.py index 23df5c7e..7565b92f 100644 --- a/server/szurubooru/api/post_api.py +++ b/server/szurubooru/api/post_api.py @@ -6,9 +6,7 @@ class PostFeatureApi(BaseApi): auth.verify_privilege(ctx.user, 'posts:feature') post_id = ctx.get_param_as_int('id', required=True) post = posts.get_post_by_id(post_id) - if not post: - raise posts.PostNotFoundError('Post %r not found.' % post_id) - featured_post = posts.get_featured_post() + featured_post = posts.try_get_featured_post() if featured_post and featured_post.post_id == post.post_id: raise posts.PostAlreadyFeaturedError( 'Post %r is already featured.' % post_id) @@ -20,5 +18,5 @@ class PostFeatureApi(BaseApi): return posts.serialize_post_with_details(post, ctx.user) def get(self, ctx): - post = posts.get_featured_post() + post = posts.try_get_featured_post() return posts.serialize_post_with_details(post, ctx.user) diff --git a/server/szurubooru/api/tag_api.py b/server/szurubooru/api/tag_api.py index e13b880e..99078214 100644 --- a/server/szurubooru/api/tag_api.py +++ b/server/szurubooru/api/tag_api.py @@ -35,31 +35,22 @@ class TagDetailApi(BaseApi): def get(self, ctx, tag_name): auth.verify_privilege(ctx.user, 'tags:view') tag = tags.get_tag_by_name(tag_name) - if not tag: - raise tags.TagNotFoundError('Tag %r not found.' % tag_name) return tags.serialize_tag_with_details(tag) def put(self, ctx, tag_name): tag = tags.get_tag_by_name(tag_name) - if not tag: - raise tags.TagNotFoundError('Tag %r not found.' % tag_name) - if ctx.has_param('names'): auth.verify_privilege(ctx.user, 'tags:edit:names') tags.update_names(tag, ctx.get_param_as_list('names')) - if ctx.has_param('category'): auth.verify_privilege(ctx.user, 'tags:edit:category') tags.update_category_name(tag, ctx.get_param_as_string('category')) - if ctx.has_param('suggestions'): auth.verify_privilege(ctx.user, 'tags:edit:suggestions') tags.update_suggestions(tag, ctx.get_param_as_list('suggestions')) - if ctx.has_param('implications'): auth.verify_privilege(ctx.user, 'tags:edit:implications') tags.update_implications(tag, ctx.get_param_as_list('implications')) - tag.last_edit_time = datetime.datetime.now() ctx.session.flush() snapshots.modify(tag, ctx.user) @@ -69,13 +60,10 @@ class TagDetailApi(BaseApi): def delete(self, ctx, tag_name): tag = tags.get_tag_by_name(tag_name) - if not tag: - raise tags.TagNotFoundError('Tag %r not found.' % tag_name) if tag.post_count > 0: raise tags.TagIsInUseError( 'Tag has some usages and cannot be deleted. ' + 'Please untag relevant posts first.') - auth.verify_privilege(ctx.user, 'tags:delete') snapshots.delete(tag, ctx.user) ctx.session.delete(tag) @@ -89,15 +77,8 @@ class TagMergeApi(BaseApi): target_tag_name = ctx.get_param_as_string('merge-to', required=True) or '' source_tag = tags.get_tag_by_name(source_tag_name) target_tag = tags.get_tag_by_name(target_tag_name) - if not source_tag: - raise tags.TagNotFoundError( - 'Source tag %r not found.' % source_tag_name) - if not target_tag: - raise tags.TagNotFoundError( - 'Source tag %r not found.' % target_tag_name) if source_tag.tag_id == target_tag.tag_id: - raise tags.InvalidTagRelationError( - 'Cannot merge tag with itself.') + raise tags.InvalidTagRelationError('Cannot merge tag with itself.') auth.verify_privilege(ctx.user, 'tags:merge') snapshots.delete(source_tag, ctx.user) tags.merge_tags(source_tag, target_tag) @@ -109,8 +90,6 @@ class TagSiblingsApi(BaseApi): def get(self, ctx, tag_name): auth.verify_privilege(ctx.user, 'tags:view') tag = tags.get_tag_by_name(tag_name) - if not tag: - raise tags.TagNotFoundError('Tag %r not found.' % tag_name) result = tags.get_siblings(tag) serialized_siblings = [] for sibling, occurrences in result: diff --git a/server/szurubooru/api/tag_category_api.py b/server/szurubooru/api/tag_category_api.py index bc6cbf3f..798fa1fc 100644 --- a/server/szurubooru/api/tag_category_api.py +++ b/server/szurubooru/api/tag_category_api.py @@ -7,7 +7,8 @@ class TagCategoryListApi(BaseApi): categories = tag_categories.get_all_categories() return { 'tagCategories': [ - tags.serialize_category(category) for category in categories], + tag_categories.serialize_category(category) \ + for category in categories], } def post(self, ctx): @@ -20,22 +21,16 @@ class TagCategoryListApi(BaseApi): snapshots.create(category, ctx.user) ctx.session.commit() tags.export_to_json() - return tags.serialize_category_with_details(category) + return tag_categories.serialize_category_with_details(category) class TagCategoryDetailApi(BaseApi): def get(self, ctx, category_name): auth.verify_privilege(ctx.user, 'tag_categories:view') category = tag_categories.get_category_by_name(category_name) - if not category: - raise tag_categories.TagCategoryNotFoundError( - 'Tag category %r not found.' % category_name) - return tags.serialize_category_with_details(category) + return tag_categories.serialize_category_with_details(category) def put(self, ctx, category_name): category = tag_categories.get_category_by_name(category_name) - if not category: - raise tag_categories.TagCategoryNotFoundError( - 'Tag category %r not found.' % category_name) if ctx.has_param('name'): auth.verify_privilege(ctx.user, 'tag_categories:edit:name') tag_categories.update_name( @@ -48,13 +43,10 @@ class TagCategoryDetailApi(BaseApi): snapshots.modify(category, ctx.user) ctx.session.commit() tags.export_to_json() - return tags.serialize_category_with_details(category) + return tag_categories.serialize_category_with_details(category) def delete(self, ctx, category_name): category = tag_categories.get_category_by_name(category_name) - if not category: - raise tag_categories.TagCategoryNotFoundError( - 'Tag category %r not found.' % category_name) auth.verify_privilege(ctx.user, 'tag_categories:delete') if len(tag_categories.get_all_category_names()) == 1: raise tag_categories.TagCategoryIsInUseError( diff --git a/server/szurubooru/api/user_api.py b/server/szurubooru/api/user_api.py index 4ba2d7ad..4c91fee0 100644 --- a/server/szurubooru/api/user_api.py +++ b/server/szurubooru/api/user_api.py @@ -14,22 +14,17 @@ class UserListApi(BaseApi): def post(self, ctx): auth.verify_privilege(ctx.user, 'users:create') - name = ctx.get_param_as_string('name', required=True) password = ctx.get_param_as_string('password', required=True) email = ctx.get_param_as_string('email', required=False, default='') - user = users.create_user(name, password, email, ctx.user) - if ctx.has_param('rank'): users.update_rank(user, ctx.get_param_as_string('rank'), ctx.user) - if ctx.has_param('avatarStyle'): users.update_avatar( user, ctx.get_param_as_string('avatarStyle'), ctx.get_file('avatar')) - ctx.session.add(user) ctx.session.commit() return {'user': users.serialize_user(user, ctx.user)} @@ -38,56 +33,35 @@ class UserDetailApi(BaseApi): def get(self, ctx, user_name): auth.verify_privilege(ctx.user, 'users:view') user = users.get_user_by_name(user_name) - if not user: - raise users.UserNotFoundError('User %r not found.' % user_name) return {'user': users.serialize_user(user, ctx.user)} def put(self, ctx, user_name): user = users.get_user_by_name(user_name) - if not user: - raise users.UserNotFoundError('User %r not found.' % user_name) - - if ctx.user.user_id == user.user_id: - infix = 'self' - else: - infix = 'any' - + infix = 'self' if ctx.user.user_id == user.user_id else 'any' if ctx.has_param('name'): auth.verify_privilege(ctx.user, 'users:edit:%s:name' % infix) users.update_name(user, ctx.get_param_as_string('name'), ctx.user) - if ctx.has_param('password'): auth.verify_privilege(ctx.user, 'users:edit:%s:pass' % infix) users.update_password(user, ctx.get_param_as_string('password')) - if ctx.has_param('email'): auth.verify_privilege(ctx.user, 'users:edit:%s:email' % infix) users.update_email(user, ctx.get_param_as_string('email')) - if ctx.has_param('rank'): auth.verify_privilege(ctx.user, 'users:edit:%s:rank' % infix) users.update_rank(user, ctx.get_param_as_string('rank'), ctx.user) - if ctx.has_param('avatarStyle'): auth.verify_privilege(ctx.user, 'users:edit:%s:avatar' % infix) users.update_avatar( user, ctx.get_param_as_string('avatarStyle'), ctx.get_file('avatar')) - ctx.session.commit() return {'user': users.serialize_user(user, ctx.user)} def delete(self, ctx, user_name): user = users.get_user_by_name(user_name) - if not user: - raise users.UserNotFoundError('User %r not found.' % user_name) - - if ctx.user.user_id == user.user_id: - infix = 'self' - else: - infix = 'any' - + infix = 'self' if ctx.user.user_id == user.user_id else 'any' auth.verify_privilege(ctx.user, 'users:delete:%s' % infix) ctx.session.delete(user) ctx.session.commit() diff --git a/server/szurubooru/func/comments.py b/server/szurubooru/func/comments.py index 873bc7dc..576cca23 100644 --- a/server/szurubooru/func/comments.py +++ b/server/szurubooru/func/comments.py @@ -15,12 +15,18 @@ def serialize_comment(comment, authenticated_user): 'lastEditTime': comment.last_edit_time, } -def get_comment_by_id(comment_id): +def try_get_comment_by_id(comment_id): return db.session \ .query(db.Comment) \ .filter(db.Comment.comment_id == comment_id) \ .one_or_none() +def get_comment_by_id(comment_id): + comment = try_get_comment_by_id(comment_id) + if comment: + return comment + raise CommentNotFoundError('Comment %r not found.' % comment_id) + def create_comment(user, post, text): comment = db.Comment() comment.user = user diff --git a/server/szurubooru/func/posts.py b/server/szurubooru/func/posts.py index 8811caa2..5d947b20 100644 --- a/server/szurubooru/func/posts.py +++ b/server/szurubooru/func/posts.py @@ -49,13 +49,19 @@ def serialize_post_with_details(post, authenticated_user): def get_post_count(): return db.session.query(sqlalchemy.func.count(db.Post.post_id)).one()[0] -def get_post_by_id(post_id): +def try_get_post_by_id(post_id): return db.session \ .query(db.Post) \ .filter(db.Post.post_id == post_id) \ .one_or_none() -def get_featured_post(): +def get_post_by_id(post_id): + post = try_get_post_by_id(post_id) + if not post: + raise PostNotFoundError('Post %r not found.' % post_id) + return post + +def try_get_featured_post(): post_feature = db.session \ .query(db.PostFeature) \ .order_by(db.PostFeature.time.desc()) \ diff --git a/server/szurubooru/func/tag_categories.py b/server/szurubooru/func/tag_categories.py index c4a55cd7..90bb3876 100644 --- a/server/szurubooru/func/tag_categories.py +++ b/server/szurubooru/func/tag_categories.py @@ -1,6 +1,6 @@ import re from szurubooru import config, db, errors -from szurubooru.func import util +from szurubooru.func import util, snapshots class TagCategoryNotFoundError(errors.NotFoundError): pass class TagCategoryAlreadyExistsError(errors.ValidationError): pass @@ -14,6 +14,18 @@ def _verify_name_validity(name): raise InvalidTagCategoryNameError( 'Name must satisfy regex %r.' % name_regex) +def serialize_category(category): + return { + 'name': category.name, + 'color': category.color, + } + +def serialize_category_with_details(category): + return { + 'tagCategory': serialize_category(category), + 'snapshots': snapshots.get_serialized_history(category), + } + def create_category(name, color): category = db.TagCategory() update_name(category, name) @@ -42,11 +54,17 @@ def update_color(category, color): raise InvalidTagCategoryColorError('Color is too long.') category.color = color -def get_category_by_name(name): +def try_get_category_by_name(name): return db.session \ .query(db.TagCategory) \ .filter(db.TagCategory.name.ilike(name)) \ - .first() + .one_or_none() + +def get_category_by_name(name): + category = try_get_category_by_name(name) + if not category: + raise TagCategoryNotFoundError('Tag category %r not found.' % name) + return category def get_all_category_names(): return [row[0] for row in db.session.query(db.TagCategory.name).all()] @@ -54,9 +72,15 @@ def get_all_category_names(): def get_all_categories(): return db.session.query(db.TagCategory).all() -def get_default_category(): +def try_get_default_category(): return db.session \ .query(db.TagCategory) \ .order_by(db.TagCategory.tag_category_id.asc()) \ .limit(1) \ .one() + +def get_default_category(): + category = try_get_default_category() + if not category: + raise TagCategoryNotFoundError('No tag category created yet.') + return category diff --git a/server/szurubooru/func/tags.py b/server/szurubooru/func/tags.py index 7b5c70fb..14619a02 100644 --- a/server/szurubooru/func/tags.py +++ b/server/szurubooru/func/tags.py @@ -13,6 +13,20 @@ class InvalidTagNameError(errors.ValidationError): pass class InvalidTagCategoryError(errors.ValidationError): pass class InvalidTagRelationError(errors.ValidationError): pass +def _verify_name_validity(name): + name_regex = config.config['tag_name_regex'] + if not re.match(name_regex, name): + raise InvalidTagNameError('Name must satisfy regex %r.' % name_regex) + +def _get_plain_names(tag): + return [tag_name.name for tag_name in tag.names] + +def _lower_list(names): + return [name.lower() for name in names] + +def _check_name_intersection(names1, names2): + return len(set(_lower_list(names1)).intersection(_lower_list(names2))) > 0 + def serialize_tag(tag): return { 'names': [tag_name.name for tag_name in tag.names], @@ -31,32 +45,6 @@ def serialize_tag_with_details(tag): 'snapshots': snapshots.get_serialized_history(tag), } -def serialize_category(category): - return { - 'name': category.name, - 'color': category.color, - } - -def serialize_category_with_details(category): - return { - 'tagCategory': serialize_category(category), - 'snapshots': snapshots.get_serialized_history(category), - } - -def _verify_name_validity(name): - name_regex = config.config['tag_name_regex'] - if not re.match(name_regex, name): - raise InvalidTagNameError('Name must satisfy regex %r.' % name_regex) - -def _get_plain_names(tag): - return [tag_name.name for tag_name in tag.names] - -def _lower_list(names): - return [name.lower() for name in names] - -def _check_name_intersection(names1, names2): - return len(set(_lower_list(names1)).intersection(_lower_list(names2))) > 0 - def export_to_json(): output = { 'tags': [], @@ -90,12 +78,18 @@ def export_to_json(): with open(export_path, 'w') as handle: handle.write(json.dumps(output, separators=(',', ':'))) -def get_tag_by_name(name): +def try_get_tag_by_name(name): return db.session \ .query(db.Tag) \ .join(db.TagName) \ .filter(db.TagName.name.ilike(name)) \ - .first() + .one_or_none() + +def get_tag_by_name(name): + tag = try_get_tag_by_name(name) + if not tag: + raise TagNotFoundError('Tag %r not found.' % name) + return tag def get_tags_by_names(names): names = util.icase_unique(names) diff --git a/server/szurubooru/func/users.py b/server/szurubooru/func/users.py index 85272fff..5a418813 100644 --- a/server/szurubooru/func/users.py +++ b/server/szurubooru/func/users.py @@ -44,19 +44,31 @@ def serialize_user(user, authenticated_user): def get_user_count(): return db.session.query(db.User).count() -def get_user_by_name(name): +def try_get_user_by_name(name): return db.session \ .query(db.User) \ .filter(func.lower(db.User.name) == func.lower(name)) \ - .first() + .one_or_none() -def get_user_by_name_or_email(name_or_email): +def get_user_by_name(name): + user = try_get_user_by_name(name) + if not user: + raise UserNotFoundError('User %r not found.' % name) + return user + +def try_get_user_by_name_or_email(name_or_email): return db.session \ .query(db.User) \ .filter( (func.lower(db.User.name) == func.lower(name_or_email)) | (func.lower(db.User.email) == func.lower(name_or_email))) \ - .first() + .one_or_none() + +def get_user_by_name_or_email(name_or_email): + user = try_get_user_by_name_or_email(name_or_email) + if not user: + raise UserNotFoundError('User %r not found.' % name_or_email) + return user def create_user(name, password, email, auth_user): user = db.User() @@ -76,7 +88,7 @@ def update_name(user, name, auth_user): raise InvalidUserNameError('Name cannot be empty.') if util.value_exceeds_column_size(name, db.User.name): raise InvalidUserNameError('User name is too long.') - other_user = get_user_by_name(name) + other_user = try_get_user_by_name(name) if other_user and other_user.user_id != auth_user.user_id: raise UserAlreadyExistsError('User %r already exists.' % name) name = name.strip() diff --git a/server/szurubooru/tests/api/test_post_featuring.py b/server/szurubooru/tests/api/test_post_featuring.py index 9bb751ba..a7cc24ae 100644 --- a/server/szurubooru/tests/api/test_post_featuring.py +++ b/server/szurubooru/tests/api/test_post_featuring.py @@ -20,7 +20,7 @@ def test_ctx(context_factory, config_injector, user_factory, post_factory): return ret def test_no_featured_post(test_ctx): - assert posts.get_featured_post() is None + assert posts.try_get_featured_post() is None result = test_ctx.api.get( test_ctx.context_factory( user=test_ctx.user_factory(rank='regular_user'))) @@ -34,8 +34,8 @@ def test_featuring(test_ctx): test_ctx.context_factory( input={'id': 1}, user=test_ctx.user_factory(rank='regular_user'))) - assert posts.get_featured_post() is not None - assert posts.get_featured_post().post_id == 1 + assert posts.try_get_featured_post() is not None + assert posts.try_get_featured_post().post_id == 1 assert posts.get_post_by_id(1).is_featured assert 'post' in result assert 'snapshots' in result @@ -63,7 +63,7 @@ def test_featuring_one_post_after_another(test_ctx, fake_datetime): db.session.add(test_ctx.post_factory(id=1)) db.session.add(test_ctx.post_factory(id=2)) db.session.commit() - assert posts.get_featured_post() is None + assert posts.try_get_featured_post() is None assert not posts.get_post_by_id(1).is_featured assert not posts.get_post_by_id(2).is_featured with fake_datetime('1997'): @@ -76,8 +76,8 @@ def test_featuring_one_post_after_another(test_ctx, fake_datetime): test_ctx.context_factory( input={'id': 2}, user=test_ctx.user_factory(rank='regular_user'))) - assert posts.get_featured_post() is not None - assert posts.get_featured_post().post_id == 2 + assert posts.try_get_featured_post() is not None + assert posts.try_get_featured_post().post_id == 2 assert not posts.get_post_by_id(1).is_featured assert posts.get_post_by_id(2).is_featured diff --git a/server/szurubooru/tests/api/test_tag_category_updating.py b/server/szurubooru/tests/api/test_tag_category_updating.py index 3924c7f7..592cdd4e 100644 --- a/server/szurubooru/tests/api/test_tag_category_updating.py +++ b/server/szurubooru/tests/api/test_tag_category_updating.py @@ -43,7 +43,7 @@ def test_simple_updating(test_ctx): 'color': 'white', } assert len(result['snapshots']) == 1 - assert tag_categories.get_category_by_name('name') is None + assert tag_categories.try_get_category_by_name('name') is None category = tag_categories.get_category_by_name('changed') assert category is not None assert category.name == 'changed' diff --git a/server/szurubooru/tests/api/test_tag_creating.py b/server/szurubooru/tests/api/test_tag_creating.py index 5061d079..5f4b6fee 100644 --- a/server/szurubooru/tests/api/test_tag_creating.py +++ b/server/szurubooru/tests/api/test_tag_creating.py @@ -4,13 +4,6 @@ import pytest from szurubooru import api, config, db, errors from szurubooru.func import util, tags -def get_tag(name): - return db.session \ - .query(db.Tag) \ - .join(db.TagName) \ - .filter(db.TagName.name==name) \ - .first() - def assert_relations(relations, expected_tag_names): actual_names = [rel.names[0].name for rel in relations] assert actual_names == expected_tag_names @@ -54,7 +47,7 @@ def test_creating_simple_tags(test_ctx, fake_datetime): 'lastEditTime': None, } assert len(result['snapshots']) == 1 - tag = get_tag('tag1') + tag = tags.get_tag_by_name('tag1') assert [tag_name.name for tag_name in tag.names] == ['tag1', 'tag2'] assert tag.category.name == 'meta' assert tag.last_edit_time is None @@ -133,7 +126,7 @@ def test_duplicating_names(test_ctx): user=test_ctx.user_factory(rank='regular_user'))) assert result['tag']['names'] == ['tag1'] assert result['tag']['category'] == 'meta' - tag = get_tag('tag1') + tag = tags.get_tag_by_name('tag1') assert [tag_name.name for tag_name in tag.names] == ['tag1'] def test_trying_to_use_existing_name(test_ctx): @@ -162,7 +155,7 @@ def test_trying_to_use_existing_name(test_ctx): 'implications': [], }, user=test_ctx.user_factory(rank='regular_user'))) - assert get_tag('unused') is None + assert tags.try_get_tag_by_name('unused') is None @pytest.mark.parametrize('input,expected_suggestions,expected_implications', [ # new relations @@ -201,11 +194,11 @@ def test_creating_new_suggestions_and_implications( input=input, user=test_ctx.user_factory(rank='regular_user'))) assert result['tag']['suggestions'] == expected_suggestions assert result['tag']['implications'] == expected_implications - tag = get_tag('main') + tag = tags.get_tag_by_name('main') assert_relations(tag.suggestions, expected_suggestions) assert_relations(tag.implications, expected_implications) for name in ['main'] + expected_suggestions + expected_implications: - assert get_tag(name) is not None + assert tags.try_get_tag_by_name(name) is not None def test_reusing_suggestions_and_implications(test_ctx): db.session.add_all([ @@ -225,7 +218,7 @@ def test_reusing_suggestions_and_implications(test_ctx): # NOTE: it should export only the first name assert result['tag']['suggestions'] == ['tag1'] assert result['tag']['implications'] == ['tag1'] - tag = get_tag('new') + tag = tags.get_tag_by_name('new') assert_relations(tag.suggestions, ['tag1']) assert_relations(tag.implications, ['tag1']) @@ -249,7 +242,7 @@ def test_tag_trying_to_relate_to_itself(test_ctx, input): test_ctx.context_factory( input=input, user=test_ctx.user_factory(rank='regular_user'))) - assert get_tag('tag') is None + assert tags.try_get_tag_by_name('tag') is None def test_trying_to_create_tag_without_privileges(test_ctx): with pytest.raises(errors.AuthError): diff --git a/server/szurubooru/tests/api/test_tag_merging.py b/server/szurubooru/tests/api/test_tag_merging.py index e6dc0984..2d7a2b69 100644 --- a/server/szurubooru/tests/api/test_tag_merging.py +++ b/server/szurubooru/tests/api/test_tag_merging.py @@ -4,13 +4,6 @@ import pytest from szurubooru import api, config, db, errors from szurubooru.func import util, tags -def get_tag(name): - return db.session \ - .query(db.Tag) \ - .join(db.TagName) \ - .filter(db.TagName.name==name) \ - .first() - @pytest.fixture def test_ctx( tmpdir, config_injector, context_factory, user_factory, tag_factory): @@ -50,8 +43,8 @@ def test_merging_without_usages(test_ctx, fake_datetime): 'lastEditTime': None, } assert 'snapshots' in result - assert get_tag('source') is None - tag = get_tag('target') + assert tags.try_get_tag_by_name('source') is None + tag = tags.get_tag_by_name('target') assert tag is not None assert os.path.exists(os.path.join(config.config['data_dir'], 'tags.json')) @@ -76,8 +69,8 @@ def test_merging_with_usages(test_ctx, fake_datetime, post_factory): 'merge-to': 'target', }, user=test_ctx.user_factory(rank='regular_user'))) - assert get_tag('source') is None - assert get_tag('target').post_count == 1 + assert tags.try_get_tag_by_name('source') is None + assert tags.get_tag_by_name('target').post_count == 1 @pytest.mark.parametrize('input,expected_exception', [ ({'remove': None}, tags.TagNotFoundError), diff --git a/server/szurubooru/tests/api/test_tag_updating.py b/server/szurubooru/tests/api/test_tag_updating.py index 282062ba..f8333993 100644 --- a/server/szurubooru/tests/api/test_tag_updating.py +++ b/server/szurubooru/tests/api/test_tag_updating.py @@ -4,13 +4,6 @@ import pytest from szurubooru import api, config, db, errors from szurubooru.func import util, tags -def get_tag(name): - return db.session \ - .query(db.Tag) \ - .join(db.TagName) \ - .filter(db.TagName.name==name) \ - .first() - def assert_relations(relations, expected_tag_names): actual_names = [rel.names[0].name for rel in relations] assert actual_names == expected_tag_names @@ -61,9 +54,9 @@ def test_simple_updating(test_ctx, fake_datetime): 'lastEditTime': datetime.datetime(1997, 12, 1), } assert len(result['snapshots']) == 1 - assert get_tag('tag1') is None - assert get_tag('tag2') is None - tag = get_tag('tag3') + assert tags.try_get_tag_by_name('tag1') is None + assert tags.try_get_tag_by_name('tag2') is None + tag = tags.get_tag_by_name('tag3') assert tag is not None assert [tag_name.name for tag_name in tag.names] == ['tag3'] assert tag.category.name == 'character' @@ -132,9 +125,9 @@ def test_reusing_own_name(test_ctx, dup_name): user=test_ctx.user_factory(rank='regular_user')), 'tag1') assert result['tag']['names'] == ['tag1', 'tag3'] - assert get_tag('tag2') is None - tag1 = get_tag('tag1') - tag2 = get_tag('tag3') + assert tags.try_get_tag_by_name('tag2') is None + tag1 = tags.get_tag_by_name('tag1') + tag2 = tags.get_tag_by_name('tag3') assert tag1.tag_id == tag2.tag_id assert [name.name for name in tag1.names] == ['tag1', 'tag3'] @@ -147,9 +140,9 @@ def test_duplicating_names(test_ctx): user=test_ctx.user_factory(rank='regular_user')), 'tag1') assert result['tag']['names'] == ['tag3'] - assert get_tag('tag1') is None - assert get_tag('tag2') is None - tag = get_tag('tag3') + assert tags.try_get_tag_by_name('tag1') is None + assert tags.try_get_tag_by_name('tag2') is None + tag = tags.get_tag_by_name('tag3') assert tag is not None assert [tag_name.name for tag_name in tag.names] == ['tag3'] @@ -199,11 +192,11 @@ def test_updating_new_suggestions_and_implications( 'main') assert result['tag']['suggestions'] == expected_suggestions assert result['tag']['implications'] == expected_implications - tag = get_tag('main') + tag = tags.get_tag_by_name('main') assert_relations(tag.suggestions, expected_suggestions) assert_relations(tag.implications, expected_implications) for name in ['main'] + expected_suggestions + expected_implications: - assert get_tag(name) is not None + assert tags.try_get_tag_by_name(name) is not None def test_reusing_suggestions_and_implications(test_ctx): db.session.add_all([ @@ -225,7 +218,7 @@ def test_reusing_suggestions_and_implications(test_ctx): # NOTE: it should export only the first name assert result['tag']['suggestions'] == ['tag1'] assert result['tag']['implications'] == ['tag1'] - tag = get_tag('new') + tag = tags.get_tag_by_name('new') assert_relations(tag.suggestions, ['tag1']) assert_relations(tag.implications, ['tag1']) diff --git a/server/szurubooru/tests/api/test_user_creating.py b/server/szurubooru/tests/api/test_user_creating.py index 274f0f40..3bb18ff4 100644 --- a/server/szurubooru/tests/api/test_user_creating.py +++ b/server/szurubooru/tests/api/test_user_creating.py @@ -8,9 +8,6 @@ EMPTY_PIXEL = \ b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \ b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b' -def get_user(name): - return db.session.query(db.User).filter_by(name=name).first() - @pytest.fixture def test_ctx(config_injector, context_factory, user_factory): config_injector({ @@ -51,7 +48,7 @@ def test_creating_user(test_ctx, fake_datetime): 'rankName': 'Unknown', } } - user = get_user('chewie1') + user = users.get_user_by_name('chewie1') assert user.name == 'chewie1' assert user.email == 'asd@asd.asd' assert user.rank == 'admin' @@ -77,8 +74,8 @@ def test_first_user_becomes_admin_others_not(test_ctx): user=test_ctx.user_factory(rank='anonymous'))) assert result1['user']['rank'] == 'admin' assert result2['user']['rank'] == 'regular_user' - first_user = get_user('chewie1') - other_user = get_user('chewie2') + first_user = users.get_user_by_name('chewie1') + other_user = users.get_user_by_name('chewie2') assert first_user.rank == 'admin' assert other_user.rank == 'regular_user' @@ -225,7 +222,7 @@ def test_uploading_avatar(test_ctx, tmpdir): }, files={'avatar': EMPTY_PIXEL}, user=test_ctx.user_factory(rank='mod'))) - user = get_user('chewie') + user = users.get_user_by_name('chewie') assert user.avatar_style == user.AVATAR_MANUAL assert response['user']['avatarUrl'] == \ 'http://example.com/data/avatars/chewie.jpg' diff --git a/server/szurubooru/tests/api/test_user_updating.py b/server/szurubooru/tests/api/test_user_updating.py index e069c24e..72723f4e 100644 --- a/server/szurubooru/tests/api/test_user_updating.py +++ b/server/szurubooru/tests/api/test_user_updating.py @@ -8,9 +8,6 @@ EMPTY_PIXEL = \ b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \ b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b' -def get_user(name): - return db.session.query(db.User).filter_by(name=name).first() - @pytest.fixture def test_ctx(config_injector, context_factory, user_factory): config_injector({ @@ -66,7 +63,7 @@ def test_updating_user(test_ctx): 'rankName': 'Unknown', } } - user = get_user('chewie') + user = users.get_user_by_name('chewie') assert user.name == 'chewie' assert user.email == 'asd@asd.asd' assert user.rank == 'mod' @@ -133,7 +130,7 @@ def test_removing_email(test_ctx): db.session.add(user) test_ctx.api.put( test_ctx.context_factory(input={'email': ''}, user=user), 'u1') - assert get_user('u1').email is None + assert users.get_user_by_name('u1').email is None @pytest.mark.parametrize('input', [ {'name': 'whatever'}, @@ -183,7 +180,7 @@ def test_uploading_avatar(test_ctx, tmpdir): files={'avatar': EMPTY_PIXEL}, user=user), 'u1') - user = get_user('u1') + user = users.get_user_by_name('u1') assert user.avatar_style == user.AVATAR_MANUAL assert response['user']['avatarUrl'] == \ 'http://example.com/data/avatars/u1.jpg'