[Mlir-commits] [mlir] [mlir][linalg] Enable fuse consumer (PR #85528)

donald chen llvmlistbot at llvm.org
Wed Apr 17 06:34:24 PDT 2024


https://github.com/cxy-1993 updated https://github.com/llvm/llvm-project/pull/85528

>From 8f7208b63c2a0b50084fe5983fd8ea6929e0187e Mon Sep 17 00:00:00 2001
From: cxy <chenxunyu1993 at gmail.com>
Date: Sat, 16 Mar 2024 09:12:12 +0800
Subject: [PATCH] [mlir][linalg] Enable fuse consumer

This patch adds support for consumer fusion to the tiling interface.

- Add interface method 'getIterationDomainTileFromOperandTile' to tiling
  interface which get iteration domain position from operand position.
- Add interface method 'getTiledImplementationFromOperandTile' to tiling
  interface which generate tiled implementation according to operand position.
---
 .../mlir/Interfaces/TilingInterface.td        | 55 +++++++++++
 .../Linalg/Transforms/TilingInterfaceImpl.cpp | 95 +++++++++++++++----
 2 files changed, 129 insertions(+), 21 deletions(-)

diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 66382f29c24249..9fff3465ad9e72 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -96,6 +96,25 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           return failure();
         }]
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to return the position of iteration domain tile computed by the
+          tiled operation.
+        }],
+        /*retType=*/"LogicalResult",
+        /*methodName=*/"getIterationDomainTileFromOperandTile",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$operandNumber,
+          "ArrayRef<OpFoldResult> ":$offsets,
+          "ArrayRef<OpFoldResult> ":$sizes,
+          "SmallVector<OpFoldResult> &":$iterDomainOffsets,
+          "SmallVector<OpFoldResult> &":$iterDomainSizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Method to generate the code that produces a tile of the result.
@@ -131,6 +150,42 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           return failure();
         }]
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to generate the tiled implementation of an operation from
+          operand tile position.
+
+          Generates the IR that computes the tiled implementation of an
+          operation from operand tile.  The `offsets` and `sizes`
+          describe the tile of the operand required. This is different from
+          `getTiledImplementation` which generates the tiled
+          implementation of the operation given a tile of the
+          iteration space. This method generates a tiled
+          implementation of the operation based on the tile of the
+          operand required. This method enables fusion consumer by using
+          tile and fuse. The method returns failure if the operation
+          can't be tiled to generate the operand tile. In practical terms
+          this implies it cannot be tiled and fused with its producers.
+
+          - `offsets` provides the offset of the tile in the coordinate system
+            of the original iteration space, i.e., if an iteration space
+            dimension had non-zero offset, it must be included in the offset
+            provided here (as opposed to zero-based offset "relative" to the
+            iteration space).
+          - `sizes` provides the size of the tile.
+        }],
+        /*retType=*/"FailureOr<TilingResult>",
+        /*methodName=*/"getTiledImplementationFromOperandTile",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$operandNumber,
+          "ArrayRef<OpFoldResult>":$offsets,
+          "ArrayRef<OpFoldResult>":$sizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Generates the scalar implementation of the operation.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index bd870d4f982e5d..094ae81d2a01fc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -132,6 +132,58 @@ struct LinalgOpTilingInterface
     return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
   }
 
+  void getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b,
+                              AffineMap indexingMap,
+                              ArrayRef<OpFoldResult> offsets,
+                              ArrayRef<OpFoldResult> sizes,
+                              SmallVector<OpFoldResult> &mappedOffsets,
+                              SmallVector<OpFoldResult> &mappedSizes) const {
+    auto numLoops = linalgOp.getNumLoops();
+    auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
+    mappedOffsets.resize(numLoops);
+    mappedSizes.resize(numLoops);
+    if (!indexingMap.isPermutation()) {
+      SmallVector<Range> iterationDomain =
+          tilingInterfaceOp.getIterationDomain(b);
+      for (const auto &range : llvm::enumerate(iterationDomain)) {
+        mappedOffsets[range.index()] = range.value().offset;
+        mappedSizes[range.index()] = range.value().size;
+      }
+    }
+    for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
+      unsigned dimPosition =
+          cast<AffineDimExpr>(resultExpr.value()).getPosition();
+      mappedOffsets[dimPosition] = offsets[resultExpr.index()];
+      mappedSizes[dimPosition] = sizes[resultExpr.index()];
+    }
+  }
+
+  // Return the details of the output tile generated by the tiled
+  // implementation.
+  LogicalResult getIterationDomainTileFromOperandTile(
+      Operation *op, OpBuilder &b, unsigned operandNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+      SmallVector<OpFoldResult> &iterDomainOffsets,
+      SmallVector<OpFoldResult> &iterDomainSizes) const {
+    auto linalgOp = cast<LinalgOp>(op);
+
+    // Check that the indexing map used for the operand is a projected
+    // permutation. This could be relaxed with a more general approach that can
+    // map the offsets and sizes from the operand to iteration space tiles
+    // (filling in full extent for dimensions not used to access the result).
+    AffineMap indexingMap =
+        linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
+    if (!indexingMap.isProjectedPermutation()) {
+      return op->emitOpError(
+          "unhandled get iter domain position when operand is not "
+          "accessed using a permuted projection");
+    }
+
+    getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
+                           iterDomainOffsets, iterDomainSizes);
+    return success();
+  }
+
   // Return the details of the output tile generated by the tiled
   // implementation.
   LogicalResult
@@ -160,6 +212,20 @@ struct LinalgOpTilingInterface
     return success();
   }
 
+  FailureOr<TilingResult> getTiledImplementationFromOperandTile(
+      Operation *op, OpBuilder &b, unsigned operandNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+    auto tilingInterfaceOp = cast<TilingInterface>(op);
+    if (failed(tilingInterfaceOp.getIterationDomainTileFromOperandTile(
+            b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
+      return op->emitOpError(
+          "unable to obtain the iter domain position of the operation.");
+    }
+    return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
+                                                    mappedSizes);
+  }
+
   FailureOr<TilingResult>
   generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
                           ArrayRef<OpFoldResult> offsets,
@@ -177,29 +243,16 @@ struct LinalgOpTilingInterface
           "unhandled tiled implementation generation when result is not "
           "accessed using a permuted projection");
     }
-
-    auto numLoops = linalgOp.getNumLoops();
+    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+    getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
+                           mappedOffsets, mappedSizes);
     auto tilingInterfaceOp = cast<TilingInterface>(op);
-    SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
-        iterationTileSizes(numLoops);
-    if (!indexingMap.isPermutation()) {
-      SmallVector<Range> iterationDomain =
-          tilingInterfaceOp.getIterationDomain(b);
-      for (const auto &range : llvm::enumerate(iterationDomain)) {
-        iterationTileOffsets[range.index()] = range.value().offset;
-        iterationTileSizes[range.index()] = range.value().size;
-      }
-    }
-    for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
-      unsigned dimPosition =
-          cast<AffineDimExpr>(resultExpr.value()).getPosition();
-      iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
-      iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
-    }
-
     FailureOr<TilingResult> tilingResult =
-        tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
-                                                 iterationTileSizes);
+        tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
+
+    if (failed(tilingResult))
+      return failure();
+
     if (tilingResult->tiledOps.size() != 1)
       return op->emitOpError("failed to generate tiled implementation");
 



More information about the Mlir-commits mailing list