diff --git a/mtprotoproxy.py b/mtprotoproxy.py index 2885ab3..b2aec89 100755 --- a/mtprotoproxy.py +++ b/mtprotoproxy.py @@ -16,6 +16,7 @@ import sys import re import runpy import signal +import os TG_DATACENTER_PORT = 443 @@ -194,6 +195,9 @@ def init_config(): # listen address for IPv6 conf_dict.setdefault("LISTEN_ADDR_IPV6", "::") + # listen unix socket + conf_dict.setdefault("LISTEN_UNIX_SOCK", "") + # allow access to config by attributes config = type("config", (dict,), conf_dict)(conf_dict) @@ -1745,6 +1749,15 @@ def try_setup_uvloop(): pass +def remove_unix_socket(path): + from stat import S_ISSOCK + try: + if S_ISSOCK(os.stat(path).st_mode): + os.unlink(path) + except (FileNotFoundError, NotADirectoryError): + pass + + def loop_exception_handler(loop, context): exception = context.get("exception") transport = context.get("transport") @@ -1802,16 +1815,25 @@ def main(): asyncio.ensure_future(get_cert_len_task) reuse_port = hasattr(socket, "SO_REUSEPORT") + has_unix = hasattr(socket, "AF_UNIX") + servers = [] if config.LISTEN_ADDR_IPV4: - task_v4 = asyncio.start_server(handle_client_wrapper, config.LISTEN_ADDR_IPV4, config.PORT, - limit=get_to_tg_bufsize(), reuse_port=reuse_port, loop=loop) - server_v4 = loop.run_until_complete(task_v4) + task = asyncio.start_server(handle_client_wrapper, config.LISTEN_ADDR_IPV4, config.PORT, + limit=get_to_tg_bufsize(), reuse_port=reuse_port, loop=loop) + servers.append(loop.run_until_complete(task)) if config.LISTEN_ADDR_IPV6 and socket.has_ipv6: - task_v6 = asyncio.start_server(handle_client_wrapper, config.LISTEN_ADDR_IPV6, config.PORT, - limit=get_to_tg_bufsize(), reuse_port=reuse_port, loop=loop) - server_v6 = loop.run_until_complete(task_v6) + task = asyncio.start_server(handle_client_wrapper, config.LISTEN_ADDR_IPV6, config.PORT, + limit=get_to_tg_bufsize(), reuse_port=reuse_port, loop=loop) + servers.append(loop.run_until_complete(task)) + + if config.LISTEN_UNIX_SOCK and has_unix: + remove_unix_socket(config.LISTEN_UNIX_SOCK) + task = asyncio.start_unix_server(handle_client_wrapper, config.LISTEN_UNIX_SOCK, + limit=get_to_tg_bufsize(), loop=loop) + servers.append(loop.run_until_complete(task)) + os.chmod(config.LISTEN_UNIX_SOCK, 0o666) try: loop.run_forever() @@ -1821,13 +1843,12 @@ def main(): for task in asyncio.Task.all_tasks(): task.cancel() - if config.LISTEN_ADDR_IPV4: - server_v4.close() - loop.run_until_complete(server_v4.wait_closed()) + for server in servers: + server.close() + loop.run_until_complete(server.wait_closed()) - if config.LISTEN_ADDR_IPV6 and socket.has_ipv6: - server_v6.close() - loop.run_until_complete(server_v6.wait_closed()) + if config.LISTEN_UNIX_SOCK and has_unix: + remove_unix_socket(config.LISTEN_UNIX_SOCK) loop.close()