diff --git a/mtprotoproxy.py b/mtprotoproxy.py index 402bb8b..2d177a1 100755 --- a/mtprotoproxy.py +++ b/mtprotoproxy.py @@ -449,29 +449,34 @@ class TgConnectionPool: def register_host_port(self, host, port): if (host, port) not in self.pools: - self.pools[(host, port)] = set() + self.pools[(host, port)] = [] while len(self.pools[(host, port)]) < TgConnectionPool.MAX_CONNS_IN_POOL: connect_task = asyncio.ensure_future(self.open_tg_connection(host, port)) - self.pools[(host, port)].add(connect_task) + self.pools[(host, port)].append(connect_task) async def get_connection(self, host, port): self.register_host_port(host, port) - for task in self.pools[(host, port)].copy(): + ret = None + for task in self.pools[(host, port)][::]: if task.done(): - self.pools[(host, port)].remove(task) - self.register_host_port(host, port) - if task.exception(): + self.pools[(host, port)].remove(task) continue reader, writer = task.result() if writer.transport.is_closing(): + self.pools[(host, port)].remove(task) continue - return reader, writer + if not ret: + self.pools[(host, port)].remove(task) + ret = (reader, writer) + self.register_host_port(host, port) + if ret: + return ret return await self.open_tg_connection(host, port)