[Mlir-commits] [mlir] 97f9198 - [mlir][TilingInterface] NFC Refactor of tile and fuse using `TilingInterface`.
Mahesh Ravishankar
llvmlistbot at llvm.org
Wed Sep 28 13:26:00 PDT 2022
Author: Mahesh Ravishankar
Date: 2022-09-28T20:25:33Z
New Revision: 97f919820b075fe49393405bf0ea990cf820ffeb
URL: https://github.com/llvm/llvm-project/commit/97f919820b075fe49393405bf0ea990cf820ffeb
DIFF: https://github.com/llvm/llvm-project/commit/97f919820b075fe49393405bf0ea990cf820ffeb.diff
LOG: [mlir][TilingInterface] NFC Refactor of tile and fuse using `TilingInterface`.
This patch refactors the tiling and tile + fuse implementation using
`TilingInterface`. Primarily, it exposes the functionality as simple
utility functions instead of as a Pattern to allow calling it from a
pattern as it is done in the test today or from within the transform
dialect (in the future). This is a step towards deprecating similar
methods in Linalg dialect.
- The utility methods do not erase the root operations.
- The return value provides the values to use for replacements.
Differential Revision: https://reviews.llvm.org/D134144
Added:
Modified:
mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index a56b6b44e4657..1c374d62425d1 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -60,38 +60,48 @@ struct SCFTilingOptions {
}
};
+/// Transformation information returned after tiling.
struct SCFTilingResult {
+ /// The tiled operation generated.
Operation *tiledOp;
+ /// The `scf.for` operations that iterate over the tiles.
SmallVector<scf::ForOp> loops;
+ /// Values to use as replacements for the untiled op. Is the same size as the
+ /// number of results of the untiled op.
+ SmallVector<Value> replacements;
};
-/// Pattern to tile an op that implements the `TilingInterface` using
+/// Method to tile an op that implements the `TilingInterface` using
/// `scf.for` for iterating over the tiles.
-struct TileUsingSCFForOp : public OpInterfaceRewritePattern<TilingInterface> {
- /// Construct a generic pattern applied to all TilingInterface ops.
- TileUsingSCFForOp(MLIRContext *context, SCFTilingOptions options,
- PatternBenefit benefit = 1);
-
- /// Construct a generic pattern applied to `opName`.
- TileUsingSCFForOp(StringRef opName, MLIRContext *context,
- SCFTilingOptions options, PatternBenefit benefit = 1);
-
- /// `matchAndRewrite` implementation that returns the significant transformed
- /// pieces of IR.
- FailureOr<SCFTilingResult>
- returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const;
-
- LogicalResult matchAndRewrite(TilingInterface op,
- PatternRewriter &rewriter) const override {
- return returningMatchAndRewrite(op, rewriter);
+FailureOr<SCFTilingResult> tileUsingSCFForOp(RewriterBase &rewriter,
+ TilingInterface op,
+ SCFTilingOptions options);
+
+/// Options used to control tile + fuse.
+struct SCFTileAndFuseOptions {
+ /// The tiling options used to control the tiling of the consumer.
+ SCFTilingOptions tilingOptions;
+ SCFTileAndFuseOptions &setTilingOptions(SCFTilingOptions options) {
+ tilingOptions = options;
+ return *this;
}
+};
-private:
- /// Options to control tiling;
- SCFTilingOptions options;
+/// Transformation information returned after tile and fuse.
+struct SCFTileAndFuseResult {
+ /// List of untiled operations that were fused with the tiled consumer.
+ llvm::SetVector<Operation *> fusedProducers;
+ /// List of tiled and fused operations generated. The first one in this list
+ /// is guaranteed to be the tiled operations generated during tiling of the
+ /// generated operation.
+ llvm::SetVector<Operation *> tiledAndFusedOps;
+ /// The `scf.for` operations that iterate over the tiles.
+ SmallVector<scf::ForOp> loops;
+ /// The replacement values to use for the tiled and fused operations.
+ llvm::DenseMap<Value, Value> replacements;
};
-/// Pattern to tile and fuse a sequence of operations, by tiling the consumer
+/// Method to tile and fuse a sequence of operations, by tiling the consumer
/// and fusing its producers. Note that this assumes that it is valid to
/// tile+fuse the producer into the innermost tiled loop. Its up to the caller
/// to ensure that the tile sizes provided make this fusion valid.
@@ -99,64 +109,32 @@ struct TileUsingSCFForOp : public OpInterfaceRewritePattern<TilingInterface> {
/// For example, for the following sequence
///
/// ```mlir
-/// %0 = linalg.fill ...
-/// %1 = linalg.matmul ... outs(%0 : ...) ...
+/// %0 =
+/// %1 = linalg.fill ... outs(%0 : ... )
+/// %2 = linalg.matmul ... outs(%1 : ...) ...
/// ```
///
/// it is legal to fuse the fill with the matmul only if the matmul is tiled
/// along the parallel dimensions and not the reduction dimension, i.e. the tile
-/// size for the reduction dimension should be 0.
-struct SCFTileAndFuseResult {
- SmallVector<Operation *> tiledAndFusedOps;
- SmallVector<scf::ForOp> loops;
-};
-struct TileConsumerAndFuseProducersUsingSCFForOp
- : public OpInterfaceRewritePattern<TilingInterface> {
-
- /// Construct a generic pattern applied to all TilingInterface ops.
- TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context,
- SCFTilingOptions options,
- PatternBenefit benefit = 1);
-
- /// Construct a generic pattern applied to `opName`.
- TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName,
- MLIRContext *context,
- SCFTilingOptions options,
- PatternBenefit benefit = 1);
-
- /// `matchAndRewrite` implementation that returns the significant transformed
- /// pieces of IR.
- FailureOr<SCFTileAndFuseResult>
- returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const;
-
- LogicalResult matchAndRewrite(TilingInterface op,
- PatternRewriter &rewriter) const override {
- return returningMatchAndRewrite(op, rewriter);
- }
-
-private:
- /// This pattern uses the tiling pattern. Instead of using inheritance, use
- /// the patterns as private object that is instantiated at the same time as
- /// this pattern.
- TileUsingSCFForOp tilingPattern;
-};
-
-/// Pattern to lower operations that implement the `TilingInterface` to
-/// loops/scalar IR using `scf.for`.
-struct LowerToLoopsUsingSCFForOp
- : public OpInterfaceRewritePattern<TilingInterface> {
- using OpInterfaceRewritePattern<TilingInterface>::OpInterfaceRewritePattern;
-
- /// `matchAndRewrite` implementation that returns the significant transformed
- /// pieces of IR.
- FailureOr<SmallVector<scf::ForOp>>
- returningMatchAndRewrite(TilingInterface op, PatternRewriter &rewriter) const;
-
- LogicalResult matchAndRewrite(TilingInterface op,
- PatternRewriter &rewriter) const override {
- return returningMatchAndRewrite(op, rewriter);
- }
-};
+/// size for the reduction dimension should be 0. The resulting fused
+/// transformation is
+///
+/// ```mlir
+/// %1 = scf.for ... iter_args(%arg0 = %0)
+/// %2 = tensor.extract_slice %arg0
+/// %3 = linalg.fill .. outs(%2 : ... )
+/// %4 = linalg.matmul .. outs(%3 : ...)
+/// }
+/// ```
+FailureOr<SCFTileAndFuseResult>
+tileConsumerAndFuseProducerGreedilyUsingSCFForOp(RewriterBase &rewriter,
+ TilingInterface consumer,
+ SCFTileAndFuseOptions options);
+
+/// Method to lower an `op` that implements the `TilingInterface` to
+/// loops/scalars.
+FailureOr<SmallVector<scf::ForOp>>
+lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, TilingInterface op);
} // namespace scf
} // namespace mlir
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 0c6ba3d195da5..5342be5cfdc65 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -87,7 +87,7 @@ static bool isPermutation(ArrayRef<unsigned> interchange) {
}
//===----------------------------------------------------------------------===//
-// TileUsingSCFForOp pattern implementation.
+// tileUsingSCFForOp implementation.
//===----------------------------------------------------------------------===//
// Check if `stride` evenly divides the trip count `size - offset`.
@@ -167,7 +167,65 @@ generateTileLoopNest(OpBuilder &builder, Location loc,
return loops;
}
-/// If the tiled operation is in destination passing style, update the
+/// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`,
+/// construct the destructive update pattern that inserts the yielded
+/// value into a destination tensor provided by `initValue` at offset
+/// `tileOffsets` and size `tileSizes`. For example,
+///
+/// ```mlir
+/// scf.for %iv0 = ... {
+/// %0 = tiled_op
+/// }
+/// ```
+///
+/// is transformed to
+///
+/// ```mlir
+/// scf.for %iv0 = ... iter_args(%arg = %0) {
+/// %1 = tensor.extract_slice %arg
+/// %2 = tiled_op
+/// %3 = tensor.insert_slice %2 into %arg
+/// scf.yield %3
+/// }
+/// ```
+/// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`.
+static FailureOr<SmallVector<Value>>
+yieldTiledValues(RewriterBase &rewriter, ValueRange initValues,
+ ValueRange yieldedValues,
+ ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
+ ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
+ MutableArrayRef<scf::ForOp> loops) {
+ NewYieldValueFn yieldValueFn =
+ [&](OpBuilder &b, Location loc,
+ ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> {
+ SmallVector<Value> inserts;
+ for (auto yieldedValue : llvm::enumerate(yieldedValues)) {
+ ArrayRef<OpFoldResult> tileOffsets =
+ tileOffsetsList[yieldedValue.index()];
+ ArrayRef<OpFoldResult> tileSizes = tileSizesList[yieldedValue.index()];
+ SmallVector<OpFoldResult> tileStrides(tileOffsets.size(),
+ b.getIndexAttr(1));
+ Value insert = b.create<tensor::InsertSliceOp>(
+ loc, yieldedValue.value(), newBBArgs[yieldedValue.index()],
+ tileOffsets, tileSizes, tileStrides);
+ inserts.push_back(insert);
+ }
+ return inserts;
+ };
+
+ SmallVector<scf::ForOp> newLoops =
+ replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn,
+ /*replaceIterOperandsUsesInLoop =*/false);
+ for (const auto &loop : llvm::enumerate(loops)) {
+ rewriter.eraseOp(loop.value());
+ loops[loop.index()] = newLoops[loop.index()];
+ }
+ return llvm::to_vector(llvm::map_range(
+ loops.front().getResults().take_back(yieldedValues.size()),
+ [](OpResult r) -> Value { return r; }));
+}
+
+/// If the tiled operation is destination passing style, update the
/// slice of the destination used (which refers to the untiled destination)
/// to use the corresponding region argument of the innermost loop.
///
@@ -191,8 +249,6 @@ generateTileLoopNest(OpBuilder &builder, Location loc,
/// scf.yield %3
/// }
/// ```
-/// TODO: This can be made much cleaner when `DestinationStyleOp` interface is
-/// available generally.
static void
updateDestinationOperandsForTiledOp(OpBuilder &builder,
ValueRange tiledOpDestinationValues,
@@ -205,22 +261,11 @@ updateDestinationOperandsForTiledOp(OpBuilder &builder,
}
}
-scf::TileUsingSCFForOp::TileUsingSCFForOp(MLIRContext *context,
- scf::SCFTilingOptions options,
- PatternBenefit benefit)
- : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
- options(std::move(options)) {}
-
-scf::TileUsingSCFForOp::TileUsingSCFForOp(StringRef opName,
- MLIRContext *context,
- scf::SCFTilingOptions options,
- PatternBenefit benefit)
- : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
- options(std::move(options)) {}
-
+/// Implementation of tiling transformation of `op` that implements the
+/// `TilingInterface` using `scf.for` to iterate over the tiles.
FailureOr<scf::SCFTilingResult>
-scf::TileUsingSCFForOp::returningMatchAndRewrite(
- TilingInterface op, PatternRewriter &rewriter) const {
+mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
+ scf::SCFTilingOptions options) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(op);
@@ -282,132 +327,86 @@ scf::TileUsingSCFForOp::returningMatchAndRewrite(
offsets = applyPermutationToVector(offsets, inversePermutation);
sizes = applyPermutationToVector(sizes, inversePermutation);
}
+ }
- LLVM_DEBUG({
- if (!tilingResult.loops.empty()) {
- llvm::errs() << "LoopNest shell :\n";
- tilingResult.loops.front().dump();
- llvm::errs() << "\n";
- }
- });
-
- // 4. Generate the tiled implementation within the inner most loop.
- if (!tilingResult.loops.empty())
- rewriter.setInsertionPoint(
- tilingResult.loops.back().getBody()->getTerminator());
- SmallVector<Operation *> tiledImplementation =
- op.getTiledImplementation(rewriter, offsets, sizes);
- if (tiledImplementation.size() != 1) {
- return rewriter.notifyMatchFailure(
- op, "expected tiled implementation to return a single op");
+ LLVM_DEBUG({
+ if (!tilingResult.loops.empty()) {
+ llvm::dbgs() << "LoopNest shell :\n";
+ tilingResult.loops.front().dump();
+ llvm::dbgs() << "\n";
}
- tilingResult.tiledOp = tiledImplementation[0];
-
- LLVM_DEBUG({
- if (!tilingResult.loops.empty()) {
- llvm::errs() << "After tiled implementation :\n";
- tilingResult.loops.front().dump();
- llvm::errs() << "\n";
- }
- });
+ });
+
+ // 4. Generate the tiled implementation within the inner most loop.
+ if (!tilingResult.loops.empty())
+ rewriter.setInsertionPoint(
+ tilingResult.loops.back().getBody()->getTerminator());
+ SmallVector<Operation *> tiledImplementation =
+ op.getTiledImplementation(rewriter, offsets, sizes);
+ if (tiledImplementation.size() != 1) {
+ return rewriter.notifyMatchFailure(
+ op, "expected tiled implementation to return a single op");
}
-
+ tilingResult.tiledOp = tiledImplementation[0];
if (op->getNumResults() == 0) {
- rewriter.eraseOp(op);
+ // nothing more to do.
return tilingResult;
}
- // 5. If the original operations has results, modify the loop nest to yield
- // the replacement values.
+ // If loops are empty, the tiled op is used as the replacement for the untiled
+ // op.
if (tilingResult.loops.empty()) {
- // 5a. If there were no loops, the tiled implementation results are the
- // replacements.
- rewriter.replaceOp(op, tilingResult.tiledOp->getResults());
+ tilingResult.replacements = llvm::to_vector(
+ llvm::map_range(tiledImplementation[0]->getResults(),
+ [](OpResult result) -> Value { return result; }));
return tilingResult;
}
- // 6. Yield the results of the tiled operation from the loop nest as
- // replacements for the original untiled ops.
- if (tilingResult.tiledOp->getNumResults() != op->getNumResults()) {
- return rewriter.notifyMatchFailure(
- tilingResult.tiledOp,
- "expected tiled op to have as many results as the untiled operation");
+ // 5. Yield all the results of the tiled operation. The surrounding loop
+ // nest is modified to insert a destructive update pattern to yield
+ // from the loop nest values to replace the untiled op with.
+ unsigned numResults = op->getNumResults();
+ SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults),
+ resultSizesList(numResults);
+ for (auto result : llvm::enumerate(op->getResults())) {
+ if (failed(op.getResultTilePosition(rewriter, result.index(), offsets,
+ sizes,
+ resultOffsetsList[result.index()],
+ resultSizesList[result.index()]))) {
+ return rewriter.notifyMatchFailure(
+ op, "failed to get slice of result produced");
+ }
}
- // `scf.for` with tensor semantics requires the loop nest to yield the
- // replacement values using destructive updates. Use the `TilingInterface`
- // to get the position of the result tiles and use that to generate the
- // destructive update pattern, i.e.,
- //
- // ```mlir
- // scf.for %iv0 = ... {
- // %0 = tiled_op
- // }
- // ```
- //
- // is transformed to
- //
- // ```mlir
- // %result = scf.for %iv0 = ... iter_args(%arg = %init) -> .. {
- // %0 = tiled_op
- // %1 = tensor.insert_slice %0 into %arg[..] [..] [..]
- // scf.yield %1
- // }
- // ```
- NewYieldValueFn yieldValueFn =
- [&](OpBuilder &b, Location loc,
- ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> {
- SmallVector<Value> yieldedValues;
- Attribute one = b.getIndexAttr(1);
- for (auto resultNum : llvm::seq<unsigned>(0, op->getNumResults())) {
- SmallVector<OpFoldResult> resultTileOffsets, resultTileSizes;
- if (failed(op.getResultTilePosition(b, resultNum, offsets, sizes,
- resultTileOffsets,
- resultTileSizes))) {
- op.emitOpError("unable to get position of result ")
- << resultNum << " of the tiled implementation";
- return {};
- }
- SmallVector<OpFoldResult> resultTileStrides(resultTileOffsets.size(),
- one);
- Value yieldedValue = b.create<tensor::InsertSliceOp>(
- op->getLoc(), tilingResult.tiledOp->getResult(resultNum),
- newBBArgs[resultNum], resultTileOffsets, resultTileSizes,
- resultTileStrides);
- yieldedValues.push_back(yieldedValue);
- }
- return yieldedValues;
- };
- SmallVector<scf::ForOp> newLoops = replaceLoopNestWithNewYields(
- rewriter, tilingResult.loops, op.getDestinationOperands(rewriter),
- yieldValueFn, /*replaceIterOperandsUsesInLoops =*/false);
- for (const auto &loop : llvm::enumerate(tilingResult.loops)) {
- rewriter.eraseOp(loop.value());
- tilingResult.loops[loop.index()] = newLoops[loop.index()];
+ FailureOr<SmallVector<Value>> replacementOr =
+ yieldTiledValues(rewriter, op.getDestinationOperands(rewriter),
+ tilingResult.tiledOp->getResults(), resultOffsetsList,
+ resultSizesList, tilingResult.loops);
+ if (failed(replacementOr))
+ return rewriter.notifyMatchFailure(op, "failed to yield replacement");
+ if (auto tiledInterfaceOp = dyn_cast<TilingInterface>(tilingResult.tiledOp)) {
+ auto innerMostLoop = tilingResult.loops.back();
+ updateDestinationOperandsForTiledOp(
+ rewriter, tiledInterfaceOp.getDestinationOperands(rewriter),
+ innerMostLoop.getRegionIterArgs());
}
- rewriter.replaceOp(op, tilingResult.loops.front().getResults());
+
+ tilingResult.replacements = replacementOr.value();
+
+ LLVM_DEBUG({
+ if (!tilingResult.loops.empty()) {
+ llvm::dbgs() << "After tiled implementation :\n";
+ tilingResult.loops.front().dump();
+ llvm::dbgs() << "\n";
+ }
+ });
return tilingResult;
}
//===----------------------------------------------------------------------===//
-// TileConsumerAndFuseProducersUsingSCFForOp pattern implementation.
+// tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation.
//===----------------------------------------------------------------------===//
-scf::TileConsumerAndFuseProducersUsingSCFForOp::
- TileConsumerAndFuseProducersUsingSCFForOp(MLIRContext *context,
- scf::SCFTilingOptions options,
- PatternBenefit benefit)
- : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
- tilingPattern(context, std::move(options)) {}
-
-scf::TileConsumerAndFuseProducersUsingSCFForOp::
- TileConsumerAndFuseProducersUsingSCFForOp(StringRef opName,
- MLIRContext *context,
- scf::SCFTilingOptions options,
- PatternBenefit benefit)
- : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
- tilingPattern(context, std::move(options)) {}
-
/// Return the untiled producer whose slice is used in a tiled consumer. The
/// method traverses the tile loop nest (`loops`) if needed, and returns the
/// `iter_args` of the outer most that is encountered. Traversing the iter_args
@@ -430,28 +429,41 @@ getUntiledProducerFromSliceSource(OpOperand *source,
return {source->get().dyn_cast<OpResult>(), destinationIterArg};
}
+/// Implementation of tile consumer and fuse producer greedily.
FailureOr<scf::SCFTileAndFuseResult>
-scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
- TilingInterface op, PatternRewriter &rewriter) const {
+mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
+ RewriterBase &rewriter, TilingInterface consumer,
+ scf::SCFTileAndFuseOptions options) {
// This transformation is only valid for ops that return values (i.e. not
// valid to use with operations that have memref operands).
- if (!op->getNumResults()) {
+ if (!consumer->getNumResults()) {
return rewriter.notifyMatchFailure(
- op, "invalid pattern for op with no results");
+ consumer, "invalid pattern for op with no results");
}
// 1. First tile the consumer.
- SCFTileAndFuseResult tileAndFuseResult;
+ scf::SCFTileAndFuseResult tileAndFuseResult;
+ llvm::SmallDenseMap<Value, unsigned> yieldedValueToResultNumber;
{
- FailureOr<SCFTilingResult> tilingResult =
- tilingPattern.returningMatchAndRewrite(op, rewriter);
- if (failed(tilingResult)) {
- return failure();
- }
- tileAndFuseResult.tiledAndFusedOps.push_back(tilingResult->tiledOp);
+ FailureOr<scf::SCFTilingResult> tilingResult =
+ tileUsingSCFForOp(rewriter, consumer, options.tilingOptions);
+ if (failed(tilingResult))
+ return rewriter.notifyMatchFailure(consumer, "failed to tile consumer");
+ tileAndFuseResult.tiledAndFusedOps.insert(tilingResult->tiledOp);
tileAndFuseResult.loops = std::move(tilingResult->loops);
+ for (auto result : llvm::enumerate(
+ llvm::zip(consumer->getResults(), tilingResult->replacements))) {
+ tileAndFuseResult.replacements[std::get<0>(result.value())] =
+ std::get<1>(result.value());
+ yieldedValueToResultNumber[tilingResult->tiledOp->getResult(
+ result.index())] = result.index();
+ }
}
+ // If there are no loops generated, fusion is immaterial.
+ if (tileAndFuseResult.loops.empty())
+ return tileAndFuseResult;
+
// 2. Typically, the operands of the tiled operation are slices of the
// operands of the untiled operation. These are expressed in IR using
// `tensor.extract_slice` operations with source being the operands of the
@@ -495,7 +507,7 @@ scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
// values produced by operations that implement the `TilingInterface`.
// Add these operations to the worklist.
Operation *fusedProducer = fusedProducerValue->getDefiningOp();
- tileAndFuseResult.tiledAndFusedOps.push_back(fusedProducer);
+ tileAndFuseResult.tiledAndFusedOps.insert(fusedProducer);
addCandidateSlices(fusedProducer, candidates);
// 2e. If the slice is for a destination operand, for example,
@@ -577,20 +589,19 @@ scf::TileConsumerAndFuseProducersUsingSCFForOp::returningMatchAndRewrite(
}
//===----------------------------------------------------------------------===//
-// LowerToLoopsUsingSCFForOp
+// lowerToLoopsUsingSCFForOp implementation.
//===----------------------------------------------------------------------===//
FailureOr<SmallVector<scf::ForOp>>
-scf::LowerToLoopsUsingSCFForOp::returningMatchAndRewrite(
- TilingInterface op, PatternRewriter &rewriter) const {
- SmallVector<Range> domain = op.getIterationDomain(rewriter);
-
+mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter,
+ TilingInterface op) {
// TODO: Handle cases where the op has results if needed.
if (op->getNumResults() > 0) {
return rewriter.notifyMatchFailure(
op, "unable to lower to loops operations with return values");
}
+ SmallVector<Range> domain = op.getIterationDomain(rewriter);
SmallVector<Value> ivs;
SmallVector<scf::ForOp> loops;
Location loc = op.getLoc();
@@ -610,6 +621,5 @@ scf::LowerToLoopsUsingSCFForOp::returningMatchAndRewrite(
if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) {
return failure();
}
- rewriter.eraseOp(op);
return loops;
}
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
index b20cd640c3881..edb7ba3729173 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp
@@ -36,38 +36,46 @@ namespace {
/// Pattern for testing `TileUsingSCFForOp` pattern (that tiles operations using
/// the `TilingInterface` with `scf.for` ops for iterating over the tiles) while
/// using a `filter` to avoid recursive application.
-struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp {
- TestTileUsingSCFForOpWithFilter(MLIRContext *context,
- scf::SCFTilingOptions options,
- linalg::LinalgTransformationFilter filter =
- linalg::LinalgTransformationFilter(),
- PatternBenefit benefit = 1)
- : scf::TileUsingSCFForOp(context, std::move(options), benefit),
- filter(std::move(filter)) {}
+struct TestTileUsingSCFForOp
+ : public OpInterfaceRewritePattern<TilingInterface> {
+ TestTileUsingSCFForOp(MLIRContext *context, scf::SCFTilingOptions options,
+ linalg::LinalgTransformationFilter filter =
+ linalg::LinalgTransformationFilter(),
+ PatternBenefit benefit = 1)
+ : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+ options(std::move(options)), filter(std::move(filter)) {}
/// Construct a generic pattern applied to `opName`.
- TestTileUsingSCFForOpWithFilter(StringRef opName, MLIRContext *context,
- scf::SCFTilingOptions options,
- linalg::LinalgTransformationFilter filter =
- linalg::LinalgTransformationFilter(),
- PatternBenefit benefit = 1)
- : scf::TileUsingSCFForOp(context, std::move(options), benefit),
- filter(std::move(filter)) {}
+ TestTileUsingSCFForOp(StringRef opName, MLIRContext *context,
+ scf::SCFTilingOptions options,
+ linalg::LinalgTransformationFilter filter =
+ linalg::LinalgTransformationFilter(),
+ PatternBenefit benefit = 1)
+ : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+ options(std::move(options)), filter(std::move(filter)) {}
LogicalResult matchAndRewrite(TilingInterface op,
PatternRewriter &rewriter) const override {
if (failed(filter.checkAndNotify(rewriter, op)))
return failure();
- auto tilingResult = returningMatchAndRewrite(op, rewriter);
- if (failed(tilingResult)) {
- return failure();
+ FailureOr<scf::SCFTilingResult> tilingResult =
+ scf::tileUsingSCFForOp(rewriter, op, options);
+ if (failed(tilingResult))
+ return rewriter.notifyMatchFailure(op, "failed to tile operation");
+
+ if (op->getNumResults()) {
+ rewriter.replaceOp(op, tilingResult->replacements);
+ } else {
+ rewriter.eraseOp(op);
}
+
filter.replaceLinalgTransformationFilter(rewriter, tilingResult->tiledOp);
return success();
}
private:
+ scf::SCFTilingOptions options;
linalg::LinalgTransformationFilter filter;
};
@@ -75,45 +83,74 @@ struct TestTileUsingSCFForOpWithFilter : public scf::TileUsingSCFForOp {
/// (that tiles and fuses operations using the `TilingInterface` with `scf.for`
/// ops for iterating over the tiles) while using a `filter` to avoid recursive
/// application.
-struct TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter
- : public scf::TileConsumerAndFuseProducersUsingSCFForOp {
- TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter(
- MLIRContext *context, scf::SCFTilingOptions options,
+struct TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp
+ : public OpInterfaceRewritePattern<TilingInterface> {
+ TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp(
+ MLIRContext *context, scf::SCFTileAndFuseOptions options,
linalg::LinalgTransformationFilter filter =
linalg::LinalgTransformationFilter(),
PatternBenefit benefit = 1)
- : scf::TileConsumerAndFuseProducersUsingSCFForOp(
- context, std::move(options), benefit),
- filter(std::move(filter)) {}
+ : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+ options(std::move(options)), filter(std::move(filter)) {}
/// Construct a generic pattern applied to `opName`.
- TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter(
- StringRef opName, MLIRContext *context, scf::SCFTilingOptions options,
+ TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp(
+ StringRef opName, MLIRContext *context,
+ scf::SCFTileAndFuseOptions options,
linalg::LinalgTransformationFilter filter =
linalg::LinalgTransformationFilter(),
PatternBenefit benefit = 1)
- : scf::TileConsumerAndFuseProducersUsingSCFForOp(
- context, std::move(options), benefit),
- filter(std::move(filter)) {}
+ : OpInterfaceRewritePattern<TilingInterface>(context, benefit),
+ options(std::move(options)), filter(std::move(filter)) {}
LogicalResult matchAndRewrite(TilingInterface op,
PatternRewriter &rewriter) const override {
if (failed(filter.checkAndNotify(rewriter, op)))
return failure();
- auto tileAndFuseResult = returningMatchAndRewrite(op, rewriter);
+ FailureOr<scf::SCFTileAndFuseResult> tileAndFuseResult =
+ scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(rewriter, op,
+ options);
if (failed(tileAndFuseResult)) {
return failure();
}
+ // Replace the tiled op with replacements.
+ SmallVector<Value> replacements(op->getNumResults());
+ for (auto result : llvm::enumerate(op->getResults())) {
+ replacements[result.index()] =
+ tileAndFuseResult->replacements.lookup(result.value());
+ }
+ rewriter.replaceOp(op, replacements);
+
filter.replaceLinalgTransformationFilter(
rewriter, tileAndFuseResult->tiledAndFusedOps.front());
return success();
}
private:
+ scf::SCFTileAndFuseOptions options;
linalg::LinalgTransformationFilter filter;
};
+/// Pattern to lower operations that implement the `TilingInterface` to
+/// loops/scalar IR using `scf.for`.
+struct LowerToLoopsUsingSCFForOp
+ : public OpInterfaceRewritePattern<TilingInterface> {
+ using OpInterfaceRewritePattern<TilingInterface>::OpInterfaceRewritePattern;
+
+ /// `matchAndRewrite` implementation that returns the significant transformed
+ /// pieces of IR.
+ LogicalResult matchAndRewrite(TilingInterface op,
+ PatternRewriter &rewriter) const override {
+ FailureOr<SmallVector<scf::ForOp>> loops =
+ scf::lowerToLoopsUsingSCFForOp(rewriter, op);
+ if (failed(loops))
+ return rewriter.notifyMatchFailure(op, "failed to lower to loops");
+ rewriter.eraseOp(op);
+ return loops;
+ }
+};
+
/// Test pass for testing the use of `TilingInterface`.
struct TestTilingInterfacePass
: public PassWrapper<TestTilingInterfacePass, OperationPass<func::FuncOp>> {
@@ -158,72 +195,78 @@ struct TestTilingInterfacePass
};
} // namespace
-template <class Pattern>
-static void
-addPatternForTiling(MLIRContext *context, RewritePatternSet &patterns,
- StringRef filterName, ArrayRef<int64_t> tileSizes,
- ArrayRef<unsigned> interchange = {}) {
+static void addPatternForTiling(MLIRContext *context,
+ RewritePatternSet &patterns,
+ StringRef filterName,
+ ArrayRef<int64_t> tileSizes,
+ ArrayRef<unsigned> interchange = {}) {
scf::SCFTilingOptions tilingOptions;
tilingOptions.setTileSizes(tileSizes).setInterchange(interchange);
linalg::LinalgTransformationFilter filter(
StringAttr::get(context, filterName), StringAttr::get(context, "tiled"));
- patterns.add<Pattern>(context, tilingOptions, filter);
+ patterns.add<TestTileUsingSCFForOp>(context, tilingOptions, filter);
+}
+
+static void addPatternForTileAndFuse(MLIRContext *context,
+ RewritePatternSet &patterns,
+ StringRef filterName,
+ ArrayRef<int64_t> tileSizes,
+ ArrayRef<unsigned> interchange = {}) {
+ scf::SCFTileAndFuseOptions tileAndFuseOptions;
+ tileAndFuseOptions.tilingOptions.setTileSizes(tileSizes).setInterchange(
+ interchange);
+ linalg::LinalgTransformationFilter filter(
+ StringAttr::get(context, filterName), StringAttr::get(context, "tiled"));
+ patterns.add<TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp>(
+ context, tileAndFuseOptions, filter);
}
void TestTilingInterfacePass::addTestPatterns(MLIRContext *context,
RewritePatternSet &patterns) {
if (testTiling) {
// 1. Tiling M and N dims of `linalg.matmul` on tensors.
- addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
- context, patterns, "simple_gemm", {10, 20});
+ addPatternForTiling(context, patterns, "simple_gemm", {10, 20});
// 2. Tiling M, N and K of `linalg.matmul` on buffers.
- addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
- context, patterns, "simple_gemm_memref", {10, 20, 30});
+ addPatternForTiling(context, patterns, "simple_gemm_memref", {10, 20, 30});
// 3. Tiling 3D parallel generic op which implements a transpose
- addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
- context, patterns, "parallel_generic_transpose", {10, 0, 20});
+ addPatternForTiling(context, patterns, "parallel_generic_transpose",
+ {10, 0, 20});
// 4. Tiling 2D conv op.
- addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
- context, patterns, "simple_conv", {0, 0, 0, 0, 10, 20, 30});
+ addPatternForTiling(context, patterns, "simple_conv",
+ {0, 0, 0, 0, 10, 20, 30});
// 5. Tiling a simple op with `linalg.index` inside.
- addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
- context, patterns, "indexed_semantics", {10, 20});
+ addPatternForTiling(context, patterns, "indexed_semantics", {10, 20});
// 6. Tiling + interchange of an operation
- addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
- context, patterns, "gemm_interchange", {10, 20, 30}, {1, 2, 0});
+ addPatternForTiling(context, patterns, "gemm_interchange", {10, 20, 30},
+ {1, 2, 0});
// 7. Tiling for 2D pad tensor operations.
- addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
- context, patterns, "pad_2dtiling", {2, 3});
+ addPatternForTiling(context, patterns, "pad_2dtiling", {2, 3});
// 8. Tiling inner dimension of 2d pad tensor operations.
- addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
- context, patterns, "pad_inner_tiling", {0, 3});
+ addPatternForTiling(context, patterns, "pad_inner_tiling", {0, 3});
// 9. Tiling inner dimension of 2d pad tensor operations.
- addPatternForTiling<TestTileUsingSCFForOpWithFilter>(
- context, patterns, "pad_outer_tiling", {2, 3});
+ addPatternForTiling(context, patterns, "pad_outer_tiling", {2, 3});
return;
}
if (testTileConsumerAndFuseProducer) {
- // 1. Tile and fuse of gemm with bias-add operation.
- addPatternForTiling<
- TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
- context, patterns, "fusion", {10, 20});
- addPatternForTiling<
- TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
- context, patterns, "gemm_fusion", {10});
- addPatternForTiling<
- TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
- context, patterns, "gemm_interchange_fusion", {10, 20}, {1, 0});
- addPatternForTiling<
- TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
- context, patterns, "gemm_plus_gemm_fusion", {10, 20});
- addPatternForTiling<
- TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>(
- context, patterns, "gemm_sequence_fusion", {10});
+ // 1. Tile and fuse of gemm with fill producer and bias-add consumer.
+ addPatternForTileAndFuse(context, patterns, "fusion", {10, 20});
+ // 2. Tile and fuse sequence of GEMMs, by fusing only along M.
+ addPatternForTileAndFuse(context, patterns, "gemm_fusion", {10});
+ // 3. Tile and fuse gemm with consumer + interchange of tiled loops.
+ addPatternForTileAndFuse(context, patterns, "gemm_interchange_fusion",
+ {10, 20}, {1, 0});
+ // 4. Tile and fuse matmul + transpose(matmul). Will introduce redundant
+ // computations.
+ addPatternForTileAndFuse(context, patterns, "gemm_plus_gemm_fusion",
+ {10, 20});
+ // 5. Tile and fuse a sequence of GEMMs by tiling and fusing only along M
+ // dimension.
+ addPatternForTileAndFuse(context, patterns, "gemm_sequence_fusion", {10});
return;
}
if (testLoweringToScalar) {
- patterns.add<scf::LowerToLoopsUsingSCFForOp>(context);
+ patterns.add<LowerToLoopsUsingSCFForOp>(context);
}
}
More information about the Mlir-commits
mailing list