LoginSignup
3
7

More than 1 year has passed since last update.

準同型暗号を用いたロジスティック回帰の学習について知っている知識を全て書きます。

Posted at

@kenmaroです。
普段は主に秘密計算、準同型暗号などの記事について投稿しています
秘密計算に関連するまとめの記事に関しては以下をご覧ください。

概要

準同型暗号を用いたアプリケーションとして、機械学習はとても魅力的な応用先です。
その中でも、よく聞かれるのが、準同型暗号でデータを暗号化した状態で、モデルを学習できますか??
ということです。

今回はその、暗号化状態でモデルの学習ってできるのか?
そのなかでもロジスティック回帰の学習についてどのようなプロセスで暗号状態で行うことができるのか、
ということについて、今まで取り組んできたことなどを含めてまとめていこうと思います。

暗号状態でモデルの学習できるんですか???

結論はイエスですが、実用性の向上が不可欠だということです。
イエスという意味は、暗号状態で任意の計算を可能にする「完全準同型暗号」を用いることで、
時間はかかりますがどんな演算でも秘匿的に計算を行うことができるからです。
イエス「ですが、」という逆説がついている理由は、計算リソースや演算実行時間の制約があるためです。
可能ですが、いろいろ工夫をしないと普通に使えるレベルではない、ということです。
おそらくテック系の仕事をしている人であればこの文脈を理解していただけると思うのでこの辺にしときます。

暗号状態でもモデル学習できますか?という質問をよくいただきます、という話をしていました。
自分も自由自在に暗号化データに対して学習を行い、モデルも入力データも暗号化した状態で最適化、
暗号化モデルをクラウド上のMLaaSのような感じでデプロイ、
モデル使用ユーザはデータを暗号化したままモデルへと投入し、暗号状態のまま結果をもらう。

と一連のフローが全て暗号状態でできる世界、魅力的ですよね。

しかしながら、現実的にモデルの学習フローについて暗号状態で行うような研究はなかなか(数えるほどしか)行われていません。
それは、機械学習を行なった人ならわかると思いますが、学習プロセスに必要な計算フローというのは、学習済みモデルを使用して入力データに対して推論をする計算プロセスよりも格段に複雑であり、計算過程が長いからです。

しかしながら、比較的簡単なモデル、たとえばロジスティック回帰等のモデルであれば、
研究分野で暗号状態で学種させるアルゴリズムなどは存在しており、その実装に取り組んだりしてきました。

格子暗号を用いた時の推論、学習フローとは?

この後の項では数式やコードなどを実際に使い、
どのようなアルゴリズムでロジスティック回帰の学習を格子暗号を用いて行うか、
ということを詳解しますが、全体の処理フローについて最初に図示します。
これでだいたいのニュアンスを掴んでもらえればな、と思います。

まずは、ロジスティック回帰の推論フローから。
Screen Shot 2021-12-04 at 11.27.13.png

次に、学習フローです。
Screen Shot 2021-12-04 at 11.27.22.png

ここで、あれ、シグモイド関数はクライアント側で一時復号してやるの?
とか、クライアント側でやる工程多くない?
と思われる方がいると思います。
また、学習の方は、バッチごとに暗号をサーバとやりとりするのか、、通信かなり多くなるな、、と。
その気持ち、わかります。
以下のような処理フローであれば明快ですばらしいですよね。
Screen Shot 2021-12-04 at 11.41.34.png

これはつまり、クライアントは一度暗号化を実行したらサーバに投げ、
サーバが学習を完遂したあとにクライアントに結果(学習済みモデル)を返却、
クラインとは出来上がったモデルを復号して性能をチェックする。

というフローです。

理想的にはこの処理をシステムとして行うことですが、
CKKS等の暗号形式(レベル格子暗号)を用いると、
バッチ単位で少なくとも一時復号、再暗号化を行うことが必要となり、
エポック単位で処理を完遂することが難しくなります。

しかしながら、このようにサーバサイドで処理を完遂できるような
「完全」準同型暗号も存在し、最近提唱された「プログラマブルブートストラップ」
を用いることで、上のような理想的な処理も可能になる、未来も近いうちに来るのではないかと思います。
しかしそのためには処理速度の向上やハードウェアとの連携、PBSによる計算精度劣化問題の解決
などが必要となっています。
プログラマブルブートストラップについての解説や、考察等はいくつか記事に書いていますのでそちらをどうぞご覧ください。

格子暗号で新ブレイクスルー!! プログラマブルブートストラップを解説!!

全体としての処理の流れを把握した上で、
暗号状態でどうモデルの更新を行うのか、詳細に入っていきます。

ロジスティック回帰の学習

ロジスティック回帰の学習について暗号状態で実装しようとする時、
まずはロジスティック回帰のアルゴリズムについて今一度復習する必要があります。
MLを触っている人なら一度は勉強したことがあると思いますが、ロジスティック回帰のバックプロパゲーションでの重みの更新式は、
おどろくほどシンプルになります。

具体的には、ロス関数の全微分の式が、
Screen Shot 2021-12-03 at 17.33.56.png

このようにシンプルになり、したがって、重みの更新式は
Screen Shot 2021-12-03 at 17.34.00.png

このようになります。

このあたりはこのgihyoの記事がとてもわかりやすかったのでぜひ一度計算してみるとよりしっくりきます。

xにかかっている関数phiは、特徴量を作るための前処理のようなものなので、
phi(x)を欠けている作業は、入力データxを、予測値yとラベルtの差分にかけ、
学習率を掛け合わせたものが更新値となるため、準同型暗号を用いても比較的簡単に更新値を求められます。

格子暗号を用いた時の高速化処理

具体的に、この重み更新を行うことは可能なのですが、
一つ一つの重みを暗号化し、一つ一つの重みを使って更新値を計算すると、計算時間が大変なことになってしまいます。
それを回避するために、CKKS形式のパッキングを使用することになります。
パッキング手法による高速化などは以前記事にまとめたのでぜひご覧ください。
パッキングにより、簡単に言えばベクトルを一つの暗号文にすることが可能となり、
暗号状態でのベクトル演算のようなものが可能になります。

詳しくはIDASHとよばれる秘密計算実装コンテストで以前優勝したこの論文にとても丁寧に書かれていますのでぜひご覧ください。

IDASHでは、ゲノムデータの解析を準同型暗号を用いて行うことを想定したコンペティションであり、
この論文はCKKS形式を提唱した研究者たちによるロジスティック回帰学習の実装論文です。

一つ一つのパッキングステップが詳細に書かれています。
また、CKKS形式を用いると線形関数は先ほどのベクトルを用いることで可能である一方で、
登場するシグモイド関数は非線形関数であり、暗号状態で簡単に評価することができないのが準同型暗号の弱点なのですが、
この論文ではテイラー展開による近似を用い、3次や5次の項までの近似関数を用い、その制限の中でも学習をある程度の精度で実装しています。

SEAL による実装

とても明快に書かれた上の論文ですが、実際に実装しようとすると、SEALライブラリを用いることになります。
SEALライブラリは、マイクロソフトがOSSとして開発している格子暗号の低レイヤライブラリです。

以前よく使われる格子暗号のライブラリについてまとめた記事がありますので、ぜひご覧ください。
ここでSEALライブラリについても言及しています。

有名な格子暗号ライブラリの使用感をまとめてみた。

実装の上位レイヤ

低レイヤから組み上げて、上の論文で実装されているアルゴリズムを実装することになるのですが、
なかなか実装上のハードルは高くなる印象です。
低レイヤから全てのコードをここに書き出すと大変なことになりますので、ここではメイン関数のみを簡単に公開して、
どのようにロジスティック回帰の学習が準同型暗号を用いて実装されるのかについてお伝えします。

main.cpp
#include <iostream>
#include <fstream>
#include <string>
#include "capsuleflow.h"
#include <seal/seal.h>

#include <string>
#include <sstream>
#include <vector>

#include <stdio.h>

using namespace std;
using namespace seal;
using namespace capsuleflow;
using namespace std::chrono;
inline double get_time_sec(void){
  return static_cast<double> (duration_cast<nanoseconds>(steady_clock::now().time_since_epoch()).count())/1000000000;
}


template <typename T>
void print_vec(vector<T> x, int size){
  // 省略
}

vector<string> split(string& input, char delimiter)
{
    // 省略
}

vector<vector<double>> read_matrix(string file_path){
 // 省略

}


int main(){
  printf("hello, capsuleflow application!\n");
  double start, stop;

  vector<int> modc = {60, 40, 40, 40, 40, 60};
  int mc = 4;
  //int pm = 1 << 30;
  int poly_mod = 1 << 14;
  double scale = pow(2., 40);

  int attrs = 4;
  int bs = 40;
  int train_data_size = 1000;
  bool approx_sig = false;
  int power = 3;
  int range = 5;
  int ss = 1 << 13;
  double eta = 0.08;
  double lambda = 0.07;
  string reg = "L2";
  int epochs = 30;


  CKKSEncryptionService ser = CKKSEncryptionService(
    mc, scale, modc, poly_mod
  );

  LogisticRegressionTrainConfig config = LogisticRegressionTrainConfig(
    attrs=attrs, bs=bs, approx_sig=approx_sig, power=power,
    range=range, ss=ss, eta=eta, lambda=lambda, train_data_size=train_data_size,
    reg=reg, epochs=epochs
  );

  string data_file_x = "/ltrain/bank_data/x_train.csv";
  string data_file_y = "/ltrain/bank_data/y_train.csv";
  vector<vector<double>> train_x = read_matrix(data_file_x);
  vector<vector<double>> train_y = read_matrix(data_file_y);

  printf("train_x: size: (%d, %d)\n", train_x.size(), train_x[0].size());
  printf("train_y: size: (%d, %d)\n", train_y.size(), train_y[0].size());

  vector<double> w;
  double b;
  CKKSCtxt loss;

  Struct_VD_D vd_d_1 = LogisticRegressionTrainUtils::generate_logistic_regression_initial_weight_and_bias(config);
  w = vd_d_1.key1;
  b = vd_d_1.key2;
  printf("w: size: %d\n", w.size());
  print_vec<double>(w, w.size());
  printf("b: %f\n", b);



  DoubleVec1d all_loss;
  DoubleVec1d all_time;
  double total_start, total_stop;

  total_start = get_time_sec();
  for(int k=0; k<config.epochs; k++){
    printf("\n\n\n=====================================================");
    printf("%d th epoch / %d\n", k, config.epochs);
    DoubleVec3d batched_x, batched_label;
    Struct_VVVD_VVVD vvvd_vvvd_1 = LogisticRegressionTrainUtils::make_mini_batches(train_x, train_y, config);
    batched_x = vvvd_vvvd_1.key1;
    batched_label = vvvd_vvvd_1.key2;
    printf("batch_x: size: (%d, %d, %d)\n", batched_x.size(), batched_x[0].size(), batched_x[0][0].size());
    printf("batch_t: size: (%d, %d, %d)\n", batched_label.size(), batched_label[0].size(), batched_label[0][0].size());
    printf("batch_num: %d\n", config.bn);

    for(int i=0; i<config.bn; i++){
      start = get_time_sec();

      //=================================================
      // Client Side Process 1
      DoubleVec1d ppd_x = LogisticRegressionTrainUtils::preprocess_input(batched_x[i], config);
      DoubleVec1d ppd_t = LogisticRegressionTrainUtils::preprocess_label(batched_label[i], config);
      DoubleVec1d ppd_w = LogisticRegressionTrainUtils::preprocess_weight(w, config);
      DoubleVec1d ppd_b = LogisticRegressionTrainUtils::preprocess_bias(b, config);

      CKKSCtxt enc_x = LogisticRegressionTrainUtils::encrypt_input(ppd_x, config, ser);
      CKKSCtxt enc_t = LogisticRegressionTrainUtils::encrypt_bias(ppd_t, config, ser);
      CKKSCtxt enc_w = LogisticRegressionTrainUtils::encrypt_weight(ppd_w, config, ser);
      CKKSCtxt enc_b = LogisticRegressionTrainUtils::encrypt_bias(ppd_b, config, ser);

      //=================================================
      // Server Side Forward
      CKKSCtxt tmp1 = LogisticRegressionTrainUtils::apply_without_sigmoid(enc_x, enc_w, enc_b, config);

      //=================================================
      // Client Side Process 1
      DoubleVec1d tmp2 = LogisticRegressionTrainUtils::decrypt(tmp1, config, ser);
      DoubleVec1d tmp3 = LogisticRegressionTrainUtils::postprocess(tmp2, config);
      DoubleVec1d tmp4 = LogisticRegressionTrainUtils::apply_exact_sigmoid(tmp3, config);

      CKKSCtxt tmp5 = LogisticRegressionTrainUtils::encrypt_inference_result(tmp4, config, ser);

      //=================================================
      // Server Side Backward
      Struct_CKKS_CKKS_CKKS tmp6 = LogisticRegressionTrainUtils::calculate_update(tmp5, enc_t, enc_x, config);
      CKKSCtxt dw = tmp6.key1;
      CKKSCtxt db = tmp6.key2;
      loss = tmp6.key3;

      //=================================================
      // Client Side
      DoubleVec1d dw_dec = LogisticRegressionTrainUtils::decrypt_dw(dw, ser);
      DoubleVec1d db_dec = LogisticRegressionTrainUtils::decrypt_db(db, ser);


      Struct_VD_D tmp7 = LogisticRegressionTrainUtils::apply_updates(w, dw_dec, b, db_dec[0], config);
      w = tmp7.key1;
      b = tmp7.key2;

      stop = get_time_sec();

      double current_loss = LogisticRegressionTrainUtils::calculate_loss(loss, ser);
      all_loss.emplace_back(current_loss);
      printf("\n%d th batch / %d, time: %f, loss: %f\n", i, config.bn, stop-start, current_loss);
      print_vec<double>(w, w.size());
      all_time.emplace_back(stop-start);

      //print_vec<double>(all_loss, all_loss.size());

    }

    config.add_w_history(w);
    config.add_b_history(b);
    config.add_loss_history(LogisticRegressionTrainUtils::calculate_loss(loss, ser));

  }

  print_vec<double>(config.loss_history, config.loss_history.size());
  print_vec<double>(all_time, all_time.size());

  string loss_file_path = "/ltrain/build/result/res_loss.txt";
  IOUtils::write_vector_to_file(config.loss_history, loss_file_path, config.loss_history.size());
  IOUtils::write_weight_to_file(config.w_history, "/ltrain/build/result/res_weight.txt", config.w_history.size(), config.w_history[0].size());
  IOUtils::write_bias_to_file(config.b_history, "/ltrain/build/result/res_bias.txt", config.b_history.size());

  total_stop = get_time_sec();
  printf("total training time: %f\n", total_stop - total_start);

  return 0;
}


上位レイヤとは言えど、そこそこのコード量にはなっていることをご了承ください。

main.cpp
  vector<int> modc = {60, 40, 40, 40, 40, 60};
  int mc = 4;
  //int pm = 1 << 30;
  int poly_mod = 1 << 14;
  double scale = pow(2., 40);

まず、上記で以下のことを行なっています。

  • CKKS暗号形式の暗号パラメータ設定

ここでは、modc (modulus chain) を上記のように設定することで、レベル型の格子暗号を設定しています。
レベル型というのは、あらかじめ何回掛け算をできるか定めた準同型暗号のことです。
ここでは、レベル4であるため、4回まで掛け算ができますが、それ以上掛け算をしようとするとノイズが爆発し、復号不可能となるためできないような設定になっています。(SEALでは scale out of bounds というようなエラーとなります。)

main.cpp
  int attrs = 4;
  int bs = 40;
  int train_data_size = 1000;
  bool approx_sig = false;
  int power = 3;
  int range = 5;
  int ss = 1 << 13;
  double eta = 0.08;
  double lambda = 0.07;
  string reg = "L2";
  int epochs = 30;

次に、上記の部分で、ロジスティック回帰に必要なMLパラメータを設定しています。

  • attrs: ロジスティック回帰への入力のフィーチャの数
  • bs: バッチサイズ
  • train_data_size: 学習データの数 ( = len(trainx))
  • appx_sig: sigmoidをテイラー近似するかどうかのブーリアン
  • power: sigmoid をテイラー近似する際の次元 (3 であれば3次の項までで近似)
  • range: テイラー近似に使う定義域の絶対値
  • ss: スロットサイズ(平文多項式に入れることのできる平文の数)
  • eta: 学習率
  • lambda: L1, L2 などの正規化を使う時の制限パラメター
  • reg: L1,L2, none などで正規化の有無を指定
  • epochs : エポック数
main.cpp

  CKKSEncryptionService ser = CKKSEncryptionService(
    mc, scale, modc, poly_mod
  );

  LogisticRegressionTrainConfig config = LogisticRegressionTrainConfig(
    attrs=attrs, bs=bs, approx_sig=approx_sig, power=power,
    range=range, ss=ss, eta=eta, lambda=lambda, train_data_size=train_data_size,
    reg=reg, epochs=epochs
  );

上記で設定した暗号、ロジスティック回帰のパラメータを使い、

  • 暗号に使用する鍵(CKKSEncryptionService に格納される)
  • ロジスティック回帰のコンフィグ

をコンストラクトしています。

main.cpp
  string data_file_x = "/ltrain/bank_data/x_train.csv";
  string data_file_y = "/ltrain/bank_data/y_train.csv";
  vector<vector<double>> train_x = read_matrix(data_file_x);
  vector<vector<double>> train_y = read_matrix(data_file_y);

  printf("train_x: size: (%d, %d)\n", train_x.size(), train_x[0].size());
  printf("train_y: size: (%d, %d)\n", train_y.size(), train_y[0].size());

  vector<double> w;
  double b;
  CKKSCtxt loss;

  Struct_VD_D vd_d_1 = LogisticRegressionTrainUtils::generate_logistic_regression_initial_weight_and_bias(config);
  w = vd_d_1.key1;
  b = vd_d_1.key2;
  printf("w: size: %d\n", w.size());
  print_vec<double>(w, w.size());
  printf("b: %f\n", b);

データをcsvから読んだり、スタートする重みの初期化などを行なっています。

main.cpp
  for(int k=0; k<config.epochs; k++){
    printf("\n\n\n=====================================================");
    printf("%d th epoch / %d\n", k, config.epochs);
    DoubleVec3d batched_x, batched_label;
    Struct_VVVD_VVVD vvvd_vvvd_1 = LogisticRegressionTrainUtils::make_mini_batches(train_x, train_y, config);
    batched_x = vvvd_vvvd_1.key1;
    batched_label = vvvd_vvvd_1.key2;
    printf("batch_x: size: (%d, %d, %d)\n", batched_x.size(), batched_x[0].size(), batched_x[0][0].size());
    printf("batch_t: size: (%d, %d, %d)\n", batched_label.size(), batched_label[0].size(), batched_label[0][0].size());
    printf("batch_num: %d\n", config.bn);

エポックごとにバッチを生成しているだけです。

main.cpp
    for(int i=0; i<config.bn; i++){
      start = get_time_sec();

      //=================================================
      // Client Side Process 1
      DoubleVec1d ppd_x = LogisticRegressionTrainUtils::preprocess_input(batched_x[i], config);
      DoubleVec1d ppd_t = LogisticRegressionTrainUtils::preprocess_label(batched_label[i], config);
      DoubleVec1d ppd_w = LogisticRegressionTrainUtils::preprocess_weight(w, config);
      DoubleVec1d ppd_b = LogisticRegressionTrainUtils::preprocess_bias(b, config);

      CKKSCtxt enc_x = LogisticRegressionTrainUtils::encrypt_input(ppd_x, config, ser);
      CKKSCtxt enc_t = LogisticRegressionTrainUtils::encrypt_bias(ppd_t, config, ser);
      CKKSCtxt enc_w = LogisticRegressionTrainUtils::encrypt_weight(ppd_w, config, ser);
      CKKSCtxt enc_b = LogisticRegressionTrainUtils::encrypt_bias(ppd_b, config, ser);

各バッチごとに、クライアントは必要な前処理(論文を参照)を行い、
入力データ、重み、バイアスをそれぞれ暗号化します。
このあと今回のコードでは通信は行っていませんが、
計算サーバに暗号化したものをhttp, grpc 等を用いて送る必要があります。

main.cpp
      //=================================================
      // Server Side Forward
      CKKSCtxt tmp1 = LogisticRegressionTrainUtils::apply_without_sigmoid(enc_x, enc_w, enc_b, config);

計算サーバは暗号化された入力データ、重み、バイアスをもとに、
暗号状態で線形回帰を行います。
サーバサイドでシグモイドの評価をテイラー近似して行う場合は、その処理も行います。
実際、線形回帰は暗号状態で入力データと重みの内積を計算するというだけのものですが、
パッキングされた暗号に対して内積をとるときは特有のアルゴリズムを用いる必要があります。
詳しいことは以前まとめた以下の記事をご覧ください。

準同型暗号を用いた高速化の例

main.cpp
      //=================================================
      // Client Side Process 1
      DoubleVec1d tmp2 = LogisticRegressionTrainUtils::decrypt(tmp1, config, ser);
      DoubleVec1d tmp3 = LogisticRegressionTrainUtils::postprocess(tmp2, config);
      DoubleVec1d tmp4 = LogisticRegressionTrainUtils::apply_exact_sigmoid(tmp3, config);

      CKKSCtxt tmp5 = LogisticRegressionTrainUtils::encrypt_inference_result(tmp4, config, ser);

この場合、なるべく暗号化状態で非線形演算を回避するために、
シグモイド演算はクライアント側にデータを戻し、一時復号をした状態で行っています。
サーバ側でテイラー近似を使って行う時は、この一時復号は必要ではありません。
クライアントは、シグモイド関数を実行した後、計算サーバに逆伝搬の計算をしてもらうためにデータを再度暗号化し、サーバ再度に送り返します。

main.cpp
      //=================================================
      // Server Side Backward
      Struct_CKKS_CKKS_CKKS tmp6 = LogisticRegressionTrainUtils::calculate_update(tmp5, enc_t, enc_x, config);
      CKKSCtxt dw = tmp6.key1;
      CKKSCtxt db = tmp6.key2;
      loss = tmp6.key3;

サーバサイドは論文のアルゴリズムに従い、ロジスティック回帰の更新項を求めるための逆伝搬を計算します。
結果として、dw, db, loss の項を暗号状態で求めています。

main.cpp
      //=================================================
      // Client Side
      DoubleVec1d dw_dec = LogisticRegressionTrainUtils::decrypt_dw(dw, ser);
      DoubleVec1d db_dec = LogisticRegressionTrainUtils::decrypt_db(db, ser);


      Struct_VD_D tmp7 = LogisticRegressionTrainUtils::apply_updates(w, dw_dec, b, db_dec[0], config);
      w = tmp7.key1;
      b = tmp7.key2;

      stop = get_time_sec();

      double current_loss = LogisticRegressionTrainUtils::calculate_loss(loss, ser);
      all_loss.emplace_back(current_loss);
      printf("\n%d th batch / %d, time: %f, loss: %f\n", i, config.bn, stop-start, current_loss);
      print_vec<double>(w, w.size());
      all_time.emplace_back(stop-start);

      //print_vec<double>(all_loss, all_loss.size());

更新式をもらったクライアントは、dw, dbを復号し、自身の持っている重みに対して更新を実際に行います。
また、lossを復号し、ロスが下がっていることを確認したり、十分であると判断すれば学習をストップします。

このようにしてロスが一定に下がるまで、このプロセスを繰り返すことにより、
クライアントはサーバ側に入力データ、重み等を見せることなく、
ロジスティック回帰のモデル学習を行うことができます。

最後に

準同型暗号(格子暗号)を用いた機械学習モデルの学習、という分野の先行研究はあまりまだ行われていませんが、
どのようなフローで「現在の技術で」そこを実装できるか、というところに対して知っていることを全部書いてみました。
格子暗号の応用について興味のある学生や、研究者の方(どちらかというと実装屋の方)には参考になるのではないか、と思っています。
プログラマブルブートストラップの精度、速度の改善によっては、
CKKS等を用いて一時復号を含んで学習全体を執り行う、というような現在の「苦肉の策」的なアプローチを取る必要がなくなる可能性もあります。

これからも格子暗号の応用先には目を配っていきましょう。

今回はこの辺で。

@kenmaro

3
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
7