server/password-reset: try to construct full URL
This commit is contained in:
parent
d85e746a65
commit
c9cb9aa539
5 changed files with 23 additions and 8 deletions
|
@ -13,7 +13,7 @@ MAIL_BODY = (
|
|||
|
||||
@rest.routes.get('/password-reset/(?P<user_name>[^/]+)/?')
|
||||
def start_password_reset(
|
||||
_ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
||||
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
||||
user_name = params['user_name']
|
||||
user = users.get_user_by_name_or_email(user_name)
|
||||
if not user.email:
|
||||
|
@ -21,12 +21,19 @@ def start_password_reset(
|
|||
'User %r hasn\'t supplied email. Cannot reset password.' % (
|
||||
user_name))
|
||||
token = auth.generate_authentication_token(user)
|
||||
url = '/password-reset/%s:%s' % (user.name, token)
|
||||
|
||||
if 'HTTP_ORIGIN' in ctx.env:
|
||||
url = ctx.env['HTTP_ORIGIN'].rstrip('/')
|
||||
else:
|
||||
url = ''
|
||||
url += '/password-reset/%s:%s' % (user.name, token)
|
||||
|
||||
mailer.send_mail(
|
||||
'noreply@%s' % config.config['name'],
|
||||
user.email,
|
||||
MAIL_SUBJECT.format(name=config.config['name']),
|
||||
MAIL_BODY.format(name=config.config['name'], url=url))
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
|
|
|
@ -63,7 +63,7 @@ def _create_context(env: Dict[str, Any]) -> context.Context:
|
|||
'Could not decode the request body. The JSON '
|
||||
'was incorrect or was not encoded as UTF-8.')
|
||||
|
||||
return context.Context(method, path, headers, params, files)
|
||||
return context.Context(env, method, path, headers, params, files)
|
||||
|
||||
|
||||
def application(
|
||||
|
|
|
@ -11,11 +11,13 @@ Response = Optional[Dict[str, Any]]
|
|||
class Context:
|
||||
def __init__(
|
||||
self,
|
||||
env: Dict[str, Any],
|
||||
method: str,
|
||||
url: str,
|
||||
headers: Dict[str, str] = None,
|
||||
params: Request = None,
|
||||
files: Dict[str, bytes] = None) -> None:
|
||||
self.env = env
|
||||
self.method = method
|
||||
self.url = url
|
||||
self._headers = headers or {}
|
||||
|
|
|
@ -95,6 +95,7 @@ def session(query_logger): # pylint: disable=unused-argument
|
|||
def context_factory(session):
|
||||
def factory(params=None, files=None, user=None, headers=None):
|
||||
ctx = rest.Context(
|
||||
env={'HTTP_ORIGIN': 'http://example.com'},
|
||||
method=None,
|
||||
url=None,
|
||||
headers=headers or {},
|
||||
|
|
|
@ -6,13 +6,14 @@ from szurubooru.func import net
|
|||
|
||||
|
||||
def test_has_param():
|
||||
ctx = rest.Context(method=None, url=None, params={'key': 'value'})
|
||||
ctx = rest.Context(env={}, method=None, url=None, params={'key': 'value'})
|
||||
assert ctx.has_param('key')
|
||||
assert not ctx.has_param('non-existing')
|
||||
|
||||
|
||||
def test_get_file():
|
||||
ctx = rest.Context(method=None, url=None, files={'key': b'content'})
|
||||
ctx = rest.Context(
|
||||
env={}, method=None, url=None, files={'key': b'content'})
|
||||
assert ctx.get_file('key') == b'content'
|
||||
with pytest.raises(errors.ValidationError):
|
||||
ctx.get_file('non-existing')
|
||||
|
@ -22,7 +23,7 @@ def test_get_file_from_url():
|
|||
with unittest.mock.patch('szurubooru.func.net.download'):
|
||||
net.download.return_value = b'content'
|
||||
ctx = rest.Context(
|
||||
method=None, url=None, params={'keyUrl': 'example.com'})
|
||||
env={}, method=None, url=None, params={'keyUrl': 'example.com'})
|
||||
assert ctx.get_file('key') == b'content'
|
||||
net.download.assert_called_once_with('example.com')
|
||||
with pytest.raises(errors.ValidationError):
|
||||
|
@ -31,6 +32,7 @@ def test_get_file_from_url():
|
|||
|
||||
def test_getting_list_parameter():
|
||||
ctx = rest.Context(
|
||||
env={},
|
||||
method=None,
|
||||
url=None,
|
||||
params={'key': 'value', 'list': ['1', '2', '3']})
|
||||
|
@ -43,6 +45,7 @@ def test_getting_list_parameter():
|
|||
|
||||
def test_getting_string_parameter():
|
||||
ctx = rest.Context(
|
||||
env={},
|
||||
method=None,
|
||||
url=None,
|
||||
params={'key': 'value', 'list': ['1', '2', '3']})
|
||||
|
@ -55,6 +58,7 @@ def test_getting_string_parameter():
|
|||
|
||||
def test_getting_int_parameter():
|
||||
ctx = rest.Context(
|
||||
env={},
|
||||
method=None,
|
||||
url=None,
|
||||
params={'key': '50', 'err': 'invalid', 'list': [1, 2, 3]})
|
||||
|
@ -76,7 +80,8 @@ def test_getting_int_parameter():
|
|||
|
||||
def test_getting_bool_parameter():
|
||||
def test(value):
|
||||
ctx = rest.Context(method=None, url=None, params={'key': value})
|
||||
ctx = rest.Context(
|
||||
env={}, method=None, url=None, params={'key': value})
|
||||
return ctx.get_param_as_bool('key')
|
||||
|
||||
assert test('1') is True
|
||||
|
@ -104,7 +109,7 @@ def test_getting_bool_parameter():
|
|||
with pytest.raises(errors.ValidationError):
|
||||
test(['1', '2'])
|
||||
|
||||
ctx = rest.Context(method=None, url=None)
|
||||
ctx = rest.Context(env={}, method=None, url=None)
|
||||
with pytest.raises(errors.ValidationError):
|
||||
ctx.get_param_as_bool('non-existing')
|
||||
assert ctx.get_param_as_bool('non-existing', default=True) is True
|
||||
|
|
Loading…
Reference in a new issue