[Mlir-commits] [mlir] [MLIR] Fix -remove-dead-values (#98935) (PR #99671)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jul 19 10:33:22 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: None (huang-me)

<details>
<summary>Changes</summary>

1. Skip removing external function's return value.
    > We try to remove non-used return values of all functions before, however, external functions have no function body and thus impossible being modified.
2. Removed arguments of `LinalgOp` have no memory effect, therefore skip checking memory effect of them.
  

---
Full diff: https://github.com/llvm/llvm-project/pull/99671.diff


3 Files Affected:

- (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+2-2) 
- (modified) mlir/lib/Transforms/RemoveDeadValues.cpp (+5-1) 
- (modified) mlir/test/Transforms/remove-dead-values.mlir (+51) 


``````````diff
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index cefaad9b22653..8b27e67393250 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1124,7 +1124,7 @@ static void getGenericEffectsImpl(
         &effects,
     LinalgOp linalgOp) {
   for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
-    if (!llvm::isa<MemRefType>(operand.getType()))
+    if (!operand || !llvm::isa<MemRefType>(operand.getType()))
       continue;
     effects.emplace_back(
         MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
@@ -1132,7 +1132,7 @@ static void getGenericEffectsImpl(
   }
 
   for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
-    if (!llvm::isa<MemRefType>(operand.get().getType()))
+    if (!operand.get() || !llvm::isa<MemRefType>(operand.get().getType()))
       continue;
     if (linalgOp.payloadUsesValueFromOperand(&operand)) {
       effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 055256903a152..b7e05056c6251 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -208,7 +208,8 @@ static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module,
       arg.dropAllUses();
 
   // Do (2).
-  funcOp.eraseArguments(nonLiveArgs);
+  if (nonLiveArgs.size())
+    funcOp.eraseArguments(nonLiveArgs);
 
   // Do (3).
   SymbolTable::UseRange uses = *funcOp.getSymbolUses(module);
@@ -225,6 +226,9 @@ static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module,
     callOp->eraseOperands(nonLiveCallOperands);
   }
 
+  if (funcOp.isExternal())
+    return;
+
   // Get the list of unnecessary terminator operands (return values that are
   // non-live across all callers) in `nonLiveRets`. There is a very important
   // subtlety here. Unnecessary terminator operands are NOT the operands of the
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 69426fdb62083..73637ca6ee83b 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -357,3 +357,54 @@ func.func @kernel(%arg0: memref<18xf32>) {
 // CHECK: gpu.launch blocks
 // CHECK: memref.store
 // CHECK-NEXT: gpu.terminator
+
+// -----
+
+// skip removing return values of an external function
+// 
+// CHECK-LABEL: func.func private @printF64(f64)
+// CHECK-NEXT:  func.func @main() -> f64 {
+// CHECK-NEXT:    %cst = arith.constant 1.500000e+00 : f64
+// CHECK-NEXT:    %cst_0 = arith.constant 2.000000e+00 : f64
+// CHECK-NEXT:    %0 = arith.addf %cst_0, %cst : f64
+// CHECK-NEXT:    call @printF64(%cst) : (f64) -> ()
+// CHECK-NEXT:    return %0 : f64
+// CHECK-NEXT:  }
+// CHECK:       func.func private @rtf32(f64) -> f32
+// CHECK-NEXT:  func.func @main2() -> f32 {
+// CHECK-NEXT:    %cst = arith.constant 1.500000e+00 : f64
+// CHECK-NEXT:    %0 = call @rtf32(%cst) : (f64) -> f32
+// CHECK-NEXT:    return %0 : f32
+// CHECK-NEXT:  }
+func.func private @printF64(f64)
+func.func @main() -> f64 {
+    %cst = arith.constant 1.500000e+00 : f64
+    %cst_0 = arith.constant 1.500000e+00 : f64
+    %cst_1 = arith.constant 2.000000e+00 : f64
+    %cst_2 = arith.constant 2.000000e+00 : f64
+    %cst_3 = arith.constant 3.500000e+00 : f64
+    %0 = arith.addf %cst_1, %cst : f64
+    call @printF64(%cst) : (f64) -> ()
+    return %0 : f64
+}
+
+func.func private @rtf32(f64) -> f32
+func.func @main2() -> f32 {
+    %cst = arith.constant 1.500000e+00 : f64
+    %test = call @rtf32(%cst) : (f64) -> (f32)
+    return %test : f32
+}
+
+// -----
+
+// Removed arguments have no memroy effect
+// 
+// CHECK-LABEL:  func.func @transpose() {
+// CHECK-NEXT:     return
+// CHECK-NEXT:   }
+func.func @transpose() {
+  %cst_2 = arith.constant dense<[[1., 2., 3.], [4., 5., 6.]]> : tensor<2x3xf64>
+  %cst_3 = arith.constant dense<0.> : tensor<3x2xf64>
+  %transposed = linalg.transpose ins(%cst_2 : tensor<2x3xf64>) outs(%cst_3 : tensor<3x2xf64>) permutation = [1, 0]
+  return
+}

``````````

</details>


https://github.com/llvm/llvm-project/pull/99671


More information about the Mlir-commits mailing list