use explicit base classes for layered streams

This commit is contained in:
Alexander Bersenev
2018-06-04 18:03:42 +05:00
parent 44ab6fef13
commit 94be19087c

View File

@@ -118,29 +118,51 @@ def update_stats(user, connects=0, curr_connects_x2=0, octets=0):
octets=octets)
class CryptoWrappedStreamReader:
def __init__(self, stream, decryptor, block_size=1):
self.stream = stream
class LayeredStreamReaderBase:
def __init__(self, upstream):
self.upstream = upstream
async def read(self, n):
return await self.upstream.read(n)
async def readexactly(self, n):
return await self.upstream.readexactly(n)
class LayeredStreamWriterBase:
def __init__(self, upstream):
self.upstream = upstream
def write(self, data):
return self.upstream.write(data)
async def drain(self):
return await self.upstream.drain()
def close(self):
return self.upstream.close()
class CryptoWrappedStreamReader(LayeredStreamReaderBase):
def __init__(self, upstream, decryptor, block_size=1):
self.upstream = upstream
self.decryptor = decryptor
self.block_size = block_size
self.buf = bytearray()
def __getattr__(self, attr):
return getattr(self.stream, attr)
async def read(self, n):
if self.buf:
ret = bytes(self.buf)
self.buf.clear()
return ret
else:
data = await self.stream.read(n)
data = await self.upstream.read(n)
if not data:
return b""
needed_till_full_block = -len(data) % self.block_size
if needed_till_full_block > 0:
data += self.stream.readexactly(needed_till_full_block)
data += self.upstream.readexactly(needed_till_full_block)
return self.decryptor.decrypt(data)
async def readexactly(self, n):
@@ -149,7 +171,7 @@ class CryptoWrappedStreamReader:
needed_till_full_block = -to_read % self.block_size
to_read_block_aligned = to_read + needed_till_full_block
data = await self.stream.readexactly(to_read_block_aligned)
data = await self.upstream.readexactly(to_read_block_aligned)
self.buf += self.decryptor.decrypt(data)
ret = bytes(self.buf[:n])
@@ -157,38 +179,32 @@ class CryptoWrappedStreamReader:
return ret
class CryptoWrappedStreamWriter:
def __init__(self, stream, encryptor, block_size=1):
self.stream = stream
class CryptoWrappedStreamWriter(LayeredStreamWriterBase):
def __init__(self, upstream, encryptor, block_size=1):
self.upstream = upstream
self.encryptor = encryptor
self.block_size = block_size
def __getattr__(self, attr):
return getattr(self.stream, attr)
def write(self, data):
if len(data) % self.block_size != 0:
print("BUG: writing %d bytes not aligned to block size %d" % (
len(data), self.block_size))
return 0
q = self.encryptor.encrypt(data)
return self.stream.write(q)
return self.upstream.write(q)
class MTProtoFrameStreamReader:
def __init__(self, stream, seq_no=0):
self.stream = stream
class MTProtoFrameStreamReader(LayeredStreamReaderBase):
def __init__(self, upstream, seq_no=0):
self.upstream = upstream
self.seq_no = seq_no
def __getattr__(self, attr):
return getattr(self.stream, attr)
async def read(self, buf_size):
msg_len_bytes = await self.stream.readexactly(4)
msg_len_bytes = await self.upstream.readexactly(4)
msg_len = int.from_bytes(msg_len_bytes, "little")
# skip paddings
while msg_len == 4:
msg_len_bytes = await self.stream.readexactly(4)
msg_len_bytes = await self.upstream.readexactly(4)
msg_len = int.from_bytes(msg_len_bytes, "little")
len_is_bad = (msg_len % len(PADDING_FILLER) != 0)
@@ -196,7 +212,7 @@ class MTProtoFrameStreamReader:
print("msg_len is bad, closing connection", msg_len)
return b""
msg_seq_bytes = await self.stream.readexactly(4)
msg_seq_bytes = await self.upstream.readexactly(4)
msg_seq = int.from_bytes(msg_seq_bytes, "little", signed=True)
if msg_seq != self.seq_no:
print("unexpected seq_no")
@@ -204,8 +220,8 @@ class MTProtoFrameStreamReader:
self.seq_no += 1
data = await self.stream.readexactly(msg_len - 4 - 4 - 4)
checksum_bytes = await self.stream.readexactly(4)
data = await self.upstream.readexactly(msg_len - 4 - 4 - 4)
checksum_bytes = await self.upstream.readexactly(4)
checksum = int.from_bytes(checksum_bytes, "little")
computed_checksum = binascii.crc32(msg_len_bytes + msg_seq_bytes + data)
@@ -214,14 +230,11 @@ class MTProtoFrameStreamReader:
return data
class MTProtoFrameStreamWriter:
def __init__(self, stream, seq_no=0):
self.stream = stream
class MTProtoFrameStreamWriter(LayeredStreamWriterBase):
def __init__(self, upstream, seq_no=0):
self.upstream = upstream
self.seq_no = seq_no
def __getattr__(self, attr):
return getattr(self.stream, attr)
def write(self, msg):
len_bytes = int.to_bytes(len(msg) + 4 + 4 + 4, 4, "little")
seq_bytes = int.to_bytes(self.seq_no, 4, "little", signed=True)
@@ -233,42 +246,33 @@ class MTProtoFrameStreamWriter:
full_msg = msg_without_checksum + checksum
padding = PADDING_FILLER * ((-len(full_msg) % CBC_PADDING) // len(PADDING_FILLER))
return self.stream.write(full_msg + padding)
return self.upstream.write(full_msg + padding)
class MTProtoCompactFrameStreamReader:
def __init__(self, stream):
self.stream = stream
def __getattr__(self, attr):
return getattr(self.stream, attr)
class MTProtoCompactFrameStreamReader(LayeredStreamReaderBase):
async def read(self, buf_size):
msg_len_bytes = await self.stream.readexactly(1)
msg_len_bytes = await self.upstream.readexactly(1)
msg_len = int.from_bytes(msg_len_bytes, "little")
if msg_len >= 0x80:
msg_len -= 0x80
if msg_len == 0x7f:
msg_len_bytes = await self.stream.readexactly(3)
msg_len_bytes = await self.upstream.readexactly(3)
msg_len = int.from_bytes(msg_len_bytes, "little")
msg_len *= 4
data = await self.stream.readexactly(msg_len)
data = await self.upstream.readexactly(msg_len)
return data
class MTProtoCompactFrameStreamWriter:
def __init__(self, stream, seq_no=0):
self.stream = stream
class MTProtoCompactFrameStreamWriter(LayeredStreamWriterBase):
def __init__(self, upstream, seq_no=0):
self.upstream = upstream
self.seq_no = seq_no
def __getattr__(self, attr):
return getattr(self.stream, attr)
def write(self, data):
SMALL_PKT_BORDER = 0x7f
LARGE_PKT_BORGER = 256 ** 3
@@ -280,27 +284,21 @@ class MTProtoCompactFrameStreamWriter:
len_div_four = len(data) // 4
if len_div_four < SMALL_PKT_BORDER:
return self.stream.write(bytes([len_div_four]) + data)
return self.upstream.write(bytes([len_div_four]) + data)
elif len_div_four < LARGE_PKT_BORGER:
return self.stream.write(b'\x7f' + bytes(int.to_bytes(len_div_four, 3, 'little')) +
return self.upstream.write(b'\x7f' + bytes(int.to_bytes(len_div_four, 3, 'little')) +
data)
else:
print("Attempted to send too large pkt len =", len(data))
return 0
class ProxyReqStreamReader:
def __init__(self, stream):
self.stream = stream
def __getattr__(self, attr):
return getattr(self.stream, attr)
class ProxyReqStreamReader(LayeredStreamReaderBase):
async def read(self, msg):
RPC_PROXY_ANS = b"\x0d\xda\x03\x44"
RPC_CLOSE_EXT = b"\xa2\x34\xb6\x5e"
data = await self.stream.read(1)
data = await self.upstream.read(1)
if len(data) < 4:
return b""
@@ -316,12 +314,9 @@ class ProxyReqStreamReader:
return conn_data
class ProxyReqStreamWriter:
def __init__(self, stream):
self.stream = stream
def __getattr__(self, attr):
return getattr(self.stream, attr)
class ProxyReqStreamWriter(LayeredStreamWriterBase):
def __init__(self, upstream):
self.upstream = upstream
def write(self, msg):
RPC_PROXY_REQ = b"\xee\xf1\xce\x36"
@@ -343,7 +338,7 @@ class ProxyReqStreamWriter:
full_msg += bytes([len(AD_TAG)]) + AD_TAG + FOUR_BYTES_ALIGNER
full_msg += msg
return self.stream.write(full_msg)
return self.upstream.write(full_msg)
async def handle_handshake(reader, writer):
@@ -511,8 +506,8 @@ async def do_middleproxy_handshake(dc_idx):
return False
# get keys
tg_ip, tg_port = writer_tgt.stream.get_extra_info('peername')
my_ip, my_port = writer_tgt.stream.get_extra_info('sockname')
tg_ip, tg_port = writer_tgt.upstream.get_extra_info('peername')
my_ip, my_port = writer_tgt.upstream.get_extra_info('sockname')
global my_ip_info
if my_ip_info["ipv4"]:
@@ -542,11 +537,11 @@ async def do_middleproxy_handshake(dc_idx):
# TODO: pass client ip and port here for statistics
handshake = RPC_HANDSHAKE + RPC_FLAGS + SENDER_PID + PEER_PID
writer_tgt.stream = CryptoWrappedStreamWriter(writer_tgt.stream, encryptor, block_size=16)
writer_tgt.upstream = CryptoWrappedStreamWriter(writer_tgt.upstream, encryptor, block_size=16)
writer_tgt.write(handshake)
await writer_tgt.drain()
reader_tgt.stream = CryptoWrappedStreamReader(reader_tgt.stream, decryptor, block_size=16)
reader_tgt.upstream = CryptoWrappedStreamReader(reader_tgt.upstream, decryptor, block_size=16)
handshake_ans = await reader_tgt.read(1)
if len(handshake_ans) != RPC_HANDSHAKE_ANS_LEN:
@@ -570,7 +565,7 @@ async def handle_client(reader_clt, writer_clt):
return
reader_clt, writer_clt, user, dc_idx, enc_key_and_iv = clt_data
update_stats(user, connects=1)
if not USE_MIDDLE_PROXY:
@@ -669,7 +664,7 @@ def init_ip_info():
except Exception:
pass
if USE_MIDDLE_PROXY and not my_ip_info["ipv4"]: #and not my_ip_info["ipv6"]:
if USE_MIDDLE_PROXY and not my_ip_info["ipv4"]: # and not my_ip_info["ipv6"]:
print("Failed to determine your ip, advertising disabled", flush=True)
USE_MIDDLE_PROXY = False