server/general: consistently use db.session

This commit is contained in:
rr- 2016-04-19 17:39:16 +02:00
parent fe56e376f6
commit 2e57a0746f
28 changed files with 351 additions and 380 deletions

View file

@ -1,6 +1,5 @@
''' Exports create_app. ''' ''' Exports create_app. '''
import json
import falcon import falcon
from szurubooru import api, errors, middleware from szurubooru import api, errors, middleware
@ -26,15 +25,13 @@ def _on_processing_error(ex, _request, _response, _params):
def create_method_not_allowed(allowed_methods): def create_method_not_allowed(allowed_methods):
allowed = ', '.join(allowed_methods) allowed = ', '.join(allowed_methods)
def method_not_allowed(request, response, **_kwargs):
def method_not_allowed(request, response, **kwargs):
response.status = falcon.status_codes.HTTP_405 response.status = falcon.status_codes.HTTP_405
response.set_header('Allow', allowed) response.set_header('Allow', allowed)
request.context.output = { request.context.output = {
'title': 'Method not allowed', 'title': 'Method not allowed',
'description': 'Allowed methods: %r' % allowed_methods, 'description': 'Allowed methods: %r' % allowed_methods,
} }
return method_not_allowed return method_not_allowed
def create_app(): def create_app():

View file

@ -23,8 +23,8 @@ class SearchExecutor(object):
count_query = filter_query.statement \ count_query = filter_query.statement \
.with_only_columns([sqlalchemy.func.count()]) \ .with_only_columns([sqlalchemy.func.count()]) \
.order_by(None) .order_by(None)
count = filter_query \ count = filter_query.session \
.session.execute(count_query) \ .execute(count_query) \
.scalar() .scalar()
return (count, entities) return (count, entities)

View file

@ -4,7 +4,7 @@ from szurubooru.search.base_search_config import BaseSearchConfig
class TagSearchConfig(BaseSearchConfig): class TagSearchConfig(BaseSearchConfig):
def create_query(self): def create_query(self):
return db.session().query(db.Tag) return db.session.query(db.Tag)
def finalize_query(self, query): def finalize_query(self, query):
return query.order_by(db.Tag.first_name.asc()) return query.order_by(db.Tag.first_name.asc())

View file

@ -6,7 +6,7 @@ class UserSearchConfig(BaseSearchConfig):
''' Executes searches related to the users. ''' ''' Executes searches related to the users. '''
def create_query(self): def create_query(self):
return db.session().query(db.User) return db.session.query(db.User)
def finalize_query(self, query): def finalize_query(self, query):
return query.order_by(db.User.name.asc()) return query.order_by(db.User.name.asc())

View file

@ -14,8 +14,8 @@ def password_reset_api(config_injector):
return api.PasswordResetApi() return api.PasswordResetApi()
def test_reset_sending_email( def test_reset_sending_email(
password_reset_api, session, context_factory, user_factory): password_reset_api, context_factory, user_factory):
session.add(user_factory( db.session.add(user_factory(
name='u1', rank='regular_user', email='user@example.com')) name='u1', rank='regular_user', email='user@example.com'))
for getter in ['u1', 'user@example.com']: for getter in ['u1', 'user@example.com']:
mailer.send_mail = mock.MagicMock() mailer.send_mail = mock.MagicMock()
@ -34,17 +34,17 @@ def test_trying_to_reset_non_existing(password_reset_api, context_factory):
password_reset_api.get(context_factory(), 'u1') password_reset_api.get(context_factory(), 'u1')
def test_trying_to_reset_without_email( def test_trying_to_reset_without_email(
password_reset_api, session, context_factory, user_factory): password_reset_api, context_factory, user_factory):
session.add(user_factory(name='u1', rank='regular_user', email=None)) db.session.add(user_factory(name='u1', rank='regular_user', email=None))
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError):
password_reset_api.get(context_factory(), 'u1') password_reset_api.get(context_factory(), 'u1')
def test_confirming_with_good_token( def test_confirming_with_good_token(
password_reset_api, context_factory, session, user_factory): password_reset_api, context_factory, user_factory):
user = user_factory( user = user_factory(
name='u1', rank='regular_user', email='user@example.com') name='u1', rank='regular_user', email='user@example.com')
old_hash = user.password_hash old_hash = user.password_hash
session.add(user) db.session.add(user)
context = context_factory( context = context_factory(
input={'token': '4ac0be176fb364f13ee6b634c43220e2'}) input={'token': '4ac0be176fb364f13ee6b634c43220e2'})
result = password_reset_api.post(context, 'u1') result = password_reset_api.post(context, 'u1')
@ -56,15 +56,15 @@ def test_trying_to_confirm_non_existing(password_reset_api, context_factory):
password_reset_api.post(context_factory(), 'u1') password_reset_api.post(context_factory(), 'u1')
def test_trying_to_confirm_without_token( def test_trying_to_confirm_without_token(
password_reset_api, context_factory, session, user_factory): password_reset_api, context_factory, user_factory):
session.add(user_factory( db.session.add(user_factory(
name='u1', rank='regular_user', email='user@example.com')) name='u1', rank='regular_user', email='user@example.com'))
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError):
password_reset_api.post(context_factory(input={}), 'u1') password_reset_api.post(context_factory(input={}), 'u1')
def test_trying_to_confirm_with_bad_token( def test_trying_to_confirm_with_bad_token(
password_reset_api, context_factory, session, user_factory): password_reset_api, context_factory, user_factory):
session.add(user_factory( db.session.add(user_factory(
name='u1', rank='regular_user', email='user@example.com')) name='u1', rank='regular_user', email='user@example.com'))
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError):
password_reset_api.post( password_reset_api.post(

View file

@ -28,35 +28,35 @@ def test_ctx(
return ret return ret
def test_deleting(test_ctx): def test_deleting(test_ctx):
db.session().add(test_ctx.tag_category_factory(name='root')) db.session.add(test_ctx.tag_category_factory(name='root'))
db.session().add(test_ctx.tag_category_factory(name='category')) db.session.add(test_ctx.tag_category_factory(name='category'))
db.session().commit() db.session.commit()
result = test_ctx.api.delete( result = test_ctx.api.delete(
test_ctx.context_factory( test_ctx.context_factory(
user=test_ctx.user_factory(rank='regular_user')), user=test_ctx.user_factory(rank='regular_user')),
'category') 'category')
assert result == {} assert result == {}
assert db.session().query(db.TagCategory).count() == 1 assert db.session.query(db.TagCategory).count() == 1
assert db.session().query(db.TagCategory).one().name == 'root' assert db.session.query(db.TagCategory).one().name == 'root'
assert os.path.exists(os.path.join(config.config['data_dir'], 'tags.json')) assert os.path.exists(os.path.join(config.config['data_dir'], 'tags.json'))
def test_trying_to_delete_used(test_ctx, tag_factory): def test_trying_to_delete_used(test_ctx, tag_factory):
category = test_ctx.tag_category_factory(name='category') category = test_ctx.tag_category_factory(name='category')
db.session().add(category) db.session.add(category)
db.session().flush() db.session.flush()
tag = test_ctx.tag_factory(names=['tag'], category=category) tag = test_ctx.tag_factory(names=['tag'], category=category)
db.session().add(tag) db.session.add(tag)
db.session().commit() db.session.commit()
with pytest.raises(tag_categories.TagCategoryIsInUseError): with pytest.raises(tag_categories.TagCategoryIsInUseError):
test_ctx.api.delete( test_ctx.api.delete(
test_ctx.context_factory( test_ctx.context_factory(
user=test_ctx.user_factory(rank='regular_user')), user=test_ctx.user_factory(rank='regular_user')),
'category') 'category')
assert db.session().query(db.TagCategory).count() == 1 assert db.session.query(db.TagCategory).count() == 1
def test_trying_to_delete_last(test_ctx, tag_factory): def test_trying_to_delete_last(test_ctx, tag_factory):
db.session().add(test_ctx.tag_category_factory(name='root')) db.session.add(test_ctx.tag_category_factory(name='root'))
db.session().commit() db.session.commit()
with pytest.raises(tag_categories.TagCategoryIsInUseError): with pytest.raises(tag_categories.TagCategoryIsInUseError):
result = test_ctx.api.delete( result = test_ctx.api.delete(
test_ctx.context_factory( test_ctx.context_factory(
@ -70,11 +70,11 @@ def test_trying_to_delete_non_existing(test_ctx):
user=test_ctx.user_factory(rank='regular_user')), 'bad') user=test_ctx.user_factory(rank='regular_user')), 'bad')
def test_trying_to_delete_without_privileges(test_ctx): def test_trying_to_delete_without_privileges(test_ctx):
db.session().add(test_ctx.tag_category_factory(name='category')) db.session.add(test_ctx.tag_category_factory(name='category'))
db.session().commit() db.session.commit()
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.api.delete( test_ctx.api.delete(
test_ctx.context_factory( test_ctx.context_factory(
user=test_ctx.user_factory(rank='anonymous')), user=test_ctx.user_factory(rank='anonymous')),
'category') 'category')
assert db.session().query(db.TagCategory).count() == 1 assert db.session.query(db.TagCategory).count() == 1

View file

@ -24,7 +24,7 @@ def test_ctx(
return ret return ret
def test_retrieving_multiple(test_ctx): def test_retrieving_multiple(test_ctx):
db.session().add_all([ db.session.add_all([
test_ctx.tag_category_factory(name='c1'), test_ctx.tag_category_factory(name='c1'),
test_ctx.tag_category_factory(name='c2'), test_ctx.tag_category_factory(name='c2'),
]) ])
@ -34,7 +34,7 @@ def test_retrieving_multiple(test_ctx):
assert [cat['name'] for cat in result['tagCategories']] == ['c1', 'c2'] assert [cat['name'] for cat in result['tagCategories']] == ['c1', 'c2']
def test_retrieving_single(test_ctx): def test_retrieving_single(test_ctx):
db.session().add(test_ctx.tag_category_factory(name='cat')) db.session.add(test_ctx.tag_category_factory(name='cat'))
result = test_ctx.detail_api.get( result = test_ctx.detail_api.get(
test_ctx.context_factory( test_ctx.context_factory(
user=test_ctx.user_factory(rank='regular_user')), user=test_ctx.user_factory(rank='regular_user')),

View file

@ -28,8 +28,8 @@ def test_ctx(
def test_simple_updating(test_ctx): def test_simple_updating(test_ctx):
category = test_ctx.tag_category_factory(name='name', color='black') category = test_ctx.tag_category_factory(name='name', color='black')
db.session().add(category) db.session.add(category)
db.session().commit() db.session.commit()
result = test_ctx.api.put( result = test_ctx.api.put(
test_ctx.context_factory( test_ctx.context_factory(
input={ input={
@ -59,8 +59,8 @@ def test_simple_updating(test_ctx):
{'color': ''}, {'color': ''},
]) ])
def test_trying_to_pass_invalid_input(test_ctx, input): def test_trying_to_pass_invalid_input(test_ctx, input):
db.session().add(test_ctx.tag_category_factory(name='meta', color='black')) db.session.add(test_ctx.tag_category_factory(name='meta', color='black'))
db.session().commit() db.session.commit()
with pytest.raises(tag_categories.InvalidTagCategoryNameError): with pytest.raises(tag_categories.InvalidTagCategoryNameError):
test_ctx.api.put( test_ctx.api.put(
test_ctx.context_factory( test_ctx.context_factory(
@ -70,8 +70,8 @@ def test_trying_to_pass_invalid_input(test_ctx, input):
@pytest.mark.parametrize('field', ['name', 'color']) @pytest.mark.parametrize('field', ['name', 'color'])
def test_omitting_optional_field(test_ctx, tmpdir, field): def test_omitting_optional_field(test_ctx, tmpdir, field):
db.session().add(test_ctx.tag_category_factory(name='name', color='black')) db.session.add(test_ctx.tag_category_factory(name='name', color='black'))
db.session().commit() db.session.commit()
input = { input = {
'name': 'changed', 'name': 'changed',
'color': 'white', 'color': 'white',
@ -94,8 +94,8 @@ def test_trying_to_update_non_existing(test_ctx):
@pytest.mark.parametrize('new_name', ['cat', 'CAT']) @pytest.mark.parametrize('new_name', ['cat', 'CAT'])
def test_reusing_own_name(test_ctx, new_name): def test_reusing_own_name(test_ctx, new_name):
db.session().add(test_ctx.tag_category_factory(name='cat', color='black')) db.session.add(test_ctx.tag_category_factory(name='cat', color='black'))
db.session().commit() db.session.commit()
result = test_ctx.api.put( result = test_ctx.api.put(
test_ctx.context_factory( test_ctx.context_factory(
input={'name': new_name}, input={'name': new_name},
@ -107,10 +107,10 @@ def test_reusing_own_name(test_ctx, new_name):
@pytest.mark.parametrize('dup_name', ['cat1', 'CAT1']) @pytest.mark.parametrize('dup_name', ['cat1', 'CAT1'])
def test_trying_to_use_existing_name(test_ctx, dup_name): def test_trying_to_use_existing_name(test_ctx, dup_name):
db.session().add_all([ db.session.add_all([
test_ctx.tag_category_factory(name='cat1', color='black'), test_ctx.tag_category_factory(name='cat1', color='black'),
test_ctx.tag_category_factory(name='cat2', color='black')]) test_ctx.tag_category_factory(name='cat2', color='black')])
db.session().commit() db.session.commit()
with pytest.raises(tag_categories.TagCategoryAlreadyExistsError): with pytest.raises(tag_categories.TagCategoryAlreadyExistsError):
test_ctx.api.put( test_ctx.api.put(
test_ctx.context_factory( test_ctx.context_factory(
@ -123,8 +123,8 @@ def test_trying_to_use_existing_name(test_ctx, dup_name):
{'color': 'whatever'}, {'color': 'whatever'},
]) ])
def test_trying_to_update_without_privileges(test_ctx, input): def test_trying_to_update_without_privileges(test_ctx, input):
db.session().add(test_ctx.tag_category_factory(name='dummy')) db.session.add(test_ctx.tag_category_factory(name='dummy'))
db.session().commit() db.session.commit()
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.api.put( test_ctx.api.put(
test_ctx.context_factory( test_ctx.context_factory(

View file

@ -4,8 +4,9 @@ import pytest
from szurubooru import api, config, db, errors from szurubooru import api, config, db, errors
from szurubooru.util import misc, tags from szurubooru.util import misc, tags
def get_tag(session, name): def get_tag(name):
return session.query(db.Tag) \ return db.session \
.query(db.Tag) \
.join(db.TagName) \ .join(db.TagName) \
.filter(db.TagName.name==name) \ .filter(db.TagName.name==name) \
.first() .first()
@ -16,24 +17,17 @@ def assert_relations(relations, expected_tag_names):
@pytest.fixture @pytest.fixture
def test_ctx( def test_ctx(
tmpdir, tmpdir, config_injector, context_factory, user_factory, tag_factory):
session,
config_injector,
context_factory,
user_factory,
tag_factory):
config_injector({ config_injector({
'data_dir': str(tmpdir), 'data_dir': str(tmpdir),
'tag_name_regex': '^[^!]*$', 'tag_name_regex': '^[^!]*$',
'ranks': ['anonymous', 'regular_user'], 'ranks': ['anonymous', 'regular_user'],
'privileges': {'tags:create': 'regular_user'}, 'privileges': {'tags:create': 'regular_user'},
}) })
session.add_all([ db.session.add_all([
db.TagCategory(name) for name in [ db.TagCategory(name) for name in ['meta', 'character', 'copyright']])
'meta', 'character', 'copyright']]) db.session.flush()
session.flush()
ret = misc.dotdict() ret = misc.dotdict()
ret.session = session
ret.context_factory = context_factory ret.context_factory = context_factory
ret.user_factory = user_factory ret.user_factory = user_factory
ret.tag_factory = tag_factory ret.tag_factory = tag_factory
@ -61,7 +55,7 @@ def test_creating_simple_tags(test_ctx, fake_datetime):
'lastEditTime': None, 'lastEditTime': None,
} }
} }
tag = get_tag(test_ctx.session, 'tag1') tag = get_tag('tag1')
assert [tag_name.name for tag_name in tag.names] == ['tag1', 'tag2'] assert [tag_name.name for tag_name in tag.names] == ['tag1', 'tag2']
assert tag.category.name == 'meta' assert tag.category.name == 'meta'
assert tag.last_edit_time is None assert tag.last_edit_time is None
@ -140,15 +134,15 @@ def test_duplicating_names(test_ctx):
user=test_ctx.user_factory(rank='regular_user'))) user=test_ctx.user_factory(rank='regular_user')))
assert result['tag']['names'] == ['tag1'] assert result['tag']['names'] == ['tag1']
assert result['tag']['category'] == 'meta' assert result['tag']['category'] == 'meta'
tag = get_tag(test_ctx.session, 'tag1') tag = get_tag('tag1')
assert [tag_name.name for tag_name in tag.names] == ['tag1'] assert [tag_name.name for tag_name in tag.names] == ['tag1']
def test_trying_to_use_existing_name(test_ctx): def test_trying_to_use_existing_name(test_ctx):
test_ctx.session.add_all([ db.session.add_all([
test_ctx.tag_factory(names=['used1'], category_name='meta'), test_ctx.tag_factory(names=['used1'], category_name='meta'),
test_ctx.tag_factory(names=['used2'], category_name='meta'), test_ctx.tag_factory(names=['used2'], category_name='meta'),
]) ])
test_ctx.session.commit() db.session.commit()
with pytest.raises(tags.TagAlreadyExistsError): with pytest.raises(tags.TagAlreadyExistsError):
test_ctx.api.post( test_ctx.api.post(
test_ctx.context_factory( test_ctx.context_factory(
@ -169,7 +163,7 @@ def test_trying_to_use_existing_name(test_ctx):
'implications': [], 'implications': [],
}, },
user=test_ctx.user_factory(rank='regular_user'))) user=test_ctx.user_factory(rank='regular_user')))
assert get_tag(test_ctx.session, 'unused') is None assert get_tag('unused') is None
@pytest.mark.parametrize('input,expected_suggestions,expected_implications', [ @pytest.mark.parametrize('input,expected_suggestions,expected_implications', [
# new relations # new relations
@ -208,18 +202,18 @@ def test_creating_new_suggestions_and_implications(
input=input, user=test_ctx.user_factory(rank='regular_user'))) input=input, user=test_ctx.user_factory(rank='regular_user')))
assert result['tag']['suggestions'] == expected_suggestions assert result['tag']['suggestions'] == expected_suggestions
assert result['tag']['implications'] == expected_implications assert result['tag']['implications'] == expected_implications
tag = get_tag(test_ctx.session, 'main') tag = get_tag('main')
assert_relations(tag.suggestions, expected_suggestions) assert_relations(tag.suggestions, expected_suggestions)
assert_relations(tag.implications, expected_implications) assert_relations(tag.implications, expected_implications)
for name in ['main'] + expected_suggestions + expected_implications: for name in ['main'] + expected_suggestions + expected_implications:
assert get_tag(test_ctx.session, name) is not None assert get_tag(name) is not None
def test_reusing_suggestions_and_implications(test_ctx): def test_reusing_suggestions_and_implications(test_ctx):
test_ctx.session.add_all([ db.session.add_all([
test_ctx.tag_factory(names=['tag1', 'tag2'], category_name='meta'), test_ctx.tag_factory(names=['tag1', 'tag2'], category_name='meta'),
test_ctx.tag_factory(names=['tag3'], category_name='meta'), test_ctx.tag_factory(names=['tag3'], category_name='meta'),
]) ])
test_ctx.session.commit() db.session.commit()
result = test_ctx.api.post( result = test_ctx.api.post(
test_ctx.context_factory( test_ctx.context_factory(
input={ input={
@ -232,7 +226,7 @@ def test_reusing_suggestions_and_implications(test_ctx):
# NOTE: it should export only the first name # NOTE: it should export only the first name
assert result['tag']['suggestions'] == ['tag1'] assert result['tag']['suggestions'] == ['tag1']
assert result['tag']['implications'] == ['tag1'] assert result['tag']['implications'] == ['tag1']
tag = get_tag(test_ctx.session, 'new') tag = get_tag('new')
assert_relations(tag.suggestions, ['tag1']) assert_relations(tag.suggestions, ['tag1'])
assert_relations(tag.implications, ['tag1']) assert_relations(tag.implications, ['tag1'])
@ -256,7 +250,7 @@ def test_tag_trying_to_relate_to_itself(test_ctx, input):
test_ctx.context_factory( test_ctx.context_factory(
input=input, input=input,
user=test_ctx.user_factory(rank='regular_user'))) user=test_ctx.user_factory(rank='regular_user')))
assert get_tag(test_ctx.session, 'tag') is None assert get_tag('tag') is None
def test_trying_to_create_tag_without_privileges(test_ctx): def test_trying_to_create_tag_without_privileges(test_ctx):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):

View file

@ -6,12 +6,7 @@ from szurubooru.util import misc, tags
@pytest.fixture @pytest.fixture
def test_ctx( def test_ctx(
tmpdir, tmpdir, config_injector, context_factory, tag_factory, user_factory):
session,
config_injector,
context_factory,
tag_factory,
user_factory):
config_injector({ config_injector({
'data_dir': str(tmpdir), 'data_dir': str(tmpdir),
'privileges': { 'privileges': {
@ -20,7 +15,6 @@ def test_ctx(
'ranks': ['anonymous', 'regular_user'], 'ranks': ['anonymous', 'regular_user'],
}) })
ret = misc.dotdict() ret = misc.dotdict()
ret.session = session
ret.context_factory = context_factory ret.context_factory = context_factory
ret.user_factory = user_factory ret.user_factory = user_factory
ret.tag_factory = tag_factory ret.tag_factory = tag_factory
@ -28,28 +22,28 @@ def test_ctx(
return ret return ret
def test_deleting(test_ctx): def test_deleting(test_ctx):
test_ctx.session.add(test_ctx.tag_factory(names=['tag'])) db.session.add(test_ctx.tag_factory(names=['tag']))
test_ctx.session.commit() db.session.commit()
result = test_ctx.api.delete( result = test_ctx.api.delete(
test_ctx.context_factory( test_ctx.context_factory(
user=test_ctx.user_factory(rank='regular_user')), user=test_ctx.user_factory(rank='regular_user')),
'tag') 'tag')
assert result == {} assert result == {}
assert test_ctx.session.query(db.Tag).count() == 0 assert db.session.query(db.Tag).count() == 0
assert os.path.exists(os.path.join(config.config['data_dir'], 'tags.json')) assert os.path.exists(os.path.join(config.config['data_dir'], 'tags.json'))
def test_trying_to_delete_used(test_ctx, post_factory): def test_trying_to_delete_used(test_ctx, post_factory):
tag = test_ctx.tag_factory(names=['tag']) tag = test_ctx.tag_factory(names=['tag'])
post = post_factory() post = post_factory()
post.tags.append(tag) post.tags.append(tag)
test_ctx.session.add_all([tag, post]) db.session.add_all([tag, post])
test_ctx.session.commit() db.session.commit()
with pytest.raises(tags.TagIsInUseError): with pytest.raises(tags.TagIsInUseError):
test_ctx.api.delete( test_ctx.api.delete(
test_ctx.context_factory( test_ctx.context_factory(
user=test_ctx.user_factory(rank='regular_user')), user=test_ctx.user_factory(rank='regular_user')),
'tag') 'tag')
assert test_ctx.session.query(db.Tag).count() == 1 assert db.session.query(db.Tag).count() == 1
def test_trying_to_delete_non_existing(test_ctx): def test_trying_to_delete_non_existing(test_ctx):
with pytest.raises(tags.TagNotFoundError): with pytest.raises(tags.TagNotFoundError):
@ -58,11 +52,11 @@ def test_trying_to_delete_non_existing(test_ctx):
user=test_ctx.user_factory(rank='regular_user')), 'bad') user=test_ctx.user_factory(rank='regular_user')), 'bad')
def test_trying_to_delete_without_privileges(test_ctx): def test_trying_to_delete_without_privileges(test_ctx):
test_ctx.session.add(test_ctx.tag_factory(names=['tag'])) db.session.add(test_ctx.tag_factory(names=['tag']))
test_ctx.session.commit() db.session.commit()
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.api.delete( test_ctx.api.delete(
test_ctx.context_factory( test_ctx.context_factory(
user=test_ctx.user_factory(rank='anonymous')), user=test_ctx.user_factory(rank='anonymous')),
'tag') 'tag')
assert test_ctx.session.query(db.Tag).count() == 1 assert db.session.query(db.Tag).count() == 1

View file

@ -7,7 +7,6 @@ from szurubooru.util import tags
def test_export( def test_export(
tmpdir, tmpdir,
query_counter, query_counter,
session,
config_injector, config_injector,
tag_factory, tag_factory,
tag_category_factory): tag_category_factory):
@ -16,23 +15,23 @@ def test_export(
}) })
cat1 = tag_category_factory(name='cat1', color='black') cat1 = tag_category_factory(name='cat1', color='black')
cat2 = tag_category_factory(name='cat2', color='white') cat2 = tag_category_factory(name='cat2', color='white')
session.add_all([cat1, cat2]) db.session.add_all([cat1, cat2])
session.flush() db.session.flush()
sug1 = tag_factory(names=['sug1'], category=cat1) sug1 = tag_factory(names=['sug1'], category=cat1)
sug2 = tag_factory(names=['sug2'], category=cat1) sug2 = tag_factory(names=['sug2'], category=cat1)
imp1 = tag_factory(names=['imp1'], category=cat1) imp1 = tag_factory(names=['imp1'], category=cat1)
imp2 = tag_factory(names=['imp2'], category=cat1) imp2 = tag_factory(names=['imp2'], category=cat1)
tag = tag_factory(names=['alias1', 'alias2'], category=cat2) tag = tag_factory(names=['alias1', 'alias2'], category=cat2)
tag.post_count = 1 tag.post_count = 1
session.add_all([tag, sug1, sug2, imp1, imp2, cat1, cat2]) db.session.add_all([tag, sug1, sug2, imp1, imp2, cat1, cat2])
session.flush() db.session.flush()
session.add_all([ db.session.add_all([
db.TagSuggestion(tag.tag_id, sug1.tag_id), db.TagSuggestion(tag.tag_id, sug1.tag_id),
db.TagSuggestion(tag.tag_id, sug2.tag_id), db.TagSuggestion(tag.tag_id, sug2.tag_id),
db.TagImplication(tag.tag_id, imp1.tag_id), db.TagImplication(tag.tag_id, imp1.tag_id),
db.TagImplication(tag.tag_id, imp2.tag_id), db.TagImplication(tag.tag_id, imp2.tag_id),
]) ])
session.flush() db.session.flush()
with query_counter: with query_counter:
tags.export_to_json() tags.export_to_json()

View file

@ -4,8 +4,7 @@ from szurubooru import api, db, errors
from szurubooru.util import misc, tags from szurubooru.util import misc, tags
@pytest.fixture @pytest.fixture
def test_ctx( def test_ctx(context_factory, config_injector, user_factory, tag_factory):
session, context_factory, config_injector, user_factory, tag_factory):
config_injector({ config_injector({
'privileges': { 'privileges': {
'tags:list': 'regular_user', 'tags:list': 'regular_user',
@ -16,7 +15,6 @@ def test_ctx(
'rank_names': {'regular_user': 'Peasant'}, 'rank_names': {'regular_user': 'Peasant'},
}) })
ret = misc.dotdict() ret = misc.dotdict()
ret.session = session
ret.context_factory = context_factory ret.context_factory = context_factory
ret.user_factory = user_factory ret.user_factory = user_factory
ret.tag_factory = tag_factory ret.tag_factory = tag_factory
@ -27,7 +25,7 @@ def test_ctx(
def test_retrieving_multiple(test_ctx): def test_retrieving_multiple(test_ctx):
tag1 = test_ctx.tag_factory(names=['t1']) tag1 = test_ctx.tag_factory(names=['t1'])
tag2 = test_ctx.tag_factory(names=['t2']) tag2 = test_ctx.tag_factory(names=['t2'])
test_ctx.session.add_all([tag1, tag2]) db.session.add_all([tag1, tag2])
result = test_ctx.list_api.get( result = test_ctx.list_api.get(
test_ctx.context_factory( test_ctx.context_factory(
input={'query': '', 'page': 1}, input={'query': '', 'page': 1},
@ -46,7 +44,7 @@ def test_trying_to_retrieve_multiple_without_privileges(test_ctx):
user=test_ctx.user_factory(rank='anonymous'))) user=test_ctx.user_factory(rank='anonymous')))
def test_retrieving_single(test_ctx): def test_retrieving_single(test_ctx):
test_ctx.session.add(test_ctx.tag_factory(names=['tag'])) db.session.add(test_ctx.tag_factory(names=['tag']))
result = test_ctx.detail_api.get( result = test_ctx.detail_api.get(
test_ctx.context_factory( test_ctx.context_factory(
input={'query': '', 'page': 1}, input={'query': '', 'page': 1},

View file

@ -4,8 +4,9 @@ import pytest
from szurubooru import api, config, db, errors from szurubooru import api, config, db, errors
from szurubooru.util import misc, tags from szurubooru.util import misc, tags
def get_tag(session, name): def get_tag(name):
return session.query(db.Tag) \ return db.session \
.query(db.Tag) \
.join(db.TagName) \ .join(db.TagName) \
.filter(db.TagName.name==name) \ .filter(db.TagName.name==name) \
.first() .first()
@ -16,12 +17,7 @@ def assert_relations(relations, expected_tag_names):
@pytest.fixture @pytest.fixture
def test_ctx( def test_ctx(
tmpdir, tmpdir, config_injector, context_factory, user_factory, tag_factory):
session,
config_injector,
context_factory,
user_factory,
tag_factory):
config_injector({ config_injector({
'data_dir': str(tmpdir), 'data_dir': str(tmpdir),
'tag_name_regex': '^[^!]*$', 'tag_name_regex': '^[^!]*$',
@ -33,12 +29,10 @@ def test_ctx(
'tags:edit:implications': 'regular_user', 'tags:edit:implications': 'regular_user',
}, },
}) })
session.add_all([ db.session.add_all([
db.TagCategory(name) for name in [ db.TagCategory(name) for name in ['meta', 'character', 'copyright']])
'meta', 'character', 'copyright']]) db.session.flush()
session.flush()
ret = misc.dotdict() ret = misc.dotdict()
ret.session = session
ret.context_factory = context_factory ret.context_factory = context_factory
ret.user_factory = user_factory ret.user_factory = user_factory
ret.tag_factory = tag_factory ret.tag_factory = tag_factory
@ -47,8 +41,8 @@ def test_ctx(
def test_simple_updating(test_ctx, fake_datetime): def test_simple_updating(test_ctx, fake_datetime):
tag = test_ctx.tag_factory(names=['tag1', 'tag2'], category_name='meta') tag = test_ctx.tag_factory(names=['tag1', 'tag2'], category_name='meta')
test_ctx.session.add(tag) db.session.add(tag)
test_ctx.session.commit() db.session.commit()
with fake_datetime('1997-12-01'): with fake_datetime('1997-12-01'):
result = test_ctx.api.put( result = test_ctx.api.put(
test_ctx.context_factory( test_ctx.context_factory(
@ -68,9 +62,9 @@ def test_simple_updating(test_ctx, fake_datetime):
'lastEditTime': datetime.datetime(1997, 12, 1), 'lastEditTime': datetime.datetime(1997, 12, 1),
} }
} }
assert get_tag(test_ctx.session, 'tag1') is None assert get_tag('tag1') is None
assert get_tag(test_ctx.session, 'tag2') is None assert get_tag('tag2') is None
tag = get_tag(test_ctx.session, 'tag3') tag = get_tag('tag3')
assert tag is not None assert tag is not None
assert [tag_name.name for tag_name in tag.names] == ['tag3'] assert [tag_name.name for tag_name in tag.names] == ['tag3']
assert tag.category.name == 'character' assert tag.category.name == 'character'
@ -92,9 +86,8 @@ def test_simple_updating(test_ctx, fake_datetime):
({'implications': ['good', '!bad']}, tags.InvalidTagNameError), ({'implications': ['good', '!bad']}, tags.InvalidTagNameError),
]) ])
def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception): def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception):
test_ctx.session.add( db.session.add(test_ctx.tag_factory(names=['tag1'], category_name='meta'))
test_ctx.tag_factory(names=['tag1'], category_name='meta')) db.session.commit()
test_ctx.session.commit()
with pytest.raises(expected_exception): with pytest.raises(expected_exception):
test_ctx.api.put( test_ctx.api.put(
test_ctx.context_factory( test_ctx.context_factory(
@ -105,9 +98,8 @@ def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception):
@pytest.mark.parametrize( @pytest.mark.parametrize(
'field', ['names', 'category', 'implications', 'suggestions']) 'field', ['names', 'category', 'implications', 'suggestions'])
def test_omitting_optional_field(test_ctx, tmpdir, field): def test_omitting_optional_field(test_ctx, tmpdir, field):
test_ctx.session.add( db.session.add(test_ctx.tag_factory(names=['tag'], category_name='meta'))
test_ctx.tag_factory(names=['tag'], category_name='meta')) db.session.commit()
test_ctx.session.commit()
input = { input = {
'names': ['tag1', 'tag2'], 'names': ['tag1', 'tag2'],
'category': 'meta', 'category': 'meta',
@ -132,23 +124,23 @@ def test_trying_to_update_non_existing(test_ctx):
@pytest.mark.parametrize('dup_name', ['tag1', 'TAG1']) @pytest.mark.parametrize('dup_name', ['tag1', 'TAG1'])
def test_reusing_own_name(test_ctx, dup_name): def test_reusing_own_name(test_ctx, dup_name):
test_ctx.session.add( db.session.add(
test_ctx.tag_factory(names=['tag1', 'tag2'], category_name='meta')) test_ctx.tag_factory(names=['tag1', 'tag2'], category_name='meta'))
test_ctx.session.commit() db.session.commit()
result = test_ctx.api.put( result = test_ctx.api.put(
test_ctx.context_factory( test_ctx.context_factory(
input={'names': [dup_name, 'tag3']}, input={'names': [dup_name, 'tag3']},
user=test_ctx.user_factory(rank='regular_user')), user=test_ctx.user_factory(rank='regular_user')),
'tag1') 'tag1')
assert result['tag']['names'] == ['tag1', 'tag3'] assert result['tag']['names'] == ['tag1', 'tag3']
assert get_tag(test_ctx.session, 'tag2') is None assert get_tag('tag2') is None
tag1 = get_tag(test_ctx.session, 'tag1') tag1 = get_tag('tag1')
tag2 = get_tag(test_ctx.session, 'tag3') tag2 = get_tag('tag3')
assert tag1.tag_id == tag2.tag_id assert tag1.tag_id == tag2.tag_id
assert [name.name for name in tag1.names] == ['tag1', 'tag3'] assert [name.name for name in tag1.names] == ['tag1', 'tag3']
def test_duplicating_names(test_ctx): def test_duplicating_names(test_ctx):
test_ctx.session.add( db.session.add(
test_ctx.tag_factory(names=['tag1', 'tag2'], category_name='meta')) test_ctx.tag_factory(names=['tag1', 'tag2'], category_name='meta'))
result = test_ctx.api.put( result = test_ctx.api.put(
test_ctx.context_factory( test_ctx.context_factory(
@ -156,18 +148,18 @@ def test_duplicating_names(test_ctx):
user=test_ctx.user_factory(rank='regular_user')), user=test_ctx.user_factory(rank='regular_user')),
'tag1') 'tag1')
assert result['tag']['names'] == ['tag3'] assert result['tag']['names'] == ['tag3']
assert get_tag(test_ctx.session, 'tag1') is None assert get_tag('tag1') is None
assert get_tag(test_ctx.session, 'tag2') is None assert get_tag('tag2') is None
tag = get_tag(test_ctx.session, 'tag3') tag = get_tag('tag3')
assert tag is not None assert tag is not None
assert [tag_name.name for tag_name in tag.names] == ['tag3'] assert [tag_name.name for tag_name in tag.names] == ['tag3']
@pytest.mark.parametrize('dup_name', ['tag1', 'TAG1', 'tag2', 'TAG2']) @pytest.mark.parametrize('dup_name', ['tag1', 'TAG1', 'tag2', 'TAG2'])
def test_trying_to_use_existing_name(test_ctx, dup_name): def test_trying_to_use_existing_name(test_ctx, dup_name):
test_ctx.session.add_all([ db.session.add_all([
test_ctx.tag_factory(names=['tag1', 'tag2'], category_name='meta'), test_ctx.tag_factory(names=['tag1', 'tag2'], category_name='meta'),
test_ctx.tag_factory(names=['tag3', 'tag4'], category_name='meta')]) test_ctx.tag_factory(names=['tag3', 'tag4'], category_name='meta')])
test_ctx.session.commit() db.session.commit()
with pytest.raises(tags.TagAlreadyExistsError): with pytest.raises(tags.TagAlreadyExistsError):
test_ctx.api.put( test_ctx.api.put(
test_ctx.context_factory( test_ctx.context_factory(
@ -199,28 +191,28 @@ def test_trying_to_use_existing_name(test_ctx, dup_name):
]) ])
def test_updating_new_suggestions_and_implications( def test_updating_new_suggestions_and_implications(
test_ctx, input, expected_suggestions, expected_implications): test_ctx, input, expected_suggestions, expected_implications):
test_ctx.session.add( db.session.add(
test_ctx.tag_factory(names=['main'], category_name='meta')) test_ctx.tag_factory(names=['main'], category_name='meta'))
test_ctx.session.commit() db.session.commit()
result = test_ctx.api.put( result = test_ctx.api.put(
test_ctx.context_factory( test_ctx.context_factory(
input=input, user=test_ctx.user_factory(rank='regular_user')), input=input, user=test_ctx.user_factory(rank='regular_user')),
'main') 'main')
assert result['tag']['suggestions'] == expected_suggestions assert result['tag']['suggestions'] == expected_suggestions
assert result['tag']['implications'] == expected_implications assert result['tag']['implications'] == expected_implications
tag = get_tag(test_ctx.session, 'main') tag = get_tag('main')
assert_relations(tag.suggestions, expected_suggestions) assert_relations(tag.suggestions, expected_suggestions)
assert_relations(tag.implications, expected_implications) assert_relations(tag.implications, expected_implications)
for name in ['main'] + expected_suggestions + expected_implications: for name in ['main'] + expected_suggestions + expected_implications:
assert get_tag(test_ctx.session, name) is not None assert get_tag(name) is not None
def test_reusing_suggestions_and_implications(test_ctx): def test_reusing_suggestions_and_implications(test_ctx):
test_ctx.session.add_all([ db.session.add_all([
test_ctx.tag_factory(names=['tag1', 'tag2'], category_name='meta'), test_ctx.tag_factory(names=['tag1', 'tag2'], category_name='meta'),
test_ctx.tag_factory(names=['tag3'], category_name='meta'), test_ctx.tag_factory(names=['tag3'], category_name='meta'),
test_ctx.tag_factory(names=['tag4'], category_name='meta'), test_ctx.tag_factory(names=['tag4'], category_name='meta'),
]) ])
test_ctx.session.commit() db.session.commit()
result = test_ctx.api.put( result = test_ctx.api.put(
test_ctx.context_factory( test_ctx.context_factory(
input={ input={
@ -234,7 +226,7 @@ def test_reusing_suggestions_and_implications(test_ctx):
# NOTE: it should export only the first name # NOTE: it should export only the first name
assert result['tag']['suggestions'] == ['tag1'] assert result['tag']['suggestions'] == ['tag1']
assert result['tag']['implications'] == ['tag1'] assert result['tag']['implications'] == ['tag1']
tag = get_tag(test_ctx.session, 'new') tag = get_tag('new')
assert_relations(tag.suggestions, ['tag1']) assert_relations(tag.suggestions, ['tag1'])
assert_relations(tag.implications, ['tag1']) assert_relations(tag.implications, ['tag1'])
@ -253,9 +245,8 @@ def test_reusing_suggestions_and_implications(test_ctx):
} }
]) ])
def test_trying_to_relate_tag_to_itself(test_ctx, input): def test_trying_to_relate_tag_to_itself(test_ctx, input):
test_ctx.session.add( db.session.add(test_ctx.tag_factory(names=['tag1'], category_name='meta'))
test_ctx.tag_factory(names=['tag1'], category_name='meta')) db.session.commit()
test_ctx.session.commit()
with pytest.raises(tags.InvalidTagRelationError): with pytest.raises(tags.InvalidTagRelationError):
test_ctx.api.put( test_ctx.api.put(
test_ctx.context_factory( test_ctx.context_factory(
@ -269,9 +260,8 @@ def test_trying_to_relate_tag_to_itself(test_ctx, input):
{'implications': ['whatever']}, {'implications': ['whatever']},
]) ])
def test_trying_to_update_without_privileges(test_ctx, input): def test_trying_to_update_without_privileges(test_ctx, input):
test_ctx.session.add( db.session.add(test_ctx.tag_factory(names=['tag'], category_name='meta'))
test_ctx.tag_factory(names=['tag'], category_name='meta')) db.session.commit()
test_ctx.session.commit()
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.api.put( test_ctx.api.put(
test_ctx.context_factory( test_ctx.context_factory(

View file

@ -8,12 +8,11 @@ EMPTY_PIXEL = \
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \ b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b' b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b'
def get_user(session, name): def get_user(name):
return session.query(db.User).filter_by(name=name).first() return db.session.query(db.User).filter_by(name=name).first()
@pytest.fixture @pytest.fixture
def test_ctx( def test_ctx(config_injector, context_factory, user_factory):
session, config_injector, context_factory, user_factory):
config_injector({ config_injector({
'secret': '', 'secret': '',
'user_name_regex': '[^!]{3,}', 'user_name_regex': '[^!]{3,}',
@ -25,7 +24,6 @@ def test_ctx(
'privileges': {'users:create': 'anonymous'}, 'privileges': {'users:create': 'anonymous'},
}) })
ret = misc.dotdict() ret = misc.dotdict()
ret.session = session
ret.context_factory = context_factory ret.context_factory = context_factory
ret.user_factory = user_factory ret.user_factory = user_factory
ret.api = api.UserListApi() ret.api = api.UserListApi()
@ -53,7 +51,7 @@ def test_creating_user(test_ctx, fake_datetime):
'rankName': 'Unknown', 'rankName': 'Unknown',
} }
} }
user = get_user(test_ctx.session, 'chewie1') user = get_user('chewie1')
assert user.name == 'chewie1' assert user.name == 'chewie1'
assert user.email == 'asd@asd.asd' assert user.email == 'asd@asd.asd'
assert user.rank == 'admin' assert user.rank == 'admin'
@ -79,8 +77,8 @@ def test_first_user_becomes_admin_others_not(test_ctx):
user=test_ctx.user_factory(rank='anonymous'))) user=test_ctx.user_factory(rank='anonymous')))
assert result1['user']['rank'] == 'admin' assert result1['user']['rank'] == 'admin'
assert result2['user']['rank'] == 'regular_user' assert result2['user']['rank'] == 'regular_user'
first_user = get_user(test_ctx.session, 'chewie1') first_user = get_user('chewie1')
other_user = get_user(test_ctx.session, 'chewie2') other_user = get_user('chewie2')
assert first_user.rank == 'admin' assert first_user.rank == 'admin'
assert other_user.rank == 'regular_user' assert other_user.rank == 'regular_user'
@ -192,7 +190,7 @@ def test_omitting_optional_field(test_ctx, tmpdir, field):
def test_mods_trying_to_become_admin(test_ctx): def test_mods_trying_to_become_admin(test_ctx):
user1 = test_ctx.user_factory(name='u1', rank='mod') user1 = test_ctx.user_factory(name='u1', rank='mod')
user2 = test_ctx.user_factory(name='u2', rank='mod') user2 = test_ctx.user_factory(name='u2', rank='mod')
test_ctx.session.add_all([user1, user2]) db.session.add_all([user1, user2])
context = test_ctx.context_factory(input={ context = test_ctx.context_factory(input={
'name': 'chewie', 'name': 'chewie',
'email': 'asd@asd.asd', 'email': 'asd@asd.asd',
@ -204,7 +202,7 @@ def test_mods_trying_to_become_admin(test_ctx):
def test_admin_creating_mod_account(test_ctx): def test_admin_creating_mod_account(test_ctx):
user = test_ctx.user_factory(rank='admin') user = test_ctx.user_factory(rank='admin')
test_ctx.session.add(user) db.session.add(user)
context = test_ctx.context_factory(input={ context = test_ctx.context_factory(input={
'name': 'chewie', 'name': 'chewie',
'email': 'asd@asd.asd', 'email': 'asd@asd.asd',
@ -227,7 +225,7 @@ def test_uploading_avatar(test_ctx, tmpdir):
}, },
files={'avatar': EMPTY_PIXEL}, files={'avatar': EMPTY_PIXEL},
user=test_ctx.user_factory(rank='mod'))) user=test_ctx.user_factory(rank='mod')))
user = get_user(test_ctx.session, 'chewie') user = get_user('chewie')
assert user.avatar_style == user.AVATAR_MANUAL assert user.avatar_style == user.AVATAR_MANUAL
assert response['user']['avatarUrl'] == \ assert response['user']['avatarUrl'] == \
'http://example.com/data/avatars/chewie.jpg' 'http://example.com/data/avatars/chewie.jpg'

View file

@ -4,7 +4,7 @@ from szurubooru import api, db, errors
from szurubooru.util import misc, users from szurubooru.util import misc, users
@pytest.fixture @pytest.fixture
def test_ctx(session, config_injector, context_factory, user_factory): def test_ctx(config_injector, context_factory, user_factory):
config_injector({ config_injector({
'privileges': { 'privileges': {
'users:delete:self': 'regular_user', 'users:delete:self': 'regular_user',
@ -13,7 +13,6 @@ def test_ctx(session, config_injector, context_factory, user_factory):
'ranks': ['anonymous', 'regular_user', 'mod', 'admin'], 'ranks': ['anonymous', 'regular_user', 'mod', 'admin'],
}) })
ret = misc.dotdict() ret = misc.dotdict()
ret.session = session
ret.context_factory = context_factory ret.context_factory = context_factory
ret.user_factory = user_factory ret.user_factory = user_factory
ret.api = api.UserDetailApi() ret.api = api.UserDetailApi()
@ -21,28 +20,28 @@ def test_ctx(session, config_injector, context_factory, user_factory):
def test_deleting_oneself(test_ctx): def test_deleting_oneself(test_ctx):
user = test_ctx.user_factory(name='u', rank='regular_user') user = test_ctx.user_factory(name='u', rank='regular_user')
test_ctx.session.add(user) db.session.add(user)
test_ctx.session.commit() db.session.commit()
result = test_ctx.api.delete(test_ctx.context_factory(user=user), 'u') result = test_ctx.api.delete(test_ctx.context_factory(user=user), 'u')
assert result == {} assert result == {}
assert test_ctx.session.query(db.User).count() == 0 assert db.session.query(db.User).count() == 0
def test_deleting_someone_else(test_ctx): def test_deleting_someone_else(test_ctx):
user1 = test_ctx.user_factory(name='u1', rank='regular_user') user1 = test_ctx.user_factory(name='u1', rank='regular_user')
user2 = test_ctx.user_factory(name='u2', rank='mod') user2 = test_ctx.user_factory(name='u2', rank='mod')
test_ctx.session.add_all([user1, user2]) db.session.add_all([user1, user2])
test_ctx.session.commit() db.session.commit()
test_ctx.api.delete(test_ctx.context_factory(user=user2), 'u1') test_ctx.api.delete(test_ctx.context_factory(user=user2), 'u1')
assert test_ctx.session.query(db.User).count() == 1 assert db.session.query(db.User).count() == 1
def test_trying_to_delete_someone_else_without_privileges(test_ctx): def test_trying_to_delete_someone_else_without_privileges(test_ctx):
user1 = test_ctx.user_factory(name='u1', rank='regular_user') user1 = test_ctx.user_factory(name='u1', rank='regular_user')
user2 = test_ctx.user_factory(name='u2', rank='regular_user') user2 = test_ctx.user_factory(name='u2', rank='regular_user')
test_ctx.session.add_all([user1, user2]) db.session.add_all([user1, user2])
test_ctx.session.commit() db.session.commit()
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.api.delete(test_ctx.context_factory(user=user2), 'u1') test_ctx.api.delete(test_ctx.context_factory(user=user2), 'u1')
assert test_ctx.session.query(db.User).count() == 2 assert db.session.query(db.User).count() == 2
def test_trying_to_delete_non_existing(test_ctx): def test_trying_to_delete_non_existing(test_ctx):
with pytest.raises(users.UserNotFoundError): with pytest.raises(users.UserNotFoundError):

View file

@ -4,7 +4,7 @@ from szurubooru import api, db, errors
from szurubooru.util import misc, users from szurubooru.util import misc, users
@pytest.fixture @pytest.fixture
def test_ctx(session, context_factory, config_injector, user_factory): def test_ctx(context_factory, config_injector, user_factory):
config_injector({ config_injector({
'privileges': { 'privileges': {
'users:list': 'regular_user', 'users:list': 'regular_user',
@ -15,7 +15,6 @@ def test_ctx(session, context_factory, config_injector, user_factory):
'rank_names': {'regular_user': 'Peasant'}, 'rank_names': {'regular_user': 'Peasant'},
}) })
ret = misc.dotdict() ret = misc.dotdict()
ret.session = session
ret.context_factory = context_factory ret.context_factory = context_factory
ret.user_factory = user_factory ret.user_factory = user_factory
ret.list_api = api.UserListApi() ret.list_api = api.UserListApi()
@ -25,7 +24,7 @@ def test_ctx(session, context_factory, config_injector, user_factory):
def test_retrieving_multiple(test_ctx): def test_retrieving_multiple(test_ctx):
user1 = test_ctx.user_factory(name='u1', rank='mod') user1 = test_ctx.user_factory(name='u1', rank='mod')
user2 = test_ctx.user_factory(name='u2', rank='mod') user2 = test_ctx.user_factory(name='u2', rank='mod')
test_ctx.session.add_all([user1, user2]) db.session.add_all([user1, user2])
result = test_ctx.list_api.get( result = test_ctx.list_api.get(
test_ctx.context_factory( test_ctx.context_factory(
input={'query': '', 'page': 1}, input={'query': '', 'page': 1},
@ -44,7 +43,7 @@ def test_trying_to_retrieve_multiple_without_privileges(test_ctx):
user=test_ctx.user_factory(rank='anonymous'))) user=test_ctx.user_factory(rank='anonymous')))
def test_retrieving_single(test_ctx): def test_retrieving_single(test_ctx):
test_ctx.session.add(test_ctx.user_factory(name='u1', rank='regular_user')) db.session.add(test_ctx.user_factory(name='u1', rank='regular_user'))
result = test_ctx.detail_api.get( result = test_ctx.detail_api.get(
test_ctx.context_factory( test_ctx.context_factory(
input={'query': '', 'page': 1}, input={'query': '', 'page': 1},

View file

@ -8,12 +8,11 @@ EMPTY_PIXEL = \
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \ b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b' b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b'
def get_user(session, name): def get_user(name):
return session.query(db.User).filter_by(name=name).first() return db.session.query(db.User).filter_by(name=name).first()
@pytest.fixture @pytest.fixture
def test_ctx( def test_ctx(config_injector, context_factory, user_factory):
session, config_injector, context_factory, user_factory):
config_injector({ config_injector({
'secret': '', 'secret': '',
'user_name_regex': '^[^!]{3,}$', 'user_name_regex': '^[^!]{3,}$',
@ -35,7 +34,6 @@ def test_ctx(
}, },
}) })
ret = misc.dotdict() ret = misc.dotdict()
ret.session = session
ret.context_factory = context_factory ret.context_factory = context_factory
ret.user_factory = user_factory ret.user_factory = user_factory
ret.api = api.UserDetailApi() ret.api = api.UserDetailApi()
@ -43,7 +41,7 @@ def test_ctx(
def test_updating_user(test_ctx): def test_updating_user(test_ctx):
user = test_ctx.user_factory(name='u1', rank='admin') user = test_ctx.user_factory(name='u1', rank='admin')
test_ctx.session.add(user) db.session.add(user)
result = test_ctx.api.put( result = test_ctx.api.put(
test_ctx.context_factory( test_ctx.context_factory(
input={ input={
@ -68,7 +66,7 @@ def test_updating_user(test_ctx):
'rankName': 'Unknown', 'rankName': 'Unknown',
} }
} }
user = get_user(test_ctx.session, 'chewie') user = get_user('chewie')
assert user.name == 'chewie' assert user.name == 'chewie'
assert user.email == 'asd@asd.asd' assert user.email == 'asd@asd.asd'
assert user.rank == 'mod' assert user.rank == 'mod'
@ -96,7 +94,7 @@ def test_updating_user(test_ctx):
]) ])
def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception): def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception):
user = test_ctx.user_factory(name='u1', rank='admin') user = test_ctx.user_factory(name='u1', rank='admin')
test_ctx.session.add(user) db.session.add(user)
with pytest.raises(expected_exception): with pytest.raises(expected_exception):
test_ctx.api.put( test_ctx.api.put(
test_ctx.context_factory(input=input, user=user), 'u1') test_ctx.context_factory(input=input, user=user), 'u1')
@ -107,7 +105,7 @@ def test_omitting_optional_field(test_ctx, tmpdir, field):
config.config['data_dir'] = str(tmpdir.mkdir('data')) config.config['data_dir'] = str(tmpdir.mkdir('data'))
config.config['data_url'] = 'http://example.com/data/' config.config['data_url'] = 'http://example.com/data/'
user = test_ctx.user_factory(name='u1', rank='admin') user = test_ctx.user_factory(name='u1', rank='admin')
test_ctx.session.add(user) db.session.add(user)
input = { input = {
'name': 'chewie', 'name': 'chewie',
'email': 'asd@asd.asd', 'email': 'asd@asd.asd',
@ -126,16 +124,16 @@ def test_omitting_optional_field(test_ctx, tmpdir, field):
def test_trying_to_update_non_existing(test_ctx): def test_trying_to_update_non_existing(test_ctx):
user = test_ctx.user_factory(name='u1', rank='admin') user = test_ctx.user_factory(name='u1', rank='admin')
test_ctx.session.add(user) db.session.add(user)
with pytest.raises(users.UserNotFoundError): with pytest.raises(users.UserNotFoundError):
test_ctx.api.put(test_ctx.context_factory(user=user), 'u2') test_ctx.api.put(test_ctx.context_factory(user=user), 'u2')
def test_removing_email(test_ctx): def test_removing_email(test_ctx):
user = test_ctx.user_factory(name='u1', rank='admin') user = test_ctx.user_factory(name='u1', rank='admin')
test_ctx.session.add(user) db.session.add(user)
test_ctx.api.put( test_ctx.api.put(
test_ctx.context_factory(input={'email': ''}, user=user), 'u1') test_ctx.context_factory(input={'email': ''}, user=user), 'u1')
assert get_user(test_ctx.session, 'u1').email is None assert get_user('u1').email is None
@pytest.mark.parametrize('input', [ @pytest.mark.parametrize('input', [
{'name': 'whatever'}, {'name': 'whatever'},
@ -147,7 +145,7 @@ def test_removing_email(test_ctx):
def test_trying_to_update_someone_else(test_ctx, input): def test_trying_to_update_someone_else(test_ctx, input):
user1 = test_ctx.user_factory(name='u1', rank='regular_user') user1 = test_ctx.user_factory(name='u1', rank='regular_user')
user2 = test_ctx.user_factory(name='u2', rank='regular_user') user2 = test_ctx.user_factory(name='u2', rank='regular_user')
test_ctx.session.add_all([user1, user2]) db.session.add_all([user1, user2])
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.api.put( test_ctx.api.put(
test_ctx.context_factory(input=input, user=user1), user2.name) test_ctx.context_factory(input=input, user=user1), user2.name)
@ -155,7 +153,7 @@ def test_trying_to_update_someone_else(test_ctx, input):
def test_trying_to_become_someone_else(test_ctx): def test_trying_to_become_someone_else(test_ctx):
user1 = test_ctx.user_factory(name='me', rank='regular_user') user1 = test_ctx.user_factory(name='me', rank='regular_user')
user2 = test_ctx.user_factory(name='her', rank='regular_user') user2 = test_ctx.user_factory(name='her', rank='regular_user')
test_ctx.session.add_all([user1, user2]) db.session.add_all([user1, user2])
with pytest.raises(users.UserAlreadyExistsError): with pytest.raises(users.UserAlreadyExistsError):
test_ctx.api.put( test_ctx.api.put(
test_ctx.context_factory(input={'name': 'her'}, user=user1), test_ctx.context_factory(input={'name': 'her'}, user=user1),
@ -167,7 +165,7 @@ def test_trying_to_become_someone_else(test_ctx):
def test_mods_trying_to_become_admin(test_ctx): def test_mods_trying_to_become_admin(test_ctx):
user1 = test_ctx.user_factory(name='u1', rank='mod') user1 = test_ctx.user_factory(name='u1', rank='mod')
user2 = test_ctx.user_factory(name='u2', rank='mod') user2 = test_ctx.user_factory(name='u2', rank='mod')
test_ctx.session.add_all([user1, user2]) db.session.add_all([user1, user2])
context = test_ctx.context_factory(input={'rank': 'admin'}, user=user1) context = test_ctx.context_factory(input={'rank': 'admin'}, user=user1)
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.api.put(context, user1.name) test_ctx.api.put(context, user1.name)
@ -178,14 +176,14 @@ def test_uploading_avatar(test_ctx, tmpdir):
config.config['data_dir'] = str(tmpdir.mkdir('data')) config.config['data_dir'] = str(tmpdir.mkdir('data'))
config.config['data_url'] = 'http://example.com/data/' config.config['data_url'] = 'http://example.com/data/'
user = test_ctx.user_factory(name='u1', rank='mod') user = test_ctx.user_factory(name='u1', rank='mod')
test_ctx.session.add(user) db.session.add(user)
response = test_ctx.api.put( response = test_ctx.api.put(
test_ctx.context_factory( test_ctx.context_factory(
input={'avatarStyle': 'manual'}, input={'avatarStyle': 'manual'},
files={'avatar': EMPTY_PIXEL}, files={'avatar': EMPTY_PIXEL},
user=user), user=user),
'u1') 'u1')
user = get_user(test_ctx.session, 'u1') user = get_user('u1')
assert user.avatar_style == user.AVATAR_MANUAL assert user.avatar_style == user.AVATAR_MANUAL
assert response['user']['avatarUrl'] == \ assert response['user']['avatarUrl'] == \
'http://example.com/data/avatars/u1.jpg' 'http://example.com/data/avatars/u1.jpg'

View file

@ -46,8 +46,8 @@ def fake_datetime():
freezer.stop() freezer.stop()
return injector return injector
@pytest.yield_fixture @pytest.yield_fixture(autouse=True)
def session(query_counter, autoload=True): def session(query_counter):
import logging import logging
logging.basicConfig() logging.basicConfig()
logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO) logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)

View file

@ -1,7 +1,7 @@
from datetime import datetime from datetime import datetime
from szurubooru import db from szurubooru import db
def test_saving_post(session, post_factory, user_factory, tag_factory): def test_saving_post(post_factory, user_factory, tag_factory):
user = user_factory() user = user_factory()
tag1 = tag_factory() tag1 = tag_factory()
tag2 = tag_factory() tag2 = tag_factory()
@ -13,17 +13,17 @@ def test_saving_post(session, post_factory, user_factory, tag_factory):
post.checksum = 'deadbeef' post.checksum = 'deadbeef'
post.creation_time = datetime(1997, 1, 1) post.creation_time = datetime(1997, 1, 1)
post.last_edit_time = datetime(1998, 1, 1) post.last_edit_time = datetime(1998, 1, 1)
session.add_all([user, tag1, tag2, related_post1, related_post2, post]) db.session.add_all([user, tag1, tag2, related_post1, related_post2, post])
post.user = user post.user = user
post.tags.append(tag1) post.tags.append(tag1)
post.tags.append(tag2) post.tags.append(tag2)
post.relations.append(related_post1) post.relations.append(related_post1)
post.relations.append(related_post2) post.relations.append(related_post2)
session.commit() db.session.commit()
post = session.query(db.Post).filter(db.Post.post_id == post.post_id).one() db.session.refresh(post)
assert not session.dirty assert not db.session.dirty
assert post.user.user_id is not None assert post.user.user_id is not None
assert post.safety == 'safety' assert post.safety == 'safety'
assert post.type == 'type' assert post.type == 'type'
@ -32,60 +32,60 @@ def test_saving_post(session, post_factory, user_factory, tag_factory):
assert post.last_edit_time == datetime(1998, 1, 1) assert post.last_edit_time == datetime(1998, 1, 1)
assert len(post.relations) == 2 assert len(post.relations) == 2
def test_cascade_deletions(session, post_factory, user_factory, tag_factory): def test_cascade_deletions(post_factory, user_factory, tag_factory):
user = user_factory() user = user_factory()
tag1 = tag_factory() tag1 = tag_factory()
tag2 = tag_factory() tag2 = tag_factory()
related_post1 = post_factory() related_post1 = post_factory()
related_post2 = post_factory() related_post2 = post_factory()
post = post_factory() post = post_factory()
session.add_all([user, tag1, tag2, post, related_post1, related_post2]) db.session.add_all([user, tag1, tag2, post, related_post1, related_post2])
session.flush() db.session.flush()
post.user = user post.user = user
post.tags.append(tag1) post.tags.append(tag1)
post.tags.append(tag2) post.tags.append(tag2)
post.relations.append(related_post1) post.relations.append(related_post1)
post.relations.append(related_post2) post.relations.append(related_post2)
session.flush() db.session.flush()
assert not session.dirty assert not db.session.dirty
assert post.user.user_id is not None assert post.user.user_id is not None
assert len(post.relations) == 2 assert len(post.relations) == 2
assert session.query(db.User).count() == 1 assert db.session.query(db.User).count() == 1
assert session.query(db.Tag).count() == 2 assert db.session.query(db.Tag).count() == 2
assert session.query(db.Post).count() == 3 assert db.session.query(db.Post).count() == 3
assert session.query(db.PostTag).count() == 2 assert db.session.query(db.PostTag).count() == 2
assert session.query(db.PostRelation).count() == 2 assert db.session.query(db.PostRelation).count() == 2
session.delete(post) db.session.delete(post)
session.commit() db.session.commit()
assert not session.dirty assert not db.session.dirty
assert session.query(db.User).count() == 1 assert db.session.query(db.User).count() == 1
assert session.query(db.Tag).count() == 2 assert db.session.query(db.Tag).count() == 2
assert session.query(db.Post).count() == 2 assert db.session.query(db.Post).count() == 2
assert session.query(db.PostTag).count() == 0 assert db.session.query(db.PostTag).count() == 0
assert session.query(db.PostRelation).count() == 0 assert db.session.query(db.PostRelation).count() == 0
def test_tracking_tag_count(session, post_factory, tag_factory): def test_tracking_tag_count(post_factory, tag_factory):
post = post_factory() post = post_factory()
tag1 = tag_factory() tag1 = tag_factory()
tag2 = tag_factory() tag2 = tag_factory()
session.add_all([tag1, tag2, post]) db.session.add_all([tag1, tag2, post])
session.flush() db.session.flush()
post.tags.append(tag1) post.tags.append(tag1)
post.tags.append(tag2) post.tags.append(tag2)
session.commit() db.session.commit()
assert len(post.tags) == 2 assert len(post.tags) == 2
assert post.tag_count == 2 assert post.tag_count == 2
session.delete(tag1) db.session.delete(tag1)
session.commit() db.session.commit()
session.refresh(post) db.session.refresh(post)
assert len(post.tags) == 1 assert len(post.tags) == 1
assert post.tag_count == 1 assert post.tag_count == 1
session.delete(tag2) db.session.delete(tag2)
session.commit() db.session.commit()
session.refresh(post) db.session.refresh(post)
assert len(post.tags) == 0 assert len(post.tags) == 0
assert post.tag_count == 0 assert post.tag_count == 0

View file

@ -1,7 +1,7 @@
from datetime import datetime from datetime import datetime
from szurubooru import db from szurubooru import db
def test_saving_tag(session, tag_factory): def test_saving_tag(tag_factory):
sug1 = tag_factory(names=['sug1']) sug1 = tag_factory(names=['sug1'])
sug2 = tag_factory(names=['sug2']) sug2 = tag_factory(names=['sug2'])
imp1 = tag_factory(names=['imp1']) imp1 = tag_factory(names=['imp1'])
@ -13,8 +13,8 @@ def test_saving_tag(session, tag_factory):
tag.category = db.TagCategory('category') tag.category = db.TagCategory('category')
tag.creation_time = datetime(1997, 1, 1) tag.creation_time = datetime(1997, 1, 1)
tag.last_edit_time = datetime(1998, 1, 1) tag.last_edit_time = datetime(1998, 1, 1)
session.add_all([tag, sug1, sug2, imp1, imp2]) db.session.add_all([tag, sug1, sug2, imp1, imp2])
session.commit() db.session.commit()
assert tag.tag_id is not None assert tag.tag_id is not None
assert sug1.tag_id is not None assert sug1.tag_id is not None
@ -25,9 +25,10 @@ def test_saving_tag(session, tag_factory):
tag.suggestions.append(sug2) tag.suggestions.append(sug2)
tag.implications.append(imp1) tag.implications.append(imp1)
tag.implications.append(imp2) tag.implications.append(imp2)
session.commit() db.session.commit()
tag = session.query(db.Tag) \ tag = db.session \
.query(db.Tag) \
.join(db.TagName) \ .join(db.TagName) \
.filter(db.TagName.name=='alias1') \ .filter(db.TagName.name=='alias1') \
.one() .one()
@ -40,7 +41,7 @@ def test_saving_tag(session, tag_factory):
assert [relation.names[0].name for relation in tag.implications] \ assert [relation.names[0].name for relation in tag.implications] \
== ['imp1', 'imp2'] == ['imp1', 'imp2']
def test_cascade_deletions(session, tag_factory): def test_cascade_deletions(tag_factory):
sug1 = tag_factory(names=['sug1']) sug1 = tag_factory(names=['sug1'])
sug2 = tag_factory(names=['sug2']) sug2 = tag_factory(names=['sug2'])
imp1 = tag_factory(names=['imp1']) imp1 = tag_factory(names=['imp1'])
@ -53,8 +54,8 @@ def test_cascade_deletions(session, tag_factory):
tag.creation_time = datetime(1997, 1, 1) tag.creation_time = datetime(1997, 1, 1)
tag.last_edit_time = datetime(1998, 1, 1) tag.last_edit_time = datetime(1998, 1, 1)
tag.post_count = 1 tag.post_count = 1
session.add_all([tag, sug1, sug2, imp1, imp2]) db.session.add_all([tag, sug1, sug2, imp1, imp2])
session.commit() db.session.commit()
assert tag.tag_id is not None assert tag.tag_id is not None
assert sug1.tag_id is not None assert sug1.tag_id is not None
@ -65,32 +66,32 @@ def test_cascade_deletions(session, tag_factory):
tag.suggestions.append(sug2) tag.suggestions.append(sug2)
tag.implications.append(imp1) tag.implications.append(imp1)
tag.implications.append(imp2) tag.implications.append(imp2)
session.commit() db.session.commit()
session.delete(tag) db.session.delete(tag)
session.commit() db.session.commit()
assert session.query(db.Tag).count() == 4 assert db.session.query(db.Tag).count() == 4
assert session.query(db.TagName).count() == 4 assert db.session.query(db.TagName).count() == 4
assert session.query(db.TagImplication).count() == 0 assert db.session.query(db.TagImplication).count() == 0
assert session.query(db.TagSuggestion).count() == 0 assert db.session.query(db.TagSuggestion).count() == 0
def test_tracking_post_count(session, post_factory, tag_factory): def test_tracking_post_count(post_factory, tag_factory):
tag = tag_factory() tag = tag_factory()
post1 = post_factory() post1 = post_factory()
post2 = post_factory() post2 = post_factory()
session.add_all([tag, post1, post2]) db.session.add_all([tag, post1, post2])
session.flush() db.session.flush()
post1.tags.append(tag) post1.tags.append(tag)
post2.tags.append(tag) post2.tags.append(tag)
session.commit() db.session.commit()
assert len(post1.tags) == 1 assert len(post1.tags) == 1
assert len(post2.tags) == 1 assert len(post2.tags) == 1
assert tag.post_count == 2 assert tag.post_count == 2
session.delete(post1) db.session.delete(post1)
session.commit() db.session.commit()
session.refresh(tag) db.session.refresh(tag)
assert tag.post_count == 1 assert tag.post_count == 1
session.delete(post2) db.session.delete(post2)
session.commit() db.session.commit()
session.refresh(tag) db.session.refresh(tag)
assert tag.post_count == 0 assert tag.post_count == 0

View file

@ -1,7 +1,7 @@
from datetime import datetime from datetime import datetime
from szurubooru import db from szurubooru import db
def test_saving_user(session): def test_saving_user():
user = db.User() user = db.User()
user.name = 'name' user.name = 'name'
user.password_salt = 'salt' user.password_salt = 'salt'
@ -10,8 +10,10 @@ def test_saving_user(session):
user.rank = 'rank' user.rank = 'rank'
user.creation_time = datetime(1997, 1, 1) user.creation_time = datetime(1997, 1, 1)
user.avatar_style = db.User.AVATAR_GRAVATAR user.avatar_style = db.User.AVATAR_GRAVATAR
session.add(user) db.session.add(user)
user = session.query(db.User).one() db.session.flush()
db.session.refresh(user)
assert not db.session.dirty
assert user.name == 'name' assert user.name == 'name'
assert user.password_salt == 'salt' assert user.password_salt == 'salt'
assert user.password_hash == 'hash' assert user.password_hash == 'hash'

View file

@ -40,14 +40,14 @@ def verify_unpaged(executor):
('-creation-date:2014-01,2015', ['t2']), ('-creation-date:2014-01,2015', ['t2']),
]) ])
def test_filter_by_creation_time( def test_filter_by_creation_time(
verify_unpaged, session, tag_factory, input, expected_tag_names): verify_unpaged, tag_factory, input, expected_tag_names):
tag1 = tag_factory(names=['t1']) tag1 = tag_factory(names=['t1'])
tag2 = tag_factory(names=['t2']) tag2 = tag_factory(names=['t2'])
tag3 = tag_factory(names=['t3']) tag3 = tag_factory(names=['t3'])
tag1.creation_time = datetime.datetime(2014, 1, 1) tag1.creation_time = datetime.datetime(2014, 1, 1)
tag2.creation_time = datetime.datetime(2014, 6, 1) tag2.creation_time = datetime.datetime(2014, 6, 1)
tag3.creation_time = datetime.datetime(2015, 1, 1) tag3.creation_time = datetime.datetime(2015, 1, 1)
session.add_all([tag1, tag2, tag3]) db.session.add_all([tag1, tag2, tag3])
verify_unpaged(input, expected_tag_names) verify_unpaged(input, expected_tag_names)
@pytest.mark.parametrize('input,expected_tag_names', [ @pytest.mark.parametrize('input,expected_tag_names', [
@ -70,12 +70,11 @@ def test_filter_by_creation_time(
('name:tag4', ['tag4']), ('name:tag4', ['tag4']),
('name:tag4,tag5', ['tag4']), ('name:tag4,tag5', ['tag4']),
]) ])
def test_filter_by_name( def test_filter_by_name(verify_unpaged, tag_factory, input, expected_tag_names):
session, verify_unpaged, tag_factory, input, expected_tag_names): db.session.add(tag_factory(names=['tag1']))
session.add(tag_factory(names=['tag1'])) db.session.add(tag_factory(names=['tag2']))
session.add(tag_factory(names=['tag2'])) db.session.add(tag_factory(names=['tag3']))
session.add(tag_factory(names=['tag3'])) db.session.add(tag_factory(names=['tag4', 'tag5', 'tag6']))
session.add(tag_factory(names=['tag4', 'tag5', 'tag6']))
verify_unpaged(input, expected_tag_names) verify_unpaged(input, expected_tag_names)
@pytest.mark.parametrize('input,expected_tag_names', [ @pytest.mark.parametrize('input,expected_tag_names', [
@ -84,10 +83,9 @@ def test_filter_by_name(
('t2', ['t2']), ('t2', ['t2']),
('t1,t2', ['t1', 't2']), ('t1,t2', ['t1', 't2']),
]) ])
def test_anonymous( def test_anonymous(verify_unpaged, tag_factory, input, expected_tag_names):
session, verify_unpaged, tag_factory, input, expected_tag_names): db.session.add(tag_factory(names=['t1']))
session.add(tag_factory(names=['t1'])) db.session.add(tag_factory(names=['t2']))
session.add(tag_factory(names=['t2']))
verify_unpaged(input, expected_tag_names) verify_unpaged(input, expected_tag_names)
@pytest.mark.parametrize('input,expected_tag_names', [ @pytest.mark.parametrize('input,expected_tag_names', [
@ -99,10 +97,9 @@ def test_anonymous(
('-order:name,asc', ['t2', 't1']), ('-order:name,asc', ['t2', 't1']),
('-order:name,desc', ['t1', 't2']), ('-order:name,desc', ['t1', 't2']),
]) ])
def test_order_by_name( def test_order_by_name(verify_unpaged, tag_factory, input, expected_tag_names):
session, verify_unpaged, tag_factory, input, expected_tag_names): db.session.add(tag_factory(names=['t2']))
session.add(tag_factory(names=['t2'])) db.session.add(tag_factory(names=['t1']))
session.add(tag_factory(names=['t1']))
verify_unpaged(input, expected_tag_names) verify_unpaged(input, expected_tag_names)
@pytest.mark.parametrize('input,expected_user_names', [ @pytest.mark.parametrize('input,expected_user_names', [
@ -111,28 +108,28 @@ def test_order_by_name(
('order:creation-time', ['t3', 't2', 't1']), ('order:creation-time', ['t3', 't2', 't1']),
]) ])
def test_order_by_creation_time( def test_order_by_creation_time(
session, verify_unpaged, tag_factory, input, expected_user_names): verify_unpaged, tag_factory, input, expected_user_names):
tag1 = tag_factory(names=['t1']) tag1 = tag_factory(names=['t1'])
tag2 = tag_factory(names=['t2']) tag2 = tag_factory(names=['t2'])
tag3 = tag_factory(names=['t3']) tag3 = tag_factory(names=['t3'])
tag1.creation_time = datetime.datetime(1991, 1, 1) tag1.creation_time = datetime.datetime(1991, 1, 1)
tag2.creation_time = datetime.datetime(1991, 1, 2) tag2.creation_time = datetime.datetime(1991, 1, 2)
tag3.creation_time = datetime.datetime(1991, 1, 3) tag3.creation_time = datetime.datetime(1991, 1, 3)
session.add_all([tag3, tag1, tag2]) db.session.add_all([tag3, tag1, tag2])
verify_unpaged(input, expected_user_names) verify_unpaged(input, expected_user_names)
@pytest.mark.parametrize('input,expected_tag_names', [ @pytest.mark.parametrize('input,expected_tag_names', [
('order:suggestion-count', ['t1', 't2', 'sug1', 'sug2', 'sug3']), ('order:suggestion-count', ['t1', 't2', 'sug1', 'sug2', 'sug3']),
]) ])
def test_order_by_suggestion_count( def test_order_by_suggestion_count(
session, verify_unpaged, tag_factory, input, expected_tag_names): verify_unpaged, tag_factory, input, expected_tag_names):
sug1 = tag_factory(names=['sug1']) sug1 = tag_factory(names=['sug1'])
sug2 = tag_factory(names=['sug2']) sug2 = tag_factory(names=['sug2'])
sug3 = tag_factory(names=['sug3']) sug3 = tag_factory(names=['sug3'])
tag1 = tag_factory(names=['t1']) tag1 = tag_factory(names=['t1'])
tag2 = tag_factory(names=['t2']) tag2 = tag_factory(names=['t2'])
session.add_all([sug1, sug3, tag2, sug2, tag1]) db.session.add_all([sug1, sug3, tag2, sug2, tag1])
session.commit() db.session.commit()
tag1.suggestions.append(sug1) tag1.suggestions.append(sug1)
tag1.suggestions.append(sug2) tag1.suggestions.append(sug2)
tag2.suggestions.append(sug3) tag2.suggestions.append(sug3)
@ -142,32 +139,32 @@ def test_order_by_suggestion_count(
('order:implication-count', ['t1', 't2', 'sug1', 'sug2', 'sug3']), ('order:implication-count', ['t1', 't2', 'sug1', 'sug2', 'sug3']),
]) ])
def test_order_by_implication_count( def test_order_by_implication_count(
session, verify_unpaged, tag_factory, input, expected_tag_names): verify_unpaged, tag_factory, input, expected_tag_names):
sug1 = tag_factory(names=['sug1']) sug1 = tag_factory(names=['sug1'])
sug2 = tag_factory(names=['sug2']) sug2 = tag_factory(names=['sug2'])
sug3 = tag_factory(names=['sug3']) sug3 = tag_factory(names=['sug3'])
tag1 = tag_factory(names=['t1']) tag1 = tag_factory(names=['t1'])
tag2 = tag_factory(names=['t2']) tag2 = tag_factory(names=['t2'])
session.add_all([sug1, sug3, tag2, sug2, tag1]) db.session.add_all([sug1, sug3, tag2, sug2, tag1])
session.commit() db.session.commit()
tag1.implications.append(sug1) tag1.implications.append(sug1)
tag1.implications.append(sug2) tag1.implications.append(sug2)
tag2.implications.append(sug3) tag2.implications.append(sug3)
verify_unpaged(input, expected_tag_names) verify_unpaged(input, expected_tag_names)
def test_filter_by_relation_count(session, verify_unpaged, tag_factory): def test_filter_by_relation_count(verify_unpaged, tag_factory):
sug1 = tag_factory(names=['sug1']) sug1 = tag_factory(names=['sug1'])
sug2 = tag_factory(names=['sug2']) sug2 = tag_factory(names=['sug2'])
imp1 = tag_factory(names=['imp1']) imp1 = tag_factory(names=['imp1'])
tag1 = tag_factory(names=['t1']) tag1 = tag_factory(names=['t1'])
tag2 = tag_factory(names=['t2']) tag2 = tag_factory(names=['t2'])
session.add_all([sug1, tag1, sug2, imp1, tag2]) db.session.add_all([sug1, tag1, sug2, imp1, tag2])
session.commit() db.session.commit()
session.add_all([ db.session.add_all([
db.TagSuggestion(tag1.tag_id, sug1.tag_id), db.TagSuggestion(tag1.tag_id, sug1.tag_id),
db.TagSuggestion(tag1.tag_id, sug2.tag_id), db.TagSuggestion(tag1.tag_id, sug2.tag_id),
db.TagImplication(tag2.tag_id, imp1.tag_id)]) db.TagImplication(tag2.tag_id, imp1.tag_id)])
session.commit() db.session.commit()
verify_unpaged('suggestion-count:0', ['imp1', 'sug1', 'sug2', 't2']) verify_unpaged('suggestion-count:0', ['imp1', 'sug1', 'sug2', 't2'])
verify_unpaged('suggestion-count:1', []) verify_unpaged('suggestion-count:1', [])
verify_unpaged('suggestion-count:2', ['t1']) verify_unpaged('suggestion-count:2', ['t1'])

View file

@ -42,14 +42,14 @@ def verify_unpaged(executor):
('-creation-date:2014-01,2015', ['u2']), ('-creation-date:2014-01,2015', ['u2']),
]) ])
def test_filter_by_creation_time( def test_filter_by_creation_time(
verify_unpaged, session, input, expected_user_names, user_factory): verify_unpaged, input, expected_user_names, user_factory):
user1 = user_factory(name='u1') user1 = user_factory(name='u1')
user2 = user_factory(name='u2') user2 = user_factory(name='u2')
user3 = user_factory(name='u3') user3 = user_factory(name='u3')
user1.creation_time = datetime.datetime(2014, 1, 1) user1.creation_time = datetime.datetime(2014, 1, 1)
user2.creation_time = datetime.datetime(2014, 6, 1) user2.creation_time = datetime.datetime(2014, 6, 1)
user3.creation_time = datetime.datetime(2015, 1, 1) user3.creation_time = datetime.datetime(2015, 1, 1)
session.add_all([user1, user2, user3]) db.session.add_all([user1, user2, user3])
verify_unpaged(input, expected_user_names) verify_unpaged(input, expected_user_names)
@pytest.mark.parametrize('input,expected_user_names', [ @pytest.mark.parametrize('input,expected_user_names', [
@ -71,10 +71,10 @@ def test_filter_by_creation_time(
('-name:user1,user3', ['user2']), ('-name:user1,user3', ['user2']),
]) ])
def test_filter_by_name( def test_filter_by_name(
session, verify_unpaged, input, expected_user_names, user_factory): verify_unpaged, input, expected_user_names, user_factory):
session.add(user_factory(name='user1')) db.session.add(user_factory(name='user1'))
session.add(user_factory(name='user2')) db.session.add(user_factory(name='user2'))
session.add(user_factory(name='user3')) db.session.add(user_factory(name='user3'))
verify_unpaged(input, expected_user_names) verify_unpaged(input, expected_user_names)
@pytest.mark.parametrize('input,expected_user_names', [ @pytest.mark.parametrize('input,expected_user_names', [
@ -84,9 +84,9 @@ def test_filter_by_name(
('u1,u2', ['u1', 'u2']), ('u1,u2', ['u1', 'u2']),
]) ])
def test_anonymous( def test_anonymous(
session, verify_unpaged, input, expected_user_names, user_factory): verify_unpaged, input, expected_user_names, user_factory):
session.add(user_factory(name='u1')) db.session.add(user_factory(name='u1'))
session.add(user_factory(name='u2')) db.session.add(user_factory(name='u2'))
verify_unpaged(input, expected_user_names) verify_unpaged(input, expected_user_names)
@pytest.mark.parametrize('input,expected_user_names', [ @pytest.mark.parametrize('input,expected_user_names', [
@ -95,14 +95,14 @@ def test_anonymous(
('creation-time:2016 u2', []), ('creation-time:2016 u2', []),
]) ])
def test_combining_tokens( def test_combining_tokens(
session, verify_unpaged, input, expected_user_names, user_factory): verify_unpaged, input, expected_user_names, user_factory):
user1 = user_factory(name='u1') user1 = user_factory(name='u1')
user2 = user_factory(name='u2') user2 = user_factory(name='u2')
user3 = user_factory(name='u3') user3 = user_factory(name='u3')
user1.creation_time = datetime.datetime(2014, 1, 1) user1.creation_time = datetime.datetime(2014, 1, 1)
user2.creation_time = datetime.datetime(2014, 6, 1) user2.creation_time = datetime.datetime(2014, 6, 1)
user3.creation_time = datetime.datetime(2015, 1, 1) user3.creation_time = datetime.datetime(2015, 1, 1)
session.add_all([user1, user2, user3]) db.session.add_all([user1, user2, user3])
verify_unpaged(input, expected_user_names) verify_unpaged(input, expected_user_names)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -114,10 +114,10 @@ def test_combining_tokens(
(0, 0, 2, []), (0, 0, 2, []),
]) ])
def test_paging( def test_paging(
session, executor, user_factory, page, page_size, executor, user_factory, page, page_size,
expected_total_count, expected_user_names): expected_total_count, expected_user_names):
session.add(user_factory(name='u1')) db.session.add(user_factory(name='u1'))
session.add(user_factory(name='u2')) db.session.add(user_factory(name='u2'))
actual_count, actual_users = executor.execute( actual_count, actual_users = executor.execute(
'', page=page, page_size=page_size) '', page=page, page_size=page_size)
actual_user_names = [u.name for u in actual_users] actual_user_names = [u.name for u in actual_users]
@ -134,9 +134,9 @@ def test_paging(
('-order:name,desc', ['u1', 'u2']), ('-order:name,desc', ['u1', 'u2']),
]) ])
def test_order_by_name( def test_order_by_name(
session, verify_unpaged, input, expected_user_names, user_factory): verify_unpaged, input, expected_user_names, user_factory):
session.add(user_factory(name='u2')) db.session.add(user_factory(name='u2'))
session.add(user_factory(name='u1')) db.session.add(user_factory(name='u1'))
verify_unpaged(input, expected_user_names) verify_unpaged(input, expected_user_names)
@pytest.mark.parametrize('input,expected_user_names', [ @pytest.mark.parametrize('input,expected_user_names', [
@ -150,14 +150,14 @@ def test_order_by_name(
('-order:creation-date,desc', ['u1', 'u2', 'u3']), ('-order:creation-date,desc', ['u1', 'u2', 'u3']),
]) ])
def test_order_by_creation_time( def test_order_by_creation_time(
session, verify_unpaged, input, expected_user_names, user_factory): verify_unpaged, input, expected_user_names, user_factory):
user1 = user_factory(name='u1') user1 = user_factory(name='u1')
user2 = user_factory(name='u2') user2 = user_factory(name='u2')
user3 = user_factory(name='u3') user3 = user_factory(name='u3')
user1.creation_time = datetime.datetime(1991, 1, 1) user1.creation_time = datetime.datetime(1991, 1, 1)
user2.creation_time = datetime.datetime(1991, 1, 2) user2.creation_time = datetime.datetime(1991, 1, 2)
user3.creation_time = datetime.datetime(1991, 1, 3) user3.creation_time = datetime.datetime(1991, 1, 3)
session.add_all([user3, user1, user2]) db.session.add_all([user3, user1, user2])
verify_unpaged(input, expected_user_names) verify_unpaged(input, expected_user_names)
@pytest.mark.parametrize('input,expected_user_names', [ @pytest.mark.parametrize('input,expected_user_names', [
@ -168,21 +168,21 @@ def test_order_by_creation_time(
('order:login-time', ['u3', 'u2', 'u1']), ('order:login-time', ['u3', 'u2', 'u1']),
]) ])
def test_order_by_last_login_time( def test_order_by_last_login_time(
session, verify_unpaged, input, expected_user_names, user_factory): verify_unpaged, input, expected_user_names, user_factory):
user1 = user_factory(name='u1') user1 = user_factory(name='u1')
user2 = user_factory(name='u2') user2 = user_factory(name='u2')
user3 = user_factory(name='u3') user3 = user_factory(name='u3')
user1.last_login_time = datetime.datetime(1991, 1, 1) user1.last_login_time = datetime.datetime(1991, 1, 1)
user2.last_login_time = datetime.datetime(1991, 1, 2) user2.last_login_time = datetime.datetime(1991, 1, 2)
user3.last_login_time = datetime.datetime(1991, 1, 3) user3.last_login_time = datetime.datetime(1991, 1, 3)
session.add_all([user3, user1, user2]) db.session.add_all([user3, user1, user2])
verify_unpaged(input, expected_user_names) verify_unpaged(input, expected_user_names)
def test_random_order(session, executor, user_factory): def test_random_order(executor, user_factory):
user1 = user_factory(name='u1') user1 = user_factory(name='u1')
user2 = user_factory(name='u2') user2 = user_factory(name='u2')
user3 = user_factory(name='u3') user3 = user_factory(name='u3')
session.add_all([user3, user1, user2]) db.session.add_all([user3, user1, user2])
actual_count, actual_users = executor.execute( actual_count, actual_users = executor.execute(
'order:random', page=1, page_size=100) 'order:random', page=1, page_size=100)
actual_user_names = [u.name for u in actual_users] actual_user_names = [u.name for u in actual_users]

View file

@ -3,7 +3,7 @@ import pytest
from szurubooru import db from szurubooru import db
from szurubooru.util import snapshots from szurubooru.util import snapshots
def test_serializing_tag(session, tag_factory): def test_serializing_tag(tag_factory):
tag = tag_factory(names=['main_name', 'alias'], category_name='dummy') tag = tag_factory(names=['main_name', 'alias'], category_name='dummy')
assert snapshots.get_tag_snapshot(tag) == { assert snapshots.get_tag_snapshot(tag) == {
'names': ['main_name', 'alias'], 'names': ['main_name', 'alias'],
@ -15,10 +15,10 @@ def test_serializing_tag(session, tag_factory):
imp2 = tag_factory(names=['imp2_main_name', 'imp2_alias']) imp2 = tag_factory(names=['imp2_main_name', 'imp2_alias'])
sug1 = tag_factory(names=['sug1_main_name', 'sug1_alias']) sug1 = tag_factory(names=['sug1_main_name', 'sug1_alias'])
sug2 = tag_factory(names=['sug2_main_name', 'sug2_alias']) sug2 = tag_factory(names=['sug2_main_name', 'sug2_alias'])
session.add_all([imp1, imp2, sug1, sug2]) db.session.add_all([imp1, imp2, sug1, sug2])
tag.implications = [imp1, imp2] tag.implications = [imp1, imp2]
tag.suggestions = [sug1, sug2] tag.suggestions = [sug1, sug2]
session.flush() db.session.flush()
assert snapshots.get_tag_snapshot(tag) == { assert snapshots.get_tag_snapshot(tag) == {
'names': ['main_name', 'alias'], 'names': ['main_name', 'alias'],
'category': 'dummy', 'category': 'dummy',
@ -26,39 +26,38 @@ def test_serializing_tag(session, tag_factory):
'suggestions': ['sug1_main_name', 'sug2_main_name'], 'suggestions': ['sug1_main_name', 'sug2_main_name'],
} }
def test_merging_modification_to_creation(session, tag_factory, user_factory): def test_merging_modification_to_creation(tag_factory, user_factory):
tag = tag_factory(names=['dummy']) tag = tag_factory(names=['dummy'])
user = user_factory() user = user_factory()
session.add_all([tag, user]) db.session.add_all([tag, user])
session.flush() db.session.flush()
snapshots.create(tag, user) snapshots.create(tag, user)
session.flush() db.session.flush()
tag.names = [db.TagName('changed')] tag.names = [db.TagName('changed')]
snapshots.modify(tag, user) snapshots.modify(tag, user)
session.flush() db.session.flush()
results = session.query(db.Snapshot).all() results = db.session.query(db.Snapshot).all()
assert len(results) == 1 assert len(results) == 1
assert results[0].operation == db.Snapshot.OPERATION_CREATED assert results[0].operation == db.Snapshot.OPERATION_CREATED
assert results[0].data['names'] == ['changed'] assert results[0].data['names'] == ['changed']
def test_merging_modifications( def test_merging_modifications(fake_datetime, tag_factory, user_factory):
fake_datetime, session, tag_factory, user_factory):
tag = tag_factory(names=['dummy']) tag = tag_factory(names=['dummy'])
user = user_factory() user = user_factory()
session.add_all([tag, user]) db.session.add_all([tag, user])
session.flush() db.session.flush()
with fake_datetime('13:00:00'): with fake_datetime('13:00:00'):
snapshots.create(tag, user) snapshots.create(tag, user)
session.flush() db.session.flush()
tag.names = [db.TagName('changed')] tag.names = [db.TagName('changed')]
with fake_datetime('14:00:00'): with fake_datetime('14:00:00'):
snapshots.modify(tag, user) snapshots.modify(tag, user)
session.flush() db.session.flush()
tag.names = [db.TagName('changed again')] tag.names = [db.TagName('changed again')]
with fake_datetime('14:00:01'): with fake_datetime('14:00:01'):
snapshots.modify(tag, user) snapshots.modify(tag, user)
session.flush() db.session.flush()
results = session.query(db.Snapshot).all() results = db.session.query(db.Snapshot).all()
assert len(results) == 2 assert len(results) == 2
assert results[0].operation == db.Snapshot.OPERATION_CREATED assert results[0].operation == db.Snapshot.OPERATION_CREATED
assert results[1].operation == db.Snapshot.OPERATION_MODIFIED assert results[1].operation == db.Snapshot.OPERATION_MODIFIED
@ -66,93 +65,94 @@ def test_merging_modifications(
assert results[1].data['names'] == ['changed again'] assert results[1].data['names'] == ['changed again']
def test_not_adding_snapshot_if_data_doesnt_change( def test_not_adding_snapshot_if_data_doesnt_change(
fake_datetime, session, tag_factory, user_factory): fake_datetime, tag_factory, user_factory):
tag = tag_factory(names=['dummy']) tag = tag_factory(names=['dummy'])
user = user_factory() user = user_factory()
session.add_all([tag, user]) db.session.add_all([tag, user])
session.flush() db.session.flush()
with fake_datetime('13:00:00'): with fake_datetime('13:00:00'):
snapshots.create(tag, user) snapshots.create(tag, user)
session.flush() db.session.flush()
with fake_datetime('14:00:00'): with fake_datetime('14:00:00'):
snapshots.modify(tag, user) snapshots.modify(tag, user)
session.flush() db.session.flush()
results = session.query(db.Snapshot).all() results = db.session.query(db.Snapshot).all()
assert len(results) == 1 assert len(results) == 1
assert results[0].operation == db.Snapshot.OPERATION_CREATED assert results[0].operation == db.Snapshot.OPERATION_CREATED
assert results[0].data['names'] == ['dummy'] assert results[0].data['names'] == ['dummy']
def test_not_merging_due_to_time_difference( def test_not_merging_due_to_time_difference(
fake_datetime, session, tag_factory, user_factory): fake_datetime, tag_factory, user_factory):
tag = tag_factory(names=['dummy']) tag = tag_factory(names=['dummy'])
user = user_factory() user = user_factory()
session.add_all([tag, user]) db.session.add_all([tag, user])
session.flush() db.session.flush()
with fake_datetime('13:00:00'): with fake_datetime('13:00:00'):
snapshots.create(tag, user) snapshots.create(tag, user)
session.flush() db.session.flush()
tag.names = [db.TagName('changed')] tag.names = [db.TagName('changed')]
with fake_datetime('13:10:01'): with fake_datetime('13:10:01'):
snapshots.modify(tag, user) snapshots.modify(tag, user)
session.flush() db.session.flush()
assert session.query(db.Snapshot).count() == 2 assert db.session.query(db.Snapshot).count() == 2
def test_not_merging_operations_by_different_users( def test_not_merging_operations_by_different_users(
fake_datetime, session, tag_factory, user_factory): fake_datetime, tag_factory, user_factory):
tag = tag_factory(names=['dummy']) tag = tag_factory(names=['dummy'])
user1, user2 = [user_factory(), user_factory()] user1, user2 = [user_factory(), user_factory()]
session.add_all([tag, user1, user2]) db.session.add_all([tag, user1, user2])
session.flush() db.session.flush()
with fake_datetime('13:00:00'): with fake_datetime('13:00:00'):
snapshots.create(tag, user1) snapshots.create(tag, user1)
session.flush() db.session.flush()
tag.names = [db.TagName('changed')] tag.names = [db.TagName('changed')]
snapshots.modify(tag, user2) snapshots.modify(tag, user2)
session.flush() db.session.flush()
assert session.query(db.Snapshot).count() == 2 assert db.session.query(db.Snapshot).count() == 2
def test_merging_resets_merging_time_window( def test_merging_resets_merging_time_window(
fake_datetime, session, tag_factory, user_factory): fake_datetime, tag_factory, user_factory):
tag = tag_factory(names=['dummy']) tag = tag_factory(names=['dummy'])
user = user_factory() user = user_factory()
session.add_all([tag, user]) db.session.add_all([tag, user])
session.flush() db.session.flush()
with fake_datetime('13:00:00'): with fake_datetime('13:00:00'):
snapshots.create(tag, user) snapshots.create(tag, user)
session.flush() db.session.flush()
tag.names = [db.TagName('changed')] tag.names = [db.TagName('changed')]
with fake_datetime('13:09:59'): with fake_datetime('13:09:59'):
snapshots.modify(tag, user) snapshots.modify(tag, user)
session.flush() db.session.flush()
tag.names = [db.TagName('changed again')] tag.names = [db.TagName('changed again')]
with fake_datetime('13:19:59'): with fake_datetime('13:19:59'):
snapshots.modify(tag, user) snapshots.modify(tag, user)
session.flush() db.session.flush()
results = session.query(db.Snapshot).all() results = db.session.query(db.Snapshot).all()
assert len(results) == 1 assert len(results) == 1
assert results[0].data['names'] == ['changed again'] assert results[0].data['names'] == ['changed again']
@pytest.mark.parametrize( @pytest.mark.parametrize(
'initial_operation', [snapshots.create, snapshots.modify]) 'initial_operation', [snapshots.create, snapshots.modify])
def test_merging_deletion_to_modification_or_creation( def test_merging_deletion_to_modification_or_creation(
fake_datetime, session, tag_factory, user_factory, initial_operation): fake_datetime, tag_factory, user_factory, initial_operation):
tag = tag_factory(names=['dummy'], category_name='dummy') tag = tag_factory(names=['dummy'], category_name='dummy')
user = user_factory() user = user_factory()
session.add_all([tag, user]) db.session.add_all([tag, user])
session.flush() db.session.flush()
with fake_datetime('13:00:00'): with fake_datetime('13:00:00'):
initial_operation(tag, user) initial_operation(tag, user)
session.flush() db.session.flush()
tag.names = [db.TagName('changed')] tag.names = [db.TagName('changed')]
with fake_datetime('14:00:00'): with fake_datetime('14:00:00'):
snapshots.modify(tag, user) snapshots.modify(tag, user)
session.flush() db.session.flush()
tag.names = [db.TagName('changed again')] tag.names = [db.TagName('changed again')]
with fake_datetime('14:00:01'): with fake_datetime('14:00:01'):
snapshots.delete(tag, user) snapshots.delete(tag, user)
session.flush() db.session.flush()
assert session.query(db.Snapshot).count() == 2 assert db.session.query(db.Snapshot).count() == 2
results = session.query(db.Snapshot) \ results = db.session \
.query(db.Snapshot) \
.order_by(db.Snapshot.snapshot_id.asc()) \ .order_by(db.Snapshot.snapshot_id.asc()) \
.all() .all()
assert results[1].operation == db.Snapshot.OPERATION_DELETED assert results[1].operation == db.Snapshot.OPERATION_DELETED
@ -161,20 +161,20 @@ def test_merging_deletion_to_modification_or_creation(
@pytest.mark.parametrize( @pytest.mark.parametrize(
'expected_operation', [snapshots.create, snapshots.modify]) 'expected_operation', [snapshots.create, snapshots.modify])
def test_merging_deletion_all_the_way_deletes_all_snapshots( def test_merging_deletion_all_the_way_deletes_all_snapshots(
fake_datetime, session, tag_factory, user_factory, expected_operation): fake_datetime, tag_factory, user_factory, expected_operation):
tag = tag_factory(names=['dummy']) tag = tag_factory(names=['dummy'])
user = user_factory() user = user_factory()
session.add_all([tag, user]) db.session.add_all([tag, user])
session.flush() db.session.flush()
with fake_datetime('13:00:00'): with fake_datetime('13:00:00'):
snapshots.create(tag, user) snapshots.create(tag, user)
session.flush() db.session.flush()
tag.names = [db.TagName('changed')] tag.names = [db.TagName('changed')]
with fake_datetime('13:00:01'): with fake_datetime('13:00:01'):
snapshots.modify(tag, user) snapshots.modify(tag, user)
session.flush() db.session.flush()
tag.names = [db.TagName('changed again')] tag.names = [db.TagName('changed again')]
with fake_datetime('13:00:02'): with fake_datetime('13:00:02'):
snapshots.delete(tag, user) snapshots.delete(tag, user)
session.flush() db.session.flush()
assert session.query(db.Snapshot).count() == 0 assert db.session.query(db.Snapshot).count() == 0

View file

@ -34,8 +34,8 @@ def save(operation, entity, auth_user):
snapshot.data = serializers[table_name](entity) snapshot.data = serializers[table_name](entity)
snapshot.user = auth_user snapshot.user = auth_user
session = db.session() earlier_snapshots = db.session \
earlier_snapshots = session.query(db.Snapshot) \ .query(db.Snapshot) \
.filter(db.Snapshot.resource_type == table_name) \ .filter(db.Snapshot.resource_type == table_name) \
.filter(db.Snapshot.resource_id == primary_key) \ .filter(db.Snapshot.resource_id == primary_key) \
.order_by(db.Snapshot.creation_time.desc()) \ .order_by(db.Snapshot.creation_time.desc()) \
@ -49,7 +49,7 @@ def save(operation, entity, auth_user):
if snapshot.data != last_snapshot.data: if snapshot.data != last_snapshot.data:
if not is_fresh or last_snapshot.user != auth_user: if not is_fresh or last_snapshot.user != auth_user:
break break
session.delete(last_snapshot) db.session.delete(last_snapshot)
if snapshot.operation != db.Snapshot.OPERATION_DELETED: if snapshot.operation != db.Snapshot.OPERATION_DELETED:
snapshot.operation = last_snapshot.operation snapshot.operation = last_snapshot.operation
snapshots_left -= 1 snapshots_left -= 1
@ -57,7 +57,7 @@ def save(operation, entity, auth_user):
if not snapshots_left and operation == db.Snapshot.OPERATION_DELETED: if not snapshots_left and operation == db.Snapshot.OPERATION_DELETED:
pass pass
else: else:
session.add(snapshot) db.session.add(snapshot)
def create(entity, auth_user): def create(entity, auth_user):
save(db.Snapshot.OPERATION_CREATED, entity, auth_user) save(db.Snapshot.OPERATION_CREATED, entity, auth_user)

View file

@ -26,7 +26,7 @@ def update_name(category, name):
expr = db.TagCategory.name.ilike(name) expr = db.TagCategory.name.ilike(name)
if category.tag_category_id: if category.tag_category_id:
expr = expr & (db.TagCategory.tag_category_id != category.tag_category_id) expr = expr & (db.TagCategory.tag_category_id != category.tag_category_id)
already_exists = db.session().query(db.TagCategory).filter(expr).count() > 0 already_exists = db.session.query(db.TagCategory).filter(expr).count() > 0
if already_exists: if already_exists:
raise TagCategoryAlreadyExistsError( raise TagCategoryAlreadyExistsError(
'A category with this name already exists.') 'A category with this name already exists.')
@ -43,7 +43,8 @@ def update_color(category, color):
category.color = color category.color = color
def get_category_by_name(name): def get_category_by_name(name):
return db.session.query(db.TagCategory) \ return db.session \
.query(db.TagCategory) \
.filter(db.TagCategory.name.ilike(name)) \ .filter(db.TagCategory.name.ilike(name)) \
.first() .first()
@ -54,7 +55,8 @@ def get_all_categories():
return db.session.query(db.TagCategory).all() return db.session.query(db.TagCategory).all()
def get_default_category(): def get_default_category():
return db.session().query(db.TagCategory) \ return db.session \
.query(db.TagCategory) \
.order_by(db.TagCategory.tag_category_id.asc()) \ .order_by(db.TagCategory.tag_category_id.asc()) \
.limit(1) \ .limit(1) \
.one() .one()

View file

@ -32,7 +32,7 @@ def export_to_json():
'tags': [], 'tags': [],
'categories': [], 'categories': [],
} }
all_tags = db.session() \ all_tags = db.session \
.query(db.Tag) \ .query(db.Tag) \
.options( .options(
sqlalchemy.orm.joinedload('suggestions'), sqlalchemy.orm.joinedload('suggestions'),
@ -61,7 +61,8 @@ def export_to_json():
handle.write(json.dumps(output, separators=(',', ':'))) handle.write(json.dumps(output, separators=(',', ':')))
def get_tag_by_name(name): def get_tag_by_name(name):
return db.session().query(db.Tag) \ return db.session \
.query(db.Tag) \
.join(db.TagName) \ .join(db.TagName) \
.filter(db.TagName.name.ilike(name)) \ .filter(db.TagName.name.ilike(name)) \
.first() .first()
@ -73,7 +74,7 @@ def get_tags_by_names(names):
expr = sqlalchemy.sql.false() expr = sqlalchemy.sql.false()
for name in names: for name in names:
expr = expr | db.TagName.name.ilike(name) expr = expr | db.TagName.name.ilike(name)
return db.session().query(db.Tag).join(db.TagName).filter(expr).all() return db.session.query(db.Tag).join(db.TagName).filter(expr).all()
def get_or_create_tags_by_names(names): def get_or_create_tags_by_names(names):
names = misc.icase_unique(names) names = misc.icase_unique(names)
@ -93,7 +94,7 @@ def get_or_create_tags_by_names(names):
category_name=tag_categories.get_default_category().name, category_name=tag_categories.get_default_category().name,
suggestions=[], suggestions=[],
implications=[]) implications=[])
db.session().add(new_tag) db.session.add(new_tag)
new_tags.append(new_tag) new_tags.append(new_tag)
return related_tags, new_tags return related_tags, new_tags
@ -107,8 +108,8 @@ def create_tag(names, category_name, suggestions, implications):
return tag return tag
def update_category_name(tag, category_name): def update_category_name(tag, category_name):
session = db.session() category = db.session \
category = session.query(db.TagCategory) \ .query(db.TagCategory) \
.filter(db.TagCategory.name == category_name) \ .filter(db.TagCategory.name == category_name) \
.first() .first()
if not category: if not category:
@ -131,7 +132,7 @@ def update_names(tag, names):
expr = expr | db.TagName.name.ilike(name) expr = expr | db.TagName.name.ilike(name)
if tag.tag_id: if tag.tag_id:
expr = expr & (db.TagName.tag_id != tag.tag_id) expr = expr & (db.TagName.tag_id != tag.tag_id)
existing_tags = db.session().query(db.TagName).filter(expr).all() existing_tags = db.session.query(db.TagName).filter(expr).all()
if len(existing_tags): if len(existing_tags):
raise TagAlreadyExistsError( raise TagAlreadyExistsError(
'One of names is already used by another tag.') 'One of names is already used by another tag.')
@ -149,12 +150,12 @@ def update_implications(tag, relations):
if _check_name_intersection(_get_plain_names(tag), relations): if _check_name_intersection(_get_plain_names(tag), relations):
raise InvalidTagRelationError('Tag cannot imply itself.') raise InvalidTagRelationError('Tag cannot imply itself.')
related_tags, new_tags = get_or_create_tags_by_names(relations) related_tags, new_tags = get_or_create_tags_by_names(relations)
db.session().flush() db.session.flush()
tag.implications = related_tags + new_tags tag.implications = related_tags + new_tags
def update_suggestions(tag, relations): def update_suggestions(tag, relations):
if _check_name_intersection(_get_plain_names(tag), relations): if _check_name_intersection(_get_plain_names(tag), relations):
raise InvalidTagRelationError('Tag cannot suggest itself.') raise InvalidTagRelationError('Tag cannot suggest itself.')
related_tags, new_tags = get_or_create_tags_by_names(relations) related_tags, new_tags = get_or_create_tags_by_names(relations)
db.session().flush() db.session.flush()
tag.suggestions = related_tags + new_tags tag.suggestions = related_tags + new_tags

View file

@ -16,12 +16,14 @@ def get_user_count():
return db.session.query(db.User).count() return db.session.query(db.User).count()
def get_user_by_name(name): def get_user_by_name(name):
return db.session.query(db.User) \ return db.session \
.query(db.User) \
.filter(func.lower(db.User.name) == func.lower(name)) \ .filter(func.lower(db.User.name) == func.lower(name)) \
.first() .first()
def get_user_by_name_or_email(name_or_email): def get_user_by_name_or_email(name_or_email):
return db.session.query(db.User) \ return db.session \
.query(db.User) \
.filter( .filter(
(func.lower(db.User.name) == func.lower(name_or_email)) (func.lower(db.User.name) == func.lower(name_or_email))
| (func.lower(db.User.email) == func.lower(name_or_email))) \ | (func.lower(db.User.email) == func.lower(name_or_email))) \