_[CA CTF 2022]_ - Memory Acceleration

14 minute read

Description

  • While everyone was asleep, you were pushing the capabilities of your technology to the max. Night after night, you frantically tried to repair the encrypted parts of your brain, reversing custom protocols implemented by your father, wanting to pinpoint exactly what damage had been done and constantly keeping notes because of your inability of forming new memories. On one of those nights, you had a flashback. Your father had always talked about a new technology and how it would change the galaxy. You realized that he had used it on you. This technology dealt with a proof of a work function and decentralized networks. Along with Virgil’s help, you had a “Eureka!” moment, but his approach, brute forcing, meant draining all your energy. Can you find a quicker way to validate new memory blocks?

Objective

  • Break a custom hash function with z3 and minor brute forcing.

Difficulty

  • medium/medium-hard

Flag

  • HTB{K14u5_h45_z3_h42d_c0d3d_1n_m3m02y}

Downloadables

  • source.py: contains the server code.

Attack

Analyzing the source code

Looking at the source.py script we can see that our goal is to somehow find a way to make the phash function output 0, 4 times.

  1. A memory block is loaded and awaites validation.
  2. 2 keys are requested by the debuging interface.
  3. The phash is calculated. If valid continue if not exit()
  4. The new memory is appended to the previous block and the process is repeated.

Step 1, 2 and 4 are not that interesting to analyze. We are going to jump emedietly to the phash function.

2019

1    proof_of_work = phash(block, first_key, second_key)

phash is import from another file

1from pofwork import phash

So let’s take a look on the pofwork file.

We first see a sbox, the sub function, the rotl function and the phash function.

The sub function is just a subsitution function that uses the default sbox from the wikipedia of AES. So it’s not interesting.

The rotl function is just a standard bit rotating function. Also nothing interesting there.

Let’s now look the juicy phash function.

We see that there is a manipulation of the memory block that is takes as an input.
It hashes the block, then it expands it and then separates it to blocks of size 4.

1    block = md5(block.encode()).digest()
2    block = 4 * block
3    blocks = [int.from_bytes(block[i:i+4],'big') for i in range(0, len(block), 4)]

After that some interesting variables are initialized. And something to note is that key1 is used here.

1    m = 0xffffffff
2    rv1, rv2 = 0x2423380b4d045, 0x3b30fa7ccaa83
3    x, y, z, u = key1, 0x39ef52e9f30b3, 0x253ea615d0215, 0x2cd1372d21d77

After that 13 rounds of chaos basically happens with a lot of bitwise operations and h is calculated from rv1.

1    for i in range(13):
2        x, y, z, u = blocks[i] ^ x, blocks[i+1] ^ y, blocks[i+2] ^ z, blocks[i+3] ^ u
3        rv1 ^= (x := (x & m) * (m + (y >> 16)) ^ rotl(z, 3))
4        rv2 ^= (y := (y & m) * (m + (z >> 16)) ^ rotl(x, 3))
5        rv1, rv2 = rv2, rv1
6        rv1 = sub(rv1)
7        rv1 = bytes_to_long(rv1)
8
9    h = rv1 + 0x6276137d7 & m

Finally, key2 is passed to sub, and some bitwise operations are performed with h and key2

 1    for i, d in enumerate(key2):
 2        a = (h << 1) & m
 3        b = (h << 3) & m
 4        c = (h >> 4) & m
 5        h ^= (a + b + c - d)
 6        h += h
 7        h &= m
 8
 9    h *= u * z
10    h &= m
11
12    return h

A small recap of all the interesting stuff we found out until now:

  1. key1 is used only to calculate h until some point
  2. key2 and h are used for the final hash calculation with some simpler bitwise operations.

Finding the vulnerability

There are many things wrong with this hash function, but our goal is to somehow make it output 0. If we break it down to simpler problems we can see that it has parts that can be solved with z3 and a little bit of bruteforcing.

Exploitation

Connecting to the server

A pretty basic script for connecting to the server with pwntools:

1if __name__ == '__main__':
2    r = remote('0.0.0.0', 1337)
3    pwn()

Getting the first point

When somebody connects to the server, a block of memory is calculated. In order to fetch it from the server we can use:

1def getBlockToValidate():
2    debug_msg = r.recvline()
3    block = debug_msg.decode().strip()[len('DEBUG MSG - You need to validate this memory block: '):]
4    return block

z3 magic

As hinted above the hash function needs to be dissected to sections in order to help z3 solve a solvable system.

More specifically we can ignore the entire first half of the hash function and focus only on the second one.

 1    h = rv1 + 0x6276137d7 & m
 2    key2 = sub(key2)
 3
 4    for i, d in enumerate(key2):
 5        a = (h << 1) & m
 6        b = (h << 3) & m
 7        c = (h >> 4) & m
 8        h ^= (a + b + c - d)
 9        h += h
10        h &= m
11
12    h *= u * z
13    h &= m

How can we simplify the problem more for z3?

We know that h is deterministic for the pair block, key1.
We can also see that sub(key2) is reversable and that in order for h to be 0 the last 2 lines are irrelevant.

So the problem becomes:

If we can control key and h is a known value how can we make h = 0 after the bitwise operations.

1    for i, d in enumerate(key2):
2        a = (h << 1) & m
3        b = (h << 3) & m
4        c = (h >> 4) & m
5        h ^= (a + b + c - d)
6        h += h
7        h &= m

z3 to the rescue. Let’s create a model.

  1. We create a solver.
  2. Create the bit vectors
  3. Add our restrictions.
  4. Solve the model if it is solvable.
 1def findSecondKeyWithZ3(block, key1):
 2    s = Solver()
 3
 4    xs = list(BitVecs('c0 c1 c2 c3 c4 c5', 32))
 5    h = BitVec('hs', 32)
 6
 7    target_h = phashFirstHalf(block, key1)
 8    s.add(h == target_h)
 9
10    for i, e in enumerate(xs):
11        s.add(e >= 0)
12        s.add(e < 255)
13        h ^= ((h << 1) + (h << 3) + rotr(h, 4) - e)
14        h += h
15
16    s.add(h == 0)
17
18    if (s.check() != unsat):
19        m = s.model()
20        s = bytes([m[c].as_long() for c in xs])
21
22        assert phashSecondHalf(s, target_h) == 0
23        return(s)
24    else:
25        return('unsat')

Brute forcing key1

You might wondering why the findSecondKeyWithZ3() is a function that takes key1 as an input.. Well not every key1 creates a solvable system for z3 so we need to brute force it a little bit. But don’t worry it’s a tiny brute force.

1def findKeys(block):
2    first_key = 0
3    while 1:
4        second_key = findSecondKeyWithZ3(block, first_key)
5        if (second_key != 'unsat'):
6            return(first_key, bytes_to_long(isub(second_key)))
7        first_key += 1

Sending the keys

1def sendKeys(first_key, second_key):
2    r.sendlineafter(b'DEBUG MSG - Enter first key: ', bytes(str(first_key), 'Latin'))
3    r.sendlineafter(b'DEBUG MSG - Enter second key: ', bytes(str(second_key), 'Latin'))

Getting the flag

A final recap of all the above:

  1. We connected to the server.
  2. We got the block of memory.
  3. We found out that a section of our hash function can be solved with z3.
  4. We brute forced key1 such that the z3 system would be solvable.
  5. We run our solver and found key2.
  6. We finaly send the keys and get the next block.
  7. Repeat this process 4 times and we get the flag :)

This recap can be reprisented by code with the pwn() function.

 1def pwn():
 2    for _ in range(4):
 3        skipMessages()
 4        block = getBlockToValidate()
 5        print(f"Fetched block of memory that needs validation")
 6        first_key, second_key = findKeys(block)
 7        print(f"Found valid keys: {first_key}, {second_key} for block: {block}")
 8        sendKeys(first_key, second_key)
 9        print("Moving on to the next block")
10    r.interactive()

Challenge Code

The server code is presented below:

 1import socketserver
 2import signal
 3from pofwork import phash
 4
 5
 6DEBUG_MSG = "DEBUG MSG - "
 7WELCOME_MSG = """Virgil says:
 8Klaus I'm connecting the serial debugger to your memory.
 9Please stay still. We don't want anything wrong to happen.
10Ok you should be able to see debug messages now..\n\n"""
11
12
13with open('memories.txt', 'r') as f:
14    MEMORIES = [m.strip() for m in f.readlines()]
15
16
17class Handler(socketserver.BaseRequestHandler):
18    def handle(self):
19        signal.alarm(0)
20        main(self.request)
21
22
23class ReusableTCPServer(socketserver.ForkingMixIn, socketserver.TCPServer):
24    pass
25
26
27def sendMessage(s, msg):
28    s.send(msg.encode())
29
30
31def recieveMessage(s, msg):
32    sendMessage(s, msg)
33    return s.recv(4096).decode().strip()
34
35
36def main(s):
37    block = ""
38    counter = 0
39    sendMessage(s, WELCOME_MSG)
40
41    while True:
42        block += MEMORIES[counter]
43
44        sendMessage(s, DEBUG_MSG +
45                    f"You need to validate this memory block: {block}\n")
46
47        first_key = recieveMessage(s, DEBUG_MSG + "Enter first key: ")
48        second_key = recieveMessage(s, DEBUG_MSG + "Enter second key: ")
49
50        try:
51            first_key, second_key = int(first_key), int(second_key)
52            proof_of_work = phash(block, first_key, second_key)
53        except:
54            sendMessage(s, "\nVirgil says: \n"
55                        "Be carefull Klaus!! You don't want to damage yourself.\n"
56                        "Let's start over.")
57            exit()
58
59        if proof_of_work == 0:
60            block += f" ({first_key}, {second_key}). "
61            sendMessage(s, "\nVirgil says: \nWow you formed a new memory!!\n")
62            counter += 1
63            sendMessage(
64                s, f"Let's try again {4 - counter} times just to be sure!\n\n")
65        else:
66            sendMessage(s, DEBUG_MSG + f"Incorect proof of work\n"
67                        "\nVirgil says: \n"
68                        "You calculated something wrong Klaus we need to start over.")
69            exit()
70
71        if counter == 4:
72            sendMessage(s, "It seems that everything are working fine.\n"
73                        "Wait what is that...\n"
74                        "Klaus this is important!!\n"
75                        "This can help you find your father!!\n"
76                        f"{MEMORIES[-1]}")
77            exit()
78
79
80if __name__ == '__main__':
81    socketserver.TCPServer.allow_reuse_address = True
82    server = ReusableTCPServer(("0.0.0.0", 1337), Handler)
83    server.serve_forever()

The pofwork:

 1from hashlib import md5
 2from Crypto.Util.number import long_to_bytes, bytes_to_long
 3
 4
 5sbox = [
 6    0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76,
 7    0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
 8    0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
 9    0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
10    0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84,
11    0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
12    0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8,
13    0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
14    0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
15    0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
16    0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79,
17    0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
18    0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a,
19    0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
20    0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
21    0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16
22]
23
24
25def rotl(n, b):
26    return ((n << b) | (n >> (32 - b))) & 0xffffffff
27
28
29def sub(b):
30    b = long_to_bytes(b)
31    return bytes([sbox[i] for i in b])
32
33
34def phash(block, key1, key2):
35    block = md5(block.encode()).digest()
36    block = 4 * block
37    blocks = [bytes_to_long(block[i:i+4]) for i in range(0, len(block), 4)]
38
39    m = 0xffffffff
40    rv1, rv2 = 0x2423380b4d045, 0x3b30fa7ccaa83
41    x, y, z, u = key1, 0x39ef52e9f30b3, 0x253ea615d0215, 0x2cd1372d21d77
42
43    for i in range(13):
44        x, y = blocks[i] ^ x, blocks[i+1] ^ y
45        z, u = blocks[i+2] ^ z, blocks[i+3] ^ u
46        rv1 ^= (x := (x & m) * (m + (y >> 16)) ^ rotl(z, 3))
47        rv2 ^= (y := (y & m) * (m + (z >> 16)) ^ rotl(x, 3))
48        rv1, rv2 = rv2, rv1
49        rv1 = sub(rv1)
50        rv1 = bytes_to_long(rv1)
51
52    h = rv1 + 0x6276137d7 & m
53    key2 = sub(key2)
54
55    for i, d in enumerate(key2):
56        a = (h << 1) & m
57        b = (h << 3) & m
58        c = (h >> 4) & m
59        h ^= (a + b + c - d)
60        h += h
61        h &= m
62
63    h *= u * z
64    h &= m
65
66    return h

Solver

  1import random
  2from hashlib import md5
  3from Crypto.Util.number import *
  4from z3 import *
  5from pwn import *
  6
  7
  8sbox = [
  9    0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76,
 10    0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
 11    0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
 12    0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
 13    0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84,
 14    0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
 15    0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8,
 16    0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
 17    0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
 18    0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
 19    0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79,
 20    0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
 21    0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a,
 22    0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
 23    0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
 24    0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16
 25]
 26
 27
 28def rotl(n, b):
 29    return ((n << b) | (n >> (32 - b))) & 0xffffffff
 30
 31
 32def rotr(v, n):
 33    return (v >> n) & ((1 << (32 - n)) - 1)
 34
 35
 36def sub(b):
 37    b = long_to_bytes(b)
 38    return bytes([sbox[i] for i in b])
 39
 40
 41def isub(b):
 42    return bytes([sbox.index(i) for i in b])
 43
 44
 45def phashFirstHalf(text, key1):
 46    text = md5(text.encode()).digest()
 47    text = 4 * text
 48    text = [int.from_bytes(text[i:i+4], 'big') for i in range(0, len(text), 4)]
 49
 50    m = 0xffffffff
 51    rv1, rv2 = 0x2423380b4d045, 0x3b30fa7ccaa83
 52    x, y, z, u = key1, 0x39ef52e9f30b3, 0x253ea615d0215, 0x2cd1372d21d77
 53
 54    for i in range(13):
 55        x, y, z, u = text[i] ^ x, text[i+1] ^ y, text[i+2] ^ z, text[i+3] ^ u
 56        rv1 ^= (x := (x & m) * (m + (y >> 16)) ^ rotl(z, 3))
 57        rv2 ^= (y := (y & m) * (m + (z >> 16)) ^ rotl(x, 3))
 58        rv1, rv2 = rv2, rv1
 59        rv1 = sub(rv1)
 60        rv1 = bytes_to_long(rv1)
 61
 62    h = rv1 + 0x6276137d7 & m
 63    return h
 64
 65
 66def phashSecondHalf(s, target_h):
 67    h = target_h
 68    for i, d in enumerate(s):
 69        a = (h << 1) & 0xffffffff
 70        b = (h << 3) & 0xffffffff
 71        c = (h >> 4) & 0xffffffff
 72        h ^= (a + b + c - d)
 73        h += h
 74        h &= 0xffffffff
 75    return h
 76
 77
 78def skipMessages():
 79    for _ in range(5):
 80        r.recvline()
 81
 82
 83def getBlockToValidate():
 84    debug_msg = r.recvline()
 85    block = debug_msg.decode().strip()[len(
 86        'DEBUG MSG - You need to validate this memory block: '):]
 87    return block
 88
 89
 90def findSecondKeyWithZ3(block, key1):
 91    s = Solver()
 92
 93    target_h = phashFirstHalf(block, key1)
 94
 95    xs = list(BitVecs('c0 c1 c2 c3 c4 c5', 32))
 96    h = BitVec('hs', 32)
 97
 98    s.add(h == target_h)
 99
100    for i, e in enumerate(xs):
101        s.add(e >= 0)
102        s.add(e < 255)
103        h ^= ((h << 1) + (h << 3) + rotr(h, 4) - e)
104        h += h
105
106    s.add(h == 0)
107
108    if (s.check() != unsat):
109        m = s.model()
110        s = bytes([m[c].as_long() for c in xs])
111
112        assert phashSecondHalf(s, target_h) == 0
113        return(s)
114    else:
115        return('unsat')
116
117
118def findKeys(block):
119    first_key = 0
120    while 1:
121        second_key = findSecondKeyWithZ3(block, first_key)
122        if (second_key != 'unsat'):
123            return(first_key, bytes_to_long(isub(second_key)))
124        first_key += 1
125
126
127def sendKeys(first_key, second_key):
128    r.sendlineafter(b'DEBUG MSG - Enter first key: ',
129                    bytes(str(first_key), 'Latin'))
130    r.sendlineafter(b'DEBUG MSG - Enter second key: ',
131                    bytes(str(second_key), 'Latin'))
132
133
134def pwn():
135    for _ in range(4):
136        skipMessages()
137        block = getBlockToValidate()
138        print(f"Fetched block of memory that needs validation")
139        first_key, second_key = findKeys(block)
140        print(
141            f"Found valid keys: {first_key}, {second_key} for block: {block}")
142        sendKeys(first_key, second_key)
143        print("Moving on to the next block")
144    r.interactive()
145
146
147if __name__ == '__main__':
148    r = remote('localhost', 1337)
149    pwn()