1#![no_std]
9
10#[cfg(any(feature = "std", unix, windows))]
11#[macro_use]
12extern crate std;
13extern crate alloc;
14
15use alloc::boxed::Box;
16use anyhow::Error;
17use core::cell::Cell;
18use core::marker::PhantomData;
19use core::ops::Range;
20
21cfg_if::cfg_if! {
22 if #[cfg(not(feature = "std"))] {
23 mod nostd;
24 use nostd as imp;
25 mod stackswitch;
26 } else if #[cfg(miri)] {
27 mod miri;
28 use miri as imp;
29 } else if #[cfg(windows)] {
30 mod windows;
31 use windows as imp;
32 } else if #[cfg(unix)] {
33 mod unix;
34 use unix as imp;
35 mod stackswitch;
36 } else {
37 compile_error!("fibers are not supported on this platform");
38 }
39}
40
41pub struct FiberStack(imp::FiberStack);
43
44fn _assert_send_sync() {
45 fn _assert_send<T: Send>() {}
46 fn _assert_sync<T: Sync>() {}
47
48 _assert_send::<FiberStack>();
49 _assert_sync::<FiberStack>();
50}
51
52pub type Result<T, E = imp::Error> = core::result::Result<T, E>;
53
54impl FiberStack {
55 pub fn new(size: usize, zeroed: bool) -> Result<Self> {
57 Ok(Self(imp::FiberStack::new(size, zeroed)?))
58 }
59
60 pub fn from_custom(custom: Box<dyn RuntimeFiberStack>) -> Result<Self> {
62 Ok(Self(imp::FiberStack::from_custom(custom)?))
63 }
64
65 pub unsafe fn from_raw_parts(bottom: *mut u8, guard_size: usize, len: usize) -> Result<Self> {
79 Ok(Self(unsafe {
80 imp::FiberStack::from_raw_parts(bottom, guard_size, len)?
81 }))
82 }
83
84 pub fn top(&self) -> Option<*mut u8> {
89 self.0.top()
90 }
91
92 pub fn range(&self) -> Option<Range<usize>> {
95 self.0.range()
96 }
97
98 pub fn is_from_raw_parts(&self) -> bool {
101 self.0.is_from_raw_parts()
102 }
103
104 pub fn guard_range(&self) -> Option<Range<*mut u8>> {
106 self.0.guard_range()
107 }
108}
109
110pub unsafe trait RuntimeFiberStackCreator: Send + Sync {
112 fn new_stack(&self, size: usize, zeroed: bool) -> Result<Box<dyn RuntimeFiberStack>, Error>;
118}
119
120pub unsafe trait RuntimeFiberStack: Send + Sync {
122 fn top(&self) -> *mut u8;
124 fn range(&self) -> Range<usize>;
126 fn guard_range(&self) -> Range<*mut u8>;
128}
129
130pub struct Fiber<'a, Resume, Yield, Return> {
131 stack: Option<FiberStack>,
132 inner: imp::Fiber,
133 done: Cell<bool>,
134 _phantom: PhantomData<&'a (Resume, Yield, Return)>,
135}
136
137pub struct Suspend<Resume, Yield, Return> {
138 inner: imp::Suspend,
139 _phantom: PhantomData<(Resume, Yield, Return)>,
140}
141
142enum RunResult<Resume, Yield, Return> {
147 Executing,
150
151 Resuming(Resume),
154
155 Yield(Yield),
158
159 Returned(Return),
162
163 #[cfg(feature = "std")]
165 Panicked(Box<dyn core::any::Any + Send>),
166}
167
168impl<'a, Resume, Yield, Return> Fiber<'a, Resume, Yield, Return> {
169 pub fn new(
175 stack: FiberStack,
176 func: impl FnOnce(Resume, &mut Suspend<Resume, Yield, Return>) -> Return + 'a,
177 ) -> Result<Self> {
178 let inner = imp::Fiber::new(&stack.0, func)?;
179
180 Ok(Self {
181 stack: Some(stack),
182 inner,
183 done: Cell::new(false),
184 _phantom: PhantomData,
185 })
186 }
187
188 pub fn resume(&self, val: Resume) -> Result<Return, Yield> {
204 assert!(!self.done.replace(true), "cannot resume a finished fiber");
205 let result = Cell::new(RunResult::Resuming(val));
206 self.inner.resume(&self.stack().0, &result);
207 match result.into_inner() {
208 RunResult::Resuming(_) | RunResult::Executing => unreachable!(),
209 RunResult::Yield(y) => {
210 self.done.set(false);
211 Err(y)
212 }
213 RunResult::Returned(r) => Ok(r),
214 #[cfg(feature = "std")]
215 RunResult::Panicked(_payload) => {
216 use std::panic;
217 panic::resume_unwind(_payload);
218 }
219 }
220 }
221
222 pub fn done(&self) -> bool {
224 self.done.get()
225 }
226
227 pub fn stack(&self) -> &FiberStack {
229 self.stack.as_ref().unwrap()
230 }
231
232 pub fn into_stack(mut self) -> FiberStack {
234 assert!(self.done());
235 self.stack.take().unwrap()
236 }
237}
238
239impl<Resume, Yield, Return> Suspend<Resume, Yield, Return> {
240 pub fn suspend(&mut self, value: Yield) -> Resume {
250 self.inner
251 .switch::<Resume, Yield, Return>(RunResult::Yield(value))
252 }
253
254 fn execute(
255 inner: imp::Suspend,
256 initial: Resume,
257 func: impl FnOnce(Resume, &mut Suspend<Resume, Yield, Return>) -> Return,
258 ) {
259 let mut suspend = Suspend {
260 inner,
261 _phantom: PhantomData,
262 };
263
264 #[cfg(feature = "std")]
265 let result = {
266 use std::panic::{self, AssertUnwindSafe};
267 let result = panic::catch_unwind(AssertUnwindSafe(|| (func)(initial, &mut suspend)));
268 match result {
269 Ok(result) => RunResult::Returned(result),
270 Err(panic) => RunResult::Panicked(panic),
271 }
272 };
273
274 #[cfg(not(feature = "std"))]
279 let result = RunResult::Returned((func)(initial, &mut suspend));
280
281 suspend.inner.exit::<Resume, Yield, Return>(result);
282 }
283}
284
285impl<A, B, C> Drop for Fiber<'_, A, B, C> {
286 fn drop(&mut self) {
287 debug_assert!(self.done.get(), "fiber dropped without finishing");
288 unsafe {
289 self.inner.drop::<A, B, C>();
290 }
291 }
292}
293
294#[cfg(all(test))]
295mod tests {
296 use super::{Fiber, FiberStack};
297 use alloc::string::ToString;
298 use std::cell::Cell;
299 use std::rc::Rc;
300
301 fn fiber_stack(size: usize) -> FiberStack {
302 FiberStack::new(size, false).unwrap()
303 }
304
305 #[test]
306 fn small_stacks() {
307 Fiber::<(), (), ()>::new(fiber_stack(0), |_, _| {})
308 .unwrap()
309 .resume(())
310 .unwrap();
311 Fiber::<(), (), ()>::new(fiber_stack(1), |_, _| {})
312 .unwrap()
313 .resume(())
314 .unwrap();
315 }
316
317 #[test]
318 fn smoke() {
319 let hit = Rc::new(Cell::new(false));
320 let hit2 = hit.clone();
321 let fiber = Fiber::<(), (), ()>::new(fiber_stack(1024 * 1024), move |_, _| {
322 hit2.set(true);
323 })
324 .unwrap();
325 assert!(!hit.get());
326 fiber.resume(()).unwrap();
327 assert!(hit.get());
328 }
329
330 #[test]
331 fn suspend_and_resume() {
332 let hit = Rc::new(Cell::new(false));
333 let hit2 = hit.clone();
334 let fiber = Fiber::<(), (), ()>::new(fiber_stack(1024 * 1024), move |_, s| {
335 s.suspend(());
336 hit2.set(true);
337 s.suspend(());
338 })
339 .unwrap();
340 assert!(!hit.get());
341 assert!(fiber.resume(()).is_err());
342 assert!(!hit.get());
343 assert!(fiber.resume(()).is_err());
344 assert!(hit.get());
345 assert!(fiber.resume(()).is_ok());
346 assert!(hit.get());
347 }
348
349 #[test]
350 fn backtrace_traces_to_host() {
351 #[inline(never)] fn look_for_me() {
353 run_test();
354 }
355 fn assert_contains_host() {
356 let trace = backtrace::Backtrace::new();
357 println!("{trace:?}");
358 assert!(
359 trace
360 .frames()
361 .iter()
362 .flat_map(|f| f.symbols())
363 .filter_map(|s| Some(s.name()?.to_string()))
364 .any(|s| s.contains("look_for_me"))
365 || cfg!(windows)
367 || cfg!(all(target_os = "macos", target_arch = "aarch64"))
369 || cfg!(target_arch = "arm")
372 || cfg!(asan)
374 || cfg!(miri)
376 );
377 }
378
379 fn run_test() {
380 let fiber = Fiber::<(), (), ()>::new(fiber_stack(1024 * 1024), move |(), s| {
381 assert_contains_host();
382 s.suspend(());
383 assert_contains_host();
384 s.suspend(());
385 assert_contains_host();
386 })
387 .unwrap();
388 assert!(fiber.resume(()).is_err());
389 assert!(fiber.resume(()).is_err());
390 assert!(fiber.resume(()).is_ok());
391 }
392
393 look_for_me();
394 }
395
396 #[test]
397 #[cfg(feature = "std")]
398 fn panics_propagated() {
399 use std::panic::{self, AssertUnwindSafe};
400
401 let a = Rc::new(Cell::new(false));
402 let b = SetOnDrop(a.clone());
403 let fiber = Fiber::<(), (), ()>::new(fiber_stack(1024 * 1024), move |(), _s| {
404 let _ = &b;
405 panic!();
406 })
407 .unwrap();
408 assert!(panic::catch_unwind(AssertUnwindSafe(|| fiber.resume(()))).is_err());
409 assert!(a.get());
410
411 struct SetOnDrop(Rc<Cell<bool>>);
412
413 impl Drop for SetOnDrop {
414 fn drop(&mut self) {
415 self.0.set(true);
416 }
417 }
418 }
419
420 #[test]
421 fn suspend_and_resume_values() {
422 let fiber = Fiber::new(fiber_stack(1024 * 1024), move |first, s| {
423 assert_eq!(first, 2.0);
424 assert_eq!(s.suspend(4), 3.0);
425 "hello".to_string()
426 })
427 .unwrap();
428 assert_eq!(fiber.resume(2.0), Err(4));
429 assert_eq!(fiber.resume(3.0), Ok("hello".to_string()));
430 }
431}