From c2bbf7b62ca58cf6360d6c3bad67fe2f4e02f40f Mon Sep 17 00:00:00 2001 From: rr- Date: Sun, 14 Aug 2016 10:45:00 +0200 Subject: [PATCH] server/general: add assertions --- server/szurubooru/func/auth.py | 4 ++++ server/szurubooru/func/comments.py | 1 + server/szurubooru/func/favorites.py | 9 +++++++++ server/szurubooru/func/net.py | 1 + server/szurubooru/func/posts.py | 17 +++++++++++++++++ server/szurubooru/func/scores.py | 8 ++++++++ server/szurubooru/func/snapshots.py | 8 ++++++++ server/szurubooru/func/tag_categories.py | 3 +++ server/szurubooru/func/tags.py | 10 ++++++++++ server/szurubooru/func/users.py | 16 ++++++++++++++++ 10 files changed, 77 insertions(+) diff --git a/server/szurubooru/func/auth.py b/server/szurubooru/func/auth.py index 95d7f16f..e41c5ee1 100644 --- a/server/szurubooru/func/auth.py +++ b/server/szurubooru/func/auth.py @@ -40,6 +40,7 @@ def create_password(): return ''.join(random.choice(alphabet[l]) for l in list(pattern)) def is_valid_password(user, password): + assert user salt, valid_hash = user.password_salt, user.password_hash possible_hashes = [ get_password_hash(salt, password), @@ -48,6 +49,7 @@ def is_valid_password(user, password): return valid_hash in possible_hashes def has_privilege(user, privilege_name): + assert user all_ranks = list(RANK_MAP.keys()) assert privilege_name in config.config['privileges'] assert user.rank in all_ranks @@ -57,11 +59,13 @@ def has_privilege(user, privilege_name): return user.rank in good_ranks def verify_privilege(user, privilege_name): + assert user if not has_privilege(user, privilege_name): raise errors.AuthError('Insufficient privileges to do this.') def generate_authentication_token(user): ''' Generate nonguessable challenge (e.g. links in password reminder). ''' + assert user digest = hashlib.md5() digest.update(config.config['secret'].encode('utf8')) digest.update(user.password_salt.encode('utf8')) diff --git a/server/szurubooru/func/comments.py b/server/szurubooru/func/comments.py index 65f9edf1..136c0c09 100644 --- a/server/szurubooru/func/comments.py +++ b/server/szurubooru/func/comments.py @@ -42,6 +42,7 @@ def create_comment(user, post, text): return comment def update_comment_text(comment, text): + assert comment if not text: raise EmptyCommentTextError('Comment text cannot be empty.') comment.text = text diff --git a/server/szurubooru/func/favorites.py b/server/szurubooru/func/favorites.py index 40406a6d..d9255765 100644 --- a/server/szurubooru/func/favorites.py +++ b/server/szurubooru/func/favorites.py @@ -5,23 +5,32 @@ from szurubooru.func import scores class InvalidFavoriteTargetError(errors.ValidationError): pass def _get_table_info(entity): + assert entity resource_type, _, _ = db.util.get_resource_info(entity) if resource_type == 'post': return db.PostFavorite, lambda table: table.post_id raise InvalidFavoriteTargetError() def _get_fav_entity(entity, user): + assert entity + assert user return db.util.get_aux_entity(db.session, _get_table_info, entity, user) def has_favorited(entity, user): + assert entity + assert user return _get_fav_entity(entity, user) is not None def unset_favorite(entity, user): + assert entity + assert user fav_entity = _get_fav_entity(entity, user) if fav_entity: db.session.delete(fav_entity) def set_favorite(entity, user): + assert entity + assert user try: scores.set_score(entity, user, 1) except scores.InvalidScoreTargetError: diff --git a/server/szurubooru/func/net.py b/server/szurubooru/func/net.py index 48a93522..c9d96089 100644 --- a/server/szurubooru/func/net.py +++ b/server/szurubooru/func/net.py @@ -1,6 +1,7 @@ import urllib.request def download(url): + assert url request = urllib.request.Request(url) request.add_header('Referer', url) with urllib.request.urlopen(request) as handle: diff --git a/server/szurubooru/func/posts.py b/server/szurubooru/func/posts.py index e321a81a..d8cc59bb 100644 --- a/server/szurubooru/func/posts.py +++ b/server/szurubooru/func/posts.py @@ -36,27 +36,33 @@ FLAG_MAP = { } def get_post_content_url(post): + assert post return '%s/posts/%d.%s' % ( config.config['data_url'].rstrip('/'), post.post_id, mime.get_extension(post.mime_type) or 'dat') def get_post_thumbnail_url(post): + assert post return '%s/generated-thumbnails/%d.jpg' % ( config.config['data_url'].rstrip('/'), post.post_id) def get_post_content_path(post): + assert post return 'posts/%d.%s' % ( post.post_id, mime.get_extension(post.mime_type) or 'dat') def get_post_thumbnail_path(post): + assert post return 'generated-thumbnails/%d.jpg' % (post.post_id) def get_post_thumbnail_backup_path(post): + assert post return 'posts/custom-thumbnails/%d.dat' % (post.post_id) def serialize_note(note): + assert note return { 'polygon': note.polygon, 'text': note.text, @@ -175,6 +181,7 @@ def create_post(content, tag_names, user): return (post, new_tags) def update_post_safety(post, safety): + assert post safety = util.flip(SAFETY_MAP).get(safety, None) if not safety: raise InvalidPostSafetyError( @@ -182,11 +189,13 @@ def update_post_safety(post, safety): post.safety = safety def update_post_source(post, source): + assert post if util.value_exceeds_column_size(source, db.Post.source): raise InvalidPostSourceError('Source is too long.') post.source = source def update_post_content(post, content): + assert post if not content: raise InvalidPostContentError('Post content missing.') post.mime_type = mime.get_mime_type(content) @@ -227,6 +236,7 @@ def update_post_content(post, content): update_post_thumbnail(post, content=None, do_delete=False) def update_post_thumbnail(post, content=None, do_delete=True): + assert post if not content: content = files.get(get_post_content_path(post)) if do_delete: @@ -236,6 +246,7 @@ def update_post_thumbnail(post, content=None, do_delete=True): generate_post_thumbnail(post) def generate_post_thumbnail(post): + assert post if files.has(get_post_thumbnail_backup_path(post)): content = files.get(get_post_thumbnail_backup_path(post)) else: @@ -250,11 +261,13 @@ def generate_post_thumbnail(post): files.save(get_post_thumbnail_path(post), EMPTY_PIXEL) def update_post_tags(post, tag_names): + assert post existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names) post.tags = existing_tags + new_tags return new_tags def update_post_relations(post, new_post_ids): + assert post old_posts = post.relations old_post_ids = [p.post_id for p in old_posts] new_posts = db.session \ @@ -274,6 +287,7 @@ def update_post_relations(post, new_post_ids): relation.relations.append(post) def update_post_notes(post, notes): + assert post post.notes = [] for note in notes: for field in ('polygon', 'text'): @@ -309,6 +323,7 @@ def update_post_notes(post, notes): db.PostNote(polygon=note['polygon'], text=str(note['text']))) def update_post_flags(post, flags): + assert post target_flags = [] for flag in flags: flag = util.flip(FLAG_MAP).get(flag, None) @@ -319,6 +334,7 @@ def update_post_flags(post, flags): post.flags = target_flags def feature_post(post, user): + assert post post_feature = db.PostFeature() post_feature.time = datetime.datetime.utcnow() post_feature.post = post @@ -326,4 +342,5 @@ def feature_post(post, user): db.session.add(post_feature) def delete(post): + assert post db.session.delete(post) diff --git a/server/szurubooru/func/scores.py b/server/szurubooru/func/scores.py index 4ed86320..75c623ec 100644 --- a/server/szurubooru/func/scores.py +++ b/server/szurubooru/func/scores.py @@ -6,6 +6,7 @@ class InvalidScoreTargetError(errors.ValidationError): pass class InvalidScoreValueError(errors.ValidationError): pass def _get_table_info(entity): + assert entity resource_type, _, _ = db.util.get_resource_info(entity) if resource_type == 'post': return db.PostScore, lambda table: table.post_id @@ -14,14 +15,19 @@ def _get_table_info(entity): raise InvalidScoreTargetError() def _get_score_entity(entity, user): + assert user return db.util.get_aux_entity(db.session, _get_table_info, entity, user) def delete_score(entity, user): + assert entity + assert user score_entity = _get_score_entity(entity, user) if score_entity: db.session.delete(score_entity) def get_score(entity, user): + assert entity + assert user table, get_column = _get_table_info(entity) row = db.session \ .query(table.score) \ @@ -31,6 +37,8 @@ def get_score(entity, user): return row[0] if row else 0 def set_score(entity, user, score): + assert entity + assert user if not score: delete_score(entity, user) try: diff --git a/server/szurubooru/func/snapshots.py b/server/szurubooru/func/snapshots.py index c65e47b2..fdfa2b34 100644 --- a/server/szurubooru/func/snapshots.py +++ b/server/szurubooru/func/snapshots.py @@ -40,6 +40,7 @@ serializers = { } def get_previous_snapshot(snapshot): + assert snapshot return db.session \ .query(db.Snapshot) \ .filter(db.Snapshot.resource_type == snapshot.resource_type) \ @@ -50,6 +51,7 @@ def get_previous_snapshot(snapshot): .first() def get_snapshots(entity): + assert entity resource_type, resource_id, _ = db.util.get_resource_info(entity) return db.session \ .query(db.Snapshot) \ @@ -59,6 +61,7 @@ def get_snapshots(entity): .all() def serialize_snapshot(snapshot, earlier_snapshot=()): + assert snapshot if earlier_snapshot is (): earlier_snapshot = get_previous_snapshot(snapshot) return { @@ -82,6 +85,8 @@ def get_serialized_history(entity): return ret def _save(operation, entity, auth_user): + assert operation + assert entity resource_type, resource_id, resource_repr = db.util.get_resource_info(entity) now = datetime.datetime.utcnow() @@ -115,10 +120,13 @@ def _save(operation, entity, auth_user): db.session.add(snapshot) def save_entity_creation(entity, auth_user): + assert entity _save(db.Snapshot.OPERATION_CREATED, entity, auth_user) def save_entity_modification(entity, auth_user): + assert entity _save(db.Snapshot.OPERATION_MODIFIED, entity, auth_user) def save_entity_deletion(entity, auth_user): + assert entity _save(db.Snapshot.OPERATION_DELETED, entity, auth_user) diff --git a/server/szurubooru/func/tag_categories.py b/server/szurubooru/func/tag_categories.py index ea84b8a7..ddc4f895 100644 --- a/server/szurubooru/func/tag_categories.py +++ b/server/szurubooru/func/tag_categories.py @@ -37,6 +37,7 @@ def create_category(name, color): return category def update_category_name(category, name): + assert category if not name: raise InvalidTagCategoryNameError('Name cannot be empty.') expr = sqlalchemy.func.lower(db.TagCategory.name) == name.lower() @@ -52,6 +53,7 @@ def update_category_name(category, name): category.name = name def update_category_color(category, color): + assert category if not color: raise InvalidTagCategoryColorError('Color cannot be empty.') if not re.match(r'^#?[0-9a-z]+$', color): @@ -103,6 +105,7 @@ def get_default_category(): return category def set_default_category(category): + assert category old_category = try_get_default_category() if old_category: old_category.default = False diff --git a/server/szurubooru/func/tags.py b/server/szurubooru/func/tags.py index cdec0e6c..83579fd6 100644 --- a/server/szurubooru/func/tags.py +++ b/server/szurubooru/func/tags.py @@ -20,6 +20,7 @@ def _verify_name_validity(name): raise InvalidTagNameError('Name must satisfy regex %r.' % name_regex) def _get_plain_names(tag): + assert tag return [tag_name.name for tag_name in tag.names] def _lower_list(names): @@ -157,6 +158,7 @@ def get_or_create_tags_by_names(names): return existing_tags, new_tags def get_tag_siblings(tag): + assert tag tag_alias = sqlalchemy.orm.aliased(db.Tag) pt_alias1 = sqlalchemy.orm.aliased(db.PostTag) pt_alias2 = sqlalchemy.orm.aliased(db.PostTag) @@ -172,6 +174,7 @@ def get_tag_siblings(tag): return result def delete(source_tag): + assert source_tag db.session.execute( sqlalchemy.sql.expression.delete(db.TagSuggestion) \ .where(db.TagSuggestion.child_id == source_tag.tag_id)) @@ -181,6 +184,8 @@ def delete(source_tag): db.session.delete(source_tag) def merge_tags(source_tag, target_tag): + assert source_tag + assert target_tag if source_tag.tag_id == target_tag.tag_id: raise InvalidTagRelationError('Cannot merge tag with itself.') pt1 = db.PostTag @@ -205,9 +210,11 @@ def create_tag(names, category_name, suggestions, implications): return tag def update_tag_category_name(tag, category_name): + assert tag tag.category = tag_categories.get_category_by_name(category_name) def update_tag_names(tag, names): + assert tag names = util.icase_unique([name for name in names if name]) if not len(names): raise InvalidTagNameError('At least one name must be specified.') @@ -232,16 +239,19 @@ def update_tag_names(tag, names): tag.names.append(db.TagName(name)) def update_tag_implications(tag, relations): + assert tag if _check_name_intersection(_get_plain_names(tag), relations): raise InvalidTagRelationError('Tag cannot imply itself.') tag.implications = get_tags_by_names(relations) def update_tag_suggestions(tag, relations): + assert tag if _check_name_intersection(_get_plain_names(tag), relations): raise InvalidTagRelationError('Tag cannot suggest itself.') tag.suggestions = get_tags_by_names(relations) def update_tag_description(tag, description): + assert tag if util.value_exceeds_column_size(description, db.Tag.description): raise InvalidTagDescriptionError('Description is too long.') tag.description = description diff --git a/server/szurubooru/func/users.py b/server/szurubooru/func/users.py index 6771b892..2f00a389 100644 --- a/server/szurubooru/func/users.py +++ b/server/szurubooru/func/users.py @@ -16,15 +16,20 @@ def _get_avatar_path(name): return 'avatars/' + name.lower() + '.png' def _get_avatar_url(user): + assert user if user.avatar_style == user.AVATAR_GRAVATAR: + assert user.email or user.name return 'https://gravatar.com/avatar/%s?d=retro&s=%d' % ( util.get_md5((user.email or user.name).lower()), config.config['thumbnails']['avatar_width']) else: + assert user.name return '%s/avatars/%s.png' % ( config.config['data_url'].rstrip('/'), user.name.lower()) def _get_email(user, authenticated_user, force_show_email): + assert user + assert authenticated_user if not force_show_email \ and authenticated_user.user_id != user.user_id \ and not auth.has_privilege(authenticated_user, 'users:edit:any:email'): @@ -32,11 +37,15 @@ def _get_email(user, authenticated_user, force_show_email): return user.email def _get_liked_post_count(user, authenticated_user): + assert user + assert authenticated_user if authenticated_user.user_id != user.user_id: return False return user.liked_post_count def _get_disliked_post_count(user, authenticated_user): + assert user + assert authenticated_user if authenticated_user.user_id != user.user_id: return False return user.disliked_post_count @@ -113,6 +122,7 @@ def create_user(name, password, email): return user def update_user_name(user, name): + assert user if not name: raise InvalidUserNameError('Name cannot be empty.') if util.value_exceeds_column_size(name, db.User.name): @@ -130,6 +140,7 @@ def update_user_name(user, name): user.name = name def update_user_password(user, password): + assert user if not password: raise InvalidPasswordError('Password cannot be empty.') password_regex = config.config['password_regex'] @@ -140,6 +151,7 @@ def update_user_password(user, password): user.password_hash = auth.get_password_hash(user.password_salt, password) def update_user_email(user, email): + assert user if email: email = email.strip() if not email: @@ -151,6 +163,7 @@ def update_user_email(user, email): user.email = email def update_user_rank(user, rank, authenticated_user): + assert user if not rank: raise InvalidRankError('Rank cannot be empty.') rank = util.flip(auth.RANK_MAP).get(rank.strip(), None) @@ -166,6 +179,7 @@ def update_user_rank(user, rank, authenticated_user): user.rank = rank def update_user_avatar(user, avatar_style, avatar_content): + assert user if avatar_style == 'gravatar': user.avatar_style = user.AVATAR_GRAVATAR elif avatar_style == 'manual': @@ -186,9 +200,11 @@ def update_user_avatar(user, avatar_style, avatar_content): avatar_style, ['gravatar', 'manual'])) def bump_user_login_time(user): + assert user user.last_login_time = datetime.datetime.utcnow() def reset_user_password(user): + assert user password = auth.create_password() user.password_salt = auth.create_password() user.password_hash = auth.get_password_hash(user.password_salt, password)