[Mlir-commits] [mlir] [mlir][python] Expose AsmState python side. (PR #66819)

Jacques Pienaar llvmlistbot at llvm.org
Tue Sep 19 14:09:37 PDT 2023


https://github.com/jpienaar updated https://github.com/llvm/llvm-project/pull/66819

>From d7ce21ac198d587eb659da0efc03fbe4836abd4e Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Tue, 19 Sep 2023 13:59:53 -0700
Subject: [PATCH 1/2] [mlir][python] Expose AsmState python side.

This does basic plumbing, ideally want a context approach to reduce
needing to thread these manually, but the current is useful even in that
state.
---
 mlir/lib/Bindings/Python/IRCore.cpp | 38 ++++++++++++++++++++++-------
 mlir/lib/Bindings/Python/IRModule.h | 25 +++++++++++++++++++
 mlir/test/python/ir/value.py        | 12 ++++++++-
 3 files changed, 65 insertions(+), 10 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index af713547cccbb27..2ab1219016006d8 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3425,19 +3425,35 @@ void mlir::python::populateIRCore(py::module &m) {
           kValueDunderStrDocstring)
       .def(
           "get_name",
-          [](PyValue &self, bool useLocalScope) {
+          [](PyValue &self, std::optional<bool> useLocalScope,
+             std::optional<std::reference_wrapper<PyAsmState>> state) {
             PyPrintAccumulator printAccum;
-            MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
-            if (useLocalScope)
-              mlirOpPrintingFlagsUseLocalScope(flags);
-            MlirAsmState state = mlirAsmStateCreateForValue(self.get(), flags);
-            mlirValuePrintAsOperand(self.get(), state, printAccum.getCallback(),
+            MlirOpPrintingFlags flags;
+            MlirAsmState valueState;
+            // Use state if provided, else create a new state.
+            if (state) {
+              valueState = state.value().get().get();
+              // Don't allow setting using local scope and state at same time.
+              if (useLocalScope)
+                throw py::value_error(
+                    "setting AsmState and local scope together not supported");
+            } else {
+              flags = mlirOpPrintingFlagsCreate();
+              if (useLocalScope.value_or(false))
+                mlirOpPrintingFlagsUseLocalScope(flags);
+              valueState = mlirAsmStateCreateForValue(self.get(), flags);
+            }
+            mlirValuePrintAsOperand(self.get(), valueState, printAccum.getCallback(),
                                     printAccum.getUserData());
-            mlirOpPrintingFlagsDestroy(flags);
-            mlirAsmStateDestroy(state);
+            // Release state if allocated locally.
+            if (!state) {
+              mlirOpPrintingFlagsDestroy(flags);
+              mlirAsmStateDestroy(valueState);
+            }
             return printAccum.join();
           },
-          py::arg("use_local_scope") = false, kGetNameAsOperand)
+          py::arg("use_local_scope") = std::nullopt,
+          py::arg("state") = std::nullopt, kGetNameAsOperand)
       .def_property_readonly(
           "type", [](PyValue &self) { return mlirValueGetType(self.get()); })
       .def(
@@ -3456,6 +3472,10 @@ void mlir::python::populateIRCore(py::module &m) {
   PyOpResult::bind(m);
   PyOpOperand::bind(m);
 
+  py::class_<PyAsmState>(m, "AsmState", py::module_local())
+      .def(py::init<PyValue &, bool>(), py::arg("value"),
+           py::arg("use_local_scope"));
+
   //----------------------------------------------------------------------------
   // Mapping of SymbolTable.
   //----------------------------------------------------------------------------
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index d1911730c1ede03..23338f7fdb38add 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -748,6 +748,31 @@ class PyRegion {
   MlirRegion region;
 };
 
+/// Wrapper around an MlirAsmState.
+class PyAsmState {
+ public:
+  PyAsmState(MlirValue value, bool useLocalScope) {
+    flags = mlirOpPrintingFlagsCreate();
+    // The OpPrintingFlags are not exposed Python side, create locally and
+    // associate lifetime with the state.
+    if (useLocalScope)
+      mlirOpPrintingFlagsUseLocalScope(flags);
+    state = mlirAsmStateCreateForValue(value, flags);
+  }
+  ~PyAsmState() {
+    mlirOpPrintingFlagsDestroy(flags);
+  }
+  // Delete copy constructors.
+  PyAsmState(PyAsmState &other) = delete;
+  PyAsmState(const PyAsmState &other) = delete;
+
+  MlirAsmState get() { return state; }
+
+ private:
+  MlirAsmState state;
+  MlirOpPrintingFlags flags;
+};
+
 /// Wrapper around an MlirBlock.
 /// Blocks are managed completely by their containing operation. Unlike the
 /// C++ API, the python API does not support detached blocks.
diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index 46a50ac5291e8d9..2a47c8d820eaf4f 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -1,4 +1,4 @@
-# RUN: %PYTHON %s | FileCheck %s
+# RUN: %PYTHON %s | FileCheck %s --enable-var-scope=false
 
 import gc
 from mlir.ir import *
@@ -199,6 +199,16 @@ def testValuePrintAsOperand():
         # CHECK: %[[VAL4]]
         print(value4.get_name())
 
+        print("With AsmState")
+        # CHECK-LABEL: With AsmState
+        state = AsmState(value3, use_local_scope=True)
+        # CHECK: %0
+        print(value3.get_name(state=state))
+        # CHECK: %1
+        print(value4.get_name(state=state))
+
+        print("With use_local_scope")
+        # CHECK-LABEL: With use_local_scope
         # CHECK: %0
         print(value3.get_name(use_local_scope=True))
         # CHECK: %1

>From 9de8ec41675aafea37a44f5847064e0eaa4b658f Mon Sep 17 00:00:00 2001
From: Jacques Pienaar <jpienaar at google.com>
Date: Tue, 19 Sep 2023 14:09:18 -0700
Subject: [PATCH 2/2] Set same default in constructor

---
 mlir/lib/Bindings/Python/IRCore.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 2ab1219016006d8..cafcdd19ad9a0c1 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3474,7 +3474,7 @@ void mlir::python::populateIRCore(py::module &m) {
 
   py::class_<PyAsmState>(m, "AsmState", py::module_local())
       .def(py::init<PyValue &, bool>(), py::arg("value"),
-           py::arg("use_local_scope"));
+           py::arg("use_local_scope") = false);
 
   //----------------------------------------------------------------------------
   // Mapping of SymbolTable.



More information about the Mlir-commits mailing list