diff --git a/server/.pylintrc b/server/.pylintrc index 1188f6df..70f46dda 100644 --- a/server/.pylintrc +++ b/server/.pylintrc @@ -8,11 +8,26 @@ good-names=ex,_,logger dummy-variables-rgx=_|dummy [format] -max-line-length=90 +max-line-length=79 [messages control] -disable=missing-docstring,no-self-use,too-few-public-methods,multiple-statements reports=no +disable= + # we're not java + missing-docstring, + + # covered better by pycodestyle + bad-continuation, + + # we're adults + redefined-builtin, + duplicate-code, + too-many-return-statements, + too-many-arguments, + + # plain stupid + no-self-use, + too-few-public-methods [typecheck] generated-members=add|add_all diff --git a/server/szurubooru/api/comment_api.py b/server/szurubooru/api/comment_api.py index 7ac72bc2..08fd09c4 100644 --- a/server/szurubooru/api/comment_api.py +++ b/server/szurubooru/api/comment_api.py @@ -3,20 +3,24 @@ from szurubooru import search from szurubooru.func import auth, comments, posts, scores, util from szurubooru.rest import routes + _search_executor = search.Executor(search.configs.CommentSearchConfig()) + def _serialize(ctx, comment, **kwargs): return comments.serialize_comment( comment, ctx.user, options=util.get_serialization_options(ctx), **kwargs) + @routes.get('/comments/?') def get_comments(ctx, _params=None): auth.verify_privilege(ctx.user, 'comments:list') return _search_executor.execute_and_serialize( ctx, lambda comment: _serialize(ctx, comment)) + @routes.post('/comments/?') def create_comment(ctx, _params=None): auth.verify_privilege(ctx.user, 'comments:create') @@ -28,12 +32,14 @@ def create_comment(ctx, _params=None): ctx.session.commit() return _serialize(ctx, comment) + @routes.get('/comment/(?P[^/]+)/?') def get_comment(ctx, params): auth.verify_privilege(ctx.user, 'comments:view') comment = comments.get_comment_by_id(params['comment_id']) return _serialize(ctx, comment) + @routes.put('/comment/(?P[^/]+)/?') def update_comment(ctx, params): comment = comments.get_comment_by_id(params['comment_id']) @@ -47,6 +53,7 @@ def update_comment(ctx, params): ctx.session.commit() return _serialize(ctx, comment) + @routes.delete('/comment/(?P[^/]+)/?') def delete_comment(ctx, params): comment = comments.get_comment_by_id(params['comment_id']) @@ -57,6 +64,7 @@ def delete_comment(ctx, params): ctx.session.commit() return {} + @routes.put('/comment/(?P[^/]+)/score/?') def set_comment_score(ctx, params): auth.verify_privilege(ctx.user, 'comments:score') @@ -66,6 +74,7 @@ def set_comment_score(ctx, params): ctx.session.commit() return _serialize(ctx, comment) + @routes.delete('/comment/(?P[^/]+)/score/?') def delete_comment_score(ctx, params): auth.verify_privilege(ctx.user, 'comments:score') diff --git a/server/szurubooru/api/info_api.py b/server/szurubooru/api/info_api.py index 16a4e384..a5485546 100644 --- a/server/szurubooru/api/info_api.py +++ b/server/szurubooru/api/info_api.py @@ -4,11 +4,13 @@ from szurubooru import config from szurubooru.func import posts, users, util from szurubooru.rest import routes + _cache_time = None _cache_result = None + def _get_disk_usage(): - global _cache_time, _cache_result # pylint: disable=global-statement + global _cache_time, _cache_result # pylint: disable=global-statement threshold = datetime.timedelta(hours=1) now = datetime.datetime.utcnow() if _cache_time and _cache_time > now - threshold: @@ -22,17 +24,20 @@ def _get_disk_usage(): _cache_result = total_size return total_size + @routes.get('/info/?') def get_info(ctx, _params=None): post_feature = posts.try_get_current_post_feature() return { 'postCount': posts.get_post_count(), 'diskUsage': _get_disk_usage(), - 'featuredPost': posts.serialize_post(post_feature.post, ctx.user) \ - if post_feature else None, + 'featuredPost': + posts.serialize_post(post_feature.post, ctx.user) + if post_feature else None, 'featuringTime': post_feature.time if post_feature else None, - 'featuringUser': users.serialize_user(post_feature.user, ctx.user) \ - if post_feature else None, + 'featuringUser': + users.serialize_user(post_feature.user, ctx.user) + if post_feature else None, 'serverTime': datetime.datetime.utcnow(), 'config': { 'userNameRegex': config.config['user_name_regex'], @@ -40,7 +45,8 @@ def get_info(ctx, _params=None): 'tagNameRegex': config.config['tag_name_regex'], 'tagCategoryNameRegex': config.config['tag_category_name_regex'], 'defaultUserRank': config.config['default_rank'], - 'privileges': util.snake_case_to_lower_camel_case_keys( - config.config['privileges']), + 'privileges': + util.snake_case_to_lower_camel_case_keys( + config.config['privileges']), }, } diff --git a/server/szurubooru/api/password_reset_api.py b/server/szurubooru/api/password_reset_api.py index 2040fb64..fbc4ba4d 100644 --- a/server/szurubooru/api/password_reset_api.py +++ b/server/szurubooru/api/password_reset_api.py @@ -2,12 +2,14 @@ from szurubooru import config, errors from szurubooru.func import auth, mailer, users, util from szurubooru.rest import routes + MAIL_SUBJECT = 'Password reset for {name}' MAIL_BODY = \ 'You (or someone else) requested to reset your password on {name}.\n' \ 'If you wish to proceed, click this link: {url}\n' \ 'Otherwise, please ignore this email.' + @routes.get('/password-reset/(?P[^/]+)/?') def start_password_reset(_ctx, params): ''' Send a mail with secure token to the correlated user. ''' @@ -27,6 +29,7 @@ def start_password_reset(_ctx, params): MAIL_BODY.format(name=config.config['name'], url=url)) return {} + @routes.post('/password-reset/(?P[^/]+)/?') def finish_password_reset(ctx, params): ''' Verify token from mail, generate a new password and return it. ''' diff --git a/server/szurubooru/api/post_api.py b/server/szurubooru/api/post_api.py index b75f40b7..e800bd54 100644 --- a/server/szurubooru/api/post_api.py +++ b/server/szurubooru/api/post_api.py @@ -1,16 +1,20 @@ import datetime from szurubooru import search -from szurubooru.func import auth, tags, posts, snapshots, favorites, scores, util from szurubooru.rest import routes +from szurubooru.func import ( + auth, tags, posts, snapshots, favorites, scores, util) + _search_executor = search.Executor(search.configs.PostSearchConfig()) + def _serialize_post(ctx, post): return posts.serialize_post( post, ctx.user, options=util.get_serialization_options(ctx)) + @routes.get('/posts/?') def get_posts(ctx, _params=None): auth.verify_privilege(ctx.user, 'posts:list') @@ -18,6 +22,7 @@ def get_posts(ctx, _params=None): return _search_executor.execute_and_serialize( ctx, lambda post: _serialize_post(ctx, post)) + @routes.post('/posts/?') def create_post(ctx, _params=None): anonymous = ctx.get_param_as_bool('anonymous', default=False) @@ -52,12 +57,14 @@ def create_post(ctx, _params=None): tags.export_to_json() return _serialize_post(ctx, post) + @routes.get('/post/(?P[^/]+)/?') def get_post(ctx, params): auth.verify_privilege(ctx.user, 'posts:view') post = posts.get_post_by_id(params['post_id']) return _serialize_post(ctx, post) + @routes.put('/post/(?P[^/]+)/?') def update_post(ctx, params): post = posts.get_post_by_id(params['post_id']) @@ -98,6 +105,7 @@ def update_post(ctx, params): tags.export_to_json() return _serialize_post(ctx, post) + @routes.delete('/post/(?P[^/]+)/?') def delete_post(ctx, params): auth.verify_privilege(ctx.user, 'posts:delete') @@ -109,11 +117,13 @@ def delete_post(ctx, params): tags.export_to_json() return {} + @routes.get('/featured-post/?') def get_featured_post(ctx, _params=None): post = posts.try_get_featured_post() return _serialize_post(ctx, post) + @routes.post('/featured-post/?') def set_featured_post(ctx, _params=None): auth.verify_privilege(ctx.user, 'posts:feature') @@ -130,6 +140,7 @@ def set_featured_post(ctx, _params=None): ctx.session.commit() return _serialize_post(ctx, post) + @routes.put('/post/(?P[^/]+)/score/?') def set_post_score(ctx, params): auth.verify_privilege(ctx.user, 'posts:score') @@ -139,6 +150,7 @@ def set_post_score(ctx, params): ctx.session.commit() return _serialize_post(ctx, post) + @routes.delete('/post/(?P[^/]+)/score/?') def delete_post_score(ctx, params): auth.verify_privilege(ctx.user, 'posts:score') @@ -147,6 +159,7 @@ def delete_post_score(ctx, params): ctx.session.commit() return _serialize_post(ctx, post) + @routes.post('/post/(?P[^/]+)/favorite/?') def add_post_to_favorites(ctx, params): auth.verify_privilege(ctx.user, 'posts:favorite') @@ -155,6 +168,7 @@ def add_post_to_favorites(ctx, params): ctx.session.commit() return _serialize_post(ctx, post) + @routes.delete('/post/(?P[^/]+)/favorite/?') def delete_post_from_favorites(ctx, params): auth.verify_privilege(ctx.user, 'posts:favorite') @@ -163,6 +177,7 @@ def delete_post_from_favorites(ctx, params): ctx.session.commit() return _serialize_post(ctx, post) + @routes.get('/post/(?P[^/]+)/around/?') def get_posts_around(ctx, params): auth.verify_privilege(ctx.user, 'posts:list') diff --git a/server/szurubooru/api/snapshot_api.py b/server/szurubooru/api/snapshot_api.py index 1fd6fc52..bf6eba86 100644 --- a/server/szurubooru/api/snapshot_api.py +++ b/server/szurubooru/api/snapshot_api.py @@ -2,8 +2,9 @@ from szurubooru import search from szurubooru.func import auth, snapshots from szurubooru.rest import routes -_search_executor = search.Executor( - search.configs.SnapshotSearchConfig()) + +_search_executor = search.Executor(search.configs.SnapshotSearchConfig()) + @routes.get('/snapshots/?') def get_snapshots(ctx, _params=None): diff --git a/server/szurubooru/api/tag_api.py b/server/szurubooru/api/tag_api.py index 079e3dd5..630cf19b 100644 --- a/server/szurubooru/api/tag_api.py +++ b/server/szurubooru/api/tag_api.py @@ -3,12 +3,15 @@ from szurubooru import db, search from szurubooru.func import auth, tags, util, snapshots from szurubooru.rest import routes + _search_executor = search.Executor(search.configs.TagSearchConfig()) + def _serialize(ctx, tag): return tags.serialize_tag( tag, options=util.get_serialization_options(ctx)) + def _create_if_needed(tag_names, user): if not tag_names: return @@ -19,12 +22,14 @@ def _create_if_needed(tag_names, user): for tag in new_tags: snapshots.save_entity_creation(tag, user) + @routes.get('/tags/?') def get_tags(ctx, _params=None): auth.verify_privilege(ctx.user, 'tags:list') return _search_executor.execute_and_serialize( ctx, lambda tag: _serialize(ctx, tag)) + @routes.post('/tags/?') def create_tag(ctx, _params=None): auth.verify_privilege(ctx.user, 'tags:create') @@ -50,12 +55,14 @@ def create_tag(ctx, _params=None): tags.export_to_json() return _serialize(ctx, tag) + @routes.get('/tag/(?P[^/]+)/?') def get_tag(ctx, params): auth.verify_privilege(ctx.user, 'tags:view') tag = tags.get_tag_by_name(params['tag_name']) return _serialize(ctx, tag) + @routes.put('/tag/(?P[^/]+)/?') def update_tag(ctx, params): tag = tags.get_tag_by_name(params['tag_name']) @@ -89,6 +96,7 @@ def update_tag(ctx, params): tags.export_to_json() return _serialize(ctx, tag) + @routes.delete('/tag/(?P[^/]+)/?') def delete_tag(ctx, params): tag = tags.get_tag_by_name(params['tag_name']) @@ -100,6 +108,7 @@ def delete_tag(ctx, params): tags.export_to_json() return {} + @routes.post('/tag-merge/?') def merge_tags(ctx, _params=None): source_tag_name = ctx.get_param_as_string('remove', required=True) or '' @@ -116,6 +125,7 @@ def merge_tags(ctx, _params=None): tags.export_to_json() return _serialize(ctx, target_tag) + @routes.get('/tag-siblings/(?P[^/]+)/?') def get_tag_siblings(ctx, params): auth.verify_privilege(ctx.user, 'tags:view') diff --git a/server/szurubooru/api/tag_category_api.py b/server/szurubooru/api/tag_category_api.py index 9ac9ffa5..fbd25530 100644 --- a/server/szurubooru/api/tag_category_api.py +++ b/server/szurubooru/api/tag_category_api.py @@ -1,10 +1,12 @@ from szurubooru.rest import routes from szurubooru.func import auth, tags, tag_categories, util, snapshots + def _serialize(ctx, category): return tag_categories.serialize_category( category, options=util.get_serialization_options(ctx)) + @routes.get('/tag-categories/?') def get_tag_categories(ctx, _params=None): auth.verify_privilege(ctx.user, 'tag_categories:list') @@ -13,6 +15,7 @@ def get_tag_categories(ctx, _params=None): 'results': [_serialize(ctx, category) for category in categories], } + @routes.post('/tag-categories/?') def create_tag_category(ctx, _params=None): auth.verify_privilege(ctx.user, 'tag_categories:create') @@ -26,12 +29,14 @@ def create_tag_category(ctx, _params=None): tags.export_to_json() return _serialize(ctx, category) + @routes.get('/tag-category/(?P[^/]+)/?') def get_tag_category(ctx, params): auth.verify_privilege(ctx.user, 'tag_categories:view') category = tag_categories.get_category_by_name(params['category_name']) return _serialize(ctx, category) + @routes.put('/tag-category/(?P[^/]+)/?') def update_tag_category(ctx, params): category = tag_categories.get_category_by_name(params['category_name']) @@ -51,6 +56,7 @@ def update_tag_category(ctx, params): tags.export_to_json() return _serialize(ctx, category) + @routes.delete('/tag-category/(?P[^/]+)/?') def delete_tag_category(ctx, params): category = tag_categories.get_category_by_name(params['category_name']) @@ -62,6 +68,7 @@ def delete_tag_category(ctx, params): tags.export_to_json() return {} + @routes.put('/tag-category/(?P[^/]+)/default/?') def set_tag_category_as_default(ctx, params): auth.verify_privilege(ctx.user, 'tag_categories:set_default') diff --git a/server/szurubooru/api/user_api.py b/server/szurubooru/api/user_api.py index 94b82f3f..aa99c42f 100644 --- a/server/szurubooru/api/user_api.py +++ b/server/szurubooru/api/user_api.py @@ -2,8 +2,10 @@ from szurubooru import search from szurubooru.func import auth, users, util from szurubooru.rest import routes + _search_executor = search.Executor(search.configs.UserSearchConfig()) + def _serialize(ctx, user, **kwargs): return users.serialize_user( user, @@ -11,12 +13,14 @@ def _serialize(ctx, user, **kwargs): options=util.get_serialization_options(ctx), **kwargs) + @routes.get('/users/?') def get_users(ctx, _params=None): auth.verify_privilege(ctx.user, 'users:list') return _search_executor.execute_and_serialize( ctx, lambda user: _serialize(ctx, user)) + @routes.post('/users/?') def create_user(ctx, _params=None): auth.verify_privilege(ctx.user, 'users:create') @@ -36,6 +40,7 @@ def create_user(ctx, _params=None): ctx.session.commit() return _serialize(ctx, user, force_show_email=True) + @routes.get('/user/(?P[^/]+)/?') def get_user(ctx, params): user = users.get_user_by_name(params['user_name']) @@ -43,6 +48,7 @@ def get_user(ctx, params): auth.verify_privilege(ctx.user, 'users:view') return _serialize(ctx, user) + @routes.put('/user/(?P[^/]+)/?') def update_user(ctx, params): user = users.get_user_by_name(params['user_name']) @@ -72,6 +78,7 @@ def update_user(ctx, params): ctx.session.commit() return _serialize(ctx, user) + @routes.delete('/user/(?P[^/]+)/?') def delete_user(ctx, params): user = users.get_user_by_name(params['user_name']) diff --git a/server/szurubooru/config.py b/server/szurubooru/config.py index d1723ee9..e5693117 100644 --- a/server/szurubooru/config.py +++ b/server/szurubooru/config.py @@ -1,6 +1,7 @@ import os import yaml + def merge(left, right): for key in right: if key in left: @@ -12,6 +13,7 @@ def merge(left, right): left[key] = right[key] return left + def read_config(): with open('../config.yaml.dist') as handle: ret = yaml.load(handle.read()) @@ -20,4 +22,5 @@ def read_config(): ret = merge(ret, yaml.load(handle.read())) return ret -config = read_config() # pylint: disable=invalid-name + +config = read_config() # pylint: disable=invalid-name diff --git a/server/szurubooru/db/base.py b/server/szurubooru/db/base.py index f7e43148..e61d35a9 100644 --- a/server/szurubooru/db/base.py +++ b/server/szurubooru/db/base.py @@ -1,2 +1,4 @@ from sqlalchemy.ext.declarative import declarative_base -Base = declarative_base() # pylint: disable=invalid-name + + +Base = declarative_base() # pylint: disable=invalid-name diff --git a/server/szurubooru/db/comment.py b/server/szurubooru/db/comment.py index b099aefb..05298282 100644 --- a/server/szurubooru/db/comment.py +++ b/server/szurubooru/db/comment.py @@ -3,13 +3,18 @@ from sqlalchemy.orm import relationship, backref from sqlalchemy.sql.expression import func from szurubooru.db.base import Base + class CommentScore(Base): __tablename__ = 'comment_score' comment_id = Column( 'comment_id', Integer, ForeignKey('comment.id'), primary_key=True) user_id = Column( - 'user_id', Integer, ForeignKey('user.id'), primary_key=True, index=True) + 'user_id', + Integer, + ForeignKey('user.id'), + primary_key=True, + index=True) time = Column('time', DateTime, nullable=False) score = Column('score', Integer, nullable=False) @@ -18,14 +23,14 @@ class CommentScore(Base): 'User', backref=backref('comment_scores', cascade='all, delete-orphan')) + class Comment(Base): __tablename__ = 'comment' comment_id = Column('id', Integer, primary_key=True) post_id = Column( 'post_id', Integer, ForeignKey('post.id'), index=True, nullable=False) - user_id = Column( - 'user_id', Integer, ForeignKey('user.id'), index=True) + user_id = Column('user_id', Integer, ForeignKey('user.id'), index=True) version = Column('version', Integer, default=1, nullable=False) creation_time = Column('creation_time', DateTime, nullable=False) last_edit_time = Column('last_edit_time', DateTime) diff --git a/server/szurubooru/db/post.py b/server/szurubooru/db/post.py index 0b160eb0..5529e49d 100644 --- a/server/szurubooru/db/post.py +++ b/server/szurubooru/db/post.py @@ -1,10 +1,12 @@ +from sqlalchemy.sql.expression import func, select from sqlalchemy import ( Column, Integer, DateTime, Unicode, UnicodeText, PickleType, ForeignKey) -from sqlalchemy.orm import relationship, column_property, object_session, backref -from sqlalchemy.sql.expression import func, select +from sqlalchemy.orm import ( + relationship, column_property, object_session, backref) from szurubooru.db.base import Base from szurubooru.db.comment import Comment + class PostFeature(Base): __tablename__ = 'post_feature' @@ -20,13 +22,22 @@ class PostFeature(Base): 'User', backref=backref('post_features', cascade='all, delete-orphan')) + class PostScore(Base): __tablename__ = 'post_score' post_id = Column( - 'post_id', Integer, ForeignKey('post.id'), primary_key=True, index=True) + 'post_id', + Integer, + ForeignKey('post.id'), + primary_key=True, + index=True) user_id = Column( - 'user_id', Integer, ForeignKey('user.id'), primary_key=True, index=True) + 'user_id', + Integer, + ForeignKey('user.id'), + primary_key=True, + index=True) time = Column('time', DateTime, nullable=False) score = Column('score', Integer, nullable=False) @@ -35,13 +46,22 @@ class PostScore(Base): 'User', backref=backref('post_scores', cascade='all, delete-orphan')) + class PostFavorite(Base): __tablename__ = 'post_favorite' post_id = Column( - 'post_id', Integer, ForeignKey('post.id'), primary_key=True, index=True) + 'post_id', + Integer, + ForeignKey('post.id'), + primary_key=True, + index=True) user_id = Column( - 'user_id', Integer, ForeignKey('user.id'), primary_key=True, index=True) + 'user_id', + Integer, + ForeignKey('user.id'), + primary_key=True, + index=True) time = Column('time', DateTime, nullable=False) post = relationship('Post') @@ -49,6 +69,7 @@ class PostFavorite(Base): 'User', backref=backref('post_favorites', cascade='all, delete-orphan')) + class PostNote(Base): __tablename__ = 'post_note' @@ -60,23 +81,37 @@ class PostNote(Base): post = relationship('Post') + class PostRelation(Base): __tablename__ = 'post_relation' parent_id = Column( - 'parent_id', Integer, ForeignKey('post.id'), primary_key=True, index=True) + 'parent_id', + Integer, + ForeignKey('post.id'), + primary_key=True, + index=True) child_id = Column( - 'child_id', Integer, ForeignKey('post.id'), primary_key=True, index=True) + 'child_id', + Integer, + ForeignKey('post.id'), + primary_key=True, + index=True) def __init__(self, parent_id, child_id): self.parent_id = parent_id self.child_id = child_id + class PostTag(Base): __tablename__ = 'post_tag' post_id = Column( - 'post_id', Integer, ForeignKey('post.id'), primary_key=True, index=True) + 'post_id', + Integer, + ForeignKey('post.id'), + primary_key=True, + index=True) tag_id = Column( 'tag_id', Integer, ForeignKey('tag.id'), primary_key=True, index=True) @@ -84,6 +119,7 @@ class PostTag(Base): self.post_id = post_id self.tag_id = tag_id + class Post(Base): __tablename__ = 'post' @@ -136,8 +172,8 @@ class Post(Base): # dynamic columns tag_count = column_property( - select([func.count(PostTag.tag_id)]) \ - .where(PostTag.post_id == post_id) \ + select([func.count(PostTag.tag_id)]) + .where(PostTag.post_id == post_id) .correlate_except(PostTag)) canvas_area = column_property(canvas_width * canvas_height) @@ -151,53 +187,53 @@ class Post(Base): return featured_post and featured_post.post_id == self.post_id score = column_property( - select([func.coalesce(func.sum(PostScore.score), 0)]) \ - .where(PostScore.post_id == post_id) \ + select([func.coalesce(func.sum(PostScore.score), 0)]) + .where(PostScore.post_id == post_id) .correlate_except(PostScore)) favorite_count = column_property( - select([func.count(PostFavorite.post_id)]) \ - .where(PostFavorite.post_id == post_id) \ + select([func.count(PostFavorite.post_id)]) + .where(PostFavorite.post_id == post_id) .correlate_except(PostFavorite)) last_favorite_time = column_property( - select([func.max(PostFavorite.time)]) \ - .where(PostFavorite.post_id == post_id) \ + select([func.max(PostFavorite.time)]) + .where(PostFavorite.post_id == post_id) .correlate_except(PostFavorite)) feature_count = column_property( - select([func.count(PostFeature.post_id)]) \ - .where(PostFeature.post_id == post_id) \ + select([func.count(PostFeature.post_id)]) + .where(PostFeature.post_id == post_id) .correlate_except(PostFeature)) last_feature_time = column_property( - select([func.max(PostFeature.time)]) \ - .where(PostFeature.post_id == post_id) \ + select([func.max(PostFeature.time)]) + .where(PostFeature.post_id == post_id) .correlate_except(PostFeature)) comment_count = column_property( - select([func.count(Comment.post_id)]) \ - .where(Comment.post_id == post_id) \ + select([func.count(Comment.post_id)]) + .where(Comment.post_id == post_id) .correlate_except(Comment)) last_comment_creation_time = column_property( - select([func.max(Comment.creation_time)]) \ - .where(Comment.post_id == post_id) \ + select([func.max(Comment.creation_time)]) + .where(Comment.post_id == post_id) .correlate_except(Comment)) last_comment_edit_time = column_property( - select([func.max(Comment.last_edit_time)]) \ - .where(Comment.post_id == post_id) \ + select([func.max(Comment.last_edit_time)]) + .where(Comment.post_id == post_id) .correlate_except(Comment)) note_count = column_property( - select([func.count(PostNote.post_id)]) \ - .where(PostNote.post_id == post_id) \ + select([func.count(PostNote.post_id)]) + .where(PostNote.post_id == post_id) .correlate_except(PostNote)) relation_count = column_property( - select([func.count(PostRelation.child_id)]) \ + select([func.count(PostRelation.child_id)]) .where( - (PostRelation.parent_id == post_id) \ - | (PostRelation.child_id == post_id)) \ + (PostRelation.parent_id == post_id) + | (PostRelation.child_id == post_id)) .correlate_except(PostRelation)) diff --git a/server/szurubooru/db/session.py b/server/szurubooru/db/session.py index 36f25f13..b9dea7ca 100644 --- a/server/szurubooru/db/session.py +++ b/server/szurubooru/db/session.py @@ -1,6 +1,7 @@ import sqlalchemy from szurubooru import config + class QueryCounter(object): _query_count = 0 @@ -16,6 +17,7 @@ class QueryCounter(object): def get(): return QueryCounter._query_count + def create_session(): _engine = sqlalchemy.create_engine( '{schema}://{user}:{password}@{host}:{port}/{name}'.format( @@ -30,6 +32,7 @@ def create_session(): _session_maker = sqlalchemy.orm.sessionmaker(bind=_engine) return sqlalchemy.orm.scoped_session(_session_maker) + # pylint: disable=invalid-name session = create_session() reset_query_count = QueryCounter.reset diff --git a/server/szurubooru/db/snapshot.py b/server/szurubooru/db/snapshot.py index 8fd237e8..554e11ca 100644 --- a/server/szurubooru/db/snapshot.py +++ b/server/szurubooru/db/snapshot.py @@ -1,7 +1,9 @@ -from sqlalchemy import Column, Integer, DateTime, Unicode, PickleType, ForeignKey from sqlalchemy.orm import relationship +from sqlalchemy import ( + Column, Integer, DateTime, Unicode, PickleType, ForeignKey) from szurubooru.db.base import Base + class Snapshot(Base): __tablename__ = 'snapshot' @@ -11,7 +13,8 @@ class Snapshot(Base): snapshot_id = Column('id', Integer, primary_key=True) creation_time = Column('creation_time', DateTime, nullable=False) - resource_type = Column('resource_type', Unicode(32), nullable=False, index=True) + resource_type = Column( + 'resource_type', Unicode(32), nullable=False, index=True) resource_id = Column('resource_id', Integer, nullable=False, index=True) resource_repr = Column('resource_repr', Unicode(64), nullable=False) operation = Column('operation', Unicode(16), nullable=False) diff --git a/server/szurubooru/db/tag.py b/server/szurubooru/db/tag.py index a1f50579..faa6e68d 100644 --- a/server/szurubooru/db/tag.py +++ b/server/szurubooru/db/tag.py @@ -1,49 +1,73 @@ -from sqlalchemy import Column, Integer, DateTime, Unicode, UnicodeText, ForeignKey +from sqlalchemy import ( + Column, Integer, DateTime, Unicode, UnicodeText, ForeignKey) from sqlalchemy.orm import relationship, column_property from sqlalchemy.sql.expression import func, select from szurubooru.db.base import Base from szurubooru.db.post import PostTag + class TagSuggestion(Base): __tablename__ = 'tag_suggestion' parent_id = Column( - 'parent_id', Integer, ForeignKey('tag.id'), primary_key=True, index=True) + 'parent_id', + Integer, + ForeignKey('tag.id'), + primary_key=True, index=True) child_id = Column( - 'child_id', Integer, ForeignKey('tag.id'), primary_key=True, index=True) + 'child_id', + Integer, + ForeignKey('tag.id'), + primary_key=True, index=True) def __init__(self, parent_id, child_id): self.parent_id = parent_id self.child_id = child_id + class TagImplication(Base): __tablename__ = 'tag_implication' parent_id = Column( - 'parent_id', Integer, ForeignKey('tag.id'), primary_key=True, index=True) + 'parent_id', + Integer, + ForeignKey('tag.id'), + primary_key=True, + index=True) child_id = Column( - 'child_id', Integer, ForeignKey('tag.id'), primary_key=True, index=True) + 'child_id', + Integer, + ForeignKey('tag.id'), + primary_key=True, + index=True) def __init__(self, parent_id, child_id): self.parent_id = parent_id self.child_id = child_id + class TagName(Base): __tablename__ = 'tag_name' tag_name_id = Column('tag_name_id', Integer, primary_key=True) - tag_id = Column('tag_id', Integer, ForeignKey('tag.id'), nullable=False, index=True) + tag_id = Column( + 'tag_id', Integer, ForeignKey('tag.id'), nullable=False, index=True) name = Column('name', Unicode(64), nullable=False, unique=True) def __init__(self, name): self.name = name + class Tag(Base): __tablename__ = 'tag' tag_id = Column('id', Integer, primary_key=True) category_id = Column( - 'category_id', Integer, ForeignKey('tag_category.id'), nullable=False, index=True) + 'category_id', + Integer, + ForeignKey('tag_category.id'), + nullable=False, + index=True) version = Column('version', Integer, default=1, nullable=False) creation_time = Column('creation_time', DateTime, nullable=False) last_edit_time = Column('last_edit_time', DateTime) @@ -69,25 +93,25 @@ class Tag(Base): lazy='joined') post_count = column_property( - select([func.count(PostTag.post_id)]) \ - .where(PostTag.tag_id == tag_id) \ + select([func.count(PostTag.post_id)]) + .where(PostTag.tag_id == tag_id) .correlate_except(PostTag)) first_name = column_property( - select([TagName.name]) \ - .where(TagName.tag_id == tag_id) \ - .limit(1) \ + select([TagName.name]) + .where(TagName.tag_id == tag_id) + .limit(1) .as_scalar(), deferred=True) suggestion_count = column_property( - select([func.count(TagSuggestion.child_id)]) \ - .where(TagSuggestion.parent_id == tag_id) \ + select([func.count(TagSuggestion.child_id)]) + .where(TagSuggestion.parent_id == tag_id) .as_scalar(), deferred=True) implication_count = column_property( - select([func.count(TagImplication.child_id)]) \ - .where(TagImplication.parent_id == tag_id) \ + select([func.count(TagImplication.child_id)]) + .where(TagImplication.parent_id == tag_id) .as_scalar(), deferred=True) diff --git a/server/szurubooru/db/tag_category.py b/server/szurubooru/db/tag_category.py index d6610bd7..cb1d9328 100644 --- a/server/szurubooru/db/tag_category.py +++ b/server/szurubooru/db/tag_category.py @@ -4,6 +4,7 @@ from sqlalchemy.sql.expression import func, select from szurubooru.db.base import Base from szurubooru.db.tag import Tag + class TagCategory(Base): __tablename__ = 'tag_category' @@ -17,6 +18,6 @@ class TagCategory(Base): self.name = name tag_count = column_property( - select([func.count('Tag.tag_id')]) \ - .where(Tag.category_id == tag_category_id) \ + select([func.count('Tag.tag_id')]) + .where(Tag.category_id == tag_category_id) .correlate_except(table('Tag'))) diff --git a/server/szurubooru/db/user.py b/server/szurubooru/db/user.py index 1439aebe..f6b9ceb2 100644 --- a/server/szurubooru/db/user.py +++ b/server/szurubooru/db/user.py @@ -5,6 +5,7 @@ from szurubooru.db.base import Base from szurubooru.db.post import Post, PostScore, PostFavorite from szurubooru.db.comment import Comment + class User(Base): __tablename__ = 'user' @@ -17,7 +18,7 @@ class User(Base): RANK_POWER = 'power' RANK_MODERATOR = 'moderator' RANK_ADMINISTRATOR = 'administrator' - RANK_NOBODY = 'nobody' # used for privileges: "nobody can be higher than admin" + RANK_NOBODY = 'nobody' # unattainable, used for privileges user_id = Column('id', Integer, primary_key=True) creation_time = Column('creation_time', DateTime, nullable=False) @@ -36,41 +37,41 @@ class User(Base): @property def post_count(self): from szurubooru.db import session - return session \ - .query(func.sum(1)) \ - .filter(Post.user_id == self.user_id) \ - .one()[0] or 0 + return (session + .query(func.sum(1)) + .filter(Post.user_id == self.user_id) + .one()[0] or 0) @property def comment_count(self): from szurubooru.db import session - return session \ - .query(func.sum(1)) \ - .filter(Comment.user_id == self.user_id) \ - .one()[0] or 0 + return (session + .query(func.sum(1)) + .filter(Comment.user_id == self.user_id) + .one()[0] or 0) @property def favorite_post_count(self): from szurubooru.db import session - return session \ - .query(func.sum(1)) \ - .filter(PostFavorite.user_id == self.user_id) \ - .one()[0] or 0 + return (session + .query(func.sum(1)) + .filter(PostFavorite.user_id == self.user_id) + .one()[0] or 0) @property def liked_post_count(self): from szurubooru.db import session - return session \ - .query(func.sum(1)) \ - .filter(PostScore.user_id == self.user_id) \ - .filter(PostScore.score == 1) \ - .one()[0] or 0 + return (session + .query(func.sum(1)) + .filter(PostScore.user_id == self.user_id) + .filter(PostScore.score == 1) + .one()[0] or 0) @property def disliked_post_count(self): from szurubooru.db import session - return session \ - .query(func.sum(1)) \ - .filter(PostScore.user_id == self.user_id) \ - .filter(PostScore.score == -1) \ - .one()[0] or 0 + return (session + .query(func.sum(1)) + .filter(PostScore.user_id == self.user_id) + .filter(PostScore.score == -1) + .one()[0] or 0) diff --git a/server/szurubooru/db/util.py b/server/szurubooru/db/util.py index d9e5d678..88ba6425 100644 --- a/server/szurubooru/db/util.py +++ b/server/szurubooru/db/util.py @@ -1,5 +1,6 @@ from sqlalchemy.inspection import inspect + def get_resource_info(entity): serializers = { 'tag': lambda tag: tag.first_name, @@ -23,6 +24,7 @@ def get_resource_info(entity): return (resource_type, resource_id, resource_repr) + def get_aux_entity(session, get_table_info, entity, user): table, get_column = get_table_info(entity) return session \ diff --git a/server/szurubooru/errors.py b/server/szurubooru/errors.py index 88423828..bd8c6691 100644 --- a/server/szurubooru/errors.py +++ b/server/szurubooru/errors.py @@ -1,11 +1,38 @@ -class ConfigError(RuntimeError): pass -class AuthError(RuntimeError): pass -class IntegrityError(RuntimeError): pass -class ValidationError(RuntimeError): pass -class SearchError(RuntimeError): pass -class NotFoundError(RuntimeError): pass -class ProcessingError(RuntimeError): pass +class ConfigError(RuntimeError): + pass -class MissingRequiredFileError(ValidationError): pass -class MissingRequiredParameterError(ValidationError): pass -class InvalidParameterError(ValidationError): pass + +class AuthError(RuntimeError): + pass + + +class IntegrityError(RuntimeError): + pass + + +class ValidationError(RuntimeError): + pass + + +class SearchError(RuntimeError): + pass + + +class NotFoundError(RuntimeError): + pass + + +class ProcessingError(RuntimeError): + pass + + +class MissingRequiredFileError(ValidationError): + pass + + +class MissingRequiredParameterError(ValidationError): + pass + + +class InvalidParameterError(ValidationError): + pass diff --git a/server/szurubooru/facade.py b/server/szurubooru/facade.py index 50652c45..ad1516c6 100644 --- a/server/szurubooru/facade.py +++ b/server/szurubooru/facade.py @@ -7,30 +7,37 @@ from szurubooru import config, errors, rest # pylint: disable=unused-import from szurubooru import api, middleware + def _on_auth_error(ex): raise rest.errors.HttpForbidden( title='Authentication error', description=str(ex)) + def _on_validation_error(ex): raise rest.errors.HttpBadRequest( title='Validation error', description=str(ex)) + def _on_search_error(ex): raise rest.errors.HttpBadRequest( title='Search error', description=str(ex)) + def _on_integrity_error(ex): raise rest.errors.HttpConflict( title='Integrity violation', description=ex.args[0]) + def _on_not_found_error(ex): raise rest.errors.HttpNotFound( title='Not found', description=str(ex)) + def _on_processing_error(ex): raise rest.errors.HttpBadRequest( title='Processing error', description=str(ex)) + def validate_config(): ''' Check whether config doesn't contain errors that might prove @@ -60,6 +67,7 @@ def validate_config(): raise errors.ConfigError( 'Database is not configured: %r is missing' % key) + def create_app(): ''' Create a WSGI compatible App object. ''' validate_config() diff --git a/server/szurubooru/func/auth.py b/server/szurubooru/func/auth.py index e41c5ee1..d71c8f9d 100644 --- a/server/szurubooru/func/auth.py +++ b/server/szurubooru/func/auth.py @@ -4,6 +4,7 @@ from collections import OrderedDict from szurubooru import config, db, errors from szurubooru.func import util + RANK_MAP = OrderedDict([ (db.User.RANK_ANONYMOUS, 'anonymous'), (db.User.RANK_RESTRICTED, 'restricted'), @@ -14,6 +15,7 @@ RANK_MAP = OrderedDict([ (db.User.RANK_NOBODY, 'nobody'), ]) + def get_password_hash(salt, password): ''' Retrieve new-style password hash. ''' digest = hashlib.sha256() @@ -22,6 +24,7 @@ def get_password_hash(salt, password): digest.update(password.encode('utf8')) return digest.hexdigest() + def get_legacy_password_hash(salt, password): ''' Retrieve old-style password hash. ''' digest = hashlib.sha1() @@ -30,6 +33,7 @@ def get_legacy_password_hash(salt, password): digest.update(password.encode('utf8')) return digest.hexdigest() + def create_password(): alphabet = { 'c': list('bcdfghijklmnpqrstvwxyz'), @@ -39,6 +43,7 @@ def create_password(): pattern = 'cvcvnncvcv' return ''.join(random.choice(alphabet[l]) for l in list(pattern)) + def is_valid_password(user, password): assert user salt, valid_hash = user.password_salt, user.password_hash @@ -48,6 +53,7 @@ def is_valid_password(user, password): ] return valid_hash in possible_hashes + def has_privilege(user, privilege_name): assert user all_ranks = list(RANK_MAP.keys()) @@ -58,11 +64,13 @@ def has_privilege(user, privilege_name): good_ranks = all_ranks[all_ranks.index(minimal_rank):] return user.rank in good_ranks + def verify_privilege(user, privilege_name): assert user if not has_privilege(user, privilege_name): raise errors.AuthError('Insufficient privileges to do this.') + def generate_authentication_token(user): ''' Generate nonguessable challenge (e.g. links in password reminder). ''' assert user diff --git a/server/szurubooru/func/cache.py b/server/szurubooru/func/cache.py index d6f581b4..78fb871d 100644 --- a/server/szurubooru/func/cache.py +++ b/server/szurubooru/func/cache.py @@ -1,11 +1,13 @@ from datetime import datetime + class LruCacheItem(object): def __init__(self, key, value): self.key = key self.value = value self.timestamp = datetime.utcnow() + class LruCache(object): def __init__(self, length, delta=None): self.length = length @@ -15,12 +17,13 @@ class LruCache(object): def insert_item(self, item): if item.key in self.hash: - item_index = next(i \ - for i, v in enumerate(self.item_list) \ + item_index = next( + i + for i, v in enumerate(self.item_list) if v.key == item.key) self.item_list[:] \ = self.item_list[:item_index] \ - + self.item_list[item_index+1:] + + self.item_list[item_index + 1:] self.item_list.insert(0, item) else: if len(self.item_list) > self.length: @@ -36,16 +39,21 @@ class LruCache(object): del self.hash[item.key] del self.item_list[self.item_list.index(item)] + _CACHE = LruCache(length=100) + def purge(): _CACHE.remove_all() + def has(key): return key in _CACHE.hash + def get(key): return _CACHE.hash[key].value + def put(key, value): _CACHE.insert_item(LruCacheItem(key, value)) diff --git a/server/szurubooru/func/comments.py b/server/szurubooru/func/comments.py index ec6efa89..6b7def85 100644 --- a/server/szurubooru/func/comments.py +++ b/server/szurubooru/func/comments.py @@ -2,16 +2,26 @@ import datetime from szurubooru import db, errors from szurubooru.func import users, scores, util -class InvalidCommentIdError(errors.ValidationError): pass -class CommentNotFoundError(errors.NotFoundError): pass -class EmptyCommentTextError(errors.ValidationError): pass + +class InvalidCommentIdError(errors.ValidationError): + pass + + +class CommentNotFoundError(errors.NotFoundError): + pass + + +class EmptyCommentTextError(errors.ValidationError): + pass + def serialize_comment(comment, auth_user, options=None): return util.serialize_entity( comment, { 'id': lambda: comment.comment_id, - 'user': lambda: users.serialize_micro_user(comment.user, auth_user), + 'user': + lambda: users.serialize_micro_user(comment.user, auth_user), 'postId': lambda: comment.post.post_id, 'version': lambda: comment.version, 'text': lambda: comment.text, @@ -22,6 +32,7 @@ def serialize_comment(comment, auth_user, options=None): }, options) + def try_get_comment_by_id(comment_id): try: comment_id = int(comment_id) @@ -32,12 +43,14 @@ def try_get_comment_by_id(comment_id): .filter(db.Comment.comment_id == comment_id) \ .one_or_none() + def get_comment_by_id(comment_id): comment = try_get_comment_by_id(comment_id) if comment: return comment raise CommentNotFoundError('Comment %r not found.' % comment_id) + def create_comment(user, post, text): comment = db.Comment() comment.user = user @@ -46,6 +59,7 @@ def create_comment(user, post, text): comment.creation_time = datetime.datetime.utcnow() return comment + def update_comment_text(comment, text): assert comment if not text: diff --git a/server/szurubooru/func/favorites.py b/server/szurubooru/func/favorites.py index d9255765..e95582fe 100644 --- a/server/szurubooru/func/favorites.py +++ b/server/szurubooru/func/favorites.py @@ -2,7 +2,10 @@ import datetime from szurubooru import db, errors from szurubooru.func import scores -class InvalidFavoriteTargetError(errors.ValidationError): pass + +class InvalidFavoriteTargetError(errors.ValidationError): + pass + def _get_table_info(entity): assert entity @@ -11,16 +14,19 @@ def _get_table_info(entity): return db.PostFavorite, lambda table: table.post_id raise InvalidFavoriteTargetError() + def _get_fav_entity(entity, user): assert entity assert user return db.util.get_aux_entity(db.session, _get_table_info, entity, user) + def has_favorited(entity, user): assert entity assert user return _get_fav_entity(entity, user) is not None + def unset_favorite(entity, user): assert entity assert user @@ -28,6 +34,7 @@ def unset_favorite(entity, user): if fav_entity: db.session.delete(fav_entity) + def set_favorite(entity, user): assert entity assert user diff --git a/server/szurubooru/func/files.py b/server/szurubooru/func/files.py index a010187a..48340450 100644 --- a/server/szurubooru/func/files.py +++ b/server/szurubooru/func/files.py @@ -1,20 +1,25 @@ import os from szurubooru import config + def _get_full_path(path): return os.path.join(config.config['data_dir'], path) + def delete(path): full_path = _get_full_path(path) if os.path.exists(full_path): os.unlink(full_path) + def has(path): return os.path.exists(_get_full_path(path)) + def move(source_path, target_path): return os.rename(_get_full_path(source_path), _get_full_path(target_path)) + def get(path): full_path = _get_full_path(path) if not os.path.exists(full_path): @@ -22,6 +27,7 @@ def get(path): with open(full_path, 'rb') as handle: return handle.read() + def save(path, content): full_path = _get_full_path(path) os.makedirs(os.path.dirname(full_path), exist_ok=True) diff --git a/server/szurubooru/func/images.py b/server/szurubooru/func/images.py index 870334af..f4040b2b 100644 --- a/server/szurubooru/func/images.py +++ b/server/szurubooru/func/images.py @@ -6,11 +6,14 @@ import math from szurubooru import errors from szurubooru.func import mime, util + logger = logging.getLogger(__name__) + _SCALE_FIT_FMT = \ r'scale=iw*max({width}/iw\,{height}/ih):ih*max({width}/iw\,{height}/ih)' + class Image(object): def __init__(self, content): self.content = content @@ -38,12 +41,13 @@ class Image(object): '-', ] if 'duration' in self.info['format'] \ - and float(self.info['format']['duration']) > 3 \ and self.info['format']['format_name'] != 'swf': - cli = [ - '-ss', - '%d' % math.floor(float(self.info['format']['duration']) * 0.3), - ] + cli + duration = float(self.info['format']['duration']) + if duration > 3: + cli = [ + '-ss', + '%d' % math.floor(duration * 0.3), + ] + cli self.content = self._execute(cli) assert self.content self._reload_info() diff --git a/server/szurubooru/func/mailer.py b/server/szurubooru/func/mailer.py index 0f9828c1..94f9c506 100644 --- a/server/szurubooru/func/mailer.py +++ b/server/szurubooru/func/mailer.py @@ -2,6 +2,7 @@ import smtplib import email.mime.text from szurubooru import config + def send_mail(sender, recipient, subject, body): msg = email.mime.text.MIMEText(body) msg['Subject'] = subject diff --git a/server/szurubooru/func/mime.py b/server/szurubooru/func/mime.py index 54fe225e..0f712d1a 100644 --- a/server/szurubooru/func/mime.py +++ b/server/szurubooru/func/mime.py @@ -1,6 +1,6 @@ import re -# pylint: disable=too-many-return-statements + def get_mime_type(content): if not content: return 'application/octet-stream' @@ -25,6 +25,7 @@ def get_mime_type(content): return 'application/octet-stream' + def get_extension(mime_type): extension_map = { 'application/x-shockwave-flash': 'swf', @@ -37,15 +38,19 @@ def get_extension(mime_type): } return extension_map.get((mime_type or '').strip().lower(), None) + def is_flash(mime_type): return mime_type.lower() == 'application/x-shockwave-flash' + def is_video(mime_type): return mime_type.lower() in ('application/ogg', 'video/mp4', 'video/webm') + def is_image(mime_type): return mime_type.lower() in ('image/jpeg', 'image/png', 'image/gif') + def is_animated_gif(content): return get_mime_type(content) == 'image/gif' \ and len(re.findall(b'\x21\xF9\x04.{4}\x00[\x2C\x21]', content)) > 1 diff --git a/server/szurubooru/func/net.py b/server/szurubooru/func/net.py index c9d96089..65c01c00 100644 --- a/server/szurubooru/func/net.py +++ b/server/szurubooru/func/net.py @@ -1,5 +1,6 @@ import urllib.request + def download(url): assert url request = urllib.request.Request(url) diff --git a/server/szurubooru/func/posts.py b/server/szurubooru/func/posts.py index 23e2d28e..027e7c4c 100644 --- a/server/szurubooru/func/posts.py +++ b/server/szurubooru/func/posts.py @@ -4,37 +4,71 @@ from szurubooru import config, db, errors from szurubooru.func import ( users, snapshots, scores, comments, tags, util, mime, images, files) + EMPTY_PIXEL = \ b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' \ b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \ b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b' -class PostNotFoundError(errors.NotFoundError): pass -class PostAlreadyFeaturedError(errors.ValidationError): pass -class PostAlreadyUploadedError(errors.ValidationError): pass -class InvalidPostIdError(errors.ValidationError): pass -class InvalidPostSafetyError(errors.ValidationError): pass -class InvalidPostSourceError(errors.ValidationError): pass -class InvalidPostContentError(errors.ValidationError): pass -class InvalidPostRelationError(errors.ValidationError): pass -class InvalidPostNoteError(errors.ValidationError): pass -class InvalidPostFlagError(errors.ValidationError): pass + +class PostNotFoundError(errors.NotFoundError): + pass + + +class PostAlreadyFeaturedError(errors.ValidationError): + pass + + +class PostAlreadyUploadedError(errors.ValidationError): + pass + + +class InvalidPostIdError(errors.ValidationError): + pass + + +class InvalidPostSafetyError(errors.ValidationError): + pass + + +class InvalidPostSourceError(errors.ValidationError): + pass + + +class InvalidPostContentError(errors.ValidationError): + pass + + +class InvalidPostRelationError(errors.ValidationError): + pass + + +class InvalidPostNoteError(errors.ValidationError): + pass + + +class InvalidPostFlagError(errors.ValidationError): + pass + SAFETY_MAP = { db.Post.SAFETY_SAFE: 'safe', db.Post.SAFETY_SKETCHY: 'sketchy', db.Post.SAFETY_UNSAFE: 'unsafe', } + TYPE_MAP = { db.Post.TYPE_IMAGE: 'image', db.Post.TYPE_ANIMATION: 'animation', db.Post.TYPE_VIDEO: 'video', db.Post.TYPE_FLASH: 'flash', } + FLAG_MAP = { db.Post.FLAG_LOOP: 'loop', } + def get_post_content_url(post): assert post return '%s/posts/%d.%s' % ( @@ -42,25 +76,30 @@ def get_post_content_url(post): post.post_id, mime.get_extension(post.mime_type) or 'dat') + def get_post_thumbnail_url(post): assert post return '%s/generated-thumbnails/%d.jpg' % ( config.config['data_url'].rstrip('/'), post.post_id) + def get_post_content_path(post): assert post return 'posts/%d.%s' % ( post.post_id, mime.get_extension(post.mime_type) or 'dat') + def get_post_thumbnail_path(post): assert post return 'generated-thumbnails/%d.jpg' % (post.post_id) + def get_post_thumbnail_backup_path(post): assert post return 'posts/custom-thumbnails/%d.dat' % (post.post_id) + def serialize_note(note): assert note return { @@ -68,6 +107,7 @@ def serialize_note(note): 'text': note.text, } + def serialize_post(post, auth_user, options=None): return util.serialize_entity( post, @@ -93,17 +133,17 @@ def serialize_post(post, auth_user, options=None): { post['id']: post for post in [ - serialize_micro_post(rel, auth_user) \ - for rel in post.relations - ] + serialize_micro_post(rel, auth_user) + for rel in post.relations] }.values(), key=lambda post: post['id']), 'user': lambda: users.serialize_micro_user(post.user, auth_user), 'score': lambda: post.score, 'ownScore': lambda: scores.get_score(post, auth_user), - 'ownFavorite': lambda: len( - [user for user in post.favorited_by \ - if user.user_id == auth_user.user_id]) > 0, + 'ownFavorite': lambda: len([ + user for user in post.favorited_by + if user.user_id == auth_user.user_id] + ) > 0, 'tagCount': lambda: post.tag_count, 'favoriteCount': lambda: post.favorite_count, 'commentCount': lambda: post.comment_count, @@ -112,31 +152,35 @@ def serialize_post(post, auth_user, options=None): 'featureCount': lambda: post.feature_count, 'lastFeatureTime': lambda: post.last_feature_time, 'favoritedBy': lambda: [ - users.serialize_micro_user(rel.user, auth_user) \ - for rel in post.favorited_by], + users.serialize_micro_user(rel.user, auth_user) + for rel in post.favorited_by + ], 'hasCustomThumbnail': lambda: files.has(get_post_thumbnail_backup_path(post)), 'notes': lambda: sorted( [serialize_note(note) for note in post.notes], key=lambda x: x['polygon']), 'comments': lambda: [ - comments.serialize_comment(comment, auth_user) \ - for comment in sorted( - post.comments, - key=lambda comment: comment.creation_time)], + comments.serialize_comment(comment, auth_user) + for comment in sorted( + post.comments, + key=lambda comment: comment.creation_time)], 'snapshots': lambda: snapshots.get_serialized_history(post), }, options) + def serialize_micro_post(post, auth_user): return serialize_post( post, auth_user=auth_user, options=['id', 'thumbnailUrl']) + def get_post_count(): return db.session.query(sqlalchemy.func.count(db.Post.post_id)).one()[0] + def try_get_post_by_id(post_id): try: post_id = int(post_id) @@ -147,22 +191,26 @@ def try_get_post_by_id(post_id): .filter(db.Post.post_id == post_id) \ .one_or_none() + def get_post_by_id(post_id): post = try_get_post_by_id(post_id) if not post: raise PostNotFoundError('Post %r not found.' % post_id) return post + def try_get_current_post_feature(): return db.session \ .query(db.PostFeature) \ .order_by(db.PostFeature.time.desc()) \ .first() + def try_get_featured_post(): post_feature = try_get_current_post_feature() return post_feature.post if post_feature else None + def create_post(content, tag_names, user): post = db.Post() post.safety = db.Post.SAFETY_SAFE @@ -181,6 +229,7 @@ def create_post(content, tag_names, user): new_tags = update_post_tags(post, tag_names) return (post, new_tags) + def update_post_safety(post, safety): assert post safety = util.flip(SAFETY_MAP).get(safety, None) @@ -189,12 +238,14 @@ def update_post_safety(post, safety): 'Safety can be either of %r.' % list(SAFETY_MAP.values())) post.safety = safety + def update_post_source(post, source): assert post if util.value_exceeds_column_size(source, db.Post.source): raise InvalidPostSourceError('Source is too long.') post.source = source + def update_post_content(post, content): assert post if not content: @@ -210,7 +261,8 @@ def update_post_content(post, content): elif mime.is_video(post.mime_type): post.type = db.Post.TYPE_VIDEO else: - raise InvalidPostContentError('Unhandled file type: %r' % post.mime_type) + raise InvalidPostContentError( + 'Unhandled file type: %r' % post.mime_type) post.checksum = util.get_md5(content) other_post = db.session \ @@ -236,6 +288,7 @@ def update_post_content(post, content): files.save(get_post_content_path(post), content) update_post_thumbnail(post, content=None, do_delete=False) + def update_post_thumbnail(post, content=None, do_delete=True): assert post if not content: @@ -246,6 +299,7 @@ def update_post_thumbnail(post, content=None, do_delete=True): files.save(get_post_thumbnail_backup_path(post), content) generate_post_thumbnail(post) + def generate_post_thumbnail(post): assert post if files.has(get_post_thumbnail_backup_path(post)): @@ -261,12 +315,14 @@ def generate_post_thumbnail(post): except errors.ProcessingError: files.save(get_post_thumbnail_path(post), EMPTY_PIXEL) + def update_post_tags(post, tag_names): assert post existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names) post.tags = existing_tags + new_tags return new_tags + def update_post_relations(post, new_post_ids): assert post old_posts = post.relations @@ -287,6 +343,7 @@ def update_post_relations(post, new_post_ids): post.relations.append(relation) relation.relations.append(post) + def update_post_notes(post, notes): assert post post.notes = [] @@ -323,6 +380,7 @@ def update_post_notes(post, notes): post.notes.append( db.PostNote(polygon=note['polygon'], text=str(note['text']))) + def update_post_flags(post, flags): assert post target_flags = [] @@ -334,6 +392,7 @@ def update_post_flags(post, flags): target_flags.append(flag) post.flags = target_flags + def feature_post(post, user): assert post post_feature = db.PostFeature() @@ -342,6 +401,7 @@ def feature_post(post, user): post_feature.user = user db.session.add(post_feature) + def delete(post): assert post db.session.delete(post) diff --git a/server/szurubooru/func/scores.py b/server/szurubooru/func/scores.py index 75c623ec..ed9c84db 100644 --- a/server/szurubooru/func/scores.py +++ b/server/szurubooru/func/scores.py @@ -2,8 +2,14 @@ import datetime from szurubooru import db, errors from szurubooru.func import favorites -class InvalidScoreTargetError(errors.ValidationError): pass -class InvalidScoreValueError(errors.ValidationError): pass + +class InvalidScoreTargetError(errors.ValidationError): + pass + + +class InvalidScoreValueError(errors.ValidationError): + pass + def _get_table_info(entity): assert entity @@ -14,10 +20,12 @@ def _get_table_info(entity): return db.CommentScore, lambda table: table.comment_id raise InvalidScoreTargetError() + def _get_score_entity(entity, user): assert user return db.util.get_aux_entity(db.session, _get_table_info, entity, user) + def delete_score(entity, user): assert entity assert user @@ -25,6 +33,7 @@ def delete_score(entity, user): if score_entity: db.session.delete(score_entity) + def get_score(entity, user): assert entity assert user @@ -36,6 +45,7 @@ def get_score(entity, user): .one_or_none() return row[0] if row else 0 + def set_score(entity, user, score): assert entity assert user diff --git a/server/szurubooru/func/snapshots.py b/server/szurubooru/func/snapshots.py index 29da205b..a7543166 100644 --- a/server/szurubooru/func/snapshots.py +++ b/server/szurubooru/func/snapshots.py @@ -1,6 +1,7 @@ import datetime from szurubooru import db + def get_tag_snapshot(tag): return { 'names': [tag_name.name for tag_name in tag.names], @@ -9,6 +10,7 @@ def get_tag_snapshot(tag): 'implications': sorted(rel.first_name for rel in tag.implications), } + def get_post_snapshot(post): return { 'source': post.source, @@ -25,6 +27,7 @@ def get_post_snapshot(post): 'featured': post.is_featured, } + def get_tag_category_snapshot(category): return { 'name': category.name, @@ -32,6 +35,7 @@ def get_tag_category_snapshot(category): 'default': True if category.default else False, } + def get_previous_snapshot(snapshot): assert snapshot return db.session \ @@ -43,6 +47,7 @@ def get_previous_snapshot(snapshot): .limit(1) \ .first() + def get_snapshots(entity): assert entity resource_type, resource_id, _ = db.util.get_resource_info(entity) @@ -53,6 +58,7 @@ def get_snapshots(entity): .order_by(db.Snapshot.creation_time.desc()) \ .all() + def serialize_snapshot(snapshot, earlier_snapshot=()): assert snapshot if earlier_snapshot is (): @@ -67,6 +73,7 @@ def serialize_snapshot(snapshot, earlier_snapshot=()): 'time': snapshot.creation_time, } + def get_serialized_history(entity): if not entity: return [] @@ -77,6 +84,7 @@ def get_serialized_history(entity): earlier_snapshot = snapshot return ret + def _save(operation, entity, auth_user): assert operation assert entity @@ -86,7 +94,8 @@ def _save(operation, entity, auth_user): 'post': get_post_snapshot, } - resource_type, resource_id, resource_repr = db.util.get_resource_info(entity) + resource_type, resource_id, resource_repr = ( + db.util.get_resource_info(entity)) now = datetime.datetime.utcnow() snapshot = db.Snapshot() @@ -118,14 +127,17 @@ def _save(operation, entity, auth_user): else: db.session.add(snapshot) + def save_entity_creation(entity, auth_user): assert entity _save(db.Snapshot.OPERATION_CREATED, entity, auth_user) + def save_entity_modification(entity, auth_user): assert entity _save(db.Snapshot.OPERATION_MODIFIED, entity, auth_user) + def save_entity_deletion(entity, auth_user): assert entity _save(db.Snapshot.OPERATION_DELETED, entity, auth_user) diff --git a/server/szurubooru/func/tag_categories.py b/server/szurubooru/func/tag_categories.py index 65391c99..f476f332 100644 --- a/server/szurubooru/func/tag_categories.py +++ b/server/szurubooru/func/tag_categories.py @@ -3,11 +3,26 @@ import sqlalchemy from szurubooru import config, db, errors from szurubooru.func import util, snapshots, cache -class TagCategoryNotFoundError(errors.NotFoundError): pass -class TagCategoryAlreadyExistsError(errors.ValidationError): pass -class TagCategoryIsInUseError(errors.ValidationError): pass -class InvalidTagCategoryNameError(errors.ValidationError): pass -class InvalidTagCategoryColorError(errors.ValidationError): pass + +class TagCategoryNotFoundError(errors.NotFoundError): + pass + + +class TagCategoryAlreadyExistsError(errors.ValidationError): + pass + + +class TagCategoryIsInUseError(errors.ValidationError): + pass + + +class InvalidTagCategoryNameError(errors.ValidationError): + pass + + +class InvalidTagCategoryColorError(errors.ValidationError): + pass + def _verify_name_validity(name): name_regex = config.config['tag_category_name_regex'] @@ -15,6 +30,7 @@ def _verify_name_validity(name): raise InvalidTagCategoryNameError( 'Name must satisfy regex %r.' % name_regex) + def serialize_category(category, options=None): return util.serialize_entity( category, @@ -28,6 +44,7 @@ def serialize_category(category, options=None): }, options) + def create_category(name, color): category = db.TagCategory() update_category_name(category, name) @@ -36,13 +53,15 @@ def create_category(name, color): category.default = True return category + def update_category_name(category, name): assert category if not name: raise InvalidTagCategoryNameError('Name cannot be empty.') expr = sqlalchemy.func.lower(db.TagCategory.name) == name.lower() if category.tag_category_id: - expr = expr & (db.TagCategory.tag_category_id != category.tag_category_id) + expr = expr & ( + db.TagCategory.tag_category_id != category.tag_category_id) already_exists = db.session.query(db.TagCategory).filter(expr).count() > 0 if already_exists: raise TagCategoryAlreadyExistsError( @@ -52,6 +71,7 @@ def update_category_name(category, name): _verify_name_validity(name) category.name = name + def update_category_color(category, color): assert category if not color: @@ -62,24 +82,29 @@ def update_category_color(category, color): raise InvalidTagCategoryColorError('Color is too long.') category.color = color + def try_get_category_by_name(name): return db.session \ .query(db.TagCategory) \ .filter(sqlalchemy.func.lower(db.TagCategory.name) == name.lower()) \ .one_or_none() + def get_category_by_name(name): category = try_get_category_by_name(name) if not category: raise TagCategoryNotFoundError('Tag category %r not found.' % name) return category + def get_all_category_names(): return [row[0] for row in db.session.query(db.TagCategory.name).all()] + def get_all_categories(): return db.session.query(db.TagCategory).all() + def try_get_default_category(): key = 'default-tag-category' if cache.has(key): @@ -98,12 +123,14 @@ def try_get_default_category(): cache.put(key, category) return category + def get_default_category(): category = try_get_default_category() if not category: raise TagCategoryNotFoundError('No tag category created yet.') return category + def set_default_category(category): assert category old_category = try_get_default_category() @@ -111,6 +138,7 @@ def set_default_category(category): old_category.default = False category.default = True + def delete_category(category): assert category if len(get_all_category_names()) == 1: diff --git a/server/szurubooru/func/tags.py b/server/szurubooru/func/tags.py index 433a7d9d..19f338a7 100644 --- a/server/szurubooru/func/tags.py +++ b/server/szurubooru/func/tags.py @@ -6,32 +6,57 @@ import sqlalchemy from szurubooru import config, db, errors from szurubooru.func import util, tag_categories, snapshots -class TagNotFoundError(errors.NotFoundError): pass -class TagAlreadyExistsError(errors.ValidationError): pass -class TagIsInUseError(errors.ValidationError): pass -class InvalidTagNameError(errors.ValidationError): pass -class InvalidTagRelationError(errors.ValidationError): pass -class InvalidTagCategoryError(errors.ValidationError): pass -class InvalidTagDescriptionError(errors.ValidationError): pass + +class TagNotFoundError(errors.NotFoundError): + pass + + +class TagAlreadyExistsError(errors.ValidationError): + pass + + +class TagIsInUseError(errors.ValidationError): + pass + + +class InvalidTagNameError(errors.ValidationError): + pass + + +class InvalidTagRelationError(errors.ValidationError): + pass + + +class InvalidTagCategoryError(errors.ValidationError): + pass + + +class InvalidTagDescriptionError(errors.ValidationError): + pass + def _verify_name_validity(name): name_regex = config.config['tag_name_regex'] if not re.match(name_regex, name): raise InvalidTagNameError('Name must satisfy regex %r.' % name_regex) -def _get_plain_names(tag): + +def _get_names(tag): assert tag return [tag_name.name for tag_name in tag.names] + def _lower_list(names): return [name.lower() for name in names] -def _check_name_intersection(names1, names2): - return len(set(_lower_list(names1)).intersection(_lower_list(names2))) > 0 -def _check_name_intersection_case_sensitive(names1, names2): +def _check_name_intersection(names1, names2, case_sensitive): + if not case_sensitive: + names1 = _lower_list(names1) + names2 = _lower_list(names2) return len(set(names1).intersection(names2)) > 0 + def sort_tags(tags): default_category = tag_categories.try_get_default_category() default_category_name = default_category.name if default_category else None @@ -43,6 +68,7 @@ def sort_tags(tags): tag.names[0].name) ) + def serialize_tag(tag, options=None): return util.serialize_entity( tag, @@ -55,15 +81,16 @@ def serialize_tag(tag, options=None): 'lastEditTime': lambda: tag.last_edit_time, 'usages': lambda: tag.post_count, 'suggestions': lambda: [ - relation.names[0].name \ - for relation in sort_tags(tag.suggestions)], + relation.names[0].name + for relation in sort_tags(tag.suggestions)], 'implications': lambda: [ - relation.names[0].name \ - for relation in sort_tags(tag.implications)], + relation.names[0].name + for relation in sort_tags(tag.implications)], 'snapshots': lambda: snapshots.get_serialized_history(tag), }, options) + def export_to_json(): tags = {} categories = {} @@ -82,19 +109,19 @@ def export_to_json(): tags[result[0]] = {'names': []} tags[result[0]]['names'].append(result[1]) - for result in db.session \ - .query(db.TagSuggestion.parent_id, db.TagName.name) \ - .join(db.TagName, db.TagName.tag_id == db.TagSuggestion.child_id) \ - .all(): - if not 'suggestions' in tags[result[0]]: + for result in (db.session + .query(db.TagSuggestion.parent_id, db.TagName.name) + .join(db.TagName, db.TagName.tag_id == db.TagSuggestion.child_id) + .all()): + if 'suggestions' not in tags[result[0]]: tags[result[0]]['suggestions'] = [] tags[result[0]]['suggestions'].append(result[1]) - for result in db.session \ - .query(db.TagImplication.parent_id, db.TagName.name) \ - .join(db.TagName, db.TagName.tag_id == db.TagImplication.child_id) \ - .all(): - if not 'implications' in tags[result[0]]: + for result in (db.session + .query(db.TagImplication.parent_id, db.TagName.name) + .join(db.TagName, db.TagName.tag_id == db.TagImplication.child_id) + .all()): + if 'implications' not in tags[result[0]]: tags[result[0]]['implications'] = [] tags[result[0]]['implications'].append(result[1]) @@ -114,12 +141,14 @@ def export_to_json(): with open(export_path, 'w') as handle: handle.write(json.dumps(output, separators=(',', ':'))) + def try_get_tag_by_name(name): - return db.session \ - .query(db.Tag) \ - .join(db.TagName) \ - .filter(sqlalchemy.func.lower(db.TagName.name) == name.lower()) \ - .one_or_none() + return (db.session + .query(db.Tag) + .join(db.TagName) + .filter(sqlalchemy.func.lower(db.TagName.name) == name.lower()) + .one_or_none()) + def get_tag_by_name(name): tag = try_get_tag_by_name(name) @@ -127,6 +156,7 @@ def get_tag_by_name(name): raise TagNotFoundError('Tag %r not found.' % name) return tag + def get_tags_by_names(names): names = util.icase_unique(names) if len(names) == 0: @@ -136,6 +166,7 @@ def get_tags_by_names(names): expr = expr | (sqlalchemy.func.lower(db.TagName.name) == name.lower()) return db.session.query(db.Tag).join(db.TagName).filter(expr).all() + def get_or_create_tags_by_names(names): names = util.icase_unique(names) existing_tags = get_tags_by_names(names) @@ -144,7 +175,8 @@ def get_or_create_tags_by_names(names): for name in names: found = False for existing_tag in existing_tags: - if _check_name_intersection(_get_plain_names(existing_tag), [name]): + if _check_name_intersection( + _get_names(existing_tag), [name], False): found = True break if not found: @@ -157,32 +189,35 @@ def get_or_create_tags_by_names(names): new_tags.append(new_tag) return existing_tags, new_tags + def get_tag_siblings(tag): assert tag tag_alias = sqlalchemy.orm.aliased(db.Tag) pt_alias1 = sqlalchemy.orm.aliased(db.PostTag) pt_alias2 = sqlalchemy.orm.aliased(db.PostTag) - result = db.session \ - .query(tag_alias, sqlalchemy.func.count(pt_alias2.post_id)) \ - .join(pt_alias1, pt_alias1.tag_id == tag_alias.tag_id) \ - .join(pt_alias2, pt_alias2.post_id == pt_alias1.post_id) \ - .filter(pt_alias2.tag_id == tag.tag_id) \ - .filter(pt_alias1.tag_id != tag.tag_id) \ - .group_by(tag_alias.tag_id) \ - .order_by(sqlalchemy.func.count(pt_alias2.post_id).desc()) \ - .limit(50) + result = (db.session + .query(tag_alias, sqlalchemy.func.count(pt_alias2.post_id)) + .join(pt_alias1, pt_alias1.tag_id == tag_alias.tag_id) + .join(pt_alias2, pt_alias2.post_id == pt_alias1.post_id) + .filter(pt_alias2.tag_id == tag.tag_id) + .filter(pt_alias1.tag_id != tag.tag_id) + .group_by(tag_alias.tag_id) + .order_by(sqlalchemy.func.count(pt_alias2.post_id).desc()) + .limit(50)) return result + def delete(source_tag): assert source_tag db.session.execute( - sqlalchemy.sql.expression.delete(db.TagSuggestion) \ + sqlalchemy.sql.expression.delete(db.TagSuggestion) .where(db.TagSuggestion.child_id == source_tag.tag_id)) db.session.execute( - sqlalchemy.sql.expression.delete(db.TagImplication) \ + sqlalchemy.sql.expression.delete(db.TagImplication) .where(db.TagImplication.child_id == source_tag.tag_id)) db.session.delete(source_tag) + def merge_tags(source_tag, target_tag): assert source_tag assert target_tag @@ -191,15 +226,16 @@ def merge_tags(source_tag, target_tag): pt1 = db.PostTag pt2 = sqlalchemy.orm.util.aliased(db.PostTag) - update_stmt = sqlalchemy.sql.expression.update(pt1) \ - .where(db.PostTag.tag_id == source_tag.tag_id) \ - .where(~sqlalchemy.exists() \ - .where(pt2.post_id == pt1.post_id) \ - .where(pt2.tag_id == target_tag.tag_id)) \ - .values(tag_id=target_tag.tag_id) + update_stmt = (sqlalchemy.sql.expression.update(pt1) + .where(db.PostTag.tag_id == source_tag.tag_id) + .where(~sqlalchemy.exists() + .where(pt2.post_id == pt1.post_id) + .where(pt2.tag_id == target_tag.tag_id)) + .values(tag_id=target_tag.tag_id)) db.session.execute(update_stmt) delete(source_tag) + def create_tag(names, category_name, suggestions, implications): tag = db.Tag() tag.creation_time = datetime.datetime.utcnow() @@ -209,10 +245,12 @@ def create_tag(names, category_name, suggestions, implications): update_tag_implications(tag, implications) return tag + def update_tag_category_name(tag, category_name): assert tag tag.category = tag_categories.get_category_by_name(category_name) + def update_tag_names(tag, names): assert tag names = util.icase_unique([name for name in names if name]) @@ -232,26 +270,29 @@ def update_tag_names(tag, names): raise TagAlreadyExistsError( 'One of names is already used by another tag.') for tag_name in tag.names[:]: - if not _check_name_intersection_case_sensitive([tag_name.name], names): + if not _check_name_intersection([tag_name.name], names, True): tag.names.remove(tag_name) for name in names: - if not _check_name_intersection_case_sensitive(_get_plain_names(tag), [name]): + if not _check_name_intersection(_get_names(tag), [name], True): tag.names.append(db.TagName(name)) + # TODO: what to do with relations that do not yet exist? def update_tag_implications(tag, relations): assert tag - if _check_name_intersection(_get_plain_names(tag), relations): + if _check_name_intersection(_get_names(tag), relations, False): raise InvalidTagRelationError('Tag cannot imply itself.') tag.implications = get_tags_by_names(relations) + # TODO: what to do with relations that do not yet exist? def update_tag_suggestions(tag, relations): assert tag - if _check_name_intersection(_get_plain_names(tag), relations): + if _check_name_intersection(_get_names(tag), relations, False): raise InvalidTagRelationError('Tag cannot suggest itself.') tag.suggestions = get_tags_by_names(relations) + def update_tag_description(tag, description): assert tag if util.value_exceeds_column_size(description, db.Tag.description): diff --git a/server/szurubooru/func/users.py b/server/szurubooru/func/users.py index 969824cf..c644d4a2 100644 --- a/server/szurubooru/func/users.py +++ b/server/szurubooru/func/users.py @@ -4,17 +4,39 @@ from sqlalchemy import func from szurubooru import config, db, errors from szurubooru.func import auth, util, files, images -class UserNotFoundError(errors.NotFoundError): pass -class UserAlreadyExistsError(errors.ValidationError): pass -class InvalidUserNameError(errors.ValidationError): pass -class InvalidEmailError(errors.ValidationError): pass -class InvalidPasswordError(errors.ValidationError): pass -class InvalidRankError(errors.ValidationError): pass -class InvalidAvatarError(errors.ValidationError): pass + +class UserNotFoundError(errors.NotFoundError): + pass + + +class UserAlreadyExistsError(errors.ValidationError): + pass + + +class InvalidUserNameError(errors.ValidationError): + pass + + +class InvalidEmailError(errors.ValidationError): + pass + + +class InvalidPasswordError(errors.ValidationError): + pass + + +class InvalidRankError(errors.ValidationError): + pass + + +class InvalidAvatarError(errors.ValidationError): + pass + def get_avatar_path(user_name): return 'avatars/' + user_name.lower() + '.png' + def get_avatar_url(user): assert user if user.avatar_style == user.AVATAR_GRAVATAR: @@ -27,6 +49,7 @@ def get_avatar_url(user): return '%s/avatars/%s.png' % ( config.config['data_url'].rstrip('/'), user.name.lower()) + def get_email(user, auth_user, force_show_email): assert user assert auth_user @@ -36,6 +59,7 @@ def get_email(user, auth_user, force_show_email): return False return user.email + def get_liked_post_count(user, auth_user): assert user assert auth_user @@ -43,6 +67,7 @@ def get_liked_post_count(user, auth_user): return False return user.liked_post_count + def get_disliked_post_count(user, auth_user): assert user assert auth_user @@ -50,6 +75,7 @@ def get_disliked_post_count(user, auth_user): return False return user.disliked_post_count + def serialize_user(user, auth_user, options=None, force_show_email=False): return util.serialize_entity( user, @@ -73,34 +99,40 @@ def serialize_user(user, auth_user, options=None, force_show_email=False): }, options) + def serialize_micro_user(user, auth_user): return serialize_user( user, auth_user=auth_user, options=['name', 'avatarUrl']) + def get_user_count(): return db.session.query(db.User).count() + def try_get_user_by_name(name): return db.session \ .query(db.User) \ .filter(func.lower(db.User.name) == func.lower(name)) \ .one_or_none() + def get_user_by_name(name): user = try_get_user_by_name(name) if not user: raise UserNotFoundError('User %r not found.' % name) return user + def try_get_user_by_name_or_email(name_or_email): - return db.session \ - .query(db.User) \ + return (db.session + .query(db.User) .filter( - (func.lower(db.User.name) == func.lower(name_or_email)) - | (func.lower(db.User.email) == func.lower(name_or_email))) \ - .one_or_none() + (func.lower(db.User.name) == func.lower(name_or_email)) | + (func.lower(db.User.email) == func.lower(name_or_email))) + .one_or_none()) + def get_user_by_name_or_email(name_or_email): user = try_get_user_by_name_or_email(name_or_email) @@ -108,6 +140,7 @@ def get_user_by_name_or_email(name_or_email): raise UserNotFoundError('User %r not found.' % name_or_email) return user + def create_user(name, password, email): user = db.User() update_user_name(user, name) @@ -121,6 +154,7 @@ def create_user(name, password, email): user.avatar_style = db.User.AVATAR_GRAVATAR return user + def update_user_name(user, name): assert user if not name: @@ -139,6 +173,7 @@ def update_user_name(user, name): files.move(get_avatar_path(user.name), get_avatar_path(name)) user.name = name + def update_user_password(user, password): assert user if not password: @@ -150,6 +185,7 @@ def update_user_password(user, password): user.password_salt = auth.create_password() user.password_hash = auth.get_password_hash(user.password_salt, password) + def update_user_email(user, email): assert user if email: @@ -162,6 +198,7 @@ def update_user_email(user, email): raise InvalidEmailError('E-mail is invalid.') user.email = email + def update_user_rank(user, rank, auth_user): assert user if not rank: @@ -178,6 +215,7 @@ def update_user_rank(user, rank, auth_user): raise errors.AuthError('Trying to set higher rank than your own.') user.rank = rank + def update_user_avatar(user, avatar_style, avatar_content=None): assert user if avatar_style == 'gravatar': @@ -199,10 +237,12 @@ def update_user_avatar(user, avatar_style, avatar_content=None): 'Avatar style %r is invalid. Valid avatar styles: %r.' % ( avatar_style, ['gravatar', 'manual'])) + def bump_user_login_time(user): assert user user.last_login_time = datetime.datetime.utcnow() + def reset_user_password(user): assert user password = auth.create_password() diff --git a/server/szurubooru/func/util.py b/server/szurubooru/func/util.py index a4bd9177..8a1e6c97 100644 --- a/server/szurubooru/func/util.py +++ b/server/szurubooru/func/util.py @@ -1,18 +1,22 @@ import os -import datetime import hashlib import re import tempfile +from datetime import datetime, timedelta from contextlib import contextmanager from szurubooru import errors + def snake_case_to_lower_camel_case(text): components = text.split('_') return components[0].lower() + \ ''.join(word[0].upper() + word[1:].lower() for word in components[1:]) + def snake_case_to_upper_train_case(text): - return '-'.join(word[0].upper() + word[1:].lower() for word in text.split('_')) + return '-'.join( + word[0].upper() + word[1:].lower() for word in text.split('_')) + def snake_case_to_lower_camel_case_keys(source): target = {} @@ -20,9 +24,11 @@ def snake_case_to_lower_camel_case_keys(source): target[snake_case_to_lower_camel_case(key)] = value return target + def get_serialization_options(ctx): return ctx.get_param_as_list('fields', required=False, default=None) + def serialize_entity(entity, field_factories, options): if not entity: return None @@ -30,13 +36,14 @@ def serialize_entity(entity, field_factories, options): options = field_factories.keys() ret = {} for key in options: - if not key in field_factories: + if key not in field_factories: raise errors.ValidationError('Invalid key: %r. Valid keys: %r.' % ( key, list(sorted(field_factories.keys())))) factory = field_factories[key] ret[key] = factory() return ret + @contextmanager def create_temp_file(**kwargs): (handle, path) = tempfile.mkstemp(**kwargs) @@ -47,6 +54,7 @@ def create_temp_file(**kwargs): finally: os.remove(path) + def unalias_dict(input_dict): output_dict = {} for key_list, value in input_dict.items(): @@ -56,6 +64,7 @@ def unalias_dict(input_dict): output_dict[key] = value return output_dict + def get_md5(source): if not isinstance(source, bytes): source = source.encode('utf-8') @@ -63,57 +72,58 @@ def get_md5(source): md5.update(source) return md5.hexdigest() + def flip(source): return {v: k for k, v in source.items()} + def is_valid_email(email): ''' Return whether given email address is valid or empty. ''' return not email or re.match(r'^[^@]*@[^@]*\.[^@]*$', email) -class dotdict(dict): # pylint: disable=invalid-name + +class dotdict(dict): # pylint: disable=invalid-name ''' dot.notation access to dictionary attributes. ''' def __getattr__(self, attr): return self.get(attr) __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ + def parse_time_range(value): ''' Return tuple containing min/max time for given text representation. ''' - one_day = datetime.timedelta(days=1) - one_second = datetime.timedelta(seconds=1) + one_day = timedelta(days=1) + one_second = timedelta(seconds=1) value = value.lower() if not value: raise errors.ValidationError('Empty date format.') if value == 'today': - now = datetime.datetime.utcnow() + now = datetime.utcnow() return ( - datetime.datetime(now.year, now.month, now.day, 0, 0, 0), - datetime.datetime(now.year, now.month, now.day, 0, 0, 0) \ + datetime(now.year, now.month, now.day, 0, 0, 0), + datetime(now.year, now.month, now.day, 0, 0, 0) + one_day - one_second) if value == 'yesterday': - now = datetime.datetime.utcnow() + now = datetime.utcnow() return ( - datetime.datetime(now.year, now.month, now.day, 0, 0, 0) - one_day, - datetime.datetime(now.year, now.month, now.day, 0, 0, 0) \ - - one_second) + datetime(now.year, now.month, now.day, 0, 0, 0) - one_day, + datetime(now.year, now.month, now.day, 0, 0, 0) - one_second) match = re.match(r'^(\d{4})$', value) if match: year = int(match.group(1)) - return ( - datetime.datetime(year, 1, 1), - datetime.datetime(year + 1, 1, 1) - one_second) + return (datetime(year, 1, 1), datetime(year + 1, 1, 1) - one_second) match = re.match(r'^(\d{4})-(\d{1,2})$', value) if match: year = int(match.group(1)) month = int(match.group(2)) return ( - datetime.datetime(year, month, 1), - datetime.datetime(year, month + 1, 1) - one_second) + datetime(year, month, 1), + datetime(year, month + 1, 1) - one_second) match = re.match(r'^(\d{4})-(\d{1,2})-(\d{1,2})$', value) if match: @@ -121,11 +131,12 @@ def parse_time_range(value): month = int(match.group(2)) day = int(match.group(3)) return ( - datetime.datetime(year, month, day), - datetime.datetime(year, month, day + 1) - one_second) + datetime(year, month, day), + datetime(year, month, day + 1) - one_second) raise errors.ValidationError('Invalid date format: %r.' % value) + def icase_unique(source): target = [] target_low = [] @@ -135,6 +146,7 @@ def icase_unique(source): target_low.append(source_item.lower()) return target + def value_exceeds_column_size(value, column): if not value: return False @@ -143,6 +155,7 @@ def value_exceeds_column_size(value, column): return False return len(value) > max_length + def verify_version(entity, context, field_name='version'): actual_version = context.get_param_as_int(field_name, required=True) expected_version = entity.version @@ -151,5 +164,6 @@ def verify_version(entity, context, field_name='version'): 'Someone else modified this in the meantime. ' + 'Please try again.') + def bump_version(entity): entity.version += 1 diff --git a/server/szurubooru/middleware/authenticator.py b/server/szurubooru/middleware/authenticator.py index b483494d..b78b90ea 100644 --- a/server/szurubooru/middleware/authenticator.py +++ b/server/szurubooru/middleware/authenticator.py @@ -4,6 +4,7 @@ from szurubooru.func import auth, users from szurubooru.rest import middleware from szurubooru.rest.errors import HttpBadRequest + def _authenticate(username, password): ''' Try to authenticate user. Throw AuthError for invalid users. ''' user = users.get_user_by_name(username) @@ -11,23 +12,25 @@ def _authenticate(username, password): raise errors.AuthError('Invalid password.') return user + def _create_anonymous_user(): user = db.User() user.name = None user.rank = 'anonymous' return user + def _get_user(ctx): if not ctx.has_header('Authorization'): return _create_anonymous_user() try: - auth_type, user_and_password = ctx.get_header('Authorization').split(' ', 1) + auth_type, credentials = ctx.get_header('Authorization').split(' ', 1) if auth_type.lower() != 'basic': raise HttpBadRequest( 'Only basic HTTP authentication is supported.') username, password = base64.decodebytes( - user_and_password.encode('ascii')).decode('utf8').split(':') + credentials.encode('ascii')).decode('utf8').split(':') return _authenticate(username, password) except ValueError as err: msg = 'Basic authentication header value are not properly formed. ' \ @@ -35,6 +38,7 @@ def _get_user(ctx): raise HttpBadRequest( msg.format(ctx.get_header('Authorization'), str(err))) + @middleware.pre_hook def process_request(ctx): ''' Bind the user to request. Update last login time if needed. ''' diff --git a/server/szurubooru/middleware/cache_purger.py b/server/szurubooru/middleware/cache_purger.py index bc9bd9bc..e26b3bae 100644 --- a/server/szurubooru/middleware/cache_purger.py +++ b/server/szurubooru/middleware/cache_purger.py @@ -1,6 +1,7 @@ from szurubooru.func import cache from szurubooru.rest import middleware + @middleware.pre_hook def process_request(ctx): if ctx.method != 'GET': diff --git a/server/szurubooru/middleware/db_session.py b/server/szurubooru/middleware/db_session.py index 8d2c1c9c..8f8a1a55 100644 --- a/server/szurubooru/middleware/db_session.py +++ b/server/szurubooru/middleware/db_session.py @@ -1,11 +1,13 @@ from szurubooru import db from szurubooru.rest import middleware + @middleware.pre_hook def _process_request(ctx): ctx.session = db.session() db.reset_query_count() + @middleware.post_hook def _process_response(_ctx): db.session.remove() diff --git a/server/szurubooru/middleware/request_logger.py b/server/szurubooru/middleware/request_logger.py index 638df16b..c430ca4e 100644 --- a/server/szurubooru/middleware/request_logger.py +++ b/server/szurubooru/middleware/request_logger.py @@ -2,8 +2,10 @@ import logging from szurubooru import db from szurubooru.rest import middleware + logger = logging.getLogger(__name__) + @middleware.post_hook def process_response(ctx): logger.info( diff --git a/server/szurubooru/migrations/env.py b/server/szurubooru/migrations/env.py index c109d9a3..f1161509 100644 --- a/server/szurubooru/migrations/env.py +++ b/server/szurubooru/migrations/env.py @@ -28,6 +28,7 @@ alembic_config.set_main_option( target_metadata = szurubooru.db.Base.metadata + def run_migrations_offline(): ''' Run migrations in 'offline' mode. diff --git a/server/szurubooru/migrations/versions/00cb3a2734db_create_tag_tables.py b/server/szurubooru/migrations/versions/00cb3a2734db_create_tag_tables.py index ccdcb885..77d76414 100644 --- a/server/szurubooru/migrations/versions/00cb3a2734db_create_tag_tables.py +++ b/server/szurubooru/migrations/versions/00cb3a2734db_create_tag_tables.py @@ -13,6 +13,7 @@ down_revision = 'e5c1216a8503' branch_labels = None depends_on = None + def upgrade(): op.create_table( 'tag_category', @@ -55,6 +56,7 @@ def upgrade(): sa.ForeignKeyConstraint(['child_id'], ['tag.id']), sa.PrimaryKeyConstraint('parent_id', 'child_id')) + def downgrade(): op.drop_table('tag_suggestion') op.drop_table('tag_implication') diff --git a/server/szurubooru/migrations/versions/055d0e048fb3_add_default_column_to_tag_categories.py b/server/szurubooru/migrations/versions/055d0e048fb3_add_default_column_to_tag_categories.py index 9c4efc4d..e6a37e6d 100644 --- a/server/szurubooru/migrations/versions/055d0e048fb3_add_default_column_to_tag_categories.py +++ b/server/szurubooru/migrations/versions/055d0e048fb3_add_default_column_to_tag_categories.py @@ -13,10 +13,16 @@ down_revision = '49ab4e1139ef' branch_labels = None depends_on = None + def upgrade(): - op.add_column('tag_category', sa.Column('default', sa.Boolean(), nullable=True)) - op.execute(sa.table('tag_category', sa.column('default')).update().values(default=False)) + op.add_column( + 'tag_category', sa.Column('default', sa.Boolean(), nullable=True)) + op.execute( + sa.table('tag_category', sa.column('default')) + .update() + .values(default=False)) op.alter_column('tag_category', 'default', nullable=False) + def downgrade(): op.drop_column('tag_category', 'default') diff --git a/server/szurubooru/migrations/versions/23abaf4a0a4b_add_mime_type_to_posts.py b/server/szurubooru/migrations/versions/23abaf4a0a4b_add_mime_type_to_posts.py index e6d06f08..e18119b2 100644 --- a/server/szurubooru/migrations/versions/23abaf4a0a4b_add_mime_type_to_posts.py +++ b/server/szurubooru/migrations/versions/23abaf4a0a4b_add_mime_type_to_posts.py @@ -13,8 +13,11 @@ down_revision = 'ed6dd16a30f3' branch_labels = None depends_on = None + def upgrade(): - op.add_column('post', sa.Column('mime-type', sa.Unicode(length=32), nullable=False)) + op.add_column( + 'post', sa.Column('mime-type', sa.Unicode(length=32), nullable=False)) + def downgrade(): op.drop_column('post', 'mime-type') diff --git a/server/szurubooru/migrations/versions/336a76ec1338_create_post_tables.py b/server/szurubooru/migrations/versions/336a76ec1338_create_post_tables.py index 2fc0e163..b767c98f 100644 --- a/server/szurubooru/migrations/versions/336a76ec1338_create_post_tables.py +++ b/server/szurubooru/migrations/versions/336a76ec1338_create_post_tables.py @@ -13,6 +13,7 @@ down_revision = '00cb3a2734db' branch_labels = None depends_on = None + def upgrade(): op.create_table( 'post', @@ -56,6 +57,7 @@ def upgrade(): sa.ForeignKeyConstraint(['tag_id'], ['tag.id']), sa.PrimaryKeyConstraint('post_id', 'tag_id')) + def downgrade(): op.drop_table('post_tag') op.drop_table('post_relation') diff --git a/server/szurubooru/migrations/versions/46cd5229839b_add_snapshot_resource_repr.py b/server/szurubooru/migrations/versions/46cd5229839b_add_snapshot_resource_repr.py index a33273e5..0a46fbcc 100644 --- a/server/szurubooru/migrations/versions/46cd5229839b_add_snapshot_resource_repr.py +++ b/server/szurubooru/migrations/versions/46cd5229839b_add_snapshot_resource_repr.py @@ -13,10 +13,12 @@ down_revision = '565e01e3cf6d' branch_labels = None depends_on = None + def upgrade(): op.add_column( 'snapshot', sa.Column('resource_repr', sa.Unicode(length=64), nullable=False)) + def downgrade(): op.drop_column('snapshot', 'resource_repr') diff --git a/server/szurubooru/migrations/versions/46df355634dc_add_comment_tables.py b/server/szurubooru/migrations/versions/46df355634dc_add_comment_tables.py index 6e8f8052..49971fef 100644 --- a/server/szurubooru/migrations/versions/46df355634dc_add_comment_tables.py +++ b/server/szurubooru/migrations/versions/46df355634dc_add_comment_tables.py @@ -13,6 +13,7 @@ down_revision = '84bd402f15f0' branch_labels = None depends_on = None + def upgrade(): op.create_table( 'comment', @@ -36,6 +37,7 @@ def upgrade(): sa.ForeignKeyConstraint(['user_id'], ['user.id']), sa.PrimaryKeyConstraint('comment_id', 'user_id')) + def downgrade(): op.drop_table('comment_score') op.drop_table('comment') diff --git a/server/szurubooru/migrations/versions/49ab4e1139ef_create_indexes.py b/server/szurubooru/migrations/versions/49ab4e1139ef_create_indexes.py index e30ccb92..b18e4108 100644 --- a/server/szurubooru/migrations/versions/49ab4e1139ef_create_indexes.py +++ b/server/szurubooru/migrations/versions/49ab4e1139ef_create_indexes.py @@ -13,52 +13,59 @@ down_revision = '23abaf4a0a4b' branch_labels = None depends_on = None + def upgrade(): - op.create_index(op.f('ix_comment_post_id'), 'comment', ['post_id'], unique=False) - op.create_index(op.f('ix_comment_user_id'), 'comment', ['user_id'], unique=False) - op.create_index(op.f('ix_comment_score_user_id'), 'comment_score', ['user_id'], unique=False) - op.create_index(op.f('ix_post_user_id'), 'post', ['user_id'], unique=False) - op.create_index(op.f('ix_post_favorite_post_id'), 'post_favorite', ['post_id'], unique=False) - op.create_index(op.f('ix_post_favorite_user_id'), 'post_favorite', ['user_id'], unique=False) - op.create_index(op.f('ix_post_feature_post_id'), 'post_feature', ['post_id'], unique=False) - op.create_index(op.f('ix_post_feature_user_id'), 'post_feature', ['user_id'], unique=False) - op.create_index(op.f('ix_post_note_post_id'), 'post_note', ['post_id'], unique=False) - op.create_index(op.f('ix_post_relation_child_id'), 'post_relation', ['child_id'], unique=False) - op.create_index(op.f('ix_post_relation_parent_id'), 'post_relation', ['parent_id'], unique=False) - op.create_index(op.f('ix_post_score_post_id'), 'post_score', ['post_id'], unique=False) - op.create_index(op.f('ix_post_score_user_id'), 'post_score', ['user_id'], unique=False) - op.create_index(op.f('ix_post_tag_post_id'), 'post_tag', ['post_id'], unique=False) - op.create_index(op.f('ix_post_tag_tag_id'), 'post_tag', ['tag_id'], unique=False) - op.create_index(op.f('ix_snapshot_resource_id'), 'snapshot', ['resource_id'], unique=False) - op.create_index(op.f('ix_snapshot_resource_type'), 'snapshot', ['resource_type'], unique=False) - op.create_index(op.f('ix_tag_category_id'), 'tag', ['category_id'], unique=False) - op.create_index(op.f('ix_tag_implication_child_id'), 'tag_implication', ['child_id'], unique=False) - op.create_index(op.f('ix_tag_implication_parent_id'), 'tag_implication', ['parent_id'], unique=False) - op.create_index(op.f('ix_tag_name_tag_id'), 'tag_name', ['tag_id'], unique=False) - op.create_index(op.f('ix_tag_suggestion_child_id'), 'tag_suggestion', ['child_id'], unique=False) - op.create_index(op.f('ix_tag_suggestion_parent_id'), 'tag_suggestion', ['parent_id'], unique=False) + for index_name, table_name, column_name in [ + ('ix_comment_post_id', 'comment', 'post_id'), + ('ix_comment_user_id', 'comment', 'user_id'), + ('ix_comment_score_user_id', 'comment_score', 'user_id'), + ('ix_post_user_id', 'post', 'user_id'), + ('ix_post_favorite_post_id', 'post_favorite', 'post_id'), + ('ix_post_favorite_user_id', 'post_favorite', 'user_id'), + ('ix_post_feature_post_id', 'post_feature', 'post_id'), + ('ix_post_feature_user_id', 'post_feature', 'user_id'), + ('ix_post_note_post_id', 'post_note', 'post_id'), + ('ix_post_relation_child_id', 'post_relation', 'child_id'), + ('ix_post_relation_parent_id', 'post_relation', 'parent_id'), + ('ix_post_score_post_id', 'post_score', 'post_id'), + ('ix_post_score_user_id', 'post_score', 'user_id'), + ('ix_post_tag_post_id', 'post_tag', 'post_id'), + ('ix_post_tag_tag_id', 'post_tag', 'tag_id'), + ('ix_snapshot_resource_id', 'snapshot', 'resource_id'), + ('ix_snapshot_resource_type', 'snapshot', 'resource_type'), + ('ix_tag_category_id', 'tag', 'category_id'), + ('ix_tag_implication_child_id', 'tag_implication', 'child_id'), + ('ix_tag_implication_parent_id', 'tag_implication', 'parent_id'), + ('ix_tag_name_tag_id', 'tag_name', 'tag_id'), + ('ix_tag_suggestion_child_id', 'tag_suggestion', 'child_id'), + ('ix_tag_suggestion_parent_id', 'tag_suggestion', 'parent_id')]: + op.create_index( + op.f(index_name), table_name, [column_name], unique=False) + def downgrade(): - op.drop_index(op.f('ix_tag_suggestion_parent_id'), table_name='tag_suggestion') - op.drop_index(op.f('ix_tag_suggestion_child_id'), table_name='tag_suggestion') - op.drop_index(op.f('ix_tag_name_tag_id'), table_name='tag_name') - op.drop_index(op.f('ix_tag_implication_parent_id'), table_name='tag_implication') - op.drop_index(op.f('ix_tag_implication_child_id'), table_name='tag_implication') - op.drop_index(op.f('ix_tag_category_id'), table_name='tag') - op.drop_index(op.f('ix_snapshot_resource_type'), table_name='snapshot') - op.drop_index(op.f('ix_snapshot_resource_id'), table_name='snapshot') - op.drop_index(op.f('ix_post_tag_tag_id'), table_name='post_tag') - op.drop_index(op.f('ix_post_tag_post_id'), table_name='post_tag') - op.drop_index(op.f('ix_post_score_user_id'), table_name='post_score') - op.drop_index(op.f('ix_post_score_post_id'), table_name='post_score') - op.drop_index(op.f('ix_post_relation_parent_id'), table_name='post_relation') - op.drop_index(op.f('ix_post_relation_child_id'), table_name='post_relation') - op.drop_index(op.f('ix_post_note_post_id'), table_name='post_note') - op.drop_index(op.f('ix_post_feature_user_id'), table_name='post_feature') - op.drop_index(op.f('ix_post_feature_post_id'), table_name='post_feature') - op.drop_index(op.f('ix_post_favorite_user_id'), table_name='post_favorite') - op.drop_index(op.f('ix_post_favorite_post_id'), table_name='post_favorite') - op.drop_index(op.f('ix_post_user_id'), table_name='post') - op.drop_index(op.f('ix_comment_score_user_id'), table_name='comment_score') - op.drop_index(op.f('ix_comment_user_id'), table_name='comment') - op.drop_index(op.f('ix_comment_post_id'), table_name='comment') + for index_name, table_name in [ + ('ix_tag_suggestion_parent_id', 'tag_suggestion'), + ('ix_tag_suggestion_child_id', 'tag_suggestion'), + ('ix_tag_name_tag_id', 'tag_name'), + ('ix_tag_implication_parent_id', 'tag_implication'), + ('ix_tag_implication_child_id', 'tag_implication'), + ('ix_tag_category_id', 'tag'), + ('ix_snapshot_resource_type', 'snapshot'), + ('ix_snapshot_resource_id', 'snapshot'), + ('ix_post_tag_tag_id', 'post_tag'), + ('ix_post_tag_post_id', 'post_tag'), + ('ix_post_score_user_id', 'post_score'), + ('ix_post_score_post_id', 'post_score'), + ('ix_post_relation_parent_id', 'post_relation'), + ('ix_post_relation_child_id', 'post_relation'), + ('ix_post_note_post_id', 'post_note'), + ('ix_post_feature_user_id', 'post_feature'), + ('ix_post_feature_post_id', 'post_feature'), + ('ix_post_favorite_user_id', 'post_favorite'), + ('ix_post_favorite_post_id', 'post_favorite'), + ('ix_post_user_id', 'post'), + ('ix_comment_score_user_id', 'comment_score'), + ('ix_comment_user_id', 'comment'), + ('ix_comment_post_id', 'comment')]: + op.drop_index(op.f(index_name), table_name=table_name) diff --git a/server/szurubooru/migrations/versions/4c526f869323_add_description_to_tags.py b/server/szurubooru/migrations/versions/4c526f869323_add_description_to_tags.py index 3aff3ad6..f53866f1 100644 --- a/server/szurubooru/migrations/versions/4c526f869323_add_description_to_tags.py +++ b/server/szurubooru/migrations/versions/4c526f869323_add_description_to_tags.py @@ -13,8 +13,11 @@ down_revision = '055d0e048fb3' branch_labels = None depends_on = None + def upgrade(): - op.add_column('tag', sa.Column('description', sa.UnicodeText(), nullable=True)) + op.add_column( + 'tag', sa.Column('description', sa.UnicodeText(), nullable=True)) + def downgrade(): op.drop_column('tag', 'description') diff --git a/server/szurubooru/migrations/versions/565e01e3cf6d_create_snapshot_table.py b/server/szurubooru/migrations/versions/565e01e3cf6d_create_snapshot_table.py index e2eefb6a..475fe96b 100644 --- a/server/szurubooru/migrations/versions/565e01e3cf6d_create_snapshot_table.py +++ b/server/szurubooru/migrations/versions/565e01e3cf6d_create_snapshot_table.py @@ -13,6 +13,7 @@ down_revision = '336a76ec1338' branch_labels = None depends_on = None + def upgrade(): op.create_table( 'snapshot', @@ -26,5 +27,6 @@ def upgrade(): sa.ForeignKeyConstraint(['user_id'], ['user.id']), sa.PrimaryKeyConstraint('id')) + def downgrade(): op.drop_table('snapshot') diff --git a/server/szurubooru/migrations/versions/7f6baf38c27c_add_versions.py b/server/szurubooru/migrations/versions/7f6baf38c27c_add_versions.py index c74bc5a6..71e6bcb0 100644 --- a/server/szurubooru/migrations/versions/7f6baf38c27c_add_versions.py +++ b/server/szurubooru/migrations/versions/7f6baf38c27c_add_versions.py @@ -15,12 +15,17 @@ depends_on = None tables = ['tag_category', 'tag', 'user', 'post', 'comment'] + def upgrade(): for table in tables: op.add_column(table, sa.Column('version', sa.Integer(), nullable=True)) - op.execute(sa.table(table, sa.column('version')).update().values(version=1)) + op.execute( + sa.table(table, sa.column('version')) + .update() + .values(version=1)) op.alter_column(table, 'version', nullable=False) + def downgrade(): for table in tables: op.drop_column(table, 'version') diff --git a/server/szurubooru/migrations/versions/84bd402f15f0_change_flags_column_type.py b/server/szurubooru/migrations/versions/84bd402f15f0_change_flags_column_type.py index adc040eb..72366413 100644 --- a/server/szurubooru/migrations/versions/84bd402f15f0_change_flags_column_type.py +++ b/server/szurubooru/migrations/versions/84bd402f15f0_change_flags_column_type.py @@ -13,10 +13,14 @@ down_revision = '9587de88a84b' branch_labels = None depends_on = None + def upgrade(): op.drop_column('post', 'flags') op.add_column('post', sa.Column('flags', sa.PickleType(), nullable=True)) + def downgrade(): op.drop_column('post', 'flags') - op.add_column('post', sa.Column('flags', sa.Integer(), autoincrement=False, nullable=False)) + op.add_column( + 'post', + sa.Column('flags', sa.Integer(), autoincrement=False, nullable=False)) diff --git a/server/szurubooru/migrations/versions/9587de88a84b_create_aux_post_tables.py b/server/szurubooru/migrations/versions/9587de88a84b_create_aux_post_tables.py index 769e8736..eddec24f 100644 --- a/server/szurubooru/migrations/versions/9587de88a84b_create_aux_post_tables.py +++ b/server/szurubooru/migrations/versions/9587de88a84b_create_aux_post_tables.py @@ -13,6 +13,7 @@ down_revision = '46cd5229839b' branch_labels = None depends_on = None + def upgrade(): op.create_table( 'post_favorite', @@ -52,6 +53,7 @@ def upgrade(): sa.ForeignKeyConstraint(['user_id'], ['user.id']), sa.PrimaryKeyConstraint('post_id', 'user_id')) + def downgrade(): op.drop_table('post_score') op.drop_table('post_note') diff --git a/server/szurubooru/migrations/versions/e5c1216a8503_create_user_table.py b/server/szurubooru/migrations/versions/e5c1216a8503_create_user_table.py index a9c0fb16..a84e31ba 100644 --- a/server/szurubooru/migrations/versions/e5c1216a8503_create_user_table.py +++ b/server/szurubooru/migrations/versions/e5c1216a8503_create_user_table.py @@ -13,6 +13,7 @@ down_revision = None branch_labels = None depends_on = None + def upgrade(): op.create_table( 'user', @@ -28,5 +29,6 @@ def upgrade(): sa.PrimaryKeyConstraint('id')) op.create_unique_constraint('uq_user_name', 'user', ['name']) + def downgrade(): op.drop_table('user') diff --git a/server/szurubooru/migrations/versions/ed6dd16a30f3_delete_post_columns.py b/server/szurubooru/migrations/versions/ed6dd16a30f3_delete_post_columns.py index b40f6638..f51da444 100644 --- a/server/szurubooru/migrations/versions/ed6dd16a30f3_delete_post_columns.py +++ b/server/szurubooru/migrations/versions/ed6dd16a30f3_delete_post_columns.py @@ -13,24 +13,36 @@ down_revision = '46df355634dc' branch_labels = None depends_on = None + def upgrade(): - op.drop_column('post', 'auto_comment_edit_time') - op.drop_column('post', 'auto_fav_count') - op.drop_column('post', 'auto_comment_creation_time') - op.drop_column('post', 'auto_feature_count') - op.drop_column('post', 'auto_comment_count') - op.drop_column('post', 'auto_score') - op.drop_column('post', 'auto_fav_time') - op.drop_column('post', 'auto_feature_time') - op.drop_column('post', 'auto_note_count') + for column_name in [ + 'auto_comment_edit_time' + 'auto_fav_count', + 'auto_comment_creation_time', + 'auto_feature_count', + 'auto_comment_count', + 'auto_score', + 'auto_fav_time', + 'auto_feature_time', + 'auto_note_count']: + op.drop_column('post', column_name) + def downgrade(): - op.add_column('post', sa.Column('auto_note_count', sa.INTEGER(), autoincrement=False, nullable=False)) - op.add_column('post', sa.Column('auto_feature_time', sa.INTEGER(), autoincrement=False, nullable=False)) - op.add_column('post', sa.Column('auto_fav_time', sa.INTEGER(), autoincrement=False, nullable=False)) - op.add_column('post', sa.Column('auto_score', sa.INTEGER(), autoincrement=False, nullable=False)) - op.add_column('post', sa.Column('auto_comment_count', sa.INTEGER(), autoincrement=False, nullable=False)) - op.add_column('post', sa.Column('auto_feature_count', sa.INTEGER(), autoincrement=False, nullable=False)) - op.add_column('post', sa.Column('auto_comment_creation_time', sa.INTEGER(), autoincrement=False, nullable=False)) - op.add_column('post', sa.Column('auto_fav_count', sa.INTEGER(), autoincrement=False, nullable=False)) - op.add_column('post', sa.Column('auto_comment_edit_time', sa.INTEGER(), autoincrement=False, nullable=False)) + for column_name in [ + 'auto_note_count', + 'auto_feature_time', + 'auto_fav_time', + 'auto_score', + 'auto_comment_count', + 'auto_feature_count', + 'auto_comment_creation_time', + 'auto_fav_count', + 'auto_comment_edit_time']: + op.add_column( + 'post', + sa.Column( + column_name, + sa.INTEGER(), + autoincrement=False, + nullable=False)) diff --git a/server/szurubooru/rest/app.py b/server/szurubooru/rest/app.py index 5a5b21b2..150af668 100644 --- a/server/szurubooru/rest/app.py +++ b/server/szurubooru/rest/app.py @@ -6,6 +6,7 @@ from datetime import datetime from szurubooru.func import util from szurubooru.rest import errors, middleware, routes, context + def _json_serializer(obj): ''' JSON serializer for objects not serializable by default JSON code ''' if isinstance(obj, datetime): @@ -13,14 +14,16 @@ def _json_serializer(obj): return serial raise TypeError('Type not serializable') + def _dump_json(obj): return json.dumps(obj, default=_json_serializer, indent=2) + def _read(env): length = int(env.get('CONTENT_LENGTH', 0)) output = io.BytesIO() while length > 0: - part = env['wsgi.input'].read(min(length, 1024*200)) + part = env['wsgi.input'].read(min(length, 1024 * 200)) if not part: break output.write(part) @@ -28,6 +31,7 @@ def _read(env): output.seek(0) return output + def _get_headers(env): headers = {} for key, value in env.items(): @@ -36,6 +40,7 @@ def _get_headers(env): headers[key] = value return headers + def _create_context(env): method = env['REQUEST_METHOD'] path = '/' + env['PATH_INFO'].lstrip('/') @@ -56,7 +61,7 @@ def _create_context(env): if isinstance(form[key], cgi.MiniFieldStorage): params[key] = form.getvalue(key) else: - _original_file_name = getattr(form[key], 'filename', None) + # _user_file_name = getattr(form[key], 'filename', None) files[key] = form.getvalue(key) if 'metadata' in form: body = form.getvalue('metadata') @@ -79,10 +84,11 @@ def _create_context(env): return context.Context(method, path, headers, params, files) + def application(env, start_response): try: ctx = _create_context(env) - if not 'application/json' in ctx.get_header('Accept'): + if 'application/json' not in ctx.get_header('Accept'): raise errors.HttpNotAcceptable( 'This API only supports JSON responses.') diff --git a/server/szurubooru/rest/context.py b/server/szurubooru/rest/context.py index a6454a4e..3bc2705b 100644 --- a/server/szurubooru/rest/context.py +++ b/server/szurubooru/rest/context.py @@ -1,9 +1,11 @@ from szurubooru import errors from szurubooru.func import net + def _lower_first(source): return source[0].lower() + source[1:] + def _param_wrapper(func): def wrapper(self, name, required=False, default=None, **kwargs): # pylint: disable=protected-access @@ -22,8 +24,8 @@ def _param_wrapper(func): 'Required parameter %r is missing.' % name) return wrapper + class Context(): - # pylint: disable=too-many-arguments def __init__(self, method, url, headers=None, params=None, files=None): self.method = method self.url = url @@ -74,7 +76,6 @@ class Context(): raise errors.InvalidParameterError('Expected simple string.') return value - # pylint: disable=redefined-builtin @_param_wrapper def get_param_as_int(self, value, min=None, max=None): try: @@ -97,4 +98,5 @@ class Context(): return True if value in ['0', 'n', 'no', 'nope', 'f', 'false']: return False - raise errors.InvalidParameterError('The value must be a boolean value.') + raise errors.InvalidParameterError( + 'The value must be a boolean value.') diff --git a/server/szurubooru/rest/errors.py b/server/szurubooru/rest/errors.py index 9ada0235..f9755a74 100644 --- a/server/szurubooru/rest/errors.py +++ b/server/szurubooru/rest/errors.py @@ -1,4 +1,5 @@ -error_handlers = {} # pylint: disable=invalid-name +error_handlers = {} # pylint: disable=invalid-name + class BaseHttpError(RuntimeError): code = None @@ -9,29 +10,36 @@ class BaseHttpError(RuntimeError): self.description = description self.title = title or self.reason + class HttpBadRequest(BaseHttpError): code = 400 reason = 'Bad Request' + class HttpForbidden(BaseHttpError): code = 403 reason = 'Forbidden' + class HttpNotFound(BaseHttpError): code = 404 reason = 'Not Found' + class HttpNotAcceptable(BaseHttpError): code = 406 reason = 'Not Acceptable' + class HttpConflict(BaseHttpError): code = 409 reason = 'Conflict' + class HttpMethodNotAllowed(BaseHttpError): code = 405 reason = 'Method Not Allowed' + def handle(exception_type, handler): error_handlers[exception_type] = handler diff --git a/server/szurubooru/rest/middleware.py b/server/szurubooru/rest/middleware.py index e569d692..7cf07296 100644 --- a/server/szurubooru/rest/middleware.py +++ b/server/szurubooru/rest/middleware.py @@ -2,8 +2,10 @@ pre_hooks = [] post_hooks = [] + def pre_hook(handler): pre_hooks.append(handler) + def post_hook(handler): post_hooks.insert(0, handler) diff --git a/server/szurubooru/rest/routes.py b/server/szurubooru/rest/routes.py index f5567219..ffa95f56 100644 --- a/server/szurubooru/rest/routes.py +++ b/server/szurubooru/rest/routes.py @@ -1,6 +1,8 @@ from collections import defaultdict -routes = defaultdict(dict) # pylint: disable=invalid-name + +routes = defaultdict(dict) # pylint: disable=invalid-name + def get(url): def wrapper(handler): @@ -8,18 +10,21 @@ def get(url): return handler return wrapper + def put(url): def wrapper(handler): routes[url]['PUT'] = handler return handler return wrapper + def post(url): def wrapper(handler): routes[url]['POST'] = handler return handler return wrapper + def delete(url): def wrapper(handler): routes[url]['DELETE'] = handler diff --git a/server/szurubooru/search/configs/__init__.py b/server/szurubooru/search/configs/__init__.py index 9f48e14d..c7e3102f 100644 --- a/server/szurubooru/search/configs/__init__.py +++ b/server/szurubooru/search/configs/__init__.py @@ -1,5 +1,5 @@ -from szurubooru.search.configs.user_search_config import UserSearchConfig -from szurubooru.search.configs.snapshot_search_config import SnapshotSearchConfig -from szurubooru.search.configs.tag_search_config import TagSearchConfig -from szurubooru.search.configs.comment_search_config import CommentSearchConfig -from szurubooru.search.configs.post_search_config import PostSearchConfig +from .user_search_config import UserSearchConfig +from .tag_search_config import TagSearchConfig +from .post_search_config import PostSearchConfig +from .snapshot_search_config import SnapshotSearchConfig +from .comment_search_config import CommentSearchConfig diff --git a/server/szurubooru/search/configs/base_search_config.py b/server/szurubooru/search/configs/base_search_config.py index 350a54a8..4ee23b3e 100644 --- a/server/szurubooru/search/configs/base_search_config.py +++ b/server/szurubooru/search/configs/base_search_config.py @@ -1,5 +1,6 @@ from szurubooru.search import tokens + class BaseSearchConfig(object): SORT_ASC = tokens.SortToken.SORT_ASC SORT_DESC = tokens.SortToken.SORT_DESC diff --git a/server/szurubooru/search/configs/comment_search_config.py b/server/szurubooru/search/configs/comment_search_config.py index b921492f..693ccb80 100644 --- a/server/szurubooru/search/configs/comment_search_config.py +++ b/server/szurubooru/search/configs/comment_search_config.py @@ -3,6 +3,7 @@ from szurubooru import db from szurubooru.search.configs import util as search_util from szurubooru.search.configs.base_search_config import BaseSearchConfig + class CommentSearchConfig(BaseSearchConfig): def create_filter_query(self): return db.session.query(db.Comment).join(db.User) @@ -22,12 +23,18 @@ class CommentSearchConfig(BaseSearchConfig): 'user': search_util.create_str_filter(db.User.name), 'author': search_util.create_str_filter(db.User.name), 'text': search_util.create_str_filter(db.Comment.text), - 'creation-date': search_util.create_date_filter(db.Comment.creation_time), - 'creation-time': search_util.create_date_filter(db.Comment.creation_time), - 'last-edit-date': search_util.create_date_filter(db.Comment.last_edit_time), - 'last-edit-time': search_util.create_date_filter(db.Comment.last_edit_time), - 'edit-date': search_util.create_date_filter(db.Comment.last_edit_time), - 'edit-time': search_util.create_date_filter(db.Comment.last_edit_time), + 'creation-date': + search_util.create_date_filter(db.Comment.creation_time), + 'creation-time': + search_util.create_date_filter(db.Comment.creation_time), + 'last-edit-date': + search_util.create_date_filter(db.Comment.last_edit_time), + 'last-edit-time': + search_util.create_date_filter(db.Comment.last_edit_time), + 'edit-date': + search_util.create_date_filter(db.Comment.last_edit_time), + 'edit-time': + search_util.create_date_filter(db.Comment.last_edit_time), } @property diff --git a/server/szurubooru/search/configs/post_search_config.py b/server/szurubooru/search/configs/post_search_config.py index 4fca489f..3a578478 100644 --- a/server/szurubooru/search/configs/post_search_config.py +++ b/server/szurubooru/search/configs/post_search_config.py @@ -6,6 +6,7 @@ from szurubooru.search import criteria, tokens from szurubooru.search.configs import util as search_util from szurubooru.search.configs.base_search_config import BaseSearchConfig + def _enum_transformer(available_values, value): try: return available_values[value.lower()] @@ -14,6 +15,7 @@ def _enum_transformer(available_values, value): 'Invalid value: %r. Possible values: %r.' % ( value, list(sorted(available_values.keys())))) + def _type_transformer(value): available_values = { 'image': db.Post.TYPE_IMAGE, @@ -28,6 +30,7 @@ def _type_transformer(value): } return _enum_transformer(available_values, value) + def _safety_transformer(value): available_values = { 'safe': db.Post.SAFETY_SAFE, @@ -37,11 +40,12 @@ def _safety_transformer(value): } return _enum_transformer(available_values, value) + def _create_score_filter(score): def wrapper(query, criterion, negated): if not getattr(criterion, 'internal', False): raise errors.SearchError( - 'Votes cannot be seen publicly. Did you mean %r?' \ + 'Votes cannot be seen publicly. Did you mean %r?' % 'special:liked') user_alias = aliased(db.User) score_alias = aliased(db.PostScore) @@ -57,6 +61,7 @@ def _create_score_filter(score): return ret return wrapper + class PostSearchConfig(BaseSearchConfig): def on_search_query_parsed(self, search_query): new_special_tokens = [] @@ -64,7 +69,8 @@ class PostSearchConfig(BaseSearchConfig): if token.value in ('fav', 'liked', 'disliked'): assert self.user if self.user.rank == 'anonymous': - raise errors.SearchError('Must be logged in to use this feature.') + raise errors.SearchError( + 'Must be logged in to use this feature.') criterion = criteria.PlainCriterion( original_text=self.user.name, value=self.user.name) @@ -85,9 +91,9 @@ class PostSearchConfig(BaseSearchConfig): return self.create_count_query() \ .options( # use config optimized for official client - #defer(db.Post.score), - #defer(db.Post.favorite_count), - #defer(db.Post.comment_count), + # defer(db.Post.score), + # defer(db.Post.favorite_count), + # defer(db.Post.comment_count), defer(db.Post.last_favorite_time), defer(db.Post.feature_count), defer(db.Post.last_feature_time), @@ -99,8 +105,7 @@ class PostSearchConfig(BaseSearchConfig): lazyload(db.Post.user), lazyload(db.Post.relations), lazyload(db.Post.notes), - lazyload(db.Post.favorited_by), - ) + lazyload(db.Post.favorited_by)) def create_count_query(self): return db.session.query(db.Post) @@ -153,12 +158,18 @@ class PostSearchConfig(BaseSearchConfig): 'liked': _create_score_filter(1), 'disliked': _create_score_filter(-1), 'tag-count': search_util.create_num_filter(db.Post.tag_count), - 'comment-count': search_util.create_num_filter(db.Post.comment_count), - 'fav-count': search_util.create_num_filter(db.Post.favorite_count), + 'comment-count': + search_util.create_num_filter(db.Post.comment_count), + 'fav-count': + search_util.create_num_filter(db.Post.favorite_count), 'note-count': search_util.create_num_filter(db.Post.note_count), - 'relation-count': search_util.create_num_filter(db.Post.relation_count), - 'feature-count': search_util.create_num_filter(db.Post.feature_count), - 'type': search_util.create_str_filter(db.Post.type, _type_transformer), + 'relation-count': + search_util.create_num_filter(db.Post.relation_count), + 'feature-count': + search_util.create_num_filter(db.Post.feature_count), + 'type': + search_util.create_str_filter( + db.Post.type, _type_transformer), 'file-size': search_util.create_num_filter(db.Post.file_size), ('image-width', 'width'): search_util.create_num_filter(db.Post.canvas_width), @@ -171,13 +182,15 @@ class PostSearchConfig(BaseSearchConfig): ('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'): search_util.create_date_filter(db.Post.last_edit_time), ('comment-date', 'comment-time'): - search_util.create_date_filter(db.Post.last_comment_creation_time), + search_util.create_date_filter( + db.Post.last_comment_creation_time), ('fav-date', 'fav-time'): search_util.create_date_filter(db.Post.last_favorite_time), ('feature-date', 'feature-time'): search_util.create_date_filter(db.Post.last_feature_time), ('safety', 'rating'): - search_util.create_str_filter(db.Post.safety, _safety_transformer), + search_util.create_str_filter( + db.Post.safety, _safety_transformer), }) @property @@ -193,9 +206,12 @@ class PostSearchConfig(BaseSearchConfig): 'relation-count': (db.Post.relation_count, self.SORT_DESC), 'feature-count': (db.Post.feature_count, self.SORT_DESC), 'file-size': (db.Post.file_size, self.SORT_DESC), - ('image-width', 'width'): (db.Post.canvas_width, self.SORT_DESC), - ('image-height', 'height'): (db.Post.canvas_height, self.SORT_DESC), - ('image-area', 'area'): (db.Post.canvas_area, self.SORT_DESC), + ('image-width', 'width'): + (db.Post.canvas_width, self.SORT_DESC), + ('image-height', 'height'): + (db.Post.canvas_height, self.SORT_DESC), + ('image-area', 'area'): + (db.Post.canvas_area, self.SORT_DESC), ('creation-date', 'creation-time', 'date', 'time'): (db.Post.creation_time, self.SORT_DESC), ('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'): diff --git a/server/szurubooru/search/configs/snapshot_search_config.py b/server/szurubooru/search/configs/snapshot_search_config.py index a5e7f9de..a1a90d3e 100644 --- a/server/szurubooru/search/configs/snapshot_search_config.py +++ b/server/szurubooru/search/configs/snapshot_search_config.py @@ -2,6 +2,7 @@ from szurubooru import db from szurubooru.search.configs import util as search_util from szurubooru.search.configs.base_search_config import BaseSearchConfig + class SnapshotSearchConfig(BaseSearchConfig): def create_filter_query(self): return db.session.query(db.Snapshot) diff --git a/server/szurubooru/search/configs/tag_search_config.py b/server/szurubooru/search/configs/tag_search_config.py index 05bbbcd4..615c8d02 100644 --- a/server/szurubooru/search/configs/tag_search_config.py +++ b/server/szurubooru/search/configs/tag_search_config.py @@ -5,6 +5,7 @@ from szurubooru.func import util from szurubooru.search.configs import util as search_util from szurubooru.search.configs.base_search_config import BaseSearchConfig + class TagSearchConfig(BaseSearchConfig): def create_filter_query(self): return self.create_count_query() \ @@ -13,8 +14,7 @@ class TagSearchConfig(BaseSearchConfig): subqueryload(db.Tag.names), subqueryload(db.Tag.category), subqueryload(db.Tag.suggestions).joinedload(db.Tag.names), - subqueryload(db.Tag.implications).joinedload(db.Tag.names) - ) + subqueryload(db.Tag.implications).joinedload(db.Tag.names)) def create_count_query(self): return db.session.query(db.Tag) diff --git a/server/szurubooru/search/configs/user_search_config.py b/server/szurubooru/search/configs/user_search_config.py index 36468919..8fb31dad 100644 --- a/server/szurubooru/search/configs/user_search_config.py +++ b/server/szurubooru/search/configs/user_search_config.py @@ -3,6 +3,7 @@ from szurubooru import db from szurubooru.search.configs import util as search_util from szurubooru.search.configs.base_search_config import BaseSearchConfig + class UserSearchConfig(BaseSearchConfig): ''' Executes searches related to the users. ''' @@ -20,12 +21,18 @@ class UserSearchConfig(BaseSearchConfig): def named_filters(self): return { 'name': search_util.create_str_filter(db.User.name), - 'creation-date': search_util.create_date_filter(db.User.creation_time), - 'creation-time': search_util.create_date_filter(db.User.creation_time), - 'last-login-date': search_util.create_date_filter(db.User.last_login_time), - 'last-login-time': search_util.create_date_filter(db.User.last_login_time), - 'login-date': search_util.create_date_filter(db.User.last_login_time), - 'login-time': search_util.create_date_filter(db.User.last_login_time), + 'creation-date': + search_util.create_date_filter(db.User.creation_time), + 'creation-time': + search_util.create_date_filter(db.User.creation_time), + 'last-login-date': + search_util.create_date_filter(db.User.last_login_time), + 'last-login-time': + search_util.create_date_filter(db.User.last_login_time), + 'login-date': + search_util.create_date_filter(db.User.last_login_time), + 'login-time': + search_util.create_date_filter(db.User.last_login_time), } @property diff --git a/server/szurubooru/search/configs/util.py b/server/szurubooru/search/configs/util.py index 5e770ae3..40fdd156 100644 --- a/server/szurubooru/search/configs/util.py +++ b/server/szurubooru/search/configs/util.py @@ -3,9 +3,11 @@ from szurubooru import db, errors from szurubooru.func import util from szurubooru.search import criteria + def wildcard_transformer(value): return value.replace('*', '%') + def apply_num_criterion_to_column(column, criterion): ''' Decorate SQLAlchemy filter on given column using supplied criterion. @@ -32,6 +34,7 @@ def apply_num_criterion_to_column(column, criterion): 'Criterion value %r must be a number.' % (criterion,)) return expr + def create_num_filter(column): def wrapper(query, criterion, negated): expr = apply_num_criterion_to_column( @@ -41,6 +44,7 @@ def create_num_filter(column): return query.filter(expr) return wrapper + def apply_str_criterion_to_column( column, criterion, transformer=wildcard_transformer): ''' @@ -59,6 +63,7 @@ def apply_str_criterion_to_column( assert False return expr + def create_str_filter(column, transformer=wildcard_transformer): def wrapper(query, criterion, negated): expr = apply_str_criterion_to_column( @@ -68,6 +73,7 @@ def create_str_filter(column, transformer=wildcard_transformer): return query.filter(expr) return wrapper + def apply_date_criterion_to_column(column, criterion): ''' Decorate SQLAlchemy filter on given column using supplied criterion. @@ -97,6 +103,7 @@ def apply_date_criterion_to_column(column, criterion): assert False return expr + def create_date_filter(column): def wrapper(query, criterion, negated): expr = apply_date_criterion_to_column( @@ -106,6 +113,7 @@ def create_date_filter(column): return query.filter(expr) return wrapper + def create_subquery_filter( left_id_column, right_id_column, @@ -113,6 +121,7 @@ def create_subquery_filter( filter_factory, subquery_decorator=None): filter_func = filter_factory(filter_column) + def wrapper(query, criterion, negated): subquery = db.session.query(right_id_column.label('foreign_id')) if subquery_decorator: @@ -121,4 +130,5 @@ def create_subquery_filter( subquery = filter_func(subquery, criterion, negated) subquery = subquery.subquery('t') return query.filter(left_id_column.in_(subquery)) + return wrapper diff --git a/server/szurubooru/search/criteria.py b/server/szurubooru/search/criteria.py index 2558c8cf..7db2fe88 100644 --- a/server/szurubooru/search/criteria.py +++ b/server/szurubooru/search/criteria.py @@ -5,6 +5,7 @@ class _BaseCriterion(object): def __repr__(self): return self.original_text + class RangedCriterion(_BaseCriterion): def __init__(self, original_text, min_value, max_value): super().__init__(original_text) @@ -14,6 +15,7 @@ class RangedCriterion(_BaseCriterion): def __hash__(self): return hash(('range', self.min_value, self.max_value)) + class PlainCriterion(_BaseCriterion): def __init__(self, original_text, value): super().__init__(original_text) @@ -22,6 +24,7 @@ class PlainCriterion(_BaseCriterion): def __hash__(self): return hash(self.value) + class ArrayCriterion(_BaseCriterion): def __init__(self, original_text, values): super().__init__(original_text) diff --git a/server/szurubooru/search/executor.py b/server/szurubooru/search/executor.py index 13238632..7f55fcee 100644 --- a/server/szurubooru/search/executor.py +++ b/server/szurubooru/search/executor.py @@ -3,19 +3,22 @@ from szurubooru import db, errors from szurubooru.func import cache from szurubooru.search import tokens, parser + def _format_dict_keys(source): return list(sorted(source.keys())) -def _get_direction(direction, default_direction): - if direction == tokens.SortToken.SORT_DEFAULT: - return default_direction - if direction == tokens.SortToken.SORT_NEGATED_DEFAULT: - if default_direction == tokens.SortToken.SORT_ASC: + +def _get_order(order, default_order): + if order == tokens.SortToken.SORT_DEFAULT: + return default_order + if order == tokens.SortToken.SORT_NEGATED_DEFAULT: + if default_order == tokens.SortToken.SORT_ASC: return tokens.SortToken.SORT_DESC - elif default_direction == tokens.SortToken.SORT_DESC: + elif default_order == tokens.SortToken.SORT_DESC: return tokens.SortToken.SORT_ASC assert False - return direction + return order + class Executor(object): ''' @@ -30,20 +33,26 @@ class Executor(object): def get_around(self, query_text, entity_id): search_query = self.parser.parse(query_text) self.config.on_search_query_parsed(search_query) - filter_query = self.config \ - .create_around_query() \ - .options(sqlalchemy.orm.lazyload('*')) - filter_query = self._prepare_db_query(filter_query, search_query, False) - prev_filter_query = filter_query \ - .filter(self.config.id_column < entity_id) \ - .order_by(None) \ - .order_by(sqlalchemy.func.abs(self.config.id_column - entity_id).asc()) \ - .limit(1) - next_filter_query = filter_query \ - .filter(self.config.id_column > entity_id) \ - .order_by(None) \ - .order_by(sqlalchemy.func.abs(self.config.id_column - entity_id).asc()) \ - .limit(1) + filter_query = ( + self.config + .create_around_query() + .options(sqlalchemy.orm.lazyload('*'))) + filter_query = self._prepare_db_query( + filter_query, search_query, False) + prev_filter_query = ( + filter_query + .filter(self.config.id_column < entity_id) + .order_by(None) + .order_by(sqlalchemy.func.abs( + self.config.id_column - entity_id).asc()) + .limit(1)) + next_filter_query = ( + filter_query + .filter(self.config.id_column > entity_id) + .order_by(None) + .order_by(sqlalchemy.func.abs( + self.config.id_column - entity_id).asc()) + .limit(1)) return [ next_filter_query.one_or_none(), prev_filter_query.one_or_none()] @@ -92,7 +101,8 @@ class Executor(object): def execute_and_serialize(self, ctx, serializer): query = ctx.get_param_as_string('query') page = ctx.get_param_as_int('page', default=1, min=1) - page_size = ctx.get_param_as_int('pageSize', default=100, min=1, max=100) + page_size = ctx.get_param_as_int( + 'pageSize', default=100, min=1, max=100) count, entities = self.execute(query, page, page_size) return { 'query': query, @@ -124,7 +134,8 @@ class Executor(object): for token in search_query.special_tokens: if token.value not in self.config.special_filters: raise errors.SearchError( - 'Unknown special token: %r. Available special tokens: %r.' % ( + 'Unknown special token: %r. ' + 'Available special tokens: %r.' % ( token.value, _format_dict_keys(self.config.special_filters))) db_query = self.config.special_filters[token.value]( @@ -134,14 +145,15 @@ class Executor(object): for token in search_query.sort_tokens: if token.name not in self.config.sort_columns: raise errors.SearchError( - 'Unknown sort token: %r. Available sort tokens: %r.' % ( + 'Unknown sort token: %r. ' + 'Available sort tokens: %r.' % ( token.name, _format_dict_keys(self.config.sort_columns))) - column, default_direction = self.config.sort_columns[token.name] - direction = _get_direction(token.direction, default_direction) - if direction == token.SORT_ASC: + column, default_order = self.config.sort_columns[token.name] + order = _get_order(token.order, default_order) + if order == token.SORT_ASC: db_query = db_query.order_by(column.asc()) - elif direction == token.SORT_DESC: + elif order == token.SORT_DESC: db_query = db_query.order_by(column.desc()) db_query = self.config.finalize_query(db_query) diff --git a/server/szurubooru/search/parser.py b/server/szurubooru/search/parser.py index 261c0d42..ab5c042f 100644 --- a/server/szurubooru/search/parser.py +++ b/server/szurubooru/search/parser.py @@ -2,6 +2,7 @@ import re from szurubooru import errors from szurubooru.search import criteria, tokens + def _create_criterion(original_value, value): if '..' in value: low, high = value.split('..', 1) @@ -13,10 +14,12 @@ def _create_criterion(original_value, value): original_value, value.split(',')) return criteria.PlainCriterion(original_value, value) + def _parse_anonymous(value, negated): criterion = _create_criterion(value, value) return tokens.AnonymousToken(criterion, negated) + def _parse_named(key, value, negated): original_value = value if key.endswith('-min'): @@ -28,34 +31,41 @@ def _parse_named(key, value, negated): criterion = _create_criterion(original_value, value) return tokens.NamedToken(key, criterion, negated) + def _parse_special(value, negated): return tokens.SpecialToken(value, negated) + def _parse_sort(value, negated): if value.count(',') == 0: - direction_str = None + order_str = None elif value.count(',') == 1: - value, direction_str = value.split(',') + value, order_str = value.split(',') else: raise errors.SearchError('Too many commas in sort style token.') try: - direction = { + order = { 'asc': tokens.SortToken.SORT_ASC, 'desc': tokens.SortToken.SORT_DESC, '': tokens.SortToken.SORT_DEFAULT, None: tokens.SortToken.SORT_DEFAULT, - }[direction_str] + }[order_str] except KeyError: raise errors.SearchError( - 'Unknown search direction: %r.' % direction_str) + 'Unknown search direction: %r.' % order_str) if negated: - direction = { - tokens.SortToken.SORT_ASC: tokens.SortToken.SORT_DESC, - tokens.SortToken.SORT_DESC: tokens.SortToken.SORT_ASC, - tokens.SortToken.SORT_DEFAULT: tokens.SortToken.SORT_NEGATED_DEFAULT, - tokens.SortToken.SORT_NEGATED_DEFAULT: tokens.SortToken.SORT_DEFAULT, - }[direction] - return tokens.SortToken(value, direction) + order = { + tokens.SortToken.SORT_ASC: + tokens.SortToken.SORT_DESC, + tokens.SortToken.SORT_DESC: + tokens.SortToken.SORT_ASC, + tokens.SortToken.SORT_DEFAULT: + tokens.SortToken.SORT_NEGATED_DEFAULT, + tokens.SortToken.SORT_NEGATED_DEFAULT: + tokens.SortToken.SORT_DEFAULT, + }[order] + return tokens.SortToken(value, order) + class SearchQuery(): def __init__(self): @@ -71,6 +81,7 @@ class SearchQuery(): tuple(self.special_tokens), tuple(self.sort_tokens))) + class Parser(object): def parse(self, query_text): query = SearchQuery() @@ -93,5 +104,6 @@ class Parser(object): query.named_tokens.append( _parse_named(key, value, negated)) else: - query.anonymous_tokens.append(_parse_anonymous(chunk, negated)) + query.anonymous_tokens.append( + _parse_anonymous(chunk, negated)) return query diff --git a/server/szurubooru/search/tokens.py b/server/szurubooru/search/tokens.py index 4c4e46d2..e723fe52 100644 --- a/server/szurubooru/search/tokens.py +++ b/server/szurubooru/search/tokens.py @@ -6,6 +6,7 @@ class AnonymousToken(object): def __hash__(self): return hash((self.criterion, self.negated)) + class NamedToken(AnonymousToken): def __init__(self, name, criterion, negated): super().__init__(criterion, negated) @@ -14,18 +15,20 @@ class NamedToken(AnonymousToken): def __hash__(self): return hash((self.name, self.criterion, self.negated)) + class SortToken(object): SORT_DESC = 'desc' SORT_ASC = 'asc' SORT_DEFAULT = 'default' SORT_NEGATED_DEFAULT = 'negated default' - def __init__(self, name, direction): + def __init__(self, name, order): self.name = name - self.direction = direction + self.order = order def __hash__(self): - return hash((self.name, self.direction)) + return hash((self.name, self.order)) + class SpecialToken(object): def __init__(self, value, negated): diff --git a/server/szurubooru/tests/__init__.py b/server/szurubooru/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/szurubooru/tests/api/__init__.py b/server/szurubooru/tests/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/szurubooru/tests/api/test_comment_creating.py b/server/szurubooru/tests/api/test_comment_creating.py index 68614d75..c7d0b0f6 100644 --- a/server/szurubooru/tests/api/test_comment_creating.py +++ b/server/szurubooru/tests/api/test_comment_creating.py @@ -1,20 +1,22 @@ -import pytest -import unittest.mock from datetime import datetime +from unittest.mock import patch +import pytest from szurubooru import api, db, errors from szurubooru.func import comments, posts + @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({'privileges': {'comments:create': db.User.RANK_REGULAR}}) + def test_creating_comment( user_factory, post_factory, context_factory, fake_datetime): post = post_factory() user = user_factory(rank=db.User.RANK_REGULAR) db.session.add_all([post, user]) db.session.flush() - with unittest.mock.patch('szurubooru.func.comments.serialize_comment'), \ + with patch('szurubooru.func.comments.serialize_comment'), \ fake_datetime('1997-01-01'): comments.serialize_comment.return_value = 'serialized comment' result = api.comment_api.create_comment( @@ -29,6 +31,7 @@ def test_creating_comment( assert comment.user and comment.user.user_id == user.user_id assert comment.post and comment.post.post_id == post.post_id + @pytest.mark.parametrize('params', [ {'text': None}, {'text': ''}, @@ -48,6 +51,7 @@ def test_trying_to_pass_invalid_params( api.comment_api.create_comment( context_factory(params=real_params, user=user)) + @pytest.mark.parametrize('field', ['text', 'postId']) def test_trying_to_omit_mandatory_field(user_factory, context_factory, field): params = { @@ -61,6 +65,7 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field): params={}, user=user_factory(rank=db.User.RANK_REGULAR))) + def test_trying_to_comment_non_existing(user_factory, context_factory): user = user_factory(rank=db.User.RANK_REGULAR) db.session.add_all([user]) @@ -70,6 +75,7 @@ def test_trying_to_comment_non_existing(user_factory, context_factory): context_factory( params={'text': 'bad', 'postId': 5}, user=user)) + def test_trying_to_create_without_privileges(user_factory, context_factory): with pytest.raises(errors.AuthError): api.comment_api.create_comment( diff --git a/server/szurubooru/tests/api/test_comment_deleting.py b/server/szurubooru/tests/api/test_comment_deleting.py index 99b875f9..efb432a6 100644 --- a/server/szurubooru/tests/api/test_comment_deleting.py +++ b/server/szurubooru/tests/api/test_comment_deleting.py @@ -2,6 +2,7 @@ import pytest from szurubooru import api, db, errors from szurubooru.func import comments + @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({ @@ -11,6 +12,7 @@ def inject_config(config_injector): }, }) + def test_deleting_own_comment(user_factory, comment_factory, context_factory): user = user_factory() comment = comment_factory(user=user) @@ -22,6 +24,7 @@ def test_deleting_own_comment(user_factory, comment_factory, context_factory): assert result == {} assert db.session.query(db.Comment).count() == 0 + def test_deleting_someones_else_comment( user_factory, comment_factory, context_factory): user1 = user_factory(rank=db.User.RANK_REGULAR) @@ -29,11 +32,12 @@ def test_deleting_someones_else_comment( comment = comment_factory(user=user1) db.session.add(comment) db.session.commit() - result = api.comment_api.delete_comment( + api.comment_api.delete_comment( context_factory(params={'version': 1}, user=user2), {'comment_id': comment.comment_id}) assert db.session.query(db.Comment).count() == 0 + def test_trying_to_delete_someones_else_comment_without_privileges( user_factory, comment_factory, context_factory): user1 = user_factory(rank=db.User.RANK_REGULAR) @@ -47,6 +51,7 @@ def test_trying_to_delete_someones_else_comment_without_privileges( {'comment_id': comment.comment_id}) assert db.session.query(db.Comment).count() == 1 + def test_trying_to_delete_non_existing(user_factory, context_factory): with pytest.raises(comments.CommentNotFoundError): api.comment_api.delete_comment( diff --git a/server/szurubooru/tests/api/test_comment_rating.py b/server/szurubooru/tests/api/test_comment_rating.py index 1191f2fc..981e0dd8 100644 --- a/server/szurubooru/tests/api/test_comment_rating.py +++ b/server/szurubooru/tests/api/test_comment_rating.py @@ -1,86 +1,92 @@ +from unittest.mock import patch import pytest -import unittest.mock from szurubooru import api, db, errors -from szurubooru.func import comments, scores +from szurubooru.func import comments + @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({'privileges': {'comments:score': db.User.RANK_REGULAR}}) + def test_simple_rating( user_factory, comment_factory, context_factory, fake_datetime): user = user_factory(rank=db.User.RANK_REGULAR) comment = comment_factory(user=user) db.session.add(comment) db.session.commit() - with unittest.mock.patch('szurubooru.func.comments.serialize_comment'): + with patch('szurubooru.func.comments.serialize_comment'), \ + fake_datetime('1997-12-01'): comments.serialize_comment.return_value = 'serialized comment' - with fake_datetime('1997-12-01'): - result = api.comment_api.set_comment_score( - context_factory(params={'score': 1}, user=user), - {'comment_id': comment.comment_id}) + result = api.comment_api.set_comment_score( + context_factory(params={'score': 1}, user=user), + {'comment_id': comment.comment_id}) assert result == 'serialized comment' assert db.session.query(db.CommentScore).count() == 1 assert comment is not None assert comment.score == 1 + def test_updating_rating( user_factory, comment_factory, context_factory, fake_datetime): user = user_factory(rank=db.User.RANK_REGULAR) comment = comment_factory(user=user) db.session.add(comment) db.session.commit() - with unittest.mock.patch('szurubooru.func.comments.serialize_comment'): + with patch('szurubooru.func.comments.serialize_comment'): with fake_datetime('1997-12-01'): - result = api.comment_api.set_comment_score( + api.comment_api.set_comment_score( context_factory(params={'score': 1}, user=user), {'comment_id': comment.comment_id}) with fake_datetime('1997-12-02'): - result = api.comment_api.set_comment_score( + api.comment_api.set_comment_score( context_factory(params={'score': -1}, user=user), {'comment_id': comment.comment_id}) comment = db.session.query(db.Comment).one() assert db.session.query(db.CommentScore).count() == 1 assert comment.score == -1 + def test_updating_rating_to_zero( user_factory, comment_factory, context_factory, fake_datetime): user = user_factory(rank=db.User.RANK_REGULAR) comment = comment_factory(user=user) db.session.add(comment) db.session.commit() - with unittest.mock.patch('szurubooru.func.comments.serialize_comment'): + with patch('szurubooru.func.comments.serialize_comment'): with fake_datetime('1997-12-01'): - result = api.comment_api.set_comment_score( + api.comment_api.set_comment_score( context_factory(params={'score': 1}, user=user), {'comment_id': comment.comment_id}) with fake_datetime('1997-12-02'): - result = api.comment_api.set_comment_score( + api.comment_api.set_comment_score( context_factory(params={'score': 0}, user=user), {'comment_id': comment.comment_id}) comment = db.session.query(db.Comment).one() assert db.session.query(db.CommentScore).count() == 0 assert comment.score == 0 + def test_deleting_rating( user_factory, comment_factory, context_factory, fake_datetime): user = user_factory(rank=db.User.RANK_REGULAR) comment = comment_factory(user=user) db.session.add(comment) db.session.commit() - with unittest.mock.patch('szurubooru.func.comments.serialize_comment'): + with patch('szurubooru.func.comments.serialize_comment'): with fake_datetime('1997-12-01'): - result = api.comment_api.set_comment_score( + api.comment_api.set_comment_score( context_factory(params={'score': 1}, user=user), {'comment_id': comment.comment_id}) with fake_datetime('1997-12-02'): - result = api.comment_api.delete_comment_score( + api.comment_api.delete_comment_score( context_factory(user=user), {'comment_id': comment.comment_id}) comment = db.session.query(db.Comment).one() assert db.session.query(db.CommentScore).count() == 0 assert comment.score == 0 + def test_ratings_from_multiple_users( user_factory, comment_factory, context_factory, fake_datetime): user1 = user_factory(rank=db.User.RANK_REGULAR) @@ -88,19 +94,20 @@ def test_ratings_from_multiple_users( comment = comment_factory() db.session.add_all([user1, user2, comment]) db.session.commit() - with unittest.mock.patch('szurubooru.func.comments.serialize_comment'): + with patch('szurubooru.func.comments.serialize_comment'): with fake_datetime('1997-12-01'): - result = api.comment_api.set_comment_score( + api.comment_api.set_comment_score( context_factory(params={'score': 1}, user=user1), {'comment_id': comment.comment_id}) with fake_datetime('1997-12-02'): - result = api.comment_api.set_comment_score( + api.comment_api.set_comment_score( context_factory(params={'score': -1}, user=user2), {'comment_id': comment.comment_id}) comment = db.session.query(db.Comment).one() assert db.session.query(db.CommentScore).count() == 2 assert comment.score == 0 + def test_trying_to_omit_mandatory_field( user_factory, comment_factory, context_factory): user = user_factory() @@ -112,8 +119,8 @@ def test_trying_to_omit_mandatory_field( context_factory(params={}, user=user), {'comment_id': comment.comment_id}) -def test_trying_to_update_non_existing( - user_factory, comment_factory, context_factory): + +def test_trying_to_update_non_existing(user_factory, context_factory): with pytest.raises(comments.CommentNotFoundError): api.comment_api.set_comment_score( context_factory( @@ -121,6 +128,7 @@ def test_trying_to_update_non_existing( user=user_factory(rank=db.User.RANK_REGULAR)), {'comment_id': 5}) + def test_trying_to_rate_without_privileges( user_factory, comment_factory, context_factory): comment = comment_factory() diff --git a/server/szurubooru/tests/api/test_comment_retrieving.py b/server/szurubooru/tests/api/test_comment_retrieving.py index 5033f798..48bf3730 100644 --- a/server/szurubooru/tests/api/test_comment_retrieving.py +++ b/server/szurubooru/tests/api/test_comment_retrieving.py @@ -1,8 +1,9 @@ +from unittest.mock import patch import pytest -import unittest.mock from szurubooru import api, db, errors from szurubooru.func import comments + @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({ @@ -12,11 +13,12 @@ def inject_config(config_injector): }, }) + def test_retrieving_multiple(user_factory, comment_factory, context_factory): comment1 = comment_factory(text='text 1') comment2 = comment_factory(text='text 2') db.session.add_all([comment1, comment2]) - with unittest.mock.patch('szurubooru.func.comments.serialize_comment'): + with patch('szurubooru.func.comments.serialize_comment'): comments.serialize_comment.return_value = 'serialized comment' result = api.comment_api.get_comments( context_factory( @@ -30,6 +32,7 @@ def test_retrieving_multiple(user_factory, comment_factory, context_factory): 'results': ['serialized comment', 'serialized comment'], } + def test_trying_to_retrieve_multiple_without_privileges( user_factory, context_factory): with pytest.raises(errors.AuthError): @@ -38,11 +41,12 @@ def test_trying_to_retrieve_multiple_without_privileges( params={'query': '', 'page': 1}, user=user_factory(rank=db.User.RANK_ANONYMOUS))) + def test_retrieving_single(user_factory, comment_factory, context_factory): comment = comment_factory(text='dummy text') db.session.add(comment) db.session.flush() - with unittest.mock.patch('szurubooru.func.comments.serialize_comment'): + with patch('szurubooru.func.comments.serialize_comment'): comments.serialize_comment.return_value = 'serialized comment' result = api.comment_api.get_comment( context_factory( @@ -50,6 +54,7 @@ def test_retrieving_single(user_factory, comment_factory, context_factory): {'comment_id': comment.comment_id}) assert result == 'serialized comment' + def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): with pytest.raises(comments.CommentNotFoundError): api.comment_api.get_comment( @@ -57,6 +62,7 @@ def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): user=user_factory(rank=db.User.RANK_REGULAR)), {'comment_id': 5}) + def test_trying_to_retrieve_single_without_privileges( user_factory, context_factory): with pytest.raises(errors.AuthError): diff --git a/server/szurubooru/tests/api/test_comment_updating.py b/server/szurubooru/tests/api/test_comment_updating.py index 9023cb55..5f3d12b0 100644 --- a/server/szurubooru/tests/api/test_comment_updating.py +++ b/server/szurubooru/tests/api/test_comment_updating.py @@ -1,9 +1,10 @@ -import pytest -import unittest.mock from datetime import datetime +from unittest.mock import patch +import pytest from szurubooru import api, db, errors from szurubooru.func import comments + @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({ @@ -13,13 +14,14 @@ def inject_config(config_injector): }, }) + def test_simple_updating( user_factory, comment_factory, context_factory, fake_datetime): user = user_factory(rank=db.User.RANK_REGULAR) comment = comment_factory(user=user) db.session.add(comment) db.session.commit() - with unittest.mock.patch('szurubooru.func.comments.serialize_comment'), \ + with patch('szurubooru.func.comments.serialize_comment'), \ fake_datetime('1997-12-01'): comments.serialize_comment.return_value = 'serialized comment' result = api.comment_api.update_comment( @@ -29,6 +31,7 @@ def test_simple_updating( assert result == 'serialized comment' assert comment.last_edit_time == datetime(1997, 12, 1) + @pytest.mark.parametrize('params,expected_exception', [ ({'text': None}, comments.EmptyCommentTextError), ({'text': ''}, comments.EmptyCommentTextError), @@ -37,7 +40,11 @@ def test_simple_updating( ({'text': ['']}, comments.EmptyCommentTextError), ]) def test_trying_to_pass_invalid_params( - user_factory, comment_factory, context_factory, params, expected_exception): + user_factory, + comment_factory, + context_factory, + params, + expected_exception): user = user_factory() comment = comment_factory(user=user) db.session.add(comment) @@ -48,6 +55,7 @@ def test_trying_to_pass_invalid_params( params={**params, **{'version': 1}}, user=user), {'comment_id': comment.comment_id}) + def test_trying_to_omit_mandatory_field( user_factory, comment_factory, context_factory): user = user_factory() @@ -59,6 +67,7 @@ def test_trying_to_omit_mandatory_field( context_factory(params={'version': 1}, user=user), {'comment_id': comment.comment_id}) + def test_trying_to_update_non_existing(user_factory, context_factory): with pytest.raises(comments.CommentNotFoundError): api.comment_api.update_comment( @@ -67,6 +76,7 @@ def test_trying_to_update_non_existing(user_factory, context_factory): user=user_factory(rank=db.User.RANK_REGULAR)), {'comment_id': 5}) + def test_trying_to_update_someones_comment_without_privileges( user_factory, comment_factory, context_factory): user = user_factory(rank=db.User.RANK_REGULAR) @@ -80,6 +90,7 @@ def test_trying_to_update_someones_comment_without_privileges( params={'text': 'new text', 'version': 1}, user=user2), {'comment_id': comment.comment_id}) + def test_updating_someones_comment_with_privileges( user_factory, comment_factory, context_factory): user = user_factory(rank=db.User.RANK_REGULAR) @@ -87,7 +98,7 @@ def test_updating_someones_comment_with_privileges( comment = comment_factory(user=user) db.session.add(comment) db.session.commit() - with unittest.mock.patch('szurubooru.func.comments.serialize_comment'): + with patch('szurubooru.func.comments.serialize_comment'): api.comment_api.update_comment( context_factory( params={'text': 'new text', 'version': 1}, user=user2), diff --git a/server/szurubooru/tests/api/test_info.py b/server/szurubooru/tests/api/test_info.py index c1499c1d..28fce63f 100644 --- a/server/szurubooru/tests/api/test_info.py +++ b/server/szurubooru/tests/api/test_info.py @@ -1,6 +1,7 @@ from datetime import datetime from szurubooru import api, db + def test_info_api( tmpdir, config_injector, context_factory, post_factory, fake_datetime): directory = tmpdir.mkdir('data') @@ -45,7 +46,7 @@ def test_info_api( with fake_datetime('2016-01-01 13:59'): assert api.info_api.get_info(context_factory()) == { 'postCount': 2, - 'diskUsage': 3, # still 3 - it's cached + 'diskUsage': 3, # still 3 - it's cached 'featuredPost': None, 'featuringTime': None, 'featuringUser': None, @@ -55,7 +56,7 @@ def test_info_api( with fake_datetime('2016-01-01 14:01'): assert api.info_api.get_info(context_factory()) == { 'postCount': 2, - 'diskUsage': 6, # cache expired + 'diskUsage': 6, # cache expired 'featuredPost': None, 'featuringTime': None, 'featuringUser': None, diff --git a/server/szurubooru/tests/api/test_password_reset.py b/server/szurubooru/tests/api/test_password_reset.py index d1a844ce..b84b7f79 100644 --- a/server/szurubooru/tests/api/test_password_reset.py +++ b/server/szurubooru/tests/api/test_password_reset.py @@ -1,21 +1,23 @@ +from unittest.mock import patch import pytest -import unittest.mock from szurubooru import api, db, errors from szurubooru.func import auth, mailer + @pytest.fixture(autouse=True) -def inject_config(tmpdir, config_injector): +def inject_config(config_injector): config_injector({ 'secret': 'x', 'base_url': 'http://example.com/', 'name': 'Test instance', }) + def test_reset_sending_email(context_factory, user_factory): db.session.add(user_factory( name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')) for initiating_user in ['u1', 'user@example.com']: - with unittest.mock.patch('szurubooru.func.mailer.send_mail'): + with patch('szurubooru.func.mailer.send_mail'): assert api.password_reset_api.start_password_reset( context_factory(), {'user_name': initiating_user}) == {} mailer.send_mail.assert_called_once_with( @@ -27,17 +29,21 @@ def test_reset_sending_email(context_factory, user_factory): 'ink: http://example.com/password-reset/u1:4ac0be176fb36' + '4f13ee6b634c43220e2\nOtherwise, please ignore this email.') + def test_trying_to_reset_non_existing(context_factory): with pytest.raises(errors.NotFoundError): api.password_reset_api.start_password_reset( context_factory(), {'user_name': 'u1'}) + def test_trying_to_reset_without_email(context_factory, user_factory): - db.session.add(user_factory(name='u1', rank=db.User.RANK_REGULAR, email=None)) + db.session.add( + user_factory(name='u1', rank=db.User.RANK_REGULAR, email=None)) with pytest.raises(errors.ValidationError): api.password_reset_api.start_password_reset( context_factory(), {'user_name': 'u1'}) + def test_confirming_with_good_token(context_factory, user_factory): user = user_factory( name='u1', rank=db.User.RANK_REGULAR, email='user@example.com') @@ -50,11 +56,13 @@ def test_confirming_with_good_token(context_factory, user_factory): assert user.password_hash != old_hash assert auth.is_valid_password(user, result['password']) is True + def test_trying_to_confirm_non_existing(context_factory): with pytest.raises(errors.NotFoundError): api.password_reset_api.finish_password_reset( context_factory(), {'user_name': 'u1'}) + def test_trying_to_confirm_without_token(context_factory, user_factory): db.session.add(user_factory( name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')) @@ -62,6 +70,7 @@ def test_trying_to_confirm_without_token(context_factory, user_factory): api.password_reset_api.finish_password_reset( context_factory(params={}), {'user_name': 'u1'}) + def test_trying_to_confirm_with_bad_token(context_factory, user_factory): db.session.add(user_factory( name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')) diff --git a/server/szurubooru/tests/api/test_post_creating.py b/server/szurubooru/tests/api/test_post_creating.py index bddd10ba..1a1f045f 100644 --- a/server/szurubooru/tests/api/test_post_creating.py +++ b/server/szurubooru/tests/api/test_post_creating.py @@ -1,8 +1,9 @@ +from unittest.mock import patch import pytest -import unittest.mock from szurubooru import api, db, errors from szurubooru.func import posts, tags, snapshots, net + @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({ @@ -13,6 +14,7 @@ def inject_config(config_injector): }, }) + def test_creating_minimal_posts( context_factory, post_factory, user_factory): auth_user = user_factory(rank=db.User.RANK_REGULAR) @@ -20,16 +22,16 @@ def test_creating_minimal_posts( db.session.add(post) db.session.flush() - with unittest.mock.patch('szurubooru.func.posts.create_post'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_safety'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_source'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_relations'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_notes'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_flags'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_thumbnail'), \ - unittest.mock.patch('szurubooru.func.posts.serialize_post'), \ - unittest.mock.patch('szurubooru.func.tags.export_to_json'), \ - unittest.mock.patch('szurubooru.func.snapshots.save_entity_creation'): + with patch('szurubooru.func.posts.create_post'), \ + patch('szurubooru.func.posts.update_post_safety'), \ + patch('szurubooru.func.posts.update_post_source'), \ + patch('szurubooru.func.posts.update_post_relations'), \ + patch('szurubooru.func.posts.update_post_notes'), \ + patch('szurubooru.func.posts.update_post_flags'), \ + patch('szurubooru.func.posts.update_post_thumbnail'), \ + patch('szurubooru.func.posts.serialize_post'), \ + patch('szurubooru.func.tags.export_to_json'), \ + patch('szurubooru.func.snapshots.save_entity_creation'): posts.create_post.return_value = (post, []) posts.serialize_post.return_value = 'serialized post' @@ -48,32 +50,36 @@ def test_creating_minimal_posts( assert result == 'serialized post' posts.create_post.assert_called_once_with( 'post-content', ['tag1', 'tag2'], auth_user) - posts.update_post_thumbnail.assert_called_once_with(post, 'post-thumbnail') + posts.update_post_thumbnail.assert_called_once_with( + post, 'post-thumbnail') posts.update_post_safety.assert_called_once_with(post, 'safe') posts.update_post_source.assert_called_once_with(post, None) posts.update_post_relations.assert_called_once_with(post, []) posts.update_post_notes.assert_called_once_with(post, []) posts.update_post_flags.assert_called_once_with(post, []) - posts.update_post_thumbnail.assert_called_once_with(post, 'post-thumbnail') - posts.serialize_post.assert_called_once_with(post, auth_user, options=None) + posts.update_post_thumbnail.assert_called_once_with( + post, 'post-thumbnail') + posts.serialize_post.assert_called_once_with( + post, auth_user, options=None) tags.export_to_json.assert_called_once_with() snapshots.save_entity_creation.assert_called_once_with(post, auth_user) + def test_creating_full_posts(context_factory, post_factory, user_factory): auth_user = user_factory(rank=db.User.RANK_REGULAR) post = post_factory() db.session.add(post) db.session.flush() - with unittest.mock.patch('szurubooru.func.posts.create_post'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_safety'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_source'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_relations'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_notes'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_flags'), \ - unittest.mock.patch('szurubooru.func.posts.serialize_post'), \ - unittest.mock.patch('szurubooru.func.tags.export_to_json'), \ - unittest.mock.patch('szurubooru.func.snapshots.save_entity_creation'): + with patch('szurubooru.func.posts.create_post'), \ + patch('szurubooru.func.posts.update_post_safety'), \ + patch('szurubooru.func.posts.update_post_source'), \ + patch('szurubooru.func.posts.update_post_relations'), \ + patch('szurubooru.func.posts.update_post_notes'), \ + patch('szurubooru.func.posts.update_post_flags'), \ + patch('szurubooru.func.posts.serialize_post'), \ + patch('szurubooru.func.tags.export_to_json'), \ + patch('szurubooru.func.snapshots.save_entity_creation'): posts.create_post.return_value = (post, []) posts.serialize_post.return_value = 'serialized post' @@ -98,12 +104,16 @@ def test_creating_full_posts(context_factory, post_factory, user_factory): posts.update_post_safety.assert_called_once_with(post, 'safe') posts.update_post_source.assert_called_once_with(post, 'source') posts.update_post_relations.assert_called_once_with(post, [1, 2]) - posts.update_post_notes.assert_called_once_with(post, ['note1', 'note2']) - posts.update_post_flags.assert_called_once_with(post, ['flag1', 'flag2']) - posts.serialize_post.assert_called_once_with(post, auth_user, options=None) + posts.update_post_notes.assert_called_once_with( + post, ['note1', 'note2']) + posts.update_post_flags.assert_called_once_with( + post, ['flag1', 'flag2']) + posts.serialize_post.assert_called_once_with( + post, auth_user, options=None) tags.export_to_json.assert_called_once_with() snapshots.save_entity_creation.assert_called_once_with(post, auth_user) + def test_anonymous_uploads( config_injector, context_factory, post_factory, user_factory): auth_user = user_factory(rank=db.User.RANK_REGULAR) @@ -111,11 +121,11 @@ def test_anonymous_uploads( db.session.add(post) db.session.flush() - with unittest.mock.patch('szurubooru.func.tags.export_to_json'), \ - unittest.mock.patch('szurubooru.func.snapshots.save_entity_creation'), \ - unittest.mock.patch('szurubooru.func.posts.serialize_post'), \ - unittest.mock.patch('szurubooru.func.posts.create_post'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_source'): + with patch('szurubooru.func.tags.export_to_json'), \ + patch('szurubooru.func.snapshots.save_entity_creation'), \ + patch('szurubooru.func.posts.serialize_post'), \ + patch('szurubooru.func.posts.create_post'), \ + patch('szurubooru.func.posts.update_post_source'): config_injector({ 'privileges': {'posts:create:anonymous': db.User.RANK_REGULAR}, }) @@ -134,6 +144,7 @@ def test_anonymous_uploads( posts.create_post.assert_called_once_with( 'post-content', ['tag1', 'tag2'], None) + def test_creating_from_url_saves_source( config_injector, context_factory, post_factory, user_factory): auth_user = user_factory(rank=db.User.RANK_REGULAR) @@ -141,12 +152,12 @@ def test_creating_from_url_saves_source( db.session.add(post) db.session.flush() - with unittest.mock.patch('szurubooru.func.net.download'), \ - unittest.mock.patch('szurubooru.func.tags.export_to_json'), \ - unittest.mock.patch('szurubooru.func.snapshots.save_entity_creation'), \ - unittest.mock.patch('szurubooru.func.posts.serialize_post'), \ - unittest.mock.patch('szurubooru.func.posts.create_post'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_source'): + with patch('szurubooru.func.net.download'), \ + patch('szurubooru.func.tags.export_to_json'), \ + patch('szurubooru.func.snapshots.save_entity_creation'), \ + patch('szurubooru.func.posts.serialize_post'), \ + patch('szurubooru.func.posts.create_post'), \ + patch('szurubooru.func.posts.update_post_source'): config_injector({ 'privileges': {'posts:create:identified': db.User.RANK_REGULAR}, }) @@ -165,6 +176,7 @@ def test_creating_from_url_saves_source( b'content', ['tag1', 'tag2'], auth_user) posts.update_post_source.assert_called_once_with(post, 'example.com') + def test_creating_from_url_with_source_specified( config_injector, context_factory, post_factory, user_factory): auth_user = user_factory(rank=db.User.RANK_REGULAR) @@ -172,12 +184,12 @@ def test_creating_from_url_with_source_specified( db.session.add(post) db.session.flush() - with unittest.mock.patch('szurubooru.func.net.download'), \ - unittest.mock.patch('szurubooru.func.tags.export_to_json'), \ - unittest.mock.patch('szurubooru.func.snapshots.save_entity_creation'), \ - unittest.mock.patch('szurubooru.func.posts.serialize_post'), \ - unittest.mock.patch('szurubooru.func.posts.create_post'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_source'): + with patch('szurubooru.func.net.download'), \ + patch('szurubooru.func.tags.export_to_json'), \ + patch('szurubooru.func.snapshots.save_entity_creation'), \ + patch('szurubooru.func.posts.serialize_post'), \ + patch('szurubooru.func.posts.create_post'), \ + patch('szurubooru.func.posts.update_post_source'): config_injector({ 'privileges': {'posts:create:identified': db.User.RANK_REGULAR}, }) @@ -197,6 +209,7 @@ def test_creating_from_url_with_source_specified( b'content', ['tag1', 'tag2'], auth_user) posts.update_post_source.assert_called_once_with(post, 'example2.com') + @pytest.mark.parametrize('field', ['tags', 'safety']) def test_trying_to_omit_mandatory_field(context_factory, user_factory, field): params = { @@ -211,6 +224,7 @@ def test_trying_to_omit_mandatory_field(context_factory, user_factory, field): files={'content': '...'}, user=user_factory(rank=db.User.RANK_REGULAR))) + def test_trying_to_omit_content(context_factory, user_factory): with pytest.raises(errors.MissingRequiredFileError): api.post_api.create_post( @@ -221,12 +235,15 @@ def test_trying_to_omit_content(context_factory, user_factory): }, user=user_factory(rank=db.User.RANK_REGULAR))) -def test_trying_to_create_post_without_privileges(context_factory, user_factory): + +def test_trying_to_create_post_without_privileges( + context_factory, user_factory): with pytest.raises(errors.AuthError): api.post_api.create_post(context_factory( params='whatever', user=user_factory(rank=db.User.RANK_ANONYMOUS))) + def test_trying_to_create_tags_without_privileges( config_injector, context_factory, user_factory): config_injector({ @@ -237,8 +254,8 @@ def test_trying_to_create_tags_without_privileges( }, }) with pytest.raises(errors.AuthError), \ - unittest.mock.patch('szurubooru.func.posts.update_post_content'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_tags'): + patch('szurubooru.func.posts.update_post_content'), \ + patch('szurubooru.func.posts.update_post_tags'): posts.update_post_tags.return_value = ['new-tag'] api.post_api.create_post( context_factory( diff --git a/server/szurubooru/tests/api/test_post_deleting.py b/server/szurubooru/tests/api/test_post_deleting.py index f128d527..28175802 100644 --- a/server/szurubooru/tests/api/test_post_deleting.py +++ b/server/szurubooru/tests/api/test_post_deleting.py @@ -1,16 +1,18 @@ +from unittest.mock import patch import pytest -import unittest.mock from szurubooru import api, db, errors from szurubooru.func import posts, tags + @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({'privileges': {'posts:delete': db.User.RANK_REGULAR}}) + def test_deleting(user_factory, post_factory, context_factory): db.session.add(post_factory(id=1)) db.session.commit() - with unittest.mock.patch('szurubooru.func.tags.export_to_json'): + with patch('szurubooru.func.tags.export_to_json'): result = api.post_api.delete_post( context_factory( params={'version': 1}, @@ -20,12 +22,14 @@ def test_deleting(user_factory, post_factory, context_factory): assert db.session.query(db.Post).count() == 0 tags.export_to_json.assert_called_once_with() + def test_trying_to_delete_non_existing(user_factory, context_factory): with pytest.raises(posts.PostNotFoundError): api.post_api.delete_post( context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), {'post_id': 999}) + def test_trying_to_delete_without_privileges( user_factory, post_factory, context_factory): db.session.add(post_factory(id=1)) diff --git a/server/szurubooru/tests/api/test_post_favoriting.py b/server/szurubooru/tests/api/test_post_favoriting.py index 0788ec2f..d78d199e 100644 --- a/server/szurubooru/tests/api/test_post_favoriting.py +++ b/server/szurubooru/tests/api/test_post_favoriting.py @@ -1,20 +1,22 @@ -import pytest -import unittest.mock from datetime import datetime +from unittest.mock import patch +import pytest from szurubooru import api, db, errors from szurubooru.func import posts + @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({'privileges': {'posts:favorite': db.User.RANK_REGULAR}}) + def test_adding_to_favorites( user_factory, post_factory, context_factory, fake_datetime): post = post_factory() db.session.add(post) db.session.commit() assert post.score == 0 - with unittest.mock.patch('szurubooru.func.posts.serialize_post'), \ + with patch('szurubooru.func.posts.serialize_post'), \ fake_datetime('1997-12-01'): posts.serialize_post.return_value = 'serialized post' result = api.post_api.add_post_to_favorites( @@ -27,6 +29,7 @@ def test_adding_to_favorites( assert post.favorite_count == 1 assert post.score == 1 + def test_removing_from_favorites( user_factory, post_factory, context_factory, fake_datetime): user = user_factory() @@ -34,7 +37,7 @@ def test_removing_from_favorites( db.session.add(post) db.session.commit() assert post.score == 0 - with unittest.mock.patch('szurubooru.func.posts.serialize_post'): + with patch('szurubooru.func.posts.serialize_post'): with fake_datetime('1997-12-01'): api.post_api.add_post_to_favorites( context_factory(user=user), @@ -49,13 +52,14 @@ def test_removing_from_favorites( assert db.session.query(db.PostFavorite).count() == 0 assert post.favorite_count == 0 + def test_favoriting_twice( user_factory, post_factory, context_factory, fake_datetime): user = user_factory() post = post_factory() db.session.add(post) db.session.commit() - with unittest.mock.patch('szurubooru.func.posts.serialize_post'): + with patch('szurubooru.func.posts.serialize_post'): with fake_datetime('1997-12-01'): api.post_api.add_post_to_favorites( context_factory(user=user), @@ -68,13 +72,14 @@ def test_favoriting_twice( assert db.session.query(db.PostFavorite).count() == 1 assert post.favorite_count == 1 + def test_removing_twice( user_factory, post_factory, context_factory, fake_datetime): user = user_factory() post = post_factory() db.session.add(post) db.session.commit() - with unittest.mock.patch('szurubooru.func.posts.serialize_post'): + with patch('szurubooru.func.posts.serialize_post'): with fake_datetime('1997-12-01'): api.post_api.add_post_to_favorites( context_factory(user=user), @@ -91,6 +96,7 @@ def test_removing_twice( assert db.session.query(db.PostFavorite).count() == 0 assert post.favorite_count == 0 + def test_favorites_from_multiple_users( user_factory, post_factory, context_factory, fake_datetime): user1 = user_factory() @@ -98,7 +104,7 @@ def test_favorites_from_multiple_users( post = post_factory() db.session.add_all([user1, user2, post]) db.session.commit() - with unittest.mock.patch('szurubooru.func.posts.serialize_post'): + with patch('szurubooru.func.posts.serialize_post'): with fake_datetime('1997-12-01'): api.post_api.add_post_to_favorites( context_factory(user=user1), @@ -112,12 +118,14 @@ def test_favorites_from_multiple_users( assert post.favorite_count == 2 assert post.last_favorite_time == datetime(1997, 12, 2) + def test_trying_to_update_non_existing(user_factory, context_factory): with pytest.raises(posts.PostNotFoundError): api.post_api.add_post_to_favorites( context_factory(user=user_factory()), {'post_id': 5}) + def test_trying_to_rate_without_privileges( user_factory, post_factory, context_factory): post = post_factory() diff --git a/server/szurubooru/tests/api/test_post_featuring.py b/server/szurubooru/tests/api/test_post_featuring.py index 45cae474..9f3bd0d3 100644 --- a/server/szurubooru/tests/api/test_post_featuring.py +++ b/server/szurubooru/tests/api/test_post_featuring.py @@ -1,8 +1,9 @@ +from unittest.mock import patch import pytest -import unittest.mock from szurubooru import api, db, errors from szurubooru.func import posts + @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({ @@ -12,14 +13,12 @@ def inject_config(config_injector): }, }) -def test_no_featured_post(user_factory, post_factory, context_factory): - assert posts.try_get_featured_post() is None def test_featuring(user_factory, post_factory, context_factory): db.session.add(post_factory(id=1)) db.session.commit() assert not posts.get_post_by_id(1).is_featured - with unittest.mock.patch('szurubooru.func.posts.serialize_post'): + with patch('szurubooru.func.posts.serialize_post'): posts.serialize_post.return_value = 'serialized post' result = api.post_api.set_featured_post( context_factory( @@ -34,18 +33,19 @@ def test_featuring(user_factory, post_factory, context_factory): user=user_factory(rank=db.User.RANK_REGULAR))) assert result == 'serialized post' -def test_trying_to_omit_required_parameter( - user_factory, post_factory, context_factory): + +def test_trying_to_omit_required_parameter(user_factory, context_factory): with pytest.raises(errors.MissingRequiredParameterError): api.post_api.set_featured_post( context_factory( user=user_factory(rank=db.User.RANK_REGULAR))) + def test_trying_to_feature_the_same_post_twice( user_factory, post_factory, context_factory): db.session.add(post_factory(id=1)) db.session.commit() - with unittest.mock.patch('szurubooru.func.posts.serialize_post'): + with patch('szurubooru.func.posts.serialize_post'): api.post_api.set_featured_post( context_factory( params={'id': 1}, @@ -56,6 +56,7 @@ def test_trying_to_feature_the_same_post_twice( params={'id': 1}, user=user_factory(rank=db.User.RANK_REGULAR))) + def test_featuring_one_post_after_another( user_factory, post_factory, context_factory, fake_datetime): db.session.add(post_factory(id=1)) @@ -64,14 +65,14 @@ def test_featuring_one_post_after_another( assert posts.try_get_featured_post() is None assert not posts.get_post_by_id(1).is_featured assert not posts.get_post_by_id(2).is_featured - with unittest.mock.patch('szurubooru.func.posts.serialize_post'): + with patch('szurubooru.func.posts.serialize_post'): with fake_datetime('1997'): - result = api.post_api.set_featured_post( + api.post_api.set_featured_post( context_factory( params={'id': 1}, user=user_factory(rank=db.User.RANK_REGULAR))) with fake_datetime('1998'): - result = api.post_api.set_featured_post( + api.post_api.set_featured_post( context_factory( params={'id': 2}, user=user_factory(rank=db.User.RANK_REGULAR))) @@ -80,6 +81,7 @@ def test_featuring_one_post_after_another( assert not posts.get_post_by_id(1).is_featured assert posts.get_post_by_id(2).is_featured + def test_trying_to_feature_non_existing(user_factory, context_factory): with pytest.raises(posts.PostNotFoundError): api.post_api.set_featured_post( @@ -87,6 +89,7 @@ def test_trying_to_feature_non_existing(user_factory, context_factory): params={'id': 1}, user=user_factory(rank=db.User.RANK_REGULAR))) + def test_trying_to_feature_without_privileges(user_factory, context_factory): with pytest.raises(errors.AuthError): api.post_api.set_featured_post( @@ -94,6 +97,7 @@ def test_trying_to_feature_without_privileges(user_factory, context_factory): params={'id': 1}, user=user_factory(rank=db.User.RANK_ANONYMOUS))) + def test_getting_featured_post_without_privileges_to_view( user_factory, context_factory): api.post_api.get_featured_post( diff --git a/server/szurubooru/tests/api/test_post_rating.py b/server/szurubooru/tests/api/test_post_rating.py index ed646b3b..18e823e7 100644 --- a/server/szurubooru/tests/api/test_post_rating.py +++ b/server/szurubooru/tests/api/test_post_rating.py @@ -1,87 +1,93 @@ +from unittest.mock import patch import pytest -import unittest.mock from szurubooru import api, db, errors -from szurubooru.func import posts, scores +from szurubooru.func import posts + @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({'privileges': {'posts:score': db.User.RANK_REGULAR}}) + def test_simple_rating( user_factory, post_factory, context_factory, fake_datetime): post = post_factory() db.session.add(post) db.session.commit() - with unittest.mock.patch('szurubooru.func.posts.serialize_post'): + with patch('szurubooru.func.posts.serialize_post'), \ + fake_datetime('1997-12-01'): posts.serialize_post.return_value = 'serialized post' - with fake_datetime('1997-12-01'): - result = api.post_api.set_post_score( - context_factory( - params={'score': 1}, user=user_factory()), - {'post_id': post.post_id}) + result = api.post_api.set_post_score( + context_factory( + params={'score': 1}, user=user_factory()), + {'post_id': post.post_id}) assert result == 'serialized post' post = db.session.query(db.Post).one() assert db.session.query(db.PostScore).count() == 1 assert post is not None assert post.score == 1 + def test_updating_rating( user_factory, post_factory, context_factory, fake_datetime): user = user_factory() post = post_factory() db.session.add(post) db.session.commit() - with unittest.mock.patch('szurubooru.func.posts.serialize_post'): + with patch('szurubooru.func.posts.serialize_post'): with fake_datetime('1997-12-01'): - result = api.post_api.set_post_score( + api.post_api.set_post_score( context_factory(params={'score': 1}, user=user), {'post_id': post.post_id}) with fake_datetime('1997-12-02'): - result = api.post_api.set_post_score( + api.post_api.set_post_score( context_factory(params={'score': -1}, user=user), {'post_id': post.post_id}) post = db.session.query(db.Post).one() assert db.session.query(db.PostScore).count() == 1 assert post.score == -1 + def test_updating_rating_to_zero( user_factory, post_factory, context_factory, fake_datetime): user = user_factory() post = post_factory() db.session.add(post) db.session.commit() - with unittest.mock.patch('szurubooru.func.posts.serialize_post'): + with patch('szurubooru.func.posts.serialize_post'): with fake_datetime('1997-12-01'): - result = api.post_api.set_post_score( + api.post_api.set_post_score( context_factory(params={'score': 1}, user=user), {'post_id': post.post_id}) with fake_datetime('1997-12-02'): - result = api.post_api.set_post_score( + api.post_api.set_post_score( context_factory(params={'score': 0}, user=user), {'post_id': post.post_id}) post = db.session.query(db.Post).one() assert db.session.query(db.PostScore).count() == 0 assert post.score == 0 + def test_deleting_rating( user_factory, post_factory, context_factory, fake_datetime): user = user_factory() post = post_factory() db.session.add(post) db.session.commit() - with unittest.mock.patch('szurubooru.func.posts.serialize_post'): + with patch('szurubooru.func.posts.serialize_post'): with fake_datetime('1997-12-01'): - result = api.post_api.set_post_score( + api.post_api.set_post_score( context_factory(params={'score': 1}, user=user), {'post_id': post.post_id}) with fake_datetime('1997-12-02'): - result = api.post_api.delete_post_score( + api.post_api.delete_post_score( context_factory(user=user), {'post_id': post.post_id}) post = db.session.query(db.Post).one() assert db.session.query(db.PostScore).count() == 0 assert post.score == 0 + def test_ratings_from_multiple_users( user_factory, post_factory, context_factory, fake_datetime): user1 = user_factory() @@ -89,19 +95,20 @@ def test_ratings_from_multiple_users( post = post_factory() db.session.add_all([user1, user2, post]) db.session.commit() - with unittest.mock.patch('szurubooru.func.posts.serialize_post'): + with patch('szurubooru.func.posts.serialize_post'): with fake_datetime('1997-12-01'): - result = api.post_api.set_post_score( + api.post_api.set_post_score( context_factory(params={'score': 1}, user=user1), {'post_id': post.post_id}) with fake_datetime('1997-12-02'): - result = api.post_api.set_post_score( + api.post_api.set_post_score( context_factory(params={'score': -1}, user=user2), {'post_id': post.post_id}) post = db.session.query(db.Post).one() assert db.session.query(db.PostScore).count() == 2 assert post.score == 0 + def test_trying_to_omit_mandatory_field( user_factory, post_factory, context_factory): post = post_factory() @@ -112,13 +119,14 @@ def test_trying_to_omit_mandatory_field( context_factory(params={}, user=user_factory()), {'post_id': post.post_id}) -def test_trying_to_update_non_existing( - user_factory, post_factory, context_factory): + +def test_trying_to_update_non_existing(user_factory, context_factory): with pytest.raises(posts.PostNotFoundError): api.post_api.set_post_score( context_factory(params={'score': 1}, user=user_factory()), {'post_id': 5}) + def test_trying_to_rate_without_privileges( user_factory, post_factory, context_factory): post = post_factory() diff --git a/server/szurubooru/tests/api/test_post_retrieving.py b/server/szurubooru/tests/api/test_post_retrieving.py index 34583460..e37e1ecf 100644 --- a/server/szurubooru/tests/api/test_post_retrieving.py +++ b/server/szurubooru/tests/api/test_post_retrieving.py @@ -1,11 +1,12 @@ -import pytest -import unittest.mock from datetime import datetime +from unittest.mock import patch +import pytest from szurubooru import api, db, errors from szurubooru.func import posts + @pytest.fixture(autouse=True) -def inject_config(tmpdir, config_injector): +def inject_config(config_injector): config_injector({ 'privileges': { 'posts:list': db.User.RANK_REGULAR, @@ -13,11 +14,12 @@ def inject_config(tmpdir, config_injector): }, }) + def test_retrieving_multiple(user_factory, post_factory, context_factory): post1 = post_factory(id=1) post2 = post_factory(id=2) db.session.add_all([post1, post2]) - with unittest.mock.patch('szurubooru.func.posts.serialize_post'): + with patch('szurubooru.func.posts.serialize_post'): posts.serialize_post.return_value = 'serialized post' result = api.post_api.get_posts( context_factory( @@ -31,6 +33,7 @@ def test_retrieving_multiple(user_factory, post_factory, context_factory): 'results': ['serialized post', 'serialized post'], } + def test_using_special_tokens(user_factory, post_factory, context_factory): auth_user = user_factory(rank=db.User.RANK_REGULAR) post1 = post_factory(id=1) @@ -39,7 +42,7 @@ def test_using_special_tokens(user_factory, post_factory, context_factory): user=auth_user, time=datetime.utcnow())] db.session.add_all([post1, post2, auth_user]) db.session.flush() - with unittest.mock.patch('szurubooru.func.posts.serialize_post'): + with patch('szurubooru.func.posts.serialize_post'): posts.serialize_post.side_effect = \ lambda post, *_args, **_kwargs: \ 'serialized post %d' % post.post_id @@ -55,8 +58,9 @@ def test_using_special_tokens(user_factory, post_factory, context_factory): 'results': ['serialized post 1'], } + def test_trying_to_use_special_tokens_without_logging_in( - user_factory, post_factory, context_factory, config_injector): + user_factory, context_factory, config_injector): config_injector({ 'privileges': {'posts:list': 'anonymous'}, }) @@ -66,6 +70,7 @@ def test_trying_to_use_special_tokens_without_logging_in( params={'query': 'special:fav', 'page': 1}, user=user_factory(rank=db.User.RANK_ANONYMOUS))) + def test_trying_to_retrieve_multiple_without_privileges( user_factory, context_factory): with pytest.raises(errors.AuthError): @@ -74,21 +79,24 @@ def test_trying_to_retrieve_multiple_without_privileges( params={'query': '', 'page': 1}, user=user_factory(rank=db.User.RANK_ANONYMOUS))) + def test_retrieving_single(user_factory, post_factory, context_factory): db.session.add(post_factory(id=1)) - with unittest.mock.patch('szurubooru.func.posts.serialize_post'): + with patch('szurubooru.func.posts.serialize_post'): posts.serialize_post.return_value = 'serialized post' result = api.post_api.get_post( context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), {'post_id': 1}) assert result == 'serialized post' + def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): with pytest.raises(posts.PostNotFoundError): api.post_api.get_post( context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), {'post_id': 999}) + def test_trying_to_retrieve_single_without_privileges( user_factory, context_factory): with pytest.raises(errors.AuthError): diff --git a/server/szurubooru/tests/api/test_post_updating.py b/server/szurubooru/tests/api/test_post_updating.py index ce659f25..b577c7ee 100644 --- a/server/szurubooru/tests/api/test_post_updating.py +++ b/server/szurubooru/tests/api/test_post_updating.py @@ -1,11 +1,12 @@ -import pytest -import unittest.mock from datetime import datetime +from unittest.mock import patch +import pytest from szurubooru import api, db, errors from szurubooru.func import posts, tags, snapshots, net + @pytest.fixture(autouse=True) -def inject_config(tmpdir, config_injector): +def inject_config(config_injector): config_injector({ 'privileges': { 'posts:edit:tags': db.User.RANK_REGULAR, @@ -20,6 +21,7 @@ def inject_config(tmpdir, config_injector): }, }) + def test_post_updating( context_factory, post_factory, user_factory, fake_datetime): auth_user = user_factory(rank=db.User.RANK_REGULAR) @@ -27,18 +29,18 @@ def test_post_updating( db.session.add(post) db.session.flush() - with unittest.mock.patch('szurubooru.func.posts.create_post'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_tags'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_content'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_thumbnail'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_safety'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_source'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_relations'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_notes'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_flags'), \ - unittest.mock.patch('szurubooru.func.posts.serialize_post'), \ - unittest.mock.patch('szurubooru.func.tags.export_to_json'), \ - unittest.mock.patch('szurubooru.func.snapshots.save_entity_modification'), \ + with patch('szurubooru.func.posts.create_post'), \ + patch('szurubooru.func.posts.update_post_tags'), \ + patch('szurubooru.func.posts.update_post_content'), \ + patch('szurubooru.func.posts.update_post_thumbnail'), \ + patch('szurubooru.func.posts.update_post_safety'), \ + patch('szurubooru.func.posts.update_post_source'), \ + patch('szurubooru.func.posts.update_post_relations'), \ + patch('szurubooru.func.posts.update_post_notes'), \ + patch('szurubooru.func.posts.update_post_flags'), \ + patch('szurubooru.func.posts.serialize_post'), \ + patch('szurubooru.func.tags.export_to_json'), \ + patch('szurubooru.func.snapshots.save_entity_modification'), \ fake_datetime('1997-01-01'): posts.serialize_post.return_value = 'serialized post' @@ -64,28 +66,34 @@ def test_post_updating( posts.create_post.assert_not_called() posts.update_post_tags.assert_called_once_with(post, ['tag1', 'tag2']) posts.update_post_content.assert_called_once_with(post, 'post-content') - posts.update_post_thumbnail.assert_called_once_with(post, 'post-thumbnail') + posts.update_post_thumbnail.assert_called_once_with( + post, 'post-thumbnail') posts.update_post_safety.assert_called_once_with(post, 'safe') posts.update_post_source.assert_called_once_with(post, 'source') posts.update_post_relations.assert_called_once_with(post, [1, 2]) - posts.update_post_notes.assert_called_once_with(post, ['note1', 'note2']) - posts.update_post_flags.assert_called_once_with(post, ['flag1', 'flag2']) - posts.serialize_post.assert_called_once_with(post, auth_user, options=None) + posts.update_post_notes.assert_called_once_with( + post, ['note1', 'note2']) + posts.update_post_flags.assert_called_once_with( + post, ['flag1', 'flag2']) + posts.serialize_post.assert_called_once_with( + post, auth_user, options=None) tags.export_to_json.assert_called_once_with() - snapshots.save_entity_modification.assert_called_once_with(post, auth_user) + snapshots.save_entity_modification.assert_called_once_with( + post, auth_user) assert post.last_edit_time == datetime(1997, 1, 1) + def test_uploading_from_url_saves_source( context_factory, post_factory, user_factory): post = post_factory() db.session.add(post) db.session.flush() - with unittest.mock.patch('szurubooru.func.net.download'), \ - unittest.mock.patch('szurubooru.func.tags.export_to_json'), \ - unittest.mock.patch('szurubooru.func.snapshots.save_entity_modification'), \ - unittest.mock.patch('szurubooru.func.posts.serialize_post'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_content'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_source'): + with patch('szurubooru.func.net.download'), \ + patch('szurubooru.func.tags.export_to_json'), \ + patch('szurubooru.func.snapshots.save_entity_modification'), \ + patch('szurubooru.func.posts.serialize_post'), \ + patch('szurubooru.func.posts.update_post_content'), \ + patch('szurubooru.func.posts.update_post_source'): net.download.return_value = b'content' api.post_api.update_post( context_factory( @@ -96,17 +104,18 @@ def test_uploading_from_url_saves_source( posts.update_post_content.assert_called_once_with(post, b'content') posts.update_post_source.assert_called_once_with(post, 'example.com') + def test_uploading_from_url_with_source_specified( context_factory, post_factory, user_factory): post = post_factory() db.session.add(post) db.session.flush() - with unittest.mock.patch('szurubooru.func.net.download'), \ - unittest.mock.patch('szurubooru.func.tags.export_to_json'), \ - unittest.mock.patch('szurubooru.func.snapshots.save_entity_modification'), \ - unittest.mock.patch('szurubooru.func.posts.serialize_post'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_content'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_source'): + with patch('szurubooru.func.net.download'), \ + patch('szurubooru.func.tags.export_to_json'), \ + patch('szurubooru.func.snapshots.save_entity_modification'), \ + patch('szurubooru.func.posts.serialize_post'), \ + patch('szurubooru.func.posts.update_post_content'), \ + patch('szurubooru.func.posts.update_post_source'): net.download.return_value = b'content' api.post_api.update_post( context_factory( @@ -120,6 +129,7 @@ def test_uploading_from_url_with_source_specified( posts.update_post_content.assert_called_once_with(post, b'content') posts.update_post_source.assert_called_once_with(post, 'example2.com') + def test_trying_to_update_non_existing(context_factory, user_factory): with pytest.raises(posts.PostNotFoundError): api.post_api.update_post( @@ -128,18 +138,19 @@ def test_trying_to_update_non_existing(context_factory, user_factory): user=user_factory(rank=db.User.RANK_REGULAR)), {'post_id': 1}) -@pytest.mark.parametrize('privilege,files,params', [ - ('posts:edit:tags', {}, {'tags': '...'}), - ('posts:edit:safety', {}, {'safety': '...'}), - ('posts:edit:source', {}, {'source': '...'}), - ('posts:edit:relations', {}, {'relations': '...'}), - ('posts:edit:notes', {}, {'notes': '...'}), - ('posts:edit:flags', {}, {'flags': '...'}), - ('posts:edit:content', {'content': '...'}, {}), - ('posts:edit:thumbnail', {'thumbnail': '...'}, {}), + +@pytest.mark.parametrize('files,params', [ + ({}, {'tags': '...'}), + ({}, {'safety': '...'}), + ({}, {'source': '...'}), + ({}, {'relations': '...'}), + ({}, {'notes': '...'}), + ({}, {'flags': '...'}), + ({'content': '...'}, {}), + ({'thumbnail': '...'}, {}), ]) def test_trying_to_update_field_without_privileges( - context_factory, post_factory, user_factory, files, params, privilege): + context_factory, post_factory, user_factory, files, params): post = post_factory() db.session.add(post) db.session.flush() @@ -151,13 +162,14 @@ def test_trying_to_update_field_without_privileges( user=user_factory(rank=db.User.RANK_ANONYMOUS)), {'post_id': post.post_id}) + def test_trying_to_create_tags_without_privileges( context_factory, post_factory, user_factory): post = post_factory() db.session.add(post) db.session.flush() with pytest.raises(errors.AuthError), \ - unittest.mock.patch('szurubooru.func.posts.update_post_tags'): + patch('szurubooru.func.posts.update_post_tags'): posts.update_post_tags.return_value = ['new-tag'] api.post_api.update_post( context_factory( diff --git a/server/szurubooru/tests/api/test_snapshot_retrieving.py b/server/szurubooru/tests/api/test_snapshot_retrieving.py index c7d917a9..a04326b8 100644 --- a/server/szurubooru/tests/api/test_snapshot_retrieving.py +++ b/server/szurubooru/tests/api/test_snapshot_retrieving.py @@ -1,7 +1,8 @@ -import pytest from datetime import datetime +import pytest from szurubooru import api, db, errors + def snapshot_factory(): snapshot = db.Snapshot() snapshot.creation_time = datetime(1999, 1, 1) @@ -12,12 +13,14 @@ def snapshot_factory(): snapshot.data = '{}' return snapshot + @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({ 'privileges': {'snapshots:list': db.User.RANK_REGULAR}, }) + def test_retrieving_multiple(user_factory, context_factory): snapshot1 = snapshot_factory() snapshot2 = snapshot_factory() @@ -32,6 +35,7 @@ def test_retrieving_multiple(user_factory, context_factory): assert result['total'] == 2 assert len(result['results']) == 2 + def test_trying_to_retrieve_multiple_without_privileges( user_factory, context_factory): with pytest.raises(errors.AuthError): diff --git a/server/szurubooru/tests/api/test_tag_category_creating.py b/server/szurubooru/tests/api/test_tag_category_creating.py index 576fe4f5..8bf0792b 100644 --- a/server/szurubooru/tests/api/test_tag_category_creating.py +++ b/server/szurubooru/tests/api/test_tag_category_creating.py @@ -1,21 +1,24 @@ +from unittest.mock import patch import pytest -import unittest.mock from szurubooru import api, db, errors from szurubooru.func import tag_categories, tags + def _update_category_name(category, name): category.name = name + @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({ 'privileges': {'tag_categories:create': db.User.RANK_REGULAR}, }) + def test_creating_category(user_factory, context_factory): - with unittest.mock.patch('szurubooru.func.tag_categories.serialize_category'), \ - unittest.mock.patch('szurubooru.func.tag_categories.update_category_name'), \ - unittest.mock.patch('szurubooru.func.tags.export_to_json'): + with patch('szurubooru.func.tag_categories.serialize_category'), \ + patch('szurubooru.func.tag_categories.update_category_name'), \ + patch('szurubooru.func.tags.export_to_json'): tag_categories.update_category_name.side_effect = _update_category_name tag_categories.serialize_category.return_value = 'serialized category' result = api.tag_category_api.create_tag_category( @@ -29,6 +32,7 @@ def test_creating_category(user_factory, context_factory): assert category.tag_count == 0 tags.export_to_json.assert_called_once_with() + @pytest.mark.parametrize('field', ['name', 'color']) def test_trying_to_omit_mandatory_field(user_factory, context_factory, field): params = { @@ -42,6 +46,7 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field): params=params, user=user_factory(rank=db.User.RANK_REGULAR))) + def test_trying_to_create_without_privileges(user_factory, context_factory): with pytest.raises(errors.AuthError): api.tag_category_api.create_tag_category( diff --git a/server/szurubooru/tests/api/test_tag_category_deleting.py b/server/szurubooru/tests/api/test_tag_category_deleting.py index 4cbd6437..bb590efb 100644 --- a/server/szurubooru/tests/api/test_tag_category_deleting.py +++ b/server/szurubooru/tests/api/test_tag_category_deleting.py @@ -1,19 +1,21 @@ +from unittest.mock import patch import pytest -import unittest.mock from szurubooru import api, db, errors from szurubooru.func import tag_categories, tags + @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({ 'privileges': {'tag_categories:delete': db.User.RANK_REGULAR}, }) + def test_deleting(user_factory, tag_category_factory, context_factory): db.session.add(tag_category_factory(name='root')) db.session.add(tag_category_factory(name='category')) db.session.commit() - with unittest.mock.patch('szurubooru.func.tags.export_to_json'): + with patch('szurubooru.func.tags.export_to_json'): result = api.tag_category_api.delete_tag_category( context_factory( params={'version': 1}, @@ -24,6 +26,7 @@ def test_deleting(user_factory, tag_category_factory, context_factory): assert db.session.query(db.TagCategory).one().name == 'root' tags.export_to_json.assert_called_once_with() + def test_trying_to_delete_used( user_factory, tag_category_factory, tag_factory, context_factory): category = tag_category_factory(name='category') @@ -40,6 +43,7 @@ def test_trying_to_delete_used( {'category_name': 'category'}) assert db.session.query(db.TagCategory).count() == 1 + def test_trying_to_delete_last( user_factory, tag_category_factory, context_factory): db.session.add(tag_category_factory(name='root')) @@ -51,12 +55,14 @@ def test_trying_to_delete_last( user=user_factory(rank=db.User.RANK_REGULAR)), {'category_name': 'root'}) + def test_trying_to_delete_non_existing(user_factory, context_factory): with pytest.raises(tag_categories.TagCategoryNotFoundError): api.tag_category_api.delete_tag_category( context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), {'category_name': 'bad'}) + def test_trying_to_delete_without_privileges( user_factory, tag_category_factory, context_factory): db.session.add(tag_category_factory(name='category')) diff --git a/server/szurubooru/tests/api/test_tag_category_retrieving.py b/server/szurubooru/tests/api/test_tag_category_retrieving.py index 3bfc115a..e3715308 100644 --- a/server/szurubooru/tests/api/test_tag_category_retrieving.py +++ b/server/szurubooru/tests/api/test_tag_category_retrieving.py @@ -2,6 +2,7 @@ import pytest from szurubooru import api, db, errors from szurubooru.func import tag_categories + @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({ @@ -11,6 +12,7 @@ def inject_config(config_injector): }, }) + def test_retrieving_multiple( user_factory, tag_category_factory, context_factory): db.session.add_all([ @@ -21,7 +23,9 @@ def test_retrieving_multiple( context_factory(user=user_factory(rank=db.User.RANK_REGULAR))) assert [cat['name'] for cat in result['results']] == ['c1', 'c2'] -def test_retrieving_single(user_factory, tag_category_factory, context_factory): + +def test_retrieving_single( + user_factory, tag_category_factory, context_factory): db.session.add(tag_category_factory(name='cat')) result = api.tag_category_api.get_tag_category( context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), @@ -35,12 +39,14 @@ def test_retrieving_single(user_factory, tag_category_factory, context_factory): 'version': 1, } + def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): with pytest.raises(tag_categories.TagCategoryNotFoundError): api.tag_category_api.get_tag_category( context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), {'category_name': '-'}) + def test_trying_to_retrieve_single_without_privileges( user_factory, context_factory): with pytest.raises(errors.AuthError): diff --git a/server/szurubooru/tests/api/test_tag_category_updating.py b/server/szurubooru/tests/api/test_tag_category_updating.py index af10fbde..d5a863a4 100644 --- a/server/szurubooru/tests/api/test_tag_category_updating.py +++ b/server/szurubooru/tests/api/test_tag_category_updating.py @@ -1,11 +1,13 @@ +from unittest.mock import patch import pytest -import unittest.mock from szurubooru import api, db, errors from szurubooru.func import tag_categories, tags + def _update_category_name(category, name): category.name = name + @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({ @@ -16,14 +18,15 @@ def inject_config(config_injector): }, }) + def test_simple_updating(user_factory, tag_category_factory, context_factory): category = tag_category_factory(name='name', color='black') db.session.add(category) db.session.commit() - with unittest.mock.patch('szurubooru.func.tag_categories.serialize_category'), \ - unittest.mock.patch('szurubooru.func.tag_categories.update_category_name'), \ - unittest.mock.patch('szurubooru.func.tag_categories.update_category_color'), \ - unittest.mock.patch('szurubooru.func.tags.export_to_json'): + with patch('szurubooru.func.tag_categories.serialize_category'), \ + patch('szurubooru.func.tag_categories.update_category_name'), \ + patch('szurubooru.func.tag_categories.update_category_color'), \ + patch('szurubooru.func.tags.export_to_json'): tag_categories.update_category_name.side_effect = _update_category_name tag_categories.serialize_category.return_value = 'serialized category' result = api.tag_category_api.update_tag_category( @@ -36,10 +39,13 @@ def test_simple_updating(user_factory, tag_category_factory, context_factory): user=user_factory(rank=db.User.RANK_REGULAR)), {'category_name': 'name'}) assert result == 'serialized category' - tag_categories.update_category_name.assert_called_once_with(category, 'changed') - tag_categories.update_category_color.assert_called_once_with(category, 'white') + tag_categories.update_category_name.assert_called_once_with( + category, 'changed') + tag_categories.update_category_color.assert_called_once_with( + category, 'white') tags.export_to_json.assert_called_once_with() + @pytest.mark.parametrize('field', ['name', 'color']) def test_omitting_optional_field( user_factory, tag_category_factory, context_factory, field): @@ -50,15 +56,16 @@ def test_omitting_optional_field( 'color': 'white', } del params[field] - with unittest.mock.patch('szurubooru.func.tag_categories.serialize_category'), \ - unittest.mock.patch('szurubooru.func.tag_categories.update_category_name'), \ - unittest.mock.patch('szurubooru.func.tags.export_to_json'): + with patch('szurubooru.func.tag_categories.serialize_category'), \ + patch('szurubooru.func.tag_categories.update_category_name'), \ + patch('szurubooru.func.tags.export_to_json'): api.tag_category_api.update_tag_category( context_factory( params={**params, **{'version': 1}}, user=user_factory(rank=db.User.RANK_REGULAR)), {'category_name': 'name'}) + def test_trying_to_update_non_existing(user_factory, context_factory): with pytest.raises(tag_categories.TagCategoryNotFoundError): api.tag_category_api.update_tag_category( @@ -67,6 +74,7 @@ def test_trying_to_update_non_existing(user_factory, context_factory): user=user_factory(rank=db.User.RANK_REGULAR)), {'category_name': 'bad'}) + @pytest.mark.parametrize('params', [ {'name': 'whatever'}, {'color': 'whatever'}, @@ -82,13 +90,14 @@ def test_trying_to_update_without_privileges( user=user_factory(rank=db.User.RANK_ANONYMOUS)), {'category_name': 'dummy'}) + def test_set_as_default(user_factory, tag_category_factory, context_factory): category = tag_category_factory(name='name', color='black') db.session.add(category) db.session.commit() - with unittest.mock.patch('szurubooru.func.tag_categories.serialize_category'), \ - unittest.mock.patch('szurubooru.func.tag_categories.set_default_category'), \ - unittest.mock.patch('szurubooru.func.tags.export_to_json'): + with patch('szurubooru.func.tag_categories.serialize_category'), \ + patch('szurubooru.func.tag_categories.set_default_category'), \ + patch('szurubooru.func.tags.export_to_json'): tag_categories.update_category_name.side_effect = _update_category_name tag_categories.serialize_category.return_value = 'serialized category' result = api.tag_category_api.set_tag_category_as_default( diff --git a/server/szurubooru/tests/api/test_tag_creating.py b/server/szurubooru/tests/api/test_tag_creating.py index 2076f5d4..2c551774 100644 --- a/server/szurubooru/tests/api/test_tag_creating.py +++ b/server/szurubooru/tests/api/test_tag_creating.py @@ -1,17 +1,19 @@ +from unittest.mock import patch import pytest -import unittest.mock from szurubooru import api, db, errors -from szurubooru.func import tags, tag_categories +from szurubooru.func import tags + @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({'privileges': {'tags:create': db.User.RANK_REGULAR}}) + def test_creating_simple_tags(tag_factory, user_factory, context_factory): - with unittest.mock.patch('szurubooru.func.tags.create_tag'), \ - unittest.mock.patch('szurubooru.func.tags.get_or_create_tags_by_names'), \ - unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \ - unittest.mock.patch('szurubooru.func.tags.export_to_json'): + with patch('szurubooru.func.tags.create_tag'), \ + patch('szurubooru.func.tags.get_or_create_tags_by_names'), \ + patch('szurubooru.func.tags.serialize_tag'), \ + patch('szurubooru.func.tags.export_to_json'): tags.get_or_create_tags_by_names.return_value = ([], []) tags.create_tag.return_value = tag_factory() tags.serialize_tag.return_value = 'serialized tag' @@ -30,6 +32,7 @@ def test_creating_simple_tags(tag_factory, user_factory, context_factory): ['tag1', 'tag2'], 'meta', ['sug1', 'sug2'], ['imp1', 'imp2']) tags.export_to_json.assert_called_once_with() + @pytest.mark.parametrize('field', ['names', 'category']) def test_trying_to_omit_mandatory_field(user_factory, context_factory, field): params = { @@ -45,6 +48,7 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field): params=params, user=user_factory(rank=db.User.RANK_REGULAR))) + @pytest.mark.parametrize('field', ['implications', 'suggestions']) def test_omitting_optional_field( tag_factory, user_factory, context_factory, field): @@ -55,16 +59,18 @@ def test_omitting_optional_field( 'implications': [], } del params[field] - with unittest.mock.patch('szurubooru.func.tags.create_tag'), \ - unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \ - unittest.mock.patch('szurubooru.func.tags.export_to_json'): + with patch('szurubooru.func.tags.create_tag'), \ + patch('szurubooru.func.tags.serialize_tag'), \ + patch('szurubooru.func.tags.export_to_json'): tags.create_tag.return_value = tag_factory() api.tag_api.create_tag( context_factory( params=params, user=user_factory(rank=db.User.RANK_REGULAR))) -def test_trying_to_create_tag_without_privileges(user_factory, context_factory): + +def test_trying_to_create_tag_without_privileges( + user_factory, context_factory): with pytest.raises(errors.AuthError): api.tag_api.create_tag( context_factory( diff --git a/server/szurubooru/tests/api/test_tag_deleting.py b/server/szurubooru/tests/api/test_tag_deleting.py index 98ff9cf5..b5237d1c 100644 --- a/server/szurubooru/tests/api/test_tag_deleting.py +++ b/server/szurubooru/tests/api/test_tag_deleting.py @@ -1,16 +1,18 @@ +from unittest.mock import patch import pytest -import unittest.mock from szurubooru import api, db, errors from szurubooru.func import tags + @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({'privileges': {'tags:delete': db.User.RANK_REGULAR}}) + def test_deleting(user_factory, tag_factory, context_factory): db.session.add(tag_factory(names=['tag'])) db.session.commit() - with unittest.mock.patch('szurubooru.func.tags.export_to_json'): + with patch('szurubooru.func.tags.export_to_json'): result = api.tag_api.delete_tag( context_factory( params={'version': 1}, @@ -20,13 +22,15 @@ def test_deleting(user_factory, tag_factory, context_factory): assert db.session.query(db.Tag).count() == 0 tags.export_to_json.assert_called_once_with() -def test_deleting_used(user_factory, tag_factory, context_factory, post_factory): + +def test_deleting_used( + user_factory, tag_factory, context_factory, post_factory): tag = tag_factory(names=['tag']) post = post_factory() post.tags.append(tag) db.session.add_all([tag, post]) db.session.commit() - with unittest.mock.patch('szurubooru.func.tags.export_to_json'): + with patch('szurubooru.func.tags.export_to_json'): api.tag_api.delete_tag( context_factory( params={'version': 1}, @@ -36,12 +40,14 @@ def test_deleting_used(user_factory, tag_factory, context_factory, post_factory) assert db.session.query(db.Tag).count() == 0 assert post.tags == [] + def test_trying_to_delete_non_existing(user_factory, context_factory): with pytest.raises(tags.TagNotFoundError): api.tag_api.delete_tag( context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), {'tag_name': 'bad'}) + def test_trying_to_delete_without_privileges( user_factory, tag_factory, context_factory): db.session.add(tag_factory(names=['tag'])) diff --git a/server/szurubooru/tests/api/test_tag_merging.py b/server/szurubooru/tests/api/test_tag_merging.py index 90fa4d83..9786abcb 100644 --- a/server/szurubooru/tests/api/test_tag_merging.py +++ b/server/szurubooru/tests/api/test_tag_merging.py @@ -1,12 +1,14 @@ +from unittest.mock import patch import pytest -import unittest.mock from szurubooru import api, db, errors from szurubooru.func import tags + @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({'privileges': {'tags:merge': db.User.RANK_REGULAR}}) + def test_merging(user_factory, tag_factory, context_factory, post_factory): source_tag = tag_factory(names=['source']) target_tag = tag_factory(names=['target']) @@ -20,10 +22,10 @@ def test_merging(user_factory, tag_factory, context_factory, post_factory): db.session.commit() assert source_tag.post_count == 1 assert target_tag.post_count == 0 - with unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \ - unittest.mock.patch('szurubooru.func.tags.merge_tags'), \ - unittest.mock.patch('szurubooru.func.tags.export_to_json'): - result = api.tag_api.merge_tags( + with patch('szurubooru.func.tags.serialize_tag'), \ + patch('szurubooru.func.tags.merge_tags'), \ + patch('szurubooru.func.tags.export_to_json'): + api.tag_api.merge_tags( context_factory( params={ 'removeVersion': 1, @@ -35,6 +37,7 @@ def test_merging(user_factory, tag_factory, context_factory, post_factory): tags.merge_tags.called_once_with(source_tag, target_tag) tags.export_to_json.assert_called_once_with() + @pytest.mark.parametrize( 'field', ['remove', 'mergeTo', 'removeVersion', 'mergeToVersion']) def test_trying_to_omit_mandatory_field( @@ -57,6 +60,7 @@ def test_trying_to_omit_mandatory_field( params=params, user=user_factory(rank=db.User.RANK_REGULAR))) + def test_trying_to_merge_non_existing( user_factory, tag_factory, context_factory): db.session.add(tag_factory(names=['good'])) @@ -72,14 +76,9 @@ def test_trying_to_merge_non_existing( params={'remove': 'bad', 'mergeTo': 'good'}, user=user_factory(rank=db.User.RANK_REGULAR))) -@pytest.mark.parametrize('params', [ - {'names': 'whatever'}, - {'category': 'whatever'}, - {'suggestions': ['whatever']}, - {'implications': ['whatever']}, -]) + def test_trying_to_merge_without_privileges( - user_factory, tag_factory, context_factory, params): + user_factory, tag_factory, context_factory): db.session.add_all([ tag_factory(names=['source']), tag_factory(names=['target']), diff --git a/server/szurubooru/tests/api/test_tag_retrieving.py b/server/szurubooru/tests/api/test_tag_retrieving.py index cff7ee05..fc3f90b2 100644 --- a/server/szurubooru/tests/api/test_tag_retrieving.py +++ b/server/szurubooru/tests/api/test_tag_retrieving.py @@ -1,8 +1,9 @@ +from unittest.mock import patch import pytest -import unittest.mock from szurubooru import api, db, errors from szurubooru.func import tags + @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({ @@ -12,11 +13,12 @@ def inject_config(config_injector): }, }) + def test_retrieving_multiple(user_factory, tag_factory, context_factory): tag1 = tag_factory(names=['t1']) tag2 = tag_factory(names=['t2']) db.session.add_all([tag1, tag2]) - with unittest.mock.patch('szurubooru.func.tags.serialize_tag'): + with patch('szurubooru.func.tags.serialize_tag'): tags.serialize_tag.return_value = 'serialized tag' result = api.tag_api.get_tags( context_factory( @@ -30,6 +32,7 @@ def test_retrieving_multiple(user_factory, tag_factory, context_factory): 'results': ['serialized tag', 'serialized tag'], } + def test_trying_to_retrieve_multiple_without_privileges( user_factory, context_factory): with pytest.raises(errors.AuthError): @@ -38,9 +41,10 @@ def test_trying_to_retrieve_multiple_without_privileges( params={'query': '', 'page': 1}, user=user_factory(rank=db.User.RANK_ANONYMOUS))) + def test_retrieving_single(user_factory, tag_factory, context_factory): db.session.add(tag_factory(names=['tag'])) - with unittest.mock.patch('szurubooru.func.tags.serialize_tag'): + with patch('szurubooru.func.tags.serialize_tag'): tags.serialize_tag.return_value = 'serialized tag' result = api.tag_api.get_tag( context_factory( @@ -48,6 +52,7 @@ def test_retrieving_single(user_factory, tag_factory, context_factory): {'tag_name': 'tag'}) assert result == 'serialized tag' + def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): with pytest.raises(tags.TagNotFoundError): api.tag_api.get_tag( @@ -55,6 +60,7 @@ def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): user=user_factory(rank=db.User.RANK_REGULAR)), {'tag_name': '-'}) + def test_trying_to_retrieve_single_without_privileges( user_factory, context_factory): with pytest.raises(errors.AuthError): diff --git a/server/szurubooru/tests/api/test_tag_siblings_retrieving.py b/server/szurubooru/tests/api/test_tag_siblings_retrieving.py index f79bc340..bc1b20d8 100644 --- a/server/szurubooru/tests/api/test_tag_siblings_retrieving.py +++ b/server/szurubooru/tests/api/test_tag_siblings_retrieving.py @@ -1,16 +1,18 @@ +from unittest.mock import patch import pytest -import unittest.mock from szurubooru import api, db, errors from szurubooru.func import tags + @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({'privileges': {'tags:view': db.User.RANK_REGULAR}}) -def test_get_tag_siblings(user_factory, tag_factory, post_factory, context_factory): + +def test_get_tag_siblings(user_factory, tag_factory, context_factory): db.session.add(tag_factory(names=['tag'])) - with unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \ - unittest.mock.patch('szurubooru.func.tags.get_tag_siblings'): + with patch('szurubooru.func.tags.serialize_tag'), \ + patch('szurubooru.func.tags.get_tag_siblings'): tags.serialize_tag.side_effect = \ lambda tag, *args, **kwargs: \ 'serialized tag %s' % tag.names[0].name @@ -34,12 +36,14 @@ def test_get_tag_siblings(user_factory, tag_factory, post_factory, context_facto ], } + def test_trying_to_retrieve_non_existing(user_factory, context_factory): with pytest.raises(tags.TagNotFoundError): api.tag_api.get_tag_siblings( context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), {'tag_name': '-'}) + def test_trying_to_retrieve_without_privileges(user_factory, context_factory): with pytest.raises(errors.AuthError): api.tag_api.get_tag_siblings( diff --git a/server/szurubooru/tests/api/test_tag_updating.py b/server/szurubooru/tests/api/test_tag_updating.py index 4e9f8525..dbb8b803 100644 --- a/server/szurubooru/tests/api/test_tag_updating.py +++ b/server/szurubooru/tests/api/test_tag_updating.py @@ -1,8 +1,9 @@ +from unittest.mock import patch import pytest -import unittest.mock from szurubooru import api, db, errors from szurubooru.func import tags + @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({ @@ -16,20 +17,21 @@ def inject_config(config_injector): }, }) -def test_simple_updating(user_factory, tag_factory, context_factory, fake_datetime): + +def test_simple_updating(user_factory, tag_factory, context_factory): auth_user = user_factory(rank=db.User.RANK_REGULAR) tag = tag_factory(names=['tag1', 'tag2']) db.session.add(tag) db.session.commit() - with unittest.mock.patch('szurubooru.func.tags.create_tag'), \ - unittest.mock.patch('szurubooru.func.tags.get_or_create_tags_by_names'), \ - unittest.mock.patch('szurubooru.func.tags.update_tag_names'), \ - unittest.mock.patch('szurubooru.func.tags.update_tag_category_name'), \ - unittest.mock.patch('szurubooru.func.tags.update_tag_description'), \ - unittest.mock.patch('szurubooru.func.tags.update_tag_suggestions'), \ - unittest.mock.patch('szurubooru.func.tags.update_tag_implications'), \ - unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \ - unittest.mock.patch('szurubooru.func.tags.export_to_json'): + with patch('szurubooru.func.tags.create_tag'), \ + patch('szurubooru.func.tags.get_or_create_tags_by_names'), \ + patch('szurubooru.func.tags.update_tag_names'), \ + patch('szurubooru.func.tags.update_tag_category_name'), \ + patch('szurubooru.func.tags.update_tag_description'), \ + patch('szurubooru.func.tags.update_tag_suggestions'), \ + patch('szurubooru.func.tags.update_tag_implications'), \ + patch('szurubooru.func.tags.serialize_tag'), \ + patch('szurubooru.func.tags.export_to_json'): tags.get_or_create_tags_by_names.return_value = ([], []) tags.serialize_tag.return_value = 'serialized tag' result = api.tag_api.update_tag( @@ -49,12 +51,22 @@ def test_simple_updating(user_factory, tag_factory, context_factory, fake_dateti tags.update_tag_names.assert_called_once_with(tag, ['tag3']) tags.update_tag_category_name.assert_called_once_with(tag, 'character') tags.update_tag_description.assert_called_once_with(tag, 'desc') - tags.update_tag_suggestions.assert_called_once_with(tag, ['sug1', 'sug2']) - tags.update_tag_implications.assert_called_once_with(tag, ['imp1', 'imp2']) - tags.serialize_tag.assert_called_once_with(tag, options=None) + tags.update_tag_suggestions.assert_called_once_with( + tag, ['sug1', 'sug2']) + tags.update_tag_implications.assert_called_once_with( + tag, ['imp1', 'imp2']) + tags.serialize_tag.assert_called_once_with( + tag, options=None) + @pytest.mark.parametrize( - 'field', ['names', 'category', 'description', 'implications', 'suggestions']) + 'field', [ + 'names', + 'category', + 'description', + 'implications', + 'suggestions', + ]) def test_omitting_optional_field( user_factory, tag_factory, context_factory, field): db.session.add(tag_factory(names=['tag'])) @@ -67,17 +79,18 @@ def test_omitting_optional_field( 'implications': [], } del params[field] - with unittest.mock.patch('szurubooru.func.tags.create_tag'), \ - unittest.mock.patch('szurubooru.func.tags.update_tag_names'), \ - unittest.mock.patch('szurubooru.func.tags.update_tag_category_name'), \ - unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \ - unittest.mock.patch('szurubooru.func.tags.export_to_json'): + with patch('szurubooru.func.tags.create_tag'), \ + patch('szurubooru.func.tags.update_tag_names'), \ + patch('szurubooru.func.tags.update_tag_category_name'), \ + patch('szurubooru.func.tags.serialize_tag'), \ + patch('szurubooru.func.tags.export_to_json'): api.tag_api.update_tag( context_factory( params={**params, **{'version': 1}}, user=user_factory(rank=db.User.RANK_REGULAR)), {'tag_name': 'tag'}) + def test_trying_to_update_non_existing(user_factory, context_factory): with pytest.raises(tags.TagNotFoundError): api.tag_api.update_tag( @@ -86,6 +99,7 @@ def test_trying_to_update_non_existing(user_factory, context_factory): user=user_factory(rank=db.User.RANK_REGULAR)), {'tag_name': 'tag1'}) + @pytest.mark.parametrize('params', [ {'names': 'whatever'}, {'category': 'whatever'}, @@ -103,6 +117,7 @@ def test_trying_to_update_without_privileges( user=user_factory(rank=db.User.RANK_ANONYMOUS)), {'tag_name': 'tag'}) + def test_trying_to_create_tags_without_privileges( config_injector, context_factory, tag_factory, user_factory): tag = tag_factory(names=['tag']) @@ -113,7 +128,7 @@ def test_trying_to_create_tags_without_privileges( 'tags:edit:suggestions': db.User.RANK_REGULAR, 'tags:edit:implications': db.User.RANK_REGULAR, }}) - with unittest.mock.patch('szurubooru.func.tags.get_or_create_tags_by_names'): + with patch('szurubooru.func.tags.get_or_create_tags_by_names'): tags.get_or_create_tags_by_names.return_value = ([], ['new-tag']) with pytest.raises(errors.AuthError): api.tag_api.update_tag( diff --git a/server/szurubooru/tests/api/test_user_creating.py b/server/szurubooru/tests/api/test_user_creating.py index 3e607ec5..8b583b6e 100644 --- a/server/szurubooru/tests/api/test_user_creating.py +++ b/server/szurubooru/tests/api/test_user_creating.py @@ -1,21 +1,23 @@ +from unittest.mock import patch import pytest -import unittest.mock from szurubooru import api, db, errors from szurubooru.func import users + @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({'privileges': {'users:create': 'regular'}}) + def test_creating_user(user_factory, context_factory, fake_datetime): user = user_factory() - with unittest.mock.patch('szurubooru.func.users.create_user'), \ - unittest.mock.patch('szurubooru.func.users.update_user_name'), \ - unittest.mock.patch('szurubooru.func.users.update_user_password'), \ - unittest.mock.patch('szurubooru.func.users.update_user_email'), \ - unittest.mock.patch('szurubooru.func.users.update_user_rank'), \ - unittest.mock.patch('szurubooru.func.users.update_user_avatar'), \ - unittest.mock.patch('szurubooru.func.users.serialize_user'), \ + with patch('szurubooru.func.users.create_user'), \ + patch('szurubooru.func.users.update_user_name'), \ + patch('szurubooru.func.users.update_user_password'), \ + patch('szurubooru.func.users.update_user_email'), \ + patch('szurubooru.func.users.update_user_rank'), \ + patch('szurubooru.func.users.update_user_avatar'), \ + patch('szurubooru.func.users.serialize_user'), \ fake_datetime('1969-02-12'): users.serialize_user.return_value = 'serialized user' users.create_user.return_value = user @@ -31,13 +33,15 @@ def test_creating_user(user_factory, context_factory, fake_datetime): files={'avatar': b'...'}, user=user_factory(rank=db.User.RANK_REGULAR))) assert result == 'serialized user' - users.create_user.assert_called_once_with('chewie1', 'oks', 'asd@asd.asd') + users.create_user.assert_called_once_with( + 'chewie1', 'oks', 'asd@asd.asd') assert not users.update_user_name.called assert not users.update_user_password.called assert not users.update_user_email.called users.update_user_rank.called_once_with(user, 'moderator') users.update_user_avatar.called_once_with(user, 'manual', b'...') + @pytest.mark.parametrize('field', ['name', 'password']) def test_trying_to_omit_mandatory_field(user_factory, context_factory, field): params = { @@ -48,10 +52,12 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field): user = user_factory() auth_user = user_factory(rank=db.User.RANK_REGULAR) del params[field] - with unittest.mock.patch('szurubooru.func.users.create_user'), \ + with patch('szurubooru.func.users.create_user'), \ pytest.raises(errors.MissingRequiredParameterError): users.create_user.return_value = user - api.user_api.create_user(context_factory(params=params, user=auth_user)) + api.user_api.create_user( + context_factory(params=params, user=auth_user)) + @pytest.mark.parametrize('field', ['rank', 'email', 'avatarStyle']) def test_omitting_optional_field(user_factory, context_factory, field): @@ -65,14 +71,16 @@ def test_omitting_optional_field(user_factory, context_factory, field): del params[field] user = user_factory() auth_user = user_factory(rank=db.User.RANK_MODERATOR) - with unittest.mock.patch('szurubooru.func.users.create_user'), \ - unittest.mock.patch('szurubooru.func.users.update_user_avatar'), \ - unittest.mock.patch('szurubooru.func.users.serialize_user'): + with patch('szurubooru.func.users.create_user'), \ + patch('szurubooru.func.users.update_user_avatar'), \ + patch('szurubooru.func.users.serialize_user'): users.create_user.return_value = user api.user_api.create_user( context_factory(params=params, user=auth_user)) -def test_trying_to_create_user_without_privileges(context_factory, user_factory): + +def test_trying_to_create_user_without_privileges( + context_factory, user_factory): with pytest.raises(errors.AuthError): api.user_api.create_user(context_factory( params='whatever', diff --git a/server/szurubooru/tests/api/test_user_deleting.py b/server/szurubooru/tests/api/test_user_deleting.py index ec6ca635..9dd87764 100644 --- a/server/szurubooru/tests/api/test_user_deleting.py +++ b/server/szurubooru/tests/api/test_user_deleting.py @@ -2,6 +2,7 @@ import pytest from szurubooru import api, db, errors from szurubooru.func import users + @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({ @@ -11,6 +12,7 @@ def inject_config(config_injector): }, }) + def test_deleting_oneself(user_factory, context_factory): user = user_factory(name='u', rank=db.User.RANK_REGULAR) db.session.add(user) @@ -21,6 +23,7 @@ def test_deleting_oneself(user_factory, context_factory): assert result == {} assert db.session.query(db.User).count() == 0 + def test_deleting_someone_else(user_factory, context_factory): user1 = user_factory(name='u1', rank=db.User.RANK_REGULAR) user2 = user_factory(name='u2', rank=db.User.RANK_MODERATOR) @@ -31,6 +34,7 @@ def test_deleting_someone_else(user_factory, context_factory): params={'version': 1}, user=user2), {'user_name': 'u1'}) assert db.session.query(db.User).count() == 1 + def test_trying_to_delete_someone_else_without_privileges( user_factory, context_factory): user1 = user_factory(name='u1', rank=db.User.RANK_REGULAR) @@ -43,6 +47,7 @@ def test_trying_to_delete_someone_else_without_privileges( params={'version': 1}, user=user2), {'user_name': 'u1'}) assert db.session.query(db.User).count() == 2 + def test_trying_to_delete_non_existing(user_factory, context_factory): with pytest.raises(users.UserNotFoundError): api.user_api.delete_user( diff --git a/server/szurubooru/tests/api/test_user_retrieving.py b/server/szurubooru/tests/api/test_user_retrieving.py index 3e0dab0b..574fe7fc 100644 --- a/server/szurubooru/tests/api/test_user_retrieving.py +++ b/server/szurubooru/tests/api/test_user_retrieving.py @@ -1,8 +1,9 @@ -import unittest.mock +from unittest.mock import patch import pytest from szurubooru import api, db, errors from szurubooru.func import users + @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({ @@ -13,11 +14,12 @@ def inject_config(config_injector): }, }) + def test_retrieving_multiple(user_factory, context_factory): user1 = user_factory(name='u1', rank=db.User.RANK_MODERATOR) user2 = user_factory(name='u2', rank=db.User.RANK_MODERATOR) db.session.add_all([user1, user2]) - with unittest.mock.patch('szurubooru.func.users.serialize_user'): + with patch('szurubooru.func.users.serialize_user'): users.serialize_user.return_value = 'serialized user' result = api.user_api.get_users( context_factory( @@ -31,6 +33,7 @@ def test_retrieving_multiple(user_factory, context_factory): 'results': ['serialized user', 'serialized user'], } + def test_trying_to_retrieve_multiple_without_privileges( user_factory, context_factory): with pytest.raises(errors.AuthError): @@ -39,22 +42,25 @@ def test_trying_to_retrieve_multiple_without_privileges( params={'query': '', 'page': 1}, user=user_factory(rank=db.User.RANK_ANONYMOUS))) + def test_retrieving_single(user_factory, context_factory): user = user_factory(name='u1', rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=db.User.RANK_REGULAR) db.session.add(user) - with unittest.mock.patch('szurubooru.func.users.serialize_user'): + with patch('szurubooru.func.users.serialize_user'): users.serialize_user.return_value = 'serialized user' result = api.user_api.get_user( context_factory(user=auth_user), {'user_name': 'u1'}) assert result == 'serialized user' + def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): auth_user = user_factory(rank=db.User.RANK_REGULAR) with pytest.raises(users.UserNotFoundError): api.user_api.get_user( context_factory(user=auth_user), {'user_name': '-'}) + def test_trying_to_retrieve_single_without_privileges( user_factory, context_factory): auth_user = user_factory(rank=db.User.RANK_ANONYMOUS) diff --git a/server/szurubooru/tests/api/test_user_updating.py b/server/szurubooru/tests/api/test_user_updating.py index bc93295f..30ff56b6 100644 --- a/server/szurubooru/tests/api/test_user_updating.py +++ b/server/szurubooru/tests/api/test_user_updating.py @@ -1,9 +1,9 @@ +from unittest.mock import patch import pytest -import unittest.mock -from datetime import datetime from szurubooru import api, db, errors from szurubooru.func import users + @pytest.fixture(autouse=True) def inject_config(config_injector): config_injector({ @@ -21,19 +21,20 @@ def inject_config(config_injector): }, }) + def test_updating_user(context_factory, user_factory): user = user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR) auth_user = user_factory(rank=db.User.RANK_ADMINISTRATOR) db.session.add(user) db.session.flush() - with unittest.mock.patch('szurubooru.func.users.create_user'), \ - unittest.mock.patch('szurubooru.func.users.update_user_name'), \ - unittest.mock.patch('szurubooru.func.users.update_user_password'), \ - unittest.mock.patch('szurubooru.func.users.update_user_email'), \ - unittest.mock.patch('szurubooru.func.users.update_user_rank'), \ - unittest.mock.patch('szurubooru.func.users.update_user_avatar'), \ - unittest.mock.patch('szurubooru.func.users.serialize_user'): + with patch('szurubooru.func.users.create_user'), \ + patch('szurubooru.func.users.update_user_name'), \ + patch('szurubooru.func.users.update_user_password'), \ + patch('szurubooru.func.users.update_user_email'), \ + patch('szurubooru.func.users.update_user_rank'), \ + patch('szurubooru.func.users.update_user_avatar'), \ + patch('szurubooru.func.users.serialize_user'): users.serialize_user.return_value = 'serialized user' result = api.user_api.update_user( @@ -57,9 +58,13 @@ def test_updating_user(context_factory, user_factory): users.update_user_name.assert_called_once_with(user, 'chewie') users.update_user_password.assert_called_once_with(user, 'oks') users.update_user_email.assert_called_once_with(user, 'asd@asd.asd') - users.update_user_rank.assert_called_once_with(user, 'moderator', auth_user) - users.update_user_avatar.assert_called_once_with(user, 'manual', b'...') - users.serialize_user.assert_called_once_with(user, auth_user, options=None) + users.update_user_rank.assert_called_once_with( + user, 'moderator', auth_user) + users.update_user_avatar.assert_called_once_with( + user, 'manual', b'...') + users.serialize_user.assert_called_once_with( + user, auth_user, options=None) + @pytest.mark.parametrize( 'field', ['name', 'email', 'password', 'rank', 'avatarStyle']) @@ -74,13 +79,13 @@ def test_omitting_optional_field(user_factory, context_factory, field): 'avatarStyle': 'gravatar', } del params[field] - with unittest.mock.patch('szurubooru.func.users.create_user'), \ - unittest.mock.patch('szurubooru.func.users.update_user_name'), \ - unittest.mock.patch('szurubooru.func.users.update_user_password'), \ - unittest.mock.patch('szurubooru.func.users.update_user_email'), \ - unittest.mock.patch('szurubooru.func.users.update_user_rank'), \ - unittest.mock.patch('szurubooru.func.users.update_user_avatar'), \ - unittest.mock.patch('szurubooru.func.users.serialize_user'): + with patch('szurubooru.func.users.create_user'), \ + patch('szurubooru.func.users.update_user_name'), \ + patch('szurubooru.func.users.update_user_password'), \ + patch('szurubooru.func.users.update_user_email'), \ + patch('szurubooru.func.users.update_user_rank'), \ + patch('szurubooru.func.users.update_user_avatar'), \ + patch('szurubooru.func.users.serialize_user'): api.user_api.update_user( context_factory( params={**params, **{'version': 1}}, @@ -88,6 +93,7 @@ def test_omitting_optional_field(user_factory, context_factory, field): user=user), {'user_name': 'u1'}) + def test_trying_to_update_non_existing(user_factory, context_factory): user = user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR) db.session.add(user) @@ -95,6 +101,7 @@ def test_trying_to_update_non_existing(user_factory, context_factory): api.user_api.update_user( context_factory(user=user), {'user_name': 'u2'}) + @pytest.mark.parametrize('params', [ {'name': 'whatever'}, {'email': 'whatever'}, diff --git a/server/szurubooru/tests/conftest.py b/server/szurubooru/tests/conftest.py index 11f8398a..99db52eb 100644 --- a/server/szurubooru/tests/conftest.py +++ b/server/szurubooru/tests/conftest.py @@ -1,3 +1,4 @@ +# pylint: disable=redefined-outer-name import contextlib import os import datetime @@ -5,30 +6,35 @@ import uuid import pytest import freezegun import sqlalchemy -from szurubooru import api, config, db, rest -from szurubooru.func import util +from szurubooru import config, db, rest + class QueryCounter(object): def __init__(self): self._statements = [] + def __enter__(self): self._statements = [] + def __exit__(self, *args, **kwargs): self._statements = [] + def create_before_cursor_execute(self): def before_cursor_execute( - _conn, _cursor, statement, _parameters, _context, _executemany): + _conn, _cursor, statement, _params, _context, _executemany): self._statements.append(statement) return before_cursor_execute + @property def statements(self): return self._statements + _query_counter = QueryCounter() -engine = sqlalchemy.create_engine('sqlite:///:memory:') -db.Base.metadata.create_all(bind=engine) +_engine = sqlalchemy.create_engine('sqlite:///:memory:') +db.Base.metadata.create_all(bind=_engine) sqlalchemy.event.listen( - engine, + _engine, 'before_cursor_execute', _query_counter.create_before_cursor_execute()) @@ -36,6 +42,7 @@ sqlalchemy.event.listen( def get_unique_name(): return str(uuid.uuid4()) + @pytest.fixture def fake_datetime(): @contextlib.contextmanager @@ -46,22 +53,26 @@ def fake_datetime(): freezer.stop() return injector + @pytest.fixture() def query_counter(): return _query_counter + @pytest.fixture def query_logger(): if pytest.config.option.verbose > 0: import logging import coloredlogs - coloredlogs.install(fmt='[%(asctime)-15s] %(name)s %(message)s', isatty=True) + coloredlogs.install( + fmt='[%(asctime)-15s] %(name)s %(message)s', isatty=True) logging.basicConfig() logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO) + @pytest.yield_fixture(scope='function', autouse=True) -def session(query_logger): - session_maker = sqlalchemy.orm.sessionmaker(bind=engine) +def session(query_logger): # pylint: disable=unused-argument + session_maker = sqlalchemy.orm.sessionmaker(bind=_engine) session = sqlalchemy.orm.scoped_session(session_maker) db.session = session try: @@ -72,6 +83,7 @@ def session(query_logger): session.execute(table.delete()) session.commit() + @pytest.fixture def context_factory(session): def factory(params=None, files=None, user=None): @@ -86,12 +98,14 @@ def context_factory(session): return ctx return factory + @pytest.fixture def config_injector(): def injector(new_config_content): config.config = new_config_content return injector + @pytest.fixture def user_factory(): def factory(name=None, rank=db.User.RANK_REGULAR, email='dummy'): @@ -106,6 +120,7 @@ def user_factory(): return user return factory + @pytest.fixture def tag_category_factory(): def factory(name=None, color='dummy', default=False): @@ -116,6 +131,7 @@ def tag_category_factory(): return category return factory + @pytest.fixture def tag_factory(): def factory(names=None, category=None): @@ -123,14 +139,17 @@ def tag_factory(): category = db.TagCategory(get_unique_name()) db.session.add(category) tag = db.Tag() - tag.names = [db.TagName(name) for name in (names or [get_unique_name()])] + tag.names = [ + db.TagName(name) for name in names or [get_unique_name()]] tag.category = category tag.creation_time = datetime.datetime(1996, 1, 1) return tag return factory + @pytest.fixture def post_factory(): + # pylint: disable=invalid-name def factory( id=None, safety=db.Post.SAFETY_SAFE, @@ -147,6 +166,7 @@ def post_factory(): return post return factory + @pytest.fixture def comment_factory(user_factory, post_factory): def factory(user=None, post=None, text='dummy'): @@ -164,6 +184,7 @@ def comment_factory(user_factory, post_factory): return comment return factory + @pytest.fixture def read_asset(): def get(path): diff --git a/server/szurubooru/tests/db/__init__.py b/server/szurubooru/tests/db/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/szurubooru/tests/db/test_comment.py b/server/szurubooru/tests/db/test_comment.py index 5f3ea708..9a78f952 100644 --- a/server/szurubooru/tests/db/test_comment.py +++ b/server/szurubooru/tests/db/test_comment.py @@ -1,6 +1,7 @@ from datetime import datetime from szurubooru import db + def test_saving_comment(user_factory, post_factory): user = user_factory() post = post_factory() @@ -20,6 +21,7 @@ def test_saving_comment(user_factory, post_factory): assert comment.creation_time == datetime(1997, 1, 1) assert comment.last_edit_time == datetime(1998, 1, 1) + def test_cascade_deletions(comment_factory, user_factory, post_factory): user = user_factory() post = post_factory() diff --git a/server/szurubooru/tests/db/test_post.py b/server/szurubooru/tests/db/test_post.py index 79d10f5d..f0cf7a82 100644 --- a/server/szurubooru/tests/db/test_post.py +++ b/server/szurubooru/tests/db/test_post.py @@ -1,6 +1,7 @@ from datetime import datetime from szurubooru import db + def test_saving_post(post_factory, user_factory, tag_factory): user = user_factory() tag1 = tag_factory() @@ -38,7 +39,10 @@ def test_saving_post(post_factory, user_factory, tag_factory): assert len(related_post1.relations) == 0 assert len(related_post2.relations) == 0 -def test_cascade_deletions(post_factory, user_factory, tag_factory, comment_factory): + +# pylint: disable=too-many-statements +def test_cascade_deletions( + post_factory, user_factory, tag_factory, comment_factory): user = user_factory() tag1 = tag_factory() tag2 = tag_factory() @@ -46,7 +50,8 @@ def test_cascade_deletions(post_factory, user_factory, tag_factory, comment_fact related_post2 = post_factory() post = post_factory() comment = comment_factory(post=post, user=user) - db.session.add_all([user, tag1, tag2, post, related_post1, related_post2, comment]) + db.session.add_all([ + user, tag1, tag2, post, related_post1, related_post2, comment]) db.session.flush() score = db.PostScore() @@ -109,6 +114,7 @@ def test_cascade_deletions(post_factory, user_factory, tag_factory, comment_fact assert db.session.query(db.PostFavorite).count() == 0 assert db.session.query(db.Comment).count() == 0 + def test_tracking_tag_count(post_factory, tag_factory): post = post_factory() tag1 = tag_factory() diff --git a/server/szurubooru/tests/db/test_tag.py b/server/szurubooru/tests/db/test_tag.py index c6658239..1a98ab2e 100644 --- a/server/szurubooru/tests/db/test_tag.py +++ b/server/szurubooru/tests/db/test_tag.py @@ -1,6 +1,7 @@ from datetime import datetime from szurubooru import db + def test_saving_tag(tag_factory): sug1 = tag_factory(names=['sug1']) sug2 = tag_factory(names=['sug2']) @@ -30,7 +31,7 @@ def test_saving_tag(tag_factory): tag = db.session \ .query(db.Tag) \ .join(db.TagName) \ - .filter(db.TagName.name=='alias1') \ + .filter(db.TagName.name == 'alias1') \ .one() assert [tag_name.name for tag_name in tag.names] == ['alias1', 'alias2'] assert tag.category.name == 'category' @@ -41,6 +42,7 @@ def test_saving_tag(tag_factory): assert [relation.names[0].name for relation in tag.implications] \ == ['imp1', 'imp2'] + def test_cascade_deletions(tag_factory): sug1 = tag_factory(names=['sug1']) sug2 = tag_factory(names=['sug2']) @@ -75,6 +77,7 @@ def test_cascade_deletions(tag_factory): assert db.session.query(db.TagImplication).count() == 0 assert db.session.query(db.TagSuggestion).count() == 0 + def test_tracking_post_count(post_factory, tag_factory): tag = tag_factory() post1 = post_factory() diff --git a/server/szurubooru/tests/db/test_user.py b/server/szurubooru/tests/db/test_user.py index 8c4d2d42..c7c925d2 100644 --- a/server/szurubooru/tests/db/test_user.py +++ b/server/szurubooru/tests/db/test_user.py @@ -1,6 +1,7 @@ from datetime import datetime from szurubooru import db + def test_saving_user(): user = db.User() user.name = 'name' @@ -22,6 +23,7 @@ def test_saving_user(): assert user.creation_time == datetime(1997, 1, 1) assert user.avatar_style == db.User.AVATAR_GRAVATAR + def test_upload_count(user_factory, post_factory): user = user_factory() db.session.add(user) @@ -35,6 +37,7 @@ def test_upload_count(user_factory, post_factory): db.session.refresh(user) assert user.post_count == 1 + def test_comment_count(user_factory, comment_factory): user = user_factory() db.session.add(user) @@ -48,55 +51,63 @@ def test_comment_count(user_factory, comment_factory): db.session.refresh(user) assert user.comment_count == 1 + def test_favorite_count(user_factory, post_factory): - user = user_factory() - db.session.add(user) + user1 = user_factory() + user2 = user_factory() + db.session.add(user1) db.session.flush() - assert user.comment_count == 0 + assert user1.comment_count == 0 post1 = post_factory() post2 = post_factory() db.session.add_all([ - db.PostFavorite(post=post1, time=datetime.utcnow(), user=user), - db.PostFavorite(post=post2, time=datetime.utcnow(), user=user_factory()), + db.PostFavorite(post=post1, time=datetime.utcnow(), user=user1), + db.PostFavorite(post=post2, time=datetime.utcnow(), user=user2), ]) db.session.flush() - db.session.refresh(user) - assert user.favorite_post_count == 1 + db.session.refresh(user1) + assert user1.favorite_post_count == 1 + def test_liked_post_count(user_factory, post_factory): - user = user_factory() - db.session.add(user) + user1 = user_factory() + user2 = user_factory() + db.session.add_all([user1, user2]) db.session.flush() - assert user.liked_post_count == 0 - assert user.disliked_post_count == 0 + assert user1.liked_post_count == 0 + assert user1.disliked_post_count == 0 post1 = post_factory() post2 = post_factory() db.session.add_all([ - db.PostScore(post=post1, time=datetime.utcnow(), user=user, score=1), - db.PostScore(post=post2, time=datetime.utcnow(), user=user_factory(), score=1), + db.PostScore(post=post1, time=datetime.utcnow(), user=user1, score=1), + db.PostScore(post=post2, time=datetime.utcnow(), user=user2, score=1), ]) db.session.flush() - db.session.refresh(user) - assert user.liked_post_count == 1 - assert user.disliked_post_count == 0 + db.session.refresh(user1) + assert user1.liked_post_count == 1 + assert user1.disliked_post_count == 0 + def test_disliked_post_count(user_factory, post_factory): - user = user_factory() - db.session.add(user) + user1 = user_factory() + user2 = user_factory() + db.session.add_all([user1, user2]) db.session.flush() - assert user.liked_post_count == 0 - assert user.disliked_post_count == 0 + assert user1.liked_post_count == 0 + assert user1.disliked_post_count == 0 post1 = post_factory() post2 = post_factory() db.session.add_all([ - db.PostScore(post=post1, time=datetime.utcnow(), user=user, score=-1), - db.PostScore(post=post2, time=datetime.utcnow(), user=user_factory(), score=1), + db.PostScore(post=post1, time=datetime.utcnow(), user=user1, score=-1), + db.PostScore(post=post2, time=datetime.utcnow(), user=user2, score=1), ]) db.session.flush() - db.session.refresh(user) - assert user.liked_post_count == 0 - assert user.disliked_post_count == 1 + db.session.refresh(user1) + assert user1.liked_post_count == 0 + assert user1.disliked_post_count == 1 + +# pylint: disable=too-many-statements def test_cascade_deletions(post_factory, user_factory, comment_factory): user = user_factory() diff --git a/server/szurubooru/tests/func/__init__.py b/server/szurubooru/tests/func/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/szurubooru/tests/func/test_comments.py b/server/szurubooru/tests/func/test_comments.py index 78412f23..0cdeca72 100644 --- a/server/szurubooru/tests/func/test_comments.py +++ b/server/szurubooru/tests/func/test_comments.py @@ -1,23 +1,21 @@ -import unittest.mock -import pytest +from unittest.mock import patch from datetime import datetime +import pytest from szurubooru import db from szurubooru.func import comments, users + def test_serialize_user(user_factory, comment_factory): - with unittest.mock.patch('szurubooru.func.users.get_avatar_url'): + with patch('szurubooru.func.users.get_avatar_url'): users.get_avatar_url.return_value = 'https://example.com/avatar.png' comment = comment_factory(user=user_factory(name='dummy')) comment.comment_id = 77 comment.creation_time = datetime(1997, 1, 1) comment.last_edit_time = datetime(1998, 1, 1) comment.text = 'text' - db.session.add(comment) db.session.flush() - auth_user = user_factory() - assert comments.serialize_comment(comment, auth_user) == { 'id': comment.comment_id, 'postId': comment.post.post_id, @@ -33,6 +31,7 @@ def test_serialize_user(user_factory, comment_factory): 'version': 1, } + def test_try_get_comment(comment_factory): comment = comment_factory() db.session.add(comment) @@ -41,6 +40,7 @@ def test_try_get_comment(comment_factory): with pytest.raises(comments.InvalidCommentIdError): comments.try_get_comment_by_id('-') + def test_get_comment(comment_factory): comment = comment_factory() db.session.add(comment) @@ -50,11 +50,12 @@ def test_get_comment(comment_factory): with pytest.raises(comments.InvalidCommentIdError): comments.get_comment_by_id('-') + def test_create_comment(user_factory, post_factory, fake_datetime): user = user_factory() post = post_factory() db.session.add_all([user, post]) - with unittest.mock.patch('szurubooru.func.comments.update_comment_text'), \ + with patch('szurubooru.func.comments.update_comment_text'), \ fake_datetime('1997-01-01'): comment = comments.create_comment(user, post, 'text') assert comment.creation_time == datetime(1997, 1, 1) @@ -62,11 +63,13 @@ def test_create_comment(user_factory, post_factory, fake_datetime): assert comment.post == post comments.update_comment_text.assert_called_once_with(comment, 'text') + def test_update_comment_text_with_emptry_string(comment_factory): comment = comment_factory() with pytest.raises(comments.EmptyCommentTextError): comments.update_comment_text(comment, None) + def test_update_comment_text(comment_factory): comment = comment_factory() comments.update_comment_text(comment, 'text') diff --git a/server/szurubooru/tests/func/test_mime.py b/server/szurubooru/tests/func/test_mime.py index eaf69898..4b8dcfaf 100644 --- a/server/szurubooru/tests/func/test_mime.py +++ b/server/szurubooru/tests/func/test_mime.py @@ -1,7 +1,7 @@ -import os import pytest from szurubooru.func import mime + @pytest.mark.parametrize('input_path,expected_mime_type', [ ('mp4.mp4', 'video/mp4'), ('webm.webm', 'video/webm'), @@ -14,9 +14,11 @@ from szurubooru.func import mime def test_get_mime_type(read_asset, input_path, expected_mime_type): assert mime.get_mime_type(read_asset(input_path)) == expected_mime_type + def test_get_mime_type_for_empty_file(): assert mime.get_mime_type(b'') == 'application/octet-stream' + @pytest.mark.parametrize('mime_type,expected_extension', [ ('video/mp4', 'mp4'), ('video/webm', 'webm'), @@ -29,6 +31,7 @@ def test_get_mime_type_for_empty_file(): def test_get_extension(mime_type, expected_extension): assert mime.get_extension(mime_type) == expected_extension + @pytest.mark.parametrize('input_mime_type,expected_state', [ ('application/x-shockwave-flash', True), ('APPLICATION/X-SHOCKWAVE-FLASH', True), @@ -37,6 +40,7 @@ def test_get_extension(mime_type, expected_extension): def test_is_flash(input_mime_type, expected_state): assert mime.is_flash(input_mime_type) == expected_state + @pytest.mark.parametrize('input_mime_type,expected_state', [ ('video/webm', True), ('VIDEO/WEBM', True), @@ -49,6 +53,7 @@ def test_is_flash(input_mime_type, expected_state): def test_is_video(input_mime_type, expected_state): assert mime.is_video(input_mime_type) == expected_state + @pytest.mark.parametrize('input_mime_type,expected_state', [ ('image/gif', True), ('image/png', True), @@ -62,6 +67,7 @@ def test_is_video(input_mime_type, expected_state): def test_is_image(input_mime_type, expected_state): assert mime.is_image(input_mime_type) == expected_state + @pytest.mark.parametrize('input_path,expected_state', [ ('gif.gif', False), ('gif-animated.gif', True), diff --git a/server/szurubooru/tests/func/test_net.py b/server/szurubooru/tests/func/test_net.py index 4fd625d8..f749d384 100644 --- a/server/szurubooru/tests/func/test_net.py +++ b/server/szurubooru/tests/func/test_net.py @@ -1,5 +1,6 @@ from szurubooru.func import net + def test_download(): url = 'http://info.cern.ch/hypertext/WWW/TheProject.html' diff --git a/server/szurubooru/tests/func/test_posts.py b/server/szurubooru/tests/func/test_posts.py index 7afe64d3..accf75e7 100644 --- a/server/szurubooru/tests/func/test_posts.py +++ b/server/szurubooru/tests/func/test_posts.py @@ -1,9 +1,11 @@ import os -import unittest.mock -import pytest +from unittest.mock import patch from datetime import datetime +import pytest from szurubooru import db -from szurubooru.func import posts, users, comments, snapshots, tags, images +from szurubooru.func import ( + posts, users, comments, snapshots, tags, images, files, util) + @pytest.mark.parametrize('input_mime_type,expected_url', [ ('image/jpeg', 'http://example.com/posts/1.jpg'), @@ -17,6 +19,7 @@ def test_get_post_url(input_mime_type, expected_url, config_injector): post.mime_type = input_mime_type assert posts.get_post_content_url(post) == expected_url + @pytest.mark.parametrize('input_mime_type', ['image/jpeg', 'image/gif']) def test_get_post_thumbnail_url(input_mime_type, config_injector): config_injector({'data_url': 'http://example.com/'}) @@ -26,6 +29,7 @@ def test_get_post_thumbnail_url(input_mime_type, config_injector): assert posts.get_post_thumbnail_url(post) \ == 'http://example.com/generated-thumbnails/1.jpg' + @pytest.mark.parametrize('input_mime_type,expected_path', [ ('image/jpeg', 'posts/1.jpg'), ('image/gif', 'posts/1.gif'), @@ -37,6 +41,7 @@ def test_get_post_content_path(input_mime_type, expected_path): post.mime_type = input_mime_type assert posts.get_post_content_path(post) == expected_path + @pytest.mark.parametrize('input_mime_type', ['image/jpeg', 'image/gif']) def test_get_post_thumbnail_path(input_mime_type): post = db.Post() @@ -44,6 +49,7 @@ def test_get_post_thumbnail_path(input_mime_type): post.mime_type = input_mime_type assert posts.get_post_thumbnail_path(post) == 'generated-thumbnails/1.jpg' + @pytest.mark.parametrize('input_mime_type', ['image/jpeg', 'image/gif']) def test_get_post_thumbnail_backup_path(input_mime_type): post = db.Post() @@ -52,6 +58,7 @@ def test_get_post_thumbnail_backup_path(input_mime_type): assert posts.get_post_thumbnail_backup_path(post) \ == 'posts/custom-thumbnails/1.dat' + def test_serialize_note(): note = db.PostNote() note.polygon = [[0, 1], [1, 1], [1, 0], [0, 0]] @@ -61,16 +68,19 @@ def test_serialize_note(): 'text': '...' } + def test_serialize_post_when_empty(): assert posts.serialize_post(None, None) is None + def test_serialize_post( - post_factory, user_factory, comment_factory, tag_factory, config_injector): + user_factory, comment_factory, tag_factory, config_injector): config_injector({'data_url': 'http://example.com/'}) - with unittest.mock.patch('szurubooru.func.comments.serialize_comment'), \ - unittest.mock.patch('szurubooru.func.users.serialize_micro_user'), \ - unittest.mock.patch('szurubooru.func.posts.files.has', return_value=True), \ - unittest.mock.patch('szurubooru.func.snapshots.get_serialized_history'): + with patch('szurubooru.func.comments.serialize_comment'), \ + patch('szurubooru.func.users.serialize_micro_user'), \ + patch('szurubooru.func.posts.files.has'), \ + patch('szurubooru.func.snapshots.get_serialized_history'): + files.has.return_value = True users.serialize_micro_user.side_effect \ = lambda user, auth_user: user.name comments.serialize_comment.side_effect \ @@ -130,45 +140,47 @@ def test_serialize_post( result = posts.serialize_post(post, auth_user) result['tags'].sort() - assert result == { - 'id': 1, - 'version': 1, - 'creationTime': datetime(1997, 1, 1), - 'lastEditTime': datetime(1998, 1, 1), - 'safety': 'safe', - 'source': '4gag', - 'type': 'image', - 'checksum': 'deadbeef', - 'fileSize': 100, - 'canvasWidth': 200, - 'canvasHeight': 300, - 'contentUrl': 'http://example.com/posts/1.jpg', - 'thumbnailUrl': 'http://example.com/generated-thumbnails/1.jpg', - 'flags': ['loop'], - 'tags': ['tag1', 'tag3'], - 'relations': [], - 'notes': [], - 'user': 'post author', - 'score': 1, - 'ownFavorite': False, - 'ownScore': -1, - 'tagCount': 2, - 'favoriteCount': 1, - 'commentCount': 2, - 'noteCount': 0, - 'featureCount': 1, - 'relationCount': 0, - 'lastFeatureTime': datetime(1999, 1, 1), - 'favoritedBy': ['fav1'], - 'hasCustomThumbnail': True, - 'mimeType': 'image/jpeg', - 'snapshots': 'snapshot history', - 'comments': ['commenter1', 'commenter2'], - } + assert result == { + 'id': 1, + 'version': 1, + 'creationTime': datetime(1997, 1, 1), + 'lastEditTime': datetime(1998, 1, 1), + 'safety': 'safe', + 'source': '4gag', + 'type': 'image', + 'checksum': 'deadbeef', + 'fileSize': 100, + 'canvasWidth': 200, + 'canvasHeight': 300, + 'contentUrl': 'http://example.com/posts/1.jpg', + 'thumbnailUrl': 'http://example.com/generated-thumbnails/1.jpg', + 'flags': ['loop'], + 'tags': ['tag1', 'tag3'], + 'relations': [], + 'notes': [], + 'user': 'post author', + 'score': 1, + 'ownFavorite': False, + 'ownScore': -1, + 'tagCount': 2, + 'favoriteCount': 1, + 'commentCount': 2, + 'noteCount': 0, + 'featureCount': 1, + 'relationCount': 0, + 'lastFeatureTime': datetime(1999, 1, 1), + 'favoritedBy': ['fav1'], + 'hasCustomThumbnail': True, + 'mimeType': 'image/jpeg', + 'snapshots': 'snapshot history', + 'comments': ['commenter1', 'commenter2'], + } + def test_serialize_micro_post(post_factory, user_factory): - with unittest.mock.patch('szurubooru.func.posts.get_post_thumbnail_url'): - posts.get_post_thumbnail_url.return_value = 'https://example.com/thumb.png' + with patch('szurubooru.func.posts.get_post_thumbnail_url'): + posts.get_post_thumbnail_url.return_value \ + = 'https://example.com/thumb.png' auth_user = user_factory() post = post_factory() db.session.add(post) @@ -178,6 +190,7 @@ def test_serialize_micro_post(post_factory, user_factory): 'thumbnailUrl': 'https://example.com/thumb.png', } + def test_get_post_count(post_factory): previous_count = posts.get_post_count() db.session.add_all([post_factory(), post_factory()]) @@ -185,6 +198,7 @@ def test_get_post_count(post_factory): assert previous_count == 0 assert new_count == 2 + def test_try_get_post_by_id(post_factory): post = post_factory() db.session.add(post) @@ -194,6 +208,7 @@ def test_try_get_post_by_id(post_factory): with pytest.raises(posts.InvalidPostIdError): posts.get_post_by_id('-') + def test_get_post_by_id(post_factory): post = post_factory() db.session.add(post) @@ -204,17 +219,19 @@ def test_get_post_by_id(post_factory): with pytest.raises(posts.InvalidPostIdError): posts.get_post_by_id('-') + def test_create_post(user_factory, fake_datetime): - with unittest.mock.patch('szurubooru.func.posts.update_post_content'), \ - unittest.mock.patch('szurubooru.func.posts.update_post_tags'), \ + with patch('szurubooru.func.posts.update_post_content'), \ + patch('szurubooru.func.posts.update_post_tags'), \ fake_datetime('1997-01-01'): auth_user = user_factory() - post, new_tags = posts.create_post('content', ['tag'], auth_user) + post, _new_tags = posts.create_post('content', ['tag'], auth_user) assert post.creation_time == datetime(1997, 1, 1) assert post.last_edit_time is None posts.update_post_tags.assert_called_once_with(post, ['tag']) posts.update_post_content.assert_called_once_with(post, 'content') + @pytest.mark.parametrize('input_safety,expected_safety', [ ('safe', db.Post.SAFETY_SAFE), ('sketchy', db.Post.SAFETY_SKETCHY), @@ -225,31 +242,41 @@ def test_update_post_safety(input_safety, expected_safety): posts.update_post_safety(post, input_safety) assert post.safety == expected_safety + def test_update_post_safety_with_invalid_string(): post = db.Post() with pytest.raises(posts.InvalidPostSafetyError): posts.update_post_safety(post, 'bad') + def test_update_post_source(): post = db.Post() posts.update_post_source(post, 'x') assert post.source == 'x' + def test_update_post_source_with_too_long_string(): post = db.Post() with pytest.raises(posts.InvalidPostSourceError): posts.update_post_source(post, 'x' * 1000) + @pytest.mark.parametrize( - 'input_file,expected_mime_type,expected_type,output_file_name', [ - ('png.png', 'image/png', db.Post.TYPE_IMAGE, '1.png'), - ('jpeg.jpg', 'image/jpeg', db.Post.TYPE_IMAGE, '1.jpg'), - ('gif.gif', 'image/gif', db.Post.TYPE_IMAGE, '1.gif'), - ('gif-animated.gif', 'image/gif', db.Post.TYPE_ANIMATION, '1.gif'), - ('webm.webm', 'video/webm', db.Post.TYPE_VIDEO, '1.webm'), - ('mp4.mp4', 'video/mp4', db.Post.TYPE_VIDEO, '1.mp4'), - ('flash.swf', 'application/x-shockwave-flash', db.Post.TYPE_FLASH, '1.swf'), -]) + 'input_file,expected_mime_type,expected_type,output_file_name', + [ + ('png.png', 'image/png', db.Post.TYPE_IMAGE, '1.png'), + ('jpeg.jpg', 'image/jpeg', db.Post.TYPE_IMAGE, '1.jpg'), + ('gif.gif', 'image/gif', db.Post.TYPE_IMAGE, '1.gif'), + ('gif-animated.gif', 'image/gif', db.Post.TYPE_ANIMATION, '1.gif'), + ('webm.webm', 'video/webm', db.Post.TYPE_VIDEO, '1.webm'), + ('mp4.mp4', 'video/mp4', db.Post.TYPE_VIDEO, '1.mp4'), + ( + 'flash.swf', + 'application/x-shockwave-flash', + db.Post.TYPE_FLASH, + '1.swf' + ), + ]) def test_update_post_content( tmpdir, config_injector, @@ -259,7 +286,8 @@ def test_update_post_content( expected_mime_type, expected_type, output_file_name): - with unittest.mock.patch('szurubooru.func.util.get_md5', return_value='crc'): + with patch('szurubooru.func.util.get_md5'): + util.get_md5.return_value = 'crc' config_injector({ 'data_dir': str(tmpdir.mkdir('data')), 'thumbnails': { @@ -271,10 +299,11 @@ def test_update_post_content( db.session.add(post) db.session.flush() posts.update_post_content(post, read_asset(input_file)) - assert post.mime_type == expected_mime_type - assert post.type == expected_type - assert post.checksum == 'crc' - assert os.path.exists(str(tmpdir) + '/data/posts/' + output_file_name) + assert post.mime_type == expected_mime_type + assert post.type == expected_type + assert post.checksum == 'crc' + assert os.path.exists(str(tmpdir) + '/data/posts/' + output_file_name) + def test_update_post_content_to_existing_content( tmpdir, config_injector, post_factory, read_asset): @@ -293,6 +322,7 @@ def test_update_post_content_to_existing_content( with pytest.raises(posts.PostAlreadyUploadedError): posts.update_post_content(another_post, read_asset('png.png')) + def test_update_post_content_with_broken_content( tmpdir, config_injector, post_factory, read_asset): # the rationale behind this behavior is to salvage user upload even if the @@ -313,12 +343,14 @@ def test_update_post_content_with_broken_content( assert post.canvas_width is None assert post.canvas_height is None + @pytest.mark.parametrize('input_content', [None, b'not a media file']) def test_update_post_content_with_invalid_content(input_content): post = db.Post() with pytest.raises(posts.InvalidPostContentError): posts.update_post_content(post, input_content) + def test_update_post_thumbnail_to_new_one( tmpdir, config_injector, read_asset, post_factory): config_injector({ @@ -333,11 +365,14 @@ def test_update_post_thumbnail_to_new_one( db.session.flush() posts.update_post_content(post, read_asset('png.png')) posts.update_post_thumbnail(post, read_asset('jpeg.jpg')) - assert os.path.exists(str(tmpdir) + '/data/posts/custom-thumbnails/1.dat') - assert os.path.exists(str(tmpdir) + '/data/generated-thumbnails/1.jpg') - with open(str(tmpdir) + '/data/posts/custom-thumbnails/1.dat', 'rb') as handle: + source_path = str(tmpdir) + '/data/posts/custom-thumbnails/1.dat' + generated_path = str(tmpdir) + '/data/generated-thumbnails/1.jpg' + assert os.path.exists(source_path) + assert os.path.exists(generated_path) + with open(source_path, 'rb') as handle: assert handle.read() == read_asset('jpeg.jpg') + def test_update_post_thumbnail_to_default( tmpdir, config_injector, read_asset, post_factory): config_injector({ @@ -353,8 +388,10 @@ def test_update_post_thumbnail_to_default( posts.update_post_content(post, read_asset('png.png')) posts.update_post_thumbnail(post, read_asset('jpeg.jpg')) posts.update_post_thumbnail(post, None) - assert not os.path.exists(str(tmpdir) + '/data/posts/custom-thumbnails/1.dat') assert os.path.exists(str(tmpdir) + '/data/generated-thumbnails/1.jpg') + assert not os.path.exists( + str(tmpdir) + '/data/posts/custom-thumbnails/1.dat') + def test_update_post_thumbnail_with_broken_thumbnail( tmpdir, config_injector, read_asset, post_factory): @@ -370,15 +407,18 @@ def test_update_post_thumbnail_with_broken_thumbnail( db.session.flush() posts.update_post_content(post, read_asset('png.png')) posts.update_post_thumbnail(post, read_asset('png-broken.png')) - assert os.path.exists(str(tmpdir) + '/data/posts/custom-thumbnails/1.dat') - assert os.path.exists(str(tmpdir) + '/data/generated-thumbnails/1.jpg') - with open(str(tmpdir) + '/data/posts/custom-thumbnails/1.dat', 'rb') as handle: + source_path = str(tmpdir) + '/data/posts/custom-thumbnails/1.dat' + generated_path = str(tmpdir) + '/data/generated-thumbnails/1.jpg' + assert os.path.exists(source_path) + assert os.path.exists(generated_path) + with open(source_path, 'rb') as handle: assert handle.read() == read_asset('png-broken.png') - with open(str(tmpdir) + '/data/generated-thumbnails/1.jpg', 'rb') as handle: + with open(generated_path, 'rb') as handle: image = images.Image(handle.read()) assert image.width == 1 assert image.height == 1 + def test_update_post_content_leaving_custom_thumbnail( tmpdir, config_injector, read_asset, post_factory): config_injector({ @@ -397,9 +437,10 @@ def test_update_post_content_leaving_custom_thumbnail( assert os.path.exists(str(tmpdir) + '/data/posts/custom-thumbnails/1.dat') assert os.path.exists(str(tmpdir) + '/data/generated-thumbnails/1.jpg') + def test_update_post_tags(tag_factory): post = db.Post() - with unittest.mock.patch('szurubooru.func.tags.get_or_create_tags_by_names'): + with patch('szurubooru.func.tags.get_or_create_tags_by_names'): tags.get_or_create_tags_by_names.side_effect \ = lambda tag_names: \ ([tag_factory(names=[name]) for name in tag_names], []) @@ -408,6 +449,7 @@ def test_update_post_tags(tag_factory): assert post.tags[0].names[0].name == 'tag1' assert post.tags[1].names[0].name == 'tag2' + def test_update_post_relations(post_factory): relation1 = post_factory() relation2 = post_factory() @@ -419,6 +461,7 @@ def test_update_post_relations(post_factory): assert post.relations[0].post_id == relation1.post_id assert post.relations[1].post_id == relation2.post_id + def test_update_post_relations_bidirectionality(post_factory): relation1 = post_factory() relation2 = post_factory() @@ -430,11 +473,13 @@ def test_update_post_relations_bidirectionality(post_factory): assert len(post.relations) == 1 assert post.relations[0].post_id == relation2.post_id + def test_update_post_relations_with_nonexisting_posts(): post = db.Post() with pytest.raises(posts.InvalidPostRelationError): posts.update_post_relations(post, [100]) + def test_update_post_notes(): post = db.Post() posts.update_post_notes( @@ -449,6 +494,7 @@ def test_update_post_notes(): assert post.notes[1].polygon == [[0, 0], [0, 1], [1, 0], [0, 0]] assert post.notes[1].text == 'text2' + @pytest.mark.parametrize('input', [ [{'text': '...'}], [{'polygon': None, 'text': '...'}], @@ -473,16 +519,19 @@ def test_update_post_notes_with_invalid_content(input): with pytest.raises(posts.InvalidPostNoteError): posts.update_post_notes(post, input) + def test_update_post_flags(): post = db.Post() posts.update_post_flags(post, ['loop']) assert post.flags == ['loop'] + def test_update_post_flags_with_invalid_content(): post = db.Post() with pytest.raises(posts.InvalidPostFlagError): posts.update_post_flags(post, ['invalid']) + def test_feature_post(post_factory, user_factory): post = post_factory() user = user_factory() @@ -492,6 +541,7 @@ def test_feature_post(post_factory, user_factory): assert previous_featured_post is None assert new_featured_post == post + def test_delete(post_factory): post = post_factory() db.session.add(post) diff --git a/server/szurubooru/tests/func/test_snapshots.py b/server/szurubooru/tests/func/test_snapshots.py index c223f93e..0ad17bad 100644 --- a/server/szurubooru/tests/func/test_snapshots.py +++ b/server/szurubooru/tests/func/test_snapshots.py @@ -1,8 +1,9 @@ -import datetime +from datetime import datetime import pytest from szurubooru import db from szurubooru.func import snapshots + def test_serializing_post(post_factory, user_factory, tag_factory): user = user_factory(name='dummy-user') tag1 = tag_factory(names=['dummy-tag1']) @@ -16,16 +17,16 @@ def test_serializing_post(post_factory, user_factory, tag_factory): score = db.PostScore() score.post = post score.user = user - score.time = datetime.datetime(1997, 1, 1) + score.time = datetime(1997, 1, 1) score.score = 1 favorite = db.PostFavorite() favorite.post = post favorite.user = user - favorite.time = datetime.datetime(1997, 1, 1) + favorite.time = datetime(1997, 1, 1) feature = db.PostFeature() feature.post = post feature.user = user - feature.time = datetime.datetime(1997, 1, 1) + feature.time = datetime(1997, 1, 1) note = db.PostNote() note.post = post note.polygon = [(1, 1), (200, 1), (200, 200), (1, 200)] @@ -88,6 +89,7 @@ def test_serializing_tag(tag_factory, tag_category_factory): 'suggestions': ['sug1_main_name', 'sug2_main_name'], } + def test_serializing_tag_category(tag_category_factory): category = tag_category_factory(name='name', color='color') assert snapshots.get_tag_category_snapshot(category) == { @@ -102,6 +104,7 @@ def test_serializing_tag_category(tag_category_factory): 'default': True, } + def test_merging_modification_to_creation(tag_factory, user_factory): tag = tag_factory(names=['dummy']) user = user_factory() @@ -115,6 +118,7 @@ def test_merging_modification_to_creation(tag_factory, user_factory): assert results[0].operation == db.Snapshot.OPERATION_CREATED assert results[0].data['names'] == ['changed'] + def test_merging_modifications(fake_datetime, tag_factory, user_factory): tag = tag_factory(names=['dummy']) user = user_factory() @@ -135,6 +139,7 @@ def test_merging_modifications(fake_datetime, tag_factory, user_factory): assert results[0].data['names'] == ['dummy'] assert results[1].data['names'] == ['changed again'] + def test_not_adding_snapshot_if_data_doesnt_change( fake_datetime, tag_factory, user_factory): tag = tag_factory(names=['dummy']) @@ -150,6 +155,7 @@ def test_not_adding_snapshot_if_data_doesnt_change( assert results[0].operation == db.Snapshot.OPERATION_CREATED assert results[0].data['names'] == ['dummy'] + def test_not_merging_due_to_time_difference( fake_datetime, tag_factory, user_factory): tag = tag_factory(names=['dummy']) @@ -163,6 +169,7 @@ def test_not_merging_due_to_time_difference( snapshots.save_entity_modification(tag, user) assert db.session.query(db.Snapshot).count() == 2 + def test_not_merging_operations_by_different_users( fake_datetime, tag_factory, user_factory): tag = tag_factory(names=['dummy']) @@ -175,6 +182,7 @@ def test_not_merging_operations_by_different_users( snapshots.save_entity_modification(tag, user2) assert db.session.query(db.Snapshot).count() == 2 + def test_merging_resets_merging_time_window( fake_datetime, tag_factory, user_factory): tag = tag_factory(names=['dummy']) @@ -193,6 +201,7 @@ def test_merging_resets_merging_time_window( assert len(results) == 1 assert results[0].data['names'] == ['changed again'] + @pytest.mark.parametrize( 'initial_operation', [snapshots.save_entity_creation, snapshots.save_entity_modification]) @@ -228,11 +237,9 @@ def test_merging_deletion_to_modification_or_creation( 'implications': [], } -@pytest.mark.parametrize( - 'expected_operation', - [snapshots.save_entity_creation, snapshots.save_entity_modification]) + def test_merging_deletion_all_the_way_deletes_all_snapshots( - fake_datetime, tag_factory, user_factory, expected_operation): + fake_datetime, tag_factory, user_factory): tag = tag_factory(names=['dummy']) user = user_factory() db.session.add_all([tag, user]) @@ -247,6 +254,7 @@ def test_merging_deletion_all_the_way_deletes_all_snapshots( snapshots.save_entity_deletion(tag, user) assert db.session.query(db.Snapshot).count() == 0 + def test_get_serialized_history( fake_datetime, tag_factory, tag_category_factory, user_factory): category = tag_category_factory(name='dummy') @@ -263,7 +271,7 @@ def test_get_serialized_history( assert snapshots.get_serialized_history(tag) == [ { 'operation': 'modified', - 'time': datetime.datetime(2016, 4, 19, 13, 10, 1), + 'time': datetime(2016, 4, 19, 13, 10, 1), 'type': 'tag', 'id': 'changed', 'user': 'the-user', @@ -282,7 +290,7 @@ def test_get_serialized_history( }, { 'operation': 'created', - 'time': datetime.datetime(2016, 4, 19, 13, 0, 0), + 'time': datetime(2016, 4, 19, 13, 0, 0), 'type': 'tag', 'id': 'dummy', 'user': 'the-user', diff --git a/server/szurubooru/tests/func/test_tag_categories.py b/server/szurubooru/tests/func/test_tag_categories.py index 1580afec..5b67c32b 100644 --- a/server/szurubooru/tests/func/test_tag_categories.py +++ b/server/szurubooru/tests/func/test_tag_categories.py @@ -1,64 +1,68 @@ -import os -import json +from unittest.mock import patch import pytest -import unittest.mock from szurubooru import db -from szurubooru.func import tags, tag_categories, cache, snapshots +from szurubooru.func import tag_categories, cache, snapshots + @pytest.fixture(autouse=True) -def purge_cache(config_injector): +def purge_cache(): cache.purge() + def test_serialize_category_when_empty(): assert tag_categories.serialize_category(None, None) is None -def test_serialize_category(tag_category_factory, tag_factory): - with unittest.mock.patch('szurubooru.func.snapshots.get_serialized_history'): - snapshots.get_serialized_history.return_value = 'snapshot history' +def test_serialize_category(tag_category_factory, tag_factory): + with patch('szurubooru.func.snapshots.get_serialized_history'): + snapshots.get_serialized_history.return_value = 'snapshot history' category = tag_category_factory(name='name', color='color') category.category_id = 1 category.default = True - tag1 = tag_factory(category=category) tag2 = tag_factory(category=category) - db.session.add_all([category, tag1, tag2]) db.session.flush() - result = tag_categories.serialize_category(category) + assert result == { + 'name': 'name', + 'color': 'color', + 'default': True, + 'version': 1, + 'snapshots': 'snapshot history', + 'usages': 2, + } - assert result == { - 'name': 'name', - 'color': 'color', - 'default': True, - 'version': 1, - 'snapshots': 'snapshot history', - 'usages': 2, - } def test_create_category_when_first(): - with unittest.mock.patch('szurubooru.func.tag_categories.update_category_name'), \ - unittest.mock.patch('szurubooru.func.tag_categories.update_category_color'): + with patch('szurubooru.func.tag_categories.update_category_name'), \ + patch('szurubooru.func.tag_categories.update_category_color'): category = tag_categories.create_category('name', 'color') assert category.default - tag_categories.update_category_name.assert_called_once_with(category, 'name') - tag_categories.update_category_color.assert_called_once_with(category, 'color') + tag_categories.update_category_name \ + .assert_called_once_with(category, 'name') + tag_categories.update_category_color \ + .assert_called_once_with(category, 'color') + def test_create_category_when_subsequent(tag_category_factory): db.session.add(tag_category_factory()) - with unittest.mock.patch('szurubooru.func.tag_categories.update_category_name'), \ - unittest.mock.patch('szurubooru.func.tag_categories.update_category_color'): + with patch('szurubooru.func.tag_categories.update_category_name'), \ + patch('szurubooru.func.tag_categories.update_category_color'): category = tag_categories.create_category('name', 'color') assert not category.default - tag_categories.update_category_name.assert_called_once_with(category, 'name') - tag_categories.update_category_color.assert_called_once_with(category, 'color') + tag_categories.update_category_name \ + .assert_called_once_with(category, 'name') + tag_categories.update_category_color \ + .assert_called_once_with(category, 'color') + def test_update_category_name_with_empty_string(tag_category_factory): category = tag_category_factory() with pytest.raises(tag_categories.InvalidTagCategoryNameError): tag_categories.update_category_name(category, None) + def test_update_category_name_with_invalid_name( config_injector, tag_category_factory): config_injector({'tag_category_name_regex': '^[a-z]+$'}) @@ -66,6 +70,7 @@ def test_update_category_name_with_invalid_name( with pytest.raises(tag_categories.InvalidTagCategoryNameError): tag_categories.update_category_name(category, '0') + def test_update_category_name_with_too_long_string( config_injector, tag_category_factory): config_injector({'tag_category_name_regex': '^[a-z]+$'}) @@ -73,6 +78,7 @@ def test_update_category_name_with_too_long_string( with pytest.raises(tag_categories.InvalidTagCategoryNameError): tag_categories.update_category_name(category, 'a' * 3000) + def test_update_category_name_reusing_other_name( config_injector, tag_category_factory): config_injector({'tag_category_name_regex': '.*'}) @@ -83,6 +89,7 @@ def test_update_category_name_reusing_other_name( with pytest.raises(tag_categories.TagCategoryAlreadyExistsError): tag_categories.update_category_name(category, 'NAME') + def test_update_category_name_reusing_own_name( config_injector, tag_category_factory): config_injector({'tag_category_name_regex': '.*'}) @@ -94,27 +101,32 @@ def test_update_category_name_reusing_own_name( assert category.name == name db.session.rollback() + def test_update_category_color_with_empty_string(tag_category_factory): category = tag_category_factory() with pytest.raises(tag_categories.InvalidTagCategoryColorError): tag_categories.update_category_color(category, None) + def test_update_category_color_with_too_long_string(tag_category_factory): category = tag_category_factory() with pytest.raises(tag_categories.InvalidTagCategoryColorError): tag_categories.update_category_color(category, 'a' * 3000) + def test_update_category_color_with_invalid_string(tag_category_factory): category = tag_category_factory() with pytest.raises(tag_categories.InvalidTagCategoryColorError): tag_categories.update_category_color(category, 'NOPE') + @pytest.mark.parametrize('attempt', ['#aaaaaa', '#012345', '012345', 'red']) def test_update_category_color(attempt, tag_category_factory): category = tag_category_factory() tag_categories.update_category_color(category, attempt) assert category.color == attempt + def test_try_get_category_by_name(tag_category_factory): category = tag_category_factory(name='test') db.session.add(category) @@ -122,6 +134,7 @@ def test_try_get_category_by_name(tag_category_factory): assert tag_categories.try_get_category_by_name('TEST') == category assert tag_categories.try_get_category_by_name('-') is None + def test_get_category_by_name(tag_category_factory): category = tag_category_factory(name='test') db.session.add(category) @@ -130,18 +143,21 @@ def test_get_category_by_name(tag_category_factory): with pytest.raises(tag_categories.TagCategoryNotFoundError): tag_categories.get_category_by_name('-') + def test_get_all_category_names(tag_category_factory): category1 = tag_category_factory(name='cat1') category2 = tag_category_factory(name='cat2') db.session.add_all([category1, category2]) assert tag_categories.get_all_category_names() == ['cat1', 'cat2'] + def test_get_all_categories(tag_category_factory): category1 = tag_category_factory(name='cat1') category2 = tag_category_factory(name='cat2') db.session.add_all([category1, category2]) assert tag_categories.get_all_categories() == [category1, category2] + def test_try_get_default_category_when_no_default(tag_category_factory): category1 = tag_category_factory(default=False) category2 = tag_category_factory(default=False) @@ -150,6 +166,7 @@ def test_try_get_default_category_when_no_default(tag_category_factory): assert actual_default_category == category1 assert actual_default_category != category2 + def test_try_get_default_category_when_default(tag_category_factory): category1 = tag_category_factory(default=False) category2 = tag_category_factory(default=True) @@ -158,6 +175,7 @@ def test_try_get_default_category_when_default(tag_category_factory): assert actual_default_category == category2 assert actual_default_category != category1 + def test_try_get_default_category_from_cache(tag_category_factory): category1 = tag_category_factory() category2 = tag_category_factory() @@ -168,13 +186,15 @@ def test_try_get_default_category_from_cache(tag_category_factory): cache.purge() assert tag_categories.try_get_default_category() is None -def test_get_default_category(tag_category_factory): - with unittest.mock.patch('szurubooru.func.tag_categories.try_get_default_category'): + +def test_get_default_category(): + with patch('szurubooru.func.tag_categories.try_get_default_category'): tag_categories.try_get_default_category.return_value = None with pytest.raises(tag_categories.TagCategoryNotFoundError): tag_categories.get_default_category() - tag_categories.try_get_default_category.return_value = 'returned category' - assert tag_categories.get_default_category() == 'returned category' + tag_categories.try_get_default_category.return_value = 'mocked' + assert tag_categories.get_default_category() == 'mocked' + def test_set_default_category_with_previous_default(tag_category_factory): category1 = tag_category_factory(default=True) @@ -184,6 +204,7 @@ def test_set_default_category_with_previous_default(tag_category_factory): assert not category1.default assert category2.default + def test_set_default_category_without_previous_default(tag_category_factory): category1 = tag_category_factory() category2 = tag_category_factory() @@ -191,12 +212,14 @@ def test_set_default_category_without_previous_default(tag_category_factory): tag_categories.set_default_category(category2) assert category2.default + def test_delete_category_with_no_other_categories(tag_category_factory): category = tag_category_factory() db.session.add(category) with pytest.raises(tag_categories.TagCategoryIsInUseError): tag_categories.delete_category(category) + def test_delete_category_with_usages(tag_category_factory, tag_factory): db.session.add(tag_category_factory()) category = tag_category_factory() @@ -204,6 +227,7 @@ def test_delete_category_with_usages(tag_category_factory, tag_factory): with pytest.raises(tag_categories.TagCategoryIsInUseError): tag_categories.delete_category(category) + def test_delete_category(tag_category_factory): db.session.add(tag_category_factory()) category = tag_category_factory(name='target') diff --git a/server/szurubooru/tests/func/test_tags.py b/server/szurubooru/tests/func/test_tags.py index 25dbc434..07c6723e 100644 --- a/server/szurubooru/tests/func/test_tags.py +++ b/server/szurubooru/tests/func/test_tags.py @@ -1,20 +1,22 @@ import os import json -import pytest -import unittest.mock +from unittest.mock import patch from datetime import datetime +import pytest from szurubooru import db from szurubooru.func import tags, tag_categories, cache, snapshots + @pytest.fixture(autouse=True) -def purge_cache(config_injector): +def purge_cache(): cache.purge() -def _assert_tag_siblings(result, expected_tag_names_and_occurrences): - actual_tag_names_and_occurences = [] - for sibling, occurrences in result: - actual_tag_names_and_occurences.append((sibling.names[0].name, occurrences)) - assert actual_tag_names_and_occurences == expected_tag_names_and_occurrences + +def _assert_tag_siblings(result, expected_names_and_occurrences): + actual_names_and_occurences = [ + (tag.names[0].name, occurrences) for tag, occurrences in result] + assert actual_names_and_occurences == expected_names_and_occurrences + @pytest.mark.parametrize('input,expected_tag_names', [ ([('a', 'a', True), ('b', 'b', False), ('c', 'c', False)], list('abc')), @@ -22,7 +24,8 @@ def _assert_tag_siblings(result, expected_tag_names_and_occurrences): ([('a', 'c', True), ('b', 'b', False), ('c', 'a', False)], list('acb')), ([('a', 'c', False), ('b', 'b', False), ('c', 'a', True)], list('cba')), ]) -def test_sort_tags(input, expected_tag_names, tag_factory, tag_category_factory): +def test_sort_tags( + input, expected_tag_names, tag_factory, tag_category_factory): db_tags = [] for tag in input: tag_name, category_name, category_is_default = tag @@ -35,45 +38,46 @@ def test_sort_tags(input, expected_tag_names, tag_factory, tag_category_factory) actual_tag_names = [tag.names[0].name for tag in tags.sort_tags(db_tags)] assert actual_tag_names == expected_tag_names + def test_serialize_tag_when_empty(): assert tags.serialize_tag(None, None) is None -def test_serialize_tag(post_factory, tag_factory, tag_category_factory): - with unittest.mock.patch('szurubooru.func.snapshots.get_serialized_history'): - snapshots.get_serialized_history.return_value = 'snapshot history' +def test_serialize_tag(post_factory, tag_factory, tag_category_factory): + with patch('szurubooru.func.snapshots.get_serialized_history'): + snapshots.get_serialized_history.return_value = 'snapshot history' tag = tag_factory( names=['tag1', 'tag2'], category=tag_category_factory(name='cat')) tag.tag_id = 1 tag.description = 'description' - tag.suggestions = [tag_factory(names=['sug1']), tag_factory(names=['sug2'])] - tag.implications = [tag_factory(names=['impl1']), tag_factory(names=['impl2'])] + tag.suggestions = [ + tag_factory(names=['sug1']), tag_factory(names=['sug2'])] + tag.implications = [ + tag_factory(names=['impl1']), tag_factory(names=['impl2'])] tag.last_edit_time = datetime(1998, 1, 1) - post1 = post_factory() post2 = post_factory() post1.tags = [tag] post2.tags = [tag] db.session.add_all([tag, post1, post2]) db.session.flush() - result = tags.serialize_tag(tag) result['suggestions'].sort() result['implications'].sort() + assert result == { + 'names': ['tag1', 'tag2'], + 'version': 1, + 'category': 'cat', + 'creationTime': datetime(1996, 1, 1, 0, 0), + 'lastEditTime': datetime(1998, 1, 1, 0, 0), + 'description': 'description', + 'suggestions': ['sug1', 'sug2'], + 'implications': ['impl1', 'impl2'], + 'usages': 2, + 'snapshots': 'snapshot history', + } - assert result == { - 'names': ['tag1', 'tag2'], - 'version': 1, - 'category': 'cat', - 'creationTime': datetime(1996, 1, 1, 0, 0), - 'lastEditTime': datetime(1998, 1, 1, 0, 0), - 'description': 'description', - 'suggestions': ['sug1', 'sug2'], - 'implications': ['impl1', 'impl2'], - 'usages': 2, - 'snapshots': 'snapshot history', - } def test_export_to_json( tmpdir, @@ -85,24 +89,18 @@ def test_export_to_json( config_injector({'data_dir': str(tmpdir)}) cat1 = tag_category_factory(name='cat1', color='black') cat2 = tag_category_factory(name='cat2', color='white') - db.session.add_all([cat1, cat2]) - db.session.flush() - sug1 = tag_factory(names=['sug1'], category=cat1) - sug2 = tag_factory(names=['sug2'], category=cat1) - imp1 = tag_factory(names=['imp1'], category=cat1) - imp2 = tag_factory(names=['imp2'], category=cat1) tag = tag_factory(names=['alias1', 'alias2'], category=cat2) - db.session.add_all([tag, sug1, sug2, imp1, imp2, cat1, cat2]) + tag.suggestions = [ + tag_factory(names=['sug1'], category=cat1), + tag_factory(names=['sug2'], category=cat1), + ] + tag.implications = [ + tag_factory(names=['imp1'], category=cat1), + tag_factory(names=['imp2'], category=cat1), + ] post = post_factory() post.tags = [tag] - db.session.flush() - db.session.add_all([ - post, - db.TagSuggestion(tag.tag_id, sug1.tag_id), - db.TagSuggestion(tag.tag_id, sug2.tag_id), - db.TagImplication(tag.tag_id, imp1.tag_id), - db.TagImplication(tag.tag_id, imp2.tag_id), - ]) + db.session.add_all([post, tag]) db.session.flush() with query_counter: @@ -127,11 +125,12 @@ def test_export_to_json( {'names': ['imp2'], 'usages': 0, 'category': 'cat1'}, ], 'categories': [ - {'name': 'cat1', 'color': 'black'}, {'name': 'cat2', 'color': 'white'}, + {'name': 'cat1', 'color': 'black'}, ] } + @pytest.mark.parametrize('name_to_search,expected_to_find', [ ('name', True), ('NAME', True), @@ -147,6 +146,7 @@ def test_try_get_tag_by_name(name_to_search, expected_to_find, tag_factory): else: assert tags.try_get_tag_by_name(name_to_search) is None + @pytest.mark.parametrize('name_to_search,expected_to_find', [ ('name', True), ('NAME', True), @@ -163,7 +163,8 @@ def test_get_tag_by_name(name_to_search, expected_to_find, tag_factory): with pytest.raises(tags.TagNotFoundError): tags.get_tag_by_name(name_to_search) -@pytest.mark.parametrize('names_to_search,expected_ids', [ + +@pytest.mark.parametrize('names,expected_ids', [ ([], []), (['name1'], [1]), (['NAME1'], [1]), @@ -178,17 +179,18 @@ def test_get_tag_by_name(name_to_search, expected_to_find, tag_factory): (['name1', 'ALIAS2'], [1, 2]), (['name2', 'alias1'], [1, 2]), ]) -def test_get_tag_by_names(names_to_search, expected_ids, tag_factory): +def test_get_tag_by_names(names, expected_ids, tag_factory): tag1 = tag_factory(names=['name1', 'ALIAS1']) tag2 = tag_factory(names=['name2', 'ALIAS2']) tag1.tag_id = 1 tag2.tag_id = 2 db.session.add_all([tag1, tag2]) - actual_ids = [tag.tag_id for tag in tags.get_tags_by_names(names_to_search)] + actual_ids = [tag.tag_id for tag in tags.get_tags_by_names(names)] assert actual_ids == expected_ids + @pytest.mark.parametrize( - 'names_to_search,expected_ids,expected_created_names', [ + 'names,expected_ids,expected_created_names', [ ([], [], []), (['name1'], [1], []), (['NAME1'], [1], []), @@ -220,7 +222,7 @@ def test_get_tag_by_names(names_to_search, expected_ids, tag_factory): (['new', 'new2'], [], ['new', 'new2']), ]) def test_get_or_create_tags_by_names( - names_to_search, + names, expected_ids, expected_created_names, tag_factory, @@ -231,18 +233,20 @@ def test_get_or_create_tags_by_names( tag1 = tag_factory(names=['name1', 'ALIAS1'], category=category) tag2 = tag_factory(names=['name2', 'ALIAS2'], category=category) db.session.add_all([tag1, tag2]) - result = tags.get_or_create_tags_by_names(names_to_search) + result = tags.get_or_create_tags_by_names(names) actual_ids = [tag.tag_id for tag in result[0]] actual_created_names = [tag.names[0].name for tag in result[1]] assert actual_ids == expected_ids assert actual_created_names == expected_created_names -def test_get_tag_siblings_for_unused(tag_factory, post_factory): + +def test_get_tag_siblings_for_unused(tag_factory): tag = tag_factory(names=['tag']) db.session.add(tag) db.session.flush() _assert_tag_siblings(tags.get_tag_siblings(tag), []) + def test_get_tag_siblings_for_used_alone(tag_factory, post_factory): tag = tag_factory(names=['tag']) post = post_factory() @@ -251,6 +255,7 @@ def test_get_tag_siblings_for_used_alone(tag_factory, post_factory): db.session.flush() _assert_tag_siblings(tags.get_tag_siblings(tag), []) + def test_get_tag_siblings_for_used_with_others(tag_factory, post_factory): tag1 = tag_factory(names=['t1']) tag2 = tag_factory(names=['t2']) @@ -261,6 +266,7 @@ def test_get_tag_siblings_for_used_with_others(tag_factory, post_factory): _assert_tag_siblings(tags.get_tag_siblings(tag1), [('t2', 1)]) _assert_tag_siblings(tags.get_tag_siblings(tag2), [('t1', 1)]) + def test_get_tag_siblings_used_for_multiple_others(tag_factory, post_factory): tag1 = tag_factory(names=['t1']) tag2 = tag_factory(names=['t2']) @@ -280,6 +286,7 @@ def test_get_tag_siblings_used_for_multiple_others(tag_factory, post_factory): # even though tag2 is used more widely, tag1 is more relevant to tag3 _assert_tag_siblings(tags.get_tag_siblings(tag3), [('t1', 2), ('t2', 1)]) + def test_delete(tag_factory): tag = tag_factory(names=['tag']) tag.suggestions = [tag_factory(names=['sug'])] @@ -290,6 +297,7 @@ def test_delete(tag_factory): tags.delete(tag) assert db.session.query(db.Tag).count() == 2 + def test_merge_tags_without_usages(tag_factory): source_tag = tag_factory(names=['source']) target_tag = tag_factory(names=['target']) @@ -301,6 +309,7 @@ def test_merge_tags_without_usages(tag_factory): tag = tags.get_tag_by_name('target') assert tag is not None + def test_merge_tags_with_usages(tag_factory, post_factory): source_tag = tag_factory(names=['source']) target_tag = tag_factory(names=['target']) @@ -315,13 +324,15 @@ def test_merge_tags_with_usages(tag_factory, post_factory): assert tags.try_get_tag_by_name('source') is None assert tags.get_tag_by_name('target').post_count == 1 -def test_merge_tags_with_itself(tag_factory, post_factory): + +def test_merge_tags_with_itself(tag_factory): source_tag = tag_factory(names=['source']) db.session.add(source_tag) db.session.commit() with pytest.raises(tags.InvalidTagRelationError): tags.merge_tags(source_tag, source_tag) + def test_merge_tags_with_its_child_relation(tag_factory, post_factory): source_tag = tag_factory(names=['source']) target_tag = tag_factory(names=['target']) @@ -336,6 +347,7 @@ def test_merge_tags_with_its_child_relation(tag_factory, post_factory): assert tags.try_get_tag_by_name('source') is None assert tags.get_tag_by_name('target').post_count == 1 + def test_merge_tags_with_its_parent_relation(tag_factory, post_factory): source_tag = tag_factory(names=['source']) target_tag = tag_factory(names=['target']) @@ -350,6 +362,7 @@ def test_merge_tags_with_its_parent_relation(tag_factory, post_factory): assert tags.try_get_tag_by_name('source') is None assert tags.get_tag_by_name('target').post_count == 1 + def test_merge_tags_clears_relations(tag_factory): source_tag = tag_factory(names=['source']) target_tag = tag_factory(names=['target']) @@ -366,6 +379,7 @@ def test_merge_tags_clears_relations(tag_factory): assert tags.try_get_tag_by_name('parent').implications == [] assert tags.try_get_tag_by_name('parent').suggestions == [] + def test_merge_tags_when_target_exists(tag_factory, post_factory): source_tag = tag_factory(names=['source']) target_tag = tag_factory(names=['target']) @@ -380,11 +394,12 @@ def test_merge_tags_when_target_exists(tag_factory, post_factory): assert tags.try_get_tag_by_name('source') is None assert tags.get_tag_by_name('target').post_count == 1 + def test_create_tag(fake_datetime): - with unittest.mock.patch('szurubooru.func.tags.update_tag_names'), \ - unittest.mock.patch('szurubooru.func.tags.update_tag_category_name'), \ - unittest.mock.patch('szurubooru.func.tags.update_tag_suggestions'), \ - unittest.mock.patch('szurubooru.func.tags.update_tag_implications'), \ + with patch('szurubooru.func.tags.update_tag_names'), \ + patch('szurubooru.func.tags.update_tag_category_name'), \ + patch('szurubooru.func.tags.update_tag_suggestions'), \ + patch('szurubooru.func.tags.update_tag_implications'), \ fake_datetime('1997-01-01'): tag = tags.create_tag(['name'], 'cat', ['sug'], ['imp']) assert tag.creation_time == datetime(1997, 1, 1) @@ -394,38 +409,45 @@ def test_create_tag(fake_datetime): tags.update_tag_suggestions.assert_called_once_with(tag, ['sug']) tags.update_tag_implications.assert_called_once_with(tag, ['imp']) + def test_update_tag_category_name(tag_factory): - with unittest.mock.patch('szurubooru.func.tag_categories.get_category_by_name'): - tag_categories.get_category_by_name.return_value = 'returned category' + with patch('szurubooru.func.tag_categories.get_category_by_name'): + tag_categories.get_category_by_name.return_value = 'mocked' tag = tag_factory() tags.update_tag_category_name(tag, 'cat') assert tag_categories.get_category_by_name.called_once_with('cat') - assert tag.category == 'returned category' + assert tag.category == 'mocked' + def test_update_tag_names_to_empty(tag_factory): tag = tag_factory() with pytest.raises(tags.InvalidTagNameError): tags.update_tag_names(tag, []) + def test_update_tag_names_with_invalid_name(config_injector, tag_factory): config_injector({'tag_name_regex': '^[a-z]*$'}) tag = tag_factory() with pytest.raises(tags.InvalidTagNameError): tags.update_tag_names(tag, ['0']) + def test_update_tag_names_with_too_long_string(config_injector, tag_factory): config_injector({'tag_name_regex': '^[a-z]*$'}) tag = tag_factory() with pytest.raises(tags.InvalidTagNameError): tags.update_tag_names(tag, ['a' * 300]) + def test_update_tag_names_with_duplicate_names(config_injector, tag_factory): config_injector({'tag_name_regex': '^[a-z]*$'}) tag = tag_factory() tags.update_tag_names(tag, ['a', 'A']) assert [tag_name.name for tag_name in tag.names] == ['a'] -def test_update_tag_names_trying_to_use_taken_name(config_injector, tag_factory): + +def test_update_tag_names_trying_to_use_taken_name( + config_injector, tag_factory): config_injector({'tag_name_regex': '^[a-zA-Z]*$'}) existing_tag = tag_factory(names=['a']) db.session.add(existing_tag) @@ -436,6 +458,7 @@ def test_update_tag_names_trying_to_use_taken_name(config_injector, tag_factory) with pytest.raises(tags.TagAlreadyExistsError): tags.update_tag_names(tag, ['A']) + def test_update_tag_names_reusing_own_name(config_injector, tag_factory): config_injector({'tag_name_regex': '^[a-zA-Z]*$'}) for name in list('aA'): @@ -446,32 +469,37 @@ def test_update_tag_names_reusing_own_name(config_injector, tag_factory): assert [tag_name.name for tag_name in tag.names] == [name] db.session.rollback() + @pytest.mark.parametrize('attempt', ['name', 'NAME', 'alias', 'ALIAS']) def test_update_tag_suggestions_with_itself(attempt, tag_factory): tag = tag_factory(names=['name', 'ALIAS']) with pytest.raises(tags.InvalidTagRelationError): tags.update_tag_suggestions(tag, [attempt]) + def test_update_tag_suggestions(tag_factory): tag = tag_factory(names=['name', 'ALIAS']) - with unittest.mock.patch('szurubooru.func.tags.get_tags_by_names'): + with patch('szurubooru.func.tags.get_tags_by_names'): tags.get_tags_by_names.return_value = ['returned tags'] tags.update_tag_suggestions(tag, ['test']) assert tag.suggestions == ['returned tags'] + @pytest.mark.parametrize('attempt', ['name', 'NAME', 'alias', 'ALIAS']) def test_update_tag_implications_with_itself(attempt, tag_factory): tag = tag_factory(names=['name', 'ALIAS']) with pytest.raises(tags.InvalidTagRelationError): tags.update_tag_implications(tag, [attempt]) + def test_update_tag_implications(tag_factory): tag = tag_factory(names=['name', 'ALIAS']) - with unittest.mock.patch('szurubooru.func.tags.get_tags_by_names'): + with patch('szurubooru.func.tags.get_tags_by_names'): tags.get_tags_by_names.return_value = ['returned tags'] tags.update_tag_implications(tag, ['test']) assert tag.implications == ['returned tags'] + def test_update_tag_description(tag_factory): tag = tag_factory() tags.update_tag_description(tag, 'test') diff --git a/server/szurubooru/tests/func/test_users.py b/server/szurubooru/tests/func/test_users.py index ea494f1a..0e6f264a 100644 --- a/server/szurubooru/tests/func/test_users.py +++ b/server/szurubooru/tests/func/test_users.py @@ -1,40 +1,46 @@ -import unittest.mock -import pytest +from unittest.mock import patch from datetime import datetime +import pytest from szurubooru import db, errors from szurubooru.func import auth, users, files, util + EMPTY_PIXEL = \ b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' \ b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \ b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b' + @pytest.mark.parametrize('user_name', ['test', 'TEST']) def test_get_avatar_path(user_name): assert users.get_avatar_path(user_name) == 'avatars/test.png' + @pytest.mark.parametrize('user_name,user_email,avatar_style,expected_url', [ ( 'user', - None , + None, db.User.AVATAR_GRAVATAR, - 'https://gravatar.com/avatar/ee11cbb19052e40b07aac0ca060c23ee?d=retro&s=100', + 'https://gravatar.com/avatar/' + + 'ee11cbb19052e40b07aac0ca060c23ee?d=retro&s=100', ), ( None, - 'user@example.com' , + 'user@example.com', db.User.AVATAR_GRAVATAR, - 'https://gravatar.com/avatar/b58996c504c5638798eb6b511e6f49af?d=retro&s=100', + 'https://gravatar.com/avatar/' + + 'b58996c504c5638798eb6b511e6f49af?d=retro&s=100', ), ( 'user', - 'user@example.com' , + 'user@example.com', db.User.AVATAR_GRAVATAR, - 'https://gravatar.com/avatar/b58996c504c5638798eb6b511e6f49af?d=retro&s=100', + 'https://gravatar.com/avatar/' + + 'b58996c504c5638798eb6b511e6f49af?d=retro&s=100', ), ( 'user', - None , + None, db.User.AVATAR_MANUAL, 'http://example.com/avatars/user.png', ), @@ -51,16 +57,22 @@ def test_get_avatar_url( user.avatar_style = avatar_style assert users.get_avatar_url(user) == expected_url + @pytest.mark.parametrize( - 'same_user,can_edit_any_email,force_show,expected_email', [ - (False, False, False, False), - (True, False, False, 'test@example.com'), - (False, True, False, 'test@example.com'), - (False, False, True, 'test@example.com'), -]) + 'same_user,can_edit_any_email,force_show,expected_email', + [ + (False, False, False, False), + (True, False, False, 'test@example.com'), + (False, True, False, 'test@example.com'), + (False, False, True, 'test@example.com'), + ]) def test_get_email( - same_user, can_edit_any_email, force_show, expected_email, user_factory): - with unittest.mock.patch('szurubooru.func.auth.has_privilege'): + same_user, + can_edit_any_email, + force_show, + expected_email, + user_factory): + with patch('szurubooru.func.auth.has_privilege'): auth.has_privilege = lambda user, name: can_edit_any_email user = user_factory() user.email = 'test@example.com' @@ -69,13 +81,15 @@ def test_get_email( db.session.flush() assert users.get_email(user, auth_user, force_show) == expected_email + @pytest.mark.parametrize( - 'same_user,score,expected_liked_post_count,expected_disliked_post_count', [ - (False, 1, False, False), - (False, -1, False, False), - (True, 1, 1, 0), - (True, -1, 0, 1), -]) + 'same_user,score,expected_liked_post_count,expected_disliked_post_count', + [ + (False, 1, False, False), + (False, -1, False, False), + (True, 1, 1, 0), + (True, -1, 0, 1), + ]) def test_get_liked_post_count( same_user, score, @@ -86,25 +100,29 @@ def test_get_liked_post_count( user = user_factory() post = post_factory() auth_user = user if same_user else user_factory() - score = db.PostScore(post=post, user=user, score=score, time=datetime.now()) + score = db.PostScore( + post=post, user=user, score=score, time=datetime.now()) db.session.add_all([post, user, score]) db.session.flush() - assert users.get_liked_post_count(user, auth_user) == expected_liked_post_count - assert users.get_disliked_post_count(user, auth_user) == expected_disliked_post_count + actual_liked_post_count = users.get_liked_post_count(user, auth_user) + actual_disliked_post_count = users.get_disliked_post_count(user, auth_user) + assert actual_liked_post_count == expected_liked_post_count + assert actual_disliked_post_count == expected_disliked_post_count + def test_serialize_user_when_empty(): assert users.serialize_user(None, None) is None + def test_serialize_user(user_factory): - with unittest.mock.patch('szurubooru.func.users.get_email'), \ - unittest.mock.patch('szurubooru.func.users.get_avatar_url'), \ - unittest.mock.patch('szurubooru.func.users.get_liked_post_count'), \ - unittest.mock.patch('szurubooru.func.users.get_disliked_post_count'): + with patch('szurubooru.func.users.get_email'), \ + patch('szurubooru.func.users.get_avatar_url'), \ + patch('szurubooru.func.users.get_liked_post_count'), \ + patch('szurubooru.func.users.get_disliked_post_count'): users.get_email.return_value = 'test@example.com' users.get_avatar_url.return_value = 'https://example.com/avatar.png' users.get_liked_post_count.return_value = 66 users.get_disliked_post_count.return_value = 33 - auth_user = user_factory() user = user_factory(name='dummy user') user.creation_time = datetime(1997, 1, 1) @@ -113,7 +131,6 @@ def test_serialize_user(user_factory): user.rank = db.User.RANK_ADMINISTRATOR db.session.add(user) db.session.flush() - assert users.serialize_user(user, auth_user) == { 'version': 1, 'name': 'dummy user', @@ -130,8 +147,9 @@ def test_serialize_user(user_factory): 'uploadedPostCount': 0, } + def test_serialize_micro_user(user_factory): - with unittest.mock.patch('szurubooru.func.users.get_avatar_url'): + with patch('szurubooru.func.users.get_avatar_url'): users.get_avatar_url.return_value = 'https://example.com/avatar.png' auth_user = user_factory() user = user_factory(name='dummy user') @@ -142,6 +160,7 @@ def test_serialize_micro_user(user_factory): 'avatarUrl': 'https://example.com/avatar.png', } + @pytest.mark.parametrize('count', [0, 1, 2]) def test_get_user_count(count, user_factory): for _ in range(count): @@ -149,6 +168,7 @@ def test_get_user_count(count, user_factory): db.session.flush() assert users.get_user_count() == count + def test_try_get_user_by_name(user_factory): user = user_factory(name='name', email='email') db.session.add(user) @@ -157,6 +177,7 @@ def test_try_get_user_by_name(user_factory): assert users.try_get_user_by_name('name') is user assert users.try_get_user_by_name('NAME') is user + def test_get_user_by_name(user_factory): user = user_factory(name='name', email='email') db.session.add(user) @@ -167,6 +188,7 @@ def test_get_user_by_name(user_factory): assert users.get_user_by_name('name') is user assert users.get_user_by_name('NAME') is user + def test_try_get_user_by_name_or_email(user_factory): user = user_factory(name='name', email='email') db.session.add(user) @@ -176,6 +198,7 @@ def test_try_get_user_by_name_or_email(user_factory): assert users.try_get_user_by_name_or_email('name') is user assert users.try_get_user_by_name_or_email('NAME') is user + def test_get_user_by_name_or_email(user_factory): user = user_factory(name='name', email='email') db.session.add(user) @@ -186,10 +209,11 @@ def test_get_user_by_name_or_email(user_factory): assert users.get_user_by_name_or_email('name') is user assert users.get_user_by_name_or_email('NAME') is user + def test_create_user_for_first_user(fake_datetime): - with unittest.mock.patch('szurubooru.func.users.update_user_name'), \ - unittest.mock.patch('szurubooru.func.users.update_user_password'), \ - unittest.mock.patch('szurubooru.func.users.update_user_email'), \ + with patch('szurubooru.func.users.update_user_name'), \ + patch('szurubooru.func.users.update_user_password'), \ + patch('szurubooru.func.users.update_user_email'), \ fake_datetime('1997-01-01'): user = users.create_user('name', 'password', 'email') assert user.creation_time == datetime(1997, 1, 1) @@ -199,32 +223,37 @@ def test_create_user_for_first_user(fake_datetime): users.update_user_password.assert_called_once_with(user, 'password') users.update_user_email.assert_called_once_with(user, 'email') + def test_create_user_for_subsequent_users(user_factory, config_injector): config_injector({'default_rank': 'regular'}) db.session.add(user_factory()) db.session.flush() - with unittest.mock.patch('szurubooru.func.users.update_user_name'), \ - unittest.mock.patch('szurubooru.func.users.update_user_password'), \ - unittest.mock.patch('szurubooru.func.users.update_user_email'): + with patch('szurubooru.func.users.update_user_name'), \ + patch('szurubooru.func.users.update_user_email'), \ + patch('szurubooru.func.users.update_user_password'): user = users.create_user('name', 'password', 'email') assert user.rank == db.User.RANK_REGULAR + def test_update_user_name_with_empty_string(user_factory): user = user_factory() with pytest.raises(users.InvalidUserNameError): users.update_user_name(user, None) + def test_update_user_name_with_too_long_string(user_factory): user = user_factory() with pytest.raises(users.InvalidUserNameError): users.update_user_name(user, 'a' * 300) + def test_update_user_name_with_invalid_name(user_factory, config_injector): config_injector({'user_name_regex': '^[a-z]+$'}) user = user_factory() with pytest.raises(users.InvalidUserNameError): users.update_user_name(user, '0') + def test_update_user_name_with_duplicate_name(user_factory, config_injector): config_injector({'user_name_regex': '^[a-z]+$'}) user = user_factory() @@ -234,88 +263,102 @@ def test_update_user_name_with_duplicate_name(user_factory, config_injector): with pytest.raises(users.UserAlreadyExistsError): users.update_user_name(user, 'dummy') + def test_update_user_name_reusing_own_name(user_factory, config_injector): config_injector({'user_name_regex': '^[a-z]+$'}) user = user_factory(name='dummy') db.session.add(user) db.session.flush() - with unittest.mock.patch('szurubooru.func.files.has'): + with patch('szurubooru.func.files.has'): files.has.return_value = False users.update_user_name(user, 'dummy') db.session.flush() assert users.try_get_user_by_name('dummy') is user + def test_update_user_name_for_new_user(user_factory, config_injector): config_injector({'user_name_regex': '^[a-z]+$'}) user = user_factory() - with unittest.mock.patch('szurubooru.func.files.has'): + with patch('szurubooru.func.files.has'): files.has.return_value = False users.update_user_name(user, 'dummy') assert user.name == 'dummy' + def test_update_user_name_moves_avatar(user_factory, config_injector): config_injector({'user_name_regex': '^[a-z]+$'}) user = user_factory(name='old') - with unittest.mock.patch('szurubooru.func.files.has'), \ - unittest.mock.patch('szurubooru.func.files.move'): + with patch('szurubooru.func.files.has'), \ + patch('szurubooru.func.files.move'): files.has.return_value = True users.update_user_name(user, 'new') - files.move.assert_called_once_with('avatars/old.png', 'avatars/new.png') + files.move.assert_called_once_with( + 'avatars/old.png', 'avatars/new.png') + def test_update_user_password_with_empty_string(user_factory): user = user_factory() with pytest.raises(users.InvalidPasswordError): users.update_user_password(user, None) -def test_update_user_password_with_invalid_string(user_factory, config_injector): + +def test_update_user_password_with_invalid_string( + user_factory, config_injector): config_injector({'password_regex': '^[a-z]+$'}) user = user_factory() with pytest.raises(users.InvalidPasswordError): users.update_user_password(user, '0') + def test_update_user_password(user_factory, config_injector): config_injector({'password_regex': '^[a-z]+$'}) user = user_factory() - with unittest.mock.patch('szurubooru.func.auth.create_password'), \ - unittest.mock.patch('szurubooru.func.auth.get_password_hash'): + with patch('szurubooru.func.auth.create_password'), \ + patch('szurubooru.func.auth.get_password_hash'): auth.create_password.return_value = 'salt' auth.get_password_hash.return_value = 'hash' users.update_user_password(user, 'a') assert user.password_salt == 'salt' assert user.password_hash == 'hash' + def test_update_user_email_with_too_long_string(user_factory): user = user_factory() with pytest.raises(users.InvalidEmailError): users.update_user_email(user, 'a' * 300) + def test_update_user_email_with_invalid_email(user_factory): user = user_factory() - with unittest.mock.patch('szurubooru.func.util.is_valid_email'): + with patch('szurubooru.func.util.is_valid_email'): util.is_valid_email.return_value = False with pytest.raises(users.InvalidEmailError): users.update_user_email(user, 'a') + def test_update_user_email_with_empty_string(user_factory): user = user_factory() - with unittest.mock.patch('szurubooru.func.util.is_valid_email'): + with patch('szurubooru.func.util.is_valid_email'): util.is_valid_email.return_value = True users.update_user_email(user, '') assert user.email is None + def test_update_user_email(user_factory): user = user_factory() - with unittest.mock.patch('szurubooru.func.util.is_valid_email'): + with patch('szurubooru.func.util.is_valid_email'): util.is_valid_email.return_value = True users.update_user_email(user, 'a') assert user.email == 'a' + def test_update_user_rank_with_empty_string(user_factory): user = user_factory() auth_user = user_factory() with pytest.raises(users.InvalidRankError): users.update_user_rank(user, '', auth_user) + def test_update_user_rank_with_invalid_string(user_factory): user = user_factory() auth_user = user_factory() @@ -326,6 +369,7 @@ def test_update_user_rank_with_invalid_string(user_factory): with pytest.raises(users.InvalidRankError): users.update_user_rank(user, 'nobody', auth_user) + def test_update_user_rank_with_higher_rank_than_possible(user_factory): db.session.add(user_factory()) user = user_factory() @@ -336,6 +380,7 @@ def test_update_user_rank_with_higher_rank_than_possible(user_factory): with pytest.raises(errors.AuthError): users.update_user_rank(auth_user, 'regular', auth_user) + def test_update_user_rank(user_factory): db.session.add(user_factory()) user = user_factory() @@ -346,46 +391,54 @@ def test_update_user_rank(user_factory): assert user.rank == db.User.RANK_REGULAR assert auth_user.rank == db.User.RANK_REGULAR + def test_update_user_avatar_with_invalid_style(user_factory): user = user_factory() with pytest.raises(users.InvalidAvatarError): users.update_user_avatar(user, 'invalid', b'') + def test_update_user_avatar_to_gravatar(user_factory): user = user_factory() users.update_user_avatar(user, 'gravatar') assert user.avatar_style == db.User.AVATAR_GRAVATAR + def test_update_user_avatar_to_empty_manual(user_factory): user = user_factory() - with unittest.mock.patch('szurubooru.func.files.has'), \ + with patch('szurubooru.func.files.has'), \ pytest.raises(users.InvalidAvatarError): files.has.return_value = False users.update_user_avatar(user, 'manual', b'') + def test_update_user_avatar_to_previous_manual(user_factory): user = user_factory() - with unittest.mock.patch('szurubooru.func.files.has'): + with patch('szurubooru.func.files.has'): files.has.return_value = True users.update_user_avatar(user, 'manual', b'') + def test_update_user_avatar_to_new_manual(user_factory, config_injector): - config_injector({'thumbnails': {'avatar_width': 500, 'avatar_height': 500}}) + config_injector( + {'thumbnails': {'avatar_width': 500, 'avatar_height': 500}}) user = user_factory() - with unittest.mock.patch('szurubooru.func.files.save'): + with patch('szurubooru.func.files.save'): users.update_user_avatar(user, 'manual', EMPTY_PIXEL) assert user.avatar_style == db.User.AVATAR_MANUAL assert files.save.called + def test_bump_user_login_time(user_factory, fake_datetime): user = user_factory() with fake_datetime('1997-01-01'): users.bump_user_login_time(user) assert user.last_login_time == datetime(1997, 1, 1) + def test_reset_user_password(user_factory): - with unittest.mock.patch('szurubooru.func.auth.create_password'), \ - unittest.mock.patch('szurubooru.func.auth.get_password_hash'): + with patch('szurubooru.func.auth.create_password'), \ + patch('szurubooru.func.auth.get_password_hash'): user = user_factory() auth.create_password.return_value = 'salt' auth.get_password_hash.return_value = 'hash' diff --git a/server/szurubooru/tests/func/test_util.py b/server/szurubooru/tests/func/test_util.py index 69255ad2..24fe4e44 100644 --- a/server/szurubooru/tests/func/test_util.py +++ b/server/szurubooru/tests/func/test_util.py @@ -1,29 +1,33 @@ +from datetime import datetime import pytest from szurubooru import errors from szurubooru.func import util -from datetime import datetime -dt = datetime + +dt = datetime # pylint: disable=invalid-name + def test_parsing_empty_date_time(): with pytest.raises(errors.ValidationError): util.parse_time_range('') -@pytest.mark.parametrize('input,output', [ - ('today', (dt(1997, 1, 2, 0, 0, 0), dt(1997, 1, 2, 23, 59, 59))), - ('yesterday', (dt(1997, 1, 1, 0, 0, 0), dt(1997, 1, 1, 23, 59, 59))), - ('1999', (dt(1999, 1, 1, 0, 0, 0), dt(1999, 12, 31, 23, 59, 59))), - ('1999-2', (dt(1999, 2, 1, 0, 0, 0), dt(1999, 2, 28, 23, 59, 59))), - ('1999-02', (dt(1999, 2, 1, 0, 0, 0), dt(1999, 2, 28, 23, 59, 59))), - ('1999-2-6', (dt(1999, 2, 6, 0, 0, 0), dt(1999, 2, 6, 23, 59, 59))), - ('1999-02-6', (dt(1999, 2, 6, 0, 0, 0), dt(1999, 2, 6, 23, 59, 59))), - ('1999-2-06', (dt(1999, 2, 6, 0, 0, 0), dt(1999, 2, 6, 23, 59, 59))), - ('1999-02-06', (dt(1999, 2, 6, 0, 0, 0), dt(1999, 2, 6, 23, 59, 59))), + +@pytest.mark.parametrize('output,input', [ + ((dt(1997, 1, 2, 0, 0, 0), dt(1997, 1, 2, 23, 59, 59)), 'today'), + ((dt(1997, 1, 1, 0, 0, 0), dt(1997, 1, 1, 23, 59, 59)), 'yesterday'), + ((dt(1999, 1, 1, 0, 0, 0), dt(1999, 12, 31, 23, 59, 59)), '1999'), + ((dt(1999, 2, 1, 0, 0, 0), dt(1999, 2, 28, 23, 59, 59)), '1999-2'), + ((dt(1999, 2, 1, 0, 0, 0), dt(1999, 2, 28, 23, 59, 59)), '1999-02'), + ((dt(1999, 2, 6, 0, 0, 0), dt(1999, 2, 6, 23, 59, 59)), '1999-2-6'), + ((dt(1999, 2, 6, 0, 0, 0), dt(1999, 2, 6, 23, 59, 59)), '1999-02-6'), + ((dt(1999, 2, 6, 0, 0, 0), dt(1999, 2, 6, 23, 59, 59)), '1999-2-06'), + ((dt(1999, 2, 6, 0, 0, 0), dt(1999, 2, 6, 23, 59, 59)), '1999-02-06'), ]) def test_parsing_date_time(fake_datetime, input, output): with fake_datetime('1997-01-02 03:04:05'): assert util.parse_time_range(input) == output + @pytest.mark.parametrize('input,output', [ ([], []), (['a', 'b', 'c'], ['a', 'b', 'c']), diff --git a/server/szurubooru/tests/rest/__init__.py b/server/szurubooru/tests/rest/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/szurubooru/tests/rest/test_context.py b/server/szurubooru/tests/rest/test_context.py index 6fca391f..7380a855 100644 --- a/server/szurubooru/tests/rest/test_context.py +++ b/server/szurubooru/tests/rest/test_context.py @@ -1,18 +1,22 @@ +# pylint: disable=unexpected-keyword-arg import unittest.mock import pytest from szurubooru import rest, errors from szurubooru.func import net + def test_has_param(): ctx = rest.Context(method=None, url=None, params={'key': 'value'}) assert ctx.has_param('key') assert not ctx.has_param('key2') + def test_get_file(): ctx = rest.Context(method=None, url=None, files={'key': b'content'}) assert ctx.get_file('key') == b'content' assert ctx.get_file('key2') is None + def test_get_file_from_url(): with unittest.mock.patch('szurubooru.func.net.download'): net.download.return_value = b'content' @@ -22,9 +26,10 @@ def test_get_file_from_url(): assert ctx.get_file('key2') is None net.download.assert_called_once_with('example.com') + def test_getting_list_parameter(): ctx = rest.Context( - method=None, url=None, params={'key': 'value', 'list': ['1', '2', '3']}) + method=None, url=None, params={'key': 'value', 'list': list('123')}) assert ctx.get_param_as_list('key') == ['value'] assert ctx.get_param_as_list('key2') is None assert ctx.get_param_as_list('key2', default=['def']) == ['def'] @@ -32,16 +37,18 @@ def test_getting_list_parameter(): with pytest.raises(errors.ValidationError): ctx.get_param_as_list('key2', required=True) + def test_getting_string_parameter(): ctx = rest.Context( - method=None, url=None, params={'key': 'value', 'list': ['1', '2', '3']}) + method=None, url=None, params={'key': 'value', 'list': list('123')}) assert ctx.get_param_as_string('key') == 'value' assert ctx.get_param_as_string('key2') is None assert ctx.get_param_as_string('key2', default='def') == 'def' - assert ctx.get_param_as_string('list') == '1,2,3' # falcon issue #749 + assert ctx.get_param_as_string('list') == '1,2,3' with pytest.raises(errors.ValidationError): ctx.get_param_as_string('key2', required=True) + def test_getting_int_parameter(): ctx = rest.Context( method=None, @@ -63,6 +70,7 @@ def test_getting_int_parameter(): assert ctx.get_param_as_int('key', max=50) == 50 ctx.get_param_as_int('key', max=49) + def test_getting_bool_parameter(): def test(value): ctx = rest.Context(method=None, url=None, params={'key': value}) diff --git a/server/szurubooru/tests/search/__init__.py b/server/szurubooru/tests/search/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/szurubooru/tests/search/configs/__init__.py b/server/szurubooru/tests/search/configs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/szurubooru/tests/search/configs/test_comment_search_config.py b/server/szurubooru/tests/search/configs/test_comment_search_config.py index 5a88a71d..39c1ed3f 100644 --- a/server/szurubooru/tests/search/configs/test_comment_search_config.py +++ b/server/szurubooru/tests/search/configs/test_comment_search_config.py @@ -1,11 +1,14 @@ -import datetime +# pylint: disable=redefined-outer-name +from datetime import datetime import pytest -from szurubooru import db, errors, search +from szurubooru import db, search + @pytest.fixture def executor(): return search.Executor(search.configs.CommentSearchConfig()) + @pytest.fixture def verify_unpaged(executor): def verify(input, expected_comment_text): @@ -16,6 +19,7 @@ def verify_unpaged(executor): assert actual_comment_text == expected_comment_text return verify + @pytest.mark.parametrize('input,expected_comment_text', [ ('creation-time:2014', ['t2', 't1']), ('creation-date:2014', ['t2', 't1']), @@ -25,12 +29,13 @@ def test_filter_by_creation_time( comment1 = comment_factory(text='t1') comment2 = comment_factory(text='t2') comment3 = comment_factory(text='t3') - comment1.creation_time = datetime.datetime(2014, 1, 1) - comment2.creation_time = datetime.datetime(2014, 6, 1) - comment3.creation_time = datetime.datetime(2015, 1, 1) + comment1.creation_time = datetime(2014, 1, 1) + comment2.creation_time = datetime(2014, 6, 1) + comment3.creation_time = datetime(2015, 1, 1) db.session.add_all([comment1, comment2, comment3]) verify_unpaged(input, expected_comment_text) + @pytest.mark.parametrize('input,expected_comment_text', [ ('text:t1', ['t1']), ('text:t2', ['t2']), @@ -44,57 +49,80 @@ def test_filter_by_text( db.session.add_all([comment1, comment2]) verify_unpaged(input, expected_comment_text) + @pytest.mark.parametrize('input,expected_comment_text', [ ('user:u1', ['t1']), ('user:u2', ['t2']), ('user:u1,u2', ['t2', 't1']), ]) def test_filter_by_user( - verify_unpaged, comment_factory, user_factory, input, expected_comment_text): + verify_unpaged, + comment_factory, + user_factory, + input, + expected_comment_text): db.session.add(comment_factory(text='t2', user=user_factory(name='u2'))) db.session.add(comment_factory(text='t1', user=user_factory(name='u1'))) verify_unpaged(input, expected_comment_text) + @pytest.mark.parametrize('input,expected_comment_text', [ ('post:1', ['t1']), ('post:2', ['t2']), ('post:1,2', ['t1', 't2']), ]) def test_filter_by_post( - verify_unpaged, comment_factory, post_factory, input, expected_comment_text): + verify_unpaged, + comment_factory, + post_factory, + input, + expected_comment_text): db.session.add(comment_factory(text='t1', post=post_factory(id=1))) db.session.add(comment_factory(text='t2', post=post_factory(id=2))) verify_unpaged(input, expected_comment_text) + @pytest.mark.parametrize('input,expected_comment_text', [ ('', ['t1', 't2']), ('t1', ['t1']), ('t2', ['t2']), ('t1,t2', ['t1', 't2']), ]) -def test_anonymous(verify_unpaged, comment_factory, input, expected_comment_text): +def test_anonymous( + verify_unpaged, comment_factory, input, expected_comment_text): db.session.add(comment_factory(text='t1')) db.session.add(comment_factory(text='t2')) verify_unpaged(input, expected_comment_text) + @pytest.mark.parametrize('input,expected_comment_text', [ ('sort:user', ['t1', 't2']), ]) def test_sort_by_user( - verify_unpaged, comment_factory, user_factory, input, expected_comment_text): + verify_unpaged, + comment_factory, + user_factory, + input, + expected_comment_text): db.session.add(comment_factory(text='t2', user=user_factory(name='u2'))) db.session.add(comment_factory(text='t1', user=user_factory(name='u1'))) verify_unpaged(input, expected_comment_text) + @pytest.mark.parametrize('input,expected_comment_text', [ ('sort:post', ['t2', 't1']), ]) def test_sort_by_post( - verify_unpaged, comment_factory, post_factory, input, expected_comment_text): + verify_unpaged, + comment_factory, + post_factory, + input, + expected_comment_text): db.session.add(comment_factory(text='t1', post=post_factory(id=1))) db.session.add(comment_factory(text='t2', post=post_factory(id=2))) verify_unpaged(input, expected_comment_text) + @pytest.mark.parametrize('input,expected_comment_text', [ ('', ['t3', 't2', 't1']), ('sort:creation-date', ['t3', 't2', 't1']), @@ -105,12 +133,13 @@ def test_sort_by_creation_time( comment1 = comment_factory(text='t1') comment2 = comment_factory(text='t2') comment3 = comment_factory(text='t3') - comment1.creation_time = datetime.datetime(1991, 1, 1) - comment2.creation_time = datetime.datetime(1991, 1, 2) - comment3.creation_time = datetime.datetime(1991, 1, 3) + comment1.creation_time = datetime(1991, 1, 1) + comment2.creation_time = datetime(1991, 1, 2) + comment3.creation_time = datetime(1991, 1, 3) db.session.add_all([comment3, comment1, comment2]) verify_unpaged(input, expected_comment_text) + @pytest.mark.parametrize('input,expected_comment_text', [ ('sort:last-edit-date', ['t3', 't2', 't1']), ('sort:last-edit-time', ['t3', 't2', 't1']), @@ -122,8 +151,8 @@ def test_sort_by_last_edit_time( comment1 = comment_factory(text='t1') comment2 = comment_factory(text='t2') comment3 = comment_factory(text='t3') - comment1.last_edit_time = datetime.datetime(1991, 1, 1) - comment2.last_edit_time = datetime.datetime(1991, 1, 2) - comment3.last_edit_time = datetime.datetime(1991, 1, 3) + comment1.last_edit_time = datetime(1991, 1, 1) + comment2.last_edit_time = datetime(1991, 1, 2) + comment3.last_edit_time = datetime(1991, 1, 3) db.session.add_all([comment3, comment1, comment2]) verify_unpaged(input, expected_comment_text) diff --git a/server/szurubooru/tests/search/configs/test_post_search_config.py b/server/szurubooru/tests/search/configs/test_post_search_config.py index abd1a7f8..0c87f993 100644 --- a/server/szurubooru/tests/search/configs/test_post_search_config.py +++ b/server/szurubooru/tests/search/configs/test_post_search_config.py @@ -1,43 +1,55 @@ -import datetime +# pylint: disable=redefined-outer-name +from datetime import datetime import pytest from szurubooru import db, errors, search + @pytest.fixture def fav_factory(user_factory): def factory(post, user=None): return db.PostFavorite( - post=post, user=user or user_factory(), time=datetime.datetime.utcnow()) + post=post, + user=user or user_factory(), + time=datetime.utcnow()) return factory + @pytest.fixture def score_factory(user_factory): def factory(post, user=None, score=1): return db.PostScore( post=post, user=user or user_factory(), - time=datetime.datetime.utcnow(), + time=datetime.utcnow(), score=score) return factory + @pytest.fixture def note_factory(): - def factory(post=None): + def factory(): return db.PostNote(polygon='...', text='...') return factory + @pytest.fixture def feature_factory(user_factory): def factory(post=None): if post: return db.PostFeature( - time=datetime.datetime.utcnow(), user=user_factory(), post=post) - return db.PostFeature(time=datetime.datetime.utcnow(), user=user_factory()) + time=datetime.utcnow(), + user=user_factory(), + post=post) + return db.PostFeature( + time=datetime.utcnow(), user=user_factory()) return factory + @pytest.fixture -def executor(user_factory): +def executor(): return search.Executor(search.configs.PostSearchConfig()) + @pytest.fixture def auth_executor(executor, user_factory): def wrapper(): @@ -48,6 +60,7 @@ def auth_executor(executor, user_factory): return auth_user return wrapper + @pytest.fixture def verify_unpaged(executor): def verify(input, expected_post_ids, test_order=False): @@ -61,6 +74,7 @@ def verify_unpaged(executor): assert actual_count == len(expected_post_ids) return verify + @pytest.mark.parametrize('input,expected_post_ids', [ ('id:1', [1]), ('id:3', [3]), @@ -73,6 +87,7 @@ def test_filter_by_id(verify_unpaged, post_factory, input, expected_post_ids): db.session.add_all([post1, post2, post3]) verify_unpaged(input, expected_post_ids) + @pytest.mark.parametrize('input,expected_post_ids', [ ('tag:t1', [1]), ('tag:t2', [2]), @@ -86,13 +101,14 @@ def test_filter_by_tag( post2 = post_factory(id=2) post3 = post_factory(id=3) post4 = post_factory(id=4) - post1.tags=[tag_factory(names=['t1'])] - post2.tags=[tag_factory(names=['t2'])] - post3.tags=[tag_factory(names=['t3'])] - post4.tags=[tag_factory(names=['t4a', 't4b'])] + post1.tags = [tag_factory(names=['t1'])] + post2.tags = [tag_factory(names=['t2'])] + post3.tags = [tag_factory(names=['t3'])] + post4.tags = [tag_factory(names=['t4a', 't4b'])] db.session.add_all([post1, post2, post3, post4]) verify_unpaged(input, expected_post_ids) + @pytest.mark.parametrize('input,expected_post_ids', [ ('score:1', [1]), ('score:3', [3]), @@ -107,12 +123,13 @@ def test_filter_by_score( db.session.add( db.PostScore( score=post.post_id, - time=datetime.datetime.utcnow(), + time=datetime.utcnow(), post=post, user=user_factory())) db.session.add_all([post1, post2, post3]) verify_unpaged(input, expected_post_ids) + @pytest.mark.parametrize('input,expected_post_ids', [ ('uploader:u1', [1]), ('uploader:u3', [3]), @@ -135,6 +152,7 @@ def test_filter_by_uploader( db.session.add_all([post1, post2, post3]) verify_unpaged(input, expected_post_ids) + @pytest.mark.parametrize('input,expected_post_ids', [ ('comment:u1', [1]), ('comment:u3', [3]), @@ -158,6 +176,7 @@ def test_filter_by_commenter( ]) verify_unpaged(input, expected_post_ids) + @pytest.mark.parametrize('input,expected_post_ids', [ ('fav:u1', [1]), ('fav:u3', [3]), @@ -180,6 +199,7 @@ def test_filter_by_favorite( post1, post2, post3]) verify_unpaged(input, expected_post_ids) + @pytest.mark.parametrize('input,expected_post_ids', [ ('tag-count:1', [1]), ('tag-count:3', [3]), @@ -190,19 +210,24 @@ def test_filter_by_tag_count( post1 = post_factory(id=1) post2 = post_factory(id=2) post3 = post_factory(id=3) - post1.tags=[tag_factory()] - post2.tags=[tag_factory(), tag_factory()] - post3.tags=[tag_factory(), tag_factory(), tag_factory()] + post1.tags = [tag_factory()] + post2.tags = [tag_factory(), tag_factory()] + post3.tags = [tag_factory(), tag_factory(), tag_factory()] db.session.add_all([post1, post2, post3]) verify_unpaged(input, expected_post_ids) + @pytest.mark.parametrize('input,expected_post_ids', [ ('comment-count:1', [1]), ('comment-count:3', [3]), ('comment-count:1,3', [1, 3]), ]) def test_filter_by_comment_count( - verify_unpaged, post_factory, comment_factory, input, expected_post_ids): + verify_unpaged, + post_factory, + comment_factory, + input, + expected_post_ids): post1 = post_factory(id=1) post2 = post_factory(id=2) post3 = post_factory(id=3) @@ -216,6 +241,7 @@ def test_filter_by_comment_count( post1, post2, post3]) verify_unpaged(input, expected_post_ids) + @pytest.mark.parametrize('input,expected_post_ids', [ ('fav-count:1', [1]), ('fav-count:3', [3]), @@ -236,6 +262,7 @@ def test_filter_by_favorite_count( post1, post2, post3]) verify_unpaged(input, expected_post_ids) + @pytest.mark.parametrize('input,expected_post_ids', [ ('note-count:1', [1]), ('note-count:3', [3]), @@ -246,28 +273,34 @@ def test_filter_by_note_count( post1 = post_factory(id=1) post2 = post_factory(id=2) post3 = post_factory(id=3) - post1.notes=[note_factory()] - post2.notes=[note_factory(), note_factory()] - post3.notes=[note_factory(), note_factory(), note_factory()] + post1.notes = [note_factory()] + post2.notes = [note_factory(), note_factory()] + post3.notes = [note_factory(), note_factory(), note_factory()] db.session.add_all([post1, post2, post3]) verify_unpaged(input, expected_post_ids) + @pytest.mark.parametrize('input,expected_post_ids', [ ('feature-count:1', [1]), ('feature-count:3', [3]), ('feature-count:1,3', [1, 3]), ]) def test_filter_by_feature_count( - verify_unpaged, post_factory, feature_factory, input, expected_post_ids): + verify_unpaged, + post_factory, + feature_factory, + input, + expected_post_ids): post1 = post_factory(id=1) post2 = post_factory(id=2) post3 = post_factory(id=3) - post1.features=[feature_factory()] - post2.features=[feature_factory(), feature_factory()] - post3.features=[feature_factory(), feature_factory(), feature_factory()] + post1.features = [feature_factory()] + post2.features = [feature_factory(), feature_factory()] + post3.features = [feature_factory(), feature_factory(), feature_factory()] db.session.add_all([post1, post2, post3]) verify_unpaged(input, expected_post_ids) + @pytest.mark.parametrize('input,expected_post_ids', [ ('type:image', [1]), ('type:anim', [2]), @@ -278,7 +311,8 @@ def test_filter_by_feature_count( ('type:flash', [4]), ('type:swf', [4]), ]) -def test_filter_by_type(verify_unpaged, post_factory, input, expected_post_ids): +def test_filter_by_type( + verify_unpaged, post_factory, input, expected_post_ids): post1 = post_factory(id=1) post2 = post_factory(id=2) post3 = post_factory(id=3) @@ -290,13 +324,15 @@ def test_filter_by_type(verify_unpaged, post_factory, input, expected_post_ids): db.session.add_all([post1, post2, post3, post4]) verify_unpaged(input, expected_post_ids) + @pytest.mark.parametrize('input,expected_post_ids', [ ('safety:safe', [1]), ('safety:sketchy', [2]), ('safety:questionable', [2]), ('safety:unsafe', [3]), ]) -def test_filter_by_safety(verify_unpaged, post_factory, input, expected_post_ids): +def test_filter_by_safety( + verify_unpaged, post_factory, input, expected_post_ids): post1 = post_factory(id=1) post2 = post_factory(id=2) post3 = post_factory(id=3) @@ -306,10 +342,11 @@ def test_filter_by_safety(verify_unpaged, post_factory, input, expected_post_ids db.session.add_all([post1, post2, post3]) verify_unpaged(input, expected_post_ids) + def test_filter_by_invalid_type(executor): with pytest.raises(errors.SearchError): - actual_count, actual_posts = executor.execute( - 'type:invalid', page=1, page_size=100) + executor.execute('type:invalid', page=1, page_size=100) + @pytest.mark.parametrize('input,expected_post_ids', [ ('file-size:100', [1]), @@ -327,6 +364,7 @@ def test_filter_by_file_size( db.session.add_all([post1, post2, post3]) verify_unpaged(input, expected_post_ids) + @pytest.mark.parametrize('input,expected_post_ids', [ ('image-width:100', [1]), ('image-width:102', [3]), @@ -352,6 +390,7 @@ def test_filter_by_image_size( db.session.add_all([post1, post2, post3]) verify_unpaged(input, expected_post_ids) + @pytest.mark.parametrize('input,expected_post_ids', [ ('creation-date:2014', [1]), ('creation-date:2016', [3]), @@ -371,12 +410,13 @@ def test_filter_by_creation_time( post1 = post_factory(id=1) post2 = post_factory(id=2) post3 = post_factory(id=3) - post1.creation_time = datetime.datetime(2014, 1, 1) - post2.creation_time = datetime.datetime(2015, 1, 1) - post3.creation_time = datetime.datetime(2016, 1, 1) + post1.creation_time = datetime(2014, 1, 1) + post2.creation_time = datetime(2015, 1, 1) + post3.creation_time = datetime(2016, 1, 1) db.session.add_all([post1, post2, post3]) verify_unpaged(input, expected_post_ids) + @pytest.mark.parametrize('input,expected_post_ids', [ ('last-edit-date:2014', [1]), ('last-edit-date:2016', [3]), @@ -396,12 +436,13 @@ def test_filter_by_last_edit_time( post1 = post_factory(id=1) post2 = post_factory(id=2) post3 = post_factory(id=3) - post1.last_edit_time = datetime.datetime(2014, 1, 1) - post2.last_edit_time = datetime.datetime(2015, 1, 1) - post3.last_edit_time = datetime.datetime(2016, 1, 1) + post1.last_edit_time = datetime(2014, 1, 1) + post2.last_edit_time = datetime(2015, 1, 1) + post3.last_edit_time = datetime(2016, 1, 1) db.session.add_all([post1, post2, post3]) verify_unpaged(input, expected_post_ids) + @pytest.mark.parametrize('input,expected_post_ids', [ ('comment-date:2014', [1]), ('comment-date:2016', [3]), @@ -411,19 +452,24 @@ def test_filter_by_last_edit_time( ('comment-time:2014,2016', [1, 3]), ]) def test_filter_by_comment_date( - verify_unpaged, post_factory, comment_factory, input, expected_post_ids): + verify_unpaged, + post_factory, + comment_factory, + input, + expected_post_ids): post1 = post_factory(id=1) post2 = post_factory(id=2) post3 = post_factory(id=3) comment1 = comment_factory(post=post1) comment2 = comment_factory(post=post2) comment3 = comment_factory(post=post3) - comment1.creation_time = datetime.datetime(2014, 1, 1) - comment2.creation_time = datetime.datetime(2015, 1, 1) - comment3.creation_time = datetime.datetime(2016, 1, 1) + comment1.creation_time = datetime(2014, 1, 1) + comment2.creation_time = datetime(2015, 1, 1) + comment3.creation_time = datetime(2016, 1, 1) db.session.add_all([post1, post2, post3, comment1, comment2, comment3]) verify_unpaged(input, expected_post_ids) + @pytest.mark.parametrize('input,expected_post_ids', [ ('fav-date:2014', [1]), ('fav-date:2016', [3]), @@ -440,12 +486,13 @@ def test_filter_by_fav_date( fav1 = fav_factory(post=post1) fav2 = fav_factory(post=post2) fav3 = fav_factory(post=post3) - fav1.time = datetime.datetime(2014, 1, 1) - fav2.time = datetime.datetime(2015, 1, 1) - fav3.time = datetime.datetime(2016, 1, 1) + fav1.time = datetime(2014, 1, 1) + fav2.time = datetime(2015, 1, 1) + fav3.time = datetime(2016, 1, 1) db.session.add_all([post1, post2, post3, fav1, fav2, fav3]) verify_unpaged(input, expected_post_ids) + @pytest.mark.parametrize('input,expected_post_ids', [ ('feature-date:2014', [1]), ('feature-date:2016', [3]), @@ -455,19 +502,24 @@ def test_filter_by_fav_date( ('feature-time:2014,2016', [1, 3]), ]) def test_filter_by_feature_date( - verify_unpaged, post_factory, feature_factory, input, expected_post_ids): + verify_unpaged, + post_factory, + feature_factory, + input, + expected_post_ids): post1 = post_factory(id=1) post2 = post_factory(id=2) post3 = post_factory(id=3) feature1 = feature_factory(post=post1) feature2 = feature_factory(post=post2) feature3 = feature_factory(post=post3) - feature1.time = datetime.datetime(2014, 1, 1) - feature2.time = datetime.datetime(2015, 1, 1) - feature3.time = datetime.datetime(2016, 1, 1) + feature1.time = datetime(2014, 1, 1) + feature2.time = datetime(2015, 1, 1) + feature3.time = datetime(2016, 1, 1) db.session.add_all([post1, post2, post3, feature1, feature2, feature3]) verify_unpaged(input, expected_post_ids) + @pytest.mark.parametrize('input', [ 'sort:random', 'sort:id', @@ -506,6 +558,7 @@ def test_sort_tokens(verify_unpaged, post_factory, input): db.session.add_all([post1, post2, post3]) verify_unpaged(input, [1, 2, 3]) + @pytest.mark.parametrize('input,expected_post_ids', [ ('', [1, 2, 3, 4]), ('t1', [1]), @@ -520,54 +573,69 @@ def test_anonymous( post2 = post_factory(id=2) post3 = post_factory(id=3) post4 = post_factory(id=4) - post1.tags=[tag_factory(names=['t1'])] - post2.tags=[tag_factory(names=['t2'])] - post3.tags=[tag_factory(names=['t3'])] - post4.tags=[tag_factory(names=['t4a', 't4b'])] + post1.tags = [tag_factory(names=['t1'])] + post2.tags = [tag_factory(names=['t2'])] + post3.tags = [tag_factory(names=['t3'])] + post4.tags = [tag_factory(names=['t4a', 't4b'])] db.session.add_all([post1, post2, post3, post4]) verify_unpaged(input, expected_post_ids) + def test_own_liked( - auth_executor, post_factory, score_factory, user_factory, verify_unpaged): + auth_executor, + post_factory, + score_factory, + user_factory, + verify_unpaged): auth_user = auth_executor() post1 = post_factory(id=1) post2 = post_factory(id=2) post3 = post_factory(id=3) db.session.add_all([ score_factory(post=post1, user=auth_user, score=1), - score_factory(post=post2, user=user_factory(name='unrelated'), score=1), + score_factory(post=post2, user=user_factory(name='dummy'), score=1), score_factory(post=post3, user=auth_user, score=-1), post1, post2, post3, ]) verify_unpaged('special:liked', [1]) verify_unpaged('-special:liked', [2, 3]) + def test_own_disliked( - auth_executor, post_factory, score_factory, user_factory, verify_unpaged): + auth_executor, + post_factory, + score_factory, + user_factory, + verify_unpaged): auth_user = auth_executor() post1 = post_factory(id=1) post2 = post_factory(id=2) post3 = post_factory(id=3) db.session.add_all([ score_factory(post=post1, user=auth_user, score=-1), - score_factory(post=post2, user=user_factory(name='unrelated'), score=-1), + score_factory(post=post2, user=user_factory(name='dummy'), score=-1), score_factory(post=post3, user=auth_user, score=1), post1, post2, post3, ]) verify_unpaged('special:disliked', [1]) verify_unpaged('-special:disliked', [2, 3]) + @pytest.mark.parametrize('input', [ 'liked:x', 'disliked:x', ]) def test_someones_score(executor, input): with pytest.raises(errors.SearchError): - actual_count, actual_posts = executor.execute( - input, page=1, page_size=100) + executor.execute(input, page=1, page_size=100) + def test_own_fav( - auth_executor, post_factory, fav_factory, user_factory, verify_unpaged): + auth_executor, + post_factory, + fav_factory, + user_factory, + verify_unpaged): auth_user = auth_executor() post1 = post_factory(id=1) post2 = post_factory(id=2) @@ -579,8 +647,8 @@ def test_own_fav( verify_unpaged('special:fav', [1]) verify_unpaged('-special:fav', [2]) + def test_tumbleweed( - executor, post_factory, fav_factory, comment_factory, diff --git a/server/szurubooru/tests/search/configs/test_tag_search_config.py b/server/szurubooru/tests/search/configs/test_tag_search_config.py index 5b3002f6..9ba807d2 100644 --- a/server/szurubooru/tests/search/configs/test_tag_search_config.py +++ b/server/szurubooru/tests/search/configs/test_tag_search_config.py @@ -1,11 +1,14 @@ -import datetime +# pylint: disable=redefined-outer-name +from datetime import datetime import pytest from szurubooru import db, errors, search + @pytest.fixture def executor(): return search.Executor(search.configs.TagSearchConfig()) + @pytest.fixture def verify_unpaged(executor): def verify(input, expected_tag_names): @@ -16,6 +19,7 @@ def verify_unpaged(executor): assert actual_tag_names == expected_tag_names return verify + @pytest.mark.parametrize('input,expected_tag_names', [ ('', ['t1', 't2']), ('t1', ['t1']), @@ -23,15 +27,18 @@ def verify_unpaged(executor): ('t1,t2', ['t1', 't2']), ('T1,T2', ['t1', 't2']), ]) -def test_filter_anonymous(verify_unpaged, tag_factory, input, expected_tag_names): +def test_filter_anonymous( + verify_unpaged, tag_factory, input, expected_tag_names): db.session.add(tag_factory(names=['t1'])) db.session.add(tag_factory(names=['t2'])) verify_unpaged(input, expected_tag_names) + def test_filter_anonymous_starting_with_colon(verify_unpaged, tag_factory): db.session.add(tag_factory(names=[':t'])) verify_unpaged(':t', [':t']) + @pytest.mark.parametrize('input,expected_tag_names', [ ('name:tag1', ['tag1']), ('name:tag2', ['tag2']), @@ -53,13 +60,15 @@ def test_filter_anonymous_starting_with_colon(verify_unpaged, tag_factory): ('name:tag5', ['tag4']), ('name:tag4,tag5', ['tag4']), ]) -def test_filter_by_name(verify_unpaged, tag_factory, input, expected_tag_names): +def test_filter_by_name( + verify_unpaged, tag_factory, input, expected_tag_names): db.session.add(tag_factory(names=['tag1'])) db.session.add(tag_factory(names=['tag2'])) db.session.add(tag_factory(names=['tag3'])) db.session.add(tag_factory(names=['tag4', 'tag5', 'tag6'])) verify_unpaged(input, expected_tag_names) + @pytest.mark.parametrize('input,expected_tag_names', [ ('category:cat1', ['t1', 't2']), ('category:cat2', ['t3']), @@ -79,6 +88,7 @@ def test_filter_by_category( db.session.add_all([tag1, tag2, tag3]) verify_unpaged(input, expected_tag_names) + @pytest.mark.parametrize('input,expected_tag_names', [ ('creation-time:2014', ['t1', 't2']), ('creation-date:2014', ['t1', 't2']), @@ -106,12 +116,13 @@ def test_filter_by_creation_time( tag1 = tag_factory(names=['t1']) tag2 = tag_factory(names=['t2']) tag3 = tag_factory(names=['t3']) - tag1.creation_time = datetime.datetime(2014, 1, 1) - tag2.creation_time = datetime.datetime(2014, 6, 1) - tag3.creation_time = datetime.datetime(2015, 1, 1) + tag1.creation_time = datetime(2014, 1, 1) + tag2.creation_time = datetime(2014, 6, 1) + tag3.creation_time = datetime(2015, 1, 1) db.session.add_all([tag1, tag2, tag3]) verify_unpaged(input, expected_tag_names) + @pytest.mark.parametrize('input,expected_tag_names', [ ('last-edit-date:2014', ['t1', 't3']), ('last-edit-time:2014', ['t1', 't3']), @@ -123,12 +134,13 @@ def test_filter_by_edit_time( tag1 = tag_factory(names=['t1']) tag2 = tag_factory(names=['t2']) tag3 = tag_factory(names=['t3']) - tag1.last_edit_time = datetime.datetime(2014, 1, 1) - tag2.last_edit_time = datetime.datetime(2015, 1, 1) - tag3.last_edit_time = datetime.datetime(2014, 1, 1) + tag1.last_edit_time = datetime(2014, 1, 1) + tag2.last_edit_time = datetime(2015, 1, 1) + tag3.last_edit_time = datetime(2014, 1, 1) db.session.add_all([tag1, tag2, tag3]) verify_unpaged(input, expected_tag_names) + @pytest.mark.parametrize('input,expected_tag_names', [ ('post-count:2', ['t1']), ('post-count:1', ['t2']), @@ -154,6 +166,7 @@ def test_filter_by_post_count( post2.tags.append(tag1) verify_unpaged(input, expected_tag_names) + @pytest.mark.parametrize('input', [ 'post-count:..', 'post-count:asd', @@ -164,8 +177,8 @@ def test_filter_by_post_count( ]) def test_filter_by_invalid_input(executor, input): with pytest.raises(errors.SearchError): - actual_count, actual_posts = executor.execute( - input, page=1, page_size=100) + executor.execute(input, page=1, page_size=100) + @pytest.mark.parametrize('input,expected_tag_names', [ ('suggestion-count:2', ['t1']), @@ -186,6 +199,7 @@ def test_filter_by_suggestion_count( tag2.suggestions.append(sug3) verify_unpaged(input, expected_tag_names) + @pytest.mark.parametrize('input,expected_tag_names', [ ('implication-count:2', ['t1']), ('implication-count:1', ['t2']), @@ -205,6 +219,7 @@ def test_filter_by_implication_count( tag2.implications.append(sug3) verify_unpaged(input, expected_tag_names) + @pytest.mark.parametrize('input,expected_tag_names', [ ('', ['t1', 't2']), ('sort:name', ['t1', 't2']), @@ -219,6 +234,7 @@ def test_sort_by_name(verify_unpaged, tag_factory, input, expected_tag_names): db.session.add(tag_factory(names=['t1'])) verify_unpaged(input, expected_tag_names) + @pytest.mark.parametrize('input,expected_tag_names', [ ('', ['t1', 't2', 't3']), ('sort:creation-date', ['t3', 't2', 't1']), @@ -229,12 +245,13 @@ def test_sort_by_creation_time( tag1 = tag_factory(names=['t1']) tag2 = tag_factory(names=['t2']) tag3 = tag_factory(names=['t3']) - tag1.creation_time = datetime.datetime(1991, 1, 1) - tag2.creation_time = datetime.datetime(1991, 1, 2) - tag3.creation_time = datetime.datetime(1991, 1, 3) + tag1.creation_time = datetime(1991, 1, 1) + tag2.creation_time = datetime(1991, 1, 2) + tag3.creation_time = datetime(1991, 1, 3) db.session.add_all([tag3, tag1, tag2]) verify_unpaged(input, expected_tag_names) + @pytest.mark.parametrize('input,expected_tag_names', [ ('', ['t1', 't2', 't3']), ('sort:last-edit-date', ['t3', 't2', 't1']), @@ -247,12 +264,13 @@ def test_sort_by_last_edit_time( tag1 = tag_factory(names=['t1']) tag2 = tag_factory(names=['t2']) tag3 = tag_factory(names=['t3']) - tag1.last_edit_time = datetime.datetime(1991, 1, 1) - tag2.last_edit_time = datetime.datetime(1991, 1, 2) - tag3.last_edit_time = datetime.datetime(1991, 1, 3) + tag1.last_edit_time = datetime(1991, 1, 1) + tag2.last_edit_time = datetime(1991, 1, 2) + tag3.last_edit_time = datetime(1991, 1, 3) db.session.add_all([tag3, tag1, tag2]) verify_unpaged(input, expected_tag_names) + @pytest.mark.parametrize('input,expected_tag_names', [ ('sort:post-count', ['t2', 't1']), ('sort:usage-count', ['t2', 't1']), @@ -271,6 +289,7 @@ def test_sort_by_post_count( post2.tags.append(tag2) verify_unpaged(input, expected_tag_names) + @pytest.mark.parametrize('input,expected_tag_names', [ ('sort:suggestion-count', ['t1', 't2', 'sug1', 'sug2', 'sug3']), ]) @@ -288,6 +307,7 @@ def test_sort_by_suggestion_count( tag2.suggestions.append(sug3) verify_unpaged(input, expected_tag_names) + @pytest.mark.parametrize('input,expected_tag_names', [ ('sort:implication-count', ['t1', 't2', 'sug1', 'sug2', 'sug3']), ]) @@ -305,6 +325,7 @@ def test_sort_by_implication_count( tag2.implications.append(sug3) verify_unpaged(input, expected_tag_names) + @pytest.mark.parametrize('input,expected_tag_names', [ ('sort:category', ['t3', 't1', 't2']), ]) diff --git a/server/szurubooru/tests/search/configs/test_user_search_config.py b/server/szurubooru/tests/search/configs/test_user_search_config.py index 7a2c4503..0c61402c 100644 --- a/server/szurubooru/tests/search/configs/test_user_search_config.py +++ b/server/szurubooru/tests/search/configs/test_user_search_config.py @@ -1,11 +1,14 @@ -import datetime +# pylint: disable=redefined-outer-name +from datetime import datetime import pytest from szurubooru import db, errors, search + @pytest.fixture def executor(): return search.Executor(search.configs.UserSearchConfig()) + @pytest.fixture def verify_unpaged(executor): def verify(input, expected_user_names): @@ -16,6 +19,7 @@ def verify_unpaged(executor): assert actual_user_names == expected_user_names return verify + @pytest.mark.parametrize('input,expected_user_names', [ ('creation-time:2014', ['u1', 'u2']), ('creation-date:2014', ['u1', 'u2']), @@ -45,12 +49,13 @@ def test_filter_by_creation_time( user1 = user_factory(name='u1') user2 = user_factory(name='u2') user3 = user_factory(name='u3') - user1.creation_time = datetime.datetime(2014, 1, 1) - user2.creation_time = datetime.datetime(2014, 6, 1) - user3.creation_time = datetime.datetime(2015, 1, 1) + user1.creation_time = datetime(2014, 1, 1) + user2.creation_time = datetime(2014, 6, 1) + user3.creation_time = datetime(2015, 1, 1) db.session.add_all([user1, user2, user3]) verify_unpaged(input, expected_user_names) + @pytest.mark.parametrize('input,expected_user_names', [ ('name:user1', ['user1']), ('name:user2', ['user2']), @@ -76,6 +81,7 @@ def test_filter_by_name( db.session.add(user_factory(name='user3')) verify_unpaged(input, expected_user_names) + @pytest.mark.parametrize('input,expected_user_names', [ ('', ['u1', 'u2']), ('u1', ['u1']), @@ -88,6 +94,7 @@ def test_anonymous( db.session.add(user_factory(name='u2')) verify_unpaged(input, expected_user_names) + @pytest.mark.parametrize('input,expected_user_names', [ ('creation-time:2014 u1', ['u1']), ('creation-time:2014 u2', ['u2']), @@ -98,12 +105,13 @@ def test_combining_tokens( user1 = user_factory(name='u1') user2 = user_factory(name='u2') user3 = user_factory(name='u3') - user1.creation_time = datetime.datetime(2014, 1, 1) - user2.creation_time = datetime.datetime(2014, 6, 1) - user3.creation_time = datetime.datetime(2015, 1, 1) + user1.creation_time = datetime(2014, 1, 1) + user2.creation_time = datetime(2014, 6, 1) + user3.creation_time = datetime(2015, 1, 1) db.session.add_all([user1, user2, user3]) verify_unpaged(input, expected_user_names) + @pytest.mark.parametrize( 'page,page_size,expected_total_count,expected_user_names', [ (1, 1, 2, ['u1']), @@ -123,6 +131,7 @@ def test_paging( assert actual_count == expected_total_count assert actual_user_names == expected_user_names + @pytest.mark.parametrize('input,expected_user_names', [ ('', ['u1', 'u2']), ('sort:name', ['u1', 'u2']), @@ -138,6 +147,7 @@ def test_sort_by_name( db.session.add(user_factory(name='u1')) verify_unpaged(input, expected_user_names) + @pytest.mark.parametrize('input,expected_user_names', [ ('', ['u1', 'u2', 'u3']), ('sort:creation-date', ['u3', 'u2', 'u1']), @@ -153,12 +163,13 @@ def test_sort_by_creation_time( user1 = user_factory(name='u1') user2 = user_factory(name='u2') user3 = user_factory(name='u3') - user1.creation_time = datetime.datetime(1991, 1, 1) - user2.creation_time = datetime.datetime(1991, 1, 2) - user3.creation_time = datetime.datetime(1991, 1, 3) + user1.creation_time = datetime(1991, 1, 1) + user2.creation_time = datetime(1991, 1, 2) + user3.creation_time = datetime(1991, 1, 3) db.session.add_all([user3, user1, user2]) verify_unpaged(input, expected_user_names) + @pytest.mark.parametrize('input,expected_user_names', [ ('', ['u1', 'u2', 'u3']), ('sort:last-login-date', ['u3', 'u2', 'u1']), @@ -171,12 +182,13 @@ def test_sort_by_last_login_time( user1 = user_factory(name='u1') user2 = user_factory(name='u2') user3 = user_factory(name='u3') - user1.last_login_time = datetime.datetime(1991, 1, 1) - user2.last_login_time = datetime.datetime(1991, 1, 2) - user3.last_login_time = datetime.datetime(1991, 1, 3) + user1.last_login_time = datetime(1991, 1, 1) + user2.last_login_time = datetime(1991, 1, 2) + user3.last_login_time = datetime(1991, 1, 3) db.session.add_all([user3, user1, user2]) verify_unpaged(input, expected_user_names) + def test_random_sort(executor, user_factory): user1 = user_factory(name='u1') user2 = user_factory(name='u2') @@ -191,6 +203,7 @@ def test_random_sort(executor, user_factory): assert 'u2' in actual_user_names assert 'u3' in actual_user_names + @pytest.mark.parametrize('input,expected_error', [ ('creation-date:..', errors.SearchError), ('creation-date-min:..', errors.ValidationError), diff --git a/server/szurubooru/tests/search/test_executor.py b/server/szurubooru/tests/search/test_executor.py index 99e085f2..e1b2dacb 100644 --- a/server/szurubooru/tests/search/test_executor.py +++ b/server/szurubooru/tests/search/test_executor.py @@ -3,7 +3,8 @@ import pytest from szurubooru import search from szurubooru.func import cache -def test_retrieving_from_cache(user_factory): + +def test_retrieving_from_cache(): config = unittest.mock.MagicMock() with unittest.mock.patch('szurubooru.func.cache.has'), \ unittest.mock.patch('szurubooru.func.cache.get'): @@ -12,13 +13,16 @@ def test_retrieving_from_cache(user_factory): executor.execute('test:whatever', 1, 10) assert cache.get.called + def test_putting_equivalent_queries_into_cache(): config = search.configs.PostSearchConfig() with unittest.mock.patch('szurubooru.func.cache.has'), \ unittest.mock.patch('szurubooru.func.cache.put'): hashes = [] - def appender(key, value): + + def appender(key, _value): hashes.append(key) + cache.has.side_effect = lambda *args: False cache.put.side_effect = appender executor = search.Executor(config) @@ -31,13 +35,16 @@ def test_putting_equivalent_queries_into_cache(): assert len(hashes) == 6 assert len(set(hashes)) == 1 + def test_putting_non_equivalent_queries_into_cache(): config = search.configs.PostSearchConfig() with unittest.mock.patch('szurubooru.func.cache.has'), \ unittest.mock.patch('szurubooru.func.cache.put'): hashes = [] - def appender(key, value): + + def appender(key, _value): hashes.append(key) + cache.has.side_effect = lambda *args: False cache.put.side_effect = appender executor = search.Executor(config) @@ -84,6 +91,7 @@ def test_putting_non_equivalent_queries_into_cache(): assert len(hashes) == len(args) assert len(set(hashes)) == len(args) + @pytest.mark.parametrize('input', [ 'special:fav', 'special:liked', @@ -97,8 +105,10 @@ def test_putting_auth_dependent_queries_into_cache(user_factory, input): with unittest.mock.patch('szurubooru.func.cache.has'), \ unittest.mock.patch('szurubooru.func.cache.put'): hashes = [] - def appender(key, value): + + def appender(key, _value): hashes.append(key) + cache.has.side_effect = lambda *args: False cache.put.side_effect = appender executor = search.Executor(config)