[Mlir-commits] [mlir] bea77ed - [mlir][Vector] Fold vector.step compared to constant (#161615)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Oct 15 08:40:35 PDT 2025


Author: James Newling
Date: 2025-10-15T08:40:30-07:00
New Revision: bea77ed52e2d714a9f1a836673733dc5f44a29e3

URL: https://github.com/llvm/llvm-project/commit/bea77ed52e2d714a9f1a836673733dc5f44a29e3
DIFF: https://github.com/llvm/llvm-project/commit/bea77ed52e2d714a9f1a836673733dc5f44a29e3.diff

LOG: [mlir][Vector] Fold vector.step compared to constant (#161615)

This PR adds a canonicalizer to vector.step that folds vector.step iff
the result of the fold is a splat value. An alternative would be to
always constant fold it, but that might result in some very
large/cumbersome constants.

I do wonder if vector.step might be better represented as some sort of
attribute in the arith dialect, like %step = arith.constant iota<32> :
vector<32xindex>.

---------

Signed-off-by: James Newling <james.newling at gmail.com>

Added: 
    mlir/test/Dialect/Vector/canonicalize/vector-step.mlir

Modified: 
    mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
    mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 6e79085afac9f..6e15b1e7df606 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2999,6 +2999,7 @@ def Vector_StepOp : Vector_Op<"step", [
   }];
   let results = (outs VectorOfRankAndType<[1], [Index]>:$result);
   let assemblyFormat = "attr-dict `:` type($result)";
+  let hasCanonicalizer = 1;
 }
 
 def Vector_YieldOp : Vector_Op<"yield", [

diff  --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 58256b0ade9f6..45c54c7587c69 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -7601,6 +7601,111 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
   setResultRanges(getResult(), result);
 }
 
+namespace {
+
+/// Fold `vector.step -> arith.cmpi` when the step value is compared to a
+/// constant large enough such that the result is the same at all indices.
+///
+/// For example, rewrite the 'greater than' comparison below,
+///
+/// ```mlir
+/// %cst = arith.constant dense<7> : vector<3xindex>
+/// %stp = vector.step : vector<3xindex>
+/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex>
+/// ```
+///
+/// as,
+///
+/// ```mlir
+/// %out = arith.constant dense<false> : vector<3xi1>.
+/// ```
+///
+/// Above `[0, 1, 2] > [7, 7, 7]` => `[false, false, false]`. Because the result
+/// is false at ALL indices we fold. If the constant was 1, then
+/// `[0, 1, 2] > [1, 1, 1]` => `[false, false, true]` and we do fold,
+/// conservatively preferring the 'compact' vector.step representation.
+///
+/// Note: this folder only works for the case where the constant (`%cst` above)
+/// is the second operand of the comparison. The arith.cmpi canonicalizer will
+/// ensure that constants are always second (on the right).
+struct StepCompareFolder : public OpRewritePattern<StepOp> {
+  using Base::Base;
+
+  LogicalResult matchAndRewrite(StepOp stepOp,
+                                PatternRewriter &rewriter) const override {
+    const int64_t stepSize = stepOp.getResult().getType().getNumElements();
+
+    for (OpOperand &use : stepOp.getResult().getUses()) {
+      auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
+      if (!cmpiOp)
+        continue;
+
+      // arith.cmpi canonicalizer makes constants final operands.
+      const unsigned stepOperandNumber = use.getOperandNumber();
+      if (stepOperandNumber != 0)
+        continue;
+
+      // Check that operand 1 is a constant.
+      unsigned constOperandNumber = 1;
+      Value otherOperand = cmpiOp.getOperand(constOperandNumber);
+      std::optional<int64_t> maybeConstValue =
+          getConstantIntValue(otherOperand);
+      if (!maybeConstValue.has_value())
+        continue;
+
+      int64_t constValue = maybeConstValue.value();
+      arith::CmpIPredicate pred = cmpiOp.getPredicate();
+
+      auto maybeSplat = [&]() -> std::optional<bool> {
+        // Handle ult (unsigned less than) and uge (unsigned greater equal).
+        if ((pred == arith::CmpIPredicate::ult ||
+             pred == arith::CmpIPredicate::uge) &&
+            stepSize <= constValue)
+          return pred == arith::CmpIPredicate::ult;
+
+        // Handle ule and ugt.
+        if ((pred == arith::CmpIPredicate::ule ||
+             pred == arith::CmpIPredicate::ugt) &&
+            stepSize - 1 <= constValue) {
+          return pred == arith::CmpIPredicate::ule;
+        }
+
+        // Handle eq and ne.
+        if ((pred == arith::CmpIPredicate::eq ||
+             pred == arith::CmpIPredicate::ne) &&
+            stepSize <= constValue)
+          return pred == arith::CmpIPredicate::ne;
+
+        return std::nullopt;
+      }();
+
+      if (!maybeSplat.has_value())
+        continue;
+
+      rewriter.setInsertionPointAfter(cmpiOp);
+
+      auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
+      if (!type)
+        continue;
+
+      auto boolAttr = DenseElementsAttr::get(type, maybeSplat.value());
+      Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
+                                                    type, boolAttr);
+
+      rewriter.replaceOp(cmpiOp, splat);
+      return success();
+    }
+
+    return failure();
+  }
+};
+} // namespace
+
+void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                         MLIRContext *context) {
+  results.add<StepCompareFolder>(context);
+}
+
 //===----------------------------------------------------------------------===//
 // Vector Masking Utilities
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir
new file mode 100644
index 0000000000000..023a0e52b65dc
--- /dev/null
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir
@@ -0,0 +1,311 @@
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file | FileCheck %s
+
+///===----------------------------------------------===//
+///  Tests of `StepCompareFolder`
+///===----------------------------------------------===//
+
+
+///===------------------------------------===//
+///  Tests of `ugt` (unsigned greater than)
+///===------------------------------------===//
+
+// CHECK-LABEL: @ugt_constant_3_lhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @ugt_constant_3_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<3> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // 3 > [0, 1, 2] => [true, true, true] => true for all indices => fold
+  %1 = arith.cmpi ugt, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_ugt_constant_2_lhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @negative_ugt_constant_2_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<2> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // 2 > [0, 1, 2] => [true, true, false] => not same for all indices => don't fold
+  %1 = arith.cmpi ugt, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @ugt_constant_3_rhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @ugt_constant_3_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<3> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // [0, 1, 2] > 3 => [false, false, false] => false for all indices => fold
+  %1 = arith.cmpi ugt, %0, %cst : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @ugt_constant_max_rhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @ugt_constant_max_rhs() -> vector<3xi1> {
+  // The largest i64 possible:
+  %cst = arith.constant dense<0x7fffffffffffffff> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ugt, %0, %cst: vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+
+// -----
+
+// CHECK-LABEL: @ugt_constant_2_rhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @ugt_constant_2_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<2> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // [0, 1, 2] > 2 => [false, false, false] => false for all indices => fold
+  %1 = arith.cmpi ugt, %0, %cst : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_ugt_constant_1_rhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @negative_ugt_constant_1_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<1> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // [0, 1, 2] > 1 => [false, false, true] => not same for all indices => don't fold
+  %1 = arith.cmpi ugt, %0, %cst: vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+///===------------------------------------===//
+///  Tests of `uge` (unsigned greater than or equal)
+///===------------------------------------===//
+
+
+// CHECK-LABEL: @uge_constant_2_lhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @uge_constant_2_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<2> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // 2 >= [0, 1, 2] => [true, true, true] => true for all indices => fold
+  %1 = arith.cmpi uge, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_uge_constant_1_lhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @negative_uge_constant_1_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<1> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // 1 >= [0, 1, 2] => [true, false, false] => not same for all indices => don't fold
+  %1 = arith.cmpi uge, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @uge_constant_3_rhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @uge_constant_3_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<3> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // [0, 1, 2] >= 3 => [false, false, false] => false for all indices => fold
+  %1 = arith.cmpi uge, %0, %cst : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_uge_constant_2_rhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @negative_uge_constant_2_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<2> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // [0, 1, 2] >= 2 => [false, false, true] => not same for all indices => don't fold
+  %1 = arith.cmpi uge, %0, %cst : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+
+///===------------------------------------===//
+///  Tests of `ult` (unsigned less than)
+///===------------------------------------===//
+
+
+// CHECK-LABEL: @ult_constant_2_lhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @ult_constant_2_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<2> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // 2 < [0, 1, 2] => [false, false, false] => false for all indices => fold
+  %1 = arith.cmpi ult, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_ult_constant_1_lhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @negative_ult_constant_1_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<1> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // 1 < [0, 1, 2] => [false, false, true] => not same for all indices => don't fold
+  %1 = arith.cmpi ult, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @ult_constant_3_rhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @ult_constant_3_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<3> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // [0, 1, 2] < 3 => [true, true, true] => true for all indices => fold
+  %1 = arith.cmpi ult, %0, %cst : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_ult_constant_2_rhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @negative_ult_constant_2_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<2> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  // [0, 1, 2] < 2 => [true, true, false] => not same for all indices => don't fold
+  %1 = arith.cmpi ult, %0, %cst : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+///===------------------------------------===//
+///  Tests of `ule` (unsigned less than or equal)
+///===------------------------------------===//
+
+// CHECK-LABEL: @ule_constant_3_lhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @ule_constant_3_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<3> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ule, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_ule_constant_2_lhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @negative_ule_constant_2_lhs() -> vector<3xi1> {
+  %cst = arith.constant dense<2> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ule, %cst, %0 : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @ule_constant_2_rhs
+//       CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @ule_constant_2_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<2> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ule, %0, %cst : vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_ule_constant_1_rhs
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @negative_ule_constant_1_rhs() -> vector<3xi1> {
+  %cst = arith.constant dense<1> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ule, %0, %cst: vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+///===------------------------------------===//
+///  Tests of `eq` (equal)
+///===------------------------------------===//
+
+// CHECK-LABEL: @eq_constant_3
+//       CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @eq_constant_3() -> vector<3xi1> {
+  %cst = arith.constant dense<3> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi eq, %0, %cst: vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_eq_constant_2
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @negative_eq_constant_2() -> vector<3xi1> {
+  %cst = arith.constant dense<2> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi eq, %0, %cst: vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+///===------------------------------------===//
+///  Tests of `ne` (not equal)
+///===------------------------------------===//
+
+// CHECK-LABEL: @ne_constant_3
+//       CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+//       CHECK: return %[[CST]] : vector<3xi1>
+func.func @ne_constant_3() -> vector<3xi1> {
+  %cst = arith.constant dense<3> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ne, %0, %cst: vector<3xindex>
+  return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @negative_ne_constant_2
+//       CHECK: %[[CMP:.*]] = arith.cmpi
+//       CHECK: return %[[CMP]]
+func.func @negative_ne_constant_2() -> vector<3xi1> {
+  %cst = arith.constant dense<2> : vector<3xindex>
+  %0 = vector.step : vector<3xindex>
+  %1 = arith.cmpi ne, %0, %cst: vector<3xindex>
+  return %1 : vector<3xi1>
+}
+


        


More information about the Mlir-commits mailing list