[Mlir-commits] [mlir] [MLIR] Let matchers work on int ranges (PR #102494)
Tobias Gysi
llvmlistbot at llvm.org
Fri Aug 9 00:57:31 PDT 2024
https://github.com/gysit updated https://github.com/llvm/llvm-project/pull/102494
>From a746ed759d34aad28677335020443ab0ccb7c37a Mon Sep 17 00:00:00 2001
From: Tobias Gysi <tobias.gysi at nextsilicon.com>
Date: Thu, 8 Aug 2024 15:46:18 +0000
Subject: [PATCH] [MLIR] Use InferIntRangeInterface in matchers
The commit extends the matchers and specifically the m_NonZero matcher
to not only match constants but also operations that implement the
InferIntRangeInterface. It can then decide based on the inferred range
if the matched value is non-zero.
This extension of the matchers is, for example, useful when hoisting
divisions out of a loop. Only if the divisor is non-zero a division can
be hoisted. After this change additional divisions can be hoisted,
for example, if the divisor is the result of an operation that models
the number of threads of a team of threads.
---
mlir/include/mlir/IR/Matchers.h | 65 +++++++++++++++++++
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 4 +-
.../Arith/hoist-speculatable-division.mlir | 54 +++++++++++++++
3 files changed, 121 insertions(+), 2 deletions(-)
create mode 100644 mlir/test/Dialect/Arith/hoist-speculatable-division.mlir
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index f6417f62d09e8c..3799a32d957387 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -18,6 +18,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
namespace mlir {
@@ -100,6 +101,39 @@ struct constant_op_binder {
}
};
+/// A matcher that matches operations that implement the
+/// `InferIntRangeInterface`, and binds the inferred range.
+struct infer_int_range_op_binder {
+ IntegerValueRange *bind_value;
+
+ infer_int_range_op_binder(IntegerValueRange *bind_value)
+ : bind_value(bind_value) {}
+
+ bool match(Operation *op) {
+ auto inferIntRangeOp = dyn_cast<InferIntRangeInterface>(op);
+ if (!inferIntRangeOp)
+ return false;
+
+ // Set the range of all integer operands to the maximal range.
+ SmallVector<IntegerValueRange> argRanges;
+ argRanges.reserve(op->getNumOperands());
+ for (Value operand : op->getOperands())
+ argRanges.emplace_back(IntegerValueRange::getMaxRange(operand));
+
+ // Infer the result result range if possible.
+ bool matched = false;
+ auto setResultRanges = [&](Value value,
+ const IntegerValueRange &argRanges) {
+ if (argRanges.isUninitialized() || value != op->getResult(0))
+ return;
+ *bind_value = argRanges;
+ matched = true;
+ };
+ inferIntRangeOp.inferResultRangesFromOptional(argRanges, setResultRanges);
+ return matched;
+ }
+};
+
/// The matcher that matches operations that have the specified attribute
/// name, and binds the attribute value.
template <typename AttrT>
@@ -219,6 +253,31 @@ struct constant_int_predicate_matcher {
}
};
+/// A matcher that matches a given targe constant scalar / vector splat / tensor
+/// splat integer value or a constant integer range that fulfills a predicate.
+struct constant_int_range_predicate_matcher {
+ bool (*predicate)(const ConstantIntRanges &);
+
+ bool match(Attribute attr) {
+ APInt value;
+ return constant_int_value_binder(&value).match(attr) &&
+ predicate(ConstantIntRanges::constant(value));
+ }
+
+ bool match(Operation *op) {
+ // Try to match a constant integer value first.
+ APInt value;
+ if (constant_int_value_binder(&value).match(op))
+ return predicate(ConstantIntRanges::constant(value));
+
+ // Otherwise, try to match an operation that implements the
+ // `InferIntRangeInterface`.
+ IntegerValueRange range;
+ return infer_int_range_op_binder(&range).match(op) &&
+ predicate(range.getValue());
+ }
+};
+
/// The matcher that matches a certain kind of op.
template <typename OpClass>
struct op_matcher {
@@ -385,6 +444,12 @@ inline detail::constant_int_predicate_matcher m_NonZero() {
return {[](const APInt &value) { return 0 != value; }};
}
+/// Matches a constant scalar / vector splat / tensor splat integer or a
+/// unsigned integer range that does not contain zero.
+inline detail::constant_int_range_predicate_matcher m_NonZeroUI() {
+ return {[](const ConstantIntRanges &range) { return range.umin().ugt(0); }};
+}
+
/// Matches a constant scalar / vector splat / tensor splat integer one.
inline detail::constant_int_predicate_matcher m_One() {
return {[](const APInt &value) { return 1 == value; }};
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 641b7d7e2d13be..4614c1f84c066f 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -597,8 +597,8 @@ OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
// X / 0 => UB
- return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable
- : Speculation::NotSpeculatable;
+ return matchPattern(getRhs(), m_NonZeroUI()) ? Speculation::Speculatable
+ : Speculation::NotSpeculatable;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir b/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir
new file mode 100644
index 00000000000000..acca5813655658
--- /dev/null
+++ b/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir
@@ -0,0 +1,54 @@
+// RUN: mlir-opt -loop-invariant-code-motion -split-input-file %s | FileCheck %s
+
+// Verify a division by a non-zero constant is hoisted out of the loop.
+
+// CHECK-LABEL: func @match_non_zero_constant
+func.func @match_non_zero_constant(%arg0: i32) {
+ %lb = arith.constant 0 : index
+ %ub = arith.constant 10 : index
+ %step = arith.constant 1 : index
+ // CHECK: %[[CST0:.*]] = arith.constant 0 : i32
+ %cst0 = arith.constant 0 : i32
+ // CHECK: %[[CST1:.*]] = arith.constant 1 : i32
+ %cst1 = arith.constant 1 : i32
+ // CHECK: = arith.divui %{{.*}}, %[[CST1]]
+ // CHECK: scf.for
+ scf.for %arg2= %lb to %ub step %step {
+ // CHECK: = arith.divui %{{.*}}, %[[CST0]]
+ %0 = arith.divui %arg0, %cst0 : i32
+ %1 = arith.divui %arg0, %cst1 : i32
+ }
+ return
+}
+
+// -----
+
+// Verify a division by a non-zero integer whose range is known due to the
+// InferIntRangeInterface is hoisted out of the loop.
+
+// CHECK-LABEL: func @match_integer_range
+// CHECK-SAME: %[[ARG0:[[:alnum:]]+]]
+func.func @match_integer_range(%arg0: i8) {
+ %lb = arith.constant 0 : index
+ %ub = arith.constant 10 : index
+ %step = arith.constant 1 : index
+
+ // CHECK: %[[VAL0:.*]] = test.with_bounds
+ %0 = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+ // CHECK: %[[VAL1:.*]] = test.with_bounds
+ %1 = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 1 : i8} : i8
+ // CHECK: %[[VAL2:.*]] = test.with_bounds
+ %2 = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 42 : i8} : i8
+ // CHECK: = arith.divui %[[ARG0]], %[[VAL1]]
+ // CHECK: = arith.divui %[[ARG0]], %[[VAL2]]
+ // CHECK: scf.for
+ scf.for %arg2= %lb to %ub step %step {
+ // CHECK: = arith.divui %[[ARG0]], %[[VAL0]]
+ %3 = arith.divui %arg0, %0 : i8
+ %4 = arith.divui %arg0, %1 : i8
+ %5 = arith.divui %arg0, %2 : i8
+ // CHECK: = arith.divui %[[ARG0]], %[[ARG0]]
+ %6 = arith.divui %arg0, %arg0 : i8
+ }
+ return
+}
More information about the Mlir-commits
mailing list