server/image_search: implement reverse search functionality in postgres

This will remove the dependency on the Elasticsearch database.

The search query is passed currently as raw SQL. Proper implementation
using SQLAlchemy will need custom ORM classed to be made.

Additional config parameter "allow_broken_uploads" has been added.
This commit is contained in:
Shyam Sunder 2020-03-07 20:43:20 -05:00
parent a616cf6987
commit 4c78cf8c47
7 changed files with 187 additions and 231 deletions

View file

@ -21,11 +21,15 @@ thumbnails:
post_width: 300 post_width: 300
post_height: 300 post_height: 300
# automatically convert animated GIF uploads to video formats
convert: convert:
gif: gif:
to_webm: false to_webm: false
to_mp4: false to_mp4: false
# allow posts to be uploaded even if some image processing errors occur
allow_broken_uploads: false
# used to send password reset e-mails # used to send password reset e-mails
smtp: smtp:
host: # example: localhost host: # example: localhost

View file

@ -262,9 +262,9 @@ def get_posts_by_image(
'similarPosts': 'similarPosts':
[ [
{ {
'distance': lookalike.distance, 'distance': distance,
'post': _serialize_post(ctx, lookalike.post), 'post': _serialize_post(ctx, post),
} }
for lookalike in lookalikes for distance, post in lookalikes
], ],
} }

View file

@ -8,7 +8,7 @@ import coloredlogs
import sqlalchemy as sa import sqlalchemy as sa
import sqlalchemy.orm.exc import sqlalchemy.orm.exc
from szurubooru import config, db, errors, rest from szurubooru import config, db, errors, rest
from szurubooru.func import posts, file_uploads, image_hash from szurubooru.func import file_uploads
# pylint: disable=unused-import # pylint: disable=unused-import
from szurubooru import api, middleware from szurubooru import api, middleware
@ -129,13 +129,7 @@ def create_app() -> Callable[[Any, Any], Any]:
purge_thread.daemon = True purge_thread.daemon = True
purge_thread.start() purge_thread.start()
try: db.session.commit()
image_hash.get_session().cluster.health(
wait_for_status='yellow', request_timeout=120)
posts.populate_reverse_search()
db.session.commit()
except errors.ThirdPartyError:
pass
rest.errors.handle(errors.AuthError, _on_auth_error) rest.errors.handle(errors.AuthError, _on_auth_error)
rest.errors.handle(errors.ValidationError, _on_validation_error) rest.errors.handle(errors.ValidationError, _on_validation_error)

View file

@ -2,8 +2,7 @@ import logging
from io import BytesIO from io import BytesIO
from datetime import datetime from datetime import datetime
from typing import Any, Optional, Tuple, Set, List, Callable from typing import Any, Optional, Tuple, Set, List, Callable
import elasticsearch import math
import elasticsearch_dsl
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from szurubooru import config, errors from szurubooru import config, errors
@ -24,30 +23,25 @@ N = 9
P = None P = None
SAMPLE_WORDS = 16 SAMPLE_WORDS = 16
MAX_WORDS = 63 MAX_WORDS = 63
ES_DOC_TYPE = 'image' SIG_CHUNK_BITS = 32
ES_MAX_RESULTS = 100
SIG_BASE = 2*N_LEVELS + 2
SIG_CHUNK_WIDTH = int(SIG_CHUNK_BITS / math.log2(SIG_BASE))
SIG_CHUNK_NUMS = 8*N*N / SIG_CHUNK_WIDTH
assert 8*N*N % SIG_CHUNK_WIDTH == 0
Window = Tuple[Tuple[float, float], Tuple[float, float]] Window = Tuple[Tuple[float, float], Tuple[float, float]]
NpMatrix = Any NpMatrix = np.ndarray
def get_session() -> elasticsearch.Elasticsearch:
extra_args = {}
if config.config['elasticsearch']['pass']:
extra_args['http_auth'] = (
config.config['elasticsearch']['user'],
config.config['elasticsearch']['pass'])
extra_args['scheme'] = 'https'
extra_args['port'] = 443
return elasticsearch.Elasticsearch([{
'host': config.config['elasticsearch']['host'],
'port': config.config['elasticsearch']['port'],
}], **extra_args)
def _preprocess_image(content: bytes) -> NpMatrix: def _preprocess_image(content: bytes) -> NpMatrix:
img = Image.open(BytesIO(content)) try:
return np.asarray(img.convert('L'), dtype=np.uint8) img = Image.open(BytesIO(content))
return np.asarray(img.convert('L'), dtype=np.uint8)
except IOError:
raise errors.ProcessingError(
'Unable to generate a signature hash '
'for this image.')
def _crop_image( def _crop_image(
@ -175,7 +169,31 @@ def _compute_differentials(grey_level_matrix: NpMatrix) -> NpMatrix:
lower_right_neighbors])) lower_right_neighbors]))
def _generate_signature(content: bytes) -> NpMatrix: def _words_to_int(word_array: NpMatrix) -> List[int]:
width = word_array.shape[1]
coding_vector = 3**np.arange(width)
return np.dot(word_array + 1, coding_vector).astype(int).tolist()
def _get_words(array: NpMatrix, k: int, n: int) -> NpMatrix:
word_positions = np.linspace(
0, array.shape[0], n, endpoint=False).astype('int')
assert k <= array.shape[0]
assert word_positions.shape[0] <= array.shape[0]
words = np.zeros((n, k)).astype('int8')
for i, pos in enumerate(word_positions):
if pos + k <= array.shape[0]:
words[i] = array[pos:pos + k]
else:
temp = array[pos:].copy()
temp.resize(k, refcheck=False)
words[i] = temp
words[words > 0] = 1
words[words < 0] = -1
return words
def generate_signature(content: bytes) -> NpMatrix:
im_array = _preprocess_image(content) im_array = _preprocess_image(content)
image_limits = _crop_image( image_limits = _crop_image(
im_array, im_array,
@ -192,40 +210,15 @@ def _generate_signature(content: bytes) -> NpMatrix:
return np.ravel(diff_matrix).astype('int8') return np.ravel(diff_matrix).astype('int8')
def _get_words(array: NpMatrix, k: int, n: int) -> NpMatrix: def generate_words(signature: NpMatrix) -> List[int]:
word_positions = np.linspace( return _words_to_int(_get_words(signature, k=SAMPLE_WORDS, n=MAX_WORDS))
0, array.shape[0], n, endpoint=False).astype('int')
assert k <= array.shape[0]
assert word_positions.shape[0] <= array.shape[0]
words = np.zeros((n, k)).astype('int8')
for i, pos in enumerate(word_positions):
if pos + k <= array.shape[0]:
words[i] = array[pos:pos + k]
else:
temp = array[pos:].copy()
temp.resize(k)
words[i] = temp
_max_contrast(words)
words = _words_to_int(words)
return words
def _words_to_int(word_array: NpMatrix) -> NpMatrix: def normalized_distance(
width = word_array.shape[1] target_array: Any,
coding_vector = 3**np.arange(width)
return np.dot(word_array + 1, coding_vector)
def _max_contrast(array: NpMatrix) -> None:
array[array > 0] = 1
array[array < 0] = -1
def _normalized_distance(
target_array: NpMatrix,
vec: NpMatrix, vec: NpMatrix,
nan_value: float = 1.0) -> List[float]: nan_value: float = 1.0) -> List[float]:
target_array = target_array.astype(int) target_array = np.array(target_array).astype(int)
vec = vec.astype(int) vec = vec.astype(int)
topvec = np.linalg.norm(vec - target_array, axis=1) topvec = np.linalg.norm(vec - target_array, axis=1)
norm1 = np.linalg.norm(vec, axis=0) norm1 = np.linalg.norm(vec, axis=0)
@ -235,124 +228,21 @@ def _normalized_distance(
return finvec return finvec
def _safety_blanket(default_param_factory: Callable[[], Any]) -> Callable: def pack_signature(signature: NpMatrix) -> bytes:
def wrapper_outer(target_function: Callable) -> Callable: base = 2 * N_LEVELS + 1
def wrapper_inner(*args: Any, **kwargs: Any) -> Any: coding_vector = np.flipud(SIG_BASE**np.arange(SIG_CHUNK_WIDTH))
try: return np.array([
return target_function(*args, **kwargs) np.dot(x, coding_vector) for x in
except elasticsearch.exceptions.NotFoundError: np.reshape(signature + N_LEVELS, (-1, SIG_CHUNK_WIDTH))
# index not yet created, will be created dynamically by ]).astype(f'uint{SIG_CHUNK_BITS}').tobytes()
# add_image()
return default_param_factory()
except elasticsearch.exceptions.ElasticsearchException as ex:
logger.warning('Problem with elastic search: %s', ex)
raise errors.ThirdPartyError(
'Error connecting to elastic search.')
except IOError:
raise errors.ProcessingError('Not an image.')
except Exception as ex:
raise errors.ThirdPartyError('Unknown error (%s).' % ex)
return wrapper_inner
return wrapper_outer
class Lookalike: def unpack_signature(packed: bytes) -> NpMatrix:
def __init__(self, score: int, distance: float, path: Any) -> None: base = 2 * N_LEVELS + 1
self.score = score return np.ravel(np.array([
self.distance = distance [
self.path = path int(digit) - N_LEVELS for digit in
np.base_repr(e, base=SIG_BASE).zfill(SIG_CHUNK_WIDTH)
] for e in
@_safety_blanket(lambda: None) np.frombuffer(packed, dtype=f'uint{SIG_CHUNK_BITS}')
def add_image(path: str, image_content: bytes) -> None: ]).astype('int8'))
assert path
assert image_content
signature = _generate_signature(image_content)
words = _get_words(signature, k=SAMPLE_WORDS, n=MAX_WORDS)
record = {
'signature': signature.tolist(),
'path': path,
'timestamp': datetime.now(),
}
for i in range(MAX_WORDS):
record['simple_word_' + str(i)] = words[i].tolist()
get_session().index(
index=config.config['elasticsearch']['index'],
doc_type=ES_DOC_TYPE,
body=record,
refresh=True)
@_safety_blanket(lambda: None)
def delete_image(path: str) -> None:
assert path
get_session().delete_by_query(
index=config.config['elasticsearch']['index'],
doc_type=ES_DOC_TYPE,
body={'query': {'term': {'path': path}}})
@_safety_blanket(lambda: [])
def search_by_image(image_content: bytes) -> List[Lookalike]:
signature = _generate_signature(image_content)
words = _get_words(signature, k=SAMPLE_WORDS, n=MAX_WORDS)
res = get_session().search(
index=config.config['elasticsearch']['index'],
doc_type=ES_DOC_TYPE,
body={
'query':
{
'bool':
{
'should':
[
{'term': {'simple_word_%d' % i: word.tolist()}}
for i, word in enumerate(words)
]
}
},
'_source': {'excludes': ['simple_word_*']}},
size=ES_MAX_RESULTS,
timeout='10s')['hits']['hits']
if len(res) == 0:
return []
sigs = np.array([x['_source']['signature'] for x in res])
dists = _normalized_distance(sigs, np.array(signature))
ids = set() # type: Set[int]
ret = []
for item, dist in zip(res, dists):
id = item['_id']
score = item['_score']
path = item['_source']['path']
if id in ids:
continue
ids.add(id)
if dist < DISTANCE_CUTOFF:
ret.append(Lookalike(score=score, distance=dist, path=path))
return ret
@_safety_blanket(lambda: None)
def purge() -> None:
get_session().delete_by_query(
index=config.config['elasticsearch']['index'],
doc_type=ES_DOC_TYPE,
body={'query': {'match_all': {}}},
refresh=True)
@_safety_blanket(lambda: set())
def get_all_paths() -> Set[str]:
search = (
elasticsearch_dsl.Search(
using=get_session(),
index=config.config['elasticsearch']['index'],
doc_type=ES_DOC_TYPE)
.source(['path']))
return set(h.path for h in search.scan())

View file

@ -1,3 +1,4 @@
import logging
import hmac import hmac
from typing import Any, Optional, Tuple, List, Dict, Callable from typing import Any, Optional, Tuple, List, Dict, Callable
from datetime import datetime from datetime import datetime
@ -8,6 +9,9 @@ from szurubooru.func import (
mime, images, files, image_hash, serialization, snapshots) mime, images, files, image_hash, serialization, snapshots)
logger = logging.getLogger(__name__)
EMPTY_PIXEL = ( EMPTY_PIXEL = (
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00'
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00'
@ -60,12 +64,6 @@ class InvalidPostFlagError(errors.ValidationError):
pass pass
class PostLookalike(image_hash.Lookalike):
def __init__(self, score: int, distance: float, post: model.Post) -> None:
super().__init__(score, distance, post.post_id)
self.post = post
SAFETY_MAP = { SAFETY_MAP = {
model.Post.SAFETY_SAFE: 'safe', model.Post.SAFETY_SAFE: 'safe',
model.Post.SAFETY_SKETCHY: 'sketchy', model.Post.SAFETY_SKETCHY: 'sketchy',
@ -402,7 +400,6 @@ def _after_post_update(
def _before_post_delete( def _before_post_delete(
_mapper: Any, _connection: Any, post: model.Post) -> None: _mapper: Any, _connection: Any, post: model.Post) -> None:
if post.post_id: if post.post_id:
image_hash.delete_image(post.post_id)
if config.config['delete_source_files']: if config.config['delete_source_files']:
files.delete(get_post_content_path(post)) files.delete(get_post_content_path(post))
files.delete(get_post_thumbnail_path(post)) files.delete(get_post_thumbnail_path(post))
@ -416,10 +413,6 @@ def _sync_post_content(post: model.Post) -> None:
files.save(get_post_content_path(post), content) files.save(get_post_content_path(post), content)
delattr(post, '__content') delattr(post, '__content')
regenerate_thumb = True regenerate_thumb = True
if post.post_id and post.type in (
model.Post.TYPE_IMAGE, model.Post.TYPE_ANIMATION):
image_hash.delete_image(post.post_id)
image_hash.add_image(post.post_id, content)
if hasattr(post, '__thumbnail'): if hasattr(post, '__thumbnail'):
if getattr(post, '__thumbnail'): if getattr(post, '__thumbnail'):
@ -485,14 +478,56 @@ def test_sound(post: model.Post, content: bytes) -> None:
update_post_flags(post, flags) update_post_flags(post, flags)
def purge_post_signature(post: model.Post) -> None:
old_signature = (
db.session
.query(model.PostSignature)
.filter(model.PostSignature.post_id == post.post_id)
.one_or_none())
if old_signature:
db.session.delete(old_signature)
def generate_post_signature(post: model.Post, content: bytes) -> None:
try:
unpacked_signature = image_hash.generate_signature(content)
packed_signature = image_hash.pack_signature(unpacked_signature)
words = image_hash.generate_words(unpacked_signature)
db.session.add(model.PostSignature(
post=post, signature=packed_signature, words=words))
except errors.ProcessingError:
if not config.config['allow_broken_uploads']:
raise InvalidPostContentError(
'Unable to generate image hash data.')
def update_all_post_signatures() -> None:
posts_to_hash = (
db.session
.query(model.Post)
.filter(
(model.Post.type == model.Post.TYPE_IMAGE) |
(model.Post.type == model.Post.TYPE_ANIMATION))
.filter(model.Post.signature == None)
.order_by(model.Post.post_id.asc())
.all())
for post in posts_to_hash:
logger.info('Generating hash info for %d', post.post_id)
generate_post_signature(post, files.get(get_post_content_path(post)))
def update_post_content(post: model.Post, content: Optional[bytes]) -> None: def update_post_content(post: model.Post, content: Optional[bytes]) -> None:
assert post assert post
if not content: if not content:
raise InvalidPostContentError('Post content missing.') raise InvalidPostContentError('Post content missing.')
update_signature = False
post.mime_type = mime.get_mime_type(content) post.mime_type = mime.get_mime_type(content)
if mime.is_flash(post.mime_type): if mime.is_flash(post.mime_type):
post.type = model.Post.TYPE_FLASH post.type = model.Post.TYPE_FLASH
elif mime.is_image(post.mime_type): elif mime.is_image(post.mime_type):
update_signature = True
if mime.is_animated_gif(content): if mime.is_animated_gif(content):
post.type = model.Post.TYPE_ANIMATION post.type = model.Post.TYPE_ANIMATION
else: else:
@ -515,18 +550,30 @@ def update_post_content(post: model.Post, content: Optional[bytes]) -> None:
and other_post.post_id != post.post_id: and other_post.post_id != post.post_id:
raise PostAlreadyUploadedError(other_post) raise PostAlreadyUploadedError(other_post)
if update_signature:
purge_post_signature(post)
post.signature = generate_post_signature(post, content)
post.file_size = len(content) post.file_size = len(content)
try: try:
image = images.Image(content) image = images.Image(content)
post.canvas_width = image.width post.canvas_width = image.width
post.canvas_height = image.height post.canvas_height = image.height
except errors.ProcessingError: except errors.ProcessingError:
post.canvas_width = None if not config.config['allow_broken_uploads']:
post.canvas_height = None raise InvalidPostContentError(
'Unable to process image metadata')
else:
post.canvas_width = None
post.canvas_height = None
if (post.canvas_width is not None and post.canvas_width <= 0) \ if (post.canvas_width is not None and post.canvas_width <= 0) \
or (post.canvas_height is not None and post.canvas_height <= 0): or (post.canvas_height is not None and post.canvas_height <= 0):
post.canvas_width = None if not config.config['allow_broken_uploads']:
post.canvas_height = None raise InvalidPostContentError(
'Invalid image dimensions returned during processing')
else:
post.canvas_width = None
post.canvas_height = None
setattr(post, '__content', content) setattr(post, '__content', content)
@ -751,6 +798,8 @@ def merge_posts(
if replace_content: if replace_content:
content = files.get(get_post_content_path(source_post)) content = files.get(get_post_content_path(source_post))
transfer_flags(source_post.post_id, target_post.post_id) transfer_flags(source_post.post_id, target_post.post_id)
purge_post_signature(source_post)
purge_post_signature(target_post)
delete(source_post) delete(source_post)
db.session.flush() db.session.flush()
@ -768,38 +817,31 @@ def search_by_image_exact(image_content: bytes) -> Optional[model.Post]:
.one_or_none()) .one_or_none())
def search_by_image(image_content: bytes) -> List[PostLookalike]: def search_by_image(image_content: bytes) -> List[Tuple[float, model.Post]]:
ret = [] query_signature = image_hash.generate_signature(image_content)
for result in image_hash.search_by_image(image_content): query_words = image_hash.generate_words(query_signature)
post = try_get_post_by_id(result.path)
if post:
ret.append(PostLookalike(
score=result.score,
distance=result.distance,
post=post))
return ret
dbquery = '''
SELECT s.post_id, s.signature, count(a.query) AS score
FROM post_signature AS s, unnest(s.words, :q) AS a(word, query)
WHERE a.word = a.query
GROUP BY s.post_id
ORDER BY score DESC LIMIT 100;
'''
def populate_reverse_search() -> None: candidates = db.session.execute(dbquery, {'q': query_words})
excluded_post_ids = image_hash.get_all_paths() data = tuple(zip(*[
(post_id, image_hash.unpack_signature(packedsig))
post_ids_to_hash = ( for post_id, packedsig, score in candidates
db.session ]))
.query(model.Post.post_id) if data:
.filter( candidate_post_ids, sigarray = data
(model.Post.type == model.Post.TYPE_IMAGE) | distances = image_hash.normalized_distance(sigarray, query_signature)
(model.Post.type == model.Post.TYPE_ANIMATION)) return [
.filter(~model.Post.post_id.in_(excluded_post_ids)) (distance, try_get_post_by_id(candidate_post_id))
.order_by(model.Post.post_id.asc()) for candidate_post_id, distance
.all()) in zip(candidate_post_ids, distances)
if distance < image_hash.DISTANCE_CUTOFF
for post_ids_chunk in util.chunks(post_ids_to_hash, 100): ]
posts_chunk = ( else:
db.session return []
.query(model.Post)
.filter(model.Post.post_id.in_(post_ids_chunk))
.all())
for post in posts_chunk:
content_path = get_post_content_path(post)
if files.has(content_path):
image_hash.add_image(post.post_id, files.get(content_path))

View file

@ -9,7 +9,8 @@ from szurubooru.model.post import (
PostFavorite, PostFavorite,
PostScore, PostScore,
PostNote, PostNote,
PostFeature) PostFeature,
PostSignature)
from szurubooru.model.comment import Comment, CommentScore from szurubooru.model.comment import Comment, CommentScore
from szurubooru.model.snapshot import Snapshot from szurubooru.model.snapshot import Snapshot
import szurubooru.model.util import szurubooru.model.util

View file

@ -143,6 +143,26 @@ class PostTag(Base):
self.tag_id = tag_id self.tag_id = tag_id
class PostSignature(Base):
__tablename__ = 'post_signature'
post_id = sa.Column(
'post_id',
sa.Integer,
sa.ForeignKey('post.id'),
primary_key=True,
nullable=False,
index=True)
signature = sa.Column('signature', sa.LargeBinary, nullable=False)
words = sa.Column(
'words',
sa.dialects.postgresql.ARRAY(sa.Integer, dimensions=1),
nullable=False,
index=True)
post = sa.orm.relationship('Post')
class Post(Base): class Post(Base):
__tablename__ = 'post' __tablename__ = 'post'
@ -184,6 +204,11 @@ class Post(Base):
# foreign tables # foreign tables
user = sa.orm.relationship('User') user = sa.orm.relationship('User')
tags = sa.orm.relationship('Tag', backref='posts', secondary='post_tag') tags = sa.orm.relationship('Tag', backref='posts', secondary='post_tag')
signature = sa.orm.relationship(
'PostSignature',
uselist=False,
cascade='all, delete-orphan',
lazy='joined')
relations = sa.orm.relationship( relations = sa.orm.relationship(
'Post', 'Post',
secondary='post_relation', secondary='post_relation',