server/api: move all io mgmt to context
where input/output includes files, JSON metadata and GET parameters. Additionally, formalize context with a new class, Context.
This commit is contained in:
parent
07ea920def
commit
3d4ceb13b8
17 changed files with 211 additions and 149 deletions
|
@ -2,3 +2,4 @@
|
|||
|
||||
from szurubooru.api.password_reset_api import PasswordResetApi
|
||||
from szurubooru.api.user_api import UserListApi, UserDetailApi
|
||||
from szurubooru.api.context import Context, Request
|
||||
|
|
|
@ -3,13 +3,13 @@ import types
|
|||
def _bind_method(target, desired_method_name):
|
||||
actual_method = getattr(target, desired_method_name)
|
||||
def _wrapper_method(_self, request, _response, *args, **kwargs):
|
||||
request.context.result = actual_method(
|
||||
request.context, *args, **kwargs)
|
||||
request.context.output = \
|
||||
actual_method(request.context, *args, **kwargs)
|
||||
return types.MethodType(_wrapper_method, target)
|
||||
|
||||
class BaseApi(object):
|
||||
'''
|
||||
A wrapper around falcon's API interface that eases context and result
|
||||
A wrapper around falcon's API interface that eases input and output
|
||||
management.
|
||||
'''
|
||||
|
||||
|
|
57
server/szurubooru/api/context.py
Normal file
57
server/szurubooru/api/context.py
Normal file
|
@ -0,0 +1,57 @@
|
|||
import falcon
|
||||
from szurubooru import errors
|
||||
|
||||
class Context(object):
|
||||
def __init__(self):
|
||||
self.session = None
|
||||
self.user = None
|
||||
self.files = {}
|
||||
self.input = {}
|
||||
self.output = None
|
||||
|
||||
def has_param(self, name):
|
||||
return name in self.input
|
||||
|
||||
def get_file(self, name):
|
||||
return self.files.get(name, None)
|
||||
|
||||
def get_param_as_string(self, name, required=False, default=None):
|
||||
if name in self.input:
|
||||
param = self.input[name]
|
||||
if isinstance(param, list):
|
||||
param = ','.join(param)
|
||||
return param
|
||||
if not required:
|
||||
return default
|
||||
raise errors.ValidationError('Required paramter %r is missing.' % name)
|
||||
|
||||
def get_param_as_int(
|
||||
self, name, required=False, min=None, max=None, default=None):
|
||||
if name in self.input:
|
||||
val = self.input[name]
|
||||
try:
|
||||
val = int(val)
|
||||
except (ValueError, TypeError):
|
||||
raise errors.ValidationError(
|
||||
'Parameter %r is invalid: the value must be an integer.'
|
||||
% name)
|
||||
|
||||
if min is not None and val < min:
|
||||
raise errors.ValidationError(
|
||||
'Parameter %r is invalid: the value must be at least %r.'
|
||||
% (name, min))
|
||||
|
||||
if max is not None and val > max:
|
||||
raise errors.ValidationError(
|
||||
'Parameter %r is invalid: the value may not exceed %r.'
|
||||
% (name, max))
|
||||
|
||||
return val
|
||||
|
||||
if not required:
|
||||
return default
|
||||
raise errors.ValidationError(
|
||||
'Required parameter %r is missing.' % name)
|
||||
|
||||
class Request(falcon.Request):
|
||||
context_type = Context
|
|
@ -9,9 +9,9 @@ MAIL_BODY = \
|
|||
'Otherwise, please ignore this email.'
|
||||
|
||||
class PasswordResetApi(BaseApi):
|
||||
def get(self, context, user_name):
|
||||
def get(self, ctx, user_name):
|
||||
''' Send a mail with secure token to the correlated user. '''
|
||||
user = users.get_by_name_or_email(context.session, user_name)
|
||||
user = users.get_by_name_or_email(ctx.session, user_name)
|
||||
if not user:
|
||||
raise errors.NotFoundError('User %r not found.' % user_name)
|
||||
if not user.email:
|
||||
|
@ -27,17 +27,15 @@ class PasswordResetApi(BaseApi):
|
|||
MAIL_BODY.format(name=config.config['name'], url=url))
|
||||
return {}
|
||||
|
||||
def post(self, context, user_name):
|
||||
def post(self, ctx, user_name):
|
||||
''' Verify token from mail, generate a new password and return it. '''
|
||||
user = users.get_by_name_or_email(context.session, user_name)
|
||||
user = users.get_by_name_or_email(ctx.session, user_name)
|
||||
if not user:
|
||||
raise errors.NotFoundError('User %r not found.' % user_name)
|
||||
good_token = auth.generate_authentication_token(user)
|
||||
if not 'token' in context.request:
|
||||
raise errors.ValidationError('Missing password reset token.')
|
||||
token = context.request['token']
|
||||
token = ctx.get_param_as_string('token', required=True)
|
||||
if token != good_token:
|
||||
raise errors.ValidationError('Invalid password reset token.')
|
||||
new_password = users.reset_password(user)
|
||||
context.session.commit()
|
||||
ctx.session.commit()
|
||||
return {'password': new_password}
|
||||
|
|
|
@ -34,96 +34,94 @@ class UserListApi(BaseApi):
|
|||
super().__init__()
|
||||
self._search_executor = search.SearchExecutor(search.UserSearchConfig())
|
||||
|
||||
def get(self, context):
|
||||
auth.verify_privilege(context.user, 'users:list')
|
||||
query = context.get_param_as_string('query')
|
||||
page = context.get_param_as_int('page', 1)
|
||||
page_size = min(100, context.get_param_as_int('pageSize', required=False) or 100)
|
||||
def get(self, ctx):
|
||||
auth.verify_privilege(ctx.user, 'users:list')
|
||||
query = ctx.get_param_as_string('query')
|
||||
page = ctx.get_param_as_int('page', default=1, min=1)
|
||||
page_size = ctx.get_param_as_int(
|
||||
'pageSize', default=100, min=1, max=100)
|
||||
count, user_list = self._search_executor.execute(
|
||||
context.session, query, page, page_size)
|
||||
ctx.session, query, page, page_size)
|
||||
return {
|
||||
'query': query,
|
||||
'page': page,
|
||||
'pageSize': page_size,
|
||||
'total': count,
|
||||
'users': [_serialize_user(context.user, user) for user in user_list],
|
||||
'users': [_serialize_user(ctx.user, user) for user in user_list],
|
||||
}
|
||||
|
||||
def post(self, context):
|
||||
auth.verify_privilege(context.user, 'users:create')
|
||||
def post(self, ctx):
|
||||
auth.verify_privilege(ctx.user, 'users:create')
|
||||
|
||||
try:
|
||||
name = context.request['name'].strip()
|
||||
password = context.request['password']
|
||||
email = context.request['email'].strip()
|
||||
except KeyError as ex:
|
||||
raise errors.ValidationError('Field %r not found.' % ex.args[0])
|
||||
name = ctx.get_param_as_string('name', required=True)
|
||||
password = ctx.get_param_as_string('password', required=True)
|
||||
email = ctx.get_param_as_string('email', required=True)
|
||||
|
||||
if users.get_by_name(context.session, name):
|
||||
if users.get_by_name(ctx.session, name):
|
||||
raise errors.IntegrityError('User %r already exists.' % name)
|
||||
user = users.create_user(context.session, name, password, email)
|
||||
context.session.add(user)
|
||||
context.session.commit()
|
||||
return {'user': _serialize_user(context.user, user)}
|
||||
user = users.create_user(ctx.session, name, password, email)
|
||||
ctx.session.add(user)
|
||||
ctx.session.commit()
|
||||
return {'user': _serialize_user(ctx.user, user)}
|
||||
|
||||
class UserDetailApi(BaseApi):
|
||||
def get(self, context, user_name):
|
||||
auth.verify_privilege(context.user, 'users:view')
|
||||
user = users.get_by_name(context.session, user_name)
|
||||
def get(self, ctx, user_name):
|
||||
auth.verify_privilege(ctx.user, 'users:view')
|
||||
user = users.get_by_name(ctx.session, user_name)
|
||||
if not user:
|
||||
raise errors.NotFoundError('User %r not found.' % user_name)
|
||||
return {'user': _serialize_user(context.user, user)}
|
||||
return {'user': _serialize_user(ctx.user, user)}
|
||||
|
||||
def put(self, context, user_name):
|
||||
user = users.get_by_name(context.session, user_name)
|
||||
def put(self, ctx, user_name):
|
||||
user = users.get_by_name(ctx.session, user_name)
|
||||
if not user:
|
||||
raise errors.NotFoundError('User %r not found.' % user_name)
|
||||
|
||||
if context.user.user_id == user.user_id:
|
||||
if ctx.user.user_id == user.user_id:
|
||||
infix = 'self'
|
||||
else:
|
||||
infix = 'any'
|
||||
|
||||
if 'name' in context.request:
|
||||
auth.verify_privilege(context.user, 'users:edit:%s:name' % infix)
|
||||
other_user = users.get_by_name(context.session, context.request['name'])
|
||||
if ctx.has_param('name'):
|
||||
auth.verify_privilege(ctx.user, 'users:edit:%s:name' % infix)
|
||||
other_user = users.get_by_name(ctx.session, ctx.get_param_as_string('name'))
|
||||
if other_user and other_user.user_id != user.user_id:
|
||||
raise errors.IntegrityError('User %r already exists.' % user.name)
|
||||
users.update_name(user, context.request['name'])
|
||||
users.update_name(user, ctx.get_param_as_string('name'))
|
||||
|
||||
if 'password' in context.request:
|
||||
auth.verify_privilege(context.user, 'users:edit:%s:pass' % infix)
|
||||
users.update_password(user, context.request['password'])
|
||||
if ctx.has_param('password'):
|
||||
auth.verify_privilege(ctx.user, 'users:edit:%s:pass' % infix)
|
||||
users.update_password(user, ctx.get_param_as_string('password'))
|
||||
|
||||
if 'email' in context.request:
|
||||
auth.verify_privilege(context.user, 'users:edit:%s:email' % infix)
|
||||
users.update_email(user, context.request['email'])
|
||||
if ctx.has_param('email'):
|
||||
auth.verify_privilege(ctx.user, 'users:edit:%s:email' % infix)
|
||||
users.update_email(user, ctx.get_param_as_string('email'))
|
||||
|
||||
if 'rank' in context.request:
|
||||
auth.verify_privilege(context.user, 'users:edit:%s:rank' % infix)
|
||||
users.update_rank(user, context.request['rank'], context.user)
|
||||
if ctx.has_param('rank'):
|
||||
auth.verify_privilege(ctx.user, 'users:edit:%s:rank' % infix)
|
||||
users.update_rank(user, ctx.get_param_as_string('rank'), ctx.user)
|
||||
|
||||
if 'avatarStyle' in context.request:
|
||||
auth.verify_privilege(context.user, 'users:edit:%s:avatar' % infix)
|
||||
if ctx.has_param('avatarStyle'):
|
||||
auth.verify_privilege(ctx.user, 'users:edit:%s:avatar' % infix)
|
||||
users.update_avatar(
|
||||
user,
|
||||
context.request['avatarStyle'],
|
||||
context.files.get('avatar') or None)
|
||||
ctx.get_param_as_string('avatarStyle'),
|
||||
ctx.get_file('avatar'))
|
||||
|
||||
context.session.commit()
|
||||
return {'user': _serialize_user(context.user, user)}
|
||||
ctx.session.commit()
|
||||
return {'user': _serialize_user(ctx.user, user)}
|
||||
|
||||
def delete(self, context, user_name):
|
||||
user = users.get_by_name(context.session, user_name)
|
||||
def delete(self, ctx, user_name):
|
||||
user = users.get_by_name(ctx.session, user_name)
|
||||
if not user:
|
||||
raise errors.NotFoundError('User %r not found.' % user_name)
|
||||
|
||||
if context.user.user_id == user.user_id:
|
||||
if ctx.user.user_id == user.user_id:
|
||||
infix = 'self'
|
||||
else:
|
||||
infix = 'any'
|
||||
|
||||
auth.verify_privilege(context.user, 'users:delete:%s' % infix)
|
||||
context.session.delete(user)
|
||||
context.session.commit()
|
||||
auth.verify_privilege(ctx.user, 'users:delete:%s' % infix)
|
||||
ctx.session.delete(user)
|
||||
ctx.session.commit()
|
||||
return {}
|
||||
|
|
|
@ -6,22 +6,6 @@ import sqlalchemy.orm
|
|||
from szurubooru import api, config, errors, middleware
|
||||
from szurubooru.util import misc
|
||||
|
||||
class _CustomRequest(falcon.Request):
|
||||
context_type = misc.dotdict
|
||||
|
||||
def get_param_as_string(self, name, required=False, store=None, default=None):
|
||||
params = self._params
|
||||
if name in params:
|
||||
param = params[name]
|
||||
if isinstance(param, list):
|
||||
param = ','.join(param)
|
||||
if store is not None:
|
||||
store[name] = param
|
||||
return param
|
||||
if not required:
|
||||
return default
|
||||
raise falcon.HTTPMissingParam(name)
|
||||
|
||||
def _on_auth_error(ex, _request, _response, _params):
|
||||
raise falcon.HTTPForbidden(
|
||||
title='Authentication error', description=str(ex))
|
||||
|
@ -56,11 +40,10 @@ def create_app():
|
|||
scoped_session = sqlalchemy.orm.scoped_session(session_maker)
|
||||
|
||||
app = falcon.API(
|
||||
request_type=_CustomRequest,
|
||||
request_type=api.Request,
|
||||
middleware=[
|
||||
middleware.ImbueContext(),
|
||||
middleware.RequireJson(),
|
||||
middleware.JsonTranslator(),
|
||||
middleware.ContextAdapter(),
|
||||
middleware.DbSession(scoped_session),
|
||||
middleware.Authenticator(),
|
||||
])
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
''' Various hooks that get executed for each request. '''
|
||||
|
||||
from szurubooru.middleware.authenticator import Authenticator
|
||||
from szurubooru.middleware.json_translator import JsonTranslator
|
||||
from szurubooru.middleware.context_adapter import ContextAdapter
|
||||
from szurubooru.middleware.require_json import RequireJson
|
||||
from szurubooru.middleware.db_session import DbSession
|
||||
from szurubooru.middleware.imbue_context import ImbueContext
|
||||
|
|
|
@ -10,17 +10,23 @@ def json_serializer(obj):
|
|||
return serial
|
||||
raise TypeError('Type not serializable')
|
||||
|
||||
class JsonTranslator(object):
|
||||
class ContextAdapter(object):
|
||||
'''
|
||||
Translates API requests and API responses to JSON using requests'
|
||||
context.
|
||||
1. Deserialize API requests into the context:
|
||||
- Pass GET parameters
|
||||
- Handle multipart/form-data file uploads
|
||||
- Handle JSON requests
|
||||
2. Serialize API responses from the context as JSON.
|
||||
'''
|
||||
|
||||
def process_request(self, request, _response):
|
||||
request.context.files = {}
|
||||
request.context.input = {}
|
||||
for key, value in request._params.items():
|
||||
request.context.input[key] = value
|
||||
|
||||
if request.content_length in (None, 0):
|
||||
return
|
||||
|
||||
request.context.files = {}
|
||||
if 'multipart/form-data' in (request.content_type or ''):
|
||||
# obscure, claims to "avoid a bug in cgi.FieldStorage"
|
||||
request.env.setdefault('QUERY_STRING', '')
|
||||
|
@ -43,7 +49,8 @@ class JsonTranslator(object):
|
|||
if isinstance(body, bytes):
|
||||
body = body.decode('utf-8')
|
||||
|
||||
request.context.request = json.loads(body)
|
||||
for key, value in json.loads(body).items():
|
||||
request.context.input[key] = value
|
||||
except (ValueError, UnicodeDecodeError):
|
||||
raise falcon.HTTPError(
|
||||
falcon.HTTP_401,
|
||||
|
@ -52,7 +59,7 @@ class JsonTranslator(object):
|
|||
'JSON was incorrect or not encoded as UTF-8.')
|
||||
|
||||
def process_response(self, request, response, _resource):
|
||||
if 'result' not in request.context:
|
||||
if not request.context.output:
|
||||
return
|
||||
response.body = json.dumps(
|
||||
request.context.result, default=json_serializer, indent=2)
|
||||
request.context.output, default=json_serializer, indent=2)
|
|
@ -1,8 +0,0 @@
|
|||
class ImbueContext(object):
|
||||
''' Decorates context with methods from falcon's request. '''
|
||||
|
||||
def process_request(self, request, _response):
|
||||
request.context.get_param_as_string = request.get_param_as_string
|
||||
request.context.get_param_as_bool = request.get_param_as_bool
|
||||
request.context.get_param_as_int = request.get_param_as_int
|
||||
request.context.get_param_as_list = request.get_param_as_list
|
|
@ -20,8 +20,6 @@ class SearchExecutor(object):
|
|||
Parse input and return tuple containing total record count and filtered
|
||||
entities.
|
||||
'''
|
||||
page = max(1, int(page))
|
||||
page_size = max(1, int(page_size))
|
||||
filter_query = self._prepare(session, query_text)
|
||||
entities = filter_query \
|
||||
.offset((page - 1) * page_size).limit(page_size).all()
|
||||
|
|
43
server/szurubooru/tests/api/test_context.py
Normal file
43
server/szurubooru/tests/api/test_context.py
Normal file
|
@ -0,0 +1,43 @@
|
|||
import pytest
|
||||
from szurubooru import api, errors
|
||||
|
||||
def test_has_param():
|
||||
ctx = api.Context()
|
||||
ctx.input = {'key': 'value'}
|
||||
assert ctx.has_param('key')
|
||||
assert not ctx.has_param('key2')
|
||||
|
||||
def test_get_file():
|
||||
ctx = api.Context()
|
||||
ctx.files = {'key': b'content'}
|
||||
assert ctx.get_file('key') == b'content'
|
||||
assert ctx.get_file('key2') is None
|
||||
|
||||
def test_getting_string_parameter():
|
||||
ctx = api.Context()
|
||||
ctx.input = {'key': 'value', 'list': ['1', '2', '3']}
|
||||
assert ctx.get_param_as_string('key') == 'value'
|
||||
assert ctx.get_param_as_string('key2') is None
|
||||
assert ctx.get_param_as_string('key2', default='def') == 'def'
|
||||
assert ctx.get_param_as_string('list') == '1,2,3' # falcon issue #749
|
||||
with pytest.raises(errors.ValidationError):
|
||||
ctx.get_param_as_string('key2', required=True)
|
||||
|
||||
def test_getting_int_parameter():
|
||||
ctx = api.Context()
|
||||
ctx.input = {'key': '50', 'err': 'invalid', 'list': [1, 2, 3]}
|
||||
assert ctx.get_param_as_int('key') == 50
|
||||
assert ctx.get_param_as_int('key2') is None
|
||||
assert ctx.get_param_as_int('key2', default=5) == 5
|
||||
with pytest.raises(errors.ValidationError):
|
||||
ctx.get_param_as_int('list')
|
||||
with pytest.raises(errors.ValidationError):
|
||||
ctx.get_param_as_int('key2', required=True)
|
||||
with pytest.raises(errors.ValidationError):
|
||||
ctx.get_param_as_int('err')
|
||||
with pytest.raises(errors.ValidationError):
|
||||
assert ctx.get_param_as_int('key', min=50) == 50
|
||||
ctx.get_param_as_int('key', min=51)
|
||||
with pytest.raises(errors.ValidationError):
|
||||
assert ctx.get_param_as_int('key', max=50) == 50
|
||||
ctx.get_param_as_int('key', max=49)
|
|
@ -58,21 +58,21 @@ def test_confirmation_no_token(password_reset_api, context_factory, session):
|
|||
user = mock_user('u1', 'regular_user', 'user@example.com')
|
||||
session.add(user)
|
||||
with pytest.raises(errors.ValidationError):
|
||||
password_reset_api.post(context_factory(request={}), 'u1')
|
||||
password_reset_api.post(context_factory(input={}), 'u1')
|
||||
|
||||
def test_confirmation_bad_token(password_reset_api, context_factory, session):
|
||||
user = mock_user('u1', 'regular_user', 'user@example.com')
|
||||
session.add(user)
|
||||
with pytest.raises(errors.ValidationError):
|
||||
password_reset_api.post(
|
||||
context_factory(request={'token': 'bad'}), 'u1')
|
||||
context_factory(input={'token': 'bad'}), 'u1')
|
||||
|
||||
def test_confirmation_good_token(password_reset_api, context_factory, session):
|
||||
user = mock_user('u1', 'regular_user', 'user@example.com')
|
||||
old_hash = user.password_hash
|
||||
session.add(user)
|
||||
context = context_factory(
|
||||
request={'token': '4ac0be176fb364f13ee6b634c43220e2'})
|
||||
input={'token': '4ac0be176fb364f13ee6b634c43220e2'})
|
||||
result = password_reset_api.post(context, 'u1')
|
||||
assert user.password_hash != old_hash
|
||||
assert auth.is_valid_password(user, result['password']) is True
|
||||
|
|
|
@ -26,7 +26,7 @@ def test_creating_users(
|
|||
|
||||
user_list_api.post(
|
||||
context_factory(
|
||||
request={
|
||||
input={
|
||||
'name': 'chewie1',
|
||||
'email': 'asd@asd.asd',
|
||||
'password': 'oks',
|
||||
|
@ -34,7 +34,7 @@ def test_creating_users(
|
|||
user=user_factory(rank='regular_user')))
|
||||
user_list_api.post(
|
||||
context_factory(
|
||||
request={
|
||||
input={
|
||||
'name': 'chewie2',
|
||||
'email': 'asd@asd.asd',
|
||||
'password': 'sok',
|
||||
|
@ -68,7 +68,7 @@ def test_creating_user_that_already_exists(
|
|||
})
|
||||
user_list_api.post(
|
||||
context_factory(
|
||||
request={
|
||||
input={
|
||||
'name': 'chewie',
|
||||
'email': 'asd@asd.asd',
|
||||
'password': 'oks',
|
||||
|
@ -77,7 +77,7 @@ def test_creating_user_that_already_exists(
|
|||
with pytest.raises(errors.IntegrityError):
|
||||
user_list_api.post(
|
||||
context_factory(
|
||||
request={
|
||||
input={
|
||||
'name': 'chewie',
|
||||
'email': 'asd@asd.asd',
|
||||
'password': 'oks',
|
||||
|
@ -86,7 +86,7 @@ def test_creating_user_that_already_exists(
|
|||
with pytest.raises(errors.IntegrityError):
|
||||
user_list_api.post(
|
||||
context_factory(
|
||||
request={
|
||||
input={
|
||||
'name': 'CHEWIE',
|
||||
'email': 'asd@asd.asd',
|
||||
'password': 'oks',
|
||||
|
@ -109,4 +109,4 @@ def test_missing_field(
|
|||
with pytest.raises(errors.ValidationError):
|
||||
user_list_api.post(
|
||||
context_factory(
|
||||
request=request, user=user_factory(rank='regular_user')))
|
||||
input=request, user=user_factory(rank='regular_user')))
|
||||
|
|
|
@ -27,7 +27,7 @@ def test_retrieving_multiple(
|
|||
session.add_all([user1, user2])
|
||||
result = user_list_api.get(
|
||||
context_factory(
|
||||
params={'query': '', 'page': 1},
|
||||
input={'query': '', 'page': 1},
|
||||
user=user_factory(rank='regular_user')))
|
||||
assert result['query'] == ''
|
||||
assert result['page'] == 1
|
||||
|
@ -44,7 +44,7 @@ def test_retrieving_multiple_without_privileges(
|
|||
with pytest.raises(errors.AuthError):
|
||||
user_list_api.get(
|
||||
context_factory(
|
||||
params={'query': '', 'page': 1},
|
||||
input={'query': '', 'page': 1},
|
||||
user=user_factory(rank='anonymous')))
|
||||
|
||||
def test_retrieving_multiple_with_privileges(
|
||||
|
@ -55,7 +55,7 @@ def test_retrieving_multiple_with_privileges(
|
|||
})
|
||||
result = user_list_api.get(
|
||||
context_factory(
|
||||
params={'query': 'asd', 'page': 1},
|
||||
input={'query': 'asd', 'page': 1},
|
||||
user=user_factory(rank='regular_user')))
|
||||
assert result['query'] == 'asd'
|
||||
assert result['page'] == 1
|
||||
|
@ -79,7 +79,7 @@ def test_retrieving_single(
|
|||
session.add(user)
|
||||
result = user_detail_api.get(
|
||||
context_factory(
|
||||
params={'query': '', 'page': 1},
|
||||
input={'query': '', 'page': 1},
|
||||
user=user_factory(rank='regular_user')),
|
||||
'u1')
|
||||
assert result['user']['id'] == user.user_id
|
||||
|
@ -98,7 +98,7 @@ def test_retrieving_non_existing(
|
|||
with pytest.raises(errors.NotFoundError):
|
||||
user_detail_api.get(
|
||||
context_factory(
|
||||
params={'query': '', 'page': 1},
|
||||
input={'query': '', 'page': 1},
|
||||
user=user_factory(rank='regular_user')),
|
||||
'-')
|
||||
|
||||
|
@ -111,6 +111,6 @@ def test_retrieving_single_without_privileges(
|
|||
with pytest.raises(errors.AuthError):
|
||||
user_detail_api.get(
|
||||
context_factory(
|
||||
params={'query': '', 'page': 1},
|
||||
input={'query': '', 'page': 1},
|
||||
user=user_factory(rank='anonymous')),
|
||||
'-')
|
||||
|
|
|
@ -31,7 +31,7 @@ def test_updating_user(
|
|||
session.add(user)
|
||||
user_detail_api.put(
|
||||
context_factory(
|
||||
request={
|
||||
input={
|
||||
'name': 'chewie',
|
||||
'email': 'asd@asd.asd',
|
||||
'password': 'oks',
|
||||
|
@ -96,7 +96,7 @@ def test_removing_email(
|
|||
user = user_factory(name='u1', rank='admin')
|
||||
session.add(user)
|
||||
user_detail_api.put(
|
||||
context_factory(request={'email': ''}, user=user), 'u1')
|
||||
context_factory(input={'email': ''}, user=user), 'u1')
|
||||
assert session.query(db.User).filter_by(name='u1').one().email is None
|
||||
|
||||
@pytest.mark.parametrize('request', [
|
||||
|
@ -128,7 +128,7 @@ def test_invalid_inputs(
|
|||
user = user_factory(name='u1', rank='admin')
|
||||
session.add(user)
|
||||
with pytest.raises(errors.ValidationError):
|
||||
user_detail_api.put(context_factory(request=request, user=user), 'u1')
|
||||
user_detail_api.put(context_factory(input=request, user=user), 'u1')
|
||||
|
||||
@pytest.mark.parametrize('request', [
|
||||
{'name': 'whatever'},
|
||||
|
@ -159,7 +159,7 @@ def test_user_trying_to_update_someone_else(
|
|||
session.add_all([user1, user2])
|
||||
with pytest.raises(errors.AuthError):
|
||||
user_detail_api.put(
|
||||
context_factory(request=request, user=user1), user2.name)
|
||||
context_factory(input=request, user=user1), user2.name)
|
||||
|
||||
def test_user_trying_to_become_someone_else(
|
||||
session,
|
||||
|
@ -176,11 +176,11 @@ def test_user_trying_to_become_someone_else(
|
|||
session.add_all([user1, user2])
|
||||
with pytest.raises(errors.IntegrityError):
|
||||
user_detail_api.put(
|
||||
context_factory(request={'name': 'her'}, user=user1),
|
||||
context_factory(input={'name': 'her'}, user=user1),
|
||||
'me')
|
||||
with pytest.raises(errors.IntegrityError):
|
||||
user_detail_api.put(
|
||||
context_factory(request={'name': 'HER'}, user=user1), 'me')
|
||||
context_factory(input={'name': 'HER'}, user=user1), 'me')
|
||||
|
||||
def test_mods_trying_to_become_admin(
|
||||
session,
|
||||
|
@ -198,7 +198,7 @@ def test_mods_trying_to_become_admin(
|
|||
user1 = user_factory(name='u1', rank='mod')
|
||||
user2 = user_factory(name='u2', rank='mod')
|
||||
session.add_all([user1, user2])
|
||||
context = context_factory(request={'rank': 'admin'}, user=user1)
|
||||
context = context_factory(input={'rank': 'admin'}, user=user1)
|
||||
with pytest.raises(errors.AuthError):
|
||||
user_detail_api.put(context, user1.name)
|
||||
with pytest.raises(errors.AuthError):
|
||||
|
@ -227,7 +227,7 @@ def test_uploading_avatar(
|
|||
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b'
|
||||
response = user_detail_api.put(
|
||||
context_factory(
|
||||
request={'avatarStyle': 'manual'},
|
||||
input={'avatarStyle': 'manual'},
|
||||
files={'avatar': empty_pixel},
|
||||
user=user),
|
||||
'u1')
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from datetime import datetime
|
||||
import pytest
|
||||
import sqlalchemy
|
||||
from szurubooru import db, config
|
||||
from szurubooru import api, config, db
|
||||
from szurubooru.util import misc
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -15,28 +15,14 @@ def session():
|
|||
|
||||
@pytest.fixture
|
||||
def context_factory(session):
|
||||
def factory(request=None, params=None, files=None, user=None):
|
||||
params = params or {}
|
||||
def get_param_as_string(key, default=None, required=False):
|
||||
if key not in params:
|
||||
if required:
|
||||
raise RuntimeError('Param is missing!')
|
||||
return default
|
||||
return params[key]
|
||||
def get_param_as_int(key, default=None, required=False):
|
||||
if key not in params:
|
||||
if required:
|
||||
raise RuntimeError('Param is missing!')
|
||||
return default
|
||||
return int(params[key])
|
||||
context = misc.dotdict()
|
||||
context.session = session
|
||||
context.request = request or {}
|
||||
context.files = files or {}
|
||||
context.user = user or db.User()
|
||||
context.get_param_as_string = get_param_as_string
|
||||
context.get_param_as_int = get_param_as_int
|
||||
return context
|
||||
def factory(request=None, input=None, files=None, user=None):
|
||||
ctx = api.Context()
|
||||
ctx.input = input or {}
|
||||
ctx.session = session
|
||||
ctx.request = request or {}
|
||||
ctx.files = files or {}
|
||||
ctx.user = user or db.User()
|
||||
return ctx
|
||||
return factory
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
@ -123,7 +123,7 @@ def test_combining_tokens(session, verify_unpaged, input, expected_user_names):
|
|||
(2, 1, 2, ['u2']),
|
||||
(3, 1, 2, []),
|
||||
(0, 1, 2, ['u1']),
|
||||
(0, 0, 2, ['u1']),
|
||||
(0, 0, 2, []),
|
||||
])
|
||||
def test_paging(
|
||||
session, executor, page, page_size,
|
||||
|
|
Loading…
Reference in a new issue