[Mlir-commits] [mlir] [mlir] Fix consumer fusion for producer with multiple results (PR #125915)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Feb 5 11:30:25 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-scf
Author: Prashant Kumar (pashu123)
<details>
<summary>Changes</summary>
In the case of consumer fusion where the producer is producing multiple results all used by a single consumer for e.g.,
%results:3 = scf.forall ... -> (tensor<...>, tensor<...>, tensor<...>) { // Produces 3 results
scf.yield %a, %b, %c : tensor<...>, tensor<...>, tensor<...>}
// Consumer uses all 3 results
%final = consumer %results#<!-- -->0, %results#<!-- -->1, %results#<!-- -->2
all other operands of the tiled consumer needs to updated.
---
Full diff: https://github.com/llvm/llvm-project/pull/125915.diff
2 Files Affected:
- (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+110-15)
- (modified) mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir (+128-4)
``````````diff
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index b548f8ce8b560b1..bca727de3ddb3f6 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1949,6 +1949,60 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) {
}
}
+// If the producer of the operand is a loopLikeOp, then finds the last
+// insertSlice/parallelInsertSlice in the producer op that uses the block
+// argument corresponding to the operand.
+static FailureOr<Operation *>
+getSliceOpFromConsumerOperand(OpOperand &operand) {
+
+ OpResult producerResult = dyn_cast<OpResult>(operand.get());
+ if (!producerResult)
+ return failure();
+
+ LoopLikeOpInterface loopLikeOp =
+ dyn_cast<LoopLikeOpInterface>(producerResult.getOwner());
+ if (!loopLikeOp)
+ return failure();
+
+ // Obtain the BlockArgument correponding to the result.
+ BlockArgument bbArg =
+ loopLikeOp.getRegionIterArgs()[producerResult.getResultNumber()];
+
+ // Finally return the operation corresponding to the yielded value.
+ // Also check whether it's an InsertSliceOp.
+ if (dyn_cast<scf::ForOp>(producerResult.getOwner())) {
+ OpOperand *yieldVal = loopLikeOp.getTiedLoopYieldedValue(bbArg);
+ Operation *lastOp = dyn_cast<OpResult>(yieldVal->get()).getOwner();
+ auto isInsertSliceOp = isa<tensor::InsertSliceOp>(lastOp);
+ if (!isInsertSliceOp) {
+ return failure();
+ }
+ return lastOp;
+ }
+
+ auto forallOp = dyn_cast<scf::ForallOp>(producerResult.getOwner());
+ if (!forallOp)
+ return failure();
+
+ // Iterate over the terminator operation of the forallOp to find the last
+ // parallelInsertSliceOp that uses the blockArgument.
+ Operation *lastOp = nullptr;
+ forallOp.getTerminator()->walk([&](tensor::ParallelInsertSliceOp op) {
+ for (mlir::Value operand : op->getOperands()) {
+ if (auto maybeBlockArg = dyn_cast<BlockArgument>(operand)) {
+ if (maybeBlockArg == bbArg) {
+ lastOp = op;
+ }
+ }
+ }
+ });
+
+ if (!lastOp)
+ return failure();
+
+ return lastOp;
+}
+
/// Implementation of fusing consumer of a single slice by computing the
/// slice of the consumer in-place for scf loop.
FailureOr<scf::SCFFuseConsumerOfSliceResult>
@@ -1979,6 +2033,26 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
consumerOp, "consumer op's operand doesn't seem to be an OpResult");
}
+ SmallVector<OpOperand *> potentialOperands{*maybeConsumerOpOperand};
+ SmallVector<unsigned> potentialOperandResultNos{
+ consumerOpOperand->getOperandNumber()};
+ SmallVector<Operation *> potentialSliceOps{candidateSliceOp};
+
+ // 1b. Get all the other operands of the consumer op and their corresponding
+ // slice ops. In the case of the consumer consuming using multiple results
+ // from the producer, we need to update every operand.
+ for (OpOperand &otherOperand : consumerOp->getOpOperands()) {
+ if (&otherOperand == *maybeConsumerOpOperand)
+ continue;
+ auto maybePotentialSlice = getSliceOpFromConsumerOperand(otherOperand);
+ if (failed(maybePotentialSlice)) {
+ continue;
+ }
+ potentialSliceOps.push_back(*maybePotentialSlice);
+ potentialOperands.push_back(&otherOperand);
+ potentialOperandResultNos.push_back(otherOperand.getOperandNumber());
+ }
+
// There are two possible cases regarding `oldLoopOp` here:
// 1. single `scf.forall` or `scf.for`.
// 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
@@ -2037,18 +2111,29 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
// tensor.insert_slice. In the scf.for case this is a clone of the
// candidateSliceOp whereas in the scf.forall case this is created from the
// operands of tensor.parallel_insert_slice.
- tensor::InsertSliceOp clonedInsertSliceOp;
+
+ SmallVector<tensor::InsertSliceOp> allClonedInsertSliceOps;
+
+ scf::ForallOp newForallOp;
if (auto sliceOp =
dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
rewriter.setInsertionPoint(newForallOp.getTerminator());
- clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
- loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
- sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
} else {
- rewriter.setInsertionPoint(candidateSliceOp);
- clonedInsertSliceOp =
- cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
+ rewriter.setInsertionPoint(potentialSliceOps.back());
+ }
+
+ for (auto *candidateSliceOp : potentialSliceOps) {
+ if (auto sliceOp =
+ dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
+ allClonedInsertSliceOps.push_back(rewriter.create<tensor::InsertSliceOp>(
+ loc, sliceOp.getSource(), sliceOp.getDest(),
+ sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
+ sliceOp.getMixedStrides()));
+ } else {
+ allClonedInsertSliceOps.push_back(
+ cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp)));
+ }
}
// 5.a. Clone consumer op.
@@ -2056,24 +2141,34 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
// 5.b. Replace all uses of the loop result with the result of the cloned
// tensor.insert_slice.
- OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
- rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
- operandToReplace.set(clonedInsertSliceOp.getResult());
- });
+ for (const auto &it : llvm::enumerate(allClonedInsertSliceOps)) {
+ OpOperand &operandToReplace =
+ clonedConsumerOp->getOpOperand(potentialOperandResultNos[it.index()]);
+ rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
+ operandToReplace.set(it.value().getResult());
+ });
+ }
// 6. Perform tiling of the cloned consumer and replace the operand at
// `operandNumber` with the source of the cloned tensor.insert_slice op.
- auto ossSliceOp =
- cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
+ auto ossSliceOp = cast<OffsetSizeAndStrideOpInterface>(
+ allClonedInsertSliceOps.front().getOperation());
FailureOr<TilingResult> tileAndFuseResult =
tensor::replaceInsertSliceWithTiledConsumer(
rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
+
if (failed(tileAndFuseResult)) {
return failure();
}
+
auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
- rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
- clonedInsertSliceOp.getSource());
+
+ // 6b. Update the tiled consumer op with the new operands.
+ for (const auto &it : llvm::enumerate(allClonedInsertSliceOps)) {
+ rewriter.replaceAllUsesWith(
+ tiledConsumerOp->getOperand(potentialOperandResultNos[it.index()]),
+ it.value().getSource());
+ }
// 7. Reconstruct [nested] loop with new inits.
YieldTiledValuesFn newYieldValuesFn =
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index a2871b30698c527..14b9ec504c1585e 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -282,7 +282,7 @@ module {
return %unpack : tensor<2048xf32>
}
}
-
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
@@ -343,7 +343,7 @@ module {
return %unpack : tensor<2047xf32>
}
}
-
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
@@ -404,7 +404,7 @@ module {
return %pack : tensor<4x32x16xf32>
}
}
-
+
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
%slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
@@ -610,7 +610,7 @@ module attributes {transform.with_named_sequence} {
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
// CHECK: %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
// CHECK: %[[LOOP_RESULT:.*]]:3 = scf.for %[[IV1:.*]] = %[[C0]]
-// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]])
+// CHECK-SAME: iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]])
// CHECK-SAME: {
// CHECK: %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
// CHECK: %[[ADD_INS0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 256] [1, 1]
@@ -676,3 +676,127 @@ module attributes {transform.with_named_sequence} {
// CHECK: }
// CHECK: %[[RES_SLICE:.+]] = tensor.insert_slice
// CHECK: return %[[LOOP_RESULT]]#1, %[[RES_SLICE]]
+
+// -----
+
+module {
+ func.func @forall_producer_multiple_result_single_consumer(%arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %1:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2, %arg6 = %arg2) -> (tensor<64x64xf32>, tensor<64x64xf32>) {
+ %outs = tensor.empty() : tensor<32x32xf32>
+ %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+ %3 = linalg.matmul ins(%extracted_slice, %extracted_slice : tensor<32x32xf32>, tensor<32x32xf32>) outs(%outs : tensor<32x32xf32>) -> tensor<32x32xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %3 into %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+ tensor.parallel_insert_slice %extracted_slice into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+ }
+ }
+ %final_out = tensor.empty() : tensor<64x64xf32>
+ %2 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%1#0, %1#1 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%final_out : tensor<64x64xf32>) -> tensor<64x64xf32>
+ return %2 : tensor<64x64xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %consumer, %fused_consumer = transform.test.fuse_consumer %1#0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK-LABEL: func.func @forall_producer_multiple_result_single_consumer(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<64x64xf32>
+
+// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<64x64xf32>
+// CHECK: %[[LOOP_RESULT:.+]]:3 = scf.forall (%[[I:.+]], %[[J:.+]]) in (2, 2) shared_outs(%[[SHARED0:.+]] = %[[ARG0]], %[[SHARED1:.+]] = %[[ARG0]], %[[SHARED2:.+]] = %[[INIT]])
+
+// CHECK: %[[TILE_INIT:.+]] = tensor.empty() : tensor<32x32xf32>
+// CHECK: %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[SHARED0]][%[[I]], %[[J]]] [32, 32] [1, 1]
+// CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%[[EXTRACTED_SLICE]], %[[EXTRACTED_SLICE]] : tensor<32x32xf32>, tensor<32x32xf32>) outs(%[[TILE_INIT]] : tensor<32x32xf32>)
+// CHECK: %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[MATMUL]] into %[[SHARED1]][%[[I]], %[[J]]] [32, 32] [1, 1]
+// CHECK: %[[INSERTED_SLICE0:.+]] = tensor.insert_slice %[[EXTRACTED_SLICE]] into %[[SHARED0]][%[[I]], %[[J]]] [32, 32] [1, 1]
+// CHECK: %[[EXTRACTED_SLICE1:.+]] = tensor.extract_slice %[[SHARED2]][%[[I]], %[[J]]] [32, 32] [1, 1]
+// CHECK: %[[ADD:.+]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%[[EXTRACTED_SLICE]], %[[MATMUL]] : tensor<32x32xf32>, tensor<32x32xf32>) outs(%[[EXTRACTED_SLICE1]] : tensor<32x32xf32>)
+
+// CHECK: scf.forall.in_parallel {
+// CHECK: tensor.parallel_insert_slice %[[MATMUL]] into %[[SHARED1]][%[[I]], %[[J]]] [32, 32] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[EXTRACTED_SLICE]] into %[[SHARED0]][%[[I]], %[[J]]] [32, 32] [1, 1]
+// CHECK: tensor.parallel_insert_slice %[[ADD]] into %[[SHARED2]][%[[I]], %[[J]]] [32, 32] [1, 1]
+// CHECK: }
+
+// CHECK: return %[[LOOP_RESULT]]#2 : tensor<64x64xf32>
+
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+module {
+ func.func @for_producer_producing_multiple_result_single_consumer(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+ %c4 = arith.constant 4 : index
+ %c64 = arith.constant 64 : index
+ %c0 = arith.constant 0 : index
+ %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) {
+ %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32>
+ %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) {
+ ^bb0(%in: f32, %in_16: f32, %out: f32):
+ %13 = arith.mulf %in, %in_16 : f32
+ %14 = arith.addf %out, %13 : f32
+ linalg.yield %14 : f32
+ } -> tensor<32xf32>
+ %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32>
+ %5 = tensor.insert_slice %3 into %arg5[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32>
+ scf.yield %5, %4 : tensor<64xf32>, tensor<64xf32>
+ }
+ %out_operand = tensor.empty() : tensor<64xf32>
+ %2 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%1#1, %1#0 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand : tensor<64xf32>) -> tensor<64xf32>
+ return %2 : tensor<64xf32>
+ }
+ }
+
+ module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %1:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ %consumer, %fused_consumer = transform.test.fuse_consumer %1#0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+ transform.yield
+ }
+ }
+
+// CHECK-LABEL: func.func @for_producer_producing_multiple_result_single_consumer(
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<32xf32>,
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<32xf32>,
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<64xf32>
+
+// CHECK: %[[C4:.+]] = arith.constant 4 : index
+// CHECK: %[[C64:.+]] = arith.constant 64 : index
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<64xf32>
+
+// CHECK: %[[LOOP_RESULT:.+]]:3 = scf.for %[[IV:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C64]] step %[[C4]]
+// CHECK-SAME: iter_args(%[[ITER0:.+]] = %[[ARG2]], %[[ITER1:.+]] = %[[ARG2]], %[[ITER2:.+]] = %[[INIT]])
+// CHECK-SAME: -> (tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
+
+// CHECK: %[[EXTRACT_SLICE:.+]] = tensor.extract_slice %[[ITER0]][%[[IV]]] [32] [1]
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<32xf32>, tensor<32xf32>)
+// CHECK-SAME: outs(%[[EXTRACT_SLICE]] : tensor<32xf32>)
+// CHECK: ^{{.*}}(%[[IN0:.+]]: f32, %[[IN1:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK: %[[MUL:.+]] = arith.mulf %[[IN0]], %[[IN1]] : f32
+// CHECK: %[[ADD:.+]] = arith.addf %[[OUT]], %[[MUL]] : f32
+// CHECK: linalg.yield %[[ADD]] : f32
+
+// CHECK: %[[INSERT_SLICE0:.+]] = tensor.insert_slice %[[GENERIC]] into %[[ITER0]][%[[IV]]] [32] [1]
+// CHECK: %[[INSERT_SLICE1:.+]] = tensor.insert_slice %[[GENERIC]] into %[[ITER1]][%[[IV]]] [32] [1]
+// CHECK: %[[EXTRACT_SLICE2:.+]] = tensor.extract_slice %[[ITER2]][%[[IV]]] [32] [1]
+// CHECK: %[[BINARY:.+]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
+// CHECK-SAME: ins(%[[GENERIC]], %[[GENERIC]] : tensor<32xf32>, tensor<32xf32>)
+// CHECK-SAME: outs(%[[EXTRACT_SLICE2]] : tensor<32xf32>)
+// CHECK: %[[INSERT_SLICE2:.+]] = tensor.insert_slice %[[BINARY]] into %[[ITER2]][%[[IV]]] [32] [1]
+
+// CHECK: scf.yield %[[INSERT_SLICE1]], %[[INSERT_SLICE0]], %[[INSERT_SLICE2]]
+
+// CHECK: return %[[LOOP_RESULT]]#2 : tensor<64xf32>
``````````
</details>
https://github.com/llvm/llvm-project/pull/125915
More information about the Mlir-commits
mailing list