client+server: allow for wildcards in Accept headers

This commit is contained in:
Shyam Sunder 2021-09-25 12:20:28 -04:00
parent 636498ad38
commit fff0999e6a
15 changed files with 156 additions and 109 deletions

View file

@ -1,32 +0,0 @@
<!DOCTYPE html>
<html>
<head>
<meta charset='utf-8'/>
<meta name='viewport' content='width=device-width, initial-scale=1, maximum-scale=1'/>
<meta name='theme-color' content='#24aadd'/>
<meta name='apple-mobile-web-app-capable' content='yes'/>
<meta name='apple-mobile-web-app-status-bar-style' content='black'/>
<meta name='msapplication-TileColor' content='#ffffff'/>
<meta name="msapplication-TileImage" content="/img/mstile-150x150.png">
<title>Loading...</title>
<!-- Base HTML Placeholder -->
<link href='css/app.min.css' rel='stylesheet' type='text/css'/>
<link href='css/vendor.min.css' rel='stylesheet' type='text/css'/>
<link rel='shortcut icon' type='image/png' href='img/favicon.png'/>
<link rel='apple-touch-icon' sizes='180x180' href='img/apple-touch-icon.png'/>
<link rel='apple-touch-startup-image' href='img/apple-touch-startup-image-640x1136.png' media='(device-width: 320px) and (device-height: 568px) and (-webkit-device-pixel-ratio: 2) and (orientation: portrait)'/>
<link rel='apple-touch-startup-image' href='img/apple-touch-startup-image-750x1294.png' media='(device-width: 375px) and (device-height: 667px) and (-webkit-device-pixel-ratio: 2) and (orientation: portrait)'/>
<link rel='apple-touch-startup-image' href='img/apple-touch-startup-image-1242x2148.png' media='(device-width: 414px) and (device-height: 736px) and (-webkit-device-pixel-ratio: 3) and (orientation: portrait)'/>
<link rel='apple-touch-startup-image' href='img/apple-touch-startup-image-1125x2436.png' media='(device-width: 375px) and (device-height: 812px) and (-webkit-device-pixel-ratio: 3) and (orientation: portrait)'/>
<link rel='apple-touch-startup-image' href='img/apple-touch-startup-image-1536x2048.png' media='(min-device-width: 768px) and (max-device-width: 1024px) and (-webkit-min-device-pixel-ratio: 2) and (orientation: portrait)'/>
<link rel='apple-touch-startup-image' href='img/apple-touch-startup-image-1668x2224.png' media='(min-device-width: 834px) and (max-device-width: 834px) and (-webkit-min-device-pixel-ratio: 2) and (orientation: portrait)'/>
<link rel='apple-touch-startup-image' href='img/apple-touch-startup-image-2048x2732.png' media='(min-device-width: 1024px) and (max-device-width: 1024px) and (-webkit-min-device-pixel-ratio: 2) and (orientation: portrait)'/>
<link rel='manifest' href='manifest.json'/>
</head>
<body>
<div id='top-navigation-holder'></div>
<div id='content-holder'></div>
<script type='text/javascript' src='js/vendor.min.js'></script>
<script type='text/javascript' src='js/app.min.js'></script>
</body>
</html>

View file

@ -53,14 +53,6 @@ http {
gzip_proxied expired no-cache no-store private auth; gzip_proxied expired no-cache no-store private auth;
gzip_types text/plain application/json; gzip_types text/plain application/json;
if ($http_x_forwarded_host = '') {
set $http_x_forwarded_host $host;
}
if ($http_x_forwarded_proto = '') {
set $http_x_forwarded_proto 'http';
}
if ($request_uri ~* "/api/(.*)") { if ($request_uri ~* "/api/(.*)") {
proxy_pass http://backend/$1; proxy_pass http://backend/$1;
} }
@ -111,16 +103,10 @@ http {
return 406 "API requests should be sent to the /api prefix"; return 406 "API requests should be sent to the /api prefix";
} }
if ($http_x_forwarded_host = '') { if ($request_uri ~* "/(.*)") {
set $http_x_forwarded_host $host; proxy_pass http://backend/html/$1;
} }
if ($http_x_forwarded_proto = '') {
set $http_x_forwarded_proto 'http';
}
proxy_pass http://backend;
error_page 500 502 503 504 @badproxy; error_page 500 502 503 504 @badproxy;
} }

View file

@ -89,8 +89,12 @@ user@host:szuru$ docker-compose down
Some users may wish to access the service at a different base URI, such Some users may wish to access the service at a different base URI, such
as `http://example.com/szuru/`, commonly when sharing multiple HTTP as `http://example.com/szuru/`, commonly when sharing multiple HTTP
services on one domain using a reverse proxy. In this case, simply set services on one domain using a reverse proxy. This can be configured in
`BASE_URL="/szuru/"` in your `.env` file. either of the following ways:
- Set the 'domain' value in `config.yaml` to include the prefix, i.e.:
`domain: "http://example.com/szuru" # omit trailing slash`
- Configure the reverse proxy to pass the `X-Forwarded-Prefix` header.
Note that this will require a reverse proxy to function. You should set Note that this will require a reverse proxy to function. You should set
your reverse proxy to proxy `http(s)://example.com/szuru` to your reverse proxy to proxy `http(s)://example.com/szuru` to
@ -102,14 +106,16 @@ user@host:szuru$ docker-compose down
proxy_http_version 1.1; proxy_http_version 1.1;
proxy_pass http://<internal IP or hostname of frontend container>/; proxy_pass http://<internal IP or hostname of frontend container>/;
proxy_set_header Host $http_host; proxy_set_header Host $http_host;
proxy_set_header Upgrade $http_upgrade; proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade"; proxy_set_header Connection "upgrade";
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Prefix /szuru;
proxy_set_header X-Scheme $scheme;
proxy_set_header X-Real-IP $remote_addr; // optional...
proxy_set_header X-Forwarded-Proto $scheme; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Script-Name /szuru; proxy_set_header X-Scheme $scheme;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-Proto $scheme;
} }
``` ```

View file

@ -10,10 +10,6 @@ BUILD_INFO=latest
# otherwise the port specified here will be publicly accessible # otherwise the port specified here will be publicly accessible
PORT=8080 PORT=8080
# URL base to run szurubooru under
# See "Additional Features" section in INSTALL.md
BASE_URL=/
# Directory to store image data # Directory to store image data
MOUNT_DATA=/var/local/szurubooru/data MOUNT_DATA=/var/local/szurubooru/data

View file

@ -31,7 +31,6 @@ services:
- server - server
environment: environment:
BACKEND_HOST: server BACKEND_HOST: server
BASE_URL:
volumes: volumes:
- "${MOUNT_DATA}:/data:ro" - "${MOUNT_DATA}:/data:ro"
ports: ports:

View file

@ -66,7 +66,7 @@ def get_info(ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
return ret return ret
@rest.routes.get(r"/manifest\.json") @rest.routes.get(r"/manifest(?:\.json)?")
def generate_manifest( def generate_manifest(
ctx: rest.Context, _params: Dict[str, str] = {} ctx: rest.Context, _params: Dict[str, str] = {}
) -> rest.Response: ) -> rest.Response:

View file

@ -123,7 +123,7 @@ def _get_html_template(
"script", type="text/javascript", src="js/app.min.js" "script", type="text/javascript", src="js/app.min.js"
): ):
pass pass
return doc return doc.getvalue()
def _get_post_id(params: Dict[str, str]) -> int: def _get_post_id(params: Dict[str, str]) -> int:
@ -139,7 +139,7 @@ def _get_post(params: Dict[str, str]) -> model.Post:
return posts.get_post_by_id(_get_post_id(params)) return posts.get_post_by_id(_get_post_id(params))
@rest.routes.get("/post/(?P<post_id>[^/]+)/?", accept="text/html") @rest.routes.get("/html/post/(?P<post_id>[^/]+)/?", accept="text/html")
def get_post_html( def get_post_html(
ctx: rest.Context, params: Dict[str, str] = {} ctx: rest.Context, params: Dict[str, str] = {}
) -> rest.Response: ) -> rest.Response:
@ -166,7 +166,7 @@ def get_post_html(
) )
@rest.routes.get("/.*", accept="text/html") @rest.routes.get("/html/.*", accept="text/html")
def default_route( def default_route(
ctx: rest.Context, _params: Dict[str, str] = {} ctx: rest.Context, _params: Dict[str, str] = {}
) -> rest.Response: ) -> rest.Response:

View file

@ -25,15 +25,7 @@ def start_password_reset(
) )
token = auth.generate_authentication_token(user) token = auth.generate_authentication_token(user)
if config.config["domain"]: url = f"{ctx.url_prefix}/password-reset/{user.name}:{token}"
url = config.config["domain"]
elif "HTTP_ORIGIN" in ctx.env:
url = ctx.env["HTTP_ORIGIN"].rstrip("/")
elif "HTTP_REFERER" in ctx.env:
url = ctx.env["HTTP_REFERER"].rstrip("/")
else:
url = ""
url += "/password-reset/%s:%s" % (user.name, token)
mailer.send_mail( mailer.send_mail(
config.config["smtp"]["from"], config.config["smtp"]["from"],

View file

@ -21,8 +21,8 @@ def _json_serializer(obj: Any) -> str:
def _serialize_response_body(obj: Any, accept: str) -> str: def _serialize_response_body(obj: Any, accept: str) -> str:
if accept == "application/json": if accept == "application/json":
return json.dumps(obj, default=_json_serializer, indent=2) return json.dumps(obj, default=_json_serializer, indent=2)
if accept == "text/html": if "text/" in accept:
return obj.getvalue() return obj
raise ValueError("Unhandled response type %s" % accept) raise ValueError("Unhandled response type %s" % accept)
@ -40,18 +40,6 @@ def _create_context(env: Dict[str, Any]) -> context.Context:
path = "/" + env["PATH_INFO"].lstrip("/") path = "/" + env["PATH_INFO"].lstrip("/")
path = path.encode("latin-1").decode("utf-8") # PEP-3333 path = path.encode("latin-1").decode("utf-8") # PEP-3333
headers = _get_headers(env) headers = _get_headers(env)
_raw_accept = headers.get("Accept", "text/html")
if "application/json" in _raw_accept:
accept = "application/json"
elif "text/html" in _raw_accept:
accept = "text/html"
else:
raise errors.HttpNotAcceptable(
"ValidationError",
"This API only supports the following response types: "
"application/json, text/html",
)
if config.config["domain"]: if config.config["domain"]:
url_prefix = config.config["domain"].rstrip("/") url_prefix = config.config["domain"].rstrip("/")
@ -88,7 +76,13 @@ def _create_context(env: Dict[str, Any]) -> context.Context:
) )
return context.Context( return context.Context(
env, method, path, headers, accept, url_prefix, params, files env=env,
method=method,
url=path,
headers=headers,
url_prefix=url_prefix,
params=params,
files=files,
) )
@ -97,15 +91,32 @@ def application(
) -> Tuple[bytes]: ) -> Tuple[bytes]:
try: try:
ctx = _create_context(env) ctx = _create_context(env)
for url, allowed_methods in routes.routes[ctx.accept].items(): for url, allowed_methods in routes.routes.items():
match = re.fullmatch(url, ctx.url) match = re.fullmatch(url, ctx.url)
if match: if match:
if ctx.method not in allowed_methods: if ctx.method not in allowed_methods:
raise errors.HttpMethodNotAllowed( raise errors.HttpMethodNotAllowed(
"ValidationError", "ValidationError",
"Allowed methods: %r" % allowed_methods, "Allowed methods: %s"
% ", ".join(allowed_methods.keys()),
) )
handler = allowed_methods[ctx.method] handler, allowed_accept = allowed_methods[ctx.method]
if not any(
map(
lambda a: a in ctx.get_header("Accept"),
[
allowed_accept,
allowed_accept.split("/")[0] + "/*",
"*/*",
],
)
):
raise errors.HttpNotAcceptable(
"ValidationError",
"This route only supports %s responses."
% allowed_accept,
)
ctx.accept = allowed_accept
break break
else: else:
raise errors.HttpNotFound( raise errors.HttpNotFound(

View file

@ -1,20 +1,18 @@
from collections import defaultdict from collections import defaultdict
from typing import Callable, Dict from typing import Callable, Dict, Tuple
from szurubooru.rest.context import Context, Response from szurubooru.rest.context import Context, Response
RouteHandler = Callable[[Context, Dict[str, str]], Response] RouteHandler = Callable[[Context, Dict[str, str]], Response]
routes = { # type: Dict[Dict[str, Dict[str, RouteHandler]]] routes = defaultdict(dict)
"application/json": defaultdict(dict), # type: Dict[str, Dict[str, Tuple[RouteHandler, str]]]
"text/html": defaultdict(dict),
}
def get( def get(
url: str, accept: str = "application/json" url: str, accept: str = "application/json"
) -> Callable[[RouteHandler], RouteHandler]: ) -> Callable[[RouteHandler], RouteHandler]:
def wrapper(handler: RouteHandler) -> RouteHandler: def wrapper(handler: RouteHandler) -> RouteHandler:
routes[accept][url]["GET"] = handler routes[url]["GET"] = (handler, accept)
return handler return handler
return wrapper return wrapper
@ -24,7 +22,7 @@ def put(
url: str, accept: str = "application/json" url: str, accept: str = "application/json"
) -> Callable[[RouteHandler], RouteHandler]: ) -> Callable[[RouteHandler], RouteHandler]:
def wrapper(handler: RouteHandler) -> RouteHandler: def wrapper(handler: RouteHandler) -> RouteHandler:
routes[accept][url]["PUT"] = handler routes[url]["PUT"] = (handler, accept)
return handler return handler
return wrapper return wrapper
@ -34,7 +32,7 @@ def post(
url: str, accept: str = "application/json" url: str, accept: str = "application/json"
) -> Callable[[RouteHandler], RouteHandler]: ) -> Callable[[RouteHandler], RouteHandler]:
def wrapper(handler: RouteHandler) -> RouteHandler: def wrapper(handler: RouteHandler) -> RouteHandler:
routes[accept][url]["POST"] = handler routes[url]["POST"] = (handler, accept)
return handler return handler
return wrapper return wrapper
@ -44,7 +42,7 @@ def delete(
url: str, accept: str = "application/json" url: str, accept: str = "application/json"
) -> Callable[[RouteHandler], RouteHandler]: ) -> Callable[[RouteHandler], RouteHandler]:
def wrapper(handler: RouteHandler) -> RouteHandler: def wrapper(handler: RouteHandler) -> RouteHandler:
routes[accept][url]["DELETE"] = handler routes[url]["DELETE"] = (handler, accept)
return handler return handler
return wrapper return wrapper

View file

@ -94,3 +94,29 @@ def test_info_api(
"serverTime": datetime(2016, 1, 3, 13, 1), "serverTime": datetime(2016, 1, 3, 13, 1),
"config": expected_config_key, "config": expected_config_key,
} }
def test_manifest(config_injector, context_factory):
config_injector({"name": "test installation"})
ctx = context_factory()
ctx.url_prefix = "/someprefix"
expected_manifest = {
"name": "test installation",
"icons": [
{
"src": "/someprefix/img/android-chrome-192x192.png",
"type": "image/png",
"sizes": "192x192",
},
{
"src": "/someprefix/img/android-chrome-512x512.png",
"type": "image/png",
"sizes": "512x512",
},
],
"start_url": "/someprefix/",
"theme_color": "#24aadd",
"background_color": "#ffffff",
"display": "standalone",
}
assert api.info_api.generate_manifest(ctx) == expected_manifest

View file

@ -0,0 +1,40 @@
from unittest.mock import patch
import pytest
import yattag
from szurubooru import api, db
from szurubooru.func import auth, posts
def _make_meta_tag(name, content):
doc = yattag.Doc()
doc.stag("meta", name=name, content=content)
return doc.getvalue()
@pytest.mark.parametrize("view_priv", [True, False])
def test_get_post_html(
config_injector, context_factory, post_factory, view_priv
):
config_injector(
{
"name": "test installation",
"data_url": "data/",
}
)
ctx = context_factory()
ctx.url_prefix = "/someprefix"
db.session.add(post_factory(id=1))
db.session.flush()
with patch("szurubooru.func.auth.has_privilege"), patch(
"szurubooru.func.posts.get_post_content_url"
):
auth.has_privilege.return_value = view_priv
posts.get_post_content_url.return_value = "/content-url"
ret = api.opengraph_api.get_post_html(ctx, {"post_id": 1})
assert _make_meta_tag("og:site_name", "test installation") in ret
assert _make_meta_tag("og:title", "Post 1 - test installation") in ret
if view_priv:
assert _make_meta_tag("og:image", "/content-url") in ret

View file

@ -27,11 +27,13 @@ def test_reset_sending_email(context_factory, user_factory):
) )
) )
db.session.flush() db.session.flush()
ctx = context_factory()
ctx.url_prefix = "http://example.com"
for initiating_user in ["u1", "user@example.com"]: for initiating_user in ["u1", "user@example.com"]:
with patch("szurubooru.func.mailer.send_mail"): with patch("szurubooru.func.mailer.send_mail"):
assert ( assert (
api.password_reset_api.start_password_reset( api.password_reset_api.start_password_reset(
context_factory(), {"user_name": initiating_user} ctx, {"user_name": initiating_user}
) )
== {} == {}
) )

View file

@ -67,12 +67,13 @@ def nontransacted_session(query_logger, postgresql_db):
@pytest.fixture @pytest.fixture
def context_factory(session): def context_factory(session):
def factory(params=None, files=None, user=None, headers=None): def factory(params=None, files=None, user=None, headers=None, accept=None):
ctx = rest.Context( ctx = rest.Context(
env={"HTTP_ORIGIN": "http://example.com"}, env={"HTTP_ORIGIN": "http://example.com"},
method=None, method=None,
url=None, url=None,
headers=headers or {}, headers=headers or {},
accept=accept or None,
params=params or {}, params=params or {},
files=files or {}, files=files or {},
) )

View file

@ -9,11 +9,26 @@ from szurubooru.rest import errors
def test_process_request_no_header(context_factory): def test_process_request_no_header(context_factory):
ctx = context_factory() ctx = context_factory(accept="application/json")
authenticator.process_request(ctx) authenticator.process_request(ctx)
assert ctx.user.name is None assert ctx.user.name is None
def test_process_request_non_rest(context_factory, user_factory):
user = user_factory()
ctx = context_factory(
headers={"Authorization": "Basic dGVzdFVzZXI6dGVzdFBhc3N3b3Jk"},
accept="text/html",
)
with patch("szurubooru.func.auth.is_valid_password"), patch(
"szurubooru.func.users.get_user_by_name"
):
users.get_user_by_name.return_value = user
auth.is_valid_password.return_value = True
authenticator.process_request(ctx)
assert ctx.user.name is None
def test_process_request_bump_login(context_factory, user_factory): def test_process_request_bump_login(context_factory, user_factory):
user = user_factory() user = user_factory()
db.session.add(user) db.session.add(user)
@ -21,6 +36,7 @@ def test_process_request_bump_login(context_factory, user_factory):
ctx = context_factory( ctx = context_factory(
headers={"Authorization": "Basic dGVzdFVzZXI6dGVzdFRva2Vu"}, headers={"Authorization": "Basic dGVzdFVzZXI6dGVzdFRva2Vu"},
params={"bump-login": "true"}, params={"bump-login": "true"},
accept="application/json",
) )
with patch("szurubooru.func.auth.is_valid_password"), patch( with patch("szurubooru.func.auth.is_valid_password"), patch(
"szurubooru.func.users.get_user_by_name" "szurubooru.func.users.get_user_by_name"
@ -40,6 +56,7 @@ def test_process_request_bump_login_with_token(
ctx = context_factory( ctx = context_factory(
headers={"Authorization": "Token dGVzdFVzZXI6dGVzdFRva2Vu"}, headers={"Authorization": "Token dGVzdFVzZXI6dGVzdFRva2Vu"},
params={"bump-login": "true"}, params={"bump-login": "true"},
accept="application/json",
) )
with patch("szurubooru.func.auth.is_valid_token"), patch( with patch("szurubooru.func.auth.is_valid_token"), patch(
"szurubooru.func.users.get_user_by_name" "szurubooru.func.users.get_user_by_name"
@ -55,7 +72,8 @@ def test_process_request_bump_login_with_token(
def test_process_request_basic_auth_valid(context_factory, user_factory): def test_process_request_basic_auth_valid(context_factory, user_factory):
user = user_factory() user = user_factory()
ctx = context_factory( ctx = context_factory(
headers={"Authorization": "Basic dGVzdFVzZXI6dGVzdFBhc3N3b3Jk"} headers={"Authorization": "Basic dGVzdFVzZXI6dGVzdFBhc3N3b3Jk"},
accept="application/json",
) )
with patch("szurubooru.func.auth.is_valid_password"), patch( with patch("szurubooru.func.auth.is_valid_password"), patch(
"szurubooru.func.users.get_user_by_name" "szurubooru.func.users.get_user_by_name"
@ -69,7 +87,8 @@ def test_process_request_basic_auth_valid(context_factory, user_factory):
def test_process_request_token_auth_valid(context_factory, user_token_factory): def test_process_request_token_auth_valid(context_factory, user_token_factory):
user_token = user_token_factory() user_token = user_token_factory()
ctx = context_factory( ctx = context_factory(
headers={"Authorization": "Token dGVzdFVzZXI6dGVzdFRva2Vu"} headers={"Authorization": "Token dGVzdFVzZXI6dGVzdFRva2Vu"},
accept="application/json",
) )
with patch("szurubooru.func.auth.is_valid_token"), patch( with patch("szurubooru.func.auth.is_valid_token"), patch(
"szurubooru.func.users.get_user_by_name" "szurubooru.func.users.get_user_by_name"
@ -82,6 +101,9 @@ def test_process_request_token_auth_valid(context_factory, user_token_factory):
def test_process_request_bad_header(context_factory): def test_process_request_bad_header(context_factory):
ctx = context_factory(headers={"Authorization": "Secret SuperSecretValue"}) ctx = context_factory(
headers={"Authorization": "Secret SuperSecretValue"},
accept="application/json",
)
with pytest.raises(errors.HttpBadRequest): with pytest.raises(errors.HttpBadRequest):
authenticator.process_request(ctx) authenticator.process_request(ctx)