simplify end of data detection and refactoring

This commit is contained in:
Alexander Bersenev
2018-06-04 16:45:08 +05:00
parent 0648b41c17
commit 44ab6fef13

View File

@@ -130,16 +130,18 @@ class CryptoWrappedStreamReader:
async def read(self, n):
if self.buf:
ret = self.buf
ret = bytes(self.buf)
self.buf.clear()
return ret
else:
readed = await self.stream.read(n)
data = await self.stream.read(n)
if not data:
return b""
needed_till_full_block = -len(readed) % self.block_size
needed_till_full_block = -len(data) % self.block_size
if needed_till_full_block > 0:
readed += self.stream.readexactly(needed_till_full_block)
return self.decryptor.decrypt(readed)
data += self.stream.readexactly(needed_till_full_block)
return self.decryptor.decrypt(data)
async def readexactly(self, n):
if n > len(self.buf):
@@ -189,17 +191,15 @@ class MTProtoFrameStreamReader:
msg_len_bytes = await self.stream.readexactly(4)
msg_len = int.from_bytes(msg_len_bytes, "little")
len_is_impossible = (msg_len % len(PADDING_FILLER) != 0)
if not MIN_MSG_LEN <= msg_len <= MAX_MSG_LEN or len_is_impossible:
len_is_bad = (msg_len % len(PADDING_FILLER) != 0)
if not MIN_MSG_LEN <= msg_len <= MAX_MSG_LEN or len_is_bad:
print("msg_len is bad, closing connection", msg_len)
self.stream.feed_eof()
return b""
msg_seq_bytes = await self.stream.readexactly(4)
msg_seq = int.from_bytes(msg_seq_bytes, "little", signed=True)
if msg_seq != self.seq_no:
print("unexpected seq_no")
self.stream.feed_eof()
return b""
self.seq_no += 1
@@ -210,11 +210,32 @@ class MTProtoFrameStreamReader:
computed_checksum = binascii.crc32(msg_len_bytes + msg_seq_bytes + data)
if computed_checksum != checksum:
self.stream.feed_eof()
return b""
return data
class MTProtoFrameStreamWriter:
def __init__(self, stream, seq_no=0):
self.stream = stream
self.seq_no = seq_no
def __getattr__(self, attr):
return getattr(self.stream, attr)
def write(self, msg):
len_bytes = int.to_bytes(len(msg) + 4 + 4 + 4, 4, "little")
seq_bytes = int.to_bytes(self.seq_no, 4, "little", signed=True)
self.seq_no += 1
msg_without_checksum = len_bytes + seq_bytes + msg
checksum = int.to_bytes(binascii.crc32(msg_without_checksum), 4, "little")
full_msg = msg_without_checksum + checksum
padding = PADDING_FILLER * ((-len(full_msg) % CBC_PADDING) // len(PADDING_FILLER))
return self.stream.write(full_msg + padding)
class MTProtoCompactFrameStreamReader:
def __init__(self, stream):
self.stream = stream
@@ -268,28 +289,6 @@ class MTProtoCompactFrameStreamWriter:
return 0
class MTProtoFrameStreamWriter:
def __init__(self, stream, seq_no=0):
self.stream = stream
self.seq_no = seq_no
def __getattr__(self, attr):
return getattr(self.stream, attr)
def write(self, msg):
len_bytes = int.to_bytes(len(msg) + 4 + 4 + 4, 4, "little")
seq_bytes = int.to_bytes(self.seq_no, 4, "little", signed=True)
self.seq_no += 1
msg_without_checksum = len_bytes + seq_bytes + msg
checksum = int.to_bytes(binascii.crc32(msg_without_checksum), 4, "little")
full_msg = msg_without_checksum + checksum
padding = PADDING_FILLER * ((-len(full_msg) % CBC_PADDING) // len(PADDING_FILLER))
return self.stream.write(full_msg + padding)
class ProxyReqStreamReader:
def __init__(self, stream):
self.stream = stream
@@ -308,7 +307,6 @@ class ProxyReqStreamReader:
ans_type, ans_flags, conn_id, conn_data = data[:4], data[4:8], data[8:16], data[16:]
if ans_type == RPC_CLOSE_EXT:
self.feed_eof()
return b""
if ans_type != RPC_PROXY_ANS: