more reliable logic of connection closing on errors

This commit is contained in:
Alexander Bersenev
2018-06-07 18:38:56 +05:00
parent 07780602d1
commit 63b77ea637

View File

@@ -155,6 +155,13 @@ class LayeredStreamWriterBase:
def close(self):
return self.upstream.close()
def abort(self):
return self.upstream.transport.abort()
@property
def transport(self):
return self.upstream.transport
class CryptoWrappedStreamReader(LayeredStreamReaderBase):
def __init__(self, upstream, decryptor, block_size=1):
@@ -616,7 +623,7 @@ async def do_middleproxy_handshake(dc_idx, cl_ip, cl_port):
async def handle_client(reader_clt, writer_clt):
clt_data = await handle_handshake(reader_clt, writer_clt)
if not clt_data:
writer_clt.close()
writer_clt.transport.abort()
return
reader_clt, writer_clt, user, dc_idx, enc_key_and_iv = clt_data
@@ -633,7 +640,7 @@ async def handle_client(reader_clt, writer_clt):
tg_data = await do_middleproxy_handshake(dc_idx, cl_ip, cl_port)
if not tg_data:
writer_clt.close()
writer_clt.transport.abort()
return
reader_tg, writer_tg = tg_data
@@ -668,11 +675,11 @@ async def handle_client(reader_clt, writer_clt):
update_stats(user, octets=len(data))
wr.write(data)
await wr.drain()
except (ConnectionResetError, BrokenPipeError, OSError, AttributeError,
asyncio.streams.IncompleteReadError, TimeoutError) as e:
wr.close()
except (OSError, AttributeError, asyncio.streams.IncompleteReadError) as e:
# print_err(e)
pass
finally:
wr.transport.abort()
update_stats(user, curr_connects_x2=-1)
asyncio.ensure_future(connect_reader_to_writer(reader_tg, writer_clt, user))
@@ -683,7 +690,7 @@ async def handle_client_wrapper(reader, writer):
try:
await handle_client(reader, writer)
except (asyncio.IncompleteReadError, ConnectionResetError, TimeoutError):
writer.close()
writer.transport.abort()
async def stats_printer():
@@ -739,6 +746,25 @@ def print_tg_info():
print("{}: tg://proxy?{}".format(user, params_encodeded), flush=True)
def loop_exception_handler(loop, context):
exception = context.get("exception")
transport = context.get("transport")
if exception:
if isinstance(exception, TimeoutError):
if transport:
print_err("Timeout, killing transport")
transport.abort()
return
if isinstance(exception, OSError):
IGNORE_ERRNO = {
10038 # operation on non-socket on Windows, likely because fd == -1
}
if exception.errno in IGNORE_ERRNO:
return
loop.default_exception_handler(context)
def main():
init_stats()
@@ -747,6 +773,8 @@ def main():
asyncio.set_event_loop(loop)
loop = asyncio.get_event_loop()
loop.set_exception_handler(loop_exception_handler)
stats_printer_task = asyncio.Task(stats_printer())
asyncio.ensure_future(stats_printer_task)