[Mlir-commits] [mlir] [MLIR] Let matchers work on int ranges (PR #102494)

Tobias Gysi llvmlistbot at llvm.org
Wed Aug 14 05:01:33 PDT 2024


https://github.com/gysit updated https://github.com/llvm/llvm-project/pull/102494

>From a99690111b72e6bd80f7d1dbc84c63f43ab31085 Mon Sep 17 00:00:00 2001
From: Tobias Gysi <tobias.gysi at nextsilicon.com>
Date: Thu, 8 Aug 2024 15:46:18 +0000
Subject: [PATCH 1/6] [MLIR] Use InferIntRangeInterface in matchers

This commit adds a m_NonZeroUnsigned matcher that unlike the
m_NonZero matcher not only matches constants, but also operations that
implement the InferIntRangeInterface. It can then match a non-zero value
based on the inferred range. Additionally, the commit uses the new
matcher in the getSpeculatability function of Arith's unsigned integer
division. At the moment, the matcher only looks at the defining
operation to avoid expensive IR walks.

This range based matchers can be useful when hoisting divisions out of
a loop, which requires knowing the divisor is non-zero. Just checking
for a constant divisor may not be sufficient, if the divisor is,
for example, the result of an operation that returns the number of
threads of a team of threads.
---
 mlir/include/mlir/IR/Matchers.h               | 68 +++++++++++++++++++
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        |  8 +--
 .../Arith/hoist-speculatable-division.mlir    | 54 +++++++++++++++
 3 files changed, 126 insertions(+), 4 deletions(-)
 create mode 100644 mlir/test/Dialect/Arith/hoist-speculatable-division.mlir

diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index f6417f62d09e8c..32b15265ca2c83 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -18,6 +18,7 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
 
 namespace mlir {
 
@@ -100,6 +101,41 @@ struct constant_op_binder {
   }
 };
 
+/// A matcher that matches operations that implement the
+/// `InferIntRangeInterface` interface, and binds the inferred range.
+struct infer_int_range_op_binder {
+  IntegerValueRange *bind_value;
+
+  explicit infer_int_range_op_binder(IntegerValueRange *bind_value)
+      : bind_value(bind_value) {}
+
+  bool match(Operation *op) {
+    auto inferIntRangeOp = dyn_cast<InferIntRangeInterface>(op);
+    if (!inferIntRangeOp)
+      return false;
+
+    // Set the range of all integer operands to the maximal range.
+    SmallVector<IntegerValueRange> argRanges;
+    argRanges.reserve(op->getNumOperands());
+    for (Value operand : op->getOperands())
+      argRanges.emplace_back(IntegerValueRange::getMaxRange(operand));
+
+    // Infer the result result range if possible.
+    bool matched = false;
+    auto setResultRanges = [&](Value value,
+                               const IntegerValueRange &argRanges) {
+      if (argRanges.isUninitialized())
+        return;
+      if (value != op->getResult(0))
+        return;
+      *bind_value = argRanges;
+      matched = true;
+    };
+    inferIntRangeOp.inferResultRangesFromOptional(argRanges, setResultRanges);
+    return matched;
+  }
+};
+
 /// The matcher that matches operations that have the specified attribute
 /// name, and binds the attribute value.
 template <typename AttrT>
@@ -219,6 +255,31 @@ struct constant_int_predicate_matcher {
   }
 };
 
+/// A matcher that matches a given a constant scalar / vector splat / tensor
+/// splat integer value or a constant integer range that fulfills a predicate.
+struct constant_int_range_predicate_matcher {
+  bool (*predicate)(const ConstantIntRanges &);
+
+  bool match(Attribute attr) {
+    APInt value;
+    return constant_int_value_binder(&value).match(attr) &&
+           predicate(ConstantIntRanges::constant(value));
+  }
+
+  bool match(Operation *op) {
+    // Try to match a constant integer value first.
+    APInt value;
+    if (constant_int_value_binder(&value).match(op))
+      return predicate(ConstantIntRanges::constant(value));
+
+    // Otherwise, try to match an operation that implements the
+    // `InferIntRangeInterface` interface.
+    IntegerValueRange range;
+    return infer_int_range_op_binder(&range).match(op) &&
+           predicate(range.getValue());
+  }
+};
+
 /// The matcher that matches a certain kind of op.
 template <typename OpClass>
 struct op_matcher {
@@ -385,6 +446,13 @@ inline detail::constant_int_predicate_matcher m_NonZero() {
   return {[](const APInt &value) { return 0 != value; }};
 }
 
+/// Matches a constant scalar / vector splat / tensor splat integer or a
+/// unsigned integer range that does not contain zero. Note that this matcher
+/// interprets the target value as an unsigned integer.
+inline detail::constant_int_range_predicate_matcher m_NonZeroU() {
+  return {[](const ConstantIntRanges &range) { return range.umin().ugt(0); }};
+}
+
 /// Matches a constant scalar / vector splat / tensor splat integer one.
 inline detail::constant_int_predicate_matcher m_One() {
   return {[](const APInt &value) { return 1 == value; }};
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 641b7d7e2d13be..31f09d54f9759f 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -597,8 +597,8 @@ OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
 
 Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
   // X / 0 => UB
-  return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable
-                                             : Speculation::NotSpeculatable;
+  return matchPattern(getRhs(), m_NonZeroU()) ? Speculation::Speculatable
+                                              : Speculation::NotSpeculatable;
 }
 
 //===----------------------------------------------------------------------===//
@@ -676,8 +676,8 @@ OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
 
 Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
   // X / 0 => UB
-  return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable
-                                             : Speculation::NotSpeculatable;
+  return matchPattern(getRhs(), m_NonZeroU()) ? Speculation::Speculatable
+                                              : Speculation::NotSpeculatable;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir b/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir
new file mode 100644
index 00000000000000..b4bd525b502562
--- /dev/null
+++ b/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir
@@ -0,0 +1,54 @@
+// RUN: mlir-opt -loop-invariant-code-motion -split-input-file %s | FileCheck %s
+
+// Verify a division by a non-zero constant is hoisted out of the loop.
+
+// CHECK-LABEL: func @match_non_zero_constant
+func.func @match_non_zero_constant(%arg0: i32) {
+  %lb = arith.constant 0 : index
+  %ub = arith.constant 10 : index
+  %step = arith.constant 1 : index
+  // CHECK: %[[CST0:.*]] = arith.constant 0 : i32
+  %cst0 = arith.constant 0 : i32
+  // CHECK: %[[CST1:.*]] = arith.constant 1 : i32
+  %cst1 = arith.constant 1 : i32
+  // CHECK: = arith.divui %{{.*}}, %[[CST1]]
+  // CHECK: scf.for
+  scf.for %idx = %lb to %ub step %step {
+    // CHECK: = arith.divui %{{.*}}, %[[CST0]]
+    %0 = arith.divui %arg0, %cst0 : i32
+    %1 = arith.divui %arg0, %cst1 : i32
+  }
+  return
+}
+
+// -----
+
+// Verify a division by a non-zero integer whose range is known due to the
+// InferIntRangeInterface is hoisted out of the loop.
+
+// CHECK-LABEL: func @match_integer_range
+// CHECK-SAME: %[[ARG0:[[:alnum:]]+]]
+func.func @match_integer_range(%arg0: i8) {
+  %lb = arith.constant 0 : index
+  %ub = arith.constant 10 : index
+  %step = arith.constant 1 : index
+
+  // CHECK: %[[VAL0:.*]] = test.with_bounds
+  %0 = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+  // CHECK: %[[VAL1:.*]] = test.with_bounds
+  %1 = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 1 : i8} : i8
+  // CHECK: %[[VAL2:.*]] = test.with_bounds
+  %2 = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 42 : i8} : i8
+  // CHECK: = arith.ceildivui %[[ARG0]], %[[VAL1]]
+  // CHECK: = arith.divui %[[ARG0]], %[[VAL2]]
+  // CHECK: scf.for
+  scf.for %idx = %lb to %ub step %step {
+    // CHECK: = arith.divui %[[ARG0]], %[[VAL0]]
+    %3 = arith.divui %arg0, %0 : i8
+    %4 = arith.ceildivui %arg0, %1 : i8
+    %5 = arith.divui %arg0, %2 : i8
+    // CHECK: = arith.divui %[[ARG0]], %[[ARG0]]
+    %6 = arith.divui %arg0, %arg0 : i8
+  }
+  return
+}

>From 8b6b45e481a6d3ed45d7d1efa4eb4b92652538fd Mon Sep 17 00:00:00 2001
From: Tobias Gysi <tobias.gysi at nextsilicon.com>
Date: Fri, 9 Aug 2024 14:31:51 +0000
Subject: [PATCH 2/6] Address review comment.

---
 mlir/include/mlir/IR/Matchers.h | 6 ++----
 1 file changed, 2 insertions(+), 4 deletions(-)

diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 32b15265ca2c83..77f0c87a3ecd16 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -115,10 +115,8 @@ struct infer_int_range_op_binder {
       return false;
 
     // Set the range of all integer operands to the maximal range.
-    SmallVector<IntegerValueRange> argRanges;
-    argRanges.reserve(op->getNumOperands());
-    for (Value operand : op->getOperands())
-      argRanges.emplace_back(IntegerValueRange::getMaxRange(operand));
+    SmallVector<IntegerValueRange> argRanges =
+        llvm::map_to_vector(op->getOperands(), IntegerValueRange::getMaxRange);
 
     // Infer the result result range if possible.
     bool matched = false;

>From 17adb4139d2e881352b07f5895843b0a67e14ce8 Mon Sep 17 00:00:00 2001
From: Tobias Gysi <tobias.gysi at nextsilicon.com>
Date: Tue, 13 Aug 2024 16:28:30 +0000
Subject: [PATCH 3/6] Add support for signed divisions.

---
 mlir/include/mlir/IR/Matchers.h               |  9 +++
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        | 47 ++++++++-------
 .../Arith/hoist-speculatable-division.mlir    | 59 +++++++++++++++++--
 3 files changed, 88 insertions(+), 27 deletions(-)

diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 77f0c87a3ecd16..2ffcf96ce15891 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -451,6 +451,15 @@ inline detail::constant_int_range_predicate_matcher m_NonZeroU() {
   return {[](const ConstantIntRanges &range) { return range.umin().ugt(0); }};
 }
 
+/// Matches a constant scalar / vector splat / tensor splat integer or a
+/// signed integer range that does not contain zero. Note that this matcher
+/// interprets the target value as a signed integer.
+inline detail::constant_int_range_predicate_matcher m_NonZeroS() {
+  return {[](const ConstantIntRanges &range) {
+    return range.smin().sgt(0) || range.smax().slt(0);
+  }};
+}
+
 /// Matches a constant scalar / vector splat / tensor splat integer one.
 inline detail::constant_int_predicate_matcher m_One() {
   return {[](const APInt &value) { return 1 == value; }};
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 31f09d54f9759f..fa19525c8a305e 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -595,10 +595,15 @@ OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
   return div0 ? Attribute() : result;
 }
 
-Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
+/// Returns whether an unsigned division by `divisor` is speculatable.
+static Speculation::Speculatability getDivUISpeculatability(Value divisor) {
   // X / 0 => UB
-  return matchPattern(getRhs(), m_NonZeroU()) ? Speculation::Speculatable
-                                              : Speculation::NotSpeculatable;
+  return matchPattern(divisor, m_NonZeroU()) ? Speculation::Speculatable
+                                             : Speculation::NotSpeculatable;
+}
+
+Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
+  return getDivUISpeculatability(getRhs());
 }
 
 //===----------------------------------------------------------------------===//
@@ -624,16 +629,24 @@ OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
   return overflowOrDiv0 ? Attribute() : result;
 }
 
-Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
-  bool mayHaveUB = true;
+/// Returns whether a signed division by `divisor` is speculatable. This
+/// function conservatively assumes that all signed division by -1 are not
+/// speculatable.
+static Speculation::Speculatability getDivSISpeculatability(Value divisor) {
+  // INT_MIN / -1 => UB
+  APInt constDivisor;
+  if (matchPattern(divisor, m_ConstantInt(&constDivisor)) &&
+      constDivisor.isAllOnes())
+    return Speculation::NotSpeculatable;
 
-  APInt constRHS;
   // X / 0 => UB
-  // INT_MIN / -1 => UB
-  if (matchPattern(getRhs(), m_ConstantInt(&constRHS)))
-    mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
+  if (matchPattern(divisor, m_NonZeroS()))
+    return Speculation::Speculatable;
+  return Speculation::NotSpeculatable;
+}
 
-  return mayHaveUB ? Speculation::NotSpeculatable : Speculation::Speculatable;
+Speculation::Speculatability arith::DivSIOp::getSpeculatability() {
+  return getDivSISpeculatability(getRhs());
 }
 
 //===----------------------------------------------------------------------===//
@@ -675,9 +688,7 @@ OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
 }
 
 Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
-  // X / 0 => UB
-  return matchPattern(getRhs(), m_NonZeroU()) ? Speculation::Speculatable
-                                              : Speculation::NotSpeculatable;
+  return getDivUISpeculatability(getRhs());
 }
 
 //===----------------------------------------------------------------------===//
@@ -746,15 +757,7 @@ OpFoldResult arith::CeilDivSIOp::fold(FoldAdaptor adaptor) {
 }
 
 Speculation::Speculatability arith::CeilDivSIOp::getSpeculatability() {
-  bool mayHaveUB = true;
-
-  APInt constRHS;
-  // X / 0 => UB
-  // INT_MIN / -1 => UB
-  if (matchPattern(getRhs(), m_ConstantInt(&constRHS)))
-    mayHaveUB = constRHS.isAllOnes() || constRHS.isZero();
-
-  return mayHaveUB ? Speculation::NotSpeculatable : Speculation::Speculatable;
+  return getDivSISpeculatability(getRhs());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir b/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir
index b4bd525b502562..8b1f53dd3b6cdf 100644
--- a/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir
+++ b/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir
@@ -1,6 +1,7 @@
 // RUN: mlir-opt -loop-invariant-code-motion -split-input-file %s | FileCheck %s
 
-// Verify a division by a non-zero constant is hoisted out of the loop.
+// Verify a speculatable division by a non-zero constant (and constant that is
+// not minus one for signed integers) is hoisted out of the loop.
 
 // CHECK-LABEL: func @match_non_zero_constant
 func.func @match_non_zero_constant(%arg0: i32) {
@@ -11,24 +12,40 @@ func.func @match_non_zero_constant(%arg0: i32) {
   %cst0 = arith.constant 0 : i32
   // CHECK: %[[CST1:.*]] = arith.constant 1 : i32
   %cst1 = arith.constant 1 : i32
+  // CHECK: %[[CSTM1:.*]] = arith.constant -1 : i32
+  %cstm1 = arith.constant -1 : i32
+  // CHECK: %[[CSTM10:.*]] = arith.constant -10 : i32
+  %cstm10 = arith.constant -10 : i32
+
   // CHECK: = arith.divui %{{.*}}, %[[CST1]]
+  // CHECK: = arith.ceildivui %{{.*}}, %[[CSTM1]]
+  // CHECK: = arith.divsi %{{.*}}, %[[CST1]]
+  // CHECK: = arith.divsi %{{.*}}, %[[CSTM10]]
   // CHECK: scf.for
   scf.for %idx = %lb to %ub step %step {
     // CHECK: = arith.divui %{{.*}}, %[[CST0]]
     %0 = arith.divui %arg0, %cst0 : i32
     %1 = arith.divui %arg0, %cst1 : i32
+    %2 = arith.ceildivui %arg0, %cstm1 : i32
+
+    // CHECK: = arith.divsi %{{.*}}, %[[CST0]]
+    %3 = arith.divsi %arg0, %cst0 : i32
+    %4 = arith.divsi %arg0, %cst1 : i32
+    %5 = arith.divsi %arg0, %cstm10 : i32
+    // CHECK: = arith.ceildivsi %{{.*}}, %[[CSTM1]]
+    %6 = arith.ceildivsi %arg0, %cstm1 : i32
   }
   return
 }
 
 // -----
 
-// Verify a division by a non-zero integer whose range is known due to the
-// InferIntRangeInterface is hoisted out of the loop.
+// Verify a division by a non-zero integer whose unsigned range is known due to
+// the InferIntRangeInterface is hoisted out of the loop.
 
-// CHECK-LABEL: func @match_integer_range
+// CHECK-LABEL: func @match_unsigned_integer_range
 // CHECK-SAME: %[[ARG0:[[:alnum:]]+]]
-func.func @match_integer_range(%arg0: i8) {
+func.func @match_unsigned_integer_range(%arg0: i8) {
   %lb = arith.constant 0 : index
   %ub = arith.constant 10 : index
   %step = arith.constant 1 : index
@@ -52,3 +69,35 @@ func.func @match_integer_range(%arg0: i8) {
   }
   return
 }
+
+// -----
+
+// Verify a division by a non-zero integer whose signed range is known due to
+// the InferIntRangeInterface is hoisted out of the loop.
+
+// CHECK-LABEL: func @match_signed_integer_range
+// CHECK-SAME: %[[ARG0:[[:alnum:]]+]]
+func.func @match_signed_integer_range(%arg0: i8) {
+  %lb = arith.constant 0 : index
+  %ub = arith.constant 10 : index
+  %step = arith.constant 1 : index
+
+  // CHECK: %[[VAL0:.*]] = test.with_bounds
+  %0 = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+  // CHECK: %[[VAL1:.*]] = test.with_bounds
+  %1 = test.with_bounds {smax = 127 : i8, smin = 1 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+  // CHECK: %[[VAL2:.*]] = test.with_bounds
+  %2 = test.with_bounds {smax = -1 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+  // CHECK: = arith.ceildivsi %[[ARG0]], %[[VAL1]]
+  // CHECK: = arith.divsi %[[ARG0]], %[[VAL2]]
+  // CHECK: scf.for
+  scf.for %idx = %lb to %ub step %step {
+    // CHECK: = arith.divsi %[[ARG0]], %[[VAL0]]
+    %3 = arith.divsi %arg0, %0 : i8
+    %4 = arith.ceildivsi %arg0, %1 : i8
+    %5 = arith.divsi %arg0, %2 : i8
+    // CHECK: = arith.divsi %[[ARG0]], %[[ARG0]]
+    %6 = arith.divsi %arg0, %arg0 : i8
+  }
+  return
+}

>From 539459859af31a5a6279d1a2f21002285befbcf9 Mon Sep 17 00:00:00 2001
From: Tobias Gysi <tobias.gysi at nextsilicon.com>
Date: Tue, 13 Aug 2024 18:40:17 +0000
Subject: [PATCH 4/6] Rename the new matchers.

---
 mlir/include/mlir/IR/Matchers.h        | 4 ++--
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 7 ++++---
 2 files changed, 6 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 2ffcf96ce15891..0e23faf84890e2 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -447,14 +447,14 @@ inline detail::constant_int_predicate_matcher m_NonZero() {
 /// Matches a constant scalar / vector splat / tensor splat integer or a
 /// unsigned integer range that does not contain zero. Note that this matcher
 /// interprets the target value as an unsigned integer.
-inline detail::constant_int_range_predicate_matcher m_NonZeroU() {
+inline detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroU() {
   return {[](const ConstantIntRanges &range) { return range.umin().ugt(0); }};
 }
 
 /// Matches a constant scalar / vector splat / tensor splat integer or a
 /// signed integer range that does not contain zero. Note that this matcher
 /// interprets the target value as a signed integer.
-inline detail::constant_int_range_predicate_matcher m_NonZeroS() {
+inline detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS() {
   return {[](const ConstantIntRanges &range) {
     return range.smin().sgt(0) || range.smax().slt(0);
   }};
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index fa19525c8a305e..12303b5f72d912 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -598,8 +598,9 @@ OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
 /// Returns whether an unsigned division by `divisor` is speculatable.
 static Speculation::Speculatability getDivUISpeculatability(Value divisor) {
   // X / 0 => UB
-  return matchPattern(divisor, m_NonZeroU()) ? Speculation::Speculatable
-                                             : Speculation::NotSpeculatable;
+  return matchPattern(divisor, m_IntRangeWithoutZeroU())
+             ? Speculation::Speculatable
+             : Speculation::NotSpeculatable;
 }
 
 Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
@@ -640,7 +641,7 @@ static Speculation::Speculatability getDivSISpeculatability(Value divisor) {
     return Speculation::NotSpeculatable;
 
   // X / 0 => UB
-  if (matchPattern(divisor, m_NonZeroS()))
+  if (matchPattern(divisor, m_IntRangeWithoutZeroS()))
     return Speculation::Speculatable;
   return Speculation::NotSpeculatable;
 }

>From 298314a45911fa217e9b7e3019f5a3a127c40604 Mon Sep 17 00:00:00 2001
From: Tobias Gysi <tobias.gysi at nextsilicon.com>
Date: Wed, 14 Aug 2024 10:33:52 +0000
Subject: [PATCH 5/6] Fix an issue with the signed division.

Additionally, move the test cases to the existing
loop invariant code motion tests.
---
 mlir/include/mlir/IR/Matchers.h               |   9 ++
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        |  18 ++-
 .../Arith/hoist-speculatable-division.mlir    | 103 ----------------
 .../loop-invariant-code-motion.mlir           | 114 ++++++++++++++++++
 4 files changed, 131 insertions(+), 113 deletions(-)
 delete mode 100644 mlir/test/Dialect/Arith/hoist-speculatable-division.mlir

diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index 0e23faf84890e2..6fa5a47109d20d 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -460,6 +460,15 @@ inline detail::constant_int_range_predicate_matcher m_IntRangeWithoutZeroS() {
   }};
 }
 
+/// Matches a constant scalar / vector splat / tensor splat integer or a
+/// signed integer range that does not contain minus one. Note
+/// that this matcher interprets the target value as a signed integer.
+inline detail::constant_int_range_predicate_matcher m_IntRangeWithoutNegOneS() {
+  return {[](const ConstantIntRanges &range) {
+    return range.smin().sgt(-1) || range.smax().slt(-1);
+  }};
+}
+
 /// Matches a constant scalar / vector splat / tensor splat integer one.
 inline detail::constant_int_predicate_matcher m_One() {
   return {[](const APInt &value) { return 1 == value; }};
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 12303b5f72d912..254f54d9e459e1 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -598,9 +598,10 @@ OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
 /// Returns whether an unsigned division by `divisor` is speculatable.
 static Speculation::Speculatability getDivUISpeculatability(Value divisor) {
   // X / 0 => UB
-  return matchPattern(divisor, m_IntRangeWithoutZeroU())
-             ? Speculation::Speculatable
-             : Speculation::NotSpeculatable;
+  if (matchPattern(divisor, m_IntRangeWithoutZeroU()))
+    return Speculation::Speculatable;
+
+  return Speculation::NotSpeculatable;
 }
 
 Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
@@ -634,15 +635,12 @@ OpFoldResult arith::DivSIOp::fold(FoldAdaptor adaptor) {
 /// function conservatively assumes that all signed division by -1 are not
 /// speculatable.
 static Speculation::Speculatability getDivSISpeculatability(Value divisor) {
-  // INT_MIN / -1 => UB
-  APInt constDivisor;
-  if (matchPattern(divisor, m_ConstantInt(&constDivisor)) &&
-      constDivisor.isAllOnes())
-    return Speculation::NotSpeculatable;
-
   // X / 0 => UB
-  if (matchPattern(divisor, m_IntRangeWithoutZeroS()))
+  // INT_MIN / -1 => UB
+  if (matchPattern(divisor, m_IntRangeWithoutZeroS()) &&
+      matchPattern(divisor, m_IntRangeWithoutNegOneS()))
     return Speculation::Speculatable;
+
   return Speculation::NotSpeculatable;
 }
 
diff --git a/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir b/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir
deleted file mode 100644
index 8b1f53dd3b6cdf..00000000000000
--- a/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir
+++ /dev/null
@@ -1,103 +0,0 @@
-// RUN: mlir-opt -loop-invariant-code-motion -split-input-file %s | FileCheck %s
-
-// Verify a speculatable division by a non-zero constant (and constant that is
-// not minus one for signed integers) is hoisted out of the loop.
-
-// CHECK-LABEL: func @match_non_zero_constant
-func.func @match_non_zero_constant(%arg0: i32) {
-  %lb = arith.constant 0 : index
-  %ub = arith.constant 10 : index
-  %step = arith.constant 1 : index
-  // CHECK: %[[CST0:.*]] = arith.constant 0 : i32
-  %cst0 = arith.constant 0 : i32
-  // CHECK: %[[CST1:.*]] = arith.constant 1 : i32
-  %cst1 = arith.constant 1 : i32
-  // CHECK: %[[CSTM1:.*]] = arith.constant -1 : i32
-  %cstm1 = arith.constant -1 : i32
-  // CHECK: %[[CSTM10:.*]] = arith.constant -10 : i32
-  %cstm10 = arith.constant -10 : i32
-
-  // CHECK: = arith.divui %{{.*}}, %[[CST1]]
-  // CHECK: = arith.ceildivui %{{.*}}, %[[CSTM1]]
-  // CHECK: = arith.divsi %{{.*}}, %[[CST1]]
-  // CHECK: = arith.divsi %{{.*}}, %[[CSTM10]]
-  // CHECK: scf.for
-  scf.for %idx = %lb to %ub step %step {
-    // CHECK: = arith.divui %{{.*}}, %[[CST0]]
-    %0 = arith.divui %arg0, %cst0 : i32
-    %1 = arith.divui %arg0, %cst1 : i32
-    %2 = arith.ceildivui %arg0, %cstm1 : i32
-
-    // CHECK: = arith.divsi %{{.*}}, %[[CST0]]
-    %3 = arith.divsi %arg0, %cst0 : i32
-    %4 = arith.divsi %arg0, %cst1 : i32
-    %5 = arith.divsi %arg0, %cstm10 : i32
-    // CHECK: = arith.ceildivsi %{{.*}}, %[[CSTM1]]
-    %6 = arith.ceildivsi %arg0, %cstm1 : i32
-  }
-  return
-}
-
-// -----
-
-// Verify a division by a non-zero integer whose unsigned range is known due to
-// the InferIntRangeInterface is hoisted out of the loop.
-
-// CHECK-LABEL: func @match_unsigned_integer_range
-// CHECK-SAME: %[[ARG0:[[:alnum:]]+]]
-func.func @match_unsigned_integer_range(%arg0: i8) {
-  %lb = arith.constant 0 : index
-  %ub = arith.constant 10 : index
-  %step = arith.constant 1 : index
-
-  // CHECK: %[[VAL0:.*]] = test.with_bounds
-  %0 = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8
-  // CHECK: %[[VAL1:.*]] = test.with_bounds
-  %1 = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 1 : i8} : i8
-  // CHECK: %[[VAL2:.*]] = test.with_bounds
-  %2 = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 42 : i8} : i8
-  // CHECK: = arith.ceildivui %[[ARG0]], %[[VAL1]]
-  // CHECK: = arith.divui %[[ARG0]], %[[VAL2]]
-  // CHECK: scf.for
-  scf.for %idx = %lb to %ub step %step {
-    // CHECK: = arith.divui %[[ARG0]], %[[VAL0]]
-    %3 = arith.divui %arg0, %0 : i8
-    %4 = arith.ceildivui %arg0, %1 : i8
-    %5 = arith.divui %arg0, %2 : i8
-    // CHECK: = arith.divui %[[ARG0]], %[[ARG0]]
-    %6 = arith.divui %arg0, %arg0 : i8
-  }
-  return
-}
-
-// -----
-
-// Verify a division by a non-zero integer whose signed range is known due to
-// the InferIntRangeInterface is hoisted out of the loop.
-
-// CHECK-LABEL: func @match_signed_integer_range
-// CHECK-SAME: %[[ARG0:[[:alnum:]]+]]
-func.func @match_signed_integer_range(%arg0: i8) {
-  %lb = arith.constant 0 : index
-  %ub = arith.constant 10 : index
-  %step = arith.constant 1 : index
-
-  // CHECK: %[[VAL0:.*]] = test.with_bounds
-  %0 = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8
-  // CHECK: %[[VAL1:.*]] = test.with_bounds
-  %1 = test.with_bounds {smax = 127 : i8, smin = 1 : i8, umax = 255 : i8, umin = 0 : i8} : i8
-  // CHECK: %[[VAL2:.*]] = test.with_bounds
-  %2 = test.with_bounds {smax = -1 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8
-  // CHECK: = arith.ceildivsi %[[ARG0]], %[[VAL1]]
-  // CHECK: = arith.divsi %[[ARG0]], %[[VAL2]]
-  // CHECK: scf.for
-  scf.for %idx = %lb to %ub step %step {
-    // CHECK: = arith.divsi %[[ARG0]], %[[VAL0]]
-    %3 = arith.divsi %arg0, %0 : i8
-    %4 = arith.ceildivsi %arg0, %1 : i8
-    %5 = arith.divsi %arg0, %2 : i8
-    // CHECK: = arith.divsi %[[ARG0]], %[[ARG0]]
-    %6 = arith.divsi %arg0, %arg0 : i8
-  }
-  return
-}
diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir
index dcc314f36ae0a8..063dcd15c4efbb 100644
--- a/mlir/test/Transforms/loop-invariant-code-motion.mlir
+++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir
@@ -923,6 +923,120 @@ func.func @speculate_ceildivsi_const(
   return
 }
 
+func.func @no_speculate_divui_range(
+// CHECK-LABEL: @no_speculate_divui_range(
+    %num: i8, %lb: index, %ub: index, %step: index) {
+  %denom = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+  scf.for %i = %lb to %ub step %step {
+// CHECK: scf.for
+// CHECK: arith.divui
+    %val = arith.divui %num, %denom : i8
+  }
+
+  return
+}
+
+func.func @no_speculate_divsi_range(
+// CHECK-LABEL: @no_speculate_divsi_range(
+    %num: i8, %lb: index, %ub: index, %step: index) {
+  %denom0 = test.with_bounds {smax = -1: i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+  %denom1 = test.with_bounds {smax = 127 : i8, smin = 0 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+  scf.for %i = %lb to %ub step %step {
+// CHECK: scf.for
+// CHECK-COUNT2: arith.divsi
+    %val0 = arith.divsi %num, %denom0 : i8
+    %val1 = arith.divsi %num, %denom1 : i8
+  }
+
+  return
+}
+
+func.func @no_speculate_ceildivui_range(
+// CHECK-LABEL: @no_speculate_ceildivui_range(
+    %num: i8, %lb: index, %ub: index, %step: index) {
+  %denom = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+  scf.for %i = %lb to %ub step %step {
+// CHECK: scf.for
+// CHECK: arith.ceildivui
+    %val = arith.ceildivui %num, %denom : i8
+  }
+
+  return
+}
+
+func.func @no_speculate_ceildivsi_range(
+// CHECK-LABEL: @no_speculate_ceildivsi_range(
+    %num: i8, %lb: index, %ub: index, %step: index) {
+  %denom0 = test.with_bounds {smax = -1 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+  %denom1 = test.with_bounds {smax = 127 : i8, smin = 0 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+  scf.for %i = %lb to %ub step %step {
+// CHECK: scf.for
+// CHECK-COUNT2: arith.ceildivsi
+    %val0 = arith.ceildivsi %num, %denom0 : i8
+    %val1 = arith.ceildivsi %num, %denom1 : i8
+  }
+
+  return
+}
+
+func.func @speculate_divui_range(
+// CHECK-LABEL: @speculate_divui_range(
+    %num: i8, %lb: index, %ub: index, %step: index) {
+  %denom = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 1 : i8} : i8
+  scf.for %i = %lb to %ub step %step {
+// CHECK: arith.divui
+// CHECK: scf.for
+    %val = arith.divui %num, %denom : i8
+  }
+
+  return
+}
+
+func.func @speculate_divsi_range(
+// CHECK-LABEL: @speculate_divsi_range(
+    %num: i8, %lb: index, %ub: index, %step: index) {
+  %denom0 = test.with_bounds {smax = 127 : i8, smin = 1 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+  %denom1 = test.with_bounds {smax = -2 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+  scf.for %i = %lb to %ub step %step {
+// CHECK-COUNT2: arith.divsi
+// CHECK: scf.for
+    %val0 = arith.divsi %num, %denom0 : i8
+    %val1 = arith.divsi %num, %denom1 : i8
+
+  }
+
+  return
+}
+
+func.func @speculate_ceildivui_range(
+// CHECK-LABEL: @speculate_ceildivui_range(
+    %num: i8, %lb: index, %ub: index, %step: index) {
+  %denom = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 1 : i8} : i8
+  scf.for %i = %lb to %ub step %step {
+// CHECK: arith.ceildivui
+// CHECK: scf.for
+    %val = arith.ceildivui %num, %denom : i8
+  }
+
+  return
+}
+
+func.func @speculate_ceildivsi_range(
+// CHECK-LABEL: @speculate_ceildivsi_range(
+    %num: i8, %lb: index, %ub: index, %step: index) {
+  %denom0 = test.with_bounds {smax = 127 : i8, smin = 1 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+  %denom1 = test.with_bounds {smax = -2 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+  scf.for %i = %lb to %ub step %step {
+// CHECK-COUNT2: arith.ceildivsi
+// CHECK: scf.for
+    %val0 = arith.ceildivsi %num, %denom0 : i8
+    %val1 = arith.ceildivsi %num, %denom1 : i8
+
+  }
+
+  return
+}
+
 // -----
 
 func.func @speculate_static_pack_and_unpack(%source: tensor<128x256xf32>,

>From 196d21a66633932af4bc09bde02bd9207e170adb Mon Sep 17 00:00:00 2001
From: Tobias Gysi <tobias.gysi at nextsilicon.com>
Date: Wed, 14 Aug 2024 12:01:06 +0000
Subject: [PATCH 6/6] Fix check count.

---
 mlir/test/Transforms/loop-invariant-code-motion.mlir | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir
index 063dcd15c4efbb..47a49465e8a7cd 100644
--- a/mlir/test/Transforms/loop-invariant-code-motion.mlir
+++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir
@@ -943,7 +943,7 @@ func.func @no_speculate_divsi_range(
   %denom1 = test.with_bounds {smax = 127 : i8, smin = 0 : i8, umax = 255 : i8, umin = 0 : i8} : i8
   scf.for %i = %lb to %ub step %step {
 // CHECK: scf.for
-// CHECK-COUNT2: arith.divsi
+// CHECK-COUNT-2: arith.divsi
     %val0 = arith.divsi %num, %denom0 : i8
     %val1 = arith.divsi %num, %denom1 : i8
   }
@@ -971,7 +971,7 @@ func.func @no_speculate_ceildivsi_range(
   %denom1 = test.with_bounds {smax = 127 : i8, smin = 0 : i8, umax = 255 : i8, umin = 0 : i8} : i8
   scf.for %i = %lb to %ub step %step {
 // CHECK: scf.for
-// CHECK-COUNT2: arith.ceildivsi
+// CHECK-COUNT-2: arith.ceildivsi
     %val0 = arith.ceildivsi %num, %denom0 : i8
     %val1 = arith.ceildivsi %num, %denom1 : i8
   }
@@ -998,7 +998,7 @@ func.func @speculate_divsi_range(
   %denom0 = test.with_bounds {smax = 127 : i8, smin = 1 : i8, umax = 255 : i8, umin = 0 : i8} : i8
   %denom1 = test.with_bounds {smax = -2 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8
   scf.for %i = %lb to %ub step %step {
-// CHECK-COUNT2: arith.divsi
+// CHECK-COUNT-2: arith.divsi
 // CHECK: scf.for
     %val0 = arith.divsi %num, %denom0 : i8
     %val1 = arith.divsi %num, %denom1 : i8
@@ -1027,7 +1027,7 @@ func.func @speculate_ceildivsi_range(
   %denom0 = test.with_bounds {smax = 127 : i8, smin = 1 : i8, umax = 255 : i8, umin = 0 : i8} : i8
   %denom1 = test.with_bounds {smax = -2 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8
   scf.for %i = %lb to %ub step %step {
-// CHECK-COUNT2: arith.ceildivsi
+// CHECK-COUNT-2: arith.ceildivsi
 // CHECK: scf.for
     %val0 = arith.ceildivsi %num, %denom0 : i8
     %val1 = arith.ceildivsi %num, %denom1 : i8



More information about the Mlir-commits mailing list