[Mlir-commits] [mlir] 2aa1258 - [mlir][python] Don't emit diagnostics when printing invalid ops

Rahul Kayaith llvmlistbot at llvm.org
Sun Feb 26 20:50:23 PST 2023


Author: Rahul Kayaith
Date: 2023-02-26T23:50:18-05:00
New Revision: 2aa12583e6ace921a1e58f39d54a35296273758e

URL: https://github.com/llvm/llvm-project/commit/2aa12583e6ace921a1e58f39d54a35296273758e
DIFF: https://github.com/llvm/llvm-project/commit/2aa12583e6ace921a1e58f39d54a35296273758e.diff

LOG: [mlir][python] Don't emit diagnostics when printing invalid ops

The asm printer grew the ability to automatically fall back to the
generic format for invalid ops, so this logic doesn't need to be in the
bindings anymore. The printer already handles supressing diagnostics
that get emitted while checking if the op is valid.

Reviewed By: mehdi_amini, stellaraccident

Differential Revision: https://reviews.llvm.org/D144805

Added: 
    

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/lib/Bindings/Python/IRCore.cpp
    mlir/lib/CAPI/IR/IR.cpp
    mlir/test/python/ir/operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index a349eb9f3913..023b99f42ba4 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -404,6 +404,10 @@ mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags);
 MLIR_CAPI_EXPORTED void
 mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags);
 
+/// Do not verify the operation when using custom operation printers.
+MLIR_CAPI_EXPORTED void
+mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags);
+
 //===----------------------------------------------------------------------===//
 // Operation API.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 2ecfc36d41a5..e09f0fdeee90 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1075,17 +1075,6 @@ void PyOperationBase::print(py::object fileObject, bool binary,
   if (fileObject.is_none())
     fileObject = py::module::import("sys").attr("stdout");
 
-  if (!assumeVerified && !printGenericOpForm &&
-      !mlirOperationVerify(operation)) {
-    std::string message("// Verification failed, printing generic form\n");
-    if (binary) {
-      fileObject.attr("write")(py::bytes(message));
-    } else {
-      fileObject.attr("write")(py::str(message));
-    }
-    printGenericOpForm = true;
-  }
-
   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
   if (largeElementsLimit)
     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
@@ -1096,6 +1085,8 @@ void PyOperationBase::print(py::object fileObject, bool binary,
     mlirOpPrintingFlagsPrintGenericOpForm(flags);
   if (useLocalScope)
     mlirOpPrintingFlagsUseLocalScope(flags);
+  if (assumeVerified)
+    mlirOpPrintingFlagsAssumeVerified(flags);
 
   PyFileAccumulator accum(fileObject, binary);
   mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 7d3479736c55..e83f0f8240ae 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -141,6 +141,10 @@ void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) {
   unwrap(flags)->useLocalScope();
 }
 
+void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags) {
+  unwrap(flags)->assumeVerified();
+}
+
 //===----------------------------------------------------------------------===//
 // Location API.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 9190c301f1a9..f393cf92c3c1 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -681,7 +681,6 @@ def testInvalidOperationStrSoftFails():
   with Location.unknown(ctx):
     invalid_op = create_invalid_operation()
     # Verify that we fallback to the generic printer for safety.
-    # CHECK: // Verification failed, printing generic form
     # CHECK: "builtin.module"() ({
     # CHECK: }) : () -> ()
     print(invalid_op)
@@ -698,7 +697,8 @@ def testInvalidModuleStrSoftFails():
     with InsertionPoint(module.body):
       invalid_op = create_invalid_operation()
     # Verify that we fallback to the generic printer for safety.
-    # CHECK: // Verification failed, printing generic form
+    # CHECK: "builtin.module"() ({
+    # CHECK: }) : () -> ()
     print(module)
 
 
@@ -709,7 +709,7 @@ def testInvalidOperationGetAsmBinarySoftFails():
   with Location.unknown(ctx):
     invalid_op = create_invalid_operation()
     # Verify that we fallback to the generic printer for safety.
-    # CHECK: b'// Verification failed, printing generic form\n
+    # CHECK: b'"builtin.module"() ({\n^bb0:\n}, {\n}) : () -> ()\n'
     print(invalid_op.get_asm(binary=True))
 
 


        


More information about the Mlir-commits mailing list