0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

IBM Cloud の Watson Studio と Watson Machine learningサービスで、Federated Learningを動かす(API編, 準同型暗号オプション有効化)

Last updated at Posted at 2023-07-10

本記事の目的

本記事では、IBM Cloud上で提供される Cloud Pak for Data as a Service(以降、CP4DaaSと表記) を対象に、Watson Studio および Machine learning サービスを使用した Federated Learning 機能で準同型暗号オプションを適用する方法をご紹介します。

1. Federated Learningと準同型暗号

Federated Learning(FL, 連合学習)

Federated Learning(FL, 連合学習)は、さまざまな場所に分散して存在するデータを一か所に集めることなく、共同で1つの機械学習モデルをトレーニングするための方法です。
さまざまな場所の例として、企業内、企業のコンソーシアム内、複数のデータセンターまたは複数のクラウド内、あるいはエッジ・デバイス上などがあります。
さまざまな場所でローカルにモデルのトレーニングを実行し、各場所でトレーニングされたモデルのパラメーターのみを中央へ集約することで1つのモデルを作成します。
データセキュリティ、プライバシー、および規制遵守要件に対処するとともに、データの移動とそれに関連するコストを抑えられる利点があります。

準同型暗号(HE, Homomorphic Encryption)

準同型暗号は、公開鍵暗号方式の1つの形式です。
この方式で暗号化されたデータを暗号化したままの状態で計算した結果と、元の暗号化されていないデータに対する計算結果とが一致する特徴があります。

Federated Learningで準同型暗号を適用するイメージ

Federated Learningにおいてトレーニング中のデータ・プライバシー保護を強化する手段として、準同型暗号を利用するというアイデアがあります。
Federated Learningでは、さまざまな場所でローカルにモデルのトレーニングを実行し、各場所でトレーニングされたモデルのパラメーターのみを中央へ集約することで1つのモデルを作成します。
この時、モデルのパラメーターを暗号化しておくことで、セキュリティを強化することができます。
中央では暗号化したモデルのパラメーターをそのまま集約し、1つのモデルを生成します。

image.png
出典:Federated Learning FHE Demo
https://dataplatform.cloud.ibm.com/exchange/public/entry/view/aa449d3939b73847c502bd7822d0949a

2. Cloud Pak for Data as a Service の Federated Learning で利用可能な準同型暗号化オプションの概要

Federated Learning の 日本語表記について
Cloud Pak for Data as a Service 日本語版製品資料では「Federated Learning」を「統合学習」と表記していますが、本記事内では一般的に使われることの多い「連合学習」と表記しています。
製品資料をご参照いただく際には、「連合学習」と「統合学習」がともに「Federated Learning」を意味することを前提としてご参照ください。

サポートされるフレームワーク

以下のフレームワークがサポートされています。通常のFederated Learningでサポートされる加重平均融合方式やXGBoostがサポートされていない点についてご注意ください。

(2023/6/30時点)
完全同型暗号化 (FHE) は、以下のモデル・フレームワークの単純な平均融合方式をサポートします。
最新のサポート状況は、製品資料をご覧ください。

  • Tensorflow
  • Pytorch
  • Scikit-learn 分類
  • Scikit-learn 回帰

製品資料:セキュリティーとプライバシーのための同型暗号化の適用 > サポートされるフレームワークと融合方式
https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fl-homo.html

パーティーの要件

(2023/6/30時点)
Linux環境のみのサポートです。また、暗号化に使用する証明書およびRSA鍵はパーティー側で準備することが必要です。
最新のサポート状況は、製品資料をご覧ください。

  • Linux x86 システムで実行します。
  • すべての当事者に共通する認証局を識別するルート証明書を使用して構成します。
  • 以下の表で説明されている属性を使用して、RSA 公開鍵と秘密鍵のペアを構成します。
  • 認証局によって発行されたパーティーの証明書を使用して構成します。 RSA 公開鍵がパーティーの証明書に含まれている必要があります。注: 自己署名証明書を使用するように選択することもできます。

製品資料:セキュリティーとプライバシーのための同型暗号化の適用 > パーティーの要件
https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fl-homo.html

RSA鍵の要件

(2023/6/30時点)最新のサポート状況は、製品資料をご覧ください。

属性 要件
鍵のサイズ 4096 ビット
公開鍵指数 65537
パスワード なし
ハッシュ・アルゴリズム SHA256
ファイル・フォーマット 鍵ファイルと証明書ファイルは「PEM」形式でなければなりません

製品資料:セキュリティーとプライバシーのための同型暗号化の適用 > RSA鍵の要件
https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fl-homo.html

暗号化オプション

暗号化レベルが高いほど、セキュリティーと精度が向上し、より高いリソース消費量 (計算、メモリー、ネットワーク帯域幅など) が必要になります。 デフォルトは暗号化レベル 1 です。

暗号化レベル cipher_specの指定値(REST API利用時に指定)
1 encryption_level_1
2 encryption_level_2
3 encryption_level_3
4 encryption_level_4

(詳細)APIドキュメント:Create a new WML training > federated_learning > crypto
https://cloud.ibm.com/apidocs/machine-learning#trainings-create

3. Cloud Pak for Data as a Service で 準同型暗号化オプションを有効化した Federated Learning を動かす

チュートリアルとしてFederated Learning機能で準同型暗号化を適用するサンプルNotebookが公開されています。ご自身のWatson Studioプロジェクトへ追加して実行することができます。
本章ではサンプルNotebookの各セクションに従って、準同型暗号化を適用したFederated Learningを実行します。
サンプルNotebookの説明文は英語ですが、本記事では翻訳及び一部の説明を補足しています。

製品資料:API の連合学習の同型暗号化サンプル
https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fl-fhe-sample.html

本記事は2023/6/30時点のサンプルNotebookの内容に基づきます。
サンプルNotebookの内容は不定期に更新される可能性があり、本記事の内容と最新のサンプルNotebookの内容に差異がある場合、最新のサンプルNoteobokを正としてご参照ください。

1. Introduction: 導入

IBM Federated Learning を使用すると、ローカル・データ・セットを共有することなく、ローカル・データ・セットを保有する複数の分散型パーティにまたがって機械学習モデルをトレーニングすることができます。このような関係者は、例えば企業内、企業のコンソーシアム内、複数のデータセンターまたは複数のクラウド内、あるいはエッジ・デバイス上に存在することができます。これにより、ノード間でデータを共有することなく、集合的な機械学習モデルを構築することができます。データセキュリティ、プライバシー、および規制遵守要件に対処するとともに、データの移動とそれに関連するコストを排除することができます。

連合学習のトレーニング・プロセスでは、関係者はローカルに学習した機械学習モデルを構築し、これらのローカルモデルをアグリゲーター(集約者)に送信します。アグリゲーターは、ローカル・モデルを融合して集約モデルを作成し、このモデルを関係者に送り返し、次のラウンドのトレーニングを継続させます。

詳細については、IBM Federated Learningのドキュメントを参照してください。

2. Prerequisites: 前提条件

  1. 準同型暗号を使ったFederated Learningでは、パーティー(=データを保有するローカル・トレーニングの実行者)に以下の要件としてLinux x86プラットフォームでの実行が必要です。従って、このNotebookはLinux x86プラットフォーム上で実行する必要があります。

  2. Notebook実行環境にて、準同型暗号化をサポートするIBM Watson Machine Learning Pythonクライアントパッケージをインストールします。このNotebookの次のセクションにあるセルを使用することができます。

    • WML Python クライアントパッケージに加えて、Python3.10用のFLパッケージfl-rt22.2-py3.10と準同型暗号用のパッケージfl-cryptoを追加で指定します。
    pip install 'ibm_watson_machine_learning[fl-rt22.2-py3.10,fl-crypto]'
    
  3. ご自身のIBM cloud accountにて:

    • ユーザーのIAM IDを取得します。IBM cloudのUsersページにアクセスし、該当のユーザーをクリックし、「詳細」をクリックしてIAM IDをコピーして取得します。IAM IDの形式は、IBMid-xxxxxxです。
    • API keys のページでAPIキーを生成し、取得します。
  4. Watson Machine Learningサービスインスタンスを作成します。無料のプランが提供されています。

    • このNotebookで使用するプロジェクトを作成するか、既存のプロジェクトを使用します。

    • プロジェクトへアクセスし、 プロジェクトのID を取得します。

    • プロジェクトへ Watson Machine Learning サービスを関連付けます。

    • Watson Machine Learning サービスインスタンスのロケーション名を確認します。

      WMLサービス・インスタンスの作成リージョン ロケーション名
      米国(ダラス) us-south
      英国(ロンドン) eu-gb
      アジア太平洋(東京) jp-tok
      ヨーロッパ(フランクフルト) eu-de

3. Basic setup: 基本的な設定

このNotebookを実行するPython環境内に、準同型暗号化をサポートするIBM Watson Machine Learning Pythonクライアントパッケージをインストールします。
既にインストール済みの場合も、最新バージョンへ更新します。

%pip install --upgrade 'ibm_watson_machinProject e_learning[fl-rt22.2-py3.10,fl-crypto]'

User action: 以下のセルを実行する前に、セル内の必須TBDを自分の情報に置き換え、オプションTBDを確認してください。

セル内の必須TBD
PROJECT_ID = '' # TBD [mandatory] プロジェクトID
CLOUD_USERID = '' # TBD [mandatory] IBM Cloud IAMユーザーID. IBMid-xxxxxx の形式 
IAM_APIKEY = '' # TBD [mandatory] APIキー
WML_SERVICES_LOCATION = '' # TBD [mandatory] WMLインスタンスのロケーション. us-south, eu-gb, eu-de, jp-tok のいずれか
セル内の必須TBD 記載例
PROJECT_ID = 'XXXXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX' # TBD [mandatory] プロジェクトID
CLOUD_USERID = 'IBMid-xxxxxx' # TBD [mandatory] IBM Cloud IAMユーザーID. IBMid-xxxxxx の形式 
IAM_APIKEY = 'XXXXXXXXXXXXXXXXXXXX' # TBD [mandatory] APIキー
WML_SERVICES_LOCATION = 'jp-tok' # TBD [mandatory] WMLインスタンスのロケーション. us-south, eu-gb, eu-de, jp-tok のいずれか
セル全体
import os
import subprocess
import urllib3
import requests
urllib3.disable_warnings()

cmd = subprocess.Popen("pip list | grep 'ibm-watson-machine-learning'", 
    shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
wml_installed = len(cmd.communicate()[0]) > 0
if not wml_installed:
    raise Exception('ibm-watson-machine-learning package must be installed in the environment')

base_dir = os.getcwd() # TBD [optional] A base directory under which the notebook work directory will be created. Default is the current work directory.
nb_dir = os.path.join(base_dir, 'fl_fhe_nb')
data_path = os.path.join(nb_dir, 'data')
model_path = os.path.join(nb_dir, 'model')
crypto_path = os.path.join(nb_dir, 'crypto')
exec_path = os.path.join(nb_dir, 'exec')
if not os.path.exists(data_path):
    os.makedirs(data_path)
if not os.path.exists(model_path):
    os.makedirs(model_path)
if not os.path.exists(crypto_path):
    os.makedirs(crypto_path)
if not os.path.exists(exec_path):
    os.makedirs(exec_path)
os.chdir(exec_path)

PROJECT_ID = '' # TBD [mandatory] See the prerequisites section for details.
CLOUD_USERID = '' # TBD [mandatory] See the prerequisites section for details.
IAM_APIKEY = '' # TBD [mandatory] See the prerequisites section for details.
WML_SERVICES_LOCATION = '' # TBD [mandatory] See the prerequisites section for details.
WML_SERVICES_URL = 'https://' + WML_SERVICES_LOCATION + '.ml.cloud.ibm.com'
NUM_RTS = int(3) # TBD [optional] This parameter enables to specify the number of parties for a training experiment.
SW_SPEC_NAME = 'runtime-22.2-py3.10'
HW_SPEC_NAME = 'S'
RSC_TAGS = ['wml_fl_fhe_nb_example']
TIMEOUT_TRAINING_SEC = 600
crypto_file_ext = 'v1'
asym_file_is = crypto_path + "/is_asym_" + crypto_file_ext + ".pem"
cert_file_is = crypto_path + "/is_cert_" + crypto_file_ext + ".pem"
asym_file_sb = crypto_path + "/sb_asym_" + crypto_file_ext + "_"
csr_file_sb = crypto_path + "/sb_csr_" + crypto_file_ext + "_"
cert_file_sb = crypto_path + "/sb_cert_" + crypto_file_ext + "_"
prt_data_file_prefix = 'data_party_'
NUM_MODELS = int(1)
MODEL_NAME = 'pytorch'
MODEL_TYPE = 'pytorch-onnx_1.12'
INIT_MODEL_FILE_NAME = 'pt_mnist_init_model.zip'
INIT_MODEL_URL = 'https://github.com/IBMDataScience/sample-notebooks/raw/master/Files/pt_mnist_init_model.zip'
DATA_HANDLER_FILE_NAME = 'mnist_pytorch_data_handler.py'
DATA_HANDLER_CLASS_NAME = 'MnistPytorchDataHandler'
DATASET_FILE_NAME = 'mnist.npz'
DATASET_URL = 'https://api.dataplatform.cloud.ibm.com/v2/gallery-assets/entries/85ae67d0cf85df6cf114d0664194dc3b/data'

hearbeat_resp = requests.get(WML_SERVICES_URL + "/wml_services/training/heartbeat", verify=False)
print("Heartbeat response %s" % hearbeat_resp.content.decode("utf-8"))

4. Create a WML client: WML Clientの作成

このセクションでは、WMLインスタンスとの対話を可能にするWMLクライアントを作成し、アクティブ化します。

from ibm_watson_machine_learning import APIClient

wml_credentials = {
    "url": WML_SERVICES_URL,
    "apikey": IAM_APIKEY
}
wml_client = APIClient(wml_credentials)
wml_client.set.default_project(PROJECT_ID)

5. Create WML assets: WML資産の作成

このNotebookで作成するWMLアセットは、初期モデルとリモートトレーニングシステムです。
このセクションでは、新しいアセットを作成するか、このNotebookの前のセッションで作成され、削除されなかったアセットを再利用することができます。

方法1. Create new assets: WML資産の新規作成

Store initial model assets in the cluster
初期モデルの資産をクラスターに保存する

Federated Learningには、初期状態の未訓練モデル資産が必要です。このNotebookでは、未訓練の Pytorch モデルを使用します。
詳細は、製品資料:初期モデルの作成を参照してください。

まず、事前に作成済みのサンプル初期モデルをダウンロードします。

import shutil
print("Downloading initial model")
init_model_file_path = os.path.join(model_path, INIT_MODEL_FILE_NAME)
with requests.get(INIT_MODEL_URL, stream=True) as r:
    with open(init_model_file_path, 'wb') as f:
        shutil.copyfileobj(r.raw, f)
print('Model stored in: ' + str(init_model_file_path))
print("Done")

次に、初期モデルをアセットとしてクラスターにアップロードします。

print("Storing initial model")
sw_spec_id = wml_client.software_specifications.get_id_by_name(SW_SPEC_NAME)
untrained_model_ids = {}
model_metadata = {
    wml_client.repository.ModelMetaNames.NAME: MODEL_NAME,
    wml_client.repository.ModelMetaNames.TYPE: MODEL_TYPE,
    wml_client.repository.ModelMetaNames.SOFTWARE_SPEC_UID: sw_spec_id,
    wml_client.repository.ModelMetaNames.TAGS: RSC_TAGS
}
untrained_model_details = wml_client.repository.store_model(os.path.join(model_path, INIT_MODEL_FILE_NAME), model_metadata)
untrained_model_ids[MODEL_NAME] = wml_client.repository.get_model_id(untrained_model_details)
print('Model id: ' + str(untrained_model_ids[MODEL_NAME]))
print('Done')

(オプション)サンプル初期モデルの構造を確認
サンプルNotebookに含まれない手順です。
初期モデルの構造を確認したい場合、以下の手順を実施してください。

# ダウンロードした初期モデルのzipファイルを解凍するコマンド
command = ["unzip", "-o", init_model_file_path]
subprocess.call(command)

# モデルの構造を表示
import torch

init_model_file_name = "pytorch_sequence.pt"
model = torch.load(init_model_file_name)
print(model)

PyTorchのSequentialが利用されています。
入力として手書き数字のデータセットMNISTに含まれる1枚の画像(28×28=784ピクセル)を使用し、画像に含まれる手書き数字が1~10(10クラス)のいずれに分類されるかの確率を出力する構造となっています。

モデルの構造
Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=256, bias=True)
  (2): ReLU()
  (3): Linear(in_features=256, out_features=100, bias=True)
  (4): ReLU()
  (5): Linear(in_features=100, out_features=50, bias=True)
  (6): ReLU()
  (7): Linear(in_features=50, out_features=10, bias=True)
  (8): LogSoftmax(dim=1)
)

(参考)Pytorch SEQUENTIAL
https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html

Create remote training systems in the cluster
クラスタ内にリモートトレーニングシステムを構築

リモートトレーニングシステム(RTS)アセットは、FLトレーニングのためにアグリゲータに接続するパーティを定義します。
詳細については、製品資料:ステップ 2: リモート・トレーニング・システムを作成するを参照してください。

print("Creating Remote Training Systems")
remote_training_systems = []
for i in range(NUM_RTS):
    rts_metadata = {
        wml_client.remote_training_systems.ConfigurationMetaNames.NAME: "Party_"+str(i),
        wml_client.remote_training_systems.ConfigurationMetaNames.TAGS: RSC_TAGS,
        wml_client.remote_training_systems.ConfigurationMetaNames.ORGANIZATION: {"name" : "IBM", "region": "US"},
        wml_client.remote_training_systems.ConfigurationMetaNames.ALLOWED_IDENTITIES: [{"id": CLOUD_USERID, "type": "user"}],
        wml_client.remote_training_systems.ConfigurationMetaNames.REMOTE_ADMIN: {"id": CLOUD_USERID, "type":"user"}
    }
    rts = wml_client.remote_training_systems.store(rts_metadata)
    rts_id = wml_client.remote_training_systems.get_id(rts)
    print('Remote training system Party_' + str(i) + ' id: ' + str(rts_id))
    remote_training_systems.append({'id': rts_id, 'required': True})
print('Done')

方法2. Reuse existing assets: 既存のWML資産の再利用

以前のセッションで作成されたWML資産を再利用する場合は、以下のセルを実行してください。
このコードによって、既存の資産からNotebook内部のリストを作成することができます。これらのリストは、このNotebookの後の操作で使用されます。

import json

FORCE_REBUILD_DS = False

print("Models:")
if 'untrained_model_ids' not in globals() or FORCE_REBUILD_DS or \
    len(untrained_model_ids) != NUM_MODELS:
    untrained_model_ids = {}
    load_models_dict = True
else:
    load_models_dict = False
models = wml_client.repository.get_model_details(get_all=True)
for m in models['resources']:
    md = m['metadata']
    if not 'tags' in md or md['tags'] != RSC_TAGS:
        continue
    if load_models_dict:
        untrained_model_ids[md['name']] = md['id']
    print('{}: {}'.format(md['name'],md['id']))

print("Remote Training Systems:")
if 'remote_training_systems' not in globals() or FORCE_REBUILD_DS or \
    len(remote_training_systems) != NUM_RTS:
    remote_training_systems = []
    load_rts_lst = True
else:
    load_rts_lst = False
rts = wml_client.remote_training_systems.get_details()
for r in rts['resources']:
    md = r['metadata']
    if not 'tags' in md or md['tags'] != RSC_TAGS:
        continue
    if load_rts_lst:
        remote_training_systems.append({'id': md['id'], 'required': True})
    print('{}: {}'.format(md['name'],md['id']))

6. Create parties data: Party用のトレーニング・データ作成

ここでは、MNISTデータセットをダウンロードし、各Partyのためにサブセットに分割します。
そして、データハンドラーを定義し、保存します。

Download data set and split it for the parties
データセットをダウンロードし、パーティ用に分割する

import os
import requests
import numpy as np
import shutil

def load_mnist(normalize=True, download_dir=''):
    """
    Download MNIST training data from source used in `keras.datasets.load_mnist`
    :param normalize: whether or not to normalize data
    :type normalize: bool
    :param download_dir: directory to download data
    :type download_dir: `str`
    :return: 2 tuples containing training and testing data respectively
    :rtype (`np.ndarray`, `np.ndarray`), (`np.ndarray`, `np.ndarray`)
    """
    local_file = os.path.join(download_dir, DATASET_FILE_NAME)
    if not os.path.isfile(local_file):
        with requests.get(DATASET_URL, stream=True) as r:
            with open(local_file, 'wb') as f:
                shutil.copyfileobj(r.raw, f)
        with np.load(local_file, allow_pickle=True) as mnist:
            x_train, y_train = mnist['x_train'], mnist['y_train']
            x_test, y_test = mnist['x_test'], mnist['y_test']
            if normalize:
                x_train = x_train.astype('float32')
                x_test = x_test.astype('float32')
                x_train /= 255
                x_test /= 255
        np.savez(local_file, x_train=x_train, y_train=y_train,
                 x_test=x_test, y_test=y_test)
    else:
        with np.load(local_file, allow_pickle=True) as mnist:
            x_train, y_train = mnist['x_train'], mnist['y_train']
            x_test, y_test = mnist['x_test'], mnist['y_test']
    return (x_train, y_train), (x_test, y_test)

def save_mnist_party_data(nb_dp_per_party, should_stratify, party_folder, dataset_folder):
    """
    Saves MNIST party data
    :param nb_dp_per_party: the number of data points each party should have
    :type nb_dp_per_party: `list[int]`
    :param should_stratify: True if data should be assigned proportional to source class distributions
    :type should_stratify: `bool`
    :param party_folder: folder to save party data
    :type party_folder: `str`
    :param dataset_folder: folder to save dataset
    :type data_path: `str`
    :param dataset_folder: folder to save dataset
    :type dataset_folder: `str`
    """
    if not os.path.exists(dataset_folder):
        os.makedirs(dataset_folder)
    (x_train, y_train), (x_test, y_test) = load_mnist(download_dir=dataset_folder)
    labels, train_counts = np.unique(y_train, return_counts=True)
    te_labels, test_counts = np.unique(y_test, return_counts=True)
    diff_labels = np.all(np.isin(labels, te_labels))
    num_train = np.shape(y_train)[0]
    num_test = np.shape(y_test)[0]
    num_labels = np.shape(np.unique(y_test))[0]
    nb_parties = len(nb_dp_per_party)
    if should_stratify:
        train_probs = {
            label: train_counts[label] / float(num_train) for label in labels}
        test_probs = {label: test_counts[label] /
                        float(num_test) for label in te_labels}
    else:
        train_probs = {label: 1.0 / len(labels) for label in labels}
        test_probs = {label: 1.0 / len(te_labels) for label in te_labels}
    for idx, dp in enumerate(nb_dp_per_party):
        train_p = np.array([train_probs[y_train[idx]]
                            for idx in range(num_train)])
        train_p /= np.sum(train_p)
        train_indices = np.random.choice(num_train, dp, p=train_p)
        test_p = np.array([test_probs[y_test[idx]] for idx in range(num_test)])
        test_p /= np.sum(test_p)
        test_indices = np.random.choice(
            num_test, int(num_test / nb_parties), p=test_p)
        x_train_pi = x_train[train_indices]
        y_train_pi = y_train[train_indices]
        x_test_pi = x_test[test_indices]
        y_test_pi = y_test[test_indices]
        name_file = prt_data_file_prefix + str(idx) + '.npz'
        name_file = os.path.join(party_folder, name_file)
        np.savez(name_file, x_train=x_train_pi, y_train=y_train_pi,
                 x_test=x_test_pi, y_test=y_test_pi)
    print('Data saved in ' + party_folder)
    return

save_mnist_party_data(nb_dp_per_party=[200 for _ in range(NUM_RTS)], should_stratify=False, 
    party_folder=data_path, dataset_folder=data_path)
print('Done')

Define and store a data handler
データハンドラーを定義し保存する

ここでは、PyTorchを使って学習するための、MNISTデータセットのデータハンドラPythonファイルを作成します。
その他の詳細については、製品資料:データ・ハンドラーの作成を参照してください。

%%writefile mnist_pytorch_data_handler.py
import numpy as np
from ibmfl.data.data_handler import DataHandler

class MnistPytorchDataHandler(DataHandler):
    """
    Data handler for the MNIST dataset to train using PyTorch.
    """

    def __init__(self, data_config=None):
        super().__init__()
        self.file_name = None
        if data_config is not None:
            if 'npz_file' in data_config:
                self.file_name = data_config['npz_file']
        # Load the datasets.
        (self.x_train, self.y_train), (self.x_test, self.y_test) = self.load_dataset()
        # Pre-process the datasets.
        self.preprocess()

    def get_data(self):
        """
        Gets pre-process mnist training and testing data.

        :return: training data
        :rtype: `tuple`
        """
        return (self.x_train, self.y_train), (self.x_test, self.y_test)

    def load_dataset(self, nb_points=500):
        """
        Loads the training and testing datasets from a given local path.
        If no local path is provided, it will download the original MNIST \
        dataset online, and reduce the dataset size to contain \
        500 data points per training and testing dataset.
        Because this method
        is for testing it takes as input the number of datapoints, nb_points,
        to be included in the training and testing set.

        :param nb_points: Number of data points to be included in each set if
        no local dataset is provided.
        :type nb_points: `int`
        :return: training and testing datasets
        :rtype: `tuple`
        """
        try:
            data_train = np.load(self.file_name)
            x_train = data_train['x_train']
            y_train = data_train['y_train']
            x_test = data_train['x_test']
            y_test = data_train['y_test']
        except Exception:
            raise IOError('Unable to load training data from path '
                            'provided in config file: ' +
                            self.file_name)
        return (x_train, y_train), (x_test, y_test)

    def preprocess(self):
        """
        Preprocesses the training and testing dataset, \
        e.g., reshape the images according to self.channels_first; \
        convert the labels to binary class matrices.

        :return: None
        """
        img_rows, img_cols = 28, 28
        self.x_train = self.x_train.astype('float32').reshape(self.x_train.shape[0], 1, img_rows, img_cols)
        self.x_test = self.x_test.astype('float32').reshape(self.x_test.shape[0], 1,img_rows, img_cols)
        self.y_train = self.y_train.astype('int64')
        self.y_test = self.y_test.astype('int64')
import shutil
shutil.move(os.path.join('.', DATA_HANDLER_FILE_NAME), os.path.join(data_path, DATA_HANDLER_FILE_NAME))

7. Create parties cryptographic elements: パーティーの暗号化要素を作成する

このセクションでは、Federated Learningのトレーニングで暗号化を適用するために必要な証明書と鍵のファイルを作成します。
このセクションでは、暗号化ファイルを作成するために、Python cryptography パッケージを使用する方法と、openssl を使用する方法の2つが提供されています。どちらか一方の方法を使用してください

準同型暗号鍵はFLトレーニングごとに自動的に生成され、関係者間で安全に配布されます。実験に参加するパーティーだけが、生成された秘密鍵にアクセスできます。
この生成・配布プロセスを円滑にするために、FLトレーニングの実施前に以下の手順を実行する必要があります。

  • FLトレーニングに参加するすべてのパーティは、単一の認証局の利用について合意しなければならない。
  • 各パーティは、合意された認証局から証明書を提供されなければならない。
  • 各パーティは、RSA鍵ペアを提供されなければならない。RSA*公開鍵は、前述のパーティ証明書に含まれていなければならない。

パーティーのRSAキーペアと証明書は、以下のパラメータとガイドラインを使用して生成する必要があります。

  • 鍵の種類 鍵の種類:RSA
  • 鍵のサイズ:4096ビット
  • 公開指数:65537
  • RSA*鍵ファイルにパスワードをかけない
  • ハッシュアルゴリズム: SHA256
  • 鍵ファイルと証明書ファイルはPEM形式である必要があります

各パーティは、以下のファイルへのパスが設定されている必要があります。

  • 認証局の証明書
  • 認証局が発行した各パーティーごとの証明書(パーティーのRSA公開鍵を含む)
  • パーティーのRSA秘密鍵。

この構成の詳細については、Notebookの「Launch parties」セクションで説明します。
このNotebookでは、自己署名証明書の生成とプロビジョニングを行います。

Method 1: Using Python Cryptography package

方法1:Python Cryptographyパッケージの利用

Method 1: Using Python Cryptography package
import os
import datetime
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography import x509
from cryptography.x509.oid import NameOID

class CryptoRsa():

    KEY_SIZE = 4096
    PUBLIC_EXPONENT = 65537
    CRYPTO_HASH = hashes.SHA256()

    def __init__(self):
        self.private_key = CryptoRsa.generate_key()

    def generate_key():
        private_key = rsa.generate_private_key(
            public_exponent=CryptoRsa.PUBLIC_EXPONENT,
            key_size=CryptoRsa.KEY_SIZE,
        )
        return private_key

    def get_public_key(self, type: str = "obj"):
        if self.private_key is None:
            raise Exception("self.private_key is None")
        if type == "obj":
            ret = self.private_key.public_key()
        elif type == "pem":
            ret = self.private_key.public_key().public_bytes(
                encoding=serialization.Encoding.PEM,
                format=serialization.PublicFormat.SubjectPublicKeyInfo
            )
        else:
            raise Exception("Invalid type=" + repr(type))
        return ret

    def write_key_file(self, file_path: str):
        if self.private_key is None:
            raise Exception("self.private_key is None")
        pem = self.private_key.private_bytes(
            encoding=serialization.Encoding.PEM,
            format=serialization.PrivateFormat.PKCS8,
            encryption_algorithm=serialization.NoEncryption()
        )
        with open(file_path, "wb") as key_file:
            key_file.write(pem)
        return

if not os.path.exists(crypto_path):
    os.makedirs(crypto_path)

issuer = x509.Name([
    x509.NameAttribute(NameOID.COUNTRY_NAME, u"US"),
    x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, u"California"),
    x509.NameAttribute(NameOID.LOCALITY_NAME, u"San Francisco"),
    x509.NameAttribute(NameOID.ORGANIZATION_NAME, u"Issuer Company"),
    x509.NameAttribute(NameOID.COMMON_NAME, u"mysite.com"),
])
subject = x509.Name([
    x509.NameAttribute(NameOID.COUNTRY_NAME, u"US"),
    x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, u"California"),
    x509.NameAttribute(NameOID.LOCALITY_NAME, u"San Francisco"),
    x509.NameAttribute(NameOID.ORGANIZATION_NAME, u"Subject Company"),
    x509.NameAttribute(NameOID.COMMON_NAME, u"mysite.com"),
])

issuer_key = CryptoRsa()
issuer_key.write_key_file(asym_file_is)

cert_is = x509.CertificateBuilder().subject_name(
    issuer
).issuer_name(
    issuer
).public_key(
    issuer_key.get_public_key()
).serial_number(
    x509.random_serial_number()
).not_valid_before(
    datetime.datetime.utcnow()
).not_valid_after(
    datetime.datetime.utcnow() + datetime.timedelta(days=1000)
).add_extension(
    x509.SubjectAlternativeName([x509.DNSName(u"localhost")]),
    critical=False,
).sign(issuer_key.private_key, CryptoRsa.CRYPTO_HASH)
with open(cert_file_is, "wb") as f:
    f.write(cert_is.public_bytes(serialization.Encoding.PEM))

for idx in range(NUM_RTS):
    asym_file_path = asym_file_sb+str(idx)+".pem"
    cert_file_path = cert_file_sb+str(idx)+".pem"
    subject_key = CryptoRsa()
    subject_key.write_key_file(asym_file_path)
    cert_sb = x509.CertificateBuilder().subject_name(
        subject
    ).issuer_name(
        issuer
    ).public_key(
        subject_key.get_public_key()
    ).serial_number(
        x509.random_serial_number()
    ).not_valid_before(
        datetime.datetime.utcnow()
    ).not_valid_after(
        datetime.datetime.utcnow() + datetime.timedelta(days=1000)
    ).add_extension(
        x509.SubjectAlternativeName([x509.DNSName(u"localhost")]),
        critical=False,
    ).sign(issuer_key.private_key, CryptoRsa.CRYPTO_HASH)
    with open(cert_file_path, "wb") as f:
        f.write(cert_sb.public_bytes(serialization.Encoding.PEM))

print('Done')

Method 2: Using openssl**

方法2:opensslを使用する

Method 2: Using openssl
import os

if not os.path.exists(cert_file_is):
    ret = os.system("openssl req -x509 -newkey rsa:4096 -sha256 -days 365 -nodes "
        "-subj \"/C=US/ST=California/L=San Francisco/O=Issuer Company/OU=Org/CN=www.iscompany.com\" -keyout " + 
        str(asym_file_is) + " -out " + str(cert_file_is))
    if ret != 0:
        raise Exception("openssl for issuer failed: {}".format(ret))

for idx in range(NUM_RTS):
    asym_file_path = asym_file_sb+str(idx)+".pem"
    csr_file_path = csr_file_sb+str(idx)+".pem"
    cert_file_path = cert_file_sb+str(idx)+".pem"
    if not os.path.exists(cert_file_path):
        ret = os.system("openssl req -newkey rsa:4096 -nodes -subj "
            "\"/C=US/ST=California/L=San Francisco/O=SB Company/OU=Org/CN=www.sbcompany.com\" -keyout " +
            str(asym_file_path) + " -out " + str(csr_file_path))
        if ret != 0:
            raise Exception("openssl for subject step 1 failed: {}".format(ret))
        ret = os.system("openssl x509 -req -CAcreateserial -CA " + str(cert_file_is) + " -CAkey " + str(asym_file_is) +
            " -sha256 -days 365 -in " + str(csr_file_path) + " -out " + str(cert_file_path))
        if ret != 0:
            raise Exception("openssl for subject step 2 failed: {}".format(ret))

print('Done')

8.Launch aggregator: アグリゲータの起動

このセクションでは、実験用のFederated Learningアグリゲータを起動します。
アグリゲーターが起動すると、今回のFLトレーニングを識別するためのIDが割り振られます。
サンプルNotebookのコードでは、変数training_idにこの値がセットされています。

準同型暗号で実験を実行するには、アグリゲーターの設定に以下の融合タイプを指定する必要があります。
"fusion_type": "crypto_iter_avg"

また、アグリゲーターの設定へ、必要な暗号化レベルを指定する crypto オブジェクトを含めることもできます。例えば、以下のようになります。

"crypto": {
    "cipher_spec": "encryption_level_1"
}

このオブジェクトが指定されない場合、デフォルトの 暗号化レベル1 encryption_level_1 が使用されます。

暗号化レベルは、レベル1からレベル4までの4段階があります。暗号化レベルが高いほど、セキュリティと精度が向上し、より高いリソース消費(計算、メモリ、ネットワーク帯域幅など)を必要とします。セキュリティレベルは、暗号化システムの強度に対応し、通常、攻撃者がシステムを破るために実行しなければならない操作の数で測定されます。精度レベルは、暗号化システムの結果の精度に対応する。精度レベルが高いということは、暗号演算が浮動小数点の前後の桁数まで正確であることを意味します。精度レベルが高いほど、暗号化操作によるモデルの精度の低下を抑えることができます。

  • encryption_level_1 暗号化レベル1は、高いセキュリティと良好な精度を提供し、デフォルトのレベルである。
  • encryption_level_2 暗号化レベル2は、高いセキュリティと高い精度を提供し、レベル1より多くのリソースを必要とします。
  • encryption_level_3 暗号化レベル3は、高いセキュリティと優れた精度を提供し、レベル2より多くのリソースを必要とします。
  • encryption_level_4 暗号化レベル4は、特別に高いセキュリティと高い精度を提供し、レベル3より多くのリソースを必要とします。
    アグリゲーターの起動に関する詳細については、製品資料:アグリゲーターの開始 (管理者)を参照してください。
fl_conf = {
    "model": {
      "type": MODEL_NAME,
      "spec": {
        "id": untrained_model_ids[MODEL_NAME]
      },
      "model_file": "pytorch_sequence.pt"
    },
    "fusion_type": "crypto_iter_avg",
    "crypto": {
      "cipher_spec": "encryption_level_1"
    },
    "epochs": 1,
    "rounds": 2,
    "metrics": "accuracy",
    "remote_training": {
      "max_timeout": TIMEOUT_TRAINING_SEC,
      "quorum": 1,
      "remote_training_systems": remote_training_systems,
    },
    "software_spec": {
      "name": SW_SPEC_NAME
    },
    "hardware_spec": {
      "name": HW_SPEC_NAME
    }
}
aggregator_metadata = {
    wml_client.training.ConfigurationMetaNames.NAME: 'aggregator_he',
    wml_client.training.ConfigurationMetaNames.DESCRIPTION: '',
    wml_client.training.ConfigurationMetaNames.TAGS: RSC_TAGS,
    wml_client.training.ConfigurationMetaNames.TRAINING_DATA_REFERENCES: [],
    wml_client.training.ConfigurationMetaNames.TRAINING_RESULTS_REFERENCE: {
        "type": "container",
        "name": "outputData",
        "connection": {},
        "location": {
          "path": "."
        }
    },
    wml_client.training.ConfigurationMetaNames.FEDERATED_LEARNING: fl_conf
}
print("Prepared config for aggregator with model type {}".format(MODEL_NAME))
aggregator = wml_client.training.run(aggregator_metadata, asynchronous=True)
print("Created Aggregator")
training_id = wml_client.training.get_id(aggregator)
print("Training id: " + str(training_id))
print ("RTS: " + str(remote_training_systems))

9. Launch parties: パーティーの起動

このセクションでは、実験用の Federated Learning パーティーを起動します。

FLトレーニングで準同型暗号を適用するには、パーティの構成オプションlocal_trainingオブジェクトの中に cryptoオブジェクトを含める必要があり、このオブジェクトへパーティに必要な証明書と鍵ファイルを指定します。

(詳細)APIドキュメント:Create a new WML training > federated_learning > crypto
https://cloud.ibm.com/apidocs/machine-learning#trainings-create

パーティーの起動に関するその他の詳細については製品資料:アグリゲーター (パーティー) への接続 を参照してください。

import os

for idx, prt in enumerate(remote_training_systems):
    party_metadata = {
        wml_client.remote_training_systems.ConfigurationMetaNames.LOCAL_TRAINING: {
            "info": {
                "crypto": {
                    "key_manager": {
                        "key_mgr_info": {
                            "distribution": {
                                "ca_cert_file_path": cert_file_is,
                                "my_cert_file_path": cert_file_sb+str(idx)+'.pem',
                                "asym_key_file_path": asym_file_sb+str(idx)+'.pem'
                            }
                        }
                    }
                }
            }
        },
        wml_client.remote_training_systems.ConfigurationMetaNames.DATA_HANDLER: {
            "info": {
                "npz_file": os.path.join(data_path, prt_data_file_prefix+str(idx)+'.npz')
            },
            "name": DATA_HANDLER_CLASS_NAME,
            "path": os.path.join(data_path, DATA_HANDLER_FILE_NAME)
        }
    }
    print("Connecting party id {} to aggregator id {}, model type {}".format(prt['id'], training_id, MODEL_NAME))
    party = wml_client.remote_training_systems.create_party(prt['id'], party_metadata)
    party.monitor_logs("ERROR")
    party.run(aggregator_id=training_id, asynchronous=True, verify=False)
    print("Party {} is running".format(prt['id']))
print('Done')

10. Monitor execution status of the training: トレーニングの実行状況をモニターする

このセクションでは、FLトレーニングの実行状況をモニターすることができます。
training_idを指定することで、今回のFLトレーニングの実行状況をモニターしています。
FLトレーニングの実行における各ステータスの詳細については、製品資料:エクスペリメントのモニターを参照してください。(

import time
import json

def monitor_training(training_id):
    print('Monitoring training id: {}'.format(training_id))
    MAX_ITER = 240
    SLP_TIME_SEC = 10
    aggregator_status = wml_client.training.get_status(training_id)
    aggregator_state = aggregator_status['state']
    iter = 0
    while iter < MAX_ITER and 'completed' != aggregator_state and 'failed' != aggregator_state and 'canceled' != aggregator_state:
        print("Elapsed time: {} seconds, State: {}".format(iter*SLP_TIME_SEC, aggregator_state))
        time.sleep(SLP_TIME_SEC)
        aggregator_status = wml_client.training.get_status(training_id)
        aggregator_state = aggregator_status['state']
        iter += 1
    if iter >= MAX_ITER:
        raise Exception("Training did not finish after {} seconds".format(iter*SLP_TIME_SEC))
    print("Final status: " + json.dumps(aggregator_status, indent=4))

if 'training_id' in globals():
    monitor_training(training_id)
else:
    trn = wml_client.training.get_details(get_all=True)
    for t in trn['resources']:
        md = t['metadata']
        if 'tags' in md and md['tags'] == RSC_TAGS:
            monitor_training(md['id'])

(オプション)トレーニング済みモデルをプロジェクトへ保存する

サンプルNotebookに含まれない手順です。
もし、学習したモデルを予測に利用したい場合、以下の手順でモデルの保存を実施してください。

モデルをプロジェクトへ保存
##### 関数定義
import ibm_boto3
from ibm_botocore.client import Config, ClientError

def setup_cos_client(wslib_instance, ibm_cloud_api_key):    

"""Cloud Object Storageクライアントのセットアップ"""

projet_storage_metadata = wslib.here.get_storage()

COS_ENDPOINT = "https://s3.private." + projet_storage_metadata["properties"]["bucket_region"] + ".cloud-object-storage.appdomain.cloud"
COS_RESOURCE_INSTANCE_ID = projet_storage_metadata["properties"]["credentials"]["viewer"]["resource_key_crn"]
COS_APIKEY = ibm_cloud_api_key

cos_client = ibm_boto3.resource("s3",
    ibm_api_key_id=COS_APIKEY,
    ibm_service_instance_id=COS_RESOURCE_INSTANCE_ID,
    config=Config(signature_version="oauth"),
    endpoint_url=COS_ENDPOINT
)

return cos_client


def get_cos_file_content(cos_client, bucket_name, file_name):

"""COSバケット上のファイル内容の取得"""

file_content = cos_client.Object(bucket_name, file_name).get()
data = json.loads(file_content["Body"].read())

return data


def save_model_to_repositry(training_id, base_model_name, trained_model_data, model_name=None):

"""プロジェクトにtrained modelを保存"""

# model metadataの組み立て
if model_name is None:
    model_name = f"{base_model_name}_{training_id}"
    
sw_spec_id = wml_client.software_specifications.get_id_by_name(trained_model_data["software_spec"]["name"])

metadata = {
    wml_client.repository.ModelMetaNames.NAME: model_name,
    wml_client.repository.ModelMetaNames.SOFTWARE_SPEC_UID: sw_spec_id,
    wml_client.repository.ModelMetaNames.TYPE: trained_model_data["type"]

}

# プロジェクトへtrained modelを保存
details = wml_client.repository.store_model(model=training_id, meta_props=metadata)

return details

##### 処理の実行
print('Save FL trained model')
# COSからtrained modelの情報を取得
BUCKET_NAME = wslib.here.get_storage()['properties']['bucket_name']
ITEM_NAME = training_id + "/assets/" + training_id + "/resources/wml_model/request.json"
cos_client = setup_cos_client(wslib, IAM_APIKEY)
trained_model_data = get_cos_file_content(cos_client=cos_client, bucket_name=BUCKET_NAME, file_name=ITEM_NAME)

# WMLクライアントのセットアップ
from ibm_watson_machine_learning import APIClient

wml_credentials = {
    "url": WML_SERVICES_URL,
    "apikey": IAM_APIKEY
}
wml_client = APIClient(wml_credentials)
wml_client.set.default_project(PROJECT_ID) #「3. Basic setup: 基本的な設定」のセクションで設定したPROJECT_IDを利用

# プロジェクトへtrained modelを保存
model_save_name = f"{MODEL_NAME}_{training_id}"  # モデルの名前を指定。今回は「3. Basic setup: 基本的な設定」のセクションで設定したMODEL_NAMEとtraining_idを組み合わせて利用。
details = save_model_to_repositry(training_id=training_id, base_model_name=MODEL_NAME, trained_model_data=trained_model_data, model_name=model_save_name)
model_id = details["metadata"]["id"]
print(f"Save FL trained model Done model id:{model_id}")
  • モデルがプロジェクトへ保存されたことを確認します。
    • モデルの一覧へ、一つ前の手順で保存した学習済みモデルと、初期モデル(「5. Create WML assets: WML資産の作成」のセクションで保存した初期状態の未訓練モデル資産)が表示されます。
      • PytorchのモデルはTYPEが pytorch-onnx_1.12 と表示されており、ONNX形式として保持されていることがわかります。
プロジェクトへ保存されたモデルの一覧
# モデル一覧
wml_client.repository.list_models()
実行例:プロジェクトへ保存されたモデルの一覧
------------------------------------  ---------------------------------------------------------------  ------------------------  -----------------  ----------  ----------------
ID                                    NAME                                                             CREATED                   TYPE               SPEC_STATE  SPEC_REPLACEMENT
4813469b-7f00-46fe-b9d6-5307e8b67e0d  pt_mnist_init_model                                              2023-06-16T11:13:04.002Z  pytorch-onnx_1.12  supported
b3cfb015-5c34-4320-aa0c-4280092bbfa1  pytorch_bba9a5de-0f8d-4364-88fe-f62b54cea62f                     2023-06-08T11:49:22.002Z  pytorch-onnx_1.12  supported

11. Cleanup: クリーンアップ

このNotebookを使って作成したFLトレーニングジョブ、資産、ローカルファイルを削除するには、このセクションを使用します。

Remove training jobs
FLトレーニングジョブの削除

print('Removing training jobs')
trn = wml_client.training.get_details(get_all=True)
for t in trn['resources']:
    md = t['metadata']
    if 'tags' in md and md['tags'] == RSC_TAGS:
        wml_client.training.cancel(md['id'], hard_delete=True)
        print('Deleted {}: {}'.format(md['name'],md['id']))
print('Done')

Remove remote training systems and models
リモート・トレーニング・システムとモデルの削除

print('Removing remote training systems')
rts = wml_client.remote_training_systems.get_details(get_all=True)
for r in rts['resources']:
    md = r['metadata']
    if 'tags' in md and md['tags'] == RSC_TAGS:
        wml_client.repository.delete(md['id'])
        print('Deleted {}: {}'.format(md['name'],md['id']))

print('Removing models')
models = wml_client.repository.get_model_details(get_all=True)
for m in models['resources']:
    md = m['metadata']
    if 'tags' in md and md['tags'] == RSC_TAGS:
        wml_client.repository.delete(md['id'])
        print('Deleted {}: {}'.format(md['name'],md['id']))

print('Done')

Remove local files
ローカルファイルの削除

import shutil
shutil.rmtree(nb_dir)

4. 関連資料

CP4DaaSのFederated Learningに関連する資料です。

製品資料:IBM Federated Learning
https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fed-lea.html

製品資料:Applying homomorphic encryption for security and privacy
https://dataplatform.cloud.ibm.com/docs/content/wsj/analyze-data/fl-homo.html

APIドキュメント:Watson Machine Learning
https://cloud.ibm.com/apidocs/machine-learning-cp

Pythonライブラリ:ibm-watson-machine-learning
https://pypi.org/project/ibm-watson-machine-learning/

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?