wasmtime/compile/
stratify.rs

1//! Stratification of call graphs for parallel bottom-up inlining.
2//!
3//! This module takes a call graph and constructs a strata, which is essentially
4//! a parallel execution plan. A strata consists of an ordered sequence of
5//! layers, and a layer of an unordered set of functions. The `i`th layer must
6//! be processed before the `i + 1`th layer, but functions within the same layer
7//! may be processed in any order (and in parallel).
8//!
9//! For example, when given the following tree-like call graph:
10//!
11//! ```text
12//! +---+   +---+   +---+
13//! | a |-->| b |-->| c |
14//! +---+   +---+   +---+
15//!   |       |
16//!   |       |     +---+
17//!   |       '---->| d |
18//!   |             +---+
19//!   |
20//!   |     +---+   +---+
21//!   '---->| e |-->| f |
22//!         +---+   +---+
23//!           |
24//!           |     +---+
25//!           '---->| g |
26//!                 +---+
27//! ```
28//!
29//! then stratification will produce these layers:
30//!
31//! ```text
32//! [
33//!     {c, d, f, g},
34//!     {b, e},
35//!     {a},
36//! ]
37//! ```
38//!
39//! Our goal in constructing the layers is to maximize potential parallelism at
40//! each layer. Logically, we do this by finding the strongly-connected
41//! components of the input call graph and peeling off all of the leaves of
42//! SCCs' condensation (i.e. the DAG that the SCCs form; see the documentation
43//! for the `StronglyConnectedComponents::evaporation` method for
44//! details). These leaves become the strata's first layer. The layer's
45//! components are removed from the condensation graph, and we repeat the
46//! process, so that the condensation's new leaves become the strata's second
47//! layer, and etc... until the condensation graph is empty and all components
48//! have been processed. In practice we don't actually mutate the condensation
49//! graph or remove its nodes but instead count how many unprocessed
50//! dependencies each component has, and a component is ready for inclusion in a
51//! layer once its unprocessed-dependencies count reaches zero.
52
53use super::{
54    call_graph::CallGraph,
55    scc::{Scc, StronglyConnectedComponents},
56    *,
57};
58use std::{fmt::Debug, ops::Range};
59use wasmtime_environ::{EntityRef, SecondaryMap};
60
61/// A stratified call graph; essentially a parallel-execution plan for bottom-up
62/// inlining.
63///
64/// See the module doc comment for more details.
65pub struct Strata<Node> {
66    layers: Vec<Range<u32>>,
67    layer_elems: Vec<Node>,
68}
69
70impl<Node: Debug> Debug for Strata<Node> {
71    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
72        struct Layers<'a, Node>(&'a Strata<Node>);
73
74        impl<'a, Node: Debug> Debug for Layers<'a, Node> {
75            fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
76                let mut f = f.debug_list();
77                for layer in self.0.layers() {
78                    f.entry(&layer);
79                }
80                f.finish()
81            }
82        }
83
84        f.debug_struct("Strata")
85            .field("layers", &Layers(self))
86            .finish()
87    }
88}
89
90impl<Node> Strata<Node> {
91    /// Stratify the given call graph, yielding a `Strata` parallel-execution
92    /// plan.
93    pub fn new(nodes: impl IntoIterator<Item = Node>, call_graph: &CallGraph<Node>) -> Self
94    where
95        Node: EntityRef + Debug,
96    {
97        log::trace!("Stratifying {call_graph:#?}");
98
99        let components =
100            StronglyConnectedComponents::new(nodes, |node| call_graph.edges(node).iter().copied());
101        let evaporation = components.evaporation(|node| call_graph.edges(node).iter().copied());
102
103        // A map from each component to the count of how many call-graph
104        // dependencies to other components it has that have not been fulfilled
105        // yet. These counts are decremented as we assign a component's dependencies
106        // to layers.
107        let mut unfulfilled_deps_count = SecondaryMap::<Scc, u32>::with_capacity(components.len());
108        for to_component in components.keys() {
109            for from_component in evaporation.reverse_edges(to_component) {
110                unfulfilled_deps_count[*from_component] += 1;
111            }
112        }
113
114        // Build the strata.
115        //
116        // The first layer is formed by searching through all components for those
117        // that have a zero unfulfilled-deps count. When we finish a layer, we
118        // iterate over each of component in that layer and decrement the
119        // unfulfilled-deps count of every other component that depends on the
120        // newly-assigned-to-a-layer component. Any component that then reaches a
121        // zero unfulfilled-dep count is added to the next layer. This proceeds to a
122        // fixed point, similarly to GC tracing and ref-count decrementing.
123
124        let mut layers: Vec<Range<u32>> = vec![];
125        let mut layer_elems: Vec<Node> = Vec::with_capacity(call_graph.nodes().len());
126
127        let mut current_layer: Vec<Scc> = components
128            .keys()
129            .filter(|scc| unfulfilled_deps_count[*scc] == 0)
130            .collect();
131        debug_assert!(
132            !current_layer.is_empty() || call_graph.nodes().len() == 0,
133            "the first layer can only be empty when the call graph itself is empty"
134        );
135
136        let mut next_layer = vec![];
137
138        while !current_layer.is_empty() {
139            debug_assert!(next_layer.is_empty());
140
141            for dependee in &current_layer {
142                for depender in evaporation.reverse_edges(*dependee) {
143                    debug_assert!(unfulfilled_deps_count[*depender] > 0);
144                    unfulfilled_deps_count[*depender] -= 1;
145                    if unfulfilled_deps_count[*depender] == 0 {
146                        next_layer.push(*depender);
147                    }
148                }
149            }
150
151            layers.push(extend_with_range(
152                &mut layer_elems,
153                current_layer
154                    .drain(..)
155                    .flat_map(|scc| components.nodes(scc).iter().copied()),
156            ));
157
158            std::mem::swap(&mut next_layer, &mut current_layer);
159        }
160
161        debug_assert!(
162            unfulfilled_deps_count.values().all(|c| *c == 0),
163            "after every component is assigned to a layer, all dependencies should be fulfilled"
164        );
165
166        let result = Strata {
167            layers,
168            layer_elems,
169        };
170        log::trace!("  -> {result:#?}");
171        result
172    }
173
174    /// Iterate over the layers of this `Strata`.
175    ///
176    /// The `i`th layer must be processed before the `i + 1`th layer, but the
177    /// functions within a layer may be processed in any order and in parallel.
178    pub fn layers(&self) -> impl ExactSizeIterator<Item = &[Node]> {
179        self.layers.iter().map(|range| {
180            let start = usize::try_from(range.start).unwrap();
181            let end = usize::try_from(range.end).unwrap();
182            &self.layer_elems[start..end]
183        })
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190
191    #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
192    struct Function(u32);
193    wasmtime_environ::entity_impl!(Function);
194
195    struct Functions {
196        calls: SecondaryMap<Function, Vec<Function>>,
197    }
198
199    impl Default for Functions {
200        fn default() -> Self {
201            let _ = env_logger::try_init();
202            Self {
203                calls: Default::default(),
204            }
205        }
206    }
207
208    impl Functions {
209        fn define_func(&mut self, f: u32) -> &mut Self {
210            let f = Function::from_u32(f);
211            if self.calls.get(f).is_none() {
212                self.calls[f] = vec![];
213            }
214            self
215        }
216
217        fn define_call(&mut self, caller: u32, callee: u32) -> &mut Self {
218            self.define_func(caller);
219            self.define_func(callee);
220            let caller = Function::from_u32(caller);
221            let callee = Function::from_u32(callee);
222            self.calls[caller].push(callee);
223            self
224        }
225
226        fn define_calls(
227            &mut self,
228            caller: u32,
229            callees: impl IntoIterator<Item = u32>,
230        ) -> &mut Self {
231            for callee in callees {
232                self.define_call(caller, callee);
233            }
234            self
235        }
236
237        fn stratify(&self) -> Strata<Function> {
238            let call_graph = CallGraph::new(self.calls.keys(), |f, calls| {
239                calls.extend_from_slice(&self.calls[f]);
240                Ok(())
241            })
242            .unwrap();
243            Strata::<Function>::new(self.calls.keys(), &call_graph)
244        }
245
246        fn assert_stratification(&self, mut expected: Vec<Vec<u32>>) {
247            for layer in &mut expected {
248                layer.sort();
249            }
250            log::trace!("expected stratification = {expected:?}");
251
252            let actual = self
253                .stratify()
254                .layers()
255                .map(|layer| {
256                    let mut layer = layer.iter().map(|f| f.as_u32()).collect::<Vec<_>>();
257                    layer.sort();
258                    layer
259                })
260                .collect::<Vec<_>>();
261            log::trace!("actual stratification = {actual:?}");
262
263            assert_eq!(expected.len(), actual.iter().len());
264            for (expected, actual) in expected.into_iter().zip(actual) {
265                log::trace!("expected layer = {expected:?}");
266                log::trace!("  actual layer = {expected:?}");
267
268                assert_eq!(expected.len(), actual.len());
269                for (expected, actual) in expected.into_iter().zip(actual) {
270                    assert_eq!(expected, actual);
271                }
272            }
273        }
274    }
275
276    #[test]
277    fn test_disconnected_functions() {
278        // +---+   +---+   +---+
279        // | 0 |   | 1 |   | 2 |
280        // +---+   +---+   +---+
281        Functions::default()
282            .define_func(0)
283            .define_func(1)
284            .define_func(2)
285            .assert_stratification(vec![vec![0, 1, 2]]);
286    }
287
288    #[test]
289    fn test_chained_functions() {
290        // +---+   +---+   +---+
291        // | 0 |-->| 1 |-->| 2 |
292        // +---+   +---+   +---+
293        Functions::default()
294            .define_call(0, 1)
295            .define_call(1, 2)
296            .assert_stratification(vec![vec![2], vec![1], vec![0]]);
297    }
298
299    #[test]
300    fn test_cycle() {
301        //   ,---------------.
302        //   V               |
303        // +---+   +---+   +---+
304        // | 0 |-->| 1 |-->| 2 |
305        // +---+   +---+   +---+
306        Functions::default()
307            .define_call(0, 1)
308            .define_call(1, 2)
309            .define_call(2, 0)
310            .assert_stratification(vec![vec![0, 1, 2]]);
311    }
312
313    #[test]
314    fn test_tree() {
315        // +---+   +---+   +---+
316        // | 0 |-->| 1 |-->| 2 |
317        // +---+   +---+   +---+
318        //   |       |
319        //   |       |     +---+
320        //   |       '---->| 3 |
321        //   |             +---+
322        //   |
323        //   |     +---+   +---+
324        //   '---->| 4 |-->| 5 |
325        //         +---+   +---+
326        //           |
327        //           |     +---+
328        //           '---->| 6 |
329        //                 +---+
330        Functions::default()
331            .define_calls(0, [1, 4])
332            .define_calls(1, [2, 3])
333            .define_calls(4, [5, 6])
334            .assert_stratification(vec![vec![2, 3, 5, 6], vec![1, 4], vec![0]]);
335    }
336
337    #[test]
338    fn test_chain_of_cycles() {
339        //   ,-----.
340        //   |     |
341        //   V     |
342        // +---+   |
343        // | 0 |---'
344        // +---+
345        //   |
346        //   V
347        // +---+    +---+
348        // | 1 |<-->| 2 |
349        // +---+    +---+
350        //  |
351        //  | ,----------------.
352        //  | |                |
353        //  V |                V
354        // +---+    +---+    +---+
355        // | 3 |<---| 4 |<---| 5 |
356        // +---+    +---+    +---+
357        Functions::default()
358            .define_calls(0, [0, 1])
359            .define_calls(1, [2, 3])
360            .define_calls(2, [1])
361            .define_calls(3, [5])
362            .define_calls(4, [3])
363            .define_calls(5, [4])
364            .assert_stratification(vec![vec![3, 4, 5], vec![1, 2], vec![0]]);
365    }
366
367    #[test]
368    fn test_multiple_edges_to_same_component() {
369        // +---+           +---+
370        // | 0 |           | 1 |
371        // +---+           +---+
372        //   ^               ^
373        //   |               |
374        //   V               V
375        // +---+           +---+
376        // | 2 |           | 3 |
377        // +---+           +---+
378        //   |               |
379        //   `------. ,------'
380        //          | |
381        //          V V
382        //         +---+
383        //         | 4 |
384        //         +---+
385        //           ^
386        //           |
387        //           V
388        //         +---+
389        //         | 5 |
390        //         +---+
391        Functions::default()
392            .define_calls(0, [2])
393            .define_calls(1, [3])
394            .define_calls(2, [0, 4])
395            .define_calls(3, [1, 4])
396            .define_calls(4, [5])
397            .define_calls(5, [4])
398            .assert_stratification(vec![vec![4, 5], vec![0, 1, 2, 3]]);
399    }
400}