[Mlir-commits] [mlir] [mlir][TilingInterface] Handle multi operand consumer fusion. (PR #145193)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sat Jun 21 16:14:14 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: None (MaheshRavishankar)

<details>
<summary>Changes</summary>

For consumer fusion cases of this form

```
%0:2 = scf.forall .. shared_outs(%arg0 = ..., %arg0 = ...) {

  tensor.parallel_insert_slice ... into %arg0
  tensor.parallel_insert_slice ... into %arg1
}
%1 = linalg.generic ... ins(%0#<!-- -->0, %0#<!-- -->1)
```

the current consumer fusion that handles one slice at a time cannot fuse the consumer into the loop, since fusing along one slice will create and SSA violation on the other use from the `scf.forall`. The solution is to allow consumer fusion to allow considering multiple slices at once. This PR changes the `TilingInterface` methods related to consumer fusion, i.e.

- `getTiledImplementationFromOperandTile`
- `getIterationDomainFromOperandTile`

to allow fusion while considering multiple operands. It is upto the `TilingInterface` implementation to return an error if a list of tiles of the operands cannot result in a consistent implementation of the tiled operation.

The Linalg operation implementation of `TilingInterface` has been modified to account for these changes and allow cases where operand tiles that can result in a consistent tiling implementation are handled.

Additional change : Add `LLVM_DUMP_METHOD` to `OpFoldResult::dump` to preserve the symbol in Debug builds.

---

Patch is 60.31 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/145193.diff


10 Files Affected:

- (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+14-9) 
- (modified) mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h (+9-5) 
- (modified) mlir/include/mlir/IR/OpDefinition.h (+1-1) 
- (modified) mlir/include/mlir/Interfaces/TilingInterface.td (+27-26) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+118-64) 
- (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+142-67) 
- (modified) mlir/lib/Dialect/Tensor/Transforms/SwapExtractSliceWithProducerPatterns.cpp (+37-10) 
- (modified) mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir (+289) 
- (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp (+27-20) 
- (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td (+4-3) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index f686ae07b9a99..7b6e3cba5723d 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -319,19 +319,24 @@ tileConsumerAndFuseProducersUsingSCF(RewriterBase &rewriter,
                                      TilingInterface consumer,
                                      const SCFTileAndFuseOptions &options);
 
-/// Fuse the consumer of the source of `candidateSliceOp` by computing the
-/// required slice of the consumer in-place.  Note that the method
-/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer
-/// value but does not delete the slice operation.
+/// Fuse the consumer of the result of every element of `candidateSliceOp` by
+/// computing the required slice of the consumer in-place. All the entries of
+/// `candidateSlices` are expected to map to the same consumer. The method
+/// returns an error if the consumer cannot be tiled in a manner that is
+/// consistent for all the passed slices. Note that the method replaces the uses
+/// of `candidateSliceOp` with the tiled and fused consumer value but does not
+/// delete the slice operation.
 struct SCFFuseConsumerOfSliceResult {
-  OpOperand *origConsumerOperand; // Original untiled consumer's operand.
-  OpOperand
-      *tiledAndFusedConsumerOperand; // Tiled and fused consumer's operand.
+  // Original untiled consumer's operand.
+  SmallVector<OpOperand *> origConsumerOperands;
+  // Tiled and fused consumer's operand.
+  SmallVector<OpOperand *> tiledAndFusedConsumerOperands;
   SmallVector<Operation *> tiledOps;
 };
 FailureOr<scf::SCFFuseConsumerOfSliceResult>
-tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp,
-                           MutableArrayRef<LoopLikeOpInterface> loops);
+tileAndFuseConsumerOfSlices(RewriterBase &rewriter,
+                            ArrayRef<Operation *> candidateSlices,
+                            MutableArrayRef<LoopLikeOpInterface> loops);
 
 /// Method to lower an `op` that implements the `TilingInterface` to
 /// loops/scalars.
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index 18981337742eb..8f6eb1bd47782 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -31,12 +31,16 @@ namespace tensor {
 FailureOr<TilingResult> replaceExtractSliceWithTiledProducer(
     OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp);
 
-/// Method to swap an `tensor.insert_slice` with its consumer when the
-/// consumer implements the `TilingInterface`.
+/// Method to swap an `tensor.insert_slice`s with its consumer when the
+/// consumer implements the `TilingInterface`. The size of `sliceOps` and
+/// `consumerOperands` is expected to be the same. Every entry in
+/// `consumerOperands` represents the use of the result of the corresponding
+/// entry in `sliceOps`. All entries of `consumerOperands` is expected to be
+/// uses in the same consumer.
 FailureOr<TilingResult>
-replaceInsertSliceWithTiledConsumer(OpBuilder &builder,
-                                    OffsetSizeAndStrideOpInterface sliceOp,
-                                    OpOperand &consumerOp);
+replaceInsertSlicesWithTiledConsumer(OpBuilder &builder,
+                                     ArrayRef<tensor::InsertSliceOp> sliceOps,
+                                     ArrayRef<OpOperand *> consumerOperands);
 
 //===----------------------------------------------------------------------===//
 // Populate functions.
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index 31f54413a5ff0..663c256c848df 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -272,7 +272,7 @@ class OpFoldResult : public PointerUnion<Attribute, Value> {
   using PointerUnion<Attribute, Value>::PointerUnion;
 
 public:
-  void dump() const { llvm::errs() << *this << "\n"; }
+  LLVM_DUMP_METHOD void dump() const { llvm::errs() << *this << "\n"; }
 
   MLIRContext *getContext() const {
     PointerUnion pu = *this;
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index cdf3d01ce8a84..7ebdd8907e964 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -202,28 +202,28 @@ def TilingInterface : OpInterface<"TilingInterface"> {
       InterfaceMethod<
         /*desc=*/[{
           Method to generate the tiled implementation of an operation that uses
-          exactly a tile of the given operand.
+          exactly tiles of the given operands.
 
           This method is required to allow operations to be "tiled and fused"
-          with an (already tiled) producer. Given a tile of the producer, this
-          method generates the tile of the consumer that uses exactly this
-          produced tile. In some sense it is the "reverse" of
+          with an (already tiled) producer. Given a tiles of the producer, this
+          method generates the tile of the consumer that uses exactly these
+          produced tiles. In some sense it is the "reverse" of
           `generateResultTileValue`.
-          - `operandNumber` is the result of the producer used by the consumer.
-          - `offsets` is the offset of the slice of the producer result used by
-            the tiled implementation of the consumer.
-          - `sizes` is the size of the slice of the producer result used by the
+          - `operandNumbers` is the list of operands whose tiles are "produced".
+          - `allOffsets` is the offset of the slice of the producer results used
+            by the tiled implementation of the consumer.
+          - `allSizes` is the size of the slice of the producer results used by the
             consumer.
-          If it is illegal to fuse with a producer along the given operand for
+          If it is illegal to fuse with a producer along the given operand tiles for
           an operation, the implementation should return a failure.
         }],
         /*retType=*/"::mlir::FailureOr<::mlir::TilingResult>",
-        /*methodName=*/"getTiledImplementationFromOperandTile",
+        /*methodName=*/"getTiledImplementationFromOperandTiles",
         /*args=*/(ins
           "::mlir::OpBuilder &":$b,
-          "unsigned":$operandNumber,
-          "::mlir::ArrayRef<::mlir::OpFoldResult>":$offsets,
-          "::mlir::ArrayRef<::mlir::OpFoldResult>":$sizes),
+          "::mlir::ArrayRef<unsigned>":$operandNumbers,
+          "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allOffsets,
+          "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>>":$allSizes),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return failure();
@@ -235,13 +235,14 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           tile of the operand.
 
           This method is required to allow operations to be "tiled and fused"
-          with an (already tiled) producer. Given a tile of an operand,
-          returns the tile of the iteration space that uses this tile.
-          - `operandNumber` is the result of the producer used by the consumer.
-          - `offsets` is the offset of the slice of the producer result used by
-            the tiled implementation of the consumer.
-          - `sizes` is the size of the slice of the producer result used by the
-            consumer.
+          with an (already tiled) producer. Given tiles of an operand,
+          returns the tile of the iteration space that uses these tiles.
+          - `operandNumbers` is the list of operands whose tiles are "produced"
+            by the producer(s).
+          - `allOffsets` is the offset of the slice of the producer results
+            used by the tiled implementation of the consumer.
+          - `allSizes` is the size of the slice of the producer results used by
+            the consumer.
           If it is illegal to fuse with a producer along the given operand for
           an operation, or if this mapping cannot be computed, the
           implementation should return a failure.
@@ -285,17 +286,17 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           transformation. It does not provide guarantees on whether such a
           transformation is profitable.
 
-          For most cases `getTiledImplementationFromOperandTile` could be a
-          implemented using `getIterationDomainTileFromOperandTile` +
+          For most cases `getTiledImplementationFromOperandTiles` could be a
+          implemented using `getIterationDomainTileFromOperandTiles` +
           `getTiledImplementation` methods.
         }],
         /*retType=*/"::llvm::LogicalResult",
-        /*methodName=*/"getIterationDomainTileFromOperandTile",
+        /*methodName=*/"getIterationDomainTileFromOperandTiles",
         /*args=*/(ins
           "::mlir::OpBuilder &":$b,
-          "unsigned":$operandNumber,
-          "::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
-          "::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
+          "::mlir::ArrayRef<unsigned>":$operandNumbers,
+          "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>> ":$allOffsets,
+          "::mlir::ArrayRef<::mlir::SmallVector<::mlir::OpFoldResult>> ":$allSizes,
           "::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainOffsets,
           "::mlir::SmallVectorImpl<::mlir::OpFoldResult> &":$iterDomainSizes),
         /*methodBody=*/"",
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index 7c14cc16437fe..86045b54075bc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -147,55 +147,80 @@ struct LinalgOpTilingInterface
   /// Utility to fetch the offsets and sizes when applied as per the indexing
   /// map of the linalg op. This helps in fusing the linalg op as a consumer of
   /// a given slice op.
-  void
-  getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
-                         ArrayRef<OpFoldResult> offsets,
-                         ArrayRef<OpFoldResult> sizes,
-                         SmallVectorImpl<OpFoldResult> &mappedOffsets,
-                         SmallVectorImpl<OpFoldResult> &mappedSizes) const {
-    unsigned numLoops = linalgOp.getNumLoops();
-    auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
-    mappedOffsets.resize(numLoops);
-    mappedSizes.resize(numLoops);
-    if (!indexingMap.isPermutation()) {
-      SmallVector<Range> iterationDomain =
-          tilingInterfaceOp.getIterationDomain(b);
-      for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {
-        mappedOffsets[index] = value.offset;
-        mappedSizes[index] = value.size;
+  static LogicalResult
+  getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b,
+                         ArrayRef<AffineMap> indexingMaps,
+                         ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+                         ArrayRef<SmallVector<OpFoldResult>> allSizes,
+                         SmallVectorImpl<OpFoldResult> &mappedOffsetsVec,
+                         SmallVectorImpl<OpFoldResult> &mappedSizesVec) {
+    DenseMap<unsigned, OpFoldResult> mappedOffsets, mappedSizes;
+
+    for (auto [indexingMap, offsets, sizes] :
+         llvm::zip_equal(indexingMaps, allOffsets, allSizes)) {
+      for (auto [resultExpr, offset, size] :
+           llvm::zip_equal(indexingMap.getResults(), offsets, sizes)) {
+        auto dimExpr = dyn_cast<AffineDimExpr>(resultExpr);
+        if (!dimExpr)
+          continue;
+        unsigned position = dimExpr.getPosition();
+        auto it = mappedOffsets.find(position);
+        if (it != mappedOffsets.end()) {
+          OpFoldResult seenOffset = it->second;
+          OpFoldResult seenSize = mappedSizes.lookup(position);
+          if (seenOffset != offset || seenSize != size) {
+            return linalgOp->emitOpError(
+                "inconsistent iteration space mapping from offsets/sizes of "
+                "operands/results");
+          }
+        } else {
+          mappedOffsets[position] = offset;
+          mappedSizes[position] = size;
+        }
       }
     }
-    for (const auto &&[index, value] :
-         llvm::enumerate(indexingMap.getResults())) {
-      unsigned dimPosition = cast<AffineDimExpr>(value).getPosition();
-      mappedOffsets[dimPosition] = offsets[index];
-      mappedSizes[dimPosition] = sizes[index];
+
+    // Aggregate from the given operand offsets and sizes, or default to
+    // iteration space values.
+    SmallVector<Range> iterationDomain =
+        cast<TilingInterface>(linalgOp.getOperation()).getIterationDomain(b);
+    mappedOffsetsVec.resize(iterationDomain.size());
+    mappedSizesVec.resize(iterationDomain.size());
+    for (auto [index, domain] : llvm::enumerate(iterationDomain)) {
+      auto it = mappedOffsets.find(index);
+      if (it != mappedOffsets.end()) {
+        mappedOffsetsVec[index] = it->second;
+        mappedSizesVec[index] = mappedSizes.lookup(index);
+        continue;
+      }
+      mappedOffsetsVec[index] = domain.offset;
+      mappedOffsetsVec[index] = domain.size;
     }
+    return success();
   }
 
   /// Method to return the position of the result tile computed by the tiled
   /// operation.
-  LogicalResult getIterationDomainTileFromOperandTile(
-      Operation *op, OpBuilder &b, unsigned operandNumber,
-      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+  LogicalResult getIterationDomainTileFromOperandTiles(
+      Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+      ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+      ArrayRef<SmallVector<OpFoldResult>> allSizes,
       SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
       SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
     auto linalgOp = cast<LinalgOp>(op);
 
-    // Check that the indexing map used for the operand is a projected
-    // permutation. This could be relaxed with a more general approach that can
-    // map the offsets and sizes from the operand to iteration space tiles
-    // (filling in full extent for dimensions not used to access the result).
-    AffineMap indexingMap =
-        linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
-    if (!indexingMap.isProjectedPermutation()) {
-      return op->emitError()
-             << "unhandled get iter domain position when operand is not "
-                "accessed using a permuted projection";
+    std::optional<SmallVector<OpFoldResult>> iterationSpaceOffsets,
+        iterationSpaceSizes;
+    SmallVector<AffineMap> indexingMaps =
+        llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) {
+          OpOperand &opOperand = linalgOp->getOpOperand(operandNumber);
+          return linalgOp.getMatchingIndexingMap(&opOperand);
+        });
+    if (failed(getMappedOffsetAndSize(linalgOp, b, indexingMaps, allOffsets,
+                                      allSizes, iterDomainOffsets,
+                                      iterDomainSizes))) {
+      return failure();
     }
-
-    getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
-                           iterDomainOffsets, iterDomainSizes);
     return success();
   }
 
@@ -246,8 +271,13 @@ struct LinalgOpTilingInterface
           "accessed using a permuted projection");
     }
 
-    getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
-                           iterDomainOffsets, iterDomainSizes);
+    SmallVector<OpFoldResult> allOffsets = llvm::to_vector(offsets);
+    SmallVector<OpFoldResult> allSizes = llvm::to_vector(sizes);
+    auto status =
+        getMappedOffsetAndSize(linalgOp, b, indexingMap, {allOffsets},
+                               {allSizes}, iterDomainOffsets, iterDomainSizes);
+    (void)status;
+    assert(succeeded(status) && "unexpected error in offset calculation");
     return success();
   }
 
@@ -278,12 +308,13 @@ struct LinalgOpTilingInterface
 
   /// Method to generate the tiled implementation of an operation from the tile
   /// of the operand.
-  FailureOr<TilingResult> getTiledImplementationFromOperandTile(
-      Operation *op, OpBuilder &b, unsigned operandNumber,
-      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+  FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
+      Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+      ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+      ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
     SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
-    if (failed(getIterationDomainTileFromOperandTile(
-            op, b, operandNumber, offsets, sizes, mappedOffsets,
+    if (failed(getIterationDomainTileFromOperandTiles(
+            op, b, operandNumbers, allOffsets, allSizes, mappedOffsets,
             mappedSizes))) {
       return failure();
     }
@@ -750,13 +781,17 @@ struct PackOpTiling
   /// Method to return the position of iteration domain tile computed by the
   /// tiled operation. In current `tensor.pack` context, the `resultOffsets` and
   /// `resultSizes` only cover outer dimensions.
-  LogicalResult getIterationDomainTileFromOperandTile(
-      Operation *op, OpBuilder &b, unsigned operandNumber,
-      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+  LogicalResult getIterationDomainTileFromOperandTiles(
+      Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+      ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+      ArrayRef<SmallVector<OpFoldResult>> allSizes,
       SmallVectorImpl<OpFoldResult> &resultOffsets,
       SmallVectorImpl<OpFoldResult> &resultSizes) const {
-    if (operandNumber != 0)
-      return failure();
+    if (operandNumbers.size() != 1 || operandNumbers[0] != 0)
+      return op->emitOpError("unsupporeted operands for consumer fusion");
+
+    ArrayRef<OpFoldResult> offsets(allOffsets[0]);
+    ArrayRef<OpFoldResult> sizes(allSizes[0]);
 
     auto packOp = cast<PackOp>(op);
     // It is not trivial to infer dest tile from source tile if `packOp` has
@@ -817,11 +852,15 @@ struct PackOpTiling
   }
 
   /// Method to return the tiled implementation of tensor.pack as a consumer.
-  FailureOr<TilingResult> getTiledImplementationFromOperandTile(
-      Operation *op, OpBuilder &b, unsigned operandNumber,
-      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
-    if (operandNumber != 0)
-      return failure();
+  FailureOr<TilingResult> getTiledImplementationFromOperandTiles(
+      Operation *op, OpBuilder &b, ArrayRef<unsigned> operandNumbers,
+      ArrayRef<SmallVector<OpFoldResult>> allOffsets,
+      ArrayRef<SmallVector<OpFoldResult>> allSizes) const {
+    if (operandNumbers.size() != 1 || operandNumbers[0] != 0)
+      return op->emitOpError("unhandled operands for consumer fusion");
+
+    ArrayRef<OpFoldResult> offsets(allOffsets[0]);
+    ArrayRef<OpFoldResult> sizes(allSizes[0]);
 
     auto packOp = cast<PackOp>(op);
     Location loc = packOp.getLoc();
@@ -836,8 +875,8 @@ struct PackOpTiling
     tiledOperands.push_back(sourceSlice);
 
     SmallVector<OpFoldResult> outerDimOffsets, outerDimSizes;
-    if (failed(getIterationDomainTileFromOperandTile(
-            op, b, /*operandNumber=*/0, offsets, sizes, outerDimOffsets,
+    if (failed(getIterationDomainTileFromOperandTiles(
+            op, b, operandNumbers, allOffsets, allSizes, outerDimOffsets,
             outerDimSizes)))
       return failure();
 
@@ -1095,12 +1134,20 @@ struct UnPackOpTiling
 
   /// Method to return the position of iteration domain tile computed by the
   /// tiled operation.
-  LogicalResult getIterationDomainTileFromOperandTile(
-      Operation *op, OpBuilder &b, unsigned operandNumber,
-      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+  LogicalResult getIterationDomainTileFromOperandTiles(
+      Operation *op, OpB...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/145193


More information about the Mlir-commits mailing list