From 121a8974de1ec3dd349a304ba5ba4bfaa1116acb Mon Sep 17 00:00:00 2001 From: Vladislav Grishenko Date: Sat, 17 Aug 2019 15:11:49 +0500 Subject: [PATCH] add unix socket support (#127) Config option LISTEN_UNIX_SOCK = "/path/to/socket.file" allows to listen on specified unix socket in additional to (or instead of) configured ip addresses. Listening on a socket can be useful for connection from local reverse proxy w/o wasting tcp ports and network subsystem resources just for inter-process communication. Default value is empty - socket not used. --- mtprotoproxy.py | 45 +++++++++++++++++++++++++++++++++------------ 1 file changed, 33 insertions(+), 12 deletions(-) diff --git a/mtprotoproxy.py b/mtprotoproxy.py index 2885ab3..b2aec89 100755 --- a/mtprotoproxy.py +++ b/mtprotoproxy.py @@ -16,6 +16,7 @@ import sys import re import runpy import signal +import os TG_DATACENTER_PORT = 443 @@ -194,6 +195,9 @@ def init_config(): # listen address for IPv6 conf_dict.setdefault("LISTEN_ADDR_IPV6", "::") + # listen unix socket + conf_dict.setdefault("LISTEN_UNIX_SOCK", "") + # allow access to config by attributes config = type("config", (dict,), conf_dict)(conf_dict) @@ -1745,6 +1749,15 @@ def try_setup_uvloop(): pass +def remove_unix_socket(path): + from stat import S_ISSOCK + try: + if S_ISSOCK(os.stat(path).st_mode): + os.unlink(path) + except (FileNotFoundError, NotADirectoryError): + pass + + def loop_exception_handler(loop, context): exception = context.get("exception") transport = context.get("transport") @@ -1802,16 +1815,25 @@ def main(): asyncio.ensure_future(get_cert_len_task) reuse_port = hasattr(socket, "SO_REUSEPORT") + has_unix = hasattr(socket, "AF_UNIX") + servers = [] if config.LISTEN_ADDR_IPV4: - task_v4 = asyncio.start_server(handle_client_wrapper, config.LISTEN_ADDR_IPV4, config.PORT, - limit=get_to_tg_bufsize(), reuse_port=reuse_port, loop=loop) - server_v4 = loop.run_until_complete(task_v4) + task = asyncio.start_server(handle_client_wrapper, config.LISTEN_ADDR_IPV4, config.PORT, + limit=get_to_tg_bufsize(), reuse_port=reuse_port, loop=loop) + servers.append(loop.run_until_complete(task)) if config.LISTEN_ADDR_IPV6 and socket.has_ipv6: - task_v6 = asyncio.start_server(handle_client_wrapper, config.LISTEN_ADDR_IPV6, config.PORT, - limit=get_to_tg_bufsize(), reuse_port=reuse_port, loop=loop) - server_v6 = loop.run_until_complete(task_v6) + task = asyncio.start_server(handle_client_wrapper, config.LISTEN_ADDR_IPV6, config.PORT, + limit=get_to_tg_bufsize(), reuse_port=reuse_port, loop=loop) + servers.append(loop.run_until_complete(task)) + + if config.LISTEN_UNIX_SOCK and has_unix: + remove_unix_socket(config.LISTEN_UNIX_SOCK) + task = asyncio.start_unix_server(handle_client_wrapper, config.LISTEN_UNIX_SOCK, + limit=get_to_tg_bufsize(), loop=loop) + servers.append(loop.run_until_complete(task)) + os.chmod(config.LISTEN_UNIX_SOCK, 0o666) try: loop.run_forever() @@ -1821,13 +1843,12 @@ def main(): for task in asyncio.Task.all_tasks(): task.cancel() - if config.LISTEN_ADDR_IPV4: - server_v4.close() - loop.run_until_complete(server_v4.wait_closed()) + for server in servers: + server.close() + loop.run_until_complete(server.wait_closed()) - if config.LISTEN_ADDR_IPV6 and socket.has_ipv6: - server_v6.close() - loop.run_until_complete(server_v6.wait_closed()) + if config.LISTEN_UNIX_SOCK and has_unix: + remove_unix_socket(config.LISTEN_UNIX_SOCK) loop.close()