[Mlir-commits] [mlir] Expose skipRegions option for Op printing in the C and Python bindings (PR #96150)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Jun 20 01:51:24 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Jonas Rickert (jorickert)

<details>
<summary>Changes</summary>

The MLIR C and Python Bindings expose various methods from `mlir::OpPrintingFlags` . This PR adds a binding for the `skipRegions` method, which allows to skip the printing of Regions when printing Ops. It also exposes this option as parameter in the python `get_asm` and `print` methods 

---
Full diff: https://github.com/llvm/llvm-project/pull/96150.diff


7 Files Affected:

- (modified) mlir/include/mlir-c/IR.h (+4) 
- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+15-7) 
- (modified) mlir/lib/Bindings/Python/IRModule.h (+3-2) 
- (modified) mlir/lib/CAPI/IR/IR.cpp (+3) 
- (modified) mlir/python/mlir/_mlir_libs/_mlir/ir.pyi (+3) 
- (modified) mlir/test/CAPI/ir.c (+15-3) 
- (modified) mlir/test/python/ir/operation.py (+8-1) 


``````````diff
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index e3d69b7708b3d..694591fd99dc6 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -450,6 +450,10 @@ mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags);
 MLIR_CAPI_EXPORTED void
 mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags);
 
+/// Skip printing regions.
+MLIR_CAPI_EXPORTED void
+mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags);
+
 //===----------------------------------------------------------------------===//
 // Bytecode printing flags API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 4b6b54dc1c111..c12f75e7d224a 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -108,6 +108,7 @@ static const char kOperationPrintDocstring[] =
     and report failures in a more robust fashion. Set this to True if doing this
     in order to avoid running a redundant verification. If the IR is actually
     invalid, behavior is undefined.
+  skip_regions: Whether to skip printing regions. Defaults to False.
 )";
 
 static const char kOperationPrintStateDocstring[] =
@@ -1221,7 +1222,7 @@ void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
                             bool enableDebugInfo, bool prettyDebugInfo,
                             bool printGenericOpForm, bool useLocalScope,
                             bool assumeVerified, py::object fileObject,
-                            bool binary) {
+                            bool binary, bool skipRegions) {
   PyOperation &operation = getOperation();
   operation.checkValid();
   if (fileObject.is_none())
@@ -1239,6 +1240,8 @@ void PyOperationBase::print(std::optional<int64_t> largeElementsLimit,
     mlirOpPrintingFlagsUseLocalScope(flags);
   if (assumeVerified)
     mlirOpPrintingFlagsAssumeVerified(flags);
+  if (skipRegions)
+    mlirOpPrintingFlagsSkipRegions(flags);
 
   PyFileAccumulator accum(fileObject, binary);
   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
@@ -1314,7 +1317,7 @@ py::object PyOperationBase::getAsm(bool binary,
                                    std::optional<int64_t> largeElementsLimit,
                                    bool enableDebugInfo, bool prettyDebugInfo,
                                    bool printGenericOpForm, bool useLocalScope,
-                                   bool assumeVerified) {
+                                   bool assumeVerified, bool skipRegions) {
   py::object fileObject;
   if (binary) {
     fileObject = py::module::import("io").attr("BytesIO")();
@@ -1328,7 +1331,8 @@ py::object PyOperationBase::getAsm(bool binary,
         /*useLocalScope=*/useLocalScope,
         /*assumeVerified=*/assumeVerified,
         /*fileObject=*/fileObject,
-        /*binary=*/binary);
+        /*binary=*/binary,
+        /*skipRegions=*/skipRegions);
 
   return fileObject.attr("getvalue")();
 }
@@ -3043,7 +3047,8 @@ void mlir::python::populateIRCore(py::module &m) {
                                /*prettyDebugInfo=*/false,
                                /*printGenericOpForm=*/false,
                                /*useLocalScope=*/false,
-                               /*assumeVerified=*/false);
+                               /*assumeVerified=*/false,
+                               /*skipRegions=*/false);
           },
           "Returns the assembly form of the operation.")
       .def("print",
@@ -3053,7 +3058,8 @@ void mlir::python::populateIRCore(py::module &m) {
            py::arg("binary") = false, kOperationPrintStateDocstring)
       .def("print",
            py::overload_cast<std::optional<int64_t>, bool, bool, bool, bool,
-                             bool, py::object, bool>(&PyOperationBase::print),
+                             bool, py::object, bool, bool>(
+               &PyOperationBase::print),
            // Careful: Lots of arguments must match up with print method.
            py::arg("large_elements_limit") = py::none(),
            py::arg("enable_debug_info") = false,
@@ -3061,7 +3067,8 @@ void mlir::python::populateIRCore(py::module &m) {
            py::arg("print_generic_op_form") = false,
            py::arg("use_local_scope") = false,
            py::arg("assume_verified") = false, py::arg("file") = py::none(),
-           py::arg("binary") = false, kOperationPrintDocstring)
+           py::arg("binary") = false, py::arg("skip_regions") = false,
+           kOperationPrintDocstring)
       .def("write_bytecode", &PyOperationBase::writeBytecode, py::arg("file"),
            py::arg("desired_version") = py::none(),
            kOperationPrintBytecodeDocstring)
@@ -3073,7 +3080,8 @@ void mlir::python::populateIRCore(py::module &m) {
            py::arg("pretty_debug_info") = false,
            py::arg("print_generic_op_form") = false,
            py::arg("use_local_scope") = false,
-           py::arg("assume_verified") = false, kOperationGetAsmDocstring)
+           py::arg("assume_verified") = false, py::arg("skip_regions") = false,
+           kOperationGetAsmDocstring)
       .def("verify", &PyOperationBase::verify,
            "Verify the operation. Raises MLIRError if verification fails, and "
            "returns true otherwise.")
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index f49efcd506ee9..172898cfda0c5 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -574,14 +574,15 @@ class PyOperationBase {
   /// Implements the bound 'print' method and helps with others.
   void print(std::optional<int64_t> largeElementsLimit, bool enableDebugInfo,
              bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope,
-             bool assumeVerified, py::object fileObject, bool binary);
+             bool assumeVerified, py::object fileObject, bool binary,
+             bool skipRegions);
   void print(PyAsmState &state, py::object fileObject, bool binary);
 
   pybind11::object getAsm(bool binary,
                           std::optional<int64_t> largeElementsLimit,
                           bool enableDebugInfo, bool prettyDebugInfo,
                           bool printGenericOpForm, bool useLocalScope,
-                          bool assumeVerified);
+                          bool assumeVerified, bool skipRegions);
 
   // Implement the bound 'writeBytecode' method.
   void writeBytecode(const pybind11::object &fileObject,
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 4e823c866f433..2edc311e2f85f 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -219,6 +219,9 @@ void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags) {
   unwrap(flags)->assumeVerified();
 }
 
+void mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags) {
+  unwrap(flags)->skipRegions();
+}
 //===----------------------------------------------------------------------===//
 // Bytecode printing flags API.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
index 1e1b2a8348b1d..317e688076304 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi
@@ -209,6 +209,7 @@ class _OperationBase:
         print_generic_op_form: bool = False,
         use_local_scope: bool = False,
         assume_verified: bool = False,
+        skip_regions: bool = False,
     ) -> Union[io.BytesIO, io.StringIO]:
         """
         Gets the assembly form of the operation with all options available.
@@ -256,6 +257,7 @@ class _OperationBase:
         assume_verified: bool = False,
         file: Optional[Any] = None,
         binary: bool = False,
+        skip_regions: bool = False,
     ) -> None:
         """
         Prints the assembly form of the operation to a file like object.
@@ -281,6 +283,7 @@ class _OperationBase:
             and report failures in a more robust fashion. Set this to True if doing this
             in order to avoid running a redundant verification. If the IR is actually
             invalid, behavior is undefined.
+          skip_regions: Whether to skip printing regions. Defaults to False.
         """
     def verify(self) -> bool:
         """
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index 3d05b2a12dd8e..15a3a1fb50dc9 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -340,9 +340,9 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
   // function.
   MlirRegion region = mlirOperationGetRegion(operation, 0);
   MlirBlock block = mlirRegionGetFirstBlock(region);
-  operation = mlirBlockGetFirstOperation(block);
-  region = mlirOperationGetRegion(operation, 0);
-  MlirOperation parentOperation = operation;
+  MlirOperation function = mlirBlockGetFirstOperation(block);
+  region = mlirOperationGetRegion(function, 0);
+  MlirOperation parentOperation = function;
   block = mlirRegionGetFirstBlock(region);
   operation = mlirBlockGetFirstOperation(block);
   assert(mlirModuleIsNull(mlirModuleFromOperation(operation)));
@@ -490,6 +490,18 @@ static void printFirstOfEach(MlirContext ctx, MlirOperation operation) {
   // CHECK: Op print with all flags: %{{.*}} = "arith.constant"() <{value = 0 : index}> {elts = dense_resource<__elided__> : tensor<4xi32>} : () -> index loc(unknown)
   // clang-format on
 
+  mlirOpPrintingFlagsDestroy(flags);
+  flags = mlirOpPrintingFlagsCreate();
+  mlirOpPrintingFlagsSkipRegions(flags);
+  fprintf(stderr, "Op print with skip regions flag: ");
+  mlirOperationPrintWithFlags(function, flags, printToStderr, NULL);
+  fprintf(stderr, "\n");
+  // clang-format off
+  // CHECK: Op print with skip regions flag: func.func @add(%[[ARG0:.*]]: memref<?xf32>, %[[ARG1:.*]]: memref<?xf32>)
+  // CHECK-NOT: constant
+  // CHECK-NOT: return
+  // clang-format on
+
   fprintf(stderr, "With state: |");
   mlirValuePrintAsOperand(value, state, printToStderr, NULL);
   // CHECK: With state: |%0|
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 3a5d850b86e3a..d5d72c98b66ad 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -631,7 +631,7 @@ def testOperationPrint():
     # CHECK: constant dense<[1, 2, 3, 4]> : tensor<4xi32>
     module.operation.print(state)
 
-    # Test get_asm with options.
+    # Test print with options.
     # CHECK: value = dense_resource<__elided__> : tensor<4xi32>
     # CHECK: "func.return"(%arg0) : (i32) -> () -:4:7
     module.operation.print(
@@ -642,6 +642,13 @@ def testOperationPrint():
         use_local_scope=True,
     )
 
+    # Test print with skip_regions option
+    # CHECK: func.func @f1(%arg0: i32) -> i32
+    # CHECK-NOT: func.return
+    module.body.operations[0].print(
+        skip_regions=True,
+    )
+
 
 # CHECK-LABEL: TEST: testKnownOpView
 @run

``````````

</details>


https://github.com/llvm/llvm-project/pull/96150


More information about the Mlir-commits mailing list