diff --git a/mtprotoproxy.py b/mtprotoproxy.py index 8cee2d0..20662fd 100755 --- a/mtprotoproxy.py +++ b/mtprotoproxy.py @@ -15,82 +15,6 @@ import re import runpy import signal -if len(sys.argv) < 2: - config = runpy.run_module("config") -elif len(sys.argv) == 2: - # launch with own config - config = runpy.run_path(sys.argv[1]) -else: - # undocumented way of launching - config = {} - config["PORT"] = int(sys.argv[1]) - secrets = sys.argv[2].split(",") - config["USERS"] = {"user%d" % i: secrets[i].zfill(32) for i in range(len(secrets))} - if len(sys.argv) > 3: - config["AD_TAG"] = sys.argv[3] - -PORT = config["PORT"] -USERS = config["USERS"] -AD_TAG = bytes.fromhex(config.get("AD_TAG", "")) - -# load advanced settings - -# if IPv6 avaliable, use it by default -PREFER_IPV6 = config.get("PREFER_IPV6", socket.has_ipv6) - -# disables tg->client trafic reencryption, faster but less secure -FAST_MODE = config.get("FAST_MODE", True) - -# doesn't allow to connect in not-secure mode -SECURE_ONLY = config.get("SECURE_ONLY", False) - -# user tcp connection limits, the mapping from name to the integer limit -# one client can create many tcp connections, up to 8 -USER_MAX_TCP_CONNS = config.get("USER_MAX_TCP_CONNS", {}) - -# expiration date for users in format of day/month/year -USER_EXPIRATIONS = config.get("USER_EXPIRATIONS", {}) - -# length of used handshake randoms for active fingerprinting protection -REPLAY_CHECK_LEN = config.get("REPLAY_CHECK_LEN", 32768) - -# block bad first packets to even more protect against replay-based fingerprinting -BLOCK_IF_FIRST_PKT_BAD = config.get("BLOCK_IF_FIRST_PKT_BAD", True) - -# delay in seconds between stats printing -STATS_PRINT_PERIOD = config.get("STATS_PRINT_PERIOD", 600) - -# delay in seconds between middle proxy info updates -PROXY_INFO_UPDATE_PERIOD = config.get("PROXY_INFO_UPDATE_PERIOD", 24*60*60) - -# delay in seconds between time getting, zero means disabled -GET_TIME_PERIOD = config.get("GET_TIME_PERIOD", 10*60) - -# max socket buffer size to the client direction, the more the faster, but more RAM hungry -# can be the tuple (low, users_margin, high) for the adaptive case. If no much users, use high -TO_CLT_BUFSIZE = config.get("TO_CLT_BUFSIZE", (16384, 100, 131072)) - -# max socket buffer size to the telegram servers direction, also can be the tuple -TO_TG_BUFSIZE = config.get("TO_TG_BUFSIZE", 65536) - -# keepalive period for clients in secs -CLIENT_KEEPALIVE = config.get("CLIENT_KEEPALIVE", 10*60) - -# drop client after this timeout if the handshake fail -CLIENT_HANDSHAKE_TIMEOUT = config.get("CLIENT_HANDSHAKE_TIMEOUT", 10) - -# if client doesn't confirm data for this number of seconds, it is dropped -CLIENT_ACK_TIMEOUT = config.get("CLIENT_ACK_TIMEOUT", 5*60) - -# telegram servers connect timeout in seconds -TG_CONNECT_TIMEOUT = config.get("TG_CONNECT_TIMEOUT", 10) - -# listen address for IPv4 -LISTEN_ADDR_IPV4 = config.get("LISTEN_ADDR_IPV4", "0.0.0.0") - -# listen address for IPv6 -LISTEN_ADDR_IPV6 = config.get("LISTEN_ADDR_IPV6", "::") - TG_DATACENTER_PORT = 443 @@ -121,8 +45,6 @@ TG_MIDDLE_PROXIES_V6 = { 5: [("2001:b28:f23f:f005::d", 8888)], -5: [("2001:67c:04e8:f004::d", 8888)] } -USE_MIDDLE_PROXY = (len(AD_TAG) == 16) - PROXY_SECRET = bytes.fromhex( "c4f9faca9678e6bb48ad6c7e2ce5c0d24430645d554addeb55419e034da62721" + "d046eaab6e52ab14a95a443ecfb3463e79a05a66612adf9caeda8be9a80da698" + @@ -148,9 +70,103 @@ PADDING_FILLER = b"\x04\x00\x00\x00" MIN_MSG_LEN = 12 MAX_MSG_LEN = 2 ** 24 + my_ip_info = {"ipv4": None, "ipv6": None} used_handshakes = collections.OrderedDict() +config = {} + + +def init_config(): + global config + # we use conf_dict to protect the original config from exceptions when reloading + if len(sys.argv) < 2: + conf_dict = runpy.run_module("config") + elif len(sys.argv) == 2: + # launch with own config + conf_dict = runpy.run_path(sys.argv[1]) + else: + # undocumented way of launching + conf_dict = {} + conf_dict["PORT"] = int(sys.argv[1]) + secrets = sys.argv[2].split(",") + conf_dict["USERS"] = {"user%d" % i: secrets[i].zfill(32) for i in range(len(secrets))} + if len(sys.argv) > 3: + conf_dict["AD_TAG"] = sys.argv[3] + + conf_dict = {k: v for k, v in conf_dict.items() if k.isupper()} + + conf_dict.setdefault("PORT", 3255) + conf_dict.setdefault("USERS", {"tg": "00000000000000000000000000000000"}) + conf_dict["AD_TAG"] = bytes.fromhex(conf_dict.get("AD_TAG", "")) + + # load advanced settings + + # use middle proxy, necessary to show ad + conf_dict.setdefault("USE_MIDDLE_PROXY", len(conf_dict["AD_TAG"]) == 16) + + # if IPv6 avaliable, use it by default + conf_dict.setdefault("PREFER_IPV6", socket.has_ipv6) + + # disables tg->client trafic reencryption, faster but less secure + conf_dict.setdefault("FAST_MODE", True) + + # doesn't allow to connect in not-secure mode + conf_dict.setdefault("SECURE_ONLY", False) + + # user tcp connection limits, the mapping from name to the integer limit + # one client can create many tcp connections, up to 8 + conf_dict.setdefault("USER_MAX_TCP_CONNS", {}) + + # expiration date for users in format of day/month/year + conf_dict.setdefault("USER_EXPIRATIONS", {}) + for user in conf_dict["USER_EXPIRATIONS"]: + expiration = datetime.datetime.strptime(conf_dict["USER_EXPIRATIONS"][user], "%d/%m/%Y") + conf_dict["USER_EXPIRATIONS"][user] = expiration + + # length of used handshake randoms for active fingerprinting protection + conf_dict.setdefault("REPLAY_CHECK_LEN", 32768) + + # block bad first packets to even more protect against replay-based fingerprinting + conf_dict.setdefault("BLOCK_IF_FIRST_PKT_BAD", True) + + # delay in seconds between stats printing + conf_dict.setdefault("STATS_PRINT_PERIOD", 600) + + # delay in seconds between middle proxy info updates + conf_dict.setdefault("PROXY_INFO_UPDATE_PERIOD", 24*60*60) + + # delay in seconds between time getting, zero means disabled + conf_dict.setdefault("GET_TIME_PERIOD", 10*60) + + # max socket buffer size to the client direction, the more the faster, but more RAM hungry + # can be the tuple (low, users_margin, high) for the adaptive case. If no much users, use high + conf_dict.setdefault("TO_CLT_BUFSIZE", (16384, 100, 131072)) + + # max socket buffer size to the telegram servers direction, also can be the tuple + conf_dict.setdefault("TO_TG_BUFSIZE", 65536) + + # keepalive period for clients in secs + conf_dict.setdefault("CLIENT_KEEPALIVE", 10*60) + + # drop client after this timeout if the handshake fail + conf_dict.setdefault("CLIENT_HANDSHAKE_TIMEOUT", 10) + + # if client doesn't confirm data for this number of seconds, it is dropped + conf_dict.setdefault("CLIENT_ACK_TIMEOUT", 5*60) + + # telegram servers connect timeout in seconds + conf_dict.setdefault("TG_CONNECT_TIMEOUT", 10) + + # listen address for IPv4 + conf_dict.setdefault("LISTEN_ADDR_IPV4", "0.0.0.0") + + # listen address for IPv6 + conf_dict.setdefault("LISTEN_ADDR_IPV6", "::") + + # allow access to config by attributes + config = type("config", (dict,), conf_dict)(conf_dict) + def setup_files_limit(): try: @@ -163,7 +179,7 @@ def setup_files_limit(): pass -def setup_debug(): +def setup_signals(): if hasattr(signal, 'SIGUSR1'): def debug_signal(signum, frame): import pdb @@ -171,6 +187,13 @@ def setup_debug(): signal.signal(signal.SIGUSR1, debug_signal) + if hasattr(signal, 'SIGUSR2'): + def reload_signal(signum, frame): + init_config() + print("Config reloaded", flush=True, file=sys.stderr) + + signal.signal(signal.SIGUSR2, reload_signal) + def try_setup_uvloop(): try: @@ -276,7 +299,7 @@ def print_err(*params): def init_stats(): global stats - stats = {user: collections.Counter() for user in USERS} + stats = {user: collections.Counter() for user in config.USERS} def update_stats(user, connects=0, curr_connects=0, octets=0, msgs=0): @@ -299,18 +322,18 @@ def get_curr_connects_count(): def get_to_tg_bufsize(): - if isinstance(TO_TG_BUFSIZE, int): - return TO_TG_BUFSIZE + if isinstance(config.TO_TG_BUFSIZE, int): + return config.TO_TG_BUFSIZE - low, margin, high = TO_TG_BUFSIZE + low, margin, high = config.TO_TG_BUFSIZE return high if get_curr_connects_count() < margin else low def get_to_clt_bufsize(): - if isinstance(TO_CLT_BUFSIZE, int): - return TO_CLT_BUFSIZE + if isinstance(config.TO_CLT_BUFSIZE, int): + return config.TO_CLT_BUFSIZE - low, margin, high = TO_CLT_BUFSIZE + low, margin, high = config.TO_CLT_BUFSIZE return high if get_curr_connects_count() < margin else low @@ -638,7 +661,7 @@ class ProxyReqStreamWriter(LayeredStreamWriterBase): full_msg = bytearray() full_msg += RPC_PROXY_REQ + int.to_bytes(flags, 4, "little") + self.out_conn_id full_msg += self.remote_ip_port + self.our_ip_port + EXTRA_SIZE + PROXY_TAG - full_msg += bytes([len(AD_TAG)]) + AD_TAG + FOUR_BYTES_ALIGNER + full_msg += bytes([len(config.AD_TAG)]) + config.AD_TAG + FOUR_BYTES_ALIGNER full_msg += msg self.first_flag_byte = b"\x08" @@ -664,8 +687,8 @@ async def handle_handshake(reader, writer): return False - for user in USERS: - secret = bytes.fromhex(USERS[user]) + for user in config.USERS: + secret = bytes.fromhex(config.USERS[user]) dec_key = hashlib.sha256(dec_prekey + secret).digest() decryptor = create_aes_ctr(key=dec_key, iv=int.from_bytes(dec_iv, "big")) @@ -679,12 +702,12 @@ async def handle_handshake(reader, writer): if proto_tag not in (PROTO_TAG_ABRIDGED, PROTO_TAG_INTERMEDIATE, PROTO_TAG_SECURE): continue - if SECURE_ONLY and proto_tag != PROTO_TAG_SECURE: + if config.SECURE_ONLY and proto_tag != PROTO_TAG_SECURE: continue dc_idx = int.from_bytes(decrypted[DC_IDX_POS:DC_IDX_POS+2], "little", signed=True) - while len(used_handshakes) >= REPLAY_CHECK_LEN: + while len(used_handshakes) >= config.REPLAY_CHECK_LEN: used_handshakes.popitem(last=False) used_handshakes[dec_prekey_and_iv] = True @@ -749,7 +772,7 @@ async def do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=None): dc_idx = abs(dc_idx) - 1 - if PREFER_IPV6: + if config.PREFER_IPV6: if not 0 <= dc_idx < len(TG_DATACENTERS_V6): return False dc = TG_DATACENTERS_V6[dc_idx] @@ -760,7 +783,7 @@ async def do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=None): try: reader_tgt, writer_tgt = await open_connection_tryer( - dc, TG_DATACENTER_PORT, limit=get_to_clt_bufsize(), timeout=TG_CONNECT_TIMEOUT) + dc, TG_DATACENTER_PORT, limit=get_to_clt_bufsize(), timeout=config.TG_CONNECT_TIMEOUT) except ConnectionRefusedError as E: print_err("Got connection refused while trying to connect to", dc, TG_DATACENTER_PORT) return False @@ -847,7 +870,7 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port): # pass as consts to simplify code RPC_FLAGS = b"\x00\x00\x00\x00" - use_ipv6_tg = PREFER_IPV6 + use_ipv6_tg = config.PREFER_IPV6 use_ipv6_clt = (":" in cl_ip) if use_ipv6_tg: @@ -861,7 +884,7 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port): try: reader_tgt, writer_tgt = await open_connection_tryer(addr, port, limit=get_to_clt_bufsize(), - timeout=TG_CONNECT_TIMEOUT) + timeout=config.TG_CONNECT_TIMEOUT) except ConnectionRefusedError as E: print_err("Got connection refused while trying to connect to", addr, port) return False @@ -967,14 +990,14 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port): async def handle_client(reader_clt, writer_clt): - set_keepalive(writer_clt.get_extra_info("socket"), CLIENT_KEEPALIVE, attempts=3) - set_ack_timeout(writer_clt.get_extra_info("socket"), CLIENT_ACK_TIMEOUT) + set_keepalive(writer_clt.get_extra_info("socket"), config.CLIENT_KEEPALIVE, attempts=3) + set_ack_timeout(writer_clt.get_extra_info("socket"), config.CLIENT_ACK_TIMEOUT) set_bufsizes(writer_clt.get_extra_info("socket"), get_to_tg_bufsize(), get_to_clt_bufsize()) cl_ip, cl_port = writer_clt.get_extra_info('peername')[:2] try: clt_data = await asyncio.wait_for(handle_handshake(reader_clt, writer_clt), - timeout=CLIENT_HANDSHAKE_TIMEOUT) + timeout=config.CLIENT_HANDSHAKE_TIMEOUT) except asyncio.TimeoutError: return @@ -985,8 +1008,8 @@ async def handle_client(reader_clt, writer_clt): update_stats(user, connects=1) - if not USE_MIDDLE_PROXY: - if FAST_MODE: + if not config.USE_MIDDLE_PROXY: + if config.FAST_MODE: tg_data = await do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=enc_key_and_iv) else: tg_data = await do_direct_handshake(proto_tag, dc_idx) @@ -998,7 +1021,7 @@ async def handle_client(reader_clt, writer_clt): reader_tg, writer_tg = tg_data - if not USE_MIDDLE_PROXY and FAST_MODE: + if not config.USE_MIDDLE_PROXY and config.FAST_MODE: class FakeEncryptor: def encrypt(self, data): return data @@ -1010,7 +1033,7 @@ async def handle_client(reader_clt, writer_clt): reader_tg.decryptor = FakeDecryptor() writer_clt.encryptor = FakeEncryptor() - if USE_MIDDLE_PROXY: + if config.USE_MIDDLE_PROXY: if proto_tag == PROTO_TAG_ABRIDGED: reader_clt = MTProtoCompactFrameStreamReader(reader_clt) writer_clt = MTProtoCompactFrameStreamWriter(writer_clt) @@ -1058,7 +1081,7 @@ async def handle_client(reader_clt, writer_clt): pass tg_to_clt = connect_reader_to_writer(reader_tg, writer_clt, user, get_to_clt_bufsize(), - block_if_first_pkt_bad=BLOCK_IF_FIRST_PKT_BAD) + block_if_first_pkt_bad=config.BLOCK_IF_FIRST_PKT_BAD) clt_to_tg = connect_reader_to_writer(reader_clt, writer_tg, user, get_to_tg_bufsize()) task_tg_to_clt = asyncio.ensure_future(tg_to_clt) task_clt_to_tg = asyncio.ensure_future(clt_to_tg) @@ -1066,13 +1089,13 @@ async def handle_client(reader_clt, writer_clt): update_stats(user, curr_connects=1) tcp_limit_hit = ( - user in USER_MAX_TCP_CONNS and - stats[user]["curr_connects"] > USER_MAX_TCP_CONNS[user] + user in config.USER_MAX_TCP_CONNS and + stats[user]["curr_connects"] > config.USER_MAX_TCP_CONNS[user] ) user_expired = ( - user in USER_EXPIRATIONS and - datetime.datetime.now() > datetime.datetime.strptime(USER_EXPIRATIONS[user], "%d/%m/%Y") + user in config.USER_EXPIRATIONS and + datetime.datetime.now() > config.USER_EXPIRATIONS[user] ) if (not tcp_limit_hit) and (not user_expired): @@ -1098,7 +1121,7 @@ async def handle_client_wrapper(reader, writer): async def stats_printer(): global stats while True: - await asyncio.sleep(STATS_PRINT_PERIOD) + await asyncio.sleep(config.STATS_PRINT_PERIOD) print("Stats for", time.strftime("%d.%m.%Y %H:%M:%S")) for user, stat in stats.items(): @@ -1126,7 +1149,6 @@ async def make_https_req(url, host="core.telegram.org"): async def get_srv_time(): - global USE_MIDDLE_PROXY TIME_SYNC_ADDR = "https://core.telegram.org/getProxySecret" MAX_TIME_SKEW = 30 @@ -1142,22 +1164,22 @@ async def get_srv_time(): srv_time = datetime.datetime.strptime(line, "%a, %d %b %Y %H:%M:%S %Z") now_time = datetime.datetime.utcnow() time_diff = (now_time-srv_time).total_seconds() - if USE_MIDDLE_PROXY and abs(time_diff) > MAX_TIME_SKEW: + if config.USE_MIDDLE_PROXY and abs(time_diff) > MAX_TIME_SKEW: print_err("Time skew detected, please set the clock") print_err("Server time:", srv_time, "your time:", now_time) print_err("Disabling advertising to continue serving") - USE_MIDDLE_PROXY = False + config.USE_MIDDLE_PROXY = False want_to_reenable_advertising = True elif want_to_reenable_advertising and abs(time_diff) <= MAX_TIME_SKEW: print_err("Time is ok, reenabling advertising") - USE_MIDDLE_PROXY = True + config.USE_MIDDLE_PROXY = True want_to_reenable_advertising = False except Exception as E: print_err("Error getting server time", E) - await asyncio.sleep(GET_TIME_PERIOD) + await asyncio.sleep(config.GET_TIME_PERIOD) async def update_middle_proxy_info(): @@ -1213,12 +1235,10 @@ async def update_middle_proxy_info(): except Exception as E: print_err("Error updating middle proxy secret, using old", E) - await asyncio.sleep(PROXY_INFO_UPDATE_PERIOD) + await asyncio.sleep(config.PROXY_INFO_UPDATE_PERIOD) def init_ip_info(): - global USE_MIDDLE_PROXY - global PREFER_IPV6 global my_ip_info def get_ip_from_url(url): @@ -1240,17 +1260,17 @@ def init_ip_info(): my_ip_info["ipv4"] = get_ip_from_url(IPV4_URL1) or get_ip_from_url(IPV4_URL2) my_ip_info["ipv6"] = get_ip_from_url(IPV6_URL1) or get_ip_from_url(IPV6_URL2) - if PREFER_IPV6: + if config.PREFER_IPV6: if my_ip_info["ipv6"]: print_err("IPv6 found, using it for external communication") else: - PREFER_IPV6 = False + config.PREFER_IPV6 = False - if USE_MIDDLE_PROXY: - if ((not PREFER_IPV6 and not my_ip_info["ipv4"]) or - (PREFER_IPV6 and not my_ip_info["ipv6"])): + if config.USE_MIDDLE_PROXY: + if ((not config.PREFER_IPV6 and not my_ip_info["ipv4"]) or + (config.PREFER_IPV6 and not my_ip_info["ipv6"])): print_err("Failed to determine your ip, advertising disabled") - USE_MIDDLE_PROXY = False + config.USE_MIDDLE_PROXY = False def print_tg_info(): @@ -1260,14 +1280,14 @@ def print_tg_info(): if not ip_addrs: ip_addrs = ["YOUR_IP"] - for user, secret in sorted(USERS.items(), key=lambda x: x[0]): + for user, secret in sorted(config.USERS.items(), key=lambda x: x[0]): for ip in ip_addrs: - if not SECURE_ONLY: - params = {"server": ip, "port": PORT, "secret": secret} + if not config.SECURE_ONLY: + params = {"server": ip, "port": config.PORT, "secret": secret} params_encodeded = urllib.parse.urlencode(params, safe=':') print("{}: tg://proxy?{}".format(user, params_encodeded), flush=True) - params = {"server": ip, "port": PORT, "secret": "dd" + secret} + params = {"server": ip, "port": config.PORT, "secret": "dd" + secret} params_encodeded = urllib.parse.urlencode(params, safe=':') print("{}: tg://proxy?{}".format(user, params_encodeded), flush=True) if secret in ["00000000000000000000000000000000", "0123456789abcdef0123456789abcdef"]: @@ -1305,7 +1325,7 @@ def loop_exception_handler(loop, context): def main(): setup_files_limit() - setup_debug() + setup_signals() try_setup_uvloop() init_stats() @@ -1320,22 +1340,22 @@ def main(): stats_printer_task = asyncio.Task(stats_printer()) asyncio.ensure_future(stats_printer_task) - if USE_MIDDLE_PROXY: + if config.USE_MIDDLE_PROXY: middle_proxy_updater_task = asyncio.Task(update_middle_proxy_info()) asyncio.ensure_future(middle_proxy_updater_task) - if GET_TIME_PERIOD: + if config.GET_TIME_PERIOD: time_get_task = asyncio.Task(get_srv_time()) asyncio.ensure_future(time_get_task) reuse_port = hasattr(socket, "SO_REUSEPORT") - task_v4 = asyncio.start_server(handle_client_wrapper, LISTEN_ADDR_IPV4, PORT, + task_v4 = asyncio.start_server(handle_client_wrapper, config.LISTEN_ADDR_IPV4, config.PORT, limit=get_to_tg_bufsize(), reuse_port=reuse_port, loop=loop) server_v4 = loop.run_until_complete(task_v4) if socket.has_ipv6: - task_v6 = asyncio.start_server(handle_client_wrapper, LISTEN_ADDR_IPV6, PORT, + task_v6 = asyncio.start_server(handle_client_wrapper, config.LISTEN_ADDR_IPV6, config.PORT, limit=get_to_tg_bufsize(), reuse_port=reuse_port, loop=loop) server_v6 = loop.run_until_complete(task_v6) @@ -1357,6 +1377,7 @@ def main(): if __name__ == "__main__": + init_config() init_ip_info() print_tg_info() main()