mirror of
https://github.com/alexbers/mtprotoproxy.git
synced 2026-03-14 07:13:09 +00:00
init pooled connections to save one more round trip time
This commit is contained in:
119
mtprotoproxy.py
119
mtprotoproxy.py
@@ -442,42 +442,49 @@ class TgConnectionPool:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.pools = {}
|
self.pools = {}
|
||||||
|
|
||||||
async def open_tg_connection(self, host, port):
|
async def open_tg_connection(self, host, port, init_func=None):
|
||||||
task = asyncio.open_connection(host, port, limit=get_to_clt_bufsize())
|
task = asyncio.open_connection(host, port, limit=get_to_clt_bufsize())
|
||||||
reader_tgt, writer_tgt = await asyncio.wait_for(task, timeout=config.TG_CONNECT_TIMEOUT)
|
reader_tgt, writer_tgt = await asyncio.wait_for(task, timeout=config.TG_CONNECT_TIMEOUT)
|
||||||
|
|
||||||
|
set_keepalive(writer_tgt.get_extra_info("socket"))
|
||||||
|
set_bufsizes(writer_tgt.get_extra_info("socket"), get_to_clt_bufsize(), get_to_tg_bufsize())
|
||||||
|
|
||||||
|
if init_func:
|
||||||
|
return await asyncio.wait_for(init_func(host, port, reader_tgt, writer_tgt),
|
||||||
|
timeout=config.TG_CONNECT_TIMEOUT)
|
||||||
return reader_tgt, writer_tgt
|
return reader_tgt, writer_tgt
|
||||||
|
|
||||||
def register_host_port(self, host, port):
|
def register_host_port(self, host, port, init_func):
|
||||||
if (host, port) not in self.pools:
|
if (host, port, init_func) not in self.pools:
|
||||||
self.pools[(host, port)] = []
|
self.pools[(host, port, init_func)] = []
|
||||||
|
|
||||||
while len(self.pools[(host, port)]) < TgConnectionPool.MAX_CONNS_IN_POOL:
|
while len(self.pools[(host, port, init_func)]) < TgConnectionPool.MAX_CONNS_IN_POOL:
|
||||||
connect_task = asyncio.ensure_future(self.open_tg_connection(host, port))
|
connect_task = asyncio.ensure_future(self.open_tg_connection(host, port, init_func))
|
||||||
self.pools[(host, port)].append(connect_task)
|
self.pools[(host, port, init_func)].append(connect_task)
|
||||||
|
|
||||||
async def get_connection(self, host, port):
|
async def get_connection(self, host, port, init_func=None):
|
||||||
self.register_host_port(host, port)
|
self.register_host_port(host, port, init_func)
|
||||||
|
|
||||||
ret = None
|
ret = None
|
||||||
for task in self.pools[(host, port)][::]:
|
for task in self.pools[(host, port, init_func)][::]:
|
||||||
if task.done():
|
if task.done():
|
||||||
if task.exception():
|
if task.exception():
|
||||||
self.pools[(host, port)].remove(task)
|
self.pools[(host, port, init_func)].remove(task)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
reader, writer = task.result()
|
reader, writer, *other = task.result()
|
||||||
if writer.transport.is_closing():
|
if writer.transport.is_closing():
|
||||||
self.pools[(host, port)].remove(task)
|
self.pools[(host, port, init_func)].remove(task)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if not ret:
|
if not ret:
|
||||||
self.pools[(host, port)].remove(task)
|
self.pools[(host, port, init_func)].remove(task)
|
||||||
ret = (reader, writer)
|
ret = (reader, writer, *other)
|
||||||
|
|
||||||
self.register_host_port(host, port)
|
self.register_host_port(host, port, init_func)
|
||||||
if ret:
|
if ret:
|
||||||
return ret
|
return ret
|
||||||
return await self.open_tg_connection(host, port)
|
return await self.open_tg_connection(host, port, init_func)
|
||||||
|
|
||||||
|
|
||||||
tg_connection_pool = TgConnectionPool()
|
tg_connection_pool = TgConnectionPool()
|
||||||
@@ -1269,13 +1276,13 @@ async def do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=None):
|
|||||||
except ConnectionRefusedError as E:
|
except ConnectionRefusedError as E:
|
||||||
print_err("Got connection refused while trying to connect to", dc, TG_DATACENTER_PORT)
|
print_err("Got connection refused while trying to connect to", dc, TG_DATACENTER_PORT)
|
||||||
return False
|
return False
|
||||||
|
except ConnectionAbortedError as E:
|
||||||
|
print_err("The Telegram server connection is bad: %d (%s %s) %s" % (dc_idx, addr, port, E))
|
||||||
|
return False
|
||||||
except (OSError, asyncio.TimeoutError) as E:
|
except (OSError, asyncio.TimeoutError) as E:
|
||||||
print_err("Unable to connect to", dc, TG_DATACENTER_PORT)
|
print_err("Unable to connect to", dc, TG_DATACENTER_PORT)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
set_keepalive(writer_tgt.get_extra_info("socket"))
|
|
||||||
set_bufsizes(writer_tgt.get_extra_info("socket"), get_to_clt_bufsize(), get_to_tg_bufsize())
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
rnd = bytearray(myrandom.getrandbytes(HANDSHAKE_LEN))
|
rnd = bytearray(myrandom.getrandbytes(HANDSHAKE_LEN))
|
||||||
if rnd[:1] in RESERVED_NONCE_FIRST_CHARS:
|
if rnd[:1] in RESERVED_NONCE_FIRST_CHARS:
|
||||||
@@ -1338,20 +1345,50 @@ def get_middleproxy_aes_key_and_iv(nonce_srv, nonce_clt, clt_ts, srv_ip, clt_por
|
|||||||
return key, iv
|
return key, iv
|
||||||
|
|
||||||
|
|
||||||
async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port):
|
async def middleproxy_handshake_after_connect(host, port, reader_tgt, writer_tgt):
|
||||||
|
""" The first stage of middleproxy handshake """
|
||||||
START_SEQ_NO = -2
|
START_SEQ_NO = -2
|
||||||
NONCE_LEN = 16
|
NONCE_LEN = 16
|
||||||
|
|
||||||
RPC_NONCE = b"\xaa\x87\xcb\x7a"
|
RPC_NONCE = b"\xaa\x87\xcb\x7a"
|
||||||
RPC_HANDSHAKE = b"\xf5\xee\x82\x76"
|
|
||||||
CRYPTO_AES = b"\x01\x00\x00\x00"
|
CRYPTO_AES = b"\x01\x00\x00\x00"
|
||||||
|
|
||||||
RPC_NONCE_ANS_LEN = 32
|
RPC_NONCE_ANS_LEN = 32
|
||||||
RPC_HANDSHAKE_ANS_LEN = 32
|
|
||||||
|
|
||||||
|
writer_tgt = MTProtoFrameStreamWriter(writer_tgt, START_SEQ_NO)
|
||||||
|
key_selector = PROXY_SECRET[:4]
|
||||||
|
crypto_ts = int.to_bytes(int(time.time()) % (256**4), 4, "little")
|
||||||
|
|
||||||
|
nonce = myrandom.getrandbytes(NONCE_LEN)
|
||||||
|
|
||||||
|
msg = RPC_NONCE + key_selector + CRYPTO_AES + crypto_ts + nonce
|
||||||
|
|
||||||
|
writer_tgt.write(msg)
|
||||||
|
await writer_tgt.drain()
|
||||||
|
|
||||||
|
reader_tgt = MTProtoFrameStreamReader(reader_tgt, START_SEQ_NO)
|
||||||
|
ans = await reader_tgt.read(get_to_clt_bufsize())
|
||||||
|
|
||||||
|
if len(ans) != RPC_NONCE_ANS_LEN:
|
||||||
|
raise ConnectionAbortedError("bad rpc answer length")
|
||||||
|
|
||||||
|
rpc_type, rpc_key_selector, rpc_schema, rpc_crypto_ts, rpc_nonce = (
|
||||||
|
ans[:4], ans[4:8], ans[8:12], ans[12:16], ans[16:32]
|
||||||
|
)
|
||||||
|
|
||||||
|
if rpc_type != RPC_NONCE or rpc_key_selector != key_selector or rpc_schema != CRYPTO_AES:
|
||||||
|
raise ConnectionAbortedError("bad rpc answer")
|
||||||
|
|
||||||
|
return reader_tgt, writer_tgt, nonce, rpc_nonce, crypto_ts
|
||||||
|
|
||||||
|
|
||||||
|
async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port):
|
||||||
|
RPC_HANDSHAKE = b"\xf5\xee\x82\x76"
|
||||||
# pass as consts to simplify code
|
# pass as consts to simplify code
|
||||||
RPC_FLAGS = b"\x00\x00\x00\x00"
|
RPC_FLAGS = b"\x00\x00\x00\x00"
|
||||||
|
|
||||||
|
RPC_HANDSHAKE_ANS_LEN = 32
|
||||||
|
|
||||||
global my_ip_info
|
global my_ip_info
|
||||||
global tg_connection_pool
|
global tg_connection_pool
|
||||||
|
|
||||||
@@ -1368,43 +1405,19 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port):
|
|||||||
addr, port = myrandom.choice(TG_MIDDLE_PROXIES_V4[dc_idx])
|
addr, port = myrandom.choice(TG_MIDDLE_PROXIES_V4[dc_idx])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
reader_tgt, writer_tgt = await tg_connection_pool.get_connection(addr, port)
|
ret = await tg_connection_pool.get_connection(addr, port,
|
||||||
|
middleproxy_handshake_after_connect)
|
||||||
|
reader_tgt, writer_tgt, nonce, rpc_nonce, crypto_ts = ret
|
||||||
except ConnectionRefusedError as E:
|
except ConnectionRefusedError as E:
|
||||||
print_err("The Telegram server %d (%s %s) is refusing connections" % (dc_idx, addr, port))
|
print_err("The Telegram server %d (%s %s) is refusing connections" % (dc_idx, addr, port))
|
||||||
return False
|
return False
|
||||||
|
except ConnectionAbortedError as E:
|
||||||
|
print_err("The Telegram server connection is bad: %d (%s %s) %s" % (dc_idx, addr, port, E))
|
||||||
|
return False
|
||||||
except (OSError, asyncio.TimeoutError) as E:
|
except (OSError, asyncio.TimeoutError) as E:
|
||||||
print_err("Unable to connect to the Telegram server %d (%s %s)" % (dc_idx, addr, port))
|
print_err("Unable to connect to the Telegram server %d (%s %s)" % (dc_idx, addr, port))
|
||||||
return False
|
return False
|
||||||
|
|
||||||
set_keepalive(writer_tgt.get_extra_info("socket"))
|
|
||||||
set_bufsizes(writer_tgt.get_extra_info("socket"), get_to_clt_bufsize(), get_to_tg_bufsize())
|
|
||||||
|
|
||||||
writer_tgt = MTProtoFrameStreamWriter(writer_tgt, START_SEQ_NO)
|
|
||||||
|
|
||||||
key_selector = PROXY_SECRET[:4]
|
|
||||||
crypto_ts = int.to_bytes(int(time.time()) % (256**4), 4, "little")
|
|
||||||
|
|
||||||
nonce = myrandom.getrandbytes(NONCE_LEN)
|
|
||||||
|
|
||||||
msg = RPC_NONCE + key_selector + CRYPTO_AES + crypto_ts + nonce
|
|
||||||
|
|
||||||
writer_tgt.write(msg)
|
|
||||||
await writer_tgt.drain()
|
|
||||||
|
|
||||||
old_reader = reader_tgt
|
|
||||||
reader_tgt = MTProtoFrameStreamReader(reader_tgt, START_SEQ_NO)
|
|
||||||
ans = await reader_tgt.read(get_to_clt_bufsize())
|
|
||||||
|
|
||||||
if len(ans) != RPC_NONCE_ANS_LEN:
|
|
||||||
return False
|
|
||||||
|
|
||||||
rpc_type, rpc_key_selector, rpc_schema, rpc_crypto_ts, rpc_nonce = (
|
|
||||||
ans[:4], ans[4:8], ans[8:12], ans[12:16], ans[16:32]
|
|
||||||
)
|
|
||||||
|
|
||||||
if rpc_type != RPC_NONCE or rpc_key_selector != key_selector or rpc_schema != CRYPTO_AES:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# get keys
|
# get keys
|
||||||
tg_ip, tg_port = writer_tgt.upstream.get_extra_info('peername')[:2]
|
tg_ip, tg_port = writer_tgt.upstream.get_extra_info('peername')[:2]
|
||||||
my_ip, my_port = writer_tgt.upstream.get_extra_info('sockname')[:2]
|
my_ip, my_port = writer_tgt.upstream.get_extra_info('sockname')[:2]
|
||||||
|
|||||||
Reference in New Issue
Block a user