[Mlir-commits] [mlir] Allow SymbolUserOpInterface operators to be used in RemoveDeadValues Pass (PR #117405)

M. Zeeshan Siddiqui llvmlistbot at llvm.org
Fri Nov 22 16:37:42 PST 2024


https://github.com/codemzs updated https://github.com/llvm/llvm-project/pull/117405

>From 391189f765cb038fb56ca67b3b6bf0906a4686a6 Mon Sep 17 00:00:00 2001
From: Zeeshan Siddiqui <mzs at ntdev.microsoft.com>
Date: Sat, 23 Nov 2024 00:02:11 +0000
Subject: [PATCH 1/2] Allow SymbolUserOpInterface operators to be used in
 RemoveDeadValues pass.

---
 mlir/lib/Transforms/RemoveDeadValues.cpp     | 6 ++----
 mlir/test/Transforms/remove-dead-values.mlir | 3 ++-
 2 files changed, 4 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index b82280dda8ba73..0aa9dcb36681b3 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -577,10 +577,8 @@ void RemoveDeadValues::runOnOperation() {
   WalkResult acceptableIR = module->walk([&](Operation *op) {
     if (op == module)
       return WalkResult::advance();
-    if (isa<BranchOpInterface>(op) ||
-        (isa<SymbolUserOpInterface>(op) && !isa<CallOpInterface>(op))) {
-      op->emitError() << "cannot optimize an IR with "
-                         "non-call symbol user ops or branch ops\n";
+    if (isa<BranchOpInterface>(op)) {
+      op->emitError() << "cannot optimize an IR with branch ops\n";
       return WalkResult::interrupt();
     }
     return WalkResult::advance();
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 47137fc6430fea..7a8d49681a4b18 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -3,9 +3,10 @@
 // The IR is updated regardless of memref.global private constant
 //
 module {
-  memref.global "private" constant @__something_global : memref<i32> = dense<0>
+  memref.global "private" constant @global_buffer : memref<5xi32> = dense<[1, 2, 3, 4, 5]> : tensor<5xi32>
   func.func @main(%arg0: i32) -> i32 {
     %0 = tensor.empty() : tensor<10xbf16>
+    %1 = memref.get_global @global_buffer : memref<5xi32>
     // CHECK-NOT: tensor.empty
     return %arg0 : i32
   }

>From 801ca7f11f8abcbbb66ba4cfca6e9e03f2b21481 Mon Sep 17 00:00:00 2001
From: Zeeshan Siddiqui <mzs at ntdev.microsoft.com>
Date: Sat, 23 Nov 2024 00:37:05 +0000
Subject: [PATCH 2/2] Update error check.

---
 mlir/test/Transforms/remove-dead-values.mlir | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 7a8d49681a4b18..c215a2b8fd77c6 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -3,10 +3,10 @@
 // The IR is updated regardless of memref.global private constant
 //
 module {
-  memref.global "private" constant @global_buffer : memref<5xi32> = dense<[1, 2, 3, 4, 5]> : tensor<5xi32>
+  memref.global "private" constant @__constant_4xi32 : memref<4xi32> = dense<[1, 2, 3, 4]> {alignment = 16 : i64}
   func.func @main(%arg0: i32) -> i32 {
     %0 = tensor.empty() : tensor<10xbf16>
-    %1 = memref.get_global @global_buffer : memref<5xi32>
+    %1 = memref.get_global @__constant_4xi32 : memref<4xi32>
     // CHECK-NOT: tensor.empty
     return %arg0 : i32
   }
@@ -30,7 +30,7 @@ module @named_module_acceptable {
 //
 func.func @dont_touch_unacceptable_ir_has_cleanable_simple_op_with_branch_op(%arg0: i1) {
   %non_live = arith.constant 0 : i32
-  // expected-error @+1 {{cannot optimize an IR with non-call symbol user ops or branch ops}}
+  // expected-error @+1 {{cannot optimize an IR with branch ops}}
   cf.cond_br %arg0, ^bb1(%non_live : i32), ^bb2(%non_live : i32)
 ^bb1(%non_live_0 : i32):
   cf.br ^bb3



More information about the Mlir-commits mailing list