[Mlir-commits] [mlir] [mlir][scf] Extend fuse producer to multi-level candidates case (PR #97803)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Sep 3 01:19:58 PDT 2024


https://github.com/Yun-Fly updated https://github.com/llvm/llvm-project/pull/97803

>From f7dd9d8ff798cc083bc7a5b41a65499e730c814b Mon Sep 17 00:00:00 2001
From: "Song, Yunfei" <yunfei.song at intel.com>
Date: Fri, 5 Jul 2024 02:12:59 -0700
Subject: [PATCH 1/2] extend fuse producer to multi-level extractSliceOp

---
 .../SCF/Transforms/TileUsingInterface.h       |   4 +
 .../SCF/Transforms/TileUsingInterface.cpp     | 149 +++++++++++++++++-
 .../tile-and-fuse-producer.mlir               |  86 ++++++++++
 .../TestTilingInterfaceTransformOps.cpp       |  50 ++++++
 .../TestTilingInterfaceTransformOps.td        |  19 +++
 5 files changed, 303 insertions(+), 5 deletions(-)
 create mode 100644 mlir/test/Interfaces/TilingInterface/tile-and-fuse-producer.mlir

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 1f21af6d6a29ac..76fdda3645a017 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -157,11 +157,15 @@ struct SCFFuseProducerOfSliceResult {
   Value tiledAndFusedProducer; // Tile and fused producer value.
   SmallVector<Operation *> tiledOps;
 };
+
 std::optional<SCFFuseProducerOfSliceResult>
 tileAndFuseProducerOfSlice(RewriterBase &rewriter,
                            tensor::ExtractSliceOp candidateSliceOp,
                            MutableArrayRef<LoopLikeOpInterface> loops);
 
+std::optional<SCFFuseProducerOfSliceResult>
+tileAndFuseProducerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp);
+
 /// Reconstruct the fused producer from within the tiled-and-fused code. Based
 /// on the slice of the producer computed in place it is possible that within
 /// the loop nest same slice of the producer is computed multiple times. It is
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index e404c01010a325..ef4235c6015ad5 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1068,12 +1068,12 @@ getUntiledProducerFromSliceSource(OpOperand *source,
   return {dyn_cast<OpResult>(source->get()), destinationIterArg};
 }
 
-/// Implementation of fusing producer of a single slice by computing the
+/// Basic implementation of fusing producer of a single slice by computing the
 /// slice of the producer in-place.
-std::optional<scf::SCFFuseProducerOfSliceResult>
-mlir::scf::tileAndFuseProducerOfSlice(
-    RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
-    MutableArrayRef<LoopLikeOpInterface> loops) {
+static std::optional<scf::SCFFuseProducerOfSliceResult>
+tileAndFuseProducerOfSliceImpl(RewriterBase &rewriter,
+                               tensor::ExtractSliceOp candidateSliceOp,
+                               MutableArrayRef<LoopLikeOpInterface> loops) {
   // 1. Get the producer of the source (potentially walking through
   // `iter_args` of nested `scf.for`)
   auto [fusableProducer, destinationInitArg] =
@@ -1185,6 +1185,145 @@ mlir::scf::tileAndFuseProducerOfSlice(
                                            tileAndFuseResult->tiledOps};
 }
 
+/// Get the real producer from candidate ExtractSliceOp
+///
+/// ```
+/// %0 = producer
+/// %1 = scf.for(%arg1 = %0)
+///   %2 = extract %arg1
+///   %3 = scf.for(%arg2 = %2)
+///      %4 = extract %args2
+///      ...
+/// ```
+///
+/// @param candidateSliceOp: %4 = extract %args2
+/// @param backwardSlice: in-out parameter populated by backward extractSliceOps
+/// @return OpResult Producer : %0 = producer
+static FailureOr<OpResult> getRealProducerFromExtractSliceOp(
+    Operation *candidateSliceOp,
+    SmallVector<tensor::ExtractSliceOp> &backwardSlice, int curDepth = 0,
+    int maxDepth = 5) {
+  if (!isa<tensor::ExtractSliceOp>(candidateSliceOp))
+    return failure();
+  // control recursive time in avoid of stack overflow
+  if (curDepth > maxDepth)
+    return failure();
+
+  auto extractOp = cast<tensor::ExtractSliceOp>(candidateSliceOp);
+  backwardSlice.push_back(extractOp);
+  Value rootSource = extractOp.getSourceMutable().get();
+
+  while (true) {
+    if (auto iterArg = dyn_cast<BlockArgument>(rootSource)) {
+      if (auto outerLoop = dyn_cast<LoopLikeOpInterface>(
+              iterArg.getOwner()->getParentOp())) {
+        rootSource = outerLoop.getTiedLoopInit(iterArg)->get();
+        continue;
+      }
+      return failure();
+    } else if (auto sliceOp =
+                   rootSource.getDefiningOp<tensor::ExtractSliceOp>()) {
+      // walk up loop to find larger candidate extractSliceOp
+      return getRealProducerFromExtractSliceOp(sliceOp, backwardSlice,
+                                               curDepth + 1);
+    }
+    break;
+  }
+  return dyn_cast<OpResult>(rootSource);
+}
+
+/// Recursively find the outer nest loops of given loop(included) while the
+/// predict function succeed, sorted from outer to inner.
+///
+/// @param loop: target loop, note that this loop will be also included. I.e.
+///              if no other nest loops were found, just return itself.
+/// @param pred: predict function, the termination condition of recursive
+/// process.
+/// @return Outer Nest Loops: nest loops outside given target loop(included).
+///
+/// E.g.
+///
+/// ```
+///  %0 = scf.for()
+///    %1 = scf.for()
+///      %2 = scf.for()
+/// ```
+///
+/// If `%2 = scf.for` is given without specific prediction function, this
+/// function will return three nest loops: %0 + %1 + %2.
+static SmallVector<LoopLikeOpInterface> getOuterNestLoopsWhile(
+    LoopLikeOpInterface loop,
+    const std::function<LogicalResult(LoopLikeOpInterface)> &pred) {
+  SmallVector<LoopLikeOpInterface> nestLoops = {loop};
+  auto outerLoop = dyn_cast<LoopLikeOpInterface>(loop->getParentOp());
+  while (outerLoop && succeeded(pred(outerLoop))) {
+    nestLoops.push_back(outerLoop);
+    outerLoop = dyn_cast<LoopLikeOpInterface>(outerLoop->getParentOp());
+  }
+  // sorted from outer to inner
+  return {nestLoops.rbegin(), nestLoops.rend()};
+}
+
+/// Enhanced version for basic implementation of fusing producer, which can deal
+/// with multi-level candidates. E.g.
+///
+/// ```
+/// %0 = untiled_producer
+/// %1 = scf.for(%arg1 = %0)
+///   %2 = tensor.extract_slice %arg1
+///   %3 = scf.for(%arg2 = %2)
+///      %4 = tensor.extract_slice %args2
+///      %5 = tiled_consumer ins(%4)
+/// ```
+///
+/// This utility can fuse untiled producer at `%4 = tensor.extract_slice` within
+/// inner loop `%3 = scf.for`.
+std::optional<scf::SCFFuseProducerOfSliceResult>
+mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
+                                      Operation *candidateSliceOp) {
+  SmallVector<tensor::ExtractSliceOp> backwardSlice;
+  if (failed(
+          getRealProducerFromExtractSliceOp(candidateSliceOp, backwardSlice))) {
+    return std::nullopt;
+  }
+
+  std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResult;
+  // reverse from outer to inner
+  std::reverse(backwardSlice.begin(), backwardSlice.end());
+  // multiple application of `tileAndFuseProducerOfSliceImpl`
+  for (auto &&[index, sliceOp] : llvm::enumerate(backwardSlice)) {
+    // get nest loops between next candidate sliceOp and tiled producer.
+    auto whileProducerOutOfLoopBlock =
+        [&fuseProducerResult](LoopLikeOpInterface loop) -> LogicalResult {
+      if (fuseProducerResult) {
+        Block &body = loop->getRegion(0).front();
+        if (fuseProducerResult->tiledAndFusedProducer.getDefiningOp()
+                ->getBlock() == &body)
+          return failure();
+      }
+      return success();
+    };
+    SmallVector<LoopLikeOpInterface> outerLoops =
+        getOuterNestLoopsWhile(sliceOp->getParentOfType<LoopLikeOpInterface>(),
+                               whileProducerOutOfLoopBlock);
+    fuseProducerResult =
+        tileAndFuseProducerOfSliceImpl(rewriter, sliceOp, outerLoops);
+    if (!fuseProducerResult) {
+      return std::nullopt;
+    }
+  }
+  return fuseProducerResult;
+}
+
+/// Implementation of fusing producer of a single slice by computing the
+/// slice of the producer in-place.
+std::optional<scf::SCFFuseProducerOfSliceResult>
+mlir::scf::tileAndFuseProducerOfSlice(
+    RewriterBase &rewriter, tensor::ExtractSliceOp candidateSliceOp,
+    MutableArrayRef<LoopLikeOpInterface> loops) {
+  return tileAndFuseProducerOfSliceImpl(rewriter, candidateSliceOp, loops);
+}
+
 /// Reconstruct the fused producer from within the tiled-and-fused code.
 LogicalResult mlir::scf::yieldReplacementForFusedProducer(
     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-producer.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-producer.mlir
new file mode 100644
index 00000000000000..ef1c6952a55e1a
--- /dev/null
+++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-producer.mlir
@@ -0,0 +1,86 @@
+// RUN: mlir-opt --transform-interpreter --cse --split-input-file %s | FileCheck %s
+
+#map = affine_map<(d0) -> (d0 * 128)>
+module {
+  func.func @gemm_fill_fusion_multi_level_extract_slice(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
+    %c0 = arith.constant 0 : index
+    %c64 = arith.constant 64 : index
+    %c128 = arith.constant 128 : index
+    %cst = arith.constant 0.000000e+00 : f32
+    %dest0 = tensor.empty() : tensor<256x256xf32>
+    %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
+    %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %dest1) -> tensor<256x256xf32> {
+      %iv0 = affine.apply #map(%arg3)
+      %iv1 = affine.apply #map(%arg4)
+      %extracted_slice_1 = tensor.extract_slice %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
+      %extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32>
+      %extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32>
+      %2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args(%arg7 = %extracted_slice_1) -> (tensor<128x128xf32>) {
+        %3 = scf.for %arg8 = %c0 to %c128 step %c64 iter_args(%arg9 = %arg7) -> (tensor<128x128xf32>) {
+          %extracted_slice_4 = tensor.extract_slice %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
+          %extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg6, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32>
+          %extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg8] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32>
+          %4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) -> tensor<64x64xf32>
+          %insert_slice = tensor.insert_slice %4 into %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
+          scf.yield %insert_slice : tensor<128x128xf32>
+        }
+        scf.yield %3 : tensor<128x128xf32>
+      }
+      scf.forall.in_parallel {
+         tensor.parallel_insert_slice %2 into %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
+      }
+    }
+    return %1 : tensor<256x256xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) {
+    %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    %yield = transform.get_producer_of_operand %matmul[2]
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_producer %yield
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+
+//      CHECK: #[[MAP0:.*]] =  affine_map<(d0) -> (d0 * 128)>
+//      CHECK: func.func @gemm_fill_fusion_multi_level_extract_slice(
+// CHECK-SAME:     %[[ARG0:[a-zA-Z0-9]+]]: tensor<256x512xf32>
+// CHECK-SAME:     %[[ARG1:[a-zA-Z0-9]+]]: tensor<512x256xf32>
+// CHECK-SAME:     %[[ARG2:[a-zA-Z0-9]+]]: tensor<256x256xf32>
+//      CHECK:   %[[C0:.*]] = arith.constant 0 : index
+//      CHECK:   %[[dest0:.*]] = tensor.empty() : tensor<256x256xf32>
+//      CHECK:   %[[FORALL_RESULT:.*]] = scf.forall (%[[IV1:.*]], %[[IV2:.*]]) in (2, 2)
+// CHECK-SAME:      shared_outs(%[[INIT_ARG0:.*]] = %[[dest0]])
+// CHECK-SAME:   {
+//      CHECK:      %[[AFFINE_IV1:.*]] = affine.apply #[[MAP0]](%[[IV1]])
+//      CHECK:      %[[AFFINE_IV2:.*]] = affine.apply #[[MAP0]](%[[IV2]])
+//      CHECK:      %[[FILL_OUT_SLICE0:.*]] = tensor.extract_slice %[[INIT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+//      CHECK:      %[[INPUT_SLICE0:.*]] = tensor.extract_slice %[[ARG0]][%[[AFFINE_IV1]], 0] [128, 512] [1, 1]
+//      CHECK:      %[[WEIGHT_SLICE0:.*]] = tensor.extract_slice %[[ARG1]][0, %[[AFFINE_IV2]]] [512, 128] [1, 1]
+//      CHECK:      %[[LOOP_RESULT1:.*]] = scf.for %[[IV3:.*]] = %[[C0]]
+// CHECK-SAME:          iter_args(%[[INIT_ARG1:.*]] = %[[FILL_OUT_SLICE0]])
+// CHECK-SAME:      {
+//      CHECK:          %[[LOOP_RESULT2:.*]] = scf.for %[[IV4:.*]] = %[[C0]]
+// CHECK-SAME:            iter_args(%[[INIT_ARG2:.*]] = %[[INIT_ARG1]])
+// CHECK-SAME:          {
+//      CHECK:            %[[FILL_OUT_SLICE1:.*]] = tensor.extract_slice %[[INIT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+//      CHECK:            %[[TILED_FILL_OUT:.*]] = linalg.fill
+// CHECK-SAME:                  outs(%[[FILL_OUT_SLICE1]] :
+//      CHECK:            %[[INPUT_SLICE1:.*]] = tensor.extract_slice %[[INPUT_SLICE0]][%[[IV3]], 0] [64, 512] [1, 1]
+//      CHECK:            %[[WEIGHT_SLICE1:.*]] = tensor.extract_slice %[[WEIGHT_SLICE0]][0, %[[IV4]]] [512, 64] [1, 1]
+//      CHECK:            %[[TILED_MAT_OUT:.*]] = linalg.matmul
+// CHECK-SAME:                  outs(%[[TILED_FILL_OUT]] :
+//      CHECK:            %[[INSERT_MAT:.*]] = tensor.insert_slice %[[TILED_MAT_OUT]] into %[[INIT_ARG2]][%[[IV3]], %[[IV4]]] [64, 64] [1, 1]
+//      CHECK:            scf.yield %[[INSERT_MAT]] :
+//      CHECK:          }
+//      CHECK:          scf.yield %[[LOOP_RESULT2]] :
+//      CHECK:      }
+//      CHECK:      scf.forall.in_parallel {
+//      CHECK:          tensor.parallel_insert_slice %[[LOOP_RESULT1]] into %[[INIT_ARG0]][%[[AFFINE_IV1]], %[[AFFINE_IV2]]] [128, 128] [1, 1]
+//      CHECK:       }
+//      CHECK:   }
+//      CHECK:   return %[[FORALL_RESULT]] :
\ No newline at end of file
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
index 7aa7b58433f36c..b4dad98e2399c8 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp
@@ -160,6 +160,56 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter,
                         : DiagnosedSilenceableFailure::success();
 }
 
+//===----------------------------------------------------------------------===//
+// TestFuseProducerOp
+//===----------------------------------------------------------------------===//
+
+/// Apply fusing of producer transformation to all payload ops and store both
+/// the original producer operation as well as the fused producer operation.
+template <typename Range>
+static LogicalResult
+applyFuseProducer(RewriterBase &rewriter, Operation *transformOp,
+                  Range &&payloadOps, TransformResults &transformResults) {
+  SmallVector<Operation *> originalProducerOps;
+  SmallVector<Operation *> fusedProducerOps;
+
+  for (Operation *target : payloadOps) {
+    rewriter.setInsertionPoint(target);
+
+    std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResults =
+        scf::tileAndFuseProducerOfSlice(rewriter, target);
+
+    if (!fuseProducerResults)
+      return failure();
+
+    // Report back the relevant handles to the transform op.
+    originalProducerOps.push_back(fuseProducerResults->origProducer.getOwner());
+    fusedProducerOps.push_back(fuseProducerResults->tiledOps[0]);
+  }
+
+  transformResults.set(transformOp->getOpResult(0), originalProducerOps);
+  transformResults.set(transformOp->getOpResult(1), fusedProducerOps);
+  return success();
+}
+
+DiagnosedSilenceableFailure
+transform::TestFuseProducerOp::apply(TransformRewriter &rewriter,
+                                     TransformResults &transformResults,
+                                     TransformState &state) {
+  LogicalResult result =
+      applyFuseProducer(rewriter, getOperation(),
+                        state.getPayloadOps(getTarget()), transformResults);
+  return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
+                        : DiagnosedSilenceableFailure::success();
+}
+
+void transform::TestFuseProducerOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getTargetMutable(), effects);
+  producesHandle(getOperation()->getOpResults(), effects);
+  modifiesPayload(effects);
+}
+
 //===----------------------------------------------------------------------===//
 // TestFuseConsumerOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
index d55d746bd6aa90..6e73478c35c4af 100644
--- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
+++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td
@@ -49,6 +49,25 @@ def TestFuseAndYieldOp : Op<Transform_Dialect, "test.fuse_and_yield",
   }];
 }
 
+def TestFuseProducerOp : Op<Transform_Dialect, "test.fuse_producer",
+       [DeclareOpInterfaceMethods<TransformOpInterface>,
+        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+        ReportTrackingListenerFailuresOpTrait]> {
+  let description = [{
+    Fuses the producer of the operation pointed to by the target handle
+    using the options provided as attributes.
+  }];
+
+  let arguments =
+    (ins TransformHandleTypeInterface:$target);
+  let results = (outs TransformHandleTypeInterface:$producer,
+                      TransformHandleTypeInterface:$fused_producer);
+
+  let assemblyFormat = [{
+    $target attr-dict `:` functional-type(operands, results)
+  }];
+}
+
 def TestFuseConsumerOp : Op<Transform_Dialect, "test.fuse_consumer",
        [DeclareOpInterfaceMethods<TransformOpInterface>,
         DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,

>From 23796bfe3d1483950cf9175ff1c87978b6eb0bf1 Mon Sep 17 00:00:00 2001
From: "Song, Yunfei" <yunfei.song at intel.com>
Date: Tue, 6 Aug 2024 21:25:05 -0700
Subject: [PATCH 2/2] add `isForOpYieldResultOfInnerLoop` check

---
 .../SCF/Transforms/TileUsingInterface.cpp     | 47 +++++++++++++------
 1 file changed, 33 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index ef4235c6015ad5..1b8b5a3e7f3dba 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -1251,9 +1251,9 @@ static FailureOr<OpResult> getRealProducerFromExtractSliceOp(
 ///
 /// If `%2 = scf.for` is given without specific prediction function, this
 /// function will return three nest loops: %0 + %1 + %2.
-static SmallVector<LoopLikeOpInterface> getOuterNestLoopsWhile(
-    LoopLikeOpInterface loop,
-    const std::function<LogicalResult(LoopLikeOpInterface)> &pred) {
+static SmallVector<LoopLikeOpInterface>
+getOuterNestLoopsWhile(LoopLikeOpInterface loop,
+                       function_ref<LogicalResult(LoopLikeOpInterface)> pred) {
   SmallVector<LoopLikeOpInterface> nestLoops = {loop};
   auto outerLoop = dyn_cast<LoopLikeOpInterface>(loop->getParentOp());
   while (outerLoop && succeeded(pred(outerLoop))) {
@@ -1264,6 +1264,21 @@ static SmallVector<LoopLikeOpInterface> getOuterNestLoopsWhile(
   return {nestLoops.rbegin(), nestLoops.rend()};
 }
 
+/// Check if it is the ForOp that yield the result of inner loop
+static LogicalResult isForOpYieldResultOfInnerLoop(LoopLikeOpInterface loop) {
+  if (auto forOp = dyn_cast<scf::ForOp>(loop.getOperation())) {
+    Block::OpListType &opsInLoopBody = forOp.getBody()->getOperations();
+    for (auto &&[index, op] : llvm::enumerate(opsInLoopBody)) {
+      // If the orderIndex of inner loop is the last second one before the
+      // yieldOp of ForOp, the given loop must yield the result of inner loop.
+      if (isa<LoopLikeOpInterface>(op)) {
+        return success((index + 2) == opsInLoopBody.size());
+      }
+    }
+  }
+  return failure();
+}
+
 /// Enhanced version for basic implementation of fusing producer, which can deal
 /// with multi-level candidates. E.g.
 ///
@@ -1282,10 +1297,10 @@ std::optional<scf::SCFFuseProducerOfSliceResult>
 mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
                                       Operation *candidateSliceOp) {
   SmallVector<tensor::ExtractSliceOp> backwardSlice;
-  if (failed(
-          getRealProducerFromExtractSliceOp(candidateSliceOp, backwardSlice))) {
+  FailureOr<OpResult> realProducer =
+      getRealProducerFromExtractSliceOp(candidateSliceOp, backwardSlice);
+  if (failed(realProducer))
     return std::nullopt;
-  }
 
   std::optional<scf::SCFFuseProducerOfSliceResult> fuseProducerResult;
   // reverse from outer to inner
@@ -1294,14 +1309,18 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
   for (auto &&[index, sliceOp] : llvm::enumerate(backwardSlice)) {
     // get nest loops between next candidate sliceOp and tiled producer.
     auto whileProducerOutOfLoopBlock =
-        [&fuseProducerResult](LoopLikeOpInterface loop) -> LogicalResult {
-      if (fuseProducerResult) {
-        Block &body = loop->getRegion(0).front();
-        if (fuseProducerResult->tiledAndFusedProducer.getDefiningOp()
-                ->getBlock() == &body)
-          return failure();
-      }
-      return success();
+        [&fuseProducerResult,
+         &realProducer](LoopLikeOpInterface loop) -> LogicalResult {
+      // ensure that all surrounding outer loops are just yielding the result of
+      // the inner loops.
+      if (failed(isForOpYieldResultOfInnerLoop(loop)))
+        return failure();
+      Operation *originalOp =
+          fuseProducerResult
+              ? fuseProducerResult->tiledAndFusedProducer.getDefiningOp()
+              : realProducer->getDefiningOp();
+      Block &body = loop->getRegion(0).front();
+      return success(originalOp->getBlock() != &body);
     };
     SmallVector<LoopLikeOpInterface> outerLoops =
         getOuterNestLoopsWhile(sliceOp->getParentOfType<LoopLikeOpInterface>(),



More information about the Mlir-commits mailing list