diff --git a/src/clients/networkmanager/mod.rs b/src/clients/networkmanager/mod.rs index 4825d0b..b085a12 100644 --- a/src/clients/networkmanager/mod.rs +++ b/src/clients/networkmanager/mod.rs @@ -1,9 +1,10 @@ use color_eyre::Result; use color_eyre::eyre::Ok; use futures_lite::StreamExt; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; use tokio::sync::{RwLock, broadcast}; +use tokio::task::JoinHandle; use tracing::debug; use zbus::Connection; use zbus::zvariant::ObjectPath; @@ -45,17 +46,27 @@ struct ClientInner { controller_sender: broadcast::Sender, sender: broadcast::Sender, devices: RwLock>>, + watchers: RwLock, Device>>, + // TODO: Maybe find some way to late-init a dbus connection here + // so we can just clone it when we need it instead of awaiting it every time +} + +#[derive(Debug)] +struct Device { + state_watcher: JoinHandle>, } impl ClientInner { fn new() -> ClientInner { let (controller_sender, _) = broadcast::channel(64); let (sender, _) = broadcast::channel(8); - let devices = RwLock::new(HashSet::::new()); + let devices = RwLock::new(HashSet::new()); + let watchers = RwLock::new(HashMap::new()); ClientInner { controller_sender, sender, devices, + watchers, } } @@ -70,11 +81,11 @@ impl ClientInner { Ok(()) } - pub fn subscribe(&self) -> broadcast::Receiver { + fn subscribe(&self) -> broadcast::Receiver { self.controller_sender.subscribe() } - pub fn get_sender(&self) -> broadcast::Sender { + fn get_sender(&self) -> broadcast::Sender { self.sender.clone() } @@ -94,6 +105,8 @@ impl ClientInner { .map(ObjectPath::to_owned) .collect::>(); + // TODO: Use `self.watchers` instead of `self.devices`, which requires creating all property watchers straightaway + // Atomic read-then-write of `devices` let mut devices_locked = self.devices.write().await; let devices_snapshot = devices_locked.clone(); @@ -105,9 +118,16 @@ impl ClientInner { spawn(self.watch_device(added_device.to_owned())); } - let _removed_devices = devices_snapshot.difference(&new_devices); - // TODO: Store join handles for watchers and abort them when their device is removed // TODO: Inform module of removed devices + let removed_devices = devices_snapshot.difference(&new_devices); + for removed_device in removed_devices { + let mut watchers = self.watchers.write().await; + let device = watchers.get(removed_device).unwrap(); + device.state_watcher.abort(); + watchers.remove(removed_device); + + debug!("D-bus device state watcher for {} stopped", removed_device); + } } Ok(()) @@ -117,13 +137,13 @@ impl ClientInner { &'static self, mut receiver: broadcast::Receiver, ) -> Result<()> { + let dbus_connection = Connection::system().await?; + while let Result::Ok(event) = receiver.recv().await { match event { ModuleToClientEvent::NewController => { debug!("Client received NewController event"); - let dbus_connection = Connection::system().await?; - // We create a local clone here to avoid holding the lock for too long let devices_snapshot = self.devices.read().await.clone(); @@ -148,16 +168,20 @@ impl ClientInner { } async fn watch_device(&'static self, path: ObjectPath<'static>) -> Result<()> { - let dbus_connection = Connection::system().await?; - let device = DeviceDbusProxy::new(&dbus_connection, path).await?; + debug_assert!(!self.watchers.read().await.contains_key(&path)); - spawn(self.watch_device_state(device)); + let state_watcher = spawn(self.watch_device_state(path.clone())); + self.watchers + .write() + .await + .insert(path, Device { state_watcher }); Ok(()) } - async fn watch_device_state(&'static self, device: DeviceDbusProxy<'_>) -> Result<()> { - let path = device.inner().path(); + async fn watch_device_state(&'static self, path: ObjectPath<'_>) -> Result<()> { + let dbus_connection = Connection::system().await?; + let device = DeviceDbusProxy::new(&dbus_connection, path.clone()).await?; debug!("D-Bus device state watcher for {} starting", path);