server/tests: use postgresql test database
This commit is contained in:
parent
e19d7041d1
commit
0e6427d8bc
10 changed files with 21 additions and 57 deletions
|
@ -1,4 +1,5 @@
|
||||||
pytest>=2.9.1
|
pytest>=2.9.1
|
||||||
pytest-cov>=2.2.1
|
pytest-cov>=2.2.1
|
||||||
|
pytest-pgsql>=1.1.1
|
||||||
freezegun>=0.3.6
|
freezegun>=0.3.6
|
||||||
pycodestyle>=2.0.0
|
pycodestyle>=2.0.0
|
||||||
|
|
|
@ -13,10 +13,12 @@ USER root
|
||||||
RUN apk --no-cache add \
|
RUN apk --no-cache add \
|
||||||
py3-pytest \
|
py3-pytest \
|
||||||
py3-pytest-cov \
|
py3-pytest-cov \
|
||||||
|
postgresql \
|
||||||
&& \
|
&& \
|
||||||
pip3 install \
|
pip3 install \
|
||||||
--no-cache-dir \
|
--no-cache-dir \
|
||||||
--disable-pip-version-check \
|
--disable-pip-version-check \
|
||||||
|
pytest-pgsql \
|
||||||
freezegun
|
freezegun
|
||||||
USER app
|
USER app
|
||||||
ENV POSTGRES_HOST=x \
|
ENV POSTGRES_HOST=x \
|
||||||
|
|
|
@ -7,8 +7,8 @@ from szurubooru import config
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
_data = threading.local()
|
_data = threading.local()
|
||||||
_engine = sa.create_engine(config.config['database']) # type: Any
|
_engine = sa.create_engine(config.config['database']) # type: Any
|
||||||
sessionmaker = sa.orm.sessionmaker(bind=_engine, autoflush=False) # type: Any
|
_sessionmaker = sa.orm.sessionmaker(bind=_engine, autoflush=False) # type: Any
|
||||||
session = sa.orm.scoped_session(sessionmaker) # type: Any
|
session = sa.orm.scoped_session(_sessionmaker) # type: Any
|
||||||
|
|
||||||
|
|
||||||
def get_session() -> Any:
|
def get_session() -> Any:
|
||||||
|
|
|
@ -362,10 +362,11 @@ def create_post(
|
||||||
post.type = ''
|
post.type = ''
|
||||||
post.checksum = ''
|
post.checksum = ''
|
||||||
post.mime_type = ''
|
post.mime_type = ''
|
||||||
db.session.add(post)
|
|
||||||
|
|
||||||
update_post_content(post, content)
|
update_post_content(post, content)
|
||||||
new_tags = update_post_tags(post, tag_names)
|
new_tags = update_post_tags(post, tag_names)
|
||||||
|
|
||||||
|
db.session.add(post)
|
||||||
return post, new_tags
|
return post, new_tags
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from typing import Any, Optional, Dict, Callable
|
from typing import Any, Optional, Dict, Callable
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
import sqlalchemy as sa
|
||||||
from szurubooru import db, model
|
from szurubooru import db, model
|
||||||
from szurubooru.func import diff, users
|
from szurubooru.func import diff, users
|
||||||
|
|
||||||
|
@ -104,7 +105,7 @@ def modify(entity: model.Base, auth_user: Optional[model.User]) -> None:
|
||||||
snapshot = _create(model.Snapshot.OPERATION_MODIFIED, entity, auth_user)
|
snapshot = _create(model.Snapshot.OPERATION_MODIFIED, entity, auth_user)
|
||||||
snapshot_factory = _snapshot_factories[snapshot.resource_type]
|
snapshot_factory = _snapshot_factories[snapshot.resource_type]
|
||||||
|
|
||||||
detached_session = db.sessionmaker()
|
detached_session = sa.orm.sessionmaker(bind=db.session.get_bind())()
|
||||||
detached_entity = detached_session.query(table).get(snapshot.resource_pkey)
|
detached_entity = detached_session.query(table).get(snapshot.resource_pkey)
|
||||||
assert detached_entity, 'Entity not found in DB, have you committed it?'
|
assert detached_entity, 'Entity not found in DB, have you committed it?'
|
||||||
detached_snapshot = snapshot_factory(detached_entity)
|
detached_snapshot = snapshot_factory(detached_entity)
|
||||||
|
|
|
@ -275,7 +275,6 @@ def test_errors_not_spending_ids(
|
||||||
params={'safety': 'safe', 'tags': []},
|
params={'safety': 'safe', 'tags': []},
|
||||||
files={'content': read_asset('png.png')},
|
files={'content': read_asset('png.png')},
|
||||||
user=auth_user))
|
user=auth_user))
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
# erroreous request (duplicate post)
|
# erroreous request (duplicate post)
|
||||||
with pytest.raises(posts.PostAlreadyUploadedError):
|
with pytest.raises(posts.PostAlreadyUploadedError):
|
||||||
|
@ -284,7 +283,6 @@ def test_errors_not_spending_ids(
|
||||||
params={'safety': 'safe', 'tags': []},
|
params={'safety': 'safe', 'tags': []},
|
||||||
files={'content': read_asset('png.png')},
|
files={'content': read_asset('png.png')},
|
||||||
user=auth_user))
|
user=auth_user))
|
||||||
db.session.rollback()
|
|
||||||
|
|
||||||
# successful request
|
# successful request
|
||||||
with patch('szurubooru.func.posts.serialize_post'), \
|
with patch('szurubooru.func.posts.serialize_post'), \
|
||||||
|
|
|
@ -11,37 +11,6 @@ import sqlalchemy as sa
|
||||||
from szurubooru import config, db, model, rest
|
from szurubooru import config, db, model, rest
|
||||||
|
|
||||||
|
|
||||||
class QueryCounter:
|
|
||||||
def __init__(self):
|
|
||||||
self._statements = []
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
self._statements = []
|
|
||||||
|
|
||||||
def __exit__(self, *args, **kwargs):
|
|
||||||
self._statements = []
|
|
||||||
|
|
||||||
def create_before_cursor_execute(self):
|
|
||||||
def before_cursor_execute(
|
|
||||||
_conn, _cursor, statement, _params, _context, _executemany):
|
|
||||||
self._statements.append(statement)
|
|
||||||
return before_cursor_execute
|
|
||||||
|
|
||||||
@property
|
|
||||||
def statements(self):
|
|
||||||
return self._statements
|
|
||||||
|
|
||||||
|
|
||||||
_query_counter = QueryCounter()
|
|
||||||
_engine = sa.create_engine('sqlite:///:memory:')
|
|
||||||
model.Base.metadata.drop_all(bind=_engine)
|
|
||||||
model.Base.metadata.create_all(bind=_engine)
|
|
||||||
sa.event.listen(
|
|
||||||
_engine,
|
|
||||||
'before_cursor_execute',
|
|
||||||
_query_counter.create_before_cursor_execute())
|
|
||||||
|
|
||||||
|
|
||||||
def get_unique_name():
|
def get_unique_name():
|
||||||
alphabet = string.ascii_letters + string.digits
|
alphabet = string.ascii_letters + string.digits
|
||||||
return ''.join(random.choice(alphabet) for _ in range(8))
|
return ''.join(random.choice(alphabet) for _ in range(8))
|
||||||
|
@ -58,11 +27,6 @@ def fake_datetime():
|
||||||
return injector
|
return injector
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def query_counter():
|
|
||||||
return _query_counter
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope='session')
|
@pytest.fixture(scope='session')
|
||||||
def query_logger(pytestconfig):
|
def query_logger(pytestconfig):
|
||||||
if pytestconfig.option.verbose > 0:
|
if pytestconfig.option.verbose > 0:
|
||||||
|
@ -75,17 +39,13 @@ def query_logger(pytestconfig):
|
||||||
|
|
||||||
|
|
||||||
@pytest.yield_fixture(scope='function', autouse=True)
|
@pytest.yield_fixture(scope='function', autouse=True)
|
||||||
def session(query_logger): # pylint: disable=unused-argument
|
def session(query_logger, postgresql_db): # pylint: disable=unused-argument
|
||||||
db.sessionmaker = sa.orm.sessionmaker(
|
db.session = postgresql_db.session
|
||||||
bind=_engine, autoflush=False)
|
postgresql_db.create_table(*model.Base.metadata.sorted_tables)
|
||||||
db.session = sa.orm.scoped_session(db.sessionmaker)
|
|
||||||
try:
|
try:
|
||||||
yield db.session
|
yield postgresql_db.session
|
||||||
finally:
|
finally:
|
||||||
db.session.remove()
|
postgresql_db.reset_db()
|
||||||
for table in reversed(model.Base.metadata.sorted_tables):
|
|
||||||
db.session.execute(table.delete())
|
|
||||||
db.session.commit()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
@ -139,7 +139,7 @@ def test_create(tag_factory, user_factory):
|
||||||
|
|
||||||
|
|
||||||
def test_modify_saves_non_empty_diffs(post_factory, user_factory):
|
def test_modify_saves_non_empty_diffs(post_factory, user_factory):
|
||||||
if 'sqlite' in db.sessionmaker.kw['bind'].driver:
|
if 'sqlite' in db.session.get_bind().driver:
|
||||||
pytest.xfail(
|
pytest.xfail(
|
||||||
'SQLite doesn\'t support transaction isolation, '
|
'SQLite doesn\'t support transaction isolation, '
|
||||||
'which is required to retrieve original entity')
|
'which is required to retrieve original entity')
|
||||||
|
|
|
@ -47,7 +47,7 @@ def test_serialize_tag_when_empty():
|
||||||
def test_serialize_tag(post_factory, tag_factory, tag_category_factory):
|
def test_serialize_tag(post_factory, tag_factory, tag_category_factory):
|
||||||
cat = tag_category_factory(name='cat')
|
cat = tag_category_factory(name='cat')
|
||||||
tag = tag_factory(names=['tag1', 'tag2'], category=cat)
|
tag = tag_factory(names=['tag1', 'tag2'], category=cat)
|
||||||
tag.tag_id = 1
|
# tag.tag_id = 1
|
||||||
tag.description = 'description'
|
tag.description = 'description'
|
||||||
tag.suggestions = [
|
tag.suggestions = [
|
||||||
tag_factory(names=['sug1'], category=cat),
|
tag_factory(names=['sug1'], category=cat),
|
||||||
|
@ -58,12 +58,14 @@ def test_serialize_tag(post_factory, tag_factory, tag_category_factory):
|
||||||
tag_factory(names=['impl2'], category=cat),
|
tag_factory(names=['impl2'], category=cat),
|
||||||
]
|
]
|
||||||
tag.last_edit_time = datetime(1998, 1, 1)
|
tag.last_edit_time = datetime(1998, 1, 1)
|
||||||
|
|
||||||
post1 = post_factory()
|
post1 = post_factory()
|
||||||
post2 = post_factory()
|
|
||||||
post1.tags = [tag]
|
post1.tags = [tag]
|
||||||
|
post2 = post_factory()
|
||||||
post2.tags = [tag]
|
post2.tags = [tag]
|
||||||
db.session.add_all([tag, post1, post2])
|
db.session.add_all([tag, post1, post2])
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
|
|
||||||
result = tags.serialize_tag(tag)
|
result = tags.serialize_tag(tag)
|
||||||
result['suggestions'].sort(key=lambda relation: relation['names'][0])
|
result['suggestions'].sort(key=lambda relation: relation['names'][0])
|
||||||
result['implications'].sort(key=lambda relation: relation['names'][0])
|
result['implications'].sort(key=lambda relation: relation['names'][0])
|
||||||
|
|
|
@ -86,9 +86,8 @@ def test_escaping(
|
||||||
])
|
])
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
|
|
||||||
if db_driver:
|
if db_driver and db.session.get_bind().driver != db_driver:
|
||||||
if db.sessionmaker.kw['bind'].driver != db_driver:
|
pytest.xfail()
|
||||||
pytest.xfail()
|
|
||||||
if expected_tag_names is None:
|
if expected_tag_names is None:
|
||||||
with pytest.raises(errors.SearchError):
|
with pytest.raises(errors.SearchError):
|
||||||
executor.execute(input, offset=0, limit=100)
|
executor.execute(input, offset=0, limit=100)
|
||||||
|
|
Loading…
Reference in a new issue