redis/
subscription_tracker.rs
#![allow(dead_code)]
use core::str;
use std::collections::HashSet;
use crate::{Arg, Cmd, Pipeline};
#[derive(Default)]
pub(crate) struct SubscriptionTracker {
subscriptions: HashSet<Vec<u8>>,
s_subscriptions: HashSet<Vec<u8>>,
p_subscriptions: HashSet<Vec<u8>>,
}
pub(crate) enum SubscriptionAction {
Subscribe,
Unsubscribe,
PSubscribe,
PUnsubscribe,
SSubscribe,
Sunsubscribe,
}
impl SubscriptionAction {
fn additive(&self) -> bool {
match self {
SubscriptionAction::Subscribe
| SubscriptionAction::PSubscribe
| SubscriptionAction::SSubscribe => true,
SubscriptionAction::Unsubscribe
| SubscriptionAction::PUnsubscribe
| SubscriptionAction::Sunsubscribe => false,
}
}
}
impl SubscriptionTracker {
pub(crate) fn update_with_request(
&mut self,
action: SubscriptionAction,
args: impl Iterator<Item = Vec<u8>>,
) {
let set = match action {
SubscriptionAction::Subscribe | SubscriptionAction::Unsubscribe => {
&mut self.subscriptions
}
SubscriptionAction::PSubscribe | SubscriptionAction::PUnsubscribe => {
&mut self.p_subscriptions
}
SubscriptionAction::SSubscribe | SubscriptionAction::Sunsubscribe => {
&mut self.s_subscriptions
}
};
if action.additive() {
for sub in args {
set.insert(sub);
}
} else {
for sub in args {
set.remove(&sub);
}
}
}
pub(crate) fn update_with_cmd<'a>(&'a mut self, cmd: &'a Cmd) {
let mut args_iter = cmd.args_iter();
let first_arg = args_iter.next();
let Some(Arg::Simple(first_arg)) = first_arg else {
return;
};
let Ok(first_arg) = str::from_utf8(first_arg) else {
return;
};
let args = args_iter.filter_map(|arg| match arg {
Arg::Simple(arg) => Some(arg.to_vec()),
Arg::Cursor => None,
});
let action = if first_arg.eq_ignore_ascii_case("SUBSCRIBE") {
SubscriptionAction::Subscribe
} else if first_arg.eq_ignore_ascii_case("PSUBSCRIBE") {
SubscriptionAction::PSubscribe
} else if first_arg.eq_ignore_ascii_case("SSUBSCRIBE") {
SubscriptionAction::SSubscribe
} else if first_arg.eq_ignore_ascii_case("UNSUBSCRIBE") {
SubscriptionAction::Unsubscribe
} else if first_arg.eq_ignore_ascii_case("PUNSUBSCRIBE") {
SubscriptionAction::PUnsubscribe
} else if first_arg.eq_ignore_ascii_case("SUNSUBSCRIBE") {
SubscriptionAction::Sunsubscribe
} else {
return;
};
self.update_with_request(action, args);
}
pub(crate) fn update_with_pipeline<'a>(&'a mut self, pipe: &'a Pipeline) {
for cmd in pipe.cmd_iter() {
self.update_with_cmd(cmd);
}
}
pub(crate) fn get_subscription_pipeline(&self) -> Pipeline {
let mut pipeline = crate::pipe();
if !self.subscriptions.is_empty() {
let cmd = pipeline.cmd("SUBSCRIBE");
for channel in self.subscriptions.iter() {
cmd.arg(channel);
}
}
if !self.s_subscriptions.is_empty() {
let cmd = pipeline.cmd("SSUBSCRIBE");
for channel in self.s_subscriptions.iter() {
cmd.arg(channel);
}
}
if !self.p_subscriptions.is_empty() {
let cmd = pipeline.cmd("PSUBSCRIBE");
for channel in self.p_subscriptions.iter() {
cmd.arg(channel);
}
}
pipeline
}
}
#[cfg(test)]
mod tests {
use crate::{cmd, pipe};
use super::*;
#[test]
fn test_add_and_remove_subscriptions() {
let mut tracker = SubscriptionTracker::default();
tracker.update_with_cmd(cmd("subscribe").arg("foo").arg("bar"));
tracker.update_with_cmd(cmd("PSUBSCRIBE").arg("fo*o").arg("b*ar"));
tracker.update_with_cmd(cmd("SSUBSCRIBE").arg("sfoo").arg("sbar"));
tracker.update_with_cmd(cmd("unsubscribe").arg("foo"));
tracker.update_with_cmd(cmd("Punsubscribe").arg("b*ar"));
tracker.update_with_cmd(cmd("Sunsubscribe").arg("sfoo").arg("SBAR"));
tracker.update_with_cmd(cmd("GET").arg("sfoo"));
let result = tracker.get_subscription_pipeline();
let mut expected = pipe();
expected
.cmd("SUBSCRIBE")
.arg("bar")
.cmd("SSUBSCRIBE")
.arg("sbar")
.cmd("PSUBSCRIBE")
.arg("fo*o");
assert_eq!(
result.get_packed_pipeline(),
expected.get_packed_pipeline(),
"{}",
String::from_utf8(result.get_packed_pipeline()).unwrap()
);
}
#[test]
fn test_skip_empty_subscriptions() {
let mut tracker = SubscriptionTracker::default();
tracker.update_with_cmd(cmd("subscribe").arg("foo").arg("bar"));
tracker.update_with_cmd(cmd("PSUBSCRIBE").arg("fo*o").arg("b*ar"));
tracker.update_with_cmd(cmd("unsubscribe").arg("foo").arg("bar"));
tracker.update_with_cmd(cmd("punsubscribe").arg("fo*o"));
let result = tracker.get_subscription_pipeline();
let mut expected = pipe();
expected.cmd("PSUBSCRIBE").arg("b*ar");
assert_eq!(
result.get_packed_pipeline(),
expected.get_packed_pipeline(),
"{}",
String::from_utf8(result.get_packed_pipeline()).unwrap()
);
}
#[test]
fn test_add_and_remove_subscriptions_with_pipeline() {
let mut tracker = SubscriptionTracker::default();
tracker.update_with_pipeline(
pipe()
.cmd("subscribe")
.arg("foo")
.arg("bar")
.cmd("PSUBSCRIBE")
.arg("fo*o")
.arg("b*ar")
.cmd("SSUBSCRIBE")
.arg("sfoo")
.arg("sbar")
.cmd("unsubscribe")
.arg("foo")
.cmd("Punsubscribe")
.arg("b*ar")
.cmd("Sunsubscribe")
.arg("sfoo")
.arg("SBAR"),
);
let result = tracker.get_subscription_pipeline();
let mut expected = pipe();
expected
.cmd("SUBSCRIBE")
.arg("bar")
.cmd("SSUBSCRIBE")
.arg("sbar")
.cmd("PSUBSCRIBE")
.arg("fo*o");
assert_eq!(
result.get_packed_pipeline(),
expected.get_packed_pipeline(),
"{}",
String::from_utf8(result.get_packed_pipeline()).unwrap()
);
}
#[test]
fn test_only_unsubscribe_from_existing_subscriptions() {
let mut tracker = SubscriptionTracker::default();
tracker.update_with_cmd(cmd("unsubscribe").arg("foo"));
tracker.update_with_cmd(cmd("subscribe").arg("foo"));
let result = tracker.get_subscription_pipeline();
let mut expected = pipe();
expected.cmd("SUBSCRIBE").arg("foo");
assert_eq!(
result.get_packed_pipeline(),
expected.get_packed_pipeline(),
"{}",
String::from_utf8(result.get_packed_pipeline()).unwrap()
);
}
}