diff --git a/libp2p/transports/wstransport.nim b/libp2p/transports/wstransport.nim index 584c673094..b75a009251 100644 --- a/libp2p/transports/wstransport.nim +++ b/libp2p/transports/wstransport.nim @@ -135,6 +135,7 @@ type WsTransport* = ref object of Transport httpservers: seq[HttpServer] wsserver: WSServer connections: array[Direction, seq[WsStream]] + connectionCleanupFuts: seq[Future[void]] acceptLoop: Future[void] handshakeFuts: seq[Future[void]] acceptResults: AsyncQueue[RawConn] @@ -408,12 +409,11 @@ method stop*(self: WsTransport) {.async: (raises: []).} = await procCall Transport(self).stop() # call base var toWait: seq[Future[void]] - if not isNil(self.acceptLoop) and not self.acceptLoop.finished: + if not self.acceptLoop.isNil: toWait.add(self.acceptLoop.cancelAndWait()) for fut in self.handshakeFuts: - if not fut.finished: - toWait.add(fut.cancelAndWait()) + toWait.add(fut.cancelAndWait()) for server in self.httpservers: server.stop() @@ -421,13 +421,17 @@ method stop*(self: WsTransport) {.async: (raises: []).} = await allFutures(toWait) + # stop connections and wait for them to be closed discard await allFinished( self.connections[Direction.In].mapIt(it.close()) & self.connections[Direction.Out].mapIt(it.close()) ) + self.connectionCleanupFuts.keepItIf(not it.finished) + discard await allFinished(self.connectionCleanupFuts) self.httpservers = @[] self.handshakeFuts = @[] + self.connectionCleanupFuts = @[] self.acceptLoop = nil trace "Transport stopped" except CatchableError as e: @@ -469,7 +473,9 @@ proc connHandler( self.connections[dir].keepItIf(it != conn) trace "Cleaned up client" - asyncSpawn onClose() + self.connectionCleanupFuts.keepItIf(not it.finished) + self.connectionCleanupFuts.add(onClose()) + return conn method accept*(