チョコラスクのブログ

野生のコラッタです。

CakeCTF 2023 writeup

CakeCTF 2023 にチーム BunkyoWesterns で参加し、3位を取ることができました。

私はAVTOKYOとそのafter partyに参加したあと、レッドヌオー*1の状態で残っていたcrypto2問を解きました。

[crypto] Iron Door

次のソースコードが与えられます。

from Crypto.Util.number import getPrime, isPrime, getRandomRange, inverse, long_to_bytes
from hashlib import sha256
import os
import secrets
import signal


def h1(s: bytes) -> int:
    return int(sha256(s).hexdigest()[:40], 16)

def h2(s: bytes) -> int:
    return int(sha256(s).hexdigest()[:50], 16)

# curl https://2ton.com.au/getprimes/random/2048
q = 10855513673631576111128223823852736449477157478532599346149798456480046295301804051241065889011325365880913306008412551904076052471122611452376081547036735239632288113679547636623259366213606049138707852292092112050063109859313962494299170083993779092369158943914238361319111011578572373690710592496259566364509116075924022901254475268634373605622622819175188430725220937505841972729299849489897919215186283271358563435213401606699495614442883722640159518278016175412036195925520819094170566201390405214956943009778470165405468498916560026056350145271115393499136665394120928021623404456783443510225848755994295718931
p = 2*q + 1

assert isPrime(p)
assert isPrime(q)

g = 3
flag = os.getenv("FLAG", "neko{nanmo_omoi_tsukanai_owari}")
x = getRandomRange(0, q)
y = pow(g, x, p)
salt = secrets.token_bytes(16)


def sign(m: bytes):
    z = h1(m)
    k = inverse(h2(long_to_bytes(x + z)), q)
    r = h2(long_to_bytes(pow(g, k, p)))
    s = (z + x*r) * inverse(k, q) % q
    return r, s


def verify(m: bytes, r: int, s: int):
    z = h1(m)
    sinv = inverse(s, q)
    gk = pow(g, sinv*z, p) * pow(y, sinv*r, p) % p
    r2 = h2(long_to_bytes(gk))
    return r == r2

# integrity check
r, s = sign(salt)
assert verify(salt, r, s)

signal.alarm(1000)


print("salt =", salt.hex())
print("p =", p)
print("g =", g)
print("y =", y)

while True:
    choice = input("[s]ign or [v]erify:").strip()
    if choice == "s":
        print("=== sign ===")
        m = input("m = ").strip().encode()
        if b"goma" in m:
            exit()

        r, s = sign(m + salt)
        # print("r =", r) #  do you really need?
        print("s =", s)

    elif choice == "v":
        print("=== verify ===")
        m = input("m = ").strip().encode()
        r = int(input("r = "))
        s = int(input("s = "))
        assert 0 < r < q
        assert 0 < s < q

        ok = verify(m + salt, r, s)
        if ok and m == b"hirake goma":
            print(flag)
        elif ok:
            print("OK")
            exit()
        else:
            print("NG")
            exit()

    else:
        exit()

DSAが実装されており、goma を含まない好きな文字列の署名を何度でも取得できます。ただし、署名は r, s のうち s の方しか得られません。hirake goma の署名を入力できればフラグが得られます。

q は2047bitですが、k^{-1}, r は200bit、z は160bitと小さいです。これを利用し、LLLアルゴリズムを適用したいです。

以下、k^{-1}k と置きなおすことにします。k は200bitです。署名は次の関係式を満たします。

\displaystyle{
s=kz+krx \pmod{q}
}

x は小さくないので消去することを考えます。異なる2つの文字列の署名 s_{1}, s_{2} を取得し、この関係式から x を消去すると次のようになります。

\displaystyle{
k_{2}r_{2}s_{1}-k_{1}r_{1}s_{2}=k_{2}r_{2}k_{1}z_{1}-k_{1}r_{1}k_{2}z_{2} \pmod{q}
}

k_{1}r_{1}, k_{2}r_{2} は400bit、k_{2}r_{2}k_{1}z_{1}-k_{1}r_{1}k_{2}z_{2} は760bit程度です。400bitの k_{1}r_{1}, k_{2}r_{2} の組は 2^{800} 通りありますが、この中からランダムに選ばれるとすると、 k_{2}r_{2}s_{1}-k_{1}r_{1}s_{2} が760bit以下になる確率は 1/2^{2048-760}=1/2^{1288} 程度になりそうです。よって、この関係式から係数は一意に定まりそうです。

次のような行列に対してLLLアルゴリズムを適用します (結果のベクトルの成分が同じくらいの大きさになるよう、2^{360} 倍の重みをつけています)。

\displaystyle{
\begin{pmatrix}
s_{1} & 2^{360} & 0 \\ 
s_{2} & 0 & 2^{360} \\ 
q & 0 & 0
\end{pmatrix}
}

すると、1行目に次のようなベクトル

\displaystyle{
(k_{2}r_{2}k_{1}z_{1}-k_{1}r_{1}k_{2}z_{2}, 2^{360}k_{2}r_{2}, -2^{360}k_{1}r_{1})
}

が現れることが期待できます。実際は、

\displaystyle{
g=\gcd(k_{2}r_{2}k_{1}z_{1}-k_{1}r_{1}k_{2}z_{2}, k_{2}r_{2}, -k_{1}r_{1})
}

で割ったベクトルが出てくるので、得られたベクトルを何倍かしてみる必要があります (これに気づかずハマりました)。

さらに、k_{1} を何倍かしたものが \gcd(k_{1}r_{1}, k_{2}r_{2}k_{1}z_{1}-k_{1}r_{1}k_{2}z_{2}) に一致することから、k_{1} が求まります。よって、秘密鍵 x が求まり、署名を計算できます。

from Crypto.Util.number import *
from hashlib import sha256
import socket

def recvuntil(client, delim=b'\n'):
    buf = b''
    while delim not in buf:
        buf += client.recv(1)
    return buf

host = 'crypto.2023.cakectf.com'
port = 10321

client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
client.connect((host, port))

def h1(s: bytes) -> int:
    return int(sha256(s).hexdigest()[:40], 16) # 160bit

def h2(s: bytes) -> int:
    return int(sha256(s).hexdigest()[:50], 16) # 200bit

q = 10855513673631576111128223823852736449477157478532599346149798456480046295301804051241065889011325365880913306008412551904076052471122611452376081547036735239632288113679547636623259366213606049138707852292092112050063109859313962494299170083993779092369158943914238361319111011578572373690710592496259566364509116075924022901254475268634373605622622819175188430725220937505841972729299849489897919215186283271358563435213401606699495614442883722640159518278016175412036195925520819094170566201390405214956943009778470165405468498916560026056350145271115393499136665394120928021623404456783443510225848755994295718931
# 2047bit
p = 2*q + 1
g = 3

s = []

recvuntil(client, b'salt = ')
salt = bytes.fromhex(recvuntil(client).strip().decode())

recvuntil(client, b'y = ')
y = int(recvuntil(client).strip().decode())

recvuntil(client, b'[s]ign or [v]erify:')
client.send(b's\n')
recvuntil(client, b'm = ')
client.send(b'a\n')
recvuntil(client, b's = ')
s.append(int(recvuntil(client).strip().decode()))

recvuntil(client, b'[s]ign or [v]erify:')
client.send(b's\n')
recvuntil(client, b'm = ')
client.send(b'b\n')
recvuntil(client, b's = ')
s.append(int(recvuntil(client).strip().decode()))

z = [h1(b'a'+salt), h1(b'b'+salt)]

mat = [[0]*3 for _ in range(3)]
w1 = 2**360

mat[0][0] = s[0]
mat[1][0] = s[1]
mat[2][0] = q
for i in range(2):
    mat[i][i+1] = w1

res = Matrix(mat).LLL()
res = res[0]

k2r2 = res[1]//w1
k1r1 = -res[2]//w1
if k1r1<0:
    k1r1 = -k1r1
    k2r2 = -k2r2

x = -1
k1 = GCD(int(res[0]), int(k1r1))
for i in range(1, 256):
    if x!=-1:
        break
    k1r1i = k1r1*i
    if int(k1r1i).bit_length() > 400:
        break
    for j in range(1, 256):
        if k1*i%j!=0:
            continue
        k1i = k1*i//j
        if int(k1i).bit_length() > 200:
            continue
        if int(k1r1i)%int(k1i) != 0:
            continue
        r1 = k1r1i//k1i
        x1 = (s[0]-k1i*z[0]) * pow(int(k1r1i),int(-1),int(q)) % q
        if pow(int(g), int(x1), int(p)) == y:
            x = x1
            break

assert pow(int(g), int(x), int(p)) == y

def sign(m: bytes):
    z = h1(m)
    k = inverse(h2(long_to_bytes(int(x + z))), int(q))
    r = h2(long_to_bytes(pow(int(g), int(k), int(p))))
    s = (z + x*r) * inverse(int(k), int(q)) % q
    return r, s

r, s = sign(b"hirake goma"+salt)

recvuntil(client, b'[s]ign or [v]erify:')
client.send(b'v\n')
recvuntil(client, b'm = ')
client.send(b'hirake goma\n')
recvuntil(client, b'r = ')
client.send(str(int(r)).encode()+b'\n')
recvuntil(client, b's = ')
client.send(str(int(s)).encode()+b'\n')
print(recvuntil(client))
# CakeCTF{im_r3a11y_afraid_0f_truncating_hash_dig3st_13ading_unint3nd3d}

ちなみに、AVTOKYOの会場でもLLLの張り紙がありました。

[crypto, pwn] decryptyou

x64のELFバイナリおよびそのソースコードが与えられます。

#include <fcntl.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <gmp.h>
#include <unistd.h>

char flag[100];

typedef struct {
  struct {
    mpz_t u; mpz_t p; mpz_t q; mpz_t dp; mpz_t dq;
  } priv;
  struct {
    mpz_t n; mpz_t e;
  } pub;
} rsacrt_t;

/** @fn random_prime
 *  @brief Generate a random prime number.
 *  @param p: `mpz_t` variable to store the prime.
 *  @param nbits: Bit length of the prime.
 */
void random_prime(mpz_t p, int nbits) {
  gmp_randstate_t state;
  gmp_randinit_default(state);
  gmp_randseed_ui(state, rand());
  mpz_urandomb(p, state, nbits);
  mpz_nextprime(p, p);
  gmp_randclear(state);
}

/** @fn rsa_keygen
 *  @brief Generate a key pair for RSA-CRT.
 *  @param rsa: `rsacrt_t` structure to store the key.
 */
void rsa_keygen(rsacrt_t *rsa) {
  mpz_t p1, q1;
  mpz_inits(rsa->pub.n, rsa->pub.e,
            rsa->priv.u, rsa->priv.p, rsa->priv.q, rsa->priv.dp, rsa->priv.dq,
            p1, q1, NULL);

  /* Generate RSA parameters */
  mpz_set_ui(rsa->pub.e, 65537);
  random_prime(rsa->priv.p, 512);
  random_prime(rsa->priv.q, 512);
  mpz_sub_ui(p1, rsa->priv.p, 1);
  mpz_sub_ui(q1, rsa->priv.q, 1);
  mpz_mul(rsa->pub.n, rsa->priv.p, rsa->priv.q);     // n = p * q
  mpz_invert(rsa->priv.dp, rsa->pub.e, p1);          // dp = e^-1 mod p-1
  mpz_invert(rsa->priv.dq, rsa->pub.e, q1);          // dq = e^-1 mod q-1
  mpz_invert(rsa->priv.u, rsa->priv.q, rsa->priv.p); // u = q^-1 mod p
}

/** @fn challenge
 *  @brief Can you solve this?
 *  @param rsa: `rsacrt_t` structure containing a key pair.
 */
void challenge(rsacrt_t *rsa) {
  char buf[0x200];
  gmp_randstate_t state;
  mpz_t x, m, c, cp, cq, mp, mq;
  mpz_inits(x, m, c, cp, cq, mp, mq, NULL);

  /* Generate a random number and encrypt it */
  gmp_randinit_default(state);
  gmp_randseed_ui(state, rand());
  mpz_urandomb(x, state, 512);
  mpz_powm_ui(c, x, 1333, rsa->pub.n); // c = x^1333 mod n

  gmp_printf("n = %Zd\n", rsa->pub.n);
  gmp_printf("c = %Zd\n", c);

  for (;;) {
    /* Input ciphertext */
    printf("c = ");
    if (scanf("%s", buf) != 1
        || mpz_set_str(c, buf, 10) != 0) {
      fputs("Invalid input", stderr);
      exit(0);
    }

    /* Calculate plaintext */
    mpz_mod(cp, c, rsa->priv.p);
    mpz_mod(cq, c, rsa->priv.q);
    mpz_powm(mp, cp, rsa->priv.dp, rsa->priv.p);
    mpz_powm(mq, cq, rsa->priv.dq, rsa->priv.q);
    // m = (((mp - mq) * u mod p) * q + mq) mod n
    mpz_set(m, mp);
    mpz_sub(m, m, mq);
    mpz_mul(m, m, rsa->priv.u);
    mpz_mod(m, m, rsa->priv.p);
    mpz_mul(m, m, rsa->priv.q);
    mpz_add(m, m, mq);
    mpz_mod(m, m, rsa->pub.n);
    gmp_printf("m = %Zd\n", m);

    /* Check plaintext */
    if (mpz_cmp(m, x) == 0) {
      printf("Congratulations!\n"
             "Here is the flag: %s\n", flag);
      break;
    }
  }
}

/**
 * Entry point
 */
int main() {
  rsacrt_t rsa;
  rsa_keygen(&rsa);
  challenge(&rsa);
  return 0;
}

__attribute__((constructor))
void setup(void) {
  int seed;
  int fd;
  setvbuf(stdin, NULL, _IONBF, 0);
  setvbuf(stdout, NULL, _IONBF, 0);
  setvbuf(stderr, NULL, _IONBF, 0);

  // Get random seed
  if ((fd = open("/dev/urandom", O_RDONLY)) == -1 ||
      read(fd, &seed, sizeof(seed)) != sizeof(seed)) {
    perror("setup failed");
    exit(1);
  }
  close(fd);
  srand(seed);

  // Read flag
  if ((fd = open("/flag.txt", O_RDONLY)) == -1 ||
      read(fd, flag, sizeof(flag)) <= 0) {
    perror("flag not found");
    exit(1);
  }
  close(fd);
}

RSA-CRT (Xornetさんの記事) が実装されており、好きな暗号文の復号が何度でもできます。復号結果が秘密の平文 x に一致すればフラグが得られます。秘密の平文は、異なる e で暗号化されて与えられます。

Xornetさんの記事にも書かれているように、RSA-CRTに対してはfault attackとよばれる攻撃が知られています。また、ソースコードを見ると、buf に何文字でも入力できて、バッファオーバーフローを起こせることがわかります。そこで、このバッファオーバーフローを利用して他の変数を書き換え、復号結果をバグらせることを考えます。

まず試しに600文字入力してみると、次のようにエラーで終了します。

$ ./chall
n = 163912677711753578402018546176455166536369875636338193731886158802473769013912049898122897926394608600927104053030671754067514800594514581474793929508470115692922105904584925001255043938954803175655673822917028179462162860487277678838112468685591852540103929171670267328627307915111147885925404253306808639467
c = 155877206961297913842830834243296718889261927396211762783856436181379059789452860217137398965239284018340690345395308951268718814696496977891214876021281588402561336181375737671009919552744668200793593310735852050446035635182453541062847860321390334387118632127846297570721880821977571616202437792052368270066
c = 111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111111
GNU MP: Cannot allocate memory (size=6602459592)
zsh: IOT instruction  ./chall

ここで、変数は mpz_t という型ですが、これは次のように定義されています(https://github.com/alisw/GMP/blob/master/gmp-h.in#L150)。

typedef struct
{
  int _mp_alloc;       /* Number of *limbs* allocated and pointed
                  to by the _mp_d field.  */
  int _mp_size;            /* abs(_mp_size) is the number of limbs the
                  last field points to.  If _mp_size is
                  negative this is a negative number.  */
  mp_limb_t *_mp_d;     /* Pointer to the limbs.  */
} __mpz_struct;

確保している領域のサイズ、値のサイズ、値へのポインタを持っているようです。

これを踏まえて、gdbで暗号文入力直後のスタックの状態を確認すると、buf+592 の位置から u, p, q, dp, dq, n, e が順に並んでいるのが見れます*2

よって、u を0に書き換えることができます (サイズは適当に、ポインタは適当に参照できそうなアドレスにすればよいです)。

平文の\bmod{p},\bmod{q} からCRTで平文を復元する部分は次のように行われています。

\displaystyle{
m=((m_{p}-m_{q})u \bmod{p})q+m_{q}
}

これを見ると、u の値が変わっても、復号結果の\bmod{q} は変わらないことがわかります。よって、バグった復号結果と正しい平文の差を求め、n との \gcd を取ることで q が求まります。

これで、秘密の平文 x を復号することができます。あとは、u=0 になっている状態で、復号すると x に一致するような暗号文を入力する必要があります。u=0 のときは m=m_{q}、つまり正しい平文の\bmod{q} が返りますが、x は512bit (p, q も512bit) なので、1/2の確率で x=(x\bmod{q}) になります。よって、普通に x の暗号文を入力すればよいです。

from Crypto.Util.number import *
import socket

def recvuntil(client, delim=b'\n'):
    buf = b''
    while delim not in buf:
        buf += client.recv(1)
    return buf

host = 'crypto.2023.cakectf.com'
port = 10666

client = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
client.connect((host, port))

recvuntil(client, b'n = ')
n = int(recvuntil(client))
recvuntil(client, b'c = ')
c0 = int(recvuntil(client))

e0 = 1333
e = 0x10001

print("n =", n)
print("c0 =", c0)

def query(c):
    recvuntil(client, b'c = ')
    client.send(c+b'\n')
    recvuntil(client, b'm = ')
    return int(recvuntil(client))

m1 = int("1"*300)
c = str(pow(m1, e, n)).encode()

c1 = c + b'\x00'*(592-len(c)) + long_to_bytes(0x8, 4)[::-1] + long_to_bytes(0x8, 4)[::-1] + long_to_bytes(0x400000, 8)[::-1] 
m2 = query(c1)

q = GCD(m1-m2, n)
assert 1 < q < n
p = n//q

print("p =", p)
print("q =", q)

d0 = inverse(e0, (p-1)*(q-1))
m0 = pow(c0, d0, n)
print("m0 =", m0)

c = pow(m0, e, n)

query(str(c).encode())
print(recvuntil(client))
print(recvuntil(client))
# CakeCTF{h4lf_crypt0_h4lf_pWn_l1k3_c4k3!?}

cryptoとpwnのタグが付いていて少し身構えましたが、とても面白かったです。

*1:お酒をたくさん飲んだヌオーのことです。

*2:CTF中はなぜか buf+608 に u があると勘違いしており、p も書き換えてしまっていましたが、ちょうど解法が破綻せずうまくいっていました。