[Mlir-commits] [mlir] 2aa1258 - [mlir][python] Don't emit diagnostics when printing invalid ops
Rahul Kayaith
llvmlistbot at llvm.org
Sun Feb 26 20:50:23 PST 2023
Author: Rahul Kayaith
Date: 2023-02-26T23:50:18-05:00
New Revision: 2aa12583e6ace921a1e58f39d54a35296273758e
URL: https://github.com/llvm/llvm-project/commit/2aa12583e6ace921a1e58f39d54a35296273758e
DIFF: https://github.com/llvm/llvm-project/commit/2aa12583e6ace921a1e58f39d54a35296273758e.diff
LOG: [mlir][python] Don't emit diagnostics when printing invalid ops
The asm printer grew the ability to automatically fall back to the
generic format for invalid ops, so this logic doesn't need to be in the
bindings anymore. The printer already handles supressing diagnostics
that get emitted while checking if the op is valid.
Reviewed By: mehdi_amini, stellaraccident
Differential Revision: https://reviews.llvm.org/D144805
Added:
Modified:
mlir/include/mlir-c/IR.h
mlir/lib/Bindings/Python/IRCore.cpp
mlir/lib/CAPI/IR/IR.cpp
mlir/test/python/ir/operation.py
Removed:
################################################################################
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index a349eb9f3913..023b99f42ba4 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -404,6 +404,10 @@ mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags);
MLIR_CAPI_EXPORTED void
mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags);
+/// Do not verify the operation when using custom operation printers.
+MLIR_CAPI_EXPORTED void
+mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags);
+
//===----------------------------------------------------------------------===//
// Operation API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 2ecfc36d41a5..e09f0fdeee90 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1075,17 +1075,6 @@ void PyOperationBase::print(py::object fileObject, bool binary,
if (fileObject.is_none())
fileObject = py::module::import("sys").attr("stdout");
- if (!assumeVerified && !printGenericOpForm &&
- !mlirOperationVerify(operation)) {
- std::string message("// Verification failed, printing generic form\n");
- if (binary) {
- fileObject.attr("write")(py::bytes(message));
- } else {
- fileObject.attr("write")(py::str(message));
- }
- printGenericOpForm = true;
- }
-
MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
if (largeElementsLimit)
mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
@@ -1096,6 +1085,8 @@ void PyOperationBase::print(py::object fileObject, bool binary,
mlirOpPrintingFlagsPrintGenericOpForm(flags);
if (useLocalScope)
mlirOpPrintingFlagsUseLocalScope(flags);
+ if (assumeVerified)
+ mlirOpPrintingFlagsAssumeVerified(flags);
PyFileAccumulator accum(fileObject, binary);
mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 7d3479736c55..e83f0f8240ae 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -141,6 +141,10 @@ void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) {
unwrap(flags)->useLocalScope();
}
+void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags) {
+ unwrap(flags)->assumeVerified();
+}
+
//===----------------------------------------------------------------------===//
// Location API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py
index 9190c301f1a9..f393cf92c3c1 100644
--- a/mlir/test/python/ir/operation.py
+++ b/mlir/test/python/ir/operation.py
@@ -681,7 +681,6 @@ def testInvalidOperationStrSoftFails():
with Location.unknown(ctx):
invalid_op = create_invalid_operation()
# Verify that we fallback to the generic printer for safety.
- # CHECK: // Verification failed, printing generic form
# CHECK: "builtin.module"() ({
# CHECK: }) : () -> ()
print(invalid_op)
@@ -698,7 +697,8 @@ def testInvalidModuleStrSoftFails():
with InsertionPoint(module.body):
invalid_op = create_invalid_operation()
# Verify that we fallback to the generic printer for safety.
- # CHECK: // Verification failed, printing generic form
+ # CHECK: "builtin.module"() ({
+ # CHECK: }) : () -> ()
print(module)
@@ -709,7 +709,7 @@ def testInvalidOperationGetAsmBinarySoftFails():
with Location.unknown(ctx):
invalid_op = create_invalid_operation()
# Verify that we fallback to the generic printer for safety.
- # CHECK: b'// Verification failed, printing generic form\n
+ # CHECK: b'"builtin.module"() ({\n^bb0:\n}, {\n}) : () -> ()\n'
print(invalid_op.get_asm(binary=True))
More information about the Mlir-commits
mailing list