[Mlir-commits] [mlir] 7ef08ea - [mlir][scf] Extend option to yield replacement for multiple results case (#93144)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jun 28 05:43:56 PDT 2024


Author: Yun-Fly
Date: 2024-06-28T20:43:52+08:00
New Revision: 7ef08eacd5e125eca0881c00f83c4b108ba7c502

URL: https://github.com/llvm/llvm-project/commit/7ef08eacd5e125eca0881c00f83c4b108ba7c502
DIFF: https://github.com/llvm/llvm-project/commit/7ef08eacd5e125eca0881c00f83c4b108ba7c502.diff

LOG: [mlir][scf] Extend option to yield replacement for multiple results case (#93144)

This patch extends the functionality of yielding replacement for multiple 
results case and adds another optional argument called `yieldResultNumber` 
indicating which result(s) need yield. If not given, all of results will be yield 
by default.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
    mlir/include/mlir/Interfaces/TilingInterface.td
    mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
    mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
    mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 6316f1d130d19..d68ca11207376 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -191,10 +191,14 @@ tileAndFuseProducerOfSlice(RewriterBase &rewriter,
 /// where `%0` had other uses as well. If not reconstructed from within the loop
 /// body, uses of `%0` could not be replaced, making it still live and the
 /// fusion immaterial.
+///
+/// The @param `yieldResultNumber` decides which result would be yield. If not
+/// given, yield all `opResult` of fused producer.
 LogicalResult yieldReplacementForFusedProducer(
     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
-    MutableArrayRef<LoopLikeOpInterface> loops);
+    MutableArrayRef<LoopLikeOpInterface> loops,
+    ArrayRef<unsigned> yieldResultNumber = ArrayRef<unsigned>{});
 
 /// Transformation information returned after tile and fuse.
 struct SCFTileAndFuseResult {

diff  --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 3f927865ccf67..ad0af0aaaab80 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -51,7 +51,8 @@ def TilingInterface : OpInterface<"TilingInterface"> {
     For an operation to be "tiled and fused" with its (already tiled) consumer,
     an operation has to implement the following additional method (see
     description below): 
-      - `generateResultTileValue
+      - `generateResultTileValue`
+      - `getIterationDomainTileFromResultTile`
 
     For an operation to be "tiled and fused" with its (already tiled) producer,
     an operation has to implement the following additional methods (see
@@ -302,6 +303,41 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           return failure();
         }]
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to return the tile of the iteration domain based
+          on the given tile of the certain result.
+
+          This method is required to allow operations to be "tiled and fused"
+          with an (already tiled) consumer. Given a tile of an result,
+          returns the tile of the iteration space that uses this tile.
+          - `resultNumber` is the result of the producer used by the consumer.
+          - `offsets` is the offset of the slice of the producer result used by
+            the tiled implementation of the consumer.
+          - `sizes` is the size of the slice of the producer result used by the
+            consumer.
+          If fusion of the producer with the consumer is not legal for the
+          result, or if this mapping cannot be computed, the implementation 
+          should return a failure.
+          
+          For most cases `generateResultTileValue` could be a implemented using 
+          `getIterationDomainTileFromResultTile` + `getTiledImplementation` 
+          methods.
+        }],
+        /*retType=*/"::mlir::LogicalResult",
+        /*methodName=*/"getIterationDomainTileFromResultTile",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$resultNumber,
+          "ArrayRef<OpFoldResult> ":$offsets,
+          "ArrayRef<OpFoldResult> ":$sizes,
+          "SmallVectorImpl<OpFoldResult> &":$iterDomainOffsets,
+          "SmallVectorImpl<OpFoldResult> &":$iterDomainSizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Generates the scalar implementation of the operation.

diff  --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 39780490e9e49..2133458efe74c 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -215,10 +215,11 @@ struct LinalgOpTilingInterface
     return success();
   }
 
-  FailureOr<TilingResult>
-  generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
-                          ArrayRef<OpFoldResult> offsets,
-                          ArrayRef<OpFoldResult> sizes) const {
+  LogicalResult getIterationDomainTileFromResultTile(
+      Operation *op, OpBuilder &b, unsigned resultNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+      SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
+      SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
     auto linalgOp = cast<LinalgOp>(op);
 
     // Check that the indexing map used for the output is a projected
@@ -232,9 +233,21 @@ struct LinalgOpTilingInterface
           "unhandled tiled implementation generation when result is not "
           "accessed using a permuted projection");
     }
-    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+
     getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
-                           mappedOffsets, mappedSizes);
+                           iterDomainOffsets, iterDomainSizes);
+    return success();
+  }
+
+  FailureOr<TilingResult>
+  generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
+                          ArrayRef<OpFoldResult> offsets,
+                          ArrayRef<OpFoldResult> sizes) const {
+    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+    if (failed(getIterationDomainTileFromResultTile(
+            op, b, resultNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
+      return failure();
+    }
     auto tilingInterfaceOp = cast<TilingInterface>(op);
     FailureOr<TilingResult> tilingResult =
         tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);

diff  --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 35edd490f72eb..a1392813d6de3 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -953,49 +953,122 @@ mlir::scf::tileAndFuseProducerOfSlice(
 LogicalResult mlir::scf::yieldReplacementForFusedProducer(
     RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp,
     scf::SCFFuseProducerOfSliceResult fusedProducerInfo,
-    MutableArrayRef<LoopLikeOpInterface> loops) {
+    MutableArrayRef<LoopLikeOpInterface> loops,
+    ArrayRef<unsigned> yieldResultNumber) {
   if (loops.empty())
     return success();
 
-  OpResult fusableProducer = fusedProducerInfo.origProducer;
-  Value tiledAndFusedProducer = fusedProducerInfo.tiledAndFusedProducer;
-  FailureOr<Value> initValue = tensor::getOrCreateDestination(
-      rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer);
-  if (succeeded(initValue)) {
-
-    YieldTiledValuesFn newYieldValuesFn =
-        [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
-            ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
-            SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
-            SmallVector<SmallVector<OpFoldResult>> &tiledSizes)
-        -> LogicalResult {
-      OpBuilder::InsertionGuard g(innerRewriter);
-      if (auto tiledDestStyleOp =
-              tiledAndFusedProducer
-                  .getDefiningOp<DestinationStyleOpInterface>()) {
-        rewriter.setInsertionPoint(tiledDestStyleOp);
-        Value newRegionArg = newRegionIterArgs.back();
+  Operation *originalOwner = fusedProducerInfo.origProducer.getOwner(),
+            *tiledOwner = fusedProducerInfo.tiledOps[0];
+
+  Location loc = originalOwner->getLoc();
+  // a. collect all init Value to be appended
+  SmallVector<unsigned> initNumberList =
+      yieldResultNumber.empty() ? llvm::to_vector(llvm::seq<unsigned>(
+                                      0, originalOwner->getNumResults()))
+                                : llvm::to_vector(yieldResultNumber);
+  SmallVector<Value> initValueList;
+  for (const auto &resultNumber : initNumberList) {
+    FailureOr<Value> initValue = tensor::getOrCreateDestination(
+        rewriter, loc, originalOwner->getResult(resultNumber));
+    if (succeeded(initValue)) {
+      initValueList.push_back(initValue.value());
+    } else {
+      return failure();
+    }
+  }
+
+  YieldTiledValuesFn newYieldValuesFn =
+      [&](RewriterBase &innerRewriter, Location loc, ValueRange /*ivs*/,
+          ValueRange newRegionIterArgs, SmallVector<Value> &tiledResult,
+          SmallVector<SmallVector<OpFoldResult>> &tiledOffset,
+          SmallVector<SmallVector<OpFoldResult>> &tiledSizes) -> LogicalResult {
+    OpBuilder::InsertionGuard g(innerRewriter);
+
+    // get sliceOp tile information
+    SmallVector<OpFoldResult> sliceOffset = sliceOp.getMixedOffsets(),
+                              sliceSizes = sliceOp.getMixedSizes();
+
+    // expect all strides of sliceOp being 1
+    if (llvm::any_of(sliceOp.getMixedStrides(), [](OpFoldResult ofr) {
+          return !isConstantIntValue(ofr, 1);
+        }))
+      return failure();
+
+    unsigned sliceResultNumber =
+        fusedProducerInfo.origProducer.getResultNumber();
+
+    auto tilableOp = cast<TilingInterface>(originalOwner);
+    // b. get iterDomain Offset and Sizes based on sliceOp tile
+    SmallVector<OpFoldResult> iterDomainOffset, iterDomainSizes;
+    // skip tensor.pack/unpack/pad, which expects single opResult
+    if (tilableOp->getNumResults() > 1 &&
+        failed(tilableOp.getIterationDomainTileFromResultTile(
+            rewriter, sliceResultNumber, sliceOffset, sliceSizes,
+            iterDomainOffset, iterDomainSizes))) {
+      // In theory, it is unnecessary to raise an error here. Actually although
+      // it fails to reconstruct the result tensor, it should not broke current
+      // fusion anyway. The reason why we must return failure currently is that
+      // the callback function `newYieldValuesFn` will be called after new init
+      // operand(s) has already been appended. It will take more refactoring to
+      // make sure the init operands are added consistently in the future. For
+      // more details, please refer to:
+      // https://github.com/llvm/llvm-project/pull/93144#discussion_r1643760814
+      return failure();
+    }
+
+    // c. calculate offsets and sizes info of all OpResults respectively based
+    // on iteration Domain Tile
+    SmallVector<SmallVector<OpFoldResult>> offsetList, sizesList;
+    for (const auto &resultNumber : initNumberList) {
+      if (resultNumber == sliceResultNumber) {
+        offsetList.push_back(sliceOffset);
+        sizesList.push_back(sliceSizes);
+      } else {
+        assert(!iterDomainOffset.empty() && !iterDomainSizes.empty());
+        // infer result tile according to the iteration domain tile
+        SmallVector<OpFoldResult> offset, sizes;
+        if (failed(tilableOp.getResultTilePosition(
+                rewriter, resultNumber, iterDomainOffset, iterDomainSizes,
+                offset, sizes))) {
+          return failure();
+        }
+        offsetList.push_back(offset);
+        sizesList.push_back(sizes);
+      }
+    }
+
+    // d. create `extract_slice` for `iter_args` for DPS operation if necessary
+    if (auto tiledDestStyleOp =
+            dyn_cast<DestinationStyleOpInterface>(tiledOwner)) {
+      rewriter.setInsertionPoint(tiledDestStyleOp);
+      for (const auto &&[index, newRegionArg] :
+           llvm::enumerate(newRegionIterArgs)) {
         auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
-            sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
-            sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
-        unsigned resultNumber = fusableProducer.getResultNumber();
+            loc, newRegionArg, offsetList[index], sizesList[index],
+            SmallVector<OpFoldResult>(offsetList[index].size(),
+                                      rewriter.getIndexAttr(1)));
+        unsigned resultNumber = initNumberList[index];
         rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
           tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
         });
       }
-      Block *block = rewriter.getInsertionPoint()->getBlock();
-      rewriter.setInsertionPoint(block->getTerminator());
-      tiledResult.push_back(fusedProducerInfo.tiledAndFusedProducer);
-      tiledOffset.emplace_back(sliceOp.getMixedOffsets());
-      tiledSizes.emplace_back(sliceOp.getMixedSizes());
-      return success();
-    };
+    }
 
-    return addInitOperandsToLoopNest(rewriter, loops,
-                                     SmallVector<Value>{initValue.value()},
-                                     newYieldValuesFn);
-  }
-  return success();
+    // e. prepare tiled offset and sizes for later `insert_slice` creation by
+    // caller
+    Block *block = rewriter.getInsertionPoint()->getBlock();
+    rewriter.setInsertionPoint(block->getTerminator());
+    for (const auto &&[index, resultNumber] : llvm::enumerate(initNumberList)) {
+      tiledResult.push_back(tiledOwner->getResult(resultNumber));
+      tiledOffset.emplace_back(offsetList[index]);
+      tiledSizes.emplace_back(sizesList[index]);
+    }
+    return success();
+  };
+
+  return addInitOperandsToLoopNest(rewriter, loops, initValueList,
+                                   newYieldValuesFn);
 }
 
 /// Implementation of tile consumer and fuse producer greedily.
@@ -1085,14 +1158,22 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
       continue;
 
     if (yieldReplacement) {
+      // Reconstruct and yield all opResult of fusableProducerOp by default. The
+      // caller can specific which one to yield by designating optional argument
+      // named `yieldResultNumber` of `yieldReplacementForFusedProducer`.
+      Operation *fusableProducerOp = fusableProducer.getOwner();
       if (failed(yieldReplacementForFusedProducer(
               rewriter, candidateSliceOp, fusedResult.value(), loops))) {
         return rewriter.notifyMatchFailure(
-            fusableProducer.getOwner(), "failed to replacement value for this "
-                                        "oepration from within the tiled loop");
+            fusableProducerOp, "failed to replacement value for this "
+                               "operation from within the tiled loop");
+      }
+      for (auto [index, result] :
+           llvm::enumerate(fusableProducerOp->getResults())) {
+        origValToResultNumber[result] = loops.front()->getNumResults() -
+                                        fusableProducerOp->getNumResults() +
+                                        index;
       }
-      origValToResultNumber[fusableProducer] =
-          loops.front()->getNumResults() - 1;
     }
 
     if (Operation *tiledAndFusedOp =

diff  --git a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
index 7356c11e85ac0..3c0ada9d2cabc 100644
--- a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
+++ b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir
@@ -58,3 +58,65 @@ module attributes {transform.with_named_sequence} {
 //      CHECK:     %[[INSERT1:.+]] = tensor.insert_slice %[[GEMM0_TILE]] into %[[ITERARG1]][%[[IV]], 0]
 //      CHECK:     scf.yield %[[INSERT0]], %[[INSERT1]]
 //      CHECK:   return %[[RESULT]]#1, %[[RESULT]]#0
+
+// -----
+
+func.func @multiple_outputs_fusion_yield_all(%lhs0: tensor<32x32xf32>,
+                       %rhs0: tensor<32x32xf32>, %init0: tensor<32x32xf32>, %init1: tensor<32x32xf32>, 
+                       %rhs1: tensor<32x32xf32>, %init2: tensor<32x32xf32>) 
+                       -> (tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>) {
+  %out0, %out1 = linalg.generic {
+    indexing_maps = [affine_map<(i, j) -> (i, j)>,
+                     affine_map<(i, j) -> (i, j)>,
+                     affine_map<(i, j) -> (i, j)>,
+                     affine_map<(i, j) -> (j, i)>],
+    iterator_types = ["parallel", "parallel"]
+  }
+  ins(%lhs0, %rhs0: tensor<32x32xf32>, tensor<32x32xf32>)
+  outs(%init0, %init1: tensor<32x32xf32>, tensor<32x32xf32>) {
+  ^bb0(%0: f32, %1: f32, %2: f32, %3: f32):
+    %4 = arith.mulf %0, %1 : f32
+    %5 = arith.addf %0, %1 : f32
+    linalg.yield %4, %5: f32, f32
+  } -> (tensor<32x32xf32>, tensor<32x32xf32>)
+
+  %out3 = linalg.add ins(%out0, %rhs1: tensor<32x32xf32>, tensor<32x32xf32>) outs(%init2: tensor<32x32xf32>) -> tensor<32x32xf32>
+
+  return %out0, %out1, %out3 : tensor<32x32xf32>, tensor<32x32xf32>, tensor<32x32xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0 : !transform.any_op {transform.readonly}) {
+    %add = transform.structured.match ops{["linalg.add"]} in %arg0
+      : (!transform.any_op) -> !transform.any_op
+    %a, %b = transform.test.fuse_and_yield %add [16]
+      : (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+    transform.yield
+  }
+}
+//      CHECK: func.func @multiple_outputs_fusion_yield_all(
+// CHECK-SAME:     %[[LHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
+// CHECK-SAME:     %[[RHS0:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
+// CHECK-SAME:     %[[INIT0:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
+// CHECK-SAME:     %[[INIT1:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
+// CHECK-SAME:     %[[RHS1:[a-zA-Z0-9]+]]: tensor<32x32xf32>,
+// CHECK-SAME:     %[[INIT2:[a-zA-Z0-9]+]]: tensor<32x32xf32>)
+//      CHECK:   %[[RESULT:.+]]:3 = scf.for %[[IV:[a-zA-Z0-9]+]] =
+// CHECK-SAME:       iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT2]], %[[ITERARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ITERARG2:[a-zA-Z0-9]+]] = %[[INIT1]])
+//  CHECK-DAG:     %[[LHS0_TILE:.+]] = tensor.extract_slice %[[LHS0]][%[[IV]], 0]
+//  CHECK-DAG:     %[[RHS0_TILE:.+]] = tensor.extract_slice %[[RHS0]][%[[IV]], 0]
+//  CHECK-DAG:     %[[INIT0_TILE:.+]] = tensor.extract_slice %[[ITERARG1]][%[[IV]], 0]
+//  CHECK-DAG:     %[[INIT1_TILE:.+]] = tensor.extract_slice %[[ITERARG2]][0, %[[IV]]]
+//      CHECK:     %[[GENERIC_TILE:.+]]:2 = linalg.generic
+// CHECK-SAME:         ins(%[[LHS0_TILE]], %[[RHS0_TILE]] :
+// CHECK-SAME:         outs(%[[INIT0_TILE]], %[[INIT1_TILE]] :
+//  CHECK-DAG:     %[[RHS1_TILE:.+]] = tensor.extract_slice %[[RHS1]][%[[IV]], 0]
+//  CHECK-DAG:     %[[INIT2_TILE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0]
+//      CHECK:     %[[ADD_TILE:.+]] = linalg.add
+// CHECK-SAME:         ins(%[[GENERIC_TILE]]#0, %[[RHS1_TILE]] :
+// CHECK-SAME:         outs(%[[INIT2_TILE]] :
+//      CHECK:     %[[INSERT0:.+]] = tensor.insert_slice %[[ADD_TILE]] into %[[ITERARG0]][%[[IV]], 0]
+//      CHECK:     %[[INSERT1:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#0 into %[[ITERARG1]][%[[IV]], 0]
+//      CHECK:     %[[INSERT2:.+]] = tensor.insert_slice %[[GENERIC_TILE]]#1 into %[[ITERARG2]][0, %[[IV]]]
+//      CHECK:     scf.yield %[[INSERT0]], %[[INSERT1]], %[[INSERT2]]
+//      CHECK:   return %[[RESULT]]#1, %[[RESULT]]#2, %[[RESULT]]#0


        


More information about the Mlir-commits mailing list