[Mlir-commits] [mlir] [mlir][scf] Extend option to yield replacement for multiple results case (PR #93144)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 27 22:43:27 PDT 2024


================
@@ -940,49 +940,114 @@ mlir::scf::tileAndFuseProducerOfSlice(
 LogicalResult mlir::scf::yieldReplacementForFusedProducer(
     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
-    MutableArrayRef<LoopLikeOpInterface> loops) {
+    MutableArrayRef<LoopLikeOpInterface> loops,
+    ArrayRef<unsigned> yieldResultNumber) {
   if (loops.empty())
     return success();
 
-  OpResult fusableProducer = fusedProducerInfo.origProducer;
-  Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer;
-  FailureOr<Value> initValue = tensor::getOrCreateDestination(
-      rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
-  if (succeeded(initValue)) {
-
-    YieldTiledValuesFn newYieldValuesFn =
-        [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
-            ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
-            SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
-            SmallVector<SmallVector<OpFoldResult>> &tiledSizes)
-        -> LogicalResult {
-      OpBuilder::InsertionGuard g(innerRewriter);
-      if (auto tiledDestStyleOp =
-              tiledAndFusedProducer
-                  .getDefiningOp<DestinationStyleOpInterface>()) {
-        rewriter.setInsertionPoint(tiledDestStyleOp);
-        Value newRegionArg = newRegionIterArgs.back();
+  Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
+            *tiledOwner = fusedProducerInfo.tiledOps[0];
+
+  Location loc = originalOwner->getLoc();
+  // a. collect all init Value to be appended
+  SmallVector<unsigned> initNumberList =
+      yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
+                                      0, originalOwner->getNumResults()))
+                                : llvm::to_vector(yieldResultNumber);
+  SmallVector<Value> initValueList;
+  for (const auto &resultNumber : initNumberList) {
+    FailureOr<Value> initValue = tensor::getOrCreateDestination(
+        rewriter, loc, originalOwner->getResult(resultNumber));
+    if (succeeded(initValue)) {
+      initValueList.push_back(initValue.value());
+    } else {
+      return failure();
+    }
+  }
+
+  YieldTiledValuesFn newYieldValuesFn =
+      [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
+          ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
+          SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
+          SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
+    OpBuilder::InsertionGuard g(innerRewriter);
+
+    // get sliceOp tile information
+    SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
+                              sliceSizes = sliceOp.getMixedSizes();
+
+    // expect all strides of sliceOp being 1
+    if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
+          return !isConstantIntValue(ofr, 1);
+        }))
+      return failure();
+
+    unsigned sliceResultNumber =
+        fusedProducerInfo.origProducer.getResultNumber();
+
+    auto tilableOp = cast<TilingInterface>(originalOwner);
+    // b. get iterDomain Offset and Sizes based on sliceOp tile
+    SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
+    // skip tensor.pack/unpack/pad, which expects single opResult
+    if (tilableOp->getNumResults() > 1 &&
+        failed(tilableOp.getIterationDomainTileFromResultTile(
----------------
Yun-Fly wrote:

> Can you just add a comment here as to why this is a failure for now

Yes, sure. It is useful as a kind reminder for all of us.

> It looks good to me to land.

Thanks again!

https://github.com/llvm/llvm-project/pull/93144


More information about the Mlir-commits mailing list