_[CA CTF 2022]_ - Memory Acceleration
Description
- While everyone was asleep, you were pushing the capabilities of your technology to the max. Night after night, you frantically tried to repair the encrypted parts of your brain, reversing custom protocols implemented by your father, wanting to pinpoint exactly what damage had been done and constantly keeping notes because of your inability of forming new memories. On one of those nights, you had a flashback. Your father had always talked about a new technology and how it would change the galaxy. You realized that he had used it on you. This technology dealt with a proof of a work function and decentralized networks. Along with Virgil’s help, you had a “Eureka!” moment, but his approach, brute forcing, meant draining all your energy. Can you find a quicker way to validate new memory blocks?
Objective
- Break a custom hash function with z3 and minor brute forcing.
Difficulty
medium/medium-hard
Flag
HTB{K14u5_h45_z3_h42d_c0d3d_1n_m3m02y}
Downloadables
source.py
: contains the server code.
Attack
Analyzing the source code
Looking at the source.py
script we can see that our goal is to somehow find
a way to make the phash
function output 0, 4 times.
- A memory block is loaded and awaites validation.
- 2 keys are requested by the debuging interface.
- The phash is calculated. If valid continue if not exit()
- The new memory is appended to the previous block and the process is repeated.
Step 1, 2 and 4 are not that interesting to analyze.
We are going to jump emedietly to the phash
function.
2019
1 proof_of_work = phash(block, first_key, second_key)
phash is import from another file
1from pofwork import phash
So let’s take a look on the pofwork file.
We first see a sbox, the sub
function, the rotl
function and the phash
function.
The sub
function is just a subsitution function that uses the default sbox
from the wikipedia of AES. So it’s not interesting.
The rotl
function is just a standard bit rotating function. Also nothing interesting
there.
Let’s now look the juicy phash
function.
We see that there is a manipulation of the memory block that is takes as
an input.
It hashes the block, then it expands it and then separates it to blocks of size 4.
1 block = md5(block.encode()).digest()
2 block = 4 * block
3 blocks = [int.from_bytes(block[i:i+4],'big') for i in range(0, len(block), 4)]
After that some interesting variables are initialized. And something to note is that key1 is used here.
1 m = 0xffffffff
2 rv1, rv2 = 0x2423380b4d045, 0x3b30fa7ccaa83
3 x, y, z, u = key1, 0x39ef52e9f30b3, 0x253ea615d0215, 0x2cd1372d21d77
After that 13 rounds of chaos basically happens with a lot of bitwise operations and h is calculated from rv1.
1 for i in range(13):
2 x, y, z, u = blocks[i] ^ x, blocks[i+1] ^ y, blocks[i+2] ^ z, blocks[i+3] ^ u
3 rv1 ^= (x := (x & m) * (m + (y >> 16)) ^ rotl(z, 3))
4 rv2 ^= (y := (y & m) * (m + (z >> 16)) ^ rotl(x, 3))
5 rv1, rv2 = rv2, rv1
6 rv1 = sub(rv1)
7 rv1 = bytes_to_long(rv1)
8
9 h = rv1 + 0x6276137d7 & m
Finally, key2 is passed to sub
, and some bitwise operations are performed with
h and key2
1 for i, d in enumerate(key2):
2 a = (h << 1) & m
3 b = (h << 3) & m
4 c = (h >> 4) & m
5 h ^= (a + b + c - d)
6 h += h
7 h &= m
8
9 h *= u * z
10 h &= m
11
12 return h
A small recap of all the interesting stuff we found out until now:
- key1 is used only to calculate h until some point
- key2 and h are used for the final hash calculation with some simpler bitwise operations.
Finding the vulnerability
There are many things wrong with this hash function, but our goal is to somehow make it output 0. If we break it down to simpler problems we can see that it has parts that can be solved with z3 and a little bit of bruteforcing.
Exploitation
Connecting to the server
A pretty basic script for connecting to the server with pwntools
:
1if __name__ == '__main__':
2 r = remote('0.0.0.0', 1337)
3 pwn()
Getting the first point
When somebody connects to the server, a block of memory is calculated. In order to fetch it from the server we can use:
1def getBlockToValidate():
2 debug_msg = r.recvline()
3 block = debug_msg.decode().strip()[len('DEBUG MSG - You need to validate this memory block: '):]
4 return block
z3 magic
As hinted above the hash function needs to be dissected to sections in order to help z3 solve a solvable system.
More specifically we can ignore the entire first half of the hash function and focus only on the second one.
1 h = rv1 + 0x6276137d7 & m
2 key2 = sub(key2)
3
4 for i, d in enumerate(key2):
5 a = (h << 1) & m
6 b = (h << 3) & m
7 c = (h >> 4) & m
8 h ^= (a + b + c - d)
9 h += h
10 h &= m
11
12 h *= u * z
13 h &= m
How can we simplify the problem more for z3?
We know that h is deterministic for the pair block, key1.
We can also see that sub(key2)
is reversable and that in order for h
to be 0 the last 2 lines are irrelevant.
So the problem becomes:
If we can control key and h is a known value how can we make h = 0 after the bitwise operations.
1 for i, d in enumerate(key2):
2 a = (h << 1) & m
3 b = (h << 3) & m
4 c = (h >> 4) & m
5 h ^= (a + b + c - d)
6 h += h
7 h &= m
z3 to the rescue. Let’s create a model.
- We create a solver.
- Create the bit vectors
- Add our restrictions.
- Solve the model if it is solvable.
1def findSecondKeyWithZ3(block, key1):
2 s = Solver()
3
4 xs = list(BitVecs('c0 c1 c2 c3 c4 c5', 32))
5 h = BitVec('hs', 32)
6
7 target_h = phashFirstHalf(block, key1)
8 s.add(h == target_h)
9
10 for i, e in enumerate(xs):
11 s.add(e >= 0)
12 s.add(e < 255)
13 h ^= ((h << 1) + (h << 3) + rotr(h, 4) - e)
14 h += h
15
16 s.add(h == 0)
17
18 if (s.check() != unsat):
19 m = s.model()
20 s = bytes([m[c].as_long() for c in xs])
21
22 assert phashSecondHalf(s, target_h) == 0
23 return(s)
24 else:
25 return('unsat')
Brute forcing key1
You might wondering why the findSecondKeyWithZ3()
is a function that takes
key1 as an input.. Well not every key1 creates a solvable system for z3 so we
need to brute force it a little bit. But don’t worry it’s a tiny brute force.
1def findKeys(block):
2 first_key = 0
3 while 1:
4 second_key = findSecondKeyWithZ3(block, first_key)
5 if (second_key != 'unsat'):
6 return(first_key, bytes_to_long(isub(second_key)))
7 first_key += 1
Sending the keys
1def sendKeys(first_key, second_key):
2 r.sendlineafter(b'DEBUG MSG - Enter first key: ', bytes(str(first_key), 'Latin'))
3 r.sendlineafter(b'DEBUG MSG - Enter second key: ', bytes(str(second_key), 'Latin'))
Getting the flag
A final recap of all the above:
- We connected to the server.
- We got the block of memory.
- We found out that a section of our hash function can be solved with z3.
- We brute forced key1 such that the z3 system would be solvable.
- We run our solver and found key2.
- We finaly send the keys and get the next block.
- Repeat this process 4 times and we get the flag :)
This recap can be reprisented by code with the pwn()
function.
1def pwn():
2 for _ in range(4):
3 skipMessages()
4 block = getBlockToValidate()
5 print(f"Fetched block of memory that needs validation")
6 first_key, second_key = findKeys(block)
7 print(f"Found valid keys: {first_key}, {second_key} for block: {block}")
8 sendKeys(first_key, second_key)
9 print("Moving on to the next block")
10 r.interactive()
Challenge Code
The server code is presented below:
1import socketserver
2import signal
3from pofwork import phash
4
5
6DEBUG_MSG = "DEBUG MSG - "
7WELCOME_MSG = """Virgil says:
8Klaus I'm connecting the serial debugger to your memory.
9Please stay still. We don't want anything wrong to happen.
10Ok you should be able to see debug messages now..\n\n"""
11
12
13with open('memories.txt', 'r') as f:
14 MEMORIES = [m.strip() for m in f.readlines()]
15
16
17class Handler(socketserver.BaseRequestHandler):
18 def handle(self):
19 signal.alarm(0)
20 main(self.request)
21
22
23class ReusableTCPServer(socketserver.ForkingMixIn, socketserver.TCPServer):
24 pass
25
26
27def sendMessage(s, msg):
28 s.send(msg.encode())
29
30
31def recieveMessage(s, msg):
32 sendMessage(s, msg)
33 return s.recv(4096).decode().strip()
34
35
36def main(s):
37 block = ""
38 counter = 0
39 sendMessage(s, WELCOME_MSG)
40
41 while True:
42 block += MEMORIES[counter]
43
44 sendMessage(s, DEBUG_MSG +
45 f"You need to validate this memory block: {block}\n")
46
47 first_key = recieveMessage(s, DEBUG_MSG + "Enter first key: ")
48 second_key = recieveMessage(s, DEBUG_MSG + "Enter second key: ")
49
50 try:
51 first_key, second_key = int(first_key), int(second_key)
52 proof_of_work = phash(block, first_key, second_key)
53 except:
54 sendMessage(s, "\nVirgil says: \n"
55 "Be carefull Klaus!! You don't want to damage yourself.\n"
56 "Let's start over.")
57 exit()
58
59 if proof_of_work == 0:
60 block += f" ({first_key}, {second_key}). "
61 sendMessage(s, "\nVirgil says: \nWow you formed a new memory!!\n")
62 counter += 1
63 sendMessage(
64 s, f"Let's try again {4 - counter} times just to be sure!\n\n")
65 else:
66 sendMessage(s, DEBUG_MSG + f"Incorect proof of work\n"
67 "\nVirgil says: \n"
68 "You calculated something wrong Klaus we need to start over.")
69 exit()
70
71 if counter == 4:
72 sendMessage(s, "It seems that everything are working fine.\n"
73 "Wait what is that...\n"
74 "Klaus this is important!!\n"
75 "This can help you find your father!!\n"
76 f"{MEMORIES[-1]}")
77 exit()
78
79
80if __name__ == '__main__':
81 socketserver.TCPServer.allow_reuse_address = True
82 server = ReusableTCPServer(("0.0.0.0", 1337), Handler)
83 server.serve_forever()
The pofwork:
1from hashlib import md5
2from Crypto.Util.number import long_to_bytes, bytes_to_long
3
4
5sbox = [
6 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76,
7 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
8 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
9 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
10 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84,
11 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
12 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8,
13 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
14 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
15 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
16 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79,
17 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
18 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a,
19 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
20 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
21 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16
22]
23
24
25def rotl(n, b):
26 return ((n << b) | (n >> (32 - b))) & 0xffffffff
27
28
29def sub(b):
30 b = long_to_bytes(b)
31 return bytes([sbox[i] for i in b])
32
33
34def phash(block, key1, key2):
35 block = md5(block.encode()).digest()
36 block = 4 * block
37 blocks = [bytes_to_long(block[i:i+4]) for i in range(0, len(block), 4)]
38
39 m = 0xffffffff
40 rv1, rv2 = 0x2423380b4d045, 0x3b30fa7ccaa83
41 x, y, z, u = key1, 0x39ef52e9f30b3, 0x253ea615d0215, 0x2cd1372d21d77
42
43 for i in range(13):
44 x, y = blocks[i] ^ x, blocks[i+1] ^ y
45 z, u = blocks[i+2] ^ z, blocks[i+3] ^ u
46 rv1 ^= (x := (x & m) * (m + (y >> 16)) ^ rotl(z, 3))
47 rv2 ^= (y := (y & m) * (m + (z >> 16)) ^ rotl(x, 3))
48 rv1, rv2 = rv2, rv1
49 rv1 = sub(rv1)
50 rv1 = bytes_to_long(rv1)
51
52 h = rv1 + 0x6276137d7 & m
53 key2 = sub(key2)
54
55 for i, d in enumerate(key2):
56 a = (h << 1) & m
57 b = (h << 3) & m
58 c = (h >> 4) & m
59 h ^= (a + b + c - d)
60 h += h
61 h &= m
62
63 h *= u * z
64 h &= m
65
66 return h
Solver
1import random
2from hashlib import md5
3from Crypto.Util.number import *
4from z3 import *
5from pwn import *
6
7
8sbox = [
9 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76,
10 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0,
11 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15,
12 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75,
13 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84,
14 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf,
15 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8,
16 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2,
17 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73,
18 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb,
19 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79,
20 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08,
21 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a,
22 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e,
23 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
24 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16
25]
26
27
28def rotl(n, b):
29 return ((n << b) | (n >> (32 - b))) & 0xffffffff
30
31
32def rotr(v, n):
33 return (v >> n) & ((1 << (32 - n)) - 1)
34
35
36def sub(b):
37 b = long_to_bytes(b)
38 return bytes([sbox[i] for i in b])
39
40
41def isub(b):
42 return bytes([sbox.index(i) for i in b])
43
44
45def phashFirstHalf(text, key1):
46 text = md5(text.encode()).digest()
47 text = 4 * text
48 text = [int.from_bytes(text[i:i+4], 'big') for i in range(0, len(text), 4)]
49
50 m = 0xffffffff
51 rv1, rv2 = 0x2423380b4d045, 0x3b30fa7ccaa83
52 x, y, z, u = key1, 0x39ef52e9f30b3, 0x253ea615d0215, 0x2cd1372d21d77
53
54 for i in range(13):
55 x, y, z, u = text[i] ^ x, text[i+1] ^ y, text[i+2] ^ z, text[i+3] ^ u
56 rv1 ^= (x := (x & m) * (m + (y >> 16)) ^ rotl(z, 3))
57 rv2 ^= (y := (y & m) * (m + (z >> 16)) ^ rotl(x, 3))
58 rv1, rv2 = rv2, rv1
59 rv1 = sub(rv1)
60 rv1 = bytes_to_long(rv1)
61
62 h = rv1 + 0x6276137d7 & m
63 return h
64
65
66def phashSecondHalf(s, target_h):
67 h = target_h
68 for i, d in enumerate(s):
69 a = (h << 1) & 0xffffffff
70 b = (h << 3) & 0xffffffff
71 c = (h >> 4) & 0xffffffff
72 h ^= (a + b + c - d)
73 h += h
74 h &= 0xffffffff
75 return h
76
77
78def skipMessages():
79 for _ in range(5):
80 r.recvline()
81
82
83def getBlockToValidate():
84 debug_msg = r.recvline()
85 block = debug_msg.decode().strip()[len(
86 'DEBUG MSG - You need to validate this memory block: '):]
87 return block
88
89
90def findSecondKeyWithZ3(block, key1):
91 s = Solver()
92
93 target_h = phashFirstHalf(block, key1)
94
95 xs = list(BitVecs('c0 c1 c2 c3 c4 c5', 32))
96 h = BitVec('hs', 32)
97
98 s.add(h == target_h)
99
100 for i, e in enumerate(xs):
101 s.add(e >= 0)
102 s.add(e < 255)
103 h ^= ((h << 1) + (h << 3) + rotr(h, 4) - e)
104 h += h
105
106 s.add(h == 0)
107
108 if (s.check() != unsat):
109 m = s.model()
110 s = bytes([m[c].as_long() for c in xs])
111
112 assert phashSecondHalf(s, target_h) == 0
113 return(s)
114 else:
115 return('unsat')
116
117
118def findKeys(block):
119 first_key = 0
120 while 1:
121 second_key = findSecondKeyWithZ3(block, first_key)
122 if (second_key != 'unsat'):
123 return(first_key, bytes_to_long(isub(second_key)))
124 first_key += 1
125
126
127def sendKeys(first_key, second_key):
128 r.sendlineafter(b'DEBUG MSG - Enter first key: ',
129 bytes(str(first_key), 'Latin'))
130 r.sendlineafter(b'DEBUG MSG - Enter second key: ',
131 bytes(str(second_key), 'Latin'))
132
133
134def pwn():
135 for _ in range(4):
136 skipMessages()
137 block = getBlockToValidate()
138 print(f"Fetched block of memory that needs validation")
139 first_key, second_key = findKeys(block)
140 print(
141 f"Found valid keys: {first_key}, {second_key} for block: {block}")
142 sendKeys(first_key, second_key)
143 print("Moving on to the next block")
144 r.interactive()
145
146
147if __name__ == '__main__':
148 r = remote('localhost', 1337)
149 pwn()