server/general: embrace most of PEP8

Ignored only the rules about continuing / hanging indentation.

Also, added __init__.py to tests so that pylint discovers them. (I don't
buy pytest's BS about installing your package.)
This commit is contained in:
rr- 2016-08-14 14:22:53 +02:00
parent af62f8c45a
commit 9aea55e3d1
129 changed files with 2251 additions and 1077 deletions

View file

@ -8,11 +8,26 @@ good-names=ex,_,logger
dummy-variables-rgx=_|dummy dummy-variables-rgx=_|dummy
[format] [format]
max-line-length=90 max-line-length=79
[messages control] [messages control]
disable=missing-docstring,no-self-use,too-few-public-methods,multiple-statements
reports=no reports=no
disable=
# we're not java
missing-docstring,
# covered better by pycodestyle
bad-continuation,
# we're adults
redefined-builtin,
duplicate-code,
too-many-return-statements,
too-many-arguments,
# plain stupid
no-self-use,
too-few-public-methods
[typecheck] [typecheck]
generated-members=add|add_all generated-members=add|add_all

View file

@ -3,20 +3,24 @@ from szurubooru import search
from szurubooru.func import auth, comments, posts, scores, util from szurubooru.func import auth, comments, posts, scores, util
from szurubooru.rest import routes from szurubooru.rest import routes
_search_executor = search.Executor(search.configs.CommentSearchConfig()) _search_executor = search.Executor(search.configs.CommentSearchConfig())
def _serialize(ctx, comment, **kwargs): def _serialize(ctx, comment, **kwargs):
return comments.serialize_comment( return comments.serialize_comment(
comment, comment,
ctx.user, ctx.user,
options=util.get_serialization_options(ctx), **kwargs) options=util.get_serialization_options(ctx), **kwargs)
@routes.get('/comments/?') @routes.get('/comments/?')
def get_comments(ctx, _params=None): def get_comments(ctx, _params=None):
auth.verify_privilege(ctx.user, 'comments:list') auth.verify_privilege(ctx.user, 'comments:list')
return _search_executor.execute_and_serialize( return _search_executor.execute_and_serialize(
ctx, lambda comment: _serialize(ctx, comment)) ctx, lambda comment: _serialize(ctx, comment))
@routes.post('/comments/?') @routes.post('/comments/?')
def create_comment(ctx, _params=None): def create_comment(ctx, _params=None):
auth.verify_privilege(ctx.user, 'comments:create') auth.verify_privilege(ctx.user, 'comments:create')
@ -28,12 +32,14 @@ def create_comment(ctx, _params=None):
ctx.session.commit() ctx.session.commit()
return _serialize(ctx, comment) return _serialize(ctx, comment)
@routes.get('/comment/(?P<comment_id>[^/]+)/?') @routes.get('/comment/(?P<comment_id>[^/]+)/?')
def get_comment(ctx, params): def get_comment(ctx, params):
auth.verify_privilege(ctx.user, 'comments:view') auth.verify_privilege(ctx.user, 'comments:view')
comment = comments.get_comment_by_id(params['comment_id']) comment = comments.get_comment_by_id(params['comment_id'])
return _serialize(ctx, comment) return _serialize(ctx, comment)
@routes.put('/comment/(?P<comment_id>[^/]+)/?') @routes.put('/comment/(?P<comment_id>[^/]+)/?')
def update_comment(ctx, params): def update_comment(ctx, params):
comment = comments.get_comment_by_id(params['comment_id']) comment = comments.get_comment_by_id(params['comment_id'])
@ -47,6 +53,7 @@ def update_comment(ctx, params):
ctx.session.commit() ctx.session.commit()
return _serialize(ctx, comment) return _serialize(ctx, comment)
@routes.delete('/comment/(?P<comment_id>[^/]+)/?') @routes.delete('/comment/(?P<comment_id>[^/]+)/?')
def delete_comment(ctx, params): def delete_comment(ctx, params):
comment = comments.get_comment_by_id(params['comment_id']) comment = comments.get_comment_by_id(params['comment_id'])
@ -57,6 +64,7 @@ def delete_comment(ctx, params):
ctx.session.commit() ctx.session.commit()
return {} return {}
@routes.put('/comment/(?P<comment_id>[^/]+)/score/?') @routes.put('/comment/(?P<comment_id>[^/]+)/score/?')
def set_comment_score(ctx, params): def set_comment_score(ctx, params):
auth.verify_privilege(ctx.user, 'comments:score') auth.verify_privilege(ctx.user, 'comments:score')
@ -66,6 +74,7 @@ def set_comment_score(ctx, params):
ctx.session.commit() ctx.session.commit()
return _serialize(ctx, comment) return _serialize(ctx, comment)
@routes.delete('/comment/(?P<comment_id>[^/]+)/score/?') @routes.delete('/comment/(?P<comment_id>[^/]+)/score/?')
def delete_comment_score(ctx, params): def delete_comment_score(ctx, params):
auth.verify_privilege(ctx.user, 'comments:score') auth.verify_privilege(ctx.user, 'comments:score')

View file

@ -4,11 +4,13 @@ from szurubooru import config
from szurubooru.func import posts, users, util from szurubooru.func import posts, users, util
from szurubooru.rest import routes from szurubooru.rest import routes
_cache_time = None _cache_time = None
_cache_result = None _cache_result = None
def _get_disk_usage(): def _get_disk_usage():
global _cache_time, _cache_result # pylint: disable=global-statement global _cache_time, _cache_result # pylint: disable=global-statement
threshold = datetime.timedelta(hours=1) threshold = datetime.timedelta(hours=1)
now = datetime.datetime.utcnow() now = datetime.datetime.utcnow()
if _cache_time and _cache_time > now - threshold: if _cache_time and _cache_time > now - threshold:
@ -22,17 +24,20 @@ def _get_disk_usage():
_cache_result = total_size _cache_result = total_size
return total_size return total_size
@routes.get('/info/?') @routes.get('/info/?')
def get_info(ctx, _params=None): def get_info(ctx, _params=None):
post_feature = posts.try_get_current_post_feature() post_feature = posts.try_get_current_post_feature()
return { return {
'postCount': posts.get_post_count(), 'postCount': posts.get_post_count(),
'diskUsage': _get_disk_usage(), 'diskUsage': _get_disk_usage(),
'featuredPost': posts.serialize_post(post_feature.post, ctx.user) \ 'featuredPost':
if post_feature else None, posts.serialize_post(post_feature.post, ctx.user)
if post_feature else None,
'featuringTime': post_feature.time if post_feature else None, 'featuringTime': post_feature.time if post_feature else None,
'featuringUser': users.serialize_user(post_feature.user, ctx.user) \ 'featuringUser':
if post_feature else None, users.serialize_user(post_feature.user, ctx.user)
if post_feature else None,
'serverTime': datetime.datetime.utcnow(), 'serverTime': datetime.datetime.utcnow(),
'config': { 'config': {
'userNameRegex': config.config['user_name_regex'], 'userNameRegex': config.config['user_name_regex'],
@ -40,7 +45,8 @@ def get_info(ctx, _params=None):
'tagNameRegex': config.config['tag_name_regex'], 'tagNameRegex': config.config['tag_name_regex'],
'tagCategoryNameRegex': config.config['tag_category_name_regex'], 'tagCategoryNameRegex': config.config['tag_category_name_regex'],
'defaultUserRank': config.config['default_rank'], 'defaultUserRank': config.config['default_rank'],
'privileges': util.snake_case_to_lower_camel_case_keys( 'privileges':
config.config['privileges']), util.snake_case_to_lower_camel_case_keys(
config.config['privileges']),
}, },
} }

View file

@ -2,12 +2,14 @@ from szurubooru import config, errors
from szurubooru.func import auth, mailer, users, util from szurubooru.func import auth, mailer, users, util
from szurubooru.rest import routes from szurubooru.rest import routes
MAIL_SUBJECT = 'Password reset for {name}' MAIL_SUBJECT = 'Password reset for {name}'
MAIL_BODY = \ MAIL_BODY = \
'You (or someone else) requested to reset your password on {name}.\n' \ 'You (or someone else) requested to reset your password on {name}.\n' \
'If you wish to proceed, click this link: {url}\n' \ 'If you wish to proceed, click this link: {url}\n' \
'Otherwise, please ignore this email.' 'Otherwise, please ignore this email.'
@routes.get('/password-reset/(?P<user_name>[^/]+)/?') @routes.get('/password-reset/(?P<user_name>[^/]+)/?')
def start_password_reset(_ctx, params): def start_password_reset(_ctx, params):
''' Send a mail with secure token to the correlated user. ''' ''' Send a mail with secure token to the correlated user. '''
@ -27,6 +29,7 @@ def start_password_reset(_ctx, params):
MAIL_BODY.format(name=config.config['name'], url=url)) MAIL_BODY.format(name=config.config['name'], url=url))
return {} return {}
@routes.post('/password-reset/(?P<user_name>[^/]+)/?') @routes.post('/password-reset/(?P<user_name>[^/]+)/?')
def finish_password_reset(ctx, params): def finish_password_reset(ctx, params):
''' Verify token from mail, generate a new password and return it. ''' ''' Verify token from mail, generate a new password and return it. '''

View file

@ -1,16 +1,20 @@
import datetime import datetime
from szurubooru import search from szurubooru import search
from szurubooru.func import auth, tags, posts, snapshots, favorites, scores, util
from szurubooru.rest import routes from szurubooru.rest import routes
from szurubooru.func import (
auth, tags, posts, snapshots, favorites, scores, util)
_search_executor = search.Executor(search.configs.PostSearchConfig()) _search_executor = search.Executor(search.configs.PostSearchConfig())
def _serialize_post(ctx, post): def _serialize_post(ctx, post):
return posts.serialize_post( return posts.serialize_post(
post, post,
ctx.user, ctx.user,
options=util.get_serialization_options(ctx)) options=util.get_serialization_options(ctx))
@routes.get('/posts/?') @routes.get('/posts/?')
def get_posts(ctx, _params=None): def get_posts(ctx, _params=None):
auth.verify_privilege(ctx.user, 'posts:list') auth.verify_privilege(ctx.user, 'posts:list')
@ -18,6 +22,7 @@ def get_posts(ctx, _params=None):
return _search_executor.execute_and_serialize( return _search_executor.execute_and_serialize(
ctx, lambda post: _serialize_post(ctx, post)) ctx, lambda post: _serialize_post(ctx, post))
@routes.post('/posts/?') @routes.post('/posts/?')
def create_post(ctx, _params=None): def create_post(ctx, _params=None):
anonymous = ctx.get_param_as_bool('anonymous', default=False) anonymous = ctx.get_param_as_bool('anonymous', default=False)
@ -52,12 +57,14 @@ def create_post(ctx, _params=None):
tags.export_to_json() tags.export_to_json()
return _serialize_post(ctx, post) return _serialize_post(ctx, post)
@routes.get('/post/(?P<post_id>[^/]+)/?') @routes.get('/post/(?P<post_id>[^/]+)/?')
def get_post(ctx, params): def get_post(ctx, params):
auth.verify_privilege(ctx.user, 'posts:view') auth.verify_privilege(ctx.user, 'posts:view')
post = posts.get_post_by_id(params['post_id']) post = posts.get_post_by_id(params['post_id'])
return _serialize_post(ctx, post) return _serialize_post(ctx, post)
@routes.put('/post/(?P<post_id>[^/]+)/?') @routes.put('/post/(?P<post_id>[^/]+)/?')
def update_post(ctx, params): def update_post(ctx, params):
post = posts.get_post_by_id(params['post_id']) post = posts.get_post_by_id(params['post_id'])
@ -98,6 +105,7 @@ def update_post(ctx, params):
tags.export_to_json() tags.export_to_json()
return _serialize_post(ctx, post) return _serialize_post(ctx, post)
@routes.delete('/post/(?P<post_id>[^/]+)/?') @routes.delete('/post/(?P<post_id>[^/]+)/?')
def delete_post(ctx, params): def delete_post(ctx, params):
auth.verify_privilege(ctx.user, 'posts:delete') auth.verify_privilege(ctx.user, 'posts:delete')
@ -109,11 +117,13 @@ def delete_post(ctx, params):
tags.export_to_json() tags.export_to_json()
return {} return {}
@routes.get('/featured-post/?') @routes.get('/featured-post/?')
def get_featured_post(ctx, _params=None): def get_featured_post(ctx, _params=None):
post = posts.try_get_featured_post() post = posts.try_get_featured_post()
return _serialize_post(ctx, post) return _serialize_post(ctx, post)
@routes.post('/featured-post/?') @routes.post('/featured-post/?')
def set_featured_post(ctx, _params=None): def set_featured_post(ctx, _params=None):
auth.verify_privilege(ctx.user, 'posts:feature') auth.verify_privilege(ctx.user, 'posts:feature')
@ -130,6 +140,7 @@ def set_featured_post(ctx, _params=None):
ctx.session.commit() ctx.session.commit()
return _serialize_post(ctx, post) return _serialize_post(ctx, post)
@routes.put('/post/(?P<post_id>[^/]+)/score/?') @routes.put('/post/(?P<post_id>[^/]+)/score/?')
def set_post_score(ctx, params): def set_post_score(ctx, params):
auth.verify_privilege(ctx.user, 'posts:score') auth.verify_privilege(ctx.user, 'posts:score')
@ -139,6 +150,7 @@ def set_post_score(ctx, params):
ctx.session.commit() ctx.session.commit()
return _serialize_post(ctx, post) return _serialize_post(ctx, post)
@routes.delete('/post/(?P<post_id>[^/]+)/score/?') @routes.delete('/post/(?P<post_id>[^/]+)/score/?')
def delete_post_score(ctx, params): def delete_post_score(ctx, params):
auth.verify_privilege(ctx.user, 'posts:score') auth.verify_privilege(ctx.user, 'posts:score')
@ -147,6 +159,7 @@ def delete_post_score(ctx, params):
ctx.session.commit() ctx.session.commit()
return _serialize_post(ctx, post) return _serialize_post(ctx, post)
@routes.post('/post/(?P<post_id>[^/]+)/favorite/?') @routes.post('/post/(?P<post_id>[^/]+)/favorite/?')
def add_post_to_favorites(ctx, params): def add_post_to_favorites(ctx, params):
auth.verify_privilege(ctx.user, 'posts:favorite') auth.verify_privilege(ctx.user, 'posts:favorite')
@ -155,6 +168,7 @@ def add_post_to_favorites(ctx, params):
ctx.session.commit() ctx.session.commit()
return _serialize_post(ctx, post) return _serialize_post(ctx, post)
@routes.delete('/post/(?P<post_id>[^/]+)/favorite/?') @routes.delete('/post/(?P<post_id>[^/]+)/favorite/?')
def delete_post_from_favorites(ctx, params): def delete_post_from_favorites(ctx, params):
auth.verify_privilege(ctx.user, 'posts:favorite') auth.verify_privilege(ctx.user, 'posts:favorite')
@ -163,6 +177,7 @@ def delete_post_from_favorites(ctx, params):
ctx.session.commit() ctx.session.commit()
return _serialize_post(ctx, post) return _serialize_post(ctx, post)
@routes.get('/post/(?P<post_id>[^/]+)/around/?') @routes.get('/post/(?P<post_id>[^/]+)/around/?')
def get_posts_around(ctx, params): def get_posts_around(ctx, params):
auth.verify_privilege(ctx.user, 'posts:list') auth.verify_privilege(ctx.user, 'posts:list')

View file

@ -2,8 +2,9 @@ from szurubooru import search
from szurubooru.func import auth, snapshots from szurubooru.func import auth, snapshots
from szurubooru.rest import routes from szurubooru.rest import routes
_search_executor = search.Executor(
search.configs.SnapshotSearchConfig()) _search_executor = search.Executor(search.configs.SnapshotSearchConfig())
@routes.get('/snapshots/?') @routes.get('/snapshots/?')
def get_snapshots(ctx, _params=None): def get_snapshots(ctx, _params=None):

View file

@ -3,12 +3,15 @@ from szurubooru import db, search
from szurubooru.func import auth, tags, util, snapshots from szurubooru.func import auth, tags, util, snapshots
from szurubooru.rest import routes from szurubooru.rest import routes
_search_executor = search.Executor(search.configs.TagSearchConfig()) _search_executor = search.Executor(search.configs.TagSearchConfig())
def _serialize(ctx, tag): def _serialize(ctx, tag):
return tags.serialize_tag( return tags.serialize_tag(
tag, options=util.get_serialization_options(ctx)) tag, options=util.get_serialization_options(ctx))
def _create_if_needed(tag_names, user): def _create_if_needed(tag_names, user):
if not tag_names: if not tag_names:
return return
@ -19,12 +22,14 @@ def _create_if_needed(tag_names, user):
for tag in new_tags: for tag in new_tags:
snapshots.save_entity_creation(tag, user) snapshots.save_entity_creation(tag, user)
@routes.get('/tags/?') @routes.get('/tags/?')
def get_tags(ctx, _params=None): def get_tags(ctx, _params=None):
auth.verify_privilege(ctx.user, 'tags:list') auth.verify_privilege(ctx.user, 'tags:list')
return _search_executor.execute_and_serialize( return _search_executor.execute_and_serialize(
ctx, lambda tag: _serialize(ctx, tag)) ctx, lambda tag: _serialize(ctx, tag))
@routes.post('/tags/?') @routes.post('/tags/?')
def create_tag(ctx, _params=None): def create_tag(ctx, _params=None):
auth.verify_privilege(ctx.user, 'tags:create') auth.verify_privilege(ctx.user, 'tags:create')
@ -50,12 +55,14 @@ def create_tag(ctx, _params=None):
tags.export_to_json() tags.export_to_json()
return _serialize(ctx, tag) return _serialize(ctx, tag)
@routes.get('/tag/(?P<tag_name>[^/]+)/?') @routes.get('/tag/(?P<tag_name>[^/]+)/?')
def get_tag(ctx, params): def get_tag(ctx, params):
auth.verify_privilege(ctx.user, 'tags:view') auth.verify_privilege(ctx.user, 'tags:view')
tag = tags.get_tag_by_name(params['tag_name']) tag = tags.get_tag_by_name(params['tag_name'])
return _serialize(ctx, tag) return _serialize(ctx, tag)
@routes.put('/tag/(?P<tag_name>[^/]+)/?') @routes.put('/tag/(?P<tag_name>[^/]+)/?')
def update_tag(ctx, params): def update_tag(ctx, params):
tag = tags.get_tag_by_name(params['tag_name']) tag = tags.get_tag_by_name(params['tag_name'])
@ -89,6 +96,7 @@ def update_tag(ctx, params):
tags.export_to_json() tags.export_to_json()
return _serialize(ctx, tag) return _serialize(ctx, tag)
@routes.delete('/tag/(?P<tag_name>[^/]+)/?') @routes.delete('/tag/(?P<tag_name>[^/]+)/?')
def delete_tag(ctx, params): def delete_tag(ctx, params):
tag = tags.get_tag_by_name(params['tag_name']) tag = tags.get_tag_by_name(params['tag_name'])
@ -100,6 +108,7 @@ def delete_tag(ctx, params):
tags.export_to_json() tags.export_to_json()
return {} return {}
@routes.post('/tag-merge/?') @routes.post('/tag-merge/?')
def merge_tags(ctx, _params=None): def merge_tags(ctx, _params=None):
source_tag_name = ctx.get_param_as_string('remove', required=True) or '' source_tag_name = ctx.get_param_as_string('remove', required=True) or ''
@ -116,6 +125,7 @@ def merge_tags(ctx, _params=None):
tags.export_to_json() tags.export_to_json()
return _serialize(ctx, target_tag) return _serialize(ctx, target_tag)
@routes.get('/tag-siblings/(?P<tag_name>[^/]+)/?') @routes.get('/tag-siblings/(?P<tag_name>[^/]+)/?')
def get_tag_siblings(ctx, params): def get_tag_siblings(ctx, params):
auth.verify_privilege(ctx.user, 'tags:view') auth.verify_privilege(ctx.user, 'tags:view')

View file

@ -1,10 +1,12 @@
from szurubooru.rest import routes from szurubooru.rest import routes
from szurubooru.func import auth, tags, tag_categories, util, snapshots from szurubooru.func import auth, tags, tag_categories, util, snapshots
def _serialize(ctx, category): def _serialize(ctx, category):
return tag_categories.serialize_category( return tag_categories.serialize_category(
category, options=util.get_serialization_options(ctx)) category, options=util.get_serialization_options(ctx))
@routes.get('/tag-categories/?') @routes.get('/tag-categories/?')
def get_tag_categories(ctx, _params=None): def get_tag_categories(ctx, _params=None):
auth.verify_privilege(ctx.user, 'tag_categories:list') auth.verify_privilege(ctx.user, 'tag_categories:list')
@ -13,6 +15,7 @@ def get_tag_categories(ctx, _params=None):
'results': [_serialize(ctx, category) for category in categories], 'results': [_serialize(ctx, category) for category in categories],
} }
@routes.post('/tag-categories/?') @routes.post('/tag-categories/?')
def create_tag_category(ctx, _params=None): def create_tag_category(ctx, _params=None):
auth.verify_privilege(ctx.user, 'tag_categories:create') auth.verify_privilege(ctx.user, 'tag_categories:create')
@ -26,12 +29,14 @@ def create_tag_category(ctx, _params=None):
tags.export_to_json() tags.export_to_json()
return _serialize(ctx, category) return _serialize(ctx, category)
@routes.get('/tag-category/(?P<category_name>[^/]+)/?') @routes.get('/tag-category/(?P<category_name>[^/]+)/?')
def get_tag_category(ctx, params): def get_tag_category(ctx, params):
auth.verify_privilege(ctx.user, 'tag_categories:view') auth.verify_privilege(ctx.user, 'tag_categories:view')
category = tag_categories.get_category_by_name(params['category_name']) category = tag_categories.get_category_by_name(params['category_name'])
return _serialize(ctx, category) return _serialize(ctx, category)
@routes.put('/tag-category/(?P<category_name>[^/]+)/?') @routes.put('/tag-category/(?P<category_name>[^/]+)/?')
def update_tag_category(ctx, params): def update_tag_category(ctx, params):
category = tag_categories.get_category_by_name(params['category_name']) category = tag_categories.get_category_by_name(params['category_name'])
@ -51,6 +56,7 @@ def update_tag_category(ctx, params):
tags.export_to_json() tags.export_to_json()
return _serialize(ctx, category) return _serialize(ctx, category)
@routes.delete('/tag-category/(?P<category_name>[^/]+)/?') @routes.delete('/tag-category/(?P<category_name>[^/]+)/?')
def delete_tag_category(ctx, params): def delete_tag_category(ctx, params):
category = tag_categories.get_category_by_name(params['category_name']) category = tag_categories.get_category_by_name(params['category_name'])
@ -62,6 +68,7 @@ def delete_tag_category(ctx, params):
tags.export_to_json() tags.export_to_json()
return {} return {}
@routes.put('/tag-category/(?P<category_name>[^/]+)/default/?') @routes.put('/tag-category/(?P<category_name>[^/]+)/default/?')
def set_tag_category_as_default(ctx, params): def set_tag_category_as_default(ctx, params):
auth.verify_privilege(ctx.user, 'tag_categories:set_default') auth.verify_privilege(ctx.user, 'tag_categories:set_default')

View file

@ -2,8 +2,10 @@ from szurubooru import search
from szurubooru.func import auth, users, util from szurubooru.func import auth, users, util
from szurubooru.rest import routes from szurubooru.rest import routes
_search_executor = search.Executor(search.configs.UserSearchConfig()) _search_executor = search.Executor(search.configs.UserSearchConfig())
def _serialize(ctx, user, **kwargs): def _serialize(ctx, user, **kwargs):
return users.serialize_user( return users.serialize_user(
user, user,
@ -11,12 +13,14 @@ def _serialize(ctx, user, **kwargs):
options=util.get_serialization_options(ctx), options=util.get_serialization_options(ctx),
**kwargs) **kwargs)
@routes.get('/users/?') @routes.get('/users/?')
def get_users(ctx, _params=None): def get_users(ctx, _params=None):
auth.verify_privilege(ctx.user, 'users:list') auth.verify_privilege(ctx.user, 'users:list')
return _search_executor.execute_and_serialize( return _search_executor.execute_and_serialize(
ctx, lambda user: _serialize(ctx, user)) ctx, lambda user: _serialize(ctx, user))
@routes.post('/users/?') @routes.post('/users/?')
def create_user(ctx, _params=None): def create_user(ctx, _params=None):
auth.verify_privilege(ctx.user, 'users:create') auth.verify_privilege(ctx.user, 'users:create')
@ -36,6 +40,7 @@ def create_user(ctx, _params=None):
ctx.session.commit() ctx.session.commit()
return _serialize(ctx, user, force_show_email=True) return _serialize(ctx, user, force_show_email=True)
@routes.get('/user/(?P<user_name>[^/]+)/?') @routes.get('/user/(?P<user_name>[^/]+)/?')
def get_user(ctx, params): def get_user(ctx, params):
user = users.get_user_by_name(params['user_name']) user = users.get_user_by_name(params['user_name'])
@ -43,6 +48,7 @@ def get_user(ctx, params):
auth.verify_privilege(ctx.user, 'users:view') auth.verify_privilege(ctx.user, 'users:view')
return _serialize(ctx, user) return _serialize(ctx, user)
@routes.put('/user/(?P<user_name>[^/]+)/?') @routes.put('/user/(?P<user_name>[^/]+)/?')
def update_user(ctx, params): def update_user(ctx, params):
user = users.get_user_by_name(params['user_name']) user = users.get_user_by_name(params['user_name'])
@ -72,6 +78,7 @@ def update_user(ctx, params):
ctx.session.commit() ctx.session.commit()
return _serialize(ctx, user) return _serialize(ctx, user)
@routes.delete('/user/(?P<user_name>[^/]+)/?') @routes.delete('/user/(?P<user_name>[^/]+)/?')
def delete_user(ctx, params): def delete_user(ctx, params):
user = users.get_user_by_name(params['user_name']) user = users.get_user_by_name(params['user_name'])

View file

@ -1,6 +1,7 @@
import os import os
import yaml import yaml
def merge(left, right): def merge(left, right):
for key in right: for key in right:
if key in left: if key in left:
@ -12,6 +13,7 @@ def merge(left, right):
left[key] = right[key] left[key] = right[key]
return left return left
def read_config(): def read_config():
with open('../config.yaml.dist') as handle: with open('../config.yaml.dist') as handle:
ret = yaml.load(handle.read()) ret = yaml.load(handle.read())
@ -20,4 +22,5 @@ def read_config():
ret = merge(ret, yaml.load(handle.read())) ret = merge(ret, yaml.load(handle.read()))
return ret return ret
config = read_config() # pylint: disable=invalid-name
config = read_config() # pylint: disable=invalid-name

View file

@ -1,2 +1,4 @@
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base() # pylint: disable=invalid-name
Base = declarative_base() # pylint: disable=invalid-name

View file

@ -3,13 +3,18 @@ from sqlalchemy.orm import relationship, backref
from sqlalchemy.sql.expression import func from sqlalchemy.sql.expression import func
from szurubooru.db.base import Base from szurubooru.db.base import Base
class CommentScore(Base): class CommentScore(Base):
__tablename__ = 'comment_score' __tablename__ = 'comment_score'
comment_id = Column( comment_id = Column(
'comment_id', Integer, ForeignKey('comment.id'), primary_key=True) 'comment_id', Integer, ForeignKey('comment.id'), primary_key=True)
user_id = Column( user_id = Column(
'user_id', Integer, ForeignKey('user.id'), primary_key=True, index=True) 'user_id',
Integer,
ForeignKey('user.id'),
primary_key=True,
index=True)
time = Column('time', DateTime, nullable=False) time = Column('time', DateTime, nullable=False)
score = Column('score', Integer, nullable=False) score = Column('score', Integer, nullable=False)
@ -18,14 +23,14 @@ class CommentScore(Base):
'User', 'User',
backref=backref('comment_scores', cascade='all, delete-orphan')) backref=backref('comment_scores', cascade='all, delete-orphan'))
class Comment(Base): class Comment(Base):
__tablename__ = 'comment' __tablename__ = 'comment'
comment_id = Column('id', Integer, primary_key=True) comment_id = Column('id', Integer, primary_key=True)
post_id = Column( post_id = Column(
'post_id', Integer, ForeignKey('post.id'), index=True, nullable=False) 'post_id', Integer, ForeignKey('post.id'), index=True, nullable=False)
user_id = Column( user_id = Column('user_id', Integer, ForeignKey('user.id'), index=True)
'user_id', Integer, ForeignKey('user.id'), index=True)
version = Column('version', Integer, default=1, nullable=False) version = Column('version', Integer, default=1, nullable=False)
creation_time = Column('creation_time', DateTime, nullable=False) creation_time = Column('creation_time', DateTime, nullable=False)
last_edit_time = Column('last_edit_time', DateTime) last_edit_time = Column('last_edit_time', DateTime)

View file

@ -1,10 +1,12 @@
from sqlalchemy.sql.expression import func, select
from sqlalchemy import ( from sqlalchemy import (
Column, Integer, DateTime, Unicode, UnicodeText, PickleType, ForeignKey) Column, Integer, DateTime, Unicode, UnicodeText, PickleType, ForeignKey)
from sqlalchemy.orm import relationship, column_property, object_session, backref from sqlalchemy.orm import (
from sqlalchemy.sql.expression import func, select relationship, column_property, object_session, backref)
from szurubooru.db.base import Base from szurubooru.db.base import Base
from szurubooru.db.comment import Comment from szurubooru.db.comment import Comment
class PostFeature(Base): class PostFeature(Base):
__tablename__ = 'post_feature' __tablename__ = 'post_feature'
@ -20,13 +22,22 @@ class PostFeature(Base):
'User', 'User',
backref=backref('post_features', cascade='all, delete-orphan')) backref=backref('post_features', cascade='all, delete-orphan'))
class PostScore(Base): class PostScore(Base):
__tablename__ = 'post_score' __tablename__ = 'post_score'
post_id = Column( post_id = Column(
'post_id', Integer, ForeignKey('post.id'), primary_key=True, index=True) 'post_id',
Integer,
ForeignKey('post.id'),
primary_key=True,
index=True)
user_id = Column( user_id = Column(
'user_id', Integer, ForeignKey('user.id'), primary_key=True, index=True) 'user_id',
Integer,
ForeignKey('user.id'),
primary_key=True,
index=True)
time = Column('time', DateTime, nullable=False) time = Column('time', DateTime, nullable=False)
score = Column('score', Integer, nullable=False) score = Column('score', Integer, nullable=False)
@ -35,13 +46,22 @@ class PostScore(Base):
'User', 'User',
backref=backref('post_scores', cascade='all, delete-orphan')) backref=backref('post_scores', cascade='all, delete-orphan'))
class PostFavorite(Base): class PostFavorite(Base):
__tablename__ = 'post_favorite' __tablename__ = 'post_favorite'
post_id = Column( post_id = Column(
'post_id', Integer, ForeignKey('post.id'), primary_key=True, index=True) 'post_id',
Integer,
ForeignKey('post.id'),
primary_key=True,
index=True)
user_id = Column( user_id = Column(
'user_id', Integer, ForeignKey('user.id'), primary_key=True, index=True) 'user_id',
Integer,
ForeignKey('user.id'),
primary_key=True,
index=True)
time = Column('time', DateTime, nullable=False) time = Column('time', DateTime, nullable=False)
post = relationship('Post') post = relationship('Post')
@ -49,6 +69,7 @@ class PostFavorite(Base):
'User', 'User',
backref=backref('post_favorites', cascade='all, delete-orphan')) backref=backref('post_favorites', cascade='all, delete-orphan'))
class PostNote(Base): class PostNote(Base):
__tablename__ = 'post_note' __tablename__ = 'post_note'
@ -60,23 +81,37 @@ class PostNote(Base):
post = relationship('Post') post = relationship('Post')
class PostRelation(Base): class PostRelation(Base):
__tablename__ = 'post_relation' __tablename__ = 'post_relation'
parent_id = Column( parent_id = Column(
'parent_id', Integer, ForeignKey('post.id'), primary_key=True, index=True) 'parent_id',
Integer,
ForeignKey('post.id'),
primary_key=True,
index=True)
child_id = Column( child_id = Column(
'child_id', Integer, ForeignKey('post.id'), primary_key=True, index=True) 'child_id',
Integer,
ForeignKey('post.id'),
primary_key=True,
index=True)
def __init__(self, parent_id, child_id): def __init__(self, parent_id, child_id):
self.parent_id = parent_id self.parent_id = parent_id
self.child_id = child_id self.child_id = child_id
class PostTag(Base): class PostTag(Base):
__tablename__ = 'post_tag' __tablename__ = 'post_tag'
post_id = Column( post_id = Column(
'post_id', Integer, ForeignKey('post.id'), primary_key=True, index=True) 'post_id',
Integer,
ForeignKey('post.id'),
primary_key=True,
index=True)
tag_id = Column( tag_id = Column(
'tag_id', Integer, ForeignKey('tag.id'), primary_key=True, index=True) 'tag_id', Integer, ForeignKey('tag.id'), primary_key=True, index=True)
@ -84,6 +119,7 @@ class PostTag(Base):
self.post_id = post_id self.post_id = post_id
self.tag_id = tag_id self.tag_id = tag_id
class Post(Base): class Post(Base):
__tablename__ = 'post' __tablename__ = 'post'
@ -136,8 +172,8 @@ class Post(Base):
# dynamic columns # dynamic columns
tag_count = column_property( tag_count = column_property(
select([func.count(PostTag.tag_id)]) \ select([func.count(PostTag.tag_id)])
.where(PostTag.post_id == post_id) \ .where(PostTag.post_id == post_id)
.correlate_except(PostTag)) .correlate_except(PostTag))
canvas_area = column_property(canvas_width * canvas_height) canvas_area = column_property(canvas_width * canvas_height)
@ -151,53 +187,53 @@ class Post(Base):
return featured_post and featured_post.post_id == self.post_id return featured_post and featured_post.post_id == self.post_id
score = column_property( score = column_property(
select([func.coalesce(func.sum(PostScore.score), 0)]) \ select([func.coalesce(func.sum(PostScore.score), 0)])
.where(PostScore.post_id == post_id) \ .where(PostScore.post_id == post_id)
.correlate_except(PostScore)) .correlate_except(PostScore))
favorite_count = column_property( favorite_count = column_property(
select([func.count(PostFavorite.post_id)]) \ select([func.count(PostFavorite.post_id)])
.where(PostFavorite.post_id == post_id) \ .where(PostFavorite.post_id == post_id)
.correlate_except(PostFavorite)) .correlate_except(PostFavorite))
last_favorite_time = column_property( last_favorite_time = column_property(
select([func.max(PostFavorite.time)]) \ select([func.max(PostFavorite.time)])
.where(PostFavorite.post_id == post_id) \ .where(PostFavorite.post_id == post_id)
.correlate_except(PostFavorite)) .correlate_except(PostFavorite))
feature_count = column_property( feature_count = column_property(
select([func.count(PostFeature.post_id)]) \ select([func.count(PostFeature.post_id)])
.where(PostFeature.post_id == post_id) \ .where(PostFeature.post_id == post_id)
.correlate_except(PostFeature)) .correlate_except(PostFeature))
last_feature_time = column_property( last_feature_time = column_property(
select([func.max(PostFeature.time)]) \ select([func.max(PostFeature.time)])
.where(PostFeature.post_id == post_id) \ .where(PostFeature.post_id == post_id)
.correlate_except(PostFeature)) .correlate_except(PostFeature))
comment_count = column_property( comment_count = column_property(
select([func.count(Comment.post_id)]) \ select([func.count(Comment.post_id)])
.where(Comment.post_id == post_id) \ .where(Comment.post_id == post_id)
.correlate_except(Comment)) .correlate_except(Comment))
last_comment_creation_time = column_property( last_comment_creation_time = column_property(
select([func.max(Comment.creation_time)]) \ select([func.max(Comment.creation_time)])
.where(Comment.post_id == post_id) \ .where(Comment.post_id == post_id)
.correlate_except(Comment)) .correlate_except(Comment))
last_comment_edit_time = column_property( last_comment_edit_time = column_property(
select([func.max(Comment.last_edit_time)]) \ select([func.max(Comment.last_edit_time)])
.where(Comment.post_id == post_id) \ .where(Comment.post_id == post_id)
.correlate_except(Comment)) .correlate_except(Comment))
note_count = column_property( note_count = column_property(
select([func.count(PostNote.post_id)]) \ select([func.count(PostNote.post_id)])
.where(PostNote.post_id == post_id) \ .where(PostNote.post_id == post_id)
.correlate_except(PostNote)) .correlate_except(PostNote))
relation_count = column_property( relation_count = column_property(
select([func.count(PostRelation.child_id)]) \ select([func.count(PostRelation.child_id)])
.where( .where(
(PostRelation.parent_id == post_id) \ (PostRelation.parent_id == post_id)
| (PostRelation.child_id == post_id)) \ | (PostRelation.child_id == post_id))
.correlate_except(PostRelation)) .correlate_except(PostRelation))

View file

@ -1,6 +1,7 @@
import sqlalchemy import sqlalchemy
from szurubooru import config from szurubooru import config
class QueryCounter(object): class QueryCounter(object):
_query_count = 0 _query_count = 0
@ -16,6 +17,7 @@ class QueryCounter(object):
def get(): def get():
return QueryCounter._query_count return QueryCounter._query_count
def create_session(): def create_session():
_engine = sqlalchemy.create_engine( _engine = sqlalchemy.create_engine(
'{schema}://{user}:{password}@{host}:{port}/{name}'.format( '{schema}://{user}:{password}@{host}:{port}/{name}'.format(
@ -30,6 +32,7 @@ def create_session():
_session_maker = sqlalchemy.orm.sessionmaker(bind=_engine) _session_maker = sqlalchemy.orm.sessionmaker(bind=_engine)
return sqlalchemy.orm.scoped_session(_session_maker) return sqlalchemy.orm.scoped_session(_session_maker)
# pylint: disable=invalid-name # pylint: disable=invalid-name
session = create_session() session = create_session()
reset_query_count = QueryCounter.reset reset_query_count = QueryCounter.reset

View file

@ -1,7 +1,9 @@
from sqlalchemy import Column, Integer, DateTime, Unicode, PickleType, ForeignKey
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from sqlalchemy import (
Column, Integer, DateTime, Unicode, PickleType, ForeignKey)
from szurubooru.db.base import Base from szurubooru.db.base import Base
class Snapshot(Base): class Snapshot(Base):
__tablename__ = 'snapshot' __tablename__ = 'snapshot'
@ -11,7 +13,8 @@ class Snapshot(Base):
snapshot_id = Column('id', Integer, primary_key=True) snapshot_id = Column('id', Integer, primary_key=True)
creation_time = Column('creation_time', DateTime, nullable=False) creation_time = Column('creation_time', DateTime, nullable=False)
resource_type = Column('resource_type', Unicode(32), nullable=False, index=True) resource_type = Column(
'resource_type', Unicode(32), nullable=False, index=True)
resource_id = Column('resource_id', Integer, nullable=False, index=True) resource_id = Column('resource_id', Integer, nullable=False, index=True)
resource_repr = Column('resource_repr', Unicode(64), nullable=False) resource_repr = Column('resource_repr', Unicode(64), nullable=False)
operation = Column('operation', Unicode(16), nullable=False) operation = Column('operation', Unicode(16), nullable=False)

View file

@ -1,49 +1,73 @@
from sqlalchemy import Column, Integer, DateTime, Unicode, UnicodeText, ForeignKey from sqlalchemy import (
Column, Integer, DateTime, Unicode, UnicodeText, ForeignKey)
from sqlalchemy.orm import relationship, column_property from sqlalchemy.orm import relationship, column_property
from sqlalchemy.sql.expression import func, select from sqlalchemy.sql.expression import func, select
from szurubooru.db.base import Base from szurubooru.db.base import Base
from szurubooru.db.post import PostTag from szurubooru.db.post import PostTag
class TagSuggestion(Base): class TagSuggestion(Base):
__tablename__ = 'tag_suggestion' __tablename__ = 'tag_suggestion'
parent_id = Column( parent_id = Column(
'parent_id', Integer, ForeignKey('tag.id'), primary_key=True, index=True) 'parent_id',
Integer,
ForeignKey('tag.id'),
primary_key=True, index=True)
child_id = Column( child_id = Column(
'child_id', Integer, ForeignKey('tag.id'), primary_key=True, index=True) 'child_id',
Integer,
ForeignKey('tag.id'),
primary_key=True, index=True)
def __init__(self, parent_id, child_id): def __init__(self, parent_id, child_id):
self.parent_id = parent_id self.parent_id = parent_id
self.child_id = child_id self.child_id = child_id
class TagImplication(Base): class TagImplication(Base):
__tablename__ = 'tag_implication' __tablename__ = 'tag_implication'
parent_id = Column( parent_id = Column(
'parent_id', Integer, ForeignKey('tag.id'), primary_key=True, index=True) 'parent_id',
Integer,
ForeignKey('tag.id'),
primary_key=True,
index=True)
child_id = Column( child_id = Column(
'child_id', Integer, ForeignKey('tag.id'), primary_key=True, index=True) 'child_id',
Integer,
ForeignKey('tag.id'),
primary_key=True,
index=True)
def __init__(self, parent_id, child_id): def __init__(self, parent_id, child_id):
self.parent_id = parent_id self.parent_id = parent_id
self.child_id = child_id self.child_id = child_id
class TagName(Base): class TagName(Base):
__tablename__ = 'tag_name' __tablename__ = 'tag_name'
tag_name_id = Column('tag_name_id', Integer, primary_key=True) tag_name_id = Column('tag_name_id', Integer, primary_key=True)
tag_id = Column('tag_id', Integer, ForeignKey('tag.id'), nullable=False, index=True) tag_id = Column(
'tag_id', Integer, ForeignKey('tag.id'), nullable=False, index=True)
name = Column('name', Unicode(64), nullable=False, unique=True) name = Column('name', Unicode(64), nullable=False, unique=True)
def __init__(self, name): def __init__(self, name):
self.name = name self.name = name
class Tag(Base): class Tag(Base):
__tablename__ = 'tag' __tablename__ = 'tag'
tag_id = Column('id', Integer, primary_key=True) tag_id = Column('id', Integer, primary_key=True)
category_id = Column( category_id = Column(
'category_id', Integer, ForeignKey('tag_category.id'), nullable=False, index=True) 'category_id',
Integer,
ForeignKey('tag_category.id'),
nullable=False,
index=True)
version = Column('version', Integer, default=1, nullable=False) version = Column('version', Integer, default=1, nullable=False)
creation_time = Column('creation_time', DateTime, nullable=False) creation_time = Column('creation_time', DateTime, nullable=False)
last_edit_time = Column('last_edit_time', DateTime) last_edit_time = Column('last_edit_time', DateTime)
@ -69,25 +93,25 @@ class Tag(Base):
lazy='joined') lazy='joined')
post_count = column_property( post_count = column_property(
select([func.count(PostTag.post_id)]) \ select([func.count(PostTag.post_id)])
.where(PostTag.tag_id == tag_id) \ .where(PostTag.tag_id == tag_id)
.correlate_except(PostTag)) .correlate_except(PostTag))
first_name = column_property( first_name = column_property(
select([TagName.name]) \ select([TagName.name])
.where(TagName.tag_id == tag_id) \ .where(TagName.tag_id == tag_id)
.limit(1) \ .limit(1)
.as_scalar(), .as_scalar(),
deferred=True) deferred=True)
suggestion_count = column_property( suggestion_count = column_property(
select([func.count(TagSuggestion.child_id)]) \ select([func.count(TagSuggestion.child_id)])
.where(TagSuggestion.parent_id == tag_id) \ .where(TagSuggestion.parent_id == tag_id)
.as_scalar(), .as_scalar(),
deferred=True) deferred=True)
implication_count = column_property( implication_count = column_property(
select([func.count(TagImplication.child_id)]) \ select([func.count(TagImplication.child_id)])
.where(TagImplication.parent_id == tag_id) \ .where(TagImplication.parent_id == tag_id)
.as_scalar(), .as_scalar(),
deferred=True) deferred=True)

View file

@ -4,6 +4,7 @@ from sqlalchemy.sql.expression import func, select
from szurubooru.db.base import Base from szurubooru.db.base import Base
from szurubooru.db.tag import Tag from szurubooru.db.tag import Tag
class TagCategory(Base): class TagCategory(Base):
__tablename__ = 'tag_category' __tablename__ = 'tag_category'
@ -17,6 +18,6 @@ class TagCategory(Base):
self.name = name self.name = name
tag_count = column_property( tag_count = column_property(
select([func.count('Tag.tag_id')]) \ select([func.count('Tag.tag_id')])
.where(Tag.category_id == tag_category_id) \ .where(Tag.category_id == tag_category_id)
.correlate_except(table('Tag'))) .correlate_except(table('Tag')))

View file

@ -5,6 +5,7 @@ from szurubooru.db.base import Base
from szurubooru.db.post import Post, PostScore, PostFavorite from szurubooru.db.post import Post, PostScore, PostFavorite
from szurubooru.db.comment import Comment from szurubooru.db.comment import Comment
class User(Base): class User(Base):
__tablename__ = 'user' __tablename__ = 'user'
@ -17,7 +18,7 @@ class User(Base):
RANK_POWER = 'power' RANK_POWER = 'power'
RANK_MODERATOR = 'moderator' RANK_MODERATOR = 'moderator'
RANK_ADMINISTRATOR = 'administrator' RANK_ADMINISTRATOR = 'administrator'
RANK_NOBODY = 'nobody' # used for privileges: "nobody can be higher than admin" RANK_NOBODY = 'nobody' # unattainable, used for privileges
user_id = Column('id', Integer, primary_key=True) user_id = Column('id', Integer, primary_key=True)
creation_time = Column('creation_time', DateTime, nullable=False) creation_time = Column('creation_time', DateTime, nullable=False)
@ -36,41 +37,41 @@ class User(Base):
@property @property
def post_count(self): def post_count(self):
from szurubooru.db import session from szurubooru.db import session
return session \ return (session
.query(func.sum(1)) \ .query(func.sum(1))
.filter(Post.user_id == self.user_id) \ .filter(Post.user_id == self.user_id)
.one()[0] or 0 .one()[0] or 0)
@property @property
def comment_count(self): def comment_count(self):
from szurubooru.db import session from szurubooru.db import session
return session \ return (session
.query(func.sum(1)) \ .query(func.sum(1))
.filter(Comment.user_id == self.user_id) \ .filter(Comment.user_id == self.user_id)
.one()[0] or 0 .one()[0] or 0)
@property @property
def favorite_post_count(self): def favorite_post_count(self):
from szurubooru.db import session from szurubooru.db import session
return session \ return (session
.query(func.sum(1)) \ .query(func.sum(1))
.filter(PostFavorite.user_id == self.user_id) \ .filter(PostFavorite.user_id == self.user_id)
.one()[0] or 0 .one()[0] or 0)
@property @property
def liked_post_count(self): def liked_post_count(self):
from szurubooru.db import session from szurubooru.db import session
return session \ return (session
.query(func.sum(1)) \ .query(func.sum(1))
.filter(PostScore.user_id == self.user_id) \ .filter(PostScore.user_id == self.user_id)
.filter(PostScore.score == 1) \ .filter(PostScore.score == 1)
.one()[0] or 0 .one()[0] or 0)
@property @property
def disliked_post_count(self): def disliked_post_count(self):
from szurubooru.db import session from szurubooru.db import session
return session \ return (session
.query(func.sum(1)) \ .query(func.sum(1))
.filter(PostScore.user_id == self.user_id) \ .filter(PostScore.user_id == self.user_id)
.filter(PostScore.score == -1) \ .filter(PostScore.score == -1)
.one()[0] or 0 .one()[0] or 0)

View file

@ -1,5 +1,6 @@
from sqlalchemy.inspection import inspect from sqlalchemy.inspection import inspect
def get_resource_info(entity): def get_resource_info(entity):
serializers = { serializers = {
'tag': lambda tag: tag.first_name, 'tag': lambda tag: tag.first_name,
@ -23,6 +24,7 @@ def get_resource_info(entity):
return (resource_type, resource_id, resource_repr) return (resource_type, resource_id, resource_repr)
def get_aux_entity(session, get_table_info, entity, user): def get_aux_entity(session, get_table_info, entity, user):
table, get_column = get_table_info(entity) table, get_column = get_table_info(entity)
return session \ return session \

View file

@ -1,11 +1,38 @@
class ConfigError(RuntimeError): pass class ConfigError(RuntimeError):
class AuthError(RuntimeError): pass pass
class IntegrityError(RuntimeError): pass
class ValidationError(RuntimeError): pass
class SearchError(RuntimeError): pass
class NotFoundError(RuntimeError): pass
class ProcessingError(RuntimeError): pass
class MissingRequiredFileError(ValidationError): pass
class MissingRequiredParameterError(ValidationError): pass class AuthError(RuntimeError):
class InvalidParameterError(ValidationError): pass pass
class IntegrityError(RuntimeError):
pass
class ValidationError(RuntimeError):
pass
class SearchError(RuntimeError):
pass
class NotFoundError(RuntimeError):
pass
class ProcessingError(RuntimeError):
pass
class MissingRequiredFileError(ValidationError):
pass
class MissingRequiredParameterError(ValidationError):
pass
class InvalidParameterError(ValidationError):
pass

View file

@ -7,30 +7,37 @@ from szurubooru import config, errors, rest
# pylint: disable=unused-import # pylint: disable=unused-import
from szurubooru import api, middleware from szurubooru import api, middleware
def _on_auth_error(ex): def _on_auth_error(ex):
raise rest.errors.HttpForbidden( raise rest.errors.HttpForbidden(
title='Authentication error', description=str(ex)) title='Authentication error', description=str(ex))
def _on_validation_error(ex): def _on_validation_error(ex):
raise rest.errors.HttpBadRequest( raise rest.errors.HttpBadRequest(
title='Validation error', description=str(ex)) title='Validation error', description=str(ex))
def _on_search_error(ex): def _on_search_error(ex):
raise rest.errors.HttpBadRequest( raise rest.errors.HttpBadRequest(
title='Search error', description=str(ex)) title='Search error', description=str(ex))
def _on_integrity_error(ex): def _on_integrity_error(ex):
raise rest.errors.HttpConflict( raise rest.errors.HttpConflict(
title='Integrity violation', description=ex.args[0]) title='Integrity violation', description=ex.args[0])
def _on_not_found_error(ex): def _on_not_found_error(ex):
raise rest.errors.HttpNotFound( raise rest.errors.HttpNotFound(
title='Not found', description=str(ex)) title='Not found', description=str(ex))
def _on_processing_error(ex): def _on_processing_error(ex):
raise rest.errors.HttpBadRequest( raise rest.errors.HttpBadRequest(
title='Processing error', description=str(ex)) title='Processing error', description=str(ex))
def validate_config(): def validate_config():
''' '''
Check whether config doesn't contain errors that might prove Check whether config doesn't contain errors that might prove
@ -60,6 +67,7 @@ def validate_config():
raise errors.ConfigError( raise errors.ConfigError(
'Database is not configured: %r is missing' % key) 'Database is not configured: %r is missing' % key)
def create_app(): def create_app():
''' Create a WSGI compatible App object. ''' ''' Create a WSGI compatible App object. '''
validate_config() validate_config()

View file

@ -4,6 +4,7 @@ from collections import OrderedDict
from szurubooru import config, db, errors from szurubooru import config, db, errors
from szurubooru.func import util from szurubooru.func import util
RANK_MAP = OrderedDict([ RANK_MAP = OrderedDict([
(db.User.RANK_ANONYMOUS, 'anonymous'), (db.User.RANK_ANONYMOUS, 'anonymous'),
(db.User.RANK_RESTRICTED, 'restricted'), (db.User.RANK_RESTRICTED, 'restricted'),
@ -14,6 +15,7 @@ RANK_MAP = OrderedDict([
(db.User.RANK_NOBODY, 'nobody'), (db.User.RANK_NOBODY, 'nobody'),
]) ])
def get_password_hash(salt, password): def get_password_hash(salt, password):
''' Retrieve new-style password hash. ''' ''' Retrieve new-style password hash. '''
digest = hashlib.sha256() digest = hashlib.sha256()
@ -22,6 +24,7 @@ def get_password_hash(salt, password):
digest.update(password.encode('utf8')) digest.update(password.encode('utf8'))
return digest.hexdigest() return digest.hexdigest()
def get_legacy_password_hash(salt, password): def get_legacy_password_hash(salt, password):
''' Retrieve old-style password hash. ''' ''' Retrieve old-style password hash. '''
digest = hashlib.sha1() digest = hashlib.sha1()
@ -30,6 +33,7 @@ def get_legacy_password_hash(salt, password):
digest.update(password.encode('utf8')) digest.update(password.encode('utf8'))
return digest.hexdigest() return digest.hexdigest()
def create_password(): def create_password():
alphabet = { alphabet = {
'c': list('bcdfghijklmnpqrstvwxyz'), 'c': list('bcdfghijklmnpqrstvwxyz'),
@ -39,6 +43,7 @@ def create_password():
pattern = 'cvcvnncvcv' pattern = 'cvcvnncvcv'
return ''.join(random.choice(alphabet[l]) for l in list(pattern)) return ''.join(random.choice(alphabet[l]) for l in list(pattern))
def is_valid_password(user, password): def is_valid_password(user, password):
assert user assert user
salt, valid_hash = user.password_salt, user.password_hash salt, valid_hash = user.password_salt, user.password_hash
@ -48,6 +53,7 @@ def is_valid_password(user, password):
] ]
return valid_hash in possible_hashes return valid_hash in possible_hashes
def has_privilege(user, privilege_name): def has_privilege(user, privilege_name):
assert user assert user
all_ranks = list(RANK_MAP.keys()) all_ranks = list(RANK_MAP.keys())
@ -58,11 +64,13 @@ def has_privilege(user, privilege_name):
good_ranks = all_ranks[all_ranks.index(minimal_rank):] good_ranks = all_ranks[all_ranks.index(minimal_rank):]
return user.rank in good_ranks return user.rank in good_ranks
def verify_privilege(user, privilege_name): def verify_privilege(user, privilege_name):
assert user assert user
if not has_privilege(user, privilege_name): if not has_privilege(user, privilege_name):
raise errors.AuthError('Insufficient privileges to do this.') raise errors.AuthError('Insufficient privileges to do this.')
def generate_authentication_token(user): def generate_authentication_token(user):
''' Generate nonguessable challenge (e.g. links in password reminder). ''' ''' Generate nonguessable challenge (e.g. links in password reminder). '''
assert user assert user

View file

@ -1,11 +1,13 @@
from datetime import datetime from datetime import datetime
class LruCacheItem(object): class LruCacheItem(object):
def __init__(self, key, value): def __init__(self, key, value):
self.key = key self.key = key
self.value = value self.value = value
self.timestamp = datetime.utcnow() self.timestamp = datetime.utcnow()
class LruCache(object): class LruCache(object):
def __init__(self, length, delta=None): def __init__(self, length, delta=None):
self.length = length self.length = length
@ -15,12 +17,13 @@ class LruCache(object):
def insert_item(self, item): def insert_item(self, item):
if item.key in self.hash: if item.key in self.hash:
item_index = next(i \ item_index = next(
for i, v in enumerate(self.item_list) \ i
for i, v in enumerate(self.item_list)
if v.key == item.key) if v.key == item.key)
self.item_list[:] \ self.item_list[:] \
= self.item_list[:item_index] \ = self.item_list[:item_index] \
+ self.item_list[item_index+1:] + self.item_list[item_index + 1:]
self.item_list.insert(0, item) self.item_list.insert(0, item)
else: else:
if len(self.item_list) > self.length: if len(self.item_list) > self.length:
@ -36,16 +39,21 @@ class LruCache(object):
del self.hash[item.key] del self.hash[item.key]
del self.item_list[self.item_list.index(item)] del self.item_list[self.item_list.index(item)]
_CACHE = LruCache(length=100) _CACHE = LruCache(length=100)
def purge(): def purge():
_CACHE.remove_all() _CACHE.remove_all()
def has(key): def has(key):
return key in _CACHE.hash return key in _CACHE.hash
def get(key): def get(key):
return _CACHE.hash[key].value return _CACHE.hash[key].value
def put(key, value): def put(key, value):
_CACHE.insert_item(LruCacheItem(key, value)) _CACHE.insert_item(LruCacheItem(key, value))

View file

@ -2,16 +2,26 @@ import datetime
from szurubooru import db, errors from szurubooru import db, errors
from szurubooru.func import users, scores, util from szurubooru.func import users, scores, util
class InvalidCommentIdError(errors.ValidationError): pass
class CommentNotFoundError(errors.NotFoundError): pass class InvalidCommentIdError(errors.ValidationError):
class EmptyCommentTextError(errors.ValidationError): pass pass
class CommentNotFoundError(errors.NotFoundError):
pass
class EmptyCommentTextError(errors.ValidationError):
pass
def serialize_comment(comment, auth_user, options=None): def serialize_comment(comment, auth_user, options=None):
return util.serialize_entity( return util.serialize_entity(
comment, comment,
{ {
'id': lambda: comment.comment_id, 'id': lambda: comment.comment_id,
'user': lambda: users.serialize_micro_user(comment.user, auth_user), 'user':
lambda: users.serialize_micro_user(comment.user, auth_user),
'postId': lambda: comment.post.post_id, 'postId': lambda: comment.post.post_id,
'version': lambda: comment.version, 'version': lambda: comment.version,
'text': lambda: comment.text, 'text': lambda: comment.text,
@ -22,6 +32,7 @@ def serialize_comment(comment, auth_user, options=None):
}, },
options) options)
def try_get_comment_by_id(comment_id): def try_get_comment_by_id(comment_id):
try: try:
comment_id = int(comment_id) comment_id = int(comment_id)
@ -32,12 +43,14 @@ def try_get_comment_by_id(comment_id):
.filter(db.Comment.comment_id == comment_id) \ .filter(db.Comment.comment_id == comment_id) \
.one_or_none() .one_or_none()
def get_comment_by_id(comment_id): def get_comment_by_id(comment_id):
comment = try_get_comment_by_id(comment_id) comment = try_get_comment_by_id(comment_id)
if comment: if comment:
return comment return comment
raise CommentNotFoundError('Comment %r not found.' % comment_id) raise CommentNotFoundError('Comment %r not found.' % comment_id)
def create_comment(user, post, text): def create_comment(user, post, text):
comment = db.Comment() comment = db.Comment()
comment.user = user comment.user = user
@ -46,6 +59,7 @@ def create_comment(user, post, text):
comment.creation_time = datetime.datetime.utcnow() comment.creation_time = datetime.datetime.utcnow()
return comment return comment
def update_comment_text(comment, text): def update_comment_text(comment, text):
assert comment assert comment
if not text: if not text:

View file

@ -2,7 +2,10 @@ import datetime
from szurubooru import db, errors from szurubooru import db, errors
from szurubooru.func import scores from szurubooru.func import scores
class InvalidFavoriteTargetError(errors.ValidationError): pass
class InvalidFavoriteTargetError(errors.ValidationError):
pass
def _get_table_info(entity): def _get_table_info(entity):
assert entity assert entity
@ -11,16 +14,19 @@ def _get_table_info(entity):
return db.PostFavorite, lambda table: table.post_id return db.PostFavorite, lambda table: table.post_id
raise InvalidFavoriteTargetError() raise InvalidFavoriteTargetError()
def _get_fav_entity(entity, user): def _get_fav_entity(entity, user):
assert entity assert entity
assert user assert user
return db.util.get_aux_entity(db.session, _get_table_info, entity, user) return db.util.get_aux_entity(db.session, _get_table_info, entity, user)
def has_favorited(entity, user): def has_favorited(entity, user):
assert entity assert entity
assert user assert user
return _get_fav_entity(entity, user) is not None return _get_fav_entity(entity, user) is not None
def unset_favorite(entity, user): def unset_favorite(entity, user):
assert entity assert entity
assert user assert user
@ -28,6 +34,7 @@ def unset_favorite(entity, user):
if fav_entity: if fav_entity:
db.session.delete(fav_entity) db.session.delete(fav_entity)
def set_favorite(entity, user): def set_favorite(entity, user):
assert entity assert entity
assert user assert user

View file

@ -1,20 +1,25 @@
import os import os
from szurubooru import config from szurubooru import config
def _get_full_path(path): def _get_full_path(path):
return os.path.join(config.config['data_dir'], path) return os.path.join(config.config['data_dir'], path)
def delete(path): def delete(path):
full_path = _get_full_path(path) full_path = _get_full_path(path)
if os.path.exists(full_path): if os.path.exists(full_path):
os.unlink(full_path) os.unlink(full_path)
def has(path): def has(path):
return os.path.exists(_get_full_path(path)) return os.path.exists(_get_full_path(path))
def move(source_path, target_path): def move(source_path, target_path):
return os.rename(_get_full_path(source_path), _get_full_path(target_path)) return os.rename(_get_full_path(source_path), _get_full_path(target_path))
def get(path): def get(path):
full_path = _get_full_path(path) full_path = _get_full_path(path)
if not os.path.exists(full_path): if not os.path.exists(full_path):
@ -22,6 +27,7 @@ def get(path):
with open(full_path, 'rb') as handle: with open(full_path, 'rb') as handle:
return handle.read() return handle.read()
def save(path, content): def save(path, content):
full_path = _get_full_path(path) full_path = _get_full_path(path)
os.makedirs(os.path.dirname(full_path), exist_ok=True) os.makedirs(os.path.dirname(full_path), exist_ok=True)

View file

@ -6,11 +6,14 @@ import math
from szurubooru import errors from szurubooru import errors
from szurubooru.func import mime, util from szurubooru.func import mime, util
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_SCALE_FIT_FMT = \ _SCALE_FIT_FMT = \
r'scale=iw*max({width}/iw\,{height}/ih):ih*max({width}/iw\,{height}/ih)' r'scale=iw*max({width}/iw\,{height}/ih):ih*max({width}/iw\,{height}/ih)'
class Image(object): class Image(object):
def __init__(self, content): def __init__(self, content):
self.content = content self.content = content
@ -38,12 +41,13 @@ class Image(object):
'-', '-',
] ]
if 'duration' in self.info['format'] \ if 'duration' in self.info['format'] \
and float(self.info['format']['duration']) > 3 \
and self.info['format']['format_name'] != 'swf': and self.info['format']['format_name'] != 'swf':
cli = [ duration = float(self.info['format']['duration'])
'-ss', if duration > 3:
'%d' % math.floor(float(self.info['format']['duration']) * 0.3), cli = [
] + cli '-ss',
'%d' % math.floor(duration * 0.3),
] + cli
self.content = self._execute(cli) self.content = self._execute(cli)
assert self.content assert self.content
self._reload_info() self._reload_info()

View file

@ -2,6 +2,7 @@ import smtplib
import email.mime.text import email.mime.text
from szurubooru import config from szurubooru import config
def send_mail(sender, recipient, subject, body): def send_mail(sender, recipient, subject, body):
msg = email.mime.text.MIMEText(body) msg = email.mime.text.MIMEText(body)
msg['Subject'] = subject msg['Subject'] = subject

View file

@ -1,6 +1,6 @@
import re import re
# pylint: disable=too-many-return-statements
def get_mime_type(content): def get_mime_type(content):
if not content: if not content:
return 'application/octet-stream' return 'application/octet-stream'
@ -25,6 +25,7 @@ def get_mime_type(content):
return 'application/octet-stream' return 'application/octet-stream'
def get_extension(mime_type): def get_extension(mime_type):
extension_map = { extension_map = {
'application/x-shockwave-flash': 'swf', 'application/x-shockwave-flash': 'swf',
@ -37,15 +38,19 @@ def get_extension(mime_type):
} }
return extension_map.get((mime_type or '').strip().lower(), None) return extension_map.get((mime_type or '').strip().lower(), None)
def is_flash(mime_type): def is_flash(mime_type):
return mime_type.lower() == 'application/x-shockwave-flash' return mime_type.lower() == 'application/x-shockwave-flash'
def is_video(mime_type): def is_video(mime_type):
return mime_type.lower() in ('application/ogg', 'video/mp4', 'video/webm') return mime_type.lower() in ('application/ogg', 'video/mp4', 'video/webm')
def is_image(mime_type): def is_image(mime_type):
return mime_type.lower() in ('image/jpeg', 'image/png', 'image/gif') return mime_type.lower() in ('image/jpeg', 'image/png', 'image/gif')
def is_animated_gif(content): def is_animated_gif(content):
return get_mime_type(content) == 'image/gif' \ return get_mime_type(content) == 'image/gif' \
and len(re.findall(b'\x21\xF9\x04.{4}\x00[\x2C\x21]', content)) > 1 and len(re.findall(b'\x21\xF9\x04.{4}\x00[\x2C\x21]', content)) > 1

View file

@ -1,5 +1,6 @@
import urllib.request import urllib.request
def download(url): def download(url):
assert url assert url
request = urllib.request.Request(url) request = urllib.request.Request(url)

View file

@ -4,37 +4,71 @@ from szurubooru import config, db, errors
from szurubooru.func import ( from szurubooru.func import (
users, snapshots, scores, comments, tags, util, mime, images, files) users, snapshots, scores, comments, tags, util, mime, images, files)
EMPTY_PIXEL = \ EMPTY_PIXEL = \
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' \ b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' \
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \ b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b' b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b'
class PostNotFoundError(errors.NotFoundError): pass
class PostAlreadyFeaturedError(errors.ValidationError): pass class PostNotFoundError(errors.NotFoundError):
class PostAlreadyUploadedError(errors.ValidationError): pass pass
class InvalidPostIdError(errors.ValidationError): pass
class InvalidPostSafetyError(errors.ValidationError): pass
class InvalidPostSourceError(errors.ValidationError): pass class PostAlreadyFeaturedError(errors.ValidationError):
class InvalidPostContentError(errors.ValidationError): pass pass
class InvalidPostRelationError(errors.ValidationError): pass
class InvalidPostNoteError(errors.ValidationError): pass
class InvalidPostFlagError(errors.ValidationError): pass class PostAlreadyUploadedError(errors.ValidationError):
pass
class InvalidPostIdError(errors.ValidationError):
pass
class InvalidPostSafetyError(errors.ValidationError):
pass
class InvalidPostSourceError(errors.ValidationError):
pass
class InvalidPostContentError(errors.ValidationError):
pass
class InvalidPostRelationError(errors.ValidationError):
pass
class InvalidPostNoteError(errors.ValidationError):
pass
class InvalidPostFlagError(errors.ValidationError):
pass
SAFETY_MAP = { SAFETY_MAP = {
db.Post.SAFETY_SAFE: 'safe', db.Post.SAFETY_SAFE: 'safe',
db.Post.SAFETY_SKETCHY: 'sketchy', db.Post.SAFETY_SKETCHY: 'sketchy',
db.Post.SAFETY_UNSAFE: 'unsafe', db.Post.SAFETY_UNSAFE: 'unsafe',
} }
TYPE_MAP = { TYPE_MAP = {
db.Post.TYPE_IMAGE: 'image', db.Post.TYPE_IMAGE: 'image',
db.Post.TYPE_ANIMATION: 'animation', db.Post.TYPE_ANIMATION: 'animation',
db.Post.TYPE_VIDEO: 'video', db.Post.TYPE_VIDEO: 'video',
db.Post.TYPE_FLASH: 'flash', db.Post.TYPE_FLASH: 'flash',
} }
FLAG_MAP = { FLAG_MAP = {
db.Post.FLAG_LOOP: 'loop', db.Post.FLAG_LOOP: 'loop',
} }
def get_post_content_url(post): def get_post_content_url(post):
assert post assert post
return '%s/posts/%d.%s' % ( return '%s/posts/%d.%s' % (
@ -42,25 +76,30 @@ def get_post_content_url(post):
post.post_id, post.post_id,
mime.get_extension(post.mime_type) or 'dat') mime.get_extension(post.mime_type) or 'dat')
def get_post_thumbnail_url(post): def get_post_thumbnail_url(post):
assert post assert post
return '%s/generated-thumbnails/%d.jpg' % ( return '%s/generated-thumbnails/%d.jpg' % (
config.config['data_url'].rstrip('/'), config.config['data_url'].rstrip('/'),
post.post_id) post.post_id)
def get_post_content_path(post): def get_post_content_path(post):
assert post assert post
return 'posts/%d.%s' % ( return 'posts/%d.%s' % (
post.post_id, mime.get_extension(post.mime_type) or 'dat') post.post_id, mime.get_extension(post.mime_type) or 'dat')
def get_post_thumbnail_path(post): def get_post_thumbnail_path(post):
assert post assert post
return 'generated-thumbnails/%d.jpg' % (post.post_id) return 'generated-thumbnails/%d.jpg' % (post.post_id)
def get_post_thumbnail_backup_path(post): def get_post_thumbnail_backup_path(post):
assert post assert post
return 'posts/custom-thumbnails/%d.dat' % (post.post_id) return 'posts/custom-thumbnails/%d.dat' % (post.post_id)
def serialize_note(note): def serialize_note(note):
assert note assert note
return { return {
@ -68,6 +107,7 @@ def serialize_note(note):
'text': note.text, 'text': note.text,
} }
def serialize_post(post, auth_user, options=None): def serialize_post(post, auth_user, options=None):
return util.serialize_entity( return util.serialize_entity(
post, post,
@ -93,17 +133,17 @@ def serialize_post(post, auth_user, options=None):
{ {
post['id']: post['id']:
post for post in [ post for post in [
serialize_micro_post(rel, auth_user) \ serialize_micro_post(rel, auth_user)
for rel in post.relations for rel in post.relations]
]
}.values(), }.values(),
key=lambda post: post['id']), key=lambda post: post['id']),
'user': lambda: users.serialize_micro_user(post.user, auth_user), 'user': lambda: users.serialize_micro_user(post.user, auth_user),
'score': lambda: post.score, 'score': lambda: post.score,
'ownScore': lambda: scores.get_score(post, auth_user), 'ownScore': lambda: scores.get_score(post, auth_user),
'ownFavorite': lambda: len( 'ownFavorite': lambda: len([
[user for user in post.favorited_by \ user for user in post.favorited_by
if user.user_id == auth_user.user_id]) > 0, if user.user_id == auth_user.user_id]
) > 0,
'tagCount': lambda: post.tag_count, 'tagCount': lambda: post.tag_count,
'favoriteCount': lambda: post.favorite_count, 'favoriteCount': lambda: post.favorite_count,
'commentCount': lambda: post.comment_count, 'commentCount': lambda: post.comment_count,
@ -112,31 +152,35 @@ def serialize_post(post, auth_user, options=None):
'featureCount': lambda: post.feature_count, 'featureCount': lambda: post.feature_count,
'lastFeatureTime': lambda: post.last_feature_time, 'lastFeatureTime': lambda: post.last_feature_time,
'favoritedBy': lambda: [ 'favoritedBy': lambda: [
users.serialize_micro_user(rel.user, auth_user) \ users.serialize_micro_user(rel.user, auth_user)
for rel in post.favorited_by], for rel in post.favorited_by
],
'hasCustomThumbnail': 'hasCustomThumbnail':
lambda: files.has(get_post_thumbnail_backup_path(post)), lambda: files.has(get_post_thumbnail_backup_path(post)),
'notes': lambda: sorted( 'notes': lambda: sorted(
[serialize_note(note) for note in post.notes], [serialize_note(note) for note in post.notes],
key=lambda x: x['polygon']), key=lambda x: x['polygon']),
'comments': lambda: [ 'comments': lambda: [
comments.serialize_comment(comment, auth_user) \ comments.serialize_comment(comment, auth_user)
for comment in sorted( for comment in sorted(
post.comments, post.comments,
key=lambda comment: comment.creation_time)], key=lambda comment: comment.creation_time)],
'snapshots': lambda: snapshots.get_serialized_history(post), 'snapshots': lambda: snapshots.get_serialized_history(post),
}, },
options) options)
def serialize_micro_post(post, auth_user): def serialize_micro_post(post, auth_user):
return serialize_post( return serialize_post(
post, post,
auth_user=auth_user, auth_user=auth_user,
options=['id', 'thumbnailUrl']) options=['id', 'thumbnailUrl'])
def get_post_count(): def get_post_count():
return db.session.query(sqlalchemy.func.count(db.Post.post_id)).one()[0] return db.session.query(sqlalchemy.func.count(db.Post.post_id)).one()[0]
def try_get_post_by_id(post_id): def try_get_post_by_id(post_id):
try: try:
post_id = int(post_id) post_id = int(post_id)
@ -147,22 +191,26 @@ def try_get_post_by_id(post_id):
.filter(db.Post.post_id == post_id) \ .filter(db.Post.post_id == post_id) \
.one_or_none() .one_or_none()
def get_post_by_id(post_id): def get_post_by_id(post_id):
post = try_get_post_by_id(post_id) post = try_get_post_by_id(post_id)
if not post: if not post:
raise PostNotFoundError('Post %r not found.' % post_id) raise PostNotFoundError('Post %r not found.' % post_id)
return post return post
def try_get_current_post_feature(): def try_get_current_post_feature():
return db.session \ return db.session \
.query(db.PostFeature) \ .query(db.PostFeature) \
.order_by(db.PostFeature.time.desc()) \ .order_by(db.PostFeature.time.desc()) \
.first() .first()
def try_get_featured_post(): def try_get_featured_post():
post_feature = try_get_current_post_feature() post_feature = try_get_current_post_feature()
return post_feature.post if post_feature else None return post_feature.post if post_feature else None
def create_post(content, tag_names, user): def create_post(content, tag_names, user):
post = db.Post() post = db.Post()
post.safety = db.Post.SAFETY_SAFE post.safety = db.Post.SAFETY_SAFE
@ -181,6 +229,7 @@ def create_post(content, tag_names, user):
new_tags = update_post_tags(post, tag_names) new_tags = update_post_tags(post, tag_names)
return (post, new_tags) return (post, new_tags)
def update_post_safety(post, safety): def update_post_safety(post, safety):
assert post assert post
safety = util.flip(SAFETY_MAP).get(safety, None) safety = util.flip(SAFETY_MAP).get(safety, None)
@ -189,12 +238,14 @@ def update_post_safety(post, safety):
'Safety can be either of %r.' % list(SAFETY_MAP.values())) 'Safety can be either of %r.' % list(SAFETY_MAP.values()))
post.safety = safety post.safety = safety
def update_post_source(post, source): def update_post_source(post, source):
assert post assert post
if util.value_exceeds_column_size(source, db.Post.source): if util.value_exceeds_column_size(source, db.Post.source):
raise InvalidPostSourceError('Source is too long.') raise InvalidPostSourceError('Source is too long.')
post.source = source post.source = source
def update_post_content(post, content): def update_post_content(post, content):
assert post assert post
if not content: if not content:
@ -210,7 +261,8 @@ def update_post_content(post, content):
elif mime.is_video(post.mime_type): elif mime.is_video(post.mime_type):
post.type = db.Post.TYPE_VIDEO post.type = db.Post.TYPE_VIDEO
else: else:
raise InvalidPostContentError('Unhandled file type: %r' % post.mime_type) raise InvalidPostContentError(
'Unhandled file type: %r' % post.mime_type)
post.checksum = util.get_md5(content) post.checksum = util.get_md5(content)
other_post = db.session \ other_post = db.session \
@ -236,6 +288,7 @@ def update_post_content(post, content):
files.save(get_post_content_path(post), content) files.save(get_post_content_path(post), content)
update_post_thumbnail(post, content=None, do_delete=False) update_post_thumbnail(post, content=None, do_delete=False)
def update_post_thumbnail(post, content=None, do_delete=True): def update_post_thumbnail(post, content=None, do_delete=True):
assert post assert post
if not content: if not content:
@ -246,6 +299,7 @@ def update_post_thumbnail(post, content=None, do_delete=True):
files.save(get_post_thumbnail_backup_path(post), content) files.save(get_post_thumbnail_backup_path(post), content)
generate_post_thumbnail(post) generate_post_thumbnail(post)
def generate_post_thumbnail(post): def generate_post_thumbnail(post):
assert post assert post
if files.has(get_post_thumbnail_backup_path(post)): if files.has(get_post_thumbnail_backup_path(post)):
@ -261,12 +315,14 @@ def generate_post_thumbnail(post):
except errors.ProcessingError: except errors.ProcessingError:
files.save(get_post_thumbnail_path(post), EMPTY_PIXEL) files.save(get_post_thumbnail_path(post), EMPTY_PIXEL)
def update_post_tags(post, tag_names): def update_post_tags(post, tag_names):
assert post assert post
existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names) existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names)
post.tags = existing_tags + new_tags post.tags = existing_tags + new_tags
return new_tags return new_tags
def update_post_relations(post, new_post_ids): def update_post_relations(post, new_post_ids):
assert post assert post
old_posts = post.relations old_posts = post.relations
@ -287,6 +343,7 @@ def update_post_relations(post, new_post_ids):
post.relations.append(relation) post.relations.append(relation)
relation.relations.append(post) relation.relations.append(post)
def update_post_notes(post, notes): def update_post_notes(post, notes):
assert post assert post
post.notes = [] post.notes = []
@ -323,6 +380,7 @@ def update_post_notes(post, notes):
post.notes.append( post.notes.append(
db.PostNote(polygon=note['polygon'], text=str(note['text']))) db.PostNote(polygon=note['polygon'], text=str(note['text'])))
def update_post_flags(post, flags): def update_post_flags(post, flags):
assert post assert post
target_flags = [] target_flags = []
@ -334,6 +392,7 @@ def update_post_flags(post, flags):
target_flags.append(flag) target_flags.append(flag)
post.flags = target_flags post.flags = target_flags
def feature_post(post, user): def feature_post(post, user):
assert post assert post
post_feature = db.PostFeature() post_feature = db.PostFeature()
@ -342,6 +401,7 @@ def feature_post(post, user):
post_feature.user = user post_feature.user = user
db.session.add(post_feature) db.session.add(post_feature)
def delete(post): def delete(post):
assert post assert post
db.session.delete(post) db.session.delete(post)

View file

@ -2,8 +2,14 @@ import datetime
from szurubooru import db, errors from szurubooru import db, errors
from szurubooru.func import favorites from szurubooru.func import favorites
class InvalidScoreTargetError(errors.ValidationError): pass
class InvalidScoreValueError(errors.ValidationError): pass class InvalidScoreTargetError(errors.ValidationError):
pass
class InvalidScoreValueError(errors.ValidationError):
pass
def _get_table_info(entity): def _get_table_info(entity):
assert entity assert entity
@ -14,10 +20,12 @@ def _get_table_info(entity):
return db.CommentScore, lambda table: table.comment_id return db.CommentScore, lambda table: table.comment_id
raise InvalidScoreTargetError() raise InvalidScoreTargetError()
def _get_score_entity(entity, user): def _get_score_entity(entity, user):
assert user assert user
return db.util.get_aux_entity(db.session, _get_table_info, entity, user) return db.util.get_aux_entity(db.session, _get_table_info, entity, user)
def delete_score(entity, user): def delete_score(entity, user):
assert entity assert entity
assert user assert user
@ -25,6 +33,7 @@ def delete_score(entity, user):
if score_entity: if score_entity:
db.session.delete(score_entity) db.session.delete(score_entity)
def get_score(entity, user): def get_score(entity, user):
assert entity assert entity
assert user assert user
@ -36,6 +45,7 @@ def get_score(entity, user):
.one_or_none() .one_or_none()
return row[0] if row else 0 return row[0] if row else 0
def set_score(entity, user, score): def set_score(entity, user, score):
assert entity assert entity
assert user assert user

View file

@ -1,6 +1,7 @@
import datetime import datetime
from szurubooru import db from szurubooru import db
def get_tag_snapshot(tag): def get_tag_snapshot(tag):
return { return {
'names': [tag_name.name for tag_name in tag.names], 'names': [tag_name.name for tag_name in tag.names],
@ -9,6 +10,7 @@ def get_tag_snapshot(tag):
'implications': sorted(rel.first_name for rel in tag.implications), 'implications': sorted(rel.first_name for rel in tag.implications),
} }
def get_post_snapshot(post): def get_post_snapshot(post):
return { return {
'source': post.source, 'source': post.source,
@ -25,6 +27,7 @@ def get_post_snapshot(post):
'featured': post.is_featured, 'featured': post.is_featured,
} }
def get_tag_category_snapshot(category): def get_tag_category_snapshot(category):
return { return {
'name': category.name, 'name': category.name,
@ -32,6 +35,7 @@ def get_tag_category_snapshot(category):
'default': True if category.default else False, 'default': True if category.default else False,
} }
def get_previous_snapshot(snapshot): def get_previous_snapshot(snapshot):
assert snapshot assert snapshot
return db.session \ return db.session \
@ -43,6 +47,7 @@ def get_previous_snapshot(snapshot):
.limit(1) \ .limit(1) \
.first() .first()
def get_snapshots(entity): def get_snapshots(entity):
assert entity assert entity
resource_type, resource_id, _ = db.util.get_resource_info(entity) resource_type, resource_id, _ = db.util.get_resource_info(entity)
@ -53,6 +58,7 @@ def get_snapshots(entity):
.order_by(db.Snapshot.creation_time.desc()) \ .order_by(db.Snapshot.creation_time.desc()) \
.all() .all()
def serialize_snapshot(snapshot, earlier_snapshot=()): def serialize_snapshot(snapshot, earlier_snapshot=()):
assert snapshot assert snapshot
if earlier_snapshot is (): if earlier_snapshot is ():
@ -67,6 +73,7 @@ def serialize_snapshot(snapshot, earlier_snapshot=()):
'time': snapshot.creation_time, 'time': snapshot.creation_time,
} }
def get_serialized_history(entity): def get_serialized_history(entity):
if not entity: if not entity:
return [] return []
@ -77,6 +84,7 @@ def get_serialized_history(entity):
earlier_snapshot = snapshot earlier_snapshot = snapshot
return ret return ret
def _save(operation, entity, auth_user): def _save(operation, entity, auth_user):
assert operation assert operation
assert entity assert entity
@ -86,7 +94,8 @@ def _save(operation, entity, auth_user):
'post': get_post_snapshot, 'post': get_post_snapshot,
} }
resource_type, resource_id, resource_repr = db.util.get_resource_info(entity) resource_type, resource_id, resource_repr = (
db.util.get_resource_info(entity))
now = datetime.datetime.utcnow() now = datetime.datetime.utcnow()
snapshot = db.Snapshot() snapshot = db.Snapshot()
@ -118,14 +127,17 @@ def _save(operation, entity, auth_user):
else: else:
db.session.add(snapshot) db.session.add(snapshot)
def save_entity_creation(entity, auth_user): def save_entity_creation(entity, auth_user):
assert entity assert entity
_save(db.Snapshot.OPERATION_CREATED, entity, auth_user) _save(db.Snapshot.OPERATION_CREATED, entity, auth_user)
def save_entity_modification(entity, auth_user): def save_entity_modification(entity, auth_user):
assert entity assert entity
_save(db.Snapshot.OPERATION_MODIFIED, entity, auth_user) _save(db.Snapshot.OPERATION_MODIFIED, entity, auth_user)
def save_entity_deletion(entity, auth_user): def save_entity_deletion(entity, auth_user):
assert entity assert entity
_save(db.Snapshot.OPERATION_DELETED, entity, auth_user) _save(db.Snapshot.OPERATION_DELETED, entity, auth_user)

View file

@ -3,11 +3,26 @@ import sqlalchemy
from szurubooru import config, db, errors from szurubooru import config, db, errors
from szurubooru.func import util, snapshots, cache from szurubooru.func import util, snapshots, cache
class TagCategoryNotFoundError(errors.NotFoundError): pass
class TagCategoryAlreadyExistsError(errors.ValidationError): pass class TagCategoryNotFoundError(errors.NotFoundError):
class TagCategoryIsInUseError(errors.ValidationError): pass pass
class InvalidTagCategoryNameError(errors.ValidationError): pass
class InvalidTagCategoryColorError(errors.ValidationError): pass
class TagCategoryAlreadyExistsError(errors.ValidationError):
pass
class TagCategoryIsInUseError(errors.ValidationError):
pass
class InvalidTagCategoryNameError(errors.ValidationError):
pass
class InvalidTagCategoryColorError(errors.ValidationError):
pass
def _verify_name_validity(name): def _verify_name_validity(name):
name_regex = config.config['tag_category_name_regex'] name_regex = config.config['tag_category_name_regex']
@ -15,6 +30,7 @@ def _verify_name_validity(name):
raise InvalidTagCategoryNameError( raise InvalidTagCategoryNameError(
'Name must satisfy regex %r.' % name_regex) 'Name must satisfy regex %r.' % name_regex)
def serialize_category(category, options=None): def serialize_category(category, options=None):
return util.serialize_entity( return util.serialize_entity(
category, category,
@ -28,6 +44,7 @@ def serialize_category(category, options=None):
}, },
options) options)
def create_category(name, color): def create_category(name, color):
category = db.TagCategory() category = db.TagCategory()
update_category_name(category, name) update_category_name(category, name)
@ -36,13 +53,15 @@ def create_category(name, color):
category.default = True category.default = True
return category return category
def update_category_name(category, name): def update_category_name(category, name):
assert category assert category
if not name: if not name:
raise InvalidTagCategoryNameError('Name cannot be empty.') raise InvalidTagCategoryNameError('Name cannot be empty.')
expr = sqlalchemy.func.lower(db.TagCategory.name) == name.lower() expr = sqlalchemy.func.lower(db.TagCategory.name) == name.lower()
if category.tag_category_id: if category.tag_category_id:
expr = expr & (db.TagCategory.tag_category_id != category.tag_category_id) expr = expr & (
db.TagCategory.tag_category_id != category.tag_category_id)
already_exists = db.session.query(db.TagCategory).filter(expr).count() > 0 already_exists = db.session.query(db.TagCategory).filter(expr).count() > 0
if already_exists: if already_exists:
raise TagCategoryAlreadyExistsError( raise TagCategoryAlreadyExistsError(
@ -52,6 +71,7 @@ def update_category_name(category, name):
_verify_name_validity(name) _verify_name_validity(name)
category.name = name category.name = name
def update_category_color(category, color): def update_category_color(category, color):
assert category assert category
if not color: if not color:
@ -62,24 +82,29 @@ def update_category_color(category, color):
raise InvalidTagCategoryColorError('Color is too long.') raise InvalidTagCategoryColorError('Color is too long.')
category.color = color category.color = color
def try_get_category_by_name(name): def try_get_category_by_name(name):
return db.session \ return db.session \
.query(db.TagCategory) \ .query(db.TagCategory) \
.filter(sqlalchemy.func.lower(db.TagCategory.name) == name.lower()) \ .filter(sqlalchemy.func.lower(db.TagCategory.name) == name.lower()) \
.one_or_none() .one_or_none()
def get_category_by_name(name): def get_category_by_name(name):
category = try_get_category_by_name(name) category = try_get_category_by_name(name)
if not category: if not category:
raise TagCategoryNotFoundError('Tag category %r not found.' % name) raise TagCategoryNotFoundError('Tag category %r not found.' % name)
return category return category
def get_all_category_names(): def get_all_category_names():
return [row[0] for row in db.session.query(db.TagCategory.name).all()] return [row[0] for row in db.session.query(db.TagCategory.name).all()]
def get_all_categories(): def get_all_categories():
return db.session.query(db.TagCategory).all() return db.session.query(db.TagCategory).all()
def try_get_default_category(): def try_get_default_category():
key = 'default-tag-category' key = 'default-tag-category'
if cache.has(key): if cache.has(key):
@ -98,12 +123,14 @@ def try_get_default_category():
cache.put(key, category) cache.put(key, category)
return category return category
def get_default_category(): def get_default_category():
category = try_get_default_category() category = try_get_default_category()
if not category: if not category:
raise TagCategoryNotFoundError('No tag category created yet.') raise TagCategoryNotFoundError('No tag category created yet.')
return category return category
def set_default_category(category): def set_default_category(category):
assert category assert category
old_category = try_get_default_category() old_category = try_get_default_category()
@ -111,6 +138,7 @@ def set_default_category(category):
old_category.default = False old_category.default = False
category.default = True category.default = True
def delete_category(category): def delete_category(category):
assert category assert category
if len(get_all_category_names()) == 1: if len(get_all_category_names()) == 1:

View file

@ -6,32 +6,57 @@ import sqlalchemy
from szurubooru import config, db, errors from szurubooru import config, db, errors
from szurubooru.func import util, tag_categories, snapshots from szurubooru.func import util, tag_categories, snapshots
class TagNotFoundError(errors.NotFoundError): pass
class TagAlreadyExistsError(errors.ValidationError): pass class TagNotFoundError(errors.NotFoundError):
class TagIsInUseError(errors.ValidationError): pass pass
class InvalidTagNameError(errors.ValidationError): pass
class InvalidTagRelationError(errors.ValidationError): pass
class InvalidTagCategoryError(errors.ValidationError): pass class TagAlreadyExistsError(errors.ValidationError):
class InvalidTagDescriptionError(errors.ValidationError): pass pass
class TagIsInUseError(errors.ValidationError):
pass
class InvalidTagNameError(errors.ValidationError):
pass
class InvalidTagRelationError(errors.ValidationError):
pass
class InvalidTagCategoryError(errors.ValidationError):
pass
class InvalidTagDescriptionError(errors.ValidationError):
pass
def _verify_name_validity(name): def _verify_name_validity(name):
name_regex = config.config['tag_name_regex'] name_regex = config.config['tag_name_regex']
if not re.match(name_regex, name): if not re.match(name_regex, name):
raise InvalidTagNameError('Name must satisfy regex %r.' % name_regex) raise InvalidTagNameError('Name must satisfy regex %r.' % name_regex)
def _get_plain_names(tag):
def _get_names(tag):
assert tag assert tag
return [tag_name.name for tag_name in tag.names] return [tag_name.name for tag_name in tag.names]
def _lower_list(names): def _lower_list(names):
return [name.lower() for name in names] return [name.lower() for name in names]
def _check_name_intersection(names1, names2):
return len(set(_lower_list(names1)).intersection(_lower_list(names2))) > 0
def _check_name_intersection_case_sensitive(names1, names2): def _check_name_intersection(names1, names2, case_sensitive):
if not case_sensitive:
names1 = _lower_list(names1)
names2 = _lower_list(names2)
return len(set(names1).intersection(names2)) > 0 return len(set(names1).intersection(names2)) > 0
def sort_tags(tags): def sort_tags(tags):
default_category = tag_categories.try_get_default_category() default_category = tag_categories.try_get_default_category()
default_category_name = default_category.name if default_category else None default_category_name = default_category.name if default_category else None
@ -43,6 +68,7 @@ def sort_tags(tags):
tag.names[0].name) tag.names[0].name)
) )
def serialize_tag(tag, options=None): def serialize_tag(tag, options=None):
return util.serialize_entity( return util.serialize_entity(
tag, tag,
@ -55,15 +81,16 @@ def serialize_tag(tag, options=None):
'lastEditTime': lambda: tag.last_edit_time, 'lastEditTime': lambda: tag.last_edit_time,
'usages': lambda: tag.post_count, 'usages': lambda: tag.post_count,
'suggestions': lambda: [ 'suggestions': lambda: [
relation.names[0].name \ relation.names[0].name
for relation in sort_tags(tag.suggestions)], for relation in sort_tags(tag.suggestions)],
'implications': lambda: [ 'implications': lambda: [
relation.names[0].name \ relation.names[0].name
for relation in sort_tags(tag.implications)], for relation in sort_tags(tag.implications)],
'snapshots': lambda: snapshots.get_serialized_history(tag), 'snapshots': lambda: snapshots.get_serialized_history(tag),
}, },
options) options)
def export_to_json(): def export_to_json():
tags = {} tags = {}
categories = {} categories = {}
@ -82,19 +109,19 @@ def export_to_json():
tags[result[0]] = {'names': []} tags[result[0]] = {'names': []}
tags[result[0]]['names'].append(result[1]) tags[result[0]]['names'].append(result[1])
for result in db.session \ for result in (db.session
.query(db.TagSuggestion.parent_id, db.TagName.name) \ .query(db.TagSuggestion.parent_id, db.TagName.name)
.join(db.TagName, db.TagName.tag_id == db.TagSuggestion.child_id) \ .join(db.TagName, db.TagName.tag_id == db.TagSuggestion.child_id)
.all(): .all()):
if not 'suggestions' in tags[result[0]]: if 'suggestions' not in tags[result[0]]:
tags[result[0]]['suggestions'] = [] tags[result[0]]['suggestions'] = []
tags[result[0]]['suggestions'].append(result[1]) tags[result[0]]['suggestions'].append(result[1])
for result in db.session \ for result in (db.session
.query(db.TagImplication.parent_id, db.TagName.name) \ .query(db.TagImplication.parent_id, db.TagName.name)
.join(db.TagName, db.TagName.tag_id == db.TagImplication.child_id) \ .join(db.TagName, db.TagName.tag_id == db.TagImplication.child_id)
.all(): .all()):
if not 'implications' in tags[result[0]]: if 'implications' not in tags[result[0]]:
tags[result[0]]['implications'] = [] tags[result[0]]['implications'] = []
tags[result[0]]['implications'].append(result[1]) tags[result[0]]['implications'].append(result[1])
@ -114,12 +141,14 @@ def export_to_json():
with open(export_path, 'w') as handle: with open(export_path, 'w') as handle:
handle.write(json.dumps(output, separators=(',', ':'))) handle.write(json.dumps(output, separators=(',', ':')))
def try_get_tag_by_name(name): def try_get_tag_by_name(name):
return db.session \ return (db.session
.query(db.Tag) \ .query(db.Tag)
.join(db.TagName) \ .join(db.TagName)
.filter(sqlalchemy.func.lower(db.TagName.name) == name.lower()) \ .filter(sqlalchemy.func.lower(db.TagName.name) == name.lower())
.one_or_none() .one_or_none())
def get_tag_by_name(name): def get_tag_by_name(name):
tag = try_get_tag_by_name(name) tag = try_get_tag_by_name(name)
@ -127,6 +156,7 @@ def get_tag_by_name(name):
raise TagNotFoundError('Tag %r not found.' % name) raise TagNotFoundError('Tag %r not found.' % name)
return tag return tag
def get_tags_by_names(names): def get_tags_by_names(names):
names = util.icase_unique(names) names = util.icase_unique(names)
if len(names) == 0: if len(names) == 0:
@ -136,6 +166,7 @@ def get_tags_by_names(names):
expr = expr | (sqlalchemy.func.lower(db.TagName.name) == name.lower()) expr = expr | (sqlalchemy.func.lower(db.TagName.name) == name.lower())
return db.session.query(db.Tag).join(db.TagName).filter(expr).all() return db.session.query(db.Tag).join(db.TagName).filter(expr).all()
def get_or_create_tags_by_names(names): def get_or_create_tags_by_names(names):
names = util.icase_unique(names) names = util.icase_unique(names)
existing_tags = get_tags_by_names(names) existing_tags = get_tags_by_names(names)
@ -144,7 +175,8 @@ def get_or_create_tags_by_names(names):
for name in names: for name in names:
found = False found = False
for existing_tag in existing_tags: for existing_tag in existing_tags:
if _check_name_intersection(_get_plain_names(existing_tag), [name]): if _check_name_intersection(
_get_names(existing_tag), [name], False):
found = True found = True
break break
if not found: if not found:
@ -157,32 +189,35 @@ def get_or_create_tags_by_names(names):
new_tags.append(new_tag) new_tags.append(new_tag)
return existing_tags, new_tags return existing_tags, new_tags
def get_tag_siblings(tag): def get_tag_siblings(tag):
assert tag assert tag
tag_alias = sqlalchemy.orm.aliased(db.Tag) tag_alias = sqlalchemy.orm.aliased(db.Tag)
pt_alias1 = sqlalchemy.orm.aliased(db.PostTag) pt_alias1 = sqlalchemy.orm.aliased(db.PostTag)
pt_alias2 = sqlalchemy.orm.aliased(db.PostTag) pt_alias2 = sqlalchemy.orm.aliased(db.PostTag)
result = db.session \ result = (db.session
.query(tag_alias, sqlalchemy.func.count(pt_alias2.post_id)) \ .query(tag_alias, sqlalchemy.func.count(pt_alias2.post_id))
.join(pt_alias1, pt_alias1.tag_id == tag_alias.tag_id) \ .join(pt_alias1, pt_alias1.tag_id == tag_alias.tag_id)
.join(pt_alias2, pt_alias2.post_id == pt_alias1.post_id) \ .join(pt_alias2, pt_alias2.post_id == pt_alias1.post_id)
.filter(pt_alias2.tag_id == tag.tag_id) \ .filter(pt_alias2.tag_id == tag.tag_id)
.filter(pt_alias1.tag_id != tag.tag_id) \ .filter(pt_alias1.tag_id != tag.tag_id)
.group_by(tag_alias.tag_id) \ .group_by(tag_alias.tag_id)
.order_by(sqlalchemy.func.count(pt_alias2.post_id).desc()) \ .order_by(sqlalchemy.func.count(pt_alias2.post_id).desc())
.limit(50) .limit(50))
return result return result
def delete(source_tag): def delete(source_tag):
assert source_tag assert source_tag
db.session.execute( db.session.execute(
sqlalchemy.sql.expression.delete(db.TagSuggestion) \ sqlalchemy.sql.expression.delete(db.TagSuggestion)
.where(db.TagSuggestion.child_id == source_tag.tag_id)) .where(db.TagSuggestion.child_id == source_tag.tag_id))
db.session.execute( db.session.execute(
sqlalchemy.sql.expression.delete(db.TagImplication) \ sqlalchemy.sql.expression.delete(db.TagImplication)
.where(db.TagImplication.child_id == source_tag.tag_id)) .where(db.TagImplication.child_id == source_tag.tag_id))
db.session.delete(source_tag) db.session.delete(source_tag)
def merge_tags(source_tag, target_tag): def merge_tags(source_tag, target_tag):
assert source_tag assert source_tag
assert target_tag assert target_tag
@ -191,15 +226,16 @@ def merge_tags(source_tag, target_tag):
pt1 = db.PostTag pt1 = db.PostTag
pt2 = sqlalchemy.orm.util.aliased(db.PostTag) pt2 = sqlalchemy.orm.util.aliased(db.PostTag)
update_stmt = sqlalchemy.sql.expression.update(pt1) \ update_stmt = (sqlalchemy.sql.expression.update(pt1)
.where(db.PostTag.tag_id == source_tag.tag_id) \ .where(db.PostTag.tag_id == source_tag.tag_id)
.where(~sqlalchemy.exists() \ .where(~sqlalchemy.exists()
.where(pt2.post_id == pt1.post_id) \ .where(pt2.post_id == pt1.post_id)
.where(pt2.tag_id == target_tag.tag_id)) \ .where(pt2.tag_id == target_tag.tag_id))
.values(tag_id=target_tag.tag_id) .values(tag_id=target_tag.tag_id))
db.session.execute(update_stmt) db.session.execute(update_stmt)
delete(source_tag) delete(source_tag)
def create_tag(names, category_name, suggestions, implications): def create_tag(names, category_name, suggestions, implications):
tag = db.Tag() tag = db.Tag()
tag.creation_time = datetime.datetime.utcnow() tag.creation_time = datetime.datetime.utcnow()
@ -209,10 +245,12 @@ def create_tag(names, category_name, suggestions, implications):
update_tag_implications(tag, implications) update_tag_implications(tag, implications)
return tag return tag
def update_tag_category_name(tag, category_name): def update_tag_category_name(tag, category_name):
assert tag assert tag
tag.category = tag_categories.get_category_by_name(category_name) tag.category = tag_categories.get_category_by_name(category_name)
def update_tag_names(tag, names): def update_tag_names(tag, names):
assert tag assert tag
names = util.icase_unique([name for name in names if name]) names = util.icase_unique([name for name in names if name])
@ -232,26 +270,29 @@ def update_tag_names(tag, names):
raise TagAlreadyExistsError( raise TagAlreadyExistsError(
'One of names is already used by another tag.') 'One of names is already used by another tag.')
for tag_name in tag.names[:]: for tag_name in tag.names[:]:
if not _check_name_intersection_case_sensitive([tag_name.name], names): if not _check_name_intersection([tag_name.name], names, True):
tag.names.remove(tag_name) tag.names.remove(tag_name)
for name in names: for name in names:
if not _check_name_intersection_case_sensitive(_get_plain_names(tag), [name]): if not _check_name_intersection(_get_names(tag), [name], True):
tag.names.append(db.TagName(name)) tag.names.append(db.TagName(name))
# TODO: what to do with relations that do not yet exist? # TODO: what to do with relations that do not yet exist?
def update_tag_implications(tag, relations): def update_tag_implications(tag, relations):
assert tag assert tag
if _check_name_intersection(_get_plain_names(tag), relations): if _check_name_intersection(_get_names(tag), relations, False):
raise InvalidTagRelationError('Tag cannot imply itself.') raise InvalidTagRelationError('Tag cannot imply itself.')
tag.implications = get_tags_by_names(relations) tag.implications = get_tags_by_names(relations)
# TODO: what to do with relations that do not yet exist? # TODO: what to do with relations that do not yet exist?
def update_tag_suggestions(tag, relations): def update_tag_suggestions(tag, relations):
assert tag assert tag
if _check_name_intersection(_get_plain_names(tag), relations): if _check_name_intersection(_get_names(tag), relations, False):
raise InvalidTagRelationError('Tag cannot suggest itself.') raise InvalidTagRelationError('Tag cannot suggest itself.')
tag.suggestions = get_tags_by_names(relations) tag.suggestions = get_tags_by_names(relations)
def update_tag_description(tag, description): def update_tag_description(tag, description):
assert tag assert tag
if util.value_exceeds_column_size(description, db.Tag.description): if util.value_exceeds_column_size(description, db.Tag.description):

View file

@ -4,17 +4,39 @@ from sqlalchemy import func
from szurubooru import config, db, errors from szurubooru import config, db, errors
from szurubooru.func import auth, util, files, images from szurubooru.func import auth, util, files, images
class UserNotFoundError(errors.NotFoundError): pass
class UserAlreadyExistsError(errors.ValidationError): pass class UserNotFoundError(errors.NotFoundError):
class InvalidUserNameError(errors.ValidationError): pass pass
class InvalidEmailError(errors.ValidationError): pass
class InvalidPasswordError(errors.ValidationError): pass
class InvalidRankError(errors.ValidationError): pass class UserAlreadyExistsError(errors.ValidationError):
class InvalidAvatarError(errors.ValidationError): pass pass
class InvalidUserNameError(errors.ValidationError):
pass
class InvalidEmailError(errors.ValidationError):
pass
class InvalidPasswordError(errors.ValidationError):
pass
class InvalidRankError(errors.ValidationError):
pass
class InvalidAvatarError(errors.ValidationError):
pass
def get_avatar_path(user_name): def get_avatar_path(user_name):
return 'avatars/' + user_name.lower() + '.png' return 'avatars/' + user_name.lower() + '.png'
def get_avatar_url(user): def get_avatar_url(user):
assert user assert user
if user.avatar_style == user.AVATAR_GRAVATAR: if user.avatar_style == user.AVATAR_GRAVATAR:
@ -27,6 +49,7 @@ def get_avatar_url(user):
return '%s/avatars/%s.png' % ( return '%s/avatars/%s.png' % (
config.config['data_url'].rstrip('/'), user.name.lower()) config.config['data_url'].rstrip('/'), user.name.lower())
def get_email(user, auth_user, force_show_email): def get_email(user, auth_user, force_show_email):
assert user assert user
assert auth_user assert auth_user
@ -36,6 +59,7 @@ def get_email(user, auth_user, force_show_email):
return False return False
return user.email return user.email
def get_liked_post_count(user, auth_user): def get_liked_post_count(user, auth_user):
assert user assert user
assert auth_user assert auth_user
@ -43,6 +67,7 @@ def get_liked_post_count(user, auth_user):
return False return False
return user.liked_post_count return user.liked_post_count
def get_disliked_post_count(user, auth_user): def get_disliked_post_count(user, auth_user):
assert user assert user
assert auth_user assert auth_user
@ -50,6 +75,7 @@ def get_disliked_post_count(user, auth_user):
return False return False
return user.disliked_post_count return user.disliked_post_count
def serialize_user(user, auth_user, options=None, force_show_email=False): def serialize_user(user, auth_user, options=None, force_show_email=False):
return util.serialize_entity( return util.serialize_entity(
user, user,
@ -73,34 +99,40 @@ def serialize_user(user, auth_user, options=None, force_show_email=False):
}, },
options) options)
def serialize_micro_user(user, auth_user): def serialize_micro_user(user, auth_user):
return serialize_user( return serialize_user(
user, user,
auth_user=auth_user, auth_user=auth_user,
options=['name', 'avatarUrl']) options=['name', 'avatarUrl'])
def get_user_count(): def get_user_count():
return db.session.query(db.User).count() return db.session.query(db.User).count()
def try_get_user_by_name(name): def try_get_user_by_name(name):
return db.session \ return db.session \
.query(db.User) \ .query(db.User) \
.filter(func.lower(db.User.name) == func.lower(name)) \ .filter(func.lower(db.User.name) == func.lower(name)) \
.one_or_none() .one_or_none()
def get_user_by_name(name): def get_user_by_name(name):
user = try_get_user_by_name(name) user = try_get_user_by_name(name)
if not user: if not user:
raise UserNotFoundError('User %r not found.' % name) raise UserNotFoundError('User %r not found.' % name)
return user return user
def try_get_user_by_name_or_email(name_or_email): def try_get_user_by_name_or_email(name_or_email):
return db.session \ return (db.session
.query(db.User) \ .query(db.User)
.filter( .filter(
(func.lower(db.User.name) == func.lower(name_or_email)) (func.lower(db.User.name) == func.lower(name_or_email)) |
| (func.lower(db.User.email) == func.lower(name_or_email))) \ (func.lower(db.User.email) == func.lower(name_or_email)))
.one_or_none() .one_or_none())
def get_user_by_name_or_email(name_or_email): def get_user_by_name_or_email(name_or_email):
user = try_get_user_by_name_or_email(name_or_email) user = try_get_user_by_name_or_email(name_or_email)
@ -108,6 +140,7 @@ def get_user_by_name_or_email(name_or_email):
raise UserNotFoundError('User %r not found.' % name_or_email) raise UserNotFoundError('User %r not found.' % name_or_email)
return user return user
def create_user(name, password, email): def create_user(name, password, email):
user = db.User() user = db.User()
update_user_name(user, name) update_user_name(user, name)
@ -121,6 +154,7 @@ def create_user(name, password, email):
user.avatar_style = db.User.AVATAR_GRAVATAR user.avatar_style = db.User.AVATAR_GRAVATAR
return user return user
def update_user_name(user, name): def update_user_name(user, name):
assert user assert user
if not name: if not name:
@ -139,6 +173,7 @@ def update_user_name(user, name):
files.move(get_avatar_path(user.name), get_avatar_path(name)) files.move(get_avatar_path(user.name), get_avatar_path(name))
user.name = name user.name = name
def update_user_password(user, password): def update_user_password(user, password):
assert user assert user
if not password: if not password:
@ -150,6 +185,7 @@ def update_user_password(user, password):
user.password_salt = auth.create_password() user.password_salt = auth.create_password()
user.password_hash = auth.get_password_hash(user.password_salt, password) user.password_hash = auth.get_password_hash(user.password_salt, password)
def update_user_email(user, email): def update_user_email(user, email):
assert user assert user
if email: if email:
@ -162,6 +198,7 @@ def update_user_email(user, email):
raise InvalidEmailError('E-mail is invalid.') raise InvalidEmailError('E-mail is invalid.')
user.email = email user.email = email
def update_user_rank(user, rank, auth_user): def update_user_rank(user, rank, auth_user):
assert user assert user
if not rank: if not rank:
@ -178,6 +215,7 @@ def update_user_rank(user, rank, auth_user):
raise errors.AuthError('Trying to set higher rank than your own.') raise errors.AuthError('Trying to set higher rank than your own.')
user.rank = rank user.rank = rank
def update_user_avatar(user, avatar_style, avatar_content=None): def update_user_avatar(user, avatar_style, avatar_content=None):
assert user assert user
if avatar_style == 'gravatar': if avatar_style == 'gravatar':
@ -199,10 +237,12 @@ def update_user_avatar(user, avatar_style, avatar_content=None):
'Avatar style %r is invalid. Valid avatar styles: %r.' % ( 'Avatar style %r is invalid. Valid avatar styles: %r.' % (
avatar_style, ['gravatar', 'manual'])) avatar_style, ['gravatar', 'manual']))
def bump_user_login_time(user): def bump_user_login_time(user):
assert user assert user
user.last_login_time = datetime.datetime.utcnow() user.last_login_time = datetime.datetime.utcnow()
def reset_user_password(user): def reset_user_password(user):
assert user assert user
password = auth.create_password() password = auth.create_password()

View file

@ -1,18 +1,22 @@
import os import os
import datetime
import hashlib import hashlib
import re import re
import tempfile import tempfile
from datetime import datetime, timedelta
from contextlib import contextmanager from contextlib import contextmanager
from szurubooru import errors from szurubooru import errors
def snake_case_to_lower_camel_case(text): def snake_case_to_lower_camel_case(text):
components = text.split('_') components = text.split('_')
return components[0].lower() + \ return components[0].lower() + \
''.join(word[0].upper() + word[1:].lower() for word in components[1:]) ''.join(word[0].upper() + word[1:].lower() for word in components[1:])
def snake_case_to_upper_train_case(text): def snake_case_to_upper_train_case(text):
return '-'.join(word[0].upper() + word[1:].lower() for word in text.split('_')) return '-'.join(
word[0].upper() + word[1:].lower() for word in text.split('_'))
def snake_case_to_lower_camel_case_keys(source): def snake_case_to_lower_camel_case_keys(source):
target = {} target = {}
@ -20,9 +24,11 @@ def snake_case_to_lower_camel_case_keys(source):
target[snake_case_to_lower_camel_case(key)] = value target[snake_case_to_lower_camel_case(key)] = value
return target return target
def get_serialization_options(ctx): def get_serialization_options(ctx):
return ctx.get_param_as_list('fields', required=False, default=None) return ctx.get_param_as_list('fields', required=False, default=None)
def serialize_entity(entity, field_factories, options): def serialize_entity(entity, field_factories, options):
if not entity: if not entity:
return None return None
@ -30,13 +36,14 @@ def serialize_entity(entity, field_factories, options):
options = field_factories.keys() options = field_factories.keys()
ret = {} ret = {}
for key in options: for key in options:
if not key in field_factories: if key not in field_factories:
raise errors.ValidationError('Invalid key: %r. Valid keys: %r.' % ( raise errors.ValidationError('Invalid key: %r. Valid keys: %r.' % (
key, list(sorted(field_factories.keys())))) key, list(sorted(field_factories.keys()))))
factory = field_factories[key] factory = field_factories[key]
ret[key] = factory() ret[key] = factory()
return ret return ret
@contextmanager @contextmanager
def create_temp_file(**kwargs): def create_temp_file(**kwargs):
(handle, path) = tempfile.mkstemp(**kwargs) (handle, path) = tempfile.mkstemp(**kwargs)
@ -47,6 +54,7 @@ def create_temp_file(**kwargs):
finally: finally:
os.remove(path) os.remove(path)
def unalias_dict(input_dict): def unalias_dict(input_dict):
output_dict = {} output_dict = {}
for key_list, value in input_dict.items(): for key_list, value in input_dict.items():
@ -56,6 +64,7 @@ def unalias_dict(input_dict):
output_dict[key] = value output_dict[key] = value
return output_dict return output_dict
def get_md5(source): def get_md5(source):
if not isinstance(source, bytes): if not isinstance(source, bytes):
source = source.encode('utf-8') source = source.encode('utf-8')
@ -63,57 +72,58 @@ def get_md5(source):
md5.update(source) md5.update(source)
return md5.hexdigest() return md5.hexdigest()
def flip(source): def flip(source):
return {v: k for k, v in source.items()} return {v: k for k, v in source.items()}
def is_valid_email(email): def is_valid_email(email):
''' Return whether given email address is valid or empty. ''' ''' Return whether given email address is valid or empty. '''
return not email or re.match(r'^[^@]*@[^@]*\.[^@]*$', email) return not email or re.match(r'^[^@]*@[^@]*\.[^@]*$', email)
class dotdict(dict): # pylint: disable=invalid-name
class dotdict(dict): # pylint: disable=invalid-name
''' dot.notation access to dictionary attributes. ''' ''' dot.notation access to dictionary attributes. '''
def __getattr__(self, attr): def __getattr__(self, attr):
return self.get(attr) return self.get(attr)
__setattr__ = dict.__setitem__ __setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__ __delattr__ = dict.__delitem__
def parse_time_range(value): def parse_time_range(value):
''' Return tuple containing min/max time for given text representation. ''' ''' Return tuple containing min/max time for given text representation. '''
one_day = datetime.timedelta(days=1) one_day = timedelta(days=1)
one_second = datetime.timedelta(seconds=1) one_second = timedelta(seconds=1)
value = value.lower() value = value.lower()
if not value: if not value:
raise errors.ValidationError('Empty date format.') raise errors.ValidationError('Empty date format.')
if value == 'today': if value == 'today':
now = datetime.datetime.utcnow() now = datetime.utcnow()
return ( return (
datetime.datetime(now.year, now.month, now.day, 0, 0, 0), datetime(now.year, now.month, now.day, 0, 0, 0),
datetime.datetime(now.year, now.month, now.day, 0, 0, 0) \ datetime(now.year, now.month, now.day, 0, 0, 0)
+ one_day - one_second) + one_day - one_second)
if value == 'yesterday': if value == 'yesterday':
now = datetime.datetime.utcnow() now = datetime.utcnow()
return ( return (
datetime.datetime(now.year, now.month, now.day, 0, 0, 0) - one_day, datetime(now.year, now.month, now.day, 0, 0, 0) - one_day,
datetime.datetime(now.year, now.month, now.day, 0, 0, 0) \ datetime(now.year, now.month, now.day, 0, 0, 0) - one_second)
- one_second)
match = re.match(r'^(\d{4})$', value) match = re.match(r'^(\d{4})$', value)
if match: if match:
year = int(match.group(1)) year = int(match.group(1))
return ( return (datetime(year, 1, 1), datetime(year + 1, 1, 1) - one_second)
datetime.datetime(year, 1, 1),
datetime.datetime(year + 1, 1, 1) - one_second)
match = re.match(r'^(\d{4})-(\d{1,2})$', value) match = re.match(r'^(\d{4})-(\d{1,2})$', value)
if match: if match:
year = int(match.group(1)) year = int(match.group(1))
month = int(match.group(2)) month = int(match.group(2))
return ( return (
datetime.datetime(year, month, 1), datetime(year, month, 1),
datetime.datetime(year, month + 1, 1) - one_second) datetime(year, month + 1, 1) - one_second)
match = re.match(r'^(\d{4})-(\d{1,2})-(\d{1,2})$', value) match = re.match(r'^(\d{4})-(\d{1,2})-(\d{1,2})$', value)
if match: if match:
@ -121,11 +131,12 @@ def parse_time_range(value):
month = int(match.group(2)) month = int(match.group(2))
day = int(match.group(3)) day = int(match.group(3))
return ( return (
datetime.datetime(year, month, day), datetime(year, month, day),
datetime.datetime(year, month, day + 1) - one_second) datetime(year, month, day + 1) - one_second)
raise errors.ValidationError('Invalid date format: %r.' % value) raise errors.ValidationError('Invalid date format: %r.' % value)
def icase_unique(source): def icase_unique(source):
target = [] target = []
target_low = [] target_low = []
@ -135,6 +146,7 @@ def icase_unique(source):
target_low.append(source_item.lower()) target_low.append(source_item.lower())
return target return target
def value_exceeds_column_size(value, column): def value_exceeds_column_size(value, column):
if not value: if not value:
return False return False
@ -143,6 +155,7 @@ def value_exceeds_column_size(value, column):
return False return False
return len(value) > max_length return len(value) > max_length
def verify_version(entity, context, field_name='version'): def verify_version(entity, context, field_name='version'):
actual_version = context.get_param_as_int(field_name, required=True) actual_version = context.get_param_as_int(field_name, required=True)
expected_version = entity.version expected_version = entity.version
@ -151,5 +164,6 @@ def verify_version(entity, context, field_name='version'):
'Someone else modified this in the meantime. ' + 'Someone else modified this in the meantime. ' +
'Please try again.') 'Please try again.')
def bump_version(entity): def bump_version(entity):
entity.version += 1 entity.version += 1

View file

@ -4,6 +4,7 @@ from szurubooru.func import auth, users
from szurubooru.rest import middleware from szurubooru.rest import middleware
from szurubooru.rest.errors import HttpBadRequest from szurubooru.rest.errors import HttpBadRequest
def _authenticate(username, password): def _authenticate(username, password):
''' Try to authenticate user. Throw AuthError for invalid users. ''' ''' Try to authenticate user. Throw AuthError for invalid users. '''
user = users.get_user_by_name(username) user = users.get_user_by_name(username)
@ -11,23 +12,25 @@ def _authenticate(username, password):
raise errors.AuthError('Invalid password.') raise errors.AuthError('Invalid password.')
return user return user
def _create_anonymous_user(): def _create_anonymous_user():
user = db.User() user = db.User()
user.name = None user.name = None
user.rank = 'anonymous' user.rank = 'anonymous'
return user return user
def _get_user(ctx): def _get_user(ctx):
if not ctx.has_header('Authorization'): if not ctx.has_header('Authorization'):
return _create_anonymous_user() return _create_anonymous_user()
try: try:
auth_type, user_and_password = ctx.get_header('Authorization').split(' ', 1) auth_type, credentials = ctx.get_header('Authorization').split(' ', 1)
if auth_type.lower() != 'basic': if auth_type.lower() != 'basic':
raise HttpBadRequest( raise HttpBadRequest(
'Only basic HTTP authentication is supported.') 'Only basic HTTP authentication is supported.')
username, password = base64.decodebytes( username, password = base64.decodebytes(
user_and_password.encode('ascii')).decode('utf8').split(':') credentials.encode('ascii')).decode('utf8').split(':')
return _authenticate(username, password) return _authenticate(username, password)
except ValueError as err: except ValueError as err:
msg = 'Basic authentication header value are not properly formed. ' \ msg = 'Basic authentication header value are not properly formed. ' \
@ -35,6 +38,7 @@ def _get_user(ctx):
raise HttpBadRequest( raise HttpBadRequest(
msg.format(ctx.get_header('Authorization'), str(err))) msg.format(ctx.get_header('Authorization'), str(err)))
@middleware.pre_hook @middleware.pre_hook
def process_request(ctx): def process_request(ctx):
''' Bind the user to request. Update last login time if needed. ''' ''' Bind the user to request. Update last login time if needed. '''

View file

@ -1,6 +1,7 @@
from szurubooru.func import cache from szurubooru.func import cache
from szurubooru.rest import middleware from szurubooru.rest import middleware
@middleware.pre_hook @middleware.pre_hook
def process_request(ctx): def process_request(ctx):
if ctx.method != 'GET': if ctx.method != 'GET':

View file

@ -1,11 +1,13 @@
from szurubooru import db from szurubooru import db
from szurubooru.rest import middleware from szurubooru.rest import middleware
@middleware.pre_hook @middleware.pre_hook
def _process_request(ctx): def _process_request(ctx):
ctx.session = db.session() ctx.session = db.session()
db.reset_query_count() db.reset_query_count()
@middleware.post_hook @middleware.post_hook
def _process_response(_ctx): def _process_response(_ctx):
db.session.remove() db.session.remove()

View file

@ -2,8 +2,10 @@ import logging
from szurubooru import db from szurubooru import db
from szurubooru.rest import middleware from szurubooru.rest import middleware
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@middleware.post_hook @middleware.post_hook
def process_response(ctx): def process_response(ctx):
logger.info( logger.info(

View file

@ -28,6 +28,7 @@ alembic_config.set_main_option(
target_metadata = szurubooru.db.Base.metadata target_metadata = szurubooru.db.Base.metadata
def run_migrations_offline(): def run_migrations_offline():
''' '''
Run migrations in 'offline' mode. Run migrations in 'offline' mode.

View file

@ -13,6 +13,7 @@ down_revision = 'e5c1216a8503'
branch_labels = None branch_labels = None
depends_on = None depends_on = None
def upgrade(): def upgrade():
op.create_table( op.create_table(
'tag_category', 'tag_category',
@ -55,6 +56,7 @@ def upgrade():
sa.ForeignKeyConstraint(['child_id'], ['tag.id']), sa.ForeignKeyConstraint(['child_id'], ['tag.id']),
sa.PrimaryKeyConstraint('parent_id', 'child_id')) sa.PrimaryKeyConstraint('parent_id', 'child_id'))
def downgrade(): def downgrade():
op.drop_table('tag_suggestion') op.drop_table('tag_suggestion')
op.drop_table('tag_implication') op.drop_table('tag_implication')

View file

@ -13,10 +13,16 @@ down_revision = '49ab4e1139ef'
branch_labels = None branch_labels = None
depends_on = None depends_on = None
def upgrade(): def upgrade():
op.add_column('tag_category', sa.Column('default', sa.Boolean(), nullable=True)) op.add_column(
op.execute(sa.table('tag_category', sa.column('default')).update().values(default=False)) 'tag_category', sa.Column('default', sa.Boolean(), nullable=True))
op.execute(
sa.table('tag_category', sa.column('default'))
.update()
.values(default=False))
op.alter_column('tag_category', 'default', nullable=False) op.alter_column('tag_category', 'default', nullable=False)
def downgrade(): def downgrade():
op.drop_column('tag_category', 'default') op.drop_column('tag_category', 'default')

View file

@ -13,8 +13,11 @@ down_revision = 'ed6dd16a30f3'
branch_labels = None branch_labels = None
depends_on = None depends_on = None
def upgrade(): def upgrade():
op.add_column('post', sa.Column('mime-type', sa.Unicode(length=32), nullable=False)) op.add_column(
'post', sa.Column('mime-type', sa.Unicode(length=32), nullable=False))
def downgrade(): def downgrade():
op.drop_column('post', 'mime-type') op.drop_column('post', 'mime-type')

View file

@ -13,6 +13,7 @@ down_revision = '00cb3a2734db'
branch_labels = None branch_labels = None
depends_on = None depends_on = None
def upgrade(): def upgrade():
op.create_table( op.create_table(
'post', 'post',
@ -56,6 +57,7 @@ def upgrade():
sa.ForeignKeyConstraint(['tag_id'], ['tag.id']), sa.ForeignKeyConstraint(['tag_id'], ['tag.id']),
sa.PrimaryKeyConstraint('post_id', 'tag_id')) sa.PrimaryKeyConstraint('post_id', 'tag_id'))
def downgrade(): def downgrade():
op.drop_table('post_tag') op.drop_table('post_tag')
op.drop_table('post_relation') op.drop_table('post_relation')

View file

@ -13,10 +13,12 @@ down_revision = '565e01e3cf6d'
branch_labels = None branch_labels = None
depends_on = None depends_on = None
def upgrade(): def upgrade():
op.add_column( op.add_column(
'snapshot', 'snapshot',
sa.Column('resource_repr', sa.Unicode(length=64), nullable=False)) sa.Column('resource_repr', sa.Unicode(length=64), nullable=False))
def downgrade(): def downgrade():
op.drop_column('snapshot', 'resource_repr') op.drop_column('snapshot', 'resource_repr')

View file

@ -13,6 +13,7 @@ down_revision = '84bd402f15f0'
branch_labels = None branch_labels = None
depends_on = None depends_on = None
def upgrade(): def upgrade():
op.create_table( op.create_table(
'comment', 'comment',
@ -36,6 +37,7 @@ def upgrade():
sa.ForeignKeyConstraint(['user_id'], ['user.id']), sa.ForeignKeyConstraint(['user_id'], ['user.id']),
sa.PrimaryKeyConstraint('comment_id', 'user_id')) sa.PrimaryKeyConstraint('comment_id', 'user_id'))
def downgrade(): def downgrade():
op.drop_table('comment_score') op.drop_table('comment_score')
op.drop_table('comment') op.drop_table('comment')

View file

@ -13,52 +13,59 @@ down_revision = '23abaf4a0a4b'
branch_labels = None branch_labels = None
depends_on = None depends_on = None
def upgrade(): def upgrade():
op.create_index(op.f('ix_comment_post_id'), 'comment', ['post_id'], unique=False) for index_name, table_name, column_name in [
op.create_index(op.f('ix_comment_user_id'), 'comment', ['user_id'], unique=False) ('ix_comment_post_id', 'comment', 'post_id'),
op.create_index(op.f('ix_comment_score_user_id'), 'comment_score', ['user_id'], unique=False) ('ix_comment_user_id', 'comment', 'user_id'),
op.create_index(op.f('ix_post_user_id'), 'post', ['user_id'], unique=False) ('ix_comment_score_user_id', 'comment_score', 'user_id'),
op.create_index(op.f('ix_post_favorite_post_id'), 'post_favorite', ['post_id'], unique=False) ('ix_post_user_id', 'post', 'user_id'),
op.create_index(op.f('ix_post_favorite_user_id'), 'post_favorite', ['user_id'], unique=False) ('ix_post_favorite_post_id', 'post_favorite', 'post_id'),
op.create_index(op.f('ix_post_feature_post_id'), 'post_feature', ['post_id'], unique=False) ('ix_post_favorite_user_id', 'post_favorite', 'user_id'),
op.create_index(op.f('ix_post_feature_user_id'), 'post_feature', ['user_id'], unique=False) ('ix_post_feature_post_id', 'post_feature', 'post_id'),
op.create_index(op.f('ix_post_note_post_id'), 'post_note', ['post_id'], unique=False) ('ix_post_feature_user_id', 'post_feature', 'user_id'),
op.create_index(op.f('ix_post_relation_child_id'), 'post_relation', ['child_id'], unique=False) ('ix_post_note_post_id', 'post_note', 'post_id'),
op.create_index(op.f('ix_post_relation_parent_id'), 'post_relation', ['parent_id'], unique=False) ('ix_post_relation_child_id', 'post_relation', 'child_id'),
op.create_index(op.f('ix_post_score_post_id'), 'post_score', ['post_id'], unique=False) ('ix_post_relation_parent_id', 'post_relation', 'parent_id'),
op.create_index(op.f('ix_post_score_user_id'), 'post_score', ['user_id'], unique=False) ('ix_post_score_post_id', 'post_score', 'post_id'),
op.create_index(op.f('ix_post_tag_post_id'), 'post_tag', ['post_id'], unique=False) ('ix_post_score_user_id', 'post_score', 'user_id'),
op.create_index(op.f('ix_post_tag_tag_id'), 'post_tag', ['tag_id'], unique=False) ('ix_post_tag_post_id', 'post_tag', 'post_id'),
op.create_index(op.f('ix_snapshot_resource_id'), 'snapshot', ['resource_id'], unique=False) ('ix_post_tag_tag_id', 'post_tag', 'tag_id'),
op.create_index(op.f('ix_snapshot_resource_type'), 'snapshot', ['resource_type'], unique=False) ('ix_snapshot_resource_id', 'snapshot', 'resource_id'),
op.create_index(op.f('ix_tag_category_id'), 'tag', ['category_id'], unique=False) ('ix_snapshot_resource_type', 'snapshot', 'resource_type'),
op.create_index(op.f('ix_tag_implication_child_id'), 'tag_implication', ['child_id'], unique=False) ('ix_tag_category_id', 'tag', 'category_id'),
op.create_index(op.f('ix_tag_implication_parent_id'), 'tag_implication', ['parent_id'], unique=False) ('ix_tag_implication_child_id', 'tag_implication', 'child_id'),
op.create_index(op.f('ix_tag_name_tag_id'), 'tag_name', ['tag_id'], unique=False) ('ix_tag_implication_parent_id', 'tag_implication', 'parent_id'),
op.create_index(op.f('ix_tag_suggestion_child_id'), 'tag_suggestion', ['child_id'], unique=False) ('ix_tag_name_tag_id', 'tag_name', 'tag_id'),
op.create_index(op.f('ix_tag_suggestion_parent_id'), 'tag_suggestion', ['parent_id'], unique=False) ('ix_tag_suggestion_child_id', 'tag_suggestion', 'child_id'),
('ix_tag_suggestion_parent_id', 'tag_suggestion', 'parent_id')]:
op.create_index(
op.f(index_name), table_name, [column_name], unique=False)
def downgrade(): def downgrade():
op.drop_index(op.f('ix_tag_suggestion_parent_id'), table_name='tag_suggestion') for index_name, table_name in [
op.drop_index(op.f('ix_tag_suggestion_child_id'), table_name='tag_suggestion') ('ix_tag_suggestion_parent_id', 'tag_suggestion'),
op.drop_index(op.f('ix_tag_name_tag_id'), table_name='tag_name') ('ix_tag_suggestion_child_id', 'tag_suggestion'),
op.drop_index(op.f('ix_tag_implication_parent_id'), table_name='tag_implication') ('ix_tag_name_tag_id', 'tag_name'),
op.drop_index(op.f('ix_tag_implication_child_id'), table_name='tag_implication') ('ix_tag_implication_parent_id', 'tag_implication'),
op.drop_index(op.f('ix_tag_category_id'), table_name='tag') ('ix_tag_implication_child_id', 'tag_implication'),
op.drop_index(op.f('ix_snapshot_resource_type'), table_name='snapshot') ('ix_tag_category_id', 'tag'),
op.drop_index(op.f('ix_snapshot_resource_id'), table_name='snapshot') ('ix_snapshot_resource_type', 'snapshot'),
op.drop_index(op.f('ix_post_tag_tag_id'), table_name='post_tag') ('ix_snapshot_resource_id', 'snapshot'),
op.drop_index(op.f('ix_post_tag_post_id'), table_name='post_tag') ('ix_post_tag_tag_id', 'post_tag'),
op.drop_index(op.f('ix_post_score_user_id'), table_name='post_score') ('ix_post_tag_post_id', 'post_tag'),
op.drop_index(op.f('ix_post_score_post_id'), table_name='post_score') ('ix_post_score_user_id', 'post_score'),
op.drop_index(op.f('ix_post_relation_parent_id'), table_name='post_relation') ('ix_post_score_post_id', 'post_score'),
op.drop_index(op.f('ix_post_relation_child_id'), table_name='post_relation') ('ix_post_relation_parent_id', 'post_relation'),
op.drop_index(op.f('ix_post_note_post_id'), table_name='post_note') ('ix_post_relation_child_id', 'post_relation'),
op.drop_index(op.f('ix_post_feature_user_id'), table_name='post_feature') ('ix_post_note_post_id', 'post_note'),
op.drop_index(op.f('ix_post_feature_post_id'), table_name='post_feature') ('ix_post_feature_user_id', 'post_feature'),
op.drop_index(op.f('ix_post_favorite_user_id'), table_name='post_favorite') ('ix_post_feature_post_id', 'post_feature'),
op.drop_index(op.f('ix_post_favorite_post_id'), table_name='post_favorite') ('ix_post_favorite_user_id', 'post_favorite'),
op.drop_index(op.f('ix_post_user_id'), table_name='post') ('ix_post_favorite_post_id', 'post_favorite'),
op.drop_index(op.f('ix_comment_score_user_id'), table_name='comment_score') ('ix_post_user_id', 'post'),
op.drop_index(op.f('ix_comment_user_id'), table_name='comment') ('ix_comment_score_user_id', 'comment_score'),
op.drop_index(op.f('ix_comment_post_id'), table_name='comment') ('ix_comment_user_id', 'comment'),
('ix_comment_post_id', 'comment')]:
op.drop_index(op.f(index_name), table_name=table_name)

View file

@ -13,8 +13,11 @@ down_revision = '055d0e048fb3'
branch_labels = None branch_labels = None
depends_on = None depends_on = None
def upgrade(): def upgrade():
op.add_column('tag', sa.Column('description', sa.UnicodeText(), nullable=True)) op.add_column(
'tag', sa.Column('description', sa.UnicodeText(), nullable=True))
def downgrade(): def downgrade():
op.drop_column('tag', 'description') op.drop_column('tag', 'description')

View file

@ -13,6 +13,7 @@ down_revision = '336a76ec1338'
branch_labels = None branch_labels = None
depends_on = None depends_on = None
def upgrade(): def upgrade():
op.create_table( op.create_table(
'snapshot', 'snapshot',
@ -26,5 +27,6 @@ def upgrade():
sa.ForeignKeyConstraint(['user_id'], ['user.id']), sa.ForeignKeyConstraint(['user_id'], ['user.id']),
sa.PrimaryKeyConstraint('id')) sa.PrimaryKeyConstraint('id'))
def downgrade(): def downgrade():
op.drop_table('snapshot') op.drop_table('snapshot')

View file

@ -15,12 +15,17 @@ depends_on = None
tables = ['tag_category', 'tag', 'user', 'post', 'comment'] tables = ['tag_category', 'tag', 'user', 'post', 'comment']
def upgrade(): def upgrade():
for table in tables: for table in tables:
op.add_column(table, sa.Column('version', sa.Integer(), nullable=True)) op.add_column(table, sa.Column('version', sa.Integer(), nullable=True))
op.execute(sa.table(table, sa.column('version')).update().values(version=1)) op.execute(
sa.table(table, sa.column('version'))
.update()
.values(version=1))
op.alter_column(table, 'version', nullable=False) op.alter_column(table, 'version', nullable=False)
def downgrade(): def downgrade():
for table in tables: for table in tables:
op.drop_column(table, 'version') op.drop_column(table, 'version')

View file

@ -13,10 +13,14 @@ down_revision = '9587de88a84b'
branch_labels = None branch_labels = None
depends_on = None depends_on = None
def upgrade(): def upgrade():
op.drop_column('post', 'flags') op.drop_column('post', 'flags')
op.add_column('post', sa.Column('flags', sa.PickleType(), nullable=True)) op.add_column('post', sa.Column('flags', sa.PickleType(), nullable=True))
def downgrade(): def downgrade():
op.drop_column('post', 'flags') op.drop_column('post', 'flags')
op.add_column('post', sa.Column('flags', sa.Integer(), autoincrement=False, nullable=False)) op.add_column(
'post',
sa.Column('flags', sa.Integer(), autoincrement=False, nullable=False))

View file

@ -13,6 +13,7 @@ down_revision = '46cd5229839b'
branch_labels = None branch_labels = None
depends_on = None depends_on = None
def upgrade(): def upgrade():
op.create_table( op.create_table(
'post_favorite', 'post_favorite',
@ -52,6 +53,7 @@ def upgrade():
sa.ForeignKeyConstraint(['user_id'], ['user.id']), sa.ForeignKeyConstraint(['user_id'], ['user.id']),
sa.PrimaryKeyConstraint('post_id', 'user_id')) sa.PrimaryKeyConstraint('post_id', 'user_id'))
def downgrade(): def downgrade():
op.drop_table('post_score') op.drop_table('post_score')
op.drop_table('post_note') op.drop_table('post_note')

View file

@ -13,6 +13,7 @@ down_revision = None
branch_labels = None branch_labels = None
depends_on = None depends_on = None
def upgrade(): def upgrade():
op.create_table( op.create_table(
'user', 'user',
@ -28,5 +29,6 @@ def upgrade():
sa.PrimaryKeyConstraint('id')) sa.PrimaryKeyConstraint('id'))
op.create_unique_constraint('uq_user_name', 'user', ['name']) op.create_unique_constraint('uq_user_name', 'user', ['name'])
def downgrade(): def downgrade():
op.drop_table('user') op.drop_table('user')

View file

@ -13,24 +13,36 @@ down_revision = '46df355634dc'
branch_labels = None branch_labels = None
depends_on = None depends_on = None
def upgrade(): def upgrade():
op.drop_column('post', 'auto_comment_edit_time') for column_name in [
op.drop_column('post', 'auto_fav_count') 'auto_comment_edit_time'
op.drop_column('post', 'auto_comment_creation_time') 'auto_fav_count',
op.drop_column('post', 'auto_feature_count') 'auto_comment_creation_time',
op.drop_column('post', 'auto_comment_count') 'auto_feature_count',
op.drop_column('post', 'auto_score') 'auto_comment_count',
op.drop_column('post', 'auto_fav_time') 'auto_score',
op.drop_column('post', 'auto_feature_time') 'auto_fav_time',
op.drop_column('post', 'auto_note_count') 'auto_feature_time',
'auto_note_count']:
op.drop_column('post', column_name)
def downgrade(): def downgrade():
op.add_column('post', sa.Column('auto_note_count', sa.INTEGER(), autoincrement=False, nullable=False)) for column_name in [
op.add_column('post', sa.Column('auto_feature_time', sa.INTEGER(), autoincrement=False, nullable=False)) 'auto_note_count',
op.add_column('post', sa.Column('auto_fav_time', sa.INTEGER(), autoincrement=False, nullable=False)) 'auto_feature_time',
op.add_column('post', sa.Column('auto_score', sa.INTEGER(), autoincrement=False, nullable=False)) 'auto_fav_time',
op.add_column('post', sa.Column('auto_comment_count', sa.INTEGER(), autoincrement=False, nullable=False)) 'auto_score',
op.add_column('post', sa.Column('auto_feature_count', sa.INTEGER(), autoincrement=False, nullable=False)) 'auto_comment_count',
op.add_column('post', sa.Column('auto_comment_creation_time', sa.INTEGER(), autoincrement=False, nullable=False)) 'auto_feature_count',
op.add_column('post', sa.Column('auto_fav_count', sa.INTEGER(), autoincrement=False, nullable=False)) 'auto_comment_creation_time',
op.add_column('post', sa.Column('auto_comment_edit_time', sa.INTEGER(), autoincrement=False, nullable=False)) 'auto_fav_count',
'auto_comment_edit_time']:
op.add_column(
'post',
sa.Column(
column_name,
sa.INTEGER(),
autoincrement=False,
nullable=False))

View file

@ -6,6 +6,7 @@ from datetime import datetime
from szurubooru.func import util from szurubooru.func import util
from szurubooru.rest import errors, middleware, routes, context from szurubooru.rest import errors, middleware, routes, context
def _json_serializer(obj): def _json_serializer(obj):
''' JSON serializer for objects not serializable by default JSON code ''' ''' JSON serializer for objects not serializable by default JSON code '''
if isinstance(obj, datetime): if isinstance(obj, datetime):
@ -13,14 +14,16 @@ def _json_serializer(obj):
return serial return serial
raise TypeError('Type not serializable') raise TypeError('Type not serializable')
def _dump_json(obj): def _dump_json(obj):
return json.dumps(obj, default=_json_serializer, indent=2) return json.dumps(obj, default=_json_serializer, indent=2)
def _read(env): def _read(env):
length = int(env.get('CONTENT_LENGTH', 0)) length = int(env.get('CONTENT_LENGTH', 0))
output = io.BytesIO() output = io.BytesIO()
while length > 0: while length > 0:
part = env['wsgi.input'].read(min(length, 1024*200)) part = env['wsgi.input'].read(min(length, 1024 * 200))
if not part: if not part:
break break
output.write(part) output.write(part)
@ -28,6 +31,7 @@ def _read(env):
output.seek(0) output.seek(0)
return output return output
def _get_headers(env): def _get_headers(env):
headers = {} headers = {}
for key, value in env.items(): for key, value in env.items():
@ -36,6 +40,7 @@ def _get_headers(env):
headers[key] = value headers[key] = value
return headers return headers
def _create_context(env): def _create_context(env):
method = env['REQUEST_METHOD'] method = env['REQUEST_METHOD']
path = '/' + env['PATH_INFO'].lstrip('/') path = '/' + env['PATH_INFO'].lstrip('/')
@ -56,7 +61,7 @@ def _create_context(env):
if isinstance(form[key], cgi.MiniFieldStorage): if isinstance(form[key], cgi.MiniFieldStorage):
params[key] = form.getvalue(key) params[key] = form.getvalue(key)
else: else:
_original_file_name = getattr(form[key], 'filename', None) # _user_file_name = getattr(form[key], 'filename', None)
files[key] = form.getvalue(key) files[key] = form.getvalue(key)
if 'metadata' in form: if 'metadata' in form:
body = form.getvalue('metadata') body = form.getvalue('metadata')
@ -79,10 +84,11 @@ def _create_context(env):
return context.Context(method, path, headers, params, files) return context.Context(method, path, headers, params, files)
def application(env, start_response): def application(env, start_response):
try: try:
ctx = _create_context(env) ctx = _create_context(env)
if not 'application/json' in ctx.get_header('Accept'): if 'application/json' not in ctx.get_header('Accept'):
raise errors.HttpNotAcceptable( raise errors.HttpNotAcceptable(
'This API only supports JSON responses.') 'This API only supports JSON responses.')

View file

@ -1,9 +1,11 @@
from szurubooru import errors from szurubooru import errors
from szurubooru.func import net from szurubooru.func import net
def _lower_first(source): def _lower_first(source):
return source[0].lower() + source[1:] return source[0].lower() + source[1:]
def _param_wrapper(func): def _param_wrapper(func):
def wrapper(self, name, required=False, default=None, **kwargs): def wrapper(self, name, required=False, default=None, **kwargs):
# pylint: disable=protected-access # pylint: disable=protected-access
@ -22,8 +24,8 @@ def _param_wrapper(func):
'Required parameter %r is missing.' % name) 'Required parameter %r is missing.' % name)
return wrapper return wrapper
class Context(): class Context():
# pylint: disable=too-many-arguments
def __init__(self, method, url, headers=None, params=None, files=None): def __init__(self, method, url, headers=None, params=None, files=None):
self.method = method self.method = method
self.url = url self.url = url
@ -74,7 +76,6 @@ class Context():
raise errors.InvalidParameterError('Expected simple string.') raise errors.InvalidParameterError('Expected simple string.')
return value return value
# pylint: disable=redefined-builtin
@_param_wrapper @_param_wrapper
def get_param_as_int(self, value, min=None, max=None): def get_param_as_int(self, value, min=None, max=None):
try: try:
@ -97,4 +98,5 @@ class Context():
return True return True
if value in ['0', 'n', 'no', 'nope', 'f', 'false']: if value in ['0', 'n', 'no', 'nope', 'f', 'false']:
return False return False
raise errors.InvalidParameterError('The value must be a boolean value.') raise errors.InvalidParameterError(
'The value must be a boolean value.')

View file

@ -1,4 +1,5 @@
error_handlers = {} # pylint: disable=invalid-name error_handlers = {} # pylint: disable=invalid-name
class BaseHttpError(RuntimeError): class BaseHttpError(RuntimeError):
code = None code = None
@ -9,29 +10,36 @@ class BaseHttpError(RuntimeError):
self.description = description self.description = description
self.title = title or self.reason self.title = title or self.reason
class HttpBadRequest(BaseHttpError): class HttpBadRequest(BaseHttpError):
code = 400 code = 400
reason = 'Bad Request' reason = 'Bad Request'
class HttpForbidden(BaseHttpError): class HttpForbidden(BaseHttpError):
code = 403 code = 403
reason = 'Forbidden' reason = 'Forbidden'
class HttpNotFound(BaseHttpError): class HttpNotFound(BaseHttpError):
code = 404 code = 404
reason = 'Not Found' reason = 'Not Found'
class HttpNotAcceptable(BaseHttpError): class HttpNotAcceptable(BaseHttpError):
code = 406 code = 406
reason = 'Not Acceptable' reason = 'Not Acceptable'
class HttpConflict(BaseHttpError): class HttpConflict(BaseHttpError):
code = 409 code = 409
reason = 'Conflict' reason = 'Conflict'
class HttpMethodNotAllowed(BaseHttpError): class HttpMethodNotAllowed(BaseHttpError):
code = 405 code = 405
reason = 'Method Not Allowed' reason = 'Method Not Allowed'
def handle(exception_type, handler): def handle(exception_type, handler):
error_handlers[exception_type] = handler error_handlers[exception_type] = handler

View file

@ -2,8 +2,10 @@
pre_hooks = [] pre_hooks = []
post_hooks = [] post_hooks = []
def pre_hook(handler): def pre_hook(handler):
pre_hooks.append(handler) pre_hooks.append(handler)
def post_hook(handler): def post_hook(handler):
post_hooks.insert(0, handler) post_hooks.insert(0, handler)

View file

@ -1,6 +1,8 @@
from collections import defaultdict from collections import defaultdict
routes = defaultdict(dict) # pylint: disable=invalid-name
routes = defaultdict(dict) # pylint: disable=invalid-name
def get(url): def get(url):
def wrapper(handler): def wrapper(handler):
@ -8,18 +10,21 @@ def get(url):
return handler return handler
return wrapper return wrapper
def put(url): def put(url):
def wrapper(handler): def wrapper(handler):
routes[url]['PUT'] = handler routes[url]['PUT'] = handler
return handler return handler
return wrapper return wrapper
def post(url): def post(url):
def wrapper(handler): def wrapper(handler):
routes[url]['POST'] = handler routes[url]['POST'] = handler
return handler return handler
return wrapper return wrapper
def delete(url): def delete(url):
def wrapper(handler): def wrapper(handler):
routes[url]['DELETE'] = handler routes[url]['DELETE'] = handler

View file

@ -1,5 +1,5 @@
from szurubooru.search.configs.user_search_config import UserSearchConfig from .user_search_config import UserSearchConfig
from szurubooru.search.configs.snapshot_search_config import SnapshotSearchConfig from .tag_search_config import TagSearchConfig
from szurubooru.search.configs.tag_search_config import TagSearchConfig from .post_search_config import PostSearchConfig
from szurubooru.search.configs.comment_search_config import CommentSearchConfig from .snapshot_search_config import SnapshotSearchConfig
from szurubooru.search.configs.post_search_config import PostSearchConfig from .comment_search_config import CommentSearchConfig

View file

@ -1,5 +1,6 @@
from szurubooru.search import tokens from szurubooru.search import tokens
class BaseSearchConfig(object): class BaseSearchConfig(object):
SORT_ASC = tokens.SortToken.SORT_ASC SORT_ASC = tokens.SortToken.SORT_ASC
SORT_DESC = tokens.SortToken.SORT_DESC SORT_DESC = tokens.SortToken.SORT_DESC

View file

@ -3,6 +3,7 @@ from szurubooru import db
from szurubooru.search.configs import util as search_util from szurubooru.search.configs import util as search_util
from szurubooru.search.configs.base_search_config import BaseSearchConfig from szurubooru.search.configs.base_search_config import BaseSearchConfig
class CommentSearchConfig(BaseSearchConfig): class CommentSearchConfig(BaseSearchConfig):
def create_filter_query(self): def create_filter_query(self):
return db.session.query(db.Comment).join(db.User) return db.session.query(db.Comment).join(db.User)
@ -22,12 +23,18 @@ class CommentSearchConfig(BaseSearchConfig):
'user': search_util.create_str_filter(db.User.name), 'user': search_util.create_str_filter(db.User.name),
'author': search_util.create_str_filter(db.User.name), 'author': search_util.create_str_filter(db.User.name),
'text': search_util.create_str_filter(db.Comment.text), 'text': search_util.create_str_filter(db.Comment.text),
'creation-date': search_util.create_date_filter(db.Comment.creation_time), 'creation-date':
'creation-time': search_util.create_date_filter(db.Comment.creation_time), search_util.create_date_filter(db.Comment.creation_time),
'last-edit-date': search_util.create_date_filter(db.Comment.last_edit_time), 'creation-time':
'last-edit-time': search_util.create_date_filter(db.Comment.last_edit_time), search_util.create_date_filter(db.Comment.creation_time),
'edit-date': search_util.create_date_filter(db.Comment.last_edit_time), 'last-edit-date':
'edit-time': search_util.create_date_filter(db.Comment.last_edit_time), search_util.create_date_filter(db.Comment.last_edit_time),
'last-edit-time':
search_util.create_date_filter(db.Comment.last_edit_time),
'edit-date':
search_util.create_date_filter(db.Comment.last_edit_time),
'edit-time':
search_util.create_date_filter(db.Comment.last_edit_time),
} }
@property @property

View file

@ -6,6 +6,7 @@ from szurubooru.search import criteria, tokens
from szurubooru.search.configs import util as search_util from szurubooru.search.configs import util as search_util
from szurubooru.search.configs.base_search_config import BaseSearchConfig from szurubooru.search.configs.base_search_config import BaseSearchConfig
def _enum_transformer(available_values, value): def _enum_transformer(available_values, value):
try: try:
return available_values[value.lower()] return available_values[value.lower()]
@ -14,6 +15,7 @@ def _enum_transformer(available_values, value):
'Invalid value: %r. Possible values: %r.' % ( 'Invalid value: %r. Possible values: %r.' % (
value, list(sorted(available_values.keys())))) value, list(sorted(available_values.keys()))))
def _type_transformer(value): def _type_transformer(value):
available_values = { available_values = {
'image': db.Post.TYPE_IMAGE, 'image': db.Post.TYPE_IMAGE,
@ -28,6 +30,7 @@ def _type_transformer(value):
} }
return _enum_transformer(available_values, value) return _enum_transformer(available_values, value)
def _safety_transformer(value): def _safety_transformer(value):
available_values = { available_values = {
'safe': db.Post.SAFETY_SAFE, 'safe': db.Post.SAFETY_SAFE,
@ -37,11 +40,12 @@ def _safety_transformer(value):
} }
return _enum_transformer(available_values, value) return _enum_transformer(available_values, value)
def _create_score_filter(score): def _create_score_filter(score):
def wrapper(query, criterion, negated): def wrapper(query, criterion, negated):
if not getattr(criterion, 'internal', False): if not getattr(criterion, 'internal', False):
raise errors.SearchError( raise errors.SearchError(
'Votes cannot be seen publicly. Did you mean %r?' \ 'Votes cannot be seen publicly. Did you mean %r?'
% 'special:liked') % 'special:liked')
user_alias = aliased(db.User) user_alias = aliased(db.User)
score_alias = aliased(db.PostScore) score_alias = aliased(db.PostScore)
@ -57,6 +61,7 @@ def _create_score_filter(score):
return ret return ret
return wrapper return wrapper
class PostSearchConfig(BaseSearchConfig): class PostSearchConfig(BaseSearchConfig):
def on_search_query_parsed(self, search_query): def on_search_query_parsed(self, search_query):
new_special_tokens = [] new_special_tokens = []
@ -64,7 +69,8 @@ class PostSearchConfig(BaseSearchConfig):
if token.value in ('fav', 'liked', 'disliked'): if token.value in ('fav', 'liked', 'disliked'):
assert self.user assert self.user
if self.user.rank == 'anonymous': if self.user.rank == 'anonymous':
raise errors.SearchError('Must be logged in to use this feature.') raise errors.SearchError(
'Must be logged in to use this feature.')
criterion = criteria.PlainCriterion( criterion = criteria.PlainCriterion(
original_text=self.user.name, original_text=self.user.name,
value=self.user.name) value=self.user.name)
@ -85,9 +91,9 @@ class PostSearchConfig(BaseSearchConfig):
return self.create_count_query() \ return self.create_count_query() \
.options( .options(
# use config optimized for official client # use config optimized for official client
#defer(db.Post.score), # defer(db.Post.score),
#defer(db.Post.favorite_count), # defer(db.Post.favorite_count),
#defer(db.Post.comment_count), # defer(db.Post.comment_count),
defer(db.Post.last_favorite_time), defer(db.Post.last_favorite_time),
defer(db.Post.feature_count), defer(db.Post.feature_count),
defer(db.Post.last_feature_time), defer(db.Post.last_feature_time),
@ -99,8 +105,7 @@ class PostSearchConfig(BaseSearchConfig):
lazyload(db.Post.user), lazyload(db.Post.user),
lazyload(db.Post.relations), lazyload(db.Post.relations),
lazyload(db.Post.notes), lazyload(db.Post.notes),
lazyload(db.Post.favorited_by), lazyload(db.Post.favorited_by))
)
def create_count_query(self): def create_count_query(self):
return db.session.query(db.Post) return db.session.query(db.Post)
@ -153,12 +158,18 @@ class PostSearchConfig(BaseSearchConfig):
'liked': _create_score_filter(1), 'liked': _create_score_filter(1),
'disliked': _create_score_filter(-1), 'disliked': _create_score_filter(-1),
'tag-count': search_util.create_num_filter(db.Post.tag_count), 'tag-count': search_util.create_num_filter(db.Post.tag_count),
'comment-count': search_util.create_num_filter(db.Post.comment_count), 'comment-count':
'fav-count': search_util.create_num_filter(db.Post.favorite_count), search_util.create_num_filter(db.Post.comment_count),
'fav-count':
search_util.create_num_filter(db.Post.favorite_count),
'note-count': search_util.create_num_filter(db.Post.note_count), 'note-count': search_util.create_num_filter(db.Post.note_count),
'relation-count': search_util.create_num_filter(db.Post.relation_count), 'relation-count':
'feature-count': search_util.create_num_filter(db.Post.feature_count), search_util.create_num_filter(db.Post.relation_count),
'type': search_util.create_str_filter(db.Post.type, _type_transformer), 'feature-count':
search_util.create_num_filter(db.Post.feature_count),
'type':
search_util.create_str_filter(
db.Post.type, _type_transformer),
'file-size': search_util.create_num_filter(db.Post.file_size), 'file-size': search_util.create_num_filter(db.Post.file_size),
('image-width', 'width'): ('image-width', 'width'):
search_util.create_num_filter(db.Post.canvas_width), search_util.create_num_filter(db.Post.canvas_width),
@ -171,13 +182,15 @@ class PostSearchConfig(BaseSearchConfig):
('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'): ('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'):
search_util.create_date_filter(db.Post.last_edit_time), search_util.create_date_filter(db.Post.last_edit_time),
('comment-date', 'comment-time'): ('comment-date', 'comment-time'):
search_util.create_date_filter(db.Post.last_comment_creation_time), search_util.create_date_filter(
db.Post.last_comment_creation_time),
('fav-date', 'fav-time'): ('fav-date', 'fav-time'):
search_util.create_date_filter(db.Post.last_favorite_time), search_util.create_date_filter(db.Post.last_favorite_time),
('feature-date', 'feature-time'): ('feature-date', 'feature-time'):
search_util.create_date_filter(db.Post.last_feature_time), search_util.create_date_filter(db.Post.last_feature_time),
('safety', 'rating'): ('safety', 'rating'):
search_util.create_str_filter(db.Post.safety, _safety_transformer), search_util.create_str_filter(
db.Post.safety, _safety_transformer),
}) })
@property @property
@ -193,9 +206,12 @@ class PostSearchConfig(BaseSearchConfig):
'relation-count': (db.Post.relation_count, self.SORT_DESC), 'relation-count': (db.Post.relation_count, self.SORT_DESC),
'feature-count': (db.Post.feature_count, self.SORT_DESC), 'feature-count': (db.Post.feature_count, self.SORT_DESC),
'file-size': (db.Post.file_size, self.SORT_DESC), 'file-size': (db.Post.file_size, self.SORT_DESC),
('image-width', 'width'): (db.Post.canvas_width, self.SORT_DESC), ('image-width', 'width'):
('image-height', 'height'): (db.Post.canvas_height, self.SORT_DESC), (db.Post.canvas_width, self.SORT_DESC),
('image-area', 'area'): (db.Post.canvas_area, self.SORT_DESC), ('image-height', 'height'):
(db.Post.canvas_height, self.SORT_DESC),
('image-area', 'area'):
(db.Post.canvas_area, self.SORT_DESC),
('creation-date', 'creation-time', 'date', 'time'): ('creation-date', 'creation-time', 'date', 'time'):
(db.Post.creation_time, self.SORT_DESC), (db.Post.creation_time, self.SORT_DESC),
('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'): ('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'):

View file

@ -2,6 +2,7 @@ from szurubooru import db
from szurubooru.search.configs import util as search_util from szurubooru.search.configs import util as search_util
from szurubooru.search.configs.base_search_config import BaseSearchConfig from szurubooru.search.configs.base_search_config import BaseSearchConfig
class SnapshotSearchConfig(BaseSearchConfig): class SnapshotSearchConfig(BaseSearchConfig):
def create_filter_query(self): def create_filter_query(self):
return db.session.query(db.Snapshot) return db.session.query(db.Snapshot)

View file

@ -5,6 +5,7 @@ from szurubooru.func import util
from szurubooru.search.configs import util as search_util from szurubooru.search.configs import util as search_util
from szurubooru.search.configs.base_search_config import BaseSearchConfig from szurubooru.search.configs.base_search_config import BaseSearchConfig
class TagSearchConfig(BaseSearchConfig): class TagSearchConfig(BaseSearchConfig):
def create_filter_query(self): def create_filter_query(self):
return self.create_count_query() \ return self.create_count_query() \
@ -13,8 +14,7 @@ class TagSearchConfig(BaseSearchConfig):
subqueryload(db.Tag.names), subqueryload(db.Tag.names),
subqueryload(db.Tag.category), subqueryload(db.Tag.category),
subqueryload(db.Tag.suggestions).joinedload(db.Tag.names), subqueryload(db.Tag.suggestions).joinedload(db.Tag.names),
subqueryload(db.Tag.implications).joinedload(db.Tag.names) subqueryload(db.Tag.implications).joinedload(db.Tag.names))
)
def create_count_query(self): def create_count_query(self):
return db.session.query(db.Tag) return db.session.query(db.Tag)

View file

@ -3,6 +3,7 @@ from szurubooru import db
from szurubooru.search.configs import util as search_util from szurubooru.search.configs import util as search_util
from szurubooru.search.configs.base_search_config import BaseSearchConfig from szurubooru.search.configs.base_search_config import BaseSearchConfig
class UserSearchConfig(BaseSearchConfig): class UserSearchConfig(BaseSearchConfig):
''' Executes searches related to the users. ''' ''' Executes searches related to the users. '''
@ -20,12 +21,18 @@ class UserSearchConfig(BaseSearchConfig):
def named_filters(self): def named_filters(self):
return { return {
'name': search_util.create_str_filter(db.User.name), 'name': search_util.create_str_filter(db.User.name),
'creation-date': search_util.create_date_filter(db.User.creation_time), 'creation-date':
'creation-time': search_util.create_date_filter(db.User.creation_time), search_util.create_date_filter(db.User.creation_time),
'last-login-date': search_util.create_date_filter(db.User.last_login_time), 'creation-time':
'last-login-time': search_util.create_date_filter(db.User.last_login_time), search_util.create_date_filter(db.User.creation_time),
'login-date': search_util.create_date_filter(db.User.last_login_time), 'last-login-date':
'login-time': search_util.create_date_filter(db.User.last_login_time), search_util.create_date_filter(db.User.last_login_time),
'last-login-time':
search_util.create_date_filter(db.User.last_login_time),
'login-date':
search_util.create_date_filter(db.User.last_login_time),
'login-time':
search_util.create_date_filter(db.User.last_login_time),
} }
@property @property

View file

@ -3,9 +3,11 @@ from szurubooru import db, errors
from szurubooru.func import util from szurubooru.func import util
from szurubooru.search import criteria from szurubooru.search import criteria
def wildcard_transformer(value): def wildcard_transformer(value):
return value.replace('*', '%') return value.replace('*', '%')
def apply_num_criterion_to_column(column, criterion): def apply_num_criterion_to_column(column, criterion):
''' '''
Decorate SQLAlchemy filter on given column using supplied criterion. Decorate SQLAlchemy filter on given column using supplied criterion.
@ -32,6 +34,7 @@ def apply_num_criterion_to_column(column, criterion):
'Criterion value %r must be a number.' % (criterion,)) 'Criterion value %r must be a number.' % (criterion,))
return expr return expr
def create_num_filter(column): def create_num_filter(column):
def wrapper(query, criterion, negated): def wrapper(query, criterion, negated):
expr = apply_num_criterion_to_column( expr = apply_num_criterion_to_column(
@ -41,6 +44,7 @@ def create_num_filter(column):
return query.filter(expr) return query.filter(expr)
return wrapper return wrapper
def apply_str_criterion_to_column( def apply_str_criterion_to_column(
column, criterion, transformer=wildcard_transformer): column, criterion, transformer=wildcard_transformer):
''' '''
@ -59,6 +63,7 @@ def apply_str_criterion_to_column(
assert False assert False
return expr return expr
def create_str_filter(column, transformer=wildcard_transformer): def create_str_filter(column, transformer=wildcard_transformer):
def wrapper(query, criterion, negated): def wrapper(query, criterion, negated):
expr = apply_str_criterion_to_column( expr = apply_str_criterion_to_column(
@ -68,6 +73,7 @@ def create_str_filter(column, transformer=wildcard_transformer):
return query.filter(expr) return query.filter(expr)
return wrapper return wrapper
def apply_date_criterion_to_column(column, criterion): def apply_date_criterion_to_column(column, criterion):
''' '''
Decorate SQLAlchemy filter on given column using supplied criterion. Decorate SQLAlchemy filter on given column using supplied criterion.
@ -97,6 +103,7 @@ def apply_date_criterion_to_column(column, criterion):
assert False assert False
return expr return expr
def create_date_filter(column): def create_date_filter(column):
def wrapper(query, criterion, negated): def wrapper(query, criterion, negated):
expr = apply_date_criterion_to_column( expr = apply_date_criterion_to_column(
@ -106,6 +113,7 @@ def create_date_filter(column):
return query.filter(expr) return query.filter(expr)
return wrapper return wrapper
def create_subquery_filter( def create_subquery_filter(
left_id_column, left_id_column,
right_id_column, right_id_column,
@ -113,6 +121,7 @@ def create_subquery_filter(
filter_factory, filter_factory,
subquery_decorator=None): subquery_decorator=None):
filter_func = filter_factory(filter_column) filter_func = filter_factory(filter_column)
def wrapper(query, criterion, negated): def wrapper(query, criterion, negated):
subquery = db.session.query(right_id_column.label('foreign_id')) subquery = db.session.query(right_id_column.label('foreign_id'))
if subquery_decorator: if subquery_decorator:
@ -121,4 +130,5 @@ def create_subquery_filter(
subquery = filter_func(subquery, criterion, negated) subquery = filter_func(subquery, criterion, negated)
subquery = subquery.subquery('t') subquery = subquery.subquery('t')
return query.filter(left_id_column.in_(subquery)) return query.filter(left_id_column.in_(subquery))
return wrapper return wrapper

View file

@ -5,6 +5,7 @@ class _BaseCriterion(object):
def __repr__(self): def __repr__(self):
return self.original_text return self.original_text
class RangedCriterion(_BaseCriterion): class RangedCriterion(_BaseCriterion):
def __init__(self, original_text, min_value, max_value): def __init__(self, original_text, min_value, max_value):
super().__init__(original_text) super().__init__(original_text)
@ -14,6 +15,7 @@ class RangedCriterion(_BaseCriterion):
def __hash__(self): def __hash__(self):
return hash(('range', self.min_value, self.max_value)) return hash(('range', self.min_value, self.max_value))
class PlainCriterion(_BaseCriterion): class PlainCriterion(_BaseCriterion):
def __init__(self, original_text, value): def __init__(self, original_text, value):
super().__init__(original_text) super().__init__(original_text)
@ -22,6 +24,7 @@ class PlainCriterion(_BaseCriterion):
def __hash__(self): def __hash__(self):
return hash(self.value) return hash(self.value)
class ArrayCriterion(_BaseCriterion): class ArrayCriterion(_BaseCriterion):
def __init__(self, original_text, values): def __init__(self, original_text, values):
super().__init__(original_text) super().__init__(original_text)

View file

@ -3,19 +3,22 @@ from szurubooru import db, errors
from szurubooru.func import cache from szurubooru.func import cache
from szurubooru.search import tokens, parser from szurubooru.search import tokens, parser
def _format_dict_keys(source): def _format_dict_keys(source):
return list(sorted(source.keys())) return list(sorted(source.keys()))
def _get_direction(direction, default_direction):
if direction == tokens.SortToken.SORT_DEFAULT: def _get_order(order, default_order):
return default_direction if order == tokens.SortToken.SORT_DEFAULT:
if direction == tokens.SortToken.SORT_NEGATED_DEFAULT: return default_order
if default_direction == tokens.SortToken.SORT_ASC: if order == tokens.SortToken.SORT_NEGATED_DEFAULT:
if default_order == tokens.SortToken.SORT_ASC:
return tokens.SortToken.SORT_DESC return tokens.SortToken.SORT_DESC
elif default_direction == tokens.SortToken.SORT_DESC: elif default_order == tokens.SortToken.SORT_DESC:
return tokens.SortToken.SORT_ASC return tokens.SortToken.SORT_ASC
assert False assert False
return direction return order
class Executor(object): class Executor(object):
''' '''
@ -30,20 +33,26 @@ class Executor(object):
def get_around(self, query_text, entity_id): def get_around(self, query_text, entity_id):
search_query = self.parser.parse(query_text) search_query = self.parser.parse(query_text)
self.config.on_search_query_parsed(search_query) self.config.on_search_query_parsed(search_query)
filter_query = self.config \ filter_query = (
.create_around_query() \ self.config
.options(sqlalchemy.orm.lazyload('*')) .create_around_query()
filter_query = self._prepare_db_query(filter_query, search_query, False) .options(sqlalchemy.orm.lazyload('*')))
prev_filter_query = filter_query \ filter_query = self._prepare_db_query(
.filter(self.config.id_column < entity_id) \ filter_query, search_query, False)
.order_by(None) \ prev_filter_query = (
.order_by(sqlalchemy.func.abs(self.config.id_column - entity_id).asc()) \ filter_query
.limit(1) .filter(self.config.id_column < entity_id)
next_filter_query = filter_query \ .order_by(None)
.filter(self.config.id_column > entity_id) \ .order_by(sqlalchemy.func.abs(
.order_by(None) \ self.config.id_column - entity_id).asc())
.order_by(sqlalchemy.func.abs(self.config.id_column - entity_id).asc()) \ .limit(1))
.limit(1) next_filter_query = (
filter_query
.filter(self.config.id_column > entity_id)
.order_by(None)
.order_by(sqlalchemy.func.abs(
self.config.id_column - entity_id).asc())
.limit(1))
return [ return [
next_filter_query.one_or_none(), next_filter_query.one_or_none(),
prev_filter_query.one_or_none()] prev_filter_query.one_or_none()]
@ -92,7 +101,8 @@ class Executor(object):
def execute_and_serialize(self, ctx, serializer): def execute_and_serialize(self, ctx, serializer):
query = ctx.get_param_as_string('query') query = ctx.get_param_as_string('query')
page = ctx.get_param_as_int('page', default=1, min=1) page = ctx.get_param_as_int('page', default=1, min=1)
page_size = ctx.get_param_as_int('pageSize', default=100, min=1, max=100) page_size = ctx.get_param_as_int(
'pageSize', default=100, min=1, max=100)
count, entities = self.execute(query, page, page_size) count, entities = self.execute(query, page, page_size)
return { return {
'query': query, 'query': query,
@ -124,7 +134,8 @@ class Executor(object):
for token in search_query.special_tokens: for token in search_query.special_tokens:
if token.value not in self.config.special_filters: if token.value not in self.config.special_filters:
raise errors.SearchError( raise errors.SearchError(
'Unknown special token: %r. Available special tokens: %r.' % ( 'Unknown special token: %r. '
'Available special tokens: %r.' % (
token.value, token.value,
_format_dict_keys(self.config.special_filters))) _format_dict_keys(self.config.special_filters)))
db_query = self.config.special_filters[token.value]( db_query = self.config.special_filters[token.value](
@ -134,14 +145,15 @@ class Executor(object):
for token in search_query.sort_tokens: for token in search_query.sort_tokens:
if token.name not in self.config.sort_columns: if token.name not in self.config.sort_columns:
raise errors.SearchError( raise errors.SearchError(
'Unknown sort token: %r. Available sort tokens: %r.' % ( 'Unknown sort token: %r. '
'Available sort tokens: %r.' % (
token.name, token.name,
_format_dict_keys(self.config.sort_columns))) _format_dict_keys(self.config.sort_columns)))
column, default_direction = self.config.sort_columns[token.name] column, default_order = self.config.sort_columns[token.name]
direction = _get_direction(token.direction, default_direction) order = _get_order(token.order, default_order)
if direction == token.SORT_ASC: if order == token.SORT_ASC:
db_query = db_query.order_by(column.asc()) db_query = db_query.order_by(column.asc())
elif direction == token.SORT_DESC: elif order == token.SORT_DESC:
db_query = db_query.order_by(column.desc()) db_query = db_query.order_by(column.desc())
db_query = self.config.finalize_query(db_query) db_query = self.config.finalize_query(db_query)

View file

@ -2,6 +2,7 @@ import re
from szurubooru import errors from szurubooru import errors
from szurubooru.search import criteria, tokens from szurubooru.search import criteria, tokens
def _create_criterion(original_value, value): def _create_criterion(original_value, value):
if '..' in value: if '..' in value:
low, high = value.split('..', 1) low, high = value.split('..', 1)
@ -13,10 +14,12 @@ def _create_criterion(original_value, value):
original_value, value.split(',')) original_value, value.split(','))
return criteria.PlainCriterion(original_value, value) return criteria.PlainCriterion(original_value, value)
def _parse_anonymous(value, negated): def _parse_anonymous(value, negated):
criterion = _create_criterion(value, value) criterion = _create_criterion(value, value)
return tokens.AnonymousToken(criterion, negated) return tokens.AnonymousToken(criterion, negated)
def _parse_named(key, value, negated): def _parse_named(key, value, negated):
original_value = value original_value = value
if key.endswith('-min'): if key.endswith('-min'):
@ -28,34 +31,41 @@ def _parse_named(key, value, negated):
criterion = _create_criterion(original_value, value) criterion = _create_criterion(original_value, value)
return tokens.NamedToken(key, criterion, negated) return tokens.NamedToken(key, criterion, negated)
def _parse_special(value, negated): def _parse_special(value, negated):
return tokens.SpecialToken(value, negated) return tokens.SpecialToken(value, negated)
def _parse_sort(value, negated): def _parse_sort(value, negated):
if value.count(',') == 0: if value.count(',') == 0:
direction_str = None order_str = None
elif value.count(',') == 1: elif value.count(',') == 1:
value, direction_str = value.split(',') value, order_str = value.split(',')
else: else:
raise errors.SearchError('Too many commas in sort style token.') raise errors.SearchError('Too many commas in sort style token.')
try: try:
direction = { order = {
'asc': tokens.SortToken.SORT_ASC, 'asc': tokens.SortToken.SORT_ASC,
'desc': tokens.SortToken.SORT_DESC, 'desc': tokens.SortToken.SORT_DESC,
'': tokens.SortToken.SORT_DEFAULT, '': tokens.SortToken.SORT_DEFAULT,
None: tokens.SortToken.SORT_DEFAULT, None: tokens.SortToken.SORT_DEFAULT,
}[direction_str] }[order_str]
except KeyError: except KeyError:
raise errors.SearchError( raise errors.SearchError(
'Unknown search direction: %r.' % direction_str) 'Unknown search direction: %r.' % order_str)
if negated: if negated:
direction = { order = {
tokens.SortToken.SORT_ASC: tokens.SortToken.SORT_DESC, tokens.SortToken.SORT_ASC:
tokens.SortToken.SORT_DESC: tokens.SortToken.SORT_ASC, tokens.SortToken.SORT_DESC,
tokens.SortToken.SORT_DEFAULT: tokens.SortToken.SORT_NEGATED_DEFAULT, tokens.SortToken.SORT_DESC:
tokens.SortToken.SORT_NEGATED_DEFAULT: tokens.SortToken.SORT_DEFAULT, tokens.SortToken.SORT_ASC,
}[direction] tokens.SortToken.SORT_DEFAULT:
return tokens.SortToken(value, direction) tokens.SortToken.SORT_NEGATED_DEFAULT,
tokens.SortToken.SORT_NEGATED_DEFAULT:
tokens.SortToken.SORT_DEFAULT,
}[order]
return tokens.SortToken(value, order)
class SearchQuery(): class SearchQuery():
def __init__(self): def __init__(self):
@ -71,6 +81,7 @@ class SearchQuery():
tuple(self.special_tokens), tuple(self.special_tokens),
tuple(self.sort_tokens))) tuple(self.sort_tokens)))
class Parser(object): class Parser(object):
def parse(self, query_text): def parse(self, query_text):
query = SearchQuery() query = SearchQuery()
@ -93,5 +104,6 @@ class Parser(object):
query.named_tokens.append( query.named_tokens.append(
_parse_named(key, value, negated)) _parse_named(key, value, negated))
else: else:
query.anonymous_tokens.append(_parse_anonymous(chunk, negated)) query.anonymous_tokens.append(
_parse_anonymous(chunk, negated))
return query return query

View file

@ -6,6 +6,7 @@ class AnonymousToken(object):
def __hash__(self): def __hash__(self):
return hash((self.criterion, self.negated)) return hash((self.criterion, self.negated))
class NamedToken(AnonymousToken): class NamedToken(AnonymousToken):
def __init__(self, name, criterion, negated): def __init__(self, name, criterion, negated):
super().__init__(criterion, negated) super().__init__(criterion, negated)
@ -14,18 +15,20 @@ class NamedToken(AnonymousToken):
def __hash__(self): def __hash__(self):
return hash((self.name, self.criterion, self.negated)) return hash((self.name, self.criterion, self.negated))
class SortToken(object): class SortToken(object):
SORT_DESC = 'desc' SORT_DESC = 'desc'
SORT_ASC = 'asc' SORT_ASC = 'asc'
SORT_DEFAULT = 'default' SORT_DEFAULT = 'default'
SORT_NEGATED_DEFAULT = 'negated default' SORT_NEGATED_DEFAULT = 'negated default'
def __init__(self, name, direction): def __init__(self, name, order):
self.name = name self.name = name
self.direction = direction self.order = order
def __hash__(self): def __hash__(self):
return hash((self.name, self.direction)) return hash((self.name, self.order))
class SpecialToken(object): class SpecialToken(object):
def __init__(self, value, negated): def __init__(self, value, negated):

View file

View file

View file

@ -1,20 +1,22 @@
import pytest
import unittest.mock
from datetime import datetime from datetime import datetime
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import comments, posts from szurubooru.func import comments, posts
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({'privileges': {'comments:create': db.User.RANK_REGULAR}}) config_injector({'privileges': {'comments:create': db.User.RANK_REGULAR}})
def test_creating_comment( def test_creating_comment(
user_factory, post_factory, context_factory, fake_datetime): user_factory, post_factory, context_factory, fake_datetime):
post = post_factory() post = post_factory()
user = user_factory(rank=db.User.RANK_REGULAR) user = user_factory(rank=db.User.RANK_REGULAR)
db.session.add_all([post, user]) db.session.add_all([post, user])
db.session.flush() db.session.flush()
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'), \ with patch('szurubooru.func.comments.serialize_comment'), \
fake_datetime('1997-01-01'): fake_datetime('1997-01-01'):
comments.serialize_comment.return_value = 'serialized comment' comments.serialize_comment.return_value = 'serialized comment'
result = api.comment_api.create_comment( result = api.comment_api.create_comment(
@ -29,6 +31,7 @@ def test_creating_comment(
assert comment.user and comment.user.user_id == user.user_id assert comment.user and comment.user.user_id == user.user_id
assert comment.post and comment.post.post_id == post.post_id assert comment.post and comment.post.post_id == post.post_id
@pytest.mark.parametrize('params', [ @pytest.mark.parametrize('params', [
{'text': None}, {'text': None},
{'text': ''}, {'text': ''},
@ -48,6 +51,7 @@ def test_trying_to_pass_invalid_params(
api.comment_api.create_comment( api.comment_api.create_comment(
context_factory(params=real_params, user=user)) context_factory(params=real_params, user=user))
@pytest.mark.parametrize('field', ['text', 'postId']) @pytest.mark.parametrize('field', ['text', 'postId'])
def test_trying_to_omit_mandatory_field(user_factory, context_factory, field): def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
params = { params = {
@ -61,6 +65,7 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
params={}, params={},
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=db.User.RANK_REGULAR)))
def test_trying_to_comment_non_existing(user_factory, context_factory): def test_trying_to_comment_non_existing(user_factory, context_factory):
user = user_factory(rank=db.User.RANK_REGULAR) user = user_factory(rank=db.User.RANK_REGULAR)
db.session.add_all([user]) db.session.add_all([user])
@ -70,6 +75,7 @@ def test_trying_to_comment_non_existing(user_factory, context_factory):
context_factory( context_factory(
params={'text': 'bad', 'postId': 5}, user=user)) params={'text': 'bad', 'postId': 5}, user=user))
def test_trying_to_create_without_privileges(user_factory, context_factory): def test_trying_to_create_without_privileges(user_factory, context_factory):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
api.comment_api.create_comment( api.comment_api.create_comment(

View file

@ -2,6 +2,7 @@ import pytest
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import comments from szurubooru.func import comments
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({ config_injector({
@ -11,6 +12,7 @@ def inject_config(config_injector):
}, },
}) })
def test_deleting_own_comment(user_factory, comment_factory, context_factory): def test_deleting_own_comment(user_factory, comment_factory, context_factory):
user = user_factory() user = user_factory()
comment = comment_factory(user=user) comment = comment_factory(user=user)
@ -22,6 +24,7 @@ def test_deleting_own_comment(user_factory, comment_factory, context_factory):
assert result == {} assert result == {}
assert db.session.query(db.Comment).count() == 0 assert db.session.query(db.Comment).count() == 0
def test_deleting_someones_else_comment( def test_deleting_someones_else_comment(
user_factory, comment_factory, context_factory): user_factory, comment_factory, context_factory):
user1 = user_factory(rank=db.User.RANK_REGULAR) user1 = user_factory(rank=db.User.RANK_REGULAR)
@ -29,11 +32,12 @@ def test_deleting_someones_else_comment(
comment = comment_factory(user=user1) comment = comment_factory(user=user1)
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()
result = api.comment_api.delete_comment( api.comment_api.delete_comment(
context_factory(params={'version': 1}, user=user2), context_factory(params={'version': 1}, user=user2),
{'comment_id': comment.comment_id}) {'comment_id': comment.comment_id})
assert db.session.query(db.Comment).count() == 0 assert db.session.query(db.Comment).count() == 0
def test_trying_to_delete_someones_else_comment_without_privileges( def test_trying_to_delete_someones_else_comment_without_privileges(
user_factory, comment_factory, context_factory): user_factory, comment_factory, context_factory):
user1 = user_factory(rank=db.User.RANK_REGULAR) user1 = user_factory(rank=db.User.RANK_REGULAR)
@ -47,6 +51,7 @@ def test_trying_to_delete_someones_else_comment_without_privileges(
{'comment_id': comment.comment_id}) {'comment_id': comment.comment_id})
assert db.session.query(db.Comment).count() == 1 assert db.session.query(db.Comment).count() == 1
def test_trying_to_delete_non_existing(user_factory, context_factory): def test_trying_to_delete_non_existing(user_factory, context_factory):
with pytest.raises(comments.CommentNotFoundError): with pytest.raises(comments.CommentNotFoundError):
api.comment_api.delete_comment( api.comment_api.delete_comment(

View file

@ -1,86 +1,92 @@
from unittest.mock import patch
import pytest import pytest
import unittest.mock
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import comments, scores from szurubooru.func import comments
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({'privileges': {'comments:score': db.User.RANK_REGULAR}}) config_injector({'privileges': {'comments:score': db.User.RANK_REGULAR}})
def test_simple_rating( def test_simple_rating(
user_factory, comment_factory, context_factory, fake_datetime): user_factory, comment_factory, context_factory, fake_datetime):
user = user_factory(rank=db.User.RANK_REGULAR) user = user_factory(rank=db.User.RANK_REGULAR)
comment = comment_factory(user=user) comment = comment_factory(user=user)
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'): with patch('szurubooru.func.comments.serialize_comment'), \
fake_datetime('1997-12-01'):
comments.serialize_comment.return_value = 'serialized comment' comments.serialize_comment.return_value = 'serialized comment'
with fake_datetime('1997-12-01'): result = api.comment_api.set_comment_score(
result = api.comment_api.set_comment_score( context_factory(params={'score': 1}, user=user),
context_factory(params={'score': 1}, user=user), {'comment_id': comment.comment_id})
{'comment_id': comment.comment_id})
assert result == 'serialized comment' assert result == 'serialized comment'
assert db.session.query(db.CommentScore).count() == 1 assert db.session.query(db.CommentScore).count() == 1
assert comment is not None assert comment is not None
assert comment.score == 1 assert comment.score == 1
def test_updating_rating( def test_updating_rating(
user_factory, comment_factory, context_factory, fake_datetime): user_factory, comment_factory, context_factory, fake_datetime):
user = user_factory(rank=db.User.RANK_REGULAR) user = user_factory(rank=db.User.RANK_REGULAR)
comment = comment_factory(user=user) comment = comment_factory(user=user)
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'): with patch('szurubooru.func.comments.serialize_comment'):
with fake_datetime('1997-12-01'): with fake_datetime('1997-12-01'):
result = api.comment_api.set_comment_score( api.comment_api.set_comment_score(
context_factory(params={'score': 1}, user=user), context_factory(params={'score': 1}, user=user),
{'comment_id': comment.comment_id}) {'comment_id': comment.comment_id})
with fake_datetime('1997-12-02'): with fake_datetime('1997-12-02'):
result = api.comment_api.set_comment_score( api.comment_api.set_comment_score(
context_factory(params={'score': -1}, user=user), context_factory(params={'score': -1}, user=user),
{'comment_id': comment.comment_id}) {'comment_id': comment.comment_id})
comment = db.session.query(db.Comment).one() comment = db.session.query(db.Comment).one()
assert db.session.query(db.CommentScore).count() == 1 assert db.session.query(db.CommentScore).count() == 1
assert comment.score == -1 assert comment.score == -1
def test_updating_rating_to_zero( def test_updating_rating_to_zero(
user_factory, comment_factory, context_factory, fake_datetime): user_factory, comment_factory, context_factory, fake_datetime):
user = user_factory(rank=db.User.RANK_REGULAR) user = user_factory(rank=db.User.RANK_REGULAR)
comment = comment_factory(user=user) comment = comment_factory(user=user)
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'): with patch('szurubooru.func.comments.serialize_comment'):
with fake_datetime('1997-12-01'): with fake_datetime('1997-12-01'):
result = api.comment_api.set_comment_score( api.comment_api.set_comment_score(
context_factory(params={'score': 1}, user=user), context_factory(params={'score': 1}, user=user),
{'comment_id': comment.comment_id}) {'comment_id': comment.comment_id})
with fake_datetime('1997-12-02'): with fake_datetime('1997-12-02'):
result = api.comment_api.set_comment_score( api.comment_api.set_comment_score(
context_factory(params={'score': 0}, user=user), context_factory(params={'score': 0}, user=user),
{'comment_id': comment.comment_id}) {'comment_id': comment.comment_id})
comment = db.session.query(db.Comment).one() comment = db.session.query(db.Comment).one()
assert db.session.query(db.CommentScore).count() == 0 assert db.session.query(db.CommentScore).count() == 0
assert comment.score == 0 assert comment.score == 0
def test_deleting_rating( def test_deleting_rating(
user_factory, comment_factory, context_factory, fake_datetime): user_factory, comment_factory, context_factory, fake_datetime):
user = user_factory(rank=db.User.RANK_REGULAR) user = user_factory(rank=db.User.RANK_REGULAR)
comment = comment_factory(user=user) comment = comment_factory(user=user)
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'): with patch('szurubooru.func.comments.serialize_comment'):
with fake_datetime('1997-12-01'): with fake_datetime('1997-12-01'):
result = api.comment_api.set_comment_score( api.comment_api.set_comment_score(
context_factory(params={'score': 1}, user=user), context_factory(params={'score': 1}, user=user),
{'comment_id': comment.comment_id}) {'comment_id': comment.comment_id})
with fake_datetime('1997-12-02'): with fake_datetime('1997-12-02'):
result = api.comment_api.delete_comment_score( api.comment_api.delete_comment_score(
context_factory(user=user), context_factory(user=user),
{'comment_id': comment.comment_id}) {'comment_id': comment.comment_id})
comment = db.session.query(db.Comment).one() comment = db.session.query(db.Comment).one()
assert db.session.query(db.CommentScore).count() == 0 assert db.session.query(db.CommentScore).count() == 0
assert comment.score == 0 assert comment.score == 0
def test_ratings_from_multiple_users( def test_ratings_from_multiple_users(
user_factory, comment_factory, context_factory, fake_datetime): user_factory, comment_factory, context_factory, fake_datetime):
user1 = user_factory(rank=db.User.RANK_REGULAR) user1 = user_factory(rank=db.User.RANK_REGULAR)
@ -88,19 +94,20 @@ def test_ratings_from_multiple_users(
comment = comment_factory() comment = comment_factory()
db.session.add_all([user1, user2, comment]) db.session.add_all([user1, user2, comment])
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'): with patch('szurubooru.func.comments.serialize_comment'):
with fake_datetime('1997-12-01'): with fake_datetime('1997-12-01'):
result = api.comment_api.set_comment_score( api.comment_api.set_comment_score(
context_factory(params={'score': 1}, user=user1), context_factory(params={'score': 1}, user=user1),
{'comment_id': comment.comment_id}) {'comment_id': comment.comment_id})
with fake_datetime('1997-12-02'): with fake_datetime('1997-12-02'):
result = api.comment_api.set_comment_score( api.comment_api.set_comment_score(
context_factory(params={'score': -1}, user=user2), context_factory(params={'score': -1}, user=user2),
{'comment_id': comment.comment_id}) {'comment_id': comment.comment_id})
comment = db.session.query(db.Comment).one() comment = db.session.query(db.Comment).one()
assert db.session.query(db.CommentScore).count() == 2 assert db.session.query(db.CommentScore).count() == 2
assert comment.score == 0 assert comment.score == 0
def test_trying_to_omit_mandatory_field( def test_trying_to_omit_mandatory_field(
user_factory, comment_factory, context_factory): user_factory, comment_factory, context_factory):
user = user_factory() user = user_factory()
@ -112,8 +119,8 @@ def test_trying_to_omit_mandatory_field(
context_factory(params={}, user=user), context_factory(params={}, user=user),
{'comment_id': comment.comment_id}) {'comment_id': comment.comment_id})
def test_trying_to_update_non_existing(
user_factory, comment_factory, context_factory): def test_trying_to_update_non_existing(user_factory, context_factory):
with pytest.raises(comments.CommentNotFoundError): with pytest.raises(comments.CommentNotFoundError):
api.comment_api.set_comment_score( api.comment_api.set_comment_score(
context_factory( context_factory(
@ -121,6 +128,7 @@ def test_trying_to_update_non_existing(
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=db.User.RANK_REGULAR)),
{'comment_id': 5}) {'comment_id': 5})
def test_trying_to_rate_without_privileges( def test_trying_to_rate_without_privileges(
user_factory, comment_factory, context_factory): user_factory, comment_factory, context_factory):
comment = comment_factory() comment = comment_factory()

View file

@ -1,8 +1,9 @@
from unittest.mock import patch
import pytest import pytest
import unittest.mock
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import comments from szurubooru.func import comments
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({ config_injector({
@ -12,11 +13,12 @@ def inject_config(config_injector):
}, },
}) })
def test_retrieving_multiple(user_factory, comment_factory, context_factory): def test_retrieving_multiple(user_factory, comment_factory, context_factory):
comment1 = comment_factory(text='text 1') comment1 = comment_factory(text='text 1')
comment2 = comment_factory(text='text 2') comment2 = comment_factory(text='text 2')
db.session.add_all([comment1, comment2]) db.session.add_all([comment1, comment2])
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'): with patch('szurubooru.func.comments.serialize_comment'):
comments.serialize_comment.return_value = 'serialized comment' comments.serialize_comment.return_value = 'serialized comment'
result = api.comment_api.get_comments( result = api.comment_api.get_comments(
context_factory( context_factory(
@ -30,6 +32,7 @@ def test_retrieving_multiple(user_factory, comment_factory, context_factory):
'results': ['serialized comment', 'serialized comment'], 'results': ['serialized comment', 'serialized comment'],
} }
def test_trying_to_retrieve_multiple_without_privileges( def test_trying_to_retrieve_multiple_without_privileges(
user_factory, context_factory): user_factory, context_factory):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
@ -38,11 +41,12 @@ def test_trying_to_retrieve_multiple_without_privileges(
params={'query': '', 'page': 1}, params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS))) user=user_factory(rank=db.User.RANK_ANONYMOUS)))
def test_retrieving_single(user_factory, comment_factory, context_factory): def test_retrieving_single(user_factory, comment_factory, context_factory):
comment = comment_factory(text='dummy text') comment = comment_factory(text='dummy text')
db.session.add(comment) db.session.add(comment)
db.session.flush() db.session.flush()
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'): with patch('szurubooru.func.comments.serialize_comment'):
comments.serialize_comment.return_value = 'serialized comment' comments.serialize_comment.return_value = 'serialized comment'
result = api.comment_api.get_comment( result = api.comment_api.get_comment(
context_factory( context_factory(
@ -50,6 +54,7 @@ def test_retrieving_single(user_factory, comment_factory, context_factory):
{'comment_id': comment.comment_id}) {'comment_id': comment.comment_id})
assert result == 'serialized comment' assert result == 'serialized comment'
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
with pytest.raises(comments.CommentNotFoundError): with pytest.raises(comments.CommentNotFoundError):
api.comment_api.get_comment( api.comment_api.get_comment(
@ -57,6 +62,7 @@ def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=db.User.RANK_REGULAR)),
{'comment_id': 5}) {'comment_id': 5})
def test_trying_to_retrieve_single_without_privileges( def test_trying_to_retrieve_single_without_privileges(
user_factory, context_factory): user_factory, context_factory):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):

View file

@ -1,9 +1,10 @@
import pytest
import unittest.mock
from datetime import datetime from datetime import datetime
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import comments from szurubooru.func import comments
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({ config_injector({
@ -13,13 +14,14 @@ def inject_config(config_injector):
}, },
}) })
def test_simple_updating( def test_simple_updating(
user_factory, comment_factory, context_factory, fake_datetime): user_factory, comment_factory, context_factory, fake_datetime):
user = user_factory(rank=db.User.RANK_REGULAR) user = user_factory(rank=db.User.RANK_REGULAR)
comment = comment_factory(user=user) comment = comment_factory(user=user)
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'), \ with patch('szurubooru.func.comments.serialize_comment'), \
fake_datetime('1997-12-01'): fake_datetime('1997-12-01'):
comments.serialize_comment.return_value = 'serialized comment' comments.serialize_comment.return_value = 'serialized comment'
result = api.comment_api.update_comment( result = api.comment_api.update_comment(
@ -29,6 +31,7 @@ def test_simple_updating(
assert result == 'serialized comment' assert result == 'serialized comment'
assert comment.last_edit_time == datetime(1997, 12, 1) assert comment.last_edit_time == datetime(1997, 12, 1)
@pytest.mark.parametrize('params,expected_exception', [ @pytest.mark.parametrize('params,expected_exception', [
({'text': None}, comments.EmptyCommentTextError), ({'text': None}, comments.EmptyCommentTextError),
({'text': ''}, comments.EmptyCommentTextError), ({'text': ''}, comments.EmptyCommentTextError),
@ -37,7 +40,11 @@ def test_simple_updating(
({'text': ['']}, comments.EmptyCommentTextError), ({'text': ['']}, comments.EmptyCommentTextError),
]) ])
def test_trying_to_pass_invalid_params( def test_trying_to_pass_invalid_params(
user_factory, comment_factory, context_factory, params, expected_exception): user_factory,
comment_factory,
context_factory,
params,
expected_exception):
user = user_factory() user = user_factory()
comment = comment_factory(user=user) comment = comment_factory(user=user)
db.session.add(comment) db.session.add(comment)
@ -48,6 +55,7 @@ def test_trying_to_pass_invalid_params(
params={**params, **{'version': 1}}, user=user), params={**params, **{'version': 1}}, user=user),
{'comment_id': comment.comment_id}) {'comment_id': comment.comment_id})
def test_trying_to_omit_mandatory_field( def test_trying_to_omit_mandatory_field(
user_factory, comment_factory, context_factory): user_factory, comment_factory, context_factory):
user = user_factory() user = user_factory()
@ -59,6 +67,7 @@ def test_trying_to_omit_mandatory_field(
context_factory(params={'version': 1}, user=user), context_factory(params={'version': 1}, user=user),
{'comment_id': comment.comment_id}) {'comment_id': comment.comment_id})
def test_trying_to_update_non_existing(user_factory, context_factory): def test_trying_to_update_non_existing(user_factory, context_factory):
with pytest.raises(comments.CommentNotFoundError): with pytest.raises(comments.CommentNotFoundError):
api.comment_api.update_comment( api.comment_api.update_comment(
@ -67,6 +76,7 @@ def test_trying_to_update_non_existing(user_factory, context_factory):
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=db.User.RANK_REGULAR)),
{'comment_id': 5}) {'comment_id': 5})
def test_trying_to_update_someones_comment_without_privileges( def test_trying_to_update_someones_comment_without_privileges(
user_factory, comment_factory, context_factory): user_factory, comment_factory, context_factory):
user = user_factory(rank=db.User.RANK_REGULAR) user = user_factory(rank=db.User.RANK_REGULAR)
@ -80,6 +90,7 @@ def test_trying_to_update_someones_comment_without_privileges(
params={'text': 'new text', 'version': 1}, user=user2), params={'text': 'new text', 'version': 1}, user=user2),
{'comment_id': comment.comment_id}) {'comment_id': comment.comment_id})
def test_updating_someones_comment_with_privileges( def test_updating_someones_comment_with_privileges(
user_factory, comment_factory, context_factory): user_factory, comment_factory, context_factory):
user = user_factory(rank=db.User.RANK_REGULAR) user = user_factory(rank=db.User.RANK_REGULAR)
@ -87,7 +98,7 @@ def test_updating_someones_comment_with_privileges(
comment = comment_factory(user=user) comment = comment_factory(user=user)
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'): with patch('szurubooru.func.comments.serialize_comment'):
api.comment_api.update_comment( api.comment_api.update_comment(
context_factory( context_factory(
params={'text': 'new text', 'version': 1}, user=user2), params={'text': 'new text', 'version': 1}, user=user2),

View file

@ -1,6 +1,7 @@
from datetime import datetime from datetime import datetime
from szurubooru import api, db from szurubooru import api, db
def test_info_api( def test_info_api(
tmpdir, config_injector, context_factory, post_factory, fake_datetime): tmpdir, config_injector, context_factory, post_factory, fake_datetime):
directory = tmpdir.mkdir('data') directory = tmpdir.mkdir('data')
@ -45,7 +46,7 @@ def test_info_api(
with fake_datetime('2016-01-01 13:59'): with fake_datetime('2016-01-01 13:59'):
assert api.info_api.get_info(context_factory()) == { assert api.info_api.get_info(context_factory()) == {
'postCount': 2, 'postCount': 2,
'diskUsage': 3, # still 3 - it's cached 'diskUsage': 3, # still 3 - it's cached
'featuredPost': None, 'featuredPost': None,
'featuringTime': None, 'featuringTime': None,
'featuringUser': None, 'featuringUser': None,
@ -55,7 +56,7 @@ def test_info_api(
with fake_datetime('2016-01-01 14:01'): with fake_datetime('2016-01-01 14:01'):
assert api.info_api.get_info(context_factory()) == { assert api.info_api.get_info(context_factory()) == {
'postCount': 2, 'postCount': 2,
'diskUsage': 6, # cache expired 'diskUsage': 6, # cache expired
'featuredPost': None, 'featuredPost': None,
'featuringTime': None, 'featuringTime': None,
'featuringUser': None, 'featuringUser': None,

View file

@ -1,21 +1,23 @@
from unittest.mock import patch
import pytest import pytest
import unittest.mock
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import auth, mailer from szurubooru.func import auth, mailer
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(tmpdir, config_injector): def inject_config(config_injector):
config_injector({ config_injector({
'secret': 'x', 'secret': 'x',
'base_url': 'http://example.com/', 'base_url': 'http://example.com/',
'name': 'Test instance', 'name': 'Test instance',
}) })
def test_reset_sending_email(context_factory, user_factory): def test_reset_sending_email(context_factory, user_factory):
db.session.add(user_factory( db.session.add(user_factory(
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')) name='u1', rank=db.User.RANK_REGULAR, email='user@example.com'))
for initiating_user in ['u1', 'user@example.com']: for initiating_user in ['u1', 'user@example.com']:
with unittest.mock.patch('szurubooru.func.mailer.send_mail'): with patch('szurubooru.func.mailer.send_mail'):
assert api.password_reset_api.start_password_reset( assert api.password_reset_api.start_password_reset(
context_factory(), {'user_name': initiating_user}) == {} context_factory(), {'user_name': initiating_user}) == {}
mailer.send_mail.assert_called_once_with( mailer.send_mail.assert_called_once_with(
@ -27,17 +29,21 @@ def test_reset_sending_email(context_factory, user_factory):
'ink: http://example.com/password-reset/u1:4ac0be176fb36' + 'ink: http://example.com/password-reset/u1:4ac0be176fb36' +
'4f13ee6b634c43220e2\nOtherwise, please ignore this email.') '4f13ee6b634c43220e2\nOtherwise, please ignore this email.')
def test_trying_to_reset_non_existing(context_factory): def test_trying_to_reset_non_existing(context_factory):
with pytest.raises(errors.NotFoundError): with pytest.raises(errors.NotFoundError):
api.password_reset_api.start_password_reset( api.password_reset_api.start_password_reset(
context_factory(), {'user_name': 'u1'}) context_factory(), {'user_name': 'u1'})
def test_trying_to_reset_without_email(context_factory, user_factory): def test_trying_to_reset_without_email(context_factory, user_factory):
db.session.add(user_factory(name='u1', rank=db.User.RANK_REGULAR, email=None)) db.session.add(
user_factory(name='u1', rank=db.User.RANK_REGULAR, email=None))
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError):
api.password_reset_api.start_password_reset( api.password_reset_api.start_password_reset(
context_factory(), {'user_name': 'u1'}) context_factory(), {'user_name': 'u1'})
def test_confirming_with_good_token(context_factory, user_factory): def test_confirming_with_good_token(context_factory, user_factory):
user = user_factory( user = user_factory(
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com') name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')
@ -50,11 +56,13 @@ def test_confirming_with_good_token(context_factory, user_factory):
assert user.password_hash != old_hash assert user.password_hash != old_hash
assert auth.is_valid_password(user, result['password']) is True assert auth.is_valid_password(user, result['password']) is True
def test_trying_to_confirm_non_existing(context_factory): def test_trying_to_confirm_non_existing(context_factory):
with pytest.raises(errors.NotFoundError): with pytest.raises(errors.NotFoundError):
api.password_reset_api.finish_password_reset( api.password_reset_api.finish_password_reset(
context_factory(), {'user_name': 'u1'}) context_factory(), {'user_name': 'u1'})
def test_trying_to_confirm_without_token(context_factory, user_factory): def test_trying_to_confirm_without_token(context_factory, user_factory):
db.session.add(user_factory( db.session.add(user_factory(
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')) name='u1', rank=db.User.RANK_REGULAR, email='user@example.com'))
@ -62,6 +70,7 @@ def test_trying_to_confirm_without_token(context_factory, user_factory):
api.password_reset_api.finish_password_reset( api.password_reset_api.finish_password_reset(
context_factory(params={}), {'user_name': 'u1'}) context_factory(params={}), {'user_name': 'u1'})
def test_trying_to_confirm_with_bad_token(context_factory, user_factory): def test_trying_to_confirm_with_bad_token(context_factory, user_factory):
db.session.add(user_factory( db.session.add(user_factory(
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')) name='u1', rank=db.User.RANK_REGULAR, email='user@example.com'))

View file

@ -1,8 +1,9 @@
from unittest.mock import patch
import pytest import pytest
import unittest.mock
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import posts, tags, snapshots, net from szurubooru.func import posts, tags, snapshots, net
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({ config_injector({
@ -13,6 +14,7 @@ def inject_config(config_injector):
}, },
}) })
def test_creating_minimal_posts( def test_creating_minimal_posts(
context_factory, post_factory, user_factory): context_factory, post_factory, user_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=db.User.RANK_REGULAR)
@ -20,16 +22,16 @@ def test_creating_minimal_posts(
db.session.add(post) db.session.add(post)
db.session.flush() db.session.flush()
with unittest.mock.patch('szurubooru.func.posts.create_post'), \ with patch('szurubooru.func.posts.create_post'), \
unittest.mock.patch('szurubooru.func.posts.update_post_safety'), \ patch('szurubooru.func.posts.update_post_safety'), \
unittest.mock.patch('szurubooru.func.posts.update_post_source'), \ patch('szurubooru.func.posts.update_post_source'), \
unittest.mock.patch('szurubooru.func.posts.update_post_relations'), \ patch('szurubooru.func.posts.update_post_relations'), \
unittest.mock.patch('szurubooru.func.posts.update_post_notes'), \ patch('szurubooru.func.posts.update_post_notes'), \
unittest.mock.patch('szurubooru.func.posts.update_post_flags'), \ patch('szurubooru.func.posts.update_post_flags'), \
unittest.mock.patch('szurubooru.func.posts.update_post_thumbnail'), \ patch('szurubooru.func.posts.update_post_thumbnail'), \
unittest.mock.patch('szurubooru.func.posts.serialize_post'), \ patch('szurubooru.func.posts.serialize_post'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'), \ patch('szurubooru.func.tags.export_to_json'), \
unittest.mock.patch('szurubooru.func.snapshots.save_entity_creation'): patch('szurubooru.func.snapshots.save_entity_creation'):
posts.create_post.return_value = (post, []) posts.create_post.return_value = (post, [])
posts.serialize_post.return_value = 'serialized post' posts.serialize_post.return_value = 'serialized post'
@ -48,32 +50,36 @@ def test_creating_minimal_posts(
assert result == 'serialized post' assert result == 'serialized post'
posts.create_post.assert_called_once_with( posts.create_post.assert_called_once_with(
'post-content', ['tag1', 'tag2'], auth_user) 'post-content', ['tag1', 'tag2'], auth_user)
posts.update_post_thumbnail.assert_called_once_with(post, 'post-thumbnail') posts.update_post_thumbnail.assert_called_once_with(
post, 'post-thumbnail')
posts.update_post_safety.assert_called_once_with(post, 'safe') posts.update_post_safety.assert_called_once_with(post, 'safe')
posts.update_post_source.assert_called_once_with(post, None) posts.update_post_source.assert_called_once_with(post, None)
posts.update_post_relations.assert_called_once_with(post, []) posts.update_post_relations.assert_called_once_with(post, [])
posts.update_post_notes.assert_called_once_with(post, []) posts.update_post_notes.assert_called_once_with(post, [])
posts.update_post_flags.assert_called_once_with(post, []) posts.update_post_flags.assert_called_once_with(post, [])
posts.update_post_thumbnail.assert_called_once_with(post, 'post-thumbnail') posts.update_post_thumbnail.assert_called_once_with(
posts.serialize_post.assert_called_once_with(post, auth_user, options=None) post, 'post-thumbnail')
posts.serialize_post.assert_called_once_with(
post, auth_user, options=None)
tags.export_to_json.assert_called_once_with() tags.export_to_json.assert_called_once_with()
snapshots.save_entity_creation.assert_called_once_with(post, auth_user) snapshots.save_entity_creation.assert_called_once_with(post, auth_user)
def test_creating_full_posts(context_factory, post_factory, user_factory): def test_creating_full_posts(context_factory, post_factory, user_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=db.User.RANK_REGULAR)
post = post_factory() post = post_factory()
db.session.add(post) db.session.add(post)
db.session.flush() db.session.flush()
with unittest.mock.patch('szurubooru.func.posts.create_post'), \ with patch('szurubooru.func.posts.create_post'), \
unittest.mock.patch('szurubooru.func.posts.update_post_safety'), \ patch('szurubooru.func.posts.update_post_safety'), \
unittest.mock.patch('szurubooru.func.posts.update_post_source'), \ patch('szurubooru.func.posts.update_post_source'), \
unittest.mock.patch('szurubooru.func.posts.update_post_relations'), \ patch('szurubooru.func.posts.update_post_relations'), \
unittest.mock.patch('szurubooru.func.posts.update_post_notes'), \ patch('szurubooru.func.posts.update_post_notes'), \
unittest.mock.patch('szurubooru.func.posts.update_post_flags'), \ patch('szurubooru.func.posts.update_post_flags'), \
unittest.mock.patch('szurubooru.func.posts.serialize_post'), \ patch('szurubooru.func.posts.serialize_post'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'), \ patch('szurubooru.func.tags.export_to_json'), \
unittest.mock.patch('szurubooru.func.snapshots.save_entity_creation'): patch('szurubooru.func.snapshots.save_entity_creation'):
posts.create_post.return_value = (post, []) posts.create_post.return_value = (post, [])
posts.serialize_post.return_value = 'serialized post' posts.serialize_post.return_value = 'serialized post'
@ -98,12 +104,16 @@ def test_creating_full_posts(context_factory, post_factory, user_factory):
posts.update_post_safety.assert_called_once_with(post, 'safe') posts.update_post_safety.assert_called_once_with(post, 'safe')
posts.update_post_source.assert_called_once_with(post, 'source') posts.update_post_source.assert_called_once_with(post, 'source')
posts.update_post_relations.assert_called_once_with(post, [1, 2]) posts.update_post_relations.assert_called_once_with(post, [1, 2])
posts.update_post_notes.assert_called_once_with(post, ['note1', 'note2']) posts.update_post_notes.assert_called_once_with(
posts.update_post_flags.assert_called_once_with(post, ['flag1', 'flag2']) post, ['note1', 'note2'])
posts.serialize_post.assert_called_once_with(post, auth_user, options=None) posts.update_post_flags.assert_called_once_with(
post, ['flag1', 'flag2'])
posts.serialize_post.assert_called_once_with(
post, auth_user, options=None)
tags.export_to_json.assert_called_once_with() tags.export_to_json.assert_called_once_with()
snapshots.save_entity_creation.assert_called_once_with(post, auth_user) snapshots.save_entity_creation.assert_called_once_with(post, auth_user)
def test_anonymous_uploads( def test_anonymous_uploads(
config_injector, context_factory, post_factory, user_factory): config_injector, context_factory, post_factory, user_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=db.User.RANK_REGULAR)
@ -111,11 +121,11 @@ def test_anonymous_uploads(
db.session.add(post) db.session.add(post)
db.session.flush() db.session.flush()
with unittest.mock.patch('szurubooru.func.tags.export_to_json'), \ with patch('szurubooru.func.tags.export_to_json'), \
unittest.mock.patch('szurubooru.func.snapshots.save_entity_creation'), \ patch('szurubooru.func.snapshots.save_entity_creation'), \
unittest.mock.patch('szurubooru.func.posts.serialize_post'), \ patch('szurubooru.func.posts.serialize_post'), \
unittest.mock.patch('szurubooru.func.posts.create_post'), \ patch('szurubooru.func.posts.create_post'), \
unittest.mock.patch('szurubooru.func.posts.update_post_source'): patch('szurubooru.func.posts.update_post_source'):
config_injector({ config_injector({
'privileges': {'posts:create:anonymous': db.User.RANK_REGULAR}, 'privileges': {'posts:create:anonymous': db.User.RANK_REGULAR},
}) })
@ -134,6 +144,7 @@ def test_anonymous_uploads(
posts.create_post.assert_called_once_with( posts.create_post.assert_called_once_with(
'post-content', ['tag1', 'tag2'], None) 'post-content', ['tag1', 'tag2'], None)
def test_creating_from_url_saves_source( def test_creating_from_url_saves_source(
config_injector, context_factory, post_factory, user_factory): config_injector, context_factory, post_factory, user_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=db.User.RANK_REGULAR)
@ -141,12 +152,12 @@ def test_creating_from_url_saves_source(
db.session.add(post) db.session.add(post)
db.session.flush() db.session.flush()
with unittest.mock.patch('szurubooru.func.net.download'), \ with patch('szurubooru.func.net.download'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'), \ patch('szurubooru.func.tags.export_to_json'), \
unittest.mock.patch('szurubooru.func.snapshots.save_entity_creation'), \ patch('szurubooru.func.snapshots.save_entity_creation'), \
unittest.mock.patch('szurubooru.func.posts.serialize_post'), \ patch('szurubooru.func.posts.serialize_post'), \
unittest.mock.patch('szurubooru.func.posts.create_post'), \ patch('szurubooru.func.posts.create_post'), \
unittest.mock.patch('szurubooru.func.posts.update_post_source'): patch('szurubooru.func.posts.update_post_source'):
config_injector({ config_injector({
'privileges': {'posts:create:identified': db.User.RANK_REGULAR}, 'privileges': {'posts:create:identified': db.User.RANK_REGULAR},
}) })
@ -165,6 +176,7 @@ def test_creating_from_url_saves_source(
b'content', ['tag1', 'tag2'], auth_user) b'content', ['tag1', 'tag2'], auth_user)
posts.update_post_source.assert_called_once_with(post, 'example.com') posts.update_post_source.assert_called_once_with(post, 'example.com')
def test_creating_from_url_with_source_specified( def test_creating_from_url_with_source_specified(
config_injector, context_factory, post_factory, user_factory): config_injector, context_factory, post_factory, user_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=db.User.RANK_REGULAR)
@ -172,12 +184,12 @@ def test_creating_from_url_with_source_specified(
db.session.add(post) db.session.add(post)
db.session.flush() db.session.flush()
with unittest.mock.patch('szurubooru.func.net.download'), \ with patch('szurubooru.func.net.download'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'), \ patch('szurubooru.func.tags.export_to_json'), \
unittest.mock.patch('szurubooru.func.snapshots.save_entity_creation'), \ patch('szurubooru.func.snapshots.save_entity_creation'), \
unittest.mock.patch('szurubooru.func.posts.serialize_post'), \ patch('szurubooru.func.posts.serialize_post'), \
unittest.mock.patch('szurubooru.func.posts.create_post'), \ patch('szurubooru.func.posts.create_post'), \
unittest.mock.patch('szurubooru.func.posts.update_post_source'): patch('szurubooru.func.posts.update_post_source'):
config_injector({ config_injector({
'privileges': {'posts:create:identified': db.User.RANK_REGULAR}, 'privileges': {'posts:create:identified': db.User.RANK_REGULAR},
}) })
@ -197,6 +209,7 @@ def test_creating_from_url_with_source_specified(
b'content', ['tag1', 'tag2'], auth_user) b'content', ['tag1', 'tag2'], auth_user)
posts.update_post_source.assert_called_once_with(post, 'example2.com') posts.update_post_source.assert_called_once_with(post, 'example2.com')
@pytest.mark.parametrize('field', ['tags', 'safety']) @pytest.mark.parametrize('field', ['tags', 'safety'])
def test_trying_to_omit_mandatory_field(context_factory, user_factory, field): def test_trying_to_omit_mandatory_field(context_factory, user_factory, field):
params = { params = {
@ -211,6 +224,7 @@ def test_trying_to_omit_mandatory_field(context_factory, user_factory, field):
files={'content': '...'}, files={'content': '...'},
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=db.User.RANK_REGULAR)))
def test_trying_to_omit_content(context_factory, user_factory): def test_trying_to_omit_content(context_factory, user_factory):
with pytest.raises(errors.MissingRequiredFileError): with pytest.raises(errors.MissingRequiredFileError):
api.post_api.create_post( api.post_api.create_post(
@ -221,12 +235,15 @@ def test_trying_to_omit_content(context_factory, user_factory):
}, },
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=db.User.RANK_REGULAR)))
def test_trying_to_create_post_without_privileges(context_factory, user_factory):
def test_trying_to_create_post_without_privileges(
context_factory, user_factory):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
api.post_api.create_post(context_factory( api.post_api.create_post(context_factory(
params='whatever', params='whatever',
user=user_factory(rank=db.User.RANK_ANONYMOUS))) user=user_factory(rank=db.User.RANK_ANONYMOUS)))
def test_trying_to_create_tags_without_privileges( def test_trying_to_create_tags_without_privileges(
config_injector, context_factory, user_factory): config_injector, context_factory, user_factory):
config_injector({ config_injector({
@ -237,8 +254,8 @@ def test_trying_to_create_tags_without_privileges(
}, },
}) })
with pytest.raises(errors.AuthError), \ with pytest.raises(errors.AuthError), \
unittest.mock.patch('szurubooru.func.posts.update_post_content'), \ patch('szurubooru.func.posts.update_post_content'), \
unittest.mock.patch('szurubooru.func.posts.update_post_tags'): patch('szurubooru.func.posts.update_post_tags'):
posts.update_post_tags.return_value = ['new-tag'] posts.update_post_tags.return_value = ['new-tag']
api.post_api.create_post( api.post_api.create_post(
context_factory( context_factory(

View file

@ -1,16 +1,18 @@
from unittest.mock import patch
import pytest import pytest
import unittest.mock
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import posts, tags from szurubooru.func import posts, tags
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({'privileges': {'posts:delete': db.User.RANK_REGULAR}}) config_injector({'privileges': {'posts:delete': db.User.RANK_REGULAR}})
def test_deleting(user_factory, post_factory, context_factory): def test_deleting(user_factory, post_factory, context_factory):
db.session.add(post_factory(id=1)) db.session.add(post_factory(id=1))
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.tags.export_to_json'): with patch('szurubooru.func.tags.export_to_json'):
result = api.post_api.delete_post( result = api.post_api.delete_post(
context_factory( context_factory(
params={'version': 1}, params={'version': 1},
@ -20,12 +22,14 @@ def test_deleting(user_factory, post_factory, context_factory):
assert db.session.query(db.Post).count() == 0 assert db.session.query(db.Post).count() == 0
tags.export_to_json.assert_called_once_with() tags.export_to_json.assert_called_once_with()
def test_trying_to_delete_non_existing(user_factory, context_factory): def test_trying_to_delete_non_existing(user_factory, context_factory):
with pytest.raises(posts.PostNotFoundError): with pytest.raises(posts.PostNotFoundError):
api.post_api.delete_post( api.post_api.delete_post(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
{'post_id': 999}) {'post_id': 999})
def test_trying_to_delete_without_privileges( def test_trying_to_delete_without_privileges(
user_factory, post_factory, context_factory): user_factory, post_factory, context_factory):
db.session.add(post_factory(id=1)) db.session.add(post_factory(id=1))

View file

@ -1,20 +1,22 @@
import pytest
import unittest.mock
from datetime import datetime from datetime import datetime
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import posts from szurubooru.func import posts
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({'privileges': {'posts:favorite': db.User.RANK_REGULAR}}) config_injector({'privileges': {'posts:favorite': db.User.RANK_REGULAR}})
def test_adding_to_favorites( def test_adding_to_favorites(
user_factory, post_factory, context_factory, fake_datetime): user_factory, post_factory, context_factory, fake_datetime):
post = post_factory() post = post_factory()
db.session.add(post) db.session.add(post)
db.session.commit() db.session.commit()
assert post.score == 0 assert post.score == 0
with unittest.mock.patch('szurubooru.func.posts.serialize_post'), \ with patch('szurubooru.func.posts.serialize_post'), \
fake_datetime('1997-12-01'): fake_datetime('1997-12-01'):
posts.serialize_post.return_value = 'serialized post' posts.serialize_post.return_value = 'serialized post'
result = api.post_api.add_post_to_favorites( result = api.post_api.add_post_to_favorites(
@ -27,6 +29,7 @@ def test_adding_to_favorites(
assert post.favorite_count == 1 assert post.favorite_count == 1
assert post.score == 1 assert post.score == 1
def test_removing_from_favorites( def test_removing_from_favorites(
user_factory, post_factory, context_factory, fake_datetime): user_factory, post_factory, context_factory, fake_datetime):
user = user_factory() user = user_factory()
@ -34,7 +37,7 @@ def test_removing_from_favorites(
db.session.add(post) db.session.add(post)
db.session.commit() db.session.commit()
assert post.score == 0 assert post.score == 0
with unittest.mock.patch('szurubooru.func.posts.serialize_post'): with patch('szurubooru.func.posts.serialize_post'):
with fake_datetime('1997-12-01'): with fake_datetime('1997-12-01'):
api.post_api.add_post_to_favorites( api.post_api.add_post_to_favorites(
context_factory(user=user), context_factory(user=user),
@ -49,13 +52,14 @@ def test_removing_from_favorites(
assert db.session.query(db.PostFavorite).count() == 0 assert db.session.query(db.PostFavorite).count() == 0
assert post.favorite_count == 0 assert post.favorite_count == 0
def test_favoriting_twice( def test_favoriting_twice(
user_factory, post_factory, context_factory, fake_datetime): user_factory, post_factory, context_factory, fake_datetime):
user = user_factory() user = user_factory()
post = post_factory() post = post_factory()
db.session.add(post) db.session.add(post)
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.posts.serialize_post'): with patch('szurubooru.func.posts.serialize_post'):
with fake_datetime('1997-12-01'): with fake_datetime('1997-12-01'):
api.post_api.add_post_to_favorites( api.post_api.add_post_to_favorites(
context_factory(user=user), context_factory(user=user),
@ -68,13 +72,14 @@ def test_favoriting_twice(
assert db.session.query(db.PostFavorite).count() == 1 assert db.session.query(db.PostFavorite).count() == 1
assert post.favorite_count == 1 assert post.favorite_count == 1
def test_removing_twice( def test_removing_twice(
user_factory, post_factory, context_factory, fake_datetime): user_factory, post_factory, context_factory, fake_datetime):
user = user_factory() user = user_factory()
post = post_factory() post = post_factory()
db.session.add(post) db.session.add(post)
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.posts.serialize_post'): with patch('szurubooru.func.posts.serialize_post'):
with fake_datetime('1997-12-01'): with fake_datetime('1997-12-01'):
api.post_api.add_post_to_favorites( api.post_api.add_post_to_favorites(
context_factory(user=user), context_factory(user=user),
@ -91,6 +96,7 @@ def test_removing_twice(
assert db.session.query(db.PostFavorite).count() == 0 assert db.session.query(db.PostFavorite).count() == 0
assert post.favorite_count == 0 assert post.favorite_count == 0
def test_favorites_from_multiple_users( def test_favorites_from_multiple_users(
user_factory, post_factory, context_factory, fake_datetime): user_factory, post_factory, context_factory, fake_datetime):
user1 = user_factory() user1 = user_factory()
@ -98,7 +104,7 @@ def test_favorites_from_multiple_users(
post = post_factory() post = post_factory()
db.session.add_all([user1, user2, post]) db.session.add_all([user1, user2, post])
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.posts.serialize_post'): with patch('szurubooru.func.posts.serialize_post'):
with fake_datetime('1997-12-01'): with fake_datetime('1997-12-01'):
api.post_api.add_post_to_favorites( api.post_api.add_post_to_favorites(
context_factory(user=user1), context_factory(user=user1),
@ -112,12 +118,14 @@ def test_favorites_from_multiple_users(
assert post.favorite_count == 2 assert post.favorite_count == 2
assert post.last_favorite_time == datetime(1997, 12, 2) assert post.last_favorite_time == datetime(1997, 12, 2)
def test_trying_to_update_non_existing(user_factory, context_factory): def test_trying_to_update_non_existing(user_factory, context_factory):
with pytest.raises(posts.PostNotFoundError): with pytest.raises(posts.PostNotFoundError):
api.post_api.add_post_to_favorites( api.post_api.add_post_to_favorites(
context_factory(user=user_factory()), context_factory(user=user_factory()),
{'post_id': 5}) {'post_id': 5})
def test_trying_to_rate_without_privileges( def test_trying_to_rate_without_privileges(
user_factory, post_factory, context_factory): user_factory, post_factory, context_factory):
post = post_factory() post = post_factory()

View file

@ -1,8 +1,9 @@
from unittest.mock import patch
import pytest import pytest
import unittest.mock
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import posts from szurubooru.func import posts
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({ config_injector({
@ -12,14 +13,12 @@ def inject_config(config_injector):
}, },
}) })
def test_no_featured_post(user_factory, post_factory, context_factory):
assert posts.try_get_featured_post() is None
def test_featuring(user_factory, post_factory, context_factory): def test_featuring(user_factory, post_factory, context_factory):
db.session.add(post_factory(id=1)) db.session.add(post_factory(id=1))
db.session.commit() db.session.commit()
assert not posts.get_post_by_id(1).is_featured assert not posts.get_post_by_id(1).is_featured
with unittest.mock.patch('szurubooru.func.posts.serialize_post'): with patch('szurubooru.func.posts.serialize_post'):
posts.serialize_post.return_value = 'serialized post' posts.serialize_post.return_value = 'serialized post'
result = api.post_api.set_featured_post( result = api.post_api.set_featured_post(
context_factory( context_factory(
@ -34,18 +33,19 @@ def test_featuring(user_factory, post_factory, context_factory):
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=db.User.RANK_REGULAR)))
assert result == 'serialized post' assert result == 'serialized post'
def test_trying_to_omit_required_parameter(
user_factory, post_factory, context_factory): def test_trying_to_omit_required_parameter(user_factory, context_factory):
with pytest.raises(errors.MissingRequiredParameterError): with pytest.raises(errors.MissingRequiredParameterError):
api.post_api.set_featured_post( api.post_api.set_featured_post(
context_factory( context_factory(
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=db.User.RANK_REGULAR)))
def test_trying_to_feature_the_same_post_twice( def test_trying_to_feature_the_same_post_twice(
user_factory, post_factory, context_factory): user_factory, post_factory, context_factory):
db.session.add(post_factory(id=1)) db.session.add(post_factory(id=1))
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.posts.serialize_post'): with patch('szurubooru.func.posts.serialize_post'):
api.post_api.set_featured_post( api.post_api.set_featured_post(
context_factory( context_factory(
params={'id': 1}, params={'id': 1},
@ -56,6 +56,7 @@ def test_trying_to_feature_the_same_post_twice(
params={'id': 1}, params={'id': 1},
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=db.User.RANK_REGULAR)))
def test_featuring_one_post_after_another( def test_featuring_one_post_after_another(
user_factory, post_factory, context_factory, fake_datetime): user_factory, post_factory, context_factory, fake_datetime):
db.session.add(post_factory(id=1)) db.session.add(post_factory(id=1))
@ -64,14 +65,14 @@ def test_featuring_one_post_after_another(
assert posts.try_get_featured_post() is None assert posts.try_get_featured_post() is None
assert not posts.get_post_by_id(1).is_featured assert not posts.get_post_by_id(1).is_featured
assert not posts.get_post_by_id(2).is_featured assert not posts.get_post_by_id(2).is_featured
with unittest.mock.patch('szurubooru.func.posts.serialize_post'): with patch('szurubooru.func.posts.serialize_post'):
with fake_datetime('1997'): with fake_datetime('1997'):
result = api.post_api.set_featured_post( api.post_api.set_featured_post(
context_factory( context_factory(
params={'id': 1}, params={'id': 1},
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=db.User.RANK_REGULAR)))
with fake_datetime('1998'): with fake_datetime('1998'):
result = api.post_api.set_featured_post( api.post_api.set_featured_post(
context_factory( context_factory(
params={'id': 2}, params={'id': 2},
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=db.User.RANK_REGULAR)))
@ -80,6 +81,7 @@ def test_featuring_one_post_after_another(
assert not posts.get_post_by_id(1).is_featured assert not posts.get_post_by_id(1).is_featured
assert posts.get_post_by_id(2).is_featured assert posts.get_post_by_id(2).is_featured
def test_trying_to_feature_non_existing(user_factory, context_factory): def test_trying_to_feature_non_existing(user_factory, context_factory):
with pytest.raises(posts.PostNotFoundError): with pytest.raises(posts.PostNotFoundError):
api.post_api.set_featured_post( api.post_api.set_featured_post(
@ -87,6 +89,7 @@ def test_trying_to_feature_non_existing(user_factory, context_factory):
params={'id': 1}, params={'id': 1},
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=db.User.RANK_REGULAR)))
def test_trying_to_feature_without_privileges(user_factory, context_factory): def test_trying_to_feature_without_privileges(user_factory, context_factory):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
api.post_api.set_featured_post( api.post_api.set_featured_post(
@ -94,6 +97,7 @@ def test_trying_to_feature_without_privileges(user_factory, context_factory):
params={'id': 1}, params={'id': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS))) user=user_factory(rank=db.User.RANK_ANONYMOUS)))
def test_getting_featured_post_without_privileges_to_view( def test_getting_featured_post_without_privileges_to_view(
user_factory, context_factory): user_factory, context_factory):
api.post_api.get_featured_post( api.post_api.get_featured_post(

View file

@ -1,87 +1,93 @@
from unittest.mock import patch
import pytest import pytest
import unittest.mock
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import posts, scores from szurubooru.func import posts
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({'privileges': {'posts:score': db.User.RANK_REGULAR}}) config_injector({'privileges': {'posts:score': db.User.RANK_REGULAR}})
def test_simple_rating( def test_simple_rating(
user_factory, post_factory, context_factory, fake_datetime): user_factory, post_factory, context_factory, fake_datetime):
post = post_factory() post = post_factory()
db.session.add(post) db.session.add(post)
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.posts.serialize_post'): with patch('szurubooru.func.posts.serialize_post'), \
fake_datetime('1997-12-01'):
posts.serialize_post.return_value = 'serialized post' posts.serialize_post.return_value = 'serialized post'
with fake_datetime('1997-12-01'): result = api.post_api.set_post_score(
result = api.post_api.set_post_score( context_factory(
context_factory( params={'score': 1}, user=user_factory()),
params={'score': 1}, user=user_factory()), {'post_id': post.post_id})
{'post_id': post.post_id})
assert result == 'serialized post' assert result == 'serialized post'
post = db.session.query(db.Post).one() post = db.session.query(db.Post).one()
assert db.session.query(db.PostScore).count() == 1 assert db.session.query(db.PostScore).count() == 1
assert post is not None assert post is not None
assert post.score == 1 assert post.score == 1
def test_updating_rating( def test_updating_rating(
user_factory, post_factory, context_factory, fake_datetime): user_factory, post_factory, context_factory, fake_datetime):
user = user_factory() user = user_factory()
post = post_factory() post = post_factory()
db.session.add(post) db.session.add(post)
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.posts.serialize_post'): with patch('szurubooru.func.posts.serialize_post'):
with fake_datetime('1997-12-01'): with fake_datetime('1997-12-01'):
result = api.post_api.set_post_score( api.post_api.set_post_score(
context_factory(params={'score': 1}, user=user), context_factory(params={'score': 1}, user=user),
{'post_id': post.post_id}) {'post_id': post.post_id})
with fake_datetime('1997-12-02'): with fake_datetime('1997-12-02'):
result = api.post_api.set_post_score( api.post_api.set_post_score(
context_factory(params={'score': -1}, user=user), context_factory(params={'score': -1}, user=user),
{'post_id': post.post_id}) {'post_id': post.post_id})
post = db.session.query(db.Post).one() post = db.session.query(db.Post).one()
assert db.session.query(db.PostScore).count() == 1 assert db.session.query(db.PostScore).count() == 1
assert post.score == -1 assert post.score == -1
def test_updating_rating_to_zero( def test_updating_rating_to_zero(
user_factory, post_factory, context_factory, fake_datetime): user_factory, post_factory, context_factory, fake_datetime):
user = user_factory() user = user_factory()
post = post_factory() post = post_factory()
db.session.add(post) db.session.add(post)
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.posts.serialize_post'): with patch('szurubooru.func.posts.serialize_post'):
with fake_datetime('1997-12-01'): with fake_datetime('1997-12-01'):
result = api.post_api.set_post_score( api.post_api.set_post_score(
context_factory(params={'score': 1}, user=user), context_factory(params={'score': 1}, user=user),
{'post_id': post.post_id}) {'post_id': post.post_id})
with fake_datetime('1997-12-02'): with fake_datetime('1997-12-02'):
result = api.post_api.set_post_score( api.post_api.set_post_score(
context_factory(params={'score': 0}, user=user), context_factory(params={'score': 0}, user=user),
{'post_id': post.post_id}) {'post_id': post.post_id})
post = db.session.query(db.Post).one() post = db.session.query(db.Post).one()
assert db.session.query(db.PostScore).count() == 0 assert db.session.query(db.PostScore).count() == 0
assert post.score == 0 assert post.score == 0
def test_deleting_rating( def test_deleting_rating(
user_factory, post_factory, context_factory, fake_datetime): user_factory, post_factory, context_factory, fake_datetime):
user = user_factory() user = user_factory()
post = post_factory() post = post_factory()
db.session.add(post) db.session.add(post)
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.posts.serialize_post'): with patch('szurubooru.func.posts.serialize_post'):
with fake_datetime('1997-12-01'): with fake_datetime('1997-12-01'):
result = api.post_api.set_post_score( api.post_api.set_post_score(
context_factory(params={'score': 1}, user=user), context_factory(params={'score': 1}, user=user),
{'post_id': post.post_id}) {'post_id': post.post_id})
with fake_datetime('1997-12-02'): with fake_datetime('1997-12-02'):
result = api.post_api.delete_post_score( api.post_api.delete_post_score(
context_factory(user=user), context_factory(user=user),
{'post_id': post.post_id}) {'post_id': post.post_id})
post = db.session.query(db.Post).one() post = db.session.query(db.Post).one()
assert db.session.query(db.PostScore).count() == 0 assert db.session.query(db.PostScore).count() == 0
assert post.score == 0 assert post.score == 0
def test_ratings_from_multiple_users( def test_ratings_from_multiple_users(
user_factory, post_factory, context_factory, fake_datetime): user_factory, post_factory, context_factory, fake_datetime):
user1 = user_factory() user1 = user_factory()
@ -89,19 +95,20 @@ def test_ratings_from_multiple_users(
post = post_factory() post = post_factory()
db.session.add_all([user1, user2, post]) db.session.add_all([user1, user2, post])
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.posts.serialize_post'): with patch('szurubooru.func.posts.serialize_post'):
with fake_datetime('1997-12-01'): with fake_datetime('1997-12-01'):
result = api.post_api.set_post_score( api.post_api.set_post_score(
context_factory(params={'score': 1}, user=user1), context_factory(params={'score': 1}, user=user1),
{'post_id': post.post_id}) {'post_id': post.post_id})
with fake_datetime('1997-12-02'): with fake_datetime('1997-12-02'):
result = api.post_api.set_post_score( api.post_api.set_post_score(
context_factory(params={'score': -1}, user=user2), context_factory(params={'score': -1}, user=user2),
{'post_id': post.post_id}) {'post_id': post.post_id})
post = db.session.query(db.Post).one() post = db.session.query(db.Post).one()
assert db.session.query(db.PostScore).count() == 2 assert db.session.query(db.PostScore).count() == 2
assert post.score == 0 assert post.score == 0
def test_trying_to_omit_mandatory_field( def test_trying_to_omit_mandatory_field(
user_factory, post_factory, context_factory): user_factory, post_factory, context_factory):
post = post_factory() post = post_factory()
@ -112,13 +119,14 @@ def test_trying_to_omit_mandatory_field(
context_factory(params={}, user=user_factory()), context_factory(params={}, user=user_factory()),
{'post_id': post.post_id}) {'post_id': post.post_id})
def test_trying_to_update_non_existing(
user_factory, post_factory, context_factory): def test_trying_to_update_non_existing(user_factory, context_factory):
with pytest.raises(posts.PostNotFoundError): with pytest.raises(posts.PostNotFoundError):
api.post_api.set_post_score( api.post_api.set_post_score(
context_factory(params={'score': 1}, user=user_factory()), context_factory(params={'score': 1}, user=user_factory()),
{'post_id': 5}) {'post_id': 5})
def test_trying_to_rate_without_privileges( def test_trying_to_rate_without_privileges(
user_factory, post_factory, context_factory): user_factory, post_factory, context_factory):
post = post_factory() post = post_factory()

View file

@ -1,11 +1,12 @@
import pytest
import unittest.mock
from datetime import datetime from datetime import datetime
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import posts from szurubooru.func import posts
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(tmpdir, config_injector): def inject_config(config_injector):
config_injector({ config_injector({
'privileges': { 'privileges': {
'posts:list': db.User.RANK_REGULAR, 'posts:list': db.User.RANK_REGULAR,
@ -13,11 +14,12 @@ def inject_config(tmpdir, config_injector):
}, },
}) })
def test_retrieving_multiple(user_factory, post_factory, context_factory): def test_retrieving_multiple(user_factory, post_factory, context_factory):
post1 = post_factory(id=1) post1 = post_factory(id=1)
post2 = post_factory(id=2) post2 = post_factory(id=2)
db.session.add_all([post1, post2]) db.session.add_all([post1, post2])
with unittest.mock.patch('szurubooru.func.posts.serialize_post'): with patch('szurubooru.func.posts.serialize_post'):
posts.serialize_post.return_value = 'serialized post' posts.serialize_post.return_value = 'serialized post'
result = api.post_api.get_posts( result = api.post_api.get_posts(
context_factory( context_factory(
@ -31,6 +33,7 @@ def test_retrieving_multiple(user_factory, post_factory, context_factory):
'results': ['serialized post', 'serialized post'], 'results': ['serialized post', 'serialized post'],
} }
def test_using_special_tokens(user_factory, post_factory, context_factory): def test_using_special_tokens(user_factory, post_factory, context_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=db.User.RANK_REGULAR)
post1 = post_factory(id=1) post1 = post_factory(id=1)
@ -39,7 +42,7 @@ def test_using_special_tokens(user_factory, post_factory, context_factory):
user=auth_user, time=datetime.utcnow())] user=auth_user, time=datetime.utcnow())]
db.session.add_all([post1, post2, auth_user]) db.session.add_all([post1, post2, auth_user])
db.session.flush() db.session.flush()
with unittest.mock.patch('szurubooru.func.posts.serialize_post'): with patch('szurubooru.func.posts.serialize_post'):
posts.serialize_post.side_effect = \ posts.serialize_post.side_effect = \
lambda post, *_args, **_kwargs: \ lambda post, *_args, **_kwargs: \
'serialized post %d' % post.post_id 'serialized post %d' % post.post_id
@ -55,8 +58,9 @@ def test_using_special_tokens(user_factory, post_factory, context_factory):
'results': ['serialized post 1'], 'results': ['serialized post 1'],
} }
def test_trying_to_use_special_tokens_without_logging_in( def test_trying_to_use_special_tokens_without_logging_in(
user_factory, post_factory, context_factory, config_injector): user_factory, context_factory, config_injector):
config_injector({ config_injector({
'privileges': {'posts:list': 'anonymous'}, 'privileges': {'posts:list': 'anonymous'},
}) })
@ -66,6 +70,7 @@ def test_trying_to_use_special_tokens_without_logging_in(
params={'query': 'special:fav', 'page': 1}, params={'query': 'special:fav', 'page': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS))) user=user_factory(rank=db.User.RANK_ANONYMOUS)))
def test_trying_to_retrieve_multiple_without_privileges( def test_trying_to_retrieve_multiple_without_privileges(
user_factory, context_factory): user_factory, context_factory):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
@ -74,21 +79,24 @@ def test_trying_to_retrieve_multiple_without_privileges(
params={'query': '', 'page': 1}, params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS))) user=user_factory(rank=db.User.RANK_ANONYMOUS)))
def test_retrieving_single(user_factory, post_factory, context_factory): def test_retrieving_single(user_factory, post_factory, context_factory):
db.session.add(post_factory(id=1)) db.session.add(post_factory(id=1))
with unittest.mock.patch('szurubooru.func.posts.serialize_post'): with patch('szurubooru.func.posts.serialize_post'):
posts.serialize_post.return_value = 'serialized post' posts.serialize_post.return_value = 'serialized post'
result = api.post_api.get_post( result = api.post_api.get_post(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
{'post_id': 1}) {'post_id': 1})
assert result == 'serialized post' assert result == 'serialized post'
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
with pytest.raises(posts.PostNotFoundError): with pytest.raises(posts.PostNotFoundError):
api.post_api.get_post( api.post_api.get_post(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
{'post_id': 999}) {'post_id': 999})
def test_trying_to_retrieve_single_without_privileges( def test_trying_to_retrieve_single_without_privileges(
user_factory, context_factory): user_factory, context_factory):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):

View file

@ -1,11 +1,12 @@
import pytest
import unittest.mock
from datetime import datetime from datetime import datetime
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import posts, tags, snapshots, net from szurubooru.func import posts, tags, snapshots, net
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(tmpdir, config_injector): def inject_config(config_injector):
config_injector({ config_injector({
'privileges': { 'privileges': {
'posts:edit:tags': db.User.RANK_REGULAR, 'posts:edit:tags': db.User.RANK_REGULAR,
@ -20,6 +21,7 @@ def inject_config(tmpdir, config_injector):
}, },
}) })
def test_post_updating( def test_post_updating(
context_factory, post_factory, user_factory, fake_datetime): context_factory, post_factory, user_factory, fake_datetime):
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=db.User.RANK_REGULAR)
@ -27,18 +29,18 @@ def test_post_updating(
db.session.add(post) db.session.add(post)
db.session.flush() db.session.flush()
with unittest.mock.patch('szurubooru.func.posts.create_post'), \ with patch('szurubooru.func.posts.create_post'), \
unittest.mock.patch('szurubooru.func.posts.update_post_tags'), \ patch('szurubooru.func.posts.update_post_tags'), \
unittest.mock.patch('szurubooru.func.posts.update_post_content'), \ patch('szurubooru.func.posts.update_post_content'), \
unittest.mock.patch('szurubooru.func.posts.update_post_thumbnail'), \ patch('szurubooru.func.posts.update_post_thumbnail'), \
unittest.mock.patch('szurubooru.func.posts.update_post_safety'), \ patch('szurubooru.func.posts.update_post_safety'), \
unittest.mock.patch('szurubooru.func.posts.update_post_source'), \ patch('szurubooru.func.posts.update_post_source'), \
unittest.mock.patch('szurubooru.func.posts.update_post_relations'), \ patch('szurubooru.func.posts.update_post_relations'), \
unittest.mock.patch('szurubooru.func.posts.update_post_notes'), \ patch('szurubooru.func.posts.update_post_notes'), \
unittest.mock.patch('szurubooru.func.posts.update_post_flags'), \ patch('szurubooru.func.posts.update_post_flags'), \
unittest.mock.patch('szurubooru.func.posts.serialize_post'), \ patch('szurubooru.func.posts.serialize_post'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'), \ patch('szurubooru.func.tags.export_to_json'), \
unittest.mock.patch('szurubooru.func.snapshots.save_entity_modification'), \ patch('szurubooru.func.snapshots.save_entity_modification'), \
fake_datetime('1997-01-01'): fake_datetime('1997-01-01'):
posts.serialize_post.return_value = 'serialized post' posts.serialize_post.return_value = 'serialized post'
@ -64,28 +66,34 @@ def test_post_updating(
posts.create_post.assert_not_called() posts.create_post.assert_not_called()
posts.update_post_tags.assert_called_once_with(post, ['tag1', 'tag2']) posts.update_post_tags.assert_called_once_with(post, ['tag1', 'tag2'])
posts.update_post_content.assert_called_once_with(post, 'post-content') posts.update_post_content.assert_called_once_with(post, 'post-content')
posts.update_post_thumbnail.assert_called_once_with(post, 'post-thumbnail') posts.update_post_thumbnail.assert_called_once_with(
post, 'post-thumbnail')
posts.update_post_safety.assert_called_once_with(post, 'safe') posts.update_post_safety.assert_called_once_with(post, 'safe')
posts.update_post_source.assert_called_once_with(post, 'source') posts.update_post_source.assert_called_once_with(post, 'source')
posts.update_post_relations.assert_called_once_with(post, [1, 2]) posts.update_post_relations.assert_called_once_with(post, [1, 2])
posts.update_post_notes.assert_called_once_with(post, ['note1', 'note2']) posts.update_post_notes.assert_called_once_with(
posts.update_post_flags.assert_called_once_with(post, ['flag1', 'flag2']) post, ['note1', 'note2'])
posts.serialize_post.assert_called_once_with(post, auth_user, options=None) posts.update_post_flags.assert_called_once_with(
post, ['flag1', 'flag2'])
posts.serialize_post.assert_called_once_with(
post, auth_user, options=None)
tags.export_to_json.assert_called_once_with() tags.export_to_json.assert_called_once_with()
snapshots.save_entity_modification.assert_called_once_with(post, auth_user) snapshots.save_entity_modification.assert_called_once_with(
post, auth_user)
assert post.last_edit_time == datetime(1997, 1, 1) assert post.last_edit_time == datetime(1997, 1, 1)
def test_uploading_from_url_saves_source( def test_uploading_from_url_saves_source(
context_factory, post_factory, user_factory): context_factory, post_factory, user_factory):
post = post_factory() post = post_factory()
db.session.add(post) db.session.add(post)
db.session.flush() db.session.flush()
with unittest.mock.patch('szurubooru.func.net.download'), \ with patch('szurubooru.func.net.download'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'), \ patch('szurubooru.func.tags.export_to_json'), \
unittest.mock.patch('szurubooru.func.snapshots.save_entity_modification'), \ patch('szurubooru.func.snapshots.save_entity_modification'), \
unittest.mock.patch('szurubooru.func.posts.serialize_post'), \ patch('szurubooru.func.posts.serialize_post'), \
unittest.mock.patch('szurubooru.func.posts.update_post_content'), \ patch('szurubooru.func.posts.update_post_content'), \
unittest.mock.patch('szurubooru.func.posts.update_post_source'): patch('szurubooru.func.posts.update_post_source'):
net.download.return_value = b'content' net.download.return_value = b'content'
api.post_api.update_post( api.post_api.update_post(
context_factory( context_factory(
@ -96,17 +104,18 @@ def test_uploading_from_url_saves_source(
posts.update_post_content.assert_called_once_with(post, b'content') posts.update_post_content.assert_called_once_with(post, b'content')
posts.update_post_source.assert_called_once_with(post, 'example.com') posts.update_post_source.assert_called_once_with(post, 'example.com')
def test_uploading_from_url_with_source_specified( def test_uploading_from_url_with_source_specified(
context_factory, post_factory, user_factory): context_factory, post_factory, user_factory):
post = post_factory() post = post_factory()
db.session.add(post) db.session.add(post)
db.session.flush() db.session.flush()
with unittest.mock.patch('szurubooru.func.net.download'), \ with patch('szurubooru.func.net.download'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'), \ patch('szurubooru.func.tags.export_to_json'), \
unittest.mock.patch('szurubooru.func.snapshots.save_entity_modification'), \ patch('szurubooru.func.snapshots.save_entity_modification'), \
unittest.mock.patch('szurubooru.func.posts.serialize_post'), \ patch('szurubooru.func.posts.serialize_post'), \
unittest.mock.patch('szurubooru.func.posts.update_post_content'), \ patch('szurubooru.func.posts.update_post_content'), \
unittest.mock.patch('szurubooru.func.posts.update_post_source'): patch('szurubooru.func.posts.update_post_source'):
net.download.return_value = b'content' net.download.return_value = b'content'
api.post_api.update_post( api.post_api.update_post(
context_factory( context_factory(
@ -120,6 +129,7 @@ def test_uploading_from_url_with_source_specified(
posts.update_post_content.assert_called_once_with(post, b'content') posts.update_post_content.assert_called_once_with(post, b'content')
posts.update_post_source.assert_called_once_with(post, 'example2.com') posts.update_post_source.assert_called_once_with(post, 'example2.com')
def test_trying_to_update_non_existing(context_factory, user_factory): def test_trying_to_update_non_existing(context_factory, user_factory):
with pytest.raises(posts.PostNotFoundError): with pytest.raises(posts.PostNotFoundError):
api.post_api.update_post( api.post_api.update_post(
@ -128,18 +138,19 @@ def test_trying_to_update_non_existing(context_factory, user_factory):
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=db.User.RANK_REGULAR)),
{'post_id': 1}) {'post_id': 1})
@pytest.mark.parametrize('privilege,files,params', [
('posts:edit:tags', {}, {'tags': '...'}), @pytest.mark.parametrize('files,params', [
('posts:edit:safety', {}, {'safety': '...'}), ({}, {'tags': '...'}),
('posts:edit:source', {}, {'source': '...'}), ({}, {'safety': '...'}),
('posts:edit:relations', {}, {'relations': '...'}), ({}, {'source': '...'}),
('posts:edit:notes', {}, {'notes': '...'}), ({}, {'relations': '...'}),
('posts:edit:flags', {}, {'flags': '...'}), ({}, {'notes': '...'}),
('posts:edit:content', {'content': '...'}, {}), ({}, {'flags': '...'}),
('posts:edit:thumbnail', {'thumbnail': '...'}, {}), ({'content': '...'}, {}),
({'thumbnail': '...'}, {}),
]) ])
def test_trying_to_update_field_without_privileges( def test_trying_to_update_field_without_privileges(
context_factory, post_factory, user_factory, files, params, privilege): context_factory, post_factory, user_factory, files, params):
post = post_factory() post = post_factory()
db.session.add(post) db.session.add(post)
db.session.flush() db.session.flush()
@ -151,13 +162,14 @@ def test_trying_to_update_field_without_privileges(
user=user_factory(rank=db.User.RANK_ANONYMOUS)), user=user_factory(rank=db.User.RANK_ANONYMOUS)),
{'post_id': post.post_id}) {'post_id': post.post_id})
def test_trying_to_create_tags_without_privileges( def test_trying_to_create_tags_without_privileges(
context_factory, post_factory, user_factory): context_factory, post_factory, user_factory):
post = post_factory() post = post_factory()
db.session.add(post) db.session.add(post)
db.session.flush() db.session.flush()
with pytest.raises(errors.AuthError), \ with pytest.raises(errors.AuthError), \
unittest.mock.patch('szurubooru.func.posts.update_post_tags'): patch('szurubooru.func.posts.update_post_tags'):
posts.update_post_tags.return_value = ['new-tag'] posts.update_post_tags.return_value = ['new-tag']
api.post_api.update_post( api.post_api.update_post(
context_factory( context_factory(

View file

@ -1,7 +1,8 @@
import pytest
from datetime import datetime from datetime import datetime
import pytest
from szurubooru import api, db, errors from szurubooru import api, db, errors
def snapshot_factory(): def snapshot_factory():
snapshot = db.Snapshot() snapshot = db.Snapshot()
snapshot.creation_time = datetime(1999, 1, 1) snapshot.creation_time = datetime(1999, 1, 1)
@ -12,12 +13,14 @@ def snapshot_factory():
snapshot.data = '{}' snapshot.data = '{}'
return snapshot return snapshot
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({ config_injector({
'privileges': {'snapshots:list': db.User.RANK_REGULAR}, 'privileges': {'snapshots:list': db.User.RANK_REGULAR},
}) })
def test_retrieving_multiple(user_factory, context_factory): def test_retrieving_multiple(user_factory, context_factory):
snapshot1 = snapshot_factory() snapshot1 = snapshot_factory()
snapshot2 = snapshot_factory() snapshot2 = snapshot_factory()
@ -32,6 +35,7 @@ def test_retrieving_multiple(user_factory, context_factory):
assert result['total'] == 2 assert result['total'] == 2
assert len(result['results']) == 2 assert len(result['results']) == 2
def test_trying_to_retrieve_multiple_without_privileges( def test_trying_to_retrieve_multiple_without_privileges(
user_factory, context_factory): user_factory, context_factory):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):

View file

@ -1,21 +1,24 @@
from unittest.mock import patch
import pytest import pytest
import unittest.mock
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import tag_categories, tags from szurubooru.func import tag_categories, tags
def _update_category_name(category, name): def _update_category_name(category, name):
category.name = name category.name = name
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({ config_injector({
'privileges': {'tag_categories:create': db.User.RANK_REGULAR}, 'privileges': {'tag_categories:create': db.User.RANK_REGULAR},
}) })
def test_creating_category(user_factory, context_factory): def test_creating_category(user_factory, context_factory):
with unittest.mock.patch('szurubooru.func.tag_categories.serialize_category'), \ with patch('szurubooru.func.tag_categories.serialize_category'), \
unittest.mock.patch('szurubooru.func.tag_categories.update_category_name'), \ patch('szurubooru.func.tag_categories.update_category_name'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'): patch('szurubooru.func.tags.export_to_json'):
tag_categories.update_category_name.side_effect = _update_category_name tag_categories.update_category_name.side_effect = _update_category_name
tag_categories.serialize_category.return_value = 'serialized category' tag_categories.serialize_category.return_value = 'serialized category'
result = api.tag_category_api.create_tag_category( result = api.tag_category_api.create_tag_category(
@ -29,6 +32,7 @@ def test_creating_category(user_factory, context_factory):
assert category.tag_count == 0 assert category.tag_count == 0
tags.export_to_json.assert_called_once_with() tags.export_to_json.assert_called_once_with()
@pytest.mark.parametrize('field', ['name', 'color']) @pytest.mark.parametrize('field', ['name', 'color'])
def test_trying_to_omit_mandatory_field(user_factory, context_factory, field): def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
params = { params = {
@ -42,6 +46,7 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
params=params, params=params,
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=db.User.RANK_REGULAR)))
def test_trying_to_create_without_privileges(user_factory, context_factory): def test_trying_to_create_without_privileges(user_factory, context_factory):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
api.tag_category_api.create_tag_category( api.tag_category_api.create_tag_category(

View file

@ -1,19 +1,21 @@
from unittest.mock import patch
import pytest import pytest
import unittest.mock
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import tag_categories, tags from szurubooru.func import tag_categories, tags
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({ config_injector({
'privileges': {'tag_categories:delete': db.User.RANK_REGULAR}, 'privileges': {'tag_categories:delete': db.User.RANK_REGULAR},
}) })
def test_deleting(user_factory, tag_category_factory, context_factory): def test_deleting(user_factory, tag_category_factory, context_factory):
db.session.add(tag_category_factory(name='root')) db.session.add(tag_category_factory(name='root'))
db.session.add(tag_category_factory(name='category')) db.session.add(tag_category_factory(name='category'))
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.tags.export_to_json'): with patch('szurubooru.func.tags.export_to_json'):
result = api.tag_category_api.delete_tag_category( result = api.tag_category_api.delete_tag_category(
context_factory( context_factory(
params={'version': 1}, params={'version': 1},
@ -24,6 +26,7 @@ def test_deleting(user_factory, tag_category_factory, context_factory):
assert db.session.query(db.TagCategory).one().name == 'root' assert db.session.query(db.TagCategory).one().name == 'root'
tags.export_to_json.assert_called_once_with() tags.export_to_json.assert_called_once_with()
def test_trying_to_delete_used( def test_trying_to_delete_used(
user_factory, tag_category_factory, tag_factory, context_factory): user_factory, tag_category_factory, tag_factory, context_factory):
category = tag_category_factory(name='category') category = tag_category_factory(name='category')
@ -40,6 +43,7 @@ def test_trying_to_delete_used(
{'category_name': 'category'}) {'category_name': 'category'})
assert db.session.query(db.TagCategory).count() == 1 assert db.session.query(db.TagCategory).count() == 1
def test_trying_to_delete_last( def test_trying_to_delete_last(
user_factory, tag_category_factory, context_factory): user_factory, tag_category_factory, context_factory):
db.session.add(tag_category_factory(name='root')) db.session.add(tag_category_factory(name='root'))
@ -51,12 +55,14 @@ def test_trying_to_delete_last(
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=db.User.RANK_REGULAR)),
{'category_name': 'root'}) {'category_name': 'root'})
def test_trying_to_delete_non_existing(user_factory, context_factory): def test_trying_to_delete_non_existing(user_factory, context_factory):
with pytest.raises(tag_categories.TagCategoryNotFoundError): with pytest.raises(tag_categories.TagCategoryNotFoundError):
api.tag_category_api.delete_tag_category( api.tag_category_api.delete_tag_category(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
{'category_name': 'bad'}) {'category_name': 'bad'})
def test_trying_to_delete_without_privileges( def test_trying_to_delete_without_privileges(
user_factory, tag_category_factory, context_factory): user_factory, tag_category_factory, context_factory):
db.session.add(tag_category_factory(name='category')) db.session.add(tag_category_factory(name='category'))

View file

@ -2,6 +2,7 @@ import pytest
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import tag_categories from szurubooru.func import tag_categories
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({ config_injector({
@ -11,6 +12,7 @@ def inject_config(config_injector):
}, },
}) })
def test_retrieving_multiple( def test_retrieving_multiple(
user_factory, tag_category_factory, context_factory): user_factory, tag_category_factory, context_factory):
db.session.add_all([ db.session.add_all([
@ -21,7 +23,9 @@ def test_retrieving_multiple(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR))) context_factory(user=user_factory(rank=db.User.RANK_REGULAR)))
assert [cat['name'] for cat in result['results']] == ['c1', 'c2'] assert [cat['name'] for cat in result['results']] == ['c1', 'c2']
def test_retrieving_single(user_factory, tag_category_factory, context_factory):
def test_retrieving_single(
user_factory, tag_category_factory, context_factory):
db.session.add(tag_category_factory(name='cat')) db.session.add(tag_category_factory(name='cat'))
result = api.tag_category_api.get_tag_category( result = api.tag_category_api.get_tag_category(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
@ -35,12 +39,14 @@ def test_retrieving_single(user_factory, tag_category_factory, context_factory):
'version': 1, 'version': 1,
} }
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
with pytest.raises(tag_categories.TagCategoryNotFoundError): with pytest.raises(tag_categories.TagCategoryNotFoundError):
api.tag_category_api.get_tag_category( api.tag_category_api.get_tag_category(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
{'category_name': '-'}) {'category_name': '-'})
def test_trying_to_retrieve_single_without_privileges( def test_trying_to_retrieve_single_without_privileges(
user_factory, context_factory): user_factory, context_factory):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):

View file

@ -1,11 +1,13 @@
from unittest.mock import patch
import pytest import pytest
import unittest.mock
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import tag_categories, tags from szurubooru.func import tag_categories, tags
def _update_category_name(category, name): def _update_category_name(category, name):
category.name = name category.name = name
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({ config_injector({
@ -16,14 +18,15 @@ def inject_config(config_injector):
}, },
}) })
def test_simple_updating(user_factory, tag_category_factory, context_factory): def test_simple_updating(user_factory, tag_category_factory, context_factory):
category = tag_category_factory(name='name', color='black') category = tag_category_factory(name='name', color='black')
db.session.add(category) db.session.add(category)
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.tag_categories.serialize_category'), \ with patch('szurubooru.func.tag_categories.serialize_category'), \
unittest.mock.patch('szurubooru.func.tag_categories.update_category_name'), \ patch('szurubooru.func.tag_categories.update_category_name'), \
unittest.mock.patch('szurubooru.func.tag_categories.update_category_color'), \ patch('szurubooru.func.tag_categories.update_category_color'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'): patch('szurubooru.func.tags.export_to_json'):
tag_categories.update_category_name.side_effect = _update_category_name tag_categories.update_category_name.side_effect = _update_category_name
tag_categories.serialize_category.return_value = 'serialized category' tag_categories.serialize_category.return_value = 'serialized category'
result = api.tag_category_api.update_tag_category( result = api.tag_category_api.update_tag_category(
@ -36,10 +39,13 @@ def test_simple_updating(user_factory, tag_category_factory, context_factory):
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=db.User.RANK_REGULAR)),
{'category_name': 'name'}) {'category_name': 'name'})
assert result == 'serialized category' assert result == 'serialized category'
tag_categories.update_category_name.assert_called_once_with(category, 'changed') tag_categories.update_category_name.assert_called_once_with(
tag_categories.update_category_color.assert_called_once_with(category, 'white') category, 'changed')
tag_categories.update_category_color.assert_called_once_with(
category, 'white')
tags.export_to_json.assert_called_once_with() tags.export_to_json.assert_called_once_with()
@pytest.mark.parametrize('field', ['name', 'color']) @pytest.mark.parametrize('field', ['name', 'color'])
def test_omitting_optional_field( def test_omitting_optional_field(
user_factory, tag_category_factory, context_factory, field): user_factory, tag_category_factory, context_factory, field):
@ -50,15 +56,16 @@ def test_omitting_optional_field(
'color': 'white', 'color': 'white',
} }
del params[field] del params[field]
with unittest.mock.patch('szurubooru.func.tag_categories.serialize_category'), \ with patch('szurubooru.func.tag_categories.serialize_category'), \
unittest.mock.patch('szurubooru.func.tag_categories.update_category_name'), \ patch('szurubooru.func.tag_categories.update_category_name'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'): patch('szurubooru.func.tags.export_to_json'):
api.tag_category_api.update_tag_category( api.tag_category_api.update_tag_category(
context_factory( context_factory(
params={**params, **{'version': 1}}, params={**params, **{'version': 1}},
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=db.User.RANK_REGULAR)),
{'category_name': 'name'}) {'category_name': 'name'})
def test_trying_to_update_non_existing(user_factory, context_factory): def test_trying_to_update_non_existing(user_factory, context_factory):
with pytest.raises(tag_categories.TagCategoryNotFoundError): with pytest.raises(tag_categories.TagCategoryNotFoundError):
api.tag_category_api.update_tag_category( api.tag_category_api.update_tag_category(
@ -67,6 +74,7 @@ def test_trying_to_update_non_existing(user_factory, context_factory):
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=db.User.RANK_REGULAR)),
{'category_name': 'bad'}) {'category_name': 'bad'})
@pytest.mark.parametrize('params', [ @pytest.mark.parametrize('params', [
{'name': 'whatever'}, {'name': 'whatever'},
{'color': 'whatever'}, {'color': 'whatever'},
@ -82,13 +90,14 @@ def test_trying_to_update_without_privileges(
user=user_factory(rank=db.User.RANK_ANONYMOUS)), user=user_factory(rank=db.User.RANK_ANONYMOUS)),
{'category_name': 'dummy'}) {'category_name': 'dummy'})
def test_set_as_default(user_factory, tag_category_factory, context_factory): def test_set_as_default(user_factory, tag_category_factory, context_factory):
category = tag_category_factory(name='name', color='black') category = tag_category_factory(name='name', color='black')
db.session.add(category) db.session.add(category)
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.tag_categories.serialize_category'), \ with patch('szurubooru.func.tag_categories.serialize_category'), \
unittest.mock.patch('szurubooru.func.tag_categories.set_default_category'), \ patch('szurubooru.func.tag_categories.set_default_category'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'): patch('szurubooru.func.tags.export_to_json'):
tag_categories.update_category_name.side_effect = _update_category_name tag_categories.update_category_name.side_effect = _update_category_name
tag_categories.serialize_category.return_value = 'serialized category' tag_categories.serialize_category.return_value = 'serialized category'
result = api.tag_category_api.set_tag_category_as_default( result = api.tag_category_api.set_tag_category_as_default(

View file

@ -1,17 +1,19 @@
from unittest.mock import patch
import pytest import pytest
import unittest.mock
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import tags, tag_categories from szurubooru.func import tags
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({'privileges': {'tags:create': db.User.RANK_REGULAR}}) config_injector({'privileges': {'tags:create': db.User.RANK_REGULAR}})
def test_creating_simple_tags(tag_factory, user_factory, context_factory): def test_creating_simple_tags(tag_factory, user_factory, context_factory):
with unittest.mock.patch('szurubooru.func.tags.create_tag'), \ with patch('szurubooru.func.tags.create_tag'), \
unittest.mock.patch('szurubooru.func.tags.get_or_create_tags_by_names'), \ patch('szurubooru.func.tags.get_or_create_tags_by_names'), \
unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \ patch('szurubooru.func.tags.serialize_tag'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'): patch('szurubooru.func.tags.export_to_json'):
tags.get_or_create_tags_by_names.return_value = ([], []) tags.get_or_create_tags_by_names.return_value = ([], [])
tags.create_tag.return_value = tag_factory() tags.create_tag.return_value = tag_factory()
tags.serialize_tag.return_value = 'serialized tag' tags.serialize_tag.return_value = 'serialized tag'
@ -30,6 +32,7 @@ def test_creating_simple_tags(tag_factory, user_factory, context_factory):
['tag1', 'tag2'], 'meta', ['sug1', 'sug2'], ['imp1', 'imp2']) ['tag1', 'tag2'], 'meta', ['sug1', 'sug2'], ['imp1', 'imp2'])
tags.export_to_json.assert_called_once_with() tags.export_to_json.assert_called_once_with()
@pytest.mark.parametrize('field', ['names', 'category']) @pytest.mark.parametrize('field', ['names', 'category'])
def test_trying_to_omit_mandatory_field(user_factory, context_factory, field): def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
params = { params = {
@ -45,6 +48,7 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
params=params, params=params,
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=db.User.RANK_REGULAR)))
@pytest.mark.parametrize('field', ['implications', 'suggestions']) @pytest.mark.parametrize('field', ['implications', 'suggestions'])
def test_omitting_optional_field( def test_omitting_optional_field(
tag_factory, user_factory, context_factory, field): tag_factory, user_factory, context_factory, field):
@ -55,16 +59,18 @@ def test_omitting_optional_field(
'implications': [], 'implications': [],
} }
del params[field] del params[field]
with unittest.mock.patch('szurubooru.func.tags.create_tag'), \ with patch('szurubooru.func.tags.create_tag'), \
unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \ patch('szurubooru.func.tags.serialize_tag'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'): patch('szurubooru.func.tags.export_to_json'):
tags.create_tag.return_value = tag_factory() tags.create_tag.return_value = tag_factory()
api.tag_api.create_tag( api.tag_api.create_tag(
context_factory( context_factory(
params=params, params=params,
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=db.User.RANK_REGULAR)))
def test_trying_to_create_tag_without_privileges(user_factory, context_factory):
def test_trying_to_create_tag_without_privileges(
user_factory, context_factory):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
api.tag_api.create_tag( api.tag_api.create_tag(
context_factory( context_factory(

View file

@ -1,16 +1,18 @@
from unittest.mock import patch
import pytest import pytest
import unittest.mock
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import tags from szurubooru.func import tags
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({'privileges': {'tags:delete': db.User.RANK_REGULAR}}) config_injector({'privileges': {'tags:delete': db.User.RANK_REGULAR}})
def test_deleting(user_factory, tag_factory, context_factory): def test_deleting(user_factory, tag_factory, context_factory):
db.session.add(tag_factory(names=['tag'])) db.session.add(tag_factory(names=['tag']))
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.tags.export_to_json'): with patch('szurubooru.func.tags.export_to_json'):
result = api.tag_api.delete_tag( result = api.tag_api.delete_tag(
context_factory( context_factory(
params={'version': 1}, params={'version': 1},
@ -20,13 +22,15 @@ def test_deleting(user_factory, tag_factory, context_factory):
assert db.session.query(db.Tag).count() == 0 assert db.session.query(db.Tag).count() == 0
tags.export_to_json.assert_called_once_with() tags.export_to_json.assert_called_once_with()
def test_deleting_used(user_factory, tag_factory, context_factory, post_factory):
def test_deleting_used(
user_factory, tag_factory, context_factory, post_factory):
tag = tag_factory(names=['tag']) tag = tag_factory(names=['tag'])
post = post_factory() post = post_factory()
post.tags.append(tag) post.tags.append(tag)
db.session.add_all([tag, post]) db.session.add_all([tag, post])
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.tags.export_to_json'): with patch('szurubooru.func.tags.export_to_json'):
api.tag_api.delete_tag( api.tag_api.delete_tag(
context_factory( context_factory(
params={'version': 1}, params={'version': 1},
@ -36,12 +40,14 @@ def test_deleting_used(user_factory, tag_factory, context_factory, post_factory)
assert db.session.query(db.Tag).count() == 0 assert db.session.query(db.Tag).count() == 0
assert post.tags == [] assert post.tags == []
def test_trying_to_delete_non_existing(user_factory, context_factory): def test_trying_to_delete_non_existing(user_factory, context_factory):
with pytest.raises(tags.TagNotFoundError): with pytest.raises(tags.TagNotFoundError):
api.tag_api.delete_tag( api.tag_api.delete_tag(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
{'tag_name': 'bad'}) {'tag_name': 'bad'})
def test_trying_to_delete_without_privileges( def test_trying_to_delete_without_privileges(
user_factory, tag_factory, context_factory): user_factory, tag_factory, context_factory):
db.session.add(tag_factory(names=['tag'])) db.session.add(tag_factory(names=['tag']))

View file

@ -1,12 +1,14 @@
from unittest.mock import patch
import pytest import pytest
import unittest.mock
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import tags from szurubooru.func import tags
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({'privileges': {'tags:merge': db.User.RANK_REGULAR}}) config_injector({'privileges': {'tags:merge': db.User.RANK_REGULAR}})
def test_merging(user_factory, tag_factory, context_factory, post_factory): def test_merging(user_factory, tag_factory, context_factory, post_factory):
source_tag = tag_factory(names=['source']) source_tag = tag_factory(names=['source'])
target_tag = tag_factory(names=['target']) target_tag = tag_factory(names=['target'])
@ -20,10 +22,10 @@ def test_merging(user_factory, tag_factory, context_factory, post_factory):
db.session.commit() db.session.commit()
assert source_tag.post_count == 1 assert source_tag.post_count == 1
assert target_tag.post_count == 0 assert target_tag.post_count == 0
with unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \ with patch('szurubooru.func.tags.serialize_tag'), \
unittest.mock.patch('szurubooru.func.tags.merge_tags'), \ patch('szurubooru.func.tags.merge_tags'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'): patch('szurubooru.func.tags.export_to_json'):
result = api.tag_api.merge_tags( api.tag_api.merge_tags(
context_factory( context_factory(
params={ params={
'removeVersion': 1, 'removeVersion': 1,
@ -35,6 +37,7 @@ def test_merging(user_factory, tag_factory, context_factory, post_factory):
tags.merge_tags.called_once_with(source_tag, target_tag) tags.merge_tags.called_once_with(source_tag, target_tag)
tags.export_to_json.assert_called_once_with() tags.export_to_json.assert_called_once_with()
@pytest.mark.parametrize( @pytest.mark.parametrize(
'field', ['remove', 'mergeTo', 'removeVersion', 'mergeToVersion']) 'field', ['remove', 'mergeTo', 'removeVersion', 'mergeToVersion'])
def test_trying_to_omit_mandatory_field( def test_trying_to_omit_mandatory_field(
@ -57,6 +60,7 @@ def test_trying_to_omit_mandatory_field(
params=params, params=params,
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=db.User.RANK_REGULAR)))
def test_trying_to_merge_non_existing( def test_trying_to_merge_non_existing(
user_factory, tag_factory, context_factory): user_factory, tag_factory, context_factory):
db.session.add(tag_factory(names=['good'])) db.session.add(tag_factory(names=['good']))
@ -72,14 +76,9 @@ def test_trying_to_merge_non_existing(
params={'remove': 'bad', 'mergeTo': 'good'}, params={'remove': 'bad', 'mergeTo': 'good'},
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=db.User.RANK_REGULAR)))
@pytest.mark.parametrize('params', [
{'names': 'whatever'},
{'category': 'whatever'},
{'suggestions': ['whatever']},
{'implications': ['whatever']},
])
def test_trying_to_merge_without_privileges( def test_trying_to_merge_without_privileges(
user_factory, tag_factory, context_factory, params): user_factory, tag_factory, context_factory):
db.session.add_all([ db.session.add_all([
tag_factory(names=['source']), tag_factory(names=['source']),
tag_factory(names=['target']), tag_factory(names=['target']),

View file

@ -1,8 +1,9 @@
from unittest.mock import patch
import pytest import pytest
import unittest.mock
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import tags from szurubooru.func import tags
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({ config_injector({
@ -12,11 +13,12 @@ def inject_config(config_injector):
}, },
}) })
def test_retrieving_multiple(user_factory, tag_factory, context_factory): def test_retrieving_multiple(user_factory, tag_factory, context_factory):
tag1 = tag_factory(names=['t1']) tag1 = tag_factory(names=['t1'])
tag2 = tag_factory(names=['t2']) tag2 = tag_factory(names=['t2'])
db.session.add_all([tag1, tag2]) db.session.add_all([tag1, tag2])
with unittest.mock.patch('szurubooru.func.tags.serialize_tag'): with patch('szurubooru.func.tags.serialize_tag'):
tags.serialize_tag.return_value = 'serialized tag' tags.serialize_tag.return_value = 'serialized tag'
result = api.tag_api.get_tags( result = api.tag_api.get_tags(
context_factory( context_factory(
@ -30,6 +32,7 @@ def test_retrieving_multiple(user_factory, tag_factory, context_factory):
'results': ['serialized tag', 'serialized tag'], 'results': ['serialized tag', 'serialized tag'],
} }
def test_trying_to_retrieve_multiple_without_privileges( def test_trying_to_retrieve_multiple_without_privileges(
user_factory, context_factory): user_factory, context_factory):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
@ -38,9 +41,10 @@ def test_trying_to_retrieve_multiple_without_privileges(
params={'query': '', 'page': 1}, params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS))) user=user_factory(rank=db.User.RANK_ANONYMOUS)))
def test_retrieving_single(user_factory, tag_factory, context_factory): def test_retrieving_single(user_factory, tag_factory, context_factory):
db.session.add(tag_factory(names=['tag'])) db.session.add(tag_factory(names=['tag']))
with unittest.mock.patch('szurubooru.func.tags.serialize_tag'): with patch('szurubooru.func.tags.serialize_tag'):
tags.serialize_tag.return_value = 'serialized tag' tags.serialize_tag.return_value = 'serialized tag'
result = api.tag_api.get_tag( result = api.tag_api.get_tag(
context_factory( context_factory(
@ -48,6 +52,7 @@ def test_retrieving_single(user_factory, tag_factory, context_factory):
{'tag_name': 'tag'}) {'tag_name': 'tag'})
assert result == 'serialized tag' assert result == 'serialized tag'
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
with pytest.raises(tags.TagNotFoundError): with pytest.raises(tags.TagNotFoundError):
api.tag_api.get_tag( api.tag_api.get_tag(
@ -55,6 +60,7 @@ def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=db.User.RANK_REGULAR)),
{'tag_name': '-'}) {'tag_name': '-'})
def test_trying_to_retrieve_single_without_privileges( def test_trying_to_retrieve_single_without_privileges(
user_factory, context_factory): user_factory, context_factory):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):

View file

@ -1,16 +1,18 @@
from unittest.mock import patch
import pytest import pytest
import unittest.mock
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import tags from szurubooru.func import tags
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({'privileges': {'tags:view': db.User.RANK_REGULAR}}) config_injector({'privileges': {'tags:view': db.User.RANK_REGULAR}})
def test_get_tag_siblings(user_factory, tag_factory, post_factory, context_factory):
def test_get_tag_siblings(user_factory, tag_factory, context_factory):
db.session.add(tag_factory(names=['tag'])) db.session.add(tag_factory(names=['tag']))
with unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \ with patch('szurubooru.func.tags.serialize_tag'), \
unittest.mock.patch('szurubooru.func.tags.get_tag_siblings'): patch('szurubooru.func.tags.get_tag_siblings'):
tags.serialize_tag.side_effect = \ tags.serialize_tag.side_effect = \
lambda tag, *args, **kwargs: \ lambda tag, *args, **kwargs: \
'serialized tag %s' % tag.names[0].name 'serialized tag %s' % tag.names[0].name
@ -34,12 +36,14 @@ def test_get_tag_siblings(user_factory, tag_factory, post_factory, context_facto
], ],
} }
def test_trying_to_retrieve_non_existing(user_factory, context_factory): def test_trying_to_retrieve_non_existing(user_factory, context_factory):
with pytest.raises(tags.TagNotFoundError): with pytest.raises(tags.TagNotFoundError):
api.tag_api.get_tag_siblings( api.tag_api.get_tag_siblings(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
{'tag_name': '-'}) {'tag_name': '-'})
def test_trying_to_retrieve_without_privileges(user_factory, context_factory): def test_trying_to_retrieve_without_privileges(user_factory, context_factory):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
api.tag_api.get_tag_siblings( api.tag_api.get_tag_siblings(

View file

@ -1,8 +1,9 @@
from unittest.mock import patch
import pytest import pytest
import unittest.mock
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import tags from szurubooru.func import tags
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({ config_injector({
@ -16,20 +17,21 @@ def inject_config(config_injector):
}, },
}) })
def test_simple_updating(user_factory, tag_factory, context_factory, fake_datetime):
def test_simple_updating(user_factory, tag_factory, context_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=db.User.RANK_REGULAR)
tag = tag_factory(names=['tag1', 'tag2']) tag = tag_factory(names=['tag1', 'tag2'])
db.session.add(tag) db.session.add(tag)
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.tags.create_tag'), \ with patch('szurubooru.func.tags.create_tag'), \
unittest.mock.patch('szurubooru.func.tags.get_or_create_tags_by_names'), \ patch('szurubooru.func.tags.get_or_create_tags_by_names'), \
unittest.mock.patch('szurubooru.func.tags.update_tag_names'), \ patch('szurubooru.func.tags.update_tag_names'), \
unittest.mock.patch('szurubooru.func.tags.update_tag_category_name'), \ patch('szurubooru.func.tags.update_tag_category_name'), \
unittest.mock.patch('szurubooru.func.tags.update_tag_description'), \ patch('szurubooru.func.tags.update_tag_description'), \
unittest.mock.patch('szurubooru.func.tags.update_tag_suggestions'), \ patch('szurubooru.func.tags.update_tag_suggestions'), \
unittest.mock.patch('szurubooru.func.tags.update_tag_implications'), \ patch('szurubooru.func.tags.update_tag_implications'), \
unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \ patch('szurubooru.func.tags.serialize_tag'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'): patch('szurubooru.func.tags.export_to_json'):
tags.get_or_create_tags_by_names.return_value = ([], []) tags.get_or_create_tags_by_names.return_value = ([], [])
tags.serialize_tag.return_value = 'serialized tag' tags.serialize_tag.return_value = 'serialized tag'
result = api.tag_api.update_tag( result = api.tag_api.update_tag(
@ -49,12 +51,22 @@ def test_simple_updating(user_factory, tag_factory, context_factory, fake_dateti
tags.update_tag_names.assert_called_once_with(tag, ['tag3']) tags.update_tag_names.assert_called_once_with(tag, ['tag3'])
tags.update_tag_category_name.assert_called_once_with(tag, 'character') tags.update_tag_category_name.assert_called_once_with(tag, 'character')
tags.update_tag_description.assert_called_once_with(tag, 'desc') tags.update_tag_description.assert_called_once_with(tag, 'desc')
tags.update_tag_suggestions.assert_called_once_with(tag, ['sug1', 'sug2']) tags.update_tag_suggestions.assert_called_once_with(
tags.update_tag_implications.assert_called_once_with(tag, ['imp1', 'imp2']) tag, ['sug1', 'sug2'])
tags.serialize_tag.assert_called_once_with(tag, options=None) tags.update_tag_implications.assert_called_once_with(
tag, ['imp1', 'imp2'])
tags.serialize_tag.assert_called_once_with(
tag, options=None)
@pytest.mark.parametrize( @pytest.mark.parametrize(
'field', ['names', 'category', 'description', 'implications', 'suggestions']) 'field', [
'names',
'category',
'description',
'implications',
'suggestions',
])
def test_omitting_optional_field( def test_omitting_optional_field(
user_factory, tag_factory, context_factory, field): user_factory, tag_factory, context_factory, field):
db.session.add(tag_factory(names=['tag'])) db.session.add(tag_factory(names=['tag']))
@ -67,17 +79,18 @@ def test_omitting_optional_field(
'implications': [], 'implications': [],
} }
del params[field] del params[field]
with unittest.mock.patch('szurubooru.func.tags.create_tag'), \ with patch('szurubooru.func.tags.create_tag'), \
unittest.mock.patch('szurubooru.func.tags.update_tag_names'), \ patch('szurubooru.func.tags.update_tag_names'), \
unittest.mock.patch('szurubooru.func.tags.update_tag_category_name'), \ patch('szurubooru.func.tags.update_tag_category_name'), \
unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \ patch('szurubooru.func.tags.serialize_tag'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'): patch('szurubooru.func.tags.export_to_json'):
api.tag_api.update_tag( api.tag_api.update_tag(
context_factory( context_factory(
params={**params, **{'version': 1}}, params={**params, **{'version': 1}},
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=db.User.RANK_REGULAR)),
{'tag_name': 'tag'}) {'tag_name': 'tag'})
def test_trying_to_update_non_existing(user_factory, context_factory): def test_trying_to_update_non_existing(user_factory, context_factory):
with pytest.raises(tags.TagNotFoundError): with pytest.raises(tags.TagNotFoundError):
api.tag_api.update_tag( api.tag_api.update_tag(
@ -86,6 +99,7 @@ def test_trying_to_update_non_existing(user_factory, context_factory):
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=db.User.RANK_REGULAR)),
{'tag_name': 'tag1'}) {'tag_name': 'tag1'})
@pytest.mark.parametrize('params', [ @pytest.mark.parametrize('params', [
{'names': 'whatever'}, {'names': 'whatever'},
{'category': 'whatever'}, {'category': 'whatever'},
@ -103,6 +117,7 @@ def test_trying_to_update_without_privileges(
user=user_factory(rank=db.User.RANK_ANONYMOUS)), user=user_factory(rank=db.User.RANK_ANONYMOUS)),
{'tag_name': 'tag'}) {'tag_name': 'tag'})
def test_trying_to_create_tags_without_privileges( def test_trying_to_create_tags_without_privileges(
config_injector, context_factory, tag_factory, user_factory): config_injector, context_factory, tag_factory, user_factory):
tag = tag_factory(names=['tag']) tag = tag_factory(names=['tag'])
@ -113,7 +128,7 @@ def test_trying_to_create_tags_without_privileges(
'tags:edit:suggestions': db.User.RANK_REGULAR, 'tags:edit:suggestions': db.User.RANK_REGULAR,
'tags:edit:implications': db.User.RANK_REGULAR, 'tags:edit:implications': db.User.RANK_REGULAR,
}}) }})
with unittest.mock.patch('szurubooru.func.tags.get_or_create_tags_by_names'): with patch('szurubooru.func.tags.get_or_create_tags_by_names'):
tags.get_or_create_tags_by_names.return_value = ([], ['new-tag']) tags.get_or_create_tags_by_names.return_value = ([], ['new-tag'])
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
api.tag_api.update_tag( api.tag_api.update_tag(

Some files were not shown because too many files have changed in this diff Show more