[Mlir-commits] [mlir] [mlir] Expose replaceAllUsesExcept to Python bindings (PR #115850)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Nov 12 02:46:00 PST 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Perry Gibson (Wheest)
<details>
<summary>Changes</summary>
Problem originally described in [the forums here](https://discourse.llvm.org/t/mlir-python-expose-replaceallusesexcept/83068/1).
Using the MLIR Python bindings, the method [`replaceAllUsesWith`](https://mlir.llvm.org/doxygen/classmlir_1_1Value.html#ac56b0fdb6246bcf7fa1805ba0eb71aa2) for `Value` is exposed, e.g.,
```python
orig_value.replace_all_uses_with(
new_value
)
```
However, in my use-case I am separating a block into multiple blocks, so thus want to exclude certain Operations from having their Values replaced (since I want them to diverge).
Within Value, we have [`replaceAllUsesExcept`](https://mlir.llvm.org/doxygen/classmlir_1_1Value.html#a9ec8d5c61f8a6aada4062f609372cce4), where we can pass the Operations which should be skipped.
This is not currently exposed in the Python bindings: this PR fixes this. Adds `replace_all_uses_except`, which works with individual Operations, and lists of Operations.
---
Full diff: https://github.com/llvm/llvm-project/pull/115850.diff
4 Files Affected:
- (modified) mlir/include/mlir-c/IR.h (+16)
- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+32)
- (modified) mlir/lib/CAPI/IR/IR.cpp (+26)
- (modified) mlir/test/python/ir/value.py (+71)
``````````diff
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index b8a6f08b159817..012353993c341a 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -956,6 +956,22 @@ MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value);
MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of,
MlirValue with);
+/// Replace all uses of 'of' value with 'with' value, updating anything in the
+/// IR that uses 'of' to use 'with' instead, except if the user is listed in
+/// 'exceptions'. The 'exceptions' parameter is an array of MlirOperation
+/// pointers with a length of 'numExceptions'.
+MLIR_CAPI_EXPORTED void
+mlirValueReplaceAllUsesExceptWithSet(MlirValue of, MlirValue with,
+ MlirOperation *exceptions,
+ intptr_t numExceptions);
+
+/// Replace all uses of 'of' value with 'with' value, updating anything in the
+/// IR that uses 'of' to use 'with' instead, except if the user is
+/// 'exceptedUser'.
+MLIR_CAPI_EXPORTED void
+mlirValueReplaceAllUsesExceptWithSingle(MlirValue of, MlirValue with,
+ MlirOperation exceptedUser);
+
//===----------------------------------------------------------------------===//
// OpOperand API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 3562ff38201dc3..4bddcab8ccda6d 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -178,6 +178,12 @@ static const char kValueReplaceAllUsesWithDocstring[] =
the IR that uses 'self' to use the other value instead.
)";
+static const char kValueReplaceAllUsesExceptDocstring[] =
+ R"("Replace all uses of this value with the 'with' value, except for those
+in 'exceptions'. 'exceptions' can be either a single operation or a list of
+operations.
+)";
+
//------------------------------------------------------------------------------
// Utilities.
//------------------------------------------------------------------------------
@@ -3718,6 +3724,32 @@ void mlir::python::populateIRCore(py::module &m) {
mlirValueReplaceAllUsesOfWith(self.get(), with.get());
},
kValueReplaceAllUsesWithDocstring)
+ .def(
+ "replace_all_uses_except",
+ [](PyValue &self, PyValue &with, py::object exceptions) {
+ MlirValue selfValue = self.get();
+ MlirValue withValue = with.get();
+
+ // Check if 'exceptions' is a list
+ if (py::isinstance<py::list>(exceptions)) {
+ // Convert Python list to a vector of MlirOperations
+ std::vector<MlirOperation> exceptionOps;
+ for (py::handle exception : exceptions) {
+ exceptionOps.push_back(exception.cast<PyOperation &>().get());
+ }
+ mlirValueReplaceAllUsesExceptWithSet(
+ selfValue, withValue, exceptionOps.data(),
+ static_cast<intptr_t>(exceptionOps.size()));
+ } else {
+ // Assume 'exceptions' is a single Operation
+ MlirOperation exceptedUser =
+ exceptions.cast<PyOperation &>().get();
+ mlirValueReplaceAllUsesExceptWithSingle(selfValue, withValue,
+ exceptedUser);
+ }
+ },
+ py::arg("with"), py::arg("exceptions"),
+ kValueReplaceAllUsesExceptDocstring)
.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
[](PyValue &self) { return self.maybeDownCast(); });
PyBlockArgument::bind(m);
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index e7e6b11c81b9d3..5fd5f0a8f36457 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -28,6 +28,7 @@
#include "mlir/IR/Visitors.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Parser/Parser.h"
+#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/ThreadPool.h"
#include <cstddef>
@@ -1009,6 +1010,31 @@ void mlirValueReplaceAllUsesOfWith(MlirValue oldValue, MlirValue newValue) {
unwrap(oldValue).replaceAllUsesWith(unwrap(newValue));
}
+void mlirValueReplaceAllUsesExceptWithSet(MlirValue oldValue,
+ MlirValue newValue,
+ MlirOperation *exceptions,
+ intptr_t numExceptions) {
+ auto oldValueCpp = unwrap(oldValue);
+ auto newValueCpp = unwrap(newValue);
+
+ llvm::SmallPtrSet<mlir::Operation *, 4> exceptionSet;
+ for (intptr_t i = 0; i < numExceptions; ++i) {
+ exceptionSet.insert(unwrap(exceptions[i]));
+ }
+
+ oldValueCpp.replaceAllUsesExcept(newValueCpp, exceptionSet);
+}
+
+void mlirValueReplaceAllUsesExceptWithSingle(MlirValue oldValue,
+ MlirValue newValue,
+ MlirOperation exceptedUser) {
+ auto oldValueCpp = unwrap(oldValue);
+ auto newValueCpp = unwrap(newValue);
+ auto exceptedUserCpp = unwrap(exceptedUser);
+
+ oldValueCpp.replaceAllUsesExcept(newValueCpp, exceptedUserCpp);
+}
+
//===----------------------------------------------------------------------===//
// OpOperand API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index 50b0e8403a7f21..0991d71151c894 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -148,6 +148,77 @@ def testValueReplaceAllUsesWith():
print(f"Use operand_number: {use.operand_number}")
+# CHECK-LABEL: TEST: testValueReplaceAllUsesWithExcept
+ at run
+def testValueReplaceAllUsesWithExcept():
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with Location.unknown(ctx):
+ i32 = IntegerType.get_signless(32)
+ module = Module.create()
+ with InsertionPoint(module.body):
+ value = Operation.create("custom.op1", results=[i32]).results[0]
+ op1 = Operation.create("custom.op1", operands=[value])
+ op2 = Operation.create("custom.op2", operands=[value])
+ value2 = Operation.create("custom.op3", results=[i32]).results[0]
+ value.replace_all_uses_except(value2, [op1])
+
+ assert len(list(value.uses)) == 1
+
+ # CHECK: Use owner: "custom.op2"
+ # CHECK: Use operand_number: 0
+ for use in value2.uses:
+ assert use.owner in [op2]
+ print(f"Use owner: {use.owner}")
+ print(f"Use operand_number: {use.operand_number}")
+
+ # CHECK: Use owner: "custom.op1"
+ # CHECK: Use operand_number: 0
+ for use in value.uses:
+ assert use.owner in [op1]
+ print(f"Use owner: {use.owner}")
+ print(f"Use operand_number: {use.operand_number}")
+
+
+# CHECK-LABEL: TEST: testValueReplaceAllUsesWithMultipleExceptions
+ at run
+def testValueReplaceAllUsesWithMultipleExceptions():
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with Location.unknown(ctx):
+ i32 = IntegerType.get_signless(32)
+ module = Module.create()
+ with InsertionPoint(module.body):
+ value = Operation.create("custom.op1", results=[i32]).results[0]
+ op1 = Operation.create("custom.op1", operands=[value])
+ op2 = Operation.create("custom.op2", operands=[value])
+ op3 = Operation.create("custom.op3", operands=[value])
+ value2 = Operation.create("custom.op4", results=[i32]).results[0]
+
+ # Replace all uses of `value` with `value2`, except for `op1` and `op2`.
+ value.replace_all_uses_except(value2, [op1, op2])
+
+ # After replacement, only `op3` should use `value2`, while `op1` and `op2` should still use `value`.
+ assert len(list(value.uses)) == 2
+ assert len(list(value2.uses)) == 1
+
+ # CHECK: Use owner: "custom.op3"
+ # CHECK: Use operand_number: 0
+ for use in value2.uses:
+ assert use.owner in [op3]
+ print(f"Use owner: {use.owner}")
+ print(f"Use operand_number: {use.operand_number}")
+
+ # CHECK: Use owner: "custom.op1"
+ # CHECK: Use operand_number: 0
+ # CHECK: Use owner: "custom.op2"
+ # CHECK: Use operand_number: 0
+ for use in value.uses:
+ assert use.owner in [op1, op2]
+ print(f"Use owner: {use.owner}")
+ print(f"Use operand_number: {use.operand_number}")
+
+
# CHECK-LABEL: TEST: testValuePrintAsOperand
@run
def testValuePrintAsOperand():
``````````
</details>
https://github.com/llvm/llvm-project/pull/115850
More information about the Mlir-commits
mailing list