Blocklist: Add backend elements:

- Add default blocklist to user when created
- Tags are created if added to a user blocklist
- Add matching migration to DB to add the user blocklist table
- Various other things
This commit is contained in:
Soblow (Opale) Xaselgio 2024-03-03 16:53:23 +01:00 committed by Lugrim
parent e5f61d2c31
commit 82721c0bcb
No known key found for this signature in database
GPG key ID: 1CF1D1FB9A327611
9 changed files with 220 additions and 9 deletions

View file

@ -43,6 +43,8 @@ def get_info(ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
"tagNameRegex": config.config["tag_name_regex"], "tagNameRegex": config.config["tag_name_regex"],
"tagCategoryNameRegex": config.config["tag_category_name_regex"], "tagCategoryNameRegex": config.config["tag_category_name_regex"],
"defaultUserRank": config.config["default_rank"], "defaultUserRank": config.config["default_rank"],
"defaultTagBlocklist": config.config["default_tag_blocklist"],
"defaultTagBlocklistForAnonymous": config.config["default_tag_blocklist_for_anonymous"],
"enableSafety": config.config["enable_safety"], "enableSafety": config.config["enable_safety"],
"contactEmail": config.config["contact_email"], "contactEmail": config.config["contact_email"],
"canSendMails": bool(config.config["smtp"]["host"]), "canSendMails": bool(config.config["smtp"]["host"]),

View file

@ -2,7 +2,7 @@ from datetime import datetime
from typing import Dict, List, Optional from typing import Dict, List, Optional
from szurubooru import db, model, rest, search from szurubooru import db, model, rest, search
from szurubooru.func import auth, serialization, snapshots, tags, versions from szurubooru.func import auth, serialization, snapshots, tags, versions, users
_search_executor = search.Executor(search.configs.TagSearchConfig()) _search_executor = search.Executor(search.configs.TagSearchConfig())

View file

@ -1,7 +1,7 @@
from typing import Any, Dict from typing import Any, Dict, List
from szurubooru import model, rest, search from szurubooru import db, model, rest, search
from szurubooru.func import auth, serialization, users, versions from szurubooru.func import auth, serialization, snapshots, users, versions, tags
_search_executor = search.Executor(search.configs.UserSearchConfig()) _search_executor = search.Executor(search.configs.UserSearchConfig())
@ -17,6 +17,18 @@ def _serialize(
) )
def _create_tag_if_needed(tag_names: List[str], user: model.User) -> None:
# Taken from tag_api.py
if not tag_names:
return
_existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names)
if len(new_tags):
auth.verify_privilege(user, "tags:create")
db.session.flush()
for tag in new_tags:
snapshots.create(tag, user)
@rest.routes.get("/users/?") @rest.routes.get("/users/?")
def get_users( def get_users(
ctx: rest.Context, _params: Dict[str, str] = {} ctx: rest.Context, _params: Dict[str, str] = {}
@ -50,6 +62,10 @@ def create_user(
) )
ctx.session.add(user) ctx.session.add(user)
ctx.session.commit() ctx.session.commit()
to_add, _ = users.update_user_blocklist(user, None)
for e in to_add:
ctx.session.add(e)
ctx.session.commit()
return _serialize(ctx, user, force_show_email=True) return _serialize(ctx, user, force_show_email=True)
@ -80,6 +96,16 @@ def update_user(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
if ctx.has_param("rank"): if ctx.has_param("rank"):
auth.verify_privilege(ctx.user, "users:edit:%s:rank" % infix) auth.verify_privilege(ctx.user, "users:edit:%s:rank" % infix)
users.update_user_rank(user, ctx.get_param_as_string("rank"), ctx.user) users.update_user_rank(user, ctx.get_param_as_string("rank"), ctx.user)
if ctx.has_param("blocklist"):
auth.verify_privilege(ctx.user, "users:edit:%s:blocklist" % infix)
blocklist = ctx.get_param_as_string_list("blocklist")
_create_tag_if_needed(blocklist, user) # Non-existing tags are created.
blocklist_tags = tags.get_tags_by_names(blocklist)
to_add, to_remove = users.update_user_blocklist(user, blocklist_tags)
for e in to_remove:
ctx.session.delete(e)
for e in to_add:
ctx.session.add(e)
if ctx.has_param("avatarStyle"): if ctx.has_param("avatarStyle"):
auth.verify_privilege(ctx.user, "users:edit:%s:avatar" % infix) auth.verify_privilege(ctx.user, "users:edit:%s:avatar" % infix)
users.update_user_avatar( users.update_user_avatar(

View file

@ -159,6 +159,9 @@ def get_tag_by_name(name: str) -> model.Tag:
def get_tags_by_names(names: List[str]) -> List[model.Tag]: def get_tags_by_names(names: List[str]) -> List[model.Tag]:
"""
Returns a list of all tags which names include all the letters from the input list
"""
names = util.icase_unique(names) names = util.icase_unique(names)
if len(names) == 0: if len(names) == 0:
return [] return []
@ -175,6 +178,24 @@ def get_tags_by_names(names: List[str]) -> List[model.Tag]:
) )
def get_tags_by_exact_names(names: List[str]) -> List[model.Tag]:
"""
Returns a list of tags matching the names from the input list
"""
entries = []
if len(names) == 0:
return []
names = [name.lower() for name in names]
entries = (
db.session.query(model.Tag)
.join(model.TagName)
.filter(
sa.func.lower(model.TagName.name).in_(names)
)
.all())
return entries
def get_or_create_tags_by_names( def get_or_create_tags_by_names(
names: List[str], names: List[str],
) -> Tuple[List[model.Tag], List[model.Tag]]: ) -> Tuple[List[model.Tag], List[model.Tag]]:

View file

@ -1,3 +1,4 @@
import copy
import re import re
from datetime import datetime from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
@ -5,7 +6,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import sqlalchemy as sa import sqlalchemy as sa
from szurubooru import config, db, errors, model, rest from szurubooru import config, db, errors, model, rest
from szurubooru.func import auth, files, images, serialization, util from szurubooru.func import auth, files, images, serialization, util, tags
class UserNotFoundError(errors.NotFoundError): class UserNotFoundError(errors.NotFoundError):
@ -107,6 +108,7 @@ class UserSerializer(serialization.BaseSerializer):
"lastLoginTime": self.serialize_last_login_time, "lastLoginTime": self.serialize_last_login_time,
"version": self.serialize_version, "version": self.serialize_version,
"rank": self.serialize_rank, "rank": self.serialize_rank,
"blocklist": self.serialize_blocklist,
"avatarStyle": self.serialize_avatar_style, "avatarStyle": self.serialize_avatar_style,
"avatarUrl": self.serialize_avatar_url, "avatarUrl": self.serialize_avatar_url,
"commentCount": self.serialize_comment_count, "commentCount": self.serialize_comment_count,
@ -138,6 +140,9 @@ class UserSerializer(serialization.BaseSerializer):
def serialize_avatar_url(self) -> Any: def serialize_avatar_url(self) -> Any:
return get_avatar_url(self.user) return get_avatar_url(self.user)
def serialize_blocklist(self) -> Any:
return [tags.serialize_tag(tag) for tag in get_blocklist_tag_from_user(self.user)]
def serialize_comment_count(self) -> Any: def serialize_comment_count(self) -> Any:
return self.user.comment_count return self.user.comment_count
@ -294,6 +299,66 @@ def update_user_rank(
user.rank = rank user.rank = rank
def get_blocklist_from_user(user: model.User) -> List[model.UserTagBlocklist]:
"""
Return the UserTagBlocklist objects related to given user
"""
rez = (db.session.query(model.UserTagBlocklist)
.filter(
model.UserTagBlocklist.user_id == user.user_id
)
.all())
return rez
def get_blocklist_tag_from_user(user: model.User) -> List[model.UserTagBlocklist]:
"""
Return the Tags blocklisted by given user
"""
rez = (db.session.query(model.UserTagBlocklist.tag_id)
.filter(
model.UserTagBlocklist.user_id == user.user_id
))
rez2 = (db.session.query(model.Tag)
.filter(
model.Tag.tag_id.in_(rez)
).all())
return rez2
def update_user_blocklist(user: model.User, new_blocklist_tags: Optional[List[model.Tag]]) -> List[List[model.UserTagBlocklist]]:
"""
Modify blocklist for given user.
If new_blocklist_tags is None, set the blocklist to configured default tag blocklist.
"""
assert user
to_add: List[model.UserTagBlocklist] = []
to_remove: List[model.UserTagBlocklist] = []
if new_blocklist_tags is None: # We're creating the user, use default config blocklist
if 'default_tag_blocklist' in config.config.keys():
for e in tags.get_tags_by_exact_names(config.config['default_tag_blocklist'].split(' ')):
to_add.append(model.UserTagBlocklist(user_id=user.user_id, tag_id=e.tag_id))
else:
new_blocklist_ids: List[int] = [e.tag_id for e in new_blocklist_tags]
previous_blocklist_tags: List[model.Tag] = get_blocklist_from_user(user)
previous_blocklist_ids: List[int] = [e.tag_id for e in previous_blocklist_tags]
original_previous_blocklist_ids = copy.copy(previous_blocklist_ids)
## Remove tags no longer in the new list
for i in range(len(original_previous_blocklist_ids)):
old_tag_id = original_previous_blocklist_ids[i]
if old_tag_id not in new_blocklist_ids:
to_remove.append(previous_blocklist_tags[i])
previous_blocklist_ids.remove(old_tag_id)
## Add tags not yet in the original list
for new_tag_id in new_blocklist_ids:
if new_tag_id not in previous_blocklist_ids:
to_add.append(model.UserTagBlocklist(user_id=user.user_id, tag_id=new_tag_id))
return to_add, to_remove
def update_user_avatar( def update_user_avatar(
user: model.User, avatar_style: str, avatar_content: Optional[bytes] = None user: model.User, avatar_style: str, avatar_content: Optional[bytes] = None
) -> None: ) -> None:

View file

@ -0,0 +1,30 @@
'''
Add blocklist related fields
add_blocklist
Revision ID: 9ba5e3a6ee7c
Created at: 2023-05-20 22:28:10.824954
'''
import sqlalchemy as sa
from alembic import op
revision = '9ba5e3a6ee7c'
down_revision = 'adcd63ff76a2'
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"user_tag_blocklist",
sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column("tag_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["user.id"]),
sa.ForeignKeyConstraint(["tag_id"], ["tag.id"]),
sa.PrimaryKeyConstraint("user_id", "tag_id"),
)
def downgrade():
op.drop_table('user_tag_blocklist')

View file

@ -16,4 +16,4 @@ from szurubooru.model.post import (
from szurubooru.model.snapshot import Snapshot from szurubooru.model.snapshot import Snapshot
from szurubooru.model.tag import Tag, TagImplication, TagName, TagSuggestion from szurubooru.model.tag import Tag, TagImplication, TagName, TagSuggestion
from szurubooru.model.tag_category import TagCategory from szurubooru.model.tag_category import TagCategory
from szurubooru.model.user import User, UserToken from szurubooru.model.user import UserTagBlocklist, User, UserToken

View file

@ -5,6 +5,46 @@ from szurubooru.model.comment import Comment
from szurubooru.model.post import Post, PostFavorite, PostScore from szurubooru.model.post import Post, PostFavorite, PostScore
class UserTagBlocklist(Base):
__tablename__ = "user_tag_blocklist"
user_id = sa.Column(
"user_id",
sa.Integer,
sa.ForeignKey("user.id"),
primary_key=True,
nullable=False,
index=True,
)
tag_id = sa.Column(
"tag_id",
sa.Integer,
sa.ForeignKey("tag.id"),
primary_key=True,
nullable=False,
index=True,
)
tag = sa.orm.relationship(
"Tag",
backref=sa.orm.backref("user_tag_blocklist", cascade="all, delete-orphan"),
)
user = sa.orm.relationship(
"User",
backref=sa.orm.backref("user_tag_blocklist", cascade="all, delete-orphan"),
)
def __init__(self, user_id: int=None, tag_id: int=None, user=None, tag=None) -> None:
if user_id is not None:
self.user_id = user_id
if tag_id is not None:
self.tag_id = tag_id
if user is not None:
self.user = user
if tag is not None:
self.tag = tag
class User(Base): class User(Base):
__tablename__ = "user" __tablename__ = "user"
@ -35,6 +75,7 @@ class User(Base):
"avatar_style", sa.Unicode(32), nullable=False, default=AVATAR_GRAVATAR "avatar_style", sa.Unicode(32), nullable=False, default=AVATAR_GRAVATAR
) )
blocklist = sa.orm.relationship("UserTagBlocklist")
comments = sa.orm.relationship("Comment") comments = sa.orm.relationship("Comment")
@property @property

View file

@ -2,9 +2,9 @@ from typing import Any, Dict, Optional, Tuple
import sqlalchemy as sa import sqlalchemy as sa
from szurubooru import db, errors, model from szurubooru import config, db, errors, model
from szurubooru.func import util from szurubooru.func import tags, users, util
from szurubooru.search import criteria, tokens from szurubooru.search import criteria, parser, tokens
from szurubooru.search.configs import util as search_util from szurubooru.search.configs import util as search_util
from szurubooru.search.configs.base_search_config import ( from szurubooru.search.configs.base_search_config import (
BaseSearchConfig, BaseSearchConfig,
@ -178,6 +178,32 @@ class PostSearchConfig(BaseSearchConfig):
new_special_tokens.append(token) new_special_tokens.append(token)
search_query.special_tokens = new_special_tokens search_query.special_tokens = new_special_tokens
blocklist_to_use = ""
if self.user: # Ensure there's a user object
if (self.user.rank == model.User.RANK_ANONYMOUS) and config.config["default_tag_blocklist_for_anonymous"]:
# Anonymous user, if configured to use default blocklist, do so
blocklist_to_use = config.config["default_tag_blocklist"]
else:
# Registered user, use their blocklist
user_blocklist_tags = users.get_blocklist_tag_from_user(self.user)
if user_blocklist_tags:
user_blocklist = db.session.query(model.Tag.first_name).filter(
model.Tag.tag_id.in_([e.tag_id for e in user_blocklist_tags])
).all()
blocklist_to_use = [e[0] for e in user_blocklist]
blocklist_to_use = " ".join(blocklist_to_use)
if len(blocklist_to_use) > 0:
# TODO Sort an already parsed and checked version instead?
blocklist_query = parser.Parser().parse(blocklist_to_use)
search_query_orig_list = [e.criterion.original_text for e in search_query.anonymous_tokens]
for t in blocklist_query.anonymous_tokens:
if t.criterion.original_text in search_query_orig_list:
continue
t.negated = True
search_query.anonymous_tokens.append(t)
def create_around_query(self) -> SaQuery: def create_around_query(self) -> SaQuery:
return db.session.query(model.Post).options(sa.orm.lazyload("*")) return db.session.query(model.Post).options(sa.orm.lazyload("*"))