From d08308440713063502979a40c04c641ddef7dde7 Mon Sep 17 00:00:00 2001 From: Shyam Sunder Date: Thu, 23 Sep 2021 12:24:56 -0400 Subject: [PATCH] server/tests: use transactional db for faster unit tests * `test_modify_saves_non_empty_diffs` needs non-transactional db, so moved to seperate file * Replaced incompatable usage of `db.session.rollback()` with parametrerized function calls * xfail conditionals for search removed, as we can no longer get current driver with binds * Also remove usage of deprecated `pytest.yield_fixture` --- .../szurubooru/tests/api/test_tag_updating.py | 14 +---- server/szurubooru/tests/conftest.py | 16 ++++- .../szurubooru/tests/func/test_snapshots.py | 42 +------------ .../test_snapshots_transactional_isolation.py | 59 +++++++++++++++++++ .../tests/func/test_tag_categories.py | 15 +++-- server/szurubooru/tests/func/test_tags.py | 16 +++-- .../search/configs/test_pool_search_config.py | 2 - .../search/configs/test_tag_search_config.py | 2 - 8 files changed, 91 insertions(+), 75 deletions(-) create mode 100644 server/szurubooru/tests/func/test_snapshots_transactional_isolation.py diff --git a/server/szurubooru/tests/api/test_tag_updating.py b/server/szurubooru/tests/api/test_tag_updating.py index 729734d9..be5f4858 100644 --- a/server/szurubooru/tests/api/test_tag_updating.py +++ b/server/szurubooru/tests/api/test_tag_updating.py @@ -145,8 +145,9 @@ def test_trying_to_update_without_privileges( ) +@pytest.mark.parametrize("type", ["suggestions", "implications"]) def test_trying_to_create_tags_without_privileges( - config_injector, context_factory, tag_factory, user_factory + config_injector, context_factory, tag_factory, user_factory, type ): tag = tag_factory(names=["tag"]) db.session.add(tag) @@ -165,16 +166,7 @@ def test_trying_to_create_tags_without_privileges( with pytest.raises(errors.AuthError): api.tag_api.update_tag( context_factory( - params={"suggestions": ["tag1", "tag2"], "version": 1}, - user=user_factory(rank=model.User.RANK_REGULAR), - ), - {"tag_name": "tag"}, - ) - db.session.rollback() - with pytest.raises(errors.AuthError): - api.tag_api.update_tag( - context_factory( - params={"implications": ["tag1", "tag2"], "version": 1}, + params={type: ["tag1", "tag2"], "version": 1}, user=user_factory(rank=model.User.RANK_REGULAR), ), {"tag_name": "tag"}, diff --git a/server/szurubooru/tests/conftest.py b/server/szurubooru/tests/conftest.py index e7811fe1..280987ca 100644 --- a/server/szurubooru/tests/conftest.py +++ b/server/szurubooru/tests/conftest.py @@ -43,14 +43,26 @@ def query_logger(pytestconfig): logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO) -@pytest.yield_fixture(scope="function", autouse=True) -def session(query_logger, postgresql_db): +@pytest.fixture(scope="function", autouse=True) +def session(query_logger, transacted_postgresql_db): + db.session = transacted_postgresql_db.session + transacted_postgresql_db.create_table(*model.Base.metadata.sorted_tables) + try: + yield transacted_postgresql_db.session + finally: + transacted_postgresql_db.reset_db() + + +@pytest.fixture(scope="function") +def nontransacted_session(query_logger, postgresql_db): + old_db_session = db.session db.session = postgresql_db.session postgresql_db.create_table(*model.Base.metadata.sorted_tables) try: yield postgresql_db.session finally: postgresql_db.reset_db() + db.session = old_db_session @pytest.fixture diff --git a/server/szurubooru/tests/func/test_snapshots.py b/server/szurubooru/tests/func/test_snapshots.py index da935307..dc68ff05 100644 --- a/server/szurubooru/tests/func/test_snapshots.py +++ b/server/szurubooru/tests/func/test_snapshots.py @@ -1,7 +1,7 @@ from datetime import datetime from unittest.mock import patch -import pytest +import pytest # noqa: F401 from szurubooru import db, model from szurubooru.func import snapshots, users @@ -144,46 +144,6 @@ def test_create(tag_factory, user_factory): assert results[0].data == "mocked" -def test_modify_saves_non_empty_diffs(post_factory, user_factory): - if "sqlite" in db.session.get_bind().driver: - pytest.xfail( - "SQLite doesn't support transaction isolation, " - "which is required to retrieve original entity" - ) - post = post_factory() - post.notes = [model.PostNote(polygon=[(0, 0), (0, 1), (1, 1)], text="old")] - user = user_factory() - db.session.add_all([post, user]) - db.session.commit() - post.source = "new source" - post.notes = [model.PostNote(polygon=[(0, 0), (0, 1), (1, 1)], text="new")] - db.session.flush() - with patch("szurubooru.func.snapshots._post_to_webhooks"): - snapshots.modify(post, user) - db.session.flush() - results = db.session.query(model.Snapshot).all() - assert len(results) == 1 - assert results[0].data == { - "type": "object change", - "value": { - "source": { - "type": "primitive change", - "old-value": None, - "new-value": "new source", - }, - "notes": { - "type": "list change", - "removed": [ - {"polygon": [[0, 0], [0, 1], [1, 1]], "text": "old"} - ], - "added": [ - {"polygon": [[0, 0], [0, 1], [1, 1]], "text": "new"} - ], - }, - }, - } - - def test_modify_doesnt_save_empty_diffs(tag_factory, user_factory): tag = tag_factory(names=["dummy"]) user = user_factory() diff --git a/server/szurubooru/tests/func/test_snapshots_transactional_isolation.py b/server/szurubooru/tests/func/test_snapshots_transactional_isolation.py new file mode 100644 index 00000000..b98cea7a --- /dev/null +++ b/server/szurubooru/tests/func/test_snapshots_transactional_isolation.py @@ -0,0 +1,59 @@ +from unittest.mock import patch + +import pytest + +from szurubooru import db, model +from szurubooru.func import snapshots + + +@pytest.fixture(autouse=True) +def session(query_logger, postgresql_db): + """ + Override db session for this specific test section only + """ + db.session = postgresql_db.session + postgresql_db.create_table(*model.Base.metadata.sorted_tables) + try: + yield postgresql_db.session + finally: + postgresql_db.reset_db() + + +def test_modify_saves_non_empty_diffs(post_factory, user_factory): + if "sqlite" in db.session.get_bind().driver: + pytest.xfail( + "SQLite doesn't support transaction isolation, " + "which is required to retrieve original entity" + ) + post = post_factory() + post.notes = [model.PostNote(polygon=[(0, 0), (0, 1), (1, 1)], text="old")] + user = user_factory() + db.session.add_all([post, user]) + db.session.commit() + post.source = "new source" + post.notes = [model.PostNote(polygon=[(0, 0), (0, 1), (1, 1)], text="new")] + db.session.flush() + with patch("szurubooru.func.snapshots._post_to_webhooks"): + snapshots.modify(post, user) + db.session.flush() + results = db.session.query(model.Snapshot).all() + assert len(results) == 1 + assert results[0].data == { + "type": "object change", + "value": { + "source": { + "type": "primitive change", + "old-value": None, + "new-value": "new source", + }, + "notes": { + "type": "list change", + "removed": [ + {"polygon": [[0, 0], [0, 1], [1, 1]], "text": "old"} + ], + "added": [ + {"polygon": [[0, 0], [0, 1], [1, 1]], "text": "new"} + ], + }, + }, + } diff --git a/server/szurubooru/tests/func/test_tag_categories.py b/server/szurubooru/tests/func/test_tag_categories.py index 11300cf4..9e649a34 100644 --- a/server/szurubooru/tests/func/test_tag_categories.py +++ b/server/szurubooru/tests/func/test_tag_categories.py @@ -107,17 +107,16 @@ def test_update_category_name_reusing_other_name( tag_categories.update_category_name(category, "NAME") +@pytest.mark.parametrize("name", ["name", "NAME"]) def test_update_category_name_reusing_own_name( - config_injector, tag_category_factory + config_injector, tag_category_factory, name ): config_injector({"tag_category_name_regex": ".*"}) - for name in ["name", "NAME"]: - category = tag_category_factory(name="name") - db.session.add(category) - db.session.flush() - tag_categories.update_category_name(category, name) - assert category.name == name - db.session.rollback() + category = tag_category_factory(name="name") + db.session.add(category) + db.session.flush() + tag_categories.update_category_name(category, name) + assert category.name == name def test_update_category_color_with_empty_string(tag_category_factory): diff --git a/server/szurubooru/tests/func/test_tags.py b/server/szurubooru/tests/func/test_tags.py index ac8963c7..60df1220 100644 --- a/server/szurubooru/tests/func/test_tags.py +++ b/server/szurubooru/tests/func/test_tags.py @@ -513,15 +513,14 @@ def test_update_tag_names_trying_to_use_taken_name( tags.update_tag_names(tag, ["A"]) -def test_update_tag_names_reusing_own_name(config_injector, tag_factory): +@pytest.mark.parametrize("name", list("aA")) +def test_update_tag_names_reusing_own_name(config_injector, tag_factory, name): config_injector({"tag_name_regex": "^[a-zA-Z]*$"}) - for name in list("aA"): - tag = tag_factory(names=["a"]) - db.session.add(tag) - db.session.flush() - tags.update_tag_names(tag, [name]) - assert [tag_name.name for tag_name in tag.names] == [name] - db.session.rollback() + tag = tag_factory(names=["a"]) + db.session.add(tag) + db.session.flush() + tags.update_tag_names(tag, [name]) + assert [tag_name.name for tag_name in tag.names] == [name] def test_update_tag_names_changing_primary_name(config_injector, tag_factory): @@ -533,7 +532,6 @@ def test_update_tag_names_changing_primary_name(config_injector, tag_factory): db.session.flush() db.session.refresh(tag) assert [tag_name.name for tag_name in tag.names] == ["b", "a"] - db.session.rollback() @pytest.mark.parametrize("attempt", ["name", "NAME", "alias", "ALIAS"]) diff --git a/server/szurubooru/tests/search/configs/test_pool_search_config.py b/server/szurubooru/tests/search/configs/test_pool_search_config.py index 202635c6..1103ec40 100644 --- a/server/szurubooru/tests/search/configs/test_pool_search_config.py +++ b/server/szurubooru/tests/search/configs/test_pool_search_config.py @@ -136,8 +136,6 @@ def test_escaping( ) db.session.flush() - if db_driver and db.session.get_bind().driver != db_driver: - pytest.xfail() if expected_pool_names is None: with pytest.raises(errors.SearchError): executor.execute(input, offset=0, limit=100) diff --git a/server/szurubooru/tests/search/configs/test_tag_search_config.py b/server/szurubooru/tests/search/configs/test_tag_search_config.py index 8175b73c..9fe9a80e 100644 --- a/server/szurubooru/tests/search/configs/test_tag_search_config.py +++ b/server/szurubooru/tests/search/configs/test_tag_search_config.py @@ -134,8 +134,6 @@ def test_escaping(executor, tag_factory, input, expected_tag_names, db_driver): ) db.session.flush() - if db_driver and db.session.get_bind().driver != db_driver: - pytest.xfail() if expected_tag_names is None: with pytest.raises(errors.SearchError): executor.execute(input, offset=0, limit=100)