From 2e57a0746f8d6ee962d85b08e92e5ec179a5863a Mon Sep 17 00:00:00 2001 From: rr- Date: Tue, 19 Apr 2016 17:39:16 +0200 Subject: [PATCH] server/general: consistently use db.session --- server/szurubooru/app.py | 5 +- server/szurubooru/search/search_executor.py | 4 +- server/szurubooru/search/tag_search_config.py | 2 +- .../szurubooru/search/user_search_config.py | 2 +- .../tests/api/test_password_reset.py | 20 +-- .../tests/api/test_tag_category_deleting.py | 30 ++--- .../tests/api/test_tag_category_retrieving.py | 4 +- .../tests/api/test_tag_category_updating.py | 24 ++-- .../szurubooru/tests/api/test_tag_creating.py | 42 +++---- .../szurubooru/tests/api/test_tag_deleting.py | 26 ++-- .../szurubooru/tests/api/test_tag_export.py | 13 +- .../tests/api/test_tag_retrieving.py | 8 +- .../szurubooru/tests/api/test_tag_updating.py | 86 ++++++------- .../tests/api/test_user_creating.py | 20 ++- .../tests/api/test_user_deleting.py | 21 ++-- .../tests/api/test_user_retrieving.py | 7 +- .../tests/api/test_user_updating.py | 32 +++-- server/szurubooru/tests/conftest.py | 4 +- server/szurubooru/tests/db/test_post.py | 66 +++++----- server/szurubooru/tests/db/test_tag.py | 51 ++++---- server/szurubooru/tests/db/test_user.py | 8 +- .../tests/search/test_tag_search_config.py | 55 ++++----- .../tests/search/test_user_search_config.py | 46 +++---- .../szurubooru/tests/util/test_snapshots.py | 114 +++++++++--------- server/szurubooru/util/snapshots.py | 8 +- server/szurubooru/util/tag_categories.py | 8 +- server/szurubooru/util/tags.py | 19 +-- server/szurubooru/util/users.py | 6 +- 28 files changed, 351 insertions(+), 380 deletions(-) diff --git a/server/szurubooru/app.py b/server/szurubooru/app.py index 74d9af34..1643b93e 100644 --- a/server/szurubooru/app.py +++ b/server/szurubooru/app.py @@ -1,6 +1,5 @@ ''' Exports create_app. ''' -import json import falcon from szurubooru import api, errors, middleware @@ -26,15 +25,13 @@ def _on_processing_error(ex, _request, _response, _params): def create_method_not_allowed(allowed_methods): allowed = ', '.join(allowed_methods) - - def method_not_allowed(request, response, **kwargs): + def method_not_allowed(request, response, **_kwargs): response.status = falcon.status_codes.HTTP_405 response.set_header('Allow', allowed) request.context.output = { 'title': 'Method not allowed', 'description': 'Allowed methods: %r' % allowed_methods, } - return method_not_allowed def create_app(): diff --git a/server/szurubooru/search/search_executor.py b/server/szurubooru/search/search_executor.py index d441cd28..73d32af9 100644 --- a/server/szurubooru/search/search_executor.py +++ b/server/szurubooru/search/search_executor.py @@ -23,8 +23,8 @@ class SearchExecutor(object): count_query = filter_query.statement \ .with_only_columns([sqlalchemy.func.count()]) \ .order_by(None) - count = filter_query \ - .session.execute(count_query) \ + count = filter_query.session \ + .execute(count_query) \ .scalar() return (count, entities) diff --git a/server/szurubooru/search/tag_search_config.py b/server/szurubooru/search/tag_search_config.py index e41fab0d..c72ffba3 100644 --- a/server/szurubooru/search/tag_search_config.py +++ b/server/szurubooru/search/tag_search_config.py @@ -4,7 +4,7 @@ from szurubooru.search.base_search_config import BaseSearchConfig class TagSearchConfig(BaseSearchConfig): def create_query(self): - return db.session().query(db.Tag) + return db.session.query(db.Tag) def finalize_query(self, query): return query.order_by(db.Tag.first_name.asc()) diff --git a/server/szurubooru/search/user_search_config.py b/server/szurubooru/search/user_search_config.py index 515bb459..c7945fd7 100644 --- a/server/szurubooru/search/user_search_config.py +++ b/server/szurubooru/search/user_search_config.py @@ -6,7 +6,7 @@ class UserSearchConfig(BaseSearchConfig): ''' Executes searches related to the users. ''' def create_query(self): - return db.session().query(db.User) + return db.session.query(db.User) def finalize_query(self, query): return query.order_by(db.User.name.asc()) diff --git a/server/szurubooru/tests/api/test_password_reset.py b/server/szurubooru/tests/api/test_password_reset.py index c2970b6f..34f6f023 100644 --- a/server/szurubooru/tests/api/test_password_reset.py +++ b/server/szurubooru/tests/api/test_password_reset.py @@ -14,8 +14,8 @@ def password_reset_api(config_injector): return api.PasswordResetApi() def test_reset_sending_email( - password_reset_api, session, context_factory, user_factory): - session.add(user_factory( + password_reset_api, context_factory, user_factory): + db.session.add(user_factory( name='u1', rank='regular_user', email='user@example.com')) for getter in ['u1', 'user@example.com']: mailer.send_mail = mock.MagicMock() @@ -34,17 +34,17 @@ def test_trying_to_reset_non_existing(password_reset_api, context_factory): password_reset_api.get(context_factory(), 'u1') def test_trying_to_reset_without_email( - password_reset_api, session, context_factory, user_factory): - session.add(user_factory(name='u1', rank='regular_user', email=None)) + password_reset_api, context_factory, user_factory): + db.session.add(user_factory(name='u1', rank='regular_user', email=None)) with pytest.raises(errors.ValidationError): password_reset_api.get(context_factory(), 'u1') def test_confirming_with_good_token( - password_reset_api, context_factory, session, user_factory): + password_reset_api, context_factory, user_factory): user = user_factory( name='u1', rank='regular_user', email='user@example.com') old_hash = user.password_hash - session.add(user) + db.session.add(user) context = context_factory( input={'token': '4ac0be176fb364f13ee6b634c43220e2'}) result = password_reset_api.post(context, 'u1') @@ -56,15 +56,15 @@ def test_trying_to_confirm_non_existing(password_reset_api, context_factory): password_reset_api.post(context_factory(), 'u1') def test_trying_to_confirm_without_token( - password_reset_api, context_factory, session, user_factory): - session.add(user_factory( + password_reset_api, context_factory, user_factory): + db.session.add(user_factory( name='u1', rank='regular_user', email='user@example.com')) with pytest.raises(errors.ValidationError): password_reset_api.post(context_factory(input={}), 'u1') def test_trying_to_confirm_with_bad_token( - password_reset_api, context_factory, session, user_factory): - session.add(user_factory( + password_reset_api, context_factory, user_factory): + db.session.add(user_factory( name='u1', rank='regular_user', email='user@example.com')) with pytest.raises(errors.ValidationError): password_reset_api.post( diff --git a/server/szurubooru/tests/api/test_tag_category_deleting.py b/server/szurubooru/tests/api/test_tag_category_deleting.py index 7d7422ed..115c9060 100644 --- a/server/szurubooru/tests/api/test_tag_category_deleting.py +++ b/server/szurubooru/tests/api/test_tag_category_deleting.py @@ -28,35 +28,35 @@ def test_ctx( return ret def test_deleting(test_ctx): - db.session().add(test_ctx.tag_category_factory(name='root')) - db.session().add(test_ctx.tag_category_factory(name='category')) - db.session().commit() + db.session.add(test_ctx.tag_category_factory(name='root')) + db.session.add(test_ctx.tag_category_factory(name='category')) + db.session.commit() result = test_ctx.api.delete( test_ctx.context_factory( user=test_ctx.user_factory(rank='regular_user')), 'category') assert result == {} - assert db.session().query(db.TagCategory).count() == 1 - assert db.session().query(db.TagCategory).one().name == 'root' + assert db.session.query(db.TagCategory).count() == 1 + assert db.session.query(db.TagCategory).one().name == 'root' assert os.path.exists(os.path.join(config.config['data_dir'], 'tags.json')) def test_trying_to_delete_used(test_ctx, tag_factory): category = test_ctx.tag_category_factory(name='category') - db.session().add(category) - db.session().flush() + db.session.add(category) + db.session.flush() tag = test_ctx.tag_factory(names=['tag'], category=category) - db.session().add(tag) - db.session().commit() + db.session.add(tag) + db.session.commit() with pytest.raises(tag_categories.TagCategoryIsInUseError): test_ctx.api.delete( test_ctx.context_factory( user=test_ctx.user_factory(rank='regular_user')), 'category') - assert db.session().query(db.TagCategory).count() == 1 + assert db.session.query(db.TagCategory).count() == 1 def test_trying_to_delete_last(test_ctx, tag_factory): - db.session().add(test_ctx.tag_category_factory(name='root')) - db.session().commit() + db.session.add(test_ctx.tag_category_factory(name='root')) + db.session.commit() with pytest.raises(tag_categories.TagCategoryIsInUseError): result = test_ctx.api.delete( test_ctx.context_factory( @@ -70,11 +70,11 @@ def test_trying_to_delete_non_existing(test_ctx): user=test_ctx.user_factory(rank='regular_user')), 'bad') def test_trying_to_delete_without_privileges(test_ctx): - db.session().add(test_ctx.tag_category_factory(name='category')) - db.session().commit() + db.session.add(test_ctx.tag_category_factory(name='category')) + db.session.commit() with pytest.raises(errors.AuthError): test_ctx.api.delete( test_ctx.context_factory( user=test_ctx.user_factory(rank='anonymous')), 'category') - assert db.session().query(db.TagCategory).count() == 1 + assert db.session.query(db.TagCategory).count() == 1 diff --git a/server/szurubooru/tests/api/test_tag_category_retrieving.py b/server/szurubooru/tests/api/test_tag_category_retrieving.py index 42fd68c8..c581b7fd 100644 --- a/server/szurubooru/tests/api/test_tag_category_retrieving.py +++ b/server/szurubooru/tests/api/test_tag_category_retrieving.py @@ -24,7 +24,7 @@ def test_ctx( return ret def test_retrieving_multiple(test_ctx): - db.session().add_all([ + db.session.add_all([ test_ctx.tag_category_factory(name='c1'), test_ctx.tag_category_factory(name='c2'), ]) @@ -34,7 +34,7 @@ def test_retrieving_multiple(test_ctx): assert [cat['name'] for cat in result['tagCategories']] == ['c1', 'c2'] def test_retrieving_single(test_ctx): - db.session().add(test_ctx.tag_category_factory(name='cat')) + db.session.add(test_ctx.tag_category_factory(name='cat')) result = test_ctx.detail_api.get( test_ctx.context_factory( user=test_ctx.user_factory(rank='regular_user')), diff --git a/server/szurubooru/tests/api/test_tag_category_updating.py b/server/szurubooru/tests/api/test_tag_category_updating.py index 0a8d9223..be576d01 100644 --- a/server/szurubooru/tests/api/test_tag_category_updating.py +++ b/server/szurubooru/tests/api/test_tag_category_updating.py @@ -28,8 +28,8 @@ def test_ctx( def test_simple_updating(test_ctx): category = test_ctx.tag_category_factory(name='name', color='black') - db.session().add(category) - db.session().commit() + db.session.add(category) + db.session.commit() result = test_ctx.api.put( test_ctx.context_factory( input={ @@ -59,8 +59,8 @@ def test_simple_updating(test_ctx): {'color': ''}, ]) def test_trying_to_pass_invalid_input(test_ctx, input): - db.session().add(test_ctx.tag_category_factory(name='meta', color='black')) - db.session().commit() + db.session.add(test_ctx.tag_category_factory(name='meta', color='black')) + db.session.commit() with pytest.raises(tag_categories.InvalidTagCategoryNameError): test_ctx.api.put( test_ctx.context_factory( @@ -70,8 +70,8 @@ def test_trying_to_pass_invalid_input(test_ctx, input): @pytest.mark.parametrize('field', ['name', 'color']) def test_omitting_optional_field(test_ctx, tmpdir, field): - db.session().add(test_ctx.tag_category_factory(name='name', color='black')) - db.session().commit() + db.session.add(test_ctx.tag_category_factory(name='name', color='black')) + db.session.commit() input = { 'name': 'changed', 'color': 'white', @@ -94,8 +94,8 @@ def test_trying_to_update_non_existing(test_ctx): @pytest.mark.parametrize('new_name', ['cat', 'CAT']) def test_reusing_own_name(test_ctx, new_name): - db.session().add(test_ctx.tag_category_factory(name='cat', color='black')) - db.session().commit() + db.session.add(test_ctx.tag_category_factory(name='cat', color='black')) + db.session.commit() result = test_ctx.api.put( test_ctx.context_factory( input={'name': new_name}, @@ -107,10 +107,10 @@ def test_reusing_own_name(test_ctx, new_name): @pytest.mark.parametrize('dup_name', ['cat1', 'CAT1']) def test_trying_to_use_existing_name(test_ctx, dup_name): - db.session().add_all([ + db.session.add_all([ test_ctx.tag_category_factory(name='cat1', color='black'), test_ctx.tag_category_factory(name='cat2', color='black')]) - db.session().commit() + db.session.commit() with pytest.raises(tag_categories.TagCategoryAlreadyExistsError): test_ctx.api.put( test_ctx.context_factory( @@ -123,8 +123,8 @@ def test_trying_to_use_existing_name(test_ctx, dup_name): {'color': 'whatever'}, ]) def test_trying_to_update_without_privileges(test_ctx, input): - db.session().add(test_ctx.tag_category_factory(name='dummy')) - db.session().commit() + db.session.add(test_ctx.tag_category_factory(name='dummy')) + db.session.commit() with pytest.raises(errors.AuthError): test_ctx.api.put( test_ctx.context_factory( diff --git a/server/szurubooru/tests/api/test_tag_creating.py b/server/szurubooru/tests/api/test_tag_creating.py index f62131e9..c9ffe9a2 100644 --- a/server/szurubooru/tests/api/test_tag_creating.py +++ b/server/szurubooru/tests/api/test_tag_creating.py @@ -4,8 +4,9 @@ import pytest from szurubooru import api, config, db, errors from szurubooru.util import misc, tags -def get_tag(session, name): - return session.query(db.Tag) \ +def get_tag(name): + return db.session \ + .query(db.Tag) \ .join(db.TagName) \ .filter(db.TagName.name==name) \ .first() @@ -16,24 +17,17 @@ def assert_relations(relations, expected_tag_names): @pytest.fixture def test_ctx( - tmpdir, - session, - config_injector, - context_factory, - user_factory, - tag_factory): + tmpdir, config_injector, context_factory, user_factory, tag_factory): config_injector({ 'data_dir': str(tmpdir), 'tag_name_regex': '^[^!]*$', 'ranks': ['anonymous', 'regular_user'], 'privileges': {'tags:create': 'regular_user'}, }) - session.add_all([ - db.TagCategory(name) for name in [ - 'meta', 'character', 'copyright']]) - session.flush() + db.session.add_all([ + db.TagCategory(name) for name in ['meta', 'character', 'copyright']]) + db.session.flush() ret = misc.dotdict() - ret.session = session ret.context_factory = context_factory ret.user_factory = user_factory ret.tag_factory = tag_factory @@ -61,7 +55,7 @@ def test_creating_simple_tags(test_ctx, fake_datetime): 'lastEditTime': None, } } - tag = get_tag(test_ctx.session, 'tag1') + tag = get_tag('tag1') assert [tag_name.name for tag_name in tag.names] == ['tag1', 'tag2'] assert tag.category.name == 'meta' assert tag.last_edit_time is None @@ -140,15 +134,15 @@ def test_duplicating_names(test_ctx): user=test_ctx.user_factory(rank='regular_user'))) assert result['tag']['names'] == ['tag1'] assert result['tag']['category'] == 'meta' - tag = get_tag(test_ctx.session, 'tag1') + tag = get_tag('tag1') assert [tag_name.name for tag_name in tag.names] == ['tag1'] def test_trying_to_use_existing_name(test_ctx): - test_ctx.session.add_all([ + db.session.add_all([ test_ctx.tag_factory(names=['used1'], category_name='meta'), test_ctx.tag_factory(names=['used2'], category_name='meta'), ]) - test_ctx.session.commit() + db.session.commit() with pytest.raises(tags.TagAlreadyExistsError): test_ctx.api.post( test_ctx.context_factory( @@ -169,7 +163,7 @@ def test_trying_to_use_existing_name(test_ctx): 'implications': [], }, user=test_ctx.user_factory(rank='regular_user'))) - assert get_tag(test_ctx.session, 'unused') is None + assert get_tag('unused') is None @pytest.mark.parametrize('input,expected_suggestions,expected_implications', [ # new relations @@ -208,18 +202,18 @@ def test_creating_new_suggestions_and_implications( input=input, user=test_ctx.user_factory(rank='regular_user'))) assert result['tag']['suggestions'] == expected_suggestions assert result['tag']['implications'] == expected_implications - tag = get_tag(test_ctx.session, 'main') + tag = get_tag('main') assert_relations(tag.suggestions, expected_suggestions) assert_relations(tag.implications, expected_implications) for name in ['main'] + expected_suggestions + expected_implications: - assert get_tag(test_ctx.session, name) is not None + assert get_tag(name) is not None def test_reusing_suggestions_and_implications(test_ctx): - test_ctx.session.add_all([ + db.session.add_all([ test_ctx.tag_factory(names=['tag1', 'tag2'], category_name='meta'), test_ctx.tag_factory(names=['tag3'], category_name='meta'), ]) - test_ctx.session.commit() + db.session.commit() result = test_ctx.api.post( test_ctx.context_factory( input={ @@ -232,7 +226,7 @@ def test_reusing_suggestions_and_implications(test_ctx): # NOTE: it should export only the first name assert result['tag']['suggestions'] == ['tag1'] assert result['tag']['implications'] == ['tag1'] - tag = get_tag(test_ctx.session, 'new') + tag = get_tag('new') assert_relations(tag.suggestions, ['tag1']) assert_relations(tag.implications, ['tag1']) @@ -256,7 +250,7 @@ def test_tag_trying_to_relate_to_itself(test_ctx, input): test_ctx.context_factory( input=input, user=test_ctx.user_factory(rank='regular_user'))) - assert get_tag(test_ctx.session, 'tag') is None + assert get_tag('tag') is None def test_trying_to_create_tag_without_privileges(test_ctx): with pytest.raises(errors.AuthError): diff --git a/server/szurubooru/tests/api/test_tag_deleting.py b/server/szurubooru/tests/api/test_tag_deleting.py index f41fffda..67fa1113 100644 --- a/server/szurubooru/tests/api/test_tag_deleting.py +++ b/server/szurubooru/tests/api/test_tag_deleting.py @@ -6,12 +6,7 @@ from szurubooru.util import misc, tags @pytest.fixture def test_ctx( - tmpdir, - session, - config_injector, - context_factory, - tag_factory, - user_factory): + tmpdir, config_injector, context_factory, tag_factory, user_factory): config_injector({ 'data_dir': str(tmpdir), 'privileges': { @@ -20,7 +15,6 @@ def test_ctx( 'ranks': ['anonymous', 'regular_user'], }) ret = misc.dotdict() - ret.session = session ret.context_factory = context_factory ret.user_factory = user_factory ret.tag_factory = tag_factory @@ -28,28 +22,28 @@ def test_ctx( return ret def test_deleting(test_ctx): - test_ctx.session.add(test_ctx.tag_factory(names=['tag'])) - test_ctx.session.commit() + db.session.add(test_ctx.tag_factory(names=['tag'])) + db.session.commit() result = test_ctx.api.delete( test_ctx.context_factory( user=test_ctx.user_factory(rank='regular_user')), 'tag') assert result == {} - assert test_ctx.session.query(db.Tag).count() == 0 + assert db.session.query(db.Tag).count() == 0 assert os.path.exists(os.path.join(config.config['data_dir'], 'tags.json')) def test_trying_to_delete_used(test_ctx, post_factory): tag = test_ctx.tag_factory(names=['tag']) post = post_factory() post.tags.append(tag) - test_ctx.session.add_all([tag, post]) - test_ctx.session.commit() + db.session.add_all([tag, post]) + db.session.commit() with pytest.raises(tags.TagIsInUseError): test_ctx.api.delete( test_ctx.context_factory( user=test_ctx.user_factory(rank='regular_user')), 'tag') - assert test_ctx.session.query(db.Tag).count() == 1 + assert db.session.query(db.Tag).count() == 1 def test_trying_to_delete_non_existing(test_ctx): with pytest.raises(tags.TagNotFoundError): @@ -58,11 +52,11 @@ def test_trying_to_delete_non_existing(test_ctx): user=test_ctx.user_factory(rank='regular_user')), 'bad') def test_trying_to_delete_without_privileges(test_ctx): - test_ctx.session.add(test_ctx.tag_factory(names=['tag'])) - test_ctx.session.commit() + db.session.add(test_ctx.tag_factory(names=['tag'])) + db.session.commit() with pytest.raises(errors.AuthError): test_ctx.api.delete( test_ctx.context_factory( user=test_ctx.user_factory(rank='anonymous')), 'tag') - assert test_ctx.session.query(db.Tag).count() == 1 + assert db.session.query(db.Tag).count() == 1 diff --git a/server/szurubooru/tests/api/test_tag_export.py b/server/szurubooru/tests/api/test_tag_export.py index 3e0eb1ca..34592dc0 100644 --- a/server/szurubooru/tests/api/test_tag_export.py +++ b/server/szurubooru/tests/api/test_tag_export.py @@ -7,7 +7,6 @@ from szurubooru.util import tags def test_export( tmpdir, query_counter, - session, config_injector, tag_factory, tag_category_factory): @@ -16,23 +15,23 @@ def test_export( }) cat1 = tag_category_factory(name='cat1', color='black') cat2 = tag_category_factory(name='cat2', color='white') - session.add_all([cat1, cat2]) - session.flush() + db.session.add_all([cat1, cat2]) + db.session.flush() sug1 = tag_factory(names=['sug1'], category=cat1) sug2 = tag_factory(names=['sug2'], category=cat1) imp1 = tag_factory(names=['imp1'], category=cat1) imp2 = tag_factory(names=['imp2'], category=cat1) tag = tag_factory(names=['alias1', 'alias2'], category=cat2) tag.post_count = 1 - session.add_all([tag, sug1, sug2, imp1, imp2, cat1, cat2]) - session.flush() - session.add_all([ + db.session.add_all([tag, sug1, sug2, imp1, imp2, cat1, cat2]) + db.session.flush() + db.session.add_all([ db.TagSuggestion(tag.tag_id, sug1.tag_id), db.TagSuggestion(tag.tag_id, sug2.tag_id), db.TagImplication(tag.tag_id, imp1.tag_id), db.TagImplication(tag.tag_id, imp2.tag_id), ]) - session.flush() + db.session.flush() with query_counter: tags.export_to_json() diff --git a/server/szurubooru/tests/api/test_tag_retrieving.py b/server/szurubooru/tests/api/test_tag_retrieving.py index 4650de35..a401c9ae 100644 --- a/server/szurubooru/tests/api/test_tag_retrieving.py +++ b/server/szurubooru/tests/api/test_tag_retrieving.py @@ -4,8 +4,7 @@ from szurubooru import api, db, errors from szurubooru.util import misc, tags @pytest.fixture -def test_ctx( - session, context_factory, config_injector, user_factory, tag_factory): +def test_ctx(context_factory, config_injector, user_factory, tag_factory): config_injector({ 'privileges': { 'tags:list': 'regular_user', @@ -16,7 +15,6 @@ def test_ctx( 'rank_names': {'regular_user': 'Peasant'}, }) ret = misc.dotdict() - ret.session = session ret.context_factory = context_factory ret.user_factory = user_factory ret.tag_factory = tag_factory @@ -27,7 +25,7 @@ def test_ctx( def test_retrieving_multiple(test_ctx): tag1 = test_ctx.tag_factory(names=['t1']) tag2 = test_ctx.tag_factory(names=['t2']) - test_ctx.session.add_all([tag1, tag2]) + db.session.add_all([tag1, tag2]) result = test_ctx.list_api.get( test_ctx.context_factory( input={'query': '', 'page': 1}, @@ -46,7 +44,7 @@ def test_trying_to_retrieve_multiple_without_privileges(test_ctx): user=test_ctx.user_factory(rank='anonymous'))) def test_retrieving_single(test_ctx): - test_ctx.session.add(test_ctx.tag_factory(names=['tag'])) + db.session.add(test_ctx.tag_factory(names=['tag'])) result = test_ctx.detail_api.get( test_ctx.context_factory( input={'query': '', 'page': 1}, diff --git a/server/szurubooru/tests/api/test_tag_updating.py b/server/szurubooru/tests/api/test_tag_updating.py index 7fca1df6..381b6ae8 100644 --- a/server/szurubooru/tests/api/test_tag_updating.py +++ b/server/szurubooru/tests/api/test_tag_updating.py @@ -4,8 +4,9 @@ import pytest from szurubooru import api, config, db, errors from szurubooru.util import misc, tags -def get_tag(session, name): - return session.query(db.Tag) \ +def get_tag(name): + return db.session \ + .query(db.Tag) \ .join(db.TagName) \ .filter(db.TagName.name==name) \ .first() @@ -16,12 +17,7 @@ def assert_relations(relations, expected_tag_names): @pytest.fixture def test_ctx( - tmpdir, - session, - config_injector, - context_factory, - user_factory, - tag_factory): + tmpdir, config_injector, context_factory, user_factory, tag_factory): config_injector({ 'data_dir': str(tmpdir), 'tag_name_regex': '^[^!]*$', @@ -33,12 +29,10 @@ def test_ctx( 'tags:edit:implications': 'regular_user', }, }) - session.add_all([ - db.TagCategory(name) for name in [ - 'meta', 'character', 'copyright']]) - session.flush() + db.session.add_all([ + db.TagCategory(name) for name in ['meta', 'character', 'copyright']]) + db.session.flush() ret = misc.dotdict() - ret.session = session ret.context_factory = context_factory ret.user_factory = user_factory ret.tag_factory = tag_factory @@ -47,8 +41,8 @@ def test_ctx( def test_simple_updating(test_ctx, fake_datetime): tag = test_ctx.tag_factory(names=['tag1', 'tag2'], category_name='meta') - test_ctx.session.add(tag) - test_ctx.session.commit() + db.session.add(tag) + db.session.commit() with fake_datetime('1997-12-01'): result = test_ctx.api.put( test_ctx.context_factory( @@ -68,9 +62,9 @@ def test_simple_updating(test_ctx, fake_datetime): 'lastEditTime': datetime.datetime(1997, 12, 1), } } - assert get_tag(test_ctx.session, 'tag1') is None - assert get_tag(test_ctx.session, 'tag2') is None - tag = get_tag(test_ctx.session, 'tag3') + assert get_tag('tag1') is None + assert get_tag('tag2') is None + tag = get_tag('tag3') assert tag is not None assert [tag_name.name for tag_name in tag.names] == ['tag3'] assert tag.category.name == 'character' @@ -92,9 +86,8 @@ def test_simple_updating(test_ctx, fake_datetime): ({'implications': ['good', '!bad']}, tags.InvalidTagNameError), ]) def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception): - test_ctx.session.add( - test_ctx.tag_factory(names=['tag1'], category_name='meta')) - test_ctx.session.commit() + db.session.add(test_ctx.tag_factory(names=['tag1'], category_name='meta')) + db.session.commit() with pytest.raises(expected_exception): test_ctx.api.put( test_ctx.context_factory( @@ -105,9 +98,8 @@ def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception): @pytest.mark.parametrize( 'field', ['names', 'category', 'implications', 'suggestions']) def test_omitting_optional_field(test_ctx, tmpdir, field): - test_ctx.session.add( - test_ctx.tag_factory(names=['tag'], category_name='meta')) - test_ctx.session.commit() + db.session.add(test_ctx.tag_factory(names=['tag'], category_name='meta')) + db.session.commit() input = { 'names': ['tag1', 'tag2'], 'category': 'meta', @@ -132,23 +124,23 @@ def test_trying_to_update_non_existing(test_ctx): @pytest.mark.parametrize('dup_name', ['tag1', 'TAG1']) def test_reusing_own_name(test_ctx, dup_name): - test_ctx.session.add( + db.session.add( test_ctx.tag_factory(names=['tag1', 'tag2'], category_name='meta')) - test_ctx.session.commit() + db.session.commit() result = test_ctx.api.put( test_ctx.context_factory( input={'names': [dup_name, 'tag3']}, user=test_ctx.user_factory(rank='regular_user')), 'tag1') assert result['tag']['names'] == ['tag1', 'tag3'] - assert get_tag(test_ctx.session, 'tag2') is None - tag1 = get_tag(test_ctx.session, 'tag1') - tag2 = get_tag(test_ctx.session, 'tag3') + assert get_tag('tag2') is None + tag1 = get_tag('tag1') + tag2 = get_tag('tag3') assert tag1.tag_id == tag2.tag_id assert [name.name for name in tag1.names] == ['tag1', 'tag3'] def test_duplicating_names(test_ctx): - test_ctx.session.add( + db.session.add( test_ctx.tag_factory(names=['tag1', 'tag2'], category_name='meta')) result = test_ctx.api.put( test_ctx.context_factory( @@ -156,18 +148,18 @@ def test_duplicating_names(test_ctx): user=test_ctx.user_factory(rank='regular_user')), 'tag1') assert result['tag']['names'] == ['tag3'] - assert get_tag(test_ctx.session, 'tag1') is None - assert get_tag(test_ctx.session, 'tag2') is None - tag = get_tag(test_ctx.session, 'tag3') + assert get_tag('tag1') is None + assert get_tag('tag2') is None + tag = get_tag('tag3') assert tag is not None assert [tag_name.name for tag_name in tag.names] == ['tag3'] @pytest.mark.parametrize('dup_name', ['tag1', 'TAG1', 'tag2', 'TAG2']) def test_trying_to_use_existing_name(test_ctx, dup_name): - test_ctx.session.add_all([ + db.session.add_all([ test_ctx.tag_factory(names=['tag1', 'tag2'], category_name='meta'), test_ctx.tag_factory(names=['tag3', 'tag4'], category_name='meta')]) - test_ctx.session.commit() + db.session.commit() with pytest.raises(tags.TagAlreadyExistsError): test_ctx.api.put( test_ctx.context_factory( @@ -199,28 +191,28 @@ def test_trying_to_use_existing_name(test_ctx, dup_name): ]) def test_updating_new_suggestions_and_implications( test_ctx, input, expected_suggestions, expected_implications): - test_ctx.session.add( + db.session.add( test_ctx.tag_factory(names=['main'], category_name='meta')) - test_ctx.session.commit() + db.session.commit() result = test_ctx.api.put( test_ctx.context_factory( input=input, user=test_ctx.user_factory(rank='regular_user')), 'main') assert result['tag']['suggestions'] == expected_suggestions assert result['tag']['implications'] == expected_implications - tag = get_tag(test_ctx.session, 'main') + tag = get_tag('main') assert_relations(tag.suggestions, expected_suggestions) assert_relations(tag.implications, expected_implications) for name in ['main'] + expected_suggestions + expected_implications: - assert get_tag(test_ctx.session, name) is not None + assert get_tag(name) is not None def test_reusing_suggestions_and_implications(test_ctx): - test_ctx.session.add_all([ + db.session.add_all([ test_ctx.tag_factory(names=['tag1', 'tag2'], category_name='meta'), test_ctx.tag_factory(names=['tag3'], category_name='meta'), test_ctx.tag_factory(names=['tag4'], category_name='meta'), ]) - test_ctx.session.commit() + db.session.commit() result = test_ctx.api.put( test_ctx.context_factory( input={ @@ -234,7 +226,7 @@ def test_reusing_suggestions_and_implications(test_ctx): # NOTE: it should export only the first name assert result['tag']['suggestions'] == ['tag1'] assert result['tag']['implications'] == ['tag1'] - tag = get_tag(test_ctx.session, 'new') + tag = get_tag('new') assert_relations(tag.suggestions, ['tag1']) assert_relations(tag.implications, ['tag1']) @@ -253,9 +245,8 @@ def test_reusing_suggestions_and_implications(test_ctx): } ]) def test_trying_to_relate_tag_to_itself(test_ctx, input): - test_ctx.session.add( - test_ctx.tag_factory(names=['tag1'], category_name='meta')) - test_ctx.session.commit() + db.session.add(test_ctx.tag_factory(names=['tag1'], category_name='meta')) + db.session.commit() with pytest.raises(tags.InvalidTagRelationError): test_ctx.api.put( test_ctx.context_factory( @@ -269,9 +260,8 @@ def test_trying_to_relate_tag_to_itself(test_ctx, input): {'implications': ['whatever']}, ]) def test_trying_to_update_without_privileges(test_ctx, input): - test_ctx.session.add( - test_ctx.tag_factory(names=['tag'], category_name='meta')) - test_ctx.session.commit() + db.session.add(test_ctx.tag_factory(names=['tag'], category_name='meta')) + db.session.commit() with pytest.raises(errors.AuthError): test_ctx.api.put( test_ctx.context_factory( diff --git a/server/szurubooru/tests/api/test_user_creating.py b/server/szurubooru/tests/api/test_user_creating.py index 64475461..8da53b49 100644 --- a/server/szurubooru/tests/api/test_user_creating.py +++ b/server/szurubooru/tests/api/test_user_creating.py @@ -8,12 +8,11 @@ EMPTY_PIXEL = \ b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \ b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b' -def get_user(session, name): - return session.query(db.User).filter_by(name=name).first() +def get_user(name): + return db.session.query(db.User).filter_by(name=name).first() @pytest.fixture -def test_ctx( - session, config_injector, context_factory, user_factory): +def test_ctx(config_injector, context_factory, user_factory): config_injector({ 'secret': '', 'user_name_regex': '[^!]{3,}', @@ -25,7 +24,6 @@ def test_ctx( 'privileges': {'users:create': 'anonymous'}, }) ret = misc.dotdict() - ret.session = session ret.context_factory = context_factory ret.user_factory = user_factory ret.api = api.UserListApi() @@ -53,7 +51,7 @@ def test_creating_user(test_ctx, fake_datetime): 'rankName': 'Unknown', } } - user = get_user(test_ctx.session, 'chewie1') + user = get_user('chewie1') assert user.name == 'chewie1' assert user.email == 'asd@asd.asd' assert user.rank == 'admin' @@ -79,8 +77,8 @@ def test_first_user_becomes_admin_others_not(test_ctx): user=test_ctx.user_factory(rank='anonymous'))) assert result1['user']['rank'] == 'admin' assert result2['user']['rank'] == 'regular_user' - first_user = get_user(test_ctx.session, 'chewie1') - other_user = get_user(test_ctx.session, 'chewie2') + first_user = get_user('chewie1') + other_user = get_user('chewie2') assert first_user.rank == 'admin' assert other_user.rank == 'regular_user' @@ -192,7 +190,7 @@ def test_omitting_optional_field(test_ctx, tmpdir, field): def test_mods_trying_to_become_admin(test_ctx): user1 = test_ctx.user_factory(name='u1', rank='mod') user2 = test_ctx.user_factory(name='u2', rank='mod') - test_ctx.session.add_all([user1, user2]) + db.session.add_all([user1, user2]) context = test_ctx.context_factory(input={ 'name': 'chewie', 'email': 'asd@asd.asd', @@ -204,7 +202,7 @@ def test_mods_trying_to_become_admin(test_ctx): def test_admin_creating_mod_account(test_ctx): user = test_ctx.user_factory(rank='admin') - test_ctx.session.add(user) + db.session.add(user) context = test_ctx.context_factory(input={ 'name': 'chewie', 'email': 'asd@asd.asd', @@ -227,7 +225,7 @@ def test_uploading_avatar(test_ctx, tmpdir): }, files={'avatar': EMPTY_PIXEL}, user=test_ctx.user_factory(rank='mod'))) - user = get_user(test_ctx.session, 'chewie') + user = get_user('chewie') assert user.avatar_style == user.AVATAR_MANUAL assert response['user']['avatarUrl'] == \ 'http://example.com/data/avatars/chewie.jpg' diff --git a/server/szurubooru/tests/api/test_user_deleting.py b/server/szurubooru/tests/api/test_user_deleting.py index cd5d8441..dcbcde53 100644 --- a/server/szurubooru/tests/api/test_user_deleting.py +++ b/server/szurubooru/tests/api/test_user_deleting.py @@ -4,7 +4,7 @@ from szurubooru import api, db, errors from szurubooru.util import misc, users @pytest.fixture -def test_ctx(session, config_injector, context_factory, user_factory): +def test_ctx(config_injector, context_factory, user_factory): config_injector({ 'privileges': { 'users:delete:self': 'regular_user', @@ -13,7 +13,6 @@ def test_ctx(session, config_injector, context_factory, user_factory): 'ranks': ['anonymous', 'regular_user', 'mod', 'admin'], }) ret = misc.dotdict() - ret.session = session ret.context_factory = context_factory ret.user_factory = user_factory ret.api = api.UserDetailApi() @@ -21,28 +20,28 @@ def test_ctx(session, config_injector, context_factory, user_factory): def test_deleting_oneself(test_ctx): user = test_ctx.user_factory(name='u', rank='regular_user') - test_ctx.session.add(user) - test_ctx.session.commit() + db.session.add(user) + db.session.commit() result = test_ctx.api.delete(test_ctx.context_factory(user=user), 'u') assert result == {} - assert test_ctx.session.query(db.User).count() == 0 + assert db.session.query(db.User).count() == 0 def test_deleting_someone_else(test_ctx): user1 = test_ctx.user_factory(name='u1', rank='regular_user') user2 = test_ctx.user_factory(name='u2', rank='mod') - test_ctx.session.add_all([user1, user2]) - test_ctx.session.commit() + db.session.add_all([user1, user2]) + db.session.commit() test_ctx.api.delete(test_ctx.context_factory(user=user2), 'u1') - assert test_ctx.session.query(db.User).count() == 1 + assert db.session.query(db.User).count() == 1 def test_trying_to_delete_someone_else_without_privileges(test_ctx): user1 = test_ctx.user_factory(name='u1', rank='regular_user') user2 = test_ctx.user_factory(name='u2', rank='regular_user') - test_ctx.session.add_all([user1, user2]) - test_ctx.session.commit() + db.session.add_all([user1, user2]) + db.session.commit() with pytest.raises(errors.AuthError): test_ctx.api.delete(test_ctx.context_factory(user=user2), 'u1') - assert test_ctx.session.query(db.User).count() == 2 + assert db.session.query(db.User).count() == 2 def test_trying_to_delete_non_existing(test_ctx): with pytest.raises(users.UserNotFoundError): diff --git a/server/szurubooru/tests/api/test_user_retrieving.py b/server/szurubooru/tests/api/test_user_retrieving.py index f27cacf3..68910572 100644 --- a/server/szurubooru/tests/api/test_user_retrieving.py +++ b/server/szurubooru/tests/api/test_user_retrieving.py @@ -4,7 +4,7 @@ from szurubooru import api, db, errors from szurubooru.util import misc, users @pytest.fixture -def test_ctx(session, context_factory, config_injector, user_factory): +def test_ctx(context_factory, config_injector, user_factory): config_injector({ 'privileges': { 'users:list': 'regular_user', @@ -15,7 +15,6 @@ def test_ctx(session, context_factory, config_injector, user_factory): 'rank_names': {'regular_user': 'Peasant'}, }) ret = misc.dotdict() - ret.session = session ret.context_factory = context_factory ret.user_factory = user_factory ret.list_api = api.UserListApi() @@ -25,7 +24,7 @@ def test_ctx(session, context_factory, config_injector, user_factory): def test_retrieving_multiple(test_ctx): user1 = test_ctx.user_factory(name='u1', rank='mod') user2 = test_ctx.user_factory(name='u2', rank='mod') - test_ctx.session.add_all([user1, user2]) + db.session.add_all([user1, user2]) result = test_ctx.list_api.get( test_ctx.context_factory( input={'query': '', 'page': 1}, @@ -44,7 +43,7 @@ def test_trying_to_retrieve_multiple_without_privileges(test_ctx): user=test_ctx.user_factory(rank='anonymous'))) def test_retrieving_single(test_ctx): - test_ctx.session.add(test_ctx.user_factory(name='u1', rank='regular_user')) + db.session.add(test_ctx.user_factory(name='u1', rank='regular_user')) result = test_ctx.detail_api.get( test_ctx.context_factory( input={'query': '', 'page': 1}, diff --git a/server/szurubooru/tests/api/test_user_updating.py b/server/szurubooru/tests/api/test_user_updating.py index cff44fce..58c7e1c8 100644 --- a/server/szurubooru/tests/api/test_user_updating.py +++ b/server/szurubooru/tests/api/test_user_updating.py @@ -8,12 +8,11 @@ EMPTY_PIXEL = \ b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \ b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b' -def get_user(session, name): - return session.query(db.User).filter_by(name=name).first() +def get_user(name): + return db.session.query(db.User).filter_by(name=name).first() @pytest.fixture -def test_ctx( - session, config_injector, context_factory, user_factory): +def test_ctx(config_injector, context_factory, user_factory): config_injector({ 'secret': '', 'user_name_regex': '^[^!]{3,}$', @@ -35,7 +34,6 @@ def test_ctx( }, }) ret = misc.dotdict() - ret.session = session ret.context_factory = context_factory ret.user_factory = user_factory ret.api = api.UserDetailApi() @@ -43,7 +41,7 @@ def test_ctx( def test_updating_user(test_ctx): user = test_ctx.user_factory(name='u1', rank='admin') - test_ctx.session.add(user) + db.session.add(user) result = test_ctx.api.put( test_ctx.context_factory( input={ @@ -68,7 +66,7 @@ def test_updating_user(test_ctx): 'rankName': 'Unknown', } } - user = get_user(test_ctx.session, 'chewie') + user = get_user('chewie') assert user.name == 'chewie' assert user.email == 'asd@asd.asd' assert user.rank == 'mod' @@ -96,7 +94,7 @@ def test_updating_user(test_ctx): ]) def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception): user = test_ctx.user_factory(name='u1', rank='admin') - test_ctx.session.add(user) + db.session.add(user) with pytest.raises(expected_exception): test_ctx.api.put( test_ctx.context_factory(input=input, user=user), 'u1') @@ -107,7 +105,7 @@ def test_omitting_optional_field(test_ctx, tmpdir, field): config.config['data_dir'] = str(tmpdir.mkdir('data')) config.config['data_url'] = 'http://example.com/data/' user = test_ctx.user_factory(name='u1', rank='admin') - test_ctx.session.add(user) + db.session.add(user) input = { 'name': 'chewie', 'email': 'asd@asd.asd', @@ -126,16 +124,16 @@ def test_omitting_optional_field(test_ctx, tmpdir, field): def test_trying_to_update_non_existing(test_ctx): user = test_ctx.user_factory(name='u1', rank='admin') - test_ctx.session.add(user) + db.session.add(user) with pytest.raises(users.UserNotFoundError): test_ctx.api.put(test_ctx.context_factory(user=user), 'u2') def test_removing_email(test_ctx): user = test_ctx.user_factory(name='u1', rank='admin') - test_ctx.session.add(user) + db.session.add(user) test_ctx.api.put( test_ctx.context_factory(input={'email': ''}, user=user), 'u1') - assert get_user(test_ctx.session, 'u1').email is None + assert get_user('u1').email is None @pytest.mark.parametrize('input', [ {'name': 'whatever'}, @@ -147,7 +145,7 @@ def test_removing_email(test_ctx): def test_trying_to_update_someone_else(test_ctx, input): user1 = test_ctx.user_factory(name='u1', rank='regular_user') user2 = test_ctx.user_factory(name='u2', rank='regular_user') - test_ctx.session.add_all([user1, user2]) + db.session.add_all([user1, user2]) with pytest.raises(errors.AuthError): test_ctx.api.put( test_ctx.context_factory(input=input, user=user1), user2.name) @@ -155,7 +153,7 @@ def test_trying_to_update_someone_else(test_ctx, input): def test_trying_to_become_someone_else(test_ctx): user1 = test_ctx.user_factory(name='me', rank='regular_user') user2 = test_ctx.user_factory(name='her', rank='regular_user') - test_ctx.session.add_all([user1, user2]) + db.session.add_all([user1, user2]) with pytest.raises(users.UserAlreadyExistsError): test_ctx.api.put( test_ctx.context_factory(input={'name': 'her'}, user=user1), @@ -167,7 +165,7 @@ def test_trying_to_become_someone_else(test_ctx): def test_mods_trying_to_become_admin(test_ctx): user1 = test_ctx.user_factory(name='u1', rank='mod') user2 = test_ctx.user_factory(name='u2', rank='mod') - test_ctx.session.add_all([user1, user2]) + db.session.add_all([user1, user2]) context = test_ctx.context_factory(input={'rank': 'admin'}, user=user1) with pytest.raises(errors.AuthError): test_ctx.api.put(context, user1.name) @@ -178,14 +176,14 @@ def test_uploading_avatar(test_ctx, tmpdir): config.config['data_dir'] = str(tmpdir.mkdir('data')) config.config['data_url'] = 'http://example.com/data/' user = test_ctx.user_factory(name='u1', rank='mod') - test_ctx.session.add(user) + db.session.add(user) response = test_ctx.api.put( test_ctx.context_factory( input={'avatarStyle': 'manual'}, files={'avatar': EMPTY_PIXEL}, user=user), 'u1') - user = get_user(test_ctx.session, 'u1') + user = get_user('u1') assert user.avatar_style == user.AVATAR_MANUAL assert response['user']['avatarUrl'] == \ 'http://example.com/data/avatars/u1.jpg' diff --git a/server/szurubooru/tests/conftest.py b/server/szurubooru/tests/conftest.py index 6b70fa75..fa0609e9 100644 --- a/server/szurubooru/tests/conftest.py +++ b/server/szurubooru/tests/conftest.py @@ -46,8 +46,8 @@ def fake_datetime(): freezer.stop() return injector -@pytest.yield_fixture -def session(query_counter, autoload=True): +@pytest.yield_fixture(autouse=True) +def session(query_counter): import logging logging.basicConfig() logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO) diff --git a/server/szurubooru/tests/db/test_post.py b/server/szurubooru/tests/db/test_post.py index 1415f2b9..9362d561 100644 --- a/server/szurubooru/tests/db/test_post.py +++ b/server/szurubooru/tests/db/test_post.py @@ -1,7 +1,7 @@ from datetime import datetime from szurubooru import db -def test_saving_post(session, post_factory, user_factory, tag_factory): +def test_saving_post(post_factory, user_factory, tag_factory): user = user_factory() tag1 = tag_factory() tag2 = tag_factory() @@ -13,17 +13,17 @@ def test_saving_post(session, post_factory, user_factory, tag_factory): post.checksum = 'deadbeef' post.creation_time = datetime(1997, 1, 1) post.last_edit_time = datetime(1998, 1, 1) - session.add_all([user, tag1, tag2, related_post1, related_post2, post]) + db.session.add_all([user, tag1, tag2, related_post1, related_post2, post]) post.user = user post.tags.append(tag1) post.tags.append(tag2) post.relations.append(related_post1) post.relations.append(related_post2) - session.commit() + db.session.commit() - post = session.query(db.Post).filter(db.Post.post_id == post.post_id).one() - assert not session.dirty + db.session.refresh(post) + assert not db.session.dirty assert post.user.user_id is not None assert post.safety == 'safety' assert post.type == 'type' @@ -32,60 +32,60 @@ def test_saving_post(session, post_factory, user_factory, tag_factory): assert post.last_edit_time == datetime(1998, 1, 1) assert len(post.relations) == 2 -def test_cascade_deletions(session, post_factory, user_factory, tag_factory): +def test_cascade_deletions(post_factory, user_factory, tag_factory): user = user_factory() tag1 = tag_factory() tag2 = tag_factory() related_post1 = post_factory() related_post2 = post_factory() post = post_factory() - session.add_all([user, tag1, tag2, post, related_post1, related_post2]) - session.flush() + db.session.add_all([user, tag1, tag2, post, related_post1, related_post2]) + db.session.flush() post.user = user post.tags.append(tag1) post.tags.append(tag2) post.relations.append(related_post1) post.relations.append(related_post2) - session.flush() + db.session.flush() - assert not session.dirty + assert not db.session.dirty assert post.user.user_id is not None assert len(post.relations) == 2 - assert session.query(db.User).count() == 1 - assert session.query(db.Tag).count() == 2 - assert session.query(db.Post).count() == 3 - assert session.query(db.PostTag).count() == 2 - assert session.query(db.PostRelation).count() == 2 + assert db.session.query(db.User).count() == 1 + assert db.session.query(db.Tag).count() == 2 + assert db.session.query(db.Post).count() == 3 + assert db.session.query(db.PostTag).count() == 2 + assert db.session.query(db.PostRelation).count() == 2 - session.delete(post) - session.commit() + db.session.delete(post) + db.session.commit() - assert not session.dirty - assert session.query(db.User).count() == 1 - assert session.query(db.Tag).count() == 2 - assert session.query(db.Post).count() == 2 - assert session.query(db.PostTag).count() == 0 - assert session.query(db.PostRelation).count() == 0 + assert not db.session.dirty + assert db.session.query(db.User).count() == 1 + assert db.session.query(db.Tag).count() == 2 + assert db.session.query(db.Post).count() == 2 + assert db.session.query(db.PostTag).count() == 0 + assert db.session.query(db.PostRelation).count() == 0 -def test_tracking_tag_count(session, post_factory, tag_factory): +def test_tracking_tag_count(post_factory, tag_factory): post = post_factory() tag1 = tag_factory() tag2 = tag_factory() - session.add_all([tag1, tag2, post]) - session.flush() + db.session.add_all([tag1, tag2, post]) + db.session.flush() post.tags.append(tag1) post.tags.append(tag2) - session.commit() + db.session.commit() assert len(post.tags) == 2 assert post.tag_count == 2 - session.delete(tag1) - session.commit() - session.refresh(post) + db.session.delete(tag1) + db.session.commit() + db.session.refresh(post) assert len(post.tags) == 1 assert post.tag_count == 1 - session.delete(tag2) - session.commit() - session.refresh(post) + db.session.delete(tag2) + db.session.commit() + db.session.refresh(post) assert len(post.tags) == 0 assert post.tag_count == 0 diff --git a/server/szurubooru/tests/db/test_tag.py b/server/szurubooru/tests/db/test_tag.py index dd67f653..c6658239 100644 --- a/server/szurubooru/tests/db/test_tag.py +++ b/server/szurubooru/tests/db/test_tag.py @@ -1,7 +1,7 @@ from datetime import datetime from szurubooru import db -def test_saving_tag(session, tag_factory): +def test_saving_tag(tag_factory): sug1 = tag_factory(names=['sug1']) sug2 = tag_factory(names=['sug2']) imp1 = tag_factory(names=['imp1']) @@ -13,8 +13,8 @@ def test_saving_tag(session, tag_factory): tag.category = db.TagCategory('category') tag.creation_time = datetime(1997, 1, 1) tag.last_edit_time = datetime(1998, 1, 1) - session.add_all([tag, sug1, sug2, imp1, imp2]) - session.commit() + db.session.add_all([tag, sug1, sug2, imp1, imp2]) + db.session.commit() assert tag.tag_id is not None assert sug1.tag_id is not None @@ -25,9 +25,10 @@ def test_saving_tag(session, tag_factory): tag.suggestions.append(sug2) tag.implications.append(imp1) tag.implications.append(imp2) - session.commit() + db.session.commit() - tag = session.query(db.Tag) \ + tag = db.session \ + .query(db.Tag) \ .join(db.TagName) \ .filter(db.TagName.name=='alias1') \ .one() @@ -40,7 +41,7 @@ def test_saving_tag(session, tag_factory): assert [relation.names[0].name for relation in tag.implications] \ == ['imp1', 'imp2'] -def test_cascade_deletions(session, tag_factory): +def test_cascade_deletions(tag_factory): sug1 = tag_factory(names=['sug1']) sug2 = tag_factory(names=['sug2']) imp1 = tag_factory(names=['imp1']) @@ -53,8 +54,8 @@ def test_cascade_deletions(session, tag_factory): tag.creation_time = datetime(1997, 1, 1) tag.last_edit_time = datetime(1998, 1, 1) tag.post_count = 1 - session.add_all([tag, sug1, sug2, imp1, imp2]) - session.commit() + db.session.add_all([tag, sug1, sug2, imp1, imp2]) + db.session.commit() assert tag.tag_id is not None assert sug1.tag_id is not None @@ -65,32 +66,32 @@ def test_cascade_deletions(session, tag_factory): tag.suggestions.append(sug2) tag.implications.append(imp1) tag.implications.append(imp2) - session.commit() + db.session.commit() - session.delete(tag) - session.commit() - assert session.query(db.Tag).count() == 4 - assert session.query(db.TagName).count() == 4 - assert session.query(db.TagImplication).count() == 0 - assert session.query(db.TagSuggestion).count() == 0 + db.session.delete(tag) + db.session.commit() + assert db.session.query(db.Tag).count() == 4 + assert db.session.query(db.TagName).count() == 4 + assert db.session.query(db.TagImplication).count() == 0 + assert db.session.query(db.TagSuggestion).count() == 0 -def test_tracking_post_count(session, post_factory, tag_factory): +def test_tracking_post_count(post_factory, tag_factory): tag = tag_factory() post1 = post_factory() post2 = post_factory() - session.add_all([tag, post1, post2]) - session.flush() + db.session.add_all([tag, post1, post2]) + db.session.flush() post1.tags.append(tag) post2.tags.append(tag) - session.commit() + db.session.commit() assert len(post1.tags) == 1 assert len(post2.tags) == 1 assert tag.post_count == 2 - session.delete(post1) - session.commit() - session.refresh(tag) + db.session.delete(post1) + db.session.commit() + db.session.refresh(tag) assert tag.post_count == 1 - session.delete(post2) - session.commit() - session.refresh(tag) + db.session.delete(post2) + db.session.commit() + db.session.refresh(tag) assert tag.post_count == 0 diff --git a/server/szurubooru/tests/db/test_user.py b/server/szurubooru/tests/db/test_user.py index 1cabc731..1a38cd01 100644 --- a/server/szurubooru/tests/db/test_user.py +++ b/server/szurubooru/tests/db/test_user.py @@ -1,7 +1,7 @@ from datetime import datetime from szurubooru import db -def test_saving_user(session): +def test_saving_user(): user = db.User() user.name = 'name' user.password_salt = 'salt' @@ -10,8 +10,10 @@ def test_saving_user(session): user.rank = 'rank' user.creation_time = datetime(1997, 1, 1) user.avatar_style = db.User.AVATAR_GRAVATAR - session.add(user) - user = session.query(db.User).one() + db.session.add(user) + db.session.flush() + db.session.refresh(user) + assert not db.session.dirty assert user.name == 'name' assert user.password_salt == 'salt' assert user.password_hash == 'hash' diff --git a/server/szurubooru/tests/search/test_tag_search_config.py b/server/szurubooru/tests/search/test_tag_search_config.py index 91d00132..b579abe3 100644 --- a/server/szurubooru/tests/search/test_tag_search_config.py +++ b/server/szurubooru/tests/search/test_tag_search_config.py @@ -40,14 +40,14 @@ def verify_unpaged(executor): ('-creation-date:2014-01,2015', ['t2']), ]) def test_filter_by_creation_time( - verify_unpaged, session, tag_factory, input, expected_tag_names): + verify_unpaged, tag_factory, input, expected_tag_names): tag1 = tag_factory(names=['t1']) tag2 = tag_factory(names=['t2']) tag3 = tag_factory(names=['t3']) tag1.creation_time = datetime.datetime(2014, 1, 1) tag2.creation_time = datetime.datetime(2014, 6, 1) tag3.creation_time = datetime.datetime(2015, 1, 1) - session.add_all([tag1, tag2, tag3]) + db.session.add_all([tag1, tag2, tag3]) verify_unpaged(input, expected_tag_names) @pytest.mark.parametrize('input,expected_tag_names', [ @@ -70,12 +70,11 @@ def test_filter_by_creation_time( ('name:tag4', ['tag4']), ('name:tag4,tag5', ['tag4']), ]) -def test_filter_by_name( - session, verify_unpaged, tag_factory, input, expected_tag_names): - session.add(tag_factory(names=['tag1'])) - session.add(tag_factory(names=['tag2'])) - session.add(tag_factory(names=['tag3'])) - session.add(tag_factory(names=['tag4', 'tag5', 'tag6'])) +def test_filter_by_name(verify_unpaged, tag_factory, input, expected_tag_names): + db.session.add(tag_factory(names=['tag1'])) + db.session.add(tag_factory(names=['tag2'])) + db.session.add(tag_factory(names=['tag3'])) + db.session.add(tag_factory(names=['tag4', 'tag5', 'tag6'])) verify_unpaged(input, expected_tag_names) @pytest.mark.parametrize('input,expected_tag_names', [ @@ -84,10 +83,9 @@ def test_filter_by_name( ('t2', ['t2']), ('t1,t2', ['t1', 't2']), ]) -def test_anonymous( - session, verify_unpaged, tag_factory, input, expected_tag_names): - session.add(tag_factory(names=['t1'])) - session.add(tag_factory(names=['t2'])) +def test_anonymous(verify_unpaged, tag_factory, input, expected_tag_names): + db.session.add(tag_factory(names=['t1'])) + db.session.add(tag_factory(names=['t2'])) verify_unpaged(input, expected_tag_names) @pytest.mark.parametrize('input,expected_tag_names', [ @@ -99,10 +97,9 @@ def test_anonymous( ('-order:name,asc', ['t2', 't1']), ('-order:name,desc', ['t1', 't2']), ]) -def test_order_by_name( - session, verify_unpaged, tag_factory, input, expected_tag_names): - session.add(tag_factory(names=['t2'])) - session.add(tag_factory(names=['t1'])) +def test_order_by_name(verify_unpaged, tag_factory, input, expected_tag_names): + db.session.add(tag_factory(names=['t2'])) + db.session.add(tag_factory(names=['t1'])) verify_unpaged(input, expected_tag_names) @pytest.mark.parametrize('input,expected_user_names', [ @@ -111,28 +108,28 @@ def test_order_by_name( ('order:creation-time', ['t3', 't2', 't1']), ]) def test_order_by_creation_time( - session, verify_unpaged, tag_factory, input, expected_user_names): + verify_unpaged, tag_factory, input, expected_user_names): tag1 = tag_factory(names=['t1']) tag2 = tag_factory(names=['t2']) tag3 = tag_factory(names=['t3']) tag1.creation_time = datetime.datetime(1991, 1, 1) tag2.creation_time = datetime.datetime(1991, 1, 2) tag3.creation_time = datetime.datetime(1991, 1, 3) - session.add_all([tag3, tag1, tag2]) + db.session.add_all([tag3, tag1, tag2]) verify_unpaged(input, expected_user_names) @pytest.mark.parametrize('input,expected_tag_names', [ ('order:suggestion-count', ['t1', 't2', 'sug1', 'sug2', 'sug3']), ]) def test_order_by_suggestion_count( - session, verify_unpaged, tag_factory, input, expected_tag_names): + verify_unpaged, tag_factory, input, expected_tag_names): sug1 = tag_factory(names=['sug1']) sug2 = tag_factory(names=['sug2']) sug3 = tag_factory(names=['sug3']) tag1 = tag_factory(names=['t1']) tag2 = tag_factory(names=['t2']) - session.add_all([sug1, sug3, tag2, sug2, tag1]) - session.commit() + db.session.add_all([sug1, sug3, tag2, sug2, tag1]) + db.session.commit() tag1.suggestions.append(sug1) tag1.suggestions.append(sug2) tag2.suggestions.append(sug3) @@ -142,32 +139,32 @@ def test_order_by_suggestion_count( ('order:implication-count', ['t1', 't2', 'sug1', 'sug2', 'sug3']), ]) def test_order_by_implication_count( - session, verify_unpaged, tag_factory, input, expected_tag_names): + verify_unpaged, tag_factory, input, expected_tag_names): sug1 = tag_factory(names=['sug1']) sug2 = tag_factory(names=['sug2']) sug3 = tag_factory(names=['sug3']) tag1 = tag_factory(names=['t1']) tag2 = tag_factory(names=['t2']) - session.add_all([sug1, sug3, tag2, sug2, tag1]) - session.commit() + db.session.add_all([sug1, sug3, tag2, sug2, tag1]) + db.session.commit() tag1.implications.append(sug1) tag1.implications.append(sug2) tag2.implications.append(sug3) verify_unpaged(input, expected_tag_names) -def test_filter_by_relation_count(session, verify_unpaged, tag_factory): +def test_filter_by_relation_count(verify_unpaged, tag_factory): sug1 = tag_factory(names=['sug1']) sug2 = tag_factory(names=['sug2']) imp1 = tag_factory(names=['imp1']) tag1 = tag_factory(names=['t1']) tag2 = tag_factory(names=['t2']) - session.add_all([sug1, tag1, sug2, imp1, tag2]) - session.commit() - session.add_all([ + db.session.add_all([sug1, tag1, sug2, imp1, tag2]) + db.session.commit() + db.session.add_all([ db.TagSuggestion(tag1.tag_id, sug1.tag_id), db.TagSuggestion(tag1.tag_id, sug2.tag_id), db.TagImplication(tag2.tag_id, imp1.tag_id)]) - session.commit() + db.session.commit() verify_unpaged('suggestion-count:0', ['imp1', 'sug1', 'sug2', 't2']) verify_unpaged('suggestion-count:1', []) verify_unpaged('suggestion-count:2', ['t1']) diff --git a/server/szurubooru/tests/search/test_user_search_config.py b/server/szurubooru/tests/search/test_user_search_config.py index 42e7dd00..63928daf 100644 --- a/server/szurubooru/tests/search/test_user_search_config.py +++ b/server/szurubooru/tests/search/test_user_search_config.py @@ -42,14 +42,14 @@ def verify_unpaged(executor): ('-creation-date:2014-01,2015', ['u2']), ]) def test_filter_by_creation_time( - verify_unpaged, session, input, expected_user_names, user_factory): + verify_unpaged, input, expected_user_names, user_factory): user1 = user_factory(name='u1') user2 = user_factory(name='u2') user3 = user_factory(name='u3') user1.creation_time = datetime.datetime(2014, 1, 1) user2.creation_time = datetime.datetime(2014, 6, 1) user3.creation_time = datetime.datetime(2015, 1, 1) - session.add_all([user1, user2, user3]) + db.session.add_all([user1, user2, user3]) verify_unpaged(input, expected_user_names) @pytest.mark.parametrize('input,expected_user_names', [ @@ -71,10 +71,10 @@ def test_filter_by_creation_time( ('-name:user1,user3', ['user2']), ]) def test_filter_by_name( - session, verify_unpaged, input, expected_user_names, user_factory): - session.add(user_factory(name='user1')) - session.add(user_factory(name='user2')) - session.add(user_factory(name='user3')) + verify_unpaged, input, expected_user_names, user_factory): + db.session.add(user_factory(name='user1')) + db.session.add(user_factory(name='user2')) + db.session.add(user_factory(name='user3')) verify_unpaged(input, expected_user_names) @pytest.mark.parametrize('input,expected_user_names', [ @@ -84,9 +84,9 @@ def test_filter_by_name( ('u1,u2', ['u1', 'u2']), ]) def test_anonymous( - session, verify_unpaged, input, expected_user_names, user_factory): - session.add(user_factory(name='u1')) - session.add(user_factory(name='u2')) + verify_unpaged, input, expected_user_names, user_factory): + db.session.add(user_factory(name='u1')) + db.session.add(user_factory(name='u2')) verify_unpaged(input, expected_user_names) @pytest.mark.parametrize('input,expected_user_names', [ @@ -95,14 +95,14 @@ def test_anonymous( ('creation-time:2016 u2', []), ]) def test_combining_tokens( - session, verify_unpaged, input, expected_user_names, user_factory): + verify_unpaged, input, expected_user_names, user_factory): user1 = user_factory(name='u1') user2 = user_factory(name='u2') user3 = user_factory(name='u3') user1.creation_time = datetime.datetime(2014, 1, 1) user2.creation_time = datetime.datetime(2014, 6, 1) user3.creation_time = datetime.datetime(2015, 1, 1) - session.add_all([user1, user2, user3]) + db.session.add_all([user1, user2, user3]) verify_unpaged(input, expected_user_names) @pytest.mark.parametrize( @@ -114,10 +114,10 @@ def test_combining_tokens( (0, 0, 2, []), ]) def test_paging( - session, executor, user_factory, page, page_size, + executor, user_factory, page, page_size, expected_total_count, expected_user_names): - session.add(user_factory(name='u1')) - session.add(user_factory(name='u2')) + db.session.add(user_factory(name='u1')) + db.session.add(user_factory(name='u2')) actual_count, actual_users = executor.execute( '', page=page, page_size=page_size) actual_user_names = [u.name for u in actual_users] @@ -134,9 +134,9 @@ def test_paging( ('-order:name,desc', ['u1', 'u2']), ]) def test_order_by_name( - session, verify_unpaged, input, expected_user_names, user_factory): - session.add(user_factory(name='u2')) - session.add(user_factory(name='u1')) + verify_unpaged, input, expected_user_names, user_factory): + db.session.add(user_factory(name='u2')) + db.session.add(user_factory(name='u1')) verify_unpaged(input, expected_user_names) @pytest.mark.parametrize('input,expected_user_names', [ @@ -150,14 +150,14 @@ def test_order_by_name( ('-order:creation-date,desc', ['u1', 'u2', 'u3']), ]) def test_order_by_creation_time( - session, verify_unpaged, input, expected_user_names, user_factory): + verify_unpaged, input, expected_user_names, user_factory): user1 = user_factory(name='u1') user2 = user_factory(name='u2') user3 = user_factory(name='u3') user1.creation_time = datetime.datetime(1991, 1, 1) user2.creation_time = datetime.datetime(1991, 1, 2) user3.creation_time = datetime.datetime(1991, 1, 3) - session.add_all([user3, user1, user2]) + db.session.add_all([user3, user1, user2]) verify_unpaged(input, expected_user_names) @pytest.mark.parametrize('input,expected_user_names', [ @@ -168,21 +168,21 @@ def test_order_by_creation_time( ('order:login-time', ['u3', 'u2', 'u1']), ]) def test_order_by_last_login_time( - session, verify_unpaged, input, expected_user_names, user_factory): + verify_unpaged, input, expected_user_names, user_factory): user1 = user_factory(name='u1') user2 = user_factory(name='u2') user3 = user_factory(name='u3') user1.last_login_time = datetime.datetime(1991, 1, 1) user2.last_login_time = datetime.datetime(1991, 1, 2) user3.last_login_time = datetime.datetime(1991, 1, 3) - session.add_all([user3, user1, user2]) + db.session.add_all([user3, user1, user2]) verify_unpaged(input, expected_user_names) -def test_random_order(session, executor, user_factory): +def test_random_order(executor, user_factory): user1 = user_factory(name='u1') user2 = user_factory(name='u2') user3 = user_factory(name='u3') - session.add_all([user3, user1, user2]) + db.session.add_all([user3, user1, user2]) actual_count, actual_users = executor.execute( 'order:random', page=1, page_size=100) actual_user_names = [u.name for u in actual_users] diff --git a/server/szurubooru/tests/util/test_snapshots.py b/server/szurubooru/tests/util/test_snapshots.py index 9d6f3c6f..e4062ae7 100644 --- a/server/szurubooru/tests/util/test_snapshots.py +++ b/server/szurubooru/tests/util/test_snapshots.py @@ -3,7 +3,7 @@ import pytest from szurubooru import db from szurubooru.util import snapshots -def test_serializing_tag(session, tag_factory): +def test_serializing_tag(tag_factory): tag = tag_factory(names=['main_name', 'alias'], category_name='dummy') assert snapshots.get_tag_snapshot(tag) == { 'names': ['main_name', 'alias'], @@ -15,10 +15,10 @@ def test_serializing_tag(session, tag_factory): imp2 = tag_factory(names=['imp2_main_name', 'imp2_alias']) sug1 = tag_factory(names=['sug1_main_name', 'sug1_alias']) sug2 = tag_factory(names=['sug2_main_name', 'sug2_alias']) - session.add_all([imp1, imp2, sug1, sug2]) + db.session.add_all([imp1, imp2, sug1, sug2]) tag.implications = [imp1, imp2] tag.suggestions = [sug1, sug2] - session.flush() + db.session.flush() assert snapshots.get_tag_snapshot(tag) == { 'names': ['main_name', 'alias'], 'category': 'dummy', @@ -26,39 +26,38 @@ def test_serializing_tag(session, tag_factory): 'suggestions': ['sug1_main_name', 'sug2_main_name'], } -def test_merging_modification_to_creation(session, tag_factory, user_factory): +def test_merging_modification_to_creation(tag_factory, user_factory): tag = tag_factory(names=['dummy']) user = user_factory() - session.add_all([tag, user]) - session.flush() + db.session.add_all([tag, user]) + db.session.flush() snapshots.create(tag, user) - session.flush() + db.session.flush() tag.names = [db.TagName('changed')] snapshots.modify(tag, user) - session.flush() - results = session.query(db.Snapshot).all() + db.session.flush() + results = db.session.query(db.Snapshot).all() assert len(results) == 1 assert results[0].operation == db.Snapshot.OPERATION_CREATED assert results[0].data['names'] == ['changed'] -def test_merging_modifications( - fake_datetime, session, tag_factory, user_factory): +def test_merging_modifications(fake_datetime, tag_factory, user_factory): tag = tag_factory(names=['dummy']) user = user_factory() - session.add_all([tag, user]) - session.flush() + db.session.add_all([tag, user]) + db.session.flush() with fake_datetime('13:00:00'): snapshots.create(tag, user) - session.flush() + db.session.flush() tag.names = [db.TagName('changed')] with fake_datetime('14:00:00'): snapshots.modify(tag, user) - session.flush() + db.session.flush() tag.names = [db.TagName('changed again')] with fake_datetime('14:00:01'): snapshots.modify(tag, user) - session.flush() - results = session.query(db.Snapshot).all() + db.session.flush() + results = db.session.query(db.Snapshot).all() assert len(results) == 2 assert results[0].operation == db.Snapshot.OPERATION_CREATED assert results[1].operation == db.Snapshot.OPERATION_MODIFIED @@ -66,93 +65,94 @@ def test_merging_modifications( assert results[1].data['names'] == ['changed again'] def test_not_adding_snapshot_if_data_doesnt_change( - fake_datetime, session, tag_factory, user_factory): + fake_datetime, tag_factory, user_factory): tag = tag_factory(names=['dummy']) user = user_factory() - session.add_all([tag, user]) - session.flush() + db.session.add_all([tag, user]) + db.session.flush() with fake_datetime('13:00:00'): snapshots.create(tag, user) - session.flush() + db.session.flush() with fake_datetime('14:00:00'): snapshots.modify(tag, user) - session.flush() - results = session.query(db.Snapshot).all() + db.session.flush() + results = db.session.query(db.Snapshot).all() assert len(results) == 1 assert results[0].operation == db.Snapshot.OPERATION_CREATED assert results[0].data['names'] == ['dummy'] def test_not_merging_due_to_time_difference( - fake_datetime, session, tag_factory, user_factory): + fake_datetime, tag_factory, user_factory): tag = tag_factory(names=['dummy']) user = user_factory() - session.add_all([tag, user]) - session.flush() + db.session.add_all([tag, user]) + db.session.flush() with fake_datetime('13:00:00'): snapshots.create(tag, user) - session.flush() + db.session.flush() tag.names = [db.TagName('changed')] with fake_datetime('13:10:01'): snapshots.modify(tag, user) - session.flush() - assert session.query(db.Snapshot).count() == 2 + db.session.flush() + assert db.session.query(db.Snapshot).count() == 2 def test_not_merging_operations_by_different_users( - fake_datetime, session, tag_factory, user_factory): + fake_datetime, tag_factory, user_factory): tag = tag_factory(names=['dummy']) user1, user2 = [user_factory(), user_factory()] - session.add_all([tag, user1, user2]) - session.flush() + db.session.add_all([tag, user1, user2]) + db.session.flush() with fake_datetime('13:00:00'): snapshots.create(tag, user1) - session.flush() + db.session.flush() tag.names = [db.TagName('changed')] snapshots.modify(tag, user2) - session.flush() - assert session.query(db.Snapshot).count() == 2 + db.session.flush() + assert db.session.query(db.Snapshot).count() == 2 def test_merging_resets_merging_time_window( - fake_datetime, session, tag_factory, user_factory): + fake_datetime, tag_factory, user_factory): tag = tag_factory(names=['dummy']) user = user_factory() - session.add_all([tag, user]) - session.flush() + db.session.add_all([tag, user]) + db.session.flush() with fake_datetime('13:00:00'): snapshots.create(tag, user) - session.flush() + db.session.flush() tag.names = [db.TagName('changed')] with fake_datetime('13:09:59'): snapshots.modify(tag, user) - session.flush() + db.session.flush() tag.names = [db.TagName('changed again')] with fake_datetime('13:19:59'): snapshots.modify(tag, user) - session.flush() - results = session.query(db.Snapshot).all() + db.session.flush() + results = db.session.query(db.Snapshot).all() assert len(results) == 1 assert results[0].data['names'] == ['changed again'] @pytest.mark.parametrize( 'initial_operation', [snapshots.create, snapshots.modify]) def test_merging_deletion_to_modification_or_creation( - fake_datetime, session, tag_factory, user_factory, initial_operation): + fake_datetime, tag_factory, user_factory, initial_operation): tag = tag_factory(names=['dummy'], category_name='dummy') user = user_factory() - session.add_all([tag, user]) - session.flush() + db.session.add_all([tag, user]) + db.session.flush() with fake_datetime('13:00:00'): initial_operation(tag, user) - session.flush() + db.session.flush() tag.names = [db.TagName('changed')] with fake_datetime('14:00:00'): snapshots.modify(tag, user) - session.flush() + db.session.flush() tag.names = [db.TagName('changed again')] with fake_datetime('14:00:01'): snapshots.delete(tag, user) - session.flush() - assert session.query(db.Snapshot).count() == 2 - results = session.query(db.Snapshot) \ + db.session.flush() + assert db.session.query(db.Snapshot).count() == 2 + results = db.session \ + .query(db.Snapshot) \ .order_by(db.Snapshot.snapshot_id.asc()) \ .all() assert results[1].operation == db.Snapshot.OPERATION_DELETED @@ -161,20 +161,20 @@ def test_merging_deletion_to_modification_or_creation( @pytest.mark.parametrize( 'expected_operation', [snapshots.create, snapshots.modify]) def test_merging_deletion_all_the_way_deletes_all_snapshots( - fake_datetime, session, tag_factory, user_factory, expected_operation): + fake_datetime, tag_factory, user_factory, expected_operation): tag = tag_factory(names=['dummy']) user = user_factory() - session.add_all([tag, user]) - session.flush() + db.session.add_all([tag, user]) + db.session.flush() with fake_datetime('13:00:00'): snapshots.create(tag, user) - session.flush() + db.session.flush() tag.names = [db.TagName('changed')] with fake_datetime('13:00:01'): snapshots.modify(tag, user) - session.flush() + db.session.flush() tag.names = [db.TagName('changed again')] with fake_datetime('13:00:02'): snapshots.delete(tag, user) - session.flush() - assert session.query(db.Snapshot).count() == 0 + db.session.flush() + assert db.session.query(db.Snapshot).count() == 0 diff --git a/server/szurubooru/util/snapshots.py b/server/szurubooru/util/snapshots.py index c8822bfb..f1cc5148 100644 --- a/server/szurubooru/util/snapshots.py +++ b/server/szurubooru/util/snapshots.py @@ -34,8 +34,8 @@ def save(operation, entity, auth_user): snapshot.data = serializers[table_name](entity) snapshot.user = auth_user - session = db.session() - earlier_snapshots = session.query(db.Snapshot) \ + earlier_snapshots = db.session \ + .query(db.Snapshot) \ .filter(db.Snapshot.resource_type == table_name) \ .filter(db.Snapshot.resource_id == primary_key) \ .order_by(db.Snapshot.creation_time.desc()) \ @@ -49,7 +49,7 @@ def save(operation, entity, auth_user): if snapshot.data != last_snapshot.data: if not is_fresh or last_snapshot.user != auth_user: break - session.delete(last_snapshot) + db.session.delete(last_snapshot) if snapshot.operation != db.Snapshot.OPERATION_DELETED: snapshot.operation = last_snapshot.operation snapshots_left -= 1 @@ -57,7 +57,7 @@ def save(operation, entity, auth_user): if not snapshots_left and operation == db.Snapshot.OPERATION_DELETED: pass else: - session.add(snapshot) + db.session.add(snapshot) def create(entity, auth_user): save(db.Snapshot.OPERATION_CREATED, entity, auth_user) diff --git a/server/szurubooru/util/tag_categories.py b/server/szurubooru/util/tag_categories.py index 33dd0bac..15e86894 100644 --- a/server/szurubooru/util/tag_categories.py +++ b/server/szurubooru/util/tag_categories.py @@ -26,7 +26,7 @@ def update_name(category, name): expr = db.TagCategory.name.ilike(name) if category.tag_category_id: expr = expr & (db.TagCategory.tag_category_id != category.tag_category_id) - already_exists = db.session().query(db.TagCategory).filter(expr).count() > 0 + already_exists = db.session.query(db.TagCategory).filter(expr).count() > 0 if already_exists: raise TagCategoryAlreadyExistsError( 'A category with this name already exists.') @@ -43,7 +43,8 @@ def update_color(category, color): category.color = color def get_category_by_name(name): - return db.session.query(db.TagCategory) \ + return db.session \ + .query(db.TagCategory) \ .filter(db.TagCategory.name.ilike(name)) \ .first() @@ -54,7 +55,8 @@ def get_all_categories(): return db.session.query(db.TagCategory).all() def get_default_category(): - return db.session().query(db.TagCategory) \ + return db.session \ + .query(db.TagCategory) \ .order_by(db.TagCategory.tag_category_id.asc()) \ .limit(1) \ .one() diff --git a/server/szurubooru/util/tags.py b/server/szurubooru/util/tags.py index 69cc95e9..ff3c6ec9 100644 --- a/server/szurubooru/util/tags.py +++ b/server/szurubooru/util/tags.py @@ -32,7 +32,7 @@ def export_to_json(): 'tags': [], 'categories': [], } - all_tags = db.session() \ + all_tags = db.session \ .query(db.Tag) \ .options( sqlalchemy.orm.joinedload('suggestions'), @@ -61,7 +61,8 @@ def export_to_json(): handle.write(json.dumps(output, separators=(',', ':'))) def get_tag_by_name(name): - return db.session().query(db.Tag) \ + return db.session \ + .query(db.Tag) \ .join(db.TagName) \ .filter(db.TagName.name.ilike(name)) \ .first() @@ -73,7 +74,7 @@ def get_tags_by_names(names): expr = sqlalchemy.sql.false() for name in names: expr = expr | db.TagName.name.ilike(name) - return db.session().query(db.Tag).join(db.TagName).filter(expr).all() + return db.session.query(db.Tag).join(db.TagName).filter(expr).all() def get_or_create_tags_by_names(names): names = misc.icase_unique(names) @@ -93,7 +94,7 @@ def get_or_create_tags_by_names(names): category_name=tag_categories.get_default_category().name, suggestions=[], implications=[]) - db.session().add(new_tag) + db.session.add(new_tag) new_tags.append(new_tag) return related_tags, new_tags @@ -107,8 +108,8 @@ def create_tag(names, category_name, suggestions, implications): return tag def update_category_name(tag, category_name): - session = db.session() - category = session.query(db.TagCategory) \ + category = db.session \ + .query(db.TagCategory) \ .filter(db.TagCategory.name == category_name) \ .first() if not category: @@ -131,7 +132,7 @@ def update_names(tag, names): expr = expr | db.TagName.name.ilike(name) if tag.tag_id: expr = expr & (db.TagName.tag_id != tag.tag_id) - existing_tags = db.session().query(db.TagName).filter(expr).all() + existing_tags = db.session.query(db.TagName).filter(expr).all() if len(existing_tags): raise TagAlreadyExistsError( 'One of names is already used by another tag.') @@ -149,12 +150,12 @@ def update_implications(tag, relations): if _check_name_intersection(_get_plain_names(tag), relations): raise InvalidTagRelationError('Tag cannot imply itself.') related_tags, new_tags = get_or_create_tags_by_names(relations) - db.session().flush() + db.session.flush() tag.implications = related_tags + new_tags def update_suggestions(tag, relations): if _check_name_intersection(_get_plain_names(tag), relations): raise InvalidTagRelationError('Tag cannot suggest itself.') related_tags, new_tags = get_or_create_tags_by_names(relations) - db.session().flush() + db.session.flush() tag.suggestions = related_tags + new_tags diff --git a/server/szurubooru/util/users.py b/server/szurubooru/util/users.py index c3f2850c..07700c4a 100644 --- a/server/szurubooru/util/users.py +++ b/server/szurubooru/util/users.py @@ -16,12 +16,14 @@ def get_user_count(): return db.session.query(db.User).count() def get_user_by_name(name): - return db.session.query(db.User) \ + return db.session \ + .query(db.User) \ .filter(func.lower(db.User.name) == func.lower(name)) \ .first() def get_user_by_name_or_email(name_or_email): - return db.session.query(db.User) \ + return db.session \ + .query(db.User) \ .filter( (func.lower(db.User.name) == func.lower(name_or_email)) | (func.lower(db.User.email) == func.lower(name_or_email))) \