[Mlir-commits] [mlir] 7545371 - [mlir][python] Expose AsmState python side. (#66819)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Sep 20 15:12:10 PDT 2023
Author: Jacques Pienaar
Date: 2023-09-20T15:12:06-07:00
New Revision: 75453714f06c86a6096e4c3e45243fb7797a6756
URL: https://github.com/llvm/llvm-project/commit/75453714f06c86a6096e4c3e45243fb7797a6756
DIFF: https://github.com/llvm/llvm-project/commit/75453714f06c86a6096e4c3e45243fb7797a6756.diff
LOG: [mlir][python] Expose AsmState python side. (#66819)
This does basic plumbing, ideally want a context approach to reduce
needing to thread these manually, but the current is useful even in that
state.
Made Value.get_name change backwards compatible, so one could either set
a field or create a state to pass in.
Added:
Modified:
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/Bindings/Python/IRModule.h
mlir/test/python/ir/value.py
Removed:
################################################################################
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 504ed8f3eadb00b..c74b37a51c0df06 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -3430,19 +3430,35 @@ void mlir::python::populateIRCore(py::module &m) {
kValueDunderStrDocstring)
.def(
"get_name",
- [](PyValue &self, bool useLocalScope) {
+ [](PyValue &self, std::optional<bool> useLocalScope,
+ std::optional<std::reference_wrapper<PyAsmState>> state) {
PyPrintAccumulator printAccum;
- MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
- if (useLocalScope)
- mlirOpPrintingFlagsUseLocalScope(flags);
- MlirAsmState state = mlirAsmStateCreateForValue(self.get(), flags);
- mlirValuePrintAsOperand(self.get(), state, printAccum.getCallback(),
+ MlirOpPrintingFlags flags;
+ MlirAsmState valueState;
+ // Use state if provided, else create a new state.
+ if (state) {
+ valueState = state.value().get().get();
+ // Don't allow setting using local scope and state at same time.
+ if (useLocalScope.has_value())
+ throw py::value_error(
+ "setting AsmState and local scope together not supported");
+ } else {
+ flags = mlirOpPrintingFlagsCreate();
+ if (useLocalScope.value_or(false))
+ mlirOpPrintingFlagsUseLocalScope(flags);
+ valueState = mlirAsmStateCreateForValue(self.get(), flags);
+ }
+ mlirValuePrintAsOperand(self.get(), valueState, printAccum.getCallback(),
printAccum.getUserData());
- mlirOpPrintingFlagsDestroy(flags);
- mlirAsmStateDestroy(state);
+ // Release state if allocated locally.
+ if (!state) {
+ mlirOpPrintingFlagsDestroy(flags);
+ mlirAsmStateDestroy(valueState);
+ }
return printAccum.join();
},
- py::arg("use_local_scope") = false, kGetNameAsOperand)
+ py::arg("use_local_scope") = std::nullopt,
+ py::arg("state") = std::nullopt, kGetNameAsOperand)
.def_property_readonly(
"type", [](PyValue &self) { return mlirValueGetType(self.get()); })
.def(
@@ -3461,6 +3477,10 @@ void mlir::python::populateIRCore(py::module &m) {
PyOpResult::bind(m);
PyOpOperand::bind(m);
+ py::class_<PyAsmState>(m, "AsmState", py::module_local())
+ .def(py::init<PyValue &, bool>(), py::arg("value"),
+ py::arg("use_local_scope") = false);
+
//----------------------------------------------------------------------------
// Mapping of SymbolTable.
//----------------------------------------------------------------------------
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index d1911730c1ede03..23338f7fdb38add 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -748,6 +748,31 @@ class PyRegion {
MlirRegion region;
};
+/// Wrapper around an MlirAsmState.
+class PyAsmState {
+ public:
+ PyAsmState(MlirValue value, bool useLocalScope) {
+ flags = mlirOpPrintingFlagsCreate();
+ // The OpPrintingFlags are not exposed Python side, create locally and
+ // associate lifetime with the state.
+ if (useLocalScope)
+ mlirOpPrintingFlagsUseLocalScope(flags);
+ state = mlirAsmStateCreateForValue(value, flags);
+ }
+ ~PyAsmState() {
+ mlirOpPrintingFlagsDestroy(flags);
+ }
+ // Delete copy constructors.
+ PyAsmState(PyAsmState &other) = delete;
+ PyAsmState(const PyAsmState &other) = delete;
+
+ MlirAsmState get() { return state; }
+
+ private:
+ MlirAsmState state;
+ MlirOpPrintingFlags flags;
+};
+
/// Wrapper around an MlirBlock.
/// Blocks are managed completely by their containing operation. Unlike the
/// C++ API, the python API does not support detached blocks.
diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py
index 46a50ac5291e8d9..2a47c8d820eaf4f 100644
--- a/mlir/test/python/ir/value.py
+++ b/mlir/test/python/ir/value.py
@@ -1,4 +1,4 @@
-# RUN: %PYTHON %s | FileCheck %s
+# RUN: %PYTHON %s | FileCheck %s --enable-var-scope=false
import gc
from mlir.ir import *
@@ -199,6 +199,16 @@ def testValuePrintAsOperand():
# CHECK: %[[VAL4]]
print(value4.get_name())
+ print("With AsmState")
+ # CHECK-LABEL: With AsmState
+ state = AsmState(value3, use_local_scope=True)
+ # CHECK: %0
+ print(value3.get_name(state=state))
+ # CHECK: %1
+ print(value4.get_name(state=state))
+
+ print("With use_local_scope")
+ # CHECK-LABEL: With use_local_scope
# CHECK: %0
print(value3.get_name(use_local_scope=True))
# CHECK: %1
More information about the Mlir-commits
mailing list