server/general: ditch falcon for in-house WSGI app

For quite some time, I hated Falcon's class maps approach that caused
more chaos than good for Szurubooru. I've taken a look at the other
frameworks (hug, flask, etc) again, but they all looked too
bloated/over-engineered. I decided to just talk to WSGI myself.

Regex-based routing may not be the fastest in the world, but I'm fine
with response time of 10 ms for cached /posts.
This commit is contained in:
rr- 2016-08-14 12:35:14 +02:00
parent d102c9bdba
commit af62f8c45a
61 changed files with 2447 additions and 3096 deletions

View file

@ -1,6 +1,7 @@
[basic]
function-rgx=^_?[a-z_][a-z0-9_]{2,}$|^test_
method-rgx=^[a-z_][a-z0-9_]{2,}$|^test_
const-rgx=^[A-Z_]+$|^_[a-zA-Z_]*$
good-names=ex,_,logger
[variables]

View file

@ -10,7 +10,7 @@ import argparse
import os.path
import sys
import waitress
from szurubooru.app import create_app
from szurubooru.facade import create_app
def main():
parser = argparse.ArgumentParser('Starts szurubooru using waitress.')

View file

@ -1,6 +1,5 @@
alembic>=0.8.5
pyyaml>=3.11
falcon>=0.3.0
psycopg2>=2.6.1
SQLAlchemy>=1.0.12
pytest>=2.9.1

View file

@ -1,27 +1,8 @@
''' Falcon-compatible API facades. '''
from szurubooru.api.password_reset_api import PasswordResetApi
from szurubooru.api.user_api import UserListApi, UserDetailApi
from szurubooru.api.tag_api import (
TagListApi,
TagDetailApi,
TagMergeApi,
TagSiblingsApi)
from szurubooru.api.tag_category_api import (
TagCategoryListApi,
TagCategoryDetailApi,
DefaultTagCategoryApi)
from szurubooru.api.comment_api import (
CommentListApi,
CommentDetailApi,
CommentScoreApi)
from szurubooru.api.post_api import (
PostListApi,
PostDetailApi,
PostFeatureApi,
PostScoreApi,
PostFavoriteApi,
PostsAroundApi)
from szurubooru.api.snapshot_api import SnapshotListApi
from szurubooru.api.info_api import InfoApi
from szurubooru.api.context import Context, Request
import szurubooru.api.info_api
import szurubooru.api.user_api
import szurubooru.api.post_api
import szurubooru.api.tag_api
import szurubooru.api.tag_category_api
import szurubooru.api.comment_api
import szurubooru.api.password_reset_api
import szurubooru.api.snapshot_api

View file

@ -1,27 +0,0 @@
import types
def _bind_method(target, desired_method_name):
actual_method = getattr(target, desired_method_name)
def _wrapper_method(_self, request, _response, *args, **kwargs):
request.context.output = \
actual_method(request.context, *args, **kwargs)
return types.MethodType(_wrapper_method, target)
class BaseApi(object):
'''
A wrapper around falcon's API interface that eases input and output
management.
'''
def __init__(self):
self._translate_routes()
def _translate_routes(self):
for method_name in ['GET', 'PUT', 'POST', 'DELETE']:
desired_method_name = method_name.lower()
falcon_method_name = 'on_%s' % method_name.lower()
if hasattr(self, desired_method_name):
setattr(
self,
falcon_method_name,
_bind_method(self, desired_method_name))

View file

@ -1,7 +1,9 @@
import datetime
from szurubooru import search
from szurubooru.api.base_api import BaseApi
from szurubooru.func import auth, comments, posts, scores, util
from szurubooru.rest import routes
_search_executor = search.Executor(search.configs.CommentSearchConfig())
def _serialize(ctx, comment, **kwargs):
return comments.serialize_comment(
@ -9,67 +11,65 @@ def _serialize(ctx, comment, **kwargs):
ctx.user,
options=util.get_serialization_options(ctx), **kwargs)
class CommentListApi(BaseApi):
def __init__(self):
super().__init__()
self._search_executor = search.Executor(
search.configs.CommentSearchConfig())
@routes.get('/comments/?')
def get_comments(ctx, _params=None):
auth.verify_privilege(ctx.user, 'comments:list')
return _search_executor.execute_and_serialize(
ctx, lambda comment: _serialize(ctx, comment))
def get(self, ctx):
auth.verify_privilege(ctx.user, 'comments:list')
return self._search_executor.execute_and_serialize(
ctx,
lambda comment: _serialize(ctx, comment))
@routes.post('/comments/?')
def create_comment(ctx, _params=None):
auth.verify_privilege(ctx.user, 'comments:create')
text = ctx.get_param_as_string('text', required=True)
post_id = ctx.get_param_as_int('postId', required=True)
post = posts.get_post_by_id(post_id)
comment = comments.create_comment(ctx.user, post, text)
ctx.session.add(comment)
ctx.session.commit()
return _serialize(ctx, comment)
def post(self, ctx):
auth.verify_privilege(ctx.user, 'comments:create')
text = ctx.get_param_as_string('text', required=True)
post_id = ctx.get_param_as_int('postId', required=True)
post = posts.get_post_by_id(post_id)
comment = comments.create_comment(ctx.user, post, text)
ctx.session.add(comment)
ctx.session.commit()
return _serialize(ctx, comment)
@routes.get('/comment/(?P<comment_id>[^/]+)/?')
def get_comment(ctx, params):
auth.verify_privilege(ctx.user, 'comments:view')
comment = comments.get_comment_by_id(params['comment_id'])
return _serialize(ctx, comment)
class CommentDetailApi(BaseApi):
def get(self, ctx, comment_id):
auth.verify_privilege(ctx.user, 'comments:view')
comment = comments.get_comment_by_id(comment_id)
return _serialize(ctx, comment)
@routes.put('/comment/(?P<comment_id>[^/]+)/?')
def update_comment(ctx, params):
comment = comments.get_comment_by_id(params['comment_id'])
util.verify_version(comment, ctx)
infix = 'own' if ctx.user.user_id == comment.user_id else 'any'
text = ctx.get_param_as_string('text', required=True)
auth.verify_privilege(ctx.user, 'comments:edit:%s' % infix)
comments.update_comment_text(comment, text)
util.bump_version(comment)
comment.last_edit_time = datetime.datetime.utcnow()
ctx.session.commit()
return _serialize(ctx, comment)
def put(self, ctx, comment_id):
comment = comments.get_comment_by_id(comment_id)
util.verify_version(comment, ctx)
infix = 'own' if ctx.user.user_id == comment.user_id else 'any'
text = ctx.get_param_as_string('text', required=True)
auth.verify_privilege(ctx.user, 'comments:edit:%s' % infix)
comments.update_comment_text(comment, text)
util.bump_version(comment)
comment.last_edit_time = datetime.datetime.utcnow()
ctx.session.commit()
return _serialize(ctx, comment)
@routes.delete('/comment/(?P<comment_id>[^/]+)/?')
def delete_comment(ctx, params):
comment = comments.get_comment_by_id(params['comment_id'])
util.verify_version(comment, ctx)
infix = 'own' if ctx.user.user_id == comment.user_id else 'any'
auth.verify_privilege(ctx.user, 'comments:delete:%s' % infix)
ctx.session.delete(comment)
ctx.session.commit()
return {}
def delete(self, ctx, comment_id):
comment = comments.get_comment_by_id(comment_id)
util.verify_version(comment, ctx)
infix = 'own' if ctx.user.user_id == comment.user_id else 'any'
auth.verify_privilege(ctx.user, 'comments:delete:%s' % infix)
ctx.session.delete(comment)
ctx.session.commit()
return {}
@routes.put('/comment/(?P<comment_id>[^/]+)/score/?')
def set_comment_score(ctx, params):
auth.verify_privilege(ctx.user, 'comments:score')
score = ctx.get_param_as_int('score', required=True)
comment = comments.get_comment_by_id(params['comment_id'])
scores.set_score(comment, ctx.user, score)
ctx.session.commit()
return _serialize(ctx, comment)
class CommentScoreApi(BaseApi):
def put(self, ctx, comment_id):
auth.verify_privilege(ctx.user, 'comments:score')
score = ctx.get_param_as_int('score', required=True)
comment = comments.get_comment_by_id(comment_id)
scores.set_score(comment, ctx.user, score)
ctx.session.commit()
return _serialize(ctx, comment)
def delete(self, ctx, comment_id):
auth.verify_privilege(ctx.user, 'comments:score')
comment = comments.get_comment_by_id(comment_id)
scores.delete_score(comment, ctx.user)
ctx.session.commit()
return _serialize(ctx, comment)
@routes.delete('/comment/(?P<comment_id>[^/]+)/score/?')
def delete_comment_score(ctx, params):
auth.verify_privilege(ctx.user, 'comments:score')
comment = comments.get_comment_by_id(params['comment_id'])
scores.delete_score(comment, ctx.user)
ctx.session.commit()
return _serialize(ctx, comment)

View file

@ -1,47 +1,46 @@
import datetime
import os
from szurubooru import config
from szurubooru.api.base_api import BaseApi
from szurubooru.func import posts, users, util
from szurubooru.rest import routes
class InfoApi(BaseApi):
def __init__(self):
super().__init__()
self._cache_time = None
self._cache_result = None
_cache_time = None
_cache_result = None
def get(self, ctx):
post_feature = posts.try_get_current_post_feature()
return {
'postCount': posts.get_post_count(),
'diskUsage': self._get_disk_usage(),
'featuredPost': posts.serialize_post(post_feature.post, ctx.user) \
if post_feature else None,
'featuringTime': post_feature.time if post_feature else None,
'featuringUser': users.serialize_user(post_feature.user, ctx.user) \
if post_feature else None,
'serverTime': datetime.datetime.utcnow(),
'config': {
'userNameRegex': config.config['user_name_regex'],
'passwordRegex': config.config['password_regex'],
'tagNameRegex': config.config['tag_name_regex'],
'tagCategoryNameRegex': config.config['tag_category_name_regex'],
'defaultUserRank': config.config['default_rank'],
'privileges': util.snake_case_to_lower_camel_case_keys(
config.config['privileges']),
},
}
def _get_disk_usage():
global _cache_time, _cache_result # pylint: disable=global-statement
threshold = datetime.timedelta(hours=1)
now = datetime.datetime.utcnow()
if _cache_time and _cache_time > now - threshold:
return _cache_result
total_size = 0
for dir_path, _, file_names in os.walk(config.config['data_dir']):
for file_name in file_names:
file_path = os.path.join(dir_path, file_name)
total_size += os.path.getsize(file_path)
_cache_time = now
_cache_result = total_size
return total_size
def _get_disk_usage(self):
threshold = datetime.timedelta(hours=1)
now = datetime.datetime.utcnow()
if self._cache_time and self._cache_time > now - threshold:
return self._cache_result
total_size = 0
for dir_path, _, file_names in os.walk(config.config['data_dir']):
for file_name in file_names:
file_path = os.path.join(dir_path, file_name)
total_size += os.path.getsize(file_path)
self._cache_time = now
self._cache_result = total_size
return total_size
@routes.get('/info/?')
def get_info(ctx, _params=None):
post_feature = posts.try_get_current_post_feature()
return {
'postCount': posts.get_post_count(),
'diskUsage': _get_disk_usage(),
'featuredPost': posts.serialize_post(post_feature.post, ctx.user) \
if post_feature else None,
'featuringTime': post_feature.time if post_feature else None,
'featuringUser': users.serialize_user(post_feature.user, ctx.user) \
if post_feature else None,
'serverTime': datetime.datetime.utcnow(),
'config': {
'userNameRegex': config.config['user_name_regex'],
'passwordRegex': config.config['password_regex'],
'tagNameRegex': config.config['tag_name_regex'],
'tagCategoryNameRegex': config.config['tag_category_name_regex'],
'defaultUserRank': config.config['default_rank'],
'privileges': util.snake_case_to_lower_camel_case_keys(
config.config['privileges']),
},
}

View file

@ -1,6 +1,6 @@
from szurubooru import config, errors
from szurubooru.api.base_api import BaseApi
from szurubooru.func import auth, mailer, users, util
from szurubooru.rest import routes
MAIL_SUBJECT = 'Password reset for {name}'
MAIL_BODY = \
@ -8,32 +8,35 @@ MAIL_BODY = \
'If you wish to proceed, click this link: {url}\n' \
'Otherwise, please ignore this email.'
class PasswordResetApi(BaseApi):
def get(self, _ctx, user_name):
''' Send a mail with secure token to the correlated user. '''
user = users.get_user_by_name_or_email(user_name)
if not user.email:
raise errors.ValidationError(
'User %r hasn\'t supplied email. Cannot reset password.' % (
user_name))
token = auth.generate_authentication_token(user)
url = '%s/password-reset/%s:%s' % (
config.config['base_url'].rstrip('/'), user.name, token)
mailer.send_mail(
'noreply@%s' % config.config['name'],
user.email,
MAIL_SUBJECT.format(name=config.config['name']),
MAIL_BODY.format(name=config.config['name'], url=url))
return {}
@routes.get('/password-reset/(?P<user_name>[^/]+)/?')
def start_password_reset(_ctx, params):
''' Send a mail with secure token to the correlated user. '''
user_name = params['user_name']
user = users.get_user_by_name_or_email(user_name)
if not user.email:
raise errors.ValidationError(
'User %r hasn\'t supplied email. Cannot reset password.' % (
user_name))
token = auth.generate_authentication_token(user)
url = '%s/password-reset/%s:%s' % (
config.config['base_url'].rstrip('/'), user.name, token)
mailer.send_mail(
'noreply@%s' % config.config['name'],
user.email,
MAIL_SUBJECT.format(name=config.config['name']),
MAIL_BODY.format(name=config.config['name'], url=url))
return {}
def post(self, ctx, user_name):
''' Verify token from mail, generate a new password and return it. '''
user = users.get_user_by_name_or_email(user_name)
good_token = auth.generate_authentication_token(user)
token = ctx.get_param_as_string('token', required=True)
if token != good_token:
raise errors.ValidationError('Invalid password reset token.')
new_password = users.reset_user_password(user)
util.bump_version(user)
ctx.session.commit()
return {'password': new_password}
@routes.post('/password-reset/(?P<user_name>[^/]+)/?')
def finish_password_reset(ctx, params):
''' Verify token from mail, generate a new password and return it. '''
user_name = params['user_name']
user = users.get_user_by_name_or_email(user_name)
good_token = auth.generate_authentication_token(user)
token = ctx.get_param_as_string('token', required=True)
if token != good_token:
raise errors.ValidationError('Invalid password reset token.')
new_password = users.reset_user_password(user)
util.bump_version(user)
ctx.session.commit()
return {'password': new_password}

View file

@ -1,7 +1,9 @@
import datetime
from szurubooru import search
from szurubooru.api.base_api import BaseApi
from szurubooru.func import auth, tags, posts, snapshots, favorites, scores, util
from szurubooru.rest import routes
_search_executor = search.Executor(search.configs.PostSearchConfig())
def _serialize_post(ctx, post):
return posts.serialize_post(
@ -9,165 +11,161 @@ def _serialize_post(ctx, post):
ctx.user,
options=util.get_serialization_options(ctx))
class PostListApi(BaseApi):
def __init__(self):
super().__init__()
self._search_executor = search.Executor(
search.configs.PostSearchConfig())
@routes.get('/posts/?')
def get_posts(ctx, _params=None):
auth.verify_privilege(ctx.user, 'posts:list')
_search_executor.config.user = ctx.user
return _search_executor.execute_and_serialize(
ctx, lambda post: _serialize_post(ctx, post))
def get(self, ctx):
auth.verify_privilege(ctx.user, 'posts:list')
self._search_executor.config.user = ctx.user
return self._search_executor.execute_and_serialize(
ctx, lambda post: _serialize_post(ctx, post))
@routes.post('/posts/?')
def create_post(ctx, _params=None):
anonymous = ctx.get_param_as_bool('anonymous', default=False)
if anonymous:
auth.verify_privilege(ctx.user, 'posts:create:anonymous')
else:
auth.verify_privilege(ctx.user, 'posts:create:identified')
content = ctx.get_file('content', required=True)
tag_names = ctx.get_param_as_list('tags', required=True)
safety = ctx.get_param_as_string('safety', required=True)
source = ctx.get_param_as_string('source', required=False, default=None)
if ctx.has_param('contentUrl') and not source:
source = ctx.get_param_as_string('contentUrl')
relations = ctx.get_param_as_list('relations', required=False) or []
notes = ctx.get_param_as_list('notes', required=False) or []
flags = ctx.get_param_as_list('flags', required=False) or []
def post(self, ctx):
anonymous = ctx.get_param_as_bool('anonymous', default=False)
if anonymous:
auth.verify_privilege(ctx.user, 'posts:create:anonymous')
else:
auth.verify_privilege(ctx.user, 'posts:create:identified')
content = ctx.get_file('content', required=True)
tag_names = ctx.get_param_as_list('tags', required=True)
safety = ctx.get_param_as_string('safety', required=True)
source = ctx.get_param_as_string('source', required=False, default=None)
if ctx.has_param('contentUrl') and not source:
source = ctx.get_param_as_string('contentUrl')
relations = ctx.get_param_as_list('relations', required=False) or []
notes = ctx.get_param_as_list('notes', required=False) or []
flags = ctx.get_param_as_list('flags', required=False) or []
post, new_tags = posts.create_post(
content, tag_names, None if anonymous else ctx.user)
if len(new_tags):
auth.verify_privilege(ctx.user, 'tags:create')
posts.update_post_safety(post, safety)
posts.update_post_source(post, source)
posts.update_post_relations(post, relations)
posts.update_post_notes(post, notes)
posts.update_post_flags(post, flags)
if ctx.has_file('thumbnail'):
posts.update_post_thumbnail(post, ctx.get_file('thumbnail'))
ctx.session.add(post)
snapshots.save_entity_creation(post, ctx.user)
ctx.session.commit()
tags.export_to_json()
return _serialize_post(ctx, post)
post, new_tags = posts.create_post(
content, tag_names, None if anonymous else ctx.user)
@routes.get('/post/(?P<post_id>[^/]+)/?')
def get_post(ctx, params):
auth.verify_privilege(ctx.user, 'posts:view')
post = posts.get_post_by_id(params['post_id'])
return _serialize_post(ctx, post)
@routes.put('/post/(?P<post_id>[^/]+)/?')
def update_post(ctx, params):
post = posts.get_post_by_id(params['post_id'])
util.verify_version(post, ctx)
if ctx.has_file('content'):
auth.verify_privilege(ctx.user, 'posts:edit:content')
posts.update_post_content(post, ctx.get_file('content'))
if ctx.has_param('tags'):
auth.verify_privilege(ctx.user, 'posts:edit:tags')
new_tags = posts.update_post_tags(post, ctx.get_param_as_list('tags'))
if len(new_tags):
auth.verify_privilege(ctx.user, 'tags:create')
posts.update_post_safety(post, safety)
posts.update_post_source(post, source)
posts.update_post_relations(post, relations)
posts.update_post_notes(post, notes)
posts.update_post_flags(post, flags)
if ctx.has_file('thumbnail'):
posts.update_post_thumbnail(post, ctx.get_file('thumbnail'))
ctx.session.add(post)
snapshots.save_entity_creation(post, ctx.user)
ctx.session.commit()
tags.export_to_json()
return _serialize_post(ctx, post)
if ctx.has_param('safety'):
auth.verify_privilege(ctx.user, 'posts:edit:safety')
posts.update_post_safety(post, ctx.get_param_as_string('safety'))
if ctx.has_param('source'):
auth.verify_privilege(ctx.user, 'posts:edit:source')
posts.update_post_source(post, ctx.get_param_as_string('source'))
elif ctx.has_param('contentUrl'):
posts.update_post_source(post, ctx.get_param_as_string('contentUrl'))
if ctx.has_param('relations'):
auth.verify_privilege(ctx.user, 'posts:edit:relations')
posts.update_post_relations(post, ctx.get_param_as_list('relations'))
if ctx.has_param('notes'):
auth.verify_privilege(ctx.user, 'posts:edit:notes')
posts.update_post_notes(post, ctx.get_param_as_list('notes'))
if ctx.has_param('flags'):
auth.verify_privilege(ctx.user, 'posts:edit:flags')
posts.update_post_flags(post, ctx.get_param_as_list('flags'))
if ctx.has_file('thumbnail'):
auth.verify_privilege(ctx.user, 'posts:edit:thumbnail')
posts.update_post_thumbnail(post, ctx.get_file('thumbnail'))
util.bump_version(post)
post.last_edit_time = datetime.datetime.utcnow()
ctx.session.flush()
snapshots.save_entity_modification(post, ctx.user)
ctx.session.commit()
tags.export_to_json()
return _serialize_post(ctx, post)
class PostDetailApi(BaseApi):
def get(self, ctx, post_id):
auth.verify_privilege(ctx.user, 'posts:view')
post = posts.get_post_by_id(post_id)
return _serialize_post(ctx, post)
@routes.delete('/post/(?P<post_id>[^/]+)/?')
def delete_post(ctx, params):
auth.verify_privilege(ctx.user, 'posts:delete')
post = posts.get_post_by_id(params['post_id'])
util.verify_version(post, ctx)
snapshots.save_entity_deletion(post, ctx.user)
posts.delete(post)
ctx.session.commit()
tags.export_to_json()
return {}
def put(self, ctx, post_id):
post = posts.get_post_by_id(post_id)
util.verify_version(post, ctx)
if ctx.has_file('content'):
auth.verify_privilege(ctx.user, 'posts:edit:content')
posts.update_post_content(post, ctx.get_file('content'))
if ctx.has_param('tags'):
auth.verify_privilege(ctx.user, 'posts:edit:tags')
new_tags = posts.update_post_tags(post, ctx.get_param_as_list('tags'))
if len(new_tags):
auth.verify_privilege(ctx.user, 'tags:create')
if ctx.has_param('safety'):
auth.verify_privilege(ctx.user, 'posts:edit:safety')
posts.update_post_safety(post, ctx.get_param_as_string('safety'))
if ctx.has_param('source'):
auth.verify_privilege(ctx.user, 'posts:edit:source')
posts.update_post_source(post, ctx.get_param_as_string('source'))
elif ctx.has_param('contentUrl'):
posts.update_post_source(post, ctx.get_param_as_string('contentUrl'))
if ctx.has_param('relations'):
auth.verify_privilege(ctx.user, 'posts:edit:relations')
posts.update_post_relations(post, ctx.get_param_as_list('relations'))
if ctx.has_param('notes'):
auth.verify_privilege(ctx.user, 'posts:edit:notes')
posts.update_post_notes(post, ctx.get_param_as_list('notes'))
if ctx.has_param('flags'):
auth.verify_privilege(ctx.user, 'posts:edit:flags')
posts.update_post_flags(post, ctx.get_param_as_list('flags'))
if ctx.has_file('thumbnail'):
auth.verify_privilege(ctx.user, 'posts:edit:thumbnail')
posts.update_post_thumbnail(post, ctx.get_file('thumbnail'))
util.bump_version(post)
post.last_edit_time = datetime.datetime.utcnow()
ctx.session.flush()
snapshots.save_entity_modification(post, ctx.user)
ctx.session.commit()
tags.export_to_json()
return _serialize_post(ctx, post)
@routes.get('/featured-post/?')
def get_featured_post(ctx, _params=None):
post = posts.try_get_featured_post()
return _serialize_post(ctx, post)
def delete(self, ctx, post_id):
auth.verify_privilege(ctx.user, 'posts:delete')
post = posts.get_post_by_id(post_id)
util.verify_version(post, ctx)
snapshots.save_entity_deletion(post, ctx.user)
posts.delete(post)
ctx.session.commit()
tags.export_to_json()
return {}
@routes.post('/featured-post/?')
def set_featured_post(ctx, _params=None):
auth.verify_privilege(ctx.user, 'posts:feature')
post_id = ctx.get_param_as_int('id', required=True)
post = posts.get_post_by_id(post_id)
featured_post = posts.try_get_featured_post()
if featured_post and featured_post.post_id == post.post_id:
raise posts.PostAlreadyFeaturedError(
'Post %r is already featured.' % post_id)
posts.feature_post(post, ctx.user)
if featured_post:
snapshots.save_entity_modification(featured_post, ctx.user)
snapshots.save_entity_modification(post, ctx.user)
ctx.session.commit()
return _serialize_post(ctx, post)
class PostFeatureApi(BaseApi):
def post(self, ctx):
auth.verify_privilege(ctx.user, 'posts:feature')
post_id = ctx.get_param_as_int('id', required=True)
post = posts.get_post_by_id(post_id)
featured_post = posts.try_get_featured_post()
if featured_post and featured_post.post_id == post.post_id:
raise posts.PostAlreadyFeaturedError(
'Post %r is already featured.' % post_id)
posts.feature_post(post, ctx.user)
if featured_post:
snapshots.save_entity_modification(featured_post, ctx.user)
snapshots.save_entity_modification(post, ctx.user)
ctx.session.commit()
return _serialize_post(ctx, post)
@routes.put('/post/(?P<post_id>[^/]+)/score/?')
def set_post_score(ctx, params):
auth.verify_privilege(ctx.user, 'posts:score')
post = posts.get_post_by_id(params['post_id'])
score = ctx.get_param_as_int('score', required=True)
scores.set_score(post, ctx.user, score)
ctx.session.commit()
return _serialize_post(ctx, post)
def get(self, ctx):
post = posts.try_get_featured_post()
return _serialize_post(ctx, post)
@routes.delete('/post/(?P<post_id>[^/]+)/score/?')
def delete_post_score(ctx, params):
auth.verify_privilege(ctx.user, 'posts:score')
post = posts.get_post_by_id(params['post_id'])
scores.delete_score(post, ctx.user)
ctx.session.commit()
return _serialize_post(ctx, post)
class PostScoreApi(BaseApi):
def put(self, ctx, post_id):
auth.verify_privilege(ctx.user, 'posts:score')
post = posts.get_post_by_id(post_id)
score = ctx.get_param_as_int('score', required=True)
scores.set_score(post, ctx.user, score)
ctx.session.commit()
return _serialize_post(ctx, post)
@routes.post('/post/(?P<post_id>[^/]+)/favorite/?')
def add_post_to_favorites(ctx, params):
auth.verify_privilege(ctx.user, 'posts:favorite')
post = posts.get_post_by_id(params['post_id'])
favorites.set_favorite(post, ctx.user)
ctx.session.commit()
return _serialize_post(ctx, post)
def delete(self, ctx, post_id):
auth.verify_privilege(ctx.user, 'posts:score')
post = posts.get_post_by_id(post_id)
scores.delete_score(post, ctx.user)
ctx.session.commit()
return _serialize_post(ctx, post)
@routes.delete('/post/(?P<post_id>[^/]+)/favorite/?')
def delete_post_from_favorites(ctx, params):
auth.verify_privilege(ctx.user, 'posts:favorite')
post = posts.get_post_by_id(params['post_id'])
favorites.unset_favorite(post, ctx.user)
ctx.session.commit()
return _serialize_post(ctx, post)
class PostFavoriteApi(BaseApi):
def post(self, ctx, post_id):
auth.verify_privilege(ctx.user, 'posts:favorite')
post = posts.get_post_by_id(post_id)
favorites.set_favorite(post, ctx.user)
ctx.session.commit()
return _serialize_post(ctx, post)
def delete(self, ctx, post_id):
auth.verify_privilege(ctx.user, 'posts:favorite')
post = posts.get_post_by_id(post_id)
favorites.unset_favorite(post, ctx.user)
ctx.session.commit()
return _serialize_post(ctx, post)
class PostsAroundApi(BaseApi):
def __init__(self):
super().__init__()
self._search_executor = search.Executor(
search.configs.PostSearchConfig())
def get(self, ctx, post_id):
auth.verify_privilege(ctx.user, 'posts:list')
self._search_executor.config.user = ctx.user
return self._search_executor.get_around_and_serialize(
ctx, post_id, lambda post: _serialize_post(ctx, post))
@routes.get('/post/(?P<post_id>[^/]+)/around/?')
def get_posts_around(ctx, params):
auth.verify_privilege(ctx.user, 'posts:list')
_search_executor.config.user = ctx.user
return _search_executor.get_around_and_serialize(
ctx, params['post_id'], lambda post: _serialize_post(ctx, post))

View file

@ -1,14 +1,12 @@
from szurubooru import search
from szurubooru.api.base_api import BaseApi
from szurubooru.func import auth, snapshots
from szurubooru.rest import routes
class SnapshotListApi(BaseApi):
def __init__(self):
super().__init__()
self._search_executor = search.Executor(
search.configs.SnapshotSearchConfig())
_search_executor = search.Executor(
search.configs.SnapshotSearchConfig())
def get(self, ctx):
auth.verify_privilege(ctx.user, 'snapshots:list')
return self._search_executor.execute_and_serialize(
ctx, snapshots.serialize_snapshot)
@routes.get('/snapshots/?')
def get_snapshots(ctx, _params=None):
auth.verify_privilege(ctx.user, 'snapshots:list')
return _search_executor.execute_and_serialize(
ctx, snapshots.serialize_snapshot)

View file

@ -1,7 +1,9 @@
import datetime
from szurubooru import db, search
from szurubooru.api.base_api import BaseApi
from szurubooru.func import auth, tags, util, snapshots
from szurubooru.rest import routes
_search_executor = search.Executor(search.configs.TagSearchConfig())
def _serialize(ctx, tag):
return tags.serialize_tag(
@ -17,116 +19,112 @@ def _create_if_needed(tag_names, user):
for tag in new_tags:
snapshots.save_entity_creation(tag, user)
class TagListApi(BaseApi):
def __init__(self):
super().__init__()
self._search_executor = search.Executor(
search.configs.TagSearchConfig())
@routes.get('/tags/?')
def get_tags(ctx, _params=None):
auth.verify_privilege(ctx.user, 'tags:list')
return _search_executor.execute_and_serialize(
ctx, lambda tag: _serialize(ctx, tag))
def get(self, ctx):
auth.verify_privilege(ctx.user, 'tags:list')
return self._search_executor.execute_and_serialize(
ctx, lambda tag: _serialize(ctx, tag))
@routes.post('/tags/?')
def create_tag(ctx, _params=None):
auth.verify_privilege(ctx.user, 'tags:create')
def post(self, ctx):
auth.verify_privilege(ctx.user, 'tags:create')
names = ctx.get_param_as_list('names', required=True)
category = ctx.get_param_as_string('category', required=True)
description = ctx.get_param_as_string(
'description', required=False, default=None)
suggestions = ctx.get_param_as_list(
'suggestions', required=False, default=[])
implications = ctx.get_param_as_list(
'implications', required=False, default=[])
names = ctx.get_param_as_list('names', required=True)
category = ctx.get_param_as_string('category', required=True) or ''
description = ctx.get_param_as_string(
'description', required=False, default=None)
suggestions = ctx.get_param_as_list(
'suggestions', required=False, default=[])
implications = ctx.get_param_as_list(
'implications', required=False, default=[])
_create_if_needed(suggestions, ctx.user)
_create_if_needed(implications, ctx.user)
tag = tags.create_tag(names, category, suggestions, implications)
tags.update_tag_description(tag, description)
ctx.session.add(tag)
ctx.session.flush()
snapshots.save_entity_creation(tag, ctx.user)
ctx.session.commit()
tags.export_to_json()
return _serialize(ctx, tag)
@routes.get('/tag/(?P<tag_name>[^/]+)/?')
def get_tag(ctx, params):
auth.verify_privilege(ctx.user, 'tags:view')
tag = tags.get_tag_by_name(params['tag_name'])
return _serialize(ctx, tag)
@routes.put('/tag/(?P<tag_name>[^/]+)/?')
def update_tag(ctx, params):
tag = tags.get_tag_by_name(params['tag_name'])
util.verify_version(tag, ctx)
if ctx.has_param('names'):
auth.verify_privilege(ctx.user, 'tags:edit:names')
tags.update_tag_names(tag, ctx.get_param_as_list('names'))
if ctx.has_param('category'):
auth.verify_privilege(ctx.user, 'tags:edit:category')
tags.update_tag_category_name(
tag, ctx.get_param_as_string('category'))
if ctx.has_param('description'):
auth.verify_privilege(ctx.user, 'tags:edit:description')
tags.update_tag_description(
tag, ctx.get_param_as_string('description', default=None))
if ctx.has_param('suggestions'):
auth.verify_privilege(ctx.user, 'tags:edit:suggestions')
suggestions = ctx.get_param_as_list('suggestions')
_create_if_needed(suggestions, ctx.user)
tags.update_tag_suggestions(tag, suggestions)
if ctx.has_param('implications'):
auth.verify_privilege(ctx.user, 'tags:edit:implications')
implications = ctx.get_param_as_list('implications')
_create_if_needed(implications, ctx.user)
tags.update_tag_implications(tag, implications)
util.bump_version(tag)
tag.last_edit_time = datetime.datetime.utcnow()
ctx.session.flush()
snapshots.save_entity_modification(tag, ctx.user)
ctx.session.commit()
tags.export_to_json()
return _serialize(ctx, tag)
tag = tags.create_tag(names, category, suggestions, implications)
tags.update_tag_description(tag, description)
ctx.session.add(tag)
ctx.session.flush()
snapshots.save_entity_creation(tag, ctx.user)
ctx.session.commit()
tags.export_to_json()
return _serialize(ctx, tag)
@routes.delete('/tag/(?P<tag_name>[^/]+)/?')
def delete_tag(ctx, params):
tag = tags.get_tag_by_name(params['tag_name'])
util.verify_version(tag, ctx)
auth.verify_privilege(ctx.user, 'tags:delete')
snapshots.save_entity_deletion(tag, ctx.user)
tags.delete(tag)
ctx.session.commit()
tags.export_to_json()
return {}
class TagDetailApi(BaseApi):
def get(self, ctx, tag_name):
auth.verify_privilege(ctx.user, 'tags:view')
tag = tags.get_tag_by_name(tag_name)
return _serialize(ctx, tag)
@routes.post('/tag-merge/?')
def merge_tags(ctx, _params=None):
source_tag_name = ctx.get_param_as_string('remove', required=True) or ''
target_tag_name = ctx.get_param_as_string('mergeTo', required=True) or ''
source_tag = tags.get_tag_by_name(source_tag_name)
target_tag = tags.get_tag_by_name(target_tag_name)
util.verify_version(source_tag, ctx, 'removeVersion')
util.verify_version(target_tag, ctx, 'mergeToVersion')
auth.verify_privilege(ctx.user, 'tags:merge')
tags.merge_tags(source_tag, target_tag)
snapshots.save_entity_deletion(source_tag, ctx.user)
util.bump_version(target_tag)
ctx.session.commit()
tags.export_to_json()
return _serialize(ctx, target_tag)
def put(self, ctx, tag_name):
tag = tags.get_tag_by_name(tag_name)
util.verify_version(tag, ctx)
if ctx.has_param('names'):
auth.verify_privilege(ctx.user, 'tags:edit:names')
tags.update_tag_names(tag, ctx.get_param_as_list('names'))
if ctx.has_param('category'):
auth.verify_privilege(ctx.user, 'tags:edit:category')
tags.update_tag_category_name(
tag, ctx.get_param_as_string('category') or '')
if ctx.has_param('description'):
auth.verify_privilege(ctx.user, 'tags:edit:description')
tags.update_tag_description(
tag, ctx.get_param_as_string('description', default=None))
if ctx.has_param('suggestions'):
auth.verify_privilege(ctx.user, 'tags:edit:suggestions')
suggestions = ctx.get_param_as_list('suggestions')
_create_if_needed(suggestions, ctx.user)
tags.update_tag_suggestions(tag, suggestions)
if ctx.has_param('implications'):
auth.verify_privilege(ctx.user, 'tags:edit:implications')
implications = ctx.get_param_as_list('implications')
_create_if_needed(implications, ctx.user)
tags.update_tag_implications(tag, implications)
util.bump_version(tag)
tag.last_edit_time = datetime.datetime.utcnow()
ctx.session.flush()
snapshots.save_entity_modification(tag, ctx.user)
ctx.session.commit()
tags.export_to_json()
return _serialize(ctx, tag)
def delete(self, ctx, tag_name):
tag = tags.get_tag_by_name(tag_name)
util.verify_version(tag, ctx)
auth.verify_privilege(ctx.user, 'tags:delete')
snapshots.save_entity_deletion(tag, ctx.user)
tags.delete(tag)
ctx.session.commit()
tags.export_to_json()
return {}
class TagMergeApi(BaseApi):
def post(self, ctx):
source_tag_name = ctx.get_param_as_string('remove', required=True) or ''
target_tag_name = ctx.get_param_as_string('mergeTo', required=True) or ''
source_tag = tags.get_tag_by_name(source_tag_name)
target_tag = tags.get_tag_by_name(target_tag_name)
util.verify_version(source_tag, ctx, 'removeVersion')
util.verify_version(target_tag, ctx, 'mergeToVersion')
if source_tag.tag_id == target_tag.tag_id:
raise tags.InvalidTagRelationError('Cannot merge tag with itself.')
auth.verify_privilege(ctx.user, 'tags:merge')
snapshots.save_entity_deletion(source_tag, ctx.user)
tags.merge_tags(source_tag, target_tag)
util.bump_version(target_tag)
ctx.session.commit()
tags.export_to_json()
return _serialize(ctx, target_tag)
class TagSiblingsApi(BaseApi):
def get(self, ctx, tag_name):
auth.verify_privilege(ctx.user, 'tags:view')
tag = tags.get_tag_by_name(tag_name)
result = tags.get_tag_siblings(tag)
serialized_siblings = []
for sibling, occurrences in result:
serialized_siblings.append({
'tag': _serialize(ctx, sibling),
'occurrences': occurrences
})
return {'results': serialized_siblings}
@routes.get('/tag-siblings/(?P<tag_name>[^/]+)/?')
def get_tag_siblings(ctx, params):
auth.verify_privilege(ctx.user, 'tags:view')
tag = tags.get_tag_by_name(params['tag_name'])
result = tags.get_tag_siblings(tag)
serialized_siblings = []
for sibling, occurrences in result:
serialized_siblings.append({
'tag': _serialize(ctx, sibling),
'occurrences': occurrences
})
return {'results': serialized_siblings}

View file

@ -1,70 +1,73 @@
from szurubooru.api.base_api import BaseApi
from szurubooru.rest import routes
from szurubooru.func import auth, tags, tag_categories, util, snapshots
def _serialize(ctx, category):
return tag_categories.serialize_category(
category, options=util.get_serialization_options(ctx))
class TagCategoryListApi(BaseApi):
def get(self, ctx):
auth.verify_privilege(ctx.user, 'tag_categories:list')
categories = tag_categories.get_all_categories()
return {
'results': [_serialize(ctx, category) for category in categories],
}
@routes.get('/tag-categories/?')
def get_tag_categories(ctx, _params=None):
auth.verify_privilege(ctx.user, 'tag_categories:list')
categories = tag_categories.get_all_categories()
return {
'results': [_serialize(ctx, category) for category in categories],
}
def post(self, ctx):
auth.verify_privilege(ctx.user, 'tag_categories:create')
name = ctx.get_param_as_string('name', required=True)
color = ctx.get_param_as_string('color', required=True)
category = tag_categories.create_category(name, color)
ctx.session.add(category)
ctx.session.flush()
snapshots.save_entity_creation(category, ctx.user)
ctx.session.commit()
tags.export_to_json()
return _serialize(ctx, category)
@routes.post('/tag-categories/?')
def create_tag_category(ctx, _params=None):
auth.verify_privilege(ctx.user, 'tag_categories:create')
name = ctx.get_param_as_string('name', required=True)
color = ctx.get_param_as_string('color', required=True)
category = tag_categories.create_category(name, color)
ctx.session.add(category)
ctx.session.flush()
snapshots.save_entity_creation(category, ctx.user)
ctx.session.commit()
tags.export_to_json()
return _serialize(ctx, category)
class TagCategoryDetailApi(BaseApi):
def get(self, ctx, category_name):
auth.verify_privilege(ctx.user, 'tag_categories:view')
category = tag_categories.get_category_by_name(category_name)
return _serialize(ctx, category)
@routes.get('/tag-category/(?P<category_name>[^/]+)/?')
def get_tag_category(ctx, params):
auth.verify_privilege(ctx.user, 'tag_categories:view')
category = tag_categories.get_category_by_name(params['category_name'])
return _serialize(ctx, category)
def put(self, ctx, category_name):
category = tag_categories.get_category_by_name(category_name)
util.verify_version(category, ctx)
if ctx.has_param('name'):
auth.verify_privilege(ctx.user, 'tag_categories:edit:name')
tag_categories.update_category_name(
category, ctx.get_param_as_string('name'))
if ctx.has_param('color'):
auth.verify_privilege(ctx.user, 'tag_categories:edit:color')
tag_categories.update_category_color(
category, ctx.get_param_as_string('color'))
util.bump_version(category)
ctx.session.flush()
snapshots.save_entity_modification(category, ctx.user)
ctx.session.commit()
tags.export_to_json()
return _serialize(ctx, category)
@routes.put('/tag-category/(?P<category_name>[^/]+)/?')
def update_tag_category(ctx, params):
category = tag_categories.get_category_by_name(params['category_name'])
util.verify_version(category, ctx)
if ctx.has_param('name'):
auth.verify_privilege(ctx.user, 'tag_categories:edit:name')
tag_categories.update_category_name(
category, ctx.get_param_as_string('name'))
if ctx.has_param('color'):
auth.verify_privilege(ctx.user, 'tag_categories:edit:color')
tag_categories.update_category_color(
category, ctx.get_param_as_string('color'))
util.bump_version(category)
ctx.session.flush()
snapshots.save_entity_modification(category, ctx.user)
ctx.session.commit()
tags.export_to_json()
return _serialize(ctx, category)
def delete(self, ctx, category_name):
category = tag_categories.get_category_by_name(category_name)
util.verify_version(category, ctx)
auth.verify_privilege(ctx.user, 'tag_categories:delete')
tag_categories.delete_category(category)
snapshots.save_entity_deletion(category, ctx.user)
ctx.session.commit()
tags.export_to_json()
return {}
@routes.delete('/tag-category/(?P<category_name>[^/]+)/?')
def delete_tag_category(ctx, params):
category = tag_categories.get_category_by_name(params['category_name'])
util.verify_version(category, ctx)
auth.verify_privilege(ctx.user, 'tag_categories:delete')
tag_categories.delete_category(category)
snapshots.save_entity_deletion(category, ctx.user)
ctx.session.commit()
tags.export_to_json()
return {}
class DefaultTagCategoryApi(BaseApi):
def put(self, ctx, category_name):
auth.verify_privilege(ctx.user, 'tag_categories:set_default')
category = tag_categories.get_category_by_name(category_name)
tag_categories.set_default_category(category)
snapshots.save_entity_modification(category, ctx.user)
ctx.session.commit()
tags.export_to_json()
return _serialize(ctx, category)
@routes.put('/tag-category/(?P<category_name>[^/]+)/default/?')
def set_tag_category_as_default(ctx, params):
auth.verify_privilege(ctx.user, 'tag_categories:set_default')
category = tag_categories.get_category_by_name(params['category_name'])
tag_categories.set_default_category(category)
snapshots.save_entity_modification(category, ctx.user)
ctx.session.commit()
tags.export_to_json()
return _serialize(ctx, category)

View file

@ -1,6 +1,8 @@
from szurubooru import search
from szurubooru.api.base_api import BaseApi
from szurubooru.func import auth, users, util
from szurubooru.rest import routes
_search_executor = search.Executor(search.configs.UserSearchConfig())
def _serialize(ctx, user, **kwargs):
return users.serialize_user(
@ -9,75 +11,73 @@ def _serialize(ctx, user, **kwargs):
options=util.get_serialization_options(ctx),
**kwargs)
class UserListApi(BaseApi):
def __init__(self):
super().__init__()
self._search_executor = search.Executor(
search.configs.UserSearchConfig())
@routes.get('/users/?')
def get_users(ctx, _params=None):
auth.verify_privilege(ctx.user, 'users:list')
return _search_executor.execute_and_serialize(
ctx, lambda user: _serialize(ctx, user))
def get(self, ctx):
auth.verify_privilege(ctx.user, 'users:list')
return self._search_executor.execute_and_serialize(
ctx, lambda user: _serialize(ctx, user))
@routes.post('/users/?')
def create_user(ctx, _params=None):
auth.verify_privilege(ctx.user, 'users:create')
name = ctx.get_param_as_string('name', required=True)
password = ctx.get_param_as_string('password', required=True)
email = ctx.get_param_as_string('email', required=False, default='')
user = users.create_user(name, password, email)
if ctx.has_param('rank'):
users.update_user_rank(
user, ctx.get_param_as_string('rank'), ctx.user)
if ctx.has_param('avatarStyle'):
users.update_user_avatar(
user,
ctx.get_param_as_string('avatarStyle'),
ctx.get_file('avatar'))
ctx.session.add(user)
ctx.session.commit()
return _serialize(ctx, user, force_show_email=True)
def post(self, ctx):
auth.verify_privilege(ctx.user, 'users:create')
name = ctx.get_param_as_string('name', required=True)
password = ctx.get_param_as_string('password', required=True)
email = ctx.get_param_as_string('email', required=False, default='')
user = users.create_user(name, password, email)
if ctx.has_param('rank'):
users.update_user_rank(
user, ctx.get_param_as_string('rank'), ctx.user)
if ctx.has_param('avatarStyle'):
users.update_user_avatar(
user,
ctx.get_param_as_string('avatarStyle'),
ctx.get_file('avatar'))
ctx.session.add(user)
ctx.session.commit()
return _serialize(ctx, user, force_show_email=True)
@routes.get('/user/(?P<user_name>[^/]+)/?')
def get_user(ctx, params):
user = users.get_user_by_name(params['user_name'])
if ctx.user.user_id != user.user_id:
auth.verify_privilege(ctx.user, 'users:view')
return _serialize(ctx, user)
class UserDetailApi(BaseApi):
def get(self, ctx, user_name):
user = users.get_user_by_name(user_name)
if ctx.user.user_id != user.user_id:
auth.verify_privilege(ctx.user, 'users:view')
return _serialize(ctx, user)
@routes.put('/user/(?P<user_name>[^/]+)/?')
def update_user(ctx, params):
user = users.get_user_by_name(params['user_name'])
util.verify_version(user, ctx)
infix = 'self' if ctx.user.user_id == user.user_id else 'any'
if ctx.has_param('name'):
auth.verify_privilege(ctx.user, 'users:edit:%s:name' % infix)
users.update_user_name(user, ctx.get_param_as_string('name'))
if ctx.has_param('password'):
auth.verify_privilege(ctx.user, 'users:edit:%s:pass' % infix)
users.update_user_password(
user, ctx.get_param_as_string('password'))
if ctx.has_param('email'):
auth.verify_privilege(ctx.user, 'users:edit:%s:email' % infix)
users.update_user_email(user, ctx.get_param_as_string('email'))
if ctx.has_param('rank'):
auth.verify_privilege(ctx.user, 'users:edit:%s:rank' % infix)
users.update_user_rank(
user, ctx.get_param_as_string('rank'), ctx.user)
if ctx.has_param('avatarStyle'):
auth.verify_privilege(ctx.user, 'users:edit:%s:avatar' % infix)
users.update_user_avatar(
user,
ctx.get_param_as_string('avatarStyle'),
ctx.get_file('avatar'))
util.bump_version(user)
ctx.session.commit()
return _serialize(ctx, user)
def put(self, ctx, user_name):
user = users.get_user_by_name(user_name)
util.verify_version(user, ctx)
infix = 'self' if ctx.user.user_id == user.user_id else 'any'
if ctx.has_param('name'):
auth.verify_privilege(ctx.user, 'users:edit:%s:name' % infix)
users.update_user_name(user, ctx.get_param_as_string('name'))
if ctx.has_param('password'):
auth.verify_privilege(ctx.user, 'users:edit:%s:pass' % infix)
users.update_user_password(
user, ctx.get_param_as_string('password'))
if ctx.has_param('email'):
auth.verify_privilege(ctx.user, 'users:edit:%s:email' % infix)
users.update_user_email(user, ctx.get_param_as_string('email'))
if ctx.has_param('rank'):
auth.verify_privilege(ctx.user, 'users:edit:%s:rank' % infix)
users.update_user_rank(
user, ctx.get_param_as_string('rank'), ctx.user)
if ctx.has_param('avatarStyle'):
auth.verify_privilege(ctx.user, 'users:edit:%s:avatar' % infix)
users.update_user_avatar(
user,
ctx.get_param_as_string('avatarStyle'),
ctx.get_file('avatar'))
util.bump_version(user)
ctx.session.commit()
return _serialize(ctx, user)
def delete(self, ctx, user_name):
user = users.get_user_by_name(user_name)
util.verify_version(user, ctx)
infix = 'self' if ctx.user.user_id == user.user_id else 'any'
auth.verify_privilege(ctx.user, 'users:delete:%s' % infix)
ctx.session.delete(user)
ctx.session.commit()
return {}
@routes.delete('/user/(?P<user_name>[^/]+)/?')
def delete_user(ctx, params):
user = users.get_user_by_name(params['user_name'])
util.verify_version(user, ctx)
infix = 'self' if ctx.user.user_id == user.user_id else 'any'
auth.verify_privilege(ctx.user, 'users:delete:%s' % infix)
ctx.session.delete(user)
ctx.session.commit()
return {}

View file

@ -1,124 +0,0 @@
''' Exports create_app. '''
import os
import logging
import coloredlogs
import falcon
from szurubooru import api, config, errors, middleware
def _on_auth_error(ex, _request, _response, _params):
raise falcon.HTTPForbidden(
title='Authentication error', description=str(ex))
def _on_validation_error(ex, _request, _response, _params):
raise falcon.HTTPBadRequest(title='Validation error', description=str(ex))
def _on_search_error(ex, _request, _response, _params):
raise falcon.HTTPBadRequest(title='Search error', description=str(ex))
def _on_integrity_error(ex, _request, _response, _params):
raise falcon.HTTPConflict(
title='Integrity violation', description=ex.args[0])
def _on_not_found_error(ex, _request, _response, _params):
raise falcon.HTTPNotFound(title='Not found', description=str(ex))
def _on_processing_error(ex, _request, _response, _params):
raise falcon.HTTPBadRequest(title='Processing error', description=str(ex))
def create_method_not_allowed(allowed_methods):
allowed = ', '.join(allowed_methods)
def method_not_allowed(request, response, **_kwargs):
response.status = falcon.status_codes.HTTP_405
response.set_header('Allow', allowed)
request.context.output = {
'title': 'Method not allowed',
'description': 'Allowed methods: %r' % allowed_methods,
}
return method_not_allowed
def validate_config():
'''
Check whether config doesn't contain errors that might prove
lethal at runtime.
'''
from szurubooru.func.auth import RANK_MAP
for privilege, rank in config.config['privileges'].items():
if rank not in RANK_MAP.values():
raise errors.ConfigError(
'Rank %r for privilege %r is missing' % (rank, privilege))
if config.config['default_rank'] not in RANK_MAP.values():
raise errors.ConfigError(
'Default rank %r is not on the list of known ranks' % (
config.config['default_rank']))
for key in ['base_url', 'api_url', 'data_url', 'data_dir']:
if not config.config[key]:
raise errors.ConfigError(
'Service is not configured: %r is missing' % key)
if not os.path.isabs(config.config['data_dir']):
raise errors.ConfigError(
'data_dir must be an absolute path')
for key in ['schema', 'host', 'port', 'user', 'pass', 'name']:
if not config.config['database'][key]:
raise errors.ConfigError(
'Database is not configured: %r is missing' % key)
def create_app():
''' Create a WSGI compatible App object. '''
validate_config()
falcon.responders.create_method_not_allowed = create_method_not_allowed
coloredlogs.install(fmt='[%(asctime)-15s] %(name)s %(message)s')
if config.config['debug']:
logging.getLogger('szurubooru').setLevel(logging.INFO)
if config.config['show_sql']:
logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
app = falcon.API(
request_type=api.Request,
middleware=[
middleware.RequireJson(),
middleware.CachePurger(),
middleware.ContextAdapter(),
middleware.DbSession(),
middleware.Authenticator(),
middleware.RequestLogger(),
])
app.add_error_handler(errors.AuthError, _on_auth_error)
app.add_error_handler(errors.IntegrityError, _on_integrity_error)
app.add_error_handler(errors.ValidationError, _on_validation_error)
app.add_error_handler(errors.SearchError, _on_search_error)
app.add_error_handler(errors.NotFoundError, _on_not_found_error)
app.add_error_handler(errors.ProcessingError, _on_processing_error)
app.add_route('/users/', api.UserListApi())
app.add_route('/user/{user_name}', api.UserDetailApi())
app.add_route('/password-reset/{user_name}', api.PasswordResetApi())
app.add_route('/tag-categories/', api.TagCategoryListApi())
app.add_route('/tag-category/{category_name}', api.TagCategoryDetailApi())
app.add_route('/tag-category/{category_name}/default', api.DefaultTagCategoryApi())
app.add_route('/tags/', api.TagListApi())
app.add_route('/tag/{tag_name}', api.TagDetailApi())
app.add_route('/tag-merge/', api.TagMergeApi())
app.add_route('/tag-siblings/{tag_name}', api.TagSiblingsApi())
app.add_route('/posts/', api.PostListApi())
app.add_route('/post/{post_id}', api.PostDetailApi())
app.add_route('/post/{post_id}/score', api.PostScoreApi())
app.add_route('/post/{post_id}/favorite', api.PostFavoriteApi())
app.add_route('/post/{post_id}/around', api.PostsAroundApi())
app.add_route('/comments/', api.CommentListApi())
app.add_route('/comment/{comment_id}', api.CommentDetailApi())
app.add_route('/comment/{comment_id}/score', api.CommentScoreApi())
app.add_route('/info/', api.InfoApi())
app.add_route('/featured-post/', api.PostFeatureApi())
app.add_route('/snapshots/', api.SnapshotListApi())
return app

View file

@ -0,0 +1,79 @@
''' Exports create_app. '''
import os
import logging
import coloredlogs
from szurubooru import config, errors, rest
# pylint: disable=unused-import
from szurubooru import api, middleware
def _on_auth_error(ex):
raise rest.errors.HttpForbidden(
title='Authentication error', description=str(ex))
def _on_validation_error(ex):
raise rest.errors.HttpBadRequest(
title='Validation error', description=str(ex))
def _on_search_error(ex):
raise rest.errors.HttpBadRequest(
title='Search error', description=str(ex))
def _on_integrity_error(ex):
raise rest.errors.HttpConflict(
title='Integrity violation', description=ex.args[0])
def _on_not_found_error(ex):
raise rest.errors.HttpNotFound(
title='Not found', description=str(ex))
def _on_processing_error(ex):
raise rest.errors.HttpBadRequest(
title='Processing error', description=str(ex))
def validate_config():
'''
Check whether config doesn't contain errors that might prove
lethal at runtime.
'''
from szurubooru.func.auth import RANK_MAP
for privilege, rank in config.config['privileges'].items():
if rank not in RANK_MAP.values():
raise errors.ConfigError(
'Rank %r for privilege %r is missing' % (rank, privilege))
if config.config['default_rank'] not in RANK_MAP.values():
raise errors.ConfigError(
'Default rank %r is not on the list of known ranks' % (
config.config['default_rank']))
for key in ['base_url', 'api_url', 'data_url', 'data_dir']:
if not config.config[key]:
raise errors.ConfigError(
'Service is not configured: %r is missing' % key)
if not os.path.isabs(config.config['data_dir']):
raise errors.ConfigError(
'data_dir must be an absolute path')
for key in ['schema', 'host', 'port', 'user', 'pass', 'name']:
if not config.config['database'][key]:
raise errors.ConfigError(
'Database is not configured: %r is missing' % key)
def create_app():
''' Create a WSGI compatible App object. '''
validate_config()
coloredlogs.install(fmt='[%(asctime)-15s] %(name)s %(message)s')
if config.config['debug']:
logging.getLogger('szurubooru').setLevel(logging.INFO)
if config.config['show_sql']:
logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
rest.errors.handle(errors.AuthError, _on_auth_error)
rest.errors.handle(errors.ValidationError, _on_validation_error)
rest.errors.handle(errors.SearchError, _on_search_error)
rest.errors.handle(errors.IntegrityError, _on_integrity_error)
rest.errors.handle(errors.NotFoundError, _on_not_found_error)
rest.errors.handle(errors.ProcessingError, _on_processing_error)
return rest.application

View file

@ -32,13 +32,6 @@ def get_tag_category_snapshot(category):
'default': True if category.default else False,
}
# pylint: disable=invalid-name
serializers = {
'tag': get_tag_snapshot,
'tag_category': get_tag_category_snapshot,
'post': get_post_snapshot,
}
def get_previous_snapshot(snapshot):
assert snapshot
return db.session \
@ -87,6 +80,12 @@ def get_serialized_history(entity):
def _save(operation, entity, auth_user):
assert operation
assert entity
serializers = {
'tag': get_tag_snapshot,
'tag_category': get_tag_category_snapshot,
'post': get_post_snapshot,
}
resource_type, resource_id, resource_repr = db.util.get_resource_info(entity)
now = datetime.datetime.utcnow()

View file

@ -11,6 +11,9 @@ def snake_case_to_lower_camel_case(text):
return components[0].lower() + \
''.join(word[0].upper() + word[1:].lower() for word in components[1:])
def snake_case_to_upper_train_case(text):
return '-'.join(word[0].upper() + word[1:].lower() for word in text.split('_'))
def snake_case_to_lower_camel_case_keys(source):
target = {}
for key, value in source.items():

View file

@ -1,8 +1,6 @@
''' Various hooks that get executed for each request. '''
from szurubooru.middleware.authenticator import Authenticator
from szurubooru.middleware.context_adapter import ContextAdapter
from szurubooru.middleware.require_json import RequireJson
from szurubooru.middleware.db_session import DbSession
from szurubooru.middleware.cache_purger import CachePurger
from szurubooru.middleware.request_logger import RequestLogger
import szurubooru.middleware.db_session
import szurubooru.middleware.authenticator
import szurubooru.middleware.cache_purger
import szurubooru.middleware.request_logger

View file

@ -1,51 +1,44 @@
import base64
import falcon
from szurubooru import db, errors
from szurubooru.func import auth, users
from szurubooru.rest import middleware
from szurubooru.rest.errors import HttpBadRequest
class Authenticator(object):
'''
Authenticates every request and put information on active user in the
request context.
'''
def _authenticate(username, password):
''' Try to authenticate user. Throw AuthError for invalid users. '''
user = users.get_user_by_name(username)
if not auth.is_valid_password(user, password):
raise errors.AuthError('Invalid password.')
return user
def process_request(self, request, _response):
''' Bind the user to request. Update last login time if needed. '''
request.context.user = self._get_user(request)
if request.get_param_as_bool('bump-login') \
and request.context.user.user_id:
users.bump_user_login_time(request.context.user)
request.context.session.commit()
def _create_anonymous_user():
user = db.User()
user.name = None
user.rank = 'anonymous'
return user
def _get_user(self, request):
if not request.auth:
return self._create_anonymous_user()
def _get_user(ctx):
if not ctx.has_header('Authorization'):
return _create_anonymous_user()
try:
auth_type, user_and_password = request.auth.split(' ', 1)
if auth_type.lower() != 'basic':
raise falcon.HTTPBadRequest(
'Invalid authentication type',
'Only basic authorization is supported.')
username, password = base64.decodebytes(
user_and_password.encode('ascii')).decode('utf8').split(':')
return self._authenticate(username, password)
except ValueError as err:
msg = 'Basic authentication header value not properly formed. ' \
+ 'Supplied header {0}. Got error: {1}'
raise falcon.HTTPBadRequest(
'Malformed authentication request',
msg.format(request.auth, str(err)))
try:
auth_type, user_and_password = ctx.get_header('Authorization').split(' ', 1)
if auth_type.lower() != 'basic':
raise HttpBadRequest(
'Only basic HTTP authentication is supported.')
username, password = base64.decodebytes(
user_and_password.encode('ascii')).decode('utf8').split(':')
return _authenticate(username, password)
except ValueError as err:
msg = 'Basic authentication header value are not properly formed. ' \
+ 'Supplied header {0}. Got error: {1}'
raise HttpBadRequest(
msg.format(ctx.get_header('Authorization'), str(err)))
def _authenticate(self, username, password):
''' Try to authenticate user. Throw AuthError for invalid users. '''
user = users.get_user_by_name(username)
if not auth.is_valid_password(user, password):
raise errors.AuthError('Invalid password.')
return user
def _create_anonymous_user(self):
user = db.User()
user.name = None
user.rank = 'anonymous'
return user
@middleware.pre_hook
def process_request(ctx):
''' Bind the user to request. Update last login time if needed. '''
ctx.user = _get_user(ctx)
if ctx.get_param_as_bool('bump-login') and ctx.user.user_id:
users.bump_user_login_time(ctx.user)
ctx.session.commit()

View file

@ -1,6 +1,7 @@
from szurubooru.func import cache
from szurubooru.rest import middleware
class CachePurger(object):
def process_request(self, request, _response):
if request.method != 'GET':
cache.purge()
@middleware.pre_hook
def process_request(ctx):
if ctx.method != 'GET':
cache.purge()

View file

@ -1,65 +0,0 @@
import cgi
import datetime
import json
import falcon
def json_serializer(obj):
''' JSON serializer for objects not serializable by default JSON code '''
if isinstance(obj, datetime.datetime):
serial = obj.isoformat('T') + 'Z'
return serial
raise TypeError('Type not serializable')
class ContextAdapter(object):
'''
1. Deserialize API requests into the context:
- Pass GET parameters
- Handle multipart/form-data file uploads
- Handle JSON requests
2. Serialize API responses from the context as JSON.
'''
def process_request(self, request, _response):
request.context.files = {}
request.context.input = {}
request.context.output = None
# pylint: disable=protected-access
for key, value in request._params.items():
request.context.input[key] = value
if request.content_length in (None, 0):
return
if request.content_type and 'multipart/form-data' in request.content_type:
# obscure, claims to "avoid a bug in cgi.FieldStorage"
request.env.setdefault('QUERY_STRING', '')
form = cgi.FieldStorage(fp=request.stream, environ=request.env)
for key in form:
if key != 'metadata':
_original_file_name = getattr(form[key], 'filename', None)
request.context.files[key] = form.getvalue(key)
body = form.getvalue('metadata')
else:
body = request.stream.read()
if not body:
raise falcon.HTTPBadRequest(
'Empty request body',
'A valid JSON document is required.')
try:
if isinstance(body, bytes):
body = body.decode('utf-8')
for key, value in json.loads(body).items():
request.context.input[key] = value
except (ValueError, UnicodeDecodeError):
raise falcon.HTTPBadRequest(
'Malformed JSON',
'Could not decode the request body. The '
'JSON was incorrect or not encoded as UTF-8.')
def process_response(self, request, response, _resource):
if request.context.output:
response.body = json.dumps(
request.context.output, default=json_serializer, indent=2)

View file

@ -1,14 +1,11 @@
import logging
from szurubooru import db
from szurubooru.rest import middleware
logger = logging.getLogger(__name__)
@middleware.pre_hook
def _process_request(ctx):
ctx.session = db.session()
db.reset_query_count()
class DbSession(object):
''' Attaches database session to the context of every request. '''
def process_request(self, request, _response):
request.context.session = db.session()
db.reset_query_count()
def process_response(self, _request, _response, _resource):
db.session.remove()
@middleware.post_hook
def _process_response(_ctx):
db.session.remove()

View file

@ -1,16 +1,14 @@
import logging
from szurubooru import db
from szurubooru.rest import middleware
logger = logging.getLogger(__name__)
class RequestLogger(object):
def process_request(self, request, _response):
pass
def process_response(self, request, _response, _resource):
logger.info(
'%s %s (user=%s, queries=%d)',
request.method,
request.url,
request.context.user.name,
db.get_query_count())
@middleware.post_hook
def process_response(ctx):
logger.info(
'%s %s (user=%s, queries=%d)',
ctx.method,
ctx.url,
ctx.user.name,
db.get_query_count())

View file

@ -1,9 +0,0 @@
import falcon
class RequireJson(object):
''' Sanitizes requests so that only JSON is accepted. '''
def process_request(self, request, _response):
if not request.client_accepts_json:
raise falcon.HTTPNotAcceptable(
'This API only supports responses encoded as JSON.')

View file

@ -0,0 +1,2 @@
from szurubooru.rest.app import application
from szurubooru.rest.context import Context

View file

@ -0,0 +1,124 @@
import cgi
import io
import json
import re
from datetime import datetime
from szurubooru.func import util
from szurubooru.rest import errors, middleware, routes, context
def _json_serializer(obj):
''' JSON serializer for objects not serializable by default JSON code '''
if isinstance(obj, datetime):
serial = obj.isoformat('T') + 'Z'
return serial
raise TypeError('Type not serializable')
def _dump_json(obj):
return json.dumps(obj, default=_json_serializer, indent=2)
def _read(env):
length = int(env.get('CONTENT_LENGTH', 0))
output = io.BytesIO()
while length > 0:
part = env['wsgi.input'].read(min(length, 1024*200))
if not part:
break
output.write(part)
length -= len(part)
output.seek(0)
return output
def _get_headers(env):
headers = {}
for key, value in env.items():
if key.startswith('HTTP_'):
key = util.snake_case_to_upper_train_case(key[5:])
headers[key] = value
return headers
def _create_context(env):
method = env['REQUEST_METHOD']
path = '/' + env['PATH_INFO'].lstrip('/')
headers = _get_headers(env)
# obscure, claims to "avoid a bug in cgi.FieldStorage"
env.setdefault('QUERY_STRING', '')
files = {}
params = {}
request_stream = _read(env)
form = cgi.FieldStorage(fp=request_stream, environ=env)
if form.list:
for key in form:
if key != 'metadata':
if isinstance(form[key], cgi.MiniFieldStorage):
params[key] = form.getvalue(key)
else:
_original_file_name = getattr(form[key], 'filename', None)
files[key] = form.getvalue(key)
if 'metadata' in form:
body = form.getvalue('metadata')
else:
body = request_stream.read()
else:
body = None
if body:
try:
if isinstance(body, bytes):
body = body.decode('utf-8')
for key, value in json.loads(body).items():
params[key] = value
except (ValueError, UnicodeDecodeError):
raise errors.HttpBadRequest(
'Could not decode the request body. The JSON '
'was incorrect or was not encoded as UTF-8.')
return context.Context(method, path, headers, params, files)
def application(env, start_response):
try:
ctx = _create_context(env)
if not 'application/json' in ctx.get_header('Accept'):
raise errors.HttpNotAcceptable(
'This API only supports JSON responses.')
for url, allowed_methods in routes.routes.items():
match = re.fullmatch(url, ctx.url)
if not match:
continue
if ctx.method not in allowed_methods:
raise errors.HttpMethodNotAllowed(
'Allowed methods: %r' % allowed_methods)
for hook in middleware.pre_hooks:
hook(ctx)
handler = allowed_methods[ctx.method]
try:
response = handler(ctx, match.groupdict())
except Exception as ex:
for exception_type, handler in errors.error_handlers.items():
if isinstance(ex, exception_type):
handler(ex)
raise
finally:
for hook in middleware.post_hooks:
hook(ctx)
start_response('200', [('content-type', 'application/json')])
return (_dump_json(response).encode('utf-8'),)
raise errors.HttpNotFound(
'Requested path ' + ctx.url + ' was not found.')
except errors.BaseHttpError as ex:
start_response(
'%d %s' % (ex.code, ex.reason),
[('content-type', 'application/json')])
return (_dump_json({
'title': ex.title,
'description': ex.description,
}).encode('utf-8'),)

View file

@ -1,4 +1,3 @@
import falcon
from szurubooru import errors
from szurubooru.func import net
@ -7,8 +6,9 @@ def _lower_first(source):
def _param_wrapper(func):
def wrapper(self, name, required=False, default=None, **kwargs):
if name in self.input:
value = self.input[name]
# pylint: disable=protected-access
if name in self._params:
value = self._params[name]
try:
value = func(self, value, **kwargs)
except errors.InvalidParameterError as ex:
@ -22,34 +22,46 @@ def _param_wrapper(func):
'Required parameter %r is missing.' % name)
return wrapper
class Context(object):
def __init__(self):
self.session = None
self.user = None
self.files = {}
self.input = {}
self.output = None
self.settings = {}
class Context():
# pylint: disable=too-many-arguments
def __init__(self, method, url, headers=None, params=None, files=None):
self.method = method
self.url = url
self._headers = headers or {}
self._params = params or {}
self._files = files or {}
def has_param(self, name):
return name in self.input
# provided by middleware
# self.session = None
# self.user = None
def has_header(self, name):
return name in self._headers
def get_header(self, name):
return self._headers.get(name, None)
def has_file(self, name):
return name in self.files or name + 'Url' in self.input
return name in self._files or name + 'Url' in self._params
def get_file(self, name, required=False):
if name in self.files:
return self.files[name]
if name + 'Url' in self.input:
return net.download(self.input[name + 'Url'])
if name in self._files:
return self._files[name]
if name + 'Url' in self._params:
return net.download(self._params[name + 'Url'])
if not required:
return None
raise errors.MissingRequiredFileError(
'Required file %r is missing.' % name)
def has_param(self, name):
return name in self._params
@_param_wrapper
def get_param_as_list(self, value):
if not isinstance(value, list):
if ',' in value:
return value.split(',')
return [value]
return value
@ -86,6 +98,3 @@ class Context(object):
if value in ['0', 'n', 'no', 'nope', 'f', 'false']:
return False
raise errors.InvalidParameterError('The value must be a boolean value.')
class Request(falcon.Request):
context_type = Context

View file

@ -0,0 +1,37 @@
error_handlers = {} # pylint: disable=invalid-name
class BaseHttpError(RuntimeError):
code = None
reason = None
def __init__(self, description, title=None):
super().__init__()
self.description = description
self.title = title or self.reason
class HttpBadRequest(BaseHttpError):
code = 400
reason = 'Bad Request'
class HttpForbidden(BaseHttpError):
code = 403
reason = 'Forbidden'
class HttpNotFound(BaseHttpError):
code = 404
reason = 'Not Found'
class HttpNotAcceptable(BaseHttpError):
code = 406
reason = 'Not Acceptable'
class HttpConflict(BaseHttpError):
code = 409
reason = 'Conflict'
class HttpMethodNotAllowed(BaseHttpError):
code = 405
reason = 'Method Not Allowed'
def handle(exception_type, handler):
error_handlers[exception_type] = handler

View file

@ -0,0 +1,9 @@
# pylint: disable=invalid-name
pre_hooks = []
post_hooks = []
def pre_hook(handler):
pre_hooks.append(handler)
def post_hook(handler):
post_hooks.insert(0, handler)

View file

@ -0,0 +1,27 @@
from collections import defaultdict
routes = defaultdict(dict) # pylint: disable=invalid-name
def get(url):
def wrapper(handler):
routes[url]['GET'] = handler
return handler
return wrapper
def put(url):
def wrapper(handler):
routes[url]['PUT'] = handler
return handler
return wrapper
def post(url):
def wrapper(handler):
routes[url]['POST'] = handler
return handler
return wrapper
def delete(url):
def wrapper(handler):
routes[url]['DELETE'] = handler
return handler
return wrapper

View file

@ -1,89 +1,78 @@
import datetime
import pytest
import unittest.mock
from datetime import datetime
from szurubooru import api, db, errors
from szurubooru.func import util, posts
from szurubooru.func import comments, posts
@pytest.fixture
def test_ctx(
tmpdir, config_injector, context_factory, post_factory, user_factory):
config_injector({
'data_dir': str(tmpdir),
'data_url': 'http://example.com',
'privileges': {'comments:create': db.User.RANK_REGULAR},
'thumbnails': {'avatar_width': 200},
})
ret = util.dotdict()
ret.context_factory = context_factory
ret.post_factory = post_factory
ret.user_factory = user_factory
ret.api = api.CommentListApi()
return ret
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({'privileges': {'comments:create': db.User.RANK_REGULAR}})
def test_creating_comment(test_ctx, fake_datetime):
post = test_ctx.post_factory()
user = test_ctx.user_factory(rank=db.User.RANK_REGULAR)
def test_creating_comment(
user_factory, post_factory, context_factory, fake_datetime):
post = post_factory()
user = user_factory(rank=db.User.RANK_REGULAR)
db.session.add_all([post, user])
db.session.flush()
with fake_datetime('1997-01-01'):
result = test_ctx.api.post(
test_ctx.context_factory(
input={'text': 'input', 'postId': post.post_id},
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'), \
fake_datetime('1997-01-01'):
comments.serialize_comment.return_value = 'serialized comment'
result = api.comment_api.create_comment(
context_factory(
params={'text': 'input', 'postId': post.post_id},
user=user))
assert result['text'] == 'input'
assert 'id' in result
assert 'user' in result
assert 'name' in result['user']
assert 'postId' in result
comment = db.session.query(db.Comment).one()
assert comment.text == 'input'
assert comment.creation_time == datetime.datetime(1997, 1, 1)
assert comment.last_edit_time is None
assert comment.user and comment.user.user_id == user.user_id
assert comment.post and comment.post.post_id == post.post_id
assert result == 'serialized comment'
comment = db.session.query(db.Comment).one()
assert comment.text == 'input'
assert comment.creation_time == datetime(1997, 1, 1)
assert comment.last_edit_time is None
assert comment.user and comment.user.user_id == user.user_id
assert comment.post and comment.post.post_id == post.post_id
@pytest.mark.parametrize('input', [
@pytest.mark.parametrize('params', [
{'text': None},
{'text': ''},
{'text': [None]},
{'text': ['']},
])
def test_trying_to_pass_invalid_input(test_ctx, input):
post = test_ctx.post_factory()
user = test_ctx.user_factory(rank=db.User.RANK_REGULAR)
def test_trying_to_pass_invalid_params(
user_factory, post_factory, context_factory, params):
post = post_factory()
user = user_factory(rank=db.User.RANK_REGULAR)
db.session.add_all([post, user])
db.session.flush()
real_input = {'text': 'input', 'postId': post.post_id}
for key, value in input.items():
real_input[key] = value
real_params = {'text': 'input', 'postId': post.post_id}
for key, value in params.items():
real_params[key] = value
with pytest.raises(errors.ValidationError):
test_ctx.api.post(
test_ctx.context_factory(input=real_input, user=user))
api.comment_api.create_comment(
context_factory(params=real_params, user=user))
@pytest.mark.parametrize('field', ['text', 'postId'])
def test_trying_to_omit_mandatory_field(test_ctx, field):
input = {
def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
params = {
'text': 'input',
'postId': 1,
}
del input[field]
del params[field]
with pytest.raises(errors.ValidationError):
test_ctx.api.post(
test_ctx.context_factory(
input={},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
api.comment_api.create_comment(
context_factory(
params={},
user=user_factory(rank=db.User.RANK_REGULAR)))
def test_trying_to_comment_non_existing(test_ctx):
user = test_ctx.user_factory(rank=db.User.RANK_REGULAR)
def test_trying_to_comment_non_existing(user_factory, context_factory):
user = user_factory(rank=db.User.RANK_REGULAR)
db.session.add_all([user])
db.session.flush()
with pytest.raises(posts.PostNotFoundError):
test_ctx.api.post(
test_ctx.context_factory(
input={'text': 'bad', 'postId': 5}, user=user))
api.comment_api.create_comment(
context_factory(
params={'text': 'bad', 'postId': 5}, user=user))
def test_trying_to_create_without_privileges(test_ctx):
def test_trying_to_create_without_privileges(user_factory, context_factory):
with pytest.raises(errors.AuthError):
test_ctx.api.post(
test_ctx.context_factory(
input={},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)))
api.comment_api.create_comment(
context_factory(
params={},
user=user_factory(rank=db.User.RANK_ANONYMOUS)))

View file

@ -1,61 +1,56 @@
import pytest
from datetime import datetime
from szurubooru import api, db, errors
from szurubooru.func import util, comments
from szurubooru.func import comments
@pytest.fixture
def test_ctx(config_injector, context_factory, user_factory, comment_factory):
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({
'privileges': {
'comments:delete:own': db.User.RANK_REGULAR,
'comments:delete:any': db.User.RANK_MODERATOR,
},
})
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.comment_factory = comment_factory
ret.api = api.CommentDetailApi()
return ret
def test_deleting_own_comment(test_ctx):
user = test_ctx.user_factory()
comment = test_ctx.comment_factory(user=user)
def test_deleting_own_comment(user_factory, comment_factory, context_factory):
user = user_factory()
comment = comment_factory(user=user)
db.session.add(comment)
db.session.commit()
result = test_ctx.api.delete(
test_ctx.context_factory(input={'version': 1}, user=user),
comment.comment_id)
result = api.comment_api.delete_comment(
context_factory(params={'version': 1}, user=user),
{'comment_id': comment.comment_id})
assert result == {}
assert db.session.query(db.Comment).count() == 0
def test_deleting_someones_else_comment(test_ctx):
user1 = test_ctx.user_factory(rank=db.User.RANK_REGULAR)
user2 = test_ctx.user_factory(rank=db.User.RANK_MODERATOR)
comment = test_ctx.comment_factory(user=user1)
def test_deleting_someones_else_comment(
user_factory, comment_factory, context_factory):
user1 = user_factory(rank=db.User.RANK_REGULAR)
user2 = user_factory(rank=db.User.RANK_MODERATOR)
comment = comment_factory(user=user1)
db.session.add(comment)
db.session.commit()
result = test_ctx.api.delete(
test_ctx.context_factory(input={'version': 1}, user=user2),
comment.comment_id)
result = api.comment_api.delete_comment(
context_factory(params={'version': 1}, user=user2),
{'comment_id': comment.comment_id})
assert db.session.query(db.Comment).count() == 0
def test_trying_to_delete_someones_else_comment_without_privileges(test_ctx):
user1 = test_ctx.user_factory(rank=db.User.RANK_REGULAR)
user2 = test_ctx.user_factory(rank=db.User.RANK_REGULAR)
comment = test_ctx.comment_factory(user=user1)
def test_trying_to_delete_someones_else_comment_without_privileges(
user_factory, comment_factory, context_factory):
user1 = user_factory(rank=db.User.RANK_REGULAR)
user2 = user_factory(rank=db.User.RANK_REGULAR)
comment = comment_factory(user=user1)
db.session.add(comment)
db.session.commit()
with pytest.raises(errors.AuthError):
test_ctx.api.delete(
test_ctx.context_factory(input={'version': 1}, user=user2),
comment.comment_id)
api.comment_api.delete_comment(
context_factory(params={'version': 1}, user=user2),
{'comment_id': comment.comment_id})
assert db.session.query(db.Comment).count() == 1
def test_trying_to_delete_non_existing(test_ctx):
def test_trying_to_delete_non_existing(user_factory, context_factory):
with pytest.raises(comments.CommentNotFoundError):
test_ctx.api.delete(
test_ctx.context_factory(
input={'version': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
1)
api.comment_api.delete_comment(
context_factory(
params={'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)),
{'comment_id': 1})

View file

@ -1,152 +1,134 @@
import datetime
import pytest
import unittest.mock
from szurubooru import api, db, errors
from szurubooru.func import util, comments, scores
from szurubooru.func import comments, scores
@pytest.fixture
def test_ctx(
tmpdir, config_injector, context_factory, user_factory, comment_factory):
config_injector({
'data_dir': str(tmpdir),
'data_url': 'http://example.com',
'privileges': {
'comments:score': db.User.RANK_REGULAR,
'users:edit:any:email': db.User.RANK_MODERATOR,
},
'thumbnails': {'avatar_width': 200},
})
db.session.flush()
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.comment_factory = comment_factory
ret.api = api.CommentScoreApi()
return ret
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({'privileges': {'comments:score': db.User.RANK_REGULAR}})
def test_simple_rating(test_ctx, fake_datetime):
user = test_ctx.user_factory(rank=db.User.RANK_REGULAR)
comment = test_ctx.comment_factory(user=user)
def test_simple_rating(
user_factory, comment_factory, context_factory, fake_datetime):
user = user_factory(rank=db.User.RANK_REGULAR)
comment = comment_factory(user=user)
db.session.add(comment)
db.session.commit()
with fake_datetime('1997-12-01'):
result = test_ctx.api.put(
test_ctx.context_factory(input={'score': 1}, user=user),
comment.comment_id)
assert 'text' in result
assert db.session.query(db.CommentScore).count() == 1
assert comment is not None
assert comment.score == 1
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'):
comments.serialize_comment.return_value = 'serialized comment'
with fake_datetime('1997-12-01'):
result = api.comment_api.set_comment_score(
context_factory(params={'score': 1}, user=user),
{'comment_id': comment.comment_id})
assert result == 'serialized comment'
assert db.session.query(db.CommentScore).count() == 1
assert comment is not None
assert comment.score == 1
def test_updating_rating(test_ctx, fake_datetime):
user = test_ctx.user_factory(rank=db.User.RANK_REGULAR)
comment = test_ctx.comment_factory(user=user)
def test_updating_rating(
user_factory, comment_factory, context_factory, fake_datetime):
user = user_factory(rank=db.User.RANK_REGULAR)
comment = comment_factory(user=user)
db.session.add(comment)
db.session.commit()
with fake_datetime('1997-12-01'):
result = test_ctx.api.put(
test_ctx.context_factory(input={'score': 1}, user=user),
comment.comment_id)
with fake_datetime('1997-12-02'):
result = test_ctx.api.put(
test_ctx.context_factory(input={'score': -1}, user=user),
comment.comment_id)
comment = db.session.query(db.Comment).one()
assert db.session.query(db.CommentScore).count() == 1
assert comment.score == -1
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'):
with fake_datetime('1997-12-01'):
result = api.comment_api.set_comment_score(
context_factory(params={'score': 1}, user=user),
{'comment_id': comment.comment_id})
with fake_datetime('1997-12-02'):
result = api.comment_api.set_comment_score(
context_factory(params={'score': -1}, user=user),
{'comment_id': comment.comment_id})
comment = db.session.query(db.Comment).one()
assert db.session.query(db.CommentScore).count() == 1
assert comment.score == -1
def test_updating_rating_to_zero(test_ctx, fake_datetime):
user = test_ctx.user_factory(rank=db.User.RANK_REGULAR)
comment = test_ctx.comment_factory(user=user)
def test_updating_rating_to_zero(
user_factory, comment_factory, context_factory, fake_datetime):
user = user_factory(rank=db.User.RANK_REGULAR)
comment = comment_factory(user=user)
db.session.add(comment)
db.session.commit()
with fake_datetime('1997-12-01'):
result = test_ctx.api.put(
test_ctx.context_factory(input={'score': 1}, user=user),
comment.comment_id)
with fake_datetime('1997-12-02'):
result = test_ctx.api.put(
test_ctx.context_factory(input={'score': 0}, user=user),
comment.comment_id)
comment = db.session.query(db.Comment).one()
assert db.session.query(db.CommentScore).count() == 0
assert comment.score == 0
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'):
with fake_datetime('1997-12-01'):
result = api.comment_api.set_comment_score(
context_factory(params={'score': 1}, user=user),
{'comment_id': comment.comment_id})
with fake_datetime('1997-12-02'):
result = api.comment_api.set_comment_score(
context_factory(params={'score': 0}, user=user),
{'comment_id': comment.comment_id})
comment = db.session.query(db.Comment).one()
assert db.session.query(db.CommentScore).count() == 0
assert comment.score == 0
def test_deleting_rating(test_ctx, fake_datetime):
user = test_ctx.user_factory(rank=db.User.RANK_REGULAR)
comment = test_ctx.comment_factory(user=user)
def test_deleting_rating(
user_factory, comment_factory, context_factory, fake_datetime):
user = user_factory(rank=db.User.RANK_REGULAR)
comment = comment_factory(user=user)
db.session.add(comment)
db.session.commit()
with fake_datetime('1997-12-01'):
result = test_ctx.api.put(
test_ctx.context_factory(input={'score': 1}, user=user),
comment.comment_id)
with fake_datetime('1997-12-02'):
result = test_ctx.api.delete(
test_ctx.context_factory(user=user), comment.comment_id)
comment = db.session.query(db.Comment).one()
assert db.session.query(db.CommentScore).count() == 0
assert comment.score == 0
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'):
with fake_datetime('1997-12-01'):
result = api.comment_api.set_comment_score(
context_factory(params={'score': 1}, user=user),
{'comment_id': comment.comment_id})
with fake_datetime('1997-12-02'):
result = api.comment_api.delete_comment_score(
context_factory(user=user),
{'comment_id': comment.comment_id})
comment = db.session.query(db.Comment).one()
assert db.session.query(db.CommentScore).count() == 0
assert comment.score == 0
def test_ratings_from_multiple_users(test_ctx, fake_datetime):
user1 = test_ctx.user_factory(rank=db.User.RANK_REGULAR)
user2 = test_ctx.user_factory(rank=db.User.RANK_REGULAR)
comment = test_ctx.comment_factory()
def test_ratings_from_multiple_users(
user_factory, comment_factory, context_factory, fake_datetime):
user1 = user_factory(rank=db.User.RANK_REGULAR)
user2 = user_factory(rank=db.User.RANK_REGULAR)
comment = comment_factory()
db.session.add_all([user1, user2, comment])
db.session.commit()
with fake_datetime('1997-12-01'):
result = test_ctx.api.put(
test_ctx.context_factory(input={'score': 1}, user=user1),
comment.comment_id)
with fake_datetime('1997-12-02'):
result = test_ctx.api.put(
test_ctx.context_factory(input={'score': -1}, user=user2),
comment.comment_id)
comment = db.session.query(db.Comment).one()
assert db.session.query(db.CommentScore).count() == 2
assert comment.score == 0
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'):
with fake_datetime('1997-12-01'):
result = api.comment_api.set_comment_score(
context_factory(params={'score': 1}, user=user1),
{'comment_id': comment.comment_id})
with fake_datetime('1997-12-02'):
result = api.comment_api.set_comment_score(
context_factory(params={'score': -1}, user=user2),
{'comment_id': comment.comment_id})
comment = db.session.query(db.Comment).one()
assert db.session.query(db.CommentScore).count() == 2
assert comment.score == 0
@pytest.mark.parametrize('input,expected_exception', [
({'score': None}, errors.ValidationError),
({'score': ''}, errors.ValidationError),
({'score': -2}, scores.InvalidScoreValueError),
({'score': 2}, scores.InvalidScoreValueError),
({'score': [1]}, errors.ValidationError),
])
def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception):
user = test_ctx.user_factory()
comment = test_ctx.comment_factory(user=user)
db.session.add(comment)
db.session.commit()
with pytest.raises(expected_exception):
test_ctx.api.put(
test_ctx.context_factory(input=input, user=user),
comment.comment_id)
def test_trying_to_omit_mandatory_field(test_ctx):
user = test_ctx.user_factory()
comment = test_ctx.comment_factory(user=user)
def test_trying_to_omit_mandatory_field(
user_factory, comment_factory, context_factory):
user = user_factory()
comment = comment_factory(user=user)
db.session.add(comment)
db.session.commit()
with pytest.raises(errors.ValidationError):
test_ctx.api.put(
test_ctx.context_factory(input={}, user=user),
comment.comment_id)
api.comment_api.set_comment_score(
context_factory(params={}, user=user),
{'comment_id': comment.comment_id})
def test_trying_to_update_non_existing(test_ctx):
def test_trying_to_update_non_existing(
user_factory, comment_factory, context_factory):
with pytest.raises(comments.CommentNotFoundError):
test_ctx.api.put(
test_ctx.context_factory(
input={'score': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
5)
api.comment_api.set_comment_score(
context_factory(
params={'score': 1},
user=user_factory(rank=db.User.RANK_REGULAR)),
{'comment_id': 5})
def test_trying_to_rate_without_privileges(test_ctx):
comment = test_ctx.comment_factory()
def test_trying_to_rate_without_privileges(
user_factory, comment_factory, context_factory):
comment = comment_factory()
db.session.add(comment)
db.session.commit()
with pytest.raises(errors.AuthError):
test_ctx.api.put(
test_ctx.context_factory(
input={'score': 1},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)),
comment.comment_id)
api.comment_api.set_comment_score(
context_factory(
params={'score': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
{'comment_id': comment.comment_id})

View file

@ -1,76 +1,65 @@
import datetime
import pytest
import unittest.mock
from szurubooru import api, db, errors
from szurubooru.func import util, comments
from szurubooru.func import comments
@pytest.fixture
def test_ctx(
tmpdir, context_factory, config_injector, user_factory, comment_factory):
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({
'data_dir': str(tmpdir),
'data_url': 'http://example.com',
'privileges': {
'comments:list': db.User.RANK_REGULAR,
'comments:view': db.User.RANK_REGULAR,
'users:edit:any:email': db.User.RANK_MODERATOR,
},
'thumbnails': {'avatar_width': 200},
})
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.comment_factory = comment_factory
ret.list_api = api.CommentListApi()
ret.detail_api = api.CommentDetailApi()
return ret
def test_retrieving_multiple(test_ctx):
comment1 = test_ctx.comment_factory(text='text 1')
comment2 = test_ctx.comment_factory(text='text 2')
def test_retrieving_multiple(user_factory, comment_factory, context_factory):
comment1 = comment_factory(text='text 1')
comment2 = comment_factory(text='text 2')
db.session.add_all([comment1, comment2])
result = test_ctx.list_api.get(
test_ctx.context_factory(
input={'query': '', 'page': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
assert result['query'] == ''
assert result['page'] == 1
assert result['pageSize'] == 100
assert result['total'] == 2
assert [c['text'] for c in result['results']] == ['text 1', 'text 2']
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'):
comments.serialize_comment.return_value = 'serialized comment'
result = api.comment_api.get_comments(
context_factory(
params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_REGULAR)))
assert result == {
'query': '',
'page': 1,
'pageSize': 100,
'total': 2,
'results': ['serialized comment', 'serialized comment'],
}
def test_trying_to_retrieve_multiple_without_privileges(test_ctx):
def test_trying_to_retrieve_multiple_without_privileges(
user_factory, context_factory):
with pytest.raises(errors.AuthError):
test_ctx.list_api.get(
test_ctx.context_factory(
input={'query': '', 'page': 1},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)))
api.comment_api.get_comments(
context_factory(
params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
def test_retrieving_single(test_ctx):
comment = test_ctx.comment_factory(text='dummy text')
def test_retrieving_single(user_factory, comment_factory, context_factory):
comment = comment_factory(text='dummy text')
db.session.add(comment)
db.session.flush()
result = test_ctx.detail_api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
comment.comment_id)
assert 'id' in result
assert 'lastEditTime' in result
assert 'creationTime' in result
assert 'text' in result
assert 'user' in result
assert 'name' in result['user']
assert 'postId' in result
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'):
comments.serialize_comment.return_value = 'serialized comment'
result = api.comment_api.get_comment(
context_factory(
user=user_factory(rank=db.User.RANK_REGULAR)),
{'comment_id': comment.comment_id})
assert result == 'serialized comment'
def test_trying_to_retrieve_single_non_existing(test_ctx):
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
with pytest.raises(comments.CommentNotFoundError):
test_ctx.detail_api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
5)
api.comment_api.get_comment(
context_factory(
user=user_factory(rank=db.User.RANK_REGULAR)),
{'comment_id': 5})
def test_trying_to_retrieve_single_without_privileges(test_ctx):
def test_trying_to_retrieve_single_without_privileges(
user_factory, context_factory):
with pytest.raises(errors.AuthError):
test_ctx.detail_api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)),
5)
api.comment_api.get_comment(
context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)),
{'comment_id': 5})

View file

@ -1,103 +1,94 @@
import datetime
import pytest
import unittest.mock
from datetime import datetime
from szurubooru import api, db, errors
from szurubooru.func import util, comments
from szurubooru.func import comments
@pytest.fixture
def test_ctx(
tmpdir, config_injector, context_factory, user_factory, comment_factory):
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({
'data_dir': str(tmpdir),
'data_url': 'http://example.com',
'privileges': {
'comments:edit:own': db.User.RANK_REGULAR,
'comments:edit:any': db.User.RANK_MODERATOR,
'users:edit:any:email': db.User.RANK_MODERATOR,
},
'thumbnails': {'avatar_width': 200},
})
db.session.flush()
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.comment_factory = comment_factory
ret.api = api.CommentDetailApi()
return ret
def test_simple_updating(test_ctx, fake_datetime):
user = test_ctx.user_factory(rank=db.User.RANK_REGULAR)
comment = test_ctx.comment_factory(user=user)
def test_simple_updating(
user_factory, comment_factory, context_factory, fake_datetime):
user = user_factory(rank=db.User.RANK_REGULAR)
comment = comment_factory(user=user)
db.session.add(comment)
db.session.commit()
with fake_datetime('1997-12-01'):
result = test_ctx.api.put(
test_ctx.context_factory(
input={'text': 'new text', 'version': 1}, user=user),
comment.comment_id)
assert result['text'] == 'new text'
comment = db.session.query(db.Comment).one()
assert comment is not None
assert comment.text == 'new text'
assert comment.last_edit_time is not None
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'), \
fake_datetime('1997-12-01'):
comments.serialize_comment.return_value = 'serialized comment'
result = api.comment_api.update_comment(
context_factory(
params={'text': 'new text', 'version': 1}, user=user),
{'comment_id': comment.comment_id})
assert result == 'serialized comment'
assert comment.last_edit_time == datetime(1997, 12, 1)
@pytest.mark.parametrize('input,expected_exception', [
@pytest.mark.parametrize('params,expected_exception', [
({'text': None}, comments.EmptyCommentTextError),
({'text': ''}, comments.EmptyCommentTextError),
({'text': []}, comments.EmptyCommentTextError),
({'text': [None]}, errors.ValidationError),
({'text': ['']}, comments.EmptyCommentTextError),
])
def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception):
user = test_ctx.user_factory()
comment = test_ctx.comment_factory(user=user)
def test_trying_to_pass_invalid_params(
user_factory, comment_factory, context_factory, params, expected_exception):
user = user_factory()
comment = comment_factory(user=user)
db.session.add(comment)
db.session.commit()
with pytest.raises(expected_exception):
test_ctx.api.put(
test_ctx.context_factory(
input={**input, **{'version': 1}}, user=user),
comment.comment_id)
api.comment_api.update_comment(
context_factory(
params={**params, **{'version': 1}}, user=user),
{'comment_id': comment.comment_id})
def test_trying_to_omit_mandatory_field(test_ctx):
user = test_ctx.user_factory()
comment = test_ctx.comment_factory(user=user)
def test_trying_to_omit_mandatory_field(
user_factory, comment_factory, context_factory):
user = user_factory()
comment = comment_factory(user=user)
db.session.add(comment)
db.session.commit()
with pytest.raises(errors.ValidationError):
test_ctx.api.put(
test_ctx.context_factory(input={'version': 1}, user=user),
comment.comment_id)
api.comment_api.update_comment(
context_factory(params={'version': 1}, user=user),
{'comment_id': comment.comment_id})
def test_trying_to_update_non_existing(test_ctx):
def test_trying_to_update_non_existing(user_factory, context_factory):
with pytest.raises(comments.CommentNotFoundError):
test_ctx.api.put(
test_ctx.context_factory(
input={'text': 'new text'},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
5)
api.comment_api.update_comment(
context_factory(
params={'text': 'new text'},
user=user_factory(rank=db.User.RANK_REGULAR)),
{'comment_id': 5})
def test_trying_to_update_someones_comment_without_privileges(test_ctx):
user = test_ctx.user_factory(rank=db.User.RANK_REGULAR)
user2 = test_ctx.user_factory(rank=db.User.RANK_REGULAR)
comment = test_ctx.comment_factory(user=user)
def test_trying_to_update_someones_comment_without_privileges(
user_factory, comment_factory, context_factory):
user = user_factory(rank=db.User.RANK_REGULAR)
user2 = user_factory(rank=db.User.RANK_REGULAR)
comment = comment_factory(user=user)
db.session.add(comment)
db.session.commit()
with pytest.raises(errors.AuthError):
test_ctx.api.put(
test_ctx.context_factory(
input={'text': 'new text', 'version': 1}, user=user2),
comment.comment_id)
api.comment_api.update_comment(
context_factory(
params={'text': 'new text', 'version': 1}, user=user2),
{'comment_id': comment.comment_id})
def test_updating_someones_comment_with_privileges(test_ctx):
user = test_ctx.user_factory(rank=db.User.RANK_REGULAR)
user2 = test_ctx.user_factory(rank=db.User.RANK_MODERATOR)
comment = test_ctx.comment_factory(user=user)
def test_updating_someones_comment_with_privileges(
user_factory, comment_factory, context_factory):
user = user_factory(rank=db.User.RANK_REGULAR)
user2 = user_factory(rank=db.User.RANK_MODERATOR)
comment = comment_factory(user=user)
db.session.add(comment)
db.session.commit()
try:
test_ctx.api.put(
test_ctx.context_factory(
input={'text': 'new text', 'version': 1}, user=user2),
comment.comment_id)
except:
pytest.fail()
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'):
api.comment_api.update_comment(
context_factory(
params={'text': 'new text', 'version': 1}, user=user2),
{'comment_id': comment.comment_id})

View file

@ -31,9 +31,8 @@ def test_info_api(
},
}
info_api = api.InfoApi()
with fake_datetime('2016-01-01 13:00'):
assert info_api.get(context_factory()) == {
assert api.info_api.get_info(context_factory()) == {
'postCount': 2,
'diskUsage': 3,
'featuredPost': None,
@ -44,7 +43,7 @@ def test_info_api(
}
directory.join('test2.txt').write('abc')
with fake_datetime('2016-01-01 13:59'):
assert info_api.get(context_factory()) == {
assert api.info_api.get_info(context_factory()) == {
'postCount': 2,
'diskUsage': 3, # still 3 - it's cached
'featuredPost': None,
@ -54,7 +53,7 @@ def test_info_api(
'config': expected_config_key,
}
with fake_datetime('2016-01-01 14:01'):
assert info_api.get(context_factory()) == {
assert api.info_api.get_info(context_factory()) == {
'postCount': 2,
'diskUsage': 6, # cache expired
'featuredPost': None,

View file

@ -1,71 +1,70 @@
from datetime import datetime
from unittest import mock
import pytest
import unittest.mock
from szurubooru import api, db, errors
from szurubooru.func import auth, mailer
@pytest.fixture
def password_reset_api(config_injector):
@pytest.fixture(autouse=True)
def inject_config(tmpdir, config_injector):
config_injector({
'secret': 'x',
'base_url': 'http://example.com/',
'name': 'Test instance',
})
return api.PasswordResetApi()
def test_reset_sending_email(
password_reset_api, context_factory, user_factory):
def test_reset_sending_email(context_factory, user_factory):
db.session.add(user_factory(
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com'))
for getter in ['u1', 'user@example.com']:
mailer.send_mail = mock.MagicMock()
assert password_reset_api.get(context_factory(), getter) == {}
mailer.send_mail.assert_called_once_with(
'noreply@Test instance',
'user@example.com',
'Password reset for Test instance',
'You (or someone else) requested to reset your password ' +
'on Test instance.\nIf you wish to proceed, click this l' +
'ink: http://example.com/password-reset/u1:4ac0be176fb36' +
'4f13ee6b634c43220e2\nOtherwise, please ignore this email.')
for initiating_user in ['u1', 'user@example.com']:
with unittest.mock.patch('szurubooru.func.mailer.send_mail'):
assert api.password_reset_api.start_password_reset(
context_factory(), {'user_name': initiating_user}) == {}
mailer.send_mail.assert_called_once_with(
'noreply@Test instance',
'user@example.com',
'Password reset for Test instance',
'You (or someone else) requested to reset your password ' +
'on Test instance.\nIf you wish to proceed, click this l' +
'ink: http://example.com/password-reset/u1:4ac0be176fb36' +
'4f13ee6b634c43220e2\nOtherwise, please ignore this email.')
def test_trying_to_reset_non_existing(password_reset_api, context_factory):
def test_trying_to_reset_non_existing(context_factory):
with pytest.raises(errors.NotFoundError):
password_reset_api.get(context_factory(), 'u1')
api.password_reset_api.start_password_reset(
context_factory(), {'user_name': 'u1'})
def test_trying_to_reset_without_email(
password_reset_api, context_factory, user_factory):
def test_trying_to_reset_without_email(context_factory, user_factory):
db.session.add(user_factory(name='u1', rank=db.User.RANK_REGULAR, email=None))
with pytest.raises(errors.ValidationError):
password_reset_api.get(context_factory(), 'u1')
api.password_reset_api.start_password_reset(
context_factory(), {'user_name': 'u1'})
def test_confirming_with_good_token(
password_reset_api, context_factory, user_factory):
def test_confirming_with_good_token(context_factory, user_factory):
user = user_factory(
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')
old_hash = user.password_hash
db.session.add(user)
context = context_factory(
input={'token': '4ac0be176fb364f13ee6b634c43220e2'})
result = password_reset_api.post(context, 'u1')
params={'token': '4ac0be176fb364f13ee6b634c43220e2'})
result = api.password_reset_api.finish_password_reset(
context, {'user_name': 'u1'})
assert user.password_hash != old_hash
assert auth.is_valid_password(user, result['password']) is True
def test_trying_to_confirm_non_existing(password_reset_api, context_factory):
def test_trying_to_confirm_non_existing(context_factory):
with pytest.raises(errors.NotFoundError):
password_reset_api.post(context_factory(), 'u1')
api.password_reset_api.finish_password_reset(
context_factory(), {'user_name': 'u1'})
def test_trying_to_confirm_without_token(
password_reset_api, context_factory, user_factory):
def test_trying_to_confirm_without_token(context_factory, user_factory):
db.session.add(user_factory(
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com'))
with pytest.raises(errors.ValidationError):
password_reset_api.post(context_factory(input={}), 'u1')
api.password_reset_api.finish_password_reset(
context_factory(params={}), {'user_name': 'u1'})
def test_trying_to_confirm_with_bad_token(
password_reset_api, context_factory, user_factory):
def test_trying_to_confirm_with_bad_token(context_factory, user_factory):
db.session.add(user_factory(
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com'))
with pytest.raises(errors.ValidationError):
password_reset_api.post(
context_factory(input={'token': 'bad'}), 'u1')
api.password_reset_api.finish_password_reset(
context_factory(params={'token': 'bad'}), {'user_name': 'u1'})

View file

@ -1,7 +1,5 @@
import datetime
import os
import unittest.mock
import pytest
import unittest.mock
from szurubooru import api, db, errors
from szurubooru.func import posts, tags, snapshots, net
@ -35,9 +33,9 @@ def test_creating_minimal_posts(
posts.create_post.return_value = (post, [])
posts.serialize_post.return_value = 'serialized post'
result = api.PostListApi().post(
result = api.post_api.create_post(
context_factory(
input={
params={
'safety': 'safe',
'tags': ['tag1', 'tag2'],
},
@ -79,9 +77,9 @@ def test_creating_full_posts(context_factory, post_factory, user_factory):
posts.create_post.return_value = (post, [])
posts.serialize_post.return_value = 'serialized post'
result = api.PostListApi().post(
result = api.post_api.create_post(
context_factory(
input={
params={
'safety': 'safe',
'tags': ['tag1', 'tag2'],
'relations': [1, 2],
@ -122,9 +120,9 @@ def test_anonymous_uploads(
'privileges': {'posts:create:anonymous': db.User.RANK_REGULAR},
})
posts.create_post.return_value = [post, []]
api.PostListApi().post(
api.post_api.create_post(
context_factory(
input={
params={
'safety': 'safe',
'tags': ['tag1', 'tag2'],
'anonymous': 'True',
@ -154,9 +152,9 @@ def test_creating_from_url_saves_source(
})
net.download.return_value = b'content'
posts.create_post.return_value = [post, []]
api.PostListApi().post(
api.post_api.create_post(
context_factory(
input={
params={
'safety': 'safe',
'tags': ['tag1', 'tag2'],
'contentUrl': 'example.com',
@ -185,9 +183,9 @@ def test_creating_from_url_with_source_specified(
})
net.download.return_value = b'content'
posts.create_post.return_value = [post, []]
api.PostListApi().post(
api.post_api.create_post(
context_factory(
input={
params={
'safety': 'safe',
'tags': ['tag1', 'tag2'],
'contentUrl': 'example.com',
@ -201,23 +199,23 @@ def test_creating_from_url_with_source_specified(
@pytest.mark.parametrize('field', ['tags', 'safety'])
def test_trying_to_omit_mandatory_field(context_factory, user_factory, field):
input = {
params = {
'safety': 'safe',
'tags': ['tag1', 'tag2'],
}
del input[field]
del params[field]
with pytest.raises(errors.MissingRequiredParameterError):
api.PostListApi().post(
api.post_api.create_post(
context_factory(
input=input,
params=params,
files={'content': '...'},
user=user_factory(rank=db.User.RANK_REGULAR)))
def test_trying_to_omit_content(context_factory, user_factory):
with pytest.raises(errors.MissingRequiredFileError):
api.PostListApi().post(
api.post_api.create_post(
context_factory(
input={
params={
'safety': 'safe',
'tags': ['tag1', 'tag2'],
},
@ -225,10 +223,9 @@ def test_trying_to_omit_content(context_factory, user_factory):
def test_trying_to_create_post_without_privileges(context_factory, user_factory):
with pytest.raises(errors.AuthError):
api.PostListApi().post(
context_factory(
input='whatever',
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
api.post_api.create_post(context_factory(
params='whatever',
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
def test_trying_to_create_tags_without_privileges(
config_injector, context_factory, user_factory):
@ -243,9 +240,9 @@ def test_trying_to_create_tags_without_privileges(
unittest.mock.patch('szurubooru.func.posts.update_post_content'), \
unittest.mock.patch('szurubooru.func.posts.update_post_tags'):
posts.update_post_tags.return_value = ['new-tag']
api.PostListApi().post(
api.post_api.create_post(
context_factory(
input={
params={
'safety': 'safe',
'tags': ['tag1', 'tag2'],
},

View file

@ -1,49 +1,37 @@
import pytest
import os
from datetime import datetime
from szurubooru import api, config, db, errors
from szurubooru.func import util, posts
import unittest.mock
from szurubooru import api, db, errors
from szurubooru.func import posts, tags
@pytest.fixture
def test_ctx(
tmpdir, config_injector, context_factory, post_factory, user_factory):
config_injector({
'data_dir': str(tmpdir),
'privileges': {
'posts:delete': db.User.RANK_REGULAR,
},
})
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.post_factory = post_factory
ret.api = api.PostDetailApi()
return ret
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({'privileges': {'posts:delete': db.User.RANK_REGULAR}})
def test_deleting(test_ctx):
db.session.add(test_ctx.post_factory(id=1))
def test_deleting(user_factory, post_factory, context_factory):
db.session.add(post_factory(id=1))
db.session.commit()
result = test_ctx.api.delete(
test_ctx.context_factory(
input={'version': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
1)
assert result == {}
assert db.session.query(db.Post).count() == 0
assert os.path.exists(os.path.join(config.config['data_dir'], 'tags.json'))
with unittest.mock.patch('szurubooru.func.tags.export_to_json'):
result = api.post_api.delete_post(
context_factory(
params={'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)),
{'post_id': 1})
assert result == {}
assert db.session.query(db.Post).count() == 0
tags.export_to_json.assert_called_once_with()
def test_trying_to_delete_non_existing(test_ctx):
def test_trying_to_delete_non_existing(user_factory, context_factory):
with pytest.raises(posts.PostNotFoundError):
test_ctx.api.delete(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), '999')
api.post_api.delete_post(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
{'post_id': 999})
def test_trying_to_delete_without_privileges(test_ctx):
db.session.add(test_ctx.post_factory(id=1))
def test_trying_to_delete_without_privileges(
user_factory, post_factory, context_factory):
db.session.add(post_factory(id=1))
db.session.commit()
with pytest.raises(errors.AuthError):
test_ctx.api.delete(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)),
1)
api.post_api.delete_post(
context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)),
{'post_id': 1})
assert db.session.query(db.Post).count() == 1

View file

@ -1,132 +1,129 @@
import datetime
import pytest
import unittest.mock
from datetime import datetime
from szurubooru import api, db, errors
from szurubooru.func import util, posts
from szurubooru.func import posts
@pytest.fixture
def test_ctx(
tmpdir, config_injector, context_factory, user_factory, post_factory):
config_injector({
'data_dir': str(tmpdir),
'data_url': 'http://example.com',
'privileges': {
'posts:favorite': db.User.RANK_REGULAR,
'users:edit:any:email': db.User.RANK_MODERATOR,
},
'thumbnails': {'avatar_width': 200},
})
db.session.flush()
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.post_factory = post_factory
ret.api = api.PostFavoriteApi()
return ret
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({'privileges': {'posts:favorite': db.User.RANK_REGULAR}})
def test_adding_to_favorites(test_ctx, fake_datetime):
post = test_ctx.post_factory()
def test_adding_to_favorites(
user_factory, post_factory, context_factory, fake_datetime):
post = post_factory()
db.session.add(post)
db.session.commit()
assert post.score == 0
with fake_datetime('1997-12-01'):
result = test_ctx.api.post(
test_ctx.context_factory(user=test_ctx.user_factory()),
post.post_id)
assert 'id' in result
post = db.session.query(db.Post).one()
assert db.session.query(db.PostFavorite).count() == 1
assert post is not None
assert post.favorite_count == 1
assert post.score == 1
with unittest.mock.patch('szurubooru.func.posts.serialize_post'), \
fake_datetime('1997-12-01'):
posts.serialize_post.return_value = 'serialized post'
result = api.post_api.add_post_to_favorites(
context_factory(user=user_factory()),
{'post_id': post.post_id})
assert result == 'serialized post'
post = db.session.query(db.Post).one()
assert db.session.query(db.PostFavorite).count() == 1
assert post is not None
assert post.favorite_count == 1
assert post.score == 1
def test_removing_from_favorites(test_ctx, fake_datetime):
user = test_ctx.user_factory()
post = test_ctx.post_factory()
def test_removing_from_favorites(
user_factory, post_factory, context_factory, fake_datetime):
user = user_factory()
post = post_factory()
db.session.add(post)
db.session.commit()
assert post.score == 0
with fake_datetime('1997-12-01'):
result = test_ctx.api.post(
test_ctx.context_factory(user=user),
post.post_id)
assert post.score == 1
with fake_datetime('1997-12-02'):
result = test_ctx.api.delete(
test_ctx.context_factory(user=user),
post.post_id)
post = db.session.query(db.Post).one()
assert post.score == 1
assert db.session.query(db.PostFavorite).count() == 0
assert post.favorite_count == 0
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
with fake_datetime('1997-12-01'):
api.post_api.add_post_to_favorites(
context_factory(user=user),
{'post_id': post.post_id})
assert post.score == 1
with fake_datetime('1997-12-02'):
api.post_api.delete_post_from_favorites(
context_factory(user=user),
{'post_id': post.post_id})
post = db.session.query(db.Post).one()
assert post.score == 1
assert db.session.query(db.PostFavorite).count() == 0
assert post.favorite_count == 0
def test_favoriting_twice(test_ctx, fake_datetime):
user = test_ctx.user_factory()
post = test_ctx.post_factory()
def test_favoriting_twice(
user_factory, post_factory, context_factory, fake_datetime):
user = user_factory()
post = post_factory()
db.session.add(post)
db.session.commit()
with fake_datetime('1997-12-01'):
result = test_ctx.api.post(
test_ctx.context_factory(user=user),
post.post_id)
with fake_datetime('1997-12-02'):
result = test_ctx.api.post(
test_ctx.context_factory(user=user),
post.post_id)
post = db.session.query(db.Post).one()
assert db.session.query(db.PostFavorite).count() == 1
assert post.favorite_count == 1
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
with fake_datetime('1997-12-01'):
api.post_api.add_post_to_favorites(
context_factory(user=user),
{'post_id': post.post_id})
with fake_datetime('1997-12-02'):
api.post_api.add_post_to_favorites(
context_factory(user=user),
{'post_id': post.post_id})
post = db.session.query(db.Post).one()
assert db.session.query(db.PostFavorite).count() == 1
assert post.favorite_count == 1
def test_removing_twice(test_ctx, fake_datetime):
user = test_ctx.user_factory()
post = test_ctx.post_factory()
def test_removing_twice(
user_factory, post_factory, context_factory, fake_datetime):
user = user_factory()
post = post_factory()
db.session.add(post)
db.session.commit()
with fake_datetime('1997-12-01'):
result = test_ctx.api.post(
test_ctx.context_factory(user=user),
post.post_id)
with fake_datetime('1997-12-02'):
result = test_ctx.api.delete(
test_ctx.context_factory(user=user),
post.post_id)
with fake_datetime('1997-12-02'):
result = test_ctx.api.delete(
test_ctx.context_factory(user=user),
post.post_id)
post = db.session.query(db.Post).one()
assert db.session.query(db.PostFavorite).count() == 0
assert post.favorite_count == 0
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
with fake_datetime('1997-12-01'):
api.post_api.add_post_to_favorites(
context_factory(user=user),
{'post_id': post.post_id})
with fake_datetime('1997-12-02'):
api.post_api.delete_post_from_favorites(
context_factory(user=user),
{'post_id': post.post_id})
with fake_datetime('1997-12-02'):
api.post_api.delete_post_from_favorites(
context_factory(user=user),
{'post_id': post.post_id})
post = db.session.query(db.Post).one()
assert db.session.query(db.PostFavorite).count() == 0
assert post.favorite_count == 0
def test_favorites_from_multiple_users(test_ctx, fake_datetime):
user1 = test_ctx.user_factory()
user2 = test_ctx.user_factory()
post = test_ctx.post_factory()
def test_favorites_from_multiple_users(
user_factory, post_factory, context_factory, fake_datetime):
user1 = user_factory()
user2 = user_factory()
post = post_factory()
db.session.add_all([user1, user2, post])
db.session.commit()
with fake_datetime('1997-12-01'):
result = test_ctx.api.post(
test_ctx.context_factory(user=user1),
post.post_id)
with fake_datetime('1997-12-02'):
result = test_ctx.api.post(
test_ctx.context_factory(user=user2),
post.post_id)
post = db.session.query(db.Post).one()
assert db.session.query(db.PostFavorite).count() == 2
assert post.favorite_count == 2
assert post.last_favorite_time == datetime.datetime(1997, 12, 2)
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
with fake_datetime('1997-12-01'):
api.post_api.add_post_to_favorites(
context_factory(user=user1),
{'post_id': post.post_id})
with fake_datetime('1997-12-02'):
api.post_api.add_post_to_favorites(
context_factory(user=user2),
{'post_id': post.post_id})
post = db.session.query(db.Post).one()
assert db.session.query(db.PostFavorite).count() == 2
assert post.favorite_count == 2
assert post.last_favorite_time == datetime(1997, 12, 2)
def test_trying_to_update_non_existing(test_ctx):
def test_trying_to_update_non_existing(user_factory, context_factory):
with pytest.raises(posts.PostNotFoundError):
test_ctx.api.post(
test_ctx.context_factory(user=test_ctx.user_factory()), 5)
api.post_api.add_post_to_favorites(
context_factory(user=user_factory()),
{'post_id': 5})
def test_trying_to_rate_without_privileges(test_ctx):
post = test_ctx.post_factory()
def test_trying_to_rate_without_privileges(
user_factory, post_factory, context_factory):
post = post_factory()
db.session.add(post)
db.session.commit()
with pytest.raises(errors.AuthError):
test_ctx.api.post(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)),
post.post_id)
api.post_api.add_post_to_favorites(
context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)),
{'post_id': post.post_id})

View file

@ -1,107 +1,100 @@
import datetime
import pytest
import unittest.mock
from szurubooru import api, db, errors
from szurubooru.func import util, posts
from szurubooru.func import posts
@pytest.fixture
def test_ctx(
tmpdir, context_factory, config_injector, user_factory, post_factory):
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({
'data_dir': str(tmpdir),
'data_url': 'http://example.com',
'privileges': {
'posts:feature': db.User.RANK_REGULAR,
'posts:view': db.User.RANK_REGULAR,
},
})
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.post_factory = post_factory
ret.api = api.PostFeatureApi()
return ret
def test_no_featured_post(test_ctx):
def test_no_featured_post(user_factory, post_factory, context_factory):
assert posts.try_get_featured_post() is None
result = test_ctx.api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
assert result is None
def test_featuring(test_ctx):
db.session.add(test_ctx.post_factory(id=1))
def test_featuring(user_factory, post_factory, context_factory):
db.session.add(post_factory(id=1))
db.session.commit()
assert not posts.get_post_by_id(1).is_featured
result = test_ctx.api.post(
test_ctx.context_factory(
input={'id': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
assert posts.try_get_featured_post() is not None
assert posts.try_get_featured_post().post_id == 1
assert posts.get_post_by_id(1).is_featured
assert 'id' in result
assert 'snapshots' in result
assert 'comments' in result
result = test_ctx.api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
assert 'id' in result
assert 'snapshots' in result
assert 'comments' in result
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
posts.serialize_post.return_value = 'serialized post'
result = api.post_api.set_featured_post(
context_factory(
params={'id': 1},
user=user_factory(rank=db.User.RANK_REGULAR)))
assert result == 'serialized post'
assert posts.try_get_featured_post() is not None
assert posts.try_get_featured_post().post_id == 1
assert posts.get_post_by_id(1).is_featured
result = api.post_api.get_featured_post(
context_factory(
user=user_factory(rank=db.User.RANK_REGULAR)))
assert result == 'serialized post'
def test_trying_to_feature_the_same_post_twice(test_ctx):
db.session.add(test_ctx.post_factory(id=1))
def test_trying_to_omit_required_parameter(
user_factory, post_factory, context_factory):
with pytest.raises(errors.MissingRequiredParameterError):
api.post_api.set_featured_post(
context_factory(
user=user_factory(rank=db.User.RANK_REGULAR)))
def test_trying_to_feature_the_same_post_twice(
user_factory, post_factory, context_factory):
db.session.add(post_factory(id=1))
db.session.commit()
test_ctx.api.post(
test_ctx.context_factory(
input={'id': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
with pytest.raises(posts.PostAlreadyFeaturedError):
test_ctx.api.post(
test_ctx.context_factory(
input={'id': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
api.post_api.set_featured_post(
context_factory(
params={'id': 1},
user=user_factory(rank=db.User.RANK_REGULAR)))
with pytest.raises(posts.PostAlreadyFeaturedError):
api.post_api.set_featured_post(
context_factory(
params={'id': 1},
user=user_factory(rank=db.User.RANK_REGULAR)))
def test_featuring_one_post_after_another(test_ctx, fake_datetime):
db.session.add(test_ctx.post_factory(id=1))
db.session.add(test_ctx.post_factory(id=2))
def test_featuring_one_post_after_another(
user_factory, post_factory, context_factory, fake_datetime):
db.session.add(post_factory(id=1))
db.session.add(post_factory(id=2))
db.session.commit()
assert posts.try_get_featured_post() is None
assert not posts.get_post_by_id(1).is_featured
assert not posts.get_post_by_id(2).is_featured
with fake_datetime('1997'):
result = test_ctx.api.post(
test_ctx.context_factory(
input={'id': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
with fake_datetime('1998'):
result = test_ctx.api.post(
test_ctx.context_factory(
input={'id': 2},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
assert posts.try_get_featured_post() is not None
assert posts.try_get_featured_post().post_id == 2
assert not posts.get_post_by_id(1).is_featured
assert posts.get_post_by_id(2).is_featured
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
with fake_datetime('1997'):
result = api.post_api.set_featured_post(
context_factory(
params={'id': 1},
user=user_factory(rank=db.User.RANK_REGULAR)))
with fake_datetime('1998'):
result = api.post_api.set_featured_post(
context_factory(
params={'id': 2},
user=user_factory(rank=db.User.RANK_REGULAR)))
assert posts.try_get_featured_post() is not None
assert posts.try_get_featured_post().post_id == 2
assert not posts.get_post_by_id(1).is_featured
assert posts.get_post_by_id(2).is_featured
def test_trying_to_feature_non_existing(test_ctx):
def test_trying_to_feature_non_existing(user_factory, context_factory):
with pytest.raises(posts.PostNotFoundError):
test_ctx.api.post(
test_ctx.context_factory(
input={'id': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
api.post_api.set_featured_post(
context_factory(
params={'id': 1},
user=user_factory(rank=db.User.RANK_REGULAR)))
def test_trying_to_feature_without_privileges(test_ctx):
def test_trying_to_feature_without_privileges(user_factory, context_factory):
with pytest.raises(errors.AuthError):
test_ctx.api.post(
test_ctx.context_factory(
input={'id': 1},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)))
api.post_api.set_featured_post(
context_factory(
params={'id': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
def test_getting_featured_post_without_privileges_to_view(test_ctx):
try:
test_ctx.api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)))
except:
pytest.fail()
def test_getting_featured_post_without_privileges_to_view(
user_factory, context_factory):
api.post_api.get_featured_post(
context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)))

View file

@ -1,147 +1,132 @@
import pytest
import unittest.mock
from szurubooru import api, db, errors
from szurubooru.func import util, posts, scores
from szurubooru.func import posts, scores
@pytest.fixture
def test_ctx(
tmpdir, config_injector, context_factory, user_factory, post_factory):
config_injector({
'data_dir': str(tmpdir),
'data_url': 'http://example.com',
'privileges': {'posts:score': db.User.RANK_REGULAR},
'thumbnails': {'avatar_width': 200},
})
db.session.flush()
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.post_factory = post_factory
ret.api = api.PostScoreApi()
return ret
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({'privileges': {'posts:score': db.User.RANK_REGULAR}})
def test_simple_rating(test_ctx, fake_datetime):
post = test_ctx.post_factory()
def test_simple_rating(
user_factory, post_factory, context_factory, fake_datetime):
post = post_factory()
db.session.add(post)
db.session.commit()
with fake_datetime('1997-12-01'):
result = test_ctx.api.put(
test_ctx.context_factory(
input={'score': 1}, user=test_ctx.user_factory()),
post.post_id)
assert 'id' in result
post = db.session.query(db.Post).one()
assert db.session.query(db.PostScore).count() == 1
assert post is not None
assert post.score == 1
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
posts.serialize_post.return_value = 'serialized post'
with fake_datetime('1997-12-01'):
result = api.post_api.set_post_score(
context_factory(
params={'score': 1}, user=user_factory()),
{'post_id': post.post_id})
assert result == 'serialized post'
post = db.session.query(db.Post).one()
assert db.session.query(db.PostScore).count() == 1
assert post is not None
assert post.score == 1
def test_updating_rating(test_ctx, fake_datetime):
user = test_ctx.user_factory()
post = test_ctx.post_factory()
def test_updating_rating(
user_factory, post_factory, context_factory, fake_datetime):
user = user_factory()
post = post_factory()
db.session.add(post)
db.session.commit()
with fake_datetime('1997-12-01'):
result = test_ctx.api.put(
test_ctx.context_factory(input={'score': 1}, user=user),
post.post_id)
with fake_datetime('1997-12-02'):
result = test_ctx.api.put(
test_ctx.context_factory(input={'score': -1}, user=user),
post.post_id)
post = db.session.query(db.Post).one()
assert db.session.query(db.PostScore).count() == 1
assert post.score == -1
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
with fake_datetime('1997-12-01'):
result = api.post_api.set_post_score(
context_factory(params={'score': 1}, user=user),
{'post_id': post.post_id})
with fake_datetime('1997-12-02'):
result = api.post_api.set_post_score(
context_factory(params={'score': -1}, user=user),
{'post_id': post.post_id})
post = db.session.query(db.Post).one()
assert db.session.query(db.PostScore).count() == 1
assert post.score == -1
def test_updating_rating_to_zero(test_ctx, fake_datetime):
user = test_ctx.user_factory()
post = test_ctx.post_factory()
def test_updating_rating_to_zero(
user_factory, post_factory, context_factory, fake_datetime):
user = user_factory()
post = post_factory()
db.session.add(post)
db.session.commit()
with fake_datetime('1997-12-01'):
result = test_ctx.api.put(
test_ctx.context_factory(input={'score': 1}, user=user),
post.post_id)
with fake_datetime('1997-12-02'):
result = test_ctx.api.put(
test_ctx.context_factory(input={'score': 0}, user=user),
post.post_id)
post = db.session.query(db.Post).one()
assert db.session.query(db.PostScore).count() == 0
assert post.score == 0
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
with fake_datetime('1997-12-01'):
result = api.post_api.set_post_score(
context_factory(params={'score': 1}, user=user),
{'post_id': post.post_id})
with fake_datetime('1997-12-02'):
result = api.post_api.set_post_score(
context_factory(params={'score': 0}, user=user),
{'post_id': post.post_id})
post = db.session.query(db.Post).one()
assert db.session.query(db.PostScore).count() == 0
assert post.score == 0
def test_deleting_rating(test_ctx, fake_datetime):
user = test_ctx.user_factory()
post = test_ctx.post_factory()
def test_deleting_rating(
user_factory, post_factory, context_factory, fake_datetime):
user = user_factory()
post = post_factory()
db.session.add(post)
db.session.commit()
with fake_datetime('1997-12-01'):
result = test_ctx.api.put(
test_ctx.context_factory(input={'score': 1}, user=user),
post.post_id)
with fake_datetime('1997-12-02'):
result = test_ctx.api.delete(
test_ctx.context_factory(user=user), post.post_id)
post = db.session.query(db.Post).one()
assert db.session.query(db.PostScore).count() == 0
assert post.score == 0
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
with fake_datetime('1997-12-01'):
result = api.post_api.set_post_score(
context_factory(params={'score': 1}, user=user),
{'post_id': post.post_id})
with fake_datetime('1997-12-02'):
result = api.post_api.delete_post_score(
context_factory(user=user),
{'post_id': post.post_id})
post = db.session.query(db.Post).one()
assert db.session.query(db.PostScore).count() == 0
assert post.score == 0
def test_ratings_from_multiple_users(test_ctx, fake_datetime):
user1 = test_ctx.user_factory()
user2 = test_ctx.user_factory()
post = test_ctx.post_factory()
def test_ratings_from_multiple_users(
user_factory, post_factory, context_factory, fake_datetime):
user1 = user_factory()
user2 = user_factory()
post = post_factory()
db.session.add_all([user1, user2, post])
db.session.commit()
with fake_datetime('1997-12-01'):
result = test_ctx.api.put(
test_ctx.context_factory(input={'score': 1}, user=user1),
post.post_id)
with fake_datetime('1997-12-02'):
result = test_ctx.api.put(
test_ctx.context_factory(input={'score': -1}, user=user2),
post.post_id)
post = db.session.query(db.Post).one()
assert db.session.query(db.PostScore).count() == 2
assert post.score == 0
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
with fake_datetime('1997-12-01'):
result = api.post_api.set_post_score(
context_factory(params={'score': 1}, user=user1),
{'post_id': post.post_id})
with fake_datetime('1997-12-02'):
result = api.post_api.set_post_score(
context_factory(params={'score': -1}, user=user2),
{'post_id': post.post_id})
post = db.session.query(db.Post).one()
assert db.session.query(db.PostScore).count() == 2
assert post.score == 0
@pytest.mark.parametrize('input,expected_exception', [
({'score': None}, errors.ValidationError),
({'score': ''}, errors.ValidationError),
({'score': -2}, scores.InvalidScoreValueError),
({'score': 2}, scores.InvalidScoreValueError),
({'score': [1]}, errors.ValidationError),
])
def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception):
post = test_ctx.post_factory()
db.session.add(post)
db.session.commit()
with pytest.raises(expected_exception):
test_ctx.api.put(
test_ctx.context_factory(input=input, user=test_ctx.user_factory()),
post.post_id)
def test_trying_to_omit_mandatory_field(test_ctx):
post = test_ctx.post_factory()
def test_trying_to_omit_mandatory_field(
user_factory, post_factory, context_factory):
post = post_factory()
db.session.add(post)
db.session.commit()
with pytest.raises(errors.ValidationError):
test_ctx.api.put(
test_ctx.context_factory(input={}, user=test_ctx.user_factory()),
post.post_id)
api.post_api.set_post_score(
context_factory(params={}, user=user_factory()),
{'post_id': post.post_id})
def test_trying_to_update_non_existing(test_ctx):
def test_trying_to_update_non_existing(
user_factory, post_factory, context_factory):
with pytest.raises(posts.PostNotFoundError):
test_ctx.api.put(
test_ctx.context_factory(
input={'score': 1},
user=test_ctx.user_factory()),
5)
api.post_api.set_post_score(
context_factory(params={'score': 1}, user=user_factory()),
{'post_id': 5})
def test_trying_to_rate_without_privileges(test_ctx):
post = test_ctx.post_factory()
def test_trying_to_rate_without_privileges(
user_factory, post_factory, context_factory):
post = post_factory()
db.session.add(post)
db.session.commit()
with pytest.raises(errors.AuthError):
test_ctx.api.put(
test_ctx.context_factory(
input={'score': 1},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)),
post.post_id)
api.post_api.set_post_score(
context_factory(
params={'score': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
{'post_id': post.post_id})

View file

@ -1,105 +1,97 @@
import datetime
import pytest
import unittest.mock
from datetime import datetime
from szurubooru import api, db, errors
from szurubooru.func import util, posts
from szurubooru.func import posts
@pytest.fixture
def test_ctx(
tmpdir, context_factory, config_injector, user_factory, post_factory):
@pytest.fixture(autouse=True)
def inject_config(tmpdir, config_injector):
config_injector({
'data_dir': str(tmpdir),
'data_url': 'http://example.com',
'privileges': {
'posts:list': db.User.RANK_REGULAR,
'posts:view': db.User.RANK_REGULAR,
},
'thumbnails': {'avatar_width': 200},
})
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.post_factory = post_factory
ret.list_api = api.PostListApi()
ret.detail_api = api.PostDetailApi()
return ret
def test_retrieving_multiple(test_ctx):
post1 = test_ctx.post_factory(id=1)
post2 = test_ctx.post_factory(id=2)
def test_retrieving_multiple(user_factory, post_factory, context_factory):
post1 = post_factory(id=1)
post2 = post_factory(id=2)
db.session.add_all([post1, post2])
result = test_ctx.list_api.get(
test_ctx.context_factory(
input={'query': '', 'page': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
assert result['query'] == ''
assert result['page'] == 1
assert result['pageSize'] == 100
assert result['total'] == 2
assert [t['id'] for t in result['results']] == [2, 1]
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
posts.serialize_post.return_value = 'serialized post'
result = api.post_api.get_posts(
context_factory(
params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_REGULAR)))
assert result == {
'query': '',
'page': 1,
'pageSize': 100,
'total': 2,
'results': ['serialized post', 'serialized post'],
}
def test_using_special_tokens(
test_ctx, config_injector):
auth_user = test_ctx.user_factory(rank=db.User.RANK_REGULAR)
post1 = test_ctx.post_factory(id=1)
post2 = test_ctx.post_factory(id=2)
def test_using_special_tokens(user_factory, post_factory, context_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR)
post1 = post_factory(id=1)
post2 = post_factory(id=2)
post1.favorited_by = [db.PostFavorite(
user=auth_user, time=datetime.datetime.utcnow())]
user=auth_user, time=datetime.utcnow())]
db.session.add_all([post1, post2, auth_user])
db.session.flush()
result = test_ctx.list_api.get(
test_ctx.context_factory(
input={'query': 'special:fav', 'page': 1},
user=auth_user))
assert result['query'] == 'special:fav'
assert result['page'] == 1
assert result['pageSize'] == 100
assert result['total'] == 1
assert [t['id'] for t in result['results']] == [1]
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
posts.serialize_post.side_effect = \
lambda post, *_args, **_kwargs: \
'serialized post %d' % post.post_id
result = api.post_api.get_posts(
context_factory(
params={'query': 'special:fav', 'page': 1},
user=auth_user))
assert result == {
'query': 'special:fav',
'page': 1,
'pageSize': 100,
'total': 1,
'results': ['serialized post 1'],
}
def test_trying_to_use_special_tokens_without_logging_in(
test_ctx, config_injector):
user_factory, post_factory, context_factory, config_injector):
config_injector({
'privileges': {'posts:list': 'anonymous'},
})
with pytest.raises(errors.SearchError):
test_ctx.list_api.get(
test_ctx.context_factory(
input={'query': 'special:fav', 'page': 1},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)))
api.post_api.get_posts(
context_factory(
params={'query': 'special:fav', 'page': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
def test_trying_to_retrieve_multiple_without_privileges(test_ctx):
def test_trying_to_retrieve_multiple_without_privileges(
user_factory, context_factory):
with pytest.raises(errors.AuthError):
test_ctx.list_api.get(
test_ctx.context_factory(
input={'query': '', 'page': 1},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)))
api.post_api.get_posts(
context_factory(
params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
def test_retrieving_single(test_ctx):
db.session.add(test_ctx.post_factory(id=1))
result = test_ctx.detail_api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 1)
assert 'id' in result
assert 'snapshots' in result
assert 'comments' in result
def test_retrieving_single(user_factory, post_factory, context_factory):
db.session.add(post_factory(id=1))
with unittest.mock.patch('szurubooru.func.posts.serialize_post'):
posts.serialize_post.return_value = 'serialized post'
result = api.post_api.get_post(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
{'post_id': 1})
assert result == 'serialized post'
def test_trying_to_retrieve_invalid_id(test_ctx):
with pytest.raises(posts.InvalidPostIdError):
test_ctx.detail_api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'-')
def test_trying_to_retrieve_single_non_existing(test_ctx):
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
with pytest.raises(posts.PostNotFoundError):
test_ctx.detail_api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'999')
api.post_api.get_post(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
{'post_id': 999})
def test_trying_to_retrieve_single_without_privileges(test_ctx):
def test_trying_to_retrieve_single_without_privileges(
user_factory, context_factory):
with pytest.raises(errors.AuthError):
test_ctx.detail_api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)),
'999')
api.post_api.get_post(
context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)),
{'post_id': 999})

View file

@ -1,12 +1,11 @@
import datetime
import os
import unittest.mock
import pytest
import unittest.mock
from datetime import datetime
from szurubooru import api, db, errors
from szurubooru.func import posts, tags, snapshots, net
def test_post_updating(
config_injector, context_factory, post_factory, user_factory, fake_datetime):
@pytest.fixture(autouse=True)
def inject_config(tmpdir, config_injector):
config_injector({
'privileges': {
'posts:edit:tags': db.User.RANK_REGULAR,
@ -17,46 +16,49 @@ def test_post_updating(
'posts:edit:notes': db.User.RANK_REGULAR,
'posts:edit:flags': db.User.RANK_REGULAR,
'posts:edit:thumbnail': db.User.RANK_REGULAR,
'tags:create': db.User.RANK_MODERATOR,
},
})
def test_post_updating(
context_factory, post_factory, user_factory, fake_datetime):
auth_user = user_factory(rank=db.User.RANK_REGULAR)
post = post_factory()
db.session.add(post)
db.session.flush()
with unittest.mock.patch('szurubooru.func.posts.create_post'), \
unittest.mock.patch('szurubooru.func.posts.update_post_tags'), \
unittest.mock.patch('szurubooru.func.posts.update_post_content'), \
unittest.mock.patch('szurubooru.func.posts.update_post_thumbnail'), \
unittest.mock.patch('szurubooru.func.posts.update_post_safety'), \
unittest.mock.patch('szurubooru.func.posts.update_post_source'), \
unittest.mock.patch('szurubooru.func.posts.update_post_relations'), \
unittest.mock.patch('szurubooru.func.posts.update_post_notes'), \
unittest.mock.patch('szurubooru.func.posts.update_post_flags'), \
unittest.mock.patch('szurubooru.func.posts.serialize_post'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'), \
unittest.mock.patch('szurubooru.func.snapshots.save_entity_modification'):
unittest.mock.patch('szurubooru.func.posts.update_post_tags'), \
unittest.mock.patch('szurubooru.func.posts.update_post_content'), \
unittest.mock.patch('szurubooru.func.posts.update_post_thumbnail'), \
unittest.mock.patch('szurubooru.func.posts.update_post_safety'), \
unittest.mock.patch('szurubooru.func.posts.update_post_source'), \
unittest.mock.patch('szurubooru.func.posts.update_post_relations'), \
unittest.mock.patch('szurubooru.func.posts.update_post_notes'), \
unittest.mock.patch('szurubooru.func.posts.update_post_flags'), \
unittest.mock.patch('szurubooru.func.posts.serialize_post'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'), \
unittest.mock.patch('szurubooru.func.snapshots.save_entity_modification'), \
fake_datetime('1997-01-01'):
posts.serialize_post.return_value = 'serialized post'
with fake_datetime('1997-01-01'):
result = api.PostDetailApi().put(
context_factory(
input={
'version': 1,
'safety': 'safe',
'tags': ['tag1', 'tag2'],
'relations': [1, 2],
'source': 'source',
'notes': ['note1', 'note2'],
'flags': ['flag1', 'flag2'],
},
files={
'content': 'post-content',
'thumbnail': 'post-thumbnail',
},
user=auth_user),
post.post_id)
result = api.post_api.update_post(
context_factory(
params={
'version': 1,
'safety': 'safe',
'tags': ['tag1', 'tag2'],
'relations': [1, 2],
'source': 'source',
'notes': ['note1', 'note2'],
'flags': ['flag1', 'flag2'],
},
files={
'content': 'post-content',
'thumbnail': 'post-thumbnail',
},
user=auth_user),
{'post_id': post.post_id})
assert result == 'serialized post'
posts.create_post.assert_not_called()
@ -71,71 +73,62 @@ def test_post_updating(
posts.serialize_post.assert_called_once_with(post, auth_user, options=None)
tags.export_to_json.assert_called_once_with()
snapshots.save_entity_modification.assert_called_once_with(post, auth_user)
assert post.last_edit_time == datetime.datetime(1997, 1, 1)
assert post.last_edit_time == datetime(1997, 1, 1)
def test_uploading_from_url_saves_source(
config_injector, context_factory, post_factory, user_factory):
config_injector({
'privileges': {'posts:edit:content': db.User.RANK_REGULAR},
})
context_factory, post_factory, user_factory):
post = post_factory()
db.session.add(post)
db.session.flush()
with unittest.mock.patch('szurubooru.func.net.download'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'), \
unittest.mock.patch('szurubooru.func.snapshots.save_entity_modification'), \
unittest.mock.patch('szurubooru.func.posts.serialize_post'), \
unittest.mock.patch('szurubooru.func.posts.update_post_content'), \
unittest.mock.patch('szurubooru.func.posts.update_post_source'):
unittest.mock.patch('szurubooru.func.tags.export_to_json'), \
unittest.mock.patch('szurubooru.func.snapshots.save_entity_modification'), \
unittest.mock.patch('szurubooru.func.posts.serialize_post'), \
unittest.mock.patch('szurubooru.func.posts.update_post_content'), \
unittest.mock.patch('szurubooru.func.posts.update_post_source'):
net.download.return_value = b'content'
api.PostDetailApi().put(
api.post_api.update_post(
context_factory(
input={'contentUrl': 'example.com', 'version': 1},
params={'contentUrl': 'example.com', 'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)),
post.post_id)
{'post_id': post.post_id})
net.download.assert_called_once_with('example.com')
posts.update_post_content.assert_called_once_with(post, b'content')
posts.update_post_source.assert_called_once_with(post, 'example.com')
def test_uploading_from_url_with_source_specified(
config_injector, context_factory, post_factory, user_factory):
config_injector({
'privileges': {
'posts:edit:content': db.User.RANK_REGULAR,
'posts:edit:source': db.User.RANK_REGULAR,
},
})
context_factory, post_factory, user_factory):
post = post_factory()
db.session.add(post)
db.session.flush()
with unittest.mock.patch('szurubooru.func.net.download'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'), \
unittest.mock.patch('szurubooru.func.snapshots.save_entity_modification'), \
unittest.mock.patch('szurubooru.func.posts.serialize_post'), \
unittest.mock.patch('szurubooru.func.posts.update_post_content'), \
unittest.mock.patch('szurubooru.func.posts.update_post_source'):
unittest.mock.patch('szurubooru.func.tags.export_to_json'), \
unittest.mock.patch('szurubooru.func.snapshots.save_entity_modification'), \
unittest.mock.patch('szurubooru.func.posts.serialize_post'), \
unittest.mock.patch('szurubooru.func.posts.update_post_content'), \
unittest.mock.patch('szurubooru.func.posts.update_post_source'):
net.download.return_value = b'content'
api.PostDetailApi().put(
api.post_api.update_post(
context_factory(
input={
params={
'contentUrl': 'example.com',
'source': 'example2.com',
'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)),
post.post_id)
{'post_id': post.post_id})
net.download.assert_called_once_with('example.com')
posts.update_post_content.assert_called_once_with(post, b'content')
posts.update_post_source.assert_called_once_with(post, 'example2.com')
def test_trying_to_update_non_existing(context_factory, user_factory):
with pytest.raises(posts.PostNotFoundError):
api.PostDetailApi().put(
api.post_api.update_post(
context_factory(
input='whatever',
params='whatever',
user=user_factory(rank=db.User.RANK_REGULAR)),
1)
{'post_id': 1})
@pytest.mark.parametrize('privilege,files,input', [
@pytest.mark.parametrize('privilege,files,params', [
('posts:edit:tags', {}, {'tags': '...'}),
('posts:edit:safety', {}, {'safety': '...'}),
('posts:edit:source', {}, {'source': '...'}),
@ -146,43 +139,28 @@ def test_trying_to_update_non_existing(context_factory, user_factory):
('posts:edit:thumbnail', {'thumbnail': '...'}, {}),
])
def test_trying_to_update_field_without_privileges(
config_injector,
context_factory,
post_factory,
user_factory,
files,
input,
privilege):
config_injector({
'privileges': {privilege: db.User.RANK_REGULAR},
})
context_factory, post_factory, user_factory, files, params, privilege):
post = post_factory()
db.session.add(post)
db.session.flush()
with pytest.raises(errors.AuthError):
api.PostDetailApi().put(
api.post_api.update_post(
context_factory(
input={**input, **{'version': 1}},
params={**params, **{'version': 1}},
files=files,
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
post.post_id)
{'post_id': post.post_id})
def test_trying_to_create_tags_without_privileges(
config_injector, context_factory, post_factory, user_factory):
config_injector({
'privileges': {
'posts:edit:tags': db.User.RANK_REGULAR,
'tags:create': db.User.RANK_ADMINISTRATOR,
},
})
context_factory, post_factory, user_factory):
post = post_factory()
db.session.add(post)
db.session.flush()
with pytest.raises(errors.AuthError), \
unittest.mock.patch('szurubooru.func.posts.update_post_tags'):
posts.update_post_tags.return_value = ['new-tag']
api.PostDetailApi().put(
api.post_api.update_post(
context_factory(
input={'tags': ['tag1', 'tag2'], 'version': 1},
params={'tags': ['tag1', 'tag2'], 'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)),
post.post_id)
{'post_id': post.post_id})

View file

@ -1,11 +1,10 @@
import datetime
import pytest
from datetime import datetime
from szurubooru import api, db, errors
from szurubooru.func import util, tags
def snapshot_factory():
snapshot = db.Snapshot()
snapshot.creation_time = datetime.datetime(1999, 1, 1)
snapshot.creation_time = datetime(1999, 1, 1)
snapshot.resource_type = 'dummy'
snapshot.resource_id = 1
snapshot.resource_repr = 'dummy'
@ -13,37 +12,30 @@ def snapshot_factory():
snapshot.data = '{}'
return snapshot
@pytest.fixture
def test_ctx(context_factory, config_injector, user_factory):
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({
'privileges': {
'snapshots:list': db.User.RANK_REGULAR,
},
'thumbnails': {'avatar_width': 200},
'privileges': {'snapshots:list': db.User.RANK_REGULAR},
})
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.api = api.SnapshotListApi()
return ret
def test_retrieving_multiple(test_ctx):
def test_retrieving_multiple(user_factory, context_factory):
snapshot1 = snapshot_factory()
snapshot2 = snapshot_factory()
db.session.add_all([snapshot1, snapshot2])
result = test_ctx.api.get(
test_ctx.context_factory(
input={'query': '', 'page': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
result = api.snapshot_api.get_snapshots(
context_factory(
params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_REGULAR)))
assert result['query'] == ''
assert result['page'] == 1
assert result['pageSize'] == 100
assert result['total'] == 2
assert len(result['results']) == 2
def test_trying_to_retrieve_multiple_without_privileges(test_ctx):
def test_trying_to_retrieve_multiple_without_privileges(
user_factory, context_factory):
with pytest.raises(errors.AuthError):
test_ctx.api.get(
test_ctx.context_factory(
input={'query': '', 'page': 1},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)))
api.snapshot_api.get_snapshots(
context_factory(
params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS)))

View file

@ -1,94 +1,50 @@
import os
import pytest
from szurubooru import api, config, db, errors
from szurubooru.func import util, tag_categories
import unittest.mock
from szurubooru import api, db, errors
from szurubooru.func import tag_categories, tags
@pytest.fixture
def test_ctx(tmpdir, config_injector, context_factory, user_factory):
def _update_category_name(category, name):
category.name = name
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({
'data_dir': str(tmpdir),
'tag_category_name_regex': '^[^!]+$',
'privileges': {'tag_categories:create': db.User.RANK_REGULAR},
})
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.api = api.TagCategoryListApi()
return ret
def test_creating_category(test_ctx):
result = test_ctx.api.post(
test_ctx.context_factory(
input={'name': 'meta', 'color': 'black'},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
assert len(result['snapshots']) == 1
del result['snapshots']
assert result == {
'name': 'meta',
'color': 'black',
'usages': 0,
'default': True,
'version': 1,
}
category = db.session.query(db.TagCategory).one()
assert category.name == 'meta'
assert category.color == 'black'
assert category.tag_count == 0
assert os.path.exists(os.path.join(config.config['data_dir'], 'tags.json'))
@pytest.mark.parametrize('input', [
{'name': None},
{'name': ''},
{'name': '!bad'},
{'color': None},
{'color': ''},
{'color': 'a' * 100},
])
def test_trying_to_pass_invalid_input(test_ctx, input):
real_input = {
'name': 'okay',
'color': 'okay',
}
for key, value in input.items():
real_input[key] = value
with pytest.raises(errors.ValidationError):
test_ctx.api.post(
test_ctx.context_factory(
input=real_input,
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
def test_creating_category(user_factory, context_factory):
with unittest.mock.patch('szurubooru.func.tag_categories.serialize_category'), \
unittest.mock.patch('szurubooru.func.tag_categories.update_category_name'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'):
tag_categories.update_category_name.side_effect = _update_category_name
tag_categories.serialize_category.return_value = 'serialized category'
result = api.tag_category_api.create_tag_category(
context_factory(
params={'name': 'meta', 'color': 'black'},
user=user_factory(rank=db.User.RANK_REGULAR)))
assert result == 'serialized category'
category = db.session.query(db.TagCategory).one()
assert category.name == 'meta'
assert category.color == 'black'
assert category.tag_count == 0
tags.export_to_json.assert_called_once_with()
@pytest.mark.parametrize('field', ['name', 'color'])
def test_trying_to_omit_mandatory_field(test_ctx, field):
input = {
def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
params = {
'name': 'meta',
'color': 'black',
}
del input[field]
del params[field]
with pytest.raises(errors.ValidationError):
test_ctx.api.post(
test_ctx.context_factory(
input=input,
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
api.tag_category_api.create_tag_category(
context_factory(
params=params,
user=user_factory(rank=db.User.RANK_REGULAR)))
def test_trying_to_use_existing_name(test_ctx):
result = test_ctx.api.post(
test_ctx.context_factory(
input={'name': 'meta', 'color': 'black'},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
with pytest.raises(tag_categories.TagCategoryAlreadyExistsError):
result = test_ctx.api.post(
test_ctx.context_factory(
input={'name': 'meta', 'color': 'black'},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
with pytest.raises(tag_categories.TagCategoryAlreadyExistsError):
result = test_ctx.api.post(
test_ctx.context_factory(
input={'name': 'META', 'color': 'black'},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
def test_trying_to_create_without_privileges(test_ctx):
def test_trying_to_create_without_privileges(user_factory, context_factory):
with pytest.raises(errors.AuthError):
test_ctx.api.post(
test_ctx.context_factory(
input={'name': 'meta', 'color': 'black'},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)))
api.tag_category_api.create_tag_category(
context_factory(
params={'name': 'meta', 'color': 'black'},
user=user_factory(rank=db.User.RANK_ANONYMOUS)))

View file

@ -1,84 +1,70 @@
import pytest
import os
from datetime import datetime
from szurubooru import api, config, db, errors
from szurubooru.func import util, tags, tag_categories
import unittest.mock
from szurubooru import api, db, errors
from szurubooru.func import tag_categories, tags
@pytest.fixture
def test_ctx(
tmpdir,
config_injector,
context_factory,
tag_factory,
tag_category_factory,
user_factory):
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({
'data_dir': str(tmpdir),
'privileges': {
'tag_categories:delete': db.User.RANK_REGULAR,
},
'privileges': {'tag_categories:delete': db.User.RANK_REGULAR},
})
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.tag_factory = tag_factory
ret.tag_category_factory = tag_category_factory
ret.api = api.TagCategoryDetailApi()
return ret
def test_deleting(test_ctx):
db.session.add(test_ctx.tag_category_factory(name='root'))
db.session.add(test_ctx.tag_category_factory(name='category'))
def test_deleting(user_factory, tag_category_factory, context_factory):
db.session.add(tag_category_factory(name='root'))
db.session.add(tag_category_factory(name='category'))
db.session.commit()
result = test_ctx.api.delete(
test_ctx.context_factory(
input={'version': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'category')
assert result == {}
assert db.session.query(db.TagCategory).count() == 1
assert db.session.query(db.TagCategory).one().name == 'root'
assert os.path.exists(os.path.join(config.config['data_dir'], 'tags.json'))
with unittest.mock.patch('szurubooru.func.tags.export_to_json'):
result = api.tag_category_api.delete_tag_category(
context_factory(
params={'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)),
{'category_name': 'category'})
assert result == {}
assert db.session.query(db.TagCategory).count() == 1
assert db.session.query(db.TagCategory).one().name == 'root'
tags.export_to_json.assert_called_once_with()
def test_trying_to_delete_used(test_ctx, tag_factory):
category = test_ctx.tag_category_factory(name='category')
def test_trying_to_delete_used(
user_factory, tag_category_factory, tag_factory, context_factory):
category = tag_category_factory(name='category')
db.session.add(category)
db.session.flush()
tag = test_ctx.tag_factory(names=['tag'], category=category)
tag = tag_factory(names=['tag'], category=category)
db.session.add(tag)
db.session.commit()
with pytest.raises(tag_categories.TagCategoryIsInUseError):
test_ctx.api.delete(
test_ctx.context_factory(
input={'version': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'category')
api.tag_category_api.delete_tag_category(
context_factory(
params={'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)),
{'category_name': 'category'})
assert db.session.query(db.TagCategory).count() == 1
def test_trying_to_delete_last(test_ctx, tag_factory):
db.session.add(test_ctx.tag_category_factory(name='root'))
def test_trying_to_delete_last(
user_factory, tag_category_factory, context_factory):
db.session.add(tag_category_factory(name='root'))
db.session.commit()
with pytest.raises(tag_categories.TagCategoryIsInUseError):
result = test_ctx.api.delete(
test_ctx.context_factory(
input={'version': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'root')
api.tag_category_api.delete_tag_category(
context_factory(
params={'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)),
{'category_name': 'root'})
def test_trying_to_delete_non_existing(test_ctx):
def test_trying_to_delete_non_existing(user_factory, context_factory):
with pytest.raises(tag_categories.TagCategoryNotFoundError):
test_ctx.api.delete(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'bad')
api.tag_category_api.delete_tag_category(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
{'category_name': 'bad'})
def test_trying_to_delete_without_privileges(test_ctx):
db.session.add(test_ctx.tag_category_factory(name='category'))
def test_trying_to_delete_without_privileges(
user_factory, tag_category_factory, context_factory):
db.session.add(tag_category_factory(name='category'))
db.session.commit()
with pytest.raises(errors.AuthError):
test_ctx.api.delete(
test_ctx.context_factory(
input={'version': 1},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)),
'category')
api.tag_category_api.delete_tag_category(
context_factory(
params={'version': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
{'category_name': 'category'})
assert db.session.query(db.TagCategory).count() == 1

View file

@ -1,42 +1,31 @@
import datetime
import pytest
from szurubooru import api, db, errors
from szurubooru.func import util, tag_categories
from szurubooru.func import tag_categories
@pytest.fixture
def test_ctx(
context_factory, config_injector, user_factory, tag_category_factory):
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({
'privileges': {
'tag_categories:list': db.User.RANK_REGULAR,
'tag_categories:view': db.User.RANK_REGULAR,
},
'thumbnails': {'avatar_width': 200},
})
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.tag_category_factory = tag_category_factory
ret.list_api = api.TagCategoryListApi()
ret.detail_api = api.TagCategoryDetailApi()
return ret
def test_retrieving_multiple(test_ctx):
def test_retrieving_multiple(
user_factory, tag_category_factory, context_factory):
db.session.add_all([
test_ctx.tag_category_factory(name='c1'),
test_ctx.tag_category_factory(name='c2'),
tag_category_factory(name='c1'),
tag_category_factory(name='c2'),
])
result = test_ctx.list_api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
result = api.tag_category_api.get_tag_categories(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)))
assert [cat['name'] for cat in result['results']] == ['c1', 'c2']
def test_retrieving_single(test_ctx):
db.session.add(test_ctx.tag_category_factory(name='cat'))
result = test_ctx.detail_api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'cat')
def test_retrieving_single(user_factory, tag_category_factory, context_factory):
db.session.add(tag_category_factory(name='cat'))
result = api.tag_category_api.get_tag_category(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
{'category_name': 'cat'})
assert result == {
'name': 'cat',
'color': 'dummy',
@ -46,16 +35,15 @@ def test_retrieving_single(test_ctx):
'version': 1,
}
def test_trying_to_retrieve_single_non_existing(test_ctx):
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
with pytest.raises(tag_categories.TagCategoryNotFoundError):
test_ctx.detail_api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'-')
api.tag_category_api.get_tag_category(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
{'category_name': '-'})
def test_trying_to_retrieve_single_without_privileges(test_ctx):
def test_trying_to_retrieve_single_without_privileges(
user_factory, context_factory):
with pytest.raises(errors.AuthError):
test_ctx.detail_api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)),
'-')
api.tag_category_api.get_tag_category(
context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)),
{'category_name': '-'})

View file

@ -1,137 +1,104 @@
import os
import pytest
from szurubooru import api, config, db, errors
from szurubooru.func import util, tag_categories
import unittest.mock
from szurubooru import api, db, errors
from szurubooru.func import tag_categories, tags
@pytest.fixture
def test_ctx(
tmpdir,
config_injector,
context_factory,
user_factory,
tag_category_factory):
def _update_category_name(category, name):
category.name = name
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({
'data_dir': str(tmpdir),
'tag_category_name_regex': '^[^!]*$',
'privileges': {
'tag_categories:edit:name': db.User.RANK_REGULAR,
'tag_categories:edit:color': db.User.RANK_REGULAR,
'tag_categories:set_default': db.User.RANK_REGULAR,
},
})
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.tag_category_factory = tag_category_factory
ret.api = api.TagCategoryDetailApi()
return ret
def test_simple_updating(test_ctx):
category = test_ctx.tag_category_factory(name='name', color='black')
def test_simple_updating(user_factory, tag_category_factory, context_factory):
category = tag_category_factory(name='name', color='black')
db.session.add(category)
db.session.commit()
result = test_ctx.api.put(
test_ctx.context_factory(
input={
'name': 'changed',
'color': 'white',
'version': 1,
},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'name')
assert len(result['snapshots']) == 1
del result['snapshots']
assert result == {
'name': 'changed',
'color': 'white',
'usages': 0,
'default': False,
'version': 2,
}
assert tag_categories.try_get_category_by_name('name') is None
category = tag_categories.get_category_by_name('changed')
assert category is not None
assert category.name == 'changed'
assert category.color == 'white'
assert os.path.exists(os.path.join(config.config['data_dir'], 'tags.json'))
@pytest.mark.parametrize('input,expected_exception', [
({'name': None}, tag_categories.InvalidTagCategoryNameError),
({'name': ''}, tag_categories.InvalidTagCategoryNameError),
({'name': '!bad'}, tag_categories.InvalidTagCategoryNameError),
({'color': None}, tag_categories.InvalidTagCategoryColorError),
({'color': ''}, tag_categories.InvalidTagCategoryColorError),
({'color': '; float:left'}, tag_categories.InvalidTagCategoryColorError),
])
def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception):
db.session.add(test_ctx.tag_category_factory(name='meta', color='black'))
db.session.commit()
with pytest.raises(expected_exception):
test_ctx.api.put(
test_ctx.context_factory(
input={**input, **{'version': 1}},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'meta')
with unittest.mock.patch('szurubooru.func.tag_categories.serialize_category'), \
unittest.mock.patch('szurubooru.func.tag_categories.update_category_name'), \
unittest.mock.patch('szurubooru.func.tag_categories.update_category_color'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'):
tag_categories.update_category_name.side_effect = _update_category_name
tag_categories.serialize_category.return_value = 'serialized category'
result = api.tag_category_api.update_tag_category(
context_factory(
params={
'name': 'changed',
'color': 'white',
'version': 1,
},
user=user_factory(rank=db.User.RANK_REGULAR)),
{'category_name': 'name'})
assert result == 'serialized category'
tag_categories.update_category_name.assert_called_once_with(category, 'changed')
tag_categories.update_category_color.assert_called_once_with(category, 'white')
tags.export_to_json.assert_called_once_with()
@pytest.mark.parametrize('field', ['name', 'color'])
def test_omitting_optional_field(test_ctx, field):
db.session.add(test_ctx.tag_category_factory(name='name', color='black'))
def test_omitting_optional_field(
user_factory, tag_category_factory, context_factory, field):
db.session.add(tag_category_factory(name='name', color='black'))
db.session.commit()
input = {
params = {
'name': 'changed',
'color': 'white',
}
del input[field]
result = test_ctx.api.put(
test_ctx.context_factory(
input={**input, **{'version': 1}},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'name')
assert result is not None
del params[field]
with unittest.mock.patch('szurubooru.func.tag_categories.serialize_category'), \
unittest.mock.patch('szurubooru.func.tag_categories.update_category_name'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'):
api.tag_category_api.update_tag_category(
context_factory(
params={**params, **{'version': 1}},
user=user_factory(rank=db.User.RANK_REGULAR)),
{'category_name': 'name'})
def test_trying_to_update_non_existing(test_ctx):
def test_trying_to_update_non_existing(user_factory, context_factory):
with pytest.raises(tag_categories.TagCategoryNotFoundError):
test_ctx.api.put(
test_ctx.context_factory(
input={'name': ['dummy']},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'bad')
api.tag_category_api.update_tag_category(
context_factory(
params={'name': ['dummy']},
user=user_factory(rank=db.User.RANK_REGULAR)),
{'category_name': 'bad'})
@pytest.mark.parametrize('new_name', ['cat', 'CAT'])
def test_reusing_own_name(test_ctx, new_name):
db.session.add(test_ctx.tag_category_factory(name='cat', color='black'))
db.session.commit()
result = test_ctx.api.put(
test_ctx.context_factory(
input={'name': new_name, 'version': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'cat')
assert result['name'] == new_name
category = tag_categories.get_category_by_name('cat')
assert category.name == new_name
@pytest.mark.parametrize('dup_name', ['cat1', 'CAT1'])
def test_trying_to_use_existing_name(test_ctx, dup_name):
db.session.add_all([
test_ctx.tag_category_factory(name='cat1', color='black'),
test_ctx.tag_category_factory(name='cat2', color='black')])
db.session.commit()
with pytest.raises(tag_categories.TagCategoryAlreadyExistsError):
test_ctx.api.put(
test_ctx.context_factory(
input={'name': dup_name, 'version': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'cat2')
@pytest.mark.parametrize('input', [
@pytest.mark.parametrize('params', [
{'name': 'whatever'},
{'color': 'whatever'},
])
def test_trying_to_update_without_privileges(test_ctx, input):
db.session.add(test_ctx.tag_category_factory(name='dummy'))
def test_trying_to_update_without_privileges(
user_factory, tag_category_factory, context_factory, params):
db.session.add(tag_category_factory(name='dummy'))
db.session.commit()
with pytest.raises(errors.AuthError):
test_ctx.api.put(
test_ctx.context_factory(
input={**input, **{'version': 1}},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)),
'dummy')
api.tag_category_api.update_tag_category(
context_factory(
params={**params, **{'version': 1}},
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
{'category_name': 'dummy'})
def test_set_as_default(user_factory, tag_category_factory, context_factory):
category = tag_category_factory(name='name', color='black')
db.session.add(category)
db.session.commit()
with unittest.mock.patch('szurubooru.func.tag_categories.serialize_category'), \
unittest.mock.patch('szurubooru.func.tag_categories.set_default_category'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'):
tag_categories.update_category_name.side_effect = _update_category_name
tag_categories.serialize_category.return_value = 'serialized category'
result = api.tag_category_api.set_tag_category_as_default(
context_factory(
params={
'name': 'changed',
'color': 'white',
'version': 1,
},
user=user_factory(rank=db.User.RANK_REGULAR)),
{'category_name': 'name'})
assert result == 'serialized category'
tag_categories.set_default_category.assert_called_once_with(category)

View file

@ -1,187 +1,77 @@
import datetime
import os
import pytest
from szurubooru import api, config, db, errors
from szurubooru.func import util, tags, tag_categories, cache
import unittest.mock
from szurubooru import api, db, errors
from szurubooru.func import tags, tag_categories
def assert_relations(relations, expected_tag_names):
actual_names = sorted([rel.names[0].name for rel in relations])
assert actual_names == sorted(expected_tag_names)
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({'privileges': {'tags:create': db.User.RANK_REGULAR}})
@pytest.fixture
def test_ctx(
tmpdir, config_injector, context_factory, user_factory, tag_factory):
config_injector({
'data_dir': str(tmpdir),
'tag_name_regex': '^[^!]*$',
'tag_category_name_regex': '^[^!]*$',
'privileges': {'tags:create': db.User.RANK_REGULAR},
})
db.session.add_all([
db.TagCategory(name) for name in ['meta', 'character', 'copyright']])
db.session.flush()
cache.purge()
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.tag_factory = tag_factory
ret.api = api.TagListApi()
return ret
def test_creating_simple_tags(test_ctx, fake_datetime):
with fake_datetime('1997-12-01'):
result = test_ctx.api.post(
test_ctx.context_factory(
input={
def test_creating_simple_tags(tag_factory, user_factory, context_factory):
with unittest.mock.patch('szurubooru.func.tags.create_tag'), \
unittest.mock.patch('szurubooru.func.tags.get_or_create_tags_by_names'), \
unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'):
tags.get_or_create_tags_by_names.return_value = ([], [])
tags.create_tag.return_value = tag_factory()
tags.serialize_tag.return_value = 'serialized tag'
result = api.tag_api.create_tag(
context_factory(
params={
'names': ['tag1', 'tag2'],
'category': 'meta',
'description': 'desc',
'suggestions': [],
'implications': [],
'suggestions': ['sug1', 'sug2'],
'implications': ['imp1', 'imp2'],
},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
assert len(result['snapshots']) == 1
del result['snapshots']
assert result == {
'names': ['tag1', 'tag2'],
'category': 'meta',
'description': 'desc',
'suggestions': [],
'implications': [],
'creationTime': datetime.datetime(1997, 12, 1),
'lastEditTime': None,
'usages': 0,
'version': 1,
}
tag = tags.get_tag_by_name('tag1')
assert [tag_name.name for tag_name in tag.names] == ['tag1', 'tag2']
assert tag.category.name == 'meta'
assert tag.last_edit_time is None
assert tag.post_count == 0
assert_relations(tag.suggestions, [])
assert_relations(tag.implications, [])
assert os.path.exists(os.path.join(config.config['data_dir'], 'tags.json'))
@pytest.mark.parametrize('input,expected_exception', [
({'names': None}, tags.InvalidTagNameError),
({'names': []}, tags.InvalidTagNameError),
({'names': [None]}, tags.InvalidTagNameError),
({'names': ['']}, tags.InvalidTagNameError),
({'names': ['!bad']}, tags.InvalidTagNameError),
({'names': ['x' * 65]}, tags.InvalidTagNameError),
({'category': None}, tag_categories.TagCategoryNotFoundError),
({'category': ''}, tag_categories.TagCategoryNotFoundError),
({'category': '!bad'}, tag_categories.TagCategoryNotFoundError),
({'suggestions': ['good', '!bad']}, tags.InvalidTagNameError),
({'implications': ['good', '!bad']}, tags.InvalidTagNameError),
])
def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception):
real_input={
'names': ['tag1', 'tag2'],
'category': 'meta',
'suggestions': [],
'implications': [],
}
for key, value in input.items():
real_input[key] = value
with pytest.raises(expected_exception):
test_ctx.api.post(
test_ctx.context_factory(
input=real_input,
user=test_ctx.user_factory()))
user=user_factory(rank=db.User.RANK_REGULAR)))
assert result == 'serialized tag'
tags.create_tag.assert_called_once_with(
['tag1', 'tag2'], 'meta', ['sug1', 'sug2'], ['imp1', 'imp2'])
tags.export_to_json.assert_called_once_with()
@pytest.mark.parametrize('field', ['names', 'category'])
def test_trying_to_omit_mandatory_field(test_ctx, field):
input = {
def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
params = {
'names': ['tag1', 'tag2'],
'category': 'meta',
'suggestions': [],
'implications': [],
}
del input[field]
del params[field]
with pytest.raises(errors.ValidationError):
test_ctx.api.post(
test_ctx.context_factory(
input=input,
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
api.tag_api.create_tag(
context_factory(
params=params,
user=user_factory(rank=db.User.RANK_REGULAR)))
@pytest.mark.parametrize('field', ['implications', 'suggestions'])
def test_omitting_optional_field(test_ctx, field):
input = {
def test_omitting_optional_field(
tag_factory, user_factory, context_factory, field):
params = {
'names': ['tag1', 'tag2'],
'category': 'meta',
'suggestions': [],
'implications': [],
}
del input[field]
result = test_ctx.api.post(
test_ctx.context_factory(
input=input,
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
assert result is not None
del params[field]
with unittest.mock.patch('szurubooru.func.tags.create_tag'), \
unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'):
tags.create_tag.return_value = tag_factory()
api.tag_api.create_tag(
context_factory(
params=params,
user=user_factory(rank=db.User.RANK_REGULAR)))
def test_creating_new_category(test_ctx):
with pytest.raises(tag_categories.TagCategoryNotFoundError):
test_ctx.api.post(
test_ctx.context_factory(
input={
'names': ['main'],
'category': 'new',
'suggestions': [],
'implications': [],
}, user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
@pytest.mark.parametrize('input,expected_suggestions,expected_implications', [
# new relations
({
'names': ['main'],
'category': 'meta',
'suggestions': ['sug1', 'sug2'],
'implications': ['imp1', 'imp2'],
}, ['sug1', 'sug2'], ['imp1', 'imp2']),
# overlapping relations
({
'names': ['main'],
'category': 'meta',
'suggestions': ['sug', 'shared'],
'implications': ['shared', 'imp'],
}, ['shared', 'sug'], ['imp', 'shared']),
# duplicate relations
({
'names': ['main'],
'category': 'meta',
'suggestions': ['sug', 'SUG'],
'implications': ['imp', 'IMP'],
}, ['sug'], ['imp']),
# overlapping duplicate relations
({
'names': ['main'],
'category': 'meta',
'suggestions': ['shared1', 'shared2'],
'implications': ['SHARED1', 'SHARED2'],
}, ['shared1', 'shared2'], ['shared1', 'shared2']),
])
def test_creating_new_suggestions_and_implications(
test_ctx, input, expected_suggestions, expected_implications):
result = test_ctx.api.post(
test_ctx.context_factory(
input=input, user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
assert result['suggestions'] == expected_suggestions
assert result['implications'] == expected_implications
tag = tags.get_tag_by_name('main')
assert_relations(tag.suggestions, expected_suggestions)
assert_relations(tag.implications, expected_implications)
for name in ['main'] + expected_suggestions + expected_implications:
assert tags.try_get_tag_by_name(name) is not None
def test_trying_to_create_tag_without_privileges(test_ctx):
def test_trying_to_create_tag_without_privileges(user_factory, context_factory):
with pytest.raises(errors.AuthError):
test_ctx.api.post(
test_ctx.context_factory(
input={
api.tag_api.create_tag(
context_factory(
params={
'names': ['tag'],
'category': 'meta',
'suggestions': ['tag'],
'implications': [],
},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)))
user=user_factory(rank=db.User.RANK_ANONYMOUS)))

View file

@ -1,50 +1,55 @@
import pytest
import os
from datetime import datetime
from szurubooru import api, config, db, errors
from szurubooru.func import util, tags
import unittest.mock
from szurubooru import api, db, errors
from szurubooru.func import tags
@pytest.fixture
def test_ctx(
tmpdir, config_injector, context_factory, tag_factory, user_factory):
config_injector({
'data_dir': str(tmpdir),
'privileges': {
'tags:delete': db.User.RANK_REGULAR,
},
})
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.tag_factory = tag_factory
ret.api = api.TagDetailApi()
return ret
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({'privileges': {'tags:delete': db.User.RANK_REGULAR}})
def test_deleting(test_ctx):
db.session.add(test_ctx.tag_factory(names=['tag']))
def test_deleting(user_factory, tag_factory, context_factory):
db.session.add(tag_factory(names=['tag']))
db.session.commit()
result = test_ctx.api.delete(
test_ctx.context_factory(
input={'version': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'tag')
assert result == {}
assert db.session.query(db.Tag).count() == 0
assert os.path.exists(os.path.join(config.config['data_dir'], 'tags.json'))
with unittest.mock.patch('szurubooru.func.tags.export_to_json'):
result = api.tag_api.delete_tag(
context_factory(
params={'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)),
{'tag_name': 'tag'})
assert result == {}
assert db.session.query(db.Tag).count() == 0
tags.export_to_json.assert_called_once_with()
def test_trying_to_delete_non_existing(test_ctx):
def test_deleting_used(user_factory, tag_factory, context_factory, post_factory):
tag = tag_factory(names=['tag'])
post = post_factory()
post.tags.append(tag)
db.session.add_all([tag, post])
db.session.commit()
with unittest.mock.patch('szurubooru.func.tags.export_to_json'):
api.tag_api.delete_tag(
context_factory(
params={'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)),
{'tag_name': 'tag'})
db.session.refresh(post)
assert db.session.query(db.Tag).count() == 0
assert post.tags == []
def test_trying_to_delete_non_existing(user_factory, context_factory):
with pytest.raises(tags.TagNotFoundError):
test_ctx.api.delete(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 'bad')
api.tag_api.delete_tag(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
{'tag_name': 'bad'})
def test_trying_to_delete_without_privileges(test_ctx):
db.session.add(test_ctx.tag_factory(names=['tag']))
def test_trying_to_delete_without_privileges(
user_factory, tag_factory, context_factory):
db.session.add(tag_factory(names=['tag']))
db.session.commit()
with pytest.raises(errors.AuthError):
test_ctx.api.delete(
test_ctx.context_factory(
input={'version': 1},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)),
'tag')
api.tag_api.delete_tag(
context_factory(
params={'version': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
{'tag_name': 'tag'})
assert db.session.query(db.Tag).count() == 1

View file

@ -1,34 +1,15 @@
import datetime
import os
import pytest
from szurubooru import api, config, db, errors
from szurubooru.func import util, tags
import unittest.mock
from szurubooru import api, db, errors
from szurubooru.func import tags
@pytest.fixture
def test_ctx(
tmpdir,
config_injector,
context_factory,
user_factory,
tag_factory,
tag_category_factory):
config_injector({
'data_dir': str(tmpdir),
'privileges': {
'tags:merge': db.User.RANK_REGULAR,
},
})
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.tag_factory = tag_factory
ret.tag_category_factory = tag_category_factory
ret.api = api.TagMergeApi()
return ret
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({'privileges': {'tags:merge': db.User.RANK_REGULAR}})
def test_merging_with_usages(test_ctx, fake_datetime, post_factory):
source_tag = test_ctx.tag_factory(names=['source'])
target_tag = test_ctx.tag_factory(names=['target'])
def test_merging(user_factory, tag_factory, context_factory, post_factory):
source_tag = tag_factory(names=['source'])
target_tag = tag_factory(names=['target'])
db.session.add_all([source_tag, target_tag])
db.session.flush()
assert source_tag.post_count == 0
@ -39,73 +20,78 @@ def test_merging_with_usages(test_ctx, fake_datetime, post_factory):
db.session.commit()
assert source_tag.post_count == 1
assert target_tag.post_count == 0
with fake_datetime('1997-12-01'):
result = test_ctx.api.post(
test_ctx.context_factory(
input={
with unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \
unittest.mock.patch('szurubooru.func.tags.merge_tags'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'):
result = api.tag_api.merge_tags(
context_factory(
params={
'removeVersion': 1,
'mergeToVersion': 1,
'remove': 'source',
'mergeTo': 'target',
},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
assert tags.try_get_tag_by_name('source') is None
assert tags.get_tag_by_name('target').post_count == 1
user=user_factory(rank=db.User.RANK_REGULAR)))
tags.merge_tags.called_once_with(source_tag, target_tag)
tags.export_to_json.assert_called_once_with()
@pytest.mark.parametrize(
'field', ['remove', 'mergeTo', 'removeVersion', 'mergeToVersion'])
def test_trying_to_omit_mandatory_field(test_ctx, field):
def test_trying_to_omit_mandatory_field(
user_factory, tag_factory, context_factory, field):
db.session.add_all([
test_ctx.tag_factory(names=['source']),
test_ctx.tag_factory(names=['target']),
tag_factory(names=['source']),
tag_factory(names=['target']),
])
db.session.commit()
input = {
params = {
'removeVersion': 1,
'mergeToVersion': 1,
'remove': 'source',
'mergeTo': 'target',
}
del input[field]
del params[field]
with pytest.raises(errors.ValidationError):
test_ctx.api.post(
test_ctx.context_factory(
input=input,
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
api.tag_api.merge_tags(
context_factory(
params=params,
user=user_factory(rank=db.User.RANK_REGULAR)))
def test_trying_to_merge_non_existing(test_ctx):
db.session.add(test_ctx.tag_factory(names=['good']))
def test_trying_to_merge_non_existing(
user_factory, tag_factory, context_factory):
db.session.add(tag_factory(names=['good']))
db.session.commit()
with pytest.raises(tags.TagNotFoundError):
test_ctx.api.post(
test_ctx.context_factory(
input={'remove': 'good', 'mergeTo': 'bad'},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
api.tag_api.merge_tags(
context_factory(
params={'remove': 'good', 'mergeTo': 'bad'},
user=user_factory(rank=db.User.RANK_REGULAR)))
with pytest.raises(tags.TagNotFoundError):
test_ctx.api.post(
test_ctx.context_factory(
input={'remove': 'bad', 'mergeTo': 'good'},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
api.tag_api.merge_tags(
context_factory(
params={'remove': 'bad', 'mergeTo': 'good'},
user=user_factory(rank=db.User.RANK_REGULAR)))
@pytest.mark.parametrize('input', [
@pytest.mark.parametrize('params', [
{'names': 'whatever'},
{'category': 'whatever'},
{'suggestions': ['whatever']},
{'implications': ['whatever']},
])
def test_trying_to_merge_without_privileges(test_ctx, input):
def test_trying_to_merge_without_privileges(
user_factory, tag_factory, context_factory, params):
db.session.add_all([
test_ctx.tag_factory(names=['source']),
test_ctx.tag_factory(names=['target']),
tag_factory(names=['source']),
tag_factory(names=['target']),
])
db.session.commit()
with pytest.raises(errors.AuthError):
test_ctx.api.post(
test_ctx.context_factory(
input={
api.tag_api.merge_tags(
context_factory(
params={
'removeVersion': 1,
'mergeToVersion': 1,
'remove': 'source',
'mergeTo': 'target',
},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)))
user=user_factory(rank=db.User.RANK_ANONYMOUS)))

View file

@ -1,82 +1,64 @@
import datetime
import pytest
import unittest.mock
from szurubooru import api, db, errors
from szurubooru.func import util, tags
from szurubooru.func import tags
@pytest.fixture
def test_ctx(
context_factory,
config_injector,
user_factory,
tag_factory,
tag_category_factory):
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({
'privileges': {
'tags:list': db.User.RANK_REGULAR,
'tags:view': db.User.RANK_REGULAR,
},
'thumbnails': {'avatar_width': 200},
})
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.tag_factory = tag_factory
ret.tag_category_factory = tag_category_factory
ret.list_api = api.TagListApi()
ret.detail_api = api.TagDetailApi()
return ret
def test_retrieving_multiple(test_ctx):
tag1 = test_ctx.tag_factory(names=['t1'])
tag2 = test_ctx.tag_factory(names=['t2'])
def test_retrieving_multiple(user_factory, tag_factory, context_factory):
tag1 = tag_factory(names=['t1'])
tag2 = tag_factory(names=['t2'])
db.session.add_all([tag1, tag2])
result = test_ctx.list_api.get(
test_ctx.context_factory(
input={'query': '', 'page': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
assert result['query'] == ''
assert result['page'] == 1
assert result['pageSize'] == 100
assert result['total'] == 2
assert [t['names'] for t in result['results']] == [['t1'], ['t2']]
with unittest.mock.patch('szurubooru.func.tags.serialize_tag'):
tags.serialize_tag.return_value = 'serialized tag'
result = api.tag_api.get_tags(
context_factory(
params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_REGULAR)))
assert result == {
'query': '',
'page': 1,
'pageSize': 100,
'total': 2,
'results': ['serialized tag', 'serialized tag'],
}
def test_trying_to_retrieve_multiple_without_privileges(test_ctx):
def test_trying_to_retrieve_multiple_without_privileges(
user_factory, context_factory):
with pytest.raises(errors.AuthError):
test_ctx.list_api.get(
test_ctx.context_factory(
input={'query': '', 'page': 1},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)))
api.tag_api.get_tags(
context_factory(
params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
def test_retrieving_single(test_ctx):
category = test_ctx.tag_category_factory(name='meta')
db.session.add(test_ctx.tag_factory(names=['tag'], category=category))
result = test_ctx.detail_api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'tag')
assert result == {
'names': ['tag'],
'category': 'meta',
'description': None,
'creationTime': datetime.datetime(1996, 1, 1),
'lastEditTime': None,
'suggestions': [],
'implications': [],
'usages': 0,
'snapshots': [],
'version': 1,
}
def test_retrieving_single(user_factory, tag_factory, context_factory):
db.session.add(tag_factory(names=['tag']))
with unittest.mock.patch('szurubooru.func.tags.serialize_tag'):
tags.serialize_tag.return_value = 'serialized tag'
result = api.tag_api.get_tag(
context_factory(
user=user_factory(rank=db.User.RANK_REGULAR)),
{'tag_name': 'tag'})
assert result == 'serialized tag'
def test_trying_to_retrieve_single_non_existing(test_ctx):
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
with pytest.raises(tags.TagNotFoundError):
test_ctx.detail_api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'-')
api.tag_api.get_tag(
context_factory(
user=user_factory(rank=db.User.RANK_REGULAR)),
{'tag_name': '-'})
def test_trying_to_retrieve_single_without_privileges(test_ctx):
def test_trying_to_retrieve_single_without_privileges(
user_factory, context_factory):
with pytest.raises(errors.AuthError):
test_ctx.detail_api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)),
'-')
api.tag_api.get_tag(
context_factory(
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
{'tag_name': '-'})

View file

@ -1,56 +1,47 @@
import datetime
import pytest
import unittest.mock
from szurubooru import api, db, errors
from szurubooru.func import util, tags
from szurubooru.func import tags
def assert_results(result, expected_tag_names_and_occurrences):
actual_tag_names_and_occurences = []
for item in result['results']:
tag_name = item['tag']['names'][0]
occurrences = item['occurrences']
actual_tag_names_and_occurences.append((tag_name, occurrences))
assert actual_tag_names_and_occurences == expected_tag_names_and_occurrences
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({'privileges': {'tags:view': db.User.RANK_REGULAR}})
@pytest.fixture
def test_ctx(
context_factory, config_injector, user_factory, tag_factory, post_factory):
config_injector({
'privileges': {
'tags:view': db.User.RANK_REGULAR,
},
'thumbnails': {'avatar_width': 200},
})
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.tag_factory = tag_factory
ret.post_factory = post_factory
ret.api = api.TagSiblingsApi()
return ret
def test_get_tag_siblings(user_factory, tag_factory, post_factory, context_factory):
db.session.add(tag_factory(names=['tag']))
with unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \
unittest.mock.patch('szurubooru.func.tags.get_tag_siblings'):
tags.serialize_tag.side_effect = \
lambda tag, *args, **kwargs: \
'serialized tag %s' % tag.names[0].name
tags.get_tag_siblings.return_value = [
(tag_factory(names=['sib1']), 1),
(tag_factory(names=['sib2']), 3),
]
result = api.tag_api.get_tag_siblings(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
{'tag_name': 'tag'})
assert result == {
'results': [
{
'tag': 'serialized tag sib1',
'occurrences': 1,
},
{
'tag': 'serialized tag sib2',
'occurrences': 3,
},
],
}
def test_used_with_others(test_ctx):
tag1 = test_ctx.tag_factory(names=['tag1'])
tag2 = test_ctx.tag_factory(names=['tag2'])
post = test_ctx.post_factory()
post.tags = [tag1, tag2]
db.session.add_all([post, tag1, tag2])
result = test_ctx.api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 'tag1')
assert_results(result, [('tag2', 1)])
result = test_ctx.api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 'tag2')
assert_results(result, [('tag1', 1)])
def test_trying_to_retrieve_non_existing(test_ctx):
def test_trying_to_retrieve_non_existing(user_factory, context_factory):
with pytest.raises(tags.TagNotFoundError):
test_ctx.api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), '-')
api.tag_api.get_tag_siblings(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
{'tag_name': '-'})
def test_trying_to_retrieve_without_privileges(test_ctx):
def test_trying_to_retrieve_without_privileges(user_factory, context_factory):
with pytest.raises(errors.AuthError):
test_ctx.api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)), '-')
api.tag_api.get_tag_siblings(
context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)),
{'tag_name': '-'})

View file

@ -1,20 +1,11 @@
import datetime
import os
import pytest
from szurubooru import api, config, db, errors
from szurubooru.func import util, tags, tag_categories, cache
import unittest.mock
from szurubooru import api, db, errors
from szurubooru.func import tags
def assert_relations(relations, expected_tag_names):
actual_names = sorted([rel.names[0].name for rel in relations])
assert actual_names == sorted(expected_tag_names)
@pytest.fixture
def test_ctx(
tmpdir, config_injector, context_factory, user_factory, tag_factory):
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({
'data_dir': str(tmpdir),
'tag_name_regex': '^[^!]*$',
'tag_category_name_regex': '^[^!]*$',
'privileges': {
'tags:create': db.User.RANK_REGULAR,
'tags:edit:names': db.User.RANK_REGULAR,
@ -24,118 +15,115 @@ def test_ctx(
'tags:edit:implications': db.User.RANK_REGULAR,
},
})
db.session.add_all([
db.TagCategory(name) for name in ['meta', 'character', 'copyright']])
db.session.commit()
cache.purge()
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.tag_factory = tag_factory
ret.api = api.TagDetailApi()
return ret
def test_simple_updating(test_ctx, fake_datetime):
tag = test_ctx.tag_factory(names=['tag1', 'tag2'])
def test_simple_updating(user_factory, tag_factory, context_factory, fake_datetime):
auth_user = user_factory(rank=db.User.RANK_REGULAR)
tag = tag_factory(names=['tag1', 'tag2'])
db.session.add(tag)
db.session.commit()
with fake_datetime('1997-12-01'):
result = test_ctx.api.put(
test_ctx.context_factory(
input={
with unittest.mock.patch('szurubooru.func.tags.create_tag'), \
unittest.mock.patch('szurubooru.func.tags.get_or_create_tags_by_names'), \
unittest.mock.patch('szurubooru.func.tags.update_tag_names'), \
unittest.mock.patch('szurubooru.func.tags.update_tag_category_name'), \
unittest.mock.patch('szurubooru.func.tags.update_tag_description'), \
unittest.mock.patch('szurubooru.func.tags.update_tag_suggestions'), \
unittest.mock.patch('szurubooru.func.tags.update_tag_implications'), \
unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'):
tags.get_or_create_tags_by_names.return_value = ([], [])
tags.serialize_tag.return_value = 'serialized tag'
result = api.tag_api.update_tag(
context_factory(
params={
'version': 1,
'names': ['tag3'],
'category': 'character',
'description': 'desc',
'suggestions': ['sug1', 'sug2'],
'implications': ['imp1', 'imp2'],
},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'tag1')
assert len(result['snapshots']) == 1
del result['snapshots']
assert result == {
'names': ['tag3'],
'category': 'character',
'description': 'desc',
'suggestions': [],
'implications': [],
'creationTime': datetime.datetime(1996, 1, 1),
'lastEditTime': datetime.datetime(1997, 12, 1),
'usages': 0,
'version': 2,
}
assert tags.try_get_tag_by_name('tag1') is None
assert tags.try_get_tag_by_name('tag2') is None
tag = tags.get_tag_by_name('tag3')
assert tag is not None
assert [tag_name.name for tag_name in tag.names] == ['tag3']
assert tag.category.name == 'character'
assert tag.suggestions == []
assert tag.implications == []
assert os.path.exists(os.path.join(config.config['data_dir'], 'tags.json'))
@pytest.mark.parametrize('input,expected_exception', [
({'names': None}, tags.InvalidTagNameError),
({'names': []}, tags.InvalidTagNameError),
({'names': [None]}, tags.InvalidTagNameError),
({'names': ['']}, tags.InvalidTagNameError),
({'names': ['!bad']}, tags.InvalidTagNameError),
({'names': ['x' * 65]}, tags.InvalidTagNameError),
({'category': None}, tag_categories.TagCategoryNotFoundError),
({'category': ''}, tag_categories.TagCategoryNotFoundError),
({'category': '!bad'}, tag_categories.TagCategoryNotFoundError),
({'suggestions': ['good', '!bad']}, tags.InvalidTagNameError),
({'implications': ['good', '!bad']}, tags.InvalidTagNameError),
])
def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception):
db.session.add(test_ctx.tag_factory(names=['tag1']))
db.session.commit()
with pytest.raises(expected_exception):
test_ctx.api.put(
test_ctx.context_factory(
input={**input, **{'version': 1}},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'tag1')
user=auth_user),
{'tag_name': 'tag1'})
assert result == 'serialized tag'
tags.create_tag.assert_not_called()
tags.update_tag_names.assert_called_once_with(tag, ['tag3'])
tags.update_tag_category_name.assert_called_once_with(tag, 'character')
tags.update_tag_description.assert_called_once_with(tag, 'desc')
tags.update_tag_suggestions.assert_called_once_with(tag, ['sug1', 'sug2'])
tags.update_tag_implications.assert_called_once_with(tag, ['imp1', 'imp2'])
tags.serialize_tag.assert_called_once_with(tag, options=None)
@pytest.mark.parametrize(
'field', ['names', 'category', 'description', 'implications', 'suggestions'])
def test_omitting_optional_field(test_ctx, field):
db.session.add(test_ctx.tag_factory(names=['tag']))
def test_omitting_optional_field(
user_factory, tag_factory, context_factory, field):
db.session.add(tag_factory(names=['tag']))
db.session.commit()
input = {
params = {
'names': ['tag1', 'tag2'],
'category': 'meta',
'description': 'desc',
'suggestions': [],
'implications': [],
}
del input[field]
result = test_ctx.api.put(
test_ctx.context_factory(
input={**input, **{'version': 1}},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'tag')
assert result is not None
del params[field]
with unittest.mock.patch('szurubooru.func.tags.create_tag'), \
unittest.mock.patch('szurubooru.func.tags.update_tag_names'), \
unittest.mock.patch('szurubooru.func.tags.update_tag_category_name'), \
unittest.mock.patch('szurubooru.func.tags.serialize_tag'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'):
api.tag_api.update_tag(
context_factory(
params={**params, **{'version': 1}},
user=user_factory(rank=db.User.RANK_REGULAR)),
{'tag_name': 'tag'})
def test_trying_to_update_non_existing(test_ctx):
def test_trying_to_update_non_existing(user_factory, context_factory):
with pytest.raises(tags.TagNotFoundError):
test_ctx.api.put(
test_ctx.context_factory(
input={'names': ['dummy']},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'tag1')
api.tag_api.update_tag(
context_factory(
params={'names': ['dummy']},
user=user_factory(rank=db.User.RANK_REGULAR)),
{'tag_name': 'tag1'})
@pytest.mark.parametrize('input', [
@pytest.mark.parametrize('params', [
{'names': 'whatever'},
{'category': 'whatever'},
{'suggestions': ['whatever']},
{'implications': ['whatever']},
])
def test_trying_to_update_without_privileges(test_ctx, input):
db.session.add(test_ctx.tag_factory(names=['tag']))
def test_trying_to_update_without_privileges(
user_factory, tag_factory, context_factory, params):
db.session.add(tag_factory(names=['tag']))
db.session.commit()
with pytest.raises(errors.AuthError):
test_ctx.api.put(
test_ctx.context_factory(
input={**input, **{'version': 1}},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)),
'tag')
api.tag_api.update_tag(
context_factory(
params={**params, **{'version': 1}},
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
{'tag_name': 'tag'})
def test_trying_to_create_tags_without_privileges(
config_injector, context_factory, tag_factory, user_factory):
tag = tag_factory(names=['tag'])
db.session.add(tag)
db.session.commit()
config_injector({'privileges': {
'tags:create': db.User.RANK_ADMINISTRATOR,
'tags:edit:suggestions': db.User.RANK_REGULAR,
'tags:edit:implications': db.User.RANK_REGULAR,
}})
with unittest.mock.patch('szurubooru.func.tags.get_or_create_tags_by_names'):
tags.get_or_create_tags_by_names.return_value = ([], ['new-tag'])
with pytest.raises(errors.AuthError):
api.tag_api.update_tag(
context_factory(
params={'suggestions': ['tag1', 'tag2'], 'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)),
{'tag_name': 'tag'})
with pytest.raises(errors.AuthError):
api.tag_api.update_tag(
context_factory(
params={'implications': ['tag1', 'tag2'], 'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)),
{'tag_name': 'tag'})

View file

@ -1,230 +1,79 @@
import datetime
import pytest
from szurubooru import api, config, db, errors
from szurubooru.func import auth, util, users
import unittest.mock
from szurubooru import api, db, errors
from szurubooru.func import users
EMPTY_PIXEL = \
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' \
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b'
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({'privileges': {'users:create': 'regular'}})
@pytest.fixture
def test_ctx(tmpdir, config_injector, context_factory, user_factory):
config_injector({
'secret': '',
'user_name_regex': '[^!]{3,}',
'password_regex': '[^!]{3,}',
'default_rank': db.User.RANK_REGULAR,
'thumbnails': {'avatar_width': 200, 'avatar_height': 200},
'privileges': {'users:create': 'anonymous'},
'data_dir': str(tmpdir.mkdir('data')),
'data_url': 'http://example.com/data/',
})
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.api = api.UserListApi()
return ret
def test_creating_user(test_ctx, fake_datetime):
with fake_datetime('1969-02-12'):
result = test_ctx.api.post(
test_ctx.context_factory(
input={
def test_creating_user(user_factory, context_factory, fake_datetime):
user = user_factory()
with unittest.mock.patch('szurubooru.func.users.create_user'), \
unittest.mock.patch('szurubooru.func.users.update_user_name'), \
unittest.mock.patch('szurubooru.func.users.update_user_password'), \
unittest.mock.patch('szurubooru.func.users.update_user_email'), \
unittest.mock.patch('szurubooru.func.users.update_user_rank'), \
unittest.mock.patch('szurubooru.func.users.update_user_avatar'), \
unittest.mock.patch('szurubooru.func.users.serialize_user'), \
fake_datetime('1969-02-12'):
users.serialize_user.return_value = 'serialized user'
users.create_user.return_value = user
result = api.user_api.create_user(
context_factory(
params={
'name': 'chewie1',
'email': 'asd@asd.asd',
'password': 'oks',
'rank': 'moderator',
'avatarStyle': 'manual',
},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
assert result == {
'avatarStyle': 'gravatar',
'avatarUrl': 'https://gravatar.com/avatar/' +
'6f370c8c7109534c3d5c394123a477d7?d=retro&s=200',
'creationTime': datetime.datetime(1969, 2, 12),
'lastLoginTime': None,
'name': 'chewie1',
'rank': 'administrator',
'email': 'asd@asd.asd',
'commentCount': 0,
'likedPostCount': 0,
'dislikedPostCount': 0,
'favoritePostCount': 0,
'uploadedPostCount': 0,
'version': 1,
}
user = users.get_user_by_name('chewie1')
assert user.name == 'chewie1'
assert user.email == 'asd@asd.asd'
assert user.rank == db.User.RANK_ADMINISTRATOR
assert auth.is_valid_password(user, 'oks') is True
assert auth.is_valid_password(user, 'invalid') is False
def test_first_user_becomes_admin_others_not(test_ctx):
result1 = test_ctx.api.post(
test_ctx.context_factory(
input={
'name': 'chewie1',
'email': 'asd@asd.asd',
'password': 'oks',
},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)))
result2 = test_ctx.api.post(
test_ctx.context_factory(
input={
'name': 'chewie2',
'email': 'asd@asd.asd',
'password': 'sok',
},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)))
assert result1['rank'] == 'administrator'
assert result2['rank'] == 'regular'
first_user = users.get_user_by_name('chewie1')
other_user = users.get_user_by_name('chewie2')
assert first_user.rank == db.User.RANK_ADMINISTRATOR
assert other_user.rank == db.User.RANK_REGULAR
def test_first_user_does_not_become_admin_if_they_dont_wish_so(test_ctx):
result = test_ctx.api.post(
test_ctx.context_factory(
input={
'name': 'chewie1',
'email': 'asd@asd.asd',
'password': 'oks',
'rank': 'regular',
},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)))
assert result['rank'] == 'regular'
def test_trying_to_become_someone_else(test_ctx):
test_ctx.api.post(
test_ctx.context_factory(
input={
'name': 'chewie',
'email': 'asd@asd.asd',
'password': 'oks',
},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
with pytest.raises(users.UserAlreadyExistsError):
test_ctx.api.post(
test_ctx.context_factory(
input={
'name': 'chewie',
'email': 'asd@asd.asd',
'password': 'oks',
},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
with pytest.raises(users.UserAlreadyExistsError):
test_ctx.api.post(
test_ctx.context_factory(
input={
'name': 'CHEWIE',
'email': 'asd@asd.asd',
'password': 'oks',
},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
@pytest.mark.parametrize('input,expected_exception', [
({'name': None}, users.InvalidUserNameError),
({'name': ''}, users.InvalidUserNameError),
({'name': '!bad'}, users.InvalidUserNameError),
({'name': 'x' * 51}, users.InvalidUserNameError),
({'password': None}, users.InvalidPasswordError),
({'password': ''}, users.InvalidPasswordError),
({'password': '!bad'}, users.InvalidPasswordError),
({'rank': None}, users.InvalidRankError),
({'rank': ''}, users.InvalidRankError),
({'rank': 'bad'}, users.InvalidRankError),
({'rank': 'anonymous'}, users.InvalidRankError),
({'rank': 'nobody'}, users.InvalidRankError),
({'email': 'bad'}, users.InvalidEmailError),
({'email': 'x@' * 65 + '.com'}, users.InvalidEmailError),
({'avatarStyle': None}, users.InvalidAvatarError),
({'avatarStyle': ''}, users.InvalidAvatarError),
({'avatarStyle': 'invalid'}, users.InvalidAvatarError),
({'avatarStyle': 'manual'}, users.InvalidAvatarError), # missing file
])
def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception):
real_input={
'name': 'chewie',
'email': 'asd@asd.asd',
'password': 'oks',
}
for key, value in input.items():
real_input[key] = value
with pytest.raises(expected_exception):
test_ctx.api.post(
test_ctx.context_factory(
input=real_input,
user=test_ctx.user_factory(
name='u1', rank=db.User.RANK_ADMINISTRATOR)))
files={'avatar': b'...'},
user=user_factory(rank=db.User.RANK_REGULAR)))
assert result == 'serialized user'
users.create_user.assert_called_once_with('chewie1', 'oks', 'asd@asd.asd')
assert not users.update_user_name.called
assert not users.update_user_password.called
assert not users.update_user_email.called
users.update_user_rank.called_once_with(user, 'moderator')
users.update_user_avatar.called_once_with(user, 'manual', b'...')
@pytest.mark.parametrize('field', ['name', 'password'])
def test_trying_to_omit_mandatory_field(test_ctx, field):
input = {
def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
params = {
'name': 'chewie',
'email': 'asd@asd.asd',
'password': 'oks',
}
del input[field]
with pytest.raises(errors.ValidationError):
test_ctx.api.post(
test_ctx.context_factory(
input=input,
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
user = user_factory()
auth_user = user_factory(rank=db.User.RANK_REGULAR)
del params[field]
with unittest.mock.patch('szurubooru.func.users.create_user'), \
pytest.raises(errors.MissingRequiredParameterError):
users.create_user.return_value = user
api.user_api.create_user(context_factory(params=params, user=auth_user))
@pytest.mark.parametrize('field', ['rank', 'email', 'avatarStyle'])
def test_omitting_optional_field(test_ctx, field):
input = {
def test_omitting_optional_field(user_factory, context_factory, field):
params = {
'name': 'chewie',
'email': 'asd@asd.asd',
'password': 'oks',
'rank': 'moderator',
'avatarStyle': 'manual',
'avatarStyle': 'gravatar',
}
del input[field]
result = test_ctx.api.post(
test_ctx.context_factory(
input=input,
files={'avatar': EMPTY_PIXEL},
user=test_ctx.user_factory(rank=db.User.RANK_MODERATOR)))
assert result is not None
del params[field]
user = user_factory()
auth_user = user_factory(rank=db.User.RANK_MODERATOR)
with unittest.mock.patch('szurubooru.func.users.create_user'), \
unittest.mock.patch('szurubooru.func.users.update_user_avatar'), \
unittest.mock.patch('szurubooru.func.users.serialize_user'):
users.create_user.return_value = user
api.user_api.create_user(
context_factory(params=params, user=auth_user))
def test_mods_trying_to_become_admin(test_ctx):
user1 = test_ctx.user_factory(name='u1', rank=db.User.RANK_MODERATOR)
user2 = test_ctx.user_factory(name='u2', rank=db.User.RANK_MODERATOR)
db.session.add_all([user1, user2])
context = test_ctx.context_factory(input={
'name': 'chewie',
'email': 'asd@asd.asd',
'password': 'oks',
'rank': 'administrator',
}, user=user1)
def test_trying_to_create_user_without_privileges(context_factory, user_factory):
with pytest.raises(errors.AuthError):
test_ctx.api.post(context)
def test_admin_creating_mod_account(test_ctx):
user = test_ctx.user_factory(rank=db.User.RANK_ADMINISTRATOR)
db.session.add(user)
context = test_ctx.context_factory(input={
'name': 'chewie',
'email': 'asd@asd.asd',
'password': 'oks',
'rank': 'moderator',
}, user=user)
result = test_ctx.api.post(context)
assert result['rank'] == 'moderator'
def test_uploading_avatar(test_ctx):
response = test_ctx.api.post(
test_ctx.context_factory(
input={
'name': 'chewie',
'email': 'asd@asd.asd',
'password': 'oks',
'avatarStyle': 'manual',
},
files={'avatar': EMPTY_PIXEL},
user=test_ctx.user_factory(rank=db.User.RANK_MODERATOR)))
user = users.get_user_by_name('chewie')
assert user.avatar_style == user.AVATAR_MANUAL
assert response['avatarUrl'] == 'http://example.com/data/avatars/chewie.png'
api.user_api.create_user(context_factory(
params='whatever',
user=user_factory(rank=db.User.RANK_ANONYMOUS)))

View file

@ -1,54 +1,52 @@
import pytest
from datetime import datetime
from szurubooru import api, db, errors
from szurubooru.func import util, users
from szurubooru.func import users
@pytest.fixture
def test_ctx(config_injector, context_factory, user_factory):
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({
'privileges': {
'users:delete:self': db.User.RANK_REGULAR,
'users:delete:any': db.User.RANK_MODERATOR,
},
})
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.api = api.UserDetailApi()
return ret
def test_deleting_oneself(test_ctx):
user = test_ctx.user_factory(name='u', rank=db.User.RANK_REGULAR)
def test_deleting_oneself(user_factory, context_factory):
user = user_factory(name='u', rank=db.User.RANK_REGULAR)
db.session.add(user)
db.session.commit()
result = test_ctx.api.delete(
test_ctx.context_factory(input={'version': 1}, user=user), 'u')
result = api.user_api.delete_user(
context_factory(
params={'version': 1}, user=user), {'user_name': 'u'})
assert result == {}
assert db.session.query(db.User).count() == 0
def test_deleting_someone_else(test_ctx):
user1 = test_ctx.user_factory(name='u1', rank=db.User.RANK_REGULAR)
user2 = test_ctx.user_factory(name='u2', rank=db.User.RANK_MODERATOR)
def test_deleting_someone_else(user_factory, context_factory):
user1 = user_factory(name='u1', rank=db.User.RANK_REGULAR)
user2 = user_factory(name='u2', rank=db.User.RANK_MODERATOR)
db.session.add_all([user1, user2])
db.session.commit()
test_ctx.api.delete(
test_ctx.context_factory(input={'version': 1}, user=user2), 'u1')
api.user_api.delete_user(
context_factory(
params={'version': 1}, user=user2), {'user_name': 'u1'})
assert db.session.query(db.User).count() == 1
def test_trying_to_delete_someone_else_without_privileges(test_ctx):
user1 = test_ctx.user_factory(name='u1', rank=db.User.RANK_REGULAR)
user2 = test_ctx.user_factory(name='u2', rank=db.User.RANK_REGULAR)
def test_trying_to_delete_someone_else_without_privileges(
user_factory, context_factory):
user1 = user_factory(name='u1', rank=db.User.RANK_REGULAR)
user2 = user_factory(name='u2', rank=db.User.RANK_REGULAR)
db.session.add_all([user1, user2])
db.session.commit()
with pytest.raises(errors.AuthError):
test_ctx.api.delete(
test_ctx.context_factory(input={'version': 1}, user=user2), 'u1')
api.user_api.delete_user(
context_factory(
params={'version': 1}, user=user2), {'user_name': 'u1'})
assert db.session.query(db.User).count() == 2
def test_trying_to_delete_non_existing(test_ctx):
def test_trying_to_delete_non_existing(user_factory, context_factory):
with pytest.raises(users.UserNotFoundError):
test_ctx.api.delete(
test_ctx.context_factory(
input={'version': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'bad')
api.user_api.delete_user(
context_factory(
params={'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)),
{'user_name': 'bad'})

View file

@ -1,83 +1,64 @@
import datetime
import unittest.mock
import pytest
from szurubooru import api, db, errors
from szurubooru.func import util, users
from szurubooru.func import users
@pytest.fixture
def test_ctx(context_factory, config_injector, user_factory):
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({
'privileges': {
'users:list': db.User.RANK_REGULAR,
'users:view': db.User.RANK_REGULAR,
'users:edit:any:email': db.User.RANK_MODERATOR,
},
'thumbnails': {'avatar_width': 200},
})
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.list_api = api.UserListApi()
ret.detail_api = api.UserDetailApi()
return ret
def test_retrieving_multiple(test_ctx):
user1 = test_ctx.user_factory(name='u1', rank=db.User.RANK_MODERATOR)
user2 = test_ctx.user_factory(name='u2', rank=db.User.RANK_MODERATOR)
def test_retrieving_multiple(user_factory, context_factory):
user1 = user_factory(name='u1', rank=db.User.RANK_MODERATOR)
user2 = user_factory(name='u2', rank=db.User.RANK_MODERATOR)
db.session.add_all([user1, user2])
result = test_ctx.list_api.get(
test_ctx.context_factory(
input={'query': '', 'page': 1},
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)))
assert result['query'] == ''
assert result['page'] == 1
assert result['pageSize'] == 100
assert result['total'] == 2
assert [u['name'] for u in result['results']] == ['u1', 'u2']
with unittest.mock.patch('szurubooru.func.users.serialize_user'):
users.serialize_user.return_value = 'serialized user'
result = api.user_api.get_users(
context_factory(
params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_REGULAR)))
assert result == {
'query': '',
'page': 1,
'pageSize': 100,
'total': 2,
'results': ['serialized user', 'serialized user'],
}
def test_trying_to_retrieve_multiple_without_privileges(test_ctx):
def test_trying_to_retrieve_multiple_without_privileges(
user_factory, context_factory):
with pytest.raises(errors.AuthError):
test_ctx.list_api.get(
test_ctx.context_factory(
input={'query': '', 'page': 1},
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)))
api.user_api.get_users(
context_factory(
params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
def test_retrieving_single(test_ctx):
db.session.add(test_ctx.user_factory(name='u1', rank=db.User.RANK_REGULAR))
result = test_ctx.detail_api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'u1')
assert result == {
'name': 'u1',
'rank': db.User.RANK_REGULAR,
'creationTime': datetime.datetime(1997, 1, 1),
'lastLoginTime': None,
'avatarStyle': 'gravatar',
'avatarUrl': 'https://gravatar.com/avatar/' +
'275876e34cf609db118f3d84b799a790?d=retro&s=200',
'email': False,
'commentCount': 0,
'likedPostCount': False,
'dislikedPostCount': False,
'favoritePostCount': 0,
'uploadedPostCount': 0,
'version': 1,
}
assert result['email'] is False
assert result['likedPostCount'] is False
assert result['dislikedPostCount'] is False
def test_retrieving_single(user_factory, context_factory):
user = user_factory(name='u1', rank=db.User.RANK_REGULAR)
auth_user = user_factory(rank=db.User.RANK_REGULAR)
db.session.add(user)
with unittest.mock.patch('szurubooru.func.users.serialize_user'):
users.serialize_user.return_value = 'serialized user'
result = api.user_api.get_user(
context_factory(user=auth_user), {'user_name': 'u1'})
assert result == 'serialized user'
def test_trying_to_retrieve_single_non_existing(test_ctx):
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR)
with pytest.raises(users.UserNotFoundError):
test_ctx.detail_api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)),
'-')
api.user_api.get_user(
context_factory(user=auth_user), {'user_name': '-'})
def test_trying_to_retrieve_single_without_privileges(test_ctx):
db.session.add(test_ctx.user_factory(name='u1', rank=db.User.RANK_REGULAR))
def test_trying_to_retrieve_single_without_privileges(
user_factory, context_factory):
auth_user = user_factory(rank=db.User.RANK_ANONYMOUS)
db.session.add(user_factory(name='u1', rank=db.User.RANK_REGULAR))
with pytest.raises(errors.AuthError):
test_ctx.detail_api.get(
test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_ANONYMOUS)),
'u1')
api.user_api.get_user(
context_factory(user=auth_user), {'user_name': 'u1'})

View file

@ -1,20 +1,12 @@
import datetime
import pytest
from szurubooru import api, config, db, errors
from szurubooru.func import auth, util, users
import unittest.mock
from datetime import datetime
from szurubooru import api, db, errors
from szurubooru.func import users
EMPTY_PIXEL = \
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' \
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b'
@pytest.fixture
def test_ctx(tmpdir, config_injector, context_factory, user_factory):
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({
'secret': '',
'user_name_regex': '^[^!]{3,}$',
'password_regex': '^[^!]{3,}$',
'thumbnails': {'avatar_width': 200, 'avatar_height': 200},
'privileges': {
'users:edit:self:name': db.User.RANK_REGULAR,
'users:edit:self:pass': db.User.RANK_REGULAR,
@ -27,203 +19,97 @@ def test_ctx(tmpdir, config_injector, context_factory, user_factory):
'users:edit:any:rank': db.User.RANK_ADMINISTRATOR,
'users:edit:any:avatar': db.User.RANK_ADMINISTRATOR,
},
'data_dir': str(tmpdir.mkdir('data')),
'data_url': 'http://example.com/data/',
})
ret = util.dotdict()
ret.context_factory = context_factory
ret.user_factory = user_factory
ret.api = api.UserDetailApi()
return ret
def test_updating_user(test_ctx):
user = test_ctx.user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR)
def test_updating_user(context_factory, user_factory):
user = user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR)
auth_user = user_factory(rank=db.User.RANK_ADMINISTRATOR)
db.session.add(user)
result = test_ctx.api.put(
test_ctx.context_factory(
input={
'version': 1,
'name': 'chewie',
'email': 'asd@asd.asd',
'password': 'oks',
'rank': 'moderator',
'avatarStyle': 'gravatar',
},
user=user),
'u1')
assert result == {
'avatarStyle': 'gravatar',
'avatarUrl': 'https://gravatar.com/avatar/' +
'6f370c8c7109534c3d5c394123a477d7?d=retro&s=200',
'creationTime': datetime.datetime(1997, 1, 1),
'lastLoginTime': None,
'email': 'asd@asd.asd',
'name': 'chewie',
'rank': 'moderator',
'commentCount': 0,
'likedPostCount': 0,
'dislikedPostCount': 0,
'favoritePostCount': 0,
'uploadedPostCount': 0,
'version': 2,
}
user = users.get_user_by_name('chewie')
assert user.name == 'chewie'
assert user.email == 'asd@asd.asd'
assert user.rank == db.User.RANK_MODERATOR
assert user.avatar_style == user.AVATAR_GRAVATAR
assert auth.is_valid_password(user, 'oks') is True
assert auth.is_valid_password(user, 'invalid') is False
db.session.flush()
@pytest.mark.parametrize('input,expected_exception', [
({'name': None}, users.InvalidUserNameError),
({'name': ''}, users.InvalidUserNameError),
({'name': '!bad'}, users.InvalidUserNameError),
({'name': 'x' * 51}, users.InvalidUserNameError),
({'password': None}, users.InvalidPasswordError),
({'password': ''}, users.InvalidPasswordError),
({'password': '!bad'}, users.InvalidPasswordError),
({'rank': None}, users.InvalidRankError),
({'rank': ''}, users.InvalidRankError),
({'rank': 'bad'}, users.InvalidRankError),
({'rank': 'anonymous'}, users.InvalidRankError),
({'rank': 'nobody'}, users.InvalidRankError),
({'email': 'bad'}, users.InvalidEmailError),
({'email': 'x@' * 65 + '.com'}, users.InvalidEmailError),
({'avatarStyle': None}, users.InvalidAvatarError),
({'avatarStyle': ''}, users.InvalidAvatarError),
({'avatarStyle': 'invalid'}, users.InvalidAvatarError),
({'avatarStyle': 'manual'}, users.InvalidAvatarError), # missing file
])
def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception):
user = test_ctx.user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR)
db.session.add(user)
with pytest.raises(expected_exception):
test_ctx.api.put(
test_ctx.context_factory(
input={**input, **{'version': 1}},
user=user),
'u1')
with unittest.mock.patch('szurubooru.func.users.create_user'), \
unittest.mock.patch('szurubooru.func.users.update_user_name'), \
unittest.mock.patch('szurubooru.func.users.update_user_password'), \
unittest.mock.patch('szurubooru.func.users.update_user_email'), \
unittest.mock.patch('szurubooru.func.users.update_user_rank'), \
unittest.mock.patch('szurubooru.func.users.update_user_avatar'), \
unittest.mock.patch('szurubooru.func.users.serialize_user'):
users.serialize_user.return_value = 'serialized user'
result = api.user_api.update_user(
context_factory(
params={
'version': 1,
'name': 'chewie',
'email': 'asd@asd.asd',
'password': 'oks',
'rank': 'moderator',
'avatarStyle': 'manual',
},
files={
'avatar': b'...',
},
user=auth_user),
{'user_name': 'u1'})
assert result == 'serialized user'
users.create_user.assert_not_called()
users.update_user_name.assert_called_once_with(user, 'chewie')
users.update_user_password.assert_called_once_with(user, 'oks')
users.update_user_email.assert_called_once_with(user, 'asd@asd.asd')
users.update_user_rank.assert_called_once_with(user, 'moderator', auth_user)
users.update_user_avatar.assert_called_once_with(user, 'manual', b'...')
users.serialize_user.assert_called_once_with(user, auth_user, options=None)
@pytest.mark.parametrize(
'field', ['name', 'email', 'password', 'rank', 'avatarStyle'])
def test_omitting_optional_field(test_ctx, field):
user = test_ctx.user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR)
def test_omitting_optional_field(user_factory, context_factory, field):
user = user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR)
db.session.add(user)
input = {
params = {
'name': 'chewie',
'email': 'asd@asd.asd',
'password': 'oks',
'rank': 'moderator',
'avatarStyle': 'gravatar',
}
del input[field]
result = test_ctx.api.put(
test_ctx.context_factory(
input={**input, **{'version': 1}},
files={'avatar': EMPTY_PIXEL},
user=user),
'u1')
assert result is not None
del params[field]
with unittest.mock.patch('szurubooru.func.users.create_user'), \
unittest.mock.patch('szurubooru.func.users.update_user_name'), \
unittest.mock.patch('szurubooru.func.users.update_user_password'), \
unittest.mock.patch('szurubooru.func.users.update_user_email'), \
unittest.mock.patch('szurubooru.func.users.update_user_rank'), \
unittest.mock.patch('szurubooru.func.users.update_user_avatar'), \
unittest.mock.patch('szurubooru.func.users.serialize_user'):
api.user_api.update_user(
context_factory(
params={**params, **{'version': 1}},
files={'avatar': b'...'},
user=user),
{'user_name': 'u1'})
def test_trying_to_update_non_existing(test_ctx):
user = test_ctx.user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR)
def test_trying_to_update_non_existing(user_factory, context_factory):
user = user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR)
db.session.add(user)
with pytest.raises(users.UserNotFoundError):
test_ctx.api.put(test_ctx.context_factory(user=user), 'u2')
api.user_api.update_user(
context_factory(user=user), {'user_name': 'u2'})
def test_removing_email(test_ctx):
user = test_ctx.user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR)
db.session.add(user)
test_ctx.api.put(
test_ctx.context_factory(
input={'email': '', 'version': 1}, user=user), 'u1')
assert users.get_user_by_name('u1').email is None
@pytest.mark.parametrize('input', [
@pytest.mark.parametrize('params', [
{'name': 'whatever'},
{'email': 'whatever'},
{'rank': 'whatever'},
{'password': 'whatever'},
{'avatarStyle': 'whatever'},
])
def test_trying_to_update_someone_else(test_ctx, input):
user1 = test_ctx.user_factory(name='u1', rank=db.User.RANK_REGULAR)
user2 = test_ctx.user_factory(name='u2', rank=db.User.RANK_REGULAR)
def test_trying_to_update_field_without_privileges(
user_factory, context_factory, params):
user1 = user_factory(name='u1', rank=db.User.RANK_REGULAR)
user2 = user_factory(name='u2', rank=db.User.RANK_REGULAR)
db.session.add_all([user1, user2])
with pytest.raises(errors.AuthError):
test_ctx.api.put(
test_ctx.context_factory(
input={**input, **{'version': 1}},
api.user_api.update_user(
context_factory(
params={**params, **{'version': 1}},
user=user1),
user2.name)
def test_trying_to_become_someone_else(test_ctx):
user1 = test_ctx.user_factory(name='me', rank=db.User.RANK_REGULAR)
user2 = test_ctx.user_factory(name='her', rank=db.User.RANK_REGULAR)
db.session.add_all([user1, user2])
with pytest.raises(users.UserAlreadyExistsError):
test_ctx.api.put(
test_ctx.context_factory(
input={'name': 'her', 'version': 1}, user=user1),
'me')
with pytest.raises(users.UserAlreadyExistsError):
test_ctx.api.put(
test_ctx.context_factory(
input={'name': 'HER', 'version': 1}, user=user1),
'me')
def test_trying_to_make_someone_into_someone_else(test_ctx):
user1 = test_ctx.user_factory(name='him', rank=db.User.RANK_REGULAR)
user2 = test_ctx.user_factory(name='her', rank=db.User.RANK_REGULAR)
user3 = test_ctx.user_factory(name='me', rank=db.User.RANK_MODERATOR)
db.session.add_all([user1, user2, user3])
with pytest.raises(users.UserAlreadyExistsError):
test_ctx.api.put(
test_ctx.context_factory(
input={'name': 'her', 'version': 1}, user=user3),
'him')
with pytest.raises(users.UserAlreadyExistsError):
test_ctx.api.put(
test_ctx.context_factory(
input={'name': 'HER', 'version': 1}, user=user3),
'him')
def test_renaming_someone_else(test_ctx):
user1 = test_ctx.user_factory(name='him', rank=db.User.RANK_REGULAR)
user2 = test_ctx.user_factory(name='me', rank=db.User.RANK_MODERATOR)
db.session.add_all([user1, user2])
test_ctx.api.put(
test_ctx.context_factory(
input={'name': 'himself', 'version': 1}, user=user2),
'him')
test_ctx.api.put(
test_ctx.context_factory(
input={'name': 'HIMSELF', 'version': 2}, user=user2),
'himself')
def test_mods_trying_to_become_admin(test_ctx):
user1 = test_ctx.user_factory(name='u1', rank=db.User.RANK_MODERATOR)
user2 = test_ctx.user_factory(name='u2', rank=db.User.RANK_MODERATOR)
db.session.add_all([user1, user2])
context = test_ctx.context_factory(
input={'rank': 'administrator', 'version': 1},
user=user1)
with pytest.raises(errors.AuthError):
test_ctx.api.put(context, user1.name)
with pytest.raises(errors.AuthError):
test_ctx.api.put(context, user2.name)
def test_uploading_avatar(test_ctx):
user = test_ctx.user_factory(name='u1', rank=db.User.RANK_MODERATOR)
db.session.add(user)
response = test_ctx.api.put(
test_ctx.context_factory(
input={'avatarStyle': 'manual', 'version': 1},
files={'avatar': EMPTY_PIXEL},
user=user),
'u1')
user = users.get_user_by_name('u1')
assert user.avatar_style == user.AVATAR_MANUAL
assert response['avatarUrl'] == \
'http://example.com/data/avatars/u1.png'
{'user_name': user2.name})

View file

@ -5,7 +5,7 @@ import uuid
import pytest
import freezegun
import sqlalchemy
from szurubooru import api, config, db
from szurubooru import api, config, db, rest
from szurubooru.func import util
class QueryCounter(object):
@ -74,12 +74,14 @@ def session(query_logger):
@pytest.fixture
def context_factory(session):
def factory(request=None, input=None, files=None, user=None):
ctx = api.Context()
ctx.input = input or {}
def factory(params=None, files=None, user=None):
ctx = rest.Context(
method=None,
url=None,
headers={},
params=params or {},
files=files or {})
ctx.session = session
ctx.request = request or {}
ctx.files = files or {}
ctx.user = user or db.User()
return ctx
return factory

View file

@ -1,32 +1,30 @@
import unittest.mock
import pytest
from szurubooru import api, errors
from szurubooru import rest, errors
from szurubooru.func import net
def test_has_param():
ctx = api.Context()
ctx.input = {'key': 'value'}
ctx = rest.Context(method=None, url=None, params={'key': 'value'})
assert ctx.has_param('key')
assert not ctx.has_param('key2')
def test_get_file():
ctx = api.Context()
ctx.files = {'key': b'content'}
ctx = rest.Context(method=None, url=None, files={'key': b'content'})
assert ctx.get_file('key') == b'content'
assert ctx.get_file('key2') is None
def test_get_file_from_url():
with unittest.mock.patch('szurubooru.func.net.download'):
net.download.return_value = b'content'
ctx = api.Context()
ctx.input = {'keyUrl': 'example.com'}
ctx = rest.Context(
method=None, url=None, params={'keyUrl': 'example.com'})
assert ctx.get_file('key') == b'content'
assert ctx.get_file('key2') is None
net.download.assert_called_once_with('example.com')
def test_getting_list_parameter():
ctx = api.Context()
ctx.input = {'key': 'value', 'list': ['1', '2', '3']}
ctx = rest.Context(
method=None, url=None, params={'key': 'value', 'list': ['1', '2', '3']})
assert ctx.get_param_as_list('key') == ['value']
assert ctx.get_param_as_list('key2') is None
assert ctx.get_param_as_list('key2', default=['def']) == ['def']
@ -35,8 +33,8 @@ def test_getting_list_parameter():
ctx.get_param_as_list('key2', required=True)
def test_getting_string_parameter():
ctx = api.Context()
ctx.input = {'key': 'value', 'list': ['1', '2', '3']}
ctx = rest.Context(
method=None, url=None, params={'key': 'value', 'list': ['1', '2', '3']})
assert ctx.get_param_as_string('key') == 'value'
assert ctx.get_param_as_string('key2') is None
assert ctx.get_param_as_string('key2', default='def') == 'def'
@ -45,8 +43,10 @@ def test_getting_string_parameter():
ctx.get_param_as_string('key2', required=True)
def test_getting_int_parameter():
ctx = api.Context()
ctx.input = {'key': '50', 'err': 'invalid', 'list': [1, 2, 3]}
ctx = rest.Context(
method=None,
url=None,
params={'key': '50', 'err': 'invalid', 'list': [1, 2, 3]})
assert ctx.get_param_as_int('key') == 50
assert ctx.get_param_as_int('key2') is None
assert ctx.get_param_as_int('key2', default=5) == 5
@ -65,8 +65,7 @@ def test_getting_int_parameter():
def test_getting_bool_parameter():
def test(value):
ctx = api.Context()
ctx.input = {'key': value}
ctx = rest.Context(method=None, url=None, params={'key': value})
return ctx.get_param_as_bool('key')
assert test('1') is True
@ -94,7 +93,7 @@ def test_getting_bool_parameter():
with pytest.raises(errors.ValidationError):
test(['1', '2'])
ctx = api.Context()
ctx = rest.Context(method=None, url=None)
assert ctx.get_param_as_bool('non-existing') is None
assert ctx.get_param_as_bool('non-existing', default=True) is True
with pytest.raises(errors.ValidationError):