mirror of
https://github.com/alexbers/mtprotoproxy.git
synced 2026-03-14 07:13:09 +00:00
adaptive buffer sizes
This commit is contained in:
@@ -56,9 +56,11 @@ STATS_PRINT_PERIOD = config.get("STATS_PRINT_PERIOD", 600)
|
|||||||
PROXY_INFO_UPDATE_PERIOD = config.get("PROXY_INFO_UPDATE_PERIOD", 24*60*60)
|
PROXY_INFO_UPDATE_PERIOD = config.get("PROXY_INFO_UPDATE_PERIOD", 24*60*60)
|
||||||
|
|
||||||
# max socket buffer size to the client direction, the more the faster, but more RAM hungry
|
# max socket buffer size to the client direction, the more the faster, but more RAM hungry
|
||||||
TO_CLT_BUFSIZE = config.get("TO_CLT_BUFSIZE", 16384)
|
# can be tuple for adaptive case: (low, users_margin, high)
|
||||||
|
TO_CLT_BUFSIZE = config.get("TO_CLT_BUFSIZE", (16384, 100, 131072))
|
||||||
|
|
||||||
# max socket buffer size to the telegram servers direction
|
# max socket buffer size to the telegram servers direction
|
||||||
|
# also can be tuple
|
||||||
TO_TG_BUFSIZE = config.get("TO_TG_BUFSIZE", 65536)
|
TO_TG_BUFSIZE = config.get("TO_TG_BUFSIZE", 65536)
|
||||||
|
|
||||||
# keepalive period for clients in secs
|
# keepalive period for clients in secs
|
||||||
@@ -276,6 +278,31 @@ def update_stats(user, connects=0, curr_connects=0, octets=0, msgs=0):
|
|||||||
octets=octets, msgs=msgs)
|
octets=octets, msgs=msgs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_curr_connects_count():
|
||||||
|
global stats
|
||||||
|
|
||||||
|
all_connects = 0
|
||||||
|
for user, stat in stats.items():
|
||||||
|
all_connects += stat["curr_connects"]
|
||||||
|
return all_connects
|
||||||
|
|
||||||
|
|
||||||
|
def get_to_tg_bufsize():
|
||||||
|
if isinstance(TO_TG_BUFSIZE, int):
|
||||||
|
return TO_TG_BUFSIZE
|
||||||
|
|
||||||
|
low, margin, high = TO_TG_BUFSIZE
|
||||||
|
return high if get_curr_connects_count() < margin else low
|
||||||
|
|
||||||
|
|
||||||
|
def get_to_clt_bufsize():
|
||||||
|
if isinstance(TO_CLT_BUFSIZE, int):
|
||||||
|
return TO_CLT_BUFSIZE
|
||||||
|
|
||||||
|
low, margin, high = TO_CLT_BUFSIZE
|
||||||
|
return high if get_curr_connects_count() < margin else low
|
||||||
|
|
||||||
|
|
||||||
class LayeredStreamReaderBase:
|
class LayeredStreamReaderBase:
|
||||||
def __init__(self, upstream):
|
def __init__(self, upstream):
|
||||||
self.upstream = upstream
|
self.upstream = upstream
|
||||||
@@ -722,7 +749,7 @@ async def do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=None):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
reader_tgt, writer_tgt = await open_connection_tryer(
|
reader_tgt, writer_tgt = await open_connection_tryer(
|
||||||
dc, TG_DATACENTER_PORT, limit=TO_CLT_BUFSIZE, timeout=TG_CONNECT_TIMEOUT)
|
dc, TG_DATACENTER_PORT, limit=get_to_clt_bufsize(), timeout=TG_CONNECT_TIMEOUT)
|
||||||
except ConnectionRefusedError as E:
|
except ConnectionRefusedError as E:
|
||||||
print_err("Got connection refused while trying to connect to", dc, TG_DATACENTER_PORT)
|
print_err("Got connection refused while trying to connect to", dc, TG_DATACENTER_PORT)
|
||||||
return False
|
return False
|
||||||
@@ -731,7 +758,7 @@ async def do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=None):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
set_keepalive(writer_tgt.get_extra_info("socket"))
|
set_keepalive(writer_tgt.get_extra_info("socket"))
|
||||||
set_bufsizes(writer_tgt.get_extra_info("socket"), TO_CLT_BUFSIZE, TO_TG_BUFSIZE)
|
set_bufsizes(writer_tgt.get_extra_info("socket"), get_to_clt_bufsize(), get_to_tg_bufsize())
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
rnd = bytearray([random.randrange(0, 256) for i in range(HANDSHAKE_LEN)])
|
rnd = bytearray([random.randrange(0, 256) for i in range(HANDSHAKE_LEN)])
|
||||||
@@ -822,7 +849,7 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port):
|
|||||||
addr, port = random.choice(TG_MIDDLE_PROXIES_V4[dc_idx])
|
addr, port = random.choice(TG_MIDDLE_PROXIES_V4[dc_idx])
|
||||||
|
|
||||||
try:
|
try:
|
||||||
reader_tgt, writer_tgt = await open_connection_tryer(addr, port, limit=TO_CLT_BUFSIZE,
|
reader_tgt, writer_tgt = await open_connection_tryer(addr, port, limit=get_to_clt_bufsize(),
|
||||||
timeout=TG_CONNECT_TIMEOUT)
|
timeout=TG_CONNECT_TIMEOUT)
|
||||||
except ConnectionRefusedError as E:
|
except ConnectionRefusedError as E:
|
||||||
print_err("Got connection refused while trying to connect to", addr, port)
|
print_err("Got connection refused while trying to connect to", addr, port)
|
||||||
@@ -832,7 +859,7 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
set_keepalive(writer_tgt.get_extra_info("socket"))
|
set_keepalive(writer_tgt.get_extra_info("socket"))
|
||||||
set_bufsizes(writer_tgt.get_extra_info("socket"), TO_CLT_BUFSIZE, TO_TG_BUFSIZE)
|
set_bufsizes(writer_tgt.get_extra_info("socket"), get_to_clt_bufsize(), get_to_tg_bufsize())
|
||||||
|
|
||||||
writer_tgt = MTProtoFrameStreamWriter(writer_tgt, START_SEQ_NO)
|
writer_tgt = MTProtoFrameStreamWriter(writer_tgt, START_SEQ_NO)
|
||||||
|
|
||||||
@@ -848,7 +875,7 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port):
|
|||||||
|
|
||||||
old_reader = reader_tgt
|
old_reader = reader_tgt
|
||||||
reader_tgt = MTProtoFrameStreamReader(reader_tgt, START_SEQ_NO)
|
reader_tgt = MTProtoFrameStreamReader(reader_tgt, START_SEQ_NO)
|
||||||
ans = await reader_tgt.read(TO_CLT_BUFSIZE)
|
ans = await reader_tgt.read(get_to_clt_bufsize())
|
||||||
|
|
||||||
if len(ans) != RPC_NONCE_ANS_LEN:
|
if len(ans) != RPC_NONCE_ANS_LEN:
|
||||||
return False
|
return False
|
||||||
@@ -931,7 +958,7 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port):
|
|||||||
async def handle_client(reader_clt, writer_clt):
|
async def handle_client(reader_clt, writer_clt):
|
||||||
set_keepalive(writer_clt.get_extra_info("socket"), CLIENT_KEEPALIVE, attempts=3)
|
set_keepalive(writer_clt.get_extra_info("socket"), CLIENT_KEEPALIVE, attempts=3)
|
||||||
set_ack_timeout(writer_clt.get_extra_info("socket"), CLIENT_ACK_TIMEOUT)
|
set_ack_timeout(writer_clt.get_extra_info("socket"), CLIENT_ACK_TIMEOUT)
|
||||||
set_bufsizes(writer_clt.get_extra_info("socket"), TO_TG_BUFSIZE, TO_CLT_BUFSIZE)
|
set_bufsizes(writer_clt.get_extra_info("socket"), get_to_tg_bufsize(), get_to_clt_bufsize())
|
||||||
|
|
||||||
cl_ip, cl_port = writer_clt.get_extra_info('peername')[:2]
|
cl_ip, cl_port = writer_clt.get_extra_info('peername')[:2]
|
||||||
try:
|
try:
|
||||||
@@ -1020,9 +1047,9 @@ async def handle_client(reader_clt, writer_clt):
|
|||||||
# print_err(e)
|
# print_err(e)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
tg_to_clt = connect_reader_to_writer(reader_tg, writer_clt, user, TO_CLT_BUFSIZE,
|
tg_to_clt = connect_reader_to_writer(reader_tg, writer_clt, user, get_to_clt_bufsize(),
|
||||||
block_short_first_pkt=BLOCK_SHORT_FIRST_PKT)
|
block_short_first_pkt=BLOCK_SHORT_FIRST_PKT)
|
||||||
clt_to_tg = connect_reader_to_writer(reader_clt, writer_tg, user, TO_TG_BUFSIZE)
|
clt_to_tg = connect_reader_to_writer(reader_clt, writer_tg, user, get_to_tg_bufsize())
|
||||||
task_tg_to_clt = asyncio.ensure_future(tg_to_clt)
|
task_tg_to_clt = asyncio.ensure_future(tg_to_clt)
|
||||||
task_clt_to_tg = asyncio.ensure_future(clt_to_tg)
|
task_clt_to_tg = asyncio.ensure_future(clt_to_tg)
|
||||||
|
|
||||||
@@ -1238,12 +1265,12 @@ def main():
|
|||||||
reuse_port = hasattr(socket, "SO_REUSEPORT")
|
reuse_port = hasattr(socket, "SO_REUSEPORT")
|
||||||
|
|
||||||
task_v4 = asyncio.start_server(handle_client_wrapper, LISTEN_ADDR_IPV4, PORT,
|
task_v4 = asyncio.start_server(handle_client_wrapper, LISTEN_ADDR_IPV4, PORT,
|
||||||
limit=TO_TG_BUFSIZE, reuse_port=reuse_port, loop=loop)
|
limit=get_to_tg_bufsize(), reuse_port=reuse_port, loop=loop)
|
||||||
server_v4 = loop.run_until_complete(task_v4)
|
server_v4 = loop.run_until_complete(task_v4)
|
||||||
|
|
||||||
if socket.has_ipv6:
|
if socket.has_ipv6:
|
||||||
task_v6 = asyncio.start_server(handle_client_wrapper, LISTEN_ADDR_IPV6, PORT,
|
task_v6 = asyncio.start_server(handle_client_wrapper, LISTEN_ADDR_IPV6, PORT,
|
||||||
limit=TO_TG_BUFSIZE, reuse_port=reuse_port, loop=loop)
|
limit=get_to_tg_bufsize(), reuse_port=reuse_port, loop=loop)
|
||||||
server_v6 = loop.run_until_complete(task_v6)
|
server_v6 = loop.run_until_complete(task_v6)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user