[Mlir-commits] [mlir] 53f7fb0 - [mlir][linalg] Do not fuse shape-only producers.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 24 03:27:23 PDT 2022


Author: gysit
Date: 2022-03-24T10:22:41Z
New Revision: 53f7fb0a8703ffcc5ccdb0d8133ca0df494ef68f

URL: https://github.com/llvm/llvm-project/commit/53f7fb0a8703ffcc5ccdb0d8133ca0df494ef68f
DIFF: https://github.com/llvm/llvm-project/commit/53f7fb0a8703ffcc5ccdb0d8133ca0df494ef68f.diff

LOG: [mlir][linalg] Do not fuse shape-only producers.

This revision introduces a heuristic to stop fusion for shape-only tensors. A shape-only tensor only defines the shape of the consumer computation while the data is not used. Pure producer consumer fusion thus shall not fuse the producer of a shape-only tensor. In particular, since the shape-only tensor will have other uses that actually consume the data.

The revision enables fusion for consumers that have two uses of the same tensor. One as input operand and one as shape-only output operand. In these cases, we want to fuse only the input operand and avoid output fusion via iteration argument.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D120981

Added: 
    

Modified: 
    mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
    mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
index 1d46657018b39..3ce6570f4bf53 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
@@ -349,6 +349,12 @@ FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
       consumerOp->getBlock() != rootOp->getBlock())
     return failure();
 
+  // Check `consumerOpOperand` is not shape-only to avoid fusion if the data is
+  // not used by the `consumerOp` computation.
+  BlockArgument bbArg = consumerOp.getTiedBlockArgument(consumerOpOperand);
+  if (bbArg.getUses().empty())
+    return failure();
+
   // Check if the producer is a LinalgOp possibly passed by iteration argument.
   OpOperand *iterArg = nullptr;
   auto producerResult = sliceOp.source().dyn_cast<OpResult>();

diff  --git a/mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir
index 85c6cca7e366b..7509fdb866e43 100644
--- a/mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir
+++ b/mlir/test/Dialect/Linalg/tile-and-fuse-no-fuse.mlir
@@ -1,19 +1,40 @@
-// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul fuse tile-sizes=0,0,0 run-enable-pass=false" -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul fuse tile-sizes=0,0,0 run-enable-pass=false" -split-input-file | FileCheck --check-prefix=MATMUL %s
+// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.elemwise_unary fuse tile-sizes=32,32,0 run-enable-pass=false" -split-input-file | FileCheck --check-prefix=UNARY %s
 
-func.func @no_fuse_gemm(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+// MATMUL-LABEL: @tile_sizes_zero(
+func.func @tile_sizes_zero(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %cst = arith.constant 0.0 : f32
   %d0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
   %d1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
   %init = linalg.init_tensor [%d0, %d1] : tensor<?x?xf32>
+
+  //   MATMUL-NOT:   scf.for
+  //       MATMUL:   linalg.fill
   %fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+  //   MATMUL-NOT:   scf.for
+  //       MATMUL:   linalg.matmul
   %result = linalg.matmul ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
       outs(%fill : tensor<?x?xf32>) -> tensor<?x?xf32>
-  return %result : tensor<?x?xf32>
+  func.return %result : tensor<?x?xf32>
+}
+
+// -----
+
+// UNARY_LABEL: @shape_only(
+func.func @shape_only(%arg0 : tensor<?x?xf32>, %arg1 : tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %cst = arith.constant 0.0 : f32
+
+  //       UNARY:   linalg.fill
+  %0 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+  //       UNARY:   scf.for
+  //       UNARY:     scf.for
+  //   UNARY-NOT:       linalg.fill
+  //       UNARY:       linalg.elemwise_unary
+  %1 = linalg.elemwise_unary {fun = #linalg.unary_fn<exp>}
+      ins(%arg0 : tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
+  func.return %1 : tensor<?x?xf32>
 }
-// CHECK-LABEL: @no_fuse_gemm(
-//   CHECK-NOT:   scf.for
-//       CHECK:   linalg.fill
-//   CHECK-NOT:   scf.for
-//       CHECK:   linalg.matmul


        


More information about the Mlir-commits mailing list