diff --git a/src/mpd_now_playable/config/model.py b/src/mpd_now_playable/config/model.py index 0fde2d1..a112042 100644 --- a/src/mpd_now_playable/config/model.py +++ b/src/mpd_now_playable/config/model.py @@ -35,7 +35,7 @@ class WebsocketsReceiverConfig(BaseReceiverConfig): #: The hostname you'd like your WebSockets server to listen on. In most #: cases the default behaviour, which binds to all network interfaces, will #: be fine. - host: Optional[Host | tuple[Host, ...]] = None + host: Optional[Host] = None ReceiverConfig = Annotated[ diff --git a/src/mpd_now_playable/receivers/websockets/receiver.py b/src/mpd_now_playable/receivers/websockets/receiver.py index a125b6b..d218f35 100644 --- a/src/mpd_now_playable/receivers/websockets/receiver.py +++ b/src/mpd_now_playable/receivers/websockets/receiver.py @@ -2,7 +2,7 @@ from pathlib import Path import ormsgpack from websockets import broadcast -from websockets.server import WebSocketServerProtocol, serve +from websockets.asyncio.server import Server, ServerConnection, serve from yarl import URL from ...config.model import WebsocketsReceiverConfig @@ -24,12 +24,11 @@ def default(value: object) -> object: class WebsocketsReceiver(Receiver): config: WebsocketsReceiverConfig player: Player - connections: set[WebSocketServerProtocol] + server: Server last_status: bytes = MSGPACK_NULL def __init__(self, config: WebsocketsReceiverConfig): self.config = config - self.connections = set() @classmethod def loop_factory(cls) -> DefaultLoopFactory: @@ -37,18 +36,14 @@ class WebsocketsReceiver(Receiver): async def start(self, player: Player) -> None: self.player = player - await serve( + self.server = await serve( self.handle, host=self.config.host, port=self.config.port, reuse_port=True ) - async def handle(self, conn: WebSocketServerProtocol) -> None: - self.connections.add(conn) + async def handle(self, conn: ServerConnection) -> None: await conn.send(self.last_status) - try: - await conn.wait_closed() - finally: - self.connections.remove(conn) + await conn.wait_closed() async def update(self, playback: Playback) -> None: self.last_status = ormsgpack.packb(playback, default=default) - broadcast(self.connections, self.last_status) + broadcast(self.server.connections, self.last_status)