diff --git a/server/szurubooru/api/password_reset_api.py b/server/szurubooru/api/password_reset_api.py index 83afeaea..887d2f0d 100644 --- a/server/szurubooru/api/password_reset_api.py +++ b/server/szurubooru/api/password_reset_api.py @@ -13,7 +13,7 @@ MAIL_BODY = ( @rest.routes.get('/password-reset/(?P[^/]+)/?') def start_password_reset( - _ctx: rest.Context, params: Dict[str, str]) -> rest.Response: + ctx: rest.Context, params: Dict[str, str]) -> rest.Response: user_name = params['user_name'] user = users.get_user_by_name_or_email(user_name) if not user.email: @@ -21,12 +21,19 @@ def start_password_reset( 'User %r hasn\'t supplied email. Cannot reset password.' % ( user_name)) token = auth.generate_authentication_token(user) - url = '/password-reset/%s:%s' % (user.name, token) + + if 'HTTP_ORIGIN' in ctx.env: + url = ctx.env['HTTP_ORIGIN'].rstrip('/') + else: + url = '' + url += '/password-reset/%s:%s' % (user.name, token) + mailer.send_mail( 'noreply@%s' % config.config['name'], user.email, MAIL_SUBJECT.format(name=config.config['name']), MAIL_BODY.format(name=config.config['name'], url=url)) + return {} diff --git a/server/szurubooru/rest/app.py b/server/szurubooru/rest/app.py index ea2a2877..8c9efba1 100644 --- a/server/szurubooru/rest/app.py +++ b/server/szurubooru/rest/app.py @@ -63,7 +63,7 @@ def _create_context(env: Dict[str, Any]) -> context.Context: 'Could not decode the request body. The JSON ' 'was incorrect or was not encoded as UTF-8.') - return context.Context(method, path, headers, params, files) + return context.Context(env, method, path, headers, params, files) def application( diff --git a/server/szurubooru/rest/context.py b/server/szurubooru/rest/context.py index 62367545..2aad101a 100644 --- a/server/szurubooru/rest/context.py +++ b/server/szurubooru/rest/context.py @@ -11,11 +11,13 @@ Response = Optional[Dict[str, Any]] class Context: def __init__( self, + env: Dict[str, Any], method: str, url: str, headers: Dict[str, str] = None, params: Request = None, files: Dict[str, bytes] = None) -> None: + self.env = env self.method = method self.url = url self._headers = headers or {} diff --git a/server/szurubooru/tests/conftest.py b/server/szurubooru/tests/conftest.py index f6eeee18..27a107aa 100644 --- a/server/szurubooru/tests/conftest.py +++ b/server/szurubooru/tests/conftest.py @@ -95,6 +95,7 @@ def session(query_logger): # pylint: disable=unused-argument def context_factory(session): def factory(params=None, files=None, user=None, headers=None): ctx = rest.Context( + env={'HTTP_ORIGIN': 'http://example.com'}, method=None, url=None, headers=headers or {}, diff --git a/server/szurubooru/tests/rest/test_context.py b/server/szurubooru/tests/rest/test_context.py index e112ebbe..681d9d70 100644 --- a/server/szurubooru/tests/rest/test_context.py +++ b/server/szurubooru/tests/rest/test_context.py @@ -6,13 +6,14 @@ from szurubooru.func import net def test_has_param(): - ctx = rest.Context(method=None, url=None, params={'key': 'value'}) + ctx = rest.Context(env={}, method=None, url=None, params={'key': 'value'}) assert ctx.has_param('key') assert not ctx.has_param('non-existing') def test_get_file(): - ctx = rest.Context(method=None, url=None, files={'key': b'content'}) + ctx = rest.Context( + env={}, method=None, url=None, files={'key': b'content'}) assert ctx.get_file('key') == b'content' with pytest.raises(errors.ValidationError): ctx.get_file('non-existing') @@ -22,7 +23,7 @@ def test_get_file_from_url(): with unittest.mock.patch('szurubooru.func.net.download'): net.download.return_value = b'content' ctx = rest.Context( - method=None, url=None, params={'keyUrl': 'example.com'}) + env={}, method=None, url=None, params={'keyUrl': 'example.com'}) assert ctx.get_file('key') == b'content' net.download.assert_called_once_with('example.com') with pytest.raises(errors.ValidationError): @@ -31,6 +32,7 @@ def test_get_file_from_url(): def test_getting_list_parameter(): ctx = rest.Context( + env={}, method=None, url=None, params={'key': 'value', 'list': ['1', '2', '3']}) @@ -43,6 +45,7 @@ def test_getting_list_parameter(): def test_getting_string_parameter(): ctx = rest.Context( + env={}, method=None, url=None, params={'key': 'value', 'list': ['1', '2', '3']}) @@ -55,6 +58,7 @@ def test_getting_string_parameter(): def test_getting_int_parameter(): ctx = rest.Context( + env={}, method=None, url=None, params={'key': '50', 'err': 'invalid', 'list': [1, 2, 3]}) @@ -76,7 +80,8 @@ def test_getting_int_parameter(): def test_getting_bool_parameter(): def test(value): - ctx = rest.Context(method=None, url=None, params={'key': value}) + ctx = rest.Context( + env={}, method=None, url=None, params={'key': value}) return ctx.get_param_as_bool('key') assert test('1') is True @@ -104,7 +109,7 @@ def test_getting_bool_parameter(): with pytest.raises(errors.ValidationError): test(['1', '2']) - ctx = rest.Context(method=None, url=None) + ctx = rest.Context(env={}, method=None, url=None) with pytest.raises(errors.ValidationError): ctx.get_param_as_bool('non-existing') assert ctx.get_param_as_bool('non-existing', default=True) is True