cranelift_codegen/egraph/cost.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
//! Cost functions for egraph representation.
use crate::ir::Opcode;
/// A cost of computing some value in the program.
///
/// Costs are measured in an arbitrary union that we represent in a
/// `u32`. The ordering is meant to be meaningful, but the value of a
/// single unit is arbitrary (and "not to scale"). We use a collection
/// of heuristics to try to make this approximation at least usable.
///
/// We start by defining costs for each opcode (see `pure_op_cost`
/// below). The cost of computing some value, initially, is the cost
/// of its opcode, plus the cost of computing its inputs.
///
/// We then adjust the cost according to loop nests: for each
/// loop-nest level, we multiply by 1024. Because we only have 32
/// bits, we limit this scaling to a loop-level of two (i.e., multiply
/// by 2^20 ~= 1M).
///
/// Arithmetic on costs is always saturating: we don't want to wrap
/// around and return to a tiny cost when adding the costs of two very
/// expensive operations. It is better to approximate and lose some
/// precision than to lose the ordering by wrapping.
///
/// Finally, we reserve the highest value, `u32::MAX`, as a sentinel
/// that means "infinite". This is separate from the finite costs and
/// not reachable by doing arithmetic on them (even when overflowing)
/// -- we saturate just *below* infinity. (This is done by the
/// `finite()` method.) An infinite cost is used to represent a value
/// that cannot be computed, or otherwise serve as a sentinel when
/// performing search for the lowest-cost representation of a value.
#[derive(Clone, Copy, PartialEq, Eq)]
pub(crate) struct Cost(u32);
impl core::fmt::Debug for Cost {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
if *self == Cost::infinity() {
write!(f, "Cost::Infinite")
} else {
f.debug_struct("Cost::Finite")
.field("op_cost", &self.op_cost())
.field("depth", &self.depth())
.finish()
}
}
}
impl Ord for Cost {
#[inline]
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
// We make sure that the high bits are the op cost and the low bits are
// the depth. This means that we can use normal integer comparison to
// order by op cost and then depth.
//
// We want to break op cost ties with depth (rather than the other way
// around). When the op cost is the same, we prefer shallow and wide
// expressions to narrow and deep expressions and breaking ties with
// `depth` gives us that. For example, `(a + b) + (c + d)` is preferred
// to `((a + b) + c) + d`. This is beneficial because it exposes more
// instruction-level parallelism and shortens live ranges.
self.0.cmp(&other.0)
}
}
impl PartialOrd for Cost {
#[inline]
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Cost {
const DEPTH_BITS: u8 = 8;
const DEPTH_MASK: u32 = (1 << Self::DEPTH_BITS) - 1;
const OP_COST_MASK: u32 = !Self::DEPTH_MASK;
const MAX_OP_COST: u32 = Self::OP_COST_MASK >> Self::DEPTH_BITS;
pub(crate) fn infinity() -> Cost {
// 2^32 - 1 is, uh, pretty close to infinite... (we use `Cost`
// only for heuristics and always saturate so this suffices!)
Cost(u32::MAX)
}
pub(crate) fn zero() -> Cost {
Cost(0)
}
/// Construct a new `Cost` from the given parts.
///
/// If the opcode cost is greater than or equal to the maximum representable
/// opcode cost, then the resulting `Cost` saturates to infinity.
fn new(opcode_cost: u32, depth: u8) -> Cost {
if opcode_cost >= Self::MAX_OP_COST {
Self::infinity()
} else {
Cost(opcode_cost << Self::DEPTH_BITS | u32::from(depth))
}
}
fn depth(&self) -> u8 {
let depth = self.0 & Self::DEPTH_MASK;
u8::try_from(depth).unwrap()
}
fn op_cost(&self) -> u32 {
(self.0 & Self::OP_COST_MASK) >> Self::DEPTH_BITS
}
/// Compute the cost of the operation and its given operands.
///
/// Caller is responsible for checking that the opcode came from an instruction
/// that satisfies `inst_predicates::is_pure_for_egraph()`.
pub(crate) fn of_pure_op(op: Opcode, operand_costs: impl IntoIterator<Item = Self>) -> Self {
let c = pure_op_cost(op) + operand_costs.into_iter().sum();
Cost::new(c.op_cost(), c.depth().saturating_add(1))
}
}
impl std::iter::Sum<Cost> for Cost {
fn sum<I: Iterator<Item = Cost>>(iter: I) -> Self {
iter.fold(Self::zero(), |a, b| a + b)
}
}
impl std::default::Default for Cost {
fn default() -> Cost {
Cost::zero()
}
}
impl std::ops::Add<Cost> for Cost {
type Output = Cost;
fn add(self, other: Cost) -> Cost {
let op_cost = self.op_cost().saturating_add(other.op_cost());
let depth = std::cmp::max(self.depth(), other.depth());
Cost::new(op_cost, depth)
}
}
/// Return the cost of a *pure* opcode.
///
/// Caller is responsible for checking that the opcode came from an instruction
/// that satisfies `inst_predicates::is_pure_for_egraph()`.
fn pure_op_cost(op: Opcode) -> Cost {
match op {
// Constants.
Opcode::Iconst | Opcode::F32const | Opcode::F64const => Cost::new(1, 0),
// Extends/reduces.
Opcode::Uextend | Opcode::Sextend | Opcode::Ireduce | Opcode::Iconcat | Opcode::Isplit => {
Cost::new(2, 0)
}
// "Simple" arithmetic.
Opcode::Iadd
| Opcode::Isub
| Opcode::Band
| Opcode::Bor
| Opcode::Bxor
| Opcode::Bnot
| Opcode::Ishl
| Opcode::Ushr
| Opcode::Sshr => Cost::new(3, 0),
// Everything else (pure.)
_ => Cost::new(4, 0),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn add_cost() {
let a = Cost::new(5, 2);
let b = Cost::new(37, 3);
assert_eq!(a + b, Cost::new(42, 3));
assert_eq!(b + a, Cost::new(42, 3));
}
#[test]
fn add_infinity() {
let a = Cost::new(5, 2);
let b = Cost::infinity();
assert_eq!(a + b, Cost::infinity());
assert_eq!(b + a, Cost::infinity());
}
#[test]
fn op_cost_saturates_to_infinity() {
let a = Cost::new(Cost::MAX_OP_COST - 10, 2);
let b = Cost::new(11, 2);
assert_eq!(a + b, Cost::infinity());
assert_eq!(b + a, Cost::infinity());
}
#[test]
fn depth_saturates_to_max_depth() {
let a = Cost::new(10, u8::MAX);
let b = Cost::new(10, 1);
assert_eq!(
Cost::of_pure_op(Opcode::Iconst, [a, b]),
Cost::new(21, u8::MAX)
);
assert_eq!(
Cost::of_pure_op(Opcode::Iconst, [b, a]),
Cost::new(21, u8::MAX)
);
}
}