[Mlir-commits] [mlir] [mlir][affine] re-land implement `promoteIfSingleIteration` for `AffineForOp` (PR #72805)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Nov 22 10:51:42 PST 2023


================
@@ -2440,6 +2442,69 @@ std::optional<OpFoldResult> AffineForOp::getSingleUpperBound() {
   return OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()));
 }
 
+void mlir::affine::replaceIterArgsAndYieldResults(AffineForOp forOp) {
+  // Replace uses of iter arguments with iter operands (initial values).
+  OperandRange iterOperands = forOp.getInits();
+  MutableArrayRef<BlockArgument> iterArgs = forOp.getRegionIterArgs();
+  for (auto [operand, arg] : llvm::zip(iterOperands, iterArgs))
+    arg.replaceAllUsesWith(operand);
+
+  // Replace uses of loop results with the values yielded by the loop.
+  ResultRange outerResults = forOp.getResults();
+  OperandRange innerResults = forOp.getBody()->getTerminator()->getOperands();
+  for (auto [outer, inner] : llvm::zip(outerResults, innerResults))
+    outer.replaceAllUsesWith(inner);
+}
+
+LogicalResult AffineForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
+  auto forOp = cast<AffineForOp>(getOperation());
+  std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
+  if (!tripCount || *tripCount != 1)
+    return failure();
+
+  // TODO: extend this for arbitrary affine bounds.
+  if (forOp.getLowerBoundMap().getNumResults() != 1)
+    return failure();
+
+  // Replaces all IV uses to its single iteration value.
+  BlockArgument iv = forOp.getInductionVar();
+  if (!iv.use_empty()) {
+    if (forOp.hasConstantLowerBound()) {
+      Operation *parentOp = forOp.getOperation();
+      while (isa<AffineForOp>(parentOp->getParentOp()))
+        parentOp = parentOp->getParentOp();
+      Block *parentBlock = parentOp->getBlock();
+      OpBuilder topBuilder(parentBlock, parentBlock->begin());
+      auto constOp = topBuilder.create<arith::ConstantIndexOp>(
+          forOp.getLoc(), forOp.getConstantLowerBound());
+      iv.replaceAllUsesWith(constOp);
+    } else {
+      OperandRange lbOperands = forOp.getLowerBoundOperands();
+      AffineMap lbMap = forOp.getLowerBoundMap();
+      OpBuilder builder(forOp);
+      if (lbMap == builder.getDimIdentityMap()) {
+        // No need of generating an affine.apply.
+        iv.replaceAllUsesWith(lbOperands[0]);
+      } else {
+        auto affineApplyOp =
+            builder.create<AffineApplyOp>(forOp.getLoc(), lbMap, lbOperands);
+        iv.replaceAllUsesWith(affineApplyOp);
+      }
+    }
+  }
+
+  replaceIterArgsAndYieldResults(forOp);
+
+  // Move the loop body operations, except for its terminator, to the loop's
+  // containing block.
+  forOp.getBody()->back().erase();
+  Block *parentBlock = forOp->getBlock();
+  parentBlock->getOperations().splice(Block::iterator(forOp),
+                                      forOp.getBody()->getOperations());
+  forOp.erase();
----------------
srcarroll wrote:

i think it's safer to use `rewriter.eraseOp(forOp)` and similar on line 2500 above.  This could break if used in a pattern rewriter for a class. It has happened to me before.

https://github.com/llvm/llvm-project/pull/72805


More information about the Mlir-commits mailing list