@kenmaroです。
普段は主に秘密計算、準同型暗号などの記事について投稿しています。
秘密計算に関連するまとめの記事に関しては以下をご覧ください。
概要
格子暗号ベースの準同型暗号のこれからのデファクトスタンダードになりそうな
ライブラリ、OpenFHE
について、以前の記事
でチュートリアルをいくつかやってみました。
今回は、さらに応用を加えたチュートリアルということで、
- CKKS形式でブートストラップ手法を用いて線形回帰モデルの学習を実装
というテーマでプログラムを実装したので、そちらの解説を行いたいと思います。
これまでのチュートリアルで使用したプログラムは、全て
こちらに上げていますので、興味ある方はご覧ください。
また、実行結果だけ興味のある方は、この後の 実行結果 の章まで飛んでください。
注意
この実装は動作確認のための簡易パラメータを使用しています。
このプログラムは、セキュリティパラメータ
parameters.SetSecurityLevel(HEStd_NotSet);
を使用しており、本番システムでは必要なパラメータが変更され、実行時間や精度が大きく変更される可能性が高いです。
あくまでも動作確認のためのパラメータとお考えください。
parameters.SetSecurityLevel(HEStd_128_classic);
セキュリティパラメータ128ビットを使用した際は、多項式の長さを16ビットに設定する必要があり、実行時間が大幅に(少なくとも16倍以上に)上昇しました。
128 bit セキュリティパラメータ使用時
time_bootstrap_w_b: 102952 msec
time_batch: 119383 msec
1回のブートストラップに10秒程度、
1バッチの学習に120秒程度実行時間がかかっています。
よって、概算すると 2分 * 10バッチ * 30 = 10時間程度の学習時間となりそうです。
線形回帰学習のアルゴリズム
基本的に準同型暗号を用いるとき、線形演算をベースにしてアルゴリズムを組んでいきます。
たとえば、非線形問題のソルバーを準同型暗号で演算しようとすると、例えば行列の固有値を求めたりとか、ラグランジュの未定定数方程式を解いたりすると思いますが、そのような計算は準同型暗号(ここでは格子暗号)では難しいです。
したがって、基本的には機械学習でモデルを学習するのと同じように、
勾配法によって学習パラメータを地道にアップデートするような構成にします。
線形回帰モデルにおいて今回は、普通の二乗誤差をロスとして定義し、
学習パラメータとなる重みとバイアスをアップデートしていきます。
なぜ線形回帰モデルの学習を実装する価値があるのか
なぜ線形回帰モデルを学習するようなプログラムを組んでいるのか、
と疑問に思った方もいらっしゃるかもしれません、ScikitLearnであれば一瞬でモデルを学習できますよね。
今回重要なのは、モデルを学習する際にサーバに送信する学習データ(train_x, train_y) 及び、
サーバで学習される重みパラメータ(w, b)の登場人物は全て暗号状態で学習することが可能です。
Microsoft SEAL でできたことと何が違うのか?
格子暗号のアプリケーションを作成する際に一番使用されているライブラリは、いまのところ
Microsoftリサーチが開発、メンテナンスしている SEAL ライブラリでした。
SEALライブラリには、CKKS形式という浮動小数のベクトルを暗号化することのできるCKKS形式という格子暗号形式が実装されており、浮動小数とベクトルで線形演算をする、というのは機械学習的アプリケーションではほぼ必須であるため、この形式が実際のアプリケーションを組み上げるときには必須となっていました。
わたしの今までの記事で紹介してきた、TFHE形式(トーラス上で構成した格子暗号)にも
たくさんの利点がありましたが、速度と精度に大きな問題があり、ハードウェアの開発無しでは実用アプリケーションは難しいだろう、というのがディベロッパーの正直な印象だったかと思います。
さて、ここでOpenFHEについてまとめた前回記事では、基本的にはDuality Technologyをはじめとする世界の格子暗号研究トップランナーが、今までのPalisadeライブラリの延長上として開発を進めてきたライブラリだということを言及しました。
しかし、いくつかの実装において、今までの格子暗号ライブラリの上位互換となりうる機能が導入されており、その中でも、CKKS形式における
- プロキシ再暗号化
- Threshold暗号
- ブートストラップ
の3つの機能があり、それらについて今までチュートリアルとしていくつかプログラムを解説しました。
この中でも、一般的なアプリケーションであれば、ブートストラップ法が一番強力であり、
今までCKKS形式の実装はレベル準同型暗号であった、つまり、回路の深さには制限があり、一定の処理を行った後には、サーバからクライアントに暗号を一度戻し、復号処理を行ってノイズを削減し、
再度暗号化してサーバサイドでまた準同型演算を行い、これらのやりとりを何度も駆使して目的の演算を実行する、という制約がありました。
例えば、線形回帰モデルに対しても、勾配法を用いて学習を行う際は、学習データをバッチに分けそれらに対してパラメータを更新する作業を複数エポック回実行する必要がありました。
このときに計算に必要な乗算はかなり大きくなることが予想されるため、多くのタイミングで先ほどのサーバとクライアントの間の通信を発生させ、学習を収束させる必要があったのです。
この手法はレベル準同型暗号という制約を理解していれば避けては通れないシステム構成でした。
しかし、今回のOpenFHEにはCKKS形式のブートストラップ法が実装されており、レベルを消費した暗号であっても、レベルを再消費できるように戻すような演算が実装されています。
前回の記事で、ブートストラップと準同型暗号を組み合わせたチュートリアルコードを実装し、
思ったよりもかなり演算速度が高速なことを知り、これはブートストラップを使用して線形回帰モデルの学習も可能だな、と感じたため、今回実装してみました。
ブートストラップを各バッチ毎に重みパラメータに対して施すことで、
学習の全過程の演算を、サーバサイドで完結することができています。
プログラムの解説
プログラムはあまり綺麗に書こうとはせずに動くものを作ったので、600行くらいのプログラムになってしまっていますが、こちらになります。
int main(int argc, char* argv[]) {
int attrs = 10;
int ds = 100;
int bs = 10;
double lr = 0.05;
int epoch = 30;
のmain関数のパラメータは、少し解説すると、
attrs: データの次元数(特徴量の数)
ds: データサイズ(100件のデータに対して学習を実行)
bs: 学習する際のバッチサイズ
lr: 学習率
epoch: エポック数
となっています。
学習のメインループですが、
for (int k = 0; k < epoch; k++) {
Struct_VVVD_VVD batched_data = make_mini_batches(datas.xs, datas.ys, bs);
for (int i = 0; i < batched_data.xs.size(); i++) {
printf("epoch: %d, batch %d/%ld\n", k, i, batched_data.xs.size());
vector<double> batch_x = flatten_xs(batched_data.xs[i]);
vector<double> batch_y = batched_data.ys[i];
vector<double> ppd_batch_y = pp_y(batch_y, attrs, bs);
auto enc_xs = encrypt(batch_x, kc);
auto enc_ys = encrypt(ppd_batch_y, kc);
//printf("enc done\n");
start = get_time_msec();
w_b = main_batch(enc_xs, enc_ys, w_b, kc, attrs, bs, lr);
stop = get_time_msec(
となっており、epoch
のループの中にbatch
のループが回っています。
イニシャライズとして重みは初期化された後暗号化され、サーバに送られます。
batch
のループ毎に、クライアントはバッチに含まれる学習データを暗号化し、サーバに送信します。
サーバは順伝搬、逆伝搬を実行し、重みをサーバサイドでアップデートします。
アップデートが完了したら、サーバは重みに対してブートストラップを施し、レベルを初期化します。
以下、エポックのループが全て消化されるまでこれを繰り返します。
サーバサイドの学習コードのメイン関数は、
Struct_enc_w_b main_batch(Ciphertext<DCRTPoly>& x, Ciphertext<DCRTPoly>& y, Struct_enc_w_b& w_b,
Struct_key_and_context& kc, int attrs, int bs, double lr) {
auto y_hat = forward(x, w_b, kc, attrs, bs);
auto loss = calc_loss(y_hat, y, kc, attrs, bs);
Struct_update_dw_db dw_db = backward(y_hat, y, x, kc, attrs, bs, lr);
Struct_enc_w_b new_w_b = update(w_b, dw_db, kc, attrs, bs);
new_w_b.loss = loss;
return new_w_b;
}
この関数で行われています。
順伝搬や逆伝搬の式は全てラップしているため、ここで詳細は書きませんが、
興味のある方はそれぞれの関数の中身をチェックしてみてください。
実行結果
テスト環境は、
Processor: 2.3 GHz 8-Core Intel Core i9 (8 Core 16 thread)
Memory: 64 GB 2667 MHz DDR4
のMac book Pro です。
また、先ほど言及したように、
- データ数:100
- バッチ数:10
- データのカラム数: 10
- エポック数: 30
で実験しています。
30エポックも行う必要はなかったのですが、以下のようにロスが下がっていることが確認できました。
ブートストラップを各処理毎に入れたとしても、
実行時間は1638秒
となりました。(約30分)
二乗誤差の推移
注意
このプログラムは、セキュリティパラメータ
parameters.SetSecurityLevel(HEStd_NotSet);
を使用しており、本番システムでは必要なパラメータが変更され、実行時間や精度が大きく変更される可能性が高いです。
あくまでも動作確認のためのパラメータとお考えください。
parameters.SetSecurityLevel(HEStd_128_classic);
セキュリティパラメータ128ビットを使用した際は、多項式の長さを16ビットに設定する必要があり、実行時間が大幅に(少なくとも16倍以上に)上昇しました。
128 bit セキュリティパラメータ使用時
time_bootstrap_w_b: 102952 msec
time_batch: 119383 msec
1回のブートストラップに10秒程度、
1バッチの学習に120秒程度実行時間がかかっています。
よって、概算すると 2分 * 10バッチ * 30 = 10時間程度の学習時間となりそうです。
しかしながら、これが一度も一時復号などを実行せずに完了しているというのは、自分の中ではかなり革命的な結果となりました。
まとめ
今回は、OpenFHEチュートリアルとして、線形回帰モデルの学習を実装してみました。
かなり汚いコードではあるのですが、やりたかったところまでばーっと実装することができてよかったです。
次回はおそらく、OpenFHE + Webassembly というテーマをやってみたいのですが、
こちらに関しては実装が重くなる可能性があるので、簡易的なチュートリアルに止めることになると思います。
以上、巷ではChatGPTなどいろいろな画期的なAIモデルが誕生し注目を集めている中、
線形回帰モデルの学習を実装しているという秘密計算界隈ですが、
このCKKS+ブートストラップ(+プロキシ再暗号化など)は実装面ではブレークスルーになると思います。
これからも是非ウォッチしていきましょう!
今回はこの辺で。