server/image_search: implement reverse search functionality in postgres

This will remove the dependency on the Elasticsearch database.

The search query is passed currently as raw SQL. Proper implementation
using SQLAlchemy will need custom ORM classed to be made.

Additional config parameter "allow_broken_uploads" has been added.
This commit is contained in:
Shyam Sunder 2020-03-07 20:43:20 -05:00
parent a616cf6987
commit 4c78cf8c47
7 changed files with 187 additions and 231 deletions

View file

@ -21,11 +21,15 @@ thumbnails:
post_width: 300
post_height: 300
# automatically convert animated GIF uploads to video formats
convert:
gif:
to_webm: false
to_mp4: false
# allow posts to be uploaded even if some image processing errors occur
allow_broken_uploads: false
# used to send password reset e-mails
smtp:
host: # example: localhost

View file

@ -262,9 +262,9 @@ def get_posts_by_image(
'similarPosts':
[
{
'distance': lookalike.distance,
'post': _serialize_post(ctx, lookalike.post),
'distance': distance,
'post': _serialize_post(ctx, post),
}
for lookalike in lookalikes
for distance, post in lookalikes
],
}

View file

@ -8,7 +8,7 @@ import coloredlogs
import sqlalchemy as sa
import sqlalchemy.orm.exc
from szurubooru import config, db, errors, rest
from szurubooru.func import posts, file_uploads, image_hash
from szurubooru.func import file_uploads
# pylint: disable=unused-import
from szurubooru import api, middleware
@ -129,13 +129,7 @@ def create_app() -> Callable[[Any, Any], Any]:
purge_thread.daemon = True
purge_thread.start()
try:
image_hash.get_session().cluster.health(
wait_for_status='yellow', request_timeout=120)
posts.populate_reverse_search()
db.session.commit()
except errors.ThirdPartyError:
pass
rest.errors.handle(errors.AuthError, _on_auth_error)
rest.errors.handle(errors.ValidationError, _on_validation_error)

View file

@ -2,8 +2,7 @@ import logging
from io import BytesIO
from datetime import datetime
from typing import Any, Optional, Tuple, Set, List, Callable
import elasticsearch
import elasticsearch_dsl
import math
import numpy as np
from PIL import Image
from szurubooru import config, errors
@ -24,30 +23,25 @@ N = 9
P = None
SAMPLE_WORDS = 16
MAX_WORDS = 63
ES_DOC_TYPE = 'image'
ES_MAX_RESULTS = 100
SIG_CHUNK_BITS = 32
SIG_BASE = 2*N_LEVELS + 2
SIG_CHUNK_WIDTH = int(SIG_CHUNK_BITS / math.log2(SIG_BASE))
SIG_CHUNK_NUMS = 8*N*N / SIG_CHUNK_WIDTH
assert 8*N*N % SIG_CHUNK_WIDTH == 0
Window = Tuple[Tuple[float, float], Tuple[float, float]]
NpMatrix = Any
def get_session() -> elasticsearch.Elasticsearch:
extra_args = {}
if config.config['elasticsearch']['pass']:
extra_args['http_auth'] = (
config.config['elasticsearch']['user'],
config.config['elasticsearch']['pass'])
extra_args['scheme'] = 'https'
extra_args['port'] = 443
return elasticsearch.Elasticsearch([{
'host': config.config['elasticsearch']['host'],
'port': config.config['elasticsearch']['port'],
}], **extra_args)
NpMatrix = np.ndarray
def _preprocess_image(content: bytes) -> NpMatrix:
try:
img = Image.open(BytesIO(content))
return np.asarray(img.convert('L'), dtype=np.uint8)
except IOError:
raise errors.ProcessingError(
'Unable to generate a signature hash '
'for this image.')
def _crop_image(
@ -175,7 +169,31 @@ def _compute_differentials(grey_level_matrix: NpMatrix) -> NpMatrix:
lower_right_neighbors]))
def _generate_signature(content: bytes) -> NpMatrix:
def _words_to_int(word_array: NpMatrix) -> List[int]:
width = word_array.shape[1]
coding_vector = 3**np.arange(width)
return np.dot(word_array + 1, coding_vector).astype(int).tolist()
def _get_words(array: NpMatrix, k: int, n: int) -> NpMatrix:
word_positions = np.linspace(
0, array.shape[0], n, endpoint=False).astype('int')
assert k <= array.shape[0]
assert word_positions.shape[0] <= array.shape[0]
words = np.zeros((n, k)).astype('int8')
for i, pos in enumerate(word_positions):
if pos + k <= array.shape[0]:
words[i] = array[pos:pos + k]
else:
temp = array[pos:].copy()
temp.resize(k, refcheck=False)
words[i] = temp
words[words > 0] = 1
words[words < 0] = -1
return words
def generate_signature(content: bytes) -> NpMatrix:
im_array = _preprocess_image(content)
image_limits = _crop_image(
im_array,
@ -192,40 +210,15 @@ def _generate_signature(content: bytes) -> NpMatrix:
return np.ravel(diff_matrix).astype('int8')
def _get_words(array: NpMatrix, k: int, n: int) -> NpMatrix:
word_positions = np.linspace(
0, array.shape[0], n, endpoint=False).astype('int')
assert k <= array.shape[0]
assert word_positions.shape[0] <= array.shape[0]
words = np.zeros((n, k)).astype('int8')
for i, pos in enumerate(word_positions):
if pos + k <= array.shape[0]:
words[i] = array[pos:pos + k]
else:
temp = array[pos:].copy()
temp.resize(k)
words[i] = temp
_max_contrast(words)
words = _words_to_int(words)
return words
def generate_words(signature: NpMatrix) -> List[int]:
return _words_to_int(_get_words(signature, k=SAMPLE_WORDS, n=MAX_WORDS))
def _words_to_int(word_array: NpMatrix) -> NpMatrix:
width = word_array.shape[1]
coding_vector = 3**np.arange(width)
return np.dot(word_array + 1, coding_vector)
def _max_contrast(array: NpMatrix) -> None:
array[array > 0] = 1
array[array < 0] = -1
def _normalized_distance(
target_array: NpMatrix,
def normalized_distance(
target_array: Any,
vec: NpMatrix,
nan_value: float = 1.0) -> List[float]:
target_array = target_array.astype(int)
target_array = np.array(target_array).astype(int)
vec = vec.astype(int)
topvec = np.linalg.norm(vec - target_array, axis=1)
norm1 = np.linalg.norm(vec, axis=0)
@ -235,124 +228,21 @@ def _normalized_distance(
return finvec
def _safety_blanket(default_param_factory: Callable[[], Any]) -> Callable:
def wrapper_outer(target_function: Callable) -> Callable:
def wrapper_inner(*args: Any, **kwargs: Any) -> Any:
try:
return target_function(*args, **kwargs)
except elasticsearch.exceptions.NotFoundError:
# index not yet created, will be created dynamically by
# add_image()
return default_param_factory()
except elasticsearch.exceptions.ElasticsearchException as ex:
logger.warning('Problem with elastic search: %s', ex)
raise errors.ThirdPartyError(
'Error connecting to elastic search.')
except IOError:
raise errors.ProcessingError('Not an image.')
except Exception as ex:
raise errors.ThirdPartyError('Unknown error (%s).' % ex)
return wrapper_inner
return wrapper_outer
def pack_signature(signature: NpMatrix) -> bytes:
base = 2 * N_LEVELS + 1
coding_vector = np.flipud(SIG_BASE**np.arange(SIG_CHUNK_WIDTH))
return np.array([
np.dot(x, coding_vector) for x in
np.reshape(signature + N_LEVELS, (-1, SIG_CHUNK_WIDTH))
]).astype(f'uint{SIG_CHUNK_BITS}').tobytes()
class Lookalike:
def __init__(self, score: int, distance: float, path: Any) -> None:
self.score = score
self.distance = distance
self.path = path
@_safety_blanket(lambda: None)
def add_image(path: str, image_content: bytes) -> None:
assert path
assert image_content
signature = _generate_signature(image_content)
words = _get_words(signature, k=SAMPLE_WORDS, n=MAX_WORDS)
record = {
'signature': signature.tolist(),
'path': path,
'timestamp': datetime.now(),
}
for i in range(MAX_WORDS):
record['simple_word_' + str(i)] = words[i].tolist()
get_session().index(
index=config.config['elasticsearch']['index'],
doc_type=ES_DOC_TYPE,
body=record,
refresh=True)
@_safety_blanket(lambda: None)
def delete_image(path: str) -> None:
assert path
get_session().delete_by_query(
index=config.config['elasticsearch']['index'],
doc_type=ES_DOC_TYPE,
body={'query': {'term': {'path': path}}})
@_safety_blanket(lambda: [])
def search_by_image(image_content: bytes) -> List[Lookalike]:
signature = _generate_signature(image_content)
words = _get_words(signature, k=SAMPLE_WORDS, n=MAX_WORDS)
res = get_session().search(
index=config.config['elasticsearch']['index'],
doc_type=ES_DOC_TYPE,
body={
'query':
{
'bool':
{
'should':
def unpack_signature(packed: bytes) -> NpMatrix:
base = 2 * N_LEVELS + 1
return np.ravel(np.array([
[
{'term': {'simple_word_%d' % i: word.tolist()}}
for i, word in enumerate(words)
]
}
},
'_source': {'excludes': ['simple_word_*']}},
size=ES_MAX_RESULTS,
timeout='10s')['hits']['hits']
if len(res) == 0:
return []
sigs = np.array([x['_source']['signature'] for x in res])
dists = _normalized_distance(sigs, np.array(signature))
ids = set() # type: Set[int]
ret = []
for item, dist in zip(res, dists):
id = item['_id']
score = item['_score']
path = item['_source']['path']
if id in ids:
continue
ids.add(id)
if dist < DISTANCE_CUTOFF:
ret.append(Lookalike(score=score, distance=dist, path=path))
return ret
@_safety_blanket(lambda: None)
def purge() -> None:
get_session().delete_by_query(
index=config.config['elasticsearch']['index'],
doc_type=ES_DOC_TYPE,
body={'query': {'match_all': {}}},
refresh=True)
@_safety_blanket(lambda: set())
def get_all_paths() -> Set[str]:
search = (
elasticsearch_dsl.Search(
using=get_session(),
index=config.config['elasticsearch']['index'],
doc_type=ES_DOC_TYPE)
.source(['path']))
return set(h.path for h in search.scan())
int(digit) - N_LEVELS for digit in
np.base_repr(e, base=SIG_BASE).zfill(SIG_CHUNK_WIDTH)
] for e in
np.frombuffer(packed, dtype=f'uint{SIG_CHUNK_BITS}')
]).astype('int8'))

View file

@ -1,3 +1,4 @@
import logging
import hmac
from typing import Any, Optional, Tuple, List, Dict, Callable
from datetime import datetime
@ -8,6 +9,9 @@ from szurubooru.func import (
mime, images, files, image_hash, serialization, snapshots)
logger = logging.getLogger(__name__)
EMPTY_PIXEL = (
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00'
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00'
@ -60,12 +64,6 @@ class InvalidPostFlagError(errors.ValidationError):
pass
class PostLookalike(image_hash.Lookalike):
def __init__(self, score: int, distance: float, post: model.Post) -> None:
super().__init__(score, distance, post.post_id)
self.post = post
SAFETY_MAP = {
model.Post.SAFETY_SAFE: 'safe',
model.Post.SAFETY_SKETCHY: 'sketchy',
@ -402,7 +400,6 @@ def _after_post_update(
def _before_post_delete(
_mapper: Any, _connection: Any, post: model.Post) -> None:
if post.post_id:
image_hash.delete_image(post.post_id)
if config.config['delete_source_files']:
files.delete(get_post_content_path(post))
files.delete(get_post_thumbnail_path(post))
@ -416,10 +413,6 @@ def _sync_post_content(post: model.Post) -> None:
files.save(get_post_content_path(post), content)
delattr(post, '__content')
regenerate_thumb = True
if post.post_id and post.type in (
model.Post.TYPE_IMAGE, model.Post.TYPE_ANIMATION):
image_hash.delete_image(post.post_id)
image_hash.add_image(post.post_id, content)
if hasattr(post, '__thumbnail'):
if getattr(post, '__thumbnail'):
@ -485,14 +478,56 @@ def test_sound(post: model.Post, content: bytes) -> None:
update_post_flags(post, flags)
def purge_post_signature(post: model.Post) -> None:
old_signature = (
db.session
.query(model.PostSignature)
.filter(model.PostSignature.post_id == post.post_id)
.one_or_none())
if old_signature:
db.session.delete(old_signature)
def generate_post_signature(post: model.Post, content: bytes) -> None:
try:
unpacked_signature = image_hash.generate_signature(content)
packed_signature = image_hash.pack_signature(unpacked_signature)
words = image_hash.generate_words(unpacked_signature)
db.session.add(model.PostSignature(
post=post, signature=packed_signature, words=words))
except errors.ProcessingError:
if not config.config['allow_broken_uploads']:
raise InvalidPostContentError(
'Unable to generate image hash data.')
def update_all_post_signatures() -> None:
posts_to_hash = (
db.session
.query(model.Post)
.filter(
(model.Post.type == model.Post.TYPE_IMAGE) |
(model.Post.type == model.Post.TYPE_ANIMATION))
.filter(model.Post.signature == None)
.order_by(model.Post.post_id.asc())
.all())
for post in posts_to_hash:
logger.info('Generating hash info for %d', post.post_id)
generate_post_signature(post, files.get(get_post_content_path(post)))
def update_post_content(post: model.Post, content: Optional[bytes]) -> None:
assert post
if not content:
raise InvalidPostContentError('Post content missing.')
update_signature = False
post.mime_type = mime.get_mime_type(content)
if mime.is_flash(post.mime_type):
post.type = model.Post.TYPE_FLASH
elif mime.is_image(post.mime_type):
update_signature = True
if mime.is_animated_gif(content):
post.type = model.Post.TYPE_ANIMATION
else:
@ -515,16 +550,28 @@ def update_post_content(post: model.Post, content: Optional[bytes]) -> None:
and other_post.post_id != post.post_id:
raise PostAlreadyUploadedError(other_post)
if update_signature:
purge_post_signature(post)
post.signature = generate_post_signature(post, content)
post.file_size = len(content)
try:
image = images.Image(content)
post.canvas_width = image.width
post.canvas_height = image.height
except errors.ProcessingError:
if not config.config['allow_broken_uploads']:
raise InvalidPostContentError(
'Unable to process image metadata')
else:
post.canvas_width = None
post.canvas_height = None
if (post.canvas_width is not None and post.canvas_width <= 0) \
or (post.canvas_height is not None and post.canvas_height <= 0):
if not config.config['allow_broken_uploads']:
raise InvalidPostContentError(
'Invalid image dimensions returned during processing')
else:
post.canvas_width = None
post.canvas_height = None
setattr(post, '__content', content)
@ -751,6 +798,8 @@ def merge_posts(
if replace_content:
content = files.get(get_post_content_path(source_post))
transfer_flags(source_post.post_id, target_post.post_id)
purge_post_signature(source_post)
purge_post_signature(target_post)
delete(source_post)
db.session.flush()
@ -768,38 +817,31 @@ def search_by_image_exact(image_content: bytes) -> Optional[model.Post]:
.one_or_none())
def search_by_image(image_content: bytes) -> List[PostLookalike]:
ret = []
for result in image_hash.search_by_image(image_content):
post = try_get_post_by_id(result.path)
if post:
ret.append(PostLookalike(
score=result.score,
distance=result.distance,
post=post))
return ret
def search_by_image(image_content: bytes) -> List[Tuple[float, model.Post]]:
query_signature = image_hash.generate_signature(image_content)
query_words = image_hash.generate_words(query_signature)
dbquery = '''
SELECT s.post_id, s.signature, count(a.query) AS score
FROM post_signature AS s, unnest(s.words, :q) AS a(word, query)
WHERE a.word = a.query
GROUP BY s.post_id
ORDER BY score DESC LIMIT 100;
'''
def populate_reverse_search() -> None:
excluded_post_ids = image_hash.get_all_paths()
post_ids_to_hash = (
db.session
.query(model.Post.post_id)
.filter(
(model.Post.type == model.Post.TYPE_IMAGE) |
(model.Post.type == model.Post.TYPE_ANIMATION))
.filter(~model.Post.post_id.in_(excluded_post_ids))
.order_by(model.Post.post_id.asc())
.all())
for post_ids_chunk in util.chunks(post_ids_to_hash, 100):
posts_chunk = (
db.session
.query(model.Post)
.filter(model.Post.post_id.in_(post_ids_chunk))
.all())
for post in posts_chunk:
content_path = get_post_content_path(post)
if files.has(content_path):
image_hash.add_image(post.post_id, files.get(content_path))
candidates = db.session.execute(dbquery, {'q': query_words})
data = tuple(zip(*[
(post_id, image_hash.unpack_signature(packedsig))
for post_id, packedsig, score in candidates
]))
if data:
candidate_post_ids, sigarray = data
distances = image_hash.normalized_distance(sigarray, query_signature)
return [
(distance, try_get_post_by_id(candidate_post_id))
for candidate_post_id, distance
in zip(candidate_post_ids, distances)
if distance < image_hash.DISTANCE_CUTOFF
]
else:
return []

View file

@ -9,7 +9,8 @@ from szurubooru.model.post import (
PostFavorite,
PostScore,
PostNote,
PostFeature)
PostFeature,
PostSignature)
from szurubooru.model.comment import Comment, CommentScore
from szurubooru.model.snapshot import Snapshot
import szurubooru.model.util

View file

@ -143,6 +143,26 @@ class PostTag(Base):
self.tag_id = tag_id
class PostSignature(Base):
__tablename__ = 'post_signature'
post_id = sa.Column(
'post_id',
sa.Integer,
sa.ForeignKey('post.id'),
primary_key=True,
nullable=False,
index=True)
signature = sa.Column('signature', sa.LargeBinary, nullable=False)
words = sa.Column(
'words',
sa.dialects.postgresql.ARRAY(sa.Integer, dimensions=1),
nullable=False,
index=True)
post = sa.orm.relationship('Post')
class Post(Base):
__tablename__ = 'post'
@ -184,6 +204,11 @@ class Post(Base):
# foreign tables
user = sa.orm.relationship('User')
tags = sa.orm.relationship('Tag', backref='posts', secondary='post_tag')
signature = sa.orm.relationship(
'PostSignature',
uselist=False,
cascade='all, delete-orphan',
lazy='joined')
relations = sa.orm.relationship(
'Post',
secondary='post_relation',