Blocklist: Add backend elements:
- Add default blocklist to user when created - Tags are created if added to a user blocklist - Add matching migration to DB to add the user blocklist table - Various other things
This commit is contained in:
parent
e5f61d2c31
commit
82721c0bcb
9 changed files with 220 additions and 9 deletions
|
@ -43,6 +43,8 @@ def get_info(ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
|
||||||
"tagNameRegex": config.config["tag_name_regex"],
|
"tagNameRegex": config.config["tag_name_regex"],
|
||||||
"tagCategoryNameRegex": config.config["tag_category_name_regex"],
|
"tagCategoryNameRegex": config.config["tag_category_name_regex"],
|
||||||
"defaultUserRank": config.config["default_rank"],
|
"defaultUserRank": config.config["default_rank"],
|
||||||
|
"defaultTagBlocklist": config.config["default_tag_blocklist"],
|
||||||
|
"defaultTagBlocklistForAnonymous": config.config["default_tag_blocklist_for_anonymous"],
|
||||||
"enableSafety": config.config["enable_safety"],
|
"enableSafety": config.config["enable_safety"],
|
||||||
"contactEmail": config.config["contact_email"],
|
"contactEmail": config.config["contact_email"],
|
||||||
"canSendMails": bool(config.config["smtp"]["host"]),
|
"canSendMails": bool(config.config["smtp"]["host"]),
|
||||||
|
|
|
@ -2,7 +2,7 @@ from datetime import datetime
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from szurubooru import db, model, rest, search
|
from szurubooru import db, model, rest, search
|
||||||
from szurubooru.func import auth, serialization, snapshots, tags, versions
|
from szurubooru.func import auth, serialization, snapshots, tags, versions, users
|
||||||
|
|
||||||
_search_executor = search.Executor(search.configs.TagSearchConfig())
|
_search_executor = search.Executor(search.configs.TagSearchConfig())
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from szurubooru import model, rest, search
|
from szurubooru import db, model, rest, search
|
||||||
from szurubooru.func import auth, serialization, users, versions
|
from szurubooru.func import auth, serialization, snapshots, users, versions, tags
|
||||||
|
|
||||||
_search_executor = search.Executor(search.configs.UserSearchConfig())
|
_search_executor = search.Executor(search.configs.UserSearchConfig())
|
||||||
|
|
||||||
|
@ -17,6 +17,18 @@ def _serialize(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_tag_if_needed(tag_names: List[str], user: model.User) -> None:
|
||||||
|
# Taken from tag_api.py
|
||||||
|
if not tag_names:
|
||||||
|
return
|
||||||
|
_existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names)
|
||||||
|
if len(new_tags):
|
||||||
|
auth.verify_privilege(user, "tags:create")
|
||||||
|
db.session.flush()
|
||||||
|
for tag in new_tags:
|
||||||
|
snapshots.create(tag, user)
|
||||||
|
|
||||||
|
|
||||||
@rest.routes.get("/users/?")
|
@rest.routes.get("/users/?")
|
||||||
def get_users(
|
def get_users(
|
||||||
ctx: rest.Context, _params: Dict[str, str] = {}
|
ctx: rest.Context, _params: Dict[str, str] = {}
|
||||||
|
@ -50,6 +62,10 @@ def create_user(
|
||||||
)
|
)
|
||||||
ctx.session.add(user)
|
ctx.session.add(user)
|
||||||
ctx.session.commit()
|
ctx.session.commit()
|
||||||
|
to_add, _ = users.update_user_blocklist(user, None)
|
||||||
|
for e in to_add:
|
||||||
|
ctx.session.add(e)
|
||||||
|
ctx.session.commit()
|
||||||
|
|
||||||
return _serialize(ctx, user, force_show_email=True)
|
return _serialize(ctx, user, force_show_email=True)
|
||||||
|
|
||||||
|
@ -80,6 +96,16 @@ def update_user(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
|
||||||
if ctx.has_param("rank"):
|
if ctx.has_param("rank"):
|
||||||
auth.verify_privilege(ctx.user, "users:edit:%s:rank" % infix)
|
auth.verify_privilege(ctx.user, "users:edit:%s:rank" % infix)
|
||||||
users.update_user_rank(user, ctx.get_param_as_string("rank"), ctx.user)
|
users.update_user_rank(user, ctx.get_param_as_string("rank"), ctx.user)
|
||||||
|
if ctx.has_param("blocklist"):
|
||||||
|
auth.verify_privilege(ctx.user, "users:edit:%s:blocklist" % infix)
|
||||||
|
blocklist = ctx.get_param_as_string_list("blocklist")
|
||||||
|
_create_tag_if_needed(blocklist, user) # Non-existing tags are created.
|
||||||
|
blocklist_tags = tags.get_tags_by_names(blocklist)
|
||||||
|
to_add, to_remove = users.update_user_blocklist(user, blocklist_tags)
|
||||||
|
for e in to_remove:
|
||||||
|
ctx.session.delete(e)
|
||||||
|
for e in to_add:
|
||||||
|
ctx.session.add(e)
|
||||||
if ctx.has_param("avatarStyle"):
|
if ctx.has_param("avatarStyle"):
|
||||||
auth.verify_privilege(ctx.user, "users:edit:%s:avatar" % infix)
|
auth.verify_privilege(ctx.user, "users:edit:%s:avatar" % infix)
|
||||||
users.update_user_avatar(
|
users.update_user_avatar(
|
||||||
|
|
|
@ -159,6 +159,9 @@ def get_tag_by_name(name: str) -> model.Tag:
|
||||||
|
|
||||||
|
|
||||||
def get_tags_by_names(names: List[str]) -> List[model.Tag]:
|
def get_tags_by_names(names: List[str]) -> List[model.Tag]:
|
||||||
|
"""
|
||||||
|
Returns a list of all tags which names include all the letters from the input list
|
||||||
|
"""
|
||||||
names = util.icase_unique(names)
|
names = util.icase_unique(names)
|
||||||
if len(names) == 0:
|
if len(names) == 0:
|
||||||
return []
|
return []
|
||||||
|
@ -175,6 +178,24 @@ def get_tags_by_names(names: List[str]) -> List[model.Tag]:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tags_by_exact_names(names: List[str]) -> List[model.Tag]:
|
||||||
|
"""
|
||||||
|
Returns a list of tags matching the names from the input list
|
||||||
|
"""
|
||||||
|
entries = []
|
||||||
|
if len(names) == 0:
|
||||||
|
return []
|
||||||
|
names = [name.lower() for name in names]
|
||||||
|
entries = (
|
||||||
|
db.session.query(model.Tag)
|
||||||
|
.join(model.TagName)
|
||||||
|
.filter(
|
||||||
|
sa.func.lower(model.TagName.name).in_(names)
|
||||||
|
)
|
||||||
|
.all())
|
||||||
|
return entries
|
||||||
|
|
||||||
|
|
||||||
def get_or_create_tags_by_names(
|
def get_or_create_tags_by_names(
|
||||||
names: List[str],
|
names: List[str],
|
||||||
) -> Tuple[List[model.Tag], List[model.Tag]]:
|
) -> Tuple[List[model.Tag], List[model.Tag]]:
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import copy
|
||||||
import re
|
import re
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
@ -5,7 +6,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
from szurubooru import config, db, errors, model, rest
|
from szurubooru import config, db, errors, model, rest
|
||||||
from szurubooru.func import auth, files, images, serialization, util
|
from szurubooru.func import auth, files, images, serialization, util, tags
|
||||||
|
|
||||||
|
|
||||||
class UserNotFoundError(errors.NotFoundError):
|
class UserNotFoundError(errors.NotFoundError):
|
||||||
|
@ -107,6 +108,7 @@ class UserSerializer(serialization.BaseSerializer):
|
||||||
"lastLoginTime": self.serialize_last_login_time,
|
"lastLoginTime": self.serialize_last_login_time,
|
||||||
"version": self.serialize_version,
|
"version": self.serialize_version,
|
||||||
"rank": self.serialize_rank,
|
"rank": self.serialize_rank,
|
||||||
|
"blocklist": self.serialize_blocklist,
|
||||||
"avatarStyle": self.serialize_avatar_style,
|
"avatarStyle": self.serialize_avatar_style,
|
||||||
"avatarUrl": self.serialize_avatar_url,
|
"avatarUrl": self.serialize_avatar_url,
|
||||||
"commentCount": self.serialize_comment_count,
|
"commentCount": self.serialize_comment_count,
|
||||||
|
@ -138,6 +140,9 @@ class UserSerializer(serialization.BaseSerializer):
|
||||||
def serialize_avatar_url(self) -> Any:
|
def serialize_avatar_url(self) -> Any:
|
||||||
return get_avatar_url(self.user)
|
return get_avatar_url(self.user)
|
||||||
|
|
||||||
|
def serialize_blocklist(self) -> Any:
|
||||||
|
return [tags.serialize_tag(tag) for tag in get_blocklist_tag_from_user(self.user)]
|
||||||
|
|
||||||
def serialize_comment_count(self) -> Any:
|
def serialize_comment_count(self) -> Any:
|
||||||
return self.user.comment_count
|
return self.user.comment_count
|
||||||
|
|
||||||
|
@ -294,6 +299,66 @@ def update_user_rank(
|
||||||
user.rank = rank
|
user.rank = rank
|
||||||
|
|
||||||
|
|
||||||
|
def get_blocklist_from_user(user: model.User) -> List[model.UserTagBlocklist]:
|
||||||
|
"""
|
||||||
|
Return the UserTagBlocklist objects related to given user
|
||||||
|
"""
|
||||||
|
rez = (db.session.query(model.UserTagBlocklist)
|
||||||
|
.filter(
|
||||||
|
model.UserTagBlocklist.user_id == user.user_id
|
||||||
|
)
|
||||||
|
.all())
|
||||||
|
return rez
|
||||||
|
|
||||||
|
|
||||||
|
def get_blocklist_tag_from_user(user: model.User) -> List[model.UserTagBlocklist]:
|
||||||
|
"""
|
||||||
|
Return the Tags blocklisted by given user
|
||||||
|
"""
|
||||||
|
rez = (db.session.query(model.UserTagBlocklist.tag_id)
|
||||||
|
.filter(
|
||||||
|
model.UserTagBlocklist.user_id == user.user_id
|
||||||
|
))
|
||||||
|
rez2 = (db.session.query(model.Tag)
|
||||||
|
.filter(
|
||||||
|
model.Tag.tag_id.in_(rez)
|
||||||
|
).all())
|
||||||
|
return rez2
|
||||||
|
|
||||||
|
|
||||||
|
def update_user_blocklist(user: model.User, new_blocklist_tags: Optional[List[model.Tag]]) -> List[List[model.UserTagBlocklist]]:
|
||||||
|
"""
|
||||||
|
Modify blocklist for given user.
|
||||||
|
If new_blocklist_tags is None, set the blocklist to configured default tag blocklist.
|
||||||
|
"""
|
||||||
|
assert user
|
||||||
|
to_add: List[model.UserTagBlocklist] = []
|
||||||
|
to_remove: List[model.UserTagBlocklist] = []
|
||||||
|
|
||||||
|
if new_blocklist_tags is None: # We're creating the user, use default config blocklist
|
||||||
|
if 'default_tag_blocklist' in config.config.keys():
|
||||||
|
for e in tags.get_tags_by_exact_names(config.config['default_tag_blocklist'].split(' ')):
|
||||||
|
to_add.append(model.UserTagBlocklist(user_id=user.user_id, tag_id=e.tag_id))
|
||||||
|
else:
|
||||||
|
new_blocklist_ids: List[int] = [e.tag_id for e in new_blocklist_tags]
|
||||||
|
previous_blocklist_tags: List[model.Tag] = get_blocklist_from_user(user)
|
||||||
|
previous_blocklist_ids: List[int] = [e.tag_id for e in previous_blocklist_tags]
|
||||||
|
original_previous_blocklist_ids = copy.copy(previous_blocklist_ids)
|
||||||
|
|
||||||
|
## Remove tags no longer in the new list
|
||||||
|
for i in range(len(original_previous_blocklist_ids)):
|
||||||
|
old_tag_id = original_previous_blocklist_ids[i]
|
||||||
|
if old_tag_id not in new_blocklist_ids:
|
||||||
|
to_remove.append(previous_blocklist_tags[i])
|
||||||
|
previous_blocklist_ids.remove(old_tag_id)
|
||||||
|
|
||||||
|
## Add tags not yet in the original list
|
||||||
|
for new_tag_id in new_blocklist_ids:
|
||||||
|
if new_tag_id not in previous_blocklist_ids:
|
||||||
|
to_add.append(model.UserTagBlocklist(user_id=user.user_id, tag_id=new_tag_id))
|
||||||
|
return to_add, to_remove
|
||||||
|
|
||||||
|
|
||||||
def update_user_avatar(
|
def update_user_avatar(
|
||||||
user: model.User, avatar_style: str, avatar_content: Optional[bytes] = None
|
user: model.User, avatar_style: str, avatar_content: Optional[bytes] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
@ -0,0 +1,30 @@
|
||||||
|
'''
|
||||||
|
Add blocklist related fields
|
||||||
|
|
||||||
|
add_blocklist
|
||||||
|
|
||||||
|
Revision ID: 9ba5e3a6ee7c
|
||||||
|
Created at: 2023-05-20 22:28:10.824954
|
||||||
|
'''
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision = '9ba5e3a6ee7c'
|
||||||
|
down_revision = 'adcd63ff76a2'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
op.create_table(
|
||||||
|
"user_tag_blocklist",
|
||||||
|
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("tag_id", sa.Integer(), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["user.id"]),
|
||||||
|
sa.ForeignKeyConstraint(["tag_id"], ["tag.id"]),
|
||||||
|
sa.PrimaryKeyConstraint("user_id", "tag_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
op.drop_table('user_tag_blocklist')
|
|
@ -16,4 +16,4 @@ from szurubooru.model.post import (
|
||||||
from szurubooru.model.snapshot import Snapshot
|
from szurubooru.model.snapshot import Snapshot
|
||||||
from szurubooru.model.tag import Tag, TagImplication, TagName, TagSuggestion
|
from szurubooru.model.tag import Tag, TagImplication, TagName, TagSuggestion
|
||||||
from szurubooru.model.tag_category import TagCategory
|
from szurubooru.model.tag_category import TagCategory
|
||||||
from szurubooru.model.user import User, UserToken
|
from szurubooru.model.user import UserTagBlocklist, User, UserToken
|
||||||
|
|
|
@ -5,6 +5,46 @@ from szurubooru.model.comment import Comment
|
||||||
from szurubooru.model.post import Post, PostFavorite, PostScore
|
from szurubooru.model.post import Post, PostFavorite, PostScore
|
||||||
|
|
||||||
|
|
||||||
|
class UserTagBlocklist(Base):
|
||||||
|
__tablename__ = "user_tag_blocklist"
|
||||||
|
|
||||||
|
user_id = sa.Column(
|
||||||
|
"user_id",
|
||||||
|
sa.Integer,
|
||||||
|
sa.ForeignKey("user.id"),
|
||||||
|
primary_key=True,
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
tag_id = sa.Column(
|
||||||
|
"tag_id",
|
||||||
|
sa.Integer,
|
||||||
|
sa.ForeignKey("tag.id"),
|
||||||
|
primary_key=True,
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
tag = sa.orm.relationship(
|
||||||
|
"Tag",
|
||||||
|
backref=sa.orm.backref("user_tag_blocklist", cascade="all, delete-orphan"),
|
||||||
|
)
|
||||||
|
user = sa.orm.relationship(
|
||||||
|
"User",
|
||||||
|
backref=sa.orm.backref("user_tag_blocklist", cascade="all, delete-orphan"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, user_id: int=None, tag_id: int=None, user=None, tag=None) -> None:
|
||||||
|
if user_id is not None:
|
||||||
|
self.user_id = user_id
|
||||||
|
if tag_id is not None:
|
||||||
|
self.tag_id = tag_id
|
||||||
|
if user is not None:
|
||||||
|
self.user = user
|
||||||
|
if tag is not None:
|
||||||
|
self.tag = tag
|
||||||
|
|
||||||
|
|
||||||
class User(Base):
|
class User(Base):
|
||||||
__tablename__ = "user"
|
__tablename__ = "user"
|
||||||
|
|
||||||
|
@ -35,6 +75,7 @@ class User(Base):
|
||||||
"avatar_style", sa.Unicode(32), nullable=False, default=AVATAR_GRAVATAR
|
"avatar_style", sa.Unicode(32), nullable=False, default=AVATAR_GRAVATAR
|
||||||
)
|
)
|
||||||
|
|
||||||
|
blocklist = sa.orm.relationship("UserTagBlocklist")
|
||||||
comments = sa.orm.relationship("Comment")
|
comments = sa.orm.relationship("Comment")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -2,9 +2,9 @@ from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
from szurubooru import db, errors, model
|
from szurubooru import config, db, errors, model
|
||||||
from szurubooru.func import util
|
from szurubooru.func import tags, users, util
|
||||||
from szurubooru.search import criteria, tokens
|
from szurubooru.search import criteria, parser, tokens
|
||||||
from szurubooru.search.configs import util as search_util
|
from szurubooru.search.configs import util as search_util
|
||||||
from szurubooru.search.configs.base_search_config import (
|
from szurubooru.search.configs.base_search_config import (
|
||||||
BaseSearchConfig,
|
BaseSearchConfig,
|
||||||
|
@ -178,6 +178,32 @@ class PostSearchConfig(BaseSearchConfig):
|
||||||
new_special_tokens.append(token)
|
new_special_tokens.append(token)
|
||||||
search_query.special_tokens = new_special_tokens
|
search_query.special_tokens = new_special_tokens
|
||||||
|
|
||||||
|
blocklist_to_use = ""
|
||||||
|
|
||||||
|
if self.user: # Ensure there's a user object
|
||||||
|
if (self.user.rank == model.User.RANK_ANONYMOUS) and config.config["default_tag_blocklist_for_anonymous"]:
|
||||||
|
# Anonymous user, if configured to use default blocklist, do so
|
||||||
|
blocklist_to_use = config.config["default_tag_blocklist"]
|
||||||
|
else:
|
||||||
|
# Registered user, use their blocklist
|
||||||
|
user_blocklist_tags = users.get_blocklist_tag_from_user(self.user)
|
||||||
|
if user_blocklist_tags:
|
||||||
|
user_blocklist = db.session.query(model.Tag.first_name).filter(
|
||||||
|
model.Tag.tag_id.in_([e.tag_id for e in user_blocklist_tags])
|
||||||
|
).all()
|
||||||
|
blocklist_to_use = [e[0] for e in user_blocklist]
|
||||||
|
blocklist_to_use = " ".join(blocklist_to_use)
|
||||||
|
|
||||||
|
if len(blocklist_to_use) > 0:
|
||||||
|
# TODO Sort an already parsed and checked version instead?
|
||||||
|
blocklist_query = parser.Parser().parse(blocklist_to_use)
|
||||||
|
search_query_orig_list = [e.criterion.original_text for e in search_query.anonymous_tokens]
|
||||||
|
for t in blocklist_query.anonymous_tokens:
|
||||||
|
if t.criterion.original_text in search_query_orig_list:
|
||||||
|
continue
|
||||||
|
t.negated = True
|
||||||
|
search_query.anonymous_tokens.append(t)
|
||||||
|
|
||||||
def create_around_query(self) -> SaQuery:
|
def create_around_query(self) -> SaQuery:
|
||||||
return db.session.query(model.Post).options(sa.orm.lazyload("*"))
|
return db.session.query(model.Post).options(sa.orm.lazyload("*"))
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue