This commit is contained in:
Shyam Sunder 2022-01-17 16:14:44 -07:00 committed by GitHub
commit c448669fe4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
36 changed files with 780 additions and 255 deletions

View file

@ -8,7 +8,7 @@ COPY . ./
ARG BUILD_INFO="docker-latest" ARG BUILD_INFO="docker-latest"
ARG CLIENT_BUILD_ARGS="" ARG CLIENT_BUILD_ARGS=""
RUN BASE_URL="__BASEURL__" node build.js --gzip ${CLIENT_BUILD_ARGS} RUN node build.js --gzip ${CLIENT_BUILD_ARGS}
FROM --platform=$BUILDPLATFORM scratch as approot FROM --platform=$BUILDPLATFORM scratch as approot

View file

@ -30,26 +30,6 @@ const external_js = [
'underscore', 'underscore',
]; ];
const app_manifest = {
name: 'szurubooru',
icons: [
{
src: baseUrl() + 'img/android-chrome-192x192.png',
type: 'image/png',
sizes: '192x192'
},
{
src: baseUrl() + 'img/android-chrome-512x512.png',
type: 'image/png',
sizes: '512x512'
}
],
start_url: baseUrl(),
theme_color: '#24aadd',
background_color: '#ffffff',
display: 'standalone'
}
// ------------------------------------------------- // -------------------------------------------------
const fs = require('fs'); const fs = require('fs');
@ -72,10 +52,6 @@ function gzipFile(file) {
execSync('gzip -6 -k ' + file); execSync('gzip -6 -k ' + file);
} }
function baseUrl() {
return process.env.BASE_URL ? process.env.BASE_URL : '/';
}
// ------------------------------------------------- // -------------------------------------------------
function bundleHtml() { function bundleHtml() {
@ -90,10 +66,6 @@ function bundleHtml() {
}).trim(); }).trim();
} }
const baseHtml = readTextFile('./html/index.htm')
.replace('<!-- Base HTML Placeholder -->', `<base href="${baseUrl()}"/>`);
fs.writeFileSync('./public/index.htm', minifyHtml(baseHtml));
let compiledTemplateJs = [ let compiledTemplateJs = [
`'use strict';`, `'use strict';`,
`let _ = require('underscore');`, `let _ = require('underscore');`,
@ -266,9 +238,6 @@ function bundleBinaryAssets() {
function bundleWebAppFiles() { function bundleWebAppFiles() {
const Jimp = require('jimp'); const Jimp = require('jimp');
fs.writeFileSync('./public/manifest.json', JSON.stringify(app_manifest));
console.info('Generated app manifest');
Promise.all(webapp_icons.map(icon => { Promise.all(webapp_icons.map(icon => {
return Jimp.read('./img/app.png') return Jimp.read('./img/app.png')
.then(file => { .then(file => {

View file

@ -1,11 +1,12 @@
#!/usr/bin/dumb-init /bin/sh #!/usr/bin/dumb-init /bin/sh
# Create cache directory
mkdir -p /tmp/nginx-cache
chmod a+rwx /tmp/nginx-cache
# Integrate environment variables # Integrate environment variables
sed -i "s|__BACKEND__|${BACKEND_HOST}|" \ sed -i "s|__BACKEND__|${BACKEND_HOST}|" \
/etc/nginx/nginx.conf /etc/nginx/nginx.conf
sed -i "s|__BASEURL__|${BASE_URL:-/}|g" \
/var/www/index.htm \
/var/www/manifest.json
# Start server # Start server
exec nginx exec nginx

View file

@ -1,32 +0,0 @@
<!DOCTYPE html>
<html>
<head>
<meta charset='utf-8'/>
<meta name='viewport' content='width=device-width, initial-scale=1, maximum-scale=1'/>
<meta name='theme-color' content='#24aadd'/>
<meta name='apple-mobile-web-app-capable' content='yes'/>
<meta name='apple-mobile-web-app-status-bar-style' content='black'/>
<meta name='msapplication-TileColor' content='#ffffff'/>
<meta name="msapplication-TileImage" content="/img/mstile-150x150.png">
<title>Loading...</title>
<!-- Base HTML Placeholder -->
<link href='css/app.min.css' rel='stylesheet' type='text/css'/>
<link href='css/vendor.min.css' rel='stylesheet' type='text/css'/>
<link rel='shortcut icon' type='image/png' href='img/favicon.png'/>
<link rel='apple-touch-icon' sizes='180x180' href='img/apple-touch-icon.png'/>
<link rel='apple-touch-startup-image' href='img/apple-touch-startup-image-640x1136.png' media='(device-width: 320px) and (device-height: 568px) and (-webkit-device-pixel-ratio: 2) and (orientation: portrait)'/>
<link rel='apple-touch-startup-image' href='img/apple-touch-startup-image-750x1294.png' media='(device-width: 375px) and (device-height: 667px) and (-webkit-device-pixel-ratio: 2) and (orientation: portrait)'/>
<link rel='apple-touch-startup-image' href='img/apple-touch-startup-image-1242x2148.png' media='(device-width: 414px) and (device-height: 736px) and (-webkit-device-pixel-ratio: 3) and (orientation: portrait)'/>
<link rel='apple-touch-startup-image' href='img/apple-touch-startup-image-1125x2436.png' media='(device-width: 375px) and (device-height: 812px) and (-webkit-device-pixel-ratio: 3) and (orientation: portrait)'/>
<link rel='apple-touch-startup-image' href='img/apple-touch-startup-image-1536x2048.png' media='(min-device-width: 768px) and (max-device-width: 1024px) and (-webkit-min-device-pixel-ratio: 2) and (orientation: portrait)'/>
<link rel='apple-touch-startup-image' href='img/apple-touch-startup-image-1668x2224.png' media='(min-device-width: 834px) and (max-device-width: 834px) and (-webkit-min-device-pixel-ratio: 2) and (orientation: portrait)'/>
<link rel='apple-touch-startup-image' href='img/apple-touch-startup-image-2048x2732.png' media='(min-device-width: 1024px) and (max-device-width: 1024px) and (-webkit-min-device-pixel-ratio: 2) and (orientation: portrait)'/>
<link rel='manifest' href='manifest.json'/>
</head>
<body>
<div id='top-navigation-holder'></div>
<div id='content-holder'></div>
<script type='text/javascript' src='js/vendor.min.js'></script>
<script type='text/javascript' src='js/app.min.js'></script>
</body>
</html>

View file

@ -19,6 +19,13 @@ http {
server_tokens off; server_tokens off;
keepalive_timeout 65; keepalive_timeout 65;
proxy_cache_path /tmp/nginx-cache
levels=1:2
keys_zone=spa_cache:4m
max_size=50m
inactive=60m
use_temp_path=off;
upstream backend { upstream backend {
server __BACKEND__:6666; server __BACKEND__:6666;
} }
@ -69,9 +76,8 @@ http {
error_page 404 @notfound; error_page 404 @notfound;
} }
location / { location ~ ^/(js|css|img|fonts)/.*$ {
root /var/www; root /var/www;
try_files $uri /index.htm;
sendfile on; sendfile on;
tcp_nopush on; tcp_nopush on;
@ -79,6 +85,36 @@ http {
gzip_static on; gzip_static on;
gzip_proxied expired no-cache no-store private auth; gzip_proxied expired no-cache no-store private auth;
error_page 404 @notfound;
}
location / {
tcp_nodelay on;
# remove unneeded auth headers to improve caching
proxy_set_header Authorization "";
proxy_cache spa_cache;
proxy_cache_use_stale error timeout updating http_500 http_502 http_503 http_504;
proxy_cache_background_update on;
proxy_cache_lock on;
gzip on;
gzip_comp_level 3;
gzip_min_length 20;
gzip_proxied any;
gzip_types text/plain application/json;
if ( $http_accept ~ "application/json" ) {
return 406 "API requests should be sent to the /api prefix";
}
if ($request_uri ~* "/(.*)") {
proxy_pass http://backend/html/$1;
}
error_page 500 502 503 504 @badproxy;
} }
location @unauthorized { location @unauthorized {

View file

@ -80,8 +80,8 @@ user@host:szuru$ docker-compose down
If you want to host your website on, (`http://example.com/`) but want If you want to host your website on, (`http://example.com/`) but want
to serve the images on a different domain, (`http://static.example.com/`) to serve the images on a different domain, (`http://static.example.com/`)
then you can run the backend container with an additional environment then you can configure the `data_url` variable in your `config.yaml`
variable `DATA_URL=http://static.example.com/`. Make sure that this (ex: `data_url: http://static.example.com/`). Make sure that this
additional host has access contents to the `/data` volume mounted in the additional host has access contents to the `/data` volume mounted in the
backend. backend.
@ -89,8 +89,9 @@ user@host:szuru$ docker-compose down
Some users may wish to access the service at a different base URI, such Some users may wish to access the service at a different base URI, such
as `http://example.com/szuru/`, commonly when sharing multiple HTTP as `http://example.com/szuru/`, commonly when sharing multiple HTTP
services on one domain using a reverse proxy. In this case, simply set services on one domain using a reverse proxy. For szurubooru to handle
`BASE_URL="/szuru/"` in your `.env` file. links properly, you must configure the reverse proxy to pass the new
URL prefix (in this case `/szuru`) in the `X-Forwarded-Prefix` header.
Note that this will require a reverse proxy to function. You should set Note that this will require a reverse proxy to function. You should set
your reverse proxy to proxy `http(s)://example.com/szuru` to your reverse proxy to proxy `http(s)://example.com/szuru` to
@ -102,14 +103,16 @@ user@host:szuru$ docker-compose down
proxy_http_version 1.1; proxy_http_version 1.1;
proxy_pass http://<internal IP or hostname of frontend container>/; proxy_pass http://<internal IP or hostname of frontend container>/;
proxy_set_header Host $http_host; proxy_set_header Host $http_host;
proxy_set_header Upgrade $http_upgrade; proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade"; proxy_set_header Connection "upgrade";
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; proxy_set_header X-Forwarded-Prefix /szuru;
proxy_set_header X-Scheme $scheme;
proxy_set_header X-Real-IP $remote_addr; # optional...
proxy_set_header X-Forwarded-Proto $scheme; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Script-Name /szuru; proxy_set_header X-Scheme $scheme;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-Proto $scheme;
} }
``` ```

View file

@ -10,10 +10,6 @@ BUILD_INFO=latest
# otherwise the port specified here will be publicly accessible # otherwise the port specified here will be publicly accessible
PORT=8080 PORT=8080
# URL base to run szurubooru under
# See "Additional Features" section in INSTALL.md
BASE_URL=/
# Directory to store image data # Directory to store image data
MOUNT_DATA=/var/local/szurubooru/data MOUNT_DATA=/var/local/szurubooru/data

View file

@ -31,7 +31,6 @@ services:
- server - server
environment: environment:
BACKEND_HOST: server BACKEND_HOST: server
BASE_URL:
volumes: volumes:
- "${MOUNT_DATA}:/data:ro" - "${MOUNT_DATA}:/data:ro"
ports: ports:

View file

@ -31,6 +31,7 @@ RUN apk --no-cache add \
youtube_dl \ youtube_dl \
pillow-avif-plugin \ pillow-avif-plugin \
pyheif-pillow-opener \ pyheif-pillow-opener \
yattag \
&& apk --no-cache del py3-pip && apk --no-cache del py3-pip
COPY ./ /opt/app/ COPY ./ /opt/app/

View file

@ -3,11 +3,28 @@
# shown in the website title and on the front page # shown in the website title and on the front page
name: szurubooru name: szurubooru
# full url to the homepage of this szurubooru site, with no trailing slash
domain: # example: http://example.com
# used to salt the users' password hashes and generate filenames for static content # used to salt the users' password hashes and generate filenames for static content
secret: change secret: change
# set to the root web address for your instance
# example values:
# - `/` (default) is used when the domain is unknown
# - `https://szuru.example.com/` if you know the specific domain
# and is required if you want email-based password reset
# - `/baseprefix` if you want to host szurubooru on a specific
# prefix and share the domain with other applications
# - `https://www.example.com/szuru` combines both of the above
# also see: "Setting a specific base URI for proxying" in INSTALL.md
base_url: /
# !!should not be changed for the normal docker installation!!
# set to the root web address for static image content
# if it is a relative path with no leading `/`, then this will be
# appended to the base url.
# see: "Using a seperate domain to host static files" in INSTALL.md
# for more info on when to modify
data_url: data/
# Delete thumbnails and source files on post delete # Delete thumbnails and source files on post delete
# Original functionality is no, to mitigate the impacts of admins going # Original functionality is no, to mitigate the impacts of admins going
# on unchecked post purges. # on unchecked post purges.
@ -171,7 +188,6 @@ privileges:
## ONLY SET THESE IF DEPLOYING OUTSIDE OF DOCKER ## ONLY SET THESE IF DEPLOYING OUTSIDE OF DOCKER
#debug: 0 # generate server logs? #debug: 0 # generate server logs?
#show_sql: 0 # show sql in server logs? #show_sql: 0 # show sql in server logs?
#data_url: /data/
#data_dir: /var/www/data #data_dir: /var/www/data
## usage: schema://user:password@host:port/database_name ## usage: schema://user:password@host:port/database_name
## example: postgres://szuru:dog@localhost:5432/szuru_test ## example: postgres://szuru:dog@localhost:5432/szuru_test

View file

@ -11,4 +11,5 @@ pytz>=2018.3
pyRFC3339>=1.0 pyRFC3339>=1.0
pillow-avif-plugin>=1.1.0 pillow-avif-plugin>=1.1.0
pyheif-pillow-opener>=0.1.0 pyheif-pillow-opener>=0.1.0
yattag>=1.14.0
youtube_dl youtube_dl

View file

@ -1,5 +1,6 @@
import szurubooru.api.comment_api import szurubooru.api.comment_api
import szurubooru.api.info_api import szurubooru.api.info_api
import szurubooru.api.opengraph_api
import szurubooru.api.password_reset_api import szurubooru.api.password_reset_api
import szurubooru.api.pool_api import szurubooru.api.pool_api
import szurubooru.api.pool_category_api import szurubooru.api.pool_category_api

View file

@ -45,7 +45,7 @@ def get_info(ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
"defaultUserRank": config.config["default_rank"], "defaultUserRank": config.config["default_rank"],
"enableSafety": config.config["enable_safety"], "enableSafety": config.config["enable_safety"],
"contactEmail": config.config["contact_email"], "contactEmail": config.config["contact_email"],
"canSendMails": bool(config.config["smtp"]["host"]), "canSendMails": util.can_send_mail(),
"privileges": util.snake_case_to_lower_camel_case_keys( "privileges": util.snake_case_to_lower_camel_case_keys(
config.config["privileges"] config.config["privileges"]
), ),
@ -64,3 +64,28 @@ def get_info(ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
) )
ret["featuringTime"] = post_feature.time if post_feature else None ret["featuringTime"] = post_feature.time if post_feature else None
return ret return ret
@rest.routes.get(r"/manifest(?:\.json)?")
def generate_manifest(
ctx: rest.Context, _params: Dict[str, str] = {}
) -> rest.Response:
return {
"name": config.config["name"],
"icons": [
{
"src": util.add_url_prefix("/img/android-chrome-192x192.png"),
"type": "image/png",
"sizes": "192x192",
},
{
"src": util.add_url_prefix("/img/android-chrome-512x512.png"),
"type": "image/png",
"sizes": "512x512",
},
],
"start_url": util.add_url_prefix(),
"theme_color": "#24aadd",
"background_color": "#ffffff",
"display": "standalone",
}

View file

@ -0,0 +1,231 @@
from typing import Callable, Dict
from yattag import Doc
from szurubooru import config, model, rest
from szurubooru.func import auth, posts, util
_default_meta_tags = {
"viewport": "width=device-width, initial-scale=1, maximum-scale=1",
"theme-color": "#24aadd",
"apple-mobile-web-app-capable": "yes",
"apple-mobile-web-app-status-bar-style": "black",
"msapplication-TileColor": "#ffffff",
"msapplication-TileImage": "/img/mstile-150x150.png",
}
_apple_touch_startup_images = {
"640x1136": {
"device-width": "320px",
"device-height": "568px",
"-webkit-device-pixel-ratio": 2,
"orientation": "portrait",
},
"750x1294": {
"device-width": "375px",
"device-height": "667px",
"-webkit-device-pixel-ratio": 2,
"orientation": "portrait",
},
"1242x2148": {
"device-width": "414px",
"device-height": "736px",
"-webkit-device-pixel-ratio": 3,
"orientation": "portrait",
},
"1125x2436": {
"device-width": "375px",
"device-height": "812px",
"-webkit-device-pixel-ratio": 3,
"orientation": "portrait",
},
"1536x2048": {
"min-device-width": "768px",
"max-device-width": "1024px",
"-webkit-min-device-pixel-ratio": 2,
"orientation": "portrait",
},
"1668x2224": {
"min-device-width": "834px",
"max-device-width": "834px",
"-webkit-min-device-pixel-ratio": 2,
"orientation": "portrait",
},
"2048x2732": {
"min-device-width": "1024px",
"max-device-width": "1024px",
"-webkit-min-device-pixel-ratio": 2,
"orientation": "portrait",
},
}
def _get_html_template(
title: str,
header_content: str = "",
) -> Doc:
doc = Doc()
doc.asis("<!DOCTYPE html>")
with doc.tag("html"):
with doc.tag("head"):
doc.stag("meta", charset="utf-8")
for name, content in _default_meta_tags.items():
doc.stag("meta", name=name, content=content)
with doc.tag("title"):
doc.text(title)
doc.stag(
"link",
rel="manifest",
href=util.add_url_prefix("/api/manifest.json"),
)
doc.stag(
"link",
href=util.add_url_prefix("/css/app.min.css"),
rel="stylesheet",
type="text/css",
)
doc.stag(
"link",
href=util.add_url_prefix("/css/vendor.min.css"),
rel="stylesheet",
type="text/css",
)
doc.stag(
"link",
rel="shortcut icon",
type="image/png",
href=util.add_url_prefix("/img/favicon.png"),
)
doc.stag(
"link",
rel="apple-touch-icon",
sizes="180x180",
href=util.add_url_prefix("/img/apple-touch-icon.png"),
)
for res, media in _apple_touch_startup_images.items():
doc.stag(
"link",
rel="apple-touch-startup-image",
href=util.add_url_prefix(
f"/img/apple-touch-startup-image-{res}.png"
),
media=" and ".join(
f"({k}: {v})" for k, v in media.items()
),
)
doc.stag("base", href=util.add_url_prefix())
doc.asis(header_content)
with doc.tag("body"):
with doc.tag("div", id="top-navigation-holder"):
pass
with doc.tag("div", id="content-holder"):
pass
with doc.tag(
"script",
type="text/javascript",
src=util.add_url_prefix("js/vendor.min.js"),
):
pass
with doc.tag(
"script",
type="text/javascript",
src=util.add_url_prefix("js/app.min.js"),
):
pass
return doc.getvalue()
def _get_post_id(params: Dict[str, str]) -> int:
try:
return int(params["post_id"])
except TypeError:
raise posts.InvalidPostIdError(
"Invalid post ID: %r." % params["post_id"]
)
def _get_post(params: Dict[str, str]) -> model.Post:
return posts.get_post_by_id(_get_post_id(params))
@rest.routes.get("/html/post/(?P<post_id>[^/]+)/?", accept="text/html")
def get_post_html(
ctx: rest.Context, params: Dict[str, str] = {}
) -> rest.Response:
try:
post = _get_post(params)
title = f"{config.config['name']} - Post #{_get_post_id(params)}"
except posts.InvalidPostIdError:
# Return the default template and let the browser JS handle the 404
return _get_html_template()
doc = Doc()
doc.stag("meta", name="og:site_name", content=config.config["name"])
doc.stag(
"meta",
name="og:url",
content=util.add_url_prefix(f"post/{params['post_id']}"),
)
doc.stag("meta", name="og:title", content=title),
doc.stag("meta", name="twitter:title", content=title),
doc.stag("meta", name="og:type", content="article"),
if not auth.anon_has_privilege("posts:view"):
return _get_html_template(title=title, header_content=doc.getvalue())
content_url = util.add_data_prefix(posts.get_post_content_path(post))
thumbnail_url = util.add_data_prefix(posts.get_post_thumbnail_path(post))
tag_string = " ".join(tag.first_name for tag in post.tags)
doc.stag("meta", name="og:image:alt", content=tag_string)
doc.stag(
"meta",
name="og:article:published_time",
content=post.creation_time.isoformat(),
)
if post.last_edit_time:
doc.stag(
"meta",
name="og:article:modified_time",
content=post.last_edit_time.isoformat(),
)
for tag in post.tags:
doc.stag("meta", name="article:tag", content=tag.first_name)
if post.type in (model.Post.TYPE_VIDEO,):
doc.stag("meta", name="twitter:card", content="player")
doc.stag("meta", name="og:video:url", content=content_url)
doc.stag("meta", name="twitter:player:stream", content=content_url)
doc.stag("meta", name="og:image:url", content=thumbnail_url)
if post.canvas_width and post.canvas_height:
doc.stag(
"meta", name="og:video:width", content=str(post.canvas_width)
)
doc.stag(
"meta", name="og:video:height", content=str(post.canvas_height)
)
doc.stag(
"meta",
name="twitter:player:width",
content=str(post.canvas_width),
)
doc.stag(
"meta",
name="twitter:player:height",
content=str(post.canvas_height),
)
doc.stag("link", name="preload", href=content_url, **{"as": "video"})
else:
doc.stag("meta", name="twitter:card", content="summary_large_image")
doc.stag("meta", name="og:image:url", content=content_url)
doc.stag("meta", name="twitter:image", content=content_url)
doc.stag("link", name="preload", href=content_url, **{"as": "image"})
return _get_html_template(title=title, header_content=doc.getvalue())
@rest.routes.get("/html/.*", accept="text/html")
def default_route(
ctx: rest.Context, _params: Dict[str, str] = {}
) -> rest.Response:
return _get_html_template(title=config.config["name"])

View file

@ -2,7 +2,7 @@ from hashlib import md5
from typing import Dict from typing import Dict
from szurubooru import config, errors, rest from szurubooru import config, errors, rest
from szurubooru.func import auth, mailer, users, versions from szurubooru.func import auth, mailer, users, util, versions
MAIL_SUBJECT = "Password reset for {name}" MAIL_SUBJECT = "Password reset for {name}"
MAIL_BODY = ( MAIL_BODY = (
@ -24,16 +24,7 @@ def start_password_reset(
% (user_name) % (user_name)
) )
token = auth.generate_authentication_token(user) token = auth.generate_authentication_token(user)
url = util.add_url_prefix(f"password-reset/{user.name}:{token}")
if config.config["domain"]:
url = config.config["domain"]
elif "HTTP_ORIGIN" in ctx.env:
url = ctx.env["HTTP_ORIGIN"].rstrip("/")
elif "HTTP_REFERER" in ctx.env:
url = ctx.env["HTTP_REFERER"].rstrip("/")
else:
url = ""
url += "/password-reset/%s:%s" % (user.name, token)
mailer.send_mail( mailer.send_mail(
config.config["smtp"]["from"], config.config["smtp"]["from"],

View file

@ -9,7 +9,8 @@ _search_executor = search.Executor(search.configs.PoolSearchConfig())
def _serialize(ctx: rest.Context, pool: model.Pool) -> rest.Response: def _serialize(ctx: rest.Context, pool: model.Pool) -> rest.Response:
return pools.serialize_pool( return pools.serialize_pool(
pool, options=serialization.get_serialization_options(ctx) pool,
options=serialization.get_serialization_options(ctx),
) )

View file

@ -35,7 +35,9 @@ def _serialize_post(
ctx: rest.Context, post: Optional[model.Post] ctx: rest.Context, post: Optional[model.Post]
) -> rest.Response: ) -> rest.Response:
return posts.serialize_post( return posts.serialize_post(
post, ctx.user, options=serialization.get_serialization_options(ctx) post,
ctx.user,
options=serialization.get_serialization_options(ctx),
) )

View file

@ -31,7 +31,6 @@ def _docker_config() -> Dict:
return { return {
"debug": True, "debug": True,
"show_sql": int(os.getenv("LOG_SQL", 0)), "show_sql": int(os.getenv("LOG_SQL", 0)),
"data_url": os.getenv("DATA_URL", "data/"),
"data_dir": "/data/", "data_dir": "/data/",
"database": "postgres://%(user)s:%(pass)s@%(host)s:%(port)d/%(db)s" "database": "postgres://%(user)s:%(pass)s@%(host)s:%(port)d/%(db)s"
% { % {

View file

@ -85,13 +85,7 @@ def validate_config() -> None:
% (config.config["default_rank"]) % (config.config["default_rank"])
) )
for key in ["data_url", "data_dir"]: if not os.path.isabs(config.config["data_dir"] or ""):
if not config.config[key]:
raise errors.ConfigError(
"Service is not configured: %r is missing" % key
)
if not os.path.isabs(config.config["data_dir"]):
raise errors.ConfigError("data_dir must be an absolute path") raise errors.ConfigError("data_dir must be an absolute path")
if not config.config["database"]: if not config.config["database"]:

View file

@ -106,6 +106,16 @@ def is_valid_token(user_token: Optional[model.UserToken]) -> bool:
return True return True
def anon_has_privilege(privilege_name: str) -> bool:
all_ranks = list(RANK_MAP.keys())
assert privilege_name in config.config["privileges"]
minimal_rank = util.flip(RANK_MAP)[
config.config["privileges"][privilege_name]
]
good_ranks = all_ranks[all_ranks.index(minimal_rank) :]
return model.User.RANK_ANONYMOUS in good_ranks
def has_privilege(user: model.User, privilege_name: str) -> bool: def has_privilege(user: model.User, privilege_name: str) -> bool:
assert user assert user
all_ranks = list(RANK_MAP.keys()) all_ranks = list(RANK_MAP.keys())

View file

@ -145,7 +145,8 @@ class PoolSerializer(serialization.BaseSerializer):
def serialize_pool( def serialize_pool(
pool: model.Pool, options: List[str] = [] pool: model.Pool,
options: List[str] = [],
) -> Optional[rest.Response]: ) -> Optional[rest.Response]:
if not pool: if not pool:
return None return None
@ -154,7 +155,8 @@ def serialize_pool(
def serialize_micro_pool(pool: model.Pool) -> Optional[rest.Response]: def serialize_micro_pool(pool: model.Pool) -> Optional[rest.Response]:
return serialize_pool( return serialize_pool(
pool, options=["id", "names", "category", "description", "postCount"] pool,
options=["id", "names", "category", "description", "postCount"],
) )

View file

@ -44,7 +44,9 @@ class PostAlreadyUploadedError(errors.ValidationError):
super().__init__( super().__init__(
"Post already uploaded (%d)" % other_post.post_id, "Post already uploaded (%d)" % other_post.post_id,
{ {
"otherPostUrl": get_post_content_url(other_post), "otherPostUrl": util.add_data_prefix(
get_post_content_path(other_post)
),
"otherPostId": other_post.post_id, "otherPostId": other_post.post_id,
}, },
) )
@ -105,25 +107,6 @@ def get_post_security_hash(id: int) -> str:
).hexdigest()[0:16] ).hexdigest()[0:16]
def get_post_content_url(post: model.Post) -> str:
assert post
return "%s/posts/%d_%s.%s" % (
config.config["data_url"].rstrip("/"),
post.post_id,
get_post_security_hash(post.post_id),
mime.get_extension(post.mime_type) or "dat",
)
def get_post_thumbnail_url(post: model.Post) -> str:
assert post
return "%s/generated-thumbnails/%d_%s.jpg" % (
config.config["data_url"].rstrip("/"),
post.post_id,
get_post_security_hash(post.post_id),
)
def get_post_content_path(post: model.Post) -> str: def get_post_content_path(post: model.Post) -> str:
assert post assert post
assert post.post_id assert post.post_id
@ -159,7 +142,11 @@ def serialize_note(note: model.PostNote) -> rest.Response:
class PostSerializer(serialization.BaseSerializer): class PostSerializer(serialization.BaseSerializer):
def __init__(self, post: model.Post, auth_user: model.User) -> None: def __init__(
self,
post: model.Post,
auth_user: model.User,
) -> None:
self.post = post self.post = post
self.auth_user = auth_user self.auth_user = auth_user
@ -241,10 +228,10 @@ class PostSerializer(serialization.BaseSerializer):
return self.post.canvas_height return self.post.canvas_height
def serialize_content_url(self) -> Any: def serialize_content_url(self) -> Any:
return get_post_content_url(self.post) return util.add_data_prefix(get_post_content_path(self.post))
def serialize_thumbnail_url(self) -> Any: def serialize_thumbnail_url(self) -> Any:
return get_post_thumbnail_url(self.post) return util.add_data_prefix(get_post_thumbnail_path(self.post))
def serialize_flags(self) -> Any: def serialize_flags(self) -> Any:
return self.post.flags return self.post.flags
@ -264,7 +251,7 @@ class PostSerializer(serialization.BaseSerializer):
{ {
post["id"]: post post["id"]: post
for post in [ for post in [
serialize_micro_post(rel, self.auth_user) serialize_micro_post(rel, self.auth_user, self.url_prefix)
for rel in self.post.relations for rel in self.post.relations
] ]
}.values(), }.values(),
@ -346,7 +333,9 @@ class PostSerializer(serialization.BaseSerializer):
def serialize_post( def serialize_post(
post: Optional[model.Post], auth_user: model.User, options: List[str] = [] post: Optional[model.Post],
auth_user: model.User,
options: List[str] = [],
) -> Optional[rest.Response]: ) -> Optional[rest.Response]:
if not post: if not post:
return None return None
@ -354,7 +343,8 @@ def serialize_post(
def serialize_micro_post( def serialize_micro_post(
post: model.Post, auth_user: model.User post: model.Post,
auth_user: model.User,
) -> Optional[rest.Response]: ) -> Optional[rest.Response]:
return serialize_post( return serialize_post(
post, auth_user=auth_user, options=["id", "thumbnailUrl"] post, auth_user=auth_user, options=["id", "thumbnailUrl"]

View file

@ -5,8 +5,9 @@ import tempfile
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict, Generator, List, Optional, Tuple, TypeVar, Union from typing import Any, Dict, Generator, List, Optional, Tuple, TypeVar, Union
from urllib.parse import urlparse, urlunparse
from szurubooru import errors from szurubooru import config, errors
T = TypeVar("T") T = TypeVar("T")
@ -176,3 +177,34 @@ def get_column_size(column: Any) -> Optional[int]:
def chunks(source_list: List[Any], part_size: int) -> Generator: def chunks(source_list: List[Any], part_size: int) -> Generator:
for i in range(0, len(source_list), part_size): for i in range(0, len(source_list), part_size):
yield source_list[i : i + part_size] yield source_list[i : i + part_size]
def _get_url_prefix_parts() -> str:
parsed_base_url = list(urlparse(config.config["base_url"]))
if not all(parsed_base_url[0:2]):
parsed_base_url[0:2] = ["", ""]
parsed_base_url[2] = parsed_base_url[2].rstrip("/")
return parsed_base_url[0:3] + ["", "", ""]
def _get_data_prefix_parts() -> str:
parsed_base_url = _get_url_prefix_parts()
parsed_data_url = list(urlparse(config.config["data_url"]))
if not all(parsed_data_url[0:2]):
parsed_data_url[0:2] = parsed_base_url[0:2]
if not parsed_data_url[2].startswith("/"):
parsed_data_url[2] = parsed_base_url[2] + "/" + parsed_data_url[2]
parsed_data_url[2] = parsed_data_url[2].rstrip("/")
return parsed_data_url[0:3] + ["", "", ""]
def add_url_prefix(url: str = "") -> str:
return urlunparse(_get_url_prefix_parts()) + "/" + url.lstrip("/")
def add_data_prefix(url: str = "") -> str:
return urlunparse(_get_data_prefix_parts()) + "/" + url.lstrip("/")
def can_send_mail() -> bool:
return bool(config.config["smtp"]["host"] and _get_url_prefix_parts()[1])

View file

@ -13,10 +13,18 @@ def process_request(_ctx: rest.Context) -> None:
@middleware.post_hook @middleware.post_hook
def process_response(ctx: rest.Context) -> None: def process_response(ctx: rest.Context) -> None:
logger.info( if ctx.accept == "application/json":
"%s %s (user=%s, queries=%d)", logger.info(
ctx.method, "%s %s (user=%s, queries=%d)",
ctx.url, ctx.method,
ctx.user.name, ctx.url,
db.get_query_count(), ctx.user.name,
) db.get_query_count(),
)
elif ctx.accept == "text/html":
logger.info(
"HTML %s (user-agent='%s' queries=%d)",
ctx.url,
ctx.get_header("User-Agent"),
db.get_query_count(),
)

View file

@ -5,7 +5,7 @@ import urllib.parse
from datetime import datetime from datetime import datetime
from typing import Any, Callable, Dict, Tuple from typing import Any, Callable, Dict, Tuple
from szurubooru import db from szurubooru import config, db
from szurubooru.func import util from szurubooru.func import util
from szurubooru.rest import context, errors, middleware, routes from szurubooru.rest import context, errors, middleware, routes
@ -18,8 +18,12 @@ def _json_serializer(obj: Any) -> str:
raise TypeError("Type not serializable") raise TypeError("Type not serializable")
def _dump_json(obj: Any) -> str: def _serialize_response_body(obj: Any, accept: str) -> str:
return json.dumps(obj, default=_json_serializer, indent=2) if accept == "application/json":
return json.dumps(obj, default=_json_serializer, indent=2)
if "text/" in accept:
return obj
raise ValueError("Unhandled response type %s" % accept)
def _get_headers(env: Dict[str, Any]) -> Dict[str, str]: def _get_headers(env: Dict[str, Any]) -> Dict[str, str]:
@ -66,7 +70,14 @@ def _create_context(env: Dict[str, Any]) -> context.Context:
"was incorrect or was not encoded as UTF-8.", "was incorrect or was not encoded as UTF-8.",
) )
return context.Context(env, method, path, headers, params, files) return context.Context(
env=env,
method=method,
url=path,
headers=headers,
params=params,
files=files,
)
def application( def application(
@ -74,20 +85,32 @@ def application(
) -> Tuple[bytes]: ) -> Tuple[bytes]:
try: try:
ctx = _create_context(env) ctx = _create_context(env)
if "application/json" not in ctx.get_header("Accept"):
raise errors.HttpNotAcceptable(
"ValidationError", "This API only supports JSON responses."
)
for url, allowed_methods in routes.routes.items(): for url, allowed_methods in routes.routes.items():
match = re.fullmatch(url, ctx.url) match = re.fullmatch(url, ctx.url)
if match: if match:
if ctx.method not in allowed_methods: if ctx.method not in allowed_methods:
raise errors.HttpMethodNotAllowed( raise errors.HttpMethodNotAllowed(
"ValidationError", "ValidationError",
"Allowed methods: %r" % allowed_methods, "Allowed methods: %s"
% ", ".join(allowed_methods.keys()),
) )
handler = allowed_methods[ctx.method] handler, allowed_accept = allowed_methods[ctx.method]
if not any(
map(
lambda a: a in ctx.get_header("Accept"),
[
allowed_accept,
allowed_accept.split("/")[0] + "/*",
"*/*",
],
)
):
raise errors.HttpNotAcceptable(
"ValidationError",
"This route only supports %s responses."
% allowed_accept,
)
ctx.accept = allowed_accept
break break
else: else:
raise errors.HttpNotFound( raise errors.HttpNotFound(
@ -111,8 +134,10 @@ def application(
finally: finally:
db.session.remove() db.session.remove()
start_response("200", [("content-type", "application/json")]) start_response("200", [("content-type", ctx.accept)])
return (_dump_json(response).encode("utf-8"),) return (
_serialize_response_body(response, ctx.accept).encode("utf-8"),
)
except Exception as ex: except Exception as ex:
for exception_type, ex_handler in errors.error_handlers.items(): for exception_type, ex_handler in errors.error_handlers.items():
@ -133,4 +158,6 @@ def application(
if ex.extra_fields is not None: if ex.extra_fields is not None:
for key, value in ex.extra_fields.items(): for key, value in ex.extra_fields.items():
blob[key] = value blob[key] = value
return (_dump_json(blob).encode("utf-8"),) return (
_serialize_response_body(blob, "application/json").encode("utf-8"),
)

View file

@ -15,12 +15,14 @@ class Context:
method: str, method: str,
url: str, url: str,
headers: Dict[str, str] = None, headers: Dict[str, str] = None,
accept: str = None,
params: Request = None, params: Request = None,
files: Dict[str, bytes] = None, files: Dict[str, bytes] = None,
) -> None: ) -> None:
self.env = env self.env = env
self.method = method self.method = method
self.url = url self.url = url
self.accept = accept
self._headers = headers or {} self._headers = headers or {}
self._params = params or {} self._params = params or {}
self._files = files or {} self._files = files or {}

View file

@ -1,39 +1,48 @@
from collections import defaultdict from collections import defaultdict
from typing import Callable, Dict from typing import Callable, Dict, Tuple
from szurubooru.rest.context import Context, Response from szurubooru.rest.context import Context, Response
RouteHandler = Callable[[Context, Dict[str, str]], Response] RouteHandler = Callable[[Context, Dict[str, str]], Response]
routes = defaultdict(dict) # type: Dict[str, Dict[str, RouteHandler]] routes = defaultdict(dict)
# type: Dict[str, Dict[str, Tuple[RouteHandler, str]]]
def get(url: str) -> Callable[[RouteHandler], RouteHandler]: def get(
url: str, accept: str = "application/json"
) -> Callable[[RouteHandler], RouteHandler]:
def wrapper(handler: RouteHandler) -> RouteHandler: def wrapper(handler: RouteHandler) -> RouteHandler:
routes[url]["GET"] = handler routes[url]["GET"] = (handler, accept)
return handler return handler
return wrapper return wrapper
def put(url: str) -> Callable[[RouteHandler], RouteHandler]: def put(
url: str, accept: str = "application/json"
) -> Callable[[RouteHandler], RouteHandler]:
def wrapper(handler: RouteHandler) -> RouteHandler: def wrapper(handler: RouteHandler) -> RouteHandler:
routes[url]["PUT"] = handler routes[url]["PUT"] = (handler, accept)
return handler return handler
return wrapper return wrapper
def post(url: str) -> Callable[[RouteHandler], RouteHandler]: def post(
url: str, accept: str = "application/json"
) -> Callable[[RouteHandler], RouteHandler]:
def wrapper(handler: RouteHandler) -> RouteHandler: def wrapper(handler: RouteHandler) -> RouteHandler:
routes[url]["POST"] = handler routes[url]["POST"] = (handler, accept)
return handler return handler
return wrapper return wrapper
def delete(url: str) -> Callable[[RouteHandler], RouteHandler]: def delete(
url: str, accept: str = "application/json"
) -> Callable[[RouteHandler], RouteHandler]:
def wrapper(handler: RouteHandler) -> RouteHandler: def wrapper(handler: RouteHandler) -> RouteHandler:
routes[url]["DELETE"] = handler routes[url]["DELETE"] = (handler, accept)
return handler return handler
return wrapper return wrapper

View file

@ -18,6 +18,7 @@ def test_info_api(
config_injector( config_injector(
{ {
"name": "test installation", "name": "test installation",
"base_url": "https://www.example.com",
"contact_email": "test@example.com", "contact_email": "test@example.com",
"enable_safety": True, "enable_safety": True,
"data_dir": str(directory), "data_dir": str(directory),
@ -94,3 +95,33 @@ def test_info_api(
"serverTime": datetime(2016, 1, 3, 13, 1), "serverTime": datetime(2016, 1, 3, 13, 1),
"config": expected_config_key, "config": expected_config_key,
} }
def test_manifest(config_injector, context_factory):
config_injector(
{
"name": "test installation",
"base_url": "/someprefix",
}
)
ctx = context_factory()
expected_manifest = {
"name": "test installation",
"icons": [
{
"src": "/someprefix/img/android-chrome-192x192.png",
"type": "image/png",
"sizes": "192x192",
},
{
"src": "/someprefix/img/android-chrome-512x512.png",
"type": "image/png",
"sizes": "512x512",
},
],
"start_url": "/someprefix/",
"theme_color": "#24aadd",
"background_color": "#ffffff",
"display": "standalone",
}
assert api.info_api.generate_manifest(ctx) == expected_manifest

View file

@ -0,0 +1,103 @@
from unittest.mock import patch
import pytest
import yattag
from szurubooru import api, db, model
from szurubooru.func import auth, posts
def _make_meta_tag(name, content):
doc = yattag.Doc()
doc.stag("meta", name=name, content=content)
return doc.getvalue()
@pytest.mark.parametrize("view_priv", [True, False])
@pytest.mark.parametrize(
"post_type", [model.Post.TYPE_IMAGE, model.Post.TYPE_VIDEO]
)
def test_get_post_html(
config_injector, context_factory, post_factory, view_priv, post_type
):
config_injector(
{
"name": "testing",
"base_url": "/someprefix",
"data_url": "data",
}
)
post = post_factory(id=1, type=post_type)
post.canvas_width = 1920
post.canvas_height = 1080
db.session.add(post)
db.session.flush()
with patch("szurubooru.func.auth.anon_has_privilege"), patch(
"szurubooru.func.posts.get_post_content_path"
), patch("szurubooru.func.posts.get_post_thumbnail_path"):
auth.anon_has_privilege.return_value = view_priv
posts.get_post_content_path.return_value = "content-url"
posts.get_post_thumbnail_path.return_value = "thumbnail-url"
ret = api.opengraph_api.get_post_html(
context_factory(), {"post_id": 1}
)
assert _make_meta_tag("og:site_name", "testing") in ret
assert _make_meta_tag("og:url", "/someprefix/post/1") in ret
assert _make_meta_tag("og:title", "testing - Post #1") in ret
assert _make_meta_tag("twitter:title", "testing - Post #1") in ret
assert _make_meta_tag("og:type", "article") in ret
assert (
bool(
_make_meta_tag("og:article:published_time", "1996-01-01T00:00:00")
in ret
)
== view_priv
)
if post_type == model.Post.TYPE_VIDEO:
assert (
bool(_make_meta_tag("twitter:card", "player") in ret) == view_priv
)
assert (
bool(
_make_meta_tag(
"twitter:player:stream", "/someprefix/data/content-url"
)
in ret
)
== view_priv
)
assert (
bool(
_make_meta_tag("og:video:url", "/someprefix/data/content-url")
in ret
)
== view_priv
)
assert (
bool(
_make_meta_tag(
"og:image:url", "/someprefix/data/thumbnail-url"
)
in ret
)
== view_priv
)
assert (
bool(_make_meta_tag("og:video:width", "1920") in ret) == view_priv
)
assert (
bool(_make_meta_tag("og:video:height", "1080") in ret) == view_priv
)
else:
assert (
bool(_make_meta_tag("twitter:card", "summary_large_image") in ret)
== view_priv
)
assert (
bool(
_make_meta_tag("twitter:image", "/someprefix/data/content-url")
in ret
)
== view_priv
)

View file

@ -11,7 +11,7 @@ def inject_config(config_injector):
config_injector( config_injector(
{ {
"secret": "x", "secret": "x",
"domain": "http://example.com", "base_url": "http://example.com",
"name": "Test instance", "name": "Test instance",
"smtp": { "smtp": {
"from": "noreply@example.com", "from": "noreply@example.com",

View file

@ -41,19 +41,17 @@ def test_simple_updating(user_factory, pool_factory, context_factory):
): ):
posts.get_posts_by_ids.return_value = ([], []) posts.get_posts_by_ids.return_value = ([], [])
pools.serialize_pool.return_value = "serialized pool" pools.serialize_pool.return_value = "serialized pool"
result = api.pool_api.update_pool( ctx = context_factory(
context_factory( params={
params={ "version": 1,
"version": 1, "names": ["pool3"],
"names": ["pool3"], "category": "series",
"category": "series", "description": "desc",
"description": "desc", "posts": [1, 2],
"posts": [1, 2], },
}, user=auth_user,
user=auth_user,
),
{"pool_id": 1},
) )
result = api.pool_api.update_pool(ctx, {"pool_id": 1})
assert result == "serialized pool" assert result == "serialized pool"
pools.create_pool.assert_not_called() pools.create_pool.assert_not_called()
pools.update_pool_names.assert_called_once_with(pool, ["pool3"]) pools.update_pool_names.assert_called_once_with(pool, ["pool3"])

View file

@ -45,19 +45,18 @@ def test_creating_minimal_posts(context_factory, post_factory, user_factory):
posts.create_post.return_value = (post, []) posts.create_post.return_value = (post, [])
posts.serialize_post.return_value = "serialized post" posts.serialize_post.return_value = "serialized post"
result = api.post_api.create_post( ctx = context_factory(
context_factory( params={
params={ "safety": "safe",
"safety": "safe", "tags": ["tag1", "tag2"],
"tags": ["tag1", "tag2"], },
}, files={
files={ "content": "post-content",
"content": "post-content", "thumbnail": "post-thumbnail",
"thumbnail": "post-thumbnail", },
}, user=auth_user,
user=auth_user,
)
) )
result = api.post_api.create_post(ctx)
assert result == "serialized post" assert result == "serialized post"
posts.create_post.assert_called_once_with( posts.create_post.assert_called_once_with(
@ -102,22 +101,21 @@ def test_creating_full_posts(context_factory, post_factory, user_factory):
posts.create_post.return_value = (post, []) posts.create_post.return_value = (post, [])
posts.serialize_post.return_value = "serialized post" posts.serialize_post.return_value = "serialized post"
result = api.post_api.create_post( ctx = context_factory(
context_factory( params={
params={ "safety": "safe",
"safety": "safe", "tags": ["tag1", "tag2"],
"tags": ["tag1", "tag2"], "relations": [1, 2],
"relations": [1, 2], "source": "source",
"source": "source", "notes": ["note1", "note2"],
"notes": ["note1", "note2"], "flags": ["flag1", "flag2"],
"flags": ["flag1", "flag2"], },
}, files={
files={ "content": "post-content",
"content": "post-content", },
}, user=auth_user,
user=auth_user,
)
) )
result = api.post_api.create_post(ctx)
assert result == "serialized post" assert result == "serialized post"
posts.create_post.assert_called_once_with( posts.create_post.assert_called_once_with(
@ -333,7 +331,8 @@ def test_errors_not_spending_ids(
config_injector( config_injector(
{ {
"data_dir": str(tmpdir.mkdir("data")), "data_dir": str(tmpdir.mkdir("data")),
"data_url": "example.com", "base_url": "https://example.com/",
"data_url": "https://example.com/data",
"thumbnails": { "thumbnails": {
"post_width": 300, "post_width": 300,
"post_height": 300, "post_height": 300,

View file

@ -67,12 +67,13 @@ def nontransacted_session(query_logger, postgresql_db):
@pytest.fixture @pytest.fixture
def context_factory(session): def context_factory(session):
def factory(params=None, files=None, user=None, headers=None): def factory(params=None, files=None, user=None, headers=None, accept=None):
ctx = rest.Context( ctx = rest.Context(
env={"HTTP_ORIGIN": "http://example.com"}, env={"HTTP_ORIGIN": "http://example.com"},
method=None, method=None,
url=None, url=None,
headers=headers or {}, headers=headers or {},
accept=accept or None,
params=params or {}, params=params or {},
files=files or {}, files=files or {},
) )

View file

@ -17,34 +17,6 @@ from szurubooru.func import (
) )
@pytest.mark.parametrize(
"input_mime_type,expected_url",
[
("image/jpeg", "http://example.com/posts/1_244c8840887984c4.jpg"),
("image/gif", "http://example.com/posts/1_244c8840887984c4.gif"),
("totally/unknown", "http://example.com/posts/1_244c8840887984c4.dat"),
],
)
def test_get_post_url(input_mime_type, expected_url, config_injector):
config_injector({"data_url": "http://example.com/", "secret": "test"})
post = model.Post()
post.post_id = 1
post.mime_type = input_mime_type
assert posts.get_post_content_url(post) == expected_url
@pytest.mark.parametrize("input_mime_type", ["image/jpeg", "image/gif"])
def test_get_post_thumbnail_url(input_mime_type, config_injector):
config_injector({"data_url": "http://example.com/", "secret": "test"})
post = model.Post()
post.post_id = 1
post.mime_type = input_mime_type
assert (
posts.get_post_thumbnail_url(post)
== "http://example.com/generated-thumbnails/1_244c8840887984c4.jpg"
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"input_mime_type,expected_path", "input_mime_type,expected_path",
[ [
@ -53,7 +25,10 @@ def test_get_post_thumbnail_url(input_mime_type, config_injector):
("totally/unknown", "posts/1_244c8840887984c4.dat"), ("totally/unknown", "posts/1_244c8840887984c4.dat"),
], ],
) )
def test_get_post_content_path(input_mime_type, expected_path): def test_get_post_content_path(
input_mime_type, expected_path, config_injector
):
config_injector({"secret": "test"})
post = model.Post() post = model.Post()
post.post_id = 1 post.post_id = 1
post.mime_type = input_mime_type post.mime_type = input_mime_type
@ -61,7 +36,8 @@ def test_get_post_content_path(input_mime_type, expected_path):
@pytest.mark.parametrize("input_mime_type", ["image/jpeg", "image/gif"]) @pytest.mark.parametrize("input_mime_type", ["image/jpeg", "image/gif"])
def test_get_post_thumbnail_path(input_mime_type): def test_get_post_thumbnail_path(input_mime_type, config_injector):
config_injector({"secret": "test"})
post = model.Post() post = model.Post()
post.post_id = 1 post.post_id = 1
post.mime_type = input_mime_type post.mime_type = input_mime_type
@ -72,7 +48,8 @@ def test_get_post_thumbnail_path(input_mime_type):
@pytest.mark.parametrize("input_mime_type", ["image/jpeg", "image/gif"]) @pytest.mark.parametrize("input_mime_type", ["image/jpeg", "image/gif"])
def test_get_post_thumbnail_backup_path(input_mime_type): def test_get_post_thumbnail_backup_path(input_mime_type, config_injector):
config_injector({"secret": "test"})
post = model.Post() post = model.Post()
post.post_id = 1 post.post_id = 1
post.mime_type = input_mime_type post.mime_type = input_mime_type
@ -105,7 +82,13 @@ def test_serialize_post(
pool_category_factory, pool_category_factory,
config_injector, config_injector,
): ):
config_injector({"data_url": "http://example.com/", "secret": "test"}) config_injector(
{
"secret": "test",
"base_url": "http://example.com/",
"data_url": "http://example.com/",
}
)
with patch("szurubooru.func.comments.serialize_comment"), patch( with patch("szurubooru.func.comments.serialize_comment"), patch(
"szurubooru.func.users.serialize_micro_user" "szurubooru.func.users.serialize_micro_user"
), patch("szurubooru.func.posts.files.has"): ), patch("szurubooru.func.posts.files.has"):
@ -277,17 +260,15 @@ def test_serialize_post(
def test_serialize_micro_post(post_factory, user_factory): def test_serialize_micro_post(post_factory, user_factory):
with patch("szurubooru.func.posts.get_post_thumbnail_url"): with patch("szurubooru.func.posts.get_post_thumbnail_path"):
posts.get_post_thumbnail_url.return_value = ( posts.get_post_thumbnail_path.return_value = "thumb.png"
"https://example.com/thumb.png"
)
auth_user = user_factory() auth_user = user_factory()
post = post_factory() post = post_factory()
db.session.add(post) db.session.add(post)
db.session.flush() db.session.flush()
assert posts.serialize_micro_post(post, auth_user) == { assert posts.serialize_micro_post(post, auth_user) == {
"id": post.post_id, "id": post.post_id,
"thumbnailUrl": "https://example.com/thumb.png", "thumbnailUrl": "http://example.com/thumb.png",
} }
@ -519,7 +500,8 @@ def test_update_post_content_to_existing_content(
config_injector( config_injector(
{ {
"data_dir": str(tmpdir.mkdir("data")), "data_dir": str(tmpdir.mkdir("data")),
"data_url": "example.com", "base_url": "https://example.com/",
"data_url": "https://example.com/data",
"thumbnails": { "thumbnails": {
"post_width": 300, "post_width": 300,
"post_height": 300, "post_height": 300,

View file

@ -45,3 +45,93 @@ def test_parsing_date_time(fake_datetime, input, output):
) )
def test_icase_unique(input, output): def test_icase_unique(input, output):
assert util.icase_unique(input) == output assert util.icase_unique(input) == output
def test_url_generation(config_injector):
config_injector(
{
"base_url": "https://www.example.com/",
"data_url": "data/",
}
)
assert util.add_url_prefix() == "https://www.example.com/"
assert util.add_url_prefix("/post/1") == "https://www.example.com/post/1"
assert util.add_url_prefix("post/1") == "https://www.example.com/post/1"
assert util.add_data_prefix() == "https://www.example.com/data/"
assert (
util.add_data_prefix("posts/1.jpg")
== "https://www.example.com/data/posts/1.jpg"
)
assert (
util.add_data_prefix("/posts/1.jpg")
== "https://www.example.com/data/posts/1.jpg"
)
config_injector(
{
"base_url": "https://www.example.com/szuru/",
"data_url": "data/",
}
)
assert util.add_url_prefix() == "https://www.example.com/szuru/"
assert (
util.add_url_prefix("/post/1")
== "https://www.example.com/szuru/post/1"
)
assert (
util.add_url_prefix("post/1") == "https://www.example.com/szuru/post/1"
)
assert util.add_data_prefix() == "https://www.example.com/szuru/data/"
assert (
util.add_data_prefix("posts/1.jpg")
== "https://www.example.com/szuru/data/posts/1.jpg"
)
assert (
util.add_data_prefix("/posts/1.jpg")
== "https://www.example.com/szuru/data/posts/1.jpg"
)
config_injector(
{
"base_url": "https://www.example.com/szuru/",
"data_url": "/data/",
}
)
assert util.add_url_prefix() == "https://www.example.com/szuru/"
assert (
util.add_url_prefix("/post/1")
== "https://www.example.com/szuru/post/1"
)
assert (
util.add_url_prefix("post/1") == "https://www.example.com/szuru/post/1"
)
assert util.add_data_prefix() == "https://www.example.com/data/"
assert (
util.add_data_prefix("posts/1.jpg")
== "https://www.example.com/data/posts/1.jpg"
)
assert (
util.add_data_prefix("/posts/1.jpg")
== "https://www.example.com/data/posts/1.jpg"
)
config_injector(
{
"base_url": "https://www.example.com/szuru",
"data_url": "https://static.example.com/",
}
)
assert util.add_url_prefix() == "https://www.example.com/szuru/"
assert (
util.add_url_prefix("/post/1")
== "https://www.example.com/szuru/post/1"
)
assert (
util.add_url_prefix("post/1") == "https://www.example.com/szuru/post/1"
)
assert util.add_data_prefix() == "https://static.example.com/"
assert (
util.add_data_prefix("posts/1.jpg")
== "https://static.example.com/posts/1.jpg"
)
assert (
util.add_data_prefix("/posts/1.jpg")
== "https://static.example.com/posts/1.jpg"
)

View file

@ -9,7 +9,7 @@ from szurubooru.rest import errors
def test_process_request_no_header(context_factory): def test_process_request_no_header(context_factory):
ctx = context_factory() ctx = context_factory(accept="application/json")
authenticator.process_request(ctx) authenticator.process_request(ctx)
assert ctx.user.name is None assert ctx.user.name is None
@ -21,6 +21,7 @@ def test_process_request_bump_login(context_factory, user_factory):
ctx = context_factory( ctx = context_factory(
headers={"Authorization": "Basic dGVzdFVzZXI6dGVzdFRva2Vu"}, headers={"Authorization": "Basic dGVzdFVzZXI6dGVzdFRva2Vu"},
params={"bump-login": "true"}, params={"bump-login": "true"},
accept="application/json",
) )
with patch("szurubooru.func.auth.is_valid_password"), patch( with patch("szurubooru.func.auth.is_valid_password"), patch(
"szurubooru.func.users.get_user_by_name" "szurubooru.func.users.get_user_by_name"
@ -40,6 +41,7 @@ def test_process_request_bump_login_with_token(
ctx = context_factory( ctx = context_factory(
headers={"Authorization": "Token dGVzdFVzZXI6dGVzdFRva2Vu"}, headers={"Authorization": "Token dGVzdFVzZXI6dGVzdFRva2Vu"},
params={"bump-login": "true"}, params={"bump-login": "true"},
accept="application/json",
) )
with patch("szurubooru.func.auth.is_valid_token"), patch( with patch("szurubooru.func.auth.is_valid_token"), patch(
"szurubooru.func.users.get_user_by_name" "szurubooru.func.users.get_user_by_name"
@ -55,7 +57,8 @@ def test_process_request_bump_login_with_token(
def test_process_request_basic_auth_valid(context_factory, user_factory): def test_process_request_basic_auth_valid(context_factory, user_factory):
user = user_factory() user = user_factory()
ctx = context_factory( ctx = context_factory(
headers={"Authorization": "Basic dGVzdFVzZXI6dGVzdFBhc3N3b3Jk"} headers={"Authorization": "Basic dGVzdFVzZXI6dGVzdFBhc3N3b3Jk"},
accept="application/json",
) )
with patch("szurubooru.func.auth.is_valid_password"), patch( with patch("szurubooru.func.auth.is_valid_password"), patch(
"szurubooru.func.users.get_user_by_name" "szurubooru.func.users.get_user_by_name"
@ -69,7 +72,8 @@ def test_process_request_basic_auth_valid(context_factory, user_factory):
def test_process_request_token_auth_valid(context_factory, user_token_factory): def test_process_request_token_auth_valid(context_factory, user_token_factory):
user_token = user_token_factory() user_token = user_token_factory()
ctx = context_factory( ctx = context_factory(
headers={"Authorization": "Token dGVzdFVzZXI6dGVzdFRva2Vu"} headers={"Authorization": "Token dGVzdFVzZXI6dGVzdFRva2Vu"},
accept="application/json",
) )
with patch("szurubooru.func.auth.is_valid_token"), patch( with patch("szurubooru.func.auth.is_valid_token"), patch(
"szurubooru.func.users.get_user_by_name" "szurubooru.func.users.get_user_by_name"
@ -82,6 +86,9 @@ def test_process_request_token_auth_valid(context_factory, user_token_factory):
def test_process_request_bad_header(context_factory): def test_process_request_bad_header(context_factory):
ctx = context_factory(headers={"Authorization": "Secret SuperSecretValue"}) ctx = context_factory(
headers={"Authorization": "Secret SuperSecretValue"},
accept="application/json",
)
with pytest.raises(errors.HttpBadRequest): with pytest.raises(errors.HttpBadRequest):
authenticator.process_request(ctx) authenticator.process_request(ctx)