[Mlir-commits] [mlir] [MLIR] Let matchers work on int ranges (PR #102494)

Tobias Gysi llvmlistbot at llvm.org
Thu Aug 8 08:47:44 PDT 2024


https://github.com/gysit created https://github.com/llvm/llvm-project/pull/102494

The commit extends the matchers and specifically the m_NonZero matcher to not only match constants but also operations that implement the InferIntRangeInterface. It can then decide based on the inferred range if the matched value is non-zero.

This extension of the matchers is, for example, useful when hoisting divisions out of a loop. Only if the divisor is non-zero a division can be hoisted. After this change additional divisions can be hoisted, for example, if the divisor is the result of an operation that models the number of threads of a team of threads.

>From c9e3ac0c5c592faedf971a6cf8bb027d4a774ee4 Mon Sep 17 00:00:00 2001
From: Tobias Gysi <tobias.gysi at nextsilicon.com>
Date: Thu, 8 Aug 2024 15:46:18 +0000
Subject: [PATCH] [MLIR] Let matchers work on int ranges

The commit extends the matchers and specifically the m_NonZero matcher
to not only match constants but also operations that implement the
InferIntRangeInterface. It can then decide based on the inferred range
if the matched value is non-zero.

This extension of the matchers is, for example, useful when hoisting divisions
out of a loop. Only if the divisor is non-zero a division can be hoisted.
After this change additional divisions can be hoisted, for example, if the
divisor is the result of an operation that models the number of threads of
a team of threads.
---
 mlir/include/mlir/IR/Matchers.h               | 69 +++++++++++++++++--
 .../Arith/hoist-speculatable-division.mlir    | 52 ++++++++++++++
 2 files changed, 117 insertions(+), 4 deletions(-)
 create mode 100644 mlir/test/Dialect/Arith/hoist-speculatable-division.mlir

diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h
index f6417f62d09e8c..a9d149c7d225fa 100644
--- a/mlir/include/mlir/IR/Matchers.h
+++ b/mlir/include/mlir/IR/Matchers.h
@@ -18,6 +18,7 @@
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/OpDefinition.h"
+#include "mlir/Interfaces/InferIntRangeInterface.h"
 
 namespace mlir {
 
@@ -100,6 +101,39 @@ struct constant_op_binder {
   }
 };
 
+/// A matcher that matches operations that implement the
+/// `InferIntRangeInterface`, and binds the inferred range.
+struct infer_int_range_op_binder {
+  IntegerValueRange *bind_value;
+
+  infer_int_range_op_binder(IntegerValueRange *bind_value)
+      : bind_value(bind_value) {}
+
+  bool match(Operation *op) {
+    auto inferIntRangeOp = dyn_cast<InferIntRangeInterface>(op);
+    if (!inferIntRangeOp)
+      return false;
+
+    // Set the range of all integer operands to the maximal range.
+    SmallVector<IntegerValueRange> argRanges;
+    argRanges.reserve(op->getNumOperands());
+    for (Value operand : op->getOperands())
+      argRanges.emplace_back(IntegerValueRange::getMaxRange(operand));
+
+    // Infer the result result range if possible.
+    bool matched = false;
+    auto setResultRanges = [&](Value value,
+                               const IntegerValueRange &argRanges) {
+      if (argRanges.isUninitialized() || value != op->getResult(0))
+        return;
+      *bind_value = argRanges;
+      matched = true;
+    };
+    inferIntRangeOp.inferResultRangesFromOptional(argRanges, setResultRanges);
+    return matched;
+  }
+};
+
 /// The matcher that matches operations that have the specified attribute
 /// name, and binds the attribute value.
 template <typename AttrT>
@@ -219,6 +253,31 @@ struct constant_int_predicate_matcher {
   }
 };
 
+/// A matcher that matches a given targe constant scalar / vector splat / tensor
+/// splat integer value or a constant integer range that fulfills a predicate.
+struct constant_int_range_predicate_matcher {
+  bool (*predicate)(const ConstantIntRanges &);
+
+  bool match(Attribute attr) {
+    APInt value;
+    return constant_int_value_binder(&value).match(attr) &&
+           predicate(ConstantIntRanges::constant(value));
+  }
+
+  bool match(Operation *op) {
+    // Try to match a constant integer value first.
+    APInt value;
+    if (constant_int_value_binder(&value).match(op))
+      return predicate(ConstantIntRanges::constant(value));
+
+    // Otherwise, try to match an operation that implements the
+    // `InferIntRangeInterface`.
+    IntegerValueRange range;
+    return infer_int_range_op_binder(&range).match(op) &&
+           predicate(range.getValue());
+  }
+};
+
 /// The matcher that matches a certain kind of op.
 template <typename OpClass>
 struct op_matcher {
@@ -379,10 +438,12 @@ inline detail::constant_int_predicate_matcher m_Zero() {
   return {[](const APInt &value) { return 0 == value; }};
 }
 
-/// Matches a constant scalar / vector splat / tensor splat integer that is any
-/// non-zero value.
-inline detail::constant_int_predicate_matcher m_NonZero() {
-  return {[](const APInt &value) { return 0 != value; }};
+/// Matches a constant scalar / vector splat / tensor splat integer or a
+/// constant integer range that does not contain zero.
+inline detail::constant_int_range_predicate_matcher m_NonZero() {
+  return {[](const ConstantIntRanges &range) {
+    return range.umin().ugt(0) && (range.smin().sgt(0) || range.smax().slt(0));
+  }};
 }
 
 /// Matches a constant scalar / vector splat / tensor splat integer one.
diff --git a/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir b/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir
new file mode 100644
index 00000000000000..b711ad0cc07112
--- /dev/null
+++ b/mlir/test/Dialect/Arith/hoist-speculatable-division.mlir
@@ -0,0 +1,52 @@
+// RUN: mlir-opt -loop-invariant-code-motion -split-input-file %s | FileCheck %s
+
+// Verify a division by a non-zero constant is hoisted out of the loop.
+
+// CHECK-LABEL: func @match_non_zero_constant
+func.func @match_non_zero_constant(%arg0: i32) {
+  %lb = arith.constant 0 : index
+  %ub = arith.constant 10 : index
+  %step = arith.constant 1 : index
+  // CHECK: %[[CST0:.*]] = arith.constant 0 : i32
+  %cst0 = arith.constant 0 : i32
+  // CHECK: %[[CST1:.*]] = arith.constant 1 : i32
+  %cst1 = arith.constant 1 : i32
+  // CHECK: = arith.divui %{{.*}}, %[[CST1]]
+  // CHECK: scf.for
+  scf.for %arg2= %lb to %ub step %step {
+    // CHECK: = arith.divui %{{.*}}, %[[CST0]]
+    %0 = arith.divui %arg0, %cst0 : i32
+    %1 = arith.divui %arg0, %cst1 : i32
+  }
+  return
+}
+
+// -----
+
+// Verify a division by a non-zero integer whose range is known due to the
+// InferIntRangeInterface is hoisted out of the loop.
+
+// CHECK-LABEL: func @match_integer_range
+func.func @match_integer_range(%arg0: i8) {
+  %lb = arith.constant 0 : index
+  %ub = arith.constant 10 : index
+  %step = arith.constant 1 : index
+
+  // CHECK: %[[VAL0:.*]] = test.with_bounds {smax = 127 : i8, smin = -128 : i8,
+  %0 = test.with_bounds {smax = 127 : i8, smin = -128 : i8, umax = -1 : i8, umin = 0 : i8} : i8
+  // CHECK: %[[VAL1:.*]] = test.with_bounds {smax = 127 : i8, smin = 1 : i8,
+  %1 = test.with_bounds {smax = 127 : i8, smin = 1 : i8, umax = 127 : i8, umin = 1 : i8} : i8
+  // CHECK: %[[VAL2:.*]] = test.with_bounds {smax = -1 : i8, smin = -128 : i8,
+  %2 = test.with_bounds {smax = -1 : i8, smin = -128 : i8, umax = 255 : i8, umin = 128 : i8} : i8
+  // CHECK: = arith.divui %{{.*}}, %[[VAL1]]
+  // CHECK: = arith.divui %{{.*}}, %[[VAL2]]
+  // CHECK: scf.for
+  scf.for %arg2= %lb to %ub step %step {
+    // CHECK: = arith.divui %{{.*}}, %[[VAL0]]
+    %3 = arith.divui %arg0, %0 : i8
+    %4 = arith.divui %arg0, %1 : i8
+    %5 = arith.divui %arg0, %2 : i8
+
+  }
+  return
+}



More information about the Mlir-commits mailing list