[Mlir-commits] [mlir] aa6be2f - [mlir][affine] implement `promoteIfSingleIteration` for `AffineForOp` (#72547)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Nov 19 10:27:46 PST 2023


Author: Maksim Levental
Date: 2023-11-19T12:27:41-06:00
New Revision: aa6be2f7c94ea3302fcc1ab034a85cd375eaa800

URL: https://github.com/llvm/llvm-project/commit/aa6be2f7c94ea3302fcc1ab034a85cd375eaa800
DIFF: https://github.com/llvm/llvm-project/commit/aa6be2f7c94ea3302fcc1ab034a85cd375eaa800.diff

LOG: [mlir][affine] implement `promoteIfSingleIteration` for `AffineForOp` (#72547)

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h
    mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
    mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
    mlir/include/mlir/Dialect/Affine/LoopUtils.h
    mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
    mlir/lib/Dialect/Affine/IR/AffineOps.cpp
    mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
    mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
    mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
    mlir/lib/Dialect/Affine/Utils/Utils.cpp
    mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h b/mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h
index 92f3d5a2c4925b1..c629c3a1c562322 100644
--- a/mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h
+++ b/mlir/include/mlir/Dialect/Affine/Analysis/LoopAnalysis.h
@@ -15,6 +15,7 @@
 
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/ArrayRef.h"
+
 #include <optional>
 
 namespace mlir {
@@ -29,20 +30,6 @@ namespace affine {
 class AffineForOp;
 class NestedPattern;
 
-/// Returns the trip count of the loop as an affine map with its corresponding
-/// operands if the latter is expressible as an affine expression, and nullptr
-/// otherwise. This method always succeeds as long as the lower bound is not a
-/// multi-result map. The trip count expression is simplified before returning.
-/// This method only utilizes map composition to construct lower and upper
-/// bounds before computing the trip count expressions
-void getTripCountMapAndOperands(AffineForOp forOp, AffineMap *map,
-                                SmallVectorImpl<Value> *operands);
-
-/// Returns the trip count of the loop if it's a constant, std::nullopt
-/// otherwise. This uses affine expression analysis and is able to determine
-/// constant trip count in non-trivial cases.
-std::optional<uint64_t> getConstantTripCount(AffineForOp forOp);
-
 /// Returns the greatest known integral divisor of the trip count. Affine
 /// expression analysis is used (indirectly through getTripCount), and
 /// this method is thus able to determine non-trivial divisors.

diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index f070d0488619063..f763cf339159a50 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -117,7 +117,8 @@ class AffineDmaStartOp
   /// Returns the affine map used to access the source memref.
   AffineMap getSrcMap() { return getSrcMapAttr().getValue(); }
   AffineMapAttr getSrcMapAttr() {
-    return cast<AffineMapAttr>(*(*this)->getInherentAttr(getSrcMapAttrStrName()));
+    return cast<AffineMapAttr>(
+        *(*this)->getInherentAttr(getSrcMapAttrStrName()));
   }
 
   /// Returns the source memref affine map indices for this DMA operation.
@@ -156,7 +157,8 @@ class AffineDmaStartOp
   /// Returns the affine map used to access the destination memref.
   AffineMap getDstMap() { return getDstMapAttr().getValue(); }
   AffineMapAttr getDstMapAttr() {
-    return cast<AffineMapAttr>(*(*this)->getInherentAttr(getDstMapAttrStrName()));
+    return cast<AffineMapAttr>(
+        *(*this)->getInherentAttr(getDstMapAttrStrName()));
   }
 
   /// Returns the destination memref indices for this DMA operation.
@@ -185,7 +187,8 @@ class AffineDmaStartOp
   /// Returns the affine map used to access the tag memref.
   AffineMap getTagMap() { return getTagMapAttr().getValue(); }
   AffineMapAttr getTagMapAttr() {
-    return cast<AffineMapAttr>(*(*this)->getInherentAttr(getTagMapAttrStrName()));
+    return cast<AffineMapAttr>(
+        *(*this)->getInherentAttr(getTagMapAttrStrName()));
   }
 
   /// Returns the tag memref indices for this DMA operation.
@@ -307,7 +310,8 @@ class AffineDmaWaitOp
   /// Returns the affine map used to access the tag memref.
   AffineMap getTagMap() { return getTagMapAttr().getValue(); }
   AffineMapAttr getTagMapAttr() {
-    return cast<AffineMapAttr>(*(*this)->getInherentAttr(getTagMapAttrStrName()));
+    return cast<AffineMapAttr>(
+        *(*this)->getInherentAttr(getTagMapAttrStrName()));
   }
 
   /// Returns the tag memref index for this DMA operation.
@@ -465,6 +469,23 @@ AffineForOp getForInductionVarOwner(Value val);
 /// AffineParallelOp.
 AffineParallelOp getAffineParallelInductionVarOwner(Value val);
 
+/// Helper to replace uses of loop carried values (iter_args) and loop
+/// yield values while promoting single iteration affine.for ops.
+void replaceIterArgsAndYieldResults(AffineForOp forOp);
+
+/// Returns the trip count of the loop as an affine expression if the latter is
+/// expressible as an affine expression, and nullptr otherwise. The trip count
+/// expression is simplified before returning. This method only utilizes map
+/// composition to construct lower and upper bounds before computing the trip
+/// count expressions.
+void getTripCountMapAndOperands(AffineForOp forOp, AffineMap *tripCountMap,
+                                SmallVectorImpl<Value> *tripCountOperands);
+
+/// Returns the trip count of the loop if it's a constant, std::nullopt
+/// otherwise. This uses affine expression analysis and is able to determine
+/// constant trip count in non-trivial cases.
+std::optional<uint64_t> getConstantTripCount(AffineForOp forOp);
+
 /// Extracts the induction variables from a list of AffineForOps and places them
 /// in the output argument `ivs`.
 void extractForInductionVars(ArrayRef<AffineForOp> forInsts,

diff  --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index f9578cf37d5d768..b4ea6122ed4c0e0 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -121,7 +121,7 @@ def AffineForOp : Affine_Op<"for",
      ImplicitAffineTerminator, ConditionallySpeculatable,
      RecursiveMemoryEffects, DeclareOpInterfaceMethods<LoopLikeOpInterface,
      ["getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
-      "getSingleUpperBound", "getYieldedValuesMutable",
+      "getSingleUpperBound", "getYieldedValuesMutable", "promoteIfSingleIteration",
       "replaceWithAdditionalYields"]>,
      DeclareOpInterfaceMethods<RegionBranchOpInterface,
      ["getEntrySuccessorOperands"]>]> {

diff  --git a/mlir/include/mlir/Dialect/Affine/LoopUtils.h b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
index 723a262f24acc51..1e3b3bffea7b838 100644
--- a/mlir/include/mlir/Dialect/Affine/LoopUtils.h
+++ b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
@@ -83,10 +83,6 @@ LogicalResult loopUnrollJamByFactor(AffineForOp forOp,
 LogicalResult loopUnrollJamUpToFactor(AffineForOp forOp,
                                       uint64_t unrollJamFactor);
 
-/// Promotes the loop body of a AffineForOp to its containing block if the loop
-/// was known to have a single iteration.
-LogicalResult promoteIfSingleIteration(AffineForOp forOp);
-
 /// Promotes all single iteration AffineForOp's in the Function, i.e., moves
 /// their body into the containing Block.
 void promoteSingleIterationLoops(func::FuncOp f);

diff  --git a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
index e645afe7cd3e8fa..24f119464b416a7 100644
--- a/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/LoopAnalysis.cpp
@@ -12,7 +12,6 @@
 
 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
 
-#include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
 #include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
@@ -20,9 +19,9 @@
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
 #include "mlir/Support/MathExtras.h"
 
-#include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/SmallString.h"
+
 #include <numeric>
 #include <optional>
 #include <type_traits>
@@ -30,83 +29,6 @@
 using namespace mlir;
 using namespace mlir::affine;
 
-/// Returns the trip count of the loop as an affine expression if the latter is
-/// expressible as an affine expression, and nullptr otherwise. The trip count
-/// expression is simplified before returning. This method only utilizes map
-/// composition to construct lower and upper bounds before computing the trip
-/// count expressions.
-void mlir::affine::getTripCountMapAndOperands(
-    AffineForOp forOp, AffineMap *tripCountMap,
-    SmallVectorImpl<Value> *tripCountOperands) {
-  MLIRContext *context = forOp.getContext();
-  int64_t step = forOp.getStepAsInt();
-  int64_t loopSpan;
-  if (forOp.hasConstantBounds()) {
-    int64_t lb = forOp.getConstantLowerBound();
-    int64_t ub = forOp.getConstantUpperBound();
-    loopSpan = ub - lb;
-    if (loopSpan < 0)
-      loopSpan = 0;
-    *tripCountMap = AffineMap::getConstantMap(ceilDiv(loopSpan, step), context);
-    tripCountOperands->clear();
-    return;
-  }
-  auto lbMap = forOp.getLowerBoundMap();
-  auto ubMap = forOp.getUpperBoundMap();
-  if (lbMap.getNumResults() != 1) {
-    *tripCountMap = AffineMap();
-    return;
-  }
-
-  // Difference of each upper bound expression from the single lower bound
-  // expression (divided by the step) provides the expressions for the trip
-  // count map.
-  AffineValueMap ubValueMap(ubMap, forOp.getUpperBoundOperands());
-
-  SmallVector<AffineExpr, 4> lbSplatExpr(ubValueMap.getNumResults(),
-                                         lbMap.getResult(0));
-  auto lbMapSplat = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(),
-                                   lbSplatExpr, context);
-  AffineValueMap lbSplatValueMap(lbMapSplat, forOp.getLowerBoundOperands());
-
-  AffineValueMap tripCountValueMap;
-  AffineValueMap::
diff erence(ubValueMap, lbSplatValueMap, &tripCountValueMap);
-  for (unsigned i = 0, e = tripCountValueMap.getNumResults(); i < e; ++i)
-    tripCountValueMap.setResult(i,
-                                tripCountValueMap.getResult(i).ceilDiv(step));
-
-  *tripCountMap = tripCountValueMap.getAffineMap();
-  tripCountOperands->assign(tripCountValueMap.getOperands().begin(),
-                            tripCountValueMap.getOperands().end());
-}
-
-/// Returns the trip count of the loop if it's a constant, std::nullopt
-/// otherwise. This method uses affine expression analysis (in turn using
-/// getTripCount) and is able to determine constant trip count in non-trivial
-/// cases.
-std::optional<uint64_t> mlir::affine::getConstantTripCount(AffineForOp forOp) {
-  SmallVector<Value, 4> operands;
-  AffineMap map;
-  getTripCountMapAndOperands(forOp, &map, &operands);
-
-  if (!map)
-    return std::nullopt;
-
-  // Take the min if all trip counts are constant.
-  std::optional<uint64_t> tripCount;
-  for (auto resultExpr : map.getResults()) {
-    if (auto constExpr = dyn_cast<AffineConstantExpr>(resultExpr)) {
-      if (tripCount.has_value())
-        tripCount =
-            std::min(*tripCount, static_cast<uint64_t>(constExpr.getValue()));
-      else
-        tripCount = constExpr.getValue();
-    } else
-      return std::nullopt;
-  }
-  return tripCount;
-}
-
 /// Returns the greatest known integral divisor of the trip count. Affine
 /// expression analysis is used (indirectly through getTripCount), and
 /// this method is thus able to determine non-trivial divisors.

diff  --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index d22a7539fb75018..b0e0aade3fa5481 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -7,7 +7,9 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/IR/AffineExprVisitor.h"
@@ -2448,6 +2450,65 @@ std::optional<OpFoldResult> AffineForOp::getSingleUpperBound() {
   return OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()));
 }
 
+void mlir::affine::replaceIterArgsAndYieldResults(AffineForOp forOp) {
+  // Replace uses of iter arguments with iter operands (initial values).
+  OperandRange iterOperands = forOp.getInits();
+  MutableArrayRef<BlockArgument> iterArgs = forOp.getRegionIterArgs();
+  for (auto [operand, arg] : llvm::zip(iterOperands, iterArgs))
+    arg.replaceAllUsesWith(operand);
+
+  // Replace uses of loop results with the values yielded by the loop.
+  ResultRange outerResults = forOp.getResults();
+  OperandRange innerResults = forOp.getBody()->getTerminator()->getOperands();
+  for (auto [outer, inner] : llvm::zip(outerResults, innerResults))
+    outer.replaceAllUsesWith(inner);
+}
+
+LogicalResult AffineForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
+  auto forOp = cast<AffineForOp>(getOperation());
+  std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
+  if (!tripCount || *tripCount != 1)
+    return failure();
+
+  // TODO: extend this for arbitrary affine bounds.
+  if (forOp.getLowerBoundMap().getNumResults() != 1)
+    return failure();
+
+  // Replaces all IV uses to its single iteration value.
+  BlockArgument iv = forOp.getInductionVar();
+  Block *parentBlock = forOp->getBlock();
+  if (!iv.use_empty()) {
+    if (forOp.hasConstantLowerBound()) {
+      OpBuilder topBuilder(forOp->getParentOfType<func::FuncOp>().getBody());
+      auto constOp = topBuilder.create<arith::ConstantIndexOp>(
+          forOp.getLoc(), forOp.getConstantLowerBound());
+      iv.replaceAllUsesWith(constOp);
+    } else {
+      OperandRange lbOperands = forOp.getLowerBoundOperands();
+      AffineMap lbMap = forOp.getLowerBoundMap();
+      OpBuilder builder(forOp);
+      if (lbMap == builder.getDimIdentityMap()) {
+        // No need of generating an affine.apply.
+        iv.replaceAllUsesWith(lbOperands[0]);
+      } else {
+        auto affineApplyOp =
+            builder.create<AffineApplyOp>(forOp.getLoc(), lbMap, lbOperands);
+        iv.replaceAllUsesWith(affineApplyOp);
+      }
+    }
+  }
+
+  replaceIterArgsAndYieldResults(forOp);
+
+  // Move the loop body operations, except for its terminator, to the loop's
+  // containing block.
+  forOp.getBody()->back().erase();
+  parentBlock->getOperations().splice(Block::iterator(forOp),
+                                      forOp.getBody()->getOperations());
+  forOp.erase();
+  return success();
+}
+
 FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
     RewriterBase &rewriter, ValueRange newInitOperands,
     bool replaceInitOperandUsesInLoop,
@@ -2546,6 +2607,79 @@ AffineParallelOp mlir::affine::getAffineParallelInductionVarOwner(Value val) {
   return nullptr;
 }
 
+/// Returns the trip count of the loop as an affine expression if the latter is
+/// expressible as an affine expression, and nullptr otherwise. The trip count
+/// expression is simplified before returning. This method only utilizes map
+/// composition to construct lower and upper bounds before computing the trip
+/// count expressions.
+void mlir::affine::getTripCountMapAndOperands(
+    AffineForOp forOp, AffineMap *tripCountMap,
+    SmallVectorImpl<Value> *tripCountOperands) {
+  MLIRContext *context = forOp.getContext();
+  int64_t step = forOp.getStepAsInt();
+  int64_t loopSpan;
+  if (forOp.hasConstantBounds()) {
+    int64_t lb = forOp.getConstantLowerBound();
+    int64_t ub = forOp.getConstantUpperBound();
+    loopSpan = ub - lb;
+    if (loopSpan < 0)
+      loopSpan = 0;
+    *tripCountMap = AffineMap::getConstantMap(ceilDiv(loopSpan, step), context);
+    tripCountOperands->clear();
+    return;
+  }
+  auto lbMap = forOp.getLowerBoundMap();
+  auto ubMap = forOp.getUpperBoundMap();
+  if (lbMap.getNumResults() != 1) {
+    *tripCountMap = AffineMap();
+    return;
+  }
+
+  // Difference of each upper bound expression from the single lower bound
+  // expression (divided by the step) provides the expressions for the trip
+  // count map.
+  AffineValueMap ubValueMap(ubMap, forOp.getUpperBoundOperands());
+
+  SmallVector<AffineExpr, 4> lbSplatExpr(ubValueMap.getNumResults(),
+                                         lbMap.getResult(0));
+  auto lbMapSplat = AffineMap::get(lbMap.getNumDims(), lbMap.getNumSymbols(),
+                                   lbSplatExpr, context);
+  AffineValueMap lbSplatValueMap(lbMapSplat, forOp.getLowerBoundOperands());
+
+  AffineValueMap tripCountValueMap;
+  AffineValueMap::
diff erence(ubValueMap, lbSplatValueMap, &tripCountValueMap);
+  for (unsigned i = 0, e = tripCountValueMap.getNumResults(); i < e; ++i)
+    tripCountValueMap.setResult(i,
+                                tripCountValueMap.getResult(i).ceilDiv(step));
+
+  *tripCountMap = tripCountValueMap.getAffineMap();
+  tripCountOperands->assign(tripCountValueMap.getOperands().begin(),
+                            tripCountValueMap.getOperands().end());
+}
+
+std::optional<uint64_t> mlir::affine::getConstantTripCount(AffineForOp forOp) {
+  SmallVector<Value, 4> operands;
+  AffineMap map;
+  getTripCountMapAndOperands(forOp, &map, &operands);
+
+  if (!map)
+    return std::nullopt;
+
+  // Take the min if all trip counts are constant.
+  std::optional<uint64_t> tripCount;
+  for (auto resultExpr : map.getResults()) {
+    if (auto constExpr = dyn_cast<AffineConstantExpr>(resultExpr)) {
+      if (tripCount.has_value())
+        tripCount =
+            std::min(*tripCount, static_cast<uint64_t>(constExpr.getValue()));
+      else
+        tripCount = constExpr.getValue();
+    } else
+      return std::nullopt;
+  }
+  return tripCount;
+}
+
 /// Extracts the induction variables from a list of AffineForOps and returns
 /// them.
 void mlir::affine::extractForInductionVars(ArrayRef<AffineForOp> forInsts,
@@ -2913,8 +3047,7 @@ static void composeSetAndOperands(IntegerSet &set,
 }
 
 /// Canonicalize an affine if op's conditional (integer set + operands).
-LogicalResult AffineIfOp::fold(FoldAdaptor,
-                               SmallVectorImpl<OpFoldResult> &) {
+LogicalResult AffineIfOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
   auto set = getIntegerSet();
   SmallVector<Value, 4> operands(getOperands());
   composeSetAndOperands(set, operands);
@@ -3005,11 +3138,11 @@ static LogicalResult
 verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
                        Operation::operand_range mapOperands,
                        MemRefType memrefType, unsigned numIndexOperands) {
-    AffineMap map = mapAttr.getValue();
-    if (map.getNumResults() != memrefType.getRank())
-      return op->emitOpError("affine map num results must equal memref rank");
-    if (map.getNumInputs() != numIndexOperands)
-      return op->emitOpError("expects as many subscripts as affine map inputs");
+  AffineMap map = mapAttr.getValue();
+  if (map.getNumResults() != memrefType.getRank())
+    return op->emitOpError("affine map num results must equal memref rank");
+  if (map.getNumInputs() != numIndexOperands)
+    return op->emitOpError("expects as many subscripts as affine map inputs");
 
   Region *scope = getAffineScope(op);
   for (auto idx : mapOperands) {

diff  --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
index 331b0f1b2c2b1c6..31b90a60472c1f1 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
@@ -219,13 +219,14 @@ void AffineDataCopyGeneration::runOnOperation() {
 
   // Promote any single iteration loops in the copy nests and collect
   // load/stores to simplify.
+  IRRewriter rewriter(f.getContext());
   SmallVector<Operation *, 4> copyOps;
   for (Operation *nest : copyNests)
     // With a post order walk, the erasure of loops does not affect
     // continuation of the walk or the collection of load/store ops.
     nest->walk([&](Operation *op) {
       if (auto forOp = dyn_cast<AffineForOp>(op))
-        (void)promoteIfSingleIteration(forOp);
+        (void)forOp.promoteIfSingleIteration(rewriter);
       else if (isa<AffineLoadOp, AffineStoreOp>(op))
         copyOps.push_back(op);
     });

diff  --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
index 77ee1d4c7a02d97..c4cf32fcb697b51 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
@@ -456,6 +456,7 @@ void mlir::affine::fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
     return (buildSliceTripCountMap(srcSlice, &sliceTripCountMap) &&
             (getSliceIterationCount(sliceTripCountMap) == 1));
   };
+  IRRewriter rewriter(srcForOp.getContext());
   // Fix up and if possible, eliminate single iteration loops.
   for (AffineForOp forOp : sliceLoops) {
     if (isLoopParallelAndContainsReduction(forOp) &&
@@ -463,9 +464,8 @@ void mlir::affine::fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
       // Patch reduction loop - only ones that are sibling-fused with the
       // destination loop - into the parent loop.
       (void)promoteSingleIterReductionLoop(forOp, true);
-    else
-      // Promote any single iteration slice loops.
-      (void)promoteIfSingleIteration(forOp);
+    else // Promote any single iteration slice loops.
+      (void)forOp.promoteIfSingleIteration(rewriter);
   }
 }
 

diff  --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index 3794ef2dabe1e0a..2c27c724007e799 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -110,68 +110,6 @@ getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor,
     lb.erase();
 }
 
-/// Helper to replace uses of loop carried values (iter_args) and loop
-/// yield values while promoting single iteration affine.for ops.
-static void replaceIterArgsAndYieldResults(AffineForOp forOp) {
-  // Replace uses of iter arguments with iter operands (initial values).
-  auto iterOperands = forOp.getInits();
-  auto iterArgs = forOp.getRegionIterArgs();
-  for (auto e : llvm::zip(iterOperands, iterArgs))
-    std::get<1>(e).replaceAllUsesWith(std::get<0>(e));
-
-  // Replace uses of loop results with the values yielded by the loop.
-  auto outerResults = forOp.getResults();
-  auto innerResults = forOp.getBody()->getTerminator()->getOperands();
-  for (auto e : llvm::zip(outerResults, innerResults))
-    std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
-}
-
-/// Promotes the loop body of a forOp to its containing block if the forOp
-/// was known to have a single iteration.
-LogicalResult mlir::affine::promoteIfSingleIteration(AffineForOp forOp) {
-  std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
-  if (!tripCount || *tripCount != 1)
-    return failure();
-
-  // TODO: extend this for arbitrary affine bounds.
-  if (forOp.getLowerBoundMap().getNumResults() != 1)
-    return failure();
-
-  // Replaces all IV uses to its single iteration value.
-  auto iv = forOp.getInductionVar();
-  auto *parentBlock = forOp->getBlock();
-  if (!iv.use_empty()) {
-    if (forOp.hasConstantLowerBound()) {
-      OpBuilder topBuilder(forOp->getParentOfType<func::FuncOp>().getBody());
-      auto constOp = topBuilder.create<arith::ConstantIndexOp>(
-          forOp.getLoc(), forOp.getConstantLowerBound());
-      iv.replaceAllUsesWith(constOp);
-    } else {
-      auto lbOperands = forOp.getLowerBoundOperands();
-      auto lbMap = forOp.getLowerBoundMap();
-      OpBuilder builder(forOp);
-      if (lbMap == builder.getDimIdentityMap()) {
-        // No need of generating an affine.apply.
-        iv.replaceAllUsesWith(lbOperands[0]);
-      } else {
-        auto affineApplyOp =
-            builder.create<AffineApplyOp>(forOp.getLoc(), lbMap, lbOperands);
-        iv.replaceAllUsesWith(affineApplyOp);
-      }
-    }
-  }
-
-  replaceIterArgsAndYieldResults(forOp);
-
-  // Move the loop body operations, except for its terminator, to the loop's
-  // containing block.
-  forOp.getBody()->back().erase();
-  parentBlock->getOperations().splice(Block::iterator(forOp),
-                                      forOp.getBody()->getOperations());
-  forOp.erase();
-  return success();
-}
-
 /// Generates an affine.for op with the specified lower and upper bounds
 /// while generating the right IV remappings to realize shifts for operations in
 /// its body. The operations that go into the loop body are specified in
@@ -218,7 +156,9 @@ static AffineForOp generateShiftedLoop(
     for (auto *op : ops)
       bodyBuilder.clone(*op, operandMap);
   };
-  if (succeeded(promoteIfSingleIteration(loopChunk)))
+
+  IRRewriter rewriter(loopChunk.getContext());
+  if (succeeded(loopChunk.promoteIfSingleIteration(rewriter)))
     return AffineForOp();
   return loopChunk;
 }
@@ -892,12 +832,13 @@ void mlir::affine::getTileableBands(
 /// Unrolls this loop completely.
 LogicalResult mlir::affine::loopUnrollFull(AffineForOp forOp) {
   std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
+  IRRewriter rewriter(forOp.getContext());
   if (mayBeConstantTripCount.has_value()) {
     uint64_t tripCount = *mayBeConstantTripCount;
     if (tripCount == 0)
       return success();
     if (tripCount == 1)
-      return promoteIfSingleIteration(forOp);
+      return forOp.promoteIfSingleIteration(rewriter);
     return loopUnrollByFactor(forOp, tripCount);
   }
   return failure();
@@ -1003,7 +944,8 @@ static LogicalResult generateCleanupLoopForUnroll(AffineForOp forOp,
 
   cleanupForOp.setLowerBound(cleanupOperands, cleanupMap);
   // Promote the loop body up if this has turned into a single iteration loop.
-  (void)promoteIfSingleIteration(cleanupForOp);
+  IRRewriter rewriter(cleanupForOp.getContext());
+  (void)cleanupForOp.promoteIfSingleIteration(rewriter);
 
   // Adjust upper bound of the original loop; this is the same as the lower
   // bound of the cleanup loop.
@@ -1019,10 +961,11 @@ LogicalResult mlir::affine::loopUnrollByFactor(
     bool cleanUpUnroll) {
   assert(unrollFactor > 0 && "unroll factor should be positive");
 
+  IRRewriter rewriter(forOp.getContext());
   std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
   if (unrollFactor == 1) {
     if (mayBeConstantTripCount && *mayBeConstantTripCount == 1 &&
-        failed(promoteIfSingleIteration(forOp)))
+        failed(forOp.promoteIfSingleIteration(rewriter)))
       return failure();
     return success();
   }
@@ -1076,7 +1019,7 @@ LogicalResult mlir::affine::loopUnrollByFactor(
       /*iterArgs=*/iterArgs, /*yieldedValues=*/yieldedValues);
 
   // Promote the loop body up if this has turned into a single iteration loop.
-  (void)promoteIfSingleIteration(forOp);
+  (void)forOp.promoteIfSingleIteration(rewriter);
   return success();
 }
 
@@ -1135,10 +1078,11 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
                                                   uint64_t unrollJamFactor) {
   assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
 
+  IRRewriter rewriter(forOp.getContext());
   std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
   if (unrollJamFactor == 1) {
     if (mayBeConstantTripCount && *mayBeConstantTripCount == 1 &&
-        failed(promoteIfSingleIteration(forOp)))
+        failed(forOp.promoteIfSingleIteration(rewriter)))
       return failure();
     return success();
   }
@@ -1198,7 +1142,6 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
   // `unrollJamFactor` copies of its iterOperands, iter_args and yield
   // operands.
   SmallVector<AffineForOp, 4> newLoopsWithIterArgs;
-  IRRewriter rewriter(forOp.getContext());
   for (AffineForOp oldForOp : loopsWithIterArgs) {
     SmallVector<Value> dupIterOperands, dupYieldOperands;
     ValueRange oldIterOperands = oldForOp.getInits();
@@ -1321,7 +1264,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
   }
 
   // Promote the loop body up if this has turned into a single iteration loop.
-  (void)promoteIfSingleIteration(forOp);
+  (void)forOp.promoteIfSingleIteration(rewriter);
   return success();
 }
 

diff  --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 50a052fb8b74e70..90abea1bb776bf9 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -552,7 +552,8 @@ void mlir::affine::normalizeAffineParallel(AffineParallelOp op) {
 
 LogicalResult mlir::affine::normalizeAffineFor(AffineForOp op,
                                                bool promoteSingleIter) {
-  if (promoteSingleIter && succeeded(promoteIfSingleIteration(op)))
+  IRRewriter rewriter(op.getContext());
+  if (promoteSingleIter && succeeded(op.promoteIfSingleIteration(rewriter)))
     return success();
 
   // Check if the forop is already normalized.

diff  --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
index b418a457473a8ec..e7913a1cb4eb4db 100644
--- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
@@ -107,13 +107,14 @@ void TestAffineDataCopy::runOnOperation() {
 
   // Promote any single iteration loops in the copy nests and simplify
   // load/stores.
+  IRRewriter rewriter(&getContext());
   SmallVector<Operation *, 4> copyOps;
   for (Operation *nest : copyNests) {
     // With a post order walk, the erasure of loops does not affect
     // continuation of the walk or the collection of load/store ops.
     nest->walk([&](Operation *op) {
       if (auto forOp = dyn_cast<AffineForOp>(op))
-        (void)promoteIfSingleIteration(forOp);
+        (void)forOp.promoteIfSingleIteration(rewriter);
       else if (auto loadOp = dyn_cast<AffineLoadOp>(op))
         copyOps.push_back(loadOp);
       else if (auto storeOp = dyn_cast<AffineStoreOp>(op))


        


More information about the Mlir-commits mailing list