[Mlir-commits] [mlir] [mlir][linalg] Enable fuse consumer (PR #89893)

donald chen llvmlistbot at llvm.org
Wed Apr 24 01:50:11 PDT 2024


https://github.com/cxy-1993 updated https://github.com/llvm/llvm-project/pull/89893

>From cb24711f662f0a657f77e58630fd469519d3c495 Mon Sep 17 00:00:00 2001
From: cxy <chenxunyu1993 at gmail.com>
Date: Tue, 23 Apr 2024 12:28:19 +0000
Subject: [PATCH] [mlir][linalg] Enable fuse consumer

This patch adds support for consumer fusion to the tiling interface, and
implements fuse consumers on FuseIntoContainingOp.

- Add interface method 'getIterDomainTilePositionFromOperandPosition' to
tiling interface which get iteration domain position from operand
position.
- Add interface method 'getTiledImplementationFromOperandPosition' to
tiling interface which generate tiled implementation according to
operand position.
- Implemented the above two methods and supported consumer fusion for
FuseIntoContainingOp.
---
 .../mlir/Interfaces/TilingInterface.td        |  67 +++-
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp      |   4 +-
 .../Linalg/Transforms/TilingInterfaceImpl.cpp | 106 +++++--
 .../Tensor/IR/TensorTilingInterfaceImpl.cpp   |  12 +-
 .../Dialect/Linalg/test-fuse-consumer.mlir    | 103 ++++++
 mlir/test/lib/Dialect/Linalg/CMakeLists.txt   |   1 +
 .../Dialect/Linalg/TestLinalgFuseConsumer.cpp | 294 ++++++++++++++++++
 mlir/tools/mlir-opt/mlir-opt.cpp              |   2 +
 8 files changed, 549 insertions(+), 40 deletions(-)
 create mode 100644 mlir/test/Dialect/Linalg/test-fuse-consumer.mlir
 create mode 100644 mlir/test/lib/Dialect/Linalg/TestLinalgFuseConsumer.cpp

diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 66382f29c24249..84f7dec2f4003d 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -63,7 +63,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           The method returns the operation that is the tiled
           implementation.
         }],
-        /*retType=*/"FailureOr<TilingResult>",
+        /*retType=*/"FailureOr<::mlir::TilingResult>",
         /*methodName=*/"getTiledImplementation",
         /*args=*/(ins
             "OpBuilder &":$b,
@@ -82,15 +82,34 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           by the tiled implementation. Expects the same `offsets` and `sizes` as
           used to obtain the tiled implementation of the operation.
         }],
-        /*retType=*/"LogicalResult",
+        /*retType=*/"::mlir::LogicalResult",
         /*methodName=*/"getResultTilePosition",
         /*args=*/(ins
           "OpBuilder &":$b,
           "unsigned":$resultNumber,
           "ArrayRef<OpFoldResult> ":$offsets,
           "ArrayRef<OpFoldResult> ":$sizes,
-          "SmallVector<OpFoldResult> &":$resultOffsets,
-          "SmallVector<OpFoldResult> &":$resultSizes),
+          "SmallVectorImpl<OpFoldResult> &":$resultOffsets,
+          "SmallVectorImpl<OpFoldResult> &":$resultSizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to return the position of iteration domain tile computed by the
+          tiled operation.
+        }],
+        /*retType=*/"::mlir::LogicalResult",
+        /*methodName=*/"getIterationDomainTileFromOperandTile",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$operandNumber,
+          "ArrayRef<OpFoldResult> ":$offsets,
+          "ArrayRef<OpFoldResult> ":$sizes,
+          "SmallVectorImpl<OpFoldResult> &":$iterDomainOffsets,
+          "SmallVectorImpl<OpFoldResult> &":$iterDomainSizes),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return failure();
@@ -119,7 +138,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
             iteration space).
           - `sizes` provides the size of the tile.
         }],
-        /*retType=*/"FailureOr<TilingResult>",
+        /*retType=*/"FailureOr<::mlir::TilingResult>",
         /*methodName=*/"generateResultTileValue",
         /*args=*/(ins
           "OpBuilder &":$b,
@@ -131,6 +150,42 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           return failure();
         }]
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to generate the tiled implementation of an operation from
+          operand tile position.
+
+          Generates the IR that computes the tiled implementation of an
+          operation from operand tile.  The `offsets` and `sizes`
+          describe the tile of the operand required. This is different from
+          `getTiledImplementation` which generates the tiled
+          implementation of the operation given a tile of the
+          iteration space. This method generates a tiled
+          implementation of the operation based on the tile of the
+          operand required. This method enables consumer fusion by using
+          tile and fuse. The method returns failure if the operation
+          can't be tiled to generate the operand tile. In practical terms
+          this implies it cannot be tiled and fused with its producers.
+
+          - `offsets` provides the offset of the tile in the coordinate system
+            of the original iteration space, i.e., if an iteration space
+            dimension had non-zero offset, it must be included in the offset
+            provided here (as opposed to zero-based offset "relative" to the
+            iteration space).
+          - `sizes` provides the size of the tile.
+        }],
+        /*retType=*/"FailureOr<::mlir::TilingResult>",
+        /*methodName=*/"getTiledImplementationFromOperandTile",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$operandNumber,
+          "ArrayRef<OpFoldResult>":$offsets,
+          "ArrayRef<OpFoldResult>":$sizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Generates the scalar implementation of the operation.
@@ -142,7 +197,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           transformations are done, this method can be used to lower to scalar
           code that can then be lowered to LLVM or SPIR-V dialects.
         }],
-        /*retType=*/"LogicalResult",
+        /*retType=*/"::mlir::LogicalResult",
         /*methodName=*/"generateScalarImplementation",
         /*args=*/(ins
             "OpBuilder &":$b,
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9c5c58fa1fabfb..e9999c34d0face 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2425,8 +2425,8 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
 
 LogicalResult SoftmaxOp::getResultTilePosition(
     OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
-    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
-    SmallVector<OpFoldResult> &resultSizes) {
+    ArrayRef<OpFoldResult> sizes, SmallVectorImpl<OpFoldResult> &resultOffsets,
+    SmallVectorImpl<OpFoldResult> &resultSizes) {
   if (resultNumber == 0) {
     resultOffsets.assign(offsets.begin(), offsets.end());
     resultSizes.assign(sizes.begin(), sizes.end());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index bd870d4f982e5d..71e9c3771dcded 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -110,7 +110,7 @@ struct LinalgOpTilingInterface
         }));
   }
 
-  // Instantiate the tiled implementation of the operation.
+  /// Instantiate the tiled implementation of the operation.
   FailureOr<TilingResult>
   getTiledImplementation(Operation *op, OpBuilder &b,
                          ArrayRef<OpFoldResult> offsets,
@@ -132,14 +132,66 @@ struct LinalgOpTilingInterface
     return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
   }
 
-  // Return the details of the output tile generated by the tiled
-  // implementation.
+  void
+  getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
+                         ArrayRef<OpFoldResult> offsets,
+                         ArrayRef<OpFoldResult> sizes,
+                         SmallVectorImpl<OpFoldResult> &mappedOffsets,
+                         SmallVectorImpl<OpFoldResult> &mappedSizes) const {
+    unsigned numLoops = linalgOp.getNumLoops();
+    auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
+    mappedOffsets.resize(numLoops);
+    mappedSizes.resize(numLoops);
+    if (!indexingMap.isPermutation()) {
+      SmallVector<Range> iterationDomain =
+          tilingInterfaceOp.getIterationDomain(b);
+      for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {
+        mappedOffsets[index] = value.offset;
+        mappedSizes[index] = value.size;
+      }
+    }
+    for (const auto &&[index, value] :
+         llvm::enumerate(indexingMap.getResults())) {
+      unsigned dimPosition = cast<AffineDimExpr>(value).getPosition();
+      mappedOffsets[dimPosition] = offsets[index];
+      mappedSizes[dimPosition] = sizes[index];
+    }
+  }
+
+  /// Return the details of the output tile generated by the tiled
+  /// implementation.
+  LogicalResult getIterationDomainTileFromOperandTile(
+      Operation *op, OpBuilder &b, unsigned operandNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+      SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
+      SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
+    auto linalgOp = cast<LinalgOp>(op);
+
+    // Check that the indexing map used for the operand is a projected
+    // permutation. This could be relaxed with a more general approach that can
+    // map the offsets and sizes from the operand to iteration space tiles
+    // (filling in full extent for dimensions not used to access the result).
+    AffineMap indexingMap =
+        linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
+    if (!indexingMap.isProjectedPermutation()) {
+      return emitError(op->getLoc(),
+                       "unhandled get iter domain position when operand is not "
+                       "accessed using a permuted projection");
+    }
+
+    getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
+                           iterDomainOffsets, iterDomainSizes);
+    return success();
+  }
+
+  /// Return the details of the output tile generated by the tiled
+  /// implementation.
   LogicalResult
   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
                         ArrayRef<OpFoldResult> offsets,
                         ArrayRef<OpFoldResult> sizes,
-                        SmallVector<OpFoldResult> &resultOffsets,
-                        SmallVector<OpFoldResult> &resultSizes) const {
+                        SmallVectorImpl<OpFoldResult> &resultOffsets,
+                        SmallVectorImpl<OpFoldResult> &resultSizes) const {
     Location loc = op->getLoc();
     LinalgOp linalgOp = cast<LinalgOp>(op);
 
@@ -160,6 +212,21 @@ struct LinalgOpTilingInterface
     return success();
   }
 
+  FailureOr<TilingResult> getTiledImplementationFromOperandTile(
+      Operation *op, OpBuilder &b, unsigned operandNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+    auto tilingInterfaceOp = cast<TilingInterface>(op);
+    if (failed(tilingInterfaceOp.getIterationDomainTileFromOperandTile(
+            b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
+      return emitError(
+          op->getLoc(),
+          "unable to obtain the iter domain position of the operation.");
+    }
+    return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
+                                                    mappedSizes);
+  }
+
   FailureOr<TilingResult>
   generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
                           ArrayRef<OpFoldResult> offsets,
@@ -177,29 +244,16 @@ struct LinalgOpTilingInterface
           "unhandled tiled implementation generation when result is not "
           "accessed using a permuted projection");
     }
-
-    auto numLoops = linalgOp.getNumLoops();
+    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+    getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
+                           mappedOffsets, mappedSizes);
     auto tilingInterfaceOp = cast<TilingInterface>(op);
-    SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
-        iterationTileSizes(numLoops);
-    if (!indexingMap.isPermutation()) {
-      SmallVector<Range> iterationDomain =
-          tilingInterfaceOp.getIterationDomain(b);
-      for (const auto &range : llvm::enumerate(iterationDomain)) {
-        iterationTileOffsets[range.index()] = range.value().offset;
-        iterationTileSizes[range.index()] = range.value().size;
-      }
-    }
-    for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
-      unsigned dimPosition =
-          cast<AffineDimExpr>(resultExpr.value()).getPosition();
-      iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
-      iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
-    }
-
     FailureOr<TilingResult> tilingResult =
-        tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
-                                                 iterationTileSizes);
+        tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
+
+    if (failed(tilingResult))
+      return failure();
+
     if (tilingResult->tiledOps.size() != 1)
       return op->emitOpError("failed to generate tiled implementation");
 
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index d25efcf50ec566..296c5fc7a5c2bd 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -61,8 +61,8 @@ struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
                         ArrayRef<OpFoldResult> offsets,
                         ArrayRef<OpFoldResult> sizes,
-                        SmallVector<OpFoldResult> &resultOffsets,
-                        SmallVector<OpFoldResult> &resultSizes) const {
+                        SmallVectorImpl<OpFoldResult> &resultOffsets,
+                        SmallVectorImpl<OpFoldResult> &resultSizes) const {
     resultOffsets.assign(offsets.begin(), offsets.end());
     resultSizes.assign(sizes.begin(), sizes.end());
     return success();
@@ -199,8 +199,8 @@ struct PackOpTiling
   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
                         ArrayRef<OpFoldResult> offsets,
                         ArrayRef<OpFoldResult> sizes,
-                        SmallVector<OpFoldResult> &resultOffsets,
-                        SmallVector<OpFoldResult> &resultSizes) const {
+                        SmallVectorImpl<OpFoldResult> &resultOffsets,
+                        SmallVectorImpl<OpFoldResult> &resultSizes) const {
     // The iteration domain is over outer dimensions of packed layout. In this
     // context, the outer dimensions of `resultOffsets` are `offsets`. The
     // inner dimensions of `resultOffsets` are zeros because tiling is not
@@ -452,8 +452,8 @@ struct UnPackOpTiling
   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
                         ArrayRef<OpFoldResult> offsets,
                         ArrayRef<OpFoldResult> sizes,
-                        SmallVector<OpFoldResult> &resultOffsets,
-                        SmallVector<OpFoldResult> &resultSizes) const {
+                        SmallVectorImpl<OpFoldResult> &resultOffsets,
+                        SmallVectorImpl<OpFoldResult> &resultSizes) const {
     resultOffsets = llvm::to_vector(offsets);
     resultSizes = llvm::to_vector(sizes);
     return success();
diff --git a/mlir/test/Dialect/Linalg/test-fuse-consumer.mlir b/mlir/test/Dialect/Linalg/test-fuse-consumer.mlir
new file mode 100644
index 00000000000000..e7edbf0b2c25d4
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/test-fuse-consumer.mlir
@@ -0,0 +1,103 @@
+// RUN: mlir-opt %s -split-input-file -test-linalg-fuse-consumer | FileCheck %s
+
+#map = affine_map<()[s0] -> (64 ceildiv s0)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0] -> (-(d0 * s0) + 64, s0)>
+// CHECK-LABEL: func.func @fuse_tileable_consumer
+// CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
+// CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<64xf32>
+// CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<64xf32>
+func.func @fuse_tileable_consumer(%arg0: index, %arg1: tensor<64xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+  // CHECK: %[[SLICE:.*]] = tensor.empty(%[[CHUNK_SIZE]]) : tensor<?xf32>
+  %0 = tensor.empty(%arg0) : tensor<?xf32>
+  %1 = affine.apply #map()[%arg0]
+  // CHECK: %[[EMPTY0:[0-9a-z]+]] = tensor.empty() : tensor<64xf32>
+  %2 = tensor.empty() : tensor<64xf32>
+  // CHECK: %[[EMPTY1:[0-9a-z]+]] = tensor.empty() : tensor<64xf32>
+  %3 = tensor.empty() : tensor<64xf32>
+  // CHECK: %[[RES:[0-9a-z]+]]:2 = scf.forall {{.*}} shared_outs(%[[LOOP_ARG0:.*]] = %[[OUT]], %[[LOOP_ARG1:.*]] = %[[EMPTY1]]
+  %4 = scf.forall (%arg3) in (%1) shared_outs(%arg4 = %arg2) -> (tensor<64xf32>) {
+    %6 = affine.apply #map1(%arg3)[%arg0]
+    %7 = affine.min #map2(%arg3)[%arg0]
+    // CHECK: %[[T0:.*]] = tensor.extract_slice %[[LOOP_ARG0]][%{{.*}}] [%{{.*}}] [{{.*}}]
+    %extracted_slice = tensor.extract_slice %arg4[%6] [%7] [1] : tensor<64xf32> to tensor<?xf32>
+    // CHECK: %[[T1:[0-9a-z]+]] = linalg.elemwise_unary
+    %8 = linalg.elemwise_unary ins(%0 : tensor<?xf32>) outs(%extracted_slice : tensor<?xf32>) -> tensor<?xf32>
+
+    // CHECK: %[[T2:.*]] = tensor.extract_slice %[[EMPTY0]][%{{.*}}] [%{{.*}}] [{{.*}}]
+    // CHECK: %[[T3:.*]] = tensor.extract_slice %[[LOOP_ARG1]][%{{.*}}] [%{{.*}}] [{{.*}}]
+    // CHECK: %[[T4:.*]] = linalg.elemwise_binary {{.*}} ins(%[[T1]], %[[T2]] : {{.*}} outs(%[[T3]]
+
+    scf.forall.in_parallel {
+      // CHECK: tensor.parallel_insert_slice %[[T4]] into %[[LOOP_ARG1]]
+      // CHECK: tensor.parallel_insert_slice %[[T1]] into %[[LOOP_ARG0]]
+      tensor.parallel_insert_slice %8 into %arg4[%6] [%7] [1] : tensor<?xf32> into tensor<64xf32>
+    }
+  } {"containing"}
+  // CHECK: %[[ORI_OUTPUT:.*]] = linalg.elemwise_binary
+  %5 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>, "consumer"} ins(%4, %2 : tensor<64xf32>, tensor<64xf32>) outs(%3 : tensor<64xf32>) -> tensor<64xf32>
+  // CHECK: return %[[RES]]#1
+  return %5 : tensor<64xf32>
+}
+// -----
+
+#map = affine_map<(d0) -> (d0 * -50 + 123, 50)>
+#map1 = affine_map<(d0) -> (d0 * -16 + 789, 16)>
+#map2 = affine_map<(d0) -> (d0 * 50)>
+#map3 = affine_map<(d0) -> (d0 * 16)>
+#map4 = affine_map<(d0, d1) -> (d0, d1)>
+#map5 = affine_map<(d0, d1) -> (d1, d0)>
+// CHECK-LABEL: func.func @fuse_consumer_multi_output
+// CHECK-SAME: %[[IN0:[0-9a-z]+]]: tensor<123x456xf32>
+// CHECK-SAME: %[[IN1:[0-9a-z]+]]: tensor<456x789xf32>
+// CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<123x789xf32>
+func.func @fuse_consumer_multi_output(%arg0: tensor<123x456xf32>, %arg1: tensor<456x789xf32>, %arg2: tensor<123x789xf32>) -> (tensor<123x789xf32>, tensor<789x123xf32>) {
+  %cst = arith.constant 0.000000e+00 : f32
+  // CHECK: %[[INIT:.*]] = linalg.fill
+  %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<123x789xf32>) -> tensor<123x789xf32>
+  // CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<123x789xf32>
+  %1 = tensor.empty() : tensor<123x789xf32>
+  // CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<789x123xf32>
+  %2 = tensor.empty() : tensor<789x123xf32>
+  // CHECK: %[[RES:[0-9a-z]+]]:3 = scf.forall {{.*}} shared_outs(%[[LOOP_ARG0:.*]] = %[[INIT]], %[[LOOP_ARG1:.*]] = %[[EMPTY0]], %[[LOOP_ARG2:.*]] = %[[EMPTY1]]
+  %3 = scf.forall (%arg3, %arg4) in (3, 50) shared_outs(%arg5 = %0) -> (tensor<123x789xf32>) {
+    %5 = affine.min #map(%arg3)
+    %6 = affine.min #map1(%arg4)
+    %7 = affine.apply #map2(%arg3)
+    %8 = affine.apply #map3(%arg4)
+    %9 = affine.apply #map2(%arg3)
+    %10 = affine.apply #map3(%arg4)
+    // CHECK: %[[EXTRACT_IN0:.*]] = tensor.extract_slice %[[IN0]]
+    %extracted_slice = tensor.extract_slice %arg0[%7, 0] [%5, 456] [1, 1] : tensor<123x456xf32> to tensor<?x456xf32>
+    // CHECK: %[[EXTRACT_IN1:.*]] = tensor.extract_slice %[[IN1]]
+    %extracted_slice_0 = tensor.extract_slice %arg1[0, %8] [456, %6] [1, 1] : tensor<456x789xf32> to tensor<456x?xf32>
+    // CHECK: %[[EXTRACT_OUT:.*]] = tensor.extract_slice %[[LOOP_ARG0]]
+    %extracted_slice_1 = tensor.extract_slice %arg5[%9, %10] [%5, %6] [1, 1] : tensor<123x789xf32> to tensor<?x?xf32>
+    // CHECK: %[[MATMUL_RES:.*]] = linalg.matmul ins(%[[EXTRACT_IN0]], %[[EXTRACT_IN1]] {{.*}} outs(%[[EXTRACT_OUT]]
+    %11 = linalg.matmul ins(%extracted_slice, %extracted_slice_0 : tensor<?x456xf32>, tensor<456x?xf32>) outs(%extracted_slice_1 : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+    // CHECK: %[[EXTRACT_EMPTY0:.*]] = tensor.extract_slice %[[LOOP_ARG1]]
+    // CHECK: %[[EXTRACT_EMPTY1:.*]] = tensor.extract_slice %[[LOOP_ARG2]]
+    // CHECK: %[[GENERIC_RES:.*]]:2 = linalg.generic {{.*}} ins(%[[MATMUL_RES]] : tensor<?x?xf32>) outs(%[[EXTRACT_EMPTY0]], %[[EXTRACT_EMPTY1]]
+
+    %12 = affine.apply #map2(%arg3)
+    %13 = affine.apply #map3(%arg4)
+    scf.forall.in_parallel {
+      // CHECK: tensor.parallel_insert_slice %[[GENERIC_RES]]#0 into %[[LOOP_ARG1]]
+      // CHECK: tensor.parallel_insert_slice %[[GENERIC_RES]]#1 into %[[LOOP_ARG2]]
+      // CHECK: tensor.parallel_insert_slice %[[MATMUL_RES]] into %[[LOOP_ARG0]]
+      tensor.parallel_insert_slice %11 into %arg5[%12, %13] [%5, %6] [1, 1] : tensor<?x?xf32> into tensor<123x789xf32>
+    }
+  } {"containing"}
+  // CHECK: %[[ORI_OUTPUT:.*]]:2 = linalg.generic
+  %4:2 = linalg.generic {"consumer", indexing_maps = [#map4, #map4, #map5], iterator_types = ["parallel", "parallel"]} ins(%3 : tensor<123x789xf32>) outs(%1, %2 : tensor<123x789xf32>, tensor<789x123xf32>) {
+  ^bb0(%in: f32, %out: f32, %out_0: f32):
+    %5 = arith.addf %in, %out : f32
+    %6 = arith.addf %5, %out_0 : f32
+    linalg.yield %5, %6 : f32, f32
+  } -> (tensor<123x789xf32>, tensor<789x123xf32>)
+  // CHECK: return %[[RES]]#1, %[[RES]]#2
+  return %4#0, %4#1 : tensor<123x789xf32>, tensor<789x123xf32>
+}
+
+
diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index b28f2b3564662a..23479912a865cd 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_library(MLIRLinalgTestPasses
   TestLinalgFusionTransforms.cpp
   TestLinalgTransforms.cpp
   TestPadFusion.cpp
+  TestLinalgFuseConsumer.cpp
 
   EXCLUDE_FROM_LIBMLIR
 
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFuseConsumer.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFuseConsumer.cpp
new file mode 100644
index 00000000000000..ae52684b204483
--- /dev/null
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFuseConsumer.cpp
@@ -0,0 +1,294 @@
+//===- TestLinalgFuseConsumer.cpp - Test Linalg fuse consumer  ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for testing fuse consumer of Linalg ops.
+// This is a temporary pass used to verify the correctness of the tiling
+// interface in linalg op and the related interface of fuse consumer. It should
+// be replaced with that implementation when the corresponding fusion transform
+// op is completed.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Region.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+
+using namespace mlir;
+
+#define DEBUG_TYPE "fuse-consumer"
+
+namespace {
+struct TestLinalgFuseConsumer
+    : public PassWrapper<TestLinalgFuseConsumer, OperationPass<>> {
+  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFuseConsumer)
+
+  TestLinalgFuseConsumer() = default;
+  TestLinalgFuseConsumer(const TestLinalgFuseConsumer &pass)
+      : PassWrapper(pass){};
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<scf::SCFDialect, linalg::LinalgDialect,
+                    tensor::TensorDialect>();
+  }
+  StringRef getArgument() const final { return "test-linalg-fuse-consumer"; }
+  StringRef getDescription() const final {
+    return "Test Linalg fuse consumer interface";
+  }
+
+  void runOnOperation() override {
+    Operation *consumerOp = nullptr, *containingOp = nullptr;
+    auto walkRes = getOperation()->walk([&](Operation *op) {
+      if (op->hasAttr("consumer")) {
+        if (consumerOp) {
+          return WalkResult::interrupt();
+        }
+        consumerOp = op;
+      }
+      if (op->hasAttr("containing")) {
+        if (containingOp) {
+          return WalkResult::interrupt();
+        }
+        containingOp = op;
+      }
+      return WalkResult::advance();
+    });
+
+    if (!consumerOp || !containingOp || walkRes.wasInterrupted()) {
+      emitError(getOperation()->getLoc())
+          << "expect 1 consumer and 1 containing op.";
+      return;
+    }
+
+    // Check consumer has tiling interface.
+    auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+    if (!tileableConsumer) {
+      emitError(consumerOp->getLoc())
+          << "consumer is not a TileableInterface: " << *consumerOp;
+      return;
+    }
+
+    // Check containing op is "scf::ForallOp".
+    auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+    if (!forallOp) {
+      emitError(containingOp->getLoc())
+          << "containing op is not a scf.forall: " << containingOp;
+      return;
+    }
+
+    // Check dominance.
+    DominanceInfo domInfo(getOperation());
+    if (llvm::any_of(consumerOp->getOperands(), [&](Value v) {
+          return v.getDefiningOp() != containingOp &&
+                 !domInfo.properlyDominates(v, containingOp);
+        })) {
+      emitError(consumerOp->getLoc())
+          << "consumer's operand can't dominate containing op";
+      return;
+    }
+
+    // Check consumer don't use more than one result of containingOp.
+    Value bridge(nullptr);
+    SmallVector<unsigned> operandNums;
+    for (auto [idx, opd] : llvm::enumerate((consumerOp->getOperands()))) {
+      if (opd.getDefiningOp() == containingOp) {
+        operandNums.push_back(idx);
+        if (!bridge) {
+          bridge = opd;
+        } else if (bridge != opd) {
+          emitError(consumerOp->getLoc())
+              << "consumer's operand use more than one containingOp's result";
+          return;
+        }
+      }
+    }
+
+    // Check consumer has DestinationStyleOpInterface.
+    auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+    if (!dstOp) {
+      emitError(consumerOp->getLoc())
+          << "consumer op should have destination style op interface";
+      return;
+    }
+
+    // Check consumer doon't use scf.forall's output as init.
+    SmallVector<Value> dpsInits = llvm::to_vector<4>(
+        llvm::map_range(dstOp.getDpsInits(), [](Value v) { return v; }));
+    if (llvm::is_contained(dpsInits, bridge)) {
+      emitError(consumerOp->getLoc())
+          << "consumer op take result of scf.forall as init";
+      return;
+    }
+
+    // Check result was inserted only once.
+    int64_t bridgeResultIdx = cast<OpResult>(bridge).getResultNumber();
+    auto bridgeBlockArg = forallOp.getRegionOutArgs()[bridgeResultIdx];
+    scf::InParallelOp terminatorOp = forallOp.getTerminator();
+
+    tensor::ParallelInsertSliceOp targetInsertOp(nullptr);
+    for (Operation &op : terminatorOp.getRegion().front().getOperations()) {
+      auto parallelInsertSliceOp = cast<tensor::ParallelInsertSliceOp>(op);
+      if (parallelInsertSliceOp.getDest() == bridgeBlockArg) {
+        if (!targetInsertOp) {
+          targetInsertOp = parallelInsertSliceOp;
+        } else {
+          emitError(containingOp->getLoc())
+              << "containingOp's result inserted multi time";
+          return;
+        }
+      }
+    }
+
+    if (!targetInsertOp) {
+      emitError(containingOp->getLoc())
+          << "containingOp's result was not inserted";
+      return;
+    }
+
+    SmallVector<OpFoldResult> offsets = targetInsertOp.getMixedOffsets();
+    SmallVector<OpFoldResult> sizes = targetInsertOp.getMixedSizes();
+    SmallVector<OpFoldResult> strides = targetInsertOp.getMixedStrides();
+
+    // Check all insert stride is 1.
+    if (llvm::any_of(strides, [](OpFoldResult foldRes) {
+          if (auto attr = foldRes.dyn_cast<Attribute>()) {
+            return cast<IntegerAttr>(attr).getInt() != 1;
+          }
+          return true;
+        })) {
+      emitError(containingOp->getLoc())
+          << "containingOp's result yield with stride";
+      return;
+    }
+
+    IRRewriter rewriter(terminatorOp);
+    Location loc = forallOp.getLoc();
+
+    SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+
+    // Try to get iter domain position from input position.
+    if (failed(tileableConsumer.getIterationDomainTileFromOperandTile(
+            rewriter, operandNums.front(), offsets, sizes, iterDomainOffsets,
+            iterDomainSizes))) {
+      emitError(consumerOp->getLoc())
+          << "can't get iter domain position from input position";
+      return;
+    }
+
+    // Try to get all containing op result's position from iter domain position.
+    llvm::SmallVector<std::pair<llvm::SmallVector<OpFoldResult>,
+                                llvm::SmallVector<OpFoldResult>>>
+        resultPositions(consumerOp->getNumResults());
+    for (auto [idx, v] : llvm::enumerate(consumerOp->getResults())) {
+      if (failed(tileableConsumer.getResultTilePosition(
+              rewriter, idx, iterDomainOffsets, iterDomainSizes,
+              resultPositions[idx].first, resultPositions[idx].second))) {
+        emitError(consumerOp->getLoc())
+            << "can't get result domain position from iter domain position";
+        return;
+      }
+    }
+
+    // All check passed, try to fuse consumer.
+    // Create tiled implementation of containing op.
+    FailureOr<TilingResult> tileAndFuseResult =
+        tileableConsumer.getTiledImplementationFromOperandTile(
+            rewriter, operandNums.front(), offsets, sizes);
+    if (failed(tileAndFuseResult)) {
+      emitError(consumerOp->getLoc()) << "get tiled implementation failed";
+      return;
+    }
+
+    auto tiledOps = tileAndFuseResult->tiledOps;
+    if (failed(tileAndFuseResult) || tiledOps.size() != 1) {
+      emitError(consumerOp->getLoc())
+          << "failed to tile consumer op: " << *tileableConsumer;
+      return;
+    }
+
+    // Replace tiled op's operand.
+    for (auto operandNum : operandNums) {
+      tiledOps[0]->setOperand(operandNum, targetInsertOp.getSource());
+    }
+    rewriter.replaceUsesWithIf(bridge, forallOp.getOutputs()[bridgeResultIdx],
+                               [&](OpOperand &use) {
+                                 Operation *op = use.getOwner();
+                                 return forallOp->isProperAncestor(op);
+                               });
+
+    SmallVector<Value> newOuts(forallOp.getOutputs());
+    newOuts.append(dpsInits);
+
+    // Create new scf.forall op.
+    rewriter.setInsertionPoint(forallOp);
+    auto newforallOp = rewriter.create<scf::ForallOp>(
+        loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+        forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+    rewriter.eraseBlock(newforallOp.getBody());
+    newforallOp.getRegion().takeBody(forallOp.getRegion());
+
+    for (auto v : dpsInits) {
+      newforallOp.getBody()->addArgument(v.getType(), v.getLoc());
+      auto bbArgs = newforallOp.getBody()->getArguments();
+      rewriter.replaceUsesWithIf(v, bbArgs.back(), [&](OpOperand &use) {
+        Operation *op = use.getOwner();
+        return newforallOp->isProperAncestor(op);
+      });
+    }
+
+    // Fix terminator.
+    scf::InParallelOp newTerminatorOp = newforallOp.getTerminator();
+    SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(llvm::map_range(
+        newTerminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));
+    Operation *firstYieldOp = yieldingOps.front();
+    rewriter.setInsertionPoint(firstYieldOp);
+    auto bbArgs = newforallOp.getBody()->getArguments();
+    for (auto [idx, v] : llvm::enumerate(tiledOps[0]->getResults())) {
+      SmallVector<OpFoldResult> strides(resultPositions[idx].first.size(),
+                                        rewriter.getIndexAttr(1));
+      rewriter.create<tensor::ParallelInsertSliceOp>(
+          firstYieldOp->getLoc(), v,
+          bbArgs[forallOp.getRank() + forallOp.getOutputs().size() + idx],
+          resultPositions[idx].first, resultPositions[idx].second, strides);
+    }
+
+    // Replace the result of forall and consumer op.
+    for (auto result : llvm::enumerate(forallOp.getResults())) {
+      rewriter.replaceAllUsesWith(result.value(),
+                                  newforallOp->getResult(result.index()));
+    }
+
+    for (auto consumerResult : llvm::enumerate(consumerOp->getResults())) {
+      rewriter.replaceAllUsesWith(
+          consumerResult.value(),
+          newforallOp->getResult(forallOp.getOutputs().size() +
+                                 consumerResult.index()));
+    }
+    forallOp.erase();
+    return;
+  }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestLinalgFuseConsumer() {
+  PassRegistration<TestLinalgFuseConsumer>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 237ebeb166dc99..87600c5c161d28 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -106,6 +106,7 @@ void registerTestLinalgDropUnitDims();
 void registerTestLinalgElementwiseFusion();
 void registerTestLinalgGreedyFusion();
 void registerTestLinalgTransforms();
+void registerTestLinalgFuseConsumer();
 void registerTestLivenessAnalysisPass();
 void registerTestLivenessPass();
 void registerTestLoopFusion();
@@ -235,6 +236,7 @@ void registerTestPasses() {
   mlir::test::registerTestLinalgElementwiseFusion();
   mlir::test::registerTestLinalgGreedyFusion();
   mlir::test::registerTestLinalgTransforms();
+  mlir::test::registerTestLinalgFuseConsumer();
   mlir::test::registerTestLivenessAnalysisPass();
   mlir::test::registerTestLivenessPass();
   mlir::test::registerTestLoopFusion();



More information about the Mlir-commits mailing list