[Mlir-commits] [mlir] [mlir][scf] Extend option to yield replacement for multiple results case (PR #93144)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jun 5 18:48:35 PDT 2024


https://github.com/Yun-Fly updated https://github.com/llvm/llvm-project/pull/93144

>From 4377bf0ce4e2108e1de01ca2b1c0a7d455068264 Mon Sep 17 00:00:00 2001
From: "Song, Yunfei" <yunfei.song at intel.com>
Date: Wed, 22 May 2024 23:15:36 -0700
Subject: [PATCH 1/3] yield replacement for multiple results

---
 .../SCF/Transforms/TileUsingInterface.h       |   6 +-
 .../mlir/Interfaces/TilingInterface.td        |  19 +++
 .../Linalg/Transforms/TilingInterfaceImpl.cpp |  35 +++--
 .../SCF/Transforms/TileUsingInterface.cpp     | 148 +++++++++++++-----
 .../tile-fuse-and-yield-using-interface.mlir  |  62 ++++++++
 5 files changed, 220 insertions(+), 50 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index dac79111af3c9..807379d99b599 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -191,10 +191,14 @@ tileAndFuseProducerOfSlice(RewriterBase &rewriter,
 /// where `%0` had other uses as well. If not reconstructed from within the loop
 /// body, uses of `%0` could not be replaced, making it still live and the
 /// fusion immaterial.
+///
+/// The @param `yieldResultNumber` decides which result would be yield. If not
+/// given, yield all `opResult` of fused producer.
 LogicalResult yieldReplacementForFusedProducer(
     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
-    MutableArrayRef<LoopLikeOpInterface> loops);
+    MutableArrayRef<LoopLikeOpInterface> loops,
+    std::optional<ArrayRef<unsigned>> yieldResultNumber = std::nullopt);
 
 /// Transformation information returned after tile and fuse.
 struct SCFTileAndFuseResult {
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index bc83c81c0086c..95e204cf5e1d2 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -115,6 +115,25 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           return failure();
         }]
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to return the tile of the iteration domain based
+          on the given tile of the certain result.
+        }],
+        /*retType=*/"::mlir::LogicalResult",
+        /*methodName=*/"getIterationDomainTileFromResultTile",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$resultNumber,
+          "ArrayRef<OpFoldResult> ":$resultOffsets,
+          "ArrayRef<OpFoldResult> ":$resultSizes,
+          "SmallVectorImpl<OpFoldResult> &":$iterDomainOffsets,
+          "SmallVectorImpl<OpFoldResult> &":$iterDomainSizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Method to generate the code that produces a tile of the result.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index c3ab3cecfada7..0dbfe93fc5650 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -215,26 +215,39 @@ struct LinalgOpTilingInterface
     return success();
   }
 
-  FailureOr<TilingResult>
-  generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
-                          ArrayRef<OpFoldResult> offsets,
-                          ArrayRef<OpFoldResult> sizes) const {
+  LogicalResult getIterationDomainTileFromResultTile(
+      Operation *op, OpBuilder &b, unsigned resultNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+      SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
+      SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
     auto linalgOp = cast<LinalgOp>(op);
 
-    // Check that the indexing map used for the output is a projected
+    // Check that the indexing map used for the operand is a projected
     // permutation. This could be relaxed with a more general approach that can
-    // map the offsets and sizes from the result to iteration space tiles
+    // map the offsets and sizes from the operand to iteration space tiles
     // (filling in full extent for dimensions not used to access the result).
     AffineMap indexingMap =
         linalgOp.getIndexingMapMatchingResult(op->getResult(resultNumber));
     if (!indexingMap.isProjectedPermutation()) {
-      return op->emitOpError(
-          "unhandled tiled implementation generation when result is not "
-          "accessed using a permuted projection");
+      return op->emitError()
+             << "unhandled get iter domain position when operand is not "
+                "accessed using a permuted projection";
     }
-    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+
     getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
-                           mappedOffsets, mappedSizes);
+                           iterDomainOffsets, iterDomainSizes);
+    return success();
+  }
+
+  FailureOr<TilingResult>
+  generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
+                          ArrayRef<OpFoldResult> offsets,
+                          ArrayRef<OpFoldResult> sizes) const {
+    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+    if (failed(getIterationDomainTileFromResultTile(
+            op, b, resultNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
+      return failure();
+    }
     auto tilingInterfaceOp = cast<TilingInterface>(op);
     FailureOr<TilingResult> tilingResult =
         tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index f3d6b7a530117..e8cf7298be681 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -940,49 +940,114 @@ mlir::scf::tileAndFuseProducerOfSlice(
 LogicalResult mlir::scf::yieldReplacementForFusedProducer(
     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
-    MutableArrayRef<LoopLikeOpInterface> loops) {
+    MutableArrayRef<LoopLikeOpInterface> loops,
+    std::optional<ArrayRef<unsigned>> yieldResultNumber) {
   if (loops.empty())
     return success();
 
-  OpResult fusableProducer = fusedProducerInfo.origProducer;
-  Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer;
-  FailureOr<Value> initValue = tensor::getOrCreateDestination(
-      rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
-  if (succeeded(initValue)) {
-
-    YieldTiledValuesFn newYieldValuesFn =
-        [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
-            ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
-            SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
-            SmallVector<SmallVector<OpFoldResult>> &tiledSizes)
-        -> LogicalResult {
-      OpBuilder::InsertionGuard g(innerRewriter);
-      if (auto tiledDestStyleOp =
-              tiledAndFusedProducer
-                  .getDefiningOp<DestinationStyleOpInterface>()) {
-        rewriter.setInsertionPoint(tiledDestStyleOp);
-        Value newRegionArg = newRegionIterArgs.back();
+  Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
+            *tiledOwner = fusedProducerInfo.tiledOps[0];
+
+  Location loc = originalOwner->getLoc();
+  // a. collect all init Value to be appended
+  ArrayRef<unsigned> initNumberList =
+      yieldResultNumber ? yieldResultNumber.value()
+                        : llvm::to_vector(llvm::seq<unsigned>(
+                              0, originalOwner->getNumResults()));
+  SmallVector<Value> initValueList;
+  for (const auto &resultNumber : initNumberList) {
+    FailureOr<Value> initValue = tensor::getOrCreateDestination(
+        rewriter, loc, originalOwner->getResult(resultNumber));
+    if (succeeded(initValue)) {
+      initValueList.push_back(initValue.value());
+    } else {
+      return failure();
+    }
+  }
+
+  YieldTiledValuesFn newYieldValuesFn =
+      [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
+          ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
+          SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
+          SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
+    OpBuilder::InsertionGuard g(innerRewriter);
+
+    // get sliceOp tile information
+    SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
+                              sliceSizes = sliceOp.getMixedSizes();
+
+    // expect all strides of sliceOp being 1
+    if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
+          return !isConstantIntValue(ofr, 1);
+        }))
+      return failure();
+
+    unsigned sliceResultNumber =
+        fusedProducerInfo.origProducer.getResultNumber();
+
+    auto tilableOp = cast<TilingInterface>(originalOwner);
+    // b. get iterDomain Offset and Sizes based on sliceOp tile
+    SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
+    // skip tensor.pack/unpack/pad, which expects single opResult
+    if (tilableOp->getNumResults() > 1 &&
+        failed(tilableOp.getIterationDomainTileFromResultTile(
+            rewriter, sliceResultNumber, sliceOffset, sliceSizes,
+            iterDomainOffset, iterDomainSizes))) {
+      return failure();
+    }
+
+    // c. calculate offsets and sizes info of all OpResults respectively based
+    // on iteration Domain Tile
+    SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
+    for (const auto &resultNumber : initNumberList) {
+      if (resultNumber == fusedProducerInfo.origProducer.getResultNumber()) {
+        offsetList.push_back(sliceOffset);
+        sizesList.push_back(sliceSizes);
+      } else {
+        assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
+        // infer result tile according to the iteration domain tile
+        SmallVector<OpFoldResult> offset, sizes;
+        if (failed(tilableOp.getResultTilePosition(
+                rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
+                offset, sizes))) {
+          return failure();
+        }
+        offsetList.push_back(offset);
+        sizesList.push_back(sizes);
+      }
+    }
+
+    // d. create `extract_slice` for `iter_args` for DPS operation if necessary
+    if (auto tiledDestStyleOp =
+            dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
+      rewriter.setInsertionPoint(tiledDestStyleOp);
+      for (const auto &&[index, newRegionArg] :
+           llvm::enumerate(newRegionIterArgs)) {
         auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
-            sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
-            sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
-        unsigned resultNumber = fusableProducer.getResultNumber();
+            loc, newRegionArg, offsetList[index], sizesList[index],
+            SmallVector<OpFoldResult>(offsetList[index].size(),
+                                      rewriter.getIndexAttr(1)));
+        unsigned resultNumber = initNumberList[index];
         rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
           tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
         });
       }
-      Block *block = rewriter.getInsertionPoint()->getBlock();
-      rewriter.setInsertionPoint(block->getTerminator());
-      tiledResult.push_back(fusedProducerInfo.tiledAndFusedProducer);
-      tiledOffset.emplace_back(sliceOp.getMixedOffsets());
-      tiledSizes.emplace_back(sliceOp.getMixedSizes());
-      return success();
-    };
+    }
 
-    return addInitOperandsToLoopNest(rewriter, loops,
-                                     SmallVector<Value>{initValue.value()},
-                                     newYieldValuesFn);
-  }
-  return success();
+    // e. prepare tiled offset and sizes for later `insert_slice` creation by
+    // caller
+    Block *block = rewriter.getInsertionPoint()->getBlock();
+    rewriter.setInsertionPoint(block->getTerminator());
+    for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
+      tiledResult.push_back(tiledOwner->getResult(resultNumber));
+      tiledOffset.emplace_back(offsetList[index]);
+      tiledSizes.emplace_back(sizesList[index]);
+    }
+    return success();
+  };
+
+  return addInitOperandsToLoopNest(rewriter, loops, initValueList,
+                                   newYieldValuesFn);
 }
 
 /// Implementation of tile consumer and fuse producer greedily.
@@ -1072,14 +1137,21 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
       continue;
 
     if (yieldReplacement) {
+      // Reconstruct and yield all opResult of fusableProducerOp by default. The
+      // caller can specific which one to yield by designating optional argument
+      // named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
+      Operation *fusableProducerOp = fusableProducer.getOwner();
       if (failed(yieldReplacementForFusedProducer(
               rewriter, candidateSliceOp, fusedResult.value(), loops))) {
         return rewriter.notifyMatchFailure(
-            fusableProducer.getOwner(), "failed to replacement value for this "
-                                        "oepration from within the tiled loop");
+            fusableProducerOp, "failed to replacement value for this "
+                               "operation from within the tiled loop");
+      }
+      for (const auto &result : fusableProducerOp->getResults()) {
+        origValToResultNumber[result] =
+            loops.front()->getNumResults() -
+            (fusableProducerOp->getNumResults() - result.getResultNumber());
       }
-      origValToResultNumber[fusableProducer] =
-          loops.front()->getNumResults() - 1;
     }
 
     if (Operation *tiledAndFusedOp =
diff --git a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
index 7356c11e85ac0..3c0ada9d2cabc 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
@@ -58,3 +58,65 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:     %[[INSERT1:.+]] = tensor.insert_slice %[[GEMM0_TILE]] into %[[ITERARG1]][%[[IV]], 0]
 //      CHECK:     scf.yield %[[INSERT0]], %[[INSERT1]]
 //      CHECK:   return %[[RESULT]]#1, %[[RESULT]]#0
+
+// -----
+
+func.func @multiple_outputs_fusion_yield_all(%lhs0: tensor<32x32xf32>,
+                       %rhs0: tensor<32x32xf32>, %init0: tensor<32x32xf32>, %init1: tensor<32x32xf32>, 
+                       %rhs1: tensor<32x32xf32>, %init2: tensor<32x32xf32>) 
+                       -> (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) {
+  %out0, %out1 = linalg.generic {
+    indexing_maps = [affine_map<(i, j) -> (i, j)>,
+                     affine_map<(i, j) -> (i, j)>,
+                     affine_map<(i, j) -> (i, j)>,
+                     affine_map<(i, j) -> (j, i)>],
+    iterator_types = ["parallel", "parallel"]
+  }
+  ins(%lhs0, %rhs0: tensor<32x32xf32>, tensor<32x32xf32>)
+  outs(%init0, %init1: tensor<32x32xf32>, tensor<32x32xf32>) {
+  ^bb0(%0: f32, %1: f32, %2: f32, %3: f32):
+    %4 = arith.mulf %0, %1 : f32
+    %5 = arith.addf %0, %1 : f32
+    linalg.yield %4, %5: f32, f32
+  } -> (tensor<32x32xf32>, tensor<32x32xf32>)
+
+  %out3 = linalg.add ins(%out0, %rhs1: tensor<32x32xf32>, tensor<32x32xf32>) outs(%init2: tensor<32x32xf32>) -> tensor<32x32xf32>
+
+  return %out0, %out1, %out3 : tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %add = transform.structured.match ops{["linalg.add"]} in %arg0
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_and_yield %add [16]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @multiple_outputs_fusion_yield_all(
+// CHECK-SAME:     %[[LHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME:     %[[RHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
+// CHECK-SAME:     %[[INIT0:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
+// CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
+// CHECK-SAME:     %[[RHS1:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
+// CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<32x32xf32>)
+//      CHECK:   %[[RESULT:.+]]:3 = scf.for %[[IV:[a-zA-Z0-9]+]] =
+// CHECK-SAME:       iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT2]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ITERARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
+//  CHECK-DAG:     %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
+//  CHECK-DAG:     %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][%[[IV]], 0]
+//  CHECK-DAG:     %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV]], 0]
+//  CHECK-DAG:     %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG2]][0, %[[IV]]]
+//      CHECK:     %[[GENERIC_TILE:.+]]:2 = linalg.generic
+// CHECK-SAME:         ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
+// CHECK-SAME:         outs(%[[INIT0_TILE]], %[[INIT1_TILE]] :
+//  CHECK-DAG:     %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][%[[IV]], 0]
+//  CHECK-DAG:     %[[INIT2_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
+//      CHECK:     %[[ADD_TILE:.+]] = linalg.add
+// CHECK-SAME:         ins(%[[GENERIC_TILE]]#0, %[[RHS1_TILE]] :
+// CHECK-SAME:         outs(%[[INIT2_TILE]] :
+//      CHECK:     %[[INSERT0:.+]] = tensor.insert_slice %[[ADD_TILE]] into %[[ITERARG0]][%[[IV]], 0]
+//      CHECK:     %[[INSERT1:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#0 into %[[ITERARG1]][%[[IV]], 0]
+//      CHECK:     %[[INSERT2:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#1 into %[[ITERARG2]][0, %[[IV]]]
+//      CHECK:     scf.yield %[[INSERT0]], %[[INSERT1]], %[[INSERT2]]
+//      CHECK:   return %[[RESULT]]#1, %[[RESULT]]#2, %[[RESULT]]#0

>From 23464b6687bfbedd851752e0b6bebf35108617f0 Mon Sep 17 00:00:00 2001
From: "Song, Yunfei" <yunfei.song at intel.com>
Date: Wed, 5 Jun 2024 07:00:01 -0700
Subject: [PATCH 2/3] change default arguments

---
 .../mlir/Dialect/SCF/Transforms/TileUsingInterface.h      | 2 +-
 mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp    | 8 ++++----
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 807379d99b599..781da1b4ef8a2 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -198,7 +198,7 @@ LogicalResult yieldReplacementForFusedProducer(
     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
     MutableArrayRef<LoopLikeOpInterface> loops,
-    std::optional<ArrayRef<unsigned>> yieldResultNumber = std::nullopt);
+    ArrayRef<unsigned> yieldResultNumber = ArrayRef<unsigned>{});
 
 /// Transformation information returned after tile and fuse.
 struct SCFTileAndFuseResult {
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index e8cf7298be681..33142e61750d2 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -941,7 +941,7 @@ LogicalResult mlir::scf::yieldReplacementForFusedProducer(
     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
     MutableArrayRef<LoopLikeOpInterface> loops,
-    std::optional<ArrayRef<unsigned>> yieldResultNumber) {
+    ArrayRef<unsigned> yieldResultNumber) {
   if (loops.empty())
     return success();
 
@@ -951,9 +951,9 @@ LogicalResult mlir::scf::yieldReplacementForFusedProducer(
   Location loc = originalOwner->getLoc();
   // a. collect all init Value to be appended
   ArrayRef<unsigned> initNumberList =
-      yieldResultNumber ? yieldResultNumber.value()
-                        : llvm::to_vector(llvm::seq<unsigned>(
-                              0, originalOwner->getNumResults()));
+      yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
+                                      0, originalOwner->getNumResults()))
+                                : yieldResultNumber;
   SmallVector<Value> initValueList;
   for (const auto &resultNumber : initNumberList) {
     FailureOr<Value> initValue = tensor::getOrCreateDestination(

>From 74d925e77a8450768aedeaf4350f9b2ced531a4d Mon Sep 17 00:00:00 2001
From: "Song, Yunfei" <yunfei.song at intel.com>
Date: Wed, 5 Jun 2024 18:47:54 -0700
Subject: [PATCH 3/3] fix CI

---
 mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 33142e61750d2..a6677393c73dc 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -950,10 +950,10 @@ LogicalResult mlir::scf::yieldReplacementForFusedProducer(
 
   Location loc = originalOwner->getLoc();
   // a. collect all init Value to be appended
-  ArrayRef<unsigned> initNumberList =
+  SmallVector<unsigned> initNumberList =
       yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
                                       0, originalOwner->getNumResults()))
-                                : yieldResultNumber;
+                                : llvm::to_vector(yieldResultNumber);
   SmallVector<Value> initValueList;
   for (const auto &resultNumber : initNumberList) {
     FailureOr<Value> initValue = tensor::getOrCreateDestination(
@@ -1000,7 +1000,7 @@ LogicalResult mlir::scf::yieldReplacementForFusedProducer(
     // on iteration Domain Tile
     SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
     for (const auto &resultNumber : initNumberList) {
-      if (resultNumber == fusedProducerInfo.origProducer.getResultNumber()) {
+      if (resultNumber == sliceResultNumber) {
         offsetList.push_back(sliceOffset);
         sizesList.push_back(sliceSizes);
       } else {



More information about the Mlir-commits mailing list