## Find Trails

In [1]:
import numpy as np
import matplotlib.pyplot as plt
from itertools import product
from functools import reduce
import random

def get_ddt(sbox):
    l = len(sbox)
    ddt = np.zeros((l,l), dtype=int)
    for i,x in enumerate(sbox):
        for j,y in enumerate(sbox):
            ddt[i^j][x^y] += 1
    return ddt

def to_nibbles(pt:bytes):
    return [(pt[i]>>(4*j)) & 0xf for i in range(len(pt)) for j in range(2)]

def from_nibbles(nibs):
    return bytes([nibs[2*i] + (nibs[2*i+1]<<4) for i in range(len(nibs)//2)])

def tobits(x:int, nbits:int): 
    ret = []
    for i in range(nbits):
        ret.append((x >> i) & 1)
    return ret

def frombits(bits):
    return sum((b<<i) for i,b in enumerate(bits))

def toblks(arr, bl:int):
    ret = []
    for i in range(len(arr)//bl): ret.append(arr[bl*i:bl*i+bl])
    return ret

PERM = [0, 16, 32, 48, 1, 17, 33, 49, 2, 18, 34, 50, 3, 19, 35, 51, 4, 20, 36, 52, 5, 21, 37, 53, 6, 22, 38, 54, 7, 23, 39, 55, 8, 24, 40, 56, 9, 25, 41, 57, 10, 26, 42, 58, 11, 27, 43, 59, 12, 28, 44, 60, 13, 29, 45, 61, 14, 30, 46, 62, 15, 31, 47, 63]
INVPERM = [PERM.index(i) for i in range(64)]
SBOX = [3, 10, 6, 8, 15, 1, 13, 4, 11, 2, 5, 0, 7, 14, 9, 12]
INVSBOX = [SBOX.index(i) for i in range(16)]
PERM = np.array(PERM, dtype=np.uint64)
INVPERM = np.array(INVPERM, dtype=np.uint64)
SBOX = np.array(SBOX, dtype=np.uint64)
INVSBOX = np.array(INVSBOX, dtype=np.uint64)

NROUNDS = 14

# PTCTDIFF = [(41, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),(26, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]),(26, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),(26, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]),(256, [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),(131, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0]),(70, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0]),(36, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]),(131, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),(67, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),(70, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),(137, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])]

In [2]:
threshold = 4
ddt = get_ddt(SBOX)
ddt_thres = ddt > threshold
d = dict([(x,y) for x in range(16) for y in range(16) if ddt_thres[x,y] != 0])
DDT_SBOX = np.array([d[i] for i in range(16)], dtype=np.uint64)

In [3]:
from numba import jit, uint64, byte, types, float64
import sys
sys.path.append("../src")
from cipher import expandkey, encrypt, decrypt

@jit(uint64(byte[:]))
def byte2int(blk):
    r = 0
    for i in range(8):
        r += blk[i] << (i*8)
    return r

@jit(byte[:](uint64))
def int2byte(x):
    r = np.zeros(8, dtype=np.uint8)
    for i in range(8): r[i] = (x >> (i*8))&0xff
    return r

@jit(uint64[:](byte[:]))
def toints(pt):
    r = np.zeros(len(pt)//8, dtype=np.uint64)
    for i in range(len(pt)//8): r[i] = byte2int(pt[8*i:8*i+8])
    return r

@jit(uint64(uint64))
def sub(p):
    r = 0
    for i in range(16):
        r = r | DDT_SBOX[(p >> (i*4)) & 0xf] << (i*4)
    return r

@jit(uint64(uint64))
def perm(p):
    r = 0
    for i in range(64):
        r |= ((p >> i) & 1) << PERM[i]
    return r

@jit(types.Tuple((uint64, float64))(uint64, uint64))
def encryptblk(p, nrounds):
    trailprob = 1.0
    for k in range(nrounds):
        for i in range(16):
            x = (p >> (i*4))&0xf
            trailprob *= ddt[x][DDT_SBOX[x]]/16
        p = sub(p)
        p = perm(p)
    p = sub(p)
    return p, trailprob

def getctrecovered(ctdiff):
    return set(i for i in range(16) if (ctdiff >> (i*4))&0xf != 0)

In [4]:
from itertools import product

# (onein, (ptdiff, ctdiff))
goodtrails = [None]*16
rejected = []
for _ptdiff in product(range(64), repeat=2):
        
    ptdiff = 0
    for a in _ptdiff:
        ptdiff |= 1 << a
    ctdiff, prob = encryptblk(ptdiff, 12)

    ct_recovered = getctrecovered(ctdiff)
    onein = 1/prob
    for k in ct_recovered:
        kt = goodtrails[k]
        if kt is None or kt[0] > onein:
            goodtrails[k] = (onein, (ptdiff, ctdiff))
            if kt is not None: rejected.append(kt[1][0])
    
goodtrails = [(x, tobits(y,64), tobits(z,64)) for x,(y,z) in goodtrails if x < 100000]

## Exploit Trails

In [5]:
import os

def gen_pairs(ptdiff, npairs):
    pairs = os.urandom(8*npairs)
    p2 = [*pairs]
    for j in range(npairs):
        for i,d in enumerate(ptdiff):
            if d == 0: continue
            p2[j*8 + i // 8] ^= 1 << (i % 8)
    return pairs + bytes(p2)

def getdistr(arr):
    ret = [0]*16
    for a in arr: ret[a] += 1
    return [r/len(arr) for r in ret]

def server_encrypt(sendtoserver):
    from nclib import Netcat
    nc = Netcat(("0.0.0.0", 1337))
    r = nc.recvline().strip()[len("Encrypted flag: "):]
    encrypted_flag = bytes.fromhex(r.decode())
    nc.recvline()
    nc.sendline(sendtoserver.hex())
    nc.recvline()
    enc = nc.recvline().strip()
    return encrypted_flag, bytes.fromhex(enc.decode())

In [6]:
NPAIR = int(10000000 / (16 * len(goodtrails)))
sendtoserver = b""
for onein, ptd, ctd in goodtrails:
    sendtoserver += gen_pairs(ptd, int(NPAIR))
encrypted_flag, pairs = server_encrypt(sendtoserver)
print("Encrypted flag:", encrypted_flag.hex())

# naked_flag = byte2int(np.frombuffer(b'9~FS1idk', dtype=np.uint8))
# KEYEX = expandkey(naked_flag ^ randomizer)

Encrypted flag: 44d75f186123acf26b66624af2c95601db77473eadb8f07558853031d807f8fb


In [7]:
kalldistr = []
ptr = 0
for onein, ptd, ctd in goodtrails:
        
    enc = pairs[ptr:ptr + NPAIR*8*2]
    ptr += NPAIR*8*2
    enc = to_nibbles(enc)
    lp = len(enc)
    enc = [*zip(toblks(enc[:lp//2], 16), toblks(enc[lp//2:], 16))]
    
    # Get distrib
    activect = [frombits(x) for x in toblks(ctd, 4)]
    alldistr = {} #
    partialdec = lambda c,k: INVSBOX[c^k]
    for didx,diff in enumerate(activect):
        if diff == 0: continue
        _ = []
        for kguess in range(16):
            cdiff = [partialdec(c1[didx], kguess) ^ partialdec(c2[didx], kguess) for c1,c2 in enc]
            _.append(getdistr(cdiff))
        alldistr[(didx, diff)] = _
    kalldistr.append(alldistr)
        
    print(f"One in {onein}, recovers: {[k[0] for k in alldistr.keys()]}")
        
    # for (didx, diff), dists in alldistr.items():
    #     plt.plot([d[diff] for d in dists])
    #     plt.axvline(x = (int(KEYEX[-1]) >> (4*didx))&0xf)
    #     plt.show()

One in 17179.869184, recovers: [1]
One in 22239.9981598543, recovers: [2, 3]
One in 4503.599627370496, recovers: [3]
One in 17179.869184, recovers: [4]
One in 1677.7216, recovers: [5]
One in 4886.718345671111, recovers: [6]
One in 687.19476736, recovers: [7]
One in 32025.597350190194, recovers: [8, 11]
One in 4886.718345671111, recovers: [9]
One in 4048.668109456143, recovers: [10, 11]
One in 1281.0238940076079, recovers: [11]
One in 4503.599627370496, recovers: [12]
One in 687.19476736, recovers: [13]
One in 1281.0238940076079, recovers: [14]
One in 281.474976710656, recovers: [15]


In [8]:
from math import log2

keyrec = [set(range(16)) for _ in range(16)]
for l,alldistr in enumerate(kalldistr):

    for (didx, diff), dists in alldistr.items():
        if len(keyrec[didx]) == 1: continue
        d = sorted([(i,d[diff]) for i,d in enumerate(dists)], key=lambda x:-x[1])
        dd = [x for _,x in d]
        h,l = max(dd),min(dd)
        keyrec[didx] &= set(x for x,f in d[:2] if (h-f)/(h-l) < 1/5)

# assert all([(int(KEYEX[-1]) >> (i*4))&0xf in k2 for i,k2 in enumerate(keyrec)])
assert min(map(len, keyrec)) > 0, "Rerun trails got wrong"
nbitbrute = sum([log2(len(i)) for i in keyrec])
print("Number of bits to bruteforce:", nbitbrute)
assert nbitbrute <= 24, "Rerun lmao not gonna brute that much"

Number of bits to bruteforce: 4.0


In [9]:
from typing import List
from itertools import product
from numba import jit, njit, byte, uint64, int32
import numpy as np

INVSBOX = np.array([11, 5, 9, 0, 7, 10, 2, 12, 3, 14, 1, 8, 15, 6, 13, 4], dtype=np.uint64)
INVEX = pow(11704981291924017277, -1, 1<<64)

@jit(uint64(uint64))
def invsub(p:int) -> int:
    r = 0
    for i in range(16):
        r |= INVSBOX[(p >> (i*4)) & 0xf] << (i*4)
    return r

@jit(uint64(uint64))
def unexpandkey(key:int) -> int:
    c = key
    mask = (1<<64) - 1
    for i in range(NROUNDS-1):
        c = invsub(c)
        c *= INVEX
        c &= mask
    return c

def totoint(key):
    r = 0
    for i in range(16): r += key[i] << (4*i)
    return r

pt,ct = (sendtoserver[:8], pairs[:8])

candidates = []
for i,keytry in enumerate(product(*keyrec)):
    k = totoint(keytry)
    k = unexpandkey(k)
    if encrypt(pt, k) == ct:
        print("Possible flag:", "SEE{" + decrypt(encrypted_flag, k).decode() + "}")
        break
    if (i % 10000 == 0): print(i, end="\r")

0Possible flag: SEE{Sl1dinG_D1ffeR3nt14L_BAb5y:1kcKj}
