mirror of
https://github.com/alexbers/mtprotoproxy.git
synced 2026-03-21 18:15:50 +00:00
support quickack flag
This commit is contained in:
@@ -168,7 +168,7 @@ class LayeredStreamWriterBase:
|
|||||||
def __init__(self, upstream):
|
def __init__(self, upstream):
|
||||||
self.upstream = upstream
|
self.upstream = upstream
|
||||||
|
|
||||||
def write(self, data):
|
def write(self, data, extra={}):
|
||||||
return self.upstream.write(data)
|
return self.upstream.write(data)
|
||||||
|
|
||||||
def write_eof(self):
|
def write_eof(self):
|
||||||
@@ -230,7 +230,7 @@ class CryptoWrappedStreamWriter(LayeredStreamWriterBase):
|
|||||||
self.encryptor = encryptor
|
self.encryptor = encryptor
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
|
|
||||||
def write(self, data):
|
def write(self, data, extra={}):
|
||||||
if len(data) % self.block_size != 0:
|
if len(data) % self.block_size != 0:
|
||||||
print_err("BUG: writing %d bytes not aligned to block size %d" % (
|
print_err("BUG: writing %d bytes not aligned to block size %d" % (
|
||||||
len(data), self.block_size))
|
len(data), self.block_size))
|
||||||
@@ -280,7 +280,7 @@ class MTProtoFrameStreamWriter(LayeredStreamWriterBase):
|
|||||||
self.upstream = upstream
|
self.upstream = upstream
|
||||||
self.seq_no = seq_no
|
self.seq_no = seq_no
|
||||||
|
|
||||||
def write(self, msg):
|
def write(self, msg, extra={}):
|
||||||
len_bytes = int.to_bytes(len(msg) + 4 + 4 + 4, 4, "little")
|
len_bytes = int.to_bytes(len(msg) + 4 + 4 + 4, 4, "little")
|
||||||
seq_bytes = int.to_bytes(self.seq_no, 4, "little", signed=True)
|
seq_bytes = int.to_bytes(self.seq_no, 4, "little", signed=True)
|
||||||
self.seq_no += 1
|
self.seq_no += 1
|
||||||
@@ -299,7 +299,9 @@ class MTProtoCompactFrameStreamReader(LayeredStreamReaderBase):
|
|||||||
msg_len_bytes = await self.upstream.readexactly(1)
|
msg_len_bytes = await self.upstream.readexactly(1)
|
||||||
msg_len = int.from_bytes(msg_len_bytes, "little")
|
msg_len = int.from_bytes(msg_len_bytes, "little")
|
||||||
|
|
||||||
|
extra = {"QUICKACK_FLAG": False}
|
||||||
if msg_len >= 0x80:
|
if msg_len >= 0x80:
|
||||||
|
extra["QUICKACK_FLAG"] = True
|
||||||
msg_len -= 0x80
|
msg_len -= 0x80
|
||||||
|
|
||||||
if msg_len == 0x7f:
|
if msg_len == 0x7f:
|
||||||
@@ -310,11 +312,11 @@ class MTProtoCompactFrameStreamReader(LayeredStreamReaderBase):
|
|||||||
|
|
||||||
data = await self.upstream.readexactly(msg_len)
|
data = await self.upstream.readexactly(msg_len)
|
||||||
|
|
||||||
return data
|
return data, extra
|
||||||
|
|
||||||
|
|
||||||
class MTProtoCompactFrameStreamWriter(LayeredStreamWriterBase):
|
class MTProtoCompactFrameStreamWriter(LayeredStreamWriterBase):
|
||||||
def write(self, data):
|
def write(self, data, extra={}):
|
||||||
SMALL_PKT_BORDER = 0x7f
|
SMALL_PKT_BORDER = 0x7f
|
||||||
LARGE_PKT_BORGER = 256 ** 3
|
LARGE_PKT_BORGER = 256 ** 3
|
||||||
|
|
||||||
@@ -327,8 +329,7 @@ class MTProtoCompactFrameStreamWriter(LayeredStreamWriterBase):
|
|||||||
if len_div_four < SMALL_PKT_BORDER:
|
if len_div_four < SMALL_PKT_BORDER:
|
||||||
return self.upstream.write(bytes([len_div_four]) + data)
|
return self.upstream.write(bytes([len_div_four]) + data)
|
||||||
elif len_div_four < LARGE_PKT_BORGER:
|
elif len_div_four < LARGE_PKT_BORGER:
|
||||||
return self.upstream.write(b'\x7f' + bytes(int.to_bytes(len_div_four, 3, 'little')) +
|
return self.upstream.write(b'\x7f' + int.to_bytes(len_div_four, 3, 'little') + data)
|
||||||
data)
|
|
||||||
else:
|
else:
|
||||||
print_err("Attempted to send too large pkt len =", len(data))
|
print_err("Attempted to send too large pkt len =", len(data))
|
||||||
return 0
|
return 0
|
||||||
@@ -339,16 +340,18 @@ class MTProtoIntermediateFrameStreamReader(LayeredStreamReaderBase):
|
|||||||
msg_len_bytes = await self.upstream.readexactly(4)
|
msg_len_bytes = await self.upstream.readexactly(4)
|
||||||
msg_len = int.from_bytes(msg_len_bytes, "little")
|
msg_len = int.from_bytes(msg_len_bytes, "little")
|
||||||
|
|
||||||
|
extra = {}
|
||||||
if msg_len > 0x80000000:
|
if msg_len > 0x80000000:
|
||||||
|
extra["QUICKACK_FLAG"] = True
|
||||||
msg_len -= 0x80000000
|
msg_len -= 0x80000000
|
||||||
|
|
||||||
data = await self.upstream.readexactly(msg_len)
|
data = await self.upstream.readexactly(msg_len)
|
||||||
|
|
||||||
return data
|
return data, extra
|
||||||
|
|
||||||
|
|
||||||
class MTProtoIntermediateFrameStreamWriter(LayeredStreamWriterBase):
|
class MTProtoIntermediateFrameStreamWriter(LayeredStreamWriterBase):
|
||||||
def write(self, data):
|
def write(self, data, extra={}):
|
||||||
return self.upstream.write(int.to_bytes(len(data), 4, 'little') + data)
|
return self.upstream.write(int.to_bytes(len(data), 4, 'little') + data)
|
||||||
|
|
||||||
|
|
||||||
@@ -392,30 +395,41 @@ class ProxyReqStreamWriter(LayeredStreamWriterBase):
|
|||||||
self.our_ip_port += int.to_bytes(my_port, 4, "little")
|
self.our_ip_port += int.to_bytes(my_port, 4, "little")
|
||||||
self.out_conn_id = bytearray([random.randrange(0, 256) for i in range(8)])
|
self.out_conn_id = bytearray([random.randrange(0, 256) for i in range(8)])
|
||||||
|
|
||||||
if proto_tag == PROTO_TAG_ABRIDGED:
|
self.proto_tag = proto_tag
|
||||||
self.last_flag_byte = b"\x40"
|
|
||||||
elif proto_tag == PROTO_TAG_INTERMEDIATE:
|
|
||||||
self.last_flag_byte = b"\x20"
|
|
||||||
else:
|
|
||||||
self.last_flag_byte = b"\x00"
|
|
||||||
|
|
||||||
def write(self, msg):
|
def write(self, msg, extra={}):
|
||||||
RPC_PROXY_REQ = b"\xee\xf1\xce\x36"
|
RPC_PROXY_REQ = b"\xee\xf1\xce\x36"
|
||||||
EXTRA_SIZE = b"\x18\x00\x00\x00"
|
EXTRA_SIZE = b"\x18\x00\x00\x00"
|
||||||
PROXY_TAG = b"\xae\x26\x1e\xdb"
|
PROXY_TAG = b"\xae\x26\x1e\xdb"
|
||||||
FOUR_BYTES_ALIGNER = b"\x00\x00\x00"
|
FOUR_BYTES_ALIGNER = b"\x00\x00\x00"
|
||||||
|
|
||||||
|
FLAG_NOT_ENCRYPTED = 0x2
|
||||||
|
FLAG_HAS_AD_TAG = 0x8
|
||||||
|
FLAG_MAGIC = 0x1000
|
||||||
|
FLAG_EXTMODE2 = 0x20000
|
||||||
|
FLAG_INTERMEDIATE = 0x20000000
|
||||||
|
FLAG_ABRIDGED = 0x40000000
|
||||||
|
FLAG_QUICKACK = 0x80000000
|
||||||
|
|
||||||
if len(msg) % 4 != 0:
|
if len(msg) % 4 != 0:
|
||||||
print_err("BUG: attempted to send msg with len %d" % len(msg))
|
print_err("BUG: attempted to send msg with len %d" % len(msg))
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
flags = FLAG_HAS_AD_TAG | FLAG_MAGIC | FLAG_EXTMODE2
|
||||||
|
|
||||||
|
if self.proto_tag == PROTO_TAG_ABRIDGED:
|
||||||
|
flags |= FLAG_ABRIDGED
|
||||||
|
elif self.proto_tag == PROTO_TAG_INTERMEDIATE:
|
||||||
|
flags |= FLAG_INTERMEDIATE
|
||||||
|
|
||||||
|
if extra.get("QUICKACK_FLAG"):
|
||||||
|
flags |= FLAG_QUICKACK
|
||||||
|
|
||||||
if msg.startswith(b"\x00" * 8):
|
if msg.startswith(b"\x00" * 8):
|
||||||
flags = b"\x0a\x10\x02" + self.last_flag_byte
|
flags |= FLAG_NOT_ENCRYPTED
|
||||||
else:
|
|
||||||
flags = b"\x08\x10\x02" + self.last_flag_byte
|
|
||||||
|
|
||||||
full_msg = bytearray()
|
full_msg = bytearray()
|
||||||
full_msg += RPC_PROXY_REQ + flags + self.out_conn_id
|
full_msg += RPC_PROXY_REQ + int.to_bytes(flags, 4, "little") + self.out_conn_id
|
||||||
full_msg += self.remote_ip_port + self.our_ip_port + EXTRA_SIZE + PROXY_TAG
|
full_msg += self.remote_ip_port + self.our_ip_port + EXTRA_SIZE + PROXY_TAG
|
||||||
full_msg += bytes([len(AD_TAG)]) + AD_TAG + FOUR_BYTES_ALIGNER
|
full_msg += bytes([len(AD_TAG)]) + AD_TAG + FOUR_BYTES_ALIGNER
|
||||||
full_msg += msg
|
full_msg += msg
|
||||||
@@ -744,6 +758,11 @@ async def handle_client(reader_clt, writer_clt):
|
|||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
data = await rd.read(READ_BUF_SIZE)
|
data = await rd.read(READ_BUF_SIZE)
|
||||||
|
if isinstance(data, tuple):
|
||||||
|
data, extra = data
|
||||||
|
else:
|
||||||
|
extra = {}
|
||||||
|
|
||||||
if not data:
|
if not data:
|
||||||
wr.write_eof()
|
wr.write_eof()
|
||||||
await wr.drain()
|
await wr.drain()
|
||||||
@@ -751,7 +770,7 @@ async def handle_client(reader_clt, writer_clt):
|
|||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
update_stats(user, octets=len(data))
|
update_stats(user, octets=len(data))
|
||||||
wr.write(data)
|
wr.write(data, extra)
|
||||||
await wr.drain()
|
await wr.drain()
|
||||||
except (OSError, AttributeError, asyncio.streams.IncompleteReadError) as e:
|
except (OSError, AttributeError, asyncio.streams.IncompleteReadError) as e:
|
||||||
# print_err(e)
|
# print_err(e)
|
||||||
|
|||||||
Reference in New Issue
Block a user