[Mlir-commits] [mlir] b72ac6f - [MLIR] Let matchers work on int ranges (#102494)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Aug 14 05:44:27 PDT 2024


Author: Tobias Gysi
Date: 2024-08-14T14:44:23+02:00
New Revision: b72ac6f97a5335e79659711d49db8f2694c02a0e

URL: https://github.com/llvm/llvm-project/commit/b72ac6f97a5335e79659711d49db8f2694c02a0e
DIFF: https://github.com/llvm/llvm-project/commit/b72ac6f97a5335e79659711d49db8f2694c02a0e.diff

LOG: [MLIR] Let matchers work on int ranges (#102494)

This commit adds three matchers that unlike the m_NonZero matcher
not only match constants, but also operations that implement the
InferIntRangeInterface. These matchers can then match a non-zero value
or a value that is not minus one based on the inferred range. Additionally,
the commit uses the new matchers in the getSpeculatability functions of
Arith's signed and unsigned integer divisions. At the moment, the
matchers only look at the defining operation to avoid expensive IR walks.

This range based matchers can be useful when hoisting divisions out of
a loop, which requires knowing the divisor is non-zero and not minus one
for signed divisions. Just checking for a constant divisor may not be
sufficient, if the divisor is, for example, the result of an operation that
returns the number of threads of a team of threads.

Added: 
    

Modified: 
    mlir/include/mlir/IR/Matchers.h
    mlir/lib/Dialect/Arith/IR/ArithOps.cpp
    mlir/test/Transforms/loop-invariant-code-motion.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index f6417f62d09e8c..6fa5a47109d20d 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -18,6 +18,7 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
 
 namespace mlir {
 
@@ -100,6 +101,39 @@ struct constant_op_binder {
   }
 };
 
+/// A matcher that matches operations that implement the
+/// `InferIntRangeInterface` interface, and binds the inferred range.
+struct infer_int_range_op_binder {
+  IntegerValueRange *bind_value;
+
+  explicit infer_int_range_op_binder(IntegerValueRange *bind_value)
+      : bind_value(bind_value) {}
+
+  bool match(Operation *op) {
+    auto inferIntRangeOp = dyn_cast<InferIntRangeInterface>(op);
+    if (!inferIntRangeOp)
+      return false;
+
+    // Set the range of all integer operands to the maximal range.
+    SmallVector<IntegerValueRange> argRanges =
+        llvm::map_to_vector(op->getOperands(), IntegerValueRange::getMaxRange);
+
+    // Infer the result result range if possible.
+    bool matched = false;
+    auto setResultRanges = [&](Value value,
+                               const IntegerValueRange &argRanges) {
+      if (argRanges.isUninitialized())
+        return;
+      if (value != op->getResult(0))
+        return;
+      *bind_value = argRanges;
+      matched = true;
+    };
+    inferIntRangeOp.inferResultRangesFromOptional(argRanges, setResultRanges);
+    return matched;
+  }
+};
+
 /// The matcher that matches operations that have the specified attribute
 /// name, and binds the attribute value.
 template <typename AttrT>
@@ -219,6 +253,31 @@ struct constant_int_predicate_matcher {
   }
 };
 
+/// A matcher that matches a given a constant scalar / vector splat / tensor
+/// splat integer value or a constant integer range that fulfills a predicate.
+struct constant_int_range_predicate_matcher {
+  bool (*predicate)(const ConstantIntRanges &);
+
+  bool match(Attribute attr) {
+    APInt value;
+    return constant_int_value_binder(&value).match(attr) &&
+           predicate(ConstantIntRanges::constant(value));
+  }
+
+  bool match(Operation *op) {
+    // Try to match a constant integer value first.
+    APInt value;
+    if (constant_int_value_binder(&value).match(op))
+      return predicate(ConstantIntRanges::constant(value));
+
+    // Otherwise, try to match an operation that implements the
+    // `InferIntRangeInterface` interface.
+    IntegerValueRange range;
+    return infer_int_range_op_binder(&range).match(op) &&
+           predicate(range.getValue());
+  }
+};
+
 /// The matcher that matches a certain kind of op.
 template <typename OpClass>
 struct op_matcher {
@@ -385,6 +444,31 @@ inline detail::constant_int_predicate_matcher m_NonZero() {
   return {[](const APInt &value) { return 0 != value; }};
 }
 
+/// Matches a constant scalar / vector splat / tensor splat integer or a
+/// unsigned integer range that does not contain zero. Note that this matcher
+/// interprets the target value as an unsigned integer.
+inline detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU() {
+  return {[](const ConstantIntRanges &range) { return range.umin().ugt(0); }};
+}
+
+/// Matches a constant scalar / vector splat / tensor splat integer or a
+/// signed integer range that does not contain zero. Note that this matcher
+/// interprets the target value as a signed integer.
+inline detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS() {
+  return {[](const ConstantIntRanges &range) {
+    return range.smin().sgt(0) || range.smax().slt(0);
+  }};
+}
+
+/// Matches a constant scalar / vector splat / tensor splat integer or a
+/// signed integer range that does not contain minus one. Note
+/// that this matcher interprets the target value as a signed integer.
+inline detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS() {
+  return {[](const ConstantIntRanges &range) {
+    return range.smin().sgt(-1) || range.smax().slt(-1);
+  }};
+}
+
 /// Matches a constant scalar / vector splat / tensor splat integer one.
 inline detail::constant_int_predicate_matcher m_One() {
   return {[](const APInt &value) { return 1 == value; }};

diff  --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 641b7d7e2d13be..254f54d9e459e1 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -595,10 +595,17 @@ OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
   return div0 ? Attribute() : result;
 }
 
-Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
+/// Returns whether an unsigned division by `divisor` is speculatable.
+static Speculation::Speculatability getDivUISpeculatability(Value divisor) {
   // X / 0 => UB
-  return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable
-                                             : Speculation::NotSpeculatable;
+  if (matchPattern(divisor, m_IntRangeWithoutZeroU()))
+    return Speculation::Speculatable;
+
+  return Speculation::NotSpeculatable;
+}
+
+Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
+  return getDivUISpeculatability(getRhs());
 }
 
 //===----------------------------------------------------------------------===//
@@ -624,16 +631,21 @@ OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
   return overflowOrDiv0 ? Attribute() : result;
 }
 
-Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
-  bool mayHaveUB = true;
-
-  APInt constRHS;
+/// Returns whether a signed division by `divisor` is speculatable. This
+/// function conservatively assumes that all signed division by -1 are not
+/// speculatable.
+static Speculation::Speculatability getDivSISpeculatability(Value divisor) {
   // X / 0 => UB
   // INT_MIN / -1 => UB
-  if (matchPattern(getRhs(), m_ConstantInt(&constRHS)))
-    mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
+  if (matchPattern(divisor, m_IntRangeWithoutZeroS()) &&
+      matchPattern(divisor, m_IntRangeWithoutNegOneS()))
+    return Speculation::Speculatable;
 
-  return mayHaveUB ? Speculation::NotSpeculatable : Speculation::Speculatable;
+  return Speculation::NotSpeculatable;
+}
+
+Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
+  return getDivSISpeculatability(getRhs());
 }
 
 //===----------------------------------------------------------------------===//
@@ -675,9 +687,7 @@ OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
 }
 
 Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
-  // X / 0 => UB
-  return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable
-                                             : Speculation::NotSpeculatable;
+  return getDivUISpeculatability(getRhs());
 }
 
 //===----------------------------------------------------------------------===//
@@ -746,15 +756,7 @@ OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
 }
 
 Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
-  bool mayHaveUB = true;
-
-  APInt constRHS;
-  // X / 0 => UB
-  // INT_MIN / -1 => UB
-  if (matchPattern(getRhs(), m_ConstantInt(&constRHS)))
-    mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
-
-  return mayHaveUB ? Speculation::NotSpeculatable : Speculation::Speculatable;
+  return getDivSISpeculatability(getRhs());
 }
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir
index dcc314f36ae0a8..47a49465e8a7cd 100644
--- a/mlir/test/Transforms/loop-invariant-code-motion.mlir
+++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir
@@ -923,6 +923,120 @@ func.func @speculate_ceildivsi_const(
   return
 }
 
+func.func @no_speculate_divui_range(
+// CHECK-LABEL: @no_speculate_divui_range(
+    %num: i8, %lb: index, %ub: index, %step: index) {
+  %denom = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+  scf.for %i = %lb to %ub step %step {
+// CHECK: scf.for
+// CHECK: arith.divui
+    %val = arith.divui %num, %denom : i8
+  }
+
+  return
+}
+
+func.func @no_speculate_divsi_range(
+// CHECK-LABEL: @no_speculate_divsi_range(
+    %num: i8, %lb: index, %ub: index, %step: index) {
+  %denom0 = test.with_bounds {smax = -1: i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+  %denom1 = test.with_bounds {smax = 127 : i8, smin = 0 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+  scf.for %i = %lb to %ub step %step {
+// CHECK: scf.for
+// CHECK-COUNT-2: arith.divsi
+    %val0 = arith.divsi %num, %denom0 : i8
+    %val1 = arith.divsi %num, %denom1 : i8
+  }
+
+  return
+}
+
+func.func @no_speculate_ceildivui_range(
+// CHECK-LABEL: @no_speculate_ceildivui_range(
+    %num: i8, %lb: index, %ub: index, %step: index) {
+  %denom = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+  scf.for %i = %lb to %ub step %step {
+// CHECK: scf.for
+// CHECK: arith.ceildivui
+    %val = arith.ceildivui %num, %denom : i8
+  }
+
+  return
+}
+
+func.func @no_speculate_ceildivsi_range(
+// CHECK-LABEL: @no_speculate_ceildivsi_range(
+    %num: i8, %lb: index, %ub: index, %step: index) {
+  %denom0 = test.with_bounds {smax = -1 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+  %denom1 = test.with_bounds {smax = 127 : i8, smin = 0 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+  scf.for %i = %lb to %ub step %step {
+// CHECK: scf.for
+// CHECK-COUNT-2: arith.ceildivsi
+    %val0 = arith.ceildivsi %num, %denom0 : i8
+    %val1 = arith.ceildivsi %num, %denom1 : i8
+  }
+
+  return
+}
+
+func.func @speculate_divui_range(
+// CHECK-LABEL: @speculate_divui_range(
+    %num: i8, %lb: index, %ub: index, %step: index) {
+  %denom = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 1 : i8} : i8
+  scf.for %i = %lb to %ub step %step {
+// CHECK: arith.divui
+// CHECK: scf.for
+    %val = arith.divui %num, %denom : i8
+  }
+
+  return
+}
+
+func.func @speculate_divsi_range(
+// CHECK-LABEL: @speculate_divsi_range(
+    %num: i8, %lb: index, %ub: index, %step: index) {
+  %denom0 = test.with_bounds {smax = 127 : i8, smin = 1 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+  %denom1 = test.with_bounds {smax = -2 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+  scf.for %i = %lb to %ub step %step {
+// CHECK-COUNT-2: arith.divsi
+// CHECK: scf.for
+    %val0 = arith.divsi %num, %denom0 : i8
+    %val1 = arith.divsi %num, %denom1 : i8
+
+  }
+
+  return
+}
+
+func.func @speculate_ceildivui_range(
+// CHECK-LABEL: @speculate_ceildivui_range(
+    %num: i8, %lb: index, %ub: index, %step: index) {
+  %denom = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 1 : i8} : i8
+  scf.for %i = %lb to %ub step %step {
+// CHECK: arith.ceildivui
+// CHECK: scf.for
+    %val = arith.ceildivui %num, %denom : i8
+  }
+
+  return
+}
+
+func.func @speculate_ceildivsi_range(
+// CHECK-LABEL: @speculate_ceildivsi_range(
+    %num: i8, %lb: index, %ub: index, %step: index) {
+  %denom0 = test.with_bounds {smax = 127 : i8, smin = 1 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+  %denom1 = test.with_bounds {smax = -2 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+  scf.for %i = %lb to %ub step %step {
+// CHECK-COUNT-2: arith.ceildivsi
+// CHECK: scf.for
+    %val0 = arith.ceildivsi %num, %denom0 : i8
+    %val1 = arith.ceildivsi %num, %denom1 : i8
+
+  }
+
+  return
+}
+
 // -----
 
 func.func @speculate_static_pack_and_unpack(%source: tensor<128x256xf32>,


        


More information about the Mlir-commits mailing list