diff --git a/server/dev-requirements.txt b/server/dev-requirements.txt index c9dc2348..04788a31 100644 --- a/server/dev-requirements.txt +++ b/server/dev-requirements.txt @@ -1,4 +1,5 @@ pytest>=2.9.1 pytest-cov>=2.2.1 +pytest-pgsql>=1.1.1 freezegun>=0.3.6 pycodestyle>=2.0.0 diff --git a/server/hooks/test b/server/hooks/test index 6c4c0435..5f642f1a 100755 --- a/server/hooks/test +++ b/server/hooks/test @@ -13,10 +13,12 @@ USER root RUN apk --no-cache add \ py3-pytest \ py3-pytest-cov \ + postgresql \ && \ pip3 install \ --no-cache-dir \ --disable-pip-version-check \ + pytest-pgsql \ freezegun USER app ENV POSTGRES_HOST=x \ diff --git a/server/szurubooru/db.py b/server/szurubooru/db.py index f90bfaf9..561b7484 100644 --- a/server/szurubooru/db.py +++ b/server/szurubooru/db.py @@ -7,8 +7,8 @@ from szurubooru import config # pylint: disable=invalid-name _data = threading.local() _engine = sa.create_engine(config.config['database']) # type: Any -sessionmaker = sa.orm.sessionmaker(bind=_engine, autoflush=False) # type: Any -session = sa.orm.scoped_session(sessionmaker) # type: Any +_sessionmaker = sa.orm.sessionmaker(bind=_engine, autoflush=False) # type: Any +session = sa.orm.scoped_session(_sessionmaker) # type: Any def get_session() -> Any: diff --git a/server/szurubooru/func/posts.py b/server/szurubooru/func/posts.py index ee58633b..73407f66 100644 --- a/server/szurubooru/func/posts.py +++ b/server/szurubooru/func/posts.py @@ -362,10 +362,11 @@ def create_post( post.type = '' post.checksum = '' post.mime_type = '' - db.session.add(post) update_post_content(post, content) new_tags = update_post_tags(post, tag_names) + + db.session.add(post) return post, new_tags diff --git a/server/szurubooru/func/snapshots.py b/server/szurubooru/func/snapshots.py index 240c3bce..eade3a19 100644 --- a/server/szurubooru/func/snapshots.py +++ b/server/szurubooru/func/snapshots.py @@ -1,5 +1,6 @@ from typing import Any, Optional, Dict, Callable from datetime import datetime +import sqlalchemy as sa from szurubooru import db, model from szurubooru.func import diff, users @@ -104,7 +105,7 @@ def modify(entity: model.Base, auth_user: Optional[model.User]) -> None: snapshot = _create(model.Snapshot.OPERATION_MODIFIED, entity, auth_user) snapshot_factory = _snapshot_factories[snapshot.resource_type] - detached_session = db.sessionmaker() + detached_session = sa.orm.sessionmaker(bind=db.session.get_bind())() detached_entity = detached_session.query(table).get(snapshot.resource_pkey) assert detached_entity, 'Entity not found in DB, have you committed it?' detached_snapshot = snapshot_factory(detached_entity) diff --git a/server/szurubooru/tests/api/test_post_creating.py b/server/szurubooru/tests/api/test_post_creating.py index 6edcdd36..c7088978 100644 --- a/server/szurubooru/tests/api/test_post_creating.py +++ b/server/szurubooru/tests/api/test_post_creating.py @@ -275,7 +275,6 @@ def test_errors_not_spending_ids( params={'safety': 'safe', 'tags': []}, files={'content': read_asset('png.png')}, user=auth_user)) - db.session.commit() # erroreous request (duplicate post) with pytest.raises(posts.PostAlreadyUploadedError): @@ -284,7 +283,6 @@ def test_errors_not_spending_ids( params={'safety': 'safe', 'tags': []}, files={'content': read_asset('png.png')}, user=auth_user)) - db.session.rollback() # successful request with patch('szurubooru.func.posts.serialize_post'), \ diff --git a/server/szurubooru/tests/conftest.py b/server/szurubooru/tests/conftest.py index 45ce88ec..593771d8 100644 --- a/server/szurubooru/tests/conftest.py +++ b/server/szurubooru/tests/conftest.py @@ -11,37 +11,6 @@ import sqlalchemy as sa from szurubooru import config, db, model, rest -class QueryCounter: - def __init__(self): - self._statements = [] - - def __enter__(self): - self._statements = [] - - def __exit__(self, *args, **kwargs): - self._statements = [] - - def create_before_cursor_execute(self): - def before_cursor_execute( - _conn, _cursor, statement, _params, _context, _executemany): - self._statements.append(statement) - return before_cursor_execute - - @property - def statements(self): - return self._statements - - -_query_counter = QueryCounter() -_engine = sa.create_engine('sqlite:///:memory:') -model.Base.metadata.drop_all(bind=_engine) -model.Base.metadata.create_all(bind=_engine) -sa.event.listen( - _engine, - 'before_cursor_execute', - _query_counter.create_before_cursor_execute()) - - def get_unique_name(): alphabet = string.ascii_letters + string.digits return ''.join(random.choice(alphabet) for _ in range(8)) @@ -58,11 +27,6 @@ def fake_datetime(): return injector -@pytest.fixture() -def query_counter(): - return _query_counter - - @pytest.fixture(scope='session') def query_logger(pytestconfig): if pytestconfig.option.verbose > 0: @@ -75,17 +39,13 @@ def query_logger(pytestconfig): @pytest.yield_fixture(scope='function', autouse=True) -def session(query_logger): # pylint: disable=unused-argument - db.sessionmaker = sa.orm.sessionmaker( - bind=_engine, autoflush=False) - db.session = sa.orm.scoped_session(db.sessionmaker) +def session(query_logger, postgresql_db): # pylint: disable=unused-argument + db.session = postgresql_db.session + postgresql_db.create_table(*model.Base.metadata.sorted_tables) try: - yield db.session + yield postgresql_db.session finally: - db.session.remove() - for table in reversed(model.Base.metadata.sorted_tables): - db.session.execute(table.delete()) - db.session.commit() + postgresql_db.reset_db() @pytest.fixture diff --git a/server/szurubooru/tests/func/test_snapshots.py b/server/szurubooru/tests/func/test_snapshots.py index 09491990..053cf982 100644 --- a/server/szurubooru/tests/func/test_snapshots.py +++ b/server/szurubooru/tests/func/test_snapshots.py @@ -139,7 +139,7 @@ def test_create(tag_factory, user_factory): def test_modify_saves_non_empty_diffs(post_factory, user_factory): - if 'sqlite' in db.sessionmaker.kw['bind'].driver: + if 'sqlite' in db.session.get_bind().driver: pytest.xfail( 'SQLite doesn\'t support transaction isolation, ' 'which is required to retrieve original entity') diff --git a/server/szurubooru/tests/func/test_tags.py b/server/szurubooru/tests/func/test_tags.py index 2f888ef5..673e37f8 100644 --- a/server/szurubooru/tests/func/test_tags.py +++ b/server/szurubooru/tests/func/test_tags.py @@ -47,7 +47,7 @@ def test_serialize_tag_when_empty(): def test_serialize_tag(post_factory, tag_factory, tag_category_factory): cat = tag_category_factory(name='cat') tag = tag_factory(names=['tag1', 'tag2'], category=cat) - tag.tag_id = 1 + # tag.tag_id = 1 tag.description = 'description' tag.suggestions = [ tag_factory(names=['sug1'], category=cat), @@ -58,12 +58,14 @@ def test_serialize_tag(post_factory, tag_factory, tag_category_factory): tag_factory(names=['impl2'], category=cat), ] tag.last_edit_time = datetime(1998, 1, 1) + post1 = post_factory() - post2 = post_factory() post1.tags = [tag] + post2 = post_factory() post2.tags = [tag] db.session.add_all([tag, post1, post2]) db.session.flush() + result = tags.serialize_tag(tag) result['suggestions'].sort(key=lambda relation: relation['names'][0]) result['implications'].sort(key=lambda relation: relation['names'][0]) diff --git a/server/szurubooru/tests/search/configs/test_tag_search_config.py b/server/szurubooru/tests/search/configs/test_tag_search_config.py index 8ea107f1..d3d2de3f 100644 --- a/server/szurubooru/tests/search/configs/test_tag_search_config.py +++ b/server/szurubooru/tests/search/configs/test_tag_search_config.py @@ -86,9 +86,8 @@ def test_escaping( ]) db.session.flush() - if db_driver: - if db.sessionmaker.kw['bind'].driver != db_driver: - pytest.xfail() + if db_driver and db.session.get_bind().driver != db_driver: + pytest.xfail() if expected_tag_names is None: with pytest.raises(errors.SearchError): executor.execute(input, offset=0, limit=100)