[Mlir-commits] [mlir] [mlir][SCF] Allow tiling by specifying maximum number of tiles. (PR #91878)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat May 11 22:43:43 PDT 2024
https://github.com/MaheshRavishankar created https://github.com/llvm/llvm-project/pull/91878
None
>From ff45ad2e0dc347f9e5cfff8eba65a2e7a886b6ef Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh at nod-labs.com>
Date: Sat, 11 May 2024 13:38:36 -0700
Subject: [PATCH] [mlir][SCF] Allow tiling by specifying maximum number of
tiles.
---
.../Linalg/TransformOps/LinalgTransformOps.h | 6 +-
.../SCF/Transforms/TileUsingInterface.h | 35 ++-
.../TransformOps/LinalgTransformOps.cpp | 31 +-
.../SCF/Transforms/TileUsingInterface.cpp | 287 +++++++++++++-----
mlir/test/Dialect/Linalg/tile-to-forall.mlir | 1 -
.../TestTilingInterfaceTransformOps.cpp | 6 +-
6 files changed, 270 insertions(+), 96 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index 3af642752724c..db25c9b241734 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -30,6 +30,10 @@ class GenericOp;
class LinalgOp;
} // namespace linalg
+namespace scf {
+struct SCFTilingResult;
+} // namespace scf
+
namespace tensor {
class InsertSliceOp;
class PackOp;
@@ -60,7 +64,7 @@ tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state,
ArrayRef<OpFoldResult> mixedNumThreads,
ArrayRef<OpFoldResult> mixedTileSizes,
std::optional<ArrayAttr> mapping,
- linalg::ForallTilingResult &tilingResult);
+ scf::SCFTilingResult &tilingResult);
} // namespace transform
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 965ef9e203be2..c1775ea4818c7 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -31,9 +31,13 @@ using SCFTileSizeComputationFunction =
/// Options to use to control tiling.
struct SCFTilingOptions {
- /// Computation function that returns the tile sizes for each operation.
- /// Delayed construction of constant tile sizes should occur to interoperate
- /// with folding.
+ /// Computation function that returns the tile sizes to use for each loop.
+ /// Returning a tile size of zero implies no tiling for that loop. If the
+ /// size of the returned vector is smaller than the number of loops, the inner
+ /// loops are not tiled. If the size of the returned vector is larger, then
+ /// the vector is truncated to number of loops. Only one of
+ /// `tileSizeComputationFunction` or `maxNumTilesComputationFunction` should
+ /// be used.
SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr;
SCFTilingOptions &
@@ -44,7 +48,25 @@ struct SCFTilingOptions {
/// Convenience function to set the `tileSizeComputationFunction` to a
/// function that computes tile sizes at the point they are needed. Allows
/// proper interaction with folding.
- SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> ts);
+ SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> tileSizes);
+
+ /// Computation function that returns the maximum number of tile to use for
+ /// each loop. Returning a tile size of zero implies no tiling for that loop.
+ /// If the size of the returned vector is smaller than the number of loops,
+ /// the inner loops are not tiled. If the size of the returned vector is
+ /// larger, then the vector is truncated to number of loops. Only one of
+ /// `tileSizeComputationFunction` or `maxNumTilesComputationFunction` should
+ /// be used.
+ SCFTileSizeComputationFunction maxNumTilesComputationFunction = nullptr;
+
+ SCFTilingOptions &
+ setMaxNumTilesComputationFunction(SCFTileSizeComputationFunction fun) {
+ maxNumTilesComputationFunction = std::move(fun);
+ return *this;
+ }
+ /// Convenience function to set the `tileSizeComputationFunction` to a
+ /// function that computes tile sizes at the point they are needed.
+ SCFTilingOptions &setMaxNumTiles(ArrayRef<OpFoldResult> numTiles);
/// The interchange vector to reorder the tiled loops.
SmallVector<int64_t> interchangeVector = {};
@@ -66,9 +88,8 @@ struct SCFTilingOptions {
/// when using loop constructs that dont support such a mapping (like
/// `scf.for`)
SmallVector<Attribute> mappingVector = {};
- SCFTilingOptions &setMapping(ArrayRef<DeviceMappingAttrInterface> mapping) {
- mappingVector = llvm::map_to_vector(
- mapping, [](auto attr) -> Attribute { return attr; });
+ SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
+ mappingVector = llvm::to_vector(mapping);
return *this;
}
};
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 13582a140a965..9fa463763068f 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2917,7 +2917,7 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
TransformOpInterface transformOp, Operation *target,
ArrayRef<OpFoldResult> mixedNumThreads,
ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
- linalg::ForallTilingResult &tilingResult) {
+ scf::SCFTilingResult &tilingResult) {
// Transform all targets one by one.
auto tileableOp = dyn_cast<TilingInterface>(target);
if (!tileableOp) {
@@ -2928,18 +2928,22 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
return diag;
}
rewriter.setInsertionPoint(tileableOp);
- FailureOr<linalg::ForallTilingResult> maybeTilingResult = failure();
+ scf::SCFTilingOptions options;
+ options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
if (!mixedNumThreads.empty()) {
- maybeTilingResult =
- linalg::tileToForallOp(rewriter, tileableOp, mixedNumThreads, mapping);
+ options.setMaxNumTiles(mixedNumThreads);
} else {
- maybeTilingResult = linalg::tileToForallOpUsingTileSizes(
- rewriter, tileableOp, mixedTileSizes, mapping);
+ options.setTileSizes(mixedTileSizes);
}
+ if (mapping) {
+ options.setMapping(mapping.value().getValue());
+ }
+ FailureOr<scf::SCFTilingResult> maybeTilingResult =
+ scf::tileUsingSCF(rewriter, tileableOp, options);
if (failed(maybeTilingResult))
return transformOp.emitDefaultSilenceableFailure(tileableOp);
- rewriter.replaceOp(tileableOp, maybeTilingResult->tileOp->getResults());
+ rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
tilingResult = *maybeTilingResult;
return DiagnosedSilenceableFailure::success();
@@ -2975,14 +2979,14 @@ DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
return status;
for (Operation *target : state.getPayloadOps(getTarget())) {
- linalg::ForallTilingResult tilingResult;
+ scf::SCFTilingResult tilingResult;
DiagnosedSilenceableFailure diag = tileToForallOpImpl(
rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
getMapping(), tilingResult);
if (!diag.succeeded())
return diag;
- tileOps.push_back(tilingResult.tileOp);
- tiledOps.push_back(tilingResult.tiledOp);
+ tileOps.push_back(tilingResult.loops.front());
+ tiledOps.append(tilingResult.tiledOps);
}
transformResults.set(cast<OpResult>(getForallOp()), tileOps);
@@ -3460,7 +3464,7 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
// OpBuilder only used to compute attributes.
OpBuilder b(getContext());
- linalg::ForallTilingResult tilingResult;
+ scf::SCFTilingResult tilingResult;
DiagnosedSilenceableFailure diag = tileToForallOpImpl(
/*rewriter=*/rewriter,
/*state=*/state,
@@ -3473,8 +3477,9 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
if (!diag.succeeded())
return diag;
- results.push_back(tilingResult.tileOp);
- results.push_back(tilingResult.tiledOp);
+ results.push_back(tilingResult.loops.front());
+ for (auto op : tilingResult.tiledOps)
+ results.push_back(op);
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1a84a59ddb69d..83bb8532a8152 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -41,6 +41,16 @@ scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
return *this;
}
+scf::SCFTilingOptions &
+scf::SCFTilingOptions::setMaxNumTiles(ArrayRef<OpFoldResult> mnt) {
+ assert(!maxNumTilesComputationFunction && "max num tiles already set");
+ auto maxNumTiles = llvm::to_vector(mnt);
+ maxNumTilesComputationFunction = [maxNumTiles](OpBuilder &b, Operation *op) {
+ return maxNumTiles;
+ };
+ return *this;
+}
+
/// Helper method to adjust the interchange vector to match the iteration
/// domain.
static SmallVector<int64_t>
@@ -60,6 +70,101 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
// tileUsingSCF implementation.
//===----------------------------------------------------------------------===//
+/// Verify the tile size options are set in a consistent manner.
+static LogicalResult
+verifyTileSizeOptions(RewriterBase &rewriter, Location loc,
+ const scf::SCFTilingOptions &options) {
+ if (!options.tileSizeComputationFunction &&
+ !options.maxNumTilesComputationFunction) {
+ return rewriter.notifyMatchFailure(
+ loc, "at least one of tile size computation function or max num tiles "
+ "computation must be specified.");
+ }
+ if (options.tileSizeComputationFunction &&
+ options.maxNumTilesComputationFunction) {
+ return rewriter.notifyMatchFailure(
+ loc, "only one of tile size computation function or max num tiles "
+ "computation function can be specified");
+ }
+
+ // If specified, check that the interchange vector is a permutation.
+ if (!options.interchangeVector.empty()) {
+ if (!isPermutationVector(options.interchangeVector)) {
+ return rewriter.notifyMatchFailure(
+ loc, "invalid intechange vector, not a permutation of the entire "
+ "iteration space");
+ }
+ }
+ return success();
+}
+
+/// Compute the tile sizes and num tiles values. The `numTiles`
+/// is empty if the `maxNumTilesComputationFunction` is not specified.
+static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
+getTileSizesAndNumTiles(RewriterBase &rewriter, TilingInterface op,
+ ArrayRef<Range> iterationDomain,
+ const scf::SCFTilingOptions &options) {
+ SmallVector<OpFoldResult> tileSizes, numTiles;
+
+ // Enforce the convention that "tiling by zero"
+ // skips tiling a particular dimension. This convention is significantly
+ // simpler to handle instead of adjusting affine maps to account for missing
+ // dimensions.
+ auto numLoops = iterationDomain.size();
+ if (options.tileSizeComputationFunction) {
+ tileSizes = options.tileSizeComputationFunction(rewriter, op);
+ tileSizes.resize(numLoops, rewriter.getIndexAttr(0));
+ return {tileSizes, numTiles};
+ }
+
+ assert(options.maxNumTilesComputationFunction &&
+ "expected at least one of tile sizes cpomputation function or max num "
+ "tiles computation function");
+ // Enforce the convention that "maxNumTiles to zero"
+ // skips tiling a particular dimension. This convention is significantly
+ // simpler to handle instead of adjusting affine maps to account for missing
+ // dimensions.
+ SmallVector<OpFoldResult> maxNumTiles =
+ options.maxNumTilesComputationFunction(rewriter, op);
+ maxNumTiles.resize(numLoops, rewriter.getIndexAttr(0));
+
+ // Use the maxNumTiles to compute the tile sizes as
+ // - niters = ceilDiv(ub - lb, step)
+ // - tileSize = ceilDiv(niters, maxNumTiles)
+ AffineExpr s0, s1, s2, s3;
+ bindSymbols(rewriter.getContext(), s0, s1, s2, s3);
+ AffineExpr numIters = (s1 - s0).ceilDiv(s2);
+ AffineExpr tileSizeExpr = numIters.ceilDiv(s3);
+ tileSizes.resize(numLoops, rewriter.getIndexAttr(0));
+ for (auto [index, maxNumTile] : llvm::enumerate(maxNumTiles)) {
+ if (isConstantIntValue(maxNumTile, 0))
+ continue;
+
+ tileSizes[index] = affine::makeComposedFoldedAffineApply(
+ rewriter, op.getLoc(), tileSizeExpr,
+ {iterationDomain[index].offset, iterationDomain[index].size,
+ iterationDomain[index].stride, maxNumTile});
+ }
+
+ // After computing the tile size recompute the num tiles. reason to do this
+ // is to avoid corner cases like:
+ // [lb, ub, step] = [0, 300, 1], maxNumTiles = 21.
+ // Computed tileSize = 15. With this the actual number of threads is 20
+ // Not accounting for that creates a slice of size 0 which is undefined.
+ AffineExpr numTileExpr = numIters.floorDiv(s3);
+ numTiles.resize(tileSizes.size(), rewriter.getIndexAttr(0));
+ for (auto [index, tileSize] : llvm::enumerate(tileSizes)) {
+ if (isConstantIntValue(tileSize, 0))
+ continue;
+ numTiles[index] = affine::makeComposedFoldedAffineApply(
+ rewriter, op.getLoc(), numTileExpr,
+ {iterationDomain[index].offset, iterationDomain[index].size,
+ iterationDomain[index].stride, tileSize});
+ }
+
+ return {tileSizes, numTiles};
+}
+
// Check if `stride` evenly divides the trip count `size - offset`.
static bool tileDividesIterationDomain(Range loopRange) {
std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
@@ -99,6 +204,46 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
}
+/// Compute the tile offsets and sizes.
+static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
+getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
+ ArrayRef<Range> iterationDomain,
+ ArrayRef<OpFoldResult> tileSizes, bool isLoopNormalized) {
+ SmallVector<OpFoldResult> offsets, sizes;
+ int materializedLoopNum = 0;
+
+ AffineExpr d0, s0, s1, s2;
+ AffineExpr offsetExpr;
+ if (isLoopNormalized) {
+ bindDims(rewriter.getContext(), d0);
+ bindSymbols(rewriter.getContext(), s0, s1, s2);
+ offsetExpr = s0 + d0 * s1 * s2;
+ }
+
+ for (auto [tileSize, loopRange] :
+ llvm::zip_equal(tileSizes, iterationDomain)) {
+ if (isConstantIntValue(tileSize, 0)) {
+ offsets.push_back(loopRange.offset);
+ sizes.push_back(loopRange.size);
+ continue;
+ }
+ // If loop is normalized, the offset is (lb + iv * step * tileSize)
+ Value iv = ivs[materializedLoopNum++];
+ OpFoldResult offset;
+ if (isLoopNormalized) {
+ offset = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, offsetExpr,
+ ArrayRef<OpFoldResult>{iv, loopRange.offset, loopRange.stride,
+ tileSize});
+ } else {
+ offset = getAsOpFoldResult(iv);
+ }
+ offsets.push_back(offset);
+ sizes.push_back(getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
+ }
+ return {offsets, sizes};
+}
+
/// A function that allows returning additional yielded values during
/// `yieldTiledValuesAndReplace`.
/// - `ivs` induction variable for the loop.
@@ -144,8 +289,8 @@ static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
/// populated.
static LogicalResult generateLoopNestUsingForOp(
RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
- ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
- YieldTiledValuesFn yieldTiledValuesFn,
+ ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numTiles,
+ ValueRange destinationTensors, YieldTiledValuesFn yieldTiledValuesFn,
SmallVector<LoopLikeOpInterface> &loops) {
assert(!loopRanges.empty() && "unexpected empty loop ranges");
assert(loopRanges.size() == tileSizes.size() &&
@@ -153,15 +298,30 @@ static LogicalResult generateLoopNestUsingForOp(
OpBuilder::InsertionGuard guard(rewriter);
SmallVector<Value> ivs;
- for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
+ Value zero, one;
+ if (!numTiles.empty()) {
+ zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ ;
+ one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ }
+
+ for (auto [index, loopRange, tileSize] :
+ llvm::enumerate(loopRanges, tileSizes)) {
// No loops if tile size is zero. Set offset and size to the loop
// offset and size.
if (isConstantIntValue(tileSize, 0))
continue;
- Value lb = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
- Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
- Value step = getValueOrCreateConstantIndexOp(rewriter, loc, tileSize);
+ Value lb, ub, step;
+ if (numTiles.empty()) {
+ lb = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
+ ub = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
+ step = getValueOrCreateConstantIndexOp(rewriter, loc, tileSize);
+ } else {
+ lb = zero;
+ ub = getValueOrCreateConstantIndexOp(rewriter, loc, numTiles[index]);
+ step = one;
+ }
auto loop =
rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
[](OpBuilder &bodyBuilder, Location bodyLoc,
@@ -220,32 +380,45 @@ static LogicalResult generateLoopNestUsingForOp(
/// populated.
static LogicalResult generateLoopNestUsingForallOp(
RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
- ArrayRef<OpFoldResult> tileSizes, ArrayRef<Attribute> mappingVector,
- ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn,
- SmallVector<LoopLikeOpInterface> &loops) {
- SmallVector<OpFoldResult> lbs, ubs, steps;
+ ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numTiles,
+ ArrayRef<Attribute> mappingVector, ValueRange destinationTensors,
+ YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
assert(!loopRanges.empty() && "unexpected empty loop ranges");
assert(loopRanges.size() == tileSizes.size() &&
"expected as many tile sizes as loop ranges");
+ assert((numTiles.empty() || numTiles.size() == loopRanges.size()) &&
+ "expected max number of tiles to be either empty or equal to number "
+ "of loops");
OpBuilder::InsertionGuard guard(rewriter);
SmallVector<OpFoldResult> offsets(loopRanges.size()),
sizes(loopRanges.size());
- for (auto [tileSize, loopRange] : llvm::zip_equal(tileSizes, loopRanges)) {
- if (isConstantIntValue(tileSize, 0))
- continue;
- lbs.push_back(loopRange.offset);
- ubs.push_back(loopRange.size);
- steps.push_back(tileSize);
- }
- assert(!lbs.empty() && "Expected at least one loop range");
-
std::optional<ArrayAttr> mappingAttr;
if (!mappingVector.empty())
mappingAttr = rewriter.getArrayAttr(mappingVector);
- auto forallOp = rewriter.create<scf::ForallOp>(
- loc, lbs, ubs, steps, destinationTensors, mappingAttr);
+ scf::ForallOp forallOp;
+ SmallVector<OpFoldResult> lbs, ubs, steps;
+ if (numTiles.empty()) {
+ for (auto [tileSize, loopRange] : llvm::zip_equal(tileSizes, loopRanges)) {
+ if (isConstantIntValue(tileSize, 0))
+ continue;
+ lbs.push_back(loopRange.offset);
+ ubs.push_back(loopRange.size);
+ steps.push_back(tileSize);
+ }
+ assert(!lbs.empty() && "Expected at least one loop range");
+ forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps,
+ destinationTensors, mappingAttr);
+ } else {
+ SmallVector<OpFoldResult> numThreads;
+ for (auto maxNumTile : numTiles) {
+ if (!isConstantIntValue(maxNumTile, 0))
+ numThreads.push_back(maxNumTile);
+ }
+ forallOp = rewriter.create<scf::ForallOp>(loc, numThreads,
+ destinationTensors, mappingAttr);
+ }
loops.push_back(forallOp);
rewriter.setInsertionPoint(forallOp.getTerminator());
@@ -282,13 +455,11 @@ static LogicalResult generateLoopNestUsingForallOp(
/// loop.
/// - `loops` is an in-out parameter into which the generated loops are
/// populated.
-static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc,
- const scf::SCFTilingOptions &options,
- ArrayRef<Range> loopRanges,
- ArrayRef<OpFoldResult> tileSizes,
- ValueRange destinationTensors,
- YieldTiledValuesFn tiledBodyFn,
- SmallVector<LoopLikeOpInterface> &loops) {
+static LogicalResult generateLoopNest(
+ RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options,
+ ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> tileSizes,
+ ArrayRef<OpFoldResult> numTiles, ValueRange destinationTensors,
+ YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
// If the tile sizes are all zero, no loops are generated. Just call the
// callback function to handle untiled case.
if (llvm::all_of(tileSizes, isZeroIndex)) {
@@ -299,11 +470,12 @@ static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc,
}
if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) {
return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
- destinationTensors, tiledBodyFn, loops);
+ numTiles, destinationTensors, tiledBodyFn,
+ loops);
}
if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
return generateLoopNestUsingForallOp(
- rewriter, loc, loopRanges, tileSizes, options.mappingVector,
+ rewriter, loc, loopRanges, tileSizes, numTiles, options.mappingVector,
destinationTensors, tiledBodyFn, loops);
}
return rewriter.notifyMatchFailure(loc, "unhandled loop type");
@@ -527,28 +699,20 @@ static LogicalResult addInitOperandsToLoopNest(
FailureOr<scf::SCFTilingResult>
mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
const scf::SCFTilingOptions &options) {
+ if (failed(verifyTileSizeOptions(rewriter, op.getLoc(), options))) {
+ return failure();
+ }
+
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(op);
- if (!options.tileSizeComputationFunction) {
- return rewriter.notifyMatchFailure(
- op, "missing tile size computation function");
- }
-
// 1. Get the range of the loops that are represented by the operation.
SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
- size_t numLoops = iterationDomain.size();
- // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
- // skips tiling a particular dimension. This convention is significantly
- // simpler to handle instead of adjusting affine maps to account for missing
- // dimensions.
- SmallVector<OpFoldResult> tileSizes =
- options.tileSizeComputationFunction(rewriter, op);
- if (tileSizes.size() < iterationDomain.size()) {
- auto zero = rewriter.getIndexAttr(0);
- tileSizes.append(numLoops - tileSizes.size(), zero);
- }
+ // 2. Materialize the tile sizes or max num tiles;
+ SmallVector<OpFoldResult> tileSizes, numTiles;
+ std::tie(tileSizes, numTiles) =
+ getTileSizesAndNumTiles(rewriter, op, iterationDomain, options);
// 3. If there is an interchange specified, permute the iteration domain and
// the tile sizes.
@@ -556,16 +720,13 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
if (!options.interchangeVector.empty()) {
interchangeVector = fillInterchangeVector(options.interchangeVector,
iterationDomain.size());
- }
- if (!interchangeVector.empty()) {
- if (!isPermutationVector(interchangeVector)) {
- return rewriter.notifyMatchFailure(
- op, "invalid intechange vector, not a permutation of the entire "
- "iteration space");
- }
+ assert(isPermutationVector(interchangeVector) &&
+ "expected interchange vector to be a permutation");
applyPermutationToVector(iterationDomain, interchangeVector);
applyPermutationToVector(tileSizes, interchangeVector);
+ if (!numTiles.empty())
+ applyPermutationToVector(numTiles, interchangeVector);
}
FailureOr<TilingResult> tilingResult;
@@ -579,21 +740,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
-> LogicalResult {
// 4a. Compute the `offsets` and `sizes` to use for tiling.
SmallVector<OpFoldResult> offsets, sizes;
- {
- int materializedLoopNum = 0;
- for (auto [tileSize, loopRange] :
- llvm::zip_equal(tileSizes, iterationDomain)) {
- if (isConstantIntValue(tileSize, 0)) {
- offsets.push_back(loopRange.offset);
- sizes.push_back(loopRange.size);
- continue;
- }
- Value iv = ivs[materializedLoopNum++];
- offsets.push_back(iv);
- sizes.push_back(
- getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
- }
- }
+ std::tie(offsets, sizes) = getTileOffsetAndSizes(
+ rewriter, loc, ivs, iterationDomain, tileSizes, !numTiles.empty());
// 4b. If interchange was provided, apply inverse of the interchange
// to get back the offsets/sizes in the order to be specified.
@@ -661,7 +809,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
// 7. Generate the tiled loops nest using the callback defined above.
SmallVector<LoopLikeOpInterface> loops;
if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
- tileSizes, destinationTensors,
+ tileSizes, numTiles, destinationTensors,
innerYieldTiledValuesFn, loops)))
return op.emitOpError("failed to generate tiling loops");
assert(succeeded(tilingResult) &&
@@ -774,6 +922,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
scf::SCFTilingOptions options;
options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector,
+ /*numTiles=*/ArrayRef<OpFoldResult>{},
destinationTensors, innerYieldTiledValuesFn,
loops)))
return b.notifyMatchFailure(op, "failed to tile for parallel reduction");
diff --git a/mlir/test/Dialect/Linalg/tile-to-forall.mlir b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
index 8545dfd25eccf..f33739f119eaf 100644
--- a/mlir/test/Dialect/Linalg/tile-to-forall.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
@@ -177,7 +177,6 @@ module attributes {transform.with_named_sequence} {
}
}
-
// -----
// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 335db1a61f476..d4126f04a2f35 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -182,11 +182,7 @@ applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
scf::SCFTilingOptions tilingOptions;
tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
if (mapping) {
- auto mappingAttrs =
- llvm::map_to_vector(mapping.value(), [](Attribute attr) {
- return cast<DeviceMappingAttrInterface>(attr);
- });
- tilingOptions.setMapping(mappingAttrs);
+ tilingOptions.setMapping(mapping.value().getValue());
}
tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
More information about the Mlir-commits
mailing list