[Mlir-commits] [mlir] [mlir][python] source line info (PR #149166)
Maksim Levental
llvmlistbot at llvm.org
Wed Jul 16 11:50:53 PDT 2025
https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/149166
None
>From bec65c38fae17e00627c1dcaf7d7c4fdb40dd660 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 16 Jul 2025 14:50:30 -0400
Subject: [PATCH] [mlir][python] source line info
---
mlir/lib/Bindings/Python/MainModule.cpp | 2 +
mlir/lib/Bindings/Python/Traceback.cpp | 443 ++++++++++++++++++++++++
mlir/lib/Bindings/Python/Traceback.h | 63 ++++
mlir/python/CMakeLists.txt | 4 +
mlir/python/mlir/source_info_util.py | 367 ++++++++++++++++++++
mlir/python/mlir/traceback_util.py | 238 +++++++++++++
6 files changed, 1117 insertions(+)
create mode 100644 mlir/lib/Bindings/Python/Traceback.cpp
create mode 100644 mlir/lib/Bindings/Python/Traceback.h
create mode 100644 mlir/python/mlir/source_info_util.py
create mode 100644 mlir/python/mlir/traceback_util.py
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 6f49431006605..489d8e21a56cd 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -12,6 +12,7 @@
#include "NanobindUtils.h"
#include "Pass.h"
#include "Rewrite.h"
+#include "Traceback.h"
#include "mlir/Bindings/Python/Nanobind.h"
namespace nb = nanobind;
@@ -105,6 +106,7 @@ NB_MODULE(_mlir, m) {
"typeid"_a, nb::kw_only(), "replace"_a = false,
"Register a value caster for casting MLIR values to custom user values.");
+ BuildTracebackSubmodule(m);
// Define and populate IR submodule.
auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
populateIRCore(irModule);
diff --git a/mlir/lib/Bindings/Python/Traceback.cpp b/mlir/lib/Bindings/Python/Traceback.cpp
new file mode 100644
index 0000000000000..a3584f749046b
--- /dev/null
+++ b/mlir/lib/Bindings/Python/Traceback.cpp
@@ -0,0 +1,443 @@
+/* Copyright 2020 The JAX Authors
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "Traceback.h"
+
+#include <Python.h>
+
+#include <array>
+#include <atomic>
+#include <bit>
+#include <cstddef>
+#include <cstring>
+#include <optional>
+#include <string>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include "nanobind/nanobind.h"
+#include "nanobind/stl/optional.h" // IWYU pragma: keep
+#include "nanobind/stl/string.h" // IWYU pragma: keep
+#include "nanobind/stl/string_view.h" // IWYU pragma: keep
+#include "nanobind/stl/vector.h" // IWYU pragma: keep
+
+#include "llvm/ADT/StringExtras.h"
+
+#ifdef PLATFORM_GOOGLE
+#define Py_BUILD_CORE
+#include "internal/pycore_frame.h"
+#undef Py_BUILD_CORE
+#endif // PLATFORM_GOOGLE
+
+namespace mlir::python {
+struct TracebackEntry;
+struct TracebackObject;
+} // namespace mlir::python
+namespace nb = nanobind;
+
+template <>
+struct std::hash<mlir::python::TracebackObject> {
+ std::size_t
+ operator()(const mlir::python::TracebackObject &tb) const noexcept;
+};
+
+template <>
+struct std::hash<mlir::python::TracebackEntry> {
+ std::size_t
+ operator()(const mlir::python::TracebackEntry &tbe) const noexcept;
+};
+
+namespace mlir::python {
+
+std::atomic<bool> traceback_enabled_ = true;
+
+static constexpr int kMaxFrames = 512;
+
+PyTypeObject *traceback_type_ = nullptr;
+
+// Entry in a traceback. Must be POD.
+struct TracebackEntry {
+ TracebackEntry() = default;
+ TracebackEntry(PyCodeObject *code, int lasti) : code(code), lasti(lasti) {}
+ PyCodeObject *code;
+ int lasti;
+
+ bool operator==(const TracebackEntry &other) const {
+ return code == other.code && lasti == other.lasti;
+ }
+ bool operator!=(const TracebackEntry &other) const {
+ return !operator==(other);
+ }
+};
+static_assert(std::is_trivial_v<TracebackEntry> == true);
+
+struct TracebackObject {
+ PyObject_VAR_HEAD;
+ TracebackEntry frames[];
+};
+
+static_assert(sizeof(TracebackObject) % alignof(PyObject) == 0);
+static_assert(sizeof(TracebackEntry) % alignof(void *) == 0);
+
+bool traceback_check(nb::handle o) {
+ return Py_TYPE(o.ptr()) == traceback_type_;
+}
+
+Py_hash_t traceback_tp_hash(PyObject *o) {
+ TracebackObject *tb = reinterpret_cast<TracebackObject *>(o);
+ std::hash<TracebackObject> hasher{};
+ size_t h = hasher(*tb);
+ Py_hash_t s = llvm::bit_cast<Py_hash_t>(h); // Python hashes are signed.
+ return s == -1 ? -2 : s; // -1 must not be used as a Python hash value.
+}
+
+PyObject *traceback_tp_richcompare(PyObject *self, PyObject *other, int op) {
+ if (op != Py_EQ && op != Py_NE) {
+ return Py_NewRef(Py_NotImplemented);
+ }
+
+ if (!traceback_check(other)) {
+ return Py_NewRef(Py_False);
+ }
+ TracebackObject *tb_self = reinterpret_cast<TracebackObject *>(self);
+ TracebackObject *tb_other = reinterpret_cast<TracebackObject *>(other);
+ if (Py_SIZE(tb_self) != Py_SIZE(tb_other)) {
+ return Py_NewRef(op == Py_EQ ? Py_False : Py_True);
+ }
+ for (Py_ssize_t i = 0; i < Py_SIZE(tb_self); ++i) {
+ if ((tb_self->frames[i] != tb_other->frames[i])) {
+ return Py_NewRef(op == Py_EQ ? Py_False : Py_True);
+ }
+ }
+ return Py_NewRef(op == Py_EQ ? Py_True : Py_False);
+}
+
+static void traceback_tp_dealloc(PyObject *self) {
+ TracebackObject *tb = reinterpret_cast<TracebackObject *>(self);
+ for (Py_ssize_t i = 0; i < Py_SIZE(tb); ++i) {
+ Py_XDECREF(tb->frames[i].code);
+ }
+ PyTypeObject *tp = Py_TYPE(self);
+ tp->tp_free((PyObject *)self);
+ Py_DECREF(tp);
+}
+
+Traceback::Frame DecodeFrame(const TracebackEntry &frame) {
+ return Traceback::Frame{
+ .file_name = nb::borrow<nb::str>(frame.code->co_filename),
+ .function_name = nb::borrow<nb::str>(frame.code->co_qualname),
+ .function_start_line = frame.code->co_firstlineno,
+ .line_num = PyCode_Addr2Line(frame.code, frame.lasti),
+ };
+}
+
+std::string traceback_to_string(const TracebackObject *tb) {
+ std::vector<std::string> frame_strs;
+ frame_strs.reserve(Py_SIZE(tb));
+ for (Py_ssize_t i = 0; i < Py_SIZE(tb); ++i) {
+ frame_strs.push_back(DecodeFrame(tb->frames[i]).ToString());
+ }
+ return llvm::join(frame_strs, "\n");
+}
+
+PyObject *traceback_tp_str(PyObject *self) {
+ TracebackObject *tb = reinterpret_cast<TracebackObject *>(self);
+ return nb::cast(traceback_to_string(tb)).release().ptr();
+}
+
+// It turns out to be slightly faster to define a tp_hash slot rather than
+// defining __hash__ and __eq__ on the class.
+PyType_Slot traceback_slots_[] = {
+ {Py_tp_hash, reinterpret_cast<void *>(traceback_tp_hash)},
+ {Py_tp_richcompare, reinterpret_cast<void *>(traceback_tp_richcompare)},
+ {Py_tp_dealloc, reinterpret_cast<void *>(traceback_tp_dealloc)},
+ {Py_tp_str, reinterpret_cast<void *>(traceback_tp_str)},
+ {0, nullptr},
+};
+
+nb::object AsPythonTraceback(const Traceback &tb) {
+ nb::object traceback = nb::none();
+ nb::dict globals;
+ nb::handle traceback_type(reinterpret_cast<PyObject *>(&PyTraceBack_Type));
+ TracebackObject *tb_obj = reinterpret_cast<TracebackObject *>(tb.ptr());
+ for (Py_ssize_t i = 0; i < Py_SIZE(tb_obj); ++i) {
+ const TracebackEntry &frame = tb_obj->frames[i];
+ int lineno = PyCode_Addr2Line(frame.code, frame.lasti);
+ // Under Python 3.11 we observed crashes when using a fake PyFrameObject
+ // with a real PyCodeObject (https://github.com/google/jax/issues/16027).
+ // because the frame does not have fields necessary to compute the locals,
+ // notably the closure object, leading to crashes in CPython in
+ // _PyFrame_FastToLocalsWithError
+ // https://github.com/python/cpython/blob/deaf509e8fc6e0363bd6f26d52ad42f976ec42f2/Objects/frameobject.c#LL1116C2-L1116C2
+ // We therefore always build a fake code object to go along with our fake
+ // frame.
+ PyCodeObject *py_code =
+ PyCode_NewEmpty(PyUnicode_AsUTF8(frame.code->co_filename),
+ PyUnicode_AsUTF8(frame.code->co_name), lineno);
+ PyFrameObject *py_frame = PyFrame_New(PyThreadState_Get(), py_code,
+ globals.ptr(), /*locals=*/nullptr);
+ Py_DECREF(py_code);
+
+ traceback = traceback_type(
+ /*tb_next=*/std::move(traceback),
+ /*tb_frame=*/
+ nb::steal<nb::object>(reinterpret_cast<PyObject *>(py_frame)),
+ /*tb_lasti=*/0,
+ /*tb_lineno=*/lineno);
+ }
+ return traceback;
+}
+
+std::vector<Traceback::Frame> Traceback::Frames() const {
+ // We require the GIL because we manipulate Python strings.
+ assert(PyGILState_Check());
+ std::vector<Traceback::Frame> frames;
+ TracebackObject *tb = reinterpret_cast<TracebackObject *>(ptr());
+ frames.reserve(Py_SIZE(tb));
+ for (Py_ssize_t i = 0; i < Py_SIZE(tb); ++i) {
+ const TracebackEntry &frame = tb->frames[i];
+ frames.push_back(DecodeFrame(frame));
+ }
+ return frames;
+}
+
+std::string Traceback::Frame::ToString() const {
+ std::string s = nb::cast<std::string>(file_name);
+ s += ":" + std::to_string(line_num) + " ";
+ s += "(" + nb::cast<std::string>(function_name) + ")";
+ return s;
+}
+
+std::string Traceback::ToString() const {
+ return traceback_to_string(reinterpret_cast<const TracebackObject *>(ptr()));
+}
+
+std::vector<std::pair<PyCodeObject *, int>> Traceback::RawFrames() const {
+ const TracebackObject *tb = reinterpret_cast<const TracebackObject *>(ptr());
+ std::vector<std::pair<PyCodeObject *, int>> frames;
+ frames.reserve(Py_SIZE(tb));
+ for (Py_ssize_t i = 0; i < Py_SIZE(tb); ++i) {
+ frames.push_back(std::make_pair(tb->frames[i].code, tb->frames[i].lasti));
+ }
+ return frames;
+}
+
+/*static*/ bool Traceback::Check(PyObject *o) { return traceback_check(o); }
+
+/*static*/ std::optional<Traceback> Traceback::Get() {
+ // We use a thread_local here mostly to avoid requiring a large amount of
+ // space.
+ thread_local std::array<TracebackEntry, kMaxFrames> frames;
+ int count = 0;
+
+ assert(PyGILState_Check());
+
+ if (!traceback_enabled_.load()) {
+ return std::nullopt;
+ }
+
+ PyThreadState *thread_state = PyThreadState_GET();
+
+#ifdef PLATFORM_GOOGLE
+// This code is equivalent to the version using public APIs, but it saves us
+// an allocation of one object per stack frame. However, this is definitely
+// violating the API contract of CPython, so we only use this where we can be
+// confident we know exactly which CPython we are using (internal to Google).
+// Feel free to turn this on if you like, but it might break at any time!
+#if PY_VERSION_HEX < 0x030d0000
+ for (_PyInterpreterFrame *f = thread_state->cframe->current_frame;
+ f != nullptr && count < kMaxFrames; f = f->previous) {
+ if (_PyFrame_IsIncomplete(f))
+ continue;
+ Py_INCREF(f->f_code);
+ frames[count] = {f->f_code, static_cast<int>(_PyInterpreterFrame_LASTI(f) *
+ sizeof(_Py_CODEUNIT))};
+ ++count;
+ }
+#else // PY_VERSION_HEX < 0x030d0000
+ for (_PyInterpreterFrame *f = thread_state->current_frame;
+ f != nullptr && count < kMaxFrames; f = f->previous) {
+ if (_PyFrame_IsIncomplete(f))
+ continue;
+ Py_INCREF(f->f_executable);
+ frames[count] = {
+ reinterpret_cast<PyCodeObject *>(f->f_executable),
+ static_cast<int>(_PyInterpreterFrame_LASTI(f) * sizeof(_Py_CODEUNIT))};
+ ++count;
+ }
+#endif // PY_VERSION_HEX < 0x030d0000
+
+#else // PLATFORM_GOOGLE
+ PyFrameObject *next;
+ for (PyFrameObject *py_frame = PyThreadState_GetFrame(thread_state);
+ py_frame != nullptr && count < kMaxFrames; py_frame = next) {
+ frames[count] = {PyFrame_GetCode(py_frame), PyFrame_GetLasti(py_frame)};
+ ++count;
+ next = PyFrame_GetBack(py_frame);
+ Py_XDECREF(py_frame);
+ }
+#endif // PLATFORM_GOOGLE
+
+ Traceback traceback =
+ nb::steal<Traceback>(PyObject_NewVar(PyObject, traceback_type_, count));
+ TracebackObject *tb = reinterpret_cast<TracebackObject *>(traceback.ptr());
+ std::memcpy(tb->frames, frames.data(), sizeof(TracebackEntry) * count);
+ return traceback;
+}
+
+template <typename Func>
+nanobind::object nb_property_readonly(Func &&get) {
+ nanobind::handle property(reinterpret_cast<PyObject *>(&PyProperty_Type));
+ return property(nanobind::cpp_function(std::forward<Func>(get)),
+ nanobind::none(), nanobind::none(), "");
+}
+
+void BuildTracebackSubmodule(nb::module_ &m) {
+ nb::class_<Traceback::Frame>(m, "Frame")
+ .def(nb::init<const nb::str &, const nb::str &, int, int>())
+ .def_ro("file_name", &Traceback::Frame::file_name)
+ .def_ro("function_name", &Traceback::Frame::function_name)
+ .def_ro("function_start_line", &Traceback::Frame::function_start_line)
+ .def_ro("line_num", &Traceback::Frame::line_num)
+ .def("__repr__", [](const Traceback::Frame &frame) {
+ std::string s = nb::cast<std::string>(frame.function_name);
+ s += ";" + nb::cast<std::string>(frame.file_name);
+ s += ":" + std::to_string(frame.line_num);
+ return s;
+ });
+
+ std::string name = nb::cast<std::string>(m.attr("__name__"));
+ name += ".Traceback";
+
+ PyType_Spec traceback_spec = {
+ /*.name=*/name.c_str(),
+ /*.basicsize=*/static_cast<int>(sizeof(TracebackObject)),
+ /*.itemsize=*/static_cast<int>(sizeof(TracebackEntry)),
+ /*.flags=*/Py_TPFLAGS_DEFAULT,
+ /*.slots=*/traceback_slots_,
+ };
+
+ traceback_type_ =
+ reinterpret_cast<PyTypeObject *>(PyType_FromSpec(&traceback_spec));
+ if (!traceback_type_) {
+ throw nb::python_error();
+ }
+
+ auto type = nb::borrow<nb::object>(traceback_type_);
+ m.attr("Traceback") = type;
+
+ m.def("tracebacks_enabled", []() { return traceback_enabled_.load(); });
+ m.def("set_tracebacks_enabled",
+ [](bool value) { traceback_enabled_.store(value); });
+
+ type.attr("get_traceback") = nb::cpp_function(Traceback::Get,
+ R"doc(
+ Returns a :class:`Traceback` for the current thread.
+
+ If ``Traceback.enabled`` is ``True``, returns a :class:`Traceback`
+ object that describes the Python stack of the calling thread. Stack
+ trace collection has a small overhead, so it is disabled by default. If
+ traceback collection is disabled, returns ``None``. )doc");
+ type.attr("frames") = nb_property_readonly(&Traceback::Frames);
+ type.attr("raw_frames") = nb::cpp_function(
+ [](const Traceback &tb) -> nb::tuple {
+ // We return a tuple of lists, rather than a list of tuples, because it
+ // is cheaper to allocate only three Python objects for everything
+ // rather than one per frame.
+ std::vector<std::pair<PyCodeObject *, int>> frames = tb.RawFrames();
+ nb::list out_code = nb::steal<nb::list>(PyList_New(frames.size()));
+ nb::list out_lasti = nb::steal<nb::list>(PyList_New(frames.size()));
+ for (size_t i = 0; i < frames.size(); ++i) {
+ const auto &frame = frames[i];
+ PyObject *code = reinterpret_cast<PyObject *>(frame.first);
+ Py_INCREF(code);
+ PyList_SET_ITEM(out_code.ptr(), i, code);
+ PyList_SET_ITEM(out_lasti.ptr(), i,
+ nb::int_(frame.second).release().ptr());
+ }
+ return nb::make_tuple(out_code, out_lasti);
+ },
+ nb::is_method());
+ type.attr("as_python_traceback") =
+ nb::cpp_function(AsPythonTraceback, nb::is_method());
+
+ type.attr("traceback_from_frames") = nb::cpp_function(
+ [](std::vector<Traceback::Frame> frames) {
+ nb::object traceback = nb::none();
+ nb::dict globals;
+ nb::handle traceback_type(
+ reinterpret_cast<PyObject *>(&PyTraceBack_Type));
+ for (const Traceback::Frame &frame : frames) {
+ PyCodeObject *py_code =
+ PyCode_NewEmpty(frame.file_name.c_str(),
+ frame.function_name.c_str(), frame.line_num);
+ PyFrameObject *py_frame = PyFrame_New(PyThreadState_Get(), py_code,
+ globals.ptr(), /*locals=*/
+ nullptr);
+ Py_DECREF(py_code);
+ traceback = traceback_type(
+ /*tb_next=*/std::move(traceback),
+ /*tb_frame=*/
+ nb::steal<nb::object>(reinterpret_cast<PyObject *>(py_frame)),
+ /*tb_lasti=*/0,
+ /*tb_lineno=*/
+ frame.line_num);
+ }
+ return traceback;
+ },
+ "Creates a traceback from a list of frames.");
+
+ type.attr("code_addr2line") = nb::cpp_function(
+ [](nb::handle code, int lasti) {
+ if (!PyCode_Check(code.ptr())) {
+ throw std::runtime_error("code argument must be a code object");
+ }
+ return PyCode_Addr2Line(reinterpret_cast<PyCodeObject *>(code.ptr()),
+ lasti);
+ },
+ "Python wrapper around the Python C API function PyCode_Addr2Line");
+
+ type.attr("code_addr2location") = nb::cpp_function(
+ [](nb::handle code, int lasti) {
+ if (!PyCode_Check(code.ptr())) {
+ throw std::runtime_error("code argument must be a code object");
+ }
+ int start_line, start_column, end_line, end_column;
+ if (!PyCode_Addr2Location(reinterpret_cast<PyCodeObject *>(code.ptr()),
+ lasti, &start_line, &start_column, &end_line,
+ &end_column)) {
+ throw nb::python_error();
+ }
+ return nb::make_tuple(start_line, start_column, end_line, end_column);
+ },
+ "Python wrapper around the Python C API function PyCode_Addr2Location");
+}
+} // namespace mlir::python
+
+std::size_t std::hash<mlir::python::TracebackObject>::operator()(
+ const mlir::python::TracebackObject &tb) const noexcept {
+ const unsigned length = Py_SIZE(&tb);
+ const mlir::python::TracebackEntry *begin = &tb.frames[0];
+ const mlir::python::TracebackEntry *end = begin + length;
+ const unsigned *VBegin = reinterpret_cast<const unsigned *>(begin);
+ const unsigned *VEnd = reinterpret_cast<const unsigned *>(end);
+ return llvm::hash_combine(length, llvm::hash_combine_range(VBegin, VEnd));
+}
+
+std::size_t std::hash<mlir::python::TracebackEntry>::operator()(
+ const mlir::python::TracebackEntry &tbe) const noexcept {
+ return llvm::hash_combine(tbe.code, tbe.lasti);
+}
\ No newline at end of file
diff --git a/mlir/lib/Bindings/Python/Traceback.h b/mlir/lib/Bindings/Python/Traceback.h
new file mode 100644
index 0000000000000..0aab15ddc5da3
--- /dev/null
+++ b/mlir/lib/Bindings/Python/Traceback.h
@@ -0,0 +1,63 @@
+/* Copyright 2020 The JAX Authors
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef JAXLIB_TRACEBACK_H_
+#define JAXLIB_TRACEBACK_H_
+
+#include <Python.h>
+
+#include <optional>
+#include <string>
+#include <utility>
+#include <vector>
+
+// placeholder for index annotation headers
+#include "nanobind/nanobind.h"
+
+namespace mlir::python {
+
+class Traceback : public nanobind::object {
+public:
+ NB_OBJECT(Traceback, nanobind::object, "Traceback", Traceback::Check);
+
+ // Returns a traceback if it is enabled, otherwise returns nullopt.
+ static std::optional<Traceback> Get();
+
+ // Returns a string representation of the traceback.
+ std::string ToString() const;
+
+ // Returns a list of (code, lasti) pairs for each frame in the traceback.
+ std::vector<std::pair<PyCodeObject *, int>> RawFrames() const;
+
+ struct Frame {
+ nanobind::str file_name;
+ nanobind::str function_name;
+ int function_start_line;
+ int line_num;
+
+ std::string ToString() const;
+ };
+ // Returns a list of Frames for the traceback.
+ std::vector<Frame> Frames() const;
+
+private:
+ static bool Check(PyObject *o);
+};
+
+void BuildTracebackSubmodule(nanobind::module_ &m);
+
+} // namespace mlir::python
+
+#endif // JAXLIB_TRACEBACK_H_
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 7a0c95ebb8200..243fbe64de900 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -23,6 +23,8 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
passmanager.py
rewrite.py
dialects/_ods_common.py
+ source_info_util.py
+ traceback_util.py
# The main _mlir module has submodules: include stubs from each.
_mlir_libs/_mlir/__init__.pyi
@@ -486,6 +488,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
IRTypes.cpp
Pass.cpp
Rewrite.cpp
+ Traceback.cpp
# Headers must be included explicitly so they are installed.
Globals.h
@@ -493,6 +496,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
Pass.h
NanobindUtils.h
Rewrite.h
+ Traceback.h
PRIVATE_LINK_LIBS
LLVMSupport
EMBED_CAPI_LINK_LIBS
diff --git a/mlir/python/mlir/source_info_util.py b/mlir/python/mlir/source_info_util.py
new file mode 100644
index 0000000000000..ff7ee09149c60
--- /dev/null
+++ b/mlir/python/mlir/source_info_util.py
@@ -0,0 +1,367 @@
+# Copyright 2020 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from collections.abc import Iterator
+import contextlib
+import dataclasses
+import functools
+import itertools
+import os.path
+import re
+import sysconfig
+import threading
+import types
+from typing import NamedTuple
+
+# import jax.version
+# from jax._src.lib import xla_client
+
+from . import traceback_util
+
+traceback_util.register_exclusion(__file__)
+
+from ._mlir_libs._mlir import Traceback
+
+
+class Frame(NamedTuple):
+ file_name: str
+ function_name: str
+ start_line: int
+ start_column: int
+ end_line: int
+ end_column: int
+
+
+_exclude_paths: list[str] = [
+ # Attach the separator to make sure that .../jax does not end up matching
+ # .../jax_triton and other packages that might have a jax prefix.
+ # os.path.dirname(jax.version.__file__) + os.sep,
+ # Also exclude stdlib as user frames. In a non-standard Python runtime,
+ # the following two may be different.
+ sysconfig.get_path("stdlib"),
+ os.path.dirname(sysconfig.__file__),
+]
+
+
+ at functools.cache
+def _exclude_path_regex() -> re.Pattern[str]:
+ # The regex below would not handle an empty set of exclusions correctly.
+ assert len(_exclude_paths) > 0
+ return re.compile("|".join(f"^{re.escape(path)}" for path in _exclude_paths))
+
+
+def register_exclusion(path: str):
+ _exclude_paths.append(path)
+ _exclude_path_regex.cache_clear()
+ is_user_filename.cache_clear()
+
+
+# Explicit inclusions take priority over exclude paths.
+_include_paths: list[str] = []
+
+
+ at functools.cache
+def _include_path_regex() -> re.Pattern[str]:
+ patterns = [f"^{re.escape(path)}" for path in _include_paths]
+ patterns.append("_test.py$")
+ return re.compile("|".join(patterns))
+
+
+def register_inclusion(path: str):
+ _include_paths.append(path)
+ _include_path_regex.cache_clear()
+ is_user_filename.cache_clear()
+
+
+class Scope(NamedTuple):
+ name: str
+
+ def wrap(self, stack: list[str]):
+ stack.append(self.name)
+
+
+class Transform(NamedTuple):
+ name: str
+
+ def wrap(self, stack: list[str]):
+ if stack:
+ stack[-1] = f"{self.name}({stack[-1]})"
+ else:
+ stack.append(f"{self.name}()")
+
+
+ at dataclasses.dataclass(frozen=True)
+class NameStack:
+ stack: tuple[Scope | Transform, ...] = ()
+
+ def extend(self, name: str) -> NameStack:
+ return NameStack((*self.stack, Scope(name)))
+
+ def transform(self, transform_name: str) -> NameStack:
+ return NameStack((*self.stack, Transform(transform_name)))
+
+ def __getitem__(self, idx: slice) -> NameStack:
+ return NameStack(self.stack[idx])
+
+ def __len__(self):
+ return len(self.stack)
+
+ def __add__(self, other: NameStack) -> NameStack:
+ return NameStack(self.stack + other.stack)
+
+ def __radd__(self, other: NameStack) -> NameStack:
+ return NameStack(other.stack + self.stack)
+
+ def __str__(self) -> str:
+ scope: list[str] = []
+ for elem in self.stack[::-1]:
+ elem.wrap(scope)
+ return "/".join(reversed(scope))
+
+
+def new_name_stack(name: str = "") -> NameStack:
+ name_stack = NameStack()
+ if name:
+ name_stack = name_stack.extend(name)
+ return name_stack
+
+
+class SourceInfo:
+ traceback: Traceback | None
+ name_stack: NameStack
+
+ # It's slightly faster to use a class with __slots__ than a NamedTuple.
+ __slots__ = ["traceback", "name_stack"]
+
+ def __init__(self, traceback: Traceback | None, name_stack: NameStack):
+ self.traceback = traceback
+ self.name_stack = name_stack
+
+ def replace(
+ self, *, traceback: Traceback | None = None, name_stack: NameStack | None = None
+ ) -> SourceInfo:
+ return SourceInfo(
+ self.traceback if traceback is None else traceback,
+ self.name_stack if name_stack is None else name_stack,
+ )
+
+
+def new_source_info() -> SourceInfo:
+ return SourceInfo(None, NameStack())
+
+
+ at functools.cache
+def is_user_filename(filename: str) -> bool:
+ """Heuristic that guesses the identity of the user's code in a stack trace."""
+ return (
+ _include_path_regex().search(filename) is not None
+ or _exclude_path_regex().search(filename) is None
+ )
+
+
+def raw_frame_to_frame(code: types.CodeType, lasti: int) -> Frame:
+ loc = Traceback.code_addr2location(code, lasti)
+ start_line, start_column, end_line, end_column = loc
+ return Frame(
+ file_name=code.co_filename,
+ function_name=code.co_qualname,
+ start_line=start_line,
+ start_column=start_column,
+ end_line=end_line,
+ end_column=end_column,
+ )
+
+
+def user_frames(traceback: Traceback | None) -> Iterator[Frame]:
+ """Iterator over the user's frames, filtering jax-internal frames."""
+ # Guess the user's frame is the innermost frame not in the jax source tree or
+ # Python stdlib. We don't use traceback_util.path_starts_with because that
+ # incurs filesystem access, which may be slow; we call this function when
+ # e.g. adding source provenance annotations to XLA lowerings, so we don't
+ # want to incur the cost. We consider files that end with _test.py as user
+ # frames, to allow testing this mechanism from tests.
+ code, lasti = traceback.raw_frames() if traceback else ([], [])
+ return (
+ raw_frame_to_frame(code[i], lasti[i])
+ for i in range(len(code))
+ if is_user_filename(code[i].co_filename)
+ )
+
+
+ at functools.lru_cache(maxsize=64)
+def user_frame(traceback: Traceback | None) -> Frame | None:
+ return next(user_frames(traceback), None)
+
+
+def _summarize_frame(frame: Frame) -> str:
+ if frame.start_column != 0:
+ return (
+ f"{frame.file_name}:{frame.start_line}:{frame.start_column} "
+ f"({frame.function_name})"
+ )
+ else:
+ return f"{frame.file_name}:{frame.start_line} ({frame.function_name})"
+
+
+def summarize(source_info: SourceInfo, num_frames=1) -> str:
+ frames = itertools.islice(user_frames(source_info.traceback), num_frames)
+ frame_strs = [_summarize_frame(frame) if frame else "unknown" for frame in frames]
+ return "\n".join(reversed(frame_strs))
+
+
+class _SourceInfoContext(threading.local):
+ context: SourceInfo
+
+ def __init__(self):
+ self.context = new_source_info()
+
+
+_source_info_context = _SourceInfoContext()
+
+
+def current() -> SourceInfo:
+ source_info = _source_info_context.context
+ if not source_info.traceback:
+ source_info = source_info.replace(traceback=Traceback.get_traceback())
+ return source_info
+
+
+class JaxStackTraceBeforeTransformation(Exception):
+ pass
+
+
+_message = (
+ "The preceding stack trace is the source of the JAX operation that, once "
+ "transformed by JAX, triggered the following exception.\n"
+ "\n--------------------"
+)
+
+
+def has_user_context(e):
+ while e is not None:
+ if isinstance(e, JaxStackTraceBeforeTransformation):
+ return True
+ e = e.__cause__
+ return False
+
+
+class UserContextManager:
+ __slots__ = ["traceback", "name_stack", "prev"]
+
+ def __init__(
+ self, traceback: Traceback | None, *, name_stack: NameStack | None = None
+ ):
+ self.traceback = traceback
+ self.name_stack = name_stack
+
+ def __enter__(self):
+ self.prev = _source_info_context.context
+ _source_info_context.context = _source_info_context.context.replace(
+ traceback=self.traceback, name_stack=self.name_stack
+ )
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ _source_info_context.context = self.prev
+ if exc_type is None or exc_value is None:
+ return
+
+ if self.traceback is None or has_user_context(exc_value):
+ return
+
+ filtered_tb = traceback_util.filter_traceback(
+ self.traceback.as_python_traceback()
+ )
+ if filtered_tb:
+ msg = traceback_util.format_exception_only(exc_value)
+ msg = f"{msg}\n\n{_message}"
+ exp = JaxStackTraceBeforeTransformation(msg).with_traceback(filtered_tb)
+ exp.__context__ = exc_value.__context__
+ exp.__cause__ = exc_value.__cause__
+ exp.__suppress_context__ = exc_value.__suppress_context__
+ exc_value.__context__ = None
+ exc_value.__cause__ = exp
+
+
+user_context = UserContextManager
+
+
+def current_name_stack() -> NameStack:
+ return _source_info_context.context.name_stack
+
+
+class ExtendNameStackContextManager(contextlib.ContextDecorator):
+ __slots__ = ["name", "prev"]
+
+ def __init__(self, name: str):
+ self.name = name
+
+ def __enter__(self):
+ self.prev = prev = _source_info_context.context
+ name_stack = prev.name_stack.extend(self.name)
+ _source_info_context.context = prev.replace(name_stack=name_stack)
+ return name_stack
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ _source_info_context.context = self.prev
+
+
+extend_name_stack = ExtendNameStackContextManager
+
+
+class SetNameStackContextManager(contextlib.ContextDecorator):
+ __slots__ = ["name_stack", "prev"]
+
+ def __init__(self, name_stack: NameStack):
+ self.name_stack = name_stack
+
+ def __enter__(self):
+ self.prev = prev = _source_info_context.context
+ _source_info_context.context = prev.replace(name_stack=self.name_stack)
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ _source_info_context.context = self.prev
+
+
+set_name_stack = SetNameStackContextManager
+
+
+# TODO(mattjj,phawkins): figure out why the commented-out reset_name_stack
+# implementation doesn't work. Luckily this context manager isn't called much so
+# the performance shouldn't matter. See blame commit message for repro.
+# reset_name_stack = lambda: SetNameStackContextManager(NameStack())
+ at contextlib.contextmanager
+def reset_name_stack() -> Iterator[None]:
+ with set_name_stack(NameStack()):
+ yield
+
+
+class TransformNameStackContextManager(contextlib.ContextDecorator):
+ __slots__ = ["name", "prev"]
+
+ def __init__(self, name: str):
+ self.name = name
+
+ def __enter__(self):
+ self.prev = prev = _source_info_context.context
+ name_stack = prev.name_stack.transform(self.name)
+ _source_info_context.context = prev.replace(name_stack=name_stack)
+ return name_stack
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ _source_info_context.context = self.prev
+
+
+transform_name_stack = TransformNameStackContextManager
diff --git a/mlir/python/mlir/traceback_util.py b/mlir/python/mlir/traceback_util.py
new file mode 100644
index 0000000000000..1794fcd1d08ba
--- /dev/null
+++ b/mlir/python/mlir/traceback_util.py
@@ -0,0 +1,238 @@
+# Copyright 2020 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from collections.abc import Callable
+import functools
+import os
+import traceback
+import types
+from typing import Any, TypeVar, cast
+
+
+C = TypeVar("C", bound=Callable[..., Any])
+
+_exclude_paths: list[str] = [__file__]
+
+
+def register_exclusion(path: str):
+ _exclude_paths.append(path)
+
+
+_jax_message_append = (
+ "The stack trace below excludes JAX-internal frames.\n"
+ "The preceding is the original exception that occurred, unmodified.\n"
+ "\n--------------------"
+)
+
+
+def _path_starts_with(path: str, path_prefix: str) -> bool:
+ path = os.path.abspath(path)
+ path_prefix = os.path.abspath(path_prefix)
+ try:
+ common = os.path.commonpath([path, path_prefix])
+ except ValueError:
+ # path and path_prefix are both absolute, the only case will raise a
+ # ValueError is different drives.
+ # https://docs.python.org/3/library/os.path.html#os.path.commonpath
+ return False
+ try:
+ return common == path_prefix or os.path.samefile(common, path_prefix)
+ except OSError:
+ # One of the paths may not exist.
+ return False
+
+
+def include_frame(f: types.FrameType) -> bool:
+ return include_filename(f.f_code.co_filename)
+
+
+def include_filename(filename: str) -> bool:
+ return not any(_path_starts_with(filename, path) for path in _exclude_paths)
+
+
+# When scanning stack traces, we might encounter frames from cpython that are
+# removed from printed stack traces, such as frames from parts of importlib. We
+# ignore these frames heuristically based on source and name match.
+def _ignore_known_hidden_frame(f: types.FrameType) -> bool:
+ return "importlib._bootstrap" in f.f_code.co_filename
+
+
+def _add_tracebackhide_to_hidden_frames(tb: types.TracebackType):
+ for f, _lineno in traceback.walk_tb(tb):
+ if not include_frame(f):
+ f.f_locals["__tracebackhide__"] = True
+
+
+def filter_traceback(tb: types.TracebackType) -> types.TracebackType | None:
+ out = None
+ # Scan the traceback and collect relevant frames.
+ frames = list(traceback.walk_tb(tb))
+ for f, lineno in reversed(frames):
+ if include_frame(f):
+ out = types.TracebackType(out, f, f.f_lasti, lineno)
+ return out
+
+
+def _add_call_stack_frames(tb: types.TracebackType) -> types.TracebackType:
+ # Continue up the call stack.
+ #
+ # We would like to avoid stepping too far up, e.g. past the exec/eval point of
+ # a REPL such as IPython. To that end, we stop past the first contiguous bunch
+ # of module-level frames, if we reach any such frames at all. This is a
+ # heuristic that might stop in advance of the REPL boundary. For example, if
+ # the call stack includes module-level frames from the current module A, and
+ # the current module A was imported from within a function F elsewhere, then
+ # the stack trace we produce will be truncated at F's frame.
+ out = tb
+
+ reached_module_level = False
+ for f, lineno in traceback.walk_stack(tb.tb_frame):
+ if _ignore_known_hidden_frame(f):
+ continue
+ if reached_module_level and f.f_code.co_name != "<module>":
+ break
+ if include_frame(f):
+ out = types.TracebackType(out, f, f.f_lasti, lineno)
+ if f.f_code.co_name == "<module>":
+ reached_module_level = True
+ return out
+
+
+def _is_reraiser_frame(f: traceback.FrameSummary) -> bool:
+ return f.filename == __file__ and f.name == "reraise_with_filtered_traceback"
+
+
+def _is_under_reraiser(e: BaseException) -> bool:
+ if e.__traceback__ is None:
+ return False
+ tb = traceback.extract_stack(e.__traceback__.tb_frame)
+ return any(_is_reraiser_frame(f) for f in tb[:-1])
+
+
+def format_exception_only(e: BaseException) -> str:
+ return "".join(traceback.format_exception_only(type(e), e)).strip()
+
+
+class UnfilteredStackTrace(Exception):
+ pass
+
+
+_simplified_tb_msg = (
+ "For simplicity, JAX has removed its internal frames from the "
+ "traceback of the following exception. Set "
+ "JAX_TRACEBACK_FILTERING=off to include these."
+)
+
+
+class SimplifiedTraceback(Exception):
+ def __str__(self):
+ return _simplified_tb_msg
+
+
+SimplifiedTraceback.__module__ = "jax.errors"
+
+
+def _running_under_ipython() -> bool:
+ """Returns true if we appear to be in an IPython session."""
+ try:
+ get_ipython() # type: ignore
+ return True
+ except NameError:
+ return False
+
+
+def _ipython_supports_tracebackhide() -> bool:
+ """Returns true if the IPython version supports __tracebackhide__."""
+ import IPython # pytype: disable=import-error
+
+ return IPython.version_info[:2] >= (7, 17)
+
+
+def _filtering_mode() -> str:
+ mode = None
+ if mode is None or mode == "auto":
+ if _running_under_ipython() and _ipython_supports_tracebackhide():
+ mode = "tracebackhide"
+ else:
+ mode = "quiet_remove_frames"
+ return mode
+
+
+def api_boundary(fun: C) -> C:
+ """Wraps ``fun`` to form a boundary for filtering exception tracebacks.
+
+ When an exception occurs below ``fun``, this appends to it a custom
+ ``__cause__`` that carries a filtered traceback. The traceback imitates the
+ stack trace of the original exception, but with JAX-internal frames removed.
+
+ This boundary annotation works in composition with itself. The topmost frame
+ corresponding to an :func:`~api_boundary` is the one below which stack traces
+ are filtered. In other words, if ``api_boundary(f)`` calls
+ ``api_boundary(g)``, directly or indirectly, the filtered stack trace provided
+ is the same as if ``api_boundary(f)`` were to simply call ``g`` instead.
+
+ This annotation is primarily useful in wrapping functions output by JAX's
+ transformations. For example, consider ``g = jax.jit(f)``. When ``g`` is
+ called, JAX's JIT compilation machinery is invoked, which in turn calls ``f``
+ in order to trace and translate it. If the function ``f`` raises an exception,
+ the stack unwinds through JAX's JIT internals up to the original call site of
+ ``g``. Because the function returned by :func:`~jax.jit` is annotated as an
+ :func:`~api_boundary`, such an exception is accompanied by an additional
+ traceback that excludes the frames specific to JAX's implementation.
+ """
+
+ @functools.wraps(fun)
+ def reraise_with_filtered_traceback(*args, **kwargs):
+ __tracebackhide__ = True
+ try:
+ return fun(*args, **kwargs)
+ except Exception as e:
+ mode = _filtering_mode()
+ if _is_under_reraiser(e) or mode == "off":
+ raise
+ if mode == "tracebackhide":
+ _add_tracebackhide_to_hidden_frames(e.__traceback__)
+ raise
+
+ filtered_tb, unfiltered = None, None
+ try:
+ tb = e.__traceback__
+ filtered_tb = filter_traceback(tb)
+ e.with_traceback(filtered_tb)
+ if mode == "quiet_remove_frames":
+ e.add_note("--------------------\n" + _simplified_tb_msg)
+ else:
+ if mode == "remove_frames":
+ msg = format_exception_only(e)
+ msg = f"{msg}\n\n{_jax_message_append}"
+ jax_error = UnfilteredStackTrace(msg)
+ jax_error.with_traceback(_add_call_stack_frames(tb))
+ else:
+ raise ValueError(
+ f"JAX_TRACEBACK_FILTERING={mode} is not a valid value."
+ )
+ jax_error.__cause__ = e.__cause__
+ jax_error.__context__ = e.__context__
+ jax_error.__suppress_context__ = e.__suppress_context__
+ e.__cause__ = jax_error
+ e.__context__ = None
+ raise
+ finally:
+ del filtered_tb
+ del unfiltered
+ del mode
+
+ return cast(C, reraise_with_filtered_traceback)
More information about the Mlir-commits
mailing list