diff --git a/server/szurubooru/api/post_api.py b/server/szurubooru/api/post_api.py index 34e007cc..e4d49750 100644 --- a/server/szurubooru/api/post_api.py +++ b/server/szurubooru/api/post_api.py @@ -44,9 +44,6 @@ def create_post(ctx, _params=None): content, tag_names, None if anonymous else ctx.user) if len(new_tags): auth.verify_privilege(ctx.user, 'tags:create') - db.session.flush() - for tag in new_tags: - snapshots.create(tag, None if anonymous else ctx.user) posts.update_post_safety(post, safety) posts.update_post_source(post, source) posts.update_post_relations(post, relations) @@ -55,7 +52,10 @@ def create_post(ctx, _params=None): if ctx.has_file('thumbnail'): posts.update_post_thumbnail(post, ctx.get_file('thumbnail')) ctx.session.add(post) + ctx.session.flush() snapshots.create(post, None if anonymous else ctx.user) + for tag in new_tags: + snapshots.create(tag, None if anonymous else ctx.user) ctx.session.commit() tags.export_to_json() return _serialize_post(ctx, post) diff --git a/server/szurubooru/func/posts.py b/server/szurubooru/func/posts.py index 31ab8d10..b7ffe2bd 100644 --- a/server/szurubooru/func/posts.py +++ b/server/szurubooru/func/posts.py @@ -86,6 +86,7 @@ def get_post_thumbnail_url(post): def get_post_content_path(post): assert post + assert post.post_id return 'posts/%d.%s' % ( post.post_id, mime.get_extension(post.mime_type) or 'dat') @@ -217,12 +218,10 @@ def create_post(content, tag_names, user): post.creation_time = datetime.datetime.utcnow() post.flags = [] - # we'll need post ID post.type = '' post.checksum = '' post.mime_type = '' db.session.add(post) - db.session.flush() update_post_content(post, content) new_tags = update_post_tags(post, tag_names) @@ -245,6 +244,38 @@ def update_post_source(post, source): post.source = source +@sqlalchemy.events.event.listens_for(db.Post, 'after_insert') +def _after_post_insert(_mapper, _connection, post): + _sync_post_content(post) + + +@sqlalchemy.events.event.listens_for(db.Post, 'after_update') +def _after_post_update(_mapper, _connection, post): + _sync_post_content(post) + + +def _sync_post_content(post): + regenerate_thumb = False + + if hasattr(post, '__content'): + files.save(get_post_content_path(post), getattr(post, '__content')) + delattr(post, '__content') + regenerate_thumb = True + + if hasattr(post, '__thumbnail'): + if getattr(post, '__thumbnail'): + files.save( + get_post_thumbnail_backup_path(post), + getattr(post, '__thumbnail')) + else: + files.delete(get_post_thumbnail_backup_path(post)) + delattr(post, '__thumbnail') + regenerate_thumb = True + + if regenerate_thumb: + generate_post_thumbnail(post) + + def update_post_content(post, content): assert post if not content: @@ -269,7 +300,9 @@ def update_post_content(post, content): .filter(db.Post.checksum == post.checksum) \ .filter(db.Post.post_id != post.post_id) \ .one_or_none() - if other_post: + if other_post \ + and other_post.post_id \ + and other_post.post_id != post.post_id: raise PostAlreadyUploadedError( 'Post already uploaded (%d)' % other_post.post_id) @@ -284,19 +317,12 @@ def update_post_content(post, content): if post.canvas_width <= 0 or post.canvas_height <= 0: post.canvas_width = None post.canvas_height = None - files.save(get_post_content_path(post), content) - update_post_thumbnail(post, content=None, do_delete=False) + setattr(post, '__content', content) -def update_post_thumbnail(post, content=None, do_delete=True): +def update_post_thumbnail(post, content=None): assert post - if not content: - content = files.get(get_post_content_path(post)) - if do_delete: - files.delete(get_post_thumbnail_backup_path(post)) - else: - files.save(get_post_thumbnail_backup_path(post), content) - generate_post_thumbnail(post) + setattr(post, '__thumbnail', content) def generate_post_thumbnail(post): diff --git a/server/szurubooru/tests/api/test_post_creating.py b/server/szurubooru/tests/api/test_post_creating.py index 2ea119da..45f0e708 100644 --- a/server/szurubooru/tests/api/test_post_creating.py +++ b/server/szurubooru/tests/api/test_post_creating.py @@ -257,6 +257,55 @@ def test_omitting_optional_field( assert result == 'serialized post' +def test_errors_not_spending_ids( + config_injector, tmpdir, context_factory, read_asset, user_factory): + config_injector({ + 'data_dir': str(tmpdir.mkdir('data')), + 'thumbnails': { + 'post_width': 300, + 'post_height': 300, + }, + 'privileges': { + 'posts:create:identified': db.User.RANK_REGULAR, + }, + }) + auth_user = user_factory(rank=db.User.RANK_REGULAR) + + # successful request + with patch('szurubooru.func.posts.serialize_post'), \ + patch('szurubooru.func.posts.update_post_tags'): + posts.serialize_post.side_effect = lambda post, *_, **__: post.post_id + post1_id = api.post_api.create_post( + context_factory( + params={'safety': 'safe', 'tags': []}, + files={'content': read_asset('png.png')}, + user=auth_user)) + db.session.commit() + + # erroreous request (duplicate post) + with pytest.raises(posts.PostAlreadyUploadedError): + api.post_api.create_post( + context_factory( + params={'safety': 'safe', 'tags': []}, + files={'content': read_asset('png.png')}, + user=auth_user)) + db.session.rollback() + + # successful request + with patch('szurubooru.func.posts.serialize_post'), \ + patch('szurubooru.func.posts.update_post_tags'): + posts.serialize_post.side_effect = lambda post, *_, **__: post.post_id + post2_id = api.post_api.create_post( + context_factory( + params={'safety': 'safe', 'tags': []}, + files={'content': read_asset('jpeg.jpg')}, + user=auth_user)) + + assert post1_id > 0 + assert post2_id > 0 + assert post2_id == post1_id + 1 + + def test_trying_to_omit_content(context_factory, user_factory): with pytest.raises(errors.MissingRequiredFileError): api.post_api.create_post( diff --git a/server/szurubooru/tests/func/test_posts.py b/server/szurubooru/tests/func/test_posts.py index fe955323..dc67c008 100644 --- a/server/szurubooru/tests/func/test_posts.py +++ b/server/szurubooru/tests/func/test_posts.py @@ -265,26 +265,35 @@ def test_update_post_source_with_too_long_string(): @pytest.mark.parametrize( - 'input_file,expected_mime_type,expected_type,output_file_name', + 'is_existing,input_file,expected_mime_type,expected_type,output_file_name', [ - ('png.png', 'image/png', db.Post.TYPE_IMAGE, '1.png'), - ('jpeg.jpg', 'image/jpeg', db.Post.TYPE_IMAGE, '1.jpg'), - ('gif.gif', 'image/gif', db.Post.TYPE_IMAGE, '1.gif'), - ('gif-animated.gif', 'image/gif', db.Post.TYPE_ANIMATION, '1.gif'), - ('webm.webm', 'video/webm', db.Post.TYPE_VIDEO, '1.webm'), - ('mp4.mp4', 'video/mp4', db.Post.TYPE_VIDEO, '1.mp4'), + (True, 'png.png', 'image/png', db.Post.TYPE_IMAGE, '1.png'), + (False, 'png.png', 'image/png', db.Post.TYPE_IMAGE, '1.png'), + (False, 'jpeg.jpg', 'image/jpeg', db.Post.TYPE_IMAGE, '1.jpg'), + (False, 'gif.gif', 'image/gif', db.Post.TYPE_IMAGE, '1.gif'), ( + False, + 'gif-animated.gif', + 'image/gif', + db.Post.TYPE_ANIMATION, + '1.gif', + ), + (False, 'webm.webm', 'video/webm', db.Post.TYPE_VIDEO, '1.webm'), + (False, 'mp4.mp4', 'video/mp4', db.Post.TYPE_VIDEO, '1.mp4'), + ( + False, 'flash.swf', 'application/x-shockwave-flash', db.Post.TYPE_FLASH, '1.swf' ), ]) -def test_update_post_content( +def test_update_post_content_for_new_post( tmpdir, config_injector, post_factory, read_asset, + is_existing, input_file, expected_mime_type, expected_type, @@ -298,14 +307,22 @@ def test_update_post_content( 'post_height': 300, }, }) - post = post_factory(id=1) + output_file_path = str(tmpdir) + '/data/posts/' + output_file_name + post = post_factory() db.session.add(post) + if is_existing: + db.session.flush() + assert post.post_id + else: + assert not post.post_id + assert not os.path.exists(output_file_path) posts.update_post_content(post, read_asset(input_file)) + assert not os.path.exists(output_file_path) db.session.flush() assert post.mime_type == expected_mime_type assert post.type == expected_type assert post.checksum == 'crc' - assert os.path.exists(str(tmpdir) + '/data/posts/' + output_file_name) + assert os.path.exists(output_file_path) def test_update_post_content_to_existing_content( @@ -320,7 +337,6 @@ def test_update_post_content_to_existing_content( post = post_factory() another_post = post_factory() db.session.add_all([post, another_post]) - db.session.flush() posts.update_post_content(post, read_asset('png.png')) db.session.flush() with pytest.raises(posts.PostAlreadyUploadedError): @@ -342,8 +358,8 @@ def test_update_post_content_with_broken_content( post = post_factory() another_post = post_factory() db.session.add_all([post, another_post]) - db.session.flush() posts.update_post_content(post, read_asset('png-broken.png')) + db.session.flush() assert post.canvas_width is None assert post.canvas_height is None @@ -355,8 +371,9 @@ def test_update_post_content_with_invalid_content(input_content): posts.update_post_content(post, input_content) +@pytest.mark.parametrize('is_existing', (True, False)) def test_update_post_thumbnail_to_new_one( - tmpdir, config_injector, read_asset, post_factory): + tmpdir, config_injector, read_asset, post_factory, is_existing): config_injector({ 'data_dir': str(tmpdir.mkdir('data')), 'thumbnails': { @@ -364,21 +381,31 @@ def test_update_post_thumbnail_to_new_one( 'post_height': 300, }, }) - post = post_factory(id=1) + post = post_factory() db.session.add(post) - db.session.flush() + if is_existing: + db.session.flush() + assert post.post_id + else: + assert not post.post_id + generated_path = str(tmpdir) + '/data/generated-thumbnails/1.jpg' + source_path = str(tmpdir) + '/data/posts/custom-thumbnails/1.dat' + assert not os.path.exists(generated_path) + assert not os.path.exists(source_path) posts.update_post_content(post, read_asset('png.png')) posts.update_post_thumbnail(post, read_asset('jpeg.jpg')) - source_path = str(tmpdir) + '/data/posts/custom-thumbnails/1.dat' - generated_path = str(tmpdir) + '/data/generated-thumbnails/1.jpg' - assert os.path.exists(source_path) + assert not os.path.exists(generated_path) + assert not os.path.exists(source_path) + db.session.flush() assert os.path.exists(generated_path) + assert os.path.exists(source_path) with open(source_path, 'rb') as handle: assert handle.read() == read_asset('jpeg.jpg') +@pytest.mark.parametrize('is_existing', (True, False)) def test_update_post_thumbnail_to_default( - tmpdir, config_injector, read_asset, post_factory): + tmpdir, config_injector, read_asset, post_factory, is_existing): config_injector({ 'data_dir': str(tmpdir.mkdir('data')), 'thumbnails': { @@ -386,19 +413,30 @@ def test_update_post_thumbnail_to_default( 'post_height': 300, }, }) - post = post_factory(id=1) + post = post_factory() db.session.add(post) - db.session.flush() + if is_existing: + db.session.flush() + assert post.post_id + else: + assert not post.post_id + generated_path = str(tmpdir) + '/data/generated-thumbnails/1.jpg' + source_path = str(tmpdir) + '/data/posts/custom-thumbnails/1.dat' + assert not os.path.exists(generated_path) + assert not os.path.exists(source_path) posts.update_post_content(post, read_asset('png.png')) posts.update_post_thumbnail(post, read_asset('jpeg.jpg')) posts.update_post_thumbnail(post, None) - assert os.path.exists(str(tmpdir) + '/data/generated-thumbnails/1.jpg') - assert not os.path.exists( - str(tmpdir) + '/data/posts/custom-thumbnails/1.dat') + assert not os.path.exists(generated_path) + assert not os.path.exists(source_path) + db.session.flush() + assert os.path.exists(generated_path) + assert not os.path.exists(source_path) +@pytest.mark.parametrize('is_existing', (True, False)) def test_update_post_thumbnail_with_broken_thumbnail( - tmpdir, config_injector, read_asset, post_factory): + tmpdir, config_injector, read_asset, post_factory, is_existing): config_injector({ 'data_dir': str(tmpdir.mkdir('data')), 'thumbnails': { @@ -406,15 +444,24 @@ def test_update_post_thumbnail_with_broken_thumbnail( 'post_height': 300, }, }) - post = post_factory(id=1) + post = post_factory() db.session.add(post) - db.session.flush() + if is_existing: + db.session.flush() + assert post.post_id + else: + assert not post.post_id + generated_path = str(tmpdir) + '/data/generated-thumbnails/1.jpg' + source_path = str(tmpdir) + '/data/posts/custom-thumbnails/1.dat' + assert not os.path.exists(generated_path) + assert not os.path.exists(source_path) posts.update_post_content(post, read_asset('png.png')) posts.update_post_thumbnail(post, read_asset('png-broken.png')) - source_path = str(tmpdir) + '/data/posts/custom-thumbnails/1.dat' - generated_path = str(tmpdir) + '/data/generated-thumbnails/1.jpg' - assert os.path.exists(source_path) + assert not os.path.exists(generated_path) + assert not os.path.exists(source_path) + db.session.flush() assert os.path.exists(generated_path) + assert os.path.exists(source_path) with open(source_path, 'rb') as handle: assert handle.read() == read_asset('png-broken.png') with open(generated_path, 'rb') as handle: @@ -434,10 +481,10 @@ def test_update_post_content_leaving_custom_thumbnail( }) post = post_factory(id=1) db.session.add(post) - db.session.flush() posts.update_post_content(post, read_asset('png.png')) posts.update_post_thumbnail(post, read_asset('jpeg.jpg')) posts.update_post_content(post, read_asset('png.png')) + db.session.flush() assert os.path.exists(str(tmpdir) + '/data/posts/custom-thumbnails/1.dat') assert os.path.exists(str(tmpdir) + '/data/generated-thumbnails/1.jpg')