[Mlir-commits] [mlir] [MLIR][NFC] Stop depending on func.func in affine LoopUtils via &Region (PR #83325)

Alexey Z. llvmlistbot at llvm.org
Wed Feb 28 12:06:52 PST 2024


https://github.com/last5bits created https://github.com/llvm/llvm-project/pull/83325

Note: this is an experiment illustrating an approach alternative to PR 82079.

Instead, pass a reference to the outermost region, which makes it possible to use affine LoopUtils routines in downstream dialects that have their own function-like ops.

>From 25ca48c92799cfe59612ece465fd79a7b22e16e2 Mon Sep 17 00:00:00 2001
From: Alexey Zhikhartsev <alexey.zhikhar at gmail.com>
Date: Wed, 28 Feb 2024 11:09:08 -0500
Subject: [PATCH] [MLIR][NFC] Stop depending on func.func in affine LoopUtils
 via &Region

Instead, pass a reference to the outermost region, which makes it
possible to use affine LoopUtils routines in downstream dialects that
have their own function-like ops.
---
 mlir/include/mlir/Dialect/Affine/LoopUtils.h  |  30 +++--
 mlir/include/mlir/Dialect/Affine/Utils.h      |   2 +-
 .../Transforms/AffineDataCopyGeneration.cpp   |  15 ++-
 .../Affine/Transforms/AffineLoopNormalize.cpp |   3 +-
 .../Dialect/Affine/Transforms/LoopTiling.cpp  |   2 +-
 .../Dialect/Affine/Transforms/LoopUnroll.cpp  |  14 +-
 .../Affine/Transforms/LoopUnrollAndJam.cpp    |   3 +-
 .../Transforms/PipelineDataTransfer.cpp       |   3 +-
 .../Dialect/Affine/Utils/LoopFusionUtils.cpp  |   7 +-
 mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp   | 124 ++++++++++--------
 mlir/lib/Dialect/Affine/Utils/Utils.cpp       |   5 +-
 .../SCF/TransformOps/SCFTransformOps.cpp      |   9 +-
 .../lib/Dialect/Affine/TestAffineDataCopy.cpp |  21 +--
 .../Affine/TestAffineLoopParametricTiling.cpp |   2 +-
 .../lib/Dialect/Affine/TestLoopFusion.cpp     |   5 +-
 .../Dialect/Affine/TestVectorizationUtils.cpp |   3 +-
 16 files changed, 141 insertions(+), 107 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/LoopUtils.h b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
index 723a262f24acc5..445290bce76684 100644
--- a/mlir/include/mlir/Dialect/Affine/LoopUtils.h
+++ b/mlir/include/mlir/Dialect/Affine/LoopUtils.h
@@ -43,7 +43,7 @@ struct MemRefRegion;
 
 /// Unrolls this for operation completely if the trip count is known to be
 /// constant. Returns failure otherwise.
-LogicalResult loopUnrollFull(AffineForOp forOp);
+LogicalResult loopUnrollFull(Region &topRegion, AffineForOp forOp);
 
 /// Unrolls this for operation by the specified unroll factor. Returns failure
 /// if the loop cannot be unrolled either due to restrictions or due to invalid
@@ -52,13 +52,14 @@ LogicalResult loopUnrollFull(AffineForOp forOp);
 /// When `cleanUpUnroll` is true, we can ensure the cleanup loop is unrolled
 /// regardless of the unroll factor.
 LogicalResult loopUnrollByFactor(
-    AffineForOp forOp, uint64_t unrollFactor,
+    Region &topRegion, AffineForOp forOp, uint64_t unrollFactor,
     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn = nullptr,
     bool cleanUpUnroll = false);
 
 /// Unrolls this loop by the specified unroll factor or its trip count,
 /// whichever is lower.
-LogicalResult loopUnrollUpToFactor(AffineForOp forOp, uint64_t unrollFactor);
+LogicalResult loopUnrollUpToFactor(Region &topRegion, AffineForOp forOp,
+                                   uint64_t unrollFactor);
 
 /// Returns true if `loops` is a perfectly nested loop nest, where loops appear
 /// in it from outermost to innermost.
@@ -75,34 +76,35 @@ void getPerfectlyNestedLoops(SmallVectorImpl<AffineForOp> &nestedLoops,
 /// with iteration arguments performing supported reductions and its inner loops
 /// can have iteration arguments. Returns success if the loop is successfully
 /// unroll-jammed.
-LogicalResult loopUnrollJamByFactor(AffineForOp forOp,
+LogicalResult loopUnrollJamByFactor(Region &topRegion, AffineForOp forOp,
                                     uint64_t unrollJamFactor);
 
 /// Unrolls and jams this loop by the specified factor or by the trip count (if
 /// constant), whichever is lower.
-LogicalResult loopUnrollJamUpToFactor(AffineForOp forOp,
+LogicalResult loopUnrollJamUpToFactor(Region &topRegion, AffineForOp forOp,
                                       uint64_t unrollJamFactor);
 
 /// Promotes the loop body of a AffineForOp to its containing block if the loop
 /// was known to have a single iteration.
-LogicalResult promoteIfSingleIteration(AffineForOp forOp);
+LogicalResult promoteIfSingleIteration(Region &topRegion, AffineForOp forOp);
 
 /// Promotes all single iteration AffineForOp's in the Function, i.e., moves
 /// their body into the containing Block.
-void promoteSingleIterationLoops(func::FuncOp f);
+void promoteSingleIterationLoops(Region &region);
 
 /// Skew the operations in an affine.for's body with the specified
 /// operation-wise shifts. The shifts are with respect to the original execution
 /// order, and are multiplied by the loop 'step' before being applied. If
 /// `unrollPrologueEpilogue` is set, fully unroll the prologue and epilogue
 /// loops when possible.
-LogicalResult affineForOpBodySkew(AffineForOp forOp, ArrayRef<uint64_t> shifts,
+LogicalResult affineForOpBodySkew(Region &topRegion, AffineForOp forOp,
+                                  ArrayRef<uint64_t> shifts,
                                   bool unrollPrologueEpilogue = false);
 
 /// Identify valid and profitable bands of loops to tile. This is currently just
 /// a temporary placeholder to test the mechanics of tiled code generation.
 /// Returns all maximal outermost perfect loop nests to tile.
-void getTileableBands(func::FuncOp f,
+void getTileableBands(Region &region,
                       std::vector<SmallVector<AffineForOp, 6>> *bands);
 
 /// Tiles the specified band of perfectly nested loops creating tile-space loops
@@ -190,14 +192,15 @@ struct AffineCopyOptions {
 /// encountered. For memrefs for whose element types a size in bytes can't be
 /// computed (`index` type), their capacity is not accounted for and the
 /// `fastMemCapacityBytes` copy option would be non-functional in such cases.
-LogicalResult affineDataCopyGenerate(Block::iterator begin, Block::iterator end,
+LogicalResult affineDataCopyGenerate(Region &topRegion, Block::iterator begin,
+                                     Block::iterator end,
                                      const AffineCopyOptions &copyOptions,
                                      std::optional<Value> filterMemRef,
                                      DenseSet<Operation *> &copyNests);
 
 /// A convenience version of affineDataCopyGenerate for all ops in the body of
 /// an AffineForOp.
-LogicalResult affineDataCopyGenerate(AffineForOp forOp,
+LogicalResult affineDataCopyGenerate(Region &topRegion, AffineForOp forOp,
                                      const AffineCopyOptions &copyOptions,
                                      std::optional<Value> filterMemRef,
                                      DenseSet<Operation *> &copyNests);
@@ -225,7 +228,8 @@ struct CopyGenerateResult {
 ///
 /// Also note that certain options in `copyOptions` aren't looked at anymore,
 /// like slowMemorySpace.
-LogicalResult generateCopyForMemRegion(const MemRefRegion &memrefRegion,
+LogicalResult generateCopyForMemRegion(Region &topRegion,
+                                       const MemRefRegion &memrefRegion,
                                        Operation *analyzedOp,
                                        const AffineCopyOptions &copyOptions,
                                        CopyGenerateResult &result);
@@ -273,7 +277,7 @@ void mapLoopToProcessorIds(scf::ForOp forOp, ArrayRef<Value> processorId,
                            ArrayRef<Value> numProcessors);
 
 /// Gathers all AffineForOps in 'func.func' grouped by loop depth.
-void gatherLoops(func::FuncOp func,
+void gatherLoops(Region &region,
                  std::vector<SmallVector<AffineForOp, 2>> &depthToLoops);
 
 /// Creates an AffineForOp while ensuring that the lower and upper bounds are
diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h
index 67c7a964feefd7..3a327e53d9f508 100644
--- a/mlir/include/mlir/Dialect/Affine/Utils.h
+++ b/mlir/include/mlir/Dialect/Affine/Utils.h
@@ -172,7 +172,7 @@ void normalizeAffineParallel(AffineParallelOp op);
 /// loop has been normalized (or is already in the normal form). If
 /// `promoteSingleIter` is true, the loop is simply promoted if it has a single
 /// iteration.
-LogicalResult normalizeAffineFor(AffineForOp op,
+LogicalResult normalizeAffineFor(Region &region, AffineForOp op,
                                  bool promoteSingleIter = false);
 
 /// Traverse `e` and return an AffineExpr where all occurrences of `dim` have
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
index 331b0f1b2c2b1c..17637d172c6cfa 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineDataCopyGeneration.cpp
@@ -116,6 +116,8 @@ void AffineDataCopyGeneration::runOnBlock(Block *block,
   AffineCopyOptions copyOptions = {generateDma, slowMemorySpace,
                                    fastMemorySpace, tagMemorySpace,
                                    fastMemCapacityBytes};
+  auto &topRegion =
+      block->getParent()->getParentOfType<func::FuncOp>().getBody();
 
   // Every affine.for op in the block starts and ends a block range for copying;
   // in addition, a contiguous sequence of operations starting with a
@@ -139,8 +141,9 @@ void AffineDataCopyGeneration::runOnBlock(Block *block,
     // If you hit a non-copy for loop, we will split there.
     if ((forOp = dyn_cast<AffineForOp>(&*it)) && copyNests.count(forOp) == 0) {
       // Perform the copying up unti this 'for' op first.
-      (void)affineDataCopyGenerate(/*begin=*/curBegin, /*end=*/it, copyOptions,
-                                   /*filterMemRef=*/std::nullopt, copyNests);
+      (void)affineDataCopyGenerate(topRegion, /*begin=*/curBegin, /*end=*/it,
+                                   copyOptions, /*filterMemRef=*/std::nullopt,
+                                   copyNests);
 
       // Returns true if the footprint is known to exceed capacity.
       auto exceedsCapacity = [&](AffineForOp forOp) {
@@ -172,8 +175,8 @@ void AffineDataCopyGeneration::runOnBlock(Block *block,
         // Inner loop copies have their own scope - we don't thus update
         // consumed capacity. The footprint check above guarantees this inner
         // loop's footprint fits.
-        (void)affineDataCopyGenerate(/*begin=*/it, /*end=*/std::next(it),
-                                     copyOptions,
+        (void)affineDataCopyGenerate(topRegion, /*begin=*/it,
+                                     /*end=*/std::next(it), copyOptions,
                                      /*filterMemRef=*/std::nullopt, copyNests);
       }
       // Get to the next load or store op after 'forOp'.
@@ -196,7 +199,7 @@ void AffineDataCopyGeneration::runOnBlock(Block *block,
     assert(!curBegin->hasTrait<OpTrait::IsTerminator>() &&
            "can't be a terminator");
     // Exclude the affine.yield - hence, the std::prev.
-    (void)affineDataCopyGenerate(/*begin=*/curBegin,
+    (void)affineDataCopyGenerate(topRegion, /*begin=*/curBegin,
                                  /*end=*/std::prev(block->end()), copyOptions,
                                  /*filterMemRef=*/std::nullopt, copyNests);
   }
@@ -225,7 +228,7 @@ void AffineDataCopyGeneration::runOnOperation() {
     // continuation of the walk or the collection of load/store ops.
     nest->walk([&](Operation *op) {
       if (auto forOp = dyn_cast<AffineForOp>(op))
-        (void)promoteIfSingleIteration(forOp);
+        (void)promoteIfSingleIteration(f.getBody(), forOp);
       else if (isa<AffineLoadOp, AffineStoreOp>(op))
         copyOps.push_back(op);
     });
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp
index 5cc38f70517261..6773027461f19d 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineLoopNormalize.cpp
@@ -38,11 +38,12 @@ struct AffineLoopNormalizePass
   }
 
   void runOnOperation() override {
+    auto &topRegion = getOperation().getBody();
     getOperation().walk([&](Operation *op) {
       if (auto affineParallel = dyn_cast<AffineParallelOp>(op))
         normalizeAffineParallel(affineParallel);
       else if (auto affineFor = dyn_cast<AffineForOp>(op))
-        (void)normalizeAffineFor(affineFor, promoteSingleIter);
+        (void)normalizeAffineFor(topRegion, affineFor, promoteSingleIter);
     });
   }
 };
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp
index 2650a06d198eab..022d327b197a40 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopTiling.cpp
@@ -238,7 +238,7 @@ void LoopTiling::getTileSizes(ArrayRef<AffineForOp> band,
 void LoopTiling::runOnOperation() {
   // Bands of loops to tile.
   std::vector<SmallVector<AffineForOp, 6>> bands;
-  getTileableBands(getOperation(), &bands);
+  getTileableBands(getOperation().getBody(), &bands);
 
   // Tile each band.
   for (auto &band : bands) {
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp
index 57df7ada91654c..9bff4e4b53a902 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopUnroll.cpp
@@ -108,7 +108,7 @@ void LoopUnroll::runOnOperation() {
         loops.push_back(forOp);
     });
     for (auto forOp : loops)
-      (void)loopUnrollFull(forOp);
+      (void)loopUnrollFull(func.getBody(), forOp);
     return;
   }
 
@@ -131,18 +131,20 @@ void LoopUnroll::runOnOperation() {
 /// Unrolls a 'affine.for' op. Returns success if the loop was unrolled,
 /// failure otherwise. The default unroll factor is 4.
 LogicalResult LoopUnroll::runOnAffineForOp(AffineForOp forOp) {
+  auto &topRegion = forOp->getParentOfType<func::FuncOp>().getBody();
+
   // Use the function callback if one was provided.
   if (getUnrollFactor)
-    return loopUnrollByFactor(forOp, getUnrollFactor(forOp),
+    return loopUnrollByFactor(topRegion, forOp, getUnrollFactor(forOp),
                               /*annotateFn=*/nullptr, cleanUpUnroll);
   // Unroll completely if full loop unroll was specified.
   if (unrollFull)
-    return loopUnrollFull(forOp);
+    return loopUnrollFull(topRegion, forOp);
   // Otherwise, unroll by the given unroll factor.
   if (unrollUpToFactor)
-    return loopUnrollUpToFactor(forOp, unrollFactor);
-  return loopUnrollByFactor(forOp, unrollFactor, /*annotateFn=*/nullptr,
-                            cleanUpUnroll);
+    return loopUnrollUpToFactor(topRegion, forOp, unrollFactor);
+  return loopUnrollByFactor(topRegion, forOp, unrollFactor,
+                            /*annotateFn=*/nullptr, cleanUpUnroll);
 }
 
 std::unique_ptr<OperationPass<func::FuncOp>> mlir::affine::createLoopUnrollPass(
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopUnrollAndJam.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopUnrollAndJam.cpp
index a79160df7575a3..a53299807ea831 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopUnrollAndJam.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopUnrollAndJam.cpp
@@ -90,6 +90,7 @@ void LoopUnrollAndJam::runOnOperation() {
   // unroll-and-jammed by this pass. However, runOnAffineForOp can be called on
   // any for operation.
   auto &entryBlock = getOperation().front();
+  auto &topRegion = getOperation().getBody();
   if (auto forOp = dyn_cast<AffineForOp>(entryBlock.front()))
-    (void)loopUnrollJamByFactor(forOp, unrollJamFactor);
+    (void)loopUnrollJamByFactor(topRegion, forOp, unrollJamFactor);
 }
diff --git a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp
index deb530b4cf1c95..f12a7b16262748 100644
--- a/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/PipelineDataTransfer.cpp
@@ -373,7 +373,8 @@ void PipelineDataTransfer::runOnAffineForOp(AffineForOp forOp) {
     return;
   }
 
-  if (failed(affineForOpBodySkew(forOp, shifts))) {
+  auto &topRegion = forOp->getParentOfType<func::FuncOp>().getBody();
+  if (failed(affineForOpBodySkew(topRegion, forOp, shifts))) {
     LLVM_DEBUG(llvm::dbgs() << "op body skewing failed - unexpected\n";);
     return;
   }
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
index fb45528ad5e7d1..534f2c01ddb4e7 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/Affine/Analysis/Utils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/LoopUtils.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/PatternMatch.h"
@@ -463,9 +464,11 @@ void mlir::affine::fuseLoops(AffineForOp srcForOp, AffineForOp dstForOp,
       // Patch reduction loop - only ones that are sibling-fused with the
       // destination loop - into the parent loop.
       (void)promoteSingleIterReductionLoop(forOp, true);
-    else
+    else {
       // Promote any single iteration slice loops.
-      (void)promoteIfSingleIteration(forOp);
+      auto &topRegion = forOp->getParentOfType<func::FuncOp>().getBody();
+      (void)promoteIfSingleIteration(topRegion, forOp);
+    }
   }
 }
 
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index 3794ef2dabe1e0..1eb889f28c6081 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -128,7 +128,8 @@ static void replaceIterArgsAndYieldResults(AffineForOp forOp) {
 
 /// Promotes the loop body of a forOp to its containing block if the forOp
 /// was known to have a single iteration.
-LogicalResult mlir::affine::promoteIfSingleIteration(AffineForOp forOp) {
+LogicalResult mlir::affine::promoteIfSingleIteration(Region &topRegion,
+                                                     AffineForOp forOp) {
   std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
   if (!tripCount || *tripCount != 1)
     return failure();
@@ -142,7 +143,7 @@ LogicalResult mlir::affine::promoteIfSingleIteration(AffineForOp forOp) {
   auto *parentBlock = forOp->getBlock();
   if (!iv.use_empty()) {
     if (forOp.hasConstantLowerBound()) {
-      OpBuilder topBuilder(forOp->getParentOfType<func::FuncOp>().getBody());
+      OpBuilder topBuilder(topRegion);
       auto constOp = topBuilder.create<arith::ConstantIndexOp>(
           forOp.getLoc(), forOp.getConstantLowerBound());
       iv.replaceAllUsesWith(constOp);
@@ -182,7 +183,7 @@ LogicalResult mlir::affine::promoteIfSingleIteration(AffineForOp forOp) {
 static AffineForOp generateShiftedLoop(
     AffineMap lbMap, AffineMap ubMap,
     const std::vector<std::pair<uint64_t, ArrayRef<Operation *>>> &opGroupQueue,
-    unsigned offset, AffineForOp srcForOp, OpBuilder b) {
+    unsigned offset, AffineForOp srcForOp, Region &topRegion, OpBuilder b) {
   auto lbOperands = srcForOp.getLowerBoundOperands();
   auto ubOperands = srcForOp.getUpperBoundOperands();
 
@@ -218,7 +219,7 @@ static AffineForOp generateShiftedLoop(
     for (auto *op : ops)
       bodyBuilder.clone(*op, operandMap);
   };
-  if (succeeded(promoteIfSingleIteration(loopChunk)))
+  if (succeeded(promoteIfSingleIteration(topRegion, loopChunk)))
     return AffineForOp();
   return loopChunk;
 }
@@ -234,7 +235,8 @@ static AffineForOp generateShiftedLoop(
 // asserts preservation of SSA dominance. A check for that as well as that for
 // memory-based dependence preservation check rests with the users of this
 // method.
-LogicalResult mlir::affine::affineForOpBodySkew(AffineForOp forOp,
+LogicalResult mlir::affine::affineForOpBodySkew(Region &topRegion,
+                                                AffineForOp forOp,
                                                 ArrayRef<uint64_t> shifts,
                                                 bool unrollPrologueEpilogue) {
   assert(forOp.getBody()->getOperations().size() == shifts.size() &&
@@ -308,14 +310,15 @@ LogicalResult mlir::affine::affineForOpBodySkew(AffineForOp forOp,
         res = generateShiftedLoop(
             b.getShiftedAffineMap(origLbMap, lbShift),
             b.getShiftedAffineMap(origLbMap, lbShift + tripCount * step),
-            opGroupQueue, /*offset=*/0, forOp, b);
+            opGroupQueue, /*offset=*/0, forOp, topRegion, b);
         // Entire loop for the queued op groups generated, empty it.
         opGroupQueue.clear();
         lbShift += tripCount * step;
       } else {
         res = generateShiftedLoop(b.getShiftedAffineMap(origLbMap, lbShift),
                                   b.getShiftedAffineMap(origLbMap, d),
-                                  opGroupQueue, /*offset=*/0, forOp, b);
+                                  opGroupQueue, /*offset=*/0, forOp, topRegion,
+                                  b);
         lbShift = d * step;
       }
 
@@ -345,9 +348,10 @@ LogicalResult mlir::affine::affineForOpBodySkew(AffineForOp forOp,
   // and their loops completed.
   for (unsigned i = 0, e = opGroupQueue.size(); i < e; ++i) {
     uint64_t ubShift = (opGroupQueue[i].first + tripCount) * step;
-    epilogue = generateShiftedLoop(b.getShiftedAffineMap(origLbMap, lbShift),
-                                   b.getShiftedAffineMap(origLbMap, ubShift),
-                                   opGroupQueue, /*offset=*/i, forOp, b);
+    epilogue =
+        generateShiftedLoop(b.getShiftedAffineMap(origLbMap, lbShift),
+                            b.getShiftedAffineMap(origLbMap, ubShift),
+                            opGroupQueue, /*offset=*/i, forOp, topRegion, b);
     lbShift = ubShift;
     if (!prologue)
       prologue = epilogue;
@@ -357,9 +361,9 @@ LogicalResult mlir::affine::affineForOpBodySkew(AffineForOp forOp,
   forOp.erase();
 
   if (unrollPrologueEpilogue && prologue)
-    (void)loopUnrollFull(prologue);
+    (void)loopUnrollFull(topRegion, prologue);
   if (unrollPrologueEpilogue && !epilogue && epilogue != prologue)
-    (void)loopUnrollFull(epilogue);
+    (void)loopUnrollFull(topRegion, epilogue);
 
   return success();
 }
@@ -879,10 +883,10 @@ void mlir::affine::getPerfectlyNestedLoops(
 /// a temporary placeholder to test the mechanics of tiled code generation.
 /// Returns all maximal outermost perfect loop nests to tile.
 void mlir::affine::getTileableBands(
-    func::FuncOp f, std::vector<SmallVector<AffineForOp, 6>> *bands) {
+    Region &region, std::vector<SmallVector<AffineForOp, 6>> *bands) {
   // Get maximal perfect nest of 'affine.for' insts starting from root
   // (inclusive).
-  for (AffineForOp forOp : f.getOps<AffineForOp>()) {
+  for (AffineForOp forOp : region.getOps<AffineForOp>()) {
     SmallVector<AffineForOp, 6> band;
     getPerfectlyNestedLoops(band, forOp);
     bands->push_back(band);
@@ -890,28 +894,30 @@ void mlir::affine::getTileableBands(
 }
 
 /// Unrolls this loop completely.
-LogicalResult mlir::affine::loopUnrollFull(AffineForOp forOp) {
+LogicalResult mlir::affine::loopUnrollFull(Region &topRegion,
+                                           AffineForOp forOp) {
   std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
   if (mayBeConstantTripCount.has_value()) {
     uint64_t tripCount = *mayBeConstantTripCount;
     if (tripCount == 0)
       return success();
     if (tripCount == 1)
-      return promoteIfSingleIteration(forOp);
-    return loopUnrollByFactor(forOp, tripCount);
+      return promoteIfSingleIteration(topRegion, forOp);
+    return loopUnrollByFactor(topRegion, forOp, tripCount);
   }
   return failure();
 }
 
 /// Unrolls this loop by the specified factor or by the trip count (if constant)
 /// whichever is lower.
-LogicalResult mlir::affine::loopUnrollUpToFactor(AffineForOp forOp,
+LogicalResult mlir::affine::loopUnrollUpToFactor(Region &topRegion,
+                                                 AffineForOp forOp,
                                                  uint64_t unrollFactor) {
   std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
   if (mayBeConstantTripCount.has_value() &&
       *mayBeConstantTripCount < unrollFactor)
-    return loopUnrollByFactor(forOp, *mayBeConstantTripCount);
-  return loopUnrollByFactor(forOp, unrollFactor);
+    return loopUnrollByFactor(topRegion, forOp, *mayBeConstantTripCount);
+  return loopUnrollByFactor(topRegion, forOp, unrollFactor);
 }
 
 /// Generates unrolled copies of AffineForOp 'loopBodyBlock', with associated
@@ -978,7 +984,8 @@ static void generateUnrolledLoop(
 
 /// Helper to generate cleanup loop for unroll or unroll-and-jam when the trip
 /// count is not a multiple of `unrollFactor`.
-static LogicalResult generateCleanupLoopForUnroll(AffineForOp forOp,
+static LogicalResult generateCleanupLoopForUnroll(Region &topRegion,
+                                                  AffineForOp forOp,
                                                   uint64_t unrollFactor) {
   // Insert the cleanup loop right after 'forOp'.
   OpBuilder builder(forOp->getBlock(), std::next(Block::iterator(forOp)));
@@ -1003,7 +1010,7 @@ static LogicalResult generateCleanupLoopForUnroll(AffineForOp forOp,
 
   cleanupForOp.setLowerBound(cleanupOperands, cleanupMap);
   // Promote the loop body up if this has turned into a single iteration loop.
-  (void)promoteIfSingleIteration(cleanupForOp);
+  (void)promoteIfSingleIteration(topRegion, cleanupForOp);
 
   // Adjust upper bound of the original loop; this is the same as the lower
   // bound of the cleanup loop.
@@ -1014,7 +1021,7 @@ static LogicalResult generateCleanupLoopForUnroll(AffineForOp forOp,
 /// Unrolls this loop by the specified factor. Returns success if the loop
 /// is successfully unrolled.
 LogicalResult mlir::affine::loopUnrollByFactor(
-    AffineForOp forOp, uint64_t unrollFactor,
+    Region &topRegion, AffineForOp forOp, uint64_t unrollFactor,
     function_ref<void(unsigned, Operation *, OpBuilder)> annotateFn,
     bool cleanUpUnroll) {
   assert(unrollFactor > 0 && "unroll factor should be positive");
@@ -1022,7 +1029,7 @@ LogicalResult mlir::affine::loopUnrollByFactor(
   std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
   if (unrollFactor == 1) {
     if (mayBeConstantTripCount && *mayBeConstantTripCount == 1 &&
-        failed(promoteIfSingleIteration(forOp)))
+        failed(promoteIfSingleIteration(topRegion, forOp)))
       return failure();
     return success();
   }
@@ -1035,7 +1042,7 @@ LogicalResult mlir::affine::loopUnrollByFactor(
   if (mayBeConstantTripCount && *mayBeConstantTripCount < unrollFactor) {
     if (cleanUpUnroll) {
       // Unroll the cleanup loop if cleanUpUnroll is specified.
-      return loopUnrollFull(forOp);
+      return loopUnrollFull(topRegion, forOp);
     }
 
     return failure();
@@ -1052,8 +1059,8 @@ LogicalResult mlir::affine::loopUnrollByFactor(
       return failure();
     if (cleanUpUnroll)
       // Force unroll including cleanup loop
-      return loopUnrollFull(forOp);
-    if (failed(generateCleanupLoopForUnroll(forOp, unrollFactor)))
+      return loopUnrollFull(topRegion, forOp);
+    if (failed(generateCleanupLoopForUnroll(topRegion, forOp, unrollFactor)))
       assert(false && "cleanup loop lower bound map for single result lower "
                       "and upper bound maps can always be determined");
   }
@@ -1076,17 +1083,18 @@ LogicalResult mlir::affine::loopUnrollByFactor(
       /*iterArgs=*/iterArgs, /*yieldedValues=*/yieldedValues);
 
   // Promote the loop body up if this has turned into a single iteration loop.
-  (void)promoteIfSingleIteration(forOp);
+  (void)promoteIfSingleIteration(topRegion, forOp);
   return success();
 }
 
-LogicalResult mlir::affine::loopUnrollJamUpToFactor(AffineForOp forOp,
+LogicalResult mlir::affine::loopUnrollJamUpToFactor(Region &topRegion,
+                                                    AffineForOp forOp,
                                                     uint64_t unrollJamFactor) {
   std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
   if (mayBeConstantTripCount.has_value() &&
       *mayBeConstantTripCount < unrollJamFactor)
-    return loopUnrollJamByFactor(forOp, *mayBeConstantTripCount);
-  return loopUnrollJamByFactor(forOp, unrollJamFactor);
+    return loopUnrollJamByFactor(topRegion, forOp, *mayBeConstantTripCount);
+  return loopUnrollJamByFactor(topRegion, forOp, unrollJamFactor);
 }
 
 /// Check if all control operands of all loops are defined outside of `forOp`
@@ -1131,14 +1139,15 @@ struct JamBlockGatherer {
 };
 
 /// Unrolls and jams this loop by the specified factor.
-LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
+LogicalResult mlir::affine::loopUnrollJamByFactor(Region &topRegion,
+                                                  AffineForOp forOp,
                                                   uint64_t unrollJamFactor) {
   assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
 
   std::optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
   if (unrollJamFactor == 1) {
     if (mayBeConstantTripCount && *mayBeConstantTripCount == 1 &&
-        failed(promoteIfSingleIteration(forOp)))
+        failed(promoteIfSingleIteration(topRegion, forOp)))
       return failure();
     return success();
   }
@@ -1185,7 +1194,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
     if (forOp.getLowerBoundMap().getNumResults() != 1 ||
         forOp.getUpperBoundMap().getNumResults() != 1)
       return failure();
-    if (failed(generateCleanupLoopForUnroll(forOp, unrollJamFactor)))
+    if (failed(generateCleanupLoopForUnroll(topRegion, forOp, unrollJamFactor)))
       assert(false && "cleanup loop lower bound map for single result lower "
                       "and upper bound maps can always be determined");
   }
@@ -1321,7 +1330,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
   }
 
   // Promote the loop body up if this has turned into a single iteration loop.
-  (void)promoteIfSingleIteration(forOp);
+  (void)promoteIfSingleIteration(topRegion, forOp);
   return success();
 }
 
@@ -1968,8 +1977,8 @@ emitRemarkForBlock(Block &block) {
 /// output argument `nEnd` is set to the new end. `sizeInBytes` is set to the
 /// size of the fast buffer allocated.
 static LogicalResult generateCopy(
-    const MemRefRegion &region, Block *block, Block::iterator begin,
-    Block::iterator end, Block *copyPlacementBlock,
+    const MemRefRegion &region, Block *block, Region &topRegion,
+    Block::iterator begin, Block::iterator end, Block *copyPlacementBlock,
     Block::iterator copyInPlacementStart, Block::iterator copyOutPlacementStart,
     const AffineCopyOptions &copyOptions, DenseMap<Value, Value> &fastBufferMap,
     DenseSet<Operation *> &copyNests, uint64_t *sizeInBytes,
@@ -1977,9 +1986,9 @@ static LogicalResult generateCopy(
   *nBegin = begin;
   *nEnd = end;
 
-  func::FuncOp f = begin->getParentOfType<func::FuncOp>();
-  OpBuilder topBuilder(f.getBody());
-  Value zeroIndex = topBuilder.create<arith::ConstantIndexOp>(f.getLoc(), 0);
+  OpBuilder topBuilder(topRegion);
+  Value zeroIndex =
+      topBuilder.create<arith::ConstantIndexOp>(topRegion.getLoc(), 0);
 
   *sizeInBytes = 0;
 
@@ -1997,8 +2006,7 @@ static LogicalResult generateCopy(
   OpBuilder &b = region.isWrite() ? epilogue : prologue;
 
   // Builder to create constants at the top level.
-  auto func = copyPlacementBlock->getParent()->getParentOfType<func::FuncOp>();
-  OpBuilder top(func.getBody());
+  OpBuilder top(topRegion);
 
   auto loc = region.loc;
   auto memref = region.memref;
@@ -2301,11 +2309,10 @@ static bool getFullMemRefAsRegion(Operation *op, unsigned numParamLoopIVs,
   return true;
 }
 
-LogicalResult
-mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end,
-                                     const AffineCopyOptions &copyOptions,
-                                     std::optional<Value> filterMemRef,
-                                     DenseSet<Operation *> &copyNests) {
+LogicalResult mlir::affine::affineDataCopyGenerate(
+    Region &topRegion, Block::iterator begin, Block::iterator end,
+    const AffineCopyOptions &copyOptions, std::optional<Value> filterMemRef,
+    DenseSet<Operation *> &copyNests) {
   if (begin == end)
     return success();
 
@@ -2450,10 +2457,11 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end,
 
           uint64_t sizeInBytes;
           Block::iterator nBegin, nEnd;
-          LogicalResult iRet = generateCopy(
-              *regionEntry.second, block, begin, end, copyPlacementBlock,
-              copyInPlacementStart, copyOutPlacementStart, copyOptions,
-              fastBufferMap, copyNests, &sizeInBytes, &nBegin, &nEnd);
+          LogicalResult iRet =
+              generateCopy(*regionEntry.second, block, topRegion, begin, end,
+                           copyPlacementBlock, copyInPlacementStart,
+                           copyOutPlacementStart, copyOptions, fastBufferMap,
+                           copyNests, &sizeInBytes, &nBegin, &nEnd);
           if (succeeded(iRet)) {
             // begin/end could have been invalidated, and need update.
             begin = nBegin;
@@ -2492,15 +2500,15 @@ mlir::affine::affineDataCopyGenerate(Block::iterator begin, Block::iterator end,
 // A convenience version of affineDataCopyGenerate for all ops in the body of
 // an AffineForOp.
 LogicalResult mlir::affine::affineDataCopyGenerate(
-    AffineForOp forOp, const AffineCopyOptions &copyOptions,
+    Region &topRegion, AffineForOp forOp, const AffineCopyOptions &copyOptions,
     std::optional<Value> filterMemRef, DenseSet<Operation *> &copyNests) {
-  return affineDataCopyGenerate(forOp.getBody()->begin(),
+  return affineDataCopyGenerate(topRegion, forOp.getBody()->begin(),
                                 std::prev(forOp.getBody()->end()), copyOptions,
                                 filterMemRef, copyNests);
 }
 
 LogicalResult mlir::affine::generateCopyForMemRegion(
-    const MemRefRegion &memrefRegion, Operation *analyzedOp,
+    Region &topRegion, const MemRefRegion &memrefRegion, Operation *analyzedOp,
     const AffineCopyOptions &copyOptions, CopyGenerateResult &result) {
   Block *block = analyzedOp->getBlock();
   auto begin = analyzedOp->getIterator();
@@ -2508,8 +2516,8 @@ LogicalResult mlir::affine::generateCopyForMemRegion(
   DenseMap<Value, Value> fastBufferMap;
   DenseSet<Operation *> copyNests;
 
-  auto err = generateCopy(memrefRegion, block, begin, end, block, begin, end,
-                          copyOptions, fastBufferMap, copyNests,
+  auto err = generateCopy(memrefRegion, block, topRegion, begin, end, block,
+                          begin, end, copyOptions, fastBufferMap, copyNests,
                           &result.sizeInBytes, &begin, &end);
   if (failed(err))
     return err;
@@ -2544,8 +2552,8 @@ gatherLoopsInBlock(Block *block, unsigned currLoopDepth,
 
 /// Gathers all AffineForOps in 'func.func' grouped by loop depth.
 void mlir::affine::gatherLoops(
-    func::FuncOp func, std::vector<SmallVector<AffineForOp, 2>> &depthToLoops) {
-  for (auto &block : func)
+    Region &region, std::vector<SmallVector<AffineForOp, 2>> &depthToLoops) {
+  for (auto &block : region)
     gatherLoopsInBlock(&block, /*currLoopDepth=*/0, depthToLoops);
 
   // Remove last loop level from output since it's empty.
diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
index 3dc5539cde3d98..e5e016b7a78699 100644
--- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp
@@ -552,9 +552,10 @@ void mlir::affine::normalizeAffineParallel(AffineParallelOp op) {
   op.setUpperBounds(ranges.getOperands(), newUpperMap);
 }
 
-LogicalResult mlir::affine::normalizeAffineFor(AffineForOp op,
+LogicalResult mlir::affine::normalizeAffineFor(Region &topRegion,
+                                               AffineForOp op,
                                                bool promoteSingleIter) {
-  if (promoteSingleIter && succeeded(promoteIfSingleIteration(op)))
+  if (promoteSingleIter && succeeded(promoteIfSingleIteration(topRegion, op)))
     return success();
 
   // Check if the forop is already normalized.
diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
index bc2fe5772af9d6..7908d2f63ebe36 100644
--- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
+++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
@@ -308,10 +308,13 @@ transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter,
                                     transform::ApplyToEachResultList &results,
                                     transform::TransformState &state) {
   LogicalResult result(failure());
-  if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op))
+  if (scf::ForOp scfFor = dyn_cast<scf::ForOp>(op)) {
     result = loopUnrollByFactor(scfFor, getFactor());
-  else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op))
-    result = loopUnrollByFactor(affineFor, getFactor());
+  }
+  else if (AffineForOp affineFor = dyn_cast<AffineForOp>(op)) {
+    auto &topRegion = affineFor->getParentOfType<func::FuncOp>().getBody();
+    result = loopUnrollByFactor(topRegion, affineFor, getFactor());
+  }
 
   if (failed(result)) {
     DiagnosedSilenceableFailure diag = emitSilenceableError()
diff --git a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
index b418a457473a8e..8fd6d5f7c9758e 100644
--- a/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestAffineDataCopy.cpp
@@ -60,7 +60,8 @@ struct TestAffineDataCopy
 void TestAffineDataCopy::runOnOperation() {
   // Gather all AffineForOps by loop depth.
   std::vector<SmallVector<AffineForOp, 2>> depthToLoops;
-  gatherLoops(getOperation(), depthToLoops);
+  auto &topRegion = getOperation().getBody();
+  gatherLoops(topRegion, depthToLoops);
   if (depthToLoops.empty())
     return;
 
@@ -93,15 +94,16 @@ void TestAffineDataCopy::runOnOperation() {
                                    /*fastMemCapacityBytes=*/32 * 1024 * 1024UL};
   DenseSet<Operation *> copyNests;
   if (clMemRefFilter) {
-    if (failed(affineDataCopyGenerate(loopNest, copyOptions, load.getMemRef(),
-                                      copyNests)))
+    if (failed(affineDataCopyGenerate(topRegion, loopNest, copyOptions,
+                                      load.getMemRef(), copyNests)))
       return;
   } else if (clTestGenerateCopyForMemRegion) {
     CopyGenerateResult result;
     MemRefRegion region(loopNest.getLoc());
     if (failed(region.compute(load, /*loopDepth=*/0)))
       return;
-    if (failed(generateCopyForMemRegion(region, loopNest, copyOptions, result)))
+    if (failed(generateCopyForMemRegion(topRegion, region, loopNest,
+                                        copyOptions, result)))
       return;
   }
 
@@ -112,12 +114,15 @@ void TestAffineDataCopy::runOnOperation() {
     // With a post order walk, the erasure of loops does not affect
     // continuation of the walk or the collection of load/store ops.
     nest->walk([&](Operation *op) {
-      if (auto forOp = dyn_cast<AffineForOp>(op))
-        (void)promoteIfSingleIteration(forOp);
-      else if (auto loadOp = dyn_cast<AffineLoadOp>(op))
+      if (auto forOp = dyn_cast<AffineForOp>(op)) {
+        auto &topRegion = forOp->getParentOfType<func::FuncOp>().getBody();
+        (void)promoteIfSingleIteration(topRegion, forOp);
+      } else if (auto loadOp = dyn_cast<AffineLoadOp>(op)) {
         copyOps.push_back(loadOp);
-      else if (auto storeOp = dyn_cast<AffineStoreOp>(op))
+      }
+      else if (auto storeOp = dyn_cast<AffineStoreOp>(op)) {
         copyOps.push_back(storeOp);
+      }
     });
   }
 
diff --git a/mlir/test/lib/Dialect/Affine/TestAffineLoopParametricTiling.cpp b/mlir/test/lib/Dialect/Affine/TestAffineLoopParametricTiling.cpp
index f8e76356c43215..83ca705e591abf 100644
--- a/mlir/test/lib/Dialect/Affine/TestAffineLoopParametricTiling.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestAffineLoopParametricTiling.cpp
@@ -76,7 +76,7 @@ getTilingParameters(ArrayRef<AffineForOp> band,
 void TestAffineLoopParametricTiling::runOnOperation() {
   // Bands of loops to tile.
   std::vector<SmallVector<AffineForOp, 6>> bands;
-  getTileableBands(getOperation(), &bands);
+  getTileableBands(getOperation().getBody(), &bands);
 
   // Tile each band.
   for (MutableArrayRef<AffineForOp> band : bands) {
diff --git a/mlir/test/lib/Dialect/Affine/TestLoopFusion.cpp b/mlir/test/lib/Dialect/Affine/TestLoopFusion.cpp
index f4f1593dc53e2b..5938a19f62a54a 100644
--- a/mlir/test/lib/Dialect/Affine/TestLoopFusion.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestLoopFusion.cpp
@@ -181,12 +181,13 @@ static bool iterateLoops(ArrayRef<SmallVector<AffineForOp, 2>> depthToLoops,
 
 void TestLoopFusion::runOnOperation() {
   std::vector<SmallVector<AffineForOp, 2>> depthToLoops;
+  auto &topRegion = getOperation().getBody();
   if (clTestLoopFusionTransformation) {
     // Run loop fusion until a fixed point is reached.
     do {
       depthToLoops.clear();
       // Gather all AffineForOps by loop depth.
-      gatherLoops(getOperation(), depthToLoops);
+      gatherLoops(topRegion, depthToLoops);
 
       // Try to fuse all combinations of src/dst loop nests in 'depthToLoops'.
     } while (iterateLoops(depthToLoops, testLoopFusionTransformation,
@@ -195,7 +196,7 @@ void TestLoopFusion::runOnOperation() {
   }
 
   // Gather all AffineForOps by loop depth.
-  gatherLoops(getOperation(), depthToLoops);
+  gatherLoops(topRegion, depthToLoops);
 
   // Run tests on all combinations of src/dst loop nests in 'depthToLoops'.
   if (clTestDependenceCheck)
diff --git a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
index b497f8d75fde75..0e2334bc22d5ed 100644
--- a/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
+++ b/mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
@@ -230,7 +230,8 @@ void VectorizerTestPass::testComposeMaps(llvm::raw_ostream &outs) {
 /// Test for 'vectorizeAffineLoopNest' utility.
 void VectorizerTestPass::testVecAffineLoopNest(llvm::raw_ostream &outs) {
   std::vector<SmallVector<AffineForOp, 2>> loops;
-  gatherLoops(getOperation(), loops);
+  auto &topRegion = getOperation().getBody();
+  gatherLoops(topRegion, loops);
 
   // Expected only one loop nest.
   if (loops.empty() || loops[0].size() != 1)



More information about the Mlir-commits mailing list