[llvm-branch-commits] [mlir] 1c21594 - Use the generic form when printing from the python bindings and the verifier fails

Mehdi Amini via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Dec 3 10:49:57 PST 2020


Author: Mehdi Amini
Date: 2020-12-03T18:45:00Z
New Revision: 1c2159494d0731d3ecf4c02651f8b14ab9eb0338

URL: https://github.com/llvm/llvm-project/commit/1c2159494d0731d3ecf4c02651f8b14ab9eb0338
DIFF: https://github.com/llvm/llvm-project/commit/1c2159494d0731d3ecf4c02651f8b14ab9eb0338.diff

LOG: Use the generic form when printing from the python bindings and the verifier fails

This reduces the chances of segfault. While it is a good practice to ensure
robust custom printers, it is unfortunately common to have them crash on
invalid input.

Reviewed By: stellaraccident

Differential Revision: https://reviews.llvm.org/D92536

Added: 
    

Modified: 
    mlir/include/mlir-c/IR.h
    mlir/lib/Bindings/Python/IRModules.cpp
    mlir/lib/CAPI/IR/IR.cpp
    mlir/test/Bindings/Python/ir_operation.py

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index d8f63263763b..02d1e54d20e5 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -386,6 +386,9 @@ MLIR_CAPI_EXPORTED void mlirOperationPrintWithFlags(MlirOperation op,
 /// Prints an operation to stderr.
 MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op);
 
+/// Verify the operation and return true if it passes, false if it fails.
+MLIR_CAPI_EXPORTED bool mlirOperationVerify(MlirOperation op);
+
 //===----------------------------------------------------------------------===//
 // Region API.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp
index 8beed22e5372..fde95a57b6ce 100644
--- a/mlir/lib/Bindings/Python/IRModules.cpp
+++ b/mlir/lib/Bindings/Python/IRModules.cpp
@@ -809,6 +809,12 @@ void PyOperationBase::print(py::object fileObject, bool binary,
   operation.checkValid();
   if (fileObject.is_none())
     fileObject = py::module::import("sys").attr("stdout");
+
+  if (!printGenericOpForm && !mlirOperationVerify(operation)) {
+    fileObject.attr("write")("// Verification failed, printing generic form\n");
+    printGenericOpForm = true;
+  }
+
   MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
   if (largeElementsLimit)
     mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);

diff  --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 139c878d76e0..c34da0daeb5a 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -17,6 +17,7 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/Types.h"
+#include "mlir/IR/Verifier.h"
 #include "mlir/Parser.h"
 
 using namespace mlir;
@@ -339,6 +340,10 @@ void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags,
 
 void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); }
 
+bool mlirOperationVerify(MlirOperation op) {
+  return succeeded(verify(unwrap(op)));
+}
+
 //===----------------------------------------------------------------------===//
 // Region API.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/Bindings/Python/ir_operation.py b/mlir/test/Bindings/Python/ir_operation.py
index e3867b99a9b4..d23e0b6c0b4e 100644
--- a/mlir/test/Bindings/Python/ir_operation.py
+++ b/mlir/test/Bindings/Python/ir_operation.py
@@ -537,3 +537,17 @@ def testSingleResultProperty():
   print(module.body.operations[2])
 
 run(testSingleResultProperty)
+
+# CHECK-LABEL: TEST: testPrintInvalidOperation
+def testPrintInvalidOperation():
+  ctx = Context()
+  with Location.unknown(ctx):
+    module = Operation.create("module", regions=1)
+    # This block does not have a terminator, it may crash the custom printer.
+    # Verify that we fallback to the generic printer for safety.
+    block = module.regions[0].blocks.append()
+    print(module)
+    # CHECK: // Verification failed, printing generic form
+    # CHECK: "module"() ( {
+    # CHECK: }) : () -> ()
+run(testPrintInvalidOperation)


        


More information about the llvm-branch-commits mailing list