[Mlir-commits] [mlir] [mlir][affine] re-land implement `promoteIfSingleIteration` for `AffineForOp` (PR #72805)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Nov 19 11:57:47 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-affine

Author: Maksim Levental (makslevental)

<details>
<summary>Changes</summary>

I had to revert https://github.com/llvm/llvm-project/pull/72547 because I didn't notice a dep on `func::FuncOp` in `promoteIfSingleIteration`:

```cpp
if (forOp.hasConstantLowerBound()) {
  OpBuilder topBuilder(forOp->getParentOfType<func::FuncOp>().getBody());
  auto constOp = topBuilder.create<arith::ConstantIndexOp>(
      forOp.getLoc(), forOp.getConstantLowerBound());
```

I.e., hoist the `arith.constant` to the nearest `func`. The alternative I implemented here 

```cpp
if (forOp.hasConstantLowerBound()) {
  Operation *parentOp = forOp.getOperation();
  while (isa<AffineForOp>(parentOp->getParentOp()))
    parentOp = parentOp->getParentOp();
  Block *parentBlock = parentOp->getBlock();
  OpBuilder topBuilder(parentBlock, parentBlock->begin());
```

but just wanted to make sure.

---

Patch is 29.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/72805.diff


11 Files Affected:

- (modified) mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h (+1-14) 
- (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.h (+25-4) 
- (modified) mlir/include/mlir/Dialect/Affine/IR/AffineOps.td (+1-1) 
- (modified) mlir/include/mlir/Dialect/Affine/LoopUtils.h (-4) 
- (modified) mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp (+1-79) 
- (modified) mlir/lib/Dialect/Affine/IR/AffineOps.cpp (+144-7) 
- (modified) mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp (+2-1) 
- (modified) mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp (+3-3) 
- (modified) mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp (+13-70) 
- (modified) mlir/lib/Dialect/Affine/Utils/Utils.cpp (+2-1) 
- (modified) mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp (+2-1) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h b/mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h
index 92f3d5a2c4925b1..c629c3a1c562322 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h
@@ -15,6 +15,7 @@
 
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/ArrayRef.h"
+
 #include <optional>
 
 namespace mlir {
@@ -29,20 +30,6 @@ namespace affine {
 class AffineForOp;
 class NestedPattern;
 
-/// Returns the trip count of the loop as an affine map with its corresponding
-/// operands if the latter is expressible as an affine expression, and nullptr
-/// otherwise. This method always succeeds as long as the lower bound is not a
-/// multi-result map. The trip count expression is simplified before returning.
-/// This method only utilizes map composition to construct lower and upper
-/// bounds before computing the trip count expressions
-void getTripCountMapAndOperands(AffineForOp forOp, AffineMap *map,
-                                SmallVectorImpl<Value> *operands);
-
-/// Returns the trip count of the loop if it's a constant, std::nullopt
-/// otherwise. This uses affine expression analysis and is able to determine
-/// constant trip count in non-trivial cases.
-std::optional<uint64_t> getConstantTripCount(AffineForOp forOp);
-
 /// Returns the greatest known integral divisor of the trip count. Affine
 /// expression analysis is used (indirectly through getTripCount), and
 /// this method is thus able to determine non-trivial divisors.
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index f070d0488619063..f763cf339159a50 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -117,7 +117,8 @@ class AffineDmaStartOp
   /// Returns the affine map used to access the source memref.
   AffineMap getSrcMap() { return getSrcMapAttr().getValue(); }
   AffineMapAttr getSrcMapAttr() {
-    return cast<AffineMapAttr>(*(*this)->getInherentAttr(getSrcMapAttrStrName()));
+    return cast<AffineMapAttr>(
+        *(*this)->getInherentAttr(getSrcMapAttrStrName()));
   }
 
   /// Returns the source memref affine map indices for this DMA operation.
@@ -156,7 +157,8 @@ class AffineDmaStartOp
   /// Returns the affine map used to access the destination memref.
   AffineMap getDstMap() { return getDstMapAttr().getValue(); }
   AffineMapAttr getDstMapAttr() {
-    return cast<AffineMapAttr>(*(*this)->getInherentAttr(getDstMapAttrStrName()));
+    return cast<AffineMapAttr>(
+        *(*this)->getInherentAttr(getDstMapAttrStrName()));
   }
 
   /// Returns the destination memref indices for this DMA operation.
@@ -185,7 +187,8 @@ class AffineDmaStartOp
   /// Returns the affine map used to access the tag memref.
   AffineMap getTagMap() { return getTagMapAttr().getValue(); }
   AffineMapAttr getTagMapAttr() {
-    return cast<AffineMapAttr>(*(*this)->getInherentAttr(getTagMapAttrStrName()));
+    return cast<AffineMapAttr>(
+        *(*this)->getInherentAttr(getTagMapAttrStrName()));
   }
 
   /// Returns the tag memref indices for this DMA operation.
@@ -307,7 +310,8 @@ class AffineDmaWaitOp
   /// Returns the affine map used to access the tag memref.
   AffineMap getTagMap() { return getTagMapAttr().getValue(); }
   AffineMapAttr getTagMapAttr() {
-    return cast<AffineMapAttr>(*(*this)->getInherentAttr(getTagMapAttrStrName()));
+    return cast<AffineMapAttr>(
+        *(*this)->getInherentAttr(getTagMapAttrStrName()));
   }
 
   /// Returns the tag memref index for this DMA operation.
@@ -465,6 +469,23 @@ AffineForOp getForInductionVarOwner(Value val);
 /// AffineParallelOp.
 AffineParallelOp getAffineParallelInductionVarOwner(Value val);
 
+/// Helper to replace uses of loop carried values (iter_args) and loop
+/// yield values while promoting single iteration affine.for ops.
+void replaceIterArgsAndYieldResults(AffineForOp forOp);
+
+/// Returns the trip count of the loop as an affine expression if the latter is
+/// expressible as an affine expression, and nullptr otherwise. The trip count
+/// expression is simplified before returning. This method only utilizes map
+/// composition to construct lower and upper bounds before computing the trip
+/// count expressions.
+void getTripCountMapAndOperands(AffineForOp forOp, AffineMap *tripCountMap,
+                                SmallVectorImpl<Value> *tripCountOperands);
+
+/// Returns the trip count of the loop if it's a constant, std::nullopt
+/// otherwise. This uses affine expression analysis and is able to determine
+/// constant trip count in non-trivial cases.
+std::optional<uint64_t> getConstantTripCount(AffineForOp forOp);
+
 /// Extracts the induction variables from a list of AffineForOps and places them
 /// in the output argument `ivs`.
 void extractForInductionVars(ArrayRef<AffineForOp> forInsts,
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index f9578cf37d5d768..b4ea6122ed4c0e0 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -121,7 +121,7 @@ def AffineForOp : Affine_Op<"for",
      ImplicitAffineTerminator, ConditionallySpeculatable,
      RecursiveMemoryEffects, DeclareOpInterfaceMethods<LoopLikeOpInterface,
      ["getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
-      "getSingleUpperBound", "getYieldedValuesMutable",
+      "getSingleUpperBound", "getYieldedValuesMutable", "promoteIfSingleIteration",
       "replaceWithAdditionalYields"]>,
      DeclareOpInterfaceMethods<RegionBranchOpInterface,
      ["getEntrySuccessorOperands"]>]> {
diff --git a/mlir/include/mlir/Dialect/Affine/LoopUtils.h b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
index 723a262f24acc51..1e3b3bffea7b838 100644
--- a/mlir/include/mlir/Dialect/Affine/LoopUtils.h
+++ b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
@@ -83,10 +83,6 @@ LogicalResult loopUnrollJamByFactor(AffineForOp forOp,
 LogicalResult loopUnrollJamUpToFactor(AffineForOp forOp,
                                       uint64_t unrollJamFactor);
 
-/// Promotes the loop body of a AffineForOp to its containing block if the loop
-/// was known to have a single iteration.
-LogicalResult promoteIfSingleIteration(AffineForOp forOp);
-
 /// Promotes all single iteration AffineForOp's in the Function, i.e., moves
 /// their body into the containing Block.
 void promoteSingleIterationLoops(func::FuncOp f);
diff --git a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
index e645afe7cd3e8fa..24f119464b416a7 100644
--- a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
@@ -12,7 +12,6 @@
 
 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
 
-#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
 #include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
@@ -20,9 +19,9 @@
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Support/MathExtras.h"
 
-#include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/SmallString.h"
+
 #include <numeric>
 #include <optional>
 #include <type_traits>
@@ -30,83 +29,6 @@
 using namespace mlir;
 using namespace mlir::affine;
 
-/// Returns the trip count of the loop as an affine expression if the latter is
-/// expressible as an affine expression, and nullptr otherwise. The trip count
-/// expression is simplified before returning. This method only utilizes map
-/// composition to construct lower and upper bounds before computing the trip
-/// count expressions.
-void mlir::affine::getTripCountMapAndOperands(
-    AffineForOp forOp, AffineMap *tripCountMap,
-    SmallVectorImpl<Value> *tripCountOperands) {
-  MLIRContext *context = forOp.getContext();
-  int64_t step = forOp.getStepAsInt();
-  int64_t loopSpan;
-  if (forOp.hasConstantBounds()) {
-    int64_t lb = forOp.getConstantLowerBound();
-    int64_t ub = forOp.getConstantUpperBound();
-    loopSpan = ub - lb;
-    if (loopSpan < 0)
-      loopSpan = 0;
-    *tripCountMap = AffineMap::getConstantMap(ceilDiv(loopSpan, step), context);
-    tripCountOperands->clear();
-    return;
-  }
-  auto lbMap = forOp.getLowerBoundMap();
-  auto ubMap = forOp.getUpperBoundMap();
-  if (lbMap.getNumResults() != 1) {
-    *tripCountMap = AffineMap();
-    return;
-  }
-
-  // Difference of each upper bound expression from the single lower bound
-  // expression (divided by the step) provides the expressions for the trip
-  // count map.
-  AffineValueMap ubValueMap(ubMap, forOp.getUpperBoundOperands());
-
-  SmallVector<AffineExpr, 4> lbSplatExpr(ubValueMap.getNumResults(),
-                                         lbMap.getResult(0));
-  auto lbMapSplat = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(),
-                                   lbSplatExpr, context);
-  AffineValueMap lbSplatValueMap(lbMapSplat, forOp.getLowerBoundOperands());
-
-  AffineValueMap tripCountValueMap;
-  AffineValueMap::difference(ubValueMap, lbSplatValueMap, &tripCountValueMap);
-  for (unsigned i = 0, e = tripCountValueMap.getNumResults(); i < e; ++i)
-    tripCountValueMap.setResult(i,
-                                tripCountValueMap.getResult(i).ceilDiv(step));
-
-  *tripCountMap = tripCountValueMap.getAffineMap();
-  tripCountOperands->assign(tripCountValueMap.getOperands().begin(),
-                            tripCountValueMap.getOperands().end());
-}
-
-/// Returns the trip count of the loop if it's a constant, std::nullopt
-/// otherwise. This method uses affine expression analysis (in turn using
-/// getTripCount) and is able to determine constant trip count in non-trivial
-/// cases.
-std::optional<uint64_t> mlir::affine::getConstantTripCount(AffineForOp forOp) {
-  SmallVector<Value, 4> operands;
-  AffineMap map;
-  getTripCountMapAndOperands(forOp, &map, &operands);
-
-  if (!map)
-    return std::nullopt;
-
-  // Take the min if all trip counts are constant.
-  std::optional<uint64_t> tripCount;
-  for (auto resultExpr : map.getResults()) {
-    if (auto constExpr = dyn_cast<AffineConstantExpr>(resultExpr)) {
-      if (tripCount.has_value())
-        tripCount =
-            std::min(*tripCount, static_cast<uint64_t>(constExpr.getValue()));
-      else
-        tripCount = constExpr.getValue();
-    } else
-      return std::nullopt;
-  }
-  return tripCount;
-}
-
 /// Returns the greatest known integral divisor of the trip count. Affine
 /// expression analysis is used (indirectly through getTripCount), and
 /// this method is thus able to determine non-trivial divisors.
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 05496e70716a2a1..8716d7a3525b526 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/AffineExprVisitor.h"
@@ -23,6 +24,7 @@
 #include "llvm/ADT/SmallVectorExtras.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
+
 #include <numeric>
 #include <optional>
 
@@ -2440,6 +2442,69 @@ std::optional<OpFoldResult> AffineForOp::getSingleUpperBound() {
   return OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()));
 }
 
+void mlir::affine::replaceIterArgsAndYieldResults(AffineForOp forOp) {
+  // Replace uses of iter arguments with iter operands (initial values).
+  OperandRange iterOperands = forOp.getInits();
+  MutableArrayRef<BlockArgument> iterArgs = forOp.getRegionIterArgs();
+  for (auto [operand, arg] : llvm::zip(iterOperands, iterArgs))
+    arg.replaceAllUsesWith(operand);
+
+  // Replace uses of loop results with the values yielded by the loop.
+  ResultRange outerResults = forOp.getResults();
+  OperandRange innerResults = forOp.getBody()->getTerminator()->getOperands();
+  for (auto [outer, inner] : llvm::zip(outerResults, innerResults))
+    outer.replaceAllUsesWith(inner);
+}
+
+LogicalResult AffineForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
+  auto forOp = cast<AffineForOp>(getOperation());
+  std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
+  if (!tripCount || *tripCount != 1)
+    return failure();
+
+  // TODO: extend this for arbitrary affine bounds.
+  if (forOp.getLowerBoundMap().getNumResults() != 1)
+    return failure();
+
+  // Replaces all IV uses to its single iteration value.
+  BlockArgument iv = forOp.getInductionVar();
+  if (!iv.use_empty()) {
+    if (forOp.hasConstantLowerBound()) {
+      Operation *parentOp = forOp.getOperation();
+      while (isa<AffineForOp>(parentOp->getParentOp()))
+        parentOp = parentOp->getParentOp();
+      Block *parentBlock = parentOp->getBlock();
+      OpBuilder topBuilder(parentBlock, parentBlock->begin());
+      auto constOp = topBuilder.create<arith::ConstantIndexOp>(
+          forOp.getLoc(), forOp.getConstantLowerBound());
+      iv.replaceAllUsesWith(constOp);
+    } else {
+      OperandRange lbOperands = forOp.getLowerBoundOperands();
+      AffineMap lbMap = forOp.getLowerBoundMap();
+      OpBuilder builder(forOp);
+      if (lbMap == builder.getDimIdentityMap()) {
+        // No need of generating an affine.apply.
+        iv.replaceAllUsesWith(lbOperands[0]);
+      } else {
+        auto affineApplyOp =
+            builder.create<AffineApplyOp>(forOp.getLoc(), lbMap, lbOperands);
+        iv.replaceAllUsesWith(affineApplyOp);
+      }
+    }
+  }
+
+  replaceIterArgsAndYieldResults(forOp);
+
+  // Move the loop body operations, except for its terminator, to the loop's
+  // containing block.
+  forOp.getBody()->back().erase();
+  Block *parentBlock = forOp->getBlock();
+  parentBlock->getOperations().splice(Block::iterator(forOp),
+                                      forOp.getBody()->getOperations());
+  forOp.erase();
+  return success();
+}
+
 FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
     RewriterBase &rewriter, ValueRange newInitOperands,
     bool replaceInitOperandUsesInLoop,
@@ -2538,6 +2603,79 @@ AffineParallelOp mlir::affine::getAffineParallelInductionVarOwner(Value val) {
   return nullptr;
 }
 
+/// Returns the trip count of the loop as an affine expression if the latter is
+/// expressible as an affine expression, and nullptr otherwise. The trip count
+/// expression is simplified before returning. This method only utilizes map
+/// composition to construct lower and upper bounds before computing the trip
+/// count expressions.
+void mlir::affine::getTripCountMapAndOperands(
+    AffineForOp forOp, AffineMap *tripCountMap,
+    SmallVectorImpl<Value> *tripCountOperands) {
+  MLIRContext *context = forOp.getContext();
+  int64_t step = forOp.getStepAsInt();
+  int64_t loopSpan;
+  if (forOp.hasConstantBounds()) {
+    int64_t lb = forOp.getConstantLowerBound();
+    int64_t ub = forOp.getConstantUpperBound();
+    loopSpan = ub - lb;
+    if (loopSpan < 0)
+      loopSpan = 0;
+    *tripCountMap = AffineMap::getConstantMap(ceilDiv(loopSpan, step), context);
+    tripCountOperands->clear();
+    return;
+  }
+  auto lbMap = forOp.getLowerBoundMap();
+  auto ubMap = forOp.getUpperBoundMap();
+  if (lbMap.getNumResults() != 1) {
+    *tripCountMap = AffineMap();
+    return;
+  }
+
+  // Difference of each upper bound expression from the single lower bound
+  // expression (divided by the step) provides the expressions for the trip
+  // count map.
+  AffineValueMap ubValueMap(ubMap, forOp.getUpperBoundOperands());
+
+  SmallVector<AffineExpr, 4> lbSplatExpr(ubValueMap.getNumResults(),
+                                         lbMap.getResult(0));
+  auto lbMapSplat = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(),
+                                   lbSplatExpr, context);
+  AffineValueMap lbSplatValueMap(lbMapSplat, forOp.getLowerBoundOperands());
+
+  AffineValueMap tripCountValueMap;
+  AffineValueMap::difference(ubValueMap, lbSplatValueMap, &tripCountValueMap);
+  for (unsigned i = 0, e = tripCountValueMap.getNumResults(); i < e; ++i)
+    tripCountValueMap.setResult(i,
+                                tripCountValueMap.getResult(i).ceilDiv(step));
+
+  *tripCountMap = tripCountValueMap.getAffineMap();
+  tripCountOperands->assign(tripCountValueMap.getOperands().begin(),
+                            tripCountValueMap.getOperands().end());
+}
+
+std::optional<uint64_t> mlir::affine::getConstantTripCount(AffineForOp forOp) {
+  SmallVector<Value, 4> operands;
+  AffineMap map;
+  getTripCountMapAndOperands(forOp, &map, &operands);
+
+  if (!map)
+    return std::nullopt;
+
+  // Take the min if all trip counts are constant.
+  std::optional<uint64_t> tripCount;
+  for (auto resultExpr : map.getResults()) {
+    if (auto constExpr = dyn_cast<AffineConstantExpr>(resultExpr)) {
+      if (tripCount.has_value())
+        tripCount =
+            std::min(*tripCount, static_cast<uint64_t>(constExpr.getValue()));
+      else
+        tripCount = constExpr.getValue();
+    } else
+      return std::nullopt;
+  }
+  return tripCount;
+}
+
 /// Extracts the induction variables from a list of AffineForOps and returns
 /// them.
 void mlir::affine::extractForInductionVars(ArrayRef<AffineForOp> forInsts,
@@ -2905,8 +3043,7 @@ static void composeSetAndOperands(IntegerSet &set,
 }
 
 /// Canonicalize an affine if op's conditional (integer set + operands).
-LogicalResult AffineIfOp::fold(FoldAdaptor,
-                               SmallVectorImpl<OpFoldResult> &) {
+LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
   auto set = getIntegerSet();
   SmallVector<Value, 4> operands(getOperands());
   composeSetAndOperands(set, operands);
@@ -2997,11 +3134,11 @@ static LogicalResult
 verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
                        Operation::operand_range mapOperands,
                        MemRefType memrefType, unsigned numIndexOperands) {
-    AffineMap map = mapAttr.getValue();
-    if (map.getNumResults() != memrefType.getRank())
-      return op->emitOpError("affine map num results must equal memref rank");
-    if (map.getNumInputs() != numIndexOperands)
-      return op->emitOpError("expects as many subscripts as affine map inputs");
+  AffineMap map = mapAttr.getValue();
+  if (map.getNumResults() != memrefType.getRank())
+    return op->emitOpError("affine map num results must equal memref rank");
+  if (map.getNumInputs() != numIndexOperands)
+    return op->emitOpError("expects as many subscripts as affine map inputs");
 
   Region *scope = getAffineScope(op);
   for (auto idx : mapOperands) {
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
index 331b0f1b2c2b1c6..31b90a60472c1f1 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
@@ -219,13 +219,14 @@ void AffineDataCopyGeneration::runOnOperation() {
 
   // Promote any single iteration loops in the copy nests and collect
   // load/stores to simplify.
+  IRRewriter rewriter(f.getContext());
   SmallVector<Operation *, 4> copyOps;
   for (Operation *nest : copyNests)
     // With a post order walk, the erasure of loops does not affect
     // continuation of the walk or the collection of load/store ops.
     nest->walk([&](Operation *op) {
       if (auto forOp = dyn_cast<AffineForOp>(op))
-        (void)promoteIfSingleIteration(forOp);
+        (void)forOp.promoteIfSingleIteration(rewriter);
       else if (isa<AffineLoadOp, AffineStoreOp>(op))
         copyOps.push_back(op);
     });
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
index 5053b08ee0834cd..d11e77544e24ea5 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
@@ -457,6 +457,7 @@ void m...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/72805


More information about the Mlir-commits mailing list