server/general: ditch falcon for in-house WSGI app

For quite some time, I hated Falcon's class maps approach that caused
more chaos than good for Szurubooru. I've taken a look at the other
frameworks (hug, flask, etc) again, but they all looked too
bloated/over-engineered. I decided to just talk to WSGI myself.

Regex-based routing may not be the fastest in the world, but I'm fine
with response time of 10 ms for cached /posts.
This commit is contained in:
rr- 2016-08-14 12:35:14 +02:00
parent d102c9bdba
commit af62f8c45a
61 changed files with 2447 additions and 3096 deletions

View file

@ -1,6 +1,7 @@
[basic] [basic]
function-rgx=^_?[a-z_][a-z0-9_]{2,}$|^test_ function-rgx=^_?[a-z_][a-z0-9_]{2,}$|^test_
method-rgx=^[a-z_][a-z0-9_]{2,}$|^test_ method-rgx=^[a-z_][a-z0-9_]{2,}$|^test_
const-rgx=^[A-Z_]+$|^_[a-zA-Z_]*$
good-names=ex,_,logger good-names=ex,_,logger
[variables] [variables]

View file

@ -10,7 +10,7 @@ import argparse
import os.path import os.path
import sys import sys
import waitress import waitress
from szurubooru.app import create_app from szurubooru.facade import create_app
def main(): def main():
parser = argparse.ArgumentParser('Starts szurubooru using waitress.') parser = argparse.ArgumentParser('Starts szurubooru using waitress.')

View file

@ -1,6 +1,5 @@
alembic>=0.8.5 alembic>=0.8.5
pyyaml>=3.11 pyyaml>=3.11
falcon>=0.3.0
psycopg2>=2.6.1 psycopg2>=2.6.1
SQLAlchemy>=1.0.12 SQLAlchemy>=1.0.12
pytest>=2.9.1 pytest>=2.9.1

View file

@ -1,27 +1,8 @@
''' Falcon-compatible API facades. ''' import szurubooru.api.info_api
import szurubooru.api.user_api
from szurubooru.api.password_reset_api import PasswordResetApi import szurubooru.api.post_api
from szurubooru.api.user_api import UserListApi, UserDetailApi import szurubooru.api.tag_api
from szurubooru.api.tag_api import ( import szurubooru.api.tag_category_api
TagListApi, import szurubooru.api.comment_api
TagDetailApi, import szurubooru.api.password_reset_api
TagMergeApi, import szurubooru.api.snapshot_api
TagSiblingsApi)
from szurubooru.api.tag_category_api import (
TagCategoryListApi,
TagCategoryDetailApi,
DefaultTagCategoryApi)
from szurubooru.api.comment_api import (
CommentListApi,
CommentDetailApi,
CommentScoreApi)
from szurubooru.api.post_api import (
PostListApi,
PostDetailApi,
PostFeatureApi,
PostScoreApi,
PostFavoriteApi,
PostsAroundApi)
from szurubooru.api.snapshot_api import SnapshotListApi
from szurubooru.api.info_api import InfoApi
from szurubooru.api.context import Context, Request

View file

@ -1,27 +0,0 @@
import types
def _bind_method(target, desired_method_name):
actual_method = getattr(target, desired_method_name)
def _wrapper_method(_self, request, _response, *args, **kwargs):
request.context.output = \
actual_method(request.context, *args, **kwargs)
return types.MethodType(_wrapper_method, target)
class BaseApi(object):
'''
A wrapper around falcon's API interface that eases input and output
management.
'''
def __init__(self):
self._translate_routes()
def _translate_routes(self):
for method_name in ['GET', 'PUT', 'POST', 'DELETE']:
desired_method_name = method_name.lower()
falcon_method_name = 'on_%s' % method_name.lower()
if hasattr(self, desired_method_name):
setattr(
self,
falcon_method_name,
_bind_method(self, desired_method_name))

View file

@ -1,7 +1,9 @@
import datetime import datetime
from szurubooru import search from szurubooru import search
from szurubooru.api.base_api import BaseApi
from szurubooru.func import auth, comments, posts, scores, util from szurubooru.func import auth, comments, posts, scores, util
from szurubooru.rest import routes
_search_executor = search.Executor(search.configs.CommentSearchConfig())
def _serialize(ctx, comment, **kwargs): def _serialize(ctx, comment, **kwargs):
return comments.serialize_comment( return comments.serialize_comment(
@ -9,19 +11,14 @@ def _serialize(ctx, comment, **kwargs):
ctx.user, ctx.user,
options=util.get_serialization_options(ctx), **kwargs) options=util.get_serialization_options(ctx), **kwargs)
class CommentListApi(BaseApi): @routes.get('/comments/?')
def __init__(self): def get_comments(ctx, _params=None):
super().__init__()
self._search_executor = search.Executor(
search.configs.CommentSearchConfig())
def get(self, ctx):
auth.verify_privilege(ctx.user, 'comments:list') auth.verify_privilege(ctx.user, 'comments:list')
return self._search_executor.execute_and_serialize( return _search_executor.execute_and_serialize(
ctx, ctx, lambda comment: _serialize(ctx, comment))
lambda comment: _serialize(ctx, comment))
def post(self, ctx): @routes.post('/comments/?')
def create_comment(ctx, _params=None):
auth.verify_privilege(ctx.user, 'comments:create') auth.verify_privilege(ctx.user, 'comments:create')
text = ctx.get_param_as_string('text', required=True) text = ctx.get_param_as_string('text', required=True)
post_id = ctx.get_param_as_int('postId', required=True) post_id = ctx.get_param_as_int('postId', required=True)
@ -31,14 +28,15 @@ class CommentListApi(BaseApi):
ctx.session.commit() ctx.session.commit()
return _serialize(ctx, comment) return _serialize(ctx, comment)
class CommentDetailApi(BaseApi): @routes.get('/comment/(?P<comment_id>[^/]+)/?')
def get(self, ctx, comment_id): def get_comment(ctx, params):
auth.verify_privilege(ctx.user, 'comments:view') auth.verify_privilege(ctx.user, 'comments:view')
comment = comments.get_comment_by_id(comment_id) comment = comments.get_comment_by_id(params['comment_id'])
return _serialize(ctx, comment) return _serialize(ctx, comment)
def put(self, ctx, comment_id): @routes.put('/comment/(?P<comment_id>[^/]+)/?')
comment = comments.get_comment_by_id(comment_id) def update_comment(ctx, params):
comment = comments.get_comment_by_id(params['comment_id'])
util.verify_version(comment, ctx) util.verify_version(comment, ctx)
infix = 'own' if ctx.user.user_id == comment.user_id else 'any' infix = 'own' if ctx.user.user_id == comment.user_id else 'any'
text = ctx.get_param_as_string('text', required=True) text = ctx.get_param_as_string('text', required=True)
@ -49,8 +47,9 @@ class CommentDetailApi(BaseApi):
ctx.session.commit() ctx.session.commit()
return _serialize(ctx, comment) return _serialize(ctx, comment)
def delete(self, ctx, comment_id): @routes.delete('/comment/(?P<comment_id>[^/]+)/?')
comment = comments.get_comment_by_id(comment_id) def delete_comment(ctx, params):
comment = comments.get_comment_by_id(params['comment_id'])
util.verify_version(comment, ctx) util.verify_version(comment, ctx)
infix = 'own' if ctx.user.user_id == comment.user_id else 'any' infix = 'own' if ctx.user.user_id == comment.user_id else 'any'
auth.verify_privilege(ctx.user, 'comments:delete:%s' % infix) auth.verify_privilege(ctx.user, 'comments:delete:%s' % infix)
@ -58,18 +57,19 @@ class CommentDetailApi(BaseApi):
ctx.session.commit() ctx.session.commit()
return {} return {}
class CommentScoreApi(BaseApi): @routes.put('/comment/(?P<comment_id>[^/]+)/score/?')
def put(self, ctx, comment_id): def set_comment_score(ctx, params):
auth.verify_privilege(ctx.user, 'comments:score') auth.verify_privilege(ctx.user, 'comments:score')
score = ctx.get_param_as_int('score', required=True) score = ctx.get_param_as_int('score', required=True)
comment = comments.get_comment_by_id(comment_id) comment = comments.get_comment_by_id(params['comment_id'])
scores.set_score(comment, ctx.user, score) scores.set_score(comment, ctx.user, score)
ctx.session.commit() ctx.session.commit()
return _serialize(ctx, comment) return _serialize(ctx, comment)
def delete(self, ctx, comment_id): @routes.delete('/comment/(?P<comment_id>[^/]+)/score/?')
def delete_comment_score(ctx, params):
auth.verify_privilege(ctx.user, 'comments:score') auth.verify_privilege(ctx.user, 'comments:score')
comment = comments.get_comment_by_id(comment_id) comment = comments.get_comment_by_id(params['comment_id'])
scores.delete_score(comment, ctx.user) scores.delete_score(comment, ctx.user)
ctx.session.commit() ctx.session.commit()
return _serialize(ctx, comment) return _serialize(ctx, comment)

View file

@ -1,20 +1,33 @@
import datetime import datetime
import os import os
from szurubooru import config from szurubooru import config
from szurubooru.api.base_api import BaseApi
from szurubooru.func import posts, users, util from szurubooru.func import posts, users, util
from szurubooru.rest import routes
class InfoApi(BaseApi): _cache_time = None
def __init__(self): _cache_result = None
super().__init__()
self._cache_time = None
self._cache_result = None
def get(self, ctx): def _get_disk_usage():
global _cache_time, _cache_result # pylint: disable=global-statement
threshold = datetime.timedelta(hours=1)
now = datetime.datetime.utcnow()
if _cache_time and _cache_time > now - threshold:
return _cache_result
total_size = 0
for dir_path, _, file_names in os.walk(config.config['data_dir']):
for file_name in file_names:
file_path = os.path.join(dir_path, file_name)
total_size += os.path.getsize(file_path)
_cache_time = now
_cache_result = total_size
return total_size
@routes.get('/info/?')
def get_info(ctx, _params=None):
post_feature = posts.try_get_current_post_feature() post_feature = posts.try_get_current_post_feature()
return { return {
'postCount': posts.get_post_count(), 'postCount': posts.get_post_count(),
'diskUsage': self._get_disk_usage(), 'diskUsage': _get_disk_usage(),
'featuredPost': posts.serialize_post(post_feature.post, ctx.user) \ 'featuredPost': posts.serialize_post(post_feature.post, ctx.user) \
if post_feature else None, if post_feature else None,
'featuringTime': post_feature.time if post_feature else None, 'featuringTime': post_feature.time if post_feature else None,
@ -31,17 +44,3 @@ class InfoApi(BaseApi):
config.config['privileges']), config.config['privileges']),
}, },
} }
def _get_disk_usage(self):
threshold = datetime.timedelta(hours=1)
now = datetime.datetime.utcnow()
if self._cache_time and self._cache_time > now - threshold:
return self._cache_result
total_size = 0
for dir_path, _, file_names in os.walk(config.config['data_dir']):
for file_name in file_names:
file_path = os.path.join(dir_path, file_name)
total_size += os.path.getsize(file_path)
self._cache_time = now
self._cache_result = total_size
return total_size

View file

@ -1,6 +1,6 @@
from szurubooru import config, errors from szurubooru import config, errors
from szurubooru.api.base_api import BaseApi
from szurubooru.func import auth, mailer, users, util from szurubooru.func import auth, mailer, users, util
from szurubooru.rest import routes
MAIL_SUBJECT = 'Password reset for {name}' MAIL_SUBJECT = 'Password reset for {name}'
MAIL_BODY = \ MAIL_BODY = \
@ -8,9 +8,10 @@ MAIL_BODY = \
'If you wish to proceed, click this link: {url}\n' \ 'If you wish to proceed, click this link: {url}\n' \
'Otherwise, please ignore this email.' 'Otherwise, please ignore this email.'
class PasswordResetApi(BaseApi): @routes.get('/password-reset/(?P<user_name>[^/]+)/?')
def get(self, _ctx, user_name): def start_password_reset(_ctx, params):
''' Send a mail with secure token to the correlated user. ''' ''' Send a mail with secure token to the correlated user. '''
user_name = params['user_name']
user = users.get_user_by_name_or_email(user_name) user = users.get_user_by_name_or_email(user_name)
if not user.email: if not user.email:
raise errors.ValidationError( raise errors.ValidationError(
@ -26,8 +27,10 @@ class PasswordResetApi(BaseApi):
MAIL_BODY.format(name=config.config['name'], url=url)) MAIL_BODY.format(name=config.config['name'], url=url))
return {} return {}
def post(self, ctx, user_name): @routes.post('/password-reset/(?P<user_name>[^/]+)/?')
def finish_password_reset(ctx, params):
''' Verify token from mail, generate a new password and return it. ''' ''' Verify token from mail, generate a new password and return it. '''
user_name = params['user_name']
user = users.get_user_by_name_or_email(user_name) user = users.get_user_by_name_or_email(user_name)
good_token = auth.generate_authentication_token(user) good_token = auth.generate_authentication_token(user)
token = ctx.get_param_as_string('token', required=True) token = ctx.get_param_as_string('token', required=True)

View file

@ -1,7 +1,9 @@
import datetime import datetime
from szurubooru import search from szurubooru import search
from szurubooru.api.base_api import BaseApi
from szurubooru.func import auth, tags, posts, snapshots, favorites, scores, util from szurubooru.func import auth, tags, posts, snapshots, favorites, scores, util
from szurubooru.rest import routes
_search_executor = search.Executor(search.configs.PostSearchConfig())
def _serialize_post(ctx, post): def _serialize_post(ctx, post):
return posts.serialize_post( return posts.serialize_post(
@ -9,19 +11,15 @@ def _serialize_post(ctx, post):
ctx.user, ctx.user,
options=util.get_serialization_options(ctx)) options=util.get_serialization_options(ctx))
class PostListApi(BaseApi): @routes.get('/posts/?')
def __init__(self): def get_posts(ctx, _params=None):
super().__init__()
self._search_executor = search.Executor(
search.configs.PostSearchConfig())
def get(self, ctx):
auth.verify_privilege(ctx.user, 'posts:list') auth.verify_privilege(ctx.user, 'posts:list')
self._search_executor.config.user = ctx.user _search_executor.config.user = ctx.user
return self._search_executor.execute_and_serialize( return _search_executor.execute_and_serialize(
ctx, lambda post: _serialize_post(ctx, post)) ctx, lambda post: _serialize_post(ctx, post))
def post(self, ctx): @routes.post('/posts/?')
def create_post(ctx, _params=None):
anonymous = ctx.get_param_as_bool('anonymous', default=False) anonymous = ctx.get_param_as_bool('anonymous', default=False)
if anonymous: if anonymous:
auth.verify_privilege(ctx.user, 'posts:create:anonymous') auth.verify_privilege(ctx.user, 'posts:create:anonymous')
@ -54,14 +52,15 @@ class PostListApi(BaseApi):
tags.export_to_json() tags.export_to_json()
return _serialize_post(ctx, post) return _serialize_post(ctx, post)
class PostDetailApi(BaseApi): @routes.get('/post/(?P<post_id>[^/]+)/?')
def get(self, ctx, post_id): def get_post(ctx, params):
auth.verify_privilege(ctx.user, 'posts:view') auth.verify_privilege(ctx.user, 'posts:view')
post = posts.get_post_by_id(post_id) post = posts.get_post_by_id(params['post_id'])
return _serialize_post(ctx, post) return _serialize_post(ctx, post)
def put(self, ctx, post_id): @routes.put('/post/(?P<post_id>[^/]+)/?')
post = posts.get_post_by_id(post_id) def update_post(ctx, params):
post = posts.get_post_by_id(params['post_id'])
util.verify_version(post, ctx) util.verify_version(post, ctx)
if ctx.has_file('content'): if ctx.has_file('content'):
auth.verify_privilege(ctx.user, 'posts:edit:content') auth.verify_privilege(ctx.user, 'posts:edit:content')
@ -99,9 +98,10 @@ class PostDetailApi(BaseApi):
tags.export_to_json() tags.export_to_json()
return _serialize_post(ctx, post) return _serialize_post(ctx, post)
def delete(self, ctx, post_id): @routes.delete('/post/(?P<post_id>[^/]+)/?')
def delete_post(ctx, params):
auth.verify_privilege(ctx.user, 'posts:delete') auth.verify_privilege(ctx.user, 'posts:delete')
post = posts.get_post_by_id(post_id) post = posts.get_post_by_id(params['post_id'])
util.verify_version(post, ctx) util.verify_version(post, ctx)
snapshots.save_entity_deletion(post, ctx.user) snapshots.save_entity_deletion(post, ctx.user)
posts.delete(post) posts.delete(post)
@ -109,8 +109,13 @@ class PostDetailApi(BaseApi):
tags.export_to_json() tags.export_to_json()
return {} return {}
class PostFeatureApi(BaseApi): @routes.get('/featured-post/?')
def post(self, ctx): def get_featured_post(ctx, _params=None):
post = posts.try_get_featured_post()
return _serialize_post(ctx, post)
@routes.post('/featured-post/?')
def set_featured_post(ctx, _params=None):
auth.verify_privilege(ctx.user, 'posts:feature') auth.verify_privilege(ctx.user, 'posts:feature')
post_id = ctx.get_param_as_int('id', required=True) post_id = ctx.get_param_as_int('id', required=True)
post = posts.get_post_by_id(post_id) post = posts.get_post_by_id(post_id)
@ -125,49 +130,42 @@ class PostFeatureApi(BaseApi):
ctx.session.commit() ctx.session.commit()
return _serialize_post(ctx, post) return _serialize_post(ctx, post)
def get(self, ctx): @routes.put('/post/(?P<post_id>[^/]+)/score/?')
post = posts.try_get_featured_post() def set_post_score(ctx, params):
return _serialize_post(ctx, post)
class PostScoreApi(BaseApi):
def put(self, ctx, post_id):
auth.verify_privilege(ctx.user, 'posts:score') auth.verify_privilege(ctx.user, 'posts:score')
post = posts.get_post_by_id(post_id) post = posts.get_post_by_id(params['post_id'])
score = ctx.get_param_as_int('score', required=True) score = ctx.get_param_as_int('score', required=True)
scores.set_score(post, ctx.user, score) scores.set_score(post, ctx.user, score)
ctx.session.commit() ctx.session.commit()
return _serialize_post(ctx, post) return _serialize_post(ctx, post)
def delete(self, ctx, post_id): @routes.delete('/post/(?P<post_id>[^/]+)/score/?')
def delete_post_score(ctx, params):
auth.verify_privilege(ctx.user, 'posts:score') auth.verify_privilege(ctx.user, 'posts:score')
post = posts.get_post_by_id(post_id) post = posts.get_post_by_id(params['post_id'])
scores.delete_score(post, ctx.user) scores.delete_score(post, ctx.user)
ctx.session.commit() ctx.session.commit()
return _serialize_post(ctx, post) return _serialize_post(ctx, post)
class PostFavoriteApi(BaseApi): @routes.post('/post/(?P<post_id>[^/]+)/favorite/?')
def post(self, ctx, post_id): def add_post_to_favorites(ctx, params):
auth.verify_privilege(ctx.user, 'posts:favorite') auth.verify_privilege(ctx.user, 'posts:favorite')
post = posts.get_post_by_id(post_id) post = posts.get_post_by_id(params['post_id'])
favorites.set_favorite(post, ctx.user) favorites.set_favorite(post, ctx.user)
ctx.session.commit() ctx.session.commit()
return _serialize_post(ctx, post) return _serialize_post(ctx, post)
def delete(self, ctx, post_id): @routes.delete('/post/(?P<post_id>[^/]+)/favorite/?')
def delete_post_from_favorites(ctx, params):
auth.verify_privilege(ctx.user, 'posts:favorite') auth.verify_privilege(ctx.user, 'posts:favorite')
post = posts.get_post_by_id(post_id) post = posts.get_post_by_id(params['post_id'])
favorites.unset_favorite(post, ctx.user) favorites.unset_favorite(post, ctx.user)
ctx.session.commit() ctx.session.commit()
return _serialize_post(ctx, post) return _serialize_post(ctx, post)
class PostsAroundApi(BaseApi): @routes.get('/post/(?P<post_id>[^/]+)/around/?')
def __init__(self): def get_posts_around(ctx, params):
super().__init__()
self._search_executor = search.Executor(
search.configs.PostSearchConfig())
def get(self, ctx, post_id):
auth.verify_privilege(ctx.user, 'posts:list') auth.verify_privilege(ctx.user, 'posts:list')
self._search_executor.config.user = ctx.user _search_executor.config.user = ctx.user
return self._search_executor.get_around_and_serialize( return _search_executor.get_around_and_serialize(
ctx, post_id, lambda post: _serialize_post(ctx, post)) ctx, params['post_id'], lambda post: _serialize_post(ctx, post))

View file

@ -1,14 +1,12 @@
from szurubooru import search from szurubooru import search
from szurubooru.api.base_api import BaseApi
from szurubooru.func import auth, snapshots from szurubooru.func import auth, snapshots
from szurubooru.rest import routes
class SnapshotListApi(BaseApi): _search_executor = search.Executor(
def __init__(self):
super().__init__()
self._search_executor = search.Executor(
search.configs.SnapshotSearchConfig()) search.configs.SnapshotSearchConfig())
def get(self, ctx): @routes.get('/snapshots/?')
def get_snapshots(ctx, _params=None):
auth.verify_privilege(ctx.user, 'snapshots:list') auth.verify_privilege(ctx.user, 'snapshots:list')
return self._search_executor.execute_and_serialize( return _search_executor.execute_and_serialize(
ctx, snapshots.serialize_snapshot) ctx, snapshots.serialize_snapshot)

View file

@ -1,7 +1,9 @@
import datetime import datetime
from szurubooru import db, search from szurubooru import db, search
from szurubooru.api.base_api import BaseApi
from szurubooru.func import auth, tags, util, snapshots from szurubooru.func import auth, tags, util, snapshots
from szurubooru.rest import routes
_search_executor = search.Executor(search.configs.TagSearchConfig())
def _serialize(ctx, tag): def _serialize(ctx, tag):
return tags.serialize_tag( return tags.serialize_tag(
@ -17,22 +19,18 @@ def _create_if_needed(tag_names, user):
for tag in new_tags: for tag in new_tags:
snapshots.save_entity_creation(tag, user) snapshots.save_entity_creation(tag, user)
class TagListApi(BaseApi): @routes.get('/tags/?')
def __init__(self): def get_tags(ctx, _params=None):
super().__init__()
self._search_executor = search.Executor(
search.configs.TagSearchConfig())
def get(self, ctx):
auth.verify_privilege(ctx.user, 'tags:list') auth.verify_privilege(ctx.user, 'tags:list')
return self._search_executor.execute_and_serialize( return _search_executor.execute_and_serialize(
ctx, lambda tag: _serialize(ctx, tag)) ctx, lambda tag: _serialize(ctx, tag))
def post(self, ctx): @routes.post('/tags/?')
def create_tag(ctx, _params=None):
auth.verify_privilege(ctx.user, 'tags:create') auth.verify_privilege(ctx.user, 'tags:create')
names = ctx.get_param_as_list('names', required=True) names = ctx.get_param_as_list('names', required=True)
category = ctx.get_param_as_string('category', required=True) or '' category = ctx.get_param_as_string('category', required=True)
description = ctx.get_param_as_string( description = ctx.get_param_as_string(
'description', required=False, default=None) 'description', required=False, default=None)
suggestions = ctx.get_param_as_list( suggestions = ctx.get_param_as_list(
@ -52,14 +50,15 @@ class TagListApi(BaseApi):
tags.export_to_json() tags.export_to_json()
return _serialize(ctx, tag) return _serialize(ctx, tag)
class TagDetailApi(BaseApi): @routes.get('/tag/(?P<tag_name>[^/]+)/?')
def get(self, ctx, tag_name): def get_tag(ctx, params):
auth.verify_privilege(ctx.user, 'tags:view') auth.verify_privilege(ctx.user, 'tags:view')
tag = tags.get_tag_by_name(tag_name) tag = tags.get_tag_by_name(params['tag_name'])
return _serialize(ctx, tag) return _serialize(ctx, tag)
def put(self, ctx, tag_name): @routes.put('/tag/(?P<tag_name>[^/]+)/?')
tag = tags.get_tag_by_name(tag_name) def update_tag(ctx, params):
tag = tags.get_tag_by_name(params['tag_name'])
util.verify_version(tag, ctx) util.verify_version(tag, ctx)
if ctx.has_param('names'): if ctx.has_param('names'):
auth.verify_privilege(ctx.user, 'tags:edit:names') auth.verify_privilege(ctx.user, 'tags:edit:names')
@ -67,7 +66,7 @@ class TagDetailApi(BaseApi):
if ctx.has_param('category'): if ctx.has_param('category'):
auth.verify_privilege(ctx.user, 'tags:edit:category') auth.verify_privilege(ctx.user, 'tags:edit:category')
tags.update_tag_category_name( tags.update_tag_category_name(
tag, ctx.get_param_as_string('category') or '') tag, ctx.get_param_as_string('category'))
if ctx.has_param('description'): if ctx.has_param('description'):
auth.verify_privilege(ctx.user, 'tags:edit:description') auth.verify_privilege(ctx.user, 'tags:edit:description')
tags.update_tag_description( tags.update_tag_description(
@ -90,8 +89,9 @@ class TagDetailApi(BaseApi):
tags.export_to_json() tags.export_to_json()
return _serialize(ctx, tag) return _serialize(ctx, tag)
def delete(self, ctx, tag_name): @routes.delete('/tag/(?P<tag_name>[^/]+)/?')
tag = tags.get_tag_by_name(tag_name) def delete_tag(ctx, params):
tag = tags.get_tag_by_name(params['tag_name'])
util.verify_version(tag, ctx) util.verify_version(tag, ctx)
auth.verify_privilege(ctx.user, 'tags:delete') auth.verify_privilege(ctx.user, 'tags:delete')
snapshots.save_entity_deletion(tag, ctx.user) snapshots.save_entity_deletion(tag, ctx.user)
@ -100,28 +100,26 @@ class TagDetailApi(BaseApi):
tags.export_to_json() tags.export_to_json()
return {} return {}
class TagMergeApi(BaseApi): @routes.post('/tag-merge/?')
def post(self, ctx): def merge_tags(ctx, _params=None):
source_tag_name = ctx.get_param_as_string('remove', required=True) or '' source_tag_name = ctx.get_param_as_string('remove', required=True) or ''
target_tag_name = ctx.get_param_as_string('mergeTo', required=True) or '' target_tag_name = ctx.get_param_as_string('mergeTo', required=True) or ''
source_tag = tags.get_tag_by_name(source_tag_name) source_tag = tags.get_tag_by_name(source_tag_name)
target_tag = tags.get_tag_by_name(target_tag_name) target_tag = tags.get_tag_by_name(target_tag_name)
util.verify_version(source_tag, ctx, 'removeVersion') util.verify_version(source_tag, ctx, 'removeVersion')
util.verify_version(target_tag, ctx, 'mergeToVersion') util.verify_version(target_tag, ctx, 'mergeToVersion')
if source_tag.tag_id == target_tag.tag_id:
raise tags.InvalidTagRelationError('Cannot merge tag with itself.')
auth.verify_privilege(ctx.user, 'tags:merge') auth.verify_privilege(ctx.user, 'tags:merge')
snapshots.save_entity_deletion(source_tag, ctx.user)
tags.merge_tags(source_tag, target_tag) tags.merge_tags(source_tag, target_tag)
snapshots.save_entity_deletion(source_tag, ctx.user)
util.bump_version(target_tag) util.bump_version(target_tag)
ctx.session.commit() ctx.session.commit()
tags.export_to_json() tags.export_to_json()
return _serialize(ctx, target_tag) return _serialize(ctx, target_tag)
class TagSiblingsApi(BaseApi): @routes.get('/tag-siblings/(?P<tag_name>[^/]+)/?')
def get(self, ctx, tag_name): def get_tag_siblings(ctx, params):
auth.verify_privilege(ctx.user, 'tags:view') auth.verify_privilege(ctx.user, 'tags:view')
tag = tags.get_tag_by_name(tag_name) tag = tags.get_tag_by_name(params['tag_name'])
result = tags.get_tag_siblings(tag) result = tags.get_tag_siblings(tag)
serialized_siblings = [] serialized_siblings = []
for sibling, occurrences in result: for sibling, occurrences in result:

View file

@ -1,19 +1,20 @@
from szurubooru.api.base_api import BaseApi from szurubooru.rest import routes
from szurubooru.func import auth, tags, tag_categories, util, snapshots from szurubooru.func import auth, tags, tag_categories, util, snapshots
def _serialize(ctx, category): def _serialize(ctx, category):
return tag_categories.serialize_category( return tag_categories.serialize_category(
category, options=util.get_serialization_options(ctx)) category, options=util.get_serialization_options(ctx))
class TagCategoryListApi(BaseApi): @routes.get('/tag-categories/?')
def get(self, ctx): def get_tag_categories(ctx, _params=None):
auth.verify_privilege(ctx.user, 'tag_categories:list') auth.verify_privilege(ctx.user, 'tag_categories:list')
categories = tag_categories.get_all_categories() categories = tag_categories.get_all_categories()
return { return {
'results': [_serialize(ctx, category) for category in categories], 'results': [_serialize(ctx, category) for category in categories],
} }
def post(self, ctx): @routes.post('/tag-categories/?')
def create_tag_category(ctx, _params=None):
auth.verify_privilege(ctx.user, 'tag_categories:create') auth.verify_privilege(ctx.user, 'tag_categories:create')
name = ctx.get_param_as_string('name', required=True) name = ctx.get_param_as_string('name', required=True)
color = ctx.get_param_as_string('color', required=True) color = ctx.get_param_as_string('color', required=True)
@ -25,14 +26,15 @@ class TagCategoryListApi(BaseApi):
tags.export_to_json() tags.export_to_json()
return _serialize(ctx, category) return _serialize(ctx, category)
class TagCategoryDetailApi(BaseApi): @routes.get('/tag-category/(?P<category_name>[^/]+)/?')
def get(self, ctx, category_name): def get_tag_category(ctx, params):
auth.verify_privilege(ctx.user, 'tag_categories:view') auth.verify_privilege(ctx.user, 'tag_categories:view')
category = tag_categories.get_category_by_name(category_name) category = tag_categories.get_category_by_name(params['category_name'])
return _serialize(ctx, category) return _serialize(ctx, category)
def put(self, ctx, category_name): @routes.put('/tag-category/(?P<category_name>[^/]+)/?')
category = tag_categories.get_category_by_name(category_name) def update_tag_category(ctx, params):
category = tag_categories.get_category_by_name(params['category_name'])
util.verify_version(category, ctx) util.verify_version(category, ctx)
if ctx.has_param('name'): if ctx.has_param('name'):
auth.verify_privilege(ctx.user, 'tag_categories:edit:name') auth.verify_privilege(ctx.user, 'tag_categories:edit:name')
@ -49,8 +51,9 @@ class TagCategoryDetailApi(BaseApi):
tags.export_to_json() tags.export_to_json()
return _serialize(ctx, category) return _serialize(ctx, category)
def delete(self, ctx, category_name): @routes.delete('/tag-category/(?P<category_name>[^/]+)/?')
category = tag_categories.get_category_by_name(category_name) def delete_tag_category(ctx, params):
category = tag_categories.get_category_by_name(params['category_name'])
util.verify_version(category, ctx) util.verify_version(category, ctx)
auth.verify_privilege(ctx.user, 'tag_categories:delete') auth.verify_privilege(ctx.user, 'tag_categories:delete')
tag_categories.delete_category(category) tag_categories.delete_category(category)
@ -59,10 +62,10 @@ class TagCategoryDetailApi(BaseApi):
tags.export_to_json() tags.export_to_json()
return {} return {}
class DefaultTagCategoryApi(BaseApi): @routes.put('/tag-category/(?P<category_name>[^/]+)/default/?')
def put(self, ctx, category_name): def set_tag_category_as_default(ctx, params):
auth.verify_privilege(ctx.user, 'tag_categories:set_default') auth.verify_privilege(ctx.user, 'tag_categories:set_default')
category = tag_categories.get_category_by_name(category_name) category = tag_categories.get_category_by_name(params['category_name'])
tag_categories.set_default_category(category) tag_categories.set_default_category(category)
snapshots.save_entity_modification(category, ctx.user) snapshots.save_entity_modification(category, ctx.user)
ctx.session.commit() ctx.session.commit()

View file

@ -1,6 +1,8 @@
from szurubooru import search from szurubooru import search
from szurubooru.api.base_api import BaseApi
from szurubooru.func import auth, users, util from szurubooru.func import auth, users, util
from szurubooru.rest import routes
_search_executor = search.Executor(search.configs.UserSearchConfig())
def _serialize(ctx, user, **kwargs): def _serialize(ctx, user, **kwargs):
return users.serialize_user( return users.serialize_user(
@ -9,18 +11,14 @@ def _serialize(ctx, user, **kwargs):
options=util.get_serialization_options(ctx), options=util.get_serialization_options(ctx),
**kwargs) **kwargs)
class UserListApi(BaseApi): @routes.get('/users/?')
def __init__(self): def get_users(ctx, _params=None):
super().__init__()
self._search_executor = search.Executor(
search.configs.UserSearchConfig())
def get(self, ctx):
auth.verify_privilege(ctx.user, 'users:list') auth.verify_privilege(ctx.user, 'users:list')
return self._search_executor.execute_and_serialize( return _search_executor.execute_and_serialize(
ctx, lambda user: _serialize(ctx, user)) ctx, lambda user: _serialize(ctx, user))
def post(self, ctx): @routes.post('/users/?')
def create_user(ctx, _params=None):
auth.verify_privilege(ctx.user, 'users:create') auth.verify_privilege(ctx.user, 'users:create')
name = ctx.get_param_as_string('name', required=True) name = ctx.get_param_as_string('name', required=True)
password = ctx.get_param_as_string('password', required=True) password = ctx.get_param_as_string('password', required=True)
@ -38,15 +36,16 @@ class UserListApi(BaseApi):
ctx.session.commit() ctx.session.commit()
return _serialize(ctx, user, force_show_email=True) return _serialize(ctx, user, force_show_email=True)
class UserDetailApi(BaseApi): @routes.get('/user/(?P<user_name>[^/]+)/?')
def get(self, ctx, user_name): def get_user(ctx, params):
user = users.get_user_by_name(user_name) user = users.get_user_by_name(params['user_name'])
if ctx.user.user_id != user.user_id: if ctx.user.user_id != user.user_id:
auth.verify_privilege(ctx.user, 'users:view') auth.verify_privilege(ctx.user, 'users:view')
return _serialize(ctx, user) return _serialize(ctx, user)
def put(self, ctx, user_name): @routes.put('/user/(?P<user_name>[^/]+)/?')
user = users.get_user_by_name(user_name) def update_user(ctx, params):
user = users.get_user_by_name(params['user_name'])
util.verify_version(user, ctx) util.verify_version(user, ctx)
infix = 'self' if ctx.user.user_id == user.user_id else 'any' infix = 'self' if ctx.user.user_id == user.user_id else 'any'
if ctx.has_param('name'): if ctx.has_param('name'):
@ -73,8 +72,9 @@ class UserDetailApi(BaseApi):
ctx.session.commit() ctx.session.commit()
return _serialize(ctx, user) return _serialize(ctx, user)
def delete(self, ctx, user_name): @routes.delete('/user/(?P<user_name>[^/]+)/?')
user = users.get_user_by_name(user_name) def delete_user(ctx, params):
user = users.get_user_by_name(params['user_name'])
util.verify_version(user, ctx) util.verify_version(user, ctx)
infix = 'self' if ctx.user.user_id == user.user_id else 'any' infix = 'self' if ctx.user.user_id == user.user_id else 'any'
auth.verify_privilege(ctx.user, 'users:delete:%s' % infix) auth.verify_privilege(ctx.user, 'users:delete:%s' % infix)

View file

@ -1,124 +0,0 @@
''' Exports create_app. '''
import os
import logging
import coloredlogs
import falcon
from szurubooru import api, config, errors, middleware
def _on_auth_error(ex, _request, _response, _params):
raise falcon.HTTPForbidden(
title='Authentication error', description=str(ex))
def _on_validation_error(ex, _request, _response, _params):
raise falcon.HTTPBadRequest(title='Validation error', description=str(ex))
def _on_search_error(ex, _request, _response, _params):
raise falcon.HTTPBadRequest(title='Search error', description=str(ex))
def _on_integrity_error(ex, _request, _response, _params):
raise falcon.HTTPConflict(
title='Integrity violation', description=ex.args[0])
def _on_not_found_error(ex, _request, _response, _params):
raise falcon.HTTPNotFound(title='Not found', description=str(ex))
def _on_processing_error(ex, _request, _response, _params):
raise falcon.HTTPBadRequest(title='Processing error', description=str(ex))
def create_method_not_allowed(allowed_methods):
allowed = ', '.join(allowed_methods)
def method_not_allowed(request, response, **_kwargs):
response.status = falcon.status_codes.HTTP_405
response.set_header('Allow', allowed)
request.context.output = {
'title': 'Method not allowed',
'description': 'Allowed methods: %r' % allowed_methods,
}
return method_not_allowed
def validate_config():
'''
Check whether config doesn't contain errors that might prove
lethal at runtime.
'''
from szurubooru.func.auth import RANK_MAP
for privilege, rank in config.config['privileges'].items():
if rank not in RANK_MAP.values():
raise errors.ConfigError(
'Rank %r for privilege %r is missing' % (rank, privilege))
if config.config['default_rank'] not in RANK_MAP.values():
raise errors.ConfigError(
'Default rank %r is not on the list of known ranks' % (
config.config['default_rank']))
for key in ['base_url', 'api_url', 'data_url', 'data_dir']:
if not config.config[key]:
raise errors.ConfigError(
'Service is not configured: %r is missing' % key)
if not os.path.isabs(config.config['data_dir']):
raise errors.ConfigError(
'data_dir must be an absolute path')
for key in ['schema', 'host', 'port', 'user', 'pass', 'name']:
if not config.config['database'][key]:
raise errors.ConfigError(
'Database is not configured: %r is missing' % key)
def create_app():
''' Create a WSGI compatible App object. '''
validate_config()
falcon.responders.create_method_not_allowed = create_method_not_allowed
coloredlogs.install(fmt='[%(asctime)-15s] %(name)s %(message)s')
if config.config['debug']:
logging.getLogger('szurubooru').setLevel(logging.INFO)
if config.config['show_sql']:
logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
app = falcon.API(
request_type=api.Request,
middleware=[
middleware.RequireJson(),
middleware.CachePurger(),
middleware.ContextAdapter(),
middleware.DbSession(),
middleware.Authenticator(),
middleware.RequestLogger(),
])
app.add_error_handler(errors.AuthError, _on_auth_error)
app.add_error_handler(errors.IntegrityError, _on_integrity_error)
app.add_error_handler(errors.ValidationError, _on_validation_error)
app.add_error_handler(errors.SearchError, _on_search_error)
app.add_error_handler(errors.NotFoundError, _on_not_found_error)
app.add_error_handler(errors.ProcessingError, _on_processing_error)
app.add_route('/users/', api.UserListApi())
app.add_route('/user/{user_name}', api.UserDetailApi())
app.add_route('/password-reset/{user_name}', api.PasswordResetApi())
app.add_route('/tag-categories/', api.TagCategoryListApi())
app.add_route('/tag-category/{category_name}', api.TagCategoryDetailApi())
app.add_route('/tag-category/{category_name}/default', api.DefaultTagCategoryApi())
app.add_route('/tags/', api.TagListApi())
app.add_route('/tag/{tag_name}', api.TagDetailApi())
app.add_route('/tag-merge/', api.TagMergeApi())
app.add_route('/tag-siblings/{tag_name}', api.TagSiblingsApi())
app.add_route('/posts/', api.PostListApi())
app.add_route('/post/{post_id}', api.PostDetailApi())
app.add_route('/post/{post_id}/score', api.PostScoreApi())
app.add_route('/post/{post_id}/favorite', api.PostFavoriteApi())
app.add_route('/post/{post_id}/around', api.PostsAroundApi())
app.add_route('/comments/', api.CommentListApi())
app.add_route('/comment/{comment_id}', api.CommentDetailApi())
app.add_route('/comment/{comment_id}/score', api.CommentScoreApi())
app.add_route('/info/', api.InfoApi())
app.add_route('/featured-post/', api.PostFeatureApi())
app.add_route('/snapshots/', api.SnapshotListApi())
return app

View file

@ -0,0 +1,79 @@
''' Exports create_app. '''
import os
import logging
import coloredlogs
from szurubooru import config, errors, rest
# pylint: disable=unused-import
from szurubooru import api, middleware
def _on_auth_error(ex):
raise rest.errors.HttpForbidden(
title='Authentication error', description=str(ex))
def _on_validation_error(ex):
raise rest.errors.HttpBadRequest(
title='Validation error', description=str(ex))
def _on_search_error(ex):
raise rest.errors.HttpBadRequest(
title='Search error', description=str(ex))
def _on_integrity_error(ex):
raise rest.errors.HttpConflict(
title='Integrity violation', description=ex.args[0])
def _on_not_found_error(ex):
raise rest.errors.HttpNotFound(
title='Not found', description=str(ex))
def _on_processing_error(ex):
raise rest.errors.HttpBadRequest(
title='Processing error', description=str(ex))
def validate_config():
'''
Check whether config doesn't contain errors that might prove
lethal at runtime.
'''
from szurubooru.func.auth import RANK_MAP
for privilege, rank in config.config['privileges'].items():
if rank not in RANK_MAP.values():
raise errors.ConfigError(
'Rank %r for privilege %r is missing' % (rank, privilege))
if config.config['default_rank'] not in RANK_MAP.values():
raise errors.ConfigError(
'Default rank %r is not on the list of known ranks' % (
config.config['default_rank']))
for key in ['base_url', 'api_url', 'data_url', 'data_dir']:
if not config.config[key]:
raise errors.ConfigError(
'Service is not configured: %r is missing' % key)
if not os.path.isabs(config.config['data_dir']):
raise errors.ConfigError(
'data_dir must be an absolute path')
for key in ['schema', 'host', 'port', 'user', 'pass', 'name']:
if not config.config['database'][key]:
raise errors.ConfigError(
'Database is not configured: %r is missing' % key)
def create_app():
''' Create a WSGI compatible App object. '''
validate_config()
coloredlogs.install(fmt='[%(asctime)-15s] %(name)s %(message)s')
if config.config['debug']:
logging.getLogger('szurubooru').setLevel(logging.INFO)
if config.config['show_sql']:
logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
rest.errors.handle(errors.AuthError, _on_auth_error)
rest.errors.handle(errors.ValidationError, _on_validation_error)
rest.errors.handle(errors.SearchError, _on_search_error)
rest.errors.handle(errors.IntegrityError, _on_integrity_error)
rest.errors.handle(errors.NotFoundError, _on_not_found_error)
rest.errors.handle(errors.ProcessingError, _on_processing_error)
return rest.application

View file

@ -32,13 +32,6 @@ def get_tag_category_snapshot(category):
'default': True if category.default else False, 'default': True if category.default else False,
} }
# pylint: disable=invalid-name
serializers = {
'tag': get_tag_snapshot,
'tag_category': get_tag_category_snapshot,
'post': get_post_snapshot,
}
def get_previous_snapshot(snapshot): def get_previous_snapshot(snapshot):
assert snapshot assert snapshot
return db.session \ return db.session \
@ -87,6 +80,12 @@ def get_serialized_history(entity):
def _save(operation, entity, auth_user): def _save(operation, entity, auth_user):
assert operation assert operation
assert entity assert entity
serializers = {
'tag': get_tag_snapshot,
'tag_category': get_tag_category_snapshot,
'post': get_post_snapshot,
}
resource_type, resource_id, resource_repr = db.util.get_resource_info(entity) resource_type, resource_id, resource_repr = db.util.get_resource_info(entity)
now = datetime.datetime.utcnow() now = datetime.datetime.utcnow()

View file

@ -11,6 +11,9 @@ def snake_case_to_lower_camel_case(text):
return components[0].lower() + \ return components[0].lower() + \
''.join(word[0].upper() + word[1:].lower() for word in components[1:]) ''.join(word[0].upper() + word[1:].lower() for word in components[1:])
def snake_case_to_upper_train_case(text):
return '-'.join(word[0].upper() + word[1:].lower() for word in text.split('_'))
def snake_case_to_lower_camel_case_keys(source): def snake_case_to_lower_camel_case_keys(source):
target = {} target = {}
for key, value in source.items(): for key, value in source.items():

View file

@ -1,8 +1,6 @@
''' Various hooks that get executed for each request. ''' ''' Various hooks that get executed for each request. '''
from szurubooru.middleware.authenticator import Authenticator import szurubooru.middleware.db_session
from szurubooru.middleware.context_adapter import ContextAdapter import szurubooru.middleware.authenticator
from szurubooru.middleware.require_json import RequireJson import szurubooru.middleware.cache_purger
from szurubooru.middleware.db_session import DbSession import szurubooru.middleware.request_logger
from szurubooru.middleware.cache_purger import CachePurger
from szurubooru.middleware.request_logger import RequestLogger

View file

@ -1,51 +1,44 @@
import base64 import base64
import falcon
from szurubooru import db, errors from szurubooru import db, errors
from szurubooru.func import auth, users from szurubooru.func import auth, users
from szurubooru.rest import middleware
from szurubooru.rest.errors import HttpBadRequest
class Authenticator(object): def _authenticate(username, password):
'''
Authenticates every request and put information on active user in the
request context.
'''
def process_request(self, request, _response):
''' Bind the user to request. Update last login time if needed. '''
request.context.user = self._get_user(request)
if request.get_param_as_bool('bump-login') \
and request.context.user.user_id:
users.bump_user_login_time(request.context.user)
request.context.session.commit()
def _get_user(self, request):
if not request.auth:
return self._create_anonymous_user()
try:
auth_type, user_and_password = request.auth.split(' ', 1)
if auth_type.lower() != 'basic':
raise falcon.HTTPBadRequest(
'Invalid authentication type',
'Only basic authorization is supported.')
username, password = base64.decodebytes(
user_and_password.encode('ascii')).decode('utf8').split(':')
return self._authenticate(username, password)
except ValueError as err:
msg = 'Basic authentication header value not properly formed. ' \
+ 'Supplied header {0}. Got error: {1}'
raise falcon.HTTPBadRequest(
'Malformed authentication request',
msg.format(request.auth, str(err)))
def _authenticate(self, username, password):
''' Try to authenticate user. Throw AuthError for invalid users. ''' ''' Try to authenticate user. Throw AuthError for invalid users. '''
user = users.get_user_by_name(username) user = users.get_user_by_name(username)
if not auth.is_valid_password(user, password): if not auth.is_valid_password(user, password):
raise errors.AuthError('Invalid password.') raise errors.AuthError('Invalid password.')
return user return user
def _create_anonymous_user(self): def _create_anonymous_user():
user = db.User() user = db.User()
user.name = None user.name = None
user.rank = 'anonymous' user.rank = 'anonymous'
return user return user
def _get_user(ctx):
if not ctx.has_header('Authorization'):
return _create_anonymous_user()
try:
auth_type, user_and_password = ctx.get_header('Authorization').split(' ', 1)
if auth_type.lower() != 'basic':
raise HttpBadRequest(
'Only basic HTTP authentication is supported.')
username, password = base64.decodebytes(
user_and_password.encode('ascii')).decode('utf8').split(':')
return _authenticate(username, password)
except ValueError as err:
msg = 'Basic authentication header value are not properly formed. ' \
+ 'Supplied header {0}. Got error: {1}'
raise HttpBadRequest(
msg.format(ctx.get_header('Authorization'), str(err)))
@middleware.pre_hook
def process_request(ctx):
''' Bind the user to request. Update last login time if needed. '''
ctx.user = _get_user(ctx)
if ctx.get_param_as_bool('bump-login') and ctx.user.user_id:
users.bump_user_login_time(ctx.user)
ctx.session.commit()

View file

@ -1,6 +1,7 @@
from szurubooru.func import cache from szurubooru.func import cache
from szurubooru.rest import middleware
class CachePurger(object): @middleware.pre_hook
def process_request(self, request, _response): def process_request(ctx):
if request.method != 'GET': if ctx.method != 'GET':
cache.purge() cache.purge()

View file

@ -1,65 +0,0 @@
import cgi
import datetime
import json
import falcon
def json_serializer(obj):
''' JSON serializer for objects not serializable by default JSON code '''
if isinstance(obj, datetime.datetime):
serial = obj.isoformat('T') + 'Z'
return serial
raise TypeError('Type not serializable')
class ContextAdapter(object):
'''
1. Deserialize API requests into the context:
- Pass GET parameters
- Handle multipart/form-data file uploads
- Handle JSON requests
2. Serialize API responses from the context as JSON.
'''
def process_request(self, request, _response):
request.context.files = {}
request.context.input = {}
request.context.output = None
# pylint: disable=protected-access
for key, value in request._params.items():
request.context.input[key] = value
if request.content_length in (None, 0):
return
if request.content_type and 'multipart/form-data' in request.content_type:
# obscure, claims to "avoid a bug in cgi.FieldStorage"
request.env.setdefault('QUERY_STRING', '')
form = cgi.FieldStorage(fp=request.stream, environ=request.env)
for key in form:
if key != 'metadata':
_original_file_name = getattr(form[key], 'filename', None)
request.context.files[key] = form.getvalue(key)
body = form.getvalue('metadata')
else:
body = request.stream.read()
if not body:
raise falcon.HTTPBadRequest(
'Empty request body',
'A valid JSON document is required.')
try:
if isinstance(body, bytes):
body = body.decode('utf-8')
for key, value in json.loads(body).items():
request.context.input[key] = value
except (ValueError, UnicodeDecodeError):
raise falcon.HTTPBadRequest(
'Malformed JSON',
'Could not decode the request body. The '
'JSON was incorrect or not encoded as UTF-8.')
def process_response(self, request, response, _resource):
if request.context.output:
response.body = json.dumps(
request.context.output, default=json_serializer, indent=2)

View file

@ -1,14 +1,11 @@
import logging
from szurubooru import db from szurubooru import db
from szurubooru.rest import middleware
logger = logging.getLogger(__name__) @middleware.pre_hook
def _process_request(ctx):
class DbSession(object): ctx.session = db.session()
''' Attaches database session to the context of every request. '''
def process_request(self, request, _response):
request.context.session = db.session()
db.reset_query_count() db.reset_query_count()
def process_response(self, _request, _response, _resource): @middleware.post_hook
def _process_response(_ctx):
db.session.remove() db.session.remove()

View file

@ -1,16 +1,14 @@
import logging import logging
from szurubooru import db from szurubooru import db
from szurubooru.rest import middleware
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RequestLogger(object): @middleware.post_hook
def process_request(self, request, _response): def process_response(ctx):
pass
def process_response(self, request, _response, _resource):
logger.info( logger.info(
'%s %s (user=%s, queries=%d)', '%s %s (user=%s, queries=%d)',
request.method, ctx.method,
request.url, ctx.url,
request.context.user.name, ctx.user.name,
db.get_query_count()) db.get_query_count())

View file

@ -1,9 +0,0 @@
import falcon
class RequireJson(object):
''' Sanitizes requests so that only JSON is accepted. '''
def process_request(self, request, _response):
if not request.client_accepts_json:
raise falcon.HTTPNotAcceptable(
'This API only supports responses encoded as JSON.')

View file

@ -0,0 +1,2 @@
from szurubooru.rest.app import application
from szurubooru.rest.context import Context

View file

@ -0,0 +1,124 @@
import cgi
import io
import json
import re
from datetime import datetime
from szurubooru.func import util
from szurubooru.rest import errors, middleware, routes, context
def _json_serializer(obj):
''' JSON serializer for objects not serializable by default JSON code '''
if isinstance(obj, datetime):
serial = obj.isoformat('T') + 'Z'
return serial
raise TypeError('Type not serializable')
def _dump_json(obj):
return json.dumps(obj, default=_json_serializer, indent=2)
def _read(env):
length = int(env.get('CONTENT_LENGTH', 0))
output = io.BytesIO()
while length > 0:
part = env['wsgi.input'].read(min(length, 1024*200))
if not part:
break
output.write(part)
length -= len(part)
output.seek(0)
return output
def _get_headers(env):
headers = {}
for key, value in env.items():
if key.startswith('HTTP_'):
key = util.snake_case_to_upper_train_case(key[5:])
headers[key] = value
return headers
def _create_context(env):
method = env['REQUEST_METHOD']
path = '/' + env['PATH_INFO'].lstrip('/')
headers = _get_headers(env)
# obscure, claims to "avoid a bug in cgi.FieldStorage"
env.setdefault('QUERY_STRING', '')
files = {}
params = {}
request_stream = _read(env)
form = cgi.FieldStorage(fp=request_stream, environ=env)
if form.list:
for key in form:
if key != 'metadata':
if isinstance(form[key], cgi.MiniFieldStorage):
params[key] = form.getvalue(key)
else:
_original_file_name = getattr(form[key], 'filename', None)
files[key] = form.getvalue(key)
if 'metadata' in form:
body = form.getvalue('metadata')
else:
body = request_stream.read()
else:
body = None
if body:
try:
if isinstance(body, bytes):
body = body.decode('utf-8')
for key, value in json.loads(body).items():
params[key] = value
except (ValueError, UnicodeDecodeError):
raise errors.HttpBadRequest(
'Could not decode the request body. The JSON '
'was incorrect or was not encoded as UTF-8.')
return context.Context(method, path, headers, params, files)
def application(env, start_response):
try:
ctx = _create_context(env)
if not 'application/json' in ctx.get_header('Accept'):
raise errors.HttpNotAcceptable(
'This API only supports JSON responses.')
for url, allowed_methods in routes.routes.items():
match = re.fullmatch(url, ctx.url)
if not match:
continue
if ctx.method not in allowed_methods:
raise errors.HttpMethodNotAllowed(
'Allowed methods: %r' % allowed_methods)
for hook in middleware.pre_hooks:
hook(ctx)
handler = allowed_methods[ctx.method]
try:
response = handler(ctx, match.groupdict())
except Exception as ex:
for exception_type, handler in errors.error_handlers.items():
if isinstance(ex, exception_type):
handler(ex)
raise
finally:
for hook in middleware.post_hooks:
hook(ctx)
start_response('200', [('content-type', 'application/json')])
return (_dump_json(response).encode('utf-8'),)
raise errors.HttpNotFound(
'Requested path ' + ctx.url + ' was not found.')
except errors.BaseHttpError as ex:
start_response(
'%d %s' % (ex.code, ex.reason),
[('content-type', 'application/json')])
return (_dump_json({
'title': ex.title,
'description': ex.description,
}).encode('utf-8'),)

View file

@ -1,4 +1,3 @@
import falcon
from szurubooru import errors from szurubooru import errors
from szurubooru.func import net from szurubooru.func import net
@ -7,8 +6,9 @@ def _lower_first(source):
def _param_wrapper(func): def _param_wrapper(func):
def wrapper(self, name, required=False, default=None, **kwargs): def wrapper(self, name, required=False, default=None, **kwargs):
if name in self.input: # pylint: disable=protected-access
value = self.input[name] if name in self._params:
value = self._params[name]
try: try:
value = func(self, value, **kwargs) value = func(self, value, **kwargs)
except errors.InvalidParameterError as ex: except errors.InvalidParameterError as ex:
@ -22,34 +22,46 @@ def _param_wrapper(func):
'Required parameter %r is missing.' % name) 'Required parameter %r is missing.' % name)
return wrapper return wrapper
class Context(object): class Context():
def __init__(self): # pylint: disable=too-many-arguments
self.session = None def __init__(self, method, url, headers=None, params=None, files=None):
self.user = None self.method = method
self.files = {} self.url = url
self.input = {} self._headers = headers or {}
self.output = None self._params = params or {}
self.settings = {} self._files = files or {}
def has_param(self, name): # provided by middleware
return name in self.input # self.session = None
# self.user = None
def has_header(self, name):
return name in self._headers
def get_header(self, name):
return self._headers.get(name, None)
def has_file(self, name): def has_file(self, name):
return name in self.files or name + 'Url' in self.input return name in self._files or name + 'Url' in self._params
def get_file(self, name, required=False): def get_file(self, name, required=False):
if name in self.files: if name in self._files:
return self.files[name] return self._files[name]
if name + 'Url' in self.input: if name + 'Url' in self._params:
return net.download(self.input[name + 'Url']) return net.download(self._params[name + 'Url'])
if not required: if not required:
return None return None
raise errors.MissingRequiredFileError( raise errors.MissingRequiredFileError(
'Required file %r is missing.' % name) 'Required file %r is missing.' % name)
def has_param(self, name):
return name in self._params
@_param_wrapper @_param_wrapper
def get_param_as_list(self, value): def get_param_as_list(self, value):
if not isinstance(value, list): if not isinstance(value, list):
if ',' in value:
return value.split(',')
return [value] return [value]
return value return value
@ -86,6 +98,3 @@ class Context(object):
if value in ['0', 'n', 'no', 'nope', 'f', 'false']: if value in ['0', 'n', 'no', 'nope', 'f', 'false']:
return False return False
raise errors.InvalidParameterError('The value must be a boolean value.') raise errors.InvalidParameterError('The value must be a boolean value.')
class Request(falcon.Request):
context_type = Context

View file

@ -0,0 +1,37 @@
error_handlers = {} # pylint: disable=invalid-name
class BaseHttpError(RuntimeError):
code = None
reason = None
def __init__(self, description, title=None):
super().__init__()
self.description = description
self.title = title or self.reason
class HttpBadRequest(BaseHttpError):
code = 400
reason = 'Bad Request'
class HttpForbidden(BaseHttpError):
code = 403
reason = 'Forbidden'
class HttpNotFound(BaseHttpError):
code = 404
reason = 'Not Found'
class HttpNotAcceptable(BaseHttpError):
code = 406
reason = 'Not Acceptable'
class HttpConflict(BaseHttpError):
code = 409
reason = 'Conflict'
class HttpMethodNotAllowed(BaseHttpError):
code = 405
reason = 'Method Not Allowed'
def handle(exception_type, handler):
error_handlers[exception_type] = handler

View file

@ -0,0 +1,9 @@
# pylint: disable=invalid-name
pre_hooks = []
post_hooks = []
def pre_hook(handler):
pre_hooks.append(handler)
def post_hook(handler):
post_hooks.insert(0, handler)

View file

@ -0,0 +1,27 @@
from collections import defaultdict
routes = defaultdict(dict) # pylint: disable=invalid-name
def get(url):
def wrapper(handler):
routes[url]['GET'] = handler
return handler
return wrapper
def put(url):
def wrapper(handler):
routes[url]['PUT'] = handler
return handler
return wrapper
def post(url):
def wrapper(handler):
routes[url]['POST'] = handler
return handler
return wrapper
def delete(url):
def wrapper(handler):
routes[url]['DELETE'] = handler
return handler
return wrapper

View file

@ -1,89 +1,78 @@
import datetime
import pytest import pytest
import unittest.mock
from datetime import datetime
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import util, posts from szurubooru.func import comments, posts
@pytest.fixture @pytest.fixture(autouse=True)
def test_ctx( def inject_config(config_injector):
tmpdir, config_injector, context_factory, post_factory, user_factory): config_injector({'privileges': {'comments:create': db.User.RANK_REGULAR}})
config_injector({
'data_dir': str(tmpdir),
'data_url': 'http://example.com',
'privileges': {'comments:create': db.User.RANK_REGULAR},
'thumbnails': {'avatar_width': 200},
})
ret = util.dotdict()
ret.context_factory = context_factory
ret.post_factory = post_factory
ret.user_factory = user_factory
ret.api = api.CommentListApi()
return ret
def test_creating_comment(test_ctx, fake_datetime): def test_creating_comment(
post = test_ctx.post_factory() user_factory, post_factory, context_factory, fake_datetime):
user = test_ctx.user_factory(rank=db.User.RANK_REGULAR) post = post_factory()
user = user_factory(rank=db.User.RANK_REGULAR)
db.session.add_all([post, user]) db.session.add_all([post, user])
db.session.flush() db.session.flush()
with fake_datetime('1997-01-01'): with unittest.mock.patch('szurubooru.func.comments.serialize_comment'), \
result = test_ctx.api.post( fake_datetime('1997-01-01'):
test_ctx.context_factory( comments.serialize_comment.return_value = 'serialized comment'
input={'text': 'input', 'postId': post.post_id}, result = api.comment_api.create_comment(
context_factory(
params={'text': 'input', 'postId': post.post_id},
user=user)) user=user))
assert result['text'] == 'input' assert result == 'serialized comment'
assert 'id' in result
assert 'user' in result
assert 'name' in result['user']
assert 'postId' in result
comment = db.session.query(db.Comment).one() comment = db.session.query(db.Comment).one()
assert comment.text == 'input' assert comment.text == 'input'
assert comment.creation_time == datetime.datetime(1997, 1, 1) assert comment.creation_time == datetime(1997, 1, 1)
assert comment.last_edit_time is None assert comment.last_edit_time is None
assert comment.user and comment.user.user_id == user.user_id assert comment.user and comment.user.user_id == user.user_id
assert comment.post and comment.post.post_id == post.post_id assert comment.post and comment.post.post_id == post.post_id
@pytest.mark.parametrize('input', [ @pytest.mark.parametrize('params', [
{'text': None}, {'text': None},
{'text': ''}, {'text': ''},
{'text': [None]}, {'text': [None]},
{'text': ['']}, {'text': ['']},
]) ])
def test_trying_to_pass_invalid_input(test_ctx, input): def test_trying_to_pass_invalid_params(
post = test_ctx.post_factory() user_factory, post_factory, context_factory, params):
user = test_ctx.user_factory(rank=db.User.RANK_REGULAR) post = post_factory()
user = user_factory(rank=db.User.RANK_REGULAR)
db.session.add_all([post, user]) db.session.add_all([post, user])
db.session.flush() db.session.flush()
real_input = {'text': 'input', 'postId': post.post_id} real_params = {'text': 'input', 'postId': post.post_id}
for key, value in input.items(): for key, value in params.items():
real_input[key] = value real_params[key] = value
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError):
test_ctx.api.post( api.comment_api.create_comment(
test_ctx.context_factory(input=real_input, user=user)) context_factory(params=real_params, user=user))
@pytest.mark.parametrize('field', ['text', 'postId']) @pytest.mark.parametrize('field', ['text', 'postId'])
def test_trying_to_omit_mandatory_field(test_ctx, field): def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
input = { params = {
'text': 'input', 'text': 'input',
'postId': 1, 'postId': 1,
} }
del input[field] del params[field]
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError):
test_ctx.api.post( api.comment_api.create_comment(
test_ctx.context_factory( context_factory(
input={}, params={},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=db.User.RANK_REGULAR)))
def test_trying_to_comment_non_existing(test_ctx): def test_trying_to_comment_non_existing(user_factory, context_factory):
user = test_ctx.user_factory(rank=db.User.RANK_REGULAR) user = user_factory(rank=db.User.RANK_REGULAR)
db.session.add_all([user]) db.session.add_all([user])
db.session.flush() db.session.flush()
with pytest.raises(posts.PostNotFoundError): with pytest.raises(posts.PostNotFoundError):
test_ctx.api.post( api.comment_api.create_comment(
test_ctx.context_factory( context_factory(
input={'text': 'bad', 'postId': 5}, user=user)) params={'text': 'bad', 'postId': 5}, user=user))
def test_trying_to_create_without_privileges(test_ctx): def test_trying_to_create_without_privileges(user_factory, context_factory):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.api.post( api.comment_api.create_comment(
test_ctx.context_factory( context_factory(
input={}, params={},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS))) user=user_factory(rank=db.User.RANK_ANONYMOUS)))

View file

@ -1,61 +1,56 @@
import pytest import pytest
from datetime import datetime
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import util, comments from szurubooru.func import comments
@pytest.fixture @pytest.fixture(autouse=True)
def test_ctx(config_injector, context_factory, user_factory, comment_factory): def inject_config(config_injector):
config_injector({ config_injector({
'privileges': { 'privileges': {
'comments:delete:own': db.User.RANK_REGULAR, 'comments:delete:own': db.User.RANK_REGULAR,
'comments:delete:any': db.User.RANK_MODERATOR, 'comments:delete:any': db.User.RANK_MODERATOR,
}, },
}) })
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.comment_factory = comment_factory
ret.api = api.CommentDetailApi()
return ret
def test_deleting_own_comment(test_ctx): def test_deleting_own_comment(user_factory, comment_factory, context_factory):
user = test_ctx.user_factory() user = user_factory()
comment = test_ctx.comment_factory(user=user) comment = comment_factory(user=user)
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()
result = test_ctx.api.delete( result = api.comment_api.delete_comment(
test_ctx.context_factory(input={'version': 1}, user=user), context_factory(params={'version': 1}, user=user),
comment.comment_id) {'comment_id': comment.comment_id})
assert result == {} assert result == {}
assert db.session.query(db.Comment).count() == 0 assert db.session.query(db.Comment).count() == 0
def test_deleting_someones_else_comment(test_ctx): def test_deleting_someones_else_comment(
user1 = test_ctx.user_factory(rank=db.User.RANK_REGULAR) user_factory, comment_factory, context_factory):
user2 = test_ctx.user_factory(rank=db.User.RANK_MODERATOR) user1 = user_factory(rank=db.User.RANK_REGULAR)
comment = test_ctx.comment_factory(user=user1) user2 = user_factory(rank=db.User.RANK_MODERATOR)
comment = comment_factory(user=user1)
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()
result = test_ctx.api.delete( result = api.comment_api.delete_comment(
test_ctx.context_factory(input={'version': 1}, user=user2), context_factory(params={'version': 1}, user=user2),
comment.comment_id) {'comment_id': comment.comment_id})
assert db.session.query(db.Comment).count() == 0 assert db.session.query(db.Comment).count() == 0
def test_trying_to_delete_someones_else_comment_without_privileges(test_ctx): def test_trying_to_delete_someones_else_comment_without_privileges(
user1 = test_ctx.user_factory(rank=db.User.RANK_REGULAR) user_factory, comment_factory, context_factory):
user2 = test_ctx.user_factory(rank=db.User.RANK_REGULAR) user1 = user_factory(rank=db.User.RANK_REGULAR)
comment = test_ctx.comment_factory(user=user1) user2 = user_factory(rank=db.User.RANK_REGULAR)
comment = comment_factory(user=user1)
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.api.delete( api.comment_api.delete_comment(
test_ctx.context_factory(input={'version': 1}, user=user2), context_factory(params={'version': 1}, user=user2),
comment.comment_id) {'comment_id': comment.comment_id})
assert db.session.query(db.Comment).count() == 1 assert db.session.query(db.Comment).count() == 1
def test_trying_to_delete_non_existing(test_ctx): def test_trying_to_delete_non_existing(user_factory, context_factory):
with pytest.raises(comments.CommentNotFoundError): with pytest.raises(comments.CommentNotFoundError):
test_ctx.api.delete( api.comment_api.delete_comment(
test_ctx.context_factory( context_factory(
input={'version': 1}, params={'version': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=db.User.RANK_REGULAR)),
1) {'comment_id': 1})

View file

@ -1,152 +1,134 @@
import datetime
import pytest import pytest
import unittest.mock
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import util, comments, scores from szurubooru.func import comments, scores
@pytest.fixture @pytest.fixture(autouse=True)
def test_ctx( def inject_config(config_injector):
tmpdir, config_injector, context_factory, user_factory, comment_factory): config_injector({'privileges': {'comments:score': db.User.RANK_REGULAR}})
config_injector({
'data_dir': str(tmpdir),
'data_url': 'http://example.com',
'privileges': {
'comments:score': db.User.RANK_REGULAR,
'users:edit:any:email': db.User.RANK_MODERATOR,
},
'thumbnails': {'avatar_width': 200},
})
db.session.flush()
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.comment_factory = comment_factory
ret.api = api.CommentScoreApi()
return ret
def test_simple_rating(test_ctx, fake_datetime): def test_simple_rating(
user = test_ctx.user_factory(rank=db.User.RANK_REGULAR) user_factory, comment_factory, context_factory, fake_datetime):
comment = test_ctx.comment_factory(user=user) user = user_factory(rank=db.User.RANK_REGULAR)
comment = comment_factory(user=user)
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'):
comments.serialize_comment.return_value = 'serialized comment'
with fake_datetime('1997-12-01'): with fake_datetime('1997-12-01'):
result = test_ctx.api.put( result = api.comment_api.set_comment_score(
test_ctx.context_factory(input={'score': 1}, user=user), context_factory(params={'score': 1}, user=user),
comment.comment_id) {'comment_id': comment.comment_id})
assert 'text' in result assert result == 'serialized comment'
assert db.session.query(db.CommentScore).count() == 1 assert db.session.query(db.CommentScore).count() == 1
assert comment is not None assert comment is not None
assert comment.score == 1 assert comment.score == 1
def test_updating_rating(test_ctx, fake_datetime): def test_updating_rating(
user = test_ctx.user_factory(rank=db.User.RANK_REGULAR) user_factory, comment_factory, context_factory, fake_datetime):
comment = test_ctx.comment_factory(user=user) user = user_factory(rank=db.User.RANK_REGULAR)
comment = comment_factory(user=user)
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'):
with fake_datetime('1997-12-01'): with fake_datetime('1997-12-01'):
result = test_ctx.api.put( result = api.comment_api.set_comment_score(
test_ctx.context_factory(input={'score': 1}, user=user), context_factory(params={'score': 1}, user=user),
comment.comment_id) {'comment_id': comment.comment_id})
with fake_datetime('1997-12-02'): with fake_datetime('1997-12-02'):
result = test_ctx.api.put( result = api.comment_api.set_comment_score(
test_ctx.context_factory(input={'score': -1}, user=user), context_factory(params={'score': -1}, user=user),
comment.comment_id) {'comment_id': comment.comment_id})
comment = db.session.query(db.Comment).one() comment = db.session.query(db.Comment).one()
assert db.session.query(db.CommentScore).count() == 1 assert db.session.query(db.CommentScore).count() == 1
assert comment.score == -1 assert comment.score == -1
def test_updating_rating_to_zero(test_ctx, fake_datetime): def test_updating_rating_to_zero(
user = test_ctx.user_factory(rank=db.User.RANK_REGULAR) user_factory, comment_factory, context_factory, fake_datetime):
comment = test_ctx.comment_factory(user=user) user = user_factory(rank=db.User.RANK_REGULAR)
comment = comment_factory(user=user)
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'):
with fake_datetime('1997-12-01'): with fake_datetime('1997-12-01'):
result = test_ctx.api.put( result = api.comment_api.set_comment_score(
test_ctx.context_factory(input={'score': 1}, user=user), context_factory(params={'score': 1}, user=user),
comment.comment_id) {'comment_id': comment.comment_id})
with fake_datetime('1997-12-02'): with fake_datetime('1997-12-02'):
result = test_ctx.api.put( result = api.comment_api.set_comment_score(
test_ctx.context_factory(input={'score': 0}, user=user), context_factory(params={'score': 0}, user=user),
comment.comment_id) {'comment_id': comment.comment_id})
comment = db.session.query(db.Comment).one() comment = db.session.query(db.Comment).one()
assert db.session.query(db.CommentScore).count() == 0 assert db.session.query(db.CommentScore).count() == 0
assert comment.score == 0 assert comment.score == 0
def test_deleting_rating(test_ctx, fake_datetime): def test_deleting_rating(
user = test_ctx.user_factory(rank=db.User.RANK_REGULAR) user_factory, comment_factory, context_factory, fake_datetime):
comment = test_ctx.comment_factory(user=user) user = user_factory(rank=db.User.RANK_REGULAR)
comment = comment_factory(user=user)
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'):
with fake_datetime('1997-12-01'): with fake_datetime('1997-12-01'):
result = test_ctx.api.put( result = api.comment_api.set_comment_score(
test_ctx.context_factory(input={'score': 1}, user=user), context_factory(params={'score': 1}, user=user),
comment.comment_id) {'comment_id': comment.comment_id})
with fake_datetime('1997-12-02'): with fake_datetime('1997-12-02'):
result = test_ctx.api.delete( result = api.comment_api.delete_comment_score(
test_ctx.context_factory(user=user), comment.comment_id) context_factory(user=user),
{'comment_id': comment.comment_id})
comment = db.session.query(db.Comment).one() comment = db.session.query(db.Comment).one()
assert db.session.query(db.CommentScore).count() == 0 assert db.session.query(db.CommentScore).count() == 0
assert comment.score == 0 assert comment.score == 0
def test_ratings_from_multiple_users(test_ctx, fake_datetime): def test_ratings_from_multiple_users(
user1 = test_ctx.user_factory(rank=db.User.RANK_REGULAR) user_factory, comment_factory, context_factory, fake_datetime):
user2 = test_ctx.user_factory(rank=db.User.RANK_REGULAR) user1 = user_factory(rank=db.User.RANK_REGULAR)
comment = test_ctx.comment_factory() user2 = user_factory(rank=db.User.RANK_REGULAR)
comment = comment_factory()
db.session.add_all([user1, user2, comment]) db.session.add_all([user1, user2, comment])
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'):
with fake_datetime('1997-12-01'): with fake_datetime('1997-12-01'):
result = test_ctx.api.put( result = api.comment_api.set_comment_score(
test_ctx.context_factory(input={'score': 1}, user=user1), context_factory(params={'score': 1}, user=user1),
comment.comment_id) {'comment_id': comment.comment_id})
with fake_datetime('1997-12-02'): with fake_datetime('1997-12-02'):
result = test_ctx.api.put( result = api.comment_api.set_comment_score(
test_ctx.context_factory(input={'score': -1}, user=user2), context_factory(params={'score': -1}, user=user2),
comment.comment_id) {'comment_id': comment.comment_id})
comment = db.session.query(db.Comment).one() comment = db.session.query(db.Comment).one()
assert db.session.query(db.CommentScore).count() == 2 assert db.session.query(db.CommentScore).count() == 2
assert comment.score == 0 assert comment.score == 0
@pytest.mark.parametrize('input,expected_exception', [ def test_trying_to_omit_mandatory_field(
({'score': None}, errors.ValidationError), user_factory, comment_factory, context_factory):
({'score': ''}, errors.ValidationError), user = user_factory()
({'score': -2}, scores.InvalidScoreValueError), comment = comment_factory(user=user)
({'score': 2}, scores.InvalidScoreValueError),
({'score': [1]}, errors.ValidationError),
])
def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception):
user = test_ctx.user_factory()
comment = test_ctx.comment_factory(user=user)
db.session.add(comment)
db.session.commit()
with pytest.raises(expected_exception):
test_ctx.api.put(
test_ctx.context_factory(input=input, user=user),
comment.comment_id)
def test_trying_to_omit_mandatory_field(test_ctx):
user = test_ctx.user_factory()
comment = test_ctx.comment_factory(user=user)
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError):
test_ctx.api.put( api.comment_api.set_comment_score(
test_ctx.context_factory(input={}, user=user), context_factory(params={}, user=user),
comment.comment_id) {'comment_id': comment.comment_id})
def test_trying_to_update_non_existing(test_ctx): def test_trying_to_update_non_existing(
user_factory, comment_factory, context_factory):
with pytest.raises(comments.CommentNotFoundError): with pytest.raises(comments.CommentNotFoundError):
test_ctx.api.put( api.comment_api.set_comment_score(
test_ctx.context_factory( context_factory(
input={'score': 1}, params={'score': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=db.User.RANK_REGULAR)),
5) {'comment_id': 5})
def test_trying_to_rate_without_privileges(test_ctx): def test_trying_to_rate_without_privileges(
comment = test_ctx.comment_factory() user_factory, comment_factory, context_factory):
comment = comment_factory()
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.api.put( api.comment_api.set_comment_score(
test_ctx.context_factory( context_factory(
input={'score': 1}, params={'score': 1},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)), user=user_factory(rank=db.User.RANK_ANONYMOUS)),
comment.comment_id) {'comment_id': comment.comment_id})

View file

@ -1,76 +1,65 @@
import datetime
import pytest import pytest
import unittest.mock
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import util, comments from szurubooru.func import comments
@pytest.fixture @pytest.fixture(autouse=True)
def test_ctx( def inject_config(config_injector):
tmpdir, context_factory, config_injector, user_factory, comment_factory):
config_injector({ config_injector({
'data_dir': str(tmpdir),
'data_url': 'http://example.com',
'privileges': { 'privileges': {
'comments:list': db.User.RANK_REGULAR, 'comments:list': db.User.RANK_REGULAR,
'comments:view': db.User.RANK_REGULAR, 'comments:view': db.User.RANK_REGULAR,
'users:edit:any:email': db.User.RANK_MODERATOR,
}, },
'thumbnails': {'avatar_width': 200},
}) })
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.comment_factory = comment_factory
ret.list_api = api.CommentListApi()
ret.detail_api = api.CommentDetailApi()
return ret
def test_retrieving_multiple(test_ctx): def test_retrieving_multiple(user_factory, comment_factory, context_factory):
comment1 = test_ctx.comment_factory(text='text 1') comment1 = comment_factory(text='text 1')
comment2 = test_ctx.comment_factory(text='text 2') comment2 = comment_factory(text='text 2')
db.session.add_all([comment1, comment2]) db.session.add_all([comment1, comment2])
result = test_ctx.list_api.get( with unittest.mock.patch('szurubooru.func.comments.serialize_comment'):
test_ctx.context_factory( comments.serialize_comment.return_value = 'serialized comment'
input={'query': '', 'page': 1}, result = api.comment_api.get_comments(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) context_factory(
assert result['query'] == '' params={'query': '', 'page': 1},
assert result['page'] == 1 user=user_factory(rank=db.User.RANK_REGULAR)))
assert result['pageSize'] == 100 assert result == {
assert result['total'] == 2 'query': '',
assert [c['text'] for c in result['results']] == ['text 1', 'text 2'] 'page': 1,
'pageSize': 100,
'total': 2,
'results': ['serialized comment', 'serialized comment'],
}
def test_trying_to_retrieve_multiple_without_privileges(test_ctx): def test_trying_to_retrieve_multiple_without_privileges(
user_factory, context_factory):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.list_api.get( api.comment_api.get_comments(
test_ctx.context_factory( context_factory(
input={'query': '', 'page': 1}, params={'query': '', 'page': 1},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS))) user=user_factory(rank=db.User.RANK_ANONYMOUS)))
def test_retrieving_single(test_ctx): def test_retrieving_single(user_factory, comment_factory, context_factory):
comment = test_ctx.comment_factory(text='dummy text') comment = comment_factory(text='dummy text')
db.session.add(comment) db.session.add(comment)
db.session.flush() db.session.flush()
result = test_ctx.detail_api.get( with unittest.mock.patch('szurubooru.func.comments.serialize_comment'):
test_ctx.context_factory( comments.serialize_comment.return_value = 'serialized comment'
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), result = api.comment_api.get_comment(
comment.comment_id) context_factory(
assert 'id' in result user=user_factory(rank=db.User.RANK_REGULAR)),
assert 'lastEditTime' in result {'comment_id': comment.comment_id})
assert 'creationTime' in result assert result == 'serialized comment'
assert 'text' in result
assert 'user' in result
assert 'name' in result['user']
assert 'postId' in result
def test_trying_to_retrieve_single_non_existing(test_ctx): def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
with pytest.raises(comments.CommentNotFoundError): with pytest.raises(comments.CommentNotFoundError):
test_ctx.detail_api.get( api.comment_api.get_comment(
test_ctx.context_factory( context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=db.User.RANK_REGULAR)),
5) {'comment_id': 5})
def test_trying_to_retrieve_single_without_privileges(test_ctx): def test_trying_to_retrieve_single_without_privileges(
user_factory, context_factory):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.detail_api.get( api.comment_api.get_comment(
test_ctx.context_factory( context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)),
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)), {'comment_id': 5})
5)

View file

@ -1,103 +1,94 @@
import datetime
import pytest import pytest
import unittest.mock
from datetime import datetime
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import util, comments from szurubooru.func import comments
@pytest.fixture @pytest.fixture(autouse=True)
def test_ctx( def inject_config(config_injector):
tmpdir, config_injector, context_factory, user_factory, comment_factory):
config_injector({ config_injector({
'data_dir': str(tmpdir),
'data_url': 'http://example.com',
'privileges': { 'privileges': {
'comments:edit:own': db.User.RANK_REGULAR, 'comments:edit:own': db.User.RANK_REGULAR,
'comments:edit:any': db.User.RANK_MODERATOR, 'comments:edit:any': db.User.RANK_MODERATOR,
'users:edit:any:email': db.User.RANK_MODERATOR,
}, },
'thumbnails': {'avatar_width': 200},
}) })
db.session.flush()
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.comment_factory = comment_factory
ret.api = api.CommentDetailApi()
return ret
def test_simple_updating(test_ctx, fake_datetime): def test_simple_updating(
user = test_ctx.user_factory(rank=db.User.RANK_REGULAR) user_factory, comment_factory, context_factory, fake_datetime):
comment = test_ctx.comment_factory(user=user) user = user_factory(rank=db.User.RANK_REGULAR)
comment = comment_factory(user=user)
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()
with fake_datetime('1997-12-01'): with unittest.mock.patch('szurubooru.func.comments.serialize_comment'), \
result = test_ctx.api.put( fake_datetime('1997-12-01'):
test_ctx.context_factory( comments.serialize_comment.return_value = 'serialized comment'
input={'text': 'new text', 'version': 1}, user=user), result = api.comment_api.update_comment(
comment.comment_id) context_factory(
assert result['text'] == 'new text' params={'text': 'new text', 'version': 1}, user=user),
comment = db.session.query(db.Comment).one() {'comment_id': comment.comment_id})
assert comment is not None assert result == 'serialized comment'
assert comment.text == 'new text' assert comment.last_edit_time == datetime(1997, 12, 1)
assert comment.last_edit_time is not None
@pytest.mark.parametrize('input,expected_exception', [ @pytest.mark.parametrize('params,expected_exception', [
({'text': None}, comments.EmptyCommentTextError), ({'text': None}, comments.EmptyCommentTextError),
({'text': ''}, comments.EmptyCommentTextError), ({'text': ''}, comments.EmptyCommentTextError),
({'text': []}, comments.EmptyCommentTextError), ({'text': []}, comments.EmptyCommentTextError),
({'text': [None]}, errors.ValidationError), ({'text': [None]}, errors.ValidationError),
({'text': ['']}, comments.EmptyCommentTextError), ({'text': ['']}, comments.EmptyCommentTextError),
]) ])
def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception): def test_trying_to_pass_invalid_params(
user = test_ctx.user_factory() user_factory, comment_factory, context_factory, params, expected_exception):
comment = test_ctx.comment_factory(user=user) user = user_factory()
comment = comment_factory(user=user)
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()
with pytest.raises(expected_exception): with pytest.raises(expected_exception):
test_ctx.api.put( api.comment_api.update_comment(
test_ctx.context_factory( context_factory(
input={**input, **{'version': 1}}, user=user), params={**params, **{'version': 1}}, user=user),
comment.comment_id) {'comment_id': comment.comment_id})
def test_trying_to_omit_mandatory_field(test_ctx): def test_trying_to_omit_mandatory_field(
user = test_ctx.user_factory() user_factory, comment_factory, context_factory):
comment = test_ctx.comment_factory(user=user) user = user_factory()
comment = comment_factory(user=user)
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError):
test_ctx.api.put( api.comment_api.update_comment(
test_ctx.context_factory(input={'version': 1}, user=user), context_factory(params={'version': 1}, user=user),
comment.comment_id) {'comment_id': comment.comment_id})
def test_trying_to_update_non_existing(test_ctx): def test_trying_to_update_non_existing(user_factory, context_factory):
with pytest.raises(comments.CommentNotFoundError): with pytest.raises(comments.CommentNotFoundError):
test_ctx.api.put( api.comment_api.update_comment(
test_ctx.context_factory( context_factory(
input={'text': 'new text'}, params={'text': 'new text'},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=db.User.RANK_REGULAR)),
5) {'comment_id': 5})
def test_trying_to_update_someones_comment_without_privileges(test_ctx): def test_trying_to_update_someones_comment_without_privileges(
user = test_ctx.user_factory(rank=db.User.RANK_REGULAR) user_factory, comment_factory, context_factory):
user2 = test_ctx.user_factory(rank=db.User.RANK_REGULAR) user = user_factory(rank=db.User.RANK_REGULAR)
comment = test_ctx.comment_factory(user=user) user2 = user_factory(rank=db.User.RANK_REGULAR)
comment = comment_factory(user=user)
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.api.put( api.comment_api.update_comment(
test_ctx.context_factory( context_factory(
input={'text': 'new text', 'version': 1}, user=user2), params={'text': 'new text', 'version': 1}, user=user2),
comment.comment_id) {'comment_id': comment.comment_id})
def test_updating_someones_comment_with_privileges(test_ctx): def test_updating_someones_comment_with_privileges(
user = test_ctx.user_factory(rank=db.User.RANK_REGULAR) user_factory, comment_factory, context_factory):
user2 = test_ctx.user_factory(rank=db.User.RANK_MODERATOR) user = user_factory(rank=db.User.RANK_REGULAR)
comment = test_ctx.comment_factory(user=user) user2 = user_factory(rank=db.User.RANK_MODERATOR)
comment = comment_factory(user=user)
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()
try: with unittest.mock.patch('szurubooru.func.comments.serialize_comment'):
test_ctx.api.put( api.comment_api.update_comment(
test_ctx.context_factory( context_factory(
input={'text': 'new text', 'version': 1}, user=user2), params={'text': 'new text', 'version': 1}, user=user2),
comment.comment_id) {'comment_id': comment.comment_id})
except:
pytest.fail()

View file

@ -31,9 +31,8 @@ def test_info_api(
}, },
} }
info_api = api.InfoApi()
with fake_datetime('2016-01-01 13:00'): with fake_datetime('2016-01-01 13:00'):
assert info_api.get(context_factory()) == { assert api.info_api.get_info(context_factory()) == {
'postCount': 2, 'postCount': 2,
'diskUsage': 3, 'diskUsage': 3,
'featuredPost': None, 'featuredPost': None,
@ -44,7 +43,7 @@ def test_info_api(
} }
directory.join('test2.txt').write('abc') directory.join('test2.txt').write('abc')
with fake_datetime('2016-01-01 13:59'): with fake_datetime('2016-01-01 13:59'):
assert info_api.get(context_factory()) == { assert api.info_api.get_info(context_factory()) == {
'postCount': 2, 'postCount': 2,
'diskUsage': 3, # still 3 - it's cached 'diskUsage': 3, # still 3 - it's cached
'featuredPost': None, 'featuredPost': None,
@ -54,7 +53,7 @@ def test_info_api(
'config': expected_config_key, 'config': expected_config_key,
} }
with fake_datetime('2016-01-01 14:01'): with fake_datetime('2016-01-01 14:01'):
assert info_api.get(context_factory()) == { assert api.info_api.get_info(context_factory()) == {
'postCount': 2, 'postCount': 2,
'diskUsage': 6, # cache expired 'diskUsage': 6, # cache expired
'featuredPost': None, 'featuredPost': None,

View file

@ -1,25 +1,23 @@
from datetime import datetime
from unittest import mock
import pytest import pytest
import unittest.mock
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import auth, mailer from szurubooru.func import auth, mailer
@pytest.fixture @pytest.fixture(autouse=True)
def password_reset_api(config_injector): def inject_config(tmpdir, config_injector):
config_injector({ config_injector({
'secret': 'x', 'secret': 'x',
'base_url': 'http://example.com/', 'base_url': 'http://example.com/',
'name': 'Test instance', 'name': 'Test instance',
}) })
return api.PasswordResetApi()
def test_reset_sending_email( def test_reset_sending_email(context_factory, user_factory):
password_reset_api, context_factory, user_factory):
db.session.add(user_factory( db.session.add(user_factory(
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')) name='u1', rank=db.User.RANK_REGULAR, email='user@example.com'))
for getter in ['u1', 'user@example.com']: for initiating_user in ['u1', 'user@example.com']:
mailer.send_mail = mock.MagicMock() with unittest.mock.patch('szurubooru.func.mailer.send_mail'):
assert password_reset_api.get(context_factory(), getter) == {} assert api.password_reset_api.start_password_reset(
context_factory(), {'user_name': initiating_user}) == {}
mailer.send_mail.assert_called_once_with( mailer.send_mail.assert_called_once_with(
'noreply@Test instance', 'noreply@Test instance',
'user@example.com', 'user@example.com',
@ -29,43 +27,44 @@ def test_reset_sending_email(
'ink: http://example.com/password-reset/u1:4ac0be176fb36' + 'ink: http://example.com/password-reset/u1:4ac0be176fb36' +
'4f13ee6b634c43220e2\nOtherwise, please ignore this email.') '4f13ee6b634c43220e2\nOtherwise, please ignore this email.')
def test_trying_to_reset_non_existing(password_reset_api, context_factory): def test_trying_to_reset_non_existing(context_factory):
with pytest.raises(errors.NotFoundError): with pytest.raises(errors.NotFoundError):
password_reset_api.get(context_factory(), 'u1') api.password_reset_api.start_password_reset(
context_factory(), {'user_name': 'u1'})
def test_trying_to_reset_without_email( def test_trying_to_reset_without_email(context_factory, user_factory):
password_reset_api, context_factory, user_factory):
db.session.add(user_factory(name='u1', rank=db.User.RANK_REGULAR, email=None)) db.session.add(user_factory(name='u1', rank=db.User.RANK_REGULAR, email=None))
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError):
password_reset_api.get(context_factory(), 'u1') api.password_reset_api.start_password_reset(
context_factory(), {'user_name': 'u1'})
def test_confirming_with_good_token( def test_confirming_with_good_token(context_factory, user_factory):
password_reset_api, context_factory, user_factory):
user = user_factory( user = user_factory(
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com') name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')
old_hash = user.password_hash old_hash = user.password_hash
db.session.add(user) db.session.add(user)
context = context_factory( context = context_factory(
input={'token': '4ac0be176fb364f13ee6b634c43220e2'}) params={'token': '4ac0be176fb364f13ee6b634c43220e2'})
result = password_reset_api.post(context, 'u1') result = api.password_reset_api.finish_password_reset(
context, {'user_name': 'u1'})
assert user.password_hash != old_hash assert user.password_hash != old_hash
assert auth.is_valid_password(user, result['password']) is True assert auth.is_valid_password(user, result['password']) is True
def test_trying_to_confirm_non_existing(password_reset_api, context_factory): def test_trying_to_confirm_non_existing(context_factory):
with pytest.raises(errors.NotFoundError): with pytest.raises(errors.NotFoundError):
password_reset_api.post(context_factory(), 'u1') api.password_reset_api.finish_password_reset(
context_factory(), {'user_name': 'u1'})
def test_trying_to_confirm_without_token( def test_trying_to_confirm_without_token(context_factory, user_factory):
password_reset_api, context_factory, user_factory):
db.session.add(user_factory( db.session.add(user_factory(
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')) name='u1', rank=db.User.RANK_REGULAR, email='user@example.com'))
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError):
password_reset_api.post(context_factory(input={}), 'u1') api.password_reset_api.finish_password_reset(
context_factory(params={}), {'user_name': 'u1'})
def test_trying_to_confirm_with_bad_token( def test_trying_to_confirm_with_bad_token(context_factory, user_factory):
password_reset_api, context_factory, user_factory):
db.session.add(user_factory( db.session.add(user_factory(
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')) name='u1', rank=db.User.RANK_REGULAR, email='user@example.com'))
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError):
password_reset_api.post( api.password_reset_api.finish_password_reset(
context_factory(input={'token': 'bad'}), 'u1') context_factory(params={'token': 'bad'}), {'user_name': 'u1'})

View file

@ -1,7 +1,5 @@
import datetime
import os
import unittest.mock
import pytest import pytest
import unittest.mock
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import posts, tags, snapshots, net from szurubooru.func import posts, tags, snapshots, net
@ -35,9 +33,9 @@ def test_creating_minimal_posts(
posts.create_post.return_value = (post, []) posts.create_post.return_value = (post, [])
posts.serialize_post.return_value = 'serialized post' posts.serialize_post.return_value = 'serialized post'
result = api.PostListApi().post( result = api.post_api.create_post(
context_factory( context_factory(
input={ params={
'safety': 'safe', 'safety': 'safe',
'tags': ['tag1', 'tag2'], 'tags': ['tag1', 'tag2'],
}, },
@ -79,9 +77,9 @@ def test_creating_full_posts(context_factory, post_factory, user_factory):
posts.create_post.return_value = (post, []) posts.create_post.return_value = (post, [])
posts.serialize_post.return_value = 'serialized post' posts.serialize_post.return_value = 'serialized post'
result = api.PostListApi().post( result = api.post_api.create_post(
context_factory( context_factory(
input={ params={
'safety': 'safe', 'safety': 'safe',
'tags': ['tag1', 'tag2'], 'tags': ['tag1', 'tag2'],
'relations': [1, 2], 'relations': [1, 2],
@ -122,9 +120,9 @@ def test_anonymous_uploads(
'privileges': {'posts:create:anonymous': db.User.RANK_REGULAR}, 'privileges': {'posts:create:anonymous': db.User.RANK_REGULAR},
}) })
posts.create_post.return_value = [post, []] posts.create_post.return_value = [post, []]
api.PostListApi().post( api.post_api.create_post(
context_factory( context_factory(
input={ params={
'safety': 'safe', 'safety': 'safe',
'tags': ['tag1', 'tag2'], 'tags': ['tag1', 'tag2'],
'anonymous': 'True', 'anonymous': 'True',
@ -154,9 +152,9 @@ def test_creating_from_url_saves_source(
}) })
net.download.return_value = b'content' net.download.return_value = b'content'
posts.create_post.return_value = [post, []] posts.create_post.return_value = [post, []]
api.PostListApi().post( api.post_api.create_post(
context_factory( context_factory(
input={ params={
'safety': 'safe', 'safety': 'safe',
'tags': ['tag1', 'tag2'], 'tags': ['tag1', 'tag2'],
'contentUrl': 'example.com', 'contentUrl': 'example.com',
@ -185,9 +183,9 @@ def test_creating_from_url_with_source_specified(
}) })
net.download.return_value = b'content' net.download.return_value = b'content'
posts.create_post.return_value = [post, []] posts.create_post.return_value = [post, []]
api.PostListApi().post( api.post_api.create_post(
context_factory( context_factory(
input={ params={
'safety': 'safe', 'safety': 'safe',
'tags': ['tag1', 'tag2'], 'tags': ['tag1', 'tag2'],
'contentUrl': 'example.com', 'contentUrl': 'example.com',
@ -201,23 +199,23 @@ def test_creating_from_url_with_source_specified(
@pytest.mark.parametrize('field', ['tags', 'safety']) @pytest.mark.parametrize('field', ['tags', 'safety'])
def test_trying_to_omit_mandatory_field(context_factory, user_factory, field): def test_trying_to_omit_mandatory_field(context_factory, user_factory, field):
input = { params = {
'safety': 'safe', 'safety': 'safe',
'tags': ['tag1', 'tag2'], 'tags': ['tag1', 'tag2'],
} }
del input[field] del params[field]
with pytest.raises(errors.MissingRequiredParameterError): with pytest.raises(errors.MissingRequiredParameterError):
api.PostListApi().post( api.post_api.create_post(
context_factory( context_factory(
input=input, params=params,
files={'content': '...'}, files={'content': '...'},
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=db.User.RANK_REGULAR)))
def test_trying_to_omit_content(context_factory, user_factory): def test_trying_to_omit_content(context_factory, user_factory):
with pytest.raises(errors.MissingRequiredFileError): with pytest.raises(errors.MissingRequiredFileError):
api.PostListApi().post( api.post_api.create_post(
context_factory( context_factory(
input={ params={
'safety': 'safe', 'safety': 'safe',
'tags': ['tag1', 'tag2'], 'tags': ['tag1', 'tag2'],
}, },
@ -225,9 +223,8 @@ def test_trying_to_omit_content(context_factory, user_factory):
def test_trying_to_create_post_without_privileges(context_factory, user_factory): def test_trying_to_create_post_without_privileges(context_factory, user_factory):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
api.PostListApi().post( api.post_api.create_post(context_factory(
context_factory( params='whatever',
input='whatever',
user=user_factory(rank=db.User.RANK_ANONYMOUS))) user=user_factory(rank=db.User.RANK_ANONYMOUS)))
def test_trying_to_create_tags_without_privileges( def test_trying_to_create_tags_without_privileges(
@ -243,9 +240,9 @@ def test_trying_to_create_tags_without_privileges(
unittest.mock.patch('szurubooru.func.posts.update_post_content'), \ unittest.mock.patch('szurubooru.func.posts.update_post_content'), \
unittest.mock.patch('szurubooru.func.posts.update_post_tags'): unittest.mock.patch('szurubooru.func.posts.update_post_tags'):
posts.update_post_tags.return_value = ['new-tag'] posts.update_post_tags.return_value = ['new-tag']
api.PostListApi().post( api.post_api.create_post(
context_factory( context_factory(
input={ params={
'safety': 'safe', 'safety': 'safe',
'tags': ['tag1', 'tag2'], 'tags': ['tag1', 'tag2'],
}, },

View file

@ -1,49 +1,37 @@
import pytest import pytest
import os import unittest.mock
from datetime import datetime from szurubooru import api, db, errors
from szurubooru import api, config, db, errors from szurubooru.func import posts, tags
from szurubooru.func import util, posts
@pytest.fixture @pytest.fixture(autouse=True)
def test_ctx( def inject_config(config_injector):
tmpdir, config_injector, context_factory, post_factory, user_factory): config_injector({'privileges': {'posts:delete': db.User.RANK_REGULAR}})
config_injector({
'data_dir': str(tmpdir),
'privileges': {
'posts:delete': db.User.RANK_REGULAR,
},
})
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.post_factory = post_factory
ret.api = api.PostDetailApi()
return ret
def test_deleting(test_ctx): def test_deleting(user_factory, post_factory, context_factory):
db.session.add(test_ctx.post_factory(id=1)) db.session.add(post_factory(id=1))
db.session.commit() db.session.commit()
result = test_ctx.api.delete( with unittest.mock.patch('szurubooru.func.tags.export_to_json'):
test_ctx.context_factory( result = api.post_api.delete_post(
input={'version': 1}, context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), params={'version': 1},
1) user=user_factory(rank=db.User.RANK_REGULAR)),
{'post_id': 1})
assert result == {} assert result == {}
assert db.session.query(db.Post).count() == 0 assert db.session.query(db.Post).count() == 0
assert os.path.exists(os.path.join(config.config['data_dir'], 'tags.json')) tags.export_to_json.assert_called_once_with()
def test_trying_to_delete_non_existing(test_ctx): def test_trying_to_delete_non_existing(user_factory, context_factory):
with pytest.raises(posts.PostNotFoundError): with pytest.raises(posts.PostNotFoundError):
test_ctx.api.delete( api.post_api.delete_post(
test_ctx.context_factory( context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), '999') {'post_id': 999})
def test_trying_to_delete_without_privileges(test_ctx): def test_trying_to_delete_without_privileges(
db.session.add(test_ctx.post_factory(id=1)) user_factory, post_factory, context_factory):
db.session.add(post_factory(id=1))
db.session.commit() db.session.commit()
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.api.delete( api.post_api.delete_post(
test_ctx.context_factory( context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)),
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)), {'post_id': 1})
1)
assert db.session.query(db.Post).count() == 1 assert db.session.query(db.Post).count() == 1

View file

@ -1,132 +1,129 @@
import datetime
import pytest import pytest
import unittest.mock
from datetime import datetime
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import util, posts from szurubooru.func import posts
@pytest.fixture @pytest.fixture(autouse=True)
def test_ctx( def inject_config(config_injector):
tmpdir, config_injector, context_factory, user_factory, post_factory): config_injector({'privileges': {'posts:favorite': db.User.RANK_REGULAR}})
config_injector({
'data_dir': str(tmpdir),
'data_url': 'http://example.com',
'privileges': {
'posts:favorite': db.User.RANK_REGULAR,
'users:edit:any:email': db.User.RANK_MODERATOR,
},
'thumbnails': {'avatar_width': 200},
})
db.session.flush()
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.post_factory = post_factory
ret.api = api.PostFavoriteApi()
return ret
def test_adding_to_favorites(test_ctx, fake_datetime): def test_adding_to_favorites(
post = test_ctx.post_factory() user_factory, post_factory, context_factory, fake_datetime):
post = post_factory()
db.session.add(post) db.session.add(post)
db.session.commit() db.session.commit()
assert post.score == 0 assert post.score == 0
with fake_datetime('1997-12-01'): with unittest.mock.patch('szurubooru.func.posts.serialize_post'), \
result = test_ctx.api.post( fake_datetime('1997-12-01'):
test_ctx.context_factory(user=test_ctx.user_factory()), posts.serialize_post.return_value = 'serialized post'
post.post_id) result = api.post_api.add_post_to_favorites(
assert 'id' in result context_factory(user=user_factory()),
{'post_id': post.post_id})
assert result == 'serialized post'
post = db.session.query(db.Post).one() post = db.session.query(db.Post).one()
assert db.session.query(db.PostFavorite).count() == 1 assert db.session.query(db.PostFavorite).count() == 1
assert post is not None assert post is not None
assert post.favorite_count == 1 assert post.favorite_count == 1
assert post.score == 1 assert post.score == 1
def test_removing_from_favorites(test_ctx, fake_datetime): def test_removing_from_favorites(
user = test_ctx.user_factory() user_factory, post_factory, context_factory, fake_datetime):
post = test_ctx.post_factory() user = user_factory()
post = post_factory()
db.session.add(post) db.session.add(post)
db.session.commit() db.session.commit()
assert post.score == 0 assert post.score == 0
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
with fake_datetime('1997-12-01'): with fake_datetime('1997-12-01'):
result = test_ctx.api.post( api.post_api.add_post_to_favorites(
test_ctx.context_factory(user=user), context_factory(user=user),
post.post_id) {'post_id': post.post_id})
assert post.score == 1 assert post.score == 1
with fake_datetime('1997-12-02'): with fake_datetime('1997-12-02'):
result = test_ctx.api.delete( api.post_api.delete_post_from_favorites(
test_ctx.context_factory(user=user), context_factory(user=user),
post.post_id) {'post_id': post.post_id})
post = db.session.query(db.Post).one() post = db.session.query(db.Post).one()
assert post.score == 1 assert post.score == 1
assert db.session.query(db.PostFavorite).count() == 0 assert db.session.query(db.PostFavorite).count() == 0
assert post.favorite_count == 0 assert post.favorite_count == 0
def test_favoriting_twice(test_ctx, fake_datetime): def test_favoriting_twice(
user = test_ctx.user_factory() user_factory, post_factory, context_factory, fake_datetime):
post = test_ctx.post_factory() user = user_factory()
post = post_factory()
db.session.add(post) db.session.add(post)
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
with fake_datetime('1997-12-01'): with fake_datetime('1997-12-01'):
result = test_ctx.api.post( api.post_api.add_post_to_favorites(
test_ctx.context_factory(user=user), context_factory(user=user),
post.post_id) {'post_id': post.post_id})
with fake_datetime('1997-12-02'): with fake_datetime('1997-12-02'):
result = test_ctx.api.post( api.post_api.add_post_to_favorites(
test_ctx.context_factory(user=user), context_factory(user=user),
post.post_id) {'post_id': post.post_id})
post = db.session.query(db.Post).one() post = db.session.query(db.Post).one()
assert db.session.query(db.PostFavorite).count() == 1 assert db.session.query(db.PostFavorite).count() == 1
assert post.favorite_count == 1 assert post.favorite_count == 1
def test_removing_twice(test_ctx, fake_datetime): def test_removing_twice(
user = test_ctx.user_factory() user_factory, post_factory, context_factory, fake_datetime):
post = test_ctx.post_factory() user = user_factory()
post = post_factory()
db.session.add(post) db.session.add(post)
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
with fake_datetime('1997-12-01'): with fake_datetime('1997-12-01'):
result = test_ctx.api.post( api.post_api.add_post_to_favorites(
test_ctx.context_factory(user=user), context_factory(user=user),
post.post_id) {'post_id': post.post_id})
with fake_datetime('1997-12-02'): with fake_datetime('1997-12-02'):
result = test_ctx.api.delete( api.post_api.delete_post_from_favorites(
test_ctx.context_factory(user=user), context_factory(user=user),
post.post_id) {'post_id': post.post_id})
with fake_datetime('1997-12-02'): with fake_datetime('1997-12-02'):
result = test_ctx.api.delete( api.post_api.delete_post_from_favorites(
test_ctx.context_factory(user=user), context_factory(user=user),
post.post_id) {'post_id': post.post_id})
post = db.session.query(db.Post).one() post = db.session.query(db.Post).one()
assert db.session.query(db.PostFavorite).count() == 0 assert db.session.query(db.PostFavorite).count() == 0
assert post.favorite_count == 0 assert post.favorite_count == 0
def test_favorites_from_multiple_users(test_ctx, fake_datetime): def test_favorites_from_multiple_users(
user1 = test_ctx.user_factory() user_factory, post_factory, context_factory, fake_datetime):
user2 = test_ctx.user_factory() user1 = user_factory()
post = test_ctx.post_factory() user2 = user_factory()
post = post_factory()
db.session.add_all([user1, user2, post]) db.session.add_all([user1, user2, post])
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
with fake_datetime('1997-12-01'): with fake_datetime('1997-12-01'):
result = test_ctx.api.post( api.post_api.add_post_to_favorites(
test_ctx.context_factory(user=user1), context_factory(user=user1),
post.post_id) {'post_id': post.post_id})
with fake_datetime('1997-12-02'): with fake_datetime('1997-12-02'):
result = test_ctx.api.post( api.post_api.add_post_to_favorites(
test_ctx.context_factory(user=user2), context_factory(user=user2),
post.post_id) {'post_id': post.post_id})
post = db.session.query(db.Post).one() post = db.session.query(db.Post).one()
assert db.session.query(db.PostFavorite).count() == 2 assert db.session.query(db.PostFavorite).count() == 2
assert post.favorite_count == 2 assert post.favorite_count == 2
assert post.last_favorite_time == datetime.datetime(1997, 12, 2) assert post.last_favorite_time == datetime(1997, 12, 2)
def test_trying_to_update_non_existing(test_ctx): def test_trying_to_update_non_existing(user_factory, context_factory):
with pytest.raises(posts.PostNotFoundError): with pytest.raises(posts.PostNotFoundError):
test_ctx.api.post( api.post_api.add_post_to_favorites(
test_ctx.context_factory(user=test_ctx.user_factory()), 5) context_factory(user=user_factory()),
{'post_id': 5})
def test_trying_to_rate_without_privileges(test_ctx): def test_trying_to_rate_without_privileges(
post = test_ctx.post_factory() user_factory, post_factory, context_factory):
post = post_factory()
db.session.add(post) db.session.add(post)
db.session.commit() db.session.commit()
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.api.post( api.post_api.add_post_to_favorites(
test_ctx.context_factory( context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)),
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)), {'post_id': post.post_id})
post.post_id)

View file

@ -1,107 +1,100 @@
import datetime
import pytest import pytest
import unittest.mock
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import util, posts from szurubooru.func import posts
@pytest.fixture @pytest.fixture(autouse=True)
def test_ctx( def inject_config(config_injector):
tmpdir, context_factory, config_injector, user_factory, post_factory):
config_injector({ config_injector({
'data_dir': str(tmpdir),
'data_url': 'http://example.com',
'privileges': { 'privileges': {
'posts:feature': db.User.RANK_REGULAR, 'posts:feature': db.User.RANK_REGULAR,
'posts:view': db.User.RANK_REGULAR, 'posts:view': db.User.RANK_REGULAR,
}, },
}) })
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.post_factory = post_factory
ret.api = api.PostFeatureApi()
return ret
def test_no_featured_post(test_ctx): def test_no_featured_post(user_factory, post_factory, context_factory):
assert posts.try_get_featured_post() is None assert posts.try_get_featured_post() is None
result = test_ctx.api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
assert result is None
def test_featuring(test_ctx): def test_featuring(user_factory, post_factory, context_factory):
db.session.add(test_ctx.post_factory(id=1)) db.session.add(post_factory(id=1))
db.session.commit() db.session.commit()
assert not posts.get_post_by_id(1).is_featured assert not posts.get_post_by_id(1).is_featured
result = test_ctx.api.post( with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
test_ctx.context_factory( posts.serialize_post.return_value = 'serialized post'
input={'id': 1}, result = api.post_api.set_featured_post(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) context_factory(
params={'id': 1},
user=user_factory(rank=db.User.RANK_REGULAR)))
assert result == 'serialized post'
assert posts.try_get_featured_post() is not None assert posts.try_get_featured_post() is not None
assert posts.try_get_featured_post().post_id == 1 assert posts.try_get_featured_post().post_id == 1
assert posts.get_post_by_id(1).is_featured assert posts.get_post_by_id(1).is_featured
assert 'id' in result result = api.post_api.get_featured_post(
assert 'snapshots' in result context_factory(
assert 'comments' in result user=user_factory(rank=db.User.RANK_REGULAR)))
result = test_ctx.api.get( assert result == 'serialized post'
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
assert 'id' in result
assert 'snapshots' in result
assert 'comments' in result
def test_trying_to_feature_the_same_post_twice(test_ctx): def test_trying_to_omit_required_parameter(
db.session.add(test_ctx.post_factory(id=1)) user_factory, post_factory, context_factory):
with pytest.raises(errors.MissingRequiredParameterError):
api.post_api.set_featured_post(
context_factory(
user=user_factory(rank=db.User.RANK_REGULAR)))
def test_trying_to_feature_the_same_post_twice(
user_factory, post_factory, context_factory):
db.session.add(post_factory(id=1))
db.session.commit() db.session.commit()
test_ctx.api.post( with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
test_ctx.context_factory( api.post_api.set_featured_post(
input={'id': 1}, context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) params={'id': 1},
user=user_factory(rank=db.User.RANK_REGULAR)))
with pytest.raises(posts.PostAlreadyFeaturedError): with pytest.raises(posts.PostAlreadyFeaturedError):
test_ctx.api.post( api.post_api.set_featured_post(
test_ctx.context_factory( context_factory(
input={'id': 1}, params={'id': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=db.User.RANK_REGULAR)))
def test_featuring_one_post_after_another(test_ctx, fake_datetime): def test_featuring_one_post_after_another(
db.session.add(test_ctx.post_factory(id=1)) user_factory, post_factory, context_factory, fake_datetime):
db.session.add(test_ctx.post_factory(id=2)) db.session.add(post_factory(id=1))
db.session.add(post_factory(id=2))
db.session.commit() db.session.commit()
assert posts.try_get_featured_post() is None assert posts.try_get_featured_post() is None
assert not posts.get_post_by_id(1).is_featured assert not posts.get_post_by_id(1).is_featured
assert not posts.get_post_by_id(2).is_featured assert not posts.get_post_by_id(2).is_featured
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
with fake_datetime('1997'): with fake_datetime('1997'):
result = test_ctx.api.post( result = api.post_api.set_featured_post(
test_ctx.context_factory( context_factory(
input={'id': 1}, params={'id': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=db.User.RANK_REGULAR)))
with fake_datetime('1998'): with fake_datetime('1998'):
result = test_ctx.api.post( result = api.post_api.set_featured_post(
test_ctx.context_factory( context_factory(
input={'id': 2}, params={'id': 2},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=db.User.RANK_REGULAR)))
assert posts.try_get_featured_post() is not None assert posts.try_get_featured_post() is not None
assert posts.try_get_featured_post().post_id == 2 assert posts.try_get_featured_post().post_id == 2
assert not posts.get_post_by_id(1).is_featured assert not posts.get_post_by_id(1).is_featured
assert posts.get_post_by_id(2).is_featured assert posts.get_post_by_id(2).is_featured
def test_trying_to_feature_non_existing(test_ctx): def test_trying_to_feature_non_existing(user_factory, context_factory):
with pytest.raises(posts.PostNotFoundError): with pytest.raises(posts.PostNotFoundError):
test_ctx.api.post( api.post_api.set_featured_post(
test_ctx.context_factory( context_factory(
input={'id': 1}, params={'id': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=db.User.RANK_REGULAR)))
def test_trying_to_feature_without_privileges(test_ctx): def test_trying_to_feature_without_privileges(user_factory, context_factory):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.api.post( api.post_api.set_featured_post(
test_ctx.context_factory( context_factory(
input={'id': 1}, params={'id': 1},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS))) user=user_factory(rank=db.User.RANK_ANONYMOUS)))
def test_getting_featured_post_without_privileges_to_view(test_ctx): def test_getting_featured_post_without_privileges_to_view(
try: user_factory, context_factory):
test_ctx.api.get( api.post_api.get_featured_post(
test_ctx.context_factory( context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)))
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)))
except:
pytest.fail()

View file

@ -1,147 +1,132 @@
import pytest import pytest
import unittest.mock
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import util, posts, scores from szurubooru.func import posts, scores
@pytest.fixture @pytest.fixture(autouse=True)
def test_ctx( def inject_config(config_injector):
tmpdir, config_injector, context_factory, user_factory, post_factory): config_injector({'privileges': {'posts:score': db.User.RANK_REGULAR}})
config_injector({
'data_dir': str(tmpdir),
'data_url': 'http://example.com',
'privileges': {'posts:score': db.User.RANK_REGULAR},
'thumbnails': {'avatar_width': 200},
})
db.session.flush()
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.post_factory = post_factory
ret.api = api.PostScoreApi()
return ret
def test_simple_rating(test_ctx, fake_datetime): def test_simple_rating(
post = test_ctx.post_factory() user_factory, post_factory, context_factory, fake_datetime):
post = post_factory()
db.session.add(post) db.session.add(post)
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
posts.serialize_post.return_value = 'serialized post'
with fake_datetime('1997-12-01'): with fake_datetime('1997-12-01'):
result = test_ctx.api.put( result = api.post_api.set_post_score(
test_ctx.context_factory( context_factory(
input={'score': 1}, user=test_ctx.user_factory()), params={'score': 1}, user=user_factory()),
post.post_id) {'post_id': post.post_id})
assert 'id' in result assert result == 'serialized post'
post = db.session.query(db.Post).one() post = db.session.query(db.Post).one()
assert db.session.query(db.PostScore).count() == 1 assert db.session.query(db.PostScore).count() == 1
assert post is not None assert post is not None
assert post.score == 1 assert post.score == 1
def test_updating_rating(test_ctx, fake_datetime): def test_updating_rating(
user = test_ctx.user_factory() user_factory, post_factory, context_factory, fake_datetime):
post = test_ctx.post_factory() user = user_factory()
post = post_factory()
db.session.add(post) db.session.add(post)
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
with fake_datetime('1997-12-01'): with fake_datetime('1997-12-01'):
result = test_ctx.api.put( result = api.post_api.set_post_score(
test_ctx.context_factory(input={'score': 1}, user=user), context_factory(params={'score': 1}, user=user),
post.post_id) {'post_id': post.post_id})
with fake_datetime('1997-12-02'): with fake_datetime('1997-12-02'):
result = test_ctx.api.put( result = api.post_api.set_post_score(
test_ctx.context_factory(input={'score': -1}, user=user), context_factory(params={'score': -1}, user=user),
post.post_id) {'post_id': post.post_id})
post = db.session.query(db.Post).one() post = db.session.query(db.Post).one()
assert db.session.query(db.PostScore).count() == 1 assert db.session.query(db.PostScore).count() == 1
assert post.score == -1 assert post.score == -1
def test_updating_rating_to_zero(test_ctx, fake_datetime): def test_updating_rating_to_zero(
user = test_ctx.user_factory() user_factory, post_factory, context_factory, fake_datetime):
post = test_ctx.post_factory() user = user_factory()
post = post_factory()
db.session.add(post) db.session.add(post)
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
with fake_datetime('1997-12-01'): with fake_datetime('1997-12-01'):
result = test_ctx.api.put( result = api.post_api.set_post_score(
test_ctx.context_factory(input={'score': 1}, user=user), context_factory(params={'score': 1}, user=user),
post.post_id) {'post_id': post.post_id})
with fake_datetime('1997-12-02'): with fake_datetime('1997-12-02'):
result = test_ctx.api.put( result = api.post_api.set_post_score(
test_ctx.context_factory(input={'score': 0}, user=user), context_factory(params={'score': 0}, user=user),
post.post_id) {'post_id': post.post_id})
post = db.session.query(db.Post).one() post = db.session.query(db.Post).one()
assert db.session.query(db.PostScore).count() == 0 assert db.session.query(db.PostScore).count() == 0
assert post.score == 0 assert post.score == 0
def test_deleting_rating(test_ctx, fake_datetime): def test_deleting_rating(
user = test_ctx.user_factory() user_factory, post_factory, context_factory, fake_datetime):
post = test_ctx.post_factory() user = user_factory()
post = post_factory()
db.session.add(post) db.session.add(post)
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
with fake_datetime('1997-12-01'): with fake_datetime('1997-12-01'):
result = test_ctx.api.put( result = api.post_api.set_post_score(
test_ctx.context_factory(input={'score': 1}, user=user), context_factory(params={'score': 1}, user=user),
post.post_id) {'post_id': post.post_id})
with fake_datetime('1997-12-02'): with fake_datetime('1997-12-02'):
result = test_ctx.api.delete( result = api.post_api.delete_post_score(
test_ctx.context_factory(user=user), post.post_id) context_factory(user=user),
{'post_id': post.post_id})
post = db.session.query(db.Post).one() post = db.session.query(db.Post).one()
assert db.session.query(db.PostScore).count() == 0 assert db.session.query(db.PostScore).count() == 0
assert post.score == 0 assert post.score == 0
def test_ratings_from_multiple_users(test_ctx, fake_datetime): def test_ratings_from_multiple_users(
user1 = test_ctx.user_factory() user_factory, post_factory, context_factory, fake_datetime):
user2 = test_ctx.user_factory() user1 = user_factory()
post = test_ctx.post_factory() user2 = user_factory()
post = post_factory()
db.session.add_all([user1, user2, post]) db.session.add_all([user1, user2, post])
db.session.commit() db.session.commit()
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
with fake_datetime('1997-12-01'): with fake_datetime('1997-12-01'):
result = test_ctx.api.put( result = api.post_api.set_post_score(
test_ctx.context_factory(input={'score': 1}, user=user1), context_factory(params={'score': 1}, user=user1),
post.post_id) {'post_id': post.post_id})
with fake_datetime('1997-12-02'): with fake_datetime('1997-12-02'):
result = test_ctx.api.put( result = api.post_api.set_post_score(
test_ctx.context_factory(input={'score': -1}, user=user2), context_factory(params={'score': -1}, user=user2),
post.post_id) {'post_id': post.post_id})
post = db.session.query(db.Post).one() post = db.session.query(db.Post).one()
assert db.session.query(db.PostScore).count() == 2 assert db.session.query(db.PostScore).count() == 2
assert post.score == 0 assert post.score == 0
@pytest.mark.parametrize('input,expected_exception', [ def test_trying_to_omit_mandatory_field(
({'score': None}, errors.ValidationError), user_factory, post_factory, context_factory):
({'score': ''}, errors.ValidationError), post = post_factory()
({'score': -2}, scores.InvalidScoreValueError),
({'score': 2}, scores.InvalidScoreValueError),
({'score': [1]}, errors.ValidationError),
])
def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception):
post = test_ctx.post_factory()
db.session.add(post)
db.session.commit()
with pytest.raises(expected_exception):
test_ctx.api.put(
test_ctx.context_factory(input=input, user=test_ctx.user_factory()),
post.post_id)
def test_trying_to_omit_mandatory_field(test_ctx):
post = test_ctx.post_factory()
db.session.add(post) db.session.add(post)
db.session.commit() db.session.commit()
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError):
test_ctx.api.put( api.post_api.set_post_score(
test_ctx.context_factory(input={}, user=test_ctx.user_factory()), context_factory(params={}, user=user_factory()),
post.post_id) {'post_id': post.post_id})
def test_trying_to_update_non_existing(test_ctx): def test_trying_to_update_non_existing(
user_factory, post_factory, context_factory):
with pytest.raises(posts.PostNotFoundError): with pytest.raises(posts.PostNotFoundError):
test_ctx.api.put( api.post_api.set_post_score(
test_ctx.context_factory( context_factory(params={'score': 1}, user=user_factory()),
input={'score': 1}, {'post_id': 5})
user=test_ctx.user_factory()),
5)
def test_trying_to_rate_without_privileges(test_ctx): def test_trying_to_rate_without_privileges(
post = test_ctx.post_factory() user_factory, post_factory, context_factory):
post = post_factory()
db.session.add(post) db.session.add(post)
db.session.commit() db.session.commit()
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.api.put( api.post_api.set_post_score(
test_ctx.context_factory( context_factory(
input={'score': 1}, params={'score': 1},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)), user=user_factory(rank=db.User.RANK_ANONYMOUS)),
post.post_id) {'post_id': post.post_id})

View file

@ -1,105 +1,97 @@
import datetime
import pytest import pytest
import unittest.mock
from datetime import datetime
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import util, posts from szurubooru.func import posts
@pytest.fixture @pytest.fixture(autouse=True)
def test_ctx( def inject_config(tmpdir, config_injector):
tmpdir, context_factory, config_injector, user_factory, post_factory):
config_injector({ config_injector({
'data_dir': str(tmpdir),
'data_url': 'http://example.com',
'privileges': { 'privileges': {
'posts:list': db.User.RANK_REGULAR, 'posts:list': db.User.RANK_REGULAR,
'posts:view': db.User.RANK_REGULAR, 'posts:view': db.User.RANK_REGULAR,
}, },
'thumbnails': {'avatar_width': 200},
}) })
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.post_factory = post_factory
ret.list_api = api.PostListApi()
ret.detail_api = api.PostDetailApi()
return ret
def test_retrieving_multiple(test_ctx): def test_retrieving_multiple(user_factory, post_factory, context_factory):
post1 = test_ctx.post_factory(id=1) post1 = post_factory(id=1)
post2 = test_ctx.post_factory(id=2) post2 = post_factory(id=2)
db.session.add_all([post1, post2]) db.session.add_all([post1, post2])
result = test_ctx.list_api.get( with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
test_ctx.context_factory( posts.serialize_post.return_value = 'serialized post'
input={'query': '', 'page': 1}, result = api.post_api.get_posts(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) context_factory(
assert result['query'] == '' params={'query': '', 'page': 1},
assert result['page'] == 1 user=user_factory(rank=db.User.RANK_REGULAR)))
assert result['pageSize'] == 100 assert result == {
assert result['total'] == 2 'query': '',
assert [t['id'] for t in result['results']] == [2, 1] 'page': 1,
'pageSize': 100,
'total': 2,
'results': ['serialized post', 'serialized post'],
}
def test_using_special_tokens( def test_using_special_tokens(user_factory, post_factory, context_factory):
test_ctx, config_injector): auth_user = user_factory(rank=db.User.RANK_REGULAR)
auth_user = test_ctx.user_factory(rank=db.User.RANK_REGULAR) post1 = post_factory(id=1)
post1 = test_ctx.post_factory(id=1) post2 = post_factory(id=2)
post2 = test_ctx.post_factory(id=2)
post1.favorited_by = [db.PostFavorite( post1.favorited_by = [db.PostFavorite(
user=auth_user, time=datetime.datetime.utcnow())] user=auth_user, time=datetime.utcnow())]
db.session.add_all([post1, post2, auth_user]) db.session.add_all([post1, post2, auth_user])
db.session.flush() db.session.flush()
result = test_ctx.list_api.get( with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
test_ctx.context_factory( posts.serialize_post.side_effect = \
input={'query': 'special:fav', 'page': 1}, lambda post, *_args, **_kwargs: \
'serialized post %d' % post.post_id
result = api.post_api.get_posts(
context_factory(
params={'query': 'special:fav', 'page': 1},
user=auth_user)) user=auth_user))
assert result['query'] == 'special:fav' assert result == {
assert result['page'] == 1 'query': 'special:fav',
assert result['pageSize'] == 100 'page': 1,
assert result['total'] == 1 'pageSize': 100,
assert [t['id'] for t in result['results']] == [1] 'total': 1,
'results': ['serialized post 1'],
}
def test_trying_to_use_special_tokens_without_logging_in( def test_trying_to_use_special_tokens_without_logging_in(
test_ctx, config_injector): user_factory, post_factory, context_factory, config_injector):
config_injector({ config_injector({
'privileges': {'posts:list': 'anonymous'}, 'privileges': {'posts:list': 'anonymous'},
}) })
with pytest.raises(errors.SearchError): with pytest.raises(errors.SearchError):
test_ctx.list_api.get( api.post_api.get_posts(
test_ctx.context_factory( context_factory(
input={'query': 'special:fav', 'page': 1}, params={'query': 'special:fav', 'page': 1},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS))) user=user_factory(rank=db.User.RANK_ANONYMOUS)))
def test_trying_to_retrieve_multiple_without_privileges(test_ctx): def test_trying_to_retrieve_multiple_without_privileges(
user_factory, context_factory):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.list_api.get( api.post_api.get_posts(
test_ctx.context_factory( context_factory(
input={'query': '', 'page': 1}, params={'query': '', 'page': 1},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS))) user=user_factory(rank=db.User.RANK_ANONYMOUS)))
def test_retrieving_single(test_ctx): def test_retrieving_single(user_factory, post_factory, context_factory):
db.session.add(test_ctx.post_factory(id=1)) db.session.add(post_factory(id=1))
result = test_ctx.detail_api.get( with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
test_ctx.context_factory( posts.serialize_post.return_value = 'serialized post'
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 1) result = api.post_api.get_post(
assert 'id' in result context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
assert 'snapshots' in result {'post_id': 1})
assert 'comments' in result assert result == 'serialized post'
def test_trying_to_retrieve_invalid_id(test_ctx): def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
with pytest.raises(posts.InvalidPostIdError):
test_ctx.detail_api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'-')
def test_trying_to_retrieve_single_non_existing(test_ctx):
with pytest.raises(posts.PostNotFoundError): with pytest.raises(posts.PostNotFoundError):
test_ctx.detail_api.get( api.post_api.get_post(
test_ctx.context_factory( context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), {'post_id': 999})
'999')
def test_trying_to_retrieve_single_without_privileges(test_ctx): def test_trying_to_retrieve_single_without_privileges(
user_factory, context_factory):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.detail_api.get( api.post_api.get_post(
test_ctx.context_factory( context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)),
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)), {'post_id': 999})
'999')

View file

@ -1,12 +1,11 @@
import datetime
import os
import unittest.mock
import pytest import pytest
import unittest.mock
from datetime import datetime
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import posts, tags, snapshots, net from szurubooru.func import posts, tags, snapshots, net
def test_post_updating( @pytest.fixture(autouse=True)
config_injector, context_factory, post_factory, user_factory, fake_datetime): def inject_config(tmpdir, config_injector):
config_injector({ config_injector({
'privileges': { 'privileges': {
'posts:edit:tags': db.User.RANK_REGULAR, 'posts:edit:tags': db.User.RANK_REGULAR,
@ -17,8 +16,12 @@ def test_post_updating(
'posts:edit:notes': db.User.RANK_REGULAR, 'posts:edit:notes': db.User.RANK_REGULAR,
'posts:edit:flags': db.User.RANK_REGULAR, 'posts:edit:flags': db.User.RANK_REGULAR,
'posts:edit:thumbnail': db.User.RANK_REGULAR, 'posts:edit:thumbnail': db.User.RANK_REGULAR,
'tags:create': db.User.RANK_MODERATOR,
}, },
}) })
def test_post_updating(
context_factory, post_factory, user_factory, fake_datetime):
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=db.User.RANK_REGULAR)
post = post_factory() post = post_factory()
db.session.add(post) db.session.add(post)
@ -35,14 +38,13 @@ def test_post_updating(
unittest.mock.patch('szurubooru.func.posts.update_post_flags'), \ unittest.mock.patch('szurubooru.func.posts.update_post_flags'), \
unittest.mock.patch('szurubooru.func.posts.serialize_post'), \ unittest.mock.patch('szurubooru.func.posts.serialize_post'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'), \ unittest.mock.patch('szurubooru.func.tags.export_to_json'), \
unittest.mock.patch('szurubooru.func.snapshots.save_entity_modification'): unittest.mock.patch('szurubooru.func.snapshots.save_entity_modification'), \
fake_datetime('1997-01-01'):
posts.serialize_post.return_value = 'serialized post' posts.serialize_post.return_value = 'serialized post'
with fake_datetime('1997-01-01'): result = api.post_api.update_post(
result = api.PostDetailApi().put(
context_factory( context_factory(
input={ params={
'version': 1, 'version': 1,
'safety': 'safe', 'safety': 'safe',
'tags': ['tag1', 'tag2'], 'tags': ['tag1', 'tag2'],
@ -56,7 +58,7 @@ def test_post_updating(
'thumbnail': 'post-thumbnail', 'thumbnail': 'post-thumbnail',
}, },
user=auth_user), user=auth_user),
post.post_id) {'post_id': post.post_id})
assert result == 'serialized post' assert result == 'serialized post'
posts.create_post.assert_not_called() posts.create_post.assert_not_called()
@ -71,13 +73,10 @@ def test_post_updating(
posts.serialize_post.assert_called_once_with(post, auth_user, options=None) posts.serialize_post.assert_called_once_with(post, auth_user, options=None)
tags.export_to_json.assert_called_once_with() tags.export_to_json.assert_called_once_with()
snapshots.save_entity_modification.assert_called_once_with(post, auth_user) snapshots.save_entity_modification.assert_called_once_with(post, auth_user)
assert post.last_edit_time == datetime.datetime(1997, 1, 1) assert post.last_edit_time == datetime(1997, 1, 1)
def test_uploading_from_url_saves_source( def test_uploading_from_url_saves_source(
config_injector, context_factory, post_factory, user_factory): context_factory, post_factory, user_factory):
config_injector({
'privileges': {'posts:edit:content': db.User.RANK_REGULAR},
})
post = post_factory() post = post_factory()
db.session.add(post) db.session.add(post)
db.session.flush() db.session.flush()
@ -88,23 +87,17 @@ def test_uploading_from_url_saves_source(
unittest.mock.patch('szurubooru.func.posts.update_post_content'), \ unittest.mock.patch('szurubooru.func.posts.update_post_content'), \
unittest.mock.patch('szurubooru.func.posts.update_post_source'): unittest.mock.patch('szurubooru.func.posts.update_post_source'):
net.download.return_value = b'content' net.download.return_value = b'content'
api.PostDetailApi().put( api.post_api.update_post(
context_factory( context_factory(
input={'contentUrl': 'example.com', 'version': 1}, params={'contentUrl': 'example.com', 'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=db.User.RANK_REGULAR)),
post.post_id) {'post_id': post.post_id})
net.download.assert_called_once_with('example.com') net.download.assert_called_once_with('example.com')
posts.update_post_content.assert_called_once_with(post, b'content') posts.update_post_content.assert_called_once_with(post, b'content')
posts.update_post_source.assert_called_once_with(post, 'example.com') posts.update_post_source.assert_called_once_with(post, 'example.com')
def test_uploading_from_url_with_source_specified( def test_uploading_from_url_with_source_specified(
config_injector, context_factory, post_factory, user_factory): context_factory, post_factory, user_factory):
config_injector({
'privileges': {
'posts:edit:content': db.User.RANK_REGULAR,
'posts:edit:source': db.User.RANK_REGULAR,
},
})
post = post_factory() post = post_factory()
db.session.add(post) db.session.add(post)
db.session.flush() db.session.flush()
@ -115,27 +108,27 @@ def test_uploading_from_url_with_source_specified(
unittest.mock.patch('szurubooru.func.posts.update_post_content'), \ unittest.mock.patch('szurubooru.func.posts.update_post_content'), \
unittest.mock.patch('szurubooru.func.posts.update_post_source'): unittest.mock.patch('szurubooru.func.posts.update_post_source'):
net.download.return_value = b'content' net.download.return_value = b'content'
api.PostDetailApi().put( api.post_api.update_post(
context_factory( context_factory(
input={ params={
'contentUrl': 'example.com', 'contentUrl': 'example.com',
'source': 'example2.com', 'source': 'example2.com',
'version': 1}, 'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=db.User.RANK_REGULAR)),
post.post_id) {'post_id': post.post_id})
net.download.assert_called_once_with('example.com') net.download.assert_called_once_with('example.com')
posts.update_post_content.assert_called_once_with(post, b'content') posts.update_post_content.assert_called_once_with(post, b'content')
posts.update_post_source.assert_called_once_with(post, 'example2.com') posts.update_post_source.assert_called_once_with(post, 'example2.com')
def test_trying_to_update_non_existing(context_factory, user_factory): def test_trying_to_update_non_existing(context_factory, user_factory):
with pytest.raises(posts.PostNotFoundError): with pytest.raises(posts.PostNotFoundError):
api.PostDetailApi().put( api.post_api.update_post(
context_factory( context_factory(
input='whatever', params='whatever',
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=db.User.RANK_REGULAR)),
1) {'post_id': 1})
@pytest.mark.parametrize('privilege,files,input', [ @pytest.mark.parametrize('privilege,files,params', [
('posts:edit:tags', {}, {'tags': '...'}), ('posts:edit:tags', {}, {'tags': '...'}),
('posts:edit:safety', {}, {'safety': '...'}), ('posts:edit:safety', {}, {'safety': '...'}),
('posts:edit:source', {}, {'source': '...'}), ('posts:edit:source', {}, {'source': '...'}),
@ -146,43 +139,28 @@ def test_trying_to_update_non_existing(context_factory, user_factory):
('posts:edit:thumbnail', {'thumbnail': '...'}, {}), ('posts:edit:thumbnail', {'thumbnail': '...'}, {}),
]) ])
def test_trying_to_update_field_without_privileges( def test_trying_to_update_field_without_privileges(
config_injector, context_factory, post_factory, user_factory, files, params, privilege):
context_factory,
post_factory,
user_factory,
files,
input,
privilege):
config_injector({
'privileges': {privilege: db.User.RANK_REGULAR},
})
post = post_factory() post = post_factory()
db.session.add(post) db.session.add(post)
db.session.flush() db.session.flush()
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
api.PostDetailApi().put( api.post_api.update_post(
context_factory( context_factory(
input={**input, **{'version': 1}}, params={**params, **{'version': 1}},
files=files, files=files,
user=user_factory(rank=db.User.RANK_ANONYMOUS)), user=user_factory(rank=db.User.RANK_ANONYMOUS)),
post.post_id) {'post_id': post.post_id})
def test_trying_to_create_tags_without_privileges( def test_trying_to_create_tags_without_privileges(
config_injector, context_factory, post_factory, user_factory): context_factory, post_factory, user_factory):
config_injector({
'privileges': {
'posts:edit:tags': db.User.RANK_REGULAR,
'tags:create': db.User.RANK_ADMINISTRATOR,
},
})
post = post_factory() post = post_factory()
db.session.add(post) db.session.add(post)
db.session.flush() db.session.flush()
with pytest.raises(errors.AuthError), \ with pytest.raises(errors.AuthError), \
unittest.mock.patch('szurubooru.func.posts.update_post_tags'): unittest.mock.patch('szurubooru.func.posts.update_post_tags'):
posts.update_post_tags.return_value = ['new-tag'] posts.update_post_tags.return_value = ['new-tag']
api.PostDetailApi().put( api.post_api.update_post(
context_factory( context_factory(
input={'tags': ['tag1', 'tag2'], 'version': 1}, params={'tags': ['tag1', 'tag2'], 'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=db.User.RANK_REGULAR)),
post.post_id) {'post_id': post.post_id})

View file

@ -1,11 +1,10 @@
import datetime
import pytest import pytest
from datetime import datetime
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import util, tags
def snapshot_factory(): def snapshot_factory():
snapshot = db.Snapshot() snapshot = db.Snapshot()
snapshot.creation_time = datetime.datetime(1999, 1, 1) snapshot.creation_time = datetime(1999, 1, 1)
snapshot.resource_type = 'dummy' snapshot.resource_type = 'dummy'
snapshot.resource_id = 1 snapshot.resource_id = 1
snapshot.resource_repr = 'dummy' snapshot.resource_repr = 'dummy'
@ -13,37 +12,30 @@ def snapshot_factory():
snapshot.data = '{}' snapshot.data = '{}'
return snapshot return snapshot
@pytest.fixture @pytest.fixture(autouse=True)
def test_ctx(context_factory, config_injector, user_factory): def inject_config(config_injector):
config_injector({ config_injector({
'privileges': { 'privileges': {'snapshots:list': db.User.RANK_REGULAR},
'snapshots:list': db.User.RANK_REGULAR,
},
'thumbnails': {'avatar_width': 200},
}) })
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.api = api.SnapshotListApi()
return ret
def test_retrieving_multiple(test_ctx): def test_retrieving_multiple(user_factory, context_factory):
snapshot1 = snapshot_factory() snapshot1 = snapshot_factory()
snapshot2 = snapshot_factory() snapshot2 = snapshot_factory()
db.session.add_all([snapshot1, snapshot2]) db.session.add_all([snapshot1, snapshot2])
result = test_ctx.api.get( result = api.snapshot_api.get_snapshots(
test_ctx.context_factory( context_factory(
input={'query': '', 'page': 1}, params={'query': '', 'page': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=db.User.RANK_REGULAR)))
assert result['query'] == '' assert result['query'] == ''
assert result['page'] == 1 assert result['page'] == 1
assert result['pageSize'] == 100 assert result['pageSize'] == 100
assert result['total'] == 2 assert result['total'] == 2
assert len(result['results']) == 2 assert len(result['results']) == 2
def test_trying_to_retrieve_multiple_without_privileges(test_ctx): def test_trying_to_retrieve_multiple_without_privileges(
user_factory, context_factory):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.api.get( api.snapshot_api.get_snapshots(
test_ctx.context_factory( context_factory(
input={'query': '', 'page': 1}, params={'query': '', 'page': 1},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS))) user=user_factory(rank=db.User.RANK_ANONYMOUS)))

View file

@ -1,94 +1,50 @@
import os
import pytest import pytest
from szurubooru import api, config, db, errors import unittest.mock
from szurubooru.func import util, tag_categories from szurubooru import api, db, errors
from szurubooru.func import tag_categories, tags
@pytest.fixture def _update_category_name(category, name):
def test_ctx(tmpdir, config_injector, context_factory, user_factory): category.name = name
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({ config_injector({
'data_dir': str(tmpdir),
'tag_category_name_regex': '^[^!]+$',
'privileges': {'tag_categories:create': db.User.RANK_REGULAR}, 'privileges': {'tag_categories:create': db.User.RANK_REGULAR},
}) })
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.api = api.TagCategoryListApi()
return ret
def test_creating_category(test_ctx): def test_creating_category(user_factory, context_factory):
result = test_ctx.api.post( with unittest.mock.patch('szurubooru.func.tag_categories.serialize_category'), \
test_ctx.context_factory( unittest.mock.patch('szurubooru.func.tag_categories.update_category_name'), \
input={'name': 'meta', 'color': 'black'}, unittest.mock.patch('szurubooru.func.tags.export_to_json'):
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) tag_categories.update_category_name.side_effect = _update_category_name
assert len(result['snapshots']) == 1 tag_categories.serialize_category.return_value = 'serialized category'
del result['snapshots'] result = api.tag_category_api.create_tag_category(
assert result == { context_factory(
'name': 'meta', params={'name': 'meta', 'color': 'black'},
'color': 'black', user=user_factory(rank=db.User.RANK_REGULAR)))
'usages': 0, assert result == 'serialized category'
'default': True,
'version': 1,
}
category = db.session.query(db.TagCategory).one() category = db.session.query(db.TagCategory).one()
assert category.name == 'meta' assert category.name == 'meta'
assert category.color == 'black' assert category.color == 'black'
assert category.tag_count == 0 assert category.tag_count == 0
assert os.path.exists(os.path.join(config.config['data_dir'], 'tags.json')) tags.export_to_json.assert_called_once_with()
@pytest.mark.parametrize('input', [
{'name': None},
{'name': ''},
{'name': '!bad'},
{'color': None},
{'color': ''},
{'color': 'a' * 100},
])
def test_trying_to_pass_invalid_input(test_ctx, input):
real_input = {
'name': 'okay',
'color': 'okay',
}
for key, value in input.items():
real_input[key] = value
with pytest.raises(errors.ValidationError):
test_ctx.api.post(
test_ctx.context_factory(
input=real_input,
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
@pytest.mark.parametrize('field', ['name', 'color']) @pytest.mark.parametrize('field', ['name', 'color'])
def test_trying_to_omit_mandatory_field(test_ctx, field): def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
input = { params = {
'name': 'meta', 'name': 'meta',
'color': 'black', 'color': 'black',
} }
del input[field] del params[field]
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError):
test_ctx.api.post( api.tag_category_api.create_tag_category(
test_ctx.context_factory( context_factory(
input=input, params=params,
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=db.User.RANK_REGULAR)))
def test_trying_to_use_existing_name(test_ctx): def test_trying_to_create_without_privileges(user_factory, context_factory):
result = test_ctx.api.post(
test_ctx.context_factory(
input={'name': 'meta', 'color': 'black'},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
with pytest.raises(tag_categories.TagCategoryAlreadyExistsError):
result = test_ctx.api.post(
test_ctx.context_factory(
input={'name': 'meta', 'color': 'black'},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
with pytest.raises(tag_categories.TagCategoryAlreadyExistsError):
result = test_ctx.api.post(
test_ctx.context_factory(
input={'name': 'META', 'color': 'black'},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
def test_trying_to_create_without_privileges(test_ctx):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.api.post( api.tag_category_api.create_tag_category(
test_ctx.context_factory( context_factory(
input={'name': 'meta', 'color': 'black'}, params={'name': 'meta', 'color': 'black'},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS))) user=user_factory(rank=db.User.RANK_ANONYMOUS)))

View file

@ -1,84 +1,70 @@
import pytest import pytest
import os import unittest.mock
from datetime import datetime from szurubooru import api, db, errors
from szurubooru import api, config, db, errors from szurubooru.func import tag_categories, tags
from szurubooru.func import util, tags, tag_categories
@pytest.fixture @pytest.fixture(autouse=True)
def test_ctx( def inject_config(config_injector):
tmpdir,
config_injector,
context_factory,
tag_factory,
tag_category_factory,
user_factory):
config_injector({ config_injector({
'data_dir': str(tmpdir), 'privileges': {'tag_categories:delete': db.User.RANK_REGULAR},
'privileges': {
'tag_categories:delete': db.User.RANK_REGULAR,
},
}) })
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.tag_factory = tag_factory
ret.tag_category_factory = tag_category_factory
ret.api = api.TagCategoryDetailApi()
return ret
def test_deleting(test_ctx): def test_deleting(user_factory, tag_category_factory, context_factory):
db.session.add(test_ctx.tag_category_factory(name='root')) db.session.add(tag_category_factory(name='root'))
db.session.add(test_ctx.tag_category_factory(name='category')) db.session.add(tag_category_factory(name='category'))
db.session.commit() db.session.commit()
result = test_ctx.api.delete( with unittest.mock.patch('szurubooru.func.tags.export_to_json'):
test_ctx.context_factory( result = api.tag_category_api.delete_tag_category(
input={'version': 1}, context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), params={'version': 1},
'category') user=user_factory(rank=db.User.RANK_REGULAR)),
{'category_name': 'category'})
assert result == {} assert result == {}
assert db.session.query(db.TagCategory).count() == 1 assert db.session.query(db.TagCategory).count() == 1
assert db.session.query(db.TagCategory).one().name == 'root' assert db.session.query(db.TagCategory).one().name == 'root'
assert os.path.exists(os.path.join(config.config['data_dir'], 'tags.json')) tags.export_to_json.assert_called_once_with()
def test_trying_to_delete_used(test_ctx, tag_factory): def test_trying_to_delete_used(
category = test_ctx.tag_category_factory(name='category') user_factory, tag_category_factory, tag_factory, context_factory):
category = tag_category_factory(name='category')
db.session.add(category) db.session.add(category)
db.session.flush() db.session.flush()
tag = test_ctx.tag_factory(names=['tag'], category=category) tag = tag_factory(names=['tag'], category=category)
db.session.add(tag) db.session.add(tag)
db.session.commit() db.session.commit()
with pytest.raises(tag_categories.TagCategoryIsInUseError): with pytest.raises(tag_categories.TagCategoryIsInUseError):
test_ctx.api.delete( api.tag_category_api.delete_tag_category(
test_ctx.context_factory( context_factory(
input={'version': 1}, params={'version': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=db.User.RANK_REGULAR)),
'category') {'category_name': 'category'})
assert db.session.query(db.TagCategory).count() == 1 assert db.session.query(db.TagCategory).count() == 1
def test_trying_to_delete_last(test_ctx, tag_factory): def test_trying_to_delete_last(
db.session.add(test_ctx.tag_category_factory(name='root')) user_factory, tag_category_factory, context_factory):
db.session.add(tag_category_factory(name='root'))
db.session.commit() db.session.commit()
with pytest.raises(tag_categories.TagCategoryIsInUseError): with pytest.raises(tag_categories.TagCategoryIsInUseError):
result = test_ctx.api.delete( api.tag_category_api.delete_tag_category(
test_ctx.context_factory( context_factory(
input={'version': 1}, params={'version': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=db.User.RANK_REGULAR)),
'root') {'category_name': 'root'})
def test_trying_to_delete_non_existing(test_ctx): def test_trying_to_delete_non_existing(user_factory, context_factory):
with pytest.raises(tag_categories.TagCategoryNotFoundError): with pytest.raises(tag_categories.TagCategoryNotFoundError):
test_ctx.api.delete( api.tag_category_api.delete_tag_category(
test_ctx.context_factory( context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), {'category_name': 'bad'})
'bad')
def test_trying_to_delete_without_privileges(test_ctx): def test_trying_to_delete_without_privileges(
db.session.add(test_ctx.tag_category_factory(name='category')) user_factory, tag_category_factory, context_factory):
db.session.add(tag_category_factory(name='category'))
db.session.commit() db.session.commit()
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.api.delete( api.tag_category_api.delete_tag_category(
test_ctx.context_factory( context_factory(
input={'version': 1}, params={'version': 1},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)), user=user_factory(rank=db.User.RANK_ANONYMOUS)),
'category') {'category_name': 'category'})
assert db.session.query(db.TagCategory).count() == 1 assert db.session.query(db.TagCategory).count() == 1

View file

@ -1,42 +1,31 @@
import datetime
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import util, tag_categories from szurubooru.func import tag_categories
@pytest.fixture @pytest.fixture(autouse=True)
def test_ctx( def inject_config(config_injector):
context_factory, config_injector, user_factory, tag_category_factory):
config_injector({ config_injector({
'privileges': { 'privileges': {
'tag_categories:list': db.User.RANK_REGULAR, 'tag_categories:list': db.User.RANK_REGULAR,
'tag_categories:view': db.User.RANK_REGULAR, 'tag_categories:view': db.User.RANK_REGULAR,
}, },
'thumbnails': {'avatar_width': 200},
}) })
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.tag_category_factory = tag_category_factory
ret.list_api = api.TagCategoryListApi()
ret.detail_api = api.TagCategoryDetailApi()
return ret
def test_retrieving_multiple(test_ctx): def test_retrieving_multiple(
user_factory, tag_category_factory, context_factory):
db.session.add_all([ db.session.add_all([
test_ctx.tag_category_factory(name='c1'), tag_category_factory(name='c1'),
test_ctx.tag_category_factory(name='c2'), tag_category_factory(name='c2'),
]) ])
result = test_ctx.list_api.get( result = api.tag_category_api.get_tag_categories(
test_ctx.context_factory( context_factory(user=user_factory(rank=db.User.RANK_REGULAR)))
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
assert [cat['name'] for cat in result['results']] == ['c1', 'c2'] assert [cat['name'] for cat in result['results']] == ['c1', 'c2']
def test_retrieving_single(test_ctx): def test_retrieving_single(user_factory, tag_category_factory, context_factory):
db.session.add(test_ctx.tag_category_factory(name='cat')) db.session.add(tag_category_factory(name='cat'))
result = test_ctx.detail_api.get( result = api.tag_category_api.get_tag_category(
test_ctx.context_factory( context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), {'category_name': 'cat'})
'cat')
assert result == { assert result == {
'name': 'cat', 'name': 'cat',
'color': 'dummy', 'color': 'dummy',
@ -46,16 +35,15 @@ def test_retrieving_single(test_ctx):
'version': 1, 'version': 1,
} }
def test_trying_to_retrieve_single_non_existing(test_ctx): def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
with pytest.raises(tag_categories.TagCategoryNotFoundError): with pytest.raises(tag_categories.TagCategoryNotFoundError):
test_ctx.detail_api.get( api.tag_category_api.get_tag_category(
test_ctx.context_factory( context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), {'category_name': '-'})
'-')
def test_trying_to_retrieve_single_without_privileges(test_ctx): def test_trying_to_retrieve_single_without_privileges(
user_factory, context_factory):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.detail_api.get( api.tag_category_api.get_tag_category(
test_ctx.context_factory( context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)),
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)), {'category_name': '-'})
'-')

View file

@ -1,137 +1,104 @@
import os
import pytest import pytest
from szurubooru import api, config, db, errors import unittest.mock
from szurubooru.func import util, tag_categories from szurubooru import api, db, errors
from szurubooru.func import tag_categories, tags
@pytest.fixture def _update_category_name(category, name):
def test_ctx( category.name = name
tmpdir,
config_injector, @pytest.fixture(autouse=True)
context_factory, def inject_config(config_injector):
user_factory,
tag_category_factory):
config_injector({ config_injector({
'data_dir': str(tmpdir),
'tag_category_name_regex': '^[^!]*$',
'privileges': { 'privileges': {
'tag_categories:edit:name': db.User.RANK_REGULAR, 'tag_categories:edit:name': db.User.RANK_REGULAR,
'tag_categories:edit:color': db.User.RANK_REGULAR, 'tag_categories:edit:color': db.User.RANK_REGULAR,
'tag_categories:set_default': db.User.RANK_REGULAR,
}, },
}) })
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.tag_category_factory = tag_category_factory
ret.api = api.TagCategoryDetailApi()
return ret
def test_simple_updating(test_ctx): def test_simple_updating(user_factory, tag_category_factory, context_factory):
category = test_ctx.tag_category_factory(name='name', color='black') category = tag_category_factory(name='name', color='black')
db.session.add(category) db.session.add(category)
db.session.commit() db.session.commit()
result = test_ctx.api.put( with unittest.mock.patch('szurubooru.func.tag_categories.serialize_category'), \
test_ctx.context_factory( unittest.mock.patch('szurubooru.func.tag_categories.update_category_name'), \
input={ unittest.mock.patch('szurubooru.func.tag_categories.update_category_color'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'):
tag_categories.update_category_name.side_effect = _update_category_name
tag_categories.serialize_category.return_value = 'serialized category'
result = api.tag_category_api.update_tag_category(
context_factory(
params={
'name': 'changed', 'name': 'changed',
'color': 'white', 'color': 'white',
'version': 1, 'version': 1,
}, },
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=db.User.RANK_REGULAR)),
'name') {'category_name': 'name'})
assert len(result['snapshots']) == 1 assert result == 'serialized category'
del result['snapshots'] tag_categories.update_category_name.assert_called_once_with(category, 'changed')
assert result == { tag_categories.update_category_color.assert_called_once_with(category, 'white')
'name': 'changed', tags.export_to_json.assert_called_once_with()
'color': 'white',
'usages': 0,
'default': False,
'version': 2,
}
assert tag_categories.try_get_category_by_name('name') is None
category = tag_categories.get_category_by_name('changed')
assert category is not None
assert category.name == 'changed'
assert category.color == 'white'
assert os.path.exists(os.path.join(config.config['data_dir'], 'tags.json'))
@pytest.mark.parametrize('input,expected_exception', [
({'name': None}, tag_categories.InvalidTagCategoryNameError),
({'name': ''}, tag_categories.InvalidTagCategoryNameError),
({'name': '!bad'}, tag_categories.InvalidTagCategoryNameError),
({'color': None}, tag_categories.InvalidTagCategoryColorError),
({'color': ''}, tag_categories.InvalidTagCategoryColorError),
({'color': '; float:left'}, tag_categories.InvalidTagCategoryColorError),
])
def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception):
db.session.add(test_ctx.tag_category_factory(name='meta', color='black'))
db.session.commit()
with pytest.raises(expected_exception):
test_ctx.api.put(
test_ctx.context_factory(
input={**input, **{'version': 1}},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'meta')
@pytest.mark.parametrize('field', ['name', 'color']) @pytest.mark.parametrize('field', ['name', 'color'])
def test_omitting_optional_field(test_ctx, field): def test_omitting_optional_field(
db.session.add(test_ctx.tag_category_factory(name='name', color='black')) user_factory, tag_category_factory, context_factory, field):
db.session.add(tag_category_factory(name='name', color='black'))
db.session.commit() db.session.commit()
input = { params = {
'name': 'changed', 'name': 'changed',
'color': 'white', 'color': 'white',
} }
del input[field] del params[field]
result = test_ctx.api.put( with unittest.mock.patch('szurubooru.func.tag_categories.serialize_category'), \
test_ctx.context_factory( unittest.mock.patch('szurubooru.func.tag_categories.update_category_name'), \
input={**input, **{'version': 1}}, unittest.mock.patch('szurubooru.func.tags.export_to_json'):
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), api.tag_category_api.update_tag_category(
'name') context_factory(
assert result is not None params={**params, **{'version': 1}},
user=user_factory(rank=db.User.RANK_REGULAR)),
{'category_name': 'name'})
def test_trying_to_update_non_existing(test_ctx): def test_trying_to_update_non_existing(user_factory, context_factory):
with pytest.raises(tag_categories.TagCategoryNotFoundError): with pytest.raises(tag_categories.TagCategoryNotFoundError):
test_ctx.api.put( api.tag_category_api.update_tag_category(
test_ctx.context_factory( context_factory(
input={'name': ['dummy']}, params={'name': ['dummy']},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=db.User.RANK_REGULAR)),
'bad') {'category_name': 'bad'})
@pytest.mark.parametrize('new_name', ['cat', 'CAT']) @pytest.mark.parametrize('params', [
def test_reusing_own_name(test_ctx, new_name):
db.session.add(test_ctx.tag_category_factory(name='cat', color='black'))
db.session.commit()
result = test_ctx.api.put(
test_ctx.context_factory(
input={'name': new_name, 'version': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'cat')
assert result['name'] == new_name
category = tag_categories.get_category_by_name('cat')
assert category.name == new_name
@pytest.mark.parametrize('dup_name', ['cat1', 'CAT1'])
def test_trying_to_use_existing_name(test_ctx, dup_name):
db.session.add_all([
test_ctx.tag_category_factory(name='cat1', color='black'),
test_ctx.tag_category_factory(name='cat2', color='black')])
db.session.commit()
with pytest.raises(tag_categories.TagCategoryAlreadyExistsError):
test_ctx.api.put(
test_ctx.context_factory(
input={'name': dup_name, 'version': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'cat2')
@pytest.mark.parametrize('input', [
{'name': 'whatever'}, {'name': 'whatever'},
{'color': 'whatever'}, {'color': 'whatever'},
]) ])
def test_trying_to_update_without_privileges(test_ctx, input): def test_trying_to_update_without_privileges(
db.session.add(test_ctx.tag_category_factory(name='dummy')) user_factory, tag_category_factory, context_factory, params):
db.session.add(tag_category_factory(name='dummy'))
db.session.commit() db.session.commit()
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.api.put( api.tag_category_api.update_tag_category(
test_ctx.context_factory( context_factory(
input={**input, **{'version': 1}}, params={**params, **{'version': 1}},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)), user=user_factory(rank=db.User.RANK_ANONYMOUS)),
'dummy') {'category_name': 'dummy'})
def test_set_as_default(user_factory, tag_category_factory, context_factory):
category = tag_category_factory(name='name', color='black')
db.session.add(category)
db.session.commit()
with unittest.mock.patch('szurubooru.func.tag_categories.serialize_category'), \
unittest.mock.patch('szurubooru.func.tag_categories.set_default_category'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'):
tag_categories.update_category_name.side_effect = _update_category_name
tag_categories.serialize_category.return_value = 'serialized category'
result = api.tag_category_api.set_tag_category_as_default(
context_factory(
params={
'name': 'changed',
'color': 'white',
'version': 1,
},
user=user_factory(rank=db.User.RANK_REGULAR)),
{'category_name': 'name'})
assert result == 'serialized category'
tag_categories.set_default_category.assert_called_once_with(category)

View file

@ -1,187 +1,77 @@
import datetime
import os
import pytest import pytest
from szurubooru import api, config, db, errors import unittest.mock
from szurubooru.func import util, tags, tag_categories, cache from szurubooru import api, db, errors
from szurubooru.func import tags, tag_categories
def assert_relations(relations, expected_tag_names): @pytest.fixture(autouse=True)
actual_names = sorted([rel.names[0].name for rel in relations]) def inject_config(config_injector):
assert actual_names == sorted(expected_tag_names) config_injector({'privileges': {'tags:create': db.User.RANK_REGULAR}})
@pytest.fixture def test_creating_simple_tags(tag_factory, user_factory, context_factory):
def test_ctx( with unittest.mock.patch('szurubooru.func.tags.create_tag'), \
tmpdir, config_injector, context_factory, user_factory, tag_factory): unittest.mock.patch('szurubooru.func.tags.get_or_create_tags_by_names'), \
config_injector({ unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \
'data_dir': str(tmpdir), unittest.mock.patch('szurubooru.func.tags.export_to_json'):
'tag_name_regex': '^[^!]*$', tags.get_or_create_tags_by_names.return_value = ([], [])
'tag_category_name_regex': '^[^!]*$', tags.create_tag.return_value = tag_factory()
'privileges': {'tags:create': db.User.RANK_REGULAR}, tags.serialize_tag.return_value = 'serialized tag'
}) result = api.tag_api.create_tag(
db.session.add_all([ context_factory(
db.TagCategory(name) for name in ['meta', 'character', 'copyright']]) params={
db.session.flush()
cache.purge()
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.tag_factory = tag_factory
ret.api = api.TagListApi()
return ret
def test_creating_simple_tags(test_ctx, fake_datetime):
with fake_datetime('1997-12-01'):
result = test_ctx.api.post(
test_ctx.context_factory(
input={
'names': ['tag1', 'tag2'], 'names': ['tag1', 'tag2'],
'category': 'meta', 'category': 'meta',
'description': 'desc', 'description': 'desc',
'suggestions': [],
'implications': [],
},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
assert len(result['snapshots']) == 1
del result['snapshots']
assert result == {
'names': ['tag1', 'tag2'],
'category': 'meta',
'description': 'desc',
'suggestions': [],
'implications': [],
'creationTime': datetime.datetime(1997, 12, 1),
'lastEditTime': None,
'usages': 0,
'version': 1,
}
tag = tags.get_tag_by_name('tag1')
assert [tag_name.name for tag_name in tag.names] == ['tag1', 'tag2']
assert tag.category.name == 'meta'
assert tag.last_edit_time is None
assert tag.post_count == 0
assert_relations(tag.suggestions, [])
assert_relations(tag.implications, [])
assert os.path.exists(os.path.join(config.config['data_dir'], 'tags.json'))
@pytest.mark.parametrize('input,expected_exception', [
({'names': None}, tags.InvalidTagNameError),
({'names': []}, tags.InvalidTagNameError),
({'names': [None]}, tags.InvalidTagNameError),
({'names': ['']}, tags.InvalidTagNameError),
({'names': ['!bad']}, tags.InvalidTagNameError),
({'names': ['x' * 65]}, tags.InvalidTagNameError),
({'category': None}, tag_categories.TagCategoryNotFoundError),
({'category': ''}, tag_categories.TagCategoryNotFoundError),
({'category': '!bad'}, tag_categories.TagCategoryNotFoundError),
({'suggestions': ['good', '!bad']}, tags.InvalidTagNameError),
({'implications': ['good', '!bad']}, tags.InvalidTagNameError),
])
def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception):
real_input={
'names': ['tag1', 'tag2'],
'category': 'meta',
'suggestions': [],
'implications': [],
}
for key, value in input.items():
real_input[key] = value
with pytest.raises(expected_exception):
test_ctx.api.post(
test_ctx.context_factory(
input=real_input,
user=test_ctx.user_factory()))
@pytest.mark.parametrize('field', ['names', 'category'])
def test_trying_to_omit_mandatory_field(test_ctx, field):
input = {
'names': ['tag1', 'tag2'],
'category': 'meta',
'suggestions': [],
'implications': [],
}
del input[field]
with pytest.raises(errors.ValidationError):
test_ctx.api.post(
test_ctx.context_factory(
input=input,
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
@pytest.mark.parametrize('field', ['implications', 'suggestions'])
def test_omitting_optional_field(test_ctx, field):
input = {
'names': ['tag1', 'tag2'],
'category': 'meta',
'suggestions': [],
'implications': [],
}
del input[field]
result = test_ctx.api.post(
test_ctx.context_factory(
input=input,
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
assert result is not None
def test_creating_new_category(test_ctx):
with pytest.raises(tag_categories.TagCategoryNotFoundError):
test_ctx.api.post(
test_ctx.context_factory(
input={
'names': ['main'],
'category': 'new',
'suggestions': [],
'implications': [],
}, user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
@pytest.mark.parametrize('input,expected_suggestions,expected_implications', [
# new relations
({
'names': ['main'],
'category': 'meta',
'suggestions': ['sug1', 'sug2'], 'suggestions': ['sug1', 'sug2'],
'implications': ['imp1', 'imp2'], 'implications': ['imp1', 'imp2'],
}, ['sug1', 'sug2'], ['imp1', 'imp2']), },
# overlapping relations user=user_factory(rank=db.User.RANK_REGULAR)))
({ assert result == 'serialized tag'
'names': ['main'], tags.create_tag.assert_called_once_with(
'category': 'meta', ['tag1', 'tag2'], 'meta', ['sug1', 'sug2'], ['imp1', 'imp2'])
'suggestions': ['sug', 'shared'], tags.export_to_json.assert_called_once_with()
'implications': ['shared', 'imp'],
}, ['shared', 'sug'], ['imp', 'shared']),
# duplicate relations
({
'names': ['main'],
'category': 'meta',
'suggestions': ['sug', 'SUG'],
'implications': ['imp', 'IMP'],
}, ['sug'], ['imp']),
# overlapping duplicate relations
({
'names': ['main'],
'category': 'meta',
'suggestions': ['shared1', 'shared2'],
'implications': ['SHARED1', 'SHARED2'],
}, ['shared1', 'shared2'], ['shared1', 'shared2']),
])
def test_creating_new_suggestions_and_implications(
test_ctx, input, expected_suggestions, expected_implications):
result = test_ctx.api.post(
test_ctx.context_factory(
input=input, user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
assert result['suggestions'] == expected_suggestions
assert result['implications'] == expected_implications
tag = tags.get_tag_by_name('main')
assert_relations(tag.suggestions, expected_suggestions)
assert_relations(tag.implications, expected_implications)
for name in ['main'] + expected_suggestions + expected_implications:
assert tags.try_get_tag_by_name(name) is not None
def test_trying_to_create_tag_without_privileges(test_ctx): @pytest.mark.parametrize('field', ['names', 'category'])
def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
params = {
'names': ['tag1', 'tag2'],
'category': 'meta',
'suggestions': [],
'implications': [],
}
del params[field]
with pytest.raises(errors.ValidationError):
api.tag_api.create_tag(
context_factory(
params=params,
user=user_factory(rank=db.User.RANK_REGULAR)))
@pytest.mark.parametrize('field', ['implications', 'suggestions'])
def test_omitting_optional_field(
tag_factory, user_factory, context_factory, field):
params = {
'names': ['tag1', 'tag2'],
'category': 'meta',
'suggestions': [],
'implications': [],
}
del params[field]
with unittest.mock.patch('szurubooru.func.tags.create_tag'), \
unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'):
tags.create_tag.return_value = tag_factory()
api.tag_api.create_tag(
context_factory(
params=params,
user=user_factory(rank=db.User.RANK_REGULAR)))
def test_trying_to_create_tag_without_privileges(user_factory, context_factory):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.api.post( api.tag_api.create_tag(
test_ctx.context_factory( context_factory(
input={ params={
'names': ['tag'], 'names': ['tag'],
'category': 'meta', 'category': 'meta',
'suggestions': ['tag'], 'suggestions': ['tag'],
'implications': [], 'implications': [],
}, },
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS))) user=user_factory(rank=db.User.RANK_ANONYMOUS)))

View file

@ -1,50 +1,55 @@
import pytest import pytest
import os import unittest.mock
from datetime import datetime from szurubooru import api, db, errors
from szurubooru import api, config, db, errors from szurubooru.func import tags
from szurubooru.func import util, tags
@pytest.fixture @pytest.fixture(autouse=True)
def test_ctx( def inject_config(config_injector):
tmpdir, config_injector, context_factory, tag_factory, user_factory): config_injector({'privileges': {'tags:delete': db.User.RANK_REGULAR}})
config_injector({
'data_dir': str(tmpdir),
'privileges': {
'tags:delete': db.User.RANK_REGULAR,
},
})
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.tag_factory = tag_factory
ret.api = api.TagDetailApi()
return ret
def test_deleting(test_ctx): def test_deleting(user_factory, tag_factory, context_factory):
db.session.add(test_ctx.tag_factory(names=['tag'])) db.session.add(tag_factory(names=['tag']))
db.session.commit() db.session.commit()
result = test_ctx.api.delete( with unittest.mock.patch('szurubooru.func.tags.export_to_json'):
test_ctx.context_factory( result = api.tag_api.delete_tag(
input={'version': 1}, context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), params={'version': 1},
'tag') user=user_factory(rank=db.User.RANK_REGULAR)),
{'tag_name': 'tag'})
assert result == {} assert result == {}
assert db.session.query(db.Tag).count() == 0 assert db.session.query(db.Tag).count() == 0
assert os.path.exists(os.path.join(config.config['data_dir'], 'tags.json')) tags.export_to_json.assert_called_once_with()
def test_trying_to_delete_non_existing(test_ctx): def test_deleting_used(user_factory, tag_factory, context_factory, post_factory):
tag = tag_factory(names=['tag'])
post = post_factory()
post.tags.append(tag)
db.session.add_all([tag, post])
db.session.commit()
with unittest.mock.patch('szurubooru.func.tags.export_to_json'):
api.tag_api.delete_tag(
context_factory(
params={'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)),
{'tag_name': 'tag'})
db.session.refresh(post)
assert db.session.query(db.Tag).count() == 0
assert post.tags == []
def test_trying_to_delete_non_existing(user_factory, context_factory):
with pytest.raises(tags.TagNotFoundError): with pytest.raises(tags.TagNotFoundError):
test_ctx.api.delete( api.tag_api.delete_tag(
test_ctx.context_factory( context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 'bad') {'tag_name': 'bad'})
def test_trying_to_delete_without_privileges(test_ctx): def test_trying_to_delete_without_privileges(
db.session.add(test_ctx.tag_factory(names=['tag'])) user_factory, tag_factory, context_factory):
db.session.add(tag_factory(names=['tag']))
db.session.commit() db.session.commit()
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.api.delete( api.tag_api.delete_tag(
test_ctx.context_factory( context_factory(
input={'version': 1}, params={'version': 1},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)), user=user_factory(rank=db.User.RANK_ANONYMOUS)),
'tag') {'tag_name': 'tag'})
assert db.session.query(db.Tag).count() == 1 assert db.session.query(db.Tag).count() == 1

View file

@ -1,34 +1,15 @@
import datetime
import os
import pytest import pytest
from szurubooru import api, config, db, errors import unittest.mock
from szurubooru.func import util, tags from szurubooru import api, db, errors
from szurubooru.func import tags
@pytest.fixture @pytest.fixture(autouse=True)
def test_ctx( def inject_config(config_injector):
tmpdir, config_injector({'privileges': {'tags:merge': db.User.RANK_REGULAR}})
config_injector,
context_factory,
user_factory,
tag_factory,
tag_category_factory):
config_injector({
'data_dir': str(tmpdir),
'privileges': {
'tags:merge': db.User.RANK_REGULAR,
},
})
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.tag_factory = tag_factory
ret.tag_category_factory = tag_category_factory
ret.api = api.TagMergeApi()
return ret
def test_merging_with_usages(test_ctx, fake_datetime, post_factory): def test_merging(user_factory, tag_factory, context_factory, post_factory):
source_tag = test_ctx.tag_factory(names=['source']) source_tag = tag_factory(names=['source'])
target_tag = test_ctx.tag_factory(names=['target']) target_tag = tag_factory(names=['target'])
db.session.add_all([source_tag, target_tag]) db.session.add_all([source_tag, target_tag])
db.session.flush() db.session.flush()
assert source_tag.post_count == 0 assert source_tag.post_count == 0
@ -39,73 +20,78 @@ def test_merging_with_usages(test_ctx, fake_datetime, post_factory):
db.session.commit() db.session.commit()
assert source_tag.post_count == 1 assert source_tag.post_count == 1
assert target_tag.post_count == 0 assert target_tag.post_count == 0
with fake_datetime('1997-12-01'): with unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \
result = test_ctx.api.post( unittest.mock.patch('szurubooru.func.tags.merge_tags'), \
test_ctx.context_factory( unittest.mock.patch('szurubooru.func.tags.export_to_json'):
input={ result = api.tag_api.merge_tags(
context_factory(
params={
'removeVersion': 1, 'removeVersion': 1,
'mergeToVersion': 1, 'mergeToVersion': 1,
'remove': 'source', 'remove': 'source',
'mergeTo': 'target', 'mergeTo': 'target',
}, },
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=db.User.RANK_REGULAR)))
assert tags.try_get_tag_by_name('source') is None tags.merge_tags.called_once_with(source_tag, target_tag)
assert tags.get_tag_by_name('target').post_count == 1 tags.export_to_json.assert_called_once_with()
@pytest.mark.parametrize( @pytest.mark.parametrize(
'field', ['remove', 'mergeTo', 'removeVersion', 'mergeToVersion']) 'field', ['remove', 'mergeTo', 'removeVersion', 'mergeToVersion'])
def test_trying_to_omit_mandatory_field(test_ctx, field): def test_trying_to_omit_mandatory_field(
user_factory, tag_factory, context_factory, field):
db.session.add_all([ db.session.add_all([
test_ctx.tag_factory(names=['source']), tag_factory(names=['source']),
test_ctx.tag_factory(names=['target']), tag_factory(names=['target']),
]) ])
db.session.commit() db.session.commit()
input = { params = {
'removeVersion': 1, 'removeVersion': 1,
'mergeToVersion': 1, 'mergeToVersion': 1,
'remove': 'source', 'remove': 'source',
'mergeTo': 'target', 'mergeTo': 'target',
} }
del input[field] del params[field]
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError):
test_ctx.api.post( api.tag_api.merge_tags(
test_ctx.context_factory( context_factory(
input=input, params=params,
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=db.User.RANK_REGULAR)))
def test_trying_to_merge_non_existing(test_ctx): def test_trying_to_merge_non_existing(
db.session.add(test_ctx.tag_factory(names=['good'])) user_factory, tag_factory, context_factory):
db.session.add(tag_factory(names=['good']))
db.session.commit() db.session.commit()
with pytest.raises(tags.TagNotFoundError): with pytest.raises(tags.TagNotFoundError):
test_ctx.api.post( api.tag_api.merge_tags(
test_ctx.context_factory( context_factory(
input={'remove': 'good', 'mergeTo': 'bad'}, params={'remove': 'good', 'mergeTo': 'bad'},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=db.User.RANK_REGULAR)))
with pytest.raises(tags.TagNotFoundError): with pytest.raises(tags.TagNotFoundError):
test_ctx.api.post( api.tag_api.merge_tags(
test_ctx.context_factory( context_factory(
input={'remove': 'bad', 'mergeTo': 'good'}, params={'remove': 'bad', 'mergeTo': 'good'},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=db.User.RANK_REGULAR)))
@pytest.mark.parametrize('input', [ @pytest.mark.parametrize('params', [
{'names': 'whatever'}, {'names': 'whatever'},
{'category': 'whatever'}, {'category': 'whatever'},
{'suggestions': ['whatever']}, {'suggestions': ['whatever']},
{'implications': ['whatever']}, {'implications': ['whatever']},
]) ])
def test_trying_to_merge_without_privileges(test_ctx, input): def test_trying_to_merge_without_privileges(
user_factory, tag_factory, context_factory, params):
db.session.add_all([ db.session.add_all([
test_ctx.tag_factory(names=['source']), tag_factory(names=['source']),
test_ctx.tag_factory(names=['target']), tag_factory(names=['target']),
]) ])
db.session.commit() db.session.commit()
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.api.post( api.tag_api.merge_tags(
test_ctx.context_factory( context_factory(
input={ params={
'removeVersion': 1, 'removeVersion': 1,
'mergeToVersion': 1, 'mergeToVersion': 1,
'remove': 'source', 'remove': 'source',
'mergeTo': 'target', 'mergeTo': 'target',
}, },
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS))) user=user_factory(rank=db.User.RANK_ANONYMOUS)))

View file

@ -1,82 +1,64 @@
import datetime
import pytest import pytest
import unittest.mock
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import util, tags from szurubooru.func import tags
@pytest.fixture @pytest.fixture(autouse=True)
def test_ctx( def inject_config(config_injector):
context_factory,
config_injector,
user_factory,
tag_factory,
tag_category_factory):
config_injector({ config_injector({
'privileges': { 'privileges': {
'tags:list': db.User.RANK_REGULAR, 'tags:list': db.User.RANK_REGULAR,
'tags:view': db.User.RANK_REGULAR, 'tags:view': db.User.RANK_REGULAR,
}, },
'thumbnails': {'avatar_width': 200},
}) })
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.tag_factory = tag_factory
ret.tag_category_factory = tag_category_factory
ret.list_api = api.TagListApi()
ret.detail_api = api.TagDetailApi()
return ret
def test_retrieving_multiple(test_ctx): def test_retrieving_multiple(user_factory, tag_factory, context_factory):
tag1 = test_ctx.tag_factory(names=['t1']) tag1 = tag_factory(names=['t1'])
tag2 = test_ctx.tag_factory(names=['t2']) tag2 = tag_factory(names=['t2'])
db.session.add_all([tag1, tag2]) db.session.add_all([tag1, tag2])
result = test_ctx.list_api.get( with unittest.mock.patch('szurubooru.func.tags.serialize_tag'):
test_ctx.context_factory( tags.serialize_tag.return_value = 'serialized tag'
input={'query': '', 'page': 1}, result = api.tag_api.get_tags(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) context_factory(
assert result['query'] == '' params={'query': '', 'page': 1},
assert result['page'] == 1 user=user_factory(rank=db.User.RANK_REGULAR)))
assert result['pageSize'] == 100
assert result['total'] == 2
assert [t['names'] for t in result['results']] == [['t1'], ['t2']]
def test_trying_to_retrieve_multiple_without_privileges(test_ctx):
with pytest.raises(errors.AuthError):
test_ctx.list_api.get(
test_ctx.context_factory(
input={'query': '', 'page': 1},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)))
def test_retrieving_single(test_ctx):
category = test_ctx.tag_category_factory(name='meta')
db.session.add(test_ctx.tag_factory(names=['tag'], category=category))
result = test_ctx.detail_api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'tag')
assert result == { assert result == {
'names': ['tag'], 'query': '',
'category': 'meta', 'page': 1,
'description': None, 'pageSize': 100,
'creationTime': datetime.datetime(1996, 1, 1), 'total': 2,
'lastEditTime': None, 'results': ['serialized tag', 'serialized tag'],
'suggestions': [],
'implications': [],
'usages': 0,
'snapshots': [],
'version': 1,
} }
def test_trying_to_retrieve_single_non_existing(test_ctx): def test_trying_to_retrieve_multiple_without_privileges(
with pytest.raises(tags.TagNotFoundError): user_factory, context_factory):
test_ctx.detail_api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'-')
def test_trying_to_retrieve_single_without_privileges(test_ctx):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.detail_api.get( api.tag_api.get_tags(
test_ctx.context_factory( context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)), params={'query': '', 'page': 1},
'-') user=user_factory(rank=db.User.RANK_ANONYMOUS)))
def test_retrieving_single(user_factory, tag_factory, context_factory):
db.session.add(tag_factory(names=['tag']))
with unittest.mock.patch('szurubooru.func.tags.serialize_tag'):
tags.serialize_tag.return_value = 'serialized tag'
result = api.tag_api.get_tag(
context_factory(
user=user_factory(rank=db.User.RANK_REGULAR)),
{'tag_name': 'tag'})
assert result == 'serialized tag'
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
with pytest.raises(tags.TagNotFoundError):
api.tag_api.get_tag(
context_factory(
user=user_factory(rank=db.User.RANK_REGULAR)),
{'tag_name': '-'})
def test_trying_to_retrieve_single_without_privileges(
user_factory, context_factory):
with pytest.raises(errors.AuthError):
api.tag_api.get_tag(
context_factory(
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
{'tag_name': '-'})

View file

@ -1,56 +1,47 @@
import datetime
import pytest import pytest
import unittest.mock
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import util, tags from szurubooru.func import tags
def assert_results(result, expected_tag_names_and_occurrences): @pytest.fixture(autouse=True)
actual_tag_names_and_occurences = [] def inject_config(config_injector):
for item in result['results']: config_injector({'privileges': {'tags:view': db.User.RANK_REGULAR}})
tag_name = item['tag']['names'][0]
occurrences = item['occurrences']
actual_tag_names_and_occurences.append((tag_name, occurrences))
assert actual_tag_names_and_occurences == expected_tag_names_and_occurrences
@pytest.fixture def test_get_tag_siblings(user_factory, tag_factory, post_factory, context_factory):
def test_ctx( db.session.add(tag_factory(names=['tag']))
context_factory, config_injector, user_factory, tag_factory, post_factory): with unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \
config_injector({ unittest.mock.patch('szurubooru.func.tags.get_tag_siblings'):
'privileges': { tags.serialize_tag.side_effect = \
'tags:view': db.User.RANK_REGULAR, lambda tag, *args, **kwargs: \
'serialized tag %s' % tag.names[0].name
tags.get_tag_siblings.return_value = [
(tag_factory(names=['sib1']), 1),
(tag_factory(names=['sib2']), 3),
]
result = api.tag_api.get_tag_siblings(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
{'tag_name': 'tag'})
assert result == {
'results': [
{
'tag': 'serialized tag sib1',
'occurrences': 1,
}, },
'thumbnails': {'avatar_width': 200}, {
}) 'tag': 'serialized tag sib2',
ret = util.dotdict() 'occurrences': 3,
ret.context_factory = context_factory },
ret.user_factory = user_factory ],
ret.tag_factory = tag_factory }
ret.post_factory = post_factory
ret.api = api.TagSiblingsApi()
return ret
def test_used_with_others(test_ctx): def test_trying_to_retrieve_non_existing(user_factory, context_factory):
tag1 = test_ctx.tag_factory(names=['tag1'])
tag2 = test_ctx.tag_factory(names=['tag2'])
post = test_ctx.post_factory()
post.tags = [tag1, tag2]
db.session.add_all([post, tag1, tag2])
result = test_ctx.api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 'tag1')
assert_results(result, [('tag2', 1)])
result = test_ctx.api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 'tag2')
assert_results(result, [('tag1', 1)])
def test_trying_to_retrieve_non_existing(test_ctx):
with pytest.raises(tags.TagNotFoundError): with pytest.raises(tags.TagNotFoundError):
test_ctx.api.get( api.tag_api.get_tag_siblings(
test_ctx.context_factory( context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), '-') {'tag_name': '-'})
def test_trying_to_retrieve_without_privileges(test_ctx): def test_trying_to_retrieve_without_privileges(user_factory, context_factory):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.api.get( api.tag_api.get_tag_siblings(
test_ctx.context_factory( context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)),
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)), '-') {'tag_name': '-'})

View file

@ -1,20 +1,11 @@
import datetime
import os
import pytest import pytest
from szurubooru import api, config, db, errors import unittest.mock
from szurubooru.func import util, tags, tag_categories, cache from szurubooru import api, db, errors
from szurubooru.func import tags
def assert_relations(relations, expected_tag_names): @pytest.fixture(autouse=True)
actual_names = sorted([rel.names[0].name for rel in relations]) def inject_config(config_injector):
assert actual_names == sorted(expected_tag_names)
@pytest.fixture
def test_ctx(
tmpdir, config_injector, context_factory, user_factory, tag_factory):
config_injector({ config_injector({
'data_dir': str(tmpdir),
'tag_name_regex': '^[^!]*$',
'tag_category_name_regex': '^[^!]*$',
'privileges': { 'privileges': {
'tags:create': db.User.RANK_REGULAR, 'tags:create': db.User.RANK_REGULAR,
'tags:edit:names': db.User.RANK_REGULAR, 'tags:edit:names': db.User.RANK_REGULAR,
@ -24,118 +15,115 @@ def test_ctx(
'tags:edit:implications': db.User.RANK_REGULAR, 'tags:edit:implications': db.User.RANK_REGULAR,
}, },
}) })
db.session.add_all([
db.TagCategory(name) for name in ['meta', 'character', 'copyright']])
db.session.commit()
cache.purge()
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.tag_factory = tag_factory
ret.api = api.TagDetailApi()
return ret
def test_simple_updating(test_ctx, fake_datetime): def test_simple_updating(user_factory, tag_factory, context_factory, fake_datetime):
tag = test_ctx.tag_factory(names=['tag1', 'tag2']) auth_user = user_factory(rank=db.User.RANK_REGULAR)
tag = tag_factory(names=['tag1', 'tag2'])
db.session.add(tag) db.session.add(tag)
db.session.commit() db.session.commit()
with fake_datetime('1997-12-01'): with unittest.mock.patch('szurubooru.func.tags.create_tag'), \
result = test_ctx.api.put( unittest.mock.patch('szurubooru.func.tags.get_or_create_tags_by_names'), \
test_ctx.context_factory( unittest.mock.patch('szurubooru.func.tags.update_tag_names'), \
input={ unittest.mock.patch('szurubooru.func.tags.update_tag_category_name'), \
unittest.mock.patch('szurubooru.func.tags.update_tag_description'), \
unittest.mock.patch('szurubooru.func.tags.update_tag_suggestions'), \
unittest.mock.patch('szurubooru.func.tags.update_tag_implications'), \
unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'):
tags.get_or_create_tags_by_names.return_value = ([], [])
tags.serialize_tag.return_value = 'serialized tag'
result = api.tag_api.update_tag(
context_factory(
params={
'version': 1, 'version': 1,
'names': ['tag3'], 'names': ['tag3'],
'category': 'character', 'category': 'character',
'description': 'desc', 'description': 'desc',
'suggestions': ['sug1', 'sug2'],
'implications': ['imp1', 'imp2'],
}, },
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), user=auth_user),
'tag1') {'tag_name': 'tag1'})
assert len(result['snapshots']) == 1 assert result == 'serialized tag'
del result['snapshots'] tags.create_tag.assert_not_called()
assert result == { tags.update_tag_names.assert_called_once_with(tag, ['tag3'])
'names': ['tag3'], tags.update_tag_category_name.assert_called_once_with(tag, 'character')
'category': 'character', tags.update_tag_description.assert_called_once_with(tag, 'desc')
'description': 'desc', tags.update_tag_suggestions.assert_called_once_with(tag, ['sug1', 'sug2'])
'suggestions': [], tags.update_tag_implications.assert_called_once_with(tag, ['imp1', 'imp2'])
'implications': [], tags.serialize_tag.assert_called_once_with(tag, options=None)
'creationTime': datetime.datetime(1996, 1, 1),
'lastEditTime': datetime.datetime(1997, 12, 1),
'usages': 0,
'version': 2,
}
assert tags.try_get_tag_by_name('tag1') is None
assert tags.try_get_tag_by_name('tag2') is None
tag = tags.get_tag_by_name('tag3')
assert tag is not None
assert [tag_name.name for tag_name in tag.names] == ['tag3']
assert tag.category.name == 'character'
assert tag.suggestions == []
assert tag.implications == []
assert os.path.exists(os.path.join(config.config['data_dir'], 'tags.json'))
@pytest.mark.parametrize('input,expected_exception', [
({'names': None}, tags.InvalidTagNameError),
({'names': []}, tags.InvalidTagNameError),
({'names': [None]}, tags.InvalidTagNameError),
({'names': ['']}, tags.InvalidTagNameError),
({'names': ['!bad']}, tags.InvalidTagNameError),
({'names': ['x' * 65]}, tags.InvalidTagNameError),
({'category': None}, tag_categories.TagCategoryNotFoundError),
({'category': ''}, tag_categories.TagCategoryNotFoundError),
({'category': '!bad'}, tag_categories.TagCategoryNotFoundError),
({'suggestions': ['good', '!bad']}, tags.InvalidTagNameError),
({'implications': ['good', '!bad']}, tags.InvalidTagNameError),
])
def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception):
db.session.add(test_ctx.tag_factory(names=['tag1']))
db.session.commit()
with pytest.raises(expected_exception):
test_ctx.api.put(
test_ctx.context_factory(
input={**input, **{'version': 1}},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'tag1')
@pytest.mark.parametrize( @pytest.mark.parametrize(
'field', ['names', 'category', 'description', 'implications', 'suggestions']) 'field', ['names', 'category', 'description', 'implications', 'suggestions'])
def test_omitting_optional_field(test_ctx, field): def test_omitting_optional_field(
db.session.add(test_ctx.tag_factory(names=['tag'])) user_factory, tag_factory, context_factory, field):
db.session.add(tag_factory(names=['tag']))
db.session.commit() db.session.commit()
input = { params = {
'names': ['tag1', 'tag2'], 'names': ['tag1', 'tag2'],
'category': 'meta', 'category': 'meta',
'description': 'desc', 'description': 'desc',
'suggestions': [], 'suggestions': [],
'implications': [], 'implications': [],
} }
del input[field] del params[field]
result = test_ctx.api.put( with unittest.mock.patch('szurubooru.func.tags.create_tag'), \
test_ctx.context_factory( unittest.mock.patch('szurubooru.func.tags.update_tag_names'), \
input={**input, **{'version': 1}}, unittest.mock.patch('szurubooru.func.tags.update_tag_category_name'), \
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \
'tag') unittest.mock.patch('szurubooru.func.tags.export_to_json'):
assert result is not None api.tag_api.update_tag(
context_factory(
params={**params, **{'version': 1}},
user=user_factory(rank=db.User.RANK_REGULAR)),
{'tag_name': 'tag'})
def test_trying_to_update_non_existing(test_ctx): def test_trying_to_update_non_existing(user_factory, context_factory):
with pytest.raises(tags.TagNotFoundError): with pytest.raises(tags.TagNotFoundError):
test_ctx.api.put( api.tag_api.update_tag(
test_ctx.context_factory( context_factory(
input={'names': ['dummy']}, params={'names': ['dummy']},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=db.User.RANK_REGULAR)),
'tag1') {'tag_name': 'tag1'})
@pytest.mark.parametrize('input', [ @pytest.mark.parametrize('params', [
{'names': 'whatever'}, {'names': 'whatever'},
{'category': 'whatever'}, {'category': 'whatever'},
{'suggestions': ['whatever']}, {'suggestions': ['whatever']},
{'implications': ['whatever']}, {'implications': ['whatever']},
]) ])
def test_trying_to_update_without_privileges(test_ctx, input): def test_trying_to_update_without_privileges(
db.session.add(test_ctx.tag_factory(names=['tag'])) user_factory, tag_factory, context_factory, params):
db.session.add(tag_factory(names=['tag']))
db.session.commit() db.session.commit()
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.api.put( api.tag_api.update_tag(
test_ctx.context_factory( context_factory(
input={**input, **{'version': 1}}, params={**params, **{'version': 1}},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)), user=user_factory(rank=db.User.RANK_ANONYMOUS)),
'tag') {'tag_name': 'tag'})
def test_trying_to_create_tags_without_privileges(
config_injector, context_factory, tag_factory, user_factory):
tag = tag_factory(names=['tag'])
db.session.add(tag)
db.session.commit()
config_injector({'privileges': {
'tags:create': db.User.RANK_ADMINISTRATOR,
'tags:edit:suggestions': db.User.RANK_REGULAR,
'tags:edit:implications': db.User.RANK_REGULAR,
}})
with unittest.mock.patch('szurubooru.func.tags.get_or_create_tags_by_names'):
tags.get_or_create_tags_by_names.return_value = ([], ['new-tag'])
with pytest.raises(errors.AuthError):
api.tag_api.update_tag(
context_factory(
params={'suggestions': ['tag1', 'tag2'], 'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)),
{'tag_name': 'tag'})
with pytest.raises(errors.AuthError):
api.tag_api.update_tag(
context_factory(
params={'implications': ['tag1', 'tag2'], 'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)),
{'tag_name': 'tag'})

View file

@ -1,230 +1,79 @@
import datetime
import pytest import pytest
from szurubooru import api, config, db, errors import unittest.mock
from szurubooru.func import auth, util, users from szurubooru import api, db, errors
from szurubooru.func import users
EMPTY_PIXEL = \ @pytest.fixture(autouse=True)
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' \ def inject_config(config_injector):
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \ config_injector({'privileges': {'users:create': 'regular'}})
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b'
@pytest.fixture def test_creating_user(user_factory, context_factory, fake_datetime):
def test_ctx(tmpdir, config_injector, context_factory, user_factory): user = user_factory()
config_injector({ with unittest.mock.patch('szurubooru.func.users.create_user'), \
'secret': '', unittest.mock.patch('szurubooru.func.users.update_user_name'), \
'user_name_regex': '[^!]{3,}', unittest.mock.patch('szurubooru.func.users.update_user_password'), \
'password_regex': '[^!]{3,}', unittest.mock.patch('szurubooru.func.users.update_user_email'), \
'default_rank': db.User.RANK_REGULAR, unittest.mock.patch('szurubooru.func.users.update_user_rank'), \
'thumbnails': {'avatar_width': 200, 'avatar_height': 200}, unittest.mock.patch('szurubooru.func.users.update_user_avatar'), \
'privileges': {'users:create': 'anonymous'}, unittest.mock.patch('szurubooru.func.users.serialize_user'), \
'data_dir': str(tmpdir.mkdir('data')), fake_datetime('1969-02-12'):
'data_url': 'http://example.com/data/', users.serialize_user.return_value = 'serialized user'
}) users.create_user.return_value = user
ret = util.dotdict() result = api.user_api.create_user(
ret.context_factory = context_factory context_factory(
ret.user_factory = user_factory params={
ret.api = api.UserListApi()
return ret
def test_creating_user(test_ctx, fake_datetime):
with fake_datetime('1969-02-12'):
result = test_ctx.api.post(
test_ctx.context_factory(
input={
'name': 'chewie1', 'name': 'chewie1',
'email': 'asd@asd.asd', 'email': 'asd@asd.asd',
'password': 'oks', 'password': 'oks',
'rank': 'moderator',
'avatarStyle': 'manual',
}, },
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) files={'avatar': b'...'},
assert result == { user=user_factory(rank=db.User.RANK_REGULAR)))
'avatarStyle': 'gravatar', assert result == 'serialized user'
'avatarUrl': 'https://gravatar.com/avatar/' + users.create_user.assert_called_once_with('chewie1', 'oks', 'asd@asd.asd')
'6f370c8c7109534c3d5c394123a477d7?d=retro&s=200', assert not users.update_user_name.called
'creationTime': datetime.datetime(1969, 2, 12), assert not users.update_user_password.called
'lastLoginTime': None, assert not users.update_user_email.called
'name': 'chewie1', users.update_user_rank.called_once_with(user, 'moderator')
'rank': 'administrator', users.update_user_avatar.called_once_with(user, 'manual', b'...')
'email': 'asd@asd.asd',
'commentCount': 0,
'likedPostCount': 0,
'dislikedPostCount': 0,
'favoritePostCount': 0,
'uploadedPostCount': 0,
'version': 1,
}
user = users.get_user_by_name('chewie1')
assert user.name == 'chewie1'
assert user.email == 'asd@asd.asd'
assert user.rank == db.User.RANK_ADMINISTRATOR
assert auth.is_valid_password(user, 'oks') is True
assert auth.is_valid_password(user, 'invalid') is False
def test_first_user_becomes_admin_others_not(test_ctx):
result1 = test_ctx.api.post(
test_ctx.context_factory(
input={
'name': 'chewie1',
'email': 'asd@asd.asd',
'password': 'oks',
},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)))
result2 = test_ctx.api.post(
test_ctx.context_factory(
input={
'name': 'chewie2',
'email': 'asd@asd.asd',
'password': 'sok',
},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)))
assert result1['rank'] == 'administrator'
assert result2['rank'] == 'regular'
first_user = users.get_user_by_name('chewie1')
other_user = users.get_user_by_name('chewie2')
assert first_user.rank == db.User.RANK_ADMINISTRATOR
assert other_user.rank == db.User.RANK_REGULAR
def test_first_user_does_not_become_admin_if_they_dont_wish_so(test_ctx):
result = test_ctx.api.post(
test_ctx.context_factory(
input={
'name': 'chewie1',
'email': 'asd@asd.asd',
'password': 'oks',
'rank': 'regular',
},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)))
assert result['rank'] == 'regular'
def test_trying_to_become_someone_else(test_ctx):
test_ctx.api.post(
test_ctx.context_factory(
input={
'name': 'chewie',
'email': 'asd@asd.asd',
'password': 'oks',
},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
with pytest.raises(users.UserAlreadyExistsError):
test_ctx.api.post(
test_ctx.context_factory(
input={
'name': 'chewie',
'email': 'asd@asd.asd',
'password': 'oks',
},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
with pytest.raises(users.UserAlreadyExistsError):
test_ctx.api.post(
test_ctx.context_factory(
input={
'name': 'CHEWIE',
'email': 'asd@asd.asd',
'password': 'oks',
},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
@pytest.mark.parametrize('input,expected_exception', [
({'name': None}, users.InvalidUserNameError),
({'name': ''}, users.InvalidUserNameError),
({'name': '!bad'}, users.InvalidUserNameError),
({'name': 'x' * 51}, users.InvalidUserNameError),
({'password': None}, users.InvalidPasswordError),
({'password': ''}, users.InvalidPasswordError),
({'password': '!bad'}, users.InvalidPasswordError),
({'rank': None}, users.InvalidRankError),
({'rank': ''}, users.InvalidRankError),
({'rank': 'bad'}, users.InvalidRankError),
({'rank': 'anonymous'}, users.InvalidRankError),
({'rank': 'nobody'}, users.InvalidRankError),
({'email': 'bad'}, users.InvalidEmailError),
({'email': 'x@' * 65 + '.com'}, users.InvalidEmailError),
({'avatarStyle': None}, users.InvalidAvatarError),
({'avatarStyle': ''}, users.InvalidAvatarError),
({'avatarStyle': 'invalid'}, users.InvalidAvatarError),
({'avatarStyle': 'manual'}, users.InvalidAvatarError), # missing file
])
def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception):
real_input={
'name': 'chewie',
'email': 'asd@asd.asd',
'password': 'oks',
}
for key, value in input.items():
real_input[key] = value
with pytest.raises(expected_exception):
test_ctx.api.post(
test_ctx.context_factory(
input=real_input,
user=test_ctx.user_factory(
name='u1', rank=db.User.RANK_ADMINISTRATOR)))
@pytest.mark.parametrize('field', ['name', 'password']) @pytest.mark.parametrize('field', ['name', 'password'])
def test_trying_to_omit_mandatory_field(test_ctx, field): def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
input = { params = {
'name': 'chewie', 'name': 'chewie',
'email': 'asd@asd.asd', 'email': 'asd@asd.asd',
'password': 'oks', 'password': 'oks',
} }
del input[field] user = user_factory()
with pytest.raises(errors.ValidationError): auth_user = user_factory(rank=db.User.RANK_REGULAR)
test_ctx.api.post( del params[field]
test_ctx.context_factory( with unittest.mock.patch('szurubooru.func.users.create_user'), \
input=input, pytest.raises(errors.MissingRequiredParameterError):
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) users.create_user.return_value = user
api.user_api.create_user(context_factory(params=params, user=auth_user))
@pytest.mark.parametrize('field', ['rank', 'email', 'avatarStyle']) @pytest.mark.parametrize('field', ['rank', 'email', 'avatarStyle'])
def test_omitting_optional_field(test_ctx, field): def test_omitting_optional_field(user_factory, context_factory, field):
input = { params = {
'name': 'chewie', 'name': 'chewie',
'email': 'asd@asd.asd', 'email': 'asd@asd.asd',
'password': 'oks', 'password': 'oks',
'rank': 'moderator', 'rank': 'moderator',
'avatarStyle': 'manual', 'avatarStyle': 'gravatar',
} }
del input[field] del params[field]
result = test_ctx.api.post( user = user_factory()
test_ctx.context_factory( auth_user = user_factory(rank=db.User.RANK_MODERATOR)
input=input, with unittest.mock.patch('szurubooru.func.users.create_user'), \
files={'avatar': EMPTY_PIXEL}, unittest.mock.patch('szurubooru.func.users.update_user_avatar'), \
user=test_ctx.user_factory(rank=db.User.RANK_MODERATOR))) unittest.mock.patch('szurubooru.func.users.serialize_user'):
assert result is not None users.create_user.return_value = user
api.user_api.create_user(
context_factory(params=params, user=auth_user))
def test_mods_trying_to_become_admin(test_ctx): def test_trying_to_create_user_without_privileges(context_factory, user_factory):
user1 = test_ctx.user_factory(name='u1', rank=db.User.RANK_MODERATOR)
user2 = test_ctx.user_factory(name='u2', rank=db.User.RANK_MODERATOR)
db.session.add_all([user1, user2])
context = test_ctx.context_factory(input={
'name': 'chewie',
'email': 'asd@asd.asd',
'password': 'oks',
'rank': 'administrator',
}, user=user1)
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.api.post(context) api.user_api.create_user(context_factory(
params='whatever',
def test_admin_creating_mod_account(test_ctx): user=user_factory(rank=db.User.RANK_ANONYMOUS)))
user = test_ctx.user_factory(rank=db.User.RANK_ADMINISTRATOR)
db.session.add(user)
context = test_ctx.context_factory(input={
'name': 'chewie',
'email': 'asd@asd.asd',
'password': 'oks',
'rank': 'moderator',
}, user=user)
result = test_ctx.api.post(context)
assert result['rank'] == 'moderator'
def test_uploading_avatar(test_ctx):
response = test_ctx.api.post(
test_ctx.context_factory(
input={
'name': 'chewie',
'email': 'asd@asd.asd',
'password': 'oks',
'avatarStyle': 'manual',
},
files={'avatar': EMPTY_PIXEL},
user=test_ctx.user_factory(rank=db.User.RANK_MODERATOR)))
user = users.get_user_by_name('chewie')
assert user.avatar_style == user.AVATAR_MANUAL
assert response['avatarUrl'] == 'http://example.com/data/avatars/chewie.png'

View file

@ -1,54 +1,52 @@
import pytest import pytest
from datetime import datetime
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import util, users from szurubooru.func import users
@pytest.fixture @pytest.fixture(autouse=True)
def test_ctx(config_injector, context_factory, user_factory): def inject_config(config_injector):
config_injector({ config_injector({
'privileges': { 'privileges': {
'users:delete:self': db.User.RANK_REGULAR, 'users:delete:self': db.User.RANK_REGULAR,
'users:delete:any': db.User.RANK_MODERATOR, 'users:delete:any': db.User.RANK_MODERATOR,
}, },
}) })
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.api = api.UserDetailApi()
return ret
def test_deleting_oneself(test_ctx): def test_deleting_oneself(user_factory, context_factory):
user = test_ctx.user_factory(name='u', rank=db.User.RANK_REGULAR) user = user_factory(name='u', rank=db.User.RANK_REGULAR)
db.session.add(user) db.session.add(user)
db.session.commit() db.session.commit()
result = test_ctx.api.delete( result = api.user_api.delete_user(
test_ctx.context_factory(input={'version': 1}, user=user), 'u') context_factory(
params={'version': 1}, user=user), {'user_name': 'u'})
assert result == {} assert result == {}
assert db.session.query(db.User).count() == 0 assert db.session.query(db.User).count() == 0
def test_deleting_someone_else(test_ctx): def test_deleting_someone_else(user_factory, context_factory):
user1 = test_ctx.user_factory(name='u1', rank=db.User.RANK_REGULAR) user1 = user_factory(name='u1', rank=db.User.RANK_REGULAR)
user2 = test_ctx.user_factory(name='u2', rank=db.User.RANK_MODERATOR) user2 = user_factory(name='u2', rank=db.User.RANK_MODERATOR)
db.session.add_all([user1, user2]) db.session.add_all([user1, user2])
db.session.commit() db.session.commit()
test_ctx.api.delete( api.user_api.delete_user(
test_ctx.context_factory(input={'version': 1}, user=user2), 'u1') context_factory(
params={'version': 1}, user=user2), {'user_name': 'u1'})
assert db.session.query(db.User).count() == 1 assert db.session.query(db.User).count() == 1
def test_trying_to_delete_someone_else_without_privileges(test_ctx): def test_trying_to_delete_someone_else_without_privileges(
user1 = test_ctx.user_factory(name='u1', rank=db.User.RANK_REGULAR) user_factory, context_factory):
user2 = test_ctx.user_factory(name='u2', rank=db.User.RANK_REGULAR) user1 = user_factory(name='u1', rank=db.User.RANK_REGULAR)
user2 = user_factory(name='u2', rank=db.User.RANK_REGULAR)
db.session.add_all([user1, user2]) db.session.add_all([user1, user2])
db.session.commit() db.session.commit()
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.api.delete( api.user_api.delete_user(
test_ctx.context_factory(input={'version': 1}, user=user2), 'u1') context_factory(
params={'version': 1}, user=user2), {'user_name': 'u1'})
assert db.session.query(db.User).count() == 2 assert db.session.query(db.User).count() == 2
def test_trying_to_delete_non_existing(test_ctx): def test_trying_to_delete_non_existing(user_factory, context_factory):
with pytest.raises(users.UserNotFoundError): with pytest.raises(users.UserNotFoundError):
test_ctx.api.delete( api.user_api.delete_user(
test_ctx.context_factory( context_factory(
input={'version': 1}, params={'version': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=db.User.RANK_REGULAR)),
'bad') {'user_name': 'bad'})

View file

@ -1,83 +1,64 @@
import datetime import unittest.mock
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, errors
from szurubooru.func import util, users from szurubooru.func import users
@pytest.fixture @pytest.fixture(autouse=True)
def test_ctx(context_factory, config_injector, user_factory): def inject_config(config_injector):
config_injector({ config_injector({
'privileges': { 'privileges': {
'users:list': db.User.RANK_REGULAR, 'users:list': db.User.RANK_REGULAR,
'users:view': db.User.RANK_REGULAR, 'users:view': db.User.RANK_REGULAR,
'users:edit:any:email': db.User.RANK_MODERATOR, 'users:edit:any:email': db.User.RANK_MODERATOR,
}, },
'thumbnails': {'avatar_width': 200},
}) })
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.list_api = api.UserListApi()
ret.detail_api = api.UserDetailApi()
return ret
def test_retrieving_multiple(test_ctx): def test_retrieving_multiple(user_factory, context_factory):
user1 = test_ctx.user_factory(name='u1', rank=db.User.RANK_MODERATOR) user1 = user_factory(name='u1', rank=db.User.RANK_MODERATOR)
user2 = test_ctx.user_factory(name='u2', rank=db.User.RANK_MODERATOR) user2 = user_factory(name='u2', rank=db.User.RANK_MODERATOR)
db.session.add_all([user1, user2]) db.session.add_all([user1, user2])
result = test_ctx.list_api.get( with unittest.mock.patch('szurubooru.func.users.serialize_user'):
test_ctx.context_factory( users.serialize_user.return_value = 'serialized user'
input={'query': '', 'page': 1}, result = api.user_api.get_users(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR))) context_factory(
assert result['query'] == '' params={'query': '', 'page': 1},
assert result['page'] == 1 user=user_factory(rank=db.User.RANK_REGULAR)))
assert result['pageSize'] == 100
assert result['total'] == 2
assert [u['name'] for u in result['results']] == ['u1', 'u2']
def test_trying_to_retrieve_multiple_without_privileges(test_ctx):
with pytest.raises(errors.AuthError):
test_ctx.list_api.get(
test_ctx.context_factory(
input={'query': '', 'page': 1},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)))
def test_retrieving_single(test_ctx):
db.session.add(test_ctx.user_factory(name='u1', rank=db.User.RANK_REGULAR))
result = test_ctx.detail_api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'u1')
assert result == { assert result == {
'name': 'u1', 'query': '',
'rank': db.User.RANK_REGULAR, 'page': 1,
'creationTime': datetime.datetime(1997, 1, 1), 'pageSize': 100,
'lastLoginTime': None, 'total': 2,
'avatarStyle': 'gravatar', 'results': ['serialized user', 'serialized user'],
'avatarUrl': 'https://gravatar.com/avatar/' +
'275876e34cf609db118f3d84b799a790?d=retro&s=200',
'email': False,
'commentCount': 0,
'likedPostCount': False,
'dislikedPostCount': False,
'favoritePostCount': 0,
'uploadedPostCount': 0,
'version': 1,
} }
assert result['email'] is False
assert result['likedPostCount'] is False
assert result['dislikedPostCount'] is False
def test_trying_to_retrieve_single_non_existing(test_ctx): def test_trying_to_retrieve_multiple_without_privileges(
with pytest.raises(users.UserNotFoundError): user_factory, context_factory):
test_ctx.detail_api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'-')
def test_trying_to_retrieve_single_without_privileges(test_ctx):
db.session.add(test_ctx.user_factory(name='u1', rank=db.User.RANK_REGULAR))
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.detail_api.get( api.user_api.get_users(
test_ctx.context_factory( context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)), params={'query': '', 'page': 1},
'u1') user=user_factory(rank=db.User.RANK_ANONYMOUS)))
def test_retrieving_single(user_factory, context_factory):
user = user_factory(name='u1', rank=db.User.RANK_REGULAR)
auth_user = user_factory(rank=db.User.RANK_REGULAR)
db.session.add(user)
with unittest.mock.patch('szurubooru.func.users.serialize_user'):
users.serialize_user.return_value = 'serialized user'
result = api.user_api.get_user(
context_factory(user=auth_user), {'user_name': 'u1'})
assert result == 'serialized user'
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR)
with pytest.raises(users.UserNotFoundError):
api.user_api.get_user(
context_factory(user=auth_user), {'user_name': '-'})
def test_trying_to_retrieve_single_without_privileges(
user_factory, context_factory):
auth_user = user_factory(rank=db.User.RANK_ANONYMOUS)
db.session.add(user_factory(name='u1', rank=db.User.RANK_REGULAR))
with pytest.raises(errors.AuthError):
api.user_api.get_user(
context_factory(user=auth_user), {'user_name': 'u1'})

View file

@ -1,20 +1,12 @@
import datetime
import pytest import pytest
from szurubooru import api, config, db, errors import unittest.mock
from szurubooru.func import auth, util, users from datetime import datetime
from szurubooru import api, db, errors
from szurubooru.func import users
EMPTY_PIXEL = \ @pytest.fixture(autouse=True)
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' \ def inject_config(config_injector):
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b'
@pytest.fixture
def test_ctx(tmpdir, config_injector, context_factory, user_factory):
config_injector({ config_injector({
'secret': '',
'user_name_regex': '^[^!]{3,}$',
'password_regex': '^[^!]{3,}$',
'thumbnails': {'avatar_width': 200, 'avatar_height': 200},
'privileges': { 'privileges': {
'users:edit:self:name': db.User.RANK_REGULAR, 'users:edit:self:name': db.User.RANK_REGULAR,
'users:edit:self:pass': db.User.RANK_REGULAR, 'users:edit:self:pass': db.User.RANK_REGULAR,
@ -27,203 +19,97 @@ def test_ctx(tmpdir, config_injector, context_factory, user_factory):
'users:edit:any:rank': db.User.RANK_ADMINISTRATOR, 'users:edit:any:rank': db.User.RANK_ADMINISTRATOR,
'users:edit:any:avatar': db.User.RANK_ADMINISTRATOR, 'users:edit:any:avatar': db.User.RANK_ADMINISTRATOR,
}, },
'data_dir': str(tmpdir.mkdir('data')),
'data_url': 'http://example.com/data/',
}) })
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.api = api.UserDetailApi()
return ret
def test_updating_user(test_ctx): def test_updating_user(context_factory, user_factory):
user = test_ctx.user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR) user = user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR)
auth_user = user_factory(rank=db.User.RANK_ADMINISTRATOR)
db.session.add(user) db.session.add(user)
result = test_ctx.api.put( db.session.flush()
test_ctx.context_factory(
input={ with unittest.mock.patch('szurubooru.func.users.create_user'), \
unittest.mock.patch('szurubooru.func.users.update_user_name'), \
unittest.mock.patch('szurubooru.func.users.update_user_password'), \
unittest.mock.patch('szurubooru.func.users.update_user_email'), \
unittest.mock.patch('szurubooru.func.users.update_user_rank'), \
unittest.mock.patch('szurubooru.func.users.update_user_avatar'), \
unittest.mock.patch('szurubooru.func.users.serialize_user'):
users.serialize_user.return_value = 'serialized user'
result = api.user_api.update_user(
context_factory(
params={
'version': 1, 'version': 1,
'name': 'chewie', 'name': 'chewie',
'email': 'asd@asd.asd', 'email': 'asd@asd.asd',
'password': 'oks', 'password': 'oks',
'rank': 'moderator', 'rank': 'moderator',
'avatarStyle': 'gravatar', 'avatarStyle': 'manual',
}, },
user=user), files={
'u1') 'avatar': b'...',
assert result == { },
'avatarStyle': 'gravatar', user=auth_user),
'avatarUrl': 'https://gravatar.com/avatar/' + {'user_name': 'u1'})
'6f370c8c7109534c3d5c394123a477d7?d=retro&s=200',
'creationTime': datetime.datetime(1997, 1, 1),
'lastLoginTime': None,
'email': 'asd@asd.asd',
'name': 'chewie',
'rank': 'moderator',
'commentCount': 0,
'likedPostCount': 0,
'dislikedPostCount': 0,
'favoritePostCount': 0,
'uploadedPostCount': 0,
'version': 2,
}
user = users.get_user_by_name('chewie')
assert user.name == 'chewie'
assert user.email == 'asd@asd.asd'
assert user.rank == db.User.RANK_MODERATOR
assert user.avatar_style == user.AVATAR_GRAVATAR
assert auth.is_valid_password(user, 'oks') is True
assert auth.is_valid_password(user, 'invalid') is False
@pytest.mark.parametrize('input,expected_exception', [ assert result == 'serialized user'
({'name': None}, users.InvalidUserNameError), users.create_user.assert_not_called()
({'name': ''}, users.InvalidUserNameError), users.update_user_name.assert_called_once_with(user, 'chewie')
({'name': '!bad'}, users.InvalidUserNameError), users.update_user_password.assert_called_once_with(user, 'oks')
({'name': 'x' * 51}, users.InvalidUserNameError), users.update_user_email.assert_called_once_with(user, 'asd@asd.asd')
({'password': None}, users.InvalidPasswordError), users.update_user_rank.assert_called_once_with(user, 'moderator', auth_user)
({'password': ''}, users.InvalidPasswordError), users.update_user_avatar.assert_called_once_with(user, 'manual', b'...')
({'password': '!bad'}, users.InvalidPasswordError), users.serialize_user.assert_called_once_with(user, auth_user, options=None)
({'rank': None}, users.InvalidRankError),
({'rank': ''}, users.InvalidRankError),
({'rank': 'bad'}, users.InvalidRankError),
({'rank': 'anonymous'}, users.InvalidRankError),
({'rank': 'nobody'}, users.InvalidRankError),
({'email': 'bad'}, users.InvalidEmailError),
({'email': 'x@' * 65 + '.com'}, users.InvalidEmailError),
({'avatarStyle': None}, users.InvalidAvatarError),
({'avatarStyle': ''}, users.InvalidAvatarError),
({'avatarStyle': 'invalid'}, users.InvalidAvatarError),
({'avatarStyle': 'manual'}, users.InvalidAvatarError), # missing file
])
def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception):
user = test_ctx.user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR)
db.session.add(user)
with pytest.raises(expected_exception):
test_ctx.api.put(
test_ctx.context_factory(
input={**input, **{'version': 1}},
user=user),
'u1')
@pytest.mark.parametrize( @pytest.mark.parametrize(
'field', ['name', 'email', 'password', 'rank', 'avatarStyle']) 'field', ['name', 'email', 'password', 'rank', 'avatarStyle'])
def test_omitting_optional_field(test_ctx, field): def test_omitting_optional_field(user_factory, context_factory, field):
user = test_ctx.user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR) user = user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR)
db.session.add(user) db.session.add(user)
input = { params = {
'name': 'chewie', 'name': 'chewie',
'email': 'asd@asd.asd', 'email': 'asd@asd.asd',
'password': 'oks', 'password': 'oks',
'rank': 'moderator', 'rank': 'moderator',
'avatarStyle': 'gravatar', 'avatarStyle': 'gravatar',
} }
del input[field] del params[field]
result = test_ctx.api.put( with unittest.mock.patch('szurubooru.func.users.create_user'), \
test_ctx.context_factory( unittest.mock.patch('szurubooru.func.users.update_user_name'), \
input={**input, **{'version': 1}}, unittest.mock.patch('szurubooru.func.users.update_user_password'), \
files={'avatar': EMPTY_PIXEL}, unittest.mock.patch('szurubooru.func.users.update_user_email'), \
unittest.mock.patch('szurubooru.func.users.update_user_rank'), \
unittest.mock.patch('szurubooru.func.users.update_user_avatar'), \
unittest.mock.patch('szurubooru.func.users.serialize_user'):
api.user_api.update_user(
context_factory(
params={**params, **{'version': 1}},
files={'avatar': b'...'},
user=user), user=user),
'u1') {'user_name': 'u1'})
assert result is not None
def test_trying_to_update_non_existing(test_ctx): def test_trying_to_update_non_existing(user_factory, context_factory):
user = test_ctx.user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR) user = user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR)
db.session.add(user) db.session.add(user)
with pytest.raises(users.UserNotFoundError): with pytest.raises(users.UserNotFoundError):
test_ctx.api.put(test_ctx.context_factory(user=user), 'u2') api.user_api.update_user(
context_factory(user=user), {'user_name': 'u2'})
def test_removing_email(test_ctx): @pytest.mark.parametrize('params', [
user = test_ctx.user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR)
db.session.add(user)
test_ctx.api.put(
test_ctx.context_factory(
input={'email': '', 'version': 1}, user=user), 'u1')
assert users.get_user_by_name('u1').email is None
@pytest.mark.parametrize('input', [
{'name': 'whatever'}, {'name': 'whatever'},
{'email': 'whatever'}, {'email': 'whatever'},
{'rank': 'whatever'}, {'rank': 'whatever'},
{'password': 'whatever'}, {'password': 'whatever'},
{'avatarStyle': 'whatever'}, {'avatarStyle': 'whatever'},
]) ])
def test_trying_to_update_someone_else(test_ctx, input): def test_trying_to_update_field_without_privileges(
user1 = test_ctx.user_factory(name='u1', rank=db.User.RANK_REGULAR) user_factory, context_factory, params):
user2 = test_ctx.user_factory(name='u2', rank=db.User.RANK_REGULAR) user1 = user_factory(name='u1', rank=db.User.RANK_REGULAR)
user2 = user_factory(name='u2', rank=db.User.RANK_REGULAR)
db.session.add_all([user1, user2]) db.session.add_all([user1, user2])
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
test_ctx.api.put( api.user_api.update_user(
test_ctx.context_factory( context_factory(
input={**input, **{'version': 1}}, params={**params, **{'version': 1}},
user=user1), user=user1),
user2.name) {'user_name': user2.name})
def test_trying_to_become_someone_else(test_ctx):
user1 = test_ctx.user_factory(name='me', rank=db.User.RANK_REGULAR)
user2 = test_ctx.user_factory(name='her', rank=db.User.RANK_REGULAR)
db.session.add_all([user1, user2])
with pytest.raises(users.UserAlreadyExistsError):
test_ctx.api.put(
test_ctx.context_factory(
input={'name': 'her', 'version': 1}, user=user1),
'me')
with pytest.raises(users.UserAlreadyExistsError):
test_ctx.api.put(
test_ctx.context_factory(
input={'name': 'HER', 'version': 1}, user=user1),
'me')
def test_trying_to_make_someone_into_someone_else(test_ctx):
user1 = test_ctx.user_factory(name='him', rank=db.User.RANK_REGULAR)
user2 = test_ctx.user_factory(name='her', rank=db.User.RANK_REGULAR)
user3 = test_ctx.user_factory(name='me', rank=db.User.RANK_MODERATOR)
db.session.add_all([user1, user2, user3])
with pytest.raises(users.UserAlreadyExistsError):
test_ctx.api.put(
test_ctx.context_factory(
input={'name': 'her', 'version': 1}, user=user3),
'him')
with pytest.raises(users.UserAlreadyExistsError):
test_ctx.api.put(
test_ctx.context_factory(
input={'name': 'HER', 'version': 1}, user=user3),
'him')
def test_renaming_someone_else(test_ctx):
user1 = test_ctx.user_factory(name='him', rank=db.User.RANK_REGULAR)
user2 = test_ctx.user_factory(name='me', rank=db.User.RANK_MODERATOR)
db.session.add_all([user1, user2])
test_ctx.api.put(
test_ctx.context_factory(
input={'name': 'himself', 'version': 1}, user=user2),
'him')
test_ctx.api.put(
test_ctx.context_factory(
input={'name': 'HIMSELF', 'version': 2}, user=user2),
'himself')
def test_mods_trying_to_become_admin(test_ctx):
user1 = test_ctx.user_factory(name='u1', rank=db.User.RANK_MODERATOR)
user2 = test_ctx.user_factory(name='u2', rank=db.User.RANK_MODERATOR)
db.session.add_all([user1, user2])
context = test_ctx.context_factory(
input={'rank': 'administrator', 'version': 1},
user=user1)
with pytest.raises(errors.AuthError):
test_ctx.api.put(context, user1.name)
with pytest.raises(errors.AuthError):
test_ctx.api.put(context, user2.name)
def test_uploading_avatar(test_ctx):
user = test_ctx.user_factory(name='u1', rank=db.User.RANK_MODERATOR)
db.session.add(user)
response = test_ctx.api.put(
test_ctx.context_factory(
input={'avatarStyle': 'manual', 'version': 1},
files={'avatar': EMPTY_PIXEL},
user=user),
'u1')
user = users.get_user_by_name('u1')
assert user.avatar_style == user.AVATAR_MANUAL
assert response['avatarUrl'] == \
'http://example.com/data/avatars/u1.png'

View file

@ -5,7 +5,7 @@ import uuid
import pytest import pytest
import freezegun import freezegun
import sqlalchemy import sqlalchemy
from szurubooru import api, config, db from szurubooru import api, config, db, rest
from szurubooru.func import util from szurubooru.func import util
class QueryCounter(object): class QueryCounter(object):
@ -74,12 +74,14 @@ def session(query_logger):
@pytest.fixture @pytest.fixture
def context_factory(session): def context_factory(session):
def factory(request=None, input=None, files=None, user=None): def factory(params=None, files=None, user=None):
ctx = api.Context() ctx = rest.Context(
ctx.input = input or {} method=None,
url=None,
headers={},
params=params or {},
files=files or {})
ctx.session = session ctx.session = session
ctx.request = request or {}
ctx.files = files or {}
ctx.user = user or db.User() ctx.user = user or db.User()
return ctx return ctx
return factory return factory

View file

@ -1,32 +1,30 @@
import unittest.mock import unittest.mock
import pytest import pytest
from szurubooru import api, errors from szurubooru import rest, errors
from szurubooru.func import net from szurubooru.func import net
def test_has_param(): def test_has_param():
ctx = api.Context() ctx = rest.Context(method=None, url=None, params={'key': 'value'})
ctx.input = {'key': 'value'}
assert ctx.has_param('key') assert ctx.has_param('key')
assert not ctx.has_param('key2') assert not ctx.has_param('key2')
def test_get_file(): def test_get_file():
ctx = api.Context() ctx = rest.Context(method=None, url=None, files={'key': b'content'})
ctx.files = {'key': b'content'}
assert ctx.get_file('key') == b'content' assert ctx.get_file('key') == b'content'
assert ctx.get_file('key2') is None assert ctx.get_file('key2') is None
def test_get_file_from_url(): def test_get_file_from_url():
with unittest.mock.patch('szurubooru.func.net.download'): with unittest.mock.patch('szurubooru.func.net.download'):
net.download.return_value = b'content' net.download.return_value = b'content'
ctx = api.Context() ctx = rest.Context(
ctx.input = {'keyUrl': 'example.com'} method=None, url=None, params={'keyUrl': 'example.com'})
assert ctx.get_file('key') == b'content' assert ctx.get_file('key') == b'content'
assert ctx.get_file('key2') is None assert ctx.get_file('key2') is None
net.download.assert_called_once_with('example.com') net.download.assert_called_once_with('example.com')
def test_getting_list_parameter(): def test_getting_list_parameter():
ctx = api.Context() ctx = rest.Context(
ctx.input = {'key': 'value', 'list': ['1', '2', '3']} method=None, url=None, params={'key': 'value', 'list': ['1', '2', '3']})
assert ctx.get_param_as_list('key') == ['value'] assert ctx.get_param_as_list('key') == ['value']
assert ctx.get_param_as_list('key2') is None assert ctx.get_param_as_list('key2') is None
assert ctx.get_param_as_list('key2', default=['def']) == ['def'] assert ctx.get_param_as_list('key2', default=['def']) == ['def']
@ -35,8 +33,8 @@ def test_getting_list_parameter():
ctx.get_param_as_list('key2', required=True) ctx.get_param_as_list('key2', required=True)
def test_getting_string_parameter(): def test_getting_string_parameter():
ctx = api.Context() ctx = rest.Context(
ctx.input = {'key': 'value', 'list': ['1', '2', '3']} method=None, url=None, params={'key': 'value', 'list': ['1', '2', '3']})
assert ctx.get_param_as_string('key') == 'value' assert ctx.get_param_as_string('key') == 'value'
assert ctx.get_param_as_string('key2') is None assert ctx.get_param_as_string('key2') is None
assert ctx.get_param_as_string('key2', default='def') == 'def' assert ctx.get_param_as_string('key2', default='def') == 'def'
@ -45,8 +43,10 @@ def test_getting_string_parameter():
ctx.get_param_as_string('key2', required=True) ctx.get_param_as_string('key2', required=True)
def test_getting_int_parameter(): def test_getting_int_parameter():
ctx = api.Context() ctx = rest.Context(
ctx.input = {'key': '50', 'err': 'invalid', 'list': [1, 2, 3]} method=None,
url=None,
params={'key': '50', 'err': 'invalid', 'list': [1, 2, 3]})
assert ctx.get_param_as_int('key') == 50 assert ctx.get_param_as_int('key') == 50
assert ctx.get_param_as_int('key2') is None assert ctx.get_param_as_int('key2') is None
assert ctx.get_param_as_int('key2', default=5) == 5 assert ctx.get_param_as_int('key2', default=5) == 5
@ -65,8 +65,7 @@ def test_getting_int_parameter():
def test_getting_bool_parameter(): def test_getting_bool_parameter():
def test(value): def test(value):
ctx = api.Context() ctx = rest.Context(method=None, url=None, params={'key': value})
ctx.input = {'key': value}
return ctx.get_param_as_bool('key') return ctx.get_param_as_bool('key')
assert test('1') is True assert test('1') is True
@ -94,7 +93,7 @@ def test_getting_bool_parameter():
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError):
test(['1', '2']) test(['1', '2'])
ctx = api.Context() ctx = rest.Context(method=None, url=None)
assert ctx.get_param_as_bool('non-existing') is None assert ctx.get_param_as_bool('non-existing') is None
assert ctx.get_param_as_bool('non-existing', default=True) is True assert ctx.get_param_as_bool('non-existing', default=True) is True
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError):