redis/
subscription_tracker.rs1#![allow(dead_code)]
2
3use core::str;
4use std::collections::HashSet;
5
6use crate::{Arg, Cmd, Pipeline};
7
8#[derive(Default)]
9pub(crate) struct SubscriptionTracker {
10 subscriptions: HashSet<Vec<u8>>,
11 s_subscriptions: HashSet<Vec<u8>>,
12 p_subscriptions: HashSet<Vec<u8>>,
13}
14
15pub(crate) enum SubscriptionAction {
16 Subscribe,
17 Unsubscribe,
18 PSubscribe,
19 PUnsubscribe,
20 SSubscribe,
21 Sunsubscribe,
22}
23
24impl SubscriptionAction {
25 fn additive(&self) -> bool {
26 match self {
27 SubscriptionAction::Subscribe
28 | SubscriptionAction::PSubscribe
29 | SubscriptionAction::SSubscribe => true,
30
31 SubscriptionAction::Unsubscribe
32 | SubscriptionAction::PUnsubscribe
33 | SubscriptionAction::Sunsubscribe => false,
34 }
35 }
36}
37
38impl SubscriptionTracker {
39 pub(crate) fn update_with_request(
40 &mut self,
41 action: SubscriptionAction,
42 args: impl Iterator<Item = Vec<u8>>,
43 ) {
44 let set = match action {
45 SubscriptionAction::Subscribe | SubscriptionAction::Unsubscribe => {
46 &mut self.subscriptions
47 }
48 SubscriptionAction::PSubscribe | SubscriptionAction::PUnsubscribe => {
49 &mut self.p_subscriptions
50 }
51 SubscriptionAction::SSubscribe | SubscriptionAction::Sunsubscribe => {
52 &mut self.s_subscriptions
53 }
54 };
55
56 if action.additive() {
57 for sub in args {
58 set.insert(sub);
59 }
60 } else {
61 for sub in args {
62 set.remove(&sub);
63 }
64 }
65 }
66
67 pub(crate) fn update_with_cmd<'a>(&'a mut self, cmd: &'a Cmd) {
68 let mut args_iter = cmd.args_iter();
69 let first_arg = args_iter.next();
70
71 let Some(Arg::Simple(first_arg)) = first_arg else {
72 return;
73 };
74 let Ok(first_arg) = str::from_utf8(first_arg) else {
75 return;
76 };
77
78 let args = args_iter.filter_map(|arg| match arg {
79 Arg::Simple(arg) => Some(arg.to_vec()),
80 Arg::Cursor => None,
81 });
82
83 let action = if first_arg.eq_ignore_ascii_case("SUBSCRIBE") {
84 SubscriptionAction::Subscribe
85 } else if first_arg.eq_ignore_ascii_case("PSUBSCRIBE") {
86 SubscriptionAction::PSubscribe
87 } else if first_arg.eq_ignore_ascii_case("SSUBSCRIBE") {
88 SubscriptionAction::SSubscribe
89 } else if first_arg.eq_ignore_ascii_case("UNSUBSCRIBE") {
90 SubscriptionAction::Unsubscribe
91 } else if first_arg.eq_ignore_ascii_case("PUNSUBSCRIBE") {
92 SubscriptionAction::PUnsubscribe
93 } else if first_arg.eq_ignore_ascii_case("SUNSUBSCRIBE") {
94 SubscriptionAction::Sunsubscribe
95 } else {
96 return;
97 };
98 self.update_with_request(action, args);
99 }
100
101 pub(crate) fn update_with_pipeline<'a>(&'a mut self, pipe: &'a Pipeline) {
102 for cmd in pipe.cmd_iter() {
103 self.update_with_cmd(cmd);
104 }
105 }
106
107 pub(crate) fn get_subscription_pipeline(&self) -> Pipeline {
108 let mut pipeline = crate::pipe();
109 if !self.subscriptions.is_empty() {
110 let cmd = pipeline.cmd("SUBSCRIBE");
111 for channel in self.subscriptions.iter() {
112 cmd.arg(channel);
113 }
114 }
115 if !self.s_subscriptions.is_empty() {
116 let cmd = pipeline.cmd("SSUBSCRIBE");
117 for channel in self.s_subscriptions.iter() {
118 cmd.arg(channel);
119 }
120 }
121 if !self.p_subscriptions.is_empty() {
122 let cmd = pipeline.cmd("PSUBSCRIBE");
123 for channel in self.p_subscriptions.iter() {
124 cmd.arg(channel);
125 }
126 }
127
128 pipeline
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use crate::{cmd, pipe};
135
136 use super::*;
137
138 #[test]
139 fn test_add_and_remove_subscriptions() {
140 let mut tracker = SubscriptionTracker::default();
141
142 tracker.update_with_cmd(cmd("subscribe").arg("foo").arg("bar"));
143 tracker.update_with_cmd(cmd("PSUBSCRIBE").arg("fo*o").arg("b*ar"));
144 tracker.update_with_cmd(cmd("SSUBSCRIBE").arg("sfoo").arg("sbar"));
145 tracker.update_with_cmd(cmd("unsubscribe").arg("foo"));
146 tracker.update_with_cmd(cmd("Punsubscribe").arg("b*ar"));
147 tracker.update_with_cmd(cmd("Sunsubscribe").arg("sfoo").arg("SBAR"));
148 tracker.update_with_cmd(cmd("GET").arg("sfoo"));
150
151 let result = tracker.get_subscription_pipeline();
152 let mut expected = pipe();
153 expected
154 .cmd("SUBSCRIBE")
155 .arg("bar")
156 .cmd("SSUBSCRIBE")
157 .arg("sbar")
158 .cmd("PSUBSCRIBE")
159 .arg("fo*o");
160 assert_eq!(
161 result.get_packed_pipeline(),
162 expected.get_packed_pipeline(),
163 "{}",
164 String::from_utf8(result.get_packed_pipeline()).unwrap()
165 );
166 }
167
168 #[test]
169 fn test_skip_empty_subscriptions() {
170 let mut tracker = SubscriptionTracker::default();
171
172 tracker.update_with_cmd(cmd("subscribe").arg("foo").arg("bar"));
173 tracker.update_with_cmd(cmd("PSUBSCRIBE").arg("fo*o").arg("b*ar"));
174 tracker.update_with_cmd(cmd("unsubscribe").arg("foo").arg("bar"));
175 tracker.update_with_cmd(cmd("punsubscribe").arg("fo*o"));
176
177 let result = tracker.get_subscription_pipeline();
178 let mut expected = pipe();
179 expected.cmd("PSUBSCRIBE").arg("b*ar");
180 assert_eq!(
181 result.get_packed_pipeline(),
182 expected.get_packed_pipeline(),
183 "{}",
184 String::from_utf8(result.get_packed_pipeline()).unwrap()
185 );
186 }
187
188 #[test]
189 fn test_add_and_remove_subscriptions_with_pipeline() {
190 let mut tracker = SubscriptionTracker::default();
191
192 tracker.update_with_pipeline(
193 pipe()
194 .cmd("subscribe")
195 .arg("foo")
196 .arg("bar")
197 .cmd("PSUBSCRIBE")
198 .arg("fo*o")
199 .arg("b*ar")
200 .cmd("SSUBSCRIBE")
201 .arg("sfoo")
202 .arg("sbar")
203 .cmd("unsubscribe")
204 .arg("foo")
205 .cmd("Punsubscribe")
206 .arg("b*ar")
207 .cmd("Sunsubscribe")
208 .arg("sfoo")
209 .arg("SBAR"),
210 );
211
212 let result = tracker.get_subscription_pipeline();
213 let mut expected = pipe();
214 expected
215 .cmd("SUBSCRIBE")
216 .arg("bar")
217 .cmd("SSUBSCRIBE")
218 .arg("sbar")
219 .cmd("PSUBSCRIBE")
220 .arg("fo*o");
221 assert_eq!(
222 result.get_packed_pipeline(),
223 expected.get_packed_pipeline(),
224 "{}",
225 String::from_utf8(result.get_packed_pipeline()).unwrap()
226 );
227 }
228
229 #[test]
230 fn test_only_unsubscribe_from_existing_subscriptions() {
231 let mut tracker = SubscriptionTracker::default();
232
233 tracker.update_with_cmd(cmd("unsubscribe").arg("foo"));
234 tracker.update_with_cmd(cmd("subscribe").arg("foo"));
235
236 let result = tracker.get_subscription_pipeline();
237 let mut expected = pipe();
238 expected.cmd("SUBSCRIBE").arg("foo");
239 assert_eq!(
240 result.get_packed_pipeline(),
241 expected.get_packed_pipeline(),
242 "{}",
243 String::from_utf8(result.get_packed_pipeline()).unwrap()
244 );
245 }
246}