diff --git a/server/szurubooru/func/auth.py b/server/szurubooru/func/auth.py index a71285ab..a9c40b0f 100644 --- a/server/szurubooru/func/auth.py +++ b/server/szurubooru/func/auth.py @@ -1,13 +1,13 @@ -import uuid - import hashlib import random from collections import OrderedDict from nacl.exceptions import InvalidkeyError -from nacl.pwhash import argon2id, verify from szurubooru import config, model, errors, db from szurubooru.func import util +from nacl.pwhash import argon2id, verify +import uuid + RANK_MAP = OrderedDict([ (model.User.RANK_ANONYMOUS, 'anonymous'), diff --git a/server/szurubooru/func/cache.py b/server/szurubooru/func/cache.py index 260dfbab..01e46592 100644 --- a/server/szurubooru/func/cache.py +++ b/server/szurubooru/func/cache.py @@ -1,5 +1,5 @@ -from datetime import datetime from typing import Any, List, Dict +from datetime import datetime class LruCacheItem: diff --git a/server/szurubooru/func/comments.py b/server/szurubooru/func/comments.py index 97fdf0b2..9f882831 100644 --- a/server/szurubooru/func/comments.py +++ b/server/szurubooru/func/comments.py @@ -1,6 +1,5 @@ from datetime import datetime from typing import Any, Optional, List, Dict, Callable - from szurubooru import db, model, errors, rest from szurubooru.func import users, scores, serialization @@ -76,9 +75,9 @@ def try_get_comment_by_id(comment_id: int) -> Optional[model.Comment]: comment_id = int(comment_id) return ( db.session - .query(model.Comment) - .filter(model.Comment.comment_id == comment_id) - .one_or_none()) + .query(model.Comment) + .filter(model.Comment.comment_id == comment_id) + .one_or_none()) def get_comment_by_id(comment_id: int) -> model.Comment: diff --git a/server/szurubooru/func/favorites.py b/server/szurubooru/func/favorites.py index 6e19a59b..f567bfad 100644 --- a/server/szurubooru/func/favorites.py +++ b/server/szurubooru/func/favorites.py @@ -1,6 +1,5 @@ -from datetime import datetime from typing import Any, Optional, Callable, Tuple - +from datetime import datetime from szurubooru import db, model, errors diff --git a/server/szurubooru/func/file_uploads.py b/server/szurubooru/func/file_uploads.py index ece9a252..e7f93d83 100644 --- a/server/szurubooru/func/file_uploads.py +++ b/server/szurubooru/func/file_uploads.py @@ -1,8 +1,8 @@ -from datetime import datetime, timedelta from typing import Optional - +from datetime import datetime, timedelta from szurubooru.func import files, util + MAX_MINUTES = 60 diff --git a/server/szurubooru/func/files.py b/server/szurubooru/func/files.py index 097ac309..fa9f36fd 100644 --- a/server/szurubooru/func/files.py +++ b/server/szurubooru/func/files.py @@ -1,7 +1,5 @@ from typing import Any, Optional, List - import os - from szurubooru import config diff --git a/server/szurubooru/func/image_hash.py b/server/szurubooru/func/image_hash.py index 456752bb..dae84435 100644 --- a/server/szurubooru/func/image_hash.py +++ b/server/szurubooru/func/image_hash.py @@ -1,14 +1,12 @@ import logging +from io import BytesIO from datetime import datetime from typing import Any, Optional, Tuple, Set, List, Callable - import elasticsearch import elasticsearch_dsl import numpy as np -from PIL import Image -from io import BytesIO from skimage.color import rgb2gray - +from PIL import Image from szurubooru import config, errors # pylint: disable=invalid-name @@ -135,7 +133,7 @@ def _compute_differentials(grey_level_matrix: NpMatrix) -> NpMatrix: np.diff(grey_level_matrix), ( np.zeros(grey_level_matrix.shape[0]) - .reshape((grey_level_matrix.shape[0], 1)) + .reshape((grey_level_matrix.shape[0], 1)) ) ), axis=1) down_neighbors = -np.concatenate( @@ -143,7 +141,7 @@ def _compute_differentials(grey_level_matrix: NpMatrix) -> NpMatrix: np.diff(grey_level_matrix, axis=0), ( np.zeros(grey_level_matrix.shape[1]) - .reshape((1, grey_level_matrix.shape[1])) + .reshape((1, grey_level_matrix.shape[1])) ) )) left_neighbors = -np.concatenate( @@ -209,7 +207,7 @@ def _get_words(array: NpMatrix, k: int, n: int) -> NpMatrix: def _words_to_int(word_array: NpMatrix) -> NpMatrix: width = word_array.shape[1] - coding_vector = 3 ** np.arange(width) + coding_vector = 3**np.arange(width) return np.dot(word_array + 1, coding_vector) @@ -249,9 +247,7 @@ def _safety_blanket(default_param_factory: Callable[[], Any]) -> Callable: raise errors.ProcessingError('Not an image.') except Exception as ex: raise errors.ThirdPartyError('Unknown error (%s).' % ex) - return wrapper_inner - return wrapper_outer @@ -353,5 +349,5 @@ def get_all_paths() -> Set[str]: using=_get_session(), index=config.config['elasticsearch']['index'], doc_type=ES_DOC_TYPE) - .source(['path'])) + .source(['path'])) return set(h.path for h in search.scan()) diff --git a/server/szurubooru/func/images.py b/server/szurubooru/func/images.py index 749d0898..0bf84ed5 100644 --- a/server/szurubooru/func/images.py +++ b/server/szurubooru/func/images.py @@ -1,15 +1,16 @@ -import json +from typing import List import logging -import math +import json import shlex import subprocess -from typing import List - +import math from szurubooru import errors from szurubooru.func import mime, util + logger = logging.getLogger(__name__) + _SCALE_FIT_FMT = ( r'scale=iw*max({width}/iw\,{height}/ih):ih*max({width}/iw\,{height}/ih)') diff --git a/server/szurubooru/func/mailer.py b/server/szurubooru/func/mailer.py index fbf96927..76682f11 100644 --- a/server/szurubooru/func/mailer.py +++ b/server/szurubooru/func/mailer.py @@ -1,6 +1,5 @@ -import email.mime.text import smtplib - +import email.mime.text from szurubooru import config diff --git a/server/szurubooru/func/mime.py b/server/szurubooru/func/mime.py index 327ad8b5..12e358c0 100644 --- a/server/szurubooru/func/mime.py +++ b/server/szurubooru/func/mime.py @@ -1,6 +1,5 @@ -from typing import Optional - import re +from typing import Optional APPLICATION_SWF = 'application/x-shockwave-flash' IMAGE_JPEG = 'image/jpeg' @@ -64,4 +63,4 @@ def is_image(mime_type: str) -> bool: def is_animated_gif(content: bytes) -> bool: pattern = b'\x21\xF9\x04[\x00-\xFF]{4}\x00[\x2C\x21]' return get_mime_type(content) == IMAGE_GIF \ - and len(re.findall(pattern, content)) > 1 + and len(re.findall(pattern, content)) > 1 diff --git a/server/szurubooru/func/net.py b/server/szurubooru/func/net.py index c8651bf5..e6326c06 100644 --- a/server/szurubooru/func/net.py +++ b/server/szurubooru/func/net.py @@ -1,6 +1,6 @@ import urllib.request - -from szurubooru import config, errors +from szurubooru import config +from szurubooru import errors def download(url: str) -> bytes: diff --git a/server/szurubooru/func/posts.py b/server/szurubooru/func/posts.py index 70a87d21..4e524387 100644 --- a/server/szurubooru/func/posts.py +++ b/server/szurubooru/func/posts.py @@ -1,14 +1,13 @@ -from datetime import datetime -from typing import Any, Optional, Tuple, List, Dict, Callable - import hmac +from typing import Any, Optional, Tuple, List, Dict, Callable +from datetime import datetime import sqlalchemy as sa - from szurubooru import config, db, model, errors, rest from szurubooru.func import ( users, scores, comments, tags, util, mime, images, files, image_hash, serialization, snapshots) + EMPTY_PIXEL = ( b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' @@ -238,8 +237,8 @@ class PostSerializer(serialization.BaseSerializer): { post['id']: post for post in [ - serialize_micro_post(rel, self.auth_user) - for rel in self.post.relations] + serialize_micro_post(rel, self.auth_user) + for rel in self.post.relations] }.values(), key=lambda post: post['id']) @@ -323,9 +322,9 @@ def get_post_count() -> int: def try_get_post_by_id(post_id: int) -> Optional[model.Post]: return ( db.session - .query(model.Post) - .filter(model.Post.post_id == post_id) - .one_or_none()) + .query(model.Post) + .filter(model.Post.post_id == post_id) + .one_or_none()) def get_post_by_id(post_id: int) -> model.Post: @@ -338,9 +337,9 @@ def get_post_by_id(post_id: int) -> model.Post: def try_get_current_post_feature() -> Optional[model.PostFeature]: return ( db.session - .query(model.PostFeature) - .order_by(model.PostFeature.time.desc()) - .first()) + .query(model.PostFeature) + .order_by(model.PostFeature.time.desc()) + .first()) def try_get_featured_post() -> Optional[model.Post]: @@ -487,10 +486,10 @@ def update_post_content(post: model.Post, content: Optional[bytes]) -> None: post.checksum = util.get_sha1(content) other_post = ( db.session - .query(model.Post) - .filter(model.Post.checksum == post.checksum) - .filter(model.Post.post_id != post.post_id) - .one_or_none()) + .query(model.Post) + .filter(model.Post.checksum == post.checksum) + .filter(model.Post.post_id != post.post_id) + .one_or_none()) if other_post \ and other_post.post_id \ and other_post.post_id != post.post_id: @@ -554,9 +553,9 @@ def update_post_relations(post: model.Post, new_post_ids: List[int]) -> None: if new_post_ids: new_posts = ( db.session - .query(model.Post) - .filter(model.Post.post_id.in_(new_post_ids)) - .all()) + .query(model.Post) + .filter(model.Post.post_id.in_(new_post_ids)) + .all()) else: new_posts = [] if len(new_posts) != len(new_post_ids): @@ -655,13 +654,15 @@ def merge_posts( alias2 = sa.orm.util.aliased(table) update_stmt = ( sa.sql.expression.update(alias1) - .where(alias1.post_id == source_post_id)) + .where(alias1.post_id == source_post_id)) if anti_dup_func is not None: update_stmt = ( - update_stmt.where(~sa.exists() - .where(anti_dup_func(alias1, alias2)) - .where(alias2.post_id == target_post_id))) + update_stmt + .where( + ~sa.exists() + .where(anti_dup_func(alias1, alias2)) + .where(alias2.post_id == target_post_id))) update_stmt = update_stmt.values(post_id=target_post_id) db.session.execute(update_stmt) @@ -695,24 +696,24 @@ def merge_posts( alias2 = sa.orm.util.aliased(model.PostRelation) update_stmt = ( sa.sql.expression.update(alias1) - .where(alias1.parent_id == source_post_id) - .where(alias1.child_id != target_post_id) - .where( + .where(alias1.parent_id == source_post_id) + .where(alias1.child_id != target_post_id) + .where( ~sa.exists() - .where(alias2.child_id == alias1.child_id) - .where(alias2.parent_id == target_post_id)) - .values(parent_id=target_post_id)) + .where(alias2.child_id == alias1.child_id) + .where(alias2.parent_id == target_post_id)) + .values(parent_id=target_post_id)) db.session.execute(update_stmt) update_stmt = ( sa.sql.expression.update(alias1) - .where(alias1.child_id == source_post_id) - .where(alias1.parent_id != target_post_id) - .where( + .where(alias1.child_id == source_post_id) + .where(alias1.parent_id != target_post_id) + .where( ~sa.exists() - .where(alias2.parent_id == alias1.parent_id) - .where(alias2.child_id == target_post_id)) - .values(child_id=target_post_id)) + .where(alias2.parent_id == alias1.parent_id) + .where(alias2.child_id == target_post_id)) + .values(child_id=target_post_id)) db.session.execute(update_stmt) merge_tags(source_post.post_id, target_post.post_id) @@ -733,9 +734,10 @@ def merge_posts( def search_by_image_exact(image_content: bytes) -> Optional[model.Post]: checksum = util.get_sha1(image_content) return ( - db.session.query(model.Post) - .filter(model.Post.checksum == checksum) - .one_or_none()) + db.session + .query(model.Post) + .filter(model.Post.checksum == checksum) + .one_or_none()) def search_by_image(image_content: bytes) -> List[PostLookalike]: @@ -754,19 +756,21 @@ def populate_reverse_search() -> None: excluded_post_ids = image_hash.get_all_paths() post_ids_to_hash = ( - db.session.query(model.Post.post_id) - .filter( + db.session + .query(model.Post.post_id) + .filter( (model.Post.type == model.Post.TYPE_IMAGE) | (model.Post.type == model.Post.TYPE_ANIMATION)) - .filter(~model.Post.post_id.in_(excluded_post_ids)) - .order_by(model.Post.post_id.asc()) - .all()) + .filter(~model.Post.post_id.in_(excluded_post_ids)) + .order_by(model.Post.post_id.asc()) + .all()) for post_ids_chunk in util.chunks(post_ids_to_hash, 100): posts_chunk = ( - db.session.query(model.Post) - .filter(model.Post.post_id.in_(post_ids_chunk)) - .all()) + db.session + .query(model.Post) + .filter(model.Post.post_id.in_(post_ids_chunk)) + .all()) for post in posts_chunk: content_path = get_post_content_path(post) if files.has(content_path): diff --git a/server/szurubooru/func/scores.py b/server/szurubooru/func/scores.py index a35206bc..615fd981 100644 --- a/server/szurubooru/func/scores.py +++ b/server/szurubooru/func/scores.py @@ -1,6 +1,5 @@ import datetime from typing import Any, Tuple, Callable - from szurubooru import db, model, errors @@ -41,10 +40,11 @@ def get_score(entity: model.Base, user: model.User) -> int: assert user table, get_column = _get_table_info(entity) row = ( - db.session.query(table.score) - .filter(get_column(table) == get_column(entity)) - .filter(table.user_id == user.user_id) - .one_or_none()) + db.session + .query(table.score) + .filter(get_column(table) == get_column(entity)) + .filter(table.user_id == user.user_id) + .one_or_none()) return row[0] if row else 0 diff --git a/server/szurubooru/func/serialization.py b/server/szurubooru/func/serialization.py index 42a5413b..699fb473 100644 --- a/server/szurubooru/func/serialization.py +++ b/server/szurubooru/func/serialization.py @@ -1,5 +1,4 @@ from typing import Any, List, Dict, Callable - from szurubooru import model, rest, errors diff --git a/server/szurubooru/func/snapshots.py b/server/szurubooru/func/snapshots.py index dded0eda..240c3bce 100644 --- a/server/szurubooru/func/snapshots.py +++ b/server/szurubooru/func/snapshots.py @@ -1,6 +1,5 @@ -from datetime import datetime from typing import Any, Optional, Dict, Callable - +from datetime import datetime from szurubooru import db, model from szurubooru.func import diff, users @@ -97,7 +96,7 @@ def modify(entity: model.Base, auth_user: Optional[model.User]) -> None: cls for cls in model.Base._decl_class_registry.values() if hasattr(cls, '__table__') - and cls.__table__.fullname == entity.__table__.fullname + and cls.__table__.fullname == entity.__table__.fullname ), None) assert table diff --git a/server/szurubooru/func/tag_categories.py b/server/szurubooru/func/tag_categories.py index 3a6a3853..f1951a8c 100644 --- a/server/szurubooru/func/tag_categories.py +++ b/server/szurubooru/func/tag_categories.py @@ -1,11 +1,10 @@ -from typing import Any, Optional, Dict, List, Callable - import re +from typing import Any, Optional, Dict, List, Callable import sqlalchemy as sa - from szurubooru import config, db, model, errors, rest from szurubooru.func import util, serialization, cache + DEFAULT_CATEGORY_NAME_CACHE_KEY = 'default-tag-category' @@ -89,9 +88,9 @@ def update_category_name(category: model.TagCategory, name: str) -> None: expr = sa.func.lower(model.TagCategory.name) == name.lower() if category.tag_category_id: expr = expr & ( - model.TagCategory.tag_category_id != category.tag_category_id) + model.TagCategory.tag_category_id != category.tag_category_id) already_exists = ( - db.session.query(model.TagCategory).filter(expr).count() > 0) + db.session.query(model.TagCategory).filter(expr).count() > 0) if already_exists: raise TagCategoryAlreadyExistsError( 'A category with this name already exists.') @@ -116,8 +115,9 @@ def update_category_color(category: model.TagCategory, color: str) -> None: def try_get_category_by_name( name: str, lock: bool = False) -> Optional[model.TagCategory]: query = ( - db.session.query(model.TagCategory) - .filter(sa.func.lower(model.TagCategory.name) == name.lower())) + db.session + .query(model.TagCategory) + .filter(sa.func.lower(model.TagCategory.name) == name.lower())) if lock: query = query.with_lockmode('update') return query.one_or_none() @@ -141,8 +141,9 @@ def get_all_categories() -> List[model.TagCategory]: def try_get_default_category( lock: bool = False) -> Optional[model.TagCategory]: query = ( - db.session.query(model.TagCategory) - .filter(model.TagCategory.default)) + db.session + .query(model.TagCategory) + .filter(model.TagCategory.default)) if lock: query = query.with_lockmode('update') category = query.first() @@ -150,8 +151,9 @@ def try_get_default_category( # category, get the first record available. if not category: query = ( - db.session.query(model.TagCategory) - .order_by(model.TagCategory.tag_category_id.asc())) + db.session + .query(model.TagCategory) + .order_by(model.TagCategory.tag_category_id.asc())) if lock: query = query.with_lockmode('update') category = query.first() diff --git a/server/szurubooru/func/tags.py b/server/szurubooru/func/tags.py index 0064da4f..7d92f1e7 100644 --- a/server/szurubooru/func/tags.py +++ b/server/szurubooru/func/tags.py @@ -1,9 +1,9 @@ -from datetime import datetime -from typing import Any, Optional, Tuple, List, Dict, Callable - +import json +import os import re +from typing import Any, Optional, Tuple, List, Dict, Callable +from datetime import datetime import sqlalchemy as sa - from szurubooru import config, db, model, errors, rest from szurubooru.func import util, tag_categories, serialization @@ -138,10 +138,11 @@ def serialize_tag( def try_get_tag_by_name(name: str) -> Optional[model.Tag]: return ( - db.session.query(model.Tag) - .join(model.TagName) - .filter(sa.func.lower(model.TagName.name) == name.lower()) - .one_or_none()) + db.session + .query(model.Tag) + .join(model.TagName) + .filter(sa.func.lower(model.TagName.name) == name.lower()) + .one_or_none()) def get_tag_by_name(name: str) -> model.Tag: @@ -157,12 +158,12 @@ def get_tags_by_names(names: List[str]) -> List[model.Tag]: return [] return ( db.session.query(model.Tag) - .join(model.TagName) - .filter( + .join(model.TagName) + .filter( sa.sql.or_( sa.func.lower(model.TagName.name) == name.lower() for name in names)) - .all()) + .all()) def get_or_create_tags_by_names( @@ -195,15 +196,16 @@ def get_tag_siblings(tag: model.Tag) -> List[model.Tag]: pt_alias1 = sa.orm.aliased(model.PostTag) pt_alias2 = sa.orm.aliased(model.PostTag) result = ( - db.session.query(tag_alias, sa.func.count(pt_alias2.post_id)) - .join(pt_alias1, pt_alias1.tag_id == tag_alias.tag_id) - .join(pt_alias2, pt_alias2.post_id == pt_alias1.post_id) - .filter(pt_alias2.tag_id == tag.tag_id) - .filter(pt_alias1.tag_id != tag.tag_id) - .group_by(tag_alias.tag_id) - .order_by(sa.func.count(pt_alias2.post_id).desc()) - .order_by(tag_alias.first_name) - .limit(50)) + db.session + .query(tag_alias, sa.func.count(pt_alias2.post_id)) + .join(pt_alias1, pt_alias1.tag_id == tag_alias.tag_id) + .join(pt_alias2, pt_alias2.post_id == pt_alias1.post_id) + .filter(pt_alias2.tag_id == tag.tag_id) + .filter(pt_alias1.tag_id != tag.tag_id) + .group_by(tag_alias.tag_id) + .order_by(sa.func.count(pt_alias2.post_id).desc()) + .order_by(tag_alias.first_name) + .limit(50)) return result @@ -211,10 +213,10 @@ def delete(source_tag: model.Tag) -> None: assert source_tag db.session.execute( sa.sql.expression.delete(model.TagSuggestion) - .where(model.TagSuggestion.child_id == source_tag.tag_id)) + .where(model.TagSuggestion.child_id == source_tag.tag_id)) db.session.execute( sa.sql.expression.delete(model.TagImplication) - .where(model.TagImplication.child_id == source_tag.tag_id)) + .where(model.TagImplication.child_id == source_tag.tag_id)) db.session.delete(source_tag) @@ -229,12 +231,13 @@ def merge_tags(source_tag: model.Tag, target_tag: model.Tag) -> None: alias2 = sa.orm.util.aliased(model.PostTag) update_stmt = ( sa.sql.expression.update(alias1) - .where(alias1.tag_id == source_tag_id)) + .where(alias1.tag_id == source_tag_id)) update_stmt = ( update_stmt - .where(~sa.exists() - .where(alias1.post_id == alias2.post_id) - .where(alias2.tag_id == target_tag_id))) + .where( + ~sa.exists() + .where(alias1.post_id == alias2.post_id) + .where(alias2.tag_id == target_tag_id))) update_stmt = update_stmt.values(tag_id=target_tag_id) db.session.execute(update_stmt) @@ -244,22 +247,24 @@ def merge_tags(source_tag: model.Tag, target_tag: model.Tag) -> None: alias2 = sa.orm.util.aliased(table) update_stmt = ( sa.sql.expression.update(alias1) - .where(alias1.parent_id == source_tag_id) - .where(alias1.child_id != target_tag_id) - .where(~sa.exists() - .where(alias2.child_id == alias1.child_id) - .where(alias2.parent_id == target_tag_id)) - .values(parent_id=target_tag_id)) + .where(alias1.parent_id == source_tag_id) + .where(alias1.child_id != target_tag_id) + .where( + ~sa.exists() + .where(alias2.child_id == alias1.child_id) + .where(alias2.parent_id == target_tag_id)) + .values(parent_id=target_tag_id)) db.session.execute(update_stmt) update_stmt = ( sa.sql.expression.update(alias1) - .where(alias1.child_id == source_tag_id) - .where(alias1.parent_id != target_tag_id) - .where(~sa.exists() - .where(alias2.parent_id == alias1.parent_id) - .where(alias2.child_id == target_tag_id)) - .values(child_id=target_tag_id)) + .where(alias1.child_id == source_tag_id) + .where(alias1.parent_id != target_tag_id) + .where( + ~sa.exists() + .where(alias2.parent_id == alias1.parent_id) + .where(alias2.child_id == target_tag_id)) + .values(child_id=target_tag_id)) db.session.execute(update_stmt) def merge_suggestions(source_tag_id: int, target_tag_id: int) -> None: diff --git a/server/szurubooru/func/util.py b/server/szurubooru/func/util.py index 164b2545..ba2d4dc9 100644 --- a/server/szurubooru/func/util.py +++ b/server/szurubooru/func/util.py @@ -1,21 +1,20 @@ -from contextlib import contextmanager -from datetime import datetime, timedelta -from typing import Any, Optional, Union, Tuple, List, Dict, Generator, TypeVar - -import hashlib import os +import hashlib import re import tempfile - +from typing import Any, Optional, Union, Tuple, List, Dict, Generator, TypeVar +from datetime import datetime, timedelta +from contextlib import contextmanager from szurubooru import errors + T = TypeVar('T') def snake_case_to_lower_camel_case(text: str) -> str: components = text.split('_') return components[0].lower() + \ - ''.join(word[0].upper() + word[1:].lower() for word in components[1:]) + ''.join(word[0].upper() + word[1:].lower() for word in components[1:]) def snake_case_to_upper_train_case(text: str) -> str: @@ -87,7 +86,6 @@ def is_valid_email(email: Optional[str]) -> bool: class dotdict(dict): # pylint: disable=invalid-name ''' dot.notation access to dictionary attributes. ''' - def __getattr__(self, attr: str) -> Any: return self.get(attr)