チョコラスクのブログ

野生のコラッタです。

IERAE CTF 2024 writeup

IERAE CTF 2024にチームBunkyoWesternsで参加しました。

結果は224チーム中1位でした!

cryptoジャンルの問題は全完することができ、うち2問でfirst bloodを取れました!

[crypto] Weak PRNG

次のようなPythonスクリプトが与えられます。

#!/usr/bin/env python

from os import getenv
import random
import secrets

FLAG = getenv("FLAG", "TEST{TEST_FLAG}")


def main():
    # Python uses the Mersenne Twister (MT19937) as the core generator.
    # Setup Random Number Generator
    rng = random.Random()
    rng.seed(secrets.randbits(32))

    secret = rng.getrandbits(32)

    print("Welcome!")
    print("Recover the initial output and input them to get the flag.")

    while True:
        print("--------------------")
        print("Menu")
        print("1. Get next 16 random data")
        print("2. Submit your answer")
        print("3. Quit")
        print("Enter your choice (1-3)")
        choice = input("> ").strip()

        if choice == "1":
            print("Here are your random data:")
            for _ in range(16):
                print(rng.getrandbits(32))
        elif choice == "2":
            print("Enter the secret decimal number")
            try:
                num = int(input("> ").strip())

                if num == secret:
                    print("Correct! Here is your flag:")
                    print(FLAG)
                else:
                    print("Incorrect number. Bye!")
                break
            except (ValueError, EOFError):
                print("Invalid input. Exiting.")
                break
        elif choice == "3":
            print("Bye!")
            break
        else:
            print("Invalid choice. Please enter 1, 2 or 3.")
            continue


if __name__ == "__main__":
    main()

乱数予測の問題です。最初に32bitの乱数 secret が生成されます。その後、何度でも乱数の出力を得ることができ、最初に生成された乱数 secret を当てることができればフラグが得られます。乱数生成にはPythonのrandomモジュールが用いられています。

Pythonのrandomモジュールはメルセンヌツイスタとよばれる擬似乱数生成アルゴリズムを用いています。メルセンヌツイスタの内部状態は32bit整数624個からなり、32bit整数624個の乱数出力が得られれば内部状態を復元可能です。内部状態の更新の巻き戻しも可能なので、初期の内部状態を復元でき、最初の乱数が求まります。

この記事に載っている実装がほぼそのまま使えました。

import random

# https://inaz2.hatenablog.com/entry/2016/03/07/194147
def untemper(x):
    x = unBitshiftRightXor(x, 18)
    x = unBitshiftLeftXor(x, 15, 0xefc60000)
    x = unBitshiftLeftXor(x, 7, 0x9d2c5680)
    x = unBitshiftRightXor(x, 11)
    return x

def unBitshiftRightXor(x, shift):
    i = 1
    y = x
    while i * shift < 32:
        z = y >> shift
        y = x ^ z
        i += 1
    return y

def unBitshiftLeftXor(x, shift, mask):
    i = 1
    y = x
    while i * shift < 32:
        z = y << shift
        y = x ^ (z & mask)
        i += 1
    return y

def get_prev_state(state):
    for i in range(623, -1, -1):
        result = 0
        tmp = state[i]
        tmp ^= state[(i + 397) % 624]
        if ((tmp & 0x80000000) == 0x80000000):
            tmp ^= 0x9908b0df
        result = (tmp << 1) & 0x80000000
        tmp = state[(i - 1 + 624) % 624]
        tmp ^= state[(i + 396) % 624]
        if ((tmp & 0x80000000) == 0x80000000):
            tmp ^= 0x9908b0df
            result |= 1
        result |= (tmp << 1) & 0x7fffffff
        state[i] = result
    return state

from ptrlib import *

sock = Socket("nc xxx.xxx.xxx.xxx xxxxx")

outputs = []
for _ in range(39):
    sock.sendlineafter("> ", "1")
    sock.recvuntil("Here are your random data:\n")
    for i in range(16):
        outputs.append(int(sock.recvline()))

mt_state = [untemper(x) for x in outputs]
prev_mt_state = get_prev_state(mt_state)
random.setstate((3, tuple(prev_mt_state + [0]), None))

secret = [random.getrandbits(32) for _ in range(624)][-1]
sock.sendlineafter("> ", "2")
sock.sendline(str(secret))
sock.interactive()

[crypto] splitting

次のようなSageMathのスクリプトが与えられます。

#!/usr/bin/env sage

from Crypto.Util.number import *
from os import getenv

FLAG = getenv("FLAG", "TEST{TEST_FLAG}").encode()
f = bytes_to_long(FLAG)

p = random_prime(2^128)
Fp = GF(p)
a, b = Fp.random_element(), Fp.random_element()
E = EllipticCurve(Fp, [a, b])

print(a)
print(b)
print(p)

gens = list(E.gens())
if len(gens) < 2:
    gens.append(ZZ(Fp.random_element()) * E.gens()[0])

res = []
while f > 0:
    r = Fp.random_element()
    res.append(ZZ(r) * gens[f & 1])
    f >>= 1

for R in res:
    print(R.xy())

128bitの素数 p および \mathbb{F}_{p} 上の楕円曲線がランダムに生成され、E(\mathbb{F}_{p}) の生成元 G_{0}, G_{1} を取っています。ただし、E(\mathbb{F}_{p}) が1つの元のみで生成される場合、G_{1} はランダムな点が選ばれます。その後、フラグの各ビット b について、ランダムな値 r が選ばれ、点 R=rG_{b} の座標が得られます。

フラグが固定なので、サーバーに何度も接続して情報を得ることが重要そうです。色々な解法がありそうですが、自分は E(\mathbb{F}_{p}) が1つの元のみで生成される場合に G_{1} はランダムな点であることに着目しました。

E の位数 n が偶数で、E(\mathbb{F}_{p}) が1つの元のみで生成される場合だけを考えます (そうでない場合は無視します)。このとき、G_{0} は位数 n ですが、G_{1} はランダムなので確率1/2で (n/2)G_{1}=O を満たします。よって、R(n/2)R=O を満たす確率は、b=0 のときは r が偶数である確率なので1/2、b=1 のときは r が偶数または  (n/2)G_{1}=O が成り立つ確率なので3/4となります。

よって、何度もサーバーに接続して R(n/2)R=O を満たす回数を数え、回数が接続回数の半分より十分多ければ対応するビットは1とわかります。

from ptrlib import *
from Crypto.Util.number import *

L = 567
cnt = [0]*L

t = 0
while t < 150:
    sock = Socket("nc xxx.xxx.xxx.xxx xxxxx")
    a = int(sock.recvline())
    b = int(sock.recvline())
    p = int(sock.recvline())

    Fp = GF(p)
    E = EllipticCurve(Fp, [a, b])
    n = E.order()
    gens = list(E.gens())
    if n%2!=0 or len(gens)>1:
        sock.close()
        continue
    
    t += 1
    for i in range(L):
        P = eval(sock.recvline())
        P = E(P)
        Q = (n//2)*P
        if Q==E(0):
            cnt[i] += 1

f = ""    
for v in cnt:
    if v > 90:
        f += "1"
    else:
        f += "0"
f = f[::-1]
print(long_to_bytes(int(f,2)))

[crypto] cluster

次のようなPythonスクリプトとその実行結果が与えられます。

from secret import p, q, r, flag
from Crypto.Util.number import isPrime, bytes_to_long

N = p * q * r

assert isPrime(p) and p.bit_length() < 1024
assert isPrime(q) and q.bit_length() < 1024
assert isPrime(r) and r.bit_length() < 1024
assert p ** 2 + q ** 2 + r ** 2 + (p * q + q * r + r * p) == 6 * N



m = bytes_to_long(flag)
e = 65537
c = pow(m, e, N)

print(f'{c = }')

3つの素数 p, q, r を使ったRSA暗号でフラグが暗号化されています。p^{2}+q^{2}+r^{2}+pq+qr+rp=6pqr が成り立つことがわかっています。N=pqr の値は与えられません。

N が与えられないので、p^{2}+q^{2}+r^{2}+pq+qr+rp=6pqr の解 (p, q, r) を列挙する必要がありそうです。東大入試の過去問x^{2}+y^{2}+z^{2}=xyz の整数解が無限個あることを示す問題があったことを思い出しました。この問題は、(x, y, z) が解なら (y, z, yz-x) も解であることを使って解けます。同様の方法を試してみることにします。

(p, q, r) (素数でなくてもよい) が p^{2}+q^{2}+r^{2}+pq+qr+rp=6pqr の解のとき、x2次方程式 x^{2}-(6qr-q-r)x+q^{2}+qr+r^{2}=0x=p を解に持ちます。この方程式のもう一つの解を p' とすると、解と係数の関係より p+p'=6qr-q-r です。よって、(6qr-q-r-p, q, r) も解であることがわかります。同様に (p, 6pr-p-r-q, r) も解であることがわかります。(p, q, r)=(1, 1, 1) から始めてこの方法で解を無数に構築することができます。

東大の過去問の方程式について調べると、この方程式はマルコフ方程式とよばれていることがわかります。「markov equation integer solution」などのキーワードで調べると、この論文が見つかります。この論文によると、p^{2}+q^{2}+r^{2}+pq+qr+rp=6pqr の解は上記の方法で生成されるもので全てであることが示せるようです。

1024bit以下の解を全て列挙してみると解は35000個程度しかないことがわかるので、これらのうち素数であるもの (3個しかない) について復号を試せばよいです。

from Crypto.Util.number import *

c = 803065078252547393812982498895211019353977926969143481455672761264443519482121067346644328911375984166893647468186232810673857290127114177258405196432172412966170401425497369188710097376895361641046391686887615687734454887428130745946475159776034046370464137762008371294039825175819408224450178007611894599399705434991448459196552982074660952318580952594830076838718297573226980847848142642550316589863549823042663312178673956251841439218528410295177672591802052069297783
e = 65537

st = set()
def dfs(p,q,r):
    if (p,q,r) in st:
        return
    if r.bit_length()>1024:
        return
    v = p ** 2 + q ** 2 + r ** 2 + (p * q + q * r + r * p)
    n = p*q*r
    assert v==6*n
    st.add((p,q,r))
    p1 = (q**2+r**2+q*r)//p
    q1 = (p**2+r**2+p*r)//q
    dfs(*sorted([p1,q,r]))
    dfs(*sorted([p,q1,r]))

dfs(1,1,1)

for p,q,r in st:
    n = p*q*r
    if not(isPrime(p) and isPrime(q) and isPrime(r)):
        continue
    print(p,q,r)
    d = pow(e,-1,(p-1)*(q-1)*(r-1))
    m = pow(c,d,n)
    flag = long_to_bytes(m)
    print(flag)
    # if b'IERAE' in flag:
    #     print(flag)

この方法で全ての解が列挙できることを確認する前に実装はしていましたがフラグが出ず、他の解があるのかと思い調べたところ、これで全ての解が出てくることを示す論文が見つかったので、実装ミスかなと思いながら試しに復号結果が IERAE を含むとき出力するのではなく全ての復号結果を出力するようにしてみたところ、CTF{xxx} というフラグが出てきて、カス!となりました。作問者の方はフラグフォーマットには注意をお願いします...

[crypto] Heady Heights

次のようなSageMathのスクリプトとその実行結果が与えられます。

from sage.all import *
import flag

BITS = 88
K = 8


def random(lower_bound=0, upper_bound=2 ^ BITS, bits=None):
    return ZZ.random_element(lower_bound, upper_bound)


def random_bits(bits):
    return random(2 ^ (bits - 1), 2 ^ bits)


p = next_prime(random_bits(BITS))
m = ZZ(flag.FLAG.encode().hex(), 16)

a, b, E = None, None, None
P = None
Q = None
R = None

while True:
    a = random_bits(BITS)
    b = random_bits(BITS)
    E = EllipticCurve(Zmod(p ^ K), [a, b])
    try:
        P = E.lift_x(1337)
        break
    except:
        continue

while True:
    secret_key = random(upper_bound=p ^ (K - 1))
    x0 = (secret_key * m) % (p ^ K)
    try:
        R = E.lift_x(x0)
        break
    except:
        continue

Q = secret_key * P


def xy(P):
    t = P.xy()
    return ZZ(t[0]), ZZ(t[1])


x1, y1 = xy(P)
x2, y2 = xy(Q)
x3, y3 = xy(R)

print((x1, x2, x3))
print((y1, y2, y3))

\mathbb{Z}/p^{8}\mathbb{Z} 上の楕円曲線 E の点 P(x_{1},y_{1}), Q(x_{2},y_{2}), R(x_{3},y_{3}) が与えられます。ここで、Px 座標は1337です。Q は、0以上 p^{7} 未満のランダムな整数 secret_key (s とする) を用いて Q=sP と計算されています。R は、s とフラグの積\bmod{p^{8}}x 座標とする点です。p の値や E の方程式 y^{2}=x^{3}+ax+b は与えられませんが、p, a, b は88bitであることがわかっています。

\mathbb{Z}/p^{n}\mathbb{Z} (n\geq 2) 上の楕円曲線におけるDLPについてはzer0pts CTF 2021 pure divisionで出題されており、p 進数体上にliftすることで解けることが知られています。よって、p, a, b の値がわかれば、この方法で secret_key が求まるので、Rx 座標からフラグが求まります。

p の値を求めるところが問題です。c_{i}=y_{i}^{2}-x_{i}^{3} とおきます。y_{i}^{2}=x_{i}^{3}+ax_{i}+b\bmod{p^{8}} (i=1, 2, 3) から b を消去すると、a=(c_{i}-c_{j})/(x_{i}-x_{j})\bmod{p^{8}} が得られます。よって、n=(x_{1}-x_{2})(c_{1}-c_{3})-(x_{1}-x_{3})(c_{1}-c_{2})p^{8} で割り切れます。

n は2800bit程度あるので素因数分解することは難しそうです*1。そこで、a, bx_{1} の値が小さいことに着目します。y_{1}^{2}=x_{1}^{3}+ax_{1}+b\bmod{p^{8}} において、x_{1}=1337 で、a, b は88bitであることから、x_{1}^{3}+ax_{1}+b は100bit程度の値になります。よって、x多項式 x-y_{1}^{2}n の約数 p^{8} をmodとして小さい根 x_{1}^{3}+ax_{1}+b を持つので、Coppersmith法でその根が求まります。SageMathのsmall_roots関数を用いる際のパラメータは、p^{8} のbit数が n のbit数の1/4程度なので \beta=1/4 とします。根は n^{\beta^{2}} より十分小さいので、\epsilon を適当に設定することで根が求まります。

x_{1}^{3}+ax_{1}+b が求まれば、p^{8}=\mathrm{gcd}(x_{1}^{3}+ax_{1}+b-y_{1}^{2}, n) より p が求まります。

from gmpy2 import iroot
from Crypto.Util.number import *

x1,x2,x3 = (1337, 108758038897050520831860923441402897201224898270547825657705075428051130846061735614252293345445641285591980004736447964462956581141116321772403519125859758137648644808920743070411296325521866392898376475395494, 5438451076181919949694350690364579526012926958491719881284366792649670689294870931317007945903275017524668258922051576064401873439529896167369498669912618211164397682696947429627504905294350782410183543966679528)
y1,y2,y3 = (2356240417146305163212384832005924367753484871437731042165238964932920608988096746757585282365391701455222258919772283748442969489163122612874542328479985011793178437324509351503404273134948028573603448460822465, 5224211491008373131406603536527981755345757742567201307027247664784412223361972085071271594280642689356776497337283996518196426296230388008390390705691353643411319840725993589925599219787596133403802269715179842, 1255469150673352477643406441586559401886808227235272570913194477760462899397412967437903450228715079681927518702031385236882455686813595191144244687009073603134094899106009798791920033413388436982273752206346286)

c1 = y1**2-x1**3
c2 = y2**2-x2**3
c3 = y3**2-x3**3

n = abs((x1-x2)*(c1-c3)-(x1-x3)*(c1-c2))
for i in range(2,100000):
    while n%i==0:
        n//=i

R.<x> = PolynomialRing(Zmod(n))
f = x-y1^2

beta = 0.25
eps = beta^2/4
res = f.small_roots(beta=beta, epsilon=eps)
p8 = GCD(int(res[0]-y1^2), n)
p = int(iroot(int(p8), 8)[0])
print(p)

K = 8
a = (c1-c2)*pow(int(x1-x2),int(-1),int(p**K))%(p**K)
b = (c1-a*x1)%(p**K)
print(a)
print(b)

Fp = GF(p)
Qp = pAdicField(p, K)
E = EllipticCurve(Qp, [a, b])
N = EllipticCurve(GF(p), [a, b]).order()

S = E(x1,y1)
T = E(x2,y2)
NS = N * S
a = Fp(-NS[0] / (p * NS[1]))

n = 0
l = 1
Sp = S
Tp = T
ds = []
while Tp != 0:
    NTp = N*Tp
    w = -NTp[0] / NTp[1]
    b = w / p^l
    d = Fp(Integer(b)/a)
    ds.append(Integer(d))
    Tp = Tp - Integer(d)*Sp
    Sp = p*Sp
    n += 1
    l += 1
    if n > K:
        break

solve = 0
for i in range(len(ds)):
    solve += ds[i] * p^i

flag = pow(int(solve),int(-1),int(p^K))*x3%(p^(K))
print(long_to_bytes(int(flag)))

非想定だったようですが、欲しい素数を約数に持つような n を (RSAの公開鍵などで与えられているわけではなく) 作ったうえでCoppersmith法を適用するという流れが面白かったです。

[crypto] Free Your Mind

次のようなSageMathのスクリプトが与えられます。

#!/usr/bin/env sage

from sage.all import *
import flag
import sys
import os

sys.stdout = os.fdopen(sys.stdout.fileno(), 'w', buffering=1)
sys.stderr = os.fdopen(sys.stderr.fileno(), 'w', buffering=1)
sys.stdin = os.fdopen(sys.stdin.fileno(), 'r', buffering=1)

# Initial configuration
UCF = UniversalCyclotomicField()
zeta = UCF.zeta(11).to_cyclotomic_field()

eta = zeta^2 + zeta^-2

# NF is a finite Galois extension of the rational field, hence NF is a simple extension of the rational field
NF.<omega> = NumberField(eta.minpoly())
deg = NF.degree()

CRITERIA = 2^16
BITS = 2048

def main():
    print("Welcome to *Number Theoretic* Encryption Oracle!")

    print("Enter your request key:")
    coeffs = []
    # you can enter any element of NF because omega's coefficient is controllable :)
    # ...but smaller value is prohibited to prevent cheating!
    for _ in range(deg):
        x = input().replace(" ", "")
        if "/" in x: 
            p, q = x.split("/")
            p = int(p)
            q = int(q)
            g = gcd(p, q)
            p /= g
            q /= g
            if abs(p) < CRITERIA or abs(q) < CRITERIA:
                print("You can't cheat!")
                return
            coeffs.append(QQ(p) / q)
        else:
            x = int(x)
            if abs(x) < CRITERIA:
                print("You can't cheat!")
                return
            coeffs.append(x)

    print("We got your request key")
    # your element is here
    alpha = sum(x * omega^i for i, x in enumerate(coeffs))

    # we use the norm of `alpha` for RSA's public key
    # class group order is not used here, because it is too complicated to compute :(
    e = alpha.norm()
    set_random_seed(int(str(alpha).encode().hex(), 16) * current_randstate().ZZ_seed()) # Add to the entropy pool
    p = random_prime(2^BITS)
    q = random_prime(2^BITS)

    N = p * q
    d = Mod(e, (p - 1) * (q - 1))^-1
    # print(f"{e = }") <- you don't need it because you can compute it, isn't it?
    print(f"{N = }")
    msg = ZZ(flag.FLAG.encode().hex(), 16)
    assert msg < N
    c = ZZ(Mod(msg, N)^e)
    assert Mod(c, N)^d == msg
    print(f"Enc(flag) = {c}") # Can you get the FLAG?

if __name__ == "__main__":
    try:
        main()
    except:
        print("Are you kidding me?")

\zeta を1の原始11乗根として、\omega=\zeta^{2}+\zeta^{-2} とし、F=\mathbb{Q}(\omega) という代数体を考えています。F\mathbb{Q} の5次拡大です。F の元 \alpha を入力すると、\alpha のノルム N_{F}(\alpha) を公開鍵 e の値とするRSA暗号でフラグを暗号化した結果が得られます。ただし、\alpha=\sum_{i=0}^{4}\alpha_{i}\omega^{i} (\alpha_{i}\in \mathbb{Q}) と表したときの \alpha_{i} の値はある程度大きい必要があります。

まずノルムの定義を思い出しておきます。代数体 F=\mathbb{Q}(\omega) に対し、\omega\mathbb{Q} 上の最小多項式 (\omega を根に持つ最小次数の有理数係数多項式) を f(x) とし、f(x) の根を \omega_{0}, \ldots , \omega_{n-1} とします。共役写像 \sigma_{k} を、\alpha=\sum_{i=0}^{n-1}\alpha_{i}\omega^{i} (\alpha_{i}\in \mathbb{Q}) に対し \sigma_{k}(\alpha)=\sum_{i=0}^{n-1}\alpha_{i}\omega_{k}^{i} で定めます。このとき、\alpha のノルムは N_{F}(\alpha)=\prod_{k=0}^{n-1}\sigma_{k}(\alpha) で定義されます。

例えば、K=\mathbb{Q}(\sqrt{2}) とすると、共役写像は恒等写像\sigma(x+y\sqrt{2})=x-y\sqrt{2} の2つなので、N_{K}(x+y\sqrt{2})=(x+y\sqrt{2})(x-y\sqrt{2})=x^{2}-2y^{2} です。

ノルムは乗法的です。つまり、\alpha, \beta\in F に対し、N_{F}(\alpha\beta)=N_{F}(\alpha)N_{F}(\beta) が成り立ちます。

F の整数環を O_{F} とします。結論から言うと、\alphaO_{F} の単数とすれば、e=N_{F}(\alpha)=\pm 1 となり、フラグがそのまま得られます。

O_{F} の単数とは、\alpha\beta=1 を満たす \beta\in O_{F} が存在するような O_{F} の元 \alpha のことをいいます。例えば、有理整数環 \mathbb{Z} の単数は \pm 1 のみですが、\mathbb{Q}(\sqrt{2}) の整数環 \mathbb{Z}[\sqrt{2}] の単数は無数にあります。実際、(1+\sqrt{2})(-1+\sqrt{2})=1 なので 1+\sqrt{2}\mathbb{Z}[\sqrt{2}] の単数であり、\pm(1+\sqrt{2})^{n} (n=0,\pm 1, \pm 2, \ldots) は全て \mathbb{Z}[\sqrt{2}] の単数です。

\alphaO_{F} の単数のとき、ノルムの乗法性より N_{F}(\alpha)\mathbb{Z} の単数なので、N_{F}(\alpha)=\pm 1 です。

この問題の F においても、O_{F} の単数は無数に存在します。実際、SageMathのドキュメントを見ながら、O_{F} の単数のなす群 (単数群) を計算してみると次のようになります。

sage: UCF = UniversalCyclotomicField() 
....: zeta = UCF.zeta(11).to_cyclotomic_field() 
....:  
....: eta = zeta^2 + zeta^-2                                                     
sage: NF.<omega> = NumberField(eta.minpoly())                                    
sage: UK = UnitGroup(NF)                                                         
sage: UK                                                                         
Unit group with structure C2 x Z x Z x Z x Z of Number Field in omega with defining polynomial x^5 + x^4 - 4*x^3 - 3*x^2 + 3*x + 1

単数群は4つの位数無限の元から生成されるようです。O_{F} の単数群の生成元を O_{F} の基本単数といいます。基本単数は次のようにしてSageMathで計算できます。

sage: UK.fundamental_units()                                                     
[omega^4 - 4*omega^2 + 2,
 -omega^3 + 3*omega,
 -omega^3 + 3*omega - 1,
 omega^2 - 2]

これを何乗かすれば、\omega の式で表したときの係数が大きい単数が得られます。

sage: u = UK.fundamental_units()[0]                                              
sage: alpha = u^50                                                                   
sage: alpha                                                                          
-35550146891575*omega^4 + 11066005530460*omega^3 + 127746807020466*omega^2 - 60895014313320*omega - 26994779733015
sage: alpha.norm()                                                                   
1

\alpha としてこの値をサーバーに入力してみるとフラグが得られました。

first bloodでした。大学のとき代数的整数論を勉強していたのですぐ解けました。

[rev, crypto] Fortress

x86-64 ELFのバイナリおよび、それをサーバーとして実行するためのDockerfileが与えられます。

実行してみると、任意の平文に対する暗号文を取得する機能、フラグの暗号文を取得する機能があることがわかります。

$ ./fortress 
1. Encrypt
2. Get Encrypted Flag
3. Exit
> 1
Enter plaintext (Base64): aaaa 
Encrypted (Base64): xB4kEGgYs6/VBWXH4SMdy58n0Sk2EGd020hoEaHJlu4=
1. Encrypt
2. Get Encrypted Flag
3. Exit
> 2
Encrypted Flag: +vcCpLPOaYDVYOQpAsvGpv1qy/hH4EpEmGiMuBHYi54=
1. Encrypt
2. Get Encrypted Flag
3. Exit
> 3

また、Dockerfileに長さ96の16進文字列をファイル key.txt に書き込む処理があることから、鍵は48bytesで固定のようです。

revパートはほぼチームメンバーのkanonさんにやっていただきましたが、簡単に説明します。はじめに key.txt から鍵を読み込み、128bytesの内部状態を初期化しています。その後、まずフラグを flag.txt から読み込んで暗号化しています。メニューの2番を選択すると、このときの暗号文が出力されます (メニューの2番を選択するたびに都度暗号化しているわけではない)。

入力された平文は、base64デコードされた後、32bytesごとに暗号化され、base64エンコードされて出力されます。 暗号化の処理はkanonさんがPythonで書き起こしてくれたコードをもとに説明します。

内部状態は8bytesのブロック16個 (key_0, key_1, ..., key_15) に分かれています。まず、内部状態の一部をAESENC命令で処理したものを平文にXORし、暗号文を作っています。AESENC命令はAESの1ラウンド分の処理 (SubBytes, ShiftRows, MixColumns, AddRoundKey) を行う命令らしいです。ちなみに、AESENCをエミュレートする関数はptrlibに intel_aesenc として実装されています (ptrlibすごい)。

aes_return_1 = aes_1round(key_2, key_3, key_10, key_11)
aes_return_2 = aes_1round(bxor(key_8, key_0), bxor(key_9, key_1), key_4, key_5)

v33 = bxor(flag_copy[0], aes_return_1[:8])
v34 = bxor(flag_copy[1], aes_return_1[8:])
v35 = bxor(flag_copy[2], aes_return_2[:8])
v36 = bxor(flag_copy[3], aes_return_2[8:])
ct.extend([v33, v34, v35, v36])

その後、内部状態の更新が行われます。内部状態の更新処理もAESENCとXORの組み合わせですが、key_0, key_1, key_8, key_9 の部分には暗号文がXORされているのがポイントです。

aes_ret1 = aes_1round(key_0, key_1, key_14, key_15)
aes_ret2 = aes_1round(key_4, key_5, key_2, key_3)

key_0_tmp = bxor(key_14, v33)
key_1_tmp = bxor(key_15, v34)
key_4 = bxor(key_12, key_2)
key_5 = bxor(key_13, key_3)

aes_ret3 = aes_1round(key_8, key_9, key_6, key_7)
aes_ret4 = aes_1round(key_10, key_11, key_8, key_9)
key_8 = bxor(key_6, v35)
key_9 = bxor(key_7, v36)
key_14 = bxor(key_12, key_0)
key_15 = bxor(key_13, key_1)

key_1 = key_1_tmp
key_0 = key_0_tmp

key_2, key_3 = aes_ret1[:8], aes_ret1[8:]
key_6, key_7 = aes_ret2[:8], aes_ret2[8:]
key_10, key_11 = aes_ret3[:8], aes_ret3[8:]
key_12, key_13 = aes_ret4[:8], aes_ret4[8:]

cryptoパートに取り掛かります。まず、暗号化を1回行った後の内部状態が完全にわかっているとして、暗号化前の内部状態を復元できるかを考えます。これが出来ないと解ける気がしませんが、正攻法で出来るようにも見えなかったので*2、z3を試してみたところ無事復元できました。実装においてはsboxが非線形な点が困りますが、josephさんのRTACTF 2023 1R-AESのwriteupを参考にしました。このAESの実装も使っています。

from ptrlib import *
from z3 import *
import base64

from aes import s_box, shift_rows, add_round_key, inv_shift_rows, inv_sub_bytes

def bxor(a, b):
    return [x^y for x,y in zip(a,b)]

def bytes2matrix(text):
    return [list(text[i:i+4]) for i in range(0, len(text), 4)]

def matrix2bytes(matrix):
    return sum(matrix, [])

def xtime(a):
    return If((a & 0x80)==0, (a << 1), (((a << 1) ^ 0x1B) & 0xFF))

def mix_single_column(a):
    t = a[0] ^ a[1] ^ a[2] ^ a[3]
    u = a[0]
    a[0] ^= t ^ xtime(a[0] ^ a[1])
    a[1] ^= t ^ xtime(a[1] ^ a[2])
    a[2] ^= t ^ xtime(a[2] ^ a[3])
    a[3] ^= t ^ xtime(a[3] ^ u)

def mix_columns(s):
    for i in range(4):
        mix_single_column(s[i])

def sub_bytes(s):
    for i in range(4):
        for j in range(4):
            s[i][j] = z3_SBOX(s[i][j])

def aes_1round(a, b, c, d):
    block = bytes2matrix(a+b)
    sub_bytes(block)
    shift_rows(block)
    mix_columns(block)
    add_round_key(block, bytes2matrix(c+d))
    return matrix2bytes(block)

solver = Solver()
z3_SBOX = Function('z3_SBOX', BitVecSort(8), BitVecSort(8))

for i in range(len(s_box)):
    solver.add(z3_SBOX(i) == s_box[i])

state = [[BitVec(f'k_{i}_{j}', 8) for j in range(8)] for i in range(16)]

state_next = [0]*16

v0 = aes_1round(state[2], state[3], state[10], state[11])
v1 = aes_1round(bxor(state[0],state[8]), bxor(state[1],state[9]), state[4], state[5])

w0 = aes_1round(state[0],state[1],state[14],state[15])
w1 = aes_1round(state[4],state[5],state[2],state[3])
w2 = aes_1round(state[8],state[9],state[6],state[7])
w3 = aes_1round(state[10],state[11],state[8],state[9])

sock = Socket("nc xxx.xxx.xxx.xxx xxxxx")
sock.sendline("2")
res = sock.recvlineafter("Encrypted Flag: ").decode()
ciphertext = list(base64.b64decode(res))
ciphertext = [ciphertext[i:i+8] for i in range(0,32,8)]

state_next[0] = bxor(state[14], ciphertext[0])
state_next[1] = bxor(state[15], ciphertext[1])
state_next[2] = w0[:8]
state_next[3] = w0[8:]
state_next[4] = bxor(state[2], state[12])
state_next[5] = bxor(state[3], state[13])
state_next[6] = w1[:8]
state_next[7] = w1[8:]
state_next[8] = bxor(state[6], ciphertext[2])
state_next[9] = bxor(state[7], ciphertext[3])
state_next[10] = w2[:8]
state_next[11] = w2[8:]
state_next[12] = w3[:8]
state_next[13] = w3[8:]
state_next[14] = bxor(state[0], state[12])
state_next[15] = bxor(state[1], state[13])

state_next_res = [[161, 242, 146, 82, 34, 29, 135, 155], [86, 168, 207, 23, 4, 170, 222, 96], [126, 145, 68, 64, 52, 200, 213, 235], [234, 93, 97, 36, 250, 144, 51, 232], [13, 39, 147, 1, 150, 101, 75, 183], [243, 192, 168, 249, 13, 24, 166, 135], [19, 250, 249, 216, 201, 38, 182, 14], [60, 154, 216, 182, 41, 6, 175, 73], [9, 113, 215, 109, 250, 115, 202, 242], [243, 57, 226, 102, 174, 112, 56, 211], [104, 122, 157, 81, 82, 71, 247, 211], [236, 176, 62, 108, 81, 85, 151, 183], [56, 30, 250, 116, 206, 75, 170, 38], [172, 241, 237, 125, 84, 40, 16, 16], [30, 225, 84, 144, 194, 14, 172, 12], [174, 137, 125, 157, 37, 255, 41, 39]]

for i in range(16):
    for j in range(8):
        solver.add(state_next_res[i][j] == state_next[i][j])

print('solving...')
print(solver.check())
m = solver.model()
print([[m[k].as_long() for k in state[i]] for i in range(16)])

次に、入力する平文が内部状態の更新に関わることに着目し、入力する平文を1ビット反転させることを考えます。このとき暗号文も1ビット反転するので、次の内部状態における key_0, key_1, key_8, key_9 のいずれか1ビットが反転します。すると、次の暗号文にXORされる aes_1round(bxor(key_8, key_0), bxor(key_9, key_1), key_4, key_5) において、AESENCのstateの入力が1ビット反転します。

AESENCのstateの入力が1ビット反転すると、SubBytesによりそのビットが属す1バイトが変化し、MixColumnsによりその変化が4バイトに拡散されます。この4バイトの差分は、ビット反転した1バイトの値のみに依存するので、1バイトずつ全探索することでstateの入力を求めることができます。

これにより bxor(key_8, key_0)bxor(key_9, key_1) が求まります。さらに、AESENCのround keyは最後のAddRoundKeyでXORされるだけなので、key_4, key_5 も求まります。

以下のスクリプトは、固定の平文を6回連続で暗号化したときの、各暗号化直後の bxor(key_8, key_0), bxor(key_9, key_1), key_4, key_5 を求めるスクリプトです。

from ptrlib import *
import os
import base64

def bxor(a, b):
    return bytes(x^y for x,y in zip(a,b))

def aes_1round(a, b, c, d):
    return intel_aesenc(a + b, c + d)

# plaintext = os.urandom(32)
# plaintext = list(plaintext)
plaintext = [75, 135, 135, 189, 81, 46, 25, 4, 134, 146, 62, 168, 44, 52, 206, 99, 135, 9, 175, 103, 246, 253, 53, 101, 222, 218, 20, 244, 168, 37, 156, 133]

f = open("result.txt","w")

N = 6
for t in range(N):
    msg = base64.b64encode(bytes(plaintext))
    sock = Socket("nc xxx.xxx.xxx.xxx xxxxx")
    for _ in range(t):
        sock.sendline("1")
        sock.sendline(msg)
        sock.recvlineafter("Encrypted (Base64): ").decode()
    sock.sendline("1")
    sock.sendline(msg)
    base64.b64decode(sock.recvlineafter("Encrypted (Base64): ").decode())
    sock.sendline("1")
    sock.sendline(msg)
    res0 = base64.b64decode(sock.recvlineafter("Encrypted (Base64): ").decode())
    sock.close()

    state01 = []
    for k in range(16):
        diffs1 = [None]*256
        diffs2 = [None]*256
        a = os.urandom(16)
        c = os.urandom(16)
        for i in range(256):
            a0 = a[:k]+bytes([i])+a[k+1:]
            a1 = a[:k]+bytes([i^1])+a[k+1:]
            a2 = a[:k]+bytes([i^2])+a[k+1:]
            e = intel_aesenc(a0, c)
            e1 = intel_aesenc(a1, c)
            e2 = intel_aesenc(a2, c)
            diffs1[i] = bxor(e, e1)
            diffs2[i] = bxor(e, e2)

        plaintext1 = list(plaintext)
        plaintext1[k] ^= 1
        msg1 = base64.b64encode(bytes(plaintext1))
        plaintext2 = list(plaintext)
        plaintext2[k] ^= 2
        msg2 = base64.b64encode(bytes(plaintext2))

        sock = Socket("nc xxx.xxx.xxx.xxx xxxxx")
        for _ in range(t):
            sock.sendline("1")
            sock.sendline(msg)
            sock.recvlineafter("Encrypted (Base64): ").decode()
        sock.sendline("1")
        sock.sendline(msg1)
        base64.b64decode(sock.recvlineafter("Encrypted (Base64): ").decode())
        sock.sendline("1")
        sock.sendline(msg)
        res1 = base64.b64decode(sock.recvlineafter("Encrypted (Base64): ").decode())
        sock.close()

        sock = Socket("nc xxx.xxx.xxx.xxx xxxxx")
        for _ in range(t):
            sock.sendline("1")
            sock.sendline(msg)
            sock.recvlineafter("Encrypted (Base64): ").decode()
        sock.sendline("1")
        sock.sendline(msg2)
        base64.b64decode(sock.recvlineafter("Encrypted (Base64): ").decode())
        sock.sendline("1")
        sock.sendline(msg)
        res2 = base64.b64decode(sock.recvlineafter("Encrypted (Base64): ").decode())
        sock.close()

        res01 = bxor(res0, res1)
        res02 = bxor(res0, res2)

        for i in range(256):
            if diffs1[i] == res01[16:] and diffs2[i] == res02[16:]:
                state01.append(i)
                break

    state45 = list(bxor(intel_aesenc(state01, b"\x00"*16), bxor(res0[16:], plaintext[16:])))

    print(state01, file=f)
    print(state45, file=f)

以上で内部状態128bytesのうち32bytes分が求まりましたが、ここで考察に行き詰まってしまったので、ここまでで得られた結果をz3の制約に追加することで解けないかを試してみることにしました。固定の平文を6回連続で暗号化したときの各暗号文の結果と、各暗号化直後の内部状態32bytes分を求めた結果をz3の制約に追加したところ、なんと10分程度*3で初期の内部状態を復元することができました!

from ptrlib import *
from z3 import *
import base64

from aes import s_box, shift_rows, add_round_key, inv_shift_rows, inv_sub_bytes

def bxor(a, b):
    return [x^y for x,y in zip(a,b)]

def bytes2matrix(text):
    return [list(text[i:i+4]) for i in range(0, len(text), 4)]

def matrix2bytes(matrix):
    return sum(matrix, [])

def xtime(a):
    return If((a & 0x80)==0, (a << 1), (((a << 1) ^ 0x1B) & 0xFF))

def mix_single_column(a):
    t = a[0] ^ a[1] ^ a[2] ^ a[3]
    u = a[0]
    a[0] ^= t ^ xtime(a[0] ^ a[1])
    a[1] ^= t ^ xtime(a[1] ^ a[2])
    a[2] ^= t ^ xtime(a[2] ^ a[3])
    a[3] ^= t ^ xtime(a[3] ^ u)

def mix_columns(s):
    for i in range(4):
        mix_single_column(s[i])

def sub_bytes(s):
    for i in range(4):
        for j in range(4):
            s[i][j] = z3_SBOX(s[i][j])

def aes_1round(a, b, c, d):
    block = bytes2matrix(a+b)
    sub_bytes(block)
    shift_rows(block)
    mix_columns(block)
    add_round_key(block, bytes2matrix(c+d))
    return matrix2bytes(block)

solver = Solver()
z3_SBOX = Function('z3_SBOX', BitVecSort(8), BitVecSort(8))
for i in range(len(s_box)):
    solver.add(z3_SBOX(i) == s_box[i])

state = [[BitVec(f'k_{i}_{j}', 8) for j in range(8)] for i in range(16)]
state0 = state[:]
plaintext = [75, 135, 135, 189, 81, 46, 25, 4, 134, 146, 62, 168, 44, 52, 206, 99, 135, 9, 175, 103, 246, 253, 53, 101, 222, 218, 20, 244, 168, 37, 156, 133]

sock = Socket("nc xxx.xxx.xxx.xxx xxxxx")

msg = base64.b64encode(bytes(plaintext))
sock.sendline("1")
sock.sendline(msg)
res = sock.recvlineafter("Encrypted (Base64): ").decode()

N = 6
ciphertexts = []
for _ in range(N):
    sock.sendline("1")
    sock.sendline(msg)
    res = sock.recvlineafter("Encrypted (Base64): ").decode()
    ciphertexts.append(res)
ciphertexts = [list(base64.b64decode(s)) for s in ciphertexts]
sock.close()

state01 = []
state45 = []
with open("result.txt","r") as f:
    for _ in range(N):
        state01.append(eval(f.readline()))
        state45.append(eval(f.readline()))

for k in range(N):
    for i in range(8):
        solver.add(state[4][i]==state45[k][i])
        solver.add(state[5][i]==state45[k][i+8])
        solver.add(state[0][i]^state[8][i] == state01[k][i])
        solver.add(state[1][i]^state[9][i] == state01[k][i+8])

    state_next = [0]*16

    v0 = aes_1round(state[2], state[3], state[10], state[11])
    v1 = aes_1round(bxor(state[0],state[8]), bxor(state[1],state[9]), state[4], state[5])

    w0 = aes_1round(state[0],state[1],state[14],state[15])
    w1 = aes_1round(state[4],state[5],state[2],state[3])
    w2 = aes_1round(state[8],state[9],state[6],state[7])
    w3 = aes_1round(state[10],state[11],state[8],state[9])

    t0 = bxor(v0[:8], plaintext[:8])
    t1 = bxor(v0[8:], plaintext[8:16])
    t2 = bxor(v1[:8], plaintext[16:24])
    t3 = bxor(v1[8:], plaintext[24:])
    t = [t0,t1,t2,t3]
    
    for i in range(4):
        for j in range(8):
            solver.add(ciphertexts[k][i*8+j]==t[i][j])

    state_next[0] = bxor(state[14], t0)
    state_next[1] = bxor(state[15], t1)
    state_next[2] = w0[:8]
    state_next[3] = w0[8:]
    state_next[4] = bxor(state[2], state[12])
    state_next[5] = bxor(state[3], state[13])
    state_next[6] = w1[:8]
    state_next[7] = w1[8:]
    state_next[8] = bxor(state[6], t2)
    state_next[9] = bxor(state[7], t3)
    state_next[10] = w2[:8]
    state_next[11] = w2[8:]
    state_next[12] = w3[:8]
    state_next[13] = w3[8:]
    state_next[14] = bxor(state[0], state[12])
    state_next[15] = bxor(state[1], state[13])

    state = state_next

print('solving...')
print(solver.check())
m = solver.model()
print([[m[k].as_long() for k in state0[i]] for i in range(16)])

IERAE NIGHTで作問者の方に想定解を教えてもらったところ、base64の処理にバグがあり、それを使って差分攻撃ができるとのことでした (あとで復習します)。base64の部分はろくに読まずに数個の入力を試しただけで普通のbase64っぽいですねーみたいなことを言ってしまった気がするので、z3が効かなかったら戦犯になるところでした...

*1:実際はyafuという素因数分解ツールで素因数分解ができたようで、これが想定解だったらしいです。

*2:とCTF中は思っていましたが、落ち着いて考えると普通に逆算できますね...。

*3:入力する平文によっては30分経っても解が見つからないケースもありましたが、入力する平文を変えることでうまくいきました。