diff --git a/Dockerfile b/Dockerfile index 3622790..174bff1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM alpine:3.6 +FROM alpine:3.8 RUN adduser tgproxy -u 10000 -D diff --git a/mtprotoproxy.py b/mtprotoproxy.py index 80716b1..28cbc1d 100755 --- a/mtprotoproxy.py +++ b/mtprotoproxy.py @@ -149,11 +149,12 @@ PREFER_IPV6 = config.get("PREFER_IPV6", socket.has_ipv6) FAST_MODE = config.get("FAST_MODE", True) STATS_PRINT_PERIOD = config.get("STATS_PRINT_PERIOD", 600) PROXY_INFO_UPDATE_PERIOD = config.get("PROXY_INFO_UPDATE_PERIOD", 24*60*60) -TO_CLT_BUFSIZE = config.get("TO_CLT_BUFSIZE", 8192) +TO_CLT_BUFSIZE = config.get("TO_CLT_BUFSIZE", 16384) TO_TG_BUFSIZE = config.get("TO_TG_BUFSIZE", 65536) CLIENT_KEEPALIVE = config.get("CLIENT_KEEPALIVE", 10*60) CLIENT_HANDSHAKE_TIMEOUT = config.get("CLIENT_HANDSHAKE_TIMEOUT", 10) CLIENT_ACK_TIMEOUT = config.get("CLIENT_ACK_TIMEOUT", 5*60) +TG_CONNECT_TIMEOUT = config.get("TG_CONNECT_TIMEOUT", 10) TG_DATACENTER_PORT = 443 @@ -430,11 +431,6 @@ class MTProtoIntermediateFrameStreamReader(LayeredStreamReaderBase): msg_len -= 0x80000000 data = await self.upstream.readexactly(msg_len) - - if msg_len % 4 != 0: - cut_border = msg_len - (msg_len % 4) - data = data[:cut_border] - return data, extra @@ -446,6 +442,38 @@ class MTProtoIntermediateFrameStreamWriter(LayeredStreamWriterBase): return self.upstream.write(int.to_bytes(len(data), 4, 'little') + data) +class MTProtoSecureIntermediateFrameStreamReader(LayeredStreamReaderBase): + async def read(self, buf_size): + msg_len_bytes = await self.upstream.readexactly(4) + msg_len = int.from_bytes(msg_len_bytes, "little") + + extra = {} + if msg_len > 0x80000000: + extra["QUICKACK_FLAG"] = True + msg_len -= 0x80000000 + + data = await self.upstream.readexactly(msg_len) + + if msg_len % 4 != 0: + cut_border = msg_len - (msg_len % 4) + data = data[:cut_border] + + return data, extra + + +class MTProtoSecureIntermediateFrameStreamWriter(LayeredStreamWriterBase): + def write(self, data, extra={}): + MAX_PADDING_LEN = 4 + if extra.get("SIMPLE_ACK"): + # TODO: make this unpredictable + return self.upstream.write(data) + else: + padding_len = random.randrange(MAX_PADDING_LEN) + padding = bytearray([random.randrange(256) for i in range(padding_len)]) + padded_data_len_bytes = int.to_bytes(len(data) + padding_len, 4, 'little') + return self.upstream.write(padded_data_len_bytes + data + padding) + + class ProxyReqStreamReader(LayeredStreamReaderBase): async def read(self, msg): RPC_PROXY_ANS = b"\x0d\xda\x03\x44" @@ -504,6 +532,7 @@ class ProxyReqStreamWriter(LayeredStreamWriterBase): FLAG_HAS_AD_TAG = 0x8 FLAG_MAGIC = 0x1000 FLAG_EXTMODE2 = 0x20000 + FLAG_PAD = 0x8000000 FLAG_INTERMEDIATE = 0x20000000 FLAG_ABRIDGED = 0x40000000 FLAG_QUICKACK = 0x80000000 @@ -518,6 +547,8 @@ class ProxyReqStreamWriter(LayeredStreamWriterBase): flags |= FLAG_ABRIDGED elif self.proto_tag == PROTO_TAG_INTERMEDIATE: flags |= FLAG_INTERMEDIATE + elif self.proto_tag == PROTO_TAG_SECURE: + flags |= FLAG_INTERMEDIATE | FLAG_PAD if extra.get("QUICKACK_FLAG"): flags |= FLAG_QUICKACK @@ -591,6 +622,21 @@ def set_bufsizes(sock, recv_buf, send_buf): sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, send_buf) +async def open_connection_tryer(addr, port, limit, timeout, max_attempts=3): + for attempt in range(max_attempts-1): + try: + task = asyncio.open_connection(addr, port, limit=limit) + reader_tgt, writer_tgt = await asyncio.wait_for(task, timeout=timeout) + return reader_tgt, writer_tgt + except (OSError, asyncio.TimeoutError): + continue + + # the last attempt + task = asyncio.open_connection(addr, port, limit=limit) + reader_tgt, writer_tgt = await asyncio.wait_for(task, timeout=timeout) + return reader_tgt, writer_tgt + + async def do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=None): RESERVED_NONCE_FIRST_CHARS = [b"\xef"] RESERVED_NONCE_BEGININGS = [b"\x48\x45\x41\x44", b"\x50\x4F\x53\x54", @@ -609,18 +655,18 @@ async def do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=None): dc = TG_DATACENTERS_V4[dc_idx] try: - reader_tgt, writer_tgt = await asyncio.open_connection(dc, TG_DATACENTER_PORT, - limit=TO_CLT_BUFSIZE) - set_keepalive(writer_tgt.get_extra_info("socket")) - set_bufsizes(writer_tgt.get_extra_info("socket"), TO_CLT_BUFSIZE, TO_TG_BUFSIZE) - + reader_tgt, writer_tgt = await open_connection_tryer( + dc, TG_DATACENTER_PORT, limit=TO_CLT_BUFSIZE, timeout=TG_CONNECT_TIMEOUT) except ConnectionRefusedError as E: print_err("Got connection refused while trying to connect to", dc, TG_DATACENTER_PORT) return False - except OSError as E: + except (OSError, asyncio.TimeoutError) as E: print_err("Unable to connect to", dc, TG_DATACENTER_PORT) return False + set_keepalive(writer_tgt.get_extra_info("socket")) + set_bufsizes(writer_tgt.get_extra_info("socket"), TO_CLT_BUFSIZE, TO_TG_BUFSIZE) + while True: rnd = bytearray([random.randrange(0, 256) for i in range(HANDSHAKE_LEN)]) if rnd[:1] in RESERVED_NONCE_FIRST_CHARS: @@ -710,16 +756,18 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port): addr, port = random.choice(TG_MIDDLE_PROXIES_V4[dc_idx]) try: - reader_tgt, writer_tgt = await asyncio.open_connection(addr, port, limit=TO_CLT_BUFSIZE) - set_keepalive(writer_tgt.get_extra_info("socket")) - set_bufsizes(writer_tgt.get_extra_info("socket"), TO_CLT_BUFSIZE, TO_TG_BUFSIZE) + reader_tgt, writer_tgt = await open_connection_tryer(addr, port, limit=TO_CLT_BUFSIZE, + timeout=TG_CONNECT_TIMEOUT) except ConnectionRefusedError as E: print_err("Got connection refused while trying to connect to", addr, port) return False - except OSError as E: + except (OSError, asyncio.TimeoutError) as E: print_err("Unable to connect to", addr, port) return False + set_keepalive(writer_tgt.get_extra_info("socket")) + set_bufsizes(writer_tgt.get_extra_info("socket"), TO_CLT_BUFSIZE, TO_TG_BUFSIZE) + writer_tgt = MTProtoFrameStreamWriter(writer_tgt, START_SEQ_NO) key_selector = PROXY_SECRET[:4] @@ -862,9 +910,12 @@ async def handle_client(reader_clt, writer_clt): if proto_tag == PROTO_TAG_ABRIDGED: reader_clt = MTProtoCompactFrameStreamReader(reader_clt) writer_clt = MTProtoCompactFrameStreamWriter(writer_clt) - elif proto_tag in (PROTO_TAG_INTERMEDIATE, PROTO_TAG_SECURE): + elif proto_tag == PROTO_TAG_INTERMEDIATE: reader_clt = MTProtoIntermediateFrameStreamReader(reader_clt) writer_clt = MTProtoIntermediateFrameStreamWriter(writer_clt) + elif proto_tag == PROTO_TAG_SECURE: + reader_clt = MTProtoSecureIntermediateFrameStreamReader(reader_clt) + writer_clt = MTProtoSecureIntermediateFrameStreamWriter(writer_clt) else: return