cranelift_codegen/
unionfind.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
//! Simple union-find data structure.

use crate::trace;
use cranelift_entity::{packed_option::ReservedValue, EntityRef, SecondaryMap};
use std::hash::Hash;
use std::mem::swap;

/// A union-find data structure. The data structure can allocate
/// `Idx`s, indicating eclasses, and can merge eclasses together.
///
/// Running `union(a, b)` will change the canonical `Idx` of `a` or `b`.
/// Usually, this is chosen based on what will minimize path lengths,
/// but it is also possible to _pin_ an eclass, such that its canonical `Idx`
/// won't change unless it gets unioned with another pinned eclass.
///
/// In the context of the egraph pass, merging two pinned eclasses
/// is very unlikely to happen – we do not know a single concrete test case
/// where it does. The only situation where it might happen looks as follows:
///
/// 1. We encounter terms `A` and `B`, and the optimizer does not find any
///    reason to union them together.
/// 2. We encounter a term `C`, and we rewrite `C -> A`, and separately, `C -> B`.
///
/// Unless `C` somehow includes some crucial hint without which it is hard to
/// notice that `A = B`, there's probably a rewrite rule that we should add.
///
/// Worst case, if we do merge two pinned eclasses, some nodes will essentially
/// disappear from the GVN map, which only affects the quality of the generated
/// code.
#[derive(Clone, Debug, PartialEq)]
pub struct UnionFind<Idx: EntityRef> {
    parent: SecondaryMap<Idx, Val<Idx>>,
    /// The `rank` table is used to perform the union operations optimally,
    /// without creating unnecessarily long paths. Pins are represented by
    /// eclasses with a rank of `u8::MAX`.
    ///
    /// `rank[x]` is the upper bound on the height of the subtree rooted at `x`.
    /// The subtree is guaranteed to have at least `2**rank[x]` elements,
    /// unless `rank` has been artificially inflated by pinning.
    rank: SecondaryMap<Idx, u8>,

    pub(crate) pinned_union_count: u64,
}

#[derive(Clone, Debug, PartialEq)]
struct Val<Idx>(Idx);

impl<Idx: EntityRef + ReservedValue> Default for Val<Idx> {
    fn default() -> Self {
        Self(Idx::reserved_value())
    }
}

impl<Idx: EntityRef + Hash + std::fmt::Display + Ord + ReservedValue> UnionFind<Idx> {
    /// Create a new `UnionFind` with the given capacity.
    pub fn with_capacity(cap: usize) -> Self {
        UnionFind {
            parent: SecondaryMap::with_capacity(cap),
            rank: SecondaryMap::with_capacity(cap),
            pinned_union_count: 0,
        }
    }

    /// Add an `Idx` to the `UnionFind`, with its own equivalence class
    /// initially. All `Idx`s must be added before being queried or
    /// unioned.
    pub fn add(&mut self, id: Idx) {
        debug_assert!(id != Idx::reserved_value());
        self.parent[id] = Val(id);
    }

    /// Find the canonical `Idx` of a given `Idx`.
    pub fn find(&self, mut node: Idx) -> Idx {
        while node != self.parent[node].0 {
            node = self.parent[node].0;
        }
        node
    }

    /// Find the canonical `Idx` of a given `Idx`, updating the data
    /// structure in the process so that future queries for this `Idx`
    /// (and others in its chain up to the root of the equivalence
    /// class) will be faster.
    pub fn find_and_update(&mut self, mut node: Idx) -> Idx {
        // "Path halving" mutating find (Tarjan and Van Leeuwen).
        debug_assert!(node != Idx::reserved_value());
        while node != self.parent[node].0 {
            let next = self.parent[self.parent[node].0].0;
            debug_assert!(next != Idx::reserved_value());
            self.parent[node] = Val(next);
            node = next;
        }
        debug_assert!(node != Idx::reserved_value());
        node
    }

    /// Request a stable identifier for `node`.
    ///
    /// After an `union` operation, the canonical representative of one
    /// of the eclasses being merged together necessarily changes. If a pinned
    /// eclass is merged with a non-pinned eclass, it'll be the other eclass
    /// whose representative will change.
    ///
    /// If two pinned eclasses are unioned, one of the pins gets broken,
    /// which is reported in the statistics for the pass. No concrete test case
    /// which triggers this is known.
    pub fn pin_index(&mut self, mut node: Idx) -> Idx {
        node = self.find_and_update(node);
        self.rank[node] = u8::MAX;
        node
    }

    /// Merge the equivalence classes of the two `Idx`s.
    pub fn union(&mut self, a: Idx, b: Idx) {
        let mut a = self.find_and_update(a);
        let mut b = self.find_and_update(b);

        if a == b {
            return;
        }

        if self.rank[a] < self.rank[b] {
            swap(&mut a, &mut b);
        } else if self.rank[a] == self.rank[b] {
            self.rank[a] = self.rank[a].checked_add(1).unwrap_or_else(
                #[cold]
                || {
                    // Both `a` and `b` are pinned.
                    //
                    // This should only occur if we rewrite X -> Y and X -> Z,
                    // yet neither Y -> Z nor Z -> Y can be established without
                    // the "hint" provided by X. This probably means we're
                    // missing an optimization rule.
                    self.pinned_union_count += 1;
                    u8::MAX
                },
            );
        }

        self.parent[b] = Val(a);
        trace!("union: {}, {}", a, b);
    }
}