diff --git a/Dockerfile b/Dockerfile index 62be9c4..3622790 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,12 +2,12 @@ FROM alpine:3.6 RUN adduser tgproxy -u 10000 -D -RUN apk add --no-cache python3 py3-crypto ca-certificates +RUN apk add --no-cache python3 py3-cryptography ca-certificates libcap COPY mtprotoproxy.py config.py /home/tgproxy/ -COPY pyaes/*.py /home/tgproxy/pyaes/ RUN chown -R tgproxy:tgproxy /home/tgproxy +RUN setcap cap_net_bind_service=+ep /usr/bin/python3.6 USER tgproxy diff --git a/README.md b/README.md index 4cd1689..e6be07c 100644 --- a/README.md +++ b/README.md @@ -16,4 +16,11 @@ To advertise a channel get a tag from **@MTProxybot** and write it to *config.py ## Performance ## The proxy performance should be enough to comfortably serve about 4 000 simultaneous users on -the smallest VDS instance with 1 CPU core and 1024MB RAM. +the VDS instance with 1 CPU core and 1024MB RAM. + +## Advanced Usage ## + +The proxy can be launched: +- with a custom config: `python3 mtprotoproxy.py [configfile]` +- several times, clients will be automaticaly balanced between instances +- using *PyPy* interprteter diff --git a/docker-compose.yml b/docker-compose.yml index f460e7a..0073e8e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -3,5 +3,5 @@ services: mtprotoproxy: build: . restart: unless-stopped - mem_limit: 1024m network_mode: "host" +# mem_limit: 1024m diff --git a/mtprotoproxy.py b/mtprotoproxy.py index f404666..ee5af3e 100755 --- a/mtprotoproxy.py +++ b/mtprotoproxy.py @@ -12,9 +12,54 @@ import binascii import sys import re import runpy - +import signal try: + import uvloop + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +except ImportError: + pass + + +def try_use_cryptography_module(): + from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes + from cryptography.hazmat.backends import default_backend + + def create_aes_ctr(key, iv): + class EncryptorAdapter: + def __init__(self, cipher): + self.encryptor = cipher.encryptor() + self.decryptor = cipher.decryptor() + + def encrypt(self, data): + return self.encryptor.update(data) + + def decrypt(self, data): + return self.decryptor.update(data) + + iv_bytes = int.to_bytes(iv, 16, "big") + cipher = Cipher(algorithms.AES(key), modes.CTR(iv_bytes), default_backend()) + return EncryptorAdapter(cipher) + + def create_aes_cbc(key, iv): + class EncryptorAdapter: + def __init__(self, cipher): + self.encryptor = cipher.encryptor() + self.decryptor = cipher.decryptor() + + def encrypt(self, data): + return self.encryptor.update(data) + + def decrypt(self, data): + return self.decryptor.update(data) + + cipher = Cipher(algorithms.AES(key), modes.CBC(iv), default_backend()) + return EncryptorAdapter(cipher) + + return create_aes_ctr, create_aes_cbc + + +def try_use_pycrypto_or_pycryptodome_module(): from Crypto.Cipher import AES from Crypto.Util import Counter @@ -25,11 +70,16 @@ try: def create_aes_cbc(key, iv): return AES.new(key, AES.MODE_CBC, iv) -except ImportError: - print("Failed to find pycryptodome or pycrypto, using slow AES implementation", - flush=True, file=sys.stderr) + return create_aes_ctr, create_aes_cbc + + +def use_slow_bundled_cryptography_module(): import pyaes + msg = "To make the program a *lot* faster, please install cryptography module: " + msg += "pip install cryptography\n" + print(msg, flush=True, file=sys.stderr) + def create_aes_ctr(key, iv): ctr = pyaes.Counter(iv) return pyaes.AESModeOfOperationCTR(key, ctr) @@ -49,8 +99,17 @@ except ImportError: mode = pyaes.AESModeOfOperationCBC(key, iv) return EncryptorAdapter(mode) + return create_aes_ctr, create_aes_cbc +try: + create_aes_ctr, create_aes_cbc = try_use_cryptography_module() +except ImportError: + try: + create_aes_ctr, create_aes_cbc = try_use_pycrypto_or_pycryptodome_module() + except ImportError: + create_aes_ctr, create_aes_cbc = use_slow_bundled_cryptography_module() + try: import resource soft_fd_limit, hard_fd_limit = resource.getrlimit(resource.RLIMIT_NOFILE) @@ -60,13 +119,29 @@ except (ValueError, OSError): except ImportError: pass -if len(sys.argv) > 1: +if hasattr(signal, 'SIGUSR1'): + def debug_signal(signum, frame): + import pdb + pdb.set_trace() + + signal.signal(signal.SIGUSR1, debug_signal) + +if len(sys.argv) < 2: + config = runpy.run_module("config") +elif len(sys.argv) == 2: config = runpy.run_path(sys.argv[1]) else: - config = runpy.run_module("config") + # undocumented way of launching + config = {} + config["PORT"] = int(sys.argv[1]) + secrets = sys.argv[2].split(",") + config["USERS"] = {"user%d" % i: secrets[i].zfill(32) for i in range(len(secrets))} + if len(sys.argv) > 3: + config["AD_TAG"] = sys.argv[3] PORT = config["PORT"] USERS = config["USERS"] +AD_TAG = bytes.fromhex(config.get("AD_TAG", "")) # load advanced settings PREFER_IPV6 = config.get("PREFER_IPV6", socket.has_ipv6) @@ -77,7 +152,7 @@ PROXY_INFO_UPDATE_PERIOD = config.get("PROXY_INFO_UPDATE_PERIOD", 60*60*24) READ_BUF_SIZE = config.get("READ_BUF_SIZE", 16384) WRITE_BUF_SIZE = config.get("WRITE_BUF_SIZE", 65536) CLIENT_KEEPALIVE = config.get("CLIENT_KEEPALIVE", 60*30) -AD_TAG = bytes.fromhex(config.get("AD_TAG", "")) +CLIENT_HANDSHAKE_TIMEOUT = config.get("CLIENT_HANDSHAKE_TIMEOUT", 10) TG_DATACENTER_PORT = 443 @@ -147,13 +222,13 @@ def init_stats(): stats = {user: collections.Counter() for user in USERS} -def update_stats(user, connects=0, curr_connects_x2=0, octets=0): +def update_stats(user, connects=0, curr_connects=0, octets=0): global stats if user not in stats: stats[user] = collections.Counter() - stats[user].update(connects=connects, curr_connects_x2=curr_connects_x2, + stats[user].update(connects=connects, curr_connects=curr_connects, octets=octets) @@ -481,6 +556,11 @@ async def handle_handshake(reader, writer): reader = CryptoWrappedStreamReader(reader, decryptor) writer = CryptoWrappedStreamWriter(writer, encryptor) return reader, writer, proto_tag, user, dc_idx, enc_key + enc_iv + + while await reader.read(READ_BUF_SIZE): + # just consume all the data + pass + return False @@ -723,9 +803,13 @@ async def handle_client(reader_clt, writer_clt): set_keepalive(writer_clt.get_extra_info("socket"), CLIENT_KEEPALIVE) set_bufsizes(writer_clt.get_extra_info("socket")) - clt_data = await handle_handshake(reader_clt, writer_clt) + try: + clt_data = await asyncio.wait_for(handle_handshake(reader_clt, writer_clt), + timeout=CLIENT_HANDSHAKE_TIMEOUT) + except asyncio.TimeoutError: + return + if not clt_data: - writer_clt.transport.abort() return reader_clt, writer_clt, proto_tag, user, dc_idx, enc_key_and_iv = clt_data @@ -742,7 +826,6 @@ async def handle_client(reader_clt, writer_clt): tg_data = await do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port) if not tg_data: - writer_clt.transport.abort() return reader_tg, writer_tg = tg_data @@ -770,7 +853,6 @@ async def handle_client(reader_clt, writer_clt): return async def connect_reader_to_writer(rd, wr, user): - update_stats(user, curr_connects_x2=1) try: while True: data = await rd.read(READ_BUF_SIZE) @@ -782,27 +864,34 @@ async def handle_client(reader_clt, writer_clt): if not data: wr.write_eof() await wr.drain() - wr.close() return else: update_stats(user, octets=len(data)) wr.write(data, extra) await wr.drain() - except (OSError, AttributeError, asyncio.streams.IncompleteReadError) as e: + except (OSError, asyncio.streams.IncompleteReadError) as e: # print_err(e) pass - finally: - wr.transport.abort() - update_stats(user, curr_connects_x2=-1) - asyncio.ensure_future(connect_reader_to_writer(reader_tg, writer_clt, user)) - asyncio.ensure_future(connect_reader_to_writer(reader_clt, writer_tg, user)) + task_tg_to_clt = asyncio.ensure_future(connect_reader_to_writer(reader_tg, writer_clt, user)) + task_clt_to_tg = asyncio.ensure_future(connect_reader_to_writer(reader_clt, writer_tg, user)) + + update_stats(user, curr_connects=1) + await asyncio.wait([task_tg_to_clt, task_clt_to_tg], return_when=asyncio.FIRST_COMPLETED) + update_stats(user, curr_connects=-1) + + task_tg_to_clt.cancel() + task_clt_to_tg.cancel() + + writer_tg.transport.abort() async def handle_client_wrapper(reader, writer): try: await handle_client(reader, writer) except (asyncio.IncompleteReadError, ConnectionResetError, TimeoutError): + pass + finally: writer.transport.abort() @@ -814,7 +903,7 @@ async def stats_printer(): print("Stats for", time.strftime("%d.%m.%Y %H:%M:%S")) for user, stat in stats.items(): print("%s: %d connects (%d current), %.2f MB" % ( - user, stat["connects"], stat["curr_connects_x2"] // 2, + user, stat["connects"], stat["curr_connects"], stat["octets"] / 1000000)) print(flush=True) @@ -966,7 +1055,7 @@ def loop_exception_handler(loop, context): def main(): init_stats() - if sys.platform == 'win32': + if sys.platform == "win32": loop = asyncio.ProactorEventLoop() asyncio.set_event_loop(loop) @@ -980,13 +1069,15 @@ def main(): middle_proxy_updater_task = asyncio.Task(update_middle_proxy_info()) asyncio.ensure_future(middle_proxy_updater_task) - task_v4 = asyncio.start_server(handle_client_wrapper, - '0.0.0.0', PORT, limit=READ_BUF_SIZE, loop=loop) + reuse_port = hasattr(socket, "SO_REUSEPORT") + + task_v4 = asyncio.start_server(handle_client_wrapper, '0.0.0.0', PORT, + limit=READ_BUF_SIZE, reuse_port=reuse_port, loop=loop) server_v4 = loop.run_until_complete(task_v4) if socket.has_ipv6: - task_v6 = asyncio.start_server(handle_client_wrapper, - '::', PORT, limit=READ_BUF_SIZE, loop=loop) + task_v6 = asyncio.start_server(handle_client_wrapper, '::', PORT, + limit=READ_BUF_SIZE, reuse_port=reuse_port, loop=loop) server_v6 = loop.run_until_complete(task_v6) try: