[Mlir-commits] [mlir] [mlir][SCF] Allow using a custom operation to generate loops with `mlir::tileUsingSCF`. (PR #159660)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 18 15:02:01 PDT 2025


https://github.com/MaheshRavishankar created https://github.com/llvm/llvm-project/pull/159660

This change adds an option to use a custom operation to generate the
inter-tile loops during tiling. When the loop type is set to
scf::SCFTilingOptions::LoopType::CustomOp, the method
mlir::tileUsingSCF provides two callback functions

First one to generate the header of the loop.
Second one to generate the terminator of the loop.
These methods receive the information needed to generate the
loops/terminator and expect to return information needed to generate
the code for the intra-tile computation. See comments for more
details.

Presently this is adds support only for tiling. Subsequent commits
will update this to add support for fusion as well.

The PR is split into two commits.

The first commit is an NFC that just refactors the code (and cleans up some naming) to make it easier to add the support for custom loop operations.
The second commit adds the support for using a custom loop operation, as well as a test to exercise this path.

Note that this is duplicate of https://github.com/llvm/llvm-project/pull/159506 that was accidently committed and was reverted in https://github.com/llvm/llvm-project/pull/159598 to wait for reviews.

Signed-off-by: MaheshRavishankar [mahesh.ravishankar at gmail.com](mailto:mahesh.ravishankar at gmail.com)

>From d4140b3561ed0f9f10d9b96bb92d70ad61cf90b1 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Wed, 17 Sep 2025 21:47:46 -0700
Subject: [PATCH 1/2] [mlir][SCF] NFC refactor for better demarcation of
 splitting to use different loop types for tiling.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
 .../SCF/Transforms/TileUsingInterface.h       |  34 +-
 .../SCF/Transforms/TileUsingInterface.cpp     | 455 ++++++++++--------
 2 files changed, 279 insertions(+), 210 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 3205da6e448fc..117e1ce1371f2 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -33,6 +33,14 @@ using SCFTileSizeComputationFunction =
 
 /// Options to use to control tiling.
 struct SCFTilingOptions {
+  /// Specify which loop construct to use for tile and fuse.
+  enum class LoopType { ForOp, ForallOp};
+  LoopType loopType = LoopType::ForOp;
+  SCFTilingOptions &setLoopType(LoopType type) {
+    loopType = type;
+    return *this;
+  }
+
   /// Computation function that returns the tile sizes to use for each loop.
   /// Returning a tile size of zero implies no tiling for that loop. If the
   /// size of the returned vector is smaller than the number of loops, the inner
@@ -50,6 +58,17 @@ struct SCFTilingOptions {
   /// proper interaction with folding.
   SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> tileSizes);
 
+  /// The interchange vector to reorder the tiled loops.
+  SmallVector<int64_t> interchangeVector = {};
+  SCFTilingOptions &setInterchange(ArrayRef<int64_t> interchange) {
+    interchangeVector = llvm::to_vector(interchange);
+    return *this;
+  }
+
+  //-------------------------------------------------------------------------//
+  // Options related to tiling using `scf.forall`.
+  //-------------------------------------------------------------------------//
+
   /// Computation function that returns the number of threads to use for
   /// each loop. Returning a num threads of zero implies no tiling for that
   /// loop. If the size of the returned vector is smaller than the number of
@@ -70,21 +89,6 @@ struct SCFTilingOptions {
   /// function that computes num threads at the point they are needed.
   SCFTilingOptions &setNumThreads(ArrayRef<OpFoldResult> numThreads);
 
-  /// The interchange vector to reorder the tiled loops.
-  SmallVector<int64_t> interchangeVector = {};
-  SCFTilingOptions &setInterchange(ArrayRef<int64_t> interchange) {
-    interchangeVector = llvm::to_vector(interchange);
-    return *this;
-  }
-
-  /// Specify which loop construct to use for tile and fuse.
-  enum class LoopType { ForOp, ForallOp };
-  LoopType loopType = LoopType::ForOp;
-  SCFTilingOptions &setLoopType(LoopType type) {
-    loopType = type;
-    return *this;
-  }
-
   /// Specify mapping of loops to devices. This is only respected when the loop
   /// constructs support such a mapping (like `scf.forall`). Will be ignored
   /// when using loop constructs that dont support such a mapping (like
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 834c02126fa53..b77f66b701927 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -155,18 +155,18 @@ getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
 static LogicalResult checkTileSizes(TilingInterface op,
                                     scf::SCFTilingOptions::LoopType loopType,
                                     ReductionTilingStrategy reductionStrategy,
-                                    ArrayRef<OpFoldResult> tileSizes,
+                                    ArrayRef<OpFoldResult> givenTileSizes,
                                     ArrayRef<OpFoldResult> numThreads) {
   auto iterators = op.getLoopIteratorTypes();
-  assert(iterators.size() == tileSizes.size() &&
+  assert(iterators.size() == givenTileSizes.size() &&
          "expected as many tile size values as number of loops");
   assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
          "when specified, expected number of threads to use for each loop");
 
   bool isParallelTiling = false;
-  for (auto [index, iterator, tileSize] :
-       llvm::enumerate(iterators, tileSizes)) {
-    if (!isConstantIntValue(tileSize, 0)) {
+  for (auto [index, iterator, givenTileSize] :
+       llvm::enumerate(iterators, givenTileSizes)) {
+    if (!isConstantIntValue(givenTileSize, 0)) {
       isParallelTiling |= iterator == utils::IteratorType::parallel;
     }
 
@@ -186,7 +186,7 @@ static LogicalResult checkTileSizes(TilingInterface op,
       }
 
       if (std::optional<int64_t> constTileSize =
-              getConstantIntValue(tileSize)) {
+              getConstantIntValue(givenTileSize)) {
         if (constTileSize.value() > 0 &&
             iterator != utils::IteratorType::parallel) {
           op.emitWarning() << "tiling is not thread safe at axis #" << index;
@@ -207,11 +207,11 @@ static LogicalResult checkTileSizes(TilingInterface op,
 /// Get the reduction dims that are tiled. This accounts for reduction dims
 /// that are specified as tiled, but the tile size is 0.
 static SetVector<unsigned>
-getSanitizedReductionDims(ArrayRef<OpFoldResult> tileSizes,
+getSanitizedReductionDims(ArrayRef<OpFoldResult> givenTileSizes,
                           const scf::SCFTilingOptions &options) {
   SetVector<unsigned> reductionDims;
   for (auto dim : options.reductionDims) {
-    if (isConstantIntValue(tileSizes[dim], 0))
+    if (isConstantIntValue(givenTileSizes[dim], 0))
       continue;
     reductionDims.insert(dim);
   }
@@ -236,14 +236,14 @@ static bool tileDividesIterationDomain(Range loopRange) {
 /// `tileSize`, i.e., `min(tileSize, range.end() - offset)`.
 static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
                                        Range loopRange, OpFoldResult offset,
-                                       OpFoldResult tileSize) {
-  std::optional<int64_t> ts = getConstantIntValue(tileSize);
+                                       OpFoldResult givenTileSize) {
+  std::optional<int64_t> ts = getConstantIntValue(givenTileSize);
   if (ts && ts.value() == 1)
-    return tileSize;
+    return givenTileSize;
 
   if (tileDividesIterationDomain(
-          Range{loopRange.offset, loopRange.size, tileSize}))
-    return tileSize;
+          Range{loopRange.offset, loopRange.size, givenTileSize}))
+    return givenTileSize;
 
   // The tile size to use (to avoid out of bounds access) is  minimum of
   // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled
@@ -254,15 +254,15 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
   AffineMap minMap = AffineMap::get(1, 2, {s0 - d0, s1}, b.getContext());
   Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
   return affine::makeComposedFoldedAffineMin(
-      b, loc, minMap, SmallVector<OpFoldResult>{offset, size, tileSize});
+      b, loc, minMap, SmallVector<OpFoldResult>{offset, size, givenTileSize});
 }
 
 /// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
 /// than `iterationSize`.
-static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
+static bool canOmitTileOffsetInBoundsCheck(OpFoldResult givenTileSize,
                                            OpFoldResult numThreads,
                                            OpFoldResult iterationSize) {
-  std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize);
+  std::optional<int64_t> tileSizeConst = getConstantIntValue(givenTileSize);
   std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
   std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
   if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
@@ -274,114 +274,51 @@ static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
 /// `offset`s and `size`s of the tile of the iteration space that the
 /// innermost loop body of the generated tiled loops corresponds to.
 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
-getTileOffsetAndSizes(RewriterBase &rewriter, Location loc,
-                      ReductionTilingStrategy strategy, ValueRange ivs,
+getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
                       ArrayRef<Range> iterationDomain,
-                      ArrayRef<OpFoldResult> tileSizes,
-                      ArrayRef<OpFoldResult> numThreads,
-                      const llvm::SetVector<unsigned> &reductionDims) {
+                      ArrayRef<OpFoldResult> givenTileSizes) {
   SmallVector<OpFoldResult> offsets, sizes;
   int materializedLoopNum = 0;
-
-  if (!numThreads.empty()) {
-    AffineExpr d0, d1, s0, s1;
-    AffineExpr offsetExpr, residualTileSizeExpr;
-    bindDims(rewriter.getContext(), d0, d1);
-    bindSymbols(rewriter.getContext(), s0, s1);
-    offsetExpr = d0 + d1 * s0;
-    residualTileSizeExpr = s1 - (d0 + d1 * s0);
-
-    for (auto [index, nt, tileSize, loopRange] :
-         llvm::enumerate(numThreads, tileSizes, iterationDomain)) {
-
-      // Non-tiled cases, set the offset and size to the
-      // `loopRange.offset/size`.
-      if (isZeroInteger(nt)) {
-        offsets.push_back(loopRange.offset);
-        sizes.push_back(loopRange.size);
-        continue;
-      }
-
-      Value iv = ivs[materializedLoopNum++];
-      OpFoldResult offset = affine::makeComposedFoldedAffineApply(
-          rewriter, loc, offsetExpr,
-          ArrayRef<OpFoldResult>{loopRange.offset, iv, tileSize});
-      OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply(
-          rewriter, loc, residualTileSizeExpr,
-          {loopRange.offset, nt, tileSize, loopRange.size});
-
-      OpFoldResult size = tileSize;
-      if (!isZeroInteger(residualTileSize)) {
-        OpFoldResult sizeMinusOffsetPerThread =
-            affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
-                                                  {offset, loopRange.size});
-        size = affine::makeComposedFoldedAffineMin(
-            rewriter, loc,
-            AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()),
-            {sizeMinusOffsetPerThread, tileSize});
-      }
-
-      // Consider the case where the original loop was `[0, 100)`.
-      // If number of threads are `7`, the tile size would be computed as
-      // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6)
-      // - `offset = 0 + 6 * 15 = 105`
-      // - `tileSize = min(15, 100 - 105) = -5`
-      // To avoid negative tile sizes, we need to do a further
-      // `nonNegativeTileSize = affine.max(0, tileSize)`.
-      // This `max` can be avoided if
-      //  `offset + tileSize * (numThreads - 1) < (ub - lb)`
-      if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) {
-        AffineMap maxMap =
-            AffineMap::getMultiDimIdentityMap(2, rewriter.getContext());
-        size = affine::makeComposedFoldedAffineMax(
-            rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size});
-      }
-
-      offsets.push_back(offset);
-      sizes.push_back(size);
+  for (auto [givenTileSize, loopRange] :
+       llvm::zip_equal(givenTileSizes, iterationDomain)) {
+
+    // Non-tiled cases, set the offset and size to the
+    // `loopRange.offset/size`.
+    if (isZeroInteger(givenTileSize)) {
+      offsets.push_back(loopRange.offset);
+      sizes.push_back(loopRange.size);
+      continue;
     }
-    return {offsets, sizes};
-  } else {
-    for (auto [tileSize, loopRange] :
-         llvm::zip_equal(tileSizes, iterationDomain)) {
-
-      // Non-tiled cases, set the offset and size to the
-      // `loopRange.offset/size`.
-      if (isZeroInteger(tileSize)) {
-        offsets.push_back(loopRange.offset);
-        sizes.push_back(loopRange.size);
-        continue;
-      }
 
-      Value iv = ivs[materializedLoopNum++];
-      OpFoldResult offset = getAsOpFoldResult(iv);
-      offsets.push_back(offset);
-      OpFoldResult size =
-          getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize);
-      sizes.push_back(size);
-    }
-    return {offsets, sizes};
+    Value iv = ivs[materializedLoopNum++];
+    OpFoldResult offset = getAsOpFoldResult(iv);
+    offsets.push_back(offset);
+    OpFoldResult size =
+        getBoundedTileSize(rewriter, loc, loopRange, offset, givenTileSize);
+    sizes.push_back(size);
   }
+  return {offsets, sizes};
 }
 
 /// Function to return the bounds of the loops to be generated.
 static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
                   SmallVector<OpFoldResult>>
 getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
-              ArrayRef<OpFoldResult> tileSizes) {
+              ArrayRef<OpFoldResult> givenTileSizes) {
   SmallVector<OpFoldResult> lbs, ubs, steps;
-  for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
+  for (auto [loopRange, givenTileSize] :
+       llvm::zip_equal(loopRanges, givenTileSizes)) {
     // No loop if the tile size is 0.
-    if (isZeroInteger(tileSize))
+    if (isZeroInteger(givenTileSize))
       continue;
     lbs.push_back(loopRange.offset);
     ubs.push_back(loopRange.size);
-    steps.push_back(tileSize);
+    steps.push_back(givenTileSize);
   }
   return {lbs, ubs, steps};
 }
 
-/// A function that allows returning additional yielded values during
+/// Typedef for function that allows returning additional yielded values during
 /// `yieldTiledValuesAndReplace`.
 /// - `ivs` induction variable for the loop.
 /// - `newBbArgs` basic block arguments corresponding to newly added iter_args.
@@ -402,6 +339,30 @@ using YieldTiledValuesFn = std::function<LogicalResult(
     SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
     SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
 
+/// Typedef for function that implements the body of a tiled loop.
+/// - `ivs` induction variable for the loop.
+/// - `tileOffsets` represents offsets for the tiled iteration space.
+/// - `tileSizes` represents the sizes for the tiled iteraiton space.
+/// - `outerDestinationTensors` tensor that holds the result. Is same size
+///   as the destination operands of the original operations.
+/// - `tiledResults` results of the tiled computation, corresponds to
+///   tiles of the original operation computed by the loop body.
+///   Should be same size as the `destinationTensors`
+/// - `resultOffsets` is of the same size as `tiledResults` and represents
+///   the offset to use when writing the corresponding element from
+///   `tiledResults` into `destinationTensors`.
+/// - `resultOffsets` is of the same size as `tiledResults` and represents
+///   the size to use when writing the corresponding element from
+///   `tiledResults` into `destinationTensors`.
+/// In case the method needs to return `failure()` the method is expected
+/// to clean up any inserted operations.
+using GenerateTiledBodyFn = std::function<LogicalResult(
+    RewriterBase &rewriter, Location Loc, ValueRange ivs,
+    ArrayRef<OpFoldResult> tileOffsets, ArrayRef<OpFoldResult> tileSizes,
+    ValueRange outerDestinationTensors, SmallVector<Value> &tiledResults,
+    SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+    SmallVector<SmallVector<OpFoldResult>> &resultSizes)>;
+
 /// Clones the operation and updates the destination if the operation
 /// implements the `DestinationStyleOpInterface`.
 static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
@@ -417,26 +378,25 @@ static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
 
 /// Generate the tile-loop nest using `scf.for` operation.
 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
-/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
-/// - `destinationTensors` are the init values to use for the outer most loop.
-/// - `yieldTiledValuesFn` is called to generated the loop body of the inner
+/// - `givenTileSizes` is the tile sizes to use. Zero represent untiled loops.
+/// - `outerDestinationTensors` are the init values to use for the outer most
+/// loop.
+/// - `tiledBodyFn` is called to generated the loop body of the inner
 /// most
 ///    loop.
-/// - `loops` is an in-out parameter into which the generated loops are
-///    populated.
-static LogicalResult generateLoopNestUsingForOp(
+/// Returns the generated `scf.for` loops on success.
+static FailureOr<SmallVector<LoopLikeOpInterface>> generateLoopNestUsingForOp(
     RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
-    ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
-    YieldTiledValuesFn yieldTiledValuesFn,
-    SmallVector<LoopLikeOpInterface> &loops) {
+    ArrayRef<OpFoldResult> givenTileSizes, ValueRange outerDestinationTensors,
+    GenerateTiledBodyFn tiledBodyFn) {
   assert(!loopRanges.empty() && "unexpected empty loop ranges");
-  assert(loopRanges.size() == tileSizes.size() &&
+  assert(loopRanges.size() == givenTileSizes.size() &&
          "expected as many tile sizes as loop ranges");
   OpBuilder::InsertionGuard guard(rewriter);
 
   SmallVector<OpFoldResult> lbs, ubs, steps;
   std::tie(lbs, ubs, steps) =
-      getLoopBounds(rewriter, loc, loopRanges, tileSizes);
+      getLoopBounds(rewriter, loc, loopRanges, givenTileSizes);
   SmallVector<Value> lbVals =
       getValueOrCreateConstantIndexOp(rewriter, loc, lbs);
   SmallVector<Value> ubVals =
@@ -445,34 +405,42 @@ static LogicalResult generateLoopNestUsingForOp(
       getValueOrCreateConstantIndexOp(rewriter, loc, steps);
 
   SmallVector<Value> ivs;
+  SmallVector<LoopLikeOpInterface> loops;
+  ValueRange innerDestinationTensors(outerDestinationTensors);
   for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
     auto loop =
-        scf::ForOp::create(rewriter, loc, lb, ub, step, destinationTensors,
+        scf::ForOp::create(rewriter, loc, lb, ub, step, innerDestinationTensors,
                            [](OpBuilder &bodyBuilder, Location bodyLoc,
                               Value iv, ValueRange /*iterArgs*/) {});
     loops.push_back(loop);
     ivs.push_back(loop.getInductionVar());
     rewriter.setInsertionPointToEnd(loop.getBody());
-    destinationTensors = loop.getRegionIterArgs();
+    innerDestinationTensors = loop.getRegionIterArgs();
   }
 
+  // Compute the `offsets` and `sizes` to use for tiling.
+  SmallVector<OpFoldResult> offsets, sizes;
+  std::tie(offsets, sizes) =
+      getTileOffsetAndSizes(rewriter, loc, ivs, loopRanges, givenTileSizes);
+
   SmallVector<Value> tiledResults;
   SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
-  if (failed(yieldTiledValuesFn(rewriter, loc, ivs, destinationTensors,
-                                tiledResults, resultOffsets, resultSizes))) {
+  if (failed(tiledBodyFn(rewriter, loc, ivs, offsets, sizes,
+                         innerDestinationTensors, tiledResults, resultOffsets,
+                         resultSizes))) {
     return rewriter.notifyMatchFailure(
         loc, "failed to generate inner tile loop body");
   }
   if (loops.empty())
-    return success();
+    return loops;
 
-  assert(tiledResults.size() == destinationTensors.size() &&
+  assert(tiledResults.size() == innerDestinationTensors.size() &&
          "Number of results of body should be equal to number of iter args");
 
   // 6. Yield all the results of the tiled operation.
   SmallVector<Value> yieldedValues;
   for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
-       llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
+       llvm::zip_equal(tiledResults, innerDestinationTensors, resultOffsets,
                        resultSizes)) {
     SmallVector<OpFoldResult> resultStride(resultOffset.size(),
                                            rewriter.getIndexAttr(1));
@@ -491,27 +459,108 @@ static LogicalResult generateLoopNestUsingForOp(
         cast<scf::ForOp>(outerLoop.getOperation()).getBody());
     scf::YieldOp::create(rewriter, outerLoop.getLoc(), innerLoop->getResults());
   }
-  return success();
+  return loops;
+}
+
+/// Compute the `OpFoldResult`s that represents the multi-dimensional
+/// `offset`s and `size`s of the tile of the iteration space that the
+/// innermost loop body of the generated tiled loops corresponds to
+/// when tiling using `forall` op. This is handle separately dut to
+/// the special case handling needed for when the tiling is done by
+/// specifying number of threads.
+static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
+getTileOffsetAndSizesWithForAllOp(RewriterBase &rewriter, Location loc,
+                                  ValueRange ivs,
+                                  ArrayRef<Range> iterationDomain,
+                                  ArrayRef<OpFoldResult> givenTileSizes,
+                                  ArrayRef<OpFoldResult> numThreads) {
+  if (numThreads.empty()) {
+    return getTileOffsetAndSizes(rewriter, loc, ivs, iterationDomain,
+                                 givenTileSizes);
+  }
+
+  SmallVector<OpFoldResult> offsets, sizes;
+  int materializedLoopNum = 0;
+
+  AffineExpr d0, d1, s0, s1;
+  AffineExpr offsetExpr, residualTileSizeExpr;
+  bindDims(rewriter.getContext(), d0, d1);
+  bindSymbols(rewriter.getContext(), s0, s1);
+  offsetExpr = d0 + d1 * s0;
+  residualTileSizeExpr = s1 - (d0 + d1 * s0);
+
+  for (auto [index, nt, givenTileSize, loopRange] :
+       llvm::enumerate(numThreads, givenTileSizes, iterationDomain)) {
+
+    // Non-tiled cases, set the offset and size to the
+    // `loopRange.offset/size`.
+    if (isZeroInteger(nt)) {
+      offsets.push_back(loopRange.offset);
+      sizes.push_back(loopRange.size);
+      continue;
+    }
+
+    Value iv = ivs[materializedLoopNum++];
+    OpFoldResult offset = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, offsetExpr,
+        ArrayRef<OpFoldResult>{loopRange.offset, iv, givenTileSize});
+    OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply(
+        rewriter, loc, residualTileSizeExpr,
+        {loopRange.offset, nt, givenTileSize, loopRange.size});
+
+    OpFoldResult size = givenTileSize;
+    if (!isZeroInteger(residualTileSize)) {
+      OpFoldResult sizeMinusOffsetPerThread =
+          affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
+                                                {offset, loopRange.size});
+      size = affine::makeComposedFoldedAffineMin(
+          rewriter, loc,
+          AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()),
+          {sizeMinusOffsetPerThread, givenTileSize});
+    }
+
+    // Consider the case where the original loop was `[0, 100)`.
+    // If number of threads are `7`, the tile size would be computed as
+    // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6)
+    // - `offset = 0 + 6 * 15 = 105`
+    // - `tileSize = min(15, 100 - 105) = -5`
+    // To avoid negative tile sizes, we need to do a further
+    // `nonNegativeTileSize = affine.max(0, tileSize)`.
+    // This `max` can be avoided if
+    //  `offset + tileSize * (numThreads - 1) < (ub - lb)`
+    if (!canOmitTileOffsetInBoundsCheck(givenTileSize, nt, loopRange.size)) {
+      AffineMap maxMap =
+          AffineMap::getMultiDimIdentityMap(2, rewriter.getContext());
+      size = affine::makeComposedFoldedAffineMax(
+          rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size});
+    }
+
+    offsets.push_back(offset);
+    sizes.push_back(size);
+  }
+  return {offsets, sizes};
 }
 
 /// Generate the tile-loop nest using `scf.forall` operation.
 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
-/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
-/// - `destinationTensors` are the init values to use for the outer most loop.
+/// - `giventileSizes` is the tile sizes to use. Zero represent untiled loops.
+/// - `outerDestinationTensors` are the init values to use for the loop.
 /// - `mappingVector` is the mapping attributes to use for loop construction.
 ///   Can be empty.
-/// - `yieldTiledValuesFn` is called to generated the loop body of the inner
+/// - `tiledBodyFn` is called to generated the loop body of the inner
 /// most
 ///    loop.
-/// - `loops` is an in-out parameter into which the generated loops are
-///    populated.
-static LogicalResult generateLoopNestUsingForallOp(
-    RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
-    ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
-    ArrayRef<Attribute> mappingVector, ValueRange destinationTensors,
-    YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
+/// Returns the generated `scf.forall` loop on success.
+static FailureOr<SmallVector<LoopLikeOpInterface>>
+generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc,
+                              ArrayRef<Range> loopRanges,
+                              ArrayRef<OpFoldResult> givenTileSizes,
+                              ArrayRef<OpFoldResult> numThreads,
+                              ArrayRef<Attribute> mappingVector,
+                              ValueRange outerDestinationTensors,
+                              GenerateTiledBodyFn tiledBodyFn) {
   assert(!loopRanges.empty() && "unexpected empty loop ranges");
-  assert(loopRanges.size() == tileSizes.size() &&
+  assert(loopRanges.size() == givenTileSizes.size() &&
          "expected as many tile sizes as loop ranges");
   OpBuilder::InsertionGuard guard(rewriter);
 
@@ -522,6 +571,7 @@ static LogicalResult generateLoopNestUsingForallOp(
   scf::ForallOp forallOp;
   bool useNumThreads = !numThreads.empty();
 
+  SmallVector<LoopLikeOpInterface> loops;
   if (useNumThreads) {
     // Prune the zero numthreads.
     SmallVector<OpFoldResult> nonZeroNumThreads;
@@ -531,29 +581,35 @@ static LogicalResult generateLoopNestUsingForallOp(
       nonZeroNumThreads.push_back(nt);
     }
     forallOp = scf::ForallOp::create(rewriter, loc, nonZeroNumThreads,
-                                     destinationTensors, mappingAttr);
+                                     outerDestinationTensors, mappingAttr);
   } else {
     SmallVector<OpFoldResult> lbs, ubs, steps;
     std::tie(lbs, ubs, steps) =
-        getLoopBounds(rewriter, loc, loopRanges, tileSizes);
+        getLoopBounds(rewriter, loc, loopRanges, givenTileSizes);
     forallOp = scf::ForallOp::create(rewriter, loc, lbs, ubs, steps,
-                                     destinationTensors, mappingAttr);
+                                     outerDestinationTensors, mappingAttr);
   }
   loops.push_back(forallOp);
 
   rewriter.setInsertionPoint(forallOp.getTerminator());
-  destinationTensors = forallOp.getRegionOutArgs();
+  ValueRange innerDestinationTensors = forallOp.getRegionOutArgs();
+  SmallVector<Value> ivs = forallOp.getInductionVars();
+
+  // Compute the `offsets` and `sizes` to use for tiling.
+  SmallVector<OpFoldResult> offsets, sizes;
+  std::tie(offsets, sizes) = getTileOffsetAndSizesWithForAllOp(
+      rewriter, loc, ivs, loopRanges, givenTileSizes, numThreads);
 
   SmallVector<Value> tiledResults;
   SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
-  if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
-                         destinationTensors, tiledResults, resultOffsets,
+  if (failed(tiledBodyFn(rewriter, loc, ivs, offsets, sizes,
+                         innerDestinationTensors, tiledResults, resultOffsets,
                          resultSizes)))
     return rewriter.notifyMatchFailure(loc, "failed to generate loop body");
 
   rewriter.setInsertionPointToEnd(forallOp.getTerminator().getBody());
   for (auto [tiledValue, destinationTensor, resultOffset, resultSize] :
-       llvm::zip_equal(tiledResults, destinationTensors, resultOffsets,
+       llvm::zip_equal(tiledResults, innerDestinationTensors, resultOffsets,
                        resultSizes)) {
     SmallVector<OpFoldResult> resultStride(resultOffset.size(),
                                            rewriter.getIndexAttr(1));
@@ -562,41 +618,48 @@ static LogicalResult generateLoopNestUsingForallOp(
                                           destinationTensor, resultOffset,
                                           resultSize, resultStride);
   }
-  return success();
+  return loops;
 }
 
 /// Generate the tile-loop nest using the loop construct specifed in `options`.
 /// - `options`: Tiling options specified.
 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
 /// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
-/// - `destinationTensors` are the init values to use for the outer most loop.
+/// - `outerDestinationTensors` are the init values to use for the outer most
+/// loop.
 /// - `yieldTiledValuesFn` is called to generated the loop body of the inner
 /// most
 ///    loop.
-/// - `loops` is an in-out parameter into which the generated loops are
-///    populated.
-static LogicalResult generateLoopNest(
-    RewriterBase &rewriter, Location loc,
-    scf::SCFTilingOptions::LoopType loopType, ArrayRef<Range> loopRanges,
-    ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
-    ValueRange destinationTensors, ArrayRef<Attribute> mappingVector,
-    YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
+/// Returns the generated loops on success.
+static FailureOr<SmallVector<LoopLikeOpInterface>> generateLoopNest(
+    RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options,
+    ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> givenTileSizes,
+    ArrayRef<OpFoldResult> numThreads, ValueRange destinationTensors,
+    GenerateTiledBodyFn tiledBodyFn) {
   // If the tile sizes are all zero, no loops are generated. Just call the
   // callback function to handle untiled case.
-  if (llvm::all_of(tileSizes, isZeroInteger)) {
+  if (llvm::all_of(givenTileSizes, isZeroInteger)) {
     SmallVector<Value> tiledResults;
     SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
-    return tiledBodyFn(rewriter, loc, ValueRange{}, destinationTensors,
-                       tiledResults, resultOffsets, resultSizes);
+    auto tileOffsets =
+        llvm::map_to_vector(loopRanges, [](Range r) { return r.offset; });
+    auto tileSizes =
+        llvm::map_to_vector(loopRanges, [](Range r) { return r.size; });
+    if (failed(tiledBodyFn(rewriter, loc, ValueRange{}, tileOffsets, tileSizes,
+                           destinationTensors, tiledResults, resultOffsets,
+                           resultSizes))) {
+      return failure();
+    }
+    return SmallVector<LoopLikeOpInterface>{};
   }
-  if (loopType == scf::SCFTilingOptions::LoopType::ForOp) {
-    return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
-                                      destinationTensors, tiledBodyFn, loops);
+  if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) {
+    return generateLoopNestUsingForOp(rewriter, loc, loopRanges, givenTileSizes,
+                                      destinationTensors, tiledBodyFn);
   }
-  if (loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
+  if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
     return generateLoopNestUsingForallOp(
-        rewriter, loc, loopRanges, tileSizes, numThreads, mappingVector,
-        destinationTensors, tiledBodyFn, loops);
+        rewriter, loc, loopRanges, givenTileSizes, numThreads,
+        options.mappingVector, destinationTensors, tiledBodyFn);
   }
   return rewriter.notifyMatchFailure(loc, "unhandled loop type");
 }
@@ -604,7 +667,7 @@ static LogicalResult generateLoopNest(
 static FailureOr<SmallVector<Value>> createInitialTensorsForTiling(
     RewriterBase &rewriter, TilingInterface op,
     ReductionTilingStrategy reductionStrategy, ArrayRef<Range> iterationDomain,
-    ArrayRef<OpFoldResult> numThreads, ArrayRef<OpFoldResult> tileSizes,
+    ArrayRef<OpFoldResult> numThreads, ArrayRef<OpFoldResult> givenTileSizes,
     const SetVector<unsigned> &reductionDims) {
   SmallVector<Value> initTensors;
   Location loc = op->getLoc();
@@ -626,7 +689,7 @@ static FailureOr<SmallVector<Value>> createInitialTensorsForTiling(
   AffineExpr sizeExpr = ((s0 - s1).ceilDiv(s2));
   AffineExpr divExpr = s0.ceilDiv(s1);
   for (auto [index, domain, tileSize] :
-       llvm::enumerate(iterationDomain, tileSizes)) {
+       llvm::enumerate(iterationDomain, givenTileSizes)) {
     if (!numThreads.empty()) {
       // Untiled case.
       if (isConstantIntValue(numThreads[index], 0)) {
@@ -672,7 +735,7 @@ static SmallVector<OpFoldResult>
 getSplitReductionIvs(RewriterBase &rewriter, Location loc,
                      ReductionTilingStrategy reductionStrategy, ValueRange ivs,
                      ArrayRef<OpFoldResult> numThreads,
-                     ArrayRef<OpFoldResult> tileSizes,
+                     ArrayRef<OpFoldResult> givenTileSizes,
                      const SetVector<unsigned> &reductionDims) {
   SmallVector<OpFoldResult> splitReductionIvs;
   splitReductionIvs.resize(reductionDims.size(), rewriter.getIndexAttr(0));
@@ -689,7 +752,7 @@ getSplitReductionIvs(RewriterBase &rewriter, Location loc,
       }
       splitReductionIvs[index] = affine::makeComposedFoldedAffineApply(
           rewriter, loc, divExpr,
-          ArrayRef<OpFoldResult>{ivs[ivIndex++], tileSizes[reductionDim]});
+          ArrayRef<OpFoldResult>{ivs[ivIndex++], givenTileSizes[reductionDim]});
     }
   }
   return splitReductionIvs;
@@ -701,7 +764,7 @@ getTiledImplementation(RewriterBase &rewriter, TilingInterface op,
                        ValueRange regionIterArg, ArrayRef<OpFoldResult> offsets,
                        ArrayRef<OpFoldResult> sizes, ValueRange ivs,
                        ArrayRef<OpFoldResult> numThreads,
-                       ArrayRef<OpFoldResult> tileSizes,
+                       ArrayRef<OpFoldResult> givenTileSizes,
                        const SetVector<unsigned> &reductionDims) {
   if (reductionStrategy == ReductionTilingStrategy::FullReduction) {
     return op.getTiledImplementation(rewriter, offsets, sizes);
@@ -717,7 +780,7 @@ getTiledImplementation(RewriterBase &rewriter, TilingInterface op,
 
   SmallVector<OpFoldResult> splitReductionIvs =
       getSplitReductionIvs(rewriter, op.getLoc(), reductionStrategy, ivs,
-                           numThreads, tileSizes, reductionDims);
+                           numThreads, givenTileSizes, reductionDims);
   return redOp.tileToPartialReduction(rewriter, op.getLoc(), reductionStrategy,
                                       regionIterArg, offsets, sizes,
                                       reductionDims, splitReductionIvs);
@@ -728,7 +791,8 @@ static LogicalResult getResultTilePosition(
     int64_t index, Value tiledResult, TilingInterface op,
     ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
     ValueRange ivs, ArrayRef<OpFoldResult> numThreads,
-    ArrayRef<OpFoldResult> tileSizes, const SetVector<unsigned> &reductionDims,
+    ArrayRef<OpFoldResult> givenTileSizes,
+    const SetVector<unsigned> &reductionDims,
     SmallVector<OpFoldResult> &resultOffset,
     SmallVector<OpFoldResult> &resultSize) {
 
@@ -744,7 +808,7 @@ static LogicalResult getResultTilePosition(
   }
   SmallVector<OpFoldResult> splitReductionIvs =
       getSplitReductionIvs(rewriter, op.getLoc(), reductionStrategy, ivs,
-                           numThreads, tileSizes, reductionDims);
+                           numThreads, givenTileSizes, reductionDims);
   return redOp.getPartialResultTilePosition(
       rewriter, index, reductionStrategy, offsets, sizes, reductionDims,
       splitReductionIvs, resultOffset, resultSize);
@@ -999,20 +1063,20 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
   SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
 
   // 2. Materialize the tile sizes and/or number of threads;
-  SmallVector<OpFoldResult> tileSizes, numThreads;
-  std::tie(tileSizes, numThreads) =
+  SmallVector<OpFoldResult> givenTileSizes, numThreads;
+  std::tie(givenTileSizes, numThreads) =
       getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options);
 
   // Check if it is safe to tile. This is hold over from previous iterations
   // of tile to for-all. Consider dropping it.
   if (failed(checkTileSizes(op, options.loopType, options.reductionStrategy,
-                            tileSizes, numThreads))) {
+                            givenTileSizes, numThreads))) {
     return failure();
   }
 
   // Get the reduction dims
   SetVector<unsigned> reductionDims =
-      getSanitizedReductionDims(tileSizes, options);
+      getSanitizedReductionDims(givenTileSizes, options);
 
   // 3. If there is an interchange specified, permute the iteration domain and
   // the tile sizes.
@@ -1024,7 +1088,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
            "expected interchange vector to be a permutation");
 
     applyPermutationToVector(iterationDomain, interchangeVector);
-    applyPermutationToVector(tileSizes, interchangeVector);
+    applyPermutationToVector(givenTileSizes, interchangeVector);
     if (!numThreads.empty())
       applyPermutationToVector(numThreads, interchangeVector);
   }
@@ -1032,24 +1096,21 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
   FailureOr<TilingResult> tilingResult;
   // 4. Define the lambda function used later to generate the body of the
   // innermost tiled loop.
-  YieldTiledValuesFn innerYieldTiledValuesFn =
+  GenerateTiledBodyFn innerYieldTiledValuesFn =
       [&](RewriterBase &rewriter, Location loc, ValueRange ivs,
+          ArrayRef<OpFoldResult> tileOffsets, ArrayRef<OpFoldResult> tileSizes,
           ValueRange regionIterArgs, SmallVector<Value> &tiledResults,
           SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
           SmallVector<SmallVector<OpFoldResult>> &resultSizes)
       -> LogicalResult {
-    // 4a. Compute the `offsets` and `sizes` to use for tiling.
-    SmallVector<OpFoldResult> offsets, sizes;
-    std::tie(offsets, sizes) = getTileOffsetAndSizes(
-        rewriter, loc, options.reductionStrategy, ivs, iterationDomain,
-        tileSizes, numThreads, reductionDims);
-
     // 4b. If interchange was provided, apply inverse of the interchange
     //     to get back the offsets/sizes in the order to be specified.
+    SmallVector<OpFoldResult> tileOffsetsVec = llvm::to_vector(tileOffsets);
+    SmallVector<OpFoldResult> tileSizesVec = llvm::to_vector(tileSizes);
     if (!interchangeVector.empty()) {
       auto inversePermutation = invertPermutationVector(interchangeVector);
-      applyPermutationToVector(offsets, inversePermutation);
-      applyPermutationToVector(sizes, inversePermutation);
+      applyPermutationToVector(tileOffsetsVec, inversePermutation);
+      applyPermutationToVector(tileSizesVec, inversePermutation);
     }
 
     // 5. Generate the tiled implementation within the inner most loop.
@@ -1061,7 +1122,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
     // 5b. Early return cloned op if tiling is not happening. We can not
     // return the original op because it could lead to `rewriter.replaceOp(op,
     // op->getResults())` and users would get crash.
-    if (llvm::all_of(tileSizes, isZeroInteger)) {
+    if (llvm::all_of(givenTileSizes, isZeroInteger)) {
       tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
       tilingResult =
           TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(),
@@ -1070,9 +1131,10 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
     }
 
     // 5c. Tile the cloned operation.
-    tilingResult = getTiledImplementation(
-        rewriter, clonedOp, options.reductionStrategy, regionIterArgs, offsets,
-        sizes, ivs, numThreads, tileSizes, reductionDims);
+    tilingResult =
+        getTiledImplementation(rewriter, clonedOp, options.reductionStrategy,
+                               regionIterArgs, tileOffsetsVec, tileSizesVec,
+                               ivs, numThreads, givenTileSizes, reductionDims);
     if (failed(tilingResult)) {
       rewriter.eraseOp(clonedOp);
       return op.emitOpError("faild to tile operation");
@@ -1089,8 +1151,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
       SmallVector<OpFoldResult> resultOffset, resultSize;
       if (failed(getResultTilePosition(
               rewriter, options.reductionStrategy, index, tiledValue, op,
-              offsets, sizes, ivs, numThreads, tileSizes, reductionDims,
-              resultOffset, resultSize))) {
+              tileOffsetsVec, tileSizesVec, ivs, numThreads, givenTileSizes,
+              reductionDims, resultOffset, resultSize))) {
         for (auto op : tilingResult->tiledOps) {
           rewriter.eraseOp(op);
         }
@@ -1107,7 +1169,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
   // 6. Find the destination tensors to use for the operation.
   FailureOr<SmallVector<Value>> maybeInits = createInitialTensorsForTiling(
       rewriter, op, options.reductionStrategy, iterationDomain, numThreads,
-      tileSizes, reductionDims);
+      givenTileSizes, reductionDims);
   if (failed(maybeInits)) {
     return rewriter.notifyMatchFailure(
         op, "unable to create initial tensors for tiling");
@@ -1116,13 +1178,16 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
 
   // 7. Generate the tiled loops nest using the callback defined above.
   SmallVector<LoopLikeOpInterface> loops;
-  if (failed(generateLoopNest(rewriter, op.getLoc(), options.loopType,
-                              iterationDomain, tileSizes, numThreads,
-                              initTensors, options.mappingVector,
-                              innerYieldTiledValuesFn, loops)))
-    return op.emitOpError("failed to generate tiling loops");
-  assert(succeeded(tilingResult) &&
-         "expected tiling result to be computed after loop generation");
+  {
+    FailureOr<SmallVector<LoopLikeOpInterface>> loopsOr = generateLoopNest(
+        rewriter, op.getLoc(), options, iterationDomain, givenTileSizes,
+        numThreads, initTensors, innerYieldTiledValuesFn);
+    if (failed(loopsOr))
+      return op.emitOpError("failed to generate tiling loops");
+    assert(succeeded(tilingResult) &&
+           "expected tiling result to be computed after loop generation");
+    std::swap(loops, loopsOr.value());
+  }
 
   if (loops.empty()) {
     // If loops are empty, the tiled op is used as the replacement for the

>From 7f636c27a0f12190d201b0902181ff9dc61801f1 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Wed, 17 Sep 2025 21:48:50 -0700
Subject: [PATCH 2/2] [mlir][SCF] Allow using a custom operation to generate
 loops with `mlir::tileUsingSCF`.

This change adds an option to use a custom operation to generate the
inter-tile loops during tiling. When the loop type is set to
`scf::SCFTilingOptions::LoopType::CustomOp`, the method
`mlir::tileUsingSCF` provides two callback functions

1. First one to generate the header of the loop.
2. Second one to generate the terminator of the loop.

These methods receive the information needed to generate the
loops/terminator and expect to return information needed to generate
the code for the intra-tile computation. See comments for more
details.

Presently this is adds support only for tiling. Subsequent commits
will update this to add support for fusion as well.

Signed-off-by: MaheshRavishankar <mahesh.ravishankar at gmail.com>
---
 .../SCF/Transforms/TileUsingInterface.h       |  94 ++++++++++-
 .../SCF/Transforms/TileUsingInterface.cpp     |  57 +++++++
 .../TilingInterface/tile-using-custom-op.mlir |  60 +++++++
 .../TestTilingInterfaceTransformOps.cpp       | 148 ++++++++++++++++++
 .../TestTilingInterfaceTransformOps.td        |  23 +++
 5 files changed, 381 insertions(+), 1 deletion(-)
 create mode 100644 mlir/test/Interfaces/TilingInterface/tile-using-custom-op.mlir

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 117e1ce1371f2..6b05ade37881c 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -34,7 +34,7 @@ using SCFTileSizeComputationFunction =
 /// Options to use to control tiling.
 struct SCFTilingOptions {
   /// Specify which loop construct to use for tile and fuse.
-  enum class LoopType { ForOp, ForallOp};
+  enum class LoopType { ForOp, ForallOp, CustomOp };
   LoopType loopType = LoopType::ForOp;
   SCFTilingOptions &setLoopType(LoopType type) {
     loopType = type;
@@ -121,6 +121,98 @@ struct SCFTilingOptions {
     reductionDims.insert(dims.begin(), dims.end());
     return *this;
   }
+
+  //-------------------------------------------------------------------------//
+  // Options related to tiling using custom loop.
+  //-------------------------------------------------------------------------//
+
+  // For generating the inter-tile loops using a custom loop, two callback
+  // functions are needed
+  // 1. That generates the "loop header", i.e. the loop that iterates over the
+  //    different tiles.
+  // 2. That generates the loop terminator
+  //
+  // For `scf.forall` case the call back to generate loop header would generate
+  //
+  // ```mlir
+  // scf.forall (...) = ... {
+  //   ..
+  // }
+  // ```
+  //
+  // and the call back to generate the loop terminator would generate the
+  // `scf.in_parallel` region
+  //
+  // ```mlir
+  // scf.forall (...) = ... {
+  //   scf.in_parallel {
+  //      tensor.parallel_insert_slice ...
+  //   }
+  // }
+  // ```
+  //
+
+  // Information that is to be returned by the callback to generate the loop
+  // header needed for the rest of the tiled codegeneration.
+  // - `loops`: The generated loops
+  // - `tileOffset`: The values that represent the offset of the iteration space
+  // tile
+  // - `tileSizes` : The values that represent the size of the iteration space
+  // tile.
+  // - `destinationTensors` : The tensors to use as destinations during tiling.
+  struct CustomLoopHeaderInfo {
+    SmallVector<LoopLikeOpInterface> loops;
+    SmallVector<OpFoldResult> tileOffset;
+    SmallVector<OpFoldResult> tileSizes;
+    SmallVector<Value> destinationTensors;
+  };
+
+  // Type of the callback function that generates the loop headers.
+  // - `loopRanges` : Values that represent the full size of the iteration space
+  //                  being tiled.
+  // - `giveTileSizes` : The tile sizes that are to be used to tile the
+  // iteration
+  //                     space.
+  // - `destinationTensors` : The tensors to use as destinations for the results
+  //                          of the tiled loop for loops that implement
+  //                          `DestinationStyleOpInterface`.
+  // Returns the `CustomLoopHeaderInfo` object (described above). it is expected
+  // that this function sets the insertion point of `rewriter` to the program
+  // point where the intra-tile loop computation is to be generated.
+  using GenerateLoopHeaderFn = std::function<FailureOr<CustomLoopHeaderInfo>(
+      RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
+      ArrayRef<OpFoldResult> givenTileSizes, ValueRange destinationTensors)>;
+
+  // Type of the callback function that generates the loop terminator.
+  // - `tiledResults` : Tiles of the result computed for the iteration space
+  // tile
+  // - `resultOffsets` : For each of the `tiledResults`, the offset at which
+  //                     the result tile is to be "inserted" back into the
+  //                     destination tensor.
+  // - `resultSizes` : For each of the `tiledResults`, the size of the result
+  // tile
+  //                   that is to be "inserted" back into the destination
+  //                   tensor.
+  // Returns the `CustomLoopHeaderInfo` object (described above)
+  using GenerateLoopTerminatorFn = std::function<LogicalResult(
+      RewriterBase &rewriter, Location loc, ValueRange tiledResults,
+      ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
+      ArrayRef<SmallVector<OpFoldResult>> resultSizes,
+      ValueRange destinationTensors)>;
+
+  // Callback function to generate the inter-tile loop header.
+  GenerateLoopHeaderFn generateLoopHeaderFn = nullptr;
+  // Callback function to generate the inter-tile loop terminator.
+  GenerateLoopTerminatorFn generateLoopTerminatorFn = nullptr;
+  // Helper function to set the callbacks for inter-tile loop header and
+  // terminator functions when using a custom operation for the loop.
+  SCFTilingOptions &
+  setCustomLoopGenerationFns(GenerateLoopHeaderFn headerFn,
+                             GenerateLoopTerminatorFn terminatorFn) {
+    generateLoopHeaderFn = std::move(headerFn);
+    generateLoopTerminatorFn = std::move(terminatorFn);
+    return *this;
+  }
 };
 
 /// Transformation information returned after tiling.
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index b77f66b701927..c3899473289e2 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -621,6 +621,57 @@ generateLoopNestUsingForallOp(RewriterBase &rewriter, Location loc,
   return loops;
 }
 
+/// Generate the tile-loop nest using custom loop operation.
+/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
+/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
+/// - `destinationTensors` are the init values to use for the outer most loop.
+/// - `mappingVector` is the mapping attributes to use for loop construction.
+///   Can be empty.
+/// - `tiledBodyFn` is called to generated the loop body of the inner
+/// most
+///    loop.
+/// Returns the generated `scf.forall` loop on success.
+static FailureOr<SmallVector<LoopLikeOpInterface>>
+generateLoopNestUsingCustomOp(
+    RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
+    ArrayRef<OpFoldResult> givenTileSizes, ValueRange outerDestinationTensors,
+    const scf::SCFTilingOptions::GenerateLoopHeaderFn &generateLoopHeaderFn,
+    const scf::SCFTilingOptions::GenerateLoopTerminatorFn
+        &generateLoopTerminatorFn,
+    GenerateTiledBodyFn tiledBodyFn) {
+  assert(!loopRanges.empty() && "unexpected empty loop ranges");
+  assert(loopRanges.size() == givenTileSizes.size() &&
+         "expected as many tile sizes as loop ranges");
+  assert(generateLoopHeaderFn && generateLoopTerminatorFn &&
+         "expected loop header/terminator generation function");
+  OpBuilder::InsertionGuard guard(rewriter);
+
+  FailureOr<scf::SCFTilingOptions::CustomLoopHeaderInfo> loopHeaderInfo =
+      generateLoopHeaderFn(rewriter, loc, loopRanges, givenTileSizes,
+                           outerDestinationTensors);
+  if (failed(loopHeaderInfo)) {
+    return failure();
+  }
+
+  SmallVector<Value> ivs;
+  SmallVector<Value> tiledResults;
+  SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
+  if (failed(tiledBodyFn(rewriter, loc, ivs, loopHeaderInfo->tileOffset,
+                         loopHeaderInfo->tileSizes,
+                         loopHeaderInfo->destinationTensors, tiledResults,
+                         resultOffsets, resultSizes))) {
+    return failure();
+  }
+
+  if (failed(generateLoopTerminatorFn(rewriter, loc, tiledResults,
+                                      resultOffsets, resultSizes,
+                                      loopHeaderInfo->destinationTensors))) {
+    return failure();
+  }
+
+  return loopHeaderInfo->loops;
+}
+
 /// Generate the tile-loop nest using the loop construct specifed in `options`.
 /// - `options`: Tiling options specified.
 /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
@@ -661,6 +712,12 @@ static FailureOr<SmallVector<LoopLikeOpInterface>> generateLoopNest(
         rewriter, loc, loopRanges, givenTileSizes, numThreads,
         options.mappingVector, destinationTensors, tiledBodyFn);
   }
+  if (options.loopType == scf::SCFTilingOptions::LoopType::CustomOp) {
+    return generateLoopNestUsingCustomOp(
+        rewriter, loc, loopRanges, givenTileSizes, destinationTensors,
+        options.generateLoopHeaderFn, options.generateLoopTerminatorFn,
+        tiledBodyFn);
+  }
   return rewriter.notifyMatchFailure(loc, "unhandled loop type");
 }
 
diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-custom-op.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-custom-op.mlir
new file mode 100644
index 0000000000000..d335e9c3fb5d0
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/tile-using-custom-op.mlir
@@ -0,0 +1,60 @@
+// RUN: mlir-opt --transform-interpreter --cse --split-input-file --mlir-print-local-scope %s | FileCheck %s
+
+module {
+  func.func @generic_parallel(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?xf32>) -> tensor<?x?xf32> {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+    %d1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+    %empty = tensor.empty(%d0, %d1) : tensor<?x?xf32>
+    %generic = linalg.generic {
+        indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                         affine_map<(d0, d1) -> (d1)>,
+                         affine_map<(d0, d1) -> (d0, d1)>],
+        iterator_types = ["parallel", "parallel"]}
+        ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>) outs(%empty : tensor<?x?xf32>) {
+      ^bb(%b0 : f32, %b1 : f32, %b2 : f32):
+        %add = arith.addf %b0, %b1 : f32
+        linalg.yield %add : f32
+    } -> tensor<?x?xf32>
+    return %generic : tensor<?x?xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %op = transform.structured.match ops {["linalg.generic"]} in %arg1
+        : (!transform.any_op) -> !transform.any_op
+    %tiled_op, %loop = transform.test.tile_using_custom_loop %op tile_sizes = [10, 20]
+        : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+// CHECK-LABEL: func @generic_parallel
+//  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
+//  CHECK-SAME:     %[[ARG1:.+]]: tensor<?xf32>
+//   CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
+//   CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
+//   CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
+//   CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
+//   CHECK-DAG:   %[[EMPTY:.+]] = tensor.empty(%[[D0]], %[[D1]]) : tensor<?x?xf32>
+//   CHECK-DAG:   %[[NITERS0:.+]] = affine.apply affine_map<()[s0] -> (s0 ceildiv 10)>()[%[[D0]]]
+//   CHECK-DAG:   %[[NITERS1:.+]] = affine.apply affine_map<()[s0] -> (s0 ceildiv 20)>()[%[[D1]]]
+//   CHECK-DAG:   %[[NITERS:.+]] = affine.apply affine_map<()[s0, s1] -> ((s0 ceildiv 10) * (s1 ceildiv 20))>()[%[[D0]], %[[D1]]]
+//       CHECK:   %[[FOR:.+]] = scf.for %[[IV:[a-zA-Z0-9]+]] = %[[C0]] to %[[NITERS]] step %[[C1]]
+//  CHECK-SAME:       iter_args(%[[INIT:.+]] = %[[EMPTY]])
+//       CHECK:     %[[DELINEARIZE:.+]]:2 = affine.delinearize_index %[[IV]] into (%[[NITERS0]], %[[NITERS1]])
+//   CHECK-DAG:     %[[SIZE0:.+]] = affine.min affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>(%[[DELINEARIZE]]#0)[%[[D0]]]
+//   CHECK-DAG:     %[[SIZE1:.+]] = affine.min affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>(%[[DELINEARIZE]]#1)[%[[D1]]]
+//   CHECK-DAG:     %[[OFFSET0:.+]] = affine.apply affine_map<(d0) -> (d0 * 10)>(%[[DELINEARIZE]]#0)
+//   CHECK-DAG:     %[[OFFSET1:.+]] = affine.apply affine_map<(d0) -> (d0 * 20)>(%[[DELINEARIZE]]#1)
+//   CHECK-DAG:     %[[ARG0_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[OFFSET0]], %[[OFFSET1]]] [%[[SIZE0]], %[[SIZE1]]] [1, 1]
+//   CHECK-DAG:     %[[ARG1_SLICE:.+]] = tensor.extract_slice %[[ARG1]][%[[OFFSET1]]] [%[[SIZE1]]] [1]
+//   CHECK-DAG:     %[[INIT_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[OFFSET0]], %[[OFFSET1]]] [%[[SIZE0]], %[[SIZE1]]] [1, 1]
+//       CHECK:     %[[GENERIC:.+]] = linalg.generic
+//  CHECK-SAME:         ins(%[[ARG0_SLICE]], %[[ARG1_SLICE]] :
+//  CHECK-SAME:         outs(%[[INIT_SLICE]] :
+//       CHECK:     %[[INSERT_SLICE:.+]] = tensor.insert_slice %[[GENERIC]] into %[[INIT]]
+//  CHECK-SAME:         [%[[OFFSET0]], %[[OFFSET1]]] [%[[SIZE0]], %[[SIZE1]]] [1, 1]
+//       CHECK:     scf.yield %[[INSERT_SLICE]]
+//       CHECK:   return %[[FOR]]
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 3d24d4ecc4d0d..1e3d5371f1ea8 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -13,6 +13,7 @@
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
 #include "mlir/Dialect/Transform/IR/TransformAttrs.h"
 #include "mlir/Dialect/Transform/IR/TransformDialect.h"
@@ -468,6 +469,153 @@ transform::TestTileAndFuseOuterParallelPartialReductionOp::apply(
                         : DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// TestTileAndFuseOuterParallelPartialReduction
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::TestTileUsingCustomLoopOp::apply(
+    TransformRewriter &transformRewriter, TransformResults &transformResults,
+    TransformState &state) {
+  auto target =
+      dyn_cast<TilingInterface>(*state.getPayloadOps(getRootOp()).begin());
+  if (!target) {
+    emitOpError("expected root operation to implement `TilingInterface`");
+    return DiagnosedSilenceableFailure::definiteFailure();
+  }
+
+  OpFoldResult oneOfr = transformRewriter.getIndexAttr(1);
+
+  scf::SCFTilingOptions::GenerateLoopHeaderFn loopHeaderFn =
+      [&](RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
+          ArrayRef<OpFoldResult> givenTileSizes,
+          ValueRange outerDestinationTensors)
+      -> FailureOr<scf::SCFTilingOptions::CustomLoopHeaderInfo> {
+    // Check that the strides are all 1 (to make it easier in the test).
+    if (llvm::any_of(loopRanges, [](Range r) {
+          return !isConstantIntValue(r.stride, 1);
+        })) {
+      return emitOpError("unable to handle loop ranges with strides != 1");
+    }
+    // For testing disallow any of the tile sizes being 0.
+    if (llvm::any_of(givenTileSizes, isZeroInteger)) {
+      return emitOpError("unhandled case of zero tile size");
+    }
+    // For testing, only handle tensor tiling.
+    if (outerDestinationTensors.empty()) {
+      return emitOpError("expected destination tensors");
+    }
+
+    // Compute the number of iterations for each of the loops.
+    AffineExpr s0, s1, s2;
+    bindSymbols(rewriter.getContext(), s0, s1, s2);
+    AffineExpr numItersExpr = (s1 - s0).ceilDiv(s2); // (ub - lb) / tileSize
+
+    SmallVector<OpFoldResult> allNumIters;
+    allNumIters.reserve(loopRanges.size());
+    for (auto [loopRange, tileSize] :
+         llvm::zip_equal(loopRanges, givenTileSizes)) {
+      OpFoldResult numIters = affine::makeComposedFoldedAffineApply(
+          rewriter, loc, numItersExpr,
+          {loopRange.offset, loopRange.size, tileSize});
+      allNumIters.push_back(numIters);
+    }
+    if (allNumIters.empty()) {
+      return emitOpError("unhandled case where all tile sizes are zero");
+    }
+
+    AffineExpr mulExpr = s0 * s1;
+    OpFoldResult cummulative = oneOfr;
+    for (auto numIters : allNumIters) {
+      cummulative = affine::makeComposedFoldedAffineApply(
+          rewriter, loc, mulExpr, {cummulative, numIters});
+    }
+
+    Value zeroVal = arith::ConstantIndexOp::create(rewriter, loc, 0);
+    Value oneVal = arith::ConstantIndexOp::create(rewriter, loc, 1);
+    Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, cummulative);
+
+    SmallVector<OpFoldResult> offsets;
+    SmallVector<OpFoldResult> sizes;
+    SmallVector<Value> innerDestinationTensors;
+    offsets.reserve(loopRanges.size());
+    sizes.reserve(loopRanges.size());
+
+    AffineExpr d0;
+    bindDims(rewriter.getContext(), d0);
+    AffineExpr offsetExpr = s0 + d0 * s1; // lb + iv * tileSize
+    AffineMap minMap =
+        AffineMap::get(1, 2, {s0 - d0, s1},
+                       rewriter.getContext()); // min(ub - offset, tileSize)
+    auto forOp = scf::ForOp::create(
+        rewriter, loc, zeroVal, ub, oneVal, outerDestinationTensors,
+        [&](OpBuilder &b, Location bodyLoc, Value linearizedIv,
+            ValueRange destinations) {
+          auto delinearizeOp = affine::AffineDelinearizeIndexOp::create(
+              b, bodyLoc, linearizedIv, allNumIters);
+          for (auto [normalizedIv, range, tileSize] : llvm::zip_equal(
+                   delinearizeOp.getResults(), loopRanges, givenTileSizes)) {
+
+            OpFoldResult normalizedIvOfr = getAsOpFoldResult(normalizedIv);
+            OpFoldResult offset = affine::makeComposedFoldedAffineApply(
+                b, bodyLoc, offsetExpr,
+                {normalizedIvOfr, range.offset, tileSize});
+            offsets.push_back(offset);
+
+            OpFoldResult size = affine::makeComposedFoldedAffineMin(
+                b, bodyLoc, minMap, {offset, range.size, tileSize});
+            sizes.push_back(size);
+          }
+          innerDestinationTensors = llvm::to_vector(destinations);
+        });
+    rewriter.setInsertionPointToEnd(forOp.getBody());
+    return scf::SCFTilingOptions::CustomLoopHeaderInfo{
+        {cast<LoopLikeOpInterface>(forOp.getOperation())},
+        offsets,
+        sizes,
+        innerDestinationTensors};
+  };
+
+  scf::SCFTilingOptions::GenerateLoopTerminatorFn terminatorFn =
+      [&](RewriterBase &rewriter, Location loc, ValueRange tiledResults,
+          ArrayRef<SmallVector<OpFoldResult>> resultOffsets,
+          ArrayRef<SmallVector<OpFoldResult>> resultSizes,
+          ValueRange destinationTensors) -> LogicalResult {
+    SmallVector<Value> yieldValues;
+    yieldValues.reserve(destinationTensors.size());
+    for (auto [tiledResult, offsets, sizes, destination] : llvm::zip_equal(
+             tiledResults, resultOffsets, resultSizes, destinationTensors)) {
+      SmallVector<OpFoldResult> strides(offsets.size(), oneOfr);
+      Value insertedVal = tensor::InsertSliceOp::create(
+          rewriter, loc, tiledResult, destination, offsets, sizes, strides);
+      yieldValues.push_back(insertedVal);
+    }
+    scf::YieldOp::create(rewriter, loc, yieldValues);
+    return success();
+  };
+
+  scf::SCFTilingOptions tilingOptions;
+  SmallVector<int64_t> staticTileSizes =
+      extractFromIntegerArrayAttr<int64_t>(getTileSizes());
+  SmallVector<OpFoldResult> tileSizes =
+      getAsIndexOpFoldResult(transformRewriter.getContext(), staticTileSizes);
+  tilingOptions.setTileSizes(tileSizes)
+      .setLoopType(scf::SCFTilingOptions::LoopType::CustomOp)
+      .setCustomLoopGenerationFns(loopHeaderFn, terminatorFn);
+
+  OpBuilder::InsertionGuard g(transformRewriter);
+  transformRewriter.setInsertionPoint(target);
+  FailureOr<scf::SCFTilingResult> tiledResults =
+      scf::tileUsingSCF(transformRewriter, target, tilingOptions);
+  if (failed(tiledResults)) {
+    return DiagnosedSilenceableFailure::definiteFailure();
+  }
+  transformRewriter.replaceOp(target, tiledResults->replacements);
+  transformResults.set(getOperation()->getResult(0), tiledResults->tiledOps);
+  transformResults.set(getOperation()->getResult(1), tiledResults->loops);
+
+  return DiagnosedSilenceableFailure::success();
+}
+
 #define GET_OP_CLASSES
 #include "TestTilingInterfaceTransformOps.cpp.inc"
 
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
index 58ccd30bb99a2..694c4229eef62 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
@@ -150,4 +150,27 @@ def TestTileAndFuseOuterParallelPartialReductionOp : Op<
   }];
 }
 
+def TestTileUsingCustomLoopOp : Op<
+    Transform_Dialect, "test.tile_using_custom_loop",
+    [FunctionalStyleTransformOpTrait, MemoryEffectsOpInterface,
+     DeclareOpInterfaceMethods<TransformOpInterface>,
+     ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Test Transform op to tile an operation using custom loops.
+
+    The test just folds all the loops and into a single loop and then
+    delinearizes the indices.
+  }];
+
+  let arguments = (ins TransformHandleTypeInterface:$root_op,
+                   DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes);
+  let results  = (outs TransformHandleTypeInterface:$tiled_ops,
+                  Variadic<TransformHandleTypeInterface>:$loops);
+  
+  let assemblyFormat = [{
+    $root_op `tile_sizes` `=` $tile_sizes
+    attr-dict `:` functional-type(operands, results)
+  }];
+}
+
 #endif // TEST_TILINGINTERFACE_TRANSFORM_OPS



More information about the Mlir-commits mailing list