diff --git a/server/szurubooru/func/tags.py b/server/szurubooru/func/tags.py index a3944dd3..0d988edb 100644 --- a/server/szurubooru/func/tags.py +++ b/server/szurubooru/func/tags.py @@ -55,34 +55,51 @@ def serialize_tag(tag, options=None): options) def export_to_json(): - output = { - 'tags': [], - 'categories': [], - } - all_tags = db.session \ - .query(db.Tag) \ - .options( - sqlalchemy.orm.joinedload('suggestions'), - sqlalchemy.orm.joinedload('implications')) \ - .all() - for tag in all_tags: - item = { - 'names': [tag_name.name for tag_name in tag.names], - 'usages': tag.post_count, - 'category': tag.category.name, + tags = {} + categories = {} + + for result in db.session.query( + db.TagCategory.tag_category_id, + db.TagCategory.name, + db.TagCategory.color).all(): + categories[result[0]] = { + 'name': result[1], + 'color': result[2], } - if len(tag.suggestions): - item['suggestions'] = \ - [rel.names[0].name for rel in tag.suggestions] - if len(tag.implications): - item['implications'] = \ - [rel.names[0].name for rel in tag.implications] - output['tags'].append(item) - for category in tag_categories.get_all_categories(): - output['categories'].append({ - 'name': category.name, - 'color': category.color, - }) + + for result in db.session.query(db.TagName.tag_id, db.TagName.name).all(): + if not result[0] in tags: + tags[result[0]] = {'names': []} + tags[result[0]]['names'].append(result[1]) + + for result in db.session \ + .query(db.TagSuggestion.parent_id, db.TagName.name) \ + .join(db.TagName, db.TagName.tag_id == db.TagSuggestion.child_id) \ + .all(): + if not 'suggestions' in tags[result[0]]: + tags[result[0]]['suggestions'] = [] + tags[result[0]]['suggestions'].append(result[1]) + + for result in db.session \ + .query(db.TagImplication.parent_id, db.TagName.name) \ + .join(db.TagName, db.TagName.tag_id == db.TagImplication.child_id) \ + .all(): + if not 'implications' in tags[result[0]]: + tags[result[0]]['implications'] = [] + tags[result[0]]['implications'].append(result[1]) + + for result in db.session.query( + db.Tag.tag_id, + db.Tag.category_id, + db.Tag.post_count).all(): + tags[result[0]]['category'] = categories[result[1]]['name'] + tags[result[0]]['usages'] = result[2] + + output = { + 'categories': list(categories.values()), + 'tags': list(tags.values()), + } + export_path = os.path.join(config.config['data_dir'], 'tags.json') with open(export_path, 'w') as handle: handle.write(json.dumps(output, separators=(',', ':'))) diff --git a/server/szurubooru/tests/api/test_tag_export.py b/server/szurubooru/tests/api/test_tag_export.py index 7ea84fb6..29fb7e14 100644 --- a/server/szurubooru/tests/api/test_tag_export.py +++ b/server/szurubooru/tests/api/test_tag_export.py @@ -8,6 +8,7 @@ def test_export( tmpdir, query_counter, config_injector, + post_factory, tag_factory, tag_category_factory): config_injector({ @@ -22,10 +23,12 @@ def test_export( imp1 = tag_factory(names=['imp1'], category=cat1) imp2 = tag_factory(names=['imp2'], category=cat1) tag = tag_factory(names=['alias1', 'alias2'], category=cat2) - tag.post_count = 1 db.session.add_all([tag, sug1, sug2, imp1, imp2, cat1, cat2]) + post = post_factory() + post.tags = [tag] db.session.flush() db.session.add_all([ + post, db.TagSuggestion(tag.tag_id, sug1.tag_id), db.TagSuggestion(tag.tag_id, sug2.tag_id), db.TagImplication(tag.tag_id, imp1.tag_id), @@ -35,7 +38,7 @@ def test_export( with query_counter: tags.export_to_json() - assert len(query_counter.statements) == 2 + assert len(query_counter.statements) == 5 export_path = os.path.join(config.config['data_dir'], 'tags.json') assert os.path.exists(export_path)