cranelift_codegen/egraph/
cost.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
//! Cost functions for egraph representation.

use crate::ir::Opcode;

/// A cost of computing some value in the program.
///
/// Costs are measured in an arbitrary union that we represent in a
/// `u32`. The ordering is meant to be meaningful, but the value of a
/// single unit is arbitrary (and "not to scale"). We use a collection
/// of heuristics to try to make this approximation at least usable.
///
/// We start by defining costs for each opcode (see `pure_op_cost`
/// below). The cost of computing some value, initially, is the cost
/// of its opcode, plus the cost of computing its inputs.
///
/// We then adjust the cost according to loop nests: for each
/// loop-nest level, we multiply by 1024. Because we only have 32
/// bits, we limit this scaling to a loop-level of two (i.e., multiply
/// by 2^20 ~= 1M).
///
/// Arithmetic on costs is always saturating: we don't want to wrap
/// around and return to a tiny cost when adding the costs of two very
/// expensive operations. It is better to approximate and lose some
/// precision than to lose the ordering by wrapping.
///
/// Finally, we reserve the highest value, `u32::MAX`, as a sentinel
/// that means "infinite". This is separate from the finite costs and
/// not reachable by doing arithmetic on them (even when overflowing)
/// -- we saturate just *below* infinity. (This is done by the
/// `finite()` method.) An infinite cost is used to represent a value
/// that cannot be computed, or otherwise serve as a sentinel when
/// performing search for the lowest-cost representation of a value.
#[derive(Clone, Copy, PartialEq, Eq)]
pub(crate) struct Cost(u32);

impl core::fmt::Debug for Cost {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        if *self == Cost::infinity() {
            write!(f, "Cost::Infinite")
        } else {
            f.debug_struct("Cost::Finite")
                .field("op_cost", &self.op_cost())
                .field("depth", &self.depth())
                .finish()
        }
    }
}

impl Ord for Cost {
    #[inline]
    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
        // We make sure that the high bits are the op cost and the low bits are
        // the depth. This means that we can use normal integer comparison to
        // order by op cost and then depth.
        //
        // We want to break op cost ties with depth (rather than the other way
        // around). When the op cost is the same, we prefer shallow and wide
        // expressions to narrow and deep expressions and breaking ties with
        // `depth` gives us that. For example, `(a + b) + (c + d)` is preferred
        // to `((a + b) + c) + d`. This is beneficial because it exposes more
        // instruction-level parallelism and shortens live ranges.
        self.0.cmp(&other.0)
    }
}

impl PartialOrd for Cost {
    #[inline]
    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
        Some(self.cmp(other))
    }
}

impl Cost {
    const DEPTH_BITS: u8 = 8;
    const DEPTH_MASK: u32 = (1 << Self::DEPTH_BITS) - 1;
    const OP_COST_MASK: u32 = !Self::DEPTH_MASK;
    const MAX_OP_COST: u32 = Self::OP_COST_MASK >> Self::DEPTH_BITS;

    pub(crate) fn infinity() -> Cost {
        // 2^32 - 1 is, uh, pretty close to infinite... (we use `Cost`
        // only for heuristics and always saturate so this suffices!)
        Cost(u32::MAX)
    }

    pub(crate) fn zero() -> Cost {
        Cost(0)
    }

    /// Construct a new `Cost` from the given parts.
    ///
    /// If the opcode cost is greater than or equal to the maximum representable
    /// opcode cost, then the resulting `Cost` saturates to infinity.
    fn new(opcode_cost: u32, depth: u8) -> Cost {
        if opcode_cost >= Self::MAX_OP_COST {
            Self::infinity()
        } else {
            Cost(opcode_cost << Self::DEPTH_BITS | u32::from(depth))
        }
    }

    fn depth(&self) -> u8 {
        let depth = self.0 & Self::DEPTH_MASK;
        u8::try_from(depth).unwrap()
    }

    fn op_cost(&self) -> u32 {
        (self.0 & Self::OP_COST_MASK) >> Self::DEPTH_BITS
    }

    /// Compute the cost of the operation and its given operands.
    ///
    /// Caller is responsible for checking that the opcode came from an instruction
    /// that satisfies `inst_predicates::is_pure_for_egraph()`.
    pub(crate) fn of_pure_op(op: Opcode, operand_costs: impl IntoIterator<Item = Self>) -> Self {
        let c = pure_op_cost(op) + operand_costs.into_iter().sum();
        Cost::new(c.op_cost(), c.depth().saturating_add(1))
    }
}

impl std::iter::Sum<Cost> for Cost {
    fn sum<I: Iterator<Item = Cost>>(iter: I) -> Self {
        iter.fold(Self::zero(), |a, b| a + b)
    }
}

impl std::default::Default for Cost {
    fn default() -> Cost {
        Cost::zero()
    }
}

impl std::ops::Add<Cost> for Cost {
    type Output = Cost;

    fn add(self, other: Cost) -> Cost {
        let op_cost = self.op_cost().saturating_add(other.op_cost());
        let depth = std::cmp::max(self.depth(), other.depth());
        Cost::new(op_cost, depth)
    }
}

/// Return the cost of a *pure* opcode.
///
/// Caller is responsible for checking that the opcode came from an instruction
/// that satisfies `inst_predicates::is_pure_for_egraph()`.
fn pure_op_cost(op: Opcode) -> Cost {
    match op {
        // Constants.
        Opcode::Iconst | Opcode::F32const | Opcode::F64const => Cost::new(1, 0),

        // Extends/reduces.
        Opcode::Uextend | Opcode::Sextend | Opcode::Ireduce | Opcode::Iconcat | Opcode::Isplit => {
            Cost::new(2, 0)
        }

        // "Simple" arithmetic.
        Opcode::Iadd
        | Opcode::Isub
        | Opcode::Band
        | Opcode::Bor
        | Opcode::Bxor
        | Opcode::Bnot
        | Opcode::Ishl
        | Opcode::Ushr
        | Opcode::Sshr => Cost::new(3, 0),

        // Everything else (pure.)
        _ => Cost::new(4, 0),
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn add_cost() {
        let a = Cost::new(5, 2);
        let b = Cost::new(37, 3);
        assert_eq!(a + b, Cost::new(42, 3));
        assert_eq!(b + a, Cost::new(42, 3));
    }

    #[test]
    fn add_infinity() {
        let a = Cost::new(5, 2);
        let b = Cost::infinity();
        assert_eq!(a + b, Cost::infinity());
        assert_eq!(b + a, Cost::infinity());
    }

    #[test]
    fn op_cost_saturates_to_infinity() {
        let a = Cost::new(Cost::MAX_OP_COST - 10, 2);
        let b = Cost::new(11, 2);
        assert_eq!(a + b, Cost::infinity());
        assert_eq!(b + a, Cost::infinity());
    }

    #[test]
    fn depth_saturates_to_max_depth() {
        let a = Cost::new(10, u8::MAX);
        let b = Cost::new(10, 1);
        assert_eq!(
            Cost::of_pure_op(Opcode::Iconst, [a, b]),
            Cost::new(21, u8::MAX)
        );
        assert_eq!(
            Cost::of_pure_op(Opcode::Iconst, [b, a]),
            Cost::new(21, u8::MAX)
        );
    }
}