server/image_search: implement reverse search functionality in postgres
This will remove the dependency on the Elasticsearch database. The search query is passed currently as raw SQL. Proper implementation using SQLAlchemy will need custom ORM classed to be made. Additional config parameter "allow_broken_uploads" has been added.
This commit is contained in:
parent
a616cf6987
commit
4c78cf8c47
7 changed files with 187 additions and 231 deletions
|
@ -21,11 +21,15 @@ thumbnails:
|
||||||
post_width: 300
|
post_width: 300
|
||||||
post_height: 300
|
post_height: 300
|
||||||
|
|
||||||
|
# automatically convert animated GIF uploads to video formats
|
||||||
convert:
|
convert:
|
||||||
gif:
|
gif:
|
||||||
to_webm: false
|
to_webm: false
|
||||||
to_mp4: false
|
to_mp4: false
|
||||||
|
|
||||||
|
# allow posts to be uploaded even if some image processing errors occur
|
||||||
|
allow_broken_uploads: false
|
||||||
|
|
||||||
# used to send password reset e-mails
|
# used to send password reset e-mails
|
||||||
smtp:
|
smtp:
|
||||||
host: # example: localhost
|
host: # example: localhost
|
||||||
|
|
|
@ -262,9 +262,9 @@ def get_posts_by_image(
|
||||||
'similarPosts':
|
'similarPosts':
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
'distance': lookalike.distance,
|
'distance': distance,
|
||||||
'post': _serialize_post(ctx, lookalike.post),
|
'post': _serialize_post(ctx, post),
|
||||||
}
|
}
|
||||||
for lookalike in lookalikes
|
for distance, post in lookalikes
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,7 +8,7 @@ import coloredlogs
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
import sqlalchemy.orm.exc
|
import sqlalchemy.orm.exc
|
||||||
from szurubooru import config, db, errors, rest
|
from szurubooru import config, db, errors, rest
|
||||||
from szurubooru.func import posts, file_uploads, image_hash
|
from szurubooru.func import file_uploads
|
||||||
# pylint: disable=unused-import
|
# pylint: disable=unused-import
|
||||||
from szurubooru import api, middleware
|
from szurubooru import api, middleware
|
||||||
|
|
||||||
|
@ -129,13 +129,7 @@ def create_app() -> Callable[[Any, Any], Any]:
|
||||||
purge_thread.daemon = True
|
purge_thread.daemon = True
|
||||||
purge_thread.start()
|
purge_thread.start()
|
||||||
|
|
||||||
try:
|
db.session.commit()
|
||||||
image_hash.get_session().cluster.health(
|
|
||||||
wait_for_status='yellow', request_timeout=120)
|
|
||||||
posts.populate_reverse_search()
|
|
||||||
db.session.commit()
|
|
||||||
except errors.ThirdPartyError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
rest.errors.handle(errors.AuthError, _on_auth_error)
|
rest.errors.handle(errors.AuthError, _on_auth_error)
|
||||||
rest.errors.handle(errors.ValidationError, _on_validation_error)
|
rest.errors.handle(errors.ValidationError, _on_validation_error)
|
||||||
|
|
|
@ -2,8 +2,7 @@ import logging
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Optional, Tuple, Set, List, Callable
|
from typing import Any, Optional, Tuple, Set, List, Callable
|
||||||
import elasticsearch
|
import math
|
||||||
import elasticsearch_dsl
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from szurubooru import config, errors
|
from szurubooru import config, errors
|
||||||
|
@ -24,30 +23,25 @@ N = 9
|
||||||
P = None
|
P = None
|
||||||
SAMPLE_WORDS = 16
|
SAMPLE_WORDS = 16
|
||||||
MAX_WORDS = 63
|
MAX_WORDS = 63
|
||||||
ES_DOC_TYPE = 'image'
|
SIG_CHUNK_BITS = 32
|
||||||
ES_MAX_RESULTS = 100
|
|
||||||
|
SIG_BASE = 2*N_LEVELS + 2
|
||||||
|
SIG_CHUNK_WIDTH = int(SIG_CHUNK_BITS / math.log2(SIG_BASE))
|
||||||
|
SIG_CHUNK_NUMS = 8*N*N / SIG_CHUNK_WIDTH
|
||||||
|
assert 8*N*N % SIG_CHUNK_WIDTH == 0
|
||||||
|
|
||||||
Window = Tuple[Tuple[float, float], Tuple[float, float]]
|
Window = Tuple[Tuple[float, float], Tuple[float, float]]
|
||||||
NpMatrix = Any
|
NpMatrix = np.ndarray
|
||||||
|
|
||||||
|
|
||||||
def get_session() -> elasticsearch.Elasticsearch:
|
|
||||||
extra_args = {}
|
|
||||||
if config.config['elasticsearch']['pass']:
|
|
||||||
extra_args['http_auth'] = (
|
|
||||||
config.config['elasticsearch']['user'],
|
|
||||||
config.config['elasticsearch']['pass'])
|
|
||||||
extra_args['scheme'] = 'https'
|
|
||||||
extra_args['port'] = 443
|
|
||||||
return elasticsearch.Elasticsearch([{
|
|
||||||
'host': config.config['elasticsearch']['host'],
|
|
||||||
'port': config.config['elasticsearch']['port'],
|
|
||||||
}], **extra_args)
|
|
||||||
|
|
||||||
|
|
||||||
def _preprocess_image(content: bytes) -> NpMatrix:
|
def _preprocess_image(content: bytes) -> NpMatrix:
|
||||||
img = Image.open(BytesIO(content))
|
try:
|
||||||
return np.asarray(img.convert('L'), dtype=np.uint8)
|
img = Image.open(BytesIO(content))
|
||||||
|
return np.asarray(img.convert('L'), dtype=np.uint8)
|
||||||
|
except IOError:
|
||||||
|
raise errors.ProcessingError(
|
||||||
|
'Unable to generate a signature hash '
|
||||||
|
'for this image.')
|
||||||
|
|
||||||
|
|
||||||
def _crop_image(
|
def _crop_image(
|
||||||
|
@ -175,7 +169,31 @@ def _compute_differentials(grey_level_matrix: NpMatrix) -> NpMatrix:
|
||||||
lower_right_neighbors]))
|
lower_right_neighbors]))
|
||||||
|
|
||||||
|
|
||||||
def _generate_signature(content: bytes) -> NpMatrix:
|
def _words_to_int(word_array: NpMatrix) -> List[int]:
|
||||||
|
width = word_array.shape[1]
|
||||||
|
coding_vector = 3**np.arange(width)
|
||||||
|
return np.dot(word_array + 1, coding_vector).astype(int).tolist()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_words(array: NpMatrix, k: int, n: int) -> NpMatrix:
|
||||||
|
word_positions = np.linspace(
|
||||||
|
0, array.shape[0], n, endpoint=False).astype('int')
|
||||||
|
assert k <= array.shape[0]
|
||||||
|
assert word_positions.shape[0] <= array.shape[0]
|
||||||
|
words = np.zeros((n, k)).astype('int8')
|
||||||
|
for i, pos in enumerate(word_positions):
|
||||||
|
if pos + k <= array.shape[0]:
|
||||||
|
words[i] = array[pos:pos + k]
|
||||||
|
else:
|
||||||
|
temp = array[pos:].copy()
|
||||||
|
temp.resize(k, refcheck=False)
|
||||||
|
words[i] = temp
|
||||||
|
words[words > 0] = 1
|
||||||
|
words[words < 0] = -1
|
||||||
|
return words
|
||||||
|
|
||||||
|
|
||||||
|
def generate_signature(content: bytes) -> NpMatrix:
|
||||||
im_array = _preprocess_image(content)
|
im_array = _preprocess_image(content)
|
||||||
image_limits = _crop_image(
|
image_limits = _crop_image(
|
||||||
im_array,
|
im_array,
|
||||||
|
@ -192,40 +210,15 @@ def _generate_signature(content: bytes) -> NpMatrix:
|
||||||
return np.ravel(diff_matrix).astype('int8')
|
return np.ravel(diff_matrix).astype('int8')
|
||||||
|
|
||||||
|
|
||||||
def _get_words(array: NpMatrix, k: int, n: int) -> NpMatrix:
|
def generate_words(signature: NpMatrix) -> List[int]:
|
||||||
word_positions = np.linspace(
|
return _words_to_int(_get_words(signature, k=SAMPLE_WORDS, n=MAX_WORDS))
|
||||||
0, array.shape[0], n, endpoint=False).astype('int')
|
|
||||||
assert k <= array.shape[0]
|
|
||||||
assert word_positions.shape[0] <= array.shape[0]
|
|
||||||
words = np.zeros((n, k)).astype('int8')
|
|
||||||
for i, pos in enumerate(word_positions):
|
|
||||||
if pos + k <= array.shape[0]:
|
|
||||||
words[i] = array[pos:pos + k]
|
|
||||||
else:
|
|
||||||
temp = array[pos:].copy()
|
|
||||||
temp.resize(k)
|
|
||||||
words[i] = temp
|
|
||||||
_max_contrast(words)
|
|
||||||
words = _words_to_int(words)
|
|
||||||
return words
|
|
||||||
|
|
||||||
|
|
||||||
def _words_to_int(word_array: NpMatrix) -> NpMatrix:
|
def normalized_distance(
|
||||||
width = word_array.shape[1]
|
target_array: Any,
|
||||||
coding_vector = 3**np.arange(width)
|
|
||||||
return np.dot(word_array + 1, coding_vector)
|
|
||||||
|
|
||||||
|
|
||||||
def _max_contrast(array: NpMatrix) -> None:
|
|
||||||
array[array > 0] = 1
|
|
||||||
array[array < 0] = -1
|
|
||||||
|
|
||||||
|
|
||||||
def _normalized_distance(
|
|
||||||
target_array: NpMatrix,
|
|
||||||
vec: NpMatrix,
|
vec: NpMatrix,
|
||||||
nan_value: float = 1.0) -> List[float]:
|
nan_value: float = 1.0) -> List[float]:
|
||||||
target_array = target_array.astype(int)
|
target_array = np.array(target_array).astype(int)
|
||||||
vec = vec.astype(int)
|
vec = vec.astype(int)
|
||||||
topvec = np.linalg.norm(vec - target_array, axis=1)
|
topvec = np.linalg.norm(vec - target_array, axis=1)
|
||||||
norm1 = np.linalg.norm(vec, axis=0)
|
norm1 = np.linalg.norm(vec, axis=0)
|
||||||
|
@ -235,124 +228,21 @@ def _normalized_distance(
|
||||||
return finvec
|
return finvec
|
||||||
|
|
||||||
|
|
||||||
def _safety_blanket(default_param_factory: Callable[[], Any]) -> Callable:
|
def pack_signature(signature: NpMatrix) -> bytes:
|
||||||
def wrapper_outer(target_function: Callable) -> Callable:
|
base = 2 * N_LEVELS + 1
|
||||||
def wrapper_inner(*args: Any, **kwargs: Any) -> Any:
|
coding_vector = np.flipud(SIG_BASE**np.arange(SIG_CHUNK_WIDTH))
|
||||||
try:
|
return np.array([
|
||||||
return target_function(*args, **kwargs)
|
np.dot(x, coding_vector) for x in
|
||||||
except elasticsearch.exceptions.NotFoundError:
|
np.reshape(signature + N_LEVELS, (-1, SIG_CHUNK_WIDTH))
|
||||||
# index not yet created, will be created dynamically by
|
]).astype(f'uint{SIG_CHUNK_BITS}').tobytes()
|
||||||
# add_image()
|
|
||||||
return default_param_factory()
|
|
||||||
except elasticsearch.exceptions.ElasticsearchException as ex:
|
|
||||||
logger.warning('Problem with elastic search: %s', ex)
|
|
||||||
raise errors.ThirdPartyError(
|
|
||||||
'Error connecting to elastic search.')
|
|
||||||
except IOError:
|
|
||||||
raise errors.ProcessingError('Not an image.')
|
|
||||||
except Exception as ex:
|
|
||||||
raise errors.ThirdPartyError('Unknown error (%s).' % ex)
|
|
||||||
return wrapper_inner
|
|
||||||
return wrapper_outer
|
|
||||||
|
|
||||||
|
|
||||||
class Lookalike:
|
def unpack_signature(packed: bytes) -> NpMatrix:
|
||||||
def __init__(self, score: int, distance: float, path: Any) -> None:
|
base = 2 * N_LEVELS + 1
|
||||||
self.score = score
|
return np.ravel(np.array([
|
||||||
self.distance = distance
|
[
|
||||||
self.path = path
|
int(digit) - N_LEVELS for digit in
|
||||||
|
np.base_repr(e, base=SIG_BASE).zfill(SIG_CHUNK_WIDTH)
|
||||||
|
] for e in
|
||||||
@_safety_blanket(lambda: None)
|
np.frombuffer(packed, dtype=f'uint{SIG_CHUNK_BITS}')
|
||||||
def add_image(path: str, image_content: bytes) -> None:
|
]).astype('int8'))
|
||||||
assert path
|
|
||||||
assert image_content
|
|
||||||
signature = _generate_signature(image_content)
|
|
||||||
words = _get_words(signature, k=SAMPLE_WORDS, n=MAX_WORDS)
|
|
||||||
|
|
||||||
record = {
|
|
||||||
'signature': signature.tolist(),
|
|
||||||
'path': path,
|
|
||||||
'timestamp': datetime.now(),
|
|
||||||
}
|
|
||||||
for i in range(MAX_WORDS):
|
|
||||||
record['simple_word_' + str(i)] = words[i].tolist()
|
|
||||||
|
|
||||||
get_session().index(
|
|
||||||
index=config.config['elasticsearch']['index'],
|
|
||||||
doc_type=ES_DOC_TYPE,
|
|
||||||
body=record,
|
|
||||||
refresh=True)
|
|
||||||
|
|
||||||
|
|
||||||
@_safety_blanket(lambda: None)
|
|
||||||
def delete_image(path: str) -> None:
|
|
||||||
assert path
|
|
||||||
get_session().delete_by_query(
|
|
||||||
index=config.config['elasticsearch']['index'],
|
|
||||||
doc_type=ES_DOC_TYPE,
|
|
||||||
body={'query': {'term': {'path': path}}})
|
|
||||||
|
|
||||||
|
|
||||||
@_safety_blanket(lambda: [])
|
|
||||||
def search_by_image(image_content: bytes) -> List[Lookalike]:
|
|
||||||
signature = _generate_signature(image_content)
|
|
||||||
words = _get_words(signature, k=SAMPLE_WORDS, n=MAX_WORDS)
|
|
||||||
|
|
||||||
res = get_session().search(
|
|
||||||
index=config.config['elasticsearch']['index'],
|
|
||||||
doc_type=ES_DOC_TYPE,
|
|
||||||
body={
|
|
||||||
'query':
|
|
||||||
{
|
|
||||||
'bool':
|
|
||||||
{
|
|
||||||
'should':
|
|
||||||
[
|
|
||||||
{'term': {'simple_word_%d' % i: word.tolist()}}
|
|
||||||
for i, word in enumerate(words)
|
|
||||||
]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
'_source': {'excludes': ['simple_word_*']}},
|
|
||||||
size=ES_MAX_RESULTS,
|
|
||||||
timeout='10s')['hits']['hits']
|
|
||||||
|
|
||||||
if len(res) == 0:
|
|
||||||
return []
|
|
||||||
|
|
||||||
sigs = np.array([x['_source']['signature'] for x in res])
|
|
||||||
dists = _normalized_distance(sigs, np.array(signature))
|
|
||||||
|
|
||||||
ids = set() # type: Set[int]
|
|
||||||
ret = []
|
|
||||||
for item, dist in zip(res, dists):
|
|
||||||
id = item['_id']
|
|
||||||
score = item['_score']
|
|
||||||
path = item['_source']['path']
|
|
||||||
if id in ids:
|
|
||||||
continue
|
|
||||||
ids.add(id)
|
|
||||||
if dist < DISTANCE_CUTOFF:
|
|
||||||
ret.append(Lookalike(score=score, distance=dist, path=path))
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
@_safety_blanket(lambda: None)
|
|
||||||
def purge() -> None:
|
|
||||||
get_session().delete_by_query(
|
|
||||||
index=config.config['elasticsearch']['index'],
|
|
||||||
doc_type=ES_DOC_TYPE,
|
|
||||||
body={'query': {'match_all': {}}},
|
|
||||||
refresh=True)
|
|
||||||
|
|
||||||
|
|
||||||
@_safety_blanket(lambda: set())
|
|
||||||
def get_all_paths() -> Set[str]:
|
|
||||||
search = (
|
|
||||||
elasticsearch_dsl.Search(
|
|
||||||
using=get_session(),
|
|
||||||
index=config.config['elasticsearch']['index'],
|
|
||||||
doc_type=ES_DOC_TYPE)
|
|
||||||
.source(['path']))
|
|
||||||
return set(h.path for h in search.scan())
|
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import logging
|
||||||
import hmac
|
import hmac
|
||||||
from typing import Any, Optional, Tuple, List, Dict, Callable
|
from typing import Any, Optional, Tuple, List, Dict, Callable
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
@ -8,6 +9,9 @@ from szurubooru.func import (
|
||||||
mime, images, files, image_hash, serialization, snapshots)
|
mime, images, files, image_hash, serialization, snapshots)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
EMPTY_PIXEL = (
|
EMPTY_PIXEL = (
|
||||||
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00'
|
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00'
|
||||||
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00'
|
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00'
|
||||||
|
@ -60,12 +64,6 @@ class InvalidPostFlagError(errors.ValidationError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class PostLookalike(image_hash.Lookalike):
|
|
||||||
def __init__(self, score: int, distance: float, post: model.Post) -> None:
|
|
||||||
super().__init__(score, distance, post.post_id)
|
|
||||||
self.post = post
|
|
||||||
|
|
||||||
|
|
||||||
SAFETY_MAP = {
|
SAFETY_MAP = {
|
||||||
model.Post.SAFETY_SAFE: 'safe',
|
model.Post.SAFETY_SAFE: 'safe',
|
||||||
model.Post.SAFETY_SKETCHY: 'sketchy',
|
model.Post.SAFETY_SKETCHY: 'sketchy',
|
||||||
|
@ -402,7 +400,6 @@ def _after_post_update(
|
||||||
def _before_post_delete(
|
def _before_post_delete(
|
||||||
_mapper: Any, _connection: Any, post: model.Post) -> None:
|
_mapper: Any, _connection: Any, post: model.Post) -> None:
|
||||||
if post.post_id:
|
if post.post_id:
|
||||||
image_hash.delete_image(post.post_id)
|
|
||||||
if config.config['delete_source_files']:
|
if config.config['delete_source_files']:
|
||||||
files.delete(get_post_content_path(post))
|
files.delete(get_post_content_path(post))
|
||||||
files.delete(get_post_thumbnail_path(post))
|
files.delete(get_post_thumbnail_path(post))
|
||||||
|
@ -416,10 +413,6 @@ def _sync_post_content(post: model.Post) -> None:
|
||||||
files.save(get_post_content_path(post), content)
|
files.save(get_post_content_path(post), content)
|
||||||
delattr(post, '__content')
|
delattr(post, '__content')
|
||||||
regenerate_thumb = True
|
regenerate_thumb = True
|
||||||
if post.post_id and post.type in (
|
|
||||||
model.Post.TYPE_IMAGE, model.Post.TYPE_ANIMATION):
|
|
||||||
image_hash.delete_image(post.post_id)
|
|
||||||
image_hash.add_image(post.post_id, content)
|
|
||||||
|
|
||||||
if hasattr(post, '__thumbnail'):
|
if hasattr(post, '__thumbnail'):
|
||||||
if getattr(post, '__thumbnail'):
|
if getattr(post, '__thumbnail'):
|
||||||
|
@ -485,14 +478,56 @@ def test_sound(post: model.Post, content: bytes) -> None:
|
||||||
update_post_flags(post, flags)
|
update_post_flags(post, flags)
|
||||||
|
|
||||||
|
|
||||||
|
def purge_post_signature(post: model.Post) -> None:
|
||||||
|
old_signature = (
|
||||||
|
db.session
|
||||||
|
.query(model.PostSignature)
|
||||||
|
.filter(model.PostSignature.post_id == post.post_id)
|
||||||
|
.one_or_none())
|
||||||
|
if old_signature:
|
||||||
|
db.session.delete(old_signature)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_post_signature(post: model.Post, content: bytes) -> None:
|
||||||
|
try:
|
||||||
|
unpacked_signature = image_hash.generate_signature(content)
|
||||||
|
packed_signature = image_hash.pack_signature(unpacked_signature)
|
||||||
|
words = image_hash.generate_words(unpacked_signature)
|
||||||
|
|
||||||
|
db.session.add(model.PostSignature(
|
||||||
|
post=post, signature=packed_signature, words=words))
|
||||||
|
except errors.ProcessingError:
|
||||||
|
if not config.config['allow_broken_uploads']:
|
||||||
|
raise InvalidPostContentError(
|
||||||
|
'Unable to generate image hash data.')
|
||||||
|
|
||||||
|
|
||||||
|
def update_all_post_signatures() -> None:
|
||||||
|
posts_to_hash = (
|
||||||
|
db.session
|
||||||
|
.query(model.Post)
|
||||||
|
.filter(
|
||||||
|
(model.Post.type == model.Post.TYPE_IMAGE) |
|
||||||
|
(model.Post.type == model.Post.TYPE_ANIMATION))
|
||||||
|
.filter(model.Post.signature == None)
|
||||||
|
.order_by(model.Post.post_id.asc())
|
||||||
|
.all())
|
||||||
|
for post in posts_to_hash:
|
||||||
|
logger.info('Generating hash info for %d', post.post_id)
|
||||||
|
generate_post_signature(post, files.get(get_post_content_path(post)))
|
||||||
|
|
||||||
|
|
||||||
def update_post_content(post: model.Post, content: Optional[bytes]) -> None:
|
def update_post_content(post: model.Post, content: Optional[bytes]) -> None:
|
||||||
assert post
|
assert post
|
||||||
if not content:
|
if not content:
|
||||||
raise InvalidPostContentError('Post content missing.')
|
raise InvalidPostContentError('Post content missing.')
|
||||||
|
|
||||||
|
update_signature = False
|
||||||
post.mime_type = mime.get_mime_type(content)
|
post.mime_type = mime.get_mime_type(content)
|
||||||
if mime.is_flash(post.mime_type):
|
if mime.is_flash(post.mime_type):
|
||||||
post.type = model.Post.TYPE_FLASH
|
post.type = model.Post.TYPE_FLASH
|
||||||
elif mime.is_image(post.mime_type):
|
elif mime.is_image(post.mime_type):
|
||||||
|
update_signature = True
|
||||||
if mime.is_animated_gif(content):
|
if mime.is_animated_gif(content):
|
||||||
post.type = model.Post.TYPE_ANIMATION
|
post.type = model.Post.TYPE_ANIMATION
|
||||||
else:
|
else:
|
||||||
|
@ -515,18 +550,30 @@ def update_post_content(post: model.Post, content: Optional[bytes]) -> None:
|
||||||
and other_post.post_id != post.post_id:
|
and other_post.post_id != post.post_id:
|
||||||
raise PostAlreadyUploadedError(other_post)
|
raise PostAlreadyUploadedError(other_post)
|
||||||
|
|
||||||
|
if update_signature:
|
||||||
|
purge_post_signature(post)
|
||||||
|
post.signature = generate_post_signature(post, content)
|
||||||
|
|
||||||
post.file_size = len(content)
|
post.file_size = len(content)
|
||||||
try:
|
try:
|
||||||
image = images.Image(content)
|
image = images.Image(content)
|
||||||
post.canvas_width = image.width
|
post.canvas_width = image.width
|
||||||
post.canvas_height = image.height
|
post.canvas_height = image.height
|
||||||
except errors.ProcessingError:
|
except errors.ProcessingError:
|
||||||
post.canvas_width = None
|
if not config.config['allow_broken_uploads']:
|
||||||
post.canvas_height = None
|
raise InvalidPostContentError(
|
||||||
|
'Unable to process image metadata')
|
||||||
|
else:
|
||||||
|
post.canvas_width = None
|
||||||
|
post.canvas_height = None
|
||||||
if (post.canvas_width is not None and post.canvas_width <= 0) \
|
if (post.canvas_width is not None and post.canvas_width <= 0) \
|
||||||
or (post.canvas_height is not None and post.canvas_height <= 0):
|
or (post.canvas_height is not None and post.canvas_height <= 0):
|
||||||
post.canvas_width = None
|
if not config.config['allow_broken_uploads']:
|
||||||
post.canvas_height = None
|
raise InvalidPostContentError(
|
||||||
|
'Invalid image dimensions returned during processing')
|
||||||
|
else:
|
||||||
|
post.canvas_width = None
|
||||||
|
post.canvas_height = None
|
||||||
setattr(post, '__content', content)
|
setattr(post, '__content', content)
|
||||||
|
|
||||||
|
|
||||||
|
@ -751,6 +798,8 @@ def merge_posts(
|
||||||
if replace_content:
|
if replace_content:
|
||||||
content = files.get(get_post_content_path(source_post))
|
content = files.get(get_post_content_path(source_post))
|
||||||
transfer_flags(source_post.post_id, target_post.post_id)
|
transfer_flags(source_post.post_id, target_post.post_id)
|
||||||
|
purge_post_signature(source_post)
|
||||||
|
purge_post_signature(target_post)
|
||||||
|
|
||||||
delete(source_post)
|
delete(source_post)
|
||||||
db.session.flush()
|
db.session.flush()
|
||||||
|
@ -768,38 +817,31 @@ def search_by_image_exact(image_content: bytes) -> Optional[model.Post]:
|
||||||
.one_or_none())
|
.one_or_none())
|
||||||
|
|
||||||
|
|
||||||
def search_by_image(image_content: bytes) -> List[PostLookalike]:
|
def search_by_image(image_content: bytes) -> List[Tuple[float, model.Post]]:
|
||||||
ret = []
|
query_signature = image_hash.generate_signature(image_content)
|
||||||
for result in image_hash.search_by_image(image_content):
|
query_words = image_hash.generate_words(query_signature)
|
||||||
post = try_get_post_by_id(result.path)
|
|
||||||
if post:
|
|
||||||
ret.append(PostLookalike(
|
|
||||||
score=result.score,
|
|
||||||
distance=result.distance,
|
|
||||||
post=post))
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
dbquery = '''
|
||||||
|
SELECT s.post_id, s.signature, count(a.query) AS score
|
||||||
|
FROM post_signature AS s, unnest(s.words, :q) AS a(word, query)
|
||||||
|
WHERE a.word = a.query
|
||||||
|
GROUP BY s.post_id
|
||||||
|
ORDER BY score DESC LIMIT 100;
|
||||||
|
'''
|
||||||
|
|
||||||
def populate_reverse_search() -> None:
|
candidates = db.session.execute(dbquery, {'q': query_words})
|
||||||
excluded_post_ids = image_hash.get_all_paths()
|
data = tuple(zip(*[
|
||||||
|
(post_id, image_hash.unpack_signature(packedsig))
|
||||||
post_ids_to_hash = (
|
for post_id, packedsig, score in candidates
|
||||||
db.session
|
]))
|
||||||
.query(model.Post.post_id)
|
if data:
|
||||||
.filter(
|
candidate_post_ids, sigarray = data
|
||||||
(model.Post.type == model.Post.TYPE_IMAGE) |
|
distances = image_hash.normalized_distance(sigarray, query_signature)
|
||||||
(model.Post.type == model.Post.TYPE_ANIMATION))
|
return [
|
||||||
.filter(~model.Post.post_id.in_(excluded_post_ids))
|
(distance, try_get_post_by_id(candidate_post_id))
|
||||||
.order_by(model.Post.post_id.asc())
|
for candidate_post_id, distance
|
||||||
.all())
|
in zip(candidate_post_ids, distances)
|
||||||
|
if distance < image_hash.DISTANCE_CUTOFF
|
||||||
for post_ids_chunk in util.chunks(post_ids_to_hash, 100):
|
]
|
||||||
posts_chunk = (
|
else:
|
||||||
db.session
|
return []
|
||||||
.query(model.Post)
|
|
||||||
.filter(model.Post.post_id.in_(post_ids_chunk))
|
|
||||||
.all())
|
|
||||||
for post in posts_chunk:
|
|
||||||
content_path = get_post_content_path(post)
|
|
||||||
if files.has(content_path):
|
|
||||||
image_hash.add_image(post.post_id, files.get(content_path))
|
|
||||||
|
|
|
@ -9,7 +9,8 @@ from szurubooru.model.post import (
|
||||||
PostFavorite,
|
PostFavorite,
|
||||||
PostScore,
|
PostScore,
|
||||||
PostNote,
|
PostNote,
|
||||||
PostFeature)
|
PostFeature,
|
||||||
|
PostSignature)
|
||||||
from szurubooru.model.comment import Comment, CommentScore
|
from szurubooru.model.comment import Comment, CommentScore
|
||||||
from szurubooru.model.snapshot import Snapshot
|
from szurubooru.model.snapshot import Snapshot
|
||||||
import szurubooru.model.util
|
import szurubooru.model.util
|
||||||
|
|
|
@ -143,6 +143,26 @@ class PostTag(Base):
|
||||||
self.tag_id = tag_id
|
self.tag_id = tag_id
|
||||||
|
|
||||||
|
|
||||||
|
class PostSignature(Base):
|
||||||
|
__tablename__ = 'post_signature'
|
||||||
|
|
||||||
|
post_id = sa.Column(
|
||||||
|
'post_id',
|
||||||
|
sa.Integer,
|
||||||
|
sa.ForeignKey('post.id'),
|
||||||
|
primary_key=True,
|
||||||
|
nullable=False,
|
||||||
|
index=True)
|
||||||
|
signature = sa.Column('signature', sa.LargeBinary, nullable=False)
|
||||||
|
words = sa.Column(
|
||||||
|
'words',
|
||||||
|
sa.dialects.postgresql.ARRAY(sa.Integer, dimensions=1),
|
||||||
|
nullable=False,
|
||||||
|
index=True)
|
||||||
|
|
||||||
|
post = sa.orm.relationship('Post')
|
||||||
|
|
||||||
|
|
||||||
class Post(Base):
|
class Post(Base):
|
||||||
__tablename__ = 'post'
|
__tablename__ = 'post'
|
||||||
|
|
||||||
|
@ -184,6 +204,11 @@ class Post(Base):
|
||||||
# foreign tables
|
# foreign tables
|
||||||
user = sa.orm.relationship('User')
|
user = sa.orm.relationship('User')
|
||||||
tags = sa.orm.relationship('Tag', backref='posts', secondary='post_tag')
|
tags = sa.orm.relationship('Tag', backref='posts', secondary='post_tag')
|
||||||
|
signature = sa.orm.relationship(
|
||||||
|
'PostSignature',
|
||||||
|
uselist=False,
|
||||||
|
cascade='all, delete-orphan',
|
||||||
|
lazy='joined')
|
||||||
relations = sa.orm.relationship(
|
relations = sa.orm.relationship(
|
||||||
'Post',
|
'Post',
|
||||||
secondary='post_relation',
|
secondary='post_relation',
|
||||||
|
|
Loading…
Reference in a new issue