diff --git a/mtprotoproxy.py b/mtprotoproxy.py index a7bfafc..218fa36 100755 --- a/mtprotoproxy.py +++ b/mtprotoproxy.py @@ -86,6 +86,7 @@ mask_host_cached_ip = None last_clients_with_time_skew = {} last_clients_with_first_pkt_error = collections.Counter() last_clients_with_same_handshake = collections.Counter() +proxy_start_time = 0 config = {} @@ -171,6 +172,9 @@ def init_config(): # delay in seconds between stats printing conf_dict.setdefault("STATS_PRINT_PERIOD", 600) + # delay in seconds between metric sending, if enabled + conf_dict.setdefault("SEND_METRICS_PERIOD", 15) + # delay in seconds between middle proxy info updates conf_dict.setdefault("PROXY_INFO_UPDATE_PERIOD", 24*60*60) @@ -208,6 +212,15 @@ def init_config(): # listen unix socket conf_dict.setdefault("LISTEN_UNIX_SOCK", "") + # prometheus push gateway addr to send metrics, disabled by default + conf_dict.setdefault("METRICS_HOST", None) + + # prometheus push gateway port + conf_dict.setdefault("METRICS_PORT", 9091) + + # prometheus push gateway identity string, by default proxy addr and port will be used + conf_dict.setdefault("METRICS_ID", None) + # allow access to config by attributes config = type("config", (dict,), conf_dict)(conf_dict) @@ -301,24 +314,35 @@ def print_err(*params): def init_stats(): global stats - stats = {user: collections.Counter() for user in config.USERS} + global user_stats + + stats = collections.Counter() + user_stats = {user: collections.Counter() for user in config.USERS} -def update_stats(user, connects=0, curr_connects=0, octets=0, msgs=0): +def init_proxy_start_time(): + global proxy_start_time + proxy_start_time = time.time() + + +def update_stats(**kw_stats): global stats + stats.update(**kw_stats) - if user not in stats: - stats[user] = collections.Counter() - stats[user].update(connects=connects, curr_connects=curr_connects, - octets=octets, msgs=msgs) +def update_user_stats(user, **kw_stats): + global user_stats + + if user not in user_stats: + user_stats[user] = collections.Counter() + user_stats[user].update(**kw_stats) def get_curr_connects_count(): - global stats + global user_stats all_connects = 0 - for user, stat in stats.items(): + for user, stat in user_stats.items(): all_connects += stat["curr_connects"] return all_connects @@ -832,6 +856,8 @@ async def handle_bad_client(reader_clt, writer_clt, handshake): global mask_host_cached_ip + update_stats(bad_clients=1) + if writer_clt.transport.is_closing(): return @@ -922,7 +948,6 @@ async def handle_fake_tls_handshake(handshake, reader, writer, peer): DIGEST_HALFLEN = 16 DIGEST_POS = 11 - SESSION_ID_LEN_POS = DIGEST_POS + DIGEST_LEN SESSION_ID_POS = SESSION_ID_LEN_POS + 1 @@ -1387,7 +1412,7 @@ async def handle_client(reader_clt, writer_clt): reader_clt, writer_clt, proto_tag, user, dc_idx, enc_key_and_iv, peer = clt_data cl_ip, cl_port = peer - update_stats(user, connects=1) + update_user_stats(user, connects=1) connect_directly = (not config.USE_MIDDLE_PROXY or disable_middle_proxy) @@ -1457,7 +1482,7 @@ async def handle_client(reader_clt, writer_clt): await wr.drain() return else: - update_stats(user, octets=len(data), msgs=1) + update_user_stats(user, octets=len(data), msgs=1) wr.write(data, extra) await wr.drain() except (OSError, asyncio.streams.IncompleteReadError) as e: @@ -1470,11 +1495,11 @@ async def handle_client(reader_clt, writer_clt): task_tg_to_clt = asyncio.ensure_future(tg_to_clt) task_clt_to_tg = asyncio.ensure_future(clt_to_tg) - update_stats(user, curr_connects=1) + update_user_stats(user, curr_connects=1) tcp_limit_hit = ( user in config.USER_MAX_TCP_CONNS and - stats[user]["curr_connects"] > config.USER_MAX_TCP_CONNS[user] + user_stats[user]["curr_connects"] > config.USER_MAX_TCP_CONNS[user] ) user_expired = ( @@ -1484,13 +1509,13 @@ async def handle_client(reader_clt, writer_clt): user_data_quota_hit = ( user in config.USER_DATA_QUOTA and - stats[user]["octets"] > config.USER_DATA_QUOTA[user] + user_stats[user]["octets"] > config.USER_DATA_QUOTA[user] ) if (not tcp_limit_hit) and (not user_expired) and (not user_data_quota_hit): await asyncio.wait([task_tg_to_clt, task_clt_to_tg], return_when=asyncio.FIRST_COMPLETED) - update_stats(user, curr_connects=-1) + update_user_stats(user, curr_connects=-1) task_tg_to_clt.cancel() task_clt_to_tg.cancel() @@ -1512,7 +1537,7 @@ async def handle_client_wrapper(reader, writer): async def stats_printer(): - global stats + global user_stats global last_clients_with_time_skew global last_clients_with_first_pkt_error global last_clients_with_same_handshake @@ -1521,7 +1546,7 @@ async def stats_printer(): await asyncio.sleep(config.STATS_PRINT_PERIOD) print("Stats for", time.strftime("%d.%m.%Y %H:%M:%S")) - for user, stat in stats.items(): + for user, stat in user_stats.items(): print("%s: %d connects (%d current), %.2f MB, %d msgs" % ( user, stat["connects"], stat["curr_connects"], stat["octets"] / 1000000, stat["msgs"])) @@ -1547,6 +1572,124 @@ async def stats_printer(): last_clients_with_same_handshake.clear() +def make_metrics_pkt(host, instance_name, metrics): + pkt_body_list = [] + used_names = set() + + for name, m_type, desc, val in metrics: + if name not in used_names: + pkt_body_list.append("# HELP %s %s" % (name, desc)) + pkt_body_list.append("# TYPE %s %s" % (name, m_type)) + used_names.add(name) + + if isinstance(val, dict): + tags = [] + for tag, tag_val in val.items(): + if tag == "val": + continue + tag_val = tag_val.replace('"', r'\"') + tags.append('%s="%s"' % (tag, tag_val)) + pkt_body_list.append("%s{%s} %s" % (name, ",".join(tags), val["val"])) + else: + pkt_body_list.append("%s %s" % (name, val)) + pkt_body = "\n".join(pkt_body_list) + "\n" + + instance_name = urllib.parse.quote_plus(instance_name) + + pkt_header_list = [] + pkt_header_list.append("PUT /metrics/job/mtprotoproxy/instance/%s HTTP/1.1" % instance_name) + pkt_header_list.append("Accept-Encoding: identity") + pkt_header_list.append("Content-Length: %d" % len(pkt_body)) + pkt_header_list.append("Host: %s" % host) + pkt_header_list.append("User-Agent: Python-urllib/3.7") + pkt_header_list.append("Content-Type: text/plain; version=0.0.4; charset=utf-8") + pkt_header_list.append("Connection: close") + + pkt_header = "\r\n".join(pkt_header_list) + + pkt = pkt_header + "\r\n\r\n" + pkt_body + return pkt + + +async def send_metrics(host, port): + global stats + global user_stats + global my_ip_info + global proxy_start_time + global last_clients_with_time_skew + global last_clients_with_first_pkt_error + global last_clients_with_same_handshake + + instance_name = config.METRICS_ID + if not instance_name: + if my_ip_info.get("ipv4"): + instance_name = "%s:%d" % (my_ip_info["ipv4"], config.PORT) + elif my_ip_info.get("ipv6"): + instance_name = "%s:%d" % (my_ip_info["ipv6"], config.PORT) + else: + instance_name = "%s:%d" % ("unknown_ip", config.PORT) + + metrics = [] + metrics.append(["uptime", "counter", "proxy uptime", time.time() - proxy_start_time]) + metrics.append(["bad_clients", "counter", "clients with bad secret", stats["bad_clients"]]) + + user_metrics_desc = [ + ["connects_all", "counter", "all connects", "connects"], + ["connects_curr", "gauge", "current connects", "curr_connects"], + ["octets", "counter", "octets proxied", "octets"], + ["msgs", "counter", "msgs proxied", "msgs"], + ] + + for m_name, m_type, m_desc, stat_key in user_metrics_desc: + for user, stat in user_stats.items(): + metric = {"user": user, "val": stat[stat_key]} + metrics.append([m_name, m_type, m_desc, metric]) + + metrics_host_hdr = "%s:%s" % (host, port) + pkt = make_metrics_pkt(metrics_host_hdr, instance_name, metrics) + + reader, writer = await asyncio.open_connection(host, port) + writer.write(pkt.encode()) + await writer.drain() + + http_vers = (await reader.readuntil(b" "))[:-1] + http_statuscode = (await reader.readuntil(b" "))[:-1] + writer.close() + + return http_vers == b"HTTP/1.1" and http_statuscode == b"202" + + +async def metrics_sender(): + SEND_METRICS_ENABLED_CHECK_PERIOD = 60 + SEND_METRICS_TIMEOUT = 10 + + last_error_msg = "" + + while True: + if not config.METRICS_HOST: + await asyncio.sleep(config.SEND_METRICS_ENABLED_CHECK_PERIOD) + continue + await asyncio.sleep(config.SEND_METRICS_PERIOD) + + error_msg = "" + try: + task = send_metrics(config.METRICS_HOST, config.METRICS_PORT) + sent = await asyncio.wait_for(task, timeout=SEND_METRICS_TIMEOUT) + if not sent: + error_msg = "The METRICS_HOST %s refused metrics" % config.METRICS_HOST + except ConnectionRefusedError: + error_msg = "The METRICS_HOST %s is refusing connections" % config.METRICS_HOST + except (TimeoutError, asyncio.TimeoutError): + error_msg = "Got timeout while sending metrics to METRICS_HOST %s" % config.METRICS_HOST + except Exception as E: + error_msg = ("Got exception while sending metrics to METRICS_HOST %s: %s" % + (config.METRICS_HOST, E)) + + if error_msg and error_msg != last_error_msg: + print_err(error_msg) + last_error_msg = error_msg + + async def make_https_req(url, host="core.telegram.org"): """ Make request, return resp body and headers. """ SSL_PORT = 443 @@ -1921,6 +2064,7 @@ def main(): try_setup_uvloop() init_stats() + init_proxy_start_time() if sys.platform == "win32": loop = asyncio.ProactorEventLoop() @@ -1932,6 +2076,9 @@ def main(): stats_printer_task = asyncio.Task(stats_printer()) asyncio.ensure_future(stats_printer_task) + metrics_sender_task = asyncio.Task(metrics_sender()) + asyncio.ensure_future(metrics_sender_task) + if config.USE_MIDDLE_PROXY: middle_proxy_updater_task = asyncio.Task(update_middle_proxy_info()) asyncio.ensure_future(middle_proxy_updater_task)