[Mlir-commits] [mlir] 25b8433 - add set_type to ir.Value

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Jul 26 05:32:18 PDT 2023


Author: max
Date: 2023-07-26T07:28:21-05:00
New Revision: 25b8433b759791c49f3abec8a9971066cdc2975c

URL: https://github.com/llvm/llvm-project/commit/25b8433b759791c49f3abec8a9971066cdc2975c
DIFF: https://github.com/llvm/llvm-project/commit/25b8433b759791c49f3abec8a9971066cdc2975c.diff

LOG: add set_type to ir.Value

Differential Revision: https://reviews.llvm.org/D156289

Added: 
    

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/CAPI/IR/IR.cpp
    mlir/test/python/ir/value.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 5312db09179451..b5c6a3094bc67d 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -801,6 +801,9 @@ MLIR_CAPI_EXPORTED intptr_t mlirOpResultGetResultNumber(MlirValue value);
 /// Returns the type of the value.
 MLIR_CAPI_EXPORTED MlirType mlirValueGetType(MlirValue value);
 
+/// Set the type of the value.
+MLIR_CAPI_EXPORTED void mlirValueSetType(MlirValue value, MlirType type);
+
 /// Prints the value to the standard error stream.
 MLIR_CAPI_EXPORTED void mlirValueDump(MlirValue value);
 

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 971d2819ade44b..6b9de9a8c76cea 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3431,6 +3431,12 @@ void mlir::python::populateIRCore(py::module &m) {
           py::arg("use_local_scope") = false, kGetNameAsOperand)
       .def_property_readonly(
           "type", [](PyValue &self) { return mlirValueGetType(self.get()); })
+      .def(
+          "set_type",
+          [](PyValue &self, const PyType &type) {
+            return mlirValueSetType(self.get(), type);
+          },
+          py::arg("type"))
       .def(
           "replace_all_uses_with",
           [](PyValue &self, PyValue &with) {

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 5231fe5f94d8a0..ccdae142499856 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -823,6 +823,10 @@ MlirType mlirValueGetType(MlirValue value) {
   return wrap(unwrap(value).getType());
 }
 
+void mlirValueSetType(MlirValue value, MlirType type) {
+  unwrap(value).setType(unwrap(type));
+}
+
 void mlirValueDump(MlirValue value) { unwrap(value).dump(); }
 
 void mlirValuePrint(MlirValue value, MlirStringCallback callback,

diff  --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index 8a2ada1f78f1c2..46a50ac5291e8d 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -238,3 +238,25 @@ def testValuePrintAsOperand():
         value2.owner.detach_from_parent()
         # CHECK: %0
         print(value2.get_name())
+
+
+# CHECK-LABEL: TEST: testValueSetType
+ at run
+def testValueSetType():
+    ctx = Context()
+    ctx.allow_unregistered_dialects = True
+    with Location.unknown(ctx):
+        i32 = IntegerType.get_signless(32)
+        i64 = IntegerType.get_signless(64)
+        module = Module.create()
+        with InsertionPoint(module.body):
+            value = Operation.create("custom.op1", results=[i32]).results[0]
+            # CHECK: Value(%[[VAL1:.*]] = "custom.op1"() : () -> i32)
+            print(value)
+
+            value.set_type(i64)
+            # CHECK: Value(%[[VAL1]] = "custom.op1"() : () -> i64)
+            print(value)
+
+            # CHECK: %[[VAL1]] = "custom.op1"() : () -> i64
+            print(value.owner)


        


More information about the Mlir-commits mailing list