[Mlir-commits] [mlir] [mlir][affine] Make AffineForEmptyLoopFolder as folder function (PR #163929)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Oct 17 02:01:48 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-affine

Author: lonely eagle (linuxlonelyeagle)

<details>
<summary>Changes</summary>

Removed the hasCanonicalizer from AffineForOp.Make AffineForEmptyLoopFolder as folder function.

---
Full diff: https://github.com/llvm/llvm-project/pull/163929.diff


2 Files Affected:

- (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.td (-1) 
- (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+62-75) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index e52b7d2090d53..12a79358d42f1 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -330,7 +330,6 @@ def AffineForOp : Affine_Op<"for",
     Speculation::Speculatability getSpeculatability();
   }];
 
-  let hasCanonicalizer = 1;
   let hasCustomAssemblyFormat = 1;
   let hasFolder = 1;
   let hasRegionVerifier = 1;
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 7e5ce26b5f733..c02a7a56ae87b 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2460,6 +2460,67 @@ static LogicalResult foldLoopBounds(AffineForOp forOp) {
   return success(folded);
 }
 
+/// Returns constant trip count in trivial cases.
+static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
+  int64_t step = forOp.getStepAsInt();
+  if (!forOp.hasConstantBounds() || step <= 0)
+    return std::nullopt;
+  int64_t lb = forOp.getConstantLowerBound();
+  int64_t ub = forOp.getConstantUpperBound();
+  return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
+}
+
+/// Fold the empty loop.
+static LogicalResult AffineForEmptyLoopFolder(AffineForOp forOp) {
+  if (!llvm::hasSingleElement(*forOp.getBody()))
+    return failure();
+  if (forOp.getNumResults() == 0)
+    return success();
+  std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
+  if (tripCount == 0) {
+    // The initial values of the iteration arguments would be the op's
+    // results.
+    forOp.getResults().replaceAllUsesWith(forOp.getInits());
+    return success();
+  }
+  SmallVector<Value, 4> replacements;
+  auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
+  auto iterArgs = forOp.getRegionIterArgs();
+  bool hasValDefinedOutsideLoop = false;
+  bool iterArgsNotInOrder = false;
+  for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
+    Value val = yieldOp.getOperand(i);
+    auto *iterArgIt = llvm::find(iterArgs, val);
+    // TODO: It should be possible to perform a replacement by computing the
+    // last value of the IV based on the bounds and the step.
+    if (val == forOp.getInductionVar())
+      return failure();
+    if (iterArgIt == iterArgs.end()) {
+      // `val` is defined outside of the loop.
+      assert(forOp.isDefinedOutsideOfLoop(val) &&
+             "must be defined outside of the loop");
+      hasValDefinedOutsideLoop = true;
+      replacements.push_back(val);
+    } else {
+      unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
+      if (pos != i)
+        iterArgsNotInOrder = true;
+      replacements.push_back(forOp.getInits()[pos]);
+    }
+  }
+  // Bail out when the trip count is unknown and the loop returns any value
+  // defined outside of the loop or any iterArg out of order.
+  if (!tripCount.has_value() &&
+      (hasValDefinedOutsideLoop || iterArgsNotInOrder))
+    return failure();
+  // Bail out when the loop iterates more than once and it returns any iterArg
+  // out of order.
+  if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
+    return failure();
+  forOp.getResults().replaceAllUsesWith(replacements);
+  return success();
+}
+
 /// Canonicalize the bounds of the given loop.
 static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
   SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
@@ -2491,81 +2552,6 @@ static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
   return success();
 }
 
-namespace {
-/// Returns constant trip count in trivial cases.
-static std::optional<uint64_t> getTrivialConstantTripCount(AffineForOp forOp) {
-  int64_t step = forOp.getStepAsInt();
-  if (!forOp.hasConstantBounds() || step <= 0)
-    return std::nullopt;
-  int64_t lb = forOp.getConstantLowerBound();
-  int64_t ub = forOp.getConstantUpperBound();
-  return ub - lb <= 0 ? 0 : (ub - lb + step - 1) / step;
-}
-
-/// This is a pattern to fold trivially empty loop bodies.
-/// TODO: This should be moved into the folding hook.
-struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
-  using OpRewritePattern<AffineForOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(AffineForOp forOp,
-                                PatternRewriter &rewriter) const override {
-    // Check that the body only contains a yield.
-    if (!llvm::hasSingleElement(*forOp.getBody()))
-      return failure();
-    if (forOp.getNumResults() == 0)
-      return success();
-    std::optional<uint64_t> tripCount = getTrivialConstantTripCount(forOp);
-    if (tripCount == 0) {
-      // The initial values of the iteration arguments would be the op's
-      // results.
-      rewriter.replaceOp(forOp, forOp.getInits());
-      return success();
-    }
-    SmallVector<Value, 4> replacements;
-    auto yieldOp = cast<AffineYieldOp>(forOp.getBody()->getTerminator());
-    auto iterArgs = forOp.getRegionIterArgs();
-    bool hasValDefinedOutsideLoop = false;
-    bool iterArgsNotInOrder = false;
-    for (unsigned i = 0, e = yieldOp->getNumOperands(); i < e; ++i) {
-      Value val = yieldOp.getOperand(i);
-      auto *iterArgIt = llvm::find(iterArgs, val);
-      // TODO: It should be possible to perform a replacement by computing the
-      // last value of the IV based on the bounds and the step.
-      if (val == forOp.getInductionVar())
-        return failure();
-      if (iterArgIt == iterArgs.end()) {
-        // `val` is defined outside of the loop.
-        assert(forOp.isDefinedOutsideOfLoop(val) &&
-               "must be defined outside of the loop");
-        hasValDefinedOutsideLoop = true;
-        replacements.push_back(val);
-      } else {
-        unsigned pos = std::distance(iterArgs.begin(), iterArgIt);
-        if (pos != i)
-          iterArgsNotInOrder = true;
-        replacements.push_back(forOp.getInits()[pos]);
-      }
-    }
-    // Bail out when the trip count is unknown and the loop returns any value
-    // defined outside of the loop or any iterArg out of order.
-    if (!tripCount.has_value() &&
-        (hasValDefinedOutsideLoop || iterArgsNotInOrder))
-      return failure();
-    // Bail out when the loop iterates more than once and it returns any iterArg
-    // out of order.
-    if (tripCount.has_value() && tripCount.value() >= 2 && iterArgsNotInOrder)
-      return failure();
-    rewriter.replaceOp(forOp, replacements);
-    return success();
-  }
-};
-} // namespace
-
-void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
-                                              MLIRContext *context) {
-  results.add<AffineForEmptyLoopFolder>(context);
-}
-
 OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
   assert((point.isParent() || point == getRegion()) && "invalid region point");
 
@@ -2615,6 +2601,7 @@ LogicalResult AffineForOp::fold(FoldAdaptor adaptor,
                                 SmallVectorImpl<OpFoldResult> &results) {
   bool folded = succeeded(foldLoopBounds(*this));
   folded |= succeeded(canonicalizeLoopBounds(*this));
+  folded |= succeeded(AffineForEmptyLoopFolder(*this));
   if (hasTrivialZeroTripCount(*this) && getNumResults() != 0) {
     // The initial values of the loop-carried variables (iter_args) are the
     // results of the op. But this must be avoided for an affine.for op that

``````````

</details>


https://github.com/llvm/llvm-project/pull/163929


More information about the Mlir-commits mailing list