[Mlir-commits] [mlir] 4a02001 - [NFC] Simplify the tiling implementation using cloning. (#72178)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Nov 20 09:05:54 PST 2023
Author: MaheshRavishankar
Date: 2023-11-20T09:05:48-08:00
New Revision: 4a020018ce7abdee21e976f7ed5746ef2eb2c0fd
URL: https://github.com/llvm/llvm-project/commit/4a020018ce7abdee21e976f7ed5746ef2eb2c0fd
DIFF: https://github.com/llvm/llvm-project/commit/4a020018ce7abdee21e976f7ed5746ef2eb2c0fd.diff
LOG: [NFC] Simplify the tiling implementation using cloning. (#72178)
The current implementation of tiling using `scf.for` is convoluted to
make sure that the destination passing style of the untiled program is
preserved. The addition of support to tile using `scf.forall` (adapted
from the transform operation in Linalg) in
https://github.com/llvm/llvm-project/pull/67083 used cloning of the
tiled operations to better streamline the implementation. This PR adapts
the other tiling methods to use a similar approach, making the
transformations (and handling destination passing style semantics) more
systematic.
---------
Co-authored-by: Abhishek-Varma <avarma094 at gmail.com>
Added:
Modified:
mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
mlir/test/Dialect/Tensor/tiling.mlir
mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 81325b62791c44b8..2f8f337bb8057ce9 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -83,6 +83,12 @@ FailureOr<SCFTilingResult> tileUsingSCFForOp(RewriterBase &rewriter,
TilingInterface op,
const SCFTilingOptions &options);
+/// Method to tile an op that implements the `TilingInterface` using
+/// `scf.forall`.
+FailureOr<SCFTilingResult>
+tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
+ const SCFTilingOptions &options);
+
/// Options used to control tile + fuse.
struct SCFTileAndFuseOptions {
/// The tiling options used to control the tiling of the consumer.
@@ -93,12 +99,6 @@ struct SCFTileAndFuseOptions {
}
};
-/// Method to tile an op that implements the `TilingInterface` using
-/// `scf.forall`.
-FailureOr<SCFTilingResult>
-tileUsingSCFForallOp(RewriterBase &rewriter, TilingInterface op,
- const SCFTilingOptions &options);
-
/// Fuse the producer of the source of `candidateSliceOp` by computing the
/// required slice of the producer in-place. Note that the method
/// replaces the uses of `candidateSliceOp` with the tiled and fused producer
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index df162d29a48eb89d..b91af4d246d99191 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -128,10 +128,10 @@ static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
Operation *op,
ValueRange newDestArgs) {
Operation *clonedOp = rewriter.clone(*op);
- if (auto destinationStyleOp =
- dyn_cast<DestinationStyleOpInterface>(clonedOp)) {
+ if (newDestArgs.empty())
+ return clonedOp;
+ if (auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp))
destinationStyleOp.getDpsInitsMutable().assign(newDestArgs);
- }
return clonedOp;
}
@@ -139,13 +139,16 @@ static Operation *cloneOpAndUpdateDestinationArgs(RewriterBase &rewriter,
/// - `loopRanges` specifies the lb, ub and step of the untiled iteration space.
/// - `tileSizes` is the tile sizes to use. Zero represent untiled loops.
/// - In `offsets` and `sizes` return the multi-dimensional offset and size of
-/// the
-/// tile processed within the inner most loop.
+/// the tile processed within the inner most loop.
+/// Note that this methods adds `scf.yield` operation for all but the innermost
+/// loop. These yield the value returned by the immediately inner loop. The
+/// caller is expected to add the scf.yield operation for the innermost loop.
static SmallVector<scf::ForOp> generateTileLoopNest(
OpBuilder &builder, Location loc, ArrayRef<Range> loopRanges,
ArrayRef<OpFoldResult> tileSizes, SmallVector<OpFoldResult> &offsets,
- SmallVector<OpFoldResult> &sizes) {
- assert(!loopRanges.empty() && "expected at least one loop range");
+ SmallVector<OpFoldResult> &sizes, ValueRange destinationTensors = {}) {
+ if (loopRanges.empty())
+ return {};
assert(loopRanges.size() == tileSizes.size() &&
"expected as many tile sizes as loop ranges");
OpBuilder::InsertionGuard guard(builder);
@@ -169,136 +172,99 @@ static SmallVector<scf::ForOp> generateTileLoopNest(
}
auto loop = builder.create<scf::ForOp>(
- loc, offset, size, tileSize, ValueRange{},
+ loc, offset, size, tileSize, destinationTensors,
[&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv,
ValueRange /*iterArgs*/) {
sizes[loopRange.index()] =
getBoundedTileSize(bodyBuilder, bodyLoc, loopRange.value(), iv,
getAsOpFoldResult(tileSize));
- builder.create<scf::YieldOp>(loc);
});
offsets[loopRange.index()] = loop.getInductionVar();
loops.push_back(loop);
- builder.setInsertionPoint(loop.getBody()->getTerminator());
+ builder.setInsertionPointToEnd(loop.getBody());
+ destinationTensors = loop.getRegionIterArgs();
}
- return loops;
-}
-/// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`,
-/// construct the destructive update pattern that inserts the yielded
-/// value into a destination tensor provided by `initValue` at offset
-/// `tileOffsets` and size `tileSizes`. For example,
-///
-/// ```mlir
-/// scf.for %iv0 = ... {
-/// %0 = tiled_op
-/// }
-/// ```
-///
-/// is transformed to
-///
-/// ```mlir
-/// scf.for %iv0 = ... iter_args(%arg = %0) {
-/// %1 = tensor.extract_slice %arg
-/// %2 = tiled_op
-/// %3 = tensor.insert_slice %2 into %arg
-/// scf.yield %3
-/// }
-/// ```
-/// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`.
-static SmallVector<Value>
-yieldTiledValues(RewriterBase &rewriter, ValueRange initValues,
- ValueRange yieldedValues,
- ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
- ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
- MutableArrayRef<scf::ForOp> loops) {
- NewYieldValuesFn yieldValueFn =
- [&](OpBuilder &b, Location loc,
- ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> {
- SmallVector<Value> inserts;
- for (const auto &yieldedValue : llvm::enumerate(yieldedValues)) {
- ArrayRef<OpFoldResult> tileOffsets =
- tileOffsetsList[yieldedValue.index()];
- ArrayRef<OpFoldResult> tileSizes = tileSizesList[yieldedValue.index()];
- SmallVector<OpFoldResult> tileStrides(tileOffsets.size(),
- b.getIndexAttr(1));
- Value insert = b.create<tensor::InsertSliceOp>(
- loc, yieldedValue.value(), newBBArgs[yieldedValue.index()],
- tileOffsets, tileSizes, tileStrides);
- inserts.push_back(insert);
+ // Add the scf.yield operations for all the outer loops.
+ if (!loops.empty()) {
+ for (auto [outerLoop, innerLoop] :
+ llvm::zip_equal(MutableArrayRef(loops).drop_back(),
+ MutableArrayRef(loops).drop_front())) {
+ builder.setInsertionPointToEnd(outerLoop.getBody());
+ builder.create<scf::YieldOp>(outerLoop.getLoc(), innerLoop.getResults());
}
- return inserts;
- };
-
- SmallVector<scf::ForOp> newLoops =
- replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn,
- /*replaceIterOperandsUsesInLoop =*/false);
- for (const auto &loop : llvm::enumerate(loops)) {
- loops[loop.index()] = newLoops[loop.index()];
}
- return llvm::to_vector(llvm::map_range(
- loops.front().getResults().take_back(yieldedValues.size()),
- [](OpResult r) -> Value { return r; }));
+ return loops;
}
-/// If the tiled operation is destination passing style, update the
-/// slice of the destination used (which refers to the untiled destination)
-/// to use the corresponding region argument of the innermost loop.
-///
-/// ```mlir
-/// %0 =
-/// scf.for %iv0 = ... iter_args(%arg = %0) {
-/// %1 = tensor.extract_slice %0
-/// %2 = tiled_op
-/// %3 = tensor.insert_slice %2 into %arg
-/// scf.yield %3
-/// }
-/// ```
-///
-/// is transformed to
-///
-/// ```mlir
-/// scf.for %iv0 = ... iter_args(%arg = %0) {
-/// %1 = tensor.extract_slice %arg
-/// %2 = tiled_op
-/// %3 = tensor.insert_slice %2 into %arg
-/// scf.yield %3
-/// }
-/// ```
-static void
-updateDestinationOperandsForTiledOp(OpBuilder &builder,
- ValueRange tiledOpDestinationValues,
- ValueRange bbArgsList) {
- for (const auto &destValue : llvm::enumerate(tiledOpDestinationValues)) {
- auto sliceOp = destValue.value().getDefiningOp<tensor::ExtractSliceOp>();
- if (!sliceOp)
- continue;
- sliceOp.setOperand(0, bbArgsList[destValue.index()]);
+/// Method to add new init values to a loop nest. Updates `loops` in-place with
+/// new loops that use the `newInitValues`.
+/// The outer-loops are updated to yield the new result values of the inner
+/// loop. For the innermost loop, the call back `getNewYields` is invoked to get
+/// the additional values to yield form the innermost loop.
+static void addInitOperandsToLoopNest(
+ RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loops,
+ ValueRange newInitValues,
+ llvm::function_ref<SmallVector<Value>(RewriterBase &rewriter, Value iv,
+ ValueRange newRegionIterArgs)>
+ getNewYieldValsFn) {
+ SmallVector<scf::ForOp> newLoops;
+ if (loops.empty())
+ return;
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(loops.front());
+ for (auto &loop : loops) {
+ rewriter.setInsertionPoint(loop);
+
+ // Create a new loop with the new init values for this loop.
+ SmallVector<Value> newInits = llvm::to_vector(loop.getInitArgs());
+ newInits.append(newInitValues.begin(), newInitValues.end());
+ auto newLoop = rewriter.create<scf::ForOp>(
+ loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(),
+ loop.getStep(), newInits,
+ [&](OpBuilder &b, Location loc, Value iv, ValueRange iterArgs) {});
+
+ // Merge the body of the new loop with the body of the old loops.
+ SmallVector<Value> sourceBlockArgs;
+ sourceBlockArgs.push_back(newLoop.getInductionVar());
+ auto newRegionIterArgs = newLoop.getRegionIterArgs();
+ sourceBlockArgs.append(
+ newRegionIterArgs.begin(),
+ std::next(newRegionIterArgs.begin(), loop.getNumResults()));
+ rewriter.mergeBlocks(loop.getBody(), newLoop.getBody(), sourceBlockArgs);
+ rewriter.replaceOp(loop,
+ newLoop.getResults().take_front(loop.getNumResults()));
+ loop = newLoop;
+ newInitValues = newLoop.getRegionIterArgs().take_back(newInitValues.size());
}
-}
-/// Helper method to yield the values of the tiled op, as well as
-/// update the destination operands of the tiled op, if it is
-/// a destination passing style op.
-static SmallVector<Value>
-yieldTiledValues(RewriterBase &rewriter, ArrayRef<Value> initValues,
- TilingResult tilingResult,
- ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
- ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
- MutableArrayRef<scf::ForOp> loops) {
- SmallVector<Value> replacements =
- yieldTiledValues(rewriter, initValues, tilingResult.tiledValues,
- tileOffsetsList, tileSizesList, loops);
- for (auto tiledOp : tilingResult.tiledOps) {
- if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(tiledOp)) {
- auto innerMostLoop = loops.back();
- SmallVector<Value> tiledOpDestinationTensors =
- llvm::to_vector(dstOp.getDpsInits());
- updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors,
- innerMostLoop.getRegionIterArgs());
- }
+ // Update the loop body of the innermost loop to get new yield values.
+ scf::ForOp innerMostLoop = loops.back();
+ auto innerMostYieldOp =
+ cast<scf::YieldOp>(innerMostLoop.getBody()->getTerminator());
+ rewriter.setInsertionPoint(innerMostYieldOp);
+ SmallVector<Value> newYieldVals =
+ getNewYieldValsFn(rewriter, innerMostLoop.getInductionVar(),
+ innerMostLoop.getRegionIterArgs());
+ SmallVector<Value> newYieldOperands =
+ llvm::to_vector(innerMostYieldOp->getOperands());
+ newYieldOperands.append(newYieldVals);
+ rewriter.replaceOpWithNewOp<scf::YieldOp>(innerMostYieldOp, newYieldOperands);
+
+ // Make all other loops except the innermost loops yield the values returned
+ // by the inner loop.
+ for (auto [outerLoop, innerLoop] :
+ llvm::zip_equal(loops.drop_back(), loops.drop_front())) {
+ auto outerLoopYield =
+ cast<scf::YieldOp>(outerLoop.getBody()->getTerminator());
+ SmallVector<Value> newYields =
+ llvm::to_vector(outerLoopYield.getOperands());
+ ValueRange additionalYields =
+ innerLoop.getResults().take_back(newInitValues.size());
+ newYields.append(additionalYields.begin(), additionalYields.end());
+ rewriter.setInsertionPoint(outerLoopYield);
+ rewriter.replaceOpWithNewOp<scf::YieldOp>(outerLoopYield, newYields);
}
- return replacements;
}
/// Implementation of tiling transformation of `op` that implements the
@@ -321,7 +287,6 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
return rewriter.notifyMatchFailure(
op, "unable to tile op with no iteration domain");
}
-
// 2. Materialize the tile sizes. Enforce the convention that "tiling by zero"
// skips tiling a particular dimension. This convention is significantly
// simpler to handle instead of adjusting affine maps to account for missing
@@ -333,6 +298,14 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
tileSizeVector.append(numLoops - tileSizeVector.size(), zero);
}
+ // 3. Find the destination tensors to use for the operation.
+ SmallVector<Value> destinationTensors;
+ if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
+ destinationTensors))) {
+ return rewriter.notifyMatchFailure(op,
+ "unable to create destination tensors");
+ }
+
SmallVector<OpFoldResult> offsets, sizes;
SmallVector<scf::ForOp> forLoops;
{
@@ -354,11 +327,12 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
applyPermutationToVector(tileSizeVector, interchangeVector);
}
- // 3. Materialize an empty loop nest that iterates over the tiles. These
+ // 4. Materialize an empty loop nest that iterates over the tiles. These
// loops for now do not return any values even if the original operation has
// results.
forLoops = generateTileLoopNest(rewriter, op.getLoc(), iterationDomain,
- tileSizeVector, offsets, sizes);
+ tileSizeVector, offsets, sizes,
+ destinationTensors);
if (!interchangeVector.empty()) {
auto inversePermutation = invertPermutationVector(interchangeVector);
@@ -375,17 +349,29 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
}
});
- // 4. Generate the tiled implementation within the inner most loop.
- if (!forLoops.empty())
- rewriter.setInsertionPoint(forLoops.back().getBody()->getTerminator());
- FailureOr<TilingResult> tiledImplementation =
- op.getTiledImplementation(rewriter, offsets, sizes);
+ // 5. Generate the tiled implementation within the inner most loop.
+ SmallVector<Value> clonedOpDestination = destinationTensors;
+ if (!forLoops.empty()) {
+ rewriter.setInsertionPointToEnd(forLoops.back().getBody());
+ clonedOpDestination =
+ llvm::map_to_vector(forLoops.back().getRegionIterArgs(),
+ [](BlockArgument b) -> Value { return b; });
+ }
- if (op->getNumResults() == 0) {
- return scf::SCFTilingResult{
- tiledImplementation->tiledOps, getAsOperations(forLoops), {}};
+ // 5a. Clone the operation within the loop body.
+ auto clonedOp = cast<TilingInterface>(
+ cloneOpAndUpdateDestinationArgs(rewriter, op, clonedOpDestination));
+
+ // 5b. Tile the cloned operation.
+ FailureOr<TilingResult> tiledImplementation =
+ clonedOp.getTiledImplementation(rewriter, offsets, sizes);
+ if (failed(tiledImplementation)) {
+ return rewriter.notifyMatchFailure(op, "failed to tile operation");
}
+ // 5c. Delete the cloned operation.
+ rewriter.eraseOp(clonedOp);
+
// If loops are empty, the tiled op is used as the replacement for the untiled
// op.
if (forLoops.empty()) {
@@ -394,30 +380,39 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
tiledImplementation->tiledValues};
}
- // 5. Yield all the results of the tiled operation. The surrounding loop
- // nest is modified to insert a destructive update pattern to yield
- // from the loop nest values to replace the untiled op with.
+ if (op->getNumResults() == 0) {
+ // The innermost loop does not have a `scf.yield` yet. There is nothing to
+ // return, so generate an empty `scf.yield` operation.
+ rewriter.setInsertionPointToEnd(forLoops.back().getBody());
+ rewriter.create<scf::YieldOp>(op->getLoc());
+ return scf::SCFTilingResult{
+ tiledImplementation->tiledOps, getAsOperations(forLoops), {}};
+ }
+
+ // 6. Yield all the results of the tiled operation.
int64_t numResults = op->getNumResults();
SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults),
resultSizesList(numResults);
- for (const auto &result : llvm::enumerate(op->getResults())) {
- if (failed(op.getResultTilePosition(rewriter, result.index(), offsets,
- sizes,
- resultOffsetsList[result.index()],
- resultSizesList[result.index()]))) {
+ SmallVector<Value> yieldedValues;
+ for (auto [index, tiledValue] :
+ llvm::enumerate(tiledImplementation->tiledValues)) {
+ SmallVector<OpFoldResult> resultOffsets, resultSizes;
+ if (failed(op.getResultTilePosition(rewriter, index, offsets, sizes,
+ resultOffsets, resultSizes))) {
return rewriter.notifyMatchFailure(
op, "failed to get slice of result produced");
}
+ SmallVector<OpFoldResult> resultStrides(resultOffsets.size(),
+ rewriter.getIndexAttr(1));
+ auto insertSlice = rewriter.create<tensor::InsertSliceOp>(
+ op->getLoc(), tiledValue, clonedOpDestination[index], resultOffsets,
+ resultSizes, resultStrides);
+ yieldedValues.push_back(insertSlice);
}
+ rewriter.create<scf::YieldOp>(op->getLoc(), yieldedValues);
- SmallVector<Value> destinationTensors;
- if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op,
- destinationTensors)))
- return rewriter.notifyMatchFailure(op, "failed to get destinations");
-
- SmallVector<Value> replacements = yieldTiledValues(
- rewriter, destinationTensors, tiledImplementation.value(),
- resultOffsetsList, resultSizesList, forLoops);
+ SmallVector<Value> replacements = llvm::map_to_vector(
+ forLoops.front().getResults(), [](OpResult r) -> Value { return r; });
LLVM_DEBUG({
if (!forLoops.empty()) {
llvm::dbgs() << "After tiled implementation :\n";
@@ -457,42 +452,59 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
reductionDims.push_back(idx);
}
- // 1. create the inital tensor value.
+ // 2. create the inital tensor value.
FailureOr<Operation *> identityTensor =
op.generateInitialTensorForPartialReduction(b, loc, tileSizesVector,
reductionDims);
if (failed(identityTensor))
return b.notifyMatchFailure(op,
"cannot create a tensor of identity value.");
- // 2. Create the nested loops.
+ // 3. Create the nested loops.
SmallVector<OpFoldResult> offsets, sizes;
- SmallVector<scf::ForOp> loops = generateTileLoopNest(
- b, loc, iterationDomain, tileSizesVector, offsets, sizes);
-
- // 3. Generate the tiled implementation within the inner most loop.
- b.setInsertionPoint(loops.back().getBody()->getTerminator());
- Operation *parallelOp = op.tileToPartialReduction(
- b, loc, (*identityTensor)->getResults(), offsets, sizes, reductionDims);
+ SmallVector<scf::ForOp> loops =
+ generateTileLoopNest(b, loc, iterationDomain, tileSizesVector, offsets,
+ sizes, identityTensor.value()->getResults());
- SmallVector<OpFoldResult> resultSizesList;
- for (size_t i = 0; i < offsets.size(); i++)
- resultSizesList.push_back(
+ // 4. Generate the tiled implementation within the inner most loop.
+ // 4a. Clone the operation within the loop body.
+ SmallVector<Value> clonedOpDestination =
+ llvm::map_to_vector(identityTensor.value()->getResults(),
+ [](OpResult res) -> Value { return res; });
+ if (!loops.empty()) {
+ b.setInsertionPointToEnd(loops.back().getBody());
+ clonedOpDestination =
+ llvm::map_to_vector(loops.back().getRegionIterArgs(),
+ [](BlockArgument b) -> Value { return b; });
+ }
+ auto clonedOp = cast<PartialReductionOpInterface>(
+ cloneOpAndUpdateDestinationArgs(b, op, clonedOpDestination));
+
+ // 4b. Tile the cloned operation.
+ Operation *parallelOp = clonedOp.tileToPartialReduction(
+ b, loc, clonedOpDestination, offsets, sizes, reductionDims);
+ // 4c. Delete the cloned operation.
+ b.eraseOp(clonedOp);
+
+ SmallVector<OpFoldResult> outSizes;
+ for (size_t i = 0; i < offsets.size(); i++) {
+ outSizes.push_back(
tensor::getMixedSize(b, loc, parallelOp->getResult(0), i));
+ }
SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
- SmallVector<Value> replacements = yieldTiledValues(
- b, (*identityTensor)->getResults(), parallelOp->getResults(), outOffsets,
- resultSizesList, loops);
-
- auto dstOp = cast<DestinationStyleOpInterface>(parallelOp);
- auto innerMostLoop = loops.back();
- SmallVector<Value> destinationTensors = llvm::to_vector(dstOp.getDpsInits());
- assert(destinationTensors.size() ==
- innerMostLoop.getRegionIterArgs().size() &&
- "unexpected number of outputs");
- updateDestinationOperandsForTiledOp(b, destinationTensors,
- innerMostLoop.getRegionIterArgs());
-
- // 4. Apply the merge reduction to combine all the partial values.
+ SmallVector<OpFoldResult> outStrides(outOffsets.size(), b.getIndexAttr(1));
+ SmallVector<Value> yieldedVals;
+ auto bbArgs = loops.back().getRegionIterArgs();
+ for (auto [result, bbArg] : llvm::zip(parallelOp->getResults(), bbArgs)) {
+ Value insert = b.create<tensor::InsertSliceOp>(
+ loc, result, bbArg, outOffsets, outSizes, outStrides);
+ yieldedVals.push_back(insert);
+ }
+ b.create<scf::YieldOp>(loc, yieldedVals);
+
+ SmallVector<Value> replacements = llvm::map_to_vector(
+ loops.front().getResults(), [](OpResult r) -> Value { return r; });
+
+ // 5. Apply the merge reduction to combine all the partial values.
b.setInsertionPointAfter(*loops.begin());
Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDims);
b.replaceOp(op, mergeOp->getResults());
@@ -544,17 +556,55 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
loops);
if (!fusableProducer)
return std::nullopt;
+ unsigned resultNumber = fusableProducer.getResultNumber();
- // 2. Generate the tiled implementation of the producer of the source
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(candidateSliceOp);
+
+ // 2. Clone the fused producer
+ // 2a. Compute the destination operands to use for the cloned operation.
+ SmallVector<Value> origDestinationTensors, clonedOpDestinationTensors;
+ Operation *fusableProducerOp = fusableProducer.getOwner();
+ if (isa<DestinationStyleOpInterface>(fusableProducerOp) &&
+ failed(tensor::getOrCreateDestinations(
+ rewriter, fusableProducerOp->getLoc(), fusableProducerOp,
+ origDestinationTensors)))
+ return std::nullopt;
+
+ clonedOpDestinationTensors = origDestinationTensors;
+ if (destinationInitArg &&
+ isa<DestinationStyleOpInterface>(fusableProducerOp)) {
+ // 2b. If the producer is also destination style, then to maintain the
+ // destination passing style, update the destination of the producer to be
+ // the source of the slice.
+ clonedOpDestinationTensors[resultNumber] = candidateSliceOp.getSource();
+ }
+ // 2c. Clone the fused producer.
+ Operation *clonedProducerOp = cloneOpAndUpdateDestinationArgs(
+ rewriter, fusableProducerOp, clonedOpDestinationTensors);
+ // 2d. Update the source of the candidateSlice to be the cloned producer.
+ // Easier to just clone the slice with
diff erent source since replacements
+ // and DCE of cloned ops becomes easier
+ SmallVector<Value> candidateSliceOpOperands =
+ llvm::to_vector(candidateSliceOp->getOperands());
+ candidateSliceOpOperands[0] = clonedProducerOp->getResult(resultNumber);
+ tensor::ExtractSliceOp clonedCandidateSliceOp =
+ mlir::clone(rewriter, candidateSliceOp,
+ candidateSliceOp->getResultTypes(), candidateSliceOpOperands);
+
+ // 3. Generate the tiled implementation of the producer of the source
FailureOr<TilingResult> tileAndFuseResult =
- tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp,
- fusableProducer);
+ tensor::replaceExtractSliceWithTiledProducer(
+ rewriter, clonedCandidateSliceOp,
+ clonedProducerOp->getResult(resultNumber));
if (failed(tileAndFuseResult))
return std::nullopt;
+ // Note: Do not delete the candidateSliceOp, since its passed in from the
+ // caller.
rewriter.replaceAllUsesWith(candidateSliceOp,
tileAndFuseResult->tiledValues[0]);
+ rewriter.eraseOp(clonedCandidateSliceOp);
+ rewriter.eraseOp(clonedProducerOp);
// 3. If the slice is for a destination operand, for example,
//
@@ -576,7 +626,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
// %1 = linalg.fill
// %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) {
// %3 = scf.for .. iter_args(%arg1 = %arg0) {
- // %4 = tensor.extract_slice %0 /*incorrect value */ [..]
+ // %4 = tensor.extract_slice %arg1[..]
// %5 = linalg.fill .. outs(%4 : )
// .. = linalg.matmul .. outs(%5 : )
// }
@@ -585,46 +635,25 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
//
// The untiled `linalg.fill` is still used as the `init_value` since it
// was originally a destination operand of the untiled `linalg.matmul`.
- // When fusing an operand that is a destination operand.
- // - Update the iter_arg of the outer most loop to use the destination
- // of the untiled producer.
- // - Update the destination of the slice of the tiled producer generated
- // to use the same basic block argument as the slice that was used to
- // generate inplace the tiled implementation of the producer.
- // With this the IR will be.
+ // When fusing an operand that is a destination operand, the iter_arg of
+ // the outer most loop should be changed to use the destination of the
+ // fused operation. With this the IR will be.
//
// ```
// %0 = linalg.init
// %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) {
// %2 = scf.for .. iter_args(%arg1 = %arg0) {
- // %3 = tensor.extract_slice %arg1 /* corrected value */ [..]
+ // %3 = tensor.extract_slice %arg1[..]
// %4 = linalg.fill .. outs(%3 : )
// .. = linalg.matmul .. outs(%4 : )
// }
// }
// ```
- // TODO: This can be modeled better if the `DestinationStyleOpInterface`.
- // Update to use that when it does become available.
- scf::ForOp outerMostLoop = loops.front();
if (destinationInitArg &&
- (*destinationInitArg)->getOwner() == outerMostLoop) {
- unsigned iterArgNumber =
- outerMostLoop.getTiedLoopResult(*destinationInitArg).getResultNumber();
- int64_t resultNumber = fusableProducer.getResultNumber();
- if (auto dstOp =
- dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
- (*destinationInitArg)
- ->set(dstOp.getTiedOpOperand(fusableProducer)->get());
- }
- for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) {
- auto dstOp = dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
- if (!dstOp)
- continue;
- scf::ForOp innerMostLoop = loops.back();
- updateDestinationOperandsForTiledOp(
- rewriter, dstOp.getDpsInitOperand(resultNumber)->get(),
- innerMostLoop.getRegionIterArgs()[iterArgNumber]);
- }
+ isa<DestinationStyleOpInterface>(fusableProducerOp) && !loops.empty()) {
+ loops.front()
+ ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
+ .set(origDestinationTensors[resultNumber]);
}
return scf::SCFFuseProducerOfSliceResult{fusableProducer,
tileAndFuseResult->tiledValues[0],
@@ -636,28 +665,46 @@ void mlir::scf::yieldReplacementForFusedProducer(
RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
MutableArrayRef<scf::ForOp> loops) {
- auto [fusableProducer, fusedProducerValue, tileAndFusedOps] =
- fusedProducerInfo;
- SmallVector<Value> initValues;
+ if (loops.empty())
+ return;
+
+ OpResult fusableProducer = fusedProducerInfo.origProducer;
+ Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer;
FailureOr<Value> initValue = tensor::getOrCreateDestination(
rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
if (succeeded(initValue)) {
- SmallVector<OpFoldResult> resultOffsets = sliceOp.getMixedOffsets();
- SmallVector<OpFoldResult> resultSizes = sliceOp.getMixedSizes();
- SmallVector<Value> yieldedVals =
- yieldTiledValues(rewriter, initValue.value(), fusedProducerValue,
- resultOffsets, resultSizes, loops);
- }
- for (auto tileAndFusedOp : tileAndFusedOps) {
- auto dstStyleProducer =
- dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
- if (!dstStyleProducer)
- continue;
- Value dstValue =
- dstStyleProducer.getDpsInitOperand(fusableProducer.getResultNumber())
- ->get();
- updateDestinationOperandsForTiledOp(
- rewriter, dstValue, loops.back().getRegionIterArgs().back());
+
+ auto newYieldValuesFn =
+ [&](RewriterBase &innerRewriter, Value iv,
+ ValueRange newRegionIterArgs) -> SmallVector<Value> {
+ OpBuilder::InsertionGuard g(innerRewriter);
+ if (auto tiledDestStyleOp =
+ tiledAndFusedProducer
+ .getDefiningOp<DestinationStyleOpInterface>()) {
+ rewriter.setInsertionPoint(tiledDestStyleOp);
+ BlockArgument newRegionArg = loops.back().getRegionIterArgs().back();
+ auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
+ sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
+ sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
+ unsigned resultNumber = fusableProducer.getResultNumber();
+ rewriter.updateRootInPlace(tiledDestStyleOp, [&]() {
+ tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
+ });
+
+ Block *block = rewriter.getInsertionPoint()->getBlock();
+ rewriter.setInsertionPoint(block->getTerminator());
+ Value replacement = rewriter.create<tensor::InsertSliceOp>(
+ fusedProducerInfo.origProducer.getLoc(),
+ fusedProducerInfo.tiledAndFusedProducer,
+ loops.back().getRegionIterArgs().back(), sliceOp.getMixedOffsets(),
+ sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
+ return {replacement};
+ }
+ };
+
+ addInitOperandsToLoopNest(rewriter, loops,
+ SmallVector<Value>{initValue.value()},
+ newYieldValuesFn);
}
}
diff --git a/mlir/test/Dialect/Tensor/tiling.mlir b/mlir/test/Dialect/Tensor/tiling.mlir
index 51f33a96e571b83e..bb42f84afc50f94e 100644
--- a/mlir/test/Dialect/Tensor/tiling.mlir
+++ b/mlir/test/Dialect/Tensor/tiling.mlir
@@ -374,8 +374,8 @@ module attributes {transform.with_named_sequence} {
// CHECK: %[[IN_J:.*]] = affine.apply #[[MAP2]](%[[J]])[%[[TILE_1]]]
// CHECK: %[[IN_J_SZ:.*]] = affine.min #[[MAP3]](%[[OUT_J_SZ]], %[[J]])[%[[TILE_1]], %[[IN_D1]]]
// CHECK: %[[SUB_IN:.*]] = tensor.extract_slice %[[IN]][%[[IN_I]], %[[IN_J]]] [%[[IN_I_SZ]], %[[IN_J_SZ]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
-// CHECK: %[[OUT_D2:.+]] = tensor.dim %[[OUT]], %[[C2]]
-// CHECK: %[[OUT_D3:.+]] = tensor.dim %[[OUT]], %[[C3]]
+// CHECK: %[[OUT_D2:.+]] = tensor.dim %[[ITER1]], %[[C2]]
+// CHECK: %[[OUT_D3:.+]] = tensor.dim %[[ITER1]], %[[C3]]
// CHECK: %[[SUB_OUT:.*]] = tensor.extract_slice %[[ITER1]][%[[I]], %[[J]], 0, 0] [%[[OUT_I_SZ]], %[[OUT_J_SZ]], %[[OUT_D2]], %[[OUT_D3]]] [1, 1, 1, 1] : tensor<?x?x?x?xf32> to tensor<?x?x?x?xf32>
// CHECK: %[[PACK:.*]] = tensor.pack
// CHECK-SAME: %[[SUB_IN]] padding_value(%[[PAD]] : f32) inner_dims_pos = [0, 1] inner_tiles = [%[[TILE_0]], %[[TILE_1]]]
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
index cf5a1b828f95b75f..2078b5b4dabb268a 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir
@@ -369,15 +369,15 @@ func.func @matmul_sequence_fusion(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>
// CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[N0:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK-DAG: %[[ORIG_GEMM1:.+]] = linalg.matmul ins(%[[ARG0]], %[[ARG1]] :
-// CHECK-DAG: %[[N1:.+]] = tensor.dim %[[ORIG_GEMM1]], %[[C1]]
// CHECK-DAG: %[[ORIG_GEMM2:.+]] = linalg.matmul ins(%[[ORIG_GEMM1]], %[[ARG3]] :
// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ORIG_GEMM2]], %[[C0]]
// CHECK-DAG: %[[N2:.+]] = tensor.dim %[[ORIG_GEMM2]], %[[C1]]
// CHECK-DAG: %[[N3:.+]] = tensor.dim %[[ARG5]], %[[C1]]
// CHECK: %[[R0:.+]] = scf.for %[[IV:[a-zA-Z0-9_]+]] =
// CHECK-SAME: iter_args(%[[ARG8:.+]] = %[[ARG6]]) -> (tensor<?x?xf32>) {
+// CHECK-DAG: %[[N1:.+]] = tensor.dim %[[ORIG_GEMM1]], %[[C1]]
+// CHECK-DAG: %[[N0:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK-DAG: %[[TILE_M:.+]] = affine.min #[[MAP]](%[[IV]])[%[[M]]]
// CHECK-DAG: %[[SLICE_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[TILE_M]], %[[N0]]]
// CHECK-DAG: %[[SLICE_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, 0] [%[[N0]], %[[N1]]]
diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
index 2153eb6f237fcfd6..e99ffc88066d69da 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir
@@ -100,37 +100,37 @@ func.func @multi_result(%arg0 : tensor<128x200x300xf32>) -> (tensor<128x300x200x
} -> (tensor<128x300x200xf32>, tensor<300x128x200xf32>)
return %0#0, %0#1 : tensor<128x300x200xf32>, tensor<300x128x200xf32>
}
-// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0) -> (10, -d0 + 128)>
-// CHECK-LABEL: func.func @multi_result(
-// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x200x300xf32>)
-// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
-// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
-// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
-// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
-// CHECK-DAG: %[[C300:.+]] = arith.constant 300 : index
-// CHECK-DAG: %[[INIT0:.+]] = tensor.empty()
-// CHECK-DAG: %[[INIT1:.+]] = tensor.empty()
-// CHECK: %[[OUTER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C128]] step %[[C10]]
-// CHECK-SAME: iter_args(%[[ARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
-// CHECK: %[[TS_Y:.+]] = affine.min #[[$MAP0]](%[[IV0]])
-// CHECK: %[[INNER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[C300]] step %[[C20]]
-// CHECK-SAME: iter_args(%[[ARG3:[a-zA-Z0-9]+]] = %[[ARG1]], %[[ARG4:[a-zA-Z0-9]+]] = %[[ARG2]])
-// CHECK-DAG: %[[ARG_TILE:.+]] = tensor.extract_slice %[[ARG0]]
-// CHECK-SAME: [%[[IV0]], 0, %[[IV1]]] [%[[TS_Y]], 200, 20] [1, 1, 1]
-// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ARG3]]
-// CHECK-SAME: [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], 20, 200] [1, 1, 1]
-// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ARG4]]
-// CHECK-SAME: [%[[IV1]], %[[IV0]], 0] [20, %[[TS_Y]], 200] [1, 1, 1]
-// CHECK: %[[RESULT_TILE:.+]]:2 = linalg.generic
-// CHECK-SAME: ins(%[[ARG_TILE]] :
-// CHECK-SAME: outs(%[[INIT0_TILE]], %[[INIT1_TILE]] :
-// CHECK: %[[UPDATE0:.+]] = tensor.insert_slice %[[RESULT_TILE]]#0 into %[[ARG3]]
-// CHECK-SAME: [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], 20, 200] [1, 1, 1]
-// CHECK: %[[UPDATE1:.+]] = tensor.insert_slice %[[RESULT_TILE]]#1 into %[[ARG4]]
-// CHECK-SAME: [%[[IV1]], %[[IV0]], 0] [20, %[[TS_Y]], 200] [1, 1, 1]
-// CHECK: scf.yield %[[UPDATE0]], %[[UPDATE1]]
-// CHECK: scf.yield %[[INNER]]#0, %[[INNER]]#1
-// CHECK: return %[[OUTER]]#0, %[[OUTER]]#1
+// CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0) -> (10, -d0 + 128)>
+// CHECK-LABEL: func.func @multi_result(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x200x300xf32>)
+// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
+// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index
+// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index
+// CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index
+// CHECK-DAG: %[[C300:.+]] = arith.constant 300 : index
+// CHECK-DAG: %[[INIT0:.+]] = tensor.empty()
+// CHECK-DAG: %[[INIT1:.+]] = tensor.empty()
+// CHECK: %[[OUTER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C128]] step %[[C10]]
+// CHECK-SAME: iter_args(%[[ARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
+// CHECK: %[[TS_Y:.+]] = affine.min #[[$MAP0]](%[[IV0]])
+// CHECK: %[[INNER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[C300]] step %[[C20]]
+// CHECK-SAME: iter_args(%[[ARG3:[a-zA-Z0-9]+]] = %[[ARG1]], %[[ARG4:[a-zA-Z0-9]+]] = %[[ARG2]])
+// CHECK-DAG: %[[ARG_TILE:.+]] = tensor.extract_slice %[[ARG0]]
+// CHECK-SAME: [%[[IV0]], 0, %[[IV1]]] [%[[TS_Y]], 200, 20] [1, 1, 1]
+// CHECK-DAG: %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ARG3]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], 20, 200] [1, 1, 1]
+// CHECK-DAG: %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ARG4]]
+// CHECK-SAME: [%[[IV1]], %[[IV0]], 0] [20, %[[TS_Y]], 200] [1, 1, 1]
+// CHECK: %[[RESULT_TILE:.+]]:2 = linalg.generic
+// CHECK-SAME: ins(%[[ARG_TILE]] :
+// CHECK-SAME: outs(%[[INIT0_TILE]], %[[INIT1_TILE]] :
+// CHECK: %[[UPDATE0:.+]] = tensor.insert_slice %[[RESULT_TILE]]#0 into %[[ARG3]]
+// CHECK-SAME: [%[[IV0]], %[[IV1]], 0] [%[[TS_Y]], 20, 200] [1, 1, 1]
+// CHECK: %[[UPDATE1:.+]] = tensor.insert_slice %[[RESULT_TILE]]#1 into %[[ARG4]]
+// CHECK-SAME: [%[[IV1]], %[[IV0]], 0] [20, %[[TS_Y]], 200] [1, 1, 1]
+// CHECK: scf.yield %[[UPDATE0]], %[[UPDATE1]]
+// CHECK: scf.yield %[[INNER]]#0, %[[INNER]]#1
+// CHECK: return %[[OUTER]]#0, %[[OUTER]]#1
// -----
@@ -193,14 +193,9 @@ func.func @conv2D(%arg0 : tensor<?x?x?x?xf32>, %arg1 : tensor<?x?x?x?xf32>,
// -----
-// CHECK: #[[$MAP_ADD:.+]] = affine_map<(d0, d1) -> (d0 + d1)>
-
-// CHECK-LABEL: @indexed_semantics
func.func @indexed_semantics(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
// Check that we correctly amend "linalg.index" results.
- // CHECK: scf.for %[[I0:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
- // CHECK: scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
%0 = linalg.generic {
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d0, d1)>],
@@ -209,13 +204,8 @@ func.func @indexed_semantics(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) ->
ins(%arg0: tensor<?x?xf32>)
outs(%arg1: tensor<?x?xf32>) {
^bb0(%arg2: f32, %arg3: f32):
- // CHECK: %[[INDEX0:.+]] = linalg.index 0
- // CHECK: %[[INDEX0_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX0]], %[[I0]])
%1 = linalg.index 0 : index
- // CHECK: %[[INDEX1:.+]] = linalg.index 1
- // CHECK: %[[INDEX1_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX1]], %[[I1]])
%2 = linalg.index 1 : index
- // CHECK: arith.addi %[[INDEX0_AMENDED]], %[[INDEX1_AMENDED]]
%3 = arith.addi %1, %2 : index
%4 = arith.index_cast %3 : index to i64
%5 = arith.uitofp %4 : i64 to f32
@@ -224,6 +214,15 @@ func.func @indexed_semantics(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) ->
} -> (tensor<?x?xf32>)
return %0 : tensor<?x?xf32>
}
+// CHECK: #[[$MAP_ADD:.+]] = affine_map<(d0, d1) -> (d0 + d1)>
+// CHECK-LABEL: @indexed_semantics
+// CHECK: scf.for %[[I0:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
+// CHECK: scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}}
+// CHECK: %[[INDEX0:.+]] = linalg.index 0
+// CHECK: %[[INDEX0_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX0]], %[[I0]])
+// CHECK: %[[INDEX1:.+]] = linalg.index 1
+// CHECK: %[[INDEX1_AMENDED:.+]] = affine.apply #[[$MAP_ADD]](%[[INDEX1]], %[[I1]])
+// CHECK: arith.addi %[[INDEX0_AMENDED]], %[[INDEX1_AMENDED]]
// -----
@@ -276,14 +275,53 @@ func.func @interchange_matmul(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>,
// -----
+func.func @linalg_copy_matmul(%a: memref<?x?xf32>, %b: memref<?x?xf32>) {
+ linalg.copy {__internal_transform__ = "simple_copy_memref"}
+ ins(%a : memref<?x?xf32>) outs(%b : memref<?x?xf32>)
+ return
+}
// CHECK-LABEL: func @linalg_copy_matmul(
// CHECK: scf.for
// CHECK: scf.for
// CHECK: memref.subview
// CHECK: memref.subview
// CHECK: linalg.copy
-func.func @linalg_copy_matmul(%a: memref<?x?xf32>, %b: memref<?x?xf32>) {
- linalg.copy {__internal_transform__ = "simple_copy_memref"}
- ins(%a : memref<?x?xf32>) outs(%b : memref<?x?xf32>)
+
+// -----
+
+func.func @check_scalar_operation(%arg0 : tensor<f32>) -> tensor<f32> {
+ %init = tensor.empty() : tensor<f32>
+ %0 = linalg.generic {
+ indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>],
+ iterator_types = []}
+ {__internal_transform__ = "scalar_op"}
+ ins(%arg0 : tensor<f32>) outs(%init : tensor<f32>){
+ ^bb0(%b0 : f32, %b1 : f32):
+ %1 = arith.mulf %b0, %b0 : f32
+ linalg.yield %1 : f32
+ } -> tensor<f32>
+ return %0 : tensor<f32>
+}
+// CHECK-LABEL: func @check_scalar_operation
+// CHECK-NOT: scf.for
+// CHECK: linalg.generic
+// CHECK-SAME: __internal_transform__ = "scalar_op"
+
+// -----
+
+func.func @check_scalar_memref_operation(%arg0 : memref<f32>, %arg1 : memref<f32>){
+ linalg.generic {
+ indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>],
+ iterator_types = []}
+ {__internal_transform__ = "scalar_op"}
+ ins(%arg0 : memref<f32>) outs(%arg1 : memref<f32>){
+ ^bb0(%b0 : f32, %b1 : f32):
+ %1 = arith.mulf %b0, %b0 : f32
+ linalg.yield %1 : f32
+ }
return
}
+// CHECK-LABEL: func @check_scalar_memref_operation
+// CHECK-NOT: scf.for
+// CHECK: linalg.generic
+// CHECK-SAME: __internal_transform__ = "scalar_op"
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
index e5d7dc54409e4473..112ad6cbde858943 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -579,6 +579,8 @@ void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
addPatternForTiling(context, patterns, "pad_outer_tiling", {2, 3});
// 10. Tiling M and N dims of `linalg.copy` on memrefs.
addPatternForTiling(context, patterns, "simple_copy_memref", {10, 20});
+ // 11. Tiling scalar operations.
+ addPatternForTiling(context, patterns, "scalar_op", {});
return;
}
if (testTilingForAll) {
More information about the Mlir-commits
mailing list