wasmcloud_provider_messaging_kafka/
lib.rs1use std::collections::{HashMap, HashSet};
4use std::sync::Arc;
5
6use anyhow::{bail, Context as _, Result};
7use bytes::Bytes;
8use kafka::producer::{Producer, Record};
9use tokio::spawn;
10use tokio::sync::oneshot::Sender;
11use tokio::sync::RwLock;
12use tokio::task::JoinHandle;
13use tokio::time::Duration;
14use tokio_stream::StreamExt;
15use tracing::{debug, error, instrument, warn};
16use wasmcloud_provider_sdk::{
17 get_connection, run_provider, Context, LinkConfig, LinkDeleteInfo, Provider,
18};
19use wasmcloud_provider_sdk::{initialize_observability, serve_provider_exports};
20use wasmcloud_tracing::context::TraceContextInjector;
21
22mod client;
23use client::{AsyncKafkaClient, AsyncKafkaConsumer};
24
25mod bindings {
26 wit_bindgen_wrpc::generate!({
27 with: {
28 "wasmcloud:messaging/consumer@0.2.0": generate,
29 "wasmcloud:messaging/handler@0.2.0": generate,
30 "wasmcloud:messaging/types@0.2.0": generate,
31 },
32 });
33}
34use bindings::wasmcloud::messaging::types::BrokerMessage;
35
36const KAFKA_HOSTS_CONFIG_KEY: &str = "hosts";
38const DEFAULT_HOST: &str = "127.0.0.1:9092";
39
40const KAFKA_TOPIC_CONFIG_KEY: &str = "topic";
42const DEFAULT_TOPIC: &str = "my-topic";
43
44const KAFKA_CONSUMER_GROUP_CONFIG_KEY: &str = "consumer_group";
46
47const KAFKA_CONSUMER_PARTITIONS_CONFIG_KEY: &str = "consumer_partitions";
50
51const KAFKA_PRODUCER_PARTITIONS_CONFIG_KEY: &str = "producer_partitions";
54
55const CONSUMER_STOP_TIMEOUT_SECS: u64 = 5;
57
58pub async fn run() -> Result<()> {
59 KafkaMessagingProvider::run().await
60}
61
62#[allow(dead_code)]
64struct KafkaConnection {
65 hosts: Vec<String>,
67 client: AsyncKafkaClient,
69 consumer: JoinHandle<anyhow::Result<()>>,
71 consumer_stop_tx: Sender<()>,
73 consumer_partitions: Vec<i32>,
75 producer_partitions: Vec<i32>,
77 consumer_group: Option<String>,
79}
80
81#[derive(Clone, Default)]
82pub struct KafkaMessagingProvider {
83 connections: Arc<RwLock<HashMap<String, KafkaConnection>>>,
87}
88
89impl KafkaMessagingProvider {
90 pub fn name() -> &'static str {
91 "messaging-kafka-provider"
92 }
93
94 pub async fn run() -> anyhow::Result<()> {
95 initialize_observability!(
96 KafkaMessagingProvider::name(),
97 std::env::var_os("PROVIDER_MESSAGING_KAFKA_FLAMEGRAPH_PATH")
98 );
99
100 let provider = Self::default();
101 let shutdown = run_provider(provider.clone(), KafkaMessagingProvider::name())
102 .await
103 .context("failed to run provider")?;
104 let connection = get_connection();
105 let wrpc = connection
106 .get_wrpc_client(connection.provider_key())
107 .await?;
108 serve_provider_exports(&wrpc, provider, shutdown, bindings::serve)
109 .await
110 .context("failed to serve provider exports")
111 }
112}
113
114fn extract_hosts_from_link_config(link_config: &LinkConfig) -> Vec<String> {
118 let maybe_hosts = link_config
123 .secrets
124 .iter()
125 .find_map(|(k, v)| {
126 match (k, v.as_string()) {
127 (k, Some(v)) if *k == KAFKA_HOSTS_CONFIG_KEY => Some(String::from(v)),
128 _ => None,
129 }
130 })
131 .or_else(|| {
132 warn!("secret value [{KAFKA_HOSTS_CONFIG_KEY}] was not found in secrets. Prefer storing sensitive values in secrets");
133 link_config
134 .config
135 .iter()
136 .find_map(|(k, v)| {
137 if *k == KAFKA_HOSTS_CONFIG_KEY {
138 Some(v.to_string())
139 } else {
140 None
141 }
142 })
143 });
144
145 maybe_hosts
146 .unwrap_or_else(|| DEFAULT_HOST.to_string())
147 .trim()
148 .split(',')
149 .map(std::string::ToString::to_string)
150 .collect::<Vec<String>>()
151}
152
153fn extract_topic_from_config(config: &HashMap<String, String>) -> &str {
157 config
158 .iter()
159 .find_map(|(k, v)| {
160 if *k == KAFKA_TOPIC_CONFIG_KEY {
161 Some(v.as_str())
162 } else {
163 None
164 }
165 })
166 .unwrap_or(DEFAULT_TOPIC)
167 .trim()
168}
169
170impl Provider for KafkaMessagingProvider {
171 #[instrument(skip_all, fields(source_id))]
173 async fn receive_link_config_as_target(&self, link_config: LinkConfig<'_>) -> Result<()> {
174 let LinkConfig {
175 link_name,
176 source_id,
177 config,
178 ..
179 } = link_config;
180 debug!(link_name, source_id, "receiving link as target");
181 let hosts = extract_hosts_from_link_config(&link_config);
183 let topic = extract_topic_from_config(config);
184 let consumer_group = config
185 .get(KAFKA_CONSUMER_GROUP_CONFIG_KEY)
186 .map(String::to_string);
187 let consumer_partitions = config
188 .get(KAFKA_CONSUMER_PARTITIONS_CONFIG_KEY)
189 .map(String::to_string)
190 .unwrap_or_default()
191 .split(',')
192 .map(|s| s.into())
193 .collect::<HashSet<String>>()
194 .iter()
195 .filter_map(|v| v.parse::<i32>().ok())
196 .collect::<Vec<i32>>();
197 let producer_partitions = config
198 .get(KAFKA_PRODUCER_PARTITIONS_CONFIG_KEY)
199 .map(String::to_string)
200 .unwrap_or_default()
201 .split(',')
202 .map(|s| s.into())
203 .collect::<HashSet<String>>()
204 .iter()
205 .filter_map(|v| v.parse::<i32>().ok())
206 .collect::<Vec<i32>>();
207
208 let client = AsyncKafkaClient::from_hosts(hosts.clone()).await.with_context(|| {
210 warn!(
211 source_id,
212 "failed to create Kafka client for component",
213 );
214 format!("failed to build async kafka client for component [{source_id}], messages won't be received")
215 })?;
216
217 let _consumer_group = consumer_group.clone();
219 let _consumer_partitions = consumer_partitions.clone();
220 debug!(topic, ?consumer_partitions, "creating kafka async consumer");
221 let consumer = AsyncKafkaConsumer::from_async_client(client, move |mut b| {
222 b = b.with_topic(topic.into());
223 b = b.with_topic_partitions(topic.into(), _consumer_partitions.as_slice());
224 if let Some(g) = _consumer_group {
225 b = b.with_group(g);
226 }
227 b
228 }).await.with_context(|| {
229 warn!(
230 source_id,
231 "failed to build consumer from Kafka client for component",
232 );
233 format!("failed to build consumer from kafka client for component [{source_id}], messages won't be received")
234 })?;
235
236 let client = AsyncKafkaClient::from_hosts(hosts.clone()).await.with_context(|| {
238 warn!(
239 source_id,
240 "failed to create Kafka client for component",
241 );
242 format!("failed to build async kafka client for component [{source_id}], messages won't be received")
243 })?;
244
245 let component_id: Arc<str> = source_id.into();
247 let subject: Arc<str> = topic.into();
248
249 let (stop_listener_tx, mut stop_listener_rx) = tokio::sync::oneshot::channel();
251
252 let (mut stream, inner_stop_tx) = match consumer
254 .messages()
255 .await
256 .context("failed to start listening to consumer messages")
257 {
258 Ok(v) => v,
259 Err(e) => {
260 warn!("failed listening to consumer message stream: {e}");
261 bail!(e);
262 }
263 };
264
265 let task = spawn(async move {
268 let wrpc = get_connection().get_wrpc_client(&component_id).await?;
269
270 loop {
272 tokio::select! {
273 _ = &mut stop_listener_rx => {
275 if let Err(()) = inner_stop_tx.send(()) {
276 bail!("failed to send stop consumer");
277 }
278 return Ok(());
279 },
280
281 Some(msg) = stream.next() => {
285 let component_id = Arc::clone(&component_id);
286 let wrpc = wrpc.clone();
287 let subject = Arc::clone(&subject);
288 tokio::spawn(async move {
289 if let Err(e) = bindings::wasmcloud::messaging::handler::handle_message(
290 &wrpc,
291 None,
292 &BrokerMessage {
293 body: msg.value.into(),
294 reply_to: Some(format!("{subject}.reply")),
296 subject: subject.to_string(),
297 },
298 )
299 .await
300 {
301 warn!(
302 subject = subject.to_string(),
303 component_id = component_id.to_string(),
304 "unable to send subscription: {e:?}",
305 );
306 }
307 });
308 }
309 }
310 }
311 });
312
313 let mut connections = self.connections.write().await;
315 connections.insert(
316 source_id.to_string(),
317 KafkaConnection {
318 client,
319 consumer: task,
320 consumer_stop_tx: stop_listener_tx,
321 hosts,
322 consumer_partitions,
323 producer_partitions,
324 consumer_group,
325 },
326 );
327
328 Ok(())
329 }
330
331 #[instrument(level = "info", skip_all, fields(source_id = info.get_source_id()))]
333 async fn delete_link_as_target(&self, info: impl LinkDeleteInfo) -> Result<()> {
334 let component_id = info.get_source_id();
335 debug!(component_id, "deleting link for component");
336
337 let mut connections = self.connections.write().await;
339 let Some(KafkaConnection {
340 consumer,
341 consumer_stop_tx,
342 ..
343 }) = connections.remove(component_id)
344 else {
345 debug!("Linkdef deleted for non-existent consumer, ignoring");
346 return Ok(());
347 };
348
349 if let Err(()) = consumer_stop_tx.send(()) {
351 bail!("failed to send stop consumer");
352 }
353 let _ = tokio::time::timeout(Duration::from_secs(CONSUMER_STOP_TIMEOUT_SECS), consumer)
354 .await
355 .context("consumer task did not exit cleanly")?;
356
357 Ok(())
358 }
359
360 async fn shutdown(&self) -> Result<()> {
362 let mut connections = self.connections.write().await;
363 for (
364 _source_id,
365 KafkaConnection {
366 consumer,
367 consumer_stop_tx,
368 ..
369 },
370 ) in connections.drain()
371 {
372 consumer_stop_tx
373 .send(())
374 .map_err(|_| anyhow::anyhow!("failed to send consumer stop"))?;
375 if let Err(err) =
376 tokio::try_join!(consumer).context("consumer task did not exit cleanly")
377 {
378 error!(?err, "failed to stop consumer task cleanly");
379 };
380 }
381 Ok(())
382 }
383}
384
385impl bindings::exports::wasmcloud::messaging::consumer::Handler<Option<Context>>
387 for KafkaMessagingProvider
388{
389 #[instrument(
390 skip_all,
391 fields(subject = %msg.subject, reply_to = ?msg.reply_to, body_len = %msg.body.len())
392 )]
393 async fn publish(
394 &self,
395 ctx: Option<Context>,
396 msg: BrokerMessage,
397 ) -> Result<std::result::Result<(), String>> {
398 let trace_ctx = match ctx {
400 Some(Context { ref tracing, .. }) if !tracing.is_empty() => tracing
401 .iter()
402 .map(|(k, v)| (k.to_string(), v.to_string()))
403 .collect::<Vec<(String, String)>>(),
404
405 _ => TraceContextInjector::default_with_span()
406 .iter()
407 .map(|(k, v)| (k.to_string(), v.to_string()))
408 .collect(),
409 };
410 wasmcloud_tracing::context::attach_span_context(&trace_ctx);
411 debug!(?msg, "publishing message");
412
413 let ctx = ctx.as_ref().context("unexpectedly missing context")?;
414 let Some(component_id) = ctx.component.as_ref() else {
415 bail!("context unexpectedly missing component ID");
416 };
417
418 let connections = self.connections.read().await;
420 let Some(KafkaConnection {
421 hosts,
422 producer_partitions,
423 ..
424 }) = connections.get(component_id)
425 else {
426 warn!(component_id, "failed to get connection for component");
427 return Ok(Err(format!(
428 "failed to get connection for component [{component_id}]"
429 )));
430 };
431
432 let mut producer = Producer::from_hosts(hosts.clone())
434 .create()
435 .context("failed to build kafka producer")?;
436
437 debug!(subject = msg.subject, "sending message");
440 match producer_partitions[..] {
441 [] => {
443 producer
444 .send(&Record::<(), Vec<u8>>::from_key_value(
445 &msg.subject,
446 (),
447 msg.body.to_vec(),
448 ))
449 .context("failed to send record")?;
450 }
451 _ => {
453 for partition in producer_partitions {
454 producer
455 .send(
456 &Record::<(), Vec<u8>>::from_key_value(
457 &msg.subject,
458 (),
459 msg.body.to_vec(),
460 )
461 .with_partition(*partition),
462 )
463 .with_context(|| {
464 format!("failed to send record to partition [{partition}]")
465 })?;
466 }
467 }
468 }
469
470 Ok(Ok(()))
471 }
472
473 #[instrument(skip_all)]
474 async fn request(
475 &self,
476 ctx: Option<Context>,
477 _subject: String,
478 _body: Bytes,
479 _timeout_ms: u32,
480 ) -> Result<std::result::Result<BrokerMessage, String>> {
481 let trace_ctx = match ctx {
483 Some(Context { ref tracing, .. }) if !tracing.is_empty() => tracing
484 .iter()
485 .map(|(k, v)| (k.to_string(), v.to_string()))
486 .collect::<Vec<(String, String)>>(),
487
488 _ => TraceContextInjector::default_with_span()
489 .iter()
490 .map(|(k, v)| (k.to_string(), v.to_string()))
491 .collect(),
492 };
493 wasmcloud_tracing::context::attach_span_context(&trace_ctx);
494
495 error!("not implemented (Kafka does not officially support the request-reply paradigm)");
499 Ok(Err(
500 "not implemented (Kafka does not officially support the request-reply paradigm)"
501 .to_string(),
502 ))
503 }
504}