diff --git a/config.yaml.dist b/config.yaml.dist index a38393a0..1f4a2307 100644 --- a/config.yaml.dist +++ b/config.yaml.dist @@ -36,6 +36,7 @@ smtp: elasticsearch: host: localhost port: 9200 + index: szurubooru limits: users_per_page: 20 diff --git a/server/szurubooru/func/image_hash.py b/server/szurubooru/func/image_hash.py index 8f8c4d72..ae368488 100644 --- a/server/szurubooru/func/image_hash.py +++ b/server/szurubooru/func/image_hash.py @@ -12,7 +12,10 @@ es = elasticsearch.Elasticsearch([{ 'host': config.config['elasticsearch']['host'], 'port': config.config['elasticsearch']['port'], }]) -session = SignatureES(es, index='szurubooru') + + +def _get_session(): + return SignatureES(es, index=config.config['elasticsearch']['index']) def _safe_blanket(default_param_factory): @@ -48,6 +51,7 @@ class Lookalike: def add_image(path, image_content): if not path or not image_content: return + session = _get_session() session.add_image(path=path, img=image_content, bytestream=True) @@ -55,6 +59,7 @@ def add_image(path, image_content): def delete_image(path): if not path: return + session = _get_session() es.delete_by_query( index=session.index, doc_type=session.doc_type, @@ -64,6 +69,7 @@ def delete_image(path): @_safe_blanket(lambda: []) def search_by_image(image_content): ret = [] + session = _get_session() for result in session.search_image( path=image_content, # sic bytestream=True): @@ -76,6 +82,7 @@ def search_by_image(image_content): @_safe_blanket(lambda: None) def purge(): + session = _get_session() es.delete_by_query( index=session.index, doc_type=session.doc_type, @@ -84,6 +91,7 @@ def purge(): @_safe_blanket(lambda: set()) def get_all_paths(): + session = _get_session() search = ( elasticsearch_dsl.Search( using=es, index=session.index, doc_type=session.doc_type) diff --git a/server/szurubooru/tests/api/test_post_creating.py b/server/szurubooru/tests/api/test_post_creating.py index 341990c8..9737a73b 100644 --- a/server/szurubooru/tests/api/test_post_creating.py +++ b/server/szurubooru/tests/api/test_post_creating.py @@ -258,7 +258,8 @@ def test_omitting_optional_field( def test_errors_not_spending_ids( - config_injector, tmpdir, context_factory, read_asset, user_factory): + config_injector, tmpdir, context_factory, read_asset, user_factory, + skip_post_hashing): config_injector({ 'data_dir': str(tmpdir.mkdir('data')), 'data_url': 'example.com', diff --git a/server/szurubooru/tests/assets/jpeg-similar.jpg b/server/szurubooru/tests/assets/jpeg-similar.jpg new file mode 100644 index 00000000..af612092 Binary files /dev/null and b/server/szurubooru/tests/assets/jpeg-similar.jpg differ diff --git a/server/szurubooru/tests/conftest.py b/server/szurubooru/tests/conftest.py index 5dd47576..db34ee02 100644 --- a/server/szurubooru/tests/conftest.py +++ b/server/szurubooru/tests/conftest.py @@ -155,7 +155,7 @@ def tag_factory(): return factory -@pytest.yield_fixture(autouse=True) +@pytest.yield_fixture def skip_post_hashing(): with patch('szurubooru.func.image_hash.add_image'), \ patch('szurubooru.func.image_hash.delete_image'): @@ -163,7 +163,7 @@ def skip_post_hashing(): @pytest.fixture -def post_factory(): +def post_factory(skip_post_hashing): # pylint: disable=invalid-name def factory( id=None, diff --git a/server/szurubooru/tests/func/test_image_hash.py b/server/szurubooru/tests/func/test_image_hash.py new file mode 100644 index 00000000..3bbdeba2 --- /dev/null +++ b/server/szurubooru/tests/func/test_image_hash.py @@ -0,0 +1,24 @@ +from time import sleep +from szurubooru.func import image_hash + + +def test_hashing(read_asset, config_injector): + config_injector({'elasticsearch': {'index': 'szurubooru_test'}}) + image_hash.purge() + image_hash.add_image('test', read_asset('jpeg.jpg')) + + sleep(0.1) + + paths = image_hash.get_all_paths() + results_exact = image_hash.search_by_image(read_asset('jpeg.jpg')) + results_similar = image_hash.search_by_image(read_asset('jpeg-similar.jpg')) + + assert len(paths) == 1 + assert len(results_exact) == 1 + assert len(results_similar) == 1 + assert results_exact[0].path == 'test' + assert results_exact[0].score == 63 + assert results_exact[0].distance == 0 + assert results_similar[0].path == 'test' + assert results_similar[0].score == 26 + assert abs(results_similar[0].distance - 0.189390583) < 1e-8 diff --git a/server/szurubooru/tests/func/test_posts.py b/server/szurubooru/tests/func/test_posts.py index f44fb03b..a10cc6d9 100644 --- a/server/szurubooru/tests/func/test_posts.py +++ b/server/szurubooru/tests/func/test_posts.py @@ -290,15 +290,8 @@ def test_update_post_source_with_too_long_string(): ), ]) def test_update_post_content_for_new_post( - tmpdir, - config_injector, - post_factory, - read_asset, - is_existing, - input_file, - expected_mime_type, - expected_type, - output_file_name): + tmpdir, config_injector, post_factory, read_asset, is_existing, + input_file, expected_mime_type, expected_type, output_file_name): with patch('szurubooru.func.util.get_sha1'): util.get_sha1.return_value = 'crc' config_injector({