[Mlir-commits] [mlir] 12831be - [mlir][Linalg] NFC - Cleanup internal transform APIs and produce better messages on failure to apply.
Nicolas Vasilache
llvmlistbot at llvm.org
Mon Sep 19 04:21:15 PDT 2022
Author: Nicolas Vasilache
Date: 2022-09-19T04:16:15-07:00
New Revision: 12831be96cdc152e6b07e74e3841da3d9a7a93ab
URL: https://github.com/llvm/llvm-project/commit/12831be96cdc152e6b07e74e3841da3d9a7a93ab
DIFF: https://github.com/llvm/llvm-project/commit/12831be96cdc152e6b07e74e3841da3d9a7a93ab.diff
LOG: [mlir][Linalg] NFC - Cleanup internal transform APIs and produce better messages on failure to apply.
Added:
Modified:
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index bc3f10c717c3f..bf4396b446f9c 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -228,12 +228,16 @@ LogicalResult transform::FuseOp::verify() {
/// Find the first "extract" user of `producerOp` and tile it right before its
/// use. The tiled op is fused under the `containingOp`.
/// Return this fused op on success or nullptr if anything fails.
-static Operation *tileAndFuseFirstExtractUse(Operation *producerOp,
- Operation *containingOp,
- RewriterBase &rewriter) {
+static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
+ Diagnostic &diag,
+ Operation *producerOp,
+ Operation *containingOp) {
auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
- if (!tileableProducer)
+ if (!tileableProducer) {
+ diag.attachNote(producerOp->getLoc())
+ << "producer is not a TileableInterface: " << *producerOp;
return nullptr;
+ }
// Search the producer slices accessed within the containing operation.
// TODO: Generalize to more extract/insert/parallel_insert triples, maybe
@@ -244,8 +248,11 @@ static Operation *tileAndFuseFirstExtractUse(Operation *producerOp,
});
// Find a fusion opportunity.
- if (it == tileableProducer->getUsers().end())
+ if (it == tileableProducer->getUsers().end()) {
+ diag.attachNote(tileableProducer->getLoc())
+ << "could not find fusion opportunity for: " << *tileableProducer;
return nullptr;
+ }
auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
// Try to fuse the producer in-place.
@@ -256,8 +263,11 @@ static Operation *tileAndFuseFirstExtractUse(Operation *producerOp,
FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(),
sliceOpToTile.getMixedSizes());
- if (failed(tiledProducer))
+ if (failed(tiledProducer)) {
+ diag.attachNote(tileableProducer->getLoc())
+ << "failed to tile producer op: " << *tileableProducer;
return nullptr;
+ }
// Replace the extract op.
Operation *fusedOp = tiledProducer->getDefiningOp();
@@ -272,11 +282,25 @@ static Operation *tileAndFuseFirstExtractUse(Operation *producerOp,
/// `containingOp`.
/// Return this fused op on success or nullptr if anything fails.
static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
- Operation *producerOp, Operation *containingOp, RewriterBase &rewriter) {
+ RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
+ Operation *containingOp) {
auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
- if (!tileableProducer)
+ if (!tileableProducer) {
+ diag.attachNote(producerOp->getLoc())
+ << "producer is not a TileableInterface: " << *producerOp;
+ return nullptr;
+ }
+
+ // Ensure `tileableProducer` has exactly one destination operand that we can
+ // replace the ForeachThreadOp bbArg with.
+ auto destinationOperands = tileableProducer.getDestinationOperands(rewriter);
+ if (destinationOperands.size() != 1) {
+ diag.attachNote(tileableProducer->getLoc())
+ << "tileableProducer must have exactly one destination operand: "
+ << *tileableProducer;
return nullptr;
+ }
// Search the first use by a "scf::ForeachThreadOp" user.
scf::ForeachThreadOp foreachThreadOp;
@@ -286,8 +310,11 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
return foreachThreadOp;
});
// If it's not from the containing op, return.
- if (!foreachThreadOp || foreachThreadOp != containingOp)
+ if (!foreachThreadOp || foreachThreadOp != containingOp) {
+ diag.attachNote(tileableProducer->getLoc())
+ << "could not find a use by the containing op: " << *tileableProducer;
return nullptr;
+ }
// Search the producer slices accessed within the containing
// operation.
@@ -305,16 +332,13 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
});
// Find a fusion opportunity.
- if (itBBArgUsers == bbArg.getUsers().end())
+ if (itBBArgUsers == bbArg.getUsers().end()) {
+ diag.attachNote(containingOp->getLoc())
+ << "could not find fusion opportunity for bbArg: " << bbArg;
return nullptr;
+ }
auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
- // Ensure `tileableProducer` has exactly one destination operand that we can
- // replace the ForeachThreadOp bbArg with.
- auto destinationOperands = tileableProducer.getDestinationOperands(rewriter);
- if (destinationOperands.size() != 1)
- return nullptr;
-
// Try to fuse the producer in-place.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(sliceOpToTile);
@@ -333,8 +357,11 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
tileableProducerClone.generateResultTileValue(
rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(),
sliceOpToTile.getMixedSizes());
- if (failed(tiledProducer))
+ if (failed(tiledProducer)) {
+ diag.attachNote(tileableProducer->getLoc())
+ << "failed to tile producer op: " << *tileableProducer;
return nullptr;
+ }
// Replace the extract op.
Operation *fusedOp = tiledProducer->getDefiningOp();
@@ -349,9 +376,9 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
return fusedOp;
}
-static Operation *cloneAndFuseFirstUse(Operation *producerOp,
- Operation *containingOp,
- RewriterBase &rewriter) {
+static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
+ Operation *producerOp,
+ Operation *containingOp) {
// Gather all uses inside the containing op.
SmallVector<OpOperand *> uses;
for (OpResult result : producerOp->getOpResults()) {
@@ -362,14 +389,19 @@ static Operation *cloneAndFuseFirstUse(Operation *producerOp,
}
// Cannot clone and fuse if the use is by the containing op itself: fail
// immediately.
- if (containingOp == use.getOwner())
+ if (containingOp == use.getOwner()) {
+ diag.attachNote(producerOp->getLoc())
+ << "producer op use by containing op cannot be fused by cloning";
return nullptr;
+ }
}
}
// Check for a non-empty list of fusion opportunities.
- if (uses.empty())
+ if (uses.empty()) {
+ diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning";
return nullptr;
+ }
// Clone and fuse inside the containing op.
Operation *fusedOp = nullptr;
@@ -441,18 +473,23 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
auto nextProducer = getNextProducer();
if (failed(nextProducer)) {
Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Remark);
- diag << "could not fuse ops into container";
+ diag << "could not find next producer to fuse into container";
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
}
Operation *producerOp = *nextProducer;
+
+ // Detaul diagnostic, to be complemented with more failure information.
+ Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
+ diag << "could not fuse " << *producerOp << " into " << *containingOp;
+
// TODO: If there are multiple uses of the producer in the containing op,
// we currently tile/clone the op multiple times (once per use). In some
// cases, we can tile/clone once and reuse the value for each use.
// Futhermore, producers should then be traversed according to a
// topological sorting.
Operation *tiled =
- tileAndFuseFirstExtractUse(producerOp, containingOp, rewriter);
+ tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
if (tiled) {
fusedOps.push_back(tiled);
continue;
@@ -460,21 +497,19 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
Operation *tiledContainingOpOperand =
tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
- producerOp, containingOp, rewriter);
+ rewriter, diag, producerOp, containingOp);
if (tiledContainingOpOperand) {
fusedOps.push_back(tiledContainingOpOperand);
continue;
}
Operation *cloned =
- cloneAndFuseFirstUse(producerOp, containingOp, rewriter);
+ cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
if (cloned) {
fusedOps.push_back(cloned);
continue;
}
- Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
- diag << "could not fuse " << *producerOp << "into " << *containingOp;
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
}
More information about the Mlir-commits
mailing list