[Mlir-commits] [mlir] [MLIR] Let matchers work on int ranges (PR #102494)

Tobias Gysi llvmlistbot at llvm.org
Fri Aug 9 00:57:31 PDT 2024


https://github.com/gysit updated https://github.com/llvm/llvm-project/pull/102494

>From a746ed759d34aad28677335020443ab0ccb7c37a Mon Sep 17 00:00:00 2001
From: Tobias Gysi <tobias.gysi at nextsilicon.com>
Date: Thu, 8 Aug 2024 15:46:18 +0000
Subject: [PATCH] [MLIR] Use InferIntRangeInterface in matchers

The commit extends the matchers and specifically the m_NonZero matcher
to not only match constants but also operations that implement the
InferIntRangeInterface. It can then decide based on the inferred range
if the matched value is non-zero.

This extension of the matchers is, for example, useful when hoisting
divisions out of a loop. Only if the divisor is non-zero a division can
be hoisted. After this change additional divisions can be hoisted,
for example, if the divisor is the result of an operation that models
the number of threads of a team of threads.
---
 mlir/include/mlir/IR/Matchers.h               | 65 +++++++++++++++++++
 mlir/lib/Dialect/Arith/IR/ArithOps.cpp        |  4 +-
 .../Arith/hoist-speculatable-division.mlir    | 54 +++++++++++++++
 3 files changed, 121 insertions(+), 2 deletions(-)
 create mode 100644 mlir/test/Dialect/Arith/hoist-speculatable-division.mlir

diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index f6417f62d09e8c..3799a32d957387 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -18,6 +18,7 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
 
 namespace mlir {
 
@@ -100,6 +101,39 @@ struct constant_op_binder {
   }
 };
 
+/// A matcher that matches operations that implement the
+/// `InferIntRangeInterface`, and binds the inferred range.
+struct infer_int_range_op_binder {
+  IntegerValueRange *bind_value;
+
+  infer_int_range_op_binder(IntegerValueRange *bind_value)
+      : bind_value(bind_value) {}
+
+  bool match(Operation *op) {
+    auto inferIntRangeOp = dyn_cast<InferIntRangeInterface>(op);
+    if (!inferIntRangeOp)
+      return false;
+
+    // Set the range of all integer operands to the maximal range.
+    SmallVector<IntegerValueRange> argRanges;
+    argRanges.reserve(op->getNumOperands());
+    for (Value operand : op->getOperands())
+      argRanges.emplace_back(IntegerValueRange::getMaxRange(operand));
+
+    // Infer the result result range if possible.
+    bool matched = false;
+    auto setResultRanges = [&](Value value,
+                               const IntegerValueRange &argRanges) {
+      if (argRanges.isUninitialized() || value != op->getResult(0))
+        return;
+      *bind_value = argRanges;
+      matched = true;
+    };
+    inferIntRangeOp.inferResultRangesFromOptional(argRanges, setResultRanges);
+    return matched;
+  }
+};
+
 /// The matcher that matches operations that have the specified attribute
 /// name, and binds the attribute value.
 template <typename AttrT>
@@ -219,6 +253,31 @@ struct constant_int_predicate_matcher {
   }
 };
 
+/// A matcher that matches a given targe constant scalar / vector splat / tensor
+/// splat integer value or a constant integer range that fulfills a predicate.
+struct constant_int_range_predicate_matcher {
+  bool (*predicate)(const ConstantIntRanges &);
+
+  bool match(Attribute attr) {
+    APInt value;
+    return constant_int_value_binder(&value).match(attr) &&
+           predicate(ConstantIntRanges::constant(value));
+  }
+
+  bool match(Operation *op) {
+    // Try to match a constant integer value first.
+    APInt value;
+    if (constant_int_value_binder(&value).match(op))
+      return predicate(ConstantIntRanges::constant(value));
+
+    // Otherwise, try to match an operation that implements the
+    // `InferIntRangeInterface`.
+    IntegerValueRange range;
+    return infer_int_range_op_binder(&range).match(op) &&
+           predicate(range.getValue());
+  }
+};
+
 /// The matcher that matches a certain kind of op.
 template <typename OpClass>
 struct op_matcher {
@@ -385,6 +444,12 @@ inline detail::constant_int_predicate_matcher m_NonZero() {
   return {[](const APInt &value) { return 0 != value; }};
 }
 
+/// Matches a constant scalar / vector splat / tensor splat integer or a
+/// unsigned integer range that does not contain zero.
+inline detail::constant_int_range_predicate_matcher m_NonZeroUI() {
+  return {[](const ConstantIntRanges &range) { return range.umin().ugt(0); }};
+}
+
 /// Matches a constant scalar / vector splat / tensor splat integer one.
 inline detail::constant_int_predicate_matcher m_One() {
   return {[](const APInt &value) { return 1 == value; }};
diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
index 641b7d7e2d13be..4614c1f84c066f 100644
--- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
+++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp
@@ -597,8 +597,8 @@ OpFoldResult arith::DivUIOp::fold(FoldAdaptor adaptor) {
 
 Speculation::Speculatability arith::DivUIOp::getSpeculatability() {
   // X / 0 => UB
-  return matchPattern(getRhs(), m_NonZero()) ? Speculation::Speculatable
-                                             : Speculation::NotSpeculatable;
+  return matchPattern(getRhs(), m_NonZeroUI()) ? Speculation::Speculatable
+                                               : Speculation::NotSpeculatable;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir b/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir
new file mode 100644
index 00000000000000..acca5813655658
--- /dev/null
+++ b/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir
@@ -0,0 +1,54 @@
+// RUN: mlir-opt -loop-invariant-code-motion -split-input-file %s | FileCheck %s
+
+// Verify a division by a non-zero constant is hoisted out of the loop.
+
+// CHECK-LABEL: func @match_non_zero_constant
+func.func @match_non_zero_constant(%arg0: i32) {
+  %lb = arith.constant 0 : index
+  %ub = arith.constant 10 : index
+  %step = arith.constant 1 : index
+  // CHECK: %[[CST0:.*]] = arith.constant 0 : i32
+  %cst0 = arith.constant 0 : i32
+  // CHECK: %[[CST1:.*]] = arith.constant 1 : i32
+  %cst1 = arith.constant 1 : i32
+  // CHECK: = arith.divui %{{.*}}, %[[CST1]]
+  // CHECK: scf.for
+  scf.for %arg2= %lb to %ub step %step {
+    // CHECK: = arith.divui %{{.*}}, %[[CST0]]
+    %0 = arith.divui %arg0, %cst0 : i32
+    %1 = arith.divui %arg0, %cst1 : i32
+  }
+  return
+}
+
+// -----
+
+// Verify a division by a non-zero integer whose range is known due to the
+// InferIntRangeInterface is hoisted out of the loop.
+
+// CHECK-LABEL: func @match_integer_range
+// CHECK-SAME: %[[ARG0:[[:alnum:]]+]]
+func.func @match_integer_range(%arg0: i8) {
+  %lb = arith.constant 0 : index
+  %ub = arith.constant 10 : index
+  %step = arith.constant 1 : index
+
+  // CHECK: %[[VAL0:.*]] = test.with_bounds
+  %0 = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 0 : i8} : i8
+  // CHECK: %[[VAL1:.*]] = test.with_bounds
+  %1 = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 1 : i8} : i8
+  // CHECK: %[[VAL2:.*]] = test.with_bounds
+  %2 = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = 255 : i8, umin = 42 : i8} : i8
+  // CHECK: = arith.divui %[[ARG0]], %[[VAL1]]
+  // CHECK: = arith.divui %[[ARG0]], %[[VAL2]]
+  // CHECK: scf.for
+  scf.for %arg2= %lb to %ub step %step {
+    // CHECK: = arith.divui %[[ARG0]], %[[VAL0]]
+    %3 = arith.divui %arg0, %0 : i8
+    %4 = arith.divui %arg0, %1 : i8
+    %5 = arith.divui %arg0, %2 : i8
+    // CHECK: = arith.divui %[[ARG0]], %[[ARG0]]
+    %6 = arith.divui %arg0, %arg0 : i8
+  }
+  return
+}



More information about the Mlir-commits mailing list