back/middleware: change context to dotdict

This commit is contained in:
rr- 2016-03-28 22:53:56 +02:00
parent 5a0ce0b49d
commit 509fd0620d
8 changed files with 46 additions and 27 deletions

View file

@ -24,19 +24,19 @@ class UserListApi(object):
def on_get(self, request, response): def on_get(self, request, response):
''' Retrieves a list of users. ''' ''' Retrieves a list of users. '''
self._auth_service.verify_privilege(request.context['user'], 'users:list') self._auth_service.verify_privilege(request.context.user, 'users:list')
request.context['result'] = {'message': 'Searching for users'} request.context.result = {'message': 'Searching for users'}
def on_post(self, request, response): def on_post(self, request, response):
''' Creates a new user. ''' ''' Creates a new user. '''
self._auth_service.verify_privilege(request.context['user'], 'users:create') self._auth_service.verify_privilege(request.context.user, 'users:create')
name_regex = self._config['service']['user_name_regex'] name_regex = self._config['service']['user_name_regex']
password_regex = self._config['service']['password_regex'] password_regex = self._config['service']['password_regex']
try: try:
name = request.context['doc']['name'] name = request.context.request['name']
password = request.context['doc']['password'] password = request.context.request['password']
email = request.context['doc']['email'].strip() email = request.context.request['email'].strip()
if not email: if not email:
email = None email = None
except KeyError as ex: except KeyError as ex:
@ -53,13 +53,13 @@ class UserListApi(object):
'Malformed data', 'Malformed data',
'Password must validate %r expression' % password_regex) 'Password must validate %r expression' % password_regex)
session = request.context['session'] session = request.context.session
try: try:
user = self._user_service.create_user(session, name, password, email) user = self._user_service.create_user(session, name, password, email)
session.commit() session.commit()
except: except:
raise IntegrityError('User %r already exists.' % name) raise IntegrityError('User %r already exists.' % name)
request.context['result'] = {'user': _serialize_user(user)} request.context.result = {'user': _serialize_user(user)}
class UserDetailApi(object): class UserDetailApi(object):
''' API for individual users. ''' ''' API for individual users. '''
@ -70,12 +70,12 @@ class UserDetailApi(object):
def on_get(self, request, response, user_name): def on_get(self, request, response, user_name):
''' Retrieves an user. ''' ''' Retrieves an user. '''
self._auth_service.verify_privilege(request.context['user'], 'users:view') self._auth_service.verify_privilege(request.context.user, 'users:view')
session = request.context['session'] session = request.context.session
user = self._user_service.get_by_name(session, user_name) user = self._user_service.get_by_name(session, user_name)
request.context['result'] = _serialize_user(user) request.context.result = _serialize_user(user)
def on_put(self, request, response, user_name): def on_put(self, request, response, user_name):
''' Updates an existing user. ''' ''' Updates an existing user. '''
self._auth_service.verify_privilege(request.context['user'], 'users:edit') self._auth_service.verify_privilege(request.context.user, 'users:edit')
request.context['result'] = {'message': 'Updating user ' + user_name} request.context.result = {'message': 'Updating user ' + user_name}

View file

@ -8,6 +8,10 @@ import szurubooru.api
import szurubooru.config import szurubooru.config
import szurubooru.middleware import szurubooru.middleware
import szurubooru.services import szurubooru.services
import szurubooru.util
class _CustomRequest(falcon.Request):
context_type = szurubooru.util.dotdict
def _on_auth_error(ex, req, resp, params): def _on_auth_error(ex, req, resp, params):
raise falcon.HTTPForbidden('Authentication error', str(ex)) raise falcon.HTTPForbidden('Authentication error', str(ex))
@ -40,12 +44,14 @@ def create_app():
user_list = szurubooru.api.UserListApi(config, auth_service, user_service) user_list = szurubooru.api.UserListApi(config, auth_service, user_service)
user = szurubooru.api.UserDetailApi(config, auth_service, user_service) user = szurubooru.api.UserDetailApi(config, auth_service, user_service)
app = falcon.API(middleware=[ app = falcon.API(
szurubooru.middleware.RequireJson(), request_type=_CustomRequest,
szurubooru.middleware.JsonTranslator(), middleware=[
szurubooru.middleware.DbSession(session_maker), szurubooru.middleware.RequireJson(),
szurubooru.middleware.Authenticator(auth_service, user_service), szurubooru.middleware.JsonTranslator(),
]) szurubooru.middleware.DbSession(session_maker),
szurubooru.middleware.Authenticator(auth_service, user_service),
])
app.add_error_handler(szurubooru.services.AuthError, _on_auth_error) app.add_error_handler(szurubooru.services.AuthError, _on_auth_error)
app.add_error_handler(szurubooru.services.IntegrityError, _on_integrity_error) app.add_error_handler(szurubooru.services.IntegrityError, _on_integrity_error)

View file

@ -17,7 +17,7 @@ class Authenticator(object):
def process_request(self, request, response): def process_request(self, request, response):
''' Executed before passing the request to the API. ''' ''' Executed before passing the request to the API. '''
request.context['user'] = self._get_user(request) request.context.user = self._get_user(request)
def _get_user(self, request): def _get_user(self, request):
if not request.auth: if not request.auth:
@ -34,7 +34,7 @@ class Authenticator(object):
username, password = base64.decodebytes( username, password = base64.decodebytes(
user_and_password.encode('ascii')).decode('utf8').split(':') user_and_password.encode('ascii')).decode('utf8').split(':')
session = request.context['session'] session = request.context.session
return self._authenticate(session, username, password) return self._authenticate(session, username, password)
except ValueError as err: except ValueError as err:
msg = 'Basic authentication header value not properly formed. ' \ msg = 'Basic authentication header value not properly formed. ' \

View file

@ -8,7 +8,11 @@ class DbSession(object):
def process_request(self, request, response): def process_request(self, request, response):
''' Executed before passing the request to the API. ''' ''' Executed before passing the request to the API. '''
request.context['session'] = self._session_factory() request.context.session = self._session_factory()
def process_response(self, request, response, resource): def process_response(self, request, response, resource):
request.context['session'].close() '''
Executed before passing the response to falcon.
Any commits to database need to happen explicitly in the API layer.
'''
request.context.session.close()

View file

@ -29,7 +29,7 @@ class JsonTranslator(object):
'A valid JSON document is required.') 'A valid JSON document is required.')
try: try:
request.context['doc'] = json.loads(body.decode('utf-8')) request.context.request = json.loads(body.decode('utf-8'))
except (ValueError, UnicodeDecodeError): except (ValueError, UnicodeDecodeError):
raise falcon.HTTPError( raise falcon.HTTPError(
falcon.HTTP_401, falcon.HTTP_401,
@ -41,5 +41,4 @@ class JsonTranslator(object):
''' Executed before passing the response to falcon. ''' ''' Executed before passing the response to falcon. '''
if 'result' not in request.context: if 'result' not in request.context:
return return
response.body = json.dumps( response.body = json.dumps(request.context.result, default=json_serial)
request.context['result'], default=json_serial)

View file

@ -1,4 +1,4 @@
''' Base model for every database resource. ''' ''' Base model for every database resource. '''
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base() # pylint: disable=C0103 Base = declarative_base() # pylint: disable=invalid-name

View file

@ -1,3 +1,5 @@
# pylint: disable=too-many-instance-attributes,too-few-public-methods
''' Exports User. ''' ''' Exports User. '''
import sqlalchemy as sa import sqlalchemy as sa

8
szurubooru/util.py Normal file
View file

@ -0,0 +1,8 @@
''' Exports dotdict. '''
class dotdict(dict): # pylint: disable=invalid-name
'''dot.notation access to dictionary attributes'''
def __getattr__(self, attr):
return self.get(attr)
__setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__