wasmcloud_provider_messaging_kafka/
lib.rs

1//! Implementation for wasmcloud:messaging
2
3use std::collections::{HashMap, HashSet};
4use std::sync::Arc;
5
6use anyhow::{bail, Context as _, Result};
7use bytes::Bytes;
8use kafka::producer::{Producer, Record};
9use tokio::spawn;
10use tokio::sync::oneshot::Sender;
11use tokio::sync::RwLock;
12use tokio::task::JoinHandle;
13use tokio::time::Duration;
14use tokio_stream::StreamExt;
15use tracing::{debug, error, instrument, warn};
16use wasmcloud_provider_sdk::{
17    get_connection, run_provider, Context, LinkConfig, LinkDeleteInfo, Provider,
18};
19use wasmcloud_provider_sdk::{initialize_observability, serve_provider_exports};
20use wasmcloud_tracing::context::TraceContextInjector;
21
22mod client;
23use client::{AsyncKafkaClient, AsyncKafkaConsumer};
24
25mod bindings {
26    wit_bindgen_wrpc::generate!({
27        with: {
28            "wasmcloud:messaging/consumer@0.2.0": generate,
29            "wasmcloud:messaging/handler@0.2.0": generate,
30            "wasmcloud:messaging/types@0.2.0": generate,
31        },
32    });
33}
34use bindings::wasmcloud::messaging::types::BrokerMessage;
35
36/// Config value for hosts, accepted as a comma separated string
37const KAFKA_HOSTS_CONFIG_KEY: &str = "hosts";
38const DEFAULT_HOST: &str = "127.0.0.1:9092";
39
40/// Config value for topic, accepted as a single string
41const KAFKA_TOPIC_CONFIG_KEY: &str = "topic";
42const DEFAULT_TOPIC: &str = "my-topic";
43
44/// Config value for specifying a consumer group
45const KAFKA_CONSUMER_GROUP_CONFIG_KEY: &str = "consumer_group";
46
47/// Config value for specifying one or more comma delimited partition(s)
48/// to use when consuming values
49const KAFKA_CONSUMER_PARTITIONS_CONFIG_KEY: &str = "consumer_partitions";
50
51/// Config value for specifying one or more comma delimited partition(s)
52/// to use when producing values
53const KAFKA_PRODUCER_PARTITIONS_CONFIG_KEY: &str = "producer_partitions";
54
55/// Number of seconds to wait for a consumer to stop after triggering it
56const CONSUMER_STOP_TIMEOUT_SECS: u64 = 5;
57
58pub async fn run() -> Result<()> {
59    KafkaMessagingProvider::run().await
60}
61
62/// A struct that contains a consumer task handler and the host connection strings
63#[allow(dead_code)]
64struct KafkaConnection {
65    /// Hosts that the connection is using
66    hosts: Vec<String>,
67    /// Kafka client that can be used for one-off things
68    client: AsyncKafkaClient,
69    /// Handle to a tokio consumer task handle
70    consumer: JoinHandle<anyhow::Result<()>>,
71    /// Stop the consumer
72    consumer_stop_tx: Sender<()>,
73    /// Topic partition(s) on which the consumer is consuming messages
74    consumer_partitions: Vec<i32>,
75    /// Topic partition(s) on which the producer is sending messages
76    producer_partitions: Vec<i32>,
77    /// Consumer group
78    consumer_group: Option<String>,
79}
80
81#[derive(Clone, Default)]
82pub struct KafkaMessagingProvider {
83    // Map of Component ID to the JoinHandle where messages are consumed.
84    //
85    // When a link is put we spawn a tokio::task to handle messages, and on delete the task is closed
86    connections: Arc<RwLock<HashMap<String, KafkaConnection>>>,
87}
88
89impl KafkaMessagingProvider {
90    pub fn name() -> &'static str {
91        "messaging-kafka-provider"
92    }
93
94    pub async fn run() -> anyhow::Result<()> {
95        initialize_observability!(
96            KafkaMessagingProvider::name(),
97            std::env::var_os("PROVIDER_MESSAGING_KAFKA_FLAMEGRAPH_PATH")
98        );
99
100        let provider = Self::default();
101        let shutdown = run_provider(provider.clone(), KafkaMessagingProvider::name())
102            .await
103            .context("failed to run provider")?;
104        let connection = get_connection();
105        let wrpc = connection
106            .get_wrpc_client(connection.provider_key())
107            .await?;
108        serve_provider_exports(&wrpc, provider, shutdown, bindings::serve)
109            .await
110            .context("failed to serve provider exports")
111    }
112}
113
114/// Extract hostnames (separated by commas, found under key [`KAFKA_HOSTS_CONFIG_KEY`]) from config hashmap
115///
116/// If no hostnames are found [`DEFAULT_HOST`] is split (by ',') and returned.
117fn extract_hosts_from_link_config(link_config: &LinkConfig) -> Vec<String> {
118    // Collect comma separated hosts into a Vec<String>
119    //
120    // This value could come from either secrets or regular config (for backwards compat)
121    // but we want to make sure we warn if it is pulled from config.
122    let maybe_hosts = link_config
123        .secrets
124        .iter()
125        .find_map(|(k, v)| {
126            match (k, v.as_string()) {
127                (k, Some(v)) if *k == KAFKA_HOSTS_CONFIG_KEY  => Some(String::from(v)),
128                _ => None,
129            }
130        })
131    .or_else(|| {
132        warn!("secret value [{KAFKA_HOSTS_CONFIG_KEY}] was not found in secrets. Prefer storing sensitive values in secrets");
133        link_config
134            .config
135            .iter()
136            .find_map(|(k, v)| {
137                if *k == KAFKA_HOSTS_CONFIG_KEY {
138                    Some(v.to_string())
139                } else {
140                    None
141                }
142            })
143    });
144
145    maybe_hosts
146        .unwrap_or_else(|| DEFAULT_HOST.to_string())
147        .trim()
148        .split(',')
149        .map(std::string::ToString::to_string)
150        .collect::<Vec<String>>()
151}
152
153/// Extract a topic (found under key [`KAFKA_TOPIC_CONFIG_KEY`]) from config hashmap
154///
155/// If no topic is found, [`DEFAULT_TOPIC`] is returned.
156fn extract_topic_from_config(config: &HashMap<String, String>) -> &str {
157    config
158        .iter()
159        .find_map(|(k, v)| {
160            if *k == KAFKA_TOPIC_CONFIG_KEY {
161                Some(v.as_str())
162            } else {
163                None
164            }
165        })
166        .unwrap_or(DEFAULT_TOPIC)
167        .trim()
168}
169
170impl Provider for KafkaMessagingProvider {
171    /// Called when this provider is linked to, when the provider is the *target* of the link.
172    #[instrument(skip_all, fields(source_id))]
173    async fn receive_link_config_as_target(&self, link_config: LinkConfig<'_>) -> Result<()> {
174        let LinkConfig {
175            link_name,
176            source_id,
177            config,
178            ..
179        } = link_config;
180        debug!(link_name, source_id, "receiving link as target");
181        // Collect various values from config (if present)
182        let hosts = extract_hosts_from_link_config(&link_config);
183        let topic = extract_topic_from_config(config);
184        let consumer_group = config
185            .get(KAFKA_CONSUMER_GROUP_CONFIG_KEY)
186            .map(String::to_string);
187        let consumer_partitions = config
188            .get(KAFKA_CONSUMER_PARTITIONS_CONFIG_KEY)
189            .map(String::to_string)
190            .unwrap_or_default()
191            .split(',')
192            .map(|s| s.into())
193            .collect::<HashSet<String>>()
194            .iter()
195            .filter_map(|v| v.parse::<i32>().ok())
196            .collect::<Vec<i32>>();
197        let producer_partitions = config
198            .get(KAFKA_PRODUCER_PARTITIONS_CONFIG_KEY)
199            .map(String::to_string)
200            .unwrap_or_default()
201            .split(',')
202            .map(|s| s.into())
203            .collect::<HashSet<String>>()
204            .iter()
205            .filter_map(|v| v.parse::<i32>().ok())
206            .collect::<Vec<i32>>();
207
208        // Build client for use with the consumer
209        let client = AsyncKafkaClient::from_hosts(hosts.clone()).await.with_context(|| {
210            warn!(
211                source_id,
212                "failed to create Kafka client for component",
213            );
214            format!("failed to build async kafka client for component [{source_id}], messages won't be received")
215        })?;
216
217        // Build a consumer configured with our given client
218        let _consumer_group = consumer_group.clone();
219        let _consumer_partitions = consumer_partitions.clone();
220        debug!(topic, ?consumer_partitions, "creating kafka async consumer");
221        let consumer = AsyncKafkaConsumer::from_async_client(client, move |mut b| {
222            b = b.with_topic(topic.into());
223            b = b.with_topic_partitions(topic.into(), _consumer_partitions.as_slice());
224            if let Some(g) = _consumer_group {
225                b = b.with_group(g);
226            }
227            b
228        }).await.with_context(|| {
229            warn!(
230                source_id,
231                "failed to build consumer from Kafka client for component",
232            );
233            format!("failed to build consumer from kafka client for component [{source_id}], messages won't be received")
234        })?;
235
236        // Build a second client to store in the connection
237        let client = AsyncKafkaClient::from_hosts(hosts.clone()).await.with_context(|| {
238            warn!(
239                source_id,
240                "failed to create Kafka client for component",
241            );
242            format!("failed to build async kafka client for component [{source_id}], messages won't be received")
243        })?;
244
245        // Store reusable information for use when processing new messages
246        let component_id: Arc<str> = source_id.into();
247        let subject: Arc<str> = topic.into();
248
249        // Allow triggering listeners to stop
250        let (stop_listener_tx, mut stop_listener_rx) = tokio::sync::oneshot::channel();
251
252        // Start listening for incoming messages
253        let (mut stream, inner_stop_tx) = match consumer
254            .messages()
255            .await
256            .context("failed to start listening to consumer messages")
257        {
258            Ok(v) => v,
259            Err(e) => {
260                warn!("failed listening to consumer message stream: {e}");
261                bail!(e);
262            }
263        };
264
265        // StartOffset::Latest only processes new messages, but Earliest will send every message.
266        // This could be a linkdef tunable value in the future
267        let task = spawn(async move {
268            let wrpc = get_connection().get_wrpc_client(&component_id).await?;
269
270            // Listen to messages forever until we're instructed to stop
271            loop {
272                tokio::select! {
273                    // Handle listening to calls to stop
274                    _ = &mut stop_listener_rx => {
275                        if let Err(()) = inner_stop_tx.send(()) {
276                            bail!("failed to send stop consumer");
277                        }
278                        return Ok(());
279                    },
280
281                    // Listen to the next messages in the stream
282                    //
283                    // This stream will essentially never stop producing values.
284                    Some(msg) = stream.next() => {
285                        let component_id = Arc::clone(&component_id);
286                        let wrpc = wrpc.clone();
287                        let subject = Arc::clone(&subject);
288                        tokio::spawn(async move {
289                            if let Err(e) = bindings::wasmcloud::messaging::handler::handle_message(
290                                &wrpc,
291                                None,
292                                &BrokerMessage {
293                                    body: msg.value.into(),
294                                    // By default, we always append '.reply' for reply topics
295                                    reply_to: Some(format!("{subject}.reply")),
296                                    subject: subject.to_string(),
297                                },
298                            )
299                                .await
300                            {
301                                warn!(
302                                    subject = subject.to_string(),
303                                    component_id = component_id.to_string(),
304                                    "unable to send subscription: {e:?}",
305                                );
306                            }
307                        });
308                    }
309                }
310            }
311        });
312
313        // Save the newly task that constantly listens for messages to the provider
314        let mut connections = self.connections.write().await;
315        connections.insert(
316            source_id.to_string(),
317            KafkaConnection {
318                client,
319                consumer: task,
320                consumer_stop_tx: stop_listener_tx,
321                hosts,
322                consumer_partitions,
323                producer_partitions,
324                consumer_group,
325            },
326        );
327
328        Ok(())
329    }
330
331    /// Handle notification that a link is dropped: close the connection
332    #[instrument(level = "info", skip_all, fields(source_id = info.get_source_id()))]
333    async fn delete_link_as_target(&self, info: impl LinkDeleteInfo) -> Result<()> {
334        let component_id = info.get_source_id();
335        debug!(component_id, "deleting link for component");
336
337        // Find the connection and remove it from the HashMap
338        let mut connections = self.connections.write().await;
339        let Some(KafkaConnection {
340            consumer,
341            consumer_stop_tx,
342            ..
343        }) = connections.remove(component_id)
344        else {
345            debug!("Linkdef deleted for non-existent consumer, ignoring");
346            return Ok(());
347        };
348
349        // Signal the consumer to stop, then wait for it to close out
350        if let Err(()) = consumer_stop_tx.send(()) {
351            bail!("failed to send stop consumer");
352        }
353        let _ = tokio::time::timeout(Duration::from_secs(CONSUMER_STOP_TIMEOUT_SECS), consumer)
354            .await
355            .context("consumer task did not exit cleanly")?;
356
357        Ok(())
358    }
359
360    /// Handle shutdown request with any cleanup necessary
361    async fn shutdown(&self) -> Result<()> {
362        let mut connections = self.connections.write().await;
363        for (
364            _source_id,
365            KafkaConnection {
366                consumer,
367                consumer_stop_tx,
368                ..
369            },
370        ) in connections.drain()
371        {
372            consumer_stop_tx
373                .send(())
374                .map_err(|_| anyhow::anyhow!("failed to send consumer stop"))?;
375            if let Err(err) =
376                tokio::try_join!(consumer).context("consumer task did not exit cleanly")
377            {
378                error!(?err, "failed to stop consumer task cleanly");
379            };
380        }
381        Ok(())
382    }
383}
384
385/// Implement the 'wasmcloud:messaging' capability provider interface
386impl bindings::exports::wasmcloud::messaging::consumer::Handler<Option<Context>>
387    for KafkaMessagingProvider
388{
389    #[instrument(
390        skip_all,
391        fields(subject = %msg.subject, reply_to = ?msg.reply_to, body_len = %msg.body.len())
392    )]
393    async fn publish(
394        &self,
395        ctx: Option<Context>,
396        msg: BrokerMessage,
397    ) -> Result<std::result::Result<(), String>> {
398        // Extract tracing information from invocation context, if present
399        let trace_ctx = match ctx {
400            Some(Context { ref tracing, .. }) if !tracing.is_empty() => tracing
401                .iter()
402                .map(|(k, v)| (k.to_string(), v.to_string()))
403                .collect::<Vec<(String, String)>>(),
404
405            _ => TraceContextInjector::default_with_span()
406                .iter()
407                .map(|(k, v)| (k.to_string(), v.to_string()))
408                .collect(),
409        };
410        wasmcloud_tracing::context::attach_span_context(&trace_ctx);
411        debug!(?msg, "publishing message");
412
413        let ctx = ctx.as_ref().context("unexpectedly missing context")?;
414        let Some(component_id) = ctx.component.as_ref() else {
415            bail!("context unexpectedly missing component ID");
416        };
417
418        // Retrieve a usable Kafka client from the kafka connection for our component
419        let connections = self.connections.read().await;
420        let Some(KafkaConnection {
421            hosts,
422            producer_partitions,
423            ..
424        }) = connections.get(component_id)
425        else {
426            warn!(component_id, "failed to get connection for component");
427            return Ok(Err(format!(
428                "failed to get connection for component [{component_id}]"
429            )));
430        };
431
432        // Create a producer we'll use to send
433        let mut producer = Producer::from_hosts(hosts.clone())
434            .create()
435            .context("failed to build kafka producer")?;
436
437        // For every partition we're listening on, send out a record
438        // if we're listening on *no* partitions, then use the unspecified partition
439        debug!(subject = msg.subject, "sending message");
440        match producer_partitions[..] {
441            // Send to the default ("unspecified") partition
442            [] => {
443                producer
444                    .send(&Record::<(), Vec<u8>>::from_key_value(
445                        &msg.subject,
446                        (),
447                        msg.body.to_vec(),
448                    ))
449                    .context("failed to send record")?;
450            }
451            // If there are multiple partitions to publish to, then publish to each of them
452            _ => {
453                for partition in producer_partitions {
454                    producer
455                        .send(
456                            &Record::<(), Vec<u8>>::from_key_value(
457                                &msg.subject,
458                                (),
459                                msg.body.to_vec(),
460                            )
461                            .with_partition(*partition),
462                        )
463                        .with_context(|| {
464                            format!("failed to send record to partition [{partition}]")
465                        })?;
466                }
467            }
468        }
469
470        Ok(Ok(()))
471    }
472
473    #[instrument(skip_all)]
474    async fn request(
475        &self,
476        ctx: Option<Context>,
477        _subject: String,
478        _body: Bytes,
479        _timeout_ms: u32,
480    ) -> Result<std::result::Result<BrokerMessage, String>> {
481        // Extract tracing information from invocation context, if present
482        let trace_ctx = match ctx {
483            Some(Context { ref tracing, .. }) if !tracing.is_empty() => tracing
484                .iter()
485                .map(|(k, v)| (k.to_string(), v.to_string()))
486                .collect::<Vec<(String, String)>>(),
487
488            _ => TraceContextInjector::default_with_span()
489                .iter()
490                .map(|(k, v)| (k.to_string(), v.to_string()))
491                .collect(),
492        };
493        wasmcloud_tracing::context::attach_span_context(&trace_ctx);
494
495        // Kafka does not support request-reply in the traditional sense. You can publish to a
496        // topic, and get an acknowledgement that it was received, but you can't get a
497        // reply from a consumer on the other side.
498        error!("not implemented (Kafka does not officially support the request-reply paradigm)");
499        Ok(Err(
500            "not implemented (Kafka does not officially support the request-reply paradigm)"
501                .to_string(),
502        ))
503    }
504}