[Mlir-commits] [mlir] [mlir] Don't require extract_slice in fusion with transform op (PR #112755)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Oct 17 11:04:21 PDT 2024
https://github.com/Max191 updated https://github.com/llvm/llvm-project/pull/112755
>From 432341c07ac5a1fc90aecd66e8ebeae3a4507c08 Mon Sep 17 00:00:00 2001
From: Max Dawkins <max.dawkins at gmail.com>
Date: Thu, 17 Oct 2024 10:04:01 -0500
Subject: [PATCH] [mlir] Don't require extract_slice in fusion with transform
op
Signed-off-by: Max Dawkins <max.dawkins at gmail.com>
---
.../TransformOps/LinalgTransformOps.cpp | 54 +++++++++++++------
.../transform-op-fuse-into-containing.mlir | 40 ++++++++++++++
2 files changed, 77 insertions(+), 17 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ad72b5d7beccde..2bc1d5dde6b5d9 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -818,27 +818,23 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
// Search the producer slices accessed within the containing operation.
// TODO: Generalize to more extract/insert/parallel_insert triples, maybe
// evolve into an interface.
+ if (bbArg.getUsers().empty()) {
+ diag.attachNote(containingOp->getLoc())
+ << "could not find fusion opportunity for bbArg: " << bbArg;
+ return {};
+ }
auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
return sliceOp && containingOp->isProperAncestor(sliceOp);
});
-
- // Find a fusion opportunity.
+ OpBuilder::InsertionGuard guard(rewriter);
+ tensor::ExtractSliceOp sliceOpToTile;
if (itBBArgUsers == bbArg.getUsers().end()) {
- diag.attachNote(containingOp->getLoc())
- << "could not find fusion opportunity for bbArg: " << bbArg;
- return {};
+ rewriter.setInsertionPoint(&bbArg.getOwner()->front());
+ } else {
+ sliceOpToTile = llvm::cast<tensor::ExtractSliceOp>(*itBBArgUsers);
+ rewriter.setInsertionPoint(sliceOpToTile);
}
- auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
-
- // Try to fuse the producer in-place.
- OpBuilder::InsertionGuard guard(rewriter);
- rewriter.setInsertionPoint(sliceOpToTile);
-
- // Replace the use in the tileableProducer before tiling: clone, replace and
- // then tile.
- int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
- LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
// Gather destination tensors.
SmallVector<Value> destinationTensors;
@@ -850,14 +846,38 @@ tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
return {};
}
+ // Replace the use in the tileableProducer before tiling: clone, replace and
+ // then tile.
+ SmallVector<Operation *> oldBbArgUsers(bbArg.getUsers());
+ int64_t resultNumber = cast<OpResult>(pUse->get()).getResultNumber();
+ LLVM_DEBUG(DBGS() << "resultNumber: " << resultNumber << "\n");
IRMapping bvm;
bvm.map(destinationTensors[resultNumber], bbArg);
auto tileableProducerClone =
cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
- auto scopeGuard =
- llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
+
+ // If there was no extract_slice user, then no need to tile.
+ if (!sliceOpToTile) {
+ LLVM_DEBUG(DBGS() << "No extract_slice user. No need to tile cloned op.\n");
+ // Replace the old uses of bbArg with the cloned op, except for any parallel
+ // insert ops.
+ rewriter.replaceUsesWithIf(
+ bbArg, tileableProducerClone->getResult(resultNumber),
+ [&](OpOperand &operand) {
+ return !isa<tensor::ParallelInsertSliceOp>(operand.getOwner()) &&
+ operand.getOwner() != tileableProducerClone.getOperation();
+ });
+ // Replace the use in containingOp.
+ rewriter.modifyOpInPlace(containingOp, [&]() {
+ containingOp->setOperand(pUse->getOperandNumber(),
+ destinationTensors.front());
+ });
+ return {tileableProducerClone};
+ }
// Tile the producer.
+ auto scopeGuard =
+ llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
FailureOr<TilingResult> tileAndFuseResult =
tileableProducerClone.generateResultTileValue(
rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
index 4115f2857a20c6..c0c7b8ec9598bc 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir
@@ -202,6 +202,46 @@ module {
// -----
+module {
+ // CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg_no_slice
+ // CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
+ // CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<?xf32>
+ // CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<?xf32>
+ func.func @fuse_tileable_op_through_bbarg_no_slice(%arg0: index, %arg1: tensor<?xf32>, %arg2: tensor<?xf32>) -> tensor<?xf32> {
+ %cst = arith.constant 4.200000e+01 : f32
+ %c0 = arith.constant 0 : index
+ %d0 = tensor.dim %arg1, %c0 : tensor<?xf32>
+
+ %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<?xf32>) -> tensor<?xf32>
+ // CHECK: scf.forall {{.*}} shared_outs(%[[BBARGOUT:.*]] = %[[OUT]]) -> (tensor<?xf32>) {
+ %1 = scf.forall (%arg3) in (%arg0) shared_outs(%o = %0) -> (tensor<?xf32>) {
+ // CHECK: %[[T0:.*]] = linalg.fill {{.*}} outs(%[[BBARGOUT]]
+
+ // CHECK: %[[T1:.*]] = linalg.elemwise_unary {{.*}} outs(%[[T0]]
+ %2 = linalg.elemwise_unary ins(%arg1 : tensor<?xf32>) outs(%o : tensor<?xf32>) -> tensor<?xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %2 into %o[0] [%d0] [1] : tensor<?xf32> into tensor<?xf32>
+ }
+ }
+ // CHECK: }
+ func.return %1 : tensor<?xf32>
+ }
+
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %1 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+
+ // linalg.fill is tileable. The op is tiled and fused.
+ transform.structured.fuse_into_containing_op %0 into %1
+ : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+ }
+}
+
+// -----
+
#map0 = affine_map<()[s0, s1] -> (s0 ceildiv s1)>
#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
#map2 = affine_map<(d0)[s0, s1] -> (-(d0 * s1) + s0, s1)>
More information about the Mlir-commits
mailing list