[Mlir-commits] [mlir] [mlir][TilingInterface] Use `LoopLikeOpInterface` in tiling using SCF to unify tiling with `scf.for` and `scf.forall`. (PR #77874)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jan 11 21:33:11 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: None (MaheshRavishankar)

<details>
<summary>Changes</summary>

Using `LoopLikeOpInterface` as the basis for the implementation unifies all the tiling logic for both `scf.for` and `scf.forall`. The only difference is the actual loop generation.
Instead of many entry points for each loop type, the loop type is now passed as part of the options passed to the tiling method.

This is a breaking change with the following changes

1) The `scf::tileUsingSCFForOp` is renamed to `scf::tileUsingSCF` 2) The `scf::tileUsingSCFForallOp` is deprecated. The same
   functionality is obtained by using `scf::tileUsingSCF` and setting
   the loop type in `scf::SCFTilingOptions` passed into this method to
   `scf::SCFTilingOptions::LoopType::ForallOp` (using the
   `setLoopType` method).
3) The `scf::tileConsumerAndFusedProducerGreedilyUsingSCFForOp` is
   renamed to `scf::tileConsumerAndFuseProducerUsingSCF`. The use of
   the `controlFn` in `scf::SCFTileAndFuseOptions` allows implementing
   any strategy with the default callback implemeting the greedy fusion.
4) The `scf::SCFTilingResult` and `scf::SCFTileAndFuseResult` now use
   `SmallVector<LoopLikeOpInterface>`.
5) To make `scf::ForallOp` implement the parts of
   `LoopLikeOpInterface` needed, the `getOutputBlockArguments()`
   method is replaced with `getRegionIterArgs()`

These changes now bring the tiling and fusion capabilities using `scf.forall` on par with what was already supported

---

Patch is 107.39 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/77874.diff


18 Files Affected:

- (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (+6-13) 
- (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+22-19) 
- (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+5-5) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp (+2-2) 
- (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+47-3) 
- (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+411-398) 
- (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+4-4) 
- (modified) mlir/lib/Interfaces/LoopLikeInterface.cpp (+19-17) 
- (modified) mlir/test/Dialect/Linalg/generalize-tensor-unpack-tile.mlir (+4-4) 
- (modified) mlir/test/Dialect/Linalg/tile-conv.mlir (+1-1) 
- (modified) mlir/test/Dialect/Linalg/tile-tensors.mlir (+1-1) 
- (modified) mlir/test/Dialect/Tensor/tiling.mlir (+10-14) 
- (added) mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-scfforall.mlir (+176) 
- (modified) mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir (+2-3) 
- (modified) mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir (+9-9) 
- (modified) mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir (+148-2) 
- (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp (+95-3) 
- (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td (+24-2) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 8d65d3dd820baf..819d88df0ae57f 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -135,9 +135,9 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
 
 def ForOp : SCF_Op<"for",
       [AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
-       ["getInitsMutable", "getSingleInductionVar", "getSingleLowerBound",
-        "getSingleStep", "getSingleUpperBound", "getYieldedValuesMutable",
-        "getLoopResults", "promoteIfSingleIteration",
+       ["getInitsMutable", "getRegionIterArgs", "getSingleInductionVar", 
+        "getSingleLowerBound", "getSingleStep", "getSingleUpperBound",
+        "getYieldedValuesMutable", "getLoopResults", "promoteIfSingleIteration",
         "replaceWithAdditionalYields"]>,
        AllTypesMatch<["lowerBound", "upperBound", "step"]>,
        ConditionallySpeculatable,
@@ -259,10 +259,6 @@ def ForOp : SCF_Op<"for",
 
     Value getInductionVar() { return getBody()->getArgument(0); }
 
-    Block::BlockArgListType getRegionIterArgs() {
-      return getBody()->getArguments().drop_front(getNumInductionVars());
-    }
-
     /// Return the `index`-th region iteration argument.
     BlockArgument getRegionIterArg(unsigned index) {
       assert(index < getNumRegionIterArgs() &&
@@ -304,8 +300,9 @@ def ForallOp : SCF_Op<"forall", [
        AttrSizedOperandSegments,
        AutomaticAllocationScope,
        DeclareOpInterfaceMethods<LoopLikeOpInterface,
-          ["promoteIfSingleIteration", "getSingleInductionVar",
-          "getSingleLowerBound", "getSingleUpperBound", "getSingleStep"]>,
+          ["getInitsMutable", "getRegionIterArgs", "getSingleInductionVar", 
+           "getSingleLowerBound", "getSingleUpperBound", "getSingleStep",
+           "promoteIfSingleIteration", "replaceWithAdditionalYields"]>,
        RecursiveMemoryEffects,
        SingleBlockImplicitTerminator<"scf::InParallelOp">,
        DeclareOpInterfaceMethods<RegionBranchOpInterface>,
@@ -585,10 +582,6 @@ def ForallOp : SCF_Op<"forall", [
                                     getNumDynamicControlOperands() + getRank());
     }
 
-    ArrayRef<BlockArgument> getOutputBlockArguments() {
-      return getBody()->getArguments().drop_front(getRank());
-    }
-
     ::mlir::ValueRange getInductionVars() {
       return getBody()->getArguments().take_front(getRank());
     }
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 5d2d78e6e6165b..965ef9e203be28 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -12,6 +12,7 @@
 #include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/Tensor/Transforms/Transforms.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/TilingInterface.h"
 
 #include <deque>
@@ -52,6 +53,14 @@ struct SCFTilingOptions {
     return *this;
   }
 
+  /// Specify which loop construct to use for tile and fuse.
+  enum class LoopType { ForOp, ForallOp };
+  LoopType loopType = LoopType::ForOp;
+  SCFTilingOptions &setLoopType(LoopType type) {
+    loopType = type;
+    return *this;
+  }
+
   /// Specify mapping of loops to devices. This is only respected when the loop
   /// constructs support such a mapping (like `scf.forall`). Will be ignored
   /// when using loop constructs that dont support such a mapping (like
@@ -71,7 +80,7 @@ struct SCFTilingResult {
   /// of the last op.
   SmallVector<Operation *> tiledOps;
   /// The `scf.for` operations that iterate over the tiles.
-  SmallVector<Operation *> loops;
+  SmallVector<LoopLikeOpInterface> loops;
   /// Values to use as replacements for the untiled op. Is the same size as the
   /// number of results of the untiled op.
   SmallVector<Value> replacements;
@@ -79,15 +88,9 @@ struct SCFTilingResult {
 
 /// Method to tile an op that implements the `TilingInterface` using
 /// `scf.for` for iterating over the tiles.
-FailureOr<SCFTilingResult> tileUsingSCFForOp(RewriterBase &rewriter,
-                                             TilingInterface op,
-                                             const SCFTilingOptions &options);
-
-/// Method to tile an op that implements the `TilingInterface` using
-/// `scf.forall`.
-FailureOr<SCFTilingResult>
-tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
-                     const SCFTilingOptions &options);
+FailureOr<SCFTilingResult> tileUsingSCF(RewriterBase &rewriter,
+                                        TilingInterface op,
+                                        const SCFTilingOptions &options);
 
 /// Options used to control tile + fuse.
 struct SCFTileAndFuseOptions {
@@ -135,7 +138,7 @@ struct SCFFuseProducerOfSliceResult {
 std::optional<SCFFuseProducerOfSliceResult>
 tileAndFuseProducerOfSlice(RewriterBase &rewriter,
                            tensor::ExtractSliceOp candidateSliceOp,
-                           MutableArrayRef<scf::ForOp> loops);
+                           MutableArrayRef<LoopLikeOpInterface> loops);
 
 /// Reconstruct the fused producer from within the tiled-and-fused code. Based
 /// on the slice of the producer computed in place it is possible that within
@@ -187,10 +190,10 @@ tileAndFuseProducerOfSlice(RewriterBase &rewriter,
 /// where `%0` had other uses as well. If not reconstructed from within the loop
 /// body, uses of `%0` could not be replaced, making it still live and the
 /// fusion immaterial.
-void yieldReplacementForFusedProducer(
+LogicalResult yieldReplacementForFusedProducer(
     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
-    MutableArrayRef<scf::ForOp> loops);
+    MutableArrayRef<LoopLikeOpInterface> loops);
 
 /// Transformation information returned after tile and fuse.
 struct SCFTileAndFuseResult {
@@ -201,7 +204,7 @@ struct SCFTileAndFuseResult {
   /// generated operation.
   llvm::SetVector<Operation *> tiledAndFusedOps;
   /// The `scf.for` operations that iterate over the tiles.
-  SmallVector<Operation *> loops;
+  SmallVector<LoopLikeOpInterface> loops;
   /// The replacement values to use for the tiled and fused operations.
   llvm::DenseMap<Value, Value> replacements;
 };
@@ -232,9 +235,9 @@ struct SCFTileAndFuseResult {
 /// }
 /// ```
 FailureOr<SCFTileAndFuseResult>
-tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
-    RewriterBase &rewriter, TilingInterface consumer,
-    const SCFTileAndFuseOptions &options);
+tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
+                                     TilingInterface consumer,
+                                     const SCFTileAndFuseOptions &options);
 
 /// Method to lower an `op` that implements the `TilingInterface` to
 /// loops/scalars.
@@ -249,8 +252,8 @@ struct SCFReductionTilingResult {
   Operation *mergeOp;
   /// Initial op
   Operation *initialOp;
-  /// The `scf.for` operations that iterate over the tiles.
-  SmallVector<scf::ForOp> loops;
+  /// The loop operations that iterate over the tiles.
+  SmallVector<LoopLikeOpInterface> loops;
 };
 
 /// Method to tile a reduction and generate a parallel op within a serial loop.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 97d2b4a3be5c56..5efa2b3eb7476f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -485,8 +485,8 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
       tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
       [&](TilingInterface tilingInterfaceOp)
           -> FailureOr<scf::SCFTileAndFuseResult> {
-        return tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
-            rewriter, tilingInterfaceOp, tileAndFuseOptions);
+        return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
+                                                    tileAndFuseOptions);
       });
   return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
                         : DiagnosedSilenceableFailure::success();
@@ -584,7 +584,7 @@ static Operation *replaceForAllWithNewSignature(
   Operation *firstYieldOp = yieldingOps.front();
   rewriter.setInsertionPoint(firstYieldOp);
   Value src = tileAndFuseResult.tiledValues[0];
-  Value dst = newforallOp.getOutputBlockArguments().back();
+  Value dst = newforallOp.getRegionIterArgs().back();
   SmallVector<OpFoldResult> strides(offsets.size(), rewriter.getIndexAttr(1));
   rewriter.create<tensor::ParallelInsertSliceOp>(firstYieldOp->getLoc(), src,
                                                  dst, offsets, sizes, strides);
@@ -2063,7 +2063,7 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter,
   });
   SmallVector<int64_t> emptyTileSizes;
   rewriter.setInsertionPoint(target);
-  FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp(
+  FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCF(
       rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
   if (failed(maybeTilingResult))
     return emitDefaultDefiniteFailure(target);
@@ -2647,7 +2647,7 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter,
 
     tilingOptions.setInterchange(getInterchange());
     FailureOr<scf::SCFTilingResult> maybeTilingResult =
-        tileUsingSCFForOp(rewriter, tilingInterface, tilingOptions);
+        tileUsingSCF(rewriter, tilingInterface, tilingOptions);
     if (failed(maybeTilingResult))
       return DiagnosedSilenceableFailure::definiteFailure();
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index 7f3ab1f1a24b2f..339d06cdeaf60a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -358,7 +358,7 @@ static FailureOr<ForallTilingResult> tileToForallOpImpl(
 
   // 3. Clone the tileable op and update its destination operands to use the
   // output bbArgs of the ForallOp.
-  ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
+  ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
   Operation *tiledOp = nullptr;
   SmallVector<Value> tiledValues;
   {
@@ -695,7 +695,7 @@ FailureOr<linalg::ForallReductionTilingResult> linalg::tileReductionUsingForall(
   // 4. Clone the tileable op and update its destination operands to use the
   // output bbArgs of the ForallOp.
   SmallVector<Value> tilingResults;
-  ArrayRef<BlockArgument> destBbArgs = forallOp.getOutputBlockArguments();
+  ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
   {
     // 4.a. RAII guard, inserting within forallOp, before terminator.
     OpBuilder::InsertionGuard g(b);
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5570c2ec688c8a..7dbc8f250b25eb 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -527,6 +527,10 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
 
 SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
 
+Block::BlockArgListType ForOp::getRegionIterArgs() {
+  return getBody()->getArguments().drop_front(getNumInductionVars());
+}
+
 MutableArrayRef<OpOperand> ForOp::getInitsMutable() {
   return getInitArgsMutable();
 }
@@ -622,6 +626,47 @@ LogicalResult scf::ForallOp::promoteIfSingleIteration(RewriterBase &rewriter) {
   return success();
 }
 
+Block::BlockArgListType ForallOp::getRegionIterArgs() {
+  return getBody()->getArguments().drop_front(getRank());
+}
+
+MutableArrayRef<OpOperand> ForallOp::getInitsMutable() {
+  return getOutputsMutable();
+}
+
+FailureOr<LoopLikeOpInterface>
+ForallOp::replaceWithAdditionalYields(RewriterBase &rewriter,
+                                      ValueRange newInitOperands,
+                                      bool replaceInitOperandUsesInLoop,
+                                      const NewYieldValuesFn &newYieldValueFn) {
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(getOperation());
+  auto inits = llvm::to_vector(getOutputs());
+  inits.append(newInitOperands.begin(), newInitOperands.end());
+  auto newLoop = rewriter.create<scf::ForallOp>(
+      getLoc(), getMixedLowerBound(), getMixedUpperBound(), getMixedStep(),
+      inits, getMapping(), [](OpBuilder &, Location, ValueRange) {});
+
+  // Move the region of the current block to the newly created op.
+  Block *newLoopBody = newLoop.getBody();
+  rewriter.mergeBlocks(
+      getBody(), newLoopBody,
+      newLoopBody->getArguments().take_front(getBody()->getNumArguments()));
+
+  // Update the terminator.
+  {
+    OpBuilder::InsertionGuard g(rewriter);
+    auto terminator = cast<scf::InParallelOp>(newLoopBody->getTerminator());
+    rewriter.setInsertionPointToEnd(terminator.getBody());
+    newYieldValueFn(
+        rewriter, getLoc(),
+        newLoopBody->getArguments().take_back(newInitOperands.size()));
+  }
+  rewriter.replaceOp(getOperation(),
+                     newLoop->getResults().take_front(getNumResults()));
+  return cast<LoopLikeOpInterface>(newLoop.getOperation());
+}
+
 /// Promotes the loop body of a scf::ForallOp to its containing block.
 void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
   OpBuilder::InsertionGuard g(rewriter);
@@ -1630,9 +1675,8 @@ struct FoldTensorCastOfOutputIntoForallOp
     // mapped to the tensor.cast old-typed results of the output bbArgs. The
     // destination have to be updated to point to the output bbArgs directly.
     auto terminator = newForallOp.getTerminator();
-    for (auto [yieldingOp, outputBlockArg] :
-         llvm::zip(terminator.getYieldingOps(),
-                   newForallOp.getOutputBlockArguments())) {
+    for (auto [yieldingOp, outputBlockArg] : llvm::zip(
+             terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
       auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
       insertSliceOp.getDestMutable().assign(outputBlockArg);
     }
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 38e0625d7ce093..20baa5c8dfcfa5 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -55,32 +55,8 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
   return filledVector;
 }
 
-/// Convert a list of ops of type `SrcOpTy` to list of `Operation *`.
-template <typename SrcOpTy>
-static SmallVector<Operation *> getAsOperations(ArrayRef<SrcOpTy> ops) {
-  return llvm::to_vector(
-      llvm::map_range(ops, [](auto op) -> Operation * { return op; }));
-}
-template <typename SrcOpTy>
-static SmallVector<Operation *>
-getAsOperations(const SmallVector<SrcOpTy> &ops) {
-  return getAsOperations(ArrayRef<SrcOpTy>(ops));
-}
-
-/// Convert a list of `Operation *` to a list of `DstOpTy.
-template <typename DstOpTy>
-static SmallVector<DstOpTy> castToTypedOperations(ArrayRef<Operation *> ops) {
-  return llvm::to_vector(
-      llvm::map_range(ops, [](Operation *op) { return cast<DstOpTy>(op); }));
-}
-template <typename DstOpTy>
-static SmallVector<DstOpTy>
-castToTypedOperations(const SmallVector<Operation *> &ops) {
-  return castToTypedOperations<DstOpTy>(ArrayRef<Operation *>(ops));
-}
-
 //===----------------------------------------------------------------------===//
-// tileUsingSCFForOp implementation.
+// tileUsingSCF implementation.
 //===----------------------------------------------------------------------===//
 
 // Check if `stride` evenly divides the trip count `size - offset`.
@@ -135,66 +111,204 @@ static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
   return clonedOp;
 }
 
-/// Generate an empty loop nest that represents the tiled loop nest shell.
+/// Type of the call back function used to generate the body of the tiled
+/// loops. The loop generation methods use this callback to generate
+/// the body of the inner-most tile loop. The call back is provided with
+/// - `rewriter`: with insertion point set to the end of the inner most loop
+///    body
+/// - `ivs`: the induction variables for the surrounding loops.
+/// - `regionIterArgs`: the basic block arguments of the inner most loop that
+///   correspond to the init/result of the loop.
+/// - `resultOffsets/Sizes`: The call back returns the result of tiling the
+///   operation, but in `resultOffsets` and `resultSizes` the callback function
+///   is expected to populate the offsets/sizes to use while inserting the
+///   result back into the corresponding `regionIterArgs`. These values
+///   are used by the loop generation methods to create the appropriate yield
+///   values from the inner most loop.
+/// The call back returns the `TilingResult` obtained from tiling the operation.
+using TileLoopBodyFn = std::function<FailureOr<TilingResult>(
+    RewriterBase &rewriter, Location loc, ValueRange ivs,
+    ValueRange regionIterArgs,
+    SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+    SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
+
+/// Generate the tile-loop nest using `scf.for` operation.
 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
-/// - In `offsets` and `sizes` return the multi-dimensional offset and size of
-///   the tile processed within the inner most loop.
-/// Note that this methods adds `scf.yield` operation for all but the innermost
-/// loop. These yield the value returned by the immediately inner loop. The
-/// caller is expected to add the scf.yield operation for the innermost loop.
-static SmallVector<scf::ForOp> generateTileLoopNest(
-    OpBuilder &builder, Location loc, ArrayRef<Range> loopRanges,
-    ArrayRef<OpFoldResult> tileSizes, SmallVector<OpFoldResult> &offsets,
-    SmallVector<OpFoldResult> &sizes, ValueRange destinationTensors = {}) {
-  if (loopRanges.empty())
-    return {};
+/// - `destinationTensors` are the init values to use for the outer most loop.
+/// - `tileLoopBodyFn` is called to generated the loop body of the inner most
+///    loop.
+/// - `loops` is an in-out parameter into which the generated loops are
+///    populated.
+/// The `TilingResult` returned by calling `tileLoopBodyFn` is returned back
+/// to the caller.
+static FailureOr<TilingResult> generateLoopNestUsingForOp(
+    RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
+    ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
+    TileLoopBodyFn tileLoopBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
+  assert(!loopRanges.empty() && "unexpected empty loop ranges");
   assert(loopRanges.size() == tileSizes.size() &&
          "expected as many tile sizes as loop ranges");
-  OpBuilder::InsertionGuard guard(builder);
-  SmallVector<scf::ForOp> loops;
-  offsets.resize(loopRanges.size());
-  sizes.resize(loopRanges.size());
-
-  for (auto loopRange : llvm::enumerate(loopRanges)) {
-    Value offset =
-        getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset);
-    Value size =
-        getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size);
-    Value tileSize = getValueOrCreateConstantIndexOp(
-        builder, loc, tileSizes[loopRange.index()]);
+  OpBuilder::InsertionGuard guard(rewriter);
+  SmallVector<Value> ivs;
+
+  for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
     // No loops if tile size is zero. Set offset and size to the loop
     // offset and size.
-    if (matchPattern(tileSize, m_Zero())) {
-      offsets[loopRange.index()] = offset;
-      sizes[loopRange.index()] = size;
+    if (isConstantIntValue(tileSize, 0)) {
       continue;
     }
 
-    auto loop = builder.create<scf::ForOp>(
-        loc, offset, size, tileSize, destinationTensors,
-        [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
-            ValueRange /*iterArgs*/) {
-          sizes[loopRange.index()] =
-   ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/77874


More information about the Mlir-commits mailing list