[Mlir-commits] [mlir] [mlir][linalg] do not break outs from block argument (PR #73572)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Nov 27 13:58:11 PST 2023


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-linalg

Author: Okwan Kwon (okkwon)

<details>
<summary>Changes</summary>

When an argument is used for the output, do not break the dependency. Otherwise, the argument will be marked as unused and optimized away.

---
Full diff: https://github.com/llvm/llvm-project/pull/73572.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp (+5) 
- (modified) mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir (+6-8) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index f0393e44fc00c27..0d33cc99ae55e48 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1818,6 +1818,11 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
         if (sparse_tensor::getSparseTensorEncoding(operandVal.getType()))
           continue;
 
+        // If outs is wired from a block argument, keep the dependency to
+        // prevent the argument from being optimized away.
+        if (isa<BlockArgument>(operandVal))
+          continue;
+
         // If outs is already an `empty` operation, nothing to do.
         auto definingOp = operandVal.getDefiningOp<tensor::EmptyOp>();
         if (definingOp)
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 9d8421cbab49d8e..b0038966e074798 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -730,11 +730,8 @@ func.func @break_outs_dependency(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
 // CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?xf32>)
 //  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
 //  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
-//  CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
-//  CHECK-DAG:   %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]])
 //      CHECK:   %[[GENERIC1:.+]] = linalg.generic
-// CHECK-SAME:     outs(%[[INIT]] : tensor<?x?xf32>)
+// CHECK-SAME:     outs(%[[ARG0]] : tensor<?x?xf32>)
 //  CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[GENERIC1]], %[[C0]]
 //  CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[GENERIC1]], %[[C1]]
 //  CHECK-DAG:   %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]])
@@ -976,11 +973,10 @@ func.func @fusion_different_axes(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi
 //      CHECK: func @fusion_different_axes(
 // CHECK-SAME:     %[[ARG0:.+]]: tensor<5000xi64>
 // CHECK-SAME:     %[[ARG1:.+]]: tensor<5000xi32>
-//  CHECK-DAG:   %[[INIT0:.+]] = tensor.empty() : tensor<5000xi64>
 //  CHECK-DAG:   %[[INIT1:.+]] = tensor.empty() : tensor<5000xi32>
 //      CHECK:   %[[RESULT:.+]]:2 = linalg.generic
 // CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]]]
-// CHECK-SAME:       outs(%[[INIT0]], %[[INIT1]] :
+// CHECK-SAME:       outs(%[[ARG0]], %[[INIT1]] :
 // CHECK-NEXT:   ^bb0(
 // CHECK-SAME:       %[[B0:.+]]: i64
 // CHECK-SAME:       %[[B1:.+]]: i32
@@ -1097,10 +1093,12 @@ module {
 // CHECK-LABEL: func.func @fuse_multi_result_producer
 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9_]+]]: tensor<f32>
 //  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: tensor<f32>
-//       CHECK:   %[[INIT:.+]] = tensor.empty
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9_]+]]: tensor<f32>
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9_]+]]: tensor<f32>
+//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9_]+]]: tensor<f32>
 //       CHECK:   %[[GENERIC:.+]] = linalg.generic
 //  CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] :
-//  CHECK-SAME:       outs(%[[INIT]] :
+//  CHECK-SAME:       outs(%[[ARG4]] :
 //  CHECK-NEXT:     ^bb0
 //  CHECK-SAME:         %[[B0:[a-zA-Z0-9_]+]]: f32
 //  CHECK-SAME:         %[[B1:[a-zA-Z0-9_]+]]: f32

``````````

</details>


https://github.com/llvm/llvm-project/pull/73572


More information about the Mlir-commits mailing list