LoginSignup
0
0

More than 1 year has passed since last update.

TFHEベースの暗号を用いた数値計算ライブラリ【concrete-numpy】を使ってみた

Posted at

kenmaro です。
秘密計算、特に準同型暗号のことについて記事を書いています。

秘密計算エンジニアとして得た全ての知見をまとめた記事はこちら。
https://qiita.com/kenmaro/items/74c3147ccb8c7ce7c60c

これから準同型暗号について勉強したいリサーチャー、エンジニアの方へのロードマップはこちら。
https://qiita.com/kenmaro/items/f2d4fb84833c308a4d29

今話題のゼロ知識証明について解説した記事はこちら。
https://qiita.com/kenmaro/items/d968375793fe754575fe

概要

TFHEにプログラマブルブートストラップという拡張機能を搭載したTFHE亜種を2018年に発表した
フランスのスタートアップ ZamaAIですが、

  • tfhe-rs
  • concrete-np
  • concrete-ml

というコアライブラリをオープンソースとして公開しています。

今回はその中の、concrete-np のチュートリアルを動かしてみたいと思います。

実行環境

私のノートパソコンで動かしてみました。

OS Mac (Intel)
Core 8Core 16 threads
Memory 64GB

のpyenv 環境に、

pip install concrete-numpy

としてインストールしました。

チュートリアルの内容

ここをフォローしていきます。

暗号同士の足し算

test_add.py
import time
import concrete.numpy as cnp

def add(x, y):
    return x + y

compiler = cnp.Compiler(add, {"x": "encrypted", "y": "encrypted"})
inputset = [(2, 3), (0, 0), (1, 6), (7, 7), (7, 1), (3, 2), (6, 1), (1, 7), (4, 5), (5, 4)]

t1 = time.time()
print(f"Compiling...")
circuit = compiler.compile(inputset)
t2 = time.time()
print(f"time: {t2-t1}")

t1 = time.time()
print(f"Generating keys...")
circuit.keygen()
t2 = time.time()
print(f"time: {t2-t1}")

t1 = time.time()
examples = [(3, 4), (1, 2), (7, 7), (0, 0)]
for example in examples:
    encrypted_example = circuit.encrypt(*example)
    encrypted_result = circuit.run(encrypted_example)
    result = circuit.decrypt(encrypted_result)
    print(f"Evaluation of {' + '.join(map(str, example))} homomorphically = {result}")
t2 = time.time()
print(f"time: {t2-t1}")

入力したタプルについて暗号化して足し算をするチュートリアルです。
ここで、入力している input_setについてですが、
実際に入るようなデータをiterator の形で渡す必要があると書かれています。
これにより、TFHEでの計算に必要なビット数などを算出しコンパイルするということです。

実行結果

Compiling...
time: 0.27693891525268555
Generating keys...
time: 21.034945964813232
Evaluation of 3 + 4 homomorphically = 7
Evaluation of 1 + 2 homomorphically = 3
Evaluation of 7 + 7 homomorphically = 14
Evaluation of 0 + 0 homomorphically = 0
time: 0.00379180908203125

デコレータによる表記

test_decorator.py
import concrete.numpy as cnp

@cnp.compiler({"x": "encrypted"})
def f(x):
    return x + 42

inputset = range(10)
circuit = f.compile(inputset)

assert circuit.encrypt_run_decrypt(10) == f(10)

デコレータを使うことで、関数をコンパイルして暗号状態で走る回路へと変換することができます。
先ほどの例に比べると、f.compileのところで回路のコンパイルが走るためそこに時間がかかります。

実行結果

time: 0.2761857509613037
time: 179.76507902145386

ルックアップテーブルを参照

ルックアップテーブルを作成しておき、暗号化されたインデックスに対してルックアップテーブルを参照することのできるチュートリアルは以下になっています。

インデックス情報を暗号文にし、プログラマブルブートストラップを用いて参照しています。

test_lut.py
import time
import concrete.numpy as cnp

table = cnp.LookupTable([2, -1, 3, 0])

@cnp.compiler({"x": "encrypted"})
def f(x):
    return table[x]

t1 = time.time()
inputset = range(4)
circuit = f.compile(inputset)
t2 = time.time()
print(f"time: {t2-t1}")

t1 = time.time()
assert circuit.encrypt_run_decrypt(0) == table[0] == 2
t2 = time.time()
print(f"time: {t2-t1}")
assert circuit.encrypt_run_decrypt(1) == table[1] == -1
assert circuit.encrypt_run_decrypt(2) == table[2] == 3
assert circuit.encrypt_run_decrypt(3) == table[3] == 0

実行結果

time: 0.2784440517425537
time: 13.55506420135498

LUTを関数から自動で作成

test_lut_fuse.py
import concrete.numpy as cnp
import numpy as np

@cnp.compiler({"x": "encrypted"})
def f(x):
    return (42 * np.sin(x)).astype(np.int64) // 10

inputset = range(8)
circuit = f.compile(inputset)

for x in range(8):
    assert circuit.encrypt_run_decrypt(x) == f(x)

この例では、sin関数やfloor関数を内部に持つ関数 f に対して、
関数をコンパイルすることでLUT化し、それを参照することでこの関数を実行している例です。
先ほどのようにルックアップテーブルを自分でリストのように作ってインデックスを参照するのではなく、
あくまで非線形関数を実行するような構成でルックアップテーブルが内部で参照されます。

自分で作ったサンプルプログラム(ベクトル同士の内積)

以上を踏まえて、以下のようなプログラムを作成してみました。
やっていることは非常に単純ですが、3次元同士のベクトルの内積です。
一方のベクトルは暗号文で、もう一方のベクトルは平文で計算しています。

このとき、どちらも暗号文でやろうとすると、

only dot product between encrypted and clear is supported
return %2

のようなエラー文がでてサポートされていないことがわかりました。

test_dot.py
import time
import concrete.numpy as cnp
import numpy as np

n_bits = 3

def dot(x, y):
    #return np.dot(x, y)
    return np.dot(x,y)

t1 = time.time()
#compiler = cnp.Compiler(dot, {"x": "encrypted", "y": "encrypted"})
compiler = cnp.Compiler(dot, {"x": "encrypted", "y": "clear"})
t2 = time.time()
print(f"time: {t2-t1}")


x = np.array([1,2,3])
y = np.array([1,1,4])


inputset = [
    (
        np.random.randint(0, 2**n_bits, size=x.shape),
        np.random.randint(0, 2**n_bits, size=y.shape),
    )
    for _ in range(5)
]

t1 = time.time()
circuit = compiler.compile(inputset)
t2 = time.time()
print(f"time: {t2-t1}")


clear_evaluation = dot(x, y)

t1 = time.time()
enc = circuit.encrypt(x, y)
run = circuit.run(enc)
res = circuit.decrypt(run)
t2 = time.time()
print(f"time: {t2-t1}")

print(clear_evaluation)
print(res)

実行結果

time: 4.1961669921875e-05
time: 0.30979108810424805
time: 167.0315079689026
15
15

内積計算で167秒実行時間がかかる、というのはやはり重たいですね、、、

自分で作ったサンプルプログラム(比較演算)

結論から言うと、以下のようなプログラムはRunTimeErrorが発生して通りませんでした。

test_compare.py
import time
import concrete.numpy as cnp
import numpy as np

n_bits = 3

@cnp.compiler({"x": "encrypted", "y": "encrypted"})
def my_evaluation(x, y):
    if x > y:
        return 10
    else:
        return 20


inputset = [
    (
        np.random.randint(0, 2**n_bits),
        np.random.randint(0, 2**n_bits),
    )
    for _ in range(5)
]
print(inputset)

t1 = time.time()
my_circuit = my_evaluation.compile(inputset)
t2 = time.time()
print(f"time: {t2-t1}")


x = 10
y = 12


clear_evaluation = my_evaluation(x, y)

t1 = time.time()
enc = my_circuit.encrypt(x, y)
run = my_circuit.run(enc)
res = my_circuit.decrypt(run)
t2 = time.time()
print(f"time: {t2-t1}")

print(clear_evaluation)
print(res)

そのあと、np.greater がサポートされているようだったので、

test_compare.py
import time
import concrete.numpy as cnp
import numpy as np

n_bits = 3

@cnp.compiler({"x": "encrypted", "y": "encrypted"})
def my_evaluation(x, y):
    return np.greater(x, y)


inputset = [
    (
        np.random.randint(0, 2**n_bits),
        np.random.randint(0, 2**n_bits),
    )
    for _ in range(5)
]
print(inputset)

t1 = time.time()
my_circuit = my_evaluation.compile(inputset)
t2 = time.time()
print(f"time: {t2-t1}")


x = 10
y = 12


clear_evaluation = my_evaluation(x, y)

t1 = time.time()
enc = my_circuit.encrypt(x, y)
run = my_circuit.run(enc)
res = my_circuit.decrypt(run)
t2 = time.time()
print(f"time: {t2-t1}")

print(clear_evaluation)
print(res)

で試してみましたが、これもRunTimeErrorで実行できませんでした。

この辺りは私がよくライブラリを理解していないためにこうなっている可能性も大いにあるので、
もう少し勉強してから出直したいと思います。
検証したかったのは、暗号文同士の比較演算でした。

まとめ

今回は、concrete-numpy を用いて、
用意されているチュートリアルを一通りやってみて、
最後にベクトルの内積を計算するようなサンプルを自分で書いて実行してみました。

使い所がいまいちまだわかりませんが、
このレイヤでいろいろな計算を実行してみて、他にどんなことができるのか
確かめてみたいと思います。

今回はこの辺で。

@kenmaro

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0