[Mlir-commits] [mlir] [mlir][linalg] do not break outs from block argument (PR #73572)
Okwan Kwon
llvmlistbot at llvm.org
Mon Nov 27 13:57:54 PST 2023
https://github.com/okkwon created https://github.com/llvm/llvm-project/pull/73572
When an argument is used for the output, do not break the dependency. Otherwise, the argument will be marked as unused and optimized away.
>From c70cb6218fa2bf8fcd04352eebe81d651c4cca80 Mon Sep 17 00:00:00 2001
From: Okwan Kwon <okkwon at gmail.com>
Date: Mon, 27 Nov 2023 13:53:07 -0800
Subject: [PATCH] [mlir][linalg] do not break outs from block argument
When an argument is used for the output, do not break the dependency.
Otherwise, the argument will be marked as unused and optimized away.
---
.../Linalg/Transforms/ElementwiseOpFusion.cpp | 5 +++++
.../Dialect/Linalg/fusion-elementwise-ops.mlir | 14 ++++++--------
2 files changed, 11 insertions(+), 8 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
index f0393e44fc00c27..0d33cc99ae55e48 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
@@ -1818,6 +1818,11 @@ struct RemoveOutsDependency : public OpRewritePattern<GenericOp> {
if (sparse_tensor::getSparseTensorEncoding(operandVal.getType()))
continue;
+ // If outs is wired from a block argument, keep the dependency to
+ // prevent the argument from being optimized away.
+ if (isa<BlockArgument>(operandVal))
+ continue;
+
// If outs is already an `empty` operation, nothing to do.
auto definingOp = operandVal.getDefiningOp<tensor::EmptyOp>();
if (definingOp)
diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
index 9d8421cbab49d8e..b0038966e074798 100644
--- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
+++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir
@@ -730,11 +730,8 @@ func.func @break_outs_dependency(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?xf32>)
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
-// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
-// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
-// CHECK-DAG: %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]])
// CHECK: %[[GENERIC1:.+]] = linalg.generic
-// CHECK-SAME: outs(%[[INIT]] : tensor<?x?xf32>)
+// CHECK-SAME: outs(%[[ARG0]] : tensor<?x?xf32>)
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[GENERIC1]], %[[C0]]
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[GENERIC1]], %[[C1]]
// CHECK-DAG: %[[INIT:.+]] = tensor.empty(%[[D0]], %[[D1]])
@@ -976,11 +973,10 @@ func.func @fusion_different_axes(%arg0 : tensor<5000xi64>, %arg1 : tensor<5000xi
// CHECK: func @fusion_different_axes(
// CHECK-SAME: %[[ARG0:.+]]: tensor<5000xi64>
// CHECK-SAME: %[[ARG1:.+]]: tensor<5000xi32>
-// CHECK-DAG: %[[INIT0:.+]] = tensor.empty() : tensor<5000xi64>
// CHECK-DAG: %[[INIT1:.+]] = tensor.empty() : tensor<5000xi32>
// CHECK: %[[RESULT:.+]]:2 = linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]]
-// CHECK-SAME: outs(%[[INIT0]], %[[INIT1]] :
+// CHECK-SAME: outs(%[[ARG0]], %[[INIT1]] :
// CHECK-NEXT: ^bb0(
// CHECK-SAME: %[[B0:.+]]: i64
// CHECK-SAME: %[[B1:.+]]: i32
@@ -1097,10 +1093,12 @@ module {
// CHECK-LABEL: func.func @fuse_multi_result_producer
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<f32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<f32>
-// CHECK: %[[INIT:.+]] = tensor.empty
+// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: tensor<f32>
+// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: tensor<f32>
+// CHECK-SAME: %[[ARG4:[a-zA-Z0-9_]+]]: tensor<f32>
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] :
-// CHECK-SAME: outs(%[[INIT]] :
+// CHECK-SAME: outs(%[[ARG4]] :
// CHECK-NEXT: ^bb0
// CHECK-SAME: %[[B0:[a-zA-Z0-9_]+]]: f32
// CHECK-SAME: %[[B1:[a-zA-Z0-9_]+]]: f32
More information about the Mlir-commits
mailing list