server/password-reset: try to construct full URL

This commit is contained in:
rr- 2018-07-08 10:04:09 +02:00
parent d85e746a65
commit c9cb9aa539
5 changed files with 23 additions and 8 deletions

View file

@ -13,7 +13,7 @@ MAIL_BODY = (
@rest.routes.get('/password-reset/(?P<user_name>[^/]+)/?')
def start_password_reset(
_ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
user_name = params['user_name']
user = users.get_user_by_name_or_email(user_name)
if not user.email:
@ -21,12 +21,19 @@ def start_password_reset(
'User %r hasn\'t supplied email. Cannot reset password.' % (
user_name))
token = auth.generate_authentication_token(user)
url = '/password-reset/%s:%s' % (user.name, token)
if 'HTTP_ORIGIN' in ctx.env:
url = ctx.env['HTTP_ORIGIN'].rstrip('/')
else:
url = ''
url += '/password-reset/%s:%s' % (user.name, token)
mailer.send_mail(
'noreply@%s' % config.config['name'],
user.email,
MAIL_SUBJECT.format(name=config.config['name']),
MAIL_BODY.format(name=config.config['name'], url=url))
return {}

View file

@ -63,7 +63,7 @@ def _create_context(env: Dict[str, Any]) -> context.Context:
'Could not decode the request body. The JSON '
'was incorrect or was not encoded as UTF-8.')
return context.Context(method, path, headers, params, files)
return context.Context(env, method, path, headers, params, files)
def application(

View file

@ -11,11 +11,13 @@ Response = Optional[Dict[str, Any]]
class Context:
def __init__(
self,
env: Dict[str, Any],
method: str,
url: str,
headers: Dict[str, str] = None,
params: Request = None,
files: Dict[str, bytes] = None) -> None:
self.env = env
self.method = method
self.url = url
self._headers = headers or {}

View file

@ -95,6 +95,7 @@ def session(query_logger): # pylint: disable=unused-argument
def context_factory(session):
def factory(params=None, files=None, user=None, headers=None):
ctx = rest.Context(
env={'HTTP_ORIGIN': 'http://example.com'},
method=None,
url=None,
headers=headers or {},

View file

@ -6,13 +6,14 @@ from szurubooru.func import net
def test_has_param():
ctx = rest.Context(method=None, url=None, params={'key': 'value'})
ctx = rest.Context(env={}, method=None, url=None, params={'key': 'value'})
assert ctx.has_param('key')
assert not ctx.has_param('non-existing')
def test_get_file():
ctx = rest.Context(method=None, url=None, files={'key': b'content'})
ctx = rest.Context(
env={}, method=None, url=None, files={'key': b'content'})
assert ctx.get_file('key') == b'content'
with pytest.raises(errors.ValidationError):
ctx.get_file('non-existing')
@ -22,7 +23,7 @@ def test_get_file_from_url():
with unittest.mock.patch('szurubooru.func.net.download'):
net.download.return_value = b'content'
ctx = rest.Context(
method=None, url=None, params={'keyUrl': 'example.com'})
env={}, method=None, url=None, params={'keyUrl': 'example.com'})
assert ctx.get_file('key') == b'content'
net.download.assert_called_once_with('example.com')
with pytest.raises(errors.ValidationError):
@ -31,6 +32,7 @@ def test_get_file_from_url():
def test_getting_list_parameter():
ctx = rest.Context(
env={},
method=None,
url=None,
params={'key': 'value', 'list': ['1', '2', '3']})
@ -43,6 +45,7 @@ def test_getting_list_parameter():
def test_getting_string_parameter():
ctx = rest.Context(
env={},
method=None,
url=None,
params={'key': 'value', 'list': ['1', '2', '3']})
@ -55,6 +58,7 @@ def test_getting_string_parameter():
def test_getting_int_parameter():
ctx = rest.Context(
env={},
method=None,
url=None,
params={'key': '50', 'err': 'invalid', 'list': [1, 2, 3]})
@ -76,7 +80,8 @@ def test_getting_int_parameter():
def test_getting_bool_parameter():
def test(value):
ctx = rest.Context(method=None, url=None, params={'key': value})
ctx = rest.Context(
env={}, method=None, url=None, params={'key': value})
return ctx.get_param_as_bool('key')
assert test('1') is True
@ -104,7 +109,7 @@ def test_getting_bool_parameter():
with pytest.raises(errors.ValidationError):
test(['1', '2'])
ctx = rest.Context(method=None, url=None)
ctx = rest.Context(env={}, method=None, url=None)
with pytest.raises(errors.ValidationError):
ctx.get_param_as_bool('non-existing')
assert ctx.get_param_as_bool('non-existing', default=True) is True