[Mlir-commits] [mlir] 53f7fb0 - [mlir][linalg] Do not fuse shape-only producers.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 24 03:27:23 PDT 2022
Author: gysit
Date: 2022-03-24T10:22:41Z
New Revision: 53f7fb0a8703ffcc5ccdb0d8133ca0df494ef68f
URL: https://github.com/llvm/llvm-project/commit/53f7fb0a8703ffcc5ccdb0d8133ca0df494ef68f
DIFF: https://github.com/llvm/llvm-project/commit/53f7fb0a8703ffcc5ccdb0d8133ca0df494ef68f.diff
LOG: [mlir][linalg] Do not fuse shape-only producers.
This revision introduces a heuristic to stop fusion for shape-only tensors. A shape-only tensor only defines the shape of the consumer computation while the data is not used. Pure producer consumer fusion thus shall not fuse the producer of a shape-only tensor. In particular, since the shape-only tensor will have other uses that actually consume the data.
The revision enables fusion for consumers that have two uses of the same tensor. One as input operand and one as shape-only output operand. In these cases, we want to fuse only the input operand and avoid output fusion via iteration argument.
Reviewed By: hanchung
Differential Revision: https://reviews.llvm.org/D120981
Added:
Modified:
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 1d46657018b39..3ce6570f4bf53 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -349,6 +349,12 @@ FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
consumerOp->getBlock() != rootOp->getBlock())
return failure();
+ // Check `consumerOpOperand` is not shape-only to avoid fusion if the data is
+ // not used by the `consumerOp` computation.
+ BlockArgument bbArg = consumerOp.getTiedBlockArgument(consumerOpOperand);
+ if (bbArg.getUses().empty())
+ return failure();
+
// Check if the producer is a LinalgOp possibly passed by iteration argument.
OpOperand *iterArg = nullptr;
auto producerResult = sliceOp.source().dyn_cast<OpResult>();
diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir
index 85c6cca7e366b..7509fdb866e43 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir
@@ -1,19 +1,40 @@
-// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul fuse tile-sizes=0,0,0 run-enable-pass=false" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul fuse tile-sizes=0,0,0 run-enable-pass=false" -split-input-file | FileCheck --check-prefix=MATMUL %s
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.elemwise_unary fuse tile-sizes=32,32,0 run-enable-pass=false" -split-input-file | FileCheck --check-prefix=UNARY %s
-func.func @no_fuse_gemm(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+// MATMUL-LABEL: @tile_sizes_zero(
+func.func @tile_sizes_zero(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%cst = arith.constant 0.0 : f32
%d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
%init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
+
+ // MATMUL-NOT: scf.for
+ // MATMUL: linalg.fill
%fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+ // MATMUL-NOT: scf.for
+ // MATMUL: linalg.matmul
%result = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
- return %result : tensor<?x?xf32>
+ func.return %result : tensor<?x?xf32>
+}
+
+// -----
+
+// UNARY_LABEL: @shape_only(
+func.func @shape_only(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %cst = arith.constant 0.0 : f32
+
+ // UNARY: linalg.fill
+ %0 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+ // UNARY: scf.for
+ // UNARY: scf.for
+ // UNARY-NOT: linalg.fill
+ // UNARY: linalg.elemwise_unary
+ %1 = linalg.elemwise_unary {fun = #linalg.unary_fn<exp>}
+ ins(%arg0 : tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+ func.return %1 : tensor<?x?xf32>
}
-// CHECK-LABEL: @no_fuse_gemm(
-// CHECK-NOT: scf.for
-// CHECK: linalg.fill
-// CHECK-NOT: scf.for
-// CHECK: linalg.matmul
More information about the Mlir-commits
mailing list