server/general: add assertions

This commit is contained in:
rr- 2016-08-14 10:45:00 +02:00
parent bb86e9bf56
commit c2bbf7b62c
10 changed files with 77 additions and 0 deletions

View file

@ -40,6 +40,7 @@ def create_password():
return ''.join(random.choice(alphabet[l]) for l in list(pattern))
def is_valid_password(user, password):
assert user
salt, valid_hash = user.password_salt, user.password_hash
possible_hashes = [
get_password_hash(salt, password),
@ -48,6 +49,7 @@ def is_valid_password(user, password):
return valid_hash in possible_hashes
def has_privilege(user, privilege_name):
assert user
all_ranks = list(RANK_MAP.keys())
assert privilege_name in config.config['privileges']
assert user.rank in all_ranks
@ -57,11 +59,13 @@ def has_privilege(user, privilege_name):
return user.rank in good_ranks
def verify_privilege(user, privilege_name):
assert user
if not has_privilege(user, privilege_name):
raise errors.AuthError('Insufficient privileges to do this.')
def generate_authentication_token(user):
''' Generate nonguessable challenge (e.g. links in password reminder). '''
assert user
digest = hashlib.md5()
digest.update(config.config['secret'].encode('utf8'))
digest.update(user.password_salt.encode('utf8'))

View file

@ -42,6 +42,7 @@ def create_comment(user, post, text):
return comment
def update_comment_text(comment, text):
assert comment
if not text:
raise EmptyCommentTextError('Comment text cannot be empty.')
comment.text = text

View file

@ -5,23 +5,32 @@ from szurubooru.func import scores
class InvalidFavoriteTargetError(errors.ValidationError): pass
def _get_table_info(entity):
assert entity
resource_type, _, _ = db.util.get_resource_info(entity)
if resource_type == 'post':
return db.PostFavorite, lambda table: table.post_id
raise InvalidFavoriteTargetError()
def _get_fav_entity(entity, user):
assert entity
assert user
return db.util.get_aux_entity(db.session, _get_table_info, entity, user)
def has_favorited(entity, user):
assert entity
assert user
return _get_fav_entity(entity, user) is not None
def unset_favorite(entity, user):
assert entity
assert user
fav_entity = _get_fav_entity(entity, user)
if fav_entity:
db.session.delete(fav_entity)
def set_favorite(entity, user):
assert entity
assert user
try:
scores.set_score(entity, user, 1)
except scores.InvalidScoreTargetError:

View file

@ -1,6 +1,7 @@
import urllib.request
def download(url):
assert url
request = urllib.request.Request(url)
request.add_header('Referer', url)
with urllib.request.urlopen(request) as handle:

View file

@ -36,27 +36,33 @@ FLAG_MAP = {
}
def get_post_content_url(post):
assert post
return '%s/posts/%d.%s' % (
config.config['data_url'].rstrip('/'),
post.post_id,
mime.get_extension(post.mime_type) or 'dat')
def get_post_thumbnail_url(post):
assert post
return '%s/generated-thumbnails/%d.jpg' % (
config.config['data_url'].rstrip('/'),
post.post_id)
def get_post_content_path(post):
assert post
return 'posts/%d.%s' % (
post.post_id, mime.get_extension(post.mime_type) or 'dat')
def get_post_thumbnail_path(post):
assert post
return 'generated-thumbnails/%d.jpg' % (post.post_id)
def get_post_thumbnail_backup_path(post):
assert post
return 'posts/custom-thumbnails/%d.dat' % (post.post_id)
def serialize_note(note):
assert note
return {
'polygon': note.polygon,
'text': note.text,
@ -175,6 +181,7 @@ def create_post(content, tag_names, user):
return (post, new_tags)
def update_post_safety(post, safety):
assert post
safety = util.flip(SAFETY_MAP).get(safety, None)
if not safety:
raise InvalidPostSafetyError(
@ -182,11 +189,13 @@ def update_post_safety(post, safety):
post.safety = safety
def update_post_source(post, source):
assert post
if util.value_exceeds_column_size(source, db.Post.source):
raise InvalidPostSourceError('Source is too long.')
post.source = source
def update_post_content(post, content):
assert post
if not content:
raise InvalidPostContentError('Post content missing.')
post.mime_type = mime.get_mime_type(content)
@ -227,6 +236,7 @@ def update_post_content(post, content):
update_post_thumbnail(post, content=None, do_delete=False)
def update_post_thumbnail(post, content=None, do_delete=True):
assert post
if not content:
content = files.get(get_post_content_path(post))
if do_delete:
@ -236,6 +246,7 @@ def update_post_thumbnail(post, content=None, do_delete=True):
generate_post_thumbnail(post)
def generate_post_thumbnail(post):
assert post
if files.has(get_post_thumbnail_backup_path(post)):
content = files.get(get_post_thumbnail_backup_path(post))
else:
@ -250,11 +261,13 @@ def generate_post_thumbnail(post):
files.save(get_post_thumbnail_path(post), EMPTY_PIXEL)
def update_post_tags(post, tag_names):
assert post
existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names)
post.tags = existing_tags + new_tags
return new_tags
def update_post_relations(post, new_post_ids):
assert post
old_posts = post.relations
old_post_ids = [p.post_id for p in old_posts]
new_posts = db.session \
@ -274,6 +287,7 @@ def update_post_relations(post, new_post_ids):
relation.relations.append(post)
def update_post_notes(post, notes):
assert post
post.notes = []
for note in notes:
for field in ('polygon', 'text'):
@ -309,6 +323,7 @@ def update_post_notes(post, notes):
db.PostNote(polygon=note['polygon'], text=str(note['text'])))
def update_post_flags(post, flags):
assert post
target_flags = []
for flag in flags:
flag = util.flip(FLAG_MAP).get(flag, None)
@ -319,6 +334,7 @@ def update_post_flags(post, flags):
post.flags = target_flags
def feature_post(post, user):
assert post
post_feature = db.PostFeature()
post_feature.time = datetime.datetime.utcnow()
post_feature.post = post
@ -326,4 +342,5 @@ def feature_post(post, user):
db.session.add(post_feature)
def delete(post):
assert post
db.session.delete(post)

View file

@ -6,6 +6,7 @@ class InvalidScoreTargetError(errors.ValidationError): pass
class InvalidScoreValueError(errors.ValidationError): pass
def _get_table_info(entity):
assert entity
resource_type, _, _ = db.util.get_resource_info(entity)
if resource_type == 'post':
return db.PostScore, lambda table: table.post_id
@ -14,14 +15,19 @@ def _get_table_info(entity):
raise InvalidScoreTargetError()
def _get_score_entity(entity, user):
assert user
return db.util.get_aux_entity(db.session, _get_table_info, entity, user)
def delete_score(entity, user):
assert entity
assert user
score_entity = _get_score_entity(entity, user)
if score_entity:
db.session.delete(score_entity)
def get_score(entity, user):
assert entity
assert user
table, get_column = _get_table_info(entity)
row = db.session \
.query(table.score) \
@ -31,6 +37,8 @@ def get_score(entity, user):
return row[0] if row else 0
def set_score(entity, user, score):
assert entity
assert user
if not score:
delete_score(entity, user)
try:

View file

@ -40,6 +40,7 @@ serializers = {
}
def get_previous_snapshot(snapshot):
assert snapshot
return db.session \
.query(db.Snapshot) \
.filter(db.Snapshot.resource_type == snapshot.resource_type) \
@ -50,6 +51,7 @@ def get_previous_snapshot(snapshot):
.first()
def get_snapshots(entity):
assert entity
resource_type, resource_id, _ = db.util.get_resource_info(entity)
return db.session \
.query(db.Snapshot) \
@ -59,6 +61,7 @@ def get_snapshots(entity):
.all()
def serialize_snapshot(snapshot, earlier_snapshot=()):
assert snapshot
if earlier_snapshot is ():
earlier_snapshot = get_previous_snapshot(snapshot)
return {
@ -82,6 +85,8 @@ def get_serialized_history(entity):
return ret
def _save(operation, entity, auth_user):
assert operation
assert entity
resource_type, resource_id, resource_repr = db.util.get_resource_info(entity)
now = datetime.datetime.utcnow()
@ -115,10 +120,13 @@ def _save(operation, entity, auth_user):
db.session.add(snapshot)
def save_entity_creation(entity, auth_user):
assert entity
_save(db.Snapshot.OPERATION_CREATED, entity, auth_user)
def save_entity_modification(entity, auth_user):
assert entity
_save(db.Snapshot.OPERATION_MODIFIED, entity, auth_user)
def save_entity_deletion(entity, auth_user):
assert entity
_save(db.Snapshot.OPERATION_DELETED, entity, auth_user)

View file

@ -37,6 +37,7 @@ def create_category(name, color):
return category
def update_category_name(category, name):
assert category
if not name:
raise InvalidTagCategoryNameError('Name cannot be empty.')
expr = sqlalchemy.func.lower(db.TagCategory.name) == name.lower()
@ -52,6 +53,7 @@ def update_category_name(category, name):
category.name = name
def update_category_color(category, color):
assert category
if not color:
raise InvalidTagCategoryColorError('Color cannot be empty.')
if not re.match(r'^#?[0-9a-z]+$', color):
@ -103,6 +105,7 @@ def get_default_category():
return category
def set_default_category(category):
assert category
old_category = try_get_default_category()
if old_category:
old_category.default = False

View file

@ -20,6 +20,7 @@ def _verify_name_validity(name):
raise InvalidTagNameError('Name must satisfy regex %r.' % name_regex)
def _get_plain_names(tag):
assert tag
return [tag_name.name for tag_name in tag.names]
def _lower_list(names):
@ -157,6 +158,7 @@ def get_or_create_tags_by_names(names):
return existing_tags, new_tags
def get_tag_siblings(tag):
assert tag
tag_alias = sqlalchemy.orm.aliased(db.Tag)
pt_alias1 = sqlalchemy.orm.aliased(db.PostTag)
pt_alias2 = sqlalchemy.orm.aliased(db.PostTag)
@ -172,6 +174,7 @@ def get_tag_siblings(tag):
return result
def delete(source_tag):
assert source_tag
db.session.execute(
sqlalchemy.sql.expression.delete(db.TagSuggestion) \
.where(db.TagSuggestion.child_id == source_tag.tag_id))
@ -181,6 +184,8 @@ def delete(source_tag):
db.session.delete(source_tag)
def merge_tags(source_tag, target_tag):
assert source_tag
assert target_tag
if source_tag.tag_id == target_tag.tag_id:
raise InvalidTagRelationError('Cannot merge tag with itself.')
pt1 = db.PostTag
@ -205,9 +210,11 @@ def create_tag(names, category_name, suggestions, implications):
return tag
def update_tag_category_name(tag, category_name):
assert tag
tag.category = tag_categories.get_category_by_name(category_name)
def update_tag_names(tag, names):
assert tag
names = util.icase_unique([name for name in names if name])
if not len(names):
raise InvalidTagNameError('At least one name must be specified.')
@ -232,16 +239,19 @@ def update_tag_names(tag, names):
tag.names.append(db.TagName(name))
def update_tag_implications(tag, relations):
assert tag
if _check_name_intersection(_get_plain_names(tag), relations):
raise InvalidTagRelationError('Tag cannot imply itself.')
tag.implications = get_tags_by_names(relations)
def update_tag_suggestions(tag, relations):
assert tag
if _check_name_intersection(_get_plain_names(tag), relations):
raise InvalidTagRelationError('Tag cannot suggest itself.')
tag.suggestions = get_tags_by_names(relations)
def update_tag_description(tag, description):
assert tag
if util.value_exceeds_column_size(description, db.Tag.description):
raise InvalidTagDescriptionError('Description is too long.')
tag.description = description

View file

@ -16,15 +16,20 @@ def _get_avatar_path(name):
return 'avatars/' + name.lower() + '.png'
def _get_avatar_url(user):
assert user
if user.avatar_style == user.AVATAR_GRAVATAR:
assert user.email or user.name
return 'https://gravatar.com/avatar/%s?d=retro&s=%d' % (
util.get_md5((user.email or user.name).lower()),
config.config['thumbnails']['avatar_width'])
else:
assert user.name
return '%s/avatars/%s.png' % (
config.config['data_url'].rstrip('/'), user.name.lower())
def _get_email(user, authenticated_user, force_show_email):
assert user
assert authenticated_user
if not force_show_email \
and authenticated_user.user_id != user.user_id \
and not auth.has_privilege(authenticated_user, 'users:edit:any:email'):
@ -32,11 +37,15 @@ def _get_email(user, authenticated_user, force_show_email):
return user.email
def _get_liked_post_count(user, authenticated_user):
assert user
assert authenticated_user
if authenticated_user.user_id != user.user_id:
return False
return user.liked_post_count
def _get_disliked_post_count(user, authenticated_user):
assert user
assert authenticated_user
if authenticated_user.user_id != user.user_id:
return False
return user.disliked_post_count
@ -113,6 +122,7 @@ def create_user(name, password, email):
return user
def update_user_name(user, name):
assert user
if not name:
raise InvalidUserNameError('Name cannot be empty.')
if util.value_exceeds_column_size(name, db.User.name):
@ -130,6 +140,7 @@ def update_user_name(user, name):
user.name = name
def update_user_password(user, password):
assert user
if not password:
raise InvalidPasswordError('Password cannot be empty.')
password_regex = config.config['password_regex']
@ -140,6 +151,7 @@ def update_user_password(user, password):
user.password_hash = auth.get_password_hash(user.password_salt, password)
def update_user_email(user, email):
assert user
if email:
email = email.strip()
if not email:
@ -151,6 +163,7 @@ def update_user_email(user, email):
user.email = email
def update_user_rank(user, rank, authenticated_user):
assert user
if not rank:
raise InvalidRankError('Rank cannot be empty.')
rank = util.flip(auth.RANK_MAP).get(rank.strip(), None)
@ -166,6 +179,7 @@ def update_user_rank(user, rank, authenticated_user):
user.rank = rank
def update_user_avatar(user, avatar_style, avatar_content):
assert user
if avatar_style == 'gravatar':
user.avatar_style = user.AVATAR_GRAVATAR
elif avatar_style == 'manual':
@ -186,9 +200,11 @@ def update_user_avatar(user, avatar_style, avatar_content):
avatar_style, ['gravatar', 'manual']))
def bump_user_login_time(user):
assert user
user.last_login_time = datetime.datetime.utcnow()
def reset_user_password(user):
assert user
password = auth.create_password()
user.password_salt = auth.create_password()
user.password_hash = auth.get_password_hash(user.password_salt, password)