Add proxy protocol v1/v2 support (#119)

* add proxy protocol v1/v2 support

With fake-tls enabled, it was still quite hard to use mtprotoproxy
as backend behing some reverse https/tls proxy (nginx, haproxy, etc)
because it still need client address & port info.
With nginx already configured to use stream proxy with proxy protocol,
it was impossibe to connect due additional proxy header transmission
before real hadshake.
Adding general support of proxy protocol fixed both issues.

New config option PROXY_PROTOCOL = True enables transparent support,
unproxied incoming connections will still be accepted.
Since reverse proxy needs to be trusted, option disabled by default.

References:
* https://www.haproxy.com/blog/haproxy/proxy-protocol/
* http://haproxy.1wt.eu/download/1.5/doc/proxy-protocol.txt

* slightly optimize proxy v1 error path

* rework proxy handler

* deny direct connection with enabled PROXY_PROTOCOl per specs
* simplify proxy header checking
* use textual form of proxy v1 header
* drop useless find() call

* fix client address logging
This commit is contained in:
Vladislav Grishenko
2019-08-14 23:03:01 +05:00
committed by Alexander Bersenev
parent d9fa5b222a
commit f51a4bfe34

View File

@@ -121,6 +121,9 @@ def init_config():
# allows to connect in tls mode only
conf_dict.setdefault("TLS_ONLY", False)
# support proxy protocol v1/v2 for incoming connections
conf_dict.setdefault("PROXY_PROTOCOL", False)
# set the tls domain for the proxy, has an influence only on starting message
conf_dict.setdefault("TLS_DOMAIN", "google.com")
@@ -789,7 +792,7 @@ async def handle_bad_client(reader_clt, writer_clt, handshake):
set_instant_rst(writer_clt.get_extra_info("socket"))
set_bufsizes(writer_clt.get_extra_info("socket"), BUF_SIZE, BUF_SIZE)
if not config.MASK:
if not config.MASK or handshake is None:
while await reader_clt.read(BUF_SIZE):
# just consume all the data
pass
@@ -834,7 +837,7 @@ async def handle_bad_client(reader_clt, writer_clt, handshake):
writer_srv.transport.abort()
async def handle_pseudo_tls_handshake(handshake, reader, writer):
async def handle_pseudo_tls_handshake(handshake, reader, writer, peer):
global used_handshakes
TLS_VERS = b"\x03\x03"
@@ -853,8 +856,7 @@ async def handle_pseudo_tls_handshake(handshake, reader, writer):
digest = handshake[DIGEST_POS: DIGEST_POS + DIGEST_LEN]
if digest in used_handshakes:
ip = writer.get_extra_info('peername')[0]
print_err("Active TLS fingerprinting detected from %s, handling it" % ip)
print_err("Active TLS fingerprinting detected from %s, handling it" % peer[0])
return False
sess_id_len = handshake[SESSION_ID_LEN_POS]
@@ -898,17 +900,80 @@ async def handle_pseudo_tls_handshake(handshake, reader, writer):
return False
async def handle_proxy_protocol(reader, peer=None):
PROXY_SIGNATURE = b"PROXY "
PROXY_MIN_LEN = 6
PROXY_TCP4 = b"TCP4"
PROXY_TCP6 = b"TCP6"
PROXY_UNKNOWN = b"UNKNOWN"
PROXY2_SIGNATURE = b"\x0d\x0a\x0d\x0a\x00\x0d\x0a\x51\x55\x49\x54\x0a"
PROXY2_MIN_LEN = 16
PROXY2_AF_UNSPEC = 0x0
PROXY2_AF_INET = 0x1
PROXY2_AF_INET6 = 0x2
header = await reader.readexactly(PROXY_MIN_LEN)
if header.startswith(PROXY_SIGNATURE):
# proxy header v1
header += await reader.readuntil(b"\r\n")
_, proxy_fam, *proxy_addr = header[:-2].split(b" ")
if proxy_fam in (PROXY_TCP4, PROXY_TCP6):
if len(proxy_addr) == 4:
src_addr = proxy_addr[0].decode('ascii')
src_port = proxy_addr[2].decode('ascii')
return (src_addr, src_port)
elif proxy_fam == PROXY_UNKNOWN:
return peer
return False
header += await reader.readexactly(PROXY2_MIN_LEN - PROXY_MIN_LEN)
if header.startswith(PROXY2_SIGNATURE):
# proxy header v2
proxy_ver = header[12]
if proxy_ver & 0xf0 != 0x20:
return False
proxy_len = int.from_bytes(header[14:16], "big")
proxy_addr = await reader.readexactly(proxy_len)
if proxy_ver == 0x21:
proxy_fam = header[13] >> 4
if proxy_fam == PROXY2_AF_INET:
if proxy_len >= (4 + 2)*2:
src_addr = socket.inet_ntop(socket.AF_INET, proxy_addr[:4])
src_port = int.from_bytes(proxy_addr[8:10], "big")
return (src_addr, src_port)
elif proxy_fam == PROXY2_AF_INET6:
if proxy_len >= (16 + 2)*2:
src_addr = socket.inet_ntop(socket.AF_INET6, proxy_addr[:16])
src_port = int.from_bytes(proxy_addr[32:34], "big")
return (src_addr, src_port)
elif proxy_fam == PROXY2_AF_UNSPEC:
return peer
elif proxy_ver == 0x20:
return peer
return False
async def handle_handshake(reader, writer):
global used_handshakes
TLS_START_BYTES = b"\x16\x03\x01\x02\x00\x01\x00\x01\xfc\x03\x03"
EMPTY_READ_BUF_SIZE = 4096
peer = writer.get_extra_info('peername')[:2]
if config.PROXY_PROTOCOL:
peer = await handle_proxy_protocol(reader, peer)
if not peer:
await handle_bad_client(reader, writer, None)
return False
handshake = await reader.readexactly(HANDSHAKE_LEN)
if handshake.startswith(TLS_START_BYTES):
handshake += await reader.readexactly(TLS_HANDSHAKE_LEN - HANDSHAKE_LEN)
tls_handshake_result = await handle_pseudo_tls_handshake(handshake, reader, writer)
tls_handshake_result = await handle_pseudo_tls_handshake(handshake, reader, writer, peer)
if not tls_handshake_result:
await handle_bad_client(reader, writer, handshake)
@@ -926,8 +991,7 @@ async def handle_handshake(reader, writer):
enc_prekey, enc_iv = enc_prekey_and_iv[:PREKEY_LEN], enc_prekey_and_iv[PREKEY_LEN:]
if dec_prekey_and_iv in used_handshakes:
ip = writer.get_extra_info('peername')[0]
print_err("Active fingerprinting detected from %s, handling it" % ip)
print_err("Active fingerprinting detected from %s, handling it" % peer[0])
await handle_bad_client(reader, writer, handshake)
return False
@@ -957,7 +1021,7 @@ async def handle_handshake(reader, writer):
reader = CryptoWrappedStreamReader(reader, decryptor)
writer = CryptoWrappedStreamWriter(writer, encryptor)
return reader, writer, proto_tag, user, dc_idx, enc_key + enc_iv
return reader, writer, proto_tag, user, dc_idx, enc_key + enc_iv, peer
await handle_bad_client(reader, writer, handshake)
return False
@@ -1212,7 +1276,6 @@ async def handle_client(reader_clt, writer_clt):
set_ack_timeout(writer_clt.get_extra_info("socket"), config.CLIENT_ACK_TIMEOUT)
set_bufsizes(writer_clt.get_extra_info("socket"), get_to_tg_bufsize(), get_to_clt_bufsize())
cl_ip, cl_port = writer_clt.get_extra_info('peername')[:2]
try:
clt_data = await asyncio.wait_for(handle_handshake(reader_clt, writer_clt),
timeout=config.CLIENT_HANDSHAKE_TIMEOUT)
@@ -1222,7 +1285,8 @@ async def handle_client(reader_clt, writer_clt):
if not clt_data:
return
reader_clt, writer_clt, proto_tag, user, dc_idx, enc_key_and_iv = clt_data
reader_clt, writer_clt, proto_tag, user, dc_idx, enc_key_and_iv, peer = clt_data
cl_ip, cl_port = peer
update_stats(user, connects=1)