0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

AlpacaHack Round5: XorshiftStream Upsolve

Posted at

XorshiftStream

authored by keymoon

What? An XSS challenge in Crypto?

import os
import secrets
from Crypto.Util.strxor import strxor

class XorshiftStream:
    def __init__(self, key: int):
        self.state = key % 2**64

    def _next(self):
        self.state = (self.state ^ (self.state << 13)) % 2**64
        self.state = (self.state ^ (self.state >> 7)) % 2**64
        self.state = (self.state ^ (self.state << 17)) % 2**64
        return self.state

    def encrypt(self, data: bytes):
        ct = b""
        for i in range(0, len(data), 8):
            pt_block = data[i : i + 8]
            ct += (int.from_bytes(pt_block, "little") ^ self._next()).to_bytes(
                8, "little"
            )[: len(pt_block)]
        return ct

FLAG = os.environ.get("FLAG", "fakeflag").encode()

xss = XorshiftStream(secrets.randbelow(2**64))
key = secrets.token_bytes(len(FLAG))

c = xss.encrypt(key.hex().encode() + strxor(key, FLAG))
print(c.hex())

状態を$s \in ${$0, 1, \dots, 2^{64}-1$}として、各ステップにおいて次の写像$F$により更新しています。

\begin{align}
s \leftarrow s \oplus (s \ll 13) \pmod {2^{64}}\\
s \leftarrow s \oplus (s \gg 7) \pmod {2^{64}}\\
s \leftarrow s \oplus (s \ll 17) \pmod {2^{64}}
\end{align}

そして、暗号文$c$はc = xss.encrypt(key.hex().encode() + strxor(key, FLAG))で生成されていますが、最初のkeyはsecrets.token_bytes(len(FLAG)なので、仮にlen(flag)が$n$バイトだとすると、key.hex().encode() は$2n$バイトとなります。そして、strxor(key, FLAG)FLAGとxorしているため、$n$バイトになり、合計$3n$バイトであることがわかります。

解法

今回の嬉しいところはkeyがhexエンコードされているところです。
つまり、どのような鍵だったとしても0~9, a~fのいずれかの値になることが確定します。

以上の考察から、encを64ビットブロックに直して未知の初期状態s0を変数にしてXorShiftを回し、復号結果がhexの範囲になるような制約をつけることで初期状態を復元することができ、これはz3で実装できます。それ以外にも解法はあるとは思うのですが、私がz3に慣れていないので、使わせてください。

まず、準備をしていきましょう。
pt_block = data[i : i + 8]であることから、$8\times2=16$文字ずつ区切る必要があります。
つまり、

from z3 import *
import struct 
from Crypto.Util.strxor import strxor

def z3_next(state):
    state = (state ^ (state << 13))
    state = (state ^ LShR(state,7))
    state = (state ^ (state << 17))
    return state

enc = "142d35c86db4e4bb82ca5965ca1d6bd55c0ffeb35c8a5825f00819821cd775c4c091391f5eb5671b251f5722f1b47e539122f7e5eadc00eee8a6a631928a0c14c57c7e05b6575067c336090f85618c8e181eeddbb3c6e177ad0f9b16d23c777b313e62b877148f06014e8bf3bc156bf88eedd123ba513dfd6fcb32446e41a5b719412939f5b98ffd54c2b5e44f4f7a927ecaff337cddf19fa4e38cbe01162a1b54bb43b0678adf2801d893655a74c656779f9a807c3125b5a30f4800a8"
key_range = 2*len(enc)//3
enc_key = enc[:key_range]
enc_flag = enc[key_range:] 


seed_state = BitVec('seed_state',64)
ct_u64blocks = []
for i in range(len(enc_key)//16):
        block_hex = enc_key[i*(16):(i+1)*16]
        block_bytes = bytes.fromhex(block_hex)
        block_u64 = struct.unpack("<Q", block_bytes)[0]
        ct_u64blocks.append(BitVecVal(block_u64, 64))

s = Solver()
for r, ct_block in enumerate(ct_u64blocks):
        state = seed_state
        for _ in range(r+1):
            state = z3_next(state)
        pt_block = state ^ ct_block
        for byte_index in range(8):
            pt_byte = Extract(8 * (byte_index+1) - 1, 8 * byte_index, pt_block)
            is_digit = And(pt_byte >= 0x30, pt_byte <= 0x39)
            is_lower = And(pt_byte >= 0x61, pt_byte <= 0x66)
            s.add(Or(is_digit, is_lower))

init_state = 0
if s.check() == sat:
    init_state = s.model()[seed_state].as_long()
    print("[+] init_state: "+hex(init_state))
else:
    print("no")
    raise SystemExit(1)
[+] init_state: 0xb57e0c257802dad1

ということで、初期状態が求まりました。後は、復元していきましょう。平文$P$は
$$
P = \text{key.hex().encode()} || \text{key} \oplus FLAG)
$$
となっているため、最初にkey.hex().encode()からkeyを求め、その後、keyencの最後$n$バイトでXORを取ればフラグが得られるでしょう。
ただし、hexエンコードされているため、元々の文字列の長さは$\text{len}(enc)/2$となる点と、unpackする際に必ず8バイト必要なため、足りない場合はパディングする必要があります。

state = init_state
len_plain = len(enc)//2
plain = b""
for i in range(0, len(enc), 16):
    state = next(state)
    plain_bytes = bytes.fromhex(enc[i:i+16])
    if len(plain_bytes) < 8:
         plain_bytes += b'\x00'*(8-len(plain_bytes))
    key_block = struct.unpack("<Q", plain_bytes)[0]
    plain += struct.pack("<Q", key_block^state)

plain = plain[:len_plain]
key = bytes.fromhex(plain[:key_range//2].decode())
xor_enc = plain[key_range//2:]

以上の操作をまとめたコードは以下になります。

from z3 import *
import struct 
from Crypto.Util.strxor import strxor

def next(state):
    state = (state ^ (state << 13)) % 2**64
    state = (state ^ (state >> 7)) % 2**64
    state = (state ^ (state << 17)) % 2**64
    return state


def z3_next(state):
    state = (state ^ (state << 13))
    state = (state ^ LShR(state,7))
    state = (state ^ (state << 17))
    return state

enc = "142d35c86db4e4bb82ca5965ca1d6bd55c0ffeb35c8a5825f00819821cd775c4c091391f5eb5671b251f5722f1b47e539122f7e5eadc00eee8a6a631928a0c14c57c7e05b6575067c336090f85618c8e181eeddbb3c6e177ad0f9b16d23c777b313e62b877148f06014e8bf3bc156bf88eedd123ba513dfd6fcb32446e41a5b719412939f5b98ffd54c2b5e44f4f7a927ecaff337cddf19fa4e38cbe01162a1b54bb43b0678adf2801d893655a74c656779f9a807c3125b5a30f4800a8"
key_range = 2*len(enc)//3
enc_key = enc[:key_range]
enc_flag = enc[key_range:] 


seed_state = BitVec('seed_state',64)
ct_u64blocks = []
for i in range(len(enc_key)//16):
        block_hex = enc_key[i*(16):(i+1)*16]
        block_bytes = bytes.fromhex(block_hex)
        block_u64 = struct.unpack("<Q", block_bytes)[0]
        ct_u64blocks.append(BitVecVal(block_u64, 64))

s = Solver()
for r, ct_block in enumerate(ct_u64blocks):
        state = seed_state
        for _ in range(r+1):
            state = z3_next(state)
        pt_block = state ^ ct_block
        for byte_index in range(8):
            pt_byte = Extract(8 * (byte_index+1) - 1, 8 * byte_index, pt_block)
            is_digit = And(pt_byte >= 0x30, pt_byte <= 0x39)
            is_lower = And(pt_byte >= 0x61, pt_byte <= 0x66)
            s.add(Or(is_digit, is_lower))

init_state = 0
if s.check() == sat:
    init_state = s.model()[seed_state].as_long()
    print("[+] init_state: "+hex(init_state))
else:
    print("no")
    raise SystemExit(1)

state = init_state
len_plain = len(enc)//2
plain = b""
for i in range(0, len(enc), 16):
    state = next(state)
    plain_bytes = bytes.fromhex(enc[i:i+16])
    if len(plain_bytes) < 8:
         plain_bytes += b'\x00'*(8-len(plain_bytes))
    key_block = struct.unpack("<Q", plain_bytes)[0]
    plain += struct.pack("<Q", key_block^state)

plain = plain[:len_plain]
key = bytes.fromhex(plain[:key_range//2].decode())
xor_enc = plain[key_range//2:]

print(strxor(key, xor_enc).decode())

Flag:Alpaca{I'v3_n3v3r_seen_4_c1ient_51d3_CryptoWeb_ch4ll3ng3_0nce!}

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?