diff --git a/mtprotoproxy.py b/mtprotoproxy.py index b3a229b..3d16f8f 100755 --- a/mtprotoproxy.py +++ b/mtprotoproxy.py @@ -72,6 +72,7 @@ else: PORT = config["PORT"] USERS = config["USERS"] +AD_TAG = bytes.fromhex(config.get("AD_TAG", "")) # load advanced settings PREFER_IPV6 = config.get("PREFER_IPV6", socket.has_ipv6) @@ -82,7 +83,7 @@ PROXY_INFO_UPDATE_PERIOD = config.get("PROXY_INFO_UPDATE_PERIOD", 60*60*24) READ_BUF_SIZE = config.get("READ_BUF_SIZE", 16384) WRITE_BUF_SIZE = config.get("WRITE_BUF_SIZE", 65536) CLIENT_KEEPALIVE = config.get("CLIENT_KEEPALIVE", 60*30) -AD_TAG = bytes.fromhex(config.get("AD_TAG", "")) +CLIENT_HANDSHAKE_TIMEOUT = config.get("CLIENT_HANDSHAKE_TIMEOUT", 10) TG_DATACENTER_PORT = 443 @@ -728,9 +729,13 @@ async def handle_client(reader_clt, writer_clt): set_keepalive(writer_clt.get_extra_info("socket"), CLIENT_KEEPALIVE) set_bufsizes(writer_clt.get_extra_info("socket")) - clt_data = await handle_handshake(reader_clt, writer_clt) + try: + clt_data = await asyncio.wait_for(handle_handshake(reader_clt, writer_clt), + timeout=CLIENT_HANDSHAKE_TIMEOUT) + except asyncio.TimeoutError: + return + if not clt_data: - writer_clt.transport.abort() return reader_clt, writer_clt, proto_tag, user, dc_idx, enc_key_and_iv = clt_data @@ -747,7 +752,6 @@ async def handle_client(reader_clt, writer_clt): tg_data = await do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port) if not tg_data: - writer_clt.transport.abort() return reader_tg, writer_tg = tg_data @@ -800,14 +804,19 @@ async def handle_client(reader_clt, writer_clt): wr.transport.abort() update_stats(user, curr_connects_x2=-1) - asyncio.ensure_future(connect_reader_to_writer(reader_tg, writer_clt, user)) - asyncio.ensure_future(connect_reader_to_writer(reader_clt, writer_tg, user)) + task_tg_to_clt = connect_reader_to_writer(reader_tg, writer_clt, user) + task_clt_to_tg = connect_reader_to_writer(reader_clt, writer_tg, user) + + await asyncio.wait([task_tg_to_clt, task_clt_to_tg], return_when=asyncio.FIRST_COMPLETED) + writer_tg.transport.abort() async def handle_client_wrapper(reader, writer): try: await handle_client(reader, writer) except (asyncio.IncompleteReadError, ConnectionResetError, TimeoutError): + pass + finally: writer.transport.abort()