[Mlir-commits] [mlir] [MLIR] Fix -remove-dead-values (#98935) (PR #99671)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jul 19 18:32:05 PDT 2024
https://github.com/huang-me updated https://github.com/llvm/llvm-project/pull/99671
>From 0bcdaa5e632e0045b145f5ff214e36d1f0cc26f2 Mon Sep 17 00:00:00 2001
From: huang-me <amos0107 at gmail.com>
Date: Sat, 20 Jul 2024 09:30:30 +0800
Subject: [PATCH 1/2] [MLIR] Removed arguments have no memory effect
---
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 4 ++--
mlir/lib/Transforms/RemoveDeadValues.cpp | 3 ++-
mlir/test/Transforms/remove-dead-values.mlir | 14 ++++++++++++++
3 files changed, 18 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index cefaad9b22653..8b27e67393250 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1124,7 +1124,7 @@ static void getGenericEffectsImpl(
&effects,
LinalgOp linalgOp) {
for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
- if (!llvm::isa<MemRefType>(operand.getType()))
+ if (!operand || !llvm::isa<MemRefType>(operand.getType()))
continue;
effects.emplace_back(
MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
@@ -1132,7 +1132,7 @@ static void getGenericEffectsImpl(
}
for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
- if (!llvm::isa<MemRefType>(operand.get().getType()))
+ if (!operand.get() || !llvm::isa<MemRefType>(operand.get().getType()))
continue;
if (linalgOp.payloadUsesValueFromOperand(&operand)) {
effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 055256903a152..6d7760b7c338b 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -208,7 +208,8 @@ static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module,
arg.dropAllUses();
// Do (2).
- funcOp.eraseArguments(nonLiveArgs);
+ if (nonLiveArgs.size())
+ funcOp.eraseArguments(nonLiveArgs);
// Do (3).
SymbolTable::UseRange uses = *funcOp.getSymbolUses(module);
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 69426fdb62083..25b67a0c61734 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -357,3 +357,17 @@ func.func @kernel(%arg0: memref<18xf32>) {
// CHECK: gpu.launch blocks
// CHECK: memref.store
// CHECK-NEXT: gpu.terminator
+
+// -----
+
+// Removed arguments have no memroy effect
+//
+// CHECK-LABEL: func.func @transpose() {
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+func.func @transpose() {
+ %cst_2 = arith.constant dense<[[1., 2., 3.], [4., 5., 6.]]> : tensor<2x3xf64>
+ %cst_3 = arith.constant dense<0.> : tensor<3x2xf64>
+ %transposed = linalg.transpose ins(%cst_2 : tensor<2x3xf64>) outs(%cst_3 : tensor<3x2xf64>) permutation = [1, 0]
+ return
+}
>From 2dc9dd54f2104418a9b5dd68113fcec3e13198fb Mon Sep 17 00:00:00 2001
From: huang-me <amos0107 at gmail.com>
Date: Sat, 20 Jul 2024 09:31:19 +0800
Subject: [PATCH 2/2] [MLIR][Transform] Skip removing return values of external
functions
---
mlir/lib/Transforms/RemoveDeadValues.cpp | 3 ++
mlir/test/Transforms/remove-dead-values.mlir | 37 ++++++++++++++++++++
2 files changed, 40 insertions(+)
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 6d7760b7c338b..b7e05056c6251 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -226,6 +226,9 @@ static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module,
callOp->eraseOperands(nonLiveCallOperands);
}
+ if (funcOp.isExternal())
+ return;
+
// Get the list of unnecessary terminator operands (return values that are
// non-live across all callers) in `nonLiveRets`. There is a very important
// subtlety here. Unnecessary terminator operands are NOT the operands of the
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 25b67a0c61734..73637ca6ee83b 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -360,6 +360,43 @@ func.func @kernel(%arg0: memref<18xf32>) {
// -----
+// skip removing return values of an external function
+//
+// CHECK-LABEL: func.func private @printF64(f64)
+// CHECK-NEXT: func.func @main() -> f64 {
+// CHECK-NEXT: %cst = arith.constant 1.500000e+00 : f64
+// CHECK-NEXT: %cst_0 = arith.constant 2.000000e+00 : f64
+// CHECK-NEXT: %0 = arith.addf %cst_0, %cst : f64
+// CHECK-NEXT: call @printF64(%cst) : (f64) -> ()
+// CHECK-NEXT: return %0 : f64
+// CHECK-NEXT: }
+// CHECK: func.func private @rtf32(f64) -> f32
+// CHECK-NEXT: func.func @main2() -> f32 {
+// CHECK-NEXT: %cst = arith.constant 1.500000e+00 : f64
+// CHECK-NEXT: %0 = call @rtf32(%cst) : (f64) -> f32
+// CHECK-NEXT: return %0 : f32
+// CHECK-NEXT: }
+func.func private @printF64(f64)
+func.func @main() -> f64 {
+ %cst = arith.constant 1.500000e+00 : f64
+ %cst_0 = arith.constant 1.500000e+00 : f64
+ %cst_1 = arith.constant 2.000000e+00 : f64
+ %cst_2 = arith.constant 2.000000e+00 : f64
+ %cst_3 = arith.constant 3.500000e+00 : f64
+ %0 = arith.addf %cst_1, %cst : f64
+ call @printF64(%cst) : (f64) -> ()
+ return %0 : f64
+}
+
+func.func private @rtf32(f64) -> f32
+func.func @main2() -> f32 {
+ %cst = arith.constant 1.500000e+00 : f64
+ %test = call @rtf32(%cst) : (f64) -> (f32)
+ return %test : f32
+}
+
+// -----
+
// Removed arguments have no memroy effect
//
// CHECK-LABEL: func.func @transpose() {
More information about the Mlir-commits
mailing list