Add more ban API endpoints

This commit is contained in:
rebel 2023-05-17 01:28:44 +02:00
parent a7f6547f18
commit 68c2f8f676
11 changed files with 222 additions and 36 deletions

View file

@ -116,7 +116,9 @@ privileges:
'posts:bulk-edit:tags': power
'posts:bulk-edit:safety': power
'posts:bulk-edit:delete': power
'posts:ban': moderator
'posts:ban:create': moderator
'posts:ban:delete': moderator
'posts:ban:list': moderator
'tags:create': regular
'tags:edit:names': power

View file

@ -10,3 +10,4 @@ import szurubooru.api.tag_category_api
import szurubooru.api.upload_api
import szurubooru.api.user_api
import szurubooru.api.user_token_api
import szurubooru.api.ban_api

View file

@ -0,0 +1,58 @@
from datetime import datetime
from typing import Dict, List, Optional
from server.szurubooru.api import post_api
from server.szurubooru.func import posts
from server.szurubooru.model.bans import PostBan
from szurubooru import db, errors, model, rest, search
from szurubooru.func import (
auth,
bans,
serialization,
snapshots,
versions,
)
def _get_ban_by_hash(hash: str) -> Optional[PostBan]:
try:
return bans.get_bans_by_hash(hash)
except:
return None
_search_executor = search.Executor(search.configs.BanSearchConfig())
def _serialize(ctx: rest.Context, ban: model.PostBan) -> rest.Response:
return bans.serialize_ban(
ban, options=serialization.get_serialization_options(ctx)
)
@rest.routes.post("/post-ban/(?P<post_id>[^/]+)/?")
def ban_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, "posts:ban:create")
post = post_api._get_post(params)
versions.verify_version(post, ctx)
posts.ban(bans.create_ban(post))
snapshots.delete(post, ctx.user)
posts.delete(post)
ctx.session.commit()
return {}
@rest.routes.delete("/post-ban/(?P<image_hash>[^/]+)/?")
def unban_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, "posts:ban:delete")
ban = _get_ban_by_hash(params["image_hash"])
bans.delete(ban)
ctx.session.commit()
return {}
@rest.routes.get("/post-ban/?")
def get_bans(ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, "posts:ban:list")
return _search_executor.execute_and_serialize(
ctx, lambda tag: _serialize(ctx, tag)
)

View file

@ -183,18 +183,6 @@ def delete_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
return {}
@rest.routes.post("/post-ban/(?P<post_id>[^/]+)/?")
def ban_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
auth.verify_privilege(ctx.user, "posts:ban")
post = _get_post(params)
versions.verify_version(post, ctx)
posts.ban(posts.create_ban(post))
snapshots.delete(post, ctx.user)
posts.delete(post)
ctx.session.commit()
return {}
@rest.routes.post("/post-merge/?")
def merge_posts(
ctx: rest.Context, _params: Dict[str, str] = {}

View file

@ -0,0 +1,68 @@
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Tuple
from szurubooru import db, errors, model, rest
from szurubooru.func import (
serialization,
)
class PostBannedError(errors.ValidationError):
def __init__(self, message: str = "This file was banned", extra_fields: Dict[str, str] = None) -> None:
super().__init__(message, extra_fields)
class HashNotBannedError(errors.ValidationError):
pass
class TagSerializer(serialization.BaseSerializer):
def __init__(self, ban: model.PostBan) -> None:
self.ban = ban
def _serializers(self) -> Dict[str, Callable[[], Any]]:
return {
"checksum": self.serialize_checksum,
"time": self.serialize_time
}
def serialize_checksum(self) -> Any:
return self.ban.checksum
def serialize_time(self) -> Any:
return self.ban.time
def create_ban(post: model.Post) -> model.PostBan:
ban = model.PostBan()
ban.checksum = post.checksum
ban.time = datetime.utcnow()
db.session.add(ban)
return ban
def try_get_ban_by_checksum(checksum: str) -> Optional[model.PostBan]:
return (
db.session.query(model.PostBan)
.filter(model.PostBan.checksum == checksum)
.one_or_none()
)
def get_bans_by_hash(hash: str) -> model.PostBan:
ban = try_get_ban_by_checksum(hash)
if ban is None:
raise HashNotBannedError("Hash %s is not banned" % hash)
return ban
def delete(ban: model.PostBan) -> None:
db.session.delete(ban)
def serialize_ban(
ban: model.PostBan, options: List[str] = []
) -> Optional[rest.Response]:
if not ban:
return None
return serialization.BaseSerializer(ban).serialize(options)

View file

@ -4,6 +4,7 @@ from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Tuple
import sqlalchemy as sa
from server.szurubooru.func.bans import PostBannedError
from szurubooru import config, db, errors, model, rest
from szurubooru.func import (
@ -50,11 +51,6 @@ class PostAlreadyUploadedError(errors.ValidationError):
)
class PostBannedError(errors.ValidationError):
def __init__(self, message: str = "This file was banned", extra_fields: Dict[str, str] = None) -> None:
super().__init__(message, extra_fields)
class InvalidPostIdError(errors.ValidationError):
pass
@ -431,15 +427,6 @@ def create_post(
return post, new_tags
def create_ban(post: model.Post) -> model.PostBan:
ban = model.PostBan()
ban.checksum = post.checksum
ban.time = datetime.utcnow()
db.session.add(ban)
return ban
def update_post_safety(post: model.Post, safety: str) -> None:
assert post
safety = util.flip(SAFETY_MAP).get(safety, None)

View file

@ -1,11 +1,11 @@
import szurubooru.model.util
from szurubooru.model.base import Base
from szurubooru.model.bans import PostBan
from szurubooru.model.comment import Comment, CommentScore
from szurubooru.model.pool import Pool, PoolName, PoolPost
from szurubooru.model.pool_category import PoolCategory
from szurubooru.model.post import (
Post,
PostBan,
PostFavorite,
PostFeature,
PostNote,

View file

@ -0,0 +1,12 @@
from typing import List
import sqlalchemy as sa
from szurubooru.model.base import Base
class PostBan(Base):
__tablename__ = "post_ban"
ban_id = sa.Column("id", sa.Integer, primary_key=True)
checksum = sa.Column("checksum", sa.Unicode(64), nullable=False)
time = sa.Column("time", sa.DateTime, nullable=False)

View file

@ -94,14 +94,6 @@ class PostFavorite(Base):
)
class PostBan(Base):
__tablename__ = "post_ban"
ban_id = sa.Column("id", sa.Integer, primary_key=True)
checksum = sa.Column("checksum", sa.Unicode(64), nullable=False)
time = sa.Column("time", sa.DateTime, nullable=False)
class PostNote(Base):
__tablename__ = "post_note"

View file

@ -4,3 +4,4 @@ from .post_search_config import PostSearchConfig
from .snapshot_search_config import SnapshotSearchConfig
from .tag_search_config import TagSearchConfig
from .user_search_config import UserSearchConfig
from .ban_search_config import BanSearchConfig

View file

@ -0,0 +1,77 @@
from typing import Dict, Tuple
import sqlalchemy as sa
from szurubooru import db, model
from szurubooru.func import util
from szurubooru.search.configs import util as search_util
from szurubooru.search.configs.base_search_config import (
BaseSearchConfig,
Filter,
)
from szurubooru.search.typing import SaColumn, SaQuery
class BanSearchConfig(BaseSearchConfig):
def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery:
strategy = (
sa.orm.lazyload if _disable_eager_loads else sa.orm.subqueryload
)
return (
db.session.query(model.PostBan)
.options(
sa.orm.defer(model.PostBan.checksum),
sa.orm.defer(model.PostBan.time),
sa.orm.defer(model.PostBan.ban_id)
)
)
def create_count_query(self, _disable_eager_loads: bool) -> SaQuery:
return db.session.query(model.PostBan)
def create_around_query(self) -> SaQuery:
raise NotImplementedError()
def finalize_query(self, query: SaQuery) -> SaQuery:
return query.order_by(model.PostBan.time.asc())
@property
def anonymous_filter(self) -> Filter:
return search_util.create_subquery_filter(
model.PostBan.checksum,
model.PostBan.time,
search_util.create_str_filter,
)
@property
def named_filters(self) -> Dict[str, Filter]:
return util.unalias_dict(
[
(
["time"],
search_util.create_date_filter(
model.PostBan.time,
),
),
(
["checksum"],
search_util.create_subquery_filter(
model.PostBan.checksum,
search_util.create_str_filter,
),
)
]
)
@property
def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]:
return util.unalias_dict(
[
(
["random"],
(sa.sql.expression.func.random(), self.SORT_NONE),
),
(["checksum"], (model.PostBan.checksum, self.SORT_ASC)),
(["time"], (model.PostBan.time, self.SORT_ASC))
]
)