[Mlir-commits] [mlir] [mlir] Fix consumer fusion for producer with multiple results (PR #125915)

Prashant Kumar llvmlistbot at llvm.org
Thu Feb 6 03:54:06 PST 2025


https://github.com/pashu123 updated https://github.com/llvm/llvm-project/pull/125915

>From 191008bdb7952e028b806f98d4db28db80399bcb Mon Sep 17 00:00:00 2001
From: Prashant Kumar <pk5561 at gmail.com>
Date: Mon, 3 Feb 2025 22:12:01 +0530
Subject: [PATCH] [mlir] Fix consumer fusion for producer with multiple results

In the case of consumer fusion where the producer is producing multiple
results all used by a single consumer for e.g.,

%results:3 = scf.forall ... -> (tensor<...>, tensor<...>, tensor<...>) {
// Produces 3 results
     scf.yield %a, %b, %c : tensor<...>, tensor<...>, tensor<...>}
// Consumer uses all 3 results
%final = consumer %results#0, %results#1, %results#2

all other operands of the tiled consumer needs to updated.
---
 .../SCF/Transforms/TileUsingInterface.cpp     | 125 +++++++++++++++--
 .../tile-and-fuse-consumer.mlir               | 132 +++++++++++++++++-
 2 files changed, 238 insertions(+), 19 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index b548f8ce8b560b1..3c2324d62021125 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1949,6 +1949,60 @@ getUntiledConsumerFromSlice(RewriterBase &rewriter, Operation *sliceOp) {
   }
 }
 
+// If the producer of the operand is a loopLikeOp, then finds the last
+// insertSlice/parallelInsertSlice in the producer op that uses the block
+// argument corresponding to the operand.
+static FailureOr<Operation *>
+getSliceOpFromConsumerOperand(OpOperand &operand) {
+
+  OpResult producerResult = dyn_cast<OpResult>(operand.get());
+  if (!producerResult)
+    return failure();
+
+  LoopLikeOpInterface loopLikeOp =
+      dyn_cast<LoopLikeOpInterface>(producerResult.getOwner());
+  if (!loopLikeOp)
+    return failure();
+
+  // Obtain the BlockArgument correponding to the result.
+  BlockArgument bbArg =
+      loopLikeOp.getRegionIterArgs()[producerResult.getResultNumber()];
+
+  // Finally return the operation corresponding to the yielded value.
+  // Also check whether it's an InsertSliceOp.
+  if (dyn_cast<scf::ForOp>(producerResult.getOwner())) {
+    OpOperand *yieldVal = loopLikeOp.getTiedLoopYieldedValue(bbArg);
+    Operation *lastOp = dyn_cast<OpResult>(yieldVal->get()).getOwner();
+    auto isInsertSliceOp = isa<tensor::InsertSliceOp>(lastOp);
+    if (!isInsertSliceOp) {
+      return failure();
+    }
+    return lastOp;
+  }
+
+  auto forallOp = dyn_cast<scf::ForallOp>(producerResult.getOwner());
+  if (!forallOp)
+    return failure();
+
+  // Iterate over the terminator operation of the forallOp to find the last
+  // parallelInsertSliceOp that uses the blockArgument.
+  Operation *lastOp = nullptr;
+  forallOp.getTerminator()->walk([&](tensor::ParallelInsertSliceOp op) {
+    for (mlir::Value operand : op->getOperands()) {
+      if (auto maybeBlockArg = dyn_cast<BlockArgument>(operand)) {
+        if (maybeBlockArg == bbArg) {
+          lastOp = op;
+        }
+      }
+    }
+  });
+
+  if (!lastOp)
+    return failure();
+
+  return lastOp;
+}
+
 /// Implementation of fusing consumer of a single slice by computing the
 /// slice of the consumer in-place for scf loop.
 FailureOr<scf::SCFFuseConsumerOfSliceResult>
@@ -1979,6 +2033,26 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
         consumerOp, "consumer op's operand doesn't seem to be an OpResult");
   }
 
+  SmallVector<OpOperand *> potentialOperands = {*maybeConsumerOpOperand};
+  SmallVector<unsigned> potentialOperandResultNos = {
+      consumerOpOperand->getOperandNumber()};
+  SmallVector<Operation *> potentialSliceOps = {candidateSliceOp};
+
+  // 1b. Get all the other operands of the consumer op and their corresponding
+  // slice ops. In the case of the consumer using multiple results
+  // from the producer, we need to update every operand.
+  for (OpOperand &otherOperand : consumerOp->getOpOperands()) {
+    if (&otherOperand == *maybeConsumerOpOperand)
+      continue;
+    auto maybePotentialSlice = getSliceOpFromConsumerOperand(otherOperand);
+    if (failed(maybePotentialSlice)) {
+      continue;
+    }
+    potentialSliceOps.push_back(*maybePotentialSlice);
+    potentialOperands.push_back(&otherOperand);
+    potentialOperandResultNos.push_back(otherOperand.getOperandNumber());
+  }
+
   // There are two possible cases regarding `oldLoopOp` here:
   // 1. single `scf.forall` or `scf.for`.
   // 2. inner-most `scf.for` insider nest `scf.loop` structure, where the
@@ -2037,18 +2111,29 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
   // tensor.insert_slice. In the scf.for case this is a clone of the
   // candidateSliceOp whereas in the scf.forall case this is created from the
   // operands of tensor.parallel_insert_slice.
-  tensor::InsertSliceOp clonedInsertSliceOp;
+
+  SmallVector<tensor::InsertSliceOp> allClonedInsertSliceOps;
+
+  scf::ForallOp newForallOp;
   if (auto sliceOp =
           dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
     auto newForallOp = cast<scf::ForallOp>(innerMostLoop.getOperation());
     rewriter.setInsertionPoint(newForallOp.getTerminator());
-    clonedInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
-        loc, sliceOp.getSource(), sliceOp.getDest(), sliceOp.getMixedOffsets(),
-        sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
   } else {
-    rewriter.setInsertionPoint(candidateSliceOp);
-    clonedInsertSliceOp =
-        cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp));
+    rewriter.setInsertionPoint(potentialSliceOps.back());
+  }
+
+  for (auto *candidateSliceOp : potentialSliceOps) {
+    if (auto sliceOp =
+            dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
+      allClonedInsertSliceOps.push_back(rewriter.create<tensor::InsertSliceOp>(
+          loc, sliceOp.getSource(), sliceOp.getDest(),
+          sliceOp.getMixedOffsets(), sliceOp.getMixedSizes(),
+          sliceOp.getMixedStrides()));
+    } else {
+      allClonedInsertSliceOps.push_back(
+          cast<tensor::InsertSliceOp>(rewriter.clone(*candidateSliceOp)));
+    }
   }
 
   // 5.a. Clone consumer op.
@@ -2056,24 +2141,34 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
 
   // 5.b. Replace all uses of the loop result with the result of the cloned
   // tensor.insert_slice.
-  OpOperand &operandToReplace = clonedConsumerOp->getOpOperand(operandNumber);
-  rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
-    operandToReplace.set(clonedInsertSliceOp.getResult());
-  });
+  for (const auto &it : llvm::enumerate(allClonedInsertSliceOps)) {
+    OpOperand &operandToReplace =
+        clonedConsumerOp->getOpOperand(potentialOperandResultNos[it.index()]);
+    rewriter.modifyOpInPlace(clonedConsumerOp, [&]() {
+      operandToReplace.set(it.value().getResult());
+    });
+  }
 
   // 6. Perform tiling of the cloned consumer and replace the operand at
   // `operandNumber` with the source of the cloned tensor.insert_slice op.
-  auto ossSliceOp =
-      cast<OffsetSizeAndStrideOpInterface>(clonedInsertSliceOp.getOperation());
+  auto ossSliceOp = cast<OffsetSizeAndStrideOpInterface>(
+      allClonedInsertSliceOps.front().getOperation());
   FailureOr<TilingResult> tileAndFuseResult =
       tensor::replaceInsertSliceWithTiledConsumer(
           rewriter, ossSliceOp, clonedConsumerOp->getOpOperand(operandNumber));
+
   if (failed(tileAndFuseResult)) {
     return failure();
   }
+
   auto tiledConsumerOp = cast<TilingInterface>(tileAndFuseResult->tiledOps[0]);
-  rewriter.replaceAllUsesWith(tiledConsumerOp->getOperand(operandNumber),
-                              clonedInsertSliceOp.getSource());
+
+  // 6b. Update the tiled consumer op with the new operands.
+  for (const auto &it : llvm::enumerate(allClonedInsertSliceOps)) {
+    rewriter.replaceAllUsesWith(
+        tiledConsumerOp->getOperand(potentialOperandResultNos[it.index()]),
+        it.value().getSource());
+  }
 
   // 7. Reconstruct [nested] loop with new inits.
   YieldTiledValuesFn newYieldValuesFn =
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
index a2871b30698c527..14b9ec504c1585e 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir
@@ -282,7 +282,7 @@ module {
         return %unpack : tensor<2048xf32>
     }
 }
-  
+
 module attributes {transform.with_named_sequence} {
     transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
         %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
@@ -343,7 +343,7 @@ module {
         return %unpack : tensor<2047xf32>
     }
 }
-  
+
 module attributes {transform.with_named_sequence} {
     transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
         %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
@@ -404,7 +404,7 @@ module {
         return %pack : tensor<4x32x16xf32>
     }
 }
-  
+
 module attributes {transform.with_named_sequence} {
     transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
         %slice_op = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1
@@ -610,7 +610,7 @@ module attributes {transform.with_named_sequence} {
 // CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
 //      CHECK:   %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
 //      CHECK:   %[[LOOP_RESULT:.*]]:3 = scf.for %[[IV1:.*]] = %[[C0]]
-// CHECK-SAME:       iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]]) 
+// CHECK-SAME:       iter_args(%[[FIRST_OUT_ARG:.*]] = %[[dest0]], %[[SECOND_OUT_ARG:.*]] = %[[dest0]], %[[THIRD_OUT_ARG:.*]] = %[[dest0]])
 // CHECK-SAME:   {
 //      CHECK:          %[[ADD_OUT_SLICE:.*]] = tensor.extract_slice %[[FIRST_OUT_ARG]][%[[IV1]], 0] [64, 256] [1, 1]
 //      CHECK:          %[[ADD_INS0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[IV1]], 0] [64, 256] [1, 1]
@@ -676,3 +676,127 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:   }
 //      CHECK:   %[[RES_SLICE:.+]] = tensor.insert_slice
 //      CHECK:   return %[[LOOP_RESULT]]#1, %[[RES_SLICE]]
+
+// -----
+
+module {
+ func.func @forall_producer_multiple_result_single_consumer(%arg2: tensor<64x64xf32>) -> tensor<64x64xf32> {
+   %c4 = arith.constant 4 : index
+   %c64 = arith.constant 64 : index
+   %c0 = arith.constant 0 : index
+   %1:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %arg2, %arg6 = %arg2) -> (tensor<64x64xf32>, tensor<64x64xf32>) {
+     %outs =  tensor.empty() : tensor<32x32xf32>
+     %extracted_slice = tensor.extract_slice %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<64x64xf32> to tensor<32x32xf32>
+     %3 = linalg.matmul ins(%extracted_slice, %extracted_slice : tensor<32x32xf32>, tensor<32x32xf32>) outs(%outs : tensor<32x32xf32>) -> tensor<32x32xf32>
+     scf.forall.in_parallel {
+        tensor.parallel_insert_slice %3 into %arg6[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+        tensor.parallel_insert_slice %extracted_slice into %arg5[%arg3, %arg4] [32, 32] [1, 1] : tensor<32x32xf32> into tensor<64x64xf32>
+     }
+   }
+   %final_out = tensor.empty() : tensor<64x64xf32>
+   %2 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%1#0, %1#1 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%final_out : tensor<64x64xf32>) -> tensor<64x64xf32>
+   return %2 : tensor<64x64xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+   transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+     %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+     %1:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+     %consumer, %fused_consumer = transform.test.fuse_consumer %1#0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+     transform.yield
+   }
+}
+
+// CHECK-LABEL: func.func @forall_producer_multiple_result_single_consumer(
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<64x64xf32>
+
+// CHECK:     %[[INIT:.+]] = tensor.empty() : tensor<64x64xf32>
+// CHECK:     %[[LOOP_RESULT:.+]]:3 = scf.forall (%[[I:.+]], %[[J:.+]]) in (2, 2) shared_outs(%[[SHARED0:.+]] = %[[ARG0]], %[[SHARED1:.+]] = %[[ARG0]], %[[SHARED2:.+]] = %[[INIT]])
+
+// CHECK:       %[[TILE_INIT:.+]] = tensor.empty() : tensor<32x32xf32>
+// CHECK:       %[[EXTRACTED_SLICE:.+]] = tensor.extract_slice %[[SHARED0]][%[[I]], %[[J]]] [32, 32] [1, 1]
+// CHECK:       %[[MATMUL:.+]] = linalg.matmul ins(%[[EXTRACTED_SLICE]], %[[EXTRACTED_SLICE]] : tensor<32x32xf32>, tensor<32x32xf32>) outs(%[[TILE_INIT]] : tensor<32x32xf32>)
+// CHECK:       %[[INSERTED_SLICE:.+]] = tensor.insert_slice %[[MATMUL]] into %[[SHARED1]][%[[I]], %[[J]]] [32, 32] [1, 1]
+// CHECK:       %[[INSERTED_SLICE0:.+]] = tensor.insert_slice %[[EXTRACTED_SLICE]] into %[[SHARED0]][%[[I]], %[[J]]] [32, 32] [1, 1]
+// CHECK:       %[[EXTRACTED_SLICE1:.+]] = tensor.extract_slice %[[SHARED2]][%[[I]], %[[J]]] [32, 32] [1, 1]
+// CHECK:       %[[ADD:.+]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%[[EXTRACTED_SLICE]], %[[MATMUL]] : tensor<32x32xf32>, tensor<32x32xf32>) outs(%[[EXTRACTED_SLICE1]] : tensor<32x32xf32>)
+
+// CHECK:       scf.forall.in_parallel {
+// CHECK:         tensor.parallel_insert_slice %[[MATMUL]] into %[[SHARED1]][%[[I]], %[[J]]] [32, 32] [1, 1]
+// CHECK:         tensor.parallel_insert_slice %[[EXTRACTED_SLICE]] into %[[SHARED0]][%[[I]], %[[J]]] [32, 32] [1, 1]
+// CHECK:         tensor.parallel_insert_slice %[[ADD]] into %[[SHARED2]][%[[I]], %[[J]]] [32, 32] [1, 1]
+// CHECK:       }
+
+// CHECK:     return %[[LOOP_RESULT]]#2 : tensor<64x64xf32>
+
+
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+module {
+  func.func @for_producer_producing_multiple_result_single_consumer(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+    %c4 = arith.constant 4 : index
+    %c64 = arith.constant 64 : index
+    %c0 = arith.constant 0 : index
+    %1:2 = scf.for %arg3 = %c0 to %c64 step %c4 iter_args(%arg4 = %arg2, %arg5 = %arg2) -> (tensor<64xf32>, tensor<64xf32>) {
+      %extracted_slice = tensor.extract_slice %arg4[%arg3] [32] [1] : tensor<64xf32> to tensor<32xf32>
+      %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%arg0, %arg1 : tensor<32xf32>, tensor<32xf32>) outs(%extracted_slice : tensor<32xf32>) {
+        ^bb0(%in: f32, %in_16: f32, %out: f32):
+          %13 = arith.mulf %in, %in_16 : f32
+          %14 = arith.addf %out, %13 : f32
+          linalg.yield %14 : f32
+        } -> tensor<32xf32>
+      %4 = tensor.insert_slice %3 into %arg4[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32>
+      %5 = tensor.insert_slice %3 into %arg5[%arg3] [32] [1] : tensor<32xf32> into tensor<64xf32>
+      scf.yield %5, %4 : tensor<64xf32>, tensor<64xf32>
+    }
+    %out_operand = tensor.empty() : tensor<64xf32>
+    %2 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%1#1, %1#0 : tensor<64xf32>, tensor<64xf32>) outs(%out_operand : tensor<64xf32>) -> tensor<64xf32>
+    return %2 : tensor<64xf32>
+  }
+  }
+
+  module attributes {transform.with_named_sequence} {
+    transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+      %0 = transform.structured.match ops{["tensor.insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+      %1:2 = transform.split_handle %0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+      %consumer, %fused_consumer = transform.test.fuse_consumer %1#0 : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+      transform.yield
+    }
+  }
+
+// CHECK-LABEL: func.func @for_producer_producing_multiple_result_single_consumer(
+// CHECK-SAME:   %[[ARG0:[a-zA-Z0-9_]+]]: tensor<32xf32>,
+// CHECK-SAME:   %[[ARG1:[a-zA-Z0-9_]+]]: tensor<32xf32>,
+// CHECK-SAME:   %[[ARG2:[a-zA-Z0-9_]+]]: tensor<64xf32>
+
+// CHECK:       %[[C4:.+]] = arith.constant 4 : index
+// CHECK:       %[[C64:.+]] = arith.constant 64 : index
+// CHECK:       %[[C0:.+]] = arith.constant 0 : index
+// CHECK:       %[[INIT:.+]] = tensor.empty() : tensor<64xf32>
+
+// CHECK:       %[[LOOP_RESULT:.+]]:3 = scf.for %[[IV:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C64]] step %[[C4]]
+// CHECK-SAME:        iter_args(%[[ITER0:.+]] = %[[ARG2]], %[[ITER1:.+]] = %[[ARG2]], %[[ITER2:.+]] = %[[INIT]])
+// CHECK-SAME:         -> (tensor<64xf32>, tensor<64xf32>, tensor<64xf32>)
+
+// CHECK:           %[[EXTRACT_SLICE:.+]] = tensor.extract_slice %[[ITER0]][%[[IV]]] [32] [1]
+// CHECK:           %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME:        ins(%[[ARG0]], %[[ARG1]] : tensor<32xf32>, tensor<32xf32>)
+// CHECK-SAME:        outs(%[[EXTRACT_SLICE]] : tensor<32xf32>)
+// CHECK:             ^{{.*}}(%[[IN0:.+]]: f32, %[[IN1:.+]]: f32, %[[OUT:.+]]: f32):
+// CHECK:             %[[MUL:.+]] = arith.mulf %[[IN0]], %[[IN1]] : f32
+// CHECK:             %[[ADD:.+]] = arith.addf %[[OUT]], %[[MUL]] : f32
+// CHECK:           linalg.yield %[[ADD]] : f32
+
+// CHECK:       %[[INSERT_SLICE0:.+]] = tensor.insert_slice %[[GENERIC]] into %[[ITER0]][%[[IV]]] [32] [1]
+// CHECK:       %[[INSERT_SLICE1:.+]] = tensor.insert_slice %[[GENERIC]] into %[[ITER1]][%[[IV]]] [32] [1]
+// CHECK:       %[[EXTRACT_SLICE2:.+]] = tensor.extract_slice %[[ITER2]][%[[IV]]] [32] [1]
+// CHECK:       %[[BINARY:.+]] = linalg.elemwise_binary {fun = #linalg.binary_fn<add>}
+// CHECK-SAME:      ins(%[[GENERIC]], %[[GENERIC]] : tensor<32xf32>, tensor<32xf32>)
+// CHECK-SAME:      outs(%[[EXTRACT_SLICE2]] : tensor<32xf32>)
+// CHECK:       %[[INSERT_SLICE2:.+]] = tensor.insert_slice %[[BINARY]] into %[[ITER2]][%[[IV]]] [32] [1]
+
+// CHECK:       scf.yield %[[INSERT_SLICE1]], %[[INSERT_SLICE0]], %[[INSERT_SLICE2]]
+
+// CHECK:     return %[[LOOP_RESULT]]#2 : tensor<64xf32>



More information about the Mlir-commits mailing list