mirror of
https://github.com/alexbers/mtprotoproxy.git
synced 2026-03-13 23:03:09 +00:00
simplify end of data detection and refactoring
This commit is contained in:
@@ -130,16 +130,18 @@ class CryptoWrappedStreamReader:
|
||||
|
||||
async def read(self, n):
|
||||
if self.buf:
|
||||
ret = self.buf
|
||||
ret = bytes(self.buf)
|
||||
self.buf.clear()
|
||||
return ret
|
||||
else:
|
||||
readed = await self.stream.read(n)
|
||||
data = await self.stream.read(n)
|
||||
if not data:
|
||||
return b""
|
||||
|
||||
needed_till_full_block = -len(readed) % self.block_size
|
||||
needed_till_full_block = -len(data) % self.block_size
|
||||
if needed_till_full_block > 0:
|
||||
readed += self.stream.readexactly(needed_till_full_block)
|
||||
return self.decryptor.decrypt(readed)
|
||||
data += self.stream.readexactly(needed_till_full_block)
|
||||
return self.decryptor.decrypt(data)
|
||||
|
||||
async def readexactly(self, n):
|
||||
if n > len(self.buf):
|
||||
@@ -189,17 +191,15 @@ class MTProtoFrameStreamReader:
|
||||
msg_len_bytes = await self.stream.readexactly(4)
|
||||
msg_len = int.from_bytes(msg_len_bytes, "little")
|
||||
|
||||
len_is_impossible = (msg_len % len(PADDING_FILLER) != 0)
|
||||
if not MIN_MSG_LEN <= msg_len <= MAX_MSG_LEN or len_is_impossible:
|
||||
len_is_bad = (msg_len % len(PADDING_FILLER) != 0)
|
||||
if not MIN_MSG_LEN <= msg_len <= MAX_MSG_LEN or len_is_bad:
|
||||
print("msg_len is bad, closing connection", msg_len)
|
||||
self.stream.feed_eof()
|
||||
return b""
|
||||
|
||||
msg_seq_bytes = await self.stream.readexactly(4)
|
||||
msg_seq = int.from_bytes(msg_seq_bytes, "little", signed=True)
|
||||
if msg_seq != self.seq_no:
|
||||
print("unexpected seq_no")
|
||||
self.stream.feed_eof()
|
||||
return b""
|
||||
|
||||
self.seq_no += 1
|
||||
@@ -210,11 +210,32 @@ class MTProtoFrameStreamReader:
|
||||
|
||||
computed_checksum = binascii.crc32(msg_len_bytes + msg_seq_bytes + data)
|
||||
if computed_checksum != checksum:
|
||||
self.stream.feed_eof()
|
||||
return b""
|
||||
return data
|
||||
|
||||
|
||||
class MTProtoFrameStreamWriter:
|
||||
def __init__(self, stream, seq_no=0):
|
||||
self.stream = stream
|
||||
self.seq_no = seq_no
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.stream, attr)
|
||||
|
||||
def write(self, msg):
|
||||
len_bytes = int.to_bytes(len(msg) + 4 + 4 + 4, 4, "little")
|
||||
seq_bytes = int.to_bytes(self.seq_no, 4, "little", signed=True)
|
||||
self.seq_no += 1
|
||||
|
||||
msg_without_checksum = len_bytes + seq_bytes + msg
|
||||
checksum = int.to_bytes(binascii.crc32(msg_without_checksum), 4, "little")
|
||||
|
||||
full_msg = msg_without_checksum + checksum
|
||||
padding = PADDING_FILLER * ((-len(full_msg) % CBC_PADDING) // len(PADDING_FILLER))
|
||||
|
||||
return self.stream.write(full_msg + padding)
|
||||
|
||||
|
||||
class MTProtoCompactFrameStreamReader:
|
||||
def __init__(self, stream):
|
||||
self.stream = stream
|
||||
@@ -268,28 +289,6 @@ class MTProtoCompactFrameStreamWriter:
|
||||
return 0
|
||||
|
||||
|
||||
class MTProtoFrameStreamWriter:
|
||||
def __init__(self, stream, seq_no=0):
|
||||
self.stream = stream
|
||||
self.seq_no = seq_no
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.stream, attr)
|
||||
|
||||
def write(self, msg):
|
||||
len_bytes = int.to_bytes(len(msg) + 4 + 4 + 4, 4, "little")
|
||||
seq_bytes = int.to_bytes(self.seq_no, 4, "little", signed=True)
|
||||
self.seq_no += 1
|
||||
|
||||
msg_without_checksum = len_bytes + seq_bytes + msg
|
||||
checksum = int.to_bytes(binascii.crc32(msg_without_checksum), 4, "little")
|
||||
|
||||
full_msg = msg_without_checksum + checksum
|
||||
padding = PADDING_FILLER * ((-len(full_msg) % CBC_PADDING) // len(PADDING_FILLER))
|
||||
|
||||
return self.stream.write(full_msg + padding)
|
||||
|
||||
|
||||
class ProxyReqStreamReader:
|
||||
def __init__(self, stream):
|
||||
self.stream = stream
|
||||
@@ -308,7 +307,6 @@ class ProxyReqStreamReader:
|
||||
|
||||
ans_type, ans_flags, conn_id, conn_data = data[:4], data[4:8], data[8:16], data[16:]
|
||||
if ans_type == RPC_CLOSE_EXT:
|
||||
self.feed_eof()
|
||||
return b""
|
||||
|
||||
if ans_type != RPC_PROXY_ANS:
|
||||
|
||||
Reference in New Issue
Block a user