mirror of
https://github.com/alexbers/mtprotoproxy.git
synced 2026-03-20 17:45:49 +00:00
use explicit base classes for layered streams
This commit is contained in:
139
mtprotoproxy.py
139
mtprotoproxy.py
@@ -118,29 +118,51 @@ def update_stats(user, connects=0, curr_connects_x2=0, octets=0):
|
|||||||
octets=octets)
|
octets=octets)
|
||||||
|
|
||||||
|
|
||||||
class CryptoWrappedStreamReader:
|
class LayeredStreamReaderBase:
|
||||||
def __init__(self, stream, decryptor, block_size=1):
|
def __init__(self, upstream):
|
||||||
self.stream = stream
|
self.upstream = upstream
|
||||||
|
|
||||||
|
async def read(self, n):
|
||||||
|
return await self.upstream.read(n)
|
||||||
|
|
||||||
|
async def readexactly(self, n):
|
||||||
|
return await self.upstream.readexactly(n)
|
||||||
|
|
||||||
|
|
||||||
|
class LayeredStreamWriterBase:
|
||||||
|
def __init__(self, upstream):
|
||||||
|
self.upstream = upstream
|
||||||
|
|
||||||
|
def write(self, data):
|
||||||
|
return self.upstream.write(data)
|
||||||
|
|
||||||
|
async def drain(self):
|
||||||
|
return await self.upstream.drain()
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
return self.upstream.close()
|
||||||
|
|
||||||
|
|
||||||
|
class CryptoWrappedStreamReader(LayeredStreamReaderBase):
|
||||||
|
def __init__(self, upstream, decryptor, block_size=1):
|
||||||
|
self.upstream = upstream
|
||||||
self.decryptor = decryptor
|
self.decryptor = decryptor
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
self.buf = bytearray()
|
self.buf = bytearray()
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
|
||||||
return getattr(self.stream, attr)
|
|
||||||
|
|
||||||
async def read(self, n):
|
async def read(self, n):
|
||||||
if self.buf:
|
if self.buf:
|
||||||
ret = bytes(self.buf)
|
ret = bytes(self.buf)
|
||||||
self.buf.clear()
|
self.buf.clear()
|
||||||
return ret
|
return ret
|
||||||
else:
|
else:
|
||||||
data = await self.stream.read(n)
|
data = await self.upstream.read(n)
|
||||||
if not data:
|
if not data:
|
||||||
return b""
|
return b""
|
||||||
|
|
||||||
needed_till_full_block = -len(data) % self.block_size
|
needed_till_full_block = -len(data) % self.block_size
|
||||||
if needed_till_full_block > 0:
|
if needed_till_full_block > 0:
|
||||||
data += self.stream.readexactly(needed_till_full_block)
|
data += self.upstream.readexactly(needed_till_full_block)
|
||||||
return self.decryptor.decrypt(data)
|
return self.decryptor.decrypt(data)
|
||||||
|
|
||||||
async def readexactly(self, n):
|
async def readexactly(self, n):
|
||||||
@@ -149,7 +171,7 @@ class CryptoWrappedStreamReader:
|
|||||||
needed_till_full_block = -to_read % self.block_size
|
needed_till_full_block = -to_read % self.block_size
|
||||||
|
|
||||||
to_read_block_aligned = to_read + needed_till_full_block
|
to_read_block_aligned = to_read + needed_till_full_block
|
||||||
data = await self.stream.readexactly(to_read_block_aligned)
|
data = await self.upstream.readexactly(to_read_block_aligned)
|
||||||
self.buf += self.decryptor.decrypt(data)
|
self.buf += self.decryptor.decrypt(data)
|
||||||
|
|
||||||
ret = bytes(self.buf[:n])
|
ret = bytes(self.buf[:n])
|
||||||
@@ -157,38 +179,32 @@ class CryptoWrappedStreamReader:
|
|||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
class CryptoWrappedStreamWriter:
|
class CryptoWrappedStreamWriter(LayeredStreamWriterBase):
|
||||||
def __init__(self, stream, encryptor, block_size=1):
|
def __init__(self, upstream, encryptor, block_size=1):
|
||||||
self.stream = stream
|
self.upstream = upstream
|
||||||
self.encryptor = encryptor
|
self.encryptor = encryptor
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
|
||||||
return getattr(self.stream, attr)
|
|
||||||
|
|
||||||
def write(self, data):
|
def write(self, data):
|
||||||
if len(data) % self.block_size != 0:
|
if len(data) % self.block_size != 0:
|
||||||
print("BUG: writing %d bytes not aligned to block size %d" % (
|
print("BUG: writing %d bytes not aligned to block size %d" % (
|
||||||
len(data), self.block_size))
|
len(data), self.block_size))
|
||||||
return 0
|
return 0
|
||||||
q = self.encryptor.encrypt(data)
|
q = self.encryptor.encrypt(data)
|
||||||
return self.stream.write(q)
|
return self.upstream.write(q)
|
||||||
|
|
||||||
|
|
||||||
class MTProtoFrameStreamReader:
|
class MTProtoFrameStreamReader(LayeredStreamReaderBase):
|
||||||
def __init__(self, stream, seq_no=0):
|
def __init__(self, upstream, seq_no=0):
|
||||||
self.stream = stream
|
self.upstream = upstream
|
||||||
self.seq_no = seq_no
|
self.seq_no = seq_no
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
|
||||||
return getattr(self.stream, attr)
|
|
||||||
|
|
||||||
async def read(self, buf_size):
|
async def read(self, buf_size):
|
||||||
msg_len_bytes = await self.stream.readexactly(4)
|
msg_len_bytes = await self.upstream.readexactly(4)
|
||||||
msg_len = int.from_bytes(msg_len_bytes, "little")
|
msg_len = int.from_bytes(msg_len_bytes, "little")
|
||||||
# skip paddings
|
# skip paddings
|
||||||
while msg_len == 4:
|
while msg_len == 4:
|
||||||
msg_len_bytes = await self.stream.readexactly(4)
|
msg_len_bytes = await self.upstream.readexactly(4)
|
||||||
msg_len = int.from_bytes(msg_len_bytes, "little")
|
msg_len = int.from_bytes(msg_len_bytes, "little")
|
||||||
|
|
||||||
len_is_bad = (msg_len % len(PADDING_FILLER) != 0)
|
len_is_bad = (msg_len % len(PADDING_FILLER) != 0)
|
||||||
@@ -196,7 +212,7 @@ class MTProtoFrameStreamReader:
|
|||||||
print("msg_len is bad, closing connection", msg_len)
|
print("msg_len is bad, closing connection", msg_len)
|
||||||
return b""
|
return b""
|
||||||
|
|
||||||
msg_seq_bytes = await self.stream.readexactly(4)
|
msg_seq_bytes = await self.upstream.readexactly(4)
|
||||||
msg_seq = int.from_bytes(msg_seq_bytes, "little", signed=True)
|
msg_seq = int.from_bytes(msg_seq_bytes, "little", signed=True)
|
||||||
if msg_seq != self.seq_no:
|
if msg_seq != self.seq_no:
|
||||||
print("unexpected seq_no")
|
print("unexpected seq_no")
|
||||||
@@ -204,8 +220,8 @@ class MTProtoFrameStreamReader:
|
|||||||
|
|
||||||
self.seq_no += 1
|
self.seq_no += 1
|
||||||
|
|
||||||
data = await self.stream.readexactly(msg_len - 4 - 4 - 4)
|
data = await self.upstream.readexactly(msg_len - 4 - 4 - 4)
|
||||||
checksum_bytes = await self.stream.readexactly(4)
|
checksum_bytes = await self.upstream.readexactly(4)
|
||||||
checksum = int.from_bytes(checksum_bytes, "little")
|
checksum = int.from_bytes(checksum_bytes, "little")
|
||||||
|
|
||||||
computed_checksum = binascii.crc32(msg_len_bytes + msg_seq_bytes + data)
|
computed_checksum = binascii.crc32(msg_len_bytes + msg_seq_bytes + data)
|
||||||
@@ -214,14 +230,11 @@ class MTProtoFrameStreamReader:
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
class MTProtoFrameStreamWriter:
|
class MTProtoFrameStreamWriter(LayeredStreamWriterBase):
|
||||||
def __init__(self, stream, seq_no=0):
|
def __init__(self, upstream, seq_no=0):
|
||||||
self.stream = stream
|
self.upstream = upstream
|
||||||
self.seq_no = seq_no
|
self.seq_no = seq_no
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
|
||||||
return getattr(self.stream, attr)
|
|
||||||
|
|
||||||
def write(self, msg):
|
def write(self, msg):
|
||||||
len_bytes = int.to_bytes(len(msg) + 4 + 4 + 4, 4, "little")
|
len_bytes = int.to_bytes(len(msg) + 4 + 4 + 4, 4, "little")
|
||||||
seq_bytes = int.to_bytes(self.seq_no, 4, "little", signed=True)
|
seq_bytes = int.to_bytes(self.seq_no, 4, "little", signed=True)
|
||||||
@@ -233,42 +246,33 @@ class MTProtoFrameStreamWriter:
|
|||||||
full_msg = msg_without_checksum + checksum
|
full_msg = msg_without_checksum + checksum
|
||||||
padding = PADDING_FILLER * ((-len(full_msg) % CBC_PADDING) // len(PADDING_FILLER))
|
padding = PADDING_FILLER * ((-len(full_msg) % CBC_PADDING) // len(PADDING_FILLER))
|
||||||
|
|
||||||
return self.stream.write(full_msg + padding)
|
return self.upstream.write(full_msg + padding)
|
||||||
|
|
||||||
|
|
||||||
class MTProtoCompactFrameStreamReader:
|
class MTProtoCompactFrameStreamReader(LayeredStreamReaderBase):
|
||||||
def __init__(self, stream):
|
|
||||||
self.stream = stream
|
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
|
||||||
return getattr(self.stream, attr)
|
|
||||||
|
|
||||||
async def read(self, buf_size):
|
async def read(self, buf_size):
|
||||||
msg_len_bytes = await self.stream.readexactly(1)
|
msg_len_bytes = await self.upstream.readexactly(1)
|
||||||
msg_len = int.from_bytes(msg_len_bytes, "little")
|
msg_len = int.from_bytes(msg_len_bytes, "little")
|
||||||
|
|
||||||
if msg_len >= 0x80:
|
if msg_len >= 0x80:
|
||||||
msg_len -= 0x80
|
msg_len -= 0x80
|
||||||
|
|
||||||
if msg_len == 0x7f:
|
if msg_len == 0x7f:
|
||||||
msg_len_bytes = await self.stream.readexactly(3)
|
msg_len_bytes = await self.upstream.readexactly(3)
|
||||||
msg_len = int.from_bytes(msg_len_bytes, "little")
|
msg_len = int.from_bytes(msg_len_bytes, "little")
|
||||||
|
|
||||||
msg_len *= 4
|
msg_len *= 4
|
||||||
|
|
||||||
data = await self.stream.readexactly(msg_len)
|
data = await self.upstream.readexactly(msg_len)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
class MTProtoCompactFrameStreamWriter:
|
class MTProtoCompactFrameStreamWriter(LayeredStreamWriterBase):
|
||||||
def __init__(self, stream, seq_no=0):
|
def __init__(self, upstream, seq_no=0):
|
||||||
self.stream = stream
|
self.upstream = upstream
|
||||||
self.seq_no = seq_no
|
self.seq_no = seq_no
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
|
||||||
return getattr(self.stream, attr)
|
|
||||||
|
|
||||||
def write(self, data):
|
def write(self, data):
|
||||||
SMALL_PKT_BORDER = 0x7f
|
SMALL_PKT_BORDER = 0x7f
|
||||||
LARGE_PKT_BORGER = 256 ** 3
|
LARGE_PKT_BORGER = 256 ** 3
|
||||||
@@ -280,27 +284,21 @@ class MTProtoCompactFrameStreamWriter:
|
|||||||
len_div_four = len(data) // 4
|
len_div_four = len(data) // 4
|
||||||
|
|
||||||
if len_div_four < SMALL_PKT_BORDER:
|
if len_div_four < SMALL_PKT_BORDER:
|
||||||
return self.stream.write(bytes([len_div_four]) + data)
|
return self.upstream.write(bytes([len_div_four]) + data)
|
||||||
elif len_div_four < LARGE_PKT_BORGER:
|
elif len_div_four < LARGE_PKT_BORGER:
|
||||||
return self.stream.write(b'\x7f' + bytes(int.to_bytes(len_div_four, 3, 'little')) +
|
return self.upstream.write(b'\x7f' + bytes(int.to_bytes(len_div_four, 3, 'little')) +
|
||||||
data)
|
data)
|
||||||
else:
|
else:
|
||||||
print("Attempted to send too large pkt len =", len(data))
|
print("Attempted to send too large pkt len =", len(data))
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
class ProxyReqStreamReader:
|
class ProxyReqStreamReader(LayeredStreamReaderBase):
|
||||||
def __init__(self, stream):
|
|
||||||
self.stream = stream
|
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
|
||||||
return getattr(self.stream, attr)
|
|
||||||
|
|
||||||
async def read(self, msg):
|
async def read(self, msg):
|
||||||
RPC_PROXY_ANS = b"\x0d\xda\x03\x44"
|
RPC_PROXY_ANS = b"\x0d\xda\x03\x44"
|
||||||
RPC_CLOSE_EXT = b"\xa2\x34\xb6\x5e"
|
RPC_CLOSE_EXT = b"\xa2\x34\xb6\x5e"
|
||||||
|
|
||||||
data = await self.stream.read(1)
|
data = await self.upstream.read(1)
|
||||||
|
|
||||||
if len(data) < 4:
|
if len(data) < 4:
|
||||||
return b""
|
return b""
|
||||||
@@ -316,12 +314,9 @@ class ProxyReqStreamReader:
|
|||||||
return conn_data
|
return conn_data
|
||||||
|
|
||||||
|
|
||||||
class ProxyReqStreamWriter:
|
class ProxyReqStreamWriter(LayeredStreamWriterBase):
|
||||||
def __init__(self, stream):
|
def __init__(self, upstream):
|
||||||
self.stream = stream
|
self.upstream = upstream
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
|
||||||
return getattr(self.stream, attr)
|
|
||||||
|
|
||||||
def write(self, msg):
|
def write(self, msg):
|
||||||
RPC_PROXY_REQ = b"\xee\xf1\xce\x36"
|
RPC_PROXY_REQ = b"\xee\xf1\xce\x36"
|
||||||
@@ -343,7 +338,7 @@ class ProxyReqStreamWriter:
|
|||||||
full_msg += bytes([len(AD_TAG)]) + AD_TAG + FOUR_BYTES_ALIGNER
|
full_msg += bytes([len(AD_TAG)]) + AD_TAG + FOUR_BYTES_ALIGNER
|
||||||
full_msg += msg
|
full_msg += msg
|
||||||
|
|
||||||
return self.stream.write(full_msg)
|
return self.upstream.write(full_msg)
|
||||||
|
|
||||||
|
|
||||||
async def handle_handshake(reader, writer):
|
async def handle_handshake(reader, writer):
|
||||||
@@ -511,8 +506,8 @@ async def do_middleproxy_handshake(dc_idx):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# get keys
|
# get keys
|
||||||
tg_ip, tg_port = writer_tgt.stream.get_extra_info('peername')
|
tg_ip, tg_port = writer_tgt.upstream.get_extra_info('peername')
|
||||||
my_ip, my_port = writer_tgt.stream.get_extra_info('sockname')
|
my_ip, my_port = writer_tgt.upstream.get_extra_info('sockname')
|
||||||
|
|
||||||
global my_ip_info
|
global my_ip_info
|
||||||
if my_ip_info["ipv4"]:
|
if my_ip_info["ipv4"]:
|
||||||
@@ -542,11 +537,11 @@ async def do_middleproxy_handshake(dc_idx):
|
|||||||
# TODO: pass client ip and port here for statistics
|
# TODO: pass client ip and port here for statistics
|
||||||
handshake = RPC_HANDSHAKE + RPC_FLAGS + SENDER_PID + PEER_PID
|
handshake = RPC_HANDSHAKE + RPC_FLAGS + SENDER_PID + PEER_PID
|
||||||
|
|
||||||
writer_tgt.stream = CryptoWrappedStreamWriter(writer_tgt.stream, encryptor, block_size=16)
|
writer_tgt.upstream = CryptoWrappedStreamWriter(writer_tgt.upstream, encryptor, block_size=16)
|
||||||
writer_tgt.write(handshake)
|
writer_tgt.write(handshake)
|
||||||
await writer_tgt.drain()
|
await writer_tgt.drain()
|
||||||
|
|
||||||
reader_tgt.stream = CryptoWrappedStreamReader(reader_tgt.stream, decryptor, block_size=16)
|
reader_tgt.upstream = CryptoWrappedStreamReader(reader_tgt.upstream, decryptor, block_size=16)
|
||||||
|
|
||||||
handshake_ans = await reader_tgt.read(1)
|
handshake_ans = await reader_tgt.read(1)
|
||||||
if len(handshake_ans) != RPC_HANDSHAKE_ANS_LEN:
|
if len(handshake_ans) != RPC_HANDSHAKE_ANS_LEN:
|
||||||
@@ -669,7 +664,7 @@ def init_ip_info():
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if USE_MIDDLE_PROXY and not my_ip_info["ipv4"]: #and not my_ip_info["ipv6"]:
|
if USE_MIDDLE_PROXY and not my_ip_info["ipv4"]: # and not my_ip_info["ipv6"]:
|
||||||
print("Failed to determine your ip, advertising disabled", flush=True)
|
print("Failed to determine your ip, advertising disabled", flush=True)
|
||||||
USE_MIDDLE_PROXY = False
|
USE_MIDDLE_PROXY = False
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user