[Mlir-commits] [mlir] [MLIR] Let matchers work on int ranges (PR #102494)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Aug 9 02:03:05 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-arith
Author: Tobias Gysi (gysit)
<details>
<summary>Changes</summary>
This commit adds a `m_NonZeroUnsigned` matcher that unlike the `m_NonZero` matcher not only matches constants, but also operations that implement the InferIntRangeInterface. It can then match a non-zero value
based on the inferred range. Additionally, the commit uses the new matcher in the `getSpeculatability` function of Arith's unsigned integer division. At the moment, the matcher only looks at the defining operation to avoid expensive IR walks.
This range based matchers can be useful when hoisting divisions out of a loop, which requires knowing the divisor is non-zero. Just checking for a constant divisor may not be sufficient, if the divisor is, for example, the result of an operation that returns the number of threads of a team of threads.
---
Full diff: https://github.com/llvm/llvm-project/pull/102494.diff
3 Files Affected:
- (modified) mlir/include/mlir/IR/Matchers.h (+68)
- (modified) mlir/lib/Dialect/Arith/IR/ArithOps.cpp (+4-4)
- (added) mlir/test/Dialect/Arith/hoist-speculatable-division.mlir (+54)
``````````diff
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index f6417f62d09e8c..32b15265ca2c83 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -18,6 +18,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
namespace mlir {
@@ -100,6 +101,41 @@ struct constant_op_binder {
}
};
+/// A matcher that matches operations that implement the
+/// `InferIntRangeInterface` interface, and binds the inferred range.
+struct infer_int_range_op_binder {
+ IntegerValueRange *bind_value;
+
+ explicit infer_int_range_op_binder(IntegerValueRange *bind_value)
+ : bind_value(bind_value) {}
+
+ bool match(Operation *op) {
+ auto inferIntRangeOp = dyn_cast<InferIntRangeInterface>(op);
+ if (!inferIntRangeOp)
+ return false;
+
+ // Set the range of all integer operands to the maximal range.
+ SmallVector<IntegerValueRange> argRanges;
+ argRanges.reserve(op->getNumOperands());
+ for (Value operand : op->getOperands())
+ argRanges.emplace_back(IntegerValueRange::getMaxRange(operand));
+
+ // Infer the result result range if possible.
+ bool matched = false;
+ auto setResultRanges = [&](Value value,
+ const IntegerValueRange &argRanges) {
+ if (argRanges.isUninitialized())
+ return;
+ if (value != op->getResult(0))
+ return;
+ *bind_value = argRanges;
+ matched = true;
+ };
+ inferIntRangeOp.inferResultRangesFromOptional(argRanges, setResultRanges);
+ return matched;
+ }
+};
+
/// The matcher that matches operations that have the specified attribute
/// name, and binds the attribute value.
template <typename AttrT>
@@ -219,6 +255,31 @@ struct constant_int_predicate_matcher {
}
};
+/// A matcher that matches a given a constant scalar / vector splat / tensor
+/// splat integer value or a constant integer range that fulfills a predicate.
+struct constant_int_range_predicate_matcher {
+ bool (*predicate)(const ConstantIntRanges &);
+
+ bool match(Attribute attr) {
+ APInt value;
+ return constant_int_value_binder(&value).match(attr) &&
+ predicate(ConstantIntRanges::constant(value));
+ }
+
+ bool match(Operation *op) {
+ // Try to match a constant integer value first.
+ APInt value;
+ if (constant_int_value_binder(&value).match(op))
+ return predicate(ConstantIntRanges::constant(value));
+
+ // Otherwise, try to match an operation that implements the
+ // `InferIntRangeInterface` interface.
+ IntegerValueRange range;
+ return infer_int_range_op_binder(&range).match(op) &&
+ predicate(range.getValue());
+ }
+};
+
/// The matcher that matches a certain kind of op.
template <typename OpClass>
struct op_matcher {
@@ -385,6 +446,13 @@ inline detail::constant_int_predicate_matcher m_NonZero() {
return {[](const APInt &value) { return 0 != value; }};
}
+/// Matches a constant scalar / vector splat / tensor splat integer or a
+/// unsigned integer range that does not contain zero. Note that this matcher
+/// interprets the target value as an unsigned integer.
+inline detail::constant_int_range_predicate_matcher m_NonZeroU() {
+ return {[](const ConstantIntRanges &range) { return range.umin().ugt(0); }};
+}
+
/// Matches a constant scalar / vector splat / tensor splat integer one.
inline detail::constant_int_predicate_matcher m_One() {
return {[](const APInt &value) { return 1 == value; }};
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 641b7d7e2d13be..31f09d54f9759f 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -597,8 +597,8 @@ OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
// X / 0 => UB
- return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable
- : Speculation::NotSpeculatable;
+ return matchPattern(getRhs(), m_NonZeroU()) ? Speculation::Speculatable
+ : Speculation::NotSpeculatable;
}
//===----------------------------------------------------------------------===//
@@ -676,8 +676,8 @@ OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
// X / 0 => UB
- return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable
- : Speculation::NotSpeculatable;
+ return matchPattern(getRhs(), m_NonZeroU()) ? Speculation::Speculatable
+ : Speculation::NotSpeculatable;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir b/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir
new file mode 100644
index 00000000000000..b4bd525b502562
--- /dev/null
+++ b/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir
@@ -0,0 +1,54 @@
+// RUN: mlir-opt -loop-invariant-code-motion -split-input-file %s | FileCheck %s
+
+// Verify a division by a non-zero constant is hoisted out of the loop.
+
+// CHECK-LABEL: func @match_non_zero_constant
+func.func @match_non_zero_constant(%arg0: i32) {
+ %lb = arith.constant 0 : index
+ %ub = arith.constant 10 : index
+ %step = arith.constant 1 : index
+ // CHECK: %[[CST0:.*]] = arith.constant 0 : i32
+ %cst0 = arith.constant 0 : i32
+ // CHECK: %[[CST1:.*]] = arith.constant 1 : i32
+ %cst1 = arith.constant 1 : i32
+ // CHECK: = arith.divui %{{.*}}, %[[CST1]]
+ // CHECK: scf.for
+ scf.for %idx = %lb to %ub step %step {
+ // CHECK: = arith.divui %{{.*}}, %[[CST0]]
+ %0 = arith.divui %arg0, %cst0 : i32
+ %1 = arith.divui %arg0, %cst1 : i32
+ }
+ return
+}
+
+// -----
+
+// Verify a division by a non-zero integer whose range is known due to the
+// InferIntRangeInterface is hoisted out of the loop.
+
+// CHECK-LABEL: func @match_integer_range
+// CHECK-SAME: %[[ARG0:[[:alnum:]]+]]
+func.func @match_integer_range(%arg0: i8) {
+ %lb = arith.constant 0 : index
+ %ub = arith.constant 10 : index
+ %step = arith.constant 1 : index
+
+ // CHECK: %[[VAL0:.*]] = test.with_bounds
+ %0 = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+ // CHECK: %[[VAL1:.*]] = test.with_bounds
+ %1 = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 1 : i8} : i8
+ // CHECK: %[[VAL2:.*]] = test.with_bounds
+ %2 = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 42 : i8} : i8
+ // CHECK: = arith.ceildivui %[[ARG0]], %[[VAL1]]
+ // CHECK: = arith.divui %[[ARG0]], %[[VAL2]]
+ // CHECK: scf.for
+ scf.for %idx = %lb to %ub step %step {
+ // CHECK: = arith.divui %[[ARG0]], %[[VAL0]]
+ %3 = arith.divui %arg0, %0 : i8
+ %4 = arith.ceildivui %arg0, %1 : i8
+ %5 = arith.divui %arg0, %2 : i8
+ // CHECK: = arith.divui %[[ARG0]], %[[ARG0]]
+ %6 = arith.divui %arg0, %arg0 : i8
+ }
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/102494
More information about the Mlir-commits
mailing list