[Mlir-commits] [mlir] [mlir][Vector] Fold vector.step compared to constant (PR #161615)

Jakub Kuderski llvmlistbot at llvm.org
Thu Oct 2 09:44:38 PDT 2025


================
@@ -7524,6 +7524,101 @@ void StepOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
   setResultRanges(getResult(), result);
 }
 
+namespace {
+
+/// Fold `vector.step -> arith.cmpi` when the step value is compared to a
+/// constant large enough such that the result is the same at all indices.
+///
+/// For example, rewrite the 'greater than' comparison below,
+///
+/// %cst = arith.constant dense<7> : vector<3xindex>
+/// %stp = vector.step : vector<3xindex>
+/// %out = arith.cmpi ugt, %stp, %cst : vector<3xindex>
+///
+/// as,
+///
+/// %out = arith.constant dense<false> : vector<3xi1>.
+///
+/// Above [0, 1, 2] > [7, 7, 7] => [false, false, false]. Because the result is
+/// false at ALL indices we fold. If the constant was 1, then
+/// [0, 1, 2] > [1, 1, 1] => [false, false, true] and we do fold, conservatively
+/// preferring the 'compact' vector.step representation.
+struct StepCompareFolder : public OpRewritePattern<StepOp> {
+  using OpRewritePattern::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(StepOp stepOp,
+                                PatternRewriter &rewriter) const override {
+    const int64_t stepSize = stepOp.getResult().getType().getNumElements();
+
+    for (auto &use : stepOp.getResult().getUses()) {
+      if (auto cmpiOp = dyn_cast<arith::CmpIOp>(use.getOwner())) {
+        const unsigned stepOperandNumber = use.getOperandNumber();
+
+        // arith.cmpi canonicalizer makes constants final operands.
+        if (stepOperandNumber != 0)
+          continue;
+
+        // Check that operand 1 is a constant.
+        unsigned constOperandNumber = 1;
+        Value otherOperand = cmpiOp.getOperand(constOperandNumber);
+        auto maybeConstValue = getConstantIntValue(otherOperand);
+        if (!maybeConstValue.has_value())
+          continue;
+
+        int64_t constValue = maybeConstValue.value();
+        arith::CmpIPredicate pred = cmpiOp.getPredicate();
+
+        auto maybeSplat = [&]() -> std::optional<bool> {
+          // Handle ult (unsigned less than) and uge (unsigned greater equal).
+          if ((pred == arith::CmpIPredicate::ult ||
+               pred == arith::CmpIPredicate::uge) &&
+              stepSize <= constValue)
+            return pred == arith::CmpIPredicate::ult;
+
+          // Handle ule and ugt.
+          if ((pred == arith::CmpIPredicate::ule ||
+               pred == arith::CmpIPredicate::ugt) &&
+              stepSize <= constValue + 1)
----------------
kuhar wrote:

Are you sure this doesn't overflow?

https://github.com/llvm/llvm-project/pull/161615


More information about the Mlir-commits mailing list