[Mlir-commits] [mlir] [MLIR] Fix -remove-dead-values (#98935) (PR #99671)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jul 19 10:32:54 PDT 2024


https://github.com/huang-me created https://github.com/llvm/llvm-project/pull/99671

1. Skip removing external function's return value.
    > We try to remove non-used return values of all functions before, however, external functions have no function body and thus impossible being modified.
2. Removed arguments of `LinalgOp` have no memory effect, therefore skip checking memory effect of them.
  

>From 5960ac35e9149aac76dd686e787471a04c89c5b1 Mon Sep 17 00:00:00 2001
From: huang-me <amos0107 at gmail.com>
Date: Fri, 19 Jul 2024 21:07:18 +0800
Subject: [PATCH 1/2] [MLIR][Transform] Skip removing return values of an
 external function (#98935)

---
 mlir/lib/Transforms/RemoveDeadValues.cpp     |  3 ++
 mlir/test/Transforms/remove-dead-values.mlir | 49 ++++++++++++++++++++
 2 files changed, 52 insertions(+)

diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 055256903a152..a3f0bfbf81fb7 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -225,6 +225,9 @@ static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module,
     callOp->eraseOperands(nonLiveCallOperands);
   }
 
+  if (funcOp.isExternal())
+    return;
+
   // Get the list of unnecessary terminator operands (return values that are
   // non-live across all callers) in `nonLiveRets`. There is a very important
   // subtlety here. Unnecessary terminator operands are NOT the operands of the
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 69426fdb62083..695f6e724cd64 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -357,3 +357,52 @@ func.func @kernel(%arg0: memref<18xf32>) {
 // CHECK: gpu.launch blocks
 // CHECK: memref.store
 // CHECK-NEXT: gpu.terminator
+
+// -----
+
+// skip removing return values of an external function
+// 
+// CHECK-LABEL: func.func private @printF64(f64)
+// CHECK-NEXT:  func.func @main() -> f64 {
+// CHECK-NEXT:    %cst = arith.constant 1.500000e+00 : f64
+// CHECK-NEXT:    %cst_0 = arith.constant 2.000000e+00 : f64
+// CHECK-NEXT:    %0 = arith.addf %cst_0, %cst : f64
+// CHECK-NEXT:    call @printF64(%cst) : (f64) -> ()
+// CHECK-NEXT:    return %0 : f64
+// CHECK-NEXT:  }
+// CHECK:       func.func private @rtf32(f64) -> f32
+// CHECK-NEXT:  func.func @main2() -> f32 {
+// CHECK-NEXT:    %cst = arith.constant 1.500000e+00 : f64
+// CHECK-NEXT:    %0 = call @rtf32(%cst) : (f64) -> f32
+// CHECK-NEXT:    return %0 : f32
+// CHECK-NEXT:  }
+func.func private @printF64(f64)
+func.func @main() -> f64 {
+    %cst = arith.constant 1.500000e+00 : f64
+    %cst_0 = arith.constant 1.500000e+00 : f64
+    %cst_1 = arith.constant 2.000000e+00 : f64
+    %cst_2 = arith.constant 2.000000e+00 : f64
+    %cst_3 = arith.constant 3.500000e+00 : f64
+    %0 = arith.addf %cst_1, %cst : f64
+    call @printF64(%cst) : (f64) -> ()
+    return %0 : f64
+}
+
+func.func private @rtf32(f64) -> f32
+func.func @main2() -> f32 {
+    %cst = arith.constant 1.500000e+00 : f64
+    %test = call @rtf32(%cst) : (f64) -> (f32)
+    return %test : f32
+}
+
+// -----
+
+// Removed arguments have no memroy effect
+// 
+// CHECK-LABEL:  func.func @transpose() {
+// CHECK-NEXT:     return
+// CHECK-NEXT:   }
+func.func @transpose() {
+  %cst_2 = arith.constant dense<[[1., 2., 3.], [4., 5., 6.]]> : tensor<2x3xf64>
+  %cst_3 = arith.constant dense<0.> : tensor<3x2xf64>
+  %transposed = linalg.transpose ins(%cst_2 : tensor<2x3xf64>) outs(%cst_3 : tensor<3x2xf64>) permutation = [1, 0]

>From 9ca751777bf746b7ecbd437188c4d8b3c18947b3 Mon Sep 17 00:00:00 2001
From: huang-me <amos0107 at gmail.com>
Date: Fri, 19 Jul 2024 21:06:37 +0800
Subject: [PATCH 2/2] [MLIR][Transform] Removed arguments have no memory effect
 (#98935)

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp     | 4 ++--
 mlir/lib/Transforms/RemoveDeadValues.cpp     | 3 ++-
 mlir/test/Transforms/remove-dead-values.mlir | 2 ++
 3 files changed, 6 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index cefaad9b22653..8b27e67393250 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1124,7 +1124,7 @@ static void getGenericEffectsImpl(
         &effects,
     LinalgOp linalgOp) {
   for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
-    if (!llvm::isa<MemRefType>(operand.getType()))
+    if (!operand || !llvm::isa<MemRefType>(operand.getType()))
       continue;
     effects.emplace_back(
         MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
@@ -1132,7 +1132,7 @@ static void getGenericEffectsImpl(
   }
 
   for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
-    if (!llvm::isa<MemRefType>(operand.get().getType()))
+    if (!operand.get() || !llvm::isa<MemRefType>(operand.get().getType()))
       continue;
     if (linalgOp.payloadUsesValueFromOperand(&operand)) {
       effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index a3f0bfbf81fb7..b7e05056c6251 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -208,7 +208,8 @@ static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module,
       arg.dropAllUses();
 
   // Do (2).
-  funcOp.eraseArguments(nonLiveArgs);
+  if (nonLiveArgs.size())
+    funcOp.eraseArguments(nonLiveArgs);
 
   // Do (3).
   SymbolTable::UseRange uses = *funcOp.getSymbolUses(module);
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 695f6e724cd64..73637ca6ee83b 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -406,3 +406,5 @@ func.func @transpose() {
   %cst_2 = arith.constant dense<[[1., 2., 3.], [4., 5., 6.]]> : tensor<2x3xf64>
   %cst_3 = arith.constant dense<0.> : tensor<3x2xf64>
   %transposed = linalg.transpose ins(%cst_2 : tensor<2x3xf64>) outs(%cst_3 : tensor<3x2xf64>) permutation = [1, 0]
+  return
+}



More information about the Mlir-commits mailing list