[Mlir-commits] [mlir] [mlir][Vector] Fold vector.step compared to constant (PR #161615)
James Newling
llvmlistbot at llvm.org
Thu Oct 2 10:39:48 PDT 2025
https://github.com/newling updated https://github.com/llvm/llvm-project/pull/161615
>From 4555ad22c26a2b2a7e5f335d5d7e6eefb814e35f Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Wed, 1 Oct 2025 18:07:36 -0700
Subject: [PATCH 1/3] add folder
Signed-off-by: James Newling <james.newling at gmail.com>
---
.../mlir/Dialect/Vector/IR/VectorOps.td | 1 +
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 114 ++++++
.../Vector/canonicalize/vector-step.mlir | 379 ++++++++++++++++++
3 files changed, 494 insertions(+)
create mode 100644 mlir/test/Dialect/Vector/canonicalize/vector-step.mlir
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index 252c0b72456df..dbb5d0f659159 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -3045,6 +3045,7 @@ def Vector_StepOp : Vector_Op<"step", [
}];
let results = (outs VectorOfRankAndType<[1], [Index]>:$result);
let assemblyFormat = "attr-dict `:` type($result)";
+ let hasCanonicalizer = 1;
}
def Vector_YieldOp : Vector_Op<"yield", [
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index eb4686997c1b9..6a18df72c5335 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -7524,6 +7524,120 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), result);
}
+namespace {
+
+/// Constant fold vector.step when it is compared to constant with arith.cmpi
+/// and the result is the same at all indices. For example, rewrite:
+///
+/// %cst = arith.constant dense<7> : vector<3xindex>
+/// %0 = vector.step : vector<3xindex>
+/// %1 = arith.cmpi ugt, %0, %cst : vector<3xindex>
+///
+/// as
+///
+/// %out = arith.constant dense<false> : vector<3xi1>
+///
+/// Above [0, 1, 2] > [7, 7, 7] => [false, false, false]. Because the result is
+/// false at ALL indices we fold to the constant. false. If the constant was 1,
+/// then [0, 1, 2] > [1, 1, 1] => [false, false, true] and we do not constant
+/// fold, preferring the more 'compact' vector.step representation.
+struct StepCompareFolder : public OpRewritePattern<StepOp> {
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(StepOp stepOp,
+ PatternRewriter &rewriter) const override {
+
+ int64_t stepSize = stepOp.getResult().getType().getNumElements();
+
+ for (auto &use : stepOp.getResult().getUses()) {
+ if (auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner())) {
+ unsigned stepOperandNumber = use.getOperandNumber();
+
+ // arith.cmpi has a canonicalizer to put constants on operand 1. Let it
+ // run first.
+ if (stepOperandNumber != 0) {
+ continue;
+ }
+
+ // Check that operand 1 is a constant.
+ unsigned otherOperandNumber = 1;
+ Value otherOperand = cmpiOp.getOperand(otherOperandNumber);
+ auto maybeConstValue = getConstantIntValue(otherOperand);
+ if (!maybeConstValue.has_value())
+ continue;
+ int64_t constValue = maybeConstValue.value();
+
+ arith::CmpIPredicate pred = cmpiOp.getPredicate();
+
+ auto maybeSplat = [&]() -> std::optional<bool> {
+ // Handle ult (unsigned less than) and uge (unsigned greater equal).
+ // Examples where stepSize = constValue = 3, for the 4
+ // cases of [ult, uge] x [stepOperandNumber = 0, 1]:
+ //
+ // pred stepOperandNumber
+ // ==== =================
+ // ult 0 [0, 1, 2] < 3 ==> true.
+ // ult 1 3 < [0, 1, 2] ==> false.
+ // uge 0 [0, 1, 2] >= 3 ==> true.
+ // uge 1 3 >= [0, 1, 2] ==> false.
+ //
+ // If constValue is any smaller, the comparison is not constant.
+ if (pred == arith::CmpIPredicate::ult ||
+ pred == arith::CmpIPredicate::uge) {
+ if (stepSize <= constValue) {
+ return pred == arith::CmpIPredicate::ult;
+ }
+ }
+
+ // Handle ule and ugt.
+ //
+ // pred stepOperandNumber
+ // ==== =================
+ // ule 0 [0, 1, 2] <= 2 ==> true
+ // (stepSize = 3, constValue = 2).
+ if (pred == arith::CmpIPredicate::ule ||
+ pred == arith::CmpIPredicate::ugt) {
+ if (stepSize <= constValue + 1) {
+ return pred == arith::CmpIPredicate::ule;
+ }
+ }
+
+ // Handle eq and ne
+ if (pred == arith::CmpIPredicate::eq ||
+ pred == arith::CmpIPredicate::ne) {
+ if (stepSize <= constValue) {
+ return pred == arith::CmpIPredicate::ne;
+ }
+ }
+
+ return std::optional<bool>();
+ }();
+
+ if (!maybeSplat.has_value())
+ continue;
+
+ rewriter.setInsertionPointAfter(cmpiOp);
+ auto boolConst = mlir::arith::ConstantOp::create(
+ rewriter, cmpiOp.getLoc(),
+ rewriter.getBoolAttr(maybeSplat.value()));
+ auto splat = vector::BroadcastOp::create(
+ rewriter, cmpiOp.getLoc(), cmpiOp.getResult().getType(), boolConst);
+
+ rewriter.replaceOp(cmpiOp, splat.getResult());
+ return success();
+ }
+ }
+
+ return failure();
+ }
+};
+} // namespace
+
+void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<StepCompareFolder>(context);
+}
+
//===----------------------------------------------------------------------===//
// Vector Masking Utilities
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir
new file mode 100644
index 0000000000000..e213c78f5ea42
--- /dev/null
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir
@@ -0,0 +1,379 @@
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
+
+///===----------------------------------------------===//
+/// Tests of `StepCompareFolder`
+///===----------------------------------------------===//
+
+
+///===--------------===//
+/// Tests of `ugt` (unsigned greater than)
+///===--------------===//
+
+// CHECK-LABEL: @check_ugt_constant_3_lhs
+// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ugt_constant_3_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<3> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // 3 > [0, 1, 2] => true
+ %1 = arith.cmpi ugt, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ugt_constant_2_lhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @check_ugt_constant_2_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<2> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // 2 > [0, 1, 2] => not constant
+ %1 = arith.cmpi ugt, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ugt_constant_1_lhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @check_ugt_constant_1_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<1> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // 1 > [0, 1, 2] => not constant
+ %1 = arith.cmpi ugt, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ugt_constant_3_rhs
+// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ugt_constant_3_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<3> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // [0, 1, 2] > 3 => false
+ %1 = arith.cmpi ugt, %0, %cst : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ugt_constant_2_rhs
+// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ugt_constant_2_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<2> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // [0, 1, 2] > 2 => false
+ %1 = arith.cmpi ugt, %0, %cst : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ugt_constant_1_rhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @check_ugt_constant_1_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<1> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // [0, 1, 2] > 1 => not constant
+ %1 = arith.cmpi ugt, %0, %cst: vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+///===--------------===//
+/// Tests of `uge` (unsigned greater than or equal)
+///===--------------===//
+
+// CHECK-LABEL: @check_uge_constant_3_lhs
+// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_uge_constant_3_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<3> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // 3 >= [0, 1, 2] => true
+ %1 = arith.cmpi uge, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_uge_constant_2_lhs
+// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_uge_constant_2_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<2> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // 2 >= [0, 1, 2] => true
+ %1 = arith.cmpi uge, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_uge_constant_1_lhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @check_uge_constant_1_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<1> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // 1 >= [0, 1, 2] => not constant
+ %1 = arith.cmpi uge, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_uge_constant_3_rhs
+// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_uge_constant_3_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<3> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // [0, 1, 2] >= 3 => false
+ %1 = arith.cmpi uge, %0, %cst : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_uge_constant_2_rhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @check_uge_constant_2_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<2> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // [0, 1, 2] >= 2 => not constant
+ %1 = arith.cmpi uge, %0, %cst : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_uge_constant_1_rhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @check_uge_constant_1_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<1> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ // [0, 1, 2] >= 1 => not constant
+ %1 = arith.cmpi uge, %0, %cst: vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+
+
+///===--------------===//
+/// Tests of `ult` (unsigned less than)
+///===--------------===//
+
+// CHECK-LABEL: @check_ult_constant_3_lhs
+// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ult_constant_3_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<3> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ult, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ult_constant_2_lhs
+// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ult_constant_2_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<2> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ult, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ult_constant_1_lhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @check_ult_constant_1_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<1> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ult, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ult_constant_3_rhs
+// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ult_constant_3_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<3> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ult, %0, %cst : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ult_constant_2_rhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @check_ult_constant_2_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<2> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ult, %0, %cst : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ult_constant_1_rhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @check_ult_constant_1_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<1> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ult, %0, %cst: vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+///===--------------===//
+/// Tests of `ule` (unsigned less than or equal)
+///===--------------===//
+
+// CHECK-LABEL: @check_ule_constant_3_lhs
+// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ule_constant_3_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<3> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ule, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ule_constant_2_lhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @check_ule_constant_2_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<2> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ule, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ule_constant_1_lhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @check_ule_constant_1_lhs() -> vector<3xi1> {
+ %cst = arith.constant dense<1> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ule, %cst, %0 : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ule_constant_3_rhs
+// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ule_constant_3_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<3> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ule, %0, %cst : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ule_constant_2_rhs
+// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ule_constant_2_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<2> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ule, %0, %cst : vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ule_constant_1_rhs
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @check_ule_constant_1_rhs() -> vector<3xi1> {
+ %cst = arith.constant dense<1> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ule, %0, %cst: vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+///===--------------===//
+/// Tests of `eq` (equal)
+///===--------------===//
+
+// CHECK-LABEL: @check_eq_constant_3
+// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_eq_constant_3() -> vector<3xi1> {
+ %cst = arith.constant dense<3> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi eq, %0, %cst: vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_eq_constant_2
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @check_eq_constant_2() -> vector<3xi1> {
+ %cst = arith.constant dense<2> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi eq, %0, %cst: vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+///===--------------===//
+/// Tests of `ne` (not equal)
+///===--------------===//
+
+// CHECK-LABEL: @check_ne_constant_3
+// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ne_constant_3() -> vector<3xi1> {
+ %cst = arith.constant dense<3> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ne, %0, %cst: vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+// -----
+
+// CHECK-LABEL: @check_ne_constant_2
+// CHECK: %[[CMP:.*]] = arith.cmpi
+// CHECK: return %[[CMP]]
+func.func @check_ne_constant_2() -> vector<3xi1> {
+ %cst = arith.constant dense<2> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ne, %0, %cst: vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
>From 7232237305d5374c802995d42fe5e6126852911e Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Thu, 2 Oct 2025 09:07:52 -0700
Subject: [PATCH 2/3] cosmetics
Signed-off-by: James Newling <james.newling at gmail.com>
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 101 +++++++-----------
.../Vector/canonicalize/vector-step.mlir | 24 ++---
2 files changed, 53 insertions(+), 72 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 6a18df72c5335..306be186308b0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -7526,89 +7526,66 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
namespace {
-/// Constant fold vector.step when it is compared to constant with arith.cmpi
-/// and the result is the same at all indices. For example, rewrite:
+/// Fold `vector.step -> arith.cmpi` when the step value is compared to a
+/// constant large enough such that the result is the same at all indices.
+///
+/// For example, rewrite the 'greater than' comparison below,
///
/// %cst = arith.constant dense<7> : vector<3xindex>
-/// %0 = vector.step : vector<3xindex>
-/// %1 = arith.cmpi ugt, %0, %cst : vector<3xindex>
+/// %stp = vector.step : vector<3xindex>
+/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex>
///
-/// as
+/// as,
///
-/// %out = arith.constant dense<false> : vector<3xi1>
+/// %out = arith.constant dense<false> : vector<3xi1>.
///
/// Above [0, 1, 2] > [7, 7, 7] => [false, false, false]. Because the result is
-/// false at ALL indices we fold to the constant. false. If the constant was 1,
-/// then [0, 1, 2] > [1, 1, 1] => [false, false, true] and we do not constant
-/// fold, preferring the more 'compact' vector.step representation.
+/// false at ALL indices we fold. If the constant was 1, then
+/// [0, 1, 2] > [1, 1, 1] => [false, false, true] and we do fold, conservatively
+/// preferring the 'compact' vector.step representation.
struct StepCompareFolder : public OpRewritePattern<StepOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(StepOp stepOp,
PatternRewriter &rewriter) const override {
-
- int64_t stepSize = stepOp.getResult().getType().getNumElements();
+ const int64_t stepSize = stepOp.getResult().getType().getNumElements();
for (auto &use : stepOp.getResult().getUses()) {
if (auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner())) {
- unsigned stepOperandNumber = use.getOperandNumber();
+ const unsigned stepOperandNumber = use.getOperandNumber();
- // arith.cmpi has a canonicalizer to put constants on operand 1. Let it
- // run first.
- if (stepOperandNumber != 0) {
+ // arith.cmpi canonicalizer makes constants final operands.
+ if (stepOperandNumber != 0)
continue;
- }
// Check that operand 1 is a constant.
- unsigned otherOperandNumber = 1;
- Value otherOperand = cmpiOp.getOperand(otherOperandNumber);
+ unsigned constOperandNumber = 1;
+ Value otherOperand = cmpiOp.getOperand(constOperandNumber);
auto maybeConstValue = getConstantIntValue(otherOperand);
if (!maybeConstValue.has_value())
continue;
- int64_t constValue = maybeConstValue.value();
+ int64_t constValue = maybeConstValue.value();
arith::CmpIPredicate pred = cmpiOp.getPredicate();
auto maybeSplat = [&]() -> std::optional<bool> {
// Handle ult (unsigned less than) and uge (unsigned greater equal).
- // Examples where stepSize = constValue = 3, for the 4
- // cases of [ult, uge] x [stepOperandNumber = 0, 1]:
- //
- // pred stepOperandNumber
- // ==== =================
- // ult 0 [0, 1, 2] < 3 ==> true.
- // ult 1 3 < [0, 1, 2] ==> false.
- // uge 0 [0, 1, 2] >= 3 ==> true.
- // uge 1 3 >= [0, 1, 2] ==> false.
- //
- // If constValue is any smaller, the comparison is not constant.
- if (pred == arith::CmpIPredicate::ult ||
- pred == arith::CmpIPredicate::uge) {
- if (stepSize <= constValue) {
- return pred == arith::CmpIPredicate::ult;
- }
- }
+ if ((pred == arith::CmpIPredicate::ult ||
+ pred == arith::CmpIPredicate::uge) &&
+ stepSize <= constValue)
+ return pred == arith::CmpIPredicate::ult;
// Handle ule and ugt.
- //
- // pred stepOperandNumber
- // ==== =================
- // ule 0 [0, 1, 2] <= 2 ==> true
- // (stepSize = 3, constValue = 2).
- if (pred == arith::CmpIPredicate::ule ||
- pred == arith::CmpIPredicate::ugt) {
- if (stepSize <= constValue + 1) {
- return pred == arith::CmpIPredicate::ule;
- }
- }
+ if ((pred == arith::CmpIPredicate::ule ||
+ pred == arith::CmpIPredicate::ugt) &&
+ stepSize <= constValue + 1)
+ return pred == arith::CmpIPredicate::ule;
- // Handle eq and ne
- if (pred == arith::CmpIPredicate::eq ||
- pred == arith::CmpIPredicate::ne) {
- if (stepSize <= constValue) {
- return pred == arith::CmpIPredicate::ne;
- }
- }
+ // Handle eq and ne.
+ if ((pred == arith::CmpIPredicate::eq ||
+ pred == arith::CmpIPredicate::ne) &&
+ stepSize <= constValue)
+ return pred == arith::CmpIPredicate::ne;
return std::optional<bool>();
}();
@@ -7617,13 +7594,17 @@ struct StepCompareFolder : public OpRewritePattern<StepOp> {
continue;
rewriter.setInsertionPointAfter(cmpiOp);
- auto boolConst = mlir::arith::ConstantOp::create(
- rewriter, cmpiOp.getLoc(),
- rewriter.getBoolAttr(maybeSplat.value()));
- auto splat = vector::BroadcastOp::create(
- rewriter, cmpiOp.getLoc(), cmpiOp.getResult().getType(), boolConst);
- rewriter.replaceOp(cmpiOp, splat.getResult());
+ auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
+ if (!type)
+ continue;
+
+ DenseElementsAttr boolAttr =
+ DenseElementsAttr::get(type, maybeSplat.value());
+ Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
+ type, boolAttr);
+
+ rewriter.replaceOp(cmpiOp, splat);
return success();
}
}
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir
index e213c78f5ea42..effeb3d9c093a 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir
@@ -5,9 +5,9 @@
///===----------------------------------------------===//
-///===--------------===//
+///===------------------------------------===//
/// Tests of `ugt` (unsigned greater than)
-///===--------------===//
+///===------------------------------------===//
// CHECK-LABEL: @check_ugt_constant_3_lhs
// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
@@ -87,9 +87,9 @@ func.func @check_ugt_constant_1_rhs() -> vector<3xi1> {
// -----
-///===--------------===//
+///===------------------------------------===//
/// Tests of `uge` (unsigned greater than or equal)
-///===--------------===//
+///===------------------------------------===//
// CHECK-LABEL: @check_uge_constant_3_lhs
// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
@@ -171,9 +171,9 @@ func.func @check_uge_constant_1_rhs() -> vector<3xi1> {
-///===--------------===//
+///===------------------------------------===//
/// Tests of `ult` (unsigned less than)
-///===--------------===//
+///===------------------------------------===//
// CHECK-LABEL: @check_ult_constant_3_lhs
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
@@ -247,9 +247,9 @@ func.func @check_ult_constant_1_rhs() -> vector<3xi1> {
// -----
-///===--------------===//
+///===------------------------------------===//
/// Tests of `ule` (unsigned less than or equal)
-///===--------------===//
+///===------------------------------------===//
// CHECK-LABEL: @check_ule_constant_3_lhs
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
@@ -323,9 +323,9 @@ func.func @check_ule_constant_1_rhs() -> vector<3xi1> {
// -----
-///===--------------===//
+///===------------------------------------===//
/// Tests of `eq` (equal)
-///===--------------===//
+///===------------------------------------===//
// CHECK-LABEL: @check_eq_constant_3
// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
@@ -351,9 +351,9 @@ func.func @check_eq_constant_2() -> vector<3xi1> {
// -----
-///===--------------===//
+///===------------------------------------===//
/// Tests of `ne` (not equal)
-///===--------------===//
+///===------------------------------------===//
// CHECK-LABEL: @check_ne_constant_3
// CHECK: %[[CST:.*]] = arith.constant dense<true> : vector<3xi1>
>From 934606f8cad1e2fa1fc3dec4a460a178db0d169f Mon Sep 17 00:00:00 2001
From: James Newling <james.newling at gmail.com>
Date: Thu, 2 Oct 2025 10:42:28 -0700
Subject: [PATCH 3/3] address Jakub's comments
Signed-off-by: James Newling <james.newling at gmail.com>
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 94 ++++++++++---------
.../Vector/canonicalize/vector-step.mlir | 16 +++-
2 files changed, 63 insertions(+), 47 deletions(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 306be186308b0..2a7dff2a99e88 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -7551,62 +7551,64 @@ struct StepCompareFolder : public OpRewritePattern<StepOp> {
const int64_t stepSize = stepOp.getResult().getType().getNumElements();
for (auto &use : stepOp.getResult().getUses()) {
- if (auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner())) {
- const unsigned stepOperandNumber = use.getOperandNumber();
-
- // arith.cmpi canonicalizer makes constants final operands.
- if (stepOperandNumber != 0)
- continue;
-
- // Check that operand 1 is a constant.
- unsigned constOperandNumber = 1;
- Value otherOperand = cmpiOp.getOperand(constOperandNumber);
- auto maybeConstValue = getConstantIntValue(otherOperand);
- if (!maybeConstValue.has_value())
- continue;
+ auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
+ if (!cmpiOp)
+ continue;
- int64_t constValue = maybeConstValue.value();
- arith::CmpIPredicate pred = cmpiOp.getPredicate();
+ // arith.cmpi canonicalizer makes constants final operands.
+ const unsigned stepOperandNumber = use.getOperandNumber();
+ if (stepOperandNumber != 0)
+ continue;
- auto maybeSplat = [&]() -> std::optional<bool> {
- // Handle ult (unsigned less than) and uge (unsigned greater equal).
- if ((pred == arith::CmpIPredicate::ult ||
- pred == arith::CmpIPredicate::uge) &&
- stepSize <= constValue)
- return pred == arith::CmpIPredicate::ult;
+ // Check that operand 1 is a constant.
+ unsigned constOperandNumber = 1;
+ Value otherOperand = cmpiOp.getOperand(constOperandNumber);
+ auto maybeConstValue = getConstantIntValue(otherOperand);
+ if (!maybeConstValue.has_value())
+ continue;
- // Handle ule and ugt.
- if ((pred == arith::CmpIPredicate::ule ||
- pred == arith::CmpIPredicate::ugt) &&
- stepSize <= constValue + 1)
- return pred == arith::CmpIPredicate::ule;
+ int64_t constValue = maybeConstValue.value();
+ arith::CmpIPredicate pred = cmpiOp.getPredicate();
+
+ auto maybeSplat = [&]() -> std::optional<bool> {
+ // Handle ult (unsigned less than) and uge (unsigned greater equal).
+ if ((pred == arith::CmpIPredicate::ult ||
+ pred == arith::CmpIPredicate::uge) &&
+ stepSize <= constValue)
+ return pred == arith::CmpIPredicate::ult;
+
+ // Handle ule and ugt.
+ if ((pred == arith::CmpIPredicate::ule ||
+ pred == arith::CmpIPredicate::ugt) &&
+ stepSize - 1 <= constValue) {
+ return pred == arith::CmpIPredicate::ule;
+ }
- // Handle eq and ne.
- if ((pred == arith::CmpIPredicate::eq ||
- pred == arith::CmpIPredicate::ne) &&
- stepSize <= constValue)
- return pred == arith::CmpIPredicate::ne;
+ // Handle eq and ne.
+ if ((pred == arith::CmpIPredicate::eq ||
+ pred == arith::CmpIPredicate::ne) &&
+ stepSize <= constValue)
+ return pred == arith::CmpIPredicate::ne;
- return std::optional<bool>();
- }();
+ return std::nullopt;
+ }();
- if (!maybeSplat.has_value())
- continue;
+ if (!maybeSplat.has_value())
+ continue;
- rewriter.setInsertionPointAfter(cmpiOp);
+ rewriter.setInsertionPointAfter(cmpiOp);
- auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
- if (!type)
- continue;
+ auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
+ if (!type)
+ continue;
- DenseElementsAttr boolAttr =
- DenseElementsAttr::get(type, maybeSplat.value());
- Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
- type, boolAttr);
+ DenseElementsAttr boolAttr =
+ DenseElementsAttr::get(type, maybeSplat.value());
+ Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
+ type, boolAttr);
- rewriter.replaceOp(cmpiOp, splat);
- return success();
- }
+ rewriter.replaceOp(cmpiOp, splat);
+ return success();
}
return failure();
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir
index effeb3d9c093a..eb997438d2d51 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-step.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file -allow-unregistered-dialect | FileCheck %s
+// RUN: mlir-opt %s -canonicalize="test-convergence" -split-input-file | FileCheck %s
///===----------------------------------------------===//
/// Tests of `StepCompareFolder`
@@ -59,6 +59,20 @@ func.func @check_ugt_constant_3_rhs() -> vector<3xi1> {
return %1 : vector<3xi1>
}
+// -----
+
+// CHECK-LABEL: @check_ugt_constant_max_rhs
+// CHECK: %[[CST:.*]] = arith.constant dense<false> : vector<3xi1>
+// CHECK: return %[[CST]] : vector<3xi1>
+func.func @check_ugt_constant_max_rhs() -> vector<3xi1> {
+ // The largest i64 possible:
+ %cst = arith.constant dense<0x7fffffffffffffff> : vector<3xindex>
+ %0 = vector.step : vector<3xindex>
+ %1 = arith.cmpi ugt, %0, %cst: vector<3xindex>
+ return %1 : vector<3xi1>
+}
+
+
// -----
// CHECK-LABEL: @check_ugt_constant_2_rhs
More information about the Mlir-commits
mailing list