diff --git a/mtprotoproxy.py b/mtprotoproxy.py index 9593fea..2d44bc9 100755 --- a/mtprotoproxy.py +++ b/mtprotoproxy.py @@ -984,6 +984,23 @@ def gen_x25519_public_key(): return int.to_bytes((n*n) % P, length=32, byteorder="little") +async def connect_reader_to_writer(reader, writer): + try: + while True: + data = await reader.read(BUF_SIZE) + + if not data: + if not writer.transport.is_closing(): + writer.write_eof() + await writer.drain() + return + + writer.write(data) + await writer.drain() + except (OSError, asyncio.IncompleteReadError) as e: + pass + + async def handle_bad_client(reader_clt, writer_clt, handshake): BUF_SIZE = 8192 CONNECT_TIMEOUT = 5 @@ -1003,22 +1020,6 @@ async def handle_bad_client(reader_clt, writer_clt, handshake): pass return - async def connect_reader_to_writer(reader, writer): - try: - while True: - data = await reader.read(BUF_SIZE) - - if not data: - if not writer.transport.is_closing(): - writer.write_eof() - await writer.drain() - return - - writer.write(data) - await writer.drain() - except (OSError, asyncio.IncompleteReadError) as e: - pass - writer_srv = None try: host = mask_host_cached_ip or config.MASK_HOST @@ -1122,7 +1123,6 @@ async def handle_fake_tls_handshake(handshake, reader, writer, peer): last_clients_with_time_skew[peer[0]] = (time.time() - timestamp) // 60 continue - http_data = myrandom.getrandbytes(fake_cert_len) srv_hello = TLS_VERS + b"\x00"*DIGEST_LEN + bytes([sess_id_len]) + sess_id @@ -1546,6 +1546,32 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port): return reader_tgt, writer_tgt +async def tg_connect_reader_to_writer(rd, wr, user, rd_buf_size, is_upstream): + try: + while True: + data = await rd.read(rd_buf_size) + if isinstance(data, tuple): + data, extra = data + else: + extra = {} + + if not data: + wr.write_eof() + await wr.drain() + return + else: + if is_upstream: + update_user_stats(user, octets_from_client=len(data), msgs_from_client=1) + else: + update_user_stats(user, octets_to_client=len(data), msgs_to_client=1) + + wr.write(data, extra) + await wr.drain() + except (OSError, asyncio.IncompleteReadError) as e: + # print_err(e) + pass + + async def handle_client(reader_clt, writer_clt): set_keepalive(writer_clt.get_extra_info("socket"), config.CLIENT_KEEPALIVE, attempts=3) set_ack_timeout(writer_clt.get_extra_info("socket"), config.CLIENT_ACK_TIMEOUT) @@ -1608,33 +1634,10 @@ async def handle_client(reader_clt, writer_clt): else: return - async def connect_reader_to_writer(rd, wr, user, rd_buf_size, is_upstream): - try: - while True: - data = await rd.read(rd_buf_size) - if isinstance(data, tuple): - data, extra = data - else: - extra = {} - - if not data: - wr.write_eof() - await wr.drain() - return - else: - if is_upstream: - update_user_stats(user, octets_from_client=len(data), msgs_from_client=1) - else: - update_user_stats(user, octets_to_client=len(data), msgs_to_client=1) - - wr.write(data, extra) - await wr.drain() - except (OSError, asyncio.IncompleteReadError) as e: - # print_err(e) - pass - - tg_to_clt = connect_reader_to_writer(reader_tg, writer_clt, user, get_to_clt_bufsize(), False) - clt_to_tg = connect_reader_to_writer(reader_clt, writer_tg, user, get_to_tg_bufsize(), True) + tg_to_clt = tg_connect_reader_to_writer(reader_tg, writer_clt, user, + get_to_clt_bufsize(), False) + clt_to_tg = tg_connect_reader_to_writer(reader_clt, writer_tg, + user, get_to_tg_bufsize(), True) task_tg_to_clt = asyncio.ensure_future(tg_to_clt) task_clt_to_tg = asyncio.ensure_future(clt_to_tg)