[Mlir-commits] [mlir] [mlir] Better Python diagnostics (PR #128581)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 24 13:25:15 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Nikhil Kalra (nikalra)
<details>
<summary>Changes</summary>
Updated the Python diagnostics handler to emit notes (in addition to errors) into the output stream so that users have more context as to where in the IR the error is occurring.
To test this, I also updated the CAPI with an option to set `printStackTraceOnDiagnostic` so that notes are available in the diagnostic for the Python test.
---
Full diff: https://github.com/llvm/llvm-project/pull/128581.diff
5 Files Affected:
- (modified) mlir/include/mlir-c/IR.h (+5)
- (modified) mlir/include/mlir/Bindings/Python/Diagnostics.h (+20-11)
- (modified) mlir/lib/CAPI/IR/IR.cpp (+4)
- (modified) mlir/test/python/ir/diagnostic_handler.py (+14)
- (modified) mlir/test/python/lib/PythonTestModuleNanobind.cpp (+12)
``````````diff
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index 14ccae650606a..f661e90105704 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -162,6 +162,11 @@ MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context,
MLIR_CAPI_EXPORTED void mlirContextSetThreadPool(MlirContext context,
MlirLlvmThreadPool threadPool);
+/// Sets the context to attach the stack trace for the source code location at
+/// which a diagnostic is emitted.
+MLIR_CAPI_EXPORTED void
+mlirContextPrintStackTraceOnDiagnostic(MlirContext context, bool enable);
+
//===----------------------------------------------------------------------===//
// Dialect API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Bindings/Python/Diagnostics.h b/mlir/include/mlir/Bindings/Python/Diagnostics.h
index ea80e14dde0f3..4f9be844dc1ac 100644
--- a/mlir/include/mlir/Bindings/Python/Diagnostics.h
+++ b/mlir/include/mlir/Bindings/Python/Diagnostics.h
@@ -9,12 +9,13 @@
#ifndef MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H
#define MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H
-#include <cassert>
-#include <string>
-
#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
-#include "llvm/ADT/StringRef.h"
+
+#include <cassert>
+#include <cstdint>
+#include <sstream>
+#include <string>
namespace mlir {
namespace python {
@@ -28,29 +29,37 @@ class CollectDiagnosticsToStringScope {
/*deleteUserData=*/nullptr);
}
~CollectDiagnosticsToStringScope() {
- assert(errorMessage.empty() && "unchecked error message");
mlirContextDetachDiagnosticHandler(context, handlerID);
}
- [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); }
+ [[nodiscard]] std::string takeMessage() {
+ std::ostringstream stream;
+ std::swap(stream, errorMessage);
+ return stream.str();
+ }
private:
static MlirLogicalResult handler(MlirDiagnostic diag, void *data) {
auto printer = +[](MlirStringRef message, void *data) {
- *static_cast<std::string *>(data) +=
- llvm::StringRef(message.data, message.length);
+ *static_cast<std::ostringstream *>(data)
+ << std::string_view(message.data, message.length);
};
MlirLocation loc = mlirDiagnosticGetLocation(diag);
- *static_cast<std::string *>(data) += "at ";
+ *static_cast<std::ostringstream *>(data) << "at ";
mlirLocationPrint(loc, printer, data);
- *static_cast<std::string *>(data) += ": ";
+ *static_cast<std::ostringstream *>(data) << ": ";
mlirDiagnosticPrint(diag, printer, data);
+ for (intptr_t i = 0; i < mlirDiagnosticGetNumNotes(diag); i++) {
+ *static_cast<std::ostringstream *>(data) << "\n";
+ MlirDiagnostic note = mlirDiagnosticGetNote(diag, i);
+ handler(note, data);
+ }
return mlirLogicalResultSuccess();
}
MlirContext context;
MlirDiagnosticHandlerID handlerID;
- std::string errorMessage = "";
+ std::ostringstream errorMessage;
};
} // namespace python
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 999e8cbda1295..2249519ad4eef 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -114,6 +114,10 @@ void mlirContextSetThreadPool(MlirContext context,
unwrap(context)->setThreadPool(*unwrap(threadPool));
}
+void mlirContextPrintStackTraceOnDiagnostic(MlirContext context, bool enable) {
+ unwrap(context)->printStackTraceOnDiagnostic(enable);
+}
+
//===----------------------------------------------------------------------===//
// Dialect API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/ir/diagnostic_handler.py b/mlir/test/python/ir/diagnostic_handler.py
index d516cda819897..5f6696850682a 100644
--- a/mlir/test/python/ir/diagnostic_handler.py
+++ b/mlir/test/python/ir/diagnostic_handler.py
@@ -2,6 +2,7 @@
import gc
from mlir.ir import *
+from mlir._mlir_libs._mlirPythonTestNanobind import test_diagnostics_with_errors_and_notes
def run(f):
@@ -222,3 +223,16 @@ def callback2(d):
# CHECK: CALLBACK2: foobar
# CHECK: CALLBACK1: foobar
loc.emit_error("foobar")
+
+# CHECK-LABEL: TEST: testBuiltInDiagnosticsHandler
+ at run
+def testBuiltInDiagnosticsHandler():
+ ctx = Context()
+
+ try:
+ test_diagnostics_with_errors_and_notes(ctx)
+ except ValueError as e:
+ # CHECK: created error
+ # CHECK: MLIRPythonCAPI
+ print(e)
+
diff --git a/mlir/test/python/lib/PythonTestModuleNanobind.cpp b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
index 99c81eae97a0c..daf3b4602b367 100644
--- a/mlir/test/python/lib/PythonTestModuleNanobind.cpp
+++ b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
@@ -11,9 +11,12 @@
#include "PythonTestCAPI.h"
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
+#include "mlir-c/Diagnostics.h"
#include "mlir-c/IR.h"
+#include "mlir/Bindings/Python/Diagnostics.h"
#include "mlir/Bindings/Python/Nanobind.h"
#include "mlir/Bindings/Python/NanobindAdaptors.h"
+#include "nanobind/nanobind.h"
namespace nb = nanobind;
using namespace mlir::python::nanobind_adaptors;
@@ -45,6 +48,15 @@ NB_MODULE(_mlirPythonTestNanobind, m) {
},
nb::arg("registry"));
+ m.def("test_diagnostics_with_errors_and_notes", [](MlirContext ctx) {
+ mlirContextPrintStackTraceOnDiagnostic(ctx, true);
+ mlir::python::CollectDiagnosticsToStringScope handler(ctx);
+
+ auto loc = mlirLocationUnknownGet(ctx);
+ mlirEmitError(loc, "created error");
+ throw nb::value_error(handler.takeMessage().c_str());
+ });
+
mlir_attribute_subclass(m, "TestAttr",
mlirAttributeIsAPythonTestTestAttribute,
mlirPythonTestTestAttributeGetTypeID)
``````````
</details>
https://github.com/llvm/llvm-project/pull/128581
More information about the Mlir-commits
mailing list