[Mlir-commits] [mlir] [mlir][Vector] Fold vector.step compared to constant (PR #161615)

Jakub Kuderski llvmlistbot at llvm.org
Fri Oct 10 17:12:18 PDT 2025


================
@@ -7524,6 +7524,103 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
   setResultRanges(getResult(), result);
 }
 
+namespace {
+
+/// Fold `vector.step -> arith.cmpi` when the step value is compared to a
+/// constant large enough such that the result is the same at all indices.
+///
+/// For example, rewrite the 'greater than' comparison below,
+///
+/// %cst = arith.constant dense<7> : vector<3xindex>
+/// %stp = vector.step : vector<3xindex>
+/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex>
+///
+/// as,
+///
+/// %out = arith.constant dense<false> : vector<3xi1>.
+///
+/// Above [0, 1, 2] > [7, 7, 7] => [false, false, false]. Because the result is
+/// false at ALL indices we fold. If the constant was 1, then
+/// [0, 1, 2] > [1, 1, 1] => [false, false, true] and we do fold, conservatively
+/// preferring the 'compact' vector.step representation.
+struct StepCompareFolder : public OpRewritePattern<StepOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(StepOp stepOp,
+                                PatternRewriter &rewriter) const override {
+    const int64_t stepSize = stepOp.getResult().getType().getNumElements();
+
+    for (auto &use : stepOp.getResult().getUses()) {
+      auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner());
+      if (!cmpiOp)
+        continue;
+
+      // arith.cmpi canonicalizer makes constants final operands.
+      const unsigned stepOperandNumber = use.getOperandNumber();
+      if (stepOperandNumber != 0)
+        continue;
+
+      // Check that operand 1 is a constant.
+      unsigned constOperandNumber = 1;
+      Value otherOperand = cmpiOp.getOperand(constOperandNumber);
+      auto maybeConstValue = getConstantIntValue(otherOperand);
+      if (!maybeConstValue.has_value())
+        continue;
+
+      int64_t constValue = maybeConstValue.value();
+      arith::CmpIPredicate pred = cmpiOp.getPredicate();
+
+      auto maybeSplat = [&]() -> std::optional<bool> {
+        // Handle ult (unsigned less than) and uge (unsigned greater equal).
+        if ((pred == arith::CmpIPredicate::ult ||
+             pred == arith::CmpIPredicate::uge) &&
+            stepSize <= constValue)
+          return pred == arith::CmpIPredicate::ult;
+
+        // Handle ule and ugt.
+        if ((pred == arith::CmpIPredicate::ule ||
+             pred == arith::CmpIPredicate::ugt) &&
+            stepSize - 1 <= constValue) {
+          return pred == arith::CmpIPredicate::ule;
+        }
+
+        // Handle eq and ne.
+        if ((pred == arith::CmpIPredicate::eq ||
+             pred == arith::CmpIPredicate::ne) &&
+            stepSize <= constValue)
+          return pred == arith::CmpIPredicate::ne;
+
+        return std::nullopt;
+      }();
+
+      if (!maybeSplat.has_value())
+        continue;
+
+      rewriter.setInsertionPointAfter(cmpiOp);
+
+      auto type = dyn_cast<VectorType>(cmpiOp.getResult().getType());
+      if (!type)
+        continue;
+
+      DenseElementsAttr boolAttr =
+          DenseElementsAttr::get(type, maybeSplat.value());
+      Value splat = mlir::arith::ConstantOp::create(rewriter, cmpiOp.getLoc(),
+                                                    type, boolAttr);
+
+      rewriter.replaceOp(cmpiOp, splat);
+      return success();
+    }
+
+    return failure();
+  }
+};
+} // namespace
+
+void StepOp::getCanonicalizationPatterns(RewritePatternSet &results,
+                                         MLIRContext *context) {
+  results.add<StepCompareFolder>(context);
----------------
kuhar wrote:

Ah, right, it's an issue with layering. Fine to keep as-is then.

https://github.com/llvm/llvm-project/pull/161615


More information about the Mlir-commits mailing list