mirror of
https://github.com/alexbers/mtprotoproxy.git
synced 2026-03-13 23:03:09 +00:00
use explicit base classes for layered streams
This commit is contained in:
141
mtprotoproxy.py
141
mtprotoproxy.py
@@ -118,29 +118,51 @@ def update_stats(user, connects=0, curr_connects_x2=0, octets=0):
|
||||
octets=octets)
|
||||
|
||||
|
||||
class CryptoWrappedStreamReader:
|
||||
def __init__(self, stream, decryptor, block_size=1):
|
||||
self.stream = stream
|
||||
class LayeredStreamReaderBase:
|
||||
def __init__(self, upstream):
|
||||
self.upstream = upstream
|
||||
|
||||
async def read(self, n):
|
||||
return await self.upstream.read(n)
|
||||
|
||||
async def readexactly(self, n):
|
||||
return await self.upstream.readexactly(n)
|
||||
|
||||
|
||||
class LayeredStreamWriterBase:
|
||||
def __init__(self, upstream):
|
||||
self.upstream = upstream
|
||||
|
||||
def write(self, data):
|
||||
return self.upstream.write(data)
|
||||
|
||||
async def drain(self):
|
||||
return await self.upstream.drain()
|
||||
|
||||
def close(self):
|
||||
return self.upstream.close()
|
||||
|
||||
|
||||
class CryptoWrappedStreamReader(LayeredStreamReaderBase):
|
||||
def __init__(self, upstream, decryptor, block_size=1):
|
||||
self.upstream = upstream
|
||||
self.decryptor = decryptor
|
||||
self.block_size = block_size
|
||||
self.buf = bytearray()
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.stream, attr)
|
||||
|
||||
async def read(self, n):
|
||||
if self.buf:
|
||||
ret = bytes(self.buf)
|
||||
self.buf.clear()
|
||||
return ret
|
||||
else:
|
||||
data = await self.stream.read(n)
|
||||
data = await self.upstream.read(n)
|
||||
if not data:
|
||||
return b""
|
||||
|
||||
needed_till_full_block = -len(data) % self.block_size
|
||||
if needed_till_full_block > 0:
|
||||
data += self.stream.readexactly(needed_till_full_block)
|
||||
data += self.upstream.readexactly(needed_till_full_block)
|
||||
return self.decryptor.decrypt(data)
|
||||
|
||||
async def readexactly(self, n):
|
||||
@@ -149,7 +171,7 @@ class CryptoWrappedStreamReader:
|
||||
needed_till_full_block = -to_read % self.block_size
|
||||
|
||||
to_read_block_aligned = to_read + needed_till_full_block
|
||||
data = await self.stream.readexactly(to_read_block_aligned)
|
||||
data = await self.upstream.readexactly(to_read_block_aligned)
|
||||
self.buf += self.decryptor.decrypt(data)
|
||||
|
||||
ret = bytes(self.buf[:n])
|
||||
@@ -157,38 +179,32 @@ class CryptoWrappedStreamReader:
|
||||
return ret
|
||||
|
||||
|
||||
class CryptoWrappedStreamWriter:
|
||||
def __init__(self, stream, encryptor, block_size=1):
|
||||
self.stream = stream
|
||||
class CryptoWrappedStreamWriter(LayeredStreamWriterBase):
|
||||
def __init__(self, upstream, encryptor, block_size=1):
|
||||
self.upstream = upstream
|
||||
self.encryptor = encryptor
|
||||
self.block_size = block_size
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.stream, attr)
|
||||
|
||||
def write(self, data):
|
||||
if len(data) % self.block_size != 0:
|
||||
print("BUG: writing %d bytes not aligned to block size %d" % (
|
||||
len(data), self.block_size))
|
||||
return 0
|
||||
q = self.encryptor.encrypt(data)
|
||||
return self.stream.write(q)
|
||||
return self.upstream.write(q)
|
||||
|
||||
|
||||
class MTProtoFrameStreamReader:
|
||||
def __init__(self, stream, seq_no=0):
|
||||
self.stream = stream
|
||||
class MTProtoFrameStreamReader(LayeredStreamReaderBase):
|
||||
def __init__(self, upstream, seq_no=0):
|
||||
self.upstream = upstream
|
||||
self.seq_no = seq_no
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.stream, attr)
|
||||
|
||||
async def read(self, buf_size):
|
||||
msg_len_bytes = await self.stream.readexactly(4)
|
||||
msg_len_bytes = await self.upstream.readexactly(4)
|
||||
msg_len = int.from_bytes(msg_len_bytes, "little")
|
||||
# skip paddings
|
||||
while msg_len == 4:
|
||||
msg_len_bytes = await self.stream.readexactly(4)
|
||||
msg_len_bytes = await self.upstream.readexactly(4)
|
||||
msg_len = int.from_bytes(msg_len_bytes, "little")
|
||||
|
||||
len_is_bad = (msg_len % len(PADDING_FILLER) != 0)
|
||||
@@ -196,7 +212,7 @@ class MTProtoFrameStreamReader:
|
||||
print("msg_len is bad, closing connection", msg_len)
|
||||
return b""
|
||||
|
||||
msg_seq_bytes = await self.stream.readexactly(4)
|
||||
msg_seq_bytes = await self.upstream.readexactly(4)
|
||||
msg_seq = int.from_bytes(msg_seq_bytes, "little", signed=True)
|
||||
if msg_seq != self.seq_no:
|
||||
print("unexpected seq_no")
|
||||
@@ -204,8 +220,8 @@ class MTProtoFrameStreamReader:
|
||||
|
||||
self.seq_no += 1
|
||||
|
||||
data = await self.stream.readexactly(msg_len - 4 - 4 - 4)
|
||||
checksum_bytes = await self.stream.readexactly(4)
|
||||
data = await self.upstream.readexactly(msg_len - 4 - 4 - 4)
|
||||
checksum_bytes = await self.upstream.readexactly(4)
|
||||
checksum = int.from_bytes(checksum_bytes, "little")
|
||||
|
||||
computed_checksum = binascii.crc32(msg_len_bytes + msg_seq_bytes + data)
|
||||
@@ -214,14 +230,11 @@ class MTProtoFrameStreamReader:
|
||||
return data
|
||||
|
||||
|
||||
class MTProtoFrameStreamWriter:
|
||||
def __init__(self, stream, seq_no=0):
|
||||
self.stream = stream
|
||||
class MTProtoFrameStreamWriter(LayeredStreamWriterBase):
|
||||
def __init__(self, upstream, seq_no=0):
|
||||
self.upstream = upstream
|
||||
self.seq_no = seq_no
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.stream, attr)
|
||||
|
||||
def write(self, msg):
|
||||
len_bytes = int.to_bytes(len(msg) + 4 + 4 + 4, 4, "little")
|
||||
seq_bytes = int.to_bytes(self.seq_no, 4, "little", signed=True)
|
||||
@@ -233,42 +246,33 @@ class MTProtoFrameStreamWriter:
|
||||
full_msg = msg_without_checksum + checksum
|
||||
padding = PADDING_FILLER * ((-len(full_msg) % CBC_PADDING) // len(PADDING_FILLER))
|
||||
|
||||
return self.stream.write(full_msg + padding)
|
||||
return self.upstream.write(full_msg + padding)
|
||||
|
||||
|
||||
class MTProtoCompactFrameStreamReader:
|
||||
def __init__(self, stream):
|
||||
self.stream = stream
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.stream, attr)
|
||||
|
||||
class MTProtoCompactFrameStreamReader(LayeredStreamReaderBase):
|
||||
async def read(self, buf_size):
|
||||
msg_len_bytes = await self.stream.readexactly(1)
|
||||
msg_len_bytes = await self.upstream.readexactly(1)
|
||||
msg_len = int.from_bytes(msg_len_bytes, "little")
|
||||
|
||||
if msg_len >= 0x80:
|
||||
msg_len -= 0x80
|
||||
|
||||
if msg_len == 0x7f:
|
||||
msg_len_bytes = await self.stream.readexactly(3)
|
||||
msg_len_bytes = await self.upstream.readexactly(3)
|
||||
msg_len = int.from_bytes(msg_len_bytes, "little")
|
||||
|
||||
msg_len *= 4
|
||||
|
||||
data = await self.stream.readexactly(msg_len)
|
||||
data = await self.upstream.readexactly(msg_len)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class MTProtoCompactFrameStreamWriter:
|
||||
def __init__(self, stream, seq_no=0):
|
||||
self.stream = stream
|
||||
class MTProtoCompactFrameStreamWriter(LayeredStreamWriterBase):
|
||||
def __init__(self, upstream, seq_no=0):
|
||||
self.upstream = upstream
|
||||
self.seq_no = seq_no
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.stream, attr)
|
||||
|
||||
def write(self, data):
|
||||
SMALL_PKT_BORDER = 0x7f
|
||||
LARGE_PKT_BORGER = 256 ** 3
|
||||
@@ -280,27 +284,21 @@ class MTProtoCompactFrameStreamWriter:
|
||||
len_div_four = len(data) // 4
|
||||
|
||||
if len_div_four < SMALL_PKT_BORDER:
|
||||
return self.stream.write(bytes([len_div_four]) + data)
|
||||
return self.upstream.write(bytes([len_div_four]) + data)
|
||||
elif len_div_four < LARGE_PKT_BORGER:
|
||||
return self.stream.write(b'\x7f' + bytes(int.to_bytes(len_div_four, 3, 'little')) +
|
||||
return self.upstream.write(b'\x7f' + bytes(int.to_bytes(len_div_four, 3, 'little')) +
|
||||
data)
|
||||
else:
|
||||
print("Attempted to send too large pkt len =", len(data))
|
||||
return 0
|
||||
|
||||
|
||||
class ProxyReqStreamReader:
|
||||
def __init__(self, stream):
|
||||
self.stream = stream
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.stream, attr)
|
||||
|
||||
class ProxyReqStreamReader(LayeredStreamReaderBase):
|
||||
async def read(self, msg):
|
||||
RPC_PROXY_ANS = b"\x0d\xda\x03\x44"
|
||||
RPC_CLOSE_EXT = b"\xa2\x34\xb6\x5e"
|
||||
|
||||
data = await self.stream.read(1)
|
||||
data = await self.upstream.read(1)
|
||||
|
||||
if len(data) < 4:
|
||||
return b""
|
||||
@@ -316,12 +314,9 @@ class ProxyReqStreamReader:
|
||||
return conn_data
|
||||
|
||||
|
||||
class ProxyReqStreamWriter:
|
||||
def __init__(self, stream):
|
||||
self.stream = stream
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return getattr(self.stream, attr)
|
||||
class ProxyReqStreamWriter(LayeredStreamWriterBase):
|
||||
def __init__(self, upstream):
|
||||
self.upstream = upstream
|
||||
|
||||
def write(self, msg):
|
||||
RPC_PROXY_REQ = b"\xee\xf1\xce\x36"
|
||||
@@ -343,7 +338,7 @@ class ProxyReqStreamWriter:
|
||||
full_msg += bytes([len(AD_TAG)]) + AD_TAG + FOUR_BYTES_ALIGNER
|
||||
full_msg += msg
|
||||
|
||||
return self.stream.write(full_msg)
|
||||
return self.upstream.write(full_msg)
|
||||
|
||||
|
||||
async def handle_handshake(reader, writer):
|
||||
@@ -511,8 +506,8 @@ async def do_middleproxy_handshake(dc_idx):
|
||||
return False
|
||||
|
||||
# get keys
|
||||
tg_ip, tg_port = writer_tgt.stream.get_extra_info('peername')
|
||||
my_ip, my_port = writer_tgt.stream.get_extra_info('sockname')
|
||||
tg_ip, tg_port = writer_tgt.upstream.get_extra_info('peername')
|
||||
my_ip, my_port = writer_tgt.upstream.get_extra_info('sockname')
|
||||
|
||||
global my_ip_info
|
||||
if my_ip_info["ipv4"]:
|
||||
@@ -542,11 +537,11 @@ async def do_middleproxy_handshake(dc_idx):
|
||||
# TODO: pass client ip and port here for statistics
|
||||
handshake = RPC_HANDSHAKE + RPC_FLAGS + SENDER_PID + PEER_PID
|
||||
|
||||
writer_tgt.stream = CryptoWrappedStreamWriter(writer_tgt.stream, encryptor, block_size=16)
|
||||
writer_tgt.upstream = CryptoWrappedStreamWriter(writer_tgt.upstream, encryptor, block_size=16)
|
||||
writer_tgt.write(handshake)
|
||||
await writer_tgt.drain()
|
||||
|
||||
reader_tgt.stream = CryptoWrappedStreamReader(reader_tgt.stream, decryptor, block_size=16)
|
||||
reader_tgt.upstream = CryptoWrappedStreamReader(reader_tgt.upstream, decryptor, block_size=16)
|
||||
|
||||
handshake_ans = await reader_tgt.read(1)
|
||||
if len(handshake_ans) != RPC_HANDSHAKE_ANS_LEN:
|
||||
@@ -570,7 +565,7 @@ async def handle_client(reader_clt, writer_clt):
|
||||
return
|
||||
|
||||
reader_clt, writer_clt, user, dc_idx, enc_key_and_iv = clt_data
|
||||
|
||||
|
||||
update_stats(user, connects=1)
|
||||
|
||||
if not USE_MIDDLE_PROXY:
|
||||
@@ -669,7 +664,7 @@ def init_ip_info():
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if USE_MIDDLE_PROXY and not my_ip_info["ipv4"]: #and not my_ip_info["ipv6"]:
|
||||
if USE_MIDDLE_PROXY and not my_ip_info["ipv4"]: # and not my_ip_info["ipv6"]:
|
||||
print("Failed to determine your ip, advertising disabled", flush=True)
|
||||
USE_MIDDLE_PROXY = False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user