Cleanup func imports, and small formatting changes.

This commit is contained in:
ReAnzu 2018-02-25 17:30:48 -06:00
parent a526a56767
commit 796563f772
18 changed files with 147 additions and 145 deletions

View file

@ -1,13 +1,13 @@
import uuid
import hashlib import hashlib
import random import random
from collections import OrderedDict from collections import OrderedDict
from nacl.exceptions import InvalidkeyError from nacl.exceptions import InvalidkeyError
from nacl.pwhash import argon2id, verify
from szurubooru import config, model, errors, db from szurubooru import config, model, errors, db
from szurubooru.func import util from szurubooru.func import util
from nacl.pwhash import argon2id, verify
import uuid
RANK_MAP = OrderedDict([ RANK_MAP = OrderedDict([
(model.User.RANK_ANONYMOUS, 'anonymous'), (model.User.RANK_ANONYMOUS, 'anonymous'),

View file

@ -1,5 +1,5 @@
from typing import Any, List, Dict
from datetime import datetime from datetime import datetime
from typing import Any, List, Dict
class LruCacheItem: class LruCacheItem:

View file

@ -1,5 +1,6 @@
from datetime import datetime from datetime import datetime
from typing import Any, Optional, List, Dict, Callable from typing import Any, Optional, List, Dict, Callable
from szurubooru import db, model, errors, rest from szurubooru import db, model, errors, rest
from szurubooru.func import users, scores, serialization from szurubooru.func import users, scores, serialization
@ -75,9 +76,9 @@ def try_get_comment_by_id(comment_id: int) -> Optional[model.Comment]:
comment_id = int(comment_id) comment_id = int(comment_id)
return ( return (
db.session db.session
.query(model.Comment) .query(model.Comment)
.filter(model.Comment.comment_id == comment_id) .filter(model.Comment.comment_id == comment_id)
.one_or_none()) .one_or_none())
def get_comment_by_id(comment_id: int) -> model.Comment: def get_comment_by_id(comment_id: int) -> model.Comment:

View file

@ -1,5 +1,6 @@
from typing import Any, Optional, Callable, Tuple
from datetime import datetime from datetime import datetime
from typing import Any, Optional, Callable, Tuple
from szurubooru import db, model, errors from szurubooru import db, model, errors

View file

@ -1,7 +1,7 @@
from typing import Optional
from datetime import datetime, timedelta from datetime import datetime, timedelta
from szurubooru.func import files, util from typing import Optional
from szurubooru.func import files, util
MAX_MINUTES = 60 MAX_MINUTES = 60

View file

@ -1,5 +1,7 @@
from typing import Any, Optional, List from typing import Any, Optional, List
import os import os
from szurubooru import config from szurubooru import config

View file

@ -1,12 +1,14 @@
import logging import logging
from io import BytesIO
from datetime import datetime from datetime import datetime
from typing import Any, Optional, Tuple, Set, List, Callable from typing import Any, Optional, Tuple, Set, List, Callable
import elasticsearch import elasticsearch
import elasticsearch_dsl import elasticsearch_dsl
import numpy as np import numpy as np
from skimage.color import rgb2gray
from PIL import Image from PIL import Image
from io import BytesIO
from skimage.color import rgb2gray
from szurubooru import config, errors from szurubooru import config, errors
# pylint: disable=invalid-name # pylint: disable=invalid-name
@ -133,7 +135,7 @@ def _compute_differentials(grey_level_matrix: NpMatrix) -> NpMatrix:
np.diff(grey_level_matrix), np.diff(grey_level_matrix),
( (
np.zeros(grey_level_matrix.shape[0]) np.zeros(grey_level_matrix.shape[0])
.reshape((grey_level_matrix.shape[0], 1)) .reshape((grey_level_matrix.shape[0], 1))
) )
), axis=1) ), axis=1)
down_neighbors = -np.concatenate( down_neighbors = -np.concatenate(
@ -141,7 +143,7 @@ def _compute_differentials(grey_level_matrix: NpMatrix) -> NpMatrix:
np.diff(grey_level_matrix, axis=0), np.diff(grey_level_matrix, axis=0),
( (
np.zeros(grey_level_matrix.shape[1]) np.zeros(grey_level_matrix.shape[1])
.reshape((1, grey_level_matrix.shape[1])) .reshape((1, grey_level_matrix.shape[1]))
) )
)) ))
left_neighbors = -np.concatenate( left_neighbors = -np.concatenate(
@ -207,7 +209,7 @@ def _get_words(array: NpMatrix, k: int, n: int) -> NpMatrix:
def _words_to_int(word_array: NpMatrix) -> NpMatrix: def _words_to_int(word_array: NpMatrix) -> NpMatrix:
width = word_array.shape[1] width = word_array.shape[1]
coding_vector = 3**np.arange(width) coding_vector = 3 ** np.arange(width)
return np.dot(word_array + 1, coding_vector) return np.dot(word_array + 1, coding_vector)
@ -247,7 +249,9 @@ def _safety_blanket(default_param_factory: Callable[[], Any]) -> Callable:
raise errors.ProcessingError('Not an image.') raise errors.ProcessingError('Not an image.')
except Exception as ex: except Exception as ex:
raise errors.ThirdPartyError('Unknown error (%s).' % ex) raise errors.ThirdPartyError('Unknown error (%s).' % ex)
return wrapper_inner return wrapper_inner
return wrapper_outer return wrapper_outer
@ -349,5 +353,5 @@ def get_all_paths() -> Set[str]:
using=_get_session(), using=_get_session(),
index=config.config['elasticsearch']['index'], index=config.config['elasticsearch']['index'],
doc_type=ES_DOC_TYPE) doc_type=ES_DOC_TYPE)
.source(['path'])) .source(['path']))
return set(h.path for h in search.scan()) return set(h.path for h in search.scan())

View file

@ -1,16 +1,15 @@
from typing import List
import logging
import json import json
import logging
import math
import shlex import shlex
import subprocess import subprocess
import math from typing import List
from szurubooru import errors from szurubooru import errors
from szurubooru.func import mime, util from szurubooru.func import mime, util
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_SCALE_FIT_FMT = ( _SCALE_FIT_FMT = (
r'scale=iw*max({width}/iw\,{height}/ih):ih*max({width}/iw\,{height}/ih)') r'scale=iw*max({width}/iw\,{height}/ih):ih*max({width}/iw\,{height}/ih)')

View file

@ -1,5 +1,6 @@
import smtplib
import email.mime.text import email.mime.text
import smtplib
from szurubooru import config from szurubooru import config

View file

@ -1,6 +1,7 @@
import re
from typing import Optional from typing import Optional
import re
APPLICATION_SWF = 'application/x-shockwave-flash' APPLICATION_SWF = 'application/x-shockwave-flash'
IMAGE_JPEG = 'image/jpeg' IMAGE_JPEG = 'image/jpeg'
IMAGE_PNG = 'image/png' IMAGE_PNG = 'image/png'
@ -63,4 +64,4 @@ def is_image(mime_type: str) -> bool:
def is_animated_gif(content: bytes) -> bool: def is_animated_gif(content: bytes) -> bool:
pattern = b'\x21\xF9\x04[\x00-\xFF]{4}\x00[\x2C\x21]' pattern = b'\x21\xF9\x04[\x00-\xFF]{4}\x00[\x2C\x21]'
return get_mime_type(content) == IMAGE_GIF \ return get_mime_type(content) == IMAGE_GIF \
and len(re.findall(pattern, content)) > 1 and len(re.findall(pattern, content)) > 1

View file

@ -1,6 +1,6 @@
import urllib.request import urllib.request
from szurubooru import config
from szurubooru import errors from szurubooru import config, errors
def download(url: str) -> bytes: def download(url: str) -> bytes:

View file

@ -1,13 +1,14 @@
import hmac
from typing import Any, Optional, Tuple, List, Dict, Callable
from datetime import datetime from datetime import datetime
from typing import Any, Optional, Tuple, List, Dict, Callable
import hmac
import sqlalchemy as sa import sqlalchemy as sa
from szurubooru import config, db, model, errors, rest from szurubooru import config, db, model, errors, rest
from szurubooru.func import ( from szurubooru.func import (
users, scores, comments, tags, util, users, scores, comments, tags, util,
mime, images, files, image_hash, serialization, snapshots) mime, images, files, image_hash, serialization, snapshots)
EMPTY_PIXEL = ( EMPTY_PIXEL = (
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00'
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00'
@ -237,8 +238,8 @@ class PostSerializer(serialization.BaseSerializer):
{ {
post['id']: post post['id']: post
for post in [ for post in [
serialize_micro_post(rel, self.auth_user) serialize_micro_post(rel, self.auth_user)
for rel in self.post.relations] for rel in self.post.relations]
}.values(), }.values(),
key=lambda post: post['id']) key=lambda post: post['id'])
@ -322,9 +323,9 @@ def get_post_count() -> int:
def try_get_post_by_id(post_id: int) -> Optional[model.Post]: def try_get_post_by_id(post_id: int) -> Optional[model.Post]:
return ( return (
db.session db.session
.query(model.Post) .query(model.Post)
.filter(model.Post.post_id == post_id) .filter(model.Post.post_id == post_id)
.one_or_none()) .one_or_none())
def get_post_by_id(post_id: int) -> model.Post: def get_post_by_id(post_id: int) -> model.Post:
@ -337,9 +338,9 @@ def get_post_by_id(post_id: int) -> model.Post:
def try_get_current_post_feature() -> Optional[model.PostFeature]: def try_get_current_post_feature() -> Optional[model.PostFeature]:
return ( return (
db.session db.session
.query(model.PostFeature) .query(model.PostFeature)
.order_by(model.PostFeature.time.desc()) .order_by(model.PostFeature.time.desc())
.first()) .first())
def try_get_featured_post() -> Optional[model.Post]: def try_get_featured_post() -> Optional[model.Post]:
@ -486,10 +487,10 @@ def update_post_content(post: model.Post, content: Optional[bytes]) -> None:
post.checksum = util.get_sha1(content) post.checksum = util.get_sha1(content)
other_post = ( other_post = (
db.session db.session
.query(model.Post) .query(model.Post)
.filter(model.Post.checksum == post.checksum) .filter(model.Post.checksum == post.checksum)
.filter(model.Post.post_id != post.post_id) .filter(model.Post.post_id != post.post_id)
.one_or_none()) .one_or_none())
if other_post \ if other_post \
and other_post.post_id \ and other_post.post_id \
and other_post.post_id != post.post_id: and other_post.post_id != post.post_id:
@ -553,9 +554,9 @@ def update_post_relations(post: model.Post, new_post_ids: List[int]) -> None:
if new_post_ids: if new_post_ids:
new_posts = ( new_posts = (
db.session db.session
.query(model.Post) .query(model.Post)
.filter(model.Post.post_id.in_(new_post_ids)) .filter(model.Post.post_id.in_(new_post_ids))
.all()) .all())
else: else:
new_posts = [] new_posts = []
if len(new_posts) != len(new_post_ids): if len(new_posts) != len(new_post_ids):
@ -654,15 +655,13 @@ def merge_posts(
alias2 = sa.orm.util.aliased(table) alias2 = sa.orm.util.aliased(table)
update_stmt = ( update_stmt = (
sa.sql.expression.update(alias1) sa.sql.expression.update(alias1)
.where(alias1.post_id == source_post_id)) .where(alias1.post_id == source_post_id))
if anti_dup_func is not None: if anti_dup_func is not None:
update_stmt = ( update_stmt = (
update_stmt update_stmt.where(~sa.exists()
.where( .where(anti_dup_func(alias1, alias2))
~sa.exists() .where(alias2.post_id == target_post_id)))
.where(anti_dup_func(alias1, alias2))
.where(alias2.post_id == target_post_id)))
update_stmt = update_stmt.values(post_id=target_post_id) update_stmt = update_stmt.values(post_id=target_post_id)
db.session.execute(update_stmt) db.session.execute(update_stmt)
@ -696,24 +695,24 @@ def merge_posts(
alias2 = sa.orm.util.aliased(model.PostRelation) alias2 = sa.orm.util.aliased(model.PostRelation)
update_stmt = ( update_stmt = (
sa.sql.expression.update(alias1) sa.sql.expression.update(alias1)
.where(alias1.parent_id == source_post_id) .where(alias1.parent_id == source_post_id)
.where(alias1.child_id != target_post_id) .where(alias1.child_id != target_post_id)
.where( .where(
~sa.exists() ~sa.exists()
.where(alias2.child_id == alias1.child_id) .where(alias2.child_id == alias1.child_id)
.where(alias2.parent_id == target_post_id)) .where(alias2.parent_id == target_post_id))
.values(parent_id=target_post_id)) .values(parent_id=target_post_id))
db.session.execute(update_stmt) db.session.execute(update_stmt)
update_stmt = ( update_stmt = (
sa.sql.expression.update(alias1) sa.sql.expression.update(alias1)
.where(alias1.child_id == source_post_id) .where(alias1.child_id == source_post_id)
.where(alias1.parent_id != target_post_id) .where(alias1.parent_id != target_post_id)
.where( .where(
~sa.exists() ~sa.exists()
.where(alias2.parent_id == alias1.parent_id) .where(alias2.parent_id == alias1.parent_id)
.where(alias2.child_id == target_post_id)) .where(alias2.child_id == target_post_id))
.values(child_id=target_post_id)) .values(child_id=target_post_id))
db.session.execute(update_stmt) db.session.execute(update_stmt)
merge_tags(source_post.post_id, target_post.post_id) merge_tags(source_post.post_id, target_post.post_id)
@ -734,10 +733,9 @@ def merge_posts(
def search_by_image_exact(image_content: bytes) -> Optional[model.Post]: def search_by_image_exact(image_content: bytes) -> Optional[model.Post]:
checksum = util.get_sha1(image_content) checksum = util.get_sha1(image_content)
return ( return (
db.session db.session.query(model.Post)
.query(model.Post) .filter(model.Post.checksum == checksum)
.filter(model.Post.checksum == checksum) .one_or_none())
.one_or_none())
def search_by_image(image_content: bytes) -> List[PostLookalike]: def search_by_image(image_content: bytes) -> List[PostLookalike]:
@ -756,21 +754,19 @@ def populate_reverse_search() -> None:
excluded_post_ids = image_hash.get_all_paths() excluded_post_ids = image_hash.get_all_paths()
post_ids_to_hash = ( post_ids_to_hash = (
db.session db.session.query(model.Post.post_id)
.query(model.Post.post_id) .filter(
.filter(
(model.Post.type == model.Post.TYPE_IMAGE) | (model.Post.type == model.Post.TYPE_IMAGE) |
(model.Post.type == model.Post.TYPE_ANIMATION)) (model.Post.type == model.Post.TYPE_ANIMATION))
.filter(~model.Post.post_id.in_(excluded_post_ids)) .filter(~model.Post.post_id.in_(excluded_post_ids))
.order_by(model.Post.post_id.asc()) .order_by(model.Post.post_id.asc())
.all()) .all())
for post_ids_chunk in util.chunks(post_ids_to_hash, 100): for post_ids_chunk in util.chunks(post_ids_to_hash, 100):
posts_chunk = ( posts_chunk = (
db.session db.session.query(model.Post)
.query(model.Post) .filter(model.Post.post_id.in_(post_ids_chunk))
.filter(model.Post.post_id.in_(post_ids_chunk)) .all())
.all())
for post in posts_chunk: for post in posts_chunk:
content_path = get_post_content_path(post) content_path = get_post_content_path(post)
if files.has(content_path): if files.has(content_path):

View file

@ -1,5 +1,6 @@
import datetime import datetime
from typing import Any, Tuple, Callable from typing import Any, Tuple, Callable
from szurubooru import db, model, errors from szurubooru import db, model, errors
@ -40,11 +41,10 @@ def get_score(entity: model.Base, user: model.User) -> int:
assert user assert user
table, get_column = _get_table_info(entity) table, get_column = _get_table_info(entity)
row = ( row = (
db.session db.session.query(table.score)
.query(table.score) .filter(get_column(table) == get_column(entity))
.filter(get_column(table) == get_column(entity)) .filter(table.user_id == user.user_id)
.filter(table.user_id == user.user_id) .one_or_none())
.one_or_none())
return row[0] if row else 0 return row[0] if row else 0

View file

@ -1,4 +1,5 @@
from typing import Any, List, Dict, Callable from typing import Any, List, Dict, Callable
from szurubooru import model, rest, errors from szurubooru import model, rest, errors

View file

@ -1,5 +1,6 @@
from typing import Any, Optional, Dict, Callable
from datetime import datetime from datetime import datetime
from typing import Any, Optional, Dict, Callable
from szurubooru import db, model from szurubooru import db, model
from szurubooru.func import diff, users from szurubooru.func import diff, users
@ -96,7 +97,7 @@ def modify(entity: model.Base, auth_user: Optional[model.User]) -> None:
cls cls
for cls in model.Base._decl_class_registry.values() for cls in model.Base._decl_class_registry.values()
if hasattr(cls, '__table__') if hasattr(cls, '__table__')
and cls.__table__.fullname == entity.__table__.fullname and cls.__table__.fullname == entity.__table__.fullname
), ),
None) None)
assert table assert table

View file

@ -1,10 +1,11 @@
import re
from typing import Any, Optional, Dict, List, Callable from typing import Any, Optional, Dict, List, Callable
import re
import sqlalchemy as sa import sqlalchemy as sa
from szurubooru import config, db, model, errors, rest from szurubooru import config, db, model, errors, rest
from szurubooru.func import util, serialization, cache from szurubooru.func import util, serialization, cache
DEFAULT_CATEGORY_NAME_CACHE_KEY = 'default-tag-category' DEFAULT_CATEGORY_NAME_CACHE_KEY = 'default-tag-category'
@ -88,9 +89,9 @@ def update_category_name(category: model.TagCategory, name: str) -> None:
expr = sa.func.lower(model.TagCategory.name) == name.lower() expr = sa.func.lower(model.TagCategory.name) == name.lower()
if category.tag_category_id: if category.tag_category_id:
expr = expr & ( expr = expr & (
model.TagCategory.tag_category_id != category.tag_category_id) model.TagCategory.tag_category_id != category.tag_category_id)
already_exists = ( already_exists = (
db.session.query(model.TagCategory).filter(expr).count() > 0) db.session.query(model.TagCategory).filter(expr).count() > 0)
if already_exists: if already_exists:
raise TagCategoryAlreadyExistsError( raise TagCategoryAlreadyExistsError(
'A category with this name already exists.') 'A category with this name already exists.')
@ -115,9 +116,8 @@ def update_category_color(category: model.TagCategory, color: str) -> None:
def try_get_category_by_name( def try_get_category_by_name(
name: str, lock: bool = False) -> Optional[model.TagCategory]: name: str, lock: bool = False) -> Optional[model.TagCategory]:
query = ( query = (
db.session db.session.query(model.TagCategory)
.query(model.TagCategory) .filter(sa.func.lower(model.TagCategory.name) == name.lower()))
.filter(sa.func.lower(model.TagCategory.name) == name.lower()))
if lock: if lock:
query = query.with_lockmode('update') query = query.with_lockmode('update')
return query.one_or_none() return query.one_or_none()
@ -141,9 +141,8 @@ def get_all_categories() -> List[model.TagCategory]:
def try_get_default_category( def try_get_default_category(
lock: bool = False) -> Optional[model.TagCategory]: lock: bool = False) -> Optional[model.TagCategory]:
query = ( query = (
db.session db.session.query(model.TagCategory)
.query(model.TagCategory) .filter(model.TagCategory.default))
.filter(model.TagCategory.default))
if lock: if lock:
query = query.with_lockmode('update') query = query.with_lockmode('update')
category = query.first() category = query.first()
@ -151,9 +150,8 @@ def try_get_default_category(
# category, get the first record available. # category, get the first record available.
if not category: if not category:
query = ( query = (
db.session db.session.query(model.TagCategory)
.query(model.TagCategory) .order_by(model.TagCategory.tag_category_id.asc()))
.order_by(model.TagCategory.tag_category_id.asc()))
if lock: if lock:
query = query.with_lockmode('update') query = query.with_lockmode('update')
category = query.first() category = query.first()

View file

@ -1,9 +1,9 @@
import json
import os
import re
from typing import Any, Optional, Tuple, List, Dict, Callable
from datetime import datetime from datetime import datetime
from typing import Any, Optional, Tuple, List, Dict, Callable
import re
import sqlalchemy as sa import sqlalchemy as sa
from szurubooru import config, db, model, errors, rest from szurubooru import config, db, model, errors, rest
from szurubooru.func import util, tag_categories, serialization from szurubooru.func import util, tag_categories, serialization
@ -138,11 +138,10 @@ def serialize_tag(
def try_get_tag_by_name(name: str) -> Optional[model.Tag]: def try_get_tag_by_name(name: str) -> Optional[model.Tag]:
return ( return (
db.session db.session.query(model.Tag)
.query(model.Tag) .join(model.TagName)
.join(model.TagName) .filter(sa.func.lower(model.TagName.name) == name.lower())
.filter(sa.func.lower(model.TagName.name) == name.lower()) .one_or_none())
.one_or_none())
def get_tag_by_name(name: str) -> model.Tag: def get_tag_by_name(name: str) -> model.Tag:
@ -158,12 +157,12 @@ def get_tags_by_names(names: List[str]) -> List[model.Tag]:
return [] return []
return ( return (
db.session.query(model.Tag) db.session.query(model.Tag)
.join(model.TagName) .join(model.TagName)
.filter( .filter(
sa.sql.or_( sa.sql.or_(
sa.func.lower(model.TagName.name) == name.lower() sa.func.lower(model.TagName.name) == name.lower()
for name in names)) for name in names))
.all()) .all())
def get_or_create_tags_by_names( def get_or_create_tags_by_names(
@ -196,16 +195,15 @@ def get_tag_siblings(tag: model.Tag) -> List[model.Tag]:
pt_alias1 = sa.orm.aliased(model.PostTag) pt_alias1 = sa.orm.aliased(model.PostTag)
pt_alias2 = sa.orm.aliased(model.PostTag) pt_alias2 = sa.orm.aliased(model.PostTag)
result = ( result = (
db.session db.session.query(tag_alias, sa.func.count(pt_alias2.post_id))
.query(tag_alias, sa.func.count(pt_alias2.post_id)) .join(pt_alias1, pt_alias1.tag_id == tag_alias.tag_id)
.join(pt_alias1, pt_alias1.tag_id == tag_alias.tag_id) .join(pt_alias2, pt_alias2.post_id == pt_alias1.post_id)
.join(pt_alias2, pt_alias2.post_id == pt_alias1.post_id) .filter(pt_alias2.tag_id == tag.tag_id)
.filter(pt_alias2.tag_id == tag.tag_id) .filter(pt_alias1.tag_id != tag.tag_id)
.filter(pt_alias1.tag_id != tag.tag_id) .group_by(tag_alias.tag_id)
.group_by(tag_alias.tag_id) .order_by(sa.func.count(pt_alias2.post_id).desc())
.order_by(sa.func.count(pt_alias2.post_id).desc()) .order_by(tag_alias.first_name)
.order_by(tag_alias.first_name) .limit(50))
.limit(50))
return result return result
@ -213,10 +211,10 @@ def delete(source_tag: model.Tag) -> None:
assert source_tag assert source_tag
db.session.execute( db.session.execute(
sa.sql.expression.delete(model.TagSuggestion) sa.sql.expression.delete(model.TagSuggestion)
.where(model.TagSuggestion.child_id == source_tag.tag_id)) .where(model.TagSuggestion.child_id == source_tag.tag_id))
db.session.execute( db.session.execute(
sa.sql.expression.delete(model.TagImplication) sa.sql.expression.delete(model.TagImplication)
.where(model.TagImplication.child_id == source_tag.tag_id)) .where(model.TagImplication.child_id == source_tag.tag_id))
db.session.delete(source_tag) db.session.delete(source_tag)
@ -231,13 +229,12 @@ def merge_tags(source_tag: model.Tag, target_tag: model.Tag) -> None:
alias2 = sa.orm.util.aliased(model.PostTag) alias2 = sa.orm.util.aliased(model.PostTag)
update_stmt = ( update_stmt = (
sa.sql.expression.update(alias1) sa.sql.expression.update(alias1)
.where(alias1.tag_id == source_tag_id)) .where(alias1.tag_id == source_tag_id))
update_stmt = ( update_stmt = (
update_stmt update_stmt
.where( .where(~sa.exists()
~sa.exists() .where(alias1.post_id == alias2.post_id)
.where(alias1.post_id == alias2.post_id) .where(alias2.tag_id == target_tag_id)))
.where(alias2.tag_id == target_tag_id)))
update_stmt = update_stmt.values(tag_id=target_tag_id) update_stmt = update_stmt.values(tag_id=target_tag_id)
db.session.execute(update_stmt) db.session.execute(update_stmt)
@ -247,24 +244,22 @@ def merge_tags(source_tag: model.Tag, target_tag: model.Tag) -> None:
alias2 = sa.orm.util.aliased(table) alias2 = sa.orm.util.aliased(table)
update_stmt = ( update_stmt = (
sa.sql.expression.update(alias1) sa.sql.expression.update(alias1)
.where(alias1.parent_id == source_tag_id) .where(alias1.parent_id == source_tag_id)
.where(alias1.child_id != target_tag_id) .where(alias1.child_id != target_tag_id)
.where( .where(~sa.exists()
~sa.exists() .where(alias2.child_id == alias1.child_id)
.where(alias2.child_id == alias1.child_id) .where(alias2.parent_id == target_tag_id))
.where(alias2.parent_id == target_tag_id)) .values(parent_id=target_tag_id))
.values(parent_id=target_tag_id))
db.session.execute(update_stmt) db.session.execute(update_stmt)
update_stmt = ( update_stmt = (
sa.sql.expression.update(alias1) sa.sql.expression.update(alias1)
.where(alias1.child_id == source_tag_id) .where(alias1.child_id == source_tag_id)
.where(alias1.parent_id != target_tag_id) .where(alias1.parent_id != target_tag_id)
.where( .where(~sa.exists()
~sa.exists() .where(alias2.parent_id == alias1.parent_id)
.where(alias2.parent_id == alias1.parent_id) .where(alias2.child_id == target_tag_id))
.where(alias2.child_id == target_tag_id)) .values(child_id=target_tag_id))
.values(child_id=target_tag_id))
db.session.execute(update_stmt) db.session.execute(update_stmt)
def merge_suggestions(source_tag_id: int, target_tag_id: int) -> None: def merge_suggestions(source_tag_id: int, target_tag_id: int) -> None:

View file

@ -1,12 +1,13 @@
import os from contextlib import contextmanager
from datetime import datetime, timedelta
from typing import Any, Optional, Union, Tuple, List, Dict, Generator, TypeVar
import hashlib import hashlib
import os
import re import re
import tempfile import tempfile
from typing import Any, Optional, Union, Tuple, List, Dict, Generator, TypeVar
from datetime import datetime, timedelta
from contextlib import contextmanager
from szurubooru import errors
from szurubooru import errors
T = TypeVar('T') T = TypeVar('T')
@ -14,7 +15,7 @@ T = TypeVar('T')
def snake_case_to_lower_camel_case(text: str) -> str: def snake_case_to_lower_camel_case(text: str) -> str:
components = text.split('_') components = text.split('_')
return components[0].lower() + \ return components[0].lower() + \
''.join(word[0].upper() + word[1:].lower() for word in components[1:]) ''.join(word[0].upper() + word[1:].lower() for word in components[1:])
def snake_case_to_upper_train_case(text: str) -> str: def snake_case_to_upper_train_case(text: str) -> str:
@ -86,6 +87,7 @@ def is_valid_email(email: Optional[str]) -> bool:
class dotdict(dict): # pylint: disable=invalid-name class dotdict(dict): # pylint: disable=invalid-name
''' dot.notation access to dictionary attributes. ''' ''' dot.notation access to dictionary attributes. '''
def __getattr__(self, attr: str) -> Any: def __getattr__(self, attr: str) -> Any:
return self.get(attr) return self.get(attr)