diff --git a/mtprotoproxy.py b/mtprotoproxy.py index 2d177a1..741591a 100755 --- a/mtprotoproxy.py +++ b/mtprotoproxy.py @@ -442,42 +442,49 @@ class TgConnectionPool: def __init__(self): self.pools = {} - async def open_tg_connection(self, host, port): + async def open_tg_connection(self, host, port, init_func=None): task = asyncio.open_connection(host, port, limit=get_to_clt_bufsize()) reader_tgt, writer_tgt = await asyncio.wait_for(task, timeout=config.TG_CONNECT_TIMEOUT) + + set_keepalive(writer_tgt.get_extra_info("socket")) + set_bufsizes(writer_tgt.get_extra_info("socket"), get_to_clt_bufsize(), get_to_tg_bufsize()) + + if init_func: + return await asyncio.wait_for(init_func(host, port, reader_tgt, writer_tgt), + timeout=config.TG_CONNECT_TIMEOUT) return reader_tgt, writer_tgt - def register_host_port(self, host, port): - if (host, port) not in self.pools: - self.pools[(host, port)] = [] + def register_host_port(self, host, port, init_func): + if (host, port, init_func) not in self.pools: + self.pools[(host, port, init_func)] = [] - while len(self.pools[(host, port)]) < TgConnectionPool.MAX_CONNS_IN_POOL: - connect_task = asyncio.ensure_future(self.open_tg_connection(host, port)) - self.pools[(host, port)].append(connect_task) + while len(self.pools[(host, port, init_func)]) < TgConnectionPool.MAX_CONNS_IN_POOL: + connect_task = asyncio.ensure_future(self.open_tg_connection(host, port, init_func)) + self.pools[(host, port, init_func)].append(connect_task) - async def get_connection(self, host, port): - self.register_host_port(host, port) + async def get_connection(self, host, port, init_func=None): + self.register_host_port(host, port, init_func) ret = None - for task in self.pools[(host, port)][::]: + for task in self.pools[(host, port, init_func)][::]: if task.done(): if task.exception(): - self.pools[(host, port)].remove(task) + self.pools[(host, port, init_func)].remove(task) continue - reader, writer = task.result() + reader, writer, *other = task.result() if writer.transport.is_closing(): - self.pools[(host, port)].remove(task) + self.pools[(host, port, init_func)].remove(task) continue if not ret: - self.pools[(host, port)].remove(task) - ret = (reader, writer) + self.pools[(host, port, init_func)].remove(task) + ret = (reader, writer, *other) - self.register_host_port(host, port) + self.register_host_port(host, port, init_func) if ret: return ret - return await self.open_tg_connection(host, port) + return await self.open_tg_connection(host, port, init_func) tg_connection_pool = TgConnectionPool() @@ -1269,13 +1276,13 @@ async def do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=None): except ConnectionRefusedError as E: print_err("Got connection refused while trying to connect to", dc, TG_DATACENTER_PORT) return False + except ConnectionAbortedError as E: + print_err("The Telegram server connection is bad: %d (%s %s) %s" % (dc_idx, addr, port, E)) + return False except (OSError, asyncio.TimeoutError) as E: print_err("Unable to connect to", dc, TG_DATACENTER_PORT) return False - set_keepalive(writer_tgt.get_extra_info("socket")) - set_bufsizes(writer_tgt.get_extra_info("socket"), get_to_clt_bufsize(), get_to_tg_bufsize()) - while True: rnd = bytearray(myrandom.getrandbytes(HANDSHAKE_LEN)) if rnd[:1] in RESERVED_NONCE_FIRST_CHARS: @@ -1338,20 +1345,50 @@ def get_middleproxy_aes_key_and_iv(nonce_srv, nonce_clt, clt_ts, srv_ip, clt_por return key, iv -async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port): +async def middleproxy_handshake_after_connect(host, port, reader_tgt, writer_tgt): + """ The first stage of middleproxy handshake """ START_SEQ_NO = -2 NONCE_LEN = 16 RPC_NONCE = b"\xaa\x87\xcb\x7a" - RPC_HANDSHAKE = b"\xf5\xee\x82\x76" CRYPTO_AES = b"\x01\x00\x00\x00" RPC_NONCE_ANS_LEN = 32 - RPC_HANDSHAKE_ANS_LEN = 32 + writer_tgt = MTProtoFrameStreamWriter(writer_tgt, START_SEQ_NO) + key_selector = PROXY_SECRET[:4] + crypto_ts = int.to_bytes(int(time.time()) % (256**4), 4, "little") + + nonce = myrandom.getrandbytes(NONCE_LEN) + + msg = RPC_NONCE + key_selector + CRYPTO_AES + crypto_ts + nonce + + writer_tgt.write(msg) + await writer_tgt.drain() + + reader_tgt = MTProtoFrameStreamReader(reader_tgt, START_SEQ_NO) + ans = await reader_tgt.read(get_to_clt_bufsize()) + + if len(ans) != RPC_NONCE_ANS_LEN: + raise ConnectionAbortedError("bad rpc answer length") + + rpc_type, rpc_key_selector, rpc_schema, rpc_crypto_ts, rpc_nonce = ( + ans[:4], ans[4:8], ans[8:12], ans[12:16], ans[16:32] + ) + + if rpc_type != RPC_NONCE or rpc_key_selector != key_selector or rpc_schema != CRYPTO_AES: + raise ConnectionAbortedError("bad rpc answer") + + return reader_tgt, writer_tgt, nonce, rpc_nonce, crypto_ts + + +async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port): + RPC_HANDSHAKE = b"\xf5\xee\x82\x76" # pass as consts to simplify code RPC_FLAGS = b"\x00\x00\x00\x00" + RPC_HANDSHAKE_ANS_LEN = 32 + global my_ip_info global tg_connection_pool @@ -1368,43 +1405,19 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port): addr, port = myrandom.choice(TG_MIDDLE_PROXIES_V4[dc_idx]) try: - reader_tgt, writer_tgt = await tg_connection_pool.get_connection(addr, port) + ret = await tg_connection_pool.get_connection(addr, port, + middleproxy_handshake_after_connect) + reader_tgt, writer_tgt, nonce, rpc_nonce, crypto_ts = ret except ConnectionRefusedError as E: print_err("The Telegram server %d (%s %s) is refusing connections" % (dc_idx, addr, port)) return False + except ConnectionAbortedError as E: + print_err("The Telegram server connection is bad: %d (%s %s) %s" % (dc_idx, addr, port, E)) + return False except (OSError, asyncio.TimeoutError) as E: print_err("Unable to connect to the Telegram server %d (%s %s)" % (dc_idx, addr, port)) return False - set_keepalive(writer_tgt.get_extra_info("socket")) - set_bufsizes(writer_tgt.get_extra_info("socket"), get_to_clt_bufsize(), get_to_tg_bufsize()) - - writer_tgt = MTProtoFrameStreamWriter(writer_tgt, START_SEQ_NO) - - key_selector = PROXY_SECRET[:4] - crypto_ts = int.to_bytes(int(time.time()) % (256**4), 4, "little") - - nonce = myrandom.getrandbytes(NONCE_LEN) - - msg = RPC_NONCE + key_selector + CRYPTO_AES + crypto_ts + nonce - - writer_tgt.write(msg) - await writer_tgt.drain() - - old_reader = reader_tgt - reader_tgt = MTProtoFrameStreamReader(reader_tgt, START_SEQ_NO) - ans = await reader_tgt.read(get_to_clt_bufsize()) - - if len(ans) != RPC_NONCE_ANS_LEN: - return False - - rpc_type, rpc_key_selector, rpc_schema, rpc_crypto_ts, rpc_nonce = ( - ans[:4], ans[4:8], ans[8:12], ans[12:16], ans[16:32] - ) - - if rpc_type != RPC_NONCE or rpc_key_selector != key_selector or rpc_schema != CRYPTO_AES: - return False - # get keys tg_ip, tg_port = writer_tgt.upstream.get_extra_info('peername')[:2] my_ip, my_port = writer_tgt.upstream.get_extra_info('sockname')[:2]