server: refactor + add type hinting
- Added type hinting (for now, 3.5-compatible) - Split `db` namespace into `db` module and `model` namespace - Changed elastic search to be created lazily for each operation - Changed to class based approach in entity serialization to allow stronger typing - Removed `required` argument from `context.get_*` family of functions; now it's implied if `default` argument is omitted - Changed `unalias_dict` implementation to use less magic inputs
This commit is contained in:
parent
abf1fc2b2d
commit
ad842ee8a5
116 changed files with 2868 additions and 2037 deletions
|
@ -8,7 +8,7 @@ import zlib
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
import logging
|
import logging
|
||||||
import coloredlogs
|
import coloredlogs
|
||||||
import sqlalchemy
|
import sqlalchemy as sa
|
||||||
from szurubooru import config, db
|
from szurubooru import config, db
|
||||||
from szurubooru.func import files, images, posts, comments
|
from szurubooru.func import files, images, posts, comments
|
||||||
|
|
||||||
|
@ -42,8 +42,8 @@ def get_v1_session(args):
|
||||||
port=args.port,
|
port=args.port,
|
||||||
name=args.name)
|
name=args.name)
|
||||||
logger.info('Connecting to %r...', dsn)
|
logger.info('Connecting to %r...', dsn)
|
||||||
engine = sqlalchemy.create_engine(dsn)
|
engine = sa.create_engine(dsn)
|
||||||
session_maker = sqlalchemy.orm.sessionmaker(bind=engine)
|
session_maker = sa.orm.sessionmaker(bind=engine)
|
||||||
return session_maker()
|
return session_maker()
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
|
14
server/mypy.ini
Normal file
14
server/mypy.ini
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
[mypy]
|
||||||
|
ignore_missing_imports = True
|
||||||
|
follow_imports = skip
|
||||||
|
disallow_untyped_calls = True
|
||||||
|
disallow_untyped_defs = True
|
||||||
|
check_untyped_defs = True
|
||||||
|
disallow_subclassing_any = False
|
||||||
|
warn_redundant_casts = True
|
||||||
|
warn_unused_ignores = True
|
||||||
|
strict_optional = True
|
||||||
|
strict_boolean = False
|
||||||
|
|
||||||
|
[mypy-szurubooru.tests.*]
|
||||||
|
ignore_errors=True
|
|
@ -1,31 +1,44 @@
|
||||||
import datetime
|
from typing import Dict
|
||||||
from szurubooru import search
|
from datetime import datetime
|
||||||
from szurubooru.rest import routes
|
from szurubooru import search, rest, model
|
||||||
from szurubooru.func import auth, comments, posts, scores, util, versions
|
from szurubooru.func import (
|
||||||
|
auth, comments, posts, scores, versions, serialization)
|
||||||
|
|
||||||
|
|
||||||
_search_executor = search.Executor(search.configs.CommentSearchConfig())
|
_search_executor = search.Executor(search.configs.CommentSearchConfig())
|
||||||
|
|
||||||
|
|
||||||
def _serialize(ctx, comment, **kwargs):
|
def _get_comment(params: Dict[str, str]) -> model.Comment:
|
||||||
|
try:
|
||||||
|
comment_id = int(params['comment_id'])
|
||||||
|
except TypeError:
|
||||||
|
raise comments.InvalidCommentIdError(
|
||||||
|
'Invalid comment ID: %r.' % params['comment_id'])
|
||||||
|
return comments.get_comment_by_id(comment_id)
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize(
|
||||||
|
ctx: rest.Context, comment: model.Comment) -> rest.Response:
|
||||||
return comments.serialize_comment(
|
return comments.serialize_comment(
|
||||||
comment,
|
comment,
|
||||||
ctx.user,
|
ctx.user,
|
||||||
options=util.get_serialization_options(ctx), **kwargs)
|
options=serialization.get_serialization_options(ctx))
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/comments/?')
|
@rest.routes.get('/comments/?')
|
||||||
def get_comments(ctx, _params=None):
|
def get_comments(
|
||||||
|
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||||
auth.verify_privilege(ctx.user, 'comments:list')
|
auth.verify_privilege(ctx.user, 'comments:list')
|
||||||
return _search_executor.execute_and_serialize(
|
return _search_executor.execute_and_serialize(
|
||||||
ctx, lambda comment: _serialize(ctx, comment))
|
ctx, lambda comment: _serialize(ctx, comment))
|
||||||
|
|
||||||
|
|
||||||
@routes.post('/comments/?')
|
@rest.routes.post('/comments/?')
|
||||||
def create_comment(ctx, _params=None):
|
def create_comment(
|
||||||
|
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||||
auth.verify_privilege(ctx.user, 'comments:create')
|
auth.verify_privilege(ctx.user, 'comments:create')
|
||||||
text = ctx.get_param_as_string('text', required=True)
|
text = ctx.get_param_as_string('text')
|
||||||
post_id = ctx.get_param_as_int('postId', required=True)
|
post_id = ctx.get_param_as_int('postId')
|
||||||
post = posts.get_post_by_id(post_id)
|
post = posts.get_post_by_id(post_id)
|
||||||
comment = comments.create_comment(ctx.user, post, text)
|
comment = comments.create_comment(ctx.user, post, text)
|
||||||
ctx.session.add(comment)
|
ctx.session.add(comment)
|
||||||
|
@ -33,30 +46,30 @@ def create_comment(ctx, _params=None):
|
||||||
return _serialize(ctx, comment)
|
return _serialize(ctx, comment)
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/comment/(?P<comment_id>[^/]+)/?')
|
@rest.routes.get('/comment/(?P<comment_id>[^/]+)/?')
|
||||||
def get_comment(ctx, params):
|
def get_comment(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
||||||
auth.verify_privilege(ctx.user, 'comments:view')
|
auth.verify_privilege(ctx.user, 'comments:view')
|
||||||
comment = comments.get_comment_by_id(params['comment_id'])
|
comment = _get_comment(params)
|
||||||
return _serialize(ctx, comment)
|
return _serialize(ctx, comment)
|
||||||
|
|
||||||
|
|
||||||
@routes.put('/comment/(?P<comment_id>[^/]+)/?')
|
@rest.routes.put('/comment/(?P<comment_id>[^/]+)/?')
|
||||||
def update_comment(ctx, params):
|
def update_comment(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
||||||
comment = comments.get_comment_by_id(params['comment_id'])
|
comment = _get_comment(params)
|
||||||
versions.verify_version(comment, ctx)
|
versions.verify_version(comment, ctx)
|
||||||
versions.bump_version(comment)
|
versions.bump_version(comment)
|
||||||
infix = 'own' if ctx.user.user_id == comment.user_id else 'any'
|
infix = 'own' if ctx.user.user_id == comment.user_id else 'any'
|
||||||
text = ctx.get_param_as_string('text', required=True)
|
text = ctx.get_param_as_string('text')
|
||||||
auth.verify_privilege(ctx.user, 'comments:edit:%s' % infix)
|
auth.verify_privilege(ctx.user, 'comments:edit:%s' % infix)
|
||||||
comments.update_comment_text(comment, text)
|
comments.update_comment_text(comment, text)
|
||||||
comment.last_edit_time = datetime.datetime.utcnow()
|
comment.last_edit_time = datetime.utcnow()
|
||||||
ctx.session.commit()
|
ctx.session.commit()
|
||||||
return _serialize(ctx, comment)
|
return _serialize(ctx, comment)
|
||||||
|
|
||||||
|
|
||||||
@routes.delete('/comment/(?P<comment_id>[^/]+)/?')
|
@rest.routes.delete('/comment/(?P<comment_id>[^/]+)/?')
|
||||||
def delete_comment(ctx, params):
|
def delete_comment(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
||||||
comment = comments.get_comment_by_id(params['comment_id'])
|
comment = _get_comment(params)
|
||||||
versions.verify_version(comment, ctx)
|
versions.verify_version(comment, ctx)
|
||||||
infix = 'own' if ctx.user.user_id == comment.user_id else 'any'
|
infix = 'own' if ctx.user.user_id == comment.user_id else 'any'
|
||||||
auth.verify_privilege(ctx.user, 'comments:delete:%s' % infix)
|
auth.verify_privilege(ctx.user, 'comments:delete:%s' % infix)
|
||||||
|
@ -65,20 +78,22 @@ def delete_comment(ctx, params):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
@routes.put('/comment/(?P<comment_id>[^/]+)/score/?')
|
@rest.routes.put('/comment/(?P<comment_id>[^/]+)/score/?')
|
||||||
def set_comment_score(ctx, params):
|
def set_comment_score(
|
||||||
|
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
||||||
auth.verify_privilege(ctx.user, 'comments:score')
|
auth.verify_privilege(ctx.user, 'comments:score')
|
||||||
score = ctx.get_param_as_int('score', required=True)
|
score = ctx.get_param_as_int('score')
|
||||||
comment = comments.get_comment_by_id(params['comment_id'])
|
comment = _get_comment(params)
|
||||||
scores.set_score(comment, ctx.user, score)
|
scores.set_score(comment, ctx.user, score)
|
||||||
ctx.session.commit()
|
ctx.session.commit()
|
||||||
return _serialize(ctx, comment)
|
return _serialize(ctx, comment)
|
||||||
|
|
||||||
|
|
||||||
@routes.delete('/comment/(?P<comment_id>[^/]+)/score/?')
|
@rest.routes.delete('/comment/(?P<comment_id>[^/]+)/score/?')
|
||||||
def delete_comment_score(ctx, params):
|
def delete_comment_score(
|
||||||
|
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
||||||
auth.verify_privilege(ctx.user, 'comments:score')
|
auth.verify_privilege(ctx.user, 'comments:score')
|
||||||
comment = comments.get_comment_by_id(params['comment_id'])
|
comment = _get_comment(params)
|
||||||
scores.delete_score(comment, ctx.user)
|
scores.delete_score(comment, ctx.user)
|
||||||
ctx.session.commit()
|
ctx.session.commit()
|
||||||
return _serialize(ctx, comment)
|
return _serialize(ctx, comment)
|
||||||
|
|
|
@ -1,19 +1,20 @@
|
||||||
import datetime
|
|
||||||
import os
|
import os
|
||||||
from szurubooru import config
|
from typing import Optional, Dict
|
||||||
from szurubooru.rest import routes
|
from datetime import datetime, timedelta
|
||||||
|
from szurubooru import config, rest
|
||||||
from szurubooru.func import posts, users, util
|
from szurubooru.func import posts, users, util
|
||||||
|
|
||||||
|
|
||||||
_cache_time = None
|
_cache_time = None # type: Optional[datetime]
|
||||||
_cache_result = None
|
_cache_result = None # type: Optional[int]
|
||||||
|
|
||||||
|
|
||||||
def _get_disk_usage():
|
def _get_disk_usage() -> int:
|
||||||
global _cache_time, _cache_result # pylint: disable=global-statement
|
global _cache_time, _cache_result # pylint: disable=global-statement
|
||||||
threshold = datetime.timedelta(hours=48)
|
threshold = timedelta(hours=48)
|
||||||
now = datetime.datetime.utcnow()
|
now = datetime.utcnow()
|
||||||
if _cache_time and _cache_time > now - threshold:
|
if _cache_time and _cache_time > now - threshold:
|
||||||
|
assert _cache_result
|
||||||
return _cache_result
|
return _cache_result
|
||||||
total_size = 0
|
total_size = 0
|
||||||
for dir_path, _, file_names in os.walk(config.config['data_dir']):
|
for dir_path, _, file_names in os.walk(config.config['data_dir']):
|
||||||
|
@ -25,8 +26,9 @@ def _get_disk_usage():
|
||||||
return total_size
|
return total_size
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/info/?')
|
@rest.routes.get('/info/?')
|
||||||
def get_info(ctx, _params=None):
|
def get_info(
|
||||||
|
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||||
post_feature = posts.try_get_current_post_feature()
|
post_feature = posts.try_get_current_post_feature()
|
||||||
return {
|
return {
|
||||||
'postCount': posts.get_post_count(),
|
'postCount': posts.get_post_count(),
|
||||||
|
@ -38,7 +40,7 @@ def get_info(ctx, _params=None):
|
||||||
'featuringUser':
|
'featuringUser':
|
||||||
users.serialize_user(post_feature.user, ctx.user)
|
users.serialize_user(post_feature.user, ctx.user)
|
||||||
if post_feature else None,
|
if post_feature else None,
|
||||||
'serverTime': datetime.datetime.utcnow(),
|
'serverTime': datetime.utcnow(),
|
||||||
'config': {
|
'config': {
|
||||||
'userNameRegex': config.config['user_name_regex'],
|
'userNameRegex': config.config['user_name_regex'],
|
||||||
'passwordRegex': config.config['password_regex'],
|
'passwordRegex': config.config['password_regex'],
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from szurubooru import config, errors
|
from typing import Dict
|
||||||
from szurubooru.rest import routes
|
from szurubooru import config, errors, rest
|
||||||
from szurubooru.func import auth, mailer, users, versions
|
from szurubooru.func import auth, mailer, users, versions
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,9 +10,9 @@ MAIL_BODY = \
|
||||||
'Otherwise, please ignore this email.'
|
'Otherwise, please ignore this email.'
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/password-reset/(?P<user_name>[^/]+)/?')
|
@rest.routes.get('/password-reset/(?P<user_name>[^/]+)/?')
|
||||||
def start_password_reset(_ctx, params):
|
def start_password_reset(
|
||||||
''' Send a mail with secure token to the correlated user. '''
|
_ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
||||||
user_name = params['user_name']
|
user_name = params['user_name']
|
||||||
user = users.get_user_by_name_or_email(user_name)
|
user = users.get_user_by_name_or_email(user_name)
|
||||||
if not user.email:
|
if not user.email:
|
||||||
|
@ -30,13 +30,13 @@ def start_password_reset(_ctx, params):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
@routes.post('/password-reset/(?P<user_name>[^/]+)/?')
|
@rest.routes.post('/password-reset/(?P<user_name>[^/]+)/?')
|
||||||
def finish_password_reset(ctx, params):
|
def finish_password_reset(
|
||||||
''' Verify token from mail, generate a new password and return it. '''
|
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
||||||
user_name = params['user_name']
|
user_name = params['user_name']
|
||||||
user = users.get_user_by_name_or_email(user_name)
|
user = users.get_user_by_name_or_email(user_name)
|
||||||
good_token = auth.generate_authentication_token(user)
|
good_token = auth.generate_authentication_token(user)
|
||||||
token = ctx.get_param_as_string('token', required=True)
|
token = ctx.get_param_as_string('token')
|
||||||
if token != good_token:
|
if token != good_token:
|
||||||
raise errors.ValidationError('Invalid password reset token.')
|
raise errors.ValidationError('Invalid password reset token.')
|
||||||
new_password = users.reset_user_password(user)
|
new_password = users.reset_user_password(user)
|
||||||
|
|
|
@ -1,44 +1,60 @@
|
||||||
import datetime
|
from typing import Optional, Dict
|
||||||
from szurubooru import search, db, errors
|
from datetime import datetime
|
||||||
from szurubooru.rest import routes
|
from szurubooru import db, model, errors, rest, search
|
||||||
from szurubooru.func import (
|
from szurubooru.func import (
|
||||||
auth, tags, posts, snapshots, favorites, scores, util, versions)
|
auth, tags, posts, snapshots, favorites, scores, serialization, versions)
|
||||||
|
|
||||||
|
|
||||||
_search_executor = search.Executor(search.configs.PostSearchConfig())
|
_search_executor_config = search.configs.PostSearchConfig()
|
||||||
|
_search_executor = search.Executor(_search_executor_config)
|
||||||
|
|
||||||
|
|
||||||
def _serialize_post(ctx, post):
|
def _get_post_id(params: Dict[str, str]) -> int:
|
||||||
|
try:
|
||||||
|
return int(params['post_id'])
|
||||||
|
except TypeError:
|
||||||
|
raise posts.InvalidPostIdError(
|
||||||
|
'Invalid post ID: %r.' % params['post_id'])
|
||||||
|
|
||||||
|
|
||||||
|
def _get_post(params: Dict[str, str]) -> model.Post:
|
||||||
|
return posts.get_post_by_id(_get_post_id(params))
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_post(
|
||||||
|
ctx: rest.Context, post: Optional[model.Post]) -> rest.Response:
|
||||||
return posts.serialize_post(
|
return posts.serialize_post(
|
||||||
post,
|
post,
|
||||||
ctx.user,
|
ctx.user,
|
||||||
options=util.get_serialization_options(ctx))
|
options=serialization.get_serialization_options(ctx))
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/posts/?')
|
@rest.routes.get('/posts/?')
|
||||||
def get_posts(ctx, _params=None):
|
def get_posts(
|
||||||
|
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||||
auth.verify_privilege(ctx.user, 'posts:list')
|
auth.verify_privilege(ctx.user, 'posts:list')
|
||||||
_search_executor.config.user = ctx.user
|
_search_executor_config.user = ctx.user
|
||||||
return _search_executor.execute_and_serialize(
|
return _search_executor.execute_and_serialize(
|
||||||
ctx, lambda post: _serialize_post(ctx, post))
|
ctx, lambda post: _serialize_post(ctx, post))
|
||||||
|
|
||||||
|
|
||||||
@routes.post('/posts/?')
|
@rest.routes.post('/posts/?')
|
||||||
def create_post(ctx, _params=None):
|
def create_post(
|
||||||
|
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||||
anonymous = ctx.get_param_as_bool('anonymous', default=False)
|
anonymous = ctx.get_param_as_bool('anonymous', default=False)
|
||||||
if anonymous:
|
if anonymous:
|
||||||
auth.verify_privilege(ctx.user, 'posts:create:anonymous')
|
auth.verify_privilege(ctx.user, 'posts:create:anonymous')
|
||||||
else:
|
else:
|
||||||
auth.verify_privilege(ctx.user, 'posts:create:identified')
|
auth.verify_privilege(ctx.user, 'posts:create:identified')
|
||||||
content = ctx.get_file('content', required=True)
|
content = ctx.get_file('content')
|
||||||
tag_names = ctx.get_param_as_list('tags', required=False, default=[])
|
tag_names = ctx.get_param_as_list('tags', default=[])
|
||||||
safety = ctx.get_param_as_string('safety', required=True)
|
safety = ctx.get_param_as_string('safety')
|
||||||
source = ctx.get_param_as_string('source', required=False, default=None)
|
source = ctx.get_param_as_string('source', default='')
|
||||||
if ctx.has_param('contentUrl') and not source:
|
if ctx.has_param('contentUrl') and not source:
|
||||||
source = ctx.get_param_as_string('contentUrl')
|
source = ctx.get_param_as_string('contentUrl', default='')
|
||||||
relations = ctx.get_param_as_list('relations', required=False) or []
|
relations = ctx.get_param_as_list('relations', default=[])
|
||||||
notes = ctx.get_param_as_list('notes', required=False) or []
|
notes = ctx.get_param_as_list('notes', default=[])
|
||||||
flags = ctx.get_param_as_list('flags', required=False) or []
|
flags = ctx.get_param_as_list('flags', default=[])
|
||||||
|
|
||||||
post, new_tags = posts.create_post(
|
post, new_tags = posts.create_post(
|
||||||
content, tag_names, None if anonymous else ctx.user)
|
content, tag_names, None if anonymous else ctx.user)
|
||||||
|
@ -61,16 +77,16 @@ def create_post(ctx, _params=None):
|
||||||
return _serialize_post(ctx, post)
|
return _serialize_post(ctx, post)
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/post/(?P<post_id>[^/]+)/?')
|
@rest.routes.get('/post/(?P<post_id>[^/]+)/?')
|
||||||
def get_post(ctx, params):
|
def get_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
||||||
auth.verify_privilege(ctx.user, 'posts:view')
|
auth.verify_privilege(ctx.user, 'posts:view')
|
||||||
post = posts.get_post_by_id(params['post_id'])
|
post = _get_post(params)
|
||||||
return _serialize_post(ctx, post)
|
return _serialize_post(ctx, post)
|
||||||
|
|
||||||
|
|
||||||
@routes.put('/post/(?P<post_id>[^/]+)/?')
|
@rest.routes.put('/post/(?P<post_id>[^/]+)/?')
|
||||||
def update_post(ctx, params):
|
def update_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
||||||
post = posts.get_post_by_id(params['post_id'])
|
post = _get_post(params)
|
||||||
versions.verify_version(post, ctx)
|
versions.verify_version(post, ctx)
|
||||||
versions.bump_version(post)
|
versions.bump_version(post)
|
||||||
if ctx.has_file('content'):
|
if ctx.has_file('content'):
|
||||||
|
@ -104,7 +120,7 @@ def update_post(ctx, params):
|
||||||
if ctx.has_file('thumbnail'):
|
if ctx.has_file('thumbnail'):
|
||||||
auth.verify_privilege(ctx.user, 'posts:edit:thumbnail')
|
auth.verify_privilege(ctx.user, 'posts:edit:thumbnail')
|
||||||
posts.update_post_thumbnail(post, ctx.get_file('thumbnail'))
|
posts.update_post_thumbnail(post, ctx.get_file('thumbnail'))
|
||||||
post.last_edit_time = datetime.datetime.utcnow()
|
post.last_edit_time = datetime.utcnow()
|
||||||
ctx.session.flush()
|
ctx.session.flush()
|
||||||
snapshots.modify(post, ctx.user)
|
snapshots.modify(post, ctx.user)
|
||||||
ctx.session.commit()
|
ctx.session.commit()
|
||||||
|
@ -112,10 +128,10 @@ def update_post(ctx, params):
|
||||||
return _serialize_post(ctx, post)
|
return _serialize_post(ctx, post)
|
||||||
|
|
||||||
|
|
||||||
@routes.delete('/post/(?P<post_id>[^/]+)/?')
|
@rest.routes.delete('/post/(?P<post_id>[^/]+)/?')
|
||||||
def delete_post(ctx, params):
|
def delete_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
||||||
auth.verify_privilege(ctx.user, 'posts:delete')
|
auth.verify_privilege(ctx.user, 'posts:delete')
|
||||||
post = posts.get_post_by_id(params['post_id'])
|
post = _get_post(params)
|
||||||
versions.verify_version(post, ctx)
|
versions.verify_version(post, ctx)
|
||||||
snapshots.delete(post, ctx.user)
|
snapshots.delete(post, ctx.user)
|
||||||
posts.delete(post)
|
posts.delete(post)
|
||||||
|
@ -124,13 +140,14 @@ def delete_post(ctx, params):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
@routes.post('/post-merge/?')
|
@rest.routes.post('/post-merge/?')
|
||||||
def merge_posts(ctx, _params=None):
|
def merge_posts(
|
||||||
source_post_id = ctx.get_param_as_string('remove', required=True) or ''
|
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||||
target_post_id = ctx.get_param_as_string('mergeTo', required=True) or ''
|
source_post_id = ctx.get_param_as_int('remove')
|
||||||
replace_content = ctx.get_param_as_bool('replaceContent')
|
target_post_id = ctx.get_param_as_int('mergeTo')
|
||||||
source_post = posts.get_post_by_id(source_post_id)
|
source_post = posts.get_post_by_id(source_post_id)
|
||||||
target_post = posts.get_post_by_id(target_post_id)
|
target_post = posts.get_post_by_id(target_post_id)
|
||||||
|
replace_content = ctx.get_param_as_bool('replaceContent')
|
||||||
versions.verify_version(source_post, ctx, 'removeVersion')
|
versions.verify_version(source_post, ctx, 'removeVersion')
|
||||||
versions.verify_version(target_post, ctx, 'mergeToVersion')
|
versions.verify_version(target_post, ctx, 'mergeToVersion')
|
||||||
versions.bump_version(target_post)
|
versions.bump_version(target_post)
|
||||||
|
@ -141,16 +158,18 @@ def merge_posts(ctx, _params=None):
|
||||||
return _serialize_post(ctx, target_post)
|
return _serialize_post(ctx, target_post)
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/featured-post/?')
|
@rest.routes.get('/featured-post/?')
|
||||||
def get_featured_post(ctx, _params=None):
|
def get_featured_post(
|
||||||
|
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||||
post = posts.try_get_featured_post()
|
post = posts.try_get_featured_post()
|
||||||
return _serialize_post(ctx, post)
|
return _serialize_post(ctx, post)
|
||||||
|
|
||||||
|
|
||||||
@routes.post('/featured-post/?')
|
@rest.routes.post('/featured-post/?')
|
||||||
def set_featured_post(ctx, _params=None):
|
def set_featured_post(
|
||||||
|
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||||
auth.verify_privilege(ctx.user, 'posts:feature')
|
auth.verify_privilege(ctx.user, 'posts:feature')
|
||||||
post_id = ctx.get_param_as_int('id', required=True)
|
post_id = ctx.get_param_as_int('id')
|
||||||
post = posts.get_post_by_id(post_id)
|
post = posts.get_post_by_id(post_id)
|
||||||
featured_post = posts.try_get_featured_post()
|
featured_post = posts.try_get_featured_post()
|
||||||
if featured_post and featured_post.post_id == post.post_id:
|
if featured_post and featured_post.post_id == post.post_id:
|
||||||
|
@ -162,55 +181,61 @@ def set_featured_post(ctx, _params=None):
|
||||||
return _serialize_post(ctx, post)
|
return _serialize_post(ctx, post)
|
||||||
|
|
||||||
|
|
||||||
@routes.put('/post/(?P<post_id>[^/]+)/score/?')
|
@rest.routes.put('/post/(?P<post_id>[^/]+)/score/?')
|
||||||
def set_post_score(ctx, params):
|
def set_post_score(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
||||||
auth.verify_privilege(ctx.user, 'posts:score')
|
auth.verify_privilege(ctx.user, 'posts:score')
|
||||||
post = posts.get_post_by_id(params['post_id'])
|
post = _get_post(params)
|
||||||
score = ctx.get_param_as_int('score', required=True)
|
score = ctx.get_param_as_int('score')
|
||||||
scores.set_score(post, ctx.user, score)
|
scores.set_score(post, ctx.user, score)
|
||||||
ctx.session.commit()
|
ctx.session.commit()
|
||||||
return _serialize_post(ctx, post)
|
return _serialize_post(ctx, post)
|
||||||
|
|
||||||
|
|
||||||
@routes.delete('/post/(?P<post_id>[^/]+)/score/?')
|
@rest.routes.delete('/post/(?P<post_id>[^/]+)/score/?')
|
||||||
def delete_post_score(ctx, params):
|
def delete_post_score(
|
||||||
|
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
||||||
auth.verify_privilege(ctx.user, 'posts:score')
|
auth.verify_privilege(ctx.user, 'posts:score')
|
||||||
post = posts.get_post_by_id(params['post_id'])
|
post = _get_post(params)
|
||||||
scores.delete_score(post, ctx.user)
|
scores.delete_score(post, ctx.user)
|
||||||
ctx.session.commit()
|
ctx.session.commit()
|
||||||
return _serialize_post(ctx, post)
|
return _serialize_post(ctx, post)
|
||||||
|
|
||||||
|
|
||||||
@routes.post('/post/(?P<post_id>[^/]+)/favorite/?')
|
@rest.routes.post('/post/(?P<post_id>[^/]+)/favorite/?')
|
||||||
def add_post_to_favorites(ctx, params):
|
def add_post_to_favorites(
|
||||||
|
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
||||||
auth.verify_privilege(ctx.user, 'posts:favorite')
|
auth.verify_privilege(ctx.user, 'posts:favorite')
|
||||||
post = posts.get_post_by_id(params['post_id'])
|
post = _get_post(params)
|
||||||
favorites.set_favorite(post, ctx.user)
|
favorites.set_favorite(post, ctx.user)
|
||||||
ctx.session.commit()
|
ctx.session.commit()
|
||||||
return _serialize_post(ctx, post)
|
return _serialize_post(ctx, post)
|
||||||
|
|
||||||
|
|
||||||
@routes.delete('/post/(?P<post_id>[^/]+)/favorite/?')
|
@rest.routes.delete('/post/(?P<post_id>[^/]+)/favorite/?')
|
||||||
def delete_post_from_favorites(ctx, params):
|
def delete_post_from_favorites(
|
||||||
|
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
||||||
auth.verify_privilege(ctx.user, 'posts:favorite')
|
auth.verify_privilege(ctx.user, 'posts:favorite')
|
||||||
post = posts.get_post_by_id(params['post_id'])
|
post = _get_post(params)
|
||||||
favorites.unset_favorite(post, ctx.user)
|
favorites.unset_favorite(post, ctx.user)
|
||||||
ctx.session.commit()
|
ctx.session.commit()
|
||||||
return _serialize_post(ctx, post)
|
return _serialize_post(ctx, post)
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/post/(?P<post_id>[^/]+)/around/?')
|
@rest.routes.get('/post/(?P<post_id>[^/]+)/around/?')
|
||||||
def get_posts_around(ctx, params):
|
def get_posts_around(
|
||||||
|
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
||||||
auth.verify_privilege(ctx.user, 'posts:list')
|
auth.verify_privilege(ctx.user, 'posts:list')
|
||||||
_search_executor.config.user = ctx.user
|
_search_executor_config.user = ctx.user
|
||||||
|
post_id = _get_post_id(params)
|
||||||
return _search_executor.get_around_and_serialize(
|
return _search_executor.get_around_and_serialize(
|
||||||
ctx, params['post_id'], lambda post: _serialize_post(ctx, post))
|
ctx, post_id, lambda post: _serialize_post(ctx, post))
|
||||||
|
|
||||||
|
|
||||||
@routes.post('/posts/reverse-search/?')
|
@rest.routes.post('/posts/reverse-search/?')
|
||||||
def get_posts_by_image(ctx, _params=None):
|
def get_posts_by_image(
|
||||||
|
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||||
auth.verify_privilege(ctx.user, 'posts:reverse_search')
|
auth.verify_privilege(ctx.user, 'posts:reverse_search')
|
||||||
content = ctx.get_file('content', required=True)
|
content = ctx.get_file('content')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
lookalikes = posts.search_by_image(content)
|
lookalikes = posts.search_by_image(content)
|
||||||
|
|
|
@ -1,13 +1,14 @@
|
||||||
from szurubooru import search
|
from typing import Dict
|
||||||
from szurubooru.rest import routes
|
from szurubooru import search, rest
|
||||||
from szurubooru.func import auth, snapshots
|
from szurubooru.func import auth, snapshots
|
||||||
|
|
||||||
|
|
||||||
_search_executor = search.Executor(search.configs.SnapshotSearchConfig())
|
_search_executor = search.Executor(search.configs.SnapshotSearchConfig())
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/snapshots/?')
|
@rest.routes.get('/snapshots/?')
|
||||||
def get_snapshots(ctx, _params=None):
|
def get_snapshots(
|
||||||
|
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||||
auth.verify_privilege(ctx.user, 'snapshots:list')
|
auth.verify_privilege(ctx.user, 'snapshots:list')
|
||||||
return _search_executor.execute_and_serialize(
|
return _search_executor.execute_and_serialize(
|
||||||
ctx, lambda snapshot: snapshots.serialize_snapshot(snapshot, ctx.user))
|
ctx, lambda snapshot: snapshots.serialize_snapshot(snapshot, ctx.user))
|
||||||
|
|
|
@ -1,18 +1,22 @@
|
||||||
import datetime
|
from typing import Optional, List, Dict
|
||||||
from szurubooru import db, search
|
from datetime import datetime
|
||||||
from szurubooru.rest import routes
|
from szurubooru import db, model, search, rest
|
||||||
from szurubooru.func import auth, tags, snapshots, util, versions
|
from szurubooru.func import auth, tags, snapshots, serialization, versions
|
||||||
|
|
||||||
|
|
||||||
_search_executor = search.Executor(search.configs.TagSearchConfig())
|
_search_executor = search.Executor(search.configs.TagSearchConfig())
|
||||||
|
|
||||||
|
|
||||||
def _serialize(ctx, tag):
|
def _serialize(ctx: rest.Context, tag: model.Tag) -> rest.Response:
|
||||||
return tags.serialize_tag(
|
return tags.serialize_tag(
|
||||||
tag, options=util.get_serialization_options(ctx))
|
tag, options=serialization.get_serialization_options(ctx))
|
||||||
|
|
||||||
|
|
||||||
def _create_if_needed(tag_names, user):
|
def _get_tag(params: Dict[str, str]) -> model.Tag:
|
||||||
|
return tags.get_tag_by_name(params['tag_name'])
|
||||||
|
|
||||||
|
|
||||||
|
def _create_if_needed(tag_names: List[str], user: model.User) -> None:
|
||||||
if not tag_names:
|
if not tag_names:
|
||||||
return
|
return
|
||||||
_existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names)
|
_existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names)
|
||||||
|
@ -23,25 +27,22 @@ def _create_if_needed(tag_names, user):
|
||||||
snapshots.create(tag, user)
|
snapshots.create(tag, user)
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/tags/?')
|
@rest.routes.get('/tags/?')
|
||||||
def get_tags(ctx, _params=None):
|
def get_tags(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||||
auth.verify_privilege(ctx.user, 'tags:list')
|
auth.verify_privilege(ctx.user, 'tags:list')
|
||||||
return _search_executor.execute_and_serialize(
|
return _search_executor.execute_and_serialize(
|
||||||
ctx, lambda tag: _serialize(ctx, tag))
|
ctx, lambda tag: _serialize(ctx, tag))
|
||||||
|
|
||||||
|
|
||||||
@routes.post('/tags/?')
|
@rest.routes.post('/tags/?')
|
||||||
def create_tag(ctx, _params=None):
|
def create_tag(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||||
auth.verify_privilege(ctx.user, 'tags:create')
|
auth.verify_privilege(ctx.user, 'tags:create')
|
||||||
|
|
||||||
names = ctx.get_param_as_list('names', required=True)
|
names = ctx.get_param_as_list('names')
|
||||||
category = ctx.get_param_as_string('category', required=True)
|
category = ctx.get_param_as_string('category')
|
||||||
description = ctx.get_param_as_string(
|
description = ctx.get_param_as_string('description', default='')
|
||||||
'description', required=False, default=None)
|
suggestions = ctx.get_param_as_list('suggestions', default=[])
|
||||||
suggestions = ctx.get_param_as_list(
|
implications = ctx.get_param_as_list('implications', default=[])
|
||||||
'suggestions', required=False, default=[])
|
|
||||||
implications = ctx.get_param_as_list(
|
|
||||||
'implications', required=False, default=[])
|
|
||||||
|
|
||||||
_create_if_needed(suggestions, ctx.user)
|
_create_if_needed(suggestions, ctx.user)
|
||||||
_create_if_needed(implications, ctx.user)
|
_create_if_needed(implications, ctx.user)
|
||||||
|
@ -56,16 +57,16 @@ def create_tag(ctx, _params=None):
|
||||||
return _serialize(ctx, tag)
|
return _serialize(ctx, tag)
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/tag/(?P<tag_name>.+)')
|
@rest.routes.get('/tag/(?P<tag_name>.+)')
|
||||||
def get_tag(ctx, params):
|
def get_tag(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
||||||
auth.verify_privilege(ctx.user, 'tags:view')
|
auth.verify_privilege(ctx.user, 'tags:view')
|
||||||
tag = tags.get_tag_by_name(params['tag_name'])
|
tag = _get_tag(params)
|
||||||
return _serialize(ctx, tag)
|
return _serialize(ctx, tag)
|
||||||
|
|
||||||
|
|
||||||
@routes.put('/tag/(?P<tag_name>.+)')
|
@rest.routes.put('/tag/(?P<tag_name>.+)')
|
||||||
def update_tag(ctx, params):
|
def update_tag(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
||||||
tag = tags.get_tag_by_name(params['tag_name'])
|
tag = _get_tag(params)
|
||||||
versions.verify_version(tag, ctx)
|
versions.verify_version(tag, ctx)
|
||||||
versions.bump_version(tag)
|
versions.bump_version(tag)
|
||||||
if ctx.has_param('names'):
|
if ctx.has_param('names'):
|
||||||
|
@ -78,7 +79,7 @@ def update_tag(ctx, params):
|
||||||
if ctx.has_param('description'):
|
if ctx.has_param('description'):
|
||||||
auth.verify_privilege(ctx.user, 'tags:edit:description')
|
auth.verify_privilege(ctx.user, 'tags:edit:description')
|
||||||
tags.update_tag_description(
|
tags.update_tag_description(
|
||||||
tag, ctx.get_param_as_string('description', default=None))
|
tag, ctx.get_param_as_string('description'))
|
||||||
if ctx.has_param('suggestions'):
|
if ctx.has_param('suggestions'):
|
||||||
auth.verify_privilege(ctx.user, 'tags:edit:suggestions')
|
auth.verify_privilege(ctx.user, 'tags:edit:suggestions')
|
||||||
suggestions = ctx.get_param_as_list('suggestions')
|
suggestions = ctx.get_param_as_list('suggestions')
|
||||||
|
@ -89,7 +90,7 @@ def update_tag(ctx, params):
|
||||||
implications = ctx.get_param_as_list('implications')
|
implications = ctx.get_param_as_list('implications')
|
||||||
_create_if_needed(implications, ctx.user)
|
_create_if_needed(implications, ctx.user)
|
||||||
tags.update_tag_implications(tag, implications)
|
tags.update_tag_implications(tag, implications)
|
||||||
tag.last_edit_time = datetime.datetime.utcnow()
|
tag.last_edit_time = datetime.utcnow()
|
||||||
ctx.session.flush()
|
ctx.session.flush()
|
||||||
snapshots.modify(tag, ctx.user)
|
snapshots.modify(tag, ctx.user)
|
||||||
ctx.session.commit()
|
ctx.session.commit()
|
||||||
|
@ -97,9 +98,9 @@ def update_tag(ctx, params):
|
||||||
return _serialize(ctx, tag)
|
return _serialize(ctx, tag)
|
||||||
|
|
||||||
|
|
||||||
@routes.delete('/tag/(?P<tag_name>.+)')
|
@rest.routes.delete('/tag/(?P<tag_name>.+)')
|
||||||
def delete_tag(ctx, params):
|
def delete_tag(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
||||||
tag = tags.get_tag_by_name(params['tag_name'])
|
tag = _get_tag(params)
|
||||||
versions.verify_version(tag, ctx)
|
versions.verify_version(tag, ctx)
|
||||||
auth.verify_privilege(ctx.user, 'tags:delete')
|
auth.verify_privilege(ctx.user, 'tags:delete')
|
||||||
snapshots.delete(tag, ctx.user)
|
snapshots.delete(tag, ctx.user)
|
||||||
|
@ -109,10 +110,11 @@ def delete_tag(ctx, params):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
@routes.post('/tag-merge/?')
|
@rest.routes.post('/tag-merge/?')
|
||||||
def merge_tags(ctx, _params=None):
|
def merge_tags(
|
||||||
source_tag_name = ctx.get_param_as_string('remove', required=True) or ''
|
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||||
target_tag_name = ctx.get_param_as_string('mergeTo', required=True) or ''
|
source_tag_name = ctx.get_param_as_string('remove')
|
||||||
|
target_tag_name = ctx.get_param_as_string('mergeTo')
|
||||||
source_tag = tags.get_tag_by_name(source_tag_name)
|
source_tag = tags.get_tag_by_name(source_tag_name)
|
||||||
target_tag = tags.get_tag_by_name(target_tag_name)
|
target_tag = tags.get_tag_by_name(target_tag_name)
|
||||||
versions.verify_version(source_tag, ctx, 'removeVersion')
|
versions.verify_version(source_tag, ctx, 'removeVersion')
|
||||||
|
@ -126,10 +128,11 @@ def merge_tags(ctx, _params=None):
|
||||||
return _serialize(ctx, target_tag)
|
return _serialize(ctx, target_tag)
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/tag-siblings/(?P<tag_name>.+)')
|
@rest.routes.get('/tag-siblings/(?P<tag_name>.+)')
|
||||||
def get_tag_siblings(ctx, params):
|
def get_tag_siblings(
|
||||||
|
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
||||||
auth.verify_privilege(ctx.user, 'tags:view')
|
auth.verify_privilege(ctx.user, 'tags:view')
|
||||||
tag = tags.get_tag_by_name(params['tag_name'])
|
tag = _get_tag(params)
|
||||||
result = tags.get_tag_siblings(tag)
|
result = tags.get_tag_siblings(tag)
|
||||||
serialized_siblings = []
|
serialized_siblings = []
|
||||||
for sibling, occurrences in result:
|
for sibling, occurrences in result:
|
||||||
|
|
|
@ -1,15 +1,18 @@
|
||||||
from szurubooru.rest import routes
|
from typing import Dict
|
||||||
|
from szurubooru import model, rest
|
||||||
from szurubooru.func import (
|
from szurubooru.func import (
|
||||||
auth, tags, tag_categories, snapshots, util, versions)
|
auth, tags, tag_categories, snapshots, serialization, versions)
|
||||||
|
|
||||||
|
|
||||||
def _serialize(ctx, category):
|
def _serialize(
|
||||||
|
ctx: rest.Context, category: model.TagCategory) -> rest.Response:
|
||||||
return tag_categories.serialize_category(
|
return tag_categories.serialize_category(
|
||||||
category, options=util.get_serialization_options(ctx))
|
category, options=serialization.get_serialization_options(ctx))
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/tag-categories/?')
|
@rest.routes.get('/tag-categories/?')
|
||||||
def get_tag_categories(ctx, _params=None):
|
def get_tag_categories(
|
||||||
|
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||||
auth.verify_privilege(ctx.user, 'tag_categories:list')
|
auth.verify_privilege(ctx.user, 'tag_categories:list')
|
||||||
categories = tag_categories.get_all_categories()
|
categories = tag_categories.get_all_categories()
|
||||||
return {
|
return {
|
||||||
|
@ -17,11 +20,12 @@ def get_tag_categories(ctx, _params=None):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@routes.post('/tag-categories/?')
|
@rest.routes.post('/tag-categories/?')
|
||||||
def create_tag_category(ctx, _params=None):
|
def create_tag_category(
|
||||||
|
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||||
auth.verify_privilege(ctx.user, 'tag_categories:create')
|
auth.verify_privilege(ctx.user, 'tag_categories:create')
|
||||||
name = ctx.get_param_as_string('name', required=True)
|
name = ctx.get_param_as_string('name')
|
||||||
color = ctx.get_param_as_string('color', required=True)
|
color = ctx.get_param_as_string('color')
|
||||||
category = tag_categories.create_category(name, color)
|
category = tag_categories.create_category(name, color)
|
||||||
ctx.session.add(category)
|
ctx.session.add(category)
|
||||||
ctx.session.flush()
|
ctx.session.flush()
|
||||||
|
@ -31,15 +35,17 @@ def create_tag_category(ctx, _params=None):
|
||||||
return _serialize(ctx, category)
|
return _serialize(ctx, category)
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/tag-category/(?P<category_name>[^/]+)/?')
|
@rest.routes.get('/tag-category/(?P<category_name>[^/]+)/?')
|
||||||
def get_tag_category(ctx, params):
|
def get_tag_category(
|
||||||
|
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
||||||
auth.verify_privilege(ctx.user, 'tag_categories:view')
|
auth.verify_privilege(ctx.user, 'tag_categories:view')
|
||||||
category = tag_categories.get_category_by_name(params['category_name'])
|
category = tag_categories.get_category_by_name(params['category_name'])
|
||||||
return _serialize(ctx, category)
|
return _serialize(ctx, category)
|
||||||
|
|
||||||
|
|
||||||
@routes.put('/tag-category/(?P<category_name>[^/]+)/?')
|
@rest.routes.put('/tag-category/(?P<category_name>[^/]+)/?')
|
||||||
def update_tag_category(ctx, params):
|
def update_tag_category(
|
||||||
|
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
||||||
category = tag_categories.get_category_by_name(
|
category = tag_categories.get_category_by_name(
|
||||||
params['category_name'], lock=True)
|
params['category_name'], lock=True)
|
||||||
versions.verify_version(category, ctx)
|
versions.verify_version(category, ctx)
|
||||||
|
@ -59,8 +65,9 @@ def update_tag_category(ctx, params):
|
||||||
return _serialize(ctx, category)
|
return _serialize(ctx, category)
|
||||||
|
|
||||||
|
|
||||||
@routes.delete('/tag-category/(?P<category_name>[^/]+)/?')
|
@rest.routes.delete('/tag-category/(?P<category_name>[^/]+)/?')
|
||||||
def delete_tag_category(ctx, params):
|
def delete_tag_category(
|
||||||
|
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
||||||
category = tag_categories.get_category_by_name(
|
category = tag_categories.get_category_by_name(
|
||||||
params['category_name'], lock=True)
|
params['category_name'], lock=True)
|
||||||
versions.verify_version(category, ctx)
|
versions.verify_version(category, ctx)
|
||||||
|
@ -72,8 +79,9 @@ def delete_tag_category(ctx, params):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
@routes.put('/tag-category/(?P<category_name>[^/]+)/default/?')
|
@rest.routes.put('/tag-category/(?P<category_name>[^/]+)/default/?')
|
||||||
def set_tag_category_as_default(ctx, params):
|
def set_tag_category_as_default(
|
||||||
|
ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
||||||
auth.verify_privilege(ctx.user, 'tag_categories:set_default')
|
auth.verify_privilege(ctx.user, 'tag_categories:set_default')
|
||||||
category = tag_categories.get_category_by_name(
|
category = tag_categories.get_category_by_name(
|
||||||
params['category_name'], lock=True)
|
params['category_name'], lock=True)
|
||||||
|
|
|
@ -1,10 +1,12 @@
|
||||||
from szurubooru.rest import routes
|
from typing import Dict
|
||||||
|
from szurubooru import rest
|
||||||
from szurubooru.func import auth, file_uploads
|
from szurubooru.func import auth, file_uploads
|
||||||
|
|
||||||
|
|
||||||
@routes.post('/uploads/?')
|
@rest.routes.post('/uploads/?')
|
||||||
def create_temporary_file(ctx, _params=None):
|
def create_temporary_file(
|
||||||
|
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||||
auth.verify_privilege(ctx.user, 'uploads:create')
|
auth.verify_privilege(ctx.user, 'uploads:create')
|
||||||
content = ctx.get_file('content', required=True, allow_tokens=False)
|
content = ctx.get_file('content', allow_tokens=False)
|
||||||
token = file_uploads.save(content)
|
token = file_uploads.save(content)
|
||||||
return {'token': token}
|
return {'token': token}
|
||||||
|
|
|
@ -1,56 +1,57 @@
|
||||||
from szurubooru import search
|
from typing import Any, Dict
|
||||||
from szurubooru.rest import routes
|
from szurubooru import model, search, rest
|
||||||
from szurubooru.func import auth, users, util, versions
|
from szurubooru.func import auth, users, serialization, versions
|
||||||
|
|
||||||
|
|
||||||
_search_executor = search.Executor(search.configs.UserSearchConfig())
|
_search_executor = search.Executor(search.configs.UserSearchConfig())
|
||||||
|
|
||||||
|
|
||||||
def _serialize(ctx, user, **kwargs):
|
def _serialize(
|
||||||
|
ctx: rest.Context, user: model.User, **kwargs: Any) -> rest.Response:
|
||||||
return users.serialize_user(
|
return users.serialize_user(
|
||||||
user,
|
user,
|
||||||
ctx.user,
|
ctx.user,
|
||||||
options=util.get_serialization_options(ctx),
|
options=serialization.get_serialization_options(ctx),
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/users/?')
|
@rest.routes.get('/users/?')
|
||||||
def get_users(ctx, _params=None):
|
def get_users(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||||
auth.verify_privilege(ctx.user, 'users:list')
|
auth.verify_privilege(ctx.user, 'users:list')
|
||||||
return _search_executor.execute_and_serialize(
|
return _search_executor.execute_and_serialize(
|
||||||
ctx, lambda user: _serialize(ctx, user))
|
ctx, lambda user: _serialize(ctx, user))
|
||||||
|
|
||||||
|
|
||||||
@routes.post('/users/?')
|
@rest.routes.post('/users/?')
|
||||||
def create_user(ctx, _params=None):
|
def create_user(
|
||||||
|
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
|
||||||
auth.verify_privilege(ctx.user, 'users:create')
|
auth.verify_privilege(ctx.user, 'users:create')
|
||||||
name = ctx.get_param_as_string('name', required=True)
|
name = ctx.get_param_as_string('name')
|
||||||
password = ctx.get_param_as_string('password', required=True)
|
password = ctx.get_param_as_string('password')
|
||||||
email = ctx.get_param_as_string('email', required=False, default='')
|
email = ctx.get_param_as_string('email', default='')
|
||||||
user = users.create_user(name, password, email)
|
user = users.create_user(name, password, email)
|
||||||
if ctx.has_param('rank'):
|
if ctx.has_param('rank'):
|
||||||
users.update_user_rank(
|
users.update_user_rank(user, ctx.get_param_as_string('rank'), ctx.user)
|
||||||
user, ctx.get_param_as_string('rank'), ctx.user)
|
|
||||||
if ctx.has_param('avatarStyle'):
|
if ctx.has_param('avatarStyle'):
|
||||||
users.update_user_avatar(
|
users.update_user_avatar(
|
||||||
user,
|
user,
|
||||||
ctx.get_param_as_string('avatarStyle'),
|
ctx.get_param_as_string('avatarStyle'),
|
||||||
ctx.get_file('avatar'))
|
ctx.get_file('avatar', default=b''))
|
||||||
ctx.session.add(user)
|
ctx.session.add(user)
|
||||||
ctx.session.commit()
|
ctx.session.commit()
|
||||||
return _serialize(ctx, user, force_show_email=True)
|
return _serialize(ctx, user, force_show_email=True)
|
||||||
|
|
||||||
|
|
||||||
@routes.get('/user/(?P<user_name>[^/]+)/?')
|
@rest.routes.get('/user/(?P<user_name>[^/]+)/?')
|
||||||
def get_user(ctx, params):
|
def get_user(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
||||||
user = users.get_user_by_name(params['user_name'])
|
user = users.get_user_by_name(params['user_name'])
|
||||||
if ctx.user.user_id != user.user_id:
|
if ctx.user.user_id != user.user_id:
|
||||||
auth.verify_privilege(ctx.user, 'users:view')
|
auth.verify_privilege(ctx.user, 'users:view')
|
||||||
return _serialize(ctx, user)
|
return _serialize(ctx, user)
|
||||||
|
|
||||||
|
|
||||||
@routes.put('/user/(?P<user_name>[^/]+)/?')
|
@rest.routes.put('/user/(?P<user_name>[^/]+)/?')
|
||||||
def update_user(ctx, params):
|
def update_user(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
||||||
user = users.get_user_by_name(params['user_name'])
|
user = users.get_user_by_name(params['user_name'])
|
||||||
versions.verify_version(user, ctx)
|
versions.verify_version(user, ctx)
|
||||||
versions.bump_version(user)
|
versions.bump_version(user)
|
||||||
|
@ -74,13 +75,13 @@ def update_user(ctx, params):
|
||||||
users.update_user_avatar(
|
users.update_user_avatar(
|
||||||
user,
|
user,
|
||||||
ctx.get_param_as_string('avatarStyle'),
|
ctx.get_param_as_string('avatarStyle'),
|
||||||
ctx.get_file('avatar'))
|
ctx.get_file('avatar', default=b''))
|
||||||
ctx.session.commit()
|
ctx.session.commit()
|
||||||
return _serialize(ctx, user)
|
return _serialize(ctx, user)
|
||||||
|
|
||||||
|
|
||||||
@routes.delete('/user/(?P<user_name>[^/]+)/?')
|
@rest.routes.delete('/user/(?P<user_name>[^/]+)/?')
|
||||||
def delete_user(ctx, params):
|
def delete_user(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
||||||
user = users.get_user_by_name(params['user_name'])
|
user = users.get_user_by_name(params['user_name'])
|
||||||
versions.verify_version(user, ctx)
|
versions.verify_version(user, ctx)
|
||||||
infix = 'self' if ctx.user.user_id == user.user_id else 'any'
|
infix = 'self' if ctx.user.user_id == user.user_id else 'any'
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
|
from typing import Dict
|
||||||
import os
|
import os
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
def merge(left, right):
|
def merge(left: Dict, right: Dict) -> Dict:
|
||||||
for key in right:
|
for key in right:
|
||||||
if key in left:
|
if key in left:
|
||||||
if isinstance(left[key], dict) and isinstance(right[key], dict):
|
if isinstance(left[key], dict) and isinstance(right[key], dict):
|
||||||
|
@ -14,7 +15,7 @@ def merge(left, right):
|
||||||
return left
|
return left
|
||||||
|
|
||||||
|
|
||||||
def read_config():
|
def read_config() -> Dict:
|
||||||
with open('../config.yaml.dist') as handle:
|
with open('../config.yaml.dist') as handle:
|
||||||
ret = yaml.load(handle.read())
|
ret = yaml.load(handle.read())
|
||||||
if os.path.exists('../config.yaml'):
|
if os.path.exists('../config.yaml'):
|
||||||
|
|
36
server/szurubooru/db.py
Normal file
36
server/szurubooru/db.py
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
from typing import Any
|
||||||
|
import threading
|
||||||
|
import sqlalchemy as sa
|
||||||
|
import sqlalchemy.orm
|
||||||
|
from szurubooru import config
|
||||||
|
|
||||||
|
# pylint: disable=invalid-name
|
||||||
|
_data = threading.local()
|
||||||
|
_engine = sa.create_engine(config.config['database']) # type: Any
|
||||||
|
sessionmaker = sa.orm.sessionmaker(bind=_engine, autoflush=False) # type: Any
|
||||||
|
session = sa.orm.scoped_session(sessionmaker) # type: Any
|
||||||
|
|
||||||
|
|
||||||
|
def get_session() -> Any:
|
||||||
|
global session
|
||||||
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
def set_sesssion(new_session: Any) -> None:
|
||||||
|
global session
|
||||||
|
session = new_session
|
||||||
|
|
||||||
|
|
||||||
|
def reset_query_count() -> None:
|
||||||
|
_data.query_count = 0
|
||||||
|
|
||||||
|
|
||||||
|
def get_query_count() -> int:
|
||||||
|
return _data.query_count
|
||||||
|
|
||||||
|
|
||||||
|
def _bump_query_count() -> None:
|
||||||
|
_data.query_count = getattr(_data, 'query_count', 0) + 1
|
||||||
|
|
||||||
|
|
||||||
|
sa.event.listen(_engine, 'after_execute', lambda *args: _bump_query_count())
|
|
@ -1,17 +0,0 @@
|
||||||
from szurubooru.db.base import Base
|
|
||||||
from szurubooru.db.user import User
|
|
||||||
from szurubooru.db.tag_category import TagCategory
|
|
||||||
from szurubooru.db.tag import (Tag, TagName, TagSuggestion, TagImplication)
|
|
||||||
from szurubooru.db.post import (
|
|
||||||
Post,
|
|
||||||
PostTag,
|
|
||||||
PostRelation,
|
|
||||||
PostFavorite,
|
|
||||||
PostScore,
|
|
||||||
PostNote,
|
|
||||||
PostFeature)
|
|
||||||
from szurubooru.db.comment import (Comment, CommentScore)
|
|
||||||
from szurubooru.db.snapshot import Snapshot
|
|
||||||
from szurubooru.db.session import (
|
|
||||||
session, sessionmaker, reset_query_count, get_query_count)
|
|
||||||
import szurubooru.db.util
|
|
|
@ -1,27 +0,0 @@
|
||||||
import threading
|
|
||||||
import sqlalchemy
|
|
||||||
from szurubooru import config
|
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
|
||||||
_engine = sqlalchemy.create_engine(config.config['database'])
|
|
||||||
sessionmaker = sqlalchemy.orm.sessionmaker(bind=_engine, autoflush=False)
|
|
||||||
session = sqlalchemy.orm.scoped_session(sessionmaker)
|
|
||||||
|
|
||||||
_data = threading.local()
|
|
||||||
|
|
||||||
|
|
||||||
def reset_query_count():
|
|
||||||
_data.query_count = 0
|
|
||||||
|
|
||||||
|
|
||||||
def get_query_count():
|
|
||||||
return _data.query_count
|
|
||||||
|
|
||||||
|
|
||||||
def _bump_query_count():
|
|
||||||
_data.query_count = getattr(_data, 'query_count', 0) + 1
|
|
||||||
|
|
||||||
|
|
||||||
sqlalchemy.event.listen(
|
|
||||||
_engine, 'after_execute', lambda *args: _bump_query_count())
|
|
|
@ -1,34 +0,0 @@
|
||||||
from sqlalchemy.inspection import inspect
|
|
||||||
|
|
||||||
|
|
||||||
def get_resource_info(entity):
|
|
||||||
serializers = {
|
|
||||||
'tag': lambda tag: tag.first_name,
|
|
||||||
'tag_category': lambda category: category.name,
|
|
||||||
'comment': lambda comment: comment.comment_id,
|
|
||||||
'post': lambda post: post.post_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
resource_type = entity.__table__.name
|
|
||||||
assert resource_type in serializers
|
|
||||||
|
|
||||||
primary_key = inspect(entity).identity
|
|
||||||
assert primary_key is not None
|
|
||||||
assert len(primary_key) == 1
|
|
||||||
|
|
||||||
resource_name = serializers[resource_type](entity)
|
|
||||||
assert resource_name
|
|
||||||
|
|
||||||
resource_pkey = primary_key[0]
|
|
||||||
assert resource_pkey
|
|
||||||
|
|
||||||
return (resource_type, resource_pkey, resource_name)
|
|
||||||
|
|
||||||
|
|
||||||
def get_aux_entity(session, get_table_info, entity, user):
|
|
||||||
table, get_column = get_table_info(entity)
|
|
||||||
return session \
|
|
||||||
.query(table) \
|
|
||||||
.filter(get_column(table) == get_column(entity)) \
|
|
||||||
.filter(table.user_id == user.user_id) \
|
|
||||||
.one_or_none()
|
|
|
@ -1,5 +1,11 @@
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
|
||||||
class BaseError(RuntimeError):
|
class BaseError(RuntimeError):
|
||||||
def __init__(self, message='Unknown error', extra_fields=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str='Unknown error',
|
||||||
|
extra_fields: Dict[str, str]=None) -> None:
|
||||||
super().__init__(message)
|
super().__init__(message)
|
||||||
self.extra_fields = extra_fields
|
self.extra_fields = extra_fields
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,10 @@ import os
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
|
from typing import Callable, Any, Type
|
||||||
|
|
||||||
import coloredlogs
|
import coloredlogs
|
||||||
|
import sqlalchemy as sa
|
||||||
import sqlalchemy.orm.exc
|
import sqlalchemy.orm.exc
|
||||||
from szurubooru import config, db, errors, rest
|
from szurubooru import config, db, errors, rest
|
||||||
from szurubooru.func import posts, file_uploads
|
from szurubooru.func import posts, file_uploads
|
||||||
|
@ -10,7 +13,10 @@ from szurubooru.func import posts, file_uploads
|
||||||
from szurubooru import api, middleware
|
from szurubooru import api, middleware
|
||||||
|
|
||||||
|
|
||||||
def _map_error(ex, target_class, title):
|
def _map_error(
|
||||||
|
ex: Exception,
|
||||||
|
target_class: Type[rest.errors.BaseHttpError],
|
||||||
|
title: str) -> rest.errors.BaseHttpError:
|
||||||
return target_class(
|
return target_class(
|
||||||
name=type(ex).__name__,
|
name=type(ex).__name__,
|
||||||
title=title,
|
title=title,
|
||||||
|
@ -18,38 +24,38 @@ def _map_error(ex, target_class, title):
|
||||||
extra_fields=getattr(ex, 'extra_fields', {}))
|
extra_fields=getattr(ex, 'extra_fields', {}))
|
||||||
|
|
||||||
|
|
||||||
def _on_auth_error(ex):
|
def _on_auth_error(ex: Exception) -> None:
|
||||||
raise _map_error(ex, rest.errors.HttpForbidden, 'Authentication error')
|
raise _map_error(ex, rest.errors.HttpForbidden, 'Authentication error')
|
||||||
|
|
||||||
|
|
||||||
def _on_validation_error(ex):
|
def _on_validation_error(ex: Exception) -> None:
|
||||||
raise _map_error(ex, rest.errors.HttpBadRequest, 'Validation error')
|
raise _map_error(ex, rest.errors.HttpBadRequest, 'Validation error')
|
||||||
|
|
||||||
|
|
||||||
def _on_search_error(ex):
|
def _on_search_error(ex: Exception) -> None:
|
||||||
raise _map_error(ex, rest.errors.HttpBadRequest, 'Search error')
|
raise _map_error(ex, rest.errors.HttpBadRequest, 'Search error')
|
||||||
|
|
||||||
|
|
||||||
def _on_integrity_error(ex):
|
def _on_integrity_error(ex: Exception) -> None:
|
||||||
raise _map_error(ex, rest.errors.HttpConflict, 'Integrity violation')
|
raise _map_error(ex, rest.errors.HttpConflict, 'Integrity violation')
|
||||||
|
|
||||||
|
|
||||||
def _on_not_found_error(ex):
|
def _on_not_found_error(ex: Exception) -> None:
|
||||||
raise _map_error(ex, rest.errors.HttpNotFound, 'Not found')
|
raise _map_error(ex, rest.errors.HttpNotFound, 'Not found')
|
||||||
|
|
||||||
|
|
||||||
def _on_processing_error(ex):
|
def _on_processing_error(ex: Exception) -> None:
|
||||||
raise _map_error(ex, rest.errors.HttpBadRequest, 'Processing error')
|
raise _map_error(ex, rest.errors.HttpBadRequest, 'Processing error')
|
||||||
|
|
||||||
|
|
||||||
def _on_third_party_error(ex):
|
def _on_third_party_error(ex: Exception) -> None:
|
||||||
raise _map_error(
|
raise _map_error(
|
||||||
ex,
|
ex,
|
||||||
rest.errors.HttpInternalServerError,
|
rest.errors.HttpInternalServerError,
|
||||||
'Server configuration error')
|
'Server configuration error')
|
||||||
|
|
||||||
|
|
||||||
def _on_stale_data_error(_ex):
|
def _on_stale_data_error(_ex: Exception) -> None:
|
||||||
raise rest.errors.HttpConflict(
|
raise rest.errors.HttpConflict(
|
||||||
name='IntegrityError',
|
name='IntegrityError',
|
||||||
title='Integrity violation',
|
title='Integrity violation',
|
||||||
|
@ -58,7 +64,7 @@ def _on_stale_data_error(_ex):
|
||||||
'Please try again.'))
|
'Please try again.'))
|
||||||
|
|
||||||
|
|
||||||
def validate_config():
|
def validate_config() -> None:
|
||||||
'''
|
'''
|
||||||
Check whether config doesn't contain errors that might prove
|
Check whether config doesn't contain errors that might prove
|
||||||
lethal at runtime.
|
lethal at runtime.
|
||||||
|
@ -86,7 +92,7 @@ def validate_config():
|
||||||
raise errors.ConfigError('Database is not configured')
|
raise errors.ConfigError('Database is not configured')
|
||||||
|
|
||||||
|
|
||||||
def purge_old_uploads():
|
def purge_old_uploads() -> None:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
file_uploads.purge_old_uploads()
|
file_uploads.purge_old_uploads()
|
||||||
|
@ -95,7 +101,7 @@ def purge_old_uploads():
|
||||||
time.sleep(60 * 5)
|
time.sleep(60 * 5)
|
||||||
|
|
||||||
|
|
||||||
def create_app():
|
def create_app() -> Callable[[Any, Any], Any]:
|
||||||
''' Create a WSGI compatible App object. '''
|
''' Create a WSGI compatible App object. '''
|
||||||
validate_config()
|
validate_config()
|
||||||
coloredlogs.install(fmt='[%(asctime)-15s] %(name)s %(message)s')
|
coloredlogs.install(fmt='[%(asctime)-15s] %(name)s %(message)s')
|
||||||
|
@ -122,7 +128,7 @@ def create_app():
|
||||||
rest.errors.handle(errors.NotFoundError, _on_not_found_error)
|
rest.errors.handle(errors.NotFoundError, _on_not_found_error)
|
||||||
rest.errors.handle(errors.ProcessingError, _on_processing_error)
|
rest.errors.handle(errors.ProcessingError, _on_processing_error)
|
||||||
rest.errors.handle(errors.ThirdPartyError, _on_third_party_error)
|
rest.errors.handle(errors.ThirdPartyError, _on_third_party_error)
|
||||||
rest.errors.handle(sqlalchemy.orm.exc.StaleDataError, _on_stale_data_error)
|
rest.errors.handle(sa.orm.exc.StaleDataError, _on_stale_data_error)
|
||||||
|
|
||||||
return rest.application
|
return rest.application
|
||||||
|
|
||||||
|
|
|
@ -1,22 +1,22 @@
|
||||||
import hashlib
|
import hashlib
|
||||||
import random
|
import random
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from szurubooru import config, db, errors
|
from szurubooru import config, model, errors
|
||||||
from szurubooru.func import util
|
from szurubooru.func import util
|
||||||
|
|
||||||
|
|
||||||
RANK_MAP = OrderedDict([
|
RANK_MAP = OrderedDict([
|
||||||
(db.User.RANK_ANONYMOUS, 'anonymous'),
|
(model.User.RANK_ANONYMOUS, 'anonymous'),
|
||||||
(db.User.RANK_RESTRICTED, 'restricted'),
|
(model.User.RANK_RESTRICTED, 'restricted'),
|
||||||
(db.User.RANK_REGULAR, 'regular'),
|
(model.User.RANK_REGULAR, 'regular'),
|
||||||
(db.User.RANK_POWER, 'power'),
|
(model.User.RANK_POWER, 'power'),
|
||||||
(db.User.RANK_MODERATOR, 'moderator'),
|
(model.User.RANK_MODERATOR, 'moderator'),
|
||||||
(db.User.RANK_ADMINISTRATOR, 'administrator'),
|
(model.User.RANK_ADMINISTRATOR, 'administrator'),
|
||||||
(db.User.RANK_NOBODY, 'nobody'),
|
(model.User.RANK_NOBODY, 'nobody'),
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
||||||
def get_password_hash(salt, password):
|
def get_password_hash(salt: str, password: str) -> str:
|
||||||
''' Retrieve new-style password hash. '''
|
''' Retrieve new-style password hash. '''
|
||||||
digest = hashlib.sha256()
|
digest = hashlib.sha256()
|
||||||
digest.update(config.config['secret'].encode('utf8'))
|
digest.update(config.config['secret'].encode('utf8'))
|
||||||
|
@ -25,7 +25,7 @@ def get_password_hash(salt, password):
|
||||||
return digest.hexdigest()
|
return digest.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def get_legacy_password_hash(salt, password):
|
def get_legacy_password_hash(salt: str, password: str) -> str:
|
||||||
''' Retrieve old-style password hash. '''
|
''' Retrieve old-style password hash. '''
|
||||||
digest = hashlib.sha1()
|
digest = hashlib.sha1()
|
||||||
digest.update(b'1A2/$_4xVa')
|
digest.update(b'1A2/$_4xVa')
|
||||||
|
@ -34,7 +34,7 @@ def get_legacy_password_hash(salt, password):
|
||||||
return digest.hexdigest()
|
return digest.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def create_password():
|
def create_password() -> str:
|
||||||
alphabet = {
|
alphabet = {
|
||||||
'c': list('bcdfghijklmnpqrstvwxyz'),
|
'c': list('bcdfghijklmnpqrstvwxyz'),
|
||||||
'v': list('aeiou'),
|
'v': list('aeiou'),
|
||||||
|
@ -44,7 +44,7 @@ def create_password():
|
||||||
return ''.join(random.choice(alphabet[l]) for l in list(pattern))
|
return ''.join(random.choice(alphabet[l]) for l in list(pattern))
|
||||||
|
|
||||||
|
|
||||||
def is_valid_password(user, password):
|
def is_valid_password(user: model.User, password: str) -> bool:
|
||||||
assert user
|
assert user
|
||||||
salt, valid_hash = user.password_salt, user.password_hash
|
salt, valid_hash = user.password_salt, user.password_hash
|
||||||
possible_hashes = [
|
possible_hashes = [
|
||||||
|
@ -54,7 +54,7 @@ def is_valid_password(user, password):
|
||||||
return valid_hash in possible_hashes
|
return valid_hash in possible_hashes
|
||||||
|
|
||||||
|
|
||||||
def has_privilege(user, privilege_name):
|
def has_privilege(user: model.User, privilege_name: str) -> bool:
|
||||||
assert user
|
assert user
|
||||||
all_ranks = list(RANK_MAP.keys())
|
all_ranks = list(RANK_MAP.keys())
|
||||||
assert privilege_name in config.config['privileges']
|
assert privilege_name in config.config['privileges']
|
||||||
|
@ -65,13 +65,13 @@ def has_privilege(user, privilege_name):
|
||||||
return user.rank in good_ranks
|
return user.rank in good_ranks
|
||||||
|
|
||||||
|
|
||||||
def verify_privilege(user, privilege_name):
|
def verify_privilege(user: model.User, privilege_name: str) -> None:
|
||||||
assert user
|
assert user
|
||||||
if not has_privilege(user, privilege_name):
|
if not has_privilege(user, privilege_name):
|
||||||
raise errors.AuthError('Insufficient privileges to do this.')
|
raise errors.AuthError('Insufficient privileges to do this.')
|
||||||
|
|
||||||
|
|
||||||
def generate_authentication_token(user):
|
def generate_authentication_token(user: model.User) -> str:
|
||||||
''' Generate nonguessable challenge (e.g. links in password reminder). '''
|
''' Generate nonguessable challenge (e.g. links in password reminder). '''
|
||||||
assert user
|
assert user
|
||||||
digest = hashlib.md5()
|
digest = hashlib.md5()
|
||||||
|
|
|
@ -1,21 +1,21 @@
|
||||||
|
from typing import Any, List, Dict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
class LruCacheItem:
|
class LruCacheItem:
|
||||||
def __init__(self, key, value):
|
def __init__(self, key: object, value: Any) -> None:
|
||||||
self.key = key
|
self.key = key
|
||||||
self.value = value
|
self.value = value
|
||||||
self.timestamp = datetime.utcnow()
|
self.timestamp = datetime.utcnow()
|
||||||
|
|
||||||
|
|
||||||
class LruCache:
|
class LruCache:
|
||||||
def __init__(self, length, delta=None):
|
def __init__(self, length: int) -> None:
|
||||||
self.length = length
|
self.length = length
|
||||||
self.delta = delta
|
self.hash = {} # type: Dict[object, LruCacheItem]
|
||||||
self.hash = {}
|
self.item_list = [] # type: List[LruCacheItem]
|
||||||
self.item_list = []
|
|
||||||
|
|
||||||
def insert_item(self, item):
|
def insert_item(self, item: LruCacheItem) -> None:
|
||||||
if item.key in self.hash:
|
if item.key in self.hash:
|
||||||
item_index = next(
|
item_index = next(
|
||||||
i
|
i
|
||||||
|
@ -31,11 +31,11 @@ class LruCache:
|
||||||
self.hash[item.key] = item
|
self.hash[item.key] = item
|
||||||
self.item_list.insert(0, item)
|
self.item_list.insert(0, item)
|
||||||
|
|
||||||
def remove_all(self):
|
def remove_all(self) -> None:
|
||||||
self.hash = {}
|
self.hash = {}
|
||||||
self.item_list = []
|
self.item_list = []
|
||||||
|
|
||||||
def remove_item(self, item):
|
def remove_item(self, item: LruCacheItem) -> None:
|
||||||
del self.hash[item.key]
|
del self.hash[item.key]
|
||||||
del self.item_list[self.item_list.index(item)]
|
del self.item_list[self.item_list.index(item)]
|
||||||
|
|
||||||
|
@ -43,22 +43,22 @@ class LruCache:
|
||||||
_CACHE = LruCache(length=100)
|
_CACHE = LruCache(length=100)
|
||||||
|
|
||||||
|
|
||||||
def purge():
|
def purge() -> None:
|
||||||
_CACHE.remove_all()
|
_CACHE.remove_all()
|
||||||
|
|
||||||
|
|
||||||
def has(key):
|
def has(key: object) -> bool:
|
||||||
return key in _CACHE.hash
|
return key in _CACHE.hash
|
||||||
|
|
||||||
|
|
||||||
def get(key):
|
def get(key: object) -> Any:
|
||||||
return _CACHE.hash[key].value
|
return _CACHE.hash[key].value
|
||||||
|
|
||||||
|
|
||||||
def remove(key):
|
def remove(key: object) -> None:
|
||||||
if has(key):
|
if has(key):
|
||||||
del _CACHE.hash[key]
|
del _CACHE.hash[key]
|
||||||
|
|
||||||
|
|
||||||
def put(key, value):
|
def put(key: object, value: Any) -> None:
|
||||||
_CACHE.insert_item(LruCacheItem(key, value))
|
_CACHE.insert_item(LruCacheItem(key, value))
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import datetime
|
from datetime import datetime
|
||||||
from szurubooru import db, errors
|
from typing import Any, Optional, List, Dict, Callable
|
||||||
from szurubooru.func import users, scores, util
|
from szurubooru import db, model, errors, rest
|
||||||
|
from szurubooru.func import users, scores, util, serialization
|
||||||
|
|
||||||
|
|
||||||
class InvalidCommentIdError(errors.ValidationError):
|
class InvalidCommentIdError(errors.ValidationError):
|
||||||
|
@ -15,52 +16,87 @@ class EmptyCommentTextError(errors.ValidationError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def serialize_comment(comment, auth_user, options=None):
|
class CommentSerializer(serialization.BaseSerializer):
|
||||||
return util.serialize_entity(
|
def __init__(self, comment: model.Comment, auth_user: model.User) -> None:
|
||||||
comment,
|
self.comment = comment
|
||||||
{
|
self.auth_user = auth_user
|
||||||
'id': lambda: comment.comment_id,
|
|
||||||
'user':
|
def _serializers(self) -> Dict[str, Callable[[], Any]]:
|
||||||
lambda: users.serialize_micro_user(comment.user, auth_user),
|
return {
|
||||||
'postId': lambda: comment.post.post_id,
|
'id': self.serialize_id,
|
||||||
'version': lambda: comment.version,
|
'user': self.serialize_user,
|
||||||
'text': lambda: comment.text,
|
'postId': self.serialize_post_id,
|
||||||
'creationTime': lambda: comment.creation_time,
|
'version': self.serialize_version,
|
||||||
'lastEditTime': lambda: comment.last_edit_time,
|
'text': self.serialize_text,
|
||||||
'score': lambda: comment.score,
|
'creationTime': self.serialize_creation_time,
|
||||||
'ownScore': lambda: scores.get_score(comment, auth_user),
|
'lastEditTime': self.serialize_last_edit_time,
|
||||||
},
|
'score': self.serialize_score,
|
||||||
options)
|
'ownScore': self.serialize_own_score,
|
||||||
|
}
|
||||||
|
|
||||||
|
def serialize_id(self) -> Any:
|
||||||
|
return self.comment.comment_id
|
||||||
|
|
||||||
|
def serialize_user(self) -> Any:
|
||||||
|
return users.serialize_micro_user(self.comment.user, self.auth_user)
|
||||||
|
|
||||||
|
def serialize_post_id(self) -> Any:
|
||||||
|
return self.comment.post.post_id
|
||||||
|
|
||||||
|
def serialize_version(self) -> Any:
|
||||||
|
return self.comment.version
|
||||||
|
|
||||||
|
def serialize_text(self) -> Any:
|
||||||
|
return self.comment.text
|
||||||
|
|
||||||
|
def serialize_creation_time(self) -> Any:
|
||||||
|
return self.comment.creation_time
|
||||||
|
|
||||||
|
def serialize_last_edit_time(self) -> Any:
|
||||||
|
return self.comment.last_edit_time
|
||||||
|
|
||||||
|
def serialize_score(self) -> Any:
|
||||||
|
return self.comment.score
|
||||||
|
|
||||||
|
def serialize_own_score(self) -> Any:
|
||||||
|
return scores.get_score(self.comment, self.auth_user)
|
||||||
|
|
||||||
|
|
||||||
def try_get_comment_by_id(comment_id):
|
def serialize_comment(
|
||||||
try:
|
comment: model.Comment,
|
||||||
|
auth_user: model.User,
|
||||||
|
options: List[str]=[]) -> rest.Response:
|
||||||
|
if comment is None:
|
||||||
|
return None
|
||||||
|
return CommentSerializer(comment, auth_user).serialize(options)
|
||||||
|
|
||||||
|
|
||||||
|
def try_get_comment_by_id(comment_id: int) -> Optional[model.Comment]:
|
||||||
comment_id = int(comment_id)
|
comment_id = int(comment_id)
|
||||||
except ValueError:
|
|
||||||
raise InvalidCommentIdError('Invalid comment ID: %r.' % comment_id)
|
|
||||||
return db.session \
|
return db.session \
|
||||||
.query(db.Comment) \
|
.query(model.Comment) \
|
||||||
.filter(db.Comment.comment_id == comment_id) \
|
.filter(model.Comment.comment_id == comment_id) \
|
||||||
.one_or_none()
|
.one_or_none()
|
||||||
|
|
||||||
|
|
||||||
def get_comment_by_id(comment_id):
|
def get_comment_by_id(comment_id: int) -> model.Comment:
|
||||||
comment = try_get_comment_by_id(comment_id)
|
comment = try_get_comment_by_id(comment_id)
|
||||||
if comment:
|
if comment:
|
||||||
return comment
|
return comment
|
||||||
raise CommentNotFoundError('Comment %r not found.' % comment_id)
|
raise CommentNotFoundError('Comment %r not found.' % comment_id)
|
||||||
|
|
||||||
|
|
||||||
def create_comment(user, post, text):
|
def create_comment(
|
||||||
comment = db.Comment()
|
user: model.User, post: model.Post, text: str) -> model.Comment:
|
||||||
|
comment = model.Comment()
|
||||||
comment.user = user
|
comment.user = user
|
||||||
comment.post = post
|
comment.post = post
|
||||||
update_comment_text(comment, text)
|
update_comment_text(comment, text)
|
||||||
comment.creation_time = datetime.datetime.utcnow()
|
comment.creation_time = datetime.utcnow()
|
||||||
return comment
|
return comment
|
||||||
|
|
||||||
|
|
||||||
def update_comment_text(comment, text):
|
def update_comment_text(comment: model.Comment, text: str) -> None:
|
||||||
assert comment
|
assert comment
|
||||||
if not text:
|
if not text:
|
||||||
raise EmptyCommentTextError('Comment text cannot be empty.')
|
raise EmptyCommentTextError('Comment text cannot be empty.')
|
||||||
|
|
|
@ -1,21 +1,26 @@
|
||||||
def get_list_diff(old, new):
|
from typing import List, Dict, Any
|
||||||
value = {'type': 'list change', 'added': [], 'removed': []}
|
|
||||||
|
|
||||||
|
def get_list_diff(old: List[Any], new: List[Any]) -> Any:
|
||||||
equal = True
|
equal = True
|
||||||
|
removed = [] # type: List[Any]
|
||||||
|
added = [] # type: List[Any]
|
||||||
|
|
||||||
for item in old:
|
for item in old:
|
||||||
if item not in new:
|
if item not in new:
|
||||||
equal = False
|
equal = False
|
||||||
value['removed'].append(item)
|
removed.append(item)
|
||||||
|
|
||||||
for item in new:
|
for item in new:
|
||||||
if item not in old:
|
if item not in old:
|
||||||
equal = False
|
equal = False
|
||||||
value['added'].append(item)
|
added.append(item)
|
||||||
|
|
||||||
return None if equal else value
|
return None if equal else {
|
||||||
|
'type': 'list change', 'added': added, 'removed': removed}
|
||||||
|
|
||||||
|
|
||||||
def get_dict_diff(old, new):
|
def get_dict_diff(old: Dict[str, Any], new: Dict[str, Any]) -> Any:
|
||||||
value = {}
|
value = {}
|
||||||
equal = True
|
equal = True
|
||||||
|
|
||||||
|
|
|
@ -1,32 +1,34 @@
|
||||||
import datetime
|
from typing import Any, Optional, Callable, Tuple
|
||||||
from szurubooru import db, errors
|
from datetime import datetime
|
||||||
|
from szurubooru import db, model, errors
|
||||||
|
|
||||||
|
|
||||||
class InvalidFavoriteTargetError(errors.ValidationError):
|
class InvalidFavoriteTargetError(errors.ValidationError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _get_table_info(entity):
|
def _get_table_info(
|
||||||
|
entity: model.Base) -> Tuple[model.Base, Callable[[model.Base], Any]]:
|
||||||
assert entity
|
assert entity
|
||||||
resource_type, _, _ = db.util.get_resource_info(entity)
|
resource_type, _, _ = model.util.get_resource_info(entity)
|
||||||
if resource_type == 'post':
|
if resource_type == 'post':
|
||||||
return db.PostFavorite, lambda table: table.post_id
|
return model.PostFavorite, lambda table: table.post_id
|
||||||
raise InvalidFavoriteTargetError()
|
raise InvalidFavoriteTargetError()
|
||||||
|
|
||||||
|
|
||||||
def _get_fav_entity(entity, user):
|
def _get_fav_entity(entity: model.Base, user: model.User) -> model.Base:
|
||||||
assert entity
|
assert entity
|
||||||
assert user
|
assert user
|
||||||
return db.util.get_aux_entity(db.session, _get_table_info, entity, user)
|
return model.util.get_aux_entity(db.session, _get_table_info, entity, user)
|
||||||
|
|
||||||
|
|
||||||
def has_favorited(entity, user):
|
def has_favorited(entity: model.Base, user: model.User) -> bool:
|
||||||
assert entity
|
assert entity
|
||||||
assert user
|
assert user
|
||||||
return _get_fav_entity(entity, user) is not None
|
return _get_fav_entity(entity, user) is not None
|
||||||
|
|
||||||
|
|
||||||
def unset_favorite(entity, user):
|
def unset_favorite(entity: model.Base, user: Optional[model.User]) -> None:
|
||||||
assert entity
|
assert entity
|
||||||
assert user
|
assert user
|
||||||
fav_entity = _get_fav_entity(entity, user)
|
fav_entity = _get_fav_entity(entity, user)
|
||||||
|
@ -34,7 +36,7 @@ def unset_favorite(entity, user):
|
||||||
db.session.delete(fav_entity)
|
db.session.delete(fav_entity)
|
||||||
|
|
||||||
|
|
||||||
def set_favorite(entity, user):
|
def set_favorite(entity: model.Base, user: Optional[model.User]) -> None:
|
||||||
from szurubooru.func import scores
|
from szurubooru.func import scores
|
||||||
assert entity
|
assert entity
|
||||||
assert user
|
assert user
|
||||||
|
@ -48,5 +50,5 @@ def set_favorite(entity, user):
|
||||||
fav_entity = table()
|
fav_entity = table()
|
||||||
setattr(fav_entity, get_column(table).name, get_column(entity))
|
setattr(fav_entity, get_column(table).name, get_column(entity))
|
||||||
fav_entity.user = user
|
fav_entity.user = user
|
||||||
fav_entity.time = datetime.datetime.utcnow()
|
fav_entity.time = datetime.utcnow()
|
||||||
db.session.add(fav_entity)
|
db.session.add(fav_entity)
|
||||||
|
|
|
@ -1,27 +1,28 @@
|
||||||
import datetime
|
from typing import Optional
|
||||||
|
from datetime import datetime, timedelta
|
||||||
from szurubooru.func import files, util
|
from szurubooru.func import files, util
|
||||||
|
|
||||||
|
|
||||||
MAX_MINUTES = 60
|
MAX_MINUTES = 60
|
||||||
|
|
||||||
|
|
||||||
def _get_path(checksum):
|
def _get_path(checksum: str) -> str:
|
||||||
return 'temporary-uploads/%s.dat' % checksum
|
return 'temporary-uploads/%s.dat' % checksum
|
||||||
|
|
||||||
|
|
||||||
def purge_old_uploads():
|
def purge_old_uploads() -> None:
|
||||||
now = datetime.datetime.now()
|
now = datetime.now()
|
||||||
for file in files.scan('temporary-uploads'):
|
for file in files.scan('temporary-uploads'):
|
||||||
file_time = datetime.datetime.fromtimestamp(file.stat().st_ctime)
|
file_time = datetime.fromtimestamp(file.stat().st_ctime)
|
||||||
if now - file_time > datetime.timedelta(minutes=MAX_MINUTES):
|
if now - file_time > timedelta(minutes=MAX_MINUTES):
|
||||||
files.delete('temporary-uploads/%s' % file.name)
|
files.delete('temporary-uploads/%s' % file.name)
|
||||||
|
|
||||||
|
|
||||||
def get(checksum):
|
def get(checksum: str) -> Optional[bytes]:
|
||||||
return files.get('temporary-uploads/%s.dat' % checksum)
|
return files.get('temporary-uploads/%s.dat' % checksum)
|
||||||
|
|
||||||
|
|
||||||
def save(content):
|
def save(content: bytes) -> str:
|
||||||
checksum = util.get_sha1(content)
|
checksum = util.get_sha1(content)
|
||||||
path = _get_path(checksum)
|
path = _get_path(checksum)
|
||||||
if not files.has(path):
|
if not files.has(path):
|
||||||
|
|
|
@ -1,32 +1,33 @@
|
||||||
|
from typing import Any, Optional, List
|
||||||
import os
|
import os
|
||||||
from szurubooru import config
|
from szurubooru import config
|
||||||
|
|
||||||
|
|
||||||
def _get_full_path(path):
|
def _get_full_path(path: str) -> str:
|
||||||
return os.path.join(config.config['data_dir'], path)
|
return os.path.join(config.config['data_dir'], path)
|
||||||
|
|
||||||
|
|
||||||
def delete(path):
|
def delete(path: str) -> None:
|
||||||
full_path = _get_full_path(path)
|
full_path = _get_full_path(path)
|
||||||
if os.path.exists(full_path):
|
if os.path.exists(full_path):
|
||||||
os.unlink(full_path)
|
os.unlink(full_path)
|
||||||
|
|
||||||
|
|
||||||
def has(path):
|
def has(path: str) -> bool:
|
||||||
return os.path.exists(_get_full_path(path))
|
return os.path.exists(_get_full_path(path))
|
||||||
|
|
||||||
|
|
||||||
def scan(path):
|
def scan(path: str) -> List[os.DirEntry]:
|
||||||
if has(path):
|
if has(path):
|
||||||
return os.scandir(_get_full_path(path))
|
return list(os.scandir(_get_full_path(path)))
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def move(source_path, target_path):
|
def move(source_path: str, target_path: str) -> None:
|
||||||
return os.rename(_get_full_path(source_path), _get_full_path(target_path))
|
os.rename(_get_full_path(source_path), _get_full_path(target_path))
|
||||||
|
|
||||||
|
|
||||||
def get(path):
|
def get(path: str) -> Optional[bytes]:
|
||||||
full_path = _get_full_path(path)
|
full_path = _get_full_path(path)
|
||||||
if not os.path.exists(full_path):
|
if not os.path.exists(full_path):
|
||||||
return None
|
return None
|
||||||
|
@ -34,7 +35,7 @@ def get(path):
|
||||||
return handle.read()
|
return handle.read()
|
||||||
|
|
||||||
|
|
||||||
def save(path, content):
|
def save(path: str, content: bytes) -> None:
|
||||||
full_path = _get_full_path(path)
|
full_path = _get_full_path(path)
|
||||||
os.makedirs(os.path.dirname(full_path), exist_ok=True)
|
os.makedirs(os.path.dirname(full_path), exist_ok=True)
|
||||||
with open(full_path, 'wb') as handle:
|
with open(full_path, 'wb') as handle:
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import Any, Optional, Tuple, Set, List, Callable
|
||||||
import elasticsearch
|
import elasticsearch
|
||||||
import elasticsearch_dsl
|
import elasticsearch_dsl
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -10,13 +11,8 @@ from szurubooru import config, errors
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
es = elasticsearch.Elasticsearch([{
|
|
||||||
'host': config.config['elasticsearch']['host'],
|
|
||||||
'port': config.config['elasticsearch']['port'],
|
|
||||||
}])
|
|
||||||
|
|
||||||
|
# Math based on paper from H. Chi Wong, Marshall Bern and David Goldberg
|
||||||
# Math based on paper from H. Chi Wong, Marshall Bern and David Goldber
|
|
||||||
# Math code taken from https://github.com/ascribe/image-match
|
# Math code taken from https://github.com/ascribe/image-match
|
||||||
# (which is licensed under Apache 2 license)
|
# (which is licensed under Apache 2 license)
|
||||||
|
|
||||||
|
@ -32,14 +28,27 @@ MAX_WORDS = 63
|
||||||
ES_DOC_TYPE = 'image'
|
ES_DOC_TYPE = 'image'
|
||||||
ES_MAX_RESULTS = 100
|
ES_MAX_RESULTS = 100
|
||||||
|
|
||||||
|
Window = Tuple[Tuple[float, float], Tuple[float, float]]
|
||||||
|
NpMatrix = Any
|
||||||
|
|
||||||
def _preprocess_image(image_or_path):
|
|
||||||
img = Image.open(BytesIO(image_or_path))
|
def _get_session() -> elasticsearch.Elasticsearch:
|
||||||
|
return elasticsearch.Elasticsearch([{
|
||||||
|
'host': config.config['elasticsearch']['host'],
|
||||||
|
'port': config.config['elasticsearch']['port'],
|
||||||
|
}])
|
||||||
|
|
||||||
|
|
||||||
|
def _preprocess_image(content: bytes) -> NpMatrix:
|
||||||
|
img = Image.open(BytesIO(content))
|
||||||
img = img.convert('RGB')
|
img = img.convert('RGB')
|
||||||
return rgb2gray(np.asarray(img, dtype=np.uint8))
|
return rgb2gray(np.asarray(img, dtype=np.uint8))
|
||||||
|
|
||||||
|
|
||||||
def _crop_image(image, lower_percentile, upper_percentile):
|
def _crop_image(
|
||||||
|
image: NpMatrix,
|
||||||
|
lower_percentile: float,
|
||||||
|
upper_percentile: float) -> Window:
|
||||||
rw = np.cumsum(np.sum(np.abs(np.diff(image, axis=1)), axis=1))
|
rw = np.cumsum(np.sum(np.abs(np.diff(image, axis=1)), axis=1))
|
||||||
cw = np.cumsum(np.sum(np.abs(np.diff(image, axis=0)), axis=0))
|
cw = np.cumsum(np.sum(np.abs(np.diff(image, axis=0)), axis=0))
|
||||||
upper_column_limit = np.searchsorted(
|
upper_column_limit = np.searchsorted(
|
||||||
|
@ -56,16 +65,19 @@ def _crop_image(image, lower_percentile, upper_percentile):
|
||||||
if lower_column_limit > upper_column_limit:
|
if lower_column_limit > upper_column_limit:
|
||||||
lower_column_limit = int(lower_percentile / 100. * image.shape[1])
|
lower_column_limit = int(lower_percentile / 100. * image.shape[1])
|
||||||
upper_column_limit = int(upper_percentile / 100. * image.shape[1])
|
upper_column_limit = int(upper_percentile / 100. * image.shape[1])
|
||||||
return [
|
return (
|
||||||
(lower_row_limit, upper_row_limit),
|
(lower_row_limit, upper_row_limit),
|
||||||
(lower_column_limit, upper_column_limit)]
|
(lower_column_limit, upper_column_limit))
|
||||||
|
|
||||||
|
|
||||||
def _normalize_and_threshold(diff_array, identical_tolerance, n_levels):
|
def _normalize_and_threshold(
|
||||||
|
diff_array: NpMatrix,
|
||||||
|
identical_tolerance: float,
|
||||||
|
n_levels: int) -> None:
|
||||||
mask = np.abs(diff_array) < identical_tolerance
|
mask = np.abs(diff_array) < identical_tolerance
|
||||||
diff_array[mask] = 0.
|
diff_array[mask] = 0.
|
||||||
if np.all(mask):
|
if np.all(mask):
|
||||||
return None
|
return
|
||||||
positive_cutoffs = np.percentile(
|
positive_cutoffs = np.percentile(
|
||||||
diff_array[diff_array > 0.], np.linspace(0, 100, n_levels + 1))
|
diff_array[diff_array > 0.], np.linspace(0, 100, n_levels + 1))
|
||||||
negative_cutoffs = np.percentile(
|
negative_cutoffs = np.percentile(
|
||||||
|
@ -82,18 +94,24 @@ def _normalize_and_threshold(diff_array, identical_tolerance, n_levels):
|
||||||
diff_array[
|
diff_array[
|
||||||
(diff_array <= interval[0]) & (diff_array >= interval[1])] = \
|
(diff_array <= interval[0]) & (diff_array >= interval[1])] = \
|
||||||
-(level + 1)
|
-(level + 1)
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _compute_grid_points(image, n, window=None):
|
def _compute_grid_points(
|
||||||
|
image: NpMatrix,
|
||||||
|
n: float,
|
||||||
|
window: Window=None) -> Tuple[NpMatrix, NpMatrix]:
|
||||||
if window is None:
|
if window is None:
|
||||||
window = [(0, image.shape[0]), (0, image.shape[1])]
|
window = ((0, image.shape[0]), (0, image.shape[1]))
|
||||||
x_coords = np.linspace(window[0][0], window[0][1], n + 2, dtype=int)[1:-1]
|
x_coords = np.linspace(window[0][0], window[0][1], n + 2, dtype=int)[1:-1]
|
||||||
y_coords = np.linspace(window[1][0], window[1][1], n + 2, dtype=int)[1:-1]
|
y_coords = np.linspace(window[1][0], window[1][1], n + 2, dtype=int)[1:-1]
|
||||||
return x_coords, y_coords
|
return x_coords, y_coords
|
||||||
|
|
||||||
|
|
||||||
def _compute_mean_level(image, x_coords, y_coords, p):
|
def _compute_mean_level(
|
||||||
|
image: NpMatrix,
|
||||||
|
x_coords: NpMatrix,
|
||||||
|
y_coords: NpMatrix,
|
||||||
|
p: Optional[float]) -> NpMatrix:
|
||||||
if p is None:
|
if p is None:
|
||||||
p = max([2.0, int(0.5 + min(image.shape) / 20.)])
|
p = max([2.0, int(0.5 + min(image.shape) / 20.)])
|
||||||
avg_grey = np.zeros((x_coords.shape[0], y_coords.shape[0]))
|
avg_grey = np.zeros((x_coords.shape[0], y_coords.shape[0]))
|
||||||
|
@ -108,7 +126,7 @@ def _compute_mean_level(image, x_coords, y_coords, p):
|
||||||
return avg_grey
|
return avg_grey
|
||||||
|
|
||||||
|
|
||||||
def _compute_differentials(grey_level_matrix):
|
def _compute_differentials(grey_level_matrix: NpMatrix) -> NpMatrix:
|
||||||
flipped = np.fliplr(grey_level_matrix)
|
flipped = np.fliplr(grey_level_matrix)
|
||||||
right_neighbors = -np.concatenate(
|
right_neighbors = -np.concatenate(
|
||||||
(
|
(
|
||||||
|
@ -152,8 +170,8 @@ def _compute_differentials(grey_level_matrix):
|
||||||
lower_right_neighbors]))
|
lower_right_neighbors]))
|
||||||
|
|
||||||
|
|
||||||
def _generate_signature(path_or_image):
|
def _generate_signature(content: bytes) -> NpMatrix:
|
||||||
im_array = _preprocess_image(path_or_image)
|
im_array = _preprocess_image(content)
|
||||||
image_limits = _crop_image(
|
image_limits = _crop_image(
|
||||||
im_array,
|
im_array,
|
||||||
lower_percentile=LOWER_PERCENTILE,
|
lower_percentile=LOWER_PERCENTILE,
|
||||||
|
@ -169,7 +187,7 @@ def _generate_signature(path_or_image):
|
||||||
return np.ravel(diff_matrix).astype('int8')
|
return np.ravel(diff_matrix).astype('int8')
|
||||||
|
|
||||||
|
|
||||||
def _get_words(array, k, n):
|
def _get_words(array: NpMatrix, k: int, n: int) -> NpMatrix:
|
||||||
word_positions = np.linspace(
|
word_positions = np.linspace(
|
||||||
0, array.shape[0], n, endpoint=False).astype('int')
|
0, array.shape[0], n, endpoint=False).astype('int')
|
||||||
assert k <= array.shape[0]
|
assert k <= array.shape[0]
|
||||||
|
@ -187,21 +205,23 @@ def _get_words(array, k, n):
|
||||||
return words
|
return words
|
||||||
|
|
||||||
|
|
||||||
def _words_to_int(word_array):
|
def _words_to_int(word_array: NpMatrix) -> NpMatrix:
|
||||||
width = word_array.shape[1]
|
width = word_array.shape[1]
|
||||||
coding_vector = 3**np.arange(width)
|
coding_vector = 3**np.arange(width)
|
||||||
return np.dot(word_array + 1, coding_vector)
|
return np.dot(word_array + 1, coding_vector)
|
||||||
|
|
||||||
|
|
||||||
def _max_contrast(array):
|
def _max_contrast(array: NpMatrix) -> None:
|
||||||
array[array > 0] = 1
|
array[array > 0] = 1
|
||||||
array[array < 0] = -1
|
array[array < 0] = -1
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _normalized_distance(_target_array, _vec, nan_value=1.0):
|
def _normalized_distance(
|
||||||
target_array = _target_array.astype(int)
|
target_array: NpMatrix,
|
||||||
vec = _vec.astype(int)
|
vec: NpMatrix,
|
||||||
|
nan_value: float=1.0) -> List[float]:
|
||||||
|
target_array = target_array.astype(int)
|
||||||
|
vec = vec.astype(int)
|
||||||
topvec = np.linalg.norm(vec - target_array, axis=1)
|
topvec = np.linalg.norm(vec - target_array, axis=1)
|
||||||
norm1 = np.linalg.norm(vec, axis=0)
|
norm1 = np.linalg.norm(vec, axis=0)
|
||||||
norm2 = np.linalg.norm(target_array, axis=1)
|
norm2 = np.linalg.norm(target_array, axis=1)
|
||||||
|
@ -210,9 +230,9 @@ def _normalized_distance(_target_array, _vec, nan_value=1.0):
|
||||||
return finvec
|
return finvec
|
||||||
|
|
||||||
|
|
||||||
def _safety_blanket(default_param_factory):
|
def _safety_blanket(default_param_factory: Callable[[], Any]) -> Callable:
|
||||||
def wrapper_outer(target_function):
|
def wrapper_outer(target_function: Callable) -> Callable:
|
||||||
def wrapper_inner(*args, **kwargs):
|
def wrapper_inner(*args: Any, **kwargs: Any) -> Any:
|
||||||
try:
|
try:
|
||||||
return target_function(*args, **kwargs)
|
return target_function(*args, **kwargs)
|
||||||
except elasticsearch.exceptions.NotFoundError:
|
except elasticsearch.exceptions.NotFoundError:
|
||||||
|
@ -226,20 +246,20 @@ def _safety_blanket(default_param_factory):
|
||||||
except IOError:
|
except IOError:
|
||||||
raise errors.ProcessingError('Not an image.')
|
raise errors.ProcessingError('Not an image.')
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
raise errors.ThirdPartyError('Unknown error (%s).', ex)
|
raise errors.ThirdPartyError('Unknown error (%s).' % ex)
|
||||||
return wrapper_inner
|
return wrapper_inner
|
||||||
return wrapper_outer
|
return wrapper_outer
|
||||||
|
|
||||||
|
|
||||||
class Lookalike:
|
class Lookalike:
|
||||||
def __init__(self, score, distance, path):
|
def __init__(self, score: int, distance: float, path: Any) -> None:
|
||||||
self.score = score
|
self.score = score
|
||||||
self.distance = distance
|
self.distance = distance
|
||||||
self.path = path
|
self.path = path
|
||||||
|
|
||||||
|
|
||||||
@_safety_blanket(lambda: None)
|
@_safety_blanket(lambda: None)
|
||||||
def add_image(path, image_content):
|
def add_image(path: str, image_content: bytes) -> None:
|
||||||
assert path
|
assert path
|
||||||
assert image_content
|
assert image_content
|
||||||
signature = _generate_signature(image_content)
|
signature = _generate_signature(image_content)
|
||||||
|
@ -253,7 +273,7 @@ def add_image(path, image_content):
|
||||||
for i in range(MAX_WORDS):
|
for i in range(MAX_WORDS):
|
||||||
record['simple_word_' + str(i)] = words[i].tolist()
|
record['simple_word_' + str(i)] = words[i].tolist()
|
||||||
|
|
||||||
es.index(
|
_get_session().index(
|
||||||
index=config.config['elasticsearch']['index'],
|
index=config.config['elasticsearch']['index'],
|
||||||
doc_type=ES_DOC_TYPE,
|
doc_type=ES_DOC_TYPE,
|
||||||
body=record,
|
body=record,
|
||||||
|
@ -261,20 +281,20 @@ def add_image(path, image_content):
|
||||||
|
|
||||||
|
|
||||||
@_safety_blanket(lambda: None)
|
@_safety_blanket(lambda: None)
|
||||||
def delete_image(path):
|
def delete_image(path: str) -> None:
|
||||||
assert path
|
assert path
|
||||||
es.delete_by_query(
|
_get_session().delete_by_query(
|
||||||
index=config.config['elasticsearch']['index'],
|
index=config.config['elasticsearch']['index'],
|
||||||
doc_type=ES_DOC_TYPE,
|
doc_type=ES_DOC_TYPE,
|
||||||
body={'query': {'term': {'path': path}}})
|
body={'query': {'term': {'path': path}}})
|
||||||
|
|
||||||
|
|
||||||
@_safety_blanket(lambda: [])
|
@_safety_blanket(lambda: [])
|
||||||
def search_by_image(image_content):
|
def search_by_image(image_content: bytes) -> List[Lookalike]:
|
||||||
signature = _generate_signature(image_content)
|
signature = _generate_signature(image_content)
|
||||||
words = _get_words(signature, k=SAMPLE_WORDS, n=MAX_WORDS)
|
words = _get_words(signature, k=SAMPLE_WORDS, n=MAX_WORDS)
|
||||||
|
|
||||||
res = es.search(
|
res = _get_session().search(
|
||||||
index=config.config['elasticsearch']['index'],
|
index=config.config['elasticsearch']['index'],
|
||||||
doc_type=ES_DOC_TYPE,
|
doc_type=ES_DOC_TYPE,
|
||||||
body={
|
body={
|
||||||
|
@ -299,7 +319,7 @@ def search_by_image(image_content):
|
||||||
sigs = np.array([x['_source']['signature'] for x in res])
|
sigs = np.array([x['_source']['signature'] for x in res])
|
||||||
dists = _normalized_distance(sigs, np.array(signature))
|
dists = _normalized_distance(sigs, np.array(signature))
|
||||||
|
|
||||||
ids = set()
|
ids = set() # type: Set[int]
|
||||||
ret = []
|
ret = []
|
||||||
for item, dist in zip(res, dists):
|
for item, dist in zip(res, dists):
|
||||||
id = item['_id']
|
id = item['_id']
|
||||||
|
@ -314,8 +334,8 @@ def search_by_image(image_content):
|
||||||
|
|
||||||
|
|
||||||
@_safety_blanket(lambda: None)
|
@_safety_blanket(lambda: None)
|
||||||
def purge():
|
def purge() -> None:
|
||||||
es.delete_by_query(
|
_get_session().delete_by_query(
|
||||||
index=config.config['elasticsearch']['index'],
|
index=config.config['elasticsearch']['index'],
|
||||||
doc_type=ES_DOC_TYPE,
|
doc_type=ES_DOC_TYPE,
|
||||||
body={'query': {'match_all': {}}},
|
body={'query': {'match_all': {}}},
|
||||||
|
@ -323,10 +343,10 @@ def purge():
|
||||||
|
|
||||||
|
|
||||||
@_safety_blanket(lambda: set())
|
@_safety_blanket(lambda: set())
|
||||||
def get_all_paths():
|
def get_all_paths() -> Set[str]:
|
||||||
search = (
|
search = (
|
||||||
elasticsearch_dsl.Search(
|
elasticsearch_dsl.Search(
|
||||||
using=es,
|
using=_get_session(),
|
||||||
index=config.config['elasticsearch']['index'],
|
index=config.config['elasticsearch']['index'],
|
||||||
doc_type=ES_DOC_TYPE)
|
doc_type=ES_DOC_TYPE)
|
||||||
.source(['path']))
|
.source(['path']))
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from typing import List
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
import shlex
|
import shlex
|
||||||
|
@ -15,23 +16,23 @@ _SCALE_FIT_FMT = \
|
||||||
|
|
||||||
|
|
||||||
class Image:
|
class Image:
|
||||||
def __init__(self, content):
|
def __init__(self, content: bytes) -> None:
|
||||||
self.content = content
|
self.content = content
|
||||||
self._reload_info()
|
self._reload_info()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def width(self):
|
def width(self) -> int:
|
||||||
return self.info['streams'][0]['width']
|
return self.info['streams'][0]['width']
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def height(self):
|
def height(self) -> int:
|
||||||
return self.info['streams'][0]['height']
|
return self.info['streams'][0]['height']
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def frames(self):
|
def frames(self) -> int:
|
||||||
return self.info['streams'][0]['nb_read_frames']
|
return self.info['streams'][0]['nb_read_frames']
|
||||||
|
|
||||||
def resize_fill(self, width, height):
|
def resize_fill(self, width: int, height: int) -> None:
|
||||||
cli = [
|
cli = [
|
||||||
'-i', '{path}',
|
'-i', '{path}',
|
||||||
'-f', 'image2',
|
'-f', 'image2',
|
||||||
|
@ -53,7 +54,7 @@ class Image:
|
||||||
assert self.content
|
assert self.content
|
||||||
self._reload_info()
|
self._reload_info()
|
||||||
|
|
||||||
def to_png(self):
|
def to_png(self) -> bytes:
|
||||||
return self._execute([
|
return self._execute([
|
||||||
'-i', '{path}',
|
'-i', '{path}',
|
||||||
'-f', 'image2',
|
'-f', 'image2',
|
||||||
|
@ -63,7 +64,7 @@ class Image:
|
||||||
'-',
|
'-',
|
||||||
])
|
])
|
||||||
|
|
||||||
def to_jpeg(self):
|
def to_jpeg(self) -> bytes:
|
||||||
return self._execute([
|
return self._execute([
|
||||||
'-f', 'lavfi',
|
'-f', 'lavfi',
|
||||||
'-i', 'color=white:s=%dx%d' % (self.width, self.height),
|
'-i', 'color=white:s=%dx%d' % (self.width, self.height),
|
||||||
|
@ -76,7 +77,7 @@ class Image:
|
||||||
'-',
|
'-',
|
||||||
])
|
])
|
||||||
|
|
||||||
def _execute(self, cli, program='ffmpeg'):
|
def _execute(self, cli: List[str], program: str='ffmpeg') -> bytes:
|
||||||
extension = mime.get_extension(mime.get_mime_type(self.content))
|
extension = mime.get_extension(mime.get_mime_type(self.content))
|
||||||
assert extension
|
assert extension
|
||||||
with util.create_temp_file(suffix='.' + extension) as handle:
|
with util.create_temp_file(suffix='.' + extension) as handle:
|
||||||
|
@ -99,7 +100,7 @@ class Image:
|
||||||
'Error while processing image.\n' + err.decode('utf-8'))
|
'Error while processing image.\n' + err.decode('utf-8'))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def _reload_info(self):
|
def _reload_info(self) -> None:
|
||||||
self.info = json.loads(self._execute([
|
self.info = json.loads(self._execute([
|
||||||
'-i', '{path}',
|
'-i', '{path}',
|
||||||
'-of', 'json',
|
'-of', 'json',
|
||||||
|
|
|
@ -3,7 +3,7 @@ import email.mime.text
|
||||||
from szurubooru import config
|
from szurubooru import config
|
||||||
|
|
||||||
|
|
||||||
def send_mail(sender, recipient, subject, body):
|
def send_mail(sender: str, recipient: str, subject: str, body: str) -> None:
|
||||||
msg = email.mime.text.MIMEText(body)
|
msg = email.mime.text.MIMEText(body)
|
||||||
msg['Subject'] = subject
|
msg['Subject'] = subject
|
||||||
msg['From'] = sender
|
msg['From'] = sender
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
import re
|
import re
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
def get_mime_type(content):
|
def get_mime_type(content: bytes) -> str:
|
||||||
if not content:
|
if not content:
|
||||||
return 'application/octet-stream'
|
return 'application/octet-stream'
|
||||||
|
|
||||||
|
@ -26,7 +27,7 @@ def get_mime_type(content):
|
||||||
return 'application/octet-stream'
|
return 'application/octet-stream'
|
||||||
|
|
||||||
|
|
||||||
def get_extension(mime_type):
|
def get_extension(mime_type: str) -> Optional[str]:
|
||||||
extension_map = {
|
extension_map = {
|
||||||
'application/x-shockwave-flash': 'swf',
|
'application/x-shockwave-flash': 'swf',
|
||||||
'image/gif': 'gif',
|
'image/gif': 'gif',
|
||||||
|
@ -39,19 +40,19 @@ def get_extension(mime_type):
|
||||||
return extension_map.get((mime_type or '').strip().lower(), None)
|
return extension_map.get((mime_type or '').strip().lower(), None)
|
||||||
|
|
||||||
|
|
||||||
def is_flash(mime_type):
|
def is_flash(mime_type: str) -> bool:
|
||||||
return mime_type.lower() == 'application/x-shockwave-flash'
|
return mime_type.lower() == 'application/x-shockwave-flash'
|
||||||
|
|
||||||
|
|
||||||
def is_video(mime_type):
|
def is_video(mime_type: str) -> bool:
|
||||||
return mime_type.lower() in ('application/ogg', 'video/mp4', 'video/webm')
|
return mime_type.lower() in ('application/ogg', 'video/mp4', 'video/webm')
|
||||||
|
|
||||||
|
|
||||||
def is_image(mime_type):
|
def is_image(mime_type: str) -> bool:
|
||||||
return mime_type.lower() in ('image/jpeg', 'image/png', 'image/gif')
|
return mime_type.lower() in ('image/jpeg', 'image/png', 'image/gif')
|
||||||
|
|
||||||
|
|
||||||
def is_animated_gif(content):
|
def is_animated_gif(content: bytes) -> bool:
|
||||||
pattern = b'\x21\xF9\x04[\x00-\xFF]{4}\x00[\x2C\x21]'
|
pattern = b'\x21\xF9\x04[\x00-\xFF]{4}\x00[\x2C\x21]'
|
||||||
return get_mime_type(content) == 'image/gif' \
|
return get_mime_type(content) == 'image/gif' \
|
||||||
and len(re.findall(pattern, content)) > 1
|
and len(re.findall(pattern, content)) > 1
|
||||||
|
|
|
@ -2,7 +2,7 @@ import urllib.request
|
||||||
from szurubooru import errors
|
from szurubooru import errors
|
||||||
|
|
||||||
|
|
||||||
def download(url):
|
def download(url: str) -> bytes:
|
||||||
assert url
|
assert url
|
||||||
request = urllib.request.Request(url)
|
request = urllib.request.Request(url)
|
||||||
request.add_header('Referer', url)
|
request.add_header('Referer', url)
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
import datetime
|
from typing import Any, Optional, Tuple, List, Dict, Callable
|
||||||
import sqlalchemy
|
from datetime import datetime
|
||||||
from szurubooru import config, db, errors
|
import sqlalchemy as sa
|
||||||
|
from szurubooru import config, db, model, errors, rest
|
||||||
from szurubooru.func import (
|
from szurubooru.func import (
|
||||||
users, scores, comments, tags, util, mime, images, files, image_hash)
|
users, scores, comments, tags, util,
|
||||||
|
mime, images, files, image_hash, serialization)
|
||||||
|
|
||||||
|
|
||||||
EMPTY_PIXEL = \
|
EMPTY_PIXEL = \
|
||||||
|
@ -20,7 +22,7 @@ class PostAlreadyFeaturedError(errors.ValidationError):
|
||||||
|
|
||||||
|
|
||||||
class PostAlreadyUploadedError(errors.ValidationError):
|
class PostAlreadyUploadedError(errors.ValidationError):
|
||||||
def __init__(self, other_post):
|
def __init__(self, other_post: model.Post) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
'Post already uploaded (%d)' % other_post.post_id,
|
'Post already uploaded (%d)' % other_post.post_id,
|
||||||
{
|
{
|
||||||
|
@ -58,30 +60,30 @@ class InvalidPostFlagError(errors.ValidationError):
|
||||||
|
|
||||||
|
|
||||||
class PostLookalike(image_hash.Lookalike):
|
class PostLookalike(image_hash.Lookalike):
|
||||||
def __init__(self, score, distance, post):
|
def __init__(self, score: int, distance: float, post: model.Post) -> None:
|
||||||
super().__init__(score, distance, post.post_id)
|
super().__init__(score, distance, post.post_id)
|
||||||
self.post = post
|
self.post = post
|
||||||
|
|
||||||
|
|
||||||
SAFETY_MAP = {
|
SAFETY_MAP = {
|
||||||
db.Post.SAFETY_SAFE: 'safe',
|
model.Post.SAFETY_SAFE: 'safe',
|
||||||
db.Post.SAFETY_SKETCHY: 'sketchy',
|
model.Post.SAFETY_SKETCHY: 'sketchy',
|
||||||
db.Post.SAFETY_UNSAFE: 'unsafe',
|
model.Post.SAFETY_UNSAFE: 'unsafe',
|
||||||
}
|
}
|
||||||
|
|
||||||
TYPE_MAP = {
|
TYPE_MAP = {
|
||||||
db.Post.TYPE_IMAGE: 'image',
|
model.Post.TYPE_IMAGE: 'image',
|
||||||
db.Post.TYPE_ANIMATION: 'animation',
|
model.Post.TYPE_ANIMATION: 'animation',
|
||||||
db.Post.TYPE_VIDEO: 'video',
|
model.Post.TYPE_VIDEO: 'video',
|
||||||
db.Post.TYPE_FLASH: 'flash',
|
model.Post.TYPE_FLASH: 'flash',
|
||||||
}
|
}
|
||||||
|
|
||||||
FLAG_MAP = {
|
FLAG_MAP = {
|
||||||
db.Post.FLAG_LOOP: 'loop',
|
model.Post.FLAG_LOOP: 'loop',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_post_content_url(post):
|
def get_post_content_url(post: model.Post) -> str:
|
||||||
assert post
|
assert post
|
||||||
return '%s/posts/%d.%s' % (
|
return '%s/posts/%d.%s' % (
|
||||||
config.config['data_url'].rstrip('/'),
|
config.config['data_url'].rstrip('/'),
|
||||||
|
@ -89,31 +91,31 @@ def get_post_content_url(post):
|
||||||
mime.get_extension(post.mime_type) or 'dat')
|
mime.get_extension(post.mime_type) or 'dat')
|
||||||
|
|
||||||
|
|
||||||
def get_post_thumbnail_url(post):
|
def get_post_thumbnail_url(post: model.Post) -> str:
|
||||||
assert post
|
assert post
|
||||||
return '%s/generated-thumbnails/%d.jpg' % (
|
return '%s/generated-thumbnails/%d.jpg' % (
|
||||||
config.config['data_url'].rstrip('/'),
|
config.config['data_url'].rstrip('/'),
|
||||||
post.post_id)
|
post.post_id)
|
||||||
|
|
||||||
|
|
||||||
def get_post_content_path(post):
|
def get_post_content_path(post: model.Post) -> str:
|
||||||
assert post
|
assert post
|
||||||
assert post.post_id
|
assert post.post_id
|
||||||
return 'posts/%d.%s' % (
|
return 'posts/%d.%s' % (
|
||||||
post.post_id, mime.get_extension(post.mime_type) or 'dat')
|
post.post_id, mime.get_extension(post.mime_type) or 'dat')
|
||||||
|
|
||||||
|
|
||||||
def get_post_thumbnail_path(post):
|
def get_post_thumbnail_path(post: model.Post) -> str:
|
||||||
assert post
|
assert post
|
||||||
return 'generated-thumbnails/%d.jpg' % (post.post_id)
|
return 'generated-thumbnails/%d.jpg' % (post.post_id)
|
||||||
|
|
||||||
|
|
||||||
def get_post_thumbnail_backup_path(post):
|
def get_post_thumbnail_backup_path(post: model.Post) -> str:
|
||||||
assert post
|
assert post
|
||||||
return 'posts/custom-thumbnails/%d.dat' % (post.post_id)
|
return 'posts/custom-thumbnails/%d.dat' % (post.post_id)
|
||||||
|
|
||||||
|
|
||||||
def serialize_note(note):
|
def serialize_note(note: model.PostNote) -> rest.Response:
|
||||||
assert note
|
assert note
|
||||||
return {
|
return {
|
||||||
'polygon': note.polygon,
|
'polygon': note.polygon,
|
||||||
|
@ -121,113 +123,216 @@ def serialize_note(note):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def serialize_post(post, auth_user, options=None):
|
class PostSerializer(serialization.BaseSerializer):
|
||||||
return util.serialize_entity(
|
def __init__(self, post: model.Post, auth_user: model.User) -> None:
|
||||||
post,
|
self.post = post
|
||||||
|
self.auth_user = auth_user
|
||||||
|
|
||||||
|
def _serializers(self) -> Dict[str, Callable[[], Any]]:
|
||||||
|
return {
|
||||||
|
'id': self.serialize_id,
|
||||||
|
'version': self.serialize_version,
|
||||||
|
'creationTime': self.serialize_creation_time,
|
||||||
|
'lastEditTime': self.serialize_last_edit_time,
|
||||||
|
'safety': self.serialize_safety,
|
||||||
|
'source': self.serialize_source,
|
||||||
|
'type': self.serialize_type,
|
||||||
|
'mimeType': self.serialize_mime,
|
||||||
|
'checksum': self.serialize_checksum,
|
||||||
|
'fileSize': self.serialize_file_size,
|
||||||
|
'canvasWidth': self.serialize_canvas_width,
|
||||||
|
'canvasHeight': self.serialize_canvas_height,
|
||||||
|
'contentUrl': self.serialize_content_url,
|
||||||
|
'thumbnailUrl': self.serialize_thumbnail_url,
|
||||||
|
'flags': self.serialize_flags,
|
||||||
|
'tags': self.serialize_tags,
|
||||||
|
'relations': self.serialize_relations,
|
||||||
|
'user': self.serialize_user,
|
||||||
|
'score': self.serialize_score,
|
||||||
|
'ownScore': self.serialize_own_score,
|
||||||
|
'ownFavorite': self.serialize_own_favorite,
|
||||||
|
'tagCount': self.serialize_tag_count,
|
||||||
|
'favoriteCount': self.serialize_favorite_count,
|
||||||
|
'commentCount': self.serialize_comment_count,
|
||||||
|
'noteCount': self.serialize_note_count,
|
||||||
|
'relationCount': self.serialize_relation_count,
|
||||||
|
'featureCount': self.serialize_feature_count,
|
||||||
|
'lastFeatureTime': self.serialize_last_feature_time,
|
||||||
|
'favoritedBy': self.serialize_favorited_by,
|
||||||
|
'hasCustomThumbnail': self.serialize_has_custom_thumbnail,
|
||||||
|
'notes': self.serialize_notes,
|
||||||
|
'comments': self.serialize_comments,
|
||||||
|
}
|
||||||
|
|
||||||
|
def serialize_id(self) -> Any:
|
||||||
|
return self.post.post_id
|
||||||
|
|
||||||
|
def serialize_version(self) -> Any:
|
||||||
|
return self.post.version
|
||||||
|
|
||||||
|
def serialize_creation_time(self) -> Any:
|
||||||
|
return self.post.creation_time
|
||||||
|
|
||||||
|
def serialize_last_edit_time(self) -> Any:
|
||||||
|
return self.post.last_edit_time
|
||||||
|
|
||||||
|
def serialize_safety(self) -> Any:
|
||||||
|
return SAFETY_MAP[self.post.safety]
|
||||||
|
|
||||||
|
def serialize_source(self) -> Any:
|
||||||
|
return self.post.source
|
||||||
|
|
||||||
|
def serialize_type(self) -> Any:
|
||||||
|
return TYPE_MAP[self.post.type]
|
||||||
|
|
||||||
|
def serialize_mime(self) -> Any:
|
||||||
|
return self.post.mime_type
|
||||||
|
|
||||||
|
def serialize_checksum(self) -> Any:
|
||||||
|
return self.post.checksum
|
||||||
|
|
||||||
|
def serialize_file_size(self) -> Any:
|
||||||
|
return self.post.file_size
|
||||||
|
|
||||||
|
def serialize_canvas_width(self) -> Any:
|
||||||
|
return self.post.canvas_width
|
||||||
|
|
||||||
|
def serialize_canvas_height(self) -> Any:
|
||||||
|
return self.post.canvas_height
|
||||||
|
|
||||||
|
def serialize_content_url(self) -> Any:
|
||||||
|
return get_post_content_url(self.post)
|
||||||
|
|
||||||
|
def serialize_thumbnail_url(self) -> Any:
|
||||||
|
return get_post_thumbnail_url(self.post)
|
||||||
|
|
||||||
|
def serialize_flags(self) -> Any:
|
||||||
|
return self.post.flags
|
||||||
|
|
||||||
|
def serialize_tags(self) -> Any:
|
||||||
|
return [tag.names[0].name for tag in tags.sort_tags(self.post.tags)]
|
||||||
|
|
||||||
|
def serialize_relations(self) -> Any:
|
||||||
|
return sorted(
|
||||||
{
|
{
|
||||||
'id': lambda: post.post_id,
|
post['id']: post
|
||||||
'version': lambda: post.version,
|
for post in [
|
||||||
'creationTime': lambda: post.creation_time,
|
serialize_micro_post(rel, self.auth_user)
|
||||||
'lastEditTime': lambda: post.last_edit_time,
|
for rel in self.post.relations]
|
||||||
'safety': lambda: SAFETY_MAP[post.safety],
|
|
||||||
'source': lambda: post.source,
|
|
||||||
'type': lambda: TYPE_MAP[post.type],
|
|
||||||
'mimeType': lambda: post.mime_type,
|
|
||||||
'checksum': lambda: post.checksum,
|
|
||||||
'fileSize': lambda: post.file_size,
|
|
||||||
'canvasWidth': lambda: post.canvas_width,
|
|
||||||
'canvasHeight': lambda: post.canvas_height,
|
|
||||||
'contentUrl': lambda: get_post_content_url(post),
|
|
||||||
'thumbnailUrl': lambda: get_post_thumbnail_url(post),
|
|
||||||
'flags': lambda: post.flags,
|
|
||||||
'tags': lambda: [
|
|
||||||
tag.names[0].name for tag in tags.sort_tags(post.tags)],
|
|
||||||
'relations': lambda: sorted(
|
|
||||||
{
|
|
||||||
post['id']:
|
|
||||||
post for post in [
|
|
||||||
serialize_micro_post(rel, auth_user)
|
|
||||||
for rel in post.relations]
|
|
||||||
}.values(),
|
}.values(),
|
||||||
key=lambda post: post['id']),
|
key=lambda post: post['id'])
|
||||||
'user': lambda: users.serialize_micro_user(post.user, auth_user),
|
|
||||||
'score': lambda: post.score,
|
def serialize_user(self) -> Any:
|
||||||
'ownScore': lambda: scores.get_score(post, auth_user),
|
return users.serialize_micro_user(self.post.user, self.auth_user)
|
||||||
'ownFavorite': lambda: len([
|
|
||||||
user for user in post.favorited_by
|
def serialize_score(self) -> Any:
|
||||||
if user.user_id == auth_user.user_id]
|
return self.post.score
|
||||||
) > 0,
|
|
||||||
'tagCount': lambda: post.tag_count,
|
def serialize_own_score(self) -> Any:
|
||||||
'favoriteCount': lambda: post.favorite_count,
|
return scores.get_score(self.post, self.auth_user)
|
||||||
'commentCount': lambda: post.comment_count,
|
|
||||||
'noteCount': lambda: post.note_count,
|
def serialize_own_favorite(self) -> Any:
|
||||||
'relationCount': lambda: post.relation_count,
|
return len([
|
||||||
'featureCount': lambda: post.feature_count,
|
user for user in self.post.favorited_by
|
||||||
'lastFeatureTime': lambda: post.last_feature_time,
|
if user.user_id == self.auth_user.user_id]
|
||||||
'favoritedBy': lambda: [
|
) > 0
|
||||||
users.serialize_micro_user(rel.user, auth_user)
|
|
||||||
for rel in post.favorited_by
|
def serialize_tag_count(self) -> Any:
|
||||||
],
|
return self.post.tag_count
|
||||||
'hasCustomThumbnail':
|
|
||||||
lambda: files.has(get_post_thumbnail_backup_path(post)),
|
def serialize_favorite_count(self) -> Any:
|
||||||
'notes': lambda: sorted(
|
return self.post.favorite_count
|
||||||
[serialize_note(note) for note in post.notes],
|
|
||||||
key=lambda x: x['polygon']),
|
def serialize_comment_count(self) -> Any:
|
||||||
'comments': lambda: [
|
return self.post.comment_count
|
||||||
comments.serialize_comment(comment, auth_user)
|
|
||||||
|
def serialize_note_count(self) -> Any:
|
||||||
|
return self.post.note_count
|
||||||
|
|
||||||
|
def serialize_relation_count(self) -> Any:
|
||||||
|
return self.post.relation_count
|
||||||
|
|
||||||
|
def serialize_feature_count(self) -> Any:
|
||||||
|
return self.post.feature_count
|
||||||
|
|
||||||
|
def serialize_last_feature_time(self) -> Any:
|
||||||
|
return self.post.last_feature_time
|
||||||
|
|
||||||
|
def serialize_favorited_by(self) -> Any:
|
||||||
|
return [
|
||||||
|
users.serialize_micro_user(rel.user, self.auth_user)
|
||||||
|
for rel in self.post.favorited_by
|
||||||
|
]
|
||||||
|
|
||||||
|
def serialize_has_custom_thumbnail(self) -> Any:
|
||||||
|
return files.has(get_post_thumbnail_backup_path(self.post))
|
||||||
|
|
||||||
|
def serialize_notes(self) -> Any:
|
||||||
|
return sorted(
|
||||||
|
[serialize_note(note) for note in self.post.notes],
|
||||||
|
key=lambda x: x['polygon'])
|
||||||
|
|
||||||
|
def serialize_comments(self) -> Any:
|
||||||
|
return [
|
||||||
|
comments.serialize_comment(comment, self.auth_user)
|
||||||
for comment in sorted(
|
for comment in sorted(
|
||||||
post.comments,
|
self.post.comments,
|
||||||
key=lambda comment: comment.creation_time)],
|
key=lambda comment: comment.creation_time)]
|
||||||
},
|
|
||||||
options)
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_micro_post(post, auth_user):
|
def serialize_post(
|
||||||
|
post: Optional[model.Post],
|
||||||
|
auth_user: model.User,
|
||||||
|
options: List[str]=[]) -> Optional[rest.Response]:
|
||||||
|
if not post:
|
||||||
|
return None
|
||||||
|
return PostSerializer(post, auth_user).serialize(options)
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_micro_post(
|
||||||
|
post: model.Post, auth_user: model.User) -> Optional[rest.Response]:
|
||||||
return serialize_post(
|
return serialize_post(
|
||||||
post,
|
post, auth_user=auth_user, options=['id', 'thumbnailUrl'])
|
||||||
auth_user=auth_user,
|
|
||||||
options=['id', 'thumbnailUrl'])
|
|
||||||
|
|
||||||
|
|
||||||
def get_post_count():
|
def get_post_count() -> int:
|
||||||
return db.session.query(sqlalchemy.func.count(db.Post.post_id)).one()[0]
|
return db.session.query(sa.func.count(model.Post.post_id)).one()[0]
|
||||||
|
|
||||||
|
|
||||||
def try_get_post_by_id(post_id):
|
def try_get_post_by_id(post_id: int) -> Optional[model.Post]:
|
||||||
try:
|
|
||||||
post_id = int(post_id)
|
|
||||||
except ValueError:
|
|
||||||
raise InvalidPostIdError('Invalid post ID: %r.' % post_id)
|
|
||||||
return db.session \
|
return db.session \
|
||||||
.query(db.Post) \
|
.query(model.Post) \
|
||||||
.filter(db.Post.post_id == post_id) \
|
.filter(model.Post.post_id == post_id) \
|
||||||
.one_or_none()
|
.one_or_none()
|
||||||
|
|
||||||
|
|
||||||
def get_post_by_id(post_id):
|
def get_post_by_id(post_id: int) -> model.Post:
|
||||||
post = try_get_post_by_id(post_id)
|
post = try_get_post_by_id(post_id)
|
||||||
if not post:
|
if not post:
|
||||||
raise PostNotFoundError('Post %r not found.' % post_id)
|
raise PostNotFoundError('Post %r not found.' % post_id)
|
||||||
return post
|
return post
|
||||||
|
|
||||||
|
|
||||||
def try_get_current_post_feature():
|
def try_get_current_post_feature() -> Optional[model.PostFeature]:
|
||||||
return db.session \
|
return db.session \
|
||||||
.query(db.PostFeature) \
|
.query(model.PostFeature) \
|
||||||
.order_by(db.PostFeature.time.desc()) \
|
.order_by(model.PostFeature.time.desc()) \
|
||||||
.first()
|
.first()
|
||||||
|
|
||||||
|
|
||||||
def try_get_featured_post():
|
def try_get_featured_post() -> Optional[model.Post]:
|
||||||
post_feature = try_get_current_post_feature()
|
post_feature = try_get_current_post_feature()
|
||||||
return post_feature.post if post_feature else None
|
return post_feature.post if post_feature else None
|
||||||
|
|
||||||
|
|
||||||
def create_post(content, tag_names, user):
|
def create_post(
|
||||||
post = db.Post()
|
content: bytes,
|
||||||
post.safety = db.Post.SAFETY_SAFE
|
tag_names: List[str],
|
||||||
|
user: Optional[model.User]) -> Tuple[model.Post, List[model.Tag]]:
|
||||||
|
post = model.Post()
|
||||||
|
post.safety = model.Post.SAFETY_SAFE
|
||||||
post.user = user
|
post.user = user
|
||||||
post.creation_time = datetime.datetime.utcnow()
|
post.creation_time = datetime.utcnow()
|
||||||
post.flags = []
|
post.flags = []
|
||||||
|
|
||||||
post.type = ''
|
post.type = ''
|
||||||
|
@ -240,7 +345,7 @@ def create_post(content, tag_names, user):
|
||||||
return (post, new_tags)
|
return (post, new_tags)
|
||||||
|
|
||||||
|
|
||||||
def update_post_safety(post, safety):
|
def update_post_safety(post: model.Post, safety: str) -> None:
|
||||||
assert post
|
assert post
|
||||||
safety = util.flip(SAFETY_MAP).get(safety, None)
|
safety = util.flip(SAFETY_MAP).get(safety, None)
|
||||||
if not safety:
|
if not safety:
|
||||||
|
@ -249,30 +354,33 @@ def update_post_safety(post, safety):
|
||||||
post.safety = safety
|
post.safety = safety
|
||||||
|
|
||||||
|
|
||||||
def update_post_source(post, source):
|
def update_post_source(post: model.Post, source: Optional[str]) -> None:
|
||||||
assert post
|
assert post
|
||||||
if util.value_exceeds_column_size(source, db.Post.source):
|
if util.value_exceeds_column_size(source, model.Post.source):
|
||||||
raise InvalidPostSourceError('Source is too long.')
|
raise InvalidPostSourceError('Source is too long.')
|
||||||
post.source = source
|
post.source = source or None
|
||||||
|
|
||||||
|
|
||||||
@sqlalchemy.events.event.listens_for(db.Post, 'after_insert')
|
@sa.events.event.listens_for(model.Post, 'after_insert')
|
||||||
def _after_post_insert(_mapper, _connection, post):
|
def _after_post_insert(
|
||||||
|
_mapper: Any, _connection: Any, post: model.Post) -> None:
|
||||||
_sync_post_content(post)
|
_sync_post_content(post)
|
||||||
|
|
||||||
|
|
||||||
@sqlalchemy.events.event.listens_for(db.Post, 'after_update')
|
@sa.events.event.listens_for(model.Post, 'after_update')
|
||||||
def _after_post_update(_mapper, _connection, post):
|
def _after_post_update(
|
||||||
|
_mapper: Any, _connection: Any, post: model.Post) -> None:
|
||||||
_sync_post_content(post)
|
_sync_post_content(post)
|
||||||
|
|
||||||
|
|
||||||
@sqlalchemy.events.event.listens_for(db.Post, 'before_delete')
|
@sa.events.event.listens_for(model.Post, 'before_delete')
|
||||||
def _before_post_delete(_mapper, _connection, post):
|
def _before_post_delete(
|
||||||
|
_mapper: Any, _connection: Any, post: model.Post) -> None:
|
||||||
if post.post_id:
|
if post.post_id:
|
||||||
image_hash.delete_image(post.post_id)
|
image_hash.delete_image(post.post_id)
|
||||||
|
|
||||||
|
|
||||||
def _sync_post_content(post):
|
def _sync_post_content(post: model.Post) -> None:
|
||||||
regenerate_thumb = False
|
regenerate_thumb = False
|
||||||
|
|
||||||
if hasattr(post, '__content'):
|
if hasattr(post, '__content'):
|
||||||
|
@ -281,7 +389,7 @@ def _sync_post_content(post):
|
||||||
delattr(post, '__content')
|
delattr(post, '__content')
|
||||||
regenerate_thumb = True
|
regenerate_thumb = True
|
||||||
if post.post_id and post.type in (
|
if post.post_id and post.type in (
|
||||||
db.Post.TYPE_IMAGE, db.Post.TYPE_ANIMATION):
|
model.Post.TYPE_IMAGE, model.Post.TYPE_ANIMATION):
|
||||||
image_hash.delete_image(post.post_id)
|
image_hash.delete_image(post.post_id)
|
||||||
image_hash.add_image(post.post_id, content)
|
image_hash.add_image(post.post_id, content)
|
||||||
|
|
||||||
|
@ -299,29 +407,29 @@ def _sync_post_content(post):
|
||||||
generate_post_thumbnail(post)
|
generate_post_thumbnail(post)
|
||||||
|
|
||||||
|
|
||||||
def update_post_content(post, content):
|
def update_post_content(post: model.Post, content: Optional[bytes]) -> None:
|
||||||
assert post
|
assert post
|
||||||
if not content:
|
if not content:
|
||||||
raise InvalidPostContentError('Post content missing.')
|
raise InvalidPostContentError('Post content missing.')
|
||||||
post.mime_type = mime.get_mime_type(content)
|
post.mime_type = mime.get_mime_type(content)
|
||||||
if mime.is_flash(post.mime_type):
|
if mime.is_flash(post.mime_type):
|
||||||
post.type = db.Post.TYPE_FLASH
|
post.type = model.Post.TYPE_FLASH
|
||||||
elif mime.is_image(post.mime_type):
|
elif mime.is_image(post.mime_type):
|
||||||
if mime.is_animated_gif(content):
|
if mime.is_animated_gif(content):
|
||||||
post.type = db.Post.TYPE_ANIMATION
|
post.type = model.Post.TYPE_ANIMATION
|
||||||
else:
|
else:
|
||||||
post.type = db.Post.TYPE_IMAGE
|
post.type = model.Post.TYPE_IMAGE
|
||||||
elif mime.is_video(post.mime_type):
|
elif mime.is_video(post.mime_type):
|
||||||
post.type = db.Post.TYPE_VIDEO
|
post.type = model.Post.TYPE_VIDEO
|
||||||
else:
|
else:
|
||||||
raise InvalidPostContentError(
|
raise InvalidPostContentError(
|
||||||
'Unhandled file type: %r' % post.mime_type)
|
'Unhandled file type: %r' % post.mime_type)
|
||||||
|
|
||||||
post.checksum = util.get_sha1(content)
|
post.checksum = util.get_sha1(content)
|
||||||
other_post = db.session \
|
other_post = db.session \
|
||||||
.query(db.Post) \
|
.query(model.Post) \
|
||||||
.filter(db.Post.checksum == post.checksum) \
|
.filter(model.Post.checksum == post.checksum) \
|
||||||
.filter(db.Post.post_id != post.post_id) \
|
.filter(model.Post.post_id != post.post_id) \
|
||||||
.one_or_none()
|
.one_or_none()
|
||||||
if other_post \
|
if other_post \
|
||||||
and other_post.post_id \
|
and other_post.post_id \
|
||||||
|
@ -343,18 +451,20 @@ def update_post_content(post, content):
|
||||||
setattr(post, '__content', content)
|
setattr(post, '__content', content)
|
||||||
|
|
||||||
|
|
||||||
def update_post_thumbnail(post, content=None):
|
def update_post_thumbnail(
|
||||||
|
post: model.Post, content: Optional[bytes]=None) -> None:
|
||||||
assert post
|
assert post
|
||||||
setattr(post, '__thumbnail', content)
|
setattr(post, '__thumbnail', content)
|
||||||
|
|
||||||
|
|
||||||
def generate_post_thumbnail(post):
|
def generate_post_thumbnail(post: model.Post) -> None:
|
||||||
assert post
|
assert post
|
||||||
if files.has(get_post_thumbnail_backup_path(post)):
|
if files.has(get_post_thumbnail_backup_path(post)):
|
||||||
content = files.get(get_post_thumbnail_backup_path(post))
|
content = files.get(get_post_thumbnail_backup_path(post))
|
||||||
else:
|
else:
|
||||||
content = files.get(get_post_content_path(post))
|
content = files.get(get_post_content_path(post))
|
||||||
try:
|
try:
|
||||||
|
assert content
|
||||||
image = images.Image(content)
|
image = images.Image(content)
|
||||||
image.resize_fill(
|
image.resize_fill(
|
||||||
int(config.config['thumbnails']['post_width']),
|
int(config.config['thumbnails']['post_width']),
|
||||||
|
@ -364,14 +474,15 @@ def generate_post_thumbnail(post):
|
||||||
files.save(get_post_thumbnail_path(post), EMPTY_PIXEL)
|
files.save(get_post_thumbnail_path(post), EMPTY_PIXEL)
|
||||||
|
|
||||||
|
|
||||||
def update_post_tags(post, tag_names):
|
def update_post_tags(
|
||||||
|
post: model.Post, tag_names: List[str]) -> List[model.Tag]:
|
||||||
assert post
|
assert post
|
||||||
existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names)
|
existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names)
|
||||||
post.tags = existing_tags + new_tags
|
post.tags = existing_tags + new_tags
|
||||||
return new_tags
|
return new_tags
|
||||||
|
|
||||||
|
|
||||||
def update_post_relations(post, new_post_ids):
|
def update_post_relations(post: model.Post, new_post_ids: List[int]) -> None:
|
||||||
assert post
|
assert post
|
||||||
try:
|
try:
|
||||||
new_post_ids = [int(id) for id in new_post_ids]
|
new_post_ids = [int(id) for id in new_post_ids]
|
||||||
|
@ -382,8 +493,8 @@ def update_post_relations(post, new_post_ids):
|
||||||
old_post_ids = [int(p.post_id) for p in old_posts]
|
old_post_ids = [int(p.post_id) for p in old_posts]
|
||||||
if new_post_ids:
|
if new_post_ids:
|
||||||
new_posts = db.session \
|
new_posts = db.session \
|
||||||
.query(db.Post) \
|
.query(model.Post) \
|
||||||
.filter(db.Post.post_id.in_(new_post_ids)) \
|
.filter(model.Post.post_id.in_(new_post_ids)) \
|
||||||
.all()
|
.all()
|
||||||
else:
|
else:
|
||||||
new_posts = []
|
new_posts = []
|
||||||
|
@ -402,7 +513,7 @@ def update_post_relations(post, new_post_ids):
|
||||||
relation.relations.append(post)
|
relation.relations.append(post)
|
||||||
|
|
||||||
|
|
||||||
def update_post_notes(post, notes):
|
def update_post_notes(post: model.Post, notes: Any) -> None:
|
||||||
assert post
|
assert post
|
||||||
post.notes = []
|
post.notes = []
|
||||||
for note in notes:
|
for note in notes:
|
||||||
|
@ -433,13 +544,13 @@ def update_post_notes(post, notes):
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise InvalidPostNoteError(
|
raise InvalidPostNoteError(
|
||||||
'A point in note\'s polygon must be numeric.')
|
'A point in note\'s polygon must be numeric.')
|
||||||
if util.value_exceeds_column_size(note['text'], db.PostNote.text):
|
if util.value_exceeds_column_size(note['text'], model.PostNote.text):
|
||||||
raise InvalidPostNoteError('Note text is too long.')
|
raise InvalidPostNoteError('Note text is too long.')
|
||||||
post.notes.append(
|
post.notes.append(
|
||||||
db.PostNote(polygon=note['polygon'], text=str(note['text'])))
|
model.PostNote(polygon=note['polygon'], text=str(note['text'])))
|
||||||
|
|
||||||
|
|
||||||
def update_post_flags(post, flags):
|
def update_post_flags(post: model.Post, flags: List[str]) -> None:
|
||||||
assert post
|
assert post
|
||||||
target_flags = []
|
target_flags = []
|
||||||
for flag in flags:
|
for flag in flags:
|
||||||
|
@ -451,88 +562,95 @@ def update_post_flags(post, flags):
|
||||||
post.flags = target_flags
|
post.flags = target_flags
|
||||||
|
|
||||||
|
|
||||||
def feature_post(post, user):
|
def feature_post(post: model.Post, user: Optional[model.User]) -> None:
|
||||||
assert post
|
assert post
|
||||||
post_feature = db.PostFeature()
|
post_feature = model.PostFeature()
|
||||||
post_feature.time = datetime.datetime.utcnow()
|
post_feature.time = datetime.utcnow()
|
||||||
post_feature.post = post
|
post_feature.post = post
|
||||||
post_feature.user = user
|
post_feature.user = user
|
||||||
db.session.add(post_feature)
|
db.session.add(post_feature)
|
||||||
|
|
||||||
|
|
||||||
def delete(post):
|
def delete(post: model.Post) -> None:
|
||||||
assert post
|
assert post
|
||||||
db.session.delete(post)
|
db.session.delete(post)
|
||||||
|
|
||||||
|
|
||||||
def merge_posts(source_post, target_post, replace_content):
|
def merge_posts(
|
||||||
|
source_post: model.Post,
|
||||||
|
target_post: model.Post,
|
||||||
|
replace_content: bool) -> None:
|
||||||
assert source_post
|
assert source_post
|
||||||
assert target_post
|
assert target_post
|
||||||
if source_post.post_id == target_post.post_id:
|
if source_post.post_id == target_post.post_id:
|
||||||
raise InvalidPostRelationError('Cannot merge post with itself.')
|
raise InvalidPostRelationError('Cannot merge post with itself.')
|
||||||
|
|
||||||
def merge_tables(table, anti_dup_func, source_post_id, target_post_id):
|
def merge_tables(
|
||||||
|
table: model.Base,
|
||||||
|
anti_dup_func: Optional[Callable[[model.Base, model.Base], bool]],
|
||||||
|
source_post_id: int,
|
||||||
|
target_post_id: int) -> None:
|
||||||
alias1 = table
|
alias1 = table
|
||||||
alias2 = sqlalchemy.orm.util.aliased(table)
|
alias2 = sa.orm.util.aliased(table)
|
||||||
update_stmt = (
|
update_stmt = (
|
||||||
sqlalchemy.sql.expression.update(alias1)
|
sa.sql.expression.update(alias1)
|
||||||
.where(alias1.post_id == source_post_id))
|
.where(alias1.post_id == source_post_id))
|
||||||
|
|
||||||
if anti_dup_func is not None:
|
if anti_dup_func is not None:
|
||||||
update_stmt = (
|
update_stmt = (
|
||||||
update_stmt
|
update_stmt
|
||||||
.where(
|
.where(
|
||||||
~sqlalchemy.exists()
|
~sa.exists()
|
||||||
.where(anti_dup_func(alias1, alias2))
|
.where(anti_dup_func(alias1, alias2))
|
||||||
.where(alias2.post_id == target_post_id)))
|
.where(alias2.post_id == target_post_id)))
|
||||||
|
|
||||||
update_stmt = update_stmt.values(post_id=target_post_id)
|
update_stmt = update_stmt.values(post_id=target_post_id)
|
||||||
db.session.execute(update_stmt)
|
db.session.execute(update_stmt)
|
||||||
|
|
||||||
def merge_tags(source_post_id, target_post_id):
|
def merge_tags(source_post_id: int, target_post_id: int) -> None:
|
||||||
merge_tables(
|
merge_tables(
|
||||||
db.PostTag,
|
model.PostTag,
|
||||||
lambda alias1, alias2: alias1.tag_id == alias2.tag_id,
|
lambda alias1, alias2: alias1.tag_id == alias2.tag_id,
|
||||||
source_post_id,
|
source_post_id,
|
||||||
target_post_id)
|
target_post_id)
|
||||||
|
|
||||||
def merge_scores(source_post_id, target_post_id):
|
def merge_scores(source_post_id: int, target_post_id: int) -> None:
|
||||||
merge_tables(
|
merge_tables(
|
||||||
db.PostScore,
|
model.PostScore,
|
||||||
lambda alias1, alias2: alias1.user_id == alias2.user_id,
|
lambda alias1, alias2: alias1.user_id == alias2.user_id,
|
||||||
source_post_id,
|
source_post_id,
|
||||||
target_post_id)
|
target_post_id)
|
||||||
|
|
||||||
def merge_favorites(source_post_id, target_post_id):
|
def merge_favorites(source_post_id: int, target_post_id: int) -> None:
|
||||||
merge_tables(
|
merge_tables(
|
||||||
db.PostFavorite,
|
model.PostFavorite,
|
||||||
lambda alias1, alias2: alias1.user_id == alias2.user_id,
|
lambda alias1, alias2: alias1.user_id == alias2.user_id,
|
||||||
source_post_id,
|
source_post_id,
|
||||||
target_post_id)
|
target_post_id)
|
||||||
|
|
||||||
def merge_comments(source_post_id, target_post_id):
|
def merge_comments(source_post_id: int, target_post_id: int) -> None:
|
||||||
merge_tables(db.Comment, None, source_post_id, target_post_id)
|
merge_tables(model.Comment, None, source_post_id, target_post_id)
|
||||||
|
|
||||||
def merge_relations(source_post_id, target_post_id):
|
def merge_relations(source_post_id: int, target_post_id: int) -> None:
|
||||||
alias1 = db.PostRelation
|
alias1 = model.PostRelation
|
||||||
alias2 = sqlalchemy.orm.util.aliased(db.PostRelation)
|
alias2 = sa.orm.util.aliased(model.PostRelation)
|
||||||
update_stmt = (
|
update_stmt = (
|
||||||
sqlalchemy.sql.expression.update(alias1)
|
sa.sql.expression.update(alias1)
|
||||||
.where(alias1.parent_id == source_post_id)
|
.where(alias1.parent_id == source_post_id)
|
||||||
.where(alias1.child_id != target_post_id)
|
.where(alias1.child_id != target_post_id)
|
||||||
.where(
|
.where(
|
||||||
~sqlalchemy.exists()
|
~sa.exists()
|
||||||
.where(alias2.child_id == alias1.child_id)
|
.where(alias2.child_id == alias1.child_id)
|
||||||
.where(alias2.parent_id == target_post_id))
|
.where(alias2.parent_id == target_post_id))
|
||||||
.values(parent_id=target_post_id))
|
.values(parent_id=target_post_id))
|
||||||
db.session.execute(update_stmt)
|
db.session.execute(update_stmt)
|
||||||
|
|
||||||
update_stmt = (
|
update_stmt = (
|
||||||
sqlalchemy.sql.expression.update(alias1)
|
sa.sql.expression.update(alias1)
|
||||||
.where(alias1.child_id == source_post_id)
|
.where(alias1.child_id == source_post_id)
|
||||||
.where(alias1.parent_id != target_post_id)
|
.where(alias1.parent_id != target_post_id)
|
||||||
.where(
|
.where(
|
||||||
~sqlalchemy.exists()
|
~sa.exists()
|
||||||
.where(alias2.parent_id == alias1.parent_id)
|
.where(alias2.parent_id == alias1.parent_id)
|
||||||
.where(alias2.child_id == target_post_id))
|
.where(alias2.child_id == target_post_id))
|
||||||
.values(child_id=target_post_id))
|
.values(child_id=target_post_id))
|
||||||
|
@ -553,15 +671,15 @@ def merge_posts(source_post, target_post, replace_content):
|
||||||
update_post_content(target_post, content)
|
update_post_content(target_post, content)
|
||||||
|
|
||||||
|
|
||||||
def search_by_image_exact(image_content):
|
def search_by_image_exact(image_content: bytes) -> Optional[model.Post]:
|
||||||
checksum = util.get_sha1(image_content)
|
checksum = util.get_sha1(image_content)
|
||||||
return db.session \
|
return db.session \
|
||||||
.query(db.Post) \
|
.query(model.Post) \
|
||||||
.filter(db.Post.checksum == checksum) \
|
.filter(model.Post.checksum == checksum) \
|
||||||
.one_or_none()
|
.one_or_none()
|
||||||
|
|
||||||
|
|
||||||
def search_by_image(image_content):
|
def search_by_image(image_content: bytes) -> List[PostLookalike]:
|
||||||
ret = []
|
ret = []
|
||||||
for result in image_hash.search_by_image(image_content):
|
for result in image_hash.search_by_image(image_content):
|
||||||
ret.append(PostLookalike(
|
ret.append(PostLookalike(
|
||||||
|
@ -571,24 +689,24 @@ def search_by_image(image_content):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def populate_reverse_search():
|
def populate_reverse_search() -> None:
|
||||||
excluded_post_ids = image_hash.get_all_paths()
|
excluded_post_ids = image_hash.get_all_paths()
|
||||||
|
|
||||||
post_ids_to_hash = (
|
post_ids_to_hash = (
|
||||||
db.session
|
db.session
|
||||||
.query(db.Post.post_id)
|
.query(model.Post.post_id)
|
||||||
.filter(
|
.filter(
|
||||||
(db.Post.type == db.Post.TYPE_IMAGE) |
|
(model.Post.type == model.Post.TYPE_IMAGE) |
|
||||||
(db.Post.type == db.Post.TYPE_ANIMATION))
|
(model.Post.type == model.Post.TYPE_ANIMATION))
|
||||||
.filter(~db.Post.post_id.in_(excluded_post_ids))
|
.filter(~model.Post.post_id.in_(excluded_post_ids))
|
||||||
.order_by(db.Post.post_id.asc())
|
.order_by(model.Post.post_id.asc())
|
||||||
.all())
|
.all())
|
||||||
|
|
||||||
for post_ids_chunk in util.chunks(post_ids_to_hash, 100):
|
for post_ids_chunk in util.chunks(post_ids_to_hash, 100):
|
||||||
posts_chunk = (
|
posts_chunk = (
|
||||||
db.session
|
db.session
|
||||||
.query(db.Post)
|
.query(model.Post)
|
||||||
.filter(db.Post.post_id.in_(post_ids_chunk))
|
.filter(model.Post.post_id.in_(post_ids_chunk))
|
||||||
.all())
|
.all())
|
||||||
for post in posts_chunk:
|
for post in posts_chunk:
|
||||||
content_path = get_post_content_path(post)
|
content_path = get_post_content_path(post)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import datetime
|
import datetime
|
||||||
from szurubooru import db, errors
|
from typing import Any, Tuple, Callable
|
||||||
|
from szurubooru import db, model, errors
|
||||||
|
|
||||||
|
|
||||||
class InvalidScoreTargetError(errors.ValidationError):
|
class InvalidScoreTargetError(errors.ValidationError):
|
||||||
|
@ -10,22 +11,23 @@ class InvalidScoreValueError(errors.ValidationError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _get_table_info(entity):
|
def _get_table_info(
|
||||||
|
entity: model.Base) -> Tuple[model.Base, Callable[[model.Base], Any]]:
|
||||||
assert entity
|
assert entity
|
||||||
resource_type, _, _ = db.util.get_resource_info(entity)
|
resource_type, _, _ = model.util.get_resource_info(entity)
|
||||||
if resource_type == 'post':
|
if resource_type == 'post':
|
||||||
return db.PostScore, lambda table: table.post_id
|
return model.PostScore, lambda table: table.post_id
|
||||||
elif resource_type == 'comment':
|
elif resource_type == 'comment':
|
||||||
return db.CommentScore, lambda table: table.comment_id
|
return model.CommentScore, lambda table: table.comment_id
|
||||||
raise InvalidScoreTargetError()
|
raise InvalidScoreTargetError()
|
||||||
|
|
||||||
|
|
||||||
def _get_score_entity(entity, user):
|
def _get_score_entity(entity: model.Base, user: model.User) -> model.Base:
|
||||||
assert user
|
assert user
|
||||||
return db.util.get_aux_entity(db.session, _get_table_info, entity, user)
|
return model.util.get_aux_entity(db.session, _get_table_info, entity, user)
|
||||||
|
|
||||||
|
|
||||||
def delete_score(entity, user):
|
def delete_score(entity: model.Base, user: model.User) -> None:
|
||||||
assert entity
|
assert entity
|
||||||
assert user
|
assert user
|
||||||
score_entity = _get_score_entity(entity, user)
|
score_entity = _get_score_entity(entity, user)
|
||||||
|
@ -33,7 +35,7 @@ def delete_score(entity, user):
|
||||||
db.session.delete(score_entity)
|
db.session.delete(score_entity)
|
||||||
|
|
||||||
|
|
||||||
def get_score(entity, user):
|
def get_score(entity: model.Base, user: model.User) -> int:
|
||||||
assert entity
|
assert entity
|
||||||
assert user
|
assert user
|
||||||
table, get_column = _get_table_info(entity)
|
table, get_column = _get_table_info(entity)
|
||||||
|
@ -45,7 +47,7 @@ def get_score(entity, user):
|
||||||
return row[0] if row else 0
|
return row[0] if row else 0
|
||||||
|
|
||||||
|
|
||||||
def set_score(entity, user, score):
|
def set_score(entity: model.Base, user: model.User, score: int) -> None:
|
||||||
from szurubooru.func import favorites
|
from szurubooru.func import favorites
|
||||||
assert entity
|
assert entity
|
||||||
assert user
|
assert user
|
||||||
|
|
27
server/szurubooru/func/serialization.py
Normal file
27
server/szurubooru/func/serialization.py
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
from typing import Any, Optional, List, Dict, Callable
|
||||||
|
from szurubooru import db, model, rest, errors
|
||||||
|
|
||||||
|
|
||||||
|
def get_serialization_options(ctx: rest.Context) -> List[str]:
|
||||||
|
return ctx.get_param_as_list('fields', default=[])
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSerializer:
|
||||||
|
_fields = {} # type: Dict[str, Callable[[model.Base], Any]]
|
||||||
|
|
||||||
|
def serialize(self, options: List[str]) -> Any:
|
||||||
|
field_factories = self._serializers()
|
||||||
|
if not options:
|
||||||
|
options = list(field_factories.keys())
|
||||||
|
ret = {}
|
||||||
|
for key in options:
|
||||||
|
if key not in field_factories:
|
||||||
|
raise errors.ValidationError(
|
||||||
|
'Invalid key: %r. Valid keys: %r.' % (
|
||||||
|
key, list(sorted(field_factories.keys()))))
|
||||||
|
factory = field_factories[key]
|
||||||
|
ret[key] = factory()
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def _serializers(self) -> Dict[str, Callable[[], Any]]:
|
||||||
|
raise NotImplementedError()
|
|
@ -1,9 +1,10 @@
|
||||||
|
from typing import Any, Optional, Dict, Callable
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from szurubooru import db
|
from szurubooru import db, model
|
||||||
from szurubooru.func import diff, users
|
from szurubooru.func import diff, users
|
||||||
|
|
||||||
|
|
||||||
def get_tag_category_snapshot(category):
|
def get_tag_category_snapshot(category: model.TagCategory) -> Dict[str, Any]:
|
||||||
assert category
|
assert category
|
||||||
return {
|
return {
|
||||||
'name': category.name,
|
'name': category.name,
|
||||||
|
@ -12,7 +13,7 @@ def get_tag_category_snapshot(category):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_tag_snapshot(tag):
|
def get_tag_snapshot(tag: model.Tag) -> Dict[str, Any]:
|
||||||
assert tag
|
assert tag
|
||||||
return {
|
return {
|
||||||
'names': [tag_name.name for tag_name in tag.names],
|
'names': [tag_name.name for tag_name in tag.names],
|
||||||
|
@ -22,7 +23,7 @@ def get_tag_snapshot(tag):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_post_snapshot(post):
|
def get_post_snapshot(post: model.Post) -> Dict[str, Any]:
|
||||||
assert post
|
assert post
|
||||||
return {
|
return {
|
||||||
'source': post.source,
|
'source': post.source,
|
||||||
|
@ -45,10 +46,11 @@ _snapshot_factories = {
|
||||||
'tag_category': lambda entity: get_tag_category_snapshot(entity),
|
'tag_category': lambda entity: get_tag_category_snapshot(entity),
|
||||||
'tag': lambda entity: get_tag_snapshot(entity),
|
'tag': lambda entity: get_tag_snapshot(entity),
|
||||||
'post': lambda entity: get_post_snapshot(entity),
|
'post': lambda entity: get_post_snapshot(entity),
|
||||||
}
|
} # type: Dict[model.Base, Callable[[model.Base], Dict[str ,Any]]]
|
||||||
|
|
||||||
|
|
||||||
def serialize_snapshot(snapshot, auth_user):
|
def serialize_snapshot(
|
||||||
|
snapshot: model.Snapshot, auth_user: model.User) -> Dict[str, Any]:
|
||||||
assert snapshot
|
assert snapshot
|
||||||
return {
|
return {
|
||||||
'operation': snapshot.operation,
|
'operation': snapshot.operation,
|
||||||
|
@ -60,11 +62,14 @@ def serialize_snapshot(snapshot, auth_user):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _create(operation, entity, auth_user):
|
def _create(
|
||||||
|
operation: str,
|
||||||
|
entity: model.Base,
|
||||||
|
auth_user: Optional[model.User]) -> model.Snapshot:
|
||||||
resource_type, resource_pkey, resource_name = (
|
resource_type, resource_pkey, resource_name = (
|
||||||
db.util.get_resource_info(entity))
|
model.util.get_resource_info(entity))
|
||||||
|
|
||||||
snapshot = db.Snapshot()
|
snapshot = model.Snapshot()
|
||||||
snapshot.creation_time = datetime.utcnow()
|
snapshot.creation_time = datetime.utcnow()
|
||||||
snapshot.operation = operation
|
snapshot.operation = operation
|
||||||
snapshot.resource_type = resource_type
|
snapshot.resource_type = resource_type
|
||||||
|
@ -74,33 +79,33 @@ def _create(operation, entity, auth_user):
|
||||||
return snapshot
|
return snapshot
|
||||||
|
|
||||||
|
|
||||||
def create(entity, auth_user):
|
def create(entity: model.Base, auth_user: Optional[model.User]) -> None:
|
||||||
assert entity
|
assert entity
|
||||||
snapshot = _create(db.Snapshot.OPERATION_CREATED, entity, auth_user)
|
snapshot = _create(model.Snapshot.OPERATION_CREATED, entity, auth_user)
|
||||||
snapshot_factory = _snapshot_factories[snapshot.resource_type]
|
snapshot_factory = _snapshot_factories[snapshot.resource_type]
|
||||||
snapshot.data = snapshot_factory(entity)
|
snapshot.data = snapshot_factory(entity)
|
||||||
db.session.add(snapshot)
|
db.session.add(snapshot)
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
def modify(entity, auth_user):
|
def modify(entity: model.Base, auth_user: Optional[model.User]) -> None:
|
||||||
assert entity
|
assert entity
|
||||||
|
|
||||||
model = next(
|
table = next(
|
||||||
(
|
(
|
||||||
model
|
cls
|
||||||
for model in db.Base._decl_class_registry.values()
|
for cls in model.Base._decl_class_registry.values()
|
||||||
if hasattr(model, '__table__')
|
if hasattr(cls, '__table__')
|
||||||
and model.__table__.fullname == entity.__table__.fullname
|
and cls.__table__.fullname == entity.__table__.fullname
|
||||||
),
|
),
|
||||||
None)
|
None)
|
||||||
assert model
|
assert table
|
||||||
|
|
||||||
snapshot = _create(db.Snapshot.OPERATION_MODIFIED, entity, auth_user)
|
snapshot = _create(model.Snapshot.OPERATION_MODIFIED, entity, auth_user)
|
||||||
snapshot_factory = _snapshot_factories[snapshot.resource_type]
|
snapshot_factory = _snapshot_factories[snapshot.resource_type]
|
||||||
|
|
||||||
detached_session = db.sessionmaker()
|
detached_session = db.sessionmaker()
|
||||||
detached_entity = detached_session.query(model).get(snapshot.resource_pkey)
|
detached_entity = detached_session.query(table).get(snapshot.resource_pkey)
|
||||||
assert detached_entity, 'Entity not found in DB, have you committed it?'
|
assert detached_entity, 'Entity not found in DB, have you committed it?'
|
||||||
detached_snapshot = snapshot_factory(detached_entity)
|
detached_snapshot = snapshot_factory(detached_entity)
|
||||||
detached_session.close()
|
detached_session.close()
|
||||||
|
@ -113,19 +118,23 @@ def modify(entity, auth_user):
|
||||||
db.session.add(snapshot)
|
db.session.add(snapshot)
|
||||||
|
|
||||||
|
|
||||||
def delete(entity, auth_user):
|
def delete(entity: model.Base, auth_user: Optional[model.User]) -> None:
|
||||||
assert entity
|
assert entity
|
||||||
snapshot = _create(db.Snapshot.OPERATION_DELETED, entity, auth_user)
|
snapshot = _create(model.Snapshot.OPERATION_DELETED, entity, auth_user)
|
||||||
snapshot_factory = _snapshot_factories[snapshot.resource_type]
|
snapshot_factory = _snapshot_factories[snapshot.resource_type]
|
||||||
snapshot.data = snapshot_factory(entity)
|
snapshot.data = snapshot_factory(entity)
|
||||||
db.session.add(snapshot)
|
db.session.add(snapshot)
|
||||||
|
|
||||||
|
|
||||||
def merge(source_entity, target_entity, auth_user):
|
def merge(
|
||||||
|
source_entity: model.Base,
|
||||||
|
target_entity: model.Base,
|
||||||
|
auth_user: Optional[model.User]) -> None:
|
||||||
assert source_entity
|
assert source_entity
|
||||||
assert target_entity
|
assert target_entity
|
||||||
snapshot = _create(db.Snapshot.OPERATION_MERGED, source_entity, auth_user)
|
snapshot = _create(
|
||||||
|
model.Snapshot.OPERATION_MERGED, source_entity, auth_user)
|
||||||
resource_type, _resource_pkey, resource_name = (
|
resource_type, _resource_pkey, resource_name = (
|
||||||
db.util.get_resource_info(target_entity))
|
model.util.get_resource_info(target_entity))
|
||||||
snapshot.data = [resource_type, resource_name]
|
snapshot.data = [resource_type, resource_name]
|
||||||
db.session.add(snapshot)
|
db.session.add(snapshot)
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
import re
|
import re
|
||||||
import sqlalchemy
|
from typing import Any, Optional, Dict, List, Callable
|
||||||
from szurubooru import config, db, errors
|
import sqlalchemy as sa
|
||||||
from szurubooru.func import util, cache
|
from szurubooru import config, db, model, errors, rest
|
||||||
|
from szurubooru.func import util, serialization, cache
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_CATEGORY_NAME_CACHE_KEY = 'default-tag-category'
|
DEFAULT_CATEGORY_NAME_CACHE_KEY = 'default-tag-category'
|
||||||
|
@ -27,28 +28,52 @@ class InvalidTagCategoryColorError(errors.ValidationError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _verify_name_validity(name):
|
def _verify_name_validity(name: str) -> None:
|
||||||
name_regex = config.config['tag_category_name_regex']
|
name_regex = config.config['tag_category_name_regex']
|
||||||
if not re.match(name_regex, name):
|
if not re.match(name_regex, name):
|
||||||
raise InvalidTagCategoryNameError(
|
raise InvalidTagCategoryNameError(
|
||||||
'Name must satisfy regex %r.' % name_regex)
|
'Name must satisfy regex %r.' % name_regex)
|
||||||
|
|
||||||
|
|
||||||
def serialize_category(category, options=None):
|
class TagCategorySerializer(serialization.BaseSerializer):
|
||||||
return util.serialize_entity(
|
def __init__(self, category: model.TagCategory) -> None:
|
||||||
category,
|
self.category = category
|
||||||
{
|
|
||||||
'name': lambda: category.name,
|
def _serializers(self) -> Dict[str, Callable[[], Any]]:
|
||||||
'version': lambda: category.version,
|
return {
|
||||||
'color': lambda: category.color,
|
'name': self.serialize_name,
|
||||||
'usages': lambda: category.tag_count,
|
'version': self.serialize_version,
|
||||||
'default': lambda: category.default,
|
'color': self.serialize_color,
|
||||||
},
|
'usages': self.serialize_usages,
|
||||||
options)
|
'default': self.serialize_default,
|
||||||
|
}
|
||||||
|
|
||||||
|
def serialize_name(self) -> Any:
|
||||||
|
return self.category.name
|
||||||
|
|
||||||
|
def serialize_version(self) -> Any:
|
||||||
|
return self.category.version
|
||||||
|
|
||||||
|
def serialize_color(self) -> Any:
|
||||||
|
return self.category.color
|
||||||
|
|
||||||
|
def serialize_usages(self) -> Any:
|
||||||
|
return self.category.tag_count
|
||||||
|
|
||||||
|
def serialize_default(self) -> Any:
|
||||||
|
return self.category.default
|
||||||
|
|
||||||
|
|
||||||
def create_category(name, color):
|
def serialize_category(
|
||||||
category = db.TagCategory()
|
category: Optional[model.TagCategory],
|
||||||
|
options: List[str]=[]) -> Optional[rest.Response]:
|
||||||
|
if not category:
|
||||||
|
return None
|
||||||
|
return TagCategorySerializer(category).serialize(options)
|
||||||
|
|
||||||
|
|
||||||
|
def create_category(name: str, color: str) -> model.TagCategory:
|
||||||
|
category = model.TagCategory()
|
||||||
update_category_name(category, name)
|
update_category_name(category, name)
|
||||||
update_category_color(category, color)
|
update_category_color(category, color)
|
||||||
if not get_all_categories():
|
if not get_all_categories():
|
||||||
|
@ -56,64 +81,66 @@ def create_category(name, color):
|
||||||
return category
|
return category
|
||||||
|
|
||||||
|
|
||||||
def update_category_name(category, name):
|
def update_category_name(category: model.TagCategory, name: str) -> None:
|
||||||
assert category
|
assert category
|
||||||
if not name:
|
if not name:
|
||||||
raise InvalidTagCategoryNameError('Name cannot be empty.')
|
raise InvalidTagCategoryNameError('Name cannot be empty.')
|
||||||
expr = sqlalchemy.func.lower(db.TagCategory.name) == name.lower()
|
expr = sa.func.lower(model.TagCategory.name) == name.lower()
|
||||||
if category.tag_category_id:
|
if category.tag_category_id:
|
||||||
expr = expr & (
|
expr = expr & (
|
||||||
db.TagCategory.tag_category_id != category.tag_category_id)
|
model.TagCategory.tag_category_id != category.tag_category_id)
|
||||||
already_exists = db.session.query(db.TagCategory).filter(expr).count() > 0
|
already_exists = (
|
||||||
|
db.session.query(model.TagCategory).filter(expr).count() > 0)
|
||||||
if already_exists:
|
if already_exists:
|
||||||
raise TagCategoryAlreadyExistsError(
|
raise TagCategoryAlreadyExistsError(
|
||||||
'A category with this name already exists.')
|
'A category with this name already exists.')
|
||||||
if util.value_exceeds_column_size(name, db.TagCategory.name):
|
if util.value_exceeds_column_size(name, model.TagCategory.name):
|
||||||
raise InvalidTagCategoryNameError('Name is too long.')
|
raise InvalidTagCategoryNameError('Name is too long.')
|
||||||
_verify_name_validity(name)
|
_verify_name_validity(name)
|
||||||
category.name = name
|
category.name = name
|
||||||
cache.remove(DEFAULT_CATEGORY_NAME_CACHE_KEY)
|
cache.remove(DEFAULT_CATEGORY_NAME_CACHE_KEY)
|
||||||
|
|
||||||
|
|
||||||
def update_category_color(category, color):
|
def update_category_color(category: model.TagCategory, color: str) -> None:
|
||||||
assert category
|
assert category
|
||||||
if not color:
|
if not color:
|
||||||
raise InvalidTagCategoryColorError('Color cannot be empty.')
|
raise InvalidTagCategoryColorError('Color cannot be empty.')
|
||||||
if not re.match(r'^#?[0-9a-z]+$', color):
|
if not re.match(r'^#?[0-9a-z]+$', color):
|
||||||
raise InvalidTagCategoryColorError('Invalid color.')
|
raise InvalidTagCategoryColorError('Invalid color.')
|
||||||
if util.value_exceeds_column_size(color, db.TagCategory.color):
|
if util.value_exceeds_column_size(color, model.TagCategory.color):
|
||||||
raise InvalidTagCategoryColorError('Color is too long.')
|
raise InvalidTagCategoryColorError('Color is too long.')
|
||||||
category.color = color
|
category.color = color
|
||||||
|
|
||||||
|
|
||||||
def try_get_category_by_name(name, lock=False):
|
def try_get_category_by_name(
|
||||||
|
name: str, lock: bool=False) -> Optional[model.TagCategory]:
|
||||||
query = db.session \
|
query = db.session \
|
||||||
.query(db.TagCategory) \
|
.query(model.TagCategory) \
|
||||||
.filter(sqlalchemy.func.lower(db.TagCategory.name) == name.lower())
|
.filter(sa.func.lower(model.TagCategory.name) == name.lower())
|
||||||
if lock:
|
if lock:
|
||||||
query = query.with_lockmode('update')
|
query = query.with_lockmode('update')
|
||||||
return query.one_or_none()
|
return query.one_or_none()
|
||||||
|
|
||||||
|
|
||||||
def get_category_by_name(name, lock=False):
|
def get_category_by_name(name: str, lock: bool=False) -> model.TagCategory:
|
||||||
category = try_get_category_by_name(name, lock)
|
category = try_get_category_by_name(name, lock)
|
||||||
if not category:
|
if not category:
|
||||||
raise TagCategoryNotFoundError('Tag category %r not found.' % name)
|
raise TagCategoryNotFoundError('Tag category %r not found.' % name)
|
||||||
return category
|
return category
|
||||||
|
|
||||||
|
|
||||||
def get_all_category_names():
|
def get_all_category_names() -> List[str]:
|
||||||
return [row[0] for row in db.session.query(db.TagCategory.name).all()]
|
return [row[0] for row in db.session.query(model.TagCategory.name).all()]
|
||||||
|
|
||||||
|
|
||||||
def get_all_categories():
|
def get_all_categories() -> List[model.TagCategory]:
|
||||||
return db.session.query(db.TagCategory).all()
|
return db.session.query(model.TagCategory).all()
|
||||||
|
|
||||||
|
|
||||||
def try_get_default_category(lock=False):
|
def try_get_default_category(lock: bool=False) -> Optional[model.TagCategory]:
|
||||||
query = db.session \
|
query = db.session \
|
||||||
.query(db.TagCategory) \
|
.query(model.TagCategory) \
|
||||||
.filter(db.TagCategory.default)
|
.filter(model.TagCategory.default)
|
||||||
if lock:
|
if lock:
|
||||||
query = query.with_lockmode('update')
|
query = query.with_lockmode('update')
|
||||||
category = query.first()
|
category = query.first()
|
||||||
|
@ -121,22 +148,22 @@ def try_get_default_category(lock=False):
|
||||||
# category, get the first record available.
|
# category, get the first record available.
|
||||||
if not category:
|
if not category:
|
||||||
query = db.session \
|
query = db.session \
|
||||||
.query(db.TagCategory) \
|
.query(model.TagCategory) \
|
||||||
.order_by(db.TagCategory.tag_category_id.asc())
|
.order_by(model.TagCategory.tag_category_id.asc())
|
||||||
if lock:
|
if lock:
|
||||||
query = query.with_lockmode('update')
|
query = query.with_lockmode('update')
|
||||||
category = query.first()
|
category = query.first()
|
||||||
return category
|
return category
|
||||||
|
|
||||||
|
|
||||||
def get_default_category(lock=False):
|
def get_default_category(lock: bool=False) -> model.TagCategory:
|
||||||
category = try_get_default_category(lock)
|
category = try_get_default_category(lock)
|
||||||
if not category:
|
if not category:
|
||||||
raise TagCategoryNotFoundError('No tag category created yet.')
|
raise TagCategoryNotFoundError('No tag category created yet.')
|
||||||
return category
|
return category
|
||||||
|
|
||||||
|
|
||||||
def get_default_category_name():
|
def get_default_category_name() -> str:
|
||||||
if cache.has(DEFAULT_CATEGORY_NAME_CACHE_KEY):
|
if cache.has(DEFAULT_CATEGORY_NAME_CACHE_KEY):
|
||||||
return cache.get(DEFAULT_CATEGORY_NAME_CACHE_KEY)
|
return cache.get(DEFAULT_CATEGORY_NAME_CACHE_KEY)
|
||||||
default_category = get_default_category()
|
default_category = get_default_category()
|
||||||
|
@ -145,7 +172,7 @@ def get_default_category_name():
|
||||||
return default_category_name
|
return default_category_name
|
||||||
|
|
||||||
|
|
||||||
def set_default_category(category):
|
def set_default_category(category: model.TagCategory) -> None:
|
||||||
assert category
|
assert category
|
||||||
old_category = try_get_default_category(lock=True)
|
old_category = try_get_default_category(lock=True)
|
||||||
if old_category:
|
if old_category:
|
||||||
|
@ -156,7 +183,7 @@ def set_default_category(category):
|
||||||
cache.remove(DEFAULT_CATEGORY_NAME_CACHE_KEY)
|
cache.remove(DEFAULT_CATEGORY_NAME_CACHE_KEY)
|
||||||
|
|
||||||
|
|
||||||
def delete_category(category):
|
def delete_category(category: model.TagCategory) -> None:
|
||||||
assert category
|
assert category
|
||||||
if len(get_all_category_names()) == 1:
|
if len(get_all_category_names()) == 1:
|
||||||
raise TagCategoryIsInUseError('Cannot delete the last category.')
|
raise TagCategoryIsInUseError('Cannot delete the last category.')
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
import datetime
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sqlalchemy
|
from typing import Any, Optional, Tuple, List, Dict, Callable
|
||||||
from szurubooru import config, db, errors
|
from datetime import datetime
|
||||||
from szurubooru.func import util, tag_categories
|
import sqlalchemy as sa
|
||||||
|
from szurubooru import config, db, model, errors, rest
|
||||||
|
from szurubooru.func import util, tag_categories, serialization
|
||||||
|
|
||||||
|
|
||||||
class TagNotFoundError(errors.NotFoundError):
|
class TagNotFoundError(errors.NotFoundError):
|
||||||
|
@ -35,31 +36,32 @@ class InvalidTagDescriptionError(errors.ValidationError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _verify_name_validity(name):
|
def _verify_name_validity(name: str) -> None:
|
||||||
if util.value_exceeds_column_size(name, db.TagName.name):
|
if util.value_exceeds_column_size(name, model.TagName.name):
|
||||||
raise InvalidTagNameError('Name is too long.')
|
raise InvalidTagNameError('Name is too long.')
|
||||||
name_regex = config.config['tag_name_regex']
|
name_regex = config.config['tag_name_regex']
|
||||||
if not re.match(name_regex, name):
|
if not re.match(name_regex, name):
|
||||||
raise InvalidTagNameError('Name must satisfy regex %r.' % name_regex)
|
raise InvalidTagNameError('Name must satisfy regex %r.' % name_regex)
|
||||||
|
|
||||||
|
|
||||||
def _get_names(tag):
|
def _get_names(tag: model.Tag) -> List[str]:
|
||||||
assert tag
|
assert tag
|
||||||
return [tag_name.name for tag_name in tag.names]
|
return [tag_name.name for tag_name in tag.names]
|
||||||
|
|
||||||
|
|
||||||
def _lower_list(names):
|
def _lower_list(names: List[str]) -> List[str]:
|
||||||
return [name.lower() for name in names]
|
return [name.lower() for name in names]
|
||||||
|
|
||||||
|
|
||||||
def _check_name_intersection(names1, names2, case_sensitive):
|
def _check_name_intersection(
|
||||||
|
names1: List[str], names2: List[str], case_sensitive: bool) -> bool:
|
||||||
if not case_sensitive:
|
if not case_sensitive:
|
||||||
names1 = _lower_list(names1)
|
names1 = _lower_list(names1)
|
||||||
names2 = _lower_list(names2)
|
names2 = _lower_list(names2)
|
||||||
return len(set(names1).intersection(names2)) > 0
|
return len(set(names1).intersection(names2)) > 0
|
||||||
|
|
||||||
|
|
||||||
def sort_tags(tags):
|
def sort_tags(tags: List[model.Tag]) -> List[model.Tag]:
|
||||||
default_category_name = tag_categories.get_default_category_name()
|
default_category_name = tag_categories.get_default_category_name()
|
||||||
return sorted(
|
return sorted(
|
||||||
tags,
|
tags,
|
||||||
|
@ -70,35 +72,70 @@ def sort_tags(tags):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def serialize_tag(tag, options=None):
|
class TagSerializer(serialization.BaseSerializer):
|
||||||
return util.serialize_entity(
|
def __init__(self, tag: model.Tag) -> None:
|
||||||
tag,
|
self.tag = tag
|
||||||
{
|
|
||||||
'names': lambda: [tag_name.name for tag_name in tag.names],
|
def _serializers(self) -> Dict[str, Callable[[], Any]]:
|
||||||
'category': lambda: tag.category.name,
|
return {
|
||||||
'version': lambda: tag.version,
|
'names': self.serialize_names,
|
||||||
'description': lambda: tag.description,
|
'category': self.serialize_category,
|
||||||
'creationTime': lambda: tag.creation_time,
|
'version': self.serialize_version,
|
||||||
'lastEditTime': lambda: tag.last_edit_time,
|
'description': self.serialize_description,
|
||||||
'usages': lambda: tag.post_count,
|
'creationTime': self.serialize_creation_time,
|
||||||
'suggestions': lambda: [
|
'lastEditTime': self.serialize_last_edit_time,
|
||||||
|
'usages': self.serialize_usages,
|
||||||
|
'suggestions': self.serialize_suggestions,
|
||||||
|
'implications': self.serialize_implications,
|
||||||
|
}
|
||||||
|
|
||||||
|
def serialize_names(self) -> Any:
|
||||||
|
return [tag_name.name for tag_name in self.tag.names]
|
||||||
|
|
||||||
|
def serialize_category(self) -> Any:
|
||||||
|
return self.tag.category.name
|
||||||
|
|
||||||
|
def serialize_version(self) -> Any:
|
||||||
|
return self.tag.version
|
||||||
|
|
||||||
|
def serialize_description(self) -> Any:
|
||||||
|
return self.tag.description
|
||||||
|
|
||||||
|
def serialize_creation_time(self) -> Any:
|
||||||
|
return self.tag.creation_time
|
||||||
|
|
||||||
|
def serialize_last_edit_time(self) -> Any:
|
||||||
|
return self.tag.last_edit_time
|
||||||
|
|
||||||
|
def serialize_usages(self) -> Any:
|
||||||
|
return self.tag.post_count
|
||||||
|
|
||||||
|
def serialize_suggestions(self) -> Any:
|
||||||
|
return [
|
||||||
relation.names[0].name
|
relation.names[0].name
|
||||||
for relation in sort_tags(tag.suggestions)],
|
for relation in sort_tags(self.tag.suggestions)]
|
||||||
'implications': lambda: [
|
|
||||||
|
def serialize_implications(self) -> Any:
|
||||||
|
return [
|
||||||
relation.names[0].name
|
relation.names[0].name
|
||||||
for relation in sort_tags(tag.implications)],
|
for relation in sort_tags(self.tag.implications)]
|
||||||
},
|
|
||||||
options)
|
|
||||||
|
|
||||||
|
|
||||||
def export_to_json():
|
def serialize_tag(
|
||||||
tags = {}
|
tag: model.Tag, options: List[str]=[]) -> Optional[rest.Response]:
|
||||||
categories = {}
|
if not tag:
|
||||||
|
return None
|
||||||
|
return TagSerializer(tag).serialize(options)
|
||||||
|
|
||||||
|
|
||||||
|
def export_to_json() -> None:
|
||||||
|
tags = {} # type: Dict[int, Any]
|
||||||
|
categories = {} # type: Dict[int, Any]
|
||||||
|
|
||||||
for result in db.session.query(
|
for result in db.session.query(
|
||||||
db.TagCategory.tag_category_id,
|
model.TagCategory.tag_category_id,
|
||||||
db.TagCategory.name,
|
model.TagCategory.name,
|
||||||
db.TagCategory.color).all():
|
model.TagCategory.color).all():
|
||||||
categories[result[0]] = {
|
categories[result[0]] = {
|
||||||
'name': result[1],
|
'name': result[1],
|
||||||
'color': result[2],
|
'color': result[2],
|
||||||
|
@ -106,8 +143,8 @@ def export_to_json():
|
||||||
|
|
||||||
for result in (
|
for result in (
|
||||||
db.session
|
db.session
|
||||||
.query(db.TagName.tag_id, db.TagName.name)
|
.query(model.TagName.tag_id, model.TagName.name)
|
||||||
.order_by(db.TagName.order)
|
.order_by(model.TagName.order)
|
||||||
.all()):
|
.all()):
|
||||||
if not result[0] in tags:
|
if not result[0] in tags:
|
||||||
tags[result[0]] = {'names': []}
|
tags[result[0]] = {'names': []}
|
||||||
|
@ -115,8 +152,10 @@ def export_to_json():
|
||||||
|
|
||||||
for result in (
|
for result in (
|
||||||
db.session
|
db.session
|
||||||
.query(db.TagSuggestion.parent_id, db.TagName.name)
|
.query(model.TagSuggestion.parent_id, model.TagName.name)
|
||||||
.join(db.TagName, db.TagName.tag_id == db.TagSuggestion.child_id)
|
.join(
|
||||||
|
model.TagName,
|
||||||
|
model.TagName.tag_id == model.TagSuggestion.child_id)
|
||||||
.all()):
|
.all()):
|
||||||
if 'suggestions' not in tags[result[0]]:
|
if 'suggestions' not in tags[result[0]]:
|
||||||
tags[result[0]]['suggestions'] = []
|
tags[result[0]]['suggestions'] = []
|
||||||
|
@ -124,17 +163,19 @@ def export_to_json():
|
||||||
|
|
||||||
for result in (
|
for result in (
|
||||||
db.session
|
db.session
|
||||||
.query(db.TagImplication.parent_id, db.TagName.name)
|
.query(model.TagImplication.parent_id, model.TagName.name)
|
||||||
.join(db.TagName, db.TagName.tag_id == db.TagImplication.child_id)
|
.join(
|
||||||
|
model.TagName,
|
||||||
|
model.TagName.tag_id == model.TagImplication.child_id)
|
||||||
.all()):
|
.all()):
|
||||||
if 'implications' not in tags[result[0]]:
|
if 'implications' not in tags[result[0]]:
|
||||||
tags[result[0]]['implications'] = []
|
tags[result[0]]['implications'] = []
|
||||||
tags[result[0]]['implications'].append(result[1])
|
tags[result[0]]['implications'].append(result[1])
|
||||||
|
|
||||||
for result in db.session.query(
|
for result in db.session.query(
|
||||||
db.Tag.tag_id,
|
model.Tag.tag_id,
|
||||||
db.Tag.category_id,
|
model.Tag.category_id,
|
||||||
db.Tag.post_count).all():
|
model.Tag.post_count).all():
|
||||||
tags[result[0]]['category'] = categories[result[1]]['name']
|
tags[result[0]]['category'] = categories[result[1]]['name']
|
||||||
tags[result[0]]['usages'] = result[2]
|
tags[result[0]]['usages'] = result[2]
|
||||||
|
|
||||||
|
@ -148,33 +189,34 @@ def export_to_json():
|
||||||
handle.write(json.dumps(output, separators=(',', ':')))
|
handle.write(json.dumps(output, separators=(',', ':')))
|
||||||
|
|
||||||
|
|
||||||
def try_get_tag_by_name(name):
|
def try_get_tag_by_name(name: str) -> Optional[model.Tag]:
|
||||||
return (
|
return (
|
||||||
db.session
|
db.session
|
||||||
.query(db.Tag)
|
.query(model.Tag)
|
||||||
.join(db.TagName)
|
.join(model.TagName)
|
||||||
.filter(sqlalchemy.func.lower(db.TagName.name) == name.lower())
|
.filter(sa.func.lower(model.TagName.name) == name.lower())
|
||||||
.one_or_none())
|
.one_or_none())
|
||||||
|
|
||||||
|
|
||||||
def get_tag_by_name(name):
|
def get_tag_by_name(name: str) -> model.Tag:
|
||||||
tag = try_get_tag_by_name(name)
|
tag = try_get_tag_by_name(name)
|
||||||
if not tag:
|
if not tag:
|
||||||
raise TagNotFoundError('Tag %r not found.' % name)
|
raise TagNotFoundError('Tag %r not found.' % name)
|
||||||
return tag
|
return tag
|
||||||
|
|
||||||
|
|
||||||
def get_tags_by_names(names):
|
def get_tags_by_names(names: List[str]) -> List[model.Tag]:
|
||||||
names = util.icase_unique(names)
|
names = util.icase_unique(names)
|
||||||
if len(names) == 0:
|
if len(names) == 0:
|
||||||
return []
|
return []
|
||||||
expr = sqlalchemy.sql.false()
|
expr = sa.sql.false()
|
||||||
for name in names:
|
for name in names:
|
||||||
expr = expr | (sqlalchemy.func.lower(db.TagName.name) == name.lower())
|
expr = expr | (sa.func.lower(model.TagName.name) == name.lower())
|
||||||
return db.session.query(db.Tag).join(db.TagName).filter(expr).all()
|
return db.session.query(model.Tag).join(model.TagName).filter(expr).all()
|
||||||
|
|
||||||
|
|
||||||
def get_or_create_tags_by_names(names):
|
def get_or_create_tags_by_names(
|
||||||
|
names: List[str]) -> Tuple[List[model.Tag], List[model.Tag]]:
|
||||||
names = util.icase_unique(names)
|
names = util.icase_unique(names)
|
||||||
existing_tags = get_tags_by_names(names)
|
existing_tags = get_tags_by_names(names)
|
||||||
new_tags = []
|
new_tags = []
|
||||||
|
@ -197,86 +239,87 @@ def get_or_create_tags_by_names(names):
|
||||||
return existing_tags, new_tags
|
return existing_tags, new_tags
|
||||||
|
|
||||||
|
|
||||||
def get_tag_siblings(tag):
|
def get_tag_siblings(tag: model.Tag) -> List[model.Tag]:
|
||||||
assert tag
|
assert tag
|
||||||
tag_alias = sqlalchemy.orm.aliased(db.Tag)
|
tag_alias = sa.orm.aliased(model.Tag)
|
||||||
pt_alias1 = sqlalchemy.orm.aliased(db.PostTag)
|
pt_alias1 = sa.orm.aliased(model.PostTag)
|
||||||
pt_alias2 = sqlalchemy.orm.aliased(db.PostTag)
|
pt_alias2 = sa.orm.aliased(model.PostTag)
|
||||||
result = (
|
result = (
|
||||||
db.session
|
db.session
|
||||||
.query(tag_alias, sqlalchemy.func.count(pt_alias2.post_id))
|
.query(tag_alias, sa.func.count(pt_alias2.post_id))
|
||||||
.join(pt_alias1, pt_alias1.tag_id == tag_alias.tag_id)
|
.join(pt_alias1, pt_alias1.tag_id == tag_alias.tag_id)
|
||||||
.join(pt_alias2, pt_alias2.post_id == pt_alias1.post_id)
|
.join(pt_alias2, pt_alias2.post_id == pt_alias1.post_id)
|
||||||
.filter(pt_alias2.tag_id == tag.tag_id)
|
.filter(pt_alias2.tag_id == tag.tag_id)
|
||||||
.filter(pt_alias1.tag_id != tag.tag_id)
|
.filter(pt_alias1.tag_id != tag.tag_id)
|
||||||
.group_by(tag_alias.tag_id)
|
.group_by(tag_alias.tag_id)
|
||||||
.order_by(sqlalchemy.func.count(pt_alias2.post_id).desc())
|
.order_by(sa.func.count(pt_alias2.post_id).desc())
|
||||||
.limit(50))
|
.limit(50))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def delete(source_tag):
|
def delete(source_tag: model.Tag) -> None:
|
||||||
assert source_tag
|
assert source_tag
|
||||||
db.session.execute(
|
db.session.execute(
|
||||||
sqlalchemy.sql.expression.delete(db.TagSuggestion)
|
sa.sql.expression.delete(model.TagSuggestion)
|
||||||
.where(db.TagSuggestion.child_id == source_tag.tag_id))
|
.where(model.TagSuggestion.child_id == source_tag.tag_id))
|
||||||
db.session.execute(
|
db.session.execute(
|
||||||
sqlalchemy.sql.expression.delete(db.TagImplication)
|
sa.sql.expression.delete(model.TagImplication)
|
||||||
.where(db.TagImplication.child_id == source_tag.tag_id))
|
.where(model.TagImplication.child_id == source_tag.tag_id))
|
||||||
db.session.delete(source_tag)
|
db.session.delete(source_tag)
|
||||||
|
|
||||||
|
|
||||||
def merge_tags(source_tag, target_tag):
|
def merge_tags(source_tag: model.Tag, target_tag: model.Tag) -> None:
|
||||||
assert source_tag
|
assert source_tag
|
||||||
assert target_tag
|
assert target_tag
|
||||||
if source_tag.tag_id == target_tag.tag_id:
|
if source_tag.tag_id == target_tag.tag_id:
|
||||||
raise InvalidTagRelationError('Cannot merge tag with itself.')
|
raise InvalidTagRelationError('Cannot merge tag with itself.')
|
||||||
|
|
||||||
def merge_posts(source_tag_id, target_tag_id):
|
def merge_posts(source_tag_id: int, target_tag_id: int) -> None:
|
||||||
alias1 = db.PostTag
|
alias1 = model.PostTag
|
||||||
alias2 = sqlalchemy.orm.util.aliased(db.PostTag)
|
alias2 = sa.orm.util.aliased(model.PostTag)
|
||||||
update_stmt = (
|
update_stmt = (
|
||||||
sqlalchemy.sql.expression.update(alias1)
|
sa.sql.expression.update(alias1)
|
||||||
.where(alias1.tag_id == source_tag_id))
|
.where(alias1.tag_id == source_tag_id))
|
||||||
update_stmt = (
|
update_stmt = (
|
||||||
update_stmt
|
update_stmt
|
||||||
.where(
|
.where(
|
||||||
~sqlalchemy.exists()
|
~sa.exists()
|
||||||
.where(alias1.post_id == alias2.post_id)
|
.where(alias1.post_id == alias2.post_id)
|
||||||
.where(alias2.tag_id == target_tag_id)))
|
.where(alias2.tag_id == target_tag_id)))
|
||||||
update_stmt = update_stmt.values(tag_id=target_tag_id)
|
update_stmt = update_stmt.values(tag_id=target_tag_id)
|
||||||
db.session.execute(update_stmt)
|
db.session.execute(update_stmt)
|
||||||
|
|
||||||
def merge_relations(table, source_tag_id, target_tag_id):
|
def merge_relations(
|
||||||
|
table: model.Base, source_tag_id: int, target_tag_id: int) -> None:
|
||||||
alias1 = table
|
alias1 = table
|
||||||
alias2 = sqlalchemy.orm.util.aliased(table)
|
alias2 = sa.orm.util.aliased(table)
|
||||||
update_stmt = (
|
update_stmt = (
|
||||||
sqlalchemy.sql.expression.update(alias1)
|
sa.sql.expression.update(alias1)
|
||||||
.where(alias1.parent_id == source_tag_id)
|
.where(alias1.parent_id == source_tag_id)
|
||||||
.where(alias1.child_id != target_tag_id)
|
.where(alias1.child_id != target_tag_id)
|
||||||
.where(
|
.where(
|
||||||
~sqlalchemy.exists()
|
~sa.exists()
|
||||||
.where(alias2.child_id == alias1.child_id)
|
.where(alias2.child_id == alias1.child_id)
|
||||||
.where(alias2.parent_id == target_tag_id))
|
.where(alias2.parent_id == target_tag_id))
|
||||||
.values(parent_id=target_tag_id))
|
.values(parent_id=target_tag_id))
|
||||||
db.session.execute(update_stmt)
|
db.session.execute(update_stmt)
|
||||||
|
|
||||||
update_stmt = (
|
update_stmt = (
|
||||||
sqlalchemy.sql.expression.update(alias1)
|
sa.sql.expression.update(alias1)
|
||||||
.where(alias1.child_id == source_tag_id)
|
.where(alias1.child_id == source_tag_id)
|
||||||
.where(alias1.parent_id != target_tag_id)
|
.where(alias1.parent_id != target_tag_id)
|
||||||
.where(
|
.where(
|
||||||
~sqlalchemy.exists()
|
~sa.exists()
|
||||||
.where(alias2.parent_id == alias1.parent_id)
|
.where(alias2.parent_id == alias1.parent_id)
|
||||||
.where(alias2.child_id == target_tag_id))
|
.where(alias2.child_id == target_tag_id))
|
||||||
.values(child_id=target_tag_id))
|
.values(child_id=target_tag_id))
|
||||||
db.session.execute(update_stmt)
|
db.session.execute(update_stmt)
|
||||||
|
|
||||||
def merge_suggestions(source_tag_id, target_tag_id):
|
def merge_suggestions(source_tag_id: int, target_tag_id: int) -> None:
|
||||||
merge_relations(db.TagSuggestion, source_tag_id, target_tag_id)
|
merge_relations(model.TagSuggestion, source_tag_id, target_tag_id)
|
||||||
|
|
||||||
def merge_implications(source_tag_id, target_tag_id):
|
def merge_implications(source_tag_id: int, target_tag_id: int) -> None:
|
||||||
merge_relations(db.TagImplication, source_tag_id, target_tag_id)
|
merge_relations(model.TagImplication, source_tag_id, target_tag_id)
|
||||||
|
|
||||||
merge_posts(source_tag.tag_id, target_tag.tag_id)
|
merge_posts(source_tag.tag_id, target_tag.tag_id)
|
||||||
merge_suggestions(source_tag.tag_id, target_tag.tag_id)
|
merge_suggestions(source_tag.tag_id, target_tag.tag_id)
|
||||||
|
@ -284,9 +327,13 @@ def merge_tags(source_tag, target_tag):
|
||||||
delete(source_tag)
|
delete(source_tag)
|
||||||
|
|
||||||
|
|
||||||
def create_tag(names, category_name, suggestions, implications):
|
def create_tag(
|
||||||
tag = db.Tag()
|
names: List[str],
|
||||||
tag.creation_time = datetime.datetime.utcnow()
|
category_name: str,
|
||||||
|
suggestions: List[str],
|
||||||
|
implications: List[str]) -> model.Tag:
|
||||||
|
tag = model.Tag()
|
||||||
|
tag.creation_time = datetime.utcnow()
|
||||||
update_tag_names(tag, names)
|
update_tag_names(tag, names)
|
||||||
update_tag_category_name(tag, category_name)
|
update_tag_category_name(tag, category_name)
|
||||||
update_tag_suggestions(tag, suggestions)
|
update_tag_suggestions(tag, suggestions)
|
||||||
|
@ -294,12 +341,12 @@ def create_tag(names, category_name, suggestions, implications):
|
||||||
return tag
|
return tag
|
||||||
|
|
||||||
|
|
||||||
def update_tag_category_name(tag, category_name):
|
def update_tag_category_name(tag: model.Tag, category_name: str) -> None:
|
||||||
assert tag
|
assert tag
|
||||||
tag.category = tag_categories.get_category_by_name(category_name)
|
tag.category = tag_categories.get_category_by_name(category_name)
|
||||||
|
|
||||||
|
|
||||||
def update_tag_names(tag, names):
|
def update_tag_names(tag: model.Tag, names: List[str]) -> None:
|
||||||
# sanitize
|
# sanitize
|
||||||
assert tag
|
assert tag
|
||||||
names = util.icase_unique([name for name in names if name])
|
names = util.icase_unique([name for name in names if name])
|
||||||
|
@ -309,12 +356,12 @@ def update_tag_names(tag, names):
|
||||||
_verify_name_validity(name)
|
_verify_name_validity(name)
|
||||||
|
|
||||||
# check for existing tags
|
# check for existing tags
|
||||||
expr = sqlalchemy.sql.false()
|
expr = sa.sql.false()
|
||||||
for name in names:
|
for name in names:
|
||||||
expr = expr | (sqlalchemy.func.lower(db.TagName.name) == name.lower())
|
expr = expr | (sa.func.lower(model.TagName.name) == name.lower())
|
||||||
if tag.tag_id:
|
if tag.tag_id:
|
||||||
expr = expr & (db.TagName.tag_id != tag.tag_id)
|
expr = expr & (model.TagName.tag_id != tag.tag_id)
|
||||||
existing_tags = db.session.query(db.TagName).filter(expr).all()
|
existing_tags = db.session.query(model.TagName).filter(expr).all()
|
||||||
if len(existing_tags):
|
if len(existing_tags):
|
||||||
raise TagAlreadyExistsError(
|
raise TagAlreadyExistsError(
|
||||||
'One of names is already used by another tag.')
|
'One of names is already used by another tag.')
|
||||||
|
@ -326,7 +373,7 @@ def update_tag_names(tag, names):
|
||||||
# add wanted items
|
# add wanted items
|
||||||
for name in names:
|
for name in names:
|
||||||
if not _check_name_intersection(_get_names(tag), [name], True):
|
if not _check_name_intersection(_get_names(tag), [name], True):
|
||||||
tag.names.append(db.TagName(name, None))
|
tag.names.append(model.TagName(name, -1))
|
||||||
|
|
||||||
# set alias order to match the request
|
# set alias order to match the request
|
||||||
for i, name in enumerate(names):
|
for i, name in enumerate(names):
|
||||||
|
@ -336,7 +383,7 @@ def update_tag_names(tag, names):
|
||||||
|
|
||||||
|
|
||||||
# TODO: what to do with relations that do not yet exist?
|
# TODO: what to do with relations that do not yet exist?
|
||||||
def update_tag_implications(tag, relations):
|
def update_tag_implications(tag: model.Tag, relations: List[str]) -> None:
|
||||||
assert tag
|
assert tag
|
||||||
if _check_name_intersection(_get_names(tag), relations, False):
|
if _check_name_intersection(_get_names(tag), relations, False):
|
||||||
raise InvalidTagRelationError('Tag cannot imply itself.')
|
raise InvalidTagRelationError('Tag cannot imply itself.')
|
||||||
|
@ -344,15 +391,15 @@ def update_tag_implications(tag, relations):
|
||||||
|
|
||||||
|
|
||||||
# TODO: what to do with relations that do not yet exist?
|
# TODO: what to do with relations that do not yet exist?
|
||||||
def update_tag_suggestions(tag, relations):
|
def update_tag_suggestions(tag: model.Tag, relations: List[str]) -> None:
|
||||||
assert tag
|
assert tag
|
||||||
if _check_name_intersection(_get_names(tag), relations, False):
|
if _check_name_intersection(_get_names(tag), relations, False):
|
||||||
raise InvalidTagRelationError('Tag cannot suggest itself.')
|
raise InvalidTagRelationError('Tag cannot suggest itself.')
|
||||||
tag.suggestions = get_tags_by_names(relations)
|
tag.suggestions = get_tags_by_names(relations)
|
||||||
|
|
||||||
|
|
||||||
def update_tag_description(tag, description):
|
def update_tag_description(tag: model.Tag, description: str) -> None:
|
||||||
assert tag
|
assert tag
|
||||||
if util.value_exceeds_column_size(description, db.Tag.description):
|
if util.value_exceeds_column_size(description, model.Tag.description):
|
||||||
raise InvalidTagDescriptionError('Description is too long.')
|
raise InvalidTagDescriptionError('Description is too long.')
|
||||||
tag.description = description
|
tag.description = description or None
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
import datetime
|
|
||||||
import re
|
import re
|
||||||
from sqlalchemy import func
|
from typing import Any, Optional, Union, List, Dict, Callable
|
||||||
from szurubooru import config, db, errors
|
from datetime import datetime
|
||||||
from szurubooru.func import auth, util, files, images
|
import sqlalchemy as sa
|
||||||
|
from szurubooru import config, db, model, errors, rest
|
||||||
|
from szurubooru.func import auth, util, serialization, files, images
|
||||||
|
|
||||||
|
|
||||||
class UserNotFoundError(errors.NotFoundError):
|
class UserNotFoundError(errors.NotFoundError):
|
||||||
|
@ -33,11 +34,11 @@ class InvalidAvatarError(errors.ValidationError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def get_avatar_path(user_name):
|
def get_avatar_path(user_name: str) -> str:
|
||||||
return 'avatars/' + user_name.lower() + '.png'
|
return 'avatars/' + user_name.lower() + '.png'
|
||||||
|
|
||||||
|
|
||||||
def get_avatar_url(user):
|
def get_avatar_url(user: model.User) -> str:
|
||||||
assert user
|
assert user
|
||||||
if user.avatar_style == user.AVATAR_GRAVATAR:
|
if user.avatar_style == user.AVATAR_GRAVATAR:
|
||||||
assert user.email or user.name
|
assert user.email or user.name
|
||||||
|
@ -49,7 +50,10 @@ def get_avatar_url(user):
|
||||||
config.config['data_url'].rstrip('/'), user.name.lower())
|
config.config['data_url'].rstrip('/'), user.name.lower())
|
||||||
|
|
||||||
|
|
||||||
def get_email(user, auth_user, force_show_email):
|
def get_email(
|
||||||
|
user: model.User,
|
||||||
|
auth_user: model.User,
|
||||||
|
force_show_email: bool) -> Union[bool, str]:
|
||||||
assert user
|
assert user
|
||||||
assert auth_user
|
assert auth_user
|
||||||
if not force_show_email \
|
if not force_show_email \
|
||||||
|
@ -59,7 +63,8 @@ def get_email(user, auth_user, force_show_email):
|
||||||
return user.email
|
return user.email
|
||||||
|
|
||||||
|
|
||||||
def get_liked_post_count(user, auth_user):
|
def get_liked_post_count(
|
||||||
|
user: model.User, auth_user: model.User) -> Union[bool, int]:
|
||||||
assert user
|
assert user
|
||||||
assert auth_user
|
assert auth_user
|
||||||
if auth_user.user_id != user.user_id:
|
if auth_user.user_id != user.user_id:
|
||||||
|
@ -67,7 +72,8 @@ def get_liked_post_count(user, auth_user):
|
||||||
return user.liked_post_count
|
return user.liked_post_count
|
||||||
|
|
||||||
|
|
||||||
def get_disliked_post_count(user, auth_user):
|
def get_disliked_post_count(
|
||||||
|
user: model.User, auth_user: model.User) -> Union[bool, int]:
|
||||||
assert user
|
assert user
|
||||||
assert auth_user
|
assert auth_user
|
||||||
if auth_user.user_id != user.user_id:
|
if auth_user.user_id != user.user_id:
|
||||||
|
@ -75,91 +81,144 @@ def get_disliked_post_count(user, auth_user):
|
||||||
return user.disliked_post_count
|
return user.disliked_post_count
|
||||||
|
|
||||||
|
|
||||||
def serialize_user(user, auth_user, options=None, force_show_email=False):
|
class UserSerializer(serialization.BaseSerializer):
|
||||||
return util.serialize_entity(
|
def __init__(
|
||||||
user,
|
self,
|
||||||
{
|
user: model.User,
|
||||||
'name': lambda: user.name,
|
auth_user: model.User,
|
||||||
'creationTime': lambda: user.creation_time,
|
force_show_email: bool=False) -> None:
|
||||||
'lastLoginTime': lambda: user.last_login_time,
|
self.user = user
|
||||||
'version': lambda: user.version,
|
self.auth_user = auth_user
|
||||||
'rank': lambda: user.rank,
|
self.force_show_email = force_show_email
|
||||||
'avatarStyle': lambda: user.avatar_style,
|
|
||||||
'avatarUrl': lambda: get_avatar_url(user),
|
def _serializers(self) -> Dict[str, Callable[[], Any]]:
|
||||||
'commentCount': lambda: user.comment_count,
|
return {
|
||||||
'uploadedPostCount': lambda: user.post_count,
|
'name': self.serialize_name,
|
||||||
'favoritePostCount': lambda: user.favorite_post_count,
|
'creationTime': self.serialize_creation_time,
|
||||||
'likedPostCount':
|
'lastLoginTime': self.serialize_last_login_time,
|
||||||
lambda: get_liked_post_count(user, auth_user),
|
'version': self.serialize_version,
|
||||||
'dislikedPostCount':
|
'rank': self.serialize_rank,
|
||||||
lambda: get_disliked_post_count(user, auth_user),
|
'avatarStyle': self.serialize_avatar_style,
|
||||||
'email':
|
'avatarUrl': self.serialize_avatar_url,
|
||||||
lambda: get_email(user, auth_user, force_show_email),
|
'commentCount': self.serialize_comment_count,
|
||||||
},
|
'uploadedPostCount': self.serialize_uploaded_post_count,
|
||||||
options)
|
'favoritePostCount': self.serialize_favorite_post_count,
|
||||||
|
'likedPostCount': self.serialize_liked_post_count,
|
||||||
|
'dislikedPostCount': self.serialize_disliked_post_count,
|
||||||
|
'email': self.serialize_email,
|
||||||
|
}
|
||||||
|
|
||||||
|
def serialize_name(self) -> Any:
|
||||||
|
return self.user.name
|
||||||
|
|
||||||
|
def serialize_creation_time(self) -> Any:
|
||||||
|
return self.user.creation_time
|
||||||
|
|
||||||
|
def serialize_last_login_time(self) -> Any:
|
||||||
|
return self.user.last_login_time
|
||||||
|
|
||||||
|
def serialize_version(self) -> Any:
|
||||||
|
return self.user.version
|
||||||
|
|
||||||
|
def serialize_rank(self) -> Any:
|
||||||
|
return self.user.rank
|
||||||
|
|
||||||
|
def serialize_avatar_style(self) -> Any:
|
||||||
|
return self.user.avatar_style
|
||||||
|
|
||||||
|
def serialize_avatar_url(self) -> Any:
|
||||||
|
return get_avatar_url(self.user)
|
||||||
|
|
||||||
|
def serialize_comment_count(self) -> Any:
|
||||||
|
return self.user.comment_count
|
||||||
|
|
||||||
|
def serialize_uploaded_post_count(self) -> Any:
|
||||||
|
return self.user.post_count
|
||||||
|
|
||||||
|
def serialize_favorite_post_count(self) -> Any:
|
||||||
|
return self.user.favorite_post_count
|
||||||
|
|
||||||
|
def serialize_liked_post_count(self) -> Any:
|
||||||
|
return get_liked_post_count(self.user, self.auth_user)
|
||||||
|
|
||||||
|
def serialize_disliked_post_count(self) -> Any:
|
||||||
|
return get_disliked_post_count(self.user, self.auth_user)
|
||||||
|
|
||||||
|
def serialize_email(self) -> Any:
|
||||||
|
return get_email(self.user, self.auth_user, self.force_show_email)
|
||||||
|
|
||||||
|
|
||||||
def serialize_micro_user(user, auth_user):
|
def serialize_user(
|
||||||
|
user: Optional[model.User],
|
||||||
|
auth_user: model.User,
|
||||||
|
options: List[str]=[],
|
||||||
|
force_show_email: bool=False) -> Optional[rest.Response]:
|
||||||
|
if not user:
|
||||||
|
return None
|
||||||
|
return UserSerializer(user, auth_user, force_show_email).serialize(options)
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_micro_user(
|
||||||
|
user: Optional[model.User],
|
||||||
|
auth_user: model.User) -> Optional[rest.Response]:
|
||||||
return serialize_user(
|
return serialize_user(
|
||||||
user,
|
user, auth_user=auth_user, options=['name', 'avatarUrl'])
|
||||||
auth_user=auth_user,
|
|
||||||
options=['name', 'avatarUrl'])
|
|
||||||
|
|
||||||
|
|
||||||
def get_user_count():
|
def get_user_count() -> int:
|
||||||
return db.session.query(db.User).count()
|
return db.session.query(model.User).count()
|
||||||
|
|
||||||
|
|
||||||
def try_get_user_by_name(name):
|
def try_get_user_by_name(name: str) -> Optional[model.User]:
|
||||||
return db.session \
|
return db.session \
|
||||||
.query(db.User) \
|
.query(model.User) \
|
||||||
.filter(func.lower(db.User.name) == func.lower(name)) \
|
.filter(sa.func.lower(model.User.name) == sa.func.lower(name)) \
|
||||||
.one_or_none()
|
.one_or_none()
|
||||||
|
|
||||||
|
|
||||||
def get_user_by_name(name):
|
def get_user_by_name(name: str) -> model.User:
|
||||||
user = try_get_user_by_name(name)
|
user = try_get_user_by_name(name)
|
||||||
if not user:
|
if not user:
|
||||||
raise UserNotFoundError('User %r not found.' % name)
|
raise UserNotFoundError('User %r not found.' % name)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
def try_get_user_by_name_or_email(name_or_email):
|
def try_get_user_by_name_or_email(name_or_email: str) -> Optional[model.User]:
|
||||||
return (
|
return (
|
||||||
db.session
|
db.session
|
||||||
.query(db.User)
|
.query(model.User)
|
||||||
.filter(
|
.filter(
|
||||||
(func.lower(db.User.name) == func.lower(name_or_email)) |
|
(sa.func.lower(model.User.name) == sa.func.lower(name_or_email)) |
|
||||||
(func.lower(db.User.email) == func.lower(name_or_email)))
|
(sa.func.lower(model.User.email) == sa.func.lower(name_or_email)))
|
||||||
.one_or_none())
|
.one_or_none())
|
||||||
|
|
||||||
|
|
||||||
def get_user_by_name_or_email(name_or_email):
|
def get_user_by_name_or_email(name_or_email: str) -> model.User:
|
||||||
user = try_get_user_by_name_or_email(name_or_email)
|
user = try_get_user_by_name_or_email(name_or_email)
|
||||||
if not user:
|
if not user:
|
||||||
raise UserNotFoundError('User %r not found.' % name_or_email)
|
raise UserNotFoundError('User %r not found.' % name_or_email)
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
def create_user(name, password, email):
|
def create_user(name: str, password: str, email: str) -> model.User:
|
||||||
user = db.User()
|
user = model.User()
|
||||||
update_user_name(user, name)
|
update_user_name(user, name)
|
||||||
update_user_password(user, password)
|
update_user_password(user, password)
|
||||||
update_user_email(user, email)
|
update_user_email(user, email)
|
||||||
if get_user_count() > 0:
|
if get_user_count() > 0:
|
||||||
user.rank = util.flip(auth.RANK_MAP)[config.config['default_rank']]
|
user.rank = util.flip(auth.RANK_MAP)[config.config['default_rank']]
|
||||||
else:
|
else:
|
||||||
user.rank = db.User.RANK_ADMINISTRATOR
|
user.rank = model.User.RANK_ADMINISTRATOR
|
||||||
user.creation_time = datetime.datetime.utcnow()
|
user.creation_time = datetime.utcnow()
|
||||||
user.avatar_style = db.User.AVATAR_GRAVATAR
|
user.avatar_style = model.User.AVATAR_GRAVATAR
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
def update_user_name(user, name):
|
def update_user_name(user: model.User, name: str) -> None:
|
||||||
assert user
|
assert user
|
||||||
if not name:
|
if not name:
|
||||||
raise InvalidUserNameError('Name cannot be empty.')
|
raise InvalidUserNameError('Name cannot be empty.')
|
||||||
if util.value_exceeds_column_size(name, db.User.name):
|
if util.value_exceeds_column_size(name, model.User.name):
|
||||||
raise InvalidUserNameError('User name is too long.')
|
raise InvalidUserNameError('User name is too long.')
|
||||||
name = name.strip()
|
name = name.strip()
|
||||||
name_regex = config.config['user_name_regex']
|
name_regex = config.config['user_name_regex']
|
||||||
|
@ -174,7 +233,7 @@ def update_user_name(user, name):
|
||||||
user.name = name
|
user.name = name
|
||||||
|
|
||||||
|
|
||||||
def update_user_password(user, password):
|
def update_user_password(user: model.User, password: str) -> None:
|
||||||
assert user
|
assert user
|
||||||
if not password:
|
if not password:
|
||||||
raise InvalidPasswordError('Password cannot be empty.')
|
raise InvalidPasswordError('Password cannot be empty.')
|
||||||
|
@ -186,20 +245,18 @@ def update_user_password(user, password):
|
||||||
user.password_hash = auth.get_password_hash(user.password_salt, password)
|
user.password_hash = auth.get_password_hash(user.password_salt, password)
|
||||||
|
|
||||||
|
|
||||||
def update_user_email(user, email):
|
def update_user_email(user: model.User, email: str) -> None:
|
||||||
assert user
|
assert user
|
||||||
if email:
|
|
||||||
email = email.strip()
|
email = email.strip()
|
||||||
if not email:
|
if util.value_exceeds_column_size(email, model.User.email):
|
||||||
email = None
|
|
||||||
if email and util.value_exceeds_column_size(email, db.User.email):
|
|
||||||
raise InvalidEmailError('Email is too long.')
|
raise InvalidEmailError('Email is too long.')
|
||||||
if not util.is_valid_email(email):
|
if not util.is_valid_email(email):
|
||||||
raise InvalidEmailError('E-mail is invalid.')
|
raise InvalidEmailError('E-mail is invalid.')
|
||||||
user.email = email
|
user.email = email or None
|
||||||
|
|
||||||
|
|
||||||
def update_user_rank(user, rank, auth_user):
|
def update_user_rank(
|
||||||
|
user: model.User, rank: str, auth_user: model.User) -> None:
|
||||||
assert user
|
assert user
|
||||||
if not rank:
|
if not rank:
|
||||||
raise InvalidRankError('Rank cannot be empty.')
|
raise InvalidRankError('Rank cannot be empty.')
|
||||||
|
@ -208,7 +265,7 @@ def update_user_rank(user, rank, auth_user):
|
||||||
if not rank:
|
if not rank:
|
||||||
raise InvalidRankError(
|
raise InvalidRankError(
|
||||||
'Rank can be either of %r.' % all_ranks)
|
'Rank can be either of %r.' % all_ranks)
|
||||||
if rank in (db.User.RANK_ANONYMOUS, db.User.RANK_NOBODY):
|
if rank in (model.User.RANK_ANONYMOUS, model.User.RANK_NOBODY):
|
||||||
raise InvalidRankError('Rank %r cannot be used.' % auth.RANK_MAP[rank])
|
raise InvalidRankError('Rank %r cannot be used.' % auth.RANK_MAP[rank])
|
||||||
if all_ranks.index(auth_user.rank) \
|
if all_ranks.index(auth_user.rank) \
|
||||||
< all_ranks.index(rank) and get_user_count() > 0:
|
< all_ranks.index(rank) and get_user_count() > 0:
|
||||||
|
@ -216,7 +273,10 @@ def update_user_rank(user, rank, auth_user):
|
||||||
user.rank = rank
|
user.rank = rank
|
||||||
|
|
||||||
|
|
||||||
def update_user_avatar(user, avatar_style, avatar_content=None):
|
def update_user_avatar(
|
||||||
|
user: model.User,
|
||||||
|
avatar_style: str,
|
||||||
|
avatar_content: Optional[bytes]=None) -> None:
|
||||||
assert user
|
assert user
|
||||||
if avatar_style == 'gravatar':
|
if avatar_style == 'gravatar':
|
||||||
user.avatar_style = user.AVATAR_GRAVATAR
|
user.avatar_style = user.AVATAR_GRAVATAR
|
||||||
|
@ -238,12 +298,12 @@ def update_user_avatar(user, avatar_style, avatar_content=None):
|
||||||
avatar_style, ['gravatar', 'manual']))
|
avatar_style, ['gravatar', 'manual']))
|
||||||
|
|
||||||
|
|
||||||
def bump_user_login_time(user):
|
def bump_user_login_time(user: model.User) -> None:
|
||||||
assert user
|
assert user
|
||||||
user.last_login_time = datetime.datetime.utcnow()
|
user.last_login_time = datetime.utcnow()
|
||||||
|
|
||||||
|
|
||||||
def reset_user_password(user):
|
def reset_user_password(user: model.User) -> str:
|
||||||
assert user
|
assert user
|
||||||
password = auth.create_password()
|
password = auth.create_password()
|
||||||
user.password_salt = auth.create_password()
|
user.password_salt = auth.create_password()
|
||||||
|
|
|
@ -2,52 +2,39 @@ import os
|
||||||
import hashlib
|
import hashlib
|
||||||
import re
|
import re
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from typing import (
|
||||||
|
Any, Optional, Union, Tuple, List, Dict, Generator, Callable, TypeVar)
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from szurubooru import errors
|
from szurubooru import errors
|
||||||
|
|
||||||
|
|
||||||
def snake_case_to_lower_camel_case(text):
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
|
||||||
|
def snake_case_to_lower_camel_case(text: str) -> str:
|
||||||
components = text.split('_')
|
components = text.split('_')
|
||||||
return components[0].lower() + \
|
return components[0].lower() + \
|
||||||
''.join(word[0].upper() + word[1:].lower() for word in components[1:])
|
''.join(word[0].upper() + word[1:].lower() for word in components[1:])
|
||||||
|
|
||||||
|
|
||||||
def snake_case_to_upper_train_case(text):
|
def snake_case_to_upper_train_case(text: str) -> str:
|
||||||
return '-'.join(
|
return '-'.join(
|
||||||
word[0].upper() + word[1:].lower() for word in text.split('_'))
|
word[0].upper() + word[1:].lower() for word in text.split('_'))
|
||||||
|
|
||||||
|
|
||||||
def snake_case_to_lower_camel_case_keys(source):
|
def snake_case_to_lower_camel_case_keys(
|
||||||
|
source: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
target = {}
|
target = {}
|
||||||
for key, value in source.items():
|
for key, value in source.items():
|
||||||
target[snake_case_to_lower_camel_case(key)] = value
|
target[snake_case_to_lower_camel_case(key)] = value
|
||||||
return target
|
return target
|
||||||
|
|
||||||
|
|
||||||
def get_serialization_options(ctx):
|
|
||||||
return ctx.get_param_as_list('fields', required=False, default=None)
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_entity(entity, field_factories, options):
|
|
||||||
if not entity:
|
|
||||||
return None
|
|
||||||
if not options or len(options) == 0:
|
|
||||||
options = field_factories.keys()
|
|
||||||
ret = {}
|
|
||||||
for key in options:
|
|
||||||
if key not in field_factories:
|
|
||||||
raise errors.ValidationError('Invalid key: %r. Valid keys: %r.' % (
|
|
||||||
key, list(sorted(field_factories.keys()))))
|
|
||||||
factory = field_factories[key]
|
|
||||||
ret[key] = factory()
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def create_temp_file(**kwargs):
|
def create_temp_file(**kwargs: Any) -> Generator:
|
||||||
(handle, path) = tempfile.mkstemp(**kwargs)
|
(descriptor, path) = tempfile.mkstemp(**kwargs)
|
||||||
os.close(handle)
|
os.close(descriptor)
|
||||||
try:
|
try:
|
||||||
with open(path, 'r+b') as handle:
|
with open(path, 'r+b') as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
@ -55,17 +42,15 @@ def create_temp_file(**kwargs):
|
||||||
os.remove(path)
|
os.remove(path)
|
||||||
|
|
||||||
|
|
||||||
def unalias_dict(input_dict):
|
def unalias_dict(source: List[Tuple[List[str], T]]) -> Dict[str, T]:
|
||||||
output_dict = {}
|
output_dict = {} # type: Dict[str, T]
|
||||||
for key_list, value in input_dict.items():
|
for aliases, value in source:
|
||||||
if isinstance(key_list, str):
|
for alias in aliases:
|
||||||
key_list = [key_list]
|
output_dict[alias] = value
|
||||||
for key in key_list:
|
|
||||||
output_dict[key] = value
|
|
||||||
return output_dict
|
return output_dict
|
||||||
|
|
||||||
|
|
||||||
def get_md5(source):
|
def get_md5(source: Union[str, bytes]) -> str:
|
||||||
if not isinstance(source, bytes):
|
if not isinstance(source, bytes):
|
||||||
source = source.encode('utf-8')
|
source = source.encode('utf-8')
|
||||||
md5 = hashlib.md5()
|
md5 = hashlib.md5()
|
||||||
|
@ -73,7 +58,7 @@ def get_md5(source):
|
||||||
return md5.hexdigest()
|
return md5.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def get_sha1(source):
|
def get_sha1(source: Union[str, bytes]) -> str:
|
||||||
if not isinstance(source, bytes):
|
if not isinstance(source, bytes):
|
||||||
source = source.encode('utf-8')
|
source = source.encode('utf-8')
|
||||||
sha1 = hashlib.sha1()
|
sha1 = hashlib.sha1()
|
||||||
|
@ -81,24 +66,25 @@ def get_sha1(source):
|
||||||
return sha1.hexdigest()
|
return sha1.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def flip(source):
|
def flip(source: Dict[Any, Any]) -> Dict[Any, Any]:
|
||||||
return {v: k for k, v in source.items()}
|
return {v: k for k, v in source.items()}
|
||||||
|
|
||||||
|
|
||||||
def is_valid_email(email):
|
def is_valid_email(email: Optional[str]) -> bool:
|
||||||
''' Return whether given email address is valid or empty. '''
|
''' Return whether given email address is valid or empty. '''
|
||||||
return not email or re.match(r'^[^@]*@[^@]*\.[^@]*$', email)
|
return not email or re.match(r'^[^@]*@[^@]*\.[^@]*$', email) is not None
|
||||||
|
|
||||||
|
|
||||||
class dotdict(dict): # pylint: disable=invalid-name
|
class dotdict(dict): # pylint: disable=invalid-name
|
||||||
''' dot.notation access to dictionary attributes. '''
|
''' dot.notation access to dictionary attributes. '''
|
||||||
def __getattr__(self, attr):
|
def __getattr__(self, attr: str) -> Any:
|
||||||
return self.get(attr)
|
return self.get(attr)
|
||||||
|
|
||||||
__setattr__ = dict.__setitem__
|
__setattr__ = dict.__setitem__
|
||||||
__delattr__ = dict.__delitem__
|
__delattr__ = dict.__delitem__
|
||||||
|
|
||||||
|
|
||||||
def parse_time_range(value):
|
def parse_time_range(value: str) -> Tuple[datetime, datetime]:
|
||||||
''' Return tuple containing min/max time for given text representation. '''
|
''' Return tuple containing min/max time for given text representation. '''
|
||||||
one_day = timedelta(days=1)
|
one_day = timedelta(days=1)
|
||||||
one_second = timedelta(seconds=1)
|
one_second = timedelta(seconds=1)
|
||||||
|
@ -146,9 +132,9 @@ def parse_time_range(value):
|
||||||
raise errors.ValidationError('Invalid date format: %r.' % value)
|
raise errors.ValidationError('Invalid date format: %r.' % value)
|
||||||
|
|
||||||
|
|
||||||
def icase_unique(source):
|
def icase_unique(source: List[str]) -> List[str]:
|
||||||
target = []
|
target = [] # type: List[str]
|
||||||
target_low = []
|
target_low = [] # type: List[str]
|
||||||
for source_item in source:
|
for source_item in source:
|
||||||
if source_item.lower() not in target_low:
|
if source_item.lower() not in target_low:
|
||||||
target.append(source_item)
|
target.append(source_item)
|
||||||
|
@ -156,7 +142,7 @@ def icase_unique(source):
|
||||||
return target
|
return target
|
||||||
|
|
||||||
|
|
||||||
def value_exceeds_column_size(value, column):
|
def value_exceeds_column_size(value: Optional[str], column: Any) -> bool:
|
||||||
if not value:
|
if not value:
|
||||||
return False
|
return False
|
||||||
max_length = column.property.columns[0].type.length
|
max_length = column.property.columns[0].type.length
|
||||||
|
@ -165,6 +151,6 @@ def value_exceeds_column_size(value, column):
|
||||||
return len(value) > max_length
|
return len(value) > max_length
|
||||||
|
|
||||||
|
|
||||||
def chunks(source_list, part_size):
|
def chunks(source_list: List[Any], part_size: int) -> Generator:
|
||||||
for i in range(0, len(source_list), part_size):
|
for i in range(0, len(source_list), part_size):
|
||||||
yield source_list[i:i + part_size]
|
yield source_list[i:i + part_size]
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
from szurubooru import errors
|
from szurubooru import errors, rest, model
|
||||||
|
|
||||||
|
|
||||||
def verify_version(entity, context, field_name='version'):
|
def verify_version(
|
||||||
actual_version = context.get_param_as_int(field_name, required=True)
|
entity: model.Base,
|
||||||
|
context: rest.Context,
|
||||||
|
field_name: str='version') -> None:
|
||||||
|
actual_version = context.get_param_as_int(field_name)
|
||||||
expected_version = entity.version
|
expected_version = entity.version
|
||||||
if actual_version != expected_version:
|
if actual_version != expected_version:
|
||||||
raise errors.IntegrityError(
|
raise errors.IntegrityError(
|
||||||
|
@ -10,5 +13,5 @@ def verify_version(entity, context, field_name='version'):
|
||||||
'Please try again.')
|
'Please try again.')
|
||||||
|
|
||||||
|
|
||||||
def bump_version(entity):
|
def bump_version(entity: model.Base) -> None:
|
||||||
entity.version = entity.version + 1
|
entity.version = entity.version + 1
|
||||||
|
|
|
@ -1,11 +1,11 @@
|
||||||
import base64
|
import base64
|
||||||
from szurubooru import db, errors
|
from typing import Optional
|
||||||
|
from szurubooru import db, model, errors, rest
|
||||||
from szurubooru.func import auth, users
|
from szurubooru.func import auth, users
|
||||||
from szurubooru.rest import middleware
|
|
||||||
from szurubooru.rest.errors import HttpBadRequest
|
from szurubooru.rest.errors import HttpBadRequest
|
||||||
|
|
||||||
|
|
||||||
def _authenticate(username, password):
|
def _authenticate(username: str, password: str) -> model.User:
|
||||||
''' Try to authenticate user. Throw AuthError for invalid users. '''
|
''' Try to authenticate user. Throw AuthError for invalid users. '''
|
||||||
user = users.get_user_by_name(username)
|
user = users.get_user_by_name(username)
|
||||||
if not auth.is_valid_password(user, password):
|
if not auth.is_valid_password(user, password):
|
||||||
|
@ -13,16 +13,9 @@ def _authenticate(username, password):
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
def _create_anonymous_user():
|
def _get_user(ctx: rest.Context) -> Optional[model.User]:
|
||||||
user = db.User()
|
|
||||||
user.name = None
|
|
||||||
user.rank = 'anonymous'
|
|
||||||
return user
|
|
||||||
|
|
||||||
|
|
||||||
def _get_user(ctx):
|
|
||||||
if not ctx.has_header('Authorization'):
|
if not ctx.has_header('Authorization'):
|
||||||
return _create_anonymous_user()
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
auth_type, credentials = ctx.get_header('Authorization').split(' ', 1)
|
auth_type, credentials = ctx.get_header('Authorization').split(' ', 1)
|
||||||
|
@ -41,10 +34,12 @@ def _get_user(ctx):
|
||||||
msg.format(ctx.get_header('Authorization'), str(err)))
|
msg.format(ctx.get_header('Authorization'), str(err)))
|
||||||
|
|
||||||
|
|
||||||
@middleware.pre_hook
|
@rest.middleware.pre_hook
|
||||||
def process_request(ctx):
|
def process_request(ctx: rest.Context) -> None:
|
||||||
''' Bind the user to request. Update last login time if needed. '''
|
''' Bind the user to request. Update last login time if needed. '''
|
||||||
ctx.user = _get_user(ctx)
|
auth_user = _get_user(ctx)
|
||||||
if ctx.get_param_as_bool('bump-login') and ctx.user.user_id:
|
if auth_user:
|
||||||
|
ctx.user = auth_user
|
||||||
|
if ctx.get_param_as_bool('bump-login', default=False) and ctx.user.user_id:
|
||||||
users.bump_user_login_time(ctx.user)
|
users.bump_user_login_time(ctx.user)
|
||||||
ctx.session.commit()
|
ctx.session.commit()
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
|
from szurubooru import rest
|
||||||
from szurubooru.func import cache
|
from szurubooru.func import cache
|
||||||
from szurubooru.rest import middleware
|
from szurubooru.rest import middleware
|
||||||
|
|
||||||
|
|
||||||
@middleware.pre_hook
|
@middleware.pre_hook
|
||||||
def process_request(ctx):
|
def process_request(ctx: rest.Context) -> None:
|
||||||
if ctx.method != 'GET':
|
if ctx.method != 'GET':
|
||||||
cache.purge()
|
cache.purge()
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
from szurubooru import db
|
from szurubooru import db, rest
|
||||||
from szurubooru.rest import middleware
|
from szurubooru.rest import middleware
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,12 +7,12 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@middleware.pre_hook
|
@middleware.pre_hook
|
||||||
def process_request(_ctx):
|
def process_request(_ctx: rest.Context) -> None:
|
||||||
db.reset_query_count()
|
db.reset_query_count()
|
||||||
|
|
||||||
|
|
||||||
@middleware.post_hook
|
@middleware.post_hook
|
||||||
def process_response(ctx):
|
def process_response(ctx: rest.Context) -> None:
|
||||||
logger.info(
|
logger.info(
|
||||||
'%s %s (user=%s, queries=%d)',
|
'%s %s (user=%s, queries=%d)',
|
||||||
ctx.method,
|
ctx.method,
|
||||||
|
|
|
@ -2,7 +2,7 @@ import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import alembic
|
import alembic
|
||||||
import sqlalchemy
|
import sqlalchemy as sa
|
||||||
import logging.config
|
import logging.config
|
||||||
|
|
||||||
# make szurubooru module importable
|
# make szurubooru module importable
|
||||||
|
@ -48,7 +48,7 @@ def run_migrations_online():
|
||||||
In this scenario we need to create an Engine
|
In this scenario we need to create an Engine
|
||||||
and associate a connection with the context.
|
and associate a connection with the context.
|
||||||
'''
|
'''
|
||||||
connectable = sqlalchemy.engine_from_config(
|
connectable = sa.engine_from_config(
|
||||||
alembic_config.get_section(alembic_config.config_ini_section),
|
alembic_config.get_section(alembic_config.config_ini_section),
|
||||||
prefix='sqlalchemy.',
|
prefix='sqlalchemy.',
|
||||||
poolclass=sqlalchemy.pool.NullPool)
|
poolclass=sqlalchemy.pool.NullPool)
|
||||||
|
|
15
server/szurubooru/model/__init__.py
Normal file
15
server/szurubooru/model/__init__.py
Normal file
|
@ -0,0 +1,15 @@
|
||||||
|
from szurubooru.model.base import Base
|
||||||
|
from szurubooru.model.user import User
|
||||||
|
from szurubooru.model.tag_category import TagCategory
|
||||||
|
from szurubooru.model.tag import Tag, TagName, TagSuggestion, TagImplication
|
||||||
|
from szurubooru.model.post import (
|
||||||
|
Post,
|
||||||
|
PostTag,
|
||||||
|
PostRelation,
|
||||||
|
PostFavorite,
|
||||||
|
PostScore,
|
||||||
|
PostNote,
|
||||||
|
PostFeature)
|
||||||
|
from szurubooru.model.comment import Comment, CommentScore
|
||||||
|
from szurubooru.model.snapshot import Snapshot
|
||||||
|
import szurubooru.model.util
|
|
@ -1,7 +1,8 @@
|
||||||
from sqlalchemy import Column, Integer, DateTime, UnicodeText, ForeignKey
|
from sqlalchemy import Column, Integer, DateTime, UnicodeText, ForeignKey
|
||||||
from sqlalchemy.orm import relationship, backref
|
from sqlalchemy.orm import relationship, backref
|
||||||
from sqlalchemy.sql.expression import func
|
from sqlalchemy.sql.expression import func
|
||||||
from szurubooru.db.base import Base
|
from szurubooru.db import get_session
|
||||||
|
from szurubooru.model.base import Base
|
||||||
|
|
||||||
|
|
||||||
class CommentScore(Base):
|
class CommentScore(Base):
|
||||||
|
@ -48,12 +49,12 @@ class Comment(Base):
|
||||||
'CommentScore', cascade='all, delete-orphan', lazy='joined')
|
'CommentScore', cascade='all, delete-orphan', lazy='joined')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def score(self):
|
def score(self) -> int:
|
||||||
from szurubooru.db import session
|
return (
|
||||||
return session \
|
get_session()
|
||||||
.query(func.sum(CommentScore.score)) \
|
.query(func.sum(CommentScore.score))
|
||||||
.filter(CommentScore.comment_id == self.comment_id) \
|
.filter(CommentScore.comment_id == self.comment_id)
|
||||||
.one()[0] or 0
|
.one()[0] or 0)
|
||||||
|
|
||||||
__mapper_args__ = {
|
__mapper_args__ = {
|
||||||
'version_id_col': version,
|
'version_id_col': version,
|
|
@ -3,8 +3,8 @@ from sqlalchemy import (
|
||||||
Column, Integer, DateTime, Unicode, UnicodeText, PickleType, ForeignKey)
|
Column, Integer, DateTime, Unicode, UnicodeText, PickleType, ForeignKey)
|
||||||
from sqlalchemy.orm import (
|
from sqlalchemy.orm import (
|
||||||
relationship, column_property, object_session, backref)
|
relationship, column_property, object_session, backref)
|
||||||
from szurubooru.db.base import Base
|
from szurubooru.model.base import Base
|
||||||
from szurubooru.db.comment import Comment
|
from szurubooru.model.comment import Comment
|
||||||
|
|
||||||
|
|
||||||
class PostFeature(Base):
|
class PostFeature(Base):
|
||||||
|
@ -17,10 +17,9 @@ class PostFeature(Base):
|
||||||
'user_id', Integer, ForeignKey('user.id'), nullable=False, index=True)
|
'user_id', Integer, ForeignKey('user.id'), nullable=False, index=True)
|
||||||
time = Column('time', DateTime, nullable=False)
|
time = Column('time', DateTime, nullable=False)
|
||||||
|
|
||||||
post = relationship('Post')
|
post = relationship('Post') # type: Post
|
||||||
user = relationship(
|
user = relationship(
|
||||||
'User',
|
'User', backref=backref('post_features', cascade='all, delete-orphan'))
|
||||||
backref=backref('post_features', cascade='all, delete-orphan'))
|
|
||||||
|
|
||||||
|
|
||||||
class PostScore(Base):
|
class PostScore(Base):
|
||||||
|
@ -104,7 +103,7 @@ class PostRelation(Base):
|
||||||
nullable=False,
|
nullable=False,
|
||||||
index=True)
|
index=True)
|
||||||
|
|
||||||
def __init__(self, parent_id, child_id):
|
def __init__(self, parent_id: int, child_id: int) -> None:
|
||||||
self.parent_id = parent_id
|
self.parent_id = parent_id
|
||||||
self.child_id = child_id
|
self.child_id = child_id
|
||||||
|
|
||||||
|
@ -127,7 +126,7 @@ class PostTag(Base):
|
||||||
nullable=False,
|
nullable=False,
|
||||||
index=True)
|
index=True)
|
||||||
|
|
||||||
def __init__(self, post_id, tag_id):
|
def __init__(self, post_id: int, tag_id: int) -> None:
|
||||||
self.post_id = post_id
|
self.post_id = post_id
|
||||||
self.tag_id = tag_id
|
self.tag_id = tag_id
|
||||||
|
|
||||||
|
@ -197,7 +196,7 @@ class Post(Base):
|
||||||
canvas_area = column_property(canvas_width * canvas_height)
|
canvas_area = column_property(canvas_width * canvas_height)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_featured(self):
|
def is_featured(self) -> bool:
|
||||||
featured_post = object_session(self) \
|
featured_post = object_session(self) \
|
||||||
.query(PostFeature) \
|
.query(PostFeature) \
|
||||||
.order_by(PostFeature.time.desc()) \
|
.order_by(PostFeature.time.desc()) \
|
|
@ -1,7 +1,7 @@
|
||||||
from sqlalchemy.orm import relationship
|
from sqlalchemy.orm import relationship
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
Column, Integer, DateTime, Unicode, PickleType, ForeignKey)
|
Column, Integer, DateTime, Unicode, PickleType, ForeignKey)
|
||||||
from szurubooru.db.base import Base
|
from szurubooru.model.base import Base
|
||||||
|
|
||||||
|
|
||||||
class Snapshot(Base):
|
class Snapshot(Base):
|
|
@ -2,8 +2,8 @@ from sqlalchemy import (
|
||||||
Column, Integer, DateTime, Unicode, UnicodeText, ForeignKey)
|
Column, Integer, DateTime, Unicode, UnicodeText, ForeignKey)
|
||||||
from sqlalchemy.orm import relationship, column_property
|
from sqlalchemy.orm import relationship, column_property
|
||||||
from sqlalchemy.sql.expression import func, select
|
from sqlalchemy.sql.expression import func, select
|
||||||
from szurubooru.db.base import Base
|
from szurubooru.model.base import Base
|
||||||
from szurubooru.db.post import PostTag
|
from szurubooru.model.post import PostTag
|
||||||
|
|
||||||
|
|
||||||
class TagSuggestion(Base):
|
class TagSuggestion(Base):
|
||||||
|
@ -24,7 +24,7 @@ class TagSuggestion(Base):
|
||||||
primary_key=True,
|
primary_key=True,
|
||||||
index=True)
|
index=True)
|
||||||
|
|
||||||
def __init__(self, parent_id, child_id):
|
def __init__(self, parent_id: int, child_id: int) -> None:
|
||||||
self.parent_id = parent_id
|
self.parent_id = parent_id
|
||||||
self.child_id = child_id
|
self.child_id = child_id
|
||||||
|
|
||||||
|
@ -47,7 +47,7 @@ class TagImplication(Base):
|
||||||
primary_key=True,
|
primary_key=True,
|
||||||
index=True)
|
index=True)
|
||||||
|
|
||||||
def __init__(self, parent_id, child_id):
|
def __init__(self, parent_id: int, child_id: int) -> None:
|
||||||
self.parent_id = parent_id
|
self.parent_id = parent_id
|
||||||
self.child_id = child_id
|
self.child_id = child_id
|
||||||
|
|
||||||
|
@ -61,7 +61,7 @@ class TagName(Base):
|
||||||
name = Column('name', Unicode(64), nullable=False, unique=True)
|
name = Column('name', Unicode(64), nullable=False, unique=True)
|
||||||
order = Column('ord', Integer, nullable=False, index=True)
|
order = Column('ord', Integer, nullable=False, index=True)
|
||||||
|
|
||||||
def __init__(self, name, order):
|
def __init__(self, name: str, order: int) -> None:
|
||||||
self.name = name
|
self.name = name
|
||||||
self.order = order
|
self.order = order
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
|
from typing import Optional
|
||||||
from sqlalchemy import Column, Integer, Unicode, Boolean, table
|
from sqlalchemy import Column, Integer, Unicode, Boolean, table
|
||||||
from sqlalchemy.orm import column_property
|
from sqlalchemy.orm import column_property
|
||||||
from sqlalchemy.sql.expression import func, select
|
from sqlalchemy.sql.expression import func, select
|
||||||
from szurubooru.db.base import Base
|
from szurubooru.model.base import Base
|
||||||
from szurubooru.db.tag import Tag
|
from szurubooru.model.tag import Tag
|
||||||
|
|
||||||
|
|
||||||
class TagCategory(Base):
|
class TagCategory(Base):
|
||||||
|
@ -14,7 +15,7 @@ class TagCategory(Base):
|
||||||
color = Column('color', Unicode(32), nullable=False, default='#000000')
|
color = Column('color', Unicode(32), nullable=False, default='#000000')
|
||||||
default = Column('default', Boolean, nullable=False, default=False)
|
default = Column('default', Boolean, nullable=False, default=False)
|
||||||
|
|
||||||
def __init__(self, name=None):
|
def __init__(self, name: Optional[str]=None) -> None:
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
tag_count = column_property(
|
tag_count = column_property(
|
|
@ -1,9 +1,7 @@
|
||||||
from sqlalchemy import Column, Integer, Unicode, DateTime
|
import sqlalchemy as sa
|
||||||
from sqlalchemy.orm import relationship
|
from szurubooru.model.base import Base
|
||||||
from sqlalchemy.sql.expression import func
|
from szurubooru.model.post import Post, PostScore, PostFavorite
|
||||||
from szurubooru.db.base import Base
|
from szurubooru.model.comment import Comment
|
||||||
from szurubooru.db.post import Post, PostScore, PostFavorite
|
|
||||||
from szurubooru.db.comment import Comment
|
|
||||||
|
|
||||||
|
|
||||||
class User(Base):
|
class User(Base):
|
||||||
|
@ -20,63 +18,64 @@ class User(Base):
|
||||||
RANK_ADMINISTRATOR = 'administrator'
|
RANK_ADMINISTRATOR = 'administrator'
|
||||||
RANK_NOBODY = 'nobody' # unattainable, used for privileges
|
RANK_NOBODY = 'nobody' # unattainable, used for privileges
|
||||||
|
|
||||||
user_id = Column('id', Integer, primary_key=True)
|
user_id = sa.Column('id', sa.Integer, primary_key=True)
|
||||||
creation_time = Column('creation_time', DateTime, nullable=False)
|
creation_time = sa.Column('creation_time', sa.DateTime, nullable=False)
|
||||||
last_login_time = Column('last_login_time', DateTime)
|
last_login_time = sa.Column('last_login_time', sa.DateTime)
|
||||||
version = Column('version', Integer, default=1, nullable=False)
|
version = sa.Column('version', sa.Integer, default=1, nullable=False)
|
||||||
name = Column('name', Unicode(50), nullable=False, unique=True)
|
name = sa.Column('name', sa.Unicode(50), nullable=False, unique=True)
|
||||||
password_hash = Column('password_hash', Unicode(64), nullable=False)
|
password_hash = sa.Column('password_hash', sa.Unicode(64), nullable=False)
|
||||||
password_salt = Column('password_salt', Unicode(32))
|
password_salt = sa.Column('password_salt', sa.Unicode(32))
|
||||||
email = Column('email', Unicode(64), nullable=True)
|
email = sa.Column('email', sa.Unicode(64), nullable=True)
|
||||||
rank = Column('rank', Unicode(32), nullable=False)
|
rank = sa.Column('rank', sa.Unicode(32), nullable=False)
|
||||||
avatar_style = Column(
|
avatar_style = sa.Column(
|
||||||
'avatar_style', Unicode(32), nullable=False, default=AVATAR_GRAVATAR)
|
'avatar_style', sa.Unicode(32), nullable=False,
|
||||||
|
default=AVATAR_GRAVATAR)
|
||||||
|
|
||||||
comments = relationship('Comment')
|
comments = sa.orm.relationship('Comment')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def post_count(self):
|
def post_count(self) -> int:
|
||||||
from szurubooru.db import session
|
from szurubooru.db import session
|
||||||
return (
|
return (
|
||||||
session
|
session
|
||||||
.query(func.sum(1))
|
.query(sa.sql.expression.func.sum(1))
|
||||||
.filter(Post.user_id == self.user_id)
|
.filter(Post.user_id == self.user_id)
|
||||||
.one()[0] or 0)
|
.one()[0] or 0)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def comment_count(self):
|
def comment_count(self) -> int:
|
||||||
from szurubooru.db import session
|
from szurubooru.db import session
|
||||||
return (
|
return (
|
||||||
session
|
session
|
||||||
.query(func.sum(1))
|
.query(sa.sql.expression.func.sum(1))
|
||||||
.filter(Comment.user_id == self.user_id)
|
.filter(Comment.user_id == self.user_id)
|
||||||
.one()[0] or 0)
|
.one()[0] or 0)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def favorite_post_count(self):
|
def favorite_post_count(self) -> int:
|
||||||
from szurubooru.db import session
|
from szurubooru.db import session
|
||||||
return (
|
return (
|
||||||
session
|
session
|
||||||
.query(func.sum(1))
|
.query(sa.sql.expression.func.sum(1))
|
||||||
.filter(PostFavorite.user_id == self.user_id)
|
.filter(PostFavorite.user_id == self.user_id)
|
||||||
.one()[0] or 0)
|
.one()[0] or 0)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def liked_post_count(self):
|
def liked_post_count(self) -> int:
|
||||||
from szurubooru.db import session
|
from szurubooru.db import session
|
||||||
return (
|
return (
|
||||||
session
|
session
|
||||||
.query(func.sum(1))
|
.query(sa.sql.expression.func.sum(1))
|
||||||
.filter(PostScore.user_id == self.user_id)
|
.filter(PostScore.user_id == self.user_id)
|
||||||
.filter(PostScore.score == 1)
|
.filter(PostScore.score == 1)
|
||||||
.one()[0] or 0)
|
.one()[0] or 0)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def disliked_post_count(self):
|
def disliked_post_count(self) -> int:
|
||||||
from szurubooru.db import session
|
from szurubooru.db import session
|
||||||
return (
|
return (
|
||||||
session
|
session
|
||||||
.query(func.sum(1))
|
.query(sa.sql.expression.func.sum(1))
|
||||||
.filter(PostScore.user_id == self.user_id)
|
.filter(PostScore.user_id == self.user_id)
|
||||||
.filter(PostScore.score == -1)
|
.filter(PostScore.score == -1)
|
||||||
.one()[0] or 0)
|
.one()[0] or 0)
|
42
server/szurubooru/model/util.py
Normal file
42
server/szurubooru/model/util.py
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
from typing import Tuple, Any, Dict, Callable, Union, Optional
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from szurubooru.model.base import Base
|
||||||
|
from szurubooru.model.user import User
|
||||||
|
|
||||||
|
|
||||||
|
def get_resource_info(entity: Base) -> Tuple[Any, Any, Union[str, int]]:
|
||||||
|
serializers = {
|
||||||
|
'tag': lambda tag: tag.first_name,
|
||||||
|
'tag_category': lambda category: category.name,
|
||||||
|
'comment': lambda comment: comment.comment_id,
|
||||||
|
'post': lambda post: post.post_id,
|
||||||
|
} # type: Dict[str, Callable[[Base], Any]]
|
||||||
|
|
||||||
|
resource_type = entity.__table__.name
|
||||||
|
assert resource_type in serializers
|
||||||
|
|
||||||
|
primary_key = sa.inspection.inspect(entity).identity # type: Any
|
||||||
|
assert primary_key is not None
|
||||||
|
assert len(primary_key) == 1
|
||||||
|
|
||||||
|
resource_name = serializers[resource_type](entity) # type: Union[str, int]
|
||||||
|
assert resource_name
|
||||||
|
|
||||||
|
resource_pkey = primary_key[0] # type: Any
|
||||||
|
assert resource_pkey
|
||||||
|
|
||||||
|
return (resource_type, resource_pkey, resource_name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_aux_entity(
|
||||||
|
session: Any,
|
||||||
|
get_table_info: Callable[[Base], Tuple[Base, Callable[[Base], Any]]],
|
||||||
|
entity: Base,
|
||||||
|
user: User) -> Optional[Base]:
|
||||||
|
table, get_column = get_table_info(entity)
|
||||||
|
return (
|
||||||
|
session
|
||||||
|
.query(table)
|
||||||
|
.filter(get_column(table) == get_column(entity))
|
||||||
|
.filter(table.user_id == user.user_id)
|
||||||
|
.one_or_none())
|
|
@ -1,2 +1,2 @@
|
||||||
from szurubooru.rest.app import application
|
from szurubooru.rest.app import application
|
||||||
from szurubooru.rest.context import Context
|
from szurubooru.rest.context import Context, Response
|
||||||
|
|
|
@ -2,13 +2,14 @@ import urllib.parse
|
||||||
import cgi
|
import cgi
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
from typing import Dict, Any, Callable, Tuple
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from szurubooru import db
|
from szurubooru import db
|
||||||
from szurubooru.func import util
|
from szurubooru.func import util
|
||||||
from szurubooru.rest import errors, middleware, routes, context
|
from szurubooru.rest import errors, middleware, routes, context
|
||||||
|
|
||||||
|
|
||||||
def _json_serializer(obj):
|
def _json_serializer(obj: Any) -> str:
|
||||||
''' JSON serializer for objects not serializable by default JSON code '''
|
''' JSON serializer for objects not serializable by default JSON code '''
|
||||||
if isinstance(obj, datetime):
|
if isinstance(obj, datetime):
|
||||||
serial = obj.isoformat('T') + 'Z'
|
serial = obj.isoformat('T') + 'Z'
|
||||||
|
@ -16,12 +17,12 @@ def _json_serializer(obj):
|
||||||
raise TypeError('Type not serializable')
|
raise TypeError('Type not serializable')
|
||||||
|
|
||||||
|
|
||||||
def _dump_json(obj):
|
def _dump_json(obj: Any) -> str:
|
||||||
return json.dumps(obj, default=_json_serializer, indent=2)
|
return json.dumps(obj, default=_json_serializer, indent=2)
|
||||||
|
|
||||||
|
|
||||||
def _get_headers(env):
|
def _get_headers(env: Dict[str, Any]) -> Dict[str, str]:
|
||||||
headers = {}
|
headers = {} # type: Dict[str, str]
|
||||||
for key, value in env.items():
|
for key, value in env.items():
|
||||||
if key.startswith('HTTP_'):
|
if key.startswith('HTTP_'):
|
||||||
key = util.snake_case_to_upper_train_case(key[5:])
|
key = util.snake_case_to_upper_train_case(key[5:])
|
||||||
|
@ -29,7 +30,7 @@ def _get_headers(env):
|
||||||
return headers
|
return headers
|
||||||
|
|
||||||
|
|
||||||
def _create_context(env):
|
def _create_context(env: Dict[str, Any]) -> context.Context:
|
||||||
method = env['REQUEST_METHOD']
|
method = env['REQUEST_METHOD']
|
||||||
path = '/' + env['PATH_INFO'].lstrip('/')
|
path = '/' + env['PATH_INFO'].lstrip('/')
|
||||||
headers = _get_headers(env)
|
headers = _get_headers(env)
|
||||||
|
@ -64,7 +65,9 @@ def _create_context(env):
|
||||||
return context.Context(method, path, headers, params, files)
|
return context.Context(method, path, headers, params, files)
|
||||||
|
|
||||||
|
|
||||||
def application(env, start_response):
|
def application(
|
||||||
|
env: Dict[str, Any],
|
||||||
|
start_response: Callable[[str, Any], Any]) -> Tuple[bytes]:
|
||||||
try:
|
try:
|
||||||
ctx = _create_context(env)
|
ctx = _create_context(env)
|
||||||
if 'application/json' not in ctx.get_header('Accept'):
|
if 'application/json' not in ctx.get_header('Accept'):
|
||||||
|
@ -106,9 +109,9 @@ def application(env, start_response):
|
||||||
return (_dump_json(response).encode('utf-8'),)
|
return (_dump_json(response).encode('utf-8'),)
|
||||||
|
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
for exception_type, handler in errors.error_handlers.items():
|
for exception_type, ex_handler in errors.error_handlers.items():
|
||||||
if isinstance(ex, exception_type):
|
if isinstance(ex, exception_type):
|
||||||
handler(ex)
|
ex_handler(ex)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
except errors.BaseHttpError as ex:
|
except errors.BaseHttpError as ex:
|
||||||
|
|
|
@ -1,111 +1,158 @@
|
||||||
from szurubooru import errors
|
from typing import Any, Union, List, Dict, Optional, cast
|
||||||
|
from szurubooru import model, errors
|
||||||
from szurubooru.func import net, file_uploads
|
from szurubooru.func import net, file_uploads
|
||||||
|
|
||||||
|
|
||||||
def _lower_first(source):
|
MISSING = object()
|
||||||
return source[0].lower() + source[1:]
|
Request = Dict[str, Any]
|
||||||
|
Response = Optional[Dict[str, Any]]
|
||||||
|
|
||||||
def _param_wrapper(func):
|
|
||||||
def wrapper(self, name, required=False, default=None, **kwargs):
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
if name in self._params:
|
|
||||||
value = self._params[name]
|
|
||||||
try:
|
|
||||||
value = func(self, value, **kwargs)
|
|
||||||
except errors.InvalidParameterError as ex:
|
|
||||||
raise errors.InvalidParameterError(
|
|
||||||
'Parameter %r is invalid: %s' % (
|
|
||||||
name, _lower_first(str(ex))))
|
|
||||||
return value
|
|
||||||
if not required:
|
|
||||||
return default
|
|
||||||
raise errors.MissingRequiredParameterError(
|
|
||||||
'Required parameter %r is missing.' % name)
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
class Context:
|
class Context:
|
||||||
def __init__(self, method, url, headers=None, params=None, files=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
url: str,
|
||||||
|
headers: Dict[str, str]=None,
|
||||||
|
params: Request=None,
|
||||||
|
files: Dict[str, bytes]=None) -> None:
|
||||||
self.method = method
|
self.method = method
|
||||||
self.url = url
|
self.url = url
|
||||||
self._headers = headers or {}
|
self._headers = headers or {}
|
||||||
self._params = params or {}
|
self._params = params or {}
|
||||||
self._files = files or {}
|
self._files = files or {}
|
||||||
|
|
||||||
# provided by middleware
|
self.user = model.User()
|
||||||
# self.session = None
|
self.user.name = None
|
||||||
# self.user = None
|
self.user.rank = 'anonymous'
|
||||||
|
|
||||||
def has_header(self, name):
|
self.session = None # type: Any
|
||||||
|
|
||||||
|
def has_header(self, name: str) -> bool:
|
||||||
return name in self._headers
|
return name in self._headers
|
||||||
|
|
||||||
def get_header(self, name):
|
def get_header(self, name: str) -> str:
|
||||||
return self._headers.get(name, None)
|
return self._headers.get(name, '')
|
||||||
|
|
||||||
def has_file(self, name, allow_tokens=True):
|
def has_file(self, name: str, allow_tokens: bool=True) -> bool:
|
||||||
return (
|
return (
|
||||||
name in self._files or
|
name in self._files or
|
||||||
name + 'Url' in self._params or
|
name + 'Url' in self._params or
|
||||||
(allow_tokens and name + 'Token' in self._params))
|
(allow_tokens and name + 'Token' in self._params))
|
||||||
|
|
||||||
def get_file(self, name, required=False, allow_tokens=True):
|
def get_file(
|
||||||
ret = None
|
self,
|
||||||
if name in self._files:
|
name: str,
|
||||||
ret = self._files[name]
|
default: Union[object, bytes]=MISSING,
|
||||||
elif name + 'Url' in self._params:
|
allow_tokens: bool=True) -> bytes:
|
||||||
ret = net.download(self._params[name + 'Url'])
|
if name in self._files and self._files[name]:
|
||||||
elif allow_tokens and name + 'Token' in self._params:
|
return self._files[name]
|
||||||
|
|
||||||
|
if name + 'Url' in self._params:
|
||||||
|
return net.download(self._params[name + 'Url'])
|
||||||
|
|
||||||
|
if allow_tokens and name + 'Token' in self._params:
|
||||||
ret = file_uploads.get(self._params[name + 'Token'])
|
ret = file_uploads.get(self._params[name + 'Token'])
|
||||||
if required and not ret:
|
if ret:
|
||||||
|
return ret
|
||||||
|
elif default is not MISSING:
|
||||||
raise errors.MissingOrExpiredRequiredFileError(
|
raise errors.MissingOrExpiredRequiredFileError(
|
||||||
'Required file %r is missing or has expired.' % name)
|
'Required file %r is missing or has expired.' % name)
|
||||||
if required and not ret:
|
|
||||||
|
if default is not MISSING:
|
||||||
|
return cast(bytes, default)
|
||||||
raise errors.MissingRequiredFileError(
|
raise errors.MissingRequiredFileError(
|
||||||
'Required file %r is missing.' % name)
|
'Required file %r is missing.' % name)
|
||||||
return ret
|
|
||||||
|
|
||||||
def has_param(self, name):
|
def has_param(self, name: str) -> bool:
|
||||||
return name in self._params
|
return name in self._params
|
||||||
|
|
||||||
@_param_wrapper
|
def get_param_as_list(
|
||||||
def get_param_as_list(self, value):
|
self,
|
||||||
if not isinstance(value, list):
|
name: str,
|
||||||
|
default: Union[object, List[Any]]=MISSING) -> List[Any]:
|
||||||
|
if name not in self._params:
|
||||||
|
if default is not MISSING:
|
||||||
|
return cast(List[Any], default)
|
||||||
|
raise errors.MissingRequiredParameterError(
|
||||||
|
'Required parameter %r is missing.' % name)
|
||||||
|
value = self._params[name]
|
||||||
|
if type(value) is str:
|
||||||
if ',' in value:
|
if ',' in value:
|
||||||
return value.split(',')
|
return value.split(',')
|
||||||
return [value]
|
return [value]
|
||||||
|
if type(value) is list:
|
||||||
return value
|
return value
|
||||||
|
raise errors.InvalidParameterError(
|
||||||
|
'Parameter %r must be a list.' % name)
|
||||||
|
|
||||||
@_param_wrapper
|
def get_param_as_string(
|
||||||
def get_param_as_string(self, value):
|
self,
|
||||||
if isinstance(value, list):
|
name: str,
|
||||||
|
default: Union[object, str]=MISSING) -> str:
|
||||||
|
if name not in self._params:
|
||||||
|
if default is not MISSING:
|
||||||
|
return cast(str, default)
|
||||||
|
raise errors.MissingRequiredParameterError(
|
||||||
|
'Required parameter %r is missing.' % name)
|
||||||
|
value = self._params[name]
|
||||||
try:
|
try:
|
||||||
value = ','.join(value)
|
if value is None:
|
||||||
except TypeError:
|
return ''
|
||||||
raise errors.InvalidParameterError('Expected simple string.')
|
if type(value) is list:
|
||||||
|
return ','.join(value)
|
||||||
|
if type(value) is int or type(value) is float:
|
||||||
|
return str(value)
|
||||||
|
if type(value) is str:
|
||||||
return value
|
return value
|
||||||
|
except TypeError:
|
||||||
|
pass
|
||||||
|
raise errors.InvalidParameterError(
|
||||||
|
'Parameter %r must be a string value.' % name)
|
||||||
|
|
||||||
@_param_wrapper
|
def get_param_as_int(
|
||||||
def get_param_as_int(self, value, min=None, max=None):
|
self,
|
||||||
|
name: str,
|
||||||
|
default: Union[object, int]=MISSING,
|
||||||
|
min: Optional[int]=None,
|
||||||
|
max: Optional[int]=None) -> int:
|
||||||
|
if name not in self._params:
|
||||||
|
if default is not MISSING:
|
||||||
|
return cast(int, default)
|
||||||
|
raise errors.MissingRequiredParameterError(
|
||||||
|
'Required parameter %r is missing.' % name)
|
||||||
|
value = self._params[name]
|
||||||
try:
|
try:
|
||||||
value = int(value)
|
value = int(value)
|
||||||
except (ValueError, TypeError):
|
|
||||||
raise errors.InvalidParameterError(
|
|
||||||
'The value must be an integer.')
|
|
||||||
if min is not None and value < min:
|
if min is not None and value < min:
|
||||||
raise errors.InvalidParameterError(
|
raise errors.InvalidParameterError(
|
||||||
'The value must be at least %r.' % min)
|
'Parameter %r must be at least %r.' % (name, min))
|
||||||
if max is not None and value > max:
|
if max is not None and value > max:
|
||||||
raise errors.InvalidParameterError(
|
raise errors.InvalidParameterError(
|
||||||
'The value may not exceed %r.' % max)
|
'Parameter %r may not exceed %r.' % (name, max))
|
||||||
return value
|
return value
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
pass
|
||||||
|
raise errors.InvalidParameterError(
|
||||||
|
'Parameter %r must be an integer value.' % name)
|
||||||
|
|
||||||
@_param_wrapper
|
def get_param_as_bool(
|
||||||
def get_param_as_bool(self, value):
|
self,
|
||||||
|
name: str,
|
||||||
|
default: Union[object, bool]=MISSING) -> bool:
|
||||||
|
if name not in self._params:
|
||||||
|
if default is not MISSING:
|
||||||
|
return cast(bool, default)
|
||||||
|
raise errors.MissingRequiredParameterError(
|
||||||
|
'Required parameter %r is missing.' % name)
|
||||||
|
value = self._params[name]
|
||||||
|
try:
|
||||||
value = str(value).lower()
|
value = str(value).lower()
|
||||||
|
except TypeError:
|
||||||
|
pass
|
||||||
if value in ['1', 'y', 'yes', 'yeah', 'yep', 'yup', 't', 'true']:
|
if value in ['1', 'y', 'yes', 'yeah', 'yep', 'yup', 't', 'true']:
|
||||||
return True
|
return True
|
||||||
if value in ['0', 'n', 'no', 'nope', 'f', 'false']:
|
if value in ['0', 'n', 'no', 'nope', 'f', 'false']:
|
||||||
return False
|
return False
|
||||||
raise errors.InvalidParameterError(
|
raise errors.InvalidParameterError(
|
||||||
'The value must be a boolean value.')
|
'Parameter %r must be a boolean value.' % name)
|
||||||
|
|
|
@ -1,11 +1,19 @@
|
||||||
|
from typing import Callable, Type, Dict
|
||||||
|
|
||||||
|
|
||||||
error_handlers = {} # pylint: disable=invalid-name
|
error_handlers = {} # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class BaseHttpError(RuntimeError):
|
class BaseHttpError(RuntimeError):
|
||||||
code = None
|
code = -1
|
||||||
reason = None
|
reason = ''
|
||||||
|
|
||||||
def __init__(self, name, description, title=None, extra_fields=None):
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
title: str=None,
|
||||||
|
extra_fields: Dict[str, str]=None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# error name for programmers
|
# error name for programmers
|
||||||
self.name = name
|
self.name = name
|
||||||
|
@ -52,5 +60,7 @@ class HttpInternalServerError(BaseHttpError):
|
||||||
reason = 'Internal Server Error'
|
reason = 'Internal Server Error'
|
||||||
|
|
||||||
|
|
||||||
def handle(exception_type, handler):
|
def handle(
|
||||||
|
exception_type: Type[Exception],
|
||||||
|
handler: Callable[[Exception], None]) -> None:
|
||||||
error_handlers[exception_type] = handler
|
error_handlers[exception_type] = handler
|
||||||
|
|
|
@ -1,11 +1,15 @@
|
||||||
|
from typing import Callable
|
||||||
|
from szurubooru.rest.context import Context
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
pre_hooks = []
|
pre_hooks = [] # type: List[Callable[[Context], None]]
|
||||||
post_hooks = []
|
post_hooks = [] # type: List[Callable[[Context], None]]
|
||||||
|
|
||||||
|
|
||||||
def pre_hook(handler):
|
def pre_hook(handler: Callable) -> None:
|
||||||
pre_hooks.append(handler)
|
pre_hooks.append(handler)
|
||||||
|
|
||||||
|
|
||||||
def post_hook(handler):
|
def post_hook(handler: Callable) -> None:
|
||||||
post_hooks.insert(0, handler)
|
post_hooks.insert(0, handler)
|
||||||
|
|
|
@ -1,32 +1,36 @@
|
||||||
|
from typing import Callable, Dict, Any
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from szurubooru.rest.context import Context, Response
|
||||||
|
|
||||||
|
|
||||||
routes = defaultdict(dict) # pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
|
RouteHandler = Callable[[Context, Dict[str, str]], Response]
|
||||||
|
routes = defaultdict(dict) # type: Dict[str, Dict[str, RouteHandler]]
|
||||||
|
|
||||||
|
|
||||||
def get(url):
|
def get(url: str) -> Callable[[RouteHandler], RouteHandler]:
|
||||||
def wrapper(handler):
|
def wrapper(handler: RouteHandler) -> RouteHandler:
|
||||||
routes[url]['GET'] = handler
|
routes[url]['GET'] = handler
|
||||||
return handler
|
return handler
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def put(url):
|
def put(url: str) -> Callable[[RouteHandler], RouteHandler]:
|
||||||
def wrapper(handler):
|
def wrapper(handler: RouteHandler) -> RouteHandler:
|
||||||
routes[url]['PUT'] = handler
|
routes[url]['PUT'] = handler
|
||||||
return handler
|
return handler
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def post(url):
|
def post(url: str) -> Callable[[RouteHandler], RouteHandler]:
|
||||||
def wrapper(handler):
|
def wrapper(handler: RouteHandler) -> RouteHandler:
|
||||||
routes[url]['POST'] = handler
|
routes[url]['POST'] = handler
|
||||||
return handler
|
return handler
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def delete(url):
|
def delete(url: str) -> Callable[[RouteHandler], RouteHandler]:
|
||||||
def wrapper(handler):
|
def wrapper(handler: RouteHandler) -> RouteHandler:
|
||||||
routes[url]['DELETE'] = handler
|
routes[url]['DELETE'] = handler
|
||||||
return handler
|
return handler
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
|
@ -1,38 +1,47 @@
|
||||||
from szurubooru.search import tokens
|
from typing import Optional, Tuple, Dict, Callable
|
||||||
|
from szurubooru.search import tokens, criteria
|
||||||
|
from szurubooru.search.query import SearchQuery
|
||||||
|
from szurubooru.search.typing import SaColumn, SaQuery
|
||||||
|
|
||||||
|
Filter = Callable[[SaQuery, Optional[criteria.BaseCriterion], bool], SaQuery]
|
||||||
|
|
||||||
|
|
||||||
class BaseSearchConfig:
|
class BaseSearchConfig:
|
||||||
|
SORT_NONE = tokens.SortToken.SORT_NONE
|
||||||
SORT_ASC = tokens.SortToken.SORT_ASC
|
SORT_ASC = tokens.SortToken.SORT_ASC
|
||||||
SORT_DESC = tokens.SortToken.SORT_DESC
|
SORT_DESC = tokens.SortToken.SORT_DESC
|
||||||
|
|
||||||
def on_search_query_parsed(self, search_query):
|
def on_search_query_parsed(self, search_query: SearchQuery) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def create_filter_query(self, _disable_eager_loads):
|
def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def create_count_query(self, disable_eager_loads):
|
def create_count_query(self, disable_eager_loads: bool) -> SaQuery:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def create_around_query(self):
|
def create_around_query(self) -> SaQuery:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def finalize_query(self, query: SaQuery) -> SaQuery:
|
||||||
|
return query
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def id_column(self):
|
def id_column(self) -> SaColumn:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def anonymous_filter(self):
|
def anonymous_filter(self) -> Optional[Filter]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def special_filters(self):
|
def special_filters(self) -> Dict[str, Filter]:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def named_filters(self):
|
def named_filters(self) -> Dict[str, Filter]:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sort_columns(self):
|
def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]:
|
||||||
return {}
|
return {}
|
||||||
|
|
|
@ -1,59 +1,62 @@
|
||||||
from sqlalchemy.sql.expression import func
|
from typing import Tuple, Dict
|
||||||
from szurubooru import db
|
import sqlalchemy as sa
|
||||||
|
from szurubooru import db, model
|
||||||
|
from szurubooru.search.typing import SaColumn, SaQuery
|
||||||
from szurubooru.search.configs import util as search_util
|
from szurubooru.search.configs import util as search_util
|
||||||
from szurubooru.search.configs.base_search_config import BaseSearchConfig
|
from szurubooru.search.configs.base_search_config import (
|
||||||
|
BaseSearchConfig, Filter)
|
||||||
|
|
||||||
|
|
||||||
class CommentSearchConfig(BaseSearchConfig):
|
class CommentSearchConfig(BaseSearchConfig):
|
||||||
def create_filter_query(self, _disable_eager_loads):
|
def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery:
|
||||||
return db.session.query(db.Comment).join(db.User)
|
return db.session.query(model.Comment).join(model.User)
|
||||||
|
|
||||||
def create_count_query(self, disable_eager_loads):
|
def create_count_query(self, disable_eager_loads: bool) -> SaQuery:
|
||||||
return self.create_filter_query(disable_eager_loads)
|
return self.create_filter_query(disable_eager_loads)
|
||||||
|
|
||||||
def create_around_query(self):
|
def create_around_query(self) -> SaQuery:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def finalize_query(self, query):
|
def finalize_query(self, query: SaQuery) -> SaQuery:
|
||||||
return query.order_by(db.Comment.creation_time.desc())
|
return query.order_by(model.Comment.creation_time.desc())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def anonymous_filter(self):
|
def anonymous_filter(self) -> SaQuery:
|
||||||
return search_util.create_str_filter(db.Comment.text)
|
return search_util.create_str_filter(model.Comment.text)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def named_filters(self):
|
def named_filters(self) -> Dict[str, Filter]:
|
||||||
return {
|
return {
|
||||||
'id': search_util.create_num_filter(db.Comment.comment_id),
|
'id': search_util.create_num_filter(model.Comment.comment_id),
|
||||||
'post': search_util.create_num_filter(db.Comment.post_id),
|
'post': search_util.create_num_filter(model.Comment.post_id),
|
||||||
'user': search_util.create_str_filter(db.User.name),
|
'user': search_util.create_str_filter(model.User.name),
|
||||||
'author': search_util.create_str_filter(db.User.name),
|
'author': search_util.create_str_filter(model.User.name),
|
||||||
'text': search_util.create_str_filter(db.Comment.text),
|
'text': search_util.create_str_filter(model.Comment.text),
|
||||||
'creation-date':
|
'creation-date':
|
||||||
search_util.create_date_filter(db.Comment.creation_time),
|
search_util.create_date_filter(model.Comment.creation_time),
|
||||||
'creation-time':
|
'creation-time':
|
||||||
search_util.create_date_filter(db.Comment.creation_time),
|
search_util.create_date_filter(model.Comment.creation_time),
|
||||||
'last-edit-date':
|
'last-edit-date':
|
||||||
search_util.create_date_filter(db.Comment.last_edit_time),
|
search_util.create_date_filter(model.Comment.last_edit_time),
|
||||||
'last-edit-time':
|
'last-edit-time':
|
||||||
search_util.create_date_filter(db.Comment.last_edit_time),
|
search_util.create_date_filter(model.Comment.last_edit_time),
|
||||||
'edit-date':
|
'edit-date':
|
||||||
search_util.create_date_filter(db.Comment.last_edit_time),
|
search_util.create_date_filter(model.Comment.last_edit_time),
|
||||||
'edit-time':
|
'edit-time':
|
||||||
search_util.create_date_filter(db.Comment.last_edit_time),
|
search_util.create_date_filter(model.Comment.last_edit_time),
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sort_columns(self):
|
def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]:
|
||||||
return {
|
return {
|
||||||
'random': (func.random(), None),
|
'random': (sa.sql.expression.func.random(), self.SORT_NONE),
|
||||||
'user': (db.User.name, self.SORT_ASC),
|
'user': (model.User.name, self.SORT_ASC),
|
||||||
'author': (db.User.name, self.SORT_ASC),
|
'author': (model.User.name, self.SORT_ASC),
|
||||||
'post': (db.Comment.post_id, self.SORT_DESC),
|
'post': (model.Comment.post_id, self.SORT_DESC),
|
||||||
'creation-date': (db.Comment.creation_time, self.SORT_DESC),
|
'creation-date': (model.Comment.creation_time, self.SORT_DESC),
|
||||||
'creation-time': (db.Comment.creation_time, self.SORT_DESC),
|
'creation-time': (model.Comment.creation_time, self.SORT_DESC),
|
||||||
'last-edit-date': (db.Comment.last_edit_time, self.SORT_DESC),
|
'last-edit-date': (model.Comment.last_edit_time, self.SORT_DESC),
|
||||||
'last-edit-time': (db.Comment.last_edit_time, self.SORT_DESC),
|
'last-edit-time': (model.Comment.last_edit_time, self.SORT_DESC),
|
||||||
'edit-date': (db.Comment.last_edit_time, self.SORT_DESC),
|
'edit-date': (model.Comment.last_edit_time, self.SORT_DESC),
|
||||||
'edit-time': (db.Comment.last_edit_time, self.SORT_DESC),
|
'edit-time': (model.Comment.last_edit_time, self.SORT_DESC),
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,13 +1,16 @@
|
||||||
from sqlalchemy.orm import subqueryload, lazyload, defer, aliased
|
from typing import Any, Optional, Tuple, Dict
|
||||||
from sqlalchemy.sql.expression import func
|
import sqlalchemy as sa
|
||||||
from szurubooru import db, errors
|
from szurubooru import db, model, errors
|
||||||
from szurubooru.func import util
|
from szurubooru.func import util
|
||||||
from szurubooru.search import criteria, tokens
|
from szurubooru.search import criteria, tokens
|
||||||
|
from szurubooru.search.typing import SaColumn, SaQuery
|
||||||
|
from szurubooru.search.query import SearchQuery
|
||||||
from szurubooru.search.configs import util as search_util
|
from szurubooru.search.configs import util as search_util
|
||||||
from szurubooru.search.configs.base_search_config import BaseSearchConfig
|
from szurubooru.search.configs.base_search_config import (
|
||||||
|
BaseSearchConfig, Filter)
|
||||||
|
|
||||||
|
|
||||||
def _enum_transformer(available_values, value):
|
def _enum_transformer(available_values: Dict[str, Any], value: str) -> str:
|
||||||
try:
|
try:
|
||||||
return available_values[value.lower()]
|
return available_values[value.lower()]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
@ -16,71 +19,82 @@ def _enum_transformer(available_values, value):
|
||||||
value, list(sorted(available_values.keys()))))
|
value, list(sorted(available_values.keys()))))
|
||||||
|
|
||||||
|
|
||||||
def _type_transformer(value):
|
def _type_transformer(value: str) -> str:
|
||||||
available_values = {
|
available_values = {
|
||||||
'image': db.Post.TYPE_IMAGE,
|
'image': model.Post.TYPE_IMAGE,
|
||||||
'animation': db.Post.TYPE_ANIMATION,
|
'animation': model.Post.TYPE_ANIMATION,
|
||||||
'animated': db.Post.TYPE_ANIMATION,
|
'animated': model.Post.TYPE_ANIMATION,
|
||||||
'anim': db.Post.TYPE_ANIMATION,
|
'anim': model.Post.TYPE_ANIMATION,
|
||||||
'gif': db.Post.TYPE_ANIMATION,
|
'gif': model.Post.TYPE_ANIMATION,
|
||||||
'video': db.Post.TYPE_VIDEO,
|
'video': model.Post.TYPE_VIDEO,
|
||||||
'webm': db.Post.TYPE_VIDEO,
|
'webm': model.Post.TYPE_VIDEO,
|
||||||
'flash': db.Post.TYPE_FLASH,
|
'flash': model.Post.TYPE_FLASH,
|
||||||
'swf': db.Post.TYPE_FLASH,
|
'swf': model.Post.TYPE_FLASH,
|
||||||
}
|
}
|
||||||
return _enum_transformer(available_values, value)
|
return _enum_transformer(available_values, value)
|
||||||
|
|
||||||
|
|
||||||
def _safety_transformer(value):
|
def _safety_transformer(value: str) -> str:
|
||||||
available_values = {
|
available_values = {
|
||||||
'safe': db.Post.SAFETY_SAFE,
|
'safe': model.Post.SAFETY_SAFE,
|
||||||
'sketchy': db.Post.SAFETY_SKETCHY,
|
'sketchy': model.Post.SAFETY_SKETCHY,
|
||||||
'questionable': db.Post.SAFETY_SKETCHY,
|
'questionable': model.Post.SAFETY_SKETCHY,
|
||||||
'unsafe': db.Post.SAFETY_UNSAFE,
|
'unsafe': model.Post.SAFETY_UNSAFE,
|
||||||
}
|
}
|
||||||
return _enum_transformer(available_values, value)
|
return _enum_transformer(available_values, value)
|
||||||
|
|
||||||
|
|
||||||
def _create_score_filter(score):
|
def _create_score_filter(score: int) -> Filter:
|
||||||
def wrapper(query, criterion, negated):
|
def wrapper(
|
||||||
|
query: SaQuery,
|
||||||
|
criterion: Optional[criteria.BaseCriterion],
|
||||||
|
negated: bool) -> SaQuery:
|
||||||
|
assert criterion
|
||||||
if not getattr(criterion, 'internal', False):
|
if not getattr(criterion, 'internal', False):
|
||||||
raise errors.SearchError(
|
raise errors.SearchError(
|
||||||
'Votes cannot be seen publicly. Did you mean %r?'
|
'Votes cannot be seen publicly. Did you mean %r?'
|
||||||
% 'special:liked')
|
% 'special:liked')
|
||||||
user_alias = aliased(db.User)
|
user_alias = sa.orm.aliased(model.User)
|
||||||
score_alias = aliased(db.PostScore)
|
score_alias = sa.orm.aliased(model.PostScore)
|
||||||
expr = score_alias.score == score
|
expr = score_alias.score == score
|
||||||
expr = expr & search_util.apply_str_criterion_to_column(
|
expr = expr & search_util.apply_str_criterion_to_column(
|
||||||
user_alias.name, criterion)
|
user_alias.name, criterion)
|
||||||
if negated:
|
if negated:
|
||||||
expr = ~expr
|
expr = ~expr
|
||||||
ret = query \
|
ret = query \
|
||||||
.join(score_alias, score_alias.post_id == db.Post.post_id) \
|
.join(score_alias, score_alias.post_id == model.Post.post_id) \
|
||||||
.join(user_alias, user_alias.user_id == score_alias.user_id) \
|
.join(user_alias, user_alias.user_id == score_alias.user_id) \
|
||||||
.filter(expr)
|
.filter(expr)
|
||||||
return ret
|
return ret
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def _create_user_filter():
|
def _create_user_filter() -> Filter:
|
||||||
def wrapper(query, criterion, negated):
|
def wrapper(
|
||||||
|
query: SaQuery,
|
||||||
|
criterion: Optional[criteria.BaseCriterion],
|
||||||
|
negated: bool) -> SaQuery:
|
||||||
|
assert criterion
|
||||||
if isinstance(criterion, criteria.PlainCriterion) \
|
if isinstance(criterion, criteria.PlainCriterion) \
|
||||||
and not criterion.value:
|
and not criterion.value:
|
||||||
# pylint: disable=singleton-comparison
|
# pylint: disable=singleton-comparison
|
||||||
expr = db.Post.user_id == None
|
expr = model.Post.user_id == None
|
||||||
if negated:
|
if negated:
|
||||||
expr = ~expr
|
expr = ~expr
|
||||||
return query.filter(expr)
|
return query.filter(expr)
|
||||||
return search_util.create_subquery_filter(
|
return search_util.create_subquery_filter(
|
||||||
db.Post.user_id,
|
model.Post.user_id,
|
||||||
db.User.user_id,
|
model.User.user_id,
|
||||||
db.User.name,
|
model.User.name,
|
||||||
search_util.create_str_filter)(query, criterion, negated)
|
search_util.create_str_filter)(query, criterion, negated)
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
class PostSearchConfig(BaseSearchConfig):
|
class PostSearchConfig(BaseSearchConfig):
|
||||||
def on_search_query_parsed(self, search_query):
|
def __init__(self) -> None:
|
||||||
|
self.user = None # type: Optional[model.User]
|
||||||
|
|
||||||
|
def on_search_query_parsed(self, search_query: SearchQuery) -> SaQuery:
|
||||||
new_special_tokens = []
|
new_special_tokens = []
|
||||||
for token in search_query.special_tokens:
|
for token in search_query.special_tokens:
|
||||||
if token.value in ('fav', 'liked', 'disliked'):
|
if token.value in ('fav', 'liked', 'disliked'):
|
||||||
|
@ -91,7 +105,7 @@ class PostSearchConfig(BaseSearchConfig):
|
||||||
criterion = criteria.PlainCriterion(
|
criterion = criteria.PlainCriterion(
|
||||||
original_text=self.user.name,
|
original_text=self.user.name,
|
||||||
value=self.user.name)
|
value=self.user.name)
|
||||||
criterion.internal = True
|
setattr(criterion, 'internal', True)
|
||||||
search_query.named_tokens.append(
|
search_query.named_tokens.append(
|
||||||
tokens.NamedToken(
|
tokens.NamedToken(
|
||||||
name=token.value,
|
name=token.value,
|
||||||
|
@ -101,160 +115,324 @@ class PostSearchConfig(BaseSearchConfig):
|
||||||
new_special_tokens.append(token)
|
new_special_tokens.append(token)
|
||||||
search_query.special_tokens = new_special_tokens
|
search_query.special_tokens = new_special_tokens
|
||||||
|
|
||||||
def create_around_query(self):
|
def create_around_query(self) -> SaQuery:
|
||||||
return db.session.query(db.Post).options(lazyload('*'))
|
return db.session.query(model.Post).options(sa.orm.lazyload('*'))
|
||||||
|
|
||||||
def create_filter_query(self, disable_eager_loads):
|
def create_filter_query(self, disable_eager_loads: bool) -> SaQuery:
|
||||||
strategy = lazyload if disable_eager_loads else subqueryload
|
strategy = (
|
||||||
return db.session.query(db.Post) \
|
sa.orm.lazyload
|
||||||
|
if disable_eager_loads
|
||||||
|
else sa.orm.subqueryload)
|
||||||
|
return db.session.query(model.Post) \
|
||||||
.options(
|
.options(
|
||||||
lazyload('*'),
|
sa.orm.lazyload('*'),
|
||||||
# use config optimized for official client
|
# use config optimized for official client
|
||||||
# defer(db.Post.score),
|
# sa.orm.defer(model.Post.score),
|
||||||
# defer(db.Post.favorite_count),
|
# sa.orm.defer(model.Post.favorite_count),
|
||||||
# defer(db.Post.comment_count),
|
# sa.orm.defer(model.Post.comment_count),
|
||||||
defer(db.Post.last_favorite_time),
|
sa.orm.defer(model.Post.last_favorite_time),
|
||||||
defer(db.Post.feature_count),
|
sa.orm.defer(model.Post.feature_count),
|
||||||
defer(db.Post.last_feature_time),
|
sa.orm.defer(model.Post.last_feature_time),
|
||||||
defer(db.Post.last_comment_creation_time),
|
sa.orm.defer(model.Post.last_comment_creation_time),
|
||||||
defer(db.Post.last_comment_edit_time),
|
sa.orm.defer(model.Post.last_comment_edit_time),
|
||||||
defer(db.Post.note_count),
|
sa.orm.defer(model.Post.note_count),
|
||||||
defer(db.Post.tag_count),
|
sa.orm.defer(model.Post.tag_count),
|
||||||
strategy(db.Post.tags).subqueryload(db.Tag.names),
|
strategy(model.Post.tags).subqueryload(model.Tag.names),
|
||||||
strategy(db.Post.tags).defer(db.Tag.post_count),
|
strategy(model.Post.tags).defer(model.Tag.post_count),
|
||||||
strategy(db.Post.tags).lazyload(db.Tag.implications),
|
strategy(model.Post.tags).lazyload(model.Tag.implications),
|
||||||
strategy(db.Post.tags).lazyload(db.Tag.suggestions))
|
strategy(model.Post.tags).lazyload(model.Tag.suggestions))
|
||||||
|
|
||||||
def create_count_query(self, _disable_eager_loads):
|
def create_count_query(self, _disable_eager_loads: bool) -> SaQuery:
|
||||||
return db.session.query(db.Post)
|
return db.session.query(model.Post)
|
||||||
|
|
||||||
def finalize_query(self, query):
|
def finalize_query(self, query: SaQuery) -> SaQuery:
|
||||||
return query.order_by(db.Post.post_id.desc())
|
return query.order_by(model.Post.post_id.desc())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def id_column(self):
|
def id_column(self) -> SaColumn:
|
||||||
return db.Post.post_id
|
return model.Post.post_id
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def anonymous_filter(self):
|
def anonymous_filter(self) -> Optional[Filter]:
|
||||||
return search_util.create_subquery_filter(
|
return search_util.create_subquery_filter(
|
||||||
db.Post.post_id,
|
model.Post.post_id,
|
||||||
db.PostTag.post_id,
|
model.PostTag.post_id,
|
||||||
db.TagName.name,
|
model.TagName.name,
|
||||||
search_util.create_str_filter,
|
search_util.create_str_filter,
|
||||||
lambda subquery: subquery.join(db.Tag).join(db.TagName))
|
lambda subquery: subquery.join(model.Tag).join(model.TagName))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def named_filters(self):
|
def named_filters(self) -> Dict[str, Filter]:
|
||||||
return util.unalias_dict({
|
return util.unalias_dict([
|
||||||
'id': search_util.create_num_filter(db.Post.post_id),
|
(
|
||||||
'tag': search_util.create_subquery_filter(
|
['id'],
|
||||||
db.Post.post_id,
|
search_util.create_num_filter(model.Post.post_id)
|
||||||
db.PostTag.post_id,
|
),
|
||||||
db.TagName.name,
|
|
||||||
|
(
|
||||||
|
['tag'],
|
||||||
|
search_util.create_subquery_filter(
|
||||||
|
model.Post.post_id,
|
||||||
|
model.PostTag.post_id,
|
||||||
|
model.TagName.name,
|
||||||
search_util.create_str_filter,
|
search_util.create_str_filter,
|
||||||
lambda subquery: subquery.join(db.Tag).join(db.TagName)),
|
lambda subquery:
|
||||||
'score': search_util.create_num_filter(db.Post.score),
|
subquery.join(model.Tag).join(model.TagName))
|
||||||
('uploader', 'upload', 'submit'):
|
),
|
||||||
_create_user_filter(),
|
|
||||||
'comment': search_util.create_subquery_filter(
|
(
|
||||||
db.Post.post_id,
|
['score'],
|
||||||
db.Comment.post_id,
|
search_util.create_num_filter(model.Post.score)
|
||||||
db.User.name,
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['uploader', 'upload', 'submit'],
|
||||||
|
_create_user_filter()
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['comment'],
|
||||||
|
search_util.create_subquery_filter(
|
||||||
|
model.Post.post_id,
|
||||||
|
model.Comment.post_id,
|
||||||
|
model.User.name,
|
||||||
search_util.create_str_filter,
|
search_util.create_str_filter,
|
||||||
lambda subquery: subquery.join(db.User)),
|
lambda subquery: subquery.join(model.User))
|
||||||
'fav': search_util.create_subquery_filter(
|
),
|
||||||
db.Post.post_id,
|
|
||||||
db.PostFavorite.post_id,
|
(
|
||||||
db.User.name,
|
['fav'],
|
||||||
|
search_util.create_subquery_filter(
|
||||||
|
model.Post.post_id,
|
||||||
|
model.PostFavorite.post_id,
|
||||||
|
model.User.name,
|
||||||
search_util.create_str_filter,
|
search_util.create_str_filter,
|
||||||
lambda subquery: subquery.join(db.User)),
|
lambda subquery: subquery.join(model.User))
|
||||||
'liked': _create_score_filter(1),
|
),
|
||||||
'disliked': _create_score_filter(-1),
|
|
||||||
'tag-count': search_util.create_num_filter(db.Post.tag_count),
|
(
|
||||||
'comment-count':
|
['liked'],
|
||||||
search_util.create_num_filter(db.Post.comment_count),
|
_create_score_filter(1)
|
||||||
'fav-count':
|
),
|
||||||
search_util.create_num_filter(db.Post.favorite_count),
|
(
|
||||||
'note-count': search_util.create_num_filter(db.Post.note_count),
|
['disliked'],
|
||||||
'relation-count':
|
_create_score_filter(-1)
|
||||||
search_util.create_num_filter(db.Post.relation_count),
|
),
|
||||||
'feature-count':
|
|
||||||
search_util.create_num_filter(db.Post.feature_count),
|
(
|
||||||
'type':
|
['tag-count'],
|
||||||
|
search_util.create_num_filter(model.Post.tag_count)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['comment-count'],
|
||||||
|
search_util.create_num_filter(model.Post.comment_count)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['fav-count'],
|
||||||
|
search_util.create_num_filter(model.Post.favorite_count)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['note-count'],
|
||||||
|
search_util.create_num_filter(model.Post.note_count)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['relation-count'],
|
||||||
|
search_util.create_num_filter(model.Post.relation_count)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['feature-count'],
|
||||||
|
search_util.create_num_filter(model.Post.feature_count)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['type'],
|
||||||
search_util.create_str_filter(
|
search_util.create_str_filter(
|
||||||
db.Post.type, _type_transformer),
|
model.Post.type, _type_transformer)
|
||||||
'content-checksum': search_util.create_str_filter(
|
),
|
||||||
db.Post.checksum),
|
|
||||||
'file-size': search_util.create_num_filter(db.Post.file_size),
|
(
|
||||||
('image-width', 'width'):
|
['content-checksum'],
|
||||||
search_util.create_num_filter(db.Post.canvas_width),
|
search_util.create_str_filter(model.Post.checksum)
|
||||||
('image-height', 'height'):
|
),
|
||||||
search_util.create_num_filter(db.Post.canvas_height),
|
|
||||||
('image-area', 'area'):
|
(
|
||||||
search_util.create_num_filter(db.Post.canvas_area),
|
['file-size'],
|
||||||
('creation-date', 'creation-time', 'date', 'time'):
|
search_util.create_num_filter(model.Post.file_size)
|
||||||
search_util.create_date_filter(db.Post.creation_time),
|
),
|
||||||
('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'):
|
|
||||||
search_util.create_date_filter(db.Post.last_edit_time),
|
(
|
||||||
('comment-date', 'comment-time'):
|
['image-width', 'width'],
|
||||||
|
search_util.create_num_filter(model.Post.canvas_width)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['image-height', 'height'],
|
||||||
|
search_util.create_num_filter(model.Post.canvas_height)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['image-area', 'area'],
|
||||||
|
search_util.create_num_filter(model.Post.canvas_area)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['creation-date', 'creation-time', 'date', 'time'],
|
||||||
|
search_util.create_date_filter(model.Post.creation_time)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'],
|
||||||
|
search_util.create_date_filter(model.Post.last_edit_time)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['comment-date', 'comment-time'],
|
||||||
search_util.create_date_filter(
|
search_util.create_date_filter(
|
||||||
db.Post.last_comment_creation_time),
|
model.Post.last_comment_creation_time)
|
||||||
('fav-date', 'fav-time'):
|
),
|
||||||
search_util.create_date_filter(db.Post.last_favorite_time),
|
|
||||||
('feature-date', 'feature-time'):
|
(
|
||||||
search_util.create_date_filter(db.Post.last_feature_time),
|
['fav-date', 'fav-time'],
|
||||||
('safety', 'rating'):
|
search_util.create_date_filter(model.Post.last_favorite_time)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['feature-date', 'feature-time'],
|
||||||
|
search_util.create_date_filter(model.Post.last_feature_time)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['safety', 'rating'],
|
||||||
search_util.create_str_filter(
|
search_util.create_str_filter(
|
||||||
db.Post.safety, _safety_transformer),
|
model.Post.safety, _safety_transformer)
|
||||||
})
|
),
|
||||||
|
])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sort_columns(self):
|
def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]:
|
||||||
return util.unalias_dict({
|
return util.unalias_dict([
|
||||||
'random': (func.random(), None),
|
(
|
||||||
'id': (db.Post.post_id, self.SORT_DESC),
|
['random'],
|
||||||
'score': (db.Post.score, self.SORT_DESC),
|
(sa.sql.expression.func.random(), self.SORT_NONE)
|
||||||
'tag-count': (db.Post.tag_count, self.SORT_DESC),
|
),
|
||||||
'comment-count': (db.Post.comment_count, self.SORT_DESC),
|
|
||||||
'fav-count': (db.Post.favorite_count, self.SORT_DESC),
|
(
|
||||||
'note-count': (db.Post.note_count, self.SORT_DESC),
|
['id'],
|
||||||
'relation-count': (db.Post.relation_count, self.SORT_DESC),
|
(model.Post.post_id, self.SORT_DESC)
|
||||||
'feature-count': (db.Post.feature_count, self.SORT_DESC),
|
),
|
||||||
'file-size': (db.Post.file_size, self.SORT_DESC),
|
|
||||||
('image-width', 'width'):
|
(
|
||||||
(db.Post.canvas_width, self.SORT_DESC),
|
['score'],
|
||||||
('image-height', 'height'):
|
(model.Post.score, self.SORT_DESC)
|
||||||
(db.Post.canvas_height, self.SORT_DESC),
|
),
|
||||||
('image-area', 'area'):
|
|
||||||
(db.Post.canvas_area, self.SORT_DESC),
|
(
|
||||||
('creation-date', 'creation-time', 'date', 'time'):
|
['tag-count'],
|
||||||
(db.Post.creation_time, self.SORT_DESC),
|
(model.Post.tag_count, self.SORT_DESC)
|
||||||
('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'):
|
),
|
||||||
(db.Post.last_edit_time, self.SORT_DESC),
|
|
||||||
('comment-date', 'comment-time'):
|
(
|
||||||
(db.Post.last_comment_creation_time, self.SORT_DESC),
|
['comment-count'],
|
||||||
('fav-date', 'fav-time'):
|
(model.Post.comment_count, self.SORT_DESC)
|
||||||
(db.Post.last_favorite_time, self.SORT_DESC),
|
),
|
||||||
('feature-date', 'feature-time'):
|
|
||||||
(db.Post.last_feature_time, self.SORT_DESC),
|
(
|
||||||
})
|
['fav-count'],
|
||||||
|
(model.Post.favorite_count, self.SORT_DESC)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['note-count'],
|
||||||
|
(model.Post.note_count, self.SORT_DESC)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['relation-count'],
|
||||||
|
(model.Post.relation_count, self.SORT_DESC)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['feature-count'],
|
||||||
|
(model.Post.feature_count, self.SORT_DESC)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['file-size'],
|
||||||
|
(model.Post.file_size, self.SORT_DESC)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['image-width', 'width'],
|
||||||
|
(model.Post.canvas_width, self.SORT_DESC)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['image-height', 'height'],
|
||||||
|
(model.Post.canvas_height, self.SORT_DESC)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['image-area', 'area'],
|
||||||
|
(model.Post.canvas_area, self.SORT_DESC)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['creation-date', 'creation-time', 'date', 'time'],
|
||||||
|
(model.Post.creation_time, self.SORT_DESC)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'],
|
||||||
|
(model.Post.last_edit_time, self.SORT_DESC)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['comment-date', 'comment-time'],
|
||||||
|
(model.Post.last_comment_creation_time, self.SORT_DESC)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['fav-date', 'fav-time'],
|
||||||
|
(model.Post.last_favorite_time, self.SORT_DESC)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['feature-date', 'feature-time'],
|
||||||
|
(model.Post.last_feature_time, self.SORT_DESC)
|
||||||
|
),
|
||||||
|
])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def special_filters(self):
|
def special_filters(self) -> Dict[str, Filter]:
|
||||||
return {
|
return {
|
||||||
# handled by parsed
|
# handled by parser
|
||||||
'fav': None,
|
'fav': self.noop_filter,
|
||||||
'liked': None,
|
'liked': self.noop_filter,
|
||||||
'disliked': None,
|
'disliked': self.noop_filter,
|
||||||
'tumbleweed': self.tumbleweed_filter,
|
'tumbleweed': self.tumbleweed_filter,
|
||||||
}
|
}
|
||||||
|
|
||||||
def tumbleweed_filter(self, query, negated):
|
def noop_filter(
|
||||||
expr = \
|
self,
|
||||||
(db.Post.comment_count == 0) \
|
query: SaQuery,
|
||||||
& (db.Post.favorite_count == 0) \
|
_criterion: Optional[criteria.BaseCriterion],
|
||||||
& (db.Post.score == 0)
|
_negated: bool) -> SaQuery:
|
||||||
|
return query
|
||||||
|
|
||||||
|
def tumbleweed_filter(
|
||||||
|
self,
|
||||||
|
query: SaQuery,
|
||||||
|
_criterion: Optional[criteria.BaseCriterion],
|
||||||
|
negated: bool) -> SaQuery:
|
||||||
|
expr = (
|
||||||
|
(model.Post.comment_count == 0)
|
||||||
|
& (model.Post.favorite_count == 0)
|
||||||
|
& (model.Post.score == 0))
|
||||||
if negated:
|
if negated:
|
||||||
expr = ~expr
|
expr = ~expr
|
||||||
return query.filter(expr)
|
return query.filter(expr)
|
||||||
|
|
|
@ -1,28 +1,37 @@
|
||||||
from szurubooru import db
|
from typing import Dict
|
||||||
|
from szurubooru import db, model
|
||||||
|
from szurubooru.search.typing import SaQuery
|
||||||
from szurubooru.search.configs import util as search_util
|
from szurubooru.search.configs import util as search_util
|
||||||
from szurubooru.search.configs.base_search_config import BaseSearchConfig
|
from szurubooru.search.configs.base_search_config import (
|
||||||
|
BaseSearchConfig, Filter)
|
||||||
|
|
||||||
|
|
||||||
class SnapshotSearchConfig(BaseSearchConfig):
|
class SnapshotSearchConfig(BaseSearchConfig):
|
||||||
def create_filter_query(self, _disable_eager_loads):
|
def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery:
|
||||||
return db.session.query(db.Snapshot)
|
return db.session.query(model.Snapshot)
|
||||||
|
|
||||||
def create_count_query(self, _disable_eager_loads):
|
def create_count_query(self, _disable_eager_loads: bool) -> SaQuery:
|
||||||
return db.session.query(db.Snapshot)
|
return db.session.query(model.Snapshot)
|
||||||
|
|
||||||
def create_around_query(self):
|
def create_around_query(self) -> SaQuery:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def finalize_query(self, query):
|
def finalize_query(self, query: SaQuery) -> SaQuery:
|
||||||
return query.order_by(db.Snapshot.creation_time.desc())
|
return query.order_by(model.Snapshot.creation_time.desc())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def named_filters(self):
|
def named_filters(self) -> Dict[str, Filter]:
|
||||||
return {
|
return {
|
||||||
'type': search_util.create_str_filter(db.Snapshot.resource_type),
|
'type':
|
||||||
'id': search_util.create_str_filter(db.Snapshot.resource_name),
|
search_util.create_str_filter(model.Snapshot.resource_type),
|
||||||
'date': search_util.create_date_filter(db.Snapshot.creation_time),
|
'id':
|
||||||
'time': search_util.create_date_filter(db.Snapshot.creation_time),
|
search_util.create_str_filter(model.Snapshot.resource_name),
|
||||||
'operation': search_util.create_str_filter(db.Snapshot.operation),
|
'date':
|
||||||
'user': search_util.create_str_filter(db.User.name),
|
search_util.create_date_filter(model.Snapshot.creation_time),
|
||||||
|
'time':
|
||||||
|
search_util.create_date_filter(model.Snapshot.creation_time),
|
||||||
|
'operation':
|
||||||
|
search_util.create_str_filter(model.Snapshot.operation),
|
||||||
|
'user':
|
||||||
|
search_util.create_str_filter(model.User.name),
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,79 +1,134 @@
|
||||||
from sqlalchemy.orm import subqueryload, lazyload, defer
|
from typing import Tuple, Dict
|
||||||
from sqlalchemy.sql.expression import func
|
import sqlalchemy as sa
|
||||||
from szurubooru import db
|
from szurubooru import db, model
|
||||||
from szurubooru.func import util
|
from szurubooru.func import util
|
||||||
|
from szurubooru.search.typing import SaColumn, SaQuery
|
||||||
from szurubooru.search.configs import util as search_util
|
from szurubooru.search.configs import util as search_util
|
||||||
from szurubooru.search.configs.base_search_config import BaseSearchConfig
|
from szurubooru.search.configs.base_search_config import (
|
||||||
|
BaseSearchConfig, Filter)
|
||||||
|
|
||||||
|
|
||||||
class TagSearchConfig(BaseSearchConfig):
|
class TagSearchConfig(BaseSearchConfig):
|
||||||
def create_filter_query(self, _disable_eager_loads):
|
def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery:
|
||||||
strategy = lazyload if _disable_eager_loads else subqueryload
|
strategy = (
|
||||||
return db.session.query(db.Tag) \
|
sa.orm.lazyload
|
||||||
.join(db.TagCategory) \
|
if _disable_eager_loads
|
||||||
|
else sa.orm.subqueryload)
|
||||||
|
return db.session.query(model.Tag) \
|
||||||
|
.join(model.TagCategory) \
|
||||||
.options(
|
.options(
|
||||||
defer(db.Tag.first_name),
|
sa.orm.defer(model.Tag.first_name),
|
||||||
defer(db.Tag.suggestion_count),
|
sa.orm.defer(model.Tag.suggestion_count),
|
||||||
defer(db.Tag.implication_count),
|
sa.orm.defer(model.Tag.implication_count),
|
||||||
defer(db.Tag.post_count),
|
sa.orm.defer(model.Tag.post_count),
|
||||||
strategy(db.Tag.names),
|
strategy(model.Tag.names),
|
||||||
strategy(db.Tag.suggestions).joinedload(db.Tag.names),
|
strategy(model.Tag.suggestions).joinedload(model.Tag.names),
|
||||||
strategy(db.Tag.implications).joinedload(db.Tag.names))
|
strategy(model.Tag.implications).joinedload(model.Tag.names))
|
||||||
|
|
||||||
def create_count_query(self, _disable_eager_loads):
|
def create_count_query(self, _disable_eager_loads: bool) -> SaQuery:
|
||||||
return db.session.query(db.Tag)
|
return db.session.query(model.Tag)
|
||||||
|
|
||||||
def create_around_query(self):
|
def create_around_query(self) -> SaQuery:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def finalize_query(self, query):
|
def finalize_query(self, query: SaQuery) -> SaQuery:
|
||||||
return query.order_by(db.Tag.first_name.asc())
|
return query.order_by(model.Tag.first_name.asc())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def anonymous_filter(self):
|
def anonymous_filter(self) -> Filter:
|
||||||
return search_util.create_subquery_filter(
|
return search_util.create_subquery_filter(
|
||||||
db.Tag.tag_id,
|
model.Tag.tag_id,
|
||||||
db.TagName.tag_id,
|
model.TagName.tag_id,
|
||||||
db.TagName.name,
|
model.TagName.name,
|
||||||
search_util.create_str_filter)
|
search_util.create_str_filter)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def named_filters(self):
|
def named_filters(self) -> Dict[str, Filter]:
|
||||||
return util.unalias_dict({
|
return util.unalias_dict([
|
||||||
'name': search_util.create_subquery_filter(
|
(
|
||||||
db.Tag.tag_id,
|
['name'],
|
||||||
db.TagName.tag_id,
|
search_util.create_subquery_filter(
|
||||||
db.TagName.name,
|
model.Tag.tag_id,
|
||||||
search_util.create_str_filter),
|
model.TagName.tag_id,
|
||||||
'category': search_util.create_subquery_filter(
|
model.TagName.name,
|
||||||
db.Tag.category_id,
|
search_util.create_str_filter)
|
||||||
db.TagCategory.tag_category_id,
|
),
|
||||||
db.TagCategory.name,
|
|
||||||
search_util.create_str_filter),
|
(
|
||||||
('creation-date', 'creation-time'):
|
['category'],
|
||||||
search_util.create_date_filter(db.Tag.creation_time),
|
search_util.create_subquery_filter(
|
||||||
('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'):
|
model.Tag.category_id,
|
||||||
search_util.create_date_filter(db.Tag.last_edit_time),
|
model.TagCategory.tag_category_id,
|
||||||
('usage-count', 'post-count', 'usages'):
|
model.TagCategory.name,
|
||||||
search_util.create_num_filter(db.Tag.post_count),
|
search_util.create_str_filter)
|
||||||
'suggestion-count':
|
),
|
||||||
search_util.create_num_filter(db.Tag.suggestion_count),
|
|
||||||
'implication-count':
|
(
|
||||||
search_util.create_num_filter(db.Tag.implication_count),
|
['creation-date', 'creation-time'],
|
||||||
})
|
search_util.create_date_filter(model.Tag.creation_time)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'],
|
||||||
|
search_util.create_date_filter(model.Tag.last_edit_time)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['usage-count', 'post-count', 'usages'],
|
||||||
|
search_util.create_num_filter(model.Tag.post_count)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['suggestion-count'],
|
||||||
|
search_util.create_num_filter(model.Tag.suggestion_count)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['implication-count'],
|
||||||
|
search_util.create_num_filter(model.Tag.implication_count)
|
||||||
|
),
|
||||||
|
])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sort_columns(self):
|
def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]:
|
||||||
return util.unalias_dict({
|
return util.unalias_dict([
|
||||||
'random': (func.random(), None),
|
(
|
||||||
'name': (db.Tag.first_name, self.SORT_ASC),
|
['random'],
|
||||||
'category': (db.TagCategory.name, self.SORT_ASC),
|
(sa.sql.expression.func.random(), self.SORT_NONE)
|
||||||
('creation-date', 'creation-time'):
|
),
|
||||||
(db.Tag.creation_time, self.SORT_DESC),
|
|
||||||
('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'):
|
(
|
||||||
(db.Tag.last_edit_time, self.SORT_DESC),
|
['name'],
|
||||||
('usage-count', 'post-count', 'usages'):
|
(model.Tag.first_name, self.SORT_ASC)
|
||||||
(db.Tag.post_count, self.SORT_DESC),
|
),
|
||||||
'suggestion-count': (db.Tag.suggestion_count, self.SORT_DESC),
|
|
||||||
'implication-count': (db.Tag.implication_count, self.SORT_DESC),
|
(
|
||||||
})
|
['category'],
|
||||||
|
(model.TagCategory.name, self.SORT_ASC)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['creation-date', 'creation-time'],
|
||||||
|
(model.Tag.creation_time, self.SORT_DESC)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'],
|
||||||
|
(model.Tag.last_edit_time, self.SORT_DESC)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['usage-count', 'post-count', 'usages'],
|
||||||
|
(model.Tag.post_count, self.SORT_DESC)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['suggestion-count'],
|
||||||
|
(model.Tag.suggestion_count, self.SORT_DESC)
|
||||||
|
),
|
||||||
|
|
||||||
|
(
|
||||||
|
['implication-count'],
|
||||||
|
(model.Tag.implication_count, self.SORT_DESC)
|
||||||
|
),
|
||||||
|
])
|
||||||
|
|
|
@ -1,53 +1,57 @@
|
||||||
from sqlalchemy.sql.expression import func
|
from typing import Tuple, Dict
|
||||||
from szurubooru import db
|
import sqlalchemy as sa
|
||||||
|
from szurubooru import db, model
|
||||||
|
from szurubooru.search.typing import SaColumn, SaQuery
|
||||||
from szurubooru.search.configs import util as search_util
|
from szurubooru.search.configs import util as search_util
|
||||||
from szurubooru.search.configs.base_search_config import BaseSearchConfig
|
from szurubooru.search.configs.base_search_config import (
|
||||||
|
BaseSearchConfig, Filter)
|
||||||
|
|
||||||
|
|
||||||
class UserSearchConfig(BaseSearchConfig):
|
class UserSearchConfig(BaseSearchConfig):
|
||||||
def create_filter_query(self, _disable_eager_loads):
|
def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery:
|
||||||
return db.session.query(db.User)
|
return db.session.query(model.User)
|
||||||
|
|
||||||
def create_count_query(self, _disable_eager_loads):
|
def create_count_query(self, _disable_eager_loads: bool) -> SaQuery:
|
||||||
return db.session.query(db.User)
|
return db.session.query(model.User)
|
||||||
|
|
||||||
def create_around_query(self):
|
def create_around_query(self) -> SaQuery:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def finalize_query(self, query):
|
def finalize_query(self, query: SaQuery) -> SaQuery:
|
||||||
return query.order_by(db.User.name.asc())
|
return query.order_by(model.User.name.asc())
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def anonymous_filter(self):
|
def anonymous_filter(self) -> Filter:
|
||||||
return search_util.create_str_filter(db.User.name)
|
return search_util.create_str_filter(model.User.name)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def named_filters(self):
|
def named_filters(self) -> Dict[str, Filter]:
|
||||||
return {
|
return {
|
||||||
'name': search_util.create_str_filter(db.User.name),
|
'name':
|
||||||
|
search_util.create_str_filter(model.User.name),
|
||||||
'creation-date':
|
'creation-date':
|
||||||
search_util.create_date_filter(db.User.creation_time),
|
search_util.create_date_filter(model.User.creation_time),
|
||||||
'creation-time':
|
'creation-time':
|
||||||
search_util.create_date_filter(db.User.creation_time),
|
search_util.create_date_filter(model.User.creation_time),
|
||||||
'last-login-date':
|
'last-login-date':
|
||||||
search_util.create_date_filter(db.User.last_login_time),
|
search_util.create_date_filter(model.User.last_login_time),
|
||||||
'last-login-time':
|
'last-login-time':
|
||||||
search_util.create_date_filter(db.User.last_login_time),
|
search_util.create_date_filter(model.User.last_login_time),
|
||||||
'login-date':
|
'login-date':
|
||||||
search_util.create_date_filter(db.User.last_login_time),
|
search_util.create_date_filter(model.User.last_login_time),
|
||||||
'login-time':
|
'login-time':
|
||||||
search_util.create_date_filter(db.User.last_login_time),
|
search_util.create_date_filter(model.User.last_login_time),
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def sort_columns(self):
|
def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]:
|
||||||
return {
|
return {
|
||||||
'random': (func.random(), None),
|
'random': (sa.sql.expression.func.random(), self.SORT_NONE),
|
||||||
'name': (db.User.name, self.SORT_ASC),
|
'name': (model.User.name, self.SORT_ASC),
|
||||||
'creation-date': (db.User.creation_time, self.SORT_DESC),
|
'creation-date': (model.User.creation_time, self.SORT_DESC),
|
||||||
'creation-time': (db.User.creation_time, self.SORT_DESC),
|
'creation-time': (model.User.creation_time, self.SORT_DESC),
|
||||||
'last-login-date': (db.User.last_login_time, self.SORT_DESC),
|
'last-login-date': (model.User.last_login_time, self.SORT_DESC),
|
||||||
'last-login-time': (db.User.last_login_time, self.SORT_DESC),
|
'last-login-time': (model.User.last_login_time, self.SORT_DESC),
|
||||||
'login-date': (db.User.last_login_time, self.SORT_DESC),
|
'login-date': (model.User.last_login_time, self.SORT_DESC),
|
||||||
'login-time': (db.User.last_login_time, self.SORT_DESC),
|
'login-time': (model.User.last_login_time, self.SORT_DESC),
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,10 +1,13 @@
|
||||||
import sqlalchemy
|
from typing import Any, Optional, Callable
|
||||||
|
import sqlalchemy as sa
|
||||||
from szurubooru import db, errors
|
from szurubooru import db, errors
|
||||||
from szurubooru.func import util
|
from szurubooru.func import util
|
||||||
from szurubooru.search import criteria
|
from szurubooru.search import criteria
|
||||||
|
from szurubooru.search.typing import SaColumn, SaQuery
|
||||||
|
from szurubooru.search.configs.base_search_config import Filter
|
||||||
|
|
||||||
|
|
||||||
def wildcard_transformer(value):
|
def wildcard_transformer(value: str) -> str:
|
||||||
return (
|
return (
|
||||||
value
|
value
|
||||||
.replace('\\', '\\\\')
|
.replace('\\', '\\\\')
|
||||||
|
@ -13,24 +16,21 @@ def wildcard_transformer(value):
|
||||||
.replace('*', '%'))
|
.replace('*', '%'))
|
||||||
|
|
||||||
|
|
||||||
def apply_num_criterion_to_column(column, criterion):
|
def apply_num_criterion_to_column(
|
||||||
'''
|
column: Any, criterion: criteria.BaseCriterion) -> Any:
|
||||||
Decorate SQLAlchemy filter on given column using supplied criterion.
|
|
||||||
'''
|
|
||||||
try:
|
try:
|
||||||
if isinstance(criterion, criteria.PlainCriterion):
|
if isinstance(criterion, criteria.PlainCriterion):
|
||||||
expr = column == int(criterion.value)
|
expr = column == int(criterion.value)
|
||||||
elif isinstance(criterion, criteria.ArrayCriterion):
|
elif isinstance(criterion, criteria.ArrayCriterion):
|
||||||
expr = column.in_(int(value) for value in criterion.values)
|
expr = column.in_(int(value) for value in criterion.values)
|
||||||
elif isinstance(criterion, criteria.RangedCriterion):
|
elif isinstance(criterion, criteria.RangedCriterion):
|
||||||
assert criterion.min_value != '' \
|
assert criterion.min_value or criterion.max_value
|
||||||
or criterion.max_value != ''
|
if criterion.min_value and criterion.max_value:
|
||||||
if criterion.min_value != '' and criterion.max_value != '':
|
|
||||||
expr = column.between(
|
expr = column.between(
|
||||||
int(criterion.min_value), int(criterion.max_value))
|
int(criterion.min_value), int(criterion.max_value))
|
||||||
elif criterion.min_value != '':
|
elif criterion.min_value:
|
||||||
expr = column >= int(criterion.min_value)
|
expr = column >= int(criterion.min_value)
|
||||||
elif criterion.max_value != '':
|
elif criterion.max_value:
|
||||||
expr = column <= int(criterion.max_value)
|
expr = column <= int(criterion.max_value)
|
||||||
else:
|
else:
|
||||||
assert False
|
assert False
|
||||||
|
@ -40,10 +40,13 @@ def apply_num_criterion_to_column(column, criterion):
|
||||||
return expr
|
return expr
|
||||||
|
|
||||||
|
|
||||||
def create_num_filter(column):
|
def create_num_filter(column: Any) -> Filter:
|
||||||
def wrapper(query, criterion, negated):
|
def wrapper(
|
||||||
expr = apply_num_criterion_to_column(
|
query: SaQuery,
|
||||||
column, criterion)
|
criterion: Optional[criteria.BaseCriterion],
|
||||||
|
negated: bool) -> SaQuery:
|
||||||
|
assert criterion
|
||||||
|
expr = apply_num_criterion_to_column(column, criterion)
|
||||||
if negated:
|
if negated:
|
||||||
expr = ~expr
|
expr = ~expr
|
||||||
return query.filter(expr)
|
return query.filter(expr)
|
||||||
|
@ -51,14 +54,13 @@ def create_num_filter(column):
|
||||||
|
|
||||||
|
|
||||||
def apply_str_criterion_to_column(
|
def apply_str_criterion_to_column(
|
||||||
column, criterion, transformer=wildcard_transformer):
|
column: SaColumn,
|
||||||
'''
|
criterion: criteria.BaseCriterion,
|
||||||
Decorate SQLAlchemy filter on given column using supplied criterion.
|
transformer: Callable[[str], str]=wildcard_transformer) -> SaQuery:
|
||||||
'''
|
|
||||||
if isinstance(criterion, criteria.PlainCriterion):
|
if isinstance(criterion, criteria.PlainCriterion):
|
||||||
expr = column.ilike(transformer(criterion.value))
|
expr = column.ilike(transformer(criterion.value))
|
||||||
elif isinstance(criterion, criteria.ArrayCriterion):
|
elif isinstance(criterion, criteria.ArrayCriterion):
|
||||||
expr = sqlalchemy.sql.false()
|
expr = sa.sql.false()
|
||||||
for value in criterion.values:
|
for value in criterion.values:
|
||||||
expr = expr | column.ilike(transformer(value))
|
expr = expr | column.ilike(transformer(value))
|
||||||
elif isinstance(criterion, criteria.RangedCriterion):
|
elif isinstance(criterion, criteria.RangedCriterion):
|
||||||
|
@ -68,8 +70,15 @@ def apply_str_criterion_to_column(
|
||||||
return expr
|
return expr
|
||||||
|
|
||||||
|
|
||||||
def create_str_filter(column, transformer=wildcard_transformer):
|
def create_str_filter(
|
||||||
def wrapper(query, criterion, negated):
|
column: SaColumn,
|
||||||
|
transformer: Callable[[str], str]=wildcard_transformer
|
||||||
|
) -> Filter:
|
||||||
|
def wrapper(
|
||||||
|
query: SaQuery,
|
||||||
|
criterion: Optional[criteria.BaseCriterion],
|
||||||
|
negated: bool) -> SaQuery:
|
||||||
|
assert criterion
|
||||||
expr = apply_str_criterion_to_column(
|
expr = apply_str_criterion_to_column(
|
||||||
column, criterion, transformer)
|
column, criterion, transformer)
|
||||||
if negated:
|
if negated:
|
||||||
|
@ -78,16 +87,13 @@ def create_str_filter(column, transformer=wildcard_transformer):
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def apply_date_criterion_to_column(column, criterion):
|
def apply_date_criterion_to_column(
|
||||||
'''
|
column: SaQuery, criterion: criteria.BaseCriterion) -> SaQuery:
|
||||||
Decorate SQLAlchemy filter on given column using supplied criterion.
|
|
||||||
Parse the datetime inside the criterion.
|
|
||||||
'''
|
|
||||||
if isinstance(criterion, criteria.PlainCriterion):
|
if isinstance(criterion, criteria.PlainCriterion):
|
||||||
min_date, max_date = util.parse_time_range(criterion.value)
|
min_date, max_date = util.parse_time_range(criterion.value)
|
||||||
expr = column.between(min_date, max_date)
|
expr = column.between(min_date, max_date)
|
||||||
elif isinstance(criterion, criteria.ArrayCriterion):
|
elif isinstance(criterion, criteria.ArrayCriterion):
|
||||||
expr = sqlalchemy.sql.false()
|
expr = sa.sql.false()
|
||||||
for value in criterion.values:
|
for value in criterion.values:
|
||||||
min_date, max_date = util.parse_time_range(value)
|
min_date, max_date = util.parse_time_range(value)
|
||||||
expr = expr | column.between(min_date, max_date)
|
expr = expr | column.between(min_date, max_date)
|
||||||
|
@ -108,10 +114,13 @@ def apply_date_criterion_to_column(column, criterion):
|
||||||
return expr
|
return expr
|
||||||
|
|
||||||
|
|
||||||
def create_date_filter(column):
|
def create_date_filter(column: SaColumn) -> Filter:
|
||||||
def wrapper(query, criterion, negated):
|
def wrapper(
|
||||||
expr = apply_date_criterion_to_column(
|
query: SaQuery,
|
||||||
column, criterion)
|
criterion: Optional[criteria.BaseCriterion],
|
||||||
|
negated: bool) -> SaQuery:
|
||||||
|
assert criterion
|
||||||
|
expr = apply_date_criterion_to_column(column, criterion)
|
||||||
if negated:
|
if negated:
|
||||||
expr = ~expr
|
expr = ~expr
|
||||||
return query.filter(expr)
|
return query.filter(expr)
|
||||||
|
@ -119,18 +128,22 @@ def create_date_filter(column):
|
||||||
|
|
||||||
|
|
||||||
def create_subquery_filter(
|
def create_subquery_filter(
|
||||||
left_id_column,
|
left_id_column: SaColumn,
|
||||||
right_id_column,
|
right_id_column: SaColumn,
|
||||||
filter_column,
|
filter_column: SaColumn,
|
||||||
filter_factory,
|
filter_factory: SaColumn,
|
||||||
subquery_decorator=None):
|
subquery_decorator: Callable[[SaQuery], None]=None) -> Filter:
|
||||||
filter_func = filter_factory(filter_column)
|
filter_func = filter_factory(filter_column)
|
||||||
|
|
||||||
def wrapper(query, criterion, negated):
|
def wrapper(
|
||||||
|
query: SaQuery,
|
||||||
|
criterion: Optional[criteria.BaseCriterion],
|
||||||
|
negated: bool) -> SaQuery:
|
||||||
|
assert criterion
|
||||||
subquery = db.session.query(right_id_column.label('foreign_id'))
|
subquery = db.session.query(right_id_column.label('foreign_id'))
|
||||||
if subquery_decorator:
|
if subquery_decorator:
|
||||||
subquery = subquery_decorator(subquery)
|
subquery = subquery_decorator(subquery)
|
||||||
subquery = subquery.options(sqlalchemy.orm.lazyload('*'))
|
subquery = subquery.options(sa.orm.lazyload('*'))
|
||||||
subquery = filter_func(subquery, criterion, False)
|
subquery = filter_func(subquery, criterion, False)
|
||||||
subquery = subquery.subquery('t')
|
subquery = subquery.subquery('t')
|
||||||
expression = left_id_column.in_(subquery)
|
expression = left_id_column.in_(subquery)
|
||||||
|
|
|
@ -1,34 +1,42 @@
|
||||||
class _BaseCriterion:
|
from typing import Optional, List, Callable
|
||||||
def __init__(self, original_text):
|
from szurubooru.search.typing import SaQuery
|
||||||
|
|
||||||
|
|
||||||
|
class BaseCriterion:
|
||||||
|
def __init__(self, original_text: str) -> None:
|
||||||
self.original_text = original_text
|
self.original_text = original_text
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
return self.original_text
|
return self.original_text
|
||||||
|
|
||||||
|
|
||||||
class RangedCriterion(_BaseCriterion):
|
class RangedCriterion(BaseCriterion):
|
||||||
def __init__(self, original_text, min_value, max_value):
|
def __init__(
|
||||||
|
self,
|
||||||
|
original_text: str,
|
||||||
|
min_value: Optional[str],
|
||||||
|
max_value: Optional[str]) -> None:
|
||||||
super().__init__(original_text)
|
super().__init__(original_text)
|
||||||
self.min_value = min_value
|
self.min_value = min_value
|
||||||
self.max_value = max_value
|
self.max_value = max_value
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self) -> int:
|
||||||
return hash(('range', self.min_value, self.max_value))
|
return hash(('range', self.min_value, self.max_value))
|
||||||
|
|
||||||
|
|
||||||
class PlainCriterion(_BaseCriterion):
|
class PlainCriterion(BaseCriterion):
|
||||||
def __init__(self, original_text, value):
|
def __init__(self, original_text: str, value: str) -> None:
|
||||||
super().__init__(original_text)
|
super().__init__(original_text)
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self) -> int:
|
||||||
return hash(self.value)
|
return hash(self.value)
|
||||||
|
|
||||||
|
|
||||||
class ArrayCriterion(_BaseCriterion):
|
class ArrayCriterion(BaseCriterion):
|
||||||
def __init__(self, original_text, values):
|
def __init__(self, original_text: str, values: List[str]) -> None:
|
||||||
super().__init__(original_text)
|
super().__init__(original_text)
|
||||||
self.values = values
|
self.values = values
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self) -> int:
|
||||||
return hash(tuple(['array'] + self.values))
|
return hash(tuple(['array'] + self.values))
|
||||||
|
|
|
@ -1,14 +1,18 @@
|
||||||
import sqlalchemy
|
from typing import Union, Tuple, List, Dict, Callable
|
||||||
from szurubooru import db, errors
|
import sqlalchemy as sa
|
||||||
|
from szurubooru import db, model, errors, rest
|
||||||
from szurubooru.func import cache
|
from szurubooru.func import cache
|
||||||
from szurubooru.search import tokens, parser
|
from szurubooru.search import tokens, parser
|
||||||
|
from szurubooru.search.typing import SaQuery
|
||||||
|
from szurubooru.search.query import SearchQuery
|
||||||
|
from szurubooru.search.configs.base_search_config import BaseSearchConfig
|
||||||
|
|
||||||
|
|
||||||
def _format_dict_keys(source):
|
def _format_dict_keys(source: Dict) -> List[str]:
|
||||||
return list(sorted(source.keys()))
|
return list(sorted(source.keys()))
|
||||||
|
|
||||||
|
|
||||||
def _get_order(order, default_order):
|
def _get_order(order: str, default_order: str) -> Union[bool, str]:
|
||||||
if order == tokens.SortToken.SORT_DEFAULT:
|
if order == tokens.SortToken.SORT_DEFAULT:
|
||||||
return default_order or tokens.SortToken.SORT_ASC
|
return default_order or tokens.SortToken.SORT_ASC
|
||||||
if order == tokens.SortToken.SORT_NEGATED_DEFAULT:
|
if order == tokens.SortToken.SORT_NEGATED_DEFAULT:
|
||||||
|
@ -26,50 +30,57 @@ class Executor:
|
||||||
delegates sqlalchemy filter decoration to SearchConfig instances.
|
delegates sqlalchemy filter decoration to SearchConfig instances.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, search_config):
|
def __init__(self, search_config: BaseSearchConfig) -> None:
|
||||||
self.config = search_config
|
self.config = search_config
|
||||||
self.parser = parser.Parser()
|
self.parser = parser.Parser()
|
||||||
|
|
||||||
def get_around(self, query_text, entity_id):
|
def get_around(
|
||||||
|
self,
|
||||||
|
query_text: str,
|
||||||
|
entity_id: int) -> Tuple[model.Base, model.Base]:
|
||||||
search_query = self.parser.parse(query_text)
|
search_query = self.parser.parse(query_text)
|
||||||
self.config.on_search_query_parsed(search_query)
|
self.config.on_search_query_parsed(search_query)
|
||||||
filter_query = (
|
filter_query = (
|
||||||
self.config
|
self.config
|
||||||
.create_around_query()
|
.create_around_query()
|
||||||
.options(sqlalchemy.orm.lazyload('*')))
|
.options(sa.orm.lazyload('*')))
|
||||||
filter_query = self._prepare_db_query(
|
filter_query = self._prepare_db_query(
|
||||||
filter_query, search_query, False)
|
filter_query, search_query, False)
|
||||||
prev_filter_query = (
|
prev_filter_query = (
|
||||||
filter_query
|
filter_query
|
||||||
.filter(self.config.id_column > entity_id)
|
.filter(self.config.id_column > entity_id)
|
||||||
.order_by(None)
|
.order_by(None)
|
||||||
.order_by(sqlalchemy.func.abs(
|
.order_by(sa.func.abs(self.config.id_column - entity_id).asc())
|
||||||
self.config.id_column - entity_id).asc())
|
|
||||||
.limit(1))
|
.limit(1))
|
||||||
next_filter_query = (
|
next_filter_query = (
|
||||||
filter_query
|
filter_query
|
||||||
.filter(self.config.id_column < entity_id)
|
.filter(self.config.id_column < entity_id)
|
||||||
.order_by(None)
|
.order_by(None)
|
||||||
.order_by(sqlalchemy.func.abs(
|
.order_by(sa.func.abs(self.config.id_column - entity_id).asc())
|
||||||
self.config.id_column - entity_id).asc())
|
|
||||||
.limit(1))
|
.limit(1))
|
||||||
return [
|
return (
|
||||||
prev_filter_query.one_or_none(),
|
prev_filter_query.one_or_none(),
|
||||||
next_filter_query.one_or_none()]
|
next_filter_query.one_or_none())
|
||||||
|
|
||||||
def get_around_and_serialize(self, ctx, entity_id, serializer):
|
def get_around_and_serialize(
|
||||||
entities = self.get_around(ctx.get_param_as_string('query'), entity_id)
|
self,
|
||||||
|
ctx: rest.Context,
|
||||||
|
entity_id: int,
|
||||||
|
serializer: Callable[[model.Base], rest.Response]
|
||||||
|
) -> rest.Response:
|
||||||
|
entities = self.get_around(
|
||||||
|
ctx.get_param_as_string('query', default=''), entity_id)
|
||||||
return {
|
return {
|
||||||
'prev': serializer(entities[0]),
|
'prev': serializer(entities[0]),
|
||||||
'next': serializer(entities[1]),
|
'next': serializer(entities[1]),
|
||||||
}
|
}
|
||||||
|
|
||||||
def execute(self, query_text, page, page_size):
|
def execute(
|
||||||
'''
|
self,
|
||||||
Parse input and return tuple containing total record count and filtered
|
query_text: str,
|
||||||
entities.
|
page: int,
|
||||||
'''
|
page_size: int
|
||||||
|
) -> Tuple[int, List[model.Base]]:
|
||||||
search_query = self.parser.parse(query_text)
|
search_query = self.parser.parse(query_text)
|
||||||
self.config.on_search_query_parsed(search_query)
|
self.config.on_search_query_parsed(search_query)
|
||||||
|
|
||||||
|
@ -83,7 +94,7 @@ class Executor:
|
||||||
return cache.get(key)
|
return cache.get(key)
|
||||||
|
|
||||||
filter_query = self.config.create_filter_query(disable_eager_loads)
|
filter_query = self.config.create_filter_query(disable_eager_loads)
|
||||||
filter_query = filter_query.options(sqlalchemy.orm.lazyload('*'))
|
filter_query = filter_query.options(sa.orm.lazyload('*'))
|
||||||
filter_query = self._prepare_db_query(filter_query, search_query, True)
|
filter_query = self._prepare_db_query(filter_query, search_query, True)
|
||||||
entities = filter_query \
|
entities = filter_query \
|
||||||
.offset(max(page - 1, 0) * page_size) \
|
.offset(max(page - 1, 0) * page_size) \
|
||||||
|
@ -91,11 +102,11 @@ class Executor:
|
||||||
.all()
|
.all()
|
||||||
|
|
||||||
count_query = self.config.create_count_query(disable_eager_loads)
|
count_query = self.config.create_count_query(disable_eager_loads)
|
||||||
count_query = count_query.options(sqlalchemy.orm.lazyload('*'))
|
count_query = count_query.options(sa.orm.lazyload('*'))
|
||||||
count_query = self._prepare_db_query(count_query, search_query, False)
|
count_query = self._prepare_db_query(count_query, search_query, False)
|
||||||
count_statement = count_query \
|
count_statement = count_query \
|
||||||
.statement \
|
.statement \
|
||||||
.with_only_columns([sqlalchemy.func.count()]) \
|
.with_only_columns([sa.func.count()]) \
|
||||||
.order_by(None)
|
.order_by(None)
|
||||||
count = db.session.execute(count_statement).scalar()
|
count = db.session.execute(count_statement).scalar()
|
||||||
|
|
||||||
|
@ -103,8 +114,12 @@ class Executor:
|
||||||
cache.put(key, ret)
|
cache.put(key, ret)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def execute_and_serialize(self, ctx, serializer):
|
def execute_and_serialize(
|
||||||
query = ctx.get_param_as_string('query')
|
self,
|
||||||
|
ctx: rest.Context,
|
||||||
|
serializer: Callable[[model.Base], rest.Response]
|
||||||
|
) -> rest.Response:
|
||||||
|
query = ctx.get_param_as_string('query', default='')
|
||||||
page = ctx.get_param_as_int('page', default=1, min=1)
|
page = ctx.get_param_as_int('page', default=1, min=1)
|
||||||
page_size = ctx.get_param_as_int(
|
page_size = ctx.get_param_as_int(
|
||||||
'pageSize', default=100, min=1, max=100)
|
'pageSize', default=100, min=1, max=100)
|
||||||
|
@ -117,48 +132,51 @@ class Executor:
|
||||||
'results': [serializer(entity) for entity in entities],
|
'results': [serializer(entity) for entity in entities],
|
||||||
}
|
}
|
||||||
|
|
||||||
def _prepare_db_query(self, db_query, search_query, use_sort):
|
def _prepare_db_query(
|
||||||
''' Parse input and return SQLAlchemy query. '''
|
self,
|
||||||
|
db_query: SaQuery,
|
||||||
for token in search_query.anonymous_tokens:
|
search_query: SearchQuery,
|
||||||
|
use_sort: bool) -> SaQuery:
|
||||||
|
for anon_token in search_query.anonymous_tokens:
|
||||||
if not self.config.anonymous_filter:
|
if not self.config.anonymous_filter:
|
||||||
raise errors.SearchError(
|
raise errors.SearchError(
|
||||||
'Anonymous tokens are not valid in this context.')
|
'Anonymous tokens are not valid in this context.')
|
||||||
db_query = self.config.anonymous_filter(
|
db_query = self.config.anonymous_filter(
|
||||||
db_query, token.criterion, token.negated)
|
db_query, anon_token.criterion, anon_token.negated)
|
||||||
|
|
||||||
for token in search_query.named_tokens:
|
for named_token in search_query.named_tokens:
|
||||||
if token.name not in self.config.named_filters:
|
if named_token.name not in self.config.named_filters:
|
||||||
raise errors.SearchError(
|
raise errors.SearchError(
|
||||||
'Unknown named token: %r. Available named tokens: %r.' % (
|
'Unknown named token: %r. Available named tokens: %r.' % (
|
||||||
token.name,
|
named_token.name,
|
||||||
_format_dict_keys(self.config.named_filters)))
|
_format_dict_keys(self.config.named_filters)))
|
||||||
db_query = self.config.named_filters[token.name](
|
db_query = self.config.named_filters[named_token.name](
|
||||||
db_query, token.criterion, token.negated)
|
db_query, named_token.criterion, named_token.negated)
|
||||||
|
|
||||||
for token in search_query.special_tokens:
|
for sp_token in search_query.special_tokens:
|
||||||
if token.value not in self.config.special_filters:
|
if sp_token.value not in self.config.special_filters:
|
||||||
raise errors.SearchError(
|
raise errors.SearchError(
|
||||||
'Unknown special token: %r. '
|
'Unknown special token: %r. '
|
||||||
'Available special tokens: %r.' % (
|
'Available special tokens: %r.' % (
|
||||||
token.value,
|
sp_token.value,
|
||||||
_format_dict_keys(self.config.special_filters)))
|
_format_dict_keys(self.config.special_filters)))
|
||||||
db_query = self.config.special_filters[token.value](
|
db_query = self.config.special_filters[sp_token.value](
|
||||||
db_query, token.negated)
|
db_query, None, sp_token.negated)
|
||||||
|
|
||||||
if use_sort:
|
if use_sort:
|
||||||
for token in search_query.sort_tokens:
|
for sort_token in search_query.sort_tokens:
|
||||||
if token.name not in self.config.sort_columns:
|
if sort_token.name not in self.config.sort_columns:
|
||||||
raise errors.SearchError(
|
raise errors.SearchError(
|
||||||
'Unknown sort token: %r. '
|
'Unknown sort token: %r. '
|
||||||
'Available sort tokens: %r.' % (
|
'Available sort tokens: %r.' % (
|
||||||
token.name,
|
sort_token.name,
|
||||||
_format_dict_keys(self.config.sort_columns)))
|
_format_dict_keys(self.config.sort_columns)))
|
||||||
column, default_order = self.config.sort_columns[token.name]
|
column, default_order = (
|
||||||
order = _get_order(token.order, default_order)
|
self.config.sort_columns[sort_token.name])
|
||||||
if order == token.SORT_ASC:
|
order = _get_order(sort_token.order, default_order)
|
||||||
|
if order == sort_token.SORT_ASC:
|
||||||
db_query = db_query.order_by(column.asc())
|
db_query = db_query.order_by(column.asc())
|
||||||
elif order == token.SORT_DESC:
|
elif order == sort_token.SORT_DESC:
|
||||||
db_query = db_query.order_by(column.desc())
|
db_query = db_query.order_by(column.desc())
|
||||||
|
|
||||||
db_query = self.config.finalize_query(db_query)
|
db_query = self.config.finalize_query(db_query)
|
||||||
|
|
|
@ -1,9 +1,12 @@
|
||||||
import re
|
import re
|
||||||
|
from typing import List
|
||||||
from szurubooru import errors
|
from szurubooru import errors
|
||||||
from szurubooru.search import criteria, tokens
|
from szurubooru.search import criteria, tokens
|
||||||
|
from szurubooru.search.query import SearchQuery
|
||||||
|
|
||||||
|
|
||||||
def _create_criterion(original_value, value):
|
def _create_criterion(
|
||||||
|
original_value: str, value: str) -> criteria.BaseCriterion:
|
||||||
if ',' in value:
|
if ',' in value:
|
||||||
return criteria.ArrayCriterion(
|
return criteria.ArrayCriterion(
|
||||||
original_value, value.split(','))
|
original_value, value.split(','))
|
||||||
|
@ -15,12 +18,12 @@ def _create_criterion(original_value, value):
|
||||||
return criteria.PlainCriterion(original_value, value)
|
return criteria.PlainCriterion(original_value, value)
|
||||||
|
|
||||||
|
|
||||||
def _parse_anonymous(value, negated):
|
def _parse_anonymous(value: str, negated: bool) -> tokens.AnonymousToken:
|
||||||
criterion = _create_criterion(value, value)
|
criterion = _create_criterion(value, value)
|
||||||
return tokens.AnonymousToken(criterion, negated)
|
return tokens.AnonymousToken(criterion, negated)
|
||||||
|
|
||||||
|
|
||||||
def _parse_named(key, value, negated):
|
def _parse_named(key: str, value: str, negated: bool) -> tokens.NamedToken:
|
||||||
original_value = value
|
original_value = value
|
||||||
if key.endswith('-min'):
|
if key.endswith('-min'):
|
||||||
key = key[:-4]
|
key = key[:-4]
|
||||||
|
@ -32,11 +35,11 @@ def _parse_named(key, value, negated):
|
||||||
return tokens.NamedToken(key, criterion, negated)
|
return tokens.NamedToken(key, criterion, negated)
|
||||||
|
|
||||||
|
|
||||||
def _parse_special(value, negated):
|
def _parse_special(value: str, negated: bool) -> tokens.SpecialToken:
|
||||||
return tokens.SpecialToken(value, negated)
|
return tokens.SpecialToken(value, negated)
|
||||||
|
|
||||||
|
|
||||||
def _parse_sort(value, negated):
|
def _parse_sort(value: str, negated: bool) -> tokens.SortToken:
|
||||||
if value.count(',') == 0:
|
if value.count(',') == 0:
|
||||||
order_str = None
|
order_str = None
|
||||||
elif value.count(',') == 1:
|
elif value.count(',') == 1:
|
||||||
|
@ -67,23 +70,8 @@ def _parse_sort(value, negated):
|
||||||
return tokens.SortToken(value, order)
|
return tokens.SortToken(value, order)
|
||||||
|
|
||||||
|
|
||||||
class SearchQuery:
|
|
||||||
def __init__(self):
|
|
||||||
self.anonymous_tokens = []
|
|
||||||
self.named_tokens = []
|
|
||||||
self.special_tokens = []
|
|
||||||
self.sort_tokens = []
|
|
||||||
|
|
||||||
def __hash__(self):
|
|
||||||
return hash((
|
|
||||||
tuple(self.anonymous_tokens),
|
|
||||||
tuple(self.named_tokens),
|
|
||||||
tuple(self.special_tokens),
|
|
||||||
tuple(self.sort_tokens)))
|
|
||||||
|
|
||||||
|
|
||||||
class Parser:
|
class Parser:
|
||||||
def parse(self, query_text):
|
def parse(self, query_text: str) -> SearchQuery:
|
||||||
query = SearchQuery()
|
query = SearchQuery()
|
||||||
for chunk in re.split(r'\s+', (query_text or '').lower()):
|
for chunk in re.split(r'\s+', (query_text or '').lower()):
|
||||||
if not chunk:
|
if not chunk:
|
||||||
|
|
16
server/szurubooru/search/query.py
Normal file
16
server/szurubooru/search/query.py
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
from szurubooru.search import tokens
|
||||||
|
|
||||||
|
|
||||||
|
class SearchQuery:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.anonymous_tokens = [] # type: List[tokens.AnonymousToken]
|
||||||
|
self.named_tokens = [] # type: List[tokens.NamedToken]
|
||||||
|
self.special_tokens = [] # type: List[tokens.SpecialToken]
|
||||||
|
self.sort_tokens = [] # type: List[tokens.SortToken]
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return hash((
|
||||||
|
tuple(self.anonymous_tokens),
|
||||||
|
tuple(self.named_tokens),
|
||||||
|
tuple(self.special_tokens),
|
||||||
|
tuple(self.sort_tokens)))
|
|
@ -1,39 +1,44 @@
|
||||||
|
from szurubooru.search.criteria import BaseCriterion
|
||||||
|
|
||||||
|
|
||||||
class AnonymousToken:
|
class AnonymousToken:
|
||||||
def __init__(self, criterion, negated):
|
def __init__(self, criterion: BaseCriterion, negated: bool) -> None:
|
||||||
self.criterion = criterion
|
self.criterion = criterion
|
||||||
self.negated = negated
|
self.negated = negated
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self) -> int:
|
||||||
return hash((self.criterion, self.negated))
|
return hash((self.criterion, self.negated))
|
||||||
|
|
||||||
|
|
||||||
class NamedToken(AnonymousToken):
|
class NamedToken(AnonymousToken):
|
||||||
def __init__(self, name, criterion, negated):
|
def __init__(
|
||||||
|
self, name: str, criterion: BaseCriterion, negated: bool) -> None:
|
||||||
super().__init__(criterion, negated)
|
super().__init__(criterion, negated)
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self) -> int:
|
||||||
return hash((self.name, self.criterion, self.negated))
|
return hash((self.name, self.criterion, self.negated))
|
||||||
|
|
||||||
|
|
||||||
class SortToken:
|
class SortToken:
|
||||||
SORT_DESC = 'desc'
|
SORT_DESC = 'desc'
|
||||||
SORT_ASC = 'asc'
|
SORT_ASC = 'asc'
|
||||||
|
SORT_NONE = ''
|
||||||
SORT_DEFAULT = 'default'
|
SORT_DEFAULT = 'default'
|
||||||
SORT_NEGATED_DEFAULT = 'negated default'
|
SORT_NEGATED_DEFAULT = 'negated default'
|
||||||
|
|
||||||
def __init__(self, name, order):
|
def __init__(self, name: str, order: str) -> None:
|
||||||
self.name = name
|
self.name = name
|
||||||
self.order = order
|
self.order = order
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self) -> int:
|
||||||
return hash((self.name, self.order))
|
return hash((self.name, self.order))
|
||||||
|
|
||||||
|
|
||||||
class SpecialToken:
|
class SpecialToken:
|
||||||
def __init__(self, value, negated):
|
def __init__(self, value: str, negated: bool) -> None:
|
||||||
self.value = value
|
self.value = value
|
||||||
self.negated = negated
|
self.negated = negated
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self) -> int:
|
||||||
return hash((self.value, self.negated))
|
return hash((self.value, self.negated))
|
||||||
|
|
6
server/szurubooru/search/typing.py
Normal file
6
server/szurubooru/search/typing.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
|
||||||
|
SaColumn = Any
|
||||||
|
SaQuery = Any
|
||||||
|
SaQueryFactory = Callable[[], SaQuery]
|
|
@ -1,19 +1,20 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, model, errors
|
||||||
from szurubooru.func import comments, posts
|
from szurubooru.func import comments, posts
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({'privileges': {'comments:create': db.User.RANK_REGULAR}})
|
config_injector(
|
||||||
|
{'privileges': {'comments:create': model.User.RANK_REGULAR}})
|
||||||
|
|
||||||
|
|
||||||
def test_creating_comment(
|
def test_creating_comment(
|
||||||
user_factory, post_factory, context_factory, fake_datetime):
|
user_factory, post_factory, context_factory, fake_datetime):
|
||||||
post = post_factory()
|
post = post_factory()
|
||||||
user = user_factory(rank=db.User.RANK_REGULAR)
|
user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
db.session.add_all([post, user])
|
db.session.add_all([post, user])
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
with patch('szurubooru.func.comments.serialize_comment'), \
|
with patch('szurubooru.func.comments.serialize_comment'), \
|
||||||
|
@ -24,7 +25,7 @@ def test_creating_comment(
|
||||||
params={'text': 'input', 'postId': post.post_id},
|
params={'text': 'input', 'postId': post.post_id},
|
||||||
user=user))
|
user=user))
|
||||||
assert result == 'serialized comment'
|
assert result == 'serialized comment'
|
||||||
comment = db.session.query(db.Comment).one()
|
comment = db.session.query(model.Comment).one()
|
||||||
assert comment.text == 'input'
|
assert comment.text == 'input'
|
||||||
assert comment.creation_time == datetime(1997, 1, 1)
|
assert comment.creation_time == datetime(1997, 1, 1)
|
||||||
assert comment.last_edit_time is None
|
assert comment.last_edit_time is None
|
||||||
|
@ -41,7 +42,7 @@ def test_creating_comment(
|
||||||
def test_trying_to_pass_invalid_params(
|
def test_trying_to_pass_invalid_params(
|
||||||
user_factory, post_factory, context_factory, params):
|
user_factory, post_factory, context_factory, params):
|
||||||
post = post_factory()
|
post = post_factory()
|
||||||
user = user_factory(rank=db.User.RANK_REGULAR)
|
user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
db.session.add_all([post, user])
|
db.session.add_all([post, user])
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
real_params = {'text': 'input', 'postId': post.post_id}
|
real_params = {'text': 'input', 'postId': post.post_id}
|
||||||
|
@ -63,11 +64,11 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
|
||||||
api.comment_api.create_comment(
|
api.comment_api.create_comment(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={},
|
params={},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=model.User.RANK_REGULAR)))
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_comment_non_existing(user_factory, context_factory):
|
def test_trying_to_comment_non_existing(user_factory, context_factory):
|
||||||
user = user_factory(rank=db.User.RANK_REGULAR)
|
user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
db.session.add_all([user])
|
db.session.add_all([user])
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
with pytest.raises(posts.PostNotFoundError):
|
with pytest.raises(posts.PostNotFoundError):
|
||||||
|
@ -81,4 +82,4 @@ def test_trying_to_create_without_privileges(user_factory, context_factory):
|
||||||
api.comment_api.create_comment(
|
api.comment_api.create_comment(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={},
|
params={},
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
|
user=user_factory(rank=model.User.RANK_ANONYMOUS)))
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, model, errors
|
||||||
from szurubooru.func import comments
|
from szurubooru.func import comments
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,8 +7,8 @@ from szurubooru.func import comments
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
'privileges': {
|
'privileges': {
|
||||||
'comments:delete:own': db.User.RANK_REGULAR,
|
'comments:delete:own': model.User.RANK_REGULAR,
|
||||||
'comments:delete:any': db.User.RANK_MODERATOR,
|
'comments:delete:any': model.User.RANK_MODERATOR,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -22,26 +22,26 @@ def test_deleting_own_comment(user_factory, comment_factory, context_factory):
|
||||||
context_factory(params={'version': 1}, user=user),
|
context_factory(params={'version': 1}, user=user),
|
||||||
{'comment_id': comment.comment_id})
|
{'comment_id': comment.comment_id})
|
||||||
assert result == {}
|
assert result == {}
|
||||||
assert db.session.query(db.Comment).count() == 0
|
assert db.session.query(model.Comment).count() == 0
|
||||||
|
|
||||||
|
|
||||||
def test_deleting_someones_else_comment(
|
def test_deleting_someones_else_comment(
|
||||||
user_factory, comment_factory, context_factory):
|
user_factory, comment_factory, context_factory):
|
||||||
user1 = user_factory(rank=db.User.RANK_REGULAR)
|
user1 = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
user2 = user_factory(rank=db.User.RANK_MODERATOR)
|
user2 = user_factory(rank=model.User.RANK_MODERATOR)
|
||||||
comment = comment_factory(user=user1)
|
comment = comment_factory(user=user1)
|
||||||
db.session.add(comment)
|
db.session.add(comment)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
api.comment_api.delete_comment(
|
api.comment_api.delete_comment(
|
||||||
context_factory(params={'version': 1}, user=user2),
|
context_factory(params={'version': 1}, user=user2),
|
||||||
{'comment_id': comment.comment_id})
|
{'comment_id': comment.comment_id})
|
||||||
assert db.session.query(db.Comment).count() == 0
|
assert db.session.query(model.Comment).count() == 0
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_delete_someones_else_comment_without_privileges(
|
def test_trying_to_delete_someones_else_comment_without_privileges(
|
||||||
user_factory, comment_factory, context_factory):
|
user_factory, comment_factory, context_factory):
|
||||||
user1 = user_factory(rank=db.User.RANK_REGULAR)
|
user1 = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
user2 = user_factory(rank=db.User.RANK_REGULAR)
|
user2 = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
comment = comment_factory(user=user1)
|
comment = comment_factory(user=user1)
|
||||||
db.session.add(comment)
|
db.session.add(comment)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
@ -49,7 +49,7 @@ def test_trying_to_delete_someones_else_comment_without_privileges(
|
||||||
api.comment_api.delete_comment(
|
api.comment_api.delete_comment(
|
||||||
context_factory(params={'version': 1}, user=user2),
|
context_factory(params={'version': 1}, user=user2),
|
||||||
{'comment_id': comment.comment_id})
|
{'comment_id': comment.comment_id})
|
||||||
assert db.session.query(db.Comment).count() == 1
|
assert db.session.query(model.Comment).count() == 1
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_delete_non_existing(user_factory, context_factory):
|
def test_trying_to_delete_non_existing(user_factory, context_factory):
|
||||||
|
@ -57,5 +57,5 @@ def test_trying_to_delete_non_existing(user_factory, context_factory):
|
||||||
api.comment_api.delete_comment(
|
api.comment_api.delete_comment(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'version': 1},
|
params={'version': 1},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'comment_id': 1})
|
{'comment_id': 1})
|
||||||
|
|
|
@ -1,17 +1,18 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, model, errors
|
||||||
from szurubooru.func import comments
|
from szurubooru.func import comments
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({'privileges': {'comments:score': db.User.RANK_REGULAR}})
|
config_injector(
|
||||||
|
{'privileges': {'comments:score': model.User.RANK_REGULAR}})
|
||||||
|
|
||||||
|
|
||||||
def test_simple_rating(
|
def test_simple_rating(
|
||||||
user_factory, comment_factory, context_factory, fake_datetime):
|
user_factory, comment_factory, context_factory, fake_datetime):
|
||||||
user = user_factory(rank=db.User.RANK_REGULAR)
|
user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
comment = comment_factory(user=user)
|
comment = comment_factory(user=user)
|
||||||
db.session.add(comment)
|
db.session.add(comment)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
@ -22,14 +23,14 @@ def test_simple_rating(
|
||||||
context_factory(params={'score': 1}, user=user),
|
context_factory(params={'score': 1}, user=user),
|
||||||
{'comment_id': comment.comment_id})
|
{'comment_id': comment.comment_id})
|
||||||
assert result == 'serialized comment'
|
assert result == 'serialized comment'
|
||||||
assert db.session.query(db.CommentScore).count() == 1
|
assert db.session.query(model.CommentScore).count() == 1
|
||||||
assert comment is not None
|
assert comment is not None
|
||||||
assert comment.score == 1
|
assert comment.score == 1
|
||||||
|
|
||||||
|
|
||||||
def test_updating_rating(
|
def test_updating_rating(
|
||||||
user_factory, comment_factory, context_factory, fake_datetime):
|
user_factory, comment_factory, context_factory, fake_datetime):
|
||||||
user = user_factory(rank=db.User.RANK_REGULAR)
|
user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
comment = comment_factory(user=user)
|
comment = comment_factory(user=user)
|
||||||
db.session.add(comment)
|
db.session.add(comment)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
@ -42,14 +43,14 @@ def test_updating_rating(
|
||||||
api.comment_api.set_comment_score(
|
api.comment_api.set_comment_score(
|
||||||
context_factory(params={'score': -1}, user=user),
|
context_factory(params={'score': -1}, user=user),
|
||||||
{'comment_id': comment.comment_id})
|
{'comment_id': comment.comment_id})
|
||||||
comment = db.session.query(db.Comment).one()
|
comment = db.session.query(model.Comment).one()
|
||||||
assert db.session.query(db.CommentScore).count() == 1
|
assert db.session.query(model.CommentScore).count() == 1
|
||||||
assert comment.score == -1
|
assert comment.score == -1
|
||||||
|
|
||||||
|
|
||||||
def test_updating_rating_to_zero(
|
def test_updating_rating_to_zero(
|
||||||
user_factory, comment_factory, context_factory, fake_datetime):
|
user_factory, comment_factory, context_factory, fake_datetime):
|
||||||
user = user_factory(rank=db.User.RANK_REGULAR)
|
user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
comment = comment_factory(user=user)
|
comment = comment_factory(user=user)
|
||||||
db.session.add(comment)
|
db.session.add(comment)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
@ -62,14 +63,14 @@ def test_updating_rating_to_zero(
|
||||||
api.comment_api.set_comment_score(
|
api.comment_api.set_comment_score(
|
||||||
context_factory(params={'score': 0}, user=user),
|
context_factory(params={'score': 0}, user=user),
|
||||||
{'comment_id': comment.comment_id})
|
{'comment_id': comment.comment_id})
|
||||||
comment = db.session.query(db.Comment).one()
|
comment = db.session.query(model.Comment).one()
|
||||||
assert db.session.query(db.CommentScore).count() == 0
|
assert db.session.query(model.CommentScore).count() == 0
|
||||||
assert comment.score == 0
|
assert comment.score == 0
|
||||||
|
|
||||||
|
|
||||||
def test_deleting_rating(
|
def test_deleting_rating(
|
||||||
user_factory, comment_factory, context_factory, fake_datetime):
|
user_factory, comment_factory, context_factory, fake_datetime):
|
||||||
user = user_factory(rank=db.User.RANK_REGULAR)
|
user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
comment = comment_factory(user=user)
|
comment = comment_factory(user=user)
|
||||||
db.session.add(comment)
|
db.session.add(comment)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
@ -82,15 +83,15 @@ def test_deleting_rating(
|
||||||
api.comment_api.delete_comment_score(
|
api.comment_api.delete_comment_score(
|
||||||
context_factory(user=user),
|
context_factory(user=user),
|
||||||
{'comment_id': comment.comment_id})
|
{'comment_id': comment.comment_id})
|
||||||
comment = db.session.query(db.Comment).one()
|
comment = db.session.query(model.Comment).one()
|
||||||
assert db.session.query(db.CommentScore).count() == 0
|
assert db.session.query(model.CommentScore).count() == 0
|
||||||
assert comment.score == 0
|
assert comment.score == 0
|
||||||
|
|
||||||
|
|
||||||
def test_ratings_from_multiple_users(
|
def test_ratings_from_multiple_users(
|
||||||
user_factory, comment_factory, context_factory, fake_datetime):
|
user_factory, comment_factory, context_factory, fake_datetime):
|
||||||
user1 = user_factory(rank=db.User.RANK_REGULAR)
|
user1 = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
user2 = user_factory(rank=db.User.RANK_REGULAR)
|
user2 = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
comment = comment_factory()
|
comment = comment_factory()
|
||||||
db.session.add_all([user1, user2, comment])
|
db.session.add_all([user1, user2, comment])
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
@ -103,8 +104,8 @@ def test_ratings_from_multiple_users(
|
||||||
api.comment_api.set_comment_score(
|
api.comment_api.set_comment_score(
|
||||||
context_factory(params={'score': -1}, user=user2),
|
context_factory(params={'score': -1}, user=user2),
|
||||||
{'comment_id': comment.comment_id})
|
{'comment_id': comment.comment_id})
|
||||||
comment = db.session.query(db.Comment).one()
|
comment = db.session.query(model.Comment).one()
|
||||||
assert db.session.query(db.CommentScore).count() == 2
|
assert db.session.query(model.CommentScore).count() == 2
|
||||||
assert comment.score == 0
|
assert comment.score == 0
|
||||||
|
|
||||||
|
|
||||||
|
@ -125,7 +126,7 @@ def test_trying_to_update_non_existing(user_factory, context_factory):
|
||||||
api.comment_api.set_comment_score(
|
api.comment_api.set_comment_score(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'score': 1},
|
params={'score': 1},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'comment_id': 5})
|
{'comment_id': 5})
|
||||||
|
|
||||||
|
|
||||||
|
@ -138,5 +139,5 @@ def test_trying_to_rate_without_privileges(
|
||||||
api.comment_api.set_comment_score(
|
api.comment_api.set_comment_score(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'score': 1},
|
params={'score': 1},
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
|
user=user_factory(rank=model.User.RANK_ANONYMOUS)),
|
||||||
{'comment_id': comment.comment_id})
|
{'comment_id': comment.comment_id})
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, model, errors
|
||||||
from szurubooru.func import comments
|
from szurubooru.func import comments
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,8 +8,8 @@ from szurubooru.func import comments
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
'privileges': {
|
'privileges': {
|
||||||
'comments:list': db.User.RANK_REGULAR,
|
'comments:list': model.User.RANK_REGULAR,
|
||||||
'comments:view': db.User.RANK_REGULAR,
|
'comments:view': model.User.RANK_REGULAR,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ def test_retrieving_multiple(user_factory, comment_factory, context_factory):
|
||||||
result = api.comment_api.get_comments(
|
result = api.comment_api.get_comments(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'query': '', 'page': 1},
|
params={'query': '', 'page': 1},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=model.User.RANK_REGULAR)))
|
||||||
assert result == {
|
assert result == {
|
||||||
'query': '',
|
'query': '',
|
||||||
'page': 1,
|
'page': 1,
|
||||||
|
@ -40,7 +40,7 @@ def test_trying_to_retrieve_multiple_without_privileges(
|
||||||
api.comment_api.get_comments(
|
api.comment_api.get_comments(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'query': '', 'page': 1},
|
params={'query': '', 'page': 1},
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
|
user=user_factory(rank=model.User.RANK_ANONYMOUS)))
|
||||||
|
|
||||||
|
|
||||||
def test_retrieving_single(user_factory, comment_factory, context_factory):
|
def test_retrieving_single(user_factory, comment_factory, context_factory):
|
||||||
|
@ -51,7 +51,7 @@ def test_retrieving_single(user_factory, comment_factory, context_factory):
|
||||||
comments.serialize_comment.return_value = 'serialized comment'
|
comments.serialize_comment.return_value = 'serialized comment'
|
||||||
result = api.comment_api.get_comment(
|
result = api.comment_api.get_comment(
|
||||||
context_factory(
|
context_factory(
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'comment_id': comment.comment_id})
|
{'comment_id': comment.comment_id})
|
||||||
assert result == 'serialized comment'
|
assert result == 'serialized comment'
|
||||||
|
|
||||||
|
@ -60,7 +60,7 @@ def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
|
||||||
with pytest.raises(comments.CommentNotFoundError):
|
with pytest.raises(comments.CommentNotFoundError):
|
||||||
api.comment_api.get_comment(
|
api.comment_api.get_comment(
|
||||||
context_factory(
|
context_factory(
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'comment_id': 5})
|
{'comment_id': 5})
|
||||||
|
|
||||||
|
|
||||||
|
@ -68,5 +68,5 @@ def test_trying_to_retrieve_single_without_privileges(
|
||||||
user_factory, context_factory):
|
user_factory, context_factory):
|
||||||
with pytest.raises(errors.AuthError):
|
with pytest.raises(errors.AuthError):
|
||||||
api.comment_api.get_comment(
|
api.comment_api.get_comment(
|
||||||
context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)),
|
context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)),
|
||||||
{'comment_id': 5})
|
{'comment_id': 5})
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, model, errors
|
||||||
from szurubooru.func import comments
|
from szurubooru.func import comments
|
||||||
|
|
||||||
|
|
||||||
|
@ -9,15 +9,15 @@ from szurubooru.func import comments
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
'privileges': {
|
'privileges': {
|
||||||
'comments:edit:own': db.User.RANK_REGULAR,
|
'comments:edit:own': model.User.RANK_REGULAR,
|
||||||
'comments:edit:any': db.User.RANK_MODERATOR,
|
'comments:edit:any': model.User.RANK_MODERATOR,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def test_simple_updating(
|
def test_simple_updating(
|
||||||
user_factory, comment_factory, context_factory, fake_datetime):
|
user_factory, comment_factory, context_factory, fake_datetime):
|
||||||
user = user_factory(rank=db.User.RANK_REGULAR)
|
user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
comment = comment_factory(user=user)
|
comment = comment_factory(user=user)
|
||||||
db.session.add(comment)
|
db.session.add(comment)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
@ -73,14 +73,14 @@ def test_trying_to_update_non_existing(user_factory, context_factory):
|
||||||
api.comment_api.update_comment(
|
api.comment_api.update_comment(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'text': 'new text'},
|
params={'text': 'new text'},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'comment_id': 5})
|
{'comment_id': 5})
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_update_someones_comment_without_privileges(
|
def test_trying_to_update_someones_comment_without_privileges(
|
||||||
user_factory, comment_factory, context_factory):
|
user_factory, comment_factory, context_factory):
|
||||||
user = user_factory(rank=db.User.RANK_REGULAR)
|
user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
user2 = user_factory(rank=db.User.RANK_REGULAR)
|
user2 = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
comment = comment_factory(user=user)
|
comment = comment_factory(user=user)
|
||||||
db.session.add(comment)
|
db.session.add(comment)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
@ -93,8 +93,8 @@ def test_trying_to_update_someones_comment_without_privileges(
|
||||||
|
|
||||||
def test_updating_someones_comment_with_privileges(
|
def test_updating_someones_comment_with_privileges(
|
||||||
user_factory, comment_factory, context_factory):
|
user_factory, comment_factory, context_factory):
|
||||||
user = user_factory(rank=db.User.RANK_REGULAR)
|
user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
user2 = user_factory(rank=db.User.RANK_MODERATOR)
|
user2 = user_factory(rank=model.User.RANK_MODERATOR)
|
||||||
comment = comment_factory(user=user)
|
comment = comment_factory(user=user)
|
||||||
db.session.add(comment)
|
db.session.add(comment)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, model, errors
|
||||||
from szurubooru.func import auth, mailer
|
from szurubooru.func import auth, mailer
|
||||||
|
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ def inject_config(config_injector):
|
||||||
|
|
||||||
def test_reset_sending_email(context_factory, user_factory):
|
def test_reset_sending_email(context_factory, user_factory):
|
||||||
db.session.add(user_factory(
|
db.session.add(user_factory(
|
||||||
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com'))
|
name='u1', rank=model.User.RANK_REGULAR, email='user@example.com'))
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
for initiating_user in ['u1', 'user@example.com']:
|
for initiating_user in ['u1', 'user@example.com']:
|
||||||
with patch('szurubooru.func.mailer.send_mail'):
|
with patch('szurubooru.func.mailer.send_mail'):
|
||||||
|
@ -39,7 +39,7 @@ def test_trying_to_reset_non_existing(context_factory):
|
||||||
|
|
||||||
def test_trying_to_reset_without_email(context_factory, user_factory):
|
def test_trying_to_reset_without_email(context_factory, user_factory):
|
||||||
db.session.add(
|
db.session.add(
|
||||||
user_factory(name='u1', rank=db.User.RANK_REGULAR, email=None))
|
user_factory(name='u1', rank=model.User.RANK_REGULAR, email=None))
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
with pytest.raises(errors.ValidationError):
|
with pytest.raises(errors.ValidationError):
|
||||||
api.password_reset_api.start_password_reset(
|
api.password_reset_api.start_password_reset(
|
||||||
|
@ -48,7 +48,7 @@ def test_trying_to_reset_without_email(context_factory, user_factory):
|
||||||
|
|
||||||
def test_confirming_with_good_token(context_factory, user_factory):
|
def test_confirming_with_good_token(context_factory, user_factory):
|
||||||
user = user_factory(
|
user = user_factory(
|
||||||
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')
|
name='u1', rank=model.User.RANK_REGULAR, email='user@example.com')
|
||||||
old_hash = user.password_hash
|
old_hash = user.password_hash
|
||||||
db.session.add(user)
|
db.session.add(user)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
|
@ -68,7 +68,7 @@ def test_trying_to_confirm_non_existing(context_factory):
|
||||||
|
|
||||||
def test_trying_to_confirm_without_token(context_factory, user_factory):
|
def test_trying_to_confirm_without_token(context_factory, user_factory):
|
||||||
db.session.add(user_factory(
|
db.session.add(user_factory(
|
||||||
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com'))
|
name='u1', rank=model.User.RANK_REGULAR, email='user@example.com'))
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
with pytest.raises(errors.ValidationError):
|
with pytest.raises(errors.ValidationError):
|
||||||
api.password_reset_api.finish_password_reset(
|
api.password_reset_api.finish_password_reset(
|
||||||
|
@ -77,7 +77,7 @@ def test_trying_to_confirm_without_token(context_factory, user_factory):
|
||||||
|
|
||||||
def test_trying_to_confirm_with_bad_token(context_factory, user_factory):
|
def test_trying_to_confirm_with_bad_token(context_factory, user_factory):
|
||||||
db.session.add(user_factory(
|
db.session.add(user_factory(
|
||||||
name='u1', rank=db.User.RANK_REGULAR, email='user@example.com'))
|
name='u1', rank=model.User.RANK_REGULAR, email='user@example.com'))
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
with pytest.raises(errors.ValidationError):
|
with pytest.raises(errors.ValidationError):
|
||||||
api.password_reset_api.finish_password_reset(
|
api.password_reset_api.finish_password_reset(
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, model, errors
|
||||||
from szurubooru.func import posts, tags, snapshots, net
|
from szurubooru.func import posts, tags, snapshots, net
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,16 +8,16 @@ from szurubooru.func import posts, tags, snapshots, net
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
'privileges': {
|
'privileges': {
|
||||||
'posts:create:anonymous': db.User.RANK_REGULAR,
|
'posts:create:anonymous': model.User.RANK_REGULAR,
|
||||||
'posts:create:identified': db.User.RANK_REGULAR,
|
'posts:create:identified': model.User.RANK_REGULAR,
|
||||||
'tags:create': db.User.RANK_REGULAR,
|
'tags:create': model.User.RANK_REGULAR,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def test_creating_minimal_posts(
|
def test_creating_minimal_posts(
|
||||||
context_factory, post_factory, user_factory):
|
context_factory, post_factory, user_factory):
|
||||||
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
auth_user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
post = post_factory()
|
post = post_factory()
|
||||||
db.session.add(post)
|
db.session.add(post)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
|
@ -53,20 +53,20 @@ def test_creating_minimal_posts(
|
||||||
posts.update_post_thumbnail.assert_called_once_with(
|
posts.update_post_thumbnail.assert_called_once_with(
|
||||||
post, 'post-thumbnail')
|
post, 'post-thumbnail')
|
||||||
posts.update_post_safety.assert_called_once_with(post, 'safe')
|
posts.update_post_safety.assert_called_once_with(post, 'safe')
|
||||||
posts.update_post_source.assert_called_once_with(post, None)
|
posts.update_post_source.assert_called_once_with(post, '')
|
||||||
posts.update_post_relations.assert_called_once_with(post, [])
|
posts.update_post_relations.assert_called_once_with(post, [])
|
||||||
posts.update_post_notes.assert_called_once_with(post, [])
|
posts.update_post_notes.assert_called_once_with(post, [])
|
||||||
posts.update_post_flags.assert_called_once_with(post, [])
|
posts.update_post_flags.assert_called_once_with(post, [])
|
||||||
posts.update_post_thumbnail.assert_called_once_with(
|
posts.update_post_thumbnail.assert_called_once_with(
|
||||||
post, 'post-thumbnail')
|
post, 'post-thumbnail')
|
||||||
posts.serialize_post.assert_called_once_with(
|
posts.serialize_post.assert_called_once_with(
|
||||||
post, auth_user, options=None)
|
post, auth_user, options=[])
|
||||||
snapshots.create.assert_called_once_with(post, auth_user)
|
snapshots.create.assert_called_once_with(post, auth_user)
|
||||||
tags.export_to_json.assert_called_once_with()
|
tags.export_to_json.assert_called_once_with()
|
||||||
|
|
||||||
|
|
||||||
def test_creating_full_posts(context_factory, post_factory, user_factory):
|
def test_creating_full_posts(context_factory, post_factory, user_factory):
|
||||||
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
auth_user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
post = post_factory()
|
post = post_factory()
|
||||||
db.session.add(post)
|
db.session.add(post)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
|
@ -109,14 +109,14 @@ def test_creating_full_posts(context_factory, post_factory, user_factory):
|
||||||
posts.update_post_flags.assert_called_once_with(
|
posts.update_post_flags.assert_called_once_with(
|
||||||
post, ['flag1', 'flag2'])
|
post, ['flag1', 'flag2'])
|
||||||
posts.serialize_post.assert_called_once_with(
|
posts.serialize_post.assert_called_once_with(
|
||||||
post, auth_user, options=None)
|
post, auth_user, options=[])
|
||||||
snapshots.create.assert_called_once_with(post, auth_user)
|
snapshots.create.assert_called_once_with(post, auth_user)
|
||||||
tags.export_to_json.assert_called_once_with()
|
tags.export_to_json.assert_called_once_with()
|
||||||
|
|
||||||
|
|
||||||
def test_anonymous_uploads(
|
def test_anonymous_uploads(
|
||||||
config_injector, context_factory, post_factory, user_factory):
|
config_injector, context_factory, post_factory, user_factory):
|
||||||
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
auth_user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
post = post_factory()
|
post = post_factory()
|
||||||
db.session.add(post)
|
db.session.add(post)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
|
@ -126,7 +126,7 @@ def test_anonymous_uploads(
|
||||||
patch('szurubooru.func.posts.create_post'), \
|
patch('szurubooru.func.posts.create_post'), \
|
||||||
patch('szurubooru.func.posts.update_post_source'):
|
patch('szurubooru.func.posts.update_post_source'):
|
||||||
config_injector({
|
config_injector({
|
||||||
'privileges': {'posts:create:anonymous': db.User.RANK_REGULAR},
|
'privileges': {'posts:create:anonymous': model.User.RANK_REGULAR},
|
||||||
})
|
})
|
||||||
posts.create_post.return_value = [post, []]
|
posts.create_post.return_value = [post, []]
|
||||||
api.post_api.create_post(
|
api.post_api.create_post(
|
||||||
|
@ -146,7 +146,7 @@ def test_anonymous_uploads(
|
||||||
|
|
||||||
def test_creating_from_url_saves_source(
|
def test_creating_from_url_saves_source(
|
||||||
config_injector, context_factory, post_factory, user_factory):
|
config_injector, context_factory, post_factory, user_factory):
|
||||||
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
auth_user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
post = post_factory()
|
post = post_factory()
|
||||||
db.session.add(post)
|
db.session.add(post)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
|
@ -157,7 +157,7 @@ def test_creating_from_url_saves_source(
|
||||||
patch('szurubooru.func.posts.create_post'), \
|
patch('szurubooru.func.posts.create_post'), \
|
||||||
patch('szurubooru.func.posts.update_post_source'):
|
patch('szurubooru.func.posts.update_post_source'):
|
||||||
config_injector({
|
config_injector({
|
||||||
'privileges': {'posts:create:identified': db.User.RANK_REGULAR},
|
'privileges': {'posts:create:identified': model.User.RANK_REGULAR},
|
||||||
})
|
})
|
||||||
net.download.return_value = b'content'
|
net.download.return_value = b'content'
|
||||||
posts.create_post.return_value = [post, []]
|
posts.create_post.return_value = [post, []]
|
||||||
|
@ -177,7 +177,7 @@ def test_creating_from_url_saves_source(
|
||||||
|
|
||||||
def test_creating_from_url_with_source_specified(
|
def test_creating_from_url_with_source_specified(
|
||||||
config_injector, context_factory, post_factory, user_factory):
|
config_injector, context_factory, post_factory, user_factory):
|
||||||
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
auth_user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
post = post_factory()
|
post = post_factory()
|
||||||
db.session.add(post)
|
db.session.add(post)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
|
@ -188,7 +188,7 @@ def test_creating_from_url_with_source_specified(
|
||||||
patch('szurubooru.func.posts.create_post'), \
|
patch('szurubooru.func.posts.create_post'), \
|
||||||
patch('szurubooru.func.posts.update_post_source'):
|
patch('szurubooru.func.posts.update_post_source'):
|
||||||
config_injector({
|
config_injector({
|
||||||
'privileges': {'posts:create:identified': db.User.RANK_REGULAR},
|
'privileges': {'posts:create:identified': model.User.RANK_REGULAR},
|
||||||
})
|
})
|
||||||
net.download.return_value = b'content'
|
net.download.return_value = b'content'
|
||||||
posts.create_post.return_value = [post, []]
|
posts.create_post.return_value = [post, []]
|
||||||
|
@ -218,14 +218,14 @@ def test_trying_to_omit_mandatory_field(context_factory, user_factory, field):
|
||||||
context_factory(
|
context_factory(
|
||||||
params=params,
|
params=params,
|
||||||
files={'content': '...'},
|
files={'content': '...'},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=model.User.RANK_REGULAR)))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
'field', ['tags', 'relations', 'source', 'notes', 'flags'])
|
'field', ['tags', 'relations', 'source', 'notes', 'flags'])
|
||||||
def test_omitting_optional_field(
|
def test_omitting_optional_field(
|
||||||
field, context_factory, post_factory, user_factory):
|
field, context_factory, post_factory, user_factory):
|
||||||
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
auth_user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
post = post_factory()
|
post = post_factory()
|
||||||
db.session.add(post)
|
db.session.add(post)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
|
@ -268,10 +268,10 @@ def test_errors_not_spending_ids(
|
||||||
'post_height': 300,
|
'post_height': 300,
|
||||||
},
|
},
|
||||||
'privileges': {
|
'privileges': {
|
||||||
'posts:create:identified': db.User.RANK_REGULAR,
|
'posts:create:identified': model.User.RANK_REGULAR,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
auth_user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
|
|
||||||
# successful request
|
# successful request
|
||||||
with patch('szurubooru.func.posts.serialize_post'), \
|
with patch('szurubooru.func.posts.serialize_post'), \
|
||||||
|
@ -316,7 +316,7 @@ def test_trying_to_omit_content(context_factory, user_factory):
|
||||||
'safety': 'safe',
|
'safety': 'safe',
|
||||||
'tags': ['tag1', 'tag2'],
|
'tags': ['tag1', 'tag2'],
|
||||||
},
|
},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=model.User.RANK_REGULAR)))
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_create_post_without_privileges(
|
def test_trying_to_create_post_without_privileges(
|
||||||
|
@ -324,16 +324,16 @@ def test_trying_to_create_post_without_privileges(
|
||||||
with pytest.raises(errors.AuthError):
|
with pytest.raises(errors.AuthError):
|
||||||
api.post_api.create_post(context_factory(
|
api.post_api.create_post(context_factory(
|
||||||
params='whatever',
|
params='whatever',
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
|
user=user_factory(rank=model.User.RANK_ANONYMOUS)))
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_create_tags_without_privileges(
|
def test_trying_to_create_tags_without_privileges(
|
||||||
config_injector, context_factory, user_factory):
|
config_injector, context_factory, user_factory):
|
||||||
config_injector({
|
config_injector({
|
||||||
'privileges': {
|
'privileges': {
|
||||||
'posts:create:anonymous': db.User.RANK_REGULAR,
|
'posts:create:anonymous': model.User.RANK_REGULAR,
|
||||||
'posts:create:identified': db.User.RANK_REGULAR,
|
'posts:create:identified': model.User.RANK_REGULAR,
|
||||||
'tags:create': db.User.RANK_ADMINISTRATOR,
|
'tags:create': model.User.RANK_ADMINISTRATOR,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
with pytest.raises(errors.AuthError), \
|
with pytest.raises(errors.AuthError), \
|
||||||
|
@ -349,4 +349,4 @@ def test_trying_to_create_tags_without_privileges(
|
||||||
files={
|
files={
|
||||||
'content': posts.EMPTY_PIXEL,
|
'content': posts.EMPTY_PIXEL,
|
||||||
},
|
},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=model.User.RANK_REGULAR)))
|
||||||
|
|
|
@ -1,16 +1,16 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, model, errors
|
||||||
from szurubooru.func import posts, tags, snapshots
|
from szurubooru.func import posts, tags, snapshots
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({'privileges': {'posts:delete': db.User.RANK_REGULAR}})
|
config_injector({'privileges': {'posts:delete': model.User.RANK_REGULAR}})
|
||||||
|
|
||||||
|
|
||||||
def test_deleting(user_factory, post_factory, context_factory):
|
def test_deleting(user_factory, post_factory, context_factory):
|
||||||
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
auth_user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
post = post_factory(id=1)
|
post = post_factory(id=1)
|
||||||
db.session.add(post)
|
db.session.add(post)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
|
@ -20,7 +20,7 @@ def test_deleting(user_factory, post_factory, context_factory):
|
||||||
context_factory(params={'version': 1}, user=auth_user),
|
context_factory(params={'version': 1}, user=auth_user),
|
||||||
{'post_id': 1})
|
{'post_id': 1})
|
||||||
assert result == {}
|
assert result == {}
|
||||||
assert db.session.query(db.Post).count() == 0
|
assert db.session.query(model.Post).count() == 0
|
||||||
snapshots.delete.assert_called_once_with(post, auth_user)
|
snapshots.delete.assert_called_once_with(post, auth_user)
|
||||||
tags.export_to_json.assert_called_once_with()
|
tags.export_to_json.assert_called_once_with()
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ def test_deleting(user_factory, post_factory, context_factory):
|
||||||
def test_trying_to_delete_non_existing(user_factory, context_factory):
|
def test_trying_to_delete_non_existing(user_factory, context_factory):
|
||||||
with pytest.raises(posts.PostNotFoundError):
|
with pytest.raises(posts.PostNotFoundError):
|
||||||
api.post_api.delete_post(
|
api.post_api.delete_post(
|
||||||
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
|
context_factory(user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'post_id': 999})
|
{'post_id': 999})
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,6 +38,6 @@ def test_trying_to_delete_without_privileges(
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
with pytest.raises(errors.AuthError):
|
with pytest.raises(errors.AuthError):
|
||||||
api.post_api.delete_post(
|
api.post_api.delete_post(
|
||||||
context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)),
|
context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)),
|
||||||
{'post_id': 1})
|
{'post_id': 1})
|
||||||
assert db.session.query(db.Post).count() == 1
|
assert db.session.query(model.Post).count() == 1
|
||||||
|
|
|
@ -1,13 +1,14 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, model, errors
|
||||||
from szurubooru.func import posts
|
from szurubooru.func import posts
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({'privileges': {'posts:favorite': db.User.RANK_REGULAR}})
|
config_injector(
|
||||||
|
{'privileges': {'posts:favorite': model.User.RANK_REGULAR}})
|
||||||
|
|
||||||
|
|
||||||
def test_adding_to_favorites(
|
def test_adding_to_favorites(
|
||||||
|
@ -23,8 +24,8 @@ def test_adding_to_favorites(
|
||||||
context_factory(user=user_factory()),
|
context_factory(user=user_factory()),
|
||||||
{'post_id': post.post_id})
|
{'post_id': post.post_id})
|
||||||
assert result == 'serialized post'
|
assert result == 'serialized post'
|
||||||
post = db.session.query(db.Post).one()
|
post = db.session.query(model.Post).one()
|
||||||
assert db.session.query(db.PostFavorite).count() == 1
|
assert db.session.query(model.PostFavorite).count() == 1
|
||||||
assert post is not None
|
assert post is not None
|
||||||
assert post.favorite_count == 1
|
assert post.favorite_count == 1
|
||||||
assert post.score == 1
|
assert post.score == 1
|
||||||
|
@ -47,9 +48,9 @@ def test_removing_from_favorites(
|
||||||
api.post_api.delete_post_from_favorites(
|
api.post_api.delete_post_from_favorites(
|
||||||
context_factory(user=user),
|
context_factory(user=user),
|
||||||
{'post_id': post.post_id})
|
{'post_id': post.post_id})
|
||||||
post = db.session.query(db.Post).one()
|
post = db.session.query(model.Post).one()
|
||||||
assert post.score == 1
|
assert post.score == 1
|
||||||
assert db.session.query(db.PostFavorite).count() == 0
|
assert db.session.query(model.PostFavorite).count() == 0
|
||||||
assert post.favorite_count == 0
|
assert post.favorite_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
@ -68,8 +69,8 @@ def test_favoriting_twice(
|
||||||
api.post_api.add_post_to_favorites(
|
api.post_api.add_post_to_favorites(
|
||||||
context_factory(user=user),
|
context_factory(user=user),
|
||||||
{'post_id': post.post_id})
|
{'post_id': post.post_id})
|
||||||
post = db.session.query(db.Post).one()
|
post = db.session.query(model.Post).one()
|
||||||
assert db.session.query(db.PostFavorite).count() == 1
|
assert db.session.query(model.PostFavorite).count() == 1
|
||||||
assert post.favorite_count == 1
|
assert post.favorite_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@ -92,8 +93,8 @@ def test_removing_twice(
|
||||||
api.post_api.delete_post_from_favorites(
|
api.post_api.delete_post_from_favorites(
|
||||||
context_factory(user=user),
|
context_factory(user=user),
|
||||||
{'post_id': post.post_id})
|
{'post_id': post.post_id})
|
||||||
post = db.session.query(db.Post).one()
|
post = db.session.query(model.Post).one()
|
||||||
assert db.session.query(db.PostFavorite).count() == 0
|
assert db.session.query(model.PostFavorite).count() == 0
|
||||||
assert post.favorite_count == 0
|
assert post.favorite_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
@ -113,8 +114,8 @@ def test_favorites_from_multiple_users(
|
||||||
api.post_api.add_post_to_favorites(
|
api.post_api.add_post_to_favorites(
|
||||||
context_factory(user=user2),
|
context_factory(user=user2),
|
||||||
{'post_id': post.post_id})
|
{'post_id': post.post_id})
|
||||||
post = db.session.query(db.Post).one()
|
post = db.session.query(model.Post).one()
|
||||||
assert db.session.query(db.PostFavorite).count() == 2
|
assert db.session.query(model.PostFavorite).count() == 2
|
||||||
assert post.favorite_count == 2
|
assert post.favorite_count == 2
|
||||||
assert post.last_favorite_time == datetime(1997, 12, 2)
|
assert post.last_favorite_time == datetime(1997, 12, 2)
|
||||||
|
|
||||||
|
@ -133,5 +134,5 @@ def test_trying_to_rate_without_privileges(
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
with pytest.raises(errors.AuthError):
|
with pytest.raises(errors.AuthError):
|
||||||
api.post_api.add_post_to_favorites(
|
api.post_api.add_post_to_favorites(
|
||||||
context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)),
|
context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)),
|
||||||
{'post_id': post.post_id})
|
{'post_id': post.post_id})
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, model, errors
|
||||||
from szurubooru.func import posts, snapshots
|
from szurubooru.func import posts, snapshots
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,14 +8,14 @@ from szurubooru.func import posts, snapshots
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
'privileges': {
|
'privileges': {
|
||||||
'posts:feature': db.User.RANK_REGULAR,
|
'posts:feature': model.User.RANK_REGULAR,
|
||||||
'posts:view': db.User.RANK_REGULAR,
|
'posts:view': model.User.RANK_REGULAR,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def test_featuring(user_factory, post_factory, context_factory):
|
def test_featuring(user_factory, post_factory, context_factory):
|
||||||
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
auth_user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
post = post_factory(id=1)
|
post = post_factory(id=1)
|
||||||
db.session.add(post)
|
db.session.add(post)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
|
@ -31,7 +31,7 @@ def test_featuring(user_factory, post_factory, context_factory):
|
||||||
assert posts.get_post_by_id(1).is_featured
|
assert posts.get_post_by_id(1).is_featured
|
||||||
result = api.post_api.get_featured_post(
|
result = api.post_api.get_featured_post(
|
||||||
context_factory(
|
context_factory(
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=model.User.RANK_REGULAR)))
|
||||||
assert result == 'serialized post'
|
assert result == 'serialized post'
|
||||||
snapshots.modify.assert_called_once_with(post, auth_user)
|
snapshots.modify.assert_called_once_with(post, auth_user)
|
||||||
|
|
||||||
|
@ -40,7 +40,7 @@ def test_trying_to_omit_required_parameter(user_factory, context_factory):
|
||||||
with pytest.raises(errors.MissingRequiredParameterError):
|
with pytest.raises(errors.MissingRequiredParameterError):
|
||||||
api.post_api.set_featured_post(
|
api.post_api.set_featured_post(
|
||||||
context_factory(
|
context_factory(
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=model.User.RANK_REGULAR)))
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_feature_the_same_post_twice(
|
def test_trying_to_feature_the_same_post_twice(
|
||||||
|
@ -51,12 +51,12 @@ def test_trying_to_feature_the_same_post_twice(
|
||||||
api.post_api.set_featured_post(
|
api.post_api.set_featured_post(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'id': 1},
|
params={'id': 1},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=model.User.RANK_REGULAR)))
|
||||||
with pytest.raises(posts.PostAlreadyFeaturedError):
|
with pytest.raises(posts.PostAlreadyFeaturedError):
|
||||||
api.post_api.set_featured_post(
|
api.post_api.set_featured_post(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'id': 1},
|
params={'id': 1},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=model.User.RANK_REGULAR)))
|
||||||
|
|
||||||
|
|
||||||
def test_featuring_one_post_after_another(
|
def test_featuring_one_post_after_another(
|
||||||
|
@ -72,12 +72,12 @@ def test_featuring_one_post_after_another(
|
||||||
api.post_api.set_featured_post(
|
api.post_api.set_featured_post(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'id': 1},
|
params={'id': 1},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=model.User.RANK_REGULAR)))
|
||||||
with fake_datetime('1998'):
|
with fake_datetime('1998'):
|
||||||
api.post_api.set_featured_post(
|
api.post_api.set_featured_post(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'id': 2},
|
params={'id': 2},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=model.User.RANK_REGULAR)))
|
||||||
assert posts.try_get_featured_post() is not None
|
assert posts.try_get_featured_post() is not None
|
||||||
assert posts.try_get_featured_post().post_id == 2
|
assert posts.try_get_featured_post().post_id == 2
|
||||||
assert not posts.get_post_by_id(1).is_featured
|
assert not posts.get_post_by_id(1).is_featured
|
||||||
|
@ -89,7 +89,7 @@ def test_trying_to_feature_non_existing(user_factory, context_factory):
|
||||||
api.post_api.set_featured_post(
|
api.post_api.set_featured_post(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'id': 1},
|
params={'id': 1},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=model.User.RANK_REGULAR)))
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_feature_without_privileges(user_factory, context_factory):
|
def test_trying_to_feature_without_privileges(user_factory, context_factory):
|
||||||
|
@ -97,10 +97,10 @@ def test_trying_to_feature_without_privileges(user_factory, context_factory):
|
||||||
api.post_api.set_featured_post(
|
api.post_api.set_featured_post(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'id': 1},
|
params={'id': 1},
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
|
user=user_factory(rank=model.User.RANK_ANONYMOUS)))
|
||||||
|
|
||||||
|
|
||||||
def test_getting_featured_post_without_privileges_to_view(
|
def test_getting_featured_post_without_privileges_to_view(
|
||||||
user_factory, context_factory):
|
user_factory, context_factory):
|
||||||
api.post_api.get_featured_post(
|
api.post_api.get_featured_post(
|
||||||
context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)))
|
context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)))
|
||||||
|
|
|
@ -1,16 +1,16 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, model, errors
|
||||||
from szurubooru.func import posts, snapshots
|
from szurubooru.func import posts, snapshots
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({'privileges': {'posts:merge': db.User.RANK_REGULAR}})
|
config_injector({'privileges': {'posts:merge': model.User.RANK_REGULAR}})
|
||||||
|
|
||||||
|
|
||||||
def test_merging(user_factory, context_factory, post_factory):
|
def test_merging(user_factory, context_factory, post_factory):
|
||||||
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
auth_user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
source_post = post_factory()
|
source_post = post_factory()
|
||||||
target_post = post_factory()
|
target_post = post_factory()
|
||||||
db.session.add_all([source_post, target_post])
|
db.session.add_all([source_post, target_post])
|
||||||
|
@ -25,6 +25,7 @@ def test_merging(user_factory, context_factory, post_factory):
|
||||||
'mergeToVersion': 1,
|
'mergeToVersion': 1,
|
||||||
'remove': source_post.post_id,
|
'remove': source_post.post_id,
|
||||||
'mergeTo': target_post.post_id,
|
'mergeTo': target_post.post_id,
|
||||||
|
'replaceContent': False,
|
||||||
},
|
},
|
||||||
user=auth_user))
|
user=auth_user))
|
||||||
posts.merge_posts.called_once_with(source_post, target_post)
|
posts.merge_posts.called_once_with(source_post, target_post)
|
||||||
|
@ -45,13 +46,14 @@ def test_trying_to_omit_mandatory_field(
|
||||||
'mergeToVersion': 1,
|
'mergeToVersion': 1,
|
||||||
'remove': source_post.post_id,
|
'remove': source_post.post_id,
|
||||||
'mergeTo': target_post.post_id,
|
'mergeTo': target_post.post_id,
|
||||||
|
'replaceContent': False,
|
||||||
}
|
}
|
||||||
del params[field]
|
del params[field]
|
||||||
with pytest.raises(errors.ValidationError):
|
with pytest.raises(errors.ValidationError):
|
||||||
api.post_api.merge_posts(
|
api.post_api.merge_posts(
|
||||||
context_factory(
|
context_factory(
|
||||||
params=params,
|
params=params,
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=model.User.RANK_REGULAR)))
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_merge_non_existing(
|
def test_trying_to_merge_non_existing(
|
||||||
|
@ -63,12 +65,12 @@ def test_trying_to_merge_non_existing(
|
||||||
api.post_api.merge_posts(
|
api.post_api.merge_posts(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'remove': post.post_id, 'mergeTo': 999},
|
params={'remove': post.post_id, 'mergeTo': 999},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=model.User.RANK_REGULAR)))
|
||||||
with pytest.raises(posts.PostNotFoundError):
|
with pytest.raises(posts.PostNotFoundError):
|
||||||
api.post_api.merge_posts(
|
api.post_api.merge_posts(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'remove': 999, 'mergeTo': post.post_id},
|
params={'remove': 999, 'mergeTo': post.post_id},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=model.User.RANK_REGULAR)))
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_merge_without_privileges(
|
def test_trying_to_merge_without_privileges(
|
||||||
|
@ -85,5 +87,6 @@ def test_trying_to_merge_without_privileges(
|
||||||
'mergeToVersion': 1,
|
'mergeToVersion': 1,
|
||||||
'remove': source_post.post_id,
|
'remove': source_post.post_id,
|
||||||
'mergeTo': target_post.post_id,
|
'mergeTo': target_post.post_id,
|
||||||
|
'replaceContent': False,
|
||||||
},
|
},
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
|
user=user_factory(rank=model.User.RANK_ANONYMOUS)))
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, model, errors
|
||||||
from szurubooru.func import posts
|
from szurubooru.func import posts
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({'privileges': {'posts:score': db.User.RANK_REGULAR}})
|
config_injector({'privileges': {'posts:score': model.User.RANK_REGULAR}})
|
||||||
|
|
||||||
|
|
||||||
def test_simple_rating(
|
def test_simple_rating(
|
||||||
|
@ -22,8 +22,8 @@ def test_simple_rating(
|
||||||
params={'score': 1}, user=user_factory()),
|
params={'score': 1}, user=user_factory()),
|
||||||
{'post_id': post.post_id})
|
{'post_id': post.post_id})
|
||||||
assert result == 'serialized post'
|
assert result == 'serialized post'
|
||||||
post = db.session.query(db.Post).one()
|
post = db.session.query(model.Post).one()
|
||||||
assert db.session.query(db.PostScore).count() == 1
|
assert db.session.query(model.PostScore).count() == 1
|
||||||
assert post is not None
|
assert post is not None
|
||||||
assert post.score == 1
|
assert post.score == 1
|
||||||
|
|
||||||
|
@ -43,8 +43,8 @@ def test_updating_rating(
|
||||||
api.post_api.set_post_score(
|
api.post_api.set_post_score(
|
||||||
context_factory(params={'score': -1}, user=user),
|
context_factory(params={'score': -1}, user=user),
|
||||||
{'post_id': post.post_id})
|
{'post_id': post.post_id})
|
||||||
post = db.session.query(db.Post).one()
|
post = db.session.query(model.Post).one()
|
||||||
assert db.session.query(db.PostScore).count() == 1
|
assert db.session.query(model.PostScore).count() == 1
|
||||||
assert post.score == -1
|
assert post.score == -1
|
||||||
|
|
||||||
|
|
||||||
|
@ -63,8 +63,8 @@ def test_updating_rating_to_zero(
|
||||||
api.post_api.set_post_score(
|
api.post_api.set_post_score(
|
||||||
context_factory(params={'score': 0}, user=user),
|
context_factory(params={'score': 0}, user=user),
|
||||||
{'post_id': post.post_id})
|
{'post_id': post.post_id})
|
||||||
post = db.session.query(db.Post).one()
|
post = db.session.query(model.Post).one()
|
||||||
assert db.session.query(db.PostScore).count() == 0
|
assert db.session.query(model.PostScore).count() == 0
|
||||||
assert post.score == 0
|
assert post.score == 0
|
||||||
|
|
||||||
|
|
||||||
|
@ -83,8 +83,8 @@ def test_deleting_rating(
|
||||||
api.post_api.delete_post_score(
|
api.post_api.delete_post_score(
|
||||||
context_factory(user=user),
|
context_factory(user=user),
|
||||||
{'post_id': post.post_id})
|
{'post_id': post.post_id})
|
||||||
post = db.session.query(db.Post).one()
|
post = db.session.query(model.Post).one()
|
||||||
assert db.session.query(db.PostScore).count() == 0
|
assert db.session.query(model.PostScore).count() == 0
|
||||||
assert post.score == 0
|
assert post.score == 0
|
||||||
|
|
||||||
|
|
||||||
|
@ -104,8 +104,8 @@ def test_ratings_from_multiple_users(
|
||||||
api.post_api.set_post_score(
|
api.post_api.set_post_score(
|
||||||
context_factory(params={'score': -1}, user=user2),
|
context_factory(params={'score': -1}, user=user2),
|
||||||
{'post_id': post.post_id})
|
{'post_id': post.post_id})
|
||||||
post = db.session.query(db.Post).one()
|
post = db.session.query(model.Post).one()
|
||||||
assert db.session.query(db.PostScore).count() == 2
|
assert db.session.query(model.PostScore).count() == 2
|
||||||
assert post.score == 0
|
assert post.score == 0
|
||||||
|
|
||||||
|
|
||||||
|
@ -136,5 +136,5 @@ def test_trying_to_rate_without_privileges(
|
||||||
api.post_api.set_post_score(
|
api.post_api.set_post_score(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'score': 1},
|
params={'score': 1},
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
|
user=user_factory(rank=model.User.RANK_ANONYMOUS)),
|
||||||
{'post_id': post.post_id})
|
{'post_id': post.post_id})
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, model, errors
|
||||||
from szurubooru.func import posts
|
from szurubooru.func import posts
|
||||||
|
|
||||||
|
|
||||||
|
@ -9,8 +9,8 @@ from szurubooru.func import posts
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
'privileges': {
|
'privileges': {
|
||||||
'posts:list': db.User.RANK_REGULAR,
|
'posts:list': model.User.RANK_REGULAR,
|
||||||
'posts:view': db.User.RANK_REGULAR,
|
'posts:view': model.User.RANK_REGULAR,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -25,7 +25,7 @@ def test_retrieving_multiple(user_factory, post_factory, context_factory):
|
||||||
result = api.post_api.get_posts(
|
result = api.post_api.get_posts(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'query': '', 'page': 1},
|
params={'query': '', 'page': 1},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=model.User.RANK_REGULAR)))
|
||||||
assert result == {
|
assert result == {
|
||||||
'query': '',
|
'query': '',
|
||||||
'page': 1,
|
'page': 1,
|
||||||
|
@ -36,10 +36,10 @@ def test_retrieving_multiple(user_factory, post_factory, context_factory):
|
||||||
|
|
||||||
|
|
||||||
def test_using_special_tokens(user_factory, post_factory, context_factory):
|
def test_using_special_tokens(user_factory, post_factory, context_factory):
|
||||||
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
auth_user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
post1 = post_factory(id=1)
|
post1 = post_factory(id=1)
|
||||||
post2 = post_factory(id=2)
|
post2 = post_factory(id=2)
|
||||||
post1.favorited_by = [db.PostFavorite(
|
post1.favorited_by = [model.PostFavorite(
|
||||||
user=auth_user, time=datetime.utcnow())]
|
user=auth_user, time=datetime.utcnow())]
|
||||||
db.session.add_all([post1, post2, auth_user])
|
db.session.add_all([post1, post2, auth_user])
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
|
@ -68,7 +68,7 @@ def test_trying_to_use_special_tokens_without_logging_in(
|
||||||
api.post_api.get_posts(
|
api.post_api.get_posts(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'query': 'special:fav', 'page': 1},
|
params={'query': 'special:fav', 'page': 1},
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
|
user=user_factory(rank=model.User.RANK_ANONYMOUS)))
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_retrieve_multiple_without_privileges(
|
def test_trying_to_retrieve_multiple_without_privileges(
|
||||||
|
@ -77,7 +77,7 @@ def test_trying_to_retrieve_multiple_without_privileges(
|
||||||
api.post_api.get_posts(
|
api.post_api.get_posts(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'query': '', 'page': 1},
|
params={'query': '', 'page': 1},
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
|
user=user_factory(rank=model.User.RANK_ANONYMOUS)))
|
||||||
|
|
||||||
|
|
||||||
def test_retrieving_single(user_factory, post_factory, context_factory):
|
def test_retrieving_single(user_factory, post_factory, context_factory):
|
||||||
|
@ -86,7 +86,7 @@ def test_retrieving_single(user_factory, post_factory, context_factory):
|
||||||
with patch('szurubooru.func.posts.serialize_post'):
|
with patch('szurubooru.func.posts.serialize_post'):
|
||||||
posts.serialize_post.return_value = 'serialized post'
|
posts.serialize_post.return_value = 'serialized post'
|
||||||
result = api.post_api.get_post(
|
result = api.post_api.get_post(
|
||||||
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
|
context_factory(user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'post_id': 1})
|
{'post_id': 1})
|
||||||
assert result == 'serialized post'
|
assert result == 'serialized post'
|
||||||
|
|
||||||
|
@ -94,7 +94,7 @@ def test_retrieving_single(user_factory, post_factory, context_factory):
|
||||||
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
|
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
|
||||||
with pytest.raises(posts.PostNotFoundError):
|
with pytest.raises(posts.PostNotFoundError):
|
||||||
api.post_api.get_post(
|
api.post_api.get_post(
|
||||||
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
|
context_factory(user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'post_id': 999})
|
{'post_id': 999})
|
||||||
|
|
||||||
|
|
||||||
|
@ -102,5 +102,5 @@ def test_trying_to_retrieve_single_without_privileges(
|
||||||
user_factory, context_factory):
|
user_factory, context_factory):
|
||||||
with pytest.raises(errors.AuthError):
|
with pytest.raises(errors.AuthError):
|
||||||
api.post_api.get_post(
|
api.post_api.get_post(
|
||||||
context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)),
|
context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)),
|
||||||
{'post_id': 999})
|
{'post_id': 999})
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, model, errors
|
||||||
from szurubooru.func import posts, tags, snapshots, net
|
from szurubooru.func import posts, tags, snapshots, net
|
||||||
|
|
||||||
|
|
||||||
|
@ -9,22 +9,22 @@ from szurubooru.func import posts, tags, snapshots, net
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
'privileges': {
|
'privileges': {
|
||||||
'posts:edit:tags': db.User.RANK_REGULAR,
|
'posts:edit:tags': model.User.RANK_REGULAR,
|
||||||
'posts:edit:content': db.User.RANK_REGULAR,
|
'posts:edit:content': model.User.RANK_REGULAR,
|
||||||
'posts:edit:safety': db.User.RANK_REGULAR,
|
'posts:edit:safety': model.User.RANK_REGULAR,
|
||||||
'posts:edit:source': db.User.RANK_REGULAR,
|
'posts:edit:source': model.User.RANK_REGULAR,
|
||||||
'posts:edit:relations': db.User.RANK_REGULAR,
|
'posts:edit:relations': model.User.RANK_REGULAR,
|
||||||
'posts:edit:notes': db.User.RANK_REGULAR,
|
'posts:edit:notes': model.User.RANK_REGULAR,
|
||||||
'posts:edit:flags': db.User.RANK_REGULAR,
|
'posts:edit:flags': model.User.RANK_REGULAR,
|
||||||
'posts:edit:thumbnail': db.User.RANK_REGULAR,
|
'posts:edit:thumbnail': model.User.RANK_REGULAR,
|
||||||
'tags:create': db.User.RANK_MODERATOR,
|
'tags:create': model.User.RANK_MODERATOR,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def test_post_updating(
|
def test_post_updating(
|
||||||
context_factory, post_factory, user_factory, fake_datetime):
|
context_factory, post_factory, user_factory, fake_datetime):
|
||||||
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
auth_user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
post = post_factory()
|
post = post_factory()
|
||||||
db.session.add(post)
|
db.session.add(post)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
|
@ -76,7 +76,7 @@ def test_post_updating(
|
||||||
posts.update_post_flags.assert_called_once_with(
|
posts.update_post_flags.assert_called_once_with(
|
||||||
post, ['flag1', 'flag2'])
|
post, ['flag1', 'flag2'])
|
||||||
posts.serialize_post.assert_called_once_with(
|
posts.serialize_post.assert_called_once_with(
|
||||||
post, auth_user, options=None)
|
post, auth_user, options=[])
|
||||||
snapshots.modify.assert_called_once_with(post, auth_user)
|
snapshots.modify.assert_called_once_with(post, auth_user)
|
||||||
tags.export_to_json.assert_called_once_with()
|
tags.export_to_json.assert_called_once_with()
|
||||||
assert post.last_edit_time == datetime(1997, 1, 1)
|
assert post.last_edit_time == datetime(1997, 1, 1)
|
||||||
|
@ -97,7 +97,7 @@ def test_uploading_from_url_saves_source(
|
||||||
api.post_api.update_post(
|
api.post_api.update_post(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'contentUrl': 'example.com', 'version': 1},
|
params={'contentUrl': 'example.com', 'version': 1},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'post_id': post.post_id})
|
{'post_id': post.post_id})
|
||||||
net.download.assert_called_once_with('example.com')
|
net.download.assert_called_once_with('example.com')
|
||||||
posts.update_post_content.assert_called_once_with(post, b'content')
|
posts.update_post_content.assert_called_once_with(post, b'content')
|
||||||
|
@ -122,7 +122,7 @@ def test_uploading_from_url_with_source_specified(
|
||||||
'contentUrl': 'example.com',
|
'contentUrl': 'example.com',
|
||||||
'source': 'example2.com',
|
'source': 'example2.com',
|
||||||
'version': 1},
|
'version': 1},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'post_id': post.post_id})
|
{'post_id': post.post_id})
|
||||||
net.download.assert_called_once_with('example.com')
|
net.download.assert_called_once_with('example.com')
|
||||||
posts.update_post_content.assert_called_once_with(post, b'content')
|
posts.update_post_content.assert_called_once_with(post, b'content')
|
||||||
|
@ -134,7 +134,7 @@ def test_trying_to_update_non_existing(context_factory, user_factory):
|
||||||
api.post_api.update_post(
|
api.post_api.update_post(
|
||||||
context_factory(
|
context_factory(
|
||||||
params='whatever',
|
params='whatever',
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'post_id': 1})
|
{'post_id': 1})
|
||||||
|
|
||||||
|
|
||||||
|
@ -158,7 +158,7 @@ def test_trying_to_update_field_without_privileges(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={**params, **{'version': 1}},
|
params={**params, **{'version': 1}},
|
||||||
files=files,
|
files=files,
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
|
user=user_factory(rank=model.User.RANK_ANONYMOUS)),
|
||||||
{'post_id': post.post_id})
|
{'post_id': post.post_id})
|
||||||
|
|
||||||
|
|
||||||
|
@ -173,5 +173,5 @@ def test_trying_to_create_tags_without_privileges(
|
||||||
api.post_api.update_post(
|
api.post_api.update_post(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'tags': ['tag1', 'tag2'], 'version': 1},
|
params={'tags': ['tag1', 'tag2'], 'version': 1},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'post_id': post.post_id})
|
{'post_id': post.post_id})
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, model, errors
|
||||||
|
|
||||||
|
|
||||||
def snapshot_factory():
|
def snapshot_factory():
|
||||||
snapshot = db.Snapshot()
|
snapshot = model.Snapshot()
|
||||||
snapshot.creation_time = datetime(1999, 1, 1)
|
snapshot.creation_time = datetime(1999, 1, 1)
|
||||||
snapshot.resource_type = 'dummy'
|
snapshot.resource_type = 'dummy'
|
||||||
snapshot.resource_pkey = 1
|
snapshot.resource_pkey = 1
|
||||||
|
@ -17,7 +17,7 @@ def snapshot_factory():
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
'privileges': {'snapshots:list': db.User.RANK_REGULAR},
|
'privileges': {'snapshots:list': model.User.RANK_REGULAR},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ def test_retrieving_multiple(user_factory, context_factory):
|
||||||
result = api.snapshot_api.get_snapshots(
|
result = api.snapshot_api.get_snapshots(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'query': '', 'page': 1},
|
params={'query': '', 'page': 1},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=model.User.RANK_REGULAR)))
|
||||||
assert result['query'] == ''
|
assert result['query'] == ''
|
||||||
assert result['page'] == 1
|
assert result['page'] == 1
|
||||||
assert result['pageSize'] == 100
|
assert result['pageSize'] == 100
|
||||||
|
@ -43,4 +43,4 @@ def test_trying_to_retrieve_multiple_without_privileges(
|
||||||
api.snapshot_api.get_snapshots(
|
api.snapshot_api.get_snapshots(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'query': '', 'page': 1},
|
params={'query': '', 'page': 1},
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
|
user=user_factory(rank=model.User.RANK_ANONYMOUS)))
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, model, errors
|
||||||
from szurubooru.func import tag_categories, tags, snapshots
|
from szurubooru.func import tag_categories, tags, snapshots
|
||||||
|
|
||||||
|
|
||||||
|
@ -11,13 +11,13 @@ def _update_category_name(category, name):
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
'privileges': {'tag_categories:create': db.User.RANK_REGULAR},
|
'privileges': {'tag_categories:create': model.User.RANK_REGULAR},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def test_creating_category(
|
def test_creating_category(
|
||||||
tag_category_factory, user_factory, context_factory):
|
tag_category_factory, user_factory, context_factory):
|
||||||
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
auth_user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
category = tag_category_factory(name='meta')
|
category = tag_category_factory(name='meta')
|
||||||
db.session.add(category)
|
db.session.add(category)
|
||||||
|
|
||||||
|
@ -49,7 +49,7 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
|
||||||
api.tag_category_api.create_tag_category(
|
api.tag_category_api.create_tag_category(
|
||||||
context_factory(
|
context_factory(
|
||||||
params=params,
|
params=params,
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=model.User.RANK_REGULAR)))
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_create_without_privileges(user_factory, context_factory):
|
def test_trying_to_create_without_privileges(user_factory, context_factory):
|
||||||
|
@ -57,4 +57,4 @@ def test_trying_to_create_without_privileges(user_factory, context_factory):
|
||||||
api.tag_category_api.create_tag_category(
|
api.tag_category_api.create_tag_category(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'name': 'meta', 'color': 'black'},
|
params={'name': 'meta', 'color': 'black'},
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
|
user=user_factory(rank=model.User.RANK_ANONYMOUS)))
|
||||||
|
|
|
@ -1,18 +1,18 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, model, errors
|
||||||
from szurubooru.func import tag_categories, tags, snapshots
|
from szurubooru.func import tag_categories, tags, snapshots
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
'privileges': {'tag_categories:delete': db.User.RANK_REGULAR},
|
'privileges': {'tag_categories:delete': model.User.RANK_REGULAR},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def test_deleting(user_factory, tag_category_factory, context_factory):
|
def test_deleting(user_factory, tag_category_factory, context_factory):
|
||||||
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
auth_user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
category = tag_category_factory(name='category')
|
category = tag_category_factory(name='category')
|
||||||
db.session.add(tag_category_factory(name='root'))
|
db.session.add(tag_category_factory(name='root'))
|
||||||
db.session.add(category)
|
db.session.add(category)
|
||||||
|
@ -23,8 +23,8 @@ def test_deleting(user_factory, tag_category_factory, context_factory):
|
||||||
context_factory(params={'version': 1}, user=auth_user),
|
context_factory(params={'version': 1}, user=auth_user),
|
||||||
{'category_name': 'category'})
|
{'category_name': 'category'})
|
||||||
assert result == {}
|
assert result == {}
|
||||||
assert db.session.query(db.TagCategory).count() == 1
|
assert db.session.query(model.TagCategory).count() == 1
|
||||||
assert db.session.query(db.TagCategory).one().name == 'root'
|
assert db.session.query(model.TagCategory).one().name == 'root'
|
||||||
snapshots.delete.assert_called_once_with(category, auth_user)
|
snapshots.delete.assert_called_once_with(category, auth_user)
|
||||||
tags.export_to_json.assert_called_once_with()
|
tags.export_to_json.assert_called_once_with()
|
||||||
|
|
||||||
|
@ -41,9 +41,9 @@ def test_trying_to_delete_used(
|
||||||
api.tag_category_api.delete_tag_category(
|
api.tag_category_api.delete_tag_category(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'version': 1},
|
params={'version': 1},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'category_name': 'category'})
|
{'category_name': 'category'})
|
||||||
assert db.session.query(db.TagCategory).count() == 1
|
assert db.session.query(model.TagCategory).count() == 1
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_delete_last(
|
def test_trying_to_delete_last(
|
||||||
|
@ -54,14 +54,14 @@ def test_trying_to_delete_last(
|
||||||
api.tag_category_api.delete_tag_category(
|
api.tag_category_api.delete_tag_category(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'version': 1},
|
params={'version': 1},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'category_name': 'root'})
|
{'category_name': 'root'})
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_delete_non_existing(user_factory, context_factory):
|
def test_trying_to_delete_non_existing(user_factory, context_factory):
|
||||||
with pytest.raises(tag_categories.TagCategoryNotFoundError):
|
with pytest.raises(tag_categories.TagCategoryNotFoundError):
|
||||||
api.tag_category_api.delete_tag_category(
|
api.tag_category_api.delete_tag_category(
|
||||||
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
|
context_factory(user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'category_name': 'bad'})
|
{'category_name': 'bad'})
|
||||||
|
|
||||||
|
|
||||||
|
@ -73,6 +73,6 @@ def test_trying_to_delete_without_privileges(
|
||||||
api.tag_category_api.delete_tag_category(
|
api.tag_category_api.delete_tag_category(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'version': 1},
|
params={'version': 1},
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
|
user=user_factory(rank=model.User.RANK_ANONYMOUS)),
|
||||||
{'category_name': 'category'})
|
{'category_name': 'category'})
|
||||||
assert db.session.query(db.TagCategory).count() == 1
|
assert db.session.query(model.TagCategory).count() == 1
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, model, errors
|
||||||
from szurubooru.func import tag_categories
|
from szurubooru.func import tag_categories
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,8 +7,8 @@ from szurubooru.func import tag_categories
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
'privileges': {
|
'privileges': {
|
||||||
'tag_categories:list': db.User.RANK_REGULAR,
|
'tag_categories:list': model.User.RANK_REGULAR,
|
||||||
'tag_categories:view': db.User.RANK_REGULAR,
|
'tag_categories:view': model.User.RANK_REGULAR,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -21,7 +21,7 @@ def test_retrieving_multiple(
|
||||||
])
|
])
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
result = api.tag_category_api.get_tag_categories(
|
result = api.tag_category_api.get_tag_categories(
|
||||||
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)))
|
context_factory(user=user_factory(rank=model.User.RANK_REGULAR)))
|
||||||
assert [cat['name'] for cat in result['results']] == ['c1', 'c2']
|
assert [cat['name'] for cat in result['results']] == ['c1', 'c2']
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ def test_retrieving_single(
|
||||||
db.session.add(tag_category_factory(name='cat'))
|
db.session.add(tag_category_factory(name='cat'))
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
result = api.tag_category_api.get_tag_category(
|
result = api.tag_category_api.get_tag_category(
|
||||||
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
|
context_factory(user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'category_name': 'cat'})
|
{'category_name': 'cat'})
|
||||||
assert result == {
|
assert result == {
|
||||||
'name': 'cat',
|
'name': 'cat',
|
||||||
|
@ -44,7 +44,7 @@ def test_retrieving_single(
|
||||||
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
|
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
|
||||||
with pytest.raises(tag_categories.TagCategoryNotFoundError):
|
with pytest.raises(tag_categories.TagCategoryNotFoundError):
|
||||||
api.tag_category_api.get_tag_category(
|
api.tag_category_api.get_tag_category(
|
||||||
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
|
context_factory(user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'category_name': '-'})
|
{'category_name': '-'})
|
||||||
|
|
||||||
|
|
||||||
|
@ -52,5 +52,5 @@ def test_trying_to_retrieve_single_without_privileges(
|
||||||
user_factory, context_factory):
|
user_factory, context_factory):
|
||||||
with pytest.raises(errors.AuthError):
|
with pytest.raises(errors.AuthError):
|
||||||
api.tag_category_api.get_tag_category(
|
api.tag_category_api.get_tag_category(
|
||||||
context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)),
|
context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)),
|
||||||
{'category_name': '-'})
|
{'category_name': '-'})
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, model, errors
|
||||||
from szurubooru.func import tag_categories, tags, snapshots
|
from szurubooru.func import tag_categories, tags, snapshots
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,15 +12,15 @@ def _update_category_name(category, name):
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
'privileges': {
|
'privileges': {
|
||||||
'tag_categories:edit:name': db.User.RANK_REGULAR,
|
'tag_categories:edit:name': model.User.RANK_REGULAR,
|
||||||
'tag_categories:edit:color': db.User.RANK_REGULAR,
|
'tag_categories:edit:color': model.User.RANK_REGULAR,
|
||||||
'tag_categories:set_default': db.User.RANK_REGULAR,
|
'tag_categories:set_default': model.User.RANK_REGULAR,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def test_simple_updating(user_factory, tag_category_factory, context_factory):
|
def test_simple_updating(user_factory, tag_category_factory, context_factory):
|
||||||
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
auth_user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
category = tag_category_factory(name='name', color='black')
|
category = tag_category_factory(name='name', color='black')
|
||||||
db.session.add(category)
|
db.session.add(category)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
|
@ -61,7 +61,7 @@ def test_omitting_optional_field(
|
||||||
api.tag_category_api.update_tag_category(
|
api.tag_category_api.update_tag_category(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={**params, **{'version': 1}},
|
params={**params, **{'version': 1}},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'category_name': 'name'})
|
{'category_name': 'name'})
|
||||||
|
|
||||||
|
|
||||||
|
@ -70,7 +70,7 @@ def test_trying_to_update_non_existing(user_factory, context_factory):
|
||||||
api.tag_category_api.update_tag_category(
|
api.tag_category_api.update_tag_category(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'name': ['dummy']},
|
params={'name': ['dummy']},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'category_name': 'bad'})
|
{'category_name': 'bad'})
|
||||||
|
|
||||||
|
|
||||||
|
@ -86,7 +86,7 @@ def test_trying_to_update_without_privileges(
|
||||||
api.tag_category_api.update_tag_category(
|
api.tag_category_api.update_tag_category(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={**params, **{'version': 1}},
|
params={**params, **{'version': 1}},
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
|
user=user_factory(rank=model.User.RANK_ANONYMOUS)),
|
||||||
{'category_name': 'dummy'})
|
{'category_name': 'dummy'})
|
||||||
|
|
||||||
|
|
||||||
|
@ -106,7 +106,7 @@ def test_set_as_default(user_factory, tag_category_factory, context_factory):
|
||||||
'color': 'white',
|
'color': 'white',
|
||||||
'version': 1,
|
'version': 1,
|
||||||
},
|
},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'category_name': 'name'})
|
{'category_name': 'name'})
|
||||||
assert result == 'serialized category'
|
assert result == 'serialized category'
|
||||||
tag_categories.set_default_category.assert_called_once_with(category)
|
tag_categories.set_default_category.assert_called_once_with(category)
|
||||||
|
|
|
@ -1,16 +1,16 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, model, errors
|
||||||
from szurubooru.func import tags, snapshots
|
from szurubooru.func import tags, snapshots
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({'privileges': {'tags:create': db.User.RANK_REGULAR}})
|
config_injector({'privileges': {'tags:create': model.User.RANK_REGULAR}})
|
||||||
|
|
||||||
|
|
||||||
def test_creating_simple_tags(tag_factory, user_factory, context_factory):
|
def test_creating_simple_tags(tag_factory, user_factory, context_factory):
|
||||||
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
auth_user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
tag = tag_factory()
|
tag = tag_factory()
|
||||||
with patch('szurubooru.func.tags.create_tag'), \
|
with patch('szurubooru.func.tags.create_tag'), \
|
||||||
patch('szurubooru.func.tags.get_or_create_tags_by_names'), \
|
patch('szurubooru.func.tags.get_or_create_tags_by_names'), \
|
||||||
|
@ -50,7 +50,7 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
|
||||||
api.tag_api.create_tag(
|
api.tag_api.create_tag(
|
||||||
context_factory(
|
context_factory(
|
||||||
params=params,
|
params=params,
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=model.User.RANK_REGULAR)))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('field', ['implications', 'suggestions'])
|
@pytest.mark.parametrize('field', ['implications', 'suggestions'])
|
||||||
|
@ -70,7 +70,7 @@ def test_omitting_optional_field(
|
||||||
api.tag_api.create_tag(
|
api.tag_api.create_tag(
|
||||||
context_factory(
|
context_factory(
|
||||||
params=params,
|
params=params,
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=model.User.RANK_REGULAR)))
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_create_tag_without_privileges(
|
def test_trying_to_create_tag_without_privileges(
|
||||||
|
@ -84,4 +84,4 @@ def test_trying_to_create_tag_without_privileges(
|
||||||
'suggestions': ['tag'],
|
'suggestions': ['tag'],
|
||||||
'implications': [],
|
'implications': [],
|
||||||
},
|
},
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
|
user=user_factory(rank=model.User.RANK_ANONYMOUS)))
|
||||||
|
|
|
@ -1,16 +1,16 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, model, errors
|
||||||
from szurubooru.func import tags, snapshots
|
from szurubooru.func import tags, snapshots
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({'privileges': {'tags:delete': db.User.RANK_REGULAR}})
|
config_injector({'privileges': {'tags:delete': model.User.RANK_REGULAR}})
|
||||||
|
|
||||||
|
|
||||||
def test_deleting(user_factory, tag_factory, context_factory):
|
def test_deleting(user_factory, tag_factory, context_factory):
|
||||||
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
auth_user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
tag = tag_factory(names=['tag'])
|
tag = tag_factory(names=['tag'])
|
||||||
db.session.add(tag)
|
db.session.add(tag)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
@ -20,7 +20,7 @@ def test_deleting(user_factory, tag_factory, context_factory):
|
||||||
context_factory(params={'version': 1}, user=auth_user),
|
context_factory(params={'version': 1}, user=auth_user),
|
||||||
{'tag_name': 'tag'})
|
{'tag_name': 'tag'})
|
||||||
assert result == {}
|
assert result == {}
|
||||||
assert db.session.query(db.Tag).count() == 0
|
assert db.session.query(model.Tag).count() == 0
|
||||||
snapshots.delete.assert_called_once_with(tag, auth_user)
|
snapshots.delete.assert_called_once_with(tag, auth_user)
|
||||||
tags.export_to_json.assert_called_once_with()
|
tags.export_to_json.assert_called_once_with()
|
||||||
|
|
||||||
|
@ -36,17 +36,17 @@ def test_deleting_used(
|
||||||
api.tag_api.delete_tag(
|
api.tag_api.delete_tag(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'version': 1},
|
params={'version': 1},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'tag_name': 'tag'})
|
{'tag_name': 'tag'})
|
||||||
db.session.refresh(post)
|
db.session.refresh(post)
|
||||||
assert db.session.query(db.Tag).count() == 0
|
assert db.session.query(model.Tag).count() == 0
|
||||||
assert post.tags == []
|
assert post.tags == []
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_delete_non_existing(user_factory, context_factory):
|
def test_trying_to_delete_non_existing(user_factory, context_factory):
|
||||||
with pytest.raises(tags.TagNotFoundError):
|
with pytest.raises(tags.TagNotFoundError):
|
||||||
api.tag_api.delete_tag(
|
api.tag_api.delete_tag(
|
||||||
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
|
context_factory(user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'tag_name': 'bad'})
|
{'tag_name': 'bad'})
|
||||||
|
|
||||||
|
|
||||||
|
@ -58,6 +58,6 @@ def test_trying_to_delete_without_privileges(
|
||||||
api.tag_api.delete_tag(
|
api.tag_api.delete_tag(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'version': 1},
|
params={'version': 1},
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
|
user=user_factory(rank=model.User.RANK_ANONYMOUS)),
|
||||||
{'tag_name': 'tag'})
|
{'tag_name': 'tag'})
|
||||||
assert db.session.query(db.Tag).count() == 1
|
assert db.session.query(model.Tag).count() == 1
|
||||||
|
|
|
@ -1,16 +1,16 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, model, errors
|
||||||
from szurubooru.func import tags, snapshots
|
from szurubooru.func import tags, snapshots
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({'privileges': {'tags:merge': db.User.RANK_REGULAR}})
|
config_injector({'privileges': {'tags:merge': model.User.RANK_REGULAR}})
|
||||||
|
|
||||||
|
|
||||||
def test_merging(user_factory, tag_factory, context_factory, post_factory):
|
def test_merging(user_factory, tag_factory, context_factory, post_factory):
|
||||||
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
auth_user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
source_tag = tag_factory(names=['source'])
|
source_tag = tag_factory(names=['source'])
|
||||||
target_tag = tag_factory(names=['target'])
|
target_tag = tag_factory(names=['target'])
|
||||||
db.session.add_all([source_tag, target_tag])
|
db.session.add_all([source_tag, target_tag])
|
||||||
|
@ -62,7 +62,7 @@ def test_trying_to_omit_mandatory_field(
|
||||||
api.tag_api.merge_tags(
|
api.tag_api.merge_tags(
|
||||||
context_factory(
|
context_factory(
|
||||||
params=params,
|
params=params,
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=model.User.RANK_REGULAR)))
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_merge_non_existing(
|
def test_trying_to_merge_non_existing(
|
||||||
|
@ -73,12 +73,12 @@ def test_trying_to_merge_non_existing(
|
||||||
api.tag_api.merge_tags(
|
api.tag_api.merge_tags(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'remove': 'good', 'mergeTo': 'bad'},
|
params={'remove': 'good', 'mergeTo': 'bad'},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=model.User.RANK_REGULAR)))
|
||||||
with pytest.raises(tags.TagNotFoundError):
|
with pytest.raises(tags.TagNotFoundError):
|
||||||
api.tag_api.merge_tags(
|
api.tag_api.merge_tags(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'remove': 'bad', 'mergeTo': 'good'},
|
params={'remove': 'bad', 'mergeTo': 'good'},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=model.User.RANK_REGULAR)))
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_merge_without_privileges(
|
def test_trying_to_merge_without_privileges(
|
||||||
|
@ -97,4 +97,4 @@ def test_trying_to_merge_without_privileges(
|
||||||
'remove': 'source',
|
'remove': 'source',
|
||||||
'mergeTo': 'target',
|
'mergeTo': 'target',
|
||||||
},
|
},
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
|
user=user_factory(rank=model.User.RANK_ANONYMOUS)))
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, model, errors
|
||||||
from szurubooru.func import tags
|
from szurubooru.func import tags
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,8 +8,8 @@ from szurubooru.func import tags
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
'privileges': {
|
'privileges': {
|
||||||
'tags:list': db.User.RANK_REGULAR,
|
'tags:list': model.User.RANK_REGULAR,
|
||||||
'tags:view': db.User.RANK_REGULAR,
|
'tags:view': model.User.RANK_REGULAR,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ def test_retrieving_multiple(user_factory, tag_factory, context_factory):
|
||||||
result = api.tag_api.get_tags(
|
result = api.tag_api.get_tags(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'query': '', 'page': 1},
|
params={'query': '', 'page': 1},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=model.User.RANK_REGULAR)))
|
||||||
assert result == {
|
assert result == {
|
||||||
'query': '',
|
'query': '',
|
||||||
'page': 1,
|
'page': 1,
|
||||||
|
@ -40,7 +40,7 @@ def test_trying_to_retrieve_multiple_without_privileges(
|
||||||
api.tag_api.get_tags(
|
api.tag_api.get_tags(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'query': '', 'page': 1},
|
params={'query': '', 'page': 1},
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
|
user=user_factory(rank=model.User.RANK_ANONYMOUS)))
|
||||||
|
|
||||||
|
|
||||||
def test_retrieving_single(user_factory, tag_factory, context_factory):
|
def test_retrieving_single(user_factory, tag_factory, context_factory):
|
||||||
|
@ -50,7 +50,7 @@ def test_retrieving_single(user_factory, tag_factory, context_factory):
|
||||||
tags.serialize_tag.return_value = 'serialized tag'
|
tags.serialize_tag.return_value = 'serialized tag'
|
||||||
result = api.tag_api.get_tag(
|
result = api.tag_api.get_tag(
|
||||||
context_factory(
|
context_factory(
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'tag_name': 'tag'})
|
{'tag_name': 'tag'})
|
||||||
assert result == 'serialized tag'
|
assert result == 'serialized tag'
|
||||||
|
|
||||||
|
@ -59,7 +59,7 @@ def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
|
||||||
with pytest.raises(tags.TagNotFoundError):
|
with pytest.raises(tags.TagNotFoundError):
|
||||||
api.tag_api.get_tag(
|
api.tag_api.get_tag(
|
||||||
context_factory(
|
context_factory(
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'tag_name': '-'})
|
{'tag_name': '-'})
|
||||||
|
|
||||||
|
|
||||||
|
@ -68,5 +68,5 @@ def test_trying_to_retrieve_single_without_privileges(
|
||||||
with pytest.raises(errors.AuthError):
|
with pytest.raises(errors.AuthError):
|
||||||
api.tag_api.get_tag(
|
api.tag_api.get_tag(
|
||||||
context_factory(
|
context_factory(
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
|
user=user_factory(rank=model.User.RANK_ANONYMOUS)),
|
||||||
{'tag_name': '-'})
|
{'tag_name': '-'})
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, model, errors
|
||||||
from szurubooru.func import tags
|
from szurubooru.func import tags
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({'privileges': {'tags:view': db.User.RANK_REGULAR}})
|
config_injector({'privileges': {'tags:view': model.User.RANK_REGULAR}})
|
||||||
|
|
||||||
|
|
||||||
def test_get_tag_siblings(user_factory, tag_factory, context_factory):
|
def test_get_tag_siblings(user_factory, tag_factory, context_factory):
|
||||||
|
@ -21,7 +21,7 @@ def test_get_tag_siblings(user_factory, tag_factory, context_factory):
|
||||||
(tag_factory(names=['sib2']), 3),
|
(tag_factory(names=['sib2']), 3),
|
||||||
]
|
]
|
||||||
result = api.tag_api.get_tag_siblings(
|
result = api.tag_api.get_tag_siblings(
|
||||||
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
|
context_factory(user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'tag_name': 'tag'})
|
{'tag_name': 'tag'})
|
||||||
assert result == {
|
assert result == {
|
||||||
'results': [
|
'results': [
|
||||||
|
@ -40,12 +40,12 @@ def test_get_tag_siblings(user_factory, tag_factory, context_factory):
|
||||||
def test_trying_to_retrieve_non_existing(user_factory, context_factory):
|
def test_trying_to_retrieve_non_existing(user_factory, context_factory):
|
||||||
with pytest.raises(tags.TagNotFoundError):
|
with pytest.raises(tags.TagNotFoundError):
|
||||||
api.tag_api.get_tag_siblings(
|
api.tag_api.get_tag_siblings(
|
||||||
context_factory(user=user_factory(rank=db.User.RANK_REGULAR)),
|
context_factory(user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'tag_name': '-'})
|
{'tag_name': '-'})
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_retrieve_without_privileges(user_factory, context_factory):
|
def test_trying_to_retrieve_without_privileges(user_factory, context_factory):
|
||||||
with pytest.raises(errors.AuthError):
|
with pytest.raises(errors.AuthError):
|
||||||
api.tag_api.get_tag_siblings(
|
api.tag_api.get_tag_siblings(
|
||||||
context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)),
|
context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)),
|
||||||
{'tag_name': '-'})
|
{'tag_name': '-'})
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, model, errors
|
||||||
from szurubooru.func import tags, snapshots
|
from szurubooru.func import tags, snapshots
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,18 +8,18 @@ from szurubooru.func import tags, snapshots
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
'privileges': {
|
'privileges': {
|
||||||
'tags:create': db.User.RANK_REGULAR,
|
'tags:create': model.User.RANK_REGULAR,
|
||||||
'tags:edit:names': db.User.RANK_REGULAR,
|
'tags:edit:names': model.User.RANK_REGULAR,
|
||||||
'tags:edit:category': db.User.RANK_REGULAR,
|
'tags:edit:category': model.User.RANK_REGULAR,
|
||||||
'tags:edit:description': db.User.RANK_REGULAR,
|
'tags:edit:description': model.User.RANK_REGULAR,
|
||||||
'tags:edit:suggestions': db.User.RANK_REGULAR,
|
'tags:edit:suggestions': model.User.RANK_REGULAR,
|
||||||
'tags:edit:implications': db.User.RANK_REGULAR,
|
'tags:edit:implications': model.User.RANK_REGULAR,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def test_simple_updating(user_factory, tag_factory, context_factory):
|
def test_simple_updating(user_factory, tag_factory, context_factory):
|
||||||
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
auth_user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
tag = tag_factory(names=['tag1', 'tag2'])
|
tag = tag_factory(names=['tag1', 'tag2'])
|
||||||
db.session.add(tag)
|
db.session.add(tag)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
@ -56,8 +56,7 @@ def test_simple_updating(user_factory, tag_factory, context_factory):
|
||||||
tag, ['sug1', 'sug2'])
|
tag, ['sug1', 'sug2'])
|
||||||
tags.update_tag_implications.assert_called_once_with(
|
tags.update_tag_implications.assert_called_once_with(
|
||||||
tag, ['imp1', 'imp2'])
|
tag, ['imp1', 'imp2'])
|
||||||
tags.serialize_tag.assert_called_once_with(
|
tags.serialize_tag.assert_called_once_with(tag, options=[])
|
||||||
tag, options=None)
|
|
||||||
snapshots.modify.assert_called_once_with(tag, auth_user)
|
snapshots.modify.assert_called_once_with(tag, auth_user)
|
||||||
tags.export_to_json.assert_called_once_with()
|
tags.export_to_json.assert_called_once_with()
|
||||||
|
|
||||||
|
@ -90,7 +89,7 @@ def test_omitting_optional_field(
|
||||||
api.tag_api.update_tag(
|
api.tag_api.update_tag(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={**params, **{'version': 1}},
|
params={**params, **{'version': 1}},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'tag_name': 'tag'})
|
{'tag_name': 'tag'})
|
||||||
|
|
||||||
|
|
||||||
|
@ -99,7 +98,7 @@ def test_trying_to_update_non_existing(user_factory, context_factory):
|
||||||
api.tag_api.update_tag(
|
api.tag_api.update_tag(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'names': ['dummy']},
|
params={'names': ['dummy']},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'tag_name': 'tag1'})
|
{'tag_name': 'tag1'})
|
||||||
|
|
||||||
|
|
||||||
|
@ -117,7 +116,7 @@ def test_trying_to_update_without_privileges(
|
||||||
api.tag_api.update_tag(
|
api.tag_api.update_tag(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={**params, **{'version': 1}},
|
params={**params, **{'version': 1}},
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)),
|
user=user_factory(rank=model.User.RANK_ANONYMOUS)),
|
||||||
{'tag_name': 'tag'})
|
{'tag_name': 'tag'})
|
||||||
|
|
||||||
|
|
||||||
|
@ -127,9 +126,9 @@ def test_trying_to_create_tags_without_privileges(
|
||||||
db.session.add(tag)
|
db.session.add(tag)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
config_injector({'privileges': {
|
config_injector({'privileges': {
|
||||||
'tags:create': db.User.RANK_ADMINISTRATOR,
|
'tags:create': model.User.RANK_ADMINISTRATOR,
|
||||||
'tags:edit:suggestions': db.User.RANK_REGULAR,
|
'tags:edit:suggestions': model.User.RANK_REGULAR,
|
||||||
'tags:edit:implications': db.User.RANK_REGULAR,
|
'tags:edit:implications': model.User.RANK_REGULAR,
|
||||||
}})
|
}})
|
||||||
with patch('szurubooru.func.tags.get_or_create_tags_by_names'):
|
with patch('szurubooru.func.tags.get_or_create_tags_by_names'):
|
||||||
tags.get_or_create_tags_by_names.return_value = ([], ['new-tag'])
|
tags.get_or_create_tags_by_names.return_value = ([], ['new-tag'])
|
||||||
|
@ -137,12 +136,12 @@ def test_trying_to_create_tags_without_privileges(
|
||||||
api.tag_api.update_tag(
|
api.tag_api.update_tag(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'suggestions': ['tag1', 'tag2'], 'version': 1},
|
params={'suggestions': ['tag1', 'tag2'], 'version': 1},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'tag_name': 'tag'})
|
{'tag_name': 'tag'})
|
||||||
db.session.rollback()
|
db.session.rollback()
|
||||||
with pytest.raises(errors.AuthError):
|
with pytest.raises(errors.AuthError):
|
||||||
api.tag_api.update_tag(
|
api.tag_api.update_tag(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'implications': ['tag1', 'tag2'], 'version': 1},
|
params={'implications': ['tag1', 'tag2'], 'version': 1},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'tag_name': 'tag'})
|
{'tag_name': 'tag'})
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, model, errors
|
||||||
from szurubooru.func import users
|
from szurubooru.func import users
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,7 +31,7 @@ def test_creating_user(user_factory, context_factory, fake_datetime):
|
||||||
'avatarStyle': 'manual',
|
'avatarStyle': 'manual',
|
||||||
},
|
},
|
||||||
files={'avatar': b'...'},
|
files={'avatar': b'...'},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=model.User.RANK_REGULAR)))
|
||||||
assert result == 'serialized user'
|
assert result == 'serialized user'
|
||||||
users.create_user.assert_called_once_with(
|
users.create_user.assert_called_once_with(
|
||||||
'chewie1', 'oks', 'asd@asd.asd')
|
'chewie1', 'oks', 'asd@asd.asd')
|
||||||
|
@ -50,7 +50,7 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field):
|
||||||
'password': 'oks',
|
'password': 'oks',
|
||||||
}
|
}
|
||||||
user = user_factory()
|
user = user_factory()
|
||||||
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
auth_user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
del params[field]
|
del params[field]
|
||||||
with patch('szurubooru.func.users.create_user'), \
|
with patch('szurubooru.func.users.create_user'), \
|
||||||
pytest.raises(errors.MissingRequiredParameterError):
|
pytest.raises(errors.MissingRequiredParameterError):
|
||||||
|
@ -70,7 +70,7 @@ def test_omitting_optional_field(user_factory, context_factory, field):
|
||||||
}
|
}
|
||||||
del params[field]
|
del params[field]
|
||||||
user = user_factory()
|
user = user_factory()
|
||||||
auth_user = user_factory(rank=db.User.RANK_MODERATOR)
|
auth_user = user_factory(rank=model.User.RANK_MODERATOR)
|
||||||
with patch('szurubooru.func.users.create_user'), \
|
with patch('szurubooru.func.users.create_user'), \
|
||||||
patch('szurubooru.func.users.update_user_avatar'), \
|
patch('szurubooru.func.users.update_user_avatar'), \
|
||||||
patch('szurubooru.func.users.serialize_user'):
|
patch('szurubooru.func.users.serialize_user'):
|
||||||
|
@ -84,4 +84,4 @@ def test_trying_to_create_user_without_privileges(
|
||||||
with pytest.raises(errors.AuthError):
|
with pytest.raises(errors.AuthError):
|
||||||
api.user_api.create_user(context_factory(
|
api.user_api.create_user(context_factory(
|
||||||
params='whatever',
|
params='whatever',
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
|
user=user_factory(rank=model.User.RANK_ANONYMOUS)))
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, model, errors
|
||||||
from szurubooru.func import users
|
from szurubooru.func import users
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,45 +7,45 @@ from szurubooru.func import users
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
'privileges': {
|
'privileges': {
|
||||||
'users:delete:self': db.User.RANK_REGULAR,
|
'users:delete:self': model.User.RANK_REGULAR,
|
||||||
'users:delete:any': db.User.RANK_MODERATOR,
|
'users:delete:any': model.User.RANK_MODERATOR,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def test_deleting_oneself(user_factory, context_factory):
|
def test_deleting_oneself(user_factory, context_factory):
|
||||||
user = user_factory(name='u', rank=db.User.RANK_REGULAR)
|
user = user_factory(name='u', rank=model.User.RANK_REGULAR)
|
||||||
db.session.add(user)
|
db.session.add(user)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
result = api.user_api.delete_user(
|
result = api.user_api.delete_user(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'version': 1}, user=user), {'user_name': 'u'})
|
params={'version': 1}, user=user), {'user_name': 'u'})
|
||||||
assert result == {}
|
assert result == {}
|
||||||
assert db.session.query(db.User).count() == 0
|
assert db.session.query(model.User).count() == 0
|
||||||
|
|
||||||
|
|
||||||
def test_deleting_someone_else(user_factory, context_factory):
|
def test_deleting_someone_else(user_factory, context_factory):
|
||||||
user1 = user_factory(name='u1', rank=db.User.RANK_REGULAR)
|
user1 = user_factory(name='u1', rank=model.User.RANK_REGULAR)
|
||||||
user2 = user_factory(name='u2', rank=db.User.RANK_MODERATOR)
|
user2 = user_factory(name='u2', rank=model.User.RANK_MODERATOR)
|
||||||
db.session.add_all([user1, user2])
|
db.session.add_all([user1, user2])
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
api.user_api.delete_user(
|
api.user_api.delete_user(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'version': 1}, user=user2), {'user_name': 'u1'})
|
params={'version': 1}, user=user2), {'user_name': 'u1'})
|
||||||
assert db.session.query(db.User).count() == 1
|
assert db.session.query(model.User).count() == 1
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_delete_someone_else_without_privileges(
|
def test_trying_to_delete_someone_else_without_privileges(
|
||||||
user_factory, context_factory):
|
user_factory, context_factory):
|
||||||
user1 = user_factory(name='u1', rank=db.User.RANK_REGULAR)
|
user1 = user_factory(name='u1', rank=model.User.RANK_REGULAR)
|
||||||
user2 = user_factory(name='u2', rank=db.User.RANK_REGULAR)
|
user2 = user_factory(name='u2', rank=model.User.RANK_REGULAR)
|
||||||
db.session.add_all([user1, user2])
|
db.session.add_all([user1, user2])
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
with pytest.raises(errors.AuthError):
|
with pytest.raises(errors.AuthError):
|
||||||
api.user_api.delete_user(
|
api.user_api.delete_user(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'version': 1}, user=user2), {'user_name': 'u1'})
|
params={'version': 1}, user=user2), {'user_name': 'u1'})
|
||||||
assert db.session.query(db.User).count() == 2
|
assert db.session.query(model.User).count() == 2
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_delete_non_existing(user_factory, context_factory):
|
def test_trying_to_delete_non_existing(user_factory, context_factory):
|
||||||
|
@ -53,5 +53,5 @@ def test_trying_to_delete_non_existing(user_factory, context_factory):
|
||||||
api.user_api.delete_user(
|
api.user_api.delete_user(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'version': 1},
|
params={'version': 1},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)),
|
user=user_factory(rank=model.User.RANK_REGULAR)),
|
||||||
{'user_name': 'bad'})
|
{'user_name': 'bad'})
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, model, errors
|
||||||
from szurubooru.func import users
|
from szurubooru.func import users
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,16 +8,16 @@ from szurubooru.func import users
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
'privileges': {
|
'privileges': {
|
||||||
'users:list': db.User.RANK_REGULAR,
|
'users:list': model.User.RANK_REGULAR,
|
||||||
'users:view': db.User.RANK_REGULAR,
|
'users:view': model.User.RANK_REGULAR,
|
||||||
'users:edit:any:email': db.User.RANK_MODERATOR,
|
'users:edit:any:email': model.User.RANK_MODERATOR,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def test_retrieving_multiple(user_factory, context_factory):
|
def test_retrieving_multiple(user_factory, context_factory):
|
||||||
user1 = user_factory(name='u1', rank=db.User.RANK_MODERATOR)
|
user1 = user_factory(name='u1', rank=model.User.RANK_MODERATOR)
|
||||||
user2 = user_factory(name='u2', rank=db.User.RANK_MODERATOR)
|
user2 = user_factory(name='u2', rank=model.User.RANK_MODERATOR)
|
||||||
db.session.add_all([user1, user2])
|
db.session.add_all([user1, user2])
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
with patch('szurubooru.func.users.serialize_user'):
|
with patch('szurubooru.func.users.serialize_user'):
|
||||||
|
@ -25,7 +25,7 @@ def test_retrieving_multiple(user_factory, context_factory):
|
||||||
result = api.user_api.get_users(
|
result = api.user_api.get_users(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'query': '', 'page': 1},
|
params={'query': '', 'page': 1},
|
||||||
user=user_factory(rank=db.User.RANK_REGULAR)))
|
user=user_factory(rank=model.User.RANK_REGULAR)))
|
||||||
assert result == {
|
assert result == {
|
||||||
'query': '',
|
'query': '',
|
||||||
'page': 1,
|
'page': 1,
|
||||||
|
@ -41,12 +41,12 @@ def test_trying_to_retrieve_multiple_without_privileges(
|
||||||
api.user_api.get_users(
|
api.user_api.get_users(
|
||||||
context_factory(
|
context_factory(
|
||||||
params={'query': '', 'page': 1},
|
params={'query': '', 'page': 1},
|
||||||
user=user_factory(rank=db.User.RANK_ANONYMOUS)))
|
user=user_factory(rank=model.User.RANK_ANONYMOUS)))
|
||||||
|
|
||||||
|
|
||||||
def test_retrieving_single(user_factory, context_factory):
|
def test_retrieving_single(user_factory, context_factory):
|
||||||
user = user_factory(name='u1', rank=db.User.RANK_REGULAR)
|
user = user_factory(name='u1', rank=model.User.RANK_REGULAR)
|
||||||
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
auth_user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
db.session.add(user)
|
db.session.add(user)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
with patch('szurubooru.func.users.serialize_user'):
|
with patch('szurubooru.func.users.serialize_user'):
|
||||||
|
@ -57,7 +57,7 @@ def test_retrieving_single(user_factory, context_factory):
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
|
def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
|
||||||
auth_user = user_factory(rank=db.User.RANK_REGULAR)
|
auth_user = user_factory(rank=model.User.RANK_REGULAR)
|
||||||
with pytest.raises(users.UserNotFoundError):
|
with pytest.raises(users.UserNotFoundError):
|
||||||
api.user_api.get_user(
|
api.user_api.get_user(
|
||||||
context_factory(user=auth_user), {'user_name': '-'})
|
context_factory(user=auth_user), {'user_name': '-'})
|
||||||
|
@ -65,8 +65,8 @@ def test_trying_to_retrieve_single_non_existing(user_factory, context_factory):
|
||||||
|
|
||||||
def test_trying_to_retrieve_single_without_privileges(
|
def test_trying_to_retrieve_single_without_privileges(
|
||||||
user_factory, context_factory):
|
user_factory, context_factory):
|
||||||
auth_user = user_factory(rank=db.User.RANK_ANONYMOUS)
|
auth_user = user_factory(rank=model.User.RANK_ANONYMOUS)
|
||||||
db.session.add(user_factory(name='u1', rank=db.User.RANK_REGULAR))
|
db.session.add(user_factory(name='u1', rank=model.User.RANK_REGULAR))
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
with pytest.raises(errors.AuthError):
|
with pytest.raises(errors.AuthError):
|
||||||
api.user_api.get_user(
|
api.user_api.get_user(
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
from szurubooru import api, db, errors
|
from szurubooru import api, db, model, errors
|
||||||
from szurubooru.func import users
|
from szurubooru.func import users
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,23 +8,23 @@ from szurubooru.func import users
|
||||||
def inject_config(config_injector):
|
def inject_config(config_injector):
|
||||||
config_injector({
|
config_injector({
|
||||||
'privileges': {
|
'privileges': {
|
||||||
'users:edit:self:name': db.User.RANK_REGULAR,
|
'users:edit:self:name': model.User.RANK_REGULAR,
|
||||||
'users:edit:self:pass': db.User.RANK_REGULAR,
|
'users:edit:self:pass': model.User.RANK_REGULAR,
|
||||||
'users:edit:self:email': db.User.RANK_REGULAR,
|
'users:edit:self:email': model.User.RANK_REGULAR,
|
||||||
'users:edit:self:rank': db.User.RANK_MODERATOR,
|
'users:edit:self:rank': model.User.RANK_MODERATOR,
|
||||||
'users:edit:self:avatar': db.User.RANK_MODERATOR,
|
'users:edit:self:avatar': model.User.RANK_MODERATOR,
|
||||||
'users:edit:any:name': db.User.RANK_MODERATOR,
|
'users:edit:any:name': model.User.RANK_MODERATOR,
|
||||||
'users:edit:any:pass': db.User.RANK_MODERATOR,
|
'users:edit:any:pass': model.User.RANK_MODERATOR,
|
||||||
'users:edit:any:email': db.User.RANK_MODERATOR,
|
'users:edit:any:email': model.User.RANK_MODERATOR,
|
||||||
'users:edit:any:rank': db.User.RANK_ADMINISTRATOR,
|
'users:edit:any:rank': model.User.RANK_ADMINISTRATOR,
|
||||||
'users:edit:any:avatar': db.User.RANK_ADMINISTRATOR,
|
'users:edit:any:avatar': model.User.RANK_ADMINISTRATOR,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
def test_updating_user(context_factory, user_factory):
|
def test_updating_user(context_factory, user_factory):
|
||||||
user = user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR)
|
user = user_factory(name='u1', rank=model.User.RANK_ADMINISTRATOR)
|
||||||
auth_user = user_factory(rank=db.User.RANK_ADMINISTRATOR)
|
auth_user = user_factory(rank=model.User.RANK_ADMINISTRATOR)
|
||||||
db.session.add(user)
|
db.session.add(user)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
|
|
||||||
|
@ -63,13 +63,13 @@ def test_updating_user(context_factory, user_factory):
|
||||||
users.update_user_avatar.assert_called_once_with(
|
users.update_user_avatar.assert_called_once_with(
|
||||||
user, 'manual', b'...')
|
user, 'manual', b'...')
|
||||||
users.serialize_user.assert_called_once_with(
|
users.serialize_user.assert_called_once_with(
|
||||||
user, auth_user, options=None)
|
user, auth_user, options=[])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
'field', ['name', 'email', 'password', 'rank', 'avatarStyle'])
|
'field', ['name', 'email', 'password', 'rank', 'avatarStyle'])
|
||||||
def test_omitting_optional_field(user_factory, context_factory, field):
|
def test_omitting_optional_field(user_factory, context_factory, field):
|
||||||
user = user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR)
|
user = user_factory(name='u1', rank=model.User.RANK_ADMINISTRATOR)
|
||||||
db.session.add(user)
|
db.session.add(user)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
params = {
|
params = {
|
||||||
|
@ -96,7 +96,7 @@ def test_omitting_optional_field(user_factory, context_factory, field):
|
||||||
|
|
||||||
|
|
||||||
def test_trying_to_update_non_existing(user_factory, context_factory):
|
def test_trying_to_update_non_existing(user_factory, context_factory):
|
||||||
user = user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR)
|
user = user_factory(name='u1', rank=model.User.RANK_ADMINISTRATOR)
|
||||||
db.session.add(user)
|
db.session.add(user)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
with pytest.raises(users.UserNotFoundError):
|
with pytest.raises(users.UserNotFoundError):
|
||||||
|
@ -113,8 +113,8 @@ def test_trying_to_update_non_existing(user_factory, context_factory):
|
||||||
])
|
])
|
||||||
def test_trying_to_update_field_without_privileges(
|
def test_trying_to_update_field_without_privileges(
|
||||||
user_factory, context_factory, params):
|
user_factory, context_factory, params):
|
||||||
user1 = user_factory(name='u1', rank=db.User.RANK_REGULAR)
|
user1 = user_factory(name='u1', rank=model.User.RANK_REGULAR)
|
||||||
user2 = user_factory(name='u2', rank=db.User.RANK_REGULAR)
|
user2 = user_factory(name='u2', rank=model.User.RANK_REGULAR)
|
||||||
db.session.add_all([user1, user2])
|
db.session.add_all([user1, user2])
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
with pytest.raises(errors.AuthError):
|
with pytest.raises(errors.AuthError):
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue