kenmaro です。
秘密計算、特に準同型暗号のことについて記事を書いています。
秘密計算エンジニアとして得た全ての知見をまとめた記事はこちら。
https://qiita.com/kenmaro/items/74c3147ccb8c7ce7c60c
これから準同型暗号について勉強したいリサーチャー、エンジニアの方へのロードマップはこちら。
https://qiita.com/kenmaro/items/f2d4fb84833c308a4d29
今話題のゼロ知識証明について解説した記事はこちら。
https://qiita.com/kenmaro/items/d968375793fe754575fe
概要
2016年のCrytptoNetsの論文以降、
格子暗号を用いてAIモデルの推論(もしくは学習)を行うような実装論文や、
暗号状態で実行できるモデルを汎用化させるようなコンパイラの研究や実装が進んできました。
その中でもTFHEと呼ばれる任意の演算(非線形演算でもよい)を暗号状態で実行できる、
「完全準同型暗号」を用いてAIモデルを実行できるライブラリはとても限られています。
そのうちの一つである、ZamaAIのライブラリ「concrete-ml」を解説したいと考えています。
できること
concrete-ml の特徴は以下のようにまとめることができます。
- 暗号状態での「推論」ができる(学習は暗号では行えない、平文で行う)
- 線形モデル、木構造のモデル、ニューラルネット(畳み込みニューラルネットなども含む)に対応している
- 最大で16ビットの精度での推論が可能(学習されたモデルを16ビット量子化する作業が必要)
- 前処理には非対応(たとえばテキストをベクトルにしたり、特徴量の抽出、クラスタリング、標準化など)
また、これらの今できないことについては随時実装中であるとのこと。
チュートリアル
公式のチュートリアル
では、以下のようなことを試すことができます。
- 自作の簡易データセットに対する、ロジスティック回帰の推論(8ビット精度)
- MNISTデータセットに対する、ニューラルネットモデルの推論(隠れ層無しのニューラルネット: 784 --> 10 へのマッピング+シグモイド関数)
- Titanicデータセットに対する、LGBMモデルの推論
実行環境
私のノートパソコンで動かしてみました。
OS Mac (Intel)
Core 8Core 16 threads
Memory 64GB
また、pip install でMacのpyenvにインストールしたところ、
チュートリアルのcompile のところでエラーとなったため、
彼らのDocker image をpull してDocker 環境でチュートリアルを実行しました。
ロジスティック回帰の推論
プログラムからのログは以下のようになります。
In clear : [0 0 0 0 1 0 1 0 1 1 0 0 1 0 0 1 1 1 0 0]
In FHE : [0 0 0 0 1 0 1 0 1 1 0 0 1 0 0 1 1 1 0 0]
Similarity: 100%
time: 0.08336877822875977
この時の実行時間は、20件全てに対して暗号状態で推論を実行した時の時間になります。
MNIST の推論
MNISTに対するニューラルネットモデルは、畳み込みを使用したものではなく、全結合層を1層だけ使ったモデルです。
入力層: 784
出力層: 10
活性化関数: シグモイド
11ビット精度でこの推論を行う、というようなログが吐き出されましたが、
実際私はFHEでの推論計算は1件のみのデータで実行しましたが、私の8コア16スレッドのマシンではかなり時間がかかりました。
LGBMの推論
設定パラメータ
- 木の深さ (1~5)
- 木の数(1~5)
上記のパラメータで(平文で)グリッドサーチを行い、
一番良かったものに対して暗号状態での推論を実行しています。
木の数がmax5というのはかなり小さいモデルで行ったデモということになります。
プログラムからのログ
Best hyper-parameters found in 4.85s : {'learning_rate': 1, 'max_depth': 4, 'n_estimators': 4}
Best hyper-parameters found in 45.20s : {'learning_rate': 0.1, 'max_depth': 4, 'n_bits': 2, 'n_estimators': 4}
100%|█████████████████████████████████████████████████████████████████████████████| 418/418 [02:42<00:00, 2.57it/s]
Key generation time: 1.62s
Total execution time for 418 inferences: 162.71s
Execution time per inference in FHE: 0.39s
Prediction similarity between both Concrete-ML models (quantized clear and FHE): 100.00%
実際のグリッドサーチにて、
- 木の深さ = 4
- 木の数 = 4
が選出され、それに対する暗号での推論実行時間は
0.39秒(1データに対する推論時間)
であることがわかります。
Cifar10のVGGモデルでの推論
このチュートリアルについては、私の環境だとかなり実行が厳しいレベルで重い計算が必要だったため、
ZamaAIからのドキュメントに書かれている結果をそのまま書きたいと思います。
平文での実行と暗号文での実行
このチュートリアルでは、TFHEで量子化した(ビット精度に制約のある)演算を全てのレイヤに適用すると、
精度に支障が出てくるということが説明されています。
特に、入力層の最初の1層は生の画像がそのまま入ってくるレイヤですが、ここではもっと精度が必要となるため、
このチュートリアルでは最初の一層目だけは平文で計算し、残りのレイヤを暗号状態で計算するようなストーリーになっています。
平文で入力層を実行することはユースケースとしては厳しいのではないか、と思われる方も多いかと思いますが、
実際の量子化したモデルでしか実行できない制約下のなかで、この方針は現在の技術の限界であることがわかります。
暗号パラメータのチューニングに必要性
concrete-mlで使用する、暗号状態でのLUT参照は、基本的に全ての非線形関数で実行される必要があります。
このとき、LUTにどのくらいの精度を定義するか(LUTの結果は確率で値がずれることがあるため、その確率を小さくしたければより計算を多くしなければならない)によって実行時間がかなり変わってきます。
concrete-ml では、このLUTに関する暗号パラメータの設定を、
- p_error
- global_p_error
というパラメータ(のどちらかを設定することで)で設定することができます。
- p_error は1つのLUTの値にブレが発生することをどの程度の確率で許容するか
- global_p_errorは、モデル全体を計算した後の結果に、LUT起因の値のブレをどの程度の確率で許容するか
という意味を持っており、global_p_error を設定すると、モデルの構成によりp_errorが内部で自動的に計算されることになります。
このp_error を設定することで実行時間には大きな影響があるため、大きなモデルになればなるほどこの値をチューニングする必要が出てきます。詳しくは彼らのドキュメントを参照してください。
実行速度
チューニングされた p_error の元で
Time to compile: 109 seconds
Time to keygen: 30 seconds
Time to infer: 1738 seconds
という結果が出ています。しかし、実行インスタンスは
こちらの128コアあるかなりリッチなインスタンスであること、
また、先述の通り最初の1層目は平文で実行されていること(ここをTEEで行うとかの構成は考えられそうですが)
に注意です。
いずれにしても、かなり重いことに違いはないですね。
Brevitas からの学習されたモデルでCifar10を推論
Xilinxが提供しているBrevitas というライブラリは、
量子化されたモデルを明示的に学習することのできるライブラリです。
ここをみると、
[MODEL]
ARCH: CNV
PRETRAINED_URL: https://github.com/Xilinx/brevitas/releases/download/bnn_pynq-r0/cnv_2w2a-0702987f.pth
EVAL_LOG: https://github.com/Xilinx/brevitas/releases/download/cnv_test_ref-r0/cnv_2w2a_eval-5aaca4c6.txt
DATASET: CIFAR10
IN_CHANNELS: 3
NUM_CLASSES: 10
[QUANT]
WEIGHT_BIT_WIDTH: 2
ACT_BIT_WIDTH: 2
IN_BIT_WIDTH: 8
のようなBrevitas への設定ファイルを確認することができます。
ここの、[QUANT]で、
- 重みパラメータは2ビット量子化
- 活性化関数を通るときは2ビット量子化
- インプットとなる画像は8ビット量子化
という制約の元でモデルの学習を行うことがわかります。
ここで学習されたモデルにたいして、ほとんどそのまま実行できるのですが、
concrete-ml では以下のような変更を加えています。
- MaxPooling 層は、AvgPooling へと変更(これは比較演算であるMaxPoolingはLUT参照が大量に発生するため、加算でできるAvgを使ってFHEフレンドリーにしている。)
- 画像の正規化はモデルの外で実行する(平文で実行し、正規化された画像を暗号化して推論する構成。)
また、
evaluate_torch_cml.py
の実行を通して、サーキットはFHEで行うものを使い、
実際の計算は平文で実行してFHEモデルでの精度などを検証しています。
これは、この検証を暗号状態で行おうとするとコストがかかりすぎるために、
出てくるエラーなどの値を平文状態でも加味できるように検証用の関数が整備されているということです。
最後に、実際に暗号状態で1枚の画像を推論すると、
先程のような128コア程度の高価なマシンで約10時間程実行に時間がかかる、というような検証結果となっています。
Brevitas で転移学習されたモデルでCifar10, Cifar100を推論
このシナリオについても、ほとんど前項のモデルと同じです。
このチュートリアルでは、実際の暗号文での推論は行われていません。
あくまで実行されているのは、FHEのためにコンパイルしたモデルに対して、平文状態でシミュレーションとして推論を行っているまでです。
理由としては、FHEで実際に計算すると非常に長時間かかるためです。
ただし、ここで強調されているのは、Cifar100のようなデータセットにおいても、
- 量子化されたモデルが画像分類タスクで有効であるということ
- FHE用にコンパイルされたモデルでも精度が落ちないこと
- FHEでの実際の推論はとても時間がかかる(おそらく日単位)が、実行可能であること
- 専用ハードウェアなどが出来れば高速化が図れること
です。
まとめ
今回は、ZamaAIが公開している concrete-ml を実際にチュートリアルとして動かしてみました。
だいたいやれることはわかりましたし、やはりソフトウェアの高速化のみだとFHEの実行時間は厳しいなと思いました。
LigntGBMについてもかなり木の数が小さかったので、このくらいの数であればまあ実行できるくらいの時間なんだな、と感じました。
チュートリアルの後半で使用された、
XilinxのBrevitasで量子化したモデルを学習する、というところはTFHEととても相性がいいなと思いました。
さらに、LUTのp_errorなどをチューニングすることで実行時間を短くしつつ、
モデルの精度を落とさないギリギリのところにチューニングしていくところが(ある程度)自動化されているところは、
ZamaAIがとても親切に実装してくれていてすごいなあと思いました。
- tfhe-rs
- concrete-np
- concrete-ml
というようにコアとなるライブラリを全てOSSとして公開しているZamaAIですが、
他のライブラリについてもチュートリアルを行い記事にしてみようと思っておりますので興味のある方はそちらもご覧ください。
今回はこの辺で。