[Mlir-commits] [mlir] [MLIR] Fix -remove-dead-values (#98935) (PR #99671)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Jul 19 10:33:22 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: None (huang-me)
<details>
<summary>Changes</summary>
1. Skip removing external function's return value.
> We try to remove non-used return values of all functions before, however, external functions have no function body and thus impossible being modified.
2. Removed arguments of `LinalgOp` have no memory effect, therefore skip checking memory effect of them.
---
Full diff: https://github.com/llvm/llvm-project/pull/99671.diff
3 Files Affected:
- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+2-2)
- (modified) mlir/lib/Transforms/RemoveDeadValues.cpp (+5-1)
- (modified) mlir/test/Transforms/remove-dead-values.mlir (+51)
``````````diff
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index cefaad9b22653..8b27e67393250 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1124,7 +1124,7 @@ static void getGenericEffectsImpl(
&effects,
LinalgOp linalgOp) {
for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
- if (!llvm::isa<MemRefType>(operand.getType()))
+ if (!operand || !llvm::isa<MemRefType>(operand.getType()))
continue;
effects.emplace_back(
MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
@@ -1132,7 +1132,7 @@ static void getGenericEffectsImpl(
}
for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
- if (!llvm::isa<MemRefType>(operand.get().getType()))
+ if (!operand.get() || !llvm::isa<MemRefType>(operand.get().getType()))
continue;
if (linalgOp.payloadUsesValueFromOperand(&operand)) {
effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 055256903a152..b7e05056c6251 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -208,7 +208,8 @@ static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module,
arg.dropAllUses();
// Do (2).
- funcOp.eraseArguments(nonLiveArgs);
+ if (nonLiveArgs.size())
+ funcOp.eraseArguments(nonLiveArgs);
// Do (3).
SymbolTable::UseRange uses = *funcOp.getSymbolUses(module);
@@ -225,6 +226,9 @@ static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module,
callOp->eraseOperands(nonLiveCallOperands);
}
+ if (funcOp.isExternal())
+ return;
+
// Get the list of unnecessary terminator operands (return values that are
// non-live across all callers) in `nonLiveRets`. There is a very important
// subtlety here. Unnecessary terminator operands are NOT the operands of the
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 69426fdb62083..73637ca6ee83b 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -357,3 +357,54 @@ func.func @kernel(%arg0: memref<18xf32>) {
// CHECK: gpu.launch blocks
// CHECK: memref.store
// CHECK-NEXT: gpu.terminator
+
+// -----
+
+// skip removing return values of an external function
+//
+// CHECK-LABEL: func.func private @printF64(f64)
+// CHECK-NEXT: func.func @main() -> f64 {
+// CHECK-NEXT: %cst = arith.constant 1.500000e+00 : f64
+// CHECK-NEXT: %cst_0 = arith.constant 2.000000e+00 : f64
+// CHECK-NEXT: %0 = arith.addf %cst_0, %cst : f64
+// CHECK-NEXT: call @printF64(%cst) : (f64) -> ()
+// CHECK-NEXT: return %0 : f64
+// CHECK-NEXT: }
+// CHECK: func.func private @rtf32(f64) -> f32
+// CHECK-NEXT: func.func @main2() -> f32 {
+// CHECK-NEXT: %cst = arith.constant 1.500000e+00 : f64
+// CHECK-NEXT: %0 = call @rtf32(%cst) : (f64) -> f32
+// CHECK-NEXT: return %0 : f32
+// CHECK-NEXT: }
+func.func private @printF64(f64)
+func.func @main() -> f64 {
+ %cst = arith.constant 1.500000e+00 : f64
+ %cst_0 = arith.constant 1.500000e+00 : f64
+ %cst_1 = arith.constant 2.000000e+00 : f64
+ %cst_2 = arith.constant 2.000000e+00 : f64
+ %cst_3 = arith.constant 3.500000e+00 : f64
+ %0 = arith.addf %cst_1, %cst : f64
+ call @printF64(%cst) : (f64) -> ()
+ return %0 : f64
+}
+
+func.func private @rtf32(f64) -> f32
+func.func @main2() -> f32 {
+ %cst = arith.constant 1.500000e+00 : f64
+ %test = call @rtf32(%cst) : (f64) -> (f32)
+ return %test : f32
+}
+
+// -----
+
+// Removed arguments have no memroy effect
+//
+// CHECK-LABEL: func.func @transpose() {
+// CHECK-NEXT: return
+// CHECK-NEXT: }
+func.func @transpose() {
+ %cst_2 = arith.constant dense<[[1., 2., 3.], [4., 5., 6.]]> : tensor<2x3xf64>
+ %cst_3 = arith.constant dense<0.> : tensor<3x2xf64>
+ %transposed = linalg.transpose ins(%cst_2 : tensor<2x3xf64>) outs(%cst_3 : tensor<3x2xf64>) permutation = [1, 0]
+ return
+}
``````````
</details>
https://github.com/llvm/llvm-project/pull/99671
More information about the Mlir-commits
mailing list