From 3d4ceb13b8f10b7cbe9f5e135f4091cd9d5e2ed2 Mon Sep 17 00:00:00 2001 From: rr- Date: Fri, 15 Apr 2016 17:54:21 +0200 Subject: [PATCH] server/api: move all io mgmt to context where input/output includes files, JSON metadata and GET parameters. Additionally, formalize context with a new class, Context. --- server/szurubooru/api/__init__.py | 1 + server/szurubooru/api/base_api.py | 6 +- server/szurubooru/api/context.py | 57 ++++++++++ server/szurubooru/api/password_reset_api.py | 14 ++- server/szurubooru/api/user_api.py | 102 +++++++++--------- server/szurubooru/app.py | 21 +--- server/szurubooru/middleware/__init__.py | 3 +- ...{json_translator.py => context_adapter.py} | 23 ++-- server/szurubooru/middleware/imbue_context.py | 8 -- server/szurubooru/search/search_executor.py | 2 - server/szurubooru/tests/api/test_context.py | 43 ++++++++ .../tests/api/test_password_reset.py | 6 +- .../tests/api/test_user_creating.py | 12 +-- .../tests/api/test_user_retrieval.py | 12 +-- .../tests/api/test_user_updating.py | 16 +-- server/szurubooru/tests/conftest.py | 32 ++---- .../tests/search/test_user_search_config.py | 2 +- 17 files changed, 211 insertions(+), 149 deletions(-) create mode 100644 server/szurubooru/api/context.py rename server/szurubooru/middleware/{json_translator.py => context_adapter.py} (74%) delete mode 100644 server/szurubooru/middleware/imbue_context.py create mode 100644 server/szurubooru/tests/api/test_context.py diff --git a/server/szurubooru/api/__init__.py b/server/szurubooru/api/__init__.py index 4f79c4b4..b51523e4 100644 --- a/server/szurubooru/api/__init__.py +++ b/server/szurubooru/api/__init__.py @@ -2,3 +2,4 @@ from szurubooru.api.password_reset_api import PasswordResetApi from szurubooru.api.user_api import UserListApi, UserDetailApi +from szurubooru.api.context import Context, Request diff --git a/server/szurubooru/api/base_api.py b/server/szurubooru/api/base_api.py index 6170d47c..9fd2096c 100644 --- a/server/szurubooru/api/base_api.py +++ b/server/szurubooru/api/base_api.py @@ -3,13 +3,13 @@ import types def _bind_method(target, desired_method_name): actual_method = getattr(target, desired_method_name) def _wrapper_method(_self, request, _response, *args, **kwargs): - request.context.result = actual_method( - request.context, *args, **kwargs) + request.context.output = \ + actual_method(request.context, *args, **kwargs) return types.MethodType(_wrapper_method, target) class BaseApi(object): ''' - A wrapper around falcon's API interface that eases context and result + A wrapper around falcon's API interface that eases input and output management. ''' diff --git a/server/szurubooru/api/context.py b/server/szurubooru/api/context.py new file mode 100644 index 00000000..72dea321 --- /dev/null +++ b/server/szurubooru/api/context.py @@ -0,0 +1,57 @@ +import falcon +from szurubooru import errors + +class Context(object): + def __init__(self): + self.session = None + self.user = None + self.files = {} + self.input = {} + self.output = None + + def has_param(self, name): + return name in self.input + + def get_file(self, name): + return self.files.get(name, None) + + def get_param_as_string(self, name, required=False, default=None): + if name in self.input: + param = self.input[name] + if isinstance(param, list): + param = ','.join(param) + return param + if not required: + return default + raise errors.ValidationError('Required paramter %r is missing.' % name) + + def get_param_as_int( + self, name, required=False, min=None, max=None, default=None): + if name in self.input: + val = self.input[name] + try: + val = int(val) + except (ValueError, TypeError): + raise errors.ValidationError( + 'Parameter %r is invalid: the value must be an integer.' + % name) + + if min is not None and val < min: + raise errors.ValidationError( + 'Parameter %r is invalid: the value must be at least %r.' + % (name, min)) + + if max is not None and val > max: + raise errors.ValidationError( + 'Parameter %r is invalid: the value may not exceed %r.' + % (name, max)) + + return val + + if not required: + return default + raise errors.ValidationError( + 'Required parameter %r is missing.' % name) + +class Request(falcon.Request): + context_type = Context diff --git a/server/szurubooru/api/password_reset_api.py b/server/szurubooru/api/password_reset_api.py index 47c1f127..04895090 100644 --- a/server/szurubooru/api/password_reset_api.py +++ b/server/szurubooru/api/password_reset_api.py @@ -9,9 +9,9 @@ MAIL_BODY = \ 'Otherwise, please ignore this email.' class PasswordResetApi(BaseApi): - def get(self, context, user_name): + def get(self, ctx, user_name): ''' Send a mail with secure token to the correlated user. ''' - user = users.get_by_name_or_email(context.session, user_name) + user = users.get_by_name_or_email(ctx.session, user_name) if not user: raise errors.NotFoundError('User %r not found.' % user_name) if not user.email: @@ -27,17 +27,15 @@ class PasswordResetApi(BaseApi): MAIL_BODY.format(name=config.config['name'], url=url)) return {} - def post(self, context, user_name): + def post(self, ctx, user_name): ''' Verify token from mail, generate a new password and return it. ''' - user = users.get_by_name_or_email(context.session, user_name) + user = users.get_by_name_or_email(ctx.session, user_name) if not user: raise errors.NotFoundError('User %r not found.' % user_name) good_token = auth.generate_authentication_token(user) - if not 'token' in context.request: - raise errors.ValidationError('Missing password reset token.') - token = context.request['token'] + token = ctx.get_param_as_string('token', required=True) if token != good_token: raise errors.ValidationError('Invalid password reset token.') new_password = users.reset_password(user) - context.session.commit() + ctx.session.commit() return {'password': new_password} diff --git a/server/szurubooru/api/user_api.py b/server/szurubooru/api/user_api.py index 697f940d..71a64a06 100644 --- a/server/szurubooru/api/user_api.py +++ b/server/szurubooru/api/user_api.py @@ -34,96 +34,94 @@ class UserListApi(BaseApi): super().__init__() self._search_executor = search.SearchExecutor(search.UserSearchConfig()) - def get(self, context): - auth.verify_privilege(context.user, 'users:list') - query = context.get_param_as_string('query') - page = context.get_param_as_int('page', 1) - page_size = min(100, context.get_param_as_int('pageSize', required=False) or 100) + def get(self, ctx): + auth.verify_privilege(ctx.user, 'users:list') + query = ctx.get_param_as_string('query') + page = ctx.get_param_as_int('page', default=1, min=1) + page_size = ctx.get_param_as_int( + 'pageSize', default=100, min=1, max=100) count, user_list = self._search_executor.execute( - context.session, query, page, page_size) + ctx.session, query, page, page_size) return { 'query': query, 'page': page, 'pageSize': page_size, 'total': count, - 'users': [_serialize_user(context.user, user) for user in user_list], + 'users': [_serialize_user(ctx.user, user) for user in user_list], } - def post(self, context): - auth.verify_privilege(context.user, 'users:create') + def post(self, ctx): + auth.verify_privilege(ctx.user, 'users:create') - try: - name = context.request['name'].strip() - password = context.request['password'] - email = context.request['email'].strip() - except KeyError as ex: - raise errors.ValidationError('Field %r not found.' % ex.args[0]) + name = ctx.get_param_as_string('name', required=True) + password = ctx.get_param_as_string('password', required=True) + email = ctx.get_param_as_string('email', required=True) - if users.get_by_name(context.session, name): + if users.get_by_name(ctx.session, name): raise errors.IntegrityError('User %r already exists.' % name) - user = users.create_user(context.session, name, password, email) - context.session.add(user) - context.session.commit() - return {'user': _serialize_user(context.user, user)} + user = users.create_user(ctx.session, name, password, email) + ctx.session.add(user) + ctx.session.commit() + return {'user': _serialize_user(ctx.user, user)} class UserDetailApi(BaseApi): - def get(self, context, user_name): - auth.verify_privilege(context.user, 'users:view') - user = users.get_by_name(context.session, user_name) + def get(self, ctx, user_name): + auth.verify_privilege(ctx.user, 'users:view') + user = users.get_by_name(ctx.session, user_name) if not user: raise errors.NotFoundError('User %r not found.' % user_name) - return {'user': _serialize_user(context.user, user)} + return {'user': _serialize_user(ctx.user, user)} - def put(self, context, user_name): - user = users.get_by_name(context.session, user_name) + def put(self, ctx, user_name): + user = users.get_by_name(ctx.session, user_name) if not user: raise errors.NotFoundError('User %r not found.' % user_name) - if context.user.user_id == user.user_id: + if ctx.user.user_id == user.user_id: infix = 'self' else: infix = 'any' - if 'name' in context.request: - auth.verify_privilege(context.user, 'users:edit:%s:name' % infix) - other_user = users.get_by_name(context.session, context.request['name']) + if ctx.has_param('name'): + auth.verify_privilege(ctx.user, 'users:edit:%s:name' % infix) + other_user = users.get_by_name(ctx.session, ctx.get_param_as_string('name')) if other_user and other_user.user_id != user.user_id: raise errors.IntegrityError('User %r already exists.' % user.name) - users.update_name(user, context.request['name']) + users.update_name(user, ctx.get_param_as_string('name')) - if 'password' in context.request: - auth.verify_privilege(context.user, 'users:edit:%s:pass' % infix) - users.update_password(user, context.request['password']) + if ctx.has_param('password'): + auth.verify_privilege(ctx.user, 'users:edit:%s:pass' % infix) + users.update_password(user, ctx.get_param_as_string('password')) - if 'email' in context.request: - auth.verify_privilege(context.user, 'users:edit:%s:email' % infix) - users.update_email(user, context.request['email']) + if ctx.has_param('email'): + auth.verify_privilege(ctx.user, 'users:edit:%s:email' % infix) + users.update_email(user, ctx.get_param_as_string('email')) - if 'rank' in context.request: - auth.verify_privilege(context.user, 'users:edit:%s:rank' % infix) - users.update_rank(user, context.request['rank'], context.user) + if ctx.has_param('rank'): + auth.verify_privilege(ctx.user, 'users:edit:%s:rank' % infix) + users.update_rank(user, ctx.get_param_as_string('rank'), ctx.user) - if 'avatarStyle' in context.request: - auth.verify_privilege(context.user, 'users:edit:%s:avatar' % infix) + if ctx.has_param('avatarStyle'): + auth.verify_privilege(ctx.user, 'users:edit:%s:avatar' % infix) users.update_avatar( user, - context.request['avatarStyle'], - context.files.get('avatar') or None) + ctx.get_param_as_string('avatarStyle'), + ctx.get_file('avatar')) - context.session.commit() - return {'user': _serialize_user(context.user, user)} + ctx.session.commit() + return {'user': _serialize_user(ctx.user, user)} - def delete(self, context, user_name): - user = users.get_by_name(context.session, user_name) + def delete(self, ctx, user_name): + user = users.get_by_name(ctx.session, user_name) if not user: raise errors.NotFoundError('User %r not found.' % user_name) - if context.user.user_id == user.user_id: + if ctx.user.user_id == user.user_id: infix = 'self' else: infix = 'any' - auth.verify_privilege(context.user, 'users:delete:%s' % infix) - context.session.delete(user) - context.session.commit() + auth.verify_privilege(ctx.user, 'users:delete:%s' % infix) + ctx.session.delete(user) + ctx.session.commit() return {} diff --git a/server/szurubooru/app.py b/server/szurubooru/app.py index da00d7fa..0481cea7 100644 --- a/server/szurubooru/app.py +++ b/server/szurubooru/app.py @@ -6,22 +6,6 @@ import sqlalchemy.orm from szurubooru import api, config, errors, middleware from szurubooru.util import misc -class _CustomRequest(falcon.Request): - context_type = misc.dotdict - - def get_param_as_string(self, name, required=False, store=None, default=None): - params = self._params - if name in params: - param = params[name] - if isinstance(param, list): - param = ','.join(param) - if store is not None: - store[name] = param - return param - if not required: - return default - raise falcon.HTTPMissingParam(name) - def _on_auth_error(ex, _request, _response, _params): raise falcon.HTTPForbidden( title='Authentication error', description=str(ex)) @@ -56,11 +40,10 @@ def create_app(): scoped_session = sqlalchemy.orm.scoped_session(session_maker) app = falcon.API( - request_type=_CustomRequest, + request_type=api.Request, middleware=[ - middleware.ImbueContext(), middleware.RequireJson(), - middleware.JsonTranslator(), + middleware.ContextAdapter(), middleware.DbSession(scoped_session), middleware.Authenticator(), ]) diff --git a/server/szurubooru/middleware/__init__.py b/server/szurubooru/middleware/__init__.py index e958790f..84665901 100644 --- a/server/szurubooru/middleware/__init__.py +++ b/server/szurubooru/middleware/__init__.py @@ -1,7 +1,6 @@ ''' Various hooks that get executed for each request. ''' from szurubooru.middleware.authenticator import Authenticator -from szurubooru.middleware.json_translator import JsonTranslator +from szurubooru.middleware.context_adapter import ContextAdapter from szurubooru.middleware.require_json import RequireJson from szurubooru.middleware.db_session import DbSession -from szurubooru.middleware.imbue_context import ImbueContext diff --git a/server/szurubooru/middleware/json_translator.py b/server/szurubooru/middleware/context_adapter.py similarity index 74% rename from server/szurubooru/middleware/json_translator.py rename to server/szurubooru/middleware/context_adapter.py index e59d2228..e0db8a94 100644 --- a/server/szurubooru/middleware/json_translator.py +++ b/server/szurubooru/middleware/context_adapter.py @@ -10,17 +10,23 @@ def json_serializer(obj): return serial raise TypeError('Type not serializable') -class JsonTranslator(object): +class ContextAdapter(object): ''' - Translates API requests and API responses to JSON using requests' - context. + 1. Deserialize API requests into the context: + - Pass GET parameters + - Handle multipart/form-data file uploads + - Handle JSON requests + 2. Serialize API responses from the context as JSON. ''' - def process_request(self, request, _response): + request.context.files = {} + request.context.input = {} + for key, value in request._params.items(): + request.context.input[key] = value + if request.content_length in (None, 0): return - request.context.files = {} if 'multipart/form-data' in (request.content_type or ''): # obscure, claims to "avoid a bug in cgi.FieldStorage" request.env.setdefault('QUERY_STRING', '') @@ -43,7 +49,8 @@ class JsonTranslator(object): if isinstance(body, bytes): body = body.decode('utf-8') - request.context.request = json.loads(body) + for key, value in json.loads(body).items(): + request.context.input[key] = value except (ValueError, UnicodeDecodeError): raise falcon.HTTPError( falcon.HTTP_401, @@ -52,7 +59,7 @@ class JsonTranslator(object): 'JSON was incorrect or not encoded as UTF-8.') def process_response(self, request, response, _resource): - if 'result' not in request.context: + if not request.context.output: return response.body = json.dumps( - request.context.result, default=json_serializer, indent=2) + request.context.output, default=json_serializer, indent=2) diff --git a/server/szurubooru/middleware/imbue_context.py b/server/szurubooru/middleware/imbue_context.py deleted file mode 100644 index da1a91da..00000000 --- a/server/szurubooru/middleware/imbue_context.py +++ /dev/null @@ -1,8 +0,0 @@ -class ImbueContext(object): - ''' Decorates context with methods from falcon's request. ''' - - def process_request(self, request, _response): - request.context.get_param_as_string = request.get_param_as_string - request.context.get_param_as_bool = request.get_param_as_bool - request.context.get_param_as_int = request.get_param_as_int - request.context.get_param_as_list = request.get_param_as_list diff --git a/server/szurubooru/search/search_executor.py b/server/szurubooru/search/search_executor.py index 992a4a65..8cbd6c38 100644 --- a/server/szurubooru/search/search_executor.py +++ b/server/szurubooru/search/search_executor.py @@ -20,8 +20,6 @@ class SearchExecutor(object): Parse input and return tuple containing total record count and filtered entities. ''' - page = max(1, int(page)) - page_size = max(1, int(page_size)) filter_query = self._prepare(session, query_text) entities = filter_query \ .offset((page - 1) * page_size).limit(page_size).all() diff --git a/server/szurubooru/tests/api/test_context.py b/server/szurubooru/tests/api/test_context.py new file mode 100644 index 00000000..2fce5bec --- /dev/null +++ b/server/szurubooru/tests/api/test_context.py @@ -0,0 +1,43 @@ +import pytest +from szurubooru import api, errors + +def test_has_param(): + ctx = api.Context() + ctx.input = {'key': 'value'} + assert ctx.has_param('key') + assert not ctx.has_param('key2') + +def test_get_file(): + ctx = api.Context() + ctx.files = {'key': b'content'} + assert ctx.get_file('key') == b'content' + assert ctx.get_file('key2') is None + +def test_getting_string_parameter(): + ctx = api.Context() + ctx.input = {'key': 'value', 'list': ['1', '2', '3']} + assert ctx.get_param_as_string('key') == 'value' + assert ctx.get_param_as_string('key2') is None + assert ctx.get_param_as_string('key2', default='def') == 'def' + assert ctx.get_param_as_string('list') == '1,2,3' # falcon issue #749 + with pytest.raises(errors.ValidationError): + ctx.get_param_as_string('key2', required=True) + +def test_getting_int_parameter(): + ctx = api.Context() + ctx.input = {'key': '50', 'err': 'invalid', 'list': [1, 2, 3]} + assert ctx.get_param_as_int('key') == 50 + assert ctx.get_param_as_int('key2') is None + assert ctx.get_param_as_int('key2', default=5) == 5 + with pytest.raises(errors.ValidationError): + ctx.get_param_as_int('list') + with pytest.raises(errors.ValidationError): + ctx.get_param_as_int('key2', required=True) + with pytest.raises(errors.ValidationError): + ctx.get_param_as_int('err') + with pytest.raises(errors.ValidationError): + assert ctx.get_param_as_int('key', min=50) == 50 + ctx.get_param_as_int('key', min=51) + with pytest.raises(errors.ValidationError): + assert ctx.get_param_as_int('key', max=50) == 50 + ctx.get_param_as_int('key', max=49) diff --git a/server/szurubooru/tests/api/test_password_reset.py b/server/szurubooru/tests/api/test_password_reset.py index b29590f0..9d48563c 100644 --- a/server/szurubooru/tests/api/test_password_reset.py +++ b/server/szurubooru/tests/api/test_password_reset.py @@ -58,21 +58,21 @@ def test_confirmation_no_token(password_reset_api, context_factory, session): user = mock_user('u1', 'regular_user', 'user@example.com') session.add(user) with pytest.raises(errors.ValidationError): - password_reset_api.post(context_factory(request={}), 'u1') + password_reset_api.post(context_factory(input={}), 'u1') def test_confirmation_bad_token(password_reset_api, context_factory, session): user = mock_user('u1', 'regular_user', 'user@example.com') session.add(user) with pytest.raises(errors.ValidationError): password_reset_api.post( - context_factory(request={'token': 'bad'}), 'u1') + context_factory(input={'token': 'bad'}), 'u1') def test_confirmation_good_token(password_reset_api, context_factory, session): user = mock_user('u1', 'regular_user', 'user@example.com') old_hash = user.password_hash session.add(user) context = context_factory( - request={'token': '4ac0be176fb364f13ee6b634c43220e2'}) + input={'token': '4ac0be176fb364f13ee6b634c43220e2'}) result = password_reset_api.post(context, 'u1') assert user.password_hash != old_hash assert auth.is_valid_password(user, result['password']) is True diff --git a/server/szurubooru/tests/api/test_user_creating.py b/server/szurubooru/tests/api/test_user_creating.py index b9d1204a..c999b72d 100644 --- a/server/szurubooru/tests/api/test_user_creating.py +++ b/server/szurubooru/tests/api/test_user_creating.py @@ -26,7 +26,7 @@ def test_creating_users( user_list_api.post( context_factory( - request={ + input={ 'name': 'chewie1', 'email': 'asd@asd.asd', 'password': 'oks', @@ -34,7 +34,7 @@ def test_creating_users( user=user_factory(rank='regular_user'))) user_list_api.post( context_factory( - request={ + input={ 'name': 'chewie2', 'email': 'asd@asd.asd', 'password': 'sok', @@ -68,7 +68,7 @@ def test_creating_user_that_already_exists( }) user_list_api.post( context_factory( - request={ + input={ 'name': 'chewie', 'email': 'asd@asd.asd', 'password': 'oks', @@ -77,7 +77,7 @@ def test_creating_user_that_already_exists( with pytest.raises(errors.IntegrityError): user_list_api.post( context_factory( - request={ + input={ 'name': 'chewie', 'email': 'asd@asd.asd', 'password': 'oks', @@ -86,7 +86,7 @@ def test_creating_user_that_already_exists( with pytest.raises(errors.IntegrityError): user_list_api.post( context_factory( - request={ + input={ 'name': 'CHEWIE', 'email': 'asd@asd.asd', 'password': 'oks', @@ -109,4 +109,4 @@ def test_missing_field( with pytest.raises(errors.ValidationError): user_list_api.post( context_factory( - request=request, user=user_factory(rank='regular_user'))) + input=request, user=user_factory(rank='regular_user'))) diff --git a/server/szurubooru/tests/api/test_user_retrieval.py b/server/szurubooru/tests/api/test_user_retrieval.py index 6b017dd6..845b3c09 100644 --- a/server/szurubooru/tests/api/test_user_retrieval.py +++ b/server/szurubooru/tests/api/test_user_retrieval.py @@ -27,7 +27,7 @@ def test_retrieving_multiple( session.add_all([user1, user2]) result = user_list_api.get( context_factory( - params={'query': '', 'page': 1}, + input={'query': '', 'page': 1}, user=user_factory(rank='regular_user'))) assert result['query'] == '' assert result['page'] == 1 @@ -44,7 +44,7 @@ def test_retrieving_multiple_without_privileges( with pytest.raises(errors.AuthError): user_list_api.get( context_factory( - params={'query': '', 'page': 1}, + input={'query': '', 'page': 1}, user=user_factory(rank='anonymous'))) def test_retrieving_multiple_with_privileges( @@ -55,7 +55,7 @@ def test_retrieving_multiple_with_privileges( }) result = user_list_api.get( context_factory( - params={'query': 'asd', 'page': 1}, + input={'query': 'asd', 'page': 1}, user=user_factory(rank='regular_user'))) assert result['query'] == 'asd' assert result['page'] == 1 @@ -79,7 +79,7 @@ def test_retrieving_single( session.add(user) result = user_detail_api.get( context_factory( - params={'query': '', 'page': 1}, + input={'query': '', 'page': 1}, user=user_factory(rank='regular_user')), 'u1') assert result['user']['id'] == user.user_id @@ -98,7 +98,7 @@ def test_retrieving_non_existing( with pytest.raises(errors.NotFoundError): user_detail_api.get( context_factory( - params={'query': '', 'page': 1}, + input={'query': '', 'page': 1}, user=user_factory(rank='regular_user')), '-') @@ -111,6 +111,6 @@ def test_retrieving_single_without_privileges( with pytest.raises(errors.AuthError): user_detail_api.get( context_factory( - params={'query': '', 'page': 1}, + input={'query': '', 'page': 1}, user=user_factory(rank='anonymous')), '-') diff --git a/server/szurubooru/tests/api/test_user_updating.py b/server/szurubooru/tests/api/test_user_updating.py index 79205838..38e95b49 100644 --- a/server/szurubooru/tests/api/test_user_updating.py +++ b/server/szurubooru/tests/api/test_user_updating.py @@ -31,7 +31,7 @@ def test_updating_user( session.add(user) user_detail_api.put( context_factory( - request={ + input={ 'name': 'chewie', 'email': 'asd@asd.asd', 'password': 'oks', @@ -96,7 +96,7 @@ def test_removing_email( user = user_factory(name='u1', rank='admin') session.add(user) user_detail_api.put( - context_factory(request={'email': ''}, user=user), 'u1') + context_factory(input={'email': ''}, user=user), 'u1') assert session.query(db.User).filter_by(name='u1').one().email is None @pytest.mark.parametrize('request', [ @@ -128,7 +128,7 @@ def test_invalid_inputs( user = user_factory(name='u1', rank='admin') session.add(user) with pytest.raises(errors.ValidationError): - user_detail_api.put(context_factory(request=request, user=user), 'u1') + user_detail_api.put(context_factory(input=request, user=user), 'u1') @pytest.mark.parametrize('request', [ {'name': 'whatever'}, @@ -159,7 +159,7 @@ def test_user_trying_to_update_someone_else( session.add_all([user1, user2]) with pytest.raises(errors.AuthError): user_detail_api.put( - context_factory(request=request, user=user1), user2.name) + context_factory(input=request, user=user1), user2.name) def test_user_trying_to_become_someone_else( session, @@ -176,11 +176,11 @@ def test_user_trying_to_become_someone_else( session.add_all([user1, user2]) with pytest.raises(errors.IntegrityError): user_detail_api.put( - context_factory(request={'name': 'her'}, user=user1), + context_factory(input={'name': 'her'}, user=user1), 'me') with pytest.raises(errors.IntegrityError): user_detail_api.put( - context_factory(request={'name': 'HER'}, user=user1), 'me') + context_factory(input={'name': 'HER'}, user=user1), 'me') def test_mods_trying_to_become_admin( session, @@ -198,7 +198,7 @@ def test_mods_trying_to_become_admin( user1 = user_factory(name='u1', rank='mod') user2 = user_factory(name='u2', rank='mod') session.add_all([user1, user2]) - context = context_factory(request={'rank': 'admin'}, user=user1) + context = context_factory(input={'rank': 'admin'}, user=user1) with pytest.raises(errors.AuthError): user_detail_api.put(context, user1.name) with pytest.raises(errors.AuthError): @@ -227,7 +227,7 @@ def test_uploading_avatar( b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b' response = user_detail_api.put( context_factory( - request={'avatarStyle': 'manual'}, + input={'avatarStyle': 'manual'}, files={'avatar': empty_pixel}, user=user), 'u1') diff --git a/server/szurubooru/tests/conftest.py b/server/szurubooru/tests/conftest.py index 7e1851c4..14fb7d02 100644 --- a/server/szurubooru/tests/conftest.py +++ b/server/szurubooru/tests/conftest.py @@ -1,7 +1,7 @@ from datetime import datetime import pytest import sqlalchemy -from szurubooru import db, config +from szurubooru import api, config, db from szurubooru.util import misc @pytest.fixture @@ -15,28 +15,14 @@ def session(): @pytest.fixture def context_factory(session): - def factory(request=None, params=None, files=None, user=None): - params = params or {} - def get_param_as_string(key, default=None, required=False): - if key not in params: - if required: - raise RuntimeError('Param is missing!') - return default - return params[key] - def get_param_as_int(key, default=None, required=False): - if key not in params: - if required: - raise RuntimeError('Param is missing!') - return default - return int(params[key]) - context = misc.dotdict() - context.session = session - context.request = request or {} - context.files = files or {} - context.user = user or db.User() - context.get_param_as_string = get_param_as_string - context.get_param_as_int = get_param_as_int - return context + def factory(request=None, input=None, files=None, user=None): + ctx = api.Context() + ctx.input = input or {} + ctx.session = session + ctx.request = request or {} + ctx.files = files or {} + ctx.user = user or db.User() + return ctx return factory @pytest.fixture diff --git a/server/szurubooru/tests/search/test_user_search_config.py b/server/szurubooru/tests/search/test_user_search_config.py index 07c9d009..b1ec9337 100644 --- a/server/szurubooru/tests/search/test_user_search_config.py +++ b/server/szurubooru/tests/search/test_user_search_config.py @@ -123,7 +123,7 @@ def test_combining_tokens(session, verify_unpaged, input, expected_user_names): (2, 1, 2, ['u2']), (3, 1, 2, []), (0, 1, 2, ['u1']), - (0, 0, 2, ['u1']), + (0, 0, 2, []), ]) def test_paging( session, executor, page, page_size,