diff --git a/mtprotoproxy.py b/mtprotoproxy.py index 74b4654..28cbc1d 100755 --- a/mtprotoproxy.py +++ b/mtprotoproxy.py @@ -431,11 +431,6 @@ class MTProtoIntermediateFrameStreamReader(LayeredStreamReaderBase): msg_len -= 0x80000000 data = await self.upstream.readexactly(msg_len) - - if msg_len % 4 != 0: - cut_border = msg_len - (msg_len % 4) - data = data[:cut_border] - return data, extra @@ -447,6 +442,38 @@ class MTProtoIntermediateFrameStreamWriter(LayeredStreamWriterBase): return self.upstream.write(int.to_bytes(len(data), 4, 'little') + data) +class MTProtoSecureIntermediateFrameStreamReader(LayeredStreamReaderBase): + async def read(self, buf_size): + msg_len_bytes = await self.upstream.readexactly(4) + msg_len = int.from_bytes(msg_len_bytes, "little") + + extra = {} + if msg_len > 0x80000000: + extra["QUICKACK_FLAG"] = True + msg_len -= 0x80000000 + + data = await self.upstream.readexactly(msg_len) + + if msg_len % 4 != 0: + cut_border = msg_len - (msg_len % 4) + data = data[:cut_border] + + return data, extra + + +class MTProtoSecureIntermediateFrameStreamWriter(LayeredStreamWriterBase): + def write(self, data, extra={}): + MAX_PADDING_LEN = 4 + if extra.get("SIMPLE_ACK"): + # TODO: make this unpredictable + return self.upstream.write(data) + else: + padding_len = random.randrange(MAX_PADDING_LEN) + padding = bytearray([random.randrange(256) for i in range(padding_len)]) + padded_data_len_bytes = int.to_bytes(len(data) + padding_len, 4, 'little') + return self.upstream.write(padded_data_len_bytes + data + padding) + + class ProxyReqStreamReader(LayeredStreamReaderBase): async def read(self, msg): RPC_PROXY_ANS = b"\x0d\xda\x03\x44" @@ -505,6 +532,7 @@ class ProxyReqStreamWriter(LayeredStreamWriterBase): FLAG_HAS_AD_TAG = 0x8 FLAG_MAGIC = 0x1000 FLAG_EXTMODE2 = 0x20000 + FLAG_PAD = 0x8000000 FLAG_INTERMEDIATE = 0x20000000 FLAG_ABRIDGED = 0x40000000 FLAG_QUICKACK = 0x80000000 @@ -519,6 +547,8 @@ class ProxyReqStreamWriter(LayeredStreamWriterBase): flags |= FLAG_ABRIDGED elif self.proto_tag == PROTO_TAG_INTERMEDIATE: flags |= FLAG_INTERMEDIATE + elif self.proto_tag == PROTO_TAG_SECURE: + flags |= FLAG_INTERMEDIATE | FLAG_PAD if extra.get("QUICKACK_FLAG"): flags |= FLAG_QUICKACK @@ -880,9 +910,12 @@ async def handle_client(reader_clt, writer_clt): if proto_tag == PROTO_TAG_ABRIDGED: reader_clt = MTProtoCompactFrameStreamReader(reader_clt) writer_clt = MTProtoCompactFrameStreamWriter(writer_clt) - elif proto_tag in (PROTO_TAG_INTERMEDIATE, PROTO_TAG_SECURE): + elif proto_tag == PROTO_TAG_INTERMEDIATE: reader_clt = MTProtoIntermediateFrameStreamReader(reader_clt) writer_clt = MTProtoIntermediateFrameStreamWriter(writer_clt) + elif proto_tag == PROTO_TAG_SECURE: + reader_clt = MTProtoSecureIntermediateFrameStreamReader(reader_clt) + writer_clt = MTProtoSecureIntermediateFrameStreamWriter(writer_clt) else: return