[Mlir-commits] [mlir] [MLIR] Let matchers work on int ranges (PR #102494)

Tobias Gysi llvmlistbot at llvm.org
Tue Aug 13 09:29:30 PDT 2024


https://github.com/gysit updated https://github.com/llvm/llvm-project/pull/102494

>From 03f4053a934bd1ed044713961086c2c7ae6fdc06 Mon Sep 17 00:00:00 2001
From: Tobias Gysi <tobias.gysi at nextsilicon.com>
Date: Thu, 8 Aug 2024 15:46:18 +0000
Subject: [PATCH 1/3] [MLIR] Use InferIntRangeInterface in matchers

This commit adds a m_NonZeroUnsigned matcher that unlike the
m_NonZero matcher not only matches constants, but also operations that
implement the InferIntRangeInterface. It can then match a non-zero value
based on the inferred range. Additionally, the commit uses the new
matcher in the getSpeculatability function of Arith's unsigned integer
division. At the moment, the matcher only looks at the defining
operation to avoid expensive IR walks.

This range based matchers can be useful when hoisting divisions out of
a loop, which requires knowing the divisor is non-zero. Just checking
for a constant divisor may not be sufficient, if the divisor is,
for example, the result of an operation that returns the number of
threads of a team of threads.
---
 mlir/include/mlir/IR/Matchers.h               | 68 +++++++++++++++++++
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        |  8 +--
 .../Arith/hoist-speculatable-division.mlir    | 54 +++++++++++++++
 3 files changed, 126 insertions(+), 4 deletions(-)
 create mode 100644 mlir/test/Dialect/Arith/hoist-speculatable-division.mlir

diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index f6417f62d09e8c..32b15265ca2c83 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -18,6 +18,7 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
 
 namespace mlir {
 
@@ -100,6 +101,41 @@ struct constant_op_binder {
   }
 };
 
+/// A matcher that matches operations that implement the
+/// `InferIntRangeInterface` interface, and binds the inferred range.
+struct infer_int_range_op_binder {
+  IntegerValueRange *bind_value;
+
+  explicit infer_int_range_op_binder(IntegerValueRange *bind_value)
+      : bind_value(bind_value) {}
+
+  bool match(Operation *op) {
+    auto inferIntRangeOp = dyn_cast<InferIntRangeInterface>(op);
+    if (!inferIntRangeOp)
+      return false;
+
+    // Set the range of all integer operands to the maximal range.
+    SmallVector<IntegerValueRange> argRanges;
+    argRanges.reserve(op->getNumOperands());
+    for (Value operand : op->getOperands())
+      argRanges.emplace_back(IntegerValueRange::getMaxRange(operand));
+
+    // Infer the result result range if possible.
+    bool matched = false;
+    auto setResultRanges = [&](Value value,
+                               const IntegerValueRange &argRanges) {
+      if (argRanges.isUninitialized())
+        return;
+      if (value != op->getResult(0))
+        return;
+      *bind_value = argRanges;
+      matched = true;
+    };
+    inferIntRangeOp.inferResultRangesFromOptional(argRanges, setResultRanges);
+    return matched;
+  }
+};
+
 /// The matcher that matches operations that have the specified attribute
 /// name, and binds the attribute value.
 template <typename AttrT>
@@ -219,6 +255,31 @@ struct constant_int_predicate_matcher {
   }
 };
 
+/// A matcher that matches a given a constant scalar / vector splat / tensor
+/// splat integer value or a constant integer range that fulfills a predicate.
+struct constant_int_range_predicate_matcher {
+  bool (*predicate)(const ConstantIntRanges &);
+
+  bool match(Attribute attr) {
+    APInt value;
+    return constant_int_value_binder(&value).match(attr) &&
+           predicate(ConstantIntRanges::constant(value));
+  }
+
+  bool match(Operation *op) {
+    // Try to match a constant integer value first.
+    APInt value;
+    if (constant_int_value_binder(&value).match(op))
+      return predicate(ConstantIntRanges::constant(value));
+
+    // Otherwise, try to match an operation that implements the
+    // `InferIntRangeInterface` interface.
+    IntegerValueRange range;
+    return infer_int_range_op_binder(&range).match(op) &&
+           predicate(range.getValue());
+  }
+};
+
 /// The matcher that matches a certain kind of op.
 template <typename OpClass>
 struct op_matcher {
@@ -385,6 +446,13 @@ inline detail::constant_int_predicate_matcher m_NonZero() {
   return {[](const APInt &value) { return 0 != value; }};
 }
 
+/// Matches a constant scalar / vector splat / tensor splat integer or a
+/// unsigned integer range that does not contain zero. Note that this matcher
+/// interprets the target value as an unsigned integer.
+inline detail::constant_int_range_predicate_matcher m_NonZeroU() {
+  return {[](const ConstantIntRanges &range) { return range.umin().ugt(0); }};
+}
+
 /// Matches a constant scalar / vector splat / tensor splat integer one.
 inline detail::constant_int_predicate_matcher m_One() {
   return {[](const APInt &value) { return 1 == value; }};
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 641b7d7e2d13be..31f09d54f9759f 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -597,8 +597,8 @@ OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
 
 Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
   // X / 0 => UB
-  return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable
-                                             : Speculation::NotSpeculatable;
+  return matchPattern(getRhs(), m_NonZeroU()) ? Speculation::Speculatable
+                                              : Speculation::NotSpeculatable;
 }
 
 //===----------------------------------------------------------------------===//
@@ -676,8 +676,8 @@ OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
 
 Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
   // X / 0 => UB
-  return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable
-                                             : Speculation::NotSpeculatable;
+  return matchPattern(getRhs(), m_NonZeroU()) ? Speculation::Speculatable
+                                              : Speculation::NotSpeculatable;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir b/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir
new file mode 100644
index 00000000000000..b4bd525b502562
--- /dev/null
+++ b/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir
@@ -0,0 +1,54 @@
+// RUN: mlir-opt -loop-invariant-code-motion -split-input-file %s | FileCheck %s
+
+// Verify a division by a non-zero constant is hoisted out of the loop.
+
+// CHECK-LABEL: func @match_non_zero_constant
+func.func @match_non_zero_constant(%arg0: i32) {
+  %lb = arith.constant 0 : index
+  %ub = arith.constant 10 : index
+  %step = arith.constant 1 : index
+  // CHECK: %[[CST0:.*]] = arith.constant 0 : i32
+  %cst0 = arith.constant 0 : i32
+  // CHECK: %[[CST1:.*]] = arith.constant 1 : i32
+  %cst1 = arith.constant 1 : i32
+  // CHECK: = arith.divui %{{.*}}, %[[CST1]]
+  // CHECK: scf.for
+  scf.for %idx = %lb to %ub step %step {
+    // CHECK: = arith.divui %{{.*}}, %[[CST0]]
+    %0 = arith.divui %arg0, %cst0 : i32
+    %1 = arith.divui %arg0, %cst1 : i32
+  }
+  return
+}
+
+// -----
+
+// Verify a division by a non-zero integer whose range is known due to the
+// InferIntRangeInterface is hoisted out of the loop.
+
+// CHECK-LABEL: func @match_integer_range
+// CHECK-SAME: %[[ARG0:[[:alnum:]]+]]
+func.func @match_integer_range(%arg0: i8) {
+  %lb = arith.constant 0 : index
+  %ub = arith.constant 10 : index
+  %step = arith.constant 1 : index
+
+  // CHECK: %[[VAL0:.*]] = test.with_bounds
+  %0 = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+  // CHECK: %[[VAL1:.*]] = test.with_bounds
+  %1 = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 1 : i8} : i8
+  // CHECK: %[[VAL2:.*]] = test.with_bounds
+  %2 = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 42 : i8} : i8
+  // CHECK: = arith.ceildivui %[[ARG0]], %[[VAL1]]
+  // CHECK: = arith.divui %[[ARG0]], %[[VAL2]]
+  // CHECK: scf.for
+  scf.for %idx = %lb to %ub step %step {
+    // CHECK: = arith.divui %[[ARG0]], %[[VAL0]]
+    %3 = arith.divui %arg0, %0 : i8
+    %4 = arith.ceildivui %arg0, %1 : i8
+    %5 = arith.divui %arg0, %2 : i8
+    // CHECK: = arith.divui %[[ARG0]], %[[ARG0]]
+    %6 = arith.divui %arg0, %arg0 : i8
+  }
+  return
+}

>From 2bafe97d0256e084c32b2be608be17e431193aad Mon Sep 17 00:00:00 2001
From: Tobias Gysi <tobias.gysi at nextsilicon.com>
Date: Fri, 9 Aug 2024 14:31:51 +0000
Subject: [PATCH 2/3] Address review comment.

---
 mlir/include/mlir/IR/Matchers.h | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 32b15265ca2c83..77f0c87a3ecd16 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -115,10 +115,8 @@ struct infer_int_range_op_binder {
       return false;
 
     // Set the range of all integer operands to the maximal range.
-    SmallVector<IntegerValueRange> argRanges;
-    argRanges.reserve(op->getNumOperands());
-    for (Value operand : op->getOperands())
-      argRanges.emplace_back(IntegerValueRange::getMaxRange(operand));
+    SmallVector<IntegerValueRange> argRanges =
+        llvm::map_to_vector(op->getOperands(), IntegerValueRange::getMaxRange);
 
     // Infer the result result range if possible.
     bool matched = false;

>From a52421b8d213cad10171785d6bbf591ae2d7bc9e Mon Sep 17 00:00:00 2001
From: Tobias Gysi <tobias.gysi at nextsilicon.com>
Date: Tue, 13 Aug 2024 16:28:30 +0000
Subject: [PATCH 3/3] Add support for signed divisions.

---
 mlir/include/mlir/IR/Matchers.h               |  9 +++
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        | 47 ++++++++-------
 .../Arith/hoist-speculatable-division.mlir    | 59 +++++++++++++++++--
 3 files changed, 88 insertions(+), 27 deletions(-)

diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 77f0c87a3ecd16..2ffcf96ce15891 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -451,6 +451,15 @@ inline detail::constant_int_range_predicate_matcher m_NonZeroU() {
   return {[](const ConstantIntRanges &range) { return range.umin().ugt(0); }};
 }
 
+/// Matches a constant scalar / vector splat / tensor splat integer or a
+/// signed integer range that does not contain zero. Note that this matcher
+/// interprets the target value as a signed integer.
+inline detail::constant_int_range_predicate_matcher m_NonZeroS() {
+  return {[](const ConstantIntRanges &range) {
+    return range.smin().sgt(0) || range.smax().slt(0);
+  }};
+}
+
 /// Matches a constant scalar / vector splat / tensor splat integer one.
 inline detail::constant_int_predicate_matcher m_One() {
   return {[](const APInt &value) { return 1 == value; }};
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 31f09d54f9759f..fa19525c8a305e 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -595,10 +595,15 @@ OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
   return div0 ? Attribute() : result;
 }
 
-Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
+/// Returns whether an unsigned division by `divisor` is speculatable.
+static Speculation::Speculatability getDivUISpeculatability(Value divisor) {
   // X / 0 => UB
-  return matchPattern(getRhs(), m_NonZeroU()) ? Speculation::Speculatable
-                                              : Speculation::NotSpeculatable;
+  return matchPattern(divisor, m_NonZeroU()) ? Speculation::Speculatable
+                                             : Speculation::NotSpeculatable;
+}
+
+Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
+  return getDivUISpeculatability(getRhs());
 }
 
 //===----------------------------------------------------------------------===//
@@ -624,16 +629,24 @@ OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
   return overflowOrDiv0 ? Attribute() : result;
 }
 
-Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
-  bool mayHaveUB = true;
+/// Returns whether a signed division by `divisor` is speculatable. This
+/// function conservatively assumes that all signed division by -1 are not
+/// speculatable.
+static Speculation::Speculatability getDivSISpeculatability(Value divisor) {
+  // INT_MIN / -1 => UB
+  APInt constDivisor;
+  if (matchPattern(divisor, m_ConstantInt(&constDivisor)) &&
+      constDivisor.isAllOnes())
+    return Speculation::NotSpeculatable;
 
-  APInt constRHS;
   // X / 0 => UB
-  // INT_MIN / -1 => UB
-  if (matchPattern(getRhs(), m_ConstantInt(&constRHS)))
-    mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
+  if (matchPattern(divisor, m_NonZeroS()))
+    return Speculation::Speculatable;
+  return Speculation::NotSpeculatable;
+}
 
-  return mayHaveUB ? Speculation::NotSpeculatable : Speculation::Speculatable;
+Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
+  return getDivSISpeculatability(getRhs());
 }
 
 //===----------------------------------------------------------------------===//
@@ -675,9 +688,7 @@ OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
 }
 
 Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
-  // X / 0 => UB
-  return matchPattern(getRhs(), m_NonZeroU()) ? Speculation::Speculatable
-                                              : Speculation::NotSpeculatable;
+  return getDivUISpeculatability(getRhs());
 }
 
 //===----------------------------------------------------------------------===//
@@ -746,15 +757,7 @@ OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
 }
 
 Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
-  bool mayHaveUB = true;
-
-  APInt constRHS;
-  // X / 0 => UB
-  // INT_MIN / -1 => UB
-  if (matchPattern(getRhs(), m_ConstantInt(&constRHS)))
-    mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
-
-  return mayHaveUB ? Speculation::NotSpeculatable : Speculation::Speculatable;
+  return getDivSISpeculatability(getRhs());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir b/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir
index b4bd525b502562..8b1f53dd3b6cdf 100644
--- a/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir
+++ b/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir
@@ -1,6 +1,7 @@
 // RUN: mlir-opt -loop-invariant-code-motion -split-input-file %s | FileCheck %s
 
-// Verify a division by a non-zero constant is hoisted out of the loop.
+// Verify a speculatable division by a non-zero constant (and constant that is
+// not minus one for signed integers) is hoisted out of the loop.
 
 // CHECK-LABEL: func @match_non_zero_constant
 func.func @match_non_zero_constant(%arg0: i32) {
@@ -11,24 +12,40 @@ func.func @match_non_zero_constant(%arg0: i32) {
   %cst0 = arith.constant 0 : i32
   // CHECK: %[[CST1:.*]] = arith.constant 1 : i32
   %cst1 = arith.constant 1 : i32
+  // CHECK: %[[CSTM1:.*]] = arith.constant -1 : i32
+  %cstm1 = arith.constant -1 : i32
+  // CHECK: %[[CSTM10:.*]] = arith.constant -10 : i32
+  %cstm10 = arith.constant -10 : i32
+
   // CHECK: = arith.divui %{{.*}}, %[[CST1]]
+  // CHECK: = arith.ceildivui %{{.*}}, %[[CSTM1]]
+  // CHECK: = arith.divsi %{{.*}}, %[[CST1]]
+  // CHECK: = arith.divsi %{{.*}}, %[[CSTM10]]
   // CHECK: scf.for
   scf.for %idx = %lb to %ub step %step {
     // CHECK: = arith.divui %{{.*}}, %[[CST0]]
     %0 = arith.divui %arg0, %cst0 : i32
     %1 = arith.divui %arg0, %cst1 : i32
+    %2 = arith.ceildivui %arg0, %cstm1 : i32
+
+    // CHECK: = arith.divsi %{{.*}}, %[[CST0]]
+    %3 = arith.divsi %arg0, %cst0 : i32
+    %4 = arith.divsi %arg0, %cst1 : i32
+    %5 = arith.divsi %arg0, %cstm10 : i32
+    // CHECK: = arith.ceildivsi %{{.*}}, %[[CSTM1]]
+    %6 = arith.ceildivsi %arg0, %cstm1 : i32
   }
   return
 }
 
 // -----
 
-// Verify a division by a non-zero integer whose range is known due to the
-// InferIntRangeInterface is hoisted out of the loop.
+// Verify a division by a non-zero integer whose unsigned range is known due to
+// the InferIntRangeInterface is hoisted out of the loop.
 
-// CHECK-LABEL: func @match_integer_range
+// CHECK-LABEL: func @match_unsigned_integer_range
 // CHECK-SAME: %[[ARG0:[[:alnum:]]+]]
-func.func @match_integer_range(%arg0: i8) {
+func.func @match_unsigned_integer_range(%arg0: i8) {
   %lb = arith.constant 0 : index
   %ub = arith.constant 10 : index
   %step = arith.constant 1 : index
@@ -52,3 +69,35 @@ func.func @match_integer_range(%arg0: i8) {
   }
   return
 }
+
+// -----
+
+// Verify a division by a non-zero integer whose signed range is known due to
+// the InferIntRangeInterface is hoisted out of the loop.
+
+// CHECK-LABEL: func @match_signed_integer_range
+// CHECK-SAME: %[[ARG0:[[:alnum:]]+]]
+func.func @match_signed_integer_range(%arg0: i8) {
+  %lb = arith.constant 0 : index
+  %ub = arith.constant 10 : index
+  %step = arith.constant 1 : index
+
+  // CHECK: %[[VAL0:.*]] = test.with_bounds
+  %0 = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+  // CHECK: %[[VAL1:.*]] = test.with_bounds
+  %1 = test.with_bounds {smax = 127 : i8, smin = 1 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+  // CHECK: %[[VAL2:.*]] = test.with_bounds
+  %2 = test.with_bounds {smax = -1 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+  // CHECK: = arith.ceildivsi %[[ARG0]], %[[VAL1]]
+  // CHECK: = arith.divsi %[[ARG0]], %[[VAL2]]
+  // CHECK: scf.for
+  scf.for %idx = %lb to %ub step %step {
+    // CHECK: = arith.divsi %[[ARG0]], %[[VAL0]]
+    %3 = arith.divsi %arg0, %0 : i8
+    %4 = arith.ceildivsi %arg0, %1 : i8
+    %5 = arith.divsi %arg0, %2 : i8
+    // CHECK: = arith.divsi %[[ARG0]], %[[ARG0]]
+    %6 = arith.divsi %arg0, %arg0 : i8
+  }
+  return
+}



More information about the Mlir-commits mailing list