diff --git a/mtprotoproxy.py b/mtprotoproxy.py index c0a7177..82f83a7 100755 --- a/mtprotoproxy.py +++ b/mtprotoproxy.py @@ -130,16 +130,18 @@ class CryptoWrappedStreamReader: async def read(self, n): if self.buf: - ret = self.buf + ret = bytes(self.buf) self.buf.clear() return ret else: - readed = await self.stream.read(n) + data = await self.stream.read(n) + if not data: + return b"" - needed_till_full_block = -len(readed) % self.block_size + needed_till_full_block = -len(data) % self.block_size if needed_till_full_block > 0: - readed += self.stream.readexactly(needed_till_full_block) - return self.decryptor.decrypt(readed) + data += self.stream.readexactly(needed_till_full_block) + return self.decryptor.decrypt(data) async def readexactly(self, n): if n > len(self.buf): @@ -189,17 +191,15 @@ class MTProtoFrameStreamReader: msg_len_bytes = await self.stream.readexactly(4) msg_len = int.from_bytes(msg_len_bytes, "little") - len_is_impossible = (msg_len % len(PADDING_FILLER) != 0) - if not MIN_MSG_LEN <= msg_len <= MAX_MSG_LEN or len_is_impossible: + len_is_bad = (msg_len % len(PADDING_FILLER) != 0) + if not MIN_MSG_LEN <= msg_len <= MAX_MSG_LEN or len_is_bad: print("msg_len is bad, closing connection", msg_len) - self.stream.feed_eof() return b"" msg_seq_bytes = await self.stream.readexactly(4) msg_seq = int.from_bytes(msg_seq_bytes, "little", signed=True) if msg_seq != self.seq_no: print("unexpected seq_no") - self.stream.feed_eof() return b"" self.seq_no += 1 @@ -210,11 +210,32 @@ class MTProtoFrameStreamReader: computed_checksum = binascii.crc32(msg_len_bytes + msg_seq_bytes + data) if computed_checksum != checksum: - self.stream.feed_eof() return b"" return data +class MTProtoFrameStreamWriter: + def __init__(self, stream, seq_no=0): + self.stream = stream + self.seq_no = seq_no + + def __getattr__(self, attr): + return getattr(self.stream, attr) + + def write(self, msg): + len_bytes = int.to_bytes(len(msg) + 4 + 4 + 4, 4, "little") + seq_bytes = int.to_bytes(self.seq_no, 4, "little", signed=True) + self.seq_no += 1 + + msg_without_checksum = len_bytes + seq_bytes + msg + checksum = int.to_bytes(binascii.crc32(msg_without_checksum), 4, "little") + + full_msg = msg_without_checksum + checksum + padding = PADDING_FILLER * ((-len(full_msg) % CBC_PADDING) // len(PADDING_FILLER)) + + return self.stream.write(full_msg + padding) + + class MTProtoCompactFrameStreamReader: def __init__(self, stream): self.stream = stream @@ -268,28 +289,6 @@ class MTProtoCompactFrameStreamWriter: return 0 -class MTProtoFrameStreamWriter: - def __init__(self, stream, seq_no=0): - self.stream = stream - self.seq_no = seq_no - - def __getattr__(self, attr): - return getattr(self.stream, attr) - - def write(self, msg): - len_bytes = int.to_bytes(len(msg) + 4 + 4 + 4, 4, "little") - seq_bytes = int.to_bytes(self.seq_no, 4, "little", signed=True) - self.seq_no += 1 - - msg_without_checksum = len_bytes + seq_bytes + msg - checksum = int.to_bytes(binascii.crc32(msg_without_checksum), 4, "little") - - full_msg = msg_without_checksum + checksum - padding = PADDING_FILLER * ((-len(full_msg) % CBC_PADDING) // len(PADDING_FILLER)) - - return self.stream.write(full_msg + padding) - - class ProxyReqStreamReader: def __init__(self, stream): self.stream = stream @@ -308,7 +307,6 @@ class ProxyReqStreamReader: ans_type, ans_flags, conn_id, conn_data = data[:4], data[4:8], data[8:16], data[16:] if ans_type == RPC_CLOSE_EXT: - self.feed_eof() return b"" if ans_type != RPC_PROXY_ANS: