Merge branch 'master' into pypi

This commit is contained in:
Alexander Bersenev
2018-07-05 16:27:06 +05:00
3 changed files with 83 additions and 45 deletions

View File

@@ -2,10 +2,9 @@ FROM alpine:3.6
RUN adduser tgproxy -u 10000 -D RUN adduser tgproxy -u 10000 -D
RUN apk add --no-cache python3 py3-crypto ca-certificates libcap RUN apk add --no-cache python3 py3-cryptography ca-certificates libcap
COPY mtprotoproxy.py config.py /home/tgproxy/ COPY mtprotoproxy.py config.py /home/tgproxy/
COPY pyaes/*.py /home/tgproxy/pyaes/
RUN chown -R tgproxy:tgproxy /home/tgproxy RUN chown -R tgproxy:tgproxy /home/tgproxy
RUN setcap cap_net_bind_service=+ep /usr/bin/python3.6 RUN setcap cap_net_bind_service=+ep /usr/bin/python3.6

View File

@@ -16,4 +16,12 @@ To advertise a channel get a tag from **@MTProxybot** and write it to *config.py
## Performance ## ## Performance ##
The proxy performance should be enough to comfortably serve about 4 000 simultaneous users on The proxy performance should be enough to comfortably serve about 4 000 simultaneous users on
the smallest VDS instance with 1 CPU core and 1024MB RAM. the VDS instance with 1 CPU core and 1024MB RAM.
## Advanced Usage ##
The proxy can be launched:
- with a custom config: `python3 mtprotoproxy.py [configfile]`
- several times, clients will be automaticaly balanced between instances
- using *PyPy* interprteter
- with runtime statistics exported for [Prometheus](https://prometheus.io/): using [prometheus](https://github.com/alexbers/mtprotoproxy/tree/prometheus) branch

View File

@@ -153,11 +153,12 @@ PREFER_IPV6 = config.get("PREFER_IPV6", socket.has_ipv6)
# disables tg->client trafic reencryption, faster but less secure # disables tg->client trafic reencryption, faster but less secure
FAST_MODE = config.get("FAST_MODE", True) FAST_MODE = config.get("FAST_MODE", True)
STATS_PRINT_PERIOD = config.get("STATS_PRINT_PERIOD", 600) STATS_PRINT_PERIOD = config.get("STATS_PRINT_PERIOD", 600)
PROXY_INFO_UPDATE_PERIOD = config.get("PROXY_INFO_UPDATE_PERIOD", 60*60*24) PROXY_INFO_UPDATE_PERIOD = config.get("PROXY_INFO_UPDATE_PERIOD", 24*60*60)
READ_BUF_SIZE = config.get("READ_BUF_SIZE", 16384) TO_CLT_BUFSIZE = config.get("TO_CLT_BUFSIZE", 8192)
WRITE_BUF_SIZE = config.get("WRITE_BUF_SIZE", 65536) TO_TG_BUFSIZE = config.get("TO_TG_BUFSIZE", 65536)
CLIENT_KEEPALIVE = config.get("CLIENT_KEEPALIVE", 60*30) CLIENT_KEEPALIVE = config.get("CLIENT_KEEPALIVE", 10*60)
CLIENT_HANDSHAKE_TIMEOUT = config.get("CLIENT_HANDSHAKE_TIMEOUT", 10) CLIENT_HANDSHAKE_TIMEOUT = config.get("CLIENT_HANDSHAKE_TIMEOUT", 10)
CLIENT_ACK_TIMEOUT = config.get("CLIENT_ACK_TIMEOUT", 5*60)
TG_DATACENTER_PORT = 443 TG_DATACENTER_PORT = 443
@@ -208,6 +209,7 @@ DC_IDX_POS = 60
PROTO_TAG_ABRIDGED = b"\xef\xef\xef\xef" PROTO_TAG_ABRIDGED = b"\xef\xef\xef\xef"
PROTO_TAG_INTERMEDIATE = b"\xee\xee\xee\xee" PROTO_TAG_INTERMEDIATE = b"\xee\xee\xee\xee"
PROTO_TAG_SECURE = b"\xdd\xdd\xdd\xdd"
CBC_PADDING = 16 CBC_PADDING = 16
PADDING_FILLER = b"\x04\x00\x00\x00" PADDING_FILLER = b"\x04\x00\x00\x00"
@@ -227,14 +229,14 @@ def init_stats():
stats = {user: collections.Counter() for user in USERS} stats = {user: collections.Counter() for user in USERS}
def update_stats(user, connects=0, curr_connects=0, octets=0): def update_stats(user, connects=0, curr_connects=0, octets=0, msgs=0):
global stats global stats
if user not in stats: if user not in stats:
stats[user] = collections.Counter() stats[user] = collections.Counter()
stats[user].update(connects=connects, curr_connects=curr_connects, stats[user].update(connects=connects, curr_connects=curr_connects,
octets=octets) octets=octets, msgs=msgs)
class LayeredStreamReaderBase: class LayeredStreamReaderBase:
@@ -434,6 +436,10 @@ class MTProtoIntermediateFrameStreamReader(LayeredStreamReaderBase):
data = await self.upstream.readexactly(msg_len) data = await self.upstream.readexactly(msg_len)
if msg_len % 4 != 0:
cut_border = msg_len - (msg_len % 4)
data = data[:cut_border]
return data, extra return data, extra
@@ -553,7 +559,7 @@ async def handle_handshake(reader, writer):
decrypted = decryptor.decrypt(handshake) decrypted = decryptor.decrypt(handshake)
proto_tag = decrypted[PROTO_TAG_POS:PROTO_TAG_POS+4] proto_tag = decrypted[PROTO_TAG_POS:PROTO_TAG_POS+4]
if proto_tag not in (PROTO_TAG_ABRIDGED, PROTO_TAG_INTERMEDIATE): if proto_tag not in (PROTO_TAG_ABRIDGED, PROTO_TAG_INTERMEDIATE, PROTO_TAG_SECURE):
continue continue
dc_idx = int.from_bytes(decrypted[DC_IDX_POS:DC_IDX_POS+2], "little", signed=True) dc_idx = int.from_bytes(decrypted[DC_IDX_POS:DC_IDX_POS+2], "little", signed=True)
@@ -562,13 +568,34 @@ async def handle_handshake(reader, writer):
writer = CryptoWrappedStreamWriter(writer, encryptor) writer = CryptoWrappedStreamWriter(writer, encryptor)
return reader, writer, proto_tag, user, dc_idx, enc_key + enc_iv return reader, writer, proto_tag, user, dc_idx, enc_key + enc_iv
while await reader.read(READ_BUF_SIZE): EMPTY_READ_BUF_SIZE = 4096
while await reader.read(EMPTY_READ_BUF_SIZE):
# just consume all the data # just consume all the data
pass pass
return False return False
def set_keepalive(sock, interval=40, attempts=5):
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
if hasattr(socket, "TCP_KEEPIDLE"):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, interval)
if hasattr(socket, "TCP_KEEPINTVL"):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, interval)
if hasattr(socket, "TCP_KEEPCNT"):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, attempts)
def set_ack_timeout(sock, timeout):
if hasattr(socket, "TCP_USER_TIMEOUT"):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_USER_TIMEOUT, timeout*1000)
def set_bufsizes(sock, recv_buf, send_buf):
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, recv_buf)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, send_buf)
async def do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=None): async def do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=None):
RESERVED_NONCE_FIRST_CHARS = [b"\xef"] RESERVED_NONCE_FIRST_CHARS = [b"\xef"]
RESERVED_NONCE_BEGININGS = [b"\x48\x45\x41\x44", b"\x50\x4F\x53\x54", RESERVED_NONCE_BEGININGS = [b"\x48\x45\x41\x44", b"\x50\x4F\x53\x54",
@@ -588,7 +615,10 @@ async def do_direct_handshake(proto_tag, dc_idx, dec_key_and_iv=None):
try: try:
reader_tgt, writer_tgt = await asyncio.open_connection(dc, TG_DATACENTER_PORT, reader_tgt, writer_tgt = await asyncio.open_connection(dc, TG_DATACENTER_PORT,
limit=READ_BUF_SIZE) limit=TO_CLT_BUFSIZE)
set_keepalive(writer_tgt.get_extra_info("socket"))
set_bufsizes(writer_tgt.get_extra_info("socket"), TO_CLT_BUFSIZE, TO_TG_BUFSIZE)
except ConnectionRefusedError as E: except ConnectionRefusedError as E:
print_err("Got connection refused while trying to connect to", dc, TG_DATACENTER_PORT) print_err("Got connection refused while trying to connect to", dc, TG_DATACENTER_PORT)
return False return False
@@ -658,21 +688,6 @@ def get_middleproxy_aes_key_and_iv(nonce_srv, nonce_clt, clt_ts, srv_ip, clt_por
return key, iv return key, iv
def set_keepalive(sock, interval=40, attempts=5):
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
if hasattr(socket, "TCP_KEEPIDLE"):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, interval)
if hasattr(socket, "TCP_KEEPINTVL"):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, interval)
if hasattr(socket, "TCP_KEEPCNT"):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, attempts)
def set_bufsizes(sock, recv_buf=READ_BUF_SIZE, send_buf=WRITE_BUF_SIZE):
sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, recv_buf)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, send_buf)
async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port): async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port):
START_SEQ_NO = -2 START_SEQ_NO = -2
NONCE_LEN = 16 NONCE_LEN = 16
@@ -700,9 +715,9 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port):
addr, port = random.choice(TG_MIDDLE_PROXIES_V4[dc_idx]) addr, port = random.choice(TG_MIDDLE_PROXIES_V4[dc_idx])
try: try:
reader_tgt, writer_tgt = await asyncio.open_connection(addr, port, limit=READ_BUF_SIZE) reader_tgt, writer_tgt = await asyncio.open_connection(addr, port, limit=TO_CLT_BUFSIZE)
set_keepalive(writer_tgt.get_extra_info("socket")) set_keepalive(writer_tgt.get_extra_info("socket"))
set_bufsizes(writer_tgt.get_extra_info("socket")) set_bufsizes(writer_tgt.get_extra_info("socket"), TO_CLT_BUFSIZE, TO_TG_BUFSIZE)
except ConnectionRefusedError as E: except ConnectionRefusedError as E:
print_err("Got connection refused while trying to connect to", addr, port) print_err("Got connection refused while trying to connect to", addr, port)
return False return False
@@ -724,7 +739,7 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port):
old_reader = reader_tgt old_reader = reader_tgt
reader_tgt = MTProtoFrameStreamReader(reader_tgt, START_SEQ_NO) reader_tgt = MTProtoFrameStreamReader(reader_tgt, START_SEQ_NO)
ans = await reader_tgt.read(READ_BUF_SIZE) ans = await reader_tgt.read(TO_CLT_BUFSIZE)
if len(ans) != RPC_NONCE_ANS_LEN: if len(ans) != RPC_NONCE_ANS_LEN:
return False return False
@@ -805,8 +820,9 @@ async def do_middleproxy_handshake(proto_tag, dc_idx, cl_ip, cl_port):
async def handle_client(reader_clt, writer_clt): async def handle_client(reader_clt, writer_clt):
set_keepalive(writer_clt.get_extra_info("socket"), CLIENT_KEEPALIVE) set_keepalive(writer_clt.get_extra_info("socket"), CLIENT_KEEPALIVE, attempts=3)
set_bufsizes(writer_clt.get_extra_info("socket")) set_ack_timeout(writer_clt.get_extra_info("socket"), CLIENT_ACK_TIMEOUT)
set_bufsizes(writer_clt.get_extra_info("socket"), TO_TG_BUFSIZE, TO_CLT_BUFSIZE)
try: try:
clt_data = await asyncio.wait_for(handle_handshake(reader_clt, writer_clt), clt_data = await asyncio.wait_for(handle_handshake(reader_clt, writer_clt),
@@ -851,16 +867,16 @@ async def handle_client(reader_clt, writer_clt):
if proto_tag == PROTO_TAG_ABRIDGED: if proto_tag == PROTO_TAG_ABRIDGED:
reader_clt = MTProtoCompactFrameStreamReader(reader_clt) reader_clt = MTProtoCompactFrameStreamReader(reader_clt)
writer_clt = MTProtoCompactFrameStreamWriter(writer_clt) writer_clt = MTProtoCompactFrameStreamWriter(writer_clt)
elif proto_tag == PROTO_TAG_INTERMEDIATE: elif proto_tag in (PROTO_TAG_INTERMEDIATE, PROTO_TAG_SECURE):
reader_clt = MTProtoIntermediateFrameStreamReader(reader_clt) reader_clt = MTProtoIntermediateFrameStreamReader(reader_clt)
writer_clt = MTProtoIntermediateFrameStreamWriter(writer_clt) writer_clt = MTProtoIntermediateFrameStreamWriter(writer_clt)
else: else:
return return
async def connect_reader_to_writer(rd, wr, user): async def connect_reader_to_writer(rd, wr, user, rd_buf_size):
try: try:
while True: while True:
data = await rd.read(READ_BUF_SIZE) data = await rd.read(rd_buf_size)
if isinstance(data, tuple): if isinstance(data, tuple):
data, extra = data data, extra = data
else: else:
@@ -871,15 +887,17 @@ async def handle_client(reader_clt, writer_clt):
await wr.drain() await wr.drain()
return return
else: else:
update_stats(user, octets=len(data)) update_stats(user, octets=len(data), msgs=1)
wr.write(data, extra) wr.write(data, extra)
await wr.drain() await wr.drain()
except (OSError, asyncio.streams.IncompleteReadError) as e: except (OSError, asyncio.streams.IncompleteReadError) as e:
# print_err(e) # print_err(e)
pass pass
task_tg_to_clt = asyncio.ensure_future(connect_reader_to_writer(reader_tg, writer_clt, user)) tg_to_clt = connect_reader_to_writer(reader_tg, writer_clt, user, TO_CLT_BUFSIZE)
task_clt_to_tg = asyncio.ensure_future(connect_reader_to_writer(reader_clt, writer_tg, user)) clt_to_tg = connect_reader_to_writer(reader_clt, writer_tg, user, TO_TG_BUFSIZE)
task_tg_to_clt = asyncio.ensure_future(tg_to_clt)
task_clt_to_tg = asyncio.ensure_future(clt_to_tg)
update_stats(user, curr_connects=1) update_stats(user, curr_connects=1)
await asyncio.wait([task_tg_to_clt, task_clt_to_tg], return_when=asyncio.FIRST_COMPLETED) await asyncio.wait([task_tg_to_clt, task_clt_to_tg], return_when=asyncio.FIRST_COMPLETED)
@@ -907,9 +925,9 @@ async def stats_printer():
print("Stats for", time.strftime("%d.%m.%Y %H:%M:%S")) print("Stats for", time.strftime("%d.%m.%Y %H:%M:%S"))
for user, stat in stats.items(): for user, stat in stats.items():
print("%s: %d connects (%d current), %.2f MB" % ( print("%s: %d connects (%d current), %.2f MB, %d msgs" % (
user, stat["connects"], stat["curr_connects"], user, stat["connects"], stat["curr_connects"],
stat["octets"] / 1000000)) stat["octets"] / 1000000, stat["msgs"]))
print(flush=True) print(flush=True)
@@ -1037,6 +1055,10 @@ def print_tg_info():
params_encodeded = urllib.parse.urlencode(params, safe=':') params_encodeded = urllib.parse.urlencode(params, safe=':')
print("{}: tg://proxy?{}".format(user, params_encodeded), flush=True) print("{}: tg://proxy?{}".format(user, params_encodeded), flush=True)
params = {"server": ip, "port": PORT, "secret": "dd" + secret}
params_encodeded = urllib.parse.urlencode(params, safe=':')
print("{}: tg://proxy?{} (beta)".format(user, params_encodeded), flush=True)
def loop_exception_handler(loop, context): def loop_exception_handler(loop, context):
exception = context.get("exception") exception = context.get("exception")
@@ -1044,15 +1066,24 @@ def loop_exception_handler(loop, context):
if exception: if exception:
if isinstance(exception, TimeoutError): if isinstance(exception, TimeoutError):
if transport: if transport:
print_err("Timeout, killing transport")
transport.abort() transport.abort()
return return
if isinstance(exception, OSError): if isinstance(exception, OSError):
IGNORE_ERRNO = { IGNORE_ERRNO = {
10038 # operation on non-socket on Windows, likely because fd == -1 10038, # operation on non-socket on Windows, likely because fd == -1
121, # the semaphore timeout period has expired on Windows
}
FORCE_CLOSE_ERRNO = {
113, # no route to host
} }
if exception.errno in IGNORE_ERRNO: if exception.errno in IGNORE_ERRNO:
return return
elif exception.errno in FORCE_CLOSE_ERRNO:
if transport:
transport.abort()
return
loop.default_exception_handler(context) loop.default_exception_handler(context)
@@ -1074,15 +1105,15 @@ def main():
middle_proxy_updater_task = asyncio.Task(update_middle_proxy_info()) middle_proxy_updater_task = asyncio.Task(update_middle_proxy_info())
asyncio.ensure_future(middle_proxy_updater_task) asyncio.ensure_future(middle_proxy_updater_task)
reuse_port = (sys.platform != "win32") reuse_port = hasattr(socket, "SO_REUSEPORT")
task_v4 = asyncio.start_server(handle_client_wrapper, '0.0.0.0', PORT, task_v4 = asyncio.start_server(handle_client_wrapper, '0.0.0.0', PORT,
limit=READ_BUF_SIZE, reuse_port=reuse_port, loop=loop) limit=TO_TG_BUFSIZE, reuse_port=reuse_port, loop=loop)
server_v4 = loop.run_until_complete(task_v4) server_v4 = loop.run_until_complete(task_v4)
if socket.has_ipv6: if socket.has_ipv6:
task_v6 = asyncio.start_server(handle_client_wrapper, '::', PORT, task_v6 = asyncio.start_server(handle_client_wrapper, '::', PORT,
limit=READ_BUF_SIZE, reuse_port=reuse_port, loop=loop) limit=TO_TG_BUFSIZE, reuse_port=reuse_port, loop=loop)
server_v6 = loop.run_until_complete(task_v6) server_v6 = loop.run_until_complete(task_v6)
try: try: