diff --git a/mtprotoproxy.py b/mtprotoproxy.py index 6236191..4569642 100755 --- a/mtprotoproxy.py +++ b/mtprotoproxy.py @@ -121,6 +121,9 @@ def init_config(): # allows to connect in tls mode only conf_dict.setdefault("TLS_ONLY", False) + # support proxy protocol v1/v2 for incoming connections + conf_dict.setdefault("PROXY_PROTOCOL", False) + # set the tls domain for the proxy, has an influence only on starting message conf_dict.setdefault("TLS_DOMAIN", "google.com") @@ -789,7 +792,7 @@ async def handle_bad_client(reader_clt, writer_clt, handshake): set_instant_rst(writer_clt.get_extra_info("socket")) set_bufsizes(writer_clt.get_extra_info("socket"), BUF_SIZE, BUF_SIZE) - if not config.MASK: + if not config.MASK or handshake is None: while await reader_clt.read(BUF_SIZE): # just consume all the data pass @@ -834,7 +837,7 @@ async def handle_bad_client(reader_clt, writer_clt, handshake): writer_srv.transport.abort() -async def handle_pseudo_tls_handshake(handshake, reader, writer): +async def handle_pseudo_tls_handshake(handshake, reader, writer, peer): global used_handshakes TLS_VERS = b"\x03\x03" @@ -853,8 +856,7 @@ async def handle_pseudo_tls_handshake(handshake, reader, writer): digest = handshake[DIGEST_POS: DIGEST_POS + DIGEST_LEN] if digest in used_handshakes: - ip = writer.get_extra_info('peername')[0] - print_err("Active TLS fingerprinting detected from %s, handling it" % ip) + print_err("Active TLS fingerprinting detected from %s, handling it" % peer[0]) return False sess_id_len = handshake[SESSION_ID_LEN_POS] @@ -898,17 +900,80 @@ async def handle_pseudo_tls_handshake(handshake, reader, writer): return False +async def handle_proxy_protocol(reader, peer=None): + PROXY_SIGNATURE = b"PROXY " + PROXY_MIN_LEN = 6 + PROXY_TCP4 = b"TCP4" + PROXY_TCP6 = b"TCP6" + PROXY_UNKNOWN = b"UNKNOWN" + + PROXY2_SIGNATURE = b"\x0d\x0a\x0d\x0a\x00\x0d\x0a\x51\x55\x49\x54\x0a" + PROXY2_MIN_LEN = 16 + PROXY2_AF_UNSPEC = 0x0 + PROXY2_AF_INET = 0x1 + PROXY2_AF_INET6 = 0x2 + + header = await reader.readexactly(PROXY_MIN_LEN) + if header.startswith(PROXY_SIGNATURE): + # proxy header v1 + header += await reader.readuntil(b"\r\n") + _, proxy_fam, *proxy_addr = header[:-2].split(b" ") + if proxy_fam in (PROXY_TCP4, PROXY_TCP6): + if len(proxy_addr) == 4: + src_addr = proxy_addr[0].decode('ascii') + src_port = proxy_addr[2].decode('ascii') + return (src_addr, src_port) + elif proxy_fam == PROXY_UNKNOWN: + return peer + return False + + header += await reader.readexactly(PROXY2_MIN_LEN - PROXY_MIN_LEN) + if header.startswith(PROXY2_SIGNATURE): + # proxy header v2 + proxy_ver = header[12] + if proxy_ver & 0xf0 != 0x20: + return False + proxy_len = int.from_bytes(header[14:16], "big") + proxy_addr = await reader.readexactly(proxy_len) + if proxy_ver == 0x21: + proxy_fam = header[13] >> 4 + if proxy_fam == PROXY2_AF_INET: + if proxy_len >= (4 + 2)*2: + src_addr = socket.inet_ntop(socket.AF_INET, proxy_addr[:4]) + src_port = int.from_bytes(proxy_addr[8:10], "big") + return (src_addr, src_port) + elif proxy_fam == PROXY2_AF_INET6: + if proxy_len >= (16 + 2)*2: + src_addr = socket.inet_ntop(socket.AF_INET6, proxy_addr[:16]) + src_port = int.from_bytes(proxy_addr[32:34], "big") + return (src_addr, src_port) + elif proxy_fam == PROXY2_AF_UNSPEC: + return peer + elif proxy_ver == 0x20: + return peer + + return False + + async def handle_handshake(reader, writer): global used_handshakes TLS_START_BYTES = b"\x16\x03\x01\x02\x00\x01\x00\x01\xfc\x03\x03" EMPTY_READ_BUF_SIZE = 4096 + peer = writer.get_extra_info('peername')[:2] + + if config.PROXY_PROTOCOL: + peer = await handle_proxy_protocol(reader, peer) + if not peer: + await handle_bad_client(reader, writer, None) + return False + handshake = await reader.readexactly(HANDSHAKE_LEN) if handshake.startswith(TLS_START_BYTES): handshake += await reader.readexactly(TLS_HANDSHAKE_LEN - HANDSHAKE_LEN) - tls_handshake_result = await handle_pseudo_tls_handshake(handshake, reader, writer) + tls_handshake_result = await handle_pseudo_tls_handshake(handshake, reader, writer, peer) if not tls_handshake_result: await handle_bad_client(reader, writer, handshake) @@ -926,8 +991,7 @@ async def handle_handshake(reader, writer): enc_prekey, enc_iv = enc_prekey_and_iv[:PREKEY_LEN], enc_prekey_and_iv[PREKEY_LEN:] if dec_prekey_and_iv in used_handshakes: - ip = writer.get_extra_info('peername')[0] - print_err("Active fingerprinting detected from %s, handling it" % ip) + print_err("Active fingerprinting detected from %s, handling it" % peer[0]) await handle_bad_client(reader, writer, handshake) return False @@ -957,7 +1021,7 @@ async def handle_handshake(reader, writer): reader = CryptoWrappedStreamReader(reader, decryptor) writer = CryptoWrappedStreamWriter(writer, encryptor) - return reader, writer, proto_tag, user, dc_idx, enc_key + enc_iv + return reader, writer, proto_tag, user, dc_idx, enc_key + enc_iv, peer await handle_bad_client(reader, writer, handshake) return False @@ -1212,7 +1276,6 @@ async def handle_client(reader_clt, writer_clt): set_ack_timeout(writer_clt.get_extra_info("socket"), config.CLIENT_ACK_TIMEOUT) set_bufsizes(writer_clt.get_extra_info("socket"), get_to_tg_bufsize(), get_to_clt_bufsize()) - cl_ip, cl_port = writer_clt.get_extra_info('peername')[:2] try: clt_data = await asyncio.wait_for(handle_handshake(reader_clt, writer_clt), timeout=config.CLIENT_HANDSHAKE_TIMEOUT) @@ -1222,7 +1285,8 @@ async def handle_client(reader_clt, writer_clt): if not clt_data: return - reader_clt, writer_clt, proto_tag, user, dc_idx, enc_key_and_iv = clt_data + reader_clt, writer_clt, proto_tag, user, dc_idx, enc_key_and_iv, peer = clt_data + cl_ip, cl_port = peer update_stats(user, connects=1)