[Mlir-commits] [mlir] [mlir][linalg] Enable fuse consumer (PR #89893)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 24 01:46:38 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tensor

Author: donald chen (cxy-1993)

<details>
<summary>Changes</summary>

This patch adds support for consumer fusion to the tiling interface, and implements fuse consumers on FuseIntoContainingOp.

- Add interface method 'getIterDomainTilePositionFromOperandPosition' to tiling interface which get iteration domain position from operand position.
- Add interface method 'getTiledImplementationFromOperandPosition' to tiling interface which generate tiled implementation according to operand position.
- Implemented the above two methods and supported consumer fusion for FuseIntoContainingOp.
- Add tests for fuse consumer interface

---

Patch is 33.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89893.diff


8 Files Affected:

- (modified) mlir/include/mlir/Interfaces/TilingInterface.td (+61-6) 
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+2-2) 
- (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+80-26) 
- (modified) mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp (+6-6) 
- (added) mlir/test/Dialect/Linalg/test-fuse-consumer.mlir (+103) 
- (modified) mlir/test/lib/Dialect/Linalg/CMakeLists.txt (+1) 
- (added) mlir/test/lib/Dialect/Linalg/TestLinalgFuseConsumer.cpp (+294) 
- (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2) 


``````````diff
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 66382f29c24249..84f7dec2f4003d 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -63,7 +63,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           The method returns the operation that is the tiled
           implementation.
         }],
-        /*retType=*/"FailureOr<TilingResult>",
+        /*retType=*/"FailureOr<::mlir::TilingResult>",
         /*methodName=*/"getTiledImplementation",
         /*args=*/(ins
             "OpBuilder &":$b,
@@ -82,15 +82,34 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           by the tiled implementation. Expects the same `offsets` and `sizes` as
           used to obtain the tiled implementation of the operation.
         }],
-        /*retType=*/"LogicalResult",
+        /*retType=*/"::mlir::LogicalResult",
         /*methodName=*/"getResultTilePosition",
         /*args=*/(ins
           "OpBuilder &":$b,
           "unsigned":$resultNumber,
           "ArrayRef<OpFoldResult> ":$offsets,
           "ArrayRef<OpFoldResult> ":$sizes,
-          "SmallVector<OpFoldResult> &":$resultOffsets,
-          "SmallVector<OpFoldResult> &":$resultSizes),
+          "SmallVectorImpl<OpFoldResult> &":$resultOffsets,
+          "SmallVectorImpl<OpFoldResult> &":$resultSizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to return the position of iteration domain tile computed by the
+          tiled operation.
+        }],
+        /*retType=*/"::mlir::LogicalResult",
+        /*methodName=*/"getIterationDomainTileFromOperandTile",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$operandNumber,
+          "ArrayRef<OpFoldResult> ":$offsets,
+          "ArrayRef<OpFoldResult> ":$sizes,
+          "SmallVectorImpl<OpFoldResult> &":$iterDomainOffsets,
+          "SmallVectorImpl<OpFoldResult> &":$iterDomainSizes),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return failure();
@@ -119,7 +138,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
             iteration space).
           - `sizes` provides the size of the tile.
         }],
-        /*retType=*/"FailureOr<TilingResult>",
+        /*retType=*/"FailureOr<::mlir::TilingResult>",
         /*methodName=*/"generateResultTileValue",
         /*args=*/(ins
           "OpBuilder &":$b,
@@ -131,6 +150,42 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           return failure();
         }]
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to generate the tiled implementation of an operation from
+          operand tile position.
+
+          Generates the IR that computes the tiled implementation of an
+          operation from operand tile.  The `offsets` and `sizes`
+          describe the tile of the operand required. This is different from
+          `getTiledImplementation` which generates the tiled
+          implementation of the operation given a tile of the
+          iteration space. This method generates a tiled
+          implementation of the operation based on the tile of the
+          operand required. This method enables consumer fusion by using
+          tile and fuse. The method returns failure if the operation
+          can't be tiled to generate the operand tile. In practical terms
+          this implies it cannot be tiled and fused with its producers.
+
+          - `offsets` provides the offset of the tile in the coordinate system
+            of the original iteration space, i.e., if an iteration space
+            dimension had non-zero offset, it must be included in the offset
+            provided here (as opposed to zero-based offset "relative" to the
+            iteration space).
+          - `sizes` provides the size of the tile.
+        }],
+        /*retType=*/"FailureOr<::mlir::TilingResult>",
+        /*methodName=*/"getTiledImplementationFromOperandTile",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$operandNumber,
+          "ArrayRef<OpFoldResult>":$offsets,
+          "ArrayRef<OpFoldResult>":$sizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Generates the scalar implementation of the operation.
@@ -142,7 +197,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           transformations are done, this method can be used to lower to scalar
           code that can then be lowered to LLVM or SPIR-V dialects.
         }],
-        /*retType=*/"LogicalResult",
+        /*retType=*/"::mlir::LogicalResult",
         /*methodName=*/"generateScalarImplementation",
         /*args=*/(ins
             "OpBuilder &":$b,
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9c5c58fa1fabfb..e9999c34d0face 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2425,8 +2425,8 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
 
 LogicalResult SoftmaxOp::getResultTilePosition(
     OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
-    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
-    SmallVector<OpFoldResult> &resultSizes) {
+    ArrayRef<OpFoldResult> sizes, SmallVectorImpl<OpFoldResult> &resultOffsets,
+    SmallVectorImpl<OpFoldResult> &resultSizes) {
   if (resultNumber == 0) {
     resultOffsets.assign(offsets.begin(), offsets.end());
     resultSizes.assign(sizes.begin(), sizes.end());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index bd870d4f982e5d..71e9c3771dcded 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -110,7 +110,7 @@ struct LinalgOpTilingInterface
         }));
   }
 
-  // Instantiate the tiled implementation of the operation.
+  /// Instantiate the tiled implementation of the operation.
   FailureOr<TilingResult>
   getTiledImplementation(Operation *op, OpBuilder &b,
                          ArrayRef<OpFoldResult> offsets,
@@ -132,14 +132,66 @@ struct LinalgOpTilingInterface
     return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
   }
 
-  // Return the details of the output tile generated by the tiled
-  // implementation.
+  void
+  getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
+                         ArrayRef<OpFoldResult> offsets,
+                         ArrayRef<OpFoldResult> sizes,
+                         SmallVectorImpl<OpFoldResult> &mappedOffsets,
+                         SmallVectorImpl<OpFoldResult> &mappedSizes) const {
+    unsigned numLoops = linalgOp.getNumLoops();
+    auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
+    mappedOffsets.resize(numLoops);
+    mappedSizes.resize(numLoops);
+    if (!indexingMap.isPermutation()) {
+      SmallVector<Range> iterationDomain =
+          tilingInterfaceOp.getIterationDomain(b);
+      for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {
+        mappedOffsets[index] = value.offset;
+        mappedSizes[index] = value.size;
+      }
+    }
+    for (const auto &&[index, value] :
+         llvm::enumerate(indexingMap.getResults())) {
+      unsigned dimPosition = cast<AffineDimExpr>(value).getPosition();
+      mappedOffsets[dimPosition] = offsets[index];
+      mappedSizes[dimPosition] = sizes[index];
+    }
+  }
+
+  /// Return the details of the output tile generated by the tiled
+  /// implementation.
+  LogicalResult getIterationDomainTileFromOperandTile(
+      Operation *op, OpBuilder &b, unsigned operandNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+      SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
+      SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
+    auto linalgOp = cast<LinalgOp>(op);
+
+    // Check that the indexing map used for the operand is a projected
+    // permutation. This could be relaxed with a more general approach that can
+    // map the offsets and sizes from the operand to iteration space tiles
+    // (filling in full extent for dimensions not used to access the result).
+    AffineMap indexingMap =
+        linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
+    if (!indexingMap.isProjectedPermutation()) {
+      return emitError(op->getLoc(),
+                       "unhandled get iter domain position when operand is not "
+                       "accessed using a permuted projection");
+    }
+
+    getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
+                           iterDomainOffsets, iterDomainSizes);
+    return success();
+  }
+
+  /// Return the details of the output tile generated by the tiled
+  /// implementation.
   LogicalResult
   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
                         ArrayRef<OpFoldResult> offsets,
                         ArrayRef<OpFoldResult> sizes,
-                        SmallVector<OpFoldResult> &resultOffsets,
-                        SmallVector<OpFoldResult> &resultSizes) const {
+                        SmallVectorImpl<OpFoldResult> &resultOffsets,
+                        SmallVectorImpl<OpFoldResult> &resultSizes) const {
     Location loc = op->getLoc();
     LinalgOp linalgOp = cast<LinalgOp>(op);
 
@@ -160,6 +212,21 @@ struct LinalgOpTilingInterface
     return success();
   }
 
+  FailureOr<TilingResult> getTiledImplementationFromOperandTile(
+      Operation *op, OpBuilder &b, unsigned operandNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+    auto tilingInterfaceOp = cast<TilingInterface>(op);
+    if (failed(tilingInterfaceOp.getIterationDomainTileFromOperandTile(
+            b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
+      return emitError(
+          op->getLoc(),
+          "unable to obtain the iter domain position of the operation.");
+    }
+    return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
+                                                    mappedSizes);
+  }
+
   FailureOr<TilingResult>
   generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
                           ArrayRef<OpFoldResult> offsets,
@@ -177,29 +244,16 @@ struct LinalgOpTilingInterface
           "unhandled tiled implementation generation when result is not "
           "accessed using a permuted projection");
     }
-
-    auto numLoops = linalgOp.getNumLoops();
+    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+    getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
+                           mappedOffsets, mappedSizes);
     auto tilingInterfaceOp = cast<TilingInterface>(op);
-    SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
-        iterationTileSizes(numLoops);
-    if (!indexingMap.isPermutation()) {
-      SmallVector<Range> iterationDomain =
-          tilingInterfaceOp.getIterationDomain(b);
-      for (const auto &range : llvm::enumerate(iterationDomain)) {
-        iterationTileOffsets[range.index()] = range.value().offset;
-        iterationTileSizes[range.index()] = range.value().size;
-      }
-    }
-    for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
-      unsigned dimPosition =
-          cast<AffineDimExpr>(resultExpr.value()).getPosition();
-      iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
-      iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
-    }
-
     FailureOr<TilingResult> tilingResult =
-        tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
-                                                 iterationTileSizes);
+        tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
+
+    if (failed(tilingResult))
+      return failure();
+
     if (tilingResult->tiledOps.size() != 1)
       return op->emitOpError("failed to generate tiled implementation");
 
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index d25efcf50ec566..296c5fc7a5c2bd 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -61,8 +61,8 @@ struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
                         ArrayRef<OpFoldResult> offsets,
                         ArrayRef<OpFoldResult> sizes,
-                        SmallVector<OpFoldResult> &resultOffsets,
-                        SmallVector<OpFoldResult> &resultSizes) const {
+                        SmallVectorImpl<OpFoldResult> &resultOffsets,
+                        SmallVectorImpl<OpFoldResult> &resultSizes) const {
     resultOffsets.assign(offsets.begin(), offsets.end());
     resultSizes.assign(sizes.begin(), sizes.end());
     return success();
@@ -199,8 +199,8 @@ struct PackOpTiling
   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
                         ArrayRef<OpFoldResult> offsets,
                         ArrayRef<OpFoldResult> sizes,
-                        SmallVector<OpFoldResult> &resultOffsets,
-                        SmallVector<OpFoldResult> &resultSizes) const {
+                        SmallVectorImpl<OpFoldResult> &resultOffsets,
+                        SmallVectorImpl<OpFoldResult> &resultSizes) const {
     // The iteration domain is over outer dimensions of packed layout. In this
     // context, the outer dimensions of `resultOffsets` are `offsets`. The
     // inner dimensions of `resultOffsets` are zeros because tiling is not
@@ -452,8 +452,8 @@ struct UnPackOpTiling
   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
                         ArrayRef<OpFoldResult> offsets,
                         ArrayRef<OpFoldResult> sizes,
-                        SmallVector<OpFoldResult> &resultOffsets,
-                        SmallVector<OpFoldResult> &resultSizes) const {
+                        SmallVectorImpl<OpFoldResult> &resultOffsets,
+                        SmallVectorImpl<OpFoldResult> &resultSizes) const {
     resultOffsets = llvm::to_vector(offsets);
     resultSizes = llvm::to_vector(sizes);
     return success();
diff --git a/mlir/test/Dialect/Linalg/test-fuse-consumer.mlir b/mlir/test/Dialect/Linalg/test-fuse-consumer.mlir
new file mode 100644
index 00000000000000..e7edbf0b2c25d4
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/test-fuse-consumer.mlir
@@ -0,0 +1,103 @@
+// RUN: mlir-opt %s -split-input-file -test-linalg-fuse-consumer | FileCheck %s
+
+#map = affine_map<()[s0] -> (64 ceildiv s0)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0] -> (-(d0 * s0) + 64, s0)>
+// CHECK-LABEL: func.func @fuse_tileable_consumer
+// CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
+// CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<64xf32>
+// CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<64xf32>
+func.func @fuse_tileable_consumer(%arg0: index, %arg1: tensor<64xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+  // CHECK: %[[SLICE:.*]] = tensor.empty(%[[CHUNK_SIZE]]) : tensor<?xf32>
+  %0 = tensor.empty(%arg0) : tensor<?xf32>
+  %1 = affine.apply #map()[%arg0]
+  // CHECK: %[[EMPTY0:[0-9a-z]+]] = tensor.empty() : tensor<64xf32>
+  %2 = tensor.empty() : tensor<64xf32>
+  // CHECK: %[[EMPTY1:[0-9a-z]+]] = tensor.empty() : tensor<64xf32>
+  %3 = tensor.empty() : tensor<64xf32>
+  // CHECK: %[[RES:[0-9a-z]+]]:2 = scf.forall {{.*}} shared_outs(%[[LOOP_ARG0:.*]] = %[[OUT]], %[[LOOP_ARG1:.*]] = %[[EMPTY1]]
+  %4 = scf.forall (%arg3) in (%1) shared_outs(%arg4 = %arg2) -> (tensor<64xf32>) {
+    %6 = affine.apply #map1(%arg3)[%arg0]
+    %7 = affine.min #map2(%arg3)[%arg0]
+    // CHECK: %[[T0:.*]] = tensor.extract_slice %[[LOOP_ARG0]][%{{.*}}] [%{{.*}}] [{{.*}}]
+    %extracted_slice = tensor.extract_slice %arg4[%6] [%7] [1] : tensor<64xf32> to tensor<?xf32>
+    // CHECK: %[[T1:[0-9a-z]+]] = linalg.elemwise_unary
+    %8 = linalg.elemwise_unary ins(%0 : tensor<?xf32>) outs(%extracted_slice : tensor<?xf32>) -> tensor<?xf32>
+
+    // CHECK: %[[T2:.*]] = tensor.extract_slice %[[EMPTY0]][%{{.*}}] [%{{.*}}] [{{.*}}]
+    // CHECK: %[[T3:.*]] = tensor.extract_slice %[[LOOP_ARG1]][%{{.*}}] [%{{.*}}] [{{.*}}]
+    // CHECK: %[[T4:.*]] = linalg.elemwise_binary {{.*}} ins(%[[T1]], %[[T2]] : {{.*}} outs(%[[T3]]
+
+    scf.forall.in_parallel {
+      // CHECK: tensor.parallel_insert_slice %[[T4]] into %[[LOOP_ARG1]]
+      // CHECK: tensor.parallel_insert_slice %[[T1]] into %[[LOOP_ARG0]]
+      tensor.parallel_insert_slice %8 into %arg4[%6] [%7] [1] : tensor<?xf32> into tensor<64xf32>
+    }
+  } {"containing"}
+  // CHECK: %[[ORI_OUTPUT:.*]] = linalg.elemwise_binary
+  %5 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>, "consumer"} ins(%4, %2 : tensor<64xf32>, tensor<64xf32>) outs(%3 : tensor<64xf32>) -> tensor<64xf32>
+  // CHECK: return %[[RES]]#1
+  return %5 : tensor<64xf32>
+}
+// -----
+
+#map = affine_map<(d0) -> (d0 * -50 + 123, 50)>
+#map1 = affine_map<(d0) -> (d0 * -16 + 789, 16)>
+#map2 = affine_map<(d0) -> (d0 * 50)>
+#map3 = affine_map<(d0) -> (d0 * 16)>
+#map4 = affine_map<(d0, d1) -> (d0, d1)>
+#map5 = affine_map<(d0, d1) -> (d1, d0)>
+// CHECK-LABEL: func.func @fuse_consumer_multi_output
+// CHECK-SAME: %[[IN0:[0-9a-z]+]]: tensor<123x456xf32>
+// CHECK-SAME: %[[IN1:[0-9a-z]+]]: tensor<456x789xf32>
+// CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<123x789xf32>
+func.func @fuse_consumer_multi_output(%arg0: tensor<123x456xf32>, %arg1: tensor<456x789xf32>, %arg2: tensor<123x789xf32>) -> (tensor<123x789xf32>, tensor<789x123xf32>) {
+  %cst = arith.constant 0.000000e+00 : f32
+  // CHECK: %[[INIT:.*]] = linalg.fill
+  %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<123x789xf32>) -> tensor<123x789xf32>
+  // CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<123x789xf32>
+  %1 = tensor.empty() : tensor<123x789xf32>
+  // CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<789x123xf32>
+  %2 = tensor.empty() : tensor<789x123xf32>
+  // CHECK: %[[RES:[0-9a-z]+]]:3 = scf.forall {{.*}} shared_outs(%[[LOOP_ARG0:.*]] = %[[INIT]], %[[LOOP_ARG1:.*]] = %[[EMPTY0]], %[[LOOP_ARG2:.*]] = %[[EMPTY1]]
+  %3 = scf.forall (%arg3, %arg4) in (3, 50) shared_outs(%arg5 = %0) -> (tensor<123x789xf32>) {
+    %5 = affine.min #map(%arg3)
+    %6 = affine.min #map1(%arg4)
+    %7 = affine.apply #map2(%arg3)
+    %8 = affine.apply #map3(%arg4)
+    %9 = affine.apply #map2(%arg3)
+    %10 = affine.apply #map3(%arg4)
+    // CHECK: %[[EXTRACT_IN0:.*]] = tensor.extract_slice %[[IN0]]
+    %extracted_slice = tensor.extract_slice %arg0[%7, 0] [%5, 456] [1, 1] : tensor<123x456xf32> to tensor<?x456xf32>
+    // CHECK: %[[EXTRACT_IN1:.*]] = tensor.extract_slice %[[IN1]]
+    %extracted_slice_0 = tensor.extract_slice %arg1[0, %8] [456, %6] [1, 1] : tensor<456x789xf32> to tensor<456x?xf32>
+    // CHECK: %[[EXTRACT_OUT:.*]] = tensor.extract_slice %[[LOOP_ARG0]]
+    %extracted_slice_1 = tensor.extract_slice %arg5[%9, %10] [%5, %6] [1, 1] : tensor<123x789xf32> to tensor<?x?xf32>
+    // CHECK: %[[MATMUL_RES:.*]] = linalg.matmul ins(%[[EXTRACT_IN0]], %[[EXTRACT_IN1]] {{.*}} outs(%[[EXTRACT_OUT]]
+    %11 = linalg.matmul ins(%extracted_slice, %extracted_slice_0 : tensor<?x456xf32>, tensor<456x?xf32>) outs(%extracted_slice_1 : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+    // CHECK: %[[EXTRACT_EMPTY0:.*]] = tensor.extract_slice %[[LOOP_ARG1]]
+    // CHECK: %[[EXTRACT_EMPTY1:.*]] = tensor.extract_slice %[[LOOP_ARG2]]
+    // CHECK: %[[GENERIC_RES:.*]]:2 = linalg.generic {{.*}} ins(%[[MATMUL_RES]] : tensor<?x?xf32>) outs(%[[EXTRACT_EMPTY0]], %[[EXTRACT_EMPTY1]]
+
+    %12 = affine.apply #map2(%arg3)
+    %13 = affine.apply #map3(%arg4)
+    scf.forall.in_parallel {
+      // CHECK: tensor.parallel_insert_slice %[[GENERIC_RES]]#0 into %[...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/89893


More information about the Mlir-commits mailing list