[Mlir-commits] [mlir] 4b17710 - [mlir][Linalg] Support multi-output fusion in FuseIntoContainingOp
Nicolas Vasilache
llvmlistbot at llvm.org
Fri Oct 14 03:55:04 PDT 2022
Author: Nicolas Vasilache
Date: 2022-10-14T03:54:54-07:00
New Revision: 4b17710369df7f1ba73ce63d4312726b9a2b52cc
URL: https://github.com/llvm/llvm-project/commit/4b17710369df7f1ba73ce63d4312726b9a2b52cc
DIFF: https://github.com/llvm/llvm-project/commit/4b17710369df7f1ba73ce63d4312726b9a2b52cc.diff
LOG: [mlir][Linalg] Support multi-output fusion in FuseIntoContainingOp
This revision adds the ability to fuse tileable ops with multiple results to
the transform.fuse_into_containing_op.
Differential Revision: https://reviews.llvm.org/D135955
Added:
Modified:
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index be4efaafc6ca9..5c304f5efb6ea 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -71,9 +71,8 @@ def FuseIntoContainingOp :
let description = [{Fuse a producer into a containing operation.}];
let summary = [{
- Fuses the `producer_op` into the `containing_op`. Only producers with a
- single result are supported at the moment. Returns a handle to the fused
- ops.
+ Fuses the `producer_op` into the `containing_op`.
+ Returns a handle to the fused ops.
The producer is typically a slice of a tileable op (i.e., implements
TilingInterface). In that case, this transform computes the accessed
@@ -98,8 +97,10 @@ def FuseIntoContainingOp :
This is the case when tiling fails or when no producer op could be found
among the remaining producers that has at least one use within the
containing op. I.e., "producers" that are not consumed within the containing
- op are rejected by this operation. This operation reads and frees the
- producer handle. It reads the containing op handle.
+ op are rejected by this operation.
+
+ This operation reads and frees the producer handle.
+ This operation reads the containing op handle.
}];
let arguments = (ins Arg<PDL_Operation, "",
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ed74de7f61f58..e47e8e51c6830 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -22,11 +22,14 @@
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/Debug.h"
using namespace mlir;
using namespace mlir::linalg;
using namespace mlir::transform;
+#define DEBUG_TYPE "linalg-transforms"
+
/// Extracts a vector of unsigned from an array attribute. Asserts if the
/// attribute contains values other than intergers. May truncate.
static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
@@ -258,6 +261,7 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
Diagnostic &diag,
Operation *producerOp,
Operation *containingOp) {
+ LLVM_DEBUG(llvm::dbgs() << "Try to fuse a direct extract use\n");
auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
if (!tileableProducer) {
diag.attachNote(producerOp->getLoc())
@@ -286,18 +290,23 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
rewriter.setInsertionPoint(sliceOpToTile);
// Tile the producer.
+ int64_t resultNumber =
+ sliceOpToTile.getSource().cast<OpResult>().getResultNumber();
+ LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n");
+
FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
- rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(),
+ rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
sliceOpToTile.getMixedSizes());
if (failed(tiledProducer)) {
diag.attachNote(tileableProducer->getLoc())
<< "failed to tile producer op: " << *tileableProducer;
return nullptr;
}
+ LLVM_DEBUG(llvm::dbgs() << "tiledProducer: " << *tiledProducer << "\n");
// Replace the extract op.
Operation *fusedOp = tiledProducer->getDefiningOp();
- rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(0));
+ rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(resultNumber));
return fusedOp;
}
@@ -310,6 +319,8 @@ static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
Operation *containingOp) {
+ LLVM_DEBUG(
+ llvm::dbgs() << "Try to fuse an extract use through block argument\n");
auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
if (!tileableProducer) {
@@ -318,16 +329,6 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
return nullptr;
}
- // Ensure `tileableProducer` has exactly one destination operand that we can
- // replace the ForeachThreadOp bbArg with.
- auto destinationOperands = tileableProducer.getDestinationOperands(rewriter);
- if (destinationOperands.size() != 1) {
- diag.attachNote(tileableProducer->getLoc())
- << "tileableProducer must have exactly one destination operand: "
- << *tileableProducer;
- return nullptr;
- }
-
// Search the first use by a "scf::ForeachThreadOp" user.
scf::ForeachThreadOp foreachThreadOp;
auto itProducerUses =
@@ -371,8 +372,13 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
// Replace the use in the tileableProducer before tiling: clone, replace and
// then tile.
+ int64_t resultNumber = pUse->get().cast<OpResult>().getResultNumber();
+ LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n");
+
+ auto destinationOperands = tileableProducer.getDestinationOperands(rewriter);
+
BlockAndValueMapping bvm;
- bvm.map(destinationOperands.front(), bbArg);
+ bvm.map(destinationOperands[resultNumber], bbArg);
auto tileableProducerClone =
cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
auto scopeGuard =
@@ -381,17 +387,18 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
// Tile the producer.
FailureOr<Value> tiledProducer =
tileableProducerClone.generateResultTileValue(
- rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(),
+ rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
sliceOpToTile.getMixedSizes());
if (failed(tiledProducer)) {
diag.attachNote(tileableProducer->getLoc())
<< "failed to tile producer op: " << *tileableProducer;
return nullptr;
}
+ LLVM_DEBUG(llvm::dbgs() << "tiledProducer: " << *tiledProducer << "\n");
// Replace the extract op.
Operation *fusedOp = tiledProducer->getDefiningOp();
- rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(0));
+ rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(resultNumber));
// Replace the use in containingOp.
rewriter.updateRootInPlace(containingOp, [&]() {
@@ -405,6 +412,8 @@ static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
Operation *producerOp,
Operation *containingOp) {
+ LLVM_DEBUG(llvm::dbgs() << "Try to fuse an use by cloning\n");
+
// Gather all uses inside the containing op.
SmallVector<OpOperand *> uses;
for (OpResult result : producerOp->getOpResults()) {
@@ -437,6 +446,8 @@ static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
"Parallel insert slice is not a valid clone destination");
unsigned resultNumber = use->get().cast<OpResult>().getResultNumber();
+ LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n");
+
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(use->getOwner());
fusedOp = rewriter.clone(*producerOp);
@@ -453,21 +464,17 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
ArrayRef<Operation *> producerOps = state.getPayloadOps(getProducerOp());
// If nothing to fuse, propagate success.
if (producerOps.empty()) {
- results.set(getResult().cast<OpResult>(), SmallVector<mlir::Operation *>{});
+ results.set(getFusedOp().cast<OpResult>(),
+ SmallVector<mlir::Operation *>{});
return DiagnosedSilenceableFailure::success();
}
- for (Operation *producerOp : producerOps) {
- if (producerOp->getNumResults() != 1) {
- Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
- diag << "op with != 1 results not supported";
- return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
- }
- }
ArrayRef<Operation *> containingOps = state.getPayloadOps(getContainingOp());
- if (containingOps.size() != 1)
+ if (containingOps.size() != 1) {
+ // Definite failure.
return DiagnosedSilenceableFailure(
this->emitOpError("requires exactly one containing_op handle (got ")
<< containingOps.size() << ")");
+ }
Operation *containingOp = containingOps.front();
// Helper function to find the next producer that should be fused. Take any
@@ -498,6 +505,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
while (!remainingProducers.empty()) {
auto nextProducer = getNextProducer();
if (failed(nextProducer)) {
+ results.set(getFusedOp().cast<OpResult>(), ArrayRef<Operation *>());
Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Remark);
diag << "could not find next producer to fuse into container";
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
@@ -505,7 +513,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
Operation *producerOp = *nextProducer;
- // Detaul diagnostic, to be complemented with more failure information.
+ // Default diagnostic, to be complemented with more failure information.
Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
diag << "could not fuse " << *producerOp << " into " << *containingOp;
@@ -517,6 +525,8 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
Operation *tiled =
tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
if (tiled) {
+ LLVM_DEBUG(llvm::dbgs() << "\nFused a direct extract use\n"
+ << *containingOp);
fusedOps.push_back(tiled);
continue;
}
@@ -525,6 +535,9 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
rewriter, diag, producerOp, containingOp);
if (tiledContainingOpOperand) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "\nFused an extract use through block argument\n"
+ << *containingOp);
fusedOps.push_back(tiledContainingOpOperand);
continue;
}
@@ -532,10 +545,12 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
Operation *cloned =
cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
if (cloned) {
+ LLVM_DEBUG(llvm::dbgs() << "\nFused an use by cloning\n"
+ << *containingOp);
fusedOps.push_back(cloned);
continue;
}
-
+ results.set(getFusedOp().cast<OpResult>(), ArrayRef<Operation *>());
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
}
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index b1af4ef2869be..141e8f59b5a21 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -141,3 +141,63 @@ module {
transform.structured.fuse_into_containing_op %0 into %1
}
}
+
+// -----
+
+#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
+
+module {
+ // CHECK-LABEL: func.func @fuse_tileable_multi_output_op
+ // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
+ // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<?xf32>
+ // CHECK-SAME: %[[OUT_1:[0-9a-z]+]]: tensor<?xf32>
+ // CHECK-SAME: %[[OUT_2:[0-9a-z]+]]: tensor<?xf32>
+ // CHECK-SAME: %[[OUT_3:[0-9a-z]+]]: tensor<?xf32>
+ func.func @fuse_tileable_multi_output_op(%idx: index, %in: tensor<?xf32>, %out_1: tensor<?xf32>, %out_2: tensor<?xf32>, %out_3: tensor<?xf32>) -> tensor<?xf32> {
+ %cst = arith.constant 4.200000e+01 : f32
+ %c0 = arith.constant 0 : index
+
+ %0:2 = linalg.generic {
+ indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
+ iterator_types = ["parallel"]
+ } ins(%in : tensor<?xf32>) outs(%out_1, %out_3 : tensor<?xf32>, tensor<?xf32>) {
+ ^bb0(%a: f32, %b: f32, %c: f32):
+ %d = arith.addf %a, %b : f32
+ %e = arith.addf %d, %c : f32
+ linalg.yield %d, %e : f32, f32
+ } -> (tensor<?xf32>, tensor<?xf32>)
+ %d0 = tensor.dim %out_1, %c0 : tensor<?xf32>
+
+ %1 = affine.apply #map0()[%d0, %idx]
+
+ // CHECK: scf.foreach_thread {{.*}} {
+ %2 = scf.foreach_thread (%i) in (%1) shared_outs(%o = %out_2) -> (tensor<?xf32>) {
+ %3 = affine.apply #map1(%i)[%idx]
+ %4 = affine.min #map2(%i)[%d0, %idx]
+ %5 = tensor.extract_slice %o[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+ // CHECK: %[[T0:.*]] = tensor.extract_slice %[[IN]][%{{.*}}] [%{{.*}}] [{{.*}}]
+ // CHECK: %[[T1:.*]]:2 = linalg.generic {{.*}} ins(%[[T0]]
+ %6 = tensor.extract_slice %0#0[%3] [%4] [1] : tensor<?xf32> to tensor<?xf32>
+
+ // CHECK: %[[T2:.*]] = linalg.elemwise_unary ins(%[[T1]]#0
+ %7 = linalg.elemwise_unary ins(%6 : tensor<?xf32>) outs(%5 : tensor<?xf32>) -> tensor<?xf32>
+ scf.foreach_thread.perform_concurrently {
+ tensor.parallel_insert_slice %7 into %o[%3] [%4] [1] : tensor<?xf32> into tensor<?xf32>
+ }
+ }
+ // CHECK: }
+ func.return %2 : tensor<?xf32>
+ }
+
+ transform.sequence failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ %1 = transform.structured.match ops{["scf.foreach_thread"]} in %arg1
+
+ // linalg.generic is tileable. The op is tiled and fused.
+ transform.structured.fuse_into_containing_op %0 into %1
+ }
+}
More information about the Mlir-commits
mailing list