[Mlir-commits] [mlir] [MLIR][SCF] Add callbacks to have control over tile ordering within a scf.forall loop (PR #158074)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 11 06:42:23 PDT 2025
https://github.com/sebvince created https://github.com/llvm/llvm-project/pull/158074
This PR adds a callback to `SCFTilingOptions ` to control tile ordering at scf.forall creation.
>From bef8b55c63c739f244fc4a0fcc143d44a539829e Mon Sep 17 00:00:00 2001
From: Seb Vince <sebvince at amd.com>
Date: Thu, 28 Aug 2025 17:10:23 +0000
Subject: [PATCH 1/5] Add callbacks to SCFTilingOptions to control tile
ordering on scf forall creation
---
.../SCF/Transforms/TileUsingInterface.h | 35 +++++++++++++++++
.../SCF/Transforms/TileUsingInterface.cpp | 39 ++++++++++++++-----
2 files changed, 64 insertions(+), 10 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 3205da6e448fc..431368ca12640 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -31,6 +31,31 @@ namespace scf {
using SCFTileSizeComputationFunction =
std::function<SmallVector<OpFoldResult>(OpBuilder &, Operation *)>;
+/// Computes the original tile indices from the induction variables of a newly
+/// created scf.forall loop.
+///
+/// \param ivs Induction variables of the newly formed scf.forall loop.
+/// \returns SmallVector<Value> containing the original tile indices.
+using SCFUpdateConductionVarFn = std::function<SmallVector<Value>(
+ RewriterBase &, Location &, ValueRange ivs)>;
+
+/// Controls tile iteration and distribution for an scf.forall loop.
+///
+/// \param loopRanges Array of Range objects specifying the iteration domain.
+/// \param tileSizes Array of tile sizes for each loop dimension.
+/// \returns A tuple containing:
+/// - lbs: Lower bounds for the scf.forall loop.
+/// - ubs: Upper bounds for the scf.forall loop.
+/// - steps: Step sizes for the scf.forall loop.
+/// - updateConductionVarFn: Function to compute original tile indices from
+/// new induction variables.
+
+using SCFTileDistributionFn = std::function<
+ std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
+ SmallVector<OpFoldResult>, SCFUpdateConductionVarFn>(
+ RewriterBase &, Location, ArrayRef<Range> loopRanges,
+ ArrayRef<OpFoldResult> tileSizes)>;
+
/// Options to use to control tiling.
struct SCFTilingOptions {
/// Computation function that returns the tile sizes to use for each loop.
@@ -39,6 +64,11 @@ struct SCFTilingOptions {
/// loops are not tiled. If the size of the returned vector is larger, then
/// the vector is truncated to number of loops.
SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr;
+ /// Function to have control over tile ordering within the scf.forall loop.
+ /// This function takes the iterationDomain as parameter and returns:
+ /// loop bounds : (lbs, ubs, steps)
+ /// ConductionVarFn : compute old tile indexes from old ones.
+ SCFTileDistributionFn tileDistributionFunction = nullptr;
SCFTilingOptions &
setTileSizeComputationFunction(SCFTileSizeComputationFunction fun) {
@@ -95,6 +125,11 @@ struct SCFTilingOptions {
return *this;
}
+ SCFTilingOptions &setTileDistributionFunction(SCFTileDistributionFn fun) {
+ tileDistributionFunction = std::move(fun);
+ return *this;
+ }
+
//-------------------------------------------------------------------------//
// Options related reduction tiling
//-------------------------------------------------------------------------//
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 834c02126fa53..2ea733da1a0bb 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -16,16 +16,19 @@
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <optional>
@@ -509,7 +512,9 @@ static LogicalResult generateLoopNestUsingForallOp(
RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
ArrayRef<Attribute> mappingVector, ValueRange destinationTensors,
- YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
+ YieldTiledValuesFn tiledBodyFn,
+ scf::SCFTileDistributionFn tileDistributionFn,
+ SmallVector<LoopLikeOpInterface> &loops) {
assert(!loopRanges.empty() && "unexpected empty loop ranges");
assert(loopRanges.size() == tileSizes.size() &&
"expected as many tile sizes as loop ranges");
@@ -521,6 +526,7 @@ static LogicalResult generateLoopNestUsingForallOp(
scf::ForallOp forallOp;
bool useNumThreads = !numThreads.empty();
+ scf::SCFUpdateConductionVarFn updateConductionVar = nullptr;
if (useNumThreads) {
// Prune the zero numthreads.
@@ -534,8 +540,13 @@ static LogicalResult generateLoopNestUsingForallOp(
destinationTensors, mappingAttr);
} else {
SmallVector<OpFoldResult> lbs, ubs, steps;
- std::tie(lbs, ubs, steps) =
- getLoopBounds(rewriter, loc, loopRanges, tileSizes);
+ if (tileDistributionFn) {
+ std::tie(lbs, ubs, steps, updateConductionVar) =
+ tileDistributionFn(rewriter, loc, loopRanges, tileSizes);
+ } else {
+ std::tie(lbs, ubs, steps) =
+ getLoopBounds(rewriter, loc, loopRanges, tileSizes);
+ }
forallOp = scf::ForallOp::create(rewriter, loc, lbs, ubs, steps,
destinationTensors, mappingAttr);
}
@@ -546,7 +557,13 @@ static LogicalResult generateLoopNestUsingForallOp(
SmallVector<Value> tiledResults;
SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
- if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
+ SmallVector<Value> originalInductionVars = forallOp.getInductionVars();
+ SmallVector<Value> updatedInductionVars = originalInductionVars;
+ if (updateConductionVar) {
+ updatedInductionVars =
+ updateConductionVar(rewriter, loc, originalInductionVars);
+ }
+ if (failed(tiledBodyFn(rewriter, loc, updatedInductionVars,
destinationTensors, tiledResults, resultOffsets,
resultSizes)))
return rewriter.notifyMatchFailure(loc, "failed to generate loop body");
@@ -580,7 +597,9 @@ static LogicalResult generateLoopNest(
scf::SCFTilingOptions::LoopType loopType, ArrayRef<Range> loopRanges,
ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
ValueRange destinationTensors, ArrayRef<Attribute> mappingVector,
- YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
+ YieldTiledValuesFn tiledBodyFn,
+ scf::SCFTileDistributionFn tileDistributionFn,
+ SmallVector<LoopLikeOpInterface> &loops) {
// If the tile sizes are all zero, no loops are generated. Just call the
// callback function to handle untiled case.
if (llvm::all_of(tileSizes, isZeroInteger)) {
@@ -596,7 +615,7 @@ static LogicalResult generateLoopNest(
if (loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
return generateLoopNestUsingForallOp(
rewriter, loc, loopRanges, tileSizes, numThreads, mappingVector,
- destinationTensors, tiledBodyFn, loops);
+ destinationTensors, tiledBodyFn, tileDistributionFn, loops);
}
return rewriter.notifyMatchFailure(loc, "unhandled loop type");
}
@@ -1116,10 +1135,10 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
// 7. Generate the tiled loops nest using the callback defined above.
SmallVector<LoopLikeOpInterface> loops;
- if (failed(generateLoopNest(rewriter, op.getLoc(), options.loopType,
- iterationDomain, tileSizes, numThreads,
- initTensors, options.mappingVector,
- innerYieldTiledValuesFn, loops)))
+ if (failed(generateLoopNest(
+ rewriter, op.getLoc(), options.loopType, iterationDomain, tileSizes,
+ numThreads, initTensors, options.mappingVector,
+ innerYieldTiledValuesFn, options.tileDistributionFunction, loops)))
return op.emitOpError("failed to generate tiling loops");
assert(succeeded(tilingResult) &&
"expected tiling result to be computed after loop generation");
>From 90fe3d5337cdbb760ebd0b67b9b840922be61821 Mon Sep 17 00:00:00 2001
From: Seb Vince <sebvince at amd.com>
Date: Thu, 28 Aug 2025 17:42:13 +0000
Subject: [PATCH 2/5] Fix typo Conduction vs Induction vars
---
.../mlir/Dialect/SCF/Transforms/TileUsingInterface.h | 8 ++++----
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp | 8 ++++----
2 files changed, 8 insertions(+), 8 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 431368ca12640..56705c3f2860c 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -36,7 +36,7 @@ using SCFTileSizeComputationFunction =
///
/// \param ivs Induction variables of the newly formed scf.forall loop.
/// \returns SmallVector<Value> containing the original tile indices.
-using SCFUpdateConductionVarFn = std::function<SmallVector<Value>(
+using SCFUpdateInductionVarFn = std::function<SmallVector<Value>(
RewriterBase &, Location &, ValueRange ivs)>;
/// Controls tile iteration and distribution for an scf.forall loop.
@@ -47,12 +47,12 @@ using SCFUpdateConductionVarFn = std::function<SmallVector<Value>(
/// - lbs: Lower bounds for the scf.forall loop.
/// - ubs: Upper bounds for the scf.forall loop.
/// - steps: Step sizes for the scf.forall loop.
-/// - updateConductionVarFn: Function to compute original tile indices from
+/// - updateInductionVarFn: Function to compute original tile indices from
/// new induction variables.
using SCFTileDistributionFn = std::function<
std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
- SmallVector<OpFoldResult>, SCFUpdateConductionVarFn>(
+ SmallVector<OpFoldResult>, SCFUpdateInductionVarFn>(
RewriterBase &, Location, ArrayRef<Range> loopRanges,
ArrayRef<OpFoldResult> tileSizes)>;
@@ -67,7 +67,7 @@ struct SCFTilingOptions {
/// Function to have control over tile ordering within the scf.forall loop.
/// This function takes the iterationDomain as parameter and returns:
/// loop bounds : (lbs, ubs, steps)
- /// ConductionVarFn : compute old tile indexes from old ones.
+ /// InductionVarFn : compute old tile indexes from old ones.
SCFTileDistributionFn tileDistributionFunction = nullptr;
SCFTilingOptions &
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 2ea733da1a0bb..14214b60135c8 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -526,7 +526,7 @@ static LogicalResult generateLoopNestUsingForallOp(
scf::ForallOp forallOp;
bool useNumThreads = !numThreads.empty();
- scf::SCFUpdateConductionVarFn updateConductionVar = nullptr;
+ scf::SCFUpdateInductionVarFn updateInductionVar = nullptr;
if (useNumThreads) {
// Prune the zero numthreads.
@@ -541,7 +541,7 @@ static LogicalResult generateLoopNestUsingForallOp(
} else {
SmallVector<OpFoldResult> lbs, ubs, steps;
if (tileDistributionFn) {
- std::tie(lbs, ubs, steps, updateConductionVar) =
+ std::tie(lbs, ubs, steps, updateInductionVar) =
tileDistributionFn(rewriter, loc, loopRanges, tileSizes);
} else {
std::tie(lbs, ubs, steps) =
@@ -559,9 +559,9 @@ static LogicalResult generateLoopNestUsingForallOp(
SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
SmallVector<Value> originalInductionVars = forallOp.getInductionVars();
SmallVector<Value> updatedInductionVars = originalInductionVars;
- if (updateConductionVar) {
+ if (updateInductionVar) {
updatedInductionVars =
- updateConductionVar(rewriter, loc, originalInductionVars);
+ updateInductionVar(rewriter, loc, originalInductionVars);
}
if (failed(tiledBodyFn(rewriter, loc, updatedInductionVars,
destinationTensors, tiledResults, resultOffsets,
>From 75e1ff9c57b1450c48e709141ac93797eb9ee1e9 Mon Sep 17 00:00:00 2001
From: Seb Vince <sebvince at amd.com>
Date: Mon, 1 Sep 2025 08:43:37 +0000
Subject: [PATCH 3/5] Use Range to represent loop bounds
---
.../mlir/Dialect/SCF/Transforms/TileUsingInterface.h | 9 +++------
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp | 8 +++++++-
2 files changed, 10 insertions(+), 7 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 56705c3f2860c..927e0cec38ca9 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -44,15 +44,12 @@ using SCFUpdateInductionVarFn = std::function<SmallVector<Value>(
/// \param loopRanges Array of Range objects specifying the iteration domain.
/// \param tileSizes Array of tile sizes for each loop dimension.
/// \returns A tuple containing:
-/// - lbs: Lower bounds for the scf.forall loop.
-/// - ubs: Upper bounds for the scf.forall loop.
-/// - steps: Step sizes for the scf.forall loop.
+/// - ranges : loop bounds for the scf.forall loop (lbs, ubs, steps).
/// - updateInductionVarFn: Function to compute original tile indices from
/// new induction variables.
-using SCFTileDistributionFn = std::function<
- std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
- SmallVector<OpFoldResult>, SCFUpdateInductionVarFn>(
+using SCFTileDistributionFn =
+ std::function<std::tuple<SmallVector<Range>, SCFUpdateInductionVarFn>(
RewriterBase &, Location, ArrayRef<Range> loopRanges,
ArrayRef<OpFoldResult> tileSizes)>;
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 14214b60135c8..0086605dadf2a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -541,8 +541,14 @@ static LogicalResult generateLoopNestUsingForallOp(
} else {
SmallVector<OpFoldResult> lbs, ubs, steps;
if (tileDistributionFn) {
- std::tie(lbs, ubs, steps, updateInductionVar) =
+ SmallVector<Range> ranges;
+ std::tie(ranges, updateInductionVar) =
tileDistributionFn(rewriter, loc, loopRanges, tileSizes);
+ for (const auto& range : ranges) {
+ lbs.push_back(range.offset);
+ ubs.push_back(range.size);
+ steps.push_back(range.stride);
+ }
} else {
std::tie(lbs, ubs, steps) =
getLoopBounds(rewriter, loc, loopRanges, tileSizes);
>From 79ea26d146eabb606ebe721d748475fff6664b02 Mon Sep 17 00:00:00 2001
From: Seb Vince <sebvince at amd.com>
Date: Thu, 11 Sep 2025 13:25:58 +0000
Subject: [PATCH 4/5] Fix formatting
---
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 0086605dadf2a..c942f657750de 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -544,10 +544,10 @@ static LogicalResult generateLoopNestUsingForallOp(
SmallVector<Range> ranges;
std::tie(ranges, updateInductionVar) =
tileDistributionFn(rewriter, loc, loopRanges, tileSizes);
- for (const auto& range : ranges) {
- lbs.push_back(range.offset);
- ubs.push_back(range.size);
- steps.push_back(range.stride);
+ for (const Range &range : ranges) {
+ lbs.push_back(range.offset);
+ ubs.push_back(range.size);
+ steps.push_back(range.stride);
}
} else {
std::tie(lbs, ubs, steps) =
>From 23af14e7908e0f48442d3a46b6ffdc356843986c Mon Sep 17 00:00:00 2001
From: Seb Vince <sebvince at amd.com>
Date: Thu, 11 Sep 2025 13:35:32 +0000
Subject: [PATCH 5/5] Fix typo
---
mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 927e0cec38ca9..305336a7689f9 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -64,7 +64,7 @@ struct SCFTilingOptions {
/// Function to have control over tile ordering within the scf.forall loop.
/// This function takes the iterationDomain as parameter and returns:
/// loop bounds : (lbs, ubs, steps)
- /// InductionVarFn : compute old tile indexes from old ones.
+ /// InductionVarFn : compute old tile indices from new ones.
SCFTileDistributionFn tileDistributionFunction = nullptr;
SCFTilingOptions &
More information about the Mlir-commits
mailing list