server/api: move all io mgmt to context

where input/output includes files, JSON metadata and GET parameters.
Additionally, formalize context with a new class, Context.
This commit is contained in:
rr- 2016-04-15 17:54:21 +02:00
parent 07ea920def
commit 3d4ceb13b8
17 changed files with 211 additions and 149 deletions

View file

@ -2,3 +2,4 @@
from szurubooru.api.password_reset_api import PasswordResetApi
from szurubooru.api.user_api import UserListApi, UserDetailApi
from szurubooru.api.context import Context, Request

View file

@ -3,13 +3,13 @@ import types
def _bind_method(target, desired_method_name):
actual_method = getattr(target, desired_method_name)
def _wrapper_method(_self, request, _response, *args, **kwargs):
request.context.result = actual_method(
request.context, *args, **kwargs)
request.context.output = \
actual_method(request.context, *args, **kwargs)
return types.MethodType(_wrapper_method, target)
class BaseApi(object):
'''
A wrapper around falcon's API interface that eases context and result
A wrapper around falcon's API interface that eases input and output
management.
'''

View file

@ -0,0 +1,57 @@
import falcon
from szurubooru import errors
class Context(object):
def __init__(self):
self.session = None
self.user = None
self.files = {}
self.input = {}
self.output = None
def has_param(self, name):
return name in self.input
def get_file(self, name):
return self.files.get(name, None)
def get_param_as_string(self, name, required=False, default=None):
if name in self.input:
param = self.input[name]
if isinstance(param, list):
param = ','.join(param)
return param
if not required:
return default
raise errors.ValidationError('Required paramter %r is missing.' % name)
def get_param_as_int(
self, name, required=False, min=None, max=None, default=None):
if name in self.input:
val = self.input[name]
try:
val = int(val)
except (ValueError, TypeError):
raise errors.ValidationError(
'Parameter %r is invalid: the value must be an integer.'
% name)
if min is not None and val < min:
raise errors.ValidationError(
'Parameter %r is invalid: the value must be at least %r.'
% (name, min))
if max is not None and val > max:
raise errors.ValidationError(
'Parameter %r is invalid: the value may not exceed %r.'
% (name, max))
return val
if not required:
return default
raise errors.ValidationError(
'Required parameter %r is missing.' % name)
class Request(falcon.Request):
context_type = Context

View file

@ -9,9 +9,9 @@ MAIL_BODY = \
'Otherwise, please ignore this email.'
class PasswordResetApi(BaseApi):
def get(self, context, user_name):
def get(self, ctx, user_name):
''' Send a mail with secure token to the correlated user. '''
user = users.get_by_name_or_email(context.session, user_name)
user = users.get_by_name_or_email(ctx.session, user_name)
if not user:
raise errors.NotFoundError('User %r not found.' % user_name)
if not user.email:
@ -27,17 +27,15 @@ class PasswordResetApi(BaseApi):
MAIL_BODY.format(name=config.config['name'], url=url))
return {}
def post(self, context, user_name):
def post(self, ctx, user_name):
''' Verify token from mail, generate a new password and return it. '''
user = users.get_by_name_or_email(context.session, user_name)
user = users.get_by_name_or_email(ctx.session, user_name)
if not user:
raise errors.NotFoundError('User %r not found.' % user_name)
good_token = auth.generate_authentication_token(user)
if not 'token' in context.request:
raise errors.ValidationError('Missing password reset token.')
token = context.request['token']
token = ctx.get_param_as_string('token', required=True)
if token != good_token:
raise errors.ValidationError('Invalid password reset token.')
new_password = users.reset_password(user)
context.session.commit()
ctx.session.commit()
return {'password': new_password}

View file

@ -34,96 +34,94 @@ class UserListApi(BaseApi):
super().__init__()
self._search_executor = search.SearchExecutor(search.UserSearchConfig())
def get(self, context):
auth.verify_privilege(context.user, 'users:list')
query = context.get_param_as_string('query')
page = context.get_param_as_int('page', 1)
page_size = min(100, context.get_param_as_int('pageSize', required=False) or 100)
def get(self, ctx):
auth.verify_privilege(ctx.user, 'users:list')
query = ctx.get_param_as_string('query')
page = ctx.get_param_as_int('page', default=1, min=1)
page_size = ctx.get_param_as_int(
'pageSize', default=100, min=1, max=100)
count, user_list = self._search_executor.execute(
context.session, query, page, page_size)
ctx.session, query, page, page_size)
return {
'query': query,
'page': page,
'pageSize': page_size,
'total': count,
'users': [_serialize_user(context.user, user) for user in user_list],
'users': [_serialize_user(ctx.user, user) for user in user_list],
}
def post(self, context):
auth.verify_privilege(context.user, 'users:create')
def post(self, ctx):
auth.verify_privilege(ctx.user, 'users:create')
try:
name = context.request['name'].strip()
password = context.request['password']
email = context.request['email'].strip()
except KeyError as ex:
raise errors.ValidationError('Field %r not found.' % ex.args[0])
name = ctx.get_param_as_string('name', required=True)
password = ctx.get_param_as_string('password', required=True)
email = ctx.get_param_as_string('email', required=True)
if users.get_by_name(context.session, name):
if users.get_by_name(ctx.session, name):
raise errors.IntegrityError('User %r already exists.' % name)
user = users.create_user(context.session, name, password, email)
context.session.add(user)
context.session.commit()
return {'user': _serialize_user(context.user, user)}
user = users.create_user(ctx.session, name, password, email)
ctx.session.add(user)
ctx.session.commit()
return {'user': _serialize_user(ctx.user, user)}
class UserDetailApi(BaseApi):
def get(self, context, user_name):
auth.verify_privilege(context.user, 'users:view')
user = users.get_by_name(context.session, user_name)
def get(self, ctx, user_name):
auth.verify_privilege(ctx.user, 'users:view')
user = users.get_by_name(ctx.session, user_name)
if not user:
raise errors.NotFoundError('User %r not found.' % user_name)
return {'user': _serialize_user(context.user, user)}
return {'user': _serialize_user(ctx.user, user)}
def put(self, context, user_name):
user = users.get_by_name(context.session, user_name)
def put(self, ctx, user_name):
user = users.get_by_name(ctx.session, user_name)
if not user:
raise errors.NotFoundError('User %r not found.' % user_name)
if context.user.user_id == user.user_id:
if ctx.user.user_id == user.user_id:
infix = 'self'
else:
infix = 'any'
if 'name' in context.request:
auth.verify_privilege(context.user, 'users:edit:%s:name' % infix)
other_user = users.get_by_name(context.session, context.request['name'])
if ctx.has_param('name'):
auth.verify_privilege(ctx.user, 'users:edit:%s:name' % infix)
other_user = users.get_by_name(ctx.session, ctx.get_param_as_string('name'))
if other_user and other_user.user_id != user.user_id:
raise errors.IntegrityError('User %r already exists.' % user.name)
users.update_name(user, context.request['name'])
users.update_name(user, ctx.get_param_as_string('name'))
if 'password' in context.request:
auth.verify_privilege(context.user, 'users:edit:%s:pass' % infix)
users.update_password(user, context.request['password'])
if ctx.has_param('password'):
auth.verify_privilege(ctx.user, 'users:edit:%s:pass' % infix)
users.update_password(user, ctx.get_param_as_string('password'))
if 'email' in context.request:
auth.verify_privilege(context.user, 'users:edit:%s:email' % infix)
users.update_email(user, context.request['email'])
if ctx.has_param('email'):
auth.verify_privilege(ctx.user, 'users:edit:%s:email' % infix)
users.update_email(user, ctx.get_param_as_string('email'))
if 'rank' in context.request:
auth.verify_privilege(context.user, 'users:edit:%s:rank' % infix)
users.update_rank(user, context.request['rank'], context.user)
if ctx.has_param('rank'):
auth.verify_privilege(ctx.user, 'users:edit:%s:rank' % infix)
users.update_rank(user, ctx.get_param_as_string('rank'), ctx.user)
if 'avatarStyle' in context.request:
auth.verify_privilege(context.user, 'users:edit:%s:avatar' % infix)
if ctx.has_param('avatarStyle'):
auth.verify_privilege(ctx.user, 'users:edit:%s:avatar' % infix)
users.update_avatar(
user,
context.request['avatarStyle'],
context.files.get('avatar') or None)
ctx.get_param_as_string('avatarStyle'),
ctx.get_file('avatar'))
context.session.commit()
return {'user': _serialize_user(context.user, user)}
ctx.session.commit()
return {'user': _serialize_user(ctx.user, user)}
def delete(self, context, user_name):
user = users.get_by_name(context.session, user_name)
def delete(self, ctx, user_name):
user = users.get_by_name(ctx.session, user_name)
if not user:
raise errors.NotFoundError('User %r not found.' % user_name)
if context.user.user_id == user.user_id:
if ctx.user.user_id == user.user_id:
infix = 'self'
else:
infix = 'any'
auth.verify_privilege(context.user, 'users:delete:%s' % infix)
context.session.delete(user)
context.session.commit()
auth.verify_privilege(ctx.user, 'users:delete:%s' % infix)
ctx.session.delete(user)
ctx.session.commit()
return {}

View file

@ -6,22 +6,6 @@ import sqlalchemy.orm
from szurubooru import api, config, errors, middleware
from szurubooru.util import misc
class _CustomRequest(falcon.Request):
context_type = misc.dotdict
def get_param_as_string(self, name, required=False, store=None, default=None):
params = self._params
if name in params:
param = params[name]
if isinstance(param, list):
param = ','.join(param)
if store is not None:
store[name] = param
return param
if not required:
return default
raise falcon.HTTPMissingParam(name)
def _on_auth_error(ex, _request, _response, _params):
raise falcon.HTTPForbidden(
title='Authentication error', description=str(ex))
@ -56,11 +40,10 @@ def create_app():
scoped_session = sqlalchemy.orm.scoped_session(session_maker)
app = falcon.API(
request_type=_CustomRequest,
request_type=api.Request,
middleware=[
middleware.ImbueContext(),
middleware.RequireJson(),
middleware.JsonTranslator(),
middleware.ContextAdapter(),
middleware.DbSession(scoped_session),
middleware.Authenticator(),
])

View file

@ -1,7 +1,6 @@
''' Various hooks that get executed for each request. '''
from szurubooru.middleware.authenticator import Authenticator
from szurubooru.middleware.json_translator import JsonTranslator
from szurubooru.middleware.context_adapter import ContextAdapter
from szurubooru.middleware.require_json import RequireJson
from szurubooru.middleware.db_session import DbSession
from szurubooru.middleware.imbue_context import ImbueContext

View file

@ -10,17 +10,23 @@ def json_serializer(obj):
return serial
raise TypeError('Type not serializable')
class JsonTranslator(object):
class ContextAdapter(object):
'''
Translates API requests and API responses to JSON using requests'
context.
1. Deserialize API requests into the context:
- Pass GET parameters
- Handle multipart/form-data file uploads
- Handle JSON requests
2. Serialize API responses from the context as JSON.
'''
def process_request(self, request, _response):
request.context.files = {}
request.context.input = {}
for key, value in request._params.items():
request.context.input[key] = value
if request.content_length in (None, 0):
return
request.context.files = {}
if 'multipart/form-data' in (request.content_type or ''):
# obscure, claims to "avoid a bug in cgi.FieldStorage"
request.env.setdefault('QUERY_STRING', '')
@ -43,7 +49,8 @@ class JsonTranslator(object):
if isinstance(body, bytes):
body = body.decode('utf-8')
request.context.request = json.loads(body)
for key, value in json.loads(body).items():
request.context.input[key] = value
except (ValueError, UnicodeDecodeError):
raise falcon.HTTPError(
falcon.HTTP_401,
@ -52,7 +59,7 @@ class JsonTranslator(object):
'JSON was incorrect or not encoded as UTF-8.')
def process_response(self, request, response, _resource):
if 'result' not in request.context:
if not request.context.output:
return
response.body = json.dumps(
request.context.result, default=json_serializer, indent=2)
request.context.output, default=json_serializer, indent=2)

View file

@ -1,8 +0,0 @@
class ImbueContext(object):
''' Decorates context with methods from falcon's request. '''
def process_request(self, request, _response):
request.context.get_param_as_string = request.get_param_as_string
request.context.get_param_as_bool = request.get_param_as_bool
request.context.get_param_as_int = request.get_param_as_int
request.context.get_param_as_list = request.get_param_as_list

View file

@ -20,8 +20,6 @@ class SearchExecutor(object):
Parse input and return tuple containing total record count and filtered
entities.
'''
page = max(1, int(page))
page_size = max(1, int(page_size))
filter_query = self._prepare(session, query_text)
entities = filter_query \
.offset((page - 1) * page_size).limit(page_size).all()

View file

@ -0,0 +1,43 @@
import pytest
from szurubooru import api, errors
def test_has_param():
ctx = api.Context()
ctx.input = {'key': 'value'}
assert ctx.has_param('key')
assert not ctx.has_param('key2')
def test_get_file():
ctx = api.Context()
ctx.files = {'key': b'content'}
assert ctx.get_file('key') == b'content'
assert ctx.get_file('key2') is None
def test_getting_string_parameter():
ctx = api.Context()
ctx.input = {'key': 'value', 'list': ['1', '2', '3']}
assert ctx.get_param_as_string('key') == 'value'
assert ctx.get_param_as_string('key2') is None
assert ctx.get_param_as_string('key2', default='def') == 'def'
assert ctx.get_param_as_string('list') == '1,2,3' # falcon issue #749
with pytest.raises(errors.ValidationError):
ctx.get_param_as_string('key2', required=True)
def test_getting_int_parameter():
ctx = api.Context()
ctx.input = {'key': '50', 'err': 'invalid', 'list': [1, 2, 3]}
assert ctx.get_param_as_int('key') == 50
assert ctx.get_param_as_int('key2') is None
assert ctx.get_param_as_int('key2', default=5) == 5
with pytest.raises(errors.ValidationError):
ctx.get_param_as_int('list')
with pytest.raises(errors.ValidationError):
ctx.get_param_as_int('key2', required=True)
with pytest.raises(errors.ValidationError):
ctx.get_param_as_int('err')
with pytest.raises(errors.ValidationError):
assert ctx.get_param_as_int('key', min=50) == 50
ctx.get_param_as_int('key', min=51)
with pytest.raises(errors.ValidationError):
assert ctx.get_param_as_int('key', max=50) == 50
ctx.get_param_as_int('key', max=49)

View file

@ -58,21 +58,21 @@ def test_confirmation_no_token(password_reset_api, context_factory, session):
user = mock_user('u1', 'regular_user', 'user@example.com')
session.add(user)
with pytest.raises(errors.ValidationError):
password_reset_api.post(context_factory(request={}), 'u1')
password_reset_api.post(context_factory(input={}), 'u1')
def test_confirmation_bad_token(password_reset_api, context_factory, session):
user = mock_user('u1', 'regular_user', 'user@example.com')
session.add(user)
with pytest.raises(errors.ValidationError):
password_reset_api.post(
context_factory(request={'token': 'bad'}), 'u1')
context_factory(input={'token': 'bad'}), 'u1')
def test_confirmation_good_token(password_reset_api, context_factory, session):
user = mock_user('u1', 'regular_user', 'user@example.com')
old_hash = user.password_hash
session.add(user)
context = context_factory(
request={'token': '4ac0be176fb364f13ee6b634c43220e2'})
input={'token': '4ac0be176fb364f13ee6b634c43220e2'})
result = password_reset_api.post(context, 'u1')
assert user.password_hash != old_hash
assert auth.is_valid_password(user, result['password']) is True

View file

@ -26,7 +26,7 @@ def test_creating_users(
user_list_api.post(
context_factory(
request={
input={
'name': 'chewie1',
'email': 'asd@asd.asd',
'password': 'oks',
@ -34,7 +34,7 @@ def test_creating_users(
user=user_factory(rank='regular_user')))
user_list_api.post(
context_factory(
request={
input={
'name': 'chewie2',
'email': 'asd@asd.asd',
'password': 'sok',
@ -68,7 +68,7 @@ def test_creating_user_that_already_exists(
})
user_list_api.post(
context_factory(
request={
input={
'name': 'chewie',
'email': 'asd@asd.asd',
'password': 'oks',
@ -77,7 +77,7 @@ def test_creating_user_that_already_exists(
with pytest.raises(errors.IntegrityError):
user_list_api.post(
context_factory(
request={
input={
'name': 'chewie',
'email': 'asd@asd.asd',
'password': 'oks',
@ -86,7 +86,7 @@ def test_creating_user_that_already_exists(
with pytest.raises(errors.IntegrityError):
user_list_api.post(
context_factory(
request={
input={
'name': 'CHEWIE',
'email': 'asd@asd.asd',
'password': 'oks',
@ -109,4 +109,4 @@ def test_missing_field(
with pytest.raises(errors.ValidationError):
user_list_api.post(
context_factory(
request=request, user=user_factory(rank='regular_user')))
input=request, user=user_factory(rank='regular_user')))

View file

@ -27,7 +27,7 @@ def test_retrieving_multiple(
session.add_all([user1, user2])
result = user_list_api.get(
context_factory(
params={'query': '', 'page': 1},
input={'query': '', 'page': 1},
user=user_factory(rank='regular_user')))
assert result['query'] == ''
assert result['page'] == 1
@ -44,7 +44,7 @@ def test_retrieving_multiple_without_privileges(
with pytest.raises(errors.AuthError):
user_list_api.get(
context_factory(
params={'query': '', 'page': 1},
input={'query': '', 'page': 1},
user=user_factory(rank='anonymous')))
def test_retrieving_multiple_with_privileges(
@ -55,7 +55,7 @@ def test_retrieving_multiple_with_privileges(
})
result = user_list_api.get(
context_factory(
params={'query': 'asd', 'page': 1},
input={'query': 'asd', 'page': 1},
user=user_factory(rank='regular_user')))
assert result['query'] == 'asd'
assert result['page'] == 1
@ -79,7 +79,7 @@ def test_retrieving_single(
session.add(user)
result = user_detail_api.get(
context_factory(
params={'query': '', 'page': 1},
input={'query': '', 'page': 1},
user=user_factory(rank='regular_user')),
'u1')
assert result['user']['id'] == user.user_id
@ -98,7 +98,7 @@ def test_retrieving_non_existing(
with pytest.raises(errors.NotFoundError):
user_detail_api.get(
context_factory(
params={'query': '', 'page': 1},
input={'query': '', 'page': 1},
user=user_factory(rank='regular_user')),
'-')
@ -111,6 +111,6 @@ def test_retrieving_single_without_privileges(
with pytest.raises(errors.AuthError):
user_detail_api.get(
context_factory(
params={'query': '', 'page': 1},
input={'query': '', 'page': 1},
user=user_factory(rank='anonymous')),
'-')

View file

@ -31,7 +31,7 @@ def test_updating_user(
session.add(user)
user_detail_api.put(
context_factory(
request={
input={
'name': 'chewie',
'email': 'asd@asd.asd',
'password': 'oks',
@ -96,7 +96,7 @@ def test_removing_email(
user = user_factory(name='u1', rank='admin')
session.add(user)
user_detail_api.put(
context_factory(request={'email': ''}, user=user), 'u1')
context_factory(input={'email': ''}, user=user), 'u1')
assert session.query(db.User).filter_by(name='u1').one().email is None
@pytest.mark.parametrize('request', [
@ -128,7 +128,7 @@ def test_invalid_inputs(
user = user_factory(name='u1', rank='admin')
session.add(user)
with pytest.raises(errors.ValidationError):
user_detail_api.put(context_factory(request=request, user=user), 'u1')
user_detail_api.put(context_factory(input=request, user=user), 'u1')
@pytest.mark.parametrize('request', [
{'name': 'whatever'},
@ -159,7 +159,7 @@ def test_user_trying_to_update_someone_else(
session.add_all([user1, user2])
with pytest.raises(errors.AuthError):
user_detail_api.put(
context_factory(request=request, user=user1), user2.name)
context_factory(input=request, user=user1), user2.name)
def test_user_trying_to_become_someone_else(
session,
@ -176,11 +176,11 @@ def test_user_trying_to_become_someone_else(
session.add_all([user1, user2])
with pytest.raises(errors.IntegrityError):
user_detail_api.put(
context_factory(request={'name': 'her'}, user=user1),
context_factory(input={'name': 'her'}, user=user1),
'me')
with pytest.raises(errors.IntegrityError):
user_detail_api.put(
context_factory(request={'name': 'HER'}, user=user1), 'me')
context_factory(input={'name': 'HER'}, user=user1), 'me')
def test_mods_trying_to_become_admin(
session,
@ -198,7 +198,7 @@ def test_mods_trying_to_become_admin(
user1 = user_factory(name='u1', rank='mod')
user2 = user_factory(name='u2', rank='mod')
session.add_all([user1, user2])
context = context_factory(request={'rank': 'admin'}, user=user1)
context = context_factory(input={'rank': 'admin'}, user=user1)
with pytest.raises(errors.AuthError):
user_detail_api.put(context, user1.name)
with pytest.raises(errors.AuthError):
@ -227,7 +227,7 @@ def test_uploading_avatar(
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b'
response = user_detail_api.put(
context_factory(
request={'avatarStyle': 'manual'},
input={'avatarStyle': 'manual'},
files={'avatar': empty_pixel},
user=user),
'u1')

View file

@ -1,7 +1,7 @@
from datetime import datetime
import pytest
import sqlalchemy
from szurubooru import db, config
from szurubooru import api, config, db
from szurubooru.util import misc
@pytest.fixture
@ -15,28 +15,14 @@ def session():
@pytest.fixture
def context_factory(session):
def factory(request=None, params=None, files=None, user=None):
params = params or {}
def get_param_as_string(key, default=None, required=False):
if key not in params:
if required:
raise RuntimeError('Param is missing!')
return default
return params[key]
def get_param_as_int(key, default=None, required=False):
if key not in params:
if required:
raise RuntimeError('Param is missing!')
return default
return int(params[key])
context = misc.dotdict()
context.session = session
context.request = request or {}
context.files = files or {}
context.user = user or db.User()
context.get_param_as_string = get_param_as_string
context.get_param_as_int = get_param_as_int
return context
def factory(request=None, input=None, files=None, user=None):
ctx = api.Context()
ctx.input = input or {}
ctx.session = session
ctx.request = request or {}
ctx.files = files or {}
ctx.user = user or db.User()
return ctx
return factory
@pytest.fixture

View file

@ -123,7 +123,7 @@ def test_combining_tokens(session, verify_unpaged, input, expected_user_names):
(2, 1, 2, ['u2']),
(3, 1, 2, []),
(0, 1, 2, ['u1']),
(0, 0, 2, ['u1']),
(0, 0, 2, []),
])
def test_paging(
session, executor, page, page_size,