init pooled connections to save one more round trip time

This commit is contained in:
Alexander Bersenev
2019-11-13 02:31:51 +05:00
parent 8c15fc8fe0
commit 4a4d449a34

View File

@@ -442,42 +442,49 @@ class TgConnectionPool:
def __init__(self): def __init__(self):
self.pools = {} self.pools = {}
async def open_tg_connection(self, host, port): async def open_tg_connection(self, host, port, init_func=None):
task = asyncio.open_connection(host, port, limit=get_to_clt_bufsize()) task = asyncio.open_connection(host, port, limit=get_to_clt_bufsize())
reader_tgt, writer_tgt = await asyncio.wait_for(task, timeout=config.TG_CONNECT_TIMEOUT) reader_tgt, writer_tgt = await asyncio.wait_for(task, timeout=config.TG_CONNECT_TIMEOUT)
set_keepalive(writer_tgt.get_extra_info("socket"))
set_bufsizes(writer_tgt.get_extra_info("socket"), get_to_clt_bufsize(), get_to_tg_bufsize())
if init_func:
return await asyncio.wait_for(init_func(host, port, reader_tgt, writer_tgt),
timeout=config.TG_CONNECT_TIMEOUT)
return reader_tgt, writer_tgt return reader_tgt, writer_tgt
def register_host_port(self, host, port): def register_host_port(self, host, port, init_func):
if (host, port) not in self.pools: if (host, port, init_func) not in self.pools:
self.pools[(host, port)] = [] self.pools[(host, port, init_func)] = []
while len(self.pools[(host, port)]) < TgConnectionPool.MAX_CONNS_IN_POOL: while len(self.pools[(host, port, init_func)]) < TgConnectionPool.MAX_CONNS_IN_POOL:
connect_task = asyncio.ensure_future(self.open_tg_connection(host, port)) connect_task = asyncio.ensure_future(self.open_tg_connection(host, port, init_func))
self.pools[(host, port)].append(connect_task) self.pools[(host, port, init_func)].append(connect_task)
async def get_connection(self, host, port): async def get_connection(self, host, port, init_func=None):
self.register_host_port(host, port) self.register_host_port(host, port, init_func)
ret = None ret = None
for task in self.pools[(host, port)][::]: for task in self.pools[(host, port, init_func)][::]:
if task.done(): if task.done():
if task.exception(): if task.exception():
self.pools[(host, port)].remove(task) self.pools[(host, port, init_func)].remove(task)
continue continue
reader, writer = task.result() reader, writer, *other = task.result()
if writer.transport.is_closing(): if writer.transport.is_closing():
self.pools[(host, port)].remove(task) self.pools[(host, port, init_func)].remove(task)
continue continue
if not ret: if not ret:
self.pools[(host, port)].remove(task) self.pools[(host, port, init_func)].remove(task)
ret = (reader, writer) ret = (reader, writer, *other)
self.register_host_port(host, port) self.register_host_port(host, port, init_func)
if ret: if ret:
return ret return ret
return await self.open_tg_connection(host, port) return await self.open_tg_connection(host, port, init_func)
tg_connection_pool = TgConnectionPool() tg_connection_pool = TgConnectionPool()
@@ -1269,13 +1276,13 @@ async def do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=None):
except ConnectionRefusedError as E: except ConnectionRefusedError as E:
print_err("Got connection refused while trying to connect to", dc, TG_DATACENTER_PORT) print_err("Got connection refused while trying to connect to", dc, TG_DATACENTER_PORT)
return False return False
except ConnectionAbortedError as E:
print_err("The Telegram server connection is bad: %d (%s %s) %s" % (dc_idx, addr, port, E))
return False
except (OSError, asyncio.TimeoutError) as E: except (OSError, asyncio.TimeoutError) as E:
print_err("Unable to connect to", dc, TG_DATACENTER_PORT) print_err("Unable to connect to", dc, TG_DATACENTER_PORT)
return False return False
set_keepalive(writer_tgt.get_extra_info("socket"))
set_bufsizes(writer_tgt.get_extra_info("socket"), get_to_clt_bufsize(), get_to_tg_bufsize())
while True: while True:
rnd = bytearray(myrandom.getrandbytes(HANDSHAKE_LEN)) rnd = bytearray(myrandom.getrandbytes(HANDSHAKE_LEN))
if rnd[:1] in RESERVED_NONCE_FIRST_CHARS: if rnd[:1] in RESERVED_NONCE_FIRST_CHARS:
@@ -1338,20 +1345,50 @@ def get_middleproxy_aes_key_and_iv(nonce_srv, nonce_clt, clt_ts, srv_ip, clt_por
return key, iv return key, iv
async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port): async def middleproxy_handshake_after_connect(host, port, reader_tgt, writer_tgt):
""" The first stage of middleproxy handshake """
START_SEQ_NO = -2 START_SEQ_NO = -2
NONCE_LEN = 16 NONCE_LEN = 16
RPC_NONCE = b"\xaa\x87\xcb\x7a" RPC_NONCE = b"\xaa\x87\xcb\x7a"
RPC_HANDSHAKE = b"\xf5\xee\x82\x76"
CRYPTO_AES = b"\x01\x00\x00\x00" CRYPTO_AES = b"\x01\x00\x00\x00"
RPC_NONCE_ANS_LEN = 32 RPC_NONCE_ANS_LEN = 32
RPC_HANDSHAKE_ANS_LEN = 32
writer_tgt = MTProtoFrameStreamWriter(writer_tgt, START_SEQ_NO)
key_selector = PROXY_SECRET[:4]
crypto_ts = int.to_bytes(int(time.time()) % (256**4), 4, "little")
nonce = myrandom.getrandbytes(NONCE_LEN)
msg = RPC_NONCE + key_selector + CRYPTO_AES + crypto_ts + nonce
writer_tgt.write(msg)
await writer_tgt.drain()
reader_tgt = MTProtoFrameStreamReader(reader_tgt, START_SEQ_NO)
ans = await reader_tgt.read(get_to_clt_bufsize())
if len(ans) != RPC_NONCE_ANS_LEN:
raise ConnectionAbortedError("bad rpc answer length")
rpc_type, rpc_key_selector, rpc_schema, rpc_crypto_ts, rpc_nonce = (
ans[:4], ans[4:8], ans[8:12], ans[12:16], ans[16:32]
)
if rpc_type != RPC_NONCE or rpc_key_selector != key_selector or rpc_schema != CRYPTO_AES:
raise ConnectionAbortedError("bad rpc answer")
return reader_tgt, writer_tgt, nonce, rpc_nonce, crypto_ts
async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port):
RPC_HANDSHAKE = b"\xf5\xee\x82\x76"
# pass as consts to simplify code # pass as consts to simplify code
RPC_FLAGS = b"\x00\x00\x00\x00" RPC_FLAGS = b"\x00\x00\x00\x00"
RPC_HANDSHAKE_ANS_LEN = 32
global my_ip_info global my_ip_info
global tg_connection_pool global tg_connection_pool
@@ -1368,43 +1405,19 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port):
addr, port = myrandom.choice(TG_MIDDLE_PROXIES_V4[dc_idx]) addr, port = myrandom.choice(TG_MIDDLE_PROXIES_V4[dc_idx])
try: try:
reader_tgt, writer_tgt = await tg_connection_pool.get_connection(addr, port) ret = await tg_connection_pool.get_connection(addr, port,
middleproxy_handshake_after_connect)
reader_tgt, writer_tgt, nonce, rpc_nonce, crypto_ts = ret
except ConnectionRefusedError as E: except ConnectionRefusedError as E:
print_err("The Telegram server %d (%s %s) is refusing connections" % (dc_idx, addr, port)) print_err("The Telegram server %d (%s %s) is refusing connections" % (dc_idx, addr, port))
return False return False
except ConnectionAbortedError as E:
print_err("The Telegram server connection is bad: %d (%s %s) %s" % (dc_idx, addr, port, E))
return False
except (OSError, asyncio.TimeoutError) as E: except (OSError, asyncio.TimeoutError) as E:
print_err("Unable to connect to the Telegram server %d (%s %s)" % (dc_idx, addr, port)) print_err("Unable to connect to the Telegram server %d (%s %s)" % (dc_idx, addr, port))
return False return False
set_keepalive(writer_tgt.get_extra_info("socket"))
set_bufsizes(writer_tgt.get_extra_info("socket"), get_to_clt_bufsize(), get_to_tg_bufsize())
writer_tgt = MTProtoFrameStreamWriter(writer_tgt, START_SEQ_NO)
key_selector = PROXY_SECRET[:4]
crypto_ts = int.to_bytes(int(time.time()) % (256**4), 4, "little")
nonce = myrandom.getrandbytes(NONCE_LEN)
msg = RPC_NONCE + key_selector + CRYPTO_AES + crypto_ts + nonce
writer_tgt.write(msg)
await writer_tgt.drain()
old_reader = reader_tgt
reader_tgt = MTProtoFrameStreamReader(reader_tgt, START_SEQ_NO)
ans = await reader_tgt.read(get_to_clt_bufsize())
if len(ans) != RPC_NONCE_ANS_LEN:
return False
rpc_type, rpc_key_selector, rpc_schema, rpc_crypto_ts, rpc_nonce = (
ans[:4], ans[4:8], ans[8:12], ans[12:16], ans[16:32]
)
if rpc_type != RPC_NONCE or rpc_key_selector != key_selector or rpc_schema != CRYPTO_AES:
return False
# get keys # get keys
tg_ip, tg_port = writer_tgt.upstream.get_extra_info('peername')[:2] tg_ip, tg_port = writer_tgt.upstream.get_extra_info('peername')[:2]
my_ip, my_port = writer_tgt.upstream.get_extra_info('sockname')[:2] my_ip, my_port = writer_tgt.upstream.get_extra_info('sockname')[:2]