[Mlir-commits] [mlir] [mlir][linalg] Enable fuse consumer (PR #89893)
donald chen
llvmlistbot at llvm.org
Wed Apr 24 01:50:11 PDT 2024
https://github.com/cxy-1993 updated https://github.com/llvm/llvm-project/pull/89893
>From cb24711f662f0a657f77e58630fd469519d3c495 Mon Sep 17 00:00:00 2001
From: cxy <chenxunyu1993 at gmail.com>
Date: Tue, 23 Apr 2024 12:28:19 +0000
Subject: [PATCH] [mlir][linalg] Enable fuse consumer
This patch adds support for consumer fusion to the tiling interface, and
implements fuse consumers on FuseIntoContainingOp.
- Add interface method 'getIterDomainTilePositionFromOperandPosition' to
tiling interface which get iteration domain position from operand
position.
- Add interface method 'getTiledImplementationFromOperandPosition' to
tiling interface which generate tiled implementation according to
operand position.
- Implemented the above two methods and supported consumer fusion for
FuseIntoContainingOp.
---
.../mlir/Interfaces/TilingInterface.td | 67 +++-
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 4 +-
.../Linalg/Transforms/TilingInterfaceImpl.cpp | 106 +++++--
.../Tensor/IR/TensorTilingInterfaceImpl.cpp | 12 +-
.../Dialect/Linalg/test-fuse-consumer.mlir | 103 ++++++
mlir/test/lib/Dialect/Linalg/CMakeLists.txt | 1 +
.../Dialect/Linalg/TestLinalgFuseConsumer.cpp | 294 ++++++++++++++++++
mlir/tools/mlir-opt/mlir-opt.cpp | 2 +
8 files changed, 549 insertions(+), 40 deletions(-)
create mode 100644 mlir/test/Dialect/Linalg/test-fuse-consumer.mlir
create mode 100644 mlir/test/lib/Dialect/Linalg/TestLinalgFuseConsumer.cpp
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 66382f29c24249..84f7dec2f4003d 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -63,7 +63,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
The method returns the operation that is the tiled
implementation.
}],
- /*retType=*/"FailureOr<TilingResult>",
+ /*retType=*/"FailureOr<::mlir::TilingResult>",
/*methodName=*/"getTiledImplementation",
/*args=*/(ins
"OpBuilder &":$b,
@@ -82,15 +82,34 @@ def TilingInterface : OpInterface<"TilingInterface"> {
by the tiled implementation. Expects the same `offsets` and `sizes` as
used to obtain the tiled implementation of the operation.
}],
- /*retType=*/"LogicalResult",
+ /*retType=*/"::mlir::LogicalResult",
/*methodName=*/"getResultTilePosition",
/*args=*/(ins
"OpBuilder &":$b,
"unsigned":$resultNumber,
"ArrayRef<OpFoldResult> ":$offsets,
"ArrayRef<OpFoldResult> ":$sizes,
- "SmallVector<OpFoldResult> &":$resultOffsets,
- "SmallVector<OpFoldResult> &":$resultSizes),
+ "SmallVectorImpl<OpFoldResult> &":$resultOffsets,
+ "SmallVectorImpl<OpFoldResult> &":$resultSizes),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return failure();
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Method to return the position of iteration domain tile computed by the
+ tiled operation.
+ }],
+ /*retType=*/"::mlir::LogicalResult",
+ /*methodName=*/"getIterationDomainTileFromOperandTile",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "unsigned":$operandNumber,
+ "ArrayRef<OpFoldResult> ":$offsets,
+ "ArrayRef<OpFoldResult> ":$sizes,
+ "SmallVectorImpl<OpFoldResult> &":$iterDomainOffsets,
+ "SmallVectorImpl<OpFoldResult> &":$iterDomainSizes),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
@@ -119,7 +138,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
iteration space).
- `sizes` provides the size of the tile.
}],
- /*retType=*/"FailureOr<TilingResult>",
+ /*retType=*/"FailureOr<::mlir::TilingResult>",
/*methodName=*/"generateResultTileValue",
/*args=*/(ins
"OpBuilder &":$b,
@@ -131,6 +150,42 @@ def TilingInterface : OpInterface<"TilingInterface"> {
return failure();
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Method to generate the tiled implementation of an operation from
+ operand tile position.
+
+ Generates the IR that computes the tiled implementation of an
+ operation from operand tile. The `offsets` and `sizes`
+ describe the tile of the operand required. This is different from
+ `getTiledImplementation` which generates the tiled
+ implementation of the operation given a tile of the
+ iteration space. This method generates a tiled
+ implementation of the operation based on the tile of the
+ operand required. This method enables consumer fusion by using
+ tile and fuse. The method returns failure if the operation
+ can't be tiled to generate the operand tile. In practical terms
+ this implies it cannot be tiled and fused with its producers.
+
+ - `offsets` provides the offset of the tile in the coordinate system
+ of the original iteration space, i.e., if an iteration space
+ dimension had non-zero offset, it must be included in the offset
+ provided here (as opposed to zero-based offset "relative" to the
+ iteration space).
+ - `sizes` provides the size of the tile.
+ }],
+ /*retType=*/"FailureOr<::mlir::TilingResult>",
+ /*methodName=*/"getTiledImplementationFromOperandTile",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "unsigned":$operandNumber,
+ "ArrayRef<OpFoldResult>":$offsets,
+ "ArrayRef<OpFoldResult>":$sizes),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return failure();
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Generates the scalar implementation of the operation.
@@ -142,7 +197,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
transformations are done, this method can be used to lower to scalar
code that can then be lowered to LLVM or SPIR-V dialects.
}],
- /*retType=*/"LogicalResult",
+ /*retType=*/"::mlir::LogicalResult",
/*methodName=*/"generateScalarImplementation",
/*args=*/(ins
"OpBuilder &":$b,
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9c5c58fa1fabfb..e9999c34d0face 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2425,8 +2425,8 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
LogicalResult SoftmaxOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
- SmallVector<OpFoldResult> &resultSizes) {
+ ArrayRef<OpFoldResult> sizes, SmallVectorImpl<OpFoldResult> &resultOffsets,
+ SmallVectorImpl<OpFoldResult> &resultSizes) {
if (resultNumber == 0) {
resultOffsets.assign(offsets.begin(), offsets.end());
resultSizes.assign(sizes.begin(), sizes.end());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index bd870d4f982e5d..71e9c3771dcded 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -110,7 +110,7 @@ struct LinalgOpTilingInterface
}));
}
- // Instantiate the tiled implementation of the operation.
+ /// Instantiate the tiled implementation of the operation.
FailureOr<TilingResult>
getTiledImplementation(Operation *op, OpBuilder &b,
ArrayRef<OpFoldResult> offsets,
@@ -132,14 +132,66 @@ struct LinalgOpTilingInterface
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
}
- // Return the details of the output tile generated by the tiled
- // implementation.
+ void
+ getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<OpFoldResult> &mappedOffsets,
+ SmallVectorImpl<OpFoldResult> &mappedSizes) const {
+ unsigned numLoops = linalgOp.getNumLoops();
+ auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
+ mappedOffsets.resize(numLoops);
+ mappedSizes.resize(numLoops);
+ if (!indexingMap.isPermutation()) {
+ SmallVector<Range> iterationDomain =
+ tilingInterfaceOp.getIterationDomain(b);
+ for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {
+ mappedOffsets[index] = value.offset;
+ mappedSizes[index] = value.size;
+ }
+ }
+ for (const auto &&[index, value] :
+ llvm::enumerate(indexingMap.getResults())) {
+ unsigned dimPosition = cast<AffineDimExpr>(value).getPosition();
+ mappedOffsets[dimPosition] = offsets[index];
+ mappedSizes[dimPosition] = sizes[index];
+ }
+ }
+
+ /// Return the details of the output tile generated by the tiled
+ /// implementation.
+ LogicalResult getIterationDomainTileFromOperandTile(
+ Operation *op, OpBuilder &b, unsigned operandNumber,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
+ SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
+ auto linalgOp = cast<LinalgOp>(op);
+
+ // Check that the indexing map used for the operand is a projected
+ // permutation. This could be relaxed with a more general approach that can
+ // map the offsets and sizes from the operand to iteration space tiles
+ // (filling in full extent for dimensions not used to access the result).
+ AffineMap indexingMap =
+ linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
+ if (!indexingMap.isProjectedPermutation()) {
+ return emitError(op->getLoc(),
+ "unhandled get iter domain position when operand is not "
+ "accessed using a permuted projection");
+ }
+
+ getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
+ iterDomainOffsets, iterDomainSizes);
+ return success();
+ }
+
+ /// Return the details of the output tile generated by the tiled
+ /// implementation.
LogicalResult
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
- SmallVector<OpFoldResult> &resultOffsets,
- SmallVector<OpFoldResult> &resultSizes) const {
+ SmallVectorImpl<OpFoldResult> &resultOffsets,
+ SmallVectorImpl<OpFoldResult> &resultSizes) const {
Location loc = op->getLoc();
LinalgOp linalgOp = cast<LinalgOp>(op);
@@ -160,6 +212,21 @@ struct LinalgOpTilingInterface
return success();
}
+ FailureOr<TilingResult> getTiledImplementationFromOperandTile(
+ Operation *op, OpBuilder &b, unsigned operandNumber,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+ auto tilingInterfaceOp = cast<TilingInterface>(op);
+ if (failed(tilingInterfaceOp.getIterationDomainTileFromOperandTile(
+ b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
+ return emitError(
+ op->getLoc(),
+ "unable to obtain the iter domain position of the operation.");
+ }
+ return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
+ mappedSizes);
+ }
+
FailureOr<TilingResult>
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
@@ -177,29 +244,16 @@ struct LinalgOpTilingInterface
"unhandled tiled implementation generation when result is not "
"accessed using a permuted projection");
}
-
- auto numLoops = linalgOp.getNumLoops();
+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+ getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
+ mappedOffsets, mappedSizes);
auto tilingInterfaceOp = cast<TilingInterface>(op);
- SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
- iterationTileSizes(numLoops);
- if (!indexingMap.isPermutation()) {
- SmallVector<Range> iterationDomain =
- tilingInterfaceOp.getIterationDomain(b);
- for (const auto &range : llvm::enumerate(iterationDomain)) {
- iterationTileOffsets[range.index()] = range.value().offset;
- iterationTileSizes[range.index()] = range.value().size;
- }
- }
- for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
- unsigned dimPosition =
- cast<AffineDimExpr>(resultExpr.value()).getPosition();
- iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
- iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
- }
-
FailureOr<TilingResult> tilingResult =
- tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
- iterationTileSizes);
+ tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
+
+ if (failed(tilingResult))
+ return failure();
+
if (tilingResult->tiledOps.size() != 1)
return op->emitOpError("failed to generate tiled implementation");
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index d25efcf50ec566..296c5fc7a5c2bd 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -61,8 +61,8 @@ struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
- SmallVector<OpFoldResult> &resultOffsets,
- SmallVector<OpFoldResult> &resultSizes) const {
+ SmallVectorImpl<OpFoldResult> &resultOffsets,
+ SmallVectorImpl<OpFoldResult> &resultSizes) const {
resultOffsets.assign(offsets.begin(), offsets.end());
resultSizes.assign(sizes.begin(), sizes.end());
return success();
@@ -199,8 +199,8 @@ struct PackOpTiling
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
- SmallVector<OpFoldResult> &resultOffsets,
- SmallVector<OpFoldResult> &resultSizes) const {
+ SmallVectorImpl<OpFoldResult> &resultOffsets,
+ SmallVectorImpl<OpFoldResult> &resultSizes) const {
// The iteration domain is over outer dimensions of packed layout. In this
// context, the outer dimensions of `resultOffsets` are `offsets`. The
// inner dimensions of `resultOffsets` are zeros because tiling is not
@@ -452,8 +452,8 @@ struct UnPackOpTiling
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
- SmallVector<OpFoldResult> &resultOffsets,
- SmallVector<OpFoldResult> &resultSizes) const {
+ SmallVectorImpl<OpFoldResult> &resultOffsets,
+ SmallVectorImpl<OpFoldResult> &resultSizes) const {
resultOffsets = llvm::to_vector(offsets);
resultSizes = llvm::to_vector(sizes);
return success();
diff --git a/mlir/test/Dialect/Linalg/test-fuse-consumer.mlir b/mlir/test/Dialect/Linalg/test-fuse-consumer.mlir
new file mode 100644
index 00000000000000..e7edbf0b2c25d4
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/test-fuse-consumer.mlir
@@ -0,0 +1,103 @@
+// RUN: mlir-opt %s -split-input-file -test-linalg-fuse-consumer | FileCheck %s
+
+#map = affine_map<()[s0] -> (64 ceildiv s0)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0] -> (-(d0 * s0) + 64, s0)>
+// CHECK-LABEL: func.func @fuse_tileable_consumer
+// CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
+// CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<64xf32>
+// CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<64xf32>
+func.func @fuse_tileable_consumer(%arg0: index, %arg1: tensor<64xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+ // CHECK: %[[SLICE:.*]] = tensor.empty(%[[CHUNK_SIZE]]) : tensor<?xf32>
+ %0 = tensor.empty(%arg0) : tensor<?xf32>
+ %1 = affine.apply #map()[%arg0]
+ // CHECK: %[[EMPTY0:[0-9a-z]+]] = tensor.empty() : tensor<64xf32>
+ %2 = tensor.empty() : tensor<64xf32>
+ // CHECK: %[[EMPTY1:[0-9a-z]+]] = tensor.empty() : tensor<64xf32>
+ %3 = tensor.empty() : tensor<64xf32>
+ // CHECK: %[[RES:[0-9a-z]+]]:2 = scf.forall {{.*}} shared_outs(%[[LOOP_ARG0:.*]] = %[[OUT]], %[[LOOP_ARG1:.*]] = %[[EMPTY1]]
+ %4 = scf.forall (%arg3) in (%1) shared_outs(%arg4 = %arg2) -> (tensor<64xf32>) {
+ %6 = affine.apply #map1(%arg3)[%arg0]
+ %7 = affine.min #map2(%arg3)[%arg0]
+ // CHECK: %[[T0:.*]] = tensor.extract_slice %[[LOOP_ARG0]][%{{.*}}] [%{{.*}}] [{{.*}}]
+ %extracted_slice = tensor.extract_slice %arg4[%6] [%7] [1] : tensor<64xf32> to tensor<?xf32>
+ // CHECK: %[[T1:[0-9a-z]+]] = linalg.elemwise_unary
+ %8 = linalg.elemwise_unary ins(%0 : tensor<?xf32>) outs(%extracted_slice : tensor<?xf32>) -> tensor<?xf32>
+
+ // CHECK: %[[T2:.*]] = tensor.extract_slice %[[EMPTY0]][%{{.*}}] [%{{.*}}] [{{.*}}]
+ // CHECK: %[[T3:.*]] = tensor.extract_slice %[[LOOP_ARG1]][%{{.*}}] [%{{.*}}] [{{.*}}]
+ // CHECK: %[[T4:.*]] = linalg.elemwise_binary {{.*}} ins(%[[T1]], %[[T2]] : {{.*}} outs(%[[T3]]
+
+ scf.forall.in_parallel {
+ // CHECK: tensor.parallel_insert_slice %[[T4]] into %[[LOOP_ARG1]]
+ // CHECK: tensor.parallel_insert_slice %[[T1]] into %[[LOOP_ARG0]]
+ tensor.parallel_insert_slice %8 into %arg4[%6] [%7] [1] : tensor<?xf32> into tensor<64xf32>
+ }
+ } {"containing"}
+ // CHECK: %[[ORI_OUTPUT:.*]] = linalg.elemwise_binary
+ %5 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>, "consumer"} ins(%4, %2 : tensor<64xf32>, tensor<64xf32>) outs(%3 : tensor<64xf32>) -> tensor<64xf32>
+ // CHECK: return %[[RES]]#1
+ return %5 : tensor<64xf32>
+}
+// -----
+
+#map = affine_map<(d0) -> (d0 * -50 + 123, 50)>
+#map1 = affine_map<(d0) -> (d0 * -16 + 789, 16)>
+#map2 = affine_map<(d0) -> (d0 * 50)>
+#map3 = affine_map<(d0) -> (d0 * 16)>
+#map4 = affine_map<(d0, d1) -> (d0, d1)>
+#map5 = affine_map<(d0, d1) -> (d1, d0)>
+// CHECK-LABEL: func.func @fuse_consumer_multi_output
+// CHECK-SAME: %[[IN0:[0-9a-z]+]]: tensor<123x456xf32>
+// CHECK-SAME: %[[IN1:[0-9a-z]+]]: tensor<456x789xf32>
+// CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<123x789xf32>
+func.func @fuse_consumer_multi_output(%arg0: tensor<123x456xf32>, %arg1: tensor<456x789xf32>, %arg2: tensor<123x789xf32>) -> (tensor<123x789xf32>, tensor<789x123xf32>) {
+ %cst = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[INIT:.*]] = linalg.fill
+ %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<123x789xf32>) -> tensor<123x789xf32>
+ // CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<123x789xf32>
+ %1 = tensor.empty() : tensor<123x789xf32>
+ // CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<789x123xf32>
+ %2 = tensor.empty() : tensor<789x123xf32>
+ // CHECK: %[[RES:[0-9a-z]+]]:3 = scf.forall {{.*}} shared_outs(%[[LOOP_ARG0:.*]] = %[[INIT]], %[[LOOP_ARG1:.*]] = %[[EMPTY0]], %[[LOOP_ARG2:.*]] = %[[EMPTY1]]
+ %3 = scf.forall (%arg3, %arg4) in (3, 50) shared_outs(%arg5 = %0) -> (tensor<123x789xf32>) {
+ %5 = affine.min #map(%arg3)
+ %6 = affine.min #map1(%arg4)
+ %7 = affine.apply #map2(%arg3)
+ %8 = affine.apply #map3(%arg4)
+ %9 = affine.apply #map2(%arg3)
+ %10 = affine.apply #map3(%arg4)
+ // CHECK: %[[EXTRACT_IN0:.*]] = tensor.extract_slice %[[IN0]]
+ %extracted_slice = tensor.extract_slice %arg0[%7, 0] [%5, 456] [1, 1] : tensor<123x456xf32> to tensor<?x456xf32>
+ // CHECK: %[[EXTRACT_IN1:.*]] = tensor.extract_slice %[[IN1]]
+ %extracted_slice_0 = tensor.extract_slice %arg1[0, %8] [456, %6] [1, 1] : tensor<456x789xf32> to tensor<456x?xf32>
+ // CHECK: %[[EXTRACT_OUT:.*]] = tensor.extract_slice %[[LOOP_ARG0]]
+ %extracted_slice_1 = tensor.extract_slice %arg5[%9, %10] [%5, %6] [1, 1] : tensor<123x789xf32> to tensor<?x?xf32>
+ // CHECK: %[[MATMUL_RES:.*]] = linalg.matmul ins(%[[EXTRACT_IN0]], %[[EXTRACT_IN1]] {{.*}} outs(%[[EXTRACT_OUT]]
+ %11 = linalg.matmul ins(%extracted_slice, %extracted_slice_0 : tensor<?x456xf32>, tensor<456x?xf32>) outs(%extracted_slice_1 : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+ // CHECK: %[[EXTRACT_EMPTY0:.*]] = tensor.extract_slice %[[LOOP_ARG1]]
+ // CHECK: %[[EXTRACT_EMPTY1:.*]] = tensor.extract_slice %[[LOOP_ARG2]]
+ // CHECK: %[[GENERIC_RES:.*]]:2 = linalg.generic {{.*}} ins(%[[MATMUL_RES]] : tensor<?x?xf32>) outs(%[[EXTRACT_EMPTY0]], %[[EXTRACT_EMPTY1]]
+
+ %12 = affine.apply #map2(%arg3)
+ %13 = affine.apply #map3(%arg4)
+ scf.forall.in_parallel {
+ // CHECK: tensor.parallel_insert_slice %[[GENERIC_RES]]#0 into %[[LOOP_ARG1]]
+ // CHECK: tensor.parallel_insert_slice %[[GENERIC_RES]]#1 into %[[LOOP_ARG2]]
+ // CHECK: tensor.parallel_insert_slice %[[MATMUL_RES]] into %[[LOOP_ARG0]]
+ tensor.parallel_insert_slice %11 into %arg5[%12, %13] [%5, %6] [1, 1] : tensor<?x?xf32> into tensor<123x789xf32>
+ }
+ } {"containing"}
+ // CHECK: %[[ORI_OUTPUT:.*]]:2 = linalg.generic
+ %4:2 = linalg.generic {"consumer", indexing_maps = [#map4, #map4, #map5], iterator_types = ["parallel", "parallel"]} ins(%3 : tensor<123x789xf32>) outs(%1, %2 : tensor<123x789xf32>, tensor<789x123xf32>) {
+ ^bb0(%in: f32, %out: f32, %out_0: f32):
+ %5 = arith.addf %in, %out : f32
+ %6 = arith.addf %5, %out_0 : f32
+ linalg.yield %5, %6 : f32, f32
+ } -> (tensor<123x789xf32>, tensor<789x123xf32>)
+ // CHECK: return %[[RES]]#1, %[[RES]]#2
+ return %4#0, %4#1 : tensor<123x789xf32>, tensor<789x123xf32>
+}
+
+
diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index b28f2b3564662a..23479912a865cd 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -7,6 +7,7 @@ add_mlir_library(MLIRLinalgTestPasses
TestLinalgFusionTransforms.cpp
TestLinalgTransforms.cpp
TestPadFusion.cpp
+ TestLinalgFuseConsumer.cpp
EXCLUDE_FROM_LIBMLIR
diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgFuseConsumer.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgFuseConsumer.cpp
new file mode 100644
index 00000000000000..ae52684b204483
--- /dev/null
+++ b/mlir/test/lib/Dialect/Linalg/TestLinalgFuseConsumer.cpp
@@ -0,0 +1,294 @@
+//===- TestLinalgFuseConsumer.cpp - Test Linalg fuse consumer ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements a pass for testing fuse consumer of Linalg ops.
+// This is a temporary pass used to verify the correctness of the tiling
+// interface in linalg op and the related interface of fuse consumer. It should
+// be replaced with that implementation when the corresponding fusion transform
+// op is completed.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Diagnostics.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Region.h"
+#include "mlir/IR/Value.h"
+#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/TilingInterface.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SmallVector.h"
+
+using namespace mlir;
+
+#define DEBUG_TYPE "fuse-consumer"
+
+namespace {
+struct TestLinalgFuseConsumer
+ : public PassWrapper<TestLinalgFuseConsumer, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestLinalgFuseConsumer)
+
+ TestLinalgFuseConsumer() = default;
+ TestLinalgFuseConsumer(const TestLinalgFuseConsumer &pass)
+ : PassWrapper(pass){};
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<scf::SCFDialect, linalg::LinalgDialect,
+ tensor::TensorDialect>();
+ }
+ StringRef getArgument() const final { return "test-linalg-fuse-consumer"; }
+ StringRef getDescription() const final {
+ return "Test Linalg fuse consumer interface";
+ }
+
+ void runOnOperation() override {
+ Operation *consumerOp = nullptr, *containingOp = nullptr;
+ auto walkRes = getOperation()->walk([&](Operation *op) {
+ if (op->hasAttr("consumer")) {
+ if (consumerOp) {
+ return WalkResult::interrupt();
+ }
+ consumerOp = op;
+ }
+ if (op->hasAttr("containing")) {
+ if (containingOp) {
+ return WalkResult::interrupt();
+ }
+ containingOp = op;
+ }
+ return WalkResult::advance();
+ });
+
+ if (!consumerOp || !containingOp || walkRes.wasInterrupted()) {
+ emitError(getOperation()->getLoc())
+ << "expect 1 consumer and 1 containing op.";
+ return;
+ }
+
+ // Check consumer has tiling interface.
+ auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+ if (!tileableConsumer) {
+ emitError(consumerOp->getLoc())
+ << "consumer is not a TileableInterface: " << *consumerOp;
+ return;
+ }
+
+ // Check containing op is "scf::ForallOp".
+ auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+ if (!forallOp) {
+ emitError(containingOp->getLoc())
+ << "containing op is not a scf.forall: " << containingOp;
+ return;
+ }
+
+ // Check dominance.
+ DominanceInfo domInfo(getOperation());
+ if (llvm::any_of(consumerOp->getOperands(), [&](Value v) {
+ return v.getDefiningOp() != containingOp &&
+ !domInfo.properlyDominates(v, containingOp);
+ })) {
+ emitError(consumerOp->getLoc())
+ << "consumer's operand can't dominate containing op";
+ return;
+ }
+
+ // Check consumer don't use more than one result of containingOp.
+ Value bridge(nullptr);
+ SmallVector<unsigned> operandNums;
+ for (auto [idx, opd] : llvm::enumerate((consumerOp->getOperands()))) {
+ if (opd.getDefiningOp() == containingOp) {
+ operandNums.push_back(idx);
+ if (!bridge) {
+ bridge = opd;
+ } else if (bridge != opd) {
+ emitError(consumerOp->getLoc())
+ << "consumer's operand use more than one containingOp's result";
+ return;
+ }
+ }
+ }
+
+ // Check consumer has DestinationStyleOpInterface.
+ auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+ if (!dstOp) {
+ emitError(consumerOp->getLoc())
+ << "consumer op should have destination style op interface";
+ return;
+ }
+
+ // Check consumer doon't use scf.forall's output as init.
+ SmallVector<Value> dpsInits = llvm::to_vector<4>(
+ llvm::map_range(dstOp.getDpsInits(), [](Value v) { return v; }));
+ if (llvm::is_contained(dpsInits, bridge)) {
+ emitError(consumerOp->getLoc())
+ << "consumer op take result of scf.forall as init";
+ return;
+ }
+
+ // Check result was inserted only once.
+ int64_t bridgeResultIdx = cast<OpResult>(bridge).getResultNumber();
+ auto bridgeBlockArg = forallOp.getRegionOutArgs()[bridgeResultIdx];
+ scf::InParallelOp terminatorOp = forallOp.getTerminator();
+
+ tensor::ParallelInsertSliceOp targetInsertOp(nullptr);
+ for (Operation &op : terminatorOp.getRegion().front().getOperations()) {
+ auto parallelInsertSliceOp = cast<tensor::ParallelInsertSliceOp>(op);
+ if (parallelInsertSliceOp.getDest() == bridgeBlockArg) {
+ if (!targetInsertOp) {
+ targetInsertOp = parallelInsertSliceOp;
+ } else {
+ emitError(containingOp->getLoc())
+ << "containingOp's result inserted multi time";
+ return;
+ }
+ }
+ }
+
+ if (!targetInsertOp) {
+ emitError(containingOp->getLoc())
+ << "containingOp's result was not inserted";
+ return;
+ }
+
+ SmallVector<OpFoldResult> offsets = targetInsertOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = targetInsertOp.getMixedSizes();
+ SmallVector<OpFoldResult> strides = targetInsertOp.getMixedStrides();
+
+ // Check all insert stride is 1.
+ if (llvm::any_of(strides, [](OpFoldResult foldRes) {
+ if (auto attr = foldRes.dyn_cast<Attribute>()) {
+ return cast<IntegerAttr>(attr).getInt() != 1;
+ }
+ return true;
+ })) {
+ emitError(containingOp->getLoc())
+ << "containingOp's result yield with stride";
+ return;
+ }
+
+ IRRewriter rewriter(terminatorOp);
+ Location loc = forallOp.getLoc();
+
+ SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+
+ // Try to get iter domain position from input position.
+ if (failed(tileableConsumer.getIterationDomainTileFromOperandTile(
+ rewriter, operandNums.front(), offsets, sizes, iterDomainOffsets,
+ iterDomainSizes))) {
+ emitError(consumerOp->getLoc())
+ << "can't get iter domain position from input position";
+ return;
+ }
+
+ // Try to get all containing op result's position from iter domain position.
+ llvm::SmallVector<std::pair<llvm::SmallVector<OpFoldResult>,
+ llvm::SmallVector<OpFoldResult>>>
+ resultPositions(consumerOp->getNumResults());
+ for (auto [idx, v] : llvm::enumerate(consumerOp->getResults())) {
+ if (failed(tileableConsumer.getResultTilePosition(
+ rewriter, idx, iterDomainOffsets, iterDomainSizes,
+ resultPositions[idx].first, resultPositions[idx].second))) {
+ emitError(consumerOp->getLoc())
+ << "can't get result domain position from iter domain position";
+ return;
+ }
+ }
+
+ // All check passed, try to fuse consumer.
+ // Create tiled implementation of containing op.
+ FailureOr<TilingResult> tileAndFuseResult =
+ tileableConsumer.getTiledImplementationFromOperandTile(
+ rewriter, operandNums.front(), offsets, sizes);
+ if (failed(tileAndFuseResult)) {
+ emitError(consumerOp->getLoc()) << "get tiled implementation failed";
+ return;
+ }
+
+ auto tiledOps = tileAndFuseResult->tiledOps;
+ if (failed(tileAndFuseResult) || tiledOps.size() != 1) {
+ emitError(consumerOp->getLoc())
+ << "failed to tile consumer op: " << *tileableConsumer;
+ return;
+ }
+
+ // Replace tiled op's operand.
+ for (auto operandNum : operandNums) {
+ tiledOps[0]->setOperand(operandNum, targetInsertOp.getSource());
+ }
+ rewriter.replaceUsesWithIf(bridge, forallOp.getOutputs()[bridgeResultIdx],
+ [&](OpOperand &use) {
+ Operation *op = use.getOwner();
+ return forallOp->isProperAncestor(op);
+ });
+
+ SmallVector<Value> newOuts(forallOp.getOutputs());
+ newOuts.append(dpsInits);
+
+ // Create new scf.forall op.
+ rewriter.setInsertionPoint(forallOp);
+ auto newforallOp = rewriter.create<scf::ForallOp>(
+ loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+ forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+ rewriter.eraseBlock(newforallOp.getBody());
+ newforallOp.getRegion().takeBody(forallOp.getRegion());
+
+ for (auto v : dpsInits) {
+ newforallOp.getBody()->addArgument(v.getType(), v.getLoc());
+ auto bbArgs = newforallOp.getBody()->getArguments();
+ rewriter.replaceUsesWithIf(v, bbArgs.back(), [&](OpOperand &use) {
+ Operation *op = use.getOwner();
+ return newforallOp->isProperAncestor(op);
+ });
+ }
+
+ // Fix terminator.
+ scf::InParallelOp newTerminatorOp = newforallOp.getTerminator();
+ SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(llvm::map_range(
+ newTerminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));
+ Operation *firstYieldOp = yieldingOps.front();
+ rewriter.setInsertionPoint(firstYieldOp);
+ auto bbArgs = newforallOp.getBody()->getArguments();
+ for (auto [idx, v] : llvm::enumerate(tiledOps[0]->getResults())) {
+ SmallVector<OpFoldResult> strides(resultPositions[idx].first.size(),
+ rewriter.getIndexAttr(1));
+ rewriter.create<tensor::ParallelInsertSliceOp>(
+ firstYieldOp->getLoc(), v,
+ bbArgs[forallOp.getRank() + forallOp.getOutputs().size() + idx],
+ resultPositions[idx].first, resultPositions[idx].second, strides);
+ }
+
+ // Replace the result of forall and consumer op.
+ for (auto result : llvm::enumerate(forallOp.getResults())) {
+ rewriter.replaceAllUsesWith(result.value(),
+ newforallOp->getResult(result.index()));
+ }
+
+ for (auto consumerResult : llvm::enumerate(consumerOp->getResults())) {
+ rewriter.replaceAllUsesWith(
+ consumerResult.value(),
+ newforallOp->getResult(forallOp.getOutputs().size() +
+ consumerResult.index()));
+ }
+ forallOp.erase();
+ return;
+ }
+};
+} // namespace
+
+namespace mlir {
+namespace test {
+void registerTestLinalgFuseConsumer() {
+ PassRegistration<TestLinalgFuseConsumer>();
+}
+} // namespace test
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 237ebeb166dc99..87600c5c161d28 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -106,6 +106,7 @@ void registerTestLinalgDropUnitDims();
void registerTestLinalgElementwiseFusion();
void registerTestLinalgGreedyFusion();
void registerTestLinalgTransforms();
+void registerTestLinalgFuseConsumer();
void registerTestLivenessAnalysisPass();
void registerTestLivenessPass();
void registerTestLoopFusion();
@@ -235,6 +236,7 @@ void registerTestPasses() {
mlir::test::registerTestLinalgElementwiseFusion();
mlir::test::registerTestLinalgGreedyFusion();
mlir::test::registerTestLinalgTransforms();
+ mlir::test::registerTestLinalgFuseConsumer();
mlir::test::registerTestLivenessAnalysisPass();
mlir::test::registerTestLivenessPass();
mlir::test::registerTestLoopFusion();
More information about the Mlir-commits
mailing list