PostgreSQL 同一テーブル定義のデータ移行
はじめに
同じテーブル定義を持つPostgreSQLデータベース間でデータ移行が必要になったため、
Pythonを使ったデータ移行スクリプトや概要を備忘録として残します。
処理概要
- テーブル定義の比較
- データのバルクインサート
- エラーハンドリング
データのバルクインサート
当初は単純なINSERTによる逐次処理でデータ移行を行っていましたが、
パフォーマンスに課題があったため、PostgreSQLのCOPY機能を使用しています。
- 約5000件のデータ移行時のパフォーマンス比較
- 逐次処理(INSERT): 数分程度
- COPY機能を使用した処理: 数秒程度
環境
以下の環境で動作を確認しています。
Dockerイメージ
FROM python:3.12
ライブラリ
以下のライブラリを使用しています。
- psycopg2-binary: PostgreSQLデータベースに接続するためのライブラリ
pip install psycopg2-binary
手順
1. PostgreSQLへの接続
PostgreSQLデータベースへの接続を行います。
def connect_db(host, dbname, user, password, port):
conn = psycopg2.connect(
host=host, dbname=dbname, user=user, password=password, port=port
)
return conn
2. テーブルの存在確認
移行元・移行先のテーブルが存在するか確認します。
def check_table_exists(conn, table_name):
with conn.cursor() as cur:
cur.execute(
sql.SQL(
"""
SELECT EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_name = %s
)
"""
),
[table_name],
)
return cur.fetchone()[0]
3. テーブル定義の取得と比較
テーブル定義を取得し、移行元と移行先の差異を比較します。
差異があれば詳細を出力し、処理を中断します。
def compare_table_definitions(src_def, dest_def, table_name):
differences = []
src_columns = {col[0]: col for col in src_def}
dest_columns = {col[0]: col for col in dest_def}
for col_name, src_col in src_columns.items():
if (dest_col := dest_columns.get(col_name)) is None:
differences.append(f"移行先には存在しないカラム: {col_name}")
else:
if src_col != dest_col:
differences.append(
f"カラム '{col_name}' の差異: 移行元: {src_col}, 移行先: {dest_col}"
)
for col_name in dest_columns.keys():
if col_name not in src_columns:
differences.append(f"移行元には存在しないカラム: {col_name}")
if differences:
print(f"テーブル '{table_name}' に差異があります:")
for diff in differences:
print(f" - {diff}")
return False
return True
4. データのバルクインサート
バッチ単位でデータを挿入し、大量データを効率的に処理します。
def bulk_insert_data(dest_conn, table_name, columns, rows, total_rows, batch_num):
temp_file_path = f"{table_name}_data.csv"
with open(temp_file_path, "w", newline="", encoding="utf-8") as f:
writer = csv.writer(f, quoting=csv.QUOTE_MINIMAL, escapechar="\\")
for row in rows:
processed_row = []
for field in row:
if field is None:
processed_row.append("NULL")
elif field == "":
processed_row.append("")
else:
processed_row.append(field)
writer.writerow(processed_row)
with open(temp_file_path, "r", encoding="utf-8") as f:
dest_conn.cursor().copy_expert(
sql.SQL(
"COPY {} ({}) FROM STDIN WITH (FORMAT csv, DELIMITER ',', NULL 'NULL')"
).format(
sql.Identifier(table_name),
sql.SQL(", ").join(map(sql.Identifier, columns)),
),
f,
)
os.remove(temp_file_path)
print(f"テーブル '{table_name}' の進捗: バッチ {batch_num} - {len(rows)} 件挿入 (累計: {total_rows} 件)")
- 注意点
- トランザクション管理: エラー発生時には全テーブルデータをロールバック
- インデックスとトリガーの無効化: パフォーマンス向上のため、無効化箇所をコメントアウト
全ソースコード
import psycopg2
from psycopg2 import sql
import time
import csv
import os
# PostgreSQL接続用の関数
def connect_db(host, dbname, user, password, port):
conn = psycopg2.connect(
host=host, dbname=dbname, user=user, password=password, port=port
)
return conn
# テーブルの存在確認を行う関数
def check_table_exists(conn, table_name):
with conn.cursor() as cur:
cur.execute(
sql.SQL(
"""
SELECT EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_name = %s
)
"""
),
[table_name],
)
return cur.fetchone()[0]
# テーブル定義を取得する関数
def get_table_definition(conn, table_name):
with conn.cursor() as cur:
cur.execute(
sql.SQL(
"""
SELECT column_name, data_type, is_nullable, character_maximum_length
FROM information_schema.columns
WHERE table_name = %s
ORDER BY ordinal_position
"""
),
[table_name],
)
return cur.fetchall()
# テーブル定義を比較し、差分を出力する関数
def compare_table_definitions(src_def, dest_def, table_name):
differences = []
src_columns = {col[0]: col for col in src_def}
dest_columns = {col[0]: col for col in dest_def}
# 移行元に存在し、移行先に存在しないカラム
for col_name, src_col in src_columns.items():
if (dest_col := dest_columns.get(col_name)) is None:
differences.append(f"移行先には存在しないカラム: {col_name}")
else:
if src_col != dest_col:
differences.append(
f"カラム '{col_name}' の差異: 移行元: {src_col}, 移行先: {dest_col}"
)
# 移行先に存在し、移行元に存在しないカラム
for col_name in dest_columns.keys():
if col_name not in src_columns:
differences.append(f"移行元には存在しないカラム: {col_name}")
if differences:
print(f"テーブル '{table_name}' に差異があります:")
for diff in differences:
print(f" - {diff}")
return False
return True
# バルクデータ移行処理
def bulk_insert_data(dest_conn, table_name, columns, rows, total_rows, batch_num):
with dest_conn.cursor() as dest_cur:
# 一時ファイルのパスを定義
temp_file_path = f"{table_name}_data.csv"
# COPYコマンドを使用するための一時ファイル作成
with open(temp_file_path, "w", newline="", encoding="utf-8") as f:
writer = csv.writer(f, quoting=csv.QUOTE_MINIMAL, escapechar="\\")
for row in rows:
processed_row = []
for field in row:
if field is None:
processed_row.append("NULL") # NULLを文字列として書き込む
elif field == "":
processed_row.append("") # 空文字として処理
else:
# 改行を含むフィールドをクォートで囲む
if isinstance(field, str) and ("\n" in field or "\r" in field):
# ダブルクォートをエスケープしてフィールドをクォートで囲む
processed_row.append(f'{field.replace("\"", "\"\"")}')
else:
processed_row.append(
field
) # その他のフィールドはそのまま書き込む
writer.writerow(processed_row)
# COPYコマンドでデータを挿入
with open(temp_file_path, "r", encoding="utf-8") as f:
dest_cur.copy_expert(
sql.SQL(
"COPY {} ({}) FROM STDIN WITH (FORMAT csv, DELIMITER ',', NULL 'NULL')"
).format(
sql.Identifier(table_name),
sql.SQL(", ").join(map(sql.Identifier, columns)),
),
f,
)
# 一時ファイルを削除
os.remove(temp_file_path)
# 挿入が成功したバッチごとに進捗状況を出力
print(
f"テーブル '{table_name}' の進捗: バッチ {batch_num} - {len(rows)} 件挿入 (累計: {total_rows} 件)"
)
# データ移行処理
def migrate_data(src_conn, dest_conn, table_name, batch_size=10000):
total_rows = 0 # 総データ数
batch_num = 1 # バッチ番号
with src_conn.cursor() as src_cur:
src_cur.execute(sql.SQL("SELECT * FROM {}").format(sql.Identifier(table_name)))
columns = [desc[0] for desc in src_cur.description] # カラム名を取得
# データをバッチで読み込んでバルクインサート
while True:
rows = src_cur.fetchmany(batch_size)
if not rows:
break # データがなくなったら終了
total_rows += len(rows)
bulk_insert_data(
dest_conn, table_name, columns, rows, total_rows, batch_num
)
batch_num += 1 # 次のバッチへ
# メイン処理
def main():
# 移行元データベース情報
src_db = {
"host": "{host}",
"dbname": "{dbname}",
"user": "{user}",
"password": "{password}",
"port": "{port}",
}
# 移行先データベース情報
dest_db = {
"host": "{host}",
"dbname": "{dbname}",
"user": "{user}",
"password": "{password}",
"port": "{port}",
}
# 移行対象テーブル
tables_to_migrate = [
"users",
"events",
]
# データベース接続
src_conn = connect_db(**src_db)
dest_conn = connect_db(**dest_db)
try:
all_definitions_match = (
True # 全テーブルの定義が一致しているか確認するためのフラグ
)
# 全体の移行開始時間を記録
start_time = time.time()
# トランザクション開始
dest_conn.autocommit = False
# 全てのテーブルの定義を確認
for table_name in tables_to_migrate:
# テーブルが存在するかを確認
src_table_exists = check_table_exists(src_conn, table_name)
dest_table_exists = check_table_exists(dest_conn, table_name)
if not src_table_exists:
print(f"移行元のデータベースにテーブル '{table_name}' が存在しません。")
all_definitions_match = False
continue
if not dest_table_exists:
print(f"移行先のデータベースにテーブル '{table_name}' が存在しません。")
all_definitions_match = False
continue
# テーブル定義を取得して比較
src_def = get_table_definition(src_conn, table_name)
dest_def = get_table_definition(dest_conn, table_name)
if not compare_table_definitions(src_def, dest_def, table_name):
all_definitions_match = False
# 全テーブルの定義が一致していれば移行を開始
if all_definitions_match:
print("全テーブルの定義が一致しています。データ移行を開始します。")
for table_name in tables_to_migrate:
# テーブルごとの移行開始時間を記録
table_start_time = time.time()
# インデックスとトリガーを無効化
# with dest_conn.cursor() as dest_cur:
# dest_cur.execute(
# sql.SQL("ALTER TABLE {} DISABLE TRIGGER ALL").format(
# sql.Identifier(table_name)
# )
# )
migrate_data(src_conn, dest_conn, table_name)
# インデックスとトリガーを再有効化
# with dest_conn.cursor() as dest_cur:
# dest_cur.execute(
# sql.SQL("ALTER TABLE {} ENABLE TRIGGER ALL").format(
# sql.Identifier(table_name)
# )
# )
# テーブルごとの移行時間を出力
table_elapsed_time = time.time() - table_start_time
print(
f"テーブル '{table_name}' のデータ移行が完了しました。所要時間: {table_elapsed_time:.2f} 秒"
)
# 全体の移行時間を出力
total_elapsed_time = time.time() - start_time
print(
f"全てのテーブルのデータ移行が完了しました。総所要時間: {total_elapsed_time:.2f} 秒"
)
# すべてのデータ移行が完了したのでコミット
dest_conn.commit()
print("データ移行が正常に完了しました。")
else:
print("一部のテーブル定義に差異があるため、データ移行を中止します。")
except Exception as e:
# エラーが発生した場合、ロールバック
dest_conn.rollback()
print(f"エラーが発生しました。すべての変更をロールバックしました: {e}")
finally:
# 接続を閉じる
src_conn.close()
dest_conn.close()
if __name__ == "__main__":
main()