@kenmaroです。
普段は主に秘密計算、準同型暗号などの記事について投稿しています。
秘密計算に関連するまとめの記事に関しては以下をご覧ください。
注意
この実装は動作確認のための簡易パラメータを使用しています。
このプログラムは、セキュリティパラメータ
parameters.SetSecurityLevel(HEStd_NotSet);
を使用しており、本番システムでは必要なパラメータが変更され、実行時間や精度が大きく変更される可能性が高いです。
あくまでも動作確認のためのパラメータとお考えください。
parameters.SetSecurityLevel(HEStd_128_classic);
セキュリティパラメータ128ビットを使用した際は、多項式の長さを16ビットに設定する必要があり、実行時間が大幅に(少なくとも16倍以上に)上昇しました。
概要
前回の記事 OpenFHE チュートリアル: CKKS形式で演算+近似演算+ブートストラップの組み合わせにチャレンジしてみたかった。
で言及した、OpenFHEライブラリの機能
- CKKS形式のPRE(プロキシ再暗号化)はどのように使用するのか
- CKKS形式のThreshold 暗号はどのように使うのか(マルチパーティ復号)
について今回はチュートリアルを行ってみたいと思います。
プロキシ再暗号化について
プロキシ再暗号化とは、簡単に言えば、Aの鍵で暗号化された暗号を、Bが復号可能な暗号に、
暗号化を解除することなく(つまり復号することなく) 暗号文を変換するという技術です。
これを実行することで、例えば準同型暗号を用いたアクセスコントールなどをアプリケーションに組み込むことが可能になります。
もともとAさんが暗号化したデータに対して、あるタイミングでBさんがその内容を参照したい、
という際に、BさんはAさんに対して、復号権限をリクエストします。
もしAさんがそのリクエストを許可すれば、サーバはプロキシ再暗号化を行い、Bさんに暗号文を送信します。BさんはAさんからもらった暗号文を自分の秘密鍵を使用して復号し、内容をみることができます。
今回は、この内容を簡単にプログラムとしてモックアップしてみた後に、
どのくらいの時間でこのプロキシ再暗号化が実行できるか実測をしてみたいと思います。
プログラム内容
やっている内容は、
- Aの鍵ペアにより暗号文(c)を生成
- cそのもの、cに乗算をしたもの、cにブートストラップを施したもの
- 上のそれぞれに対して、Bの鍵ペアへとプロキシ再暗号化を実行(プロキシ再暗号化にかかった時間を測定)
- Bの鍵ペアで復号を実行し、結果を確認
を行っています。
#define PROFILE
#include "openfhe.h"
#include "stdio.h"
#include <random>
#include <vector>
#include <cassert>
using namespace lbcrypto;
using namespace std;
using namespace std::chrono;
inline double get_time_msec(void) {
return static_cast<double>(duration_cast<nanoseconds>(steady_clock::now().time_since_epoch()).count()) / 1000000;
}
void SimpleBootstrapExample();
int main(int argc, char* argv[]) {
SimpleBootstrapExample();
}
void SimpleBootstrapExample() {
CCParams<CryptoContextCKKSRNS> parameters;
SecretKeyDist secretKeyDist = UNIFORM_TERNARY;
parameters.SetSecretKeyDist(secretKeyDist);
parameters.SetSecurityLevel(HEStd_NotSet);
parameters.SetRingDim(1 << 12);
#if NATIVEINT == 128 && !defined(__EMSCRIPTEN__)
ScalingTechnique rescaleTech = FIXEDAUTO;
usint dcrtBits = 78;
usint firstMod = 89;
#else
ScalingTechnique rescaleTech = FLEXIBLEAUTO;
usint dcrtBits = 59;
usint firstMod = 60;
#endif
parameters.SetScalingModSize(dcrtBits);
parameters.SetScalingTechnique(rescaleTech);
parameters.SetFirstModSize(firstMod);
std::vector<uint32_t> levelBudget = {4, 4};
uint32_t approxBootstrapDepth = 8;
uint32_t levelsUsedBeforeBootstrap = 12;
usint depth =
levelsUsedBeforeBootstrap + FHECKKSRNS::GetBootstrapDepth(approxBootstrapDepth, levelBudget, secretKeyDist);
parameters.SetMultiplicativeDepth(depth);
printf("this is my depth %d\n", depth);
CryptoContext<DCRTPoly> cryptoContext = GenCryptoContext(parameters);
cryptoContext->Enable(PKE);
cryptoContext->Enable(KEYSWITCH);
cryptoContext->Enable(LEVELEDSHE);
cryptoContext->Enable(ADVANCEDSHE);
cryptoContext->Enable(FHE);
cryptoContext->Enable(PRE);
usint ringDim = cryptoContext->GetRingDimension();
// This is the maximum number of slots that can be used for full packing.
usint numSlots = ringDim / 2;
std::cout << "CKKS scheme is using ring dimension " << ringDim << std::endl << std::endl;
cryptoContext->EvalBootstrapSetup(levelBudget);
auto keyPair1 = cryptoContext->KeyGen();
cryptoContext->EvalMultKeyGen(keyPair1.secretKey);
cryptoContext->EvalBootstrapKeyGen(keyPair1.secretKey, numSlots);
KeyPair<DCRTPoly> keyPair2 = cryptoContext->KeyGen();
EvalKey<DCRTPoly> evalKey = cryptoContext->ReKeyGen(keyPair1.secretKey, keyPair2.publicKey);
// Making plaintext vector
std::vector<double> x;
for (int i = 0; i < 10; i++) {
x.push_back(i - 5);
}
size_t encodedLength = x.size();
Plaintext ptxt = cryptoContext->MakeCKKSPackedPlaintext(x, 1, 0);
ptxt->SetLength(encodedLength);
std::cout << "Input: " << ptxt << std::endl;
// Encryption
Ciphertext<DCRTPoly> c = cryptoContext->Encrypt(keyPair1.publicKey, ptxt);
// some operation
Ciphertext<DCRTPoly> c_mul = cryptoContext->EvalMultAndRelinearize(c, c);
// bootstrap operation
Ciphertext<DCRTPoly> c_bs = cryptoContext->EvalBootstrap(c);
// ReEncryption
Ciphertext<DCRTPoly> c_re = cryptoContext->ReEncrypt(c, evalKey);
Ciphertext<DCRTPoly> c_mul_re = cryptoContext->ReEncrypt(c_mul, evalKey);
Ciphertext<DCRTPoly> c_bs_re = cryptoContext->ReEncrypt(c_bs, evalKey);
double start = get_time_msec();
Plaintext res_c;
cryptoContext->Decrypt(keyPair2.secretKey, c_re, &res_c);
res_c->SetLength(encodedLength);
std::cout << "res_c\n\t" << res_c << std::endl;
Plaintext res_c_mul;
cryptoContext->Decrypt(keyPair2.secretKey, c_mul_re, &res_c_mul);
res_c_mul->SetLength(encodedLength);
std::cout << "res_c_mul\n\t" << res_c_mul << std::endl;
Plaintext res_c_bs;
cryptoContext->Decrypt(keyPair2.secretKey, c_bs_re, &res_c_bs);
res_c_bs->SetLength(encodedLength);
std::cout << "res_c_bs\n\t" << res_c_bs << std::endl;
double end = get_time_msec();
printf("total_time: %f\n", end - start);
}
実行結果
Input: (-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, ... ); Estimated precision: 59 bits
res_c
(-5, -4, -3, -2, -1, 1.28295e-14, 1, 2, 3, 4, ... ); Estimated precision: 45 bits
res_c_mul
(25, 16, 9, 4, 1, -4.06569e-15, 1, 4, 9, 16, ... ); Estimated precision: 47 bits
res_c_bs
(-5, -4, -3, -2, -1, 4.27316e-06, 1, 2, 3, 4, ... ); Estimated precision: 18 bits
total_time: 631.670871
結果としてはプロキシ再暗号化が
- 暗号文そのもの、乗算結果、ブートストラップ結果
それぞれに対して実行できており、結果も問題ないということを確認できました。
また、使用したパラメータの元で、プロキシ再暗号化は約0.6秒程度で実行できていました。
マルチパーティ復号について
マルチパーティ復号について少しだけ説明した後、実際にチュートリアルを実行してみます。
マルチパーティ復号は、Threshold 暗号化(閾値暗号化)とも呼ばれます。
プログラム内容
少し長くなる+いろいろな鍵が出てきてややこしいので、コメントを随時入れます。
#include "openfhe.h"
using namespace lbcrypto;
void RunCKKS();
int main(int argc, char* argv[]) {
RunCKKS();
return 0;
}
void RunCKKS() {
usint batchSize = 16;
CCParams<CryptoContextCKKSRNS> parameters;
parameters.SetMultiplicativeDepth(3);
parameters.SetScalingModSize(50);
parameters.SetBatchSize(batchSize);
CryptoContext<DCRTPoly> cc = GenCryptoContext(parameters);
// enable features that you wish to use
cc->Enable(PKE);
cc->Enable(KEYSWITCH);
cc->Enable(LEVELEDSHE);
cc->Enable(ADVANCEDSHE);
cc->Enable(MULTIPARTY);
////////////////////////////////////////////////////////////
// Set-up of parameters
////////////////////////////////////////////////////////////
// Output the generated parameters
std::cout << "p = " << cc->GetCryptoParameters()->GetPlaintextModulus() << std::endl;
std::cout << "n = " << cc->GetCryptoParameters()->GetElementParams()->GetCyclotomicOrder() / 2 << std::endl;
std::cout << "log2 q = " << log2(cc->GetCryptoParameters()->GetElementParams()->GetModulus().ConvertToDouble())
<< std::endl;
// Initialize Public Key Containers
KeyPair<DCRTPoly> kp1;
KeyPair<DCRTPoly> kp2;
KeyPair<DCRTPoly> kpMultiparty;
////////////////////////////////////////////////////////////
// Perform Key Generation Operation
////////////////////////////////////////////////////////////
std::cout << "Running key generation (used for source data)..." << std::endl;
// Round 1 (party A)
std::cout << "Round 1 (party A) started." << std::endl;
kp1 = cc->KeyGen();
まずパーティ1のための鍵ペアkp1
を生成しています。
// Generate evalmult key part for A
auto evalMultKey = cc->KeySwitchGen(kp1.secretKey, kp1.secretKey);
パーティ1:乗算した後の自分自身への鍵交換に必要な鍵evalMultKey
を生成しています。
// Generate evalsum key part for A
cc->EvalSumKeyGen(kp1.secretKey);
auto evalSumKeys =
std::make_shared<std::map<usint, EvalKey<DCRTPoly>>>(cc->GetEvalSumKeyMap(kp1.secretKey->GetKeyTag()));
std::cout << "Round 1 of key generation completed." << std::endl;
パーティ1: 足し算に必要な鍵evalSumKeys
を生成しています。
// Round 2 (party B)
std::cout << "Round 2 (party B) started." << std::endl;
std::cout << "Joint public key for (s_a + s_b) is generated..." << std::endl;
kp2 = cc->MultipartyKeyGen(kp1.publicKey);
auto evalMultKey2 = cc->MultiKeySwitchGen(kp2.secretKey, kp2.secretKey, evalMultKey);
パーティ2:マルチパーティ計算に必要な鍵ペアkp2
を生成しています。
また、乗算を計算するための、自分自身への鍵交換に必要な鍵evalMultKey2
を生成しています。
std::cout << "Joint evaluation addtion key for (s_a + s_b) is generated..." << std::endl;
auto evalAddAB = cc->MultiAddEvalKeys(evalMultKey, evalMultKey2, kp2.publicKey->GetKeyTag());
パーティ1とパーティ2がそれぞれ保有する暗号のシェア
パーティ2のシェアs_b
, パーティ1のシェアs_a
にたいして、足し算を行うための鍵
evalAddAB
を生成しています。
std::cout << "Joint evaluation multiplication key (s_a + s_b) is transformed "
"into s_b*(s_a + s_b)..."
<< std::endl;
auto evalMultBAB = cc->MultiMultEvalKey(kp2.secretKey, evalAddAB, kp2.publicKey->GetKeyTag());
パーティ2のシェアs_b
, パーティ1のシェアs_a
に対して、パーティ2のシェアs_b
を乗算するための鍵evalMultBAB
を生成しています。
auto evalSumKeysB = cc->MultiEvalSumKeyGen(kp2.secretKey, evalSumKeys, kp2.publicKey->GetKeyTag());
パーティ1のシェアを加算するための鍵evalSumKeys
に対応した、
パーティ2のシェアを加算するための鍵evalSumKeysB
を生成しています。
std::cout << "Joint evaluation summation key for (s_a + s_b) is generated..." << std::endl;
auto evalSumKeysJoin = cc->MultiAddEvalSumKeys(evalSumKeys, evalSumKeysB, kp2.publicKey->GetKeyTag());
cc->InsertEvalSumKey(evalSumKeysJoin);
std::cout << "Round 2 of key generation completed." << std::endl;
先ほどのパーティ1のシェアを加算するための鍵evalSumKeys
、
パーティ2のシェアを加算するための鍵evalSumKeysB
から、
シェア全体(s_a
, s_b
)に対して加算を実行できる鍵evalSumKeysJoin
を生成し、
CryptoContext
に登録しています。
std::cout << "Round 3 (party A) started." << std::endl;
std::cout << "Joint key (s_a + s_b) is transformed into s_a*(s_a + s_b)..." << std::endl;
auto evalMultAAB = cc->MultiMultEvalKey(kp1.secretKey, evalAddAB, kp2.publicKey->GetKeyTag());
各シェア(s_a
, s_b
)に対して、加算を行うための鍵evalAddAB
を用いて、
s_a + s_b
に対してs_a
を乗算するための鍵evalMultAAB
を生成しています。
std::cout << "Computing the final evaluation multiplication key for (s_a + "
"s_b)*(s_a + s_b)..."
<< std::endl;
auto evalMultFinal = cc->MultiAddEvalMultKeys(evalMultAAB, evalMultBAB, evalAddAB->GetKeyTag());
cc->InsertEvalMultKey({evalMultFinal});
std::cout << "Round 3 of key generation completed." << std::endl;
各シェア(s_a
, s_b
)に対して、
(s_b)*(s_a + s_b)
を実行できる鍵evalMultBBA
と、
(s_a)*(s_a + s_b)
を実行できる鍵evalMultABA
と、
を利用して、
(s_b)*(s_a + s_b) + (s_a)*(s_a + s_b)
を実行できる鍵evalMultFinal
を生成し、
それをCryptoContext
に登録しています。
////////////////////////////////////////////////////////////
// Encode source data
////////////////////////////////////////////////////////////
std::vector<double> vectorOfInts1 = {1, 2, 3, 4, 5, 6, 5, 4, 3, 2, 1, 0};
std::vector<double> vectorOfInts2 = {1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0};
std::vector<double> vectorOfInts3 = {2, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 0};
Plaintext plaintext1 = cc->MakeCKKSPackedPlaintext(vectorOfInts1);
Plaintext plaintext2 = cc->MakeCKKSPackedPlaintext(vectorOfInts2);
Plaintext plaintext3 = cc->MakeCKKSPackedPlaintext(vectorOfInts3);
////////////////////////////////////////////////////////////
// Encryption
////////////////////////////////////////////////////////////
Ciphertext<DCRTPoly> ciphertext1;
Ciphertext<DCRTPoly> ciphertext2;
Ciphertext<DCRTPoly> ciphertext3;
ciphertext1 = cc->Encrypt(kp2.publicKey, plaintext1);
ciphertext2 = cc->Encrypt(kp2.publicKey, plaintext2);
ciphertext3 = cc->Encrypt(kp2.publicKey, plaintext3);
ここまでですが、kp2 (パーティ2が生成した鍵ペア)を使って、
3つのベクトルを暗号化しています。
////////////////////////////////////////////////////////////
// EvalAdd Operation on Re-Encrypted Data
////////////////////////////////////////////////////////////
Ciphertext<DCRTPoly> ciphertextAdd12;
Ciphertext<DCRTPoly> ciphertextAdd123;
ciphertextAdd12 = cc->EvalAdd(ciphertext1, ciphertext2);
ciphertextAdd123 = cc->EvalAdd(ciphertextAdd12, ciphertext3);
auto ciphertextMultTemp = cc->EvalMult(ciphertext1, ciphertext3);
auto ciphertextMult = cc->ModReduce(ciphertextMultTemp);
auto ciphertextEvalSum = cc->EvalSum(ciphertext3, batchSize);
3つの暗号文(c1
, c2
, c3
)に対して、
ここまでで実行した演算は以下の通りです。
c12 = c1 + c2
c123 = c12 + c3
cmult = c1 * c3
csum = sum(c3)
(ベクトルの要素を全て足し上げる演算)
////////////////////////////////////////////////////////////
// Decryption after Accumulation Operation on Encrypted Data with Multiparty
////////////////////////////////////////////////////////////
Plaintext plaintextAddNew1;
Plaintext plaintextAddNew2;
Plaintext plaintextAddNew3;
DCRTPoly partialPlaintext1;
DCRTPoly partialPlaintext2;
DCRTPoly partialPlaintext3;
Plaintext plaintextMultipartyNew;
const std::shared_ptr<CryptoParametersBase<DCRTPoly>> cryptoParams = kp1.secretKey->GetCryptoParameters();
const std::shared_ptr<typename DCRTPoly::Params> elementParams = cryptoParams->GetElementParams();
// distributed decryption
auto ciphertextPartial1 = cc->MultipartyDecryptLead({ciphertextAdd123}, kp1.secretKey);
auto ciphertextPartial2 = cc->MultipartyDecryptMain({ciphertextAdd123}, kp2.secretKey);
std::vector<Ciphertext<DCRTPoly>> partialCiphertextVec;
partialCiphertextVec.push_back(ciphertextPartial1[0]);
partialCiphertextVec.push_back(ciphertextPartial2[0]);
cc->MultipartyDecryptFusion(partialCiphertextVec, &plaintextMultipartyNew);
std::cout << "\n Original Plaintext: \n" << std::endl;
std::cout << plaintext1 << std::endl;
std::cout << plaintext2 << std::endl;
std::cout << plaintext3 << std::endl;
plaintextMultipartyNew->SetLength(plaintext1->GetLength());
std::cout << "\n Resulting Fused Plaintext: \n" << std::endl;
std::cout << plaintextMultipartyNew << std::endl;
std::cout << "\n";
ここで行っているのは、先ほどの
c123 = c12 + c3
の結果のc123
を、パーティ1とパーティ2がそれぞれ部分復号し、
最後にそれらの結果を結合して(cc->MultipartyDecryptionFusion
)、最終的な結果を得ているということです。
Plaintext plaintextMultipartyMult;
ciphertextPartial1 = cc->MultipartyDecryptLead({ciphertextMult}, kp1.secretKey);
ciphertextPartial2 = cc->MultipartyDecryptMain({ciphertextMult}, kp2.secretKey);
std::vector<Ciphertext<DCRTPoly>> partialCiphertextVecMult;
partialCiphertextVecMult.push_back(ciphertextPartial1[0]);
partialCiphertextVecMult.push_back(ciphertextPartial2[0]);
cc->MultipartyDecryptFusion(partialCiphertextVecMult, &plaintextMultipartyMult);
plaintextMultipartyMult->SetLength(plaintext1->GetLength());
std::cout << "\n Resulting Fused Plaintext after Multiplication of plaintexts 1 "
"and 3: \n"
<< std::endl;
std::cout << plaintextMultipartyMult << std::endl;
std::cout << "\n";
ここで行っているのは、先ほどの
cmult = c1 * c3
の結果のcmult
を、パーティ1とパーティ2がそれぞれ部分復号し、
最後にそれらの結果を結合して(cc->MultipartyDecryptionFusion
)、最終的な結果を得ているということです。
Plaintext plaintextMultipartyEvalSum;
ciphertextPartial1 = cc->MultipartyDecryptLead({ciphertextEvalSum}, kp1.secretKey);
ciphertextPartial2 = cc->MultipartyDecryptMain({ciphertextEvalSum}, kp2.secretKey);
std::vector<Ciphertext<DCRTPoly>> partialCiphertextVecEvalSum;
partialCiphertextVecEvalSum.push_back(ciphertextPartial1[0]);
partialCiphertextVecEvalSum.push_back(ciphertextPartial2[0]);
cc->MultipartyDecryptFusion(partialCiphertextVecEvalSum, &plaintextMultipartyEvalSum);
plaintextMultipartyEvalSum->SetLength(plaintext1->GetLength());
std::cout << "\n Fused result after the Summation of ciphertext 3: "
"\n"
<< std::endl;
std::cout << plaintextMultipartyEvalSum << std::endl;
}
ここで行っているのは、先ほどの
csum = sum(c3)
(ベクトルの要素を全て足し上げる演算)
の結果のcsum
を、パーティ1とパーティ2がそれぞれ部分復号し、
最後にそれらの結果を結合して(cc->MultipartyDecryptionFusion
)、最終的な結果を得ているということです。
実行結果
Original Plaintext:
(1, 2, 3, 4, 5, 6, 5, 4, 3, 2, 1, ... ); Estimated precision: 50 bits
(1, 0, 0, 1, 1, ... ); Estimated precision: 50 bits
(2, 2, 3, 4, 5, 6, 7, 8, 9, 10, ... ); Estimated precision: 50 bits
Resulting Fused Plaintext:
(4, 4, 6, 9, 11, 12, 12, 12, 12, 12, 1, 2.04281e-14, ... ); Estimated precision: 43 bits
まず、c123 = c1 + c2 + c3
の結果について、確かにDecryptFusion
で閾値復号した結果が正しいことが確認できます。
Resulting Fused Plaintext after Multiplication of plaintexts 1 and 3:
(2, 4, 9, 16, 25, 36, 35, 32, 27, 20, -4.20108e-13, -1.60583e-12, ... ); Estimated precision: 39 bits
次に、cmult = c1 * c3
の結果が正しく得られていることが確認できます。
Fused result after the Summation of ciphertext 3:
(56, 56, 56, 56, 56, 56, 56, 56, 56, 56, 56, 56, ... ); Estimated precision: 43 bits
最後に、csum = sum(c3)
(ベクトルの要素を全て足し上げる演算) の結果(56が正しい結果)
が得られていることがわかります。
まとめ
今回は、OpenFHEライブラリを用いて、
- CKKS形式のPRE(プロキシ再暗号化)はどのように使用するのか
- CKKS形式のThreshold 暗号はどのように使うのか(マルチパーティ復号)
について今回はチュートリアルを実行してみました。
プロキシ再暗号化はコンセプトがつかみやすく、使い所がありそうですが、
マルチパーティに関しては鍵が乱立しているので理解するのが難しいこと、使えそうなシナリオを考えるのが難しいな、という印象を受けました。
しかしながら、CKKS形式に明示的に上記二つの機能を実装しているライブラリは
OpenFHEのみだと思われるので、それらをアプリケーションに応用していくのもこれから面白そうだなと感じています。
いろいろと使い方に慣れる必要がありそうですが、もう少しOpenFHEについて追ってみようと思います。
今回はこの辺で。