diff --git a/src/clients/networkmanager/mod.rs b/src/clients/networkmanager/mod.rs index b085a12..da7e5bf 100644 --- a/src/clients/networkmanager/mod.rs +++ b/src/clients/networkmanager/mod.rs @@ -45,28 +45,25 @@ impl Client { struct ClientInner { controller_sender: broadcast::Sender, sender: broadcast::Sender, - devices: RwLock>>, - watchers: RwLock, Device>>, + device_watchers: RwLock, DeviceWatcher>>, // TODO: Maybe find some way to late-init a dbus connection here // so we can just clone it when we need it instead of awaiting it every time } -#[derive(Debug)] -struct Device { - state_watcher: JoinHandle>, +#[derive(Clone, Debug)] +struct DeviceWatcher { + state_watcher: Arc>>, } impl ClientInner { fn new() -> ClientInner { let (controller_sender, _) = broadcast::channel(64); let (sender, _) = broadcast::channel(8); - let devices = RwLock::new(HashSet::new()); - let watchers = RwLock::new(HashMap::new()); + let device_watchers = RwLock::new(HashMap::new()); ClientInner { controller_sender, sender, - devices, - watchers, + device_watchers, } } @@ -98,35 +95,34 @@ impl ClientInner { let mut devices_changes = root.receive_all_devices_changed().await; while let Some(devices_change) = devices_changes.next().await { // The new list of devices from dbus, not to be confused with the added devices below - let new_devices = devices_change + let new_device_paths = devices_change .get() .await? .iter() .map(ObjectPath::to_owned) .collect::>(); - // TODO: Use `self.watchers` instead of `self.devices`, which requires creating all property watchers straightaway + let mut watchers = self.device_watchers.write().await; + let device_paths = watchers.keys().cloned().collect::>(); - // Atomic read-then-write of `devices` - let mut devices_locked = self.devices.write().await; - let devices_snapshot = devices_locked.clone(); - (*devices_locked).clone_from(&new_devices); - drop(devices_locked); + let added_device_paths = new_device_paths.difference(&device_paths); + for added_device_path in added_device_paths { + debug_assert!(!watchers.contains_key(added_device_path)); - let added_devices = new_devices.difference(&devices_snapshot); - for added_device in added_devices { - spawn(self.watch_device(added_device.to_owned())); + let watcher = self.watch_device(added_device_path.clone()); + watchers.insert(added_device_path.clone(), watcher); } // TODO: Inform module of removed devices - let removed_devices = devices_snapshot.difference(&new_devices); - for removed_device in removed_devices { - let mut watchers = self.watchers.write().await; - let device = watchers.get(removed_device).unwrap(); - device.state_watcher.abort(); - watchers.remove(removed_device); + let removed_device_paths = device_paths.difference(&new_device_paths); + for removed_device_path in removed_device_paths { + let watcher = watchers + .get(removed_device_path) + .expect("Device to be removed should be present in watchers"); + watcher.state_watcher.abort(); + watchers.remove(removed_device_path); - debug!("D-bus device state watcher for {} stopped", removed_device); + debug!("D-bus device watchers for {} stopped", removed_device_path); } } @@ -144,10 +140,7 @@ impl ClientInner { ModuleToClientEvent::NewController => { debug!("Client received NewController event"); - // We create a local clone here to avoid holding the lock for too long - let devices_snapshot = self.devices.read().await.clone(); - - for device_path in devices_snapshot { + for device_path in self.device_watchers.read().await.keys() { let device = DeviceDbusProxy::new(&dbus_connection, device_path).await?; let interface = device.interface().await?.to_string(); @@ -167,24 +160,18 @@ impl ClientInner { Ok(()) } - async fn watch_device(&'static self, path: ObjectPath<'static>) -> Result<()> { - debug_assert!(!self.watchers.read().await.contains_key(&path)); + fn watch_device(&'static self, path: ObjectPath<'static>) -> DeviceWatcher { + let state_watcher = Arc::new(spawn(self.watch_device_state(path))); - let state_watcher = spawn(self.watch_device_state(path.clone())); - self.watchers - .write() - .await - .insert(path, Device { state_watcher }); - - Ok(()) + DeviceWatcher { state_watcher } } async fn watch_device_state(&'static self, path: ObjectPath<'_>) -> Result<()> { + debug!("D-Bus device state watcher for {} starting", path); + let dbus_connection = Connection::system().await?; let device = DeviceDbusProxy::new(&dbus_connection, path.clone()).await?; - debug!("D-Bus device state watcher for {} starting", path); - let interface = device.interface().await?; let r#type = device.device_type().await?; @@ -208,8 +195,6 @@ impl ClientInner { })?; } - debug!("D-Bus device state watcher for {} ended", path); - Ok(()) } }