diff --git a/server/szurubooru/api/comment_api.py b/server/szurubooru/api/comment_api.py index 08fd09c4..51058160 100644 --- a/server/szurubooru/api/comment_api.py +++ b/server/szurubooru/api/comment_api.py @@ -1,7 +1,7 @@ import datetime from szurubooru import search -from szurubooru.func import auth, comments, posts, scores, util from szurubooru.rest import routes +from szurubooru.func import auth, comments, posts, scores, util, versions _search_executor = search.Executor(search.configs.CommentSearchConfig()) @@ -43,12 +43,12 @@ def get_comment(ctx, params): @routes.put('/comment/(?P[^/]+)/?') def update_comment(ctx, params): comment = comments.get_comment_by_id(params['comment_id']) - util.verify_version(comment, ctx) + versions.verify_version(comment, ctx) + versions.bump_version(comment) infix = 'own' if ctx.user.user_id == comment.user_id else 'any' text = ctx.get_param_as_string('text', required=True) auth.verify_privilege(ctx.user, 'comments:edit:%s' % infix) comments.update_comment_text(comment, text) - util.bump_version(comment) comment.last_edit_time = datetime.datetime.utcnow() ctx.session.commit() return _serialize(ctx, comment) @@ -57,7 +57,7 @@ def update_comment(ctx, params): @routes.delete('/comment/(?P[^/]+)/?') def delete_comment(ctx, params): comment = comments.get_comment_by_id(params['comment_id']) - util.verify_version(comment, ctx) + versions.verify_version(comment, ctx) infix = 'own' if ctx.user.user_id == comment.user_id else 'any' auth.verify_privilege(ctx.user, 'comments:delete:%s' % infix) ctx.session.delete(comment) diff --git a/server/szurubooru/api/info_api.py b/server/szurubooru/api/info_api.py index a5485546..8e74cb72 100644 --- a/server/szurubooru/api/info_api.py +++ b/server/szurubooru/api/info_api.py @@ -1,8 +1,8 @@ import datetime import os from szurubooru import config -from szurubooru.func import posts, users, util from szurubooru.rest import routes +from szurubooru.func import posts, users, util _cache_time = None diff --git a/server/szurubooru/api/password_reset_api.py b/server/szurubooru/api/password_reset_api.py index fbc4ba4d..7e5864c9 100644 --- a/server/szurubooru/api/password_reset_api.py +++ b/server/szurubooru/api/password_reset_api.py @@ -1,6 +1,6 @@ from szurubooru import config, errors -from szurubooru.func import auth, mailer, users, util from szurubooru.rest import routes +from szurubooru.func import auth, mailer, users, versions MAIL_SUBJECT = 'Password reset for {name}' @@ -40,6 +40,6 @@ def finish_password_reset(ctx, params): if token != good_token: raise errors.ValidationError('Invalid password reset token.') new_password = users.reset_user_password(user) - util.bump_version(user) + versions.bump_version(user) ctx.session.commit() return {'password': new_password} diff --git a/server/szurubooru/api/post_api.py b/server/szurubooru/api/post_api.py index e800bd54..53bbcd26 100644 --- a/server/szurubooru/api/post_api.py +++ b/server/szurubooru/api/post_api.py @@ -2,7 +2,7 @@ import datetime from szurubooru import search from szurubooru.rest import routes from szurubooru.func import ( - auth, tags, posts, snapshots, favorites, scores, util) + auth, tags, posts, snapshots, favorites, scores, util, versions) _search_executor = search.Executor(search.configs.PostSearchConfig()) @@ -68,7 +68,8 @@ def get_post(ctx, params): @routes.put('/post/(?P[^/]+)/?') def update_post(ctx, params): post = posts.get_post_by_id(params['post_id']) - util.verify_version(post, ctx) + versions.verify_version(post, ctx) + versions.bump_version(post) if ctx.has_file('content'): auth.verify_privilege(ctx.user, 'posts:edit:content') posts.update_post_content(post, ctx.get_file('content')) @@ -97,7 +98,6 @@ def update_post(ctx, params): if ctx.has_file('thumbnail'): auth.verify_privilege(ctx.user, 'posts:edit:thumbnail') posts.update_post_thumbnail(post, ctx.get_file('thumbnail')) - util.bump_version(post) post.last_edit_time = datetime.datetime.utcnow() ctx.session.flush() snapshots.save_entity_modification(post, ctx.user) @@ -110,7 +110,7 @@ def update_post(ctx, params): def delete_post(ctx, params): auth.verify_privilege(ctx.user, 'posts:delete') post = posts.get_post_by_id(params['post_id']) - util.verify_version(post, ctx) + versions.verify_version(post, ctx) snapshots.save_entity_deletion(post, ctx.user) posts.delete(post) ctx.session.commit() diff --git a/server/szurubooru/api/snapshot_api.py b/server/szurubooru/api/snapshot_api.py index bf6eba86..1a5f5896 100644 --- a/server/szurubooru/api/snapshot_api.py +++ b/server/szurubooru/api/snapshot_api.py @@ -1,6 +1,6 @@ from szurubooru import search -from szurubooru.func import auth, snapshots from szurubooru.rest import routes +from szurubooru.func import auth, snapshots _search_executor = search.Executor(search.configs.SnapshotSearchConfig()) diff --git a/server/szurubooru/api/tag_api.py b/server/szurubooru/api/tag_api.py index 630cf19b..4daf7e5a 100644 --- a/server/szurubooru/api/tag_api.py +++ b/server/szurubooru/api/tag_api.py @@ -1,7 +1,7 @@ import datetime from szurubooru import db, search -from szurubooru.func import auth, tags, util, snapshots from szurubooru.rest import routes +from szurubooru.func import auth, tags, snapshots, util, versions _search_executor = search.Executor(search.configs.TagSearchConfig()) @@ -66,7 +66,8 @@ def get_tag(ctx, params): @routes.put('/tag/(?P[^/]+)/?') def update_tag(ctx, params): tag = tags.get_tag_by_name(params['tag_name']) - util.verify_version(tag, ctx) + versions.verify_version(tag, ctx) + versions.bump_version(tag) if ctx.has_param('names'): auth.verify_privilege(ctx.user, 'tags:edit:names') tags.update_tag_names(tag, ctx.get_param_as_list('names')) @@ -88,7 +89,6 @@ def update_tag(ctx, params): implications = ctx.get_param_as_list('implications') _create_if_needed(implications, ctx.user) tags.update_tag_implications(tag, implications) - util.bump_version(tag) tag.last_edit_time = datetime.datetime.utcnow() ctx.session.flush() snapshots.save_entity_modification(tag, ctx.user) @@ -100,7 +100,7 @@ def update_tag(ctx, params): @routes.delete('/tag/(?P[^/]+)/?') def delete_tag(ctx, params): tag = tags.get_tag_by_name(params['tag_name']) - util.verify_version(tag, ctx) + versions.verify_version(tag, ctx) auth.verify_privilege(ctx.user, 'tags:delete') snapshots.save_entity_deletion(tag, ctx.user) tags.delete(tag) @@ -115,12 +115,12 @@ def merge_tags(ctx, _params=None): target_tag_name = ctx.get_param_as_string('mergeTo', required=True) or '' source_tag = tags.get_tag_by_name(source_tag_name) target_tag = tags.get_tag_by_name(target_tag_name) - util.verify_version(source_tag, ctx, 'removeVersion') - util.verify_version(target_tag, ctx, 'mergeToVersion') + versions.verify_version(source_tag, ctx, 'removeVersion') + versions.verify_version(target_tag, ctx, 'mergeToVersion') + versions.bump_version(target_tag) auth.verify_privilege(ctx.user, 'tags:merge') tags.merge_tags(source_tag, target_tag) snapshots.save_entity_deletion(source_tag, ctx.user) - util.bump_version(target_tag) ctx.session.commit() tags.export_to_json() return _serialize(ctx, target_tag) diff --git a/server/szurubooru/api/tag_category_api.py b/server/szurubooru/api/tag_category_api.py index fbd25530..27efb1c9 100644 --- a/server/szurubooru/api/tag_category_api.py +++ b/server/szurubooru/api/tag_category_api.py @@ -1,5 +1,6 @@ from szurubooru.rest import routes -from szurubooru.func import auth, tags, tag_categories, util, snapshots +from szurubooru.func import ( + auth, tags, tag_categories, snapshots, util, versions) def _serialize(ctx, category): @@ -40,7 +41,8 @@ def get_tag_category(ctx, params): @routes.put('/tag-category/(?P[^/]+)/?') def update_tag_category(ctx, params): category = tag_categories.get_category_by_name(params['category_name']) - util.verify_version(category, ctx) + versions.verify_version(category, ctx) + versions.bump_version(category) if ctx.has_param('name'): auth.verify_privilege(ctx.user, 'tag_categories:edit:name') tag_categories.update_category_name( @@ -49,7 +51,6 @@ def update_tag_category(ctx, params): auth.verify_privilege(ctx.user, 'tag_categories:edit:color') tag_categories.update_category_color( category, ctx.get_param_as_string('color')) - util.bump_version(category) ctx.session.flush() snapshots.save_entity_modification(category, ctx.user) ctx.session.commit() @@ -60,7 +61,7 @@ def update_tag_category(ctx, params): @routes.delete('/tag-category/(?P[^/]+)/?') def delete_tag_category(ctx, params): category = tag_categories.get_category_by_name(params['category_name']) - util.verify_version(category, ctx) + versions.verify_version(category, ctx) auth.verify_privilege(ctx.user, 'tag_categories:delete') tag_categories.delete_category(category) snapshots.save_entity_deletion(category, ctx.user) diff --git a/server/szurubooru/api/user_api.py b/server/szurubooru/api/user_api.py index aa99c42f..187e4686 100644 --- a/server/szurubooru/api/user_api.py +++ b/server/szurubooru/api/user_api.py @@ -1,6 +1,6 @@ from szurubooru import search -from szurubooru.func import auth, users, util from szurubooru.rest import routes +from szurubooru.func import auth, users, util, versions _search_executor = search.Executor(search.configs.UserSearchConfig()) @@ -52,7 +52,8 @@ def get_user(ctx, params): @routes.put('/user/(?P[^/]+)/?') def update_user(ctx, params): user = users.get_user_by_name(params['user_name']) - util.verify_version(user, ctx) + versions.verify_version(user, ctx) + versions.bump_version(user) infix = 'self' if ctx.user.user_id == user.user_id else 'any' if ctx.has_param('name'): auth.verify_privilege(ctx.user, 'users:edit:%s:name' % infix) @@ -74,7 +75,6 @@ def update_user(ctx, params): user, ctx.get_param_as_string('avatarStyle'), ctx.get_file('avatar')) - util.bump_version(user) ctx.session.commit() return _serialize(ctx, user) @@ -82,7 +82,7 @@ def update_user(ctx, params): @routes.delete('/user/(?P[^/]+)/?') def delete_user(ctx, params): user = users.get_user_by_name(params['user_name']) - util.verify_version(user, ctx) + versions.verify_version(user, ctx) infix = 'self' if ctx.user.user_id == user.user_id else 'any' auth.verify_privilege(ctx.user, 'users:delete:%s' % infix) ctx.session.delete(user) diff --git a/server/szurubooru/db/comment.py b/server/szurubooru/db/comment.py index 05298282..bbb8b189 100644 --- a/server/szurubooru/db/comment.py +++ b/server/szurubooru/db/comment.py @@ -48,3 +48,8 @@ class Comment(Base): .query(func.sum(CommentScore.score)) \ .filter(CommentScore.comment_id == self.comment_id) \ .one()[0] or 0 + + __mapper_args__ = { + 'version_id_col': version, + 'version_id_generator': False, + } diff --git a/server/szurubooru/db/post.py b/server/szurubooru/db/post.py index 5529e49d..741e581c 100644 --- a/server/szurubooru/db/post.py +++ b/server/szurubooru/db/post.py @@ -237,3 +237,8 @@ class Post(Base): (PostRelation.parent_id == post_id) | (PostRelation.child_id == post_id)) .correlate_except(PostRelation)) + + __mapper_args__ = { + 'version_id_col': version, + 'version_id_generator': False, + } diff --git a/server/szurubooru/db/tag.py b/server/szurubooru/db/tag.py index faa6e68d..2bdac536 100644 --- a/server/szurubooru/db/tag.py +++ b/server/szurubooru/db/tag.py @@ -115,3 +115,8 @@ class Tag(Base): .where(TagImplication.parent_id == tag_id) .as_scalar(), deferred=True) + + __mapper_args__ = { + 'version_id_col': version, + 'version_id_generator': False, + } diff --git a/server/szurubooru/db/tag_category.py b/server/szurubooru/db/tag_category.py index cb1d9328..907910ba 100644 --- a/server/szurubooru/db/tag_category.py +++ b/server/szurubooru/db/tag_category.py @@ -21,3 +21,8 @@ class TagCategory(Base): select([func.count('Tag.tag_id')]) .where(Tag.category_id == tag_category_id) .correlate_except(table('Tag'))) + + __mapper_args__ = { + 'version_id_col': version, + 'version_id_generator': False, + } diff --git a/server/szurubooru/db/user.py b/server/szurubooru/db/user.py index f6b9ceb2..082adcff 100644 --- a/server/szurubooru/db/user.py +++ b/server/szurubooru/db/user.py @@ -75,3 +75,8 @@ class User(Base): .filter(PostScore.user_id == self.user_id) .filter(PostScore.score == -1) .one()[0] or 0) + + __mapper_args__ = { + 'version_id_col': version, + 'version_id_generator': False, + } diff --git a/server/szurubooru/facade.py b/server/szurubooru/facade.py index ad1516c6..69bc6b99 100644 --- a/server/szurubooru/facade.py +++ b/server/szurubooru/facade.py @@ -3,6 +3,7 @@ import os import logging import coloredlogs +import sqlalchemy.orm.exc from szurubooru import config, errors, rest # pylint: disable=unused-import from szurubooru import api, middleware @@ -38,6 +39,11 @@ def _on_processing_error(ex): title='Processing error', description=str(ex)) +def _on_stale_data_error(_ex): + raise rest.errors.HttpConflict( + 'Someone else modified this in the meantime. Please try again.') + + def validate_config(): ''' Check whether config doesn't contain errors that might prove @@ -83,5 +89,6 @@ def create_app(): rest.errors.handle(errors.IntegrityError, _on_integrity_error) rest.errors.handle(errors.NotFoundError, _on_not_found_error) rest.errors.handle(errors.ProcessingError, _on_processing_error) + rest.errors.handle(sqlalchemy.orm.exc.StaleDataError, _on_stale_data_error) return rest.application diff --git a/server/szurubooru/func/util.py b/server/szurubooru/func/util.py index 8a1e6c97..32e36a66 100644 --- a/server/szurubooru/func/util.py +++ b/server/szurubooru/func/util.py @@ -154,16 +154,3 @@ def value_exceeds_column_size(value, column): if max_length is None: return False return len(value) > max_length - - -def verify_version(entity, context, field_name='version'): - actual_version = context.get_param_as_int(field_name, required=True) - expected_version = entity.version - if actual_version != expected_version: - raise errors.InvalidParameterError( - 'Someone else modified this in the meantime. ' + - 'Please try again.') - - -def bump_version(entity): - entity.version += 1 diff --git a/server/szurubooru/func/versions.py b/server/szurubooru/func/versions.py new file mode 100644 index 00000000..d38130ac --- /dev/null +++ b/server/szurubooru/func/versions.py @@ -0,0 +1,14 @@ +from szurubooru import errors + + +def verify_version(entity, context, field_name='version'): + actual_version = context.get_param_as_int(field_name, required=True) + expected_version = entity.version + if actual_version != expected_version: + raise errors.InvalidParameterError( + 'Someone else modified this in the meantime. ' + + 'Please try again.') + + +def bump_version(entity): + entity.version = entity.version + 1 diff --git a/server/szurubooru/tests/api/test_tag_updating.py b/server/szurubooru/tests/api/test_tag_updating.py index dbb8b803..81f195b9 100644 --- a/server/szurubooru/tests/api/test_tag_updating.py +++ b/server/szurubooru/tests/api/test_tag_updating.py @@ -136,6 +136,7 @@ def test_trying_to_create_tags_without_privileges( params={'suggestions': ['tag1', 'tag2'], 'version': 1}, user=user_factory(rank=db.User.RANK_REGULAR)), {'tag_name': 'tag'}) + db.session.rollback() with pytest.raises(errors.AuthError): api.tag_api.update_tag( context_factory(