server/general: consistently use db.session

This commit is contained in:
rr- 2016-04-19 17:39:16 +02:00
parent fe56e376f6
commit 2e57a0746f
28 changed files with 351 additions and 380 deletions

View file

@ -1,6 +1,5 @@
''' Exports create_app. '''
import json
import falcon
from szurubooru import api, errors, middleware
@ -26,15 +25,13 @@ def _on_processing_error(ex, _request, _response, _params):
def create_method_not_allowed(allowed_methods):
allowed = ', '.join(allowed_methods)
def method_not_allowed(request, response, **kwargs):
def method_not_allowed(request, response, **_kwargs):
response.status = falcon.status_codes.HTTP_405
response.set_header('Allow', allowed)
request.context.output = {
'title': 'Method not allowed',
'description': 'Allowed methods: %r' % allowed_methods,
}
return method_not_allowed
def create_app():

View file

@ -23,8 +23,8 @@ class SearchExecutor(object):
count_query = filter_query.statement \
.with_only_columns([sqlalchemy.func.count()]) \
.order_by(None)
count = filter_query \
.session.execute(count_query) \
count = filter_query.session \
.execute(count_query) \
.scalar()
return (count, entities)

View file

@ -4,7 +4,7 @@ from szurubooru.search.base_search_config import BaseSearchConfig
class TagSearchConfig(BaseSearchConfig):
def create_query(self):
return db.session().query(db.Tag)
return db.session.query(db.Tag)
def finalize_query(self, query):
return query.order_by(db.Tag.first_name.asc())

View file

@ -6,7 +6,7 @@ class UserSearchConfig(BaseSearchConfig):
''' Executes searches related to the users. '''
def create_query(self):
return db.session().query(db.User)
return db.session.query(db.User)
def finalize_query(self, query):
return query.order_by(db.User.name.asc())

View file

@ -14,8 +14,8 @@ def password_reset_api(config_injector):
return api.PasswordResetApi()
def test_reset_sending_email(
password_reset_api, session, context_factory, user_factory):
session.add(user_factory(
password_reset_api, context_factory, user_factory):
db.session.add(user_factory(
name='u1', rank='regular_user', email='user@example.com'))
for getter in ['u1', 'user@example.com']:
mailer.send_mail = mock.MagicMock()
@ -34,17 +34,17 @@ def test_trying_to_reset_non_existing(password_reset_api, context_factory):
password_reset_api.get(context_factory(), 'u1')
def test_trying_to_reset_without_email(
password_reset_api, session, context_factory, user_factory):
session.add(user_factory(name='u1', rank='regular_user', email=None))
password_reset_api, context_factory, user_factory):
db.session.add(user_factory(name='u1', rank='regular_user', email=None))
with pytest.raises(errors.ValidationError):
password_reset_api.get(context_factory(), 'u1')
def test_confirming_with_good_token(
password_reset_api, context_factory, session, user_factory):
password_reset_api, context_factory, user_factory):
user = user_factory(
name='u1', rank='regular_user', email='user@example.com')
old_hash = user.password_hash
session.add(user)
db.session.add(user)
context = context_factory(
input={'token': '4ac0be176fb364f13ee6b634c43220e2'})
result = password_reset_api.post(context, 'u1')
@ -56,15 +56,15 @@ def test_trying_to_confirm_non_existing(password_reset_api, context_factory):
password_reset_api.post(context_factory(), 'u1')
def test_trying_to_confirm_without_token(
password_reset_api, context_factory, session, user_factory):
session.add(user_factory(
password_reset_api, context_factory, user_factory):
db.session.add(user_factory(
name='u1', rank='regular_user', email='user@example.com'))
with pytest.raises(errors.ValidationError):
password_reset_api.post(context_factory(input={}), 'u1')
def test_trying_to_confirm_with_bad_token(
password_reset_api, context_factory, session, user_factory):
session.add(user_factory(
password_reset_api, context_factory, user_factory):
db.session.add(user_factory(
name='u1', rank='regular_user', email='user@example.com'))
with pytest.raises(errors.ValidationError):
password_reset_api.post(

View file

@ -28,35 +28,35 @@ def test_ctx(
return ret
def test_deleting(test_ctx):
db.session().add(test_ctx.tag_category_factory(name='root'))
db.session().add(test_ctx.tag_category_factory(name='category'))
db.session().commit()
db.session.add(test_ctx.tag_category_factory(name='root'))
db.session.add(test_ctx.tag_category_factory(name='category'))
db.session.commit()
result = test_ctx.api.delete(
test_ctx.context_factory(
user=test_ctx.user_factory(rank='regular_user')),
'category')
assert result == {}
assert db.session().query(db.TagCategory).count() == 1
assert db.session().query(db.TagCategory).one().name == 'root'
assert db.session.query(db.TagCategory).count() == 1
assert db.session.query(db.TagCategory).one().name == 'root'
assert os.path.exists(os.path.join(config.config['data_dir'], 'tags.json'))
def test_trying_to_delete_used(test_ctx, tag_factory):
category = test_ctx.tag_category_factory(name='category')
db.session().add(category)
db.session().flush()
db.session.add(category)
db.session.flush()
tag = test_ctx.tag_factory(names=['tag'], category=category)
db.session().add(tag)
db.session().commit()
db.session.add(tag)
db.session.commit()
with pytest.raises(tag_categories.TagCategoryIsInUseError):
test_ctx.api.delete(
test_ctx.context_factory(
user=test_ctx.user_factory(rank='regular_user')),
'category')
assert db.session().query(db.TagCategory).count() == 1
assert db.session.query(db.TagCategory).count() == 1
def test_trying_to_delete_last(test_ctx, tag_factory):
db.session().add(test_ctx.tag_category_factory(name='root'))
db.session().commit()
db.session.add(test_ctx.tag_category_factory(name='root'))
db.session.commit()
with pytest.raises(tag_categories.TagCategoryIsInUseError):
result = test_ctx.api.delete(
test_ctx.context_factory(
@ -70,11 +70,11 @@ def test_trying_to_delete_non_existing(test_ctx):
user=test_ctx.user_factory(rank='regular_user')), 'bad')
def test_trying_to_delete_without_privileges(test_ctx):
db.session().add(test_ctx.tag_category_factory(name='category'))
db.session().commit()
db.session.add(test_ctx.tag_category_factory(name='category'))
db.session.commit()
with pytest.raises(errors.AuthError):
test_ctx.api.delete(
test_ctx.context_factory(
user=test_ctx.user_factory(rank='anonymous')),
'category')
assert db.session().query(db.TagCategory).count() == 1
assert db.session.query(db.TagCategory).count() == 1

View file

@ -24,7 +24,7 @@ def test_ctx(
return ret
def test_retrieving_multiple(test_ctx):
db.session().add_all([
db.session.add_all([
test_ctx.tag_category_factory(name='c1'),
test_ctx.tag_category_factory(name='c2'),
])
@ -34,7 +34,7 @@ def test_retrieving_multiple(test_ctx):
assert [cat['name'] for cat in result['tagCategories']] == ['c1', 'c2']
def test_retrieving_single(test_ctx):
db.session().add(test_ctx.tag_category_factory(name='cat'))
db.session.add(test_ctx.tag_category_factory(name='cat'))
result = test_ctx.detail_api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank='regular_user')),

View file

@ -28,8 +28,8 @@ def test_ctx(
def test_simple_updating(test_ctx):
category = test_ctx.tag_category_factory(name='name', color='black')
db.session().add(category)
db.session().commit()
db.session.add(category)
db.session.commit()
result = test_ctx.api.put(
test_ctx.context_factory(
input={
@ -59,8 +59,8 @@ def test_simple_updating(test_ctx):
{'color': ''},
])
def test_trying_to_pass_invalid_input(test_ctx, input):
db.session().add(test_ctx.tag_category_factory(name='meta', color='black'))
db.session().commit()
db.session.add(test_ctx.tag_category_factory(name='meta', color='black'))
db.session.commit()
with pytest.raises(tag_categories.InvalidTagCategoryNameError):
test_ctx.api.put(
test_ctx.context_factory(
@ -70,8 +70,8 @@ def test_trying_to_pass_invalid_input(test_ctx, input):
@pytest.mark.parametrize('field', ['name', 'color'])
def test_omitting_optional_field(test_ctx, tmpdir, field):
db.session().add(test_ctx.tag_category_factory(name='name', color='black'))
db.session().commit()
db.session.add(test_ctx.tag_category_factory(name='name', color='black'))
db.session.commit()
input = {
'name': 'changed',
'color': 'white',
@ -94,8 +94,8 @@ def test_trying_to_update_non_existing(test_ctx):
@pytest.mark.parametrize('new_name', ['cat', 'CAT'])
def test_reusing_own_name(test_ctx, new_name):
db.session().add(test_ctx.tag_category_factory(name='cat', color='black'))
db.session().commit()
db.session.add(test_ctx.tag_category_factory(name='cat', color='black'))
db.session.commit()
result = test_ctx.api.put(
test_ctx.context_factory(
input={'name': new_name},
@ -107,10 +107,10 @@ def test_reusing_own_name(test_ctx, new_name):
@pytest.mark.parametrize('dup_name', ['cat1', 'CAT1'])
def test_trying_to_use_existing_name(test_ctx, dup_name):
db.session().add_all([
db.session.add_all([
test_ctx.tag_category_factory(name='cat1', color='black'),
test_ctx.tag_category_factory(name='cat2', color='black')])
db.session().commit()
db.session.commit()
with pytest.raises(tag_categories.TagCategoryAlreadyExistsError):
test_ctx.api.put(
test_ctx.context_factory(
@ -123,8 +123,8 @@ def test_trying_to_use_existing_name(test_ctx, dup_name):
{'color': 'whatever'},
])
def test_trying_to_update_without_privileges(test_ctx, input):
db.session().add(test_ctx.tag_category_factory(name='dummy'))
db.session().commit()
db.session.add(test_ctx.tag_category_factory(name='dummy'))
db.session.commit()
with pytest.raises(errors.AuthError):
test_ctx.api.put(
test_ctx.context_factory(

View file

@ -4,8 +4,9 @@ import pytest
from szurubooru import api, config, db, errors
from szurubooru.util import misc, tags
def get_tag(session, name):
return session.query(db.Tag) \
def get_tag(name):
return db.session \
.query(db.Tag) \
.join(db.TagName) \
.filter(db.TagName.name==name) \
.first()
@ -16,24 +17,17 @@ def assert_relations(relations, expected_tag_names):
@pytest.fixture
def test_ctx(
tmpdir,
session,
config_injector,
context_factory,
user_factory,
tag_factory):
tmpdir, config_injector, context_factory, user_factory, tag_factory):
config_injector({
'data_dir': str(tmpdir),
'tag_name_regex': '^[^!]*$',
'ranks': ['anonymous', 'regular_user'],
'privileges': {'tags:create': 'regular_user'},
})
session.add_all([
db.TagCategory(name) for name in [
'meta', 'character', 'copyright']])
session.flush()
db.session.add_all([
db.TagCategory(name) for name in ['meta', 'character', 'copyright']])
db.session.flush()
ret = misc.dotdict()
ret.session = session
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.tag_factory = tag_factory
@ -61,7 +55,7 @@ def test_creating_simple_tags(test_ctx, fake_datetime):
'lastEditTime': None,
}
}
tag = get_tag(test_ctx.session, 'tag1')
tag = get_tag('tag1')
assert [tag_name.name for tag_name in tag.names] == ['tag1', 'tag2']
assert tag.category.name == 'meta'
assert tag.last_edit_time is None
@ -140,15 +134,15 @@ def test_duplicating_names(test_ctx):
user=test_ctx.user_factory(rank='regular_user')))
assert result['tag']['names'] == ['tag1']
assert result['tag']['category'] == 'meta'
tag = get_tag(test_ctx.session, 'tag1')
tag = get_tag('tag1')
assert [tag_name.name for tag_name in tag.names] == ['tag1']
def test_trying_to_use_existing_name(test_ctx):
test_ctx.session.add_all([
db.session.add_all([
test_ctx.tag_factory(names=['used1'], category_name='meta'),
test_ctx.tag_factory(names=['used2'], category_name='meta'),
])
test_ctx.session.commit()
db.session.commit()
with pytest.raises(tags.TagAlreadyExistsError):
test_ctx.api.post(
test_ctx.context_factory(
@ -169,7 +163,7 @@ def test_trying_to_use_existing_name(test_ctx):
'implications': [],
},
user=test_ctx.user_factory(rank='regular_user')))
assert get_tag(test_ctx.session, 'unused') is None
assert get_tag('unused') is None
@pytest.mark.parametrize('input,expected_suggestions,expected_implications', [
# new relations
@ -208,18 +202,18 @@ def test_creating_new_suggestions_and_implications(
input=input, user=test_ctx.user_factory(rank='regular_user')))
assert result['tag']['suggestions'] == expected_suggestions
assert result['tag']['implications'] == expected_implications
tag = get_tag(test_ctx.session, 'main')
tag = get_tag('main')
assert_relations(tag.suggestions, expected_suggestions)
assert_relations(tag.implications, expected_implications)
for name in ['main'] + expected_suggestions + expected_implications:
assert get_tag(test_ctx.session, name) is not None
assert get_tag(name) is not None
def test_reusing_suggestions_and_implications(test_ctx):
test_ctx.session.add_all([
db.session.add_all([
test_ctx.tag_factory(names=['tag1', 'tag2'], category_name='meta'),
test_ctx.tag_factory(names=['tag3'], category_name='meta'),
])
test_ctx.session.commit()
db.session.commit()
result = test_ctx.api.post(
test_ctx.context_factory(
input={
@ -232,7 +226,7 @@ def test_reusing_suggestions_and_implications(test_ctx):
# NOTE: it should export only the first name
assert result['tag']['suggestions'] == ['tag1']
assert result['tag']['implications'] == ['tag1']
tag = get_tag(test_ctx.session, 'new')
tag = get_tag('new')
assert_relations(tag.suggestions, ['tag1'])
assert_relations(tag.implications, ['tag1'])
@ -256,7 +250,7 @@ def test_tag_trying_to_relate_to_itself(test_ctx, input):
test_ctx.context_factory(
input=input,
user=test_ctx.user_factory(rank='regular_user')))
assert get_tag(test_ctx.session, 'tag') is None
assert get_tag('tag') is None
def test_trying_to_create_tag_without_privileges(test_ctx):
with pytest.raises(errors.AuthError):

View file

@ -6,12 +6,7 @@ from szurubooru.util import misc, tags
@pytest.fixture
def test_ctx(
tmpdir,
session,
config_injector,
context_factory,
tag_factory,
user_factory):
tmpdir, config_injector, context_factory, tag_factory, user_factory):
config_injector({
'data_dir': str(tmpdir),
'privileges': {
@ -20,7 +15,6 @@ def test_ctx(
'ranks': ['anonymous', 'regular_user'],
})
ret = misc.dotdict()
ret.session = session
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.tag_factory = tag_factory
@ -28,28 +22,28 @@ def test_ctx(
return ret
def test_deleting(test_ctx):
test_ctx.session.add(test_ctx.tag_factory(names=['tag']))
test_ctx.session.commit()
db.session.add(test_ctx.tag_factory(names=['tag']))
db.session.commit()
result = test_ctx.api.delete(
test_ctx.context_factory(
user=test_ctx.user_factory(rank='regular_user')),
'tag')
assert result == {}
assert test_ctx.session.query(db.Tag).count() == 0
assert db.session.query(db.Tag).count() == 0
assert os.path.exists(os.path.join(config.config['data_dir'], 'tags.json'))
def test_trying_to_delete_used(test_ctx, post_factory):
tag = test_ctx.tag_factory(names=['tag'])
post = post_factory()
post.tags.append(tag)
test_ctx.session.add_all([tag, post])
test_ctx.session.commit()
db.session.add_all([tag, post])
db.session.commit()
with pytest.raises(tags.TagIsInUseError):
test_ctx.api.delete(
test_ctx.context_factory(
user=test_ctx.user_factory(rank='regular_user')),
'tag')
assert test_ctx.session.query(db.Tag).count() == 1
assert db.session.query(db.Tag).count() == 1
def test_trying_to_delete_non_existing(test_ctx):
with pytest.raises(tags.TagNotFoundError):
@ -58,11 +52,11 @@ def test_trying_to_delete_non_existing(test_ctx):
user=test_ctx.user_factory(rank='regular_user')), 'bad')
def test_trying_to_delete_without_privileges(test_ctx):
test_ctx.session.add(test_ctx.tag_factory(names=['tag']))
test_ctx.session.commit()
db.session.add(test_ctx.tag_factory(names=['tag']))
db.session.commit()
with pytest.raises(errors.AuthError):
test_ctx.api.delete(
test_ctx.context_factory(
user=test_ctx.user_factory(rank='anonymous')),
'tag')
assert test_ctx.session.query(db.Tag).count() == 1
assert db.session.query(db.Tag).count() == 1

View file

@ -7,7 +7,6 @@ from szurubooru.util import tags
def test_export(
tmpdir,
query_counter,
session,
config_injector,
tag_factory,
tag_category_factory):
@ -16,23 +15,23 @@ def test_export(
})
cat1 = tag_category_factory(name='cat1', color='black')
cat2 = tag_category_factory(name='cat2', color='white')
session.add_all([cat1, cat2])
session.flush()
db.session.add_all([cat1, cat2])
db.session.flush()
sug1 = tag_factory(names=['sug1'], category=cat1)
sug2 = tag_factory(names=['sug2'], category=cat1)
imp1 = tag_factory(names=['imp1'], category=cat1)
imp2 = tag_factory(names=['imp2'], category=cat1)
tag = tag_factory(names=['alias1', 'alias2'], category=cat2)
tag.post_count = 1
session.add_all([tag, sug1, sug2, imp1, imp2, cat1, cat2])
session.flush()
session.add_all([
db.session.add_all([tag, sug1, sug2, imp1, imp2, cat1, cat2])
db.session.flush()
db.session.add_all([
db.TagSuggestion(tag.tag_id, sug1.tag_id),
db.TagSuggestion(tag.tag_id, sug2.tag_id),
db.TagImplication(tag.tag_id, imp1.tag_id),
db.TagImplication(tag.tag_id, imp2.tag_id),
])
session.flush()
db.session.flush()
with query_counter:
tags.export_to_json()

View file

@ -4,8 +4,7 @@ from szurubooru import api, db, errors
from szurubooru.util import misc, tags
@pytest.fixture
def test_ctx(
session, context_factory, config_injector, user_factory, tag_factory):
def test_ctx(context_factory, config_injector, user_factory, tag_factory):
config_injector({
'privileges': {
'tags:list': 'regular_user',
@ -16,7 +15,6 @@ def test_ctx(
'rank_names': {'regular_user': 'Peasant'},
})
ret = misc.dotdict()
ret.session = session
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.tag_factory = tag_factory
@ -27,7 +25,7 @@ def test_ctx(
def test_retrieving_multiple(test_ctx):
tag1 = test_ctx.tag_factory(names=['t1'])
tag2 = test_ctx.tag_factory(names=['t2'])
test_ctx.session.add_all([tag1, tag2])
db.session.add_all([tag1, tag2])
result = test_ctx.list_api.get(
test_ctx.context_factory(
input={'query': '', 'page': 1},
@ -46,7 +44,7 @@ def test_trying_to_retrieve_multiple_without_privileges(test_ctx):
user=test_ctx.user_factory(rank='anonymous')))
def test_retrieving_single(test_ctx):
test_ctx.session.add(test_ctx.tag_factory(names=['tag']))
db.session.add(test_ctx.tag_factory(names=['tag']))
result = test_ctx.detail_api.get(
test_ctx.context_factory(
input={'query': '', 'page': 1},

View file

@ -4,8 +4,9 @@ import pytest
from szurubooru import api, config, db, errors
from szurubooru.util import misc, tags
def get_tag(session, name):
return session.query(db.Tag) \
def get_tag(name):
return db.session \
.query(db.Tag) \
.join(db.TagName) \
.filter(db.TagName.name==name) \
.first()
@ -16,12 +17,7 @@ def assert_relations(relations, expected_tag_names):
@pytest.fixture
def test_ctx(
tmpdir,
session,
config_injector,
context_factory,
user_factory,
tag_factory):
tmpdir, config_injector, context_factory, user_factory, tag_factory):
config_injector({
'data_dir': str(tmpdir),
'tag_name_regex': '^[^!]*$',
@ -33,12 +29,10 @@ def test_ctx(
'tags:edit:implications': 'regular_user',
},
})
session.add_all([
db.TagCategory(name) for name in [
'meta', 'character', 'copyright']])
session.flush()
db.session.add_all([
db.TagCategory(name) for name in ['meta', 'character', 'copyright']])
db.session.flush()
ret = misc.dotdict()
ret.session = session
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.tag_factory = tag_factory
@ -47,8 +41,8 @@ def test_ctx(
def test_simple_updating(test_ctx, fake_datetime):
tag = test_ctx.tag_factory(names=['tag1', 'tag2'], category_name='meta')
test_ctx.session.add(tag)
test_ctx.session.commit()
db.session.add(tag)
db.session.commit()
with fake_datetime('1997-12-01'):
result = test_ctx.api.put(
test_ctx.context_factory(
@ -68,9 +62,9 @@ def test_simple_updating(test_ctx, fake_datetime):
'lastEditTime': datetime.datetime(1997, 12, 1),
}
}
assert get_tag(test_ctx.session, 'tag1') is None
assert get_tag(test_ctx.session, 'tag2') is None
tag = get_tag(test_ctx.session, 'tag3')
assert get_tag('tag1') is None
assert get_tag('tag2') is None
tag = get_tag('tag3')
assert tag is not None
assert [tag_name.name for tag_name in tag.names] == ['tag3']
assert tag.category.name == 'character'
@ -92,9 +86,8 @@ def test_simple_updating(test_ctx, fake_datetime):
({'implications': ['good', '!bad']}, tags.InvalidTagNameError),
])
def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception):
test_ctx.session.add(
test_ctx.tag_factory(names=['tag1'], category_name='meta'))
test_ctx.session.commit()
db.session.add(test_ctx.tag_factory(names=['tag1'], category_name='meta'))
db.session.commit()
with pytest.raises(expected_exception):
test_ctx.api.put(
test_ctx.context_factory(
@ -105,9 +98,8 @@ def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception):
@pytest.mark.parametrize(
'field', ['names', 'category', 'implications', 'suggestions'])
def test_omitting_optional_field(test_ctx, tmpdir, field):
test_ctx.session.add(
test_ctx.tag_factory(names=['tag'], category_name='meta'))
test_ctx.session.commit()
db.session.add(test_ctx.tag_factory(names=['tag'], category_name='meta'))
db.session.commit()
input = {
'names': ['tag1', 'tag2'],
'category': 'meta',
@ -132,23 +124,23 @@ def test_trying_to_update_non_existing(test_ctx):
@pytest.mark.parametrize('dup_name', ['tag1', 'TAG1'])
def test_reusing_own_name(test_ctx, dup_name):
test_ctx.session.add(
db.session.add(
test_ctx.tag_factory(names=['tag1', 'tag2'], category_name='meta'))
test_ctx.session.commit()
db.session.commit()
result = test_ctx.api.put(
test_ctx.context_factory(
input={'names': [dup_name, 'tag3']},
user=test_ctx.user_factory(rank='regular_user')),
'tag1')
assert result['tag']['names'] == ['tag1', 'tag3']
assert get_tag(test_ctx.session, 'tag2') is None
tag1 = get_tag(test_ctx.session, 'tag1')
tag2 = get_tag(test_ctx.session, 'tag3')
assert get_tag('tag2') is None
tag1 = get_tag('tag1')
tag2 = get_tag('tag3')
assert tag1.tag_id == tag2.tag_id
assert [name.name for name in tag1.names] == ['tag1', 'tag3']
def test_duplicating_names(test_ctx):
test_ctx.session.add(
db.session.add(
test_ctx.tag_factory(names=['tag1', 'tag2'], category_name='meta'))
result = test_ctx.api.put(
test_ctx.context_factory(
@ -156,18 +148,18 @@ def test_duplicating_names(test_ctx):
user=test_ctx.user_factory(rank='regular_user')),
'tag1')
assert result['tag']['names'] == ['tag3']
assert get_tag(test_ctx.session, 'tag1') is None
assert get_tag(test_ctx.session, 'tag2') is None
tag = get_tag(test_ctx.session, 'tag3')
assert get_tag('tag1') is None
assert get_tag('tag2') is None
tag = get_tag('tag3')
assert tag is not None
assert [tag_name.name for tag_name in tag.names] == ['tag3']
@pytest.mark.parametrize('dup_name', ['tag1', 'TAG1', 'tag2', 'TAG2'])
def test_trying_to_use_existing_name(test_ctx, dup_name):
test_ctx.session.add_all([
db.session.add_all([
test_ctx.tag_factory(names=['tag1', 'tag2'], category_name='meta'),
test_ctx.tag_factory(names=['tag3', 'tag4'], category_name='meta')])
test_ctx.session.commit()
db.session.commit()
with pytest.raises(tags.TagAlreadyExistsError):
test_ctx.api.put(
test_ctx.context_factory(
@ -199,28 +191,28 @@ def test_trying_to_use_existing_name(test_ctx, dup_name):
])
def test_updating_new_suggestions_and_implications(
test_ctx, input, expected_suggestions, expected_implications):
test_ctx.session.add(
db.session.add(
test_ctx.tag_factory(names=['main'], category_name='meta'))
test_ctx.session.commit()
db.session.commit()
result = test_ctx.api.put(
test_ctx.context_factory(
input=input, user=test_ctx.user_factory(rank='regular_user')),
'main')
assert result['tag']['suggestions'] == expected_suggestions
assert result['tag']['implications'] == expected_implications
tag = get_tag(test_ctx.session, 'main')
tag = get_tag('main')
assert_relations(tag.suggestions, expected_suggestions)
assert_relations(tag.implications, expected_implications)
for name in ['main'] + expected_suggestions + expected_implications:
assert get_tag(test_ctx.session, name) is not None
assert get_tag(name) is not None
def test_reusing_suggestions_and_implications(test_ctx):
test_ctx.session.add_all([
db.session.add_all([
test_ctx.tag_factory(names=['tag1', 'tag2'], category_name='meta'),
test_ctx.tag_factory(names=['tag3'], category_name='meta'),
test_ctx.tag_factory(names=['tag4'], category_name='meta'),
])
test_ctx.session.commit()
db.session.commit()
result = test_ctx.api.put(
test_ctx.context_factory(
input={
@ -234,7 +226,7 @@ def test_reusing_suggestions_and_implications(test_ctx):
# NOTE: it should export only the first name
assert result['tag']['suggestions'] == ['tag1']
assert result['tag']['implications'] == ['tag1']
tag = get_tag(test_ctx.session, 'new')
tag = get_tag('new')
assert_relations(tag.suggestions, ['tag1'])
assert_relations(tag.implications, ['tag1'])
@ -253,9 +245,8 @@ def test_reusing_suggestions_and_implications(test_ctx):
}
])
def test_trying_to_relate_tag_to_itself(test_ctx, input):
test_ctx.session.add(
test_ctx.tag_factory(names=['tag1'], category_name='meta'))
test_ctx.session.commit()
db.session.add(test_ctx.tag_factory(names=['tag1'], category_name='meta'))
db.session.commit()
with pytest.raises(tags.InvalidTagRelationError):
test_ctx.api.put(
test_ctx.context_factory(
@ -269,9 +260,8 @@ def test_trying_to_relate_tag_to_itself(test_ctx, input):
{'implications': ['whatever']},
])
def test_trying_to_update_without_privileges(test_ctx, input):
test_ctx.session.add(
test_ctx.tag_factory(names=['tag'], category_name='meta'))
test_ctx.session.commit()
db.session.add(test_ctx.tag_factory(names=['tag'], category_name='meta'))
db.session.commit()
with pytest.raises(errors.AuthError):
test_ctx.api.put(
test_ctx.context_factory(

View file

@ -8,12 +8,11 @@ EMPTY_PIXEL = \
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b'
def get_user(session, name):
return session.query(db.User).filter_by(name=name).first()
def get_user(name):
return db.session.query(db.User).filter_by(name=name).first()
@pytest.fixture
def test_ctx(
session, config_injector, context_factory, user_factory):
def test_ctx(config_injector, context_factory, user_factory):
config_injector({
'secret': '',
'user_name_regex': '[^!]{3,}',
@ -25,7 +24,6 @@ def test_ctx(
'privileges': {'users:create': 'anonymous'},
})
ret = misc.dotdict()
ret.session = session
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.api = api.UserListApi()
@ -53,7 +51,7 @@ def test_creating_user(test_ctx, fake_datetime):
'rankName': 'Unknown',
}
}
user = get_user(test_ctx.session, 'chewie1')
user = get_user('chewie1')
assert user.name == 'chewie1'
assert user.email == 'asd@asd.asd'
assert user.rank == 'admin'
@ -79,8 +77,8 @@ def test_first_user_becomes_admin_others_not(test_ctx):
user=test_ctx.user_factory(rank='anonymous')))
assert result1['user']['rank'] == 'admin'
assert result2['user']['rank'] == 'regular_user'
first_user = get_user(test_ctx.session, 'chewie1')
other_user = get_user(test_ctx.session, 'chewie2')
first_user = get_user('chewie1')
other_user = get_user('chewie2')
assert first_user.rank == 'admin'
assert other_user.rank == 'regular_user'
@ -192,7 +190,7 @@ def test_omitting_optional_field(test_ctx, tmpdir, field):
def test_mods_trying_to_become_admin(test_ctx):
user1 = test_ctx.user_factory(name='u1', rank='mod')
user2 = test_ctx.user_factory(name='u2', rank='mod')
test_ctx.session.add_all([user1, user2])
db.session.add_all([user1, user2])
context = test_ctx.context_factory(input={
'name': 'chewie',
'email': 'asd@asd.asd',
@ -204,7 +202,7 @@ def test_mods_trying_to_become_admin(test_ctx):
def test_admin_creating_mod_account(test_ctx):
user = test_ctx.user_factory(rank='admin')
test_ctx.session.add(user)
db.session.add(user)
context = test_ctx.context_factory(input={
'name': 'chewie',
'email': 'asd@asd.asd',
@ -227,7 +225,7 @@ def test_uploading_avatar(test_ctx, tmpdir):
},
files={'avatar': EMPTY_PIXEL},
user=test_ctx.user_factory(rank='mod')))
user = get_user(test_ctx.session, 'chewie')
user = get_user('chewie')
assert user.avatar_style == user.AVATAR_MANUAL
assert response['user']['avatarUrl'] == \
'http://example.com/data/avatars/chewie.jpg'

View file

@ -4,7 +4,7 @@ from szurubooru import api, db, errors
from szurubooru.util import misc, users
@pytest.fixture
def test_ctx(session, config_injector, context_factory, user_factory):
def test_ctx(config_injector, context_factory, user_factory):
config_injector({
'privileges': {
'users:delete:self': 'regular_user',
@ -13,7 +13,6 @@ def test_ctx(session, config_injector, context_factory, user_factory):
'ranks': ['anonymous', 'regular_user', 'mod', 'admin'],
})
ret = misc.dotdict()
ret.session = session
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.api = api.UserDetailApi()
@ -21,28 +20,28 @@ def test_ctx(session, config_injector, context_factory, user_factory):
def test_deleting_oneself(test_ctx):
user = test_ctx.user_factory(name='u', rank='regular_user')
test_ctx.session.add(user)
test_ctx.session.commit()
db.session.add(user)
db.session.commit()
result = test_ctx.api.delete(test_ctx.context_factory(user=user), 'u')
assert result == {}
assert test_ctx.session.query(db.User).count() == 0
assert db.session.query(db.User).count() == 0
def test_deleting_someone_else(test_ctx):
user1 = test_ctx.user_factory(name='u1', rank='regular_user')
user2 = test_ctx.user_factory(name='u2', rank='mod')
test_ctx.session.add_all([user1, user2])
test_ctx.session.commit()
db.session.add_all([user1, user2])
db.session.commit()
test_ctx.api.delete(test_ctx.context_factory(user=user2), 'u1')
assert test_ctx.session.query(db.User).count() == 1
assert db.session.query(db.User).count() == 1
def test_trying_to_delete_someone_else_without_privileges(test_ctx):
user1 = test_ctx.user_factory(name='u1', rank='regular_user')
user2 = test_ctx.user_factory(name='u2', rank='regular_user')
test_ctx.session.add_all([user1, user2])
test_ctx.session.commit()
db.session.add_all([user1, user2])
db.session.commit()
with pytest.raises(errors.AuthError):
test_ctx.api.delete(test_ctx.context_factory(user=user2), 'u1')
assert test_ctx.session.query(db.User).count() == 2
assert db.session.query(db.User).count() == 2
def test_trying_to_delete_non_existing(test_ctx):
with pytest.raises(users.UserNotFoundError):

View file

@ -4,7 +4,7 @@ from szurubooru import api, db, errors
from szurubooru.util import misc, users
@pytest.fixture
def test_ctx(session, context_factory, config_injector, user_factory):
def test_ctx(context_factory, config_injector, user_factory):
config_injector({
'privileges': {
'users:list': 'regular_user',
@ -15,7 +15,6 @@ def test_ctx(session, context_factory, config_injector, user_factory):
'rank_names': {'regular_user': 'Peasant'},
})
ret = misc.dotdict()
ret.session = session
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.list_api = api.UserListApi()
@ -25,7 +24,7 @@ def test_ctx(session, context_factory, config_injector, user_factory):
def test_retrieving_multiple(test_ctx):
user1 = test_ctx.user_factory(name='u1', rank='mod')
user2 = test_ctx.user_factory(name='u2', rank='mod')
test_ctx.session.add_all([user1, user2])
db.session.add_all([user1, user2])
result = test_ctx.list_api.get(
test_ctx.context_factory(
input={'query': '', 'page': 1},
@ -44,7 +43,7 @@ def test_trying_to_retrieve_multiple_without_privileges(test_ctx):
user=test_ctx.user_factory(rank='anonymous')))
def test_retrieving_single(test_ctx):
test_ctx.session.add(test_ctx.user_factory(name='u1', rank='regular_user'))
db.session.add(test_ctx.user_factory(name='u1', rank='regular_user'))
result = test_ctx.detail_api.get(
test_ctx.context_factory(
input={'query': '', 'page': 1},

View file

@ -8,12 +8,11 @@ EMPTY_PIXEL = \
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b'
def get_user(session, name):
return session.query(db.User).filter_by(name=name).first()
def get_user(name):
return db.session.query(db.User).filter_by(name=name).first()
@pytest.fixture
def test_ctx(
session, config_injector, context_factory, user_factory):
def test_ctx(config_injector, context_factory, user_factory):
config_injector({
'secret': '',
'user_name_regex': '^[^!]{3,}$',
@ -35,7 +34,6 @@ def test_ctx(
},
})
ret = misc.dotdict()
ret.session = session
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.api = api.UserDetailApi()
@ -43,7 +41,7 @@ def test_ctx(
def test_updating_user(test_ctx):
user = test_ctx.user_factory(name='u1', rank='admin')
test_ctx.session.add(user)
db.session.add(user)
result = test_ctx.api.put(
test_ctx.context_factory(
input={
@ -68,7 +66,7 @@ def test_updating_user(test_ctx):
'rankName': 'Unknown',
}
}
user = get_user(test_ctx.session, 'chewie')
user = get_user('chewie')
assert user.name == 'chewie'
assert user.email == 'asd@asd.asd'
assert user.rank == 'mod'
@ -96,7 +94,7 @@ def test_updating_user(test_ctx):
])
def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception):
user = test_ctx.user_factory(name='u1', rank='admin')
test_ctx.session.add(user)
db.session.add(user)
with pytest.raises(expected_exception):
test_ctx.api.put(
test_ctx.context_factory(input=input, user=user), 'u1')
@ -107,7 +105,7 @@ def test_omitting_optional_field(test_ctx, tmpdir, field):
config.config['data_dir'] = str(tmpdir.mkdir('data'))
config.config['data_url'] = 'http://example.com/data/'
user = test_ctx.user_factory(name='u1', rank='admin')
test_ctx.session.add(user)
db.session.add(user)
input = {
'name': 'chewie',
'email': 'asd@asd.asd',
@ -126,16 +124,16 @@ def test_omitting_optional_field(test_ctx, tmpdir, field):
def test_trying_to_update_non_existing(test_ctx):
user = test_ctx.user_factory(name='u1', rank='admin')
test_ctx.session.add(user)
db.session.add(user)
with pytest.raises(users.UserNotFoundError):
test_ctx.api.put(test_ctx.context_factory(user=user), 'u2')
def test_removing_email(test_ctx):
user = test_ctx.user_factory(name='u1', rank='admin')
test_ctx.session.add(user)
db.session.add(user)
test_ctx.api.put(
test_ctx.context_factory(input={'email': ''}, user=user), 'u1')
assert get_user(test_ctx.session, 'u1').email is None
assert get_user('u1').email is None
@pytest.mark.parametrize('input', [
{'name': 'whatever'},
@ -147,7 +145,7 @@ def test_removing_email(test_ctx):
def test_trying_to_update_someone_else(test_ctx, input):
user1 = test_ctx.user_factory(name='u1', rank='regular_user')
user2 = test_ctx.user_factory(name='u2', rank='regular_user')
test_ctx.session.add_all([user1, user2])
db.session.add_all([user1, user2])
with pytest.raises(errors.AuthError):
test_ctx.api.put(
test_ctx.context_factory(input=input, user=user1), user2.name)
@ -155,7 +153,7 @@ def test_trying_to_update_someone_else(test_ctx, input):
def test_trying_to_become_someone_else(test_ctx):
user1 = test_ctx.user_factory(name='me', rank='regular_user')
user2 = test_ctx.user_factory(name='her', rank='regular_user')
test_ctx.session.add_all([user1, user2])
db.session.add_all([user1, user2])
with pytest.raises(users.UserAlreadyExistsError):
test_ctx.api.put(
test_ctx.context_factory(input={'name': 'her'}, user=user1),
@ -167,7 +165,7 @@ def test_trying_to_become_someone_else(test_ctx):
def test_mods_trying_to_become_admin(test_ctx):
user1 = test_ctx.user_factory(name='u1', rank='mod')
user2 = test_ctx.user_factory(name='u2', rank='mod')
test_ctx.session.add_all([user1, user2])
db.session.add_all([user1, user2])
context = test_ctx.context_factory(input={'rank': 'admin'}, user=user1)
with pytest.raises(errors.AuthError):
test_ctx.api.put(context, user1.name)
@ -178,14 +176,14 @@ def test_uploading_avatar(test_ctx, tmpdir):
config.config['data_dir'] = str(tmpdir.mkdir('data'))
config.config['data_url'] = 'http://example.com/data/'
user = test_ctx.user_factory(name='u1', rank='mod')
test_ctx.session.add(user)
db.session.add(user)
response = test_ctx.api.put(
test_ctx.context_factory(
input={'avatarStyle': 'manual'},
files={'avatar': EMPTY_PIXEL},
user=user),
'u1')
user = get_user(test_ctx.session, 'u1')
user = get_user('u1')
assert user.avatar_style == user.AVATAR_MANUAL
assert response['user']['avatarUrl'] == \
'http://example.com/data/avatars/u1.jpg'

View file

@ -46,8 +46,8 @@ def fake_datetime():
freezer.stop()
return injector
@pytest.yield_fixture
def session(query_counter, autoload=True):
@pytest.yield_fixture(autouse=True)
def session(query_counter):
import logging
logging.basicConfig()
logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)

View file

@ -1,7 +1,7 @@
from datetime import datetime
from szurubooru import db
def test_saving_post(session, post_factory, user_factory, tag_factory):
def test_saving_post(post_factory, user_factory, tag_factory):
user = user_factory()
tag1 = tag_factory()
tag2 = tag_factory()
@ -13,17 +13,17 @@ def test_saving_post(session, post_factory, user_factory, tag_factory):
post.checksum = 'deadbeef'
post.creation_time = datetime(1997, 1, 1)
post.last_edit_time = datetime(1998, 1, 1)
session.add_all([user, tag1, tag2, related_post1, related_post2, post])
db.session.add_all([user, tag1, tag2, related_post1, related_post2, post])
post.user = user
post.tags.append(tag1)
post.tags.append(tag2)
post.relations.append(related_post1)
post.relations.append(related_post2)
session.commit()
db.session.commit()
post = session.query(db.Post).filter(db.Post.post_id == post.post_id).one()
assert not session.dirty
db.session.refresh(post)
assert not db.session.dirty
assert post.user.user_id is not None
assert post.safety == 'safety'
assert post.type == 'type'
@ -32,60 +32,60 @@ def test_saving_post(session, post_factory, user_factory, tag_factory):
assert post.last_edit_time == datetime(1998, 1, 1)
assert len(post.relations) == 2
def test_cascade_deletions(session, post_factory, user_factory, tag_factory):
def test_cascade_deletions(post_factory, user_factory, tag_factory):
user = user_factory()
tag1 = tag_factory()
tag2 = tag_factory()
related_post1 = post_factory()
related_post2 = post_factory()
post = post_factory()
session.add_all([user, tag1, tag2, post, related_post1, related_post2])
session.flush()
db.session.add_all([user, tag1, tag2, post, related_post1, related_post2])
db.session.flush()
post.user = user
post.tags.append(tag1)
post.tags.append(tag2)
post.relations.append(related_post1)
post.relations.append(related_post2)
session.flush()
db.session.flush()
assert not session.dirty
assert not db.session.dirty
assert post.user.user_id is not None
assert len(post.relations) == 2
assert session.query(db.User).count() == 1
assert session.query(db.Tag).count() == 2
assert session.query(db.Post).count() == 3
assert session.query(db.PostTag).count() == 2
assert session.query(db.PostRelation).count() == 2
assert db.session.query(db.User).count() == 1
assert db.session.query(db.Tag).count() == 2
assert db.session.query(db.Post).count() == 3
assert db.session.query(db.PostTag).count() == 2
assert db.session.query(db.PostRelation).count() == 2
session.delete(post)
session.commit()
db.session.delete(post)
db.session.commit()
assert not session.dirty
assert session.query(db.User).count() == 1
assert session.query(db.Tag).count() == 2
assert session.query(db.Post).count() == 2
assert session.query(db.PostTag).count() == 0
assert session.query(db.PostRelation).count() == 0
assert not db.session.dirty
assert db.session.query(db.User).count() == 1
assert db.session.query(db.Tag).count() == 2
assert db.session.query(db.Post).count() == 2
assert db.session.query(db.PostTag).count() == 0
assert db.session.query(db.PostRelation).count() == 0
def test_tracking_tag_count(session, post_factory, tag_factory):
def test_tracking_tag_count(post_factory, tag_factory):
post = post_factory()
tag1 = tag_factory()
tag2 = tag_factory()
session.add_all([tag1, tag2, post])
session.flush()
db.session.add_all([tag1, tag2, post])
db.session.flush()
post.tags.append(tag1)
post.tags.append(tag2)
session.commit()
db.session.commit()
assert len(post.tags) == 2
assert post.tag_count == 2
session.delete(tag1)
session.commit()
session.refresh(post)
db.session.delete(tag1)
db.session.commit()
db.session.refresh(post)
assert len(post.tags) == 1
assert post.tag_count == 1
session.delete(tag2)
session.commit()
session.refresh(post)
db.session.delete(tag2)
db.session.commit()
db.session.refresh(post)
assert len(post.tags) == 0
assert post.tag_count == 0

View file

@ -1,7 +1,7 @@
from datetime import datetime
from szurubooru import db
def test_saving_tag(session, tag_factory):
def test_saving_tag(tag_factory):
sug1 = tag_factory(names=['sug1'])
sug2 = tag_factory(names=['sug2'])
imp1 = tag_factory(names=['imp1'])
@ -13,8 +13,8 @@ def test_saving_tag(session, tag_factory):
tag.category = db.TagCategory('category')
tag.creation_time = datetime(1997, 1, 1)
tag.last_edit_time = datetime(1998, 1, 1)
session.add_all([tag, sug1, sug2, imp1, imp2])
session.commit()
db.session.add_all([tag, sug1, sug2, imp1, imp2])
db.session.commit()
assert tag.tag_id is not None
assert sug1.tag_id is not None
@ -25,9 +25,10 @@ def test_saving_tag(session, tag_factory):
tag.suggestions.append(sug2)
tag.implications.append(imp1)
tag.implications.append(imp2)
session.commit()
db.session.commit()
tag = session.query(db.Tag) \
tag = db.session \
.query(db.Tag) \
.join(db.TagName) \
.filter(db.TagName.name=='alias1') \
.one()
@ -40,7 +41,7 @@ def test_saving_tag(session, tag_factory):
assert [relation.names[0].name for relation in tag.implications] \
== ['imp1', 'imp2']
def test_cascade_deletions(session, tag_factory):
def test_cascade_deletions(tag_factory):
sug1 = tag_factory(names=['sug1'])
sug2 = tag_factory(names=['sug2'])
imp1 = tag_factory(names=['imp1'])
@ -53,8 +54,8 @@ def test_cascade_deletions(session, tag_factory):
tag.creation_time = datetime(1997, 1, 1)
tag.last_edit_time = datetime(1998, 1, 1)
tag.post_count = 1
session.add_all([tag, sug1, sug2, imp1, imp2])
session.commit()
db.session.add_all([tag, sug1, sug2, imp1, imp2])
db.session.commit()
assert tag.tag_id is not None
assert sug1.tag_id is not None
@ -65,32 +66,32 @@ def test_cascade_deletions(session, tag_factory):
tag.suggestions.append(sug2)
tag.implications.append(imp1)
tag.implications.append(imp2)
session.commit()
db.session.commit()
session.delete(tag)
session.commit()
assert session.query(db.Tag).count() == 4
assert session.query(db.TagName).count() == 4
assert session.query(db.TagImplication).count() == 0
assert session.query(db.TagSuggestion).count() == 0
db.session.delete(tag)
db.session.commit()
assert db.session.query(db.Tag).count() == 4
assert db.session.query(db.TagName).count() == 4
assert db.session.query(db.TagImplication).count() == 0
assert db.session.query(db.TagSuggestion).count() == 0
def test_tracking_post_count(session, post_factory, tag_factory):
def test_tracking_post_count(post_factory, tag_factory):
tag = tag_factory()
post1 = post_factory()
post2 = post_factory()
session.add_all([tag, post1, post2])
session.flush()
db.session.add_all([tag, post1, post2])
db.session.flush()
post1.tags.append(tag)
post2.tags.append(tag)
session.commit()
db.session.commit()
assert len(post1.tags) == 1
assert len(post2.tags) == 1
assert tag.post_count == 2
session.delete(post1)
session.commit()
session.refresh(tag)
db.session.delete(post1)
db.session.commit()
db.session.refresh(tag)
assert tag.post_count == 1
session.delete(post2)
session.commit()
session.refresh(tag)
db.session.delete(post2)
db.session.commit()
db.session.refresh(tag)
assert tag.post_count == 0

View file

@ -1,7 +1,7 @@
from datetime import datetime
from szurubooru import db
def test_saving_user(session):
def test_saving_user():
user = db.User()
user.name = 'name'
user.password_salt = 'salt'
@ -10,8 +10,10 @@ def test_saving_user(session):
user.rank = 'rank'
user.creation_time = datetime(1997, 1, 1)
user.avatar_style = db.User.AVATAR_GRAVATAR
session.add(user)
user = session.query(db.User).one()
db.session.add(user)
db.session.flush()
db.session.refresh(user)
assert not db.session.dirty
assert user.name == 'name'
assert user.password_salt == 'salt'
assert user.password_hash == 'hash'

View file

@ -40,14 +40,14 @@ def verify_unpaged(executor):
('-creation-date:2014-01,2015', ['t2']),
])
def test_filter_by_creation_time(
verify_unpaged, session, tag_factory, input, expected_tag_names):
verify_unpaged, tag_factory, input, expected_tag_names):
tag1 = tag_factory(names=['t1'])
tag2 = tag_factory(names=['t2'])
tag3 = tag_factory(names=['t3'])
tag1.creation_time = datetime.datetime(2014, 1, 1)
tag2.creation_time = datetime.datetime(2014, 6, 1)
tag3.creation_time = datetime.datetime(2015, 1, 1)
session.add_all([tag1, tag2, tag3])
db.session.add_all([tag1, tag2, tag3])
verify_unpaged(input, expected_tag_names)
@pytest.mark.parametrize('input,expected_tag_names', [
@ -70,12 +70,11 @@ def test_filter_by_creation_time(
('name:tag4', ['tag4']),
('name:tag4,tag5', ['tag4']),
])
def test_filter_by_name(
session, verify_unpaged, tag_factory, input, expected_tag_names):
session.add(tag_factory(names=['tag1']))
session.add(tag_factory(names=['tag2']))
session.add(tag_factory(names=['tag3']))
session.add(tag_factory(names=['tag4', 'tag5', 'tag6']))
def test_filter_by_name(verify_unpaged, tag_factory, input, expected_tag_names):
db.session.add(tag_factory(names=['tag1']))
db.session.add(tag_factory(names=['tag2']))
db.session.add(tag_factory(names=['tag3']))
db.session.add(tag_factory(names=['tag4', 'tag5', 'tag6']))
verify_unpaged(input, expected_tag_names)
@pytest.mark.parametrize('input,expected_tag_names', [
@ -84,10 +83,9 @@ def test_filter_by_name(
('t2', ['t2']),
('t1,t2', ['t1', 't2']),
])
def test_anonymous(
session, verify_unpaged, tag_factory, input, expected_tag_names):
session.add(tag_factory(names=['t1']))
session.add(tag_factory(names=['t2']))
def test_anonymous(verify_unpaged, tag_factory, input, expected_tag_names):
db.session.add(tag_factory(names=['t1']))
db.session.add(tag_factory(names=['t2']))
verify_unpaged(input, expected_tag_names)
@pytest.mark.parametrize('input,expected_tag_names', [
@ -99,10 +97,9 @@ def test_anonymous(
('-order:name,asc', ['t2', 't1']),
('-order:name,desc', ['t1', 't2']),
])
def test_order_by_name(
session, verify_unpaged, tag_factory, input, expected_tag_names):
session.add(tag_factory(names=['t2']))
session.add(tag_factory(names=['t1']))
def test_order_by_name(verify_unpaged, tag_factory, input, expected_tag_names):
db.session.add(tag_factory(names=['t2']))
db.session.add(tag_factory(names=['t1']))
verify_unpaged(input, expected_tag_names)
@pytest.mark.parametrize('input,expected_user_names', [
@ -111,28 +108,28 @@ def test_order_by_name(
('order:creation-time', ['t3', 't2', 't1']),
])
def test_order_by_creation_time(
session, verify_unpaged, tag_factory, input, expected_user_names):
verify_unpaged, tag_factory, input, expected_user_names):
tag1 = tag_factory(names=['t1'])
tag2 = tag_factory(names=['t2'])
tag3 = tag_factory(names=['t3'])
tag1.creation_time = datetime.datetime(1991, 1, 1)
tag2.creation_time = datetime.datetime(1991, 1, 2)
tag3.creation_time = datetime.datetime(1991, 1, 3)
session.add_all([tag3, tag1, tag2])
db.session.add_all([tag3, tag1, tag2])
verify_unpaged(input, expected_user_names)
@pytest.mark.parametrize('input,expected_tag_names', [
('order:suggestion-count', ['t1', 't2', 'sug1', 'sug2', 'sug3']),
])
def test_order_by_suggestion_count(
session, verify_unpaged, tag_factory, input, expected_tag_names):
verify_unpaged, tag_factory, input, expected_tag_names):
sug1 = tag_factory(names=['sug1'])
sug2 = tag_factory(names=['sug2'])
sug3 = tag_factory(names=['sug3'])
tag1 = tag_factory(names=['t1'])
tag2 = tag_factory(names=['t2'])
session.add_all([sug1, sug3, tag2, sug2, tag1])
session.commit()
db.session.add_all([sug1, sug3, tag2, sug2, tag1])
db.session.commit()
tag1.suggestions.append(sug1)
tag1.suggestions.append(sug2)
tag2.suggestions.append(sug3)
@ -142,32 +139,32 @@ def test_order_by_suggestion_count(
('order:implication-count', ['t1', 't2', 'sug1', 'sug2', 'sug3']),
])
def test_order_by_implication_count(
session, verify_unpaged, tag_factory, input, expected_tag_names):
verify_unpaged, tag_factory, input, expected_tag_names):
sug1 = tag_factory(names=['sug1'])
sug2 = tag_factory(names=['sug2'])
sug3 = tag_factory(names=['sug3'])
tag1 = tag_factory(names=['t1'])
tag2 = tag_factory(names=['t2'])
session.add_all([sug1, sug3, tag2, sug2, tag1])
session.commit()
db.session.add_all([sug1, sug3, tag2, sug2, tag1])
db.session.commit()
tag1.implications.append(sug1)
tag1.implications.append(sug2)
tag2.implications.append(sug3)
verify_unpaged(input, expected_tag_names)
def test_filter_by_relation_count(session, verify_unpaged, tag_factory):
def test_filter_by_relation_count(verify_unpaged, tag_factory):
sug1 = tag_factory(names=['sug1'])
sug2 = tag_factory(names=['sug2'])
imp1 = tag_factory(names=['imp1'])
tag1 = tag_factory(names=['t1'])
tag2 = tag_factory(names=['t2'])
session.add_all([sug1, tag1, sug2, imp1, tag2])
session.commit()
session.add_all([
db.session.add_all([sug1, tag1, sug2, imp1, tag2])
db.session.commit()
db.session.add_all([
db.TagSuggestion(tag1.tag_id, sug1.tag_id),
db.TagSuggestion(tag1.tag_id, sug2.tag_id),
db.TagImplication(tag2.tag_id, imp1.tag_id)])
session.commit()
db.session.commit()
verify_unpaged('suggestion-count:0', ['imp1', 'sug1', 'sug2', 't2'])
verify_unpaged('suggestion-count:1', [])
verify_unpaged('suggestion-count:2', ['t1'])

View file

@ -42,14 +42,14 @@ def verify_unpaged(executor):
('-creation-date:2014-01,2015', ['u2']),
])
def test_filter_by_creation_time(
verify_unpaged, session, input, expected_user_names, user_factory):
verify_unpaged, input, expected_user_names, user_factory):
user1 = user_factory(name='u1')
user2 = user_factory(name='u2')
user3 = user_factory(name='u3')
user1.creation_time = datetime.datetime(2014, 1, 1)
user2.creation_time = datetime.datetime(2014, 6, 1)
user3.creation_time = datetime.datetime(2015, 1, 1)
session.add_all([user1, user2, user3])
db.session.add_all([user1, user2, user3])
verify_unpaged(input, expected_user_names)
@pytest.mark.parametrize('input,expected_user_names', [
@ -71,10 +71,10 @@ def test_filter_by_creation_time(
('-name:user1,user3', ['user2']),
])
def test_filter_by_name(
session, verify_unpaged, input, expected_user_names, user_factory):
session.add(user_factory(name='user1'))
session.add(user_factory(name='user2'))
session.add(user_factory(name='user3'))
verify_unpaged, input, expected_user_names, user_factory):
db.session.add(user_factory(name='user1'))
db.session.add(user_factory(name='user2'))
db.session.add(user_factory(name='user3'))
verify_unpaged(input, expected_user_names)
@pytest.mark.parametrize('input,expected_user_names', [
@ -84,9 +84,9 @@ def test_filter_by_name(
('u1,u2', ['u1', 'u2']),
])
def test_anonymous(
session, verify_unpaged, input, expected_user_names, user_factory):
session.add(user_factory(name='u1'))
session.add(user_factory(name='u2'))
verify_unpaged, input, expected_user_names, user_factory):
db.session.add(user_factory(name='u1'))
db.session.add(user_factory(name='u2'))
verify_unpaged(input, expected_user_names)
@pytest.mark.parametrize('input,expected_user_names', [
@ -95,14 +95,14 @@ def test_anonymous(
('creation-time:2016 u2', []),
])
def test_combining_tokens(
session, verify_unpaged, input, expected_user_names, user_factory):
verify_unpaged, input, expected_user_names, user_factory):
user1 = user_factory(name='u1')
user2 = user_factory(name='u2')
user3 = user_factory(name='u3')
user1.creation_time = datetime.datetime(2014, 1, 1)
user2.creation_time = datetime.datetime(2014, 6, 1)
user3.creation_time = datetime.datetime(2015, 1, 1)
session.add_all([user1, user2, user3])
db.session.add_all([user1, user2, user3])
verify_unpaged(input, expected_user_names)
@pytest.mark.parametrize(
@ -114,10 +114,10 @@ def test_combining_tokens(
(0, 0, 2, []),
])
def test_paging(
session, executor, user_factory, page, page_size,
executor, user_factory, page, page_size,
expected_total_count, expected_user_names):
session.add(user_factory(name='u1'))
session.add(user_factory(name='u2'))
db.session.add(user_factory(name='u1'))
db.session.add(user_factory(name='u2'))
actual_count, actual_users = executor.execute(
'', page=page, page_size=page_size)
actual_user_names = [u.name for u in actual_users]
@ -134,9 +134,9 @@ def test_paging(
('-order:name,desc', ['u1', 'u2']),
])
def test_order_by_name(
session, verify_unpaged, input, expected_user_names, user_factory):
session.add(user_factory(name='u2'))
session.add(user_factory(name='u1'))
verify_unpaged, input, expected_user_names, user_factory):
db.session.add(user_factory(name='u2'))
db.session.add(user_factory(name='u1'))
verify_unpaged(input, expected_user_names)
@pytest.mark.parametrize('input,expected_user_names', [
@ -150,14 +150,14 @@ def test_order_by_name(
('-order:creation-date,desc', ['u1', 'u2', 'u3']),
])
def test_order_by_creation_time(
session, verify_unpaged, input, expected_user_names, user_factory):
verify_unpaged, input, expected_user_names, user_factory):
user1 = user_factory(name='u1')
user2 = user_factory(name='u2')
user3 = user_factory(name='u3')
user1.creation_time = datetime.datetime(1991, 1, 1)
user2.creation_time = datetime.datetime(1991, 1, 2)
user3.creation_time = datetime.datetime(1991, 1, 3)
session.add_all([user3, user1, user2])
db.session.add_all([user3, user1, user2])
verify_unpaged(input, expected_user_names)
@pytest.mark.parametrize('input,expected_user_names', [
@ -168,21 +168,21 @@ def test_order_by_creation_time(
('order:login-time', ['u3', 'u2', 'u1']),
])
def test_order_by_last_login_time(
session, verify_unpaged, input, expected_user_names, user_factory):
verify_unpaged, input, expected_user_names, user_factory):
user1 = user_factory(name='u1')
user2 = user_factory(name='u2')
user3 = user_factory(name='u3')
user1.last_login_time = datetime.datetime(1991, 1, 1)
user2.last_login_time = datetime.datetime(1991, 1, 2)
user3.last_login_time = datetime.datetime(1991, 1, 3)
session.add_all([user3, user1, user2])
db.session.add_all([user3, user1, user2])
verify_unpaged(input, expected_user_names)
def test_random_order(session, executor, user_factory):
def test_random_order(executor, user_factory):
user1 = user_factory(name='u1')
user2 = user_factory(name='u2')
user3 = user_factory(name='u3')
session.add_all([user3, user1, user2])
db.session.add_all([user3, user1, user2])
actual_count, actual_users = executor.execute(
'order:random', page=1, page_size=100)
actual_user_names = [u.name for u in actual_users]

View file

@ -3,7 +3,7 @@ import pytest
from szurubooru import db
from szurubooru.util import snapshots
def test_serializing_tag(session, tag_factory):
def test_serializing_tag(tag_factory):
tag = tag_factory(names=['main_name', 'alias'], category_name='dummy')
assert snapshots.get_tag_snapshot(tag) == {
'names': ['main_name', 'alias'],
@ -15,10 +15,10 @@ def test_serializing_tag(session, tag_factory):
imp2 = tag_factory(names=['imp2_main_name', 'imp2_alias'])
sug1 = tag_factory(names=['sug1_main_name', 'sug1_alias'])
sug2 = tag_factory(names=['sug2_main_name', 'sug2_alias'])
session.add_all([imp1, imp2, sug1, sug2])
db.session.add_all([imp1, imp2, sug1, sug2])
tag.implications = [imp1, imp2]
tag.suggestions = [sug1, sug2]
session.flush()
db.session.flush()
assert snapshots.get_tag_snapshot(tag) == {
'names': ['main_name', 'alias'],
'category': 'dummy',
@ -26,39 +26,38 @@ def test_serializing_tag(session, tag_factory):
'suggestions': ['sug1_main_name', 'sug2_main_name'],
}
def test_merging_modification_to_creation(session, tag_factory, user_factory):
def test_merging_modification_to_creation(tag_factory, user_factory):
tag = tag_factory(names=['dummy'])
user = user_factory()
session.add_all([tag, user])
session.flush()
db.session.add_all([tag, user])
db.session.flush()
snapshots.create(tag, user)
session.flush()
db.session.flush()
tag.names = [db.TagName('changed')]
snapshots.modify(tag, user)
session.flush()
results = session.query(db.Snapshot).all()
db.session.flush()
results = db.session.query(db.Snapshot).all()
assert len(results) == 1
assert results[0].operation == db.Snapshot.OPERATION_CREATED
assert results[0].data['names'] == ['changed']
def test_merging_modifications(
fake_datetime, session, tag_factory, user_factory):
def test_merging_modifications(fake_datetime, tag_factory, user_factory):
tag = tag_factory(names=['dummy'])
user = user_factory()
session.add_all([tag, user])
session.flush()
db.session.add_all([tag, user])
db.session.flush()
with fake_datetime('13:00:00'):
snapshots.create(tag, user)
session.flush()
db.session.flush()
tag.names = [db.TagName('changed')]
with fake_datetime('14:00:00'):
snapshots.modify(tag, user)
session.flush()
db.session.flush()
tag.names = [db.TagName('changed again')]
with fake_datetime('14:00:01'):
snapshots.modify(tag, user)
session.flush()
results = session.query(db.Snapshot).all()
db.session.flush()
results = db.session.query(db.Snapshot).all()
assert len(results) == 2
assert results[0].operation == db.Snapshot.OPERATION_CREATED
assert results[1].operation == db.Snapshot.OPERATION_MODIFIED
@ -66,93 +65,94 @@ def test_merging_modifications(
assert results[1].data['names'] == ['changed again']
def test_not_adding_snapshot_if_data_doesnt_change(
fake_datetime, session, tag_factory, user_factory):
fake_datetime, tag_factory, user_factory):
tag = tag_factory(names=['dummy'])
user = user_factory()
session.add_all([tag, user])
session.flush()
db.session.add_all([tag, user])
db.session.flush()
with fake_datetime('13:00:00'):
snapshots.create(tag, user)
session.flush()
db.session.flush()
with fake_datetime('14:00:00'):
snapshots.modify(tag, user)
session.flush()
results = session.query(db.Snapshot).all()
db.session.flush()
results = db.session.query(db.Snapshot).all()
assert len(results) == 1
assert results[0].operation == db.Snapshot.OPERATION_CREATED
assert results[0].data['names'] == ['dummy']
def test_not_merging_due_to_time_difference(
fake_datetime, session, tag_factory, user_factory):
fake_datetime, tag_factory, user_factory):
tag = tag_factory(names=['dummy'])
user = user_factory()
session.add_all([tag, user])
session.flush()
db.session.add_all([tag, user])
db.session.flush()
with fake_datetime('13:00:00'):
snapshots.create(tag, user)
session.flush()
db.session.flush()
tag.names = [db.TagName('changed')]
with fake_datetime('13:10:01'):
snapshots.modify(tag, user)
session.flush()
assert session.query(db.Snapshot).count() == 2
db.session.flush()
assert db.session.query(db.Snapshot).count() == 2
def test_not_merging_operations_by_different_users(
fake_datetime, session, tag_factory, user_factory):
fake_datetime, tag_factory, user_factory):
tag = tag_factory(names=['dummy'])
user1, user2 = [user_factory(), user_factory()]
session.add_all([tag, user1, user2])
session.flush()
db.session.add_all([tag, user1, user2])
db.session.flush()
with fake_datetime('13:00:00'):
snapshots.create(tag, user1)
session.flush()
db.session.flush()
tag.names = [db.TagName('changed')]
snapshots.modify(tag, user2)
session.flush()
assert session.query(db.Snapshot).count() == 2
db.session.flush()
assert db.session.query(db.Snapshot).count() == 2
def test_merging_resets_merging_time_window(
fake_datetime, session, tag_factory, user_factory):
fake_datetime, tag_factory, user_factory):
tag = tag_factory(names=['dummy'])
user = user_factory()
session.add_all([tag, user])
session.flush()
db.session.add_all([tag, user])
db.session.flush()
with fake_datetime('13:00:00'):
snapshots.create(tag, user)
session.flush()
db.session.flush()
tag.names = [db.TagName('changed')]
with fake_datetime('13:09:59'):
snapshots.modify(tag, user)
session.flush()
db.session.flush()
tag.names = [db.TagName('changed again')]
with fake_datetime('13:19:59'):
snapshots.modify(tag, user)
session.flush()
results = session.query(db.Snapshot).all()
db.session.flush()
results = db.session.query(db.Snapshot).all()
assert len(results) == 1
assert results[0].data['names'] == ['changed again']
@pytest.mark.parametrize(
'initial_operation', [snapshots.create, snapshots.modify])
def test_merging_deletion_to_modification_or_creation(
fake_datetime, session, tag_factory, user_factory, initial_operation):
fake_datetime, tag_factory, user_factory, initial_operation):
tag = tag_factory(names=['dummy'], category_name='dummy')
user = user_factory()
session.add_all([tag, user])
session.flush()
db.session.add_all([tag, user])
db.session.flush()
with fake_datetime('13:00:00'):
initial_operation(tag, user)
session.flush()
db.session.flush()
tag.names = [db.TagName('changed')]
with fake_datetime('14:00:00'):
snapshots.modify(tag, user)
session.flush()
db.session.flush()
tag.names = [db.TagName('changed again')]
with fake_datetime('14:00:01'):
snapshots.delete(tag, user)
session.flush()
assert session.query(db.Snapshot).count() == 2
results = session.query(db.Snapshot) \
db.session.flush()
assert db.session.query(db.Snapshot).count() == 2
results = db.session \
.query(db.Snapshot) \
.order_by(db.Snapshot.snapshot_id.asc()) \
.all()
assert results[1].operation == db.Snapshot.OPERATION_DELETED
@ -161,20 +161,20 @@ def test_merging_deletion_to_modification_or_creation(
@pytest.mark.parametrize(
'expected_operation', [snapshots.create, snapshots.modify])
def test_merging_deletion_all_the_way_deletes_all_snapshots(
fake_datetime, session, tag_factory, user_factory, expected_operation):
fake_datetime, tag_factory, user_factory, expected_operation):
tag = tag_factory(names=['dummy'])
user = user_factory()
session.add_all([tag, user])
session.flush()
db.session.add_all([tag, user])
db.session.flush()
with fake_datetime('13:00:00'):
snapshots.create(tag, user)
session.flush()
db.session.flush()
tag.names = [db.TagName('changed')]
with fake_datetime('13:00:01'):
snapshots.modify(tag, user)
session.flush()
db.session.flush()
tag.names = [db.TagName('changed again')]
with fake_datetime('13:00:02'):
snapshots.delete(tag, user)
session.flush()
assert session.query(db.Snapshot).count() == 0
db.session.flush()
assert db.session.query(db.Snapshot).count() == 0

View file

@ -34,8 +34,8 @@ def save(operation, entity, auth_user):
snapshot.data = serializers[table_name](entity)
snapshot.user = auth_user
session = db.session()
earlier_snapshots = session.query(db.Snapshot) \
earlier_snapshots = db.session \
.query(db.Snapshot) \
.filter(db.Snapshot.resource_type == table_name) \
.filter(db.Snapshot.resource_id == primary_key) \
.order_by(db.Snapshot.creation_time.desc()) \
@ -49,7 +49,7 @@ def save(operation, entity, auth_user):
if snapshot.data != last_snapshot.data:
if not is_fresh or last_snapshot.user != auth_user:
break
session.delete(last_snapshot)
db.session.delete(last_snapshot)
if snapshot.operation != db.Snapshot.OPERATION_DELETED:
snapshot.operation = last_snapshot.operation
snapshots_left -= 1
@ -57,7 +57,7 @@ def save(operation, entity, auth_user):
if not snapshots_left and operation == db.Snapshot.OPERATION_DELETED:
pass
else:
session.add(snapshot)
db.session.add(snapshot)
def create(entity, auth_user):
save(db.Snapshot.OPERATION_CREATED, entity, auth_user)

View file

@ -26,7 +26,7 @@ def update_name(category, name):
expr = db.TagCategory.name.ilike(name)
if category.tag_category_id:
expr = expr & (db.TagCategory.tag_category_id != category.tag_category_id)
already_exists = db.session().query(db.TagCategory).filter(expr).count() > 0
already_exists = db.session.query(db.TagCategory).filter(expr).count() > 0
if already_exists:
raise TagCategoryAlreadyExistsError(
'A category with this name already exists.')
@ -43,7 +43,8 @@ def update_color(category, color):
category.color = color
def get_category_by_name(name):
return db.session.query(db.TagCategory) \
return db.session \
.query(db.TagCategory) \
.filter(db.TagCategory.name.ilike(name)) \
.first()
@ -54,7 +55,8 @@ def get_all_categories():
return db.session.query(db.TagCategory).all()
def get_default_category():
return db.session().query(db.TagCategory) \
return db.session \
.query(db.TagCategory) \
.order_by(db.TagCategory.tag_category_id.asc()) \
.limit(1) \
.one()

View file

@ -32,7 +32,7 @@ def export_to_json():
'tags': [],
'categories': [],
}
all_tags = db.session() \
all_tags = db.session \
.query(db.Tag) \
.options(
sqlalchemy.orm.joinedload('suggestions'),
@ -61,7 +61,8 @@ def export_to_json():
handle.write(json.dumps(output, separators=(',', ':')))
def get_tag_by_name(name):
return db.session().query(db.Tag) \
return db.session \
.query(db.Tag) \
.join(db.TagName) \
.filter(db.TagName.name.ilike(name)) \
.first()
@ -73,7 +74,7 @@ def get_tags_by_names(names):
expr = sqlalchemy.sql.false()
for name in names:
expr = expr | db.TagName.name.ilike(name)
return db.session().query(db.Tag).join(db.TagName).filter(expr).all()
return db.session.query(db.Tag).join(db.TagName).filter(expr).all()
def get_or_create_tags_by_names(names):
names = misc.icase_unique(names)
@ -93,7 +94,7 @@ def get_or_create_tags_by_names(names):
category_name=tag_categories.get_default_category().name,
suggestions=[],
implications=[])
db.session().add(new_tag)
db.session.add(new_tag)
new_tags.append(new_tag)
return related_tags, new_tags
@ -107,8 +108,8 @@ def create_tag(names, category_name, suggestions, implications):
return tag
def update_category_name(tag, category_name):
session = db.session()
category = session.query(db.TagCategory) \
category = db.session \
.query(db.TagCategory) \
.filter(db.TagCategory.name == category_name) \
.first()
if not category:
@ -131,7 +132,7 @@ def update_names(tag, names):
expr = expr | db.TagName.name.ilike(name)
if tag.tag_id:
expr = expr & (db.TagName.tag_id != tag.tag_id)
existing_tags = db.session().query(db.TagName).filter(expr).all()
existing_tags = db.session.query(db.TagName).filter(expr).all()
if len(existing_tags):
raise TagAlreadyExistsError(
'One of names is already used by another tag.')
@ -149,12 +150,12 @@ def update_implications(tag, relations):
if _check_name_intersection(_get_plain_names(tag), relations):
raise InvalidTagRelationError('Tag cannot imply itself.')
related_tags, new_tags = get_or_create_tags_by_names(relations)
db.session().flush()
db.session.flush()
tag.implications = related_tags + new_tags
def update_suggestions(tag, relations):
if _check_name_intersection(_get_plain_names(tag), relations):
raise InvalidTagRelationError('Tag cannot suggest itself.')
related_tags, new_tags = get_or_create_tags_by_names(relations)
db.session().flush()
db.session.flush()
tag.suggestions = related_tags + new_tags

View file

@ -16,12 +16,14 @@ def get_user_count():
return db.session.query(db.User).count()
def get_user_by_name(name):
return db.session.query(db.User) \
return db.session \
.query(db.User) \
.filter(func.lower(db.User.name) == func.lower(name)) \
.first()
def get_user_by_name_or_email(name_or_email):
return db.session.query(db.User) \
return db.session \
.query(db.User) \
.filter(
(func.lower(db.User.name) == func.lower(name_or_email))
| (func.lower(db.User.email) == func.lower(name_or_email))) \