From dc9da63fbcb1839b2281a4f4aae9b4629c7fd78c Mon Sep 17 00:00:00 2001 From: Alexander Bersenev Date: Mon, 18 Jun 2018 18:33:48 +0500 Subject: [PATCH] support quickack flag --- mtprotoproxy.py | 61 ++++++++++++++++++++++++++++++++----------------- 1 file changed, 40 insertions(+), 21 deletions(-) diff --git a/mtprotoproxy.py b/mtprotoproxy.py index 1bc05fd..e11a510 100755 --- a/mtprotoproxy.py +++ b/mtprotoproxy.py @@ -168,7 +168,7 @@ class LayeredStreamWriterBase: def __init__(self, upstream): self.upstream = upstream - def write(self, data): + def write(self, data, extra={}): return self.upstream.write(data) def write_eof(self): @@ -230,7 +230,7 @@ class CryptoWrappedStreamWriter(LayeredStreamWriterBase): self.encryptor = encryptor self.block_size = block_size - def write(self, data): + def write(self, data, extra={}): if len(data) % self.block_size != 0: print_err("BUG: writing %d bytes not aligned to block size %d" % ( len(data), self.block_size)) @@ -280,7 +280,7 @@ class MTProtoFrameStreamWriter(LayeredStreamWriterBase): self.upstream = upstream self.seq_no = seq_no - def write(self, msg): + def write(self, msg, extra={}): len_bytes = int.to_bytes(len(msg) + 4 + 4 + 4, 4, "little") seq_bytes = int.to_bytes(self.seq_no, 4, "little", signed=True) self.seq_no += 1 @@ -299,7 +299,9 @@ class MTProtoCompactFrameStreamReader(LayeredStreamReaderBase): msg_len_bytes = await self.upstream.readexactly(1) msg_len = int.from_bytes(msg_len_bytes, "little") + extra = {"QUICKACK_FLAG": False} if msg_len >= 0x80: + extra["QUICKACK_FLAG"] = True msg_len -= 0x80 if msg_len == 0x7f: @@ -310,11 +312,11 @@ class MTProtoCompactFrameStreamReader(LayeredStreamReaderBase): data = await self.upstream.readexactly(msg_len) - return data + return data, extra class MTProtoCompactFrameStreamWriter(LayeredStreamWriterBase): - def write(self, data): + def write(self, data, extra={}): SMALL_PKT_BORDER = 0x7f LARGE_PKT_BORGER = 256 ** 3 @@ -327,8 +329,7 @@ class MTProtoCompactFrameStreamWriter(LayeredStreamWriterBase): if len_div_four < SMALL_PKT_BORDER: return self.upstream.write(bytes([len_div_four]) + data) elif len_div_four < LARGE_PKT_BORGER: - return self.upstream.write(b'\x7f' + bytes(int.to_bytes(len_div_four, 3, 'little')) + - data) + return self.upstream.write(b'\x7f' + int.to_bytes(len_div_four, 3, 'little') + data) else: print_err("Attempted to send too large pkt len =", len(data)) return 0 @@ -339,16 +340,18 @@ class MTProtoIntermediateFrameStreamReader(LayeredStreamReaderBase): msg_len_bytes = await self.upstream.readexactly(4) msg_len = int.from_bytes(msg_len_bytes, "little") + extra = {} if msg_len > 0x80000000: + extra["QUICKACK_FLAG"] = True msg_len -= 0x80000000 data = await self.upstream.readexactly(msg_len) - return data + return data, extra class MTProtoIntermediateFrameStreamWriter(LayeredStreamWriterBase): - def write(self, data): + def write(self, data, extra={}): return self.upstream.write(int.to_bytes(len(data), 4, 'little') + data) @@ -392,30 +395,41 @@ class ProxyReqStreamWriter(LayeredStreamWriterBase): self.our_ip_port += int.to_bytes(my_port, 4, "little") self.out_conn_id = bytearray([random.randrange(0, 256) for i in range(8)]) - if proto_tag == PROTO_TAG_ABRIDGED: - self.last_flag_byte = b"\x40" - elif proto_tag == PROTO_TAG_INTERMEDIATE: - self.last_flag_byte = b"\x20" - else: - self.last_flag_byte = b"\x00" + self.proto_tag = proto_tag - def write(self, msg): + def write(self, msg, extra={}): RPC_PROXY_REQ = b"\xee\xf1\xce\x36" EXTRA_SIZE = b"\x18\x00\x00\x00" PROXY_TAG = b"\xae\x26\x1e\xdb" FOUR_BYTES_ALIGNER = b"\x00\x00\x00" + FLAG_NOT_ENCRYPTED = 0x2 + FLAG_HAS_AD_TAG = 0x8 + FLAG_MAGIC = 0x1000 + FLAG_EXTMODE2 = 0x20000 + FLAG_INTERMEDIATE = 0x20000000 + FLAG_ABRIDGED = 0x40000000 + FLAG_QUICKACK = 0x80000000 + if len(msg) % 4 != 0: print_err("BUG: attempted to send msg with len %d" % len(msg)) return 0 + flags = FLAG_HAS_AD_TAG | FLAG_MAGIC | FLAG_EXTMODE2 + + if self.proto_tag == PROTO_TAG_ABRIDGED: + flags |= FLAG_ABRIDGED + elif self.proto_tag == PROTO_TAG_INTERMEDIATE: + flags |= FLAG_INTERMEDIATE + + if extra.get("QUICKACK_FLAG"): + flags |= FLAG_QUICKACK + if msg.startswith(b"\x00" * 8): - flags = b"\x0a\x10\x02" + self.last_flag_byte - else: - flags = b"\x08\x10\x02" + self.last_flag_byte + flags |= FLAG_NOT_ENCRYPTED full_msg = bytearray() - full_msg += RPC_PROXY_REQ + flags + self.out_conn_id + full_msg += RPC_PROXY_REQ + int.to_bytes(flags, 4, "little") + self.out_conn_id full_msg += self.remote_ip_port + self.our_ip_port + EXTRA_SIZE + PROXY_TAG full_msg += bytes([len(AD_TAG)]) + AD_TAG + FOUR_BYTES_ALIGNER full_msg += msg @@ -744,6 +758,11 @@ async def handle_client(reader_clt, writer_clt): try: while True: data = await rd.read(READ_BUF_SIZE) + if isinstance(data, tuple): + data, extra = data + else: + extra = {} + if not data: wr.write_eof() await wr.drain() @@ -751,7 +770,7 @@ async def handle_client(reader_clt, writer_clt): return else: update_stats(user, octets=len(data)) - wr.write(data) + wr.write(data, extra) await wr.drain() except (OSError, AttributeError, asyncio.streams.IncompleteReadError) as e: # print_err(e)