diff --git a/server/szurubooru/tests/api/test_post_retrieving.py b/server/szurubooru/tests/api/test_post_retrieving.py index ac984c24..8d8dc2a0 100644 --- a/server/szurubooru/tests/api/test_post_retrieving.py +++ b/server/szurubooru/tests/api/test_post_retrieving.py @@ -125,3 +125,55 @@ def test_trying_to_retrieve_single_without_privileges( context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)), {"post_id": 999}, ) + +def test_get_pool_post_around(user_factory, post_factory, pool_factory, pool_post_factory, context_factory): + p1 = post_factory(id=1) + p2 = post_factory(id=2) + p3 = post_factory(id=3) + db.session.add_all([p1, p2, p3]) + + pool = pool_factory(id=1) + db.session.add(pool) + + pool_posts = [pool_post_factory(pool=pool, post=p1), pool_post_factory(pool=pool, post=p2), pool_post_factory(pool=pool, post=p3)] + db.session.add_all(pool_posts) + + result = api.post_api.get_pools_around(context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), {"post_id": 2}) + assert result[0]["previousPost"]["id"] == 1 and result[0]["nextPost"]["id"] == 3 + +def test_get_pool_post_around_start(user_factory, post_factory, pool_factory, pool_post_factory, context_factory): + p1 = post_factory(id=1) + p2 = post_factory(id=2) + p3 = post_factory(id=3) + db.session.add_all([p1, p2, p3]) + + pool = pool_factory(id=1) + db.session.add(pool) + + pool_posts = [pool_post_factory(pool=pool, post=p1), pool_post_factory(pool=pool, post=p2), pool_post_factory(pool=pool, post=p3)] + db.session.add_all(pool_posts) + + result = api.post_api.get_pools_around(context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), {"post_id": 1}) + assert result[0]["previousPost"] == None and result[0]["nextPost"]["id"] == 2 + +def test_get_pool_post_around_end(user_factory, post_factory, pool_factory, pool_post_factory, context_factory): + p1 = post_factory(id=1) + p2 = post_factory(id=2) + p3 = post_factory(id=3) + db.session.add_all([p1, p2, p3]) + + pool = pool_factory(id=1) + db.session.add(pool) + + pool_posts = [pool_post_factory(pool=pool, post=p1), pool_post_factory(pool=pool, post=p2), pool_post_factory(pool=pool, post=p3)] + db.session.add_all(pool_posts) + + result = api.post_api.get_pools_around(context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), {"post_id": 3}) + assert result[0]["previousPost"]["id"] == 2 and result[0]["nextPost"] == None + +def test_get_pool_post_around_no_pool(user_factory, post_factory, pool_factory, pool_post_factory, context_factory): + p1 = post_factory(id=1) + db.session.add(p1) + + result = api.post_api.get_pools_around(context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), {"post_id": 1}) + assert result == []