diff --git a/server/config.yaml.dist b/server/config.yaml.dist index 2d4ff0c8..63bc99fa 100644 --- a/server/config.yaml.dist +++ b/server/config.yaml.dist @@ -116,7 +116,9 @@ privileges: 'posts:bulk-edit:tags': power 'posts:bulk-edit:safety': power 'posts:bulk-edit:delete': power - 'posts:ban': moderator + 'posts:ban:create': moderator + 'posts:ban:delete': moderator + 'posts:ban:list': moderator 'tags:create': regular 'tags:edit:names': power diff --git a/server/szurubooru/api/__init__.py b/server/szurubooru/api/__init__.py index d9b7ecba..b1ba6b3b 100644 --- a/server/szurubooru/api/__init__.py +++ b/server/szurubooru/api/__init__.py @@ -10,3 +10,4 @@ import szurubooru.api.tag_category_api import szurubooru.api.upload_api import szurubooru.api.user_api import szurubooru.api.user_token_api +import szurubooru.api.ban_api \ No newline at end of file diff --git a/server/szurubooru/api/ban_api.py b/server/szurubooru/api/ban_api.py new file mode 100644 index 00000000..b6238815 --- /dev/null +++ b/server/szurubooru/api/ban_api.py @@ -0,0 +1,58 @@ +from datetime import datetime +from typing import Dict, List, Optional +from server.szurubooru.api import post_api +from server.szurubooru.func import posts +from server.szurubooru.model.bans import PostBan + +from szurubooru import db, errors, model, rest, search +from szurubooru.func import ( + auth, + bans, + serialization, + snapshots, + versions, +) + +def _get_ban_by_hash(hash: str) -> Optional[PostBan]: + try: + return bans.get_bans_by_hash(hash) + except: + return None + + +_search_executor = search.Executor(search.configs.BanSearchConfig()) + + +def _serialize(ctx: rest.Context, ban: model.PostBan) -> rest.Response: + return bans.serialize_ban( + ban, options=serialization.get_serialization_options(ctx) + ) + + +@rest.routes.post("/post-ban/(?P[^/]+)/?") +def ban_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: + auth.verify_privilege(ctx.user, "posts:ban:create") + post = post_api._get_post(params) + versions.verify_version(post, ctx) + posts.ban(bans.create_ban(post)) + snapshots.delete(post, ctx.user) + posts.delete(post) + ctx.session.commit() + return {} + + +@rest.routes.delete("/post-ban/(?P[^/]+)/?") +def unban_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: + auth.verify_privilege(ctx.user, "posts:ban:delete") + ban = _get_ban_by_hash(params["image_hash"]) + bans.delete(ban) + ctx.session.commit() + return {} + + +@rest.routes.get("/post-ban/?") +def get_bans(ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response: + auth.verify_privilege(ctx.user, "posts:ban:list") + return _search_executor.execute_and_serialize( + ctx, lambda tag: _serialize(ctx, tag) + ) diff --git a/server/szurubooru/api/post_api.py b/server/szurubooru/api/post_api.py index 6b20eee3..daba7f7e 100644 --- a/server/szurubooru/api/post_api.py +++ b/server/szurubooru/api/post_api.py @@ -183,18 +183,6 @@ def delete_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: return {} -@rest.routes.post("/post-ban/(?P[^/]+)/?") -def ban_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: - auth.verify_privilege(ctx.user, "posts:ban") - post = _get_post(params) - versions.verify_version(post, ctx) - posts.ban(posts.create_ban(post)) - snapshots.delete(post, ctx.user) - posts.delete(post) - ctx.session.commit() - return {} - - @rest.routes.post("/post-merge/?") def merge_posts( ctx: rest.Context, _params: Dict[str, str] = {} diff --git a/server/szurubooru/func/bans.py b/server/szurubooru/func/bans.py new file mode 100644 index 00000000..35946cc4 --- /dev/null +++ b/server/szurubooru/func/bans.py @@ -0,0 +1,68 @@ +from datetime import datetime +from typing import Any, Callable, Dict, List, Optional, Tuple + +from szurubooru import db, errors, model, rest +from szurubooru.func import ( + serialization, +) + +class PostBannedError(errors.ValidationError): + def __init__(self, message: str = "This file was banned", extra_fields: Dict[str, str] = None) -> None: + super().__init__(message, extra_fields) + + +class HashNotBannedError(errors.ValidationError): + pass + + +class TagSerializer(serialization.BaseSerializer): + def __init__(self, ban: model.PostBan) -> None: + self.ban = ban + + def _serializers(self) -> Dict[str, Callable[[], Any]]: + return { + "checksum": self.serialize_checksum, + "time": self.serialize_time + } + + def serialize_checksum(self) -> Any: + return self.ban.checksum + + def serialize_time(self) -> Any: + return self.ban.time + + +def create_ban(post: model.Post) -> model.PostBan: + ban = model.PostBan() + ban.checksum = post.checksum + ban.time = datetime.utcnow() + + db.session.add(ban) + return ban + + +def try_get_ban_by_checksum(checksum: str) -> Optional[model.PostBan]: + return ( + db.session.query(model.PostBan) + .filter(model.PostBan.checksum == checksum) + .one_or_none() + ) + + +def get_bans_by_hash(hash: str) -> model.PostBan: + ban = try_get_ban_by_checksum(hash) + if ban is None: + raise HashNotBannedError("Hash %s is not banned" % hash) + return ban + + +def delete(ban: model.PostBan) -> None: + db.session.delete(ban) + + +def serialize_ban( + ban: model.PostBan, options: List[str] = [] +) -> Optional[rest.Response]: + if not ban: + return None + return serialization.BaseSerializer(ban).serialize(options) diff --git a/server/szurubooru/func/posts.py b/server/szurubooru/func/posts.py index 507e77a3..7ec779f7 100644 --- a/server/szurubooru/func/posts.py +++ b/server/szurubooru/func/posts.py @@ -4,6 +4,7 @@ from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Tuple import sqlalchemy as sa +from server.szurubooru.func.bans import PostBannedError from szurubooru import config, db, errors, model, rest from szurubooru.func import ( @@ -50,11 +51,6 @@ class PostAlreadyUploadedError(errors.ValidationError): ) -class PostBannedError(errors.ValidationError): - def __init__(self, message: str = "This file was banned", extra_fields: Dict[str, str] = None) -> None: - super().__init__(message, extra_fields) - - class InvalidPostIdError(errors.ValidationError): pass @@ -431,15 +427,6 @@ def create_post( return post, new_tags -def create_ban(post: model.Post) -> model.PostBan: - ban = model.PostBan() - ban.checksum = post.checksum - ban.time = datetime.utcnow() - - db.session.add(ban) - return ban - - def update_post_safety(post: model.Post, safety: str) -> None: assert post safety = util.flip(SAFETY_MAP).get(safety, None) diff --git a/server/szurubooru/model/__init__.py b/server/szurubooru/model/__init__.py index a6642525..70d20452 100644 --- a/server/szurubooru/model/__init__.py +++ b/server/szurubooru/model/__init__.py @@ -1,11 +1,11 @@ import szurubooru.model.util from szurubooru.model.base import Base +from szurubooru.model.bans import PostBan from szurubooru.model.comment import Comment, CommentScore from szurubooru.model.pool import Pool, PoolName, PoolPost from szurubooru.model.pool_category import PoolCategory from szurubooru.model.post import ( Post, - PostBan, PostFavorite, PostFeature, PostNote, diff --git a/server/szurubooru/model/bans.py b/server/szurubooru/model/bans.py new file mode 100644 index 00000000..8abeb7f7 --- /dev/null +++ b/server/szurubooru/model/bans.py @@ -0,0 +1,12 @@ +from typing import List + +import sqlalchemy as sa + +from szurubooru.model.base import Base + +class PostBan(Base): + __tablename__ = "post_ban" + + ban_id = sa.Column("id", sa.Integer, primary_key=True) + checksum = sa.Column("checksum", sa.Unicode(64), nullable=False) + time = sa.Column("time", sa.DateTime, nullable=False) diff --git a/server/szurubooru/model/post.py b/server/szurubooru/model/post.py index 51393bf4..49e748dc 100644 --- a/server/szurubooru/model/post.py +++ b/server/szurubooru/model/post.py @@ -94,14 +94,6 @@ class PostFavorite(Base): ) -class PostBan(Base): - __tablename__ = "post_ban" - - ban_id = sa.Column("id", sa.Integer, primary_key=True) - checksum = sa.Column("checksum", sa.Unicode(64), nullable=False) - time = sa.Column("time", sa.DateTime, nullable=False) - - class PostNote(Base): __tablename__ = "post_note" diff --git a/server/szurubooru/search/configs/__init__.py b/server/szurubooru/search/configs/__init__.py index c7218131..dec5ed7f 100644 --- a/server/szurubooru/search/configs/__init__.py +++ b/server/szurubooru/search/configs/__init__.py @@ -4,3 +4,4 @@ from .post_search_config import PostSearchConfig from .snapshot_search_config import SnapshotSearchConfig from .tag_search_config import TagSearchConfig from .user_search_config import UserSearchConfig +from .ban_search_config import BanSearchConfig \ No newline at end of file diff --git a/server/szurubooru/search/configs/ban_search_config.py b/server/szurubooru/search/configs/ban_search_config.py new file mode 100644 index 00000000..8977f4a9 --- /dev/null +++ b/server/szurubooru/search/configs/ban_search_config.py @@ -0,0 +1,77 @@ +from typing import Dict, Tuple + +import sqlalchemy as sa + +from szurubooru import db, model +from szurubooru.func import util +from szurubooru.search.configs import util as search_util +from szurubooru.search.configs.base_search_config import ( + BaseSearchConfig, + Filter, +) +from szurubooru.search.typing import SaColumn, SaQuery + + +class BanSearchConfig(BaseSearchConfig): + def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery: + strategy = ( + sa.orm.lazyload if _disable_eager_loads else sa.orm.subqueryload + ) + return ( + db.session.query(model.PostBan) + .options( + sa.orm.defer(model.PostBan.checksum), + sa.orm.defer(model.PostBan.time), + sa.orm.defer(model.PostBan.ban_id) + ) + ) + + def create_count_query(self, _disable_eager_loads: bool) -> SaQuery: + return db.session.query(model.PostBan) + + def create_around_query(self) -> SaQuery: + raise NotImplementedError() + + def finalize_query(self, query: SaQuery) -> SaQuery: + return query.order_by(model.PostBan.time.asc()) + + @property + def anonymous_filter(self) -> Filter: + return search_util.create_subquery_filter( + model.PostBan.checksum, + model.PostBan.time, + search_util.create_str_filter, + ) + + @property + def named_filters(self) -> Dict[str, Filter]: + return util.unalias_dict( + [ + ( + ["time"], + search_util.create_date_filter( + model.PostBan.time, + ), + ), + ( + ["checksum"], + search_util.create_subquery_filter( + model.PostBan.checksum, + search_util.create_str_filter, + ), + ) + ] + ) + + @property + def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]: + return util.unalias_dict( + [ + ( + ["random"], + (sa.sql.expression.func.random(), self.SORT_NONE), + ), + (["checksum"], (model.PostBan.checksum, self.SORT_ASC)), + (["time"], (model.PostBan.time, self.SORT_ASC)) + ] + )