server: refactor + add type hinting

- Added type hinting (for now, 3.5-compatible)
- Split `db` namespace into `db` module and `model` namespace
- Changed elastic search to be created lazily for each operation
- Changed to class based approach in entity serialization to allow
  stronger typing
- Removed `required` argument from `context.get_*` family of functions;
  now it's implied if `default` argument is omitted
- Changed `unalias_dict` implementation to use less magic inputs
This commit is contained in:
rr- 2017-02-04 01:08:12 +01:00
parent abf1fc2b2d
commit ad842ee8a5
116 changed files with 2868 additions and 2037 deletions

View file

@ -8,7 +8,7 @@ import zlib
import concurrent.futures import concurrent.futures
import logging import logging
import coloredlogs import coloredlogs
import sqlalchemy import sqlalchemy as sa
from szurubooru import config, db from szurubooru import config, db
from szurubooru.func import files, images, posts, comments from szurubooru.func import files, images, posts, comments
@ -42,8 +42,8 @@ def get_v1_session(args):
port=args.port, port=args.port,
name=args.name) name=args.name)
logger.info('Connecting to %r...', dsn) logger.info('Connecting to %r...', dsn)
engine = sqlalchemy.create_engine(dsn) engine = sa.create_engine(dsn)
session_maker = sqlalchemy.orm.sessionmaker(bind=engine) session_maker = sa.orm.sessionmaker(bind=engine)
return session_maker() return session_maker()
def parse_args(): def parse_args():

14
server/mypy.ini Normal file
View file

@ -0,0 +1,14 @@
[mypy]
ignore_missing_imports = True
follow_imports = skip
disallow_untyped_calls = True
disallow_untyped_defs = True
check_untyped_defs = True
disallow_subclassing_any = False
warn_redundant_casts = True
warn_unused_ignores = True
strict_optional = True
strict_boolean = False
[mypy-szurubooru.tests.*]
ignore_errors=True

View file

@ -1,31 +1,44 @@
import datetime from typing import Dict
from szurubooru import search from datetime import datetime
from szurubooru.rest import routes from szurubooru import search, rest, model
from szurubooru.func import auth, comments, posts, scores, util, versions from szurubooru.func import (
auth, comments, posts, scores, versions, serialization)
_search_executor = search.Executor(search.configs.CommentSearchConfig()) _search_executor = search.Executor(search.configs.CommentSearchConfig())
def _serialize(ctx, comment, **kwargs): def _get_comment(params: Dict[str, str]) -> model.Comment:
try:
comment_id = int(params['comment_id'])
except TypeError:
raise comments.InvalidCommentIdError(
'Invalid comment ID: %r.' % params['comment_id'])
return comments.get_comment_by_id(comment_id)
def _serialize(
ctx: rest.Context, comment: model.Comment) -> rest.Response:
return comments.serialize_comment( return comments.serialize_comment(
comment, comment,
ctx.user, ctx.user,
options=util.get_serialization_options(ctx), **kwargs) options=serialization.get_serialization_options(ctx))
@routes.get('/comments/?') @rest.routes.get('/comments/?')
def get_comments(ctx, _params=None): def get_comments(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
auth.verify_privilege(ctx.user, 'comments:list') auth.verify_privilege(ctx.user, 'comments:list')
return _search_executor.execute_and_serialize( return _search_executor.execute_and_serialize(
ctx, lambda comment: _serialize(ctx, comment)) ctx, lambda comment: _serialize(ctx, comment))
@routes.post('/comments/?') @rest.routes.post('/comments/?')
def create_comment(ctx, _params=None): def create_comment(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
auth.verify_privilege(ctx.user, 'comments:create') auth.verify_privilege(ctx.user, 'comments:create')
text = ctx.get_param_as_string('text', required=True) text = ctx.get_param_as_string('text')
post_id = ctx.get_param_as_int('postId', required=True) post_id = ctx.get_param_as_int('postId')
post = posts.get_post_by_id(post_id) post = posts.get_post_by_id(post_id)
comment = comments.create_comment(ctx.user, post, text) comment = comments.create_comment(ctx.user, post, text)
ctx.session.add(comment) ctx.session.add(comment)
@ -33,30 +46,30 @@ def create_comment(ctx, _params=None):
return _serialize(ctx, comment) return _serialize(ctx, comment)
@routes.get('/comment/(?P<comment_id>[^/]+)/?') @rest.routes.get('/comment/(?P<comment_id>[^/]+)/?')
def get_comment(ctx, params): def get_comment(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'comments:view') auth.verify_privilege(ctx.user, 'comments:view')
comment = comments.get_comment_by_id(params['comment_id']) comment = _get_comment(params)
return _serialize(ctx, comment) return _serialize(ctx, comment)
@routes.put('/comment/(?P<comment_id>[^/]+)/?') @rest.routes.put('/comment/(?P<comment_id>[^/]+)/?')
def update_comment(ctx, params): def update_comment(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
comment = comments.get_comment_by_id(params['comment_id']) comment = _get_comment(params)
versions.verify_version(comment, ctx) versions.verify_version(comment, ctx)
versions.bump_version(comment) versions.bump_version(comment)
infix = 'own' if ctx.user.user_id == comment.user_id else 'any' infix = 'own' if ctx.user.user_id == comment.user_id else 'any'
text = ctx.get_param_as_string('text', required=True) text = ctx.get_param_as_string('text')
auth.verify_privilege(ctx.user, 'comments:edit:%s' % infix) auth.verify_privilege(ctx.user, 'comments:edit:%s' % infix)
comments.update_comment_text(comment, text) comments.update_comment_text(comment, text)
comment.last_edit_time = datetime.datetime.utcnow() comment.last_edit_time = datetime.utcnow()
ctx.session.commit() ctx.session.commit()
return _serialize(ctx, comment) return _serialize(ctx, comment)
@routes.delete('/comment/(?P<comment_id>[^/]+)/?') @rest.routes.delete('/comment/(?P<comment_id>[^/]+)/?')
def delete_comment(ctx, params): def delete_comment(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
comment = comments.get_comment_by_id(params['comment_id']) comment = _get_comment(params)
versions.verify_version(comment, ctx) versions.verify_version(comment, ctx)
infix = 'own' if ctx.user.user_id == comment.user_id else 'any' infix = 'own' if ctx.user.user_id == comment.user_id else 'any'
auth.verify_privilege(ctx.user, 'comments:delete:%s' % infix) auth.verify_privilege(ctx.user, 'comments:delete:%s' % infix)
@ -65,20 +78,22 @@ def delete_comment(ctx, params):
return {} return {}
@routes.put('/comment/(?P<comment_id>[^/]+)/score/?') @rest.routes.put('/comment/(?P<comment_id>[^/]+)/score/?')
def set_comment_score(ctx, params): def set_comment_score(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'comments:score') auth.verify_privilege(ctx.user, 'comments:score')
score = ctx.get_param_as_int('score', required=True) score = ctx.get_param_as_int('score')
comment = comments.get_comment_by_id(params['comment_id']) comment = _get_comment(params)
scores.set_score(comment, ctx.user, score) scores.set_score(comment, ctx.user, score)
ctx.session.commit() ctx.session.commit()
return _serialize(ctx, comment) return _serialize(ctx, comment)
@routes.delete('/comment/(?P<comment_id>[^/]+)/score/?') @rest.routes.delete('/comment/(?P<comment_id>[^/]+)/score/?')
def delete_comment_score(ctx, params): def delete_comment_score(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'comments:score') auth.verify_privilege(ctx.user, 'comments:score')
comment = comments.get_comment_by_id(params['comment_id']) comment = _get_comment(params)
scores.delete_score(comment, ctx.user) scores.delete_score(comment, ctx.user)
ctx.session.commit() ctx.session.commit()
return _serialize(ctx, comment) return _serialize(ctx, comment)

View file

@ -1,19 +1,20 @@
import datetime
import os import os
from szurubooru import config from typing import Optional, Dict
from szurubooru.rest import routes from datetime import datetime, timedelta
from szurubooru import config, rest
from szurubooru.func import posts, users, util from szurubooru.func import posts, users, util
_cache_time = None _cache_time = None # type: Optional[datetime]
_cache_result = None _cache_result = None # type: Optional[int]
def _get_disk_usage(): def _get_disk_usage() -> int:
global _cache_time, _cache_result # pylint: disable=global-statement global _cache_time, _cache_result # pylint: disable=global-statement
threshold = datetime.timedelta(hours=48) threshold = timedelta(hours=48)
now = datetime.datetime.utcnow() now = datetime.utcnow()
if _cache_time and _cache_time > now - threshold: if _cache_time and _cache_time > now - threshold:
assert _cache_result
return _cache_result return _cache_result
total_size = 0 total_size = 0
for dir_path, _, file_names in os.walk(config.config['data_dir']): for dir_path, _, file_names in os.walk(config.config['data_dir']):
@ -25,8 +26,9 @@ def _get_disk_usage():
return total_size return total_size
@routes.get('/info/?') @rest.routes.get('/info/?')
def get_info(ctx, _params=None): def get_info(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
post_feature = posts.try_get_current_post_feature() post_feature = posts.try_get_current_post_feature()
return { return {
'postCount': posts.get_post_count(), 'postCount': posts.get_post_count(),
@ -38,7 +40,7 @@ def get_info(ctx, _params=None):
'featuringUser': 'featuringUser':
users.serialize_user(post_feature.user, ctx.user) users.serialize_user(post_feature.user, ctx.user)
if post_feature else None, if post_feature else None,
'serverTime': datetime.datetime.utcnow(), 'serverTime': datetime.utcnow(),
'config': { 'config': {
'userNameRegex': config.config['user_name_regex'], 'userNameRegex': config.config['user_name_regex'],
'passwordRegex': config.config['password_regex'], 'passwordRegex': config.config['password_regex'],

View file

@ -1,5 +1,5 @@
from szurubooru import config, errors from typing import Dict
from szurubooru.rest import routes from szurubooru import config, errors, rest
from szurubooru.func import auth, mailer, users, versions from szurubooru.func import auth, mailer, users, versions
@ -10,9 +10,9 @@ MAIL_BODY = \
'Otherwise, please ignore this email.' 'Otherwise, please ignore this email.'
@routes.get('/password-reset/(?P<user_name>[^/]+)/?') @rest.routes.get('/password-reset/(?P<user_name>[^/]+)/?')
def start_password_reset(_ctx, params): def start_password_reset(
''' Send a mail with secure token to the correlated user. ''' _ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
user_name = params['user_name'] user_name = params['user_name']
user = users.get_user_by_name_or_email(user_name) user = users.get_user_by_name_or_email(user_name)
if not user.email: if not user.email:
@ -30,13 +30,13 @@ def start_password_reset(_ctx, params):
return {} return {}
@routes.post('/password-reset/(?P<user_name>[^/]+)/?') @rest.routes.post('/password-reset/(?P<user_name>[^/]+)/?')
def finish_password_reset(ctx, params): def finish_password_reset(
''' Verify token from mail, generate a new password and return it. ''' ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
user_name = params['user_name'] user_name = params['user_name']
user = users.get_user_by_name_or_email(user_name) user = users.get_user_by_name_or_email(user_name)
good_token = auth.generate_authentication_token(user) good_token = auth.generate_authentication_token(user)
token = ctx.get_param_as_string('token', required=True) token = ctx.get_param_as_string('token')
if token != good_token: if token != good_token:
raise errors.ValidationError('Invalid password reset token.') raise errors.ValidationError('Invalid password reset token.')
new_password = users.reset_user_password(user) new_password = users.reset_user_password(user)

View file

@ -1,44 +1,60 @@
import datetime from typing import Optional, Dict
from szurubooru import search, db, errors from datetime import datetime
from szurubooru.rest import routes from szurubooru import db, model, errors, rest, search
from szurubooru.func import ( from szurubooru.func import (
auth, tags, posts, snapshots, favorites, scores, util, versions) auth, tags, posts, snapshots, favorites, scores, serialization, versions)
_search_executor = search.Executor(search.configs.PostSearchConfig()) _search_executor_config = search.configs.PostSearchConfig()
_search_executor = search.Executor(_search_executor_config)
def _serialize_post(ctx, post): def _get_post_id(params: Dict[str, str]) -> int:
try:
return int(params['post_id'])
except TypeError:
raise posts.InvalidPostIdError(
'Invalid post ID: %r.' % params['post_id'])
def _get_post(params: Dict[str, str]) -> model.Post:
return posts.get_post_by_id(_get_post_id(params))
def _serialize_post(
ctx: rest.Context, post: Optional[model.Post]) -> rest.Response:
return posts.serialize_post( return posts.serialize_post(
post, post,
ctx.user, ctx.user,
options=util.get_serialization_options(ctx)) options=serialization.get_serialization_options(ctx))
@routes.get('/posts/?') @rest.routes.get('/posts/?')
def get_posts(ctx, _params=None): def get_posts(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:list') auth.verify_privilege(ctx.user, 'posts:list')
_search_executor.config.user = ctx.user _search_executor_config.user = ctx.user
return _search_executor.execute_and_serialize( return _search_executor.execute_and_serialize(
ctx, lambda post: _serialize_post(ctx, post)) ctx, lambda post: _serialize_post(ctx, post))
@routes.post('/posts/?') @rest.routes.post('/posts/?')
def create_post(ctx, _params=None): def create_post(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
anonymous = ctx.get_param_as_bool('anonymous', default=False) anonymous = ctx.get_param_as_bool('anonymous', default=False)
if anonymous: if anonymous:
auth.verify_privilege(ctx.user, 'posts:create:anonymous') auth.verify_privilege(ctx.user, 'posts:create:anonymous')
else: else:
auth.verify_privilege(ctx.user, 'posts:create:identified') auth.verify_privilege(ctx.user, 'posts:create:identified')
content = ctx.get_file('content', required=True) content = ctx.get_file('content')
tag_names = ctx.get_param_as_list('tags', required=False, default=[]) tag_names = ctx.get_param_as_list('tags', default=[])
safety = ctx.get_param_as_string('safety', required=True) safety = ctx.get_param_as_string('safety')
source = ctx.get_param_as_string('source', required=False, default=None) source = ctx.get_param_as_string('source', default='')
if ctx.has_param('contentUrl') and not source: if ctx.has_param('contentUrl') and not source:
source = ctx.get_param_as_string('contentUrl') source = ctx.get_param_as_string('contentUrl', default='')
relations = ctx.get_param_as_list('relations', required=False) or [] relations = ctx.get_param_as_list('relations', default=[])
notes = ctx.get_param_as_list('notes', required=False) or [] notes = ctx.get_param_as_list('notes', default=[])
flags = ctx.get_param_as_list('flags', required=False) or [] flags = ctx.get_param_as_list('flags', default=[])
post, new_tags = posts.create_post( post, new_tags = posts.create_post(
content, tag_names, None if anonymous else ctx.user) content, tag_names, None if anonymous else ctx.user)
@ -61,16 +77,16 @@ def create_post(ctx, _params=None):
return _serialize_post(ctx, post) return _serialize_post(ctx, post)
@routes.get('/post/(?P<post_id>[^/]+)/?') @rest.routes.get('/post/(?P<post_id>[^/]+)/?')
def get_post(ctx, params): def get_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:view') auth.verify_privilege(ctx.user, 'posts:view')
post = posts.get_post_by_id(params['post_id']) post = _get_post(params)
return _serialize_post(ctx, post) return _serialize_post(ctx, post)
@routes.put('/post/(?P<post_id>[^/]+)/?') @rest.routes.put('/post/(?P<post_id>[^/]+)/?')
def update_post(ctx, params): def update_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
post = posts.get_post_by_id(params['post_id']) post = _get_post(params)
versions.verify_version(post, ctx) versions.verify_version(post, ctx)
versions.bump_version(post) versions.bump_version(post)
if ctx.has_file('content'): if ctx.has_file('content'):
@ -104,7 +120,7 @@ def update_post(ctx, params):
if ctx.has_file('thumbnail'): if ctx.has_file('thumbnail'):
auth.verify_privilege(ctx.user, 'posts:edit:thumbnail') auth.verify_privilege(ctx.user, 'posts:edit:thumbnail')
posts.update_post_thumbnail(post, ctx.get_file('thumbnail')) posts.update_post_thumbnail(post, ctx.get_file('thumbnail'))
post.last_edit_time = datetime.datetime.utcnow() post.last_edit_time = datetime.utcnow()
ctx.session.flush() ctx.session.flush()
snapshots.modify(post, ctx.user) snapshots.modify(post, ctx.user)
ctx.session.commit() ctx.session.commit()
@ -112,10 +128,10 @@ def update_post(ctx, params):
return _serialize_post(ctx, post) return _serialize_post(ctx, post)
@routes.delete('/post/(?P<post_id>[^/]+)/?') @rest.routes.delete('/post/(?P<post_id>[^/]+)/?')
def delete_post(ctx, params): def delete_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:delete') auth.verify_privilege(ctx.user, 'posts:delete')
post = posts.get_post_by_id(params['post_id']) post = _get_post(params)
versions.verify_version(post, ctx) versions.verify_version(post, ctx)
snapshots.delete(post, ctx.user) snapshots.delete(post, ctx.user)
posts.delete(post) posts.delete(post)
@ -124,13 +140,14 @@ def delete_post(ctx, params):
return {} return {}
@routes.post('/post-merge/?') @rest.routes.post('/post-merge/?')
def merge_posts(ctx, _params=None): def merge_posts(
source_post_id = ctx.get_param_as_string('remove', required=True) or '' ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
target_post_id = ctx.get_param_as_string('mergeTo', required=True) or '' source_post_id = ctx.get_param_as_int('remove')
replace_content = ctx.get_param_as_bool('replaceContent') target_post_id = ctx.get_param_as_int('mergeTo')
source_post = posts.get_post_by_id(source_post_id) source_post = posts.get_post_by_id(source_post_id)
target_post = posts.get_post_by_id(target_post_id) target_post = posts.get_post_by_id(target_post_id)
replace_content = ctx.get_param_as_bool('replaceContent')
versions.verify_version(source_post, ctx, 'removeVersion') versions.verify_version(source_post, ctx, 'removeVersion')
versions.verify_version(target_post, ctx, 'mergeToVersion') versions.verify_version(target_post, ctx, 'mergeToVersion')
versions.bump_version(target_post) versions.bump_version(target_post)
@ -141,16 +158,18 @@ def merge_posts(ctx, _params=None):
return _serialize_post(ctx, target_post) return _serialize_post(ctx, target_post)
@routes.get('/featured-post/?') @rest.routes.get('/featured-post/?')
def get_featured_post(ctx, _params=None): def get_featured_post(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
post = posts.try_get_featured_post() post = posts.try_get_featured_post()
return _serialize_post(ctx, post) return _serialize_post(ctx, post)
@routes.post('/featured-post/?') @rest.routes.post('/featured-post/?')
def set_featured_post(ctx, _params=None): def set_featured_post(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:feature') auth.verify_privilege(ctx.user, 'posts:feature')
post_id = ctx.get_param_as_int('id', required=True) post_id = ctx.get_param_as_int('id')
post = posts.get_post_by_id(post_id) post = posts.get_post_by_id(post_id)
featured_post = posts.try_get_featured_post() featured_post = posts.try_get_featured_post()
if featured_post and featured_post.post_id == post.post_id: if featured_post and featured_post.post_id == post.post_id:
@ -162,55 +181,61 @@ def set_featured_post(ctx, _params=None):
return _serialize_post(ctx, post) return _serialize_post(ctx, post)
@routes.put('/post/(?P<post_id>[^/]+)/score/?') @rest.routes.put('/post/(?P<post_id>[^/]+)/score/?')
def set_post_score(ctx, params): def set_post_score(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:score') auth.verify_privilege(ctx.user, 'posts:score')
post = posts.get_post_by_id(params['post_id']) post = _get_post(params)
score = ctx.get_param_as_int('score', required=True) score = ctx.get_param_as_int('score')
scores.set_score(post, ctx.user, score) scores.set_score(post, ctx.user, score)
ctx.session.commit() ctx.session.commit()
return _serialize_post(ctx, post) return _serialize_post(ctx, post)
@routes.delete('/post/(?P<post_id>[^/]+)/score/?') @rest.routes.delete('/post/(?P<post_id>[^/]+)/score/?')
def delete_post_score(ctx, params): def delete_post_score(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:score') auth.verify_privilege(ctx.user, 'posts:score')
post = posts.get_post_by_id(params['post_id']) post = _get_post(params)
scores.delete_score(post, ctx.user) scores.delete_score(post, ctx.user)
ctx.session.commit() ctx.session.commit()
return _serialize_post(ctx, post) return _serialize_post(ctx, post)
@routes.post('/post/(?P<post_id>[^/]+)/favorite/?') @rest.routes.post('/post/(?P<post_id>[^/]+)/favorite/?')
def add_post_to_favorites(ctx, params): def add_post_to_favorites(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:favorite') auth.verify_privilege(ctx.user, 'posts:favorite')
post = posts.get_post_by_id(params['post_id']) post = _get_post(params)
favorites.set_favorite(post, ctx.user) favorites.set_favorite(post, ctx.user)
ctx.session.commit() ctx.session.commit()
return _serialize_post(ctx, post) return _serialize_post(ctx, post)
@routes.delete('/post/(?P<post_id>[^/]+)/favorite/?') @rest.routes.delete('/post/(?P<post_id>[^/]+)/favorite/?')
def delete_post_from_favorites(ctx, params): def delete_post_from_favorites(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:favorite') auth.verify_privilege(ctx.user, 'posts:favorite')
post = posts.get_post_by_id(params['post_id']) post = _get_post(params)
favorites.unset_favorite(post, ctx.user) favorites.unset_favorite(post, ctx.user)
ctx.session.commit() ctx.session.commit()
return _serialize_post(ctx, post) return _serialize_post(ctx, post)
@routes.get('/post/(?P<post_id>[^/]+)/around/?') @rest.routes.get('/post/(?P<post_id>[^/]+)/around/?')
def get_posts_around(ctx, params): def get_posts_around(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:list') auth.verify_privilege(ctx.user, 'posts:list')
_search_executor.config.user = ctx.user _search_executor_config.user = ctx.user
post_id = _get_post_id(params)
return _search_executor.get_around_and_serialize( return _search_executor.get_around_and_serialize(
ctx, params['post_id'], lambda post: _serialize_post(ctx, post)) ctx, post_id, lambda post: _serialize_post(ctx, post))
@routes.post('/posts/reverse-search/?') @rest.routes.post('/posts/reverse-search/?')
def get_posts_by_image(ctx, _params=None): def get_posts_by_image(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:reverse_search') auth.verify_privilege(ctx.user, 'posts:reverse_search')
content = ctx.get_file('content', required=True) content = ctx.get_file('content')
try: try:
lookalikes = posts.search_by_image(content) lookalikes = posts.search_by_image(content)

View file

@ -1,13 +1,14 @@
from szurubooru import search from typing import Dict
from szurubooru.rest import routes from szurubooru import search, rest
from szurubooru.func import auth, snapshots from szurubooru.func import auth, snapshots
_search_executor = search.Executor(search.configs.SnapshotSearchConfig()) _search_executor = search.Executor(search.configs.SnapshotSearchConfig())
@routes.get('/snapshots/?') @rest.routes.get('/snapshots/?')
def get_snapshots(ctx, _params=None): def get_snapshots(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
auth.verify_privilege(ctx.user, 'snapshots:list') auth.verify_privilege(ctx.user, 'snapshots:list')
return _search_executor.execute_and_serialize( return _search_executor.execute_and_serialize(
ctx, lambda snapshot: snapshots.serialize_snapshot(snapshot, ctx.user)) ctx, lambda snapshot: snapshots.serialize_snapshot(snapshot, ctx.user))

View file

@ -1,18 +1,22 @@
import datetime from typing import Optional, List, Dict
from szurubooru import db, search from datetime import datetime
from szurubooru.rest import routes from szurubooru import db, model, search, rest
from szurubooru.func import auth, tags, snapshots, util, versions from szurubooru.func import auth, tags, snapshots, serialization, versions
_search_executor = search.Executor(search.configs.TagSearchConfig()) _search_executor = search.Executor(search.configs.TagSearchConfig())
def _serialize(ctx, tag): def _serialize(ctx: rest.Context, tag: model.Tag) -> rest.Response:
return tags.serialize_tag( return tags.serialize_tag(
tag, options=util.get_serialization_options(ctx)) tag, options=serialization.get_serialization_options(ctx))
def _create_if_needed(tag_names, user): def _get_tag(params: Dict[str, str]) -> model.Tag:
return tags.get_tag_by_name(params['tag_name'])
def _create_if_needed(tag_names: List[str], user: model.User) -> None:
if not tag_names: if not tag_names:
return return
_existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names) _existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names)
@ -23,25 +27,22 @@ def _create_if_needed(tag_names, user):
snapshots.create(tag, user) snapshots.create(tag, user)
@routes.get('/tags/?') @rest.routes.get('/tags/?')
def get_tags(ctx, _params=None): def get_tags(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
auth.verify_privilege(ctx.user, 'tags:list') auth.verify_privilege(ctx.user, 'tags:list')
return _search_executor.execute_and_serialize( return _search_executor.execute_and_serialize(
ctx, lambda tag: _serialize(ctx, tag)) ctx, lambda tag: _serialize(ctx, tag))
@routes.post('/tags/?') @rest.routes.post('/tags/?')
def create_tag(ctx, _params=None): def create_tag(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
auth.verify_privilege(ctx.user, 'tags:create') auth.verify_privilege(ctx.user, 'tags:create')
names = ctx.get_param_as_list('names', required=True) names = ctx.get_param_as_list('names')
category = ctx.get_param_as_string('category', required=True) category = ctx.get_param_as_string('category')
description = ctx.get_param_as_string( description = ctx.get_param_as_string('description', default='')
'description', required=False, default=None) suggestions = ctx.get_param_as_list('suggestions', default=[])
suggestions = ctx.get_param_as_list( implications = ctx.get_param_as_list('implications', default=[])
'suggestions', required=False, default=[])
implications = ctx.get_param_as_list(
'implications', required=False, default=[])
_create_if_needed(suggestions, ctx.user) _create_if_needed(suggestions, ctx.user)
_create_if_needed(implications, ctx.user) _create_if_needed(implications, ctx.user)
@ -56,16 +57,16 @@ def create_tag(ctx, _params=None):
return _serialize(ctx, tag) return _serialize(ctx, tag)
@routes.get('/tag/(?P<tag_name>.+)') @rest.routes.get('/tag/(?P<tag_name>.+)')
def get_tag(ctx, params): def get_tag(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'tags:view') auth.verify_privilege(ctx.user, 'tags:view')
tag = tags.get_tag_by_name(params['tag_name']) tag = _get_tag(params)
return _serialize(ctx, tag) return _serialize(ctx, tag)
@routes.put('/tag/(?P<tag_name>.+)') @rest.routes.put('/tag/(?P<tag_name>.+)')
def update_tag(ctx, params): def update_tag(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
tag = tags.get_tag_by_name(params['tag_name']) tag = _get_tag(params)
versions.verify_version(tag, ctx) versions.verify_version(tag, ctx)
versions.bump_version(tag) versions.bump_version(tag)
if ctx.has_param('names'): if ctx.has_param('names'):
@ -78,7 +79,7 @@ def update_tag(ctx, params):
if ctx.has_param('description'): if ctx.has_param('description'):
auth.verify_privilege(ctx.user, 'tags:edit:description') auth.verify_privilege(ctx.user, 'tags:edit:description')
tags.update_tag_description( tags.update_tag_description(
tag, ctx.get_param_as_string('description', default=None)) tag, ctx.get_param_as_string('description'))
if ctx.has_param('suggestions'): if ctx.has_param('suggestions'):
auth.verify_privilege(ctx.user, 'tags:edit:suggestions') auth.verify_privilege(ctx.user, 'tags:edit:suggestions')
suggestions = ctx.get_param_as_list('suggestions') suggestions = ctx.get_param_as_list('suggestions')
@ -89,7 +90,7 @@ def update_tag(ctx, params):
implications = ctx.get_param_as_list('implications') implications = ctx.get_param_as_list('implications')
_create_if_needed(implications, ctx.user) _create_if_needed(implications, ctx.user)
tags.update_tag_implications(tag, implications) tags.update_tag_implications(tag, implications)
tag.last_edit_time = datetime.datetime.utcnow() tag.last_edit_time = datetime.utcnow()
ctx.session.flush() ctx.session.flush()
snapshots.modify(tag, ctx.user) snapshots.modify(tag, ctx.user)
ctx.session.commit() ctx.session.commit()
@ -97,9 +98,9 @@ def update_tag(ctx, params):
return _serialize(ctx, tag) return _serialize(ctx, tag)
@routes.delete('/tag/(?P<tag_name>.+)') @rest.routes.delete('/tag/(?P<tag_name>.+)')
def delete_tag(ctx, params): def delete_tag(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
tag = tags.get_tag_by_name(params['tag_name']) tag = _get_tag(params)
versions.verify_version(tag, ctx) versions.verify_version(tag, ctx)
auth.verify_privilege(ctx.user, 'tags:delete') auth.verify_privilege(ctx.user, 'tags:delete')
snapshots.delete(tag, ctx.user) snapshots.delete(tag, ctx.user)
@ -109,10 +110,11 @@ def delete_tag(ctx, params):
return {} return {}
@routes.post('/tag-merge/?') @rest.routes.post('/tag-merge/?')
def merge_tags(ctx, _params=None): def merge_tags(
source_tag_name = ctx.get_param_as_string('remove', required=True) or '' ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
target_tag_name = ctx.get_param_as_string('mergeTo', required=True) or '' source_tag_name = ctx.get_param_as_string('remove')
target_tag_name = ctx.get_param_as_string('mergeTo')
source_tag = tags.get_tag_by_name(source_tag_name) source_tag = tags.get_tag_by_name(source_tag_name)
target_tag = tags.get_tag_by_name(target_tag_name) target_tag = tags.get_tag_by_name(target_tag_name)
versions.verify_version(source_tag, ctx, 'removeVersion') versions.verify_version(source_tag, ctx, 'removeVersion')
@ -126,10 +128,11 @@ def merge_tags(ctx, _params=None):
return _serialize(ctx, target_tag) return _serialize(ctx, target_tag)
@routes.get('/tag-siblings/(?P<tag_name>.+)') @rest.routes.get('/tag-siblings/(?P<tag_name>.+)')
def get_tag_siblings(ctx, params): def get_tag_siblings(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'tags:view') auth.verify_privilege(ctx.user, 'tags:view')
tag = tags.get_tag_by_name(params['tag_name']) tag = _get_tag(params)
result = tags.get_tag_siblings(tag) result = tags.get_tag_siblings(tag)
serialized_siblings = [] serialized_siblings = []
for sibling, occurrences in result: for sibling, occurrences in result:

View file

@ -1,15 +1,18 @@
from szurubooru.rest import routes from typing import Dict
from szurubooru import model, rest
from szurubooru.func import ( from szurubooru.func import (
auth, tags, tag_categories, snapshots, util, versions) auth, tags, tag_categories, snapshots, serialization, versions)
def _serialize(ctx, category): def _serialize(
ctx: rest.Context, category: model.TagCategory) -> rest.Response:
return tag_categories.serialize_category( return tag_categories.serialize_category(
category, options=util.get_serialization_options(ctx)) category, options=serialization.get_serialization_options(ctx))
@routes.get('/tag-categories/?') @rest.routes.get('/tag-categories/?')
def get_tag_categories(ctx, _params=None): def get_tag_categories(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
auth.verify_privilege(ctx.user, 'tag_categories:list') auth.verify_privilege(ctx.user, 'tag_categories:list')
categories = tag_categories.get_all_categories() categories = tag_categories.get_all_categories()
return { return {
@ -17,11 +20,12 @@ def get_tag_categories(ctx, _params=None):
} }
@routes.post('/tag-categories/?') @rest.routes.post('/tag-categories/?')
def create_tag_category(ctx, _params=None): def create_tag_category(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
auth.verify_privilege(ctx.user, 'tag_categories:create') auth.verify_privilege(ctx.user, 'tag_categories:create')
name = ctx.get_param_as_string('name', required=True) name = ctx.get_param_as_string('name')
color = ctx.get_param_as_string('color', required=True) color = ctx.get_param_as_string('color')
category = tag_categories.create_category(name, color) category = tag_categories.create_category(name, color)
ctx.session.add(category) ctx.session.add(category)
ctx.session.flush() ctx.session.flush()
@ -31,15 +35,17 @@ def create_tag_category(ctx, _params=None):
return _serialize(ctx, category) return _serialize(ctx, category)
@routes.get('/tag-category/(?P<category_name>[^/]+)/?') @rest.routes.get('/tag-category/(?P<category_name>[^/]+)/?')
def get_tag_category(ctx, params): def get_tag_category(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'tag_categories:view') auth.verify_privilege(ctx.user, 'tag_categories:view')
category = tag_categories.get_category_by_name(params['category_name']) category = tag_categories.get_category_by_name(params['category_name'])
return _serialize(ctx, category) return _serialize(ctx, category)
@routes.put('/tag-category/(?P<category_name>[^/]+)/?') @rest.routes.put('/tag-category/(?P<category_name>[^/]+)/?')
def update_tag_category(ctx, params): def update_tag_category(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
category = tag_categories.get_category_by_name( category = tag_categories.get_category_by_name(
params['category_name'], lock=True) params['category_name'], lock=True)
versions.verify_version(category, ctx) versions.verify_version(category, ctx)
@ -59,8 +65,9 @@ def update_tag_category(ctx, params):
return _serialize(ctx, category) return _serialize(ctx, category)
@routes.delete('/tag-category/(?P<category_name>[^/]+)/?') @rest.routes.delete('/tag-category/(?P<category_name>[^/]+)/?')
def delete_tag_category(ctx, params): def delete_tag_category(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
category = tag_categories.get_category_by_name( category = tag_categories.get_category_by_name(
params['category_name'], lock=True) params['category_name'], lock=True)
versions.verify_version(category, ctx) versions.verify_version(category, ctx)
@ -72,8 +79,9 @@ def delete_tag_category(ctx, params):
return {} return {}
@routes.put('/tag-category/(?P<category_name>[^/]+)/default/?') @rest.routes.put('/tag-category/(?P<category_name>[^/]+)/default/?')
def set_tag_category_as_default(ctx, params): def set_tag_category_as_default(
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, 'tag_categories:set_default') auth.verify_privilege(ctx.user, 'tag_categories:set_default')
category = tag_categories.get_category_by_name( category = tag_categories.get_category_by_name(
params['category_name'], lock=True) params['category_name'], lock=True)

View file

@ -1,10 +1,12 @@
from szurubooru.rest import routes from typing import Dict
from szurubooru import rest
from szurubooru.func import auth, file_uploads from szurubooru.func import auth, file_uploads
@routes.post('/uploads/?') @rest.routes.post('/uploads/?')
def create_temporary_file(ctx, _params=None): def create_temporary_file(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
auth.verify_privilege(ctx.user, 'uploads:create') auth.verify_privilege(ctx.user, 'uploads:create')
content = ctx.get_file('content', required=True, allow_tokens=False) content = ctx.get_file('content', allow_tokens=False)
token = file_uploads.save(content) token = file_uploads.save(content)
return {'token': token} return {'token': token}

View file

@ -1,56 +1,57 @@
from szurubooru import search from typing import Any, Dict
from szurubooru.rest import routes from szurubooru import model, search, rest
from szurubooru.func import auth, users, util, versions from szurubooru.func import auth, users, serialization, versions
_search_executor = search.Executor(search.configs.UserSearchConfig()) _search_executor = search.Executor(search.configs.UserSearchConfig())
def _serialize(ctx, user, **kwargs): def _serialize(
ctx: rest.Context, user: model.User, **kwargs: Any) -> rest.Response:
return users.serialize_user( return users.serialize_user(
user, user,
ctx.user, ctx.user,
options=util.get_serialization_options(ctx), options=serialization.get_serialization_options(ctx),
**kwargs) **kwargs)
@routes.get('/users/?') @rest.routes.get('/users/?')
def get_users(ctx, _params=None): def get_users(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
auth.verify_privilege(ctx.user, 'users:list') auth.verify_privilege(ctx.user, 'users:list')
return _search_executor.execute_and_serialize( return _search_executor.execute_and_serialize(
ctx, lambda user: _serialize(ctx, user)) ctx, lambda user: _serialize(ctx, user))
@routes.post('/users/?') @rest.routes.post('/users/?')
def create_user(ctx, _params=None): def create_user(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
auth.verify_privilege(ctx.user, 'users:create') auth.verify_privilege(ctx.user, 'users:create')
name = ctx.get_param_as_string('name', required=True) name = ctx.get_param_as_string('name')
password = ctx.get_param_as_string('password', required=True) password = ctx.get_param_as_string('password')
email = ctx.get_param_as_string('email', required=False, default='') email = ctx.get_param_as_string('email', default='')
user = users.create_user(name, password, email) user = users.create_user(name, password, email)
if ctx.has_param('rank'): if ctx.has_param('rank'):
users.update_user_rank( users.update_user_rank(user, ctx.get_param_as_string('rank'), ctx.user)
user, ctx.get_param_as_string('rank'), ctx.user)
if ctx.has_param('avatarStyle'): if ctx.has_param('avatarStyle'):
users.update_user_avatar( users.update_user_avatar(
user, user,
ctx.get_param_as_string('avatarStyle'), ctx.get_param_as_string('avatarStyle'),
ctx.get_file('avatar')) ctx.get_file('avatar', default=b''))
ctx.session.add(user) ctx.session.add(user)
ctx.session.commit() ctx.session.commit()
return _serialize(ctx, user, force_show_email=True) return _serialize(ctx, user, force_show_email=True)
@routes.get('/user/(?P<user_name>[^/]+)/?') @rest.routes.get('/user/(?P<user_name>[^/]+)/?')
def get_user(ctx, params): def get_user(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
user = users.get_user_by_name(params['user_name']) user = users.get_user_by_name(params['user_name'])
if ctx.user.user_id != user.user_id: if ctx.user.user_id != user.user_id:
auth.verify_privilege(ctx.user, 'users:view') auth.verify_privilege(ctx.user, 'users:view')
return _serialize(ctx, user) return _serialize(ctx, user)
@routes.put('/user/(?P<user_name>[^/]+)/?') @rest.routes.put('/user/(?P<user_name>[^/]+)/?')
def update_user(ctx, params): def update_user(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
user = users.get_user_by_name(params['user_name']) user = users.get_user_by_name(params['user_name'])
versions.verify_version(user, ctx) versions.verify_version(user, ctx)
versions.bump_version(user) versions.bump_version(user)
@ -74,13 +75,13 @@ def update_user(ctx, params):
users.update_user_avatar( users.update_user_avatar(
user, user,
ctx.get_param_as_string('avatarStyle'), ctx.get_param_as_string('avatarStyle'),
ctx.get_file('avatar')) ctx.get_file('avatar', default=b''))
ctx.session.commit() ctx.session.commit()
return _serialize(ctx, user) return _serialize(ctx, user)
@routes.delete('/user/(?P<user_name>[^/]+)/?') @rest.routes.delete('/user/(?P<user_name>[^/]+)/?')
def delete_user(ctx, params): def delete_user(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
user = users.get_user_by_name(params['user_name']) user = users.get_user_by_name(params['user_name'])
versions.verify_version(user, ctx) versions.verify_version(user, ctx)
infix = 'self' if ctx.user.user_id == user.user_id else 'any' infix = 'self' if ctx.user.user_id == user.user_id else 'any'

View file

@ -1,8 +1,9 @@
from typing import Dict
import os import os
import yaml import yaml
def merge(left, right): def merge(left: Dict, right: Dict) -> Dict:
for key in right: for key in right:
if key in left: if key in left:
if isinstance(left[key], dict) and isinstance(right[key], dict): if isinstance(left[key], dict) and isinstance(right[key], dict):
@ -14,7 +15,7 @@ def merge(left, right):
return left return left
def read_config(): def read_config() -> Dict:
with open('../config.yaml.dist') as handle: with open('../config.yaml.dist') as handle:
ret = yaml.load(handle.read()) ret = yaml.load(handle.read())
if os.path.exists('../config.yaml'): if os.path.exists('../config.yaml'):

36
server/szurubooru/db.py Normal file
View file

@ -0,0 +1,36 @@
from typing import Any
import threading
import sqlalchemy as sa
import sqlalchemy.orm
from szurubooru import config
# pylint: disable=invalid-name
_data = threading.local()
_engine = sa.create_engine(config.config['database']) # type: Any
sessionmaker = sa.orm.sessionmaker(bind=_engine, autoflush=False) # type: Any
session = sa.orm.scoped_session(sessionmaker) # type: Any
def get_session() -> Any:
global session
return session
def set_sesssion(new_session: Any) -> None:
global session
session = new_session
def reset_query_count() -> None:
_data.query_count = 0
def get_query_count() -> int:
return _data.query_count
def _bump_query_count() -> None:
_data.query_count = getattr(_data, 'query_count', 0) + 1
sa.event.listen(_engine, 'after_execute', lambda *args: _bump_query_count())

View file

@ -1,17 +0,0 @@
from szurubooru.db.base import Base
from szurubooru.db.user import User
from szurubooru.db.tag_category import TagCategory
from szurubooru.db.tag import (Tag, TagName, TagSuggestion, TagImplication)
from szurubooru.db.post import (
Post,
PostTag,
PostRelation,
PostFavorite,
PostScore,
PostNote,
PostFeature)
from szurubooru.db.comment import (Comment, CommentScore)
from szurubooru.db.snapshot import Snapshot
from szurubooru.db.session import (
session, sessionmaker, reset_query_count, get_query_count)
import szurubooru.db.util

View file

@ -1,27 +0,0 @@
import threading
import sqlalchemy
from szurubooru import config
# pylint: disable=invalid-name
_engine = sqlalchemy.create_engine(config.config['database'])
sessionmaker = sqlalchemy.orm.sessionmaker(bind=_engine, autoflush=False)
session = sqlalchemy.orm.scoped_session(sessionmaker)
_data = threading.local()
def reset_query_count():
_data.query_count = 0
def get_query_count():
return _data.query_count
def _bump_query_count():
_data.query_count = getattr(_data, 'query_count', 0) + 1
sqlalchemy.event.listen(
_engine, 'after_execute', lambda *args: _bump_query_count())

View file

@ -1,34 +0,0 @@
from sqlalchemy.inspection import inspect
def get_resource_info(entity):
serializers = {
'tag': lambda tag: tag.first_name,
'tag_category': lambda category: category.name,
'comment': lambda comment: comment.comment_id,
'post': lambda post: post.post_id,
}
resource_type = entity.__table__.name
assert resource_type in serializers
primary_key = inspect(entity).identity
assert primary_key is not None
assert len(primary_key) == 1
resource_name = serializers[resource_type](entity)
assert resource_name
resource_pkey = primary_key[0]
assert resource_pkey
return (resource_type, resource_pkey, resource_name)
def get_aux_entity(session, get_table_info, entity, user):
table, get_column = get_table_info(entity)
return session \
.query(table) \
.filter(get_column(table) == get_column(entity)) \
.filter(table.user_id == user.user_id) \
.one_or_none()

View file

@ -1,5 +1,11 @@
from typing import Dict
class BaseError(RuntimeError): class BaseError(RuntimeError):
def __init__(self, message='Unknown error', extra_fields=None): def __init__(
self,
message: str='Unknown error',
extra_fields: Dict[str, str]=None) -> None:
super().__init__(message) super().__init__(message)
self.extra_fields = extra_fields self.extra_fields = extra_fields

View file

@ -2,7 +2,10 @@ import os
import time import time
import logging import logging
import threading import threading
from typing import Callable, Any, Type
import coloredlogs import coloredlogs
import sqlalchemy as sa
import sqlalchemy.orm.exc import sqlalchemy.orm.exc
from szurubooru import config, db, errors, rest from szurubooru import config, db, errors, rest
from szurubooru.func import posts, file_uploads from szurubooru.func import posts, file_uploads
@ -10,7 +13,10 @@ from szurubooru.func import posts, file_uploads
from szurubooru import api, middleware from szurubooru import api, middleware
def _map_error(ex, target_class, title): def _map_error(
ex: Exception,
target_class: Type[rest.errors.BaseHttpError],
title: str) -> rest.errors.BaseHttpError:
return target_class( return target_class(
name=type(ex).__name__, name=type(ex).__name__,
title=title, title=title,
@ -18,38 +24,38 @@ def _map_error(ex, target_class, title):
extra_fields=getattr(ex, 'extra_fields', {})) extra_fields=getattr(ex, 'extra_fields', {}))
def _on_auth_error(ex): def _on_auth_error(ex: Exception) -> None:
raise _map_error(ex, rest.errors.HttpForbidden, 'Authentication error') raise _map_error(ex, rest.errors.HttpForbidden, 'Authentication error')
def _on_validation_error(ex): def _on_validation_error(ex: Exception) -> None:
raise _map_error(ex, rest.errors.HttpBadRequest, 'Validation error') raise _map_error(ex, rest.errors.HttpBadRequest, 'Validation error')
def _on_search_error(ex): def _on_search_error(ex: Exception) -> None:
raise _map_error(ex, rest.errors.HttpBadRequest, 'Search error') raise _map_error(ex, rest.errors.HttpBadRequest, 'Search error')
def _on_integrity_error(ex): def _on_integrity_error(ex: Exception) -> None:
raise _map_error(ex, rest.errors.HttpConflict, 'Integrity violation') raise _map_error(ex, rest.errors.HttpConflict, 'Integrity violation')
def _on_not_found_error(ex): def _on_not_found_error(ex: Exception) -> None:
raise _map_error(ex, rest.errors.HttpNotFound, 'Not found') raise _map_error(ex, rest.errors.HttpNotFound, 'Not found')
def _on_processing_error(ex): def _on_processing_error(ex: Exception) -> None:
raise _map_error(ex, rest.errors.HttpBadRequest, 'Processing error') raise _map_error(ex, rest.errors.HttpBadRequest, 'Processing error')
def _on_third_party_error(ex): def _on_third_party_error(ex: Exception) -> None:
raise _map_error( raise _map_error(
ex, ex,
rest.errors.HttpInternalServerError, rest.errors.HttpInternalServerError,
'Server configuration error') 'Server configuration error')
def _on_stale_data_error(_ex): def _on_stale_data_error(_ex: Exception) -> None:
raise rest.errors.HttpConflict( raise rest.errors.HttpConflict(
name='IntegrityError', name='IntegrityError',
title='Integrity violation', title='Integrity violation',
@ -58,7 +64,7 @@ def _on_stale_data_error(_ex):
'Please try again.')) 'Please try again.'))
def validate_config(): def validate_config() -> None:
''' '''
Check whether config doesn't contain errors that might prove Check whether config doesn't contain errors that might prove
lethal at runtime. lethal at runtime.
@ -86,7 +92,7 @@ def validate_config():
raise errors.ConfigError('Database is not configured') raise errors.ConfigError('Database is not configured')
def purge_old_uploads(): def purge_old_uploads() -> None:
while True: while True:
try: try:
file_uploads.purge_old_uploads() file_uploads.purge_old_uploads()
@ -95,7 +101,7 @@ def purge_old_uploads():
time.sleep(60 * 5) time.sleep(60 * 5)
def create_app(): def create_app() -> Callable[[Any, Any], Any]:
''' Create a WSGI compatible App object. ''' ''' Create a WSGI compatible App object. '''
validate_config() validate_config()
coloredlogs.install(fmt='[%(asctime)-15s] %(name)s %(message)s') coloredlogs.install(fmt='[%(asctime)-15s] %(name)s %(message)s')
@ -122,7 +128,7 @@ def create_app():
rest.errors.handle(errors.NotFoundError, _on_not_found_error) rest.errors.handle(errors.NotFoundError, _on_not_found_error)
rest.errors.handle(errors.ProcessingError, _on_processing_error) rest.errors.handle(errors.ProcessingError, _on_processing_error)
rest.errors.handle(errors.ThirdPartyError, _on_third_party_error) rest.errors.handle(errors.ThirdPartyError, _on_third_party_error)
rest.errors.handle(sqlalchemy.orm.exc.StaleDataError, _on_stale_data_error) rest.errors.handle(sa.orm.exc.StaleDataError, _on_stale_data_error)
return rest.application return rest.application

View file

@ -1,22 +1,22 @@
import hashlib import hashlib
import random import random
from collections import OrderedDict from collections import OrderedDict
from szurubooru import config, db, errors from szurubooru import config, model, errors
from szurubooru.func import util from szurubooru.func import util
RANK_MAP = OrderedDict([ RANK_MAP = OrderedDict([
(db.User.RANK_ANONYMOUS, 'anonymous'), (model.User.RANK_ANONYMOUS, 'anonymous'),
(db.User.RANK_RESTRICTED, 'restricted'), (model.User.RANK_RESTRICTED, 'restricted'),
(db.User.RANK_REGULAR, 'regular'), (model.User.RANK_REGULAR, 'regular'),
(db.User.RANK_POWER, 'power'), (model.User.RANK_POWER, 'power'),
(db.User.RANK_MODERATOR, 'moderator'), (model.User.RANK_MODERATOR, 'moderator'),
(db.User.RANK_ADMINISTRATOR, 'administrator'), (model.User.RANK_ADMINISTRATOR, 'administrator'),
(db.User.RANK_NOBODY, 'nobody'), (model.User.RANK_NOBODY, 'nobody'),
]) ])
def get_password_hash(salt, password): def get_password_hash(salt: str, password: str) -> str:
''' Retrieve new-style password hash. ''' ''' Retrieve new-style password hash. '''
digest = hashlib.sha256() digest = hashlib.sha256()
digest.update(config.config['secret'].encode('utf8')) digest.update(config.config['secret'].encode('utf8'))
@ -25,7 +25,7 @@ def get_password_hash(salt, password):
return digest.hexdigest() return digest.hexdigest()
def get_legacy_password_hash(salt, password): def get_legacy_password_hash(salt: str, password: str) -> str:
''' Retrieve old-style password hash. ''' ''' Retrieve old-style password hash. '''
digest = hashlib.sha1() digest = hashlib.sha1()
digest.update(b'1A2/$_4xVa') digest.update(b'1A2/$_4xVa')
@ -34,7 +34,7 @@ def get_legacy_password_hash(salt, password):
return digest.hexdigest() return digest.hexdigest()
def create_password(): def create_password() -> str:
alphabet = { alphabet = {
'c': list('bcdfghijklmnpqrstvwxyz'), 'c': list('bcdfghijklmnpqrstvwxyz'),
'v': list('aeiou'), 'v': list('aeiou'),
@ -44,7 +44,7 @@ def create_password():
return ''.join(random.choice(alphabet[l]) for l in list(pattern)) return ''.join(random.choice(alphabet[l]) for l in list(pattern))
def is_valid_password(user, password): def is_valid_password(user: model.User, password: str) -> bool:
assert user assert user
salt, valid_hash = user.password_salt, user.password_hash salt, valid_hash = user.password_salt, user.password_hash
possible_hashes = [ possible_hashes = [
@ -54,7 +54,7 @@ def is_valid_password(user, password):
return valid_hash in possible_hashes return valid_hash in possible_hashes
def has_privilege(user, privilege_name): def has_privilege(user: model.User, privilege_name: str) -> bool:
assert user assert user
all_ranks = list(RANK_MAP.keys()) all_ranks = list(RANK_MAP.keys())
assert privilege_name in config.config['privileges'] assert privilege_name in config.config['privileges']
@ -65,13 +65,13 @@ def has_privilege(user, privilege_name):
return user.rank in good_ranks return user.rank in good_ranks
def verify_privilege(user, privilege_name): def verify_privilege(user: model.User, privilege_name: str) -> None:
assert user assert user
if not has_privilege(user, privilege_name): if not has_privilege(user, privilege_name):
raise errors.AuthError('Insufficient privileges to do this.') raise errors.AuthError('Insufficient privileges to do this.')
def generate_authentication_token(user): def generate_authentication_token(user: model.User) -> str:
''' Generate nonguessable challenge (e.g. links in password reminder). ''' ''' Generate nonguessable challenge (e.g. links in password reminder). '''
assert user assert user
digest = hashlib.md5() digest = hashlib.md5()

View file

@ -1,21 +1,21 @@
from typing import Any, List, Dict
from datetime import datetime from datetime import datetime
class LruCacheItem: class LruCacheItem:
def __init__(self, key, value): def __init__(self, key: object, value: Any) -> None:
self.key = key self.key = key
self.value = value self.value = value
self.timestamp = datetime.utcnow() self.timestamp = datetime.utcnow()
class LruCache: class LruCache:
def __init__(self, length, delta=None): def __init__(self, length: int) -> None:
self.length = length self.length = length
self.delta = delta self.hash = {} # type: Dict[object, LruCacheItem]
self.hash = {} self.item_list = [] # type: List[LruCacheItem]
self.item_list = []
def insert_item(self, item): def insert_item(self, item: LruCacheItem) -> None:
if item.key in self.hash: if item.key in self.hash:
item_index = next( item_index = next(
i i
@ -31,11 +31,11 @@ class LruCache:
self.hash[item.key] = item self.hash[item.key] = item
self.item_list.insert(0, item) self.item_list.insert(0, item)
def remove_all(self): def remove_all(self) -> None:
self.hash = {} self.hash = {}
self.item_list = [] self.item_list = []
def remove_item(self, item): def remove_item(self, item: LruCacheItem) -> None:
del self.hash[item.key] del self.hash[item.key]
del self.item_list[self.item_list.index(item)] del self.item_list[self.item_list.index(item)]
@ -43,22 +43,22 @@ class LruCache:
_CACHE = LruCache(length=100) _CACHE = LruCache(length=100)
def purge(): def purge() -> None:
_CACHE.remove_all() _CACHE.remove_all()
def has(key): def has(key: object) -> bool:
return key in _CACHE.hash return key in _CACHE.hash
def get(key): def get(key: object) -> Any:
return _CACHE.hash[key].value return _CACHE.hash[key].value
def remove(key): def remove(key: object) -> None:
if has(key): if has(key):
del _CACHE.hash[key] del _CACHE.hash[key]
def put(key, value): def put(key: object, value: Any) -> None:
_CACHE.insert_item(LruCacheItem(key, value)) _CACHE.insert_item(LruCacheItem(key, value))

View file

@ -1,6 +1,7 @@
import datetime from datetime import datetime
from szurubooru import db, errors from typing import Any, Optional, List, Dict, Callable
from szurubooru.func import users, scores, util from szurubooru import db, model, errors, rest
from szurubooru.func import users, scores, util, serialization
class InvalidCommentIdError(errors.ValidationError): class InvalidCommentIdError(errors.ValidationError):
@ -15,52 +16,87 @@ class EmptyCommentTextError(errors.ValidationError):
pass pass
def serialize_comment(comment, auth_user, options=None): class CommentSerializer(serialization.BaseSerializer):
return util.serialize_entity( def __init__(self, comment: model.Comment, auth_user: model.User) -> None:
comment, self.comment = comment
{ self.auth_user = auth_user
'id': lambda: comment.comment_id,
'user': def _serializers(self) -> Dict[str, Callable[[], Any]]:
lambda: users.serialize_micro_user(comment.user, auth_user), return {
'postId': lambda: comment.post.post_id, 'id': self.serialize_id,
'version': lambda: comment.version, 'user': self.serialize_user,
'text': lambda: comment.text, 'postId': self.serialize_post_id,
'creationTime': lambda: comment.creation_time, 'version': self.serialize_version,
'lastEditTime': lambda: comment.last_edit_time, 'text': self.serialize_text,
'score': lambda: comment.score, 'creationTime': self.serialize_creation_time,
'ownScore': lambda: scores.get_score(comment, auth_user), 'lastEditTime': self.serialize_last_edit_time,
}, 'score': self.serialize_score,
options) 'ownScore': self.serialize_own_score,
}
def serialize_id(self) -> Any:
return self.comment.comment_id
def serialize_user(self) -> Any:
return users.serialize_micro_user(self.comment.user, self.auth_user)
def serialize_post_id(self) -> Any:
return self.comment.post.post_id
def serialize_version(self) -> Any:
return self.comment.version
def serialize_text(self) -> Any:
return self.comment.text
def serialize_creation_time(self) -> Any:
return self.comment.creation_time
def serialize_last_edit_time(self) -> Any:
return self.comment.last_edit_time
def serialize_score(self) -> Any:
return self.comment.score
def serialize_own_score(self) -> Any:
return scores.get_score(self.comment, self.auth_user)
def try_get_comment_by_id(comment_id): def serialize_comment(
try: comment: model.Comment,
auth_user: model.User,
options: List[str]=[]) -> rest.Response:
if comment is None:
return None
return CommentSerializer(comment, auth_user).serialize(options)
def try_get_comment_by_id(comment_id: int) -> Optional[model.Comment]:
comment_id = int(comment_id) comment_id = int(comment_id)
except ValueError:
raise InvalidCommentIdError('Invalid comment ID: %r.' % comment_id)
return db.session \ return db.session \
.query(db.Comment) \ .query(model.Comment) \
.filter(db.Comment.comment_id == comment_id) \ .filter(model.Comment.comment_id == comment_id) \
.one_or_none() .one_or_none()
def get_comment_by_id(comment_id): def get_comment_by_id(comment_id: int) -> model.Comment:
comment = try_get_comment_by_id(comment_id) comment = try_get_comment_by_id(comment_id)
if comment: if comment:
return comment return comment
raise CommentNotFoundError('Comment %r not found.' % comment_id) raise CommentNotFoundError('Comment %r not found.' % comment_id)
def create_comment(user, post, text): def create_comment(
comment = db.Comment() user: model.User, post: model.Post, text: str) -> model.Comment:
comment = model.Comment()
comment.user = user comment.user = user
comment.post = post comment.post = post
update_comment_text(comment, text) update_comment_text(comment, text)
comment.creation_time = datetime.datetime.utcnow() comment.creation_time = datetime.utcnow()
return comment return comment
def update_comment_text(comment, text): def update_comment_text(comment: model.Comment, text: str) -> None:
assert comment assert comment
if not text: if not text:
raise EmptyCommentTextError('Comment text cannot be empty.') raise EmptyCommentTextError('Comment text cannot be empty.')

View file

@ -1,21 +1,26 @@
def get_list_diff(old, new): from typing import List, Dict, Any
value = {'type': 'list change', 'added': [], 'removed': []}
def get_list_diff(old: List[Any], new: List[Any]) -> Any:
equal = True equal = True
removed = [] # type: List[Any]
added = [] # type: List[Any]
for item in old: for item in old:
if item not in new: if item not in new:
equal = False equal = False
value['removed'].append(item) removed.append(item)
for item in new: for item in new:
if item not in old: if item not in old:
equal = False equal = False
value['added'].append(item) added.append(item)
return None if equal else value return None if equal else {
'type': 'list change', 'added': added, 'removed': removed}
def get_dict_diff(old, new): def get_dict_diff(old: Dict[str, Any], new: Dict[str, Any]) -> Any:
value = {} value = {}
equal = True equal = True

View file

@ -1,32 +1,34 @@
import datetime from typing import Any, Optional, Callable, Tuple
from szurubooru import db, errors from datetime import datetime
from szurubooru import db, model, errors
class InvalidFavoriteTargetError(errors.ValidationError): class InvalidFavoriteTargetError(errors.ValidationError):
pass pass
def _get_table_info(entity): def _get_table_info(
entity: model.Base) -> Tuple[model.Base, Callable[[model.Base], Any]]:
assert entity assert entity
resource_type, _, _ = db.util.get_resource_info(entity) resource_type, _, _ = model.util.get_resource_info(entity)
if resource_type == 'post': if resource_type == 'post':
return db.PostFavorite, lambda table: table.post_id return model.PostFavorite, lambda table: table.post_id
raise InvalidFavoriteTargetError() raise InvalidFavoriteTargetError()
def _get_fav_entity(entity, user): def _get_fav_entity(entity: model.Base, user: model.User) -> model.Base:
assert entity assert entity
assert user assert user
return db.util.get_aux_entity(db.session, _get_table_info, entity, user) return model.util.get_aux_entity(db.session, _get_table_info, entity, user)
def has_favorited(entity, user): def has_favorited(entity: model.Base, user: model.User) -> bool:
assert entity assert entity
assert user assert user
return _get_fav_entity(entity, user) is not None return _get_fav_entity(entity, user) is not None
def unset_favorite(entity, user): def unset_favorite(entity: model.Base, user: Optional[model.User]) -> None:
assert entity assert entity
assert user assert user
fav_entity = _get_fav_entity(entity, user) fav_entity = _get_fav_entity(entity, user)
@ -34,7 +36,7 @@ def unset_favorite(entity, user):
db.session.delete(fav_entity) db.session.delete(fav_entity)
def set_favorite(entity, user): def set_favorite(entity: model.Base, user: Optional[model.User]) -> None:
from szurubooru.func import scores from szurubooru.func import scores
assert entity assert entity
assert user assert user
@ -48,5 +50,5 @@ def set_favorite(entity, user):
fav_entity = table() fav_entity = table()
setattr(fav_entity, get_column(table).name, get_column(entity)) setattr(fav_entity, get_column(table).name, get_column(entity))
fav_entity.user = user fav_entity.user = user
fav_entity.time = datetime.datetime.utcnow() fav_entity.time = datetime.utcnow()
db.session.add(fav_entity) db.session.add(fav_entity)

View file

@ -1,27 +1,28 @@
import datetime from typing import Optional
from datetime import datetime, timedelta
from szurubooru.func import files, util from szurubooru.func import files, util
MAX_MINUTES = 60 MAX_MINUTES = 60
def _get_path(checksum): def _get_path(checksum: str) -> str:
return 'temporary-uploads/%s.dat' % checksum return 'temporary-uploads/%s.dat' % checksum
def purge_old_uploads(): def purge_old_uploads() -> None:
now = datetime.datetime.now() now = datetime.now()
for file in files.scan('temporary-uploads'): for file in files.scan('temporary-uploads'):
file_time = datetime.datetime.fromtimestamp(file.stat().st_ctime) file_time = datetime.fromtimestamp(file.stat().st_ctime)
if now - file_time > datetime.timedelta(minutes=MAX_MINUTES): if now - file_time > timedelta(minutes=MAX_MINUTES):
files.delete('temporary-uploads/%s' % file.name) files.delete('temporary-uploads/%s' % file.name)
def get(checksum): def get(checksum: str) -> Optional[bytes]:
return files.get('temporary-uploads/%s.dat' % checksum) return files.get('temporary-uploads/%s.dat' % checksum)
def save(content): def save(content: bytes) -> str:
checksum = util.get_sha1(content) checksum = util.get_sha1(content)
path = _get_path(checksum) path = _get_path(checksum)
if not files.has(path): if not files.has(path):

View file

@ -1,32 +1,33 @@
from typing import Any, Optional, List
import os import os
from szurubooru import config from szurubooru import config
def _get_full_path(path): def _get_full_path(path: str) -> str:
return os.path.join(config.config['data_dir'], path) return os.path.join(config.config['data_dir'], path)
def delete(path): def delete(path: str) -> None:
full_path = _get_full_path(path) full_path = _get_full_path(path)
if os.path.exists(full_path): if os.path.exists(full_path):
os.unlink(full_path) os.unlink(full_path)
def has(path): def has(path: str) -> bool:
return os.path.exists(_get_full_path(path)) return os.path.exists(_get_full_path(path))
def scan(path): def scan(path: str) -> List[os.DirEntry]:
if has(path): if has(path):
return os.scandir(_get_full_path(path)) return list(os.scandir(_get_full_path(path)))
return [] return []
def move(source_path, target_path): def move(source_path: str, target_path: str) -> None:
return os.rename(_get_full_path(source_path), _get_full_path(target_path)) os.rename(_get_full_path(source_path), _get_full_path(target_path))
def get(path): def get(path: str) -> Optional[bytes]:
full_path = _get_full_path(path) full_path = _get_full_path(path)
if not os.path.exists(full_path): if not os.path.exists(full_path):
return None return None
@ -34,7 +35,7 @@ def get(path):
return handle.read() return handle.read()
def save(path, content): def save(path: str, content: bytes) -> None:
full_path = _get_full_path(path) full_path = _get_full_path(path)
os.makedirs(os.path.dirname(full_path), exist_ok=True) os.makedirs(os.path.dirname(full_path), exist_ok=True)
with open(full_path, 'wb') as handle: with open(full_path, 'wb') as handle:

View file

@ -1,6 +1,7 @@
import logging import logging
from io import BytesIO from io import BytesIO
from datetime import datetime from datetime import datetime
from typing import Any, Optional, Tuple, Set, List, Callable
import elasticsearch import elasticsearch
import elasticsearch_dsl import elasticsearch_dsl
import numpy as np import numpy as np
@ -10,13 +11,8 @@ from szurubooru import config, errors
# pylint: disable=invalid-name # pylint: disable=invalid-name
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
es = elasticsearch.Elasticsearch([{
'host': config.config['elasticsearch']['host'],
'port': config.config['elasticsearch']['port'],
}])
# Math based on paper from H. Chi Wong, Marshall Bern and David Goldberg
# Math based on paper from H. Chi Wong, Marshall Bern and David Goldber
# Math code taken from https://github.com/ascribe/image-match # Math code taken from https://github.com/ascribe/image-match
# (which is licensed under Apache 2 license) # (which is licensed under Apache 2 license)
@ -32,14 +28,27 @@ MAX_WORDS = 63
ES_DOC_TYPE = 'image' ES_DOC_TYPE = 'image'
ES_MAX_RESULTS = 100 ES_MAX_RESULTS = 100
Window = Tuple[Tuple[float, float], Tuple[float, float]]
NpMatrix = Any
def _preprocess_image(image_or_path):
img = Image.open(BytesIO(image_or_path)) def _get_session() -> elasticsearch.Elasticsearch:
return elasticsearch.Elasticsearch([{
'host': config.config['elasticsearch']['host'],
'port': config.config['elasticsearch']['port'],
}])
def _preprocess_image(content: bytes) -> NpMatrix:
img = Image.open(BytesIO(content))
img = img.convert('RGB') img = img.convert('RGB')
return rgb2gray(np.asarray(img, dtype=np.uint8)) return rgb2gray(np.asarray(img, dtype=np.uint8))
def _crop_image(image, lower_percentile, upper_percentile): def _crop_image(
image: NpMatrix,
lower_percentile: float,
upper_percentile: float) -> Window:
rw = np.cumsum(np.sum(np.abs(np.diff(image, axis=1)), axis=1)) rw = np.cumsum(np.sum(np.abs(np.diff(image, axis=1)), axis=1))
cw = np.cumsum(np.sum(np.abs(np.diff(image, axis=0)), axis=0)) cw = np.cumsum(np.sum(np.abs(np.diff(image, axis=0)), axis=0))
upper_column_limit = np.searchsorted( upper_column_limit = np.searchsorted(
@ -56,16 +65,19 @@ def _crop_image(image, lower_percentile, upper_percentile):
if lower_column_limit > upper_column_limit: if lower_column_limit > upper_column_limit:
lower_column_limit = int(lower_percentile / 100. * image.shape[1]) lower_column_limit = int(lower_percentile / 100. * image.shape[1])
upper_column_limit = int(upper_percentile / 100. * image.shape[1]) upper_column_limit = int(upper_percentile / 100. * image.shape[1])
return [ return (
(lower_row_limit, upper_row_limit), (lower_row_limit, upper_row_limit),
(lower_column_limit, upper_column_limit)] (lower_column_limit, upper_column_limit))
def _normalize_and_threshold(diff_array, identical_tolerance, n_levels): def _normalize_and_threshold(
diff_array: NpMatrix,
identical_tolerance: float,
n_levels: int) -> None:
mask = np.abs(diff_array) < identical_tolerance mask = np.abs(diff_array) < identical_tolerance
diff_array[mask] = 0. diff_array[mask] = 0.
if np.all(mask): if np.all(mask):
return None return
positive_cutoffs = np.percentile( positive_cutoffs = np.percentile(
diff_array[diff_array > 0.], np.linspace(0, 100, n_levels + 1)) diff_array[diff_array > 0.], np.linspace(0, 100, n_levels + 1))
negative_cutoffs = np.percentile( negative_cutoffs = np.percentile(
@ -82,18 +94,24 @@ def _normalize_and_threshold(diff_array, identical_tolerance, n_levels):
diff_array[ diff_array[
(diff_array <= interval[0]) & (diff_array >= interval[1])] = \ (diff_array <= interval[0]) & (diff_array >= interval[1])] = \
-(level + 1) -(level + 1)
return None
def _compute_grid_points(image, n, window=None): def _compute_grid_points(
image: NpMatrix,
n: float,
window: Window=None) -> Tuple[NpMatrix, NpMatrix]:
if window is None: if window is None:
window = [(0, image.shape[0]), (0, image.shape[1])] window = ((0, image.shape[0]), (0, image.shape[1]))
x_coords = np.linspace(window[0][0], window[0][1], n + 2, dtype=int)[1:-1] x_coords = np.linspace(window[0][0], window[0][1], n + 2, dtype=int)[1:-1]
y_coords = np.linspace(window[1][0], window[1][1], n + 2, dtype=int)[1:-1] y_coords = np.linspace(window[1][0], window[1][1], n + 2, dtype=int)[1:-1]
return x_coords, y_coords return x_coords, y_coords
def _compute_mean_level(image, x_coords, y_coords, p): def _compute_mean_level(
image: NpMatrix,
x_coords: NpMatrix,
y_coords: NpMatrix,
p: Optional[float]) -> NpMatrix:
if p is None: if p is None:
p = max([2.0, int(0.5 + min(image.shape) / 20.)]) p = max([2.0, int(0.5 + min(image.shape) / 20.)])
avg_grey = np.zeros((x_coords.shape[0], y_coords.shape[0])) avg_grey = np.zeros((x_coords.shape[0], y_coords.shape[0]))
@ -108,7 +126,7 @@ def _compute_mean_level(image, x_coords, y_coords, p):
return avg_grey return avg_grey
def _compute_differentials(grey_level_matrix): def _compute_differentials(grey_level_matrix: NpMatrix) -> NpMatrix:
flipped = np.fliplr(grey_level_matrix) flipped = np.fliplr(grey_level_matrix)
right_neighbors = -np.concatenate( right_neighbors = -np.concatenate(
( (
@ -152,8 +170,8 @@ def _compute_differentials(grey_level_matrix):
lower_right_neighbors])) lower_right_neighbors]))
def _generate_signature(path_or_image): def _generate_signature(content: bytes) -> NpMatrix:
im_array = _preprocess_image(path_or_image) im_array = _preprocess_image(content)
image_limits = _crop_image( image_limits = _crop_image(
im_array, im_array,
lower_percentile=LOWER_PERCENTILE, lower_percentile=LOWER_PERCENTILE,
@ -169,7 +187,7 @@ def _generate_signature(path_or_image):
return np.ravel(diff_matrix).astype('int8') return np.ravel(diff_matrix).astype('int8')
def _get_words(array, k, n): def _get_words(array: NpMatrix, k: int, n: int) -> NpMatrix:
word_positions = np.linspace( word_positions = np.linspace(
0, array.shape[0], n, endpoint=False).astype('int') 0, array.shape[0], n, endpoint=False).astype('int')
assert k <= array.shape[0] assert k <= array.shape[0]
@ -187,21 +205,23 @@ def _get_words(array, k, n):
return words return words
def _words_to_int(word_array): def _words_to_int(word_array: NpMatrix) -> NpMatrix:
width = word_array.shape[1] width = word_array.shape[1]
coding_vector = 3**np.arange(width) coding_vector = 3**np.arange(width)
return np.dot(word_array + 1, coding_vector) return np.dot(word_array + 1, coding_vector)
def _max_contrast(array): def _max_contrast(array: NpMatrix) -> None:
array[array > 0] = 1 array[array > 0] = 1
array[array < 0] = -1 array[array < 0] = -1
return None
def _normalized_distance(_target_array, _vec, nan_value=1.0): def _normalized_distance(
target_array = _target_array.astype(int) target_array: NpMatrix,
vec = _vec.astype(int) vec: NpMatrix,
nan_value: float=1.0) -> List[float]:
target_array = target_array.astype(int)
vec = vec.astype(int)
topvec = np.linalg.norm(vec - target_array, axis=1) topvec = np.linalg.norm(vec - target_array, axis=1)
norm1 = np.linalg.norm(vec, axis=0) norm1 = np.linalg.norm(vec, axis=0)
norm2 = np.linalg.norm(target_array, axis=1) norm2 = np.linalg.norm(target_array, axis=1)
@ -210,9 +230,9 @@ def _normalized_distance(_target_array, _vec, nan_value=1.0):
return finvec return finvec
def _safety_blanket(default_param_factory): def _safety_blanket(default_param_factory: Callable[[], Any]) -> Callable:
def wrapper_outer(target_function): def wrapper_outer(target_function: Callable) -> Callable:
def wrapper_inner(*args, **kwargs): def wrapper_inner(*args: Any, **kwargs: Any) -> Any:
try: try:
return target_function(*args, **kwargs) return target_function(*args, **kwargs)
except elasticsearch.exceptions.NotFoundError: except elasticsearch.exceptions.NotFoundError:
@ -226,20 +246,20 @@ def _safety_blanket(default_param_factory):
except IOError: except IOError:
raise errors.ProcessingError('Not an image.') raise errors.ProcessingError('Not an image.')
except Exception as ex: except Exception as ex:
raise errors.ThirdPartyError('Unknown error (%s).', ex) raise errors.ThirdPartyError('Unknown error (%s).' % ex)
return wrapper_inner return wrapper_inner
return wrapper_outer return wrapper_outer
class Lookalike: class Lookalike:
def __init__(self, score, distance, path): def __init__(self, score: int, distance: float, path: Any) -> None:
self.score = score self.score = score
self.distance = distance self.distance = distance
self.path = path self.path = path
@_safety_blanket(lambda: None) @_safety_blanket(lambda: None)
def add_image(path, image_content): def add_image(path: str, image_content: bytes) -> None:
assert path assert path
assert image_content assert image_content
signature = _generate_signature(image_content) signature = _generate_signature(image_content)
@ -253,7 +273,7 @@ def add_image(path, image_content):
for i in range(MAX_WORDS): for i in range(MAX_WORDS):
record['simple_word_' + str(i)] = words[i].tolist() record['simple_word_' + str(i)] = words[i].tolist()
es.index( _get_session().index(
index=config.config['elasticsearch']['index'], index=config.config['elasticsearch']['index'],
doc_type=ES_DOC_TYPE, doc_type=ES_DOC_TYPE,
body=record, body=record,
@ -261,20 +281,20 @@ def add_image(path, image_content):
@_safety_blanket(lambda: None) @_safety_blanket(lambda: None)
def delete_image(path): def delete_image(path: str) -> None:
assert path assert path
es.delete_by_query( _get_session().delete_by_query(
index=config.config['elasticsearch']['index'], index=config.config['elasticsearch']['index'],
doc_type=ES_DOC_TYPE, doc_type=ES_DOC_TYPE,
body={'query': {'term': {'path': path}}}) body={'query': {'term': {'path': path}}})
@_safety_blanket(lambda: []) @_safety_blanket(lambda: [])
def search_by_image(image_content): def search_by_image(image_content: bytes) -> List[Lookalike]:
signature = _generate_signature(image_content) signature = _generate_signature(image_content)
words = _get_words(signature, k=SAMPLE_WORDS, n=MAX_WORDS) words = _get_words(signature, k=SAMPLE_WORDS, n=MAX_WORDS)
res = es.search( res = _get_session().search(
index=config.config['elasticsearch']['index'], index=config.config['elasticsearch']['index'],
doc_type=ES_DOC_TYPE, doc_type=ES_DOC_TYPE,
body={ body={
@ -299,7 +319,7 @@ def search_by_image(image_content):
sigs = np.array([x['_source']['signature'] for x in res]) sigs = np.array([x['_source']['signature'] for x in res])
dists = _normalized_distance(sigs, np.array(signature)) dists = _normalized_distance(sigs, np.array(signature))
ids = set() ids = set() # type: Set[int]
ret = [] ret = []
for item, dist in zip(res, dists): for item, dist in zip(res, dists):
id = item['_id'] id = item['_id']
@ -314,8 +334,8 @@ def search_by_image(image_content):
@_safety_blanket(lambda: None) @_safety_blanket(lambda: None)
def purge(): def purge() -> None:
es.delete_by_query( _get_session().delete_by_query(
index=config.config['elasticsearch']['index'], index=config.config['elasticsearch']['index'],
doc_type=ES_DOC_TYPE, doc_type=ES_DOC_TYPE,
body={'query': {'match_all': {}}}, body={'query': {'match_all': {}}},
@ -323,10 +343,10 @@ def purge():
@_safety_blanket(lambda: set()) @_safety_blanket(lambda: set())
def get_all_paths(): def get_all_paths() -> Set[str]:
search = ( search = (
elasticsearch_dsl.Search( elasticsearch_dsl.Search(
using=es, using=_get_session(),
index=config.config['elasticsearch']['index'], index=config.config['elasticsearch']['index'],
doc_type=ES_DOC_TYPE) doc_type=ES_DOC_TYPE)
.source(['path'])) .source(['path']))

View file

@ -1,3 +1,4 @@
from typing import List
import logging import logging
import json import json
import shlex import shlex
@ -15,23 +16,23 @@ _SCALE_FIT_FMT = \
class Image: class Image:
def __init__(self, content): def __init__(self, content: bytes) -> None:
self.content = content self.content = content
self._reload_info() self._reload_info()
@property @property
def width(self): def width(self) -> int:
return self.info['streams'][0]['width'] return self.info['streams'][0]['width']
@property @property
def height(self): def height(self) -> int:
return self.info['streams'][0]['height'] return self.info['streams'][0]['height']
@property @property
def frames(self): def frames(self) -> int:
return self.info['streams'][0]['nb_read_frames'] return self.info['streams'][0]['nb_read_frames']
def resize_fill(self, width, height): def resize_fill(self, width: int, height: int) -> None:
cli = [ cli = [
'-i', '{path}', '-i', '{path}',
'-f', 'image2', '-f', 'image2',
@ -53,7 +54,7 @@ class Image:
assert self.content assert self.content
self._reload_info() self._reload_info()
def to_png(self): def to_png(self) -> bytes:
return self._execute([ return self._execute([
'-i', '{path}', '-i', '{path}',
'-f', 'image2', '-f', 'image2',
@ -63,7 +64,7 @@ class Image:
'-', '-',
]) ])
def to_jpeg(self): def to_jpeg(self) -> bytes:
return self._execute([ return self._execute([
'-f', 'lavfi', '-f', 'lavfi',
'-i', 'color=white:s=%dx%d' % (self.width, self.height), '-i', 'color=white:s=%dx%d' % (self.width, self.height),
@ -76,7 +77,7 @@ class Image:
'-', '-',
]) ])
def _execute(self, cli, program='ffmpeg'): def _execute(self, cli: List[str], program: str='ffmpeg') -> bytes:
extension = mime.get_extension(mime.get_mime_type(self.content)) extension = mime.get_extension(mime.get_mime_type(self.content))
assert extension assert extension
with util.create_temp_file(suffix='.' + extension) as handle: with util.create_temp_file(suffix='.' + extension) as handle:
@ -99,7 +100,7 @@ class Image:
'Error while processing image.\n' + err.decode('utf-8')) 'Error while processing image.\n' + err.decode('utf-8'))
return out return out
def _reload_info(self): def _reload_info(self) -> None:
self.info = json.loads(self._execute([ self.info = json.loads(self._execute([
'-i', '{path}', '-i', '{path}',
'-of', 'json', '-of', 'json',

View file

@ -3,7 +3,7 @@ import email.mime.text
from szurubooru import config from szurubooru import config
def send_mail(sender, recipient, subject, body): def send_mail(sender: str, recipient: str, subject: str, body: str) -> None:
msg = email.mime.text.MIMEText(body) msg = email.mime.text.MIMEText(body)
msg['Subject'] = subject msg['Subject'] = subject
msg['From'] = sender msg['From'] = sender

View file

@ -1,7 +1,8 @@
import re import re
from typing import Optional
def get_mime_type(content): def get_mime_type(content: bytes) -> str:
if not content: if not content:
return 'application/octet-stream' return 'application/octet-stream'
@ -26,7 +27,7 @@ def get_mime_type(content):
return 'application/octet-stream' return 'application/octet-stream'
def get_extension(mime_type): def get_extension(mime_type: str) -> Optional[str]:
extension_map = { extension_map = {
'application/x-shockwave-flash': 'swf', 'application/x-shockwave-flash': 'swf',
'image/gif': 'gif', 'image/gif': 'gif',
@ -39,19 +40,19 @@ def get_extension(mime_type):
return extension_map.get((mime_type or '').strip().lower(), None) return extension_map.get((mime_type or '').strip().lower(), None)
def is_flash(mime_type): def is_flash(mime_type: str) -> bool:
return mime_type.lower() == 'application/x-shockwave-flash' return mime_type.lower() == 'application/x-shockwave-flash'
def is_video(mime_type): def is_video(mime_type: str) -> bool:
return mime_type.lower() in ('application/ogg', 'video/mp4', 'video/webm') return mime_type.lower() in ('application/ogg', 'video/mp4', 'video/webm')
def is_image(mime_type): def is_image(mime_type: str) -> bool:
return mime_type.lower() in ('image/jpeg', 'image/png', 'image/gif') return mime_type.lower() in ('image/jpeg', 'image/png', 'image/gif')
def is_animated_gif(content): def is_animated_gif(content: bytes) -> bool:
pattern = b'\x21\xF9\x04[\x00-\xFF]{4}\x00[\x2C\x21]' pattern = b'\x21\xF9\x04[\x00-\xFF]{4}\x00[\x2C\x21]'
return get_mime_type(content) == 'image/gif' \ return get_mime_type(content) == 'image/gif' \
and len(re.findall(pattern, content)) > 1 and len(re.findall(pattern, content)) > 1

View file

@ -2,7 +2,7 @@ import urllib.request
from szurubooru import errors from szurubooru import errors
def download(url): def download(url: str) -> bytes:
assert url assert url
request = urllib.request.Request(url) request = urllib.request.Request(url)
request.add_header('Referer', url) request.add_header('Referer', url)

View file

@ -1,8 +1,10 @@
import datetime from typing import Any, Optional, Tuple, List, Dict, Callable
import sqlalchemy from datetime import datetime
from szurubooru import config, db, errors import sqlalchemy as sa
from szurubooru import config, db, model, errors, rest
from szurubooru.func import ( from szurubooru.func import (
users, scores, comments, tags, util, mime, images, files, image_hash) users, scores, comments, tags, util,
mime, images, files, image_hash, serialization)
EMPTY_PIXEL = \ EMPTY_PIXEL = \
@ -20,7 +22,7 @@ class PostAlreadyFeaturedError(errors.ValidationError):
class PostAlreadyUploadedError(errors.ValidationError): class PostAlreadyUploadedError(errors.ValidationError):
def __init__(self, other_post): def __init__(self, other_post: model.Post) -> None:
super().__init__( super().__init__(
'Post already uploaded (%d)' % other_post.post_id, 'Post already uploaded (%d)' % other_post.post_id,
{ {
@ -58,30 +60,30 @@ class InvalidPostFlagError(errors.ValidationError):
class PostLookalike(image_hash.Lookalike): class PostLookalike(image_hash.Lookalike):
def __init__(self, score, distance, post): def __init__(self, score: int, distance: float, post: model.Post) -> None:
super().__init__(score, distance, post.post_id) super().__init__(score, distance, post.post_id)
self.post = post self.post = post
SAFETY_MAP = { SAFETY_MAP = {
db.Post.SAFETY_SAFE: 'safe', model.Post.SAFETY_SAFE: 'safe',
db.Post.SAFETY_SKETCHY: 'sketchy', model.Post.SAFETY_SKETCHY: 'sketchy',
db.Post.SAFETY_UNSAFE: 'unsafe', model.Post.SAFETY_UNSAFE: 'unsafe',
} }
TYPE_MAP = { TYPE_MAP = {
db.Post.TYPE_IMAGE: 'image', model.Post.TYPE_IMAGE: 'image',
db.Post.TYPE_ANIMATION: 'animation', model.Post.TYPE_ANIMATION: 'animation',
db.Post.TYPE_VIDEO: 'video', model.Post.TYPE_VIDEO: 'video',
db.Post.TYPE_FLASH: 'flash', model.Post.TYPE_FLASH: 'flash',
} }
FLAG_MAP = { FLAG_MAP = {
db.Post.FLAG_LOOP: 'loop', model.Post.FLAG_LOOP: 'loop',
} }
def get_post_content_url(post): def get_post_content_url(post: model.Post) -> str:
assert post assert post
return '%s/posts/%d.%s' % ( return '%s/posts/%d.%s' % (
config.config['data_url'].rstrip('/'), config.config['data_url'].rstrip('/'),
@ -89,31 +91,31 @@ def get_post_content_url(post):
mime.get_extension(post.mime_type) or 'dat') mime.get_extension(post.mime_type) or 'dat')
def get_post_thumbnail_url(post): def get_post_thumbnail_url(post: model.Post) -> str:
assert post assert post
return '%s/generated-thumbnails/%d.jpg' % ( return '%s/generated-thumbnails/%d.jpg' % (
config.config['data_url'].rstrip('/'), config.config['data_url'].rstrip('/'),
post.post_id) post.post_id)
def get_post_content_path(post): def get_post_content_path(post: model.Post) -> str:
assert post assert post
assert post.post_id assert post.post_id
return 'posts/%d.%s' % ( return 'posts/%d.%s' % (
post.post_id, mime.get_extension(post.mime_type) or 'dat') post.post_id, mime.get_extension(post.mime_type) or 'dat')
def get_post_thumbnail_path(post): def get_post_thumbnail_path(post: model.Post) -> str:
assert post assert post
return 'generated-thumbnails/%d.jpg' % (post.post_id) return 'generated-thumbnails/%d.jpg' % (post.post_id)
def get_post_thumbnail_backup_path(post): def get_post_thumbnail_backup_path(post: model.Post) -> str:
assert post assert post
return 'posts/custom-thumbnails/%d.dat' % (post.post_id) return 'posts/custom-thumbnails/%d.dat' % (post.post_id)
def serialize_note(note): def serialize_note(note: model.PostNote) -> rest.Response:
assert note assert note
return { return {
'polygon': note.polygon, 'polygon': note.polygon,
@ -121,113 +123,216 @@ def serialize_note(note):
} }
def serialize_post(post, auth_user, options=None): class PostSerializer(serialization.BaseSerializer):
return util.serialize_entity( def __init__(self, post: model.Post, auth_user: model.User) -> None:
post, self.post = post
self.auth_user = auth_user
def _serializers(self) -> Dict[str, Callable[[], Any]]:
return {
'id': self.serialize_id,
'version': self.serialize_version,
'creationTime': self.serialize_creation_time,
'lastEditTime': self.serialize_last_edit_time,
'safety': self.serialize_safety,
'source': self.serialize_source,
'type': self.serialize_type,
'mimeType': self.serialize_mime,
'checksum': self.serialize_checksum,
'fileSize': self.serialize_file_size,
'canvasWidth': self.serialize_canvas_width,
'canvasHeight': self.serialize_canvas_height,
'contentUrl': self.serialize_content_url,
'thumbnailUrl': self.serialize_thumbnail_url,
'flags': self.serialize_flags,
'tags': self.serialize_tags,
'relations': self.serialize_relations,
'user': self.serialize_user,
'score': self.serialize_score,
'ownScore': self.serialize_own_score,
'ownFavorite': self.serialize_own_favorite,
'tagCount': self.serialize_tag_count,
'favoriteCount': self.serialize_favorite_count,
'commentCount': self.serialize_comment_count,
'noteCount': self.serialize_note_count,
'relationCount': self.serialize_relation_count,
'featureCount': self.serialize_feature_count,
'lastFeatureTime': self.serialize_last_feature_time,
'favoritedBy': self.serialize_favorited_by,
'hasCustomThumbnail': self.serialize_has_custom_thumbnail,
'notes': self.serialize_notes,
'comments': self.serialize_comments,
}
def serialize_id(self) -> Any:
return self.post.post_id
def serialize_version(self) -> Any:
return self.post.version
def serialize_creation_time(self) -> Any:
return self.post.creation_time
def serialize_last_edit_time(self) -> Any:
return self.post.last_edit_time
def serialize_safety(self) -> Any:
return SAFETY_MAP[self.post.safety]
def serialize_source(self) -> Any:
return self.post.source
def serialize_type(self) -> Any:
return TYPE_MAP[self.post.type]
def serialize_mime(self) -> Any:
return self.post.mime_type
def serialize_checksum(self) -> Any:
return self.post.checksum
def serialize_file_size(self) -> Any:
return self.post.file_size
def serialize_canvas_width(self) -> Any:
return self.post.canvas_width
def serialize_canvas_height(self) -> Any:
return self.post.canvas_height
def serialize_content_url(self) -> Any:
return get_post_content_url(self.post)
def serialize_thumbnail_url(self) -> Any:
return get_post_thumbnail_url(self.post)
def serialize_flags(self) -> Any:
return self.post.flags
def serialize_tags(self) -> Any:
return [tag.names[0].name for tag in tags.sort_tags(self.post.tags)]
def serialize_relations(self) -> Any:
return sorted(
{ {
'id': lambda: post.post_id, post['id']: post
'version': lambda: post.version, for post in [
'creationTime': lambda: post.creation_time, serialize_micro_post(rel, self.auth_user)
'lastEditTime': lambda: post.last_edit_time, for rel in self.post.relations]
'safety': lambda: SAFETY_MAP[post.safety],
'source': lambda: post.source,
'type': lambda: TYPE_MAP[post.type],
'mimeType': lambda: post.mime_type,
'checksum': lambda: post.checksum,
'fileSize': lambda: post.file_size,
'canvasWidth': lambda: post.canvas_width,
'canvasHeight': lambda: post.canvas_height,
'contentUrl': lambda: get_post_content_url(post),
'thumbnailUrl': lambda: get_post_thumbnail_url(post),
'flags': lambda: post.flags,
'tags': lambda: [
tag.names[0].name for tag in tags.sort_tags(post.tags)],
'relations': lambda: sorted(
{
post['id']:
post for post in [
serialize_micro_post(rel, auth_user)
for rel in post.relations]
}.values(), }.values(),
key=lambda post: post['id']), key=lambda post: post['id'])
'user': lambda: users.serialize_micro_user(post.user, auth_user),
'score': lambda: post.score, def serialize_user(self) -> Any:
'ownScore': lambda: scores.get_score(post, auth_user), return users.serialize_micro_user(self.post.user, self.auth_user)
'ownFavorite': lambda: len([
user for user in post.favorited_by def serialize_score(self) -> Any:
if user.user_id == auth_user.user_id] return self.post.score
) > 0,
'tagCount': lambda: post.tag_count, def serialize_own_score(self) -> Any:
'favoriteCount': lambda: post.favorite_count, return scores.get_score(self.post, self.auth_user)
'commentCount': lambda: post.comment_count,
'noteCount': lambda: post.note_count, def serialize_own_favorite(self) -> Any:
'relationCount': lambda: post.relation_count, return len([
'featureCount': lambda: post.feature_count, user for user in self.post.favorited_by
'lastFeatureTime': lambda: post.last_feature_time, if user.user_id == self.auth_user.user_id]
'favoritedBy': lambda: [ ) > 0
users.serialize_micro_user(rel.user, auth_user)
for rel in post.favorited_by def serialize_tag_count(self) -> Any:
], return self.post.tag_count
'hasCustomThumbnail':
lambda: files.has(get_post_thumbnail_backup_path(post)), def serialize_favorite_count(self) -> Any:
'notes': lambda: sorted( return self.post.favorite_count
[serialize_note(note) for note in post.notes],
key=lambda x: x['polygon']), def serialize_comment_count(self) -> Any:
'comments': lambda: [ return self.post.comment_count
comments.serialize_comment(comment, auth_user)
def serialize_note_count(self) -> Any:
return self.post.note_count
def serialize_relation_count(self) -> Any:
return self.post.relation_count
def serialize_feature_count(self) -> Any:
return self.post.feature_count
def serialize_last_feature_time(self) -> Any:
return self.post.last_feature_time
def serialize_favorited_by(self) -> Any:
return [
users.serialize_micro_user(rel.user, self.auth_user)
for rel in self.post.favorited_by
]
def serialize_has_custom_thumbnail(self) -> Any:
return files.has(get_post_thumbnail_backup_path(self.post))
def serialize_notes(self) -> Any:
return sorted(
[serialize_note(note) for note in self.post.notes],
key=lambda x: x['polygon'])
def serialize_comments(self) -> Any:
return [
comments.serialize_comment(comment, self.auth_user)
for comment in sorted( for comment in sorted(
post.comments, self.post.comments,
key=lambda comment: comment.creation_time)], key=lambda comment: comment.creation_time)]
},
options)
def serialize_micro_post(post, auth_user): def serialize_post(
post: Optional[model.Post],
auth_user: model.User,
options: List[str]=[]) -> Optional[rest.Response]:
if not post:
return None
return PostSerializer(post, auth_user).serialize(options)
def serialize_micro_post(
post: model.Post, auth_user: model.User) -> Optional[rest.Response]:
return serialize_post( return serialize_post(
post, post, auth_user=auth_user, options=['id', 'thumbnailUrl'])
auth_user=auth_user,
options=['id', 'thumbnailUrl'])
def get_post_count(): def get_post_count() -> int:
return db.session.query(sqlalchemy.func.count(db.Post.post_id)).one()[0] return db.session.query(sa.func.count(model.Post.post_id)).one()[0]
def try_get_post_by_id(post_id): def try_get_post_by_id(post_id: int) -> Optional[model.Post]:
try:
post_id = int(post_id)
except ValueError:
raise InvalidPostIdError('Invalid post ID: %r.' % post_id)
return db.session \ return db.session \
.query(db.Post) \ .query(model.Post) \
.filter(db.Post.post_id == post_id) \ .filter(model.Post.post_id == post_id) \
.one_or_none() .one_or_none()
def get_post_by_id(post_id): def get_post_by_id(post_id: int) -> model.Post:
post = try_get_post_by_id(post_id) post = try_get_post_by_id(post_id)
if not post: if not post:
raise PostNotFoundError('Post %r not found.' % post_id) raise PostNotFoundError('Post %r not found.' % post_id)
return post return post
def try_get_current_post_feature(): def try_get_current_post_feature() -> Optional[model.PostFeature]:
return db.session \ return db.session \
.query(db.PostFeature) \ .query(model.PostFeature) \
.order_by(db.PostFeature.time.desc()) \ .order_by(model.PostFeature.time.desc()) \
.first() .first()
def try_get_featured_post(): def try_get_featured_post() -> Optional[model.Post]:
post_feature = try_get_current_post_feature() post_feature = try_get_current_post_feature()
return post_feature.post if post_feature else None return post_feature.post if post_feature else None
def create_post(content, tag_names, user): def create_post(
post = db.Post() content: bytes,
post.safety = db.Post.SAFETY_SAFE tag_names: List[str],
user: Optional[model.User]) -> Tuple[model.Post, List[model.Tag]]:
post = model.Post()
post.safety = model.Post.SAFETY_SAFE
post.user = user post.user = user
post.creation_time = datetime.datetime.utcnow() post.creation_time = datetime.utcnow()
post.flags = [] post.flags = []
post.type = '' post.type = ''
@ -240,7 +345,7 @@ def create_post(content, tag_names, user):
return (post, new_tags) return (post, new_tags)
def update_post_safety(post, safety): def update_post_safety(post: model.Post, safety: str) -> None:
assert post assert post
safety = util.flip(SAFETY_MAP).get(safety, None) safety = util.flip(SAFETY_MAP).get(safety, None)
if not safety: if not safety:
@ -249,30 +354,33 @@ def update_post_safety(post, safety):
post.safety = safety post.safety = safety
def update_post_source(post, source): def update_post_source(post: model.Post, source: Optional[str]) -> None:
assert post assert post
if util.value_exceeds_column_size(source, db.Post.source): if util.value_exceeds_column_size(source, model.Post.source):
raise InvalidPostSourceError('Source is too long.') raise InvalidPostSourceError('Source is too long.')
post.source = source post.source = source or None
@sqlalchemy.events.event.listens_for(db.Post, 'after_insert') @sa.events.event.listens_for(model.Post, 'after_insert')
def _after_post_insert(_mapper, _connection, post): def _after_post_insert(
_mapper: Any, _connection: Any, post: model.Post) -> None:
_sync_post_content(post) _sync_post_content(post)
@sqlalchemy.events.event.listens_for(db.Post, 'after_update') @sa.events.event.listens_for(model.Post, 'after_update')
def _after_post_update(_mapper, _connection, post): def _after_post_update(
_mapper: Any, _connection: Any, post: model.Post) -> None:
_sync_post_content(post) _sync_post_content(post)
@sqlalchemy.events.event.listens_for(db.Post, 'before_delete') @sa.events.event.listens_for(model.Post, 'before_delete')
def _before_post_delete(_mapper, _connection, post): def _before_post_delete(
_mapper: Any, _connection: Any, post: model.Post) -> None:
if post.post_id: if post.post_id:
image_hash.delete_image(post.post_id) image_hash.delete_image(post.post_id)
def _sync_post_content(post): def _sync_post_content(post: model.Post) -> None:
regenerate_thumb = False regenerate_thumb = False
if hasattr(post, '__content'): if hasattr(post, '__content'):
@ -281,7 +389,7 @@ def _sync_post_content(post):
delattr(post, '__content') delattr(post, '__content')
regenerate_thumb = True regenerate_thumb = True
if post.post_id and post.type in ( if post.post_id and post.type in (
db.Post.TYPE_IMAGE, db.Post.TYPE_ANIMATION): model.Post.TYPE_IMAGE, model.Post.TYPE_ANIMATION):
image_hash.delete_image(post.post_id) image_hash.delete_image(post.post_id)
image_hash.add_image(post.post_id, content) image_hash.add_image(post.post_id, content)
@ -299,29 +407,29 @@ def _sync_post_content(post):
generate_post_thumbnail(post) generate_post_thumbnail(post)
def update_post_content(post, content): def update_post_content(post: model.Post, content: Optional[bytes]) -> None:
assert post assert post
if not content: if not content:
raise InvalidPostContentError('Post content missing.') raise InvalidPostContentError('Post content missing.')
post.mime_type = mime.get_mime_type(content) post.mime_type = mime.get_mime_type(content)
if mime.is_flash(post.mime_type): if mime.is_flash(post.mime_type):
post.type = db.Post.TYPE_FLASH post.type = model.Post.TYPE_FLASH
elif mime.is_image(post.mime_type): elif mime.is_image(post.mime_type):
if mime.is_animated_gif(content): if mime.is_animated_gif(content):
post.type = db.Post.TYPE_ANIMATION post.type = model.Post.TYPE_ANIMATION
else: else:
post.type = db.Post.TYPE_IMAGE post.type = model.Post.TYPE_IMAGE
elif mime.is_video(post.mime_type): elif mime.is_video(post.mime_type):
post.type = db.Post.TYPE_VIDEO post.type = model.Post.TYPE_VIDEO
else: else:
raise InvalidPostContentError( raise InvalidPostContentError(
'Unhandled file type: %r' % post.mime_type) 'Unhandled file type: %r' % post.mime_type)
post.checksum = util.get_sha1(content) post.checksum = util.get_sha1(content)
other_post = db.session \ other_post = db.session \
.query(db.Post) \ .query(model.Post) \
.filter(db.Post.checksum == post.checksum) \ .filter(model.Post.checksum == post.checksum) \
.filter(db.Post.post_id != post.post_id) \ .filter(model.Post.post_id != post.post_id) \
.one_or_none() .one_or_none()
if other_post \ if other_post \
and other_post.post_id \ and other_post.post_id \
@ -343,18 +451,20 @@ def update_post_content(post, content):
setattr(post, '__content', content) setattr(post, '__content', content)
def update_post_thumbnail(post, content=None): def update_post_thumbnail(
post: model.Post, content: Optional[bytes]=None) -> None:
assert post assert post
setattr(post, '__thumbnail', content) setattr(post, '__thumbnail', content)
def generate_post_thumbnail(post): def generate_post_thumbnail(post: model.Post) -> None:
assert post assert post
if files.has(get_post_thumbnail_backup_path(post)): if files.has(get_post_thumbnail_backup_path(post)):
content = files.get(get_post_thumbnail_backup_path(post)) content = files.get(get_post_thumbnail_backup_path(post))
else: else:
content = files.get(get_post_content_path(post)) content = files.get(get_post_content_path(post))
try: try:
assert content
image = images.Image(content) image = images.Image(content)
image.resize_fill( image.resize_fill(
int(config.config['thumbnails']['post_width']), int(config.config['thumbnails']['post_width']),
@ -364,14 +474,15 @@ def generate_post_thumbnail(post):
files.save(get_post_thumbnail_path(post), EMPTY_PIXEL) files.save(get_post_thumbnail_path(post), EMPTY_PIXEL)
def update_post_tags(post, tag_names): def update_post_tags(
post: model.Post, tag_names: List[str]) -> List[model.Tag]:
assert post assert post
existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names) existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names)
post.tags = existing_tags + new_tags post.tags = existing_tags + new_tags
return new_tags return new_tags
def update_post_relations(post, new_post_ids): def update_post_relations(post: model.Post, new_post_ids: List[int]) -> None:
assert post assert post
try: try:
new_post_ids = [int(id) for id in new_post_ids] new_post_ids = [int(id) for id in new_post_ids]
@ -382,8 +493,8 @@ def update_post_relations(post, new_post_ids):
old_post_ids = [int(p.post_id) for p in old_posts] old_post_ids = [int(p.post_id) for p in old_posts]
if new_post_ids: if new_post_ids:
new_posts = db.session \ new_posts = db.session \
.query(db.Post) \ .query(model.Post) \
.filter(db.Post.post_id.in_(new_post_ids)) \ .filter(model.Post.post_id.in_(new_post_ids)) \
.all() .all()
else: else:
new_posts = [] new_posts = []
@ -402,7 +513,7 @@ def update_post_relations(post, new_post_ids):
relation.relations.append(post) relation.relations.append(post)
def update_post_notes(post, notes): def update_post_notes(post: model.Post, notes: Any) -> None:
assert post assert post
post.notes = [] post.notes = []
for note in notes: for note in notes:
@ -433,13 +544,13 @@ def update_post_notes(post, notes):
except ValueError: except ValueError:
raise InvalidPostNoteError( raise InvalidPostNoteError(
'A point in note\'s polygon must be numeric.') 'A point in note\'s polygon must be numeric.')
if util.value_exceeds_column_size(note['text'], db.PostNote.text): if util.value_exceeds_column_size(note['text'], model.PostNote.text):
raise InvalidPostNoteError('Note text is too long.') raise InvalidPostNoteError('Note text is too long.')
post.notes.append( post.notes.append(
db.PostNote(polygon=note['polygon'], text=str(note['text']))) model.PostNote(polygon=note['polygon'], text=str(note['text'])))
def update_post_flags(post, flags): def update_post_flags(post: model.Post, flags: List[str]) -> None:
assert post assert post
target_flags = [] target_flags = []
for flag in flags: for flag in flags:
@ -451,88 +562,95 @@ def update_post_flags(post, flags):
post.flags = target_flags post.flags = target_flags
def feature_post(post, user): def feature_post(post: model.Post, user: Optional[model.User]) -> None:
assert post assert post
post_feature = db.PostFeature() post_feature = model.PostFeature()
post_feature.time = datetime.datetime.utcnow() post_feature.time = datetime.utcnow()
post_feature.post = post post_feature.post = post
post_feature.user = user post_feature.user = user
db.session.add(post_feature) db.session.add(post_feature)
def delete(post): def delete(post: model.Post) -> None:
assert post assert post
db.session.delete(post) db.session.delete(post)
def merge_posts(source_post, target_post, replace_content): def merge_posts(
source_post: model.Post,
target_post: model.Post,
replace_content: bool) -> None:
assert source_post assert source_post
assert target_post assert target_post
if source_post.post_id == target_post.post_id: if source_post.post_id == target_post.post_id:
raise InvalidPostRelationError('Cannot merge post with itself.') raise InvalidPostRelationError('Cannot merge post with itself.')
def merge_tables(table, anti_dup_func, source_post_id, target_post_id): def merge_tables(
table: model.Base,
anti_dup_func: Optional[Callable[[model.Base, model.Base], bool]],
source_post_id: int,
target_post_id: int) -> None:
alias1 = table alias1 = table
alias2 = sqlalchemy.orm.util.aliased(table) alias2 = sa.orm.util.aliased(table)
update_stmt = ( update_stmt = (
sqlalchemy.sql.expression.update(alias1) sa.sql.expression.update(alias1)
.where(alias1.post_id == source_post_id)) .where(alias1.post_id == source_post_id))
if anti_dup_func is not None: if anti_dup_func is not None:
update_stmt = ( update_stmt = (
update_stmt update_stmt
.where( .where(
~sqlalchemy.exists() ~sa.exists()
.where(anti_dup_func(alias1, alias2)) .where(anti_dup_func(alias1, alias2))
.where(alias2.post_id == target_post_id))) .where(alias2.post_id == target_post_id)))
update_stmt = update_stmt.values(post_id=target_post_id) update_stmt = update_stmt.values(post_id=target_post_id)
db.session.execute(update_stmt) db.session.execute(update_stmt)
def merge_tags(source_post_id, target_post_id): def merge_tags(source_post_id: int, target_post_id: int) -> None:
merge_tables( merge_tables(
db.PostTag, model.PostTag,
lambda alias1, alias2: alias1.tag_id == alias2.tag_id, lambda alias1, alias2: alias1.tag_id == alias2.tag_id,
source_post_id, source_post_id,
target_post_id) target_post_id)
def merge_scores(source_post_id, target_post_id): def merge_scores(source_post_id: int, target_post_id: int) -> None:
merge_tables( merge_tables(
db.PostScore, model.PostScore,
lambda alias1, alias2: alias1.user_id == alias2.user_id, lambda alias1, alias2: alias1.user_id == alias2.user_id,
source_post_id, source_post_id,
target_post_id) target_post_id)
def merge_favorites(source_post_id, target_post_id): def merge_favorites(source_post_id: int, target_post_id: int) -> None:
merge_tables( merge_tables(
db.PostFavorite, model.PostFavorite,
lambda alias1, alias2: alias1.user_id == alias2.user_id, lambda alias1, alias2: alias1.user_id == alias2.user_id,
source_post_id, source_post_id,
target_post_id) target_post_id)
def merge_comments(source_post_id, target_post_id): def merge_comments(source_post_id: int, target_post_id: int) -> None:
merge_tables(db.Comment, None, source_post_id, target_post_id) merge_tables(model.Comment, None, source_post_id, target_post_id)
def merge_relations(source_post_id, target_post_id): def merge_relations(source_post_id: int, target_post_id: int) -> None:
alias1 = db.PostRelation alias1 = model.PostRelation
alias2 = sqlalchemy.orm.util.aliased(db.PostRelation) alias2 = sa.orm.util.aliased(model.PostRelation)
update_stmt = ( update_stmt = (
sqlalchemy.sql.expression.update(alias1) sa.sql.expression.update(alias1)
.where(alias1.parent_id == source_post_id) .where(alias1.parent_id == source_post_id)
.where(alias1.child_id != target_post_id) .where(alias1.child_id != target_post_id)
.where( .where(
~sqlalchemy.exists() ~sa.exists()
.where(alias2.child_id == alias1.child_id) .where(alias2.child_id == alias1.child_id)
.where(alias2.parent_id == target_post_id)) .where(alias2.parent_id == target_post_id))
.values(parent_id=target_post_id)) .values(parent_id=target_post_id))
db.session.execute(update_stmt) db.session.execute(update_stmt)
update_stmt = ( update_stmt = (
sqlalchemy.sql.expression.update(alias1) sa.sql.expression.update(alias1)
.where(alias1.child_id == source_post_id) .where(alias1.child_id == source_post_id)
.where(alias1.parent_id != target_post_id) .where(alias1.parent_id != target_post_id)
.where( .where(
~sqlalchemy.exists() ~sa.exists()
.where(alias2.parent_id == alias1.parent_id) .where(alias2.parent_id == alias1.parent_id)
.where(alias2.child_id == target_post_id)) .where(alias2.child_id == target_post_id))
.values(child_id=target_post_id)) .values(child_id=target_post_id))
@ -553,15 +671,15 @@ def merge_posts(source_post, target_post, replace_content):
update_post_content(target_post, content) update_post_content(target_post, content)
def search_by_image_exact(image_content): def search_by_image_exact(image_content: bytes) -> Optional[model.Post]:
checksum = util.get_sha1(image_content) checksum = util.get_sha1(image_content)
return db.session \ return db.session \
.query(db.Post) \ .query(model.Post) \
.filter(db.Post.checksum == checksum) \ .filter(model.Post.checksum == checksum) \
.one_or_none() .one_or_none()
def search_by_image(image_content): def search_by_image(image_content: bytes) -> List[PostLookalike]:
ret = [] ret = []
for result in image_hash.search_by_image(image_content): for result in image_hash.search_by_image(image_content):
ret.append(PostLookalike( ret.append(PostLookalike(
@ -571,24 +689,24 @@ def search_by_image(image_content):
return ret return ret
def populate_reverse_search(): def populate_reverse_search() -> None:
excluded_post_ids = image_hash.get_all_paths() excluded_post_ids = image_hash.get_all_paths()
post_ids_to_hash = ( post_ids_to_hash = (
db.session db.session
.query(db.Post.post_id) .query(model.Post.post_id)
.filter( .filter(
(db.Post.type == db.Post.TYPE_IMAGE) | (model.Post.type == model.Post.TYPE_IMAGE) |
(db.Post.type == db.Post.TYPE_ANIMATION)) (model.Post.type == model.Post.TYPE_ANIMATION))
.filter(~db.Post.post_id.in_(excluded_post_ids)) .filter(~model.Post.post_id.in_(excluded_post_ids))
.order_by(db.Post.post_id.asc()) .order_by(model.Post.post_id.asc())
.all()) .all())
for post_ids_chunk in util.chunks(post_ids_to_hash, 100): for post_ids_chunk in util.chunks(post_ids_to_hash, 100):
posts_chunk = ( posts_chunk = (
db.session db.session
.query(db.Post) .query(model.Post)
.filter(db.Post.post_id.in_(post_ids_chunk)) .filter(model.Post.post_id.in_(post_ids_chunk))
.all()) .all())
for post in posts_chunk: for post in posts_chunk:
content_path = get_post_content_path(post) content_path = get_post_content_path(post)

View file

@ -1,5 +1,6 @@
import datetime import datetime
from szurubooru import db, errors from typing import Any, Tuple, Callable
from szurubooru import db, model, errors
class InvalidScoreTargetError(errors.ValidationError): class InvalidScoreTargetError(errors.ValidationError):
@ -10,22 +11,23 @@ class InvalidScoreValueError(errors.ValidationError):
pass pass
def _get_table_info(entity): def _get_table_info(
entity: model.Base) -> Tuple[model.Base, Callable[[model.Base], Any]]:
assert entity assert entity
resource_type, _, _ = db.util.get_resource_info(entity) resource_type, _, _ = model.util.get_resource_info(entity)
if resource_type == 'post': if resource_type == 'post':
return db.PostScore, lambda table: table.post_id return model.PostScore, lambda table: table.post_id
elif resource_type == 'comment': elif resource_type == 'comment':
return db.CommentScore, lambda table: table.comment_id return model.CommentScore, lambda table: table.comment_id
raise InvalidScoreTargetError() raise InvalidScoreTargetError()
def _get_score_entity(entity, user): def _get_score_entity(entity: model.Base, user: model.User) -> model.Base:
assert user assert user
return db.util.get_aux_entity(db.session, _get_table_info, entity, user) return model.util.get_aux_entity(db.session, _get_table_info, entity, user)
def delete_score(entity, user): def delete_score(entity: model.Base, user: model.User) -> None:
assert entity assert entity
assert user assert user
score_entity = _get_score_entity(entity, user) score_entity = _get_score_entity(entity, user)
@ -33,7 +35,7 @@ def delete_score(entity, user):
db.session.delete(score_entity) db.session.delete(score_entity)
def get_score(entity, user): def get_score(entity: model.Base, user: model.User) -> int:
assert entity assert entity
assert user assert user
table, get_column = _get_table_info(entity) table, get_column = _get_table_info(entity)
@ -45,7 +47,7 @@ def get_score(entity, user):
return row[0] if row else 0 return row[0] if row else 0
def set_score(entity, user, score): def set_score(entity: model.Base, user: model.User, score: int) -> None:
from szurubooru.func import favorites from szurubooru.func import favorites
assert entity assert entity
assert user assert user

View file

@ -0,0 +1,27 @@
from typing import Any, Optional, List, Dict, Callable
from szurubooru import db, model, rest, errors
def get_serialization_options(ctx: rest.Context) -> List[str]:
return ctx.get_param_as_list('fields', default=[])
class BaseSerializer:
_fields = {} # type: Dict[str, Callable[[model.Base], Any]]
def serialize(self, options: List[str]) -> Any:
field_factories = self._serializers()
if not options:
options = list(field_factories.keys())
ret = {}
for key in options:
if key not in field_factories:
raise errors.ValidationError(
'Invalid key: %r. Valid keys: %r.' % (
key, list(sorted(field_factories.keys()))))
factory = field_factories[key]
ret[key] = factory()
return ret
def _serializers(self) -> Dict[str, Callable[[], Any]]:
raise NotImplementedError()

View file

@ -1,9 +1,10 @@
from typing import Any, Optional, Dict, Callable
from datetime import datetime from datetime import datetime
from szurubooru import db from szurubooru import db, model
from szurubooru.func import diff, users from szurubooru.func import diff, users
def get_tag_category_snapshot(category): def get_tag_category_snapshot(category: model.TagCategory) -> Dict[str, Any]:
assert category assert category
return { return {
'name': category.name, 'name': category.name,
@ -12,7 +13,7 @@ def get_tag_category_snapshot(category):
} }
def get_tag_snapshot(tag): def get_tag_snapshot(tag: model.Tag) -> Dict[str, Any]:
assert tag assert tag
return { return {
'names': [tag_name.name for tag_name in tag.names], 'names': [tag_name.name for tag_name in tag.names],
@ -22,7 +23,7 @@ def get_tag_snapshot(tag):
} }
def get_post_snapshot(post): def get_post_snapshot(post: model.Post) -> Dict[str, Any]:
assert post assert post
return { return {
'source': post.source, 'source': post.source,
@ -45,10 +46,11 @@ _snapshot_factories = {
'tag_category': lambda entity: get_tag_category_snapshot(entity), 'tag_category': lambda entity: get_tag_category_snapshot(entity),
'tag': lambda entity: get_tag_snapshot(entity), 'tag': lambda entity: get_tag_snapshot(entity),
'post': lambda entity: get_post_snapshot(entity), 'post': lambda entity: get_post_snapshot(entity),
} } # type: Dict[model.Base, Callable[[model.Base], Dict[str ,Any]]]
def serialize_snapshot(snapshot, auth_user): def serialize_snapshot(
snapshot: model.Snapshot, auth_user: model.User) -> Dict[str, Any]:
assert snapshot assert snapshot
return { return {
'operation': snapshot.operation, 'operation': snapshot.operation,
@ -60,11 +62,14 @@ def serialize_snapshot(snapshot, auth_user):
} }
def _create(operation, entity, auth_user): def _create(
operation: str,
entity: model.Base,
auth_user: Optional[model.User]) -> model.Snapshot:
resource_type, resource_pkey, resource_name = ( resource_type, resource_pkey, resource_name = (
db.util.get_resource_info(entity)) model.util.get_resource_info(entity))
snapshot = db.Snapshot() snapshot = model.Snapshot()
snapshot.creation_time = datetime.utcnow() snapshot.creation_time = datetime.utcnow()
snapshot.operation = operation snapshot.operation = operation
snapshot.resource_type = resource_type snapshot.resource_type = resource_type
@ -74,33 +79,33 @@ def _create(operation, entity, auth_user):
return snapshot return snapshot
def create(entity, auth_user): def create(entity: model.Base, auth_user: Optional[model.User]) -> None:
assert entity assert entity
snapshot = _create(db.Snapshot.OPERATION_CREATED, entity, auth_user) snapshot = _create(model.Snapshot.OPERATION_CREATED, entity, auth_user)
snapshot_factory = _snapshot_factories[snapshot.resource_type] snapshot_factory = _snapshot_factories[snapshot.resource_type]
snapshot.data = snapshot_factory(entity) snapshot.data = snapshot_factory(entity)
db.session.add(snapshot) db.session.add(snapshot)
# pylint: disable=protected-access # pylint: disable=protected-access
def modify(entity, auth_user): def modify(entity: model.Base, auth_user: Optional[model.User]) -> None:
assert entity assert entity
model = next( table = next(
( (
model cls
for model in db.Base._decl_class_registry.values() for cls in model.Base._decl_class_registry.values()
if hasattr(model, '__table__') if hasattr(cls, '__table__')
and model.__table__.fullname == entity.__table__.fullname and cls.__table__.fullname == entity.__table__.fullname
), ),
None) None)
assert model assert table
snapshot = _create(db.Snapshot.OPERATION_MODIFIED, entity, auth_user) snapshot = _create(model.Snapshot.OPERATION_MODIFIED, entity, auth_user)
snapshot_factory = _snapshot_factories[snapshot.resource_type] snapshot_factory = _snapshot_factories[snapshot.resource_type]
detached_session = db.sessionmaker() detached_session = db.sessionmaker()
detached_entity = detached_session.query(model).get(snapshot.resource_pkey) detached_entity = detached_session.query(table).get(snapshot.resource_pkey)
assert detached_entity, 'Entity not found in DB, have you committed it?' assert detached_entity, 'Entity not found in DB, have you committed it?'
detached_snapshot = snapshot_factory(detached_entity) detached_snapshot = snapshot_factory(detached_entity)
detached_session.close() detached_session.close()
@ -113,19 +118,23 @@ def modify(entity, auth_user):
db.session.add(snapshot) db.session.add(snapshot)
def delete(entity, auth_user): def delete(entity: model.Base, auth_user: Optional[model.User]) -> None:
assert entity assert entity
snapshot = _create(db.Snapshot.OPERATION_DELETED, entity, auth_user) snapshot = _create(model.Snapshot.OPERATION_DELETED, entity, auth_user)
snapshot_factory = _snapshot_factories[snapshot.resource_type] snapshot_factory = _snapshot_factories[snapshot.resource_type]
snapshot.data = snapshot_factory(entity) snapshot.data = snapshot_factory(entity)
db.session.add(snapshot) db.session.add(snapshot)
def merge(source_entity, target_entity, auth_user): def merge(
source_entity: model.Base,
target_entity: model.Base,
auth_user: Optional[model.User]) -> None:
assert source_entity assert source_entity
assert target_entity assert target_entity
snapshot = _create(db.Snapshot.OPERATION_MERGED, source_entity, auth_user) snapshot = _create(
model.Snapshot.OPERATION_MERGED, source_entity, auth_user)
resource_type, _resource_pkey, resource_name = ( resource_type, _resource_pkey, resource_name = (
db.util.get_resource_info(target_entity)) model.util.get_resource_info(target_entity))
snapshot.data = [resource_type, resource_name] snapshot.data = [resource_type, resource_name]
db.session.add(snapshot) db.session.add(snapshot)

View file

@ -1,7 +1,8 @@
import re import re
import sqlalchemy from typing import Any, Optional, Dict, List, Callable
from szurubooru import config, db, errors import sqlalchemy as sa
from szurubooru.func import util, cache from szurubooru import config, db, model, errors, rest
from szurubooru.func import util, serialization, cache
DEFAULT_CATEGORY_NAME_CACHE_KEY = 'default-tag-category' DEFAULT_CATEGORY_NAME_CACHE_KEY = 'default-tag-category'
@ -27,28 +28,52 @@ class InvalidTagCategoryColorError(errors.ValidationError):
pass pass
def _verify_name_validity(name): def _verify_name_validity(name: str) -> None:
name_regex = config.config['tag_category_name_regex'] name_regex = config.config['tag_category_name_regex']
if not re.match(name_regex, name): if not re.match(name_regex, name):
raise InvalidTagCategoryNameError( raise InvalidTagCategoryNameError(
'Name must satisfy regex %r.' % name_regex) 'Name must satisfy regex %r.' % name_regex)
def serialize_category(category, options=None): class TagCategorySerializer(serialization.BaseSerializer):
return util.serialize_entity( def __init__(self, category: model.TagCategory) -> None:
category, self.category = category
{
'name': lambda: category.name, def _serializers(self) -> Dict[str, Callable[[], Any]]:
'version': lambda: category.version, return {
'color': lambda: category.color, 'name': self.serialize_name,
'usages': lambda: category.tag_count, 'version': self.serialize_version,
'default': lambda: category.default, 'color': self.serialize_color,
}, 'usages': self.serialize_usages,
options) 'default': self.serialize_default,
}
def serialize_name(self) -> Any:
return self.category.name
def serialize_version(self) -> Any:
return self.category.version
def serialize_color(self) -> Any:
return self.category.color
def serialize_usages(self) -> Any:
return self.category.tag_count
def serialize_default(self) -> Any:
return self.category.default
def create_category(name, color): def serialize_category(
category = db.TagCategory() category: Optional[model.TagCategory],
options: List[str]=[]) -> Optional[rest.Response]:
if not category:
return None
return TagCategorySerializer(category).serialize(options)
def create_category(name: str, color: str) -> model.TagCategory:
category = model.TagCategory()
update_category_name(category, name) update_category_name(category, name)
update_category_color(category, color) update_category_color(category, color)
if not get_all_categories(): if not get_all_categories():
@ -56,64 +81,66 @@ def create_category(name, color):
return category return category
def update_category_name(category, name): def update_category_name(category: model.TagCategory, name: str) -> None:
assert category assert category
if not name: if not name:
raise InvalidTagCategoryNameError('Name cannot be empty.') raise InvalidTagCategoryNameError('Name cannot be empty.')
expr = sqlalchemy.func.lower(db.TagCategory.name) == name.lower() expr = sa.func.lower(model.TagCategory.name) == name.lower()
if category.tag_category_id: if category.tag_category_id:
expr = expr & ( expr = expr & (
db.TagCategory.tag_category_id != category.tag_category_id) model.TagCategory.tag_category_id != category.tag_category_id)
already_exists = db.session.query(db.TagCategory).filter(expr).count() > 0 already_exists = (
db.session.query(model.TagCategory).filter(expr).count() > 0)
if already_exists: if already_exists:
raise TagCategoryAlreadyExistsError( raise TagCategoryAlreadyExistsError(
'A category with this name already exists.') 'A category with this name already exists.')
if util.value_exceeds_column_size(name, db.TagCategory.name): if util.value_exceeds_column_size(name, model.TagCategory.name):
raise InvalidTagCategoryNameError('Name is too long.') raise InvalidTagCategoryNameError('Name is too long.')
_verify_name_validity(name) _verify_name_validity(name)
category.name = name category.name = name
cache.remove(DEFAULT_CATEGORY_NAME_CACHE_KEY) cache.remove(DEFAULT_CATEGORY_NAME_CACHE_KEY)
def update_category_color(category, color): def update_category_color(category: model.TagCategory, color: str) -> None:
assert category assert category
if not color: if not color:
raise InvalidTagCategoryColorError('Color cannot be empty.') raise InvalidTagCategoryColorError('Color cannot be empty.')
if not re.match(r'^#?[0-9a-z]+$', color): if not re.match(r'^#?[0-9a-z]+$', color):
raise InvalidTagCategoryColorError('Invalid color.') raise InvalidTagCategoryColorError('Invalid color.')
if util.value_exceeds_column_size(color, db.TagCategory.color): if util.value_exceeds_column_size(color, model.TagCategory.color):
raise InvalidTagCategoryColorError('Color is too long.') raise InvalidTagCategoryColorError('Color is too long.')
category.color = color category.color = color
def try_get_category_by_name(name, lock=False): def try_get_category_by_name(
name: str, lock: bool=False) -> Optional[model.TagCategory]:
query = db.session \ query = db.session \
.query(db.TagCategory) \ .query(model.TagCategory) \
.filter(sqlalchemy.func.lower(db.TagCategory.name) == name.lower()) .filter(sa.func.lower(model.TagCategory.name) == name.lower())
if lock: if lock:
query = query.with_lockmode('update') query = query.with_lockmode('update')
return query.one_or_none() return query.one_or_none()
def get_category_by_name(name, lock=False): def get_category_by_name(name: str, lock: bool=False) -> model.TagCategory:
category = try_get_category_by_name(name, lock) category = try_get_category_by_name(name, lock)
if not category: if not category:
raise TagCategoryNotFoundError('Tag category %r not found.' % name) raise TagCategoryNotFoundError('Tag category %r not found.' % name)
return category return category
def get_all_category_names(): def get_all_category_names() -> List[str]:
return [row[0] for row in db.session.query(db.TagCategory.name).all()] return [row[0] for row in db.session.query(model.TagCategory.name).all()]
def get_all_categories(): def get_all_categories() -> List[model.TagCategory]:
return db.session.query(db.TagCategory).all() return db.session.query(model.TagCategory).all()
def try_get_default_category(lock=False): def try_get_default_category(lock: bool=False) -> Optional[model.TagCategory]:
query = db.session \ query = db.session \
.query(db.TagCategory) \ .query(model.TagCategory) \
.filter(db.TagCategory.default) .filter(model.TagCategory.default)
if lock: if lock:
query = query.with_lockmode('update') query = query.with_lockmode('update')
category = query.first() category = query.first()
@ -121,22 +148,22 @@ def try_get_default_category(lock=False):
# category, get the first record available. # category, get the first record available.
if not category: if not category:
query = db.session \ query = db.session \
.query(db.TagCategory) \ .query(model.TagCategory) \
.order_by(db.TagCategory.tag_category_id.asc()) .order_by(model.TagCategory.tag_category_id.asc())
if lock: if lock:
query = query.with_lockmode('update') query = query.with_lockmode('update')
category = query.first() category = query.first()
return category return category
def get_default_category(lock=False): def get_default_category(lock: bool=False) -> model.TagCategory:
category = try_get_default_category(lock) category = try_get_default_category(lock)
if not category: if not category:
raise TagCategoryNotFoundError('No tag category created yet.') raise TagCategoryNotFoundError('No tag category created yet.')
return category return category
def get_default_category_name(): def get_default_category_name() -> str:
if cache.has(DEFAULT_CATEGORY_NAME_CACHE_KEY): if cache.has(DEFAULT_CATEGORY_NAME_CACHE_KEY):
return cache.get(DEFAULT_CATEGORY_NAME_CACHE_KEY) return cache.get(DEFAULT_CATEGORY_NAME_CACHE_KEY)
default_category = get_default_category() default_category = get_default_category()
@ -145,7 +172,7 @@ def get_default_category_name():
return default_category_name return default_category_name
def set_default_category(category): def set_default_category(category: model.TagCategory) -> None:
assert category assert category
old_category = try_get_default_category(lock=True) old_category = try_get_default_category(lock=True)
if old_category: if old_category:
@ -156,7 +183,7 @@ def set_default_category(category):
cache.remove(DEFAULT_CATEGORY_NAME_CACHE_KEY) cache.remove(DEFAULT_CATEGORY_NAME_CACHE_KEY)
def delete_category(category): def delete_category(category: model.TagCategory) -> None:
assert category assert category
if len(get_all_category_names()) == 1: if len(get_all_category_names()) == 1:
raise TagCategoryIsInUseError('Cannot delete the last category.') raise TagCategoryIsInUseError('Cannot delete the last category.')

View file

@ -1,10 +1,11 @@
import datetime
import json import json
import os import os
import re import re
import sqlalchemy from typing import Any, Optional, Tuple, List, Dict, Callable
from szurubooru import config, db, errors from datetime import datetime
from szurubooru.func import util, tag_categories import sqlalchemy as sa
from szurubooru import config, db, model, errors, rest
from szurubooru.func import util, tag_categories, serialization
class TagNotFoundError(errors.NotFoundError): class TagNotFoundError(errors.NotFoundError):
@ -35,31 +36,32 @@ class InvalidTagDescriptionError(errors.ValidationError):
pass pass
def _verify_name_validity(name): def _verify_name_validity(name: str) -> None:
if util.value_exceeds_column_size(name, db.TagName.name): if util.value_exceeds_column_size(name, model.TagName.name):
raise InvalidTagNameError('Name is too long.') raise InvalidTagNameError('Name is too long.')
name_regex = config.config['tag_name_regex'] name_regex = config.config['tag_name_regex']
if not re.match(name_regex, name): if not re.match(name_regex, name):
raise InvalidTagNameError('Name must satisfy regex %r.' % name_regex) raise InvalidTagNameError('Name must satisfy regex %r.' % name_regex)
def _get_names(tag): def _get_names(tag: model.Tag) -> List[str]:
assert tag assert tag
return [tag_name.name for tag_name in tag.names] return [tag_name.name for tag_name in tag.names]
def _lower_list(names): def _lower_list(names: List[str]) -> List[str]:
return [name.lower() for name in names] return [name.lower() for name in names]
def _check_name_intersection(names1, names2, case_sensitive): def _check_name_intersection(
names1: List[str], names2: List[str], case_sensitive: bool) -> bool:
if not case_sensitive: if not case_sensitive:
names1 = _lower_list(names1) names1 = _lower_list(names1)
names2 = _lower_list(names2) names2 = _lower_list(names2)
return len(set(names1).intersection(names2)) > 0 return len(set(names1).intersection(names2)) > 0
def sort_tags(tags): def sort_tags(tags: List[model.Tag]) -> List[model.Tag]:
default_category_name = tag_categories.get_default_category_name() default_category_name = tag_categories.get_default_category_name()
return sorted( return sorted(
tags, tags,
@ -70,35 +72,70 @@ def sort_tags(tags):
) )
def serialize_tag(tag, options=None): class TagSerializer(serialization.BaseSerializer):
return util.serialize_entity( def __init__(self, tag: model.Tag) -> None:
tag, self.tag = tag
{
'names': lambda: [tag_name.name for tag_name in tag.names], def _serializers(self) -> Dict[str, Callable[[], Any]]:
'category': lambda: tag.category.name, return {
'version': lambda: tag.version, 'names': self.serialize_names,
'description': lambda: tag.description, 'category': self.serialize_category,
'creationTime': lambda: tag.creation_time, 'version': self.serialize_version,
'lastEditTime': lambda: tag.last_edit_time, 'description': self.serialize_description,
'usages': lambda: tag.post_count, 'creationTime': self.serialize_creation_time,
'suggestions': lambda: [ 'lastEditTime': self.serialize_last_edit_time,
'usages': self.serialize_usages,
'suggestions': self.serialize_suggestions,
'implications': self.serialize_implications,
}
def serialize_names(self) -> Any:
return [tag_name.name for tag_name in self.tag.names]
def serialize_category(self) -> Any:
return self.tag.category.name
def serialize_version(self) -> Any:
return self.tag.version
def serialize_description(self) -> Any:
return self.tag.description
def serialize_creation_time(self) -> Any:
return self.tag.creation_time
def serialize_last_edit_time(self) -> Any:
return self.tag.last_edit_time
def serialize_usages(self) -> Any:
return self.tag.post_count
def serialize_suggestions(self) -> Any:
return [
relation.names[0].name relation.names[0].name
for relation in sort_tags(tag.suggestions)], for relation in sort_tags(self.tag.suggestions)]
'implications': lambda: [
def serialize_implications(self) -> Any:
return [
relation.names[0].name relation.names[0].name
for relation in sort_tags(tag.implications)], for relation in sort_tags(self.tag.implications)]
},
options)
def export_to_json(): def serialize_tag(
tags = {} tag: model.Tag, options: List[str]=[]) -> Optional[rest.Response]:
categories = {} if not tag:
return None
return TagSerializer(tag).serialize(options)
def export_to_json() -> None:
tags = {} # type: Dict[int, Any]
categories = {} # type: Dict[int, Any]
for result in db.session.query( for result in db.session.query(
db.TagCategory.tag_category_id, model.TagCategory.tag_category_id,
db.TagCategory.name, model.TagCategory.name,
db.TagCategory.color).all(): model.TagCategory.color).all():
categories[result[0]] = { categories[result[0]] = {
'name': result[1], 'name': result[1],
'color': result[2], 'color': result[2],
@ -106,8 +143,8 @@ def export_to_json():
for result in ( for result in (
db.session db.session
.query(db.TagName.tag_id, db.TagName.name) .query(model.TagName.tag_id, model.TagName.name)
.order_by(db.TagName.order) .order_by(model.TagName.order)
.all()): .all()):
if not result[0] in tags: if not result[0] in tags:
tags[result[0]] = {'names': []} tags[result[0]] = {'names': []}
@ -115,8 +152,10 @@ def export_to_json():
for result in ( for result in (
db.session db.session
.query(db.TagSuggestion.parent_id, db.TagName.name) .query(model.TagSuggestion.parent_id, model.TagName.name)
.join(db.TagName, db.TagName.tag_id == db.TagSuggestion.child_id) .join(
model.TagName,
model.TagName.tag_id == model.TagSuggestion.child_id)
.all()): .all()):
if 'suggestions' not in tags[result[0]]: if 'suggestions' not in tags[result[0]]:
tags[result[0]]['suggestions'] = [] tags[result[0]]['suggestions'] = []
@ -124,17 +163,19 @@ def export_to_json():
for result in ( for result in (
db.session db.session
.query(db.TagImplication.parent_id, db.TagName.name) .query(model.TagImplication.parent_id, model.TagName.name)
.join(db.TagName, db.TagName.tag_id == db.TagImplication.child_id) .join(
model.TagName,
model.TagName.tag_id == model.TagImplication.child_id)
.all()): .all()):
if 'implications' not in tags[result[0]]: if 'implications' not in tags[result[0]]:
tags[result[0]]['implications'] = [] tags[result[0]]['implications'] = []
tags[result[0]]['implications'].append(result[1]) tags[result[0]]['implications'].append(result[1])
for result in db.session.query( for result in db.session.query(
db.Tag.tag_id, model.Tag.tag_id,
db.Tag.category_id, model.Tag.category_id,
db.Tag.post_count).all(): model.Tag.post_count).all():
tags[result[0]]['category'] = categories[result[1]]['name'] tags[result[0]]['category'] = categories[result[1]]['name']
tags[result[0]]['usages'] = result[2] tags[result[0]]['usages'] = result[2]
@ -148,33 +189,34 @@ def export_to_json():
handle.write(json.dumps(output, separators=(',', ':'))) handle.write(json.dumps(output, separators=(',', ':')))
def try_get_tag_by_name(name): def try_get_tag_by_name(name: str) -> Optional[model.Tag]:
return ( return (
db.session db.session
.query(db.Tag) .query(model.Tag)
.join(db.TagName) .join(model.TagName)
.filter(sqlalchemy.func.lower(db.TagName.name) == name.lower()) .filter(sa.func.lower(model.TagName.name) == name.lower())
.one_or_none()) .one_or_none())
def get_tag_by_name(name): def get_tag_by_name(name: str) -> model.Tag:
tag = try_get_tag_by_name(name) tag = try_get_tag_by_name(name)
if not tag: if not tag:
raise TagNotFoundError('Tag %r not found.' % name) raise TagNotFoundError('Tag %r not found.' % name)
return tag return tag
def get_tags_by_names(names): def get_tags_by_names(names: List[str]) -> List[model.Tag]:
names = util.icase_unique(names) names = util.icase_unique(names)
if len(names) == 0: if len(names) == 0:
return [] return []
expr = sqlalchemy.sql.false() expr = sa.sql.false()
for name in names: for name in names:
expr = expr | (sqlalchemy.func.lower(db.TagName.name) == name.lower()) expr = expr | (sa.func.lower(model.TagName.name) == name.lower())
return db.session.query(db.Tag).join(db.TagName).filter(expr).all() return db.session.query(model.Tag).join(model.TagName).filter(expr).all()
def get_or_create_tags_by_names(names): def get_or_create_tags_by_names(
names: List[str]) -> Tuple[List[model.Tag], List[model.Tag]]:
names = util.icase_unique(names) names = util.icase_unique(names)
existing_tags = get_tags_by_names(names) existing_tags = get_tags_by_names(names)
new_tags = [] new_tags = []
@ -197,86 +239,87 @@ def get_or_create_tags_by_names(names):
return existing_tags, new_tags return existing_tags, new_tags
def get_tag_siblings(tag): def get_tag_siblings(tag: model.Tag) -> List[model.Tag]:
assert tag assert tag
tag_alias = sqlalchemy.orm.aliased(db.Tag) tag_alias = sa.orm.aliased(model.Tag)
pt_alias1 = sqlalchemy.orm.aliased(db.PostTag) pt_alias1 = sa.orm.aliased(model.PostTag)
pt_alias2 = sqlalchemy.orm.aliased(db.PostTag) pt_alias2 = sa.orm.aliased(model.PostTag)
result = ( result = (
db.session db.session
.query(tag_alias, sqlalchemy.func.count(pt_alias2.post_id)) .query(tag_alias, sa.func.count(pt_alias2.post_id))
.join(pt_alias1, pt_alias1.tag_id == tag_alias.tag_id) .join(pt_alias1, pt_alias1.tag_id == tag_alias.tag_id)
.join(pt_alias2, pt_alias2.post_id == pt_alias1.post_id) .join(pt_alias2, pt_alias2.post_id == pt_alias1.post_id)
.filter(pt_alias2.tag_id == tag.tag_id) .filter(pt_alias2.tag_id == tag.tag_id)
.filter(pt_alias1.tag_id != tag.tag_id) .filter(pt_alias1.tag_id != tag.tag_id)
.group_by(tag_alias.tag_id) .group_by(tag_alias.tag_id)
.order_by(sqlalchemy.func.count(pt_alias2.post_id).desc()) .order_by(sa.func.count(pt_alias2.post_id).desc())
.limit(50)) .limit(50))
return result return result
def delete(source_tag): def delete(source_tag: model.Tag) -> None:
assert source_tag assert source_tag
db.session.execute( db.session.execute(
sqlalchemy.sql.expression.delete(db.TagSuggestion) sa.sql.expression.delete(model.TagSuggestion)
.where(db.TagSuggestion.child_id == source_tag.tag_id)) .where(model.TagSuggestion.child_id == source_tag.tag_id))
db.session.execute( db.session.execute(
sqlalchemy.sql.expression.delete(db.TagImplication) sa.sql.expression.delete(model.TagImplication)
.where(db.TagImplication.child_id == source_tag.tag_id)) .where(model.TagImplication.child_id == source_tag.tag_id))
db.session.delete(source_tag) db.session.delete(source_tag)
def merge_tags(source_tag, target_tag): def merge_tags(source_tag: model.Tag, target_tag: model.Tag) -> None:
assert source_tag assert source_tag
assert target_tag assert target_tag
if source_tag.tag_id == target_tag.tag_id: if source_tag.tag_id == target_tag.tag_id:
raise InvalidTagRelationError('Cannot merge tag with itself.') raise InvalidTagRelationError('Cannot merge tag with itself.')
def merge_posts(source_tag_id, target_tag_id): def merge_posts(source_tag_id: int, target_tag_id: int) -> None:
alias1 = db.PostTag alias1 = model.PostTag
alias2 = sqlalchemy.orm.util.aliased(db.PostTag) alias2 = sa.orm.util.aliased(model.PostTag)
update_stmt = ( update_stmt = (
sqlalchemy.sql.expression.update(alias1) sa.sql.expression.update(alias1)
.where(alias1.tag_id == source_tag_id)) .where(alias1.tag_id == source_tag_id))
update_stmt = ( update_stmt = (
update_stmt update_stmt
.where( .where(
~sqlalchemy.exists() ~sa.exists()
.where(alias1.post_id == alias2.post_id) .where(alias1.post_id == alias2.post_id)
.where(alias2.tag_id == target_tag_id))) .where(alias2.tag_id == target_tag_id)))
update_stmt = update_stmt.values(tag_id=target_tag_id) update_stmt = update_stmt.values(tag_id=target_tag_id)
db.session.execute(update_stmt) db.session.execute(update_stmt)
def merge_relations(table, source_tag_id, target_tag_id): def merge_relations(
table: model.Base, source_tag_id: int, target_tag_id: int) -> None:
alias1 = table alias1 = table
alias2 = sqlalchemy.orm.util.aliased(table) alias2 = sa.orm.util.aliased(table)
update_stmt = ( update_stmt = (
sqlalchemy.sql.expression.update(alias1) sa.sql.expression.update(alias1)
.where(alias1.parent_id == source_tag_id) .where(alias1.parent_id == source_tag_id)
.where(alias1.child_id != target_tag_id) .where(alias1.child_id != target_tag_id)
.where( .where(
~sqlalchemy.exists() ~sa.exists()
.where(alias2.child_id == alias1.child_id) .where(alias2.child_id == alias1.child_id)
.where(alias2.parent_id == target_tag_id)) .where(alias2.parent_id == target_tag_id))
.values(parent_id=target_tag_id)) .values(parent_id=target_tag_id))
db.session.execute(update_stmt) db.session.execute(update_stmt)
update_stmt = ( update_stmt = (
sqlalchemy.sql.expression.update(alias1) sa.sql.expression.update(alias1)
.where(alias1.child_id == source_tag_id) .where(alias1.child_id == source_tag_id)
.where(alias1.parent_id != target_tag_id) .where(alias1.parent_id != target_tag_id)
.where( .where(
~sqlalchemy.exists() ~sa.exists()
.where(alias2.parent_id == alias1.parent_id) .where(alias2.parent_id == alias1.parent_id)
.where(alias2.child_id == target_tag_id)) .where(alias2.child_id == target_tag_id))
.values(child_id=target_tag_id)) .values(child_id=target_tag_id))
db.session.execute(update_stmt) db.session.execute(update_stmt)
def merge_suggestions(source_tag_id, target_tag_id): def merge_suggestions(source_tag_id: int, target_tag_id: int) -> None:
merge_relations(db.TagSuggestion, source_tag_id, target_tag_id) merge_relations(model.TagSuggestion, source_tag_id, target_tag_id)
def merge_implications(source_tag_id, target_tag_id): def merge_implications(source_tag_id: int, target_tag_id: int) -> None:
merge_relations(db.TagImplication, source_tag_id, target_tag_id) merge_relations(model.TagImplication, source_tag_id, target_tag_id)
merge_posts(source_tag.tag_id, target_tag.tag_id) merge_posts(source_tag.tag_id, target_tag.tag_id)
merge_suggestions(source_tag.tag_id, target_tag.tag_id) merge_suggestions(source_tag.tag_id, target_tag.tag_id)
@ -284,9 +327,13 @@ def merge_tags(source_tag, target_tag):
delete(source_tag) delete(source_tag)
def create_tag(names, category_name, suggestions, implications): def create_tag(
tag = db.Tag() names: List[str],
tag.creation_time = datetime.datetime.utcnow() category_name: str,
suggestions: List[str],
implications: List[str]) -> model.Tag:
tag = model.Tag()
tag.creation_time = datetime.utcnow()
update_tag_names(tag, names) update_tag_names(tag, names)
update_tag_category_name(tag, category_name) update_tag_category_name(tag, category_name)
update_tag_suggestions(tag, suggestions) update_tag_suggestions(tag, suggestions)
@ -294,12 +341,12 @@ def create_tag(names, category_name, suggestions, implications):
return tag return tag
def update_tag_category_name(tag, category_name): def update_tag_category_name(tag: model.Tag, category_name: str) -> None:
assert tag assert tag
tag.category = tag_categories.get_category_by_name(category_name) tag.category = tag_categories.get_category_by_name(category_name)
def update_tag_names(tag, names): def update_tag_names(tag: model.Tag, names: List[str]) -> None:
# sanitize # sanitize
assert tag assert tag
names = util.icase_unique([name for name in names if name]) names = util.icase_unique([name for name in names if name])
@ -309,12 +356,12 @@ def update_tag_names(tag, names):
_verify_name_validity(name) _verify_name_validity(name)
# check for existing tags # check for existing tags
expr = sqlalchemy.sql.false() expr = sa.sql.false()
for name in names: for name in names:
expr = expr | (sqlalchemy.func.lower(db.TagName.name) == name.lower()) expr = expr | (sa.func.lower(model.TagName.name) == name.lower())
if tag.tag_id: if tag.tag_id:
expr = expr & (db.TagName.tag_id != tag.tag_id) expr = expr & (model.TagName.tag_id != tag.tag_id)
existing_tags = db.session.query(db.TagName).filter(expr).all() existing_tags = db.session.query(model.TagName).filter(expr).all()
if len(existing_tags): if len(existing_tags):
raise TagAlreadyExistsError( raise TagAlreadyExistsError(
'One of names is already used by another tag.') 'One of names is already used by another tag.')
@ -326,7 +373,7 @@ def update_tag_names(tag, names):
# add wanted items # add wanted items
for name in names: for name in names:
if not _check_name_intersection(_get_names(tag), [name], True): if not _check_name_intersection(_get_names(tag), [name], True):
tag.names.append(db.TagName(name, None)) tag.names.append(model.TagName(name, -1))
# set alias order to match the request # set alias order to match the request
for i, name in enumerate(names): for i, name in enumerate(names):
@ -336,7 +383,7 @@ def update_tag_names(tag, names):
# TODO: what to do with relations that do not yet exist? # TODO: what to do with relations that do not yet exist?
def update_tag_implications(tag, relations): def update_tag_implications(tag: model.Tag, relations: List[str]) -> None:
assert tag assert tag
if _check_name_intersection(_get_names(tag), relations, False): if _check_name_intersection(_get_names(tag), relations, False):
raise InvalidTagRelationError('Tag cannot imply itself.') raise InvalidTagRelationError('Tag cannot imply itself.')
@ -344,15 +391,15 @@ def update_tag_implications(tag, relations):
# TODO: what to do with relations that do not yet exist? # TODO: what to do with relations that do not yet exist?
def update_tag_suggestions(tag, relations): def update_tag_suggestions(tag: model.Tag, relations: List[str]) -> None:
assert tag assert tag
if _check_name_intersection(_get_names(tag), relations, False): if _check_name_intersection(_get_names(tag), relations, False):
raise InvalidTagRelationError('Tag cannot suggest itself.') raise InvalidTagRelationError('Tag cannot suggest itself.')
tag.suggestions = get_tags_by_names(relations) tag.suggestions = get_tags_by_names(relations)
def update_tag_description(tag, description): def update_tag_description(tag: model.Tag, description: str) -> None:
assert tag assert tag
if util.value_exceeds_column_size(description, db.Tag.description): if util.value_exceeds_column_size(description, model.Tag.description):
raise InvalidTagDescriptionError('Description is too long.') raise InvalidTagDescriptionError('Description is too long.')
tag.description = description tag.description = description or None

View file

@ -1,8 +1,9 @@
import datetime
import re import re
from sqlalchemy import func from typing import Any, Optional, Union, List, Dict, Callable
from szurubooru import config, db, errors from datetime import datetime
from szurubooru.func import auth, util, files, images import sqlalchemy as sa
from szurubooru import config, db, model, errors, rest
from szurubooru.func import auth, util, serialization, files, images
class UserNotFoundError(errors.NotFoundError): class UserNotFoundError(errors.NotFoundError):
@ -33,11 +34,11 @@ class InvalidAvatarError(errors.ValidationError):
pass pass
def get_avatar_path(user_name): def get_avatar_path(user_name: str) -> str:
return 'avatars/' + user_name.lower() + '.png' return 'avatars/' + user_name.lower() + '.png'
def get_avatar_url(user): def get_avatar_url(user: model.User) -> str:
assert user assert user
if user.avatar_style == user.AVATAR_GRAVATAR: if user.avatar_style == user.AVATAR_GRAVATAR:
assert user.email or user.name assert user.email or user.name
@ -49,7 +50,10 @@ def get_avatar_url(user):
config.config['data_url'].rstrip('/'), user.name.lower()) config.config['data_url'].rstrip('/'), user.name.lower())
def get_email(user, auth_user, force_show_email): def get_email(
user: model.User,
auth_user: model.User,
force_show_email: bool) -> Union[bool, str]:
assert user assert user
assert auth_user assert auth_user
if not force_show_email \ if not force_show_email \
@ -59,7 +63,8 @@ def get_email(user, auth_user, force_show_email):
return user.email return user.email
def get_liked_post_count(user, auth_user): def get_liked_post_count(
user: model.User, auth_user: model.User) -> Union[bool, int]:
assert user assert user
assert auth_user assert auth_user
if auth_user.user_id != user.user_id: if auth_user.user_id != user.user_id:
@ -67,7 +72,8 @@ def get_liked_post_count(user, auth_user):
return user.liked_post_count return user.liked_post_count
def get_disliked_post_count(user, auth_user): def get_disliked_post_count(
user: model.User, auth_user: model.User) -> Union[bool, int]:
assert user assert user
assert auth_user assert auth_user
if auth_user.user_id != user.user_id: if auth_user.user_id != user.user_id:
@ -75,91 +81,144 @@ def get_disliked_post_count(user, auth_user):
return user.disliked_post_count return user.disliked_post_count
def serialize_user(user, auth_user, options=None, force_show_email=False): class UserSerializer(serialization.BaseSerializer):
return util.serialize_entity( def __init__(
user, self,
{ user: model.User,
'name': lambda: user.name, auth_user: model.User,
'creationTime': lambda: user.creation_time, force_show_email: bool=False) -> None:
'lastLoginTime': lambda: user.last_login_time, self.user = user
'version': lambda: user.version, self.auth_user = auth_user
'rank': lambda: user.rank, self.force_show_email = force_show_email
'avatarStyle': lambda: user.avatar_style,
'avatarUrl': lambda: get_avatar_url(user), def _serializers(self) -> Dict[str, Callable[[], Any]]:
'commentCount': lambda: user.comment_count, return {
'uploadedPostCount': lambda: user.post_count, 'name': self.serialize_name,
'favoritePostCount': lambda: user.favorite_post_count, 'creationTime': self.serialize_creation_time,
'likedPostCount': 'lastLoginTime': self.serialize_last_login_time,
lambda: get_liked_post_count(user, auth_user), 'version': self.serialize_version,
'dislikedPostCount': 'rank': self.serialize_rank,
lambda: get_disliked_post_count(user, auth_user), 'avatarStyle': self.serialize_avatar_style,
'email': 'avatarUrl': self.serialize_avatar_url,
lambda: get_email(user, auth_user, force_show_email), 'commentCount': self.serialize_comment_count,
}, 'uploadedPostCount': self.serialize_uploaded_post_count,
options) 'favoritePostCount': self.serialize_favorite_post_count,
'likedPostCount': self.serialize_liked_post_count,
'dislikedPostCount': self.serialize_disliked_post_count,
'email': self.serialize_email,
}
def serialize_name(self) -> Any:
return self.user.name
def serialize_creation_time(self) -> Any:
return self.user.creation_time
def serialize_last_login_time(self) -> Any:
return self.user.last_login_time
def serialize_version(self) -> Any:
return self.user.version
def serialize_rank(self) -> Any:
return self.user.rank
def serialize_avatar_style(self) -> Any:
return self.user.avatar_style
def serialize_avatar_url(self) -> Any:
return get_avatar_url(self.user)
def serialize_comment_count(self) -> Any:
return self.user.comment_count
def serialize_uploaded_post_count(self) -> Any:
return self.user.post_count
def serialize_favorite_post_count(self) -> Any:
return self.user.favorite_post_count
def serialize_liked_post_count(self) -> Any:
return get_liked_post_count(self.user, self.auth_user)
def serialize_disliked_post_count(self) -> Any:
return get_disliked_post_count(self.user, self.auth_user)
def serialize_email(self) -> Any:
return get_email(self.user, self.auth_user, self.force_show_email)
def serialize_micro_user(user, auth_user): def serialize_user(
user: Optional[model.User],
auth_user: model.User,
options: List[str]=[],
force_show_email: bool=False) -> Optional[rest.Response]:
if not user:
return None
return UserSerializer(user, auth_user, force_show_email).serialize(options)
def serialize_micro_user(
user: Optional[model.User],
auth_user: model.User) -> Optional[rest.Response]:
return serialize_user( return serialize_user(
user, user, auth_user=auth_user, options=['name', 'avatarUrl'])
auth_user=auth_user,
options=['name', 'avatarUrl'])
def get_user_count(): def get_user_count() -> int:
return db.session.query(db.User).count() return db.session.query(model.User).count()
def try_get_user_by_name(name): def try_get_user_by_name(name: str) -> Optional[model.User]:
return db.session \ return db.session \
.query(db.User) \ .query(model.User) \
.filter(func.lower(db.User.name) == func.lower(name)) \ .filter(sa.func.lower(model.User.name) == sa.func.lower(name)) \
.one_or_none() .one_or_none()
def get_user_by_name(name): def get_user_by_name(name: str) -> model.User:
user = try_get_user_by_name(name) user = try_get_user_by_name(name)
if not user: if not user:
raise UserNotFoundError('User %r not found.' % name) raise UserNotFoundError('User %r not found.' % name)
return user return user
def try_get_user_by_name_or_email(name_or_email): def try_get_user_by_name_or_email(name_or_email: str) -> Optional[model.User]:
return ( return (
db.session db.session
.query(db.User) .query(model.User)
.filter( .filter(
(func.lower(db.User.name) == func.lower(name_or_email)) | (sa.func.lower(model.User.name) == sa.func.lower(name_or_email)) |
(func.lower(db.User.email) == func.lower(name_or_email))) (sa.func.lower(model.User.email) == sa.func.lower(name_or_email)))
.one_or_none()) .one_or_none())
def get_user_by_name_or_email(name_or_email): def get_user_by_name_or_email(name_or_email: str) -> model.User:
user = try_get_user_by_name_or_email(name_or_email) user = try_get_user_by_name_or_email(name_or_email)
if not user: if not user:
raise UserNotFoundError('User %r not found.' % name_or_email) raise UserNotFoundError('User %r not found.' % name_or_email)
return user return user
def create_user(name, password, email): def create_user(name: str, password: str, email: str) -> model.User:
user = db.User() user = model.User()
update_user_name(user, name) update_user_name(user, name)
update_user_password(user, password) update_user_password(user, password)
update_user_email(user, email) update_user_email(user, email)
if get_user_count() > 0: if get_user_count() > 0:
user.rank = util.flip(auth.RANK_MAP)[config.config['default_rank']] user.rank = util.flip(auth.RANK_MAP)[config.config['default_rank']]
else: else:
user.rank = db.User.RANK_ADMINISTRATOR user.rank = model.User.RANK_ADMINISTRATOR
user.creation_time = datetime.datetime.utcnow() user.creation_time = datetime.utcnow()
user.avatar_style = db.User.AVATAR_GRAVATAR user.avatar_style = model.User.AVATAR_GRAVATAR
return user return user
def update_user_name(user, name): def update_user_name(user: model.User, name: str) -> None:
assert user assert user
if not name: if not name:
raise InvalidUserNameError('Name cannot be empty.') raise InvalidUserNameError('Name cannot be empty.')
if util.value_exceeds_column_size(name, db.User.name): if util.value_exceeds_column_size(name, model.User.name):
raise InvalidUserNameError('User name is too long.') raise InvalidUserNameError('User name is too long.')
name = name.strip() name = name.strip()
name_regex = config.config['user_name_regex'] name_regex = config.config['user_name_regex']
@ -174,7 +233,7 @@ def update_user_name(user, name):
user.name = name user.name = name
def update_user_password(user, password): def update_user_password(user: model.User, password: str) -> None:
assert user assert user
if not password: if not password:
raise InvalidPasswordError('Password cannot be empty.') raise InvalidPasswordError('Password cannot be empty.')
@ -186,20 +245,18 @@ def update_user_password(user, password):
user.password_hash = auth.get_password_hash(user.password_salt, password) user.password_hash = auth.get_password_hash(user.password_salt, password)
def update_user_email(user, email): def update_user_email(user: model.User, email: str) -> None:
assert user assert user
if email:
email = email.strip() email = email.strip()
if not email: if util.value_exceeds_column_size(email, model.User.email):
email = None
if email and util.value_exceeds_column_size(email, db.User.email):
raise InvalidEmailError('Email is too long.') raise InvalidEmailError('Email is too long.')
if not util.is_valid_email(email): if not util.is_valid_email(email):
raise InvalidEmailError('E-mail is invalid.') raise InvalidEmailError('E-mail is invalid.')
user.email = email user.email = email or None
def update_user_rank(user, rank, auth_user): def update_user_rank(
user: model.User, rank: str, auth_user: model.User) -> None:
assert user assert user
if not rank: if not rank:
raise InvalidRankError('Rank cannot be empty.') raise InvalidRankError('Rank cannot be empty.')
@ -208,7 +265,7 @@ def update_user_rank(user, rank, auth_user):
if not rank: if not rank:
raise InvalidRankError( raise InvalidRankError(
'Rank can be either of %r.' % all_ranks) 'Rank can be either of %r.' % all_ranks)
if rank in (db.User.RANK_ANONYMOUS, db.User.RANK_NOBODY): if rank in (model.User.RANK_ANONYMOUS, model.User.RANK_NOBODY):
raise InvalidRankError('Rank %r cannot be used.' % auth.RANK_MAP[rank]) raise InvalidRankError('Rank %r cannot be used.' % auth.RANK_MAP[rank])
if all_ranks.index(auth_user.rank) \ if all_ranks.index(auth_user.rank) \
< all_ranks.index(rank) and get_user_count() > 0: < all_ranks.index(rank) and get_user_count() > 0:
@ -216,7 +273,10 @@ def update_user_rank(user, rank, auth_user):
user.rank = rank user.rank = rank
def update_user_avatar(user, avatar_style, avatar_content=None): def update_user_avatar(
user: model.User,
avatar_style: str,
avatar_content: Optional[bytes]=None) -> None:
assert user assert user
if avatar_style == 'gravatar': if avatar_style == 'gravatar':
user.avatar_style = user.AVATAR_GRAVATAR user.avatar_style = user.AVATAR_GRAVATAR
@ -238,12 +298,12 @@ def update_user_avatar(user, avatar_style, avatar_content=None):
avatar_style, ['gravatar', 'manual'])) avatar_style, ['gravatar', 'manual']))
def bump_user_login_time(user): def bump_user_login_time(user: model.User) -> None:
assert user assert user
user.last_login_time = datetime.datetime.utcnow() user.last_login_time = datetime.utcnow()
def reset_user_password(user): def reset_user_password(user: model.User) -> str:
assert user assert user
password = auth.create_password() password = auth.create_password()
user.password_salt = auth.create_password() user.password_salt = auth.create_password()

View file

@ -2,52 +2,39 @@ import os
import hashlib import hashlib
import re import re
import tempfile import tempfile
from typing import (
Any, Optional, Union, Tuple, List, Dict, Generator, Callable, TypeVar)
from datetime import datetime, timedelta from datetime import datetime, timedelta
from contextlib import contextmanager from contextlib import contextmanager
from szurubooru import errors from szurubooru import errors
def snake_case_to_lower_camel_case(text): T = TypeVar('T')
def snake_case_to_lower_camel_case(text: str) -> str:
components = text.split('_') components = text.split('_')
return components[0].lower() + \ return components[0].lower() + \
''.join(word[0].upper() + word[1:].lower() for word in components[1:]) ''.join(word[0].upper() + word[1:].lower() for word in components[1:])
def snake_case_to_upper_train_case(text): def snake_case_to_upper_train_case(text: str) -> str:
return '-'.join( return '-'.join(
word[0].upper() + word[1:].lower() for word in text.split('_')) word[0].upper() + word[1:].lower() for word in text.split('_'))
def snake_case_to_lower_camel_case_keys(source): def snake_case_to_lower_camel_case_keys(
source: Dict[str, Any]) -> Dict[str, Any]:
target = {} target = {}
for key, value in source.items(): for key, value in source.items():
target[snake_case_to_lower_camel_case(key)] = value target[snake_case_to_lower_camel_case(key)] = value
return target return target
def get_serialization_options(ctx):
return ctx.get_param_as_list('fields', required=False, default=None)
def serialize_entity(entity, field_factories, options):
if not entity:
return None
if not options or len(options) == 0:
options = field_factories.keys()
ret = {}
for key in options:
if key not in field_factories:
raise errors.ValidationError('Invalid key: %r. Valid keys: %r.' % (
key, list(sorted(field_factories.keys()))))
factory = field_factories[key]
ret[key] = factory()
return ret
@contextmanager @contextmanager
def create_temp_file(**kwargs): def create_temp_file(**kwargs: Any) -> Generator:
(handle, path) = tempfile.mkstemp(**kwargs) (descriptor, path) = tempfile.mkstemp(**kwargs)
os.close(handle) os.close(descriptor)
try: try:
with open(path, 'r+b') as handle: with open(path, 'r+b') as handle:
yield handle yield handle
@ -55,17 +42,15 @@ def create_temp_file(**kwargs):
os.remove(path) os.remove(path)
def unalias_dict(input_dict): def unalias_dict(source: List[Tuple[List[str], T]]) -> Dict[str, T]:
output_dict = {} output_dict = {} # type: Dict[str, T]
for key_list, value in input_dict.items(): for aliases, value in source:
if isinstance(key_list, str): for alias in aliases:
key_list = [key_list] output_dict[alias] = value
for key in key_list:
output_dict[key] = value
return output_dict return output_dict
def get_md5(source): def get_md5(source: Union[str, bytes]) -> str:
if not isinstance(source, bytes): if not isinstance(source, bytes):
source = source.encode('utf-8') source = source.encode('utf-8')
md5 = hashlib.md5() md5 = hashlib.md5()
@ -73,7 +58,7 @@ def get_md5(source):
return md5.hexdigest() return md5.hexdigest()
def get_sha1(source): def get_sha1(source: Union[str, bytes]) -> str:
if not isinstance(source, bytes): if not isinstance(source, bytes):
source = source.encode('utf-8') source = source.encode('utf-8')
sha1 = hashlib.sha1() sha1 = hashlib.sha1()
@ -81,24 +66,25 @@ def get_sha1(source):
return sha1.hexdigest() return sha1.hexdigest()
def flip(source): def flip(source: Dict[Any, Any]) -> Dict[Any, Any]:
return {v: k for k, v in source.items()} return {v: k for k, v in source.items()}
def is_valid_email(email): def is_valid_email(email: Optional[str]) -> bool:
''' Return whether given email address is valid or empty. ''' ''' Return whether given email address is valid or empty. '''
return not email or re.match(r'^[^@]*@[^@]*\.[^@]*$', email) return not email or re.match(r'^[^@]*@[^@]*\.[^@]*$', email) is not None
class dotdict(dict): # pylint: disable=invalid-name class dotdict(dict): # pylint: disable=invalid-name
''' dot.notation access to dictionary attributes. ''' ''' dot.notation access to dictionary attributes. '''
def __getattr__(self, attr): def __getattr__(self, attr: str) -> Any:
return self.get(attr) return self.get(attr)
__setattr__ = dict.__setitem__ __setattr__ = dict.__setitem__
__delattr__ = dict.__delitem__ __delattr__ = dict.__delitem__
def parse_time_range(value): def parse_time_range(value: str) -> Tuple[datetime, datetime]:
''' Return tuple containing min/max time for given text representation. ''' ''' Return tuple containing min/max time for given text representation. '''
one_day = timedelta(days=1) one_day = timedelta(days=1)
one_second = timedelta(seconds=1) one_second = timedelta(seconds=1)
@ -146,9 +132,9 @@ def parse_time_range(value):
raise errors.ValidationError('Invalid date format: %r.' % value) raise errors.ValidationError('Invalid date format: %r.' % value)
def icase_unique(source): def icase_unique(source: List[str]) -> List[str]:
target = [] target = [] # type: List[str]
target_low = [] target_low = [] # type: List[str]
for source_item in source: for source_item in source:
if source_item.lower() not in target_low: if source_item.lower() not in target_low:
target.append(source_item) target.append(source_item)
@ -156,7 +142,7 @@ def icase_unique(source):
return target return target
def value_exceeds_column_size(value, column): def value_exceeds_column_size(value: Optional[str], column: Any) -> bool:
if not value: if not value:
return False return False
max_length = column.property.columns[0].type.length max_length = column.property.columns[0].type.length
@ -165,6 +151,6 @@ def value_exceeds_column_size(value, column):
return len(value) > max_length return len(value) > max_length
def chunks(source_list, part_size): def chunks(source_list: List[Any], part_size: int) -> Generator:
for i in range(0, len(source_list), part_size): for i in range(0, len(source_list), part_size):
yield source_list[i:i + part_size] yield source_list[i:i + part_size]

View file

@ -1,8 +1,11 @@
from szurubooru import errors from szurubooru import errors, rest, model
def verify_version(entity, context, field_name='version'): def verify_version(
actual_version = context.get_param_as_int(field_name, required=True) entity: model.Base,
context: rest.Context,
field_name: str='version') -> None:
actual_version = context.get_param_as_int(field_name)
expected_version = entity.version expected_version = entity.version
if actual_version != expected_version: if actual_version != expected_version:
raise errors.IntegrityError( raise errors.IntegrityError(
@ -10,5 +13,5 @@ def verify_version(entity, context, field_name='version'):
'Please try again.') 'Please try again.')
def bump_version(entity): def bump_version(entity: model.Base) -> None:
entity.version = entity.version + 1 entity.version = entity.version + 1

View file

@ -1,11 +1,11 @@
import base64 import base64
from szurubooru import db, errors from typing import Optional
from szurubooru import db, model, errors, rest
from szurubooru.func import auth, users from szurubooru.func import auth, users
from szurubooru.rest import middleware
from szurubooru.rest.errors import HttpBadRequest from szurubooru.rest.errors import HttpBadRequest
def _authenticate(username, password): def _authenticate(username: str, password: str) -> model.User:
''' Try to authenticate user. Throw AuthError for invalid users. ''' ''' Try to authenticate user. Throw AuthError for invalid users. '''
user = users.get_user_by_name(username) user = users.get_user_by_name(username)
if not auth.is_valid_password(user, password): if not auth.is_valid_password(user, password):
@ -13,16 +13,9 @@ def _authenticate(username, password):
return user return user
def _create_anonymous_user(): def _get_user(ctx: rest.Context) -> Optional[model.User]:
user = db.User()
user.name = None
user.rank = 'anonymous'
return user
def _get_user(ctx):
if not ctx.has_header('Authorization'): if not ctx.has_header('Authorization'):
return _create_anonymous_user() return None
try: try:
auth_type, credentials = ctx.get_header('Authorization').split(' ', 1) auth_type, credentials = ctx.get_header('Authorization').split(' ', 1)
@ -41,10 +34,12 @@ def _get_user(ctx):
msg.format(ctx.get_header('Authorization'), str(err))) msg.format(ctx.get_header('Authorization'), str(err)))
@middleware.pre_hook @rest.middleware.pre_hook
def process_request(ctx): def process_request(ctx: rest.Context) -> None:
''' Bind the user to request. Update last login time if needed. ''' ''' Bind the user to request. Update last login time if needed. '''
ctx.user = _get_user(ctx) auth_user = _get_user(ctx)
if ctx.get_param_as_bool('bump-login') and ctx.user.user_id: if auth_user:
ctx.user = auth_user
if ctx.get_param_as_bool('bump-login', default=False) and ctx.user.user_id:
users.bump_user_login_time(ctx.user) users.bump_user_login_time(ctx.user)
ctx.session.commit() ctx.session.commit()

View file

@ -1,8 +1,9 @@
from szurubooru import rest
from szurubooru.func import cache from szurubooru.func import cache
from szurubooru.rest import middleware from szurubooru.rest import middleware
@middleware.pre_hook @middleware.pre_hook
def process_request(ctx): def process_request(ctx: rest.Context) -> None:
if ctx.method != 'GET': if ctx.method != 'GET':
cache.purge() cache.purge()

View file

@ -1,5 +1,5 @@
import logging import logging
from szurubooru import db from szurubooru import db, rest
from szurubooru.rest import middleware from szurubooru.rest import middleware
@ -7,12 +7,12 @@ logger = logging.getLogger(__name__)
@middleware.pre_hook @middleware.pre_hook
def process_request(_ctx): def process_request(_ctx: rest.Context) -> None:
db.reset_query_count() db.reset_query_count()
@middleware.post_hook @middleware.post_hook
def process_response(ctx): def process_response(ctx: rest.Context) -> None:
logger.info( logger.info(
'%s %s (user=%s, queries=%d)', '%s %s (user=%s, queries=%d)',
ctx.method, ctx.method,

View file

@ -2,7 +2,7 @@ import os
import sys import sys
import alembic import alembic
import sqlalchemy import sqlalchemy as sa
import logging.config import logging.config
# make szurubooru module importable # make szurubooru module importable
@ -48,7 +48,7 @@ def run_migrations_online():
In this scenario we need to create an Engine In this scenario we need to create an Engine
and associate a connection with the context. and associate a connection with the context.
''' '''
connectable = sqlalchemy.engine_from_config( connectable = sa.engine_from_config(
alembic_config.get_section(alembic_config.config_ini_section), alembic_config.get_section(alembic_config.config_ini_section),
prefix='sqlalchemy.', prefix='sqlalchemy.',
poolclass=sqlalchemy.pool.NullPool) poolclass=sqlalchemy.pool.NullPool)

View file

@ -0,0 +1,15 @@
from szurubooru.model.base import Base
from szurubooru.model.user import User
from szurubooru.model.tag_category import TagCategory
from szurubooru.model.tag import Tag, TagName, TagSuggestion, TagImplication
from szurubooru.model.post import (
Post,
PostTag,
PostRelation,
PostFavorite,
PostScore,
PostNote,
PostFeature)
from szurubooru.model.comment import Comment, CommentScore
from szurubooru.model.snapshot import Snapshot
import szurubooru.model.util

View file

@ -1,7 +1,8 @@
from sqlalchemy import Column, Integer, DateTime, UnicodeText, ForeignKey from sqlalchemy import Column, Integer, DateTime, UnicodeText, ForeignKey
from sqlalchemy.orm import relationship, backref from sqlalchemy.orm import relationship, backref
from sqlalchemy.sql.expression import func from sqlalchemy.sql.expression import func
from szurubooru.db.base import Base from szurubooru.db import get_session
from szurubooru.model.base import Base
class CommentScore(Base): class CommentScore(Base):
@ -48,12 +49,12 @@ class Comment(Base):
'CommentScore', cascade='all, delete-orphan', lazy='joined') 'CommentScore', cascade='all, delete-orphan', lazy='joined')
@property @property
def score(self): def score(self) -> int:
from szurubooru.db import session return (
return session \ get_session()
.query(func.sum(CommentScore.score)) \ .query(func.sum(CommentScore.score))
.filter(CommentScore.comment_id == self.comment_id) \ .filter(CommentScore.comment_id == self.comment_id)
.one()[0] or 0 .one()[0] or 0)
__mapper_args__ = { __mapper_args__ = {
'version_id_col': version, 'version_id_col': version,

View file

@ -3,8 +3,8 @@ from sqlalchemy import (
Column, Integer, DateTime, Unicode, UnicodeText, PickleType, ForeignKey) Column, Integer, DateTime, Unicode, UnicodeText, PickleType, ForeignKey)
from sqlalchemy.orm import ( from sqlalchemy.orm import (
relationship, column_property, object_session, backref) relationship, column_property, object_session, backref)
from szurubooru.db.base import Base from szurubooru.model.base import Base
from szurubooru.db.comment import Comment from szurubooru.model.comment import Comment
class PostFeature(Base): class PostFeature(Base):
@ -17,10 +17,9 @@ class PostFeature(Base):
'user_id', Integer, ForeignKey('user.id'), nullable=False, index=True) 'user_id', Integer, ForeignKey('user.id'), nullable=False, index=True)
time = Column('time', DateTime, nullable=False) time = Column('time', DateTime, nullable=False)
post = relationship('Post') post = relationship('Post') # type: Post
user = relationship( user = relationship(
'User', 'User', backref=backref('post_features', cascade='all, delete-orphan'))
backref=backref('post_features', cascade='all, delete-orphan'))
class PostScore(Base): class PostScore(Base):
@ -104,7 +103,7 @@ class PostRelation(Base):
nullable=False, nullable=False,
index=True) index=True)
def __init__(self, parent_id, child_id): def __init__(self, parent_id: int, child_id: int) -> None:
self.parent_id = parent_id self.parent_id = parent_id
self.child_id = child_id self.child_id = child_id
@ -127,7 +126,7 @@ class PostTag(Base):
nullable=False, nullable=False,
index=True) index=True)
def __init__(self, post_id, tag_id): def __init__(self, post_id: int, tag_id: int) -> None:
self.post_id = post_id self.post_id = post_id
self.tag_id = tag_id self.tag_id = tag_id
@ -197,7 +196,7 @@ class Post(Base):
canvas_area = column_property(canvas_width * canvas_height) canvas_area = column_property(canvas_width * canvas_height)
@property @property
def is_featured(self): def is_featured(self) -> bool:
featured_post = object_session(self) \ featured_post = object_session(self) \
.query(PostFeature) \ .query(PostFeature) \
.order_by(PostFeature.time.desc()) \ .order_by(PostFeature.time.desc()) \

View file

@ -1,7 +1,7 @@
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from sqlalchemy import ( from sqlalchemy import (
Column, Integer, DateTime, Unicode, PickleType, ForeignKey) Column, Integer, DateTime, Unicode, PickleType, ForeignKey)
from szurubooru.db.base import Base from szurubooru.model.base import Base
class Snapshot(Base): class Snapshot(Base):

View file

@ -2,8 +2,8 @@ from sqlalchemy import (
Column, Integer, DateTime, Unicode, UnicodeText, ForeignKey) Column, Integer, DateTime, Unicode, UnicodeText, ForeignKey)
from sqlalchemy.orm import relationship, column_property from sqlalchemy.orm import relationship, column_property
from sqlalchemy.sql.expression import func, select from sqlalchemy.sql.expression import func, select
from szurubooru.db.base import Base from szurubooru.model.base import Base
from szurubooru.db.post import PostTag from szurubooru.model.post import PostTag
class TagSuggestion(Base): class TagSuggestion(Base):
@ -24,7 +24,7 @@ class TagSuggestion(Base):
primary_key=True, primary_key=True,
index=True) index=True)
def __init__(self, parent_id, child_id): def __init__(self, parent_id: int, child_id: int) -> None:
self.parent_id = parent_id self.parent_id = parent_id
self.child_id = child_id self.child_id = child_id
@ -47,7 +47,7 @@ class TagImplication(Base):
primary_key=True, primary_key=True,
index=True) index=True)
def __init__(self, parent_id, child_id): def __init__(self, parent_id: int, child_id: int) -> None:
self.parent_id = parent_id self.parent_id = parent_id
self.child_id = child_id self.child_id = child_id
@ -61,7 +61,7 @@ class TagName(Base):
name = Column('name', Unicode(64), nullable=False, unique=True) name = Column('name', Unicode(64), nullable=False, unique=True)
order = Column('ord', Integer, nullable=False, index=True) order = Column('ord', Integer, nullable=False, index=True)
def __init__(self, name, order): def __init__(self, name: str, order: int) -> None:
self.name = name self.name = name
self.order = order self.order = order

View file

@ -1,8 +1,9 @@
from typing import Optional
from sqlalchemy import Column, Integer, Unicode, Boolean, table from sqlalchemy import Column, Integer, Unicode, Boolean, table
from sqlalchemy.orm import column_property from sqlalchemy.orm import column_property
from sqlalchemy.sql.expression import func, select from sqlalchemy.sql.expression import func, select
from szurubooru.db.base import Base from szurubooru.model.base import Base
from szurubooru.db.tag import Tag from szurubooru.model.tag import Tag
class TagCategory(Base): class TagCategory(Base):
@ -14,7 +15,7 @@ class TagCategory(Base):
color = Column('color', Unicode(32), nullable=False, default='#000000') color = Column('color', Unicode(32), nullable=False, default='#000000')
default = Column('default', Boolean, nullable=False, default=False) default = Column('default', Boolean, nullable=False, default=False)
def __init__(self, name=None): def __init__(self, name: Optional[str]=None) -> None:
self.name = name self.name = name
tag_count = column_property( tag_count = column_property(

View file

@ -1,9 +1,7 @@
from sqlalchemy import Column, Integer, Unicode, DateTime import sqlalchemy as sa
from sqlalchemy.orm import relationship from szurubooru.model.base import Base
from sqlalchemy.sql.expression import func from szurubooru.model.post import Post, PostScore, PostFavorite
from szurubooru.db.base import Base from szurubooru.model.comment import Comment
from szurubooru.db.post import Post, PostScore, PostFavorite
from szurubooru.db.comment import Comment
class User(Base): class User(Base):
@ -20,63 +18,64 @@ class User(Base):
RANK_ADMINISTRATOR = 'administrator' RANK_ADMINISTRATOR = 'administrator'
RANK_NOBODY = 'nobody' # unattainable, used for privileges RANK_NOBODY = 'nobody' # unattainable, used for privileges
user_id = Column('id', Integer, primary_key=True) user_id = sa.Column('id', sa.Integer, primary_key=True)
creation_time = Column('creation_time', DateTime, nullable=False) creation_time = sa.Column('creation_time', sa.DateTime, nullable=False)
last_login_time = Column('last_login_time', DateTime) last_login_time = sa.Column('last_login_time', sa.DateTime)
version = Column('version', Integer, default=1, nullable=False) version = sa.Column('version', sa.Integer, default=1, nullable=False)
name = Column('name', Unicode(50), nullable=False, unique=True) name = sa.Column('name', sa.Unicode(50), nullable=False, unique=True)
password_hash = Column('password_hash', Unicode(64), nullable=False) password_hash = sa.Column('password_hash', sa.Unicode(64), nullable=False)
password_salt = Column('password_salt', Unicode(32)) password_salt = sa.Column('password_salt', sa.Unicode(32))
email = Column('email', Unicode(64), nullable=True) email = sa.Column('email', sa.Unicode(64), nullable=True)
rank = Column('rank', Unicode(32), nullable=False) rank = sa.Column('rank', sa.Unicode(32), nullable=False)
avatar_style = Column( avatar_style = sa.Column(
'avatar_style', Unicode(32), nullable=False, default=AVATAR_GRAVATAR) 'avatar_style', sa.Unicode(32), nullable=False,
default=AVATAR_GRAVATAR)
comments = relationship('Comment') comments = sa.orm.relationship('Comment')
@property @property
def post_count(self): def post_count(self) -> int:
from szurubooru.db import session from szurubooru.db import session
return ( return (
session session
.query(func.sum(1)) .query(sa.sql.expression.func.sum(1))
.filter(Post.user_id == self.user_id) .filter(Post.user_id == self.user_id)
.one()[0] or 0) .one()[0] or 0)
@property @property
def comment_count(self): def comment_count(self) -> int:
from szurubooru.db import session from szurubooru.db import session
return ( return (
session session
.query(func.sum(1)) .query(sa.sql.expression.func.sum(1))
.filter(Comment.user_id == self.user_id) .filter(Comment.user_id == self.user_id)
.one()[0] or 0) .one()[0] or 0)
@property @property
def favorite_post_count(self): def favorite_post_count(self) -> int:
from szurubooru.db import session from szurubooru.db import session
return ( return (
session session
.query(func.sum(1)) .query(sa.sql.expression.func.sum(1))
.filter(PostFavorite.user_id == self.user_id) .filter(PostFavorite.user_id == self.user_id)
.one()[0] or 0) .one()[0] or 0)
@property @property
def liked_post_count(self): def liked_post_count(self) -> int:
from szurubooru.db import session from szurubooru.db import session
return ( return (
session session
.query(func.sum(1)) .query(sa.sql.expression.func.sum(1))
.filter(PostScore.user_id == self.user_id) .filter(PostScore.user_id == self.user_id)
.filter(PostScore.score == 1) .filter(PostScore.score == 1)
.one()[0] or 0) .one()[0] or 0)
@property @property
def disliked_post_count(self): def disliked_post_count(self) -> int:
from szurubooru.db import session from szurubooru.db import session
return ( return (
session session
.query(func.sum(1)) .query(sa.sql.expression.func.sum(1))
.filter(PostScore.user_id == self.user_id) .filter(PostScore.user_id == self.user_id)
.filter(PostScore.score == -1) .filter(PostScore.score == -1)
.one()[0] or 0) .one()[0] or 0)

View file

@ -0,0 +1,42 @@
from typing import Tuple, Any, Dict, Callable, Union, Optional
import sqlalchemy as sa
from szurubooru.model.base import Base
from szurubooru.model.user import User
def get_resource_info(entity: Base) -> Tuple[Any, Any, Union[str, int]]:
serializers = {
'tag': lambda tag: tag.first_name,
'tag_category': lambda category: category.name,
'comment': lambda comment: comment.comment_id,
'post': lambda post: post.post_id,
} # type: Dict[str, Callable[[Base], Any]]
resource_type = entity.__table__.name
assert resource_type in serializers
primary_key = sa.inspection.inspect(entity).identity # type: Any
assert primary_key is not None
assert len(primary_key) == 1
resource_name = serializers[resource_type](entity) # type: Union[str, int]
assert resource_name
resource_pkey = primary_key[0] # type: Any
assert resource_pkey
return (resource_type, resource_pkey, resource_name)
def get_aux_entity(
session: Any,
get_table_info: Callable[[Base], Tuple[Base, Callable[[Base], Any]]],
entity: Base,
user: User) -> Optional[Base]:
table, get_column = get_table_info(entity)
return (
session
.query(table)
.filter(get_column(table) == get_column(entity))
.filter(table.user_id == user.user_id)
.one_or_none())

View file

@ -1,2 +1,2 @@
from szurubooru.rest.app import application from szurubooru.rest.app import application
from szurubooru.rest.context import Context from szurubooru.rest.context import Context, Response

View file

@ -2,13 +2,14 @@ import urllib.parse
import cgi import cgi
import json import json
import re import re
from typing import Dict, Any, Callable, Tuple
from datetime import datetime from datetime import datetime
from szurubooru import db from szurubooru import db
from szurubooru.func import util from szurubooru.func import util
from szurubooru.rest import errors, middleware, routes, context from szurubooru.rest import errors, middleware, routes, context
def _json_serializer(obj): def _json_serializer(obj: Any) -> str:
''' JSON serializer for objects not serializable by default JSON code ''' ''' JSON serializer for objects not serializable by default JSON code '''
if isinstance(obj, datetime): if isinstance(obj, datetime):
serial = obj.isoformat('T') + 'Z' serial = obj.isoformat('T') + 'Z'
@ -16,12 +17,12 @@ def _json_serializer(obj):
raise TypeError('Type not serializable') raise TypeError('Type not serializable')
def _dump_json(obj): def _dump_json(obj: Any) -> str:
return json.dumps(obj, default=_json_serializer, indent=2) return json.dumps(obj, default=_json_serializer, indent=2)
def _get_headers(env): def _get_headers(env: Dict[str, Any]) -> Dict[str, str]:
headers = {} headers = {} # type: Dict[str, str]
for key, value in env.items(): for key, value in env.items():
if key.startswith('HTTP_'): if key.startswith('HTTP_'):
key = util.snake_case_to_upper_train_case(key[5:]) key = util.snake_case_to_upper_train_case(key[5:])
@ -29,7 +30,7 @@ def _get_headers(env):
return headers return headers
def _create_context(env): def _create_context(env: Dict[str, Any]) -> context.Context:
method = env['REQUEST_METHOD'] method = env['REQUEST_METHOD']
path = '/' + env['PATH_INFO'].lstrip('/') path = '/' + env['PATH_INFO'].lstrip('/')
headers = _get_headers(env) headers = _get_headers(env)
@ -64,7 +65,9 @@ def _create_context(env):
return context.Context(method, path, headers, params, files) return context.Context(method, path, headers, params, files)
def application(env, start_response): def application(
env: Dict[str, Any],
start_response: Callable[[str, Any], Any]) -> Tuple[bytes]:
try: try:
ctx = _create_context(env) ctx = _create_context(env)
if 'application/json' not in ctx.get_header('Accept'): if 'application/json' not in ctx.get_header('Accept'):
@ -106,9 +109,9 @@ def application(env, start_response):
return (_dump_json(response).encode('utf-8'),) return (_dump_json(response).encode('utf-8'),)
except Exception as ex: except Exception as ex:
for exception_type, handler in errors.error_handlers.items(): for exception_type, ex_handler in errors.error_handlers.items():
if isinstance(ex, exception_type): if isinstance(ex, exception_type):
handler(ex) ex_handler(ex)
raise raise
except errors.BaseHttpError as ex: except errors.BaseHttpError as ex:

View file

@ -1,111 +1,158 @@
from szurubooru import errors from typing import Any, Union, List, Dict, Optional, cast
from szurubooru import model, errors
from szurubooru.func import net, file_uploads from szurubooru.func import net, file_uploads
def _lower_first(source): MISSING = object()
return source[0].lower() + source[1:] Request = Dict[str, Any]
Response = Optional[Dict[str, Any]]
def _param_wrapper(func):
def wrapper(self, name, required=False, default=None, **kwargs):
# pylint: disable=protected-access
if name in self._params:
value = self._params[name]
try:
value = func(self, value, **kwargs)
except errors.InvalidParameterError as ex:
raise errors.InvalidParameterError(
'Parameter %r is invalid: %s' % (
name, _lower_first(str(ex))))
return value
if not required:
return default
raise errors.MissingRequiredParameterError(
'Required parameter %r is missing.' % name)
return wrapper
class Context: class Context:
def __init__(self, method, url, headers=None, params=None, files=None): def __init__(
self,
method: str,
url: str,
headers: Dict[str, str]=None,
params: Request=None,
files: Dict[str, bytes]=None) -> None:
self.method = method self.method = method
self.url = url self.url = url
self._headers = headers or {} self._headers = headers or {}
self._params = params or {} self._params = params or {}
self._files = files or {} self._files = files or {}
# provided by middleware self.user = model.User()
# self.session = None self.user.name = None
# self.user = None self.user.rank = 'anonymous'
def has_header(self, name): self.session = None # type: Any
def has_header(self, name: str) -> bool:
return name in self._headers return name in self._headers
def get_header(self, name): def get_header(self, name: str) -> str:
return self._headers.get(name, None) return self._headers.get(name, '')
def has_file(self, name, allow_tokens=True): def has_file(self, name: str, allow_tokens: bool=True) -> bool:
return ( return (
name in self._files or name in self._files or
name + 'Url' in self._params or name + 'Url' in self._params or
(allow_tokens and name + 'Token' in self._params)) (allow_tokens and name + 'Token' in self._params))
def get_file(self, name, required=False, allow_tokens=True): def get_file(
ret = None self,
if name in self._files: name: str,
ret = self._files[name] default: Union[object, bytes]=MISSING,
elif name + 'Url' in self._params: allow_tokens: bool=True) -> bytes:
ret = net.download(self._params[name + 'Url']) if name in self._files and self._files[name]:
elif allow_tokens and name + 'Token' in self._params: return self._files[name]
if name + 'Url' in self._params:
return net.download(self._params[name + 'Url'])
if allow_tokens and name + 'Token' in self._params:
ret = file_uploads.get(self._params[name + 'Token']) ret = file_uploads.get(self._params[name + 'Token'])
if required and not ret: if ret:
return ret
elif default is not MISSING:
raise errors.MissingOrExpiredRequiredFileError( raise errors.MissingOrExpiredRequiredFileError(
'Required file %r is missing or has expired.' % name) 'Required file %r is missing or has expired.' % name)
if required and not ret:
if default is not MISSING:
return cast(bytes, default)
raise errors.MissingRequiredFileError( raise errors.MissingRequiredFileError(
'Required file %r is missing.' % name) 'Required file %r is missing.' % name)
return ret
def has_param(self, name): def has_param(self, name: str) -> bool:
return name in self._params return name in self._params
@_param_wrapper def get_param_as_list(
def get_param_as_list(self, value): self,
if not isinstance(value, list): name: str,
default: Union[object, List[Any]]=MISSING) -> List[Any]:
if name not in self._params:
if default is not MISSING:
return cast(List[Any], default)
raise errors.MissingRequiredParameterError(
'Required parameter %r is missing.' % name)
value = self._params[name]
if type(value) is str:
if ',' in value: if ',' in value:
return value.split(',') return value.split(',')
return [value] return [value]
if type(value) is list:
return value return value
raise errors.InvalidParameterError(
'Parameter %r must be a list.' % name)
@_param_wrapper def get_param_as_string(
def get_param_as_string(self, value): self,
if isinstance(value, list): name: str,
default: Union[object, str]=MISSING) -> str:
if name not in self._params:
if default is not MISSING:
return cast(str, default)
raise errors.MissingRequiredParameterError(
'Required parameter %r is missing.' % name)
value = self._params[name]
try: try:
value = ','.join(value) if value is None:
except TypeError: return ''
raise errors.InvalidParameterError('Expected simple string.') if type(value) is list:
return ','.join(value)
if type(value) is int or type(value) is float:
return str(value)
if type(value) is str:
return value return value
except TypeError:
pass
raise errors.InvalidParameterError(
'Parameter %r must be a string value.' % name)
@_param_wrapper def get_param_as_int(
def get_param_as_int(self, value, min=None, max=None): self,
name: str,
default: Union[object, int]=MISSING,
min: Optional[int]=None,
max: Optional[int]=None) -> int:
if name not in self._params:
if default is not MISSING:
return cast(int, default)
raise errors.MissingRequiredParameterError(
'Required parameter %r is missing.' % name)
value = self._params[name]
try: try:
value = int(value) value = int(value)
except (ValueError, TypeError):
raise errors.InvalidParameterError(
'The value must be an integer.')
if min is not None and value < min: if min is not None and value < min:
raise errors.InvalidParameterError( raise errors.InvalidParameterError(
'The value must be at least %r.' % min) 'Parameter %r must be at least %r.' % (name, min))
if max is not None and value > max: if max is not None and value > max:
raise errors.InvalidParameterError( raise errors.InvalidParameterError(
'The value may not exceed %r.' % max) 'Parameter %r may not exceed %r.' % (name, max))
return value return value
except (ValueError, TypeError):
pass
raise errors.InvalidParameterError(
'Parameter %r must be an integer value.' % name)
@_param_wrapper def get_param_as_bool(
def get_param_as_bool(self, value): self,
name: str,
default: Union[object, bool]=MISSING) -> bool:
if name not in self._params:
if default is not MISSING:
return cast(bool, default)
raise errors.MissingRequiredParameterError(
'Required parameter %r is missing.' % name)
value = self._params[name]
try:
value = str(value).lower() value = str(value).lower()
except TypeError:
pass
if value in ['1', 'y', 'yes', 'yeah', 'yep', 'yup', 't', 'true']: if value in ['1', 'y', 'yes', 'yeah', 'yep', 'yup', 't', 'true']:
return True return True
if value in ['0', 'n', 'no', 'nope', 'f', 'false']: if value in ['0', 'n', 'no', 'nope', 'f', 'false']:
return False return False
raise errors.InvalidParameterError( raise errors.InvalidParameterError(
'The value must be a boolean value.') 'Parameter %r must be a boolean value.' % name)

View file

@ -1,11 +1,19 @@
from typing import Callable, Type, Dict
error_handlers = {} # pylint: disable=invalid-name error_handlers = {} # pylint: disable=invalid-name
class BaseHttpError(RuntimeError): class BaseHttpError(RuntimeError):
code = None code = -1
reason = None reason = ''
def __init__(self, name, description, title=None, extra_fields=None): def __init__(
self,
name: str,
description: str,
title: str=None,
extra_fields: Dict[str, str]=None) -> None:
super().__init__() super().__init__()
# error name for programmers # error name for programmers
self.name = name self.name = name
@ -52,5 +60,7 @@ class HttpInternalServerError(BaseHttpError):
reason = 'Internal Server Error' reason = 'Internal Server Error'
def handle(exception_type, handler): def handle(
exception_type: Type[Exception],
handler: Callable[[Exception], None]) -> None:
error_handlers[exception_type] = handler error_handlers[exception_type] = handler

View file

@ -1,11 +1,15 @@
from typing import Callable
from szurubooru.rest.context import Context
# pylint: disable=invalid-name # pylint: disable=invalid-name
pre_hooks = [] pre_hooks = [] # type: List[Callable[[Context], None]]
post_hooks = [] post_hooks = [] # type: List[Callable[[Context], None]]
def pre_hook(handler): def pre_hook(handler: Callable) -> None:
pre_hooks.append(handler) pre_hooks.append(handler)
def post_hook(handler): def post_hook(handler: Callable) -> None:
post_hooks.insert(0, handler) post_hooks.insert(0, handler)

View file

@ -1,32 +1,36 @@
from typing import Callable, Dict, Any
from collections import defaultdict from collections import defaultdict
from szurubooru.rest.context import Context, Response
routes = defaultdict(dict) # pylint: disable=invalid-name # pylint: disable=invalid-name
RouteHandler = Callable[[Context, Dict[str, str]], Response]
routes = defaultdict(dict) # type: Dict[str, Dict[str, RouteHandler]]
def get(url): def get(url: str) -> Callable[[RouteHandler], RouteHandler]:
def wrapper(handler): def wrapper(handler: RouteHandler) -> RouteHandler:
routes[url]['GET'] = handler routes[url]['GET'] = handler
return handler return handler
return wrapper return wrapper
def put(url): def put(url: str) -> Callable[[RouteHandler], RouteHandler]:
def wrapper(handler): def wrapper(handler: RouteHandler) -> RouteHandler:
routes[url]['PUT'] = handler routes[url]['PUT'] = handler
return handler return handler
return wrapper return wrapper
def post(url): def post(url: str) -> Callable[[RouteHandler], RouteHandler]:
def wrapper(handler): def wrapper(handler: RouteHandler) -> RouteHandler:
routes[url]['POST'] = handler routes[url]['POST'] = handler
return handler return handler
return wrapper return wrapper
def delete(url): def delete(url: str) -> Callable[[RouteHandler], RouteHandler]:
def wrapper(handler): def wrapper(handler: RouteHandler) -> RouteHandler:
routes[url]['DELETE'] = handler routes[url]['DELETE'] = handler
return handler return handler
return wrapper return wrapper

View file

@ -1,38 +1,47 @@
from szurubooru.search import tokens from typing import Optional, Tuple, Dict, Callable
from szurubooru.search import tokens, criteria
from szurubooru.search.query import SearchQuery
from szurubooru.search.typing import SaColumn, SaQuery
Filter = Callable[[SaQuery, Optional[criteria.BaseCriterion], bool], SaQuery]
class BaseSearchConfig: class BaseSearchConfig:
SORT_NONE = tokens.SortToken.SORT_NONE
SORT_ASC = tokens.SortToken.SORT_ASC SORT_ASC = tokens.SortToken.SORT_ASC
SORT_DESC = tokens.SortToken.SORT_DESC SORT_DESC = tokens.SortToken.SORT_DESC
def on_search_query_parsed(self, search_query): def on_search_query_parsed(self, search_query: SearchQuery) -> None:
pass pass
def create_filter_query(self, _disable_eager_loads): def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery:
raise NotImplementedError() raise NotImplementedError()
def create_count_query(self, disable_eager_loads): def create_count_query(self, disable_eager_loads: bool) -> SaQuery:
raise NotImplementedError() raise NotImplementedError()
def create_around_query(self): def create_around_query(self) -> SaQuery:
raise NotImplementedError() raise NotImplementedError()
def finalize_query(self, query: SaQuery) -> SaQuery:
return query
@property @property
def id_column(self): def id_column(self) -> SaColumn:
return None return None
@property @property
def anonymous_filter(self): def anonymous_filter(self) -> Optional[Filter]:
return None return None
@property @property
def special_filters(self): def special_filters(self) -> Dict[str, Filter]:
return {} return {}
@property @property
def named_filters(self): def named_filters(self) -> Dict[str, Filter]:
return {} return {}
@property @property
def sort_columns(self): def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]:
return {} return {}

View file

@ -1,59 +1,62 @@
from sqlalchemy.sql.expression import func from typing import Tuple, Dict
from szurubooru import db import sqlalchemy as sa
from szurubooru import db, model
from szurubooru.search.typing import SaColumn, SaQuery
from szurubooru.search.configs import util as search_util from szurubooru.search.configs import util as search_util
from szurubooru.search.configs.base_search_config import BaseSearchConfig from szurubooru.search.configs.base_search_config import (
BaseSearchConfig, Filter)
class CommentSearchConfig(BaseSearchConfig): class CommentSearchConfig(BaseSearchConfig):
def create_filter_query(self, _disable_eager_loads): def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery:
return db.session.query(db.Comment).join(db.User) return db.session.query(model.Comment).join(model.User)
def create_count_query(self, disable_eager_loads): def create_count_query(self, disable_eager_loads: bool) -> SaQuery:
return self.create_filter_query(disable_eager_loads) return self.create_filter_query(disable_eager_loads)
def create_around_query(self): def create_around_query(self) -> SaQuery:
raise NotImplementedError() raise NotImplementedError()
def finalize_query(self, query): def finalize_query(self, query: SaQuery) -> SaQuery:
return query.order_by(db.Comment.creation_time.desc()) return query.order_by(model.Comment.creation_time.desc())
@property @property
def anonymous_filter(self): def anonymous_filter(self) -> SaQuery:
return search_util.create_str_filter(db.Comment.text) return search_util.create_str_filter(model.Comment.text)
@property @property
def named_filters(self): def named_filters(self) -> Dict[str, Filter]:
return { return {
'id': search_util.create_num_filter(db.Comment.comment_id), 'id': search_util.create_num_filter(model.Comment.comment_id),
'post': search_util.create_num_filter(db.Comment.post_id), 'post': search_util.create_num_filter(model.Comment.post_id),
'user': search_util.create_str_filter(db.User.name), 'user': search_util.create_str_filter(model.User.name),
'author': search_util.create_str_filter(db.User.name), 'author': search_util.create_str_filter(model.User.name),
'text': search_util.create_str_filter(db.Comment.text), 'text': search_util.create_str_filter(model.Comment.text),
'creation-date': 'creation-date':
search_util.create_date_filter(db.Comment.creation_time), search_util.create_date_filter(model.Comment.creation_time),
'creation-time': 'creation-time':
search_util.create_date_filter(db.Comment.creation_time), search_util.create_date_filter(model.Comment.creation_time),
'last-edit-date': 'last-edit-date':
search_util.create_date_filter(db.Comment.last_edit_time), search_util.create_date_filter(model.Comment.last_edit_time),
'last-edit-time': 'last-edit-time':
search_util.create_date_filter(db.Comment.last_edit_time), search_util.create_date_filter(model.Comment.last_edit_time),
'edit-date': 'edit-date':
search_util.create_date_filter(db.Comment.last_edit_time), search_util.create_date_filter(model.Comment.last_edit_time),
'edit-time': 'edit-time':
search_util.create_date_filter(db.Comment.last_edit_time), search_util.create_date_filter(model.Comment.last_edit_time),
} }
@property @property
def sort_columns(self): def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]:
return { return {
'random': (func.random(), None), 'random': (sa.sql.expression.func.random(), self.SORT_NONE),
'user': (db.User.name, self.SORT_ASC), 'user': (model.User.name, self.SORT_ASC),
'author': (db.User.name, self.SORT_ASC), 'author': (model.User.name, self.SORT_ASC),
'post': (db.Comment.post_id, self.SORT_DESC), 'post': (model.Comment.post_id, self.SORT_DESC),
'creation-date': (db.Comment.creation_time, self.SORT_DESC), 'creation-date': (model.Comment.creation_time, self.SORT_DESC),
'creation-time': (db.Comment.creation_time, self.SORT_DESC), 'creation-time': (model.Comment.creation_time, self.SORT_DESC),
'last-edit-date': (db.Comment.last_edit_time, self.SORT_DESC), 'last-edit-date': (model.Comment.last_edit_time, self.SORT_DESC),
'last-edit-time': (db.Comment.last_edit_time, self.SORT_DESC), 'last-edit-time': (model.Comment.last_edit_time, self.SORT_DESC),
'edit-date': (db.Comment.last_edit_time, self.SORT_DESC), 'edit-date': (model.Comment.last_edit_time, self.SORT_DESC),
'edit-time': (db.Comment.last_edit_time, self.SORT_DESC), 'edit-time': (model.Comment.last_edit_time, self.SORT_DESC),
} }

View file

@ -1,13 +1,16 @@
from sqlalchemy.orm import subqueryload, lazyload, defer, aliased from typing import Any, Optional, Tuple, Dict
from sqlalchemy.sql.expression import func import sqlalchemy as sa
from szurubooru import db, errors from szurubooru import db, model, errors
from szurubooru.func import util from szurubooru.func import util
from szurubooru.search import criteria, tokens from szurubooru.search import criteria, tokens
from szurubooru.search.typing import SaColumn, SaQuery
from szurubooru.search.query import SearchQuery
from szurubooru.search.configs import util as search_util from szurubooru.search.configs import util as search_util
from szurubooru.search.configs.base_search_config import BaseSearchConfig from szurubooru.search.configs.base_search_config import (
BaseSearchConfig, Filter)
def _enum_transformer(available_values, value): def _enum_transformer(available_values: Dict[str, Any], value: str) -> str:
try: try:
return available_values[value.lower()] return available_values[value.lower()]
except KeyError: except KeyError:
@ -16,71 +19,82 @@ def _enum_transformer(available_values, value):
value, list(sorted(available_values.keys())))) value, list(sorted(available_values.keys()))))
def _type_transformer(value): def _type_transformer(value: str) -> str:
available_values = { available_values = {
'image': db.Post.TYPE_IMAGE, 'image': model.Post.TYPE_IMAGE,
'animation': db.Post.TYPE_ANIMATION, 'animation': model.Post.TYPE_ANIMATION,
'animated': db.Post.TYPE_ANIMATION, 'animated': model.Post.TYPE_ANIMATION,
'anim': db.Post.TYPE_ANIMATION, 'anim': model.Post.TYPE_ANIMATION,
'gif': db.Post.TYPE_ANIMATION, 'gif': model.Post.TYPE_ANIMATION,
'video': db.Post.TYPE_VIDEO, 'video': model.Post.TYPE_VIDEO,
'webm': db.Post.TYPE_VIDEO, 'webm': model.Post.TYPE_VIDEO,
'flash': db.Post.TYPE_FLASH, 'flash': model.Post.TYPE_FLASH,
'swf': db.Post.TYPE_FLASH, 'swf': model.Post.TYPE_FLASH,
} }
return _enum_transformer(available_values, value) return _enum_transformer(available_values, value)
def _safety_transformer(value): def _safety_transformer(value: str) -> str:
available_values = { available_values = {
'safe': db.Post.SAFETY_SAFE, 'safe': model.Post.SAFETY_SAFE,
'sketchy': db.Post.SAFETY_SKETCHY, 'sketchy': model.Post.SAFETY_SKETCHY,
'questionable': db.Post.SAFETY_SKETCHY, 'questionable': model.Post.SAFETY_SKETCHY,
'unsafe': db.Post.SAFETY_UNSAFE, 'unsafe': model.Post.SAFETY_UNSAFE,
} }
return _enum_transformer(available_values, value) return _enum_transformer(available_values, value)
def _create_score_filter(score): def _create_score_filter(score: int) -> Filter:
def wrapper(query, criterion, negated): def wrapper(
query: SaQuery,
criterion: Optional[criteria.BaseCriterion],
negated: bool) -> SaQuery:
assert criterion
if not getattr(criterion, 'internal', False): if not getattr(criterion, 'internal', False):
raise errors.SearchError( raise errors.SearchError(
'Votes cannot be seen publicly. Did you mean %r?' 'Votes cannot be seen publicly. Did you mean %r?'
% 'special:liked') % 'special:liked')
user_alias = aliased(db.User) user_alias = sa.orm.aliased(model.User)
score_alias = aliased(db.PostScore) score_alias = sa.orm.aliased(model.PostScore)
expr = score_alias.score == score expr = score_alias.score == score
expr = expr & search_util.apply_str_criterion_to_column( expr = expr & search_util.apply_str_criterion_to_column(
user_alias.name, criterion) user_alias.name, criterion)
if negated: if negated:
expr = ~expr expr = ~expr
ret = query \ ret = query \
.join(score_alias, score_alias.post_id == db.Post.post_id) \ .join(score_alias, score_alias.post_id == model.Post.post_id) \
.join(user_alias, user_alias.user_id == score_alias.user_id) \ .join(user_alias, user_alias.user_id == score_alias.user_id) \
.filter(expr) .filter(expr)
return ret return ret
return wrapper return wrapper
def _create_user_filter(): def _create_user_filter() -> Filter:
def wrapper(query, criterion, negated): def wrapper(
query: SaQuery,
criterion: Optional[criteria.BaseCriterion],
negated: bool) -> SaQuery:
assert criterion
if isinstance(criterion, criteria.PlainCriterion) \ if isinstance(criterion, criteria.PlainCriterion) \
and not criterion.value: and not criterion.value:
# pylint: disable=singleton-comparison # pylint: disable=singleton-comparison
expr = db.Post.user_id == None expr = model.Post.user_id == None
if negated: if negated:
expr = ~expr expr = ~expr
return query.filter(expr) return query.filter(expr)
return search_util.create_subquery_filter( return search_util.create_subquery_filter(
db.Post.user_id, model.Post.user_id,
db.User.user_id, model.User.user_id,
db.User.name, model.User.name,
search_util.create_str_filter)(query, criterion, negated) search_util.create_str_filter)(query, criterion, negated)
return wrapper return wrapper
class PostSearchConfig(BaseSearchConfig): class PostSearchConfig(BaseSearchConfig):
def on_search_query_parsed(self, search_query): def __init__(self) -> None:
self.user = None # type: Optional[model.User]
def on_search_query_parsed(self, search_query: SearchQuery) -> SaQuery:
new_special_tokens = [] new_special_tokens = []
for token in search_query.special_tokens: for token in search_query.special_tokens:
if token.value in ('fav', 'liked', 'disliked'): if token.value in ('fav', 'liked', 'disliked'):
@ -91,7 +105,7 @@ class PostSearchConfig(BaseSearchConfig):
criterion = criteria.PlainCriterion( criterion = criteria.PlainCriterion(
original_text=self.user.name, original_text=self.user.name,
value=self.user.name) value=self.user.name)
criterion.internal = True setattr(criterion, 'internal', True)
search_query.named_tokens.append( search_query.named_tokens.append(
tokens.NamedToken( tokens.NamedToken(
name=token.value, name=token.value,
@ -101,160 +115,324 @@ class PostSearchConfig(BaseSearchConfig):
new_special_tokens.append(token) new_special_tokens.append(token)
search_query.special_tokens = new_special_tokens search_query.special_tokens = new_special_tokens
def create_around_query(self): def create_around_query(self) -> SaQuery:
return db.session.query(db.Post).options(lazyload('*')) return db.session.query(model.Post).options(sa.orm.lazyload('*'))
def create_filter_query(self, disable_eager_loads): def create_filter_query(self, disable_eager_loads: bool) -> SaQuery:
strategy = lazyload if disable_eager_loads else subqueryload strategy = (
return db.session.query(db.Post) \ sa.orm.lazyload
if disable_eager_loads
else sa.orm.subqueryload)
return db.session.query(model.Post) \
.options( .options(
lazyload('*'), sa.orm.lazyload('*'),
# use config optimized for official client # use config optimized for official client
# defer(db.Post.score), # sa.orm.defer(model.Post.score),
# defer(db.Post.favorite_count), # sa.orm.defer(model.Post.favorite_count),
# defer(db.Post.comment_count), # sa.orm.defer(model.Post.comment_count),
defer(db.Post.last_favorite_time), sa.orm.defer(model.Post.last_favorite_time),
defer(db.Post.feature_count), sa.orm.defer(model.Post.feature_count),
defer(db.Post.last_feature_time), sa.orm.defer(model.Post.last_feature_time),
defer(db.Post.last_comment_creation_time), sa.orm.defer(model.Post.last_comment_creation_time),
defer(db.Post.last_comment_edit_time), sa.orm.defer(model.Post.last_comment_edit_time),
defer(db.Post.note_count), sa.orm.defer(model.Post.note_count),
defer(db.Post.tag_count), sa.orm.defer(model.Post.tag_count),
strategy(db.Post.tags).subqueryload(db.Tag.names), strategy(model.Post.tags).subqueryload(model.Tag.names),
strategy(db.Post.tags).defer(db.Tag.post_count), strategy(model.Post.tags).defer(model.Tag.post_count),
strategy(db.Post.tags).lazyload(db.Tag.implications), strategy(model.Post.tags).lazyload(model.Tag.implications),
strategy(db.Post.tags).lazyload(db.Tag.suggestions)) strategy(model.Post.tags).lazyload(model.Tag.suggestions))
def create_count_query(self, _disable_eager_loads): def create_count_query(self, _disable_eager_loads: bool) -> SaQuery:
return db.session.query(db.Post) return db.session.query(model.Post)
def finalize_query(self, query): def finalize_query(self, query: SaQuery) -> SaQuery:
return query.order_by(db.Post.post_id.desc()) return query.order_by(model.Post.post_id.desc())
@property @property
def id_column(self): def id_column(self) -> SaColumn:
return db.Post.post_id return model.Post.post_id
@property @property
def anonymous_filter(self): def anonymous_filter(self) -> Optional[Filter]:
return search_util.create_subquery_filter( return search_util.create_subquery_filter(
db.Post.post_id, model.Post.post_id,
db.PostTag.post_id, model.PostTag.post_id,
db.TagName.name, model.TagName.name,
search_util.create_str_filter, search_util.create_str_filter,
lambda subquery: subquery.join(db.Tag).join(db.TagName)) lambda subquery: subquery.join(model.Tag).join(model.TagName))
@property @property
def named_filters(self): def named_filters(self) -> Dict[str, Filter]:
return util.unalias_dict({ return util.unalias_dict([
'id': search_util.create_num_filter(db.Post.post_id), (
'tag': search_util.create_subquery_filter( ['id'],
db.Post.post_id, search_util.create_num_filter(model.Post.post_id)
db.PostTag.post_id, ),
db.TagName.name,
(
['tag'],
search_util.create_subquery_filter(
model.Post.post_id,
model.PostTag.post_id,
model.TagName.name,
search_util.create_str_filter, search_util.create_str_filter,
lambda subquery: subquery.join(db.Tag).join(db.TagName)), lambda subquery:
'score': search_util.create_num_filter(db.Post.score), subquery.join(model.Tag).join(model.TagName))
('uploader', 'upload', 'submit'): ),
_create_user_filter(),
'comment': search_util.create_subquery_filter( (
db.Post.post_id, ['score'],
db.Comment.post_id, search_util.create_num_filter(model.Post.score)
db.User.name, ),
(
['uploader', 'upload', 'submit'],
_create_user_filter()
),
(
['comment'],
search_util.create_subquery_filter(
model.Post.post_id,
model.Comment.post_id,
model.User.name,
search_util.create_str_filter, search_util.create_str_filter,
lambda subquery: subquery.join(db.User)), lambda subquery: subquery.join(model.User))
'fav': search_util.create_subquery_filter( ),
db.Post.post_id,
db.PostFavorite.post_id, (
db.User.name, ['fav'],
search_util.create_subquery_filter(
model.Post.post_id,
model.PostFavorite.post_id,
model.User.name,
search_util.create_str_filter, search_util.create_str_filter,
lambda subquery: subquery.join(db.User)), lambda subquery: subquery.join(model.User))
'liked': _create_score_filter(1), ),
'disliked': _create_score_filter(-1),
'tag-count': search_util.create_num_filter(db.Post.tag_count), (
'comment-count': ['liked'],
search_util.create_num_filter(db.Post.comment_count), _create_score_filter(1)
'fav-count': ),
search_util.create_num_filter(db.Post.favorite_count), (
'note-count': search_util.create_num_filter(db.Post.note_count), ['disliked'],
'relation-count': _create_score_filter(-1)
search_util.create_num_filter(db.Post.relation_count), ),
'feature-count':
search_util.create_num_filter(db.Post.feature_count), (
'type': ['tag-count'],
search_util.create_num_filter(model.Post.tag_count)
),
(
['comment-count'],
search_util.create_num_filter(model.Post.comment_count)
),
(
['fav-count'],
search_util.create_num_filter(model.Post.favorite_count)
),
(
['note-count'],
search_util.create_num_filter(model.Post.note_count)
),
(
['relation-count'],
search_util.create_num_filter(model.Post.relation_count)
),
(
['feature-count'],
search_util.create_num_filter(model.Post.feature_count)
),
(
['type'],
search_util.create_str_filter( search_util.create_str_filter(
db.Post.type, _type_transformer), model.Post.type, _type_transformer)
'content-checksum': search_util.create_str_filter( ),
db.Post.checksum),
'file-size': search_util.create_num_filter(db.Post.file_size), (
('image-width', 'width'): ['content-checksum'],
search_util.create_num_filter(db.Post.canvas_width), search_util.create_str_filter(model.Post.checksum)
('image-height', 'height'): ),
search_util.create_num_filter(db.Post.canvas_height),
('image-area', 'area'): (
search_util.create_num_filter(db.Post.canvas_area), ['file-size'],
('creation-date', 'creation-time', 'date', 'time'): search_util.create_num_filter(model.Post.file_size)
search_util.create_date_filter(db.Post.creation_time), ),
('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'):
search_util.create_date_filter(db.Post.last_edit_time), (
('comment-date', 'comment-time'): ['image-width', 'width'],
search_util.create_num_filter(model.Post.canvas_width)
),
(
['image-height', 'height'],
search_util.create_num_filter(model.Post.canvas_height)
),
(
['image-area', 'area'],
search_util.create_num_filter(model.Post.canvas_area)
),
(
['creation-date', 'creation-time', 'date', 'time'],
search_util.create_date_filter(model.Post.creation_time)
),
(
['last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'],
search_util.create_date_filter(model.Post.last_edit_time)
),
(
['comment-date', 'comment-time'],
search_util.create_date_filter( search_util.create_date_filter(
db.Post.last_comment_creation_time), model.Post.last_comment_creation_time)
('fav-date', 'fav-time'): ),
search_util.create_date_filter(db.Post.last_favorite_time),
('feature-date', 'feature-time'): (
search_util.create_date_filter(db.Post.last_feature_time), ['fav-date', 'fav-time'],
('safety', 'rating'): search_util.create_date_filter(model.Post.last_favorite_time)
),
(
['feature-date', 'feature-time'],
search_util.create_date_filter(model.Post.last_feature_time)
),
(
['safety', 'rating'],
search_util.create_str_filter( search_util.create_str_filter(
db.Post.safety, _safety_transformer), model.Post.safety, _safety_transformer)
}) ),
])
@property @property
def sort_columns(self): def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]:
return util.unalias_dict({ return util.unalias_dict([
'random': (func.random(), None), (
'id': (db.Post.post_id, self.SORT_DESC), ['random'],
'score': (db.Post.score, self.SORT_DESC), (sa.sql.expression.func.random(), self.SORT_NONE)
'tag-count': (db.Post.tag_count, self.SORT_DESC), ),
'comment-count': (db.Post.comment_count, self.SORT_DESC),
'fav-count': (db.Post.favorite_count, self.SORT_DESC), (
'note-count': (db.Post.note_count, self.SORT_DESC), ['id'],
'relation-count': (db.Post.relation_count, self.SORT_DESC), (model.Post.post_id, self.SORT_DESC)
'feature-count': (db.Post.feature_count, self.SORT_DESC), ),
'file-size': (db.Post.file_size, self.SORT_DESC),
('image-width', 'width'): (
(db.Post.canvas_width, self.SORT_DESC), ['score'],
('image-height', 'height'): (model.Post.score, self.SORT_DESC)
(db.Post.canvas_height, self.SORT_DESC), ),
('image-area', 'area'):
(db.Post.canvas_area, self.SORT_DESC), (
('creation-date', 'creation-time', 'date', 'time'): ['tag-count'],
(db.Post.creation_time, self.SORT_DESC), (model.Post.tag_count, self.SORT_DESC)
('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'): ),
(db.Post.last_edit_time, self.SORT_DESC),
('comment-date', 'comment-time'): (
(db.Post.last_comment_creation_time, self.SORT_DESC), ['comment-count'],
('fav-date', 'fav-time'): (model.Post.comment_count, self.SORT_DESC)
(db.Post.last_favorite_time, self.SORT_DESC), ),
('feature-date', 'feature-time'):
(db.Post.last_feature_time, self.SORT_DESC), (
}) ['fav-count'],
(model.Post.favorite_count, self.SORT_DESC)
),
(
['note-count'],
(model.Post.note_count, self.SORT_DESC)
),
(
['relation-count'],
(model.Post.relation_count, self.SORT_DESC)
),
(
['feature-count'],
(model.Post.feature_count, self.SORT_DESC)
),
(
['file-size'],
(model.Post.file_size, self.SORT_DESC)
),
(
['image-width', 'width'],
(model.Post.canvas_width, self.SORT_DESC)
),
(
['image-height', 'height'],
(model.Post.canvas_height, self.SORT_DESC)
),
(
['image-area', 'area'],
(model.Post.canvas_area, self.SORT_DESC)
),
(
['creation-date', 'creation-time', 'date', 'time'],
(model.Post.creation_time, self.SORT_DESC)
),
(
['last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'],
(model.Post.last_edit_time, self.SORT_DESC)
),
(
['comment-date', 'comment-time'],
(model.Post.last_comment_creation_time, self.SORT_DESC)
),
(
['fav-date', 'fav-time'],
(model.Post.last_favorite_time, self.SORT_DESC)
),
(
['feature-date', 'feature-time'],
(model.Post.last_feature_time, self.SORT_DESC)
),
])
@property @property
def special_filters(self): def special_filters(self) -> Dict[str, Filter]:
return { return {
# handled by parsed # handled by parser
'fav': None, 'fav': self.noop_filter,
'liked': None, 'liked': self.noop_filter,
'disliked': None, 'disliked': self.noop_filter,
'tumbleweed': self.tumbleweed_filter, 'tumbleweed': self.tumbleweed_filter,
} }
def tumbleweed_filter(self, query, negated): def noop_filter(
expr = \ self,
(db.Post.comment_count == 0) \ query: SaQuery,
& (db.Post.favorite_count == 0) \ _criterion: Optional[criteria.BaseCriterion],
& (db.Post.score == 0) _negated: bool) -> SaQuery:
return query
def tumbleweed_filter(
self,
query: SaQuery,
_criterion: Optional[criteria.BaseCriterion],
negated: bool) -> SaQuery:
expr = (
(model.Post.comment_count == 0)
& (model.Post.favorite_count == 0)
& (model.Post.score == 0))
if negated: if negated:
expr = ~expr expr = ~expr
return query.filter(expr) return query.filter(expr)

View file

@ -1,28 +1,37 @@
from szurubooru import db from typing import Dict
from szurubooru import db, model
from szurubooru.search.typing import SaQuery
from szurubooru.search.configs import util as search_util from szurubooru.search.configs import util as search_util
from szurubooru.search.configs.base_search_config import BaseSearchConfig from szurubooru.search.configs.base_search_config import (
BaseSearchConfig, Filter)
class SnapshotSearchConfig(BaseSearchConfig): class SnapshotSearchConfig(BaseSearchConfig):
def create_filter_query(self, _disable_eager_loads): def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery:
return db.session.query(db.Snapshot) return db.session.query(model.Snapshot)
def create_count_query(self, _disable_eager_loads): def create_count_query(self, _disable_eager_loads: bool) -> SaQuery:
return db.session.query(db.Snapshot) return db.session.query(model.Snapshot)
def create_around_query(self): def create_around_query(self) -> SaQuery:
raise NotImplementedError() raise NotImplementedError()
def finalize_query(self, query): def finalize_query(self, query: SaQuery) -> SaQuery:
return query.order_by(db.Snapshot.creation_time.desc()) return query.order_by(model.Snapshot.creation_time.desc())
@property @property
def named_filters(self): def named_filters(self) -> Dict[str, Filter]:
return { return {
'type': search_util.create_str_filter(db.Snapshot.resource_type), 'type':
'id': search_util.create_str_filter(db.Snapshot.resource_name), search_util.create_str_filter(model.Snapshot.resource_type),
'date': search_util.create_date_filter(db.Snapshot.creation_time), 'id':
'time': search_util.create_date_filter(db.Snapshot.creation_time), search_util.create_str_filter(model.Snapshot.resource_name),
'operation': search_util.create_str_filter(db.Snapshot.operation), 'date':
'user': search_util.create_str_filter(db.User.name), search_util.create_date_filter(model.Snapshot.creation_time),
'time':
search_util.create_date_filter(model.Snapshot.creation_time),
'operation':
search_util.create_str_filter(model.Snapshot.operation),
'user':
search_util.create_str_filter(model.User.name),
} }

View file

@ -1,79 +1,134 @@
from sqlalchemy.orm import subqueryload, lazyload, defer from typing import Tuple, Dict
from sqlalchemy.sql.expression import func import sqlalchemy as sa
from szurubooru import db from szurubooru import db, model
from szurubooru.func import util from szurubooru.func import util
from szurubooru.search.typing import SaColumn, SaQuery
from szurubooru.search.configs import util as search_util from szurubooru.search.configs import util as search_util
from szurubooru.search.configs.base_search_config import BaseSearchConfig from szurubooru.search.configs.base_search_config import (
BaseSearchConfig, Filter)
class TagSearchConfig(BaseSearchConfig): class TagSearchConfig(BaseSearchConfig):
def create_filter_query(self, _disable_eager_loads): def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery:
strategy = lazyload if _disable_eager_loads else subqueryload strategy = (
return db.session.query(db.Tag) \ sa.orm.lazyload
.join(db.TagCategory) \ if _disable_eager_loads
else sa.orm.subqueryload)
return db.session.query(model.Tag) \
.join(model.TagCategory) \
.options( .options(
defer(db.Tag.first_name), sa.orm.defer(model.Tag.first_name),
defer(db.Tag.suggestion_count), sa.orm.defer(model.Tag.suggestion_count),
defer(db.Tag.implication_count), sa.orm.defer(model.Tag.implication_count),
defer(db.Tag.post_count), sa.orm.defer(model.Tag.post_count),
strategy(db.Tag.names), strategy(model.Tag.names),
strategy(db.Tag.suggestions).joinedload(db.Tag.names), strategy(model.Tag.suggestions).joinedload(model.Tag.names),
strategy(db.Tag.implications).joinedload(db.Tag.names)) strategy(model.Tag.implications).joinedload(model.Tag.names))
def create_count_query(self, _disable_eager_loads): def create_count_query(self, _disable_eager_loads: bool) -> SaQuery:
return db.session.query(db.Tag) return db.session.query(model.Tag)
def create_around_query(self): def create_around_query(self) -> SaQuery:
raise NotImplementedError() raise NotImplementedError()
def finalize_query(self, query): def finalize_query(self, query: SaQuery) -> SaQuery:
return query.order_by(db.Tag.first_name.asc()) return query.order_by(model.Tag.first_name.asc())
@property @property
def anonymous_filter(self): def anonymous_filter(self) -> Filter:
return search_util.create_subquery_filter( return search_util.create_subquery_filter(
db.Tag.tag_id, model.Tag.tag_id,
db.TagName.tag_id, model.TagName.tag_id,
db.TagName.name, model.TagName.name,
search_util.create_str_filter) search_util.create_str_filter)
@property @property
def named_filters(self): def named_filters(self) -> Dict[str, Filter]:
return util.unalias_dict({ return util.unalias_dict([
'name': search_util.create_subquery_filter( (
db.Tag.tag_id, ['name'],
db.TagName.tag_id, search_util.create_subquery_filter(
db.TagName.name, model.Tag.tag_id,
search_util.create_str_filter), model.TagName.tag_id,
'category': search_util.create_subquery_filter( model.TagName.name,
db.Tag.category_id, search_util.create_str_filter)
db.TagCategory.tag_category_id, ),
db.TagCategory.name,
search_util.create_str_filter), (
('creation-date', 'creation-time'): ['category'],
search_util.create_date_filter(db.Tag.creation_time), search_util.create_subquery_filter(
('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'): model.Tag.category_id,
search_util.create_date_filter(db.Tag.last_edit_time), model.TagCategory.tag_category_id,
('usage-count', 'post-count', 'usages'): model.TagCategory.name,
search_util.create_num_filter(db.Tag.post_count), search_util.create_str_filter)
'suggestion-count': ),
search_util.create_num_filter(db.Tag.suggestion_count),
'implication-count': (
search_util.create_num_filter(db.Tag.implication_count), ['creation-date', 'creation-time'],
}) search_util.create_date_filter(model.Tag.creation_time)
),
(
['last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'],
search_util.create_date_filter(model.Tag.last_edit_time)
),
(
['usage-count', 'post-count', 'usages'],
search_util.create_num_filter(model.Tag.post_count)
),
(
['suggestion-count'],
search_util.create_num_filter(model.Tag.suggestion_count)
),
(
['implication-count'],
search_util.create_num_filter(model.Tag.implication_count)
),
])
@property @property
def sort_columns(self): def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]:
return util.unalias_dict({ return util.unalias_dict([
'random': (func.random(), None), (
'name': (db.Tag.first_name, self.SORT_ASC), ['random'],
'category': (db.TagCategory.name, self.SORT_ASC), (sa.sql.expression.func.random(), self.SORT_NONE)
('creation-date', 'creation-time'): ),
(db.Tag.creation_time, self.SORT_DESC),
('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'): (
(db.Tag.last_edit_time, self.SORT_DESC), ['name'],
('usage-count', 'post-count', 'usages'): (model.Tag.first_name, self.SORT_ASC)
(db.Tag.post_count, self.SORT_DESC), ),
'suggestion-count': (db.Tag.suggestion_count, self.SORT_DESC),
'implication-count': (db.Tag.implication_count, self.SORT_DESC), (
}) ['category'],
(model.TagCategory.name, self.SORT_ASC)
),
(
['creation-date', 'creation-time'],
(model.Tag.creation_time, self.SORT_DESC)
),
(
['last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'],
(model.Tag.last_edit_time, self.SORT_DESC)
),
(
['usage-count', 'post-count', 'usages'],
(model.Tag.post_count, self.SORT_DESC)
),
(
['suggestion-count'],
(model.Tag.suggestion_count, self.SORT_DESC)
),
(
['implication-count'],
(model.Tag.implication_count, self.SORT_DESC)
),
])

View file

@ -1,53 +1,57 @@
from sqlalchemy.sql.expression import func from typing import Tuple, Dict
from szurubooru import db import sqlalchemy as sa
from szurubooru import db, model
from szurubooru.search.typing import SaColumn, SaQuery
from szurubooru.search.configs import util as search_util from szurubooru.search.configs import util as search_util
from szurubooru.search.configs.base_search_config import BaseSearchConfig from szurubooru.search.configs.base_search_config import (
BaseSearchConfig, Filter)
class UserSearchConfig(BaseSearchConfig): class UserSearchConfig(BaseSearchConfig):
def create_filter_query(self, _disable_eager_loads): def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery:
return db.session.query(db.User) return db.session.query(model.User)
def create_count_query(self, _disable_eager_loads): def create_count_query(self, _disable_eager_loads: bool) -> SaQuery:
return db.session.query(db.User) return db.session.query(model.User)
def create_around_query(self): def create_around_query(self) -> SaQuery:
raise NotImplementedError() raise NotImplementedError()
def finalize_query(self, query): def finalize_query(self, query: SaQuery) -> SaQuery:
return query.order_by(db.User.name.asc()) return query.order_by(model.User.name.asc())
@property @property
def anonymous_filter(self): def anonymous_filter(self) -> Filter:
return search_util.create_str_filter(db.User.name) return search_util.create_str_filter(model.User.name)
@property @property
def named_filters(self): def named_filters(self) -> Dict[str, Filter]:
return { return {
'name': search_util.create_str_filter(db.User.name), 'name':
search_util.create_str_filter(model.User.name),
'creation-date': 'creation-date':
search_util.create_date_filter(db.User.creation_time), search_util.create_date_filter(model.User.creation_time),
'creation-time': 'creation-time':
search_util.create_date_filter(db.User.creation_time), search_util.create_date_filter(model.User.creation_time),
'last-login-date': 'last-login-date':
search_util.create_date_filter(db.User.last_login_time), search_util.create_date_filter(model.User.last_login_time),
'last-login-time': 'last-login-time':
search_util.create_date_filter(db.User.last_login_time), search_util.create_date_filter(model.User.last_login_time),
'login-date': 'login-date':
search_util.create_date_filter(db.User.last_login_time), search_util.create_date_filter(model.User.last_login_time),
'login-time': 'login-time':
search_util.create_date_filter(db.User.last_login_time), search_util.create_date_filter(model.User.last_login_time),
} }
@property @property
def sort_columns(self): def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]:
return { return {
'random': (func.random(), None), 'random': (sa.sql.expression.func.random(), self.SORT_NONE),
'name': (db.User.name, self.SORT_ASC), 'name': (model.User.name, self.SORT_ASC),
'creation-date': (db.User.creation_time, self.SORT_DESC), 'creation-date': (model.User.creation_time, self.SORT_DESC),
'creation-time': (db.User.creation_time, self.SORT_DESC), 'creation-time': (model.User.creation_time, self.SORT_DESC),
'last-login-date': (db.User.last_login_time, self.SORT_DESC), 'last-login-date': (model.User.last_login_time, self.SORT_DESC),
'last-login-time': (db.User.last_login_time, self.SORT_DESC), 'last-login-time': (model.User.last_login_time, self.SORT_DESC),
'login-date': (db.User.last_login_time, self.SORT_DESC), 'login-date': (model.User.last_login_time, self.SORT_DESC),
'login-time': (db.User.last_login_time, self.SORT_DESC), 'login-time': (model.User.last_login_time, self.SORT_DESC),
} }

View file

@ -1,10 +1,13 @@
import sqlalchemy from typing import Any, Optional, Callable
import sqlalchemy as sa
from szurubooru import db, errors from szurubooru import db, errors
from szurubooru.func import util from szurubooru.func import util
from szurubooru.search import criteria from szurubooru.search import criteria
from szurubooru.search.typing import SaColumn, SaQuery
from szurubooru.search.configs.base_search_config import Filter
def wildcard_transformer(value): def wildcard_transformer(value: str) -> str:
return ( return (
value value
.replace('\\', '\\\\') .replace('\\', '\\\\')
@ -13,24 +16,21 @@ def wildcard_transformer(value):
.replace('*', '%')) .replace('*', '%'))
def apply_num_criterion_to_column(column, criterion): def apply_num_criterion_to_column(
''' column: Any, criterion: criteria.BaseCriterion) -> Any:
Decorate SQLAlchemy filter on given column using supplied criterion.
'''
try: try:
if isinstance(criterion, criteria.PlainCriterion): if isinstance(criterion, criteria.PlainCriterion):
expr = column == int(criterion.value) expr = column == int(criterion.value)
elif isinstance(criterion, criteria.ArrayCriterion): elif isinstance(criterion, criteria.ArrayCriterion):
expr = column.in_(int(value) for value in criterion.values) expr = column.in_(int(value) for value in criterion.values)
elif isinstance(criterion, criteria.RangedCriterion): elif isinstance(criterion, criteria.RangedCriterion):
assert criterion.min_value != '' \ assert criterion.min_value or criterion.max_value
or criterion.max_value != '' if criterion.min_value and criterion.max_value:
if criterion.min_value != '' and criterion.max_value != '':
expr = column.between( expr = column.between(
int(criterion.min_value), int(criterion.max_value)) int(criterion.min_value), int(criterion.max_value))
elif criterion.min_value != '': elif criterion.min_value:
expr = column >= int(criterion.min_value) expr = column >= int(criterion.min_value)
elif criterion.max_value != '': elif criterion.max_value:
expr = column <= int(criterion.max_value) expr = column <= int(criterion.max_value)
else: else:
assert False assert False
@ -40,10 +40,13 @@ def apply_num_criterion_to_column(column, criterion):
return expr return expr
def create_num_filter(column): def create_num_filter(column: Any) -> Filter:
def wrapper(query, criterion, negated): def wrapper(
expr = apply_num_criterion_to_column( query: SaQuery,
column, criterion) criterion: Optional[criteria.BaseCriterion],
negated: bool) -> SaQuery:
assert criterion
expr = apply_num_criterion_to_column(column, criterion)
if negated: if negated:
expr = ~expr expr = ~expr
return query.filter(expr) return query.filter(expr)
@ -51,14 +54,13 @@ def create_num_filter(column):
def apply_str_criterion_to_column( def apply_str_criterion_to_column(
column, criterion, transformer=wildcard_transformer): column: SaColumn,
''' criterion: criteria.BaseCriterion,
Decorate SQLAlchemy filter on given column using supplied criterion. transformer: Callable[[str], str]=wildcard_transformer) -> SaQuery:
'''
if isinstance(criterion, criteria.PlainCriterion): if isinstance(criterion, criteria.PlainCriterion):
expr = column.ilike(transformer(criterion.value)) expr = column.ilike(transformer(criterion.value))
elif isinstance(criterion, criteria.ArrayCriterion): elif isinstance(criterion, criteria.ArrayCriterion):
expr = sqlalchemy.sql.false() expr = sa.sql.false()
for value in criterion.values: for value in criterion.values:
expr = expr | column.ilike(transformer(value)) expr = expr | column.ilike(transformer(value))
elif isinstance(criterion, criteria.RangedCriterion): elif isinstance(criterion, criteria.RangedCriterion):
@ -68,8 +70,15 @@ def apply_str_criterion_to_column(
return expr return expr
def create_str_filter(column, transformer=wildcard_transformer): def create_str_filter(
def wrapper(query, criterion, negated): column: SaColumn,
transformer: Callable[[str], str]=wildcard_transformer
) -> Filter:
def wrapper(
query: SaQuery,
criterion: Optional[criteria.BaseCriterion],
negated: bool) -> SaQuery:
assert criterion
expr = apply_str_criterion_to_column( expr = apply_str_criterion_to_column(
column, criterion, transformer) column, criterion, transformer)
if negated: if negated:
@ -78,16 +87,13 @@ def create_str_filter(column, transformer=wildcard_transformer):
return wrapper return wrapper
def apply_date_criterion_to_column(column, criterion): def apply_date_criterion_to_column(
''' column: SaQuery, criterion: criteria.BaseCriterion) -> SaQuery:
Decorate SQLAlchemy filter on given column using supplied criterion.
Parse the datetime inside the criterion.
'''
if isinstance(criterion, criteria.PlainCriterion): if isinstance(criterion, criteria.PlainCriterion):
min_date, max_date = util.parse_time_range(criterion.value) min_date, max_date = util.parse_time_range(criterion.value)
expr = column.between(min_date, max_date) expr = column.between(min_date, max_date)
elif isinstance(criterion, criteria.ArrayCriterion): elif isinstance(criterion, criteria.ArrayCriterion):
expr = sqlalchemy.sql.false() expr = sa.sql.false()
for value in criterion.values: for value in criterion.values:
min_date, max_date = util.parse_time_range(value) min_date, max_date = util.parse_time_range(value)
expr = expr | column.between(min_date, max_date) expr = expr | column.between(min_date, max_date)
@ -108,10 +114,13 @@ def apply_date_criterion_to_column(column, criterion):
return expr return expr
def create_date_filter(column): def create_date_filter(column: SaColumn) -> Filter:
def wrapper(query, criterion, negated): def wrapper(
expr = apply_date_criterion_to_column( query: SaQuery,
column, criterion) criterion: Optional[criteria.BaseCriterion],
negated: bool) -> SaQuery:
assert criterion
expr = apply_date_criterion_to_column(column, criterion)
if negated: if negated:
expr = ~expr expr = ~expr
return query.filter(expr) return query.filter(expr)
@ -119,18 +128,22 @@ def create_date_filter(column):
def create_subquery_filter( def create_subquery_filter(
left_id_column, left_id_column: SaColumn,
right_id_column, right_id_column: SaColumn,
filter_column, filter_column: SaColumn,
filter_factory, filter_factory: SaColumn,
subquery_decorator=None): subquery_decorator: Callable[[SaQuery], None]=None) -> Filter:
filter_func = filter_factory(filter_column) filter_func = filter_factory(filter_column)
def wrapper(query, criterion, negated): def wrapper(
query: SaQuery,
criterion: Optional[criteria.BaseCriterion],
negated: bool) -> SaQuery:
assert criterion
subquery = db.session.query(right_id_column.label('foreign_id')) subquery = db.session.query(right_id_column.label('foreign_id'))
if subquery_decorator: if subquery_decorator:
subquery = subquery_decorator(subquery) subquery = subquery_decorator(subquery)
subquery = subquery.options(sqlalchemy.orm.lazyload('*')) subquery = subquery.options(sa.orm.lazyload('*'))
subquery = filter_func(subquery, criterion, False) subquery = filter_func(subquery, criterion, False)
subquery = subquery.subquery('t') subquery = subquery.subquery('t')
expression = left_id_column.in_(subquery) expression = left_id_column.in_(subquery)

View file

@ -1,34 +1,42 @@
class _BaseCriterion: from typing import Optional, List, Callable
def __init__(self, original_text): from szurubooru.search.typing import SaQuery
class BaseCriterion:
def __init__(self, original_text: str) -> None:
self.original_text = original_text self.original_text = original_text
def __repr__(self): def __repr__(self) -> str:
return self.original_text return self.original_text
class RangedCriterion(_BaseCriterion): class RangedCriterion(BaseCriterion):
def __init__(self, original_text, min_value, max_value): def __init__(
self,
original_text: str,
min_value: Optional[str],
max_value: Optional[str]) -> None:
super().__init__(original_text) super().__init__(original_text)
self.min_value = min_value self.min_value = min_value
self.max_value = max_value self.max_value = max_value
def __hash__(self): def __hash__(self) -> int:
return hash(('range', self.min_value, self.max_value)) return hash(('range', self.min_value, self.max_value))
class PlainCriterion(_BaseCriterion): class PlainCriterion(BaseCriterion):
def __init__(self, original_text, value): def __init__(self, original_text: str, value: str) -> None:
super().__init__(original_text) super().__init__(original_text)
self.value = value self.value = value
def __hash__(self): def __hash__(self) -> int:
return hash(self.value) return hash(self.value)
class ArrayCriterion(_BaseCriterion): class ArrayCriterion(BaseCriterion):
def __init__(self, original_text, values): def __init__(self, original_text: str, values: List[str]) -> None:
super().__init__(original_text) super().__init__(original_text)
self.values = values self.values = values
def __hash__(self): def __hash__(self) -> int:
return hash(tuple(['array'] + self.values)) return hash(tuple(['array'] + self.values))

View file

@ -1,14 +1,18 @@
import sqlalchemy from typing import Union, Tuple, List, Dict, Callable
from szurubooru import db, errors import sqlalchemy as sa
from szurubooru import db, model, errors, rest
from szurubooru.func import cache from szurubooru.func import cache
from szurubooru.search import tokens, parser from szurubooru.search import tokens, parser
from szurubooru.search.typing import SaQuery
from szurubooru.search.query import SearchQuery
from szurubooru.search.configs.base_search_config import BaseSearchConfig
def _format_dict_keys(source): def _format_dict_keys(source: Dict) -> List[str]:
return list(sorted(source.keys())) return list(sorted(source.keys()))
def _get_order(order, default_order): def _get_order(order: str, default_order: str) -> Union[bool, str]:
if order == tokens.SortToken.SORT_DEFAULT: if order == tokens.SortToken.SORT_DEFAULT:
return default_order or tokens.SortToken.SORT_ASC return default_order or tokens.SortToken.SORT_ASC
if order == tokens.SortToken.SORT_NEGATED_DEFAULT: if order == tokens.SortToken.SORT_NEGATED_DEFAULT:
@ -26,50 +30,57 @@ class Executor:
delegates sqlalchemy filter decoration to SearchConfig instances. delegates sqlalchemy filter decoration to SearchConfig instances.
''' '''
def __init__(self, search_config): def __init__(self, search_config: BaseSearchConfig) -> None:
self.config = search_config self.config = search_config
self.parser = parser.Parser() self.parser = parser.Parser()
def get_around(self, query_text, entity_id): def get_around(
self,
query_text: str,
entity_id: int) -> Tuple[model.Base, model.Base]:
search_query = self.parser.parse(query_text) search_query = self.parser.parse(query_text)
self.config.on_search_query_parsed(search_query) self.config.on_search_query_parsed(search_query)
filter_query = ( filter_query = (
self.config self.config
.create_around_query() .create_around_query()
.options(sqlalchemy.orm.lazyload('*'))) .options(sa.orm.lazyload('*')))
filter_query = self._prepare_db_query( filter_query = self._prepare_db_query(
filter_query, search_query, False) filter_query, search_query, False)
prev_filter_query = ( prev_filter_query = (
filter_query filter_query
.filter(self.config.id_column > entity_id) .filter(self.config.id_column > entity_id)
.order_by(None) .order_by(None)
.order_by(sqlalchemy.func.abs( .order_by(sa.func.abs(self.config.id_column - entity_id).asc())
self.config.id_column - entity_id).asc())
.limit(1)) .limit(1))
next_filter_query = ( next_filter_query = (
filter_query filter_query
.filter(self.config.id_column < entity_id) .filter(self.config.id_column < entity_id)
.order_by(None) .order_by(None)
.order_by(sqlalchemy.func.abs( .order_by(sa.func.abs(self.config.id_column - entity_id).asc())
self.config.id_column - entity_id).asc())
.limit(1)) .limit(1))
return [ return (
prev_filter_query.one_or_none(), prev_filter_query.one_or_none(),
next_filter_query.one_or_none()] next_filter_query.one_or_none())
def get_around_and_serialize(self, ctx, entity_id, serializer): def get_around_and_serialize(
entities = self.get_around(ctx.get_param_as_string('query'), entity_id) self,
ctx: rest.Context,
entity_id: int,
serializer: Callable[[model.Base], rest.Response]
) -> rest.Response:
entities = self.get_around(
ctx.get_param_as_string('query', default=''), entity_id)
return { return {
'prev': serializer(entities[0]), 'prev': serializer(entities[0]),
'next': serializer(entities[1]), 'next': serializer(entities[1]),
} }
def execute(self, query_text, page, page_size): def execute(
''' self,
Parse input and return tuple containing total record count and filtered query_text: str,
entities. page: int,
''' page_size: int
) -> Tuple[int, List[model.Base]]:
search_query = self.parser.parse(query_text) search_query = self.parser.parse(query_text)
self.config.on_search_query_parsed(search_query) self.config.on_search_query_parsed(search_query)
@ -83,7 +94,7 @@ class Executor:
return cache.get(key) return cache.get(key)
filter_query = self.config.create_filter_query(disable_eager_loads) filter_query = self.config.create_filter_query(disable_eager_loads)
filter_query = filter_query.options(sqlalchemy.orm.lazyload('*')) filter_query = filter_query.options(sa.orm.lazyload('*'))
filter_query = self._prepare_db_query(filter_query, search_query, True) filter_query = self._prepare_db_query(filter_query, search_query, True)
entities = filter_query \ entities = filter_query \
.offset(max(page - 1, 0) * page_size) \ .offset(max(page - 1, 0) * page_size) \
@ -91,11 +102,11 @@ class Executor:
.all() .all()
count_query = self.config.create_count_query(disable_eager_loads) count_query = self.config.create_count_query(disable_eager_loads)
count_query = count_query.options(sqlalchemy.orm.lazyload('*')) count_query = count_query.options(sa.orm.lazyload('*'))
count_query = self._prepare_db_query(count_query, search_query, False) count_query = self._prepare_db_query(count_query, search_query, False)
count_statement = count_query \ count_statement = count_query \
.statement \ .statement \
.with_only_columns([sqlalchemy.func.count()]) \ .with_only_columns([sa.func.count()]) \
.order_by(None) .order_by(None)
count = db.session.execute(count_statement).scalar() count = db.session.execute(count_statement).scalar()
@ -103,8 +114,12 @@ class Executor:
cache.put(key, ret) cache.put(key, ret)
return ret return ret
def execute_and_serialize(self, ctx, serializer): def execute_and_serialize(
query = ctx.get_param_as_string('query') self,
ctx: rest.Context,
serializer: Callable[[model.Base], rest.Response]
) -> rest.Response:
query = ctx.get_param_as_string('query', default='')
page = ctx.get_param_as_int('page', default=1, min=1) page = ctx.get_param_as_int('page', default=1, min=1)
page_size = ctx.get_param_as_int( page_size = ctx.get_param_as_int(
'pageSize', default=100, min=1, max=100) 'pageSize', default=100, min=1, max=100)
@ -117,48 +132,51 @@ class Executor:
'results': [serializer(entity) for entity in entities], 'results': [serializer(entity) for entity in entities],
} }
def _prepare_db_query(self, db_query, search_query, use_sort): def _prepare_db_query(
''' Parse input and return SQLAlchemy query. ''' self,
db_query: SaQuery,
for token in search_query.anonymous_tokens: search_query: SearchQuery,
use_sort: bool) -> SaQuery:
for anon_token in search_query.anonymous_tokens:
if not self.config.anonymous_filter: if not self.config.anonymous_filter:
raise errors.SearchError( raise errors.SearchError(
'Anonymous tokens are not valid in this context.') 'Anonymous tokens are not valid in this context.')
db_query = self.config.anonymous_filter( db_query = self.config.anonymous_filter(
db_query, token.criterion, token.negated) db_query, anon_token.criterion, anon_token.negated)
for token in search_query.named_tokens: for named_token in search_query.named_tokens:
if token.name not in self.config.named_filters: if named_token.name not in self.config.named_filters:
raise errors.SearchError( raise errors.SearchError(
'Unknown named token: %r. Available named tokens: %r.' % ( 'Unknown named token: %r. Available named tokens: %r.' % (
token.name, named_token.name,
_format_dict_keys(self.config.named_filters))) _format_dict_keys(self.config.named_filters)))
db_query = self.config.named_filters[token.name]( db_query = self.config.named_filters[named_token.name](
db_query, token.criterion, token.negated) db_query, named_token.criterion, named_token.negated)
for token in search_query.special_tokens: for sp_token in search_query.special_tokens:
if token.value not in self.config.special_filters: if sp_token.value not in self.config.special_filters:
raise errors.SearchError( raise errors.SearchError(
'Unknown special token: %r. ' 'Unknown special token: %r. '
'Available special tokens: %r.' % ( 'Available special tokens: %r.' % (
token.value, sp_token.value,
_format_dict_keys(self.config.special_filters))) _format_dict_keys(self.config.special_filters)))
db_query = self.config.special_filters[token.value]( db_query = self.config.special_filters[sp_token.value](
db_query, token.negated) db_query, None, sp_token.negated)
if use_sort: if use_sort:
for token in search_query.sort_tokens: for sort_token in search_query.sort_tokens:
if token.name not in self.config.sort_columns: if sort_token.name not in self.config.sort_columns:
raise errors.SearchError( raise errors.SearchError(
'Unknown sort token: %r. ' 'Unknown sort token: %r. '
'Available sort tokens: %r.' % ( 'Available sort tokens: %r.' % (
token.name, sort_token.name,
_format_dict_keys(self.config.sort_columns))) _format_dict_keys(self.config.sort_columns)))
column, default_order = self.config.sort_columns[token.name] column, default_order = (
order = _get_order(token.order, default_order) self.config.sort_columns[sort_token.name])
if order == token.SORT_ASC: order = _get_order(sort_token.order, default_order)
if order == sort_token.SORT_ASC:
db_query = db_query.order_by(column.asc()) db_query = db_query.order_by(column.asc())
elif order == token.SORT_DESC: elif order == sort_token.SORT_DESC:
db_query = db_query.order_by(column.desc()) db_query = db_query.order_by(column.desc())
db_query = self.config.finalize_query(db_query) db_query = self.config.finalize_query(db_query)

View file

@ -1,9 +1,12 @@
import re import re
from typing import List
from szurubooru import errors from szurubooru import errors
from szurubooru.search import criteria, tokens from szurubooru.search import criteria, tokens
from szurubooru.search.query import SearchQuery
def _create_criterion(original_value, value): def _create_criterion(
original_value: str, value: str) -> criteria.BaseCriterion:
if ',' in value: if ',' in value:
return criteria.ArrayCriterion( return criteria.ArrayCriterion(
original_value, value.split(',')) original_value, value.split(','))
@ -15,12 +18,12 @@ def _create_criterion(original_value, value):
return criteria.PlainCriterion(original_value, value) return criteria.PlainCriterion(original_value, value)
def _parse_anonymous(value, negated): def _parse_anonymous(value: str, negated: bool) -> tokens.AnonymousToken:
criterion = _create_criterion(value, value) criterion = _create_criterion(value, value)
return tokens.AnonymousToken(criterion, negated) return tokens.AnonymousToken(criterion, negated)
def _parse_named(key, value, negated): def _parse_named(key: str, value: str, negated: bool) -> tokens.NamedToken:
original_value = value original_value = value
if key.endswith('-min'): if key.endswith('-min'):
key = key[:-4] key = key[:-4]
@ -32,11 +35,11 @@ def _parse_named(key, value, negated):
return tokens.NamedToken(key, criterion, negated) return tokens.NamedToken(key, criterion, negated)
def _parse_special(value, negated): def _parse_special(value: str, negated: bool) -> tokens.SpecialToken:
return tokens.SpecialToken(value, negated) return tokens.SpecialToken(value, negated)
def _parse_sort(value, negated): def _parse_sort(value: str, negated: bool) -> tokens.SortToken:
if value.count(',') == 0: if value.count(',') == 0:
order_str = None order_str = None
elif value.count(',') == 1: elif value.count(',') == 1:
@ -67,23 +70,8 @@ def _parse_sort(value, negated):
return tokens.SortToken(value, order) return tokens.SortToken(value, order)
class SearchQuery:
def __init__(self):
self.anonymous_tokens = []
self.named_tokens = []
self.special_tokens = []
self.sort_tokens = []
def __hash__(self):
return hash((
tuple(self.anonymous_tokens),
tuple(self.named_tokens),
tuple(self.special_tokens),
tuple(self.sort_tokens)))
class Parser: class Parser:
def parse(self, query_text): def parse(self, query_text: str) -> SearchQuery:
query = SearchQuery() query = SearchQuery()
for chunk in re.split(r'\s+', (query_text or '').lower()): for chunk in re.split(r'\s+', (query_text or '').lower()):
if not chunk: if not chunk:

View file

@ -0,0 +1,16 @@
from szurubooru.search import tokens
class SearchQuery:
def __init__(self) -> None:
self.anonymous_tokens = [] # type: List[tokens.AnonymousToken]
self.named_tokens = [] # type: List[tokens.NamedToken]
self.special_tokens = [] # type: List[tokens.SpecialToken]
self.sort_tokens = [] # type: List[tokens.SortToken]
def __hash__(self) -> int:
return hash((
tuple(self.anonymous_tokens),
tuple(self.named_tokens),
tuple(self.special_tokens),
tuple(self.sort_tokens)))

View file

@ -1,39 +1,44 @@
from szurubooru.search.criteria import BaseCriterion
class AnonymousToken: class AnonymousToken:
def __init__(self, criterion, negated): def __init__(self, criterion: BaseCriterion, negated: bool) -> None:
self.criterion = criterion self.criterion = criterion
self.negated = negated self.negated = negated
def __hash__(self): def __hash__(self) -> int:
return hash((self.criterion, self.negated)) return hash((self.criterion, self.negated))
class NamedToken(AnonymousToken): class NamedToken(AnonymousToken):
def __init__(self, name, criterion, negated): def __init__(
self, name: str, criterion: BaseCriterion, negated: bool) -> None:
super().__init__(criterion, negated) super().__init__(criterion, negated)
self.name = name self.name = name
def __hash__(self): def __hash__(self) -> int:
return hash((self.name, self.criterion, self.negated)) return hash((self.name, self.criterion, self.negated))
class SortToken: class SortToken:
SORT_DESC = 'desc' SORT_DESC = 'desc'
SORT_ASC = 'asc' SORT_ASC = 'asc'
SORT_NONE = ''
SORT_DEFAULT = 'default' SORT_DEFAULT = 'default'
SORT_NEGATED_DEFAULT = 'negated default' SORT_NEGATED_DEFAULT = 'negated default'
def __init__(self, name, order): def __init__(self, name: str, order: str) -> None:
self.name = name self.name = name
self.order = order self.order = order
def __hash__(self): def __hash__(self) -> int:
return hash((self.name, self.order)) return hash((self.name, self.order))
class SpecialToken: class SpecialToken:
def __init__(self, value, negated): def __init__(self, value: str, negated: bool) -> None:
self.value = value self.value = value
self.negated = negated self.negated = negated
def __hash__(self): def __hash__(self) -> int:
return hash((self.value, self.negated)) return hash((self.value, self.negated))

View file

@ -0,0 +1,6 @@
from typing import Any, Callable
SaColumn = Any
SaQuery = Any
SaQueryFactory = Callable[[], SaQuery]

View file

@ -1,19 +1,20 @@
from datetime import datetime from datetime import datetime
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, model, errors
from szurubooru.func import comments, posts from szurubooru.func import comments, posts
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({'privileges': {'comments:create': db.User.RANK_REGULAR}}) config_injector(
{'privileges': {'comments:create': model.User.RANK_REGULAR}})
def test_creating_comment( def test_creating_comment(
user_factory, post_factory, context_factory, fake_datetime): user_factory, post_factory, context_factory, fake_datetime):
post = post_factory() post = post_factory()
user = user_factory(rank=db.User.RANK_REGULAR) user = user_factory(rank=model.User.RANK_REGULAR)
db.session.add_all([post, user]) db.session.add_all([post, user])
db.session.flush() db.session.flush()
with patch('szurubooru.func.comments.serialize_comment'), \ with patch('szurubooru.func.comments.serialize_comment'), \
@ -24,7 +25,7 @@ def test_creating_comment(
params={'text': 'input', 'postId': post.post_id}, params={'text': 'input', 'postId': post.post_id},
user=user)) user=user))
assert result == 'serialized comment' assert result == 'serialized comment'
comment = db.session.query(db.Comment).one() comment = db.session.query(model.Comment).one()
assert comment.text == 'input' assert comment.text == 'input'
assert comment.creation_time == datetime(1997, 1, 1) assert comment.creation_time == datetime(1997, 1, 1)
assert comment.last_edit_time is None assert comment.last_edit_time is None
@ -41,7 +42,7 @@ def test_creating_comment(
def test_trying_to_pass_invalid_params( def test_trying_to_pass_invalid_params(
user_factory, post_factory, context_factory, params): user_factory, post_factory, context_factory, params):
post = post_factory() post = post_factory()
user = user_factory(rank=db.User.RANK_REGULAR) user = user_factory(rank=model.User.RANK_REGULAR)
db.session.add_all([post, user]) db.session.add_all([post, user])
db.session.flush() db.session.flush()
real_params = {'text': 'input', 'postId': post.post_id} real_params = {'text': 'input', 'postId': post.post_id}
@ -63,11 +64,11 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
api.comment_api.create_comment( api.comment_api.create_comment(
context_factory( context_factory(
params={}, params={},
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=model.User.RANK_REGULAR)))
def test_trying_to_comment_non_existing(user_factory, context_factory): def test_trying_to_comment_non_existing(user_factory, context_factory):
user = user_factory(rank=db.User.RANK_REGULAR) user = user_factory(rank=model.User.RANK_REGULAR)
db.session.add_all([user]) db.session.add_all([user])
db.session.flush() db.session.flush()
with pytest.raises(posts.PostNotFoundError): with pytest.raises(posts.PostNotFoundError):
@ -81,4 +82,4 @@ def test_trying_to_create_without_privileges(user_factory, context_factory):
api.comment_api.create_comment( api.comment_api.create_comment(
context_factory( context_factory(
params={}, params={},
user=user_factory(rank=db.User.RANK_ANONYMOUS))) user=user_factory(rank=model.User.RANK_ANONYMOUS)))

View file

@ -1,5 +1,5 @@
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, model, errors
from szurubooru.func import comments from szurubooru.func import comments
@ -7,8 +7,8 @@ from szurubooru.func import comments
def inject_config(config_injector): def inject_config(config_injector):
config_injector({ config_injector({
'privileges': { 'privileges': {
'comments:delete:own': db.User.RANK_REGULAR, 'comments:delete:own': model.User.RANK_REGULAR,
'comments:delete:any': db.User.RANK_MODERATOR, 'comments:delete:any': model.User.RANK_MODERATOR,
}, },
}) })
@ -22,26 +22,26 @@ def test_deleting_own_comment(user_factory, comment_factory, context_factory):
context_factory(params={'version': 1}, user=user), context_factory(params={'version': 1}, user=user),
{'comment_id': comment.comment_id}) {'comment_id': comment.comment_id})
assert result == {} assert result == {}
assert db.session.query(db.Comment).count() == 0 assert db.session.query(model.Comment).count() == 0
def test_deleting_someones_else_comment( def test_deleting_someones_else_comment(
user_factory, comment_factory, context_factory): user_factory, comment_factory, context_factory):
user1 = user_factory(rank=db.User.RANK_REGULAR) user1 = user_factory(rank=model.User.RANK_REGULAR)
user2 = user_factory(rank=db.User.RANK_MODERATOR) user2 = user_factory(rank=model.User.RANK_MODERATOR)
comment = comment_factory(user=user1) comment = comment_factory(user=user1)
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()
api.comment_api.delete_comment( api.comment_api.delete_comment(
context_factory(params={'version': 1}, user=user2), context_factory(params={'version': 1}, user=user2),
{'comment_id': comment.comment_id}) {'comment_id': comment.comment_id})
assert db.session.query(db.Comment).count() == 0 assert db.session.query(model.Comment).count() == 0
def test_trying_to_delete_someones_else_comment_without_privileges( def test_trying_to_delete_someones_else_comment_without_privileges(
user_factory, comment_factory, context_factory): user_factory, comment_factory, context_factory):
user1 = user_factory(rank=db.User.RANK_REGULAR) user1 = user_factory(rank=model.User.RANK_REGULAR)
user2 = user_factory(rank=db.User.RANK_REGULAR) user2 = user_factory(rank=model.User.RANK_REGULAR)
comment = comment_factory(user=user1) comment = comment_factory(user=user1)
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()
@ -49,7 +49,7 @@ def test_trying_to_delete_someones_else_comment_without_privileges(
api.comment_api.delete_comment( api.comment_api.delete_comment(
context_factory(params={'version': 1}, user=user2), context_factory(params={'version': 1}, user=user2),
{'comment_id': comment.comment_id}) {'comment_id': comment.comment_id})
assert db.session.query(db.Comment).count() == 1 assert db.session.query(model.Comment).count() == 1
def test_trying_to_delete_non_existing(user_factory, context_factory): def test_trying_to_delete_non_existing(user_factory, context_factory):
@ -57,5 +57,5 @@ def test_trying_to_delete_non_existing(user_factory, context_factory):
api.comment_api.delete_comment( api.comment_api.delete_comment(
context_factory( context_factory(
params={'version': 1}, params={'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=model.User.RANK_REGULAR)),
{'comment_id': 1}) {'comment_id': 1})

View file

@ -1,17 +1,18 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, model, errors
from szurubooru.func import comments from szurubooru.func import comments
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({'privileges': {'comments:score': db.User.RANK_REGULAR}}) config_injector(
{'privileges': {'comments:score': model.User.RANK_REGULAR}})
def test_simple_rating( def test_simple_rating(
user_factory, comment_factory, context_factory, fake_datetime): user_factory, comment_factory, context_factory, fake_datetime):
user = user_factory(rank=db.User.RANK_REGULAR) user = user_factory(rank=model.User.RANK_REGULAR)
comment = comment_factory(user=user) comment = comment_factory(user=user)
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()
@ -22,14 +23,14 @@ def test_simple_rating(
context_factory(params={'score': 1}, user=user), context_factory(params={'score': 1}, user=user),
{'comment_id': comment.comment_id}) {'comment_id': comment.comment_id})
assert result == 'serialized comment' assert result == 'serialized comment'
assert db.session.query(db.CommentScore).count() == 1 assert db.session.query(model.CommentScore).count() == 1
assert comment is not None assert comment is not None
assert comment.score == 1 assert comment.score == 1
def test_updating_rating( def test_updating_rating(
user_factory, comment_factory, context_factory, fake_datetime): user_factory, comment_factory, context_factory, fake_datetime):
user = user_factory(rank=db.User.RANK_REGULAR) user = user_factory(rank=model.User.RANK_REGULAR)
comment = comment_factory(user=user) comment = comment_factory(user=user)
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()
@ -42,14 +43,14 @@ def test_updating_rating(
api.comment_api.set_comment_score( api.comment_api.set_comment_score(
context_factory(params={'score': -1}, user=user), context_factory(params={'score': -1}, user=user),
{'comment_id': comment.comment_id}) {'comment_id': comment.comment_id})
comment = db.session.query(db.Comment).one() comment = db.session.query(model.Comment).one()
assert db.session.query(db.CommentScore).count() == 1 assert db.session.query(model.CommentScore).count() == 1
assert comment.score == -1 assert comment.score == -1
def test_updating_rating_to_zero( def test_updating_rating_to_zero(
user_factory, comment_factory, context_factory, fake_datetime): user_factory, comment_factory, context_factory, fake_datetime):
user = user_factory(rank=db.User.RANK_REGULAR) user = user_factory(rank=model.User.RANK_REGULAR)
comment = comment_factory(user=user) comment = comment_factory(user=user)
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()
@ -62,14 +63,14 @@ def test_updating_rating_to_zero(
api.comment_api.set_comment_score( api.comment_api.set_comment_score(
context_factory(params={'score': 0}, user=user), context_factory(params={'score': 0}, user=user),
{'comment_id': comment.comment_id}) {'comment_id': comment.comment_id})
comment = db.session.query(db.Comment).one() comment = db.session.query(model.Comment).one()
assert db.session.query(db.CommentScore).count() == 0 assert db.session.query(model.CommentScore).count() == 0
assert comment.score == 0 assert comment.score == 0
def test_deleting_rating( def test_deleting_rating(
user_factory, comment_factory, context_factory, fake_datetime): user_factory, comment_factory, context_factory, fake_datetime):
user = user_factory(rank=db.User.RANK_REGULAR) user = user_factory(rank=model.User.RANK_REGULAR)
comment = comment_factory(user=user) comment = comment_factory(user=user)
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()
@ -82,15 +83,15 @@ def test_deleting_rating(
api.comment_api.delete_comment_score( api.comment_api.delete_comment_score(
context_factory(user=user), context_factory(user=user),
{'comment_id': comment.comment_id}) {'comment_id': comment.comment_id})
comment = db.session.query(db.Comment).one() comment = db.session.query(model.Comment).one()
assert db.session.query(db.CommentScore).count() == 0 assert db.session.query(model.CommentScore).count() == 0
assert comment.score == 0 assert comment.score == 0
def test_ratings_from_multiple_users( def test_ratings_from_multiple_users(
user_factory, comment_factory, context_factory, fake_datetime): user_factory, comment_factory, context_factory, fake_datetime):
user1 = user_factory(rank=db.User.RANK_REGULAR) user1 = user_factory(rank=model.User.RANK_REGULAR)
user2 = user_factory(rank=db.User.RANK_REGULAR) user2 = user_factory(rank=model.User.RANK_REGULAR)
comment = comment_factory() comment = comment_factory()
db.session.add_all([user1, user2, comment]) db.session.add_all([user1, user2, comment])
db.session.commit() db.session.commit()
@ -103,8 +104,8 @@ def test_ratings_from_multiple_users(
api.comment_api.set_comment_score( api.comment_api.set_comment_score(
context_factory(params={'score': -1}, user=user2), context_factory(params={'score': -1}, user=user2),
{'comment_id': comment.comment_id}) {'comment_id': comment.comment_id})
comment = db.session.query(db.Comment).one() comment = db.session.query(model.Comment).one()
assert db.session.query(db.CommentScore).count() == 2 assert db.session.query(model.CommentScore).count() == 2
assert comment.score == 0 assert comment.score == 0
@ -125,7 +126,7 @@ def test_trying_to_update_non_existing(user_factory, context_factory):
api.comment_api.set_comment_score( api.comment_api.set_comment_score(
context_factory( context_factory(
params={'score': 1}, params={'score': 1},
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=model.User.RANK_REGULAR)),
{'comment_id': 5}) {'comment_id': 5})
@ -138,5 +139,5 @@ def test_trying_to_rate_without_privileges(
api.comment_api.set_comment_score( api.comment_api.set_comment_score(
context_factory( context_factory(
params={'score': 1}, params={'score': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS)), user=user_factory(rank=model.User.RANK_ANONYMOUS)),
{'comment_id': comment.comment_id}) {'comment_id': comment.comment_id})

View file

@ -1,6 +1,6 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, model, errors
from szurubooru.func import comments from szurubooru.func import comments
@ -8,8 +8,8 @@ from szurubooru.func import comments
def inject_config(config_injector): def inject_config(config_injector):
config_injector({ config_injector({
'privileges': { 'privileges': {
'comments:list': db.User.RANK_REGULAR, 'comments:list': model.User.RANK_REGULAR,
'comments:view': db.User.RANK_REGULAR, 'comments:view': model.User.RANK_REGULAR,
}, },
}) })
@ -24,7 +24,7 @@ def test_retrieving_multiple(user_factory, comment_factory, context_factory):
result = api.comment_api.get_comments( result = api.comment_api.get_comments(
context_factory( context_factory(
params={'query': '', 'page': 1}, params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=model.User.RANK_REGULAR)))
assert result == { assert result == {
'query': '', 'query': '',
'page': 1, 'page': 1,
@ -40,7 +40,7 @@ def test_trying_to_retrieve_multiple_without_privileges(
api.comment_api.get_comments( api.comment_api.get_comments(
context_factory( context_factory(
params={'query': '', 'page': 1}, params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS))) user=user_factory(rank=model.User.RANK_ANONYMOUS)))
def test_retrieving_single(user_factory, comment_factory, context_factory): def test_retrieving_single(user_factory, comment_factory, context_factory):
@ -51,7 +51,7 @@ def test_retrieving_single(user_factory, comment_factory, context_factory):
comments.serialize_comment.return_value = 'serialized comment' comments.serialize_comment.return_value = 'serialized comment'
result = api.comment_api.get_comment( result = api.comment_api.get_comment(
context_factory( context_factory(
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=model.User.RANK_REGULAR)),
{'comment_id': comment.comment_id}) {'comment_id': comment.comment_id})
assert result == 'serialized comment' assert result == 'serialized comment'
@ -60,7 +60,7 @@ def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
with pytest.raises(comments.CommentNotFoundError): with pytest.raises(comments.CommentNotFoundError):
api.comment_api.get_comment( api.comment_api.get_comment(
context_factory( context_factory(
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=model.User.RANK_REGULAR)),
{'comment_id': 5}) {'comment_id': 5})
@ -68,5 +68,5 @@ def test_trying_to_retrieve_single_without_privileges(
user_factory, context_factory): user_factory, context_factory):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
api.comment_api.get_comment( api.comment_api.get_comment(
context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)), context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)),
{'comment_id': 5}) {'comment_id': 5})

View file

@ -1,7 +1,7 @@
from datetime import datetime from datetime import datetime
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, model, errors
from szurubooru.func import comments from szurubooru.func import comments
@ -9,15 +9,15 @@ from szurubooru.func import comments
def inject_config(config_injector): def inject_config(config_injector):
config_injector({ config_injector({
'privileges': { 'privileges': {
'comments:edit:own': db.User.RANK_REGULAR, 'comments:edit:own': model.User.RANK_REGULAR,
'comments:edit:any': db.User.RANK_MODERATOR, 'comments:edit:any': model.User.RANK_MODERATOR,
}, },
}) })
def test_simple_updating( def test_simple_updating(
user_factory, comment_factory, context_factory, fake_datetime): user_factory, comment_factory, context_factory, fake_datetime):
user = user_factory(rank=db.User.RANK_REGULAR) user = user_factory(rank=model.User.RANK_REGULAR)
comment = comment_factory(user=user) comment = comment_factory(user=user)
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()
@ -73,14 +73,14 @@ def test_trying_to_update_non_existing(user_factory, context_factory):
api.comment_api.update_comment( api.comment_api.update_comment(
context_factory( context_factory(
params={'text': 'new text'}, params={'text': 'new text'},
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=model.User.RANK_REGULAR)),
{'comment_id': 5}) {'comment_id': 5})
def test_trying_to_update_someones_comment_without_privileges( def test_trying_to_update_someones_comment_without_privileges(
user_factory, comment_factory, context_factory): user_factory, comment_factory, context_factory):
user = user_factory(rank=db.User.RANK_REGULAR) user = user_factory(rank=model.User.RANK_REGULAR)
user2 = user_factory(rank=db.User.RANK_REGULAR) user2 = user_factory(rank=model.User.RANK_REGULAR)
comment = comment_factory(user=user) comment = comment_factory(user=user)
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()
@ -93,8 +93,8 @@ def test_trying_to_update_someones_comment_without_privileges(
def test_updating_someones_comment_with_privileges( def test_updating_someones_comment_with_privileges(
user_factory, comment_factory, context_factory): user_factory, comment_factory, context_factory):
user = user_factory(rank=db.User.RANK_REGULAR) user = user_factory(rank=model.User.RANK_REGULAR)
user2 = user_factory(rank=db.User.RANK_MODERATOR) user2 = user_factory(rank=model.User.RANK_MODERATOR)
comment = comment_factory(user=user) comment = comment_factory(user=user)
db.session.add(comment) db.session.add(comment)
db.session.commit() db.session.commit()

View file

@ -1,6 +1,6 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, model, errors
from szurubooru.func import auth, mailer from szurubooru.func import auth, mailer
@ -15,7 +15,7 @@ def inject_config(config_injector):
def test_reset_sending_email(context_factory, user_factory): def test_reset_sending_email(context_factory, user_factory):
db.session.add(user_factory( db.session.add(user_factory(
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')) name='u1', rank=model.User.RANK_REGULAR, email='user@example.com'))
db.session.flush() db.session.flush()
for initiating_user in ['u1', 'user@example.com']: for initiating_user in ['u1', 'user@example.com']:
with patch('szurubooru.func.mailer.send_mail'): with patch('szurubooru.func.mailer.send_mail'):
@ -39,7 +39,7 @@ def test_trying_to_reset_non_existing(context_factory):
def test_trying_to_reset_without_email(context_factory, user_factory): def test_trying_to_reset_without_email(context_factory, user_factory):
db.session.add( db.session.add(
user_factory(name='u1', rank=db.User.RANK_REGULAR, email=None)) user_factory(name='u1', rank=model.User.RANK_REGULAR, email=None))
db.session.flush() db.session.flush()
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError):
api.password_reset_api.start_password_reset( api.password_reset_api.start_password_reset(
@ -48,7 +48,7 @@ def test_trying_to_reset_without_email(context_factory, user_factory):
def test_confirming_with_good_token(context_factory, user_factory): def test_confirming_with_good_token(context_factory, user_factory):
user = user_factory( user = user_factory(
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com') name='u1', rank=model.User.RANK_REGULAR, email='user@example.com')
old_hash = user.password_hash old_hash = user.password_hash
db.session.add(user) db.session.add(user)
db.session.flush() db.session.flush()
@ -68,7 +68,7 @@ def test_trying_to_confirm_non_existing(context_factory):
def test_trying_to_confirm_without_token(context_factory, user_factory): def test_trying_to_confirm_without_token(context_factory, user_factory):
db.session.add(user_factory( db.session.add(user_factory(
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')) name='u1', rank=model.User.RANK_REGULAR, email='user@example.com'))
db.session.flush() db.session.flush()
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError):
api.password_reset_api.finish_password_reset( api.password_reset_api.finish_password_reset(
@ -77,7 +77,7 @@ def test_trying_to_confirm_without_token(context_factory, user_factory):
def test_trying_to_confirm_with_bad_token(context_factory, user_factory): def test_trying_to_confirm_with_bad_token(context_factory, user_factory):
db.session.add(user_factory( db.session.add(user_factory(
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')) name='u1', rank=model.User.RANK_REGULAR, email='user@example.com'))
db.session.flush() db.session.flush()
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError):
api.password_reset_api.finish_password_reset( api.password_reset_api.finish_password_reset(

View file

@ -1,6 +1,6 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, model, errors
from szurubooru.func import posts, tags, snapshots, net from szurubooru.func import posts, tags, snapshots, net
@ -8,16 +8,16 @@ from szurubooru.func import posts, tags, snapshots, net
def inject_config(config_injector): def inject_config(config_injector):
config_injector({ config_injector({
'privileges': { 'privileges': {
'posts:create:anonymous': db.User.RANK_REGULAR, 'posts:create:anonymous': model.User.RANK_REGULAR,
'posts:create:identified': db.User.RANK_REGULAR, 'posts:create:identified': model.User.RANK_REGULAR,
'tags:create': db.User.RANK_REGULAR, 'tags:create': model.User.RANK_REGULAR,
}, },
}) })
def test_creating_minimal_posts( def test_creating_minimal_posts(
context_factory, post_factory, user_factory): context_factory, post_factory, user_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=model.User.RANK_REGULAR)
post = post_factory() post = post_factory()
db.session.add(post) db.session.add(post)
db.session.flush() db.session.flush()
@ -53,20 +53,20 @@ def test_creating_minimal_posts(
posts.update_post_thumbnail.assert_called_once_with( posts.update_post_thumbnail.assert_called_once_with(
post, 'post-thumbnail') post, 'post-thumbnail')
posts.update_post_safety.assert_called_once_with(post, 'safe') posts.update_post_safety.assert_called_once_with(post, 'safe')
posts.update_post_source.assert_called_once_with(post, None) posts.update_post_source.assert_called_once_with(post, '')
posts.update_post_relations.assert_called_once_with(post, []) posts.update_post_relations.assert_called_once_with(post, [])
posts.update_post_notes.assert_called_once_with(post, []) posts.update_post_notes.assert_called_once_with(post, [])
posts.update_post_flags.assert_called_once_with(post, []) posts.update_post_flags.assert_called_once_with(post, [])
posts.update_post_thumbnail.assert_called_once_with( posts.update_post_thumbnail.assert_called_once_with(
post, 'post-thumbnail') post, 'post-thumbnail')
posts.serialize_post.assert_called_once_with( posts.serialize_post.assert_called_once_with(
post, auth_user, options=None) post, auth_user, options=[])
snapshots.create.assert_called_once_with(post, auth_user) snapshots.create.assert_called_once_with(post, auth_user)
tags.export_to_json.assert_called_once_with() tags.export_to_json.assert_called_once_with()
def test_creating_full_posts(context_factory, post_factory, user_factory): def test_creating_full_posts(context_factory, post_factory, user_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=model.User.RANK_REGULAR)
post = post_factory() post = post_factory()
db.session.add(post) db.session.add(post)
db.session.flush() db.session.flush()
@ -109,14 +109,14 @@ def test_creating_full_posts(context_factory, post_factory, user_factory):
posts.update_post_flags.assert_called_once_with( posts.update_post_flags.assert_called_once_with(
post, ['flag1', 'flag2']) post, ['flag1', 'flag2'])
posts.serialize_post.assert_called_once_with( posts.serialize_post.assert_called_once_with(
post, auth_user, options=None) post, auth_user, options=[])
snapshots.create.assert_called_once_with(post, auth_user) snapshots.create.assert_called_once_with(post, auth_user)
tags.export_to_json.assert_called_once_with() tags.export_to_json.assert_called_once_with()
def test_anonymous_uploads( def test_anonymous_uploads(
config_injector, context_factory, post_factory, user_factory): config_injector, context_factory, post_factory, user_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=model.User.RANK_REGULAR)
post = post_factory() post = post_factory()
db.session.add(post) db.session.add(post)
db.session.flush() db.session.flush()
@ -126,7 +126,7 @@ def test_anonymous_uploads(
patch('szurubooru.func.posts.create_post'), \ patch('szurubooru.func.posts.create_post'), \
patch('szurubooru.func.posts.update_post_source'): patch('szurubooru.func.posts.update_post_source'):
config_injector({ config_injector({
'privileges': {'posts:create:anonymous': db.User.RANK_REGULAR}, 'privileges': {'posts:create:anonymous': model.User.RANK_REGULAR},
}) })
posts.create_post.return_value = [post, []] posts.create_post.return_value = [post, []]
api.post_api.create_post( api.post_api.create_post(
@ -146,7 +146,7 @@ def test_anonymous_uploads(
def test_creating_from_url_saves_source( def test_creating_from_url_saves_source(
config_injector, context_factory, post_factory, user_factory): config_injector, context_factory, post_factory, user_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=model.User.RANK_REGULAR)
post = post_factory() post = post_factory()
db.session.add(post) db.session.add(post)
db.session.flush() db.session.flush()
@ -157,7 +157,7 @@ def test_creating_from_url_saves_source(
patch('szurubooru.func.posts.create_post'), \ patch('szurubooru.func.posts.create_post'), \
patch('szurubooru.func.posts.update_post_source'): patch('szurubooru.func.posts.update_post_source'):
config_injector({ config_injector({
'privileges': {'posts:create:identified': db.User.RANK_REGULAR}, 'privileges': {'posts:create:identified': model.User.RANK_REGULAR},
}) })
net.download.return_value = b'content' net.download.return_value = b'content'
posts.create_post.return_value = [post, []] posts.create_post.return_value = [post, []]
@ -177,7 +177,7 @@ def test_creating_from_url_saves_source(
def test_creating_from_url_with_source_specified( def test_creating_from_url_with_source_specified(
config_injector, context_factory, post_factory, user_factory): config_injector, context_factory, post_factory, user_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=model.User.RANK_REGULAR)
post = post_factory() post = post_factory()
db.session.add(post) db.session.add(post)
db.session.flush() db.session.flush()
@ -188,7 +188,7 @@ def test_creating_from_url_with_source_specified(
patch('szurubooru.func.posts.create_post'), \ patch('szurubooru.func.posts.create_post'), \
patch('szurubooru.func.posts.update_post_source'): patch('szurubooru.func.posts.update_post_source'):
config_injector({ config_injector({
'privileges': {'posts:create:identified': db.User.RANK_REGULAR}, 'privileges': {'posts:create:identified': model.User.RANK_REGULAR},
}) })
net.download.return_value = b'content' net.download.return_value = b'content'
posts.create_post.return_value = [post, []] posts.create_post.return_value = [post, []]
@ -218,14 +218,14 @@ def test_trying_to_omit_mandatory_field(context_factory, user_factory, field):
context_factory( context_factory(
params=params, params=params,
files={'content': '...'}, files={'content': '...'},
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=model.User.RANK_REGULAR)))
@pytest.mark.parametrize( @pytest.mark.parametrize(
'field', ['tags', 'relations', 'source', 'notes', 'flags']) 'field', ['tags', 'relations', 'source', 'notes', 'flags'])
def test_omitting_optional_field( def test_omitting_optional_field(
field, context_factory, post_factory, user_factory): field, context_factory, post_factory, user_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=model.User.RANK_REGULAR)
post = post_factory() post = post_factory()
db.session.add(post) db.session.add(post)
db.session.flush() db.session.flush()
@ -268,10 +268,10 @@ def test_errors_not_spending_ids(
'post_height': 300, 'post_height': 300,
}, },
'privileges': { 'privileges': {
'posts:create:identified': db.User.RANK_REGULAR, 'posts:create:identified': model.User.RANK_REGULAR,
}, },
}) })
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=model.User.RANK_REGULAR)
# successful request # successful request
with patch('szurubooru.func.posts.serialize_post'), \ with patch('szurubooru.func.posts.serialize_post'), \
@ -316,7 +316,7 @@ def test_trying_to_omit_content(context_factory, user_factory):
'safety': 'safe', 'safety': 'safe',
'tags': ['tag1', 'tag2'], 'tags': ['tag1', 'tag2'],
}, },
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=model.User.RANK_REGULAR)))
def test_trying_to_create_post_without_privileges( def test_trying_to_create_post_without_privileges(
@ -324,16 +324,16 @@ def test_trying_to_create_post_without_privileges(
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
api.post_api.create_post(context_factory( api.post_api.create_post(context_factory(
params='whatever', params='whatever',
user=user_factory(rank=db.User.RANK_ANONYMOUS))) user=user_factory(rank=model.User.RANK_ANONYMOUS)))
def test_trying_to_create_tags_without_privileges( def test_trying_to_create_tags_without_privileges(
config_injector, context_factory, user_factory): config_injector, context_factory, user_factory):
config_injector({ config_injector({
'privileges': { 'privileges': {
'posts:create:anonymous': db.User.RANK_REGULAR, 'posts:create:anonymous': model.User.RANK_REGULAR,
'posts:create:identified': db.User.RANK_REGULAR, 'posts:create:identified': model.User.RANK_REGULAR,
'tags:create': db.User.RANK_ADMINISTRATOR, 'tags:create': model.User.RANK_ADMINISTRATOR,
}, },
}) })
with pytest.raises(errors.AuthError), \ with pytest.raises(errors.AuthError), \
@ -349,4 +349,4 @@ def test_trying_to_create_tags_without_privileges(
files={ files={
'content': posts.EMPTY_PIXEL, 'content': posts.EMPTY_PIXEL,
}, },
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=model.User.RANK_REGULAR)))

View file

@ -1,16 +1,16 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, model, errors
from szurubooru.func import posts, tags, snapshots from szurubooru.func import posts, tags, snapshots
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({'privileges': {'posts:delete': db.User.RANK_REGULAR}}) config_injector({'privileges': {'posts:delete': model.User.RANK_REGULAR}})
def test_deleting(user_factory, post_factory, context_factory): def test_deleting(user_factory, post_factory, context_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=model.User.RANK_REGULAR)
post = post_factory(id=1) post = post_factory(id=1)
db.session.add(post) db.session.add(post)
db.session.flush() db.session.flush()
@ -20,7 +20,7 @@ def test_deleting(user_factory, post_factory, context_factory):
context_factory(params={'version': 1}, user=auth_user), context_factory(params={'version': 1}, user=auth_user),
{'post_id': 1}) {'post_id': 1})
assert result == {} assert result == {}
assert db.session.query(db.Post).count() == 0 assert db.session.query(model.Post).count() == 0
snapshots.delete.assert_called_once_with(post, auth_user) snapshots.delete.assert_called_once_with(post, auth_user)
tags.export_to_json.assert_called_once_with() tags.export_to_json.assert_called_once_with()
@ -28,7 +28,7 @@ def test_deleting(user_factory, post_factory, context_factory):
def test_trying_to_delete_non_existing(user_factory, context_factory): def test_trying_to_delete_non_existing(user_factory, context_factory):
with pytest.raises(posts.PostNotFoundError): with pytest.raises(posts.PostNotFoundError):
api.post_api.delete_post( api.post_api.delete_post(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), context_factory(user=user_factory(rank=model.User.RANK_REGULAR)),
{'post_id': 999}) {'post_id': 999})
@ -38,6 +38,6 @@ def test_trying_to_delete_without_privileges(
db.session.commit() db.session.commit()
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
api.post_api.delete_post( api.post_api.delete_post(
context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)), context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)),
{'post_id': 1}) {'post_id': 1})
assert db.session.query(db.Post).count() == 1 assert db.session.query(model.Post).count() == 1

View file

@ -1,13 +1,14 @@
from datetime import datetime from datetime import datetime
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, model, errors
from szurubooru.func import posts from szurubooru.func import posts
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({'privileges': {'posts:favorite': db.User.RANK_REGULAR}}) config_injector(
{'privileges': {'posts:favorite': model.User.RANK_REGULAR}})
def test_adding_to_favorites( def test_adding_to_favorites(
@ -23,8 +24,8 @@ def test_adding_to_favorites(
context_factory(user=user_factory()), context_factory(user=user_factory()),
{'post_id': post.post_id}) {'post_id': post.post_id})
assert result == 'serialized post' assert result == 'serialized post'
post = db.session.query(db.Post).one() post = db.session.query(model.Post).one()
assert db.session.query(db.PostFavorite).count() == 1 assert db.session.query(model.PostFavorite).count() == 1
assert post is not None assert post is not None
assert post.favorite_count == 1 assert post.favorite_count == 1
assert post.score == 1 assert post.score == 1
@ -47,9 +48,9 @@ def test_removing_from_favorites(
api.post_api.delete_post_from_favorites( api.post_api.delete_post_from_favorites(
context_factory(user=user), context_factory(user=user),
{'post_id': post.post_id}) {'post_id': post.post_id})
post = db.session.query(db.Post).one() post = db.session.query(model.Post).one()
assert post.score == 1 assert post.score == 1
assert db.session.query(db.PostFavorite).count() == 0 assert db.session.query(model.PostFavorite).count() == 0
assert post.favorite_count == 0 assert post.favorite_count == 0
@ -68,8 +69,8 @@ def test_favoriting_twice(
api.post_api.add_post_to_favorites( api.post_api.add_post_to_favorites(
context_factory(user=user), context_factory(user=user),
{'post_id': post.post_id}) {'post_id': post.post_id})
post = db.session.query(db.Post).one() post = db.session.query(model.Post).one()
assert db.session.query(db.PostFavorite).count() == 1 assert db.session.query(model.PostFavorite).count() == 1
assert post.favorite_count == 1 assert post.favorite_count == 1
@ -92,8 +93,8 @@ def test_removing_twice(
api.post_api.delete_post_from_favorites( api.post_api.delete_post_from_favorites(
context_factory(user=user), context_factory(user=user),
{'post_id': post.post_id}) {'post_id': post.post_id})
post = db.session.query(db.Post).one() post = db.session.query(model.Post).one()
assert db.session.query(db.PostFavorite).count() == 0 assert db.session.query(model.PostFavorite).count() == 0
assert post.favorite_count == 0 assert post.favorite_count == 0
@ -113,8 +114,8 @@ def test_favorites_from_multiple_users(
api.post_api.add_post_to_favorites( api.post_api.add_post_to_favorites(
context_factory(user=user2), context_factory(user=user2),
{'post_id': post.post_id}) {'post_id': post.post_id})
post = db.session.query(db.Post).one() post = db.session.query(model.Post).one()
assert db.session.query(db.PostFavorite).count() == 2 assert db.session.query(model.PostFavorite).count() == 2
assert post.favorite_count == 2 assert post.favorite_count == 2
assert post.last_favorite_time == datetime(1997, 12, 2) assert post.last_favorite_time == datetime(1997, 12, 2)
@ -133,5 +134,5 @@ def test_trying_to_rate_without_privileges(
db.session.commit() db.session.commit()
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
api.post_api.add_post_to_favorites( api.post_api.add_post_to_favorites(
context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)), context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)),
{'post_id': post.post_id}) {'post_id': post.post_id})

View file

@ -1,6 +1,6 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, model, errors
from szurubooru.func import posts, snapshots from szurubooru.func import posts, snapshots
@ -8,14 +8,14 @@ from szurubooru.func import posts, snapshots
def inject_config(config_injector): def inject_config(config_injector):
config_injector({ config_injector({
'privileges': { 'privileges': {
'posts:feature': db.User.RANK_REGULAR, 'posts:feature': model.User.RANK_REGULAR,
'posts:view': db.User.RANK_REGULAR, 'posts:view': model.User.RANK_REGULAR,
}, },
}) })
def test_featuring(user_factory, post_factory, context_factory): def test_featuring(user_factory, post_factory, context_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=model.User.RANK_REGULAR)
post = post_factory(id=1) post = post_factory(id=1)
db.session.add(post) db.session.add(post)
db.session.flush() db.session.flush()
@ -31,7 +31,7 @@ def test_featuring(user_factory, post_factory, context_factory):
assert posts.get_post_by_id(1).is_featured assert posts.get_post_by_id(1).is_featured
result = api.post_api.get_featured_post( result = api.post_api.get_featured_post(
context_factory( context_factory(
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=model.User.RANK_REGULAR)))
assert result == 'serialized post' assert result == 'serialized post'
snapshots.modify.assert_called_once_with(post, auth_user) snapshots.modify.assert_called_once_with(post, auth_user)
@ -40,7 +40,7 @@ def test_trying_to_omit_required_parameter(user_factory, context_factory):
with pytest.raises(errors.MissingRequiredParameterError): with pytest.raises(errors.MissingRequiredParameterError):
api.post_api.set_featured_post( api.post_api.set_featured_post(
context_factory( context_factory(
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=model.User.RANK_REGULAR)))
def test_trying_to_feature_the_same_post_twice( def test_trying_to_feature_the_same_post_twice(
@ -51,12 +51,12 @@ def test_trying_to_feature_the_same_post_twice(
api.post_api.set_featured_post( api.post_api.set_featured_post(
context_factory( context_factory(
params={'id': 1}, params={'id': 1},
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=model.User.RANK_REGULAR)))
with pytest.raises(posts.PostAlreadyFeaturedError): with pytest.raises(posts.PostAlreadyFeaturedError):
api.post_api.set_featured_post( api.post_api.set_featured_post(
context_factory( context_factory(
params={'id': 1}, params={'id': 1},
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=model.User.RANK_REGULAR)))
def test_featuring_one_post_after_another( def test_featuring_one_post_after_another(
@ -72,12 +72,12 @@ def test_featuring_one_post_after_another(
api.post_api.set_featured_post( api.post_api.set_featured_post(
context_factory( context_factory(
params={'id': 1}, params={'id': 1},
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=model.User.RANK_REGULAR)))
with fake_datetime('1998'): with fake_datetime('1998'):
api.post_api.set_featured_post( api.post_api.set_featured_post(
context_factory( context_factory(
params={'id': 2}, params={'id': 2},
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=model.User.RANK_REGULAR)))
assert posts.try_get_featured_post() is not None assert posts.try_get_featured_post() is not None
assert posts.try_get_featured_post().post_id == 2 assert posts.try_get_featured_post().post_id == 2
assert not posts.get_post_by_id(1).is_featured assert not posts.get_post_by_id(1).is_featured
@ -89,7 +89,7 @@ def test_trying_to_feature_non_existing(user_factory, context_factory):
api.post_api.set_featured_post( api.post_api.set_featured_post(
context_factory( context_factory(
params={'id': 1}, params={'id': 1},
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=model.User.RANK_REGULAR)))
def test_trying_to_feature_without_privileges(user_factory, context_factory): def test_trying_to_feature_without_privileges(user_factory, context_factory):
@ -97,10 +97,10 @@ def test_trying_to_feature_without_privileges(user_factory, context_factory):
api.post_api.set_featured_post( api.post_api.set_featured_post(
context_factory( context_factory(
params={'id': 1}, params={'id': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS))) user=user_factory(rank=model.User.RANK_ANONYMOUS)))
def test_getting_featured_post_without_privileges_to_view( def test_getting_featured_post_without_privileges_to_view(
user_factory, context_factory): user_factory, context_factory):
api.post_api.get_featured_post( api.post_api.get_featured_post(
context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS))) context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)))

View file

@ -1,16 +1,16 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, model, errors
from szurubooru.func import posts, snapshots from szurubooru.func import posts, snapshots
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({'privileges': {'posts:merge': db.User.RANK_REGULAR}}) config_injector({'privileges': {'posts:merge': model.User.RANK_REGULAR}})
def test_merging(user_factory, context_factory, post_factory): def test_merging(user_factory, context_factory, post_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=model.User.RANK_REGULAR)
source_post = post_factory() source_post = post_factory()
target_post = post_factory() target_post = post_factory()
db.session.add_all([source_post, target_post]) db.session.add_all([source_post, target_post])
@ -25,6 +25,7 @@ def test_merging(user_factory, context_factory, post_factory):
'mergeToVersion': 1, 'mergeToVersion': 1,
'remove': source_post.post_id, 'remove': source_post.post_id,
'mergeTo': target_post.post_id, 'mergeTo': target_post.post_id,
'replaceContent': False,
}, },
user=auth_user)) user=auth_user))
posts.merge_posts.called_once_with(source_post, target_post) posts.merge_posts.called_once_with(source_post, target_post)
@ -45,13 +46,14 @@ def test_trying_to_omit_mandatory_field(
'mergeToVersion': 1, 'mergeToVersion': 1,
'remove': source_post.post_id, 'remove': source_post.post_id,
'mergeTo': target_post.post_id, 'mergeTo': target_post.post_id,
'replaceContent': False,
} }
del params[field] del params[field]
with pytest.raises(errors.ValidationError): with pytest.raises(errors.ValidationError):
api.post_api.merge_posts( api.post_api.merge_posts(
context_factory( context_factory(
params=params, params=params,
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=model.User.RANK_REGULAR)))
def test_trying_to_merge_non_existing( def test_trying_to_merge_non_existing(
@ -63,12 +65,12 @@ def test_trying_to_merge_non_existing(
api.post_api.merge_posts( api.post_api.merge_posts(
context_factory( context_factory(
params={'remove': post.post_id, 'mergeTo': 999}, params={'remove': post.post_id, 'mergeTo': 999},
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=model.User.RANK_REGULAR)))
with pytest.raises(posts.PostNotFoundError): with pytest.raises(posts.PostNotFoundError):
api.post_api.merge_posts( api.post_api.merge_posts(
context_factory( context_factory(
params={'remove': 999, 'mergeTo': post.post_id}, params={'remove': 999, 'mergeTo': post.post_id},
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=model.User.RANK_REGULAR)))
def test_trying_to_merge_without_privileges( def test_trying_to_merge_without_privileges(
@ -85,5 +87,6 @@ def test_trying_to_merge_without_privileges(
'mergeToVersion': 1, 'mergeToVersion': 1,
'remove': source_post.post_id, 'remove': source_post.post_id,
'mergeTo': target_post.post_id, 'mergeTo': target_post.post_id,
'replaceContent': False,
}, },
user=user_factory(rank=db.User.RANK_ANONYMOUS))) user=user_factory(rank=model.User.RANK_ANONYMOUS)))

View file

@ -1,12 +1,12 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, model, errors
from szurubooru.func import posts from szurubooru.func import posts
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({'privileges': {'posts:score': db.User.RANK_REGULAR}}) config_injector({'privileges': {'posts:score': model.User.RANK_REGULAR}})
def test_simple_rating( def test_simple_rating(
@ -22,8 +22,8 @@ def test_simple_rating(
params={'score': 1}, user=user_factory()), params={'score': 1}, user=user_factory()),
{'post_id': post.post_id}) {'post_id': post.post_id})
assert result == 'serialized post' assert result == 'serialized post'
post = db.session.query(db.Post).one() post = db.session.query(model.Post).one()
assert db.session.query(db.PostScore).count() == 1 assert db.session.query(model.PostScore).count() == 1
assert post is not None assert post is not None
assert post.score == 1 assert post.score == 1
@ -43,8 +43,8 @@ def test_updating_rating(
api.post_api.set_post_score( api.post_api.set_post_score(
context_factory(params={'score': -1}, user=user), context_factory(params={'score': -1}, user=user),
{'post_id': post.post_id}) {'post_id': post.post_id})
post = db.session.query(db.Post).one() post = db.session.query(model.Post).one()
assert db.session.query(db.PostScore).count() == 1 assert db.session.query(model.PostScore).count() == 1
assert post.score == -1 assert post.score == -1
@ -63,8 +63,8 @@ def test_updating_rating_to_zero(
api.post_api.set_post_score( api.post_api.set_post_score(
context_factory(params={'score': 0}, user=user), context_factory(params={'score': 0}, user=user),
{'post_id': post.post_id}) {'post_id': post.post_id})
post = db.session.query(db.Post).one() post = db.session.query(model.Post).one()
assert db.session.query(db.PostScore).count() == 0 assert db.session.query(model.PostScore).count() == 0
assert post.score == 0 assert post.score == 0
@ -83,8 +83,8 @@ def test_deleting_rating(
api.post_api.delete_post_score( api.post_api.delete_post_score(
context_factory(user=user), context_factory(user=user),
{'post_id': post.post_id}) {'post_id': post.post_id})
post = db.session.query(db.Post).one() post = db.session.query(model.Post).one()
assert db.session.query(db.PostScore).count() == 0 assert db.session.query(model.PostScore).count() == 0
assert post.score == 0 assert post.score == 0
@ -104,8 +104,8 @@ def test_ratings_from_multiple_users(
api.post_api.set_post_score( api.post_api.set_post_score(
context_factory(params={'score': -1}, user=user2), context_factory(params={'score': -1}, user=user2),
{'post_id': post.post_id}) {'post_id': post.post_id})
post = db.session.query(db.Post).one() post = db.session.query(model.Post).one()
assert db.session.query(db.PostScore).count() == 2 assert db.session.query(model.PostScore).count() == 2
assert post.score == 0 assert post.score == 0
@ -136,5 +136,5 @@ def test_trying_to_rate_without_privileges(
api.post_api.set_post_score( api.post_api.set_post_score(
context_factory( context_factory(
params={'score': 1}, params={'score': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS)), user=user_factory(rank=model.User.RANK_ANONYMOUS)),
{'post_id': post.post_id}) {'post_id': post.post_id})

View file

@ -1,7 +1,7 @@
from datetime import datetime from datetime import datetime
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, model, errors
from szurubooru.func import posts from szurubooru.func import posts
@ -9,8 +9,8 @@ from szurubooru.func import posts
def inject_config(config_injector): def inject_config(config_injector):
config_injector({ config_injector({
'privileges': { 'privileges': {
'posts:list': db.User.RANK_REGULAR, 'posts:list': model.User.RANK_REGULAR,
'posts:view': db.User.RANK_REGULAR, 'posts:view': model.User.RANK_REGULAR,
}, },
}) })
@ -25,7 +25,7 @@ def test_retrieving_multiple(user_factory, post_factory, context_factory):
result = api.post_api.get_posts( result = api.post_api.get_posts(
context_factory( context_factory(
params={'query': '', 'page': 1}, params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=model.User.RANK_REGULAR)))
assert result == { assert result == {
'query': '', 'query': '',
'page': 1, 'page': 1,
@ -36,10 +36,10 @@ def test_retrieving_multiple(user_factory, post_factory, context_factory):
def test_using_special_tokens(user_factory, post_factory, context_factory): def test_using_special_tokens(user_factory, post_factory, context_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=model.User.RANK_REGULAR)
post1 = post_factory(id=1) post1 = post_factory(id=1)
post2 = post_factory(id=2) post2 = post_factory(id=2)
post1.favorited_by = [db.PostFavorite( post1.favorited_by = [model.PostFavorite(
user=auth_user, time=datetime.utcnow())] user=auth_user, time=datetime.utcnow())]
db.session.add_all([post1, post2, auth_user]) db.session.add_all([post1, post2, auth_user])
db.session.flush() db.session.flush()
@ -68,7 +68,7 @@ def test_trying_to_use_special_tokens_without_logging_in(
api.post_api.get_posts( api.post_api.get_posts(
context_factory( context_factory(
params={'query': 'special:fav', 'page': 1}, params={'query': 'special:fav', 'page': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS))) user=user_factory(rank=model.User.RANK_ANONYMOUS)))
def test_trying_to_retrieve_multiple_without_privileges( def test_trying_to_retrieve_multiple_without_privileges(
@ -77,7 +77,7 @@ def test_trying_to_retrieve_multiple_without_privileges(
api.post_api.get_posts( api.post_api.get_posts(
context_factory( context_factory(
params={'query': '', 'page': 1}, params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS))) user=user_factory(rank=model.User.RANK_ANONYMOUS)))
def test_retrieving_single(user_factory, post_factory, context_factory): def test_retrieving_single(user_factory, post_factory, context_factory):
@ -86,7 +86,7 @@ def test_retrieving_single(user_factory, post_factory, context_factory):
with patch('szurubooru.func.posts.serialize_post'): with patch('szurubooru.func.posts.serialize_post'):
posts.serialize_post.return_value = 'serialized post' posts.serialize_post.return_value = 'serialized post'
result = api.post_api.get_post( result = api.post_api.get_post(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), context_factory(user=user_factory(rank=model.User.RANK_REGULAR)),
{'post_id': 1}) {'post_id': 1})
assert result == 'serialized post' assert result == 'serialized post'
@ -94,7 +94,7 @@ def test_retrieving_single(user_factory, post_factory, context_factory):
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
with pytest.raises(posts.PostNotFoundError): with pytest.raises(posts.PostNotFoundError):
api.post_api.get_post( api.post_api.get_post(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), context_factory(user=user_factory(rank=model.User.RANK_REGULAR)),
{'post_id': 999}) {'post_id': 999})
@ -102,5 +102,5 @@ def test_trying_to_retrieve_single_without_privileges(
user_factory, context_factory): user_factory, context_factory):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
api.post_api.get_post( api.post_api.get_post(
context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)), context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)),
{'post_id': 999}) {'post_id': 999})

View file

@ -1,7 +1,7 @@
from datetime import datetime from datetime import datetime
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, model, errors
from szurubooru.func import posts, tags, snapshots, net from szurubooru.func import posts, tags, snapshots, net
@ -9,22 +9,22 @@ from szurubooru.func import posts, tags, snapshots, net
def inject_config(config_injector): def inject_config(config_injector):
config_injector({ config_injector({
'privileges': { 'privileges': {
'posts:edit:tags': db.User.RANK_REGULAR, 'posts:edit:tags': model.User.RANK_REGULAR,
'posts:edit:content': db.User.RANK_REGULAR, 'posts:edit:content': model.User.RANK_REGULAR,
'posts:edit:safety': db.User.RANK_REGULAR, 'posts:edit:safety': model.User.RANK_REGULAR,
'posts:edit:source': db.User.RANK_REGULAR, 'posts:edit:source': model.User.RANK_REGULAR,
'posts:edit:relations': db.User.RANK_REGULAR, 'posts:edit:relations': model.User.RANK_REGULAR,
'posts:edit:notes': db.User.RANK_REGULAR, 'posts:edit:notes': model.User.RANK_REGULAR,
'posts:edit:flags': db.User.RANK_REGULAR, 'posts:edit:flags': model.User.RANK_REGULAR,
'posts:edit:thumbnail': db.User.RANK_REGULAR, 'posts:edit:thumbnail': model.User.RANK_REGULAR,
'tags:create': db.User.RANK_MODERATOR, 'tags:create': model.User.RANK_MODERATOR,
}, },
}) })
def test_post_updating( def test_post_updating(
context_factory, post_factory, user_factory, fake_datetime): context_factory, post_factory, user_factory, fake_datetime):
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=model.User.RANK_REGULAR)
post = post_factory() post = post_factory()
db.session.add(post) db.session.add(post)
db.session.flush() db.session.flush()
@ -76,7 +76,7 @@ def test_post_updating(
posts.update_post_flags.assert_called_once_with( posts.update_post_flags.assert_called_once_with(
post, ['flag1', 'flag2']) post, ['flag1', 'flag2'])
posts.serialize_post.assert_called_once_with( posts.serialize_post.assert_called_once_with(
post, auth_user, options=None) post, auth_user, options=[])
snapshots.modify.assert_called_once_with(post, auth_user) snapshots.modify.assert_called_once_with(post, auth_user)
tags.export_to_json.assert_called_once_with() tags.export_to_json.assert_called_once_with()
assert post.last_edit_time == datetime(1997, 1, 1) assert post.last_edit_time == datetime(1997, 1, 1)
@ -97,7 +97,7 @@ def test_uploading_from_url_saves_source(
api.post_api.update_post( api.post_api.update_post(
context_factory( context_factory(
params={'contentUrl': 'example.com', 'version': 1}, params={'contentUrl': 'example.com', 'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=model.User.RANK_REGULAR)),
{'post_id': post.post_id}) {'post_id': post.post_id})
net.download.assert_called_once_with('example.com') net.download.assert_called_once_with('example.com')
posts.update_post_content.assert_called_once_with(post, b'content') posts.update_post_content.assert_called_once_with(post, b'content')
@ -122,7 +122,7 @@ def test_uploading_from_url_with_source_specified(
'contentUrl': 'example.com', 'contentUrl': 'example.com',
'source': 'example2.com', 'source': 'example2.com',
'version': 1}, 'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=model.User.RANK_REGULAR)),
{'post_id': post.post_id}) {'post_id': post.post_id})
net.download.assert_called_once_with('example.com') net.download.assert_called_once_with('example.com')
posts.update_post_content.assert_called_once_with(post, b'content') posts.update_post_content.assert_called_once_with(post, b'content')
@ -134,7 +134,7 @@ def test_trying_to_update_non_existing(context_factory, user_factory):
api.post_api.update_post( api.post_api.update_post(
context_factory( context_factory(
params='whatever', params='whatever',
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=model.User.RANK_REGULAR)),
{'post_id': 1}) {'post_id': 1})
@ -158,7 +158,7 @@ def test_trying_to_update_field_without_privileges(
context_factory( context_factory(
params={**params, **{'version': 1}}, params={**params, **{'version': 1}},
files=files, files=files,
user=user_factory(rank=db.User.RANK_ANONYMOUS)), user=user_factory(rank=model.User.RANK_ANONYMOUS)),
{'post_id': post.post_id}) {'post_id': post.post_id})
@ -173,5 +173,5 @@ def test_trying_to_create_tags_without_privileges(
api.post_api.update_post( api.post_api.update_post(
context_factory( context_factory(
params={'tags': ['tag1', 'tag2'], 'version': 1}, params={'tags': ['tag1', 'tag2'], 'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=model.User.RANK_REGULAR)),
{'post_id': post.post_id}) {'post_id': post.post_id})

View file

@ -1,10 +1,10 @@
from datetime import datetime from datetime import datetime
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, model, errors
def snapshot_factory(): def snapshot_factory():
snapshot = db.Snapshot() snapshot = model.Snapshot()
snapshot.creation_time = datetime(1999, 1, 1) snapshot.creation_time = datetime(1999, 1, 1)
snapshot.resource_type = 'dummy' snapshot.resource_type = 'dummy'
snapshot.resource_pkey = 1 snapshot.resource_pkey = 1
@ -17,7 +17,7 @@ def snapshot_factory():
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({ config_injector({
'privileges': {'snapshots:list': db.User.RANK_REGULAR}, 'privileges': {'snapshots:list': model.User.RANK_REGULAR},
}) })
@ -29,7 +29,7 @@ def test_retrieving_multiple(user_factory, context_factory):
result = api.snapshot_api.get_snapshots( result = api.snapshot_api.get_snapshots(
context_factory( context_factory(
params={'query': '', 'page': 1}, params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=model.User.RANK_REGULAR)))
assert result['query'] == '' assert result['query'] == ''
assert result['page'] == 1 assert result['page'] == 1
assert result['pageSize'] == 100 assert result['pageSize'] == 100
@ -43,4 +43,4 @@ def test_trying_to_retrieve_multiple_without_privileges(
api.snapshot_api.get_snapshots( api.snapshot_api.get_snapshots(
context_factory( context_factory(
params={'query': '', 'page': 1}, params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS))) user=user_factory(rank=model.User.RANK_ANONYMOUS)))

View file

@ -1,6 +1,6 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, model, errors
from szurubooru.func import tag_categories, tags, snapshots from szurubooru.func import tag_categories, tags, snapshots
@ -11,13 +11,13 @@ def _update_category_name(category, name):
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({ config_injector({
'privileges': {'tag_categories:create': db.User.RANK_REGULAR}, 'privileges': {'tag_categories:create': model.User.RANK_REGULAR},
}) })
def test_creating_category( def test_creating_category(
tag_category_factory, user_factory, context_factory): tag_category_factory, user_factory, context_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=model.User.RANK_REGULAR)
category = tag_category_factory(name='meta') category = tag_category_factory(name='meta')
db.session.add(category) db.session.add(category)
@ -49,7 +49,7 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
api.tag_category_api.create_tag_category( api.tag_category_api.create_tag_category(
context_factory( context_factory(
params=params, params=params,
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=model.User.RANK_REGULAR)))
def test_trying_to_create_without_privileges(user_factory, context_factory): def test_trying_to_create_without_privileges(user_factory, context_factory):
@ -57,4 +57,4 @@ def test_trying_to_create_without_privileges(user_factory, context_factory):
api.tag_category_api.create_tag_category( api.tag_category_api.create_tag_category(
context_factory( context_factory(
params={'name': 'meta', 'color': 'black'}, params={'name': 'meta', 'color': 'black'},
user=user_factory(rank=db.User.RANK_ANONYMOUS))) user=user_factory(rank=model.User.RANK_ANONYMOUS)))

View file

@ -1,18 +1,18 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, model, errors
from szurubooru.func import tag_categories, tags, snapshots from szurubooru.func import tag_categories, tags, snapshots
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({ config_injector({
'privileges': {'tag_categories:delete': db.User.RANK_REGULAR}, 'privileges': {'tag_categories:delete': model.User.RANK_REGULAR},
}) })
def test_deleting(user_factory, tag_category_factory, context_factory): def test_deleting(user_factory, tag_category_factory, context_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=model.User.RANK_REGULAR)
category = tag_category_factory(name='category') category = tag_category_factory(name='category')
db.session.add(tag_category_factory(name='root')) db.session.add(tag_category_factory(name='root'))
db.session.add(category) db.session.add(category)
@ -23,8 +23,8 @@ def test_deleting(user_factory, tag_category_factory, context_factory):
context_factory(params={'version': 1}, user=auth_user), context_factory(params={'version': 1}, user=auth_user),
{'category_name': 'category'}) {'category_name': 'category'})
assert result == {} assert result == {}
assert db.session.query(db.TagCategory).count() == 1 assert db.session.query(model.TagCategory).count() == 1
assert db.session.query(db.TagCategory).one().name == 'root' assert db.session.query(model.TagCategory).one().name == 'root'
snapshots.delete.assert_called_once_with(category, auth_user) snapshots.delete.assert_called_once_with(category, auth_user)
tags.export_to_json.assert_called_once_with() tags.export_to_json.assert_called_once_with()
@ -41,9 +41,9 @@ def test_trying_to_delete_used(
api.tag_category_api.delete_tag_category( api.tag_category_api.delete_tag_category(
context_factory( context_factory(
params={'version': 1}, params={'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=model.User.RANK_REGULAR)),
{'category_name': 'category'}) {'category_name': 'category'})
assert db.session.query(db.TagCategory).count() == 1 assert db.session.query(model.TagCategory).count() == 1
def test_trying_to_delete_last( def test_trying_to_delete_last(
@ -54,14 +54,14 @@ def test_trying_to_delete_last(
api.tag_category_api.delete_tag_category( api.tag_category_api.delete_tag_category(
context_factory( context_factory(
params={'version': 1}, params={'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=model.User.RANK_REGULAR)),
{'category_name': 'root'}) {'category_name': 'root'})
def test_trying_to_delete_non_existing(user_factory, context_factory): def test_trying_to_delete_non_existing(user_factory, context_factory):
with pytest.raises(tag_categories.TagCategoryNotFoundError): with pytest.raises(tag_categories.TagCategoryNotFoundError):
api.tag_category_api.delete_tag_category( api.tag_category_api.delete_tag_category(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), context_factory(user=user_factory(rank=model.User.RANK_REGULAR)),
{'category_name': 'bad'}) {'category_name': 'bad'})
@ -73,6 +73,6 @@ def test_trying_to_delete_without_privileges(
api.tag_category_api.delete_tag_category( api.tag_category_api.delete_tag_category(
context_factory( context_factory(
params={'version': 1}, params={'version': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS)), user=user_factory(rank=model.User.RANK_ANONYMOUS)),
{'category_name': 'category'}) {'category_name': 'category'})
assert db.session.query(db.TagCategory).count() == 1 assert db.session.query(model.TagCategory).count() == 1

View file

@ -1,5 +1,5 @@
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, model, errors
from szurubooru.func import tag_categories from szurubooru.func import tag_categories
@ -7,8 +7,8 @@ from szurubooru.func import tag_categories
def inject_config(config_injector): def inject_config(config_injector):
config_injector({ config_injector({
'privileges': { 'privileges': {
'tag_categories:list': db.User.RANK_REGULAR, 'tag_categories:list': model.User.RANK_REGULAR,
'tag_categories:view': db.User.RANK_REGULAR, 'tag_categories:view': model.User.RANK_REGULAR,
}, },
}) })
@ -21,7 +21,7 @@ def test_retrieving_multiple(
]) ])
db.session.flush() db.session.flush()
result = api.tag_category_api.get_tag_categories( result = api.tag_category_api.get_tag_categories(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR))) context_factory(user=user_factory(rank=model.User.RANK_REGULAR)))
assert [cat['name'] for cat in result['results']] == ['c1', 'c2'] assert [cat['name'] for cat in result['results']] == ['c1', 'c2']
@ -30,7 +30,7 @@ def test_retrieving_single(
db.session.add(tag_category_factory(name='cat')) db.session.add(tag_category_factory(name='cat'))
db.session.flush() db.session.flush()
result = api.tag_category_api.get_tag_category( result = api.tag_category_api.get_tag_category(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), context_factory(user=user_factory(rank=model.User.RANK_REGULAR)),
{'category_name': 'cat'}) {'category_name': 'cat'})
assert result == { assert result == {
'name': 'cat', 'name': 'cat',
@ -44,7 +44,7 @@ def test_retrieving_single(
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
with pytest.raises(tag_categories.TagCategoryNotFoundError): with pytest.raises(tag_categories.TagCategoryNotFoundError):
api.tag_category_api.get_tag_category( api.tag_category_api.get_tag_category(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), context_factory(user=user_factory(rank=model.User.RANK_REGULAR)),
{'category_name': '-'}) {'category_name': '-'})
@ -52,5 +52,5 @@ def test_trying_to_retrieve_single_without_privileges(
user_factory, context_factory): user_factory, context_factory):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
api.tag_category_api.get_tag_category( api.tag_category_api.get_tag_category(
context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)), context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)),
{'category_name': '-'}) {'category_name': '-'})

View file

@ -1,6 +1,6 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, model, errors
from szurubooru.func import tag_categories, tags, snapshots from szurubooru.func import tag_categories, tags, snapshots
@ -12,15 +12,15 @@ def _update_category_name(category, name):
def inject_config(config_injector): def inject_config(config_injector):
config_injector({ config_injector({
'privileges': { 'privileges': {
'tag_categories:edit:name': db.User.RANK_REGULAR, 'tag_categories:edit:name': model.User.RANK_REGULAR,
'tag_categories:edit:color': db.User.RANK_REGULAR, 'tag_categories:edit:color': model.User.RANK_REGULAR,
'tag_categories:set_default': db.User.RANK_REGULAR, 'tag_categories:set_default': model.User.RANK_REGULAR,
}, },
}) })
def test_simple_updating(user_factory, tag_category_factory, context_factory): def test_simple_updating(user_factory, tag_category_factory, context_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=model.User.RANK_REGULAR)
category = tag_category_factory(name='name', color='black') category = tag_category_factory(name='name', color='black')
db.session.add(category) db.session.add(category)
db.session.flush() db.session.flush()
@ -61,7 +61,7 @@ def test_omitting_optional_field(
api.tag_category_api.update_tag_category( api.tag_category_api.update_tag_category(
context_factory( context_factory(
params={**params, **{'version': 1}}, params={**params, **{'version': 1}},
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=model.User.RANK_REGULAR)),
{'category_name': 'name'}) {'category_name': 'name'})
@ -70,7 +70,7 @@ def test_trying_to_update_non_existing(user_factory, context_factory):
api.tag_category_api.update_tag_category( api.tag_category_api.update_tag_category(
context_factory( context_factory(
params={'name': ['dummy']}, params={'name': ['dummy']},
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=model.User.RANK_REGULAR)),
{'category_name': 'bad'}) {'category_name': 'bad'})
@ -86,7 +86,7 @@ def test_trying_to_update_without_privileges(
api.tag_category_api.update_tag_category( api.tag_category_api.update_tag_category(
context_factory( context_factory(
params={**params, **{'version': 1}}, params={**params, **{'version': 1}},
user=user_factory(rank=db.User.RANK_ANONYMOUS)), user=user_factory(rank=model.User.RANK_ANONYMOUS)),
{'category_name': 'dummy'}) {'category_name': 'dummy'})
@ -106,7 +106,7 @@ def test_set_as_default(user_factory, tag_category_factory, context_factory):
'color': 'white', 'color': 'white',
'version': 1, 'version': 1,
}, },
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=model.User.RANK_REGULAR)),
{'category_name': 'name'}) {'category_name': 'name'})
assert result == 'serialized category' assert result == 'serialized category'
tag_categories.set_default_category.assert_called_once_with(category) tag_categories.set_default_category.assert_called_once_with(category)

View file

@ -1,16 +1,16 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, model, errors
from szurubooru.func import tags, snapshots from szurubooru.func import tags, snapshots
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({'privileges': {'tags:create': db.User.RANK_REGULAR}}) config_injector({'privileges': {'tags:create': model.User.RANK_REGULAR}})
def test_creating_simple_tags(tag_factory, user_factory, context_factory): def test_creating_simple_tags(tag_factory, user_factory, context_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=model.User.RANK_REGULAR)
tag = tag_factory() tag = tag_factory()
with patch('szurubooru.func.tags.create_tag'), \ with patch('szurubooru.func.tags.create_tag'), \
patch('szurubooru.func.tags.get_or_create_tags_by_names'), \ patch('szurubooru.func.tags.get_or_create_tags_by_names'), \
@ -50,7 +50,7 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
api.tag_api.create_tag( api.tag_api.create_tag(
context_factory( context_factory(
params=params, params=params,
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=model.User.RANK_REGULAR)))
@pytest.mark.parametrize('field', ['implications', 'suggestions']) @pytest.mark.parametrize('field', ['implications', 'suggestions'])
@ -70,7 +70,7 @@ def test_omitting_optional_field(
api.tag_api.create_tag( api.tag_api.create_tag(
context_factory( context_factory(
params=params, params=params,
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=model.User.RANK_REGULAR)))
def test_trying_to_create_tag_without_privileges( def test_trying_to_create_tag_without_privileges(
@ -84,4 +84,4 @@ def test_trying_to_create_tag_without_privileges(
'suggestions': ['tag'], 'suggestions': ['tag'],
'implications': [], 'implications': [],
}, },
user=user_factory(rank=db.User.RANK_ANONYMOUS))) user=user_factory(rank=model.User.RANK_ANONYMOUS)))

View file

@ -1,16 +1,16 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, model, errors
from szurubooru.func import tags, snapshots from szurubooru.func import tags, snapshots
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({'privileges': {'tags:delete': db.User.RANK_REGULAR}}) config_injector({'privileges': {'tags:delete': model.User.RANK_REGULAR}})
def test_deleting(user_factory, tag_factory, context_factory): def test_deleting(user_factory, tag_factory, context_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=model.User.RANK_REGULAR)
tag = tag_factory(names=['tag']) tag = tag_factory(names=['tag'])
db.session.add(tag) db.session.add(tag)
db.session.commit() db.session.commit()
@ -20,7 +20,7 @@ def test_deleting(user_factory, tag_factory, context_factory):
context_factory(params={'version': 1}, user=auth_user), context_factory(params={'version': 1}, user=auth_user),
{'tag_name': 'tag'}) {'tag_name': 'tag'})
assert result == {} assert result == {}
assert db.session.query(db.Tag).count() == 0 assert db.session.query(model.Tag).count() == 0
snapshots.delete.assert_called_once_with(tag, auth_user) snapshots.delete.assert_called_once_with(tag, auth_user)
tags.export_to_json.assert_called_once_with() tags.export_to_json.assert_called_once_with()
@ -36,17 +36,17 @@ def test_deleting_used(
api.tag_api.delete_tag( api.tag_api.delete_tag(
context_factory( context_factory(
params={'version': 1}, params={'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=model.User.RANK_REGULAR)),
{'tag_name': 'tag'}) {'tag_name': 'tag'})
db.session.refresh(post) db.session.refresh(post)
assert db.session.query(db.Tag).count() == 0 assert db.session.query(model.Tag).count() == 0
assert post.tags == [] assert post.tags == []
def test_trying_to_delete_non_existing(user_factory, context_factory): def test_trying_to_delete_non_existing(user_factory, context_factory):
with pytest.raises(tags.TagNotFoundError): with pytest.raises(tags.TagNotFoundError):
api.tag_api.delete_tag( api.tag_api.delete_tag(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), context_factory(user=user_factory(rank=model.User.RANK_REGULAR)),
{'tag_name': 'bad'}) {'tag_name': 'bad'})
@ -58,6 +58,6 @@ def test_trying_to_delete_without_privileges(
api.tag_api.delete_tag( api.tag_api.delete_tag(
context_factory( context_factory(
params={'version': 1}, params={'version': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS)), user=user_factory(rank=model.User.RANK_ANONYMOUS)),
{'tag_name': 'tag'}) {'tag_name': 'tag'})
assert db.session.query(db.Tag).count() == 1 assert db.session.query(model.Tag).count() == 1

View file

@ -1,16 +1,16 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, model, errors
from szurubooru.func import tags, snapshots from szurubooru.func import tags, snapshots
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({'privileges': {'tags:merge': db.User.RANK_REGULAR}}) config_injector({'privileges': {'tags:merge': model.User.RANK_REGULAR}})
def test_merging(user_factory, tag_factory, context_factory, post_factory): def test_merging(user_factory, tag_factory, context_factory, post_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=model.User.RANK_REGULAR)
source_tag = tag_factory(names=['source']) source_tag = tag_factory(names=['source'])
target_tag = tag_factory(names=['target']) target_tag = tag_factory(names=['target'])
db.session.add_all([source_tag, target_tag]) db.session.add_all([source_tag, target_tag])
@ -62,7 +62,7 @@ def test_trying_to_omit_mandatory_field(
api.tag_api.merge_tags( api.tag_api.merge_tags(
context_factory( context_factory(
params=params, params=params,
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=model.User.RANK_REGULAR)))
def test_trying_to_merge_non_existing( def test_trying_to_merge_non_existing(
@ -73,12 +73,12 @@ def test_trying_to_merge_non_existing(
api.tag_api.merge_tags( api.tag_api.merge_tags(
context_factory( context_factory(
params={'remove': 'good', 'mergeTo': 'bad'}, params={'remove': 'good', 'mergeTo': 'bad'},
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=model.User.RANK_REGULAR)))
with pytest.raises(tags.TagNotFoundError): with pytest.raises(tags.TagNotFoundError):
api.tag_api.merge_tags( api.tag_api.merge_tags(
context_factory( context_factory(
params={'remove': 'bad', 'mergeTo': 'good'}, params={'remove': 'bad', 'mergeTo': 'good'},
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=model.User.RANK_REGULAR)))
def test_trying_to_merge_without_privileges( def test_trying_to_merge_without_privileges(
@ -97,4 +97,4 @@ def test_trying_to_merge_without_privileges(
'remove': 'source', 'remove': 'source',
'mergeTo': 'target', 'mergeTo': 'target',
}, },
user=user_factory(rank=db.User.RANK_ANONYMOUS))) user=user_factory(rank=model.User.RANK_ANONYMOUS)))

View file

@ -1,6 +1,6 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, model, errors
from szurubooru.func import tags from szurubooru.func import tags
@ -8,8 +8,8 @@ from szurubooru.func import tags
def inject_config(config_injector): def inject_config(config_injector):
config_injector({ config_injector({
'privileges': { 'privileges': {
'tags:list': db.User.RANK_REGULAR, 'tags:list': model.User.RANK_REGULAR,
'tags:view': db.User.RANK_REGULAR, 'tags:view': model.User.RANK_REGULAR,
}, },
}) })
@ -24,7 +24,7 @@ def test_retrieving_multiple(user_factory, tag_factory, context_factory):
result = api.tag_api.get_tags( result = api.tag_api.get_tags(
context_factory( context_factory(
params={'query': '', 'page': 1}, params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=model.User.RANK_REGULAR)))
assert result == { assert result == {
'query': '', 'query': '',
'page': 1, 'page': 1,
@ -40,7 +40,7 @@ def test_trying_to_retrieve_multiple_without_privileges(
api.tag_api.get_tags( api.tag_api.get_tags(
context_factory( context_factory(
params={'query': '', 'page': 1}, params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS))) user=user_factory(rank=model.User.RANK_ANONYMOUS)))
def test_retrieving_single(user_factory, tag_factory, context_factory): def test_retrieving_single(user_factory, tag_factory, context_factory):
@ -50,7 +50,7 @@ def test_retrieving_single(user_factory, tag_factory, context_factory):
tags.serialize_tag.return_value = 'serialized tag' tags.serialize_tag.return_value = 'serialized tag'
result = api.tag_api.get_tag( result = api.tag_api.get_tag(
context_factory( context_factory(
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=model.User.RANK_REGULAR)),
{'tag_name': 'tag'}) {'tag_name': 'tag'})
assert result == 'serialized tag' assert result == 'serialized tag'
@ -59,7 +59,7 @@ def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
with pytest.raises(tags.TagNotFoundError): with pytest.raises(tags.TagNotFoundError):
api.tag_api.get_tag( api.tag_api.get_tag(
context_factory( context_factory(
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=model.User.RANK_REGULAR)),
{'tag_name': '-'}) {'tag_name': '-'})
@ -68,5 +68,5 @@ def test_trying_to_retrieve_single_without_privileges(
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
api.tag_api.get_tag( api.tag_api.get_tag(
context_factory( context_factory(
user=user_factory(rank=db.User.RANK_ANONYMOUS)), user=user_factory(rank=model.User.RANK_ANONYMOUS)),
{'tag_name': '-'}) {'tag_name': '-'})

View file

@ -1,12 +1,12 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, model, errors
from szurubooru.func import tags from szurubooru.func import tags
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_config(config_injector): def inject_config(config_injector):
config_injector({'privileges': {'tags:view': db.User.RANK_REGULAR}}) config_injector({'privileges': {'tags:view': model.User.RANK_REGULAR}})
def test_get_tag_siblings(user_factory, tag_factory, context_factory): def test_get_tag_siblings(user_factory, tag_factory, context_factory):
@ -21,7 +21,7 @@ def test_get_tag_siblings(user_factory, tag_factory, context_factory):
(tag_factory(names=['sib2']), 3), (tag_factory(names=['sib2']), 3),
] ]
result = api.tag_api.get_tag_siblings( result = api.tag_api.get_tag_siblings(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), context_factory(user=user_factory(rank=model.User.RANK_REGULAR)),
{'tag_name': 'tag'}) {'tag_name': 'tag'})
assert result == { assert result == {
'results': [ 'results': [
@ -40,12 +40,12 @@ def test_get_tag_siblings(user_factory, tag_factory, context_factory):
def test_trying_to_retrieve_non_existing(user_factory, context_factory): def test_trying_to_retrieve_non_existing(user_factory, context_factory):
with pytest.raises(tags.TagNotFoundError): with pytest.raises(tags.TagNotFoundError):
api.tag_api.get_tag_siblings( api.tag_api.get_tag_siblings(
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), context_factory(user=user_factory(rank=model.User.RANK_REGULAR)),
{'tag_name': '-'}) {'tag_name': '-'})
def test_trying_to_retrieve_without_privileges(user_factory, context_factory): def test_trying_to_retrieve_without_privileges(user_factory, context_factory):
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
api.tag_api.get_tag_siblings( api.tag_api.get_tag_siblings(
context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)), context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)),
{'tag_name': '-'}) {'tag_name': '-'})

View file

@ -1,6 +1,6 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, model, errors
from szurubooru.func import tags, snapshots from szurubooru.func import tags, snapshots
@ -8,18 +8,18 @@ from szurubooru.func import tags, snapshots
def inject_config(config_injector): def inject_config(config_injector):
config_injector({ config_injector({
'privileges': { 'privileges': {
'tags:create': db.User.RANK_REGULAR, 'tags:create': model.User.RANK_REGULAR,
'tags:edit:names': db.User.RANK_REGULAR, 'tags:edit:names': model.User.RANK_REGULAR,
'tags:edit:category': db.User.RANK_REGULAR, 'tags:edit:category': model.User.RANK_REGULAR,
'tags:edit:description': db.User.RANK_REGULAR, 'tags:edit:description': model.User.RANK_REGULAR,
'tags:edit:suggestions': db.User.RANK_REGULAR, 'tags:edit:suggestions': model.User.RANK_REGULAR,
'tags:edit:implications': db.User.RANK_REGULAR, 'tags:edit:implications': model.User.RANK_REGULAR,
}, },
}) })
def test_simple_updating(user_factory, tag_factory, context_factory): def test_simple_updating(user_factory, tag_factory, context_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=model.User.RANK_REGULAR)
tag = tag_factory(names=['tag1', 'tag2']) tag = tag_factory(names=['tag1', 'tag2'])
db.session.add(tag) db.session.add(tag)
db.session.commit() db.session.commit()
@ -56,8 +56,7 @@ def test_simple_updating(user_factory, tag_factory, context_factory):
tag, ['sug1', 'sug2']) tag, ['sug1', 'sug2'])
tags.update_tag_implications.assert_called_once_with( tags.update_tag_implications.assert_called_once_with(
tag, ['imp1', 'imp2']) tag, ['imp1', 'imp2'])
tags.serialize_tag.assert_called_once_with( tags.serialize_tag.assert_called_once_with(tag, options=[])
tag, options=None)
snapshots.modify.assert_called_once_with(tag, auth_user) snapshots.modify.assert_called_once_with(tag, auth_user)
tags.export_to_json.assert_called_once_with() tags.export_to_json.assert_called_once_with()
@ -90,7 +89,7 @@ def test_omitting_optional_field(
api.tag_api.update_tag( api.tag_api.update_tag(
context_factory( context_factory(
params={**params, **{'version': 1}}, params={**params, **{'version': 1}},
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=model.User.RANK_REGULAR)),
{'tag_name': 'tag'}) {'tag_name': 'tag'})
@ -99,7 +98,7 @@ def test_trying_to_update_non_existing(user_factory, context_factory):
api.tag_api.update_tag( api.tag_api.update_tag(
context_factory( context_factory(
params={'names': ['dummy']}, params={'names': ['dummy']},
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=model.User.RANK_REGULAR)),
{'tag_name': 'tag1'}) {'tag_name': 'tag1'})
@ -117,7 +116,7 @@ def test_trying_to_update_without_privileges(
api.tag_api.update_tag( api.tag_api.update_tag(
context_factory( context_factory(
params={**params, **{'version': 1}}, params={**params, **{'version': 1}},
user=user_factory(rank=db.User.RANK_ANONYMOUS)), user=user_factory(rank=model.User.RANK_ANONYMOUS)),
{'tag_name': 'tag'}) {'tag_name': 'tag'})
@ -127,9 +126,9 @@ def test_trying_to_create_tags_without_privileges(
db.session.add(tag) db.session.add(tag)
db.session.commit() db.session.commit()
config_injector({'privileges': { config_injector({'privileges': {
'tags:create': db.User.RANK_ADMINISTRATOR, 'tags:create': model.User.RANK_ADMINISTRATOR,
'tags:edit:suggestions': db.User.RANK_REGULAR, 'tags:edit:suggestions': model.User.RANK_REGULAR,
'tags:edit:implications': db.User.RANK_REGULAR, 'tags:edit:implications': model.User.RANK_REGULAR,
}}) }})
with patch('szurubooru.func.tags.get_or_create_tags_by_names'): with patch('szurubooru.func.tags.get_or_create_tags_by_names'):
tags.get_or_create_tags_by_names.return_value = ([], ['new-tag']) tags.get_or_create_tags_by_names.return_value = ([], ['new-tag'])
@ -137,12 +136,12 @@ def test_trying_to_create_tags_without_privileges(
api.tag_api.update_tag( api.tag_api.update_tag(
context_factory( context_factory(
params={'suggestions': ['tag1', 'tag2'], 'version': 1}, params={'suggestions': ['tag1', 'tag2'], 'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=model.User.RANK_REGULAR)),
{'tag_name': 'tag'}) {'tag_name': 'tag'})
db.session.rollback() db.session.rollback()
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
api.tag_api.update_tag( api.tag_api.update_tag(
context_factory( context_factory(
params={'implications': ['tag1', 'tag2'], 'version': 1}, params={'implications': ['tag1', 'tag2'], 'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=model.User.RANK_REGULAR)),
{'tag_name': 'tag'}) {'tag_name': 'tag'})

View file

@ -1,6 +1,6 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, model, errors
from szurubooru.func import users from szurubooru.func import users
@ -31,7 +31,7 @@ def test_creating_user(user_factory, context_factory, fake_datetime):
'avatarStyle': 'manual', 'avatarStyle': 'manual',
}, },
files={'avatar': b'...'}, files={'avatar': b'...'},
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=model.User.RANK_REGULAR)))
assert result == 'serialized user' assert result == 'serialized user'
users.create_user.assert_called_once_with( users.create_user.assert_called_once_with(
'chewie1', 'oks', 'asd@asd.asd') 'chewie1', 'oks', 'asd@asd.asd')
@ -50,7 +50,7 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
'password': 'oks', 'password': 'oks',
} }
user = user_factory() user = user_factory()
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=model.User.RANK_REGULAR)
del params[field] del params[field]
with patch('szurubooru.func.users.create_user'), \ with patch('szurubooru.func.users.create_user'), \
pytest.raises(errors.MissingRequiredParameterError): pytest.raises(errors.MissingRequiredParameterError):
@ -70,7 +70,7 @@ def test_omitting_optional_field(user_factory, context_factory, field):
} }
del params[field] del params[field]
user = user_factory() user = user_factory()
auth_user = user_factory(rank=db.User.RANK_MODERATOR) auth_user = user_factory(rank=model.User.RANK_MODERATOR)
with patch('szurubooru.func.users.create_user'), \ with patch('szurubooru.func.users.create_user'), \
patch('szurubooru.func.users.update_user_avatar'), \ patch('szurubooru.func.users.update_user_avatar'), \
patch('szurubooru.func.users.serialize_user'): patch('szurubooru.func.users.serialize_user'):
@ -84,4 +84,4 @@ def test_trying_to_create_user_without_privileges(
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
api.user_api.create_user(context_factory( api.user_api.create_user(context_factory(
params='whatever', params='whatever',
user=user_factory(rank=db.User.RANK_ANONYMOUS))) user=user_factory(rank=model.User.RANK_ANONYMOUS)))

View file

@ -1,5 +1,5 @@
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, model, errors
from szurubooru.func import users from szurubooru.func import users
@ -7,45 +7,45 @@ from szurubooru.func import users
def inject_config(config_injector): def inject_config(config_injector):
config_injector({ config_injector({
'privileges': { 'privileges': {
'users:delete:self': db.User.RANK_REGULAR, 'users:delete:self': model.User.RANK_REGULAR,
'users:delete:any': db.User.RANK_MODERATOR, 'users:delete:any': model.User.RANK_MODERATOR,
}, },
}) })
def test_deleting_oneself(user_factory, context_factory): def test_deleting_oneself(user_factory, context_factory):
user = user_factory(name='u', rank=db.User.RANK_REGULAR) user = user_factory(name='u', rank=model.User.RANK_REGULAR)
db.session.add(user) db.session.add(user)
db.session.commit() db.session.commit()
result = api.user_api.delete_user( result = api.user_api.delete_user(
context_factory( context_factory(
params={'version': 1}, user=user), {'user_name': 'u'}) params={'version': 1}, user=user), {'user_name': 'u'})
assert result == {} assert result == {}
assert db.session.query(db.User).count() == 0 assert db.session.query(model.User).count() == 0
def test_deleting_someone_else(user_factory, context_factory): def test_deleting_someone_else(user_factory, context_factory):
user1 = user_factory(name='u1', rank=db.User.RANK_REGULAR) user1 = user_factory(name='u1', rank=model.User.RANK_REGULAR)
user2 = user_factory(name='u2', rank=db.User.RANK_MODERATOR) user2 = user_factory(name='u2', rank=model.User.RANK_MODERATOR)
db.session.add_all([user1, user2]) db.session.add_all([user1, user2])
db.session.commit() db.session.commit()
api.user_api.delete_user( api.user_api.delete_user(
context_factory( context_factory(
params={'version': 1}, user=user2), {'user_name': 'u1'}) params={'version': 1}, user=user2), {'user_name': 'u1'})
assert db.session.query(db.User).count() == 1 assert db.session.query(model.User).count() == 1
def test_trying_to_delete_someone_else_without_privileges( def test_trying_to_delete_someone_else_without_privileges(
user_factory, context_factory): user_factory, context_factory):
user1 = user_factory(name='u1', rank=db.User.RANK_REGULAR) user1 = user_factory(name='u1', rank=model.User.RANK_REGULAR)
user2 = user_factory(name='u2', rank=db.User.RANK_REGULAR) user2 = user_factory(name='u2', rank=model.User.RANK_REGULAR)
db.session.add_all([user1, user2]) db.session.add_all([user1, user2])
db.session.commit() db.session.commit()
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
api.user_api.delete_user( api.user_api.delete_user(
context_factory( context_factory(
params={'version': 1}, user=user2), {'user_name': 'u1'}) params={'version': 1}, user=user2), {'user_name': 'u1'})
assert db.session.query(db.User).count() == 2 assert db.session.query(model.User).count() == 2
def test_trying_to_delete_non_existing(user_factory, context_factory): def test_trying_to_delete_non_existing(user_factory, context_factory):
@ -53,5 +53,5 @@ def test_trying_to_delete_non_existing(user_factory, context_factory):
api.user_api.delete_user( api.user_api.delete_user(
context_factory( context_factory(
params={'version': 1}, params={'version': 1},
user=user_factory(rank=db.User.RANK_REGULAR)), user=user_factory(rank=model.User.RANK_REGULAR)),
{'user_name': 'bad'}) {'user_name': 'bad'})

View file

@ -1,6 +1,6 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, model, errors
from szurubooru.func import users from szurubooru.func import users
@ -8,16 +8,16 @@ from szurubooru.func import users
def inject_config(config_injector): def inject_config(config_injector):
config_injector({ config_injector({
'privileges': { 'privileges': {
'users:list': db.User.RANK_REGULAR, 'users:list': model.User.RANK_REGULAR,
'users:view': db.User.RANK_REGULAR, 'users:view': model.User.RANK_REGULAR,
'users:edit:any:email': db.User.RANK_MODERATOR, 'users:edit:any:email': model.User.RANK_MODERATOR,
}, },
}) })
def test_retrieving_multiple(user_factory, context_factory): def test_retrieving_multiple(user_factory, context_factory):
user1 = user_factory(name='u1', rank=db.User.RANK_MODERATOR) user1 = user_factory(name='u1', rank=model.User.RANK_MODERATOR)
user2 = user_factory(name='u2', rank=db.User.RANK_MODERATOR) user2 = user_factory(name='u2', rank=model.User.RANK_MODERATOR)
db.session.add_all([user1, user2]) db.session.add_all([user1, user2])
db.session.flush() db.session.flush()
with patch('szurubooru.func.users.serialize_user'): with patch('szurubooru.func.users.serialize_user'):
@ -25,7 +25,7 @@ def test_retrieving_multiple(user_factory, context_factory):
result = api.user_api.get_users( result = api.user_api.get_users(
context_factory( context_factory(
params={'query': '', 'page': 1}, params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_REGULAR))) user=user_factory(rank=model.User.RANK_REGULAR)))
assert result == { assert result == {
'query': '', 'query': '',
'page': 1, 'page': 1,
@ -41,12 +41,12 @@ def test_trying_to_retrieve_multiple_without_privileges(
api.user_api.get_users( api.user_api.get_users(
context_factory( context_factory(
params={'query': '', 'page': 1}, params={'query': '', 'page': 1},
user=user_factory(rank=db.User.RANK_ANONYMOUS))) user=user_factory(rank=model.User.RANK_ANONYMOUS)))
def test_retrieving_single(user_factory, context_factory): def test_retrieving_single(user_factory, context_factory):
user = user_factory(name='u1', rank=db.User.RANK_REGULAR) user = user_factory(name='u1', rank=model.User.RANK_REGULAR)
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=model.User.RANK_REGULAR)
db.session.add(user) db.session.add(user)
db.session.flush() db.session.flush()
with patch('szurubooru.func.users.serialize_user'): with patch('szurubooru.func.users.serialize_user'):
@ -57,7 +57,7 @@ def test_retrieving_single(user_factory, context_factory):
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
auth_user = user_factory(rank=db.User.RANK_REGULAR) auth_user = user_factory(rank=model.User.RANK_REGULAR)
with pytest.raises(users.UserNotFoundError): with pytest.raises(users.UserNotFoundError):
api.user_api.get_user( api.user_api.get_user(
context_factory(user=auth_user), {'user_name': '-'}) context_factory(user=auth_user), {'user_name': '-'})
@ -65,8 +65,8 @@ def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
def test_trying_to_retrieve_single_without_privileges( def test_trying_to_retrieve_single_without_privileges(
user_factory, context_factory): user_factory, context_factory):
auth_user = user_factory(rank=db.User.RANK_ANONYMOUS) auth_user = user_factory(rank=model.User.RANK_ANONYMOUS)
db.session.add(user_factory(name='u1', rank=db.User.RANK_REGULAR)) db.session.add(user_factory(name='u1', rank=model.User.RANK_REGULAR))
db.session.flush() db.session.flush()
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):
api.user_api.get_user( api.user_api.get_user(

View file

@ -1,6 +1,6 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db, errors from szurubooru import api, db, model, errors
from szurubooru.func import users from szurubooru.func import users
@ -8,23 +8,23 @@ from szurubooru.func import users
def inject_config(config_injector): def inject_config(config_injector):
config_injector({ config_injector({
'privileges': { 'privileges': {
'users:edit:self:name': db.User.RANK_REGULAR, 'users:edit:self:name': model.User.RANK_REGULAR,
'users:edit:self:pass': db.User.RANK_REGULAR, 'users:edit:self:pass': model.User.RANK_REGULAR,
'users:edit:self:email': db.User.RANK_REGULAR, 'users:edit:self:email': model.User.RANK_REGULAR,
'users:edit:self:rank': db.User.RANK_MODERATOR, 'users:edit:self:rank': model.User.RANK_MODERATOR,
'users:edit:self:avatar': db.User.RANK_MODERATOR, 'users:edit:self:avatar': model.User.RANK_MODERATOR,
'users:edit:any:name': db.User.RANK_MODERATOR, 'users:edit:any:name': model.User.RANK_MODERATOR,
'users:edit:any:pass': db.User.RANK_MODERATOR, 'users:edit:any:pass': model.User.RANK_MODERATOR,
'users:edit:any:email': db.User.RANK_MODERATOR, 'users:edit:any:email': model.User.RANK_MODERATOR,
'users:edit:any:rank': db.User.RANK_ADMINISTRATOR, 'users:edit:any:rank': model.User.RANK_ADMINISTRATOR,
'users:edit:any:avatar': db.User.RANK_ADMINISTRATOR, 'users:edit:any:avatar': model.User.RANK_ADMINISTRATOR,
}, },
}) })
def test_updating_user(context_factory, user_factory): def test_updating_user(context_factory, user_factory):
user = user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR) user = user_factory(name='u1', rank=model.User.RANK_ADMINISTRATOR)
auth_user = user_factory(rank=db.User.RANK_ADMINISTRATOR) auth_user = user_factory(rank=model.User.RANK_ADMINISTRATOR)
db.session.add(user) db.session.add(user)
db.session.flush() db.session.flush()
@ -63,13 +63,13 @@ def test_updating_user(context_factory, user_factory):
users.update_user_avatar.assert_called_once_with( users.update_user_avatar.assert_called_once_with(
user, 'manual', b'...') user, 'manual', b'...')
users.serialize_user.assert_called_once_with( users.serialize_user.assert_called_once_with(
user, auth_user, options=None) user, auth_user, options=[])
@pytest.mark.parametrize( @pytest.mark.parametrize(
'field', ['name', 'email', 'password', 'rank', 'avatarStyle']) 'field', ['name', 'email', 'password', 'rank', 'avatarStyle'])
def test_omitting_optional_field(user_factory, context_factory, field): def test_omitting_optional_field(user_factory, context_factory, field):
user = user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR) user = user_factory(name='u1', rank=model.User.RANK_ADMINISTRATOR)
db.session.add(user) db.session.add(user)
db.session.flush() db.session.flush()
params = { params = {
@ -96,7 +96,7 @@ def test_omitting_optional_field(user_factory, context_factory, field):
def test_trying_to_update_non_existing(user_factory, context_factory): def test_trying_to_update_non_existing(user_factory, context_factory):
user = user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR) user = user_factory(name='u1', rank=model.User.RANK_ADMINISTRATOR)
db.session.add(user) db.session.add(user)
db.session.flush() db.session.flush()
with pytest.raises(users.UserNotFoundError): with pytest.raises(users.UserNotFoundError):
@ -113,8 +113,8 @@ def test_trying_to_update_non_existing(user_factory, context_factory):
]) ])
def test_trying_to_update_field_without_privileges( def test_trying_to_update_field_without_privileges(
user_factory, context_factory, params): user_factory, context_factory, params):
user1 = user_factory(name='u1', rank=db.User.RANK_REGULAR) user1 = user_factory(name='u1', rank=model.User.RANK_REGULAR)
user2 = user_factory(name='u2', rank=db.User.RANK_REGULAR) user2 = user_factory(name='u2', rank=model.User.RANK_REGULAR)
db.session.add_all([user1, user2]) db.session.add_all([user1, user2])
db.session.flush() db.session.flush()
with pytest.raises(errors.AuthError): with pytest.raises(errors.AuthError):

Some files were not shown because too many files have changed in this diff Show more