diff --git a/server/szurubooru/api/tag_category_api.py b/server/szurubooru/api/tag_category_api.py index 882af10c..139da1d8 100644 --- a/server/szurubooru/api/tag_category_api.py +++ b/server/szurubooru/api/tag_category_api.py @@ -40,7 +40,8 @@ def get_tag_category(ctx, params): @routes.put('/tag-category/(?P[^/]+)/?') def update_tag_category(ctx, params): - category = tag_categories.get_category_by_name(params['category_name']) + category = tag_categories.get_category_by_name( + params['category_name'], lock=True) versions.verify_version(category, ctx) versions.bump_version(category) if ctx.has_param('name'): @@ -60,7 +61,8 @@ def update_tag_category(ctx, params): @routes.delete('/tag-category/(?P[^/]+)/?') def delete_tag_category(ctx, params): - category = tag_categories.get_category_by_name(params['category_name']) + category = tag_categories.get_category_by_name( + params['category_name'], lock=True) versions.verify_version(category, ctx) auth.verify_privilege(ctx.user, 'tag_categories:delete') tag_categories.delete_category(category) @@ -73,8 +75,10 @@ def delete_tag_category(ctx, params): @routes.put('/tag-category/(?P[^/]+)/default/?') def set_tag_category_as_default(ctx, params): auth.verify_privilege(ctx.user, 'tag_categories:set_default') - category = tag_categories.get_category_by_name(params['category_name']) + category = tag_categories.get_category_by_name( + params['category_name'], lock=True) tag_categories.set_default_category(category) + ctx.session.flush() snapshots.modify(category, ctx.user) ctx.session.commit() tags.export_to_json() diff --git a/server/szurubooru/func/cache.py b/server/szurubooru/func/cache.py index 78fb871d..148aa992 100644 --- a/server/szurubooru/func/cache.py +++ b/server/szurubooru/func/cache.py @@ -55,5 +55,10 @@ def get(key): return _CACHE.hash[key].value +def remove(key): + if has(key): + del _CACHE.hash[key] + + def put(key, value): _CACHE.insert_item(LruCacheItem(key, value)) diff --git a/server/szurubooru/func/tag_categories.py b/server/szurubooru/func/tag_categories.py index ac27c999..bca15df2 100644 --- a/server/szurubooru/func/tag_categories.py +++ b/server/szurubooru/func/tag_categories.py @@ -4,6 +4,9 @@ from szurubooru import config, db, errors from szurubooru.func import util, cache +DEFAULT_CATEGORY_NAME_CACHE_KEY = 'default-tag-category' + + class TagCategoryNotFoundError(errors.NotFoundError): pass @@ -69,6 +72,7 @@ def update_category_name(category, name): raise InvalidTagCategoryNameError('Name is too long.') _verify_name_validity(name) category.name = name + cache.remove(DEFAULT_CATEGORY_NAME_CACHE_KEY) def update_category_color(category, color): @@ -82,15 +86,17 @@ def update_category_color(category, color): category.color = color -def try_get_category_by_name(name): - return db.session \ +def try_get_category_by_name(name, lock=False): + query = db.session \ .query(db.TagCategory) \ - .filter(sqlalchemy.func.lower(db.TagCategory.name) == name.lower()) \ - .one_or_none() + .filter(sqlalchemy.func.lower(db.TagCategory.name) == name.lower()) + if lock: + query = query.with_lockmode('update') + return query.one_or_none() -def get_category_by_name(name): - category = try_get_category_by_name(name) +def get_category_by_name(name, lock=False): + category = try_get_category_by_name(name, lock) if not category: raise TagCategoryNotFoundError('Tag category %r not found.' % name) return category @@ -104,38 +110,50 @@ def get_all_categories(): return db.session.query(db.TagCategory).all() -def try_get_default_category(): - key = 'default-tag-category' - if cache.has(key): - return cache.get(key) - category = db.session \ +def try_get_default_category(lock=False): + query = db.session \ .query(db.TagCategory) \ - .filter(db.TagCategory.default) \ - .first() + .filter(db.TagCategory.default) + if lock: + query = query.with_lockmode('update') + category = query.first() # if for some reason (e.g. as a result of migration) there's no default # category, get the first record available. if not category: - category = db.session \ + query = db.session \ .query(db.TagCategory) \ - .order_by(db.TagCategory.tag_category_id.asc()) \ - .first() - cache.put(key, category) + .order_by(db.TagCategory.tag_category_id.asc()) + if lock: + query = query.with_lockmode('update') + category = query.first() return category -def get_default_category(): - category = try_get_default_category() +def get_default_category(lock=False): + category = try_get_default_category(lock) if not category: raise TagCategoryNotFoundError('No tag category created yet.') return category +def get_default_category_name(): + if cache.has(DEFAULT_CATEGORY_NAME_CACHE_KEY): + return cache.get(DEFAULT_CATEGORY_NAME_CACHE_KEY) + default_category = try_get_default_category() + default_category_name = default_category.name if default_category else None + cache.put(DEFAULT_CATEGORY_NAME_CACHE_KEY, default_category_name) + return default_category_name + + def set_default_category(category): assert category - old_category = try_get_default_category() + old_category = try_get_default_category(lock=True) if old_category: + db.session.refresh(old_category) old_category.default = False + db.session.refresh(category) category.default = True + cache.remove(DEFAULT_CATEGORY_NAME_CACHE_KEY) def delete_category(category): diff --git a/server/szurubooru/func/tags.py b/server/szurubooru/func/tags.py index 22ebde33..73257148 100644 --- a/server/szurubooru/func/tags.py +++ b/server/szurubooru/func/tags.py @@ -58,8 +58,7 @@ def _check_name_intersection(names1, names2, case_sensitive): def sort_tags(tags): - default_category = tag_categories.try_get_default_category() - default_category_name = default_category.name if default_category else None + default_category_name = tag_categories.get_default_category_name() return sorted( tags, key=lambda tag: ( @@ -170,7 +169,7 @@ def get_or_create_tags_by_names(names): names = util.icase_unique(names) existing_tags = get_tags_by_names(names) new_tags = [] - tag_category_name = tag_categories.get_default_category().name + tag_category_name = tag_categories.get_default_category_name() for name in names: found = False for existing_tag in existing_tags: diff --git a/server/szurubooru/tests/func/test_tag_categories.py b/server/szurubooru/tests/func/test_tag_categories.py index 1b036cea..70f0aa0e 100644 --- a/server/szurubooru/tests/func/test_tag_categories.py +++ b/server/szurubooru/tests/func/test_tag_categories.py @@ -181,16 +181,32 @@ def test_try_get_default_category_when_default(tag_category_factory): assert actual_default_category != category1 -def test_try_get_default_category_from_cache(tag_category_factory): +def test_get_default_category_name(tag_category_factory): + category1 = tag_category_factory() + category2 = tag_category_factory(default=True) + db.session.add_all([category1, category2]) + db.session.flush() + assert tag_categories.get_default_category_name() == category2.name + category2.default = False + db.session.flush() + cache.purge() + assert tag_categories.get_default_category_name() == category1.name + db.session.query(db.TagCategory).delete() + cache.purge() + assert tag_categories.get_default_category_name() is None + + +def test_get_default_category_name_caching(tag_category_factory): category1 = tag_category_factory() category2 = tag_category_factory() db.session.add_all([category1, category2]) db.session.flush() - tag_categories.try_get_default_category() - db.session.query(db.TagCategory).delete() - assert tag_categories.try_get_default_category() == category1 + tag_categories.get_default_category_name() + db.session.delete(category1) + db.session.flush() + assert tag_categories.get_default_category_name() == category1.name cache.purge() - assert tag_categories.try_get_default_category() is None + assert tag_categories.get_default_category_name() == category2.name def test_get_default_category():