[Mlir-commits] [mlir] [mlir][TilingInterface] Extend option to yield replacement for multiple results case (PR #93144)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu May 23 00:23:11 PDT 2024


https://github.com/Yun-Fly created https://github.com/llvm/llvm-project/pull/93144

Currently, we only have an option to yield replacement for `fusableProducer` like this:
```
%1 = op1
%2 = op2(%1)
%3, %y2 = scf.for(%dest1, %dest2){
   %tiled_2 = tiled_op2(...)
   %tiled_3 = tiled_op3(%tiled_2)
   %ins_3 = insert %tiled_3 into %dest1
   %ins_2 = insert %tiled_2 into %dest2 // also yield tiled value, which is decided by caller, usually when tiled value has multiple uses.
   yield  %ins_3, %ins_2
}
// another user of %2
%4 = op4(%y2)
```

However, it has no chance to yield replacement for multiple results as followed:
```
%1 = op1
%2_1, %2_2  = op2(%1)
%3, %y2_1. %y2_2 = scf.for(%dest1, %dest2, %dest3){
   %t2_1, %t2_2 = tiled_op2(...)
   %t3 = tiled_op3(%t2_1)
   %ins3 = insert %t3 into %dest1
   %ins2_1 = insert %t2_1 into %dest2
   %ins2_2 = insert %t2_2 into %dest3
   yield  %ins3, %ins2_1, %ins2_2
}
// another user of %2_1
%4 = op4(%y2_1)
// user of %2_2
%5 = op5(%y2_2)
```
With this method, the original untiled `op2` will has no uses any more and expect cleaned up later, otherwise leading an unnecessary computation.

Based on the earlier talk with @MaheshRavishankar in [discourse](https://discourse.llvm.org/t/what-if-producer-has-multiple-results-or-result-with-multiple-uses-for-tileandfuseproducerofslice/78795), this PR extends the functionality of yielding replacement for multiple results case. NOTE that, it is still decided by the caller whether need to yield replacement as same as current status.

Two major changes:
1. extract a new tiling interface called `getIterationDomainTileFromResultTile`, which is used to compute other results tile according candidate `sliceOp`. BTW, this utility is much similar to another one named `getIterationDomainTileFromOperandTile` in [this PR](https://github.com/llvm/llvm-project/pull/88712). I think they can be further unified when finally merged.
2. enhance `yieldReplacementForFusedProducer` to deal with multiple `OpResult`s and add another optional argument called `yieldResultNumber` indicating which result need yield. If not given, all of results will be yield by default.

Considering downstream impact, not sure its better to break down this option and add another new one for current `fusionControlFn`?

@MaheshRavishankar would you help to review this PR? Thanks.

>From 4235f25b6755d60e88fafc69b3ce78bab3a076b0 Mon Sep 17 00:00:00 2001
From: "Song, Yunfei" <yunfei.song at intel.com>
Date: Wed, 22 May 2024 23:15:36 -0700
Subject: [PATCH] yield replacement for multiple results

---
 .../SCF/Transforms/TileUsingInterface.h       |   6 +-
 .../mlir/Interfaces/TilingInterface.td        |  19 +++
 .../Linalg/Transforms/TilingInterfaceImpl.cpp |  32 +++-
 .../SCF/Transforms/TileUsingInterface.cpp     | 148 +++++++++++++-----
 .../tile-fuse-and-yield-using-interface.mlir  |  62 ++++++++
 5 files changed, 222 insertions(+), 45 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 6d567171e185a..32249b90644a8 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -190,10 +190,14 @@ tileAndFuseProducerOfSlice(RewriterBase &rewriter,
 /// where `%0` had other uses as well. If not reconstructed from within the loop
 /// body, uses of `%0` could not be replaced, making it still live and the
 /// fusion immaterial.
+///
+/// The @param `yieldResultNumber` decides which result would be yield. If not
+/// given, yield all `opResult` of fused producer.
 LogicalResult yieldReplacementForFusedProducer(
     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
-    MutableArrayRef<LoopLikeOpInterface> loops);
+    MutableArrayRef<LoopLikeOpInterface> loops,
+    std::optional<ArrayRef<unsigned>> yieldResultNumber = std::nullopt);
 
 /// Transformation information returned after tile and fuse.
 struct SCFTileAndFuseResult {
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 14d775d986d20..5ac8eaac402b2 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -96,6 +96,25 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           return failure();
         }]
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to return tile offset and size of Iteration Domain 
+          based on the given tile info from the certain result.
+        }],
+        /*retType=*/"LogicalResult",
+        /*methodName=*/"getIterationDomainTileFromResultTile",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$resultNumber,
+          "ArrayRef<OpFoldResult> ":$resultOffsets,
+          "ArrayRef<OpFoldResult> ":$resultSizes,
+          "SmallVector<OpFoldResult> &":$iterDomainOffsets,
+          "SmallVector<OpFoldResult> &":$iterDomainSizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Method to generate the code that produces a tile of the result.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index f512be46cc13d..b46c8135d1c7a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -160,10 +160,11 @@ struct LinalgOpTilingInterface
     return success();
   }
 
-  FailureOr<TilingResult>
-  generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
-                          ArrayRef<OpFoldResult> offsets,
-                          ArrayRef<OpFoldResult> sizes) const {
+  LogicalResult getIterationDomainTileFromResultTile(
+      Operation *op, OpBuilder &b, unsigned resultNumber,
+      ArrayRef<OpFoldResult> resultOffsets, ArrayRef<OpFoldResult> resultSizes,
+      SmallVector<OpFoldResult> &iterDomainOffsets,
+      SmallVector<OpFoldResult> &iterDomainSizes) const {
     auto linalgOp = cast<LinalgOp>(op);
 
     // Check that the indexing map used for the output is a projected
@@ -193,8 +194,27 @@ struct LinalgOpTilingInterface
     for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
       unsigned dimPosition =
           cast<AffineDimExpr>(resultExpr.value()).getPosition();
-      iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
-      iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
+      iterationTileOffsets[dimPosition] = resultOffsets[resultExpr.index()];
+      iterationTileSizes[dimPosition] = resultSizes[resultExpr.index()];
+    }
+
+    iterDomainOffsets = iterationTileOffsets;
+    iterDomainSizes = iterationTileSizes;
+
+    return success();
+  }
+
+  FailureOr<TilingResult>
+  generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
+                          ArrayRef<OpFoldResult> offsets,
+                          ArrayRef<OpFoldResult> sizes) const {
+    auto tilingInterfaceOp = cast<TilingInterface>(op);
+
+    SmallVector<OpFoldResult> iterationTileOffsets, iterationTileSizes;
+    if (failed(tilingInterfaceOp.getIterationDomainTileFromResultTile(
+            b, resultNumber, offsets, sizes, iterationTileOffsets,
+            iterationTileSizes))) {
+      return failure();
     }
 
     FailureOr<TilingResult> tilingResult =
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index a72dafe725177..ddd0e94f9bd4c 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -939,49 +939,114 @@ mlir::scf::tileAndFuseProducerOfSlice(
 LogicalResult mlir::scf::yieldReplacementForFusedProducer(
     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
-    MutableArrayRef<LoopLikeOpInterface> loops) {
+    MutableArrayRef<LoopLikeOpInterface> loops,
+    std::optional<ArrayRef<unsigned>> yieldResultNumber) {
   if (loops.empty())
     return success();
 
-  OpResult fusableProducer = fusedProducerInfo.origProducer;
-  Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer;
-  FailureOr<Value> initValue = tensor::getOrCreateDestination(
-      rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
-  if (succeeded(initValue)) {
-
-    YieldTiledValuesFn newYieldValuesFn =
-        [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
-            ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
-            SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
-            SmallVector<SmallVector<OpFoldResult>> &tiledSizes)
-        -> LogicalResult {
-      OpBuilder::InsertionGuard g(innerRewriter);
-      if (auto tiledDestStyleOp =
-              tiledAndFusedProducer
-                  .getDefiningOp<DestinationStyleOpInterface>()) {
-        rewriter.setInsertionPoint(tiledDestStyleOp);
-        Value newRegionArg = newRegionIterArgs.back();
+  Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
+            *tiledOwner = fusedProducerInfo.tiledOps[0];
+
+  Location loc = originalOwner->getLoc();
+  // a. collect all init Value to be appended
+  ArrayRef<unsigned> initNumberList =
+      yieldResultNumber ? yieldResultNumber.value()
+                        : llvm::to_vector(llvm::seq<unsigned>(
+                              0, originalOwner->getNumResults()));
+  SmallVector<Value> initValueList;
+  for (const auto &resultNumber : initNumberList) {
+    FailureOr<Value> initValue = tensor::getOrCreateDestination(
+        rewriter, loc, originalOwner->getResult(resultNumber));
+    if (succeeded(initValue)) {
+      initValueList.push_back(initValue.value());
+    } else {
+      return failure();
+    }
+  }
+
+  YieldTiledValuesFn newYieldValuesFn =
+      [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
+          ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
+          SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
+          SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
+    OpBuilder::InsertionGuard g(innerRewriter);
+
+    // get sliceOp tile information
+    SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
+                              sliceSizes = sliceOp.getMixedSizes();
+
+    // expect all strides of sliceOp being 1
+    if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
+          return !isConstantIntValue(ofr, 1);
+        }))
+      return failure();
+
+    unsigned sliceResultNumber =
+        fusedProducerInfo.origProducer.getResultNumber();
+
+    auto tilableOp = cast<TilingInterface>(originalOwner);
+    // b. get iterDomain Offset and Sizes based on sliceOp tile
+    SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
+    // skip tensor.pack/unpack/pad, which expects single opResult
+    if (tilableOp->getNumResults() > 1 &&
+        failed(tilableOp.getIterationDomainTileFromResultTile(
+            rewriter, sliceResultNumber, sliceOffset, sliceSizes,
+            iterDomainOffset, iterDomainSizes))) {
+      return failure();
+    }
+
+    // c. calculate offsets and sizes info of all OpResults respectively based
+    // on iteration Domain Tile
+    SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
+    for (const auto &resultNumber : initNumberList) {
+      if (resultNumber == fusedProducerInfo.origProducer.getResultNumber()) {
+        offsetList.push_back(sliceOffset);
+        sizesList.push_back(sliceSizes);
+      } else {
+        assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
+        // infer result tile according to the iteration domain tile
+        SmallVector<OpFoldResult> offset, sizes;
+        if (failed(tilableOp.getResultTilePosition(
+                rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
+                offset, sizes))) {
+          return failure();
+        }
+        offsetList.push_back(offset);
+        sizesList.push_back(sizes);
+      }
+    }
+
+    // d. create `extract_slice` for `iter_args` for DPS operation if necessary
+    if (auto tiledDestStyleOp =
+            dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
+      rewriter.setInsertionPoint(tiledDestStyleOp);
+      for (const auto &&[index, newRegionArg] :
+           llvm::enumerate(newRegionIterArgs)) {
         auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
-            sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
-            sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
-        unsigned resultNumber = fusableProducer.getResultNumber();
+            loc, newRegionArg, offsetList[index], sizesList[index],
+            SmallVector<OpFoldResult>(offsetList[index].size(),
+                                      rewriter.getIndexAttr(1)));
+        unsigned resultNumber = initNumberList[index];
         rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
           tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
         });
       }
-      Block *block = rewriter.getInsertionPoint()->getBlock();
-      rewriter.setInsertionPoint(block->getTerminator());
-      tiledResult.push_back(fusedProducerInfo.tiledAndFusedProducer);
-      tiledOffset.emplace_back(sliceOp.getMixedOffsets());
-      tiledSizes.emplace_back(sliceOp.getMixedSizes());
-      return success();
-    };
+    }
 
-    return addInitOperandsToLoopNest(rewriter, loops,
-                                     SmallVector<Value>{initValue.value()},
-                                     newYieldValuesFn);
-  }
-  return success();
+    // e. prepare tiled offset and sizes for later `insert_slice` creation by
+    // caller
+    Block *block = rewriter.getInsertionPoint()->getBlock();
+    rewriter.setInsertionPoint(block->getTerminator());
+    for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
+      tiledResult.push_back(tiledOwner->getResult(resultNumber));
+      tiledOffset.emplace_back(offsetList[index]);
+      tiledSizes.emplace_back(sizesList[index]);
+    }
+    return success();
+  };
+
+  return addInitOperandsToLoopNest(rewriter, loops, initValueList,
+                                   newYieldValuesFn);
 }
 
 /// Implementation of tile consumer and fuse producer greedily.
@@ -1071,14 +1136,21 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
       continue;
 
     if (yieldReplacement) {
+      // Reconstruct and yield all opResult of fusableProducerOp by default. The
+      // caller can specific which one to yield by designating optional argument
+      // named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
+      Operation *fusableProducerOp = fusableProducer.getOwner();
       if (failed(yieldReplacementForFusedProducer(
               rewriter, candidateSliceOp, fusedResult.value(), loops))) {
         return rewriter.notifyMatchFailure(
-            fusableProducer.getOwner(), "failed to replacement value for this "
-                                        "oepration from within the tiled loop");
+            fusableProducerOp, "failed to replacement value for this "
+                               "operation from within the tiled loop");
+      }
+      for (const auto &result : fusableProducerOp->getResults()) {
+        origValToResultNumber[result] =
+            loops.front()->getNumResults() -
+            (fusableProducerOp->getNumResults() - result.getResultNumber());
       }
-      origValToResultNumber[fusableProducer] =
-          loops.front()->getNumResults() - 1;
     }
 
     if (Operation *tiledAndFusedOp =
diff --git a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
index 7356c11e85ac0..3c0ada9d2cabc 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
@@ -58,3 +58,65 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:     %[[INSERT1:.+]] = tensor.insert_slice %[[GEMM0_TILE]] into %[[ITERARG1]][%[[IV]], 0]
 //      CHECK:     scf.yield %[[INSERT0]], %[[INSERT1]]
 //      CHECK:   return %[[RESULT]]#1, %[[RESULT]]#0
+
+// -----
+
+func.func @multiple_outputs_fusion_yield_all(%lhs0: tensor<32x32xf32>,
+                       %rhs0: tensor<32x32xf32>, %init0: tensor<32x32xf32>, %init1: tensor<32x32xf32>, 
+                       %rhs1: tensor<32x32xf32>, %init2: tensor<32x32xf32>) 
+                       -> (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) {
+  %out0, %out1 = linalg.generic {
+    indexing_maps = [affine_map<(i, j) -> (i, j)>,
+                     affine_map<(i, j) -> (i, j)>,
+                     affine_map<(i, j) -> (i, j)>,
+                     affine_map<(i, j) -> (j, i)>],
+    iterator_types = ["parallel", "parallel"]
+  }
+  ins(%lhs0, %rhs0: tensor<32x32xf32>, tensor<32x32xf32>)
+  outs(%init0, %init1: tensor<32x32xf32>, tensor<32x32xf32>) {
+  ^bb0(%0: f32, %1: f32, %2: f32, %3: f32):
+    %4 = arith.mulf %0, %1 : f32
+    %5 = arith.addf %0, %1 : f32
+    linalg.yield %4, %5: f32, f32
+  } -> (tensor<32x32xf32>, tensor<32x32xf32>)
+
+  %out3 = linalg.add ins(%out0, %rhs1: tensor<32x32xf32>, tensor<32x32xf32>) outs(%init2: tensor<32x32xf32>) -> tensor<32x32xf32>
+
+  return %out0, %out1, %out3 : tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %add = transform.structured.match ops{["linalg.add"]} in %arg0
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_and_yield %add [16]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @multiple_outputs_fusion_yield_all(
+// CHECK-SAME:     %[[LHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME:     %[[RHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
+// CHECK-SAME:     %[[INIT0:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
+// CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
+// CHECK-SAME:     %[[RHS1:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
+// CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<32x32xf32>)
+//      CHECK:   %[[RESULT:.+]]:3 = scf.for %[[IV:[a-zA-Z0-9]+]] =
+// CHECK-SAME:       iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT2]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ITERARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
+//  CHECK-DAG:     %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
+//  CHECK-DAG:     %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][%[[IV]], 0]
+//  CHECK-DAG:     %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV]], 0]
+//  CHECK-DAG:     %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG2]][0, %[[IV]]]
+//      CHECK:     %[[GENERIC_TILE:.+]]:2 = linalg.generic
+// CHECK-SAME:         ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
+// CHECK-SAME:         outs(%[[INIT0_TILE]], %[[INIT1_TILE]] :
+//  CHECK-DAG:     %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][%[[IV]], 0]
+//  CHECK-DAG:     %[[INIT2_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
+//      CHECK:     %[[ADD_TILE:.+]] = linalg.add
+// CHECK-SAME:         ins(%[[GENERIC_TILE]]#0, %[[RHS1_TILE]] :
+// CHECK-SAME:         outs(%[[INIT2_TILE]] :
+//      CHECK:     %[[INSERT0:.+]] = tensor.insert_slice %[[ADD_TILE]] into %[[ITERARG0]][%[[IV]], 0]
+//      CHECK:     %[[INSERT1:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#0 into %[[ITERARG1]][%[[IV]], 0]
+//      CHECK:     %[[INSERT2:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#1 into %[[ITERARG2]][0, %[[IV]]]
+//      CHECK:     scf.yield %[[INSERT0]], %[[INSERT1]], %[[INSERT2]]
+//      CHECK:   return %[[RESULT]]#1, %[[RESULT]]#2, %[[RESULT]]#0



More information about the Mlir-commits mailing list