mirror of
https://github.com/alexbers/mtprotoproxy.git
synced 2026-03-13 23:03:09 +00:00
more reliable logic of connection closing on errors
This commit is contained in:
@@ -155,6 +155,13 @@ class LayeredStreamWriterBase:
|
||||
def close(self):
|
||||
return self.upstream.close()
|
||||
|
||||
def abort(self):
|
||||
return self.upstream.transport.abort()
|
||||
|
||||
@property
|
||||
def transport(self):
|
||||
return self.upstream.transport
|
||||
|
||||
|
||||
class CryptoWrappedStreamReader(LayeredStreamReaderBase):
|
||||
def __init__(self, upstream, decryptor, block_size=1):
|
||||
@@ -616,7 +623,7 @@ async def do_middleproxy_handshake(dc_idx, cl_ip, cl_port):
|
||||
async def handle_client(reader_clt, writer_clt):
|
||||
clt_data = await handle_handshake(reader_clt, writer_clt)
|
||||
if not clt_data:
|
||||
writer_clt.close()
|
||||
writer_clt.transport.abort()
|
||||
return
|
||||
|
||||
reader_clt, writer_clt, user, dc_idx, enc_key_and_iv = clt_data
|
||||
@@ -633,7 +640,7 @@ async def handle_client(reader_clt, writer_clt):
|
||||
tg_data = await do_middleproxy_handshake(dc_idx, cl_ip, cl_port)
|
||||
|
||||
if not tg_data:
|
||||
writer_clt.close()
|
||||
writer_clt.transport.abort()
|
||||
return
|
||||
|
||||
reader_tg, writer_tg = tg_data
|
||||
@@ -668,11 +675,11 @@ async def handle_client(reader_clt, writer_clt):
|
||||
update_stats(user, octets=len(data))
|
||||
wr.write(data)
|
||||
await wr.drain()
|
||||
except (ConnectionResetError, BrokenPipeError, OSError, AttributeError,
|
||||
asyncio.streams.IncompleteReadError, TimeoutError) as e:
|
||||
wr.close()
|
||||
except (OSError, AttributeError, asyncio.streams.IncompleteReadError) as e:
|
||||
# print_err(e)
|
||||
pass
|
||||
finally:
|
||||
wr.transport.abort()
|
||||
update_stats(user, curr_connects_x2=-1)
|
||||
|
||||
asyncio.ensure_future(connect_reader_to_writer(reader_tg, writer_clt, user))
|
||||
@@ -683,7 +690,7 @@ async def handle_client_wrapper(reader, writer):
|
||||
try:
|
||||
await handle_client(reader, writer)
|
||||
except (asyncio.IncompleteReadError, ConnectionResetError, TimeoutError):
|
||||
writer.close()
|
||||
writer.transport.abort()
|
||||
|
||||
|
||||
async def stats_printer():
|
||||
@@ -739,6 +746,25 @@ def print_tg_info():
|
||||
print("{}: tg://proxy?{}".format(user, params_encodeded), flush=True)
|
||||
|
||||
|
||||
def loop_exception_handler(loop, context):
|
||||
exception = context.get("exception")
|
||||
transport = context.get("transport")
|
||||
if exception:
|
||||
if isinstance(exception, TimeoutError):
|
||||
if transport:
|
||||
print_err("Timeout, killing transport")
|
||||
transport.abort()
|
||||
return
|
||||
if isinstance(exception, OSError):
|
||||
IGNORE_ERRNO = {
|
||||
10038 # operation on non-socket on Windows, likely because fd == -1
|
||||
}
|
||||
if exception.errno in IGNORE_ERRNO:
|
||||
return
|
||||
|
||||
loop.default_exception_handler(context)
|
||||
|
||||
|
||||
def main():
|
||||
init_stats()
|
||||
|
||||
@@ -747,6 +773,8 @@ def main():
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.set_exception_handler(loop_exception_handler)
|
||||
|
||||
stats_printer_task = asyncio.Task(stats_printer())
|
||||
asyncio.ensure_future(stats_printer_task)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user