#はじめに
この記事で「できない」という主張をしていましたができました(Pythonのハッシュ衝突攻撃の考察2: 辞書のキー検索を故意に衝突させられました)。この記事は2年前に考えていたことの記録として残しておきます。
この記事は、Pythonにおける悪意を持ったdict, setに対するハッシュ衝突攻撃を考察します。一般的なハッシュ,辞書の実装を理解していることを前提とします。
C++の場合はmap, setの衝突を以下の記事のように簡単に引き起こせます。Pythonはどうでしょうか?
サマリ
- Pythonでも実装として辞書の衝突は起こり、O(N)で各操作が発生しうる
-
(旧記載)ただし、そのためには61ビットのハッシュの衝突をさせなければならないので現実的でない
- 数値入力で衝突を引き起こす入力をすることは現実的でない
- 文字列に関しても、ハッシュの衝突を発生させることは現実的でない
-
(旧記載)このため、競技プログラミングの範囲ではhackケースにpythonの辞書を攻撃することは困難
- 定数倍をO(9)くらいにはできます。数値の辞書への追加は
mod (2**61-1)
した値を右から5bitずつ使って乱数を生成させるので、例えば50bit同じpaddingをした数値を入れれば、多くの操作で5bit x 10回分のハッシュの値を衝突させることができます(まとめ、に記載しています)
- 定数倍をO(9)くらいにはできます。数値の辞書への追加は
結論
(旧記載)
Pythonではハッシュの衝突で各操作がO(N)になることは考えずに辞書は使ってよい。各操作がO(9)くらいになる数列が作れる程度が現実的。
と書いていましたができました(Pythonのハッシュ衝突攻撃の考察2: 辞書のキー検索を故意に衝突させられました)
まずは実例: PythonでO(N)の辞書操作がかかる例
2の61乗-1
の倍数を入れ続けるとO(N)になります。
サンプルコードは、2^60 - 1
,2^61 - 1
, 2^62 - 1
の倍数を20000個代入するコードです。
2^60-1 elapsed_time:0.0[sec]
2^61-1 elapsed_time:1.8965110778808594[sec]
2^62-1 elapsed_time:0.0[sec]
2^60-1
と2^62-1
の倍数がほぼ瞬時に処理できているのに対して、2^61-1
の倍数の処理には大きな時間がかかっています。
2^61-1
の倍数を2万個はint64の領域を超えるので、競技プログラミングでは現実的ではありませんが、衝突を発生させることは可能ということがわかりました。
まず、これで、Pythonでもうまく値を選ぶとハッシュ衝突が生じそうということになります。
例えば、この処理を、int64の範囲で自由に起こすことができれば、競技プログラミングで、ハッシュ衝突攻撃となるケースを作れそうです。さて、できるのでしょうか?
復習: ハッシュの実装例
復習です。C++の場合、以下のように実装されています。
- 数値の場合は適当なmodをとり(文字列の場合はハッシュを取り)そのindexに一致する辞書の配列に値を代入します。
- 例えば、そのdictのmodのサイズが8で、0,4,5,8,12,16を入れようとすると、
dict[0] = [0, 8, 16]
dict[4] = [4, 12]
dict[5] = [5]
のようになります。冒頭にあげたC++の例では(元記事に記載の通り)mod 107897
が同じキーになる値を入れ続けると、dictへのアクセスがO(1)ではなくO(N)になります。
Pythonにおけるハッシュ計算
Pythonの場合は、(C++のように)1回のハッシュでindexを決めてその先をlistにするのではなくて、当たりが来るまで疑似乱数でindexを計算し続けるです。(当たりとは、追加の際は空のindex, 削除・参照の場合は該当のindexのことです)
https://stackoverflow.com/questions/327311/how-are-pythons-built-in-dictionaries-implemented/9022835#9022835
この情報はとても参考になります。
- ハッシュのindexの数は、もちろん無限ではなく、ハッシュに入れる全体の数に合わせた適当なサイズです(最も小さいのはint 8から始まる)。全体の数に応じて動的にサイズが拡張されます。
- C++と違って、このハッシュエントリはリストではありません。各箱はただ1つの値を持ち(key, value)が格納されています。
- ハッシュにアクセスする際、indexを計算し、そのエントリが探しているものでなければ、キーを用いた疑似乱数(後述)を使ってindexを次のエントリのindexに変換します。これをあたりが見つかるまで繰り返します。
具体的にコードで例を見ていきましょう。
- 最初のキーはhash()関数で計算されます。これをオリジナルindexと呼びます。
- 使われるのは、siphash24です。Source-PyDict_GetItemにあるとおり、Source-PyObject_Hash
がcallされます。 - 数値に対しては
mod (2**61-1)
と考えてよいです。
- 使われるのは、siphash24です。Source-PyDict_GetItemにあるとおり、Source-PyObject_Hash
- 次の$index$を作る疑似乱数コードは、今の$index$を$i$として
perturb
という疑似乱数があるうえで、i = (i*5 + perturb + 1) & mask
です。詳しくはこの下のコードで。- このロジックが記載されている例: そのキーを挿入するときの関数lookup(find_empty_slot)
- このロジックが記載されている例: キーを検索するときの関数(_Py_dict_lookup)
- maskはハッシュのサイズに応じて動的に変わります。ここの定義を見てください。
- 候補となるindexに対しては、
dictkeys_get_index()
を使ってそのindexが欲しかったvalueのものかを検証します。
indexの探索は以下のようになっています。
const size_t mask = DK_MASK(keys); /* 今のハッシュサイズ相当のマスク (例えばint32) */
size_t i = hash & mask; /* hashは最初オリジナルindex*/
Py_ssize_t ix = dictkeys_get_index(keys, i); /* 一発目の検索。これで空きが見つかれば以下のforは回らない*/
/* perturb = 疑似乱数 = index候補 */
for (size_t perturb = hash; ix >= 0;) { /* dictkeys_get_indexが0 = 空きがあたるまで検索 */
perturb >>= PERTURB_SHIFT; // PERTURB_SHIFT = 5 でstaticです
i = (i*5 + perturb + 1) & mask;
ix = dictkeys_get_index(keys, i);
}
ハッシュのサイズは以下のようになります。競技プログラミングでよく見る $2 \times 10^{5}$くらいの場合、int32の空間となります。
* int8 for dk_size <= 128
* int16 for 256 <= dk_size <= 2**15
* int32 for 2**16 <= dk_size <= 2**31
* int64 for 2**32 <= dk_size
疑似コードによる衝突の様子
サンプルコードを書きました。実行結果を見てみましょう。これは上記のコードをシミュレートするものです。
// 0,1,2,3を入れようとすると、それぞれ初回のhashでindex=0,1,2,3をつかみに行き1度のindex計算で代入したいindexを見つけられました
hash sim. you want to add [0, 1, 2, 3] to hashtable
OK 0 0 ok ix= 0 ngcnt 0
OK 1 1 ok ix= 1 ngcnt 0
OK 2 2 ok ix= 2 ngcnt 0
OK 3 3 ok ix= 3 ngcnt 0
// 2^61-1の倍数を入れようとすると、それぞれ初回のhashでindex=0をつかみに行き、
// そのあとのリトライでも同じようなindexをつかみに行く(0->1->6->31)のため、1回の代入でO(N)かかるようになります
hash sim. you want to add [0, 2305843009213693951, 4611686018427387902, 6917529027641081853] to hashtable
OK 0 0 ok ix= 0 ngcnt 0
NG ix= 0 ngcnt 0 perturb 0
OK 2305843009213693951 0 ok ix= 1 ngcnt 1
NG ix= 0 ngcnt 0 perturb 0
NG ix= 1 ngcnt 1 perturb 0
OK 4611686018427387902 0 ok ix= 6 ngcnt 2
NG ix= 0 ngcnt 0 perturb 0
NG ix= 1 ngcnt 1 perturb 0
NG ix= 6 ngcnt 2 perturb 0
OK 6917529027641081853 0 ok ix= 31 ngcnt 3
なぜ、2^61-1の倍数に弱いのか?
hash関数が決まった値(この場合は0)を返すからです。以下のPythonの実行結果を見てください。
print(hash(0)) # 0
print(hash(1)) # 1
print(hash(2**61-2)) # 2305843009213693950
print(hash((2**61 - 1) * 1)) # 0
print(hash((2**61 - 1) * 2)) # 0
Pythonのハッシュ関数は、2**61-1
の倍数に対して0を返します。このため、常に、perturbが0から始まり、疑似乱数の発生が必ず同じ遷移になります。これは0である必要はなく、通知が同じであれば衝突します。siphash24は通知に対して、mod (2**61-1)
だったので、例えば、
[(261 ) * i + 1 - i for i in range(100)] # mod(261-1)が1の数列
このhashは1になるので、衝突させられます。
まとめ: Pythonのhashは衝突させられるのか?(競技プログラミング観点で)
できました(Pythonのハッシュ衝突攻撃の考察2: 辞書のキー検索を故意に衝突させられました)
(旧記載)現実的にNoです。
- 数値入力の場合、通常、$10^{18}$の範囲内です。これは、$2^{60}$にも満たないため、すべてユニークなハッシュ値を取ります。
- 文字列の入力の場合、ハッシュを衝突させられる文字をN個用意できるなら可能です。=現実的ではありません
(旧記載)少し考えてみると、O(9)程度までの衝突は引き起こせます。$perturb$は左から5bitごとに疑似乱数生成を行うので、たとえば
sim([i << 50 for i in range(1000)])
という数列を流し込むと、各操作がO(9)くらいにはなります。
(旧記載)ということで、Pythonではハッシュの衝突でO(N)になることは考えずに辞書は使ってよさそうです。
Pythonのハッシュ衝突の例
import time
p = 0
d = dict()
k = (2 ** (60*1)) -1
start = time.time()
d = dict()
for i in range(20000):
d[k * i] = True
print("2^60-1 elapsed_time:{0}".format(time.time() - start) + "[sec]")
k = (2 ** (61*1)) -1
start = time.time()
d = dict()
for i in range(20000):
d[k * i] = True
print("2^61-1 elapsed_time:{0}".format(time.time() - start) + "[sec]")
k = (2 ** (62*1)) -1
start = time.time()
d = dict()
for i in range(20000):
d[k * i] = True
print("2^62-1 elapsed_time:{0}".format(time.time() - start) + "[sec]")
hashシミュレータ
def sim(candi):
print("hash sim. you want to add {0} to hashtable".format(candi))
PERTURB_SHIFT = 5
used = set() # 今回は簡単に使っているかのテストのみ
nn = 5
data = []
datb = []
mask = 2**16 - 1 # 簡易化のため、ハッシュサイズはある程度大きいと仮定
res = []
final = []
visited = set()
finalres = -1
for initval in candi:
hashval = initval
hashval = hash(hashval)
i = hashval & mask # これについて計算を開始
perturb = hashval # forの初期化部分
ngcnt = 0 # ixを計算しなおした数
while i in used:
print("NG ix=", i, "ngcnt", ngcnt, "perturb", perturb)
ngcnt +=1
perturb >>= PERTURB_SHIFT
i = (i * 5 + perturb + 1) & mask
print("OK", initval ,hashval, "ok ix=", i, "ngcnt", ngcnt, )
used.add(i)
sim([0,1,2,3])
sim([i* (2**61 - 1) for i in range(4)])
print(bin(1 * 2**64 - 1))
print(bin(2 * (2**64 - 1)))
print(bin(2**64 - 1).count("1"))