adaptive buffer sizes

This commit is contained in:
Alexander Bersenev
2019-05-09 02:51:36 +05:00
parent d48c177e36
commit 6f70ff3003

View File

@@ -56,9 +56,11 @@ STATS_PRINT_PERIOD = config.get("STATS_PRINT_PERIOD", 600)
PROXY_INFO_UPDATE_PERIOD = config.get("PROXY_INFO_UPDATE_PERIOD", 24*60*60)
# max socket buffer size to the client direction, the more the faster, but more RAM hungry
TO_CLT_BUFSIZE = config.get("TO_CLT_BUFSIZE", 16384)
# can be tuple for adaptive case: (low, users_margin, high)
TO_CLT_BUFSIZE = config.get("TO_CLT_BUFSIZE", (16384, 100, 131072))
# max socket buffer size to the telegram servers direction
# also can be tuple
TO_TG_BUFSIZE = config.get("TO_TG_BUFSIZE", 65536)
# keepalive period for clients in secs
@@ -276,6 +278,31 @@ def update_stats(user, connects=0, curr_connects=0, octets=0, msgs=0):
octets=octets, msgs=msgs)
def get_curr_connects_count():
global stats
all_connects = 0
for user, stat in stats.items():
all_connects += stat["curr_connects"]
return all_connects
def get_to_tg_bufsize():
if isinstance(TO_TG_BUFSIZE, int):
return TO_TG_BUFSIZE
low, margin, high = TO_TG_BUFSIZE
return high if get_curr_connects_count() < margin else low
def get_to_clt_bufsize():
if isinstance(TO_CLT_BUFSIZE, int):
return TO_CLT_BUFSIZE
low, margin, high = TO_CLT_BUFSIZE
return high if get_curr_connects_count() < margin else low
class LayeredStreamReaderBase:
def __init__(self, upstream):
self.upstream = upstream
@@ -722,7 +749,7 @@ async def do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=None):
try:
reader_tgt, writer_tgt = await open_connection_tryer(
dc, TG_DATACENTER_PORT, limit=TO_CLT_BUFSIZE, timeout=TG_CONNECT_TIMEOUT)
dc, TG_DATACENTER_PORT, limit=get_to_clt_bufsize(), timeout=TG_CONNECT_TIMEOUT)
except ConnectionRefusedError as E:
print_err("Got connection refused while trying to connect to", dc, TG_DATACENTER_PORT)
return False
@@ -731,7 +758,7 @@ async def do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=None):
return False
set_keepalive(writer_tgt.get_extra_info("socket"))
set_bufsizes(writer_tgt.get_extra_info("socket"), TO_CLT_BUFSIZE, TO_TG_BUFSIZE)
set_bufsizes(writer_tgt.get_extra_info("socket"), get_to_clt_bufsize(), get_to_tg_bufsize())
while True:
rnd = bytearray([random.randrange(0, 256) for i in range(HANDSHAKE_LEN)])
@@ -822,7 +849,7 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port):
addr, port = random.choice(TG_MIDDLE_PROXIES_V4[dc_idx])
try:
reader_tgt, writer_tgt = await open_connection_tryer(addr, port, limit=TO_CLT_BUFSIZE,
reader_tgt, writer_tgt = await open_connection_tryer(addr, port, limit=get_to_clt_bufsize(),
timeout=TG_CONNECT_TIMEOUT)
except ConnectionRefusedError as E:
print_err("Got connection refused while trying to connect to", addr, port)
@@ -832,7 +859,7 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port):
return False
set_keepalive(writer_tgt.get_extra_info("socket"))
set_bufsizes(writer_tgt.get_extra_info("socket"), TO_CLT_BUFSIZE, TO_TG_BUFSIZE)
set_bufsizes(writer_tgt.get_extra_info("socket"), get_to_clt_bufsize(), get_to_tg_bufsize())
writer_tgt = MTProtoFrameStreamWriter(writer_tgt, START_SEQ_NO)
@@ -848,7 +875,7 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port):
old_reader = reader_tgt
reader_tgt = MTProtoFrameStreamReader(reader_tgt, START_SEQ_NO)
ans = await reader_tgt.read(TO_CLT_BUFSIZE)
ans = await reader_tgt.read(get_to_clt_bufsize())
if len(ans) != RPC_NONCE_ANS_LEN:
return False
@@ -931,7 +958,7 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port):
async def handle_client(reader_clt, writer_clt):
set_keepalive(writer_clt.get_extra_info("socket"), CLIENT_KEEPALIVE, attempts=3)
set_ack_timeout(writer_clt.get_extra_info("socket"), CLIENT_ACK_TIMEOUT)
set_bufsizes(writer_clt.get_extra_info("socket"), TO_TG_BUFSIZE, TO_CLT_BUFSIZE)
set_bufsizes(writer_clt.get_extra_info("socket"), get_to_tg_bufsize(), get_to_clt_bufsize())
cl_ip, cl_port = writer_clt.get_extra_info('peername')[:2]
try:
@@ -1020,9 +1047,9 @@ async def handle_client(reader_clt, writer_clt):
# print_err(e)
pass
tg_to_clt = connect_reader_to_writer(reader_tg, writer_clt, user, TO_CLT_BUFSIZE,
tg_to_clt = connect_reader_to_writer(reader_tg, writer_clt, user, get_to_clt_bufsize(),
block_short_first_pkt=BLOCK_SHORT_FIRST_PKT)
clt_to_tg = connect_reader_to_writer(reader_clt, writer_tg, user, TO_TG_BUFSIZE)
clt_to_tg = connect_reader_to_writer(reader_clt, writer_tg, user, get_to_tg_bufsize())
task_tg_to_clt = asyncio.ensure_future(tg_to_clt)
task_clt_to_tg = asyncio.ensure_future(clt_to_tg)
@@ -1238,12 +1265,12 @@ def main():
reuse_port = hasattr(socket, "SO_REUSEPORT")
task_v4 = asyncio.start_server(handle_client_wrapper, LISTEN_ADDR_IPV4, PORT,
limit=TO_TG_BUFSIZE, reuse_port=reuse_port, loop=loop)
limit=get_to_tg_bufsize(), reuse_port=reuse_port, loop=loop)
server_v4 = loop.run_until_complete(task_v4)
if socket.has_ipv6:
task_v6 = asyncio.start_server(handle_client_wrapper, LISTEN_ADDR_IPV6, PORT,
limit=TO_TG_BUFSIZE, reuse_port=reuse_port, loop=loop)
limit=get_to_tg_bufsize(), reuse_port=reuse_port, loop=loop)
server_v6 = loop.run_until_complete(task_v6)
try: