[Mlir-commits] [mlir] [MLIR] Let matchers work on int ranges (PR #102494)
Tobias Gysi
llvmlistbot at llvm.org
Fri Aug 9 02:02:06 PDT 2024
https://github.com/gysit updated https://github.com/llvm/llvm-project/pull/102494
>From a4287d60b7577787998ab8adc6056ccfe46d66d3 Mon Sep 17 00:00:00 2001
From: Tobias Gysi <tobias.gysi at nextsilicon.com>
Date: Thu, 8 Aug 2024 15:46:18 +0000
Subject: [PATCH] [MLIR] Use InferIntRangeInterface in matchers
This commit adds a m_NonZeroUnsigned matcher that unlike the
m_NonZero matcher not only matches constants, but also operations that
implement the InferIntRangeInterface. It can then match a non-zero value
based on the inferred range. Additionally, the commit uses the new
matcher in the getSpeculatability function of Arith's unsigned integer
division. At the moment, the matcher only looks at the defining
operation to avoid expensive IR walks.
This range based matchers can be useful when hoisting divisions out of
a loop, which requires knowing the divisor is non-zero. Just checking
for a constant divisor may not be sufficient, if the divisor is,
for example, the result of an operation that returns the number of
threads of a team of threads.
---
mlir/include/mlir/IR/Matchers.h | 68 +++++++++++++++++++
mlir/lib/Dialect/Arith/IR/ArithOps.cpp | 8 +--
.../Arith/hoist-speculatable-division.mlir | 54 +++++++++++++++
3 files changed, 126 insertions(+), 4 deletions(-)
create mode 100644 mlir/test/Dialect/Arith/hoist-speculatable-division.mlir
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index f6417f62d09e8c..32b15265ca2c83 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -18,6 +18,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
namespace mlir {
@@ -100,6 +101,41 @@ struct constant_op_binder {
}
};
+/// A matcher that matches operations that implement the
+/// `InferIntRangeInterface` interface, and binds the inferred range.
+struct infer_int_range_op_binder {
+ IntegerValueRange *bind_value;
+
+ explicit infer_int_range_op_binder(IntegerValueRange *bind_value)
+ : bind_value(bind_value) {}
+
+ bool match(Operation *op) {
+ auto inferIntRangeOp = dyn_cast<InferIntRangeInterface>(op);
+ if (!inferIntRangeOp)
+ return false;
+
+ // Set the range of all integer operands to the maximal range.
+ SmallVector<IntegerValueRange> argRanges;
+ argRanges.reserve(op->getNumOperands());
+ for (Value operand : op->getOperands())
+ argRanges.emplace_back(IntegerValueRange::getMaxRange(operand));
+
+ // Infer the result result range if possible.
+ bool matched = false;
+ auto setResultRanges = [&](Value value,
+ const IntegerValueRange &argRanges) {
+ if (argRanges.isUninitialized())
+ return;
+ if (value != op->getResult(0))
+ return;
+ *bind_value = argRanges;
+ matched = true;
+ };
+ inferIntRangeOp.inferResultRangesFromOptional(argRanges, setResultRanges);
+ return matched;
+ }
+};
+
/// The matcher that matches operations that have the specified attribute
/// name, and binds the attribute value.
template <typename AttrT>
@@ -219,6 +255,31 @@ struct constant_int_predicate_matcher {
}
};
+/// A matcher that matches a given a constant scalar / vector splat / tensor
+/// splat integer value or a constant integer range that fulfills a predicate.
+struct constant_int_range_predicate_matcher {
+ bool (*predicate)(const ConstantIntRanges &);
+
+ bool match(Attribute attr) {
+ APInt value;
+ return constant_int_value_binder(&value).match(attr) &&
+ predicate(ConstantIntRanges::constant(value));
+ }
+
+ bool match(Operation *op) {
+ // Try to match a constant integer value first.
+ APInt value;
+ if (constant_int_value_binder(&value).match(op))
+ return predicate(ConstantIntRanges::constant(value));
+
+ // Otherwise, try to match an operation that implements the
+ // `InferIntRangeInterface` interface.
+ IntegerValueRange range;
+ return infer_int_range_op_binder(&range).match(op) &&
+ predicate(range.getValue());
+ }
+};
+
/// The matcher that matches a certain kind of op.
template <typename OpClass>
struct op_matcher {
@@ -385,6 +446,13 @@ inline detail::constant_int_predicate_matcher m_NonZero() {
return {[](const APInt &value) { return 0 != value; }};
}
+/// Matches a constant scalar / vector splat / tensor splat integer or a
+/// unsigned integer range that does not contain zero. Note that this matcher
+/// interprets the target value as an unsigned integer.
+inline detail::constant_int_range_predicate_matcher m_NonZeroU() {
+ return {[](const ConstantIntRanges &range) { return range.umin().ugt(0); }};
+}
+
/// Matches a constant scalar / vector splat / tensor splat integer one.
inline detail::constant_int_predicate_matcher m_One() {
return {[](const APInt &value) { return 1 == value; }};
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 641b7d7e2d13be..31f09d54f9759f 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -597,8 +597,8 @@ OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
// X / 0 => UB
- return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable
- : Speculation::NotSpeculatable;
+ return matchPattern(getRhs(), m_NonZeroU()) ? Speculation::Speculatable
+ : Speculation::NotSpeculatable;
}
//===----------------------------------------------------------------------===//
@@ -676,8 +676,8 @@ OpFoldResult arith::CeilDivUIOp::fold(FoldAdaptor adaptor) {
Speculation::Speculatability arith::CeilDivUIOp::getSpeculatability() {
// X / 0 => UB
- return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable
- : Speculation::NotSpeculatable;
+ return matchPattern(getRhs(), m_NonZeroU()) ? Speculation::Speculatable
+ : Speculation::NotSpeculatable;
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir b/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir
new file mode 100644
index 00000000000000..b4bd525b502562
--- /dev/null
+++ b/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir
@@ -0,0 +1,54 @@
+// RUN: mlir-opt -loop-invariant-code-motion -split-input-file %s | FileCheck %s
+
+// Verify a division by a non-zero constant is hoisted out of the loop.
+
+// CHECK-LABEL: func @match_non_zero_constant
+func.func @match_non_zero_constant(%arg0: i32) {
+ %lb = arith.constant 0 : index
+ %ub = arith.constant 10 : index
+ %step = arith.constant 1 : index
+ // CHECK: %[[CST0:.*]] = arith.constant 0 : i32
+ %cst0 = arith.constant 0 : i32
+ // CHECK: %[[CST1:.*]] = arith.constant 1 : i32
+ %cst1 = arith.constant 1 : i32
+ // CHECK: = arith.divui %{{.*}}, %[[CST1]]
+ // CHECK: scf.for
+ scf.for %idx = %lb to %ub step %step {
+ // CHECK: = arith.divui %{{.*}}, %[[CST0]]
+ %0 = arith.divui %arg0, %cst0 : i32
+ %1 = arith.divui %arg0, %cst1 : i32
+ }
+ return
+}
+
+// -----
+
+// Verify a division by a non-zero integer whose range is known due to the
+// InferIntRangeInterface is hoisted out of the loop.
+
+// CHECK-LABEL: func @match_integer_range
+// CHECK-SAME: %[[ARG0:[[:alnum:]]+]]
+func.func @match_integer_range(%arg0: i8) {
+ %lb = arith.constant 0 : index
+ %ub = arith.constant 10 : index
+ %step = arith.constant 1 : index
+
+ // CHECK: %[[VAL0:.*]] = test.with_bounds
+ %0 = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+ // CHECK: %[[VAL1:.*]] = test.with_bounds
+ %1 = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 1 : i8} : i8
+ // CHECK: %[[VAL2:.*]] = test.with_bounds
+ %2 = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 42 : i8} : i8
+ // CHECK: = arith.ceildivui %[[ARG0]], %[[VAL1]]
+ // CHECK: = arith.divui %[[ARG0]], %[[VAL2]]
+ // CHECK: scf.for
+ scf.for %idx = %lb to %ub step %step {
+ // CHECK: = arith.divui %[[ARG0]], %[[VAL0]]
+ %3 = arith.divui %arg0, %0 : i8
+ %4 = arith.ceildivui %arg0, %1 : i8
+ %5 = arith.divui %arg0, %2 : i8
+ // CHECK: = arith.divui %[[ARG0]], %[[ARG0]]
+ %6 = arith.divui %arg0, %arg0 : i8
+ }
+ return
+}
More information about the Mlir-commits
mailing list