[Mlir-commits] [mlir] Expose skipRegions option for Op printing in the C and Python bindings (PR #96150)
Jonas Rickert
llvmlistbot at llvm.org
Thu Jun 20 01:50:35 PDT 2024
https://github.com/jorickert created https://github.com/llvm/llvm-project/pull/96150
The MLIR C and Python Bindings expose various methods from `mlir::OpPrintingFlags` . This PR adds a binding for the `skipRegions` method, which allows to skip the printing of Regions when printing Ops. It also exposes this option as parameter in the python `get_asm` and `print` methods
>From 78b0374291b541f193865a89cba82dcca1551f23 Mon Sep 17 00:00:00 2001
From: "Rickert, Jonas" <Jonas.Rickert at amd.com>
Date: Tue, 18 Jun 2024 05:57:05 -0600
Subject: [PATCH] Expose skipRegions option for Op printing in the C and Python
bindings
---
mlir/include/mlir-c/IR.h | 4 ++++
mlir/lib/Bindings/Python/IRCore.cpp | 22 +++++++++++++++-------
mlir/lib/Bindings/Python/IRModule.h | 5 +++--
mlir/lib/CAPI/IR/IR.cpp | 3 +++
mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 3 +++
mlir/test/CAPI/ir.c | 18 +++++++++++++++---
mlir/test/python/ir/operation.py | 9 ++++++++-
7 files changed, 51 insertions(+), 13 deletions(-)
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index e3d69b7708b3d..694591fd99dc6 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -450,6 +450,10 @@ mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags);
MLIR_CAPI_EXPORTED void
mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags);
+/// Skip printing regions.
+MLIR_CAPI_EXPORTED void
+mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags);
+
//===----------------------------------------------------------------------===//
// Bytecode printing flags API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 4b6b54dc1c111..c12f75e7d224a 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -108,6 +108,7 @@ static const char kOperationPrintDocstring[] =
and report failures in a more robust fashion. Set this to True if doing this
in order to avoid running a redundant verification. If the IR is actually
invalid, behavior is undefined.
+ skip_regions: Whether to skip printing regions. Defaults to False.
)";
static const char kOperationPrintStateDocstring[] =
@@ -1221,7 +1222,7 @@ void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool prettyDebugInfo,
bool printGenericOpForm, bool useLocalScope,
bool assumeVerified, py::object fileObject,
- bool binary) {
+ bool binary, bool skipRegions) {
PyOperation &operation = getOperation();
operation.checkValid();
if (fileObject.is_none())
@@ -1239,6 +1240,8 @@ void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
mlirOpPrintingFlagsUseLocalScope(flags);
if (assumeVerified)
mlirOpPrintingFlagsAssumeVerified(flags);
+ if (skipRegions)
+ mlirOpPrintingFlagsSkipRegions(flags);
PyFileAccumulator accum(fileObject, binary);
mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
@@ -1314,7 +1317,7 @@ py::object PyOperationBase::getAsm(bool binary,
std::optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool prettyDebugInfo,
bool printGenericOpForm, bool useLocalScope,
- bool assumeVerified) {
+ bool assumeVerified, bool skipRegions) {
py::object fileObject;
if (binary) {
fileObject = py::module::import("io").attr("BytesIO")();
@@ -1328,7 +1331,8 @@ py::object PyOperationBase::getAsm(bool binary,
/*useLocalScope=*/useLocalScope,
/*assumeVerified=*/assumeVerified,
/*fileObject=*/fileObject,
- /*binary=*/binary);
+ /*binary=*/binary,
+ /*skipRegions=*/skipRegions);
return fileObject.attr("getvalue")();
}
@@ -3043,7 +3047,8 @@ void mlir::python::populateIRCore(py::module &m) {
/*prettyDebugInfo=*/false,
/*printGenericOpForm=*/false,
/*useLocalScope=*/false,
- /*assumeVerified=*/false);
+ /*assumeVerified=*/false,
+ /*skipRegions=*/false);
},
"Returns the assembly form of the operation.")
.def("print",
@@ -3053,7 +3058,8 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("binary") = false, kOperationPrintStateDocstring)
.def("print",
py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
- bool, py::object, bool>(&PyOperationBase::print),
+ bool, py::object, bool, bool>(
+ &PyOperationBase::print),
// Careful: Lots of arguments must match up with print method.
py::arg("large_elements_limit") = py::none(),
py::arg("enable_debug_info") = false,
@@ -3061,7 +3067,8 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("print_generic_op_form") = false,
py::arg("use_local_scope") = false,
py::arg("assume_verified") = false, py::arg("file") = py::none(),
- py::arg("binary") = false, kOperationPrintDocstring)
+ py::arg("binary") = false, py::arg("skip_regions") = false,
+ kOperationPrintDocstring)
.def("write_bytecode", &PyOperationBase::writeBytecode, py::arg("file"),
py::arg("desired_version") = py::none(),
kOperationPrintBytecodeDocstring)
@@ -3073,7 +3080,8 @@ void mlir::python::populateIRCore(py::module &m) {
py::arg("pretty_debug_info") = false,
py::arg("print_generic_op_form") = false,
py::arg("use_local_scope") = false,
- py::arg("assume_verified") = false, kOperationGetAsmDocstring)
+ py::arg("assume_verified") = false, py::arg("skip_regions") = false,
+ kOperationGetAsmDocstring)
.def("verify", &PyOperationBase::verify,
"Verify the operation. Raises MLIRError if verification fails, and "
"returns true otherwise.")
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index f49efcd506ee9..172898cfda0c5 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -574,14 +574,15 @@ class PyOperationBase {
/// Implements the bound 'print' method and helps with others.
void print(std::optional<int64_t> largeElementsLimit, bool enableDebugInfo,
bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope,
- bool assumeVerified, py::object fileObject, bool binary);
+ bool assumeVerified, py::object fileObject, bool binary,
+ bool skipRegions);
void print(PyAsmState &state, py::object fileObject, bool binary);
pybind11::object getAsm(bool binary,
std::optional<int64_t> largeElementsLimit,
bool enableDebugInfo, bool prettyDebugInfo,
bool printGenericOpForm, bool useLocalScope,
- bool assumeVerified);
+ bool assumeVerified, bool skipRegions);
// Implement the bound 'writeBytecode' method.
void writeBytecode(const pybind11::object &fileObject,
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 4e823c866f433..2edc311e2f85f 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -219,6 +219,9 @@ void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags) {
unwrap(flags)->assumeVerified();
}
+void mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags) {
+ unwrap(flags)->skipRegions();
+}
//===----------------------------------------------------------------------===//
// Bytecode printing flags API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 1e1b2a8348b1d..317e688076304 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -209,6 +209,7 @@ class _OperationBase:
print_generic_op_form: bool = False,
use_local_scope: bool = False,
assume_verified: bool = False,
+ skip_regions: bool = False,
) -> Union[io.BytesIO, io.StringIO]:
"""
Gets the assembly form of the operation with all options available.
@@ -256,6 +257,7 @@ class _OperationBase:
assume_verified: bool = False,
file: Optional[Any] = None,
binary: bool = False,
+ skip_regions: bool = False,
) -> None:
"""
Prints the assembly form of the operation to a file like object.
@@ -281,6 +283,7 @@ class _OperationBase:
and report failures in a more robust fashion. Set this to True if doing this
in order to avoid running a redundant verification. If the IR is actually
invalid, behavior is undefined.
+ skip_regions: Whether to skip printing regions. Defaults to False.
"""
def verify(self) -> bool:
"""
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 3d05b2a12dd8e..15a3a1fb50dc9 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -340,9 +340,9 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
// function.
MlirRegion region = mlirOperationGetRegion(operation, 0);
MlirBlock block = mlirRegionGetFirstBlock(region);
- operation = mlirBlockGetFirstOperation(block);
- region = mlirOperationGetRegion(operation, 0);
- MlirOperation parentOperation = operation;
+ MlirOperation function = mlirBlockGetFirstOperation(block);
+ region = mlirOperationGetRegion(function, 0);
+ MlirOperation parentOperation = function;
block = mlirRegionGetFirstBlock(region);
operation = mlirBlockGetFirstOperation(block);
assert(mlirModuleIsNull(mlirModuleFromOperation(operation)));
@@ -490,6 +490,18 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
// CHECK: Op print with all flags: %{{.*}} = "arith.constant"() <{value = 0 : index}> {elts = dense_resource<__elided__> : tensor<4xi32>} : () -> index loc(unknown)
// clang-format on
+ mlirOpPrintingFlagsDestroy(flags);
+ flags = mlirOpPrintingFlagsCreate();
+ mlirOpPrintingFlagsSkipRegions(flags);
+ fprintf(stderr, "Op print with skip regions flag: ");
+ mlirOperationPrintWithFlags(function, flags, printToStderr, NULL);
+ fprintf(stderr, "\n");
+ // clang-format off
+ // CHECK: Op print with skip regions flag: func.func @add(%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: memref<?xf32>)
+ // CHECK-NOT: constant
+ // CHECK-NOT: return
+ // clang-format on
+
fprintf(stderr, "With state: |");
mlirValuePrintAsOperand(value, state, printToStderr, NULL);
// CHECK: With state: |%0|
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 3a5d850b86e3a..d5d72c98b66ad 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -631,7 +631,7 @@ def testOperationPrint():
# CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32>
module.operation.print(state)
- # Test get_asm with options.
+ # Test print with options.
# CHECK: value = dense_resource<__elided__> : tensor<4xi32>
# CHECK: "func.return"(%arg0) : (i32) -> () -:4:7
module.operation.print(
@@ -642,6 +642,13 @@ def testOperationPrint():
use_local_scope=True,
)
+ # Test print with skip_regions option
+ # CHECK: func.func @f1(%arg0: i32) -> i32
+ # CHECK-NOT: func.return
+ module.body.operations[0].print(
+ skip_regions=True,
+ )
+
# CHECK-LABEL: TEST: testKnownOpView
@run
More information about the Mlir-commits
mailing list