[Mlir-commits] [mlir] [mlir][linalg] Enable fuse consumer (PR #85528)

donald chen llvmlistbot at llvm.org
Sat Apr 20 02:05:22 PDT 2024


https://github.com/cxy-1993 updated https://github.com/llvm/llvm-project/pull/85528

>From e906efbe8f77e03dfe975b1aa4956693153ebffe Mon Sep 17 00:00:00 2001
From: cxy <chenxunyu1993 at gmail.com>
Date: Sat, 16 Mar 2024 09:12:12 +0800
Subject: [PATCH] [mlir][linalg] Enable fuse consumer

This patch adds support for consumer fusion to the tiling interface.

- Add interface method 'getIterationDomainTileFromOperandTile' to tiling
  interface which get iteration domain position from operand position.
- Add interface method 'getTiledImplementationFromOperandTile' to tiling
  interface which generate tiled implementation according to operand position.
---
 .../mlir/Interfaces/TilingInterface.td        |  67 ++++++++++-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |   4 +-
 .../Linalg/Transforms/TilingInterfaceImpl.cpp | 106 +++++++++++++-----
 .../Tensor/IR/TensorTilingInterfaceImpl.cpp   |  12 +-
 4 files changed, 149 insertions(+), 40 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 66382f29c24249..84f7dec2f4003d 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -63,7 +63,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           The method returns the operation that is the tiled
           implementation.
         }],
-        /*retType=*/"FailureOr<TilingResult>",
+        /*retType=*/"FailureOr<::mlir::TilingResult>",
         /*methodName=*/"getTiledImplementation",
         /*args=*/(ins
             "OpBuilder &":$b,
@@ -82,15 +82,34 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           by the tiled implementation. Expects the same `offsets` and `sizes` as
           used to obtain the tiled implementation of the operation.
         }],
-        /*retType=*/"LogicalResult",
+        /*retType=*/"::mlir::LogicalResult",
         /*methodName=*/"getResultTilePosition",
         /*args=*/(ins
           "OpBuilder &":$b,
           "unsigned":$resultNumber,
           "ArrayRef<OpFoldResult> ":$offsets,
           "ArrayRef<OpFoldResult> ":$sizes,
-          "SmallVector<OpFoldResult> &":$resultOffsets,
-          "SmallVector<OpFoldResult> &":$resultSizes),
+          "SmallVectorImpl<OpFoldResult> &":$resultOffsets,
+          "SmallVectorImpl<OpFoldResult> &":$resultSizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to return the position of iteration domain tile computed by the
+          tiled operation.
+        }],
+        /*retType=*/"::mlir::LogicalResult",
+        /*methodName=*/"getIterationDomainTileFromOperandTile",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$operandNumber,
+          "ArrayRef<OpFoldResult> ":$offsets,
+          "ArrayRef<OpFoldResult> ":$sizes,
+          "SmallVectorImpl<OpFoldResult> &":$iterDomainOffsets,
+          "SmallVectorImpl<OpFoldResult> &":$iterDomainSizes),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return failure();
@@ -119,7 +138,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
             iteration space).
           - `sizes` provides the size of the tile.
         }],
-        /*retType=*/"FailureOr<TilingResult>",
+        /*retType=*/"FailureOr<::mlir::TilingResult>",
         /*methodName=*/"generateResultTileValue",
         /*args=*/(ins
           "OpBuilder &":$b,
@@ -131,6 +150,42 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           return failure();
         }]
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to generate the tiled implementation of an operation from
+          operand tile position.
+
+          Generates the IR that computes the tiled implementation of an
+          operation from operand tile.  The `offsets` and `sizes`
+          describe the tile of the operand required. This is different from
+          `getTiledImplementation` which generates the tiled
+          implementation of the operation given a tile of the
+          iteration space. This method generates a tiled
+          implementation of the operation based on the tile of the
+          operand required. This method enables consumer fusion by using
+          tile and fuse. The method returns failure if the operation
+          can't be tiled to generate the operand tile. In practical terms
+          this implies it cannot be tiled and fused with its producers.
+
+          - `offsets` provides the offset of the tile in the coordinate system
+            of the original iteration space, i.e., if an iteration space
+            dimension had non-zero offset, it must be included in the offset
+            provided here (as opposed to zero-based offset "relative" to the
+            iteration space).
+          - `sizes` provides the size of the tile.
+        }],
+        /*retType=*/"FailureOr<::mlir::TilingResult>",
+        /*methodName=*/"getTiledImplementationFromOperandTile",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$operandNumber,
+          "ArrayRef<OpFoldResult>":$offsets,
+          "ArrayRef<OpFoldResult>":$sizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Generates the scalar implementation of the operation.
@@ -142,7 +197,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           transformations are done, this method can be used to lower to scalar
           code that can then be lowered to LLVM or SPIR-V dialects.
         }],
-        /*retType=*/"LogicalResult",
+        /*retType=*/"::mlir::LogicalResult",
         /*methodName=*/"generateScalarImplementation",
         /*args=*/(ins
             "OpBuilder &":$b,
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9c5c58fa1fabfb..e9999c34d0face 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2425,8 +2425,8 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
 
 LogicalResult SoftmaxOp::getResultTilePosition(
     OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
-    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
-    SmallVector<OpFoldResult> &resultSizes) {
+    ArrayRef<OpFoldResult> sizes, SmallVectorImpl<OpFoldResult> &resultOffsets,
+    SmallVectorImpl<OpFoldResult> &resultSizes) {
   if (resultNumber == 0) {
     resultOffsets.assign(offsets.begin(), offsets.end());
     resultSizes.assign(sizes.begin(), sizes.end());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index bd870d4f982e5d..71e9c3771dcded 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -110,7 +110,7 @@ struct LinalgOpTilingInterface
         }));
   }
 
-  // Instantiate the tiled implementation of the operation.
+  /// Instantiate the tiled implementation of the operation.
   FailureOr<TilingResult>
   getTiledImplementation(Operation *op, OpBuilder &b,
                          ArrayRef<OpFoldResult> offsets,
@@ -132,14 +132,66 @@ struct LinalgOpTilingInterface
     return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
   }
 
-  // Return the details of the output tile generated by the tiled
-  // implementation.
+  void
+  getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
+                         ArrayRef<OpFoldResult> offsets,
+                         ArrayRef<OpFoldResult> sizes,
+                         SmallVectorImpl<OpFoldResult> &mappedOffsets,
+                         SmallVectorImpl<OpFoldResult> &mappedSizes) const {
+    unsigned numLoops = linalgOp.getNumLoops();
+    auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
+    mappedOffsets.resize(numLoops);
+    mappedSizes.resize(numLoops);
+    if (!indexingMap.isPermutation()) {
+      SmallVector<Range> iterationDomain =
+          tilingInterfaceOp.getIterationDomain(b);
+      for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {
+        mappedOffsets[index] = value.offset;
+        mappedSizes[index] = value.size;
+      }
+    }
+    for (const auto &&[index, value] :
+         llvm::enumerate(indexingMap.getResults())) {
+      unsigned dimPosition = cast<AffineDimExpr>(value).getPosition();
+      mappedOffsets[dimPosition] = offsets[index];
+      mappedSizes[dimPosition] = sizes[index];
+    }
+  }
+
+  /// Return the details of the output tile generated by the tiled
+  /// implementation.
+  LogicalResult getIterationDomainTileFromOperandTile(
+      Operation *op, OpBuilder &b, unsigned operandNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+      SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
+      SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
+    auto linalgOp = cast<LinalgOp>(op);
+
+    // Check that the indexing map used for the operand is a projected
+    // permutation. This could be relaxed with a more general approach that can
+    // map the offsets and sizes from the operand to iteration space tiles
+    // (filling in full extent for dimensions not used to access the result).
+    AffineMap indexingMap =
+        linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
+    if (!indexingMap.isProjectedPermutation()) {
+      return emitError(op->getLoc(),
+                       "unhandled get iter domain position when operand is not "
+                       "accessed using a permuted projection");
+    }
+
+    getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
+                           iterDomainOffsets, iterDomainSizes);
+    return success();
+  }
+
+  /// Return the details of the output tile generated by the tiled
+  /// implementation.
   LogicalResult
   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
                         ArrayRef<OpFoldResult> offsets,
                         ArrayRef<OpFoldResult> sizes,
-                        SmallVector<OpFoldResult> &resultOffsets,
-                        SmallVector<OpFoldResult> &resultSizes) const {
+                        SmallVectorImpl<OpFoldResult> &resultOffsets,
+                        SmallVectorImpl<OpFoldResult> &resultSizes) const {
     Location loc = op->getLoc();
     LinalgOp linalgOp = cast<LinalgOp>(op);
 
@@ -160,6 +212,21 @@ struct LinalgOpTilingInterface
     return success();
   }
 
+  FailureOr<TilingResult> getTiledImplementationFromOperandTile(
+      Operation *op, OpBuilder &b, unsigned operandNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+    auto tilingInterfaceOp = cast<TilingInterface>(op);
+    if (failed(tilingInterfaceOp.getIterationDomainTileFromOperandTile(
+            b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
+      return emitError(
+          op->getLoc(),
+          "unable to obtain the iter domain position of the operation.");
+    }
+    return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
+                                                    mappedSizes);
+  }
+
   FailureOr<TilingResult>
   generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
                           ArrayRef<OpFoldResult> offsets,
@@ -177,29 +244,16 @@ struct LinalgOpTilingInterface
           "unhandled tiled implementation generation when result is not "
           "accessed using a permuted projection");
     }
-
-    auto numLoops = linalgOp.getNumLoops();
+    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+    getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
+                           mappedOffsets, mappedSizes);
     auto tilingInterfaceOp = cast<TilingInterface>(op);
-    SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
-        iterationTileSizes(numLoops);
-    if (!indexingMap.isPermutation()) {
-      SmallVector<Range> iterationDomain =
-          tilingInterfaceOp.getIterationDomain(b);
-      for (const auto &range : llvm::enumerate(iterationDomain)) {
-        iterationTileOffsets[range.index()] = range.value().offset;
-        iterationTileSizes[range.index()] = range.value().size;
-      }
-    }
-    for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
-      unsigned dimPosition =
-          cast<AffineDimExpr>(resultExpr.value()).getPosition();
-      iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
-      iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
-    }
-
     FailureOr<TilingResult> tilingResult =
-        tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
-                                                 iterationTileSizes);
+        tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
+
+    if (failed(tilingResult))
+      return failure();
+
     if (tilingResult->tiledOps.size() != 1)
       return op->emitOpError("failed to generate tiled implementation");
 
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index d25efcf50ec566..296c5fc7a5c2bd 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -61,8 +61,8 @@ struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
                         ArrayRef<OpFoldResult> offsets,
                         ArrayRef<OpFoldResult> sizes,
-                        SmallVector<OpFoldResult> &resultOffsets,
-                        SmallVector<OpFoldResult> &resultSizes) const {
+                        SmallVectorImpl<OpFoldResult> &resultOffsets,
+                        SmallVectorImpl<OpFoldResult> &resultSizes) const {
     resultOffsets.assign(offsets.begin(), offsets.end());
     resultSizes.assign(sizes.begin(), sizes.end());
     return success();
@@ -199,8 +199,8 @@ struct PackOpTiling
   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
                         ArrayRef<OpFoldResult> offsets,
                         ArrayRef<OpFoldResult> sizes,
-                        SmallVector<OpFoldResult> &resultOffsets,
-                        SmallVector<OpFoldResult> &resultSizes) const {
+                        SmallVectorImpl<OpFoldResult> &resultOffsets,
+                        SmallVectorImpl<OpFoldResult> &resultSizes) const {
     // The iteration domain is over outer dimensions of packed layout. In this
     // context, the outer dimensions of `resultOffsets` are `offsets`. The
     // inner dimensions of `resultOffsets` are zeros because tiling is not
@@ -452,8 +452,8 @@ struct UnPackOpTiling
   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
                         ArrayRef<OpFoldResult> offsets,
                         ArrayRef<OpFoldResult> sizes,
-                        SmallVector<OpFoldResult> &resultOffsets,
-                        SmallVector<OpFoldResult> &resultSizes) const {
+                        SmallVectorImpl<OpFoldResult> &resultOffsets,
+                        SmallVectorImpl<OpFoldResult> &resultSizes) const {
     resultOffsets = llvm::to_vector(offsets);
     resultSizes = llvm::to_vector(sizes);
     return success();



More information about the Mlir-commits mailing list