server/tests: update func.posts tests

This commit is contained in:
rr- 2016-08-14 11:48:19 +02:00
parent 264f9ee70b
commit d102c9bdba
2 changed files with 49 additions and 30 deletions

View file

@ -1,7 +1,7 @@
import os import os
import datetime
import unittest.mock import unittest.mock
import pytest import pytest
from datetime import datetime
from szurubooru import db from szurubooru import db
from szurubooru.func import posts, users, comments, snapshots, tags, images from szurubooru.func import posts, users, comments, snapshots, tags, images
@ -61,7 +61,7 @@ def test_serialize_note():
'text': '...' 'text': '...'
} }
def test_serialize_empty_post(): def test_serialize_post_when_empty():
assert posts.serialize_post(None, None) is None assert posts.serialize_post(None, None) is None
def test_serialize_post( def test_serialize_post(
@ -80,8 +80,8 @@ def test_serialize_post(
auth_user = user_factory(name='auth user') auth_user = user_factory(name='auth user')
post = db.Post() post = db.Post()
post.post_id = 1 post.post_id = 1
post.creation_time = datetime.datetime(1997, 1, 1) post.creation_time = datetime(1997, 1, 1)
post.last_edit_time = datetime.datetime(1998, 1, 1) post.last_edit_time = datetime(1998, 1, 1)
post.tags = [ post.tags = [
tag_factory(names=['tag1', 'tag2']), tag_factory(names=['tag1', 'tag2']),
tag_factory(names=['tag3']) tag_factory(names=['tag3'])
@ -105,26 +105,26 @@ def test_serialize_post(
db.PostFavorite( db.PostFavorite(
post=post, post=post,
user=user_factory(name='fav1'), user=user_factory(name='fav1'),
time=datetime.datetime(1800, 1, 1)), time=datetime(1800, 1, 1)),
db.PostFeature( db.PostFeature(
post=post, post=post,
user=user_factory(), user=user_factory(),
time=datetime.datetime(1999, 1, 1)), time=datetime(1999, 1, 1)),
db.PostScore( db.PostScore(
post=post, post=post,
user=auth_user, user=auth_user,
score=-1, score=-1,
time=datetime.datetime(1800, 1, 1)), time=datetime(1800, 1, 1)),
db.PostScore( db.PostScore(
post=post, post=post,
user=user_factory(), user=user_factory(),
score=1, score=1,
time=datetime.datetime(1800, 1, 1)), time=datetime(1800, 1, 1)),
db.PostScore( db.PostScore(
post=post, post=post,
user=user_factory(), user=user_factory(),
score=1, score=1,
time=datetime.datetime(1800, 1, 1))]) time=datetime(1800, 1, 1))])
db.session.flush() db.session.flush()
result = posts.serialize_post(post, auth_user) result = posts.serialize_post(post, auth_user)
@ -133,8 +133,8 @@ def test_serialize_post(
assert result == { assert result == {
'id': 1, 'id': 1,
'version': 1, 'version': 1,
'creationTime': datetime.datetime(1997, 1, 1), 'creationTime': datetime(1997, 1, 1),
'lastEditTime': datetime.datetime(1998, 1, 1), 'lastEditTime': datetime(1998, 1, 1),
'safety': 'safe', 'safety': 'safe',
'source': '4gag', 'source': '4gag',
'type': 'image', 'type': 'image',
@ -158,7 +158,7 @@ def test_serialize_post(
'noteCount': 0, 'noteCount': 0,
'featureCount': 1, 'featureCount': 1,
'relationCount': 0, 'relationCount': 0,
'lastFeatureTime': datetime.datetime(1999, 1, 1), 'lastFeatureTime': datetime(1999, 1, 1),
'favoritedBy': ['fav1'], 'favoritedBy': ['fav1'],
'hasCustomThumbnail': True, 'hasCustomThumbnail': True,
'mimeType': 'image/jpeg', 'mimeType': 'image/jpeg',
@ -166,6 +166,18 @@ def test_serialize_post(
'comments': ['commenter1', 'commenter2'], 'comments': ['commenter1', 'commenter2'],
} }
def test_serialize_micro_post(post_factory, user_factory):
with unittest.mock.patch('szurubooru.func.posts.get_post_thumbnail_url'):
posts.get_post_thumbnail_url.return_value = 'https://example.com/thumb.png'
auth_user = user_factory()
post = post_factory()
db.session.add(post)
db.session.flush()
assert posts.serialize_micro_post(post, auth_user) == {
'id': post.post_id,
'thumbnailUrl': 'https://example.com/thumb.png',
}
def test_get_post_count(post_factory): def test_get_post_count(post_factory):
previous_count = posts.get_post_count() previous_count = posts.get_post_count()
db.session.add_all([post_factory(), post_factory()]) db.session.add_all([post_factory(), post_factory()])
@ -179,6 +191,8 @@ def test_try_get_post_by_id(post_factory):
db.session.flush() db.session.flush()
assert posts.try_get_post_by_id(post.post_id) == post assert posts.try_get_post_by_id(post.post_id) == post
assert posts.try_get_post_by_id(post.post_id + 1) is None assert posts.try_get_post_by_id(post.post_id + 1) is None
with pytest.raises(posts.InvalidPostIdError):
posts.get_post_by_id('-')
def test_get_post_by_id(post_factory): def test_get_post_by_id(post_factory):
post = post_factory() post = post_factory()
@ -187,6 +201,8 @@ def test_get_post_by_id(post_factory):
assert posts.get_post_by_id(post.post_id) == post assert posts.get_post_by_id(post.post_id) == post
with pytest.raises(posts.PostNotFoundError): with pytest.raises(posts.PostNotFoundError):
posts.get_post_by_id(post.post_id + 1) posts.get_post_by_id(post.post_id + 1)
with pytest.raises(posts.InvalidPostIdError):
posts.get_post_by_id('-')
def test_create_post(user_factory, fake_datetime): def test_create_post(user_factory, fake_datetime):
with unittest.mock.patch('szurubooru.func.posts.update_post_content'), \ with unittest.mock.patch('szurubooru.func.posts.update_post_content'), \
@ -194,7 +210,7 @@ def test_create_post(user_factory, fake_datetime):
fake_datetime('1997-01-01'): fake_datetime('1997-01-01'):
auth_user = user_factory() auth_user = user_factory()
post, new_tags = posts.create_post('content', ['tag'], auth_user) post, new_tags = posts.create_post('content', ['tag'], auth_user)
assert post.creation_time == datetime.datetime(1997, 1, 1) assert post.creation_time == datetime(1997, 1, 1)
assert post.last_edit_time is None assert post.last_edit_time is None
posts.update_post_tags.assert_called_once_with(post, ['tag']) posts.update_post_tags.assert_called_once_with(post, ['tag'])
posts.update_post_content.assert_called_once_with(post, 'content') posts.update_post_content.assert_called_once_with(post, 'content')
@ -209,7 +225,7 @@ def test_update_post_safety(input_safety, expected_safety):
posts.update_post_safety(post, input_safety) posts.update_post_safety(post, input_safety)
assert post.safety == expected_safety assert post.safety == expected_safety
def test_update_post_invalid_safety(): def test_update_post_safety_with_invalid_string():
post = db.Post() post = db.Post()
with pytest.raises(posts.InvalidPostSafetyError): with pytest.raises(posts.InvalidPostSafetyError):
posts.update_post_safety(post, 'bad') posts.update_post_safety(post, 'bad')
@ -219,7 +235,7 @@ def test_update_post_source():
posts.update_post_source(post, 'x') posts.update_post_source(post, 'x')
assert post.source == 'x' assert post.source == 'x'
def test_update_post_invalid_source(): def test_update_post_source_with_too_long_string():
post = db.Post() post = db.Post()
with pytest.raises(posts.InvalidPostSourceError): with pytest.raises(posts.InvalidPostSourceError):
posts.update_post_source(post, 'x' * 1000) posts.update_post_source(post, 'x' * 1000)
@ -277,7 +293,7 @@ def test_update_post_content_to_existing_content(
with pytest.raises(posts.PostAlreadyUploadedError): with pytest.raises(posts.PostAlreadyUploadedError):
posts.update_post_content(another_post, read_asset('png.png')) posts.update_post_content(another_post, read_asset('png.png'))
def test_update_post_content_broken_content( def test_update_post_content_with_broken_content(
tmpdir, config_injector, post_factory, read_asset): tmpdir, config_injector, post_factory, read_asset):
# the rationale behind this behavior is to salvage user upload even if the # the rationale behind this behavior is to salvage user upload even if the
# server software thinks it's broken. chances are the server is wrong, # server software thinks it's broken. chances are the server is wrong,
@ -298,7 +314,7 @@ def test_update_post_content_broken_content(
assert post.canvas_height is None assert post.canvas_height is None
@pytest.mark.parametrize('input_content', [None, b'not a media file']) @pytest.mark.parametrize('input_content', [None, b'not a media file'])
def test_update_post_invalid_content(input_content): def test_update_post_content_with_invalid_content(input_content):
post = db.Post() post = db.Post()
with pytest.raises(posts.InvalidPostContentError): with pytest.raises(posts.InvalidPostContentError):
posts.update_post_content(post, input_content) posts.update_post_content(post, input_content)
@ -340,7 +356,7 @@ def test_update_post_thumbnail_to_default(
assert not os.path.exists(str(tmpdir) + '/data/posts/custom-thumbnails/1.dat') assert not os.path.exists(str(tmpdir) + '/data/posts/custom-thumbnails/1.dat')
assert os.path.exists(str(tmpdir) + '/data/generated-thumbnails/1.jpg') assert os.path.exists(str(tmpdir) + '/data/generated-thumbnails/1.jpg')
def test_update_post_thumbnail_broken_thumbnail( def test_update_post_thumbnail_with_broken_thumbnail(
tmpdir, config_injector, read_asset, post_factory): tmpdir, config_injector, read_asset, post_factory):
config_injector({ config_injector({
'data_dir': str(tmpdir.mkdir('data')), 'data_dir': str(tmpdir.mkdir('data')),
@ -363,7 +379,7 @@ def test_update_post_thumbnail_broken_thumbnail(
assert image.width == 1 assert image.width == 1
assert image.height == 1 assert image.height == 1
def test_update_post_content_leaves_custom_thumbnail( def test_update_post_content_leaving_custom_thumbnail(
tmpdir, config_injector, read_asset, post_factory): tmpdir, config_injector, read_asset, post_factory):
config_injector({ config_injector({
'data_dir': str(tmpdir.mkdir('data')), 'data_dir': str(tmpdir.mkdir('data')),
@ -403,7 +419,7 @@ def test_update_post_relations(post_factory):
assert post.relations[0].post_id == relation1.post_id assert post.relations[0].post_id == relation1.post_id
assert post.relations[1].post_id == relation2.post_id assert post.relations[1].post_id == relation2.post_id
def test_relation_bidirectionality(post_factory): def test_update_post_relations_bidirectionality(post_factory):
relation1 = post_factory() relation1 = post_factory()
relation2 = post_factory() relation2 = post_factory()
db.session.add_all([relation1, relation2]) db.session.add_all([relation1, relation2])
@ -414,7 +430,7 @@ def test_relation_bidirectionality(post_factory):
assert len(post.relations) == 1 assert len(post.relations) == 1
assert post.relations[0].post_id == relation2.post_id assert post.relations[0].post_id == relation2.post_id
def test_update_post_non_existing_relations(): def test_update_post_relations_with_nonexisting_posts():
post = db.Post() post = db.Post()
with pytest.raises(posts.InvalidPostRelationError): with pytest.raises(posts.InvalidPostRelationError):
posts.update_post_relations(post, [100]) posts.update_post_relations(post, [100])
@ -452,7 +468,7 @@ def test_update_post_notes():
[{'polygon': [[0, 0], [0, 0], [0, 1]], 'text': None}], [{'polygon': [[0, 0], [0, 0], [0, 1]], 'text': None}],
[{'polygon': [[0, 0], [0, 0], [0, 1]]}], [{'polygon': [[0, 0], [0, 0], [0, 1]]}],
]) ])
def test_update_post_invalid_notes(input): def test_update_post_notes_with_invalid_content(input):
post = db.Post() post = db.Post()
with pytest.raises(posts.InvalidPostNoteError): with pytest.raises(posts.InvalidPostNoteError):
posts.update_post_notes(post, input) posts.update_post_notes(post, input)
@ -462,18 +478,23 @@ def test_update_post_flags():
posts.update_post_flags(post, ['loop']) posts.update_post_flags(post, ['loop'])
assert post.flags == ['loop'] assert post.flags == ['loop']
def test_update_post_invalid_flags(): def test_update_post_flags_with_invalid_content():
post = db.Post() post = db.Post()
with pytest.raises(posts.InvalidPostFlagError): with pytest.raises(posts.InvalidPostFlagError):
posts.update_post_flags(post, ['invalid']) posts.update_post_flags(post, ['invalid'])
def test_featuring_post(post_factory, user_factory): def test_feature_post(post_factory, user_factory):
post = post_factory() post = post_factory()
user = user_factory() user = user_factory()
previous_featured_post = posts.try_get_featured_post() previous_featured_post = posts.try_get_featured_post()
posts.feature_post(post, user) posts.feature_post(post, user)
new_featured_post = posts.try_get_featured_post() new_featured_post = posts.try_get_featured_post()
assert previous_featured_post is None assert previous_featured_post is None
assert new_featured_post == post assert new_featured_post == post
def test_delete(post_factory):
post = post_factory()
db.session.add(post)
assert posts.get_post_count() == 1
posts.delete(post)
assert posts.get_post_count() == 0

View file

@ -22,8 +22,6 @@ def score_factory(user_factory):
@pytest.fixture @pytest.fixture
def note_factory(): def note_factory():
def factory(post=None): def factory(post=None):
if post:
return db.PostNote(polygon='...', text='...', post=post)
return db.PostNote(polygon='...', text='...') return db.PostNote(polygon='...', text='...')
return factory return factory