[Mlir-commits] [mlir] [MLIR][SCF] Add callbacks to have control over tile ordering within a scf.forall loop (PR #158074)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Sep 11 10:45:08 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-scf
Author: None (sebvince)
<details>
<summary>Changes</summary>
This PR adds a callback to `SCFTilingOptions ` to control tile ordering at scf.forall creation.
---
Full diff: https://github.com/llvm/llvm-project/pull/158074.diff
2 Files Affected:
- (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+32)
- (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+35-10)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 3205da6e448fc..305336a7689f9 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -31,6 +31,28 @@ namespace scf {
using SCFTileSizeComputationFunction =
std::function<SmallVector<OpFoldResult>(OpBuilder &, Operation *)>;
+/// Computes the original tile indices from the induction variables of a newly
+/// created scf.forall loop.
+///
+/// \param ivs Induction variables of the newly formed scf.forall loop.
+/// \returns SmallVector<Value> containing the original tile indices.
+using SCFUpdateInductionVarFn = std::function<SmallVector<Value>(
+ RewriterBase &, Location &, ValueRange ivs)>;
+
+/// Controls tile iteration and distribution for an scf.forall loop.
+///
+/// \param loopRanges Array of Range objects specifying the iteration domain.
+/// \param tileSizes Array of tile sizes for each loop dimension.
+/// \returns A tuple containing:
+/// - ranges : loop bounds for the scf.forall loop (lbs, ubs, steps).
+/// - updateInductionVarFn: Function to compute original tile indices from
+/// new induction variables.
+
+using SCFTileDistributionFn =
+ std::function<std::tuple<SmallVector<Range>, SCFUpdateInductionVarFn>(
+ RewriterBase &, Location, ArrayRef<Range> loopRanges,
+ ArrayRef<OpFoldResult> tileSizes)>;
+
/// Options to use to control tiling.
struct SCFTilingOptions {
/// Computation function that returns the tile sizes to use for each loop.
@@ -39,6 +61,11 @@ struct SCFTilingOptions {
/// loops are not tiled. If the size of the returned vector is larger, then
/// the vector is truncated to number of loops.
SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr;
+ /// Function to have control over tile ordering within the scf.forall loop.
+ /// This function takes the iterationDomain as parameter and returns:
+ /// loop bounds : (lbs, ubs, steps)
+ /// InductionVarFn : compute old tile indices from new ones.
+ SCFTileDistributionFn tileDistributionFunction = nullptr;
SCFTilingOptions &
setTileSizeComputationFunction(SCFTileSizeComputationFunction fun) {
@@ -95,6 +122,11 @@ struct SCFTilingOptions {
return *this;
}
+ SCFTilingOptions &setTileDistributionFunction(SCFTileDistributionFn fun) {
+ tileDistributionFunction = std::move(fun);
+ return *this;
+ }
+
//-------------------------------------------------------------------------//
// Options related reduction tiling
//-------------------------------------------------------------------------//
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 834c02126fa53..c942f657750de 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -16,16 +16,19 @@
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/ValueRange.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/ScopeExit.h"
+#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <optional>
@@ -509,7 +512,9 @@ static LogicalResult generateLoopNestUsingForallOp(
RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
ArrayRef<Attribute> mappingVector, ValueRange destinationTensors,
- YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
+ YieldTiledValuesFn tiledBodyFn,
+ scf::SCFTileDistributionFn tileDistributionFn,
+ SmallVector<LoopLikeOpInterface> &loops) {
assert(!loopRanges.empty() && "unexpected empty loop ranges");
assert(loopRanges.size() == tileSizes.size() &&
"expected as many tile sizes as loop ranges");
@@ -521,6 +526,7 @@ static LogicalResult generateLoopNestUsingForallOp(
scf::ForallOp forallOp;
bool useNumThreads = !numThreads.empty();
+ scf::SCFUpdateInductionVarFn updateInductionVar = nullptr;
if (useNumThreads) {
// Prune the zero numthreads.
@@ -534,8 +540,19 @@ static LogicalResult generateLoopNestUsingForallOp(
destinationTensors, mappingAttr);
} else {
SmallVector<OpFoldResult> lbs, ubs, steps;
- std::tie(lbs, ubs, steps) =
- getLoopBounds(rewriter, loc, loopRanges, tileSizes);
+ if (tileDistributionFn) {
+ SmallVector<Range> ranges;
+ std::tie(ranges, updateInductionVar) =
+ tileDistributionFn(rewriter, loc, loopRanges, tileSizes);
+ for (const Range &range : ranges) {
+ lbs.push_back(range.offset);
+ ubs.push_back(range.size);
+ steps.push_back(range.stride);
+ }
+ } else {
+ std::tie(lbs, ubs, steps) =
+ getLoopBounds(rewriter, loc, loopRanges, tileSizes);
+ }
forallOp = scf::ForallOp::create(rewriter, loc, lbs, ubs, steps,
destinationTensors, mappingAttr);
}
@@ -546,7 +563,13 @@ static LogicalResult generateLoopNestUsingForallOp(
SmallVector<Value> tiledResults;
SmallVector<SmallVector<OpFoldResult>> resultOffsets, resultSizes;
- if (failed(tiledBodyFn(rewriter, loc, forallOp.getInductionVars(),
+ SmallVector<Value> originalInductionVars = forallOp.getInductionVars();
+ SmallVector<Value> updatedInductionVars = originalInductionVars;
+ if (updateInductionVar) {
+ updatedInductionVars =
+ updateInductionVar(rewriter, loc, originalInductionVars);
+ }
+ if (failed(tiledBodyFn(rewriter, loc, updatedInductionVars,
destinationTensors, tiledResults, resultOffsets,
resultSizes)))
return rewriter.notifyMatchFailure(loc, "failed to generate loop body");
@@ -580,7 +603,9 @@ static LogicalResult generateLoopNest(
scf::SCFTilingOptions::LoopType loopType, ArrayRef<Range> loopRanges,
ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
ValueRange destinationTensors, ArrayRef<Attribute> mappingVector,
- YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
+ YieldTiledValuesFn tiledBodyFn,
+ scf::SCFTileDistributionFn tileDistributionFn,
+ SmallVector<LoopLikeOpInterface> &loops) {
// If the tile sizes are all zero, no loops are generated. Just call the
// callback function to handle untiled case.
if (llvm::all_of(tileSizes, isZeroInteger)) {
@@ -596,7 +621,7 @@ static LogicalResult generateLoopNest(
if (loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
return generateLoopNestUsingForallOp(
rewriter, loc, loopRanges, tileSizes, numThreads, mappingVector,
- destinationTensors, tiledBodyFn, loops);
+ destinationTensors, tiledBodyFn, tileDistributionFn, loops);
}
return rewriter.notifyMatchFailure(loc, "unhandled loop type");
}
@@ -1116,10 +1141,10 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
// 7. Generate the tiled loops nest using the callback defined above.
SmallVector<LoopLikeOpInterface> loops;
- if (failed(generateLoopNest(rewriter, op.getLoc(), options.loopType,
- iterationDomain, tileSizes, numThreads,
- initTensors, options.mappingVector,
- innerYieldTiledValuesFn, loops)))
+ if (failed(generateLoopNest(
+ rewriter, op.getLoc(), options.loopType, iterationDomain, tileSizes,
+ numThreads, initTensors, options.mappingVector,
+ innerYieldTiledValuesFn, options.tileDistributionFunction, loops)))
return op.emitOpError("failed to generate tiling loops");
assert(succeeded(tilingResult) &&
"expected tiling result to be computed after loop generation");
``````````
</details>
https://github.com/llvm/llvm-project/pull/158074
More information about the Mlir-commits
mailing list