diff --git a/mtprotoproxy.py b/mtprotoproxy.py index 71e4bcd..4020056 100755 --- a/mtprotoproxy.py +++ b/mtprotoproxy.py @@ -56,9 +56,11 @@ STATS_PRINT_PERIOD = config.get("STATS_PRINT_PERIOD", 600) PROXY_INFO_UPDATE_PERIOD = config.get("PROXY_INFO_UPDATE_PERIOD", 24*60*60) # max socket buffer size to the client direction, the more the faster, but more RAM hungry -TO_CLT_BUFSIZE = config.get("TO_CLT_BUFSIZE", 16384) +# can be tuple for adaptive case: (low, users_margin, high) +TO_CLT_BUFSIZE = config.get("TO_CLT_BUFSIZE", (16384, 100, 131072)) # max socket buffer size to the telegram servers direction +# also can be tuple TO_TG_BUFSIZE = config.get("TO_TG_BUFSIZE", 65536) # keepalive period for clients in secs @@ -276,6 +278,31 @@ def update_stats(user, connects=0, curr_connects=0, octets=0, msgs=0): octets=octets, msgs=msgs) +def get_curr_connects_count(): + global stats + + all_connects = 0 + for user, stat in stats.items(): + all_connects += stat["curr_connects"] + return all_connects + + +def get_to_tg_bufsize(): + if isinstance(TO_TG_BUFSIZE, int): + return TO_TG_BUFSIZE + + low, margin, high = TO_TG_BUFSIZE + return high if get_curr_connects_count() < margin else low + + +def get_to_clt_bufsize(): + if isinstance(TO_CLT_BUFSIZE, int): + return TO_CLT_BUFSIZE + + low, margin, high = TO_CLT_BUFSIZE + return high if get_curr_connects_count() < margin else low + + class LayeredStreamReaderBase: def __init__(self, upstream): self.upstream = upstream @@ -722,7 +749,7 @@ async def do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=None): try: reader_tgt, writer_tgt = await open_connection_tryer( - dc, TG_DATACENTER_PORT, limit=TO_CLT_BUFSIZE, timeout=TG_CONNECT_TIMEOUT) + dc, TG_DATACENTER_PORT, limit=get_to_clt_bufsize(), timeout=TG_CONNECT_TIMEOUT) except ConnectionRefusedError as E: print_err("Got connection refused while trying to connect to", dc, TG_DATACENTER_PORT) return False @@ -731,7 +758,7 @@ async def do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=None): return False set_keepalive(writer_tgt.get_extra_info("socket")) - set_bufsizes(writer_tgt.get_extra_info("socket"), TO_CLT_BUFSIZE, TO_TG_BUFSIZE) + set_bufsizes(writer_tgt.get_extra_info("socket"), get_to_clt_bufsize(), get_to_tg_bufsize()) while True: rnd = bytearray([random.randrange(0, 256) for i in range(HANDSHAKE_LEN)]) @@ -822,7 +849,7 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port): addr, port = random.choice(TG_MIDDLE_PROXIES_V4[dc_idx]) try: - reader_tgt, writer_tgt = await open_connection_tryer(addr, port, limit=TO_CLT_BUFSIZE, + reader_tgt, writer_tgt = await open_connection_tryer(addr, port, limit=get_to_clt_bufsize(), timeout=TG_CONNECT_TIMEOUT) except ConnectionRefusedError as E: print_err("Got connection refused while trying to connect to", addr, port) @@ -832,7 +859,7 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port): return False set_keepalive(writer_tgt.get_extra_info("socket")) - set_bufsizes(writer_tgt.get_extra_info("socket"), TO_CLT_BUFSIZE, TO_TG_BUFSIZE) + set_bufsizes(writer_tgt.get_extra_info("socket"), get_to_clt_bufsize(), get_to_tg_bufsize()) writer_tgt = MTProtoFrameStreamWriter(writer_tgt, START_SEQ_NO) @@ -848,7 +875,7 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port): old_reader = reader_tgt reader_tgt = MTProtoFrameStreamReader(reader_tgt, START_SEQ_NO) - ans = await reader_tgt.read(TO_CLT_BUFSIZE) + ans = await reader_tgt.read(get_to_clt_bufsize()) if len(ans) != RPC_NONCE_ANS_LEN: return False @@ -931,7 +958,7 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port): async def handle_client(reader_clt, writer_clt): set_keepalive(writer_clt.get_extra_info("socket"), CLIENT_KEEPALIVE, attempts=3) set_ack_timeout(writer_clt.get_extra_info("socket"), CLIENT_ACK_TIMEOUT) - set_bufsizes(writer_clt.get_extra_info("socket"), TO_TG_BUFSIZE, TO_CLT_BUFSIZE) + set_bufsizes(writer_clt.get_extra_info("socket"), get_to_tg_bufsize(), get_to_clt_bufsize()) cl_ip, cl_port = writer_clt.get_extra_info('peername')[:2] try: @@ -1020,9 +1047,9 @@ async def handle_client(reader_clt, writer_clt): # print_err(e) pass - tg_to_clt = connect_reader_to_writer(reader_tg, writer_clt, user, TO_CLT_BUFSIZE, + tg_to_clt = connect_reader_to_writer(reader_tg, writer_clt, user, get_to_clt_bufsize(), block_short_first_pkt=BLOCK_SHORT_FIRST_PKT) - clt_to_tg = connect_reader_to_writer(reader_clt, writer_tg, user, TO_TG_BUFSIZE) + clt_to_tg = connect_reader_to_writer(reader_clt, writer_tg, user, get_to_tg_bufsize()) task_tg_to_clt = asyncio.ensure_future(tg_to_clt) task_clt_to_tg = asyncio.ensure_future(clt_to_tg) @@ -1238,12 +1265,12 @@ def main(): reuse_port = hasattr(socket, "SO_REUSEPORT") task_v4 = asyncio.start_server(handle_client_wrapper, LISTEN_ADDR_IPV4, PORT, - limit=TO_TG_BUFSIZE, reuse_port=reuse_port, loop=loop) + limit=get_to_tg_bufsize(), reuse_port=reuse_port, loop=loop) server_v4 = loop.run_until_complete(task_v4) if socket.has_ipv6: task_v6 = asyncio.start_server(handle_client_wrapper, LISTEN_ADDR_IPV6, PORT, - limit=TO_TG_BUFSIZE, reuse_port=reuse_port, loop=loop) + limit=get_to_tg_bufsize(), reuse_port=reuse_port, loop=loop) server_v6 = loop.run_until_complete(task_v6) try: