[Mlir-commits] [mlir] e479aec - Revert "[mlir][scf][Transform] Refactor transform.fuse_into_containing_op so it is iterative and supports output fusion."
Nicolas Vasilache
llvmlistbot at llvm.org
Wed Sep 14 08:52:08 PDT 2022
Author: Nicolas Vasilache
Date: 2022-09-14T08:51:30-07:00
New Revision: e479aecd56d20bea409c507dd237c3f37a766702
URL: https://github.com/llvm/llvm-project/commit/e479aecd56d20bea409c507dd237c3f37a766702
DIFF: https://github.com/llvm/llvm-project/commit/e479aecd56d20bea409c507dd237c3f37a766702.diff
LOG: Revert "[mlir][scf][Transform] Refactor transform.fuse_into_containing_op so it is iterative and supports output fusion."
This reverts commit 54a5f606281d05203dca1d81d135e691b10bc513 which is a WIP that was pushed by mistake.
Added:
Modified:
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 49328a6cb708..29b13e27de7e 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -17,12 +17,9 @@
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
-#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
-#include "llvm/ADT/STLExtras.h"
-#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/StringSet.h"
using namespace mlir;
@@ -229,168 +226,78 @@ LogicalResult transform::FuseOp::verify() {
// FuseIntoContainingOp
//===----------------------------------------------------------------------===//
-/// Find the first "extract" user of `producerOp` and tile it right before its
-/// use. The tiled op is now fused under the `containingOp`.
-/// Return this fused op on success or nullptr if anything fails.
-static Operation *tileAndFuseFirstExtractUse(Operation *producerOp,
- Operation *containingOp,
- RewriterBase &rewriter) {
+static FailureOr<SmallVector<Operation *>> tileAndFuse(Operation *producerOp,
+ Operation *containingOp,
+ RewriterBase &rewriter) {
auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
if (!tileableProducer)
- return nullptr;
+ return failure();
// Search the producer slices accessed within the containing operation.
- // TODO: Generalize to more extract/insert/parallel_insert triples.
- // Maybe evolve into an interface.
- auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
+ // TODO: Generalize to more extract/insert/parallel_insert triples. Maybe
+ // evolve into an interface.
+ SmallVector<tensor::ExtractSliceOp> sliceOps;
+ for (Operation *user : tileableProducer->getUsers()) {
auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
- return sliceOp && containingOp->isProperAncestor(sliceOp);
- });
-
- // Check for a non-empty fusion opportunity.
- if (it == tileableProducer->getUsers().end())
- return nullptr;
- auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
-
- // Try to fuse the producer in-place.
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPoint(sliceOpToTile);
-
- // Tile the producer.
- FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
- rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(),
- sliceOpToTile.getMixedSizes());
- if (failed(tiledProducer))
- return nullptr;
-
- // Replace the extract op.
- Operation *fusedOp = tiledProducer->getDefiningOp();
- rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(0));
- return fusedOp;
-}
-
-/// Find the first "extract" user of `producerOp` and tile it right before its
-/// use. The tiled op is now fused under the `containingOp`.
-/// Return this fused op on success or nullptr if anything fails.
-static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
- Operation *producerOp, Operation *containingOp, RewriterBase &rewriter) {
-
- auto foreachThreadOp = dyn_cast<scf::ForeachThreadOp>(containingOp);
- if (!foreachThreadOp)
- return nullptr;
-
- auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
- if (!tileableProducer)
- return nullptr;
-
- // Search the producer slices accessed within the containing
- // operation.
- // TODO: Generalize to more extract/insert/parallel_insert triples.
- // Maybe evolve into an interface.
- OpOperand *pUse;
- BlockArgument bbArg;
- tensor::ExtractSliceOp sliceOpToTile;
- // Only consider slices that may come from the containingOp args.
- for (OpOperand &use : tileableProducer->getUses()) {
- if (use.getOwner() != containingOp)
+ if (!sliceOp)
continue;
- pUse = &use;
- bbArg = foreachThreadOp.getTiedBlockArgument(&use);
- for (Operation *user : bbArg.getUsers()) {
- auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
- if (!sliceOp)
- continue;
- if (!containingOp->isAncestor(sliceOp))
- continue;
- sliceOpToTile = sliceOp;
- break;
- }
- if (sliceOpToTile)
- break;
+ if (!containingOp->isProperAncestor(sliceOp))
+ continue;
+ sliceOps.push_back(sliceOp);
}
// Check for a non-empty list of fusion opportunities.
- if (!sliceOpToTile || !pUse)
- return nullptr;
-
- // Ensure there is exactly one destination operand that we can replace the
- // ForeachThreadOp bbArg with.
- auto destinationOperands = tileableProducer.getDestinationOperands(rewriter);
- if (destinationOperands.size() != 1)
- return nullptr;
+ if (sliceOps.empty())
+ return failure();
// Try to fuse the producer in-place.
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPoint(sliceOpToTile);
-
- // Replace the use in the tileableProducer before tiling, replace and then
- // tile.
- BlockAndValueMapping bvm;
- bvm.map(destinationOperands.front(), bbArg);
- auto tileableProducerClone =
- cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
- auto scopeGuard =
- llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
-
- // Tile the producer.
- FailureOr<Value> tiledProducer =
- tileableProducerClone.generateResultTileValue(
- rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(),
- sliceOpToTile.getMixedSizes());
- if (failed(tiledProducer))
- return nullptr;
-
- // Replace the extract op.
- Operation *fusedOp = tiledProducer->getDefiningOp();
- rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(0));
+ SmallVector<Operation *> fusedOps;
+ for (tensor::ExtractSliceOp sliceOp : sliceOps) {
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(sliceOp);
- // Replace the use in containingOp.
- rewriter.startRootUpdate(fusedOp);
- containingOp->setOperand(pUse->getOperandNumber(),
- destinationOperands.front());
- rewriter.finalizeRootUpdate(fusedOp);
+ // Tile the producer.
+ FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
+ rewriter, /*resultNumber=*/0, sliceOp.getMixedOffsets(),
+ sliceOp.getMixedSizes());
+ if (failed(tiledProducer))
+ return failure();
+ fusedOps.push_back(tiledProducer->getDefiningOp());
+ }
- return fusedOp;
+ // Replace the extract op.
+ for (const auto &en : enumerate(sliceOps))
+ rewriter.replaceOp(en.value(), fusedOps[en.index()]->getResult(0));
+ return fusedOps;
}
-static Operation *cloneAndFuseFirstUse(Operation *producerOp,
- Operation *containingOp,
- RewriterBase &rewriter) {
+static FailureOr<SmallVector<Operation *>>
+cloneAndFuse(Operation *producerOp, Operation *containingOp,
+ RewriterBase &rewriter) {
// Gather all uses inside the containing op.
SmallVector<OpOperand *> uses;
- for (OpResult result : producerOp->getOpResults()) {
- for (OpOperand &use : result.getUses()) {
- if (containingOp->isProperAncestor(use.getOwner())) {
+ for (OpResult result : producerOp->getOpResults())
+ for (OpOperand &use : result.getUses())
+ if (containingOp->isProperAncestor(use.getOwner()))
uses.push_back(&use);
- continue;
- }
- // Cannot clone and fuse if the use is fom the containing op itself: fail.
- if (containingOp == use.getOwner())
- return nullptr;
- }
- }
// Check for a non-empty list of fusion opportunities.
if (uses.empty())
- return nullptr;
+ return failure();
// Clone and fuse inside the containing op.
- Operation *fusedOp = nullptr;
+ SmallVector<Operation *> fusedOps;
for (OpOperand *use : uses) {
- // Parallel insert slice is not a valid clone destination.
- // TODO: Generalize to other type of ops.
- assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
- "Parallel insert slice is not a valid clone destination");
unsigned resultNumber = use->get().cast<OpResult>().getResultNumber();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(use->getOwner());
- fusedOp = rewriter.clone(*producerOp);
+ Operation *cloned = rewriter.clone(*producerOp);
rewriter.updateRootInPlace(
- use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
- break;
+ use->getOwner(), [&] { use->set(cloned->getOpResult(resultNumber)); });
+ fusedOps.push_back(cloned);
}
- return fusedOp;
+ return fusedOps;
}
DiagnosedSilenceableFailure
@@ -405,7 +312,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
}
for (Operation *producerOp : producerOps) {
if (producerOp->getNumResults() != 1) {
- Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
+ Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note);
diag << "op with != 1 results not supported";
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
}
@@ -424,17 +331,15 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
auto getNextProducer = [&]() -> FailureOr<Operation *> {
for (const auto &it : enumerate(remainingProducers)) {
Operation *producerOp = it.value();
- // The containing op may be a user of producerOp: use isAncestor.
- int64_t numUsesInContainingOp =
- llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
- return containingOp->isAncestor(op);
+ bool hasUseInContainingOp =
+ any_of(producerOp->getUsers(), [&](Operation *op) {
+ return containingOp->isProperAncestor(op);
});
- // TODO: When resolving the TODO below (no duplicate ops), take an op
- // that has no use among the remaining producers. This is a topological
+ // TODO: When resolving the TODO below (no duplicate ops), take an op that
+ // has no use among the remaining producers. This is a topological
// sorting.
- if (numUsesInContainingOp > 0) {
- if (numUsesInContainingOp == 1)
- remainingProducers.erase(remainingProducers.begin() + it.index());
+ if (hasUseInContainingOp) {
+ remainingProducers.erase(remainingProducers.begin() + it.index());
return producerOp;
}
}
@@ -445,42 +350,29 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
while (!remainingProducers.empty()) {
auto nextProducer = getNextProducer();
if (failed(nextProducer)) {
- Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Remark);
+ Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Note);
diag << "could not fuse ops into container";
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
}
Operation *producerOp = *nextProducer;
- // TODO: If there are multiple uses of the producer in the containing op,
- // we currently tile/clone the op multiple times (once per use). In some
- // cases, we can tile/clone once and reuse the value for each use.
- // Futhermore, producers should then be traversed according to a
- // topological sorting.
- Operation *tiled =
- tileAndFuseFirstExtractUse(producerOp, containingOp, rewriter);
- if (tiled) {
- fusedOps.push_back(tiled);
- continue;
- }
-
- Operation *tiledContainingOpOperand =
- tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
- producerOp, containingOp, rewriter);
- if (tiledContainingOpOperand) {
- fusedOps.push_back(tiledContainingOpOperand);
- continue;
- }
-
- Operation *cloned =
- cloneAndFuseFirstUse(producerOp, containingOp, rewriter);
- if (cloned) {
- fusedOps.push_back(cloned);
- continue;
+ // TODO: If there are multiple uses of the producer in the containing op, we
+ // currently tile/clone the op multiple times (once per use). In some cases,
+ // we can tile/clone once and reuse the value for each use. Futhermore,
+ // producers should then be traversed according to a topological sorting.
+ auto tiled = tileAndFuse(producerOp, containingOp, rewriter);
+ if (succeeded(tiled))
+ fusedOps.append(*tiled);
+
+ auto cloned = cloneAndFuse(producerOp, containingOp, rewriter);
+ if (succeeded(cloned))
+ fusedOps.append(*cloned);
+
+ if (failed(tiled) && failed(cloned)) {
+ Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note);
+ diag << "could not fuse into containing op";
+ return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
}
-
- Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
- diag << "could not fuse " << *producerOp << "into " << *containingOp;
- return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
}
results.set(getFusedOp().cast<OpResult>(), fusedOps);
@@ -734,9 +626,9 @@ LogicalResult transform::PadOp::verify() {
extractFromI64ArrayAttr(getPaddingDimensions());
if (any_of(paddingDimensions,
[](int64_t paddingDimension) { return paddingDimension < 0; })) {
- return emitOpError() << "expects padding_dimensions to contain positive "
- "integers, found "
- << getPaddingDimensions();
+ return emitOpError()
+ << "expects padding_dimensions to contain positive integers, found "
+ << getPaddingDimensions();
}
SmallVector<int64_t> hoistPaddings =
@@ -807,8 +699,8 @@ transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
transform::TransformState &state) {
LinalgTilingOptions tilingOptions;
tilingOptions.scalarizeDynamicDims();
- // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the
- // tile sizes and asserts that it is not already set.
+ // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile
+ // sizes and asserts that it is not already set.
SmallVector<int64_t> emptyTileSizes;
LinalgTilingPattern pattern(getContext(), tilingOptions);
SimpleRewriter rewriter(getContext());
@@ -955,8 +847,8 @@ LogicalResult SplitOp::verify() {
if ((static_cast<int64_t>(getStaticSplitPoint()) !=
ShapedType::kDynamicSize) ^
(getDynamicSplitPoint() == nullptr)) {
- return emitOpError() << "expects either a dynamic or a static split "
- "point to be provided";
+ return emitOpError()
+ << "expects either a dynamic or a static split point to be provided";
}
return success();
}
@@ -1310,8 +1202,8 @@ transform::VectorizeOp::applyToOne(Operation *target,
//===----------------------------------------------------------------------===//
namespace {
-/// Registers new ops and declares PDL as dependent dialect since the
-/// additional ops are using PDL types for operands and results.
+/// Registers new ops and declares PDL as dependent dialect since the additional
+/// ops are using PDL types for operands and results.
class LinalgTransformDialectExtension
: public transform::TransformDialectExtension<
LinalgTransformDialectExtension> {
More information about the Mlir-commits
mailing list