[Mlir-commits] [mlir] 25b8433 - add set_type to ir.Value
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Jul 26 05:32:18 PDT 2023
Author: max
Date: 2023-07-26T07:28:21-05:00
New Revision: 25b8433b759791c49f3abec8a9971066cdc2975c
URL: https://github.com/llvm/llvm-project/commit/25b8433b759791c49f3abec8a9971066cdc2975c
DIFF: https://github.com/llvm/llvm-project/commit/25b8433b759791c49f3abec8a9971066cdc2975c.diff
LOG: add set_type to ir.Value
Differential Revision: https://reviews.llvm.org/D156289
Added:
Modified:
mlir/include/mlir-c/IR.h
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/CAPI/IR/IR.cpp
mlir/test/python/ir/value.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 5312db09179451..b5c6a3094bc67d 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -801,6 +801,9 @@ MLIR_CAPI_EXPORTED intptr_t mlirOpResultGetResultNumber(MlirValue value);
/// Returns the type of the value.
MLIR_CAPI_EXPORTED MlirType mlirValueGetType(MlirValue value);
+/// Set the type of the value.
+MLIR_CAPI_EXPORTED void mlirValueSetType(MlirValue value, MlirType type);
+
/// Prints the value to the standard error stream.
MLIR_CAPI_EXPORTED void mlirValueDump(MlirValue value);
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 971d2819ade44b..6b9de9a8c76cea 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3431,6 +3431,12 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("use_local_scope") = false, kGetNameAsOperand)
.def_property_readonly(
"type", [](PyValue &self) { return mlirValueGetType(self.get()); })
+ .def(
+ "set_type",
+ [](PyValue &self, const PyType &type) {
+ return mlirValueSetType(self.get(), type);
+ },
+ py::arg("type"))
.def(
"replace_all_uses_with",
[](PyValue &self, PyValue &with) {
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 5231fe5f94d8a0..ccdae142499856 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -823,6 +823,10 @@ MlirType mlirValueGetType(MlirValue value) {
return wrap(unwrap(value).getType());
}
+void mlirValueSetType(MlirValue value, MlirType type) {
+ unwrap(value).setType(unwrap(type));
+}
+
void mlirValueDump(MlirValue value) { unwrap(value).dump(); }
void mlirValuePrint(MlirValue value, MlirStringCallback callback,
diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index 8a2ada1f78f1c2..46a50ac5291e8d 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -238,3 +238,25 @@ def testValuePrintAsOperand():
value2.owner.detach_from_parent()
# CHECK: %0
print(value2.get_name())
+
+
+# CHECK-LABEL: TEST: testValueSetType
+ at run
+def testValueSetType():
+ ctx = Context()
+ ctx.allow_unregistered_dialects = True
+ with Location.unknown(ctx):
+ i32 = IntegerType.get_signless(32)
+ i64 = IntegerType.get_signless(64)
+ module = Module.create()
+ with InsertionPoint(module.body):
+ value = Operation.create("custom.op1", results=[i32]).results[0]
+ # CHECK: Value(%[[VAL1:.*]] = "custom.op1"() : () -> i32)
+ print(value)
+
+ value.set_type(i64)
+ # CHECK: Value(%[[VAL1]] = "custom.op1"() : () -> i64)
+ print(value)
+
+ # CHECK: %[[VAL1]] = "custom.op1"() : () -> i64
+ print(value.owner)
More information about the Mlir-commits
mailing list