1use std::{collections::HashMap, fmt::Debug, sync::Arc};
3
4use anyhow::{bail, Context as _};
5use futures::{future::AbortHandle, stream::Abortable};
6use tokio::sync::{
7 watch::{self, Receiver, Sender},
8 RwLock, RwLockReadGuard,
9};
10use tracing::{error, warn, Instrument};
11
12use crate::config::ConfigManager;
13
14type LockedConfig = Arc<RwLock<HashMap<String, String>>>;
15type WatchCache = Arc<RwLock<HashMap<String, Receiver<HashMap<String, String>>>>>;
17
18struct ConfigReceiver {
20 pub name: String,
21 pub receiver: Receiver<HashMap<String, String>>,
22}
23
24#[derive(Default)]
27struct AbortHandles {
28 handles: Vec<AbortHandle>,
29}
30
31impl Drop for AbortHandles {
32 fn drop(&mut self) {
33 for handle in &self.handles {
34 handle.abort();
35 }
36 }
37}
38
39pub struct ConfigBundle {
50 merged_config: LockedConfig,
52 config_names: Vec<String>,
54 changed_receiver: Receiver<()>,
56 _handles: Arc<AbortHandles>,
60 _changed_notifier: Arc<Sender<()>>,
63}
64
65impl Clone for ConfigBundle {
66 fn clone(&self) -> Self {
67 let mut changed_receiver = self.changed_receiver.clone();
71 changed_receiver.mark_changed();
72 ConfigBundle {
73 merged_config: self.merged_config.clone(),
74 config_names: self.config_names.clone(),
75 changed_receiver,
76 _changed_notifier: self._changed_notifier.clone(),
77 _handles: self._handles.clone(),
78 }
79 }
80}
81
82impl Debug for ConfigBundle {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 f.debug_struct("ConfigBundle")
85 .field("merged_config", &self.merged_config)
86 .finish()
87 }
88}
89
90impl ConfigBundle {
91 #[must_use]
98 async fn new(receivers: Vec<ConfigReceiver>) -> Self {
99 let (abort_handles, mut registrations): (Vec<_>, Vec<_>) =
101 std::iter::repeat_with(AbortHandle::new_pair)
102 .take(receivers.len())
103 .unzip();
104 let (changed_notifier, changed_receiver) = watch::channel(());
106 let changed_notifier = Arc::new(changed_notifier);
107 let mut bundle = ConfigBundle {
108 merged_config: Arc::default(),
109 config_names: receivers.iter().map(|r| r.name.clone()).collect(),
110 changed_receiver,
111 _changed_notifier: changed_notifier.clone(),
112 _handles: Arc::new(AbortHandles {
113 handles: abort_handles,
114 }),
115 };
116 let ordered_configs: Arc<Vec<Receiver<HashMap<String, String>>>> =
117 Arc::new(receivers.iter().map(|r| r.receiver.clone()).collect());
118 update_merge(&bundle.merged_config, &changed_notifier, &ordered_configs).await;
119 for ConfigReceiver { name, mut receiver } in receivers {
121 let reg = registrations
124 .pop()
125 .expect("missing registration, this is developer error");
126 let cloned_name = name.clone();
127 let ordered_receivers = ordered_configs.clone();
128 let merged_config = bundle.merged_config.clone();
129 let notifier = changed_notifier.clone();
130 tokio::spawn(
131 Abortable::new(
132 async move {
133 loop {
134 match receiver.changed().await {
135 Ok(()) => {
136 update_merge(&merged_config, ¬ifier, &ordered_receivers)
137 .await;
138 }
139 Err(e) => {
140 warn!(error = %e, %name, "config sender dropped, updates will not be delivered");
141 return;
142 }
143 }
144 }
145 },
146 reg,
147 )
148 .instrument(tracing::trace_span!("config_update", name = %cloned_name)),
149 );
150 }
151 bundle.changed_receiver.mark_changed();
156 bundle
157 }
158
159 pub async fn get_config(&self) -> RwLockReadGuard<'_, HashMap<String, String>> {
162 self.merged_config.read().await
163 }
164
165 pub async fn changed(
171 &mut self,
172 ) -> anyhow::Result<RwLockReadGuard<'_, HashMap<String, String>>> {
173 if let Err(e) = self.changed_receiver.changed().await {
178 error!(error = %e, "Config changed receiver errored, this means that the config sender has dropped and the whole bundle has failed");
181 bail!("failed to read receiver: {e}");
182 }
183 Ok(self.merged_config.read().await)
184 }
185
186 #[must_use]
188 pub fn config_names(&self) -> &Vec<String> {
189 &self.config_names
190 }
191}
192
193#[derive(Clone)]
195pub struct BundleGenerator {
196 store: Arc<dyn ConfigManager>,
197 watch_cache: WatchCache,
198}
199
200impl BundleGenerator {
201 #[must_use]
203 pub fn new(store: Arc<dyn ConfigManager>) -> Self {
204 Self {
205 store,
206 watch_cache: Arc::default(),
207 }
208 }
209
210 pub async fn generate(&self, config_names: Vec<String>) -> anyhow::Result<ConfigBundle> {
213 let receivers: Vec<ConfigReceiver> =
214 futures::future::join_all(config_names.into_iter().map(|name| self.get_receiver(name)))
215 .await
216 .into_iter()
217 .collect::<anyhow::Result<_>>()?;
218 Ok(ConfigBundle::new(receivers).await)
219 }
220
221 async fn get_receiver(&self, name: String) -> anyhow::Result<ConfigReceiver> {
222 if let Some(receiver) = self.watch_cache.read().await.get(&name) {
224 return Ok(ConfigReceiver {
225 name,
226 receiver: receiver.clone(),
227 });
228 }
229
230 let receiver = self
231 .store
232 .watch(&name)
233 .await
234 .context(format!("error setting up watcher for {name}"))?;
235 self.watch_cache
236 .write()
237 .await
238 .insert(name.clone(), receiver.clone());
239 Ok(ConfigReceiver { name, receiver })
240 }
241}
242
243async fn update_merge(
244 merged_config: &RwLock<HashMap<String, String>>,
245 changed_notifier: &Sender<()>,
246 ordered_receivers: &[Receiver<HashMap<String, String>>],
247) {
248 let mut hashmap = merged_config.write().await;
251 hashmap.clear();
252
253 for recv in ordered_receivers {
258 hashmap.extend(recv.borrow().clone());
259 }
260 changed_notifier.send_replace(());
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 use std::time::Duration;
268
269 use tokio::sync::watch;
270
271 #[tokio::test]
272 async fn test_config_bundle() {
273 let (foo_tx, foo_rx) =
274 watch::channel(HashMap::from([("foo".to_string(), "bar".to_string())]));
275 let (bar_tx, bar_rx) = watch::channel(HashMap::new());
276 let (baz_tx, baz_rx) = watch::channel(HashMap::new());
277
278 let mut bundle = ConfigBundle::new(vec![
279 ConfigReceiver {
280 name: "foo".to_string(),
281 receiver: foo_rx,
282 },
283 ConfigReceiver {
284 name: "bar".to_string(),
285 receiver: bar_rx,
286 },
287 ConfigReceiver {
288 name: "baz".to_string(),
289 receiver: baz_rx,
290 },
291 ])
292 .await;
293
294 assert_eq!(
296 *bundle.get_config().await,
297 HashMap::from([("foo".to_string(), "bar".to_string())])
298 );
299
300 let _ = tokio::time::timeout(Duration::from_millis(50), bundle.changed())
302 .await
303 .expect("Should have received a config");
304
305 bar_tx.send_replace(HashMap::from([("foo".to_string(), "baz".to_string())]));
307 let conf = tokio::time::timeout(Duration::from_millis(50), bundle.changed())
309 .await
310 .expect("conf should have been present")
311 .expect("Should have received a config");
312 assert_eq!(
313 *conf,
314 HashMap::from([("foo".to_string(), "baz".to_string())])
315 );
316 drop(conf);
317
318 baz_tx.send_replace(HashMap::from([("star".to_string(), "wars".to_string())]));
320 let conf = tokio::time::timeout(Duration::from_millis(50), bundle.changed())
321 .await
322 .expect("conf should have been present")
323 .expect("Should have received a config");
324 assert_eq!(
325 *conf,
326 HashMap::from([
327 ("foo".to_string(), "baz".to_string()),
328 ("star".to_string(), "wars".to_string())
329 ])
330 );
331 drop(conf);
332
333 foo_tx.send_replace(HashMap::from([
335 ("starship".to_string(), "troopers".to_string()),
336 ("foo".to_string(), "bar".to_string()),
337 ]));
338 let conf = tokio::time::timeout(Duration::from_millis(50), bundle.changed())
339 .await
340 .expect("conf should have been present")
341 .expect("Should have received a config");
342 assert_eq!(
344 *conf,
345 HashMap::from([
346 ("foo".to_string(), "baz".to_string()),
347 ("star".to_string(), "wars".to_string()),
348 ("starship".to_string(), "troopers".to_string())
349 ]),
350 );
351 }
352}