diff --git a/mtprotoproxy.py b/mtprotoproxy.py index 5a2577d..402bb8b 100755 --- a/mtprotoproxy.py +++ b/mtprotoproxy.py @@ -436,6 +436,48 @@ class MyRandom(random.Random): myrandom = MyRandom() +class TgConnectionPool: + MAX_CONNS_IN_POOL = 4 + + def __init__(self): + self.pools = {} + + async def open_tg_connection(self, host, port): + task = asyncio.open_connection(host, port, limit=get_to_clt_bufsize()) + reader_tgt, writer_tgt = await asyncio.wait_for(task, timeout=config.TG_CONNECT_TIMEOUT) + return reader_tgt, writer_tgt + + def register_host_port(self, host, port): + if (host, port) not in self.pools: + self.pools[(host, port)] = set() + + while len(self.pools[(host, port)]) < TgConnectionPool.MAX_CONNS_IN_POOL: + connect_task = asyncio.ensure_future(self.open_tg_connection(host, port)) + self.pools[(host, port)].add(connect_task) + + async def get_connection(self, host, port): + self.register_host_port(host, port) + + for task in self.pools[(host, port)].copy(): + if task.done(): + self.pools[(host, port)].remove(task) + self.register_host_port(host, port) + + if task.exception(): + continue + + reader, writer = task.result() + if writer.transport.is_closing(): + continue + + return reader, writer + + return await self.open_tg_connection(host, port) + + +tg_connection_pool = TgConnectionPool() + + class LayeredStreamReaderBase: __slots__ = ("upstream", ) @@ -1196,21 +1238,6 @@ async def handle_handshake(reader, writer): return False -async def open_connection_tryer(addr, port, limit, timeout, max_attempts=3): - for attempt in range(max_attempts-1): - try: - task = asyncio.open_connection(addr, port, limit=limit) - reader_tgt, writer_tgt = await asyncio.wait_for(task, timeout=timeout) - return reader_tgt, writer_tgt - except (OSError, asyncio.TimeoutError): - continue - - # the last attempt - task = asyncio.open_connection(addr, port, limit=limit) - reader_tgt, writer_tgt = await asyncio.wait_for(task, timeout=timeout) - return reader_tgt, writer_tgt - - async def do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=None): RESERVED_NONCE_FIRST_CHARS = [b"\xef"] RESERVED_NONCE_BEGININGS = [b"\x48\x45\x41\x44", b"\x50\x4F\x53\x54", @@ -1219,6 +1246,7 @@ async def do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=None): RESERVED_NONCE_CONTINUES = [b"\x00\x00\x00\x00"] global my_ip_info + global tg_connection_pool dc_idx = abs(dc_idx) - 1 @@ -1232,8 +1260,7 @@ async def do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=None): dc = TG_DATACENTERS_V4[dc_idx] try: - reader_tgt, writer_tgt = await open_connection_tryer( - dc, TG_DATACENTER_PORT, limit=get_to_clt_bufsize(), timeout=config.TG_CONNECT_TIMEOUT) + reader_tgt, writer_tgt = await tg_connection_pool.get_connection(dc, TG_DATACENTER_PORT) except ConnectionRefusedError as E: print_err("Got connection refused while trying to connect to", dc, TG_DATACENTER_PORT) return False @@ -1321,6 +1348,7 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port): RPC_FLAGS = b"\x00\x00\x00\x00" global my_ip_info + global tg_connection_pool use_ipv6_tg = (my_ip_info["ipv6"] and (config.PREFER_IPV6 or not my_ip_info["ipv4"])) use_ipv6_clt = (":" in cl_ip) @@ -1335,8 +1363,7 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port): addr, port = myrandom.choice(TG_MIDDLE_PROXIES_V4[dc_idx]) try: - reader_tgt, writer_tgt = await open_connection_tryer(addr, port, limit=get_to_clt_bufsize(), - timeout=config.TG_CONNECT_TIMEOUT) + reader_tgt, writer_tgt = await tg_connection_pool.get_connection(addr, port) except ConnectionRefusedError as E: print_err("The Telegram server %d (%s %s) is refusing connections" % (dc_idx, addr, port)) return False