[Mlir-commits] [mlir] [mlir][linalg] do not break outs from block argument (PR #73572)

Okwan Kwon llvmlistbot at llvm.org
Mon Nov 27 13:57:54 PST 2023


https://github.com/okkwon created https://github.com/llvm/llvm-project/pull/73572

When an argument is used for the output, do not break the dependency. Otherwise, the argument will be marked as unused and optimized away.

>From c70cb6218fa2bf8fcd04352eebe81d651c4cca80 Mon Sep 17 00:00:00 2001
From: Okwan Kwon <okkwon at gmail.com>
Date: Mon, 27 Nov 2023 13:53:07 -0800
Subject: [PATCH] [mlir][linalg] do not break outs from block argument

When an argument is used for the output, do not break the dependency.
Otherwise, the argument will be marked as unused and optimized away.
---
 .../Linalg/Transforms/ElementwiseOpFusion.cpp      |  5 +++++
 .../Dialect/Linalg/fusion-elementwise-ops.mlir     | 14 ++++++--------
 2 files changed, 11 insertions(+), 8 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index f0393e44fc00c27..0d33cc99ae55e48 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1818,6 +1818,11 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
         if (sparse_tensor::getSparseTensorEncoding(operandVal.getType()))
           continue;
 
+        // If outs is wired from a block argument, keep the dependency to
+        // prevent the argument from being optimized away.
+        if (isa<BlockArgument>(operandVal))
+          continue;
+
         // If outs is already an `empty` operation, nothing to do.
         auto definingOp = operandVal.getDefiningOp<tensor::EmptyOp>();
         if (definingOp)
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 9d8421cbab49d8e..b0038966e074798 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -730,11 +730,8 @@ func.func @break_outs_dependency(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
 // CHECK-SAME:   %[[ARG0:.+]]: tensor<?x?xf32>)
 //  CHECK-DAG:   %[[C0:.+]] = arith.constant 0 : index
 //  CHECK-DAG:   %[[C1:.+]] = arith.constant 1 : index
-//  CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
-//  CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
-//  CHECK-DAG:   %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]])
 //      CHECK:   %[[GENERIC1:.+]] = linalg.generic
-// CHECK-SAME:     outs(%[[INIT]] : tensor<?x?xf32>)
+// CHECK-SAME:     outs(%[[ARG0]] : tensor<?x?xf32>)
 //  CHECK-DAG:   %[[D0:.+]] = tensor.dim %[[GENERIC1]], %[[C0]]
 //  CHECK-DAG:   %[[D1:.+]] = tensor.dim %[[GENERIC1]], %[[C1]]
 //  CHECK-DAG:   %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]])
@@ -976,11 +973,10 @@ func.func @fusion_different_axes(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi
 //      CHECK: func @fusion_different_axes(
 // CHECK-SAME:     %[[ARG0:.+]]: tensor<5000xi64>
 // CHECK-SAME:     %[[ARG1:.+]]: tensor<5000xi32>
-//  CHECK-DAG:   %[[INIT0:.+]] = tensor.empty() : tensor<5000xi64>
 //  CHECK-DAG:   %[[INIT1:.+]] = tensor.empty() : tensor<5000xi32>
 //      CHECK:   %[[RESULT:.+]]:2 = linalg.generic
 // CHECK-SAME:       indexing_maps = [#[[MAP0]], #[[MAP1]]]
-// CHECK-SAME:       outs(%[[INIT0]], %[[INIT1]] :
+// CHECK-SAME:       outs(%[[ARG0]], %[[INIT1]] :
 // CHECK-NEXT:   ^bb0(
 // CHECK-SAME:       %[[B0:.+]]: i64
 // CHECK-SAME:       %[[B1:.+]]: i32
@@ -1097,10 +1093,12 @@ module {
 // CHECK-LABEL: func.func @fuse_multi_result_producer
 //  CHECK-SAME:     %[[ARG0:[a-zA-Z0-9_]+]]: tensor<f32>
 //  CHECK-SAME:     %[[ARG1:[a-zA-Z0-9_]+]]: tensor<f32>
-//       CHECK:   %[[INIT:.+]] = tensor.empty
+//  CHECK-SAME:     %[[ARG2:[a-zA-Z0-9_]+]]: tensor<f32>
+//  CHECK-SAME:     %[[ARG3:[a-zA-Z0-9_]+]]: tensor<f32>
+//  CHECK-SAME:     %[[ARG4:[a-zA-Z0-9_]+]]: tensor<f32>
 //       CHECK:   %[[GENERIC:.+]] = linalg.generic
 //  CHECK-SAME:       ins(%[[ARG0]], %[[ARG1]] :
-//  CHECK-SAME:       outs(%[[INIT]] :
+//  CHECK-SAME:       outs(%[[ARG4]] :
 //  CHECK-NEXT:     ^bb0
 //  CHECK-SAME:         %[[B0:[a-zA-Z0-9_]+]]: f32
 //  CHECK-SAME:         %[[B1:[a-zA-Z0-9_]+]]: f32



More information about the Mlir-commits mailing list