diff --git a/README.md b/README.md index e6be07c..b2fed90 100644 --- a/README.md +++ b/README.md @@ -24,3 +24,4 @@ The proxy can be launched: - with a custom config: `python3 mtprotoproxy.py [configfile]` - several times, clients will be automaticaly balanced between instances - using *PyPy* interprteter +- with runtime statistics exported for [Prometheus](https://prometheus.io/): using [prometheus](https://github.com/alexbers/mtprotoproxy/tree/prometheus) branch diff --git a/mtprotoproxy.py b/mtprotoproxy.py index ee5af3e..80716b1 100755 --- a/mtprotoproxy.py +++ b/mtprotoproxy.py @@ -148,11 +148,12 @@ PREFER_IPV6 = config.get("PREFER_IPV6", socket.has_ipv6) # disables tg->client trafic reencryption, faster but less secure FAST_MODE = config.get("FAST_MODE", True) STATS_PRINT_PERIOD = config.get("STATS_PRINT_PERIOD", 600) -PROXY_INFO_UPDATE_PERIOD = config.get("PROXY_INFO_UPDATE_PERIOD", 60*60*24) -READ_BUF_SIZE = config.get("READ_BUF_SIZE", 16384) -WRITE_BUF_SIZE = config.get("WRITE_BUF_SIZE", 65536) -CLIENT_KEEPALIVE = config.get("CLIENT_KEEPALIVE", 60*30) +PROXY_INFO_UPDATE_PERIOD = config.get("PROXY_INFO_UPDATE_PERIOD", 24*60*60) +TO_CLT_BUFSIZE = config.get("TO_CLT_BUFSIZE", 8192) +TO_TG_BUFSIZE = config.get("TO_TG_BUFSIZE", 65536) +CLIENT_KEEPALIVE = config.get("CLIENT_KEEPALIVE", 10*60) CLIENT_HANDSHAKE_TIMEOUT = config.get("CLIENT_HANDSHAKE_TIMEOUT", 10) +CLIENT_ACK_TIMEOUT = config.get("CLIENT_ACK_TIMEOUT", 5*60) TG_DATACENTER_PORT = 443 @@ -203,6 +204,7 @@ DC_IDX_POS = 60 PROTO_TAG_ABRIDGED = b"\xef\xef\xef\xef" PROTO_TAG_INTERMEDIATE = b"\xee\xee\xee\xee" +PROTO_TAG_SECURE = b"\xdd\xdd\xdd\xdd" CBC_PADDING = 16 PADDING_FILLER = b"\x04\x00\x00\x00" @@ -222,14 +224,14 @@ def init_stats(): stats = {user: collections.Counter() for user in USERS} -def update_stats(user, connects=0, curr_connects=0, octets=0): +def update_stats(user, connects=0, curr_connects=0, octets=0, msgs=0): global stats if user not in stats: stats[user] = collections.Counter() stats[user].update(connects=connects, curr_connects=curr_connects, - octets=octets) + octets=octets, msgs=msgs) class LayeredStreamReaderBase: @@ -429,6 +431,10 @@ class MTProtoIntermediateFrameStreamReader(LayeredStreamReaderBase): data = await self.upstream.readexactly(msg_len) + if msg_len % 4 != 0: + cut_border = msg_len - (msg_len % 4) + data = data[:cut_border] + return data, extra @@ -548,7 +554,7 @@ async def handle_handshake(reader, writer): decrypted = decryptor.decrypt(handshake) proto_tag = decrypted[PROTO_TAG_POS:PROTO_TAG_POS+4] - if proto_tag not in (PROTO_TAG_ABRIDGED, PROTO_TAG_INTERMEDIATE): + if proto_tag not in (PROTO_TAG_ABRIDGED, PROTO_TAG_INTERMEDIATE, PROTO_TAG_SECURE): continue dc_idx = int.from_bytes(decrypted[DC_IDX_POS:DC_IDX_POS+2], "little", signed=True) @@ -557,13 +563,34 @@ async def handle_handshake(reader, writer): writer = CryptoWrappedStreamWriter(writer, encryptor) return reader, writer, proto_tag, user, dc_idx, enc_key + enc_iv - while await reader.read(READ_BUF_SIZE): + EMPTY_READ_BUF_SIZE = 4096 + while await reader.read(EMPTY_READ_BUF_SIZE): # just consume all the data pass return False +def set_keepalive(sock, interval=40, attempts=5): + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + if hasattr(socket, "TCP_KEEPIDLE"): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, interval) + if hasattr(socket, "TCP_KEEPINTVL"): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, interval) + if hasattr(socket, "TCP_KEEPCNT"): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, attempts) + + +def set_ack_timeout(sock, timeout): + if hasattr(socket, "TCP_USER_TIMEOUT"): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_USER_TIMEOUT, timeout*1000) + + +def set_bufsizes(sock, recv_buf, send_buf): + sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, recv_buf) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, send_buf) + + async def do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=None): RESERVED_NONCE_FIRST_CHARS = [b"\xef"] RESERVED_NONCE_BEGININGS = [b"\x48\x45\x41\x44", b"\x50\x4F\x53\x54", @@ -583,7 +610,10 @@ async def do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=None): try: reader_tgt, writer_tgt = await asyncio.open_connection(dc, TG_DATACENTER_PORT, - limit=READ_BUF_SIZE) + limit=TO_CLT_BUFSIZE) + set_keepalive(writer_tgt.get_extra_info("socket")) + set_bufsizes(writer_tgt.get_extra_info("socket"), TO_CLT_BUFSIZE, TO_TG_BUFSIZE) + except ConnectionRefusedError as E: print_err("Got connection refused while trying to connect to", dc, TG_DATACENTER_PORT) return False @@ -653,21 +683,6 @@ def get_middleproxy_aes_key_and_iv(nonce_srv, nonce_clt, clt_ts, srv_ip, clt_por return key, iv -def set_keepalive(sock, interval=40, attempts=5): - sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - if hasattr(socket, "TCP_KEEPIDLE"): - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, interval) - if hasattr(socket, "TCP_KEEPINTVL"): - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, interval) - if hasattr(socket, "TCP_KEEPCNT"): - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, attempts) - - -def set_bufsizes(sock, recv_buf=READ_BUF_SIZE, send_buf=WRITE_BUF_SIZE): - sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, recv_buf) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, send_buf) - - async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port): START_SEQ_NO = -2 NONCE_LEN = 16 @@ -695,9 +710,9 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port): addr, port = random.choice(TG_MIDDLE_PROXIES_V4[dc_idx]) try: - reader_tgt, writer_tgt = await asyncio.open_connection(addr, port, limit=READ_BUF_SIZE) + reader_tgt, writer_tgt = await asyncio.open_connection(addr, port, limit=TO_CLT_BUFSIZE) set_keepalive(writer_tgt.get_extra_info("socket")) - set_bufsizes(writer_tgt.get_extra_info("socket")) + set_bufsizes(writer_tgt.get_extra_info("socket"), TO_CLT_BUFSIZE, TO_TG_BUFSIZE) except ConnectionRefusedError as E: print_err("Got connection refused while trying to connect to", addr, port) return False @@ -719,7 +734,7 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port): old_reader = reader_tgt reader_tgt = MTProtoFrameStreamReader(reader_tgt, START_SEQ_NO) - ans = await reader_tgt.read(READ_BUF_SIZE) + ans = await reader_tgt.read(TO_CLT_BUFSIZE) if len(ans) != RPC_NONCE_ANS_LEN: return False @@ -800,8 +815,9 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port): async def handle_client(reader_clt, writer_clt): - set_keepalive(writer_clt.get_extra_info("socket"), CLIENT_KEEPALIVE) - set_bufsizes(writer_clt.get_extra_info("socket")) + set_keepalive(writer_clt.get_extra_info("socket"), CLIENT_KEEPALIVE, attempts=3) + set_ack_timeout(writer_clt.get_extra_info("socket"), CLIENT_ACK_TIMEOUT) + set_bufsizes(writer_clt.get_extra_info("socket"), TO_TG_BUFSIZE, TO_CLT_BUFSIZE) try: clt_data = await asyncio.wait_for(handle_handshake(reader_clt, writer_clt), @@ -846,16 +862,16 @@ async def handle_client(reader_clt, writer_clt): if proto_tag == PROTO_TAG_ABRIDGED: reader_clt = MTProtoCompactFrameStreamReader(reader_clt) writer_clt = MTProtoCompactFrameStreamWriter(writer_clt) - elif proto_tag == PROTO_TAG_INTERMEDIATE: + elif proto_tag in (PROTO_TAG_INTERMEDIATE, PROTO_TAG_SECURE): reader_clt = MTProtoIntermediateFrameStreamReader(reader_clt) writer_clt = MTProtoIntermediateFrameStreamWriter(writer_clt) else: return - async def connect_reader_to_writer(rd, wr, user): + async def connect_reader_to_writer(rd, wr, user, rd_buf_size): try: while True: - data = await rd.read(READ_BUF_SIZE) + data = await rd.read(rd_buf_size) if isinstance(data, tuple): data, extra = data else: @@ -866,15 +882,17 @@ async def handle_client(reader_clt, writer_clt): await wr.drain() return else: - update_stats(user, octets=len(data)) + update_stats(user, octets=len(data), msgs=1) wr.write(data, extra) await wr.drain() except (OSError, asyncio.streams.IncompleteReadError) as e: # print_err(e) pass - task_tg_to_clt = asyncio.ensure_future(connect_reader_to_writer(reader_tg, writer_clt, user)) - task_clt_to_tg = asyncio.ensure_future(connect_reader_to_writer(reader_clt, writer_tg, user)) + tg_to_clt = connect_reader_to_writer(reader_tg, writer_clt, user, TO_CLT_BUFSIZE) + clt_to_tg = connect_reader_to_writer(reader_clt, writer_tg, user, TO_TG_BUFSIZE) + task_tg_to_clt = asyncio.ensure_future(tg_to_clt) + task_clt_to_tg = asyncio.ensure_future(clt_to_tg) update_stats(user, curr_connects=1) await asyncio.wait([task_tg_to_clt, task_clt_to_tg], return_when=asyncio.FIRST_COMPLETED) @@ -902,9 +920,9 @@ async def stats_printer(): print("Stats for", time.strftime("%d.%m.%Y %H:%M:%S")) for user, stat in stats.items(): - print("%s: %d connects (%d current), %.2f MB" % ( + print("%s: %d connects (%d current), %.2f MB, %d msgs" % ( user, stat["connects"], stat["curr_connects"], - stat["octets"] / 1000000)) + stat["octets"] / 1000000, stat["msgs"])) print(flush=True) @@ -1032,6 +1050,10 @@ def print_tg_info(): params_encodeded = urllib.parse.urlencode(params, safe=':') print("{}: tg://proxy?{}".format(user, params_encodeded), flush=True) + params = {"server": ip, "port": PORT, "secret": "dd" + secret} + params_encodeded = urllib.parse.urlencode(params, safe=':') + print("{}: tg://proxy?{} (beta)".format(user, params_encodeded), flush=True) + def loop_exception_handler(loop, context): exception = context.get("exception") @@ -1039,15 +1061,24 @@ def loop_exception_handler(loop, context): if exception: if isinstance(exception, TimeoutError): if transport: - print_err("Timeout, killing transport") transport.abort() return if isinstance(exception, OSError): IGNORE_ERRNO = { - 10038 # operation on non-socket on Windows, likely because fd == -1 + 10038, # operation on non-socket on Windows, likely because fd == -1 + 121, # the semaphore timeout period has expired on Windows + } + + FORCE_CLOSE_ERRNO = { + 113, # no route to host + } if exception.errno in IGNORE_ERRNO: return + elif exception.errno in FORCE_CLOSE_ERRNO: + if transport: + transport.abort() + return loop.default_exception_handler(context) @@ -1072,12 +1103,12 @@ def main(): reuse_port = hasattr(socket, "SO_REUSEPORT") task_v4 = asyncio.start_server(handle_client_wrapper, '0.0.0.0', PORT, - limit=READ_BUF_SIZE, reuse_port=reuse_port, loop=loop) + limit=TO_TG_BUFSIZE, reuse_port=reuse_port, loop=loop) server_v4 = loop.run_until_complete(task_v4) if socket.has_ipv6: task_v6 = asyncio.start_server(handle_client_wrapper, '::', PORT, - limit=READ_BUF_SIZE, reuse_port=reuse_port, loop=loop) + limit=TO_TG_BUFSIZE, reuse_port=reuse_port, loop=loop) server_v6 = loop.run_until_complete(task_v6) try: