-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmessage.py
More file actions
214 lines (162 loc) · 6.22 KB
/
message.py
File metadata and controls
214 lines (162 loc) · 6.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
from queue import Queue
from enum import Enum
from struct import unpack, pack
import constants
"""Handle BitTorrent Protocol message parsing duties"""
class MessageException(Exception):
pass
class MessageParser:
"""Takes in bytestrings from a peer and interprets them
into messages.
Messages arrive in bytestrings. Each message begins with
a 4 byte prefix which describes its length.
We need to check the length prefix, then use that length to
retrieve the rest of the message.
A complication is that a bytestring coming from the peer can
contain multiple messages or can end in a partial message. In
the latter case we will need to check the next bytestring
"""
def __init__(self):
self.incomplete_message = b''
self.counter = 0
def __call__(self, bytestring):
"""Take in raw bytestring and returns generator that yields messages"""
# Waiting on message?
if self.incomplete_message:
bytestring = self.incomplete_message + bytestring
self.incomplete_message = b''
while bytestring:
try:
length_bytes, bytestring = _strip_message(bytestring, 4)
length = unpack('!I', length_bytes)[0]
except ValueError:
raise MessageException('Bad Length Bits | {}'.format(bytestring))
# NOTE: If this is breaking... it might be because we're getting incomplete length bits.
# Is this a complete message?
if len(bytestring) < length:
self.incomplete_message = length_bytes + bytestring
break
else:
message_bytes, bytestring = _strip_message(bytestring, length)
yield self._parse_message(message_bytes)
@staticmethod
def _parse_message(bytestring):
"""Returns an appropriate Message object"""
if not bytestring:
return Message.factory(MessageType.KEEP_ALIVE)
else:
id_byte, payload = _strip_message(bytestring, 1)
type_id = int.from_bytes(id_byte, byteorder='big')
message_type = MessageType(type_id)
if payload:
return Message.factory(message_type, payload)
else:
return Message.factory(message_type)
class MessageType(Enum):
CLOSE = -3
HANDSHAKE = -2
KEEP_ALIVE = -1
CHOKE = 0
UNCHOKE = 1
INTERESTED = 2
UNINTERESTED = 3
HAVE = 4
BITFIELD = 5
REQUEST = 6
PIECE = 7
class Message:
@staticmethod
def factory(message_type, raw_payload=None):
if message_type == MessageType.PIECE:
return PieceMessage(raw_payload)
elif message_type == MessageType.HANDSHAKE:
return HandShakeMessage(raw_payload)
else:
return Message(message_type, raw_payload)
def __init__(self, message_type, payload=None):
self.type = message_type
self.payload = payload
def to_bytes(self):
# form the length prefix.
if self.payload:
length = len(self.payload) + 1
else:
length = 1
length_bytes = pack('!I', length)
id_byte = int.to_bytes(self.type.value, length=1, byteorder='big')
bytestring = length_bytes + id_byte
if self.payload:
bytestring += self.payload
return bytestring
def __repr__(self):
return 'Message [ type: {} | payload: {} ]'.format(self.type.name, self.payload or 'No Payload')
class HandShakeMessage(Message):
def __init__(self, payload):
super().__init__(MessageType.HANDSHAKE, payload)
def to_bytes(self):
return self.payload
class PieceMessage(Message):
def __init__(self, raw_payload):
index_bytes = raw_payload[:4]
self.index = int.from_bytes(index_bytes, byteorder='big')
offset_bytes = raw_payload[4:8]
self.offset = int.from_bytes(offset_bytes, byteorder='big')
super().__init__(MessageType.PIECE, raw_payload[8:])
def __repr__(self):
return 'Message [ type: {} | index: {} ]'.format(self.type.name, self.index)
def _strip_message(message, index):
"""Splits an array at the given index"""
if index > len(message):
raise ValueError('Index Not Applicable | {} {}'.format(message, index))
strip = message[:index]
leftover = message[index:]
return strip, leftover
def is_handshake(data):
"""Tests if a message is a handshake message. Currently this is done by comparing
the first 20 bytes and making sure they match the initial 20 bytes of a
a handshake message (19BitTorrentProtocol)"""
pstrlen = len(constants.PSTR)
i = pstrlen + 1
return data[:i] == bytes([pstrlen]) + constants.PSTR
def get_handshake(client_id, info_hash):
"""Return a bytestring that represents our handshake to
a peer
Message format:
<pstrlen><pstr><reserved><info_hash><peer_id>
"""
pstrlen = bytes([len(constants.PSTR)])
payload = b"".join([pstrlen, constants.PSTR, constants.RESERVED, info_hash, client_id.encode()])
message = Message.factory(MessageType.HANDSHAKE, payload)
return message
def parse_handshake(data):
"""Interpret a handshake received from a peer.
Remember that there could be extra info at the end!"""
handshake = {}
pstrlen = len(constants.PSTR)
i = pstrlen + len(constants.RESERVED) + 1 # 1 for the leading byte
j = i + constants.INFO_HASH_LEN
handshake['info_hash'] = data[i:j]
k = j + constants.PEER_ID_LEN
handshake['peer_id'] = data[j:k].decode('utf-8')
handshake['extra'] = data[k:]
return handshake
class MessageQueue(Queue):
STOP = object()
def close(self):
self.put(self.STOP)
def __iter__(self):
while True:
message = self.get()
try:
if message is self.STOP:
return
yield message
finally:
self.task_done()
def message_queue_worker(message_queue, callback):
"""This little guy will keep trying to pull from
the queue until it's told not to.
Will call the the callback function then alert the
queue that work is done"""
for message in message_queue:
callback(message)