szurubooru/server/migrate-v1

390 lines
14 KiB
Text
Raw Normal View History

#!/usr/bin/env python3
import os
import sys
import datetime
import argparse
import json
import zlib
import concurrent.futures
import logging
import coloredlogs
import sqlalchemy
from szurubooru import config, db
from szurubooru.func import files, images, posts, comments
coloredlogs.install(fmt='[%(asctime)-15s] %(message)s')
logger = logging.getLogger(__name__)
def read_file(path):
with open(path, 'rb') as handle:
return handle.read()
def get_v1_session(args):
dsn = '{schema}://{user}:{password}@{host}:{port}/{name}?charset=utf8'.format(
schema='mysql+pymysql',
user=args.user,
password=args.password,
host=args.host,
port=args.port,
name=args.name)
logger.info('Connecting to %r...', dsn)
engine = sqlalchemy.create_engine(dsn)
session_maker = sqlalchemy.orm.sessionmaker(bind=engine)
return session_maker()
def parse_args():
parser = argparse.ArgumentParser(
description='Migrate database from szurubooru v1.x to v2.x.\n\n')
parser.add_argument(
'--data-dir',
metavar='PATH', required=True,
help='root directory of v1.x deploy')
parser.add_argument(
'--host',
metavar='HOST', required=True,
help='name of v1.x database host')
parser.add_argument(
'--port',
metavar='NUM', type=int, default=3306,
help='port to v1.x database host')
parser.add_argument(
'--name',
metavar='HOST', required=True,
help='v1.x database name')
parser.add_argument(
'--user',
metavar='NAME', required=True,
help='v1.x database user')
parser.add_argument(
'--password',
metavar='PASSWORD', required=True,
help='v1.x database password')
return parser.parse_args()
def exec(session, query):
for row in session.execute(query):
row = dict(zip(row.keys(), row))
yield row
def import_users(v1_data_dir, v1_session, v2_session):
for row in exec(v1_session, 'SELECT * FROM users'):
logger.info('Importing user %s...', row['name'])
user = db.User()
user.user_id = row['id']
user.name = row['name']
user.password_salt = row['passwordSalt']
user.password_hash = row['passwordHash']
user.email = row['email']
user.rank = {
6: db.User.RANK_ADMINISTRATOR,
5: db.User.RANK_MODERATOR,
4: db.User.RANK_POWER,
3: db.User.RANK_REGULAR,
2: db.User.RANK_RESTRICTED,
1: db.User.RANK_ANONYMOUS,
}[row['accessRank']]
user.creation_time = row['creationTime']
user.last_login_time = row['lastLoginTime']
user.avatar_style = {
2: db.User.AVATAR_MANUAL,
1: db.User.AVATAR_GRAVATAR,
0: db.User.AVATAR_GRAVATAR,
}[row['avatarStyle']]
v2_session.add(user)
if user.avatar_style == db.User.AVATAR_MANUAL:
source_avatar_path = os.path.join(
v1_data_dir, 'public_html', 'data', 'avatars', str(user.user_id))
avatar_content = read_file(source_avatar_path)
image = images.Image(avatar_content)
image.resize_fill(
int(config.config['thumbnails']['avatar_width']),
int(config.config['thumbnails']['avatar_height']))
files.save('avatars/' + user.name.lower() + '.png', image.to_png())
v2_session.commit()
def import_tag_categories(v1_session, v2_session):
category_to_id_map = {}
for row in exec(v1_session, 'SELECT DISTINCT category FROM tags'):
logger.info('Importing tag category %s...', row['category'])
category = db.TagCategory()
category.tag_category_id = len(category_to_id_map)
category.name = row['category']
category.color = 'default'
v2_session.add(category)
category_to_id_map[category.name] = category.tag_category_id
return category_to_id_map
def import_tags(category_to_id_map, v1_session, v2_session):
unused_tag_ids = []
for row in exec(v1_session, 'SELECT * FROM tags'):
logger.info('Importing tag %s...', row['name'])
if row['banned']:
logger.info('Ignored banned tag %s', row['name'])
unused_tag_ids.append(row['id'])
continue
tag = db.Tag()
tag.tag_id = row['id']
tag.names = [db.TagName(name=row['name'])]
tag.category_id = category_to_id_map[row['category']]
tag.creation_time = row['creationTime']
tag.last_edit_time = row['lastEditTime']
v2_session.add(tag)
v2_session.commit()
return unused_tag_ids
def import_tag_relations(unused_tag_ids, v1_session, v2_session):
logger.info('Importing tag relations...')
for row in exec(v1_session, 'SELECT * FROM tagRelations'):
if row['tag1id'] in unused_tag_ids or row['tag2id'] in unused_tag_ids:
continue
if row['type'] == 1:
v2_session.add(
db.TagImplication(
parent_id=row['tag1id'], child_id=row['tag2id']))
else:
v2_session.add(
db.TagSuggestion(
parent_id=row['tag1id'], child_id=row['tag2id']))
v2_session.commit()
def import_posts(v1_session, v2_session):
unused_post_ids = []
for row in exec(v1_session, 'SELECT * FROM posts'):
logger.info('Importing post %d...', row['id'])
if row['contentType'] == 4:
logger.warn('Ignoring youtube post %d', row['id'])
unused_post_ids.append(row['id'])
continue
post = db.Post()
post.post_id = row['id']
post.user_id = row['userId']
post.type = {
1: db.Post.TYPE_IMAGE,
2: db.Post.TYPE_FLASH,
3: db.Post.TYPE_VIDEO,
5: db.Post.TYPE_ANIMATION,
}[row['contentType']]
post.source = row['source']
post.canvas_width = row['imageWidth']
post.canvas_height = row['imageHeight']
post.file_size = row['originalFileSize']
post.creation_time = row['creationTime']
post.last_edit_time = row['lastEditTime']
post.checksum = row['contentChecksum']
post.mime_type = row['contentMimeType']
post.safety = {
1: db.Post.SAFETY_SAFE,
2: db.Post.SAFETY_SKETCHY,
3: db.Post.SAFETY_UNSAFE,
}[row['safety']]
if row['flags'] & 1:
post.flags = [db.Post.FLAG_LOOP]
v2_session.add(post)
v2_session.commit()
return unused_post_ids
def _import_post_content_for_post(
unused_post_ids, v1_data_dir, v1_session, v2_session, row, post):
logger.info('Importing post %d content...', row['id'])
if row['id'] in unused_post_ids:
logger.warn('Ignoring unimported post %d', row['id'])
return
assert post
source_content_path = os.path.join(
v1_data_dir,
'public_html',
'data',
'posts',
row['name'])
source_thumb_path = os.path.join(
v1_data_dir,
'public_html',
'data',
'posts',
row['name'] + '-custom-thumb')
post_content = read_file(source_content_path)
files.save(posts.get_post_content_path(post), post_content)
if os.path.exists(source_thumb_path):
thumb_content = read_file(source_thumb_path)
files.save(posts.get_post_thumbnail_backup_path(post), thumb_content)
posts.generate_post_thumbnail(post)
def import_post_content(unused_post_ids, v1_data_dir, v1_session, v2_session):
rows = list(exec(v1_session, 'SELECT * FROM posts'))
posts = {post.post_id: post for post in v2_session.query(db.Post).all()}
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
for row in rows:
post = posts.get(row['id'])
executor.submit(
_import_post_content_for_post,
unused_post_ids, v1_data_dir, v1_session, v2_session, row, post)
def import_post_tags(unused_post_ids, v1_session, v2_session):
logger.info('Importing post tags...')
for row in exec(v1_session, 'SELECT * FROM postTags'):
if row['postId'] in unused_post_ids:
continue
v2_session.add(db.PostTag(post_id=row['postId'], tag_id=row['tagId']))
v2_session.commit()
def import_post_notes(unused_post_ids, v1_session, v2_session):
logger.info('Importing post notes...')
for row in exec(v1_session, 'SELECT * FROM postNotes'):
if row['postId'] in unused_post_ids:
continue
x, y, w, h = row['x'], row['y'], row['width'], row['height']
x /= 100
y /= 100
w /= 100
h /= 100
post_note = db.PostNote()
post_note.post_id = row['postId']
post_note.text = row['text']
post_note.polygon = [
(x, y ),
(x + w, y ),
(x + w, y + w),
(x, y + w),
]
v2_session.add(post_note)
v2_session.commit()
def import_post_relations(unused_post_ids, v1_session, v2_session):
logger.info('Importing post relations...')
for row in exec(v1_session, 'SELECT * FROM postRelations'):
if row['post1id'] in unused_post_ids or row['post2id'] in unused_post_ids:
continue
v2_session.add(
db.PostRelation(
parent_id=row['post1id'], child_id=row['post2id']))
v2_session.commit()
def import_post_favorites(unused_post_ids, v1_session, v2_session):
logger.info('Importing post favorites...')
for row in exec(v1_session, 'SELECT * FROM favorites'):
if row['postId'] in unused_post_ids:
continue
v2_session.add(
db.PostFavorite(
post_id=row['postId'],
user_id=row['userId'],
time=row['time'] or datetime.datetime.min))
v2_session.commit()
def import_comments(unused_post_ids, v1_session, v2_session):
for row in exec(v1_session, 'SELECT * FROM comments'):
logger.info('Importing comment %d...', row['id'])
if row['postId'] in unused_post_ids:
logger.warn('Ignoring comment for unimported post %d', row['postId'])
continue
if not posts.try_get_post_by_id(row['postId']):
logger.warn('Ignoring comment for non existing post %d', row['postId'])
continue
comment = db.Comment()
comment.comment_id = row['id']
comment.user_id = row['userId']
comment.post_id = row['postId']
comment.creation_time = row['creationTime']
comment.last_edit_time = row['lastEditTime']
comment.text = row['text']
v2_session.add(comment)
v2_session.commit()
def import_scores(v1_session, v2_session):
logger.info('Importing scores...')
for row in exec(v1_session, 'SELECT * FROM scores'):
if row['postId']:
post = posts.try_get_post_by_id(row['postId'])
if not post:
logger.warn('Ignoring score for unimported post %d', row['postId'])
continue
score = db.PostScore()
score.post = post
elif row['commentId']:
comment = comments.try_get_comment_by_id(row['commentId'])
if not comment:
logger.warn('Ignoring score for unimported comment %d', row['commentId'])
continue
score = db.CommentScore()
score.comment = comment
score.score = row['score']
score.time = row['time'] or datetime.datetime.min
score.user_id = row['userId']
v2_session.add(score)
v2_session.commit()
def import_snapshots(v1_session, v2_session):
logger.info('Importing snapshots...')
for row in exec(v1_session, 'SELECT * FROM snapshots ORDER BY time ASC'):
snapshot = db.Snapshot()
snapshot.creation_time = row['time']
snapshot.user_id = row['userId']
snapshot.operation = {
0: db.Snapshot.OPERATION_CREATED,
1: db.Snapshot.OPERATION_MODIFIED,
2: db.Snapshot.OPERATION_DELETED,
}[row['operation']]
snapshot.resource_type = {
0: 'post',
1: 'tag',
}[row['type']]
snapshot.resource_id = row['primaryKey']
data = json.loads(zlib.decompress(row['data'], -15).decode('utf-8'))
if snapshot.resource_type == 'post':
if 'contentChecksum' in data:
data['checksum'] = data['contentChecksum']
del data['contentChecksum']
if 'tags' in data and isinstance(data['tags'], dict):
data['tags'] = list(data['tags'].values())
snapshot.resource_repr = row['primaryKey']
elif snapshot.resource_type == 'tag':
if 'banned' in data:
del data['banned']
if 'name' in data:
data['names'] = [data['name']]
del data['name']
snapshot.resource_repr = data['names'][0]
snapshot.data = data
v2_session.add(snapshot)
v2_session.commit()
def main():
args = parse_args()
v1_data_dir = args.data_dir
v1_session = get_v1_session(args)
v2_session = db.session
if v2_session.query(db.Tag).count() \
or v2_session.query(db.Post).count() \
or v2_session.query(db.Comment).count() \
or v2_session.query(db.User).count():
logger.error('v2.x database is dirty! Aborting.')
sys.exit(1)
import_users(v1_data_dir, v1_session, v2_session)
category_to_id_map = import_tag_categories(v1_session, v2_session)
unused_tag_ids = import_tags(category_to_id_map, v1_session, v2_session)
import_tag_relations(unused_tag_ids, v1_session, v2_session)
unused_post_ids = import_posts(v1_session, v2_session)
import_post_content(unused_post_ids, v1_data_dir, v1_session, v2_session)
import_post_tags(unused_post_ids, v1_session, v2_session)
import_post_notes(unused_post_ids, v1_session, v2_session)
import_post_relations(unused_post_ids, v1_session, v2_session)
import_post_favorites(unused_post_ids, v1_session, v2_session)
import_comments(unused_post_ids, v1_session, v2_session)
import_scores(v1_session, v2_session)
import_snapshots(v1_session, v2_session)
if __name__ == '__main__':
main()