diff --git a/mtprotoproxy.py b/mtprotoproxy.py index 82f83a7..2497363 100755 --- a/mtprotoproxy.py +++ b/mtprotoproxy.py @@ -118,29 +118,51 @@ def update_stats(user, connects=0, curr_connects_x2=0, octets=0): octets=octets) -class CryptoWrappedStreamReader: - def __init__(self, stream, decryptor, block_size=1): - self.stream = stream +class LayeredStreamReaderBase: + def __init__(self, upstream): + self.upstream = upstream + + async def read(self, n): + return await self.upstream.read(n) + + async def readexactly(self, n): + return await self.upstream.readexactly(n) + + +class LayeredStreamWriterBase: + def __init__(self, upstream): + self.upstream = upstream + + def write(self, data): + return self.upstream.write(data) + + async def drain(self): + return await self.upstream.drain() + + def close(self): + return self.upstream.close() + + +class CryptoWrappedStreamReader(LayeredStreamReaderBase): + def __init__(self, upstream, decryptor, block_size=1): + self.upstream = upstream self.decryptor = decryptor self.block_size = block_size self.buf = bytearray() - def __getattr__(self, attr): - return getattr(self.stream, attr) - async def read(self, n): if self.buf: ret = bytes(self.buf) self.buf.clear() return ret else: - data = await self.stream.read(n) + data = await self.upstream.read(n) if not data: return b"" needed_till_full_block = -len(data) % self.block_size if needed_till_full_block > 0: - data += self.stream.readexactly(needed_till_full_block) + data += self.upstream.readexactly(needed_till_full_block) return self.decryptor.decrypt(data) async def readexactly(self, n): @@ -149,7 +171,7 @@ class CryptoWrappedStreamReader: needed_till_full_block = -to_read % self.block_size to_read_block_aligned = to_read + needed_till_full_block - data = await self.stream.readexactly(to_read_block_aligned) + data = await self.upstream.readexactly(to_read_block_aligned) self.buf += self.decryptor.decrypt(data) ret = bytes(self.buf[:n]) @@ -157,38 +179,32 @@ class CryptoWrappedStreamReader: return ret -class CryptoWrappedStreamWriter: - def __init__(self, stream, encryptor, block_size=1): - self.stream = stream +class CryptoWrappedStreamWriter(LayeredStreamWriterBase): + def __init__(self, upstream, encryptor, block_size=1): + self.upstream = upstream self.encryptor = encryptor self.block_size = block_size - def __getattr__(self, attr): - return getattr(self.stream, attr) - def write(self, data): if len(data) % self.block_size != 0: print("BUG: writing %d bytes not aligned to block size %d" % ( len(data), self.block_size)) return 0 q = self.encryptor.encrypt(data) - return self.stream.write(q) + return self.upstream.write(q) -class MTProtoFrameStreamReader: - def __init__(self, stream, seq_no=0): - self.stream = stream +class MTProtoFrameStreamReader(LayeredStreamReaderBase): + def __init__(self, upstream, seq_no=0): + self.upstream = upstream self.seq_no = seq_no - def __getattr__(self, attr): - return getattr(self.stream, attr) - async def read(self, buf_size): - msg_len_bytes = await self.stream.readexactly(4) + msg_len_bytes = await self.upstream.readexactly(4) msg_len = int.from_bytes(msg_len_bytes, "little") # skip paddings while msg_len == 4: - msg_len_bytes = await self.stream.readexactly(4) + msg_len_bytes = await self.upstream.readexactly(4) msg_len = int.from_bytes(msg_len_bytes, "little") len_is_bad = (msg_len % len(PADDING_FILLER) != 0) @@ -196,7 +212,7 @@ class MTProtoFrameStreamReader: print("msg_len is bad, closing connection", msg_len) return b"" - msg_seq_bytes = await self.stream.readexactly(4) + msg_seq_bytes = await self.upstream.readexactly(4) msg_seq = int.from_bytes(msg_seq_bytes, "little", signed=True) if msg_seq != self.seq_no: print("unexpected seq_no") @@ -204,8 +220,8 @@ class MTProtoFrameStreamReader: self.seq_no += 1 - data = await self.stream.readexactly(msg_len - 4 - 4 - 4) - checksum_bytes = await self.stream.readexactly(4) + data = await self.upstream.readexactly(msg_len - 4 - 4 - 4) + checksum_bytes = await self.upstream.readexactly(4) checksum = int.from_bytes(checksum_bytes, "little") computed_checksum = binascii.crc32(msg_len_bytes + msg_seq_bytes + data) @@ -214,14 +230,11 @@ class MTProtoFrameStreamReader: return data -class MTProtoFrameStreamWriter: - def __init__(self, stream, seq_no=0): - self.stream = stream +class MTProtoFrameStreamWriter(LayeredStreamWriterBase): + def __init__(self, upstream, seq_no=0): + self.upstream = upstream self.seq_no = seq_no - def __getattr__(self, attr): - return getattr(self.stream, attr) - def write(self, msg): len_bytes = int.to_bytes(len(msg) + 4 + 4 + 4, 4, "little") seq_bytes = int.to_bytes(self.seq_no, 4, "little", signed=True) @@ -233,42 +246,33 @@ class MTProtoFrameStreamWriter: full_msg = msg_without_checksum + checksum padding = PADDING_FILLER * ((-len(full_msg) % CBC_PADDING) // len(PADDING_FILLER)) - return self.stream.write(full_msg + padding) + return self.upstream.write(full_msg + padding) -class MTProtoCompactFrameStreamReader: - def __init__(self, stream): - self.stream = stream - - def __getattr__(self, attr): - return getattr(self.stream, attr) - +class MTProtoCompactFrameStreamReader(LayeredStreamReaderBase): async def read(self, buf_size): - msg_len_bytes = await self.stream.readexactly(1) + msg_len_bytes = await self.upstream.readexactly(1) msg_len = int.from_bytes(msg_len_bytes, "little") if msg_len >= 0x80: msg_len -= 0x80 if msg_len == 0x7f: - msg_len_bytes = await self.stream.readexactly(3) + msg_len_bytes = await self.upstream.readexactly(3) msg_len = int.from_bytes(msg_len_bytes, "little") msg_len *= 4 - data = await self.stream.readexactly(msg_len) + data = await self.upstream.readexactly(msg_len) return data -class MTProtoCompactFrameStreamWriter: - def __init__(self, stream, seq_no=0): - self.stream = stream +class MTProtoCompactFrameStreamWriter(LayeredStreamWriterBase): + def __init__(self, upstream, seq_no=0): + self.upstream = upstream self.seq_no = seq_no - def __getattr__(self, attr): - return getattr(self.stream, attr) - def write(self, data): SMALL_PKT_BORDER = 0x7f LARGE_PKT_BORGER = 256 ** 3 @@ -280,27 +284,21 @@ class MTProtoCompactFrameStreamWriter: len_div_four = len(data) // 4 if len_div_four < SMALL_PKT_BORDER: - return self.stream.write(bytes([len_div_four]) + data) + return self.upstream.write(bytes([len_div_four]) + data) elif len_div_four < LARGE_PKT_BORGER: - return self.stream.write(b'\x7f' + bytes(int.to_bytes(len_div_four, 3, 'little')) + + return self.upstream.write(b'\x7f' + bytes(int.to_bytes(len_div_four, 3, 'little')) + data) else: print("Attempted to send too large pkt len =", len(data)) return 0 -class ProxyReqStreamReader: - def __init__(self, stream): - self.stream = stream - - def __getattr__(self, attr): - return getattr(self.stream, attr) - +class ProxyReqStreamReader(LayeredStreamReaderBase): async def read(self, msg): RPC_PROXY_ANS = b"\x0d\xda\x03\x44" RPC_CLOSE_EXT = b"\xa2\x34\xb6\x5e" - data = await self.stream.read(1) + data = await self.upstream.read(1) if len(data) < 4: return b"" @@ -316,12 +314,9 @@ class ProxyReqStreamReader: return conn_data -class ProxyReqStreamWriter: - def __init__(self, stream): - self.stream = stream - - def __getattr__(self, attr): - return getattr(self.stream, attr) +class ProxyReqStreamWriter(LayeredStreamWriterBase): + def __init__(self, upstream): + self.upstream = upstream def write(self, msg): RPC_PROXY_REQ = b"\xee\xf1\xce\x36" @@ -343,7 +338,7 @@ class ProxyReqStreamWriter: full_msg += bytes([len(AD_TAG)]) + AD_TAG + FOUR_BYTES_ALIGNER full_msg += msg - return self.stream.write(full_msg) + return self.upstream.write(full_msg) async def handle_handshake(reader, writer): @@ -511,8 +506,8 @@ async def do_middleproxy_handshake(dc_idx): return False # get keys - tg_ip, tg_port = writer_tgt.stream.get_extra_info('peername') - my_ip, my_port = writer_tgt.stream.get_extra_info('sockname') + tg_ip, tg_port = writer_tgt.upstream.get_extra_info('peername') + my_ip, my_port = writer_tgt.upstream.get_extra_info('sockname') global my_ip_info if my_ip_info["ipv4"]: @@ -542,11 +537,11 @@ async def do_middleproxy_handshake(dc_idx): # TODO: pass client ip and port here for statistics handshake = RPC_HANDSHAKE + RPC_FLAGS + SENDER_PID + PEER_PID - writer_tgt.stream = CryptoWrappedStreamWriter(writer_tgt.stream, encryptor, block_size=16) + writer_tgt.upstream = CryptoWrappedStreamWriter(writer_tgt.upstream, encryptor, block_size=16) writer_tgt.write(handshake) await writer_tgt.drain() - reader_tgt.stream = CryptoWrappedStreamReader(reader_tgt.stream, decryptor, block_size=16) + reader_tgt.upstream = CryptoWrappedStreamReader(reader_tgt.upstream, decryptor, block_size=16) handshake_ans = await reader_tgt.read(1) if len(handshake_ans) != RPC_HANDSHAKE_ANS_LEN: @@ -570,7 +565,7 @@ async def handle_client(reader_clt, writer_clt): return reader_clt, writer_clt, user, dc_idx, enc_key_and_iv = clt_data - + update_stats(user, connects=1) if not USE_MIDDLE_PROXY: @@ -669,7 +664,7 @@ def init_ip_info(): except Exception: pass - if USE_MIDDLE_PROXY and not my_ip_info["ipv4"]: #and not my_ip_info["ipv6"]: + if USE_MIDDLE_PROXY and not my_ip_info["ipv4"]: # and not my_ip_info["ipv6"]: print("Failed to determine your ip, advertising disabled", flush=True) USE_MIDDLE_PROXY = False