diff --git a/mtprotoproxy.py b/mtprotoproxy.py index 1f6d497..7f72340 100755 --- a/mtprotoproxy.py +++ b/mtprotoproxy.py @@ -738,6 +738,33 @@ class ProxyReqStreamWriter(LayeredStreamWriterBase): return self.upstream.write(full_msg) +def try_setsockopt(sock, level, option, value): + try: + sock.setsockopt(level, option, value) + except OSError as E: + pass + + +def set_keepalive(sock, interval=40, attempts=5): + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + if hasattr(socket, "TCP_KEEPIDLE"): + try_setsockopt(sock, socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, interval) + if hasattr(socket, "TCP_KEEPINTVL"): + try_setsockopt(sock, socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, interval) + if hasattr(socket, "TCP_KEEPCNT"): + try_setsockopt(sock, socket.IPPROTO_TCP, socket.TCP_KEEPCNT, attempts) + + +def set_ack_timeout(sock, timeout): + if hasattr(socket, "TCP_USER_TIMEOUT"): + try_setsockopt(sock, socket.IPPROTO_TCP, socket.TCP_USER_TIMEOUT, timeout*1000) + + +def set_bufsizes(sock, recv_buf, send_buf): + try_setsockopt(sock, socket.SOL_SOCKET, socket.SO_RCVBUF, recv_buf) + try_setsockopt(sock, socket.SOL_SOCKET, socket.SO_SNDBUF, send_buf) + + async def handle_pseudo_tls_handshake(handshake, reader, writer): global used_handshakes @@ -813,7 +840,12 @@ async def handle_handshake(reader, writer): if handshake.startswith(TLS_START_BYTES) and not config.DISABLE_TLS: handshake += await reader.readexactly(TLS_HANDSHAKE_LEN - HANDSHAKE_LEN) tls_handshake_result = await handle_pseudo_tls_handshake(handshake, reader, writer) + if not tls_handshake_result: + if hasattr(socket, "SO_LINGER"): + INSTANT_RST = b"\x01\x00\x00\x00\x00\x00\x00\x00" + try_setsockopt(writer.get_extra_info("socket"), + socket.SOL_SOCKET, socket.SO_LINGER, INSTANT_RST) return False reader, writer = tls_handshake_result handshake = await reader.readexactly(HANDSHAKE_LEN) @@ -867,33 +899,6 @@ async def handle_handshake(reader, writer): return False -def try_setsockopt(sock, level, option, value): - try: - sock.setsockopt(level, option, value) - except OSError as E: - pass - - -def set_keepalive(sock, interval=40, attempts=5): - sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - if hasattr(socket, "TCP_KEEPIDLE"): - try_setsockopt(sock, socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, interval) - if hasattr(socket, "TCP_KEEPINTVL"): - try_setsockopt(sock, socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, interval) - if hasattr(socket, "TCP_KEEPCNT"): - try_setsockopt(sock, socket.IPPROTO_TCP, socket.TCP_KEEPCNT, attempts) - - -def set_ack_timeout(sock, timeout): - if hasattr(socket, "TCP_USER_TIMEOUT"): - try_setsockopt(sock, socket.IPPROTO_TCP, socket.TCP_USER_TIMEOUT, timeout*1000) - - -def set_bufsizes(sock, recv_buf, send_buf): - try_setsockopt(sock, socket.SOL_SOCKET, socket.SO_RCVBUF, recv_buf) - try_setsockopt(sock, socket.SOL_SOCKET, socket.SO_SNDBUF, send_buf) - - async def open_connection_tryer(addr, port, limit, timeout, max_attempts=3): for attempt in range(max_attempts-1): try: