cranelift_codegen/unionfind.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
//! Simple union-find data structure.
use crate::trace;
use cranelift_entity::{packed_option::ReservedValue, EntityRef, SecondaryMap};
use std::hash::Hash;
use std::mem::swap;
/// A union-find data structure. The data structure can allocate
/// `Idx`s, indicating eclasses, and can merge eclasses together.
///
/// Running `union(a, b)` will change the canonical `Idx` of `a` or `b`.
/// Usually, this is chosen based on what will minimize path lengths,
/// but it is also possible to _pin_ an eclass, such that its canonical `Idx`
/// won't change unless it gets unioned with another pinned eclass.
///
/// In the context of the egraph pass, merging two pinned eclasses
/// is very unlikely to happen – we do not know a single concrete test case
/// where it does. The only situation where it might happen looks as follows:
///
/// 1. We encounter terms `A` and `B`, and the optimizer does not find any
/// reason to union them together.
/// 2. We encounter a term `C`, and we rewrite `C -> A`, and separately, `C -> B`.
///
/// Unless `C` somehow includes some crucial hint without which it is hard to
/// notice that `A = B`, there's probably a rewrite rule that we should add.
///
/// Worst case, if we do merge two pinned eclasses, some nodes will essentially
/// disappear from the GVN map, which only affects the quality of the generated
/// code.
#[derive(Clone, Debug, PartialEq)]
pub struct UnionFind<Idx: EntityRef> {
parent: SecondaryMap<Idx, Val<Idx>>,
/// The `rank` table is used to perform the union operations optimally,
/// without creating unnecessarily long paths. Pins are represented by
/// eclasses with a rank of `u8::MAX`.
///
/// `rank[x]` is the upper bound on the height of the subtree rooted at `x`.
/// The subtree is guaranteed to have at least `2**rank[x]` elements,
/// unless `rank` has been artificially inflated by pinning.
rank: SecondaryMap<Idx, u8>,
pub(crate) pinned_union_count: u64,
}
#[derive(Clone, Debug, PartialEq)]
struct Val<Idx>(Idx);
impl<Idx: EntityRef + ReservedValue> Default for Val<Idx> {
fn default() -> Self {
Self(Idx::reserved_value())
}
}
impl<Idx: EntityRef + Hash + std::fmt::Display + Ord + ReservedValue> UnionFind<Idx> {
/// Create a new `UnionFind` with the given capacity.
pub fn with_capacity(cap: usize) -> Self {
UnionFind {
parent: SecondaryMap::with_capacity(cap),
rank: SecondaryMap::with_capacity(cap),
pinned_union_count: 0,
}
}
/// Add an `Idx` to the `UnionFind`, with its own equivalence class
/// initially. All `Idx`s must be added before being queried or
/// unioned.
pub fn add(&mut self, id: Idx) {
debug_assert!(id != Idx::reserved_value());
self.parent[id] = Val(id);
}
/// Find the canonical `Idx` of a given `Idx`.
pub fn find(&self, mut node: Idx) -> Idx {
while node != self.parent[node].0 {
node = self.parent[node].0;
}
node
}
/// Find the canonical `Idx` of a given `Idx`, updating the data
/// structure in the process so that future queries for this `Idx`
/// (and others in its chain up to the root of the equivalence
/// class) will be faster.
pub fn find_and_update(&mut self, mut node: Idx) -> Idx {
// "Path halving" mutating find (Tarjan and Van Leeuwen).
debug_assert!(node != Idx::reserved_value());
while node != self.parent[node].0 {
let next = self.parent[self.parent[node].0].0;
debug_assert!(next != Idx::reserved_value());
self.parent[node] = Val(next);
node = next;
}
debug_assert!(node != Idx::reserved_value());
node
}
/// Request a stable identifier for `node`.
///
/// After an `union` operation, the canonical representative of one
/// of the eclasses being merged together necessarily changes. If a pinned
/// eclass is merged with a non-pinned eclass, it'll be the other eclass
/// whose representative will change.
///
/// If two pinned eclasses are unioned, one of the pins gets broken,
/// which is reported in the statistics for the pass. No concrete test case
/// which triggers this is known.
pub fn pin_index(&mut self, mut node: Idx) -> Idx {
node = self.find_and_update(node);
self.rank[node] = u8::MAX;
node
}
/// Merge the equivalence classes of the two `Idx`s.
pub fn union(&mut self, a: Idx, b: Idx) {
let mut a = self.find_and_update(a);
let mut b = self.find_and_update(b);
if a == b {
return;
}
if self.rank[a] < self.rank[b] {
swap(&mut a, &mut b);
} else if self.rank[a] == self.rank[b] {
self.rank[a] = self.rank[a].checked_add(1).unwrap_or_else(
#[cold]
|| {
// Both `a` and `b` are pinned.
//
// This should only occur if we rewrite X -> Y and X -> Z,
// yet neither Y -> Z nor Z -> Y can be established without
// the "hint" provided by X. This probably means we're
// missing an optimization rule.
self.pinned_union_count += 1;
u8::MAX
},
);
}
self.parent[b] = Val(a);
trace!("union: {}, {}", a, b);
}
}