diff --git a/server/szurubooru/api/__init__.py b/server/szurubooru/api/__init__.py index 70756588..75c13b24 100644 --- a/server/szurubooru/api/__init__.py +++ b/server/szurubooru/api/__init__.py @@ -1,3 +1,3 @@ ''' Falcon-compatible API facades. ''' -from szurubooru.api.users import UserListApi, UserDetailApi +from szurubooru.api.user_api import UserListApi, UserDetailApi diff --git a/server/szurubooru/api/base_api.py b/server/szurubooru/api/base_api.py new file mode 100644 index 00000000..12e5ed00 --- /dev/null +++ b/server/szurubooru/api/base_api.py @@ -0,0 +1,28 @@ +''' Exports BaseApi. ''' + +import types + +def _bind_method(target, desired_method_name): + actual_method = getattr(target, desired_method_name) + def _wrapper_method(self, request, response, *args, **kwargs): + request.context.result = actual_method(request.context, *args, **kwargs) + return types.MethodType(_wrapper_method, target) + +class BaseApi(object): + ''' + A wrapper around falcon's API interface that eases context and result + management. + ''' + + def __init__(self): + self._translate_routes() + + def _translate_routes(self): + for method_name in ['GET', 'PUT', 'POST', 'DELETE']: + desired_method_name = method_name.lower() + falcon_method_name = 'on_%s' % method_name.lower() + if hasattr(self, desired_method_name): + setattr( + self, + falcon_method_name, + _bind_method(self, desired_method_name)) diff --git a/server/szurubooru/api/users.py b/server/szurubooru/api/user_api.py similarity index 58% rename from server/szurubooru/api/users.py rename to server/szurubooru/api/user_api.py index 9fddd11f..6e1824af 100644 --- a/server/szurubooru/api/users.py +++ b/server/szurubooru/api/user_api.py @@ -3,6 +3,7 @@ import re import falcon from szurubooru.services.errors import IntegrityError +from szurubooru.api.base_api import BaseApi def _serialize_user(authenticated_user, user): ret = { @@ -17,28 +18,29 @@ def _serialize_user(authenticated_user, user): ret['email'] = user.email return ret -class UserListApi(object): +class UserListApi(BaseApi): ''' API for lists of users. ''' def __init__(self, config, auth_service, user_service): + super().__init__() self._config = config self._auth_service = auth_service self._user_service = user_service - def on_get(self, request, response): + def get(self, context): ''' Retrieves a list of users. ''' - self._auth_service.verify_privilege(request.context.user, 'users:list') - request.context.result = {'message': 'Searching for users'} + self._auth_service.verify_privilege(context.user, 'users:list') + return {'message': 'Searching for users'} - def on_post(self, request, response): + def post(self, context): ''' Creates a new user. ''' - self._auth_service.verify_privilege(request.context.user, 'users:create') + self._auth_service.verify_privilege(context.user, 'users:create') name_regex = self._config['service']['user_name_regex'] password_regex = self._config['service']['password_regex'] try: - name = request.context.request['name'] - password = request.context.request['password'] - email = request.context.request['email'].strip() + name = context.request['name'] + password = context.request['password'] + email = context.request['email'].strip() if not email: email = None except KeyError as ex: @@ -55,31 +57,29 @@ class UserListApi(object): 'Malformed data', 'Password must validate %r expression' % password_regex) - session = request.context.session try: - user = self._user_service.create_user(session, name, password, email) - session.commit() + user = self._user_service.create_user( + context.session, name, password, email) + context.session.commit() except: raise IntegrityError('User %r already exists.' % name) - request.context.result = { - 'user': _serialize_user(request.context.user, user)} + return {'user': _serialize_user(context.user, user)} -class UserDetailApi(object): +class UserDetailApi(BaseApi): ''' API for individual users. ''' def __init__(self, config, auth_service, user_service): + super().__init__() self._config = config self._auth_service = auth_service self._user_service = user_service - def on_get(self, request, response, user_name): + def get(self, context, user_name): ''' Retrieves an user. ''' - self._auth_service.verify_privilege(request.context.user, 'users:view') - session = request.context.session - user = self._user_service.get_by_name(session, user_name) - request.context.result = { - 'user': _serialize_user(request.context.user, user)} + self._auth_service.verify_privilege(context.user, 'users:view') + user = self._user_service.get_by_name(context.session, user_name) + return {'user': _serialize_user(context.user, user)} - def on_put(self, request, response, user_name): + def put(self, context, user_name): ''' Updates an existing user. ''' - self._auth_service.verify_privilege(request.context.user, 'users:edit') - request.context.result = {'message': 'Updating user ' + user_name} + self._auth_service.verify_privilege(context.user, 'users:edit') + return {'message': 'Updating user ' + user_name}