server/tests: use real database

I'm experimenting with snapshots and found following limitation of
SQLite: https://www.sqlite.org/isolation.html
This commit is contained in:
rr- 2016-08-15 17:53:01 +02:00
parent 0320a0b55b
commit 87b1ee4564
6 changed files with 117 additions and 78 deletions

View file

@ -25,6 +25,15 @@ database:
pass: # example: dog pass: # example: dog
name: # example: szuru name: # example: szuru
# required for runing the test suite
test_database:
schema: postgres
host: # example: localhost
port: # example: 5432
user: # example: szuru
pass: # example: dog
name: # example: szuru_test
# used to send password reminders # used to send password reminders
smtp: smtp:
host: # example: localhost host: # example: localhost

View file

@ -1,8 +1,9 @@
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
import contextlib import contextlib
import os import os
import datetime import random
import uuid import string
from datetime import datetime
import pytest import pytest
import freezegun import freezegun
import sqlalchemy import sqlalchemy
@ -30,8 +31,19 @@ class QueryCounter(object):
return self._statements return self._statements
if not config.config['test_database']['host']:
raise RuntimeError('Test database not configured.')
_query_counter = QueryCounter() _query_counter = QueryCounter()
_engine = sqlalchemy.create_engine('sqlite:///:memory:') _engine = sqlalchemy.create_engine(
'{schema}://{user}:{password}@{host}:{port}/{name}'.format(
schema=config.config['test_database']['schema'],
user=config.config['test_database']['user'],
password=config.config['test_database']['pass'],
host=config.config['test_database']['host'],
port=config.config['test_database']['port'],
name=config.config['test_database']['name']))
db.Base.metadata.drop_all(bind=_engine)
db.Base.metadata.create_all(bind=_engine) db.Base.metadata.create_all(bind=_engine)
sqlalchemy.event.listen( sqlalchemy.event.listen(
_engine, _engine,
@ -40,7 +52,8 @@ sqlalchemy.event.listen(
def get_unique_name(): def get_unique_name():
return str(uuid.uuid4()) alphabet = string.ascii_letters + string.digits
return ''.join(random.choice(alphabet) for _ in range(8))
@pytest.fixture @pytest.fixture
@ -72,16 +85,15 @@ def query_logger():
@pytest.yield_fixture(scope='function', autouse=True) @pytest.yield_fixture(scope='function', autouse=True)
def session(query_logger): # pylint: disable=unused-argument def session(query_logger): # pylint: disable=unused-argument
session_maker = sqlalchemy.orm.sessionmaker(bind=_engine) db.sessionmaker = sqlalchemy.orm.sessionmaker(bind=_engine)
session = sqlalchemy.orm.scoped_session(session_maker) db.session = sqlalchemy.orm.scoped_session(db.sessionmaker)
db.session = session
try: try:
yield session yield db.session
finally: finally:
session.remove() db.session.remove()
for table in reversed(db.Base.metadata.sorted_tables): for table in reversed(db.Base.metadata.sorted_tables):
session.execute(table.delete()) db.session.execute(table.delete())
session.commit() db.session.commit()
@pytest.fixture @pytest.fixture
@ -115,7 +127,7 @@ def user_factory():
user.password_hash = 'dummy' user.password_hash = 'dummy'
user.email = email user.email = email
user.rank = rank user.rank = rank
user.creation_time = datetime.datetime(1997, 1, 1) user.creation_time = datetime(1997, 1, 1)
user.avatar_style = db.User.AVATAR_GRAVATAR user.avatar_style = db.User.AVATAR_GRAVATAR
return user return user
return factory return factory
@ -142,7 +154,7 @@ def tag_factory():
tag.names = [ tag.names = [
db.TagName(name) for name in names or [get_unique_name()]] db.TagName(name) for name in names or [get_unique_name()]]
tag.category = category tag.category = category
tag.creation_time = datetime.datetime(1996, 1, 1) tag.creation_time = datetime(1996, 1, 1)
return tag return tag
return factory return factory
@ -162,14 +174,14 @@ def post_factory():
post.checksum = checksum post.checksum = checksum
post.flags = [] post.flags = []
post.mime_type = 'application/octet-stream' post.mime_type = 'application/octet-stream'
post.creation_time = datetime.datetime(1996, 1, 1) post.creation_time = datetime(1996, 1, 1)
return post return post
return factory return factory
@pytest.fixture @pytest.fixture
def comment_factory(user_factory, post_factory): def comment_factory(user_factory, post_factory):
def factory(user=None, post=None, text='dummy'): def factory(user=None, post=None, text='dummy', time=None):
if not user: if not user:
user = user_factory() user = user_factory()
db.session.add(user) db.session.add(user)
@ -180,7 +192,7 @@ def comment_factory(user_factory, post_factory):
comment.user = user comment.user = user
comment.post = post comment.post = post
comment.text = text comment.text = text
comment.creation_time = datetime.datetime(1996, 1, 1) comment.creation_time = time or datetime(1996, 1, 1)
return comment return comment
return factory return factory

View file

@ -145,12 +145,12 @@ def test_cascade_deletions(post_factory, user_factory, comment_factory):
snapshot.user = user snapshot.user = user
snapshot.creation_time = datetime(1997, 1, 1) snapshot.creation_time = datetime(1997, 1, 1)
snapshot.resource_type = '-' snapshot.resource_type = '-'
snapshot.resource_id = '-' snapshot.resource_id = 1
snapshot.resource_repr = '-' snapshot.resource_repr = '-'
snapshot.operation = '-' snapshot.operation = '-'
db.session.add_all([user, post, comment, snapshot]) db.session.add_all([user, post, comment, snapshot])
db.session.flush() db.session.commit()
assert not db.session.dirty assert not db.session.dirty
assert post.user is not None and post.user.user_id is not None assert post.user is not None and post.user.user_id is not None

View file

@ -35,8 +35,9 @@ def test_serialize_user(user_factory, comment_factory):
def test_try_get_comment(comment_factory): def test_try_get_comment(comment_factory):
comment = comment_factory() comment = comment_factory()
db.session.add(comment) db.session.add(comment)
assert comments.try_get_comment_by_id(999) is None db.session.flush()
assert comments.try_get_comment_by_id(1) is comment assert comments.try_get_comment_by_id(comment.comment_id + 1) is None
assert comments.try_get_comment_by_id(comment.comment_id) is comment
with pytest.raises(comments.InvalidCommentIdError): with pytest.raises(comments.InvalidCommentIdError):
comments.try_get_comment_by_id('-') comments.try_get_comment_by_id('-')
@ -44,9 +45,10 @@ def test_try_get_comment(comment_factory):
def test_get_comment(comment_factory): def test_get_comment(comment_factory):
comment = comment_factory() comment = comment_factory()
db.session.add(comment) db.session.add(comment)
db.session.flush()
with pytest.raises(comments.CommentNotFoundError): with pytest.raises(comments.CommentNotFoundError):
comments.get_comment_by_id(999) comments.get_comment_by_id(comment.comment_id + 1)
assert comments.get_comment_by_id(1) is comment assert comments.get_comment_by_id(comment.comment_id) is comment
with pytest.raises(comments.InvalidCommentIdError): with pytest.raises(comments.InvalidCommentIdError):
comments.get_comment_by_id('-') comments.get_comment_by_id('-')

View file

@ -110,8 +110,14 @@ def test_serialize_post(
db.session.flush() db.session.flush()
db.session.add_all([ db.session.add_all([
comment_factory(user=user_factory(name='commenter1'), post=post), comment_factory(
comment_factory(user=user_factory(name='commenter2'), post=post), user=user_factory(name='commenter1'),
post=post,
time=datetime(1999, 1, 1)),
comment_factory(
user=user_factory(name='commenter2'),
post=post,
time=datetime(1999, 1, 2)),
db.PostFavorite( db.PostFavorite(
post=post, post=post,
user=user_factory(name='fav1'), user=user_factory(name='fav1'),
@ -458,8 +464,8 @@ def test_update_post_relations(post_factory):
post = post_factory() post = post_factory()
posts.update_post_relations(post, [relation1.post_id, relation2.post_id]) posts.update_post_relations(post, [relation1.post_id, relation2.post_id])
assert len(post.relations) == 2 assert len(post.relations) == 2
assert post.relations[0].post_id == relation1.post_id assert sorted(r.post_id for r in post.relations) == [
assert post.relations[1].post_id == relation2.post_id relation1.post_id, relation2.post_id]
def test_update_post_relations_bidirectionality(post_factory): def test_update_post_relations_bidirectionality(post_factory):

View file

@ -110,7 +110,12 @@ def test_export_to_json(
export_path = os.path.join(str(tmpdir), 'tags.json') export_path = os.path.join(str(tmpdir), 'tags.json')
assert os.path.exists(export_path) assert os.path.exists(export_path)
with open(export_path, 'r') as handle: with open(export_path, 'r') as handle:
assert json.loads(handle.read()) == { actual_json = json.loads(handle.read())
assert actual_json['tags']
assert actual_json['categories']
actual_json['tags'].sort(key=lambda tag: tag['names'][0])
actual_json['categories'].sort(key=lambda category: category['name'])
assert actual_json == {
'tags': [ 'tags': [
{ {
'names': ['alias1', 'alias2'], 'names': ['alias1', 'alias2'],
@ -119,14 +124,14 @@ def test_export_to_json(
'suggestions': ['sug1', 'sug2'], 'suggestions': ['sug1', 'sug2'],
'implications': ['imp1', 'imp2'], 'implications': ['imp1', 'imp2'],
}, },
{'names': ['sug1'], 'usages': 0, 'category': 'cat1'},
{'names': ['sug2'], 'usages': 0, 'category': 'cat1'},
{'names': ['imp1'], 'usages': 0, 'category': 'cat1'}, {'names': ['imp1'], 'usages': 0, 'category': 'cat1'},
{'names': ['imp2'], 'usages': 0, 'category': 'cat1'}, {'names': ['imp2'], 'usages': 0, 'category': 'cat1'},
{'names': ['sug1'], 'usages': 0, 'category': 'cat1'},
{'names': ['sug2'], 'usages': 0, 'category': 'cat1'},
], ],
'categories': [ 'categories': [
{'name': 'cat2', 'color': 'white'},
{'name': 'cat1', 'color': 'black'}, {'name': 'cat1', 'color': 'black'},
{'name': 'cat2', 'color': 'white'},
] ]
} }
@ -164,76 +169,81 @@ def test_get_tag_by_name(name_to_search, expected_to_find, tag_factory):
tags.get_tag_by_name(name_to_search) tags.get_tag_by_name(name_to_search)
@pytest.mark.parametrize('names,expected_ids', [ @pytest.mark.parametrize('names,expected_indexes', [
([], []), ([], []),
(['name1'], [1]), (['name1'], [0]),
(['NAME1'], [1]), (['NAME1'], [0]),
(['alias1'], [1]), (['alias1'], [0]),
(['ALIAS1'], [1]), (['ALIAS1'], [0]),
(['name2'], [2]), (['name2'], [1]),
(['name1', 'name1'], [1]), (['name1', 'name1'], [0]),
(['name1', 'NAME1'], [1]), (['name1', 'NAME1'], [0]),
(['name1', 'alias1'], [1]), (['name1', 'alias1'], [0]),
(['name1', 'alias2'], [1, 2]), (['name1', 'alias2'], [0, 1]),
(['NAME1', 'alias2'], [1, 2]), (['NAME1', 'alias2'], [0, 1]),
(['name1', 'ALIAS2'], [1, 2]), (['name1', 'ALIAS2'], [0, 1]),
(['name2', 'alias1'], [1, 2]), (['name2', 'alias1'], [0, 1]),
]) ])
def test_get_tag_by_names(names, expected_ids, tag_factory): def test_get_tag_by_names(names, expected_indexes, tag_factory):
tag1 = tag_factory(names=['name1', 'ALIAS1']) input_tags = [
tag2 = tag_factory(names=['name2', 'ALIAS2']) tag_factory(names=['name1', 'ALIAS1']),
tag1.tag_id = 1 tag_factory(names=['name2', 'ALIAS2']),
tag2.tag_id = 2 ]
db.session.add_all([tag1, tag2]) db.session.add_all(input_tags)
db.session.flush()
expected_ids = [input_tags[i].tag_id for i in expected_indexes]
actual_ids = [tag.tag_id for tag in tags.get_tags_by_names(names)] actual_ids = [tag.tag_id for tag in tags.get_tags_by_names(names)]
assert actual_ids == expected_ids assert actual_ids == expected_ids
@pytest.mark.parametrize( @pytest.mark.parametrize(
'names,expected_ids,expected_created_names', [ 'names,expected_indexes,expected_created_names', [
([], [], []), ([], [], []),
(['name1'], [1], []), (['name1'], [0], []),
(['NAME1'], [1], []), (['NAME1'], [0], []),
(['alias1'], [1], []), (['alias1'], [0], []),
(['ALIAS1'], [1], []), (['ALIAS1'], [0], []),
(['name2'], [2], []), (['name2'], [1], []),
(['name1', 'name1'], [1], []), (['name1', 'name1'], [0], []),
(['name1', 'NAME1'], [1], []), (['name1', 'NAME1'], [0], []),
(['name1', 'alias1'], [1], []), (['name1', 'alias1'], [0], []),
(['name1', 'alias2'], [1, 2], []), (['name1', 'alias2'], [0, 1], []),
(['NAME1', 'alias2'], [1, 2], []), (['NAME1', 'alias2'], [0, 1], []),
(['name1', 'ALIAS2'], [1, 2], []), (['name1', 'ALIAS2'], [0, 1], []),
(['name2', 'alias1'], [1, 2], []), (['name2', 'alias1'], [0, 1], []),
(['new'], [], ['new']), (['new'], [], ['new']),
(['new', 'name1'], [1], ['new']), (['new', 'name1'], [0], ['new']),
(['new', 'NAME1'], [1], ['new']), (['new', 'NAME1'], [0], ['new']),
(['new', 'alias1'], [1], ['new']), (['new', 'alias1'], [0], ['new']),
(['new', 'ALIAS1'], [1], ['new']), (['new', 'ALIAS1'], [0], ['new']),
(['new', 'name2'], [2], ['new']), (['new', 'name2'], [1], ['new']),
(['new', 'name1', 'name1'], [1], ['new']), (['new', 'name1', 'name1'], [0], ['new']),
(['new', 'name1', 'NAME1'], [1], ['new']), (['new', 'name1', 'NAME1'], [0], ['new']),
(['new', 'name1', 'alias1'], [1], ['new']), (['new', 'name1', 'alias1'], [0], ['new']),
(['new', 'name1', 'alias2'], [1, 2], ['new']), (['new', 'name1', 'alias2'], [0, 1], ['new']),
(['new', 'NAME1', 'alias2'], [1, 2], ['new']), (['new', 'NAME1', 'alias2'], [0, 1], ['new']),
(['new', 'name1', 'ALIAS2'], [1, 2], ['new']), (['new', 'name1', 'ALIAS2'], [0, 1], ['new']),
(['new', 'name2', 'alias1'], [1, 2], ['new']), (['new', 'name2', 'alias1'], [0, 1], ['new']),
(['new', 'new'], [], ['new']), (['new', 'new'], [], ['new']),
(['new', 'NEW'], [], ['new']), (['new', 'NEW'], [], ['new']),
(['new', 'new2'], [], ['new', 'new2']), (['new', 'new2'], [], ['new', 'new2']),
]) ])
def test_get_or_create_tags_by_names( def test_get_or_create_tags_by_names(
names, names,
expected_ids, expected_indexes,
expected_created_names, expected_created_names,
tag_factory, tag_factory,
tag_category_factory, tag_category_factory,
config_injector): config_injector):
config_injector({'tag_name_regex': '.*'}) config_injector({'tag_name_regex': '.*'})
category = tag_category_factory() category = tag_category_factory()
tag1 = tag_factory(names=['name1', 'ALIAS1'], category=category) input_tags = [
tag2 = tag_factory(names=['name2', 'ALIAS2'], category=category) tag_factory(names=['name1', 'ALIAS1'], category=category),
db.session.add_all([tag1, tag2]) tag_factory(names=['name2', 'ALIAS2'], category=category),
]
db.session.add_all(input_tags)
result = tags.get_or_create_tags_by_names(names) result = tags.get_or_create_tags_by_names(names)
expected_ids = [input_tags[i].tag_id for i in expected_indexes]
actual_ids = [tag.tag_id for tag in result[0]] actual_ids = [tag.tag_id for tag in result[0]]
actual_created_names = [tag.names[0].name for tag in result[1]] actual_created_names = [tag.names[0].name for tag in result[1]]
assert actual_ids == expected_ids assert actual_ids == expected_ids