This commit is contained in:
Soblow 2024-11-18 18:32:35 +00:00 committed by GitHub
commit 9ae1f281eb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 446 additions and 11 deletions

View file

@ -68,6 +68,12 @@
</div> </div>
</li> </li>
<% } %> <% } %>
<% if (ctx.canEditBlocklist) { %>
<li class='blocklist'>
<%= ctx.makeTextInput({text: 'Blocklist'}) %>
</li>
<% } %>
</ul> </ul>
<div class='messages'></div> <div class='messages'></div>

View file

@ -89,6 +89,7 @@ class UserController {
canEditAvatar: api.hasPrivilege( canEditAvatar: api.hasPrivilege(
`users:edit:${infix}:avatar` `users:edit:${infix}:avatar`
), ),
canEditBlocklist: api.hasPrivilege(`users:edit:${infix}:blocklist`),
canEditAnything: api.hasPrivilege(`users:edit:${infix}`), canEditAnything: api.hasPrivilege(`users:edit:${infix}`),
canListTokens: api.hasPrivilege( canListTokens: api.hasPrivilege(
`userTokens:list:${infix}` `userTokens:list:${infix}`

View file

@ -3,11 +3,19 @@
const api = require("../api.js"); const api = require("../api.js");
const uri = require("../util/uri.js"); const uri = require("../util/uri.js");
const events = require("../events.js"); const events = require("../events.js");
const misc = require("../util/misc.js");
class User extends events.EventTarget { class User extends events.EventTarget {
constructor() { constructor() {
const TagList = require("./tag_list.js");
super(); super();
this._orig = {}; this._orig = {};
for (let obj of [this, this._orig]) {
obj._blocklist = new TagList();
}
this._updateFromResponse({}); this._updateFromResponse({});
} }
@ -71,6 +79,10 @@ class User extends events.EventTarget {
throw "Invalid operation"; throw "Invalid operation";
} }
get blocklist() {
return this._blocklist;
}
set name(value) { set name(value) {
this._name = value; this._name = value;
} }
@ -95,6 +107,10 @@ class User extends events.EventTarget {
this._password = value; this._password = value;
} }
set blocklist(value) {
this._blocklist = value || "";
}
static fromResponse(response) { static fromResponse(response) {
const ret = new User(); const ret = new User();
ret._updateFromResponse(response); ret._updateFromResponse(response);
@ -121,6 +137,11 @@ class User extends events.EventTarget {
if (this._rank !== this._orig._rank) { if (this._rank !== this._orig._rank) {
detail.rank = this._rank; detail.rank = this._rank;
} }
if (misc.arraysDiffer(this._blocklist, this._orig._blocklist)) {
detail.blocklist = this._blocklist.map(
(relation) => relation.names[0]
);
}
if (this._avatarStyle !== this._orig._avatarStyle) { if (this._avatarStyle !== this._orig._avatarStyle) {
detail.avatarStyle = this._avatarStyle; detail.avatarStyle = this._avatarStyle;
} }
@ -187,6 +208,10 @@ class User extends events.EventTarget {
_dislikedPostCount: response.dislikedPostCount, _dislikedPostCount: response.dislikedPostCount,
}; };
for (let obj of [this, this._orig]) {
obj._blocklist.sync(response.blocklist);
}
Object.assign(this, map); Object.assign(this, map);
Object.assign(this._orig, map); Object.assign(this._orig, map);

View file

@ -4,6 +4,8 @@ const events = require("../events.js");
const api = require("../api.js"); const api = require("../api.js");
const views = require("../util/views.js"); const views = require("../util/views.js");
const FileDropperControl = require("../controls/file_dropper_control.js"); const FileDropperControl = require("../controls/file_dropper_control.js");
const TagInputControl = require("../controls/tag_input_control.js")
const misc = require("../util/misc.js");
const template = views.getTemplate("user-edit"); const template = views.getTemplate("user-edit");
@ -41,6 +43,13 @@ class UserEditView extends events.EventTarget {
}); });
} }
if (this._blocklistFieldNode) {
new TagInputControl(
this._blocklistFieldNode,
this._user.blocklist
);
}
this._formNode.addEventListener("submit", (e) => this._evtSubmit(e)); this._formNode.addEventListener("submit", (e) => this._evtSubmit(e));
} }
@ -83,6 +92,10 @@ class UserEditView extends events.EventTarget {
? this._rankInputNode.value ? this._rankInputNode.value
: undefined, : undefined,
blocklist: this._blocklistFieldNode
? misc.splitByWhitespace(this._blocklistFieldNode.value)
: undefined,
avatarStyle: this._avatarStyleInputNode avatarStyle: this._avatarStyleInputNode
? this._avatarStyleInputNode.value ? this._avatarStyleInputNode.value
: undefined, : undefined,
@ -101,6 +114,10 @@ class UserEditView extends events.EventTarget {
return this._hostNode.querySelector("form"); return this._hostNode.querySelector("form");
} }
get _blocklistFieldNode() {
return this._formNode.querySelector(".blocklist input");
}
get _rankInputNode() { get _rankInputNode() {
return this._formNode.querySelector("[name=rank]"); return this._formNode.querySelector("[name=rank]");
} }

View file

@ -67,6 +67,12 @@ webhooks:
default_rank: regular default_rank: regular
# default blocklisted tags (space separated)
default_tag_blocklist: ''
# Apply blocklist for anonymous viewers too
default_tag_blocklist_for_anonymous: yes
privileges: privileges:
'users:create:self': anonymous # Registration permission 'users:create:self': anonymous # Registration permission
'users:create:any': administrator 'users:create:any': administrator
@ -76,11 +82,13 @@ privileges:
'users:edit:any:pass': moderator 'users:edit:any:pass': moderator
'users:edit:any:email': moderator 'users:edit:any:email': moderator
'users:edit:any:avatar': moderator 'users:edit:any:avatar': moderator
'users:edit:any:blocklist': moderator
'users:edit:any:rank': moderator 'users:edit:any:rank': moderator
'users:edit:self:name': regular 'users:edit:self:name': regular
'users:edit:self:pass': regular 'users:edit:self:pass': regular
'users:edit:self:email': regular 'users:edit:self:email': regular
'users:edit:self:avatar': regular 'users:edit:self:avatar': regular
'users:edit:self:blocklist': regular
'users:edit:self:rank': moderator # one can't promote themselves or anyone to upper rank than their own. 'users:edit:self:rank': moderator # one can't promote themselves or anyone to upper rank than their own.
'users:delete:any': administrator 'users:delete:any': administrator
'users:delete:self': regular 'users:delete:self': regular

View file

@ -43,6 +43,8 @@ def get_info(ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
"tagNameRegex": config.config["tag_name_regex"], "tagNameRegex": config.config["tag_name_regex"],
"tagCategoryNameRegex": config.config["tag_category_name_regex"], "tagCategoryNameRegex": config.config["tag_category_name_regex"],
"defaultUserRank": config.config["default_rank"], "defaultUserRank": config.config["default_rank"],
"defaultTagBlocklist": config.config["default_tag_blocklist"],
"defaultTagBlocklistForAnonymous": config.config["default_tag_blocklist_for_anonymous"],
"enableSafety": config.config["enable_safety"], "enableSafety": config.config["enable_safety"],
"contactEmail": config.config["contact_email"], "contactEmail": config.config["contact_email"],
"canSendMails": bool(config.config["smtp"]["host"]), "canSendMails": bool(config.config["smtp"]["host"]),

View file

@ -2,7 +2,7 @@ from datetime import datetime
from typing import Dict, List, Optional from typing import Dict, List, Optional
from szurubooru import db, model, rest, search from szurubooru import db, model, rest, search
from szurubooru.func import auth, serialization, snapshots, tags, versions from szurubooru.func import auth, serialization, snapshots, tags, versions, users
_search_executor = search.Executor(search.configs.TagSearchConfig()) _search_executor = search.Executor(search.configs.TagSearchConfig())

View file

@ -1,7 +1,7 @@
from typing import Any, Dict from typing import Any, Dict, List
from szurubooru import model, rest, search from szurubooru import db, model, rest, search
from szurubooru.func import auth, serialization, users, versions from szurubooru.func import auth, serialization, snapshots, users, versions, tags
_search_executor = search.Executor(search.configs.UserSearchConfig()) _search_executor = search.Executor(search.configs.UserSearchConfig())
@ -17,6 +17,18 @@ def _serialize(
) )
def _create_tag_if_needed(tag_names: List[str], user: model.User) -> None:
# Taken from tag_api.py
if not tag_names:
return
_existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names)
if len(new_tags):
auth.verify_privilege(user, "tags:create")
db.session.flush()
for tag in new_tags:
snapshots.create(tag, user)
@rest.routes.get("/users/?") @rest.routes.get("/users/?")
def get_users( def get_users(
ctx: rest.Context, _params: Dict[str, str] = {} ctx: rest.Context, _params: Dict[str, str] = {}
@ -50,6 +62,10 @@ def create_user(
) )
ctx.session.add(user) ctx.session.add(user)
ctx.session.commit() ctx.session.commit()
to_add, _ = users.update_user_blocklist(user, None)
for e in to_add:
ctx.session.add(e)
ctx.session.commit()
return _serialize(ctx, user, force_show_email=True) return _serialize(ctx, user, force_show_email=True)
@ -80,6 +96,16 @@ def update_user(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
if ctx.has_param("rank"): if ctx.has_param("rank"):
auth.verify_privilege(ctx.user, "users:edit:%s:rank" % infix) auth.verify_privilege(ctx.user, "users:edit:%s:rank" % infix)
users.update_user_rank(user, ctx.get_param_as_string("rank"), ctx.user) users.update_user_rank(user, ctx.get_param_as_string("rank"), ctx.user)
if ctx.has_param("blocklist"):
auth.verify_privilege(ctx.user, "users:edit:%s:blocklist" % infix)
blocklist = ctx.get_param_as_string_list("blocklist")
_create_tag_if_needed(blocklist, user) # Non-existing tags are created.
blocklist_tags = tags.get_tags_by_names(blocklist)
to_add, to_remove = users.update_user_blocklist(user, blocklist_tags)
for e in to_remove:
ctx.session.delete(e)
for e in to_add:
ctx.session.add(e)
if ctx.has_param("avatarStyle"): if ctx.has_param("avatarStyle"):
auth.verify_privilege(ctx.user, "users:edit:%s:avatar" % infix) auth.verify_privilege(ctx.user, "users:edit:%s:avatar" % infix)
users.update_user_avatar( users.update_user_avatar(

View file

@ -159,6 +159,9 @@ def get_tag_by_name(name: str) -> model.Tag:
def get_tags_by_names(names: List[str]) -> List[model.Tag]: def get_tags_by_names(names: List[str]) -> List[model.Tag]:
"""
Returns a list of all tags which names include all the letters from the input list
"""
names = util.icase_unique(names) names = util.icase_unique(names)
if len(names) == 0: if len(names) == 0:
return [] return []
@ -175,6 +178,24 @@ def get_tags_by_names(names: List[str]) -> List[model.Tag]:
) )
def get_tags_by_exact_names(names: List[str]) -> List[model.Tag]:
"""
Returns a list of tags matching the names from the input list
"""
entries = []
if len(names) == 0:
return []
names = [name.lower() for name in names]
entries = (
db.session.query(model.Tag)
.join(model.TagName)
.filter(
sa.func.lower(model.TagName.name).in_(names)
)
.all())
return entries
def get_or_create_tags_by_names( def get_or_create_tags_by_names(
names: List[str], names: List[str],
) -> Tuple[List[model.Tag], List[model.Tag]]: ) -> Tuple[List[model.Tag], List[model.Tag]]:

View file

@ -1,3 +1,4 @@
import copy
import re import re
from datetime import datetime from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
@ -5,7 +6,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
import sqlalchemy as sa import sqlalchemy as sa
from szurubooru import config, db, errors, model, rest from szurubooru import config, db, errors, model, rest
from szurubooru.func import auth, files, images, serialization, util from szurubooru.func import auth, files, images, serialization, util, tags
class UserNotFoundError(errors.NotFoundError): class UserNotFoundError(errors.NotFoundError):
@ -107,6 +108,7 @@ class UserSerializer(serialization.BaseSerializer):
"lastLoginTime": self.serialize_last_login_time, "lastLoginTime": self.serialize_last_login_time,
"version": self.serialize_version, "version": self.serialize_version,
"rank": self.serialize_rank, "rank": self.serialize_rank,
"blocklist": self.serialize_blocklist,
"avatarStyle": self.serialize_avatar_style, "avatarStyle": self.serialize_avatar_style,
"avatarUrl": self.serialize_avatar_url, "avatarUrl": self.serialize_avatar_url,
"commentCount": self.serialize_comment_count, "commentCount": self.serialize_comment_count,
@ -138,6 +140,9 @@ class UserSerializer(serialization.BaseSerializer):
def serialize_avatar_url(self) -> Any: def serialize_avatar_url(self) -> Any:
return get_avatar_url(self.user) return get_avatar_url(self.user)
def serialize_blocklist(self) -> Any:
return [tags.serialize_tag(tag) for tag in get_blocklist_tag_from_user(self.user)]
def serialize_comment_count(self) -> Any: def serialize_comment_count(self) -> Any:
return self.user.comment_count return self.user.comment_count
@ -294,6 +299,66 @@ def update_user_rank(
user.rank = rank user.rank = rank
def get_blocklist_from_user(user: model.User) -> List[model.UserTagBlocklist]:
"""
Return the UserTagBlocklist objects related to given user
"""
rez = (db.session.query(model.UserTagBlocklist)
.filter(
model.UserTagBlocklist.user_id == user.user_id
)
.all())
return rez
def get_blocklist_tag_from_user(user: model.User) -> List[model.UserTagBlocklist]:
"""
Return the Tags blocklisted by given user
"""
rez = (db.session.query(model.UserTagBlocklist.tag_id)
.filter(
model.UserTagBlocklist.user_id == user.user_id
))
rez2 = (db.session.query(model.Tag)
.filter(
model.Tag.tag_id.in_(rez)
).all())
return rez2
def update_user_blocklist(user: model.User, new_blocklist_tags: Optional[List[model.Tag]]) -> List[List[model.UserTagBlocklist]]:
"""
Modify blocklist for given user.
If new_blocklist_tags is None, set the blocklist to configured default tag blocklist.
"""
assert user
to_add: List[model.UserTagBlocklist] = []
to_remove: List[model.UserTagBlocklist] = []
if new_blocklist_tags is None: # We're creating the user, use default config blocklist
if 'default_tag_blocklist' in config.config.keys():
for e in tags.get_tags_by_exact_names(config.config['default_tag_blocklist'].split(' ')):
to_add.append(model.UserTagBlocklist(user_id=user.user_id, tag_id=e.tag_id))
else:
new_blocklist_ids: List[int] = [e.tag_id for e in new_blocklist_tags]
previous_blocklist_tags: List[model.Tag] = get_blocklist_from_user(user)
previous_blocklist_ids: List[int] = [e.tag_id for e in previous_blocklist_tags]
original_previous_blocklist_ids = copy.copy(previous_blocklist_ids)
## Remove tags no longer in the new list
for i in range(len(original_previous_blocklist_ids)):
old_tag_id = original_previous_blocklist_ids[i]
if old_tag_id not in new_blocklist_ids:
to_remove.append(previous_blocklist_tags[i])
previous_blocklist_ids.remove(old_tag_id)
## Add tags not yet in the original list
for new_tag_id in new_blocklist_ids:
if new_tag_id not in previous_blocklist_ids:
to_add.append(model.UserTagBlocklist(user_id=user.user_id, tag_id=new_tag_id))
return to_add, to_remove
def update_user_avatar( def update_user_avatar(
user: model.User, avatar_style: str, avatar_content: Optional[bytes] = None user: model.User, avatar_style: str, avatar_content: Optional[bytes] = None
) -> None: ) -> None:

View file

@ -0,0 +1,30 @@
'''
Add blocklist related fields
add_blocklist
Revision ID: 9ba5e3a6ee7c
Created at: 2023-05-20 22:28:10.824954
'''
import sqlalchemy as sa
from alembic import op
revision = '9ba5e3a6ee7c'
down_revision = 'adcd63ff76a2'
branch_labels = None
depends_on = None
def upgrade():
op.create_table(
"user_tag_blocklist",
sa.Column("user_id", sa.Integer(), nullable=False),
sa.Column("tag_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["user.id"]),
sa.ForeignKeyConstraint(["tag_id"], ["tag.id"]),
sa.PrimaryKeyConstraint("user_id", "tag_id"),
)
def downgrade():
op.drop_table('user_tag_blocklist')

View file

@ -16,4 +16,4 @@ from szurubooru.model.post import (
from szurubooru.model.snapshot import Snapshot from szurubooru.model.snapshot import Snapshot
from szurubooru.model.tag import Tag, TagImplication, TagName, TagSuggestion from szurubooru.model.tag import Tag, TagImplication, TagName, TagSuggestion
from szurubooru.model.tag_category import TagCategory from szurubooru.model.tag_category import TagCategory
from szurubooru.model.user import User, UserToken from szurubooru.model.user import UserTagBlocklist, User, UserToken

View file

@ -5,6 +5,46 @@ from szurubooru.model.comment import Comment
from szurubooru.model.post import Post, PostFavorite, PostScore from szurubooru.model.post import Post, PostFavorite, PostScore
class UserTagBlocklist(Base):
__tablename__ = "user_tag_blocklist"
user_id = sa.Column(
"user_id",
sa.Integer,
sa.ForeignKey("user.id"),
primary_key=True,
nullable=False,
index=True,
)
tag_id = sa.Column(
"tag_id",
sa.Integer,
sa.ForeignKey("tag.id"),
primary_key=True,
nullable=False,
index=True,
)
tag = sa.orm.relationship(
"Tag",
backref=sa.orm.backref("user_tag_blocklist", cascade="all, delete-orphan"),
)
user = sa.orm.relationship(
"User",
backref=sa.orm.backref("user_tag_blocklist", cascade="all, delete-orphan"),
)
def __init__(self, user_id: int=None, tag_id: int=None, user=None, tag=None) -> None:
if user_id is not None:
self.user_id = user_id
if tag_id is not None:
self.tag_id = tag_id
if user is not None:
self.user = user
if tag is not None:
self.tag = tag
class User(Base): class User(Base):
__tablename__ = "user" __tablename__ = "user"
@ -35,6 +75,7 @@ class User(Base):
"avatar_style", sa.Unicode(32), nullable=False, default=AVATAR_GRAVATAR "avatar_style", sa.Unicode(32), nullable=False, default=AVATAR_GRAVATAR
) )
blocklist = sa.orm.relationship("UserTagBlocklist")
comments = sa.orm.relationship("Comment") comments = sa.orm.relationship("Comment")
@property @property

View file

@ -2,9 +2,9 @@ from typing import Any, Dict, Optional, Tuple
import sqlalchemy as sa import sqlalchemy as sa
from szurubooru import db, errors, model from szurubooru import config, db, errors, model
from szurubooru.func import util from szurubooru.func import tags, users, util
from szurubooru.search import criteria, tokens from szurubooru.search import criteria, parser, tokens
from szurubooru.search.configs import util as search_util from szurubooru.search.configs import util as search_util
from szurubooru.search.configs.base_search_config import ( from szurubooru.search.configs.base_search_config import (
BaseSearchConfig, BaseSearchConfig,
@ -178,6 +178,32 @@ class PostSearchConfig(BaseSearchConfig):
new_special_tokens.append(token) new_special_tokens.append(token)
search_query.special_tokens = new_special_tokens search_query.special_tokens = new_special_tokens
blocklist_to_use = ""
if self.user: # Ensure there's a user object
if (self.user.rank == model.User.RANK_ANONYMOUS) and config.config["default_tag_blocklist_for_anonymous"]:
# Anonymous user, if configured to use default blocklist, do so
blocklist_to_use = config.config["default_tag_blocklist"]
else:
# Registered user, use their blocklist
user_blocklist_tags = users.get_blocklist_tag_from_user(self.user)
if user_blocklist_tags:
user_blocklist = db.session.query(model.Tag.first_name).filter(
model.Tag.tag_id.in_([e.tag_id for e in user_blocklist_tags])
).all()
blocklist_to_use = [e[0] for e in user_blocklist]
blocklist_to_use = " ".join(blocklist_to_use)
if len(blocklist_to_use) > 0:
# TODO Sort an already parsed and checked version instead?
blocklist_query = parser.Parser().parse(blocklist_to_use)
search_query_orig_list = [e.criterion.original_text for e in search_query.anonymous_tokens]
for t in blocklist_query.anonymous_tokens:
if t.criterion.original_text in search_query_orig_list:
continue
t.negated = True
search_query.anonymous_tokens.append(t)
def create_around_query(self) -> SaQuery: def create_around_query(self) -> SaQuery:
return db.session.query(model.Post).options(sa.orm.lazyload("*")) return db.session.query(model.Post).options(sa.orm.lazyload("*"))

View file

@ -26,6 +26,8 @@ def test_info_api(
"tag_name_regex": "3", "tag_name_regex": "3",
"tag_category_name_regex": "4", "tag_category_name_regex": "4",
"default_rank": "5", "default_rank": "5",
"default_tag_blocklist": "testTag",
"default_tag_blocklist_for_anonymous": True,
"privileges": { "privileges": {
"test_key1": "test_value1", "test_key1": "test_value1",
"test_key2": "test_value2", "test_key2": "test_value2",
@ -48,6 +50,8 @@ def test_info_api(
"tagNameRegex": "3", "tagNameRegex": "3",
"tagCategoryNameRegex": "4", "tagCategoryNameRegex": "4",
"defaultUserRank": "5", "defaultUserRank": "5",
"defaultTagBlocklist": "testTag",
"defaultTagBlocklistForAnonymous": True,
"privileges": { "privileges": {
"testKey1": "test_value1", "testKey1": "test_value1",
"testKey2": "test_value2", "testKey2": "test_value2",

View file

@ -0,0 +1,139 @@
from datetime import datetime
from unittest.mock import patch
import pytest
from szurubooru import api, db, errors, model
from szurubooru.func import posts
## TODO: Add following tests:
## - Retrieve posts without blocklist active for current registered user
## - Retrieve posts with blocklist active for current registered user
## - Retrieve posts without blocklist active for anonymous user
## - Retrieve posts with blocklist active for anonymous user
## - Creation of user with default blocklist (test that user_blocklist entries are properly added to db, with right infos)
## - Modification of user with/without blocklist changes
## - Retrieve posts with a query including a blocklisted tag (it should include results with the tag)
## - Behavior when creating user with default blocklist and tags from this list don't exist (blocklist entry shouldn't be added)
## - Test all small functions used across blocklist features
def test_blocklist(user_factory, post_factory, context_factory, config_injector, user_blocklist_factory, tag_factory):
"""
Test that user blocklist is applied on post retrieval
"""
tag1 = tag_factory(names=['tag1'])
tag2 = tag_factory(names=['tag2'])
tag3 = tag_factory(names=['tag3'])
post1 = post_factory(id=11, tags=[tag1, tag2])
post2 = post_factory(id=12, tags=[tag1])
post3 = post_factory(id=13, tags=[tag2])
post4 = post_factory(id=14, tags=[tag3])
post5 = post_factory(id=15)
user1 = user_factory(rank=model.User.RANK_REGULAR)
blocklist1 = user_blocklist_factory(tag=tag1, user=user1)
config_injector({
"privileges": {
"posts:list": model.User.RANK_REGULAR,
}
})
db.session.add_all([tag1, tag2, tag3, user1, blocklist1, post1, post2, post3, post4, post5])
db.session.flush()
# We can't check that the posts we retrieve are the ones we want
with patch("szurubooru.func.posts.serialize_post"):
posts.serialize_post.side_effect = (
lambda post, *_args, **_kwargs: "serialized post %d" % post.post_id
)
result = api.post_api.get_posts(
context_factory(
params={"query": "", "offset": 0},
user=user1,
)
)
assert result == {
"query": "",
"offset": 0,
"limit": 100,
"total": 3,
"results": ["serialized post 15", "serialized post 14", "serialized post 13"],
}
# def test_blocklist_no_anonymous(user_factory, post_factory, context_factory, config_injector, tag_factory):
# """
# Test that default blocklist isn't applied on anonymous users on post retrieval if disabled in configuration
# """
# tag1 = tag_factory(names=['tag1'])
# post1 = post_factory(id=21, tags=[tag1])
# post2 = post_factory(id=22, tags=[tag1])
# post3 = post_factory(id=23)
# user1 = user_factory(rank=model.User.RANK_ANONYMOUS)
# config_injector({
# "default_tag_blocklist": "tag1",
# "default_tag_blocklist_for_anonymous": False,
# "privileges": {
# "posts:list": model.User.RANK_ANONYMOUS,
# }
# })
# db.session.add_all([tag1, post1, post2, post3])
# db.session.flush()
# with patch("szurubooru.func.posts.serialize_post"):
# posts.serialize_post.side_effect = (
# lambda post, *_args, **_kwargs: "serialized post %d" % post.post_id
# )
# result = api.post_api.get_posts(
# context_factory(
# params={"query": "", "offset": 0},
# user=user1,
# )
# )
# assert result == {
# "query": "",
# "offset": 0,
# "limit": 100,
# "total": 3,
# "results": ["serialized post 23", "serialized post 22", "serialized post 21"],
# }
def test_blocklist_anonymous(user_factory, post_factory, context_factory, config_injector, tag_factory):
"""
Test that default blocklist is applied on anonymous users on post retrieval if enabled in configuration
"""
tag1 = tag_factory(names=['tag1'])
tag2 = tag_factory(names=['tag2'])
tag3 = tag_factory(names=['tag3'])
post1 = post_factory(id=31, tags=[tag1, tag2])
post2 = post_factory(id=32, tags=[tag1])
post3 = post_factory(id=33, tags=[tag2])
post4 = post_factory(id=34, tags=[tag3])
post5 = post_factory(id=35)
config_injector({
"default_tag_blocklist": "tag3",
"default_tag_blocklist_for_anonymous": True,
"privileges": {
"posts:list": model.User.RANK_ANONYMOUS,
}
})
db.session.add_all([tag1, tag2, tag3, post1, post2, post3, post4, post5])
db.session.flush()
with patch("szurubooru.func.posts.serialize_post"):
posts.serialize_post.side_effect = (
lambda post, *_args, **_kwargs: "serialized post %d" % post.post_id
)
result = api.post_api.get_posts(
context_factory(
params={"query": "", "offset": 0},
user=user_factory(rank=model.User.RANK_ANONYMOUS),
)
)
assert result == {
"query": "",
"offset": 0,
"limit": 100,
"total": 4,
"results": ["serialized post 35", "serialized post 33", "serialized post 32", "serialized post 31"],
}
## TODO: Test when we add blocklist items to the query

View file

@ -21,6 +21,8 @@ def test_creating_user(user_factory, context_factory, fake_datetime):
"szurubooru.func.users.update_user_rank" "szurubooru.func.users.update_user_rank"
), patch( ), patch(
"szurubooru.func.users.update_user_avatar" "szurubooru.func.users.update_user_avatar"
), patch(
"szurubooru.func.users.update_user_blocklist"
), patch( ), patch(
"szurubooru.func.users.serialize_user" "szurubooru.func.users.serialize_user"
), fake_datetime( ), fake_datetime(
@ -28,6 +30,7 @@ def test_creating_user(user_factory, context_factory, fake_datetime):
): ):
users.serialize_user.return_value = "serialized user" users.serialize_user.return_value = "serialized user"
users.create_user.return_value = user users.create_user.return_value = user
users.update_user_blocklist.return_value = ([],[])
result = api.user_api.create_user( result = api.user_api.create_user(
context_factory( context_factory(
params={ params={
@ -50,6 +53,7 @@ def test_creating_user(user_factory, context_factory, fake_datetime):
assert not users.update_user_email.called assert not users.update_user_email.called
users.update_user_rank.called_once_with(user, "moderator") users.update_user_rank.called_once_with(user, "moderator")
users.update_user_avatar.called_once_with(user, "manual", b"...") users.update_user_avatar.called_once_with(user, "manual", b"...")
users.update_user_blocklist.called_once_with(user, None)
@pytest.mark.parametrize("field", ["name", "password"]) @pytest.mark.parametrize("field", ["name", "password"])

View file

@ -14,11 +14,13 @@ def inject_config(config_injector):
"users:edit:self:name": model.User.RANK_REGULAR, "users:edit:self:name": model.User.RANK_REGULAR,
"users:edit:self:pass": model.User.RANK_REGULAR, "users:edit:self:pass": model.User.RANK_REGULAR,
"users:edit:self:email": model.User.RANK_REGULAR, "users:edit:self:email": model.User.RANK_REGULAR,
"users:edit:self:blocklist": model.User.RANK_REGULAR,
"users:edit:self:rank": model.User.RANK_MODERATOR, "users:edit:self:rank": model.User.RANK_MODERATOR,
"users:edit:self:avatar": model.User.RANK_MODERATOR, "users:edit:self:avatar": model.User.RANK_MODERATOR,
"users:edit:any:name": model.User.RANK_MODERATOR, "users:edit:any:name": model.User.RANK_MODERATOR,
"users:edit:any:pass": model.User.RANK_MODERATOR, "users:edit:any:pass": model.User.RANK_MODERATOR,
"users:edit:any:email": model.User.RANK_MODERATOR, "users:edit:any:email": model.User.RANK_MODERATOR,
"users:edit:any:blocklist": model.User.RANK_MODERATOR,
"users:edit:any:rank": model.User.RANK_ADMINISTRATOR, "users:edit:any:rank": model.User.RANK_ADMINISTRATOR,
"users:edit:any:avatar": model.User.RANK_ADMINISTRATOR, "users:edit:any:avatar": model.User.RANK_ADMINISTRATOR,
}, },

View file

@ -136,6 +136,20 @@ def user_token_factory(user_factory):
return factory return factory
@pytest.fixture
def user_blocklist_factory(user_factory, tag_factory):
def factory(tag=None, user=None):
if user is None:
user = user_factory()
if tag is None:
tag = tag_factory()
return model.UserTagBlocklist(
tag=tag, user=user
)
return factory
@pytest.fixture @pytest.fixture
def tag_category_factory(): def tag_category_factory():
def factory(name=None, color="dummy", order=1, default=False): def factory(name=None, color="dummy", order=1, default=False):
@ -172,6 +186,7 @@ def post_factory():
id=None, id=None,
safety=model.Post.SAFETY_SAFE, safety=model.Post.SAFETY_SAFE,
type=model.Post.TYPE_IMAGE, type=model.Post.TYPE_IMAGE,
tags=[],
checksum="...", checksum="...",
): ):
post = model.Post() post = model.Post()
@ -182,6 +197,7 @@ def post_factory():
post.flags = [] post.flags = []
post.mime_type = "application/octet-stream" post.mime_type = "application/octet-stream"
post.creation_time = datetime(1996, 1, 1) post.creation_time = datetime(1996, 1, 1)
post.tags = tags
return post return post
return factory return factory

View file

@ -158,6 +158,7 @@ def test_serialize_user(user_factory):
"avatarUrl": "https://example.com/avatar.png", "avatarUrl": "https://example.com/avatar.png",
"likedPostCount": 66, "likedPostCount": 66,
"dislikedPostCount": 33, "dislikedPostCount": 33,
"blocklist": [],
"commentCount": 0, "commentCount": 0,
"favoritePostCount": 0, "favoritePostCount": 0,
"uploadedPostCount": 0, "uploadedPostCount": 0,
@ -235,7 +236,7 @@ def test_create_user_for_first_user(fake_datetime):
"szurubooru.func.users.update_user_password" "szurubooru.func.users.update_user_password"
), patch("szurubooru.func.users.update_user_email"), fake_datetime( ), patch("szurubooru.func.users.update_user_email"), fake_datetime(
"1997-01-01" "1997-01-01"
): ), patch("szurubooru.func.users.update_user_blocklist"):
user = users.create_user("name", "password", "email") user = users.create_user("name", "password", "email")
assert user.creation_time == datetime(1997, 1, 1) assert user.creation_time == datetime(1997, 1, 1)
assert user.last_login_time is None assert user.last_login_time is None
@ -251,7 +252,8 @@ def test_create_user_for_subsequent_users(user_factory, config_injector):
db.session.flush() db.session.flush()
with patch("szurubooru.func.users.update_user_name"), patch( with patch("szurubooru.func.users.update_user_name"), patch(
"szurubooru.func.users.update_user_email" "szurubooru.func.users.update_user_email"
), patch("szurubooru.func.users.update_user_password"): ), patch("szurubooru.func.users.update_user_password"
), patch("szurubooru.func.users.update_user_blocklist"):
user = users.create_user("name", "password", "email") user = users.create_user("name", "password", "email")
assert user.rank == model.User.RANK_REGULAR assert user.rank == model.User.RANK_REGULAR