mirror of
https://github.com/alexbers/mtprotoproxy.git
synced 2026-03-13 23:03:09 +00:00
init pooled connections to save one more round trip time
This commit is contained in:
119
mtprotoproxy.py
119
mtprotoproxy.py
@@ -442,42 +442,49 @@ class TgConnectionPool:
|
||||
def __init__(self):
|
||||
self.pools = {}
|
||||
|
||||
async def open_tg_connection(self, host, port):
|
||||
async def open_tg_connection(self, host, port, init_func=None):
|
||||
task = asyncio.open_connection(host, port, limit=get_to_clt_bufsize())
|
||||
reader_tgt, writer_tgt = await asyncio.wait_for(task, timeout=config.TG_CONNECT_TIMEOUT)
|
||||
|
||||
set_keepalive(writer_tgt.get_extra_info("socket"))
|
||||
set_bufsizes(writer_tgt.get_extra_info("socket"), get_to_clt_bufsize(), get_to_tg_bufsize())
|
||||
|
||||
if init_func:
|
||||
return await asyncio.wait_for(init_func(host, port, reader_tgt, writer_tgt),
|
||||
timeout=config.TG_CONNECT_TIMEOUT)
|
||||
return reader_tgt, writer_tgt
|
||||
|
||||
def register_host_port(self, host, port):
|
||||
if (host, port) not in self.pools:
|
||||
self.pools[(host, port)] = []
|
||||
def register_host_port(self, host, port, init_func):
|
||||
if (host, port, init_func) not in self.pools:
|
||||
self.pools[(host, port, init_func)] = []
|
||||
|
||||
while len(self.pools[(host, port)]) < TgConnectionPool.MAX_CONNS_IN_POOL:
|
||||
connect_task = asyncio.ensure_future(self.open_tg_connection(host, port))
|
||||
self.pools[(host, port)].append(connect_task)
|
||||
while len(self.pools[(host, port, init_func)]) < TgConnectionPool.MAX_CONNS_IN_POOL:
|
||||
connect_task = asyncio.ensure_future(self.open_tg_connection(host, port, init_func))
|
||||
self.pools[(host, port, init_func)].append(connect_task)
|
||||
|
||||
async def get_connection(self, host, port):
|
||||
self.register_host_port(host, port)
|
||||
async def get_connection(self, host, port, init_func=None):
|
||||
self.register_host_port(host, port, init_func)
|
||||
|
||||
ret = None
|
||||
for task in self.pools[(host, port)][::]:
|
||||
for task in self.pools[(host, port, init_func)][::]:
|
||||
if task.done():
|
||||
if task.exception():
|
||||
self.pools[(host, port)].remove(task)
|
||||
self.pools[(host, port, init_func)].remove(task)
|
||||
continue
|
||||
|
||||
reader, writer = task.result()
|
||||
reader, writer, *other = task.result()
|
||||
if writer.transport.is_closing():
|
||||
self.pools[(host, port)].remove(task)
|
||||
self.pools[(host, port, init_func)].remove(task)
|
||||
continue
|
||||
|
||||
if not ret:
|
||||
self.pools[(host, port)].remove(task)
|
||||
ret = (reader, writer)
|
||||
self.pools[(host, port, init_func)].remove(task)
|
||||
ret = (reader, writer, *other)
|
||||
|
||||
self.register_host_port(host, port)
|
||||
self.register_host_port(host, port, init_func)
|
||||
if ret:
|
||||
return ret
|
||||
return await self.open_tg_connection(host, port)
|
||||
return await self.open_tg_connection(host, port, init_func)
|
||||
|
||||
|
||||
tg_connection_pool = TgConnectionPool()
|
||||
@@ -1269,13 +1276,13 @@ async def do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=None):
|
||||
except ConnectionRefusedError as E:
|
||||
print_err("Got connection refused while trying to connect to", dc, TG_DATACENTER_PORT)
|
||||
return False
|
||||
except ConnectionAbortedError as E:
|
||||
print_err("The Telegram server connection is bad: %d (%s %s) %s" % (dc_idx, addr, port, E))
|
||||
return False
|
||||
except (OSError, asyncio.TimeoutError) as E:
|
||||
print_err("Unable to connect to", dc, TG_DATACENTER_PORT)
|
||||
return False
|
||||
|
||||
set_keepalive(writer_tgt.get_extra_info("socket"))
|
||||
set_bufsizes(writer_tgt.get_extra_info("socket"), get_to_clt_bufsize(), get_to_tg_bufsize())
|
||||
|
||||
while True:
|
||||
rnd = bytearray(myrandom.getrandbytes(HANDSHAKE_LEN))
|
||||
if rnd[:1] in RESERVED_NONCE_FIRST_CHARS:
|
||||
@@ -1338,20 +1345,50 @@ def get_middleproxy_aes_key_and_iv(nonce_srv, nonce_clt, clt_ts, srv_ip, clt_por
|
||||
return key, iv
|
||||
|
||||
|
||||
async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port):
|
||||
async def middleproxy_handshake_after_connect(host, port, reader_tgt, writer_tgt):
|
||||
""" The first stage of middleproxy handshake """
|
||||
START_SEQ_NO = -2
|
||||
NONCE_LEN = 16
|
||||
|
||||
RPC_NONCE = b"\xaa\x87\xcb\x7a"
|
||||
RPC_HANDSHAKE = b"\xf5\xee\x82\x76"
|
||||
CRYPTO_AES = b"\x01\x00\x00\x00"
|
||||
|
||||
RPC_NONCE_ANS_LEN = 32
|
||||
RPC_HANDSHAKE_ANS_LEN = 32
|
||||
|
||||
writer_tgt = MTProtoFrameStreamWriter(writer_tgt, START_SEQ_NO)
|
||||
key_selector = PROXY_SECRET[:4]
|
||||
crypto_ts = int.to_bytes(int(time.time()) % (256**4), 4, "little")
|
||||
|
||||
nonce = myrandom.getrandbytes(NONCE_LEN)
|
||||
|
||||
msg = RPC_NONCE + key_selector + CRYPTO_AES + crypto_ts + nonce
|
||||
|
||||
writer_tgt.write(msg)
|
||||
await writer_tgt.drain()
|
||||
|
||||
reader_tgt = MTProtoFrameStreamReader(reader_tgt, START_SEQ_NO)
|
||||
ans = await reader_tgt.read(get_to_clt_bufsize())
|
||||
|
||||
if len(ans) != RPC_NONCE_ANS_LEN:
|
||||
raise ConnectionAbortedError("bad rpc answer length")
|
||||
|
||||
rpc_type, rpc_key_selector, rpc_schema, rpc_crypto_ts, rpc_nonce = (
|
||||
ans[:4], ans[4:8], ans[8:12], ans[12:16], ans[16:32]
|
||||
)
|
||||
|
||||
if rpc_type != RPC_NONCE or rpc_key_selector != key_selector or rpc_schema != CRYPTO_AES:
|
||||
raise ConnectionAbortedError("bad rpc answer")
|
||||
|
||||
return reader_tgt, writer_tgt, nonce, rpc_nonce, crypto_ts
|
||||
|
||||
|
||||
async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port):
|
||||
RPC_HANDSHAKE = b"\xf5\xee\x82\x76"
|
||||
# pass as consts to simplify code
|
||||
RPC_FLAGS = b"\x00\x00\x00\x00"
|
||||
|
||||
RPC_HANDSHAKE_ANS_LEN = 32
|
||||
|
||||
global my_ip_info
|
||||
global tg_connection_pool
|
||||
|
||||
@@ -1368,43 +1405,19 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port):
|
||||
addr, port = myrandom.choice(TG_MIDDLE_PROXIES_V4[dc_idx])
|
||||
|
||||
try:
|
||||
reader_tgt, writer_tgt = await tg_connection_pool.get_connection(addr, port)
|
||||
ret = await tg_connection_pool.get_connection(addr, port,
|
||||
middleproxy_handshake_after_connect)
|
||||
reader_tgt, writer_tgt, nonce, rpc_nonce, crypto_ts = ret
|
||||
except ConnectionRefusedError as E:
|
||||
print_err("The Telegram server %d (%s %s) is refusing connections" % (dc_idx, addr, port))
|
||||
return False
|
||||
except ConnectionAbortedError as E:
|
||||
print_err("The Telegram server connection is bad: %d (%s %s) %s" % (dc_idx, addr, port, E))
|
||||
return False
|
||||
except (OSError, asyncio.TimeoutError) as E:
|
||||
print_err("Unable to connect to the Telegram server %d (%s %s)" % (dc_idx, addr, port))
|
||||
return False
|
||||
|
||||
set_keepalive(writer_tgt.get_extra_info("socket"))
|
||||
set_bufsizes(writer_tgt.get_extra_info("socket"), get_to_clt_bufsize(), get_to_tg_bufsize())
|
||||
|
||||
writer_tgt = MTProtoFrameStreamWriter(writer_tgt, START_SEQ_NO)
|
||||
|
||||
key_selector = PROXY_SECRET[:4]
|
||||
crypto_ts = int.to_bytes(int(time.time()) % (256**4), 4, "little")
|
||||
|
||||
nonce = myrandom.getrandbytes(NONCE_LEN)
|
||||
|
||||
msg = RPC_NONCE + key_selector + CRYPTO_AES + crypto_ts + nonce
|
||||
|
||||
writer_tgt.write(msg)
|
||||
await writer_tgt.drain()
|
||||
|
||||
old_reader = reader_tgt
|
||||
reader_tgt = MTProtoFrameStreamReader(reader_tgt, START_SEQ_NO)
|
||||
ans = await reader_tgt.read(get_to_clt_bufsize())
|
||||
|
||||
if len(ans) != RPC_NONCE_ANS_LEN:
|
||||
return False
|
||||
|
||||
rpc_type, rpc_key_selector, rpc_schema, rpc_crypto_ts, rpc_nonce = (
|
||||
ans[:4], ans[4:8], ans[8:12], ans[12:16], ans[16:32]
|
||||
)
|
||||
|
||||
if rpc_type != RPC_NONCE or rpc_key_selector != key_selector or rpc_schema != CRYPTO_AES:
|
||||
return False
|
||||
|
||||
# get keys
|
||||
tg_ip, tg_port = writer_tgt.upstream.get_extra_info('peername')[:2]
|
||||
my_ip, my_port = writer_tgt.upstream.get_extra_info('sockname')[:2]
|
||||
|
||||
Reference in New Issue
Block a user