From 85f012b02f0f87e28a4ba1ebe18e81a8257a923c Mon Sep 17 00:00:00 2001 From: Deka Jello Date: Thu, 18 Apr 2024 10:24:40 -0500 Subject: [PATCH] Rewrite get_pool_posts_around to not use raw sql --- server/szurubooru/api/post_api.py | 4 +- server/szurubooru/func/posts.py | 90 +++++++++---------------------- 2 files changed, 28 insertions(+), 66 deletions(-) diff --git a/server/szurubooru/api/post_api.py b/server/szurubooru/api/post_api.py index fe5d1a3a..96588ce5 100644 --- a/server/szurubooru/api/post_api.py +++ b/server/szurubooru/api/post_api.py @@ -291,10 +291,10 @@ def get_pool_posts_around( auth.verify_privilege(ctx.user, "posts:list") auth.verify_privilege(ctx.user, "pools:list") auth.verify_privilege(ctx.user, "pools:view") - _search_executor_config.user = ctx.user + _search_executor_config.user = ctx.user # never calling _search_executor so why are we setting user? post = _get_post(params) results = posts.get_pool_posts_around(post) - return posts.serialize_pool_posts_around(results) + return posts.serialize_pool_posts_around(ctx, results) @rest.routes.post("/posts/reverse-search/?") diff --git a/server/szurubooru/func/posts.py b/server/szurubooru/func/posts.py index ad298c49..21136773 100644 --- a/server/szurubooru/func/posts.py +++ b/server/szurubooru/func/posts.py @@ -135,16 +135,12 @@ def get_post_content_path(post: model.Post) -> str: ) -def get_post_thumbnail_path_from_id(post_id: int) -> str: - return "generated-thumbnails/%d_%s.jpg" % ( - post_id, - get_post_security_hash(post_id), - ) - - def get_post_thumbnail_path(post: model.Post) -> str: assert post - return get_post_thumbnail_path_from_id(post.post_id) + return "generated-thumbnails/%d_%s.jpg" % ( + post.post_id, + get_post_security_hash(post.post_id), + ) def get_post_thumbnail_backup_path(post: model.Post) -> str: @@ -977,61 +973,27 @@ def search_by_image(image_content: bytes) -> List[Tuple[float, model.Post]]: PoolPostsAround = namedtuple('PoolPostsAround', 'pool first_post prev_post next_post last_post') + def get_pool_posts_around(post: model.Post) -> List[PoolPostsAround]: - return [] - - around = dict() - pool_ids = set() - post_ids = set() - - dbquery = """ - SELECT around.ord, around.pool_id, around.post_id, around.delta - FROM pool_post pp, - LATERAL get_pool_posts_around(pp.pool_id, pp.post_id) around - WHERE pp.post_id = :post_id; - """ - - for order, pool_id, post_id, delta in db.session.execute(dbquery, {"post_id": post.post_id}): - if pool_id not in around: - around[pool_id] = [None, None, None, None] - if delta == -2: - around[pool_id][0] = post_id - elif delta == -1: - around[pool_id][1] = post_id - elif delta == 1: - around[pool_id][2] = post_id - elif delta == 2: - around[pool_id][3] = post_id - pool_ids.add(pool_id) - post_ids.add(post_id) - - pools = dict() - posts = dict() - - for pool in db.session.query(model.Pool).filter(model.Pool.pool_id.in_(pool_ids)).all(): - pools[pool.pool_id] = pool - - for result in db.session.query(model.Post.post_id).filter(model.Post.post_id.in_(post_ids)).all(): - post_id = result[0] - posts[post_id] = { "id": post_id, "thumbnailUrl": get_post_thumbnail_path_from_id(post_id) } - results = [] + for pool in post.pools: + first_post, prev_post, next_post, last_post = None, None, None, None - for pool_id, entry in around.items(): - first_post = None - prev_post = None - next_post = None - last_post = None - if entry[1] is not None: - prev_post = posts[entry[1]] - if entry[0] is not None: - first_post = posts[entry[0]] - if entry[2] is not None: - next_post = posts[entry[2]] - if entry[3] is not None: - last_post = posts[entry[3]] - results.append(PoolPostsAround(pools[pool_id], first_post, prev_post, next_post, last_post)) + # find index of current post: + index_in_pool = list(map(lambda p: p.post_id, pool.posts)).index(post.post_id) + # collect first, prev, next, last post: + if index_in_pool > 0: + first_post = pool.posts[0] + prev_post = pool.posts[index_in_pool - 1] + if index_in_pool < len(pool.posts) - 1: + next_post = pool.posts[index_in_pool + 1] + last_post = pool.posts[-1] + + around = PoolPostsAround(pool, first_post, prev_post, next_post, last_post) + logger.info("===============> WE NOW HAVE: %s", around) + results.append(around) + return results @@ -1042,14 +1004,14 @@ def sort_pool_posts_around(around: List[PoolPostsAround]) -> List[PoolPostsAroun ) -def serialize_pool_posts_around(around: List[PoolPostsAround]) -> Optional[rest.Response]: +def serialize_pool_posts_around(ctx: rest.Context, around: List[PoolPostsAround]) -> Optional[rest.Response]: return [ { "pool": pools.serialize_micro_pool(entry.pool), - "firstPost": entry.first_post, - "prevPost": entry.prev_post, - "nextPost": entry.next_post, - "lastPost": entry.last_post + "firstPost": serialize_micro_post(ctx.user, entry.first_post), + "prevPost": serialize_micro_post(ctx.user, entry.prev_post), + "nextPost": serialize_micro_post(ctx.user, entry.next_post), + "lastPost": serialize_micro_post(ctx.user, entry.last_post) } for entry in sort_pool_posts_around(around) ]