[Mlir-commits] [mlir] [MLIR] Fix -remove-dead-values (#98935) (PR #99671)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Fri Jul 19 18:32:05 PDT 2024


https://github.com/huang-me updated https://github.com/llvm/llvm-project/pull/99671

>From 0bcdaa5e632e0045b145f5ff214e36d1f0cc26f2 Mon Sep 17 00:00:00 2001
From: huang-me <amos0107 at gmail.com>
Date: Sat, 20 Jul 2024 09:30:30 +0800
Subject: [PATCH 1/2] [MLIR] Removed arguments have no memory effect

---
 mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp     |  4 ++--
 mlir/lib/Transforms/RemoveDeadValues.cpp     |  3 ++-
 mlir/test/Transforms/remove-dead-values.mlir | 14 ++++++++++++++
 3 files changed, 18 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index cefaad9b22653..8b27e67393250 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -1124,7 +1124,7 @@ static void getGenericEffectsImpl(
         &effects,
     LinalgOp linalgOp) {
   for (auto [index, operand] : llvm::enumerate(linalgOp.getDpsInputs())) {
-    if (!llvm::isa<MemRefType>(operand.getType()))
+    if (!operand || !llvm::isa<MemRefType>(operand.getType()))
       continue;
     effects.emplace_back(
         MemoryEffects::Read::get(), &linalgOp->getOpOperand(index), /*stage=*/0,
@@ -1132,7 +1132,7 @@ static void getGenericEffectsImpl(
   }
 
   for (OpOperand &operand : linalgOp.getDpsInitsMutable()) {
-    if (!llvm::isa<MemRefType>(operand.get().getType()))
+    if (!operand.get() || !llvm::isa<MemRefType>(operand.get().getType()))
       continue;
     if (linalgOp.payloadUsesValueFromOperand(&operand)) {
       effects.emplace_back(MemoryEffects::Read::get(), &operand, /*stage=*/0,
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 055256903a152..6d7760b7c338b 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -208,7 +208,8 @@ static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module,
       arg.dropAllUses();
 
   // Do (2).
-  funcOp.eraseArguments(nonLiveArgs);
+  if (nonLiveArgs.size())
+    funcOp.eraseArguments(nonLiveArgs);
 
   // Do (3).
   SymbolTable::UseRange uses = *funcOp.getSymbolUses(module);
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 69426fdb62083..25b67a0c61734 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -357,3 +357,17 @@ func.func @kernel(%arg0: memref<18xf32>) {
 // CHECK: gpu.launch blocks
 // CHECK: memref.store
 // CHECK-NEXT: gpu.terminator
+
+// -----
+
+// Removed arguments have no memroy effect
+// 
+// CHECK-LABEL:  func.func @transpose() {
+// CHECK-NEXT:     return
+// CHECK-NEXT:   }
+func.func @transpose() {
+  %cst_2 = arith.constant dense<[[1., 2., 3.], [4., 5., 6.]]> : tensor<2x3xf64>
+  %cst_3 = arith.constant dense<0.> : tensor<3x2xf64>
+  %transposed = linalg.transpose ins(%cst_2 : tensor<2x3xf64>) outs(%cst_3 : tensor<3x2xf64>) permutation = [1, 0]
+  return
+}

>From 2dc9dd54f2104418a9b5dd68113fcec3e13198fb Mon Sep 17 00:00:00 2001
From: huang-me <amos0107 at gmail.com>
Date: Sat, 20 Jul 2024 09:31:19 +0800
Subject: [PATCH 2/2] [MLIR][Transform] Skip removing return values of external
 functions

---
 mlir/lib/Transforms/RemoveDeadValues.cpp     |  3 ++
 mlir/test/Transforms/remove-dead-values.mlir | 37 ++++++++++++++++++++
 2 files changed, 40 insertions(+)

diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 6d7760b7c338b..b7e05056c6251 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -226,6 +226,9 @@ static void cleanFuncOp(FunctionOpInterface funcOp, Operation *module,
     callOp->eraseOperands(nonLiveCallOperands);
   }
 
+  if (funcOp.isExternal())
+    return;
+
   // Get the list of unnecessary terminator operands (return values that are
   // non-live across all callers) in `nonLiveRets`. There is a very important
   // subtlety here. Unnecessary terminator operands are NOT the operands of the
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 25b67a0c61734..73637ca6ee83b 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -360,6 +360,43 @@ func.func @kernel(%arg0: memref<18xf32>) {
 
 // -----
 
+// skip removing return values of an external function
+// 
+// CHECK-LABEL: func.func private @printF64(f64)
+// CHECK-NEXT:  func.func @main() -> f64 {
+// CHECK-NEXT:    %cst = arith.constant 1.500000e+00 : f64
+// CHECK-NEXT:    %cst_0 = arith.constant 2.000000e+00 : f64
+// CHECK-NEXT:    %0 = arith.addf %cst_0, %cst : f64
+// CHECK-NEXT:    call @printF64(%cst) : (f64) -> ()
+// CHECK-NEXT:    return %0 : f64
+// CHECK-NEXT:  }
+// CHECK:       func.func private @rtf32(f64) -> f32
+// CHECK-NEXT:  func.func @main2() -> f32 {
+// CHECK-NEXT:    %cst = arith.constant 1.500000e+00 : f64
+// CHECK-NEXT:    %0 = call @rtf32(%cst) : (f64) -> f32
+// CHECK-NEXT:    return %0 : f32
+// CHECK-NEXT:  }
+func.func private @printF64(f64)
+func.func @main() -> f64 {
+    %cst = arith.constant 1.500000e+00 : f64
+    %cst_0 = arith.constant 1.500000e+00 : f64
+    %cst_1 = arith.constant 2.000000e+00 : f64
+    %cst_2 = arith.constant 2.000000e+00 : f64
+    %cst_3 = arith.constant 3.500000e+00 : f64
+    %0 = arith.addf %cst_1, %cst : f64
+    call @printF64(%cst) : (f64) -> ()
+    return %0 : f64
+}
+
+func.func private @rtf32(f64) -> f32
+func.func @main2() -> f32 {
+    %cst = arith.constant 1.500000e+00 : f64
+    %test = call @rtf32(%cst) : (f64) -> (f32)
+    return %test : f32
+}
+
+// -----
+
 // Removed arguments have no memroy effect
 // 
 // CHECK-LABEL:  func.func @transpose() {



More information about the Mlir-commits mailing list