From 1a59a74d634f74627800816b2378fa59a0c81292 Mon Sep 17 00:00:00 2001 From: rr- Date: Sun, 27 Nov 2016 18:42:14 +0100 Subject: [PATCH] server/image-hash: add image search engine --- INSTALL.md | 8 +++ config.yaml.dist | 5 ++ server/requirements.txt | 4 ++ server/szurubooru/facade.py | 3 ++ server/szurubooru/func/image_hash.py | 59 ++++++++++++++++++++++ server/szurubooru/func/posts.py | 45 ++++++++++++++++- server/szurubooru/func/util.py | 5 ++ server/szurubooru/tests/conftest.py | 8 +++ server/szurubooru/tests/func/test_posts.py | 12 ++++- 9 files changed, 145 insertions(+), 4 deletions(-) create mode 100644 server/szurubooru/func/image_hash.py diff --git a/INSTALL.md b/INSTALL.md index bf758338..c4ccdcc3 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -9,6 +9,7 @@ user@host:~$ sudo pacman -S python user@host:~$ sudo pacman -S python-pip user@host:~$ sudo pacman -S ffmpeg user@host:~$ sudo pacman -S npm +user@host:~$ sudo pacman -S elasticsearch user@host:~$ sudo pip install virtualenv user@host:~$ python --version Python 3.5.1 @@ -43,6 +44,13 @@ user@host:~$ sudo -i -u postgres psql -c "ALTER USER szuru PASSWORD 'dog';" +### Setting up elasticsearch + +```console +user@host:~$ sudo systemctl start elasticsearch +user@host:~$ sudo systemctl enable elasticsearch +``` + ### Preparing environment Getting `szurubooru`: diff --git a/config.yaml.dist b/config.yaml.dist index 0e18de47..e59ab4fb 100644 --- a/config.yaml.dist +++ b/config.yaml.dist @@ -32,6 +32,11 @@ smtp: user: # example: bot pass: # example: groovy123 +# used for reverse image search +elasticsearch: + host: localhost + port: 9200 + limits: users_per_page: 20 posts_per_page: 40 diff --git a/server/requirements.txt b/server/requirements.txt index 6d172cee..c200222d 100644 --- a/server/requirements.txt +++ b/server/requirements.txt @@ -7,3 +7,7 @@ pytest-cov>=2.2.1 freezegun>=0.3.6 coloredlogs==5.0 pycodestyle>=2.0.0 +image-match>=1.1.0 +scipy>=0.18.1 +elasticsearch>=5.0.0 +elasticsearch-dsl>=5.0.0 diff --git a/server/szurubooru/facade.py b/server/szurubooru/facade.py index 1d26aadb..5211ca81 100644 --- a/server/szurubooru/facade.py +++ b/server/szurubooru/facade.py @@ -5,6 +5,7 @@ import logging import coloredlogs import sqlalchemy.orm.exc from szurubooru import config, errors, rest +from szurubooru.func import posts # pylint: disable=unused-import from szurubooru import api, middleware @@ -87,6 +88,8 @@ def create_app(): if config.config['show_sql']: logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO) + posts.populate_reverse_search() + rest.errors.handle(errors.AuthError, _on_auth_error) rest.errors.handle(errors.ValidationError, _on_validation_error) rest.errors.handle(errors.SearchError, _on_search_error) diff --git a/server/szurubooru/func/image_hash.py b/server/szurubooru/func/image_hash.py new file mode 100644 index 00000000..9c392c5f --- /dev/null +++ b/server/szurubooru/func/image_hash.py @@ -0,0 +1,59 @@ +import elasticsearch +import elasticsearch_dsl +from image_match.elasticsearch_driver import SignatureES +from szurubooru import config + + +# pylint: disable=invalid-name +es = elasticsearch.Elasticsearch([{ + 'host': config.config['elasticsearch']['host'], + 'port': config.config['elasticsearch']['port'], +}]) +session = SignatureES(es, index='szurubooru') + + +def add_image(path, image_content): + if not path or not image_content: + return + session.add_image(path=path, img=image_content, bytestream=True) + + +def delete_image(path): + if not path: + return + try: + es.delete_by_query( + index=session.index, + doc_type=session.doc_type, + body={'query': {'term': {'path': path}}}) + except elasticsearch.exceptions.NotFoundError: + pass + + +def search_by_image(image_content): + for result in session.search_image( + path=image_content, # sic + bytestream=True): + yield { + 'score': result['score'], + 'dist': result['dist'], + 'path': result['path'], + } + + +def purge(): + es.delete_by_query( + index=session.index, + doc_type=session.doc_type, + body={'query': {'match_all': {}}}) + + +def get_all_paths(): + try: + search = ( + elasticsearch_dsl.Search( + using=es, index=session.index, doc_type=session.doc_type) + .source(['path'])) + return set(h.path for h in search.scan()) + except elasticsearch.exceptions.NotFoundError: + return set() diff --git a/server/szurubooru/func/posts.py b/server/szurubooru/func/posts.py index e03fe637..6ff3a87b 100644 --- a/server/szurubooru/func/posts.py +++ b/server/szurubooru/func/posts.py @@ -2,7 +2,7 @@ import datetime import sqlalchemy from szurubooru import config, db, errors from szurubooru.func import ( - users, scores, comments, tags, util, mime, images, files) + users, scores, comments, tags, util, mime, images, files, image_hash) EMPTY_PIXEL = \ @@ -260,13 +260,22 @@ def _after_post_update(_mapper, _connection, post): _sync_post_content(post) +@sqlalchemy.events.event.listens_for(db.Post, 'before_delete') +def _before_post_delete(_mapper, _connection, post): + image_hash.delete_image(post.post_id) + + def _sync_post_content(post): regenerate_thumb = False if hasattr(post, '__content'): - files.save(get_post_content_path(post), getattr(post, '__content')) + content = getattr(post, '__content') + files.save(get_post_content_path(post), content) delattr(post, '__content') regenerate_thumb = True + if post.type in (db.Post.TYPE_IMAGE, db.Post.TYPE_ANIMATION): + image_hash.delete_image(post.post_id) + image_hash.add_image(post.post_id, content) if hasattr(post, '__thumbnail'): if getattr(post, '__thumbnail'): @@ -521,3 +530,35 @@ def merge_posts(source_post, target_post, replace_content): if replace_content: content = files.get(get_post_content_path(source_post)) update_post_content(target_post, content) + + +def search_by_image(image_content): + for result in image_hash.search_by_image(image_content): + yield { + 'score': result['score'], + 'dist': result['dist'], + 'post': get_post_by_id(result['path']) + } + + +def populate_reverse_search(): + excluded_post_ids = image_hash.get_all_paths() + + post_ids_to_hash = (db.session + .query(db.Post.post_id) + .filter( + (db.Post.type == db.Post.TYPE_IMAGE) | + (db.Post.type == db.Post.TYPE_ANIMATION)) + .filter(~db.Post.post_id.in_(excluded_post_ids)) + .order_by(db.Post.post_id.asc()) + .all()) + + for post_ids_chunk in util.chunks(post_ids_to_hash, 100): + posts_chunk = (db.session + .query(db.Post) + .filter(db.Post.post_id.in_(post_ids_chunk)) + .all()) + for post in posts_chunk: + content_path = get_post_content_path(post) + if files.has(content_path): + image_hash.add_image(post.post_id, files.get(content_path)) diff --git a/server/szurubooru/func/util.py b/server/szurubooru/func/util.py index 14c4440b..497878ee 100644 --- a/server/szurubooru/func/util.py +++ b/server/szurubooru/func/util.py @@ -162,3 +162,8 @@ def value_exceeds_column_size(value, column): if max_length is None: return False return len(value) > max_length + + +def chunks(source_list, part_size): + for i in range(0, len(source_list), part_size): + yield source_list[i:i + part_size] diff --git a/server/szurubooru/tests/conftest.py b/server/szurubooru/tests/conftest.py index 642ac6a4..5dd47576 100644 --- a/server/szurubooru/tests/conftest.py +++ b/server/szurubooru/tests/conftest.py @@ -3,6 +3,7 @@ import contextlib import os import random import string +from unittest.mock import patch from datetime import datetime import pytest import freezegun @@ -154,6 +155,13 @@ def tag_factory(): return factory +@pytest.yield_fixture(autouse=True) +def skip_post_hashing(): + with patch('szurubooru.func.image_hash.add_image'), \ + patch('szurubooru.func.image_hash.delete_image'): + yield + + @pytest.fixture def post_factory(): # pylint: disable=invalid-name diff --git a/server/szurubooru/tests/func/test_posts.py b/server/szurubooru/tests/func/test_posts.py index 5b462941..6465f51e 100644 --- a/server/szurubooru/tests/func/test_posts.py +++ b/server/szurubooru/tests/func/test_posts.py @@ -3,7 +3,8 @@ from unittest.mock import patch from datetime import datetime import pytest from szurubooru import db -from szurubooru.func import (posts, users, comments, tags, images, files, util) +from szurubooru.func import ( + posts, users, comments, tags, images, files, util, image_hash) @pytest.mark.parametrize('input_mime_type,expected_url', [ @@ -316,13 +317,20 @@ def test_update_post_content_for_new_post( else: assert not post.post_id assert not os.path.exists(output_file_path) - posts.update_post_content(post, read_asset(input_file)) + content = read_asset(input_file) + posts.update_post_content(post, content) assert not os.path.exists(output_file_path) db.session.flush() assert post.mime_type == expected_mime_type assert post.type == expected_type assert post.checksum == 'crc' assert os.path.exists(output_file_path) + if post.type in (db.Post.TYPE_IMAGE, db.Post.TYPE_ANIMATION): + image_hash.delete_image.assert_called_once_with(post.post_id) + image_hash.add_image.assert_called_once_with(post.post_id, content) + else: + image_hash.delete_image.assert_not_called() + image_hash.add_image.assert_not_called() def test_update_post_content_to_existing_content(