手搓zip解码器
jerem1ah Lv4

zlib.decompress()

Reference

https://pyokagan.name/blog/2019-10-18-zlibinflate/ //文档

https://zhidao.baidu.com/question/310628609.html //<<= >>=

https://zhuanlan.zhihu.com/p/352145413 //大端序 小端序

https://www.cnblogs.com/junyuhuang/p/4138376.html //lz77

https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-wusp/fb98aa28-5cd7-407f-8869-a6cef1ff1ccb //lz77

https://blog.csdn.net/lzq20115395/article/details/78906863 //huffman

https://zhuanlan.zhihu.com/p/103908133 //huffman

https://github.com/BuptMerak/mrctf-2022-writeups/blob/main/offical/MISC.md //mrctf jpeg and the tree

https://nano.ac/posts/e8de5668/#jpeg-and-the-tree //nano’s solve

https://room2042.gitlab.io/writeup/2021-07-24-google_ctf-david_and_the_tree/ //google huffman tree

https://zhuanlan.zhihu.com/p/72044095 //jpeg

https://room2042.gitlab.io/writeup/2021-07-24-google_ctf-david_and_the_tree/

https://formats.kaitai.io/zip/python.html

https://github.com/nayuki/Simple-DEFLATE-decompressor/blob/master/python/deflatedecompress.py

https://github.com/luker983/google-ctf-2021/tree/main/misc

step by step

The zlib container

The zlib format:

  • 1 bytes: CMF — Compression Method and compression info fields
  • 1 bytes: FLG — Compression flags fields
  • Variable number of bytes: The deflate data
  • 4 bytes: ADLER32 — Adler-32 checksum over the original uncompressed data
  • End of file/bytestring

The CMF fields are as follow:

  • Bits 0 to 3: CM — Compression Method. Only CM=8 is defined by the zlib spec.
  • Bits 4 to 7: CINFO — Compression info. This is the base-2 logarithm of LZ77 window size, minus eight. In other words, the window size is 2^(CINFO+8). The maximum windows size that is allowed by the spec is 32768, which is 2^15. In other words, CINFO must be <= 7.

The FLG fields are as follow:

-

1 0000 0000

1 0000 0001 = 256

285-255=30

1 1110 = 30

255 + 30 = 285

1 0001 1110 = 285

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class BitReader():
def __init__(self,mem):
self.mem = mem
self.pos = 0
self.b = 0
self.numbits = 0
def read_byte(self):
self.numbits = 0
b = self.mem[self.pos]
self.pos += 1
return b
def read_bit(self):
if self.numbits <= 0:
self.b = self.read_byte()
self.numbits = 8
self.numbits -= 1
bit = self.b & 1
self.b >>= 1
return bit
def read_bits(self,n):
o = 0
for i in range(n):
o |= self.read_bit() << i
return o
def read_bytes(n):
o = 0
for i in range(n):
o |= self.read_byte() << (i*8)
return o
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
class Node():
def __init__(self):
self.symbol = ''
self.left = None
self.right = None

class Huffman():
def __init__(self):
self.root = Node()
def insert(self, codeword, n, symbol):
node = self.root
for i in range(n-1, -1, -1):
b = codeword & (1 << i)
if b:
if node.right is None:
node.right = Node()
next_node = node.right
else:
if node.left is None:
node.left = Node()
next_node = node.left
node = next_node
node.symbol = symbol

1
2
3
4
5
6
7
8
9
10
def code_to_bytes(code, n):
out = [0]
numbits = 0
for i in range(n-1, -1, -1):
if numbits >= 8:
out.append(0)
numbits = 0
out[-1] |= (code & (1 << i)) << numbits
numbits += 1
return bytes(out)
1
2
3
4
5
6
def decode_symbol(r,t):
node = t.root
while node.left or node.right:
b = r.read_bit()
node = node.right if b else node.left
return node.symbol
1
2
3
def bl_list_to_tree(bl, alphabet):
MAX_BITS = max(bl)

alphabet==ABCD

bit length==2 1 3 3

bit length count 长度为0的有0个,长度为1的有1个,长度为2的有1个,长度为3的有2个 === 0 1 1 2

next_code[n] is the smallest codeword with code length n.

next_code=[0,0]

range 2,3,

bl每个符号的树深

bl_count 每个深度的符号个数

next_code每个深度的最左值。

image-20240513194522584

1
2
3
4
5
6
7
8
9
10
11
12
def bl_list_to_tree(bl, alphabet):
MAX_BITS = max(bl)
bl_count = [sum(1 for x in bl if x==y and y != 0)for y in (MAX_BITS + 1)]
next_code = [0,0]
for bits in range(2, MAX_BITS + 1):
next_code.append((next_code[bits - 1] + bl_count[bits - 1]) << 1)
t = HuffmanTree()
for c, bitlen in zip(alphabet, bl):
if bitlen != 0:
t.insert(next_code[bitlen], bitlen, c)
next_code[bitlen] += 1
return t

由字母表和长度表推算二进制表示方法,进制构建huffman树

Full source code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
class BitReader:
def __init__(self, mem):
self.mem = mem
self.pos = 0
self.b = 0
self.numbits = 0

def read_byte(self):
self.numbits = 0 # discard unread bits
b = self.mem[self.pos]
self.pos += 1
return b

def read_bit(self):
if self.numbits <= 0:
self.b = self.read_byte()
self.numbits = 8
self.numbits -= 1
# shift bit out of byte
bit = self.b & 1
self.b >>= 1
return bit

def read_bits(self, n):
o = 0
for i in range(n):
o |= self.read_bit() << i
return o

def read_bytes(self, n):
# read bytes as an integer in little-endian
o = 0
for i in range(n):
o |= self.read_byte() << (8 * i)
return o

def decompress(input):
r = BitReader(input)
CMF = r.read_byte()
CM = CMF & 15 # Compression method
if CM != 8: # only CM=8 is supported
raise Exception('invalid CM')
CINFO = (CMF >> 4) & 15 # Compression info
if CINFO > 7:
raise Exception('invalid CINFO')
FLG = r.read_byte()
if (CMF * 256 + FLG) % 31 != 0:
raise Exception('CMF+FLG checksum failed')
FDICT = (FLG >> 5) & 1 # preset dictionary?
if FDICT:
raise Exception('preset dictionary not supported')
out = inflate(r) # decompress DEFLATE data
ADLER32 = r.read_bytes(4) # Adler-32 checksum (for this exercise, we ignore it)
return out

def inflate(r):
BFINAL = 0
out = []
while not BFINAL:
BFINAL = r.read_bit()
BTYPE = r.read_bits(2)
if BTYPE == 0:
inflate_block_no_compression(r, out)
elif BTYPE == 1:
inflate_block_fixed(r, out)
elif BTYPE == 2:
inflate_block_dynamic(r, out)
else:
raise Exception('invalid BTYPE')
return bytes(out)

def inflate_block_no_compression(r, o):
LEN = r.read_bytes(2)
NLEN = r.read_bytes(2)
o.extend(r.read_byte() for _ in range(LEN))

def code_to_bytes(code, n):
# Encodes a code that is `n` bits long into bytes that is conformant with DEFLATE spec
out = [0]
numbits = 0
for i in range(n-1, -1, -1):
if numbits >= 8:
out.append(0)
numbits = 0
out[-1] |= (1 if code & (1 << i) else 0) << numbits
numbits += 1
return bytes(out)

class Node:
def __init__(self):
self.symbol = ''
self.left = None
self.right = None

class HuffmanTree:
def __init__(self):
self.root = Node()
self.root.symbol = ''

def insert(self, codeword, n, symbol):
# Insert an entry into the tree mapping `codeword` of len `n` to `symbol`
node = self.root
for i in range(n-1, -1, -1):
b = codeword & (1 << i)
if b:
next_node = node.right
if next_node is None:
node.right = Node()
next_node = node.right
else:
next_node = node.left
if next_node is None:
node.left = Node()
next_node = node.left
node = next_node
node.symbol = symbol

def decode_symbol(r, t):
"Decodes one symbol from bitstream `r` using HuffmanTree `t`"
node = t.root
while node.left or node.right:
b = r.read_bit()
node = node.right if b else node.left
return node.symbol

LengthExtraBits = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3,
3, 4, 4, 4, 4, 5, 5, 5, 5, 0]
LengthBase = [3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, 35, 43,
51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258]
DistanceExtraBits = [0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7,
8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13]
DistanceBase = [1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257,
385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145, 8193, 12289, 16385,
24577]

def inflate_block_data(r, literal_length_tree, distance_tree, out):
while True:
sym = decode_symbol(r, literal_length_tree)
if sym <= 255: # Literal byte
out.append(sym)
elif sym == 256: # End of block
return
else: # <length, backward distance> pair
sym -= 257
length = r.read_bits(LengthExtraBits[sym]) + LengthBase[sym]
dist_sym = decode_symbol(r, distance_tree)
dist = r.read_bits(DistanceExtraBits[dist_sym]) + DistanceBase[dist_sym]
for _ in range(length):
out.append(out[-dist])

def bl_list_to_tree(bl, alphabet):
MAX_BITS = max(bl)
bl_count = [sum(1 for x in bl if x == y and y != 0) for y in range(MAX_BITS+1)]
next_code = [0, 0]
for bits in range(2, MAX_BITS+1):
next_code.append((next_code[bits-1] + bl_count[bits-1]) << 1)
t = HuffmanTree()
for c, bitlen in zip(alphabet, bl):
if bitlen != 0:
t.insert(next_code[bitlen], bitlen, c)
next_code[bitlen] += 1
return t

CodeLengthCodesOrder = [16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15]

def decode_trees(r):
# The number of literal/length codes
HLIT = r.read_bits(5) + 257

# The number of distance codes
HDIST = r.read_bits(5) + 1

# The number of code length codes
HCLEN = r.read_bits(4) + 4

# Read code lengths for the code length alphabet
code_length_tree_bl = [0 for _ in range(19)]
for i in range(HCLEN):
code_length_tree_bl[CodeLengthCodesOrder[i]] = r.read_bits(3)

# Construct code length tree
code_length_tree = bl_list_to_tree(code_length_tree_bl, range(19))

# Read literal/length + distance code length list
bl = []
while len(bl) < HLIT + HDIST:
sym = decode_symbol(r, code_length_tree)
if 0 <= sym <= 15: # literal value
bl.append(sym)
elif sym == 16:
# copy the previous code length 3..6 times.
# the next 2 bits indicate repeat length ( 0 = 3, ..., 3 = 6 )
prev_code_length = bl[-1]
repeat_length = r.read_bits(2) + 3
bl.extend(prev_code_length for _ in range(repeat_length))
elif sym == 17:
# repeat code length 0 for 3..10 times. (3 bits of length)
repeat_length = r.read_bits(3) + 3
bl.extend(0 for _ in range(repeat_length))
elif sym == 18:
# repeat code length 0 for 11..138 times. (7 bits of length)
repeat_length = r.read_bits(7) + 11
bl.extend(0 for _ in range(repeat_length))
else:
raise Exception('invalid symbol')

# Construct trees
literal_length_tree = bl_list_to_tree(bl[:HLIT], range(286))
distance_tree = bl_list_to_tree(bl[HLIT:], range(30))
return literal_length_tree, distance_tree

def inflate_block_dynamic(r, o):
literal_length_tree, distance_tree = decode_trees(r)
inflate_block_data(r, literal_length_tree, distance_tree, o)

def inflate_block_fixed(r, o):
bl = ([8 for _ in range(144)] + [9 for _ in range(144, 256)] +
[7 for _ in range(256, 280)] + [8 for _ in range(280, 288)])
literal_length_tree = bl_list_to_tree(bl, range(286))

bl = [5 for _ in range(30)]
distance_tree = bl_list_to_tree(bl, range(30))

inflate_block_data(r, literal_length_tree, distance_tree, o)

import zlib
x = zlib.compress(b'Hello World!')
print(decompress(x)) # b'Hello World!'

google ctf solution 1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#!/usr/bin/env python3
import io

import deflatedecompress # modified from https://github.com/nayuki/Simple-DEFLATE-decompressor/blob/master/python/deflatedecompress.py
import zip # from https://formats.kaitai.io/zip/python.html


parsed_zip = zip.Zip.from_file('challenge.zip')

flag = []
for section_number, section in enumerate(parsed_zip.sections):
if isinstance(section.body, zip.Zip.LocalFile):
deflated_file = section.body.body
else:
continue

deflated_buf = deflatedecompress.BitInputStream(io.BytesIO(deflated_file))
inflated_buf = deflatedecompress.Decompressor.decompress_to_bytes(deflated_buf)

table = deflatedecompress.hacky._code_bits_to_symbol

for k, v in table.items():
if v == ord('E'):
k = k & 0xFF

# reverse bit order
rk = 0
for _ in range(8):
rk = rk << 1
if k & 1 == 1:
rk = rk | 1
k = k >> 1

flag.append(rk)

print(bytes(flag).decode())

google ctf solution 2

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
# huffman tree decoding adapted from Paul Tan's source code at https://pyokagan.name/blog/2019-10-18-zlibinflate/
import zipfile
import networkx as nx

# extract raw compressed data from zip file
def carve(zip_name, metadata_size):
# list of tuples, (offset, size)
files = []
# parse zip file
with zipfile.ZipFile(zip_name, 'r') as zip_file:
# iterate over each file in zip
for elem in zip_file.infolist():
offset = elem.header_offset + metadata_size
compress_size = elem.compress_size
files.append((offset, compress_size))

compressed_data = []
with open(zip_name, 'rb') as raw_zip:
for f in files:
# seek to offset and read compressed_size bytes
raw_zip.seek(f[0])
compressed_data.append(raw_zip.read(f[1]))

return compressed_data

# parse raw compressed data to get huffman tree (ignore distance tree)
def get_huffman_tree(raw):
r = BitReader(raw)
# read BFINAL and BTYPE
BFINAL = r.read_bit()
BTYPE = r.read_bits(2)
literal_length_tree, distance_tree = decode_trees(r)
return literal_length_tree

### from https://pyokagan.name/blog/2019-10-18-zlibinflate/ ###
class BitReader:
def __init__(self, mem):
self.mem = mem
self.pos = 0
self.b = 0
self.numbits = 0

def read_byte(self):
self.numbits = 0 # discard unread bits
b = self.mem[self.pos]
self.pos += 1
return b

def read_bit(self):
if self.numbits <= 0:
self.b = self.read_byte()
self.numbits = 8
self.numbits -= 1
# shift bit out of byte
bit = self.b & 1
self.b >>= 1
return bit

def read_bits(self, n):
o = 0
for i in range(n):
o |= self.read_bit() << i
return o

class Node:
def __init__(self):
self.symbol = ''
self.left = None
self.right = None

class HuffmanTree:
def __init__(self):
self.root = Node()

def make_graph(self):
g = nx.DiGraph()
g = self.walk_graph(self.root, None, g)
return g

def make_tree(self):
tree = Tree()
tree = self.walk(self.root, None, tree)
return tree

def walk_graph(self, node, parent, g, edge_label=None):
if not parent:
g.add_node(id(node), label='root')
else:
if node.symbol != '':
label = node.symbol
else:
label = ''
g.add_node(id(node), label=label)
g.add_edge(id(parent), id(node), label=edge_label)
if node.left:
self.walk_graph(node.left, node, g, '0')
if node.right:
self.walk_graph(node.right, node, g, '1')
return g

def walk(self, node, parent, tree):
if not parent:
tree.create_node(node.symbol, node.symbol)
else:
tree.create_node(node.symbol, node.symbol, parent=parent.symbol)
if node.left:
self.walk(node.left, node, tree)
if node.right:
self.walk(node.right, node, tree)
return tree

def insert(self, codeword, n, symbol):
# Insert an entry into the tree mapping `codeword` of len `n` to `symbol`
node = self.root

# if inserting symbol 69 ('E'), follow bit path
p = False
bits = b''
if symbol == 69:
p = True
for i in range(n-1, -1, -1):
b = codeword & (1 << i)
if b:
bits += b'1'
next_node = node.right
if next_node is None:
node.right = Node()
next_node = node.right
else:
bits += b'0'
next_node = node.left
if next_node is None:
node.left = Node()
next_node = node.left
node = next_node
# print bit path in reverse to get character of flag
if p:
print(chr(int(bits[::-1], 2)), end='')

node.symbol = symbol

def decode_symbol(r, t):
"Decodes one symbol from bitstream `r` using HuffmanTree `t`"
node = t.root
while node.left or node.right:
b = r.read_bit()
node = node.right if b else node.left
return node.symbol

LengthExtraBits = [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3,
3, 4, 4, 4, 4, 5, 5, 5, 5, 0]
LengthBase = [3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, 35, 43,
51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258]
DistanceExtraBits = [0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7,
8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13]
DistanceBase = [1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257,
385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145, 8193, 12289, 16385,
24577]

def bl_list_to_tree(bl, alphabet):
MAX_BITS = max(bl)
bl_count = [sum(1 for x in bl if x == y and y != 0) for y in range(MAX_BITS+1)]
next_code = [0, 0]
for bits in range(2, MAX_BITS+1):
next_code.append((next_code[bits-1] + bl_count[bits-1]) << 1)
t = HuffmanTree()
test = []
for c, bitlen in zip(alphabet, bl):
if bitlen != 0:
if c < 256:
test.append(c)
t.insert(next_code[bitlen], bitlen, c)
next_code[bitlen] += 1
if len(alphabet) == 286:
a = []
for b in range(0, 256):
if b not in test:
a.append(b)
return t

CodeLengthCodesOrder = [16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15]

def decode_trees(r):
# The number of literal/length codes
HLIT = r.read_bits(5) + 257

# The number of distance codes
HDIST = r.read_bits(5) + 1

# The number of code length codes
HCLEN = r.read_bits(4) + 4

# Read code lengths for the code length alphabet
code_length_tree_bl = [0 for _ in range(19)]
for i in range(HCLEN):
code_length_tree_bl[CodeLengthCodesOrder[i]] = r.read_bits(3)

# Construct code length tree
code_length_tree = bl_list_to_tree(code_length_tree_bl, range(19))

# Read literal/length + distance code length list
bl = []
while len(bl) < HLIT + HDIST:
sym = decode_symbol(r, code_length_tree)
if 0 <= sym <= 15: # literal value
bl.append(sym)
elif sym == 16:
# copy the previous code length 3..6 times.
# the next 2 bits indicate repeat length ( 0 = 3, ..., 3 = 6 )
prev_code_length = bl[-1]
repeat_length = r.read_bits(2) + 3
bl.extend(prev_code_length for _ in range(repeat_length))
elif sym == 17:
# repeat code length 0 for 3..10 times. (3 bits of length)
repeat_length = r.read_bits(3) + 3
bl.extend(0 for _ in range(repeat_length))
elif sym == 18:
# repeat code length 0 for 11..138 times. (7 bits of length)
repeat_length = r.read_bits(7) + 11
bl.extend(0 for _ in range(repeat_length))
else:
raise Exception('invalid symbol')

# Construct trees
literal_length_tree = bl_list_to_tree(bl[:HLIT], range(286))
distance_tree = bl_list_to_tree(bl[HLIT:], range(30))
return literal_length_tree, distance_tree


if __name__ == "__main__":
# get raw compressed data from challenge.zip
challenge_data = carve('challenge.zip', 0x24)

# extract huffman tree, inserter will print when 'E' character is inserted
for i, f in enumerate(challenge_data):
get_huffman_tree(f)

last

mastered 80%. learn for a year…

google ctf 2021, mrctf 2022

todo

  • next_code = [0,0] why
  • zip.compress() and question script
  • jpeg and the tree…and png decode
 Comments