Fully Padded RSA
authored by yu212
eは大きいし、パディングもしてるし、絶対に安全です!
import os
from Crypto.Util.number import *
from math import gcd
flag = os.environ.get("FLAG", "Alpaca{dummy}")
assert len(flag) <= 40
e1 = 65517
e2 = 65577
while True:
p = getPrime(512)
q = getPrime(512)
if gcd((p-1)*(q-1), e1) == gcd((p-1)*(q-1), e2) == 1:
break
n = p * q
padded_flag = long_to_bytes(n)[:-len(flag)] + flag.encode()
m = bytes_to_long(padded_flag)
assert m < n
c1 = pow(m, e1, n)
c2 = pow(m, e2, n)
print(f"{n = }")
print(f"{c1 = }")
print(f"{c2 = }")
$e_1, e_2$と公開鍵が2つあり、$c_1,c_2$と暗号文が2つ公開されています。
これは、Common Modulus Attackが使えます。
(参考: https://zenn.dev/anko/articles/ctf-crypto-rsa)
\displaylines{c_1 = m^{e_1} \pmod n\\ c_2 = m^{e_2} \pmod n}
とありますが、ここで、$e_1, e_2$について拡張ユークリッド互除法を用いると
g = gcd(e_1, e_2) = ue_1+ve_2
となり、
m^g = m^{ue_1+ve_2} = c_1^uc_2^v \pmod n
と表すことができます。もし、$e_1, e_2$が互いに素のときか$g$が小さい場合、Low Public Exponent Attack($m^e < N$のとき、$e$乗根を取ってあげることでmが求まる攻撃)を用いることでmを求めることができます。
ということで実際にやってみましょう。
gcd(e_1,e_2) = -1013e_1+1092e_2 = 3
よって、$g=3$なので、$m^3$が取得できました。しかし、今回のmは以下の処理をしています。
padded_flag = long_to_bytes(n)[:-len(flag)] + flag.encode()
これにより、$m$の値は大きいことがわかり、$m^3 > N$となってしまいます。
解法
先ほどのパディングの処理を考えてみましょう。
これは、nのバイト列のうち下位len(flag)バイトをflagをエンコードしたバイト列の置き換えています。
$l=len(flag)$とすると、flagをエンコードしたバイト列を
\mathrm{enc(flag)} = (f_0, f_1, \dots, f_{l-1})
と表せます。
そして、nのバイト列は
\text{bytes(n)} = (n_0, n_1, \dots, n_{k-1})
となり、今回のパディング処理からpadded_flagは
\text{enc(flag)} = (n_0, n_1, \dots, n_{k-l-1}, f_0, f_1, \dots, f_{l-1})
となります。
今回のflagは40未満であることがわかっているため、以下のように考えられます。
\text{enc}(\text{flag}) = (n_0, n_1, \dots, n_{k-l-1},
\underbrace{f_0, f_1, \dots, f_{l-1})}_{\leq40\ \text{bytes}}
よって、このenc(flag)を抜き出し、3乗根を計算することで、差分が求まり、それをnから引いていけば解けます。
実際、nとpadded_flagの差を$m'$とすると、求めたい平文$m$は
m = n-m'
と書けます。このとき
m^3 \equiv (n-m')^3 equiv -m'^3 \pmod n
となるため、$m^3 \pmod n$を$m_3$とおくと
m'^3 \equiv -m_3 \equiv n-m_3 \pmod n
さらに、$m'^3 < n$が成立するため
m'^3 = n - m_3
が成立します。以上より、$m'^3$の3乗根を計算して$m'$を求め、$m = n - m'$とすることで平文を復元できます。
これを実装すると以下になります。
from sympy import *
from Crypto.Util.number import long_to_bytes
n = 91102717210596990388603678426683097953697889897819753293818443119019220403217013812232251320814152567699322671590559119510246139859891156830672838769529887961956970370968572962306584295059185945752892892100975462391203805852473243296747559459800718013816237662990504689724747628304890125129146326331097856907
c1 = 84316690833236468829386139306045298111202426584048821548102362931269993141514516100633466389955824290011995159677864206138653174440904170622039293036862729884826231898868928186453091113165643576890891297150845933751243965934735328928976655465009980896153972226679588496970771925581698573227941539852081781874
c2 = 74682069306151159606579889187354529286195652598555930926994495384029865435810129236911316774977007932641783161876484392995815937986886903514990618178943843429073696833993271982336114314882872652681858748846455760309012235191324385691614015531641062111894149446939460102878320469041435024347449132388644171970
def modexp(base, exp, mod):
if exp >= 0:
return pow(base, exp, mod)
base_inv = invert(base, mod)
return pow(base_inv, -exp, mod)
e1 = 65517
e2 = 65577
g = gcd(e1, e2)
assert g == 3
e1p = e1 // g
e2p = e2 // g
u, v, g2 = gcdex(e1p, e2p)
u = int(u)
v = int(v)
assert g2 == 1
m1 = modexp(c1, u, n)
m2 = modexp(c2, v, n)
m3 = (m1 * m2) % n
t3 = n - m3
t, exact = integer_nthroot(t3, 3)
m = n - t
flag = long_to_bytes(m)
## 綺麗にするため書く
prefix = b"Alpaca{"
start = flag.rfind(prefix)
end = flag.find(b"}", start)
print(flag[start:end + 1].decode("utf-8"))
Flag: Alpaca{p4dd1n6_mu57_u53_r4nd0m_v41u3s}
解法2
未知のflagを整数$x$として考えてみましょう。
すると、$m$は
m(x) = \text{Prefix} ・256^{L} + x
と表せます。ただし、$L \leq 40$です。
これを用いて、多項式
f(x) = m(x)^3 - m_3 \equiv 0 \pmod n
を立て、Coppersmithで解けばフラグが出てきます。
今回cusoを使いました。
from sympy import gcd, gcdex, invert
from Crypto.Util.number import long_to_bytes, bytes_to_long
from sage.all import *
import cuso
n = 91102717210596990388603678426683097953697889897819753293818443119019220403217013812232251320814152567699322671590559119510246139859891156830672838769529887961956970370968572962306584295059185945752892892100975462391203805852473243296747559459800718013816237662990504689724747628304890125129146326331097856907
c1 = 84316690833236468829386139306045298111202426584048821548102362931269993141514516100633466389955824290011995159677864206138653174440904170622039293036862729884826231898868928186453091113165643576890891297150845933751243965934735328928976655465009980896153972226679588496970771925581698573227941539852081781874
c2 = 74682069306151159606579889187354529286195652598555930926994495384029865435810129236911316774977007932641783161876484392995815937986886903514990618178943843429073696833993271982336114314882872652681858748846455760309012235191324385691614015531641062111894149446939460102878320469041435024347449132388644171970
len_flag = 40
def modexp(base, exp, mod):
if exp >= 0:
return pow(base, exp, mod)
base_inv = invert(base, mod)
return pow(base_inv, -exp, mod)
e1 = 65517
e2 = 65577
g = gcd(e1, e2)
assert g == 3
e1p = e1 // g
e2p = e2 // g
u, v, g2 = gcdex(e1p, e2p)
u = int(u)
v = int(v)
assert g2 == 1
m1 = modexp(c1, u, n)
m2 = modexp(c2, v, n)
m3 = (m1 * m2) % n
n_bytes = long_to_bytes(n)
prefix_bytes = n_bytes[:-len_flag]
Prefix = bytes_to_long(prefix_bytes)
x = var('x')
m_poly = Prefix * (1 << (8 * len_flag)) + x
f = m_poly**3 - int(m3)
relations = [f]
bounds = {
x: (0, 1 << (8 * len_flag))
}
roots = cuso.find_small_roots(
relations,
bounds,
modulus=n,
)
assert len(roots) > 0
root = roots[0]
x_val = int(root[x])
print(f"flag:{long_to_bytes(x_val).decode()} ")