[Mlir-commits] [mlir] 98fbd9d - [MLIR][python bindings] implement `replace_all_uses_with` on `PyValue`
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Apr 24 08:09:02 PDT 2023
Author: max
Date: 2023-04-24T10:08:43-05:00
New Revision: 98fbd9d3f940928e68566c2d3794d89e361fbf88
URL: https://github.com/llvm/llvm-project/commit/98fbd9d3f940928e68566c2d3794d89e361fbf88
DIFF: https://github.com/llvm/llvm-project/commit/98fbd9d3f940928e68566c2d3794d89e361fbf88.diff
LOG: [MLIR][python bindings] implement `replace_all_uses_with` on `PyValue`
Differential Revision: https://reviews.llvm.org/D148816
Added:
Modified:
mlir/include/mlir-c/IR.h
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/CAPI/IR/IR.cpp
mlir/test/CAPI/ir.c
mlir/test/python/ir/value.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 84d226b40b71a..b45b955363f67 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -755,6 +755,12 @@ mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData);
/// operand if there are no uses.
MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value);
+/// Replace all uses of 'of' value with the 'with' value, updating anything in
+/// the IR that uses 'of' to use the other value instead. When this returns
+/// there are zero uses of 'of'.
+MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of,
+ MlirValue with);
+
//===----------------------------------------------------------------------===//
// OpOperand API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index f3fd386779373..81c5cd2183107 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -13,11 +13,9 @@
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/BuiltinAttributes.h"
-#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Debug.h"
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
-//#include "mlir-c/Registration.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
@@ -154,6 +152,11 @@ position in the argument list. If the value is an operation result, this is
equivalent to printing the operation that produced it.
)";
+static const char kValueReplaceAllUsesWithDocstring[] =
+ R"(Replace all uses of value with the new value, updating anything in
+the IR that uses 'self' to use the other value instead.
+)";
+
//------------------------------------------------------------------------------
// Utilities.
//------------------------------------------------------------------------------
@@ -3316,10 +3319,18 @@ void mlir::python::populateIRCore(py::module &m) {
return printAccum.join();
},
kValueDunderStrDocstring)
- .def_property_readonly("type", [](PyValue &self) {
- return PyType(self.getParentOperation()->getContext(),
- mlirValueGetType(self.get()));
- });
+ .def_property_readonly("type",
+ [](PyValue &self) {
+ return PyType(
+ self.getParentOperation()->getContext(),
+ mlirValueGetType(self.get()));
+ })
+ .def(
+ "replace_all_uses_with",
+ [](PyValue &self, PyValue &with) {
+ mlirValueReplaceAllUsesOfWith(self.get(), with.get());
+ },
+ kValueReplaceAllUsesWithDocstring);
PyBlockArgument::bind(m);
PyOpResult::bind(m);
PyOpOperand::bind(m);
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 051559acd440c..0bbcb3083062f 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -751,6 +751,10 @@ MlirOpOperand mlirValueGetFirstUse(MlirValue value) {
return wrap(opOperand);
}
+void mlirValueReplaceAllUsesOfWith(MlirValue oldValue, MlirValue newValue) {
+ unwrap(oldValue).replaceAllUsesWith(unwrap(newValue));
+}
+
//===----------------------------------------------------------------------===//
// OpOperand API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 5f205c4ff5e2b..b816936c9a139 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -28,6 +28,25 @@
#include <stdlib.h>
#include <string.h>
+MlirValue makeConstantLiteral(MlirContext ctx, const char *literalStr,
+ const char *typeStr) {
+ MlirLocation loc = mlirLocationUnknownGet(ctx);
+ char attrStr[50];
+ sprintf(attrStr, "%s : %s", literalStr, typeStr);
+ MlirAttribute literal =
+ mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString(attrStr));
+ MlirNamedAttribute valueAttr = mlirNamedAttributeGet(
+ mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")), literal);
+ MlirOperationState constState = mlirOperationStateGet(
+ mlirStringRefCreateFromCString("arith.constant"), loc);
+ MlirType type =
+ mlirTypeParseGet(ctx, mlirStringRefCreateFromCString(typeStr));
+ mlirOperationStateAddResults(&constState, 1, &type);
+ mlirOperationStateAddAttributes(&constState, 1, &valueAttr);
+ MlirOperation constOp = mlirOperationCreate(&constState);
+ return mlirOperationGetResult(constOp, 0);
+}
+
static void registerAllUpstreamDialects(MlirContext ctx) {
MlirDialectRegistry registry = mlirDialectRegistryCreate();
mlirRegisterAllDialects(registry);
@@ -115,26 +134,17 @@ MlirModule makeAndDumpAdd(MlirContext ctx, MlirLocation location) {
MlirOperation func = mlirOperationCreate(&funcState);
mlirBlockInsertOwnedOperation(moduleBody, 0, func);
- MlirType indexType =
- mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("index"));
- MlirAttribute indexZeroLiteral =
- mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
- MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
- mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
- indexZeroLiteral);
- MlirOperationState constZeroState = mlirOperationStateGet(
- mlirStringRefCreateFromCString("arith.constant"), location);
- mlirOperationStateAddResults(&constZeroState, 1, &indexType);
- mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
- MlirOperation constZero = mlirOperationCreate(&constZeroState);
+ MlirValue constZeroValue = makeConstantLiteral(ctx, "0", "index");
+ MlirOperation constZero = mlirOpResultGetOwner(constZeroValue);
mlirBlockAppendOwnedOperation(funcBody, constZero);
MlirValue funcArg0 = mlirBlockGetArgument(funcBody, 0);
- MlirValue constZeroValue = mlirOperationGetResult(constZero, 0);
MlirValue dimOperands[] = {funcArg0, constZeroValue};
MlirOperationState dimState = mlirOperationStateGet(
mlirStringRefCreateFromCString("memref.dim"), location);
mlirOperationStateAddOperands(&dimState, 2, dimOperands);
+ MlirType indexType =
+ mlirTypeParseGet(ctx, mlirStringRefCreateFromCString("index"));
mlirOperationStateAddResults(&dimState, 1, &indexType);
MlirOperation dim = mlirOperationCreate(&dimState);
mlirBlockAppendOwnedOperation(funcBody, dim);
@@ -153,11 +163,11 @@ MlirModule makeAndDumpAdd(MlirContext ctx, MlirLocation location) {
mlirStringRefCreateFromCString("arith.constant"), location);
mlirOperationStateAddResults(&constOneState, 1, &indexType);
mlirOperationStateAddAttributes(&constOneState, 1, &indexOneValueAttr);
- MlirOperation constOne = mlirOperationCreate(&constOneState);
+ MlirValue constOneValue = makeConstantLiteral(ctx, "1", "index");
+ MlirOperation constOne = mlirOpResultGetOwner(constOneValue);
mlirBlockAppendOwnedOperation(funcBody, constOne);
MlirValue dimValue = mlirOperationGetResult(dim, 0);
- MlirValue constOneValue = mlirOperationGetResult(constOne, 0);
MlirValue loopOperands[] = {constZeroValue, dimValue, constOneValue};
MlirOperationState loopState = mlirOperationStateGet(
mlirStringRefCreateFromCString("scf.for"), location);
@@ -820,11 +830,6 @@ static int printBuiltinTypes(MlirContext ctx) {
return 0;
}
-void callbackSetFixedLengthString(const char *data, intptr_t len,
- void *userData) {
- strncpy(userData, data, len);
-}
-
bool stringIsEqual(const char *lhs, MlirStringRef rhs) {
if (strlen(lhs) != rhs.length) {
return false;
@@ -1794,32 +1799,10 @@ int testOperands(void) {
mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("arith"));
mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("test"));
MlirLocation loc = mlirLocationUnknownGet(ctx);
- MlirType indexType = mlirIndexTypeGet(ctx);
// Create some constants to use as operands.
- MlirAttribute indexZeroLiteral =
- mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
- MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
- mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
- indexZeroLiteral);
- MlirOperationState constZeroState = mlirOperationStateGet(
- mlirStringRefCreateFromCString("arith.constant"), loc);
- mlirOperationStateAddResults(&constZeroState, 1, &indexType);
- mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
- MlirOperation constZero = mlirOperationCreate(&constZeroState);
- MlirValue constZeroValue = mlirOperationGetResult(constZero, 0);
-
- MlirAttribute indexOneLiteral =
- mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index"));
- MlirNamedAttribute indexOneValueAttr = mlirNamedAttributeGet(
- mlirIdentifierGet(ctx, mlirStringRefCreateFromCString("value")),
- indexOneLiteral);
- MlirOperationState constOneState = mlirOperationStateGet(
- mlirStringRefCreateFromCString("arith.constant"), loc);
- mlirOperationStateAddResults(&constOneState, 1, &indexType);
- mlirOperationStateAddAttributes(&constOneState, 1, &indexOneValueAttr);
- MlirOperation constOne = mlirOperationCreate(&constOneState);
- MlirValue constOneValue = mlirOperationGetResult(constOne, 0);
+ MlirValue constZeroValue = makeConstantLiteral(ctx, "0", "index");
+ MlirValue constOneValue = makeConstantLiteral(ctx, "1", "index");
// Create the operation under test.
mlirContextSetAllowUnregisteredDialects(ctx, true);
@@ -1873,9 +1856,49 @@ int testOperands(void) {
return 3;
}
+ MlirOperationState op2State =
+ mlirOperationStateGet(mlirStringRefCreateFromCString("dummy.op2"), loc);
+ MlirValue initialOperands2[] = {constOneValue};
+ mlirOperationStateAddOperands(&op2State, 1, initialOperands2);
+ (void)mlirOperationCreate(&op2State);
+
+ MlirOpOperand use3 = mlirValueGetFirstUse(constOneValue);
+ fprintf(stderr, "First use owner: ");
+ mlirOperationPrint(mlirOpOperandGetOwner(use3), printToStderr, NULL);
+ fprintf(stderr, "\n");
+ // CHECK: First use owner: "dummy.op2"
+
+ use3 = mlirOpOperandGetNextUse(mlirValueGetFirstUse(constOneValue));
+ fprintf(stderr, "Second use owner: ");
+ mlirOperationPrint(mlirOpOperandGetOwner(use3), printToStderr, NULL);
+ fprintf(stderr, "\n");
+ // CHECK: Second use owner: "dummy.op"
+
+ MlirValue constTwoValue = makeConstantLiteral(ctx, "2", "index");
+ mlirValueReplaceAllUsesOfWith(constOneValue, constTwoValue);
+
+ use3 = mlirValueGetFirstUse(constOneValue);
+ if (!mlirOpOperandIsNull(use3)) {
+ fprintf(stderr, "ERROR: Use should be null\n");
+ return 4;
+ }
+
+ MlirOpOperand use4 = mlirValueGetFirstUse(constTwoValue);
+ fprintf(stderr, "First replacement use owner: ");
+ mlirOperationPrint(mlirOpOperandGetOwner(use4), printToStderr, NULL);
+ fprintf(stderr, "\n");
+ // CHECK: First replacement use owner: "dummy.op"
+
+ use4 = mlirOpOperandGetNextUse(mlirValueGetFirstUse(constTwoValue));
+ fprintf(stderr, "Second replacement use owner: ");
+ mlirOperationPrint(mlirOpOperandGetOwner(use4), printToStderr, NULL);
+ fprintf(stderr, "\n");
+ // CHECK: Second replacement use owner: "dummy.op2"
+
mlirOperationDestroy(op);
- mlirOperationDestroy(constZero);
- mlirOperationDestroy(constOne);
+ mlirOperationDestroy(mlirOpResultGetOwner(constZeroValue));
+ mlirOperationDestroy(mlirOpResultGetOwner(constOneValue));
+ mlirOperationDestroy(mlirOpResultGetOwner(constTwoValue));
mlirContextDestroy(ctx);
return 0;
@@ -1891,18 +1914,10 @@ int testClone(void) {
mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("func"));
MlirLocation loc = mlirLocationUnknownGet(ctx);
- MlirType indexType = mlirIndexTypeGet(ctx);
MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value");
- MlirAttribute indexZeroLiteral =
- mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
- MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
- mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral);
- MlirOperationState constZeroState = mlirOperationStateGet(
- mlirStringRefCreateFromCString("arith.constant"), loc);
- mlirOperationStateAddResults(&constZeroState, 1, &indexType);
- mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
- MlirOperation constZero = mlirOperationCreate(&constZeroState);
+ MlirValue constZeroValue = makeConstantLiteral(ctx, "0", "index");
+ MlirOperation constZero = mlirOpResultGetOwner(constZeroValue);
MlirAttribute indexOneLiteral =
mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("1 : index"));
@@ -1980,19 +1995,10 @@ int testTypeID(MlirContext ctx) {
}
MlirLocation loc = mlirLocationUnknownGet(ctx);
- MlirType indexType = mlirIndexTypeGet(ctx);
- MlirStringRef valueStringRef = mlirStringRefCreateFromCString("value");
// Create a registered operation, which should have a type id.
- MlirAttribute indexZeroLiteral =
- mlirAttributeParseGet(ctx, mlirStringRefCreateFromCString("0 : index"));
- MlirNamedAttribute indexZeroValueAttr = mlirNamedAttributeGet(
- mlirIdentifierGet(ctx, valueStringRef), indexZeroLiteral);
- MlirOperationState constZeroState = mlirOperationStateGet(
- mlirStringRefCreateFromCString("arith.constant"), loc);
- mlirOperationStateAddResults(&constZeroState, 1, &indexType);
- mlirOperationStateAddAttributes(&constZeroState, 1, &indexZeroValueAttr);
- MlirOperation constZero = mlirOperationCreate(&constZeroState);
+ MlirValue constZeroValue = makeConstantLiteral(ctx, "0", "index");
+ MlirOperation constZero = mlirOpResultGetOwner(constZeroValue);
if (!mlirOperationVerify(constZero)) {
fprintf(stderr, "ERROR: Expected operation to verify correctly\n");
diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index 98f55de41e150..90fe64ac1762a 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -111,3 +111,29 @@ def testValueUses():
assert use.owner in [op1, op2]
print(f"Use owner: {use.owner}")
print(f"Use operand_number: {use.operand_number}")
+
+# CHECK-LABEL: TEST: testValueReplaceAllUsesWith
+ at run
+def testValueReplaceAllUsesWith():
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with Location.unknown(ctx):
+ i32 = IntegerType.get_signless(32)
+ module = Module.create()
+ with InsertionPoint(module.body):
+ value = Operation.create("custom.op1", results=[i32]).results[0]
+ op1 = Operation.create("custom.op2", operands=[value])
+ op2 = Operation.create("custom.op2", operands=[value])
+ value2 = Operation.create("custom.op3", results=[i32]).results[0]
+ value.replace_all_uses_with(value2)
+
+ assert len(list(value.uses)) == 0
+
+ # CHECK: Use owner: "custom.op2"
+ # CHECK: Use operand_number: 0
+ # CHECK: Use owner: "custom.op2"
+ # CHECK: Use operand_number: 0
+ for use in value2.uses:
+ assert use.owner in [op1, op2]
+ print(f"Use owner: {use.owner}")
+ print(f"Use operand_number: {use.operand_number}")
More information about the Mlir-commits
mailing list