[Mlir-commits] [mlir] [mlir, python] Expose replaceAllUsesExcept to Python bindings (PR #115850)

Perry Gibson llvmlistbot at llvm.org
Tue Nov 12 09:03:13 PST 2024


https://github.com/Wheest updated https://github.com/llvm/llvm-project/pull/115850

>From 7f37794e65091ea1c130c2012ba8a0509851c9a4 Mon Sep 17 00:00:00 2001
From: pez <perry at gibsonic.org>
Date: Tue, 12 Nov 2024 10:39:07 +0000
Subject: [PATCH 1/5] Expose replaceAllUsesExcept to Python bindings

---
 mlir/include/mlir-c/IR.h            | 16 +++++++
 mlir/lib/Bindings/Python/IRCore.cpp | 32 +++++++++++++
 mlir/lib/CAPI/IR/IR.cpp             | 26 +++++++++++
 mlir/test/python/ir/value.py        | 71 +++++++++++++++++++++++++++++
 4 files changed, 145 insertions(+)

diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index b8a6f08b159817..012353993c341a 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -956,6 +956,22 @@ MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value);
 MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of,
                                                       MlirValue with);
 
+/// Replace all uses of 'of' value with 'with' value, updating anything in the
+/// IR that uses 'of' to use 'with' instead, except if the user is listed in
+/// 'exceptions'. The 'exceptions' parameter is an array of MlirOperation
+/// pointers with a length of 'numExceptions'.
+MLIR_CAPI_EXPORTED void
+mlirValueReplaceAllUsesExceptWithSet(MlirValue of, MlirValue with,
+                                     MlirOperation *exceptions,
+                                     intptr_t numExceptions);
+
+/// Replace all uses of 'of' value with 'with' value, updating anything in the
+/// IR that uses 'of' to use 'with' instead, except if the user is
+/// 'exceptedUser'.
+MLIR_CAPI_EXPORTED void
+mlirValueReplaceAllUsesExceptWithSingle(MlirValue of, MlirValue with,
+                                        MlirOperation exceptedUser);
+
 //===----------------------------------------------------------------------===//
 // OpOperand API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 3562ff38201dc3..4bddcab8ccda6d 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -178,6 +178,12 @@ static const char kValueReplaceAllUsesWithDocstring[] =
 the IR that uses 'self' to use the other value instead.
 )";
 
+static const char kValueReplaceAllUsesExceptDocstring[] =
+    R"("Replace all uses of this value with the 'with' value, except for those
+in 'exceptions'. 'exceptions' can be either a single operation or a list of
+operations.
+)";
+
 //------------------------------------------------------------------------------
 // Utilities.
 //------------------------------------------------------------------------------
@@ -3718,6 +3724,32 @@ void mlir::python::populateIRCore(py::module &m) {
             mlirValueReplaceAllUsesOfWith(self.get(), with.get());
           },
           kValueReplaceAllUsesWithDocstring)
+      .def(
+          "replace_all_uses_except",
+          [](PyValue &self, PyValue &with, py::object exceptions) {
+            MlirValue selfValue = self.get();
+            MlirValue withValue = with.get();
+
+            // Check if 'exceptions' is a list
+            if (py::isinstance<py::list>(exceptions)) {
+              // Convert Python list to a vector of MlirOperations
+              std::vector<MlirOperation> exceptionOps;
+              for (py::handle exception : exceptions) {
+                exceptionOps.push_back(exception.cast<PyOperation &>().get());
+              }
+              mlirValueReplaceAllUsesExceptWithSet(
+                  selfValue, withValue, exceptionOps.data(),
+                  static_cast<intptr_t>(exceptionOps.size()));
+            } else {
+              // Assume 'exceptions' is a single Operation
+              MlirOperation exceptedUser =
+                  exceptions.cast<PyOperation &>().get();
+              mlirValueReplaceAllUsesExceptWithSingle(selfValue, withValue,
+                                                      exceptedUser);
+            }
+          },
+          py::arg("with"), py::arg("exceptions"),
+          kValueReplaceAllUsesExceptDocstring)
       .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
            [](PyValue &self) { return self.maybeDownCast(); });
   PyBlockArgument::bind(m);
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index e7e6b11c81b9d3..5fd5f0a8f36457 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -28,6 +28,7 @@
 #include "mlir/IR/Visitors.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Parser/Parser.h"
+#include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Support/ThreadPool.h"
 
 #include <cstddef>
@@ -1009,6 +1010,31 @@ void mlirValueReplaceAllUsesOfWith(MlirValue oldValue, MlirValue newValue) {
   unwrap(oldValue).replaceAllUsesWith(unwrap(newValue));
 }
 
+void mlirValueReplaceAllUsesExceptWithSet(MlirValue oldValue,
+                                          MlirValue newValue,
+                                          MlirOperation *exceptions,
+                                          intptr_t numExceptions) {
+  auto oldValueCpp = unwrap(oldValue);
+  auto newValueCpp = unwrap(newValue);
+
+  llvm::SmallPtrSet<mlir::Operation *, 4> exceptionSet;
+  for (intptr_t i = 0; i < numExceptions; ++i) {
+    exceptionSet.insert(unwrap(exceptions[i]));
+  }
+
+  oldValueCpp.replaceAllUsesExcept(newValueCpp, exceptionSet);
+}
+
+void mlirValueReplaceAllUsesExceptWithSingle(MlirValue oldValue,
+                                             MlirValue newValue,
+                                             MlirOperation exceptedUser) {
+  auto oldValueCpp = unwrap(oldValue);
+  auto newValueCpp = unwrap(newValue);
+  auto exceptedUserCpp = unwrap(exceptedUser);
+
+  oldValueCpp.replaceAllUsesExcept(newValueCpp, exceptedUserCpp);
+}
+
 //===----------------------------------------------------------------------===//
 // OpOperand API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index 50b0e8403a7f21..e38c84172b877b 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -148,6 +148,77 @@ def testValueReplaceAllUsesWith():
         print(f"Use operand_number: {use.operand_number}")
 
 
+# CHECK-LABEL: TEST: testValueReplaceAllUsesWithExcept
+ at run
+def testValueReplaceAllUsesWithExcept():
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with Location.unknown(ctx):
+        i32 = IntegerType.get_signless(32)
+        module = Module.create()
+        with InsertionPoint(module.body):
+            value = Operation.create("custom.op1", results=[i32]).results[0]
+            op1 = Operation.create("custom.op1", operands=[value])
+            op2 = Operation.create("custom.op2", operands=[value])
+            value2 = Operation.create("custom.op3", results=[i32]).results[0]
+            value.replace_all_uses_except(value2, [op1])
+
+    assert len(list(value.uses)) == 1
+
+    # CHECK: Use owner: "custom.op2"
+    # CHECK: Use operand_number: 0
+    for use in value2.uses:
+        assert use.owner in [op2]
+        print(f"Use owner: {use.owner}")
+        print(f"Use operand_number: {use.operand_number}")
+
+    # CHECK: Use owner: "custom.op1"
+    # CHECK: Use operand_number: 0
+    for use in value.uses:
+        assert use.owner in [op1]
+        print(f"Use owner: {use.owner}")
+        print(f"Use operand_number: {use.operand_number}")
+
+
+# CHECK-LABEL: TEST: testValueReplaceAllUsesWithMultipleExceptions
+ at run
+def testValueReplaceAllUsesWithMultipleExceptions():
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with Location.unknown(ctx):
+        i32 = IntegerType.get_signless(32)
+        module = Module.create()
+        with InsertionPoint(module.body):
+            value = Operation.create("custom.op1", results=[i32]).results[0]
+            op1 = Operation.create("custom.op1", operands=[value])
+            op2 = Operation.create("custom.op2", operands=[value])
+            op3 = Operation.create("custom.op3", operands=[value])
+            value2 = Operation.create("custom.op4", results=[i32]).results[0]
+
+            # Replace all uses of `value` with `value2`, except for `op1` and `op2`.
+            value.replace_all_uses_except(value2, [op1, op2])
+
+    # After replacement, only `op3` should use `value2`, while `op1` and `op2` should still use `value`.
+    assert len(list(value.uses)) == 2
+    assert len(list(value2.uses)) == 1
+
+    # CHECK: Use owner: "custom.op3"
+    # CHECK: Use operand_number: 0
+    for use in value2.uses:
+        assert use.owner in [op3]
+        print(f"Use owner: {use.owner}")
+        print(f"Use operand_number: {use.operand_number}")
+
+    # CHECK: Use owner: "custom.op2"
+    # CHECK: Use operand_number: 0
+    # CHECK: Use owner: "custom.op1"
+    # CHECK: Use operand_number: 0
+    for use in value.uses:
+        assert use.owner in [op1, op2]
+        print(f"Use owner: {use.owner}")
+        print(f"Use operand_number: {use.operand_number}")
+
+
 # CHECK-LABEL: TEST: testValuePrintAsOperand
 @run
 def testValuePrintAsOperand():

>From f91e5fd54c9dcfdf746656ec0d5e56f7df79aaf6 Mon Sep 17 00:00:00 2001
From: pez <perry at gibsonic.org>
Date: Tue, 12 Nov 2024 12:27:27 +0000
Subject: [PATCH 2/5] Replace auto with explicit types

---
 mlir/lib/CAPI/IR/IR.cpp | 10 +++++-----
 1 file changed, 5 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 5fd5f0a8f36457..0325c00ea5d738 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -1014,8 +1014,8 @@ void mlirValueReplaceAllUsesExceptWithSet(MlirValue oldValue,
                                           MlirValue newValue,
                                           MlirOperation *exceptions,
                                           intptr_t numExceptions) {
-  auto oldValueCpp = unwrap(oldValue);
-  auto newValueCpp = unwrap(newValue);
+  Value oldValueCpp = unwrap(oldValue);
+  Value newValueCpp = unwrap(newValue);
 
   llvm::SmallPtrSet<mlir::Operation *, 4> exceptionSet;
   for (intptr_t i = 0; i < numExceptions; ++i) {
@@ -1028,9 +1028,9 @@ void mlirValueReplaceAllUsesExceptWithSet(MlirValue oldValue,
 void mlirValueReplaceAllUsesExceptWithSingle(MlirValue oldValue,
                                              MlirValue newValue,
                                              MlirOperation exceptedUser) {
-  auto oldValueCpp = unwrap(oldValue);
-  auto newValueCpp = unwrap(newValue);
-  auto exceptedUserCpp = unwrap(exceptedUser);
+  Value oldValueCpp = unwrap(oldValue);
+  Value newValueCpp = unwrap(newValue);
+  Operation *exceptedUserCpp = unwrap(exceptedUser);
 
   oldValueCpp.replaceAllUsesExcept(newValueCpp, exceptedUserCpp);
 }

>From 825d9f6e0cea21c50a767e80be5ba80479d65f0e Mon Sep 17 00:00:00 2001
From: pez <perry at gibsonic.org>
Date: Tue, 12 Nov 2024 14:04:57 +0000
Subject: [PATCH 3/5] Use ADL to split single op and list calls

---
 mlir/lib/Bindings/Python/IRCore.cpp | 38 ++++++++++++++++-------------
 1 file changed, 21 insertions(+), 17 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 4bddcab8ccda6d..681b8b31e2917f 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3726,27 +3726,31 @@ void mlir::python::populateIRCore(py::module &m) {
           kValueReplaceAllUsesWithDocstring)
       .def(
           "replace_all_uses_except",
-          [](PyValue &self, PyValue &with, py::object exceptions) {
+          [](PyValue &self, PyValue &with, PyOperation &exception) {
             MlirValue selfValue = self.get();
             MlirValue withValue = with.get();
+            MlirOperation exceptedUser = exception.get();
 
-            // Check if 'exceptions' is a list
-            if (py::isinstance<py::list>(exceptions)) {
-              // Convert Python list to a vector of MlirOperations
-              std::vector<MlirOperation> exceptionOps;
-              for (py::handle exception : exceptions) {
-                exceptionOps.push_back(exception.cast<PyOperation &>().get());
-              }
-              mlirValueReplaceAllUsesExceptWithSet(
-                  selfValue, withValue, exceptionOps.data(),
-                  static_cast<intptr_t>(exceptionOps.size()));
-            } else {
-              // Assume 'exceptions' is a single Operation
-              MlirOperation exceptedUser =
-                  exceptions.cast<PyOperation &>().get();
-              mlirValueReplaceAllUsesExceptWithSingle(selfValue, withValue,
-                                                      exceptedUser);
+            mlirValueReplaceAllUsesExceptWithSingle(selfValue, withValue,
+                                                    exceptedUser);
+          },
+          py::arg("with"), py::arg("exceptions"),
+          kValueReplaceAllUsesExceptDocstring)
+      .def(
+          "replace_all_uses_except",
+          [](PyValue &self, PyValue &with, py::list exceptions) {
+            MlirValue selfValue = self.get();
+            MlirValue withValue = with.get();
+
+            // Convert Python list to a SmallVector of MlirOperations
+            llvm::SmallVector<MlirOperation, 4> exceptionOps;
+            for (py::handle exception : exceptions) {
+              exceptionOps.push_back(exception.cast<PyOperation &>().get());
             }
+
+            mlirValueReplaceAllUsesExceptWithSet(
+                selfValue, withValue, exceptionOps.data(),
+                static_cast<intptr_t>(exceptionOps.size()));
           },
           py::arg("with"), py::arg("exceptions"),
           kValueReplaceAllUsesExceptDocstring)

>From e97696fca021b15de619b7f63d5139ab9caaa331 Mon Sep 17 00:00:00 2001
From: pez <perry at gibsonic.org>
Date: Tue, 12 Nov 2024 16:27:36 +0000
Subject: [PATCH 4/5] Leverage PybindAdaptors and fix test

---
 mlir/lib/Bindings/Python/IRCore.cpp | 16 ++++------------
 mlir/test/python/ir/value.py        |  2 +-
 2 files changed, 5 insertions(+), 13 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 681b8b31e2917f..c29b94aba9620b 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3726,22 +3726,14 @@ void mlir::python::populateIRCore(py::module &m) {
           kValueReplaceAllUsesWithDocstring)
       .def(
           "replace_all_uses_except",
-          [](PyValue &self, PyValue &with, PyOperation &exception) {
-            MlirValue selfValue = self.get();
-            MlirValue withValue = with.get();
-            MlirOperation exceptedUser = exception.get();
-
-            mlirValueReplaceAllUsesExceptWithSingle(selfValue, withValue,
-                                                    exceptedUser);
+          [](MlirValue self, MlirValue with, MlirOperation exception) {
+            mlirValueReplaceAllUsesExceptWithSingle(self, with, exception);
           },
           py::arg("with"), py::arg("exceptions"),
           kValueReplaceAllUsesExceptDocstring)
       .def(
           "replace_all_uses_except",
-          [](PyValue &self, PyValue &with, py::list exceptions) {
-            MlirValue selfValue = self.get();
-            MlirValue withValue = with.get();
-
+          [](MlirValue self, MlirValue with, py::list exceptions) {
             // Convert Python list to a SmallVector of MlirOperations
             llvm::SmallVector<MlirOperation, 4> exceptionOps;
             for (py::handle exception : exceptions) {
@@ -3749,7 +3741,7 @@ void mlir::python::populateIRCore(py::module &m) {
             }
 
             mlirValueReplaceAllUsesExceptWithSet(
-                selfValue, withValue, exceptionOps.data(),
+                self, with, exceptionOps.data(),
                 static_cast<intptr_t>(exceptionOps.size()));
           },
           py::arg("with"), py::arg("exceptions"),
diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index e38c84172b877b..9a8146bd9350bc 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -161,7 +161,7 @@ def testValueReplaceAllUsesWithExcept():
             op1 = Operation.create("custom.op1", operands=[value])
             op2 = Operation.create("custom.op2", operands=[value])
             value2 = Operation.create("custom.op3", results=[i32]).results[0]
-            value.replace_all_uses_except(value2, [op1])
+            value.replace_all_uses_except(value2, op1)
 
     assert len(list(value.uses)) == 1
 

>From d9c3cd74ff2b6f223146c92bd2c5faa969b427f6 Mon Sep 17 00:00:00 2001
From: pez <perry at gibsonic.org>
Date: Tue, 12 Nov 2024 17:02:22 +0000
Subject: [PATCH 5/5] Revert MlirOperation PybindAdaptor for ADL

---
 mlir/lib/Bindings/Python/IRCore.cpp | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index c29b94aba9620b..821784a6007a22 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3726,8 +3726,9 @@ void mlir::python::populateIRCore(py::module &m) {
           kValueReplaceAllUsesWithDocstring)
       .def(
           "replace_all_uses_except",
-          [](MlirValue self, MlirValue with, MlirOperation exception) {
-            mlirValueReplaceAllUsesExceptWithSingle(self, with, exception);
+          [](MlirValue self, MlirValue with, PyOperation &exception) {
+            MlirOperation exceptedUser = exception.get();
+            mlirValueReplaceAllUsesExceptWithSingle(self, with, exceptedUser);
           },
           py::arg("with"), py::arg("exceptions"),
           kValueReplaceAllUsesExceptDocstring)



More information about the Mlir-commits mailing list