LoginSignup
1
0

【準同型暗号】OpenFHEの1.1系で追加された格子暗号の最先端を解説!(閾値型暗号)

Posted at

概要

今回は、準同型暗号(暗号化したままデータを計算できる暗号形式)
の中でも注目されている「格子暗号」のオープンソースライブラリ

OpenFHE

の最新バージョン v1.1 系でできる様になった機能について解説します。
また、実際に動かすことのできるサンプルコードも紹介します。

忙しい人のためのまとめ

新機能

OpenFHEのv1.1を使用すると、次の様な機能を使えるよ

①閾値型のCKKS形式の暗号が使えるよ
②CKKS形式からFHEW(TFHE)形式に暗号状態のまま変換、その逆変換もできるよ

この記事ではそのうち、
閾値型の暗号
について解説するよ

何が美味しいの?

閾値型

閾値型の格子暗号形式が使えると

  • 秘密鍵を分割して、部分復号を行えるよ
  • 部分復号結果を集約すると平文に戻せるよ
  • こうすることで秘密鍵を1パーティが保持しておく必要はないのでセキュリティ的に嬉しいよ

解説

それでは、実際にOpenFHE1.1系で追加された機能について解説していきます。

1.1系で追加されたこの二つの機能

  • 閾値型暗号
  • 暗号形式の動的変換

は、理論研究でも非常に最先端の内容であり、それらをOSSの形で実装しているOpenFHEは
まさに最先端の格子暗号ライブラリになっています。

そのほかにも、

などを中心に(ほかにもたくさんライブラリはありますが)、たくさんのOSSが存在しています。

閾値型の格子暗号についてはLattigo にも実装がされていますが、
CKKS形式に対して実装しているのはOpenFHEがおそらく初めてではないでしょうか。

暗号形式の変換については、OpenFHEが初めての実装であると思います。

閾値型格子暗号とは

閾値型格子暗号とは何か?と疑問に思われる方もいるかもしれませんが、
一言で言うと「秘密鍵」を保持するパーティを分割できる形式のことです。

秘密鍵の取り扱い

秘密鍵を使うシーンは、暗号を平文に戻す「復号」を実行する時です。

公開鍵暗号を使用する上で、秘密鍵は基本的にユーザが保持し、
自身に届いた暗号を復号して中身を見る時に使用します。

秘密鍵の保存と使用上の注意

秘密鍵は絶対に盗まれてはいけない鍵ですから、取り扱いには注意が必要です。
誰かに盗まれてはいけないため、自分で管理するか、もしくは「秘密鍵自体を別の鍵で暗号化」
して、暗号化した秘密鍵をどこかに保存したりすることがシステム構築では考えられます。

「秘密鍵を保管」する面ではこれで大丈夫ですが、実際に復号する時は秘密鍵を使用する必要がありますから、
復号は基本的にユーザによって実行される必要がありました。

復号を委託することはとてもリスキー

もしユーザがこの復号処理を一部でも他の人に任せたい時、秘密鍵を他の人に渡すことになるため、
この「復号を誰かに任せる」行動はとてもリスキーでした。

閾値型では秘密鍵は分散できる

閾値型の格子暗号を使うと、秘密鍵を分割することが可能になります。
また、復号を行う時も分割した秘密鍵を使って復号することができます。
秘密鍵をN個に分割したとすると、サーバはN個に分割された秘密鍵を取得するだけで、
秘密鍵そのものはどのサーバも保持しないため、リスクを分散できます。

ユーザの視点で言えば、自分の鍵を必ずしも自分で管理する必要がなくなるため、使い勝手も向上します。

例(5つに分散する場合)

例えば、5つに秘密鍵を分割する場合

  • ユーザは秘密鍵を生成後、5つに分割
  • ユーザは分割された秘密鍵を5つの別々の場所(復号サーバ1〜5)に保存
  • ユーザは公開鍵でデータを暗号化
  • ユーザはデータを計算サーバに送信
  • 計算サーバは暗号化されたデータに対して何か計算し結果の暗号文を作る
  • ユーザはデータの復号を依頼する
  • 計算サーバは復号サーバ1〜5に結果の暗号文を送信
  • 復号サーバ1〜5は部分復号を実行し、部分復号結果1〜5をそれぞれ取得
  • 復号サーバのうちどこかのサーバは、部分復号結果1〜5を集約することで結果の平文を見ることができる

というシナリオになります。

この時に、復号サーバ1〜5は、単体では平文を見ることはできず、
5つの部分復号結果が合わさって初めて平文を見ることができるため、
リスクを分散しています。
これはMPC(秘密分散)の考え方そのものであり、それを準同型暗号の方式にも応用したものになっています。

実装

今回は、3パーティでの閾値型の構成で、

暗号を2つ用意し、
内積計算、
近似式でのReLU関数の評価
閾値型の復号
値の確認

を一連の流れでやってみます。
参考にしたのは

こちらの公式のexample です。

example.cpp
void TCKKSInnerReLUThenBoot(enum ScalingTechnique scaleTech) {
    if (scaleTech != ScalingTechnique::FIXEDMANUAL && scaleTech != ScalingTechnique::FIXEDAUTO &&
        scaleTech != ScalingTechnique::FLEXIBLEAUTO && scaleTech != ScalingTechnique::FLEXIBLEAUTOEXT) {
        std::string errMsg = "ERROR: Scaling technique is not supported!";
        OPENFHE_THROW(config_error, errMsg);
    }

    CCParams<CryptoContextCKKSRNS> parameters;
    // A. Specify main parameters
    /*  A1) Secret key distribution
	* The secret key distribution for CKKS should either be SPARSE_TERNARY or UNIFORM_TERNARY.
	* The SPARSE_TERNARY distribution was used in the original CKKS paper,
	* but in this example, we use UNIFORM_TERNARY because this is included in the homomorphic
	* encryption standard.
	*/
    SecretKeyDist secretKeyDist = UNIFORM_TERNARY;
    parameters.SetSecretKeyDist(secretKeyDist);

    /*  A2) Desired security level based on FHE standards.
	* In this example, we use the "NotSet" option, so the example can run more quickly with
	* a smaller ring dimension. Note that this should be used only in
	* non-production environments, or by experts who understand the security
	* implications of their choices. In production-like environments, we recommend using
	* HEStd_128_classic, HEStd_192_classic, or HEStd_256_classic for 128-bit, 192-bit,
	* or 256-bit security, respectively. If you choose one of these as your security level,
	* you do not need to set the ring dimension.
	*/
    parameters.SetSecurityLevel(HEStd_128_classic);

    /*  A3) Scaling parameters.
	* By default, we set the modulus sizes and rescaling technique to the following values
	* to obtain a good precision and performance tradeoff. We recommend keeping the parameters
	* below unless you are an FHE expert.
	*/
    usint dcrtBits = 50;
    usint firstMod = 60;

    parameters.SetScalingModSize(dcrtBits);
    parameters.SetScalingTechnique(scaleTech);
    parameters.SetFirstModSize(firstMod);

    /*  A4) Multiplicative depth.
    * The multiplicative depth detemins the computational capability of the instantiated scheme. It should be set
    * according the following formula:
    * multDepth >= desired_depth + interactive_bootstrapping_depth
    * where,
    *   The desired_depth is the depth of the computation, as chosen by the user.
    *   The interactive_bootstrapping_depth is either 3 or 4, depending on the ciphertext compression mode: COMPACT vs SLACK (see below)
    * Example 1, if you want to perform a computation of depth 24, you can set multDepth to 10, use 6 levels
    * for computation and 4 for interactive bootstrapping. You will need to bootstrap 3 times.
    */
    parameters.SetMultiplicativeDepth(10);
    parameters.SetKeySwitchTechnique(KeySwitchTechnique::HYBRID);

    uint32_t batchSize = 16;
    parameters.SetBatchSize(batchSize);

    /*  Protocol-specific parameters (SLACK or COMPACT)
    * SLACK (default) uses larger masks, which makes it more secure theoretically. However, it is also slightly less efficient.
    * COMPACT uses smaller masks, which makes it more efficient. However, it is relatively less secure theoretically.
    * Both options can be used for practical security.
    * The following table summarizes the differences between SLACK and COMPACT:
    * Parameter	        SLACK	                                        COMPACT
    * Mask size	        Larger	                                        Smaller
    * Security	        More secure	                                    Less secure
    * Efficiency	    Less efficient	                                More efficient
    * Recommended use	For applications where security is paramount	For applications where efficiency is paramount
    */
    auto compressionLevel = COMPRESSION_LEVEL::COMPACT;
    parameters.SetInteractiveBootCompressionLevel(compressionLevel);

    CryptoContext<DCRTPoly> cryptoContext = GenCryptoContext(parameters);

    cryptoContext->Enable(PKE);
    cryptoContext->Enable(KEYSWITCH);
    cryptoContext->Enable(LEVELEDSHE);
    cryptoContext->Enable(ADVANCEDSHE);
    cryptoContext->Enable(MULTIPARTY);

    usint ringDim = cryptoContext->GetRingDimension();
    // This is the maximum number of slots that can be used for full packing.
    usint maxNumSlots = ringDim / 2;
    std::cout << "TCKKS scheme is using ring dimension " << ringDim << std::endl;
    std::cout << "TCKKS scheme number of slots         " << batchSize << std::endl;
    std::cout << "TCKKS scheme max number of slots     " << maxNumSlots << std::endl;
    std::cout << "TCKKS example with Scaling Technique " << scaleTech << std::endl;

    const usint numParties = 3;

    std::cout << "\n===========================IntMPBoot protocol parameters===========================\n";
    std::cout << "num of parties: " << numParties << "\n";
    std::cout << "===============================================================\n";

    // double eps = 0.0001;

    // Initialize Public Key Containers
    KeyPair<DCRTPoly> kp1;  // Party 1
    KeyPair<DCRTPoly> kp2;  // Party 2
    KeyPair<DCRTPoly> kp3;  // Lead party - who finalizes interactive bootstrapping

    KeyPair<DCRTPoly> kpMultiparty;

    ////////////////////////////////////////////////////////////
    // Perform Key Generation Operation
    ////////////////////////////////////////////////////////////

    // Round 1 (party A)
    kp1 = cryptoContext->KeyGen();

    // Generate evalmult key part for A
    auto evalMultKey = cryptoContext->KeySwitchGen(kp1.secretKey, kp1.secretKey);

    std::vector<int32_t> indices = {1};
    // Generate evalsum key part for A
    cryptoContext->EvalSumKeyGen(kp1.secretKey);
    auto evalSumKeys = std::make_shared<std::map<usint, EvalKey<DCRTPoly>>>(
        cryptoContext->GetEvalSumKeyMap(kp1.secretKey->GetKeyTag()));

    cryptoContext->EvalAtIndexKeyGen(kp1.secretKey, indices);
    auto evalAtIndexKeys = std::make_shared<std::map<usint, EvalKey<DCRTPoly>>>(
        cryptoContext->GetEvalAutomorphismKeyMap(kp1.secretKey->GetKeyTag()));

    // Round 2 (party B)
    kp2                  = cryptoContext->MultipartyKeyGen(kp1.publicKey);
    auto evalMultKey2    = cryptoContext->MultiKeySwitchGen(kp2.secretKey, kp2.secretKey, evalMultKey);
    auto evalMultAB      = cryptoContext->MultiAddEvalKeys(evalMultKey, evalMultKey2, kp2.publicKey->GetKeyTag());
    auto evalMultBAB     = cryptoContext->MultiMultEvalKey(kp2.secretKey, evalMultAB, kp2.publicKey->GetKeyTag());
    auto evalSumKeysB    = cryptoContext->MultiEvalSumKeyGen(kp2.secretKey, evalSumKeys, kp2.publicKey->GetKeyTag());
    auto evalSumKeysJoin = cryptoContext->MultiAddEvalSumKeys(evalSumKeys, evalSumKeysB, kp2.publicKey->GetKeyTag());
    cryptoContext->InsertEvalSumKey(evalSumKeysJoin);
    auto evalMultAAB   = cryptoContext->MultiMultEvalKey(kp1.secretKey, evalMultAB, kp2.publicKey->GetKeyTag());
    auto evalMultFinal = cryptoContext->MultiAddEvalMultKeys(evalMultAAB, evalMultBAB, evalMultAB->GetKeyTag());
    cryptoContext->InsertEvalMultKey({evalMultFinal});

    auto evalAtIndexKeysB =
        cryptoContext->MultiEvalAtIndexKeyGen(kp2.secretKey, evalAtIndexKeys, indices, kp2.publicKey->GetKeyTag());
    auto evalAtIndexKeysJoin =
        cryptoContext->MultiAddEvalAutomorphismKeys(evalAtIndexKeys, evalAtIndexKeysB, kp2.publicKey->GetKeyTag());
    cryptoContext->InsertEvalAutomorphismKey(evalAtIndexKeysJoin);

    /////////////////////
    // Round 3 (party C) - Lead Party (who encrypts and finalizes the bootstrapping protocol)
    kp3                 = cryptoContext->MultipartyKeyGen(kp2.publicKey);
    auto evalMultKey3   = cryptoContext->MultiKeySwitchGen(kp3.secretKey, kp3.secretKey, evalMultKey);
    auto evalMultABC    = cryptoContext->MultiAddEvalKeys(evalMultAB, evalMultKey3, kp3.publicKey->GetKeyTag());
    auto evalMultBABC   = cryptoContext->MultiMultEvalKey(kp2.secretKey, evalMultABC, kp3.publicKey->GetKeyTag());
    auto evalMultAABC   = cryptoContext->MultiMultEvalKey(kp1.secretKey, evalMultABC, kp3.publicKey->GetKeyTag());
    auto evalMultCABC   = cryptoContext->MultiMultEvalKey(kp3.secretKey, evalMultABC, kp3.publicKey->GetKeyTag());
    auto evalMultABABC  = cryptoContext->MultiAddEvalMultKeys(evalMultBABC, evalMultAABC, evalMultBABC->GetKeyTag());
    auto evalMultFinal2 = cryptoContext->MultiAddEvalMultKeys(evalMultABABC, evalMultCABC, evalMultCABC->GetKeyTag());
    cryptoContext->InsertEvalMultKey({evalMultFinal2});

    auto evalAtIndexKeysC =
        cryptoContext->MultiEvalAtIndexKeyGen(kp3.secretKey, evalAtIndexKeys, indices, kp3.publicKey->GetKeyTag());
    auto evalAtIndexKeysJoin2 =
        cryptoContext->MultiAddEvalAutomorphismKeys(evalAtIndexKeys, evalAtIndexKeysC, kp3.publicKey->GetKeyTag());
    cryptoContext->InsertEvalAutomorphismKey(evalAtIndexKeysJoin2);

    auto evalSumKeysC     = cryptoContext->MultiEvalSumKeyGen(kp3.secretKey, evalSumKeys, kp3.publicKey->GetKeyTag());
    auto evalSumKeysJoin2 = cryptoContext->MultiAddEvalSumKeys(evalSumKeys, evalSumKeysC, kp3.publicKey->GetKeyTag());
    cryptoContext->InsertEvalSumKey(evalSumKeysJoin2);

    if (!kp1.good()) {
        std::cout << "Key generation failed!" << std::endl;
        exit(1);
    }
    if (!kp2.good()) {
        std::cout << "Key generation failed!" << std::endl;
        exit(1);
    }
    if (!kp3.good()) {
        std::cout << "Key generation failed!" << std::endl;
        exit(1);
    }

    // END of Key Generation

    //std::vector<std::complex<double>> input({-4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0});
    std::vector<double> m1({1.0, 2.0, 3.0});
    std::vector<double> m2({1, 1, 0});
    std::vector<double> unit_one{1, 0, 0};
    size_t encodedLength = m1.size();

    Plaintext p1         = cryptoContext->MakeCKKSPackedPlaintext(m1);
    Plaintext p2         = cryptoContext->MakeCKKSPackedPlaintext(m2);
    Plaintext p_unit_one = cryptoContext->MakeCKKSPackedPlaintext(unit_one);

    // Chebyshev coefficients
    // std::vector<double> coefficients({1.0, 0.558971, 0.0, -0.0943712, 0.0, 0.0215023, 0.0, -0.00505348, 0.0, 0.00119324,
    //                                   0.0, -0.000281928, 0.0, 0.0000664347, 0.0, -0.0000148709});
    // Input range
    double a = -4;
    double b = 4;

    //Plaintext pt1       = cryptoContext->MakeCKKSPackedPlaintext(input);
    //usint encodedLength = input.size();

    auto ct1 = cryptoContext->Encrypt(kp3.publicKey, p1);

    // ct1 = cryptoContext->EvalChebyshevSeries(ct1, coefficients, a, b);
    auto ct2 = cryptoContext->Encrypt(kp3.publicKey, p2);
    // auto ct_unit_one = cryptoContext->Encrypt(kp3.publicKey, p_unit_one);

    // ct1 = cryptoContext->EvalInnerProduct(ct1, ct2, encodedLength);
    //auto c_chev  = cryptoContext->EvalMultAndRelinearize(c_inner, ct_unit_one);
    ct1 = cryptoContext->EvalMult(ct1, ct2);
    // auto ct1_rot = cryptoContext->EvalAtIndex(ct1, indices[0]);
    ct1 = cryptoContext->EvalAdd(ct1, ct1);

    uint32_t polyDegree = 20;
    ct1 = cryptoContext->EvalChebyshevFunction([](double x) -> double { return relu(x); }, ct1, a, b, polyDegree);

    auto cp1 = cryptoContext->MultipartyDecryptMain({ct1}, kp1.secretKey);
    auto cp2 = cryptoContext->MultipartyDecryptMain({ct1}, kp2.secretKey);
    printf("debug yo\n");
    auto cp3 = cryptoContext->MultipartyDecryptLead({ct1}, kp3.secretKey);
    printf("debug yo2\n");
    vector<Ciphertext<DCRTPoly>> pcv;
    pcv.push_back(cp1[0]);
    pcv.push_back(cp2[0]);
    pcv.push_back(cp3[0]);

    Plaintext pmp;
    cryptoContext->MultipartyDecryptFusion(pcv, &pmp);
    printf("debug9\n");
    pmp->SetLength(encodedLength);

    std::cout << "Computed Res: \n\t" << pmp->GetCKKSPackedValue() << std::endl;
    // printf("done chev\n");

    // INTERACTIVE BOOTSTRAPPING STARTS

    ct1 = cryptoContext->IntMPBootAdjustScale(ct1);
    printf("debug1\n");

    // Leading party (party B) generates a Common Random Poly (crp) at max coefficient modulus (QNumPrime).
    // a is sampled at random uniformly from R_{Q}
    auto crp = cryptoContext->IntMPBootRandomElementGen(kp3.publicKey);
    printf("debug2\n");
    // Each party generates its own shares: maskedDecryptionShare and reEncryptionShare
    // (h_{0,i}, h_{1,i}) = (masked decryption share, re-encryption share)
    // we use a vector inseat of std::pair for Python API compatibility
    vector<Ciphertext<DCRTPoly>> sharesPair0;  // for Party A
    vector<Ciphertext<DCRTPoly>> sharesPair1;  // for Party B
    vector<Ciphertext<DCRTPoly>> sharesPair2;  // for Party C

    // extract c1 - element-wise
    auto c1 = ct1->Clone();
    c1->GetElements().erase(c1->GetElements().begin());
    printf("debug3\n");
    // masked decryption on the client: c1 = a*s1
    sharesPair0 = cryptoContext->IntMPBootDecrypt(kp1.secretKey, c1, crp);
    sharesPair1 = cryptoContext->IntMPBootDecrypt(kp2.secretKey, c1, crp);
    sharesPair2 = cryptoContext->IntMPBootDecrypt(kp3.secretKey, c1, crp);
    printf("debug4\n");

    vector<vector<Ciphertext<DCRTPoly>>> sharesPairVec;
    sharesPairVec.push_back(sharesPair0);
    sharesPairVec.push_back(sharesPair1);
    sharesPairVec.push_back(sharesPair2);

    // Party B finalizes the protocol by aggregating the shares and reEncrypting the results
    auto aggregatedSharesPair = cryptoContext->IntMPBootAdd(sharesPairVec);
    printf("debug5\n");
    auto ciphertextOutput = cryptoContext->IntMPBootEncrypt(kp3.publicKey, aggregatedSharesPair, crp, ct1);
    printf("debug6\n");

    // INTERACTIVE BOOTSTRAPPING ENDS

    // distributed decryption

    auto ciphertextPartial1 = cryptoContext->MultipartyDecryptMain({ciphertextOutput}, kp1.secretKey);
    auto ciphertextPartial2 = cryptoContext->MultipartyDecryptMain({ciphertextOutput}, kp2.secretKey);
    printf("debug7\n");
    auto ciphertextPartial3 = cryptoContext->MultipartyDecryptLead({ciphertextOutput}, kp3.secretKey);
    printf("debug8\n");
    vector<Ciphertext<DCRTPoly>> partialCiphertextVec;
    partialCiphertextVec.push_back(ciphertextPartial1[0]);
    partialCiphertextVec.push_back(ciphertextPartial2[0]);
    partialCiphertextVec.push_back(ciphertextPartial3[0]);

    Plaintext plaintextMultiparty;
    cryptoContext->MultipartyDecryptFusion(partialCiphertextVec, &plaintextMultiparty);
    printf("debug9\n");
    plaintextMultiparty->SetLength(encodedLength);

    std::cout << "Computed Res: \n\t" << plaintextMultiparty->GetCKKSPackedValue() << std::endl;

    // checkApproximateEquality(plaintextResult->GetCKKSPackedValue(), plaintextMultiparty->GetCKKSPackedValue(),
    //                          encodedLength, eps);

    std::cout << "\n============================ INTERACTIVE DECRYPTION ENDED ============================\n";

    std::cout << "\nTCKKSCollectiveBoot FHE example with rescaling technique: " << scaleTech << " Completed!"
              << std::endl;
}

結論から言うと、
マルチパーティ構成の時の回転操作を行う鍵の生成がよく分からず、
内積を取ることができませんでした。


terminate called after throwing an instance of 'lbcrypto::math_error'
  what():  /openfhe-development-1.1.1/src/pke/lib/encoding/ckkspackedencoding.cpp:535 The decryption failed because the approximation error is too high. Check the parameters.

のように復号に失敗してしまう結果になりました。

しかし、回転作業を除いた形では、

Computed Res:
	[ (2.01378,0) (3.9929,0) (-2.80926e-09,0) ]

の様に、
掛け算 --> 足し算 --> ReLU --> 閾値型復号

の動作を確認することができました。

解決できたらまたアップデートしたいと思います。

今回はこの辺で。

@kenmaro

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0