1use std::{collections::HashMap, fmt::Debug, sync::Arc};
3
4use anyhow::{bail, Context};
5use async_nats::jetstream::kv::{Operation, Store};
6use futures::{future::AbortHandle, stream::Abortable, TryStreamExt};
7use tokio::sync::{
8 watch::{self, Receiver, Sender},
9 RwLock, RwLockReadGuard,
10};
11use tracing::{error, warn, Instrument};
12
13type LockedConfig = Arc<RwLock<HashMap<String, String>>>;
14type WatchCache = Arc<RwLock<HashMap<String, Receiver<HashMap<String, String>>>>>;
16
17struct ConfigReceiver {
19 pub name: String,
20 pub receiver: Receiver<HashMap<String, String>>,
21}
22
23#[derive(Default)]
26struct AbortHandles {
27 handles: Vec<AbortHandle>,
28}
29
30impl Drop for AbortHandles {
31 fn drop(&mut self) {
32 for handle in &self.handles {
33 handle.abort();
34 }
35 }
36}
37
38pub struct ConfigBundle {
49 merged_config: LockedConfig,
51 config_names: Vec<String>,
53 changed_receiver: Receiver<()>,
55 _handles: Arc<AbortHandles>,
59 _changed_notifier: Arc<Sender<()>>,
62}
63
64impl Clone for ConfigBundle {
65 fn clone(&self) -> Self {
66 let mut changed_receiver = self.changed_receiver.clone();
70 changed_receiver.mark_changed();
71 ConfigBundle {
72 merged_config: self.merged_config.clone(),
73 config_names: self.config_names.clone(),
74 changed_receiver,
75 _changed_notifier: self._changed_notifier.clone(),
76 _handles: self._handles.clone(),
77 }
78 }
79}
80
81impl Debug for ConfigBundle {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 f.debug_struct("ConfigBundle")
84 .field("merged_config", &self.merged_config)
85 .finish()
86 }
87}
88
89impl ConfigBundle {
90 #[must_use]
97 async fn new(receivers: Vec<ConfigReceiver>) -> Self {
98 let (abort_handles, mut registrations): (Vec<_>, Vec<_>) =
100 std::iter::repeat_with(AbortHandle::new_pair)
101 .take(receivers.len())
102 .unzip();
103 let (changed_notifier, changed_receiver) = watch::channel(());
105 let changed_notifier = Arc::new(changed_notifier);
106 let mut bundle = ConfigBundle {
107 merged_config: Arc::default(),
108 config_names: receivers.iter().map(|r| r.name.clone()).collect(),
109 changed_receiver,
110 _changed_notifier: changed_notifier.clone(),
111 _handles: Arc::new(AbortHandles {
112 handles: abort_handles,
113 }),
114 };
115 let ordered_configs: Arc<Vec<Receiver<HashMap<String, String>>>> =
116 Arc::new(receivers.iter().map(|r| r.receiver.clone()).collect());
117 update_merge(&bundle.merged_config, &changed_notifier, &ordered_configs).await;
118 for ConfigReceiver { name, mut receiver } in receivers {
120 let reg = registrations
123 .pop()
124 .expect("missing registration, this is developer error");
125 let cloned_name = name.clone();
126 let ordered_receivers = ordered_configs.clone();
127 let merged_config = bundle.merged_config.clone();
128 let notifier = changed_notifier.clone();
129 tokio::spawn(
130 Abortable::new(
131 async move {
132 loop {
133 match receiver.changed().await {
134 Ok(()) => {
135 update_merge(&merged_config, ¬ifier, &ordered_receivers)
136 .await;
137 }
138 Err(e) => {
139 warn!(error = %e, %name, "config sender dropped, updates will not be delivered");
140 return;
141 }
142 }
143 }
144 },
145 reg,
146 )
147 .instrument(tracing::trace_span!("config_update", name = %cloned_name)),
148 );
149 }
150 bundle.changed_receiver.mark_changed();
155 bundle
156 }
157
158 pub async fn get_config(&self) -> RwLockReadGuard<'_, HashMap<String, String>> {
161 self.merged_config.read().await
162 }
163
164 pub async fn changed(
170 &mut self,
171 ) -> anyhow::Result<RwLockReadGuard<'_, HashMap<String, String>>> {
172 if let Err(e) = self.changed_receiver.changed().await {
177 error!(error = %e, "Config changed receiver errored, this means that the config sender has dropped and the whole bundle has failed");
180 bail!("failed to read receiver: {e}");
181 }
182 Ok(self.merged_config.read().await)
183 }
184
185 #[must_use]
187 pub fn config_names(&self) -> &Vec<String> {
188 &self.config_names
189 }
190}
191
192#[derive(Clone)]
194pub struct BundleGenerator {
195 store: Store,
196 watch_cache: WatchCache,
197 watch_handles: Arc<RwLock<AbortHandles>>,
198}
199
200impl BundleGenerator {
201 #[must_use]
203 pub fn new(store: Store) -> Self {
204 Self {
205 store,
206 watch_cache: Arc::default(),
207 watch_handles: Arc::default(),
208 }
209 }
210
211 pub async fn generate(&self, config_names: Vec<String>) -> anyhow::Result<ConfigBundle> {
214 let receivers: Vec<ConfigReceiver> =
215 futures::future::join_all(config_names.into_iter().map(|name| self.get_receiver(name)))
216 .await
217 .into_iter()
218 .collect::<anyhow::Result<_>>()?;
219 Ok(ConfigBundle::new(receivers).await)
220 }
221
222 async fn get_receiver(&self, name: String) -> anyhow::Result<ConfigReceiver> {
223 if let Some(receiver) = self.watch_cache.read().await.get(&name) {
225 return Ok(ConfigReceiver {
226 name,
227 receiver: receiver.clone(),
228 });
229 }
230
231 let config: HashMap<String, String> = match self.store.get(&name).await {
235 Ok(Some(data)) => serde_json::from_slice(&data)
236 .context("Data corruption error, unable to decode data from store")?,
237 Ok(None) => return Err(anyhow::anyhow!("Config {} does not exist", name)),
238 Err(e) => return Err(anyhow::anyhow!("Error fetching config {}: {}", name, e)),
239 };
240
241 let (tx, rx) = watch::channel(config);
244 let (done, wait) = tokio::sync::oneshot::channel();
245 let (handle, reg) = AbortHandle::new_pair();
246 tokio::task::spawn(Abortable::new(
247 watcher_loop(self.store.clone(), name.clone(), tx, done),
248 reg,
249 ));
250
251 wait.await
252 .context("Error waiting for watcher to start")?
253 .context("Error waiting for watcher to start")?;
254
255 self.watch_handles.write().await.handles.push(handle);
261 self.watch_cache
262 .write()
263 .await
264 .insert(name.clone(), rx.clone());
265
266 Ok(ConfigReceiver { name, receiver: rx })
267 }
268}
269
270async fn watcher_loop(
271 store: Store,
272 name: String,
273 tx: watch::Sender<HashMap<String, String>>,
274 done: tokio::sync::oneshot::Sender<anyhow::Result<()>>,
275) {
276 let mut watcher = match store.watch(&name).await {
278 Ok(watcher) => {
279 done.send(Ok(())).expect(
280 "Receiver for watcher setup should not have been dropped. This is programmer error",
281 );
282 watcher
283 }
284 Err(e) => {
285 done.send(Err(anyhow::anyhow!(
286 "Error setting up watcher for {}: {}",
287 name,
288 e
289 )))
290 .expect(
291 "Receiver for watcher setup should not have been dropped. This is programmer error",
292 );
293 return;
294 }
295 };
296 loop {
297 match watcher.try_next().await {
298 Ok(Some(entry)) if matches!(entry.operation, Operation::Delete | Operation::Purge) => {
299 tx.send_replace(HashMap::new());
303 }
304 Ok(Some(entry)) => {
305 let config: HashMap<String, String> = match serde_json::from_slice(&entry.value) {
306 Ok(config) => config,
307 Err(e) => {
308 error!(%name, error = %e, "Error decoding config from store during watch");
309 continue;
310 }
311 };
312 tx.send_if_modified(|current| {
313 if current == &config {
314 false
315 } else {
316 *current = config;
317 true
318 }
319 });
320 }
321 Ok(None) => {
322 error!(%name, "Watcher for config has closed");
323 return;
324 }
325 Err(e) => {
326 error!(%name, error = %e, "Error reading from watcher for config. Will wait for next entry");
327 continue;
328 }
329 }
330 }
331}
332
333async fn update_merge(
334 merged_config: &RwLock<HashMap<String, String>>,
335 changed_notifier: &Sender<()>,
336 ordered_receivers: &[Receiver<HashMap<String, String>>],
337) {
338 let mut hashmap = merged_config.write().await;
341 hashmap.clear();
342
343 for recv in ordered_receivers {
348 hashmap.extend(recv.borrow().clone());
349 }
350 changed_notifier.send_replace(());
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357 use std::time::Duration;
358
359 use tokio::sync::watch;
360
361 #[tokio::test]
362 async fn test_config_bundle() {
363 let (foo_tx, foo_rx) =
364 watch::channel(HashMap::from([("foo".to_string(), "bar".to_string())]));
365 let (bar_tx, bar_rx) = watch::channel(HashMap::new());
366 let (baz_tx, baz_rx) = watch::channel(HashMap::new());
367
368 let mut bundle = ConfigBundle::new(vec![
369 ConfigReceiver {
370 name: "foo".to_string(),
371 receiver: foo_rx,
372 },
373 ConfigReceiver {
374 name: "bar".to_string(),
375 receiver: bar_rx,
376 },
377 ConfigReceiver {
378 name: "baz".to_string(),
379 receiver: baz_rx,
380 },
381 ])
382 .await;
383
384 assert_eq!(
386 *bundle.get_config().await,
387 HashMap::from([("foo".to_string(), "bar".to_string())])
388 );
389
390 let _ = tokio::time::timeout(Duration::from_millis(50), bundle.changed())
392 .await
393 .expect("Should have received a config");
394
395 bar_tx.send_replace(HashMap::from([("foo".to_string(), "baz".to_string())]));
397 let conf = tokio::time::timeout(Duration::from_millis(50), bundle.changed())
399 .await
400 .expect("conf should have been present")
401 .expect("Should have received a config");
402 assert_eq!(
403 *conf,
404 HashMap::from([("foo".to_string(), "baz".to_string())])
405 );
406 drop(conf);
407
408 baz_tx.send_replace(HashMap::from([("star".to_string(), "wars".to_string())]));
410 let conf = tokio::time::timeout(Duration::from_millis(50), bundle.changed())
411 .await
412 .expect("conf should have been present")
413 .expect("Should have received a config");
414 assert_eq!(
415 *conf,
416 HashMap::from([
417 ("foo".to_string(), "baz".to_string()),
418 ("star".to_string(), "wars".to_string())
419 ])
420 );
421 drop(conf);
422
423 foo_tx.send_replace(HashMap::from([
425 ("starship".to_string(), "troopers".to_string()),
426 ("foo".to_string(), "bar".to_string()),
427 ]));
428 let conf = tokio::time::timeout(Duration::from_millis(50), bundle.changed())
429 .await
430 .expect("conf should have been present")
431 .expect("Should have received a config");
432 assert_eq!(
434 *conf,
435 HashMap::from([
436 ("foo".to_string(), "baz".to_string()),
437 ("star".to_string(), "wars".to_string()),
438 ("starship".to_string(), "troopers".to_string())
439 ]),
440 );
441 }
442}