From 07267b9737ec3f6f6fb340d568166cd1f12c4d3f Mon Sep 17 00:00:00 2001 From: noirscape Date: Wed, 4 Jan 2023 21:04:53 +0100 Subject: [PATCH] server: better iterable logic --- server/szurubooru/func/posts.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/server/szurubooru/func/posts.py b/server/szurubooru/func/posts.py index f502d9a0..eeed83ec 100644 --- a/server/szurubooru/func/posts.py +++ b/server/szurubooru/func/posts.py @@ -2,6 +2,7 @@ import hmac import logging from collections import namedtuple from datetime import datetime +from itertools import tee, chain, islice, izip from typing import Any, Callable, Dict, List, Optional, Tuple import sqlalchemy as sa @@ -96,6 +97,13 @@ FLAG_MAP = { model.Post.FLAG_SOUND: "sound", } +# https://stackoverflow.com/a/1012089 +def _get_nearby_iter(post_list): + previous_item, current_item, next_item = tee(post_list, 3) + previous_item = chain([None], previous_item) + next_item = chain(islice(next_item, 1, None), [None]) + return izip(previous_item, current_item, next_item) + def get_post_security_hash(id: int) -> str: return hmac.new( @@ -982,12 +990,10 @@ def get_pools_nearby( first_post_id = pool.posts[0].post_id, last_post_id = pool.posts[-1].post_id, - for idx, pool_post in enumerate(pool.posts): - if post.post_id == pool_post.post_id: - if post.post_id != first_post_id: - prev_post_id = pool.posts[idx-1].post_id - if post.post_id != last_post_id: - next_post_id = pool.posts[idx+1].post_id + for previous_item, current_item, next_item in _get_nearby_iter(pool.posts): + if post.post_id == current_item.post_id: + prev_post_id = previous_item.post_id + next_post_id = next_item.post_id break resp_entry = PoolPostsNearby(