[Mlir-commits] [mlir] 21df325 - [mlir, python] Expose replaceAllUsesExcept to Python bindings (#115850)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Nov 19 16:01:02 PST 2024


Author: Perry Gibson
Date: 2024-11-19T19:00:57-05:00
New Revision: 21df32511b558b2c1e24fe23f677fffaad4da333

URL: https://github.com/llvm/llvm-project/commit/21df32511b558b2c1e24fe23f677fffaad4da333
DIFF: https://github.com/llvm/llvm-project/commit/21df32511b558b2c1e24fe23f677fffaad4da333.diff

LOG: [mlir,python] Expose replaceAllUsesExcept to Python bindings (#115850)

Problem originally described in [the forums
here](https://discourse.llvm.org/t/mlir-python-expose-replaceallusesexcept/83068/1).

Using the MLIR Python bindings, the method
[`replaceAllUsesWith`](https://mlir.llvm.org/doxygen/classmlir_1_1Value.html#ac56b0fdb6246bcf7fa1805ba0eb71aa2)
for `Value` is exposed, e.g.,

```python
orig_value.replace_all_uses_with(
    new_value               
)
```

However, in my use-case I am separating a block into multiple blocks, so
thus want to exclude certain Operations from having their Values
replaced (since I want them to diverge).

Within Value, we have
[`replaceAllUsesExcept`](https://mlir.llvm.org/doxygen/classmlir_1_1Value.html#a9ec8d5c61f8a6aada4062f609372cce4),
where we can pass the Operations which should be skipped.

This is not currently exposed in the Python bindings: this PR fixes
this. Adds `replace_all_uses_except`, which works with individual
Operations, and lists of Operations.

Added: 
    

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/CAPI/IR/IR.cpp
    mlir/test/python/ir/value.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index b8a6f08b159817..0a515bbea3b504 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -956,6 +956,15 @@ MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value);
 MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of,
                                                       MlirValue with);
 
+/// Replace all uses of 'of' value with 'with' value, updating anything in the
+/// IR that uses 'of' to use 'with' instead, except if the user is listed in
+/// 'exceptions'. The 'exceptions' parameter is an array of MlirOperation
+/// pointers with a length of 'numExceptions'.
+MLIR_CAPI_EXPORTED void
+mlirValueReplaceAllUsesExcept(MlirValue of, MlirValue with,
+                              intptr_t numExceptions,
+                              MlirOperation *exceptions);
+
 //===----------------------------------------------------------------------===//
 // OpOperand API.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 3562ff38201dc3..3e96f8c60ba7cd 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -178,6 +178,12 @@ static const char kValueReplaceAllUsesWithDocstring[] =
 the IR that uses 'self' to use the other value instead.
 )";
 
+static const char kValueReplaceAllUsesExceptDocstring[] =
+    R"("Replace all uses of this value with the 'with' value, except for those
+in 'exceptions'. 'exceptions' can be either a single operation or a list of
+operations.
+)";
+
 //------------------------------------------------------------------------------
 // Utilities.
 //------------------------------------------------------------------------------
@@ -3718,6 +3724,29 @@ void mlir::python::populateIRCore(py::module &m) {
             mlirValueReplaceAllUsesOfWith(self.get(), with.get());
           },
           kValueReplaceAllUsesWithDocstring)
+      .def(
+          "replace_all_uses_except",
+          [](MlirValue self, MlirValue with, PyOperation &exception) {
+            MlirOperation exceptedUser = exception.get();
+            mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser);
+          },
+          py::arg("with"), py::arg("exceptions"),
+          kValueReplaceAllUsesExceptDocstring)
+      .def(
+          "replace_all_uses_except",
+          [](MlirValue self, MlirValue with, py::list exceptions) {
+            // Convert Python list to a SmallVector of MlirOperations
+            llvm::SmallVector<MlirOperation> exceptionOps;
+            for (py::handle exception : exceptions) {
+              exceptionOps.push_back(exception.cast<PyOperation &>().get());
+            }
+
+            mlirValueReplaceAllUsesExcept(
+                self, with, static_cast<intptr_t>(exceptionOps.size()),
+                exceptionOps.data());
+          },
+          py::arg("with"), py::arg("exceptions"),
+          kValueReplaceAllUsesExceptDocstring)
       .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
            [](PyValue &self) { return self.maybeDownCast(); });
   PyBlockArgument::bind(m);

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index e7e6b11c81b9d3..24dc8854048532 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -28,6 +28,7 @@
 #include "mlir/IR/Visitors.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Parser/Parser.h"
+#include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Support/ThreadPool.h"
 
 #include <cstddef>
@@ -1009,6 +1010,20 @@ void mlirValueReplaceAllUsesOfWith(MlirValue oldValue, MlirValue newValue) {
   unwrap(oldValue).replaceAllUsesWith(unwrap(newValue));
 }
 
+void mlirValueReplaceAllUsesExcept(MlirValue oldValue, MlirValue newValue,
+                                   intptr_t numExceptions,
+                                   MlirOperation *exceptions) {
+  Value oldValueCpp = unwrap(oldValue);
+  Value newValueCpp = unwrap(newValue);
+
+  llvm::SmallPtrSet<mlir::Operation *, 4> exceptionSet;
+  for (intptr_t i = 0; i < numExceptions; ++i) {
+    exceptionSet.insert(unwrap(exceptions[i]));
+  }
+
+  oldValueCpp.replaceAllUsesExcept(newValueCpp, exceptionSet);
+}
+
 //===----------------------------------------------------------------------===//
 // OpOperand API.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index 50b0e8403a7f21..9a8146bd9350bc 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -148,6 +148,77 @@ def testValueReplaceAllUsesWith():
         print(f"Use operand_number: {use.operand_number}")
 
 
+# CHECK-LABEL: TEST: testValueReplaceAllUsesWithExcept
+ at run
+def testValueReplaceAllUsesWithExcept():
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with Location.unknown(ctx):
+        i32 = IntegerType.get_signless(32)
+        module = Module.create()
+        with InsertionPoint(module.body):
+            value = Operation.create("custom.op1", results=[i32]).results[0]
+            op1 = Operation.create("custom.op1", operands=[value])
+            op2 = Operation.create("custom.op2", operands=[value])
+            value2 = Operation.create("custom.op3", results=[i32]).results[0]
+            value.replace_all_uses_except(value2, op1)
+
+    assert len(list(value.uses)) == 1
+
+    # CHECK: Use owner: "custom.op2"
+    # CHECK: Use operand_number: 0
+    for use in value2.uses:
+        assert use.owner in [op2]
+        print(f"Use owner: {use.owner}")
+        print(f"Use operand_number: {use.operand_number}")
+
+    # CHECK: Use owner: "custom.op1"
+    # CHECK: Use operand_number: 0
+    for use in value.uses:
+        assert use.owner in [op1]
+        print(f"Use owner: {use.owner}")
+        print(f"Use operand_number: {use.operand_number}")
+
+
+# CHECK-LABEL: TEST: testValueReplaceAllUsesWithMultipleExceptions
+ at run
+def testValueReplaceAllUsesWithMultipleExceptions():
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with Location.unknown(ctx):
+        i32 = IntegerType.get_signless(32)
+        module = Module.create()
+        with InsertionPoint(module.body):
+            value = Operation.create("custom.op1", results=[i32]).results[0]
+            op1 = Operation.create("custom.op1", operands=[value])
+            op2 = Operation.create("custom.op2", operands=[value])
+            op3 = Operation.create("custom.op3", operands=[value])
+            value2 = Operation.create("custom.op4", results=[i32]).results[0]
+
+            # Replace all uses of `value` with `value2`, except for `op1` and `op2`.
+            value.replace_all_uses_except(value2, [op1, op2])
+
+    # After replacement, only `op3` should use `value2`, while `op1` and `op2` should still use `value`.
+    assert len(list(value.uses)) == 2
+    assert len(list(value2.uses)) == 1
+
+    # CHECK: Use owner: "custom.op3"
+    # CHECK: Use operand_number: 0
+    for use in value2.uses:
+        assert use.owner in [op3]
+        print(f"Use owner: {use.owner}")
+        print(f"Use operand_number: {use.operand_number}")
+
+    # CHECK: Use owner: "custom.op2"
+    # CHECK: Use operand_number: 0
+    # CHECK: Use owner: "custom.op1"
+    # CHECK: Use operand_number: 0
+    for use in value.uses:
+        assert use.owner in [op1, op2]
+        print(f"Use owner: {use.owner}")
+        print(f"Use operand_number: {use.operand_number}")
+
+
 # CHECK-LABEL: TEST: testValuePrintAsOperand
 @run
 def testValuePrintAsOperand():


        


More information about the Mlir-commits mailing list