[Mlir-commits] [mlir] [mlir][Linalg] Deprecate `linalg::tileToForallOp` and `linalg::tileToForallOpUsingTileSizes` (PR #91878)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jun 14 00:03:39 PDT 2024
https://github.com/MaheshRavishankar updated https://github.com/llvm/llvm-project/pull/91878
>From 4151c40f962d74f9b7953cc3b5e4f8282a6538f6 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh at nod-labs.com>
Date: Sat, 11 May 2024 13:38:36 -0700
Subject: [PATCH 1/9] [mlir][SCF] Allow tiling by specifying maximum number of
tiles.
---
.../Linalg/TransformOps/LinalgTransformOps.h | 6 +-
.../Dialect/Linalg/Transforms/Transforms.h | 24 --
.../SCF/Transforms/TileUsingInterface.h | 35 ++-
.../TransformOps/LinalgTransformOps.cpp | 47 ++-
mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 182 ------------
.../SCF/Transforms/TileUsingInterface.cpp | 271 +++++++++++++-----
mlir/test/Dialect/Linalg/tile-to-forall.mlir | 1 -
.../TestTilingInterfaceTransformOps.cpp | 6 +-
8 files changed, 270 insertions(+), 302 deletions(-)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index 3af642752724c..db25c9b241734 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -30,6 +30,10 @@ class GenericOp;
class LinalgOp;
} // namespace linalg
+namespace scf {
+struct SCFTilingResult;
+} // namespace scf
+
namespace tensor {
class InsertSliceOp;
class PackOp;
@@ -60,7 +64,7 @@ tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state,
ArrayRef<OpFoldResult> mixedNumThreads,
ArrayRef<OpFoldResult> mixedTileSizes,
std::optional<ArrayAttr> mapping,
- linalg::ForallTilingResult &tilingResult);
+ scf::SCFTilingResult &tilingResult);
} // namespace transform
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 308ce92e35520..e9b10de68bc44 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -846,30 +846,6 @@ FailureOr<StaticMultiSizeSpecification>
computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize,
int64_t divisor);
-/// Rewrite a TilingInterface `op` to a tiled `scf.forall`, applying
-/// tiling by `numThreads`.
-/// If non-empty, the `mapping` is added as an attribute to the
-/// resulting `scf.forall`.
-/// Zero tile sizes indicate that the dimension is not tiled, and can be
-/// thought of as tiling by the full size of data. It is the user's
-/// responsibility to ensure that `numThreads` is a valid tiling specification
-/// (i.e. that only tiles parallel dimensions, e.g. in the Linalg case).
-struct ForallTilingResult {
- Operation *tileOp;
- Operation *tiledOp;
-};
-FailureOr<ForallTilingResult> tileToForallOp(RewriterBase &builder,
- TilingInterface op,
- ArrayRef<OpFoldResult> numThreads,
- std::optional<ArrayAttr> mapping);
-
-/// Same as `tileToForallOp`, but calculate the number of threads
-/// required using the given tileSizes.
-FailureOr<ForallTilingResult>
-tileToForallOpUsingTileSizes(RewriterBase &builder, TilingInterface op,
- ArrayRef<OpFoldResult> tileSizes,
- std::optional<ArrayAttr> mapping);
-
/// Transformation information returned after reduction tiling.
struct ForallReductionTilingResult {
/// The partial reduction tiled op generated.
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index dac79111af3c9..451a21c766175 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -32,9 +32,13 @@ using SCFTileSizeComputationFunction =
/// Options to use to control tiling.
struct SCFTilingOptions {
- /// Computation function that returns the tile sizes for each operation.
- /// Delayed construction of constant tile sizes should occur to interoperate
- /// with folding.
+ /// Computation function that returns the tile sizes to use for each loop.
+ /// Returning a tile size of zero implies no tiling for that loop. If the
+ /// size of the returned vector is smaller than the number of loops, the inner
+ /// loops are not tiled. If the size of the returned vector is larger, then
+ /// the vector is truncated to number of loops. Only one of
+ /// `tileSizeComputationFunction` or `maxNumTilesComputationFunction` should
+ /// be used.
SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr;
SCFTilingOptions &
@@ -45,7 +49,25 @@ struct SCFTilingOptions {
/// Convenience function to set the `tileSizeComputationFunction` to a
/// function that computes tile sizes at the point they are needed. Allows
/// proper interaction with folding.
- SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> ts);
+ SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> tileSizes);
+
+ /// Computation function that returns the maximum number of tile to use for
+ /// each loop. Returning a tile size of zero implies no tiling for that loop.
+ /// If the size of the returned vector is smaller than the number of loops,
+ /// the inner loops are not tiled. If the size of the returned vector is
+ /// larger, then the vector is truncated to number of loops. Only one of
+ /// `tileSizeComputationFunction` or `maxNumTilesComputationFunction` should
+ /// be used.
+ SCFTileSizeComputationFunction maxNumTilesComputationFunction = nullptr;
+
+ SCFTilingOptions &
+ setMaxNumTilesComputationFunction(SCFTileSizeComputationFunction fun) {
+ maxNumTilesComputationFunction = std::move(fun);
+ return *this;
+ }
+ /// Convenience function to set the `tileSizeComputationFunction` to a
+ /// function that computes tile sizes at the point they are needed.
+ SCFTilingOptions &setMaxNumTiles(ArrayRef<OpFoldResult> numTiles);
/// The interchange vector to reorder the tiled loops.
SmallVector<int64_t> interchangeVector = {};
@@ -67,9 +89,8 @@ struct SCFTilingOptions {
/// when using loop constructs that dont support such a mapping (like
/// `scf.for`)
SmallVector<Attribute> mappingVector = {};
- SCFTilingOptions &setMapping(ArrayRef<DeviceMappingAttrInterface> mapping) {
- mappingVector = llvm::map_to_vector(
- mapping, [](auto attr) -> Attribute { return attr; });
+ SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
+ mappingVector = llvm::to_vector(mapping);
return *this;
}
};
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 9b3121774ab3a..bce0d8a0f65db 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2919,7 +2919,7 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
TransformOpInterface transformOp, Operation *target,
ArrayRef<OpFoldResult> mixedNumThreads,
ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
- linalg::ForallTilingResult &tilingResult) {
+ scf::SCFTilingResult &tilingResult) {
// Transform all targets one by one.
auto tileableOp = dyn_cast<TilingInterface>(target);
if (!tileableOp) {
@@ -2930,18 +2930,38 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
return diag;
}
rewriter.setInsertionPoint(tileableOp);
- FailureOr<linalg::ForallTilingResult> maybeTilingResult = failure();
+ scf::SCFTilingOptions options;
+ options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
if (!mixedNumThreads.empty()) {
- maybeTilingResult =
- linalg::tileToForallOp(rewriter, tileableOp, mixedNumThreads, mapping);
+ options.setMaxNumTiles(mixedNumThreads);
} else {
- maybeTilingResult = linalg::tileToForallOpUsingTileSizes(
- rewriter, tileableOp, mixedTileSizes, mapping);
+ SmallVector<Range> loopRanges = tileableOp.getIterationDomain(rewriter);
+ unsigned nLoops = loopRanges.size();
+ SmallVector<OpFoldResult> numThreads;
+ numThreads.reserve(nLoops);
+ AffineExpr s0, s1;
+ bindSymbols(rewriter.getContext(), s0, s1);
+ AffineExpr divExpr = s0.ceilDiv(s1);
+ for (int i = 0, e = std::min(mixedTileSizes.size(), loopRanges.size());
+ i < e; ++i) {
+ OpFoldResult numTiles = mixedTileSizes[i];
+ if (!isConstantIntValue(numTiles, 0))
+ numTiles = affine::makeComposedFoldedAffineApply(
+ rewriter, tileableOp.getLoc(), divExpr,
+ {loopRanges[i].size, numTiles});
+ numThreads.push_back(numTiles);
+ }
+ options.setMaxNumTiles(numThreads);
+ }
+ if (mapping) {
+ options.setMapping(mapping.value().getValue());
}
+ FailureOr<scf::SCFTilingResult> maybeTilingResult =
+ scf::tileUsingSCF(rewriter, tileableOp, options);
if (failed(maybeTilingResult))
return transformOp.emitDefaultSilenceableFailure(tileableOp);
- rewriter.replaceOp(tileableOp, maybeTilingResult->tileOp->getResults());
+ rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
tilingResult = *maybeTilingResult;
return DiagnosedSilenceableFailure::success();
@@ -2977,14 +2997,14 @@ DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
return status;
for (Operation *target : state.getPayloadOps(getTarget())) {
- linalg::ForallTilingResult tilingResult;
+ scf::SCFTilingResult tilingResult;
DiagnosedSilenceableFailure diag = tileToForallOpImpl(
rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
getMapping(), tilingResult);
if (!diag.succeeded())
return diag;
- tileOps.push_back(tilingResult.tileOp);
- tiledOps.push_back(tilingResult.tiledOp);
+ tileOps.push_back(tilingResult.loops.front());
+ tiledOps.append(tilingResult.tiledOps);
}
transformResults.set(cast<OpResult>(getForallOp()), tileOps);
@@ -3462,7 +3482,7 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
// OpBuilder only used to compute attributes.
OpBuilder b(getContext());
- linalg::ForallTilingResult tilingResult;
+ scf::SCFTilingResult tilingResult;
DiagnosedSilenceableFailure diag = tileToForallOpImpl(
/*rewriter=*/rewriter,
/*state=*/state,
@@ -3475,8 +3495,9 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
if (!diag.succeeded())
return diag;
- results.push_back(tilingResult.tileOp);
- results.push_back(tilingResult.tiledOp);
+ results.push_back(tilingResult.loops.front());
+ for (auto op : tilingResult.tiledOps)
+ results.push_back(op);
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index a0a0e11a6903d..aa2f3a7db2946 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -304,188 +304,6 @@ static void calculateTileOffsetsAndSizes(
}
}
-/// Returns a vector of bools representing if, for each axis, `op` can be tiled
-/// without incurring in a race condition and thus it is thread-safe to do the
-/// tiling. This is checked by iterating over numThreads and ensuring that the
-/// corresponding iterator type is "parallel". If it is not, then we know that
-/// such dimension is unsafe to tile.
-SmallVector<bool> safeToTileToForall(mlir::MLIRContext *ctx, LinalgOp linalgOp,
- ArrayRef<OpFoldResult> numThreads) {
- auto iterators = linalgOp.getIteratorTypesArray();
- SmallVector<bool> safeToTile(numThreads.size(), true);
-
- for (unsigned i = 0, e = numThreads.size(); i != e; i++) {
- if (auto attr = llvm::dyn_cast_if_present<Attribute>(numThreads[i])) {
- if (cast<IntegerAttr>(attr).getValue().getSExtValue() > 1) {
- safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
- }
- } else {
- safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
- }
- }
- return safeToTile;
-}
-
-/// Rewrite a TilingInterface `op` to a tiled `scf.forall`. The
-/// tiling is specified by the number of tiles/threads `numThreads` and the
-/// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is
-/// not specified, then it is derived from `numThreads` as `ceilDiv(dimSize[i],
-/// numThreads[i])`. If non-empty, the `mapping` is added as an
-/// attribute to the resulting `scf.forall`. A zero tile sizes indicate
-/// that the dimension is not tiled, and can be thought of as tiling by the full
-/// size of data.
-/// It is the user's responsibility to ensure that `numThreads` is a valid
-/// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the
-/// Linalg case). If the dimension is not parallelizable, a warning is issued to
-/// notify the user that the generated code is not safe to parallelize. If
-/// `omitTileOffsetBoundsCheck` is true, then the function will assume that
-/// `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds.
-static FailureOr<ForallTilingResult> tileToForallOpImpl(
- RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> numThreads,
- std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
- std::optional<ArrayAttr> mapping, bool omitTileOffsetBoundsCheck) {
- Location loc = op->getLoc();
- OpBuilder::InsertionGuard g(b);
-
- SmallVector<Range> loopRanges = op.getIterationDomain(b);
- if (loopRanges.empty())
- return op->emitOpError("expected non-empty loop ranges");
- auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
- if (llvm::any_of(loopRanges, hasStrideOne))
- return op->emitOpError("only stride-1 supported atm");
-
- // Gather destination tensors.
- SmallVector<Value> dest;
- if (failed(tensor::getOrCreateDestinations(b, loc, op, dest)))
- return op->emitOpError("failed to get destination tensors");
-
- SmallVector<OpFoldResult> nonZeroNumThreads =
- llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
- return !isConstantIntValue(ofr, 0);
- }));
- SmallVector<Value> materializedNonZeroNumThreads =
- llvm::to_vector(llvm::map_range(nonZeroNumThreads, [&](OpFoldResult ofr) {
- return getValueOrCreateConstantIndexOp(b, loc, ofr);
- }));
-
- LinalgOp linalgOp = dyn_cast<LinalgOp>(op.getOperation());
- if (linalgOp) {
- // Check if tiling is thread safe and print a warning if not.
- SmallVector<bool> tilingSafety =
- safeToTileToForall(b.getContext(), linalgOp, numThreads);
- for (size_t i = 0; i < tilingSafety.size(); i++)
- if (!tilingSafety[i])
- op.emitWarning() << "tiling is not thread safe at axis #" << i;
- }
-
- // 1. Create the ForallOp. We don't use the lambda body-builder
- // version because we require the use of RewriterBase in the body, so we
- // manually move the insertion point to the body below.
- scf::ForallOp forallOp = b.create<scf::ForallOp>(
- loc, getAsOpFoldResult((materializedNonZeroNumThreads)), dest, mapping);
-
- // 2. Fill out the ForallOp body.
- SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
- calculateTileOffsetsAndSizes(b, loc, forallOp, numThreads, loopRanges,
- omitTileOffsetBoundsCheck, nominalTileSizes,
- tiledOffsets, tiledSizes);
-
- // 3. Clone the tileable op and update its destination operands to use the
- // output bbArgs of the ForallOp.
- ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
- Operation *tiledOp = nullptr;
- SmallVector<Value> tiledValues;
- {
- // 3.a. RAII guard, inserting within forallOp, before terminator.
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(forallOp.getTerminator());
- Operation *clonedOp = b.clone(*op.getOperation());
- auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp);
- if (destinationStyleOp) {
- for (OpOperand &outOperand : destinationStyleOp.getDpsInitsMutable()) {
- // Swap tensor inits with the corresponding block argument of the
- // scf.forall op. Memref inits remain as is.
- if (isa<TensorType>(outOperand.get().getType())) {
- auto *it = llvm::find(dest, outOperand.get());
- assert(it != dest.end() && "could not find destination tensor");
- unsigned destNum = std::distance(dest.begin(), it);
- outOperand.set(destBbArgs[destNum]);
- }
- }
- }
-
- // 4. Tile the cloned op and delete the clone.
- FailureOr<TilingResult> tilingResult =
- cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
- tiledSizes);
- if (failed(tilingResult))
- return clonedOp->emitError("Failed to tile op: ");
- if (tilingResult->tiledOps.size() != 1) {
- return clonedOp->emitError("expected a single produced tiled op, got ")
- << tilingResult->tiledOps.size();
- }
-
- b.eraseOp(clonedOp);
- tiledOp = tilingResult->tiledOps.front();
- tiledValues = tilingResult->tiledValues;
- }
-
- // 5. Parallel insert back into the result tensor.
- for (auto it : llvm::zip(llvm::seq(unsigned(0), unsigned(dest.size())),
- tiledValues, destBbArgs)) {
- // 5.a. Partial subset information is inserted just before the terminator.
- OpBuilder::InsertionGuard g(b);
- b.setInsertionPoint(forallOp.getTerminator());
-
- SmallVector<OpFoldResult> resultOffsets, resultSizes;
- if (failed(op.getResultTilePosition(b, std::get<0>(it), tiledOffsets,
- tiledSizes, resultOffsets,
- resultSizes)))
- return op->emitOpError("output offsets couldn't be calculated");
- SmallVector<OpFoldResult> strides(resultSizes.size(), b.getIndexAttr(1));
-
- // 5.b. Parallel insertions are inserted at the end of the combining
- // terminator.
- b.setInsertionPointToEnd(forallOp.getTerminator().getBody());
- b.create<tensor::ParallelInsertSliceOp>(loc, std::get<1>(it),
- std::get<2>(it), resultOffsets,
- resultSizes, strides);
- }
- return ForallTilingResult{forallOp, tiledOp};
-}
-
-FailureOr<ForallTilingResult>
-linalg::tileToForallOp(RewriterBase &b, TilingInterface op,
- ArrayRef<OpFoldResult> numThreads,
- std::optional<ArrayAttr> mapping) {
- return tileToForallOpImpl(b, op, numThreads,
- /*nominalTileSizes=*/std::nullopt, mapping,
- /*omitTileOffsetBoundsCheck=*/false);
-}
-
-FailureOr<ForallTilingResult>
-linalg::tileToForallOpUsingTileSizes(RewriterBase &b, TilingInterface op,
- ArrayRef<OpFoldResult> tileSizes,
- std::optional<ArrayAttr> mapping) {
- SmallVector<Range> loopRanges = op.getIterationDomain(b);
- unsigned nLoops = loopRanges.size();
- SmallVector<OpFoldResult> numThreads;
- numThreads.reserve(nLoops);
- AffineExpr s0, s1;
- bindSymbols(b.getContext(), s0, s1);
- AffineExpr divExpr = s0.ceilDiv(s1);
- for (const auto &it : llvm::zip(tileSizes, loopRanges)) {
- OpFoldResult numTiles = std::get<0>(it);
- if (!isConstantIntValue(numTiles, 0))
- numTiles = makeComposedFoldedAffineApply(
- b, op.getLoc(), divExpr, {std::get<1>(it).size, std::get<0>(it)});
- numThreads.push_back(numTiles);
- }
- return tileToForallOpImpl(b, op, numThreads,
- /*nominalTileSizes=*/tileSizes, mapping,
- /*omitTileOffsetBoundsCheck=*/true);
-}
-
template <typename LoopTy>
static FailureOr<TiledLinalgOp>
tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index f3d6b7a530117..c328e1068dccc 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -42,6 +42,16 @@ scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
return *this;
}
+scf::SCFTilingOptions &
+scf::SCFTilingOptions::setMaxNumTiles(ArrayRef<OpFoldResult> mnt) {
+ assert(!maxNumTilesComputationFunction && "max num tiles already set");
+ auto maxNumTiles = llvm::to_vector(mnt);
+ maxNumTilesComputationFunction = [maxNumTiles](OpBuilder &b, Operation *op) {
+ return maxNumTiles;
+ };
+ return *this;
+}
+
/// Helper method to adjust the interchange vector to match the iteration
/// domain.
static SmallVector<int64_t>
@@ -61,6 +71,85 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
// tileUsingSCF implementation.
//===----------------------------------------------------------------------===//
+/// Verify the tile size options are set in a consistent manner.
+static LogicalResult
+verifyTileSizeOptions(RewriterBase &rewriter, Location loc,
+ const scf::SCFTilingOptions &options) {
+ if (!options.tileSizeComputationFunction &&
+ !options.maxNumTilesComputationFunction) {
+ return rewriter.notifyMatchFailure(
+ loc, "at least one of tile size computation function or max num tiles "
+ "computation must be specified.");
+ }
+ if (options.tileSizeComputationFunction &&
+ options.maxNumTilesComputationFunction) {
+ return rewriter.notifyMatchFailure(
+ loc, "only one of tile size computation function or max num tiles "
+ "computation function can be specified");
+ }
+
+ // If specified, check that the interchange vector is a permutation.
+ if (!options.interchangeVector.empty()) {
+ if (!isPermutationVector(options.interchangeVector)) {
+ return rewriter.notifyMatchFailure(
+ loc, "invalid intechange vector, not a permutation of the entire "
+ "iteration space");
+ }
+ }
+ return success();
+}
+
+/// Compute the tile sizes and num tiles values. The `numTiles`
+/// is empty if the `maxNumTilesComputationFunction` is not specified.
+static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
+getTileSizesAndNumTiles(RewriterBase &rewriter, TilingInterface op,
+ ArrayRef<Range> iterationDomain,
+ const scf::SCFTilingOptions &options) {
+ SmallVector<OpFoldResult> tileSizes, numTiles;
+
+ // Enforce the convention that "tiling by zero"
+ // skips tiling a particular dimension. This convention is significantly
+ // simpler to handle instead of adjusting affine maps to account for missing
+ // dimensions.
+ auto numLoops = iterationDomain.size();
+ if (options.tileSizeComputationFunction) {
+ tileSizes = options.tileSizeComputationFunction(rewriter, op);
+ tileSizes.resize(numLoops, rewriter.getIndexAttr(0));
+ return {tileSizes, numTiles};
+ }
+
+ assert(options.maxNumTilesComputationFunction &&
+ "expected at least one of tile sizes cpomputation function or max num "
+ "tiles computation function");
+ // Enforce the convention that "maxNumTiles to zero"
+ // skips tiling a particular dimension. This convention is significantly
+ // simpler to handle instead of adjusting affine maps to account for missing
+ // dimensions.
+ SmallVector<OpFoldResult> maxNumTiles =
+ options.maxNumTilesComputationFunction(rewriter, op);
+ maxNumTiles.resize(numLoops, rewriter.getIndexAttr(0));
+
+ // Use the maxNumTiles to compute the tile sizes as
+ // - niters = ceilDiv(ub - lb, step)
+ // - tileSize = ceilDiv(niters, maxNumTiles)
+ AffineExpr s0, s1, s2, s3;
+ bindSymbols(rewriter.getContext(), s0, s1, s2, s3);
+ AffineExpr numIters = (s1 - s0).ceilDiv(s2);
+ AffineExpr tileSizeExpr = numIters.ceilDiv(s3);
+ tileSizes.resize(numLoops, rewriter.getIndexAttr(0));
+ for (auto [index, maxNumTile] : llvm::enumerate(maxNumTiles)) {
+ if (isConstantIntValue(maxNumTile, 0))
+ continue;
+
+ tileSizes[index] = affine::makeComposedFoldedAffineApply(
+ rewriter, op.getLoc(), tileSizeExpr,
+ {iterationDomain[index].offset, iterationDomain[index].size,
+ iterationDomain[index].stride, maxNumTile});
+ }
+
+ return {tileSizes, maxNumTiles};
+}
+
// Check if `stride` evenly divides the trip count `size - offset`.
static bool tileDividesIterationDomain(Range loopRange) {
std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
@@ -100,6 +189,46 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
}
+/// Compute the tile offsets and sizes.
+static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
+getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
+ ArrayRef<Range> iterationDomain,
+ ArrayRef<OpFoldResult> tileSizes, bool isLoopNormalized) {
+ SmallVector<OpFoldResult> offsets, sizes;
+ int materializedLoopNum = 0;
+
+ AffineExpr d0, s0, s1, s2;
+ AffineExpr offsetExpr;
+ if (isLoopNormalized) {
+ bindDims(rewriter.getContext(), d0);
+ bindSymbols(rewriter.getContext(), s0, s1, s2);
+ offsetExpr = s0 + d0 * s1 * s2;
+ }
+
+ for (auto [tileSize, loopRange] :
+ llvm::zip_equal(tileSizes, iterationDomain)) {
+ if (isConstantIntValue(tileSize, 0)) {
+ offsets.push_back(loopRange.offset);
+ sizes.push_back(loopRange.size);
+ continue;
+ }
+ // If loop is normalized, the offset is (lb + iv * step * tileSize)
+ Value iv = ivs[materializedLoopNum++];
+ OpFoldResult offset;
+ if (isLoopNormalized) {
+ offset = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, offsetExpr,
+ ArrayRef<OpFoldResult>{iv, loopRange.offset, loopRange.stride,
+ tileSize});
+ } else {
+ offset = getAsOpFoldResult(iv);
+ }
+ offsets.push_back(offset);
+ sizes.push_back(getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
+ }
+ return {offsets, sizes};
+}
+
/// A function that allows returning additional yielded values during
/// `yieldTiledValuesAndReplace`.
/// - `ivs` induction variable for the loop.
@@ -145,8 +274,8 @@ static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
/// populated.
static LogicalResult generateLoopNestUsingForOp(
RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
- ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
- YieldTiledValuesFn yieldTiledValuesFn,
+ ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numTiles,
+ ValueRange destinationTensors, YieldTiledValuesFn yieldTiledValuesFn,
SmallVector<LoopLikeOpInterface> &loops) {
assert(!loopRanges.empty() && "unexpected empty loop ranges");
assert(loopRanges.size() == tileSizes.size() &&
@@ -154,15 +283,30 @@ static LogicalResult generateLoopNestUsingForOp(
OpBuilder::InsertionGuard guard(rewriter);
SmallVector<Value> ivs;
- for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
+ Value zero, one;
+ if (!numTiles.empty()) {
+ zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
+ ;
+ one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
+ }
+
+ for (auto [index, loopRange, tileSize] :
+ llvm::enumerate(loopRanges, tileSizes)) {
// No loops if tile size is zero. Set offset and size to the loop
// offset and size.
if (isConstantIntValue(tileSize, 0))
continue;
- Value lb = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
- Value ub = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
- Value step = getValueOrCreateConstantIndexOp(rewriter, loc, tileSize);
+ Value lb, ub, step;
+ if (numTiles.empty()) {
+ lb = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
+ ub = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
+ step = getValueOrCreateConstantIndexOp(rewriter, loc, tileSize);
+ } else {
+ lb = zero;
+ ub = getValueOrCreateConstantIndexOp(rewriter, loc, numTiles[index]);
+ step = one;
+ }
auto loop =
rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
[](OpBuilder &bodyBuilder, Location bodyLoc,
@@ -224,32 +368,45 @@ static LogicalResult generateLoopNestUsingForOp(
/// populated.
static LogicalResult generateLoopNestUsingForallOp(
RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
- ArrayRef<OpFoldResult> tileSizes, ArrayRef<Attribute> mappingVector,
- ValueRange destinationTensors, YieldTiledValuesFn tiledBodyFn,
- SmallVector<LoopLikeOpInterface> &loops) {
- SmallVector<OpFoldResult> lbs, ubs, steps;
+ ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numTiles,
+ ArrayRef<Attribute> mappingVector, ValueRange destinationTensors,
+ YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
assert(!loopRanges.empty() && "unexpected empty loop ranges");
assert(loopRanges.size() == tileSizes.size() &&
"expected as many tile sizes as loop ranges");
+ assert((numTiles.empty() || numTiles.size() == loopRanges.size()) &&
+ "expected max number of tiles to be either empty or equal to number "
+ "of loops");
OpBuilder::InsertionGuard guard(rewriter);
SmallVector<OpFoldResult> offsets(loopRanges.size()),
sizes(loopRanges.size());
- for (auto [tileSize, loopRange] : llvm::zip_equal(tileSizes, loopRanges)) {
- if (isConstantIntValue(tileSize, 0))
- continue;
- lbs.push_back(loopRange.offset);
- ubs.push_back(loopRange.size);
- steps.push_back(tileSize);
- }
- assert(!lbs.empty() && "Expected at least one loop range");
-
std::optional<ArrayAttr> mappingAttr;
if (!mappingVector.empty())
mappingAttr = rewriter.getArrayAttr(mappingVector);
- auto forallOp = rewriter.create<scf::ForallOp>(
- loc, lbs, ubs, steps, destinationTensors, mappingAttr);
+ scf::ForallOp forallOp;
+ SmallVector<OpFoldResult> lbs, ubs, steps;
+ if (numTiles.empty()) {
+ for (auto [tileSize, loopRange] : llvm::zip_equal(tileSizes, loopRanges)) {
+ if (isConstantIntValue(tileSize, 0))
+ continue;
+ lbs.push_back(loopRange.offset);
+ ubs.push_back(loopRange.size);
+ steps.push_back(tileSize);
+ }
+ assert(!lbs.empty() && "Expected at least one loop range");
+ forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps,
+ destinationTensors, mappingAttr);
+ } else {
+ SmallVector<OpFoldResult> numThreads;
+ for (auto maxNumTile : numTiles) {
+ if (!isConstantIntValue(maxNumTile, 0))
+ numThreads.push_back(maxNumTile);
+ }
+ forallOp = rewriter.create<scf::ForallOp>(loc, numThreads,
+ destinationTensors, mappingAttr);
+ }
loops.push_back(forallOp);
rewriter.setInsertionPoint(forallOp.getTerminator());
@@ -286,13 +443,11 @@ static LogicalResult generateLoopNestUsingForallOp(
/// loop.
/// - `loops` is an in-out parameter into which the generated loops are
/// populated.
-static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc,
- const scf::SCFTilingOptions &options,
- ArrayRef<Range> loopRanges,
- ArrayRef<OpFoldResult> tileSizes,
- ValueRange destinationTensors,
- YieldTiledValuesFn tiledBodyFn,
- SmallVector<LoopLikeOpInterface> &loops) {
+static LogicalResult generateLoopNest(
+ RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options,
+ ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> tileSizes,
+ ArrayRef<OpFoldResult> numTiles, ValueRange destinationTensors,
+ YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
// If the tile sizes are all zero, no loops are generated. Just call the
// callback function to handle untiled case.
if (llvm::all_of(tileSizes, isZeroIndex)) {
@@ -303,11 +458,12 @@ static LogicalResult generateLoopNest(RewriterBase &rewriter, Location loc,
}
if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) {
return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
- destinationTensors, tiledBodyFn, loops);
+ numTiles, destinationTensors, tiledBodyFn,
+ loops);
}
if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
return generateLoopNestUsingForallOp(
- rewriter, loc, loopRanges, tileSizes, options.mappingVector,
+ rewriter, loc, loopRanges, tileSizes, numTiles, options.mappingVector,
destinationTensors, tiledBodyFn, loops);
}
return rewriter.notifyMatchFailure(loc, "unhandled loop type");
@@ -531,28 +687,20 @@ static LogicalResult addInitOperandsToLoopNest(
FailureOr<scf::SCFTilingResult>
mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
const scf::SCFTilingOptions &options) {
+ if (failed(verifyTileSizeOptions(rewriter, op.getLoc(), options))) {
+ return failure();
+ }
+
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(op);
- if (!options.tileSizeComputationFunction) {
- return rewriter.notifyMatchFailure(
- op, "missing tile size computation function");
- }
-
// 1. Get the range of the loops that are represented by the operation.
SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
- size_t numLoops = iterationDomain.size();
- // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
- // skips tiling a particular dimension. This convention is significantly
- // simpler to handle instead of adjusting affine maps to account for missing
- // dimensions.
- SmallVector<OpFoldResult> tileSizes =
- options.tileSizeComputationFunction(rewriter, op);
- if (tileSizes.size() < iterationDomain.size()) {
- auto zero = rewriter.getIndexAttr(0);
- tileSizes.append(numLoops - tileSizes.size(), zero);
- }
+ // 2. Materialize the tile sizes or max num tiles;
+ SmallVector<OpFoldResult> tileSizes, numTiles;
+ std::tie(tileSizes, numTiles) =
+ getTileSizesAndNumTiles(rewriter, op, iterationDomain, options);
// 3. If there is an interchange specified, permute the iteration domain and
// the tile sizes.
@@ -560,16 +708,13 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
if (!options.interchangeVector.empty()) {
interchangeVector = fillInterchangeVector(options.interchangeVector,
iterationDomain.size());
- }
- if (!interchangeVector.empty()) {
- if (!isPermutationVector(interchangeVector)) {
- return rewriter.notifyMatchFailure(
- op, "invalid intechange vector, not a permutation of the entire "
- "iteration space");
- }
+ assert(isPermutationVector(interchangeVector) &&
+ "expected interchange vector to be a permutation");
applyPermutationToVector(iterationDomain, interchangeVector);
applyPermutationToVector(tileSizes, interchangeVector);
+ if (!numTiles.empty())
+ applyPermutationToVector(numTiles, interchangeVector);
}
FailureOr<TilingResult> tilingResult;
@@ -583,21 +728,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
-> LogicalResult {
// 4a. Compute the `offsets` and `sizes` to use for tiling.
SmallVector<OpFoldResult> offsets, sizes;
- {
- int materializedLoopNum = 0;
- for (auto [tileSize, loopRange] :
- llvm::zip_equal(tileSizes, iterationDomain)) {
- if (isConstantIntValue(tileSize, 0)) {
- offsets.push_back(loopRange.offset);
- sizes.push_back(loopRange.size);
- continue;
- }
- Value iv = ivs[materializedLoopNum++];
- offsets.push_back(iv);
- sizes.push_back(
- getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
- }
- }
+ std::tie(offsets, sizes) = getTileOffsetAndSizes(
+ rewriter, loc, ivs, iterationDomain, tileSizes, !numTiles.empty());
// 4b. If interchange was provided, apply inverse of the interchange
// to get back the offsets/sizes in the order to be specified.
@@ -665,7 +797,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
// 7. Generate the tiled loops nest using the callback defined above.
SmallVector<LoopLikeOpInterface> loops;
if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
- tileSizes, destinationTensors,
+ tileSizes, numTiles, destinationTensors,
innerYieldTiledValuesFn, loops)))
return op.emitOpError("failed to generate tiling loops");
assert(succeeded(tilingResult) &&
@@ -774,6 +906,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
scf::SCFTilingOptions options;
options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector,
+ /*numTiles=*/ArrayRef<OpFoldResult>{},
initTensors, innerYieldTiledValuesFn, loops)))
return b.notifyMatchFailure(op, "failed to tile for parallel reduction");
diff --git a/mlir/test/Dialect/Linalg/tile-to-forall.mlir b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
index 8545dfd25eccf..f33739f119eaf 100644
--- a/mlir/test/Dialect/Linalg/tile-to-forall.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
@@ -177,7 +177,6 @@ module attributes {transform.with_named_sequence} {
}
}
-
// -----
// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 833fb3cc65b81..abe41782c9b60 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -235,11 +235,7 @@ applyTileToAll(RewriterBase &rewriter, Operation *transformOp,
scf::SCFTilingOptions tilingOptions;
tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
if (mapping) {
- auto mappingAttrs =
- llvm::map_to_vector(mapping.value(), [](Attribute attr) {
- return cast<DeviceMappingAttrInterface>(attr);
- });
- tilingOptions.setMapping(mappingAttrs);
+ tilingOptions.setMapping(mapping.value().getValue());
}
tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
>From 38398d04edbf733274641806cc4c4d7c9c1d1c11 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Mon, 20 May 2024 23:17:56 -0700
Subject: [PATCH 2/9] Allow specifying both numThreads and tileSizes to keep
the same existing semantics of distribution using number of threads.
---
.../SCF/Transforms/TileUsingInterface.h | 28 +-
.../TransformOps/LinalgTransformOps.cpp | 5 +-
.../SCF/Transforms/TileUsingInterface.cpp | 266 ++++++++++--------
mlir/test/Dialect/Linalg/tile-to-forall.mlir | 52 ++--
.../Dialect/Linalg/transform-op-tile.mlir | 29 +-
.../tile-pad-using-interface.mlir | 10 +-
.../TilingInterface/tile-using-interface.mlir | 50 ++--
7 files changed, 234 insertions(+), 206 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 451a21c766175..20081e853c882 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -36,9 +36,7 @@ struct SCFTilingOptions {
/// Returning a tile size of zero implies no tiling for that loop. If the
/// size of the returned vector is smaller than the number of loops, the inner
/// loops are not tiled. If the size of the returned vector is larger, then
- /// the vector is truncated to number of loops. Only one of
- /// `tileSizeComputationFunction` or `maxNumTilesComputationFunction` should
- /// be used.
+ /// the vector is truncated to number of loops.
SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr;
SCFTilingOptions &
@@ -51,23 +49,25 @@ struct SCFTilingOptions {
/// proper interaction with folding.
SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> tileSizes);
- /// Computation function that returns the maximum number of tile to use for
- /// each loop. Returning a tile size of zero implies no tiling for that loop.
- /// If the size of the returned vector is smaller than the number of loops,
- /// the inner loops are not tiled. If the size of the returned vector is
- /// larger, then the vector is truncated to number of loops. Only one of
- /// `tileSizeComputationFunction` or `maxNumTilesComputationFunction` should
- /// be used.
- SCFTileSizeComputationFunction maxNumTilesComputationFunction = nullptr;
+ /// Computation function that returns the number of threads to use for
+ /// each loop. Returning a num threads of zero implies no tiling for that
+ /// loop. If the size of the returned vector is smaller than the number of
+ /// loops, the inner loops are not tiled. If the size of the returned vector
+ /// is larger, then the vector is truncated to number of loops. Note: This
+ /// option is only supported with loopType set to `LoopType::ForallOp`. If the
+ /// tile size function is not specified while the num threads computation is,
+ /// then the tile size is determined automatically to map at most one tile per
+ /// thread.
+ SCFTileSizeComputationFunction numThreadsComputationFunction = nullptr;
SCFTilingOptions &
- setMaxNumTilesComputationFunction(SCFTileSizeComputationFunction fun) {
- maxNumTilesComputationFunction = std::move(fun);
+ setNumThreadsComputationFunction(SCFTileSizeComputationFunction fun) {
+ numThreadsComputationFunction = std::move(fun);
return *this;
}
/// Convenience function to set the `tileSizeComputationFunction` to a
/// function that computes tile sizes at the point they are needed.
- SCFTilingOptions &setMaxNumTiles(ArrayRef<OpFoldResult> numTiles);
+ SCFTilingOptions &setNumThreads(ArrayRef<OpFoldResult> numThreads);
/// The interchange vector to reorder the tiled loops.
SmallVector<int64_t> interchangeVector = {};
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index bce0d8a0f65db..8bf7db2e15061 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2933,7 +2933,7 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
scf::SCFTilingOptions options;
options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
if (!mixedNumThreads.empty()) {
- options.setMaxNumTiles(mixedNumThreads);
+ options.setNumThreads(mixedNumThreads);
} else {
SmallVector<Range> loopRanges = tileableOp.getIterationDomain(rewriter);
unsigned nLoops = loopRanges.size();
@@ -2951,7 +2951,8 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
{loopRanges[i].size, numTiles});
numThreads.push_back(numTiles);
}
- options.setMaxNumTiles(numThreads);
+ options.setNumThreads(numThreads);
+ options.setTileSizes(mixedTileSizes);
}
if (mapping) {
options.setMapping(mapping.value().getValue());
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index c328e1068dccc..c9a56c66e9255 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -43,11 +43,11 @@ scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
}
scf::SCFTilingOptions &
-scf::SCFTilingOptions::setMaxNumTiles(ArrayRef<OpFoldResult> mnt) {
- assert(!maxNumTilesComputationFunction && "max num tiles already set");
- auto maxNumTiles = llvm::to_vector(mnt);
- maxNumTilesComputationFunction = [maxNumTiles](OpBuilder &b, Operation *op) {
- return maxNumTiles;
+scf::SCFTilingOptions::setNumThreads(ArrayRef<OpFoldResult> nt) {
+ assert(!numThreadsComputationFunction && "num tiles already set");
+ auto numThreads = llvm::to_vector(nt);
+ numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) {
+ return numThreads;
};
return *this;
}
@@ -75,17 +75,12 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
static LogicalResult
verifyTileSizeOptions(RewriterBase &rewriter, Location loc,
const scf::SCFTilingOptions &options) {
- if (!options.tileSizeComputationFunction &&
- !options.maxNumTilesComputationFunction) {
+ // Specifying number of tile is only supported on `scf.forall` op.
+ if (options.numThreadsComputationFunction &&
+ options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) {
return rewriter.notifyMatchFailure(
- loc, "at least one of tile size computation function or max num tiles "
- "computation must be specified.");
- }
- if (options.tileSizeComputationFunction &&
- options.maxNumTilesComputationFunction) {
- return rewriter.notifyMatchFailure(
- loc, "only one of tile size computation function or max num tiles "
- "computation function can be specified");
+ loc, "number of tiles/threads can only by specified when loop type is "
+ "set to use `scf.forall`");
}
// If specified, check that the interchange vector is a permutation.
@@ -99,58 +94,94 @@ verifyTileSizeOptions(RewriterBase &rewriter, Location loc,
return success();
}
-/// Compute the tile sizes and num tiles values. The `numTiles`
-/// is empty if the `maxNumTilesComputationFunction` is not specified.
+/// Compute the tile sizes and num threads values passed in.
static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
-getTileSizesAndNumTiles(RewriterBase &rewriter, TilingInterface op,
- ArrayRef<Range> iterationDomain,
- const scf::SCFTilingOptions &options) {
- SmallVector<OpFoldResult> tileSizes, numTiles;
+getTileSizes(RewriterBase &rewriter, TilingInterface op,
+ ArrayRef<Range> iterationDomain,
+ const scf::SCFTilingOptions &options) {
+ OpFoldResult zero = rewriter.getIndexAttr(0);
+ SmallVector<OpFoldResult> tileSizes, numThreads;
+ size_t numLoops = iterationDomain.size();
+
+ // Check whether the number of tiles to use is specified.
+ if (options.numThreadsComputationFunction) {
+ numThreads = options.numThreadsComputationFunction(rewriter, op);
+ numThreads.resize(numLoops, zero);
+
+ // If the number of tiles is also specified, use that.
+ if (options.tileSizeComputationFunction) {
+ tileSizes = options.tileSizeComputationFunction(rewriter, op);
+ } else {
+ // Compute the tile sizes from the iteration domain and number
+ // of tiles as follows
+ // - niters = ceilDiv(ub - lb, step)
+ // - tileSize = ceilDiv(niters, numThreads)
+ AffineExpr s0, s1, s2, s3;
+ bindSymbols(rewriter.getContext(), s0, s1, s2, s3);
+ AffineExpr numItersExpr = (s1 - s0).ceilDiv(s2);
+ AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s3);
+ tileSizes.resize(numLoops, zero);
+ for (auto [index, range, nt] :
+ llvm::enumerate(iterationDomain, numThreads)) {
+ if (isConstantIntValue(nt, 0))
+ continue;
+
+ tileSizes[index] = affine::makeComposedFoldedAffineApply(
+ rewriter, op.getLoc(), tileSizeExpr,
+ {range.offset, range.size, range.stride, nt});
+ }
+ }
+ tileSizes.resize(numLoops, zero);
+ return {tileSizes, numThreads};
+ }
// Enforce the convention that "tiling by zero"
// skips tiling a particular dimension. This convention is significantly
// simpler to handle instead of adjusting affine maps to account for missing
// dimensions.
- auto numLoops = iterationDomain.size();
if (options.tileSizeComputationFunction) {
tileSizes = options.tileSizeComputationFunction(rewriter, op);
- tileSizes.resize(numLoops, rewriter.getIndexAttr(0));
- return {tileSizes, numTiles};
}
+ tileSizes.resize(numLoops, zero);
- assert(options.maxNumTilesComputationFunction &&
- "expected at least one of tile sizes cpomputation function or max num "
- "tiles computation function");
- // Enforce the convention that "maxNumTiles to zero"
- // skips tiling a particular dimension. This convention is significantly
- // simpler to handle instead of adjusting affine maps to account for missing
- // dimensions.
- SmallVector<OpFoldResult> maxNumTiles =
- options.maxNumTilesComputationFunction(rewriter, op);
- maxNumTiles.resize(numLoops, rewriter.getIndexAttr(0));
-
- // Use the maxNumTiles to compute the tile sizes as
- // - niters = ceilDiv(ub - lb, step)
- // - tileSize = ceilDiv(niters, maxNumTiles)
- AffineExpr s0, s1, s2, s3;
- bindSymbols(rewriter.getContext(), s0, s1, s2, s3);
- AffineExpr numIters = (s1 - s0).ceilDiv(s2);
- AffineExpr tileSizeExpr = numIters.ceilDiv(s3);
- tileSizes.resize(numLoops, rewriter.getIndexAttr(0));
- for (auto [index, maxNumTile] : llvm::enumerate(maxNumTiles)) {
- if (isConstantIntValue(maxNumTile, 0))
+ return {tileSizes, numThreads};
+}
+
+/// Checks if any of the tiled loops are not parallel.
+static void checkSafeToTileToForall(TilingInterface op,
+ ArrayRef<OpFoldResult> tileSizes,
+ ArrayRef<OpFoldResult> numThreads) {
+ auto iterators = op.getLoopIteratorTypes();
+ assert(iterators.size() == tileSizes.size() &&
+ "expected as many tile size values as number of loops");
+ assert((numThreads.empty() || (numThreads.size() == iterators.size())) &&
+ "when specified, expected number of threads to use for each loop");
+
+ for (auto [index, iterator, tileSize] :
+ llvm::enumerate(iterators, tileSizes)) {
+ // If num threads is specified, check that it is greater than one only for
+ // parallel dimensions.
+ if (!numThreads.empty()) {
+ if (std::optional<int64_t> constNumThreads =
+ getConstantIntValue(numThreads[index])) {
+ if (constNumThreads.value() > 1 &&
+ iterator != utils::IteratorType::parallel) {
+ op.emitWarning() << "tiling is not thread safe at axis #" << index;
+ }
+ }
continue;
+ }
- tileSizes[index] = affine::makeComposedFoldedAffineApply(
- rewriter, op.getLoc(), tileSizeExpr,
- {iterationDomain[index].offset, iterationDomain[index].size,
- iterationDomain[index].stride, maxNumTile});
+ if (std::optional<int64_t> constTileSize = getConstantIntValue(tileSize)) {
+ if (constTileSize.value() > 0 &&
+ iterator != utils::IteratorType::parallel) {
+ op.emitWarning() << "tiling is not thread safe at axis #" << index;
+ }
+ }
}
-
- return {tileSizes, maxNumTiles};
}
-// Check if `stride` evenly divides the trip count `size - offset`.
+/// Check if `stride` evenly divides the trip count `size - offset`.
static bool tileDividesIterationDomain(Range loopRange) {
std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
if (!offsetAsInt)
@@ -164,10 +195,10 @@ static bool tileDividesIterationDomain(Range loopRange) {
return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
}
-/// Returns the bounded tile size given the current `iv`, `loopRange` and
-/// `tileSize`, i.e., `min(tileSize, range.end() - iv)`.
+/// Returns the bounded tile size given the current `offset`, `loopRange` and
+/// `tileSize`, i.e., `min(tileSize, range.end() - offset)`.
static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
- Range loopRange, Value iv,
+ Range loopRange, OpFoldResult offset,
OpFoldResult tileSize) {
std::optional<int64_t> ts = getConstantIntValue(tileSize);
if (ts && ts.value() == 1)
@@ -186,7 +217,7 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext());
Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
return affine::makeComposedFoldedAffineMin(
- b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size});
+ b, loc, minMap, SmallVector<OpFoldResult>{offset, tileSize, size});
}
/// Compute the tile offsets and sizes.
@@ -224,11 +255,29 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
offset = getAsOpFoldResult(iv);
}
offsets.push_back(offset);
- sizes.push_back(getBoundedTileSize(rewriter, loc, loopRange, iv, tileSize));
+ sizes.push_back(
+ getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize));
}
return {offsets, sizes};
}
+/// Function to return the bounds of the loops to be generated.
+static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>,
+ SmallVector<OpFoldResult>>
+getLoopBounds(RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
+ ArrayRef<OpFoldResult> tileSizes) {
+ SmallVector<OpFoldResult> lbs, ubs, steps;
+ for (auto [loopRange, tileSize] : llvm::zip_equal(loopRanges, tileSizes)) {
+ // No loop if the tile size is 0.
+ if (isConstantIntValue(tileSize, 0))
+ continue;
+ lbs.push_back(loopRange.offset);
+ ubs.push_back(loopRange.size);
+ steps.push_back(tileSize);
+ }
+ return {lbs, ubs, steps};
+}
+
/// A function that allows returning additional yielded values during
/// `yieldTiledValuesAndReplace`.
/// - `ivs` induction variable for the loop.
@@ -274,39 +323,26 @@ static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
/// populated.
static LogicalResult generateLoopNestUsingForOp(
RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
- ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numTiles,
- ValueRange destinationTensors, YieldTiledValuesFn yieldTiledValuesFn,
+ ArrayRef<OpFoldResult> tileSizes, ValueRange destinationTensors,
+ YieldTiledValuesFn yieldTiledValuesFn,
SmallVector<LoopLikeOpInterface> &loops) {
assert(!loopRanges.empty() && "unexpected empty loop ranges");
assert(loopRanges.size() == tileSizes.size() &&
"expected as many tile sizes as loop ranges");
OpBuilder::InsertionGuard guard(rewriter);
- SmallVector<Value> ivs;
- Value zero, one;
- if (!numTiles.empty()) {
- zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
- ;
- one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
- }
-
- for (auto [index, loopRange, tileSize] :
- llvm::enumerate(loopRanges, tileSizes)) {
- // No loops if tile size is zero. Set offset and size to the loop
- // offset and size.
- if (isConstantIntValue(tileSize, 0))
- continue;
+ SmallVector<OpFoldResult> lbs, ubs, steps;
+ std::tie(lbs, ubs, steps) =
+ getLoopBounds(rewriter, loc, loopRanges, tileSizes);
+ SmallVector<Value> lbVals =
+ getValueOrCreateConstantIndexOp(rewriter, loc, lbs);
+ SmallVector<Value> ubVals =
+ getValueOrCreateConstantIndexOp(rewriter, loc, ubs);
+ SmallVector<Value> stepVals =
+ getValueOrCreateConstantIndexOp(rewriter, loc, steps);
- Value lb, ub, step;
- if (numTiles.empty()) {
- lb = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset);
- ub = getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size);
- step = getValueOrCreateConstantIndexOp(rewriter, loc, tileSize);
- } else {
- lb = zero;
- ub = getValueOrCreateConstantIndexOp(rewriter, loc, numTiles[index]);
- step = one;
- }
+ SmallVector<Value> ivs;
+ for (auto [lb, ub, step] : llvm::zip_equal(lbVals, ubVals, stepVals)) {
auto loop =
rewriter.create<scf::ForOp>(loc, lb, ub, step, destinationTensors,
[](OpBuilder &bodyBuilder, Location bodyLoc,
@@ -368,15 +404,12 @@ static LogicalResult generateLoopNestUsingForOp(
/// populated.
static LogicalResult generateLoopNestUsingForallOp(
RewriterBase &rewriter, Location loc, ArrayRef<Range> loopRanges,
- ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numTiles,
+ ArrayRef<OpFoldResult> tileSizes, ArrayRef<OpFoldResult> numThreads,
ArrayRef<Attribute> mappingVector, ValueRange destinationTensors,
YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
assert(!loopRanges.empty() && "unexpected empty loop ranges");
assert(loopRanges.size() == tileSizes.size() &&
"expected as many tile sizes as loop ranges");
- assert((numTiles.empty() || numTiles.size() == loopRanges.size()) &&
- "expected max number of tiles to be either empty or equal to number "
- "of loops");
OpBuilder::InsertionGuard guard(rewriter);
SmallVector<OpFoldResult> offsets(loopRanges.size()),
sizes(loopRanges.size());
@@ -386,25 +419,23 @@ static LogicalResult generateLoopNestUsingForallOp(
mappingAttr = rewriter.getArrayAttr(mappingVector);
scf::ForallOp forallOp;
- SmallVector<OpFoldResult> lbs, ubs, steps;
- if (numTiles.empty()) {
- for (auto [tileSize, loopRange] : llvm::zip_equal(tileSizes, loopRanges)) {
- if (isConstantIntValue(tileSize, 0))
+ bool useNumThreads = !numThreads.empty();
+
+ if (useNumThreads) {
+ // Prune the zero numthreads.
+ SmallVector<OpFoldResult> nonZeroNumThreads;
+ for (auto nt : numThreads) {
+ if (isConstantIntValue(nt, 0))
continue;
- lbs.push_back(loopRange.offset);
- ubs.push_back(loopRange.size);
- steps.push_back(tileSize);
+ nonZeroNumThreads.push_back(nt);
}
- assert(!lbs.empty() && "Expected at least one loop range");
- forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps,
+ forallOp = rewriter.create<scf::ForallOp>(loc, nonZeroNumThreads,
destinationTensors, mappingAttr);
} else {
- SmallVector<OpFoldResult> numThreads;
- for (auto maxNumTile : numTiles) {
- if (!isConstantIntValue(maxNumTile, 0))
- numThreads.push_back(maxNumTile);
- }
- forallOp = rewriter.create<scf::ForallOp>(loc, numThreads,
+ SmallVector<OpFoldResult> lbs, ubs, steps;
+ std::tie(lbs, ubs, steps) =
+ getLoopBounds(rewriter, loc, loopRanges, tileSizes);
+ forallOp = rewriter.create<scf::ForallOp>(loc, lbs, ubs, steps,
destinationTensors, mappingAttr);
}
loops.push_back(forallOp);
@@ -446,7 +477,7 @@ static LogicalResult generateLoopNestUsingForallOp(
static LogicalResult generateLoopNest(
RewriterBase &rewriter, Location loc, const scf::SCFTilingOptions &options,
ArrayRef<Range> loopRanges, ArrayRef<OpFoldResult> tileSizes,
- ArrayRef<OpFoldResult> numTiles, ValueRange destinationTensors,
+ ArrayRef<OpFoldResult> numThreads, ValueRange destinationTensors,
YieldTiledValuesFn tiledBodyFn, SmallVector<LoopLikeOpInterface> &loops) {
// If the tile sizes are all zero, no loops are generated. Just call the
// callback function to handle untiled case.
@@ -458,12 +489,11 @@ static LogicalResult generateLoopNest(
}
if (options.loopType == scf::SCFTilingOptions::LoopType::ForOp) {
return generateLoopNestUsingForOp(rewriter, loc, loopRanges, tileSizes,
- numTiles, destinationTensors, tiledBodyFn,
- loops);
+ destinationTensors, tiledBodyFn, loops);
}
if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
return generateLoopNestUsingForallOp(
- rewriter, loc, loopRanges, tileSizes, numTiles, options.mappingVector,
+ rewriter, loc, loopRanges, tileSizes, numThreads, options.mappingVector,
destinationTensors, tiledBodyFn, loops);
}
return rewriter.notifyMatchFailure(loc, "unhandled loop type");
@@ -697,10 +727,16 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
// 1. Get the range of the loops that are represented by the operation.
SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
- // 2. Materialize the tile sizes or max num tiles;
- SmallVector<OpFoldResult> tileSizes, numTiles;
- std::tie(tileSizes, numTiles) =
- getTileSizesAndNumTiles(rewriter, op, iterationDomain, options);
+ // 2. Materialize the tile sizes and/or number of threads;
+ SmallVector<OpFoldResult> tileSizes, numThreads;
+ std::tie(tileSizes, numThreads) =
+ getTileSizes(rewriter, op, iterationDomain, options);
+
+ // Check if it is safe to tile. This is hold over from previous iterations
+ // of tile to for-all. Consider dropping it.
+ if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) {
+ checkSafeToTileToForall(op, tileSizes, numThreads);
+ }
// 3. If there is an interchange specified, permute the iteration domain and
// the tile sizes.
@@ -713,8 +749,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
applyPermutationToVector(iterationDomain, interchangeVector);
applyPermutationToVector(tileSizes, interchangeVector);
- if (!numTiles.empty())
- applyPermutationToVector(numTiles, interchangeVector);
+ if (!numThreads.empty())
+ applyPermutationToVector(numThreads, interchangeVector);
}
FailureOr<TilingResult> tilingResult;
@@ -729,7 +765,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
// 4a. Compute the `offsets` and `sizes` to use for tiling.
SmallVector<OpFoldResult> offsets, sizes;
std::tie(offsets, sizes) = getTileOffsetAndSizes(
- rewriter, loc, ivs, iterationDomain, tileSizes, !numTiles.empty());
+ rewriter, loc, ivs, iterationDomain, tileSizes, !numThreads.empty());
// 4b. If interchange was provided, apply inverse of the interchange
// to get back the offsets/sizes in the order to be specified.
@@ -797,7 +833,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
// 7. Generate the tiled loops nest using the callback defined above.
SmallVector<LoopLikeOpInterface> loops;
if (failed(generateLoopNest(rewriter, op.getLoc(), options, iterationDomain,
- tileSizes, numTiles, destinationTensors,
+ tileSizes, numThreads, destinationTensors,
innerYieldTiledValuesFn, loops)))
return op.emitOpError("failed to generate tiling loops");
assert(succeeded(tilingResult) &&
@@ -906,7 +942,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
scf::SCFTilingOptions options;
options.setLoopType(scf::SCFTilingOptions::LoopType::ForOp);
if (failed(generateLoopNest(b, loc, options, iterationDomain, tileSizesVector,
- /*numTiles=*/ArrayRef<OpFoldResult>{},
+ /*numThreads=*/ArrayRef<OpFoldResult>{},
initTensors, innerYieldTiledValuesFn, loops)))
return b.notifyMatchFailure(op, "failed to tile for parallel reduction");
diff --git a/mlir/test/Dialect/Linalg/tile-to-forall.mlir b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
index f33739f119eaf..d1ed468fce323 100644
--- a/mlir/test/Dialect/Linalg/tile-to-forall.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
@@ -3,9 +3,9 @@
// Offset per thread:
// CHECK-DAG: affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 10))>
// Per thread tile size.
-// CHECK-DAG: affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 10)) + s0, s0 ceildiv 10)>
+// CHECK-DAG: affine_map<(d0)[s0] -> (s0 ceildiv 10, -(d0 * (s0 ceildiv 10)) + s0)>
// CHECK-DAG: affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 20))>
-// CHECK-DAG: affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 20)) + s0, s0 ceildiv 20)>
+// CHECK-DAG: affine_map<(d0)[s0] -> (s0 ceildiv 20, -(d0 * (s0 ceildiv 20)) + s0)>
module {
// CHECK-LABEL: matmul(
@@ -96,7 +96,7 @@ module {
// In this test case, matmul dims and tile size are dynamic.
// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
-// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (s0, -(d0 * s0) + s1)>
// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * s0)>
// CHECK-LABEL: matmul_tile_size_dynamic_dynamic(
@@ -140,7 +140,7 @@ module attributes {transform.with_named_sequence} {
// Tests that dimension 0 can eliminate affine.min/max, dimension 1 cannot.
-// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * -15 + 300, 15)>
+// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (15, d0 * -15 + 300)>
// CHECK-DAG: #[[$map1:.+]] = affine_map<(d0) -> (0, d0)>
// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0) -> (d0 * 10)>
// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (d0 * 15)>
@@ -176,30 +176,29 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
-
// -----
-// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
-// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
-// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
-// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
-// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 10)>
-// CHECK-DAG: #[[$map6:.+]] = affine_map<(d0) -> (d0 * 20)>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
+// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
+// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0) -> (d0 * 10)>
+// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0) -> (d0 * 20)>
-// CHECK-LABEL: matmul_tile_size_dynamic(
+// CHECK: matmul_tile_size_dynamic(
// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[B:[0-9a-z]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[C:[0-9a-z]+]]: tensor<?x?xf32>
func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: %[[M:.+]] = tensor.dim %[[A]], %c0 :
// CHECK: %[[N:.+]] = tensor.dim %[[B]], %c1 :
- // CHECK: %[[NT0:.+]] = affine.apply #map()[%[[M]]]
- // CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
+ // CHECK: %[[NT0:.+]] = affine.apply #[[MAP0]]()[%[[M]]]
+ // CHECK: %[[NT1:.+]] = affine.apply #[[MAP1]]()[%[[N]]]
// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
- // CHECK: %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
- // CHECK: %[[TS1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
- // CHECK: %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
- // CHECK: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
+ // CHECK: %[[TS0:.+]] = affine.min #[[MAP2]](%[[IV0]])[%[[M]]]
+ // CHECK: %[[TS1:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N]]]
+ // CHECK: %[[LB0:.+]] = affine.apply #[[MAP4]](%[[IV0]])
+ // CHECK: %[[LB1:.+]] = affine.apply #[[MAP5]](%[[IV1]])
// CHECK: tensor.extract_slice %[[A]]
// CHECK: tensor.extract_slice %[[B]]
// CHECK: tensor.extract_slice %[[C_BLK]]
@@ -219,26 +218,25 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
-
// -----
// Tests that dimension 0 can eliminate affine.min/max, dimension 1 cannot.
-// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * -21 + 300, 21)>
-// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0) -> (d0 * 10)>
-// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (d0 * 21)>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * -21 + 300, 21)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * 10)>
+// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * 21)>
-// CHECK-LABEL: matmul_tile_size_static(
+// CHECK: matmul_tile_size_static(
// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor
// CHECK-SAME: %[[B:[0-9a-z]+]]: tensor
// CHECK-SAME: %[[C:[0-9a-z]+]]: tensor
func.func @matmul_tile_size_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: tensor<100x300xf32>) -> tensor<100x300xf32> {
// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (10, 15) shared_outs(%[[C_BLK:.*]] = %[[C]])
- // CHECK: %[[TS:.+]] = affine.min #[[$map0]](%[[IV1]])
+ // CHECK: %[[TS:.+]] = affine.min #[[MAP0]](%[[IV1]])
// CHECK-NOT: affine.max
// CHECK-NOT: affine.min
- // CHECK: %[[LB0:.+]] = affine.apply #[[$map2]](%[[IV0]])
- // CHECK: %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]])
+ // CHECK: %[[LB0:.+]] = affine.apply #[[MAP1]](%[[IV0]])
+ // CHECK: %[[LB1:.+]] = affine.apply #[[MAP2]](%[[IV1]])
// CHECK: %[[tA:.+]] = tensor.extract_slice %[[A]][%[[LB0]], 0] [10, 200] [1, 1] :
// CHECK: %[[tB:.+]] = tensor.extract_slice %[[B]][0, %[[LB1]]] [200, %[[TS]]] [1, 1] :
// CHECK: %[[tC:.+]] = tensor.extract_slice %[[C_BLK]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] :
@@ -298,7 +296,7 @@ module {
// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
-// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (s0, -(d0 * s0) + s1)>
// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * s0)>
// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 20)>
diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
index d244670f73754..3467a539496b8 100644
--- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt --transform-interpreter --mlir-print-local-scope --split-input-file --verify-diagnostics %s | FileCheck %s
+// RUN: mlir-opt --transform-interpreter --mlir-print-local-scope --split-input-file --verify-diagnostics --cse %s | FileCheck %s
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
@@ -178,12 +178,11 @@ module {
// CHECK-LABEL: func.func @scalable_tile(
// CHECK-SAME: %[[ARG_0:.*]]: tensor<?xf32>, %[[ARG_1:.*]]: tensor<?xf32>, %[[ARG_2:.*]]: tensor<?xf32>,
-// CHECK: %[[C4:.*]] = arith.constant 0 : index
-// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG_0]], %[[C4]] : tensor<?xf32>
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[DIM:.*]] = tensor.dim %[[ARG_0]], %[[C0]] : tensor<?xf32>
// CHECK: %[[VEC_SIZE:.*]] = arith.constant 4 : index
// CHECK: %[[VS:.*]] = vector.vscale
// CHECK: %[[STEP:.*]] = arith.muli %[[VEC_SIZE]], %[[VS]] : index
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[DIM]] step %[[STEP]] iter_args(%[[VAL:.*]] = %[[ARG_2]]) -> (tensor<?xf32>) {
// CHECK: %[[SIZE:.*]] = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%[[IV]])[%[[STEP]], %[[DIM]]]
// CHECK: %[[SLICE_ARG0:.*]] = tensor.extract_slice %[[ARG_0]][%[[IV]]] [%[[SIZE]]] [1] : tensor<?xf32> to tensor<?xf32>
@@ -202,20 +201,14 @@ module {
// -----
// CHECK-LABEL: func.func @scalable_and_fixed_length_tile
-// CHECK: %[[C4:.*]] = arith.constant 4 : index
-// CHECK: %[[VS:.*]] = vector.vscale
-// CHECK: %[[STEP_2:.*]] = arith.muli %[[C4]], %[[VS]] : index
-// CHECK: %[[C0:.*]] = arith.constant 0 : index
-// CHECK: %[[C128:.*]] = arith.constant 128 : index
-// CHECK: %[[STEP_0:.*]] = arith.constant 4 : index
-// CHECK: scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C128]] step %[[STEP_0]]
-// CHECK: %[[C0_1:.*]] = arith.constant 0 : index
-// CHECK: %[[C128_1:.*]] = arith.constant 128 : index
-// CHECK: %[[STEP_1:.*]] = arith.constant 4 : index
-// CHECK: scf.for %[[VAL_16:.*]] = %[[C0_1]] to %[[C128_1]] step %[[STEP_1]]
-// CHECK: %[[C0_2:.*]] = arith.constant 0 : index
-// CHECK: %[[C128_2:.*]] = arith.constant 128 : index
-// CHECK: scf.for %{{.*}} = %[[C0_2]] to %[[C128_2]] step %[[STEP_2]]
+// CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
+// CHECK-DAG: %[[VS:.*]] = vector.vscale
+// CHECK-DAG: %[[STEP_2:.*]] = arith.muli %[[C4]], %[[VS]] : index
+// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
+// CHECK-DAG: %[[C128:.*]] = arith.constant 128 : index
+// CHECK: scf.for %[[VAL_11:.*]] = %[[C0]] to %[[C128]] step %[[C4]]
+// CHECK: scf.for %[[VAL_16:.*]] = %[[C0]] to %[[C128]] step %[[C4]]
+// CHECK: scf.for %{{.*}} = %[[C0]] to %[[C128]] step %[[STEP_2]]
func.func @scalable_and_fixed_length_tile(
%arg0: tensor<128x128xf32>, %arg1: tensor<128x128xf32>, %arg2: tensor<128x128xf32>)
diff --git a/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir
index 7d247aefcf6b1..ccf8e37c094f4 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir
@@ -31,8 +31,8 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[DIM_IN1:.+]] = tensor.dim %[[IN]], %[[C1]]
// CHECK-DAG: %[[DIM1:.+]] = affine.apply #[[MAP1]]()[%[[DIM_IN1]]]
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK: %[[RESULT:[a-zA-Z0-9]+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[DIM0]] step %[[C2]]
-// CHECK: %[[C3:.+]] = arith.constant 3 : index
// CHECK: scf.for {{.*}} = %[[C0]] to %[[DIM1]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] =
// CHECK: %[[SWAP_RESULT:.*]] = scf.if
// CHECK: tensor.generate
@@ -62,8 +62,8 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
-// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 + 8)>
-// CHECK-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 7)>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 + 8)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 + 7)>
// CHECK: func @dynamic_2d_pad_tensor_inner_tiling(
// CHECK-SAME: %[[IN:.*]]: tensor<?x?xf32>
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
@@ -107,9 +107,9 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C15:.*]] = arith.constant 15 : index
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
+// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
+// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK: %[[RESULT:.*]] = scf.for {{.*}} = %[[C0]] to %[[C15]] step %[[C2]]
-// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
-// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
// CHECK: scf.for {{.*}} = %[[C0]] to %[[C16]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] =
// CHECK: %[[SWAP_RESULT:.*]] = scf.if
// CHECK: tensor.generate
diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
index 488a52e8e3e91..08be9737f4302 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
@@ -24,13 +24,13 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
// CHECK: %[[OUTER:[a-zA-Z0-9]+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]]
// CHECK-SAME: iter_args(%[[INIT0:.+]] = %[[ARG2]])
-// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
// CHECK: %[[INNER:[a-zA-Z0-9]+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]]
// CHECK-SAME: iter_args(%[[INIT1:.+]] = %[[INIT0]])
// CHECK-DAG: %[[TS_Y:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[M]]]
@@ -77,14 +77,14 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: memref<?x?xf32>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
// CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]]
// CHECK-DAG: %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]]
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index
// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]]
-// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]]
-// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index
// CHECK: scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C30]]
// CHECK-DAG: %[[TS_M:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[M]]]
// CHECK-DAG: %[[TS_N:.+]] = affine.min #[[$MAP1]](%[[IV1]])[%[[N]]]
@@ -130,15 +130,15 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0) -> (10, -d0 + 128)>
// CHECK-LABEL: func.func @multi_result(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x200x300xf32>)
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
-// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
// CHECK-DAG: %[[INIT0:.+]] = tensor.empty()
// CHECK-DAG: %[[INIT1:.+]] = tensor.empty()
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
+// CHECK-DAG: %[[C300:.+]] = arith.constant 300 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
// CHECK: %[[OUTER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C128]] step %[[C10]]
// CHECK-SAME: iter_args(%[[ARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
-// CHECK-DAG: %[[C300:.+]] = arith.constant 300 : index
-// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
// CHECK: %[[INNER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[C300]] step %[[C20]]
// CHECK-SAME: iter_args(%[[ARG3:[a-zA-Z0-9]+]] = %[[ARG1]], %[[ARG4:[a-zA-Z0-9]+]] = %[[ARG2]])
// CHECK-DAG: %[[TS_Y:.+]] = affine.min #[[$MAP0]](%[[IV0]])
@@ -193,7 +193,6 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
-// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
// CHECK-DAG: %[[N:.+]] = tensor.dim %[[INPUT]], %[[C0]]
// CHECK-DAG: %[[C:.+]] = tensor.dim %[[INPUT]], %[[C3]]
// CHECK-DAG: %[[P:.+]] = tensor.dim %[[FILTER]], %[[C0]]
@@ -201,12 +200,13 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: %[[F:.+]] = tensor.dim %[[FILTER]], %[[C3]]
// CHECK-DAG: %[[R:.+]] = tensor.dim %[[INIT]], %[[C1]]
// CHECK-DAG: %[[S:.+]] = tensor.dim %[[INIT]], %[[C2]]
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index
// CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[P]] step %[[C10]]
// CHECK-SAME: iter_args(%[[INIT0:.+]] = %[[INIT]])
-// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
// CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[Q]] step %[[C20]]
// CHECK-SAME: iter_args(%[[INIT1:.+]] = %[[INIT0]])
-// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index
// CHECK: scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C30]]
// CHECK-SAME: iter_args(%[[INIT2:.+]] = %[[INIT1]])
// CHECK-DAG: %[[TS_P:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[P]]]
@@ -259,15 +259,15 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
-// CHECK: #[[$MAP_ADD:.+]] = affine_map<(d0, d1) -> (d0 + d1)>
-// CHECK-LABEL: @indexed_semantics
-// CHECK: scf.for %[[I0:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
-// CHECK: scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
-// CHECK: %[[INDEX0:.+]] = linalg.index 0
-// CHECK: %[[INDEX0_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX0]], %[[I0]])
-// CHECK: %[[INDEX1:.+]] = linalg.index 1
-// CHECK: %[[INDEX1_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX1]], %[[I1]])
-// CHECK: arith.addi %[[INDEX0_AMENDED]], %[[INDEX1_AMENDED]]
+// CHECK: #[[MAP_ADD:.+]] = affine_map<(d0, d1) -> (d0 + d1)>
+// CHECK: @indexed_semantics
+// CHECK: scf.for %[[I0:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
+// CHECK: scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
+// CHECK: %[[INDEX0:.+]] = linalg.index 0
+// CHECK: %[[INDEX0_AMENDED:.+]] = affine.apply #[[MAP_ADD]](%[[INDEX0]], %[[I0]])
+// CHECK: %[[INDEX1:.+]] = linalg.index 1
+// CHECK: %[[INDEX1_AMENDED:.+]] = affine.apply #[[MAP_ADD]](%[[INDEX1]], %[[I1]])
+// CHECK: arith.addi %[[INDEX0_AMENDED]], %[[INDEX1_AMENDED]]
// -----
@@ -296,16 +296,16 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]]
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index
// CHECK: %[[OUTER:[a-zA-Z0-9]+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]]
// CHECK-SAME: iter_args(%[[INIT0:.+]] = %[[ARG2]])
-// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index
// CHECK: %[[INNER1:[a-zA-Z0-9]+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C30]]
// CHECK-SAME: iter_args(%[[INIT1:.+]] = %[[INIT0]])
-// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
// CHECK: %[[INNER2:[a-zA-Z0-9]+]] = scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]]
// CHECK-SAME: iter_args(%[[INIT2:.+]] = %[[INIT1]])
// CHECK-DAG: %[[TS_N:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[N]]]
>From c2c13ce5dd0ce2407767cdcca0f8aa22796f2693 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Thu, 23 May 2024 12:00:55 -0700
Subject: [PATCH 3/9] Add logic to account for negative tile sizes.
---
.../SCF/Transforms/TileUsingInterface.cpp | 102 +++++++++++++-----
mlir/test/Dialect/Linalg/tile-to-forall.mlir | 49 +++++----
2 files changed, 102 insertions(+), 49 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index c9a56c66e9255..b67764c18f23e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -220,45 +220,93 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
b, loc, minMap, SmallVector<OpFoldResult>{offset, tileSize, size});
}
+/// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
+/// than `iterationSize`.
+static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
+ OpFoldResult numThreads,
+ OpFoldResult iterationSize) {
+ std::optional<int64_t> tileSizeConst = getConstantIntValue(tileSize);
+ std::optional<int64_t> numThreadsConst = getConstantIntValue(numThreads);
+ std::optional<int64_t> iterSizeConst = getConstantIntValue(iterationSize);
+ if (!tileSizeConst || !numThreadsConst || !iterSizeConst)
+ return false;
+ return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
+}
+
/// Compute the tile offsets and sizes.
static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
ArrayRef<Range> iterationDomain,
- ArrayRef<OpFoldResult> tileSizes, bool isLoopNormalized) {
+ ArrayRef<OpFoldResult> tileSizes,
+ ArrayRef<OpFoldResult> numThreads) {
SmallVector<OpFoldResult> offsets, sizes;
int materializedLoopNum = 0;
- AffineExpr d0, s0, s1, s2;
- AffineExpr offsetExpr;
- if (isLoopNormalized) {
- bindDims(rewriter.getContext(), d0);
+ if (!numThreads.empty()) {
+ AffineExpr d0, d1, s0, s1, s2;
+ AffineExpr offsetExpr, residualTileSizeExpr;
+ bindDims(rewriter.getContext(), d0, d1);
bindSymbols(rewriter.getContext(), s0, s1, s2);
- offsetExpr = s0 + d0 * s1 * s2;
- }
+ offsetExpr = d0 + d1 * s0 * s1;
+ residualTileSizeExpr = s2 - (d0 + d1 * s0 * s1);
- for (auto [tileSize, loopRange] :
- llvm::zip_equal(tileSizes, iterationDomain)) {
- if (isConstantIntValue(tileSize, 0)) {
- offsets.push_back(loopRange.offset);
- sizes.push_back(loopRange.size);
- continue;
- }
- // If loop is normalized, the offset is (lb + iv * step * tileSize)
- Value iv = ivs[materializedLoopNum++];
- OpFoldResult offset;
- if (isLoopNormalized) {
- offset = affine::makeComposedFoldedAffineApply(
+ for (auto [nt, tileSize, loopRange] :
+ llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
+
+ if (isConstantIntValue(nt, 0) || isConstantIntValue(nt, 1)) {
+ offsets.push_back(loopRange.offset);
+ sizes.push_back(loopRange.size);
+ continue;
+ }
+
+ Value iv = ivs[materializedLoopNum++];
+ OpFoldResult offset = affine::makeComposedFoldedAffineApply(
rewriter, loc, offsetExpr,
- ArrayRef<OpFoldResult>{iv, loopRange.offset, loopRange.stride,
+ ArrayRef<OpFoldResult>{loopRange.offset, iv, loopRange.stride,
tileSize});
- } else {
- offset = getAsOpFoldResult(iv);
+ OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, residualTileSizeExpr,
+ {loopRange.offset, nt, loopRange.stride, tileSize, loopRange.size});
+ OpFoldResult size = tileSize;
+ if (!isConstantIntValue(residualTileSize, 0)) {
+ OpFoldResult sizeMinusOffsetPerThread =
+ affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0,
+ {offset, loopRange.size});
+ size = affine::makeComposedFoldedAffineMin(
+ rewriter, loc,
+ AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()),
+ {sizeMinusOffsetPerThread, tileSize});
+ }
+ if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) {
+ AffineMap maxMap =
+ AffineMap::getMultiDimIdentityMap(2, rewriter.getContext());
+ size = affine::makeComposedFoldedAffineMax(
+ rewriter, loc, maxMap, {rewriter.getIndexAttr(0), size});
+ }
+
+ offsets.push_back(offset);
+ sizes.push_back(size);
+ }
+ return {offsets, sizes};
+ } else {
+ for (auto [tileSize, loopRange] :
+ llvm::zip_equal(tileSizes, iterationDomain)) {
+
+ if (isConstantIntValue(tileSize, 0)) {
+ offsets.push_back(loopRange.offset);
+ sizes.push_back(loopRange.size);
+ continue;
+ }
+
+ Value iv = ivs[materializedLoopNum++];
+ OpFoldResult offset = getAsOpFoldResult(iv);
+ offsets.push_back(offset);
+ OpFoldResult size =
+ getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize);
+ sizes.push_back(size);
}
- offsets.push_back(offset);
- sizes.push_back(
- getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize));
+ return {offsets, sizes};
}
- return {offsets, sizes};
}
/// Function to return the bounds of the loops to be generated.
@@ -765,7 +813,7 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
// 4a. Compute the `offsets` and `sizes` to use for tiling.
SmallVector<OpFoldResult> offsets, sizes;
std::tie(offsets, sizes) = getTileOffsetAndSizes(
- rewriter, loc, ivs, iterationDomain, tileSizes, !numThreads.empty());
+ rewriter, loc, ivs, iterationDomain, tileSizes, numThreads);
// 4b. If interchange was provided, apply inverse of the interchange
// to get back the offsets/sizes in the order to be specified.
diff --git a/mlir/test/Dialect/Linalg/tile-to-forall.mlir b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
index d1ed468fce323..c0ba5a8402d5f 100644
--- a/mlir/test/Dialect/Linalg/tile-to-forall.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
@@ -3,9 +3,9 @@
// Offset per thread:
// CHECK-DAG: affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 10))>
// Per thread tile size.
-// CHECK-DAG: affine_map<(d0)[s0] -> (s0 ceildiv 10, -(d0 * (s0 ceildiv 10)) + s0)>
+// CHECK-DAG: affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 10)) + s0, s0 ceildiv 10)>
// CHECK-DAG: affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 20))>
-// CHECK-DAG: affine_map<(d0)[s0] -> (s0 ceildiv 20, -(d0 * (s0 ceildiv 20)) + s0)>
+// CHECK-DAG: affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 20)) + s0, s0 ceildiv 20)>
module {
// CHECK-LABEL: matmul(
@@ -96,7 +96,7 @@ module {
// In this test case, matmul dims and tile size are dynamic.
// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
-// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (s0, -(d0 * s0) + s1)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * s0)>
// CHECK-LABEL: matmul_tile_size_dynamic_dynamic(
@@ -140,7 +140,7 @@ module attributes {transform.with_named_sequence} {
// Tests that dimension 0 can eliminate affine.min/max, dimension 1 cannot.
-// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (15, d0 * -15 + 300)>
+// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * -15 + 300, 15)>
// CHECK-DAG: #[[$map1:.+]] = affine_map<(d0) -> (0, d0)>
// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0) -> (d0 * 10)>
// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (d0 * 15)>
@@ -176,6 +176,7 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
// -----
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
@@ -296,7 +297,7 @@ module {
// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
-// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (s0, -(d0 * s0) + s1)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * s0)>
// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 20)>
@@ -339,7 +340,6 @@ module attributes {transform.with_named_sequence} {
// -----
// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * -15 + 100, 15)>
-// CHECK-DAG: #[[$map1:.+]] = affine_map<(d0) -> (0, d0)>
// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0) -> (d0 * 15)>
// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (d0)>
@@ -352,8 +352,7 @@ module attributes {transform.with_named_sequence} {
%OUT1: tensor<100xf32>, %OUT2: tensor<100xf32>)
-> (tensor<100xf32>, tensor<100xf32>) {
// CHECK: scf.forall (%[[IV0:.+]]) in (7) shared_outs(%[[OUT1:[0-9a-z]+]] = %[[ORGOUT1]], %[[OUT2:[0-9a-z]+]] = %[[ORGOUT2]])
-// CHECK: %[[TSMIN:.+]] = affine.min #[[$map0]](%[[IV0]])
-// CHECK: %[[TS:.+]] = affine.max #[[$map1]](%[[TSMIN]])
+// CHECK: %[[TS:.+]] = affine.min #[[$map0]](%[[IV0]])
// CHECK-NOT: affine.min
// CHECK-NOT: affine.max
// CHECK: %[[LB:.+]] = affine.apply #[[$map2]](%[[IV0]])
@@ -453,9 +452,10 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
-// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
-// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0) -> (d0 * 10)>
-// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 20)>
+// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (0, d0)>
+// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
+// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 10)>
+// CHECK-DAG: #[[$map6:.+]] = affine_map<(d0) -> (d0 * 20)>
// CHECK-LABEL: matmul_tile_size_dynamic(
// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<?x?xf32>
@@ -470,10 +470,12 @@ func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C
// CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
// CHECK: %[[K:.+]] = tensor.dim %[[A]], %[[c1]] :
// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
- // CHECK: %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
- // CHECK: %[[TS1:.+]] = affine.min #[[$map3]](%[[IV1]])[%[[N]]]
- // CHECK: %[[LB0:.+]] = affine.apply #[[$map4]](%[[IV0]])
- // CHECK: %[[LB1:.+]] = affine.apply #[[$map5]](%[[IV1]])
+ // CHECK: %[[TSMIN0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
+ // CHECK: %[[TS0:.+]] = affine.max #[[$map3]](%[[TSMIN0]])
+ // CHECK: %[[TSMIN1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
+ // CHECK: %[[TS1:.+]] = affine.max #[[$map3]](%[[TSMIN1]])
+ // CHECK: %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
+ // CHECK: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
// CHECK: tensor.extract_slice %[[A]][%[[LB0]], 0] [%[[TS0]], %[[K]]] [1, 1] :
// CHECK: tensor.extract_slice %[[B]][0, %[[LB1]]] [%[[K]], %[[TS1]]] [1, 1] :
// CHECK: tensor.extract_slice %[[C_BLK]][%[[LB0]], %[[LB1]]] [%[[TS0]], %[[TS1]]] [1, 1] :
@@ -521,9 +523,10 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
-// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
-// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0) -> (d0 * 10)>
-// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 20)>
+// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (0, d0)>
+// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
+// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 10)>
+// CHECK-DAG: #[[$map6:.+]] = affine_map<(d0) -> (d0 * 20)>
// CHECK-LABEL: matmul_tile_size_dynamic(
// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<?x?xf32>
@@ -538,10 +541,12 @@ func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C
// CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
// CHECK: %[[K:.+]] = tensor.dim %[[A]], %[[c1]] :
// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
- // CHECK: %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
- // CHECK: %[[TS1:.+]] = affine.min #[[$map3]](%[[IV1]])[%[[N]]]
- // CHECK: %[[LB0:.+]] = affine.apply #[[$map4]](%[[IV0]])
- // CHECK: %[[LB1:.+]] = affine.apply #[[$map5]](%[[IV1]])
+ // CHECK: %[[TSMIN0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
+ // CHECK: %[[TS0:.+]] = affine.max #[[$map3]](%[[TSMIN0]])
+ // CHECK: %[[TSMIN1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
+ // CHECK: %[[TS1:.+]] = affine.max #[[$map3]](%[[TSMIN1]])
+ // CHECK: %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
+ // CHECK: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
// CHECK: tensor.extract_slice %[[A]][%[[LB0]], 0] [%[[TS0]], %[[K]]] [1, 1] :
// CHECK: tensor.extract_slice %[[B]][0, %[[LB1]]] [%[[K]], %[[TS1]]] [1, 1] :
// CHECK: tensor.extract_slice %[[C_BLK]][%[[LB0]], %[[LB1]]] [%[[TS0]], %[[TS1]]] [1, 1] :
>From 8f500bd5509f5a9068bcef3b47352950339ccd1e Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Thu, 23 May 2024 21:51:44 -0700
Subject: [PATCH 4/9] Put back CHECK-LABELs
---
mlir/test/Dialect/Linalg/tile-to-forall.mlir | 42 +++++++++----------
.../TilingInterface/tile-using-interface.mlir | 18 ++++----
2 files changed, 30 insertions(+), 30 deletions(-)
diff --git a/mlir/test/Dialect/Linalg/tile-to-forall.mlir b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
index c0ba5a8402d5f..6e92deaf4cf0d 100644
--- a/mlir/test/Dialect/Linalg/tile-to-forall.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
@@ -179,27 +179,27 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
-// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
-// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0) -> (d0 * 10)>
-// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0) -> (d0 * 20)>
-
-// CHECK: matmul_tile_size_dynamic(
+// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
+// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
+// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
+// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 10)>
+// CHECK-DAG: #[[$map6:.+]] = affine_map<(d0) -> (d0 * 20)>
+
+// CHECK-LABEL: matmul_tile_size_dynamic(
// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[B:[0-9a-z]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[C:[0-9a-z]+]]: tensor<?x?xf32>
func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: %[[M:.+]] = tensor.dim %[[A]], %c0 :
// CHECK: %[[N:.+]] = tensor.dim %[[B]], %c1 :
- // CHECK: %[[NT0:.+]] = affine.apply #[[MAP0]]()[%[[M]]]
- // CHECK: %[[NT1:.+]] = affine.apply #[[MAP1]]()[%[[N]]]
+ // CHECK: %[[NT0:.+]] = affine.apply #[[$map0]]()[%[[M]]]
+ // CHECK: %[[NT1:.+]] = affine.apply #[[$map1]]()[%[[N]]]
// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
- // CHECK: %[[TS0:.+]] = affine.min #[[MAP2]](%[[IV0]])[%[[M]]]
- // CHECK: %[[TS1:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[N]]]
- // CHECK: %[[LB0:.+]] = affine.apply #[[MAP4]](%[[IV0]])
- // CHECK: %[[LB1:.+]] = affine.apply #[[MAP5]](%[[IV1]])
+ // CHECK: %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
+ // CHECK: %[[TS1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
+ // CHECK: %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
+ // CHECK: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
// CHECK: tensor.extract_slice %[[A]]
// CHECK: tensor.extract_slice %[[B]]
// CHECK: tensor.extract_slice %[[C_BLK]]
@@ -223,21 +223,21 @@ module attributes {transform.with_named_sequence} {
// Tests that dimension 0 can eliminate affine.min/max, dimension 1 cannot.
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> (d0 * -21 + 300, 21)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0 * 10)>
-// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0 * 21)>
+// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * -21 + 300, 21)>
+// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0) -> (d0 * 10)>
+// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (d0 * 21)>
-// CHECK: matmul_tile_size_static(
+// CHECK-LABEL: matmul_tile_size_static(
// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor
// CHECK-SAME: %[[B:[0-9a-z]+]]: tensor
// CHECK-SAME: %[[C:[0-9a-z]+]]: tensor
func.func @matmul_tile_size_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: tensor<100x300xf32>) -> tensor<100x300xf32> {
// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (10, 15) shared_outs(%[[C_BLK:.*]] = %[[C]])
- // CHECK: %[[TS:.+]] = affine.min #[[MAP0]](%[[IV1]])
+ // CHECK: %[[TS:.+]] = affine.min #[[$map0]](%[[IV1]])
// CHECK-NOT: affine.max
// CHECK-NOT: affine.min
- // CHECK: %[[LB0:.+]] = affine.apply #[[MAP1]](%[[IV0]])
- // CHECK: %[[LB1:.+]] = affine.apply #[[MAP2]](%[[IV1]])
+ // CHECK: %[[LB0:.+]] = affine.apply #[[$map2]](%[[IV0]])
+ // CHECK: %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]])
// CHECK: %[[tA:.+]] = tensor.extract_slice %[[A]][%[[LB0]], 0] [10, 200] [1, 1] :
// CHECK: %[[tB:.+]] = tensor.extract_slice %[[B]][0, %[[LB1]]] [200, %[[TS]]] [1, 1] :
// CHECK: %[[tC:.+]] = tensor.extract_slice %[[C_BLK]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] :
diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
index 08be9737f4302..0a4d4c45f10be 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
@@ -259,15 +259,15 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
-// CHECK: #[[MAP_ADD:.+]] = affine_map<(d0, d1) -> (d0 + d1)>
-// CHECK: @indexed_semantics
-// CHECK: scf.for %[[I0:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
-// CHECK: scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
-// CHECK: %[[INDEX0:.+]] = linalg.index 0
-// CHECK: %[[INDEX0_AMENDED:.+]] = affine.apply #[[MAP_ADD]](%[[INDEX0]], %[[I0]])
-// CHECK: %[[INDEX1:.+]] = linalg.index 1
-// CHECK: %[[INDEX1_AMENDED:.+]] = affine.apply #[[MAP_ADD]](%[[INDEX1]], %[[I1]])
-// CHECK: arith.addi %[[INDEX0_AMENDED]], %[[INDEX1_AMENDED]]
+// CHECK: #[[$MAP_ADD:.+]] = affine_map<(d0, d1) -> (d0 + d1)>
+// CHECK-LABEL: @indexed_semantics
+// CHECK: scf.for %[[I0:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
+// CHECK: scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
+// CHECK: %[[INDEX0:.+]] = linalg.index 0
+// CHECK: %[[INDEX0_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX0]], %[[I0]])
+// CHECK: %[[INDEX1:.+]] = linalg.index 1
+// CHECK: %[[INDEX1_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX1]], %[[I1]])
+// CHECK: arith.addi %[[INDEX0_AMENDED]], %[[INDEX1_AMENDED]]
// -----
>From 1857e5c0035416711411aa084fab2ea141366645 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Thu, 23 May 2024 21:59:11 -0700
Subject: [PATCH 5/9] Address comments
---
mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h | 4 ++--
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp | 2 +-
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 20081e853c882..0f30b149e9a08 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -65,8 +65,8 @@ struct SCFTilingOptions {
numThreadsComputationFunction = std::move(fun);
return *this;
}
- /// Convenience function to set the `tileSizeComputationFunction` to a
- /// function that computes tile sizes at the point they are needed.
+ /// Convenience function to set the `numThreadsComputationFunction` to a
+ /// function that computes num threads at the point they are needed.
SCFTilingOptions &setNumThreads(ArrayRef<OpFoldResult> numThreads);
/// The interchange vector to reorder the tiled loops.
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index b67764c18f23e..9121114166d0b 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -87,7 +87,7 @@ verifyTileSizeOptions(RewriterBase &rewriter, Location loc,
if (!options.interchangeVector.empty()) {
if (!isPermutationVector(options.interchangeVector)) {
return rewriter.notifyMatchFailure(
- loc, "invalid intechange vector, not a permutation of the entire "
+ loc, "invalid interchange vector, not a permutation of the entire "
"iteration space");
}
}
>From c82857630424f83415f642a8586e4864c6ad5d0d Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Sun, 26 May 2024 17:38:05 -0700
Subject: [PATCH 6/9] Next round of comments.
---
.../SCF/Transforms/TileUsingInterface.cpp | 46 ++++++++++---------
1 file changed, 24 insertions(+), 22 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 9121114166d0b..2c06f91f4daf0 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -75,11 +75,11 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
static LogicalResult
verifyTileSizeOptions(RewriterBase &rewriter, Location loc,
const scf::SCFTilingOptions &options) {
- // Specifying number of tile is only supported on `scf.forall` op.
+ // Specifying number of threads is only supported on `scf.forall` op.
if (options.numThreadsComputationFunction &&
options.loopType != scf::SCFTilingOptions::LoopType::ForallOp) {
return rewriter.notifyMatchFailure(
- loc, "number of tiles/threads can only by specified when loop type is "
+ loc, "number of threads can only by specified when loop type is "
"set to use `scf.forall`");
}
@@ -111,25 +111,27 @@ getTileSizes(RewriterBase &rewriter, TilingInterface op,
// If the number of tiles is also specified, use that.
if (options.tileSizeComputationFunction) {
tileSizes = options.tileSizeComputationFunction(rewriter, op);
- } else {
- // Compute the tile sizes from the iteration domain and number
- // of tiles as follows
- // - niters = ceilDiv(ub - lb, step)
- // - tileSize = ceilDiv(niters, numThreads)
- AffineExpr s0, s1, s2, s3;
- bindSymbols(rewriter.getContext(), s0, s1, s2, s3);
- AffineExpr numItersExpr = (s1 - s0).ceilDiv(s2);
- AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s3);
tileSizes.resize(numLoops, zero);
- for (auto [index, range, nt] :
- llvm::enumerate(iterationDomain, numThreads)) {
- if (isConstantIntValue(nt, 0))
- continue;
+ return {tileSizes, numThreads};
+ }
- tileSizes[index] = affine::makeComposedFoldedAffineApply(
- rewriter, op.getLoc(), tileSizeExpr,
- {range.offset, range.size, range.stride, nt});
- }
+ // Compute the tile sizes from the iteration domain and number
+ // of tiles as follows
+ // - niters = ceilDiv(ub - lb, step)
+ // - tileSize = ceilDiv(niters, numThreads)
+ AffineExpr s0, s1, s2, s3;
+ bindSymbols(rewriter.getContext(), s0, s1, s2, s3);
+ AffineExpr numItersExpr = (s1 - s0).ceilDiv(s2);
+ AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s3);
+ tileSizes.resize(numLoops, zero);
+ for (auto [index, range, nt] :
+ llvm::enumerate(iterationDomain, numThreads)) {
+ if (isConstantIntValue(nt, 0))
+ continue;
+
+ tileSizes[index] = affine::makeComposedFoldedAffineApply(
+ rewriter, op.getLoc(), tileSizeExpr,
+ {range.offset, range.size, range.stride, nt});
}
tileSizes.resize(numLoops, zero);
return {tileSizes, numThreads};
@@ -139,9 +141,9 @@ getTileSizes(RewriterBase &rewriter, TilingInterface op,
// skips tiling a particular dimension. This convention is significantly
// simpler to handle instead of adjusting affine maps to account for missing
// dimensions.
- if (options.tileSizeComputationFunction) {
- tileSizes = options.tileSizeComputationFunction(rewriter, op);
- }
+ assert(options.tileSizeComputationFunction &&
+ "expected tile sizes to be specified");
+ tileSizes = options.tileSizeComputationFunction(rewriter, op);
tileSizes.resize(numLoops, zero);
return {tileSizes, numThreads};
>From 6dea0cf080f314c05fe0bd838e08af29579a2337 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Thu, 30 May 2024 16:54:38 -0700
Subject: [PATCH 7/9] Drop support for non-unit strides, and assert that
strides of iteration domain are 1.
---
.../SCF/Transforms/TileUsingInterface.cpp | 42 +++++++++++++------
1 file changed, 29 insertions(+), 13 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 2c06f91f4daf0..510454dd96717 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -99,6 +99,11 @@ static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
getTileSizes(RewriterBase &rewriter, TilingInterface op,
ArrayRef<Range> iterationDomain,
const scf::SCFTilingOptions &options) {
+ assert(
+ llvm::all_of(iterationDomain,
+ [](Range r) { return isConstantIntValue(r.stride, 1); }) &&
+ "tile size computation assumes that all dimensions of the iteration "
+ "domain have stride 1");
OpFoldResult zero = rewriter.getIndexAttr(0);
SmallVector<OpFoldResult> tileSizes, numThreads;
size_t numLoops = iterationDomain.size();
@@ -119,10 +124,11 @@ getTileSizes(RewriterBase &rewriter, TilingInterface op,
// of tiles as follows
// - niters = ceilDiv(ub - lb, step)
// - tileSize = ceilDiv(niters, numThreads)
- AffineExpr s0, s1, s2, s3;
- bindSymbols(rewriter.getContext(), s0, s1, s2, s3);
- AffineExpr numItersExpr = (s1 - s0).ceilDiv(s2);
- AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s3);
+ AffineExpr s0, s1, s2;
+ bindSymbols(rewriter.getContext(), s0, s1, s2);
+ // TODO: The step here is assumed to be 1.
+ AffineExpr numItersExpr = (s1 - s0);
+ AffineExpr tileSizeExpr = numItersExpr.ceilDiv(s2);
tileSizes.resize(numLoops, zero);
for (auto [index, range, nt] :
llvm::enumerate(iterationDomain, numThreads)) {
@@ -130,8 +136,7 @@ getTileSizes(RewriterBase &rewriter, TilingInterface op,
continue;
tileSizes[index] = affine::makeComposedFoldedAffineApply(
- rewriter, op.getLoc(), tileSizeExpr,
- {range.offset, range.size, range.stride, nt});
+ rewriter, op.getLoc(), tileSizeExpr, {range.offset, range.size, nt});
}
tileSizes.resize(numLoops, zero);
return {tileSizes, numThreads};
@@ -244,13 +249,19 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
SmallVector<OpFoldResult> offsets, sizes;
int materializedLoopNum = 0;
+ assert(
+ llvm::all_of(iterationDomain,
+ [](Range r) { return isConstantIntValue(r.stride, 1); }) &&
+ "the offset and tile size computation assumes stride 1 for all "
+ "dimensions of the iteration domain");
+
if (!numThreads.empty()) {
- AffineExpr d0, d1, s0, s1, s2;
+ AffineExpr d0, d1, s0, s1;
AffineExpr offsetExpr, residualTileSizeExpr;
bindDims(rewriter.getContext(), d0, d1);
- bindSymbols(rewriter.getContext(), s0, s1, s2);
- offsetExpr = d0 + d1 * s0 * s1;
- residualTileSizeExpr = s2 - (d0 + d1 * s0 * s1);
+ bindSymbols(rewriter.getContext(), s0, s1);
+ offsetExpr = d0 + d1 * s0;
+ residualTileSizeExpr = s1 - (d0 + d1 * s0);
for (auto [nt, tileSize, loopRange] :
llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
@@ -264,11 +275,11 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
Value iv = ivs[materializedLoopNum++];
OpFoldResult offset = affine::makeComposedFoldedAffineApply(
rewriter, loc, offsetExpr,
- ArrayRef<OpFoldResult>{loopRange.offset, iv, loopRange.stride,
- tileSize});
+ ArrayRef<OpFoldResult>{loopRange.offset, iv, tileSize});
OpFoldResult residualTileSize = affine::makeComposedFoldedAffineApply(
rewriter, loc, residualTileSizeExpr,
- {loopRange.offset, nt, loopRange.stride, tileSize, loopRange.size});
+ {loopRange.offset, nt, tileSize, loopRange.size});
+
OpFoldResult size = tileSize;
if (!isConstantIntValue(residualTileSize, 0)) {
OpFoldResult sizeMinusOffsetPerThread =
@@ -776,6 +787,11 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
// 1. Get the range of the loops that are represented by the operation.
SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
+ if (llvm::any_of(iterationDomain,
+ [](Range r) { return !isConstantIntValue(r.stride, 1); })) {
+ return rewriter.notifyMatchFailure(
+ op, "unhandled tiling of iteration domain with non-unit stride");
+ }
// 2. Materialize the tile sizes and/or number of threads;
SmallVector<OpFoldResult> tileSizes, numThreads;
>From e4ecd3c9cdd54a99656fa76e12a511e81d700cda Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Thu, 30 May 2024 18:32:40 -0700
Subject: [PATCH 8/9] Remove use of `getLoopBounds` to avoid unnecessary lit
test churn.
---
.../SCF/Transforms/TileUsingInterface.cpp | 47 ++++++++++---------
1 file changed, 24 insertions(+), 23 deletions(-)
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 510454dd96717..8eb2ab59ac81d 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -94,16 +94,12 @@ verifyTileSizeOptions(RewriterBase &rewriter, Location loc,
return success();
}
-/// Compute the tile sizes and num threads values passed in.
+/// Method to instantiate the tile sizes and/or number of threads specified
+/// by the user.
static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
-getTileSizes(RewriterBase &rewriter, TilingInterface op,
- ArrayRef<Range> iterationDomain,
- const scf::SCFTilingOptions &options) {
- assert(
- llvm::all_of(iterationDomain,
- [](Range r) { return isConstantIntValue(r.stride, 1); }) &&
- "tile size computation assumes that all dimensions of the iteration "
- "domain have stride 1");
+getUserTileSizesAndNumThreads(RewriterBase &rewriter, TilingInterface op,
+ ArrayRef<Range> iterationDomain,
+ const scf::SCFTilingOptions &options) {
OpFoldResult zero = rewriter.getIndexAttr(0);
SmallVector<OpFoldResult> tileSizes, numThreads;
size_t numLoops = iterationDomain.size();
@@ -240,7 +236,9 @@ static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
return *tileSizeConst * (*numThreadsConst - 1) < *iterSizeConst;
}
-/// Compute the tile offsets and sizes.
+/// Compute the `OpFoldResult`s that represents the multi-dimensional
+/// `offset`s and `size`s of the tile of the iteration space that the
+/// innermost loop body of the generated tiled loops corresponds to.
static std::tuple<SmallVector<OpFoldResult>, SmallVector<OpFoldResult>>
getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
ArrayRef<Range> iterationDomain,
@@ -249,12 +247,6 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
SmallVector<OpFoldResult> offsets, sizes;
int materializedLoopNum = 0;
- assert(
- llvm::all_of(iterationDomain,
- [](Range r) { return isConstantIntValue(r.stride, 1); }) &&
- "the offset and tile size computation assumes stride 1 for all "
- "dimensions of the iteration domain");
-
if (!numThreads.empty()) {
AffineExpr d0, d1, s0, s1;
AffineExpr offsetExpr, residualTileSizeExpr;
@@ -266,7 +258,9 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
for (auto [nt, tileSize, loopRange] :
llvm::zip_equal(numThreads, tileSizes, iterationDomain)) {
- if (isConstantIntValue(nt, 0) || isConstantIntValue(nt, 1)) {
+ // Non-tiled cases, set the offset and size to the
+ // `loopRange.offset/size`.
+ if (isConstantIntValue(nt, 0)) {
offsets.push_back(loopRange.offset);
sizes.push_back(loopRange.size);
continue;
@@ -290,6 +284,16 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()),
{sizeMinusOffsetPerThread, tileSize});
}
+
+ // Consider the case where the original loop was `[0, 100)`.
+ // If number of threads are `7`, the tile size would be computed as
+ // `ceilDiv(100, 7) = 15`. For the last thread (thread_id = 6)
+ // - `offset = 0 + 6 * 15 = 105`
+ // - `tileSize = min(15, 100 - 105) = -5`
+ // To avoid negative tile sizes, we need to do a further
+ // `nonNegativeTileSize = affine.max(0, tileSize)`.
+ // This `max` can be avoided if
+ // `offset + tileSize * (numThreads - 1) < (ub - lb)`
if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) {
AffineMap maxMap =
AffineMap::getMultiDimIdentityMap(2, rewriter.getContext());
@@ -305,6 +309,8 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs,
for (auto [tileSize, loopRange] :
llvm::zip_equal(tileSizes, iterationDomain)) {
+ // Non-tiled cases, set the offset and size to the
+ // `loopRange.offset/size`.
if (isConstantIntValue(tileSize, 0)) {
offsets.push_back(loopRange.offset);
sizes.push_back(loopRange.size);
@@ -787,16 +793,11 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
// 1. Get the range of the loops that are represented by the operation.
SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter);
- if (llvm::any_of(iterationDomain,
- [](Range r) { return !isConstantIntValue(r.stride, 1); })) {
- return rewriter.notifyMatchFailure(
- op, "unhandled tiling of iteration domain with non-unit stride");
- }
// 2. Materialize the tile sizes and/or number of threads;
SmallVector<OpFoldResult> tileSizes, numThreads;
std::tie(tileSizes, numThreads) =
- getTileSizes(rewriter, op, iterationDomain, options);
+ getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options);
// Check if it is safe to tile. This is hold over from previous iterations
// of tile to for-all. Consider dropping it.
>From 075633609d9f430de8966a4f128d5aad11365081 Mon Sep 17 00:00:00 2001
From: MaheshRavishankar <mahesh.ravishankar at gmail.com>
Date: Thu, 13 Jun 2024 20:19:18 -0700
Subject: [PATCH 9/9] Add method to normalize `scf.forall` op.
---
mlir/include/mlir/Dialect/SCF/Utils/Utils.h | 7 ++
.../TransformOps/LinalgTransformOps.cpp | 117 +++++++++++++++---
.../SCF/Transforms/TileUsingInterface.cpp | 4 +-
mlir/lib/Dialect/SCF/Utils/Utils.cpp | 34 +++++
mlir/test/Dialect/Linalg/tile-tensors.mlir | 2 +-
mlir/test/Dialect/Linalg/tile-to-forall.mlir | 68 +++++-----
.../Dialect/Linalg/transform-op-tile.mlir | 2 +-
.../tile-and-fuse-using-interface.mlir | 2 +-
.../TilingInterface/tile-using-interface.mlir | 24 ++--
.../TilingInterface/tile-using-scfforall.mlir | 20 +--
10 files changed, 199 insertions(+), 81 deletions(-)
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index f719c00213987..727615fc44aa4 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -203,6 +203,13 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target,
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
RewriterBase &rewriter);
+/// Normalize an `scf.forall` operation. Returns `failure()`if normalization fails.
+// On `success()` returns the
+/// newly created operation with all uses of the original operation replaced
+/// with results of the new operation.
+FailureOr<scf::ForallOp> normalizeForallOp(RewriterBase &rewriter,
+ scf::ForallOp forallOp);
+
} // namespace mlir
#endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 8bf7db2e15061..bb062f849984d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -2914,6 +2915,94 @@ void transform::TileUsingForallOp::build(OpBuilder &builder,
/*mapping=*/mapping);
}
+/// Given `lbs`, `ubs` and `steps` of loops, return (for each loop), the
+/// normalized upper bound.
+static SmallVector<OpFoldResult>
+normalizeUpperBounds(RewriterBase &rewriter, Location loc,
+ ArrayRef<OpFoldResult> lbs, ArrayRef<OpFoldResult> ubs,
+ ArrayRef<OpFoldResult> steps) {
+ AffineExpr s0, s1, s2;
+ bindSymbols(rewriter.getContext(), s0, s1, s2);
+ AffineExpr normalizedUbExpr = (s1 - s0).ceilDiv(s2);
+ SmallVector<OpFoldResult> normalizedUbs;
+ for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
+ OpFoldResult normalizedUb = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, normalizedUbExpr, {lb, ub, step});
+ normalizedUbs.push_back(normalizedUb);
+ }
+ return normalizedUbs;
+}
+
+/// When a loop is normalized, the uses of the induction variable within the
+/// loop need to replaced with `original_lb + old_iv * original_step`.
+static SmallVector<Value> denormalizeIndVar(RewriterBase &rewriter,
+ Location loc, ValueRange ivs,
+ ArrayRef<OpFoldResult> lbs,
+ ArrayRef<OpFoldResult> steps) {
+ AffineExpr s0, s1;
+ AffineExpr d0;
+ bindSymbols(rewriter.getContext(), s0, s1);
+ bindDims(rewriter.getContext(), d0);
+ AffineExpr denormExpr = s0 + d0 * s1;
+ SmallVector<Value> denormalizedIvs;
+
+ for (auto [iv, lb, step] : llvm::zip_equal(ivs, lbs, steps)) {
+ OpFoldResult denormValue = affine::makeComposedFoldedAffineApply(
+ rewriter, loc, denormExpr, ArrayRef<OpFoldResult>{iv, lb, step});
+ denormalizedIvs.push_back(
+ getValueOrCreateConstantIndexOp(rewriter, loc, denormValue));
+ }
+ return denormalizedIvs;
+}
+
+/// Given a `scf.forall` loop return a loop op with the loop bounds
+/// normalized.
+/// TODO: Replace this with a general utility to normalize `scf.forall`.
+/// At the time of writing, this wasnt done since adding this to `scf`
+/// dialect would disallow using of `affine.apply` operations due
+/// to cyclic dependencies. To avoid churn in lit tests
+/// with the change this was added with, defer that to a follow up.
+static scf::ForallOp normalizeForallLoopOp(RewriterBase &rewriter,
+ scf::ForallOp loop) {
+ SmallVector<OpFoldResult> lbs = loop.getMixedLowerBound();
+ SmallVector<OpFoldResult> ubs = loop.getMixedUpperBound();
+ SmallVector<OpFoldResult> steps = loop.getMixedStep();
+
+ if (llvm::all_of(
+ lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) &&
+ llvm::all_of(
+ steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) {
+ return loop;
+ }
+
+ Location loc = loop.getLoc();
+ SmallVector<OpFoldResult> normalizedUbs =
+ normalizeUpperBounds(rewriter, loc, lbs, ubs, steps);
+ SmallVector<OpFoldResult> normalizedLbs(normalizedUbs.size(),
+ rewriter.getIndexAttr(0));
+ SmallVector<OpFoldResult> normalizedSteps(normalizedUbs.size(),
+ rewriter.getIndexAttr(1));
+
+ auto normalizedForallOp = rewriter.create<scf::ForallOp>(
+ loc, normalizedLbs, normalizedUbs, normalizedSteps, loop.getOutputs(),
+ loop.getMapping(), [](OpBuilder &, Location, ValueRange) {});
+
+ auto normalizedLoopIvs = normalizedForallOp.getInductionVars();
+ OpBuilder::InsertionGuard g(rewriter);
+ Block *normalizedLoopBlock = normalizedForallOp.getBody();
+ rewriter.setInsertionPointToStart(normalizedLoopBlock);
+
+ SmallVector<Value> argValues =
+ denormalizeIndVar(rewriter, loc, normalizedLoopIvs, lbs, steps);
+ argValues.append(normalizedForallOp.getRegionIterArgs().begin(),
+ normalizedForallOp.getRegionIterArgs().end());
+ Block *origLoopBlock = loop.getBody();
+ rewriter.mergeBlocks(origLoopBlock, normalizedLoopBlock, argValues);
+
+ rewriter.replaceOp(loop, normalizedForallOp);
+ return normalizedForallOp;
+}
+
DiagnosedSilenceableFailure transform::tileToForallOpImpl(
RewriterBase &rewriter, transform::TransformState &state,
TransformOpInterface transformOp, Operation *target,
@@ -2935,23 +3024,6 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
if (!mixedNumThreads.empty()) {
options.setNumThreads(mixedNumThreads);
} else {
- SmallVector<Range> loopRanges = tileableOp.getIterationDomain(rewriter);
- unsigned nLoops = loopRanges.size();
- SmallVector<OpFoldResult> numThreads;
- numThreads.reserve(nLoops);
- AffineExpr s0, s1;
- bindSymbols(rewriter.getContext(), s0, s1);
- AffineExpr divExpr = s0.ceilDiv(s1);
- for (int i = 0, e = std::min(mixedTileSizes.size(), loopRanges.size());
- i < e; ++i) {
- OpFoldResult numTiles = mixedTileSizes[i];
- if (!isConstantIntValue(numTiles, 0))
- numTiles = affine::makeComposedFoldedAffineApply(
- rewriter, tileableOp.getLoc(), divExpr,
- {loopRanges[i].size, numTiles});
- numThreads.push_back(numTiles);
- }
- options.setNumThreads(numThreads);
options.setTileSizes(mixedTileSizes);
}
if (mapping) {
@@ -2962,9 +3034,20 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
if (failed(maybeTilingResult))
return transformOp.emitDefaultSilenceableFailure(tileableOp);
+
rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
tilingResult = *maybeTilingResult;
+
+ if (mixedNumThreads.empty()) {
+ auto generatedForallOp = cast<scf::ForallOp>(tilingResult.loops.front());
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(generatedForallOp);
+ scf::ForallOp normalizedForallOp =
+ normalizeForallLoopOp(rewriter, generatedForallOp);
+ tilingResult.loops.front() = normalizedForallOp;
+ }
+
return DiagnosedSilenceableFailure::success();
}
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 8eb2ab59ac81d..4d3112d626de9 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -217,10 +217,10 @@ static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc,
AffineExpr s0, s1, d0;
bindDims(b.getContext(), d0);
bindSymbols(b.getContext(), s0, s1);
- AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext());
+ AffineMap minMap = AffineMap::get(1, 2, {s0 - d0, s1}, b.getContext());
Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size);
return affine::makeComposedFoldedAffineMin(
- b, loc, minMap, SmallVector<OpFoldResult>{offset, tileSize, size});
+ b, loc, minMap, SmallVector<OpFoldResult>{offset, size, tileSize});
}
/// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index a031e53fe0ffb..3dd02ae4aa885 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1164,3 +1164,37 @@ scf::ForOp mlir::fuseIndependentSiblingForLoops(scf::ForOp target,
return fusedLoop;
}
+
+FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter,
+ scf::ForallOp forallOp) {
+ SmallVector<OpFoldResult> lbs = forallOp.getMixedLowerBound();
+ SmallVector<OpFoldResult> ubs = forallOp.getMixedUpperBound();
+ SmallVector<OpFoldResult> steps = forallOp.getMixedStep();
+
+ if (llvm::all_of(
+ lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) &&
+ llvm::all_of(
+ steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) {
+ return forallOp;
+ }
+
+ SmallVector<OpFoldResult> newLbs, newUbs, newSteps;
+ for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
+ LoopParams normalizedLoopParams =
+ emitNormalizedLoopBounds(rewriter, forallOp.getLoc(), lb, ub, step);
+ newLbs.push_back(normalizedLoopParams.lowerBound);
+ newUbs.push_back(normalizedLoopParams.upperBound);
+ newSteps.push_back(normalizedLoopParams.step);
+ }
+
+ auto normalizedForallOp = rewriter.create<scf::ForallOp>(
+ forallOp.getLoc(), newLbs, newUbs, newSteps, forallOp.getOutputs(),
+ forallOp.getMapping(), [](OpBuilder &, Location, ValueRange) {});
+
+ rewriter.inlineRegionBefore(forallOp.getBodyRegion(),
+ normalizedForallOp.getBodyRegion(),
+ normalizedForallOp.getBodyRegion().begin());
+
+ rewriter.replaceAllOpUsesWith(forallOp, normalizedForallOp);
+ return success();
+}
diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir
index 89183813c080b..8f13c69070457 100644
--- a/mlir/test/Dialect/Linalg/tile-tensors.mlir
+++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir
@@ -119,7 +119,7 @@ module attributes {transform.with_named_sequence} {
// -----
-// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (2, -d0 + s0)>
+// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)>
// CHECK: fold_extract_slice
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<?x128xf32>
diff --git a/mlir/test/Dialect/Linalg/tile-to-forall.mlir b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
index 6e92deaf4cf0d..778d5bb8b9c84 100644
--- a/mlir/test/Dialect/Linalg/tile-to-forall.mlir
+++ b/mlir/test/Dialect/Linalg/tile-to-forall.mlir
@@ -196,10 +196,10 @@ func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C
// CHECK: %[[NT0:.+]] = affine.apply #[[$map0]]()[%[[M]]]
// CHECK: %[[NT1:.+]] = affine.apply #[[$map1]]()[%[[N]]]
// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
- // CHECK: %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
- // CHECK: %[[TS1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
- // CHECK: %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
- // CHECK: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
+ // CHECK-DAG: %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
+ // CHECK-DAG: %[[TS1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
+ // CHECK-DAG: %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
+ // CHECK-DAG: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
// CHECK: tensor.extract_slice %[[A]]
// CHECK: tensor.extract_slice %[[B]]
// CHECK: tensor.extract_slice %[[C_BLK]]
@@ -233,11 +233,11 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: %[[C:[0-9a-z]+]]: tensor
func.func @matmul_tile_size_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: tensor<100x300xf32>) -> tensor<100x300xf32> {
// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (10, 15) shared_outs(%[[C_BLK:.*]] = %[[C]])
- // CHECK: %[[TS:.+]] = affine.min #[[$map0]](%[[IV1]])
+ // CHECK-DAG: %[[TS:.+]] = affine.min #[[$map0]](%[[IV1]])
+ // CHECK-DAG: %[[LB0:.+]] = affine.apply #[[$map2]](%[[IV0]])
+ // CHECK-DAG: %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]])
// CHECK-NOT: affine.max
// CHECK-NOT: affine.min
- // CHECK: %[[LB0:.+]] = affine.apply #[[$map2]](%[[IV0]])
- // CHECK: %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]])
// CHECK: %[[tA:.+]] = tensor.extract_slice %[[A]][%[[LB0]], 0] [10, 200] [1, 1] :
// CHECK: %[[tB:.+]] = tensor.extract_slice %[[B]][0, %[[LB1]]] [200, %[[TS]]] [1, 1] :
// CHECK: %[[tC:.+]] = tensor.extract_slice %[[C_BLK]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] :
@@ -452,10 +452,9 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
-// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (0, d0)>
-// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
-// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 10)>
-// CHECK-DAG: #[[$map6:.+]] = affine_map<(d0) -> (d0 * 20)>
+// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
+// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0) -> (d0 * 10)>
+// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 20)>
// CHECK-LABEL: matmul_tile_size_dynamic(
// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<?x?xf32>
@@ -464,18 +463,16 @@ module attributes {transform.with_named_sequence} {
func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: %[[c1:.*]] = arith.constant 1 : index
// CHECK: %[[c0:.*]] = arith.constant 0 : index
- // CHECK: %[[M:.+]] = tensor.dim %[[A]], %[[c0]] :
- // CHECK: %[[N:.+]] = tensor.dim %[[B]], %[[c1]] :
- // CHECK: %[[NT0:.+]] = affine.apply #map()[%[[M]]]
- // CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
- // CHECK: %[[K:.+]] = tensor.dim %[[A]], %[[c1]] :
+ // CHECK-DAG: %[[M:.+]] = tensor.dim %[[A]], %[[c0]] :
+ // CHECK-DAG: %[[N:.+]] = tensor.dim %[[B]], %[[c1]] :
+ // CHECK-DAG: %[[NT0:.+]] = affine.apply #map()[%[[M]]]
+ // CHECK-DAG: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
+ // CHECK-DAG: %[[K:.+]] = tensor.dim %[[A]], %[[c1]] :
// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
- // CHECK: %[[TSMIN0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
- // CHECK: %[[TS0:.+]] = affine.max #[[$map3]](%[[TSMIN0]])
- // CHECK: %[[TSMIN1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
- // CHECK: %[[TS1:.+]] = affine.max #[[$map3]](%[[TSMIN1]])
- // CHECK: %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
- // CHECK: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
+ // CHECK-DAG: %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
+ // CHECK-DAG: %[[TS1:.+]] = affine.min #[[$map3]](%[[IV1]])[%[[N]]]
+ // CHECK-DAG: %[[LB0:.+]] = affine.apply #[[$map4]](%[[IV0]])
+ // CHECK-DAG: %[[LB1:.+]] = affine.apply #[[$map5]](%[[IV1]])
// CHECK: tensor.extract_slice %[[A]][%[[LB0]], 0] [%[[TS0]], %[[K]]] [1, 1] :
// CHECK: tensor.extract_slice %[[B]][0, %[[LB1]]] [%[[K]], %[[TS1]]] [1, 1] :
// CHECK: tensor.extract_slice %[[C_BLK]][%[[LB0]], %[[LB1]]] [%[[TS0]], %[[TS1]]] [1, 1] :
@@ -523,10 +520,9 @@ module attributes {transform.with_named_sequence} {
// CHECK-DAG: #[[$map0:.+]] = affine_map<()[s0] -> (s0 ceildiv 10)>
// CHECK-DAG: #[[$map1:.+]] = affine_map<()[s0] -> (s0 ceildiv 20)>
// CHECK-DAG: #[[$map2:.+]] = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
-// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0) -> (0, d0)>
-// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
-// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 10)>
-// CHECK-DAG: #[[$map6:.+]] = affine_map<(d0) -> (d0 * 20)>
+// CHECK-DAG: #[[$map3:.+]] = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
+// CHECK-DAG: #[[$map4:.+]] = affine_map<(d0) -> (d0 * 10)>
+// CHECK-DAG: #[[$map5:.+]] = affine_map<(d0) -> (d0 * 20)>
// CHECK-LABEL: matmul_tile_size_dynamic(
// CHECK-SAME: %[[A:[0-9a-z]+]]: tensor<?x?xf32>
@@ -535,18 +531,16 @@ module attributes {transform.with_named_sequence} {
func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK: %[[c1:.*]] = arith.constant 1 : index
// CHECK: %[[c0:.*]] = arith.constant 0 : index
- // CHECK: %[[M:.+]] = tensor.dim %[[A]], %[[c0]] :
- // CHECK: %[[N:.+]] = tensor.dim %[[B]], %[[c1]] :
- // CHECK: %[[NT0:.+]] = affine.apply #map()[%[[M]]]
- // CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
- // CHECK: %[[K:.+]] = tensor.dim %[[A]], %[[c1]] :
+ // CHECK-DAG: %[[M:.+]] = tensor.dim %[[A]], %[[c0]] :
+ // CHECK-DAG: %[[N:.+]] = tensor.dim %[[B]], %[[c1]] :
+ // CHECK-DAG: %[[NT0:.+]] = affine.apply #map()[%[[M]]]
+ // CHECK-DAG: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
+ // CHECK-DAG: %[[K:.+]] = tensor.dim %[[A]], %[[c1]] :
// CHECK: scf.forall (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]]) shared_outs(%[[C_BLK:.*]] = %[[C]])
- // CHECK: %[[TSMIN0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
- // CHECK: %[[TS0:.+]] = affine.max #[[$map3]](%[[TSMIN0]])
- // CHECK: %[[TSMIN1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
- // CHECK: %[[TS1:.+]] = affine.max #[[$map3]](%[[TSMIN1]])
- // CHECK: %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
- // CHECK: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
+ // CHECK-DAG: %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
+ // CHECK-DAG: %[[TS1:.+]] = affine.min #[[$map3]](%[[IV1]])[%[[N]]]
+ // CHECK-DAG: %[[LB0:.+]] = affine.apply #[[$map4]](%[[IV0]])
+ // CHECK-DAG: %[[LB1:.+]] = affine.apply #[[$map5]](%[[IV1]])
// CHECK: tensor.extract_slice %[[A]][%[[LB0]], 0] [%[[TS0]], %[[K]]] [1, 1] :
// CHECK: tensor.extract_slice %[[B]][0, %[[LB1]]] [%[[K]], %[[TS1]]] [1, 1] :
// CHECK: tensor.extract_slice %[[C_BLK]][%[[LB0]], %[[LB1]]] [%[[TS0]], %[[TS1]]] [1, 1] :
diff --git a/mlir/test/Dialect/Linalg/transform-op-tile.mlir b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
index 3467a539496b8..a261f03983275 100644
--- a/mlir/test/Dialect/Linalg/transform-op-tile.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-tile.mlir
@@ -184,7 +184,7 @@ module {
// CHECK: %[[VS:.*]] = vector.vscale
// CHECK: %[[STEP:.*]] = arith.muli %[[VEC_SIZE]], %[[VS]] : index
// CHECK: scf.for %[[IV:.*]] = %[[C0]] to %[[DIM]] step %[[STEP]] iter_args(%[[VAL:.*]] = %[[ARG_2]]) -> (tensor<?xf32>) {
-// CHECK: %[[SIZE:.*]] = affine.min affine_map<(d0)[s0, s1] -> (s0, -d0 + s1)>(%[[IV]])[%[[STEP]], %[[DIM]]]
+// CHECK: %[[SIZE:.*]] = affine.min affine_map<(d0)[s0, s1] -> (-d0 + s0, s1)>(%[[IV]])[%[[DIM]], %[[STEP]]]
// CHECK: %[[SLICE_ARG0:.*]] = tensor.extract_slice %[[ARG_0]][%[[IV]]] [%[[SIZE]]] [1] : tensor<?xf32> to tensor<?xf32>
// CHECK: %[[SLICE_ARG1:.*]] = tensor.extract_slice %[[ARG_1]][%[[IV]]] [%[[SIZE]]] [1] : tensor<?xf32> to tensor<?xf32>
// CHECK: %[[SLICE_ARG2:.*]] = tensor.extract_slice %[[VAL]][%[[IV]]] [%[[SIZE]]] [1] : tensor<?xf32> to tensor<?xf32>
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
index 11ab30a7d237c..d1aed593f4545 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
@@ -428,7 +428,7 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
-// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)>
+// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 10)>
// CHECK: func @matmul_sequence_fusion(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<?x?xf32>
diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
index 0a4d4c45f10be..8eb1311170c66 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
@@ -16,8 +16,8 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
-// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)>
-// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0] -> (20, -d0 + s0)>
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 10)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 20)>
// CHECK-LABEL: func.func @simple_matmul(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
@@ -68,9 +68,9 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
-// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)>
-// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0] -> (20, -d0 + s0)>
-// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0)[s0] -> (30, -d0 + s0)>
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 10)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 20)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 30)>
// CHECK-LABEL: func.func @simple_matmul_memref(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
@@ -127,7 +127,7 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
-// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0) -> (10, -d0 + 128)>
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0) -> (-d0 + 128, 10)>
// CHECK-LABEL: func.func @multi_result(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x200x300xf32>)
// CHECK-DAG: %[[INIT0:.+]] = tensor.empty()
@@ -180,9 +180,9 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
-// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)>
-// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0] -> (20, -d0 + s0)>
-// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0)[s0] -> (30, -d0 + s0)>
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 10)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 20)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 30)>
// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0 * 2 - 2)>
// CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0)[s0] -> (d0 + s0 * 3 - 3)>
// CHECK-LABEL: func.func @conv2D(
@@ -287,9 +287,9 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
-// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0)[s0] -> (20, -d0 + s0)>
-// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0] -> (30, -d0 + s0)>
-// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)>
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 20)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 30)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 10)>
// CHECK-LABEL: func.func @interchange_matmul(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
index c5aff744b57ee..53dd0c6a2425c 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir
@@ -17,8 +17,8 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
-// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)>
-// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (20, -d0 + s0)>
+// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 10)>
+// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 20)>
// CHECK: func.func @simple_matmul(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
@@ -65,8 +65,8 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
-// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)>
-// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0] -> (20, -d0 + s0)>
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 10)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 20)>
// CHECK-LABEL: func.func @simple_matmul_memref(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref<?x?xf32>
@@ -117,7 +117,7 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
-// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0) -> (10, -d0 + 128)>
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0) -> (-d0 + 128, 10)>
// CHECK-LABEL: func.func @multi_result(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x200x300xf32>)
// CHECK-DAG: %[[INIT0:.+]] = tensor.empty()
@@ -161,9 +161,9 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
-// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)>
-// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0] -> (20, -d0 + s0)>
-// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0)[s0] -> (30, -d0 + s0)>
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 10)>
+// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 20)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 30)>
// CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0)[s0] -> (d0 + s0 * 2 - 2)>
// CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0)[s0] -> (d0 + s0 * 3 - 3)>
// CHECK-LABEL: func.func @conv2D(
@@ -264,8 +264,8 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
-// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0)[s0] -> (20, -d0 + s0)>
-// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)>
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 20)>
+// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 10)>
// CHECK-LABEL: func.func @interchange_matmul(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?xf32>
More information about the Mlir-commits
mailing list