[Mlir-commits] [mlir] [MLIR] Let matchers work on int ranges (PR #102494)
Tobias Gysi
llvmlistbot at llvm.org
Thu Aug 8 08:47:44 PDT 2024
https://github.com/gysit created https://github.com/llvm/llvm-project/pull/102494
The commit extends the matchers and specifically the m_NonZero matcher to not only match constants but also operations that implement the InferIntRangeInterface. It can then decide based on the inferred range if the matched value is non-zero.
This extension of the matchers is, for example, useful when hoisting divisions out of a loop. Only if the divisor is non-zero a division can be hoisted. After this change additional divisions can be hoisted, for example, if the divisor is the result of an operation that models the number of threads of a team of threads.
>From c9e3ac0c5c592faedf971a6cf8bb027d4a774ee4 Mon Sep 17 00:00:00 2001
From: Tobias Gysi <tobias.gysi at nextsilicon.com>
Date: Thu, 8 Aug 2024 15:46:18 +0000
Subject: [PATCH] [MLIR] Let matchers work on int ranges
The commit extends the matchers and specifically the m_NonZero matcher
to not only match constants but also operations that implement the
InferIntRangeInterface. It can then decide based on the inferred range
if the matched value is non-zero.
This extension of the matchers is, for example, useful when hoisting divisions
out of a loop. Only if the divisor is non-zero a division can be hoisted.
After this change additional divisions can be hoisted, for example, if the
divisor is the result of an operation that models the number of threads of
a team of threads.
---
mlir/include/mlir/IR/Matchers.h | 69 +++++++++++++++++--
.../Arith/hoist-speculatable-division.mlir | 52 ++++++++++++++
2 files changed, 117 insertions(+), 4 deletions(-)
create mode 100644 mlir/test/Dialect/Arith/hoist-speculatable-division.mlir
diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index f6417f62d09e8c..a9d149c7d225fa 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -18,6 +18,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
namespace mlir {
@@ -100,6 +101,39 @@ struct constant_op_binder {
}
};
+/// A matcher that matches operations that implement the
+/// `InferIntRangeInterface`, and binds the inferred range.
+struct infer_int_range_op_binder {
+ IntegerValueRange *bind_value;
+
+ infer_int_range_op_binder(IntegerValueRange *bind_value)
+ : bind_value(bind_value) {}
+
+ bool match(Operation *op) {
+ auto inferIntRangeOp = dyn_cast<InferIntRangeInterface>(op);
+ if (!inferIntRangeOp)
+ return false;
+
+ // Set the range of all integer operands to the maximal range.
+ SmallVector<IntegerValueRange> argRanges;
+ argRanges.reserve(op->getNumOperands());
+ for (Value operand : op->getOperands())
+ argRanges.emplace_back(IntegerValueRange::getMaxRange(operand));
+
+ // Infer the result result range if possible.
+ bool matched = false;
+ auto setResultRanges = [&](Value value,
+ const IntegerValueRange &argRanges) {
+ if (argRanges.isUninitialized() || value != op->getResult(0))
+ return;
+ *bind_value = argRanges;
+ matched = true;
+ };
+ inferIntRangeOp.inferResultRangesFromOptional(argRanges, setResultRanges);
+ return matched;
+ }
+};
+
/// The matcher that matches operations that have the specified attribute
/// name, and binds the attribute value.
template <typename AttrT>
@@ -219,6 +253,31 @@ struct constant_int_predicate_matcher {
}
};
+/// A matcher that matches a given targe constant scalar / vector splat / tensor
+/// splat integer value or a constant integer range that fulfills a predicate.
+struct constant_int_range_predicate_matcher {
+ bool (*predicate)(const ConstantIntRanges &);
+
+ bool match(Attribute attr) {
+ APInt value;
+ return constant_int_value_binder(&value).match(attr) &&
+ predicate(ConstantIntRanges::constant(value));
+ }
+
+ bool match(Operation *op) {
+ // Try to match a constant integer value first.
+ APInt value;
+ if (constant_int_value_binder(&value).match(op))
+ return predicate(ConstantIntRanges::constant(value));
+
+ // Otherwise, try to match an operation that implements the
+ // `InferIntRangeInterface`.
+ IntegerValueRange range;
+ return infer_int_range_op_binder(&range).match(op) &&
+ predicate(range.getValue());
+ }
+};
+
/// The matcher that matches a certain kind of op.
template <typename OpClass>
struct op_matcher {
@@ -379,10 +438,12 @@ inline detail::constant_int_predicate_matcher m_Zero() {
return {[](const APInt &value) { return 0 == value; }};
}
-/// Matches a constant scalar / vector splat / tensor splat integer that is any
-/// non-zero value.
-inline detail::constant_int_predicate_matcher m_NonZero() {
- return {[](const APInt &value) { return 0 != value; }};
+/// Matches a constant scalar / vector splat / tensor splat integer or a
+/// constant integer range that does not contain zero.
+inline detail::constant_int_range_predicate_matcher m_NonZero() {
+ return {[](const ConstantIntRanges &range) {
+ return range.umin().ugt(0) && (range.smin().sgt(0) || range.smax().slt(0));
+ }};
}
/// Matches a constant scalar / vector splat / tensor splat integer one.
diff --git a/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir b/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir
new file mode 100644
index 00000000000000..b711ad0cc07112
--- /dev/null
+++ b/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir
@@ -0,0 +1,52 @@
+// RUN: mlir-opt -loop-invariant-code-motion -split-input-file %s | FileCheck %s
+
+// Verify a division by a non-zero constant is hoisted out of the loop.
+
+// CHECK-LABEL: func @match_non_zero_constant
+func.func @match_non_zero_constant(%arg0: i32) {
+ %lb = arith.constant 0 : index
+ %ub = arith.constant 10 : index
+ %step = arith.constant 1 : index
+ // CHECK: %[[CST0:.*]] = arith.constant 0 : i32
+ %cst0 = arith.constant 0 : i32
+ // CHECK: %[[CST1:.*]] = arith.constant 1 : i32
+ %cst1 = arith.constant 1 : i32
+ // CHECK: = arith.divui %{{.*}}, %[[CST1]]
+ // CHECK: scf.for
+ scf.for %arg2= %lb to %ub step %step {
+ // CHECK: = arith.divui %{{.*}}, %[[CST0]]
+ %0 = arith.divui %arg0, %cst0 : i32
+ %1 = arith.divui %arg0, %cst1 : i32
+ }
+ return
+}
+
+// -----
+
+// Verify a division by a non-zero integer whose range is known due to the
+// InferIntRangeInterface is hoisted out of the loop.
+
+// CHECK-LABEL: func @match_integer_range
+func.func @match_integer_range(%arg0: i8) {
+ %lb = arith.constant 0 : index
+ %ub = arith.constant 10 : index
+ %step = arith.constant 1 : index
+
+ // CHECK: %[[VAL0:.*]] = test.with_bounds {smax = 127 : i8, smin = -128 : i8,
+ %0 = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = -1 : i8, umin = 0 : i8} : i8
+ // CHECK: %[[VAL1:.*]] = test.with_bounds {smax = 127 : i8, smin = 1 : i8,
+ %1 = test.with_bounds {smax = 127 : i8, smin = 1 : i8, umax = 127 : i8, umin = 1 : i8} : i8
+ // CHECK: %[[VAL2:.*]] = test.with_bounds {smax = -1 : i8, smin = -128 : i8,
+ %2 = test.with_bounds {smax = -1 : i8, smin = -128 : i8, umax = 255 : i8, umin = 128 : i8} : i8
+ // CHECK: = arith.divui %{{.*}}, %[[VAL1]]
+ // CHECK: = arith.divui %{{.*}}, %[[VAL2]]
+ // CHECK: scf.for
+ scf.for %arg2= %lb to %ub step %step {
+ // CHECK: = arith.divui %{{.*}}, %[[VAL0]]
+ %3 = arith.divui %arg0, %0 : i8
+ %4 = arith.divui %arg0, %1 : i8
+ %5 = arith.divui %arg0, %2 : i8
+
+ }
+ return
+}
More information about the Mlir-commits
mailing list