server/api: move all io mgmt to context

where input/output includes files, JSON metadata and GET parameters.
Additionally, formalize context with a new class, Context.
This commit is contained in:
rr- 2016-04-15 17:54:21 +02:00
parent 07ea920def
commit 3d4ceb13b8
17 changed files with 211 additions and 149 deletions

View file

@ -2,3 +2,4 @@
from szurubooru.api.password_reset_api import PasswordResetApi from szurubooru.api.password_reset_api import PasswordResetApi
from szurubooru.api.user_api import UserListApi, UserDetailApi from szurubooru.api.user_api import UserListApi, UserDetailApi
from szurubooru.api.context import Context, Request

View file

@ -3,13 +3,13 @@ import types
def _bind_method(target, desired_method_name): def _bind_method(target, desired_method_name):
actual_method = getattr(target, desired_method_name) actual_method = getattr(target, desired_method_name)
def _wrapper_method(_self, request, _response, *args, **kwargs): def _wrapper_method(_self, request, _response, *args, **kwargs):
request.context.result = actual_method( request.context.output = \
request.context, *args, **kwargs) actual_method(request.context, *args, **kwargs)
return types.MethodType(_wrapper_method, target) return types.MethodType(_wrapper_method, target)
class BaseApi(object): class BaseApi(object):
''' '''
A wrapper around falcon's API interface that eases context and result A wrapper around falcon's API interface that eases input and output
management. management.
''' '''

View file

@ -0,0 +1,57 @@
import falcon
from szurubooru import errors
class Context(object):
def __init__(self):
self.session = None
self.user = None
self.files = {}
self.input = {}
self.output = None
def has_param(self, name):
return name in self.input
def get_file(self, name):
return self.files.get(name, None)
def get_param_as_string(self, name, required=False, default=None):
if name in self.input:
param = self.input[name]
if isinstance(param, list):
param = ','.join(param)
return param
if not required:
return default
raise errors.ValidationError('Required paramter %r is missing.' % name)
def get_param_as_int(
self, name, required=False, min=None, max=None, default=None):
if name in self.input:
val = self.input[name]
try:
val = int(val)
except (ValueError, TypeError):
raise errors.ValidationError(
'Parameter %r is invalid: the value must be an integer.'
% name)
if min is not None and val < min:
raise errors.ValidationError(
'Parameter %r is invalid: the value must be at least %r.'
% (name, min))
if max is not None and val > max:
raise errors.ValidationError(
'Parameter %r is invalid: the value may not exceed %r.'
% (name, max))
return val
if not required:
return default
raise errors.ValidationError(
'Required parameter %r is missing.' % name)
class Request(falcon.Request):
context_type = Context

View file

@ -9,9 +9,9 @@ MAIL_BODY = \
'Otherwise, please ignore this email.' 'Otherwise, please ignore this email.'
class PasswordResetApi(BaseApi): class PasswordResetApi(BaseApi):
def get(self, context, user_name): def get(self, ctx, user_name):
''' Send a mail with secure token to the correlated user. ''' ''' Send a mail with secure token to the correlated user. '''
user = users.get_by_name_or_email(context.session, user_name) user = users.get_by_name_or_email(ctx.session, user_name)
if not user: if not user:
raise errors.NotFoundError('User %r not found.' % user_name) raise errors.NotFoundError('User %r not found.' % user_name)
if not user.email: if not user.email:
@ -27,17 +27,15 @@ class PasswordResetApi(BaseApi):
MAIL_BODY.format(name=config.config['name'], url=url)) MAIL_BODY.format(name=config.config['name'], url=url))
return {} return {}
def post(self, context, user_name): def post(self, ctx, user_name):
''' Verify token from mail, generate a new password and return it. ''' ''' Verify token from mail, generate a new password and return it. '''
user = users.get_by_name_or_email(context.session, user_name) user = users.get_by_name_or_email(ctx.session, user_name)
if not user: if not user:
raise errors.NotFoundError('User %r not found.' % user_name) raise errors.NotFoundError('User %r not found.' % user_name)
good_token = auth.generate_authentication_token(user) good_token = auth.generate_authentication_token(user)
if not 'token' in context.request: token = ctx.get_param_as_string('token', required=True)
raise errors.ValidationError('Missing password reset token.')
token = context.request['token']
if token != good_token: if token != good_token:
raise errors.ValidationError('Invalid password reset token.') raise errors.ValidationError('Invalid password reset token.')
new_password = users.reset_password(user) new_password = users.reset_password(user)
context.session.commit() ctx.session.commit()
return {'password': new_password} return {'password': new_password}

View file

@ -34,96 +34,94 @@ class UserListApi(BaseApi):
super().__init__() super().__init__()
self._search_executor = search.SearchExecutor(search.UserSearchConfig()) self._search_executor = search.SearchExecutor(search.UserSearchConfig())
def get(self, context): def get(self, ctx):
auth.verify_privilege(context.user, 'users:list') auth.verify_privilege(ctx.user, 'users:list')
query = context.get_param_as_string('query') query = ctx.get_param_as_string('query')
page = context.get_param_as_int('page', 1) page = ctx.get_param_as_int('page', default=1, min=1)
page_size = min(100, context.get_param_as_int('pageSize', required=False) or 100) page_size = ctx.get_param_as_int(
'pageSize', default=100, min=1, max=100)
count, user_list = self._search_executor.execute( count, user_list = self._search_executor.execute(
context.session, query, page, page_size) ctx.session, query, page, page_size)
return { return {
'query': query, 'query': query,
'page': page, 'page': page,
'pageSize': page_size, 'pageSize': page_size,
'total': count, 'total': count,
'users': [_serialize_user(context.user, user) for user in user_list], 'users': [_serialize_user(ctx.user, user) for user in user_list],
} }
def post(self, context): def post(self, ctx):
auth.verify_privilege(context.user, 'users:create') auth.verify_privilege(ctx.user, 'users:create')
try: name = ctx.get_param_as_string('name', required=True)
name = context.request['name'].strip() password = ctx.get_param_as_string('password', required=True)
password = context.request['password'] email = ctx.get_param_as_string('email', required=True)
email = context.request['email'].strip()
except KeyError as ex:
raise errors.ValidationError('Field %r not found.' % ex.args[0])
if users.get_by_name(context.session, name): if users.get_by_name(ctx.session, name):
raise errors.IntegrityError('User %r already exists.' % name) raise errors.IntegrityError('User %r already exists.' % name)
user = users.create_user(context.session, name, password, email) user = users.create_user(ctx.session, name, password, email)
context.session.add(user) ctx.session.add(user)
context.session.commit() ctx.session.commit()
return {'user': _serialize_user(context.user, user)} return {'user': _serialize_user(ctx.user, user)}
class UserDetailApi(BaseApi): class UserDetailApi(BaseApi):
def get(self, context, user_name): def get(self, ctx, user_name):
auth.verify_privilege(context.user, 'users:view') auth.verify_privilege(ctx.user, 'users:view')
user = users.get_by_name(context.session, user_name) user = users.get_by_name(ctx.session, user_name)
if not user: if not user:
raise errors.NotFoundError('User %r not found.' % user_name) raise errors.NotFoundError('User %r not found.' % user_name)
return {'user': _serialize_user(context.user, user)} return {'user': _serialize_user(ctx.user, user)}
def put(self, context, user_name): def put(self, ctx, user_name):
user = users.get_by_name(context.session, user_name) user = users.get_by_name(ctx.session, user_name)
if not user: if not user:
raise errors.NotFoundError('User %r not found.' % user_name) raise errors.NotFoundError('User %r not found.' % user_name)
if context.user.user_id == user.user_id: if ctx.user.user_id == user.user_id:
infix = 'self' infix = 'self'
else: else:
infix = 'any' infix = 'any'
if 'name' in context.request: if ctx.has_param('name'):
auth.verify_privilege(context.user, 'users:edit:%s:name' % infix) auth.verify_privilege(ctx.user, 'users:edit:%s:name' % infix)
other_user = users.get_by_name(context.session, context.request['name']) other_user = users.get_by_name(ctx.session, ctx.get_param_as_string('name'))
if other_user and other_user.user_id != user.user_id: if other_user and other_user.user_id != user.user_id:
raise errors.IntegrityError('User %r already exists.' % user.name) raise errors.IntegrityError('User %r already exists.' % user.name)
users.update_name(user, context.request['name']) users.update_name(user, ctx.get_param_as_string('name'))
if 'password' in context.request: if ctx.has_param('password'):
auth.verify_privilege(context.user, 'users:edit:%s:pass' % infix) auth.verify_privilege(ctx.user, 'users:edit:%s:pass' % infix)
users.update_password(user, context.request['password']) users.update_password(user, ctx.get_param_as_string('password'))
if 'email' in context.request: if ctx.has_param('email'):
auth.verify_privilege(context.user, 'users:edit:%s:email' % infix) auth.verify_privilege(ctx.user, 'users:edit:%s:email' % infix)
users.update_email(user, context.request['email']) users.update_email(user, ctx.get_param_as_string('email'))
if 'rank' in context.request: if ctx.has_param('rank'):
auth.verify_privilege(context.user, 'users:edit:%s:rank' % infix) auth.verify_privilege(ctx.user, 'users:edit:%s:rank' % infix)
users.update_rank(user, context.request['rank'], context.user) users.update_rank(user, ctx.get_param_as_string('rank'), ctx.user)
if 'avatarStyle' in context.request: if ctx.has_param('avatarStyle'):
auth.verify_privilege(context.user, 'users:edit:%s:avatar' % infix) auth.verify_privilege(ctx.user, 'users:edit:%s:avatar' % infix)
users.update_avatar( users.update_avatar(
user, user,
context.request['avatarStyle'], ctx.get_param_as_string('avatarStyle'),
context.files.get('avatar') or None) ctx.get_file('avatar'))
context.session.commit() ctx.session.commit()
return {'user': _serialize_user(context.user, user)} return {'user': _serialize_user(ctx.user, user)}
def delete(self, context, user_name): def delete(self, ctx, user_name):
user = users.get_by_name(context.session, user_name) user = users.get_by_name(ctx.session, user_name)
if not user: if not user:
raise errors.NotFoundError('User %r not found.' % user_name) raise errors.NotFoundError('User %r not found.' % user_name)
if context.user.user_id == user.user_id: if ctx.user.user_id == user.user_id:
infix = 'self' infix = 'self'
else: else:
infix = 'any' infix = 'any'
auth.verify_privilege(context.user, 'users:delete:%s' % infix) auth.verify_privilege(ctx.user, 'users:delete:%s' % infix)
context.session.delete(user) ctx.session.delete(user)
context.session.commit() ctx.session.commit()
return {} return {}

View file

@ -6,22 +6,6 @@ import sqlalchemy.orm
from szurubooru import api, config, errors, middleware from szurubooru import api, config, errors, middleware
from szurubooru.util import misc from szurubooru.util import misc
class _CustomRequest(falcon.Request):
context_type = misc.dotdict
def get_param_as_string(self, name, required=False, store=None, default=None):
params = self._params
if name in params:
param = params[name]
if isinstance(param, list):
param = ','.join(param)
if store is not None:
store[name] = param
return param
if not required:
return default
raise falcon.HTTPMissingParam(name)
def _on_auth_error(ex, _request, _response, _params): def _on_auth_error(ex, _request, _response, _params):
raise falcon.HTTPForbidden( raise falcon.HTTPForbidden(
title='Authentication error', description=str(ex)) title='Authentication error', description=str(ex))
@ -56,11 +40,10 @@ def create_app():
scoped_session = sqlalchemy.orm.scoped_session(session_maker) scoped_session = sqlalchemy.orm.scoped_session(session_maker)
app = falcon.API( app = falcon.API(
request_type=_CustomRequest, request_type=api.Request,
middleware=[ middleware=[
middleware.ImbueContext(),
middleware.RequireJson(), middleware.RequireJson(),
middleware.JsonTranslator(), middleware.ContextAdapter(),
middleware.DbSession(scoped_session), middleware.DbSession(scoped_session),
middleware.Authenticator(), middleware.Authenticator(),
]) ])

View file

@ -1,7 +1,6 @@
''' Various hooks that get executed for each request. ''' ''' Various hooks that get executed for each request. '''
from szurubooru.middleware.authenticator import Authenticator from szurubooru.middleware.authenticator import Authenticator
from szurubooru.middleware.json_translator import JsonTranslator from szurubooru.middleware.context_adapter import ContextAdapter
from szurubooru.middleware.require_json import RequireJson from szurubooru.middleware.require_json import RequireJson
from szurubooru.middleware.db_session import DbSession from szurubooru.middleware.db_session import DbSession
from szurubooru.middleware.imbue_context import ImbueContext

View file

@ -10,17 +10,23 @@ def json_serializer(obj):
return serial return serial
raise TypeError('Type not serializable') raise TypeError('Type not serializable')
class JsonTranslator(object): class ContextAdapter(object):
''' '''
Translates API requests and API responses to JSON using requests' 1. Deserialize API requests into the context:
context. - Pass GET parameters
- Handle multipart/form-data file uploads
- Handle JSON requests
2. Serialize API responses from the context as JSON.
''' '''
def process_request(self, request, _response): def process_request(self, request, _response):
request.context.files = {}
request.context.input = {}
for key, value in request._params.items():
request.context.input[key] = value
if request.content_length in (None, 0): if request.content_length in (None, 0):
return return
request.context.files = {}
if 'multipart/form-data' in (request.content_type or ''): if 'multipart/form-data' in (request.content_type or ''):
# obscure, claims to "avoid a bug in cgi.FieldStorage" # obscure, claims to "avoid a bug in cgi.FieldStorage"
request.env.setdefault('QUERY_STRING', '') request.env.setdefault('QUERY_STRING', '')
@ -43,7 +49,8 @@ class JsonTranslator(object):
if isinstance(body, bytes): if isinstance(body, bytes):
body = body.decode('utf-8') body = body.decode('utf-8')
request.context.request = json.loads(body) for key, value in json.loads(body).items():
request.context.input[key] = value
except (ValueError, UnicodeDecodeError): except (ValueError, UnicodeDecodeError):
raise falcon.HTTPError( raise falcon.HTTPError(
falcon.HTTP_401, falcon.HTTP_401,
@ -52,7 +59,7 @@ class JsonTranslator(object):
'JSON was incorrect or not encoded as UTF-8.') 'JSON was incorrect or not encoded as UTF-8.')
def process_response(self, request, response, _resource): def process_response(self, request, response, _resource):
if 'result' not in request.context: if not request.context.output:
return return
response.body = json.dumps( response.body = json.dumps(
request.context.result, default=json_serializer, indent=2) request.context.output, default=json_serializer, indent=2)

View file

@ -1,8 +0,0 @@
class ImbueContext(object):
''' Decorates context with methods from falcon's request. '''
def process_request(self, request, _response):
request.context.get_param_as_string = request.get_param_as_string
request.context.get_param_as_bool = request.get_param_as_bool
request.context.get_param_as_int = request.get_param_as_int
request.context.get_param_as_list = request.get_param_as_list

View file

@ -20,8 +20,6 @@ class SearchExecutor(object):
Parse input and return tuple containing total record count and filtered Parse input and return tuple containing total record count and filtered
entities. entities.
''' '''
page = max(1, int(page))
page_size = max(1, int(page_size))
filter_query = self._prepare(session, query_text) filter_query = self._prepare(session, query_text)
entities = filter_query \ entities = filter_query \
.offset((page - 1) * page_size).limit(page_size).all() .offset((page - 1) * page_size).limit(page_size).all()

View file

@ -0,0 +1,43 @@
import pytest
from szurubooru import api, errors
def test_has_param():
ctx = api.Context()
ctx.input = {'key': 'value'}
assert ctx.has_param('key')
assert not ctx.has_param('key2')
def test_get_file():
ctx = api.Context()
ctx.files = {'key': b'content'}
assert ctx.get_file('key') == b'content'
assert ctx.get_file('key2') is None
def test_getting_string_parameter():
ctx = api.Context()
ctx.input = {'key': 'value', 'list': ['1', '2', '3']}
assert ctx.get_param_as_string('key') == 'value'
assert ctx.get_param_as_string('key2') is None
assert ctx.get_param_as_string('key2', default='def') == 'def'
assert ctx.get_param_as_string('list') == '1,2,3' # falcon issue #749
with pytest.raises(errors.ValidationError):
ctx.get_param_as_string('key2', required=True)
def test_getting_int_parameter():
ctx = api.Context()
ctx.input = {'key': '50', 'err': 'invalid', 'list': [1, 2, 3]}
assert ctx.get_param_as_int('key') == 50
assert ctx.get_param_as_int('key2') is None
assert ctx.get_param_as_int('key2', default=5) == 5
with pytest.raises(errors.ValidationError):
ctx.get_param_as_int('list')
with pytest.raises(errors.ValidationError):
ctx.get_param_as_int('key2', required=True)
with pytest.raises(errors.ValidationError):
ctx.get_param_as_int('err')
with pytest.raises(errors.ValidationError):
assert ctx.get_param_as_int('key', min=50) == 50
ctx.get_param_as_int('key', min=51)
with pytest.raises(errors.ValidationError):
assert ctx.get_param_as_int('key', max=50) == 50
ctx.get_param_as_int('key', max=49)

View file

@ -58,21 +58,21 @@ def test_confirmation_no_token(password_reset_api, context_factory, session):
user = mock_user('u1', 'regular_user', 'user@example.com') user = mock_user('u1', 'regular_user', 'user@example.com')
session.add(user) session.add(user)
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError):
password_reset_api.post(context_factory(request={}), 'u1') password_reset_api.post(context_factory(input={}), 'u1')
def test_confirmation_bad_token(password_reset_api, context_factory, session): def test_confirmation_bad_token(password_reset_api, context_factory, session):
user = mock_user('u1', 'regular_user', 'user@example.com') user = mock_user('u1', 'regular_user', 'user@example.com')
session.add(user) session.add(user)
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError):
password_reset_api.post( password_reset_api.post(
context_factory(request={'token': 'bad'}), 'u1') context_factory(input={'token': 'bad'}), 'u1')
def test_confirmation_good_token(password_reset_api, context_factory, session): def test_confirmation_good_token(password_reset_api, context_factory, session):
user = mock_user('u1', 'regular_user', 'user@example.com') user = mock_user('u1', 'regular_user', 'user@example.com')
old_hash = user.password_hash old_hash = user.password_hash
session.add(user) session.add(user)
context = context_factory( context = context_factory(
request={'token': '4ac0be176fb364f13ee6b634c43220e2'}) input={'token': '4ac0be176fb364f13ee6b634c43220e2'})
result = password_reset_api.post(context, 'u1') result = password_reset_api.post(context, 'u1')
assert user.password_hash != old_hash assert user.password_hash != old_hash
assert auth.is_valid_password(user, result['password']) is True assert auth.is_valid_password(user, result['password']) is True

View file

@ -26,7 +26,7 @@ def test_creating_users(
user_list_api.post( user_list_api.post(
context_factory( context_factory(
request={ input={
'name': 'chewie1', 'name': 'chewie1',
'email': 'asd@asd.asd', 'email': 'asd@asd.asd',
'password': 'oks', 'password': 'oks',
@ -34,7 +34,7 @@ def test_creating_users(
user=user_factory(rank='regular_user'))) user=user_factory(rank='regular_user')))
user_list_api.post( user_list_api.post(
context_factory( context_factory(
request={ input={
'name': 'chewie2', 'name': 'chewie2',
'email': 'asd@asd.asd', 'email': 'asd@asd.asd',
'password': 'sok', 'password': 'sok',
@ -68,7 +68,7 @@ def test_creating_user_that_already_exists(
}) })
user_list_api.post( user_list_api.post(
context_factory( context_factory(
request={ input={
'name': 'chewie', 'name': 'chewie',
'email': 'asd@asd.asd', 'email': 'asd@asd.asd',
'password': 'oks', 'password': 'oks',
@ -77,7 +77,7 @@ def test_creating_user_that_already_exists(
with pytest.raises(errors.IntegrityError): with pytest.raises(errors.IntegrityError):
user_list_api.post( user_list_api.post(
context_factory( context_factory(
request={ input={
'name': 'chewie', 'name': 'chewie',
'email': 'asd@asd.asd', 'email': 'asd@asd.asd',
'password': 'oks', 'password': 'oks',
@ -86,7 +86,7 @@ def test_creating_user_that_already_exists(
with pytest.raises(errors.IntegrityError): with pytest.raises(errors.IntegrityError):
user_list_api.post( user_list_api.post(
context_factory( context_factory(
request={ input={
'name': 'CHEWIE', 'name': 'CHEWIE',
'email': 'asd@asd.asd', 'email': 'asd@asd.asd',
'password': 'oks', 'password': 'oks',
@ -109,4 +109,4 @@ def test_missing_field(
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError):
user_list_api.post( user_list_api.post(
context_factory( context_factory(
request=request, user=user_factory(rank='regular_user'))) input=request, user=user_factory(rank='regular_user')))

View file

@ -27,7 +27,7 @@ def test_retrieving_multiple(
session.add_all([user1, user2]) session.add_all([user1, user2])
result = user_list_api.get( result = user_list_api.get(
context_factory( context_factory(
params={'query': '', 'page': 1}, input={'query': '', 'page': 1},
user=user_factory(rank='regular_user'))) user=user_factory(rank='regular_user')))
assert result['query'] == '' assert result['query'] == ''
assert result['page'] == 1 assert result['page'] == 1
@ -44,7 +44,7 @@ def test_retrieving_multiple_without_privileges(
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
user_list_api.get( user_list_api.get(
context_factory( context_factory(
params={'query': '', 'page': 1}, input={'query': '', 'page': 1},
user=user_factory(rank='anonymous'))) user=user_factory(rank='anonymous')))
def test_retrieving_multiple_with_privileges( def test_retrieving_multiple_with_privileges(
@ -55,7 +55,7 @@ def test_retrieving_multiple_with_privileges(
}) })
result = user_list_api.get( result = user_list_api.get(
context_factory( context_factory(
params={'query': 'asd', 'page': 1}, input={'query': 'asd', 'page': 1},
user=user_factory(rank='regular_user'))) user=user_factory(rank='regular_user')))
assert result['query'] == 'asd' assert result['query'] == 'asd'
assert result['page'] == 1 assert result['page'] == 1
@ -79,7 +79,7 @@ def test_retrieving_single(
session.add(user) session.add(user)
result = user_detail_api.get( result = user_detail_api.get(
context_factory( context_factory(
params={'query': '', 'page': 1}, input={'query': '', 'page': 1},
user=user_factory(rank='regular_user')), user=user_factory(rank='regular_user')),
'u1') 'u1')
assert result['user']['id'] == user.user_id assert result['user']['id'] == user.user_id
@ -98,7 +98,7 @@ def test_retrieving_non_existing(
with pytest.raises(errors.NotFoundError): with pytest.raises(errors.NotFoundError):
user_detail_api.get( user_detail_api.get(
context_factory( context_factory(
params={'query': '', 'page': 1}, input={'query': '', 'page': 1},
user=user_factory(rank='regular_user')), user=user_factory(rank='regular_user')),
'-') '-')
@ -111,6 +111,6 @@ def test_retrieving_single_without_privileges(
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
user_detail_api.get( user_detail_api.get(
context_factory( context_factory(
params={'query': '', 'page': 1}, input={'query': '', 'page': 1},
user=user_factory(rank='anonymous')), user=user_factory(rank='anonymous')),
'-') '-')

View file

@ -31,7 +31,7 @@ def test_updating_user(
session.add(user) session.add(user)
user_detail_api.put( user_detail_api.put(
context_factory( context_factory(
request={ input={
'name': 'chewie', 'name': 'chewie',
'email': 'asd@asd.asd', 'email': 'asd@asd.asd',
'password': 'oks', 'password': 'oks',
@ -96,7 +96,7 @@ def test_removing_email(
user = user_factory(name='u1', rank='admin') user = user_factory(name='u1', rank='admin')
session.add(user) session.add(user)
user_detail_api.put( user_detail_api.put(
context_factory(request={'email': ''}, user=user), 'u1') context_factory(input={'email': ''}, user=user), 'u1')
assert session.query(db.User).filter_by(name='u1').one().email is None assert session.query(db.User).filter_by(name='u1').one().email is None
@pytest.mark.parametrize('request', [ @pytest.mark.parametrize('request', [
@ -128,7 +128,7 @@ def test_invalid_inputs(
user = user_factory(name='u1', rank='admin') user = user_factory(name='u1', rank='admin')
session.add(user) session.add(user)
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError):
user_detail_api.put(context_factory(request=request, user=user), 'u1') user_detail_api.put(context_factory(input=request, user=user), 'u1')
@pytest.mark.parametrize('request', [ @pytest.mark.parametrize('request', [
{'name': 'whatever'}, {'name': 'whatever'},
@ -159,7 +159,7 @@ def test_user_trying_to_update_someone_else(
session.add_all([user1, user2]) session.add_all([user1, user2])
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
user_detail_api.put( user_detail_api.put(
context_factory(request=request, user=user1), user2.name) context_factory(input=request, user=user1), user2.name)
def test_user_trying_to_become_someone_else( def test_user_trying_to_become_someone_else(
session, session,
@ -176,11 +176,11 @@ def test_user_trying_to_become_someone_else(
session.add_all([user1, user2]) session.add_all([user1, user2])
with pytest.raises(errors.IntegrityError): with pytest.raises(errors.IntegrityError):
user_detail_api.put( user_detail_api.put(
context_factory(request={'name': 'her'}, user=user1), context_factory(input={'name': 'her'}, user=user1),
'me') 'me')
with pytest.raises(errors.IntegrityError): with pytest.raises(errors.IntegrityError):
user_detail_api.put( user_detail_api.put(
context_factory(request={'name': 'HER'}, user=user1), 'me') context_factory(input={'name': 'HER'}, user=user1), 'me')
def test_mods_trying_to_become_admin( def test_mods_trying_to_become_admin(
session, session,
@ -198,7 +198,7 @@ def test_mods_trying_to_become_admin(
user1 = user_factory(name='u1', rank='mod') user1 = user_factory(name='u1', rank='mod')
user2 = user_factory(name='u2', rank='mod') user2 = user_factory(name='u2', rank='mod')
session.add_all([user1, user2]) session.add_all([user1, user2])
context = context_factory(request={'rank': 'admin'}, user=user1) context = context_factory(input={'rank': 'admin'}, user=user1)
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
user_detail_api.put(context, user1.name) user_detail_api.put(context, user1.name)
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
@ -227,7 +227,7 @@ def test_uploading_avatar(
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b' b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b'
response = user_detail_api.put( response = user_detail_api.put(
context_factory( context_factory(
request={'avatarStyle': 'manual'}, input={'avatarStyle': 'manual'},
files={'avatar': empty_pixel}, files={'avatar': empty_pixel},
user=user), user=user),
'u1') 'u1')

View file

@ -1,7 +1,7 @@
from datetime import datetime from datetime import datetime
import pytest import pytest
import sqlalchemy import sqlalchemy
from szurubooru import db, config from szurubooru import api, config, db
from szurubooru.util import misc from szurubooru.util import misc
@pytest.fixture @pytest.fixture
@ -15,28 +15,14 @@ def session():
@pytest.fixture @pytest.fixture
def context_factory(session): def context_factory(session):
def factory(request=None, params=None, files=None, user=None): def factory(request=None, input=None, files=None, user=None):
params = params or {} ctx = api.Context()
def get_param_as_string(key, default=None, required=False): ctx.input = input or {}
if key not in params: ctx.session = session
if required: ctx.request = request or {}
raise RuntimeError('Param is missing!') ctx.files = files or {}
return default ctx.user = user or db.User()
return params[key] return ctx
def get_param_as_int(key, default=None, required=False):
if key not in params:
if required:
raise RuntimeError('Param is missing!')
return default
return int(params[key])
context = misc.dotdict()
context.session = session
context.request = request or {}
context.files = files or {}
context.user = user or db.User()
context.get_param_as_string = get_param_as_string
context.get_param_as_int = get_param_as_int
return context
return factory return factory
@pytest.fixture @pytest.fixture

View file

@ -123,7 +123,7 @@ def test_combining_tokens(session, verify_unpaged, input, expected_user_names):
(2, 1, 2, ['u2']), (2, 1, 2, ['u2']),
(3, 1, 2, []), (3, 1, 2, []),
(0, 1, 2, ['u1']), (0, 1, 2, ['u1']),
(0, 0, 2, ['u1']), (0, 0, 2, []),
]) ])
def test_paging( def test_paging(
session, executor, page, page_size, session, executor, page, page_size,