0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

PostgreSQL 同一テーブル定義のデータ移行

Posted at

PostgreSQL 同一テーブル定義のデータ移行

はじめに

同じテーブル定義を持つPostgreSQLデータベース間でデータ移行が必要になったため、
Pythonを使ったデータ移行スクリプトや概要を備忘録として残します。

処理概要

  • テーブル定義の比較
  • データのバルクインサート
  • エラーハンドリング

データのバルクインサート

当初は単純なINSERTによる逐次処理でデータ移行を行っていましたが、
パフォーマンスに課題があったため、PostgreSQLのCOPY機能を使用しています。

  • 約5000件のデータ移行時のパフォーマンス比較
    • 逐次処理(INSERT): 数分程度
    • COPY機能を使用した処理: 数秒程度

環境

以下の環境で動作を確認しています。

Dockerイメージ

FROM python:3.12

ライブラリ

以下のライブラリを使用しています。

  • psycopg2-binary: PostgreSQLデータベースに接続するためのライブラリ
    pip install psycopg2-binary

手順

1. PostgreSQLへの接続

PostgreSQLデータベースへの接続を行います。

def connect_db(host, dbname, user, password, port):
    conn = psycopg2.connect(
        host=host, dbname=dbname, user=user, password=password, port=port
    )
    return conn

2. テーブルの存在確認

移行元・移行先のテーブルが存在するか確認します。

def check_table_exists(conn, table_name):
    with conn.cursor() as cur:
        cur.execute(
            sql.SQL(
                """
                SELECT EXISTS (
                    SELECT 1 
                    FROM information_schema.tables 
                    WHERE table_name = %s
                )
                """
            ),
            [table_name],
        )
        return cur.fetchone()[0]

3. テーブル定義の取得と比較

テーブル定義を取得し、移行元と移行先の差異を比較します。
差異があれば詳細を出力し、処理を中断します。

def compare_table_definitions(src_def, dest_def, table_name):
    differences = []
    src_columns = {col[0]: col for col in src_def}
    dest_columns = {col[0]: col for col in dest_def}

    for col_name, src_col in src_columns.items():
        if (dest_col := dest_columns.get(col_name)) is None:
            differences.append(f"移行先には存在しないカラム: {col_name}")
        else:
            if src_col != dest_col:
                differences.append(
                    f"カラム '{col_name}' の差異: 移行元: {src_col}, 移行先: {dest_col}"
                )

    for col_name in dest_columns.keys():
        if col_name not in src_columns:
            differences.append(f"移行元には存在しないカラム: {col_name}")

    if differences:
        print(f"テーブル '{table_name}' に差異があります:")
        for diff in differences:
            print(f"  - {diff}")
        return False
    return True

4. データのバルクインサート

バッチ単位でデータを挿入し、大量データを効率的に処理します。

def bulk_insert_data(dest_conn, table_name, columns, rows, total_rows, batch_num):
    temp_file_path = f"{table_name}_data.csv"
    with open(temp_file_path, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f, quoting=csv.QUOTE_MINIMAL, escapechar="\\")
        for row in rows:
            processed_row = []
            for field in row:
                if field is None:
                    processed_row.append("NULL")
                elif field == "":
                    processed_row.append("")
                else:
                    processed_row.append(field)
            writer.writerow(processed_row)

    with open(temp_file_path, "r", encoding="utf-8") as f:
        dest_conn.cursor().copy_expert(
            sql.SQL(
                "COPY {} ({}) FROM STDIN WITH (FORMAT csv, DELIMITER ',', NULL 'NULL')"
            ).format(
                sql.Identifier(table_name),
                sql.SQL(", ").join(map(sql.Identifier, columns)),
            ),
            f,
        )

    os.remove(temp_file_path)
    print(f"テーブル '{table_name}' の進捗: バッチ {batch_num} - {len(rows)} 件挿入 (累計: {total_rows} 件)")
  • 注意点
    • トランザクション管理: エラー発生時には全テーブルデータをロールバック
    • インデックスとトリガーの無効化: パフォーマンス向上のため、無効化箇所をコメントアウト

全ソースコード

import psycopg2
from psycopg2 import sql
import time
import csv
import os


# PostgreSQL接続用の関数
def connect_db(host, dbname, user, password, port):
    conn = psycopg2.connect(
        host=host, dbname=dbname, user=user, password=password, port=port
    )
    return conn


# テーブルの存在確認を行う関数
def check_table_exists(conn, table_name):
    with conn.cursor() as cur:
        cur.execute(
            sql.SQL(
                """
                SELECT EXISTS (
                    SELECT 1 
                    FROM information_schema.tables 
                    WHERE table_name = %s
                )
                """
            ),
            [table_name],
        )
        return cur.fetchone()[0]


# テーブル定義を取得する関数
def get_table_definition(conn, table_name):
    with conn.cursor() as cur:
        cur.execute(
            sql.SQL(
                """
                SELECT column_name, data_type, is_nullable, character_maximum_length 
                FROM information_schema.columns 
                WHERE table_name = %s
                ORDER BY ordinal_position
                """
            ),
            [table_name],
        )
        return cur.fetchall()


# テーブル定義を比較し、差分を出力する関数
def compare_table_definitions(src_def, dest_def, table_name):
    differences = []
    src_columns = {col[0]: col for col in src_def}
    dest_columns = {col[0]: col for col in dest_def}

    # 移行元に存在し、移行先に存在しないカラム
    for col_name, src_col in src_columns.items():
        if (dest_col := dest_columns.get(col_name)) is None:
            differences.append(f"移行先には存在しないカラム: {col_name}")
        else:
            if src_col != dest_col:
                differences.append(
                    f"カラム '{col_name}' の差異: 移行元: {src_col}, 移行先: {dest_col}"
                )

    # 移行先に存在し、移行元に存在しないカラム
    for col_name in dest_columns.keys():
        if col_name not in src_columns:
            differences.append(f"移行元には存在しないカラム: {col_name}")

    if differences:
        print(f"テーブル '{table_name}' に差異があります:")
        for diff in differences:
            print(f"  - {diff}")
        return False
    return True


# バルクデータ移行処理
def bulk_insert_data(dest_conn, table_name, columns, rows, total_rows, batch_num):
    with dest_conn.cursor() as dest_cur:
        # 一時ファイルのパスを定義
        temp_file_path = f"{table_name}_data.csv"

        # COPYコマンドを使用するための一時ファイル作成
        with open(temp_file_path, "w", newline="", encoding="utf-8") as f:
            writer = csv.writer(f, quoting=csv.QUOTE_MINIMAL, escapechar="\\")

            for row in rows:
                processed_row = []
                for field in row:
                    if field is None:
                        processed_row.append("NULL")  # NULLを文字列として書き込む
                    elif field == "":
                        processed_row.append("")  # 空文字として処理
                    else:
                        # 改行を含むフィールドをクォートで囲む
                        if isinstance(field, str) and ("\n" in field or "\r" in field):
                            # ダブルクォートをエスケープしてフィールドをクォートで囲む
                            processed_row.append(f'{field.replace("\"", "\"\"")}')
                        else:
                            processed_row.append(
                                field
                            )  # その他のフィールドはそのまま書き込む

                writer.writerow(processed_row)

        # COPYコマンドでデータを挿入
        with open(temp_file_path, "r", encoding="utf-8") as f:
            dest_cur.copy_expert(
                sql.SQL(
                    "COPY {} ({}) FROM STDIN WITH (FORMAT csv, DELIMITER ',', NULL 'NULL')"
                ).format(
                    sql.Identifier(table_name),
                    sql.SQL(", ").join(map(sql.Identifier, columns)),
                ),
                f,
            )

        # 一時ファイルを削除
        os.remove(temp_file_path)

        # 挿入が成功したバッチごとに進捗状況を出力
        print(
            f"テーブル '{table_name}' の進捗: バッチ {batch_num} - {len(rows)} 件挿入 (累計: {total_rows} 件)"
        )


# データ移行処理
def migrate_data(src_conn, dest_conn, table_name, batch_size=10000):
    total_rows = 0  # 総データ数
    batch_num = 1  # バッチ番号
    with src_conn.cursor() as src_cur:
        src_cur.execute(sql.SQL("SELECT * FROM {}").format(sql.Identifier(table_name)))
        columns = [desc[0] for desc in src_cur.description]  # カラム名を取得

        # データをバッチで読み込んでバルクインサート
        while True:
            rows = src_cur.fetchmany(batch_size)
            if not rows:
                break  # データがなくなったら終了

            total_rows += len(rows)
            bulk_insert_data(
                dest_conn, table_name, columns, rows, total_rows, batch_num
            )
            batch_num += 1  # 次のバッチへ


# メイン処理
def main():
    # 移行元データベース情報
    src_db = {
        "host": "{host}",
        "dbname": "{dbname}",
        "user": "{user}",
        "password": "{password}",
        "port": "{port}",
    }
    # 移行先データベース情報
    dest_db = {
        "host": "{host}",
        "dbname": "{dbname}",
        "user": "{user}",
        "password": "{password}",
        "port": "{port}",
    }

    # 移行対象テーブル
    tables_to_migrate = [
        "users",
        "events",
    ]

    # データベース接続
    src_conn = connect_db(**src_db)
    dest_conn = connect_db(**dest_db)

    try:
        all_definitions_match = (
            True  # 全テーブルの定義が一致しているか確認するためのフラグ
        )

        # 全体の移行開始時間を記録
        start_time = time.time()

        # トランザクション開始
        dest_conn.autocommit = False

        # 全てのテーブルの定義を確認
        for table_name in tables_to_migrate:
            # テーブルが存在するかを確認
            src_table_exists = check_table_exists(src_conn, table_name)
            dest_table_exists = check_table_exists(dest_conn, table_name)

            if not src_table_exists:
                print(f"移行元のデータベースにテーブル '{table_name}' が存在しません。")
                all_definitions_match = False
                continue

            if not dest_table_exists:
                print(f"移行先のデータベースにテーブル '{table_name}' が存在しません。")
                all_definitions_match = False
                continue

            # テーブル定義を取得して比較
            src_def = get_table_definition(src_conn, table_name)
            dest_def = get_table_definition(dest_conn, table_name)

            if not compare_table_definitions(src_def, dest_def, table_name):
                all_definitions_match = False

        # 全テーブルの定義が一致していれば移行を開始
        if all_definitions_match:
            print("全テーブルの定義が一致しています。データ移行を開始します。")
            for table_name in tables_to_migrate:
                # テーブルごとの移行開始時間を記録
                table_start_time = time.time()

                # インデックスとトリガーを無効化
                # with dest_conn.cursor() as dest_cur:
                #     dest_cur.execute(
                #         sql.SQL("ALTER TABLE {} DISABLE TRIGGER ALL").format(
                #             sql.Identifier(table_name)
                #         )
                #     )

                migrate_data(src_conn, dest_conn, table_name)

                # インデックスとトリガーを再有効化
                # with dest_conn.cursor() as dest_cur:
                #     dest_cur.execute(
                #         sql.SQL("ALTER TABLE {} ENABLE TRIGGER ALL").format(
                #             sql.Identifier(table_name)
                #         )
                #     )

                # テーブルごとの移行時間を出力
                table_elapsed_time = time.time() - table_start_time
                print(
                    f"テーブル '{table_name}' のデータ移行が完了しました。所要時間: {table_elapsed_time:.2f}"
                )

            # 全体の移行時間を出力
            total_elapsed_time = time.time() - start_time
            print(
                f"全てのテーブルのデータ移行が完了しました。総所要時間: {total_elapsed_time:.2f}"
            )

            # すべてのデータ移行が完了したのでコミット
            dest_conn.commit()
            print("データ移行が正常に完了しました。")
        else:
            print("一部のテーブル定義に差異があるため、データ移行を中止します。")

    except Exception as e:
        # エラーが発生した場合、ロールバック
        dest_conn.rollback()
        print(f"エラーが発生しました。すべての変更をロールバックしました: {e}")

    finally:
        # 接続を閉じる
        src_conn.close()
        dest_conn.close()


if __name__ == "__main__":
    main()
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?