From 4c78cf8c4792e3eba8508f607bc3a4544d1f666f Mon Sep 17 00:00:00 2001 From: Shyam Sunder Date: Sat, 7 Mar 2020 20:43:20 -0500 Subject: [PATCH] server/image_search: implement reverse search functionality in postgres This will remove the dependency on the Elasticsearch database. The search query is passed currently as raw SQL. Proper implementation using SQLAlchemy will need custom ORM classed to be made. Additional config parameter "allow_broken_uploads" has been added. --- server/config.yaml.dist | 4 + server/szurubooru/api/post_api.py | 6 +- server/szurubooru/facade.py | 10 +- server/szurubooru/func/image_hash.py | 232 +++++++-------------------- server/szurubooru/func/posts.py | 138 ++++++++++------ server/szurubooru/model/__init__.py | 3 +- server/szurubooru/model/post.py | 25 +++ 7 files changed, 187 insertions(+), 231 deletions(-) diff --git a/server/config.yaml.dist b/server/config.yaml.dist index 236630d0..47fd89a8 100644 --- a/server/config.yaml.dist +++ b/server/config.yaml.dist @@ -21,11 +21,15 @@ thumbnails: post_width: 300 post_height: 300 +# automatically convert animated GIF uploads to video formats convert: gif: to_webm: false to_mp4: false +# allow posts to be uploaded even if some image processing errors occur +allow_broken_uploads: false + # used to send password reset e-mails smtp: host: # example: localhost diff --git a/server/szurubooru/api/post_api.py b/server/szurubooru/api/post_api.py index 58d0708f..84094845 100644 --- a/server/szurubooru/api/post_api.py +++ b/server/szurubooru/api/post_api.py @@ -262,9 +262,9 @@ def get_posts_by_image( 'similarPosts': [ { - 'distance': lookalike.distance, - 'post': _serialize_post(ctx, lookalike.post), + 'distance': distance, + 'post': _serialize_post(ctx, post), } - for lookalike in lookalikes + for distance, post in lookalikes ], } diff --git a/server/szurubooru/facade.py b/server/szurubooru/facade.py index e10b53d9..3220b56e 100644 --- a/server/szurubooru/facade.py +++ b/server/szurubooru/facade.py @@ -8,7 +8,7 @@ import coloredlogs import sqlalchemy as sa import sqlalchemy.orm.exc from szurubooru import config, db, errors, rest -from szurubooru.func import posts, file_uploads, image_hash +from szurubooru.func import file_uploads # pylint: disable=unused-import from szurubooru import api, middleware @@ -129,13 +129,7 @@ def create_app() -> Callable[[Any, Any], Any]: purge_thread.daemon = True purge_thread.start() - try: - image_hash.get_session().cluster.health( - wait_for_status='yellow', request_timeout=120) - posts.populate_reverse_search() - db.session.commit() - except errors.ThirdPartyError: - pass + db.session.commit() rest.errors.handle(errors.AuthError, _on_auth_error) rest.errors.handle(errors.ValidationError, _on_validation_error) diff --git a/server/szurubooru/func/image_hash.py b/server/szurubooru/func/image_hash.py index e5ae6a37..c3bc2321 100644 --- a/server/szurubooru/func/image_hash.py +++ b/server/szurubooru/func/image_hash.py @@ -2,8 +2,7 @@ import logging from io import BytesIO from datetime import datetime from typing import Any, Optional, Tuple, Set, List, Callable -import elasticsearch -import elasticsearch_dsl +import math import numpy as np from PIL import Image from szurubooru import config, errors @@ -24,30 +23,25 @@ N = 9 P = None SAMPLE_WORDS = 16 MAX_WORDS = 63 -ES_DOC_TYPE = 'image' -ES_MAX_RESULTS = 100 +SIG_CHUNK_BITS = 32 + +SIG_BASE = 2*N_LEVELS + 2 +SIG_CHUNK_WIDTH = int(SIG_CHUNK_BITS / math.log2(SIG_BASE)) +SIG_CHUNK_NUMS = 8*N*N / SIG_CHUNK_WIDTH +assert 8*N*N % SIG_CHUNK_WIDTH == 0 Window = Tuple[Tuple[float, float], Tuple[float, float]] -NpMatrix = Any - - -def get_session() -> elasticsearch.Elasticsearch: - extra_args = {} - if config.config['elasticsearch']['pass']: - extra_args['http_auth'] = ( - config.config['elasticsearch']['user'], - config.config['elasticsearch']['pass']) - extra_args['scheme'] = 'https' - extra_args['port'] = 443 - return elasticsearch.Elasticsearch([{ - 'host': config.config['elasticsearch']['host'], - 'port': config.config['elasticsearch']['port'], - }], **extra_args) +NpMatrix = np.ndarray def _preprocess_image(content: bytes) -> NpMatrix: - img = Image.open(BytesIO(content)) - return np.asarray(img.convert('L'), dtype=np.uint8) + try: + img = Image.open(BytesIO(content)) + return np.asarray(img.convert('L'), dtype=np.uint8) + except IOError: + raise errors.ProcessingError( + 'Unable to generate a signature hash ' + 'for this image.') def _crop_image( @@ -175,7 +169,31 @@ def _compute_differentials(grey_level_matrix: NpMatrix) -> NpMatrix: lower_right_neighbors])) -def _generate_signature(content: bytes) -> NpMatrix: +def _words_to_int(word_array: NpMatrix) -> List[int]: + width = word_array.shape[1] + coding_vector = 3**np.arange(width) + return np.dot(word_array + 1, coding_vector).astype(int).tolist() + + +def _get_words(array: NpMatrix, k: int, n: int) -> NpMatrix: + word_positions = np.linspace( + 0, array.shape[0], n, endpoint=False).astype('int') + assert k <= array.shape[0] + assert word_positions.shape[0] <= array.shape[0] + words = np.zeros((n, k)).astype('int8') + for i, pos in enumerate(word_positions): + if pos + k <= array.shape[0]: + words[i] = array[pos:pos + k] + else: + temp = array[pos:].copy() + temp.resize(k, refcheck=False) + words[i] = temp + words[words > 0] = 1 + words[words < 0] = -1 + return words + + +def generate_signature(content: bytes) -> NpMatrix: im_array = _preprocess_image(content) image_limits = _crop_image( im_array, @@ -192,40 +210,15 @@ def _generate_signature(content: bytes) -> NpMatrix: return np.ravel(diff_matrix).astype('int8') -def _get_words(array: NpMatrix, k: int, n: int) -> NpMatrix: - word_positions = np.linspace( - 0, array.shape[0], n, endpoint=False).astype('int') - assert k <= array.shape[0] - assert word_positions.shape[0] <= array.shape[0] - words = np.zeros((n, k)).astype('int8') - for i, pos in enumerate(word_positions): - if pos + k <= array.shape[0]: - words[i] = array[pos:pos + k] - else: - temp = array[pos:].copy() - temp.resize(k) - words[i] = temp - _max_contrast(words) - words = _words_to_int(words) - return words +def generate_words(signature: NpMatrix) -> List[int]: + return _words_to_int(_get_words(signature, k=SAMPLE_WORDS, n=MAX_WORDS)) -def _words_to_int(word_array: NpMatrix) -> NpMatrix: - width = word_array.shape[1] - coding_vector = 3**np.arange(width) - return np.dot(word_array + 1, coding_vector) - - -def _max_contrast(array: NpMatrix) -> None: - array[array > 0] = 1 - array[array < 0] = -1 - - -def _normalized_distance( - target_array: NpMatrix, +def normalized_distance( + target_array: Any, vec: NpMatrix, nan_value: float = 1.0) -> List[float]: - target_array = target_array.astype(int) + target_array = np.array(target_array).astype(int) vec = vec.astype(int) topvec = np.linalg.norm(vec - target_array, axis=1) norm1 = np.linalg.norm(vec, axis=0) @@ -235,124 +228,21 @@ def _normalized_distance( return finvec -def _safety_blanket(default_param_factory: Callable[[], Any]) -> Callable: - def wrapper_outer(target_function: Callable) -> Callable: - def wrapper_inner(*args: Any, **kwargs: Any) -> Any: - try: - return target_function(*args, **kwargs) - except elasticsearch.exceptions.NotFoundError: - # index not yet created, will be created dynamically by - # add_image() - return default_param_factory() - except elasticsearch.exceptions.ElasticsearchException as ex: - logger.warning('Problem with elastic search: %s', ex) - raise errors.ThirdPartyError( - 'Error connecting to elastic search.') - except IOError: - raise errors.ProcessingError('Not an image.') - except Exception as ex: - raise errors.ThirdPartyError('Unknown error (%s).' % ex) - return wrapper_inner - return wrapper_outer +def pack_signature(signature: NpMatrix) -> bytes: + base = 2 * N_LEVELS + 1 + coding_vector = np.flipud(SIG_BASE**np.arange(SIG_CHUNK_WIDTH)) + return np.array([ + np.dot(x, coding_vector) for x in + np.reshape(signature + N_LEVELS, (-1, SIG_CHUNK_WIDTH)) + ]).astype(f'uint{SIG_CHUNK_BITS}').tobytes() -class Lookalike: - def __init__(self, score: int, distance: float, path: Any) -> None: - self.score = score - self.distance = distance - self.path = path - - -@_safety_blanket(lambda: None) -def add_image(path: str, image_content: bytes) -> None: - assert path - assert image_content - signature = _generate_signature(image_content) - words = _get_words(signature, k=SAMPLE_WORDS, n=MAX_WORDS) - - record = { - 'signature': signature.tolist(), - 'path': path, - 'timestamp': datetime.now(), - } - for i in range(MAX_WORDS): - record['simple_word_' + str(i)] = words[i].tolist() - - get_session().index( - index=config.config['elasticsearch']['index'], - doc_type=ES_DOC_TYPE, - body=record, - refresh=True) - - -@_safety_blanket(lambda: None) -def delete_image(path: str) -> None: - assert path - get_session().delete_by_query( - index=config.config['elasticsearch']['index'], - doc_type=ES_DOC_TYPE, - body={'query': {'term': {'path': path}}}) - - -@_safety_blanket(lambda: []) -def search_by_image(image_content: bytes) -> List[Lookalike]: - signature = _generate_signature(image_content) - words = _get_words(signature, k=SAMPLE_WORDS, n=MAX_WORDS) - - res = get_session().search( - index=config.config['elasticsearch']['index'], - doc_type=ES_DOC_TYPE, - body={ - 'query': - { - 'bool': - { - 'should': - [ - {'term': {'simple_word_%d' % i: word.tolist()}} - for i, word in enumerate(words) - ] - } - }, - '_source': {'excludes': ['simple_word_*']}}, - size=ES_MAX_RESULTS, - timeout='10s')['hits']['hits'] - - if len(res) == 0: - return [] - - sigs = np.array([x['_source']['signature'] for x in res]) - dists = _normalized_distance(sigs, np.array(signature)) - - ids = set() # type: Set[int] - ret = [] - for item, dist in zip(res, dists): - id = item['_id'] - score = item['_score'] - path = item['_source']['path'] - if id in ids: - continue - ids.add(id) - if dist < DISTANCE_CUTOFF: - ret.append(Lookalike(score=score, distance=dist, path=path)) - return ret - - -@_safety_blanket(lambda: None) -def purge() -> None: - get_session().delete_by_query( - index=config.config['elasticsearch']['index'], - doc_type=ES_DOC_TYPE, - body={'query': {'match_all': {}}}, - refresh=True) - - -@_safety_blanket(lambda: set()) -def get_all_paths() -> Set[str]: - search = ( - elasticsearch_dsl.Search( - using=get_session(), - index=config.config['elasticsearch']['index'], - doc_type=ES_DOC_TYPE) - .source(['path'])) - return set(h.path for h in search.scan()) +def unpack_signature(packed: bytes) -> NpMatrix: + base = 2 * N_LEVELS + 1 + return np.ravel(np.array([ + [ + int(digit) - N_LEVELS for digit in + np.base_repr(e, base=SIG_BASE).zfill(SIG_CHUNK_WIDTH) + ] for e in + np.frombuffer(packed, dtype=f'uint{SIG_CHUNK_BITS}') + ]).astype('int8')) diff --git a/server/szurubooru/func/posts.py b/server/szurubooru/func/posts.py index 73407f66..7b41a268 100644 --- a/server/szurubooru/func/posts.py +++ b/server/szurubooru/func/posts.py @@ -1,3 +1,4 @@ +import logging import hmac from typing import Any, Optional, Tuple, List, Dict, Callable from datetime import datetime @@ -8,6 +9,9 @@ from szurubooru.func import ( mime, images, files, image_hash, serialization, snapshots) +logger = logging.getLogger(__name__) + + EMPTY_PIXEL = ( b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' @@ -60,12 +64,6 @@ class InvalidPostFlagError(errors.ValidationError): pass -class PostLookalike(image_hash.Lookalike): - def __init__(self, score: int, distance: float, post: model.Post) -> None: - super().__init__(score, distance, post.post_id) - self.post = post - - SAFETY_MAP = { model.Post.SAFETY_SAFE: 'safe', model.Post.SAFETY_SKETCHY: 'sketchy', @@ -402,7 +400,6 @@ def _after_post_update( def _before_post_delete( _mapper: Any, _connection: Any, post: model.Post) -> None: if post.post_id: - image_hash.delete_image(post.post_id) if config.config['delete_source_files']: files.delete(get_post_content_path(post)) files.delete(get_post_thumbnail_path(post)) @@ -416,10 +413,6 @@ def _sync_post_content(post: model.Post) -> None: files.save(get_post_content_path(post), content) delattr(post, '__content') regenerate_thumb = True - if post.post_id and post.type in ( - model.Post.TYPE_IMAGE, model.Post.TYPE_ANIMATION): - image_hash.delete_image(post.post_id) - image_hash.add_image(post.post_id, content) if hasattr(post, '__thumbnail'): if getattr(post, '__thumbnail'): @@ -485,14 +478,56 @@ def test_sound(post: model.Post, content: bytes) -> None: update_post_flags(post, flags) +def purge_post_signature(post: model.Post) -> None: + old_signature = ( + db.session + .query(model.PostSignature) + .filter(model.PostSignature.post_id == post.post_id) + .one_or_none()) + if old_signature: + db.session.delete(old_signature) + + +def generate_post_signature(post: model.Post, content: bytes) -> None: + try: + unpacked_signature = image_hash.generate_signature(content) + packed_signature = image_hash.pack_signature(unpacked_signature) + words = image_hash.generate_words(unpacked_signature) + + db.session.add(model.PostSignature( + post=post, signature=packed_signature, words=words)) + except errors.ProcessingError: + if not config.config['allow_broken_uploads']: + raise InvalidPostContentError( + 'Unable to generate image hash data.') + + +def update_all_post_signatures() -> None: + posts_to_hash = ( + db.session + .query(model.Post) + .filter( + (model.Post.type == model.Post.TYPE_IMAGE) | + (model.Post.type == model.Post.TYPE_ANIMATION)) + .filter(model.Post.signature == None) + .order_by(model.Post.post_id.asc()) + .all()) + for post in posts_to_hash: + logger.info('Generating hash info for %d', post.post_id) + generate_post_signature(post, files.get(get_post_content_path(post))) + + def update_post_content(post: model.Post, content: Optional[bytes]) -> None: assert post if not content: raise InvalidPostContentError('Post content missing.') + + update_signature = False post.mime_type = mime.get_mime_type(content) if mime.is_flash(post.mime_type): post.type = model.Post.TYPE_FLASH elif mime.is_image(post.mime_type): + update_signature = True if mime.is_animated_gif(content): post.type = model.Post.TYPE_ANIMATION else: @@ -515,18 +550,30 @@ def update_post_content(post: model.Post, content: Optional[bytes]) -> None: and other_post.post_id != post.post_id: raise PostAlreadyUploadedError(other_post) + if update_signature: + purge_post_signature(post) + post.signature = generate_post_signature(post, content) + post.file_size = len(content) try: image = images.Image(content) post.canvas_width = image.width post.canvas_height = image.height except errors.ProcessingError: - post.canvas_width = None - post.canvas_height = None + if not config.config['allow_broken_uploads']: + raise InvalidPostContentError( + 'Unable to process image metadata') + else: + post.canvas_width = None + post.canvas_height = None if (post.canvas_width is not None and post.canvas_width <= 0) \ or (post.canvas_height is not None and post.canvas_height <= 0): - post.canvas_width = None - post.canvas_height = None + if not config.config['allow_broken_uploads']: + raise InvalidPostContentError( + 'Invalid image dimensions returned during processing') + else: + post.canvas_width = None + post.canvas_height = None setattr(post, '__content', content) @@ -751,6 +798,8 @@ def merge_posts( if replace_content: content = files.get(get_post_content_path(source_post)) transfer_flags(source_post.post_id, target_post.post_id) + purge_post_signature(source_post) + purge_post_signature(target_post) delete(source_post) db.session.flush() @@ -768,38 +817,31 @@ def search_by_image_exact(image_content: bytes) -> Optional[model.Post]: .one_or_none()) -def search_by_image(image_content: bytes) -> List[PostLookalike]: - ret = [] - for result in image_hash.search_by_image(image_content): - post = try_get_post_by_id(result.path) - if post: - ret.append(PostLookalike( - score=result.score, - distance=result.distance, - post=post)) - return ret +def search_by_image(image_content: bytes) -> List[Tuple[float, model.Post]]: + query_signature = image_hash.generate_signature(image_content) + query_words = image_hash.generate_words(query_signature) + dbquery = ''' + SELECT s.post_id, s.signature, count(a.query) AS score + FROM post_signature AS s, unnest(s.words, :q) AS a(word, query) + WHERE a.word = a.query + GROUP BY s.post_id + ORDER BY score DESC LIMIT 100; + ''' -def populate_reverse_search() -> None: - excluded_post_ids = image_hash.get_all_paths() - - post_ids_to_hash = ( - db.session - .query(model.Post.post_id) - .filter( - (model.Post.type == model.Post.TYPE_IMAGE) | - (model.Post.type == model.Post.TYPE_ANIMATION)) - .filter(~model.Post.post_id.in_(excluded_post_ids)) - .order_by(model.Post.post_id.asc()) - .all()) - - for post_ids_chunk in util.chunks(post_ids_to_hash, 100): - posts_chunk = ( - db.session - .query(model.Post) - .filter(model.Post.post_id.in_(post_ids_chunk)) - .all()) - for post in posts_chunk: - content_path = get_post_content_path(post) - if files.has(content_path): - image_hash.add_image(post.post_id, files.get(content_path)) + candidates = db.session.execute(dbquery, {'q': query_words}) + data = tuple(zip(*[ + (post_id, image_hash.unpack_signature(packedsig)) + for post_id, packedsig, score in candidates + ])) + if data: + candidate_post_ids, sigarray = data + distances = image_hash.normalized_distance(sigarray, query_signature) + return [ + (distance, try_get_post_by_id(candidate_post_id)) + for candidate_post_id, distance + in zip(candidate_post_ids, distances) + if distance < image_hash.DISTANCE_CUTOFF + ] + else: + return [] diff --git a/server/szurubooru/model/__init__.py b/server/szurubooru/model/__init__.py index 4892b974..a043bd1f 100644 --- a/server/szurubooru/model/__init__.py +++ b/server/szurubooru/model/__init__.py @@ -9,7 +9,8 @@ from szurubooru.model.post import ( PostFavorite, PostScore, PostNote, - PostFeature) + PostFeature, + PostSignature) from szurubooru.model.comment import Comment, CommentScore from szurubooru.model.snapshot import Snapshot import szurubooru.model.util diff --git a/server/szurubooru/model/post.py b/server/szurubooru/model/post.py index 30e2900c..ec2cf471 100644 --- a/server/szurubooru/model/post.py +++ b/server/szurubooru/model/post.py @@ -143,6 +143,26 @@ class PostTag(Base): self.tag_id = tag_id +class PostSignature(Base): + __tablename__ = 'post_signature' + + post_id = sa.Column( + 'post_id', + sa.Integer, + sa.ForeignKey('post.id'), + primary_key=True, + nullable=False, + index=True) + signature = sa.Column('signature', sa.LargeBinary, nullable=False) + words = sa.Column( + 'words', + sa.dialects.postgresql.ARRAY(sa.Integer, dimensions=1), + nullable=False, + index=True) + + post = sa.orm.relationship('Post') + + class Post(Base): __tablename__ = 'post' @@ -184,6 +204,11 @@ class Post(Base): # foreign tables user = sa.orm.relationship('User') tags = sa.orm.relationship('Tag', backref='posts', secondary='post_tag') + signature = sa.orm.relationship( + 'PostSignature', + uselist=False, + cascade='all, delete-orphan', + lazy='joined') relations = sa.orm.relationship( 'Post', secondary='post_relation',