diff --git a/server/szurubooru/api/info_api.py b/server/szurubooru/api/info_api.py index 757b09cf..53b63962 100644 --- a/server/szurubooru/api/info_api.py +++ b/server/szurubooru/api/info_api.py @@ -43,6 +43,8 @@ def get_info(ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response: "tagNameRegex": config.config["tag_name_regex"], "tagCategoryNameRegex": config.config["tag_category_name_regex"], "defaultUserRank": config.config["default_rank"], + "defaultTagBlocklist": config.config["default_tag_blocklist"], + "defaultTagBlocklistForAnonymous": config.config["default_tag_blocklist_for_anonymous"], "enableSafety": config.config["enable_safety"], "contactEmail": config.config["contact_email"], "canSendMails": bool(config.config["smtp"]["host"]), diff --git a/server/szurubooru/api/tag_api.py b/server/szurubooru/api/tag_api.py index 6b4c807e..724e71d7 100644 --- a/server/szurubooru/api/tag_api.py +++ b/server/szurubooru/api/tag_api.py @@ -2,7 +2,7 @@ from datetime import datetime from typing import Dict, List, Optional from szurubooru import db, model, rest, search -from szurubooru.func import auth, serialization, snapshots, tags, versions +from szurubooru.func import auth, serialization, snapshots, tags, versions, users _search_executor = search.Executor(search.configs.TagSearchConfig()) diff --git a/server/szurubooru/api/user_api.py b/server/szurubooru/api/user_api.py index a6196cb8..11a88459 100644 --- a/server/szurubooru/api/user_api.py +++ b/server/szurubooru/api/user_api.py @@ -1,7 +1,7 @@ -from typing import Any, Dict +from typing import Any, Dict, List -from szurubooru import model, rest, search -from szurubooru.func import auth, serialization, users, versions +from szurubooru import db, model, rest, search +from szurubooru.func import auth, serialization, snapshots, users, versions, tags _search_executor = search.Executor(search.configs.UserSearchConfig()) @@ -17,6 +17,18 @@ def _serialize( ) +def _create_tag_if_needed(tag_names: List[str], user: model.User) -> None: + # Taken from tag_api.py + if not tag_names: + return + _existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names) + if len(new_tags): + auth.verify_privilege(user, "tags:create") + db.session.flush() + for tag in new_tags: + snapshots.create(tag, user) + + @rest.routes.get("/users/?") def get_users( ctx: rest.Context, _params: Dict[str, str] = {} @@ -50,6 +62,10 @@ def create_user( ) ctx.session.add(user) ctx.session.commit() + to_add, _ = users.update_user_blocklist(user, None) + for e in to_add: + ctx.session.add(e) + ctx.session.commit() return _serialize(ctx, user, force_show_email=True) @@ -80,6 +96,16 @@ def update_user(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: if ctx.has_param("rank"): auth.verify_privilege(ctx.user, "users:edit:%s:rank" % infix) users.update_user_rank(user, ctx.get_param_as_string("rank"), ctx.user) + if ctx.has_param("blocklist"): + auth.verify_privilege(ctx.user, "users:edit:%s:blocklist" % infix) + blocklist = ctx.get_param_as_string_list("blocklist") + _create_tag_if_needed(blocklist, user) # Non-existing tags are created. + blocklist_tags = tags.get_tags_by_names(blocklist) + to_add, to_remove = users.update_user_blocklist(user, blocklist_tags) + for e in to_remove: + ctx.session.delete(e) + for e in to_add: + ctx.session.add(e) if ctx.has_param("avatarStyle"): auth.verify_privilege(ctx.user, "users:edit:%s:avatar" % infix) users.update_user_avatar( diff --git a/server/szurubooru/func/tags.py b/server/szurubooru/func/tags.py index 28a2a76b..10eaebf1 100644 --- a/server/szurubooru/func/tags.py +++ b/server/szurubooru/func/tags.py @@ -159,6 +159,9 @@ def get_tag_by_name(name: str) -> model.Tag: def get_tags_by_names(names: List[str]) -> List[model.Tag]: + """ + Returns a list of all tags which names include all the letters from the input list + """ names = util.icase_unique(names) if len(names) == 0: return [] @@ -175,6 +178,24 @@ def get_tags_by_names(names: List[str]) -> List[model.Tag]: ) +def get_tags_by_exact_names(names: List[str]) -> List[model.Tag]: + """ + Returns a list of tags matching the names from the input list + """ + entries = [] + if len(names) == 0: + return [] + names = [name.lower() for name in names] + entries = ( + db.session.query(model.Tag) + .join(model.TagName) + .filter( + sa.func.lower(model.TagName.name).in_(names) + ) + .all()) + return entries + + def get_or_create_tags_by_names( names: List[str], ) -> Tuple[List[model.Tag], List[model.Tag]]: diff --git a/server/szurubooru/func/users.py b/server/szurubooru/func/users.py index 5cbe3cc0..c1be991e 100644 --- a/server/szurubooru/func/users.py +++ b/server/szurubooru/func/users.py @@ -1,3 +1,4 @@ +import copy import re from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Union @@ -5,7 +6,7 @@ from typing import Any, Callable, Dict, List, Optional, Union import sqlalchemy as sa from szurubooru import config, db, errors, model, rest -from szurubooru.func import auth, files, images, serialization, util +from szurubooru.func import auth, files, images, serialization, util, tags class UserNotFoundError(errors.NotFoundError): @@ -107,6 +108,7 @@ class UserSerializer(serialization.BaseSerializer): "lastLoginTime": self.serialize_last_login_time, "version": self.serialize_version, "rank": self.serialize_rank, + "blocklist": self.serialize_blocklist, "avatarStyle": self.serialize_avatar_style, "avatarUrl": self.serialize_avatar_url, "commentCount": self.serialize_comment_count, @@ -138,6 +140,9 @@ class UserSerializer(serialization.BaseSerializer): def serialize_avatar_url(self) -> Any: return get_avatar_url(self.user) + def serialize_blocklist(self) -> Any: + return [tags.serialize_tag(tag) for tag in get_blocklist_tag_from_user(self.user)] + def serialize_comment_count(self) -> Any: return self.user.comment_count @@ -294,6 +299,66 @@ def update_user_rank( user.rank = rank +def get_blocklist_from_user(user: model.User) -> List[model.UserTagBlocklist]: + """ + Return the UserTagBlocklist objects related to given user + """ + rez = (db.session.query(model.UserTagBlocklist) + .filter( + model.UserTagBlocklist.user_id == user.user_id + ) + .all()) + return rez + + +def get_blocklist_tag_from_user(user: model.User) -> List[model.UserTagBlocklist]: + """ + Return the Tags blocklisted by given user + """ + rez = (db.session.query(model.UserTagBlocklist.tag_id) + .filter( + model.UserTagBlocklist.user_id == user.user_id + )) + rez2 = (db.session.query(model.Tag) + .filter( + model.Tag.tag_id.in_(rez) + ).all()) + return rez2 + + +def update_user_blocklist(user: model.User, new_blocklist_tags: Optional[List[model.Tag]]) -> List[List[model.UserTagBlocklist]]: + """ + Modify blocklist for given user. + If new_blocklist_tags is None, set the blocklist to configured default tag blocklist. + """ + assert user + to_add: List[model.UserTagBlocklist] = [] + to_remove: List[model.UserTagBlocklist] = [] + + if new_blocklist_tags is None: # We're creating the user, use default config blocklist + if 'default_tag_blocklist' in config.config.keys(): + for e in tags.get_tags_by_exact_names(config.config['default_tag_blocklist'].split(' ')): + to_add.append(model.UserTagBlocklist(user_id=user.user_id, tag_id=e.tag_id)) + else: + new_blocklist_ids: List[int] = [e.tag_id for e in new_blocklist_tags] + previous_blocklist_tags: List[model.Tag] = get_blocklist_from_user(user) + previous_blocklist_ids: List[int] = [e.tag_id for e in previous_blocklist_tags] + original_previous_blocklist_ids = copy.copy(previous_blocklist_ids) + + ## Remove tags no longer in the new list + for i in range(len(original_previous_blocklist_ids)): + old_tag_id = original_previous_blocklist_ids[i] + if old_tag_id not in new_blocklist_ids: + to_remove.append(previous_blocklist_tags[i]) + previous_blocklist_ids.remove(old_tag_id) + + ## Add tags not yet in the original list + for new_tag_id in new_blocklist_ids: + if new_tag_id not in previous_blocklist_ids: + to_add.append(model.UserTagBlocklist(user_id=user.user_id, tag_id=new_tag_id)) + return to_add, to_remove + + def update_user_avatar( user: model.User, avatar_style: str, avatar_content: Optional[bytes] = None ) -> None: diff --git a/server/szurubooru/migrations/versions/9ba5e3a6ee7c_add_blocklist.py b/server/szurubooru/migrations/versions/9ba5e3a6ee7c_add_blocklist.py new file mode 100644 index 00000000..317f2c96 --- /dev/null +++ b/server/szurubooru/migrations/versions/9ba5e3a6ee7c_add_blocklist.py @@ -0,0 +1,30 @@ +''' +Add blocklist related fields + +add_blocklist + +Revision ID: 9ba5e3a6ee7c +Created at: 2023-05-20 22:28:10.824954 +''' + +import sqlalchemy as sa +from alembic import op + +revision = '9ba5e3a6ee7c' +down_revision = 'adcd63ff76a2' +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + "user_tag_blocklist", + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("tag_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(["user_id"], ["user.id"]), + sa.ForeignKeyConstraint(["tag_id"], ["tag.id"]), + sa.PrimaryKeyConstraint("user_id", "tag_id"), + ) + +def downgrade(): + op.drop_table('user_tag_blocklist') diff --git a/server/szurubooru/model/__init__.py b/server/szurubooru/model/__init__.py index 21a178ef..de28af4c 100644 --- a/server/szurubooru/model/__init__.py +++ b/server/szurubooru/model/__init__.py @@ -16,4 +16,4 @@ from szurubooru.model.post import ( from szurubooru.model.snapshot import Snapshot from szurubooru.model.tag import Tag, TagImplication, TagName, TagSuggestion from szurubooru.model.tag_category import TagCategory -from szurubooru.model.user import User, UserToken +from szurubooru.model.user import UserTagBlocklist, User, UserToken diff --git a/server/szurubooru/model/user.py b/server/szurubooru/model/user.py index 41a9b30b..5186f429 100644 --- a/server/szurubooru/model/user.py +++ b/server/szurubooru/model/user.py @@ -5,6 +5,46 @@ from szurubooru.model.comment import Comment from szurubooru.model.post import Post, PostFavorite, PostScore +class UserTagBlocklist(Base): + __tablename__ = "user_tag_blocklist" + + user_id = sa.Column( + "user_id", + sa.Integer, + sa.ForeignKey("user.id"), + primary_key=True, + nullable=False, + index=True, + ) + tag_id = sa.Column( + "tag_id", + sa.Integer, + sa.ForeignKey("tag.id"), + primary_key=True, + nullable=False, + index=True, + ) + + tag = sa.orm.relationship( + "Tag", + backref=sa.orm.backref("user_tag_blocklist", cascade="all, delete-orphan"), + ) + user = sa.orm.relationship( + "User", + backref=sa.orm.backref("user_tag_blocklist", cascade="all, delete-orphan"), + ) + + def __init__(self, user_id: int=None, tag_id: int=None, user=None, tag=None) -> None: + if user_id is not None: + self.user_id = user_id + if tag_id is not None: + self.tag_id = tag_id + if user is not None: + self.user = user + if tag is not None: + self.tag = tag + + class User(Base): __tablename__ = "user" @@ -35,6 +75,7 @@ class User(Base): "avatar_style", sa.Unicode(32), nullable=False, default=AVATAR_GRAVATAR ) + blocklist = sa.orm.relationship("UserTagBlocklist") comments = sa.orm.relationship("Comment") @property diff --git a/server/szurubooru/search/configs/post_search_config.py b/server/szurubooru/search/configs/post_search_config.py index 8d4672d4..d482bcc6 100644 --- a/server/szurubooru/search/configs/post_search_config.py +++ b/server/szurubooru/search/configs/post_search_config.py @@ -2,9 +2,9 @@ from typing import Any, Dict, Optional, Tuple import sqlalchemy as sa -from szurubooru import db, errors, model -from szurubooru.func import util -from szurubooru.search import criteria, tokens +from szurubooru import config, db, errors, model +from szurubooru.func import tags, users, util +from szurubooru.search import criteria, parser, tokens from szurubooru.search.configs import util as search_util from szurubooru.search.configs.base_search_config import ( BaseSearchConfig, @@ -178,6 +178,32 @@ class PostSearchConfig(BaseSearchConfig): new_special_tokens.append(token) search_query.special_tokens = new_special_tokens + blocklist_to_use = "" + + if self.user: # Ensure there's a user object + if (self.user.rank == model.User.RANK_ANONYMOUS) and config.config["default_tag_blocklist_for_anonymous"]: + # Anonymous user, if configured to use default blocklist, do so + blocklist_to_use = config.config["default_tag_blocklist"] + else: + # Registered user, use their blocklist + user_blocklist_tags = users.get_blocklist_tag_from_user(self.user) + if user_blocklist_tags: + user_blocklist = db.session.query(model.Tag.first_name).filter( + model.Tag.tag_id.in_([e.tag_id for e in user_blocklist_tags]) + ).all() + blocklist_to_use = [e[0] for e in user_blocklist] + blocklist_to_use = " ".join(blocklist_to_use) + + if len(blocklist_to_use) > 0: + # TODO Sort an already parsed and checked version instead? + blocklist_query = parser.Parser().parse(blocklist_to_use) + search_query_orig_list = [e.criterion.original_text for e in search_query.anonymous_tokens] + for t in blocklist_query.anonymous_tokens: + if t.criterion.original_text in search_query_orig_list: + continue + t.negated = True + search_query.anonymous_tokens.append(t) + def create_around_query(self) -> SaQuery: return db.session.query(model.Post).options(sa.orm.lazyload("*"))