[Mlir-commits] [mlir] [mlir] Better Python diagnostics (PR #128581)
Nikhil Kalra
llvmlistbot at llvm.org
Mon Feb 24 13:24:43 PST 2025
https://github.com/nikalra created https://github.com/llvm/llvm-project/pull/128581
Updated the Python diagnostics handler to emit notes (in addition to errors) into the output stream so that users have more context as to where in the IR the error is occurring.
To test this, I also updated the CAPI with an option to set `printStackTraceOnDiagnostic` so that notes are available in the diagnostic for the Python test.
>From c60f112c9bbcf75da8eb8946339098a840769bd8 Mon Sep 17 00:00:00 2001
From: Nikhil Kalra <nkalra at apple.com>
Date: Mon, 24 Feb 2025 11:34:40 -0800
Subject: [PATCH 1/4] functional changes
---
mlir/include/mlir/Bindings/Python/Diagnostics.h | 6 ++++++
1 file changed, 6 insertions(+)
diff --git a/mlir/include/mlir/Bindings/Python/Diagnostics.h b/mlir/include/mlir/Bindings/Python/Diagnostics.h
index ea80e14dde0f3..47914e4107ebd 100644
--- a/mlir/include/mlir/Bindings/Python/Diagnostics.h
+++ b/mlir/include/mlir/Bindings/Python/Diagnostics.h
@@ -10,6 +10,7 @@
#define MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H
#include <cassert>
+#include <cstdint>
#include <string>
#include "mlir-c/Diagnostics.h"
@@ -45,6 +46,11 @@ class CollectDiagnosticsToStringScope {
mlirLocationPrint(loc, printer, data);
*static_cast<std::string *>(data) += ": ";
mlirDiagnosticPrint(diag, printer, data);
+ for (intptr_t i = 0; i < mlirDiagnosticGetNumNotes(diag); i++) {
+ *static_cast<std::string *>(data) += "\n";
+ MlirDiagnostic note = mlirDiagnosticGetNote(diag, i);
+ handler(note, data);
+ }
return mlirLogicalResultSuccess();
}
>From 6d09f001f3fcc4ef0d5f1f799819ac78adbe16cf Mon Sep 17 00:00:00 2001
From: Nikhil Kalra <nkalra at apple.com>
Date: Mon, 24 Feb 2025 13:21:43 -0800
Subject: [PATCH 2/4] make more efficient
---
.../mlir/Bindings/Python/Diagnostics.h | 27 ++++++++++---------
1 file changed, 15 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/Bindings/Python/Diagnostics.h b/mlir/include/mlir/Bindings/Python/Diagnostics.h
index 47914e4107ebd..4f9be844dc1ac 100644
--- a/mlir/include/mlir/Bindings/Python/Diagnostics.h
+++ b/mlir/include/mlir/Bindings/Python/Diagnostics.h
@@ -9,14 +9,14 @@
#ifndef MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H
#define MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H
+#include "mlir-c/Diagnostics.h"
+#include "mlir-c/IR.h"
+
#include <cassert>
#include <cstdint>
+#include <sstream>
#include <string>
-#include "mlir-c/Diagnostics.h"
-#include "mlir-c/IR.h"
-#include "llvm/ADT/StringRef.h"
-
namespace mlir {
namespace python {
@@ -29,25 +29,28 @@ class CollectDiagnosticsToStringScope {
/*deleteUserData=*/nullptr);
}
~CollectDiagnosticsToStringScope() {
- assert(errorMessage.empty() && "unchecked error message");
mlirContextDetachDiagnosticHandler(context, handlerID);
}
- [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }
+ [[nodiscard]] std::string takeMessage() {
+ std::ostringstream stream;
+ std::swap(stream, errorMessage);
+ return stream.str();
+ }
private:
static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
auto printer = +[](MlirStringRef message, void *data) {
- *static_cast<std::string *>(data) +=
- llvm::StringRef(message.data, message.length);
+ *static_cast<std::ostringstream *>(data)
+ << std::string_view(message.data, message.length);
};
MlirLocation loc = mlirDiagnosticGetLocation(diag);
- *static_cast<std::string *>(data) += "at ";
+ *static_cast<std::ostringstream *>(data) << "at ";
mlirLocationPrint(loc, printer, data);
- *static_cast<std::string *>(data) += ": ";
+ *static_cast<std::ostringstream *>(data) << ": ";
mlirDiagnosticPrint(diag, printer, data);
for (intptr_t i = 0; i < mlirDiagnosticGetNumNotes(diag); i++) {
- *static_cast<std::string *>(data) += "\n";
+ *static_cast<std::ostringstream *>(data) << "\n";
MlirDiagnostic note = mlirDiagnosticGetNote(diag, i);
handler(note, data);
}
@@ -56,7 +59,7 @@ class CollectDiagnosticsToStringScope {
MlirContext context;
MlirDiagnosticHandlerID handlerID;
- std::string errorMessage = "";
+ std::ostringstream errorMessage;
};
} // namespace python
>From 6afec89bb95990d91c8a8a6c88c78810024b0d37 Mon Sep 17 00:00:00 2001
From: Nikhil Kalra <nkalra at apple.com>
Date: Mon, 24 Feb 2025 13:21:54 -0800
Subject: [PATCH 3/4] capi to print stack trace on diagnostic
---
mlir/include/mlir-c/IR.h | 5 +++++
mlir/lib/CAPI/IR/IR.cpp | 4 ++++
2 files changed, 9 insertions(+)
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 14ccae650606a..f661e90105704 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -162,6 +162,11 @@ MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context,
MLIR_CAPI_EXPORTED void mlirContextSetThreadPool(MlirContext context,
MlirLlvmThreadPool threadPool);
+/// Sets the context to attach the stack trace for the source code location at
+/// which a diagnostic is emitted.
+MLIR_CAPI_EXPORTED void
+mlirContextPrintStackTraceOnDiagnostic(MlirContext context, bool enable);
+
//===----------------------------------------------------------------------===//
// Dialect API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 999e8cbda1295..2249519ad4eef 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -114,6 +114,10 @@ void mlirContextSetThreadPool(MlirContext context,
unwrap(context)->setThreadPool(*unwrap(threadPool));
}
+void mlirContextPrintStackTraceOnDiagnostic(MlirContext context, bool enable) {
+ unwrap(context)->printStackTraceOnDiagnostic(enable);
+}
+
//===----------------------------------------------------------------------===//
// Dialect API.
//===----------------------------------------------------------------------===//
>From 87cbc6c2624555b145a934bf1d33438af4e8ecc4 Mon Sep 17 00:00:00 2001
From: Nikhil Kalra <nkalra at apple.com>
Date: Mon, 24 Feb 2025 13:22:21 -0800
Subject: [PATCH 4/4] test to ensure diagnostics are logged
---
mlir/test/python/ir/diagnostic_handler.py | 14 ++++++++++++++
mlir/test/python/lib/PythonTestModuleNanobind.cpp | 12 ++++++++++++
2 files changed, 26 insertions(+)
diff --git a/mlir/test/python/ir/diagnostic_handler.py b/mlir/test/python/ir/diagnostic_handler.py
index d516cda819897..5f6696850682a 100644
--- a/mlir/test/python/ir/diagnostic_handler.py
+++ b/mlir/test/python/ir/diagnostic_handler.py
@@ -2,6 +2,7 @@
import gc
from mlir.ir import *
+from mlir._mlir_libs._mlirPythonTestNanobind import test_diagnostics_with_errors_and_notes
def run(f):
@@ -222,3 +223,16 @@ def callback2(d):
# CHECK: CALLBACK2: foobar
# CHECK: CALLBACK1: foobar
loc.emit_error("foobar")
+
+# CHECK-LABEL: TEST: testBuiltInDiagnosticsHandler
+ at run
+def testBuiltInDiagnosticsHandler():
+ ctx = Context()
+
+ try:
+ test_diagnostics_with_errors_and_notes(ctx)
+ except ValueError as e:
+ # CHECK: created error
+ # CHECK: MLIRPythonCAPI
+ print(e)
+
diff --git a/mlir/test/python/lib/PythonTestModuleNanobind.cpp b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
index 99c81eae97a0c..daf3b4602b367 100644
--- a/mlir/test/python/lib/PythonTestModuleNanobind.cpp
+++ b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
@@ -11,9 +11,12 @@
#include "PythonTestCAPI.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
+#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/Diagnostics.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "nanobind/nanobind.h"
namespace nb = nanobind;
using namespace mlir::python::nanobind_adaptors;
@@ -45,6 +48,15 @@ NB_MODULE(_mlirPythonTestNanobind, m) {
},
nb::arg("registry"));
+ m.def("test_diagnostics_with_errors_and_notes", [](MlirContext ctx) {
+ mlirContextPrintStackTraceOnDiagnostic(ctx, true);
+ mlir::python::CollectDiagnosticsToStringScope handler(ctx);
+
+ auto loc = mlirLocationUnknownGet(ctx);
+ mlirEmitError(loc, "created error");
+ throw nb::value_error(handler.takeMessage().c_str());
+ });
+
mlir_attribute_subclass(m, "TestAttr",
mlirAttributeIsAPythonTestTestAttribute,
mlirPythonTestTestAttributeGetTypeID)
More information about the Mlir-commits
mailing list