server/general: embrace most of PEP8
Ignored only the rules about continuing / hanging indentation. Also, added __init__.py to tests so that pylint discovers them. (I don't buy pytest's BS about installing your package.)
This commit is contained in:
parent
af62f8c45a
commit
9aea55e3d1
129 changed files with 2251 additions and 1077 deletions
|
@ -8,11 +8,26 @@ good-names=ex,_,logger
|
||||||
dummy-variables-rgx=_|dummy
|
dummy-variables-rgx=_|dummy
|
||||||
|
|
||||||
[format]
|
[format]
|
||||||
max-line-length=90
|
max-line-length=79
|
||||||
|
|
||||||
[messages control]
|
[messages control]
|
||||||
disable=missing-docstring,no-self-use,too-few-public-methods,multiple-statements
|
|
||||||
reports=no
|
reports=no
|
||||||
|
disable=
|
||||||
|
# we're not java
|
||||||
|
missing-docstring,
|
||||||
|
|
||||||
|
# covered better by pycodestyle
|
||||||
|
bad-continuation,
|
||||||
|
|
||||||
|
# we're adults
|
||||||
|
redefined-builtin,
|
||||||
|
duplicate-code,
|
||||||
|
too-many-return-statements,
|
||||||
|
too-many-arguments,
|
||||||
|
|
||||||
|
# plain stupid
|
||||||
|
no-self-use,
|
||||||
|
too-few-public-methods
|
||||||
|
|
||||||
[typecheck]
|
[typecheck]
|
||||||
generated-members=add|add_all
|
generated-members=add|add_all
|
||||||
|
|
|
@ -3,20 +3,24 @@ from szurubooru import search
|
||||||
from szurubooru.func import auth, comments, posts, scores, util
|
from szurubooru.func import auth, comments, posts, scores, util
|
||||||
from szurubooru.rest import routes
|
from szurubooru.rest import routes
|
||||||
|
|
||||||
|
|
||||||
_search_executor = search.Executor(search.configs.CommentSearchConfig())
|
_search_executor = search.Executor(search.configs.CommentSearchConfig())
|
||||||
|
|
||||||
|
|
||||||
def _serialize(ctx, comment, **kwargs):
|
def _serialize(ctx, comment, **kwargs):
|
||||||
return comments.serialize_comment(
|
return comments.serialize_comment(
|
||||||
comment,
|
comment,
|
||||||
ctx.user,
|
ctx.user,
|
||||||
options=util.get_serialization_options(ctx), **kwargs)
|
options=util.get_serialization_options(ctx), **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/comments/?')
|
@routes.get('/comments/?')
|
||||||
def get_comments(ctx, _params=None):
|
def get_comments(ctx, _params=None):
|
||||||
auth.verify_privilege(ctx.user, 'comments:list')
|
auth.verify_privilege(ctx.user, 'comments:list')
|
||||||
return _search_executor.execute_and_serialize(
|
return _search_executor.execute_and_serialize(
|
||||||
ctx, lambda comment: _serialize(ctx, comment))
|
ctx, lambda comment: _serialize(ctx, comment))
|
||||||
|
|
||||||
|
|
||||||
@routes.post('/comments/?')
|
@routes.post('/comments/?')
|
||||||
def create_comment(ctx, _params=None):
|
def create_comment(ctx, _params=None):
|
||||||
auth.verify_privilege(ctx.user, 'comments:create')
|
auth.verify_privilege(ctx.user, 'comments:create')
|
||||||
|
@ -28,12 +32,14 @@ def create_comment(ctx, _params=None):
|
||||||
ctx.session.commit()
|
ctx.session.commit()
|
||||||
return _serialize(ctx, comment)
|
return _serialize(ctx, comment)
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/comment/(?P<comment_id>[^/]+)/?')
|
@routes.get('/comment/(?P<comment_id>[^/]+)/?')
|
||||||
def get_comment(ctx, params):
|
def get_comment(ctx, params):
|
||||||
auth.verify_privilege(ctx.user, 'comments:view')
|
auth.verify_privilege(ctx.user, 'comments:view')
|
||||||
comment = comments.get_comment_by_id(params['comment_id'])
|
comment = comments.get_comment_by_id(params['comment_id'])
|
||||||
return _serialize(ctx, comment)
|
return _serialize(ctx, comment)
|
||||||
|
|
||||||
|
|
||||||
@routes.put('/comment/(?P<comment_id>[^/]+)/?')
|
@routes.put('/comment/(?P<comment_id>[^/]+)/?')
|
||||||
def update_comment(ctx, params):
|
def update_comment(ctx, params):
|
||||||
comment = comments.get_comment_by_id(params['comment_id'])
|
comment = comments.get_comment_by_id(params['comment_id'])
|
||||||
|
@ -47,6 +53,7 @@ def update_comment(ctx, params):
|
||||||
ctx.session.commit()
|
ctx.session.commit()
|
||||||
return _serialize(ctx, comment)
|
return _serialize(ctx, comment)
|
||||||
|
|
||||||
|
|
||||||
@routes.delete('/comment/(?P<comment_id>[^/]+)/?')
|
@routes.delete('/comment/(?P<comment_id>[^/]+)/?')
|
||||||
def delete_comment(ctx, params):
|
def delete_comment(ctx, params):
|
||||||
comment = comments.get_comment_by_id(params['comment_id'])
|
comment = comments.get_comment_by_id(params['comment_id'])
|
||||||
|
@ -57,6 +64,7 @@ def delete_comment(ctx, params):
|
||||||
ctx.session.commit()
|
ctx.session.commit()
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
@routes.put('/comment/(?P<comment_id>[^/]+)/score/?')
|
@routes.put('/comment/(?P<comment_id>[^/]+)/score/?')
|
||||||
def set_comment_score(ctx, params):
|
def set_comment_score(ctx, params):
|
||||||
auth.verify_privilege(ctx.user, 'comments:score')
|
auth.verify_privilege(ctx.user, 'comments:score')
|
||||||
|
@ -66,6 +74,7 @@ def set_comment_score(ctx, params):
|
||||||
ctx.session.commit()
|
ctx.session.commit()
|
||||||
return _serialize(ctx, comment)
|
return _serialize(ctx, comment)
|
||||||
|
|
||||||
|
|
||||||
@routes.delete('/comment/(?P<comment_id>[^/]+)/score/?')
|
@routes.delete('/comment/(?P<comment_id>[^/]+)/score/?')
|
||||||
def delete_comment_score(ctx, params):
|
def delete_comment_score(ctx, params):
|
||||||
auth.verify_privilege(ctx.user, 'comments:score')
|
auth.verify_privilege(ctx.user, 'comments:score')
|
||||||
|
|
|
@ -4,11 +4,13 @@ from szurubooru import config
|
||||||
from szurubooru.func import posts, users, util
|
from szurubooru.func import posts, users, util
|
||||||
from szurubooru.rest import routes
|
from szurubooru.rest import routes
|
||||||
|
|
||||||
|
|
||||||
_cache_time = None
|
_cache_time = None
|
||||||
_cache_result = None
|
_cache_result = None
|
||||||
|
|
||||||
|
|
||||||
def _get_disk_usage():
|
def _get_disk_usage():
|
||||||
global _cache_time, _cache_result # pylint: disable=global-statement
|
global _cache_time, _cache_result # pylint: disable=global-statement
|
||||||
threshold = datetime.timedelta(hours=1)
|
threshold = datetime.timedelta(hours=1)
|
||||||
now = datetime.datetime.utcnow()
|
now = datetime.datetime.utcnow()
|
||||||
if _cache_time and _cache_time > now - threshold:
|
if _cache_time and _cache_time > now - threshold:
|
||||||
|
@ -22,17 +24,20 @@ def _get_disk_usage():
|
||||||
_cache_result = total_size
|
_cache_result = total_size
|
||||||
return total_size
|
return total_size
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/info/?')
|
@routes.get('/info/?')
|
||||||
def get_info(ctx, _params=None):
|
def get_info(ctx, _params=None):
|
||||||
post_feature = posts.try_get_current_post_feature()
|
post_feature = posts.try_get_current_post_feature()
|
||||||
return {
|
return {
|
||||||
'postCount': posts.get_post_count(),
|
'postCount': posts.get_post_count(),
|
||||||
'diskUsage': _get_disk_usage(),
|
'diskUsage': _get_disk_usage(),
|
||||||
'featuredPost': posts.serialize_post(post_feature.post, ctx.user) \
|
'featuredPost':
|
||||||
if post_feature else None,
|
posts.serialize_post(post_feature.post, ctx.user)
|
||||||
|
if post_feature else None,
|
||||||
'featuringTime': post_feature.time if post_feature else None,
|
'featuringTime': post_feature.time if post_feature else None,
|
||||||
'featuringUser': users.serialize_user(post_feature.user, ctx.user) \
|
'featuringUser':
|
||||||
if post_feature else None,
|
users.serialize_user(post_feature.user, ctx.user)
|
||||||
|
if post_feature else None,
|
||||||
'serverTime': datetime.datetime.utcnow(),
|
'serverTime': datetime.datetime.utcnow(),
|
||||||
'config': {
|
'config': {
|
||||||
'userNameRegex': config.config['user_name_regex'],
|
'userNameRegex': config.config['user_name_regex'],
|
||||||
|
@ -40,7 +45,8 @@ def get_info(ctx, _params=None):
|
||||||
'tagNameRegex': config.config['tag_name_regex'],
|
'tagNameRegex': config.config['tag_name_regex'],
|
||||||
'tagCategoryNameRegex': config.config['tag_category_name_regex'],
|
'tagCategoryNameRegex': config.config['tag_category_name_regex'],
|
||||||
'defaultUserRank': config.config['default_rank'],
|
'defaultUserRank': config.config['default_rank'],
|
||||||
'privileges': util.snake_case_to_lower_camel_case_keys(
|
'privileges':
|
||||||
config.config['privileges']),
|
util.snake_case_to_lower_camel_case_keys(
|
||||||
|
config.config['privileges']),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,12 +2,14 @@ from szurubooru import config, errors
|
||||||
from szurubooru.func import auth, mailer, users, util
|
from szurubooru.func import auth, mailer, users, util
|
||||||
from szurubooru.rest import routes
|
from szurubooru.rest import routes
|
||||||
|
|
||||||
|
|
||||||
MAIL_SUBJECT = 'Password reset for {name}'
|
MAIL_SUBJECT = 'Password reset for {name}'
|
||||||
MAIL_BODY = \
|
MAIL_BODY = \
|
||||||
'You (or someone else) requested to reset your password on {name}.\n' \
|
'You (or someone else) requested to reset your password on {name}.\n' \
|
||||||
'If you wish to proceed, click this link: {url}\n' \
|
'If you wish to proceed, click this link: {url}\n' \
|
||||||
'Otherwise, please ignore this email.'
|
'Otherwise, please ignore this email.'
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/password-reset/(?P<user_name>[^/]+)/?')
|
@routes.get('/password-reset/(?P<user_name>[^/]+)/?')
|
||||||
def start_password_reset(_ctx, params):
|
def start_password_reset(_ctx, params):
|
||||||
''' Send a mail with secure token to the correlated user. '''
|
''' Send a mail with secure token to the correlated user. '''
|
||||||
|
@ -27,6 +29,7 @@ def start_password_reset(_ctx, params):
|
||||||
MAIL_BODY.format(name=config.config['name'], url=url))
|
MAIL_BODY.format(name=config.config['name'], url=url))
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
@routes.post('/password-reset/(?P<user_name>[^/]+)/?')
|
@routes.post('/password-reset/(?P<user_name>[^/]+)/?')
|
||||||
def finish_password_reset(ctx, params):
|
def finish_password_reset(ctx, params):
|
||||||
''' Verify token from mail, generate a new password and return it. '''
|
''' Verify token from mail, generate a new password and return it. '''
|
||||||
|
|
|
@ -1,16 +1,20 @@
|
||||||
import datetime
|
import datetime
|
||||||
from szurubooru import search
|
from szurubooru import search
|
||||||
from szurubooru.func import auth, tags, posts, snapshots, favorites, scores, util
|
|
||||||
from szurubooru.rest import routes
|
from szurubooru.rest import routes
|
||||||
|
from szurubooru.func import (
|
||||||
|
auth, tags, posts, snapshots, favorites, scores, util)
|
||||||
|
|
||||||
|
|
||||||
_search_executor = search.Executor(search.configs.PostSearchConfig())
|
_search_executor = search.Executor(search.configs.PostSearchConfig())
|
||||||
|
|
||||||
|
|
||||||
def _serialize_post(ctx, post):
|
def _serialize_post(ctx, post):
|
||||||
return posts.serialize_post(
|
return posts.serialize_post(
|
||||||
post,
|
post,
|
||||||
ctx.user,
|
ctx.user,
|
||||||
options=util.get_serialization_options(ctx))
|
options=util.get_serialization_options(ctx))
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/posts/?')
|
@routes.get('/posts/?')
|
||||||
def get_posts(ctx, _params=None):
|
def get_posts(ctx, _params=None):
|
||||||
auth.verify_privilege(ctx.user, 'posts:list')
|
auth.verify_privilege(ctx.user, 'posts:list')
|
||||||
|
@ -18,6 +22,7 @@ def get_posts(ctx, _params=None):
|
||||||
return _search_executor.execute_and_serialize(
|
return _search_executor.execute_and_serialize(
|
||||||
ctx, lambda post: _serialize_post(ctx, post))
|
ctx, lambda post: _serialize_post(ctx, post))
|
||||||
|
|
||||||
|
|
||||||
@routes.post('/posts/?')
|
@routes.post('/posts/?')
|
||||||
def create_post(ctx, _params=None):
|
def create_post(ctx, _params=None):
|
||||||
anonymous = ctx.get_param_as_bool('anonymous', default=False)
|
anonymous = ctx.get_param_as_bool('anonymous', default=False)
|
||||||
|
@ -52,12 +57,14 @@ def create_post(ctx, _params=None):
|
||||||
tags.export_to_json()
|
tags.export_to_json()
|
||||||
return _serialize_post(ctx, post)
|
return _serialize_post(ctx, post)
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/post/(?P<post_id>[^/]+)/?')
|
@routes.get('/post/(?P<post_id>[^/]+)/?')
|
||||||
def get_post(ctx, params):
|
def get_post(ctx, params):
|
||||||
auth.verify_privilege(ctx.user, 'posts:view')
|
auth.verify_privilege(ctx.user, 'posts:view')
|
||||||
post = posts.get_post_by_id(params['post_id'])
|
post = posts.get_post_by_id(params['post_id'])
|
||||||
return _serialize_post(ctx, post)
|
return _serialize_post(ctx, post)
|
||||||
|
|
||||||
|
|
||||||
@routes.put('/post/(?P<post_id>[^/]+)/?')
|
@routes.put('/post/(?P<post_id>[^/]+)/?')
|
||||||
def update_post(ctx, params):
|
def update_post(ctx, params):
|
||||||
post = posts.get_post_by_id(params['post_id'])
|
post = posts.get_post_by_id(params['post_id'])
|
||||||
|
@ -98,6 +105,7 @@ def update_post(ctx, params):
|
||||||
tags.export_to_json()
|
tags.export_to_json()
|
||||||
return _serialize_post(ctx, post)
|
return _serialize_post(ctx, post)
|
||||||
|
|
||||||
|
|
||||||
@routes.delete('/post/(?P<post_id>[^/]+)/?')
|
@routes.delete('/post/(?P<post_id>[^/]+)/?')
|
||||||
def delete_post(ctx, params):
|
def delete_post(ctx, params):
|
||||||
auth.verify_privilege(ctx.user, 'posts:delete')
|
auth.verify_privilege(ctx.user, 'posts:delete')
|
||||||
|
@ -109,11 +117,13 @@ def delete_post(ctx, params):
|
||||||
tags.export_to_json()
|
tags.export_to_json()
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/featured-post/?')
|
@routes.get('/featured-post/?')
|
||||||
def get_featured_post(ctx, _params=None):
|
def get_featured_post(ctx, _params=None):
|
||||||
post = posts.try_get_featured_post()
|
post = posts.try_get_featured_post()
|
||||||
return _serialize_post(ctx, post)
|
return _serialize_post(ctx, post)
|
||||||
|
|
||||||
|
|
||||||
@routes.post('/featured-post/?')
|
@routes.post('/featured-post/?')
|
||||||
def set_featured_post(ctx, _params=None):
|
def set_featured_post(ctx, _params=None):
|
||||||
auth.verify_privilege(ctx.user, 'posts:feature')
|
auth.verify_privilege(ctx.user, 'posts:feature')
|
||||||
|
@ -130,6 +140,7 @@ def set_featured_post(ctx, _params=None):
|
||||||
ctx.session.commit()
|
ctx.session.commit()
|
||||||
return _serialize_post(ctx, post)
|
return _serialize_post(ctx, post)
|
||||||
|
|
||||||
|
|
||||||
@routes.put('/post/(?P<post_id>[^/]+)/score/?')
|
@routes.put('/post/(?P<post_id>[^/]+)/score/?')
|
||||||
def set_post_score(ctx, params):
|
def set_post_score(ctx, params):
|
||||||
auth.verify_privilege(ctx.user, 'posts:score')
|
auth.verify_privilege(ctx.user, 'posts:score')
|
||||||
|
@ -139,6 +150,7 @@ def set_post_score(ctx, params):
|
||||||
ctx.session.commit()
|
ctx.session.commit()
|
||||||
return _serialize_post(ctx, post)
|
return _serialize_post(ctx, post)
|
||||||
|
|
||||||
|
|
||||||
@routes.delete('/post/(?P<post_id>[^/]+)/score/?')
|
@routes.delete('/post/(?P<post_id>[^/]+)/score/?')
|
||||||
def delete_post_score(ctx, params):
|
def delete_post_score(ctx, params):
|
||||||
auth.verify_privilege(ctx.user, 'posts:score')
|
auth.verify_privilege(ctx.user, 'posts:score')
|
||||||
|
@ -147,6 +159,7 @@ def delete_post_score(ctx, params):
|
||||||
ctx.session.commit()
|
ctx.session.commit()
|
||||||
return _serialize_post(ctx, post)
|
return _serialize_post(ctx, post)
|
||||||
|
|
||||||
|
|
||||||
@routes.post('/post/(?P<post_id>[^/]+)/favorite/?')
|
@routes.post('/post/(?P<post_id>[^/]+)/favorite/?')
|
||||||
def add_post_to_favorites(ctx, params):
|
def add_post_to_favorites(ctx, params):
|
||||||
auth.verify_privilege(ctx.user, 'posts:favorite')
|
auth.verify_privilege(ctx.user, 'posts:favorite')
|
||||||
|
@ -155,6 +168,7 @@ def add_post_to_favorites(ctx, params):
|
||||||
ctx.session.commit()
|
ctx.session.commit()
|
||||||
return _serialize_post(ctx, post)
|
return _serialize_post(ctx, post)
|
||||||
|
|
||||||
|
|
||||||
@routes.delete('/post/(?P<post_id>[^/]+)/favorite/?')
|
@routes.delete('/post/(?P<post_id>[^/]+)/favorite/?')
|
||||||
def delete_post_from_favorites(ctx, params):
|
def delete_post_from_favorites(ctx, params):
|
||||||
auth.verify_privilege(ctx.user, 'posts:favorite')
|
auth.verify_privilege(ctx.user, 'posts:favorite')
|
||||||
|
@ -163,6 +177,7 @@ def delete_post_from_favorites(ctx, params):
|
||||||
ctx.session.commit()
|
ctx.session.commit()
|
||||||
return _serialize_post(ctx, post)
|
return _serialize_post(ctx, post)
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/post/(?P<post_id>[^/]+)/around/?')
|
@routes.get('/post/(?P<post_id>[^/]+)/around/?')
|
||||||
def get_posts_around(ctx, params):
|
def get_posts_around(ctx, params):
|
||||||
auth.verify_privilege(ctx.user, 'posts:list')
|
auth.verify_privilege(ctx.user, 'posts:list')
|
||||||
|
|
|
@ -2,8 +2,9 @@ from szurubooru import search
|
||||||
from szurubooru.func import auth, snapshots
|
from szurubooru.func import auth, snapshots
|
||||||
from szurubooru.rest import routes
|
from szurubooru.rest import routes
|
||||||
|
|
||||||
_search_executor = search.Executor(
|
|
||||||
search.configs.SnapshotSearchConfig())
|
_search_executor = search.Executor(search.configs.SnapshotSearchConfig())
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/snapshots/?')
|
@routes.get('/snapshots/?')
|
||||||
def get_snapshots(ctx, _params=None):
|
def get_snapshots(ctx, _params=None):
|
||||||
|
|
|
@ -3,12 +3,15 @@ from szurubooru import db, search
|
||||||
from szurubooru.func import auth, tags, util, snapshots
|
from szurubooru.func import auth, tags, util, snapshots
|
||||||
from szurubooru.rest import routes
|
from szurubooru.rest import routes
|
||||||
|
|
||||||
|
|
||||||
_search_executor = search.Executor(search.configs.TagSearchConfig())
|
_search_executor = search.Executor(search.configs.TagSearchConfig())
|
||||||
|
|
||||||
|
|
||||||
def _serialize(ctx, tag):
|
def _serialize(ctx, tag):
|
||||||
return tags.serialize_tag(
|
return tags.serialize_tag(
|
||||||
tag, options=util.get_serialization_options(ctx))
|
tag, options=util.get_serialization_options(ctx))
|
||||||
|
|
||||||
|
|
||||||
def _create_if_needed(tag_names, user):
|
def _create_if_needed(tag_names, user):
|
||||||
if not tag_names:
|
if not tag_names:
|
||||||
return
|
return
|
||||||
|
@ -19,12 +22,14 @@ def _create_if_needed(tag_names, user):
|
||||||
for tag in new_tags:
|
for tag in new_tags:
|
||||||
snapshots.save_entity_creation(tag, user)
|
snapshots.save_entity_creation(tag, user)
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/tags/?')
|
@routes.get('/tags/?')
|
||||||
def get_tags(ctx, _params=None):
|
def get_tags(ctx, _params=None):
|
||||||
auth.verify_privilege(ctx.user, 'tags:list')
|
auth.verify_privilege(ctx.user, 'tags:list')
|
||||||
return _search_executor.execute_and_serialize(
|
return _search_executor.execute_and_serialize(
|
||||||
ctx, lambda tag: _serialize(ctx, tag))
|
ctx, lambda tag: _serialize(ctx, tag))
|
||||||
|
|
||||||
|
|
||||||
@routes.post('/tags/?')
|
@routes.post('/tags/?')
|
||||||
def create_tag(ctx, _params=None):
|
def create_tag(ctx, _params=None):
|
||||||
auth.verify_privilege(ctx.user, 'tags:create')
|
auth.verify_privilege(ctx.user, 'tags:create')
|
||||||
|
@ -50,12 +55,14 @@ def create_tag(ctx, _params=None):
|
||||||
tags.export_to_json()
|
tags.export_to_json()
|
||||||
return _serialize(ctx, tag)
|
return _serialize(ctx, tag)
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/tag/(?P<tag_name>[^/]+)/?')
|
@routes.get('/tag/(?P<tag_name>[^/]+)/?')
|
||||||
def get_tag(ctx, params):
|
def get_tag(ctx, params):
|
||||||
auth.verify_privilege(ctx.user, 'tags:view')
|
auth.verify_privilege(ctx.user, 'tags:view')
|
||||||
tag = tags.get_tag_by_name(params['tag_name'])
|
tag = tags.get_tag_by_name(params['tag_name'])
|
||||||
return _serialize(ctx, tag)
|
return _serialize(ctx, tag)
|
||||||
|
|
||||||
|
|
||||||
@routes.put('/tag/(?P<tag_name>[^/]+)/?')
|
@routes.put('/tag/(?P<tag_name>[^/]+)/?')
|
||||||
def update_tag(ctx, params):
|
def update_tag(ctx, params):
|
||||||
tag = tags.get_tag_by_name(params['tag_name'])
|
tag = tags.get_tag_by_name(params['tag_name'])
|
||||||
|
@ -89,6 +96,7 @@ def update_tag(ctx, params):
|
||||||
tags.export_to_json()
|
tags.export_to_json()
|
||||||
return _serialize(ctx, tag)
|
return _serialize(ctx, tag)
|
||||||
|
|
||||||
|
|
||||||
@routes.delete('/tag/(?P<tag_name>[^/]+)/?')
|
@routes.delete('/tag/(?P<tag_name>[^/]+)/?')
|
||||||
def delete_tag(ctx, params):
|
def delete_tag(ctx, params):
|
||||||
tag = tags.get_tag_by_name(params['tag_name'])
|
tag = tags.get_tag_by_name(params['tag_name'])
|
||||||
|
@ -100,6 +108,7 @@ def delete_tag(ctx, params):
|
||||||
tags.export_to_json()
|
tags.export_to_json()
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
@routes.post('/tag-merge/?')
|
@routes.post('/tag-merge/?')
|
||||||
def merge_tags(ctx, _params=None):
|
def merge_tags(ctx, _params=None):
|
||||||
source_tag_name = ctx.get_param_as_string('remove', required=True) or ''
|
source_tag_name = ctx.get_param_as_string('remove', required=True) or ''
|
||||||
|
@ -116,6 +125,7 @@ def merge_tags(ctx, _params=None):
|
||||||
tags.export_to_json()
|
tags.export_to_json()
|
||||||
return _serialize(ctx, target_tag)
|
return _serialize(ctx, target_tag)
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/tag-siblings/(?P<tag_name>[^/]+)/?')
|
@routes.get('/tag-siblings/(?P<tag_name>[^/]+)/?')
|
||||||
def get_tag_siblings(ctx, params):
|
def get_tag_siblings(ctx, params):
|
||||||
auth.verify_privilege(ctx.user, 'tags:view')
|
auth.verify_privilege(ctx.user, 'tags:view')
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
from szurubooru.rest import routes
|
from szurubooru.rest import routes
|
||||||
from szurubooru.func import auth, tags, tag_categories, util, snapshots
|
from szurubooru.func import auth, tags, tag_categories, util, snapshots
|
||||||
|
|
||||||
|
|
||||||
def _serialize(ctx, category):
|
def _serialize(ctx, category):
|
||||||
return tag_categories.serialize_category(
|
return tag_categories.serialize_category(
|
||||||
category, options=util.get_serialization_options(ctx))
|
category, options=util.get_serialization_options(ctx))
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/tag-categories/?')
|
@routes.get('/tag-categories/?')
|
||||||
def get_tag_categories(ctx, _params=None):
|
def get_tag_categories(ctx, _params=None):
|
||||||
auth.verify_privilege(ctx.user, 'tag_categories:list')
|
auth.verify_privilege(ctx.user, 'tag_categories:list')
|
||||||
|
@ -13,6 +15,7 @@ def get_tag_categories(ctx, _params=None):
|
||||||
'results': [_serialize(ctx, category) for category in categories],
|
'results': [_serialize(ctx, category) for category in categories],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@routes.post('/tag-categories/?')
|
@routes.post('/tag-categories/?')
|
||||||
def create_tag_category(ctx, _params=None):
|
def create_tag_category(ctx, _params=None):
|
||||||
auth.verify_privilege(ctx.user, 'tag_categories:create')
|
auth.verify_privilege(ctx.user, 'tag_categories:create')
|
||||||
|
@ -26,12 +29,14 @@ def create_tag_category(ctx, _params=None):
|
||||||
tags.export_to_json()
|
tags.export_to_json()
|
||||||
return _serialize(ctx, category)
|
return _serialize(ctx, category)
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/tag-category/(?P<category_name>[^/]+)/?')
|
@routes.get('/tag-category/(?P<category_name>[^/]+)/?')
|
||||||
def get_tag_category(ctx, params):
|
def get_tag_category(ctx, params):
|
||||||
auth.verify_privilege(ctx.user, 'tag_categories:view')
|
auth.verify_privilege(ctx.user, 'tag_categories:view')
|
||||||
category = tag_categories.get_category_by_name(params['category_name'])
|
category = tag_categories.get_category_by_name(params['category_name'])
|
||||||
return _serialize(ctx, category)
|
return _serialize(ctx, category)
|
||||||
|
|
||||||
|
|
||||||
@routes.put('/tag-category/(?P<category_name>[^/]+)/?')
|
@routes.put('/tag-category/(?P<category_name>[^/]+)/?')
|
||||||
def update_tag_category(ctx, params):
|
def update_tag_category(ctx, params):
|
||||||
category = tag_categories.get_category_by_name(params['category_name'])
|
category = tag_categories.get_category_by_name(params['category_name'])
|
||||||
|
@ -51,6 +56,7 @@ def update_tag_category(ctx, params):
|
||||||
tags.export_to_json()
|
tags.export_to_json()
|
||||||
return _serialize(ctx, category)
|
return _serialize(ctx, category)
|
||||||
|
|
||||||
|
|
||||||
@routes.delete('/tag-category/(?P<category_name>[^/]+)/?')
|
@routes.delete('/tag-category/(?P<category_name>[^/]+)/?')
|
||||||
def delete_tag_category(ctx, params):
|
def delete_tag_category(ctx, params):
|
||||||
category = tag_categories.get_category_by_name(params['category_name'])
|
category = tag_categories.get_category_by_name(params['category_name'])
|
||||||
|
@ -62,6 +68,7 @@ def delete_tag_category(ctx, params):
|
||||||
tags.export_to_json()
|
tags.export_to_json()
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
@routes.put('/tag-category/(?P<category_name>[^/]+)/default/?')
|
@routes.put('/tag-category/(?P<category_name>[^/]+)/default/?')
|
||||||
def set_tag_category_as_default(ctx, params):
|
def set_tag_category_as_default(ctx, params):
|
||||||
auth.verify_privilege(ctx.user, 'tag_categories:set_default')
|
auth.verify_privilege(ctx.user, 'tag_categories:set_default')
|
||||||
|
|
|
@ -2,8 +2,10 @@ from szurubooru import search
|
||||||
from szurubooru.func import auth, users, util
|
from szurubooru.func import auth, users, util
|
||||||
from szurubooru.rest import routes
|
from szurubooru.rest import routes
|
||||||
|
|
||||||
|
|
||||||
_search_executor = search.Executor(search.configs.UserSearchConfig())
|
_search_executor = search.Executor(search.configs.UserSearchConfig())
|
||||||
|
|
||||||
|
|
||||||
def _serialize(ctx, user, **kwargs):
|
def _serialize(ctx, user, **kwargs):
|
||||||
return users.serialize_user(
|
return users.serialize_user(
|
||||||
user,
|
user,
|
||||||
|
@ -11,12 +13,14 @@ def _serialize(ctx, user, **kwargs):
|
||||||
options=util.get_serialization_options(ctx),
|
options=util.get_serialization_options(ctx),
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/users/?')
|
@routes.get('/users/?')
|
||||||
def get_users(ctx, _params=None):
|
def get_users(ctx, _params=None):
|
||||||
auth.verify_privilege(ctx.user, 'users:list')
|
auth.verify_privilege(ctx.user, 'users:list')
|
||||||
return _search_executor.execute_and_serialize(
|
return _search_executor.execute_and_serialize(
|
||||||
ctx, lambda user: _serialize(ctx, user))
|
ctx, lambda user: _serialize(ctx, user))
|
||||||
|
|
||||||
|
|
||||||
@routes.post('/users/?')
|
@routes.post('/users/?')
|
||||||
def create_user(ctx, _params=None):
|
def create_user(ctx, _params=None):
|
||||||
auth.verify_privilege(ctx.user, 'users:create')
|
auth.verify_privilege(ctx.user, 'users:create')
|
||||||
|
@ -36,6 +40,7 @@ def create_user(ctx, _params=None):
|
||||||
ctx.session.commit()
|
ctx.session.commit()
|
||||||
return _serialize(ctx, user, force_show_email=True)
|
return _serialize(ctx, user, force_show_email=True)
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/user/(?P<user_name>[^/]+)/?')
|
@routes.get('/user/(?P<user_name>[^/]+)/?')
|
||||||
def get_user(ctx, params):
|
def get_user(ctx, params):
|
||||||
user = users.get_user_by_name(params['user_name'])
|
user = users.get_user_by_name(params['user_name'])
|
||||||
|
@ -43,6 +48,7 @@ def get_user(ctx, params):
|
||||||
auth.verify_privilege(ctx.user, 'users:view')
|
auth.verify_privilege(ctx.user, 'users:view')
|
||||||
return _serialize(ctx, user)
|
return _serialize(ctx, user)
|
||||||
|
|
||||||
|
|
||||||
@routes.put('/user/(?P<user_name>[^/]+)/?')
|
@routes.put('/user/(?P<user_name>[^/]+)/?')
|
||||||
def update_user(ctx, params):
|
def update_user(ctx, params):
|
||||||
user = users.get_user_by_name(params['user_name'])
|
user = users.get_user_by_name(params['user_name'])
|
||||||
|
@ -72,6 +78,7 @@ def update_user(ctx, params):
|
||||||
ctx.session.commit()
|
ctx.session.commit()
|
||||||
return _serialize(ctx, user)
|
return _serialize(ctx, user)
|
||||||
|
|
||||||
|
|
||||||
@routes.delete('/user/(?P<user_name>[^/]+)/?')
|
@routes.delete('/user/(?P<user_name>[^/]+)/?')
|
||||||
def delete_user(ctx, params):
|
def delete_user(ctx, params):
|
||||||
user = users.get_user_by_name(params['user_name'])
|
user = users.get_user_by_name(params['user_name'])
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
def merge(left, right):
|
def merge(left, right):
|
||||||
for key in right:
|
for key in right:
|
||||||
if key in left:
|
if key in left:
|
||||||
|
@ -12,6 +13,7 @@ def merge(left, right):
|
||||||
left[key] = right[key]
|
left[key] = right[key]
|
||||||
return left
|
return left
|
||||||
|
|
||||||
|
|
||||||
def read_config():
|
def read_config():
|
||||||
with open('../config.yaml.dist') as handle:
|
with open('../config.yaml.dist') as handle:
|
||||||
ret = yaml.load(handle.read())
|
ret = yaml.load(handle.read())
|
||||||
|
@ -20,4 +22,5 @@ def read_config():
|
||||||
ret = merge(ret, yaml.load(handle.read()))
|
ret = merge(ret, yaml.load(handle.read()))
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
config = read_config() # pylint: disable=invalid-name
|
|
||||||
|
config = read_config() # pylint: disable=invalid-name
|
||||||
|
|
|
@ -1,2 +1,4 @@
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
Base = declarative_base() # pylint: disable=invalid-name
|
|
||||||
|
|
||||||
|
Base = declarative_base() # pylint: disable=invalid-name
|
||||||
|
|
|
@ -3,13 +3,18 @@ from sqlalchemy.orm import relationship, backref
|
||||||
from sqlalchemy.sql.expression import func
|
from sqlalchemy.sql.expression import func
|
||||||
from szurubooru.db.base import Base
|
from szurubooru.db.base import Base
|
||||||
|
|
||||||
|
|
||||||
class CommentScore(Base):
|
class CommentScore(Base):
|
||||||
__tablename__ = 'comment_score'
|
__tablename__ = 'comment_score'
|
||||||
|
|
||||||
comment_id = Column(
|
comment_id = Column(
|
||||||
'comment_id', Integer, ForeignKey('comment.id'), primary_key=True)
|
'comment_id', Integer, ForeignKey('comment.id'), primary_key=True)
|
||||||
user_id = Column(
|
user_id = Column(
|
||||||
'user_id', Integer, ForeignKey('user.id'), primary_key=True, index=True)
|
'user_id',
|
||||||
|
Integer,
|
||||||
|
ForeignKey('user.id'),
|
||||||
|
primary_key=True,
|
||||||
|
index=True)
|
||||||
time = Column('time', DateTime, nullable=False)
|
time = Column('time', DateTime, nullable=False)
|
||||||
score = Column('score', Integer, nullable=False)
|
score = Column('score', Integer, nullable=False)
|
||||||
|
|
||||||
|
@ -18,14 +23,14 @@ class CommentScore(Base):
|
||||||
'User',
|
'User',
|
||||||
backref=backref('comment_scores', cascade='all, delete-orphan'))
|
backref=backref('comment_scores', cascade='all, delete-orphan'))
|
||||||
|
|
||||||
|
|
||||||
class Comment(Base):
|
class Comment(Base):
|
||||||
__tablename__ = 'comment'
|
__tablename__ = 'comment'
|
||||||
|
|
||||||
comment_id = Column('id', Integer, primary_key=True)
|
comment_id = Column('id', Integer, primary_key=True)
|
||||||
post_id = Column(
|
post_id = Column(
|
||||||
'post_id', Integer, ForeignKey('post.id'), index=True, nullable=False)
|
'post_id', Integer, ForeignKey('post.id'), index=True, nullable=False)
|
||||||
user_id = Column(
|
user_id = Column('user_id', Integer, ForeignKey('user.id'), index=True)
|
||||||
'user_id', Integer, ForeignKey('user.id'), index=True)
|
|
||||||
version = Column('version', Integer, default=1, nullable=False)
|
version = Column('version', Integer, default=1, nullable=False)
|
||||||
creation_time = Column('creation_time', DateTime, nullable=False)
|
creation_time = Column('creation_time', DateTime, nullable=False)
|
||||||
last_edit_time = Column('last_edit_time', DateTime)
|
last_edit_time = Column('last_edit_time', DateTime)
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
|
from sqlalchemy.sql.expression import func, select
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
Column, Integer, DateTime, Unicode, UnicodeText, PickleType, ForeignKey)
|
Column, Integer, DateTime, Unicode, UnicodeText, PickleType, ForeignKey)
|
||||||
from sqlalchemy.orm import relationship, column_property, object_session, backref
|
from sqlalchemy.orm import (
|
||||||
from sqlalchemy.sql.expression import func, select
|
relationship, column_property, object_session, backref)
|
||||||
from szurubooru.db.base import Base
|
from szurubooru.db.base import Base
|
||||||
from szurubooru.db.comment import Comment
|
from szurubooru.db.comment import Comment
|
||||||
|
|
||||||
|
|
||||||
class PostFeature(Base):
|
class PostFeature(Base):
|
||||||
__tablename__ = 'post_feature'
|
__tablename__ = 'post_feature'
|
||||||
|
|
||||||
|
@ -20,13 +22,22 @@ class PostFeature(Base):
|
||||||
'User',
|
'User',
|
||||||
backref=backref('post_features', cascade='all, delete-orphan'))
|
backref=backref('post_features', cascade='all, delete-orphan'))
|
||||||
|
|
||||||
|
|
||||||
class PostScore(Base):
|
class PostScore(Base):
|
||||||
__tablename__ = 'post_score'
|
__tablename__ = 'post_score'
|
||||||
|
|
||||||
post_id = Column(
|
post_id = Column(
|
||||||
'post_id', Integer, ForeignKey('post.id'), primary_key=True, index=True)
|
'post_id',
|
||||||
|
Integer,
|
||||||
|
ForeignKey('post.id'),
|
||||||
|
primary_key=True,
|
||||||
|
index=True)
|
||||||
user_id = Column(
|
user_id = Column(
|
||||||
'user_id', Integer, ForeignKey('user.id'), primary_key=True, index=True)
|
'user_id',
|
||||||
|
Integer,
|
||||||
|
ForeignKey('user.id'),
|
||||||
|
primary_key=True,
|
||||||
|
index=True)
|
||||||
time = Column('time', DateTime, nullable=False)
|
time = Column('time', DateTime, nullable=False)
|
||||||
score = Column('score', Integer, nullable=False)
|
score = Column('score', Integer, nullable=False)
|
||||||
|
|
||||||
|
@ -35,13 +46,22 @@ class PostScore(Base):
|
||||||
'User',
|
'User',
|
||||||
backref=backref('post_scores', cascade='all, delete-orphan'))
|
backref=backref('post_scores', cascade='all, delete-orphan'))
|
||||||
|
|
||||||
|
|
||||||
class PostFavorite(Base):
|
class PostFavorite(Base):
|
||||||
__tablename__ = 'post_favorite'
|
__tablename__ = 'post_favorite'
|
||||||
|
|
||||||
post_id = Column(
|
post_id = Column(
|
||||||
'post_id', Integer, ForeignKey('post.id'), primary_key=True, index=True)
|
'post_id',
|
||||||
|
Integer,
|
||||||
|
ForeignKey('post.id'),
|
||||||
|
primary_key=True,
|
||||||
|
index=True)
|
||||||
user_id = Column(
|
user_id = Column(
|
||||||
'user_id', Integer, ForeignKey('user.id'), primary_key=True, index=True)
|
'user_id',
|
||||||
|
Integer,
|
||||||
|
ForeignKey('user.id'),
|
||||||
|
primary_key=True,
|
||||||
|
index=True)
|
||||||
time = Column('time', DateTime, nullable=False)
|
time = Column('time', DateTime, nullable=False)
|
||||||
|
|
||||||
post = relationship('Post')
|
post = relationship('Post')
|
||||||
|
@ -49,6 +69,7 @@ class PostFavorite(Base):
|
||||||
'User',
|
'User',
|
||||||
backref=backref('post_favorites', cascade='all, delete-orphan'))
|
backref=backref('post_favorites', cascade='all, delete-orphan'))
|
||||||
|
|
||||||
|
|
||||||
class PostNote(Base):
|
class PostNote(Base):
|
||||||
__tablename__ = 'post_note'
|
__tablename__ = 'post_note'
|
||||||
|
|
||||||
|
@ -60,23 +81,37 @@ class PostNote(Base):
|
||||||
|
|
||||||
post = relationship('Post')
|
post = relationship('Post')
|
||||||
|
|
||||||
|
|
||||||
class PostRelation(Base):
|
class PostRelation(Base):
|
||||||
__tablename__ = 'post_relation'
|
__tablename__ = 'post_relation'
|
||||||
|
|
||||||
parent_id = Column(
|
parent_id = Column(
|
||||||
'parent_id', Integer, ForeignKey('post.id'), primary_key=True, index=True)
|
'parent_id',
|
||||||
|
Integer,
|
||||||
|
ForeignKey('post.id'),
|
||||||
|
primary_key=True,
|
||||||
|
index=True)
|
||||||
child_id = Column(
|
child_id = Column(
|
||||||
'child_id', Integer, ForeignKey('post.id'), primary_key=True, index=True)
|
'child_id',
|
||||||
|
Integer,
|
||||||
|
ForeignKey('post.id'),
|
||||||
|
primary_key=True,
|
||||||
|
index=True)
|
||||||
|
|
||||||
def __init__(self, parent_id, child_id):
|
def __init__(self, parent_id, child_id):
|
||||||
self.parent_id = parent_id
|
self.parent_id = parent_id
|
||||||
self.child_id = child_id
|
self.child_id = child_id
|
||||||
|
|
||||||
|
|
||||||
class PostTag(Base):
|
class PostTag(Base):
|
||||||
__tablename__ = 'post_tag'
|
__tablename__ = 'post_tag'
|
||||||
|
|
||||||
post_id = Column(
|
post_id = Column(
|
||||||
'post_id', Integer, ForeignKey('post.id'), primary_key=True, index=True)
|
'post_id',
|
||||||
|
Integer,
|
||||||
|
ForeignKey('post.id'),
|
||||||
|
primary_key=True,
|
||||||
|
index=True)
|
||||||
tag_id = Column(
|
tag_id = Column(
|
||||||
'tag_id', Integer, ForeignKey('tag.id'), primary_key=True, index=True)
|
'tag_id', Integer, ForeignKey('tag.id'), primary_key=True, index=True)
|
||||||
|
|
||||||
|
@ -84,6 +119,7 @@ class PostTag(Base):
|
||||||
self.post_id = post_id
|
self.post_id = post_id
|
||||||
self.tag_id = tag_id
|
self.tag_id = tag_id
|
||||||
|
|
||||||
|
|
||||||
class Post(Base):
|
class Post(Base):
|
||||||
__tablename__ = 'post'
|
__tablename__ = 'post'
|
||||||
|
|
||||||
|
@ -136,8 +172,8 @@ class Post(Base):
|
||||||
|
|
||||||
# dynamic columns
|
# dynamic columns
|
||||||
tag_count = column_property(
|
tag_count = column_property(
|
||||||
select([func.count(PostTag.tag_id)]) \
|
select([func.count(PostTag.tag_id)])
|
||||||
.where(PostTag.post_id == post_id) \
|
.where(PostTag.post_id == post_id)
|
||||||
.correlate_except(PostTag))
|
.correlate_except(PostTag))
|
||||||
|
|
||||||
canvas_area = column_property(canvas_width * canvas_height)
|
canvas_area = column_property(canvas_width * canvas_height)
|
||||||
|
@ -151,53 +187,53 @@ class Post(Base):
|
||||||
return featured_post and featured_post.post_id == self.post_id
|
return featured_post and featured_post.post_id == self.post_id
|
||||||
|
|
||||||
score = column_property(
|
score = column_property(
|
||||||
select([func.coalesce(func.sum(PostScore.score), 0)]) \
|
select([func.coalesce(func.sum(PostScore.score), 0)])
|
||||||
.where(PostScore.post_id == post_id) \
|
.where(PostScore.post_id == post_id)
|
||||||
.correlate_except(PostScore))
|
.correlate_except(PostScore))
|
||||||
|
|
||||||
favorite_count = column_property(
|
favorite_count = column_property(
|
||||||
select([func.count(PostFavorite.post_id)]) \
|
select([func.count(PostFavorite.post_id)])
|
||||||
.where(PostFavorite.post_id == post_id) \
|
.where(PostFavorite.post_id == post_id)
|
||||||
.correlate_except(PostFavorite))
|
.correlate_except(PostFavorite))
|
||||||
|
|
||||||
last_favorite_time = column_property(
|
last_favorite_time = column_property(
|
||||||
select([func.max(PostFavorite.time)]) \
|
select([func.max(PostFavorite.time)])
|
||||||
.where(PostFavorite.post_id == post_id) \
|
.where(PostFavorite.post_id == post_id)
|
||||||
.correlate_except(PostFavorite))
|
.correlate_except(PostFavorite))
|
||||||
|
|
||||||
feature_count = column_property(
|
feature_count = column_property(
|
||||||
select([func.count(PostFeature.post_id)]) \
|
select([func.count(PostFeature.post_id)])
|
||||||
.where(PostFeature.post_id == post_id) \
|
.where(PostFeature.post_id == post_id)
|
||||||
.correlate_except(PostFeature))
|
.correlate_except(PostFeature))
|
||||||
|
|
||||||
last_feature_time = column_property(
|
last_feature_time = column_property(
|
||||||
select([func.max(PostFeature.time)]) \
|
select([func.max(PostFeature.time)])
|
||||||
.where(PostFeature.post_id == post_id) \
|
.where(PostFeature.post_id == post_id)
|
||||||
.correlate_except(PostFeature))
|
.correlate_except(PostFeature))
|
||||||
|
|
||||||
comment_count = column_property(
|
comment_count = column_property(
|
||||||
select([func.count(Comment.post_id)]) \
|
select([func.count(Comment.post_id)])
|
||||||
.where(Comment.post_id == post_id) \
|
.where(Comment.post_id == post_id)
|
||||||
.correlate_except(Comment))
|
.correlate_except(Comment))
|
||||||
|
|
||||||
last_comment_creation_time = column_property(
|
last_comment_creation_time = column_property(
|
||||||
select([func.max(Comment.creation_time)]) \
|
select([func.max(Comment.creation_time)])
|
||||||
.where(Comment.post_id == post_id) \
|
.where(Comment.post_id == post_id)
|
||||||
.correlate_except(Comment))
|
.correlate_except(Comment))
|
||||||
|
|
||||||
last_comment_edit_time = column_property(
|
last_comment_edit_time = column_property(
|
||||||
select([func.max(Comment.last_edit_time)]) \
|
select([func.max(Comment.last_edit_time)])
|
||||||
.where(Comment.post_id == post_id) \
|
.where(Comment.post_id == post_id)
|
||||||
.correlate_except(Comment))
|
.correlate_except(Comment))
|
||||||
|
|
||||||
note_count = column_property(
|
note_count = column_property(
|
||||||
select([func.count(PostNote.post_id)]) \
|
select([func.count(PostNote.post_id)])
|
||||||
.where(PostNote.post_id == post_id) \
|
.where(PostNote.post_id == post_id)
|
||||||
.correlate_except(PostNote))
|
.correlate_except(PostNote))
|
||||||
|
|
||||||
relation_count = column_property(
|
relation_count = column_property(
|
||||||
select([func.count(PostRelation.child_id)]) \
|
select([func.count(PostRelation.child_id)])
|
||||||
.where(
|
.where(
|
||||||
(PostRelation.parent_id == post_id) \
|
(PostRelation.parent_id == post_id)
|
||||||
| (PostRelation.child_id == post_id)) \
|
| (PostRelation.child_id == post_id))
|
||||||
.correlate_except(PostRelation))
|
.correlate_except(PostRelation))
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from szurubooru import config
|
from szurubooru import config
|
||||||
|
|
||||||
|
|
||||||
class QueryCounter(object):
|
class QueryCounter(object):
|
||||||
_query_count = 0
|
_query_count = 0
|
||||||
|
|
||||||
|
@ -16,6 +17,7 @@ class QueryCounter(object):
|
||||||
def get():
|
def get():
|
||||||
return QueryCounter._query_count
|
return QueryCounter._query_count
|
||||||
|
|
||||||
|
|
||||||
def create_session():
|
def create_session():
|
||||||
_engine = sqlalchemy.create_engine(
|
_engine = sqlalchemy.create_engine(
|
||||||
'{schema}://{user}:{password}@{host}:{port}/{name}'.format(
|
'{schema}://{user}:{password}@{host}:{port}/{name}'.format(
|
||||||
|
@ -30,6 +32,7 @@ def create_session():
|
||||||
_session_maker = sqlalchemy.orm.sessionmaker(bind=_engine)
|
_session_maker = sqlalchemy.orm.sessionmaker(bind=_engine)
|
||||||
return sqlalchemy.orm.scoped_session(_session_maker)
|
return sqlalchemy.orm.scoped_session(_session_maker)
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
session = create_session()
|
session = create_session()
|
||||||
reset_query_count = QueryCounter.reset
|
reset_query_count = QueryCounter.reset
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
from sqlalchemy import Column, Integer, DateTime, Unicode, PickleType, ForeignKey
|
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
|
from sqlalchemy import (
|
||||||
|
Column, Integer, DateTime, Unicode, PickleType, ForeignKey)
|
||||||
from szurubooru.db.base import Base
|
from szurubooru.db.base import Base
|
||||||
|
|
||||||
|
|
||||||
class Snapshot(Base):
|
class Snapshot(Base):
|
||||||
__tablename__ = 'snapshot'
|
__tablename__ = 'snapshot'
|
||||||
|
|
||||||
|
@ -11,7 +13,8 @@ class Snapshot(Base):
|
||||||
|
|
||||||
snapshot_id = Column('id', Integer, primary_key=True)
|
snapshot_id = Column('id', Integer, primary_key=True)
|
||||||
creation_time = Column('creation_time', DateTime, nullable=False)
|
creation_time = Column('creation_time', DateTime, nullable=False)
|
||||||
resource_type = Column('resource_type', Unicode(32), nullable=False, index=True)
|
resource_type = Column(
|
||||||
|
'resource_type', Unicode(32), nullable=False, index=True)
|
||||||
resource_id = Column('resource_id', Integer, nullable=False, index=True)
|
resource_id = Column('resource_id', Integer, nullable=False, index=True)
|
||||||
resource_repr = Column('resource_repr', Unicode(64), nullable=False)
|
resource_repr = Column('resource_repr', Unicode(64), nullable=False)
|
||||||
operation = Column('operation', Unicode(16), nullable=False)
|
operation = Column('operation', Unicode(16), nullable=False)
|
||||||
|
|
|
@ -1,49 +1,73 @@
|
||||||
from sqlalchemy import Column, Integer, DateTime, Unicode, UnicodeText, ForeignKey
|
from sqlalchemy import (
|
||||||
|
Column, Integer, DateTime, Unicode, UnicodeText, ForeignKey)
|
||||||
from sqlalchemy.orm import relationship, column_property
|
from sqlalchemy.orm import relationship, column_property
|
||||||
from sqlalchemy.sql.expression import func, select
|
from sqlalchemy.sql.expression import func, select
|
||||||
from szurubooru.db.base import Base
|
from szurubooru.db.base import Base
|
||||||
from szurubooru.db.post import PostTag
|
from szurubooru.db.post import PostTag
|
||||||
|
|
||||||
|
|
||||||
class TagSuggestion(Base):
|
class TagSuggestion(Base):
|
||||||
__tablename__ = 'tag_suggestion'
|
__tablename__ = 'tag_suggestion'
|
||||||
|
|
||||||
parent_id = Column(
|
parent_id = Column(
|
||||||
'parent_id', Integer, ForeignKey('tag.id'), primary_key=True, index=True)
|
'parent_id',
|
||||||
|
Integer,
|
||||||
|
ForeignKey('tag.id'),
|
||||||
|
primary_key=True, index=True)
|
||||||
child_id = Column(
|
child_id = Column(
|
||||||
'child_id', Integer, ForeignKey('tag.id'), primary_key=True, index=True)
|
'child_id',
|
||||||
|
Integer,
|
||||||
|
ForeignKey('tag.id'),
|
||||||
|
primary_key=True, index=True)
|
||||||
|
|
||||||
def __init__(self, parent_id, child_id):
|
def __init__(self, parent_id, child_id):
|
||||||
self.parent_id = parent_id
|
self.parent_id = parent_id
|
||||||
self.child_id = child_id
|
self.child_id = child_id
|
||||||
|
|
||||||
|
|
||||||
class TagImplication(Base):
|
class TagImplication(Base):
|
||||||
__tablename__ = 'tag_implication'
|
__tablename__ = 'tag_implication'
|
||||||
|
|
||||||
parent_id = Column(
|
parent_id = Column(
|
||||||
'parent_id', Integer, ForeignKey('tag.id'), primary_key=True, index=True)
|
'parent_id',
|
||||||
|
Integer,
|
||||||
|
ForeignKey('tag.id'),
|
||||||
|
primary_key=True,
|
||||||
|
index=True)
|
||||||
child_id = Column(
|
child_id = Column(
|
||||||
'child_id', Integer, ForeignKey('tag.id'), primary_key=True, index=True)
|
'child_id',
|
||||||
|
Integer,
|
||||||
|
ForeignKey('tag.id'),
|
||||||
|
primary_key=True,
|
||||||
|
index=True)
|
||||||
|
|
||||||
def __init__(self, parent_id, child_id):
|
def __init__(self, parent_id, child_id):
|
||||||
self.parent_id = parent_id
|
self.parent_id = parent_id
|
||||||
self.child_id = child_id
|
self.child_id = child_id
|
||||||
|
|
||||||
|
|
||||||
class TagName(Base):
|
class TagName(Base):
|
||||||
__tablename__ = 'tag_name'
|
__tablename__ = 'tag_name'
|
||||||
|
|
||||||
tag_name_id = Column('tag_name_id', Integer, primary_key=True)
|
tag_name_id = Column('tag_name_id', Integer, primary_key=True)
|
||||||
tag_id = Column('tag_id', Integer, ForeignKey('tag.id'), nullable=False, index=True)
|
tag_id = Column(
|
||||||
|
'tag_id', Integer, ForeignKey('tag.id'), nullable=False, index=True)
|
||||||
name = Column('name', Unicode(64), nullable=False, unique=True)
|
name = Column('name', Unicode(64), nullable=False, unique=True)
|
||||||
|
|
||||||
def __init__(self, name):
|
def __init__(self, name):
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
|
|
||||||
class Tag(Base):
|
class Tag(Base):
|
||||||
__tablename__ = 'tag'
|
__tablename__ = 'tag'
|
||||||
|
|
||||||
tag_id = Column('id', Integer, primary_key=True)
|
tag_id = Column('id', Integer, primary_key=True)
|
||||||
category_id = Column(
|
category_id = Column(
|
||||||
'category_id', Integer, ForeignKey('tag_category.id'), nullable=False, index=True)
|
'category_id',
|
||||||
|
Integer,
|
||||||
|
ForeignKey('tag_category.id'),
|
||||||
|
nullable=False,
|
||||||
|
index=True)
|
||||||
version = Column('version', Integer, default=1, nullable=False)
|
version = Column('version', Integer, default=1, nullable=False)
|
||||||
creation_time = Column('creation_time', DateTime, nullable=False)
|
creation_time = Column('creation_time', DateTime, nullable=False)
|
||||||
last_edit_time = Column('last_edit_time', DateTime)
|
last_edit_time = Column('last_edit_time', DateTime)
|
||||||
|
@ -69,25 +93,25 @@ class Tag(Base):
|
||||||
lazy='joined')
|
lazy='joined')
|
||||||
|
|
||||||
post_count = column_property(
|
post_count = column_property(
|
||||||
select([func.count(PostTag.post_id)]) \
|
select([func.count(PostTag.post_id)])
|
||||||
.where(PostTag.tag_id == tag_id) \
|
.where(PostTag.tag_id == tag_id)
|
||||||
.correlate_except(PostTag))
|
.correlate_except(PostTag))
|
||||||
|
|
||||||
first_name = column_property(
|
first_name = column_property(
|
||||||
select([TagName.name]) \
|
select([TagName.name])
|
||||||
.where(TagName.tag_id == tag_id) \
|
.where(TagName.tag_id == tag_id)
|
||||||
.limit(1) \
|
.limit(1)
|
||||||
.as_scalar(),
|
.as_scalar(),
|
||||||
deferred=True)
|
deferred=True)
|
||||||
|
|
||||||
suggestion_count = column_property(
|
suggestion_count = column_property(
|
||||||
select([func.count(TagSuggestion.child_id)]) \
|
select([func.count(TagSuggestion.child_id)])
|
||||||
.where(TagSuggestion.parent_id == tag_id) \
|
.where(TagSuggestion.parent_id == tag_id)
|
||||||
.as_scalar(),
|
.as_scalar(),
|
||||||
deferred=True)
|
deferred=True)
|
||||||
|
|
||||||
implication_count = column_property(
|
implication_count = column_property(
|
||||||
select([func.count(TagImplication.child_id)]) \
|
select([func.count(TagImplication.child_id)])
|
||||||
.where(TagImplication.parent_id == tag_id) \
|
.where(TagImplication.parent_id == tag_id)
|
||||||
.as_scalar(),
|
.as_scalar(),
|
||||||
deferred=True)
|
deferred=True)
|
||||||
|
|
|
@ -4,6 +4,7 @@ from sqlalchemy.sql.expression import func, select
|
||||||
from szurubooru.db.base import Base
|
from szurubooru.db.base import Base
|
||||||
from szurubooru.db.tag import Tag
|
from szurubooru.db.tag import Tag
|
||||||
|
|
||||||
|
|
||||||
class TagCategory(Base):
|
class TagCategory(Base):
|
||||||
__tablename__ = 'tag_category'
|
__tablename__ = 'tag_category'
|
||||||
|
|
||||||
|
@ -17,6 +18,6 @@ class TagCategory(Base):
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
tag_count = column_property(
|
tag_count = column_property(
|
||||||
select([func.count('Tag.tag_id')]) \
|
select([func.count('Tag.tag_id')])
|
||||||
.where(Tag.category_id == tag_category_id) \
|
.where(Tag.category_id == tag_category_id)
|
||||||
.correlate_except(table('Tag')))
|
.correlate_except(table('Tag')))
|
||||||
|
|
|
@ -5,6 +5,7 @@ from szurubooru.db.base import Base
|
||||||
from szurubooru.db.post import Post, PostScore, PostFavorite
|
from szurubooru.db.post import Post, PostScore, PostFavorite
|
||||||
from szurubooru.db.comment import Comment
|
from szurubooru.db.comment import Comment
|
||||||
|
|
||||||
|
|
||||||
class User(Base):
|
class User(Base):
|
||||||
__tablename__ = 'user'
|
__tablename__ = 'user'
|
||||||
|
|
||||||
|
@ -17,7 +18,7 @@ class User(Base):
|
||||||
RANK_POWER = 'power'
|
RANK_POWER = 'power'
|
||||||
RANK_MODERATOR = 'moderator'
|
RANK_MODERATOR = 'moderator'
|
||||||
RANK_ADMINISTRATOR = 'administrator'
|
RANK_ADMINISTRATOR = 'administrator'
|
||||||
RANK_NOBODY = 'nobody' # used for privileges: "nobody can be higher than admin"
|
RANK_NOBODY = 'nobody' # unattainable, used for privileges
|
||||||
|
|
||||||
user_id = Column('id', Integer, primary_key=True)
|
user_id = Column('id', Integer, primary_key=True)
|
||||||
creation_time = Column('creation_time', DateTime, nullable=False)
|
creation_time = Column('creation_time', DateTime, nullable=False)
|
||||||
|
@ -36,41 +37,41 @@ class User(Base):
|
||||||
@property
|
@property
|
||||||
def post_count(self):
|
def post_count(self):
|
||||||
from szurubooru.db import session
|
from szurubooru.db import session
|
||||||
return session \
|
return (session
|
||||||
.query(func.sum(1)) \
|
.query(func.sum(1))
|
||||||
.filter(Post.user_id == self.user_id) \
|
.filter(Post.user_id == self.user_id)
|
||||||
.one()[0] or 0
|
.one()[0] or 0)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def comment_count(self):
|
def comment_count(self):
|
||||||
from szurubooru.db import session
|
from szurubooru.db import session
|
||||||
return session \
|
return (session
|
||||||
.query(func.sum(1)) \
|
.query(func.sum(1))
|
||||||
.filter(Comment.user_id == self.user_id) \
|
.filter(Comment.user_id == self.user_id)
|
||||||
.one()[0] or 0
|
.one()[0] or 0)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def favorite_post_count(self):
|
def favorite_post_count(self):
|
||||||
from szurubooru.db import session
|
from szurubooru.db import session
|
||||||
return session \
|
return (session
|
||||||
.query(func.sum(1)) \
|
.query(func.sum(1))
|
||||||
.filter(PostFavorite.user_id == self.user_id) \
|
.filter(PostFavorite.user_id == self.user_id)
|
||||||
.one()[0] or 0
|
.one()[0] or 0)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def liked_post_count(self):
|
def liked_post_count(self):
|
||||||
from szurubooru.db import session
|
from szurubooru.db import session
|
||||||
return session \
|
return (session
|
||||||
.query(func.sum(1)) \
|
.query(func.sum(1))
|
||||||
.filter(PostScore.user_id == self.user_id) \
|
.filter(PostScore.user_id == self.user_id)
|
||||||
.filter(PostScore.score == 1) \
|
.filter(PostScore.score == 1)
|
||||||
.one()[0] or 0
|
.one()[0] or 0)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def disliked_post_count(self):
|
def disliked_post_count(self):
|
||||||
from szurubooru.db import session
|
from szurubooru.db import session
|
||||||
return session \
|
return (session
|
||||||
.query(func.sum(1)) \
|
.query(func.sum(1))
|
||||||
.filter(PostScore.user_id == self.user_id) \
|
.filter(PostScore.user_id == self.user_id)
|
||||||
.filter(PostScore.score == -1) \
|
.filter(PostScore.score == -1)
|
||||||
.one()[0] or 0
|
.one()[0] or 0)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from sqlalchemy.inspection import inspect
|
from sqlalchemy.inspection import inspect
|
||||||
|
|
||||||
|
|
||||||
def get_resource_info(entity):
|
def get_resource_info(entity):
|
||||||
serializers = {
|
serializers = {
|
||||||
'tag': lambda tag: tag.first_name,
|
'tag': lambda tag: tag.first_name,
|
||||||
|
@ -23,6 +24,7 @@ def get_resource_info(entity):
|
||||||
|
|
||||||
return (resource_type, resource_id, resource_repr)
|
return (resource_type, resource_id, resource_repr)
|
||||||
|
|
||||||
|
|
||||||
def get_aux_entity(session, get_table_info, entity, user):
|
def get_aux_entity(session, get_table_info, entity, user):
|
||||||
table, get_column = get_table_info(entity)
|
table, get_column = get_table_info(entity)
|
||||||
return session \
|
return session \
|
||||||
|
|
|
@ -1,11 +1,38 @@
|
||||||
class ConfigError(RuntimeError): pass
|
class ConfigError(RuntimeError):
|
||||||
class AuthError(RuntimeError): pass
|
pass
|
||||||
class IntegrityError(RuntimeError): pass
|
|
||||||
class ValidationError(RuntimeError): pass
|
|
||||||
class SearchError(RuntimeError): pass
|
|
||||||
class NotFoundError(RuntimeError): pass
|
|
||||||
class ProcessingError(RuntimeError): pass
|
|
||||||
|
|
||||||
class MissingRequiredFileError(ValidationError): pass
|
|
||||||
class MissingRequiredParameterError(ValidationError): pass
|
class AuthError(RuntimeError):
|
||||||
class InvalidParameterError(ValidationError): pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class IntegrityError(RuntimeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationError(RuntimeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SearchError(RuntimeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NotFoundError(RuntimeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessingError(RuntimeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MissingRequiredFileError(ValidationError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class MissingRequiredParameterError(ValidationError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidParameterError(ValidationError):
|
||||||
|
pass
|
||||||
|
|
|
@ -7,30 +7,37 @@ from szurubooru import config, errors, rest
|
||||||
# pylint: disable=unused-import
|
# pylint: disable=unused-import
|
||||||
from szurubooru import api, middleware
|
from szurubooru import api, middleware
|
||||||
|
|
||||||
|
|
||||||
def _on_auth_error(ex):
|
def _on_auth_error(ex):
|
||||||
raise rest.errors.HttpForbidden(
|
raise rest.errors.HttpForbidden(
|
||||||
title='Authentication error', description=str(ex))
|
title='Authentication error', description=str(ex))
|
||||||
|
|
||||||
|
|
||||||
def _on_validation_error(ex):
|
def _on_validation_error(ex):
|
||||||
raise rest.errors.HttpBadRequest(
|
raise rest.errors.HttpBadRequest(
|
||||||
title='Validation error', description=str(ex))
|
title='Validation error', description=str(ex))
|
||||||
|
|
||||||
|
|
||||||
def _on_search_error(ex):
|
def _on_search_error(ex):
|
||||||
raise rest.errors.HttpBadRequest(
|
raise rest.errors.HttpBadRequest(
|
||||||
title='Search error', description=str(ex))
|
title='Search error', description=str(ex))
|
||||||
|
|
||||||
|
|
||||||
def _on_integrity_error(ex):
|
def _on_integrity_error(ex):
|
||||||
raise rest.errors.HttpConflict(
|
raise rest.errors.HttpConflict(
|
||||||
title='Integrity violation', description=ex.args[0])
|
title='Integrity violation', description=ex.args[0])
|
||||||
|
|
||||||
|
|
||||||
def _on_not_found_error(ex):
|
def _on_not_found_error(ex):
|
||||||
raise rest.errors.HttpNotFound(
|
raise rest.errors.HttpNotFound(
|
||||||
title='Not found', description=str(ex))
|
title='Not found', description=str(ex))
|
||||||
|
|
||||||
|
|
||||||
def _on_processing_error(ex):
|
def _on_processing_error(ex):
|
||||||
raise rest.errors.HttpBadRequest(
|
raise rest.errors.HttpBadRequest(
|
||||||
title='Processing error', description=str(ex))
|
title='Processing error', description=str(ex))
|
||||||
|
|
||||||
|
|
||||||
def validate_config():
|
def validate_config():
|
||||||
'''
|
'''
|
||||||
Check whether config doesn't contain errors that might prove
|
Check whether config doesn't contain errors that might prove
|
||||||
|
@ -60,6 +67,7 @@ def validate_config():
|
||||||
raise errors.ConfigError(
|
raise errors.ConfigError(
|
||||||
'Database is not configured: %r is missing' % key)
|
'Database is not configured: %r is missing' % key)
|
||||||
|
|
||||||
|
|
||||||
def create_app():
|
def create_app():
|
||||||
''' Create a WSGI compatible App object. '''
|
''' Create a WSGI compatible App object. '''
|
||||||
validate_config()
|
validate_config()
|
||||||
|
|
|
@ -4,6 +4,7 @@ from collections import OrderedDict
|
||||||
from szurubooru import config, db, errors
|
from szurubooru import config, db, errors
|
||||||
from szurubooru.func import util
|
from szurubooru.func import util
|
||||||
|
|
||||||
|
|
||||||
RANK_MAP = OrderedDict([
|
RANK_MAP = OrderedDict([
|
||||||
(db.User.RANK_ANONYMOUS, 'anonymous'),
|
(db.User.RANK_ANONYMOUS, 'anonymous'),
|
||||||
(db.User.RANK_RESTRICTED, 'restricted'),
|
(db.User.RANK_RESTRICTED, 'restricted'),
|
||||||
|
@ -14,6 +15,7 @@ RANK_MAP = OrderedDict([
|
||||||
(db.User.RANK_NOBODY, 'nobody'),
|
(db.User.RANK_NOBODY, 'nobody'),
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
def get_password_hash(salt, password):
|
def get_password_hash(salt, password):
|
||||||
''' Retrieve new-style password hash. '''
|
''' Retrieve new-style password hash. '''
|
||||||
digest = hashlib.sha256()
|
digest = hashlib.sha256()
|
||||||
|
@ -22,6 +24,7 @@ def get_password_hash(salt, password):
|
||||||
digest.update(password.encode('utf8'))
|
digest.update(password.encode('utf8'))
|
||||||
return digest.hexdigest()
|
return digest.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def get_legacy_password_hash(salt, password):
|
def get_legacy_password_hash(salt, password):
|
||||||
''' Retrieve old-style password hash. '''
|
''' Retrieve old-style password hash. '''
|
||||||
digest = hashlib.sha1()
|
digest = hashlib.sha1()
|
||||||
|
@ -30,6 +33,7 @@ def get_legacy_password_hash(salt, password):
|
||||||
digest.update(password.encode('utf8'))
|
digest.update(password.encode('utf8'))
|
||||||
return digest.hexdigest()
|
return digest.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def create_password():
|
def create_password():
|
||||||
alphabet = {
|
alphabet = {
|
||||||
'c': list('bcdfghijklmnpqrstvwxyz'),
|
'c': list('bcdfghijklmnpqrstvwxyz'),
|
||||||
|
@ -39,6 +43,7 @@ def create_password():
|
||||||
pattern = 'cvcvnncvcv'
|
pattern = 'cvcvnncvcv'
|
||||||
return ''.join(random.choice(alphabet[l]) for l in list(pattern))
|
return ''.join(random.choice(alphabet[l]) for l in list(pattern))
|
||||||
|
|
||||||
|
|
||||||
def is_valid_password(user, password):
|
def is_valid_password(user, password):
|
||||||
assert user
|
assert user
|
||||||
salt, valid_hash = user.password_salt, user.password_hash
|
salt, valid_hash = user.password_salt, user.password_hash
|
||||||
|
@ -48,6 +53,7 @@ def is_valid_password(user, password):
|
||||||
]
|
]
|
||||||
return valid_hash in possible_hashes
|
return valid_hash in possible_hashes
|
||||||
|
|
||||||
|
|
||||||
def has_privilege(user, privilege_name):
|
def has_privilege(user, privilege_name):
|
||||||
assert user
|
assert user
|
||||||
all_ranks = list(RANK_MAP.keys())
|
all_ranks = list(RANK_MAP.keys())
|
||||||
|
@ -58,11 +64,13 @@ def has_privilege(user, privilege_name):
|
||||||
good_ranks = all_ranks[all_ranks.index(minimal_rank):]
|
good_ranks = all_ranks[all_ranks.index(minimal_rank):]
|
||||||
return user.rank in good_ranks
|
return user.rank in good_ranks
|
||||||
|
|
||||||
|
|
||||||
def verify_privilege(user, privilege_name):
|
def verify_privilege(user, privilege_name):
|
||||||
assert user
|
assert user
|
||||||
if not has_privilege(user, privilege_name):
|
if not has_privilege(user, privilege_name):
|
||||||
raise errors.AuthError('Insufficient privileges to do this.')
|
raise errors.AuthError('Insufficient privileges to do this.')
|
||||||
|
|
||||||
|
|
||||||
def generate_authentication_token(user):
|
def generate_authentication_token(user):
|
||||||
''' Generate nonguessable challenge (e.g. links in password reminder). '''
|
''' Generate nonguessable challenge (e.g. links in password reminder). '''
|
||||||
assert user
|
assert user
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
class LruCacheItem(object):
|
class LruCacheItem(object):
|
||||||
def __init__(self, key, value):
|
def __init__(self, key, value):
|
||||||
self.key = key
|
self.key = key
|
||||||
self.value = value
|
self.value = value
|
||||||
self.timestamp = datetime.utcnow()
|
self.timestamp = datetime.utcnow()
|
||||||
|
|
||||||
|
|
||||||
class LruCache(object):
|
class LruCache(object):
|
||||||
def __init__(self, length, delta=None):
|
def __init__(self, length, delta=None):
|
||||||
self.length = length
|
self.length = length
|
||||||
|
@ -15,12 +17,13 @@ class LruCache(object):
|
||||||
|
|
||||||
def insert_item(self, item):
|
def insert_item(self, item):
|
||||||
if item.key in self.hash:
|
if item.key in self.hash:
|
||||||
item_index = next(i \
|
item_index = next(
|
||||||
for i, v in enumerate(self.item_list) \
|
i
|
||||||
|
for i, v in enumerate(self.item_list)
|
||||||
if v.key == item.key)
|
if v.key == item.key)
|
||||||
self.item_list[:] \
|
self.item_list[:] \
|
||||||
= self.item_list[:item_index] \
|
= self.item_list[:item_index] \
|
||||||
+ self.item_list[item_index+1:]
|
+ self.item_list[item_index + 1:]
|
||||||
self.item_list.insert(0, item)
|
self.item_list.insert(0, item)
|
||||||
else:
|
else:
|
||||||
if len(self.item_list) > self.length:
|
if len(self.item_list) > self.length:
|
||||||
|
@ -36,16 +39,21 @@ class LruCache(object):
|
||||||
del self.hash[item.key]
|
del self.hash[item.key]
|
||||||
del self.item_list[self.item_list.index(item)]
|
del self.item_list[self.item_list.index(item)]
|
||||||
|
|
||||||
|
|
||||||
_CACHE = LruCache(length=100)
|
_CACHE = LruCache(length=100)
|
||||||
|
|
||||||
|
|
||||||
def purge():
|
def purge():
|
||||||
_CACHE.remove_all()
|
_CACHE.remove_all()
|
||||||
|
|
||||||
|
|
||||||
def has(key):
|
def has(key):
|
||||||
return key in _CACHE.hash
|
return key in _CACHE.hash
|
||||||
|
|
||||||
|
|
||||||
def get(key):
|
def get(key):
|
||||||
return _CACHE.hash[key].value
|
return _CACHE.hash[key].value
|
||||||
|
|
||||||
|
|
||||||
def put(key, value):
|
def put(key, value):
|
||||||
_CACHE.insert_item(LruCacheItem(key, value))
|
_CACHE.insert_item(LruCacheItem(key, value))
|
||||||
|
|
|
@ -2,16 +2,26 @@ import datetime
|
||||||
from szurubooru import db, errors
|
from szurubooru import db, errors
|
||||||
from szurubooru.func import users, scores, util
|
from szurubooru.func import users, scores, util
|
||||||
|
|
||||||
class InvalidCommentIdError(errors.ValidationError): pass
|
|
||||||
class CommentNotFoundError(errors.NotFoundError): pass
|
class InvalidCommentIdError(errors.ValidationError):
|
||||||
class EmptyCommentTextError(errors.ValidationError): pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class CommentNotFoundError(errors.NotFoundError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class EmptyCommentTextError(errors.ValidationError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def serialize_comment(comment, auth_user, options=None):
|
def serialize_comment(comment, auth_user, options=None):
|
||||||
return util.serialize_entity(
|
return util.serialize_entity(
|
||||||
comment,
|
comment,
|
||||||
{
|
{
|
||||||
'id': lambda: comment.comment_id,
|
'id': lambda: comment.comment_id,
|
||||||
'user': lambda: users.serialize_micro_user(comment.user, auth_user),
|
'user':
|
||||||
|
lambda: users.serialize_micro_user(comment.user, auth_user),
|
||||||
'postId': lambda: comment.post.post_id,
|
'postId': lambda: comment.post.post_id,
|
||||||
'version': lambda: comment.version,
|
'version': lambda: comment.version,
|
||||||
'text': lambda: comment.text,
|
'text': lambda: comment.text,
|
||||||
|
@ -22,6 +32,7 @@ def serialize_comment(comment, auth_user, options=None):
|
||||||
},
|
},
|
||||||
options)
|
options)
|
||||||
|
|
||||||
|
|
||||||
def try_get_comment_by_id(comment_id):
|
def try_get_comment_by_id(comment_id):
|
||||||
try:
|
try:
|
||||||
comment_id = int(comment_id)
|
comment_id = int(comment_id)
|
||||||
|
@ -32,12 +43,14 @@ def try_get_comment_by_id(comment_id):
|
||||||
.filter(db.Comment.comment_id == comment_id) \
|
.filter(db.Comment.comment_id == comment_id) \
|
||||||
.one_or_none()
|
.one_or_none()
|
||||||
|
|
||||||
|
|
||||||
def get_comment_by_id(comment_id):
|
def get_comment_by_id(comment_id):
|
||||||
comment = try_get_comment_by_id(comment_id)
|
comment = try_get_comment_by_id(comment_id)
|
||||||
if comment:
|
if comment:
|
||||||
return comment
|
return comment
|
||||||
raise CommentNotFoundError('Comment %r not found.' % comment_id)
|
raise CommentNotFoundError('Comment %r not found.' % comment_id)
|
||||||
|
|
||||||
|
|
||||||
def create_comment(user, post, text):
|
def create_comment(user, post, text):
|
||||||
comment = db.Comment()
|
comment = db.Comment()
|
||||||
comment.user = user
|
comment.user = user
|
||||||
|
@ -46,6 +59,7 @@ def create_comment(user, post, text):
|
||||||
comment.creation_time = datetime.datetime.utcnow()
|
comment.creation_time = datetime.datetime.utcnow()
|
||||||
return comment
|
return comment
|
||||||
|
|
||||||
|
|
||||||
def update_comment_text(comment, text):
|
def update_comment_text(comment, text):
|
||||||
assert comment
|
assert comment
|
||||||
if not text:
|
if not text:
|
||||||
|
|
|
@ -2,7 +2,10 @@ import datetime
|
||||||
from szurubooru import db, errors
|
from szurubooru import db, errors
|
||||||
from szurubooru.func import scores
|
from szurubooru.func import scores
|
||||||
|
|
||||||
class InvalidFavoriteTargetError(errors.ValidationError): pass
|
|
||||||
|
class InvalidFavoriteTargetError(errors.ValidationError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _get_table_info(entity):
|
def _get_table_info(entity):
|
||||||
assert entity
|
assert entity
|
||||||
|
@ -11,16 +14,19 @@ def _get_table_info(entity):
|
||||||
return db.PostFavorite, lambda table: table.post_id
|
return db.PostFavorite, lambda table: table.post_id
|
||||||
raise InvalidFavoriteTargetError()
|
raise InvalidFavoriteTargetError()
|
||||||
|
|
||||||
|
|
||||||
def _get_fav_entity(entity, user):
|
def _get_fav_entity(entity, user):
|
||||||
assert entity
|
assert entity
|
||||||
assert user
|
assert user
|
||||||
return db.util.get_aux_entity(db.session, _get_table_info, entity, user)
|
return db.util.get_aux_entity(db.session, _get_table_info, entity, user)
|
||||||
|
|
||||||
|
|
||||||
def has_favorited(entity, user):
|
def has_favorited(entity, user):
|
||||||
assert entity
|
assert entity
|
||||||
assert user
|
assert user
|
||||||
return _get_fav_entity(entity, user) is not None
|
return _get_fav_entity(entity, user) is not None
|
||||||
|
|
||||||
|
|
||||||
def unset_favorite(entity, user):
|
def unset_favorite(entity, user):
|
||||||
assert entity
|
assert entity
|
||||||
assert user
|
assert user
|
||||||
|
@ -28,6 +34,7 @@ def unset_favorite(entity, user):
|
||||||
if fav_entity:
|
if fav_entity:
|
||||||
db.session.delete(fav_entity)
|
db.session.delete(fav_entity)
|
||||||
|
|
||||||
|
|
||||||
def set_favorite(entity, user):
|
def set_favorite(entity, user):
|
||||||
assert entity
|
assert entity
|
||||||
assert user
|
assert user
|
||||||
|
|
|
@ -1,20 +1,25 @@
|
||||||
import os
|
import os
|
||||||
from szurubooru import config
|
from szurubooru import config
|
||||||
|
|
||||||
|
|
||||||
def _get_full_path(path):
|
def _get_full_path(path):
|
||||||
return os.path.join(config.config['data_dir'], path)
|
return os.path.join(config.config['data_dir'], path)
|
||||||
|
|
||||||
|
|
||||||
def delete(path):
|
def delete(path):
|
||||||
full_path = _get_full_path(path)
|
full_path = _get_full_path(path)
|
||||||
if os.path.exists(full_path):
|
if os.path.exists(full_path):
|
||||||
os.unlink(full_path)
|
os.unlink(full_path)
|
||||||
|
|
||||||
|
|
||||||
def has(path):
|
def has(path):
|
||||||
return os.path.exists(_get_full_path(path))
|
return os.path.exists(_get_full_path(path))
|
||||||
|
|
||||||
|
|
||||||
def move(source_path, target_path):
|
def move(source_path, target_path):
|
||||||
return os.rename(_get_full_path(source_path), _get_full_path(target_path))
|
return os.rename(_get_full_path(source_path), _get_full_path(target_path))
|
||||||
|
|
||||||
|
|
||||||
def get(path):
|
def get(path):
|
||||||
full_path = _get_full_path(path)
|
full_path = _get_full_path(path)
|
||||||
if not os.path.exists(full_path):
|
if not os.path.exists(full_path):
|
||||||
|
@ -22,6 +27,7 @@ def get(path):
|
||||||
with open(full_path, 'rb') as handle:
|
with open(full_path, 'rb') as handle:
|
||||||
return handle.read()
|
return handle.read()
|
||||||
|
|
||||||
|
|
||||||
def save(path, content):
|
def save(path, content):
|
||||||
full_path = _get_full_path(path)
|
full_path = _get_full_path(path)
|
||||||
os.makedirs(os.path.dirname(full_path), exist_ok=True)
|
os.makedirs(os.path.dirname(full_path), exist_ok=True)
|
||||||
|
|
|
@ -6,11 +6,14 @@ import math
|
||||||
from szurubooru import errors
|
from szurubooru import errors
|
||||||
from szurubooru.func import mime, util
|
from szurubooru.func import mime, util
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
_SCALE_FIT_FMT = \
|
_SCALE_FIT_FMT = \
|
||||||
r'scale=iw*max({width}/iw\,{height}/ih):ih*max({width}/iw\,{height}/ih)'
|
r'scale=iw*max({width}/iw\,{height}/ih):ih*max({width}/iw\,{height}/ih)'
|
||||||
|
|
||||||
|
|
||||||
class Image(object):
|
class Image(object):
|
||||||
def __init__(self, content):
|
def __init__(self, content):
|
||||||
self.content = content
|
self.content = content
|
||||||
|
@ -38,12 +41,13 @@ class Image(object):
|
||||||
'-',
|
'-',
|
||||||
]
|
]
|
||||||
if 'duration' in self.info['format'] \
|
if 'duration' in self.info['format'] \
|
||||||
and float(self.info['format']['duration']) > 3 \
|
|
||||||
and self.info['format']['format_name'] != 'swf':
|
and self.info['format']['format_name'] != 'swf':
|
||||||
cli = [
|
duration = float(self.info['format']['duration'])
|
||||||
'-ss',
|
if duration > 3:
|
||||||
'%d' % math.floor(float(self.info['format']['duration']) * 0.3),
|
cli = [
|
||||||
] + cli
|
'-ss',
|
||||||
|
'%d' % math.floor(duration * 0.3),
|
||||||
|
] + cli
|
||||||
self.content = self._execute(cli)
|
self.content = self._execute(cli)
|
||||||
assert self.content
|
assert self.content
|
||||||
self._reload_info()
|
self._reload_info()
|
||||||
|
|
|
@ -2,6 +2,7 @@ import smtplib
|
||||||
import email.mime.text
|
import email.mime.text
|
||||||
from szurubooru import config
|
from szurubooru import config
|
||||||
|
|
||||||
|
|
||||||
def send_mail(sender, recipient, subject, body):
|
def send_mail(sender, recipient, subject, body):
|
||||||
msg = email.mime.text.MIMEText(body)
|
msg = email.mime.text.MIMEText(body)
|
||||||
msg['Subject'] = subject
|
msg['Subject'] = subject
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import re
|
import re
|
||||||
|
|
||||||
# pylint: disable=too-many-return-statements
|
|
||||||
def get_mime_type(content):
|
def get_mime_type(content):
|
||||||
if not content:
|
if not content:
|
||||||
return 'application/octet-stream'
|
return 'application/octet-stream'
|
||||||
|
@ -25,6 +25,7 @@ def get_mime_type(content):
|
||||||
|
|
||||||
return 'application/octet-stream'
|
return 'application/octet-stream'
|
||||||
|
|
||||||
|
|
||||||
def get_extension(mime_type):
|
def get_extension(mime_type):
|
||||||
extension_map = {
|
extension_map = {
|
||||||
'application/x-shockwave-flash': 'swf',
|
'application/x-shockwave-flash': 'swf',
|
||||||
|
@ -37,15 +38,19 @@ def get_extension(mime_type):
|
||||||
}
|
}
|
||||||
return extension_map.get((mime_type or '').strip().lower(), None)
|
return extension_map.get((mime_type or '').strip().lower(), None)
|
||||||
|
|
||||||
|
|
||||||
def is_flash(mime_type):
|
def is_flash(mime_type):
|
||||||
return mime_type.lower() == 'application/x-shockwave-flash'
|
return mime_type.lower() == 'application/x-shockwave-flash'
|
||||||
|
|
||||||
|
|
||||||
def is_video(mime_type):
|
def is_video(mime_type):
|
||||||
return mime_type.lower() in ('application/ogg', 'video/mp4', 'video/webm')
|
return mime_type.lower() in ('application/ogg', 'video/mp4', 'video/webm')
|
||||||
|
|
||||||
|
|
||||||
def is_image(mime_type):
|
def is_image(mime_type):
|
||||||
return mime_type.lower() in ('image/jpeg', 'image/png', 'image/gif')
|
return mime_type.lower() in ('image/jpeg', 'image/png', 'image/gif')
|
||||||
|
|
||||||
|
|
||||||
def is_animated_gif(content):
|
def is_animated_gif(content):
|
||||||
return get_mime_type(content) == 'image/gif' \
|
return get_mime_type(content) == 'image/gif' \
|
||||||
and len(re.findall(b'\x21\xF9\x04.{4}\x00[\x2C\x21]', content)) > 1
|
and len(re.findall(b'\x21\xF9\x04.{4}\x00[\x2C\x21]', content)) > 1
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import urllib.request
|
import urllib.request
|
||||||
|
|
||||||
|
|
||||||
def download(url):
|
def download(url):
|
||||||
assert url
|
assert url
|
||||||
request = urllib.request.Request(url)
|
request = urllib.request.Request(url)
|
||||||
|
|
|
@ -4,37 +4,71 @@ from szurubooru import config, db, errors
|
||||||
from szurubooru.func import (
|
from szurubooru.func import (
|
||||||
users, snapshots, scores, comments, tags, util, mime, images, files)
|
users, snapshots, scores, comments, tags, util, mime, images, files)
|
||||||
|
|
||||||
|
|
||||||
EMPTY_PIXEL = \
|
EMPTY_PIXEL = \
|
||||||
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' \
|
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' \
|
||||||
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \
|
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \
|
||||||
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b'
|
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b'
|
||||||
|
|
||||||
class PostNotFoundError(errors.NotFoundError): pass
|
|
||||||
class PostAlreadyFeaturedError(errors.ValidationError): pass
|
class PostNotFoundError(errors.NotFoundError):
|
||||||
class PostAlreadyUploadedError(errors.ValidationError): pass
|
pass
|
||||||
class InvalidPostIdError(errors.ValidationError): pass
|
|
||||||
class InvalidPostSafetyError(errors.ValidationError): pass
|
|
||||||
class InvalidPostSourceError(errors.ValidationError): pass
|
class PostAlreadyFeaturedError(errors.ValidationError):
|
||||||
class InvalidPostContentError(errors.ValidationError): pass
|
pass
|
||||||
class InvalidPostRelationError(errors.ValidationError): pass
|
|
||||||
class InvalidPostNoteError(errors.ValidationError): pass
|
|
||||||
class InvalidPostFlagError(errors.ValidationError): pass
|
class PostAlreadyUploadedError(errors.ValidationError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidPostIdError(errors.ValidationError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidPostSafetyError(errors.ValidationError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidPostSourceError(errors.ValidationError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidPostContentError(errors.ValidationError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidPostRelationError(errors.ValidationError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidPostNoteError(errors.ValidationError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidPostFlagError(errors.ValidationError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
SAFETY_MAP = {
|
SAFETY_MAP = {
|
||||||
db.Post.SAFETY_SAFE: 'safe',
|
db.Post.SAFETY_SAFE: 'safe',
|
||||||
db.Post.SAFETY_SKETCHY: 'sketchy',
|
db.Post.SAFETY_SKETCHY: 'sketchy',
|
||||||
db.Post.SAFETY_UNSAFE: 'unsafe',
|
db.Post.SAFETY_UNSAFE: 'unsafe',
|
||||||
}
|
}
|
||||||
|
|
||||||
TYPE_MAP = {
|
TYPE_MAP = {
|
||||||
db.Post.TYPE_IMAGE: 'image',
|
db.Post.TYPE_IMAGE: 'image',
|
||||||
db.Post.TYPE_ANIMATION: 'animation',
|
db.Post.TYPE_ANIMATION: 'animation',
|
||||||
db.Post.TYPE_VIDEO: 'video',
|
db.Post.TYPE_VIDEO: 'video',
|
||||||
db.Post.TYPE_FLASH: 'flash',
|
db.Post.TYPE_FLASH: 'flash',
|
||||||
}
|
}
|
||||||
|
|
||||||
FLAG_MAP = {
|
FLAG_MAP = {
|
||||||
db.Post.FLAG_LOOP: 'loop',
|
db.Post.FLAG_LOOP: 'loop',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_post_content_url(post):
|
def get_post_content_url(post):
|
||||||
assert post
|
assert post
|
||||||
return '%s/posts/%d.%s' % (
|
return '%s/posts/%d.%s' % (
|
||||||
|
@ -42,25 +76,30 @@ def get_post_content_url(post):
|
||||||
post.post_id,
|
post.post_id,
|
||||||
mime.get_extension(post.mime_type) or 'dat')
|
mime.get_extension(post.mime_type) or 'dat')
|
||||||
|
|
||||||
|
|
||||||
def get_post_thumbnail_url(post):
|
def get_post_thumbnail_url(post):
|
||||||
assert post
|
assert post
|
||||||
return '%s/generated-thumbnails/%d.jpg' % (
|
return '%s/generated-thumbnails/%d.jpg' % (
|
||||||
config.config['data_url'].rstrip('/'),
|
config.config['data_url'].rstrip('/'),
|
||||||
post.post_id)
|
post.post_id)
|
||||||
|
|
||||||
|
|
||||||
def get_post_content_path(post):
|
def get_post_content_path(post):
|
||||||
assert post
|
assert post
|
||||||
return 'posts/%d.%s' % (
|
return 'posts/%d.%s' % (
|
||||||
post.post_id, mime.get_extension(post.mime_type) or 'dat')
|
post.post_id, mime.get_extension(post.mime_type) or 'dat')
|
||||||
|
|
||||||
|
|
||||||
def get_post_thumbnail_path(post):
|
def get_post_thumbnail_path(post):
|
||||||
assert post
|
assert post
|
||||||
return 'generated-thumbnails/%d.jpg' % (post.post_id)
|
return 'generated-thumbnails/%d.jpg' % (post.post_id)
|
||||||
|
|
||||||
|
|
||||||
def get_post_thumbnail_backup_path(post):
|
def get_post_thumbnail_backup_path(post):
|
||||||
assert post
|
assert post
|
||||||
return 'posts/custom-thumbnails/%d.dat' % (post.post_id)
|
return 'posts/custom-thumbnails/%d.dat' % (post.post_id)
|
||||||
|
|
||||||
|
|
||||||
def serialize_note(note):
|
def serialize_note(note):
|
||||||
assert note
|
assert note
|
||||||
return {
|
return {
|
||||||
|
@ -68,6 +107,7 @@ def serialize_note(note):
|
||||||
'text': note.text,
|
'text': note.text,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def serialize_post(post, auth_user, options=None):
|
def serialize_post(post, auth_user, options=None):
|
||||||
return util.serialize_entity(
|
return util.serialize_entity(
|
||||||
post,
|
post,
|
||||||
|
@ -93,17 +133,17 @@ def serialize_post(post, auth_user, options=None):
|
||||||
{
|
{
|
||||||
post['id']:
|
post['id']:
|
||||||
post for post in [
|
post for post in [
|
||||||
serialize_micro_post(rel, auth_user) \
|
serialize_micro_post(rel, auth_user)
|
||||||
for rel in post.relations
|
for rel in post.relations]
|
||||||
]
|
|
||||||
}.values(),
|
}.values(),
|
||||||
key=lambda post: post['id']),
|
key=lambda post: post['id']),
|
||||||
'user': lambda: users.serialize_micro_user(post.user, auth_user),
|
'user': lambda: users.serialize_micro_user(post.user, auth_user),
|
||||||
'score': lambda: post.score,
|
'score': lambda: post.score,
|
||||||
'ownScore': lambda: scores.get_score(post, auth_user),
|
'ownScore': lambda: scores.get_score(post, auth_user),
|
||||||
'ownFavorite': lambda: len(
|
'ownFavorite': lambda: len([
|
||||||
[user for user in post.favorited_by \
|
user for user in post.favorited_by
|
||||||
if user.user_id == auth_user.user_id]) > 0,
|
if user.user_id == auth_user.user_id]
|
||||||
|
) > 0,
|
||||||
'tagCount': lambda: post.tag_count,
|
'tagCount': lambda: post.tag_count,
|
||||||
'favoriteCount': lambda: post.favorite_count,
|
'favoriteCount': lambda: post.favorite_count,
|
||||||
'commentCount': lambda: post.comment_count,
|
'commentCount': lambda: post.comment_count,
|
||||||
|
@ -112,31 +152,35 @@ def serialize_post(post, auth_user, options=None):
|
||||||
'featureCount': lambda: post.feature_count,
|
'featureCount': lambda: post.feature_count,
|
||||||
'lastFeatureTime': lambda: post.last_feature_time,
|
'lastFeatureTime': lambda: post.last_feature_time,
|
||||||
'favoritedBy': lambda: [
|
'favoritedBy': lambda: [
|
||||||
users.serialize_micro_user(rel.user, auth_user) \
|
users.serialize_micro_user(rel.user, auth_user)
|
||||||
for rel in post.favorited_by],
|
for rel in post.favorited_by
|
||||||
|
],
|
||||||
'hasCustomThumbnail':
|
'hasCustomThumbnail':
|
||||||
lambda: files.has(get_post_thumbnail_backup_path(post)),
|
lambda: files.has(get_post_thumbnail_backup_path(post)),
|
||||||
'notes': lambda: sorted(
|
'notes': lambda: sorted(
|
||||||
[serialize_note(note) for note in post.notes],
|
[serialize_note(note) for note in post.notes],
|
||||||
key=lambda x: x['polygon']),
|
key=lambda x: x['polygon']),
|
||||||
'comments': lambda: [
|
'comments': lambda: [
|
||||||
comments.serialize_comment(comment, auth_user) \
|
comments.serialize_comment(comment, auth_user)
|
||||||
for comment in sorted(
|
for comment in sorted(
|
||||||
post.comments,
|
post.comments,
|
||||||
key=lambda comment: comment.creation_time)],
|
key=lambda comment: comment.creation_time)],
|
||||||
'snapshots': lambda: snapshots.get_serialized_history(post),
|
'snapshots': lambda: snapshots.get_serialized_history(post),
|
||||||
},
|
},
|
||||||
options)
|
options)
|
||||||
|
|
||||||
|
|
||||||
def serialize_micro_post(post, auth_user):
|
def serialize_micro_post(post, auth_user):
|
||||||
return serialize_post(
|
return serialize_post(
|
||||||
post,
|
post,
|
||||||
auth_user=auth_user,
|
auth_user=auth_user,
|
||||||
options=['id', 'thumbnailUrl'])
|
options=['id', 'thumbnailUrl'])
|
||||||
|
|
||||||
|
|
||||||
def get_post_count():
|
def get_post_count():
|
||||||
return db.session.query(sqlalchemy.func.count(db.Post.post_id)).one()[0]
|
return db.session.query(sqlalchemy.func.count(db.Post.post_id)).one()[0]
|
||||||
|
|
||||||
|
|
||||||
def try_get_post_by_id(post_id):
|
def try_get_post_by_id(post_id):
|
||||||
try:
|
try:
|
||||||
post_id = int(post_id)
|
post_id = int(post_id)
|
||||||
|
@ -147,22 +191,26 @@ def try_get_post_by_id(post_id):
|
||||||
.filter(db.Post.post_id == post_id) \
|
.filter(db.Post.post_id == post_id) \
|
||||||
.one_or_none()
|
.one_or_none()
|
||||||
|
|
||||||
|
|
||||||
def get_post_by_id(post_id):
|
def get_post_by_id(post_id):
|
||||||
post = try_get_post_by_id(post_id)
|
post = try_get_post_by_id(post_id)
|
||||||
if not post:
|
if not post:
|
||||||
raise PostNotFoundError('Post %r not found.' % post_id)
|
raise PostNotFoundError('Post %r not found.' % post_id)
|
||||||
return post
|
return post
|
||||||
|
|
||||||
|
|
||||||
def try_get_current_post_feature():
|
def try_get_current_post_feature():
|
||||||
return db.session \
|
return db.session \
|
||||||
.query(db.PostFeature) \
|
.query(db.PostFeature) \
|
||||||
.order_by(db.PostFeature.time.desc()) \
|
.order_by(db.PostFeature.time.desc()) \
|
||||||
.first()
|
.first()
|
||||||
|
|
||||||
|
|
||||||
def try_get_featured_post():
|
def try_get_featured_post():
|
||||||
post_feature = try_get_current_post_feature()
|
post_feature = try_get_current_post_feature()
|
||||||
return post_feature.post if post_feature else None
|
return post_feature.post if post_feature else None
|
||||||
|
|
||||||
|
|
||||||
def create_post(content, tag_names, user):
|
def create_post(content, tag_names, user):
|
||||||
post = db.Post()
|
post = db.Post()
|
||||||
post.safety = db.Post.SAFETY_SAFE
|
post.safety = db.Post.SAFETY_SAFE
|
||||||
|
@ -181,6 +229,7 @@ def create_post(content, tag_names, user):
|
||||||
new_tags = update_post_tags(post, tag_names)
|
new_tags = update_post_tags(post, tag_names)
|
||||||
return (post, new_tags)
|
return (post, new_tags)
|
||||||
|
|
||||||
|
|
||||||
def update_post_safety(post, safety):
|
def update_post_safety(post, safety):
|
||||||
assert post
|
assert post
|
||||||
safety = util.flip(SAFETY_MAP).get(safety, None)
|
safety = util.flip(SAFETY_MAP).get(safety, None)
|
||||||
|
@ -189,12 +238,14 @@ def update_post_safety(post, safety):
|
||||||
'Safety can be either of %r.' % list(SAFETY_MAP.values()))
|
'Safety can be either of %r.' % list(SAFETY_MAP.values()))
|
||||||
post.safety = safety
|
post.safety = safety
|
||||||
|
|
||||||
|
|
||||||
def update_post_source(post, source):
|
def update_post_source(post, source):
|
||||||
assert post
|
assert post
|
||||||
if util.value_exceeds_column_size(source, db.Post.source):
|
if util.value_exceeds_column_size(source, db.Post.source):
|
||||||
raise InvalidPostSourceError('Source is too long.')
|
raise InvalidPostSourceError('Source is too long.')
|
||||||
post.source = source
|
post.source = source
|
||||||
|
|
||||||
|
|
||||||
def update_post_content(post, content):
|
def update_post_content(post, content):
|
||||||
assert post
|
assert post
|
||||||
if not content:
|
if not content:
|
||||||
|
@ -210,7 +261,8 @@ def update_post_content(post, content):
|
||||||
elif mime.is_video(post.mime_type):
|
elif mime.is_video(post.mime_type):
|
||||||
post.type = db.Post.TYPE_VIDEO
|
post.type = db.Post.TYPE_VIDEO
|
||||||
else:
|
else:
|
||||||
raise InvalidPostContentError('Unhandled file type: %r' % post.mime_type)
|
raise InvalidPostContentError(
|
||||||
|
'Unhandled file type: %r' % post.mime_type)
|
||||||
|
|
||||||
post.checksum = util.get_md5(content)
|
post.checksum = util.get_md5(content)
|
||||||
other_post = db.session \
|
other_post = db.session \
|
||||||
|
@ -236,6 +288,7 @@ def update_post_content(post, content):
|
||||||
files.save(get_post_content_path(post), content)
|
files.save(get_post_content_path(post), content)
|
||||||
update_post_thumbnail(post, content=None, do_delete=False)
|
update_post_thumbnail(post, content=None, do_delete=False)
|
||||||
|
|
||||||
|
|
||||||
def update_post_thumbnail(post, content=None, do_delete=True):
|
def update_post_thumbnail(post, content=None, do_delete=True):
|
||||||
assert post
|
assert post
|
||||||
if not content:
|
if not content:
|
||||||
|
@ -246,6 +299,7 @@ def update_post_thumbnail(post, content=None, do_delete=True):
|
||||||
files.save(get_post_thumbnail_backup_path(post), content)
|
files.save(get_post_thumbnail_backup_path(post), content)
|
||||||
generate_post_thumbnail(post)
|
generate_post_thumbnail(post)
|
||||||
|
|
||||||
|
|
||||||
def generate_post_thumbnail(post):
|
def generate_post_thumbnail(post):
|
||||||
assert post
|
assert post
|
||||||
if files.has(get_post_thumbnail_backup_path(post)):
|
if files.has(get_post_thumbnail_backup_path(post)):
|
||||||
|
@ -261,12 +315,14 @@ def generate_post_thumbnail(post):
|
||||||
except errors.ProcessingError:
|
except errors.ProcessingError:
|
||||||
files.save(get_post_thumbnail_path(post), EMPTY_PIXEL)
|
files.save(get_post_thumbnail_path(post), EMPTY_PIXEL)
|
||||||
|
|
||||||
|
|
||||||
def update_post_tags(post, tag_names):
|
def update_post_tags(post, tag_names):
|
||||||
assert post
|
assert post
|
||||||
existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names)
|
existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names)
|
||||||
post.tags = existing_tags + new_tags
|
post.tags = existing_tags + new_tags
|
||||||
return new_tags
|
return new_tags
|
||||||
|
|
||||||
|
|
||||||
def update_post_relations(post, new_post_ids):
|
def update_post_relations(post, new_post_ids):
|
||||||
assert post
|
assert post
|
||||||
old_posts = post.relations
|
old_posts = post.relations
|
||||||
|
@ -287,6 +343,7 @@ def update_post_relations(post, new_post_ids):
|
||||||
post.relations.append(relation)
|
post.relations.append(relation)
|
||||||
relation.relations.append(post)
|
relation.relations.append(post)
|
||||||
|
|
||||||
|
|
||||||
def update_post_notes(post, notes):
|
def update_post_notes(post, notes):
|
||||||
assert post
|
assert post
|
||||||
post.notes = []
|
post.notes = []
|
||||||
|
@ -323,6 +380,7 @@ def update_post_notes(post, notes):
|
||||||
post.notes.append(
|
post.notes.append(
|
||||||
db.PostNote(polygon=note['polygon'], text=str(note['text'])))
|
db.PostNote(polygon=note['polygon'], text=str(note['text'])))
|
||||||
|
|
||||||
|
|
||||||
def update_post_flags(post, flags):
|
def update_post_flags(post, flags):
|
||||||
assert post
|
assert post
|
||||||
target_flags = []
|
target_flags = []
|
||||||
|
@ -334,6 +392,7 @@ def update_post_flags(post, flags):
|
||||||
target_flags.append(flag)
|
target_flags.append(flag)
|
||||||
post.flags = target_flags
|
post.flags = target_flags
|
||||||
|
|
||||||
|
|
||||||
def feature_post(post, user):
|
def feature_post(post, user):
|
||||||
assert post
|
assert post
|
||||||
post_feature = db.PostFeature()
|
post_feature = db.PostFeature()
|
||||||
|
@ -342,6 +401,7 @@ def feature_post(post, user):
|
||||||
post_feature.user = user
|
post_feature.user = user
|
||||||
db.session.add(post_feature)
|
db.session.add(post_feature)
|
||||||
|
|
||||||
|
|
||||||
def delete(post):
|
def delete(post):
|
||||||
assert post
|
assert post
|
||||||
db.session.delete(post)
|
db.session.delete(post)
|
||||||
|
|
|
@ -2,8 +2,14 @@ import datetime
|
||||||
from szurubooru import db, errors
|
from szurubooru import db, errors
|
||||||
from szurubooru.func import favorites
|
from szurubooru.func import favorites
|
||||||
|
|
||||||
class InvalidScoreTargetError(errors.ValidationError): pass
|
|
||||||
class InvalidScoreValueError(errors.ValidationError): pass
|
class InvalidScoreTargetError(errors.ValidationError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidScoreValueError(errors.ValidationError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _get_table_info(entity):
|
def _get_table_info(entity):
|
||||||
assert entity
|
assert entity
|
||||||
|
@ -14,10 +20,12 @@ def _get_table_info(entity):
|
||||||
return db.CommentScore, lambda table: table.comment_id
|
return db.CommentScore, lambda table: table.comment_id
|
||||||
raise InvalidScoreTargetError()
|
raise InvalidScoreTargetError()
|
||||||
|
|
||||||
|
|
||||||
def _get_score_entity(entity, user):
|
def _get_score_entity(entity, user):
|
||||||
assert user
|
assert user
|
||||||
return db.util.get_aux_entity(db.session, _get_table_info, entity, user)
|
return db.util.get_aux_entity(db.session, _get_table_info, entity, user)
|
||||||
|
|
||||||
|
|
||||||
def delete_score(entity, user):
|
def delete_score(entity, user):
|
||||||
assert entity
|
assert entity
|
||||||
assert user
|
assert user
|
||||||
|
@ -25,6 +33,7 @@ def delete_score(entity, user):
|
||||||
if score_entity:
|
if score_entity:
|
||||||
db.session.delete(score_entity)
|
db.session.delete(score_entity)
|
||||||
|
|
||||||
|
|
||||||
def get_score(entity, user):
|
def get_score(entity, user):
|
||||||
assert entity
|
assert entity
|
||||||
assert user
|
assert user
|
||||||
|
@ -36,6 +45,7 @@ def get_score(entity, user):
|
||||||
.one_or_none()
|
.one_or_none()
|
||||||
return row[0] if row else 0
|
return row[0] if row else 0
|
||||||
|
|
||||||
|
|
||||||
def set_score(entity, user, score):
|
def set_score(entity, user, score):
|
||||||
assert entity
|
assert entity
|
||||||
assert user
|
assert user
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import datetime
|
import datetime
|
||||||
from szurubooru import db
|
from szurubooru import db
|
||||||
|
|
||||||
|
|
||||||
def get_tag_snapshot(tag):
|
def get_tag_snapshot(tag):
|
||||||
return {
|
return {
|
||||||
'names': [tag_name.name for tag_name in tag.names],
|
'names': [tag_name.name for tag_name in tag.names],
|
||||||
|
@ -9,6 +10,7 @@ def get_tag_snapshot(tag):
|
||||||
'implications': sorted(rel.first_name for rel in tag.implications),
|
'implications': sorted(rel.first_name for rel in tag.implications),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_post_snapshot(post):
|
def get_post_snapshot(post):
|
||||||
return {
|
return {
|
||||||
'source': post.source,
|
'source': post.source,
|
||||||
|
@ -25,6 +27,7 @@ def get_post_snapshot(post):
|
||||||
'featured': post.is_featured,
|
'featured': post.is_featured,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_tag_category_snapshot(category):
|
def get_tag_category_snapshot(category):
|
||||||
return {
|
return {
|
||||||
'name': category.name,
|
'name': category.name,
|
||||||
|
@ -32,6 +35,7 @@ def get_tag_category_snapshot(category):
|
||||||
'default': True if category.default else False,
|
'default': True if category.default else False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_previous_snapshot(snapshot):
|
def get_previous_snapshot(snapshot):
|
||||||
assert snapshot
|
assert snapshot
|
||||||
return db.session \
|
return db.session \
|
||||||
|
@ -43,6 +47,7 @@ def get_previous_snapshot(snapshot):
|
||||||
.limit(1) \
|
.limit(1) \
|
||||||
.first()
|
.first()
|
||||||
|
|
||||||
|
|
||||||
def get_snapshots(entity):
|
def get_snapshots(entity):
|
||||||
assert entity
|
assert entity
|
||||||
resource_type, resource_id, _ = db.util.get_resource_info(entity)
|
resource_type, resource_id, _ = db.util.get_resource_info(entity)
|
||||||
|
@ -53,6 +58,7 @@ def get_snapshots(entity):
|
||||||
.order_by(db.Snapshot.creation_time.desc()) \
|
.order_by(db.Snapshot.creation_time.desc()) \
|
||||||
.all()
|
.all()
|
||||||
|
|
||||||
|
|
||||||
def serialize_snapshot(snapshot, earlier_snapshot=()):
|
def serialize_snapshot(snapshot, earlier_snapshot=()):
|
||||||
assert snapshot
|
assert snapshot
|
||||||
if earlier_snapshot is ():
|
if earlier_snapshot is ():
|
||||||
|
@ -67,6 +73,7 @@ def serialize_snapshot(snapshot, earlier_snapshot=()):
|
||||||
'time': snapshot.creation_time,
|
'time': snapshot.creation_time,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_serialized_history(entity):
|
def get_serialized_history(entity):
|
||||||
if not entity:
|
if not entity:
|
||||||
return []
|
return []
|
||||||
|
@ -77,6 +84,7 @@ def get_serialized_history(entity):
|
||||||
earlier_snapshot = snapshot
|
earlier_snapshot = snapshot
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def _save(operation, entity, auth_user):
|
def _save(operation, entity, auth_user):
|
||||||
assert operation
|
assert operation
|
||||||
assert entity
|
assert entity
|
||||||
|
@ -86,7 +94,8 @@ def _save(operation, entity, auth_user):
|
||||||
'post': get_post_snapshot,
|
'post': get_post_snapshot,
|
||||||
}
|
}
|
||||||
|
|
||||||
resource_type, resource_id, resource_repr = db.util.get_resource_info(entity)
|
resource_type, resource_id, resource_repr = (
|
||||||
|
db.util.get_resource_info(entity))
|
||||||
now = datetime.datetime.utcnow()
|
now = datetime.datetime.utcnow()
|
||||||
|
|
||||||
snapshot = db.Snapshot()
|
snapshot = db.Snapshot()
|
||||||
|
@ -118,14 +127,17 @@ def _save(operation, entity, auth_user):
|
||||||
else:
|
else:
|
||||||
db.session.add(snapshot)
|
db.session.add(snapshot)
|
||||||
|
|
||||||
|
|
||||||
def save_entity_creation(entity, auth_user):
|
def save_entity_creation(entity, auth_user):
|
||||||
assert entity
|
assert entity
|
||||||
_save(db.Snapshot.OPERATION_CREATED, entity, auth_user)
|
_save(db.Snapshot.OPERATION_CREATED, entity, auth_user)
|
||||||
|
|
||||||
|
|
||||||
def save_entity_modification(entity, auth_user):
|
def save_entity_modification(entity, auth_user):
|
||||||
assert entity
|
assert entity
|
||||||
_save(db.Snapshot.OPERATION_MODIFIED, entity, auth_user)
|
_save(db.Snapshot.OPERATION_MODIFIED, entity, auth_user)
|
||||||
|
|
||||||
|
|
||||||
def save_entity_deletion(entity, auth_user):
|
def save_entity_deletion(entity, auth_user):
|
||||||
assert entity
|
assert entity
|
||||||
_save(db.Snapshot.OPERATION_DELETED, entity, auth_user)
|
_save(db.Snapshot.OPERATION_DELETED, entity, auth_user)
|
||||||
|
|
|
@ -3,11 +3,26 @@ import sqlalchemy
|
||||||
from szurubooru import config, db, errors
|
from szurubooru import config, db, errors
|
||||||
from szurubooru.func import util, snapshots, cache
|
from szurubooru.func import util, snapshots, cache
|
||||||
|
|
||||||
class TagCategoryNotFoundError(errors.NotFoundError): pass
|
|
||||||
class TagCategoryAlreadyExistsError(errors.ValidationError): pass
|
class TagCategoryNotFoundError(errors.NotFoundError):
|
||||||
class TagCategoryIsInUseError(errors.ValidationError): pass
|
pass
|
||||||
class InvalidTagCategoryNameError(errors.ValidationError): pass
|
|
||||||
class InvalidTagCategoryColorError(errors.ValidationError): pass
|
|
||||||
|
class TagCategoryAlreadyExistsError(errors.ValidationError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TagCategoryIsInUseError(errors.ValidationError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidTagCategoryNameError(errors.ValidationError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidTagCategoryColorError(errors.ValidationError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _verify_name_validity(name):
|
def _verify_name_validity(name):
|
||||||
name_regex = config.config['tag_category_name_regex']
|
name_regex = config.config['tag_category_name_regex']
|
||||||
|
@ -15,6 +30,7 @@ def _verify_name_validity(name):
|
||||||
raise InvalidTagCategoryNameError(
|
raise InvalidTagCategoryNameError(
|
||||||
'Name must satisfy regex %r.' % name_regex)
|
'Name must satisfy regex %r.' % name_regex)
|
||||||
|
|
||||||
|
|
||||||
def serialize_category(category, options=None):
|
def serialize_category(category, options=None):
|
||||||
return util.serialize_entity(
|
return util.serialize_entity(
|
||||||
category,
|
category,
|
||||||
|
@ -28,6 +44,7 @@ def serialize_category(category, options=None):
|
||||||
},
|
},
|
||||||
options)
|
options)
|
||||||
|
|
||||||
|
|
||||||
def create_category(name, color):
|
def create_category(name, color):
|
||||||
category = db.TagCategory()
|
category = db.TagCategory()
|
||||||
update_category_name(category, name)
|
update_category_name(category, name)
|
||||||
|
@ -36,13 +53,15 @@ def create_category(name, color):
|
||||||
category.default = True
|
category.default = True
|
||||||
return category
|
return category
|
||||||
|
|
||||||
|
|
||||||
def update_category_name(category, name):
|
def update_category_name(category, name):
|
||||||
assert category
|
assert category
|
||||||
if not name:
|
if not name:
|
||||||
raise InvalidTagCategoryNameError('Name cannot be empty.')
|
raise InvalidTagCategoryNameError('Name cannot be empty.')
|
||||||
expr = sqlalchemy.func.lower(db.TagCategory.name) == name.lower()
|
expr = sqlalchemy.func.lower(db.TagCategory.name) == name.lower()
|
||||||
if category.tag_category_id:
|
if category.tag_category_id:
|
||||||
expr = expr & (db.TagCategory.tag_category_id != category.tag_category_id)
|
expr = expr & (
|
||||||
|
db.TagCategory.tag_category_id != category.tag_category_id)
|
||||||
already_exists = db.session.query(db.TagCategory).filter(expr).count() > 0
|
already_exists = db.session.query(db.TagCategory).filter(expr).count() > 0
|
||||||
if already_exists:
|
if already_exists:
|
||||||
raise TagCategoryAlreadyExistsError(
|
raise TagCategoryAlreadyExistsError(
|
||||||
|
@ -52,6 +71,7 @@ def update_category_name(category, name):
|
||||||
_verify_name_validity(name)
|
_verify_name_validity(name)
|
||||||
category.name = name
|
category.name = name
|
||||||
|
|
||||||
|
|
||||||
def update_category_color(category, color):
|
def update_category_color(category, color):
|
||||||
assert category
|
assert category
|
||||||
if not color:
|
if not color:
|
||||||
|
@ -62,24 +82,29 @@ def update_category_color(category, color):
|
||||||
raise InvalidTagCategoryColorError('Color is too long.')
|
raise InvalidTagCategoryColorError('Color is too long.')
|
||||||
category.color = color
|
category.color = color
|
||||||
|
|
||||||
|
|
||||||
def try_get_category_by_name(name):
|
def try_get_category_by_name(name):
|
||||||
return db.session \
|
return db.session \
|
||||||
.query(db.TagCategory) \
|
.query(db.TagCategory) \
|
||||||
.filter(sqlalchemy.func.lower(db.TagCategory.name) == name.lower()) \
|
.filter(sqlalchemy.func.lower(db.TagCategory.name) == name.lower()) \
|
||||||
.one_or_none()
|
.one_or_none()
|
||||||
|
|
||||||
|
|
||||||
def get_category_by_name(name):
|
def get_category_by_name(name):
|
||||||
category = try_get_category_by_name(name)
|
category = try_get_category_by_name(name)
|
||||||
if not category:
|
if not category:
|
||||||
raise TagCategoryNotFoundError('Tag category %r not found.' % name)
|
raise TagCategoryNotFoundError('Tag category %r not found.' % name)
|
||||||
return category
|
return category
|
||||||
|
|
||||||
|
|
||||||
def get_all_category_names():
|
def get_all_category_names():
|
||||||
return [row[0] for row in db.session.query(db.TagCategory.name).all()]
|
return [row[0] for row in db.session.query(db.TagCategory.name).all()]
|
||||||
|
|
||||||
|
|
||||||
def get_all_categories():
|
def get_all_categories():
|
||||||
return db.session.query(db.TagCategory).all()
|
return db.session.query(db.TagCategory).all()
|
||||||
|
|
||||||
|
|
||||||
def try_get_default_category():
|
def try_get_default_category():
|
||||||
key = 'default-tag-category'
|
key = 'default-tag-category'
|
||||||
if cache.has(key):
|
if cache.has(key):
|
||||||
|
@ -98,12 +123,14 @@ def try_get_default_category():
|
||||||
cache.put(key, category)
|
cache.put(key, category)
|
||||||
return category
|
return category
|
||||||
|
|
||||||
|
|
||||||
def get_default_category():
|
def get_default_category():
|
||||||
category = try_get_default_category()
|
category = try_get_default_category()
|
||||||
if not category:
|
if not category:
|
||||||
raise TagCategoryNotFoundError('No tag category created yet.')
|
raise TagCategoryNotFoundError('No tag category created yet.')
|
||||||
return category
|
return category
|
||||||
|
|
||||||
|
|
||||||
def set_default_category(category):
|
def set_default_category(category):
|
||||||
assert category
|
assert category
|
||||||
old_category = try_get_default_category()
|
old_category = try_get_default_category()
|
||||||
|
@ -111,6 +138,7 @@ def set_default_category(category):
|
||||||
old_category.default = False
|
old_category.default = False
|
||||||
category.default = True
|
category.default = True
|
||||||
|
|
||||||
|
|
||||||
def delete_category(category):
|
def delete_category(category):
|
||||||
assert category
|
assert category
|
||||||
if len(get_all_category_names()) == 1:
|
if len(get_all_category_names()) == 1:
|
||||||
|
|
|
@ -6,32 +6,57 @@ import sqlalchemy
|
||||||
from szurubooru import config, db, errors
|
from szurubooru import config, db, errors
|
||||||
from szurubooru.func import util, tag_categories, snapshots
|
from szurubooru.func import util, tag_categories, snapshots
|
||||||
|
|
||||||
class TagNotFoundError(errors.NotFoundError): pass
|
|
||||||
class TagAlreadyExistsError(errors.ValidationError): pass
|
class TagNotFoundError(errors.NotFoundError):
|
||||||
class TagIsInUseError(errors.ValidationError): pass
|
pass
|
||||||
class InvalidTagNameError(errors.ValidationError): pass
|
|
||||||
class InvalidTagRelationError(errors.ValidationError): pass
|
|
||||||
class InvalidTagCategoryError(errors.ValidationError): pass
|
class TagAlreadyExistsError(errors.ValidationError):
|
||||||
class InvalidTagDescriptionError(errors.ValidationError): pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TagIsInUseError(errors.ValidationError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidTagNameError(errors.ValidationError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidTagRelationError(errors.ValidationError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidTagCategoryError(errors.ValidationError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidTagDescriptionError(errors.ValidationError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _verify_name_validity(name):
|
def _verify_name_validity(name):
|
||||||
name_regex = config.config['tag_name_regex']
|
name_regex = config.config['tag_name_regex']
|
||||||
if not re.match(name_regex, name):
|
if not re.match(name_regex, name):
|
||||||
raise InvalidTagNameError('Name must satisfy regex %r.' % name_regex)
|
raise InvalidTagNameError('Name must satisfy regex %r.' % name_regex)
|
||||||
|
|
||||||
def _get_plain_names(tag):
|
|
||||||
|
def _get_names(tag):
|
||||||
assert tag
|
assert tag
|
||||||
return [tag_name.name for tag_name in tag.names]
|
return [tag_name.name for tag_name in tag.names]
|
||||||
|
|
||||||
|
|
||||||
def _lower_list(names):
|
def _lower_list(names):
|
||||||
return [name.lower() for name in names]
|
return [name.lower() for name in names]
|
||||||
|
|
||||||
def _check_name_intersection(names1, names2):
|
|
||||||
return len(set(_lower_list(names1)).intersection(_lower_list(names2))) > 0
|
|
||||||
|
|
||||||
def _check_name_intersection_case_sensitive(names1, names2):
|
def _check_name_intersection(names1, names2, case_sensitive):
|
||||||
|
if not case_sensitive:
|
||||||
|
names1 = _lower_list(names1)
|
||||||
|
names2 = _lower_list(names2)
|
||||||
return len(set(names1).intersection(names2)) > 0
|
return len(set(names1).intersection(names2)) > 0
|
||||||
|
|
||||||
|
|
||||||
def sort_tags(tags):
|
def sort_tags(tags):
|
||||||
default_category = tag_categories.try_get_default_category()
|
default_category = tag_categories.try_get_default_category()
|
||||||
default_category_name = default_category.name if default_category else None
|
default_category_name = default_category.name if default_category else None
|
||||||
|
@ -43,6 +68,7 @@ def sort_tags(tags):
|
||||||
tag.names[0].name)
|
tag.names[0].name)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def serialize_tag(tag, options=None):
|
def serialize_tag(tag, options=None):
|
||||||
return util.serialize_entity(
|
return util.serialize_entity(
|
||||||
tag,
|
tag,
|
||||||
|
@ -55,15 +81,16 @@ def serialize_tag(tag, options=None):
|
||||||
'lastEditTime': lambda: tag.last_edit_time,
|
'lastEditTime': lambda: tag.last_edit_time,
|
||||||
'usages': lambda: tag.post_count,
|
'usages': lambda: tag.post_count,
|
||||||
'suggestions': lambda: [
|
'suggestions': lambda: [
|
||||||
relation.names[0].name \
|
relation.names[0].name
|
||||||
for relation in sort_tags(tag.suggestions)],
|
for relation in sort_tags(tag.suggestions)],
|
||||||
'implications': lambda: [
|
'implications': lambda: [
|
||||||
relation.names[0].name \
|
relation.names[0].name
|
||||||
for relation in sort_tags(tag.implications)],
|
for relation in sort_tags(tag.implications)],
|
||||||
'snapshots': lambda: snapshots.get_serialized_history(tag),
|
'snapshots': lambda: snapshots.get_serialized_history(tag),
|
||||||
},
|
},
|
||||||
options)
|
options)
|
||||||
|
|
||||||
|
|
||||||
def export_to_json():
|
def export_to_json():
|
||||||
tags = {}
|
tags = {}
|
||||||
categories = {}
|
categories = {}
|
||||||
|
@ -82,19 +109,19 @@ def export_to_json():
|
||||||
tags[result[0]] = {'names': []}
|
tags[result[0]] = {'names': []}
|
||||||
tags[result[0]]['names'].append(result[1])
|
tags[result[0]]['names'].append(result[1])
|
||||||
|
|
||||||
for result in db.session \
|
for result in (db.session
|
||||||
.query(db.TagSuggestion.parent_id, db.TagName.name) \
|
.query(db.TagSuggestion.parent_id, db.TagName.name)
|
||||||
.join(db.TagName, db.TagName.tag_id == db.TagSuggestion.child_id) \
|
.join(db.TagName, db.TagName.tag_id == db.TagSuggestion.child_id)
|
||||||
.all():
|
.all()):
|
||||||
if not 'suggestions' in tags[result[0]]:
|
if 'suggestions' not in tags[result[0]]:
|
||||||
tags[result[0]]['suggestions'] = []
|
tags[result[0]]['suggestions'] = []
|
||||||
tags[result[0]]['suggestions'].append(result[1])
|
tags[result[0]]['suggestions'].append(result[1])
|
||||||
|
|
||||||
for result in db.session \
|
for result in (db.session
|
||||||
.query(db.TagImplication.parent_id, db.TagName.name) \
|
.query(db.TagImplication.parent_id, db.TagName.name)
|
||||||
.join(db.TagName, db.TagName.tag_id == db.TagImplication.child_id) \
|
.join(db.TagName, db.TagName.tag_id == db.TagImplication.child_id)
|
||||||
.all():
|
.all()):
|
||||||
if not 'implications' in tags[result[0]]:
|
if 'implications' not in tags[result[0]]:
|
||||||
tags[result[0]]['implications'] = []
|
tags[result[0]]['implications'] = []
|
||||||
tags[result[0]]['implications'].append(result[1])
|
tags[result[0]]['implications'].append(result[1])
|
||||||
|
|
||||||
|
@ -114,12 +141,14 @@ def export_to_json():
|
||||||
with open(export_path, 'w') as handle:
|
with open(export_path, 'w') as handle:
|
||||||
handle.write(json.dumps(output, separators=(',', ':')))
|
handle.write(json.dumps(output, separators=(',', ':')))
|
||||||
|
|
||||||
|
|
||||||
def try_get_tag_by_name(name):
|
def try_get_tag_by_name(name):
|
||||||
return db.session \
|
return (db.session
|
||||||
.query(db.Tag) \
|
.query(db.Tag)
|
||||||
.join(db.TagName) \
|
.join(db.TagName)
|
||||||
.filter(sqlalchemy.func.lower(db.TagName.name) == name.lower()) \
|
.filter(sqlalchemy.func.lower(db.TagName.name) == name.lower())
|
||||||
.one_or_none()
|
.one_or_none())
|
||||||
|
|
||||||
|
|
||||||
def get_tag_by_name(name):
|
def get_tag_by_name(name):
|
||||||
tag = try_get_tag_by_name(name)
|
tag = try_get_tag_by_name(name)
|
||||||
|
@ -127,6 +156,7 @@ def get_tag_by_name(name):
|
||||||
raise TagNotFoundError('Tag %r not found.' % name)
|
raise TagNotFoundError('Tag %r not found.' % name)
|
||||||
return tag
|
return tag
|
||||||
|
|
||||||
|
|
||||||
def get_tags_by_names(names):
|
def get_tags_by_names(names):
|
||||||
names = util.icase_unique(names)
|
names = util.icase_unique(names)
|
||||||
if len(names) == 0:
|
if len(names) == 0:
|
||||||
|
@ -136,6 +166,7 @@ def get_tags_by_names(names):
|
||||||
expr = expr | (sqlalchemy.func.lower(db.TagName.name) == name.lower())
|
expr = expr | (sqlalchemy.func.lower(db.TagName.name) == name.lower())
|
||||||
return db.session.query(db.Tag).join(db.TagName).filter(expr).all()
|
return db.session.query(db.Tag).join(db.TagName).filter(expr).all()
|
||||||
|
|
||||||
|
|
||||||
def get_or_create_tags_by_names(names):
|
def get_or_create_tags_by_names(names):
|
||||||
names = util.icase_unique(names)
|
names = util.icase_unique(names)
|
||||||
existing_tags = get_tags_by_names(names)
|
existing_tags = get_tags_by_names(names)
|
||||||
|
@ -144,7 +175,8 @@ def get_or_create_tags_by_names(names):
|
||||||
for name in names:
|
for name in names:
|
||||||
found = False
|
found = False
|
||||||
for existing_tag in existing_tags:
|
for existing_tag in existing_tags:
|
||||||
if _check_name_intersection(_get_plain_names(existing_tag), [name]):
|
if _check_name_intersection(
|
||||||
|
_get_names(existing_tag), [name], False):
|
||||||
found = True
|
found = True
|
||||||
break
|
break
|
||||||
if not found:
|
if not found:
|
||||||
|
@ -157,32 +189,35 @@ def get_or_create_tags_by_names(names):
|
||||||
new_tags.append(new_tag)
|
new_tags.append(new_tag)
|
||||||
return existing_tags, new_tags
|
return existing_tags, new_tags
|
||||||
|
|
||||||
|
|
||||||
def get_tag_siblings(tag):
|
def get_tag_siblings(tag):
|
||||||
assert tag
|
assert tag
|
||||||
tag_alias = sqlalchemy.orm.aliased(db.Tag)
|
tag_alias = sqlalchemy.orm.aliased(db.Tag)
|
||||||
pt_alias1 = sqlalchemy.orm.aliased(db.PostTag)
|
pt_alias1 = sqlalchemy.orm.aliased(db.PostTag)
|
||||||
pt_alias2 = sqlalchemy.orm.aliased(db.PostTag)
|
pt_alias2 = sqlalchemy.orm.aliased(db.PostTag)
|
||||||
result = db.session \
|
result = (db.session
|
||||||
.query(tag_alias, sqlalchemy.func.count(pt_alias2.post_id)) \
|
.query(tag_alias, sqlalchemy.func.count(pt_alias2.post_id))
|
||||||
.join(pt_alias1, pt_alias1.tag_id == tag_alias.tag_id) \
|
.join(pt_alias1, pt_alias1.tag_id == tag_alias.tag_id)
|
||||||
.join(pt_alias2, pt_alias2.post_id == pt_alias1.post_id) \
|
.join(pt_alias2, pt_alias2.post_id == pt_alias1.post_id)
|
||||||
.filter(pt_alias2.tag_id == tag.tag_id) \
|
.filter(pt_alias2.tag_id == tag.tag_id)
|
||||||
.filter(pt_alias1.tag_id != tag.tag_id) \
|
.filter(pt_alias1.tag_id != tag.tag_id)
|
||||||
.group_by(tag_alias.tag_id) \
|
.group_by(tag_alias.tag_id)
|
||||||
.order_by(sqlalchemy.func.count(pt_alias2.post_id).desc()) \
|
.order_by(sqlalchemy.func.count(pt_alias2.post_id).desc())
|
||||||
.limit(50)
|
.limit(50))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def delete(source_tag):
|
def delete(source_tag):
|
||||||
assert source_tag
|
assert source_tag
|
||||||
db.session.execute(
|
db.session.execute(
|
||||||
sqlalchemy.sql.expression.delete(db.TagSuggestion) \
|
sqlalchemy.sql.expression.delete(db.TagSuggestion)
|
||||||
.where(db.TagSuggestion.child_id == source_tag.tag_id))
|
.where(db.TagSuggestion.child_id == source_tag.tag_id))
|
||||||
db.session.execute(
|
db.session.execute(
|
||||||
sqlalchemy.sql.expression.delete(db.TagImplication) \
|
sqlalchemy.sql.expression.delete(db.TagImplication)
|
||||||
.where(db.TagImplication.child_id == source_tag.tag_id))
|
.where(db.TagImplication.child_id == source_tag.tag_id))
|
||||||
db.session.delete(source_tag)
|
db.session.delete(source_tag)
|
||||||
|
|
||||||
|
|
||||||
def merge_tags(source_tag, target_tag):
|
def merge_tags(source_tag, target_tag):
|
||||||
assert source_tag
|
assert source_tag
|
||||||
assert target_tag
|
assert target_tag
|
||||||
|
@ -191,15 +226,16 @@ def merge_tags(source_tag, target_tag):
|
||||||
pt1 = db.PostTag
|
pt1 = db.PostTag
|
||||||
pt2 = sqlalchemy.orm.util.aliased(db.PostTag)
|
pt2 = sqlalchemy.orm.util.aliased(db.PostTag)
|
||||||
|
|
||||||
update_stmt = sqlalchemy.sql.expression.update(pt1) \
|
update_stmt = (sqlalchemy.sql.expression.update(pt1)
|
||||||
.where(db.PostTag.tag_id == source_tag.tag_id) \
|
.where(db.PostTag.tag_id == source_tag.tag_id)
|
||||||
.where(~sqlalchemy.exists() \
|
.where(~sqlalchemy.exists()
|
||||||
.where(pt2.post_id == pt1.post_id) \
|
.where(pt2.post_id == pt1.post_id)
|
||||||
.where(pt2.tag_id == target_tag.tag_id)) \
|
.where(pt2.tag_id == target_tag.tag_id))
|
||||||
.values(tag_id=target_tag.tag_id)
|
.values(tag_id=target_tag.tag_id))
|
||||||
db.session.execute(update_stmt)
|
db.session.execute(update_stmt)
|
||||||
delete(source_tag)
|
delete(source_tag)
|
||||||
|
|
||||||
|
|
||||||
def create_tag(names, category_name, suggestions, implications):
|
def create_tag(names, category_name, suggestions, implications):
|
||||||
tag = db.Tag()
|
tag = db.Tag()
|
||||||
tag.creation_time = datetime.datetime.utcnow()
|
tag.creation_time = datetime.datetime.utcnow()
|
||||||
|
@ -209,10 +245,12 @@ def create_tag(names, category_name, suggestions, implications):
|
||||||
update_tag_implications(tag, implications)
|
update_tag_implications(tag, implications)
|
||||||
return tag
|
return tag
|
||||||
|
|
||||||
|
|
||||||
def update_tag_category_name(tag, category_name):
|
def update_tag_category_name(tag, category_name):
|
||||||
assert tag
|
assert tag
|
||||||
tag.category = tag_categories.get_category_by_name(category_name)
|
tag.category = tag_categories.get_category_by_name(category_name)
|
||||||
|
|
||||||
|
|
||||||
def update_tag_names(tag, names):
|
def update_tag_names(tag, names):
|
||||||
assert tag
|
assert tag
|
||||||
names = util.icase_unique([name for name in names if name])
|
names = util.icase_unique([name for name in names if name])
|
||||||
|
@ -232,26 +270,29 @@ def update_tag_names(tag, names):
|
||||||
raise TagAlreadyExistsError(
|
raise TagAlreadyExistsError(
|
||||||
'One of names is already used by another tag.')
|
'One of names is already used by another tag.')
|
||||||
for tag_name in tag.names[:]:
|
for tag_name in tag.names[:]:
|
||||||
if not _check_name_intersection_case_sensitive([tag_name.name], names):
|
if not _check_name_intersection([tag_name.name], names, True):
|
||||||
tag.names.remove(tag_name)
|
tag.names.remove(tag_name)
|
||||||
for name in names:
|
for name in names:
|
||||||
if not _check_name_intersection_case_sensitive(_get_plain_names(tag), [name]):
|
if not _check_name_intersection(_get_names(tag), [name], True):
|
||||||
tag.names.append(db.TagName(name))
|
tag.names.append(db.TagName(name))
|
||||||
|
|
||||||
|
|
||||||
# TODO: what to do with relations that do not yet exist?
|
# TODO: what to do with relations that do not yet exist?
|
||||||
def update_tag_implications(tag, relations):
|
def update_tag_implications(tag, relations):
|
||||||
assert tag
|
assert tag
|
||||||
if _check_name_intersection(_get_plain_names(tag), relations):
|
if _check_name_intersection(_get_names(tag), relations, False):
|
||||||
raise InvalidTagRelationError('Tag cannot imply itself.')
|
raise InvalidTagRelationError('Tag cannot imply itself.')
|
||||||
tag.implications = get_tags_by_names(relations)
|
tag.implications = get_tags_by_names(relations)
|
||||||
|
|
||||||
|
|
||||||
# TODO: what to do with relations that do not yet exist?
|
# TODO: what to do with relations that do not yet exist?
|
||||||
def update_tag_suggestions(tag, relations):
|
def update_tag_suggestions(tag, relations):
|
||||||
assert tag
|
assert tag
|
||||||
if _check_name_intersection(_get_plain_names(tag), relations):
|
if _check_name_intersection(_get_names(tag), relations, False):
|
||||||
raise InvalidTagRelationError('Tag cannot suggest itself.')
|
raise InvalidTagRelationError('Tag cannot suggest itself.')
|
||||||
tag.suggestions = get_tags_by_names(relations)
|
tag.suggestions = get_tags_by_names(relations)
|
||||||
|
|
||||||
|
|
||||||
def update_tag_description(tag, description):
|
def update_tag_description(tag, description):
|
||||||
assert tag
|
assert tag
|
||||||
if util.value_exceeds_column_size(description, db.Tag.description):
|
if util.value_exceeds_column_size(description, db.Tag.description):
|
||||||
|
|
|
@ -4,17 +4,39 @@ from sqlalchemy import func
|
||||||
from szurubooru import config, db, errors
|
from szurubooru import config, db, errors
|
||||||
from szurubooru.func import auth, util, files, images
|
from szurubooru.func import auth, util, files, images
|
||||||
|
|
||||||
class UserNotFoundError(errors.NotFoundError): pass
|
|
||||||
class UserAlreadyExistsError(errors.ValidationError): pass
|
class UserNotFoundError(errors.NotFoundError):
|
||||||
class InvalidUserNameError(errors.ValidationError): pass
|
pass
|
||||||
class InvalidEmailError(errors.ValidationError): pass
|
|
||||||
class InvalidPasswordError(errors.ValidationError): pass
|
|
||||||
class InvalidRankError(errors.ValidationError): pass
|
class UserAlreadyExistsError(errors.ValidationError):
|
||||||
class InvalidAvatarError(errors.ValidationError): pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidUserNameError(errors.ValidationError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidEmailError(errors.ValidationError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidPasswordError(errors.ValidationError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidRankError(errors.ValidationError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidAvatarError(errors.ValidationError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def get_avatar_path(user_name):
|
def get_avatar_path(user_name):
|
||||||
return 'avatars/' + user_name.lower() + '.png'
|
return 'avatars/' + user_name.lower() + '.png'
|
||||||
|
|
||||||
|
|
||||||
def get_avatar_url(user):
|
def get_avatar_url(user):
|
||||||
assert user
|
assert user
|
||||||
if user.avatar_style == user.AVATAR_GRAVATAR:
|
if user.avatar_style == user.AVATAR_GRAVATAR:
|
||||||
|
@ -27,6 +49,7 @@ def get_avatar_url(user):
|
||||||
return '%s/avatars/%s.png' % (
|
return '%s/avatars/%s.png' % (
|
||||||
config.config['data_url'].rstrip('/'), user.name.lower())
|
config.config['data_url'].rstrip('/'), user.name.lower())
|
||||||
|
|
||||||
|
|
||||||
def get_email(user, auth_user, force_show_email):
|
def get_email(user, auth_user, force_show_email):
|
||||||
assert user
|
assert user
|
||||||
assert auth_user
|
assert auth_user
|
||||||
|
@ -36,6 +59,7 @@ def get_email(user, auth_user, force_show_email):
|
||||||
return False
|
return False
|
||||||
return user.email
|
return user.email
|
||||||
|
|
||||||
|
|
||||||
def get_liked_post_count(user, auth_user):
|
def get_liked_post_count(user, auth_user):
|
||||||
assert user
|
assert user
|
||||||
assert auth_user
|
assert auth_user
|
||||||
|
@ -43,6 +67,7 @@ def get_liked_post_count(user, auth_user):
|
||||||
return False
|
return False
|
||||||
return user.liked_post_count
|
return user.liked_post_count
|
||||||
|
|
||||||
|
|
||||||
def get_disliked_post_count(user, auth_user):
|
def get_disliked_post_count(user, auth_user):
|
||||||
assert user
|
assert user
|
||||||
assert auth_user
|
assert auth_user
|
||||||
|
@ -50,6 +75,7 @@ def get_disliked_post_count(user, auth_user):
|
||||||
return False
|
return False
|
||||||
return user.disliked_post_count
|
return user.disliked_post_count
|
||||||
|
|
||||||
|
|
||||||
def serialize_user(user, auth_user, options=None, force_show_email=False):
|
def serialize_user(user, auth_user, options=None, force_show_email=False):
|
||||||
return util.serialize_entity(
|
return util.serialize_entity(
|
||||||
user,
|
user,
|
||||||
|
@ -73,34 +99,40 @@ def serialize_user(user, auth_user, options=None, force_show_email=False):
|
||||||
},
|
},
|
||||||
options)
|
options)
|
||||||
|
|
||||||
|
|
||||||
def serialize_micro_user(user, auth_user):
|
def serialize_micro_user(user, auth_user):
|
||||||
return serialize_user(
|
return serialize_user(
|
||||||
user,
|
user,
|
||||||
auth_user=auth_user,
|
auth_user=auth_user,
|
||||||
options=['name', 'avatarUrl'])
|
options=['name', 'avatarUrl'])
|
||||||
|
|
||||||
|
|
||||||
def get_user_count():
|
def get_user_count():
|
||||||
return db.session.query(db.User).count()
|
return db.session.query(db.User).count()
|
||||||
|
|
||||||
|
|
||||||
def try_get_user_by_name(name):
|
def try_get_user_by_name(name):
|
||||||
return db.session \
|
return db.session \
|
||||||
.query(db.User) \
|
.query(db.User) \
|
||||||
.filter(func.lower(db.User.name) == func.lower(name)) \
|
.filter(func.lower(db.User.name) == func.lower(name)) \
|
||||||
.one_or_none()
|
.one_or_none()
|
||||||
|
|
||||||
|
|
||||||
def get_user_by_name(name):
|
def get_user_by_name(name):
|
||||||
user = try_get_user_by_name(name)
|
user = try_get_user_by_name(name)
|
||||||
if not user:
|
if not user:
|
||||||
raise UserNotFoundError('User %r not found.' % name)
|
raise UserNotFoundError('User %r not found.' % name)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
def try_get_user_by_name_or_email(name_or_email):
|
def try_get_user_by_name_or_email(name_or_email):
|
||||||
return db.session \
|
return (db.session
|
||||||
.query(db.User) \
|
.query(db.User)
|
||||||
.filter(
|
.filter(
|
||||||
(func.lower(db.User.name) == func.lower(name_or_email))
|
(func.lower(db.User.name) == func.lower(name_or_email)) |
|
||||||
| (func.lower(db.User.email) == func.lower(name_or_email))) \
|
(func.lower(db.User.email) == func.lower(name_or_email)))
|
||||||
.one_or_none()
|
.one_or_none())
|
||||||
|
|
||||||
|
|
||||||
def get_user_by_name_or_email(name_or_email):
|
def get_user_by_name_or_email(name_or_email):
|
||||||
user = try_get_user_by_name_or_email(name_or_email)
|
user = try_get_user_by_name_or_email(name_or_email)
|
||||||
|
@ -108,6 +140,7 @@ def get_user_by_name_or_email(name_or_email):
|
||||||
raise UserNotFoundError('User %r not found.' % name_or_email)
|
raise UserNotFoundError('User %r not found.' % name_or_email)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
def create_user(name, password, email):
|
def create_user(name, password, email):
|
||||||
user = db.User()
|
user = db.User()
|
||||||
update_user_name(user, name)
|
update_user_name(user, name)
|
||||||
|
@ -121,6 +154,7 @@ def create_user(name, password, email):
|
||||||
user.avatar_style = db.User.AVATAR_GRAVATAR
|
user.avatar_style = db.User.AVATAR_GRAVATAR
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
def update_user_name(user, name):
|
def update_user_name(user, name):
|
||||||
assert user
|
assert user
|
||||||
if not name:
|
if not name:
|
||||||
|
@ -139,6 +173,7 @@ def update_user_name(user, name):
|
||||||
files.move(get_avatar_path(user.name), get_avatar_path(name))
|
files.move(get_avatar_path(user.name), get_avatar_path(name))
|
||||||
user.name = name
|
user.name = name
|
||||||
|
|
||||||
|
|
||||||
def update_user_password(user, password):
|
def update_user_password(user, password):
|
||||||
assert user
|
assert user
|
||||||
if not password:
|
if not password:
|
||||||
|
@ -150,6 +185,7 @@ def update_user_password(user, password):
|
||||||
user.password_salt = auth.create_password()
|
user.password_salt = auth.create_password()
|
||||||
user.password_hash = auth.get_password_hash(user.password_salt, password)
|
user.password_hash = auth.get_password_hash(user.password_salt, password)
|
||||||
|
|
||||||
|
|
||||||
def update_user_email(user, email):
|
def update_user_email(user, email):
|
||||||
assert user
|
assert user
|
||||||
if email:
|
if email:
|
||||||
|
@ -162,6 +198,7 @@ def update_user_email(user, email):
|
||||||
raise InvalidEmailError('E-mail is invalid.')
|
raise InvalidEmailError('E-mail is invalid.')
|
||||||
user.email = email
|
user.email = email
|
||||||
|
|
||||||
|
|
||||||
def update_user_rank(user, rank, auth_user):
|
def update_user_rank(user, rank, auth_user):
|
||||||
assert user
|
assert user
|
||||||
if not rank:
|
if not rank:
|
||||||
|
@ -178,6 +215,7 @@ def update_user_rank(user, rank, auth_user):
|
||||||
raise errors.AuthError('Trying to set higher rank than your own.')
|
raise errors.AuthError('Trying to set higher rank than your own.')
|
||||||
user.rank = rank
|
user.rank = rank
|
||||||
|
|
||||||
|
|
||||||
def update_user_avatar(user, avatar_style, avatar_content=None):
|
def update_user_avatar(user, avatar_style, avatar_content=None):
|
||||||
assert user
|
assert user
|
||||||
if avatar_style == 'gravatar':
|
if avatar_style == 'gravatar':
|
||||||
|
@ -199,10 +237,12 @@ def update_user_avatar(user, avatar_style, avatar_content=None):
|
||||||
'Avatar style %r is invalid. Valid avatar styles: %r.' % (
|
'Avatar style %r is invalid. Valid avatar styles: %r.' % (
|
||||||
avatar_style, ['gravatar', 'manual']))
|
avatar_style, ['gravatar', 'manual']))
|
||||||
|
|
||||||
|
|
||||||
def bump_user_login_time(user):
|
def bump_user_login_time(user):
|
||||||
assert user
|
assert user
|
||||||
user.last_login_time = datetime.datetime.utcnow()
|
user.last_login_time = datetime.datetime.utcnow()
|
||||||
|
|
||||||
|
|
||||||
def reset_user_password(user):
|
def reset_user_password(user):
|
||||||
assert user
|
assert user
|
||||||
password = auth.create_password()
|
password = auth.create_password()
|
||||||
|
|
|
@ -1,18 +1,22 @@
|
||||||
import os
|
import os
|
||||||
import datetime
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from datetime import datetime, timedelta
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from szurubooru import errors
|
from szurubooru import errors
|
||||||
|
|
||||||
|
|
||||||
def snake_case_to_lower_camel_case(text):
|
def snake_case_to_lower_camel_case(text):
|
||||||
components = text.split('_')
|
components = text.split('_')
|
||||||
return components[0].lower() + \
|
return components[0].lower() + \
|
||||||
''.join(word[0].upper() + word[1:].lower() for word in components[1:])
|
''.join(word[0].upper() + word[1:].lower() for word in components[1:])
|
||||||
|
|
||||||
|
|
||||||
def snake_case_to_upper_train_case(text):
|
def snake_case_to_upper_train_case(text):
|
||||||
return '-'.join(word[0].upper() + word[1:].lower() for word in text.split('_'))
|
return '-'.join(
|
||||||
|
word[0].upper() + word[1:].lower() for word in text.split('_'))
|
||||||
|
|
||||||
|
|
||||||
def snake_case_to_lower_camel_case_keys(source):
|
def snake_case_to_lower_camel_case_keys(source):
|
||||||
target = {}
|
target = {}
|
||||||
|
@ -20,9 +24,11 @@ def snake_case_to_lower_camel_case_keys(source):
|
||||||
target[snake_case_to_lower_camel_case(key)] = value
|
target[snake_case_to_lower_camel_case(key)] = value
|
||||||
return target
|
return target
|
||||||
|
|
||||||
|
|
||||||
def get_serialization_options(ctx):
|
def get_serialization_options(ctx):
|
||||||
return ctx.get_param_as_list('fields', required=False, default=None)
|
return ctx.get_param_as_list('fields', required=False, default=None)
|
||||||
|
|
||||||
|
|
||||||
def serialize_entity(entity, field_factories, options):
|
def serialize_entity(entity, field_factories, options):
|
||||||
if not entity:
|
if not entity:
|
||||||
return None
|
return None
|
||||||
|
@ -30,13 +36,14 @@ def serialize_entity(entity, field_factories, options):
|
||||||
options = field_factories.keys()
|
options = field_factories.keys()
|
||||||
ret = {}
|
ret = {}
|
||||||
for key in options:
|
for key in options:
|
||||||
if not key in field_factories:
|
if key not in field_factories:
|
||||||
raise errors.ValidationError('Invalid key: %r. Valid keys: %r.' % (
|
raise errors.ValidationError('Invalid key: %r. Valid keys: %r.' % (
|
||||||
key, list(sorted(field_factories.keys()))))
|
key, list(sorted(field_factories.keys()))))
|
||||||
factory = field_factories[key]
|
factory = field_factories[key]
|
||||||
ret[key] = factory()
|
ret[key] = factory()
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def create_temp_file(**kwargs):
|
def create_temp_file(**kwargs):
|
||||||
(handle, path) = tempfile.mkstemp(**kwargs)
|
(handle, path) = tempfile.mkstemp(**kwargs)
|
||||||
|
@ -47,6 +54,7 @@ def create_temp_file(**kwargs):
|
||||||
finally:
|
finally:
|
||||||
os.remove(path)
|
os.remove(path)
|
||||||
|
|
||||||
|
|
||||||
def unalias_dict(input_dict):
|
def unalias_dict(input_dict):
|
||||||
output_dict = {}
|
output_dict = {}
|
||||||
for key_list, value in input_dict.items():
|
for key_list, value in input_dict.items():
|
||||||
|
@ -56,6 +64,7 @@ def unalias_dict(input_dict):
|
||||||
output_dict[key] = value
|
output_dict[key] = value
|
||||||
return output_dict
|
return output_dict
|
||||||
|
|
||||||
|
|
||||||
def get_md5(source):
|
def get_md5(source):
|
||||||
if not isinstance(source, bytes):
|
if not isinstance(source, bytes):
|
||||||
source = source.encode('utf-8')
|
source = source.encode('utf-8')
|
||||||
|
@ -63,57 +72,58 @@ def get_md5(source):
|
||||||
md5.update(source)
|
md5.update(source)
|
||||||
return md5.hexdigest()
|
return md5.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def flip(source):
|
def flip(source):
|
||||||
return {v: k for k, v in source.items()}
|
return {v: k for k, v in source.items()}
|
||||||
|
|
||||||
|
|
||||||
def is_valid_email(email):
|
def is_valid_email(email):
|
||||||
''' Return whether given email address is valid or empty. '''
|
''' Return whether given email address is valid or empty. '''
|
||||||
return not email or re.match(r'^[^@]*@[^@]*\.[^@]*$', email)
|
return not email or re.match(r'^[^@]*@[^@]*\.[^@]*$', email)
|
||||||
|
|
||||||
class dotdict(dict): # pylint: disable=invalid-name
|
|
||||||
|
class dotdict(dict): # pylint: disable=invalid-name
|
||||||
''' dot.notation access to dictionary attributes. '''
|
''' dot.notation access to dictionary attributes. '''
|
||||||
def __getattr__(self, attr):
|
def __getattr__(self, attr):
|
||||||
return self.get(attr)
|
return self.get(attr)
|
||||||
__setattr__ = dict.__setitem__
|
__setattr__ = dict.__setitem__
|
||||||
__delattr__ = dict.__delitem__
|
__delattr__ = dict.__delitem__
|
||||||
|
|
||||||
|
|
||||||
def parse_time_range(value):
|
def parse_time_range(value):
|
||||||
''' Return tuple containing min/max time for given text representation. '''
|
''' Return tuple containing min/max time for given text representation. '''
|
||||||
one_day = datetime.timedelta(days=1)
|
one_day = timedelta(days=1)
|
||||||
one_second = datetime.timedelta(seconds=1)
|
one_second = timedelta(seconds=1)
|
||||||
|
|
||||||
value = value.lower()
|
value = value.lower()
|
||||||
if not value:
|
if not value:
|
||||||
raise errors.ValidationError('Empty date format.')
|
raise errors.ValidationError('Empty date format.')
|
||||||
|
|
||||||
if value == 'today':
|
if value == 'today':
|
||||||
now = datetime.datetime.utcnow()
|
now = datetime.utcnow()
|
||||||
return (
|
return (
|
||||||
datetime.datetime(now.year, now.month, now.day, 0, 0, 0),
|
datetime(now.year, now.month, now.day, 0, 0, 0),
|
||||||
datetime.datetime(now.year, now.month, now.day, 0, 0, 0) \
|
datetime(now.year, now.month, now.day, 0, 0, 0)
|
||||||
+ one_day - one_second)
|
+ one_day - one_second)
|
||||||
|
|
||||||
if value == 'yesterday':
|
if value == 'yesterday':
|
||||||
now = datetime.datetime.utcnow()
|
now = datetime.utcnow()
|
||||||
return (
|
return (
|
||||||
datetime.datetime(now.year, now.month, now.day, 0, 0, 0) - one_day,
|
datetime(now.year, now.month, now.day, 0, 0, 0) - one_day,
|
||||||
datetime.datetime(now.year, now.month, now.day, 0, 0, 0) \
|
datetime(now.year, now.month, now.day, 0, 0, 0) - one_second)
|
||||||
- one_second)
|
|
||||||
|
|
||||||
match = re.match(r'^(\d{4})$', value)
|
match = re.match(r'^(\d{4})$', value)
|
||||||
if match:
|
if match:
|
||||||
year = int(match.group(1))
|
year = int(match.group(1))
|
||||||
return (
|
return (datetime(year, 1, 1), datetime(year + 1, 1, 1) - one_second)
|
||||||
datetime.datetime(year, 1, 1),
|
|
||||||
datetime.datetime(year + 1, 1, 1) - one_second)
|
|
||||||
|
|
||||||
match = re.match(r'^(\d{4})-(\d{1,2})$', value)
|
match = re.match(r'^(\d{4})-(\d{1,2})$', value)
|
||||||
if match:
|
if match:
|
||||||
year = int(match.group(1))
|
year = int(match.group(1))
|
||||||
month = int(match.group(2))
|
month = int(match.group(2))
|
||||||
return (
|
return (
|
||||||
datetime.datetime(year, month, 1),
|
datetime(year, month, 1),
|
||||||
datetime.datetime(year, month + 1, 1) - one_second)
|
datetime(year, month + 1, 1) - one_second)
|
||||||
|
|
||||||
match = re.match(r'^(\d{4})-(\d{1,2})-(\d{1,2})$', value)
|
match = re.match(r'^(\d{4})-(\d{1,2})-(\d{1,2})$', value)
|
||||||
if match:
|
if match:
|
||||||
|
@ -121,11 +131,12 @@ def parse_time_range(value):
|
||||||
month = int(match.group(2))
|
month = int(match.group(2))
|
||||||
day = int(match.group(3))
|
day = int(match.group(3))
|
||||||
return (
|
return (
|
||||||
datetime.datetime(year, month, day),
|
datetime(year, month, day),
|
||||||
datetime.datetime(year, month, day + 1) - one_second)
|
datetime(year, month, day + 1) - one_second)
|
||||||
|
|
||||||
raise errors.ValidationError('Invalid date format: %r.' % value)
|
raise errors.ValidationError('Invalid date format: %r.' % value)
|
||||||
|
|
||||||
|
|
||||||
def icase_unique(source):
|
def icase_unique(source):
|
||||||
target = []
|
target = []
|
||||||
target_low = []
|
target_low = []
|
||||||
|
@ -135,6 +146,7 @@ def icase_unique(source):
|
||||||
target_low.append(source_item.lower())
|
target_low.append(source_item.lower())
|
||||||
return target
|
return target
|
||||||
|
|
||||||
|
|
||||||
def value_exceeds_column_size(value, column):
|
def value_exceeds_column_size(value, column):
|
||||||
if not value:
|
if not value:
|
||||||
return False
|
return False
|
||||||
|
@ -143,6 +155,7 @@ def value_exceeds_column_size(value, column):
|
||||||
return False
|
return False
|
||||||
return len(value) > max_length
|
return len(value) > max_length
|
||||||
|
|
||||||
|
|
||||||
def verify_version(entity, context, field_name='version'):
|
def verify_version(entity, context, field_name='version'):
|
||||||
actual_version = context.get_param_as_int(field_name, required=True)
|
actual_version = context.get_param_as_int(field_name, required=True)
|
||||||
expected_version = entity.version
|
expected_version = entity.version
|
||||||
|
@ -151,5 +164,6 @@ def verify_version(entity, context, field_name='version'):
|
||||||
'Someone else modified this in the meantime. ' +
|
'Someone else modified this in the meantime. ' +
|
||||||
'Please try again.')
|
'Please try again.')
|
||||||
|
|
||||||
|
|
||||||
def bump_version(entity):
|
def bump_version(entity):
|
||||||
entity.version += 1
|
entity.version += 1
|
||||||
|
|
|
@ -4,6 +4,7 @@ from szurubooru.func import auth, users
|
||||||
from szurubooru.rest import middleware
|
from szurubooru.rest import middleware
|
||||||
from szurubooru.rest.errors import HttpBadRequest
|
from szurubooru.rest.errors import HttpBadRequest
|
||||||
|
|
||||||
|
|
||||||
def _authenticate(username, password):
|
def _authenticate(username, password):
|
||||||
''' Try to authenticate user. Throw AuthError for invalid users. '''
|
''' Try to authenticate user. Throw AuthError for invalid users. '''
|
||||||
user = users.get_user_by_name(username)
|
user = users.get_user_by_name(username)
|
||||||
|
@ -11,23 +12,25 @@ def _authenticate(username, password):
|
||||||
raise errors.AuthError('Invalid password.')
|
raise errors.AuthError('Invalid password.')
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
def _create_anonymous_user():
|
def _create_anonymous_user():
|
||||||
user = db.User()
|
user = db.User()
|
||||||
user.name = None
|
user.name = None
|
||||||
user.rank = 'anonymous'
|
user.rank = 'anonymous'
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
def _get_user(ctx):
|
def _get_user(ctx):
|
||||||
if not ctx.has_header('Authorization'):
|
if not ctx.has_header('Authorization'):
|
||||||
return _create_anonymous_user()
|
return _create_anonymous_user()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
auth_type, user_and_password = ctx.get_header('Authorization').split(' ', 1)
|
auth_type, credentials = ctx.get_header('Authorization').split(' ', 1)
|
||||||
if auth_type.lower() != 'basic':
|
if auth_type.lower() != 'basic':
|
||||||
raise HttpBadRequest(
|
raise HttpBadRequest(
|
||||||
'Only basic HTTP authentication is supported.')
|
'Only basic HTTP authentication is supported.')
|
||||||
username, password = base64.decodebytes(
|
username, password = base64.decodebytes(
|
||||||
user_and_password.encode('ascii')).decode('utf8').split(':')
|
credentials.encode('ascii')).decode('utf8').split(':')
|
||||||
return _authenticate(username, password)
|
return _authenticate(username, password)
|
||||||
except ValueError as err:
|
except ValueError as err:
|
||||||
msg = 'Basic authentication header value are not properly formed. ' \
|
msg = 'Basic authentication header value are not properly formed. ' \
|
||||||
|
@ -35,6 +38,7 @@ def _get_user(ctx):
|
||||||
raise HttpBadRequest(
|
raise HttpBadRequest(
|
||||||
msg.format(ctx.get_header('Authorization'), str(err)))
|
msg.format(ctx.get_header('Authorization'), str(err)))
|
||||||
|
|
||||||
|
|
||||||
@middleware.pre_hook
|
@middleware.pre_hook
|
||||||
def process_request(ctx):
|
def process_request(ctx):
|
||||||
''' Bind the user to request. Update last login time if needed. '''
|
''' Bind the user to request. Update last login time if needed. '''
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from szurubooru.func import cache
|
from szurubooru.func import cache
|
||||||
from szurubooru.rest import middleware
|
from szurubooru.rest import middleware
|
||||||
|
|
||||||
|
|
||||||
@middleware.pre_hook
|
@middleware.pre_hook
|
||||||
def process_request(ctx):
|
def process_request(ctx):
|
||||||
if ctx.method != 'GET':
|
if ctx.method != 'GET':
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
from szurubooru import db
|
from szurubooru import db
|
||||||
from szurubooru.rest import middleware
|
from szurubooru.rest import middleware
|
||||||
|
|
||||||
|
|
||||||
@middleware.pre_hook
|
@middleware.pre_hook
|
||||||
def _process_request(ctx):
|
def _process_request(ctx):
|
||||||
ctx.session = db.session()
|
ctx.session = db.session()
|
||||||
db.reset_query_count()
|
db.reset_query_count()
|
||||||
|
|
||||||
|
|
||||||
@middleware.post_hook
|
@middleware.post_hook
|
||||||
def _process_response(_ctx):
|
def _process_response(_ctx):
|
||||||
db.session.remove()
|
db.session.remove()
|
||||||
|
|
|
@ -2,8 +2,10 @@ import logging
|
||||||
from szurubooru import db
|
from szurubooru import db
|
||||||
from szurubooru.rest import middleware
|
from szurubooru.rest import middleware
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@middleware.post_hook
|
@middleware.post_hook
|
||||||
def process_response(ctx):
|
def process_response(ctx):
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
|
@ -28,6 +28,7 @@ alembic_config.set_main_option(
|
||||||
|
|
||||||
target_metadata = szurubooru.db.Base.metadata
|
target_metadata = szurubooru.db.Base.metadata
|
||||||
|
|
||||||
|
|
||||||
def run_migrations_offline():
|
def run_migrations_offline():
|
||||||
'''
|
'''
|
||||||
Run migrations in 'offline' mode.
|
Run migrations in 'offline' mode.
|
||||||
|
|
|
@ -13,6 +13,7 @@ down_revision = 'e5c1216a8503'
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
def upgrade():
|
||||||
op.create_table(
|
op.create_table(
|
||||||
'tag_category',
|
'tag_category',
|
||||||
|
@ -55,6 +56,7 @@ def upgrade():
|
||||||
sa.ForeignKeyConstraint(['child_id'], ['tag.id']),
|
sa.ForeignKeyConstraint(['child_id'], ['tag.id']),
|
||||||
sa.PrimaryKeyConstraint('parent_id', 'child_id'))
|
sa.PrimaryKeyConstraint('parent_id', 'child_id'))
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
op.drop_table('tag_suggestion')
|
op.drop_table('tag_suggestion')
|
||||||
op.drop_table('tag_implication')
|
op.drop_table('tag_implication')
|
||||||
|
|
|
@ -13,10 +13,16 @@ down_revision = '49ab4e1139ef'
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
def upgrade():
|
||||||
op.add_column('tag_category', sa.Column('default', sa.Boolean(), nullable=True))
|
op.add_column(
|
||||||
op.execute(sa.table('tag_category', sa.column('default')).update().values(default=False))
|
'tag_category', sa.Column('default', sa.Boolean(), nullable=True))
|
||||||
|
op.execute(
|
||||||
|
sa.table('tag_category', sa.column('default'))
|
||||||
|
.update()
|
||||||
|
.values(default=False))
|
||||||
op.alter_column('tag_category', 'default', nullable=False)
|
op.alter_column('tag_category', 'default', nullable=False)
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
op.drop_column('tag_category', 'default')
|
op.drop_column('tag_category', 'default')
|
||||||
|
|
|
@ -13,8 +13,11 @@ down_revision = 'ed6dd16a30f3'
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
def upgrade():
|
||||||
op.add_column('post', sa.Column('mime-type', sa.Unicode(length=32), nullable=False))
|
op.add_column(
|
||||||
|
'post', sa.Column('mime-type', sa.Unicode(length=32), nullable=False))
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
op.drop_column('post', 'mime-type')
|
op.drop_column('post', 'mime-type')
|
||||||
|
|
|
@ -13,6 +13,7 @@ down_revision = '00cb3a2734db'
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
def upgrade():
|
||||||
op.create_table(
|
op.create_table(
|
||||||
'post',
|
'post',
|
||||||
|
@ -56,6 +57,7 @@ def upgrade():
|
||||||
sa.ForeignKeyConstraint(['tag_id'], ['tag.id']),
|
sa.ForeignKeyConstraint(['tag_id'], ['tag.id']),
|
||||||
sa.PrimaryKeyConstraint('post_id', 'tag_id'))
|
sa.PrimaryKeyConstraint('post_id', 'tag_id'))
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
op.drop_table('post_tag')
|
op.drop_table('post_tag')
|
||||||
op.drop_table('post_relation')
|
op.drop_table('post_relation')
|
||||||
|
|
|
@ -13,10 +13,12 @@ down_revision = '565e01e3cf6d'
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
def upgrade():
|
||||||
op.add_column(
|
op.add_column(
|
||||||
'snapshot',
|
'snapshot',
|
||||||
sa.Column('resource_repr', sa.Unicode(length=64), nullable=False))
|
sa.Column('resource_repr', sa.Unicode(length=64), nullable=False))
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
op.drop_column('snapshot', 'resource_repr')
|
op.drop_column('snapshot', 'resource_repr')
|
||||||
|
|
|
@ -13,6 +13,7 @@ down_revision = '84bd402f15f0'
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
def upgrade():
|
||||||
op.create_table(
|
op.create_table(
|
||||||
'comment',
|
'comment',
|
||||||
|
@ -36,6 +37,7 @@ def upgrade():
|
||||||
sa.ForeignKeyConstraint(['user_id'], ['user.id']),
|
sa.ForeignKeyConstraint(['user_id'], ['user.id']),
|
||||||
sa.PrimaryKeyConstraint('comment_id', 'user_id'))
|
sa.PrimaryKeyConstraint('comment_id', 'user_id'))
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
op.drop_table('comment_score')
|
op.drop_table('comment_score')
|
||||||
op.drop_table('comment')
|
op.drop_table('comment')
|
||||||
|
|
|
@ -13,52 +13,59 @@ down_revision = '23abaf4a0a4b'
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
def upgrade():
|
||||||
op.create_index(op.f('ix_comment_post_id'), 'comment', ['post_id'], unique=False)
|
for index_name, table_name, column_name in [
|
||||||
op.create_index(op.f('ix_comment_user_id'), 'comment', ['user_id'], unique=False)
|
('ix_comment_post_id', 'comment', 'post_id'),
|
||||||
op.create_index(op.f('ix_comment_score_user_id'), 'comment_score', ['user_id'], unique=False)
|
('ix_comment_user_id', 'comment', 'user_id'),
|
||||||
op.create_index(op.f('ix_post_user_id'), 'post', ['user_id'], unique=False)
|
('ix_comment_score_user_id', 'comment_score', 'user_id'),
|
||||||
op.create_index(op.f('ix_post_favorite_post_id'), 'post_favorite', ['post_id'], unique=False)
|
('ix_post_user_id', 'post', 'user_id'),
|
||||||
op.create_index(op.f('ix_post_favorite_user_id'), 'post_favorite', ['user_id'], unique=False)
|
('ix_post_favorite_post_id', 'post_favorite', 'post_id'),
|
||||||
op.create_index(op.f('ix_post_feature_post_id'), 'post_feature', ['post_id'], unique=False)
|
('ix_post_favorite_user_id', 'post_favorite', 'user_id'),
|
||||||
op.create_index(op.f('ix_post_feature_user_id'), 'post_feature', ['user_id'], unique=False)
|
('ix_post_feature_post_id', 'post_feature', 'post_id'),
|
||||||
op.create_index(op.f('ix_post_note_post_id'), 'post_note', ['post_id'], unique=False)
|
('ix_post_feature_user_id', 'post_feature', 'user_id'),
|
||||||
op.create_index(op.f('ix_post_relation_child_id'), 'post_relation', ['child_id'], unique=False)
|
('ix_post_note_post_id', 'post_note', 'post_id'),
|
||||||
op.create_index(op.f('ix_post_relation_parent_id'), 'post_relation', ['parent_id'], unique=False)
|
('ix_post_relation_child_id', 'post_relation', 'child_id'),
|
||||||
op.create_index(op.f('ix_post_score_post_id'), 'post_score', ['post_id'], unique=False)
|
('ix_post_relation_parent_id', 'post_relation', 'parent_id'),
|
||||||
op.create_index(op.f('ix_post_score_user_id'), 'post_score', ['user_id'], unique=False)
|
('ix_post_score_post_id', 'post_score', 'post_id'),
|
||||||
op.create_index(op.f('ix_post_tag_post_id'), 'post_tag', ['post_id'], unique=False)
|
('ix_post_score_user_id', 'post_score', 'user_id'),
|
||||||
op.create_index(op.f('ix_post_tag_tag_id'), 'post_tag', ['tag_id'], unique=False)
|
('ix_post_tag_post_id', 'post_tag', 'post_id'),
|
||||||
op.create_index(op.f('ix_snapshot_resource_id'), 'snapshot', ['resource_id'], unique=False)
|
('ix_post_tag_tag_id', 'post_tag', 'tag_id'),
|
||||||
op.create_index(op.f('ix_snapshot_resource_type'), 'snapshot', ['resource_type'], unique=False)
|
('ix_snapshot_resource_id', 'snapshot', 'resource_id'),
|
||||||
op.create_index(op.f('ix_tag_category_id'), 'tag', ['category_id'], unique=False)
|
('ix_snapshot_resource_type', 'snapshot', 'resource_type'),
|
||||||
op.create_index(op.f('ix_tag_implication_child_id'), 'tag_implication', ['child_id'], unique=False)
|
('ix_tag_category_id', 'tag', 'category_id'),
|
||||||
op.create_index(op.f('ix_tag_implication_parent_id'), 'tag_implication', ['parent_id'], unique=False)
|
('ix_tag_implication_child_id', 'tag_implication', 'child_id'),
|
||||||
op.create_index(op.f('ix_tag_name_tag_id'), 'tag_name', ['tag_id'], unique=False)
|
('ix_tag_implication_parent_id', 'tag_implication', 'parent_id'),
|
||||||
op.create_index(op.f('ix_tag_suggestion_child_id'), 'tag_suggestion', ['child_id'], unique=False)
|
('ix_tag_name_tag_id', 'tag_name', 'tag_id'),
|
||||||
op.create_index(op.f('ix_tag_suggestion_parent_id'), 'tag_suggestion', ['parent_id'], unique=False)
|
('ix_tag_suggestion_child_id', 'tag_suggestion', 'child_id'),
|
||||||
|
('ix_tag_suggestion_parent_id', 'tag_suggestion', 'parent_id')]:
|
||||||
|
op.create_index(
|
||||||
|
op.f(index_name), table_name, [column_name], unique=False)
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
op.drop_index(op.f('ix_tag_suggestion_parent_id'), table_name='tag_suggestion')
|
for index_name, table_name in [
|
||||||
op.drop_index(op.f('ix_tag_suggestion_child_id'), table_name='tag_suggestion')
|
('ix_tag_suggestion_parent_id', 'tag_suggestion'),
|
||||||
op.drop_index(op.f('ix_tag_name_tag_id'), table_name='tag_name')
|
('ix_tag_suggestion_child_id', 'tag_suggestion'),
|
||||||
op.drop_index(op.f('ix_tag_implication_parent_id'), table_name='tag_implication')
|
('ix_tag_name_tag_id', 'tag_name'),
|
||||||
op.drop_index(op.f('ix_tag_implication_child_id'), table_name='tag_implication')
|
('ix_tag_implication_parent_id', 'tag_implication'),
|
||||||
op.drop_index(op.f('ix_tag_category_id'), table_name='tag')
|
('ix_tag_implication_child_id', 'tag_implication'),
|
||||||
op.drop_index(op.f('ix_snapshot_resource_type'), table_name='snapshot')
|
('ix_tag_category_id', 'tag'),
|
||||||
op.drop_index(op.f('ix_snapshot_resource_id'), table_name='snapshot')
|
('ix_snapshot_resource_type', 'snapshot'),
|
||||||
op.drop_index(op.f('ix_post_tag_tag_id'), table_name='post_tag')
|
('ix_snapshot_resource_id', 'snapshot'),
|
||||||
op.drop_index(op.f('ix_post_tag_post_id'), table_name='post_tag')
|
('ix_post_tag_tag_id', 'post_tag'),
|
||||||
op.drop_index(op.f('ix_post_score_user_id'), table_name='post_score')
|
('ix_post_tag_post_id', 'post_tag'),
|
||||||
op.drop_index(op.f('ix_post_score_post_id'), table_name='post_score')
|
('ix_post_score_user_id', 'post_score'),
|
||||||
op.drop_index(op.f('ix_post_relation_parent_id'), table_name='post_relation')
|
('ix_post_score_post_id', 'post_score'),
|
||||||
op.drop_index(op.f('ix_post_relation_child_id'), table_name='post_relation')
|
('ix_post_relation_parent_id', 'post_relation'),
|
||||||
op.drop_index(op.f('ix_post_note_post_id'), table_name='post_note')
|
('ix_post_relation_child_id', 'post_relation'),
|
||||||
op.drop_index(op.f('ix_post_feature_user_id'), table_name='post_feature')
|
('ix_post_note_post_id', 'post_note'),
|
||||||
op.drop_index(op.f('ix_post_feature_post_id'), table_name='post_feature')
|
('ix_post_feature_user_id', 'post_feature'),
|
||||||
op.drop_index(op.f('ix_post_favorite_user_id'), table_name='post_favorite')
|
('ix_post_feature_post_id', 'post_feature'),
|
||||||
op.drop_index(op.f('ix_post_favorite_post_id'), table_name='post_favorite')
|
('ix_post_favorite_user_id', 'post_favorite'),
|
||||||
op.drop_index(op.f('ix_post_user_id'), table_name='post')
|
('ix_post_favorite_post_id', 'post_favorite'),
|
||||||
op.drop_index(op.f('ix_comment_score_user_id'), table_name='comment_score')
|
('ix_post_user_id', 'post'),
|
||||||
op.drop_index(op.f('ix_comment_user_id'), table_name='comment')
|
('ix_comment_score_user_id', 'comment_score'),
|
||||||
op.drop_index(op.f('ix_comment_post_id'), table_name='comment')
|
('ix_comment_user_id', 'comment'),
|
||||||
|
('ix_comment_post_id', 'comment')]:
|
||||||
|
op.drop_index(op.f(index_name), table_name=table_name)
|
||||||
|
|
|
@ -13,8 +13,11 @@ down_revision = '055d0e048fb3'
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
def upgrade():
|
||||||
op.add_column('tag', sa.Column('description', sa.UnicodeText(), nullable=True))
|
op.add_column(
|
||||||
|
'tag', sa.Column('description', sa.UnicodeText(), nullable=True))
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
op.drop_column('tag', 'description')
|
op.drop_column('tag', 'description')
|
||||||
|
|
|
@ -13,6 +13,7 @@ down_revision = '336a76ec1338'
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
def upgrade():
|
||||||
op.create_table(
|
op.create_table(
|
||||||
'snapshot',
|
'snapshot',
|
||||||
|
@ -26,5 +27,6 @@ def upgrade():
|
||||||
sa.ForeignKeyConstraint(['user_id'], ['user.id']),
|
sa.ForeignKeyConstraint(['user_id'], ['user.id']),
|
||||||
sa.PrimaryKeyConstraint('id'))
|
sa.PrimaryKeyConstraint('id'))
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
op.drop_table('snapshot')
|
op.drop_table('snapshot')
|
||||||
|
|
|
@ -15,12 +15,17 @@ depends_on = None
|
||||||
|
|
||||||
tables = ['tag_category', 'tag', 'user', 'post', 'comment']
|
tables = ['tag_category', 'tag', 'user', 'post', 'comment']
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
def upgrade():
|
||||||
for table in tables:
|
for table in tables:
|
||||||
op.add_column(table, sa.Column('version', sa.Integer(), nullable=True))
|
op.add_column(table, sa.Column('version', sa.Integer(), nullable=True))
|
||||||
op.execute(sa.table(table, sa.column('version')).update().values(version=1))
|
op.execute(
|
||||||
|
sa.table(table, sa.column('version'))
|
||||||
|
.update()
|
||||||
|
.values(version=1))
|
||||||
op.alter_column(table, 'version', nullable=False)
|
op.alter_column(table, 'version', nullable=False)
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
for table in tables:
|
for table in tables:
|
||||||
op.drop_column(table, 'version')
|
op.drop_column(table, 'version')
|
||||||
|
|
|
@ -13,10 +13,14 @@ down_revision = '9587de88a84b'
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
def upgrade():
|
||||||
op.drop_column('post', 'flags')
|
op.drop_column('post', 'flags')
|
||||||
op.add_column('post', sa.Column('flags', sa.PickleType(), nullable=True))
|
op.add_column('post', sa.Column('flags', sa.PickleType(), nullable=True))
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
op.drop_column('post', 'flags')
|
op.drop_column('post', 'flags')
|
||||||
op.add_column('post', sa.Column('flags', sa.Integer(), autoincrement=False, nullable=False))
|
op.add_column(
|
||||||
|
'post',
|
||||||
|
sa.Column('flags', sa.Integer(), autoincrement=False, nullable=False))
|
||||||
|
|
|
@ -13,6 +13,7 @@ down_revision = '46cd5229839b'
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
def upgrade():
|
||||||
op.create_table(
|
op.create_table(
|
||||||
'post_favorite',
|
'post_favorite',
|
||||||
|
@ -52,6 +53,7 @@ def upgrade():
|
||||||
sa.ForeignKeyConstraint(['user_id'], ['user.id']),
|
sa.ForeignKeyConstraint(['user_id'], ['user.id']),
|
||||||
sa.PrimaryKeyConstraint('post_id', 'user_id'))
|
sa.PrimaryKeyConstraint('post_id', 'user_id'))
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
op.drop_table('post_score')
|
op.drop_table('post_score')
|
||||||
op.drop_table('post_note')
|
op.drop_table('post_note')
|
||||||
|
|
|
@ -13,6 +13,7 @@ down_revision = None
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
def upgrade():
|
||||||
op.create_table(
|
op.create_table(
|
||||||
'user',
|
'user',
|
||||||
|
@ -28,5 +29,6 @@ def upgrade():
|
||||||
sa.PrimaryKeyConstraint('id'))
|
sa.PrimaryKeyConstraint('id'))
|
||||||
op.create_unique_constraint('uq_user_name', 'user', ['name'])
|
op.create_unique_constraint('uq_user_name', 'user', ['name'])
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
op.drop_table('user')
|
op.drop_table('user')
|
||||||
|
|
|
@ -13,24 +13,36 @@ down_revision = '46df355634dc'
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
def upgrade():
|
||||||
op.drop_column('post', 'auto_comment_edit_time')
|
for column_name in [
|
||||||
op.drop_column('post', 'auto_fav_count')
|
'auto_comment_edit_time'
|
||||||
op.drop_column('post', 'auto_comment_creation_time')
|
'auto_fav_count',
|
||||||
op.drop_column('post', 'auto_feature_count')
|
'auto_comment_creation_time',
|
||||||
op.drop_column('post', 'auto_comment_count')
|
'auto_feature_count',
|
||||||
op.drop_column('post', 'auto_score')
|
'auto_comment_count',
|
||||||
op.drop_column('post', 'auto_fav_time')
|
'auto_score',
|
||||||
op.drop_column('post', 'auto_feature_time')
|
'auto_fav_time',
|
||||||
op.drop_column('post', 'auto_note_count')
|
'auto_feature_time',
|
||||||
|
'auto_note_count']:
|
||||||
|
op.drop_column('post', column_name)
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
def downgrade():
|
||||||
op.add_column('post', sa.Column('auto_note_count', sa.INTEGER(), autoincrement=False, nullable=False))
|
for column_name in [
|
||||||
op.add_column('post', sa.Column('auto_feature_time', sa.INTEGER(), autoincrement=False, nullable=False))
|
'auto_note_count',
|
||||||
op.add_column('post', sa.Column('auto_fav_time', sa.INTEGER(), autoincrement=False, nullable=False))
|
'auto_feature_time',
|
||||||
op.add_column('post', sa.Column('auto_score', sa.INTEGER(), autoincrement=False, nullable=False))
|
'auto_fav_time',
|
||||||
op.add_column('post', sa.Column('auto_comment_count', sa.INTEGER(), autoincrement=False, nullable=False))
|
'auto_score',
|
||||||
op.add_column('post', sa.Column('auto_feature_count', sa.INTEGER(), autoincrement=False, nullable=False))
|
'auto_comment_count',
|
||||||
op.add_column('post', sa.Column('auto_comment_creation_time', sa.INTEGER(), autoincrement=False, nullable=False))
|
'auto_feature_count',
|
||||||
op.add_column('post', sa.Column('auto_fav_count', sa.INTEGER(), autoincrement=False, nullable=False))
|
'auto_comment_creation_time',
|
||||||
op.add_column('post', sa.Column('auto_comment_edit_time', sa.INTEGER(), autoincrement=False, nullable=False))
|
'auto_fav_count',
|
||||||
|
'auto_comment_edit_time']:
|
||||||
|
op.add_column(
|
||||||
|
'post',
|
||||||
|
sa.Column(
|
||||||
|
column_name,
|
||||||
|
sa.INTEGER(),
|
||||||
|
autoincrement=False,
|
||||||
|
nullable=False))
|
||||||
|
|
|
@ -6,6 +6,7 @@ from datetime import datetime
|
||||||
from szurubooru.func import util
|
from szurubooru.func import util
|
||||||
from szurubooru.rest import errors, middleware, routes, context
|
from szurubooru.rest import errors, middleware, routes, context
|
||||||
|
|
||||||
|
|
||||||
def _json_serializer(obj):
|
def _json_serializer(obj):
|
||||||
''' JSON serializer for objects not serializable by default JSON code '''
|
''' JSON serializer for objects not serializable by default JSON code '''
|
||||||
if isinstance(obj, datetime):
|
if isinstance(obj, datetime):
|
||||||
|
@ -13,14 +14,16 @@ def _json_serializer(obj):
|
||||||
return serial
|
return serial
|
||||||
raise TypeError('Type not serializable')
|
raise TypeError('Type not serializable')
|
||||||
|
|
||||||
|
|
||||||
def _dump_json(obj):
|
def _dump_json(obj):
|
||||||
return json.dumps(obj, default=_json_serializer, indent=2)
|
return json.dumps(obj, default=_json_serializer, indent=2)
|
||||||
|
|
||||||
|
|
||||||
def _read(env):
|
def _read(env):
|
||||||
length = int(env.get('CONTENT_LENGTH', 0))
|
length = int(env.get('CONTENT_LENGTH', 0))
|
||||||
output = io.BytesIO()
|
output = io.BytesIO()
|
||||||
while length > 0:
|
while length > 0:
|
||||||
part = env['wsgi.input'].read(min(length, 1024*200))
|
part = env['wsgi.input'].read(min(length, 1024 * 200))
|
||||||
if not part:
|
if not part:
|
||||||
break
|
break
|
||||||
output.write(part)
|
output.write(part)
|
||||||
|
@ -28,6 +31,7 @@ def _read(env):
|
||||||
output.seek(0)
|
output.seek(0)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def _get_headers(env):
|
def _get_headers(env):
|
||||||
headers = {}
|
headers = {}
|
||||||
for key, value in env.items():
|
for key, value in env.items():
|
||||||
|
@ -36,6 +40,7 @@ def _get_headers(env):
|
||||||
headers[key] = value
|
headers[key] = value
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
|
||||||
def _create_context(env):
|
def _create_context(env):
|
||||||
method = env['REQUEST_METHOD']
|
method = env['REQUEST_METHOD']
|
||||||
path = '/' + env['PATH_INFO'].lstrip('/')
|
path = '/' + env['PATH_INFO'].lstrip('/')
|
||||||
|
@ -56,7 +61,7 @@ def _create_context(env):
|
||||||
if isinstance(form[key], cgi.MiniFieldStorage):
|
if isinstance(form[key], cgi.MiniFieldStorage):
|
||||||
params[key] = form.getvalue(key)
|
params[key] = form.getvalue(key)
|
||||||
else:
|
else:
|
||||||
_original_file_name = getattr(form[key], 'filename', None)
|
# _user_file_name = getattr(form[key], 'filename', None)
|
||||||
files[key] = form.getvalue(key)
|
files[key] = form.getvalue(key)
|
||||||
if 'metadata' in form:
|
if 'metadata' in form:
|
||||||
body = form.getvalue('metadata')
|
body = form.getvalue('metadata')
|
||||||
|
@ -79,10 +84,11 @@ def _create_context(env):
|
||||||
|
|
||||||
return context.Context(method, path, headers, params, files)
|
return context.Context(method, path, headers, params, files)
|
||||||
|
|
||||||
|
|
||||||
def application(env, start_response):
|
def application(env, start_response):
|
||||||
try:
|
try:
|
||||||
ctx = _create_context(env)
|
ctx = _create_context(env)
|
||||||
if not 'application/json' in ctx.get_header('Accept'):
|
if 'application/json' not in ctx.get_header('Accept'):
|
||||||
raise errors.HttpNotAcceptable(
|
raise errors.HttpNotAcceptable(
|
||||||
'This API only supports JSON responses.')
|
'This API only supports JSON responses.')
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
from szurubooru import errors
|
from szurubooru import errors
|
||||||
from szurubooru.func import net
|
from szurubooru.func import net
|
||||||
|
|
||||||
|
|
||||||
def _lower_first(source):
|
def _lower_first(source):
|
||||||
return source[0].lower() + source[1:]
|
return source[0].lower() + source[1:]
|
||||||
|
|
||||||
|
|
||||||
def _param_wrapper(func):
|
def _param_wrapper(func):
|
||||||
def wrapper(self, name, required=False, default=None, **kwargs):
|
def wrapper(self, name, required=False, default=None, **kwargs):
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
|
@ -22,8 +24,8 @@ def _param_wrapper(func):
|
||||||
'Required parameter %r is missing.' % name)
|
'Required parameter %r is missing.' % name)
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
class Context():
|
class Context():
|
||||||
# pylint: disable=too-many-arguments
|
|
||||||
def __init__(self, method, url, headers=None, params=None, files=None):
|
def __init__(self, method, url, headers=None, params=None, files=None):
|
||||||
self.method = method
|
self.method = method
|
||||||
self.url = url
|
self.url = url
|
||||||
|
@ -74,7 +76,6 @@ class Context():
|
||||||
raise errors.InvalidParameterError('Expected simple string.')
|
raise errors.InvalidParameterError('Expected simple string.')
|
||||||
return value
|
return value
|
||||||
|
|
||||||
# pylint: disable=redefined-builtin
|
|
||||||
@_param_wrapper
|
@_param_wrapper
|
||||||
def get_param_as_int(self, value, min=None, max=None):
|
def get_param_as_int(self, value, min=None, max=None):
|
||||||
try:
|
try:
|
||||||
|
@ -97,4 +98,5 @@ class Context():
|
||||||
return True
|
return True
|
||||||
if value in ['0', 'n', 'no', 'nope', 'f', 'false']:
|
if value in ['0', 'n', 'no', 'nope', 'f', 'false']:
|
||||||
return False
|
return False
|
||||||
raise errors.InvalidParameterError('The value must be a boolean value.')
|
raise errors.InvalidParameterError(
|
||||||
|
'The value must be a boolean value.')
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
error_handlers = {} # pylint: disable=invalid-name
|
error_handlers = {} # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class BaseHttpError(RuntimeError):
|
class BaseHttpError(RuntimeError):
|
||||||
code = None
|
code = None
|
||||||
|
@ -9,29 +10,36 @@ class BaseHttpError(RuntimeError):
|
||||||
self.description = description
|
self.description = description
|
||||||
self.title = title or self.reason
|
self.title = title or self.reason
|
||||||
|
|
||||||
|
|
||||||
class HttpBadRequest(BaseHttpError):
|
class HttpBadRequest(BaseHttpError):
|
||||||
code = 400
|
code = 400
|
||||||
reason = 'Bad Request'
|
reason = 'Bad Request'
|
||||||
|
|
||||||
|
|
||||||
class HttpForbidden(BaseHttpError):
|
class HttpForbidden(BaseHttpError):
|
||||||
code = 403
|
code = 403
|
||||||
reason = 'Forbidden'
|
reason = 'Forbidden'
|
||||||
|
|
||||||
|
|
||||||
class HttpNotFound(BaseHttpError):
|
class HttpNotFound(BaseHttpError):
|
||||||
code = 404
|
code = 404
|
||||||
reason = 'Not Found'
|
reason = 'Not Found'
|
||||||
|
|
||||||
|
|
||||||
class HttpNotAcceptable(BaseHttpError):
|
class HttpNotAcceptable(BaseHttpError):
|
||||||
code = 406
|
code = 406
|
||||||
reason = 'Not Acceptable'
|
reason = 'Not Acceptable'
|
||||||
|
|
||||||
|
|
||||||
class HttpConflict(BaseHttpError):
|
class HttpConflict(BaseHttpError):
|
||||||
code = 409
|
code = 409
|
||||||
reason = 'Conflict'
|
reason = 'Conflict'
|
||||||
|
|
||||||
|
|
||||||
class HttpMethodNotAllowed(BaseHttpError):
|
class HttpMethodNotAllowed(BaseHttpError):
|
||||||
code = 405
|
code = 405
|
||||||
reason = 'Method Not Allowed'
|
reason = 'Method Not Allowed'
|
||||||
|
|
||||||
|
|
||||||
def handle(exception_type, handler):
|
def handle(exception_type, handler):
|
||||||
error_handlers[exception_type] = handler
|
error_handlers[exception_type] = handler
|
||||||
|
|
|
@ -2,8 +2,10 @@
|
||||||
pre_hooks = []
|
pre_hooks = []
|
||||||
post_hooks = []
|
post_hooks = []
|
||||||
|
|
||||||
|
|
||||||
def pre_hook(handler):
|
def pre_hook(handler):
|
||||||
pre_hooks.append(handler)
|
pre_hooks.append(handler)
|
||||||
|
|
||||||
|
|
||||||
def post_hook(handler):
|
def post_hook(handler):
|
||||||
post_hooks.insert(0, handler)
|
post_hooks.insert(0, handler)
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
routes = defaultdict(dict) # pylint: disable=invalid-name
|
|
||||||
|
routes = defaultdict(dict) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
def get(url):
|
def get(url):
|
||||||
def wrapper(handler):
|
def wrapper(handler):
|
||||||
|
@ -8,18 +10,21 @@ def get(url):
|
||||||
return handler
|
return handler
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def put(url):
|
def put(url):
|
||||||
def wrapper(handler):
|
def wrapper(handler):
|
||||||
routes[url]['PUT'] = handler
|
routes[url]['PUT'] = handler
|
||||||
return handler
|
return handler
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def post(url):
|
def post(url):
|
||||||
def wrapper(handler):
|
def wrapper(handler):
|
||||||
routes[url]['POST'] = handler
|
routes[url]['POST'] = handler
|
||||||
return handler
|
return handler
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def delete(url):
|
def delete(url):
|
||||||
def wrapper(handler):
|
def wrapper(handler):
|
||||||
routes[url]['DELETE'] = handler
|
routes[url]['DELETE'] = handler
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from szurubooru.search.configs.user_search_config import UserSearchConfig
|
from .user_search_config import UserSearchConfig
|
||||||
from szurubooru.search.configs.snapshot_search_config import SnapshotSearchConfig
|
from .tag_search_config import TagSearchConfig
|
||||||
from szurubooru.search.configs.tag_search_config import TagSearchConfig
|
from .post_search_config import PostSearchConfig
|
||||||
from szurubooru.search.configs.comment_search_config import CommentSearchConfig
|
from .snapshot_search_config import SnapshotSearchConfig
|
||||||
from szurubooru.search.configs.post_search_config import PostSearchConfig
|
from .comment_search_config import CommentSearchConfig
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from szurubooru.search import tokens
|
from szurubooru.search import tokens
|
||||||
|
|
||||||
|
|
||||||
class BaseSearchConfig(object):
|
class BaseSearchConfig(object):
|
||||||
SORT_ASC = tokens.SortToken.SORT_ASC
|
SORT_ASC = tokens.SortToken.SORT_ASC
|
||||||
SORT_DESC = tokens.SortToken.SORT_DESC
|
SORT_DESC = tokens.SortToken.SORT_DESC
|
||||||
|
|
|
@ -3,6 +3,7 @@ from szurubooru import db
|
||||||
from szurubooru.search.configs import util as search_util
|
from szurubooru.search.configs import util as search_util
|
||||||
from szurubooru.search.configs.base_search_config import BaseSearchConfig
|
from szurubooru.search.configs.base_search_config import BaseSearchConfig
|
||||||
|
|
||||||
|
|
||||||
class CommentSearchConfig(BaseSearchConfig):
|
class CommentSearchConfig(BaseSearchConfig):
|
||||||
def create_filter_query(self):
|
def create_filter_query(self):
|
||||||
return db.session.query(db.Comment).join(db.User)
|
return db.session.query(db.Comment).join(db.User)
|
||||||
|
@ -22,12 +23,18 @@ class CommentSearchConfig(BaseSearchConfig):
|
||||||
'user': search_util.create_str_filter(db.User.name),
|
'user': search_util.create_str_filter(db.User.name),
|
||||||
'author': search_util.create_str_filter(db.User.name),
|
'author': search_util.create_str_filter(db.User.name),
|
||||||
'text': search_util.create_str_filter(db.Comment.text),
|
'text': search_util.create_str_filter(db.Comment.text),
|
||||||
'creation-date': search_util.create_date_filter(db.Comment.creation_time),
|
'creation-date':
|
||||||
'creation-time': search_util.create_date_filter(db.Comment.creation_time),
|
search_util.create_date_filter(db.Comment.creation_time),
|
||||||
'last-edit-date': search_util.create_date_filter(db.Comment.last_edit_time),
|
'creation-time':
|
||||||
'last-edit-time': search_util.create_date_filter(db.Comment.last_edit_time),
|
search_util.create_date_filter(db.Comment.creation_time),
|
||||||
'edit-date': search_util.create_date_filter(db.Comment.last_edit_time),
|
'last-edit-date':
|
||||||
'edit-time': search_util.create_date_filter(db.Comment.last_edit_time),
|
search_util.create_date_filter(db.Comment.last_edit_time),
|
||||||
|
'last-edit-time':
|
||||||
|
search_util.create_date_filter(db.Comment.last_edit_time),
|
||||||
|
'edit-date':
|
||||||
|
search_util.create_date_filter(db.Comment.last_edit_time),
|
||||||
|
'edit-time':
|
||||||
|
search_util.create_date_filter(db.Comment.last_edit_time),
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -6,6 +6,7 @@ from szurubooru.search import criteria, tokens
|
||||||
from szurubooru.search.configs import util as search_util
|
from szurubooru.search.configs import util as search_util
|
||||||
from szurubooru.search.configs.base_search_config import BaseSearchConfig
|
from szurubooru.search.configs.base_search_config import BaseSearchConfig
|
||||||
|
|
||||||
|
|
||||||
def _enum_transformer(available_values, value):
|
def _enum_transformer(available_values, value):
|
||||||
try:
|
try:
|
||||||
return available_values[value.lower()]
|
return available_values[value.lower()]
|
||||||
|
@ -14,6 +15,7 @@ def _enum_transformer(available_values, value):
|
||||||
'Invalid value: %r. Possible values: %r.' % (
|
'Invalid value: %r. Possible values: %r.' % (
|
||||||
value, list(sorted(available_values.keys()))))
|
value, list(sorted(available_values.keys()))))
|
||||||
|
|
||||||
|
|
||||||
def _type_transformer(value):
|
def _type_transformer(value):
|
||||||
available_values = {
|
available_values = {
|
||||||
'image': db.Post.TYPE_IMAGE,
|
'image': db.Post.TYPE_IMAGE,
|
||||||
|
@ -28,6 +30,7 @@ def _type_transformer(value):
|
||||||
}
|
}
|
||||||
return _enum_transformer(available_values, value)
|
return _enum_transformer(available_values, value)
|
||||||
|
|
||||||
|
|
||||||
def _safety_transformer(value):
|
def _safety_transformer(value):
|
||||||
available_values = {
|
available_values = {
|
||||||
'safe': db.Post.SAFETY_SAFE,
|
'safe': db.Post.SAFETY_SAFE,
|
||||||
|
@ -37,11 +40,12 @@ def _safety_transformer(value):
|
||||||
}
|
}
|
||||||
return _enum_transformer(available_values, value)
|
return _enum_transformer(available_values, value)
|
||||||
|
|
||||||
|
|
||||||
def _create_score_filter(score):
|
def _create_score_filter(score):
|
||||||
def wrapper(query, criterion, negated):
|
def wrapper(query, criterion, negated):
|
||||||
if not getattr(criterion, 'internal', False):
|
if not getattr(criterion, 'internal', False):
|
||||||
raise errors.SearchError(
|
raise errors.SearchError(
|
||||||
'Votes cannot be seen publicly. Did you mean %r?' \
|
'Votes cannot be seen publicly. Did you mean %r?'
|
||||||
% 'special:liked')
|
% 'special:liked')
|
||||||
user_alias = aliased(db.User)
|
user_alias = aliased(db.User)
|
||||||
score_alias = aliased(db.PostScore)
|
score_alias = aliased(db.PostScore)
|
||||||
|
@ -57,6 +61,7 @@ def _create_score_filter(score):
|
||||||
return ret
|
return ret
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
class PostSearchConfig(BaseSearchConfig):
|
class PostSearchConfig(BaseSearchConfig):
|
||||||
def on_search_query_parsed(self, search_query):
|
def on_search_query_parsed(self, search_query):
|
||||||
new_special_tokens = []
|
new_special_tokens = []
|
||||||
|
@ -64,7 +69,8 @@ class PostSearchConfig(BaseSearchConfig):
|
||||||
if token.value in ('fav', 'liked', 'disliked'):
|
if token.value in ('fav', 'liked', 'disliked'):
|
||||||
assert self.user
|
assert self.user
|
||||||
if self.user.rank == 'anonymous':
|
if self.user.rank == 'anonymous':
|
||||||
raise errors.SearchError('Must be logged in to use this feature.')
|
raise errors.SearchError(
|
||||||
|
'Must be logged in to use this feature.')
|
||||||
criterion = criteria.PlainCriterion(
|
criterion = criteria.PlainCriterion(
|
||||||
original_text=self.user.name,
|
original_text=self.user.name,
|
||||||
value=self.user.name)
|
value=self.user.name)
|
||||||
|
@ -85,9 +91,9 @@ class PostSearchConfig(BaseSearchConfig):
|
||||||
return self.create_count_query() \
|
return self.create_count_query() \
|
||||||
.options(
|
.options(
|
||||||
# use config optimized for official client
|
# use config optimized for official client
|
||||||
#defer(db.Post.score),
|
# defer(db.Post.score),
|
||||||
#defer(db.Post.favorite_count),
|
# defer(db.Post.favorite_count),
|
||||||
#defer(db.Post.comment_count),
|
# defer(db.Post.comment_count),
|
||||||
defer(db.Post.last_favorite_time),
|
defer(db.Post.last_favorite_time),
|
||||||
defer(db.Post.feature_count),
|
defer(db.Post.feature_count),
|
||||||
defer(db.Post.last_feature_time),
|
defer(db.Post.last_feature_time),
|
||||||
|
@ -99,8 +105,7 @@ class PostSearchConfig(BaseSearchConfig):
|
||||||
lazyload(db.Post.user),
|
lazyload(db.Post.user),
|
||||||
lazyload(db.Post.relations),
|
lazyload(db.Post.relations),
|
||||||
lazyload(db.Post.notes),
|
lazyload(db.Post.notes),
|
||||||
lazyload(db.Post.favorited_by),
|
lazyload(db.Post.favorited_by))
|
||||||
)
|
|
||||||
|
|
||||||
def create_count_query(self):
|
def create_count_query(self):
|
||||||
return db.session.query(db.Post)
|
return db.session.query(db.Post)
|
||||||
|
@ -153,12 +158,18 @@ class PostSearchConfig(BaseSearchConfig):
|
||||||
'liked': _create_score_filter(1),
|
'liked': _create_score_filter(1),
|
||||||
'disliked': _create_score_filter(-1),
|
'disliked': _create_score_filter(-1),
|
||||||
'tag-count': search_util.create_num_filter(db.Post.tag_count),
|
'tag-count': search_util.create_num_filter(db.Post.tag_count),
|
||||||
'comment-count': search_util.create_num_filter(db.Post.comment_count),
|
'comment-count':
|
||||||
'fav-count': search_util.create_num_filter(db.Post.favorite_count),
|
search_util.create_num_filter(db.Post.comment_count),
|
||||||
|
'fav-count':
|
||||||
|
search_util.create_num_filter(db.Post.favorite_count),
|
||||||
'note-count': search_util.create_num_filter(db.Post.note_count),
|
'note-count': search_util.create_num_filter(db.Post.note_count),
|
||||||
'relation-count': search_util.create_num_filter(db.Post.relation_count),
|
'relation-count':
|
||||||
'feature-count': search_util.create_num_filter(db.Post.feature_count),
|
search_util.create_num_filter(db.Post.relation_count),
|
||||||
'type': search_util.create_str_filter(db.Post.type, _type_transformer),
|
'feature-count':
|
||||||
|
search_util.create_num_filter(db.Post.feature_count),
|
||||||
|
'type':
|
||||||
|
search_util.create_str_filter(
|
||||||
|
db.Post.type, _type_transformer),
|
||||||
'file-size': search_util.create_num_filter(db.Post.file_size),
|
'file-size': search_util.create_num_filter(db.Post.file_size),
|
||||||
('image-width', 'width'):
|
('image-width', 'width'):
|
||||||
search_util.create_num_filter(db.Post.canvas_width),
|
search_util.create_num_filter(db.Post.canvas_width),
|
||||||
|
@ -171,13 +182,15 @@ class PostSearchConfig(BaseSearchConfig):
|
||||||
('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'):
|
('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'):
|
||||||
search_util.create_date_filter(db.Post.last_edit_time),
|
search_util.create_date_filter(db.Post.last_edit_time),
|
||||||
('comment-date', 'comment-time'):
|
('comment-date', 'comment-time'):
|
||||||
search_util.create_date_filter(db.Post.last_comment_creation_time),
|
search_util.create_date_filter(
|
||||||
|
db.Post.last_comment_creation_time),
|
||||||
('fav-date', 'fav-time'):
|
('fav-date', 'fav-time'):
|
||||||
search_util.create_date_filter(db.Post.last_favorite_time),
|
search_util.create_date_filter(db.Post.last_favorite_time),
|
||||||
('feature-date', 'feature-time'):
|
('feature-date', 'feature-time'):
|
||||||
search_util.create_date_filter(db.Post.last_feature_time),
|
search_util.create_date_filter(db.Post.last_feature_time),
|
||||||
('safety', 'rating'):
|
('safety', 'rating'):
|
||||||
search_util.create_str_filter(db.Post.safety, _safety_transformer),
|
search_util.create_str_filter(
|
||||||
|
db.Post.safety, _safety_transformer),
|
||||||
})
|
})
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -193,9 +206,12 @@ class PostSearchConfig(BaseSearchConfig):
|
||||||
'relation-count': (db.Post.relation_count, self.SORT_DESC),
|
'relation-count': (db.Post.relation_count, self.SORT_DESC),
|
||||||
'feature-count': (db.Post.feature_count, self.SORT_DESC),
|
'feature-count': (db.Post.feature_count, self.SORT_DESC),
|
||||||
'file-size': (db.Post.file_size, self.SORT_DESC),
|
'file-size': (db.Post.file_size, self.SORT_DESC),
|
||||||
('image-width', 'width'): (db.Post.canvas_width, self.SORT_DESC),
|
('image-width', 'width'):
|
||||||
('image-height', 'height'): (db.Post.canvas_height, self.SORT_DESC),
|
(db.Post.canvas_width, self.SORT_DESC),
|
||||||
('image-area', 'area'): (db.Post.canvas_area, self.SORT_DESC),
|
('image-height', 'height'):
|
||||||
|
(db.Post.canvas_height, self.SORT_DESC),
|
||||||
|
('image-area', 'area'):
|
||||||
|
(db.Post.canvas_area, self.SORT_DESC),
|
||||||
('creation-date', 'creation-time', 'date', 'time'):
|
('creation-date', 'creation-time', 'date', 'time'):
|
||||||
(db.Post.creation_time, self.SORT_DESC),
|
(db.Post.creation_time, self.SORT_DESC),
|
||||||
('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'):
|
('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'):
|
||||||
|
|
|
@ -2,6 +2,7 @@ from szurubooru import db
|
||||||
from szurubooru.search.configs import util as search_util
|
from szurubooru.search.configs import util as search_util
|
||||||
from szurubooru.search.configs.base_search_config import BaseSearchConfig
|
from szurubooru.search.configs.base_search_config import BaseSearchConfig
|
||||||
|
|
||||||
|
|
||||||
class SnapshotSearchConfig(BaseSearchConfig):
|
class SnapshotSearchConfig(BaseSearchConfig):
|
||||||
def create_filter_query(self):
|
def create_filter_query(self):
|
||||||
return db.session.query(db.Snapshot)
|
return db.session.query(db.Snapshot)
|
||||||
|
|
|
@ -5,6 +5,7 @@ from szurubooru.func import util
|
||||||
from szurubooru.search.configs import util as search_util
|
from szurubooru.search.configs import util as search_util
|
||||||
from szurubooru.search.configs.base_search_config import BaseSearchConfig
|
from szurubooru.search.configs.base_search_config import BaseSearchConfig
|
||||||
|
|
||||||
|
|
||||||
class TagSearchConfig(BaseSearchConfig):
|
class TagSearchConfig(BaseSearchConfig):
|
||||||
def create_filter_query(self):
|
def create_filter_query(self):
|
||||||
return self.create_count_query() \
|
return self.create_count_query() \
|
||||||
|
@ -13,8 +14,7 @@ class TagSearchConfig(BaseSearchConfig):
|
||||||
subqueryload(db.Tag.names),
|
subqueryload(db.Tag.names),
|
||||||
subqueryload(db.Tag.category),
|
subqueryload(db.Tag.category),
|
||||||
subqueryload(db.Tag.suggestions).joinedload(db.Tag.names),
|
subqueryload(db.Tag.suggestions).joinedload(db.Tag.names),
|
||||||
subqueryload(db.Tag.implications).joinedload(db.Tag.names)
|
subqueryload(db.Tag.implications).joinedload(db.Tag.names))
|
||||||
)
|
|
||||||
|
|
||||||
def create_count_query(self):
|
def create_count_query(self):
|
||||||
return db.session.query(db.Tag)
|
return db.session.query(db.Tag)
|
||||||
|
|
|
@ -3,6 +3,7 @@ from szurubooru import db
|
||||||
from szurubooru.search.configs import util as search_util
|
from szurubooru.search.configs import util as search_util
|
||||||
from szurubooru.search.configs.base_search_config import BaseSearchConfig
|
from szurubooru.search.configs.base_search_config import BaseSearchConfig
|
||||||
|
|
||||||
|
|
||||||
class UserSearchConfig(BaseSearchConfig):
|
class UserSearchConfig(BaseSearchConfig):
|
||||||
''' Executes searches related to the users. '''
|
''' Executes searches related to the users. '''
|
||||||
|
|
||||||
|
@ -20,12 +21,18 @@ class UserSearchConfig(BaseSearchConfig):
|
||||||
def named_filters(self):
|
def named_filters(self):
|
||||||
return {
|
return {
|
||||||
'name': search_util.create_str_filter(db.User.name),
|
'name': search_util.create_str_filter(db.User.name),
|
||||||
'creation-date': search_util.create_date_filter(db.User.creation_time),
|
'creation-date':
|
||||||
'creation-time': search_util.create_date_filter(db.User.creation_time),
|
search_util.create_date_filter(db.User.creation_time),
|
||||||
'last-login-date': search_util.create_date_filter(db.User.last_login_time),
|
'creation-time':
|
||||||
'last-login-time': search_util.create_date_filter(db.User.last_login_time),
|
search_util.create_date_filter(db.User.creation_time),
|
||||||
'login-date': search_util.create_date_filter(db.User.last_login_time),
|
'last-login-date':
|
||||||
'login-time': search_util.create_date_filter(db.User.last_login_time),
|
search_util.create_date_filter(db.User.last_login_time),
|
||||||
|
'last-login-time':
|
||||||
|
search_util.create_date_filter(db.User.last_login_time),
|
||||||
|
'login-date':
|
||||||
|
search_util.create_date_filter(db.User.last_login_time),
|
||||||
|
'login-time':
|
||||||
|
search_util.create_date_filter(db.User.last_login_time),
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -3,9 +3,11 @@ from szurubooru import db, errors
|
||||||
from szurubooru.func import util
|
from szurubooru.func import util
|
||||||
from szurubooru.search import criteria
|
from szurubooru.search import criteria
|
||||||
|
|
||||||
|
|
||||||
def wildcard_transformer(value):
|
def wildcard_transformer(value):
|
||||||
return value.replace('*', '%')
|
return value.replace('*', '%')
|
||||||
|
|
||||||
|
|
||||||
def apply_num_criterion_to_column(column, criterion):
|
def apply_num_criterion_to_column(column, criterion):
|
||||||
'''
|
'''
|
||||||
Decorate SQLAlchemy filter on given column using supplied criterion.
|
Decorate SQLAlchemy filter on given column using supplied criterion.
|
||||||
|
@ -32,6 +34,7 @@ def apply_num_criterion_to_column(column, criterion):
|
||||||
'Criterion value %r must be a number.' % (criterion,))
|
'Criterion value %r must be a number.' % (criterion,))
|
||||||
return expr
|
return expr
|
||||||
|
|
||||||
|
|
||||||
def create_num_filter(column):
|
def create_num_filter(column):
|
||||||
def wrapper(query, criterion, negated):
|
def wrapper(query, criterion, negated):
|
||||||
expr = apply_num_criterion_to_column(
|
expr = apply_num_criterion_to_column(
|
||||||
|
@ -41,6 +44,7 @@ def create_num_filter(column):
|
||||||
return query.filter(expr)
|
return query.filter(expr)
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def apply_str_criterion_to_column(
|
def apply_str_criterion_to_column(
|
||||||
column, criterion, transformer=wildcard_transformer):
|
column, criterion, transformer=wildcard_transformer):
|
||||||
'''
|
'''
|
||||||
|
@ -59,6 +63,7 @@ def apply_str_criterion_to_column(
|
||||||
assert False
|
assert False
|
||||||
return expr
|
return expr
|
||||||
|
|
||||||
|
|
||||||
def create_str_filter(column, transformer=wildcard_transformer):
|
def create_str_filter(column, transformer=wildcard_transformer):
|
||||||
def wrapper(query, criterion, negated):
|
def wrapper(query, criterion, negated):
|
||||||
expr = apply_str_criterion_to_column(
|
expr = apply_str_criterion_to_column(
|
||||||
|
@ -68,6 +73,7 @@ def create_str_filter(column, transformer=wildcard_transformer):
|
||||||
return query.filter(expr)
|
return query.filter(expr)
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def apply_date_criterion_to_column(column, criterion):
|
def apply_date_criterion_to_column(column, criterion):
|
||||||
'''
|
'''
|
||||||
Decorate SQLAlchemy filter on given column using supplied criterion.
|
Decorate SQLAlchemy filter on given column using supplied criterion.
|
||||||
|
@ -97,6 +103,7 @@ def apply_date_criterion_to_column(column, criterion):
|
||||||
assert False
|
assert False
|
||||||
return expr
|
return expr
|
||||||
|
|
||||||
|
|
||||||
def create_date_filter(column):
|
def create_date_filter(column):
|
||||||
def wrapper(query, criterion, negated):
|
def wrapper(query, criterion, negated):
|
||||||
expr = apply_date_criterion_to_column(
|
expr = apply_date_criterion_to_column(
|
||||||
|
@ -106,6 +113,7 @@ def create_date_filter(column):
|
||||||
return query.filter(expr)
|
return query.filter(expr)
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def create_subquery_filter(
|
def create_subquery_filter(
|
||||||
left_id_column,
|
left_id_column,
|
||||||
right_id_column,
|
right_id_column,
|
||||||
|
@ -113,6 +121,7 @@ def create_subquery_filter(
|
||||||
filter_factory,
|
filter_factory,
|
||||||
subquery_decorator=None):
|
subquery_decorator=None):
|
||||||
filter_func = filter_factory(filter_column)
|
filter_func = filter_factory(filter_column)
|
||||||
|
|
||||||
def wrapper(query, criterion, negated):
|
def wrapper(query, criterion, negated):
|
||||||
subquery = db.session.query(right_id_column.label('foreign_id'))
|
subquery = db.session.query(right_id_column.label('foreign_id'))
|
||||||
if subquery_decorator:
|
if subquery_decorator:
|
||||||
|
@ -121,4 +130,5 @@ def create_subquery_filter(
|
||||||
subquery = filter_func(subquery, criterion, negated)
|
subquery = filter_func(subquery, criterion, negated)
|
||||||
subquery = subquery.subquery('t')
|
subquery = subquery.subquery('t')
|
||||||
return query.filter(left_id_column.in_(subquery))
|
return query.filter(left_id_column.in_(subquery))
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
|
@ -5,6 +5,7 @@ class _BaseCriterion(object):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return self.original_text
|
return self.original_text
|
||||||
|
|
||||||
|
|
||||||
class RangedCriterion(_BaseCriterion):
|
class RangedCriterion(_BaseCriterion):
|
||||||
def __init__(self, original_text, min_value, max_value):
|
def __init__(self, original_text, min_value, max_value):
|
||||||
super().__init__(original_text)
|
super().__init__(original_text)
|
||||||
|
@ -14,6 +15,7 @@ class RangedCriterion(_BaseCriterion):
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash(('range', self.min_value, self.max_value))
|
return hash(('range', self.min_value, self.max_value))
|
||||||
|
|
||||||
|
|
||||||
class PlainCriterion(_BaseCriterion):
|
class PlainCriterion(_BaseCriterion):
|
||||||
def __init__(self, original_text, value):
|
def __init__(self, original_text, value):
|
||||||
super().__init__(original_text)
|
super().__init__(original_text)
|
||||||
|
@ -22,6 +24,7 @@ class PlainCriterion(_BaseCriterion):
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash(self.value)
|
return hash(self.value)
|
||||||
|
|
||||||
|
|
||||||
class ArrayCriterion(_BaseCriterion):
|
class ArrayCriterion(_BaseCriterion):
|
||||||
def __init__(self, original_text, values):
|
def __init__(self, original_text, values):
|
||||||
super().__init__(original_text)
|
super().__init__(original_text)
|
||||||
|
|
|
@ -3,19 +3,22 @@ from szurubooru import db, errors
|
||||||
from szurubooru.func import cache
|
from szurubooru.func import cache
|
||||||
from szurubooru.search import tokens, parser
|
from szurubooru.search import tokens, parser
|
||||||
|
|
||||||
|
|
||||||
def _format_dict_keys(source):
|
def _format_dict_keys(source):
|
||||||
return list(sorted(source.keys()))
|
return list(sorted(source.keys()))
|
||||||
|
|
||||||
def _get_direction(direction, default_direction):
|
|
||||||
if direction == tokens.SortToken.SORT_DEFAULT:
|
def _get_order(order, default_order):
|
||||||
return default_direction
|
if order == tokens.SortToken.SORT_DEFAULT:
|
||||||
if direction == tokens.SortToken.SORT_NEGATED_DEFAULT:
|
return default_order
|
||||||
if default_direction == tokens.SortToken.SORT_ASC:
|
if order == tokens.SortToken.SORT_NEGATED_DEFAULT:
|
||||||
|
if default_order == tokens.SortToken.SORT_ASC:
|
||||||
return tokens.SortToken.SORT_DESC
|
return tokens.SortToken.SORT_DESC
|
||||||
elif default_direction == tokens.SortToken.SORT_DESC:
|
elif default_order == tokens.SortToken.SORT_DESC:
|
||||||
return tokens.SortToken.SORT_ASC
|
return tokens.SortToken.SORT_ASC
|
||||||
assert False
|
assert False
|
||||||
return direction
|
return order
|
||||||
|
|
||||||
|
|
||||||
class Executor(object):
|
class Executor(object):
|
||||||
'''
|
'''
|
||||||
|
@ -30,20 +33,26 @@ class Executor(object):
|
||||||
def get_around(self, query_text, entity_id):
|
def get_around(self, query_text, entity_id):
|
||||||
search_query = self.parser.parse(query_text)
|
search_query = self.parser.parse(query_text)
|
||||||
self.config.on_search_query_parsed(search_query)
|
self.config.on_search_query_parsed(search_query)
|
||||||
filter_query = self.config \
|
filter_query = (
|
||||||
.create_around_query() \
|
self.config
|
||||||
.options(sqlalchemy.orm.lazyload('*'))
|
.create_around_query()
|
||||||
filter_query = self._prepare_db_query(filter_query, search_query, False)
|
.options(sqlalchemy.orm.lazyload('*')))
|
||||||
prev_filter_query = filter_query \
|
filter_query = self._prepare_db_query(
|
||||||
.filter(self.config.id_column < entity_id) \
|
filter_query, search_query, False)
|
||||||
.order_by(None) \
|
prev_filter_query = (
|
||||||
.order_by(sqlalchemy.func.abs(self.config.id_column - entity_id).asc()) \
|
filter_query
|
||||||
.limit(1)
|
.filter(self.config.id_column < entity_id)
|
||||||
next_filter_query = filter_query \
|
.order_by(None)
|
||||||
.filter(self.config.id_column > entity_id) \
|
.order_by(sqlalchemy.func.abs(
|
||||||
.order_by(None) \
|
self.config.id_column - entity_id).asc())
|
||||||
.order_by(sqlalchemy.func.abs(self.config.id_column - entity_id).asc()) \
|
.limit(1))
|
||||||
.limit(1)
|
next_filter_query = (
|
||||||
|
filter_query
|
||||||
|
.filter(self.config.id_column > entity_id)
|
||||||
|
.order_by(None)
|
||||||
|
.order_by(sqlalchemy.func.abs(
|
||||||
|
self.config.id_column - entity_id).asc())
|
||||||
|
.limit(1))
|
||||||
return [
|
return [
|
||||||
next_filter_query.one_or_none(),
|
next_filter_query.one_or_none(),
|
||||||
prev_filter_query.one_or_none()]
|
prev_filter_query.one_or_none()]
|
||||||
|
@ -92,7 +101,8 @@ class Executor(object):
|
||||||
def execute_and_serialize(self, ctx, serializer):
|
def execute_and_serialize(self, ctx, serializer):
|
||||||
query = ctx.get_param_as_string('query')
|
query = ctx.get_param_as_string('query')
|
||||||
page = ctx.get_param_as_int('page', default=1, min=1)
|
page = ctx.get_param_as_int('page', default=1, min=1)
|
||||||
page_size = ctx.get_param_as_int('pageSize', default=100, min=1, max=100)
|
page_size = ctx.get_param_as_int(
|
||||||
|
'pageSize', default=100, min=1, max=100)
|
||||||
count, entities = self.execute(query, page, page_size)
|
count, entities = self.execute(query, page, page_size)
|
||||||
return {
|
return {
|
||||||
'query': query,
|
'query': query,
|
||||||
|
@ -124,7 +134,8 @@ class Executor(object):
|
||||||
for token in search_query.special_tokens:
|
for token in search_query.special_tokens:
|
||||||
if token.value not in self.config.special_filters:
|
if token.value not in self.config.special_filters:
|
||||||
raise errors.SearchError(
|
raise errors.SearchError(
|
||||||
'Unknown special token: %r. Available special tokens: %r.' % (
|
'Unknown special token: %r. '
|
||||||
|
'Available special tokens: %r.' % (
|
||||||
token.value,
|
token.value,
|
||||||
_format_dict_keys(self.config.special_filters)))
|
_format_dict_keys(self.config.special_filters)))
|
||||||
db_query = self.config.special_filters[token.value](
|
db_query = self.config.special_filters[token.value](
|
||||||
|
@ -134,14 +145,15 @@ class Executor(object):
|
||||||
for token in search_query.sort_tokens:
|
for token in search_query.sort_tokens:
|
||||||
if token.name not in self.config.sort_columns:
|
if token.name not in self.config.sort_columns:
|
||||||
raise errors.SearchError(
|
raise errors.SearchError(
|
||||||
'Unknown sort token: %r. Available sort tokens: %r.' % (
|
'Unknown sort token: %r. '
|
||||||
|
'Available sort tokens: %r.' % (
|
||||||
token.name,
|
token.name,
|
||||||
_format_dict_keys(self.config.sort_columns)))
|
_format_dict_keys(self.config.sort_columns)))
|
||||||
column, default_direction = self.config.sort_columns[token.name]
|
column, default_order = self.config.sort_columns[token.name]
|
||||||
direction = _get_direction(token.direction, default_direction)
|
order = _get_order(token.order, default_order)
|
||||||
if direction == token.SORT_ASC:
|
if order == token.SORT_ASC:
|
||||||
db_query = db_query.order_by(column.asc())
|
db_query = db_query.order_by(column.asc())
|
||||||
elif direction == token.SORT_DESC:
|
elif order == token.SORT_DESC:
|
||||||
db_query = db_query.order_by(column.desc())
|
db_query = db_query.order_by(column.desc())
|
||||||
|
|
||||||
db_query = self.config.finalize_query(db_query)
|
db_query = self.config.finalize_query(db_query)
|
||||||
|
|
|
@ -2,6 +2,7 @@ import re
|
||||||
from szurubooru import errors
|
from szurubooru import errors
|
||||||
from szurubooru.search import criteria, tokens
|
from szurubooru.search import criteria, tokens
|
||||||
|
|
||||||
|
|
||||||
def _create_criterion(original_value, value):
|
def _create_criterion(original_value, value):
|
||||||
if '..' in value:
|
if '..' in value:
|
||||||
low, high = value.split('..', 1)
|
low, high = value.split('..', 1)
|
||||||
|
@ -13,10 +14,12 @@ def _create_criterion(original_value, value):
|
||||||
original_value, value.split(','))
|
original_value, value.split(','))
|
||||||
return criteria.PlainCriterion(original_value, value)
|
return criteria.PlainCriterion(original_value, value)
|
||||||
|
|
||||||
|
|
||||||
def _parse_anonymous(value, negated):
|
def _parse_anonymous(value, negated):
|
||||||
criterion = _create_criterion(value, value)
|
criterion = _create_criterion(value, value)
|
||||||
return tokens.AnonymousToken(criterion, negated)
|
return tokens.AnonymousToken(criterion, negated)
|
||||||
|
|
||||||
|
|
||||||
def _parse_named(key, value, negated):
|
def _parse_named(key, value, negated):
|
||||||
original_value = value
|
original_value = value
|
||||||
if key.endswith('-min'):
|
if key.endswith('-min'):
|
||||||
|
@ -28,34 +31,41 @@ def _parse_named(key, value, negated):
|
||||||
criterion = _create_criterion(original_value, value)
|
criterion = _create_criterion(original_value, value)
|
||||||
return tokens.NamedToken(key, criterion, negated)
|
return tokens.NamedToken(key, criterion, negated)
|
||||||
|
|
||||||
|
|
||||||
def _parse_special(value, negated):
|
def _parse_special(value, negated):
|
||||||
return tokens.SpecialToken(value, negated)
|
return tokens.SpecialToken(value, negated)
|
||||||
|
|
||||||
|
|
||||||
def _parse_sort(value, negated):
|
def _parse_sort(value, negated):
|
||||||
if value.count(',') == 0:
|
if value.count(',') == 0:
|
||||||
direction_str = None
|
order_str = None
|
||||||
elif value.count(',') == 1:
|
elif value.count(',') == 1:
|
||||||
value, direction_str = value.split(',')
|
value, order_str = value.split(',')
|
||||||
else:
|
else:
|
||||||
raise errors.SearchError('Too many commas in sort style token.')
|
raise errors.SearchError('Too many commas in sort style token.')
|
||||||
try:
|
try:
|
||||||
direction = {
|
order = {
|
||||||
'asc': tokens.SortToken.SORT_ASC,
|
'asc': tokens.SortToken.SORT_ASC,
|
||||||
'desc': tokens.SortToken.SORT_DESC,
|
'desc': tokens.SortToken.SORT_DESC,
|
||||||
'': tokens.SortToken.SORT_DEFAULT,
|
'': tokens.SortToken.SORT_DEFAULT,
|
||||||
None: tokens.SortToken.SORT_DEFAULT,
|
None: tokens.SortToken.SORT_DEFAULT,
|
||||||
}[direction_str]
|
}[order_str]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise errors.SearchError(
|
raise errors.SearchError(
|
||||||
'Unknown search direction: %r.' % direction_str)
|
'Unknown search direction: %r.' % order_str)
|
||||||
if negated:
|
if negated:
|
||||||
direction = {
|
order = {
|
||||||
tokens.SortToken.SORT_ASC: tokens.SortToken.SORT_DESC,
|
tokens.SortToken.SORT_ASC:
|
||||||
tokens.SortToken.SORT_DESC: tokens.SortToken.SORT_ASC,
|
tokens.SortToken.SORT_DESC,
|
||||||
tokens.SortToken.SORT_DEFAULT: tokens.SortToken.SORT_NEGATED_DEFAULT,
|
tokens.SortToken.SORT_DESC:
|
||||||
tokens.SortToken.SORT_NEGATED_DEFAULT: tokens.SortToken.SORT_DEFAULT,
|
tokens.SortToken.SORT_ASC,
|
||||||
}[direction]
|
tokens.SortToken.SORT_DEFAULT:
|
||||||
return tokens.SortToken(value, direction)
|
tokens.SortToken.SORT_NEGATED_DEFAULT,
|
||||||
|
tokens.SortToken.SORT_NEGATED_DEFAULT:
|
||||||
|
tokens.SortToken.SORT_DEFAULT,
|
||||||
|
}[order]
|
||||||
|
return tokens.SortToken(value, order)
|
||||||
|
|
||||||
|
|
||||||
class SearchQuery():
|
class SearchQuery():
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -71,6 +81,7 @@ class SearchQuery():
|
||||||
tuple(self.special_tokens),
|
tuple(self.special_tokens),
|
||||||
tuple(self.sort_tokens)))
|
tuple(self.sort_tokens)))
|
||||||
|
|
||||||
|
|
||||||
class Parser(object):
|
class Parser(object):
|
||||||
def parse(self, query_text):
|
def parse(self, query_text):
|
||||||
query = SearchQuery()
|
query = SearchQuery()
|
||||||
|
@ -93,5 +104,6 @@ class Parser(object):
|
||||||
query.named_tokens.append(
|
query.named_tokens.append(
|
||||||
_parse_named(key, value, negated))
|
_parse_named(key, value, negated))
|
||||||
else:
|
else:
|
||||||
query.anonymous_tokens.append(_parse_anonymous(chunk, negated))
|
query.anonymous_tokens.append(
|
||||||
|
_parse_anonymous(chunk, negated))
|
||||||
return query
|
return query
|
||||||
|
|
|
@ -6,6 +6,7 @@ class AnonymousToken(object):
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash((self.criterion, self.negated))
|
return hash((self.criterion, self.negated))
|
||||||
|
|
||||||
|
|
||||||
class NamedToken(AnonymousToken):
|
class NamedToken(AnonymousToken):
|
||||||
def __init__(self, name, criterion, negated):
|
def __init__(self, name, criterion, negated):
|
||||||
super().__init__(criterion, negated)
|
super().__init__(criterion, negated)
|
||||||
|
@ -14,18 +15,20 @@ class NamedToken(AnonymousToken):
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash((self.name, self.criterion, self.negated))
|
return hash((self.name, self.criterion, self.negated))
|
||||||
|
|
||||||
|
|
||||||
class SortToken(object):
|
class SortToken(object):
|
||||||
SORT_DESC = 'desc'
|
SORT_DESC = 'desc'
|
||||||
SORT_ASC = 'asc'
|
SORT_ASC = 'asc'
|
||||||
SORT_DEFAULT = 'default'
|
SORT_DEFAULT = 'default'
|
||||||
SORT_NEGATED_DEFAULT = 'negated default'
|
SORT_NEGATED_DEFAULT = 'negated default'
|
||||||
|
|
||||||
def __init__(self, name, direction):
|
def __init__(self, name, order):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.direction = direction
|
self.order = order
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash((self.name, self.direction))
|
return hash((self.name, self.order))
|
||||||
|
|
||||||
|
|
||||||
class SpecialToken(object):
|
class SpecialToken(object):
|
||||||
def __init__(self, value, negated):
|
def __init__(self, value, negated):
|
||||||
|
|
0
server/szurubooru/tests/__init__.py
Normal file
0
server/szurubooru/tests/__init__.py
Normal file
0
server/szurubooru/tests/api/__init__.py
Normal file
0
server/szurubooru/tests/api/__init__.py
Normal file
|
@ -1,20 +1,22 @@
|
||||||
import pytest
|
|
||||||
import unittest.mock
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from unittest.mock import patch
|
||||||
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, errors
|
||||||
from szurubooru.func import comments, posts
|
from szurubooru.func import comments, posts
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({'privileges': {'comments:create': db.User.RANK_REGULAR}})
|
config_injector({'privileges': {'comments:create': db.User.RANK_REGULAR}})
|
||||||
|
|
||||||
|
|
||||||
def test_creating_comment(
|
def test_creating_comment(
|
||||||
user_factory, post_factory, context_factory, fake_datetime):
|
user_factory, post_factory, context_factory, fake_datetime):
|
||||||
post = post_factory()
|
post = post_factory()
|
||||||
user = user_factory(rank=db.User.RANK_REGULAR)
|
user = user_factory(rank=db.User.RANK_REGULAR)
|
||||||
db.session.add_all([post, user])
|
db.session.add_all([post, user])
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'), \
|
with patch('szurubooru.func.comments.serialize_comment'), \
|
||||||
fake_datetime('1997-01-01'):
|
fake_datetime('1997-01-01'):
|
||||||
comments.serialize_comment.return_value = 'serialized comment'
|
comments.serialize_comment.return_value = 'serialized comment'
|
||||||
result = api.comment_api.create_comment(
|
result = api.comment_api.create_comment(
|
||||||
|
@ -29,6 +31,7 @@ def test_creating_comment(
|
||||||
assert comment.user and comment.user.user_id == user.user_id
|
assert comment.user and comment.user.user_id == user.user_id
|
||||||
assert comment.post and comment.post.post_id == post.post_id
|
assert comment.post and comment.post.post_id == post.post_id
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('params', [
|
@pytest.mark.parametrize('params', [
|
||||||
{'text': None},
|
{'text': None},
|
||||||
{'text': ''},
|
{'text': ''},
|
||||||
|
@ -48,6 +51,7 @@ def test_trying_to_pass_invalid_params(
|
||||||
api.comment_api.create_comment(
|
api.comment_api.create_comment(
|
||||||
context_factory(params=real_params, user=user))
|
context_factory(params=real_params, user=user))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('field', ['text', 'postId'])
|
@pytest.mark.parametrize('field', ['text', 'postId'])
|
||||||
def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
|
def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
|
||||||
params = {
|
params = {
|
||||||
|
@ -61,6 +65,7 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
|
||||||
params={},
|
params={},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=db.User.RANK_REGULAR)))
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_comment_non_existing(user_factory, context_factory):
|
def test_trying_to_comment_non_existing(user_factory, context_factory):
|
||||||
user = user_factory(rank=db.User.RANK_REGULAR)
|
user = user_factory(rank=db.User.RANK_REGULAR)
|
||||||
db.session.add_all([user])
|
db.session.add_all([user])
|
||||||
|
@ -70,6 +75,7 @@ def test_trying_to_comment_non_existing(user_factory, context_factory):
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'text': 'bad', 'postId': 5}, user=user))
|
params={'text': 'bad', 'postId': 5}, user=user))
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_create_without_privileges(user_factory, context_factory):
|
def test_trying_to_create_without_privileges(user_factory, context_factory):
|
||||||
with pytest.raises(errors.AuthError):
|
with pytest.raises(errors.AuthError):
|
||||||
api.comment_api.create_comment(
|
api.comment_api.create_comment(
|
||||||
|
|
|
@ -2,6 +2,7 @@ import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, errors
|
||||||
from szurubooru.func import comments
|
from szurubooru.func import comments
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
|
@ -11,6 +12,7 @@ def inject_config(config_injector):
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def test_deleting_own_comment(user_factory, comment_factory, context_factory):
|
def test_deleting_own_comment(user_factory, comment_factory, context_factory):
|
||||||
user = user_factory()
|
user = user_factory()
|
||||||
comment = comment_factory(user=user)
|
comment = comment_factory(user=user)
|
||||||
|
@ -22,6 +24,7 @@ def test_deleting_own_comment(user_factory, comment_factory, context_factory):
|
||||||
assert result == {}
|
assert result == {}
|
||||||
assert db.session.query(db.Comment).count() == 0
|
assert db.session.query(db.Comment).count() == 0
|
||||||
|
|
||||||
|
|
||||||
def test_deleting_someones_else_comment(
|
def test_deleting_someones_else_comment(
|
||||||
user_factory, comment_factory, context_factory):
|
user_factory, comment_factory, context_factory):
|
||||||
user1 = user_factory(rank=db.User.RANK_REGULAR)
|
user1 = user_factory(rank=db.User.RANK_REGULAR)
|
||||||
|
@ -29,11 +32,12 @@ def test_deleting_someones_else_comment(
|
||||||
comment = comment_factory(user=user1)
|
comment = comment_factory(user=user1)
|
||||||
db.session.add(comment)
|
db.session.add(comment)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
result = api.comment_api.delete_comment(
|
api.comment_api.delete_comment(
|
||||||
context_factory(params={'version': 1}, user=user2),
|
context_factory(params={'version': 1}, user=user2),
|
||||||
{'comment_id': comment.comment_id})
|
{'comment_id': comment.comment_id})
|
||||||
assert db.session.query(db.Comment).count() == 0
|
assert db.session.query(db.Comment).count() == 0
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_delete_someones_else_comment_without_privileges(
|
def test_trying_to_delete_someones_else_comment_without_privileges(
|
||||||
user_factory, comment_factory, context_factory):
|
user_factory, comment_factory, context_factory):
|
||||||
user1 = user_factory(rank=db.User.RANK_REGULAR)
|
user1 = user_factory(rank=db.User.RANK_REGULAR)
|
||||||
|
@ -47,6 +51,7 @@ def test_trying_to_delete_someones_else_comment_without_privileges(
|
||||||
{'comment_id': comment.comment_id})
|
{'comment_id': comment.comment_id})
|
||||||
assert db.session.query(db.Comment).count() == 1
|
assert db.session.query(db.Comment).count() == 1
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_delete_non_existing(user_factory, context_factory):
|
def test_trying_to_delete_non_existing(user_factory, context_factory):
|
||||||
with pytest.raises(comments.CommentNotFoundError):
|
with pytest.raises(comments.CommentNotFoundError):
|
||||||
api.comment_api.delete_comment(
|
api.comment_api.delete_comment(
|
||||||
|
|
|
@ -1,86 +1,92 @@
|
||||||
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
import unittest.mock
|
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, errors
|
||||||
from szurubooru.func import comments, scores
|
from szurubooru.func import comments
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({'privileges': {'comments:score': db.User.RANK_REGULAR}})
|
config_injector({'privileges': {'comments:score': db.User.RANK_REGULAR}})
|
||||||
|
|
||||||
|
|
||||||
def test_simple_rating(
|
def test_simple_rating(
|
||||||
user_factory, comment_factory, context_factory, fake_datetime):
|
user_factory, comment_factory, context_factory, fake_datetime):
|
||||||
user = user_factory(rank=db.User.RANK_REGULAR)
|
user = user_factory(rank=db.User.RANK_REGULAR)
|
||||||
comment = comment_factory(user=user)
|
comment = comment_factory(user=user)
|
||||||
db.session.add(comment)
|
db.session.add(comment)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'):
|
with patch('szurubooru.func.comments.serialize_comment'), \
|
||||||
|
fake_datetime('1997-12-01'):
|
||||||
comments.serialize_comment.return_value = 'serialized comment'
|
comments.serialize_comment.return_value = 'serialized comment'
|
||||||
with fake_datetime('1997-12-01'):
|
result = api.comment_api.set_comment_score(
|
||||||
result = api.comment_api.set_comment_score(
|
context_factory(params={'score': 1}, user=user),
|
||||||
context_factory(params={'score': 1}, user=user),
|
{'comment_id': comment.comment_id})
|
||||||
{'comment_id': comment.comment_id})
|
|
||||||
assert result == 'serialized comment'
|
assert result == 'serialized comment'
|
||||||
assert db.session.query(db.CommentScore).count() == 1
|
assert db.session.query(db.CommentScore).count() == 1
|
||||||
assert comment is not None
|
assert comment is not None
|
||||||
assert comment.score == 1
|
assert comment.score == 1
|
||||||
|
|
||||||
|
|
||||||
def test_updating_rating(
|
def test_updating_rating(
|
||||||
user_factory, comment_factory, context_factory, fake_datetime):
|
user_factory, comment_factory, context_factory, fake_datetime):
|
||||||
user = user_factory(rank=db.User.RANK_REGULAR)
|
user = user_factory(rank=db.User.RANK_REGULAR)
|
||||||
comment = comment_factory(user=user)
|
comment = comment_factory(user=user)
|
||||||
db.session.add(comment)
|
db.session.add(comment)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'):
|
with patch('szurubooru.func.comments.serialize_comment'):
|
||||||
with fake_datetime('1997-12-01'):
|
with fake_datetime('1997-12-01'):
|
||||||
result = api.comment_api.set_comment_score(
|
api.comment_api.set_comment_score(
|
||||||
context_factory(params={'score': 1}, user=user),
|
context_factory(params={'score': 1}, user=user),
|
||||||
{'comment_id': comment.comment_id})
|
{'comment_id': comment.comment_id})
|
||||||
with fake_datetime('1997-12-02'):
|
with fake_datetime('1997-12-02'):
|
||||||
result = api.comment_api.set_comment_score(
|
api.comment_api.set_comment_score(
|
||||||
context_factory(params={'score': -1}, user=user),
|
context_factory(params={'score': -1}, user=user),
|
||||||
{'comment_id': comment.comment_id})
|
{'comment_id': comment.comment_id})
|
||||||
comment = db.session.query(db.Comment).one()
|
comment = db.session.query(db.Comment).one()
|
||||||
assert db.session.query(db.CommentScore).count() == 1
|
assert db.session.query(db.CommentScore).count() == 1
|
||||||
assert comment.score == -1
|
assert comment.score == -1
|
||||||
|
|
||||||
|
|
||||||
def test_updating_rating_to_zero(
|
def test_updating_rating_to_zero(
|
||||||
user_factory, comment_factory, context_factory, fake_datetime):
|
user_factory, comment_factory, context_factory, fake_datetime):
|
||||||
user = user_factory(rank=db.User.RANK_REGULAR)
|
user = user_factory(rank=db.User.RANK_REGULAR)
|
||||||
comment = comment_factory(user=user)
|
comment = comment_factory(user=user)
|
||||||
db.session.add(comment)
|
db.session.add(comment)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'):
|
with patch('szurubooru.func.comments.serialize_comment'):
|
||||||
with fake_datetime('1997-12-01'):
|
with fake_datetime('1997-12-01'):
|
||||||
result = api.comment_api.set_comment_score(
|
api.comment_api.set_comment_score(
|
||||||
context_factory(params={'score': 1}, user=user),
|
context_factory(params={'score': 1}, user=user),
|
||||||
{'comment_id': comment.comment_id})
|
{'comment_id': comment.comment_id})
|
||||||
with fake_datetime('1997-12-02'):
|
with fake_datetime('1997-12-02'):
|
||||||
result = api.comment_api.set_comment_score(
|
api.comment_api.set_comment_score(
|
||||||
context_factory(params={'score': 0}, user=user),
|
context_factory(params={'score': 0}, user=user),
|
||||||
{'comment_id': comment.comment_id})
|
{'comment_id': comment.comment_id})
|
||||||
comment = db.session.query(db.Comment).one()
|
comment = db.session.query(db.Comment).one()
|
||||||
assert db.session.query(db.CommentScore).count() == 0
|
assert db.session.query(db.CommentScore).count() == 0
|
||||||
assert comment.score == 0
|
assert comment.score == 0
|
||||||
|
|
||||||
|
|
||||||
def test_deleting_rating(
|
def test_deleting_rating(
|
||||||
user_factory, comment_factory, context_factory, fake_datetime):
|
user_factory, comment_factory, context_factory, fake_datetime):
|
||||||
user = user_factory(rank=db.User.RANK_REGULAR)
|
user = user_factory(rank=db.User.RANK_REGULAR)
|
||||||
comment = comment_factory(user=user)
|
comment = comment_factory(user=user)
|
||||||
db.session.add(comment)
|
db.session.add(comment)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'):
|
with patch('szurubooru.func.comments.serialize_comment'):
|
||||||
with fake_datetime('1997-12-01'):
|
with fake_datetime('1997-12-01'):
|
||||||
result = api.comment_api.set_comment_score(
|
api.comment_api.set_comment_score(
|
||||||
context_factory(params={'score': 1}, user=user),
|
context_factory(params={'score': 1}, user=user),
|
||||||
{'comment_id': comment.comment_id})
|
{'comment_id': comment.comment_id})
|
||||||
with fake_datetime('1997-12-02'):
|
with fake_datetime('1997-12-02'):
|
||||||
result = api.comment_api.delete_comment_score(
|
api.comment_api.delete_comment_score(
|
||||||
context_factory(user=user),
|
context_factory(user=user),
|
||||||
{'comment_id': comment.comment_id})
|
{'comment_id': comment.comment_id})
|
||||||
comment = db.session.query(db.Comment).one()
|
comment = db.session.query(db.Comment).one()
|
||||||
assert db.session.query(db.CommentScore).count() == 0
|
assert db.session.query(db.CommentScore).count() == 0
|
||||||
assert comment.score == 0
|
assert comment.score == 0
|
||||||
|
|
||||||
|
|
||||||
def test_ratings_from_multiple_users(
|
def test_ratings_from_multiple_users(
|
||||||
user_factory, comment_factory, context_factory, fake_datetime):
|
user_factory, comment_factory, context_factory, fake_datetime):
|
||||||
user1 = user_factory(rank=db.User.RANK_REGULAR)
|
user1 = user_factory(rank=db.User.RANK_REGULAR)
|
||||||
|
@ -88,19 +94,20 @@ def test_ratings_from_multiple_users(
|
||||||
comment = comment_factory()
|
comment = comment_factory()
|
||||||
db.session.add_all([user1, user2, comment])
|
db.session.add_all([user1, user2, comment])
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'):
|
with patch('szurubooru.func.comments.serialize_comment'):
|
||||||
with fake_datetime('1997-12-01'):
|
with fake_datetime('1997-12-01'):
|
||||||
result = api.comment_api.set_comment_score(
|
api.comment_api.set_comment_score(
|
||||||
context_factory(params={'score': 1}, user=user1),
|
context_factory(params={'score': 1}, user=user1),
|
||||||
{'comment_id': comment.comment_id})
|
{'comment_id': comment.comment_id})
|
||||||
with fake_datetime('1997-12-02'):
|
with fake_datetime('1997-12-02'):
|
||||||
result = api.comment_api.set_comment_score(
|
api.comment_api.set_comment_score(
|
||||||
context_factory(params={'score': -1}, user=user2),
|
context_factory(params={'score': -1}, user=user2),
|
||||||
{'comment_id': comment.comment_id})
|
{'comment_id': comment.comment_id})
|
||||||
comment = db.session.query(db.Comment).one()
|
comment = db.session.query(db.Comment).one()
|
||||||
assert db.session.query(db.CommentScore).count() == 2
|
assert db.session.query(db.CommentScore).count() == 2
|
||||||
assert comment.score == 0
|
assert comment.score == 0
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_omit_mandatory_field(
|
def test_trying_to_omit_mandatory_field(
|
||||||
user_factory, comment_factory, context_factory):
|
user_factory, comment_factory, context_factory):
|
||||||
user = user_factory()
|
user = user_factory()
|
||||||
|
@ -112,8 +119,8 @@ def test_trying_to_omit_mandatory_field(
|
||||||
context_factory(params={}, user=user),
|
context_factory(params={}, user=user),
|
||||||
{'comment_id': comment.comment_id})
|
{'comment_id': comment.comment_id})
|
||||||
|
|
||||||
def test_trying_to_update_non_existing(
|
|
||||||
user_factory, comment_factory, context_factory):
|
def test_trying_to_update_non_existing(user_factory, context_factory):
|
||||||
with pytest.raises(comments.CommentNotFoundError):
|
with pytest.raises(comments.CommentNotFoundError):
|
||||||
api.comment_api.set_comment_score(
|
api.comment_api.set_comment_score(
|
||||||
context_factory(
|
context_factory(
|
||||||
|
@ -121,6 +128,7 @@ def test_trying_to_update_non_existing(
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=db.User.RANK_REGULAR)),
|
||||||
{'comment_id': 5})
|
{'comment_id': 5})
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_rate_without_privileges(
|
def test_trying_to_rate_without_privileges(
|
||||||
user_factory, comment_factory, context_factory):
|
user_factory, comment_factory, context_factory):
|
||||||
comment = comment_factory()
|
comment = comment_factory()
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
import unittest.mock
|
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, errors
|
||||||
from szurubooru.func import comments
|
from szurubooru.func import comments
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
|
@ -12,11 +13,12 @@ def inject_config(config_injector):
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def test_retrieving_multiple(user_factory, comment_factory, context_factory):
|
def test_retrieving_multiple(user_factory, comment_factory, context_factory):
|
||||||
comment1 = comment_factory(text='text 1')
|
comment1 = comment_factory(text='text 1')
|
||||||
comment2 = comment_factory(text='text 2')
|
comment2 = comment_factory(text='text 2')
|
||||||
db.session.add_all([comment1, comment2])
|
db.session.add_all([comment1, comment2])
|
||||||
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'):
|
with patch('szurubooru.func.comments.serialize_comment'):
|
||||||
comments.serialize_comment.return_value = 'serialized comment'
|
comments.serialize_comment.return_value = 'serialized comment'
|
||||||
result = api.comment_api.get_comments(
|
result = api.comment_api.get_comments(
|
||||||
context_factory(
|
context_factory(
|
||||||
|
@ -30,6 +32,7 @@ def test_retrieving_multiple(user_factory, comment_factory, context_factory):
|
||||||
'results': ['serialized comment', 'serialized comment'],
|
'results': ['serialized comment', 'serialized comment'],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_retrieve_multiple_without_privileges(
|
def test_trying_to_retrieve_multiple_without_privileges(
|
||||||
user_factory, context_factory):
|
user_factory, context_factory):
|
||||||
with pytest.raises(errors.AuthError):
|
with pytest.raises(errors.AuthError):
|
||||||
|
@ -38,11 +41,12 @@ def test_trying_to_retrieve_multiple_without_privileges(
|
||||||
params={'query': '', 'page': 1},
|
params={'query': '', 'page': 1},
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
|
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
|
||||||
|
|
||||||
|
|
||||||
def test_retrieving_single(user_factory, comment_factory, context_factory):
|
def test_retrieving_single(user_factory, comment_factory, context_factory):
|
||||||
comment = comment_factory(text='dummy text')
|
comment = comment_factory(text='dummy text')
|
||||||
db.session.add(comment)
|
db.session.add(comment)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'):
|
with patch('szurubooru.func.comments.serialize_comment'):
|
||||||
comments.serialize_comment.return_value = 'serialized comment'
|
comments.serialize_comment.return_value = 'serialized comment'
|
||||||
result = api.comment_api.get_comment(
|
result = api.comment_api.get_comment(
|
||||||
context_factory(
|
context_factory(
|
||||||
|
@ -50,6 +54,7 @@ def test_retrieving_single(user_factory, comment_factory, context_factory):
|
||||||
{'comment_id': comment.comment_id})
|
{'comment_id': comment.comment_id})
|
||||||
assert result == 'serialized comment'
|
assert result == 'serialized comment'
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
|
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
|
||||||
with pytest.raises(comments.CommentNotFoundError):
|
with pytest.raises(comments.CommentNotFoundError):
|
||||||
api.comment_api.get_comment(
|
api.comment_api.get_comment(
|
||||||
|
@ -57,6 +62,7 @@ def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=db.User.RANK_REGULAR)),
|
||||||
{'comment_id': 5})
|
{'comment_id': 5})
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_retrieve_single_without_privileges(
|
def test_trying_to_retrieve_single_without_privileges(
|
||||||
user_factory, context_factory):
|
user_factory, context_factory):
|
||||||
with pytest.raises(errors.AuthError):
|
with pytest.raises(errors.AuthError):
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
import pytest
|
|
||||||
import unittest.mock
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from unittest.mock import patch
|
||||||
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, errors
|
||||||
from szurubooru.func import comments
|
from szurubooru.func import comments
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
|
@ -13,13 +14,14 @@ def inject_config(config_injector):
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def test_simple_updating(
|
def test_simple_updating(
|
||||||
user_factory, comment_factory, context_factory, fake_datetime):
|
user_factory, comment_factory, context_factory, fake_datetime):
|
||||||
user = user_factory(rank=db.User.RANK_REGULAR)
|
user = user_factory(rank=db.User.RANK_REGULAR)
|
||||||
comment = comment_factory(user=user)
|
comment = comment_factory(user=user)
|
||||||
db.session.add(comment)
|
db.session.add(comment)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'), \
|
with patch('szurubooru.func.comments.serialize_comment'), \
|
||||||
fake_datetime('1997-12-01'):
|
fake_datetime('1997-12-01'):
|
||||||
comments.serialize_comment.return_value = 'serialized comment'
|
comments.serialize_comment.return_value = 'serialized comment'
|
||||||
result = api.comment_api.update_comment(
|
result = api.comment_api.update_comment(
|
||||||
|
@ -29,6 +31,7 @@ def test_simple_updating(
|
||||||
assert result == 'serialized comment'
|
assert result == 'serialized comment'
|
||||||
assert comment.last_edit_time == datetime(1997, 12, 1)
|
assert comment.last_edit_time == datetime(1997, 12, 1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('params,expected_exception', [
|
@pytest.mark.parametrize('params,expected_exception', [
|
||||||
({'text': None}, comments.EmptyCommentTextError),
|
({'text': None}, comments.EmptyCommentTextError),
|
||||||
({'text': ''}, comments.EmptyCommentTextError),
|
({'text': ''}, comments.EmptyCommentTextError),
|
||||||
|
@ -37,7 +40,11 @@ def test_simple_updating(
|
||||||
({'text': ['']}, comments.EmptyCommentTextError),
|
({'text': ['']}, comments.EmptyCommentTextError),
|
||||||
])
|
])
|
||||||
def test_trying_to_pass_invalid_params(
|
def test_trying_to_pass_invalid_params(
|
||||||
user_factory, comment_factory, context_factory, params, expected_exception):
|
user_factory,
|
||||||
|
comment_factory,
|
||||||
|
context_factory,
|
||||||
|
params,
|
||||||
|
expected_exception):
|
||||||
user = user_factory()
|
user = user_factory()
|
||||||
comment = comment_factory(user=user)
|
comment = comment_factory(user=user)
|
||||||
db.session.add(comment)
|
db.session.add(comment)
|
||||||
|
@ -48,6 +55,7 @@ def test_trying_to_pass_invalid_params(
|
||||||
params={**params, **{'version': 1}}, user=user),
|
params={**params, **{'version': 1}}, user=user),
|
||||||
{'comment_id': comment.comment_id})
|
{'comment_id': comment.comment_id})
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_omit_mandatory_field(
|
def test_trying_to_omit_mandatory_field(
|
||||||
user_factory, comment_factory, context_factory):
|
user_factory, comment_factory, context_factory):
|
||||||
user = user_factory()
|
user = user_factory()
|
||||||
|
@ -59,6 +67,7 @@ def test_trying_to_omit_mandatory_field(
|
||||||
context_factory(params={'version': 1}, user=user),
|
context_factory(params={'version': 1}, user=user),
|
||||||
{'comment_id': comment.comment_id})
|
{'comment_id': comment.comment_id})
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_update_non_existing(user_factory, context_factory):
|
def test_trying_to_update_non_existing(user_factory, context_factory):
|
||||||
with pytest.raises(comments.CommentNotFoundError):
|
with pytest.raises(comments.CommentNotFoundError):
|
||||||
api.comment_api.update_comment(
|
api.comment_api.update_comment(
|
||||||
|
@ -67,6 +76,7 @@ def test_trying_to_update_non_existing(user_factory, context_factory):
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=db.User.RANK_REGULAR)),
|
||||||
{'comment_id': 5})
|
{'comment_id': 5})
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_update_someones_comment_without_privileges(
|
def test_trying_to_update_someones_comment_without_privileges(
|
||||||
user_factory, comment_factory, context_factory):
|
user_factory, comment_factory, context_factory):
|
||||||
user = user_factory(rank=db.User.RANK_REGULAR)
|
user = user_factory(rank=db.User.RANK_REGULAR)
|
||||||
|
@ -80,6 +90,7 @@ def test_trying_to_update_someones_comment_without_privileges(
|
||||||
params={'text': 'new text', 'version': 1}, user=user2),
|
params={'text': 'new text', 'version': 1}, user=user2),
|
||||||
{'comment_id': comment.comment_id})
|
{'comment_id': comment.comment_id})
|
||||||
|
|
||||||
|
|
||||||
def test_updating_someones_comment_with_privileges(
|
def test_updating_someones_comment_with_privileges(
|
||||||
user_factory, comment_factory, context_factory):
|
user_factory, comment_factory, context_factory):
|
||||||
user = user_factory(rank=db.User.RANK_REGULAR)
|
user = user_factory(rank=db.User.RANK_REGULAR)
|
||||||
|
@ -87,7 +98,7 @@ def test_updating_someones_comment_with_privileges(
|
||||||
comment = comment_factory(user=user)
|
comment = comment_factory(user=user)
|
||||||
db.session.add(comment)
|
db.session.add(comment)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'):
|
with patch('szurubooru.func.comments.serialize_comment'):
|
||||||
api.comment_api.update_comment(
|
api.comment_api.update_comment(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'text': 'new text', 'version': 1}, user=user2),
|
params={'text': 'new text', 'version': 1}, user=user2),
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from szurubooru import api, db
|
from szurubooru import api, db
|
||||||
|
|
||||||
|
|
||||||
def test_info_api(
|
def test_info_api(
|
||||||
tmpdir, config_injector, context_factory, post_factory, fake_datetime):
|
tmpdir, config_injector, context_factory, post_factory, fake_datetime):
|
||||||
directory = tmpdir.mkdir('data')
|
directory = tmpdir.mkdir('data')
|
||||||
|
@ -45,7 +46,7 @@ def test_info_api(
|
||||||
with fake_datetime('2016-01-01 13:59'):
|
with fake_datetime('2016-01-01 13:59'):
|
||||||
assert api.info_api.get_info(context_factory()) == {
|
assert api.info_api.get_info(context_factory()) == {
|
||||||
'postCount': 2,
|
'postCount': 2,
|
||||||
'diskUsage': 3, # still 3 - it's cached
|
'diskUsage': 3, # still 3 - it's cached
|
||||||
'featuredPost': None,
|
'featuredPost': None,
|
||||||
'featuringTime': None,
|
'featuringTime': None,
|
||||||
'featuringUser': None,
|
'featuringUser': None,
|
||||||
|
@ -55,7 +56,7 @@ def test_info_api(
|
||||||
with fake_datetime('2016-01-01 14:01'):
|
with fake_datetime('2016-01-01 14:01'):
|
||||||
assert api.info_api.get_info(context_factory()) == {
|
assert api.info_api.get_info(context_factory()) == {
|
||||||
'postCount': 2,
|
'postCount': 2,
|
||||||
'diskUsage': 6, # cache expired
|
'diskUsage': 6, # cache expired
|
||||||
'featuredPost': None,
|
'featuredPost': None,
|
||||||
'featuringTime': None,
|
'featuringTime': None,
|
||||||
'featuringUser': None,
|
'featuringUser': None,
|
||||||
|
|
|
@ -1,21 +1,23 @@
|
||||||
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
import unittest.mock
|
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, errors
|
||||||
from szurubooru.func import auth, mailer
|
from szurubooru.func import auth, mailer
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(tmpdir, config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
'secret': 'x',
|
'secret': 'x',
|
||||||
'base_url': 'http://example.com/',
|
'base_url': 'http://example.com/',
|
||||||
'name': 'Test instance',
|
'name': 'Test instance',
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def test_reset_sending_email(context_factory, user_factory):
|
def test_reset_sending_email(context_factory, user_factory):
|
||||||
db.session.add(user_factory(
|
db.session.add(user_factory(
|
||||||
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com'))
|
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com'))
|
||||||
for initiating_user in ['u1', 'user@example.com']:
|
for initiating_user in ['u1', 'user@example.com']:
|
||||||
with unittest.mock.patch('szurubooru.func.mailer.send_mail'):
|
with patch('szurubooru.func.mailer.send_mail'):
|
||||||
assert api.password_reset_api.start_password_reset(
|
assert api.password_reset_api.start_password_reset(
|
||||||
context_factory(), {'user_name': initiating_user}) == {}
|
context_factory(), {'user_name': initiating_user}) == {}
|
||||||
mailer.send_mail.assert_called_once_with(
|
mailer.send_mail.assert_called_once_with(
|
||||||
|
@ -27,17 +29,21 @@ def test_reset_sending_email(context_factory, user_factory):
|
||||||
'ink: http://example.com/password-reset/u1:4ac0be176fb36' +
|
'ink: http://example.com/password-reset/u1:4ac0be176fb36' +
|
||||||
'4f13ee6b634c43220e2\nOtherwise, please ignore this email.')
|
'4f13ee6b634c43220e2\nOtherwise, please ignore this email.')
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_reset_non_existing(context_factory):
|
def test_trying_to_reset_non_existing(context_factory):
|
||||||
with pytest.raises(errors.NotFoundError):
|
with pytest.raises(errors.NotFoundError):
|
||||||
api.password_reset_api.start_password_reset(
|
api.password_reset_api.start_password_reset(
|
||||||
context_factory(), {'user_name': 'u1'})
|
context_factory(), {'user_name': 'u1'})
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_reset_without_email(context_factory, user_factory):
|
def test_trying_to_reset_without_email(context_factory, user_factory):
|
||||||
db.session.add(user_factory(name='u1', rank=db.User.RANK_REGULAR, email=None))
|
db.session.add(
|
||||||
|
user_factory(name='u1', rank=db.User.RANK_REGULAR, email=None))
|
||||||
with pytest.raises(errors.ValidationError):
|
with pytest.raises(errors.ValidationError):
|
||||||
api.password_reset_api.start_password_reset(
|
api.password_reset_api.start_password_reset(
|
||||||
context_factory(), {'user_name': 'u1'})
|
context_factory(), {'user_name': 'u1'})
|
||||||
|
|
||||||
|
|
||||||
def test_confirming_with_good_token(context_factory, user_factory):
|
def test_confirming_with_good_token(context_factory, user_factory):
|
||||||
user = user_factory(
|
user = user_factory(
|
||||||
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')
|
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')
|
||||||
|
@ -50,11 +56,13 @@ def test_confirming_with_good_token(context_factory, user_factory):
|
||||||
assert user.password_hash != old_hash
|
assert user.password_hash != old_hash
|
||||||
assert auth.is_valid_password(user, result['password']) is True
|
assert auth.is_valid_password(user, result['password']) is True
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_confirm_non_existing(context_factory):
|
def test_trying_to_confirm_non_existing(context_factory):
|
||||||
with pytest.raises(errors.NotFoundError):
|
with pytest.raises(errors.NotFoundError):
|
||||||
api.password_reset_api.finish_password_reset(
|
api.password_reset_api.finish_password_reset(
|
||||||
context_factory(), {'user_name': 'u1'})
|
context_factory(), {'user_name': 'u1'})
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_confirm_without_token(context_factory, user_factory):
|
def test_trying_to_confirm_without_token(context_factory, user_factory):
|
||||||
db.session.add(user_factory(
|
db.session.add(user_factory(
|
||||||
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com'))
|
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com'))
|
||||||
|
@ -62,6 +70,7 @@ def test_trying_to_confirm_without_token(context_factory, user_factory):
|
||||||
api.password_reset_api.finish_password_reset(
|
api.password_reset_api.finish_password_reset(
|
||||||
context_factory(params={}), {'user_name': 'u1'})
|
context_factory(params={}), {'user_name': 'u1'})
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_confirm_with_bad_token(context_factory, user_factory):
|
def test_trying_to_confirm_with_bad_token(context_factory, user_factory):
|
||||||
db.session.add(user_factory(
|
db.session.add(user_factory(
|
||||||
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com'))
|
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com'))
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
import unittest.mock
|
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, errors
|
||||||
from szurubooru.func import posts, tags, snapshots, net
|
from szurubooru.func import posts, tags, snapshots, net
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
|
@ -13,6 +14,7 @@ def inject_config(config_injector):
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def test_creating_minimal_posts(
|
def test_creating_minimal_posts(
|
||||||
context_factory, post_factory, user_factory):
|
context_factory, post_factory, user_factory):
|
||||||
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
||||||
|
@ -20,16 +22,16 @@ def test_creating_minimal_posts(
|
||||||
db.session.add(post)
|
db.session.add(post)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
|
|
||||||
with unittest.mock.patch('szurubooru.func.posts.create_post'), \
|
with patch('szurubooru.func.posts.create_post'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.update_post_safety'), \
|
patch('szurubooru.func.posts.update_post_safety'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.update_post_source'), \
|
patch('szurubooru.func.posts.update_post_source'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.update_post_relations'), \
|
patch('szurubooru.func.posts.update_post_relations'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.update_post_notes'), \
|
patch('szurubooru.func.posts.update_post_notes'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.update_post_flags'), \
|
patch('szurubooru.func.posts.update_post_flags'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.update_post_thumbnail'), \
|
patch('szurubooru.func.posts.update_post_thumbnail'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.serialize_post'), \
|
patch('szurubooru.func.posts.serialize_post'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.export_to_json'), \
|
patch('szurubooru.func.tags.export_to_json'), \
|
||||||
unittest.mock.patch('szurubooru.func.snapshots.save_entity_creation'):
|
patch('szurubooru.func.snapshots.save_entity_creation'):
|
||||||
posts.create_post.return_value = (post, [])
|
posts.create_post.return_value = (post, [])
|
||||||
posts.serialize_post.return_value = 'serialized post'
|
posts.serialize_post.return_value = 'serialized post'
|
||||||
|
|
||||||
|
@ -48,32 +50,36 @@ def test_creating_minimal_posts(
|
||||||
assert result == 'serialized post'
|
assert result == 'serialized post'
|
||||||
posts.create_post.assert_called_once_with(
|
posts.create_post.assert_called_once_with(
|
||||||
'post-content', ['tag1', 'tag2'], auth_user)
|
'post-content', ['tag1', 'tag2'], auth_user)
|
||||||
posts.update_post_thumbnail.assert_called_once_with(post, 'post-thumbnail')
|
posts.update_post_thumbnail.assert_called_once_with(
|
||||||
|
post, 'post-thumbnail')
|
||||||
posts.update_post_safety.assert_called_once_with(post, 'safe')
|
posts.update_post_safety.assert_called_once_with(post, 'safe')
|
||||||
posts.update_post_source.assert_called_once_with(post, None)
|
posts.update_post_source.assert_called_once_with(post, None)
|
||||||
posts.update_post_relations.assert_called_once_with(post, [])
|
posts.update_post_relations.assert_called_once_with(post, [])
|
||||||
posts.update_post_notes.assert_called_once_with(post, [])
|
posts.update_post_notes.assert_called_once_with(post, [])
|
||||||
posts.update_post_flags.assert_called_once_with(post, [])
|
posts.update_post_flags.assert_called_once_with(post, [])
|
||||||
posts.update_post_thumbnail.assert_called_once_with(post, 'post-thumbnail')
|
posts.update_post_thumbnail.assert_called_once_with(
|
||||||
posts.serialize_post.assert_called_once_with(post, auth_user, options=None)
|
post, 'post-thumbnail')
|
||||||
|
posts.serialize_post.assert_called_once_with(
|
||||||
|
post, auth_user, options=None)
|
||||||
tags.export_to_json.assert_called_once_with()
|
tags.export_to_json.assert_called_once_with()
|
||||||
snapshots.save_entity_creation.assert_called_once_with(post, auth_user)
|
snapshots.save_entity_creation.assert_called_once_with(post, auth_user)
|
||||||
|
|
||||||
|
|
||||||
def test_creating_full_posts(context_factory, post_factory, user_factory):
|
def test_creating_full_posts(context_factory, post_factory, user_factory):
|
||||||
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
||||||
post = post_factory()
|
post = post_factory()
|
||||||
db.session.add(post)
|
db.session.add(post)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
|
|
||||||
with unittest.mock.patch('szurubooru.func.posts.create_post'), \
|
with patch('szurubooru.func.posts.create_post'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.update_post_safety'), \
|
patch('szurubooru.func.posts.update_post_safety'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.update_post_source'), \
|
patch('szurubooru.func.posts.update_post_source'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.update_post_relations'), \
|
patch('szurubooru.func.posts.update_post_relations'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.update_post_notes'), \
|
patch('szurubooru.func.posts.update_post_notes'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.update_post_flags'), \
|
patch('szurubooru.func.posts.update_post_flags'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.serialize_post'), \
|
patch('szurubooru.func.posts.serialize_post'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.export_to_json'), \
|
patch('szurubooru.func.tags.export_to_json'), \
|
||||||
unittest.mock.patch('szurubooru.func.snapshots.save_entity_creation'):
|
patch('szurubooru.func.snapshots.save_entity_creation'):
|
||||||
posts.create_post.return_value = (post, [])
|
posts.create_post.return_value = (post, [])
|
||||||
posts.serialize_post.return_value = 'serialized post'
|
posts.serialize_post.return_value = 'serialized post'
|
||||||
|
|
||||||
|
@ -98,12 +104,16 @@ def test_creating_full_posts(context_factory, post_factory, user_factory):
|
||||||
posts.update_post_safety.assert_called_once_with(post, 'safe')
|
posts.update_post_safety.assert_called_once_with(post, 'safe')
|
||||||
posts.update_post_source.assert_called_once_with(post, 'source')
|
posts.update_post_source.assert_called_once_with(post, 'source')
|
||||||
posts.update_post_relations.assert_called_once_with(post, [1, 2])
|
posts.update_post_relations.assert_called_once_with(post, [1, 2])
|
||||||
posts.update_post_notes.assert_called_once_with(post, ['note1', 'note2'])
|
posts.update_post_notes.assert_called_once_with(
|
||||||
posts.update_post_flags.assert_called_once_with(post, ['flag1', 'flag2'])
|
post, ['note1', 'note2'])
|
||||||
posts.serialize_post.assert_called_once_with(post, auth_user, options=None)
|
posts.update_post_flags.assert_called_once_with(
|
||||||
|
post, ['flag1', 'flag2'])
|
||||||
|
posts.serialize_post.assert_called_once_with(
|
||||||
|
post, auth_user, options=None)
|
||||||
tags.export_to_json.assert_called_once_with()
|
tags.export_to_json.assert_called_once_with()
|
||||||
snapshots.save_entity_creation.assert_called_once_with(post, auth_user)
|
snapshots.save_entity_creation.assert_called_once_with(post, auth_user)
|
||||||
|
|
||||||
|
|
||||||
def test_anonymous_uploads(
|
def test_anonymous_uploads(
|
||||||
config_injector, context_factory, post_factory, user_factory):
|
config_injector, context_factory, post_factory, user_factory):
|
||||||
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
||||||
|
@ -111,11 +121,11 @@ def test_anonymous_uploads(
|
||||||
db.session.add(post)
|
db.session.add(post)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
|
|
||||||
with unittest.mock.patch('szurubooru.func.tags.export_to_json'), \
|
with patch('szurubooru.func.tags.export_to_json'), \
|
||||||
unittest.mock.patch('szurubooru.func.snapshots.save_entity_creation'), \
|
patch('szurubooru.func.snapshots.save_entity_creation'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.serialize_post'), \
|
patch('szurubooru.func.posts.serialize_post'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.create_post'), \
|
patch('szurubooru.func.posts.create_post'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.update_post_source'):
|
patch('szurubooru.func.posts.update_post_source'):
|
||||||
config_injector({
|
config_injector({
|
||||||
'privileges': {'posts:create:anonymous': db.User.RANK_REGULAR},
|
'privileges': {'posts:create:anonymous': db.User.RANK_REGULAR},
|
||||||
})
|
})
|
||||||
|
@ -134,6 +144,7 @@ def test_anonymous_uploads(
|
||||||
posts.create_post.assert_called_once_with(
|
posts.create_post.assert_called_once_with(
|
||||||
'post-content', ['tag1', 'tag2'], None)
|
'post-content', ['tag1', 'tag2'], None)
|
||||||
|
|
||||||
|
|
||||||
def test_creating_from_url_saves_source(
|
def test_creating_from_url_saves_source(
|
||||||
config_injector, context_factory, post_factory, user_factory):
|
config_injector, context_factory, post_factory, user_factory):
|
||||||
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
||||||
|
@ -141,12 +152,12 @@ def test_creating_from_url_saves_source(
|
||||||
db.session.add(post)
|
db.session.add(post)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
|
|
||||||
with unittest.mock.patch('szurubooru.func.net.download'), \
|
with patch('szurubooru.func.net.download'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.export_to_json'), \
|
patch('szurubooru.func.tags.export_to_json'), \
|
||||||
unittest.mock.patch('szurubooru.func.snapshots.save_entity_creation'), \
|
patch('szurubooru.func.snapshots.save_entity_creation'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.serialize_post'), \
|
patch('szurubooru.func.posts.serialize_post'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.create_post'), \
|
patch('szurubooru.func.posts.create_post'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.update_post_source'):
|
patch('szurubooru.func.posts.update_post_source'):
|
||||||
config_injector({
|
config_injector({
|
||||||
'privileges': {'posts:create:identified': db.User.RANK_REGULAR},
|
'privileges': {'posts:create:identified': db.User.RANK_REGULAR},
|
||||||
})
|
})
|
||||||
|
@ -165,6 +176,7 @@ def test_creating_from_url_saves_source(
|
||||||
b'content', ['tag1', 'tag2'], auth_user)
|
b'content', ['tag1', 'tag2'], auth_user)
|
||||||
posts.update_post_source.assert_called_once_with(post, 'example.com')
|
posts.update_post_source.assert_called_once_with(post, 'example.com')
|
||||||
|
|
||||||
|
|
||||||
def test_creating_from_url_with_source_specified(
|
def test_creating_from_url_with_source_specified(
|
||||||
config_injector, context_factory, post_factory, user_factory):
|
config_injector, context_factory, post_factory, user_factory):
|
||||||
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
||||||
|
@ -172,12 +184,12 @@ def test_creating_from_url_with_source_specified(
|
||||||
db.session.add(post)
|
db.session.add(post)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
|
|
||||||
with unittest.mock.patch('szurubooru.func.net.download'), \
|
with patch('szurubooru.func.net.download'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.export_to_json'), \
|
patch('szurubooru.func.tags.export_to_json'), \
|
||||||
unittest.mock.patch('szurubooru.func.snapshots.save_entity_creation'), \
|
patch('szurubooru.func.snapshots.save_entity_creation'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.serialize_post'), \
|
patch('szurubooru.func.posts.serialize_post'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.create_post'), \
|
patch('szurubooru.func.posts.create_post'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.update_post_source'):
|
patch('szurubooru.func.posts.update_post_source'):
|
||||||
config_injector({
|
config_injector({
|
||||||
'privileges': {'posts:create:identified': db.User.RANK_REGULAR},
|
'privileges': {'posts:create:identified': db.User.RANK_REGULAR},
|
||||||
})
|
})
|
||||||
|
@ -197,6 +209,7 @@ def test_creating_from_url_with_source_specified(
|
||||||
b'content', ['tag1', 'tag2'], auth_user)
|
b'content', ['tag1', 'tag2'], auth_user)
|
||||||
posts.update_post_source.assert_called_once_with(post, 'example2.com')
|
posts.update_post_source.assert_called_once_with(post, 'example2.com')
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('field', ['tags', 'safety'])
|
@pytest.mark.parametrize('field', ['tags', 'safety'])
|
||||||
def test_trying_to_omit_mandatory_field(context_factory, user_factory, field):
|
def test_trying_to_omit_mandatory_field(context_factory, user_factory, field):
|
||||||
params = {
|
params = {
|
||||||
|
@ -211,6 +224,7 @@ def test_trying_to_omit_mandatory_field(context_factory, user_factory, field):
|
||||||
files={'content': '...'},
|
files={'content': '...'},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=db.User.RANK_REGULAR)))
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_omit_content(context_factory, user_factory):
|
def test_trying_to_omit_content(context_factory, user_factory):
|
||||||
with pytest.raises(errors.MissingRequiredFileError):
|
with pytest.raises(errors.MissingRequiredFileError):
|
||||||
api.post_api.create_post(
|
api.post_api.create_post(
|
||||||
|
@ -221,12 +235,15 @@ def test_trying_to_omit_content(context_factory, user_factory):
|
||||||
},
|
},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=db.User.RANK_REGULAR)))
|
||||||
|
|
||||||
def test_trying_to_create_post_without_privileges(context_factory, user_factory):
|
|
||||||
|
def test_trying_to_create_post_without_privileges(
|
||||||
|
context_factory, user_factory):
|
||||||
with pytest.raises(errors.AuthError):
|
with pytest.raises(errors.AuthError):
|
||||||
api.post_api.create_post(context_factory(
|
api.post_api.create_post(context_factory(
|
||||||
params='whatever',
|
params='whatever',
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
|
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_create_tags_without_privileges(
|
def test_trying_to_create_tags_without_privileges(
|
||||||
config_injector, context_factory, user_factory):
|
config_injector, context_factory, user_factory):
|
||||||
config_injector({
|
config_injector({
|
||||||
|
@ -237,8 +254,8 @@ def test_trying_to_create_tags_without_privileges(
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
with pytest.raises(errors.AuthError), \
|
with pytest.raises(errors.AuthError), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.update_post_content'), \
|
patch('szurubooru.func.posts.update_post_content'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.update_post_tags'):
|
patch('szurubooru.func.posts.update_post_tags'):
|
||||||
posts.update_post_tags.return_value = ['new-tag']
|
posts.update_post_tags.return_value = ['new-tag']
|
||||||
api.post_api.create_post(
|
api.post_api.create_post(
|
||||||
context_factory(
|
context_factory(
|
||||||
|
|
|
@ -1,16 +1,18 @@
|
||||||
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
import unittest.mock
|
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, errors
|
||||||
from szurubooru.func import posts, tags
|
from szurubooru.func import posts, tags
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({'privileges': {'posts:delete': db.User.RANK_REGULAR}})
|
config_injector({'privileges': {'posts:delete': db.User.RANK_REGULAR}})
|
||||||
|
|
||||||
|
|
||||||
def test_deleting(user_factory, post_factory, context_factory):
|
def test_deleting(user_factory, post_factory, context_factory):
|
||||||
db.session.add(post_factory(id=1))
|
db.session.add(post_factory(id=1))
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
with unittest.mock.patch('szurubooru.func.tags.export_to_json'):
|
with patch('szurubooru.func.tags.export_to_json'):
|
||||||
result = api.post_api.delete_post(
|
result = api.post_api.delete_post(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'version': 1},
|
params={'version': 1},
|
||||||
|
@ -20,12 +22,14 @@ def test_deleting(user_factory, post_factory, context_factory):
|
||||||
assert db.session.query(db.Post).count() == 0
|
assert db.session.query(db.Post).count() == 0
|
||||||
tags.export_to_json.assert_called_once_with()
|
tags.export_to_json.assert_called_once_with()
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_delete_non_existing(user_factory, context_factory):
|
def test_trying_to_delete_non_existing(user_factory, context_factory):
|
||||||
with pytest.raises(posts.PostNotFoundError):
|
with pytest.raises(posts.PostNotFoundError):
|
||||||
api.post_api.delete_post(
|
api.post_api.delete_post(
|
||||||
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
|
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
|
||||||
{'post_id': 999})
|
{'post_id': 999})
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_delete_without_privileges(
|
def test_trying_to_delete_without_privileges(
|
||||||
user_factory, post_factory, context_factory):
|
user_factory, post_factory, context_factory):
|
||||||
db.session.add(post_factory(id=1))
|
db.session.add(post_factory(id=1))
|
||||||
|
|
|
@ -1,20 +1,22 @@
|
||||||
import pytest
|
|
||||||
import unittest.mock
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from unittest.mock import patch
|
||||||
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, errors
|
||||||
from szurubooru.func import posts
|
from szurubooru.func import posts
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({'privileges': {'posts:favorite': db.User.RANK_REGULAR}})
|
config_injector({'privileges': {'posts:favorite': db.User.RANK_REGULAR}})
|
||||||
|
|
||||||
|
|
||||||
def test_adding_to_favorites(
|
def test_adding_to_favorites(
|
||||||
user_factory, post_factory, context_factory, fake_datetime):
|
user_factory, post_factory, context_factory, fake_datetime):
|
||||||
post = post_factory()
|
post = post_factory()
|
||||||
db.session.add(post)
|
db.session.add(post)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
assert post.score == 0
|
assert post.score == 0
|
||||||
with unittest.mock.patch('szurubooru.func.posts.serialize_post'), \
|
with patch('szurubooru.func.posts.serialize_post'), \
|
||||||
fake_datetime('1997-12-01'):
|
fake_datetime('1997-12-01'):
|
||||||
posts.serialize_post.return_value = 'serialized post'
|
posts.serialize_post.return_value = 'serialized post'
|
||||||
result = api.post_api.add_post_to_favorites(
|
result = api.post_api.add_post_to_favorites(
|
||||||
|
@ -27,6 +29,7 @@ def test_adding_to_favorites(
|
||||||
assert post.favorite_count == 1
|
assert post.favorite_count == 1
|
||||||
assert post.score == 1
|
assert post.score == 1
|
||||||
|
|
||||||
|
|
||||||
def test_removing_from_favorites(
|
def test_removing_from_favorites(
|
||||||
user_factory, post_factory, context_factory, fake_datetime):
|
user_factory, post_factory, context_factory, fake_datetime):
|
||||||
user = user_factory()
|
user = user_factory()
|
||||||
|
@ -34,7 +37,7 @@ def test_removing_from_favorites(
|
||||||
db.session.add(post)
|
db.session.add(post)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
assert post.score == 0
|
assert post.score == 0
|
||||||
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
|
with patch('szurubooru.func.posts.serialize_post'):
|
||||||
with fake_datetime('1997-12-01'):
|
with fake_datetime('1997-12-01'):
|
||||||
api.post_api.add_post_to_favorites(
|
api.post_api.add_post_to_favorites(
|
||||||
context_factory(user=user),
|
context_factory(user=user),
|
||||||
|
@ -49,13 +52,14 @@ def test_removing_from_favorites(
|
||||||
assert db.session.query(db.PostFavorite).count() == 0
|
assert db.session.query(db.PostFavorite).count() == 0
|
||||||
assert post.favorite_count == 0
|
assert post.favorite_count == 0
|
||||||
|
|
||||||
|
|
||||||
def test_favoriting_twice(
|
def test_favoriting_twice(
|
||||||
user_factory, post_factory, context_factory, fake_datetime):
|
user_factory, post_factory, context_factory, fake_datetime):
|
||||||
user = user_factory()
|
user = user_factory()
|
||||||
post = post_factory()
|
post = post_factory()
|
||||||
db.session.add(post)
|
db.session.add(post)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
|
with patch('szurubooru.func.posts.serialize_post'):
|
||||||
with fake_datetime('1997-12-01'):
|
with fake_datetime('1997-12-01'):
|
||||||
api.post_api.add_post_to_favorites(
|
api.post_api.add_post_to_favorites(
|
||||||
context_factory(user=user),
|
context_factory(user=user),
|
||||||
|
@ -68,13 +72,14 @@ def test_favoriting_twice(
|
||||||
assert db.session.query(db.PostFavorite).count() == 1
|
assert db.session.query(db.PostFavorite).count() == 1
|
||||||
assert post.favorite_count == 1
|
assert post.favorite_count == 1
|
||||||
|
|
||||||
|
|
||||||
def test_removing_twice(
|
def test_removing_twice(
|
||||||
user_factory, post_factory, context_factory, fake_datetime):
|
user_factory, post_factory, context_factory, fake_datetime):
|
||||||
user = user_factory()
|
user = user_factory()
|
||||||
post = post_factory()
|
post = post_factory()
|
||||||
db.session.add(post)
|
db.session.add(post)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
|
with patch('szurubooru.func.posts.serialize_post'):
|
||||||
with fake_datetime('1997-12-01'):
|
with fake_datetime('1997-12-01'):
|
||||||
api.post_api.add_post_to_favorites(
|
api.post_api.add_post_to_favorites(
|
||||||
context_factory(user=user),
|
context_factory(user=user),
|
||||||
|
@ -91,6 +96,7 @@ def test_removing_twice(
|
||||||
assert db.session.query(db.PostFavorite).count() == 0
|
assert db.session.query(db.PostFavorite).count() == 0
|
||||||
assert post.favorite_count == 0
|
assert post.favorite_count == 0
|
||||||
|
|
||||||
|
|
||||||
def test_favorites_from_multiple_users(
|
def test_favorites_from_multiple_users(
|
||||||
user_factory, post_factory, context_factory, fake_datetime):
|
user_factory, post_factory, context_factory, fake_datetime):
|
||||||
user1 = user_factory()
|
user1 = user_factory()
|
||||||
|
@ -98,7 +104,7 @@ def test_favorites_from_multiple_users(
|
||||||
post = post_factory()
|
post = post_factory()
|
||||||
db.session.add_all([user1, user2, post])
|
db.session.add_all([user1, user2, post])
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
|
with patch('szurubooru.func.posts.serialize_post'):
|
||||||
with fake_datetime('1997-12-01'):
|
with fake_datetime('1997-12-01'):
|
||||||
api.post_api.add_post_to_favorites(
|
api.post_api.add_post_to_favorites(
|
||||||
context_factory(user=user1),
|
context_factory(user=user1),
|
||||||
|
@ -112,12 +118,14 @@ def test_favorites_from_multiple_users(
|
||||||
assert post.favorite_count == 2
|
assert post.favorite_count == 2
|
||||||
assert post.last_favorite_time == datetime(1997, 12, 2)
|
assert post.last_favorite_time == datetime(1997, 12, 2)
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_update_non_existing(user_factory, context_factory):
|
def test_trying_to_update_non_existing(user_factory, context_factory):
|
||||||
with pytest.raises(posts.PostNotFoundError):
|
with pytest.raises(posts.PostNotFoundError):
|
||||||
api.post_api.add_post_to_favorites(
|
api.post_api.add_post_to_favorites(
|
||||||
context_factory(user=user_factory()),
|
context_factory(user=user_factory()),
|
||||||
{'post_id': 5})
|
{'post_id': 5})
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_rate_without_privileges(
|
def test_trying_to_rate_without_privileges(
|
||||||
user_factory, post_factory, context_factory):
|
user_factory, post_factory, context_factory):
|
||||||
post = post_factory()
|
post = post_factory()
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
import unittest.mock
|
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, errors
|
||||||
from szurubooru.func import posts
|
from szurubooru.func import posts
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
|
@ -12,14 +13,12 @@ def inject_config(config_injector):
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
def test_no_featured_post(user_factory, post_factory, context_factory):
|
|
||||||
assert posts.try_get_featured_post() is None
|
|
||||||
|
|
||||||
def test_featuring(user_factory, post_factory, context_factory):
|
def test_featuring(user_factory, post_factory, context_factory):
|
||||||
db.session.add(post_factory(id=1))
|
db.session.add(post_factory(id=1))
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
assert not posts.get_post_by_id(1).is_featured
|
assert not posts.get_post_by_id(1).is_featured
|
||||||
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
|
with patch('szurubooru.func.posts.serialize_post'):
|
||||||
posts.serialize_post.return_value = 'serialized post'
|
posts.serialize_post.return_value = 'serialized post'
|
||||||
result = api.post_api.set_featured_post(
|
result = api.post_api.set_featured_post(
|
||||||
context_factory(
|
context_factory(
|
||||||
|
@ -34,18 +33,19 @@ def test_featuring(user_factory, post_factory, context_factory):
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=db.User.RANK_REGULAR)))
|
||||||
assert result == 'serialized post'
|
assert result == 'serialized post'
|
||||||
|
|
||||||
def test_trying_to_omit_required_parameter(
|
|
||||||
user_factory, post_factory, context_factory):
|
def test_trying_to_omit_required_parameter(user_factory, context_factory):
|
||||||
with pytest.raises(errors.MissingRequiredParameterError):
|
with pytest.raises(errors.MissingRequiredParameterError):
|
||||||
api.post_api.set_featured_post(
|
api.post_api.set_featured_post(
|
||||||
context_factory(
|
context_factory(
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=db.User.RANK_REGULAR)))
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_feature_the_same_post_twice(
|
def test_trying_to_feature_the_same_post_twice(
|
||||||
user_factory, post_factory, context_factory):
|
user_factory, post_factory, context_factory):
|
||||||
db.session.add(post_factory(id=1))
|
db.session.add(post_factory(id=1))
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
|
with patch('szurubooru.func.posts.serialize_post'):
|
||||||
api.post_api.set_featured_post(
|
api.post_api.set_featured_post(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'id': 1},
|
params={'id': 1},
|
||||||
|
@ -56,6 +56,7 @@ def test_trying_to_feature_the_same_post_twice(
|
||||||
params={'id': 1},
|
params={'id': 1},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=db.User.RANK_REGULAR)))
|
||||||
|
|
||||||
|
|
||||||
def test_featuring_one_post_after_another(
|
def test_featuring_one_post_after_another(
|
||||||
user_factory, post_factory, context_factory, fake_datetime):
|
user_factory, post_factory, context_factory, fake_datetime):
|
||||||
db.session.add(post_factory(id=1))
|
db.session.add(post_factory(id=1))
|
||||||
|
@ -64,14 +65,14 @@ def test_featuring_one_post_after_another(
|
||||||
assert posts.try_get_featured_post() is None
|
assert posts.try_get_featured_post() is None
|
||||||
assert not posts.get_post_by_id(1).is_featured
|
assert not posts.get_post_by_id(1).is_featured
|
||||||
assert not posts.get_post_by_id(2).is_featured
|
assert not posts.get_post_by_id(2).is_featured
|
||||||
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
|
with patch('szurubooru.func.posts.serialize_post'):
|
||||||
with fake_datetime('1997'):
|
with fake_datetime('1997'):
|
||||||
result = api.post_api.set_featured_post(
|
api.post_api.set_featured_post(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'id': 1},
|
params={'id': 1},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=db.User.RANK_REGULAR)))
|
||||||
with fake_datetime('1998'):
|
with fake_datetime('1998'):
|
||||||
result = api.post_api.set_featured_post(
|
api.post_api.set_featured_post(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'id': 2},
|
params={'id': 2},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=db.User.RANK_REGULAR)))
|
||||||
|
@ -80,6 +81,7 @@ def test_featuring_one_post_after_another(
|
||||||
assert not posts.get_post_by_id(1).is_featured
|
assert not posts.get_post_by_id(1).is_featured
|
||||||
assert posts.get_post_by_id(2).is_featured
|
assert posts.get_post_by_id(2).is_featured
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_feature_non_existing(user_factory, context_factory):
|
def test_trying_to_feature_non_existing(user_factory, context_factory):
|
||||||
with pytest.raises(posts.PostNotFoundError):
|
with pytest.raises(posts.PostNotFoundError):
|
||||||
api.post_api.set_featured_post(
|
api.post_api.set_featured_post(
|
||||||
|
@ -87,6 +89,7 @@ def test_trying_to_feature_non_existing(user_factory, context_factory):
|
||||||
params={'id': 1},
|
params={'id': 1},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=db.User.RANK_REGULAR)))
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_feature_without_privileges(user_factory, context_factory):
|
def test_trying_to_feature_without_privileges(user_factory, context_factory):
|
||||||
with pytest.raises(errors.AuthError):
|
with pytest.raises(errors.AuthError):
|
||||||
api.post_api.set_featured_post(
|
api.post_api.set_featured_post(
|
||||||
|
@ -94,6 +97,7 @@ def test_trying_to_feature_without_privileges(user_factory, context_factory):
|
||||||
params={'id': 1},
|
params={'id': 1},
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
|
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
|
||||||
|
|
||||||
|
|
||||||
def test_getting_featured_post_without_privileges_to_view(
|
def test_getting_featured_post_without_privileges_to_view(
|
||||||
user_factory, context_factory):
|
user_factory, context_factory):
|
||||||
api.post_api.get_featured_post(
|
api.post_api.get_featured_post(
|
||||||
|
|
|
@ -1,87 +1,93 @@
|
||||||
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
import unittest.mock
|
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, errors
|
||||||
from szurubooru.func import posts, scores
|
from szurubooru.func import posts
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({'privileges': {'posts:score': db.User.RANK_REGULAR}})
|
config_injector({'privileges': {'posts:score': db.User.RANK_REGULAR}})
|
||||||
|
|
||||||
|
|
||||||
def test_simple_rating(
|
def test_simple_rating(
|
||||||
user_factory, post_factory, context_factory, fake_datetime):
|
user_factory, post_factory, context_factory, fake_datetime):
|
||||||
post = post_factory()
|
post = post_factory()
|
||||||
db.session.add(post)
|
db.session.add(post)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
|
with patch('szurubooru.func.posts.serialize_post'), \
|
||||||
|
fake_datetime('1997-12-01'):
|
||||||
posts.serialize_post.return_value = 'serialized post'
|
posts.serialize_post.return_value = 'serialized post'
|
||||||
with fake_datetime('1997-12-01'):
|
result = api.post_api.set_post_score(
|
||||||
result = api.post_api.set_post_score(
|
context_factory(
|
||||||
context_factory(
|
params={'score': 1}, user=user_factory()),
|
||||||
params={'score': 1}, user=user_factory()),
|
{'post_id': post.post_id})
|
||||||
{'post_id': post.post_id})
|
|
||||||
assert result == 'serialized post'
|
assert result == 'serialized post'
|
||||||
post = db.session.query(db.Post).one()
|
post = db.session.query(db.Post).one()
|
||||||
assert db.session.query(db.PostScore).count() == 1
|
assert db.session.query(db.PostScore).count() == 1
|
||||||
assert post is not None
|
assert post is not None
|
||||||
assert post.score == 1
|
assert post.score == 1
|
||||||
|
|
||||||
|
|
||||||
def test_updating_rating(
|
def test_updating_rating(
|
||||||
user_factory, post_factory, context_factory, fake_datetime):
|
user_factory, post_factory, context_factory, fake_datetime):
|
||||||
user = user_factory()
|
user = user_factory()
|
||||||
post = post_factory()
|
post = post_factory()
|
||||||
db.session.add(post)
|
db.session.add(post)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
|
with patch('szurubooru.func.posts.serialize_post'):
|
||||||
with fake_datetime('1997-12-01'):
|
with fake_datetime('1997-12-01'):
|
||||||
result = api.post_api.set_post_score(
|
api.post_api.set_post_score(
|
||||||
context_factory(params={'score': 1}, user=user),
|
context_factory(params={'score': 1}, user=user),
|
||||||
{'post_id': post.post_id})
|
{'post_id': post.post_id})
|
||||||
with fake_datetime('1997-12-02'):
|
with fake_datetime('1997-12-02'):
|
||||||
result = api.post_api.set_post_score(
|
api.post_api.set_post_score(
|
||||||
context_factory(params={'score': -1}, user=user),
|
context_factory(params={'score': -1}, user=user),
|
||||||
{'post_id': post.post_id})
|
{'post_id': post.post_id})
|
||||||
post = db.session.query(db.Post).one()
|
post = db.session.query(db.Post).one()
|
||||||
assert db.session.query(db.PostScore).count() == 1
|
assert db.session.query(db.PostScore).count() == 1
|
||||||
assert post.score == -1
|
assert post.score == -1
|
||||||
|
|
||||||
|
|
||||||
def test_updating_rating_to_zero(
|
def test_updating_rating_to_zero(
|
||||||
user_factory, post_factory, context_factory, fake_datetime):
|
user_factory, post_factory, context_factory, fake_datetime):
|
||||||
user = user_factory()
|
user = user_factory()
|
||||||
post = post_factory()
|
post = post_factory()
|
||||||
db.session.add(post)
|
db.session.add(post)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
|
with patch('szurubooru.func.posts.serialize_post'):
|
||||||
with fake_datetime('1997-12-01'):
|
with fake_datetime('1997-12-01'):
|
||||||
result = api.post_api.set_post_score(
|
api.post_api.set_post_score(
|
||||||
context_factory(params={'score': 1}, user=user),
|
context_factory(params={'score': 1}, user=user),
|
||||||
{'post_id': post.post_id})
|
{'post_id': post.post_id})
|
||||||
with fake_datetime('1997-12-02'):
|
with fake_datetime('1997-12-02'):
|
||||||
result = api.post_api.set_post_score(
|
api.post_api.set_post_score(
|
||||||
context_factory(params={'score': 0}, user=user),
|
context_factory(params={'score': 0}, user=user),
|
||||||
{'post_id': post.post_id})
|
{'post_id': post.post_id})
|
||||||
post = db.session.query(db.Post).one()
|
post = db.session.query(db.Post).one()
|
||||||
assert db.session.query(db.PostScore).count() == 0
|
assert db.session.query(db.PostScore).count() == 0
|
||||||
assert post.score == 0
|
assert post.score == 0
|
||||||
|
|
||||||
|
|
||||||
def test_deleting_rating(
|
def test_deleting_rating(
|
||||||
user_factory, post_factory, context_factory, fake_datetime):
|
user_factory, post_factory, context_factory, fake_datetime):
|
||||||
user = user_factory()
|
user = user_factory()
|
||||||
post = post_factory()
|
post = post_factory()
|
||||||
db.session.add(post)
|
db.session.add(post)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
|
with patch('szurubooru.func.posts.serialize_post'):
|
||||||
with fake_datetime('1997-12-01'):
|
with fake_datetime('1997-12-01'):
|
||||||
result = api.post_api.set_post_score(
|
api.post_api.set_post_score(
|
||||||
context_factory(params={'score': 1}, user=user),
|
context_factory(params={'score': 1}, user=user),
|
||||||
{'post_id': post.post_id})
|
{'post_id': post.post_id})
|
||||||
with fake_datetime('1997-12-02'):
|
with fake_datetime('1997-12-02'):
|
||||||
result = api.post_api.delete_post_score(
|
api.post_api.delete_post_score(
|
||||||
context_factory(user=user),
|
context_factory(user=user),
|
||||||
{'post_id': post.post_id})
|
{'post_id': post.post_id})
|
||||||
post = db.session.query(db.Post).one()
|
post = db.session.query(db.Post).one()
|
||||||
assert db.session.query(db.PostScore).count() == 0
|
assert db.session.query(db.PostScore).count() == 0
|
||||||
assert post.score == 0
|
assert post.score == 0
|
||||||
|
|
||||||
|
|
||||||
def test_ratings_from_multiple_users(
|
def test_ratings_from_multiple_users(
|
||||||
user_factory, post_factory, context_factory, fake_datetime):
|
user_factory, post_factory, context_factory, fake_datetime):
|
||||||
user1 = user_factory()
|
user1 = user_factory()
|
||||||
|
@ -89,19 +95,20 @@ def test_ratings_from_multiple_users(
|
||||||
post = post_factory()
|
post = post_factory()
|
||||||
db.session.add_all([user1, user2, post])
|
db.session.add_all([user1, user2, post])
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
|
with patch('szurubooru.func.posts.serialize_post'):
|
||||||
with fake_datetime('1997-12-01'):
|
with fake_datetime('1997-12-01'):
|
||||||
result = api.post_api.set_post_score(
|
api.post_api.set_post_score(
|
||||||
context_factory(params={'score': 1}, user=user1),
|
context_factory(params={'score': 1}, user=user1),
|
||||||
{'post_id': post.post_id})
|
{'post_id': post.post_id})
|
||||||
with fake_datetime('1997-12-02'):
|
with fake_datetime('1997-12-02'):
|
||||||
result = api.post_api.set_post_score(
|
api.post_api.set_post_score(
|
||||||
context_factory(params={'score': -1}, user=user2),
|
context_factory(params={'score': -1}, user=user2),
|
||||||
{'post_id': post.post_id})
|
{'post_id': post.post_id})
|
||||||
post = db.session.query(db.Post).one()
|
post = db.session.query(db.Post).one()
|
||||||
assert db.session.query(db.PostScore).count() == 2
|
assert db.session.query(db.PostScore).count() == 2
|
||||||
assert post.score == 0
|
assert post.score == 0
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_omit_mandatory_field(
|
def test_trying_to_omit_mandatory_field(
|
||||||
user_factory, post_factory, context_factory):
|
user_factory, post_factory, context_factory):
|
||||||
post = post_factory()
|
post = post_factory()
|
||||||
|
@ -112,13 +119,14 @@ def test_trying_to_omit_mandatory_field(
|
||||||
context_factory(params={}, user=user_factory()),
|
context_factory(params={}, user=user_factory()),
|
||||||
{'post_id': post.post_id})
|
{'post_id': post.post_id})
|
||||||
|
|
||||||
def test_trying_to_update_non_existing(
|
|
||||||
user_factory, post_factory, context_factory):
|
def test_trying_to_update_non_existing(user_factory, context_factory):
|
||||||
with pytest.raises(posts.PostNotFoundError):
|
with pytest.raises(posts.PostNotFoundError):
|
||||||
api.post_api.set_post_score(
|
api.post_api.set_post_score(
|
||||||
context_factory(params={'score': 1}, user=user_factory()),
|
context_factory(params={'score': 1}, user=user_factory()),
|
||||||
{'post_id': 5})
|
{'post_id': 5})
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_rate_without_privileges(
|
def test_trying_to_rate_without_privileges(
|
||||||
user_factory, post_factory, context_factory):
|
user_factory, post_factory, context_factory):
|
||||||
post = post_factory()
|
post = post_factory()
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
import pytest
|
|
||||||
import unittest.mock
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from unittest.mock import patch
|
||||||
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, errors
|
||||||
from szurubooru.func import posts
|
from szurubooru.func import posts
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(tmpdir, config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
'privileges': {
|
'privileges': {
|
||||||
'posts:list': db.User.RANK_REGULAR,
|
'posts:list': db.User.RANK_REGULAR,
|
||||||
|
@ -13,11 +14,12 @@ def inject_config(tmpdir, config_injector):
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def test_retrieving_multiple(user_factory, post_factory, context_factory):
|
def test_retrieving_multiple(user_factory, post_factory, context_factory):
|
||||||
post1 = post_factory(id=1)
|
post1 = post_factory(id=1)
|
||||||
post2 = post_factory(id=2)
|
post2 = post_factory(id=2)
|
||||||
db.session.add_all([post1, post2])
|
db.session.add_all([post1, post2])
|
||||||
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
|
with patch('szurubooru.func.posts.serialize_post'):
|
||||||
posts.serialize_post.return_value = 'serialized post'
|
posts.serialize_post.return_value = 'serialized post'
|
||||||
result = api.post_api.get_posts(
|
result = api.post_api.get_posts(
|
||||||
context_factory(
|
context_factory(
|
||||||
|
@ -31,6 +33,7 @@ def test_retrieving_multiple(user_factory, post_factory, context_factory):
|
||||||
'results': ['serialized post', 'serialized post'],
|
'results': ['serialized post', 'serialized post'],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_using_special_tokens(user_factory, post_factory, context_factory):
|
def test_using_special_tokens(user_factory, post_factory, context_factory):
|
||||||
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
||||||
post1 = post_factory(id=1)
|
post1 = post_factory(id=1)
|
||||||
|
@ -39,7 +42,7 @@ def test_using_special_tokens(user_factory, post_factory, context_factory):
|
||||||
user=auth_user, time=datetime.utcnow())]
|
user=auth_user, time=datetime.utcnow())]
|
||||||
db.session.add_all([post1, post2, auth_user])
|
db.session.add_all([post1, post2, auth_user])
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
|
with patch('szurubooru.func.posts.serialize_post'):
|
||||||
posts.serialize_post.side_effect = \
|
posts.serialize_post.side_effect = \
|
||||||
lambda post, *_args, **_kwargs: \
|
lambda post, *_args, **_kwargs: \
|
||||||
'serialized post %d' % post.post_id
|
'serialized post %d' % post.post_id
|
||||||
|
@ -55,8 +58,9 @@ def test_using_special_tokens(user_factory, post_factory, context_factory):
|
||||||
'results': ['serialized post 1'],
|
'results': ['serialized post 1'],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_use_special_tokens_without_logging_in(
|
def test_trying_to_use_special_tokens_without_logging_in(
|
||||||
user_factory, post_factory, context_factory, config_injector):
|
user_factory, context_factory, config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
'privileges': {'posts:list': 'anonymous'},
|
'privileges': {'posts:list': 'anonymous'},
|
||||||
})
|
})
|
||||||
|
@ -66,6 +70,7 @@ def test_trying_to_use_special_tokens_without_logging_in(
|
||||||
params={'query': 'special:fav', 'page': 1},
|
params={'query': 'special:fav', 'page': 1},
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
|
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_retrieve_multiple_without_privileges(
|
def test_trying_to_retrieve_multiple_without_privileges(
|
||||||
user_factory, context_factory):
|
user_factory, context_factory):
|
||||||
with pytest.raises(errors.AuthError):
|
with pytest.raises(errors.AuthError):
|
||||||
|
@ -74,21 +79,24 @@ def test_trying_to_retrieve_multiple_without_privileges(
|
||||||
params={'query': '', 'page': 1},
|
params={'query': '', 'page': 1},
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
|
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
|
||||||
|
|
||||||
|
|
||||||
def test_retrieving_single(user_factory, post_factory, context_factory):
|
def test_retrieving_single(user_factory, post_factory, context_factory):
|
||||||
db.session.add(post_factory(id=1))
|
db.session.add(post_factory(id=1))
|
||||||
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
|
with patch('szurubooru.func.posts.serialize_post'):
|
||||||
posts.serialize_post.return_value = 'serialized post'
|
posts.serialize_post.return_value = 'serialized post'
|
||||||
result = api.post_api.get_post(
|
result = api.post_api.get_post(
|
||||||
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
|
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
|
||||||
{'post_id': 1})
|
{'post_id': 1})
|
||||||
assert result == 'serialized post'
|
assert result == 'serialized post'
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
|
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
|
||||||
with pytest.raises(posts.PostNotFoundError):
|
with pytest.raises(posts.PostNotFoundError):
|
||||||
api.post_api.get_post(
|
api.post_api.get_post(
|
||||||
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
|
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
|
||||||
{'post_id': 999})
|
{'post_id': 999})
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_retrieve_single_without_privileges(
|
def test_trying_to_retrieve_single_without_privileges(
|
||||||
user_factory, context_factory):
|
user_factory, context_factory):
|
||||||
with pytest.raises(errors.AuthError):
|
with pytest.raises(errors.AuthError):
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
import pytest
|
|
||||||
import unittest.mock
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from unittest.mock import patch
|
||||||
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, errors
|
||||||
from szurubooru.func import posts, tags, snapshots, net
|
from szurubooru.func import posts, tags, snapshots, net
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(tmpdir, config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
'privileges': {
|
'privileges': {
|
||||||
'posts:edit:tags': db.User.RANK_REGULAR,
|
'posts:edit:tags': db.User.RANK_REGULAR,
|
||||||
|
@ -20,6 +21,7 @@ def inject_config(tmpdir, config_injector):
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def test_post_updating(
|
def test_post_updating(
|
||||||
context_factory, post_factory, user_factory, fake_datetime):
|
context_factory, post_factory, user_factory, fake_datetime):
|
||||||
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
||||||
|
@ -27,18 +29,18 @@ def test_post_updating(
|
||||||
db.session.add(post)
|
db.session.add(post)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
|
|
||||||
with unittest.mock.patch('szurubooru.func.posts.create_post'), \
|
with patch('szurubooru.func.posts.create_post'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.update_post_tags'), \
|
patch('szurubooru.func.posts.update_post_tags'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.update_post_content'), \
|
patch('szurubooru.func.posts.update_post_content'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.update_post_thumbnail'), \
|
patch('szurubooru.func.posts.update_post_thumbnail'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.update_post_safety'), \
|
patch('szurubooru.func.posts.update_post_safety'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.update_post_source'), \
|
patch('szurubooru.func.posts.update_post_source'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.update_post_relations'), \
|
patch('szurubooru.func.posts.update_post_relations'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.update_post_notes'), \
|
patch('szurubooru.func.posts.update_post_notes'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.update_post_flags'), \
|
patch('szurubooru.func.posts.update_post_flags'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.serialize_post'), \
|
patch('szurubooru.func.posts.serialize_post'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.export_to_json'), \
|
patch('szurubooru.func.tags.export_to_json'), \
|
||||||
unittest.mock.patch('szurubooru.func.snapshots.save_entity_modification'), \
|
patch('szurubooru.func.snapshots.save_entity_modification'), \
|
||||||
fake_datetime('1997-01-01'):
|
fake_datetime('1997-01-01'):
|
||||||
posts.serialize_post.return_value = 'serialized post'
|
posts.serialize_post.return_value = 'serialized post'
|
||||||
|
|
||||||
|
@ -64,28 +66,34 @@ def test_post_updating(
|
||||||
posts.create_post.assert_not_called()
|
posts.create_post.assert_not_called()
|
||||||
posts.update_post_tags.assert_called_once_with(post, ['tag1', 'tag2'])
|
posts.update_post_tags.assert_called_once_with(post, ['tag1', 'tag2'])
|
||||||
posts.update_post_content.assert_called_once_with(post, 'post-content')
|
posts.update_post_content.assert_called_once_with(post, 'post-content')
|
||||||
posts.update_post_thumbnail.assert_called_once_with(post, 'post-thumbnail')
|
posts.update_post_thumbnail.assert_called_once_with(
|
||||||
|
post, 'post-thumbnail')
|
||||||
posts.update_post_safety.assert_called_once_with(post, 'safe')
|
posts.update_post_safety.assert_called_once_with(post, 'safe')
|
||||||
posts.update_post_source.assert_called_once_with(post, 'source')
|
posts.update_post_source.assert_called_once_with(post, 'source')
|
||||||
posts.update_post_relations.assert_called_once_with(post, [1, 2])
|
posts.update_post_relations.assert_called_once_with(post, [1, 2])
|
||||||
posts.update_post_notes.assert_called_once_with(post, ['note1', 'note2'])
|
posts.update_post_notes.assert_called_once_with(
|
||||||
posts.update_post_flags.assert_called_once_with(post, ['flag1', 'flag2'])
|
post, ['note1', 'note2'])
|
||||||
posts.serialize_post.assert_called_once_with(post, auth_user, options=None)
|
posts.update_post_flags.assert_called_once_with(
|
||||||
|
post, ['flag1', 'flag2'])
|
||||||
|
posts.serialize_post.assert_called_once_with(
|
||||||
|
post, auth_user, options=None)
|
||||||
tags.export_to_json.assert_called_once_with()
|
tags.export_to_json.assert_called_once_with()
|
||||||
snapshots.save_entity_modification.assert_called_once_with(post, auth_user)
|
snapshots.save_entity_modification.assert_called_once_with(
|
||||||
|
post, auth_user)
|
||||||
assert post.last_edit_time == datetime(1997, 1, 1)
|
assert post.last_edit_time == datetime(1997, 1, 1)
|
||||||
|
|
||||||
|
|
||||||
def test_uploading_from_url_saves_source(
|
def test_uploading_from_url_saves_source(
|
||||||
context_factory, post_factory, user_factory):
|
context_factory, post_factory, user_factory):
|
||||||
post = post_factory()
|
post = post_factory()
|
||||||
db.session.add(post)
|
db.session.add(post)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
with unittest.mock.patch('szurubooru.func.net.download'), \
|
with patch('szurubooru.func.net.download'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.export_to_json'), \
|
patch('szurubooru.func.tags.export_to_json'), \
|
||||||
unittest.mock.patch('szurubooru.func.snapshots.save_entity_modification'), \
|
patch('szurubooru.func.snapshots.save_entity_modification'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.serialize_post'), \
|
patch('szurubooru.func.posts.serialize_post'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.update_post_content'), \
|
patch('szurubooru.func.posts.update_post_content'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.update_post_source'):
|
patch('szurubooru.func.posts.update_post_source'):
|
||||||
net.download.return_value = b'content'
|
net.download.return_value = b'content'
|
||||||
api.post_api.update_post(
|
api.post_api.update_post(
|
||||||
context_factory(
|
context_factory(
|
||||||
|
@ -96,17 +104,18 @@ def test_uploading_from_url_saves_source(
|
||||||
posts.update_post_content.assert_called_once_with(post, b'content')
|
posts.update_post_content.assert_called_once_with(post, b'content')
|
||||||
posts.update_post_source.assert_called_once_with(post, 'example.com')
|
posts.update_post_source.assert_called_once_with(post, 'example.com')
|
||||||
|
|
||||||
|
|
||||||
def test_uploading_from_url_with_source_specified(
|
def test_uploading_from_url_with_source_specified(
|
||||||
context_factory, post_factory, user_factory):
|
context_factory, post_factory, user_factory):
|
||||||
post = post_factory()
|
post = post_factory()
|
||||||
db.session.add(post)
|
db.session.add(post)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
with unittest.mock.patch('szurubooru.func.net.download'), \
|
with patch('szurubooru.func.net.download'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.export_to_json'), \
|
patch('szurubooru.func.tags.export_to_json'), \
|
||||||
unittest.mock.patch('szurubooru.func.snapshots.save_entity_modification'), \
|
patch('szurubooru.func.snapshots.save_entity_modification'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.serialize_post'), \
|
patch('szurubooru.func.posts.serialize_post'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.update_post_content'), \
|
patch('szurubooru.func.posts.update_post_content'), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.update_post_source'):
|
patch('szurubooru.func.posts.update_post_source'):
|
||||||
net.download.return_value = b'content'
|
net.download.return_value = b'content'
|
||||||
api.post_api.update_post(
|
api.post_api.update_post(
|
||||||
context_factory(
|
context_factory(
|
||||||
|
@ -120,6 +129,7 @@ def test_uploading_from_url_with_source_specified(
|
||||||
posts.update_post_content.assert_called_once_with(post, b'content')
|
posts.update_post_content.assert_called_once_with(post, b'content')
|
||||||
posts.update_post_source.assert_called_once_with(post, 'example2.com')
|
posts.update_post_source.assert_called_once_with(post, 'example2.com')
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_update_non_existing(context_factory, user_factory):
|
def test_trying_to_update_non_existing(context_factory, user_factory):
|
||||||
with pytest.raises(posts.PostNotFoundError):
|
with pytest.raises(posts.PostNotFoundError):
|
||||||
api.post_api.update_post(
|
api.post_api.update_post(
|
||||||
|
@ -128,18 +138,19 @@ def test_trying_to_update_non_existing(context_factory, user_factory):
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=db.User.RANK_REGULAR)),
|
||||||
{'post_id': 1})
|
{'post_id': 1})
|
||||||
|
|
||||||
@pytest.mark.parametrize('privilege,files,params', [
|
|
||||||
('posts:edit:tags', {}, {'tags': '...'}),
|
@pytest.mark.parametrize('files,params', [
|
||||||
('posts:edit:safety', {}, {'safety': '...'}),
|
({}, {'tags': '...'}),
|
||||||
('posts:edit:source', {}, {'source': '...'}),
|
({}, {'safety': '...'}),
|
||||||
('posts:edit:relations', {}, {'relations': '...'}),
|
({}, {'source': '...'}),
|
||||||
('posts:edit:notes', {}, {'notes': '...'}),
|
({}, {'relations': '...'}),
|
||||||
('posts:edit:flags', {}, {'flags': '...'}),
|
({}, {'notes': '...'}),
|
||||||
('posts:edit:content', {'content': '...'}, {}),
|
({}, {'flags': '...'}),
|
||||||
('posts:edit:thumbnail', {'thumbnail': '...'}, {}),
|
({'content': '...'}, {}),
|
||||||
|
({'thumbnail': '...'}, {}),
|
||||||
])
|
])
|
||||||
def test_trying_to_update_field_without_privileges(
|
def test_trying_to_update_field_without_privileges(
|
||||||
context_factory, post_factory, user_factory, files, params, privilege):
|
context_factory, post_factory, user_factory, files, params):
|
||||||
post = post_factory()
|
post = post_factory()
|
||||||
db.session.add(post)
|
db.session.add(post)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
|
@ -151,13 +162,14 @@ def test_trying_to_update_field_without_privileges(
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
|
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
|
||||||
{'post_id': post.post_id})
|
{'post_id': post.post_id})
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_create_tags_without_privileges(
|
def test_trying_to_create_tags_without_privileges(
|
||||||
context_factory, post_factory, user_factory):
|
context_factory, post_factory, user_factory):
|
||||||
post = post_factory()
|
post = post_factory()
|
||||||
db.session.add(post)
|
db.session.add(post)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
with pytest.raises(errors.AuthError), \
|
with pytest.raises(errors.AuthError), \
|
||||||
unittest.mock.patch('szurubooru.func.posts.update_post_tags'):
|
patch('szurubooru.func.posts.update_post_tags'):
|
||||||
posts.update_post_tags.return_value = ['new-tag']
|
posts.update_post_tags.return_value = ['new-tag']
|
||||||
api.post_api.update_post(
|
api.post_api.update_post(
|
||||||
context_factory(
|
context_factory(
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
import pytest
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, errors
|
||||||
|
|
||||||
|
|
||||||
def snapshot_factory():
|
def snapshot_factory():
|
||||||
snapshot = db.Snapshot()
|
snapshot = db.Snapshot()
|
||||||
snapshot.creation_time = datetime(1999, 1, 1)
|
snapshot.creation_time = datetime(1999, 1, 1)
|
||||||
|
@ -12,12 +13,14 @@ def snapshot_factory():
|
||||||
snapshot.data = '{}'
|
snapshot.data = '{}'
|
||||||
return snapshot
|
return snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
'privileges': {'snapshots:list': db.User.RANK_REGULAR},
|
'privileges': {'snapshots:list': db.User.RANK_REGULAR},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def test_retrieving_multiple(user_factory, context_factory):
|
def test_retrieving_multiple(user_factory, context_factory):
|
||||||
snapshot1 = snapshot_factory()
|
snapshot1 = snapshot_factory()
|
||||||
snapshot2 = snapshot_factory()
|
snapshot2 = snapshot_factory()
|
||||||
|
@ -32,6 +35,7 @@ def test_retrieving_multiple(user_factory, context_factory):
|
||||||
assert result['total'] == 2
|
assert result['total'] == 2
|
||||||
assert len(result['results']) == 2
|
assert len(result['results']) == 2
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_retrieve_multiple_without_privileges(
|
def test_trying_to_retrieve_multiple_without_privileges(
|
||||||
user_factory, context_factory):
|
user_factory, context_factory):
|
||||||
with pytest.raises(errors.AuthError):
|
with pytest.raises(errors.AuthError):
|
||||||
|
|
|
@ -1,21 +1,24 @@
|
||||||
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
import unittest.mock
|
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, errors
|
||||||
from szurubooru.func import tag_categories, tags
|
from szurubooru.func import tag_categories, tags
|
||||||
|
|
||||||
|
|
||||||
def _update_category_name(category, name):
|
def _update_category_name(category, name):
|
||||||
category.name = name
|
category.name = name
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
'privileges': {'tag_categories:create': db.User.RANK_REGULAR},
|
'privileges': {'tag_categories:create': db.User.RANK_REGULAR},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def test_creating_category(user_factory, context_factory):
|
def test_creating_category(user_factory, context_factory):
|
||||||
with unittest.mock.patch('szurubooru.func.tag_categories.serialize_category'), \
|
with patch('szurubooru.func.tag_categories.serialize_category'), \
|
||||||
unittest.mock.patch('szurubooru.func.tag_categories.update_category_name'), \
|
patch('szurubooru.func.tag_categories.update_category_name'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.export_to_json'):
|
patch('szurubooru.func.tags.export_to_json'):
|
||||||
tag_categories.update_category_name.side_effect = _update_category_name
|
tag_categories.update_category_name.side_effect = _update_category_name
|
||||||
tag_categories.serialize_category.return_value = 'serialized category'
|
tag_categories.serialize_category.return_value = 'serialized category'
|
||||||
result = api.tag_category_api.create_tag_category(
|
result = api.tag_category_api.create_tag_category(
|
||||||
|
@ -29,6 +32,7 @@ def test_creating_category(user_factory, context_factory):
|
||||||
assert category.tag_count == 0
|
assert category.tag_count == 0
|
||||||
tags.export_to_json.assert_called_once_with()
|
tags.export_to_json.assert_called_once_with()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('field', ['name', 'color'])
|
@pytest.mark.parametrize('field', ['name', 'color'])
|
||||||
def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
|
def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
|
||||||
params = {
|
params = {
|
||||||
|
@ -42,6 +46,7 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
|
||||||
params=params,
|
params=params,
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=db.User.RANK_REGULAR)))
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_create_without_privileges(user_factory, context_factory):
|
def test_trying_to_create_without_privileges(user_factory, context_factory):
|
||||||
with pytest.raises(errors.AuthError):
|
with pytest.raises(errors.AuthError):
|
||||||
api.tag_category_api.create_tag_category(
|
api.tag_category_api.create_tag_category(
|
||||||
|
|
|
@ -1,19 +1,21 @@
|
||||||
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
import unittest.mock
|
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, errors
|
||||||
from szurubooru.func import tag_categories, tags
|
from szurubooru.func import tag_categories, tags
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
'privileges': {'tag_categories:delete': db.User.RANK_REGULAR},
|
'privileges': {'tag_categories:delete': db.User.RANK_REGULAR},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def test_deleting(user_factory, tag_category_factory, context_factory):
|
def test_deleting(user_factory, tag_category_factory, context_factory):
|
||||||
db.session.add(tag_category_factory(name='root'))
|
db.session.add(tag_category_factory(name='root'))
|
||||||
db.session.add(tag_category_factory(name='category'))
|
db.session.add(tag_category_factory(name='category'))
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
with unittest.mock.patch('szurubooru.func.tags.export_to_json'):
|
with patch('szurubooru.func.tags.export_to_json'):
|
||||||
result = api.tag_category_api.delete_tag_category(
|
result = api.tag_category_api.delete_tag_category(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'version': 1},
|
params={'version': 1},
|
||||||
|
@ -24,6 +26,7 @@ def test_deleting(user_factory, tag_category_factory, context_factory):
|
||||||
assert db.session.query(db.TagCategory).one().name == 'root'
|
assert db.session.query(db.TagCategory).one().name == 'root'
|
||||||
tags.export_to_json.assert_called_once_with()
|
tags.export_to_json.assert_called_once_with()
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_delete_used(
|
def test_trying_to_delete_used(
|
||||||
user_factory, tag_category_factory, tag_factory, context_factory):
|
user_factory, tag_category_factory, tag_factory, context_factory):
|
||||||
category = tag_category_factory(name='category')
|
category = tag_category_factory(name='category')
|
||||||
|
@ -40,6 +43,7 @@ def test_trying_to_delete_used(
|
||||||
{'category_name': 'category'})
|
{'category_name': 'category'})
|
||||||
assert db.session.query(db.TagCategory).count() == 1
|
assert db.session.query(db.TagCategory).count() == 1
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_delete_last(
|
def test_trying_to_delete_last(
|
||||||
user_factory, tag_category_factory, context_factory):
|
user_factory, tag_category_factory, context_factory):
|
||||||
db.session.add(tag_category_factory(name='root'))
|
db.session.add(tag_category_factory(name='root'))
|
||||||
|
@ -51,12 +55,14 @@ def test_trying_to_delete_last(
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=db.User.RANK_REGULAR)),
|
||||||
{'category_name': 'root'})
|
{'category_name': 'root'})
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_delete_non_existing(user_factory, context_factory):
|
def test_trying_to_delete_non_existing(user_factory, context_factory):
|
||||||
with pytest.raises(tag_categories.TagCategoryNotFoundError):
|
with pytest.raises(tag_categories.TagCategoryNotFoundError):
|
||||||
api.tag_category_api.delete_tag_category(
|
api.tag_category_api.delete_tag_category(
|
||||||
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
|
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
|
||||||
{'category_name': 'bad'})
|
{'category_name': 'bad'})
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_delete_without_privileges(
|
def test_trying_to_delete_without_privileges(
|
||||||
user_factory, tag_category_factory, context_factory):
|
user_factory, tag_category_factory, context_factory):
|
||||||
db.session.add(tag_category_factory(name='category'))
|
db.session.add(tag_category_factory(name='category'))
|
||||||
|
|
|
@ -2,6 +2,7 @@ import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, errors
|
||||||
from szurubooru.func import tag_categories
|
from szurubooru.func import tag_categories
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
|
@ -11,6 +12,7 @@ def inject_config(config_injector):
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def test_retrieving_multiple(
|
def test_retrieving_multiple(
|
||||||
user_factory, tag_category_factory, context_factory):
|
user_factory, tag_category_factory, context_factory):
|
||||||
db.session.add_all([
|
db.session.add_all([
|
||||||
|
@ -21,7 +23,9 @@ def test_retrieving_multiple(
|
||||||
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)))
|
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)))
|
||||||
assert [cat['name'] for cat in result['results']] == ['c1', 'c2']
|
assert [cat['name'] for cat in result['results']] == ['c1', 'c2']
|
||||||
|
|
||||||
def test_retrieving_single(user_factory, tag_category_factory, context_factory):
|
|
||||||
|
def test_retrieving_single(
|
||||||
|
user_factory, tag_category_factory, context_factory):
|
||||||
db.session.add(tag_category_factory(name='cat'))
|
db.session.add(tag_category_factory(name='cat'))
|
||||||
result = api.tag_category_api.get_tag_category(
|
result = api.tag_category_api.get_tag_category(
|
||||||
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
|
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
|
||||||
|
@ -35,12 +39,14 @@ def test_retrieving_single(user_factory, tag_category_factory, context_factory):
|
||||||
'version': 1,
|
'version': 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
|
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
|
||||||
with pytest.raises(tag_categories.TagCategoryNotFoundError):
|
with pytest.raises(tag_categories.TagCategoryNotFoundError):
|
||||||
api.tag_category_api.get_tag_category(
|
api.tag_category_api.get_tag_category(
|
||||||
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
|
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
|
||||||
{'category_name': '-'})
|
{'category_name': '-'})
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_retrieve_single_without_privileges(
|
def test_trying_to_retrieve_single_without_privileges(
|
||||||
user_factory, context_factory):
|
user_factory, context_factory):
|
||||||
with pytest.raises(errors.AuthError):
|
with pytest.raises(errors.AuthError):
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
import unittest.mock
|
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, errors
|
||||||
from szurubooru.func import tag_categories, tags
|
from szurubooru.func import tag_categories, tags
|
||||||
|
|
||||||
|
|
||||||
def _update_category_name(category, name):
|
def _update_category_name(category, name):
|
||||||
category.name = name
|
category.name = name
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
|
@ -16,14 +18,15 @@ def inject_config(config_injector):
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def test_simple_updating(user_factory, tag_category_factory, context_factory):
|
def test_simple_updating(user_factory, tag_category_factory, context_factory):
|
||||||
category = tag_category_factory(name='name', color='black')
|
category = tag_category_factory(name='name', color='black')
|
||||||
db.session.add(category)
|
db.session.add(category)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
with unittest.mock.patch('szurubooru.func.tag_categories.serialize_category'), \
|
with patch('szurubooru.func.tag_categories.serialize_category'), \
|
||||||
unittest.mock.patch('szurubooru.func.tag_categories.update_category_name'), \
|
patch('szurubooru.func.tag_categories.update_category_name'), \
|
||||||
unittest.mock.patch('szurubooru.func.tag_categories.update_category_color'), \
|
patch('szurubooru.func.tag_categories.update_category_color'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.export_to_json'):
|
patch('szurubooru.func.tags.export_to_json'):
|
||||||
tag_categories.update_category_name.side_effect = _update_category_name
|
tag_categories.update_category_name.side_effect = _update_category_name
|
||||||
tag_categories.serialize_category.return_value = 'serialized category'
|
tag_categories.serialize_category.return_value = 'serialized category'
|
||||||
result = api.tag_category_api.update_tag_category(
|
result = api.tag_category_api.update_tag_category(
|
||||||
|
@ -36,10 +39,13 @@ def test_simple_updating(user_factory, tag_category_factory, context_factory):
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=db.User.RANK_REGULAR)),
|
||||||
{'category_name': 'name'})
|
{'category_name': 'name'})
|
||||||
assert result == 'serialized category'
|
assert result == 'serialized category'
|
||||||
tag_categories.update_category_name.assert_called_once_with(category, 'changed')
|
tag_categories.update_category_name.assert_called_once_with(
|
||||||
tag_categories.update_category_color.assert_called_once_with(category, 'white')
|
category, 'changed')
|
||||||
|
tag_categories.update_category_color.assert_called_once_with(
|
||||||
|
category, 'white')
|
||||||
tags.export_to_json.assert_called_once_with()
|
tags.export_to_json.assert_called_once_with()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('field', ['name', 'color'])
|
@pytest.mark.parametrize('field', ['name', 'color'])
|
||||||
def test_omitting_optional_field(
|
def test_omitting_optional_field(
|
||||||
user_factory, tag_category_factory, context_factory, field):
|
user_factory, tag_category_factory, context_factory, field):
|
||||||
|
@ -50,15 +56,16 @@ def test_omitting_optional_field(
|
||||||
'color': 'white',
|
'color': 'white',
|
||||||
}
|
}
|
||||||
del params[field]
|
del params[field]
|
||||||
with unittest.mock.patch('szurubooru.func.tag_categories.serialize_category'), \
|
with patch('szurubooru.func.tag_categories.serialize_category'), \
|
||||||
unittest.mock.patch('szurubooru.func.tag_categories.update_category_name'), \
|
patch('szurubooru.func.tag_categories.update_category_name'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.export_to_json'):
|
patch('szurubooru.func.tags.export_to_json'):
|
||||||
api.tag_category_api.update_tag_category(
|
api.tag_category_api.update_tag_category(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={**params, **{'version': 1}},
|
params={**params, **{'version': 1}},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=db.User.RANK_REGULAR)),
|
||||||
{'category_name': 'name'})
|
{'category_name': 'name'})
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_update_non_existing(user_factory, context_factory):
|
def test_trying_to_update_non_existing(user_factory, context_factory):
|
||||||
with pytest.raises(tag_categories.TagCategoryNotFoundError):
|
with pytest.raises(tag_categories.TagCategoryNotFoundError):
|
||||||
api.tag_category_api.update_tag_category(
|
api.tag_category_api.update_tag_category(
|
||||||
|
@ -67,6 +74,7 @@ def test_trying_to_update_non_existing(user_factory, context_factory):
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=db.User.RANK_REGULAR)),
|
||||||
{'category_name': 'bad'})
|
{'category_name': 'bad'})
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('params', [
|
@pytest.mark.parametrize('params', [
|
||||||
{'name': 'whatever'},
|
{'name': 'whatever'},
|
||||||
{'color': 'whatever'},
|
{'color': 'whatever'},
|
||||||
|
@ -82,13 +90,14 @@ def test_trying_to_update_without_privileges(
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
|
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
|
||||||
{'category_name': 'dummy'})
|
{'category_name': 'dummy'})
|
||||||
|
|
||||||
|
|
||||||
def test_set_as_default(user_factory, tag_category_factory, context_factory):
|
def test_set_as_default(user_factory, tag_category_factory, context_factory):
|
||||||
category = tag_category_factory(name='name', color='black')
|
category = tag_category_factory(name='name', color='black')
|
||||||
db.session.add(category)
|
db.session.add(category)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
with unittest.mock.patch('szurubooru.func.tag_categories.serialize_category'), \
|
with patch('szurubooru.func.tag_categories.serialize_category'), \
|
||||||
unittest.mock.patch('szurubooru.func.tag_categories.set_default_category'), \
|
patch('szurubooru.func.tag_categories.set_default_category'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.export_to_json'):
|
patch('szurubooru.func.tags.export_to_json'):
|
||||||
tag_categories.update_category_name.side_effect = _update_category_name
|
tag_categories.update_category_name.side_effect = _update_category_name
|
||||||
tag_categories.serialize_category.return_value = 'serialized category'
|
tag_categories.serialize_category.return_value = 'serialized category'
|
||||||
result = api.tag_category_api.set_tag_category_as_default(
|
result = api.tag_category_api.set_tag_category_as_default(
|
||||||
|
|
|
@ -1,17 +1,19 @@
|
||||||
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
import unittest.mock
|
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, errors
|
||||||
from szurubooru.func import tags, tag_categories
|
from szurubooru.func import tags
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({'privileges': {'tags:create': db.User.RANK_REGULAR}})
|
config_injector({'privileges': {'tags:create': db.User.RANK_REGULAR}})
|
||||||
|
|
||||||
|
|
||||||
def test_creating_simple_tags(tag_factory, user_factory, context_factory):
|
def test_creating_simple_tags(tag_factory, user_factory, context_factory):
|
||||||
with unittest.mock.patch('szurubooru.func.tags.create_tag'), \
|
with patch('szurubooru.func.tags.create_tag'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.get_or_create_tags_by_names'), \
|
patch('szurubooru.func.tags.get_or_create_tags_by_names'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \
|
patch('szurubooru.func.tags.serialize_tag'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.export_to_json'):
|
patch('szurubooru.func.tags.export_to_json'):
|
||||||
tags.get_or_create_tags_by_names.return_value = ([], [])
|
tags.get_or_create_tags_by_names.return_value = ([], [])
|
||||||
tags.create_tag.return_value = tag_factory()
|
tags.create_tag.return_value = tag_factory()
|
||||||
tags.serialize_tag.return_value = 'serialized tag'
|
tags.serialize_tag.return_value = 'serialized tag'
|
||||||
|
@ -30,6 +32,7 @@ def test_creating_simple_tags(tag_factory, user_factory, context_factory):
|
||||||
['tag1', 'tag2'], 'meta', ['sug1', 'sug2'], ['imp1', 'imp2'])
|
['tag1', 'tag2'], 'meta', ['sug1', 'sug2'], ['imp1', 'imp2'])
|
||||||
tags.export_to_json.assert_called_once_with()
|
tags.export_to_json.assert_called_once_with()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('field', ['names', 'category'])
|
@pytest.mark.parametrize('field', ['names', 'category'])
|
||||||
def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
|
def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
|
||||||
params = {
|
params = {
|
||||||
|
@ -45,6 +48,7 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
|
||||||
params=params,
|
params=params,
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=db.User.RANK_REGULAR)))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('field', ['implications', 'suggestions'])
|
@pytest.mark.parametrize('field', ['implications', 'suggestions'])
|
||||||
def test_omitting_optional_field(
|
def test_omitting_optional_field(
|
||||||
tag_factory, user_factory, context_factory, field):
|
tag_factory, user_factory, context_factory, field):
|
||||||
|
@ -55,16 +59,18 @@ def test_omitting_optional_field(
|
||||||
'implications': [],
|
'implications': [],
|
||||||
}
|
}
|
||||||
del params[field]
|
del params[field]
|
||||||
with unittest.mock.patch('szurubooru.func.tags.create_tag'), \
|
with patch('szurubooru.func.tags.create_tag'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \
|
patch('szurubooru.func.tags.serialize_tag'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.export_to_json'):
|
patch('szurubooru.func.tags.export_to_json'):
|
||||||
tags.create_tag.return_value = tag_factory()
|
tags.create_tag.return_value = tag_factory()
|
||||||
api.tag_api.create_tag(
|
api.tag_api.create_tag(
|
||||||
context_factory(
|
context_factory(
|
||||||
params=params,
|
params=params,
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=db.User.RANK_REGULAR)))
|
||||||
|
|
||||||
def test_trying_to_create_tag_without_privileges(user_factory, context_factory):
|
|
||||||
|
def test_trying_to_create_tag_without_privileges(
|
||||||
|
user_factory, context_factory):
|
||||||
with pytest.raises(errors.AuthError):
|
with pytest.raises(errors.AuthError):
|
||||||
api.tag_api.create_tag(
|
api.tag_api.create_tag(
|
||||||
context_factory(
|
context_factory(
|
||||||
|
|
|
@ -1,16 +1,18 @@
|
||||||
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
import unittest.mock
|
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, errors
|
||||||
from szurubooru.func import tags
|
from szurubooru.func import tags
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({'privileges': {'tags:delete': db.User.RANK_REGULAR}})
|
config_injector({'privileges': {'tags:delete': db.User.RANK_REGULAR}})
|
||||||
|
|
||||||
|
|
||||||
def test_deleting(user_factory, tag_factory, context_factory):
|
def test_deleting(user_factory, tag_factory, context_factory):
|
||||||
db.session.add(tag_factory(names=['tag']))
|
db.session.add(tag_factory(names=['tag']))
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
with unittest.mock.patch('szurubooru.func.tags.export_to_json'):
|
with patch('szurubooru.func.tags.export_to_json'):
|
||||||
result = api.tag_api.delete_tag(
|
result = api.tag_api.delete_tag(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'version': 1},
|
params={'version': 1},
|
||||||
|
@ -20,13 +22,15 @@ def test_deleting(user_factory, tag_factory, context_factory):
|
||||||
assert db.session.query(db.Tag).count() == 0
|
assert db.session.query(db.Tag).count() == 0
|
||||||
tags.export_to_json.assert_called_once_with()
|
tags.export_to_json.assert_called_once_with()
|
||||||
|
|
||||||
def test_deleting_used(user_factory, tag_factory, context_factory, post_factory):
|
|
||||||
|
def test_deleting_used(
|
||||||
|
user_factory, tag_factory, context_factory, post_factory):
|
||||||
tag = tag_factory(names=['tag'])
|
tag = tag_factory(names=['tag'])
|
||||||
post = post_factory()
|
post = post_factory()
|
||||||
post.tags.append(tag)
|
post.tags.append(tag)
|
||||||
db.session.add_all([tag, post])
|
db.session.add_all([tag, post])
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
with unittest.mock.patch('szurubooru.func.tags.export_to_json'):
|
with patch('szurubooru.func.tags.export_to_json'):
|
||||||
api.tag_api.delete_tag(
|
api.tag_api.delete_tag(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'version': 1},
|
params={'version': 1},
|
||||||
|
@ -36,12 +40,14 @@ def test_deleting_used(user_factory, tag_factory, context_factory, post_factory)
|
||||||
assert db.session.query(db.Tag).count() == 0
|
assert db.session.query(db.Tag).count() == 0
|
||||||
assert post.tags == []
|
assert post.tags == []
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_delete_non_existing(user_factory, context_factory):
|
def test_trying_to_delete_non_existing(user_factory, context_factory):
|
||||||
with pytest.raises(tags.TagNotFoundError):
|
with pytest.raises(tags.TagNotFoundError):
|
||||||
api.tag_api.delete_tag(
|
api.tag_api.delete_tag(
|
||||||
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
|
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
|
||||||
{'tag_name': 'bad'})
|
{'tag_name': 'bad'})
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_delete_without_privileges(
|
def test_trying_to_delete_without_privileges(
|
||||||
user_factory, tag_factory, context_factory):
|
user_factory, tag_factory, context_factory):
|
||||||
db.session.add(tag_factory(names=['tag']))
|
db.session.add(tag_factory(names=['tag']))
|
||||||
|
|
|
@ -1,12 +1,14 @@
|
||||||
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
import unittest.mock
|
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, errors
|
||||||
from szurubooru.func import tags
|
from szurubooru.func import tags
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({'privileges': {'tags:merge': db.User.RANK_REGULAR}})
|
config_injector({'privileges': {'tags:merge': db.User.RANK_REGULAR}})
|
||||||
|
|
||||||
|
|
||||||
def test_merging(user_factory, tag_factory, context_factory, post_factory):
|
def test_merging(user_factory, tag_factory, context_factory, post_factory):
|
||||||
source_tag = tag_factory(names=['source'])
|
source_tag = tag_factory(names=['source'])
|
||||||
target_tag = tag_factory(names=['target'])
|
target_tag = tag_factory(names=['target'])
|
||||||
|
@ -20,10 +22,10 @@ def test_merging(user_factory, tag_factory, context_factory, post_factory):
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
assert source_tag.post_count == 1
|
assert source_tag.post_count == 1
|
||||||
assert target_tag.post_count == 0
|
assert target_tag.post_count == 0
|
||||||
with unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \
|
with patch('szurubooru.func.tags.serialize_tag'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.merge_tags'), \
|
patch('szurubooru.func.tags.merge_tags'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.export_to_json'):
|
patch('szurubooru.func.tags.export_to_json'):
|
||||||
result = api.tag_api.merge_tags(
|
api.tag_api.merge_tags(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={
|
params={
|
||||||
'removeVersion': 1,
|
'removeVersion': 1,
|
||||||
|
@ -35,6 +37,7 @@ def test_merging(user_factory, tag_factory, context_factory, post_factory):
|
||||||
tags.merge_tags.called_once_with(source_tag, target_tag)
|
tags.merge_tags.called_once_with(source_tag, target_tag)
|
||||||
tags.export_to_json.assert_called_once_with()
|
tags.export_to_json.assert_called_once_with()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
'field', ['remove', 'mergeTo', 'removeVersion', 'mergeToVersion'])
|
'field', ['remove', 'mergeTo', 'removeVersion', 'mergeToVersion'])
|
||||||
def test_trying_to_omit_mandatory_field(
|
def test_trying_to_omit_mandatory_field(
|
||||||
|
@ -57,6 +60,7 @@ def test_trying_to_omit_mandatory_field(
|
||||||
params=params,
|
params=params,
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=db.User.RANK_REGULAR)))
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_merge_non_existing(
|
def test_trying_to_merge_non_existing(
|
||||||
user_factory, tag_factory, context_factory):
|
user_factory, tag_factory, context_factory):
|
||||||
db.session.add(tag_factory(names=['good']))
|
db.session.add(tag_factory(names=['good']))
|
||||||
|
@ -72,14 +76,9 @@ def test_trying_to_merge_non_existing(
|
||||||
params={'remove': 'bad', 'mergeTo': 'good'},
|
params={'remove': 'bad', 'mergeTo': 'good'},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=db.User.RANK_REGULAR)))
|
||||||
|
|
||||||
@pytest.mark.parametrize('params', [
|
|
||||||
{'names': 'whatever'},
|
|
||||||
{'category': 'whatever'},
|
|
||||||
{'suggestions': ['whatever']},
|
|
||||||
{'implications': ['whatever']},
|
|
||||||
])
|
|
||||||
def test_trying_to_merge_without_privileges(
|
def test_trying_to_merge_without_privileges(
|
||||||
user_factory, tag_factory, context_factory, params):
|
user_factory, tag_factory, context_factory):
|
||||||
db.session.add_all([
|
db.session.add_all([
|
||||||
tag_factory(names=['source']),
|
tag_factory(names=['source']),
|
||||||
tag_factory(names=['target']),
|
tag_factory(names=['target']),
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
import unittest.mock
|
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, errors
|
||||||
from szurubooru.func import tags
|
from szurubooru.func import tags
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
|
@ -12,11 +13,12 @@ def inject_config(config_injector):
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def test_retrieving_multiple(user_factory, tag_factory, context_factory):
|
def test_retrieving_multiple(user_factory, tag_factory, context_factory):
|
||||||
tag1 = tag_factory(names=['t1'])
|
tag1 = tag_factory(names=['t1'])
|
||||||
tag2 = tag_factory(names=['t2'])
|
tag2 = tag_factory(names=['t2'])
|
||||||
db.session.add_all([tag1, tag2])
|
db.session.add_all([tag1, tag2])
|
||||||
with unittest.mock.patch('szurubooru.func.tags.serialize_tag'):
|
with patch('szurubooru.func.tags.serialize_tag'):
|
||||||
tags.serialize_tag.return_value = 'serialized tag'
|
tags.serialize_tag.return_value = 'serialized tag'
|
||||||
result = api.tag_api.get_tags(
|
result = api.tag_api.get_tags(
|
||||||
context_factory(
|
context_factory(
|
||||||
|
@ -30,6 +32,7 @@ def test_retrieving_multiple(user_factory, tag_factory, context_factory):
|
||||||
'results': ['serialized tag', 'serialized tag'],
|
'results': ['serialized tag', 'serialized tag'],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_retrieve_multiple_without_privileges(
|
def test_trying_to_retrieve_multiple_without_privileges(
|
||||||
user_factory, context_factory):
|
user_factory, context_factory):
|
||||||
with pytest.raises(errors.AuthError):
|
with pytest.raises(errors.AuthError):
|
||||||
|
@ -38,9 +41,10 @@ def test_trying_to_retrieve_multiple_without_privileges(
|
||||||
params={'query': '', 'page': 1},
|
params={'query': '', 'page': 1},
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
|
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
|
||||||
|
|
||||||
|
|
||||||
def test_retrieving_single(user_factory, tag_factory, context_factory):
|
def test_retrieving_single(user_factory, tag_factory, context_factory):
|
||||||
db.session.add(tag_factory(names=['tag']))
|
db.session.add(tag_factory(names=['tag']))
|
||||||
with unittest.mock.patch('szurubooru.func.tags.serialize_tag'):
|
with patch('szurubooru.func.tags.serialize_tag'):
|
||||||
tags.serialize_tag.return_value = 'serialized tag'
|
tags.serialize_tag.return_value = 'serialized tag'
|
||||||
result = api.tag_api.get_tag(
|
result = api.tag_api.get_tag(
|
||||||
context_factory(
|
context_factory(
|
||||||
|
@ -48,6 +52,7 @@ def test_retrieving_single(user_factory, tag_factory, context_factory):
|
||||||
{'tag_name': 'tag'})
|
{'tag_name': 'tag'})
|
||||||
assert result == 'serialized tag'
|
assert result == 'serialized tag'
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
|
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
|
||||||
with pytest.raises(tags.TagNotFoundError):
|
with pytest.raises(tags.TagNotFoundError):
|
||||||
api.tag_api.get_tag(
|
api.tag_api.get_tag(
|
||||||
|
@ -55,6 +60,7 @@ def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=db.User.RANK_REGULAR)),
|
||||||
{'tag_name': '-'})
|
{'tag_name': '-'})
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_retrieve_single_without_privileges(
|
def test_trying_to_retrieve_single_without_privileges(
|
||||||
user_factory, context_factory):
|
user_factory, context_factory):
|
||||||
with pytest.raises(errors.AuthError):
|
with pytest.raises(errors.AuthError):
|
||||||
|
|
|
@ -1,16 +1,18 @@
|
||||||
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
import unittest.mock
|
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, errors
|
||||||
from szurubooru.func import tags
|
from szurubooru.func import tags
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({'privileges': {'tags:view': db.User.RANK_REGULAR}})
|
config_injector({'privileges': {'tags:view': db.User.RANK_REGULAR}})
|
||||||
|
|
||||||
def test_get_tag_siblings(user_factory, tag_factory, post_factory, context_factory):
|
|
||||||
|
def test_get_tag_siblings(user_factory, tag_factory, context_factory):
|
||||||
db.session.add(tag_factory(names=['tag']))
|
db.session.add(tag_factory(names=['tag']))
|
||||||
with unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \
|
with patch('szurubooru.func.tags.serialize_tag'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.get_tag_siblings'):
|
patch('szurubooru.func.tags.get_tag_siblings'):
|
||||||
tags.serialize_tag.side_effect = \
|
tags.serialize_tag.side_effect = \
|
||||||
lambda tag, *args, **kwargs: \
|
lambda tag, *args, **kwargs: \
|
||||||
'serialized tag %s' % tag.names[0].name
|
'serialized tag %s' % tag.names[0].name
|
||||||
|
@ -34,12 +36,14 @@ def test_get_tag_siblings(user_factory, tag_factory, post_factory, context_facto
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_retrieve_non_existing(user_factory, context_factory):
|
def test_trying_to_retrieve_non_existing(user_factory, context_factory):
|
||||||
with pytest.raises(tags.TagNotFoundError):
|
with pytest.raises(tags.TagNotFoundError):
|
||||||
api.tag_api.get_tag_siblings(
|
api.tag_api.get_tag_siblings(
|
||||||
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
|
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
|
||||||
{'tag_name': '-'})
|
{'tag_name': '-'})
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_retrieve_without_privileges(user_factory, context_factory):
|
def test_trying_to_retrieve_without_privileges(user_factory, context_factory):
|
||||||
with pytest.raises(errors.AuthError):
|
with pytest.raises(errors.AuthError):
|
||||||
api.tag_api.get_tag_siblings(
|
api.tag_api.get_tag_siblings(
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
import unittest.mock
|
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, errors
|
||||||
from szurubooru.func import tags
|
from szurubooru.func import tags
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
|
@ -16,20 +17,21 @@ def inject_config(config_injector):
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
def test_simple_updating(user_factory, tag_factory, context_factory, fake_datetime):
|
|
||||||
|
def test_simple_updating(user_factory, tag_factory, context_factory):
|
||||||
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
||||||
tag = tag_factory(names=['tag1', 'tag2'])
|
tag = tag_factory(names=['tag1', 'tag2'])
|
||||||
db.session.add(tag)
|
db.session.add(tag)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
with unittest.mock.patch('szurubooru.func.tags.create_tag'), \
|
with patch('szurubooru.func.tags.create_tag'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.get_or_create_tags_by_names'), \
|
patch('szurubooru.func.tags.get_or_create_tags_by_names'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.update_tag_names'), \
|
patch('szurubooru.func.tags.update_tag_names'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.update_tag_category_name'), \
|
patch('szurubooru.func.tags.update_tag_category_name'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.update_tag_description'), \
|
patch('szurubooru.func.tags.update_tag_description'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.update_tag_suggestions'), \
|
patch('szurubooru.func.tags.update_tag_suggestions'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.update_tag_implications'), \
|
patch('szurubooru.func.tags.update_tag_implications'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \
|
patch('szurubooru.func.tags.serialize_tag'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.export_to_json'):
|
patch('szurubooru.func.tags.export_to_json'):
|
||||||
tags.get_or_create_tags_by_names.return_value = ([], [])
|
tags.get_or_create_tags_by_names.return_value = ([], [])
|
||||||
tags.serialize_tag.return_value = 'serialized tag'
|
tags.serialize_tag.return_value = 'serialized tag'
|
||||||
result = api.tag_api.update_tag(
|
result = api.tag_api.update_tag(
|
||||||
|
@ -49,12 +51,22 @@ def test_simple_updating(user_factory, tag_factory, context_factory, fake_dateti
|
||||||
tags.update_tag_names.assert_called_once_with(tag, ['tag3'])
|
tags.update_tag_names.assert_called_once_with(tag, ['tag3'])
|
||||||
tags.update_tag_category_name.assert_called_once_with(tag, 'character')
|
tags.update_tag_category_name.assert_called_once_with(tag, 'character')
|
||||||
tags.update_tag_description.assert_called_once_with(tag, 'desc')
|
tags.update_tag_description.assert_called_once_with(tag, 'desc')
|
||||||
tags.update_tag_suggestions.assert_called_once_with(tag, ['sug1', 'sug2'])
|
tags.update_tag_suggestions.assert_called_once_with(
|
||||||
tags.update_tag_implications.assert_called_once_with(tag, ['imp1', 'imp2'])
|
tag, ['sug1', 'sug2'])
|
||||||
tags.serialize_tag.assert_called_once_with(tag, options=None)
|
tags.update_tag_implications.assert_called_once_with(
|
||||||
|
tag, ['imp1', 'imp2'])
|
||||||
|
tags.serialize_tag.assert_called_once_with(
|
||||||
|
tag, options=None)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
'field', ['names', 'category', 'description', 'implications', 'suggestions'])
|
'field', [
|
||||||
|
'names',
|
||||||
|
'category',
|
||||||
|
'description',
|
||||||
|
'implications',
|
||||||
|
'suggestions',
|
||||||
|
])
|
||||||
def test_omitting_optional_field(
|
def test_omitting_optional_field(
|
||||||
user_factory, tag_factory, context_factory, field):
|
user_factory, tag_factory, context_factory, field):
|
||||||
db.session.add(tag_factory(names=['tag']))
|
db.session.add(tag_factory(names=['tag']))
|
||||||
|
@ -67,17 +79,18 @@ def test_omitting_optional_field(
|
||||||
'implications': [],
|
'implications': [],
|
||||||
}
|
}
|
||||||
del params[field]
|
del params[field]
|
||||||
with unittest.mock.patch('szurubooru.func.tags.create_tag'), \
|
with patch('szurubooru.func.tags.create_tag'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.update_tag_names'), \
|
patch('szurubooru.func.tags.update_tag_names'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.update_tag_category_name'), \
|
patch('szurubooru.func.tags.update_tag_category_name'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \
|
patch('szurubooru.func.tags.serialize_tag'), \
|
||||||
unittest.mock.patch('szurubooru.func.tags.export_to_json'):
|
patch('szurubooru.func.tags.export_to_json'):
|
||||||
api.tag_api.update_tag(
|
api.tag_api.update_tag(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={**params, **{'version': 1}},
|
params={**params, **{'version': 1}},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=db.User.RANK_REGULAR)),
|
||||||
{'tag_name': 'tag'})
|
{'tag_name': 'tag'})
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_update_non_existing(user_factory, context_factory):
|
def test_trying_to_update_non_existing(user_factory, context_factory):
|
||||||
with pytest.raises(tags.TagNotFoundError):
|
with pytest.raises(tags.TagNotFoundError):
|
||||||
api.tag_api.update_tag(
|
api.tag_api.update_tag(
|
||||||
|
@ -86,6 +99,7 @@ def test_trying_to_update_non_existing(user_factory, context_factory):
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=db.User.RANK_REGULAR)),
|
||||||
{'tag_name': 'tag1'})
|
{'tag_name': 'tag1'})
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('params', [
|
@pytest.mark.parametrize('params', [
|
||||||
{'names': 'whatever'},
|
{'names': 'whatever'},
|
||||||
{'category': 'whatever'},
|
{'category': 'whatever'},
|
||||||
|
@ -103,6 +117,7 @@ def test_trying_to_update_without_privileges(
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
|
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
|
||||||
{'tag_name': 'tag'})
|
{'tag_name': 'tag'})
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_create_tags_without_privileges(
|
def test_trying_to_create_tags_without_privileges(
|
||||||
config_injector, context_factory, tag_factory, user_factory):
|
config_injector, context_factory, tag_factory, user_factory):
|
||||||
tag = tag_factory(names=['tag'])
|
tag = tag_factory(names=['tag'])
|
||||||
|
@ -113,7 +128,7 @@ def test_trying_to_create_tags_without_privileges(
|
||||||
'tags:edit:suggestions': db.User.RANK_REGULAR,
|
'tags:edit:suggestions': db.User.RANK_REGULAR,
|
||||||
'tags:edit:implications': db.User.RANK_REGULAR,
|
'tags:edit:implications': db.User.RANK_REGULAR,
|
||||||
}})
|
}})
|
||||||
with unittest.mock.patch('szurubooru.func.tags.get_or_create_tags_by_names'):
|
with patch('szurubooru.func.tags.get_or_create_tags_by_names'):
|
||||||
tags.get_or_create_tags_by_names.return_value = ([], ['new-tag'])
|
tags.get_or_create_tags_by_names.return_value = ([], ['new-tag'])
|
||||||
with pytest.raises(errors.AuthError):
|
with pytest.raises(errors.AuthError):
|
||||||
api.tag_api.update_tag(
|
api.tag_api.update_tag(
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue