diff --git a/mtprotoproxy.py b/mtprotoproxy.py index f543b26..2fc0acc 100755 --- a/mtprotoproxy.py +++ b/mtprotoproxy.py @@ -155,6 +155,13 @@ class LayeredStreamWriterBase: def close(self): return self.upstream.close() + def abort(self): + return self.upstream.transport.abort() + + @property + def transport(self): + return self.upstream.transport + class CryptoWrappedStreamReader(LayeredStreamReaderBase): def __init__(self, upstream, decryptor, block_size=1): @@ -616,7 +623,7 @@ async def do_middleproxy_handshake(dc_idx, cl_ip, cl_port): async def handle_client(reader_clt, writer_clt): clt_data = await handle_handshake(reader_clt, writer_clt) if not clt_data: - writer_clt.close() + writer_clt.transport.abort() return reader_clt, writer_clt, user, dc_idx, enc_key_and_iv = clt_data @@ -633,7 +640,7 @@ async def handle_client(reader_clt, writer_clt): tg_data = await do_middleproxy_handshake(dc_idx, cl_ip, cl_port) if not tg_data: - writer_clt.close() + writer_clt.transport.abort() return reader_tg, writer_tg = tg_data @@ -668,11 +675,11 @@ async def handle_client(reader_clt, writer_clt): update_stats(user, octets=len(data)) wr.write(data) await wr.drain() - except (ConnectionResetError, BrokenPipeError, OSError, AttributeError, - asyncio.streams.IncompleteReadError, TimeoutError) as e: - wr.close() + except (OSError, AttributeError, asyncio.streams.IncompleteReadError) as e: # print_err(e) + pass finally: + wr.transport.abort() update_stats(user, curr_connects_x2=-1) asyncio.ensure_future(connect_reader_to_writer(reader_tg, writer_clt, user)) @@ -683,7 +690,7 @@ async def handle_client_wrapper(reader, writer): try: await handle_client(reader, writer) except (asyncio.IncompleteReadError, ConnectionResetError, TimeoutError): - writer.close() + writer.transport.abort() async def stats_printer(): @@ -739,6 +746,25 @@ def print_tg_info(): print("{}: tg://proxy?{}".format(user, params_encodeded), flush=True) +def loop_exception_handler(loop, context): + exception = context.get("exception") + transport = context.get("transport") + if exception: + if isinstance(exception, TimeoutError): + if transport: + print_err("Timeout, killing transport") + transport.abort() + return + if isinstance(exception, OSError): + IGNORE_ERRNO = { + 10038 # operation on non-socket on Windows, likely because fd == -1 + } + if exception.errno in IGNORE_ERRNO: + return + + loop.default_exception_handler(context) + + def main(): init_stats() @@ -747,6 +773,8 @@ def main(): asyncio.set_event_loop(loop) loop = asyncio.get_event_loop() + loop.set_exception_handler(loop_exception_handler) + stats_printer_task = asyncio.Task(stats_printer()) asyncio.ensure_future(stats_printer_task)