From 2bb0ef0b1fc3511d8ec0da784ae506a9eb24d81f Mon Sep 17 00:00:00 2001 From: Alexander Bersenev Date: Wed, 12 Feb 2020 15:41:05 +0500 Subject: [PATCH] simplify initialization and stats --- mtprotoproxy.py | 111 ++++++++++++++++++++++++++++-------------------- 1 file changed, 65 insertions(+), 46 deletions(-) diff --git a/mtprotoproxy.py b/mtprotoproxy.py index 22a6bde..90327d2 100755 --- a/mtprotoproxy.py +++ b/mtprotoproxy.py @@ -93,6 +93,9 @@ last_clients_with_same_handshake = collections.Counter() proxy_start_time = 0 proxy_links = [] +stats = collections.Counter() +user_stats = collections.defaultdict(collections.Counter) + config = {} @@ -380,12 +383,11 @@ def print_err(*params): print(*params, file=sys.stderr, flush=True) -def init_stats(): - global stats +def ensure_users_in_user_stats(): global user_stats - stats = collections.Counter() - user_stats = {user: collections.Counter() for user in config.USERS} + for user in config.USERS: + user_stats[user].update() def init_proxy_start_time(): @@ -400,9 +402,6 @@ def update_stats(**kw_stats): def update_user_stats(user, **kw_stats): global user_stats - - if user not in user_stats: - user_stats[user] = collections.Counter() user_stats[user].update(**kw_stats) @@ -2137,6 +2136,7 @@ def setup_signals(): if hasattr(signal, 'SIGUSR2'): def reload_signal(signum, frame): init_config() + ensure_users_in_user_stats() apply_upstream_proxy_settings() print("Config reloaded", flush=True, file=sys.stderr) print_tg_info() @@ -2192,41 +2192,11 @@ def loop_exception_handler(loop, context): loop.default_exception_handler(context) -def main(): - setup_files_limit() - setup_signals() - try_setup_uvloop() - - init_stats() - init_proxy_start_time() - - if sys.platform == "win32": - loop = asyncio.ProactorEventLoop() - asyncio.set_event_loop(loop) - - loop = asyncio.get_event_loop() - loop.set_exception_handler(loop_exception_handler) - - stats_printer_task = asyncio.Task(stats_printer()) - asyncio.ensure_future(stats_printer_task) - - if config.USE_MIDDLE_PROXY: - middle_proxy_updater_task = asyncio.Task(update_middle_proxy_info()) - asyncio.ensure_future(middle_proxy_updater_task) - - if config.GET_TIME_PERIOD: - time_get_task = asyncio.Task(get_srv_time()) - asyncio.ensure_future(time_get_task) - - get_cert_len_task = asyncio.Task(get_mask_host_cert_len()) - asyncio.ensure_future(get_cert_len_task) - - clear_resolving_cache_task = asyncio.Task(clear_ip_resolving_cache()) - asyncio.ensure_future(clear_resolving_cache_task) +def create_servers(loop): + servers = [] reuse_port = hasattr(socket, "SO_REUSEPORT") has_unix = hasattr(socket, "AF_UNIX") - servers = [] if config.LISTEN_ADDR_IPV4: task = asyncio.start_server(handle_client_wrapper, config.LISTEN_ADDR_IPV4, config.PORT, @@ -2255,15 +2225,66 @@ def main(): config.METRICS_PORT) servers.append(loop.run_until_complete(task)) + return servers + + +def create_utilitary_tasks(loop): + tasks = [] + + stats_printer_task = asyncio.Task(stats_printer()) + tasks.append(stats_printer_task) + + if config.USE_MIDDLE_PROXY: + middle_proxy_updater_task = asyncio.Task(update_middle_proxy_info()) + tasks.append(middle_proxy_updater_task) + + if config.GET_TIME_PERIOD: + time_get_task = asyncio.Task(get_srv_time()) + tasks.append(time_get_task) + + get_cert_len_task = asyncio.Task(get_mask_host_cert_len()) + tasks.append(get_cert_len_task) + + clear_resolving_cache_task = asyncio.Task(clear_ip_resolving_cache()) + tasks.append(clear_resolving_cache_task) + + return tasks + + +def main(): + init_config() + ensure_users_in_user_stats() + apply_upstream_proxy_settings() + init_ip_info() + print_tg_info() + + setup_files_limit() + setup_signals() + try_setup_uvloop() + + init_proxy_start_time() + + if sys.platform == "win32": + loop = asyncio.ProactorEventLoop() + asyncio.set_event_loop(loop) + + loop = asyncio.get_event_loop() + loop.set_exception_handler(loop_exception_handler) + + for task in create_utilitary_tasks(loop): + asyncio.ensure_future(task) + + servers = create_servers(loop) + try: loop.run_forever() except KeyboardInterrupt: pass - try: + if hasattr(asyncio, "all_tasks"): tasks = asyncio.all_tasks(loop) - except AttributeError: - # for compatibility with python 3.6 + else: + # for compatibility with Python 3.6 tasks = asyncio.Task.all_tasks(loop) for task in tasks: @@ -2273,6 +2294,8 @@ def main(): server.close() loop.run_until_complete(server.wait_closed()) + has_unix = hasattr(socket, "AF_UNIX") + if config.LISTEN_UNIX_SOCK and has_unix: remove_unix_socket(config.LISTEN_UNIX_SOCK) @@ -2280,8 +2303,4 @@ def main(): if __name__ == "__main__": - init_config() - apply_upstream_proxy_settings() - init_ip_info() - print_tg_info() main()