XorshiftStream
authored by keymoon
What? An XSS challenge in Crypto?
import os
import secrets
from Crypto.Util.strxor import strxor
class XorshiftStream:
def __init__(self, key: int):
self.state = key % 2**64
def _next(self):
self.state = (self.state ^ (self.state << 13)) % 2**64
self.state = (self.state ^ (self.state >> 7)) % 2**64
self.state = (self.state ^ (self.state << 17)) % 2**64
return self.state
def encrypt(self, data: bytes):
ct = b""
for i in range(0, len(data), 8):
pt_block = data[i : i + 8]
ct += (int.from_bytes(pt_block, "little") ^ self._next()).to_bytes(
8, "little"
)[: len(pt_block)]
return ct
FLAG = os.environ.get("FLAG", "fakeflag").encode()
xss = XorshiftStream(secrets.randbelow(2**64))
key = secrets.token_bytes(len(FLAG))
c = xss.encrypt(key.hex().encode() + strxor(key, FLAG))
print(c.hex())
状態を$s \in ${$0, 1, \dots, 2^{64}-1$}として、各ステップにおいて次の写像$F$により更新しています。
\begin{align}
s \leftarrow s \oplus (s \ll 13) \pmod {2^{64}}\\
s \leftarrow s \oplus (s \gg 7) \pmod {2^{64}}\\
s \leftarrow s \oplus (s \ll 17) \pmod {2^{64}}
\end{align}
そして、暗号文$c$はc = xss.encrypt(key.hex().encode() + strxor(key, FLAG))で生成されていますが、最初のkeyはsecrets.token_bytes(len(FLAG)なので、仮にlen(flag)が$n$バイトだとすると、key.hex().encode() は$2n$バイトとなります。そして、strxor(key, FLAG)はFLAGとxorしているため、$n$バイトになり、合計$3n$バイトであることがわかります。
解法
今回の嬉しいところはkeyがhexエンコードされているところです。
つまり、どのような鍵だったとしても0~9, a~fのいずれかの値になることが確定します。
以上の考察から、encを64ビットブロックに直して未知の初期状態s0を変数にしてXorShiftを回し、復号結果がhexの範囲になるような制約をつけることで初期状態を復元することができ、これはz3で実装できます。それ以外にも解法はあるとは思うのですが、私がz3に慣れていないので、使わせてください。
まず、準備をしていきましょう。
pt_block = data[i : i + 8]であることから、$8\times2=16$文字ずつ区切る必要があります。
つまり、
from z3 import *
import struct
from Crypto.Util.strxor import strxor
def z3_next(state):
state = (state ^ (state << 13))
state = (state ^ LShR(state,7))
state = (state ^ (state << 17))
return state
enc = "142d35c86db4e4bb82ca5965ca1d6bd55c0ffeb35c8a5825f00819821cd775c4c091391f5eb5671b251f5722f1b47e539122f7e5eadc00eee8a6a631928a0c14c57c7e05b6575067c336090f85618c8e181eeddbb3c6e177ad0f9b16d23c777b313e62b877148f06014e8bf3bc156bf88eedd123ba513dfd6fcb32446e41a5b719412939f5b98ffd54c2b5e44f4f7a927ecaff337cddf19fa4e38cbe01162a1b54bb43b0678adf2801d893655a74c656779f9a807c3125b5a30f4800a8"
key_range = 2*len(enc)//3
enc_key = enc[:key_range]
enc_flag = enc[key_range:]
seed_state = BitVec('seed_state',64)
ct_u64blocks = []
for i in range(len(enc_key)//16):
block_hex = enc_key[i*(16):(i+1)*16]
block_bytes = bytes.fromhex(block_hex)
block_u64 = struct.unpack("<Q", block_bytes)[0]
ct_u64blocks.append(BitVecVal(block_u64, 64))
s = Solver()
for r, ct_block in enumerate(ct_u64blocks):
state = seed_state
for _ in range(r+1):
state = z3_next(state)
pt_block = state ^ ct_block
for byte_index in range(8):
pt_byte = Extract(8 * (byte_index+1) - 1, 8 * byte_index, pt_block)
is_digit = And(pt_byte >= 0x30, pt_byte <= 0x39)
is_lower = And(pt_byte >= 0x61, pt_byte <= 0x66)
s.add(Or(is_digit, is_lower))
init_state = 0
if s.check() == sat:
init_state = s.model()[seed_state].as_long()
print("[+] init_state: "+hex(init_state))
else:
print("no")
raise SystemExit(1)
[+] init_state: 0xb57e0c257802dad1
ということで、初期状態が求まりました。後は、復元していきましょう。平文$P$は
$$
P = \text{key.hex().encode()} || \text{key} \oplus FLAG)
$$
となっているため、最初にkey.hex().encode()からkeyを求め、その後、keyとencの最後$n$バイトでXORを取ればフラグが得られるでしょう。
ただし、hexエンコードされているため、元々の文字列の長さは$\text{len}(enc)/2$となる点と、unpackする際に必ず8バイト必要なため、足りない場合はパディングする必要があります。
state = init_state
len_plain = len(enc)//2
plain = b""
for i in range(0, len(enc), 16):
state = next(state)
plain_bytes = bytes.fromhex(enc[i:i+16])
if len(plain_bytes) < 8:
plain_bytes += b'\x00'*(8-len(plain_bytes))
key_block = struct.unpack("<Q", plain_bytes)[0]
plain += struct.pack("<Q", key_block^state)
plain = plain[:len_plain]
key = bytes.fromhex(plain[:key_range//2].decode())
xor_enc = plain[key_range//2:]
以上の操作をまとめたコードは以下になります。
from z3 import *
import struct
from Crypto.Util.strxor import strxor
def next(state):
state = (state ^ (state << 13)) % 2**64
state = (state ^ (state >> 7)) % 2**64
state = (state ^ (state << 17)) % 2**64
return state
def z3_next(state):
state = (state ^ (state << 13))
state = (state ^ LShR(state,7))
state = (state ^ (state << 17))
return state
enc = "142d35c86db4e4bb82ca5965ca1d6bd55c0ffeb35c8a5825f00819821cd775c4c091391f5eb5671b251f5722f1b47e539122f7e5eadc00eee8a6a631928a0c14c57c7e05b6575067c336090f85618c8e181eeddbb3c6e177ad0f9b16d23c777b313e62b877148f06014e8bf3bc156bf88eedd123ba513dfd6fcb32446e41a5b719412939f5b98ffd54c2b5e44f4f7a927ecaff337cddf19fa4e38cbe01162a1b54bb43b0678adf2801d893655a74c656779f9a807c3125b5a30f4800a8"
key_range = 2*len(enc)//3
enc_key = enc[:key_range]
enc_flag = enc[key_range:]
seed_state = BitVec('seed_state',64)
ct_u64blocks = []
for i in range(len(enc_key)//16):
block_hex = enc_key[i*(16):(i+1)*16]
block_bytes = bytes.fromhex(block_hex)
block_u64 = struct.unpack("<Q", block_bytes)[0]
ct_u64blocks.append(BitVecVal(block_u64, 64))
s = Solver()
for r, ct_block in enumerate(ct_u64blocks):
state = seed_state
for _ in range(r+1):
state = z3_next(state)
pt_block = state ^ ct_block
for byte_index in range(8):
pt_byte = Extract(8 * (byte_index+1) - 1, 8 * byte_index, pt_block)
is_digit = And(pt_byte >= 0x30, pt_byte <= 0x39)
is_lower = And(pt_byte >= 0x61, pt_byte <= 0x66)
s.add(Or(is_digit, is_lower))
init_state = 0
if s.check() == sat:
init_state = s.model()[seed_state].as_long()
print("[+] init_state: "+hex(init_state))
else:
print("no")
raise SystemExit(1)
state = init_state
len_plain = len(enc)//2
plain = b""
for i in range(0, len(enc), 16):
state = next(state)
plain_bytes = bytes.fromhex(enc[i:i+16])
if len(plain_bytes) < 8:
plain_bytes += b'\x00'*(8-len(plain_bytes))
key_block = struct.unpack("<Q", plain_bytes)[0]
plain += struct.pack("<Q", key_block^state)
plain = plain[:len_plain]
key = bytes.fromhex(plain[:key_range//2].decode())
xor_enc = plain[key_range//2:]
print(strxor(key, xor_enc).decode())
Flag:Alpaca{I'v3_n3v3r_seen_4_c1ient_51d3_CryptoWeb_ch4ll3ng3_0nce!}