diff --git a/mtprotoproxy.py b/mtprotoproxy.py index 9531f92..6d4ea38 100755 --- a/mtprotoproxy.py +++ b/mtprotoproxy.py @@ -81,7 +81,7 @@ async def handle_handshake(reader, writer): encryptor = create_aes(key=enc_key, iv=int.from_bytes(enc_iv, "big")) decrypted = decryptor.decrypt(handshake) - + check_val = decrypted[MAGIC_VAL_POS:MAGIC_VAL_POS+4] if check_val != MAGIC_VAL_TO_CHECK: continue @@ -134,7 +134,7 @@ async def do_handshake(dc, dec_key_and_iv=None): enc_key_and_iv = rnd[SKIP_LEN:SKIP_LEN+KEY_LEN+IV_LEN] enc_key, enc_iv = enc_key_and_iv[:KEY_LEN], enc_key_and_iv[KEY_LEN:] encryptor = create_aes(key=enc_key, iv=int.from_bytes(enc_iv, "big")) - + rnd_enc = rnd[:MAGIC_VAL_POS] + encryptor.encrypt(rnd)[MAGIC_VAL_POS:] writer_tgt.write(rnd_enc) @@ -240,9 +240,15 @@ def main(): loop = asyncio.get_event_loop() stats_printer_task = asyncio.Task(stats_printer()) asyncio.ensure_future(stats_printer_task) - task = asyncio.start_server(handle_client_wrapper, - "0.0.0.0", PORT, loop=loop) - server = loop.run_until_complete(task) + + task_v4 = asyncio.start_server(handle_client_wrapper, + '0.0.0.0', PORT, loop=loop) + server_v4 = loop.run_until_complete(task_v4) + + if socket.has_ipv6: + task_v6 = asyncio.start_server(handle_client_wrapper, + '::', PORT, loop=loop) + server_v6 = loop.run_until_complete(task_v6) try: loop.run_forever() @@ -251,8 +257,13 @@ def main(): stats_printer_task.cancel() - server.close() - loop.run_until_complete(server.wait_closed()) + server_v4.close() + loop.run_until_complete(server_v4.wait_closed()) + + if socket.has_ipv6: + server_v6.close() + loop.run_until_complete(server_v6.wait_closed()) + loop.close()