redis/
subscription_tracker.rs

1#![allow(dead_code)]
2
3use core::str;
4use std::collections::HashSet;
5
6use crate::{Arg, Cmd, Pipeline};
7
8#[derive(Default)]
9pub(crate) struct SubscriptionTracker {
10    subscriptions: HashSet<Vec<u8>>,
11    s_subscriptions: HashSet<Vec<u8>>,
12    p_subscriptions: HashSet<Vec<u8>>,
13}
14
15pub(crate) enum SubscriptionAction {
16    Subscribe,
17    Unsubscribe,
18    PSubscribe,
19    PUnsubscribe,
20    SSubscribe,
21    Sunsubscribe,
22}
23
24impl SubscriptionAction {
25    fn additive(&self) -> bool {
26        match self {
27            SubscriptionAction::Subscribe
28            | SubscriptionAction::PSubscribe
29            | SubscriptionAction::SSubscribe => true,
30
31            SubscriptionAction::Unsubscribe
32            | SubscriptionAction::PUnsubscribe
33            | SubscriptionAction::Sunsubscribe => false,
34        }
35    }
36}
37
38impl SubscriptionTracker {
39    pub(crate) fn update_with_request(
40        &mut self,
41        action: SubscriptionAction,
42        args: impl Iterator<Item = Vec<u8>>,
43    ) {
44        let set = match action {
45            SubscriptionAction::Subscribe | SubscriptionAction::Unsubscribe => {
46                &mut self.subscriptions
47            }
48            SubscriptionAction::PSubscribe | SubscriptionAction::PUnsubscribe => {
49                &mut self.p_subscriptions
50            }
51            SubscriptionAction::SSubscribe | SubscriptionAction::Sunsubscribe => {
52                &mut self.s_subscriptions
53            }
54        };
55
56        if action.additive() {
57            for sub in args {
58                set.insert(sub);
59            }
60        } else {
61            for sub in args {
62                set.remove(&sub);
63            }
64        }
65    }
66
67    pub(crate) fn update_with_cmd<'a>(&'a mut self, cmd: &'a Cmd) {
68        let mut args_iter = cmd.args_iter();
69        let first_arg = args_iter.next();
70
71        let Some(Arg::Simple(first_arg)) = first_arg else {
72            return;
73        };
74        let Ok(first_arg) = str::from_utf8(first_arg) else {
75            return;
76        };
77
78        let args = args_iter.filter_map(|arg| match arg {
79            Arg::Simple(arg) => Some(arg.to_vec()),
80            Arg::Cursor => None,
81        });
82
83        let action = if first_arg.eq_ignore_ascii_case("SUBSCRIBE") {
84            SubscriptionAction::Subscribe
85        } else if first_arg.eq_ignore_ascii_case("PSUBSCRIBE") {
86            SubscriptionAction::PSubscribe
87        } else if first_arg.eq_ignore_ascii_case("SSUBSCRIBE") {
88            SubscriptionAction::SSubscribe
89        } else if first_arg.eq_ignore_ascii_case("UNSUBSCRIBE") {
90            SubscriptionAction::Unsubscribe
91        } else if first_arg.eq_ignore_ascii_case("PUNSUBSCRIBE") {
92            SubscriptionAction::PUnsubscribe
93        } else if first_arg.eq_ignore_ascii_case("SUNSUBSCRIBE") {
94            SubscriptionAction::Sunsubscribe
95        } else {
96            return;
97        };
98        self.update_with_request(action, args);
99    }
100
101    pub(crate) fn update_with_pipeline<'a>(&'a mut self, pipe: &'a Pipeline) {
102        for cmd in pipe.cmd_iter() {
103            self.update_with_cmd(cmd);
104        }
105    }
106
107    pub(crate) fn get_subscription_pipeline(&self) -> Pipeline {
108        let mut pipeline = crate::pipe();
109        if !self.subscriptions.is_empty() {
110            let cmd = pipeline.cmd("SUBSCRIBE");
111            for channel in self.subscriptions.iter() {
112                cmd.arg(channel);
113            }
114        }
115        if !self.s_subscriptions.is_empty() {
116            let cmd = pipeline.cmd("SSUBSCRIBE");
117            for channel in self.s_subscriptions.iter() {
118                cmd.arg(channel);
119            }
120        }
121        if !self.p_subscriptions.is_empty() {
122            let cmd = pipeline.cmd("PSUBSCRIBE");
123            for channel in self.p_subscriptions.iter() {
124                cmd.arg(channel);
125            }
126        }
127
128        pipeline
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use crate::{cmd, pipe};
135
136    use super::*;
137
138    #[test]
139    fn test_add_and_remove_subscriptions() {
140        let mut tracker = SubscriptionTracker::default();
141
142        tracker.update_with_cmd(cmd("subscribe").arg("foo").arg("bar"));
143        tracker.update_with_cmd(cmd("PSUBSCRIBE").arg("fo*o").arg("b*ar"));
144        tracker.update_with_cmd(cmd("SSUBSCRIBE").arg("sfoo").arg("sbar"));
145        tracker.update_with_cmd(cmd("unsubscribe").arg("foo"));
146        tracker.update_with_cmd(cmd("Punsubscribe").arg("b*ar"));
147        tracker.update_with_cmd(cmd("Sunsubscribe").arg("sfoo").arg("SBAR"));
148        // ignore irrelevant commands
149        tracker.update_with_cmd(cmd("GET").arg("sfoo"));
150
151        let result = tracker.get_subscription_pipeline();
152        let mut expected = pipe();
153        expected
154            .cmd("SUBSCRIBE")
155            .arg("bar")
156            .cmd("SSUBSCRIBE")
157            .arg("sbar")
158            .cmd("PSUBSCRIBE")
159            .arg("fo*o");
160        assert_eq!(
161            result.get_packed_pipeline(),
162            expected.get_packed_pipeline(),
163            "{}",
164            String::from_utf8(result.get_packed_pipeline()).unwrap()
165        );
166    }
167
168    #[test]
169    fn test_skip_empty_subscriptions() {
170        let mut tracker = SubscriptionTracker::default();
171
172        tracker.update_with_cmd(cmd("subscribe").arg("foo").arg("bar"));
173        tracker.update_with_cmd(cmd("PSUBSCRIBE").arg("fo*o").arg("b*ar"));
174        tracker.update_with_cmd(cmd("unsubscribe").arg("foo").arg("bar"));
175        tracker.update_with_cmd(cmd("punsubscribe").arg("fo*o"));
176
177        let result = tracker.get_subscription_pipeline();
178        let mut expected = pipe();
179        expected.cmd("PSUBSCRIBE").arg("b*ar");
180        assert_eq!(
181            result.get_packed_pipeline(),
182            expected.get_packed_pipeline(),
183            "{}",
184            String::from_utf8(result.get_packed_pipeline()).unwrap()
185        );
186    }
187
188    #[test]
189    fn test_add_and_remove_subscriptions_with_pipeline() {
190        let mut tracker = SubscriptionTracker::default();
191
192        tracker.update_with_pipeline(
193            pipe()
194                .cmd("subscribe")
195                .arg("foo")
196                .arg("bar")
197                .cmd("PSUBSCRIBE")
198                .arg("fo*o")
199                .arg("b*ar")
200                .cmd("SSUBSCRIBE")
201                .arg("sfoo")
202                .arg("sbar")
203                .cmd("unsubscribe")
204                .arg("foo")
205                .cmd("Punsubscribe")
206                .arg("b*ar")
207                .cmd("Sunsubscribe")
208                .arg("sfoo")
209                .arg("SBAR"),
210        );
211
212        let result = tracker.get_subscription_pipeline();
213        let mut expected = pipe();
214        expected
215            .cmd("SUBSCRIBE")
216            .arg("bar")
217            .cmd("SSUBSCRIBE")
218            .arg("sbar")
219            .cmd("PSUBSCRIBE")
220            .arg("fo*o");
221        assert_eq!(
222            result.get_packed_pipeline(),
223            expected.get_packed_pipeline(),
224            "{}",
225            String::from_utf8(result.get_packed_pipeline()).unwrap()
226        );
227    }
228
229    #[test]
230    fn test_only_unsubscribe_from_existing_subscriptions() {
231        let mut tracker = SubscriptionTracker::default();
232
233        tracker.update_with_cmd(cmd("unsubscribe").arg("foo"));
234        tracker.update_with_cmd(cmd("subscribe").arg("foo"));
235
236        let result = tracker.get_subscription_pipeline();
237        let mut expected = pipe();
238        expected.cmd("SUBSCRIBE").arg("foo");
239        assert_eq!(
240            result.get_packed_pipeline(),
241            expected.get_packed_pipeline(),
242            "{}",
243            String::from_utf8(result.get_packed_pipeline()).unwrap()
244        );
245    }
246}