[Mlir-commits] [mlir] [MLIR][SCF] Add callbacks to have control over tile ordering within a scf.forall loop (PR #158074)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Sep 11 06:42:23 PDT 2025


https://github.com/sebvince created https://github.com/llvm/llvm-project/pull/158074

This PR adds a callback to `SCFTilingOptions ` to control tile ordering at scf.forall creation.

>From bef8b55c63c739f244fc4a0fcc143d44a539829e Mon Sep 17 00:00:00 2001
From: Seb Vince <sebvince at amd.com>
Date: Thu, 28 Aug 2025 17:10:23 +0000
Subject: [PATCH 1/5] Add callbacks to SCFTilingOptions to control tile
 ordering on scf forall creation

---
 .../SCF/Transforms/TileUsingInterface.h       | 35 +++++++++++++++++
 .../SCF/Transforms/TileUsingInterface.cpp     | 39 ++++++++++++++-----
 2 files changed, 64 insertions(+), 10 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 3205da6e448fc..431368ca12640 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -31,6 +31,31 @@ namespace scf {
 using SCFTileSizeComputationFunction =
     std::function<SmallVector<OpFoldResult>(OpBuilder &, Operation *)>;
 
+/// Computes the original tile indices from the induction variables of a newly
+/// created scf.forall loop.
+///
+/// \param ivs Induction variables of the newly formed scf.forall loop.
+/// \returns SmallVector<Value> containing the original tile indices.
+using SCFUpdateConductionVarFn = std::function<SmallVector<Value>(
+    RewriterBase &, Location &, ValueRange ivs)>;
+
+/// Controls tile iteration and distribution for an scf.forall loop.
+///
+/// \param loopRanges Array of Range objects specifying the iteration domain.
+/// \param tileSizes Array of tile sizes for each loop dimension.
+/// \returns A tuple containing:
+///   - lbs: Lower bounds for the scf.forall loop.
+///   - ubs: Upper bounds for the scf.forall loop.
+///   - steps: Step sizes for the scf.forall loop.
+///   - updateConductionVarFn: Function to compute original tile indices from
+///   new induction variables.
+
+using SCFTileDistributionFn = std::function<
+    std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
+               SmallVector<OpFoldResult>, SCFUpdateConductionVarFn>(
+        RewriterBase &, Location, ArrayRef<Range> loopRanges,
+        ArrayRef<OpFoldResult> tileSizes)>;
+
 /// Options to use to control tiling.
 struct SCFTilingOptions {
   /// Computation function that returns the tile sizes to use for each loop.
@@ -39,6 +64,11 @@ struct SCFTilingOptions {
   /// loops are not tiled. If the size of the returned vector is larger, then
   /// the vector is truncated to number of loops.
   SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr;
+  /// Function to have control over tile ordering within the scf.forall loop.
+  /// This function takes the iterationDomain as parameter and returns:
+  /// loop bounds : (lbs, ubs, steps)
+  /// ConductionVarFn : compute old tile indexes from old ones.
+  SCFTileDistributionFn tileDistributionFunction = nullptr;
 
   SCFTilingOptions &
   setTileSizeComputationFunction(SCFTileSizeComputationFunction fun) {
@@ -95,6 +125,11 @@ struct SCFTilingOptions {
     return *this;
   }
 
+  SCFTilingOptions &setTileDistributionFunction(SCFTileDistributionFn fun) {
+    tileDistributionFunction = std::move(fun);
+    return *this;
+  }
+
   //-------------------------------------------------------------------------//
   // Options related reduction tiling
   //-------------------------------------------------------------------------//
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 834c02126fa53..2ea733da1a0bb 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -16,16 +16,19 @@
 #include "mlir/Analysis/TopologicalSortUtils.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
 #include "mlir/Dialect/SCF/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/ValueRange.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Rewrite/FrozenRewritePatternSet.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 #include <optional>
@@ -509,7 +512,9 @@ static LogicalResult generateLoopNestUsingForallOp(
     RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
     ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
     ArrayRef<Attribute> mappingVector, ValueRange destinationTensors,
-    YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
+    YieldTiledValuesFn tiledBodyFn,
+    scf::SCFTileDistributionFn tileDistributionFn,
+    SmallVector<LoopLikeOpInterface> &loops) {
   assert(!loopRanges.empty() && "unexpected empty loop ranges");
   assert(loopRanges.size() == tileSizes.size() &&
          "expected as many tile sizes as loop ranges");
@@ -521,6 +526,7 @@ static LogicalResult generateLoopNestUsingForallOp(
 
   scf::ForallOp forallOp;
   bool useNumThreads = !numThreads.empty();
+  scf::SCFUpdateConductionVarFn updateConductionVar = nullptr;
 
   if (useNumThreads) {
     // Prune the zero numthreads.
@@ -534,8 +540,13 @@ static LogicalResult generateLoopNestUsingForallOp(
                                      destinationTensors, mappingAttr);
   } else {
     SmallVector<OpFoldResult> lbs, ubs, steps;
-    std::tie(lbs, ubs, steps) =
-        getLoopBounds(rewriter, loc, loopRanges, tileSizes);
+    if (tileDistributionFn) {
+      std::tie(lbs, ubs, steps, updateConductionVar) =
+          tileDistributionFn(rewriter, loc, loopRanges, tileSizes);
+    } else {
+      std::tie(lbs, ubs, steps) =
+          getLoopBounds(rewriter, loc, loopRanges, tileSizes);
+    }
     forallOp = scf::ForallOp::create(rewriter, loc, lbs, ubs, steps,
                                      destinationTensors, mappingAttr);
   }
@@ -546,7 +557,13 @@ static LogicalResult generateLoopNestUsingForallOp(
 
   SmallVector<Value> tiledResults;
   SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
-  if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
+  SmallVector<Value> originalInductionVars = forallOp.getInductionVars();
+  SmallVector<Value> updatedInductionVars = originalInductionVars;
+  if (updateConductionVar) {
+    updatedInductionVars =
+        updateConductionVar(rewriter, loc, originalInductionVars);
+  }
+  if (failed(tiledBodyFn(rewriter, loc, updatedInductionVars,
                          destinationTensors, tiledResults, resultOffsets,
                          resultSizes)))
     return rewriter.notifyMatchFailure(loc, "failed to generate loop body");
@@ -580,7 +597,9 @@ static LogicalResult generateLoopNest(
     scf::SCFTilingOptions::LoopType loopType, ArrayRef<Range> loopRanges,
     ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
     ValueRange destinationTensors, ArrayRef<Attribute> mappingVector,
-    YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
+    YieldTiledValuesFn tiledBodyFn,
+    scf::SCFTileDistributionFn tileDistributionFn,
+    SmallVector<LoopLikeOpInterface> &loops) {
   // If the tile sizes are all zero, no loops are generated. Just call the
   // callback function to handle untiled case.
   if (llvm::all_of(tileSizes, isZeroInteger)) {
@@ -596,7 +615,7 @@ static LogicalResult generateLoopNest(
   if (loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
     return generateLoopNestUsingForallOp(
         rewriter, loc, loopRanges, tileSizes, numThreads, mappingVector,
-        destinationTensors, tiledBodyFn, loops);
+        destinationTensors, tiledBodyFn, tileDistributionFn, loops);
   }
   return rewriter.notifyMatchFailure(loc, "unhandled loop type");
 }
@@ -1116,10 +1135,10 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
 
   // 7. Generate the tiled loops nest using the callback defined above.
   SmallVector<LoopLikeOpInterface> loops;
-  if (failed(generateLoopNest(rewriter, op.getLoc(), options.loopType,
-                              iterationDomain, tileSizes, numThreads,
-                              initTensors, options.mappingVector,
-                              innerYieldTiledValuesFn, loops)))
+  if (failed(generateLoopNest(
+          rewriter, op.getLoc(), options.loopType, iterationDomain, tileSizes,
+          numThreads, initTensors, options.mappingVector,
+          innerYieldTiledValuesFn, options.tileDistributionFunction, loops)))
     return op.emitOpError("failed to generate tiling loops");
   assert(succeeded(tilingResult) &&
          "expected tiling result to be computed after loop generation");

>From 90fe3d5337cdbb760ebd0b67b9b840922be61821 Mon Sep 17 00:00:00 2001
From: Seb Vince <sebvince at amd.com>
Date: Thu, 28 Aug 2025 17:42:13 +0000
Subject: [PATCH 2/5] Fix typo Conduction vs Induction vars

---
 .../mlir/Dialect/SCF/Transforms/TileUsingInterface.h      | 8 ++++----
 mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp    | 8 ++++----
 2 files changed, 8 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 431368ca12640..56705c3f2860c 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -36,7 +36,7 @@ using SCFTileSizeComputationFunction =
 ///
 /// \param ivs Induction variables of the newly formed scf.forall loop.
 /// \returns SmallVector<Value> containing the original tile indices.
-using SCFUpdateConductionVarFn = std::function<SmallVector<Value>(
+using SCFUpdateInductionVarFn = std::function<SmallVector<Value>(
     RewriterBase &, Location &, ValueRange ivs)>;
 
 /// Controls tile iteration and distribution for an scf.forall loop.
@@ -47,12 +47,12 @@ using SCFUpdateConductionVarFn = std::function<SmallVector<Value>(
 ///   - lbs: Lower bounds for the scf.forall loop.
 ///   - ubs: Upper bounds for the scf.forall loop.
 ///   - steps: Step sizes for the scf.forall loop.
-///   - updateConductionVarFn: Function to compute original tile indices from
+///   - updateInductionVarFn: Function to compute original tile indices from
 ///   new induction variables.
 
 using SCFTileDistributionFn = std::function<
     std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
-               SmallVector<OpFoldResult>, SCFUpdateConductionVarFn>(
+               SmallVector<OpFoldResult>, SCFUpdateInductionVarFn>(
         RewriterBase &, Location, ArrayRef<Range> loopRanges,
         ArrayRef<OpFoldResult> tileSizes)>;
 
@@ -67,7 +67,7 @@ struct SCFTilingOptions {
   /// Function to have control over tile ordering within the scf.forall loop.
   /// This function takes the iterationDomain as parameter and returns:
   /// loop bounds : (lbs, ubs, steps)
-  /// ConductionVarFn : compute old tile indexes from old ones.
+  /// InductionVarFn : compute old tile indexes from old ones.
   SCFTileDistributionFn tileDistributionFunction = nullptr;
 
   SCFTilingOptions &
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 2ea733da1a0bb..14214b60135c8 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -526,7 +526,7 @@ static LogicalResult generateLoopNestUsingForallOp(
 
   scf::ForallOp forallOp;
   bool useNumThreads = !numThreads.empty();
-  scf::SCFUpdateConductionVarFn updateConductionVar = nullptr;
+  scf::SCFUpdateInductionVarFn updateInductionVar = nullptr;
 
   if (useNumThreads) {
     // Prune the zero numthreads.
@@ -541,7 +541,7 @@ static LogicalResult generateLoopNestUsingForallOp(
   } else {
     SmallVector<OpFoldResult> lbs, ubs, steps;
     if (tileDistributionFn) {
-      std::tie(lbs, ubs, steps, updateConductionVar) =
+      std::tie(lbs, ubs, steps, updateInductionVar) =
           tileDistributionFn(rewriter, loc, loopRanges, tileSizes);
     } else {
       std::tie(lbs, ubs, steps) =
@@ -559,9 +559,9 @@ static LogicalResult generateLoopNestUsingForallOp(
   SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
   SmallVector<Value> originalInductionVars = forallOp.getInductionVars();
   SmallVector<Value> updatedInductionVars = originalInductionVars;
-  if (updateConductionVar) {
+  if (updateInductionVar) {
     updatedInductionVars =
-        updateConductionVar(rewriter, loc, originalInductionVars);
+        updateInductionVar(rewriter, loc, originalInductionVars);
   }
   if (failed(tiledBodyFn(rewriter, loc, updatedInductionVars,
                          destinationTensors, tiledResults, resultOffsets,

>From 75e1ff9c57b1450c48e709141ac93797eb9ee1e9 Mon Sep 17 00:00:00 2001
From: Seb Vince <sebvince at amd.com>
Date: Mon, 1 Sep 2025 08:43:37 +0000
Subject: [PATCH 3/5] Use Range to represent loop bounds

---
 .../mlir/Dialect/SCF/Transforms/TileUsingInterface.h     | 9 +++------
 mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp   | 8 +++++++-
 2 files changed, 10 insertions(+), 7 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 56705c3f2860c..927e0cec38ca9 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -44,15 +44,12 @@ using SCFUpdateInductionVarFn = std::function<SmallVector<Value>(
 /// \param loopRanges Array of Range objects specifying the iteration domain.
 /// \param tileSizes Array of tile sizes for each loop dimension.
 /// \returns A tuple containing:
-///   - lbs: Lower bounds for the scf.forall loop.
-///   - ubs: Upper bounds for the scf.forall loop.
-///   - steps: Step sizes for the scf.forall loop.
+///   - ranges : loop bounds for the scf.forall loop (lbs, ubs, steps).
 ///   - updateInductionVarFn: Function to compute original tile indices from
 ///   new induction variables.
 
-using SCFTileDistributionFn = std::function<
-    std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
-               SmallVector<OpFoldResult>, SCFUpdateInductionVarFn>(
+using SCFTileDistributionFn =
+    std::function<std::tuple<SmallVector<Range>, SCFUpdateInductionVarFn>(
         RewriterBase &, Location, ArrayRef<Range> loopRanges,
         ArrayRef<OpFoldResult> tileSizes)>;
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 14214b60135c8..0086605dadf2a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -541,8 +541,14 @@ static LogicalResult generateLoopNestUsingForallOp(
   } else {
     SmallVector<OpFoldResult> lbs, ubs, steps;
     if (tileDistributionFn) {
-      std::tie(lbs, ubs, steps, updateInductionVar) =
+      SmallVector<Range> ranges;
+      std::tie(ranges, updateInductionVar) =
           tileDistributionFn(rewriter, loc, loopRanges, tileSizes);
+      for (const auto& range : ranges) {
+          lbs.push_back(range.offset);
+          ubs.push_back(range.size);
+          steps.push_back(range.stride);
+      }
     } else {
       std::tie(lbs, ubs, steps) =
           getLoopBounds(rewriter, loc, loopRanges, tileSizes);

>From 79ea26d146eabb606ebe721d748475fff6664b02 Mon Sep 17 00:00:00 2001
From: Seb Vince <sebvince at amd.com>
Date: Thu, 11 Sep 2025 13:25:58 +0000
Subject: [PATCH 4/5] Fix formatting

---
 mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 0086605dadf2a..c942f657750de 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -544,10 +544,10 @@ static LogicalResult generateLoopNestUsingForallOp(
       SmallVector<Range> ranges;
       std::tie(ranges, updateInductionVar) =
           tileDistributionFn(rewriter, loc, loopRanges, tileSizes);
-      for (const auto& range : ranges) {
-          lbs.push_back(range.offset);
-          ubs.push_back(range.size);
-          steps.push_back(range.stride);
+      for (const Range &range : ranges) {
+        lbs.push_back(range.offset);
+        ubs.push_back(range.size);
+        steps.push_back(range.stride);
       }
     } else {
       std::tie(lbs, ubs, steps) =

>From 23af14e7908e0f48442d3a46b6ffdc356843986c Mon Sep 17 00:00:00 2001
From: Seb Vince <sebvince at amd.com>
Date: Thu, 11 Sep 2025 13:35:32 +0000
Subject: [PATCH 5/5] Fix typo

---
 mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 927e0cec38ca9..305336a7689f9 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -64,7 +64,7 @@ struct SCFTilingOptions {
   /// Function to have control over tile ordering within the scf.forall loop.
   /// This function takes the iterationDomain as parameter and returns:
   /// loop bounds : (lbs, ubs, steps)
-  /// InductionVarFn : compute old tile indexes from old ones.
+  /// InductionVarFn : compute old tile indices from new ones.
   SCFTileDistributionFn tileDistributionFunction = nullptr;
 
   SCFTilingOptions &



More information about the Mlir-commits mailing list