diff --git a/server/config.yaml.dist b/server/config.yaml.dist index 236630d0..47fd89a8 100644 --- a/server/config.yaml.dist +++ b/server/config.yaml.dist @@ -21,11 +21,15 @@ thumbnails: post_width: 300 post_height: 300 +# automatically convert animated GIF uploads to video formats convert: gif: to_webm: false to_mp4: false +# allow posts to be uploaded even if some image processing errors occur +allow_broken_uploads: false + # used to send password reset e-mails smtp: host: # example: localhost diff --git a/server/szurubooru/api/post_api.py b/server/szurubooru/api/post_api.py index 58d0708f..84094845 100644 --- a/server/szurubooru/api/post_api.py +++ b/server/szurubooru/api/post_api.py @@ -262,9 +262,9 @@ def get_posts_by_image( 'similarPosts': [ { - 'distance': lookalike.distance, - 'post': _serialize_post(ctx, lookalike.post), + 'distance': distance, + 'post': _serialize_post(ctx, post), } - for lookalike in lookalikes + for distance, post in lookalikes ], } diff --git a/server/szurubooru/facade.py b/server/szurubooru/facade.py index e10b53d9..3220b56e 100644 --- a/server/szurubooru/facade.py +++ b/server/szurubooru/facade.py @@ -8,7 +8,7 @@ import coloredlogs import sqlalchemy as sa import sqlalchemy.orm.exc from szurubooru import config, db, errors, rest -from szurubooru.func import posts, file_uploads, image_hash +from szurubooru.func import file_uploads # pylint: disable=unused-import from szurubooru import api, middleware @@ -129,13 +129,7 @@ def create_app() -> Callable[[Any, Any], Any]: purge_thread.daemon = True purge_thread.start() - try: - image_hash.get_session().cluster.health( - wait_for_status='yellow', request_timeout=120) - posts.populate_reverse_search() - db.session.commit() - except errors.ThirdPartyError: - pass + db.session.commit() rest.errors.handle(errors.AuthError, _on_auth_error) rest.errors.handle(errors.ValidationError, _on_validation_error) diff --git a/server/szurubooru/func/image_hash.py b/server/szurubooru/func/image_hash.py index e5ae6a37..c3bc2321 100644 --- a/server/szurubooru/func/image_hash.py +++ b/server/szurubooru/func/image_hash.py @@ -2,8 +2,7 @@ import logging from io import BytesIO from datetime import datetime from typing import Any, Optional, Tuple, Set, List, Callable -import elasticsearch -import elasticsearch_dsl +import math import numpy as np from PIL import Image from szurubooru import config, errors @@ -24,30 +23,25 @@ N = 9 P = None SAMPLE_WORDS = 16 MAX_WORDS = 63 -ES_DOC_TYPE = 'image' -ES_MAX_RESULTS = 100 +SIG_CHUNK_BITS = 32 + +SIG_BASE = 2*N_LEVELS + 2 +SIG_CHUNK_WIDTH = int(SIG_CHUNK_BITS / math.log2(SIG_BASE)) +SIG_CHUNK_NUMS = 8*N*N / SIG_CHUNK_WIDTH +assert 8*N*N % SIG_CHUNK_WIDTH == 0 Window = Tuple[Tuple[float, float], Tuple[float, float]] -NpMatrix = Any - - -def get_session() -> elasticsearch.Elasticsearch: - extra_args = {} - if config.config['elasticsearch']['pass']: - extra_args['http_auth'] = ( - config.config['elasticsearch']['user'], - config.config['elasticsearch']['pass']) - extra_args['scheme'] = 'https' - extra_args['port'] = 443 - return elasticsearch.Elasticsearch([{ - 'host': config.config['elasticsearch']['host'], - 'port': config.config['elasticsearch']['port'], - }], **extra_args) +NpMatrix = np.ndarray def _preprocess_image(content: bytes) -> NpMatrix: - img = Image.open(BytesIO(content)) - return np.asarray(img.convert('L'), dtype=np.uint8) + try: + img = Image.open(BytesIO(content)) + return np.asarray(img.convert('L'), dtype=np.uint8) + except IOError: + raise errors.ProcessingError( + 'Unable to generate a signature hash ' + 'for this image.') def _crop_image( @@ -175,7 +169,31 @@ def _compute_differentials(grey_level_matrix: NpMatrix) -> NpMatrix: lower_right_neighbors])) -def _generate_signature(content: bytes) -> NpMatrix: +def _words_to_int(word_array: NpMatrix) -> List[int]: + width = word_array.shape[1] + coding_vector = 3**np.arange(width) + return np.dot(word_array + 1, coding_vector).astype(int).tolist() + + +def _get_words(array: NpMatrix, k: int, n: int) -> NpMatrix: + word_positions = np.linspace( + 0, array.shape[0], n, endpoint=False).astype('int') + assert k <= array.shape[0] + assert word_positions.shape[0] <= array.shape[0] + words = np.zeros((n, k)).astype('int8') + for i, pos in enumerate(word_positions): + if pos + k <= array.shape[0]: + words[i] = array[pos:pos + k] + else: + temp = array[pos:].copy() + temp.resize(k, refcheck=False) + words[i] = temp + words[words > 0] = 1 + words[words < 0] = -1 + return words + + +def generate_signature(content: bytes) -> NpMatrix: im_array = _preprocess_image(content) image_limits = _crop_image( im_array, @@ -192,40 +210,15 @@ def _generate_signature(content: bytes) -> NpMatrix: return np.ravel(diff_matrix).astype('int8') -def _get_words(array: NpMatrix, k: int, n: int) -> NpMatrix: - word_positions = np.linspace( - 0, array.shape[0], n, endpoint=False).astype('int') - assert k <= array.shape[0] - assert word_positions.shape[0] <= array.shape[0] - words = np.zeros((n, k)).astype('int8') - for i, pos in enumerate(word_positions): - if pos + k <= array.shape[0]: - words[i] = array[pos:pos + k] - else: - temp = array[pos:].copy() - temp.resize(k) - words[i] = temp - _max_contrast(words) - words = _words_to_int(words) - return words +def generate_words(signature: NpMatrix) -> List[int]: + return _words_to_int(_get_words(signature, k=SAMPLE_WORDS, n=MAX_WORDS)) -def _words_to_int(word_array: NpMatrix) -> NpMatrix: - width = word_array.shape[1] - coding_vector = 3**np.arange(width) - return np.dot(word_array + 1, coding_vector) - - -def _max_contrast(array: NpMatrix) -> None: - array[array > 0] = 1 - array[array < 0] = -1 - - -def _normalized_distance( - target_array: NpMatrix, +def normalized_distance( + target_array: Any, vec: NpMatrix, nan_value: float = 1.0) -> List[float]: - target_array = target_array.astype(int) + target_array = np.array(target_array).astype(int) vec = vec.astype(int) topvec = np.linalg.norm(vec - target_array, axis=1) norm1 = np.linalg.norm(vec, axis=0) @@ -235,124 +228,21 @@ def _normalized_distance( return finvec -def _safety_blanket(default_param_factory: Callable[[], Any]) -> Callable: - def wrapper_outer(target_function: Callable) -> Callable: - def wrapper_inner(*args: Any, **kwargs: Any) -> Any: - try: - return target_function(*args, **kwargs) - except elasticsearch.exceptions.NotFoundError: - # index not yet created, will be created dynamically by - # add_image() - return default_param_factory() - except elasticsearch.exceptions.ElasticsearchException as ex: - logger.warning('Problem with elastic search: %s', ex) - raise errors.ThirdPartyError( - 'Error connecting to elastic search.') - except IOError: - raise errors.ProcessingError('Not an image.') - except Exception as ex: - raise errors.ThirdPartyError('Unknown error (%s).' % ex) - return wrapper_inner - return wrapper_outer +def pack_signature(signature: NpMatrix) -> bytes: + base = 2 * N_LEVELS + 1 + coding_vector = np.flipud(SIG_BASE**np.arange(SIG_CHUNK_WIDTH)) + return np.array([ + np.dot(x, coding_vector) for x in + np.reshape(signature + N_LEVELS, (-1, SIG_CHUNK_WIDTH)) + ]).astype(f'uint{SIG_CHUNK_BITS}').tobytes() -class Lookalike: - def __init__(self, score: int, distance: float, path: Any) -> None: - self.score = score - self.distance = distance - self.path = path - - -@_safety_blanket(lambda: None) -def add_image(path: str, image_content: bytes) -> None: - assert path - assert image_content - signature = _generate_signature(image_content) - words = _get_words(signature, k=SAMPLE_WORDS, n=MAX_WORDS) - - record = { - 'signature': signature.tolist(), - 'path': path, - 'timestamp': datetime.now(), - } - for i in range(MAX_WORDS): - record['simple_word_' + str(i)] = words[i].tolist() - - get_session().index( - index=config.config['elasticsearch']['index'], - doc_type=ES_DOC_TYPE, - body=record, - refresh=True) - - -@_safety_blanket(lambda: None) -def delete_image(path: str) -> None: - assert path - get_session().delete_by_query( - index=config.config['elasticsearch']['index'], - doc_type=ES_DOC_TYPE, - body={'query': {'term': {'path': path}}}) - - -@_safety_blanket(lambda: []) -def search_by_image(image_content: bytes) -> List[Lookalike]: - signature = _generate_signature(image_content) - words = _get_words(signature, k=SAMPLE_WORDS, n=MAX_WORDS) - - res = get_session().search( - index=config.config['elasticsearch']['index'], - doc_type=ES_DOC_TYPE, - body={ - 'query': - { - 'bool': - { - 'should': - [ - {'term': {'simple_word_%d' % i: word.tolist()}} - for i, word in enumerate(words) - ] - } - }, - '_source': {'excludes': ['simple_word_*']}}, - size=ES_MAX_RESULTS, - timeout='10s')['hits']['hits'] - - if len(res) == 0: - return [] - - sigs = np.array([x['_source']['signature'] for x in res]) - dists = _normalized_distance(sigs, np.array(signature)) - - ids = set() # type: Set[int] - ret = [] - for item, dist in zip(res, dists): - id = item['_id'] - score = item['_score'] - path = item['_source']['path'] - if id in ids: - continue - ids.add(id) - if dist < DISTANCE_CUTOFF: - ret.append(Lookalike(score=score, distance=dist, path=path)) - return ret - - -@_safety_blanket(lambda: None) -def purge() -> None: - get_session().delete_by_query( - index=config.config['elasticsearch']['index'], - doc_type=ES_DOC_TYPE, - body={'query': {'match_all': {}}}, - refresh=True) - - -@_safety_blanket(lambda: set()) -def get_all_paths() -> Set[str]: - search = ( - elasticsearch_dsl.Search( - using=get_session(), - index=config.config['elasticsearch']['index'], - doc_type=ES_DOC_TYPE) - .source(['path'])) - return set(h.path for h in search.scan()) +def unpack_signature(packed: bytes) -> NpMatrix: + base = 2 * N_LEVELS + 1 + return np.ravel(np.array([ + [ + int(digit) - N_LEVELS for digit in + np.base_repr(e, base=SIG_BASE).zfill(SIG_CHUNK_WIDTH) + ] for e in + np.frombuffer(packed, dtype=f'uint{SIG_CHUNK_BITS}') + ]).astype('int8')) diff --git a/server/szurubooru/func/posts.py b/server/szurubooru/func/posts.py index 73407f66..7b41a268 100644 --- a/server/szurubooru/func/posts.py +++ b/server/szurubooru/func/posts.py @@ -1,3 +1,4 @@ +import logging import hmac from typing import Any, Optional, Tuple, List, Dict, Callable from datetime import datetime @@ -8,6 +9,9 @@ from szurubooru.func import ( mime, images, files, image_hash, serialization, snapshots) +logger = logging.getLogger(__name__) + + EMPTY_PIXEL = ( b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' @@ -60,12 +64,6 @@ class InvalidPostFlagError(errors.ValidationError): pass -class PostLookalike(image_hash.Lookalike): - def __init__(self, score: int, distance: float, post: model.Post) -> None: - super().__init__(score, distance, post.post_id) - self.post = post - - SAFETY_MAP = { model.Post.SAFETY_SAFE: 'safe', model.Post.SAFETY_SKETCHY: 'sketchy', @@ -402,7 +400,6 @@ def _after_post_update( def _before_post_delete( _mapper: Any, _connection: Any, post: model.Post) -> None: if post.post_id: - image_hash.delete_image(post.post_id) if config.config['delete_source_files']: files.delete(get_post_content_path(post)) files.delete(get_post_thumbnail_path(post)) @@ -416,10 +413,6 @@ def _sync_post_content(post: model.Post) -> None: files.save(get_post_content_path(post), content) delattr(post, '__content') regenerate_thumb = True - if post.post_id and post.type in ( - model.Post.TYPE_IMAGE, model.Post.TYPE_ANIMATION): - image_hash.delete_image(post.post_id) - image_hash.add_image(post.post_id, content) if hasattr(post, '__thumbnail'): if getattr(post, '__thumbnail'): @@ -485,14 +478,56 @@ def test_sound(post: model.Post, content: bytes) -> None: update_post_flags(post, flags) +def purge_post_signature(post: model.Post) -> None: + old_signature = ( + db.session + .query(model.PostSignature) + .filter(model.PostSignature.post_id == post.post_id) + .one_or_none()) + if old_signature: + db.session.delete(old_signature) + + +def generate_post_signature(post: model.Post, content: bytes) -> None: + try: + unpacked_signature = image_hash.generate_signature(content) + packed_signature = image_hash.pack_signature(unpacked_signature) + words = image_hash.generate_words(unpacked_signature) + + db.session.add(model.PostSignature( + post=post, signature=packed_signature, words=words)) + except errors.ProcessingError: + if not config.config['allow_broken_uploads']: + raise InvalidPostContentError( + 'Unable to generate image hash data.') + + +def update_all_post_signatures() -> None: + posts_to_hash = ( + db.session + .query(model.Post) + .filter( + (model.Post.type == model.Post.TYPE_IMAGE) | + (model.Post.type == model.Post.TYPE_ANIMATION)) + .filter(model.Post.signature == None) + .order_by(model.Post.post_id.asc()) + .all()) + for post in posts_to_hash: + logger.info('Generating hash info for %d', post.post_id) + generate_post_signature(post, files.get(get_post_content_path(post))) + + def update_post_content(post: model.Post, content: Optional[bytes]) -> None: assert post if not content: raise InvalidPostContentError('Post content missing.') + + update_signature = False post.mime_type = mime.get_mime_type(content) if mime.is_flash(post.mime_type): post.type = model.Post.TYPE_FLASH elif mime.is_image(post.mime_type): + update_signature = True if mime.is_animated_gif(content): post.type = model.Post.TYPE_ANIMATION else: @@ -515,18 +550,30 @@ def update_post_content(post: model.Post, content: Optional[bytes]) -> None: and other_post.post_id != post.post_id: raise PostAlreadyUploadedError(other_post) + if update_signature: + purge_post_signature(post) + post.signature = generate_post_signature(post, content) + post.file_size = len(content) try: image = images.Image(content) post.canvas_width = image.width post.canvas_height = image.height except errors.ProcessingError: - post.canvas_width = None - post.canvas_height = None + if not config.config['allow_broken_uploads']: + raise InvalidPostContentError( + 'Unable to process image metadata') + else: + post.canvas_width = None + post.canvas_height = None if (post.canvas_width is not None and post.canvas_width <= 0) \ or (post.canvas_height is not None and post.canvas_height <= 0): - post.canvas_width = None - post.canvas_height = None + if not config.config['allow_broken_uploads']: + raise InvalidPostContentError( + 'Invalid image dimensions returned during processing') + else: + post.canvas_width = None + post.canvas_height = None setattr(post, '__content', content) @@ -751,6 +798,8 @@ def merge_posts( if replace_content: content = files.get(get_post_content_path(source_post)) transfer_flags(source_post.post_id, target_post.post_id) + purge_post_signature(source_post) + purge_post_signature(target_post) delete(source_post) db.session.flush() @@ -768,38 +817,31 @@ def search_by_image_exact(image_content: bytes) -> Optional[model.Post]: .one_or_none()) -def search_by_image(image_content: bytes) -> List[PostLookalike]: - ret = [] - for result in image_hash.search_by_image(image_content): - post = try_get_post_by_id(result.path) - if post: - ret.append(PostLookalike( - score=result.score, - distance=result.distance, - post=post)) - return ret +def search_by_image(image_content: bytes) -> List[Tuple[float, model.Post]]: + query_signature = image_hash.generate_signature(image_content) + query_words = image_hash.generate_words(query_signature) + dbquery = ''' + SELECT s.post_id, s.signature, count(a.query) AS score + FROM post_signature AS s, unnest(s.words, :q) AS a(word, query) + WHERE a.word = a.query + GROUP BY s.post_id + ORDER BY score DESC LIMIT 100; + ''' -def populate_reverse_search() -> None: - excluded_post_ids = image_hash.get_all_paths() - - post_ids_to_hash = ( - db.session - .query(model.Post.post_id) - .filter( - (model.Post.type == model.Post.TYPE_IMAGE) | - (model.Post.type == model.Post.TYPE_ANIMATION)) - .filter(~model.Post.post_id.in_(excluded_post_ids)) - .order_by(model.Post.post_id.asc()) - .all()) - - for post_ids_chunk in util.chunks(post_ids_to_hash, 100): - posts_chunk = ( - db.session - .query(model.Post) - .filter(model.Post.post_id.in_(post_ids_chunk)) - .all()) - for post in posts_chunk: - content_path = get_post_content_path(post) - if files.has(content_path): - image_hash.add_image(post.post_id, files.get(content_path)) + candidates = db.session.execute(dbquery, {'q': query_words}) + data = tuple(zip(*[ + (post_id, image_hash.unpack_signature(packedsig)) + for post_id, packedsig, score in candidates + ])) + if data: + candidate_post_ids, sigarray = data + distances = image_hash.normalized_distance(sigarray, query_signature) + return [ + (distance, try_get_post_by_id(candidate_post_id)) + for candidate_post_id, distance + in zip(candidate_post_ids, distances) + if distance < image_hash.DISTANCE_CUTOFF + ] + else: + return [] diff --git a/server/szurubooru/model/__init__.py b/server/szurubooru/model/__init__.py index 4892b974..a043bd1f 100644 --- a/server/szurubooru/model/__init__.py +++ b/server/szurubooru/model/__init__.py @@ -9,7 +9,8 @@ from szurubooru.model.post import ( PostFavorite, PostScore, PostNote, - PostFeature) + PostFeature, + PostSignature) from szurubooru.model.comment import Comment, CommentScore from szurubooru.model.snapshot import Snapshot import szurubooru.model.util diff --git a/server/szurubooru/model/post.py b/server/szurubooru/model/post.py index 30e2900c..ec2cf471 100644 --- a/server/szurubooru/model/post.py +++ b/server/szurubooru/model/post.py @@ -143,6 +143,26 @@ class PostTag(Base): self.tag_id = tag_id +class PostSignature(Base): + __tablename__ = 'post_signature' + + post_id = sa.Column( + 'post_id', + sa.Integer, + sa.ForeignKey('post.id'), + primary_key=True, + nullable=False, + index=True) + signature = sa.Column('signature', sa.LargeBinary, nullable=False) + words = sa.Column( + 'words', + sa.dialects.postgresql.ARRAY(sa.Integer, dimensions=1), + nullable=False, + index=True) + + post = sa.orm.relationship('Post') + + class Post(Base): __tablename__ = 'post' @@ -184,6 +204,11 @@ class Post(Base): # foreign tables user = sa.orm.relationship('User') tags = sa.orm.relationship('Tag', backref='posts', secondary='post_tag') + signature = sa.orm.relationship( + 'PostSignature', + uselist=False, + cascade='all, delete-orphan', + lazy='joined') relations = sa.orm.relationship( 'Post', secondary='post_relation',