[Mlir-commits] [mlir] [mlir][python] source line info (PR #149166)

Maksim Levental llvmlistbot at llvm.org
Wed Jul 16 11:50:53 PDT 2025


https://github.com/makslevental created https://github.com/llvm/llvm-project/pull/149166

None

>From bec65c38fae17e00627c1dcaf7d7c4fdb40dd660 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 16 Jul 2025 14:50:30 -0400
Subject: [PATCH] [mlir][python] source line info

---
 mlir/lib/Bindings/Python/MainModule.cpp |   2 +
 mlir/lib/Bindings/Python/Traceback.cpp  | 443 ++++++++++++++++++++++++
 mlir/lib/Bindings/Python/Traceback.h    |  63 ++++
 mlir/python/CMakeLists.txt              |   4 +
 mlir/python/mlir/source_info_util.py    | 367 ++++++++++++++++++++
 mlir/python/mlir/traceback_util.py      | 238 +++++++++++++
 6 files changed, 1117 insertions(+)
 create mode 100644 mlir/lib/Bindings/Python/Traceback.cpp
 create mode 100644 mlir/lib/Bindings/Python/Traceback.h
 create mode 100644 mlir/python/mlir/source_info_util.py
 create mode 100644 mlir/python/mlir/traceback_util.py

diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 6f49431006605..489d8e21a56cd 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -12,6 +12,7 @@
 #include "NanobindUtils.h"
 #include "Pass.h"
 #include "Rewrite.h"
+#include "Traceback.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 
 namespace nb = nanobind;
@@ -105,6 +106,7 @@ NB_MODULE(_mlir, m) {
       "typeid"_a, nb::kw_only(), "replace"_a = false,
       "Register a value caster for casting MLIR values to custom user values.");
 
+  BuildTracebackSubmodule(m);
   // Define and populate IR submodule.
   auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
   populateIRCore(irModule);
diff --git a/mlir/lib/Bindings/Python/Traceback.cpp b/mlir/lib/Bindings/Python/Traceback.cpp
new file mode 100644
index 0000000000000..a3584f749046b
--- /dev/null
+++ b/mlir/lib/Bindings/Python/Traceback.cpp
@@ -0,0 +1,443 @@
+/* Copyright 2020 The JAX Authors
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "Traceback.h"
+
+#include <Python.h>
+
+#include <array>
+#include <atomic>
+#include <bit>
+#include <cstddef>
+#include <cstring>
+#include <optional>
+#include <string>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include "nanobind/nanobind.h"
+#include "nanobind/stl/optional.h"    // IWYU pragma: keep
+#include "nanobind/stl/string.h"      // IWYU pragma: keep
+#include "nanobind/stl/string_view.h" // IWYU pragma: keep
+#include "nanobind/stl/vector.h"      // IWYU pragma: keep
+
+#include "llvm/ADT/StringExtras.h"
+
+#ifdef PLATFORM_GOOGLE
+#define Py_BUILD_CORE
+#include "internal/pycore_frame.h"
+#undef Py_BUILD_CORE
+#endif // PLATFORM_GOOGLE
+
+namespace mlir::python {
+struct TracebackEntry;
+struct TracebackObject;
+} // namespace mlir::python
+namespace nb = nanobind;
+
+template <>
+struct std::hash<mlir::python::TracebackObject> {
+  std::size_t
+  operator()(const mlir::python::TracebackObject &tb) const noexcept;
+};
+
+template <>
+struct std::hash<mlir::python::TracebackEntry> {
+  std::size_t
+  operator()(const mlir::python::TracebackEntry &tbe) const noexcept;
+};
+
+namespace mlir::python {
+
+std::atomic<bool> traceback_enabled_ = true;
+
+static constexpr int kMaxFrames = 512;
+
+PyTypeObject *traceback_type_ = nullptr;
+
+// Entry in a traceback. Must be POD.
+struct TracebackEntry {
+  TracebackEntry() = default;
+  TracebackEntry(PyCodeObject *code, int lasti) : code(code), lasti(lasti) {}
+  PyCodeObject *code;
+  int lasti;
+
+  bool operator==(const TracebackEntry &other) const {
+    return code == other.code && lasti == other.lasti;
+  }
+  bool operator!=(const TracebackEntry &other) const {
+    return !operator==(other);
+  }
+};
+static_assert(std::is_trivial_v<TracebackEntry> == true);
+
+struct TracebackObject {
+  PyObject_VAR_HEAD;
+  TracebackEntry frames[];
+};
+
+static_assert(sizeof(TracebackObject) % alignof(PyObject) == 0);
+static_assert(sizeof(TracebackEntry) % alignof(void *) == 0);
+
+bool traceback_check(nb::handle o) {
+  return Py_TYPE(o.ptr()) == traceback_type_;
+}
+
+Py_hash_t traceback_tp_hash(PyObject *o) {
+  TracebackObject *tb = reinterpret_cast<TracebackObject *>(o);
+  std::hash<TracebackObject> hasher{};
+  size_t h = hasher(*tb);
+  Py_hash_t s = llvm::bit_cast<Py_hash_t>(h); // Python hashes are signed.
+  return s == -1 ? -2 : s; // -1 must not be used as a Python hash value.
+}
+
+PyObject *traceback_tp_richcompare(PyObject *self, PyObject *other, int op) {
+  if (op != Py_EQ && op != Py_NE) {
+    return Py_NewRef(Py_NotImplemented);
+  }
+
+  if (!traceback_check(other)) {
+    return Py_NewRef(Py_False);
+  }
+  TracebackObject *tb_self = reinterpret_cast<TracebackObject *>(self);
+  TracebackObject *tb_other = reinterpret_cast<TracebackObject *>(other);
+  if (Py_SIZE(tb_self) != Py_SIZE(tb_other)) {
+    return Py_NewRef(op == Py_EQ ? Py_False : Py_True);
+  }
+  for (Py_ssize_t i = 0; i < Py_SIZE(tb_self); ++i) {
+    if ((tb_self->frames[i] != tb_other->frames[i])) {
+      return Py_NewRef(op == Py_EQ ? Py_False : Py_True);
+    }
+  }
+  return Py_NewRef(op == Py_EQ ? Py_True : Py_False);
+}
+
+static void traceback_tp_dealloc(PyObject *self) {
+  TracebackObject *tb = reinterpret_cast<TracebackObject *>(self);
+  for (Py_ssize_t i = 0; i < Py_SIZE(tb); ++i) {
+    Py_XDECREF(tb->frames[i].code);
+  }
+  PyTypeObject *tp = Py_TYPE(self);
+  tp->tp_free((PyObject *)self);
+  Py_DECREF(tp);
+}
+
+Traceback::Frame DecodeFrame(const TracebackEntry &frame) {
+  return Traceback::Frame{
+      .file_name = nb::borrow<nb::str>(frame.code->co_filename),
+      .function_name = nb::borrow<nb::str>(frame.code->co_qualname),
+      .function_start_line = frame.code->co_firstlineno,
+      .line_num = PyCode_Addr2Line(frame.code, frame.lasti),
+  };
+}
+
+std::string traceback_to_string(const TracebackObject *tb) {
+  std::vector<std::string> frame_strs;
+  frame_strs.reserve(Py_SIZE(tb));
+  for (Py_ssize_t i = 0; i < Py_SIZE(tb); ++i) {
+    frame_strs.push_back(DecodeFrame(tb->frames[i]).ToString());
+  }
+  return llvm::join(frame_strs, "\n");
+}
+
+PyObject *traceback_tp_str(PyObject *self) {
+  TracebackObject *tb = reinterpret_cast<TracebackObject *>(self);
+  return nb::cast(traceback_to_string(tb)).release().ptr();
+}
+
+// It turns out to be slightly faster to define a tp_hash slot rather than
+// defining __hash__ and __eq__ on the class.
+PyType_Slot traceback_slots_[] = {
+    {Py_tp_hash, reinterpret_cast<void *>(traceback_tp_hash)},
+    {Py_tp_richcompare, reinterpret_cast<void *>(traceback_tp_richcompare)},
+    {Py_tp_dealloc, reinterpret_cast<void *>(traceback_tp_dealloc)},
+    {Py_tp_str, reinterpret_cast<void *>(traceback_tp_str)},
+    {0, nullptr},
+};
+
+nb::object AsPythonTraceback(const Traceback &tb) {
+  nb::object traceback = nb::none();
+  nb::dict globals;
+  nb::handle traceback_type(reinterpret_cast<PyObject *>(&PyTraceBack_Type));
+  TracebackObject *tb_obj = reinterpret_cast<TracebackObject *>(tb.ptr());
+  for (Py_ssize_t i = 0; i < Py_SIZE(tb_obj); ++i) {
+    const TracebackEntry &frame = tb_obj->frames[i];
+    int lineno = PyCode_Addr2Line(frame.code, frame.lasti);
+    // Under Python 3.11 we observed crashes when using a fake PyFrameObject
+    // with a real PyCodeObject (https://github.com/google/jax/issues/16027).
+    // because the frame does not have fields necessary to compute the locals,
+    // notably the closure object, leading to crashes in CPython in
+    // _PyFrame_FastToLocalsWithError
+    // https://github.com/python/cpython/blob/deaf509e8fc6e0363bd6f26d52ad42f976ec42f2/Objects/frameobject.c#LL1116C2-L1116C2
+    // We therefore always build a fake code object to go along with our fake
+    // frame.
+    PyCodeObject *py_code =
+        PyCode_NewEmpty(PyUnicode_AsUTF8(frame.code->co_filename),
+                        PyUnicode_AsUTF8(frame.code->co_name), lineno);
+    PyFrameObject *py_frame = PyFrame_New(PyThreadState_Get(), py_code,
+                                          globals.ptr(), /*locals=*/nullptr);
+    Py_DECREF(py_code);
+
+    traceback = traceback_type(
+        /*tb_next=*/std::move(traceback),
+        /*tb_frame=*/
+        nb::steal<nb::object>(reinterpret_cast<PyObject *>(py_frame)),
+        /*tb_lasti=*/0,
+        /*tb_lineno=*/lineno);
+  }
+  return traceback;
+}
+
+std::vector<Traceback::Frame> Traceback::Frames() const {
+  // We require the GIL because we manipulate Python strings.
+  assert(PyGILState_Check());
+  std::vector<Traceback::Frame> frames;
+  TracebackObject *tb = reinterpret_cast<TracebackObject *>(ptr());
+  frames.reserve(Py_SIZE(tb));
+  for (Py_ssize_t i = 0; i < Py_SIZE(tb); ++i) {
+    const TracebackEntry &frame = tb->frames[i];
+    frames.push_back(DecodeFrame(frame));
+  }
+  return frames;
+}
+
+std::string Traceback::Frame::ToString() const {
+  std::string s = nb::cast<std::string>(file_name);
+  s += ":" + std::to_string(line_num) + " ";
+  s += "(" + nb::cast<std::string>(function_name) + ")";
+  return s;
+}
+
+std::string Traceback::ToString() const {
+  return traceback_to_string(reinterpret_cast<const TracebackObject *>(ptr()));
+}
+
+std::vector<std::pair<PyCodeObject *, int>> Traceback::RawFrames() const {
+  const TracebackObject *tb = reinterpret_cast<const TracebackObject *>(ptr());
+  std::vector<std::pair<PyCodeObject *, int>> frames;
+  frames.reserve(Py_SIZE(tb));
+  for (Py_ssize_t i = 0; i < Py_SIZE(tb); ++i) {
+    frames.push_back(std::make_pair(tb->frames[i].code, tb->frames[i].lasti));
+  }
+  return frames;
+}
+
+/*static*/ bool Traceback::Check(PyObject *o) { return traceback_check(o); }
+
+/*static*/ std::optional<Traceback> Traceback::Get() {
+  // We use a thread_local here mostly to avoid requiring a large amount of
+  // space.
+  thread_local std::array<TracebackEntry, kMaxFrames> frames;
+  int count = 0;
+
+  assert(PyGILState_Check());
+
+  if (!traceback_enabled_.load()) {
+    return std::nullopt;
+  }
+
+  PyThreadState *thread_state = PyThreadState_GET();
+
+#ifdef PLATFORM_GOOGLE
+// This code is equivalent to the version using public APIs, but it saves us
+// an allocation of one object per stack frame. However, this is definitely
+// violating the API contract of CPython, so we only use this where we can be
+// confident we know exactly which CPython we are using (internal to Google).
+// Feel free to turn this on if you like, but it might break at any time!
+#if PY_VERSION_HEX < 0x030d0000
+  for (_PyInterpreterFrame *f = thread_state->cframe->current_frame;
+       f != nullptr && count < kMaxFrames; f = f->previous) {
+    if (_PyFrame_IsIncomplete(f))
+      continue;
+    Py_INCREF(f->f_code);
+    frames[count] = {f->f_code, static_cast<int>(_PyInterpreterFrame_LASTI(f) *
+                                                 sizeof(_Py_CODEUNIT))};
+    ++count;
+  }
+#else  // PY_VERSION_HEX < 0x030d0000
+  for (_PyInterpreterFrame *f = thread_state->current_frame;
+       f != nullptr && count < kMaxFrames; f = f->previous) {
+    if (_PyFrame_IsIncomplete(f))
+      continue;
+    Py_INCREF(f->f_executable);
+    frames[count] = {
+        reinterpret_cast<PyCodeObject *>(f->f_executable),
+        static_cast<int>(_PyInterpreterFrame_LASTI(f) * sizeof(_Py_CODEUNIT))};
+    ++count;
+  }
+#endif // PY_VERSION_HEX < 0x030d0000
+
+#else  // PLATFORM_GOOGLE
+  PyFrameObject *next;
+  for (PyFrameObject *py_frame = PyThreadState_GetFrame(thread_state);
+       py_frame != nullptr && count < kMaxFrames; py_frame = next) {
+    frames[count] = {PyFrame_GetCode(py_frame), PyFrame_GetLasti(py_frame)};
+    ++count;
+    next = PyFrame_GetBack(py_frame);
+    Py_XDECREF(py_frame);
+  }
+#endif // PLATFORM_GOOGLE
+
+  Traceback traceback =
+      nb::steal<Traceback>(PyObject_NewVar(PyObject, traceback_type_, count));
+  TracebackObject *tb = reinterpret_cast<TracebackObject *>(traceback.ptr());
+  std::memcpy(tb->frames, frames.data(), sizeof(TracebackEntry) * count);
+  return traceback;
+}
+
+template <typename Func>
+nanobind::object nb_property_readonly(Func &&get) {
+  nanobind::handle property(reinterpret_cast<PyObject *>(&PyProperty_Type));
+  return property(nanobind::cpp_function(std::forward<Func>(get)),
+                  nanobind::none(), nanobind::none(), "");
+}
+
+void BuildTracebackSubmodule(nb::module_ &m) {
+  nb::class_<Traceback::Frame>(m, "Frame")
+      .def(nb::init<const nb::str &, const nb::str &, int, int>())
+      .def_ro("file_name", &Traceback::Frame::file_name)
+      .def_ro("function_name", &Traceback::Frame::function_name)
+      .def_ro("function_start_line", &Traceback::Frame::function_start_line)
+      .def_ro("line_num", &Traceback::Frame::line_num)
+      .def("__repr__", [](const Traceback::Frame &frame) {
+        std::string s = nb::cast<std::string>(frame.function_name);
+        s += ";" + nb::cast<std::string>(frame.file_name);
+        s += ":" + std::to_string(frame.line_num);
+        return s;
+      });
+
+  std::string name = nb::cast<std::string>(m.attr("__name__"));
+  name += ".Traceback";
+
+  PyType_Spec traceback_spec = {
+      /*.name=*/name.c_str(),
+      /*.basicsize=*/static_cast<int>(sizeof(TracebackObject)),
+      /*.itemsize=*/static_cast<int>(sizeof(TracebackEntry)),
+      /*.flags=*/Py_TPFLAGS_DEFAULT,
+      /*.slots=*/traceback_slots_,
+  };
+
+  traceback_type_ =
+      reinterpret_cast<PyTypeObject *>(PyType_FromSpec(&traceback_spec));
+  if (!traceback_type_) {
+    throw nb::python_error();
+  }
+
+  auto type = nb::borrow<nb::object>(traceback_type_);
+  m.attr("Traceback") = type;
+
+  m.def("tracebacks_enabled", []() { return traceback_enabled_.load(); });
+  m.def("set_tracebacks_enabled",
+        [](bool value) { traceback_enabled_.store(value); });
+
+  type.attr("get_traceback") = nb::cpp_function(Traceback::Get,
+                                                R"doc(
+      Returns a :class:`Traceback` for the current thread.
+
+      If ``Traceback.enabled`` is ``True``, returns a :class:`Traceback`
+      object that describes the Python stack of the calling thread. Stack
+      trace collection has a small overhead, so it is disabled by default. If
+      traceback collection is disabled, returns ``None``. )doc");
+  type.attr("frames") = nb_property_readonly(&Traceback::Frames);
+  type.attr("raw_frames") = nb::cpp_function(
+      [](const Traceback &tb) -> nb::tuple {
+        // We return a tuple of lists, rather than a list of tuples, because it
+        // is cheaper to allocate only three Python objects for everything
+        // rather than one per frame.
+        std::vector<std::pair<PyCodeObject *, int>> frames = tb.RawFrames();
+        nb::list out_code = nb::steal<nb::list>(PyList_New(frames.size()));
+        nb::list out_lasti = nb::steal<nb::list>(PyList_New(frames.size()));
+        for (size_t i = 0; i < frames.size(); ++i) {
+          const auto &frame = frames[i];
+          PyObject *code = reinterpret_cast<PyObject *>(frame.first);
+          Py_INCREF(code);
+          PyList_SET_ITEM(out_code.ptr(), i, code);
+          PyList_SET_ITEM(out_lasti.ptr(), i,
+                          nb::int_(frame.second).release().ptr());
+        }
+        return nb::make_tuple(out_code, out_lasti);
+      },
+      nb::is_method());
+  type.attr("as_python_traceback") =
+      nb::cpp_function(AsPythonTraceback, nb::is_method());
+
+  type.attr("traceback_from_frames") = nb::cpp_function(
+      [](std::vector<Traceback::Frame> frames) {
+        nb::object traceback = nb::none();
+        nb::dict globals;
+        nb::handle traceback_type(
+            reinterpret_cast<PyObject *>(&PyTraceBack_Type));
+        for (const Traceback::Frame &frame : frames) {
+          PyCodeObject *py_code =
+              PyCode_NewEmpty(frame.file_name.c_str(),
+                              frame.function_name.c_str(), frame.line_num);
+          PyFrameObject *py_frame = PyFrame_New(PyThreadState_Get(), py_code,
+                                                globals.ptr(), /*locals=*/
+                                                nullptr);
+          Py_DECREF(py_code);
+          traceback = traceback_type(
+              /*tb_next=*/std::move(traceback),
+              /*tb_frame=*/
+              nb::steal<nb::object>(reinterpret_cast<PyObject *>(py_frame)),
+              /*tb_lasti=*/0,
+              /*tb_lineno=*/
+              frame.line_num);
+        }
+        return traceback;
+      },
+      "Creates a traceback from a list of frames.");
+
+  type.attr("code_addr2line") = nb::cpp_function(
+      [](nb::handle code, int lasti) {
+        if (!PyCode_Check(code.ptr())) {
+          throw std::runtime_error("code argument must be a code object");
+        }
+        return PyCode_Addr2Line(reinterpret_cast<PyCodeObject *>(code.ptr()),
+                                lasti);
+      },
+      "Python wrapper around the Python C API function PyCode_Addr2Line");
+
+  type.attr("code_addr2location") = nb::cpp_function(
+      [](nb::handle code, int lasti) {
+        if (!PyCode_Check(code.ptr())) {
+          throw std::runtime_error("code argument must be a code object");
+        }
+        int start_line, start_column, end_line, end_column;
+        if (!PyCode_Addr2Location(reinterpret_cast<PyCodeObject *>(code.ptr()),
+                                  lasti, &start_line, &start_column, &end_line,
+                                  &end_column)) {
+          throw nb::python_error();
+        }
+        return nb::make_tuple(start_line, start_column, end_line, end_column);
+      },
+      "Python wrapper around the Python C API function PyCode_Addr2Location");
+}
+} // namespace mlir::python
+
+std::size_t std::hash<mlir::python::TracebackObject>::operator()(
+    const mlir::python::TracebackObject &tb) const noexcept {
+  const unsigned length = Py_SIZE(&tb);
+  const mlir::python::TracebackEntry *begin = &tb.frames[0];
+  const mlir::python::TracebackEntry *end = begin + length;
+  const unsigned *VBegin = reinterpret_cast<const unsigned *>(begin);
+  const unsigned *VEnd = reinterpret_cast<const unsigned *>(end);
+  return llvm::hash_combine(length, llvm::hash_combine_range(VBegin, VEnd));
+}
+
+std::size_t std::hash<mlir::python::TracebackEntry>::operator()(
+    const mlir::python::TracebackEntry &tbe) const noexcept {
+  return llvm::hash_combine(tbe.code, tbe.lasti);
+}
\ No newline at end of file
diff --git a/mlir/lib/Bindings/Python/Traceback.h b/mlir/lib/Bindings/Python/Traceback.h
new file mode 100644
index 0000000000000..0aab15ddc5da3
--- /dev/null
+++ b/mlir/lib/Bindings/Python/Traceback.h
@@ -0,0 +1,63 @@
+/* Copyright 2020 The JAX Authors
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef JAXLIB_TRACEBACK_H_
+#define JAXLIB_TRACEBACK_H_
+
+#include <Python.h>
+
+#include <optional>
+#include <string>
+#include <utility>
+#include <vector>
+
+// placeholder for index annotation headers
+#include "nanobind/nanobind.h"
+
+namespace mlir::python {
+
+class Traceback : public nanobind::object {
+public:
+  NB_OBJECT(Traceback, nanobind::object, "Traceback", Traceback::Check);
+
+  // Returns a traceback if it is enabled, otherwise returns nullopt.
+  static std::optional<Traceback> Get();
+
+  // Returns a string representation of the traceback.
+  std::string ToString() const;
+
+  // Returns a list of (code, lasti) pairs for each frame in the traceback.
+  std::vector<std::pair<PyCodeObject *, int>> RawFrames() const;
+
+  struct Frame {
+    nanobind::str file_name;
+    nanobind::str function_name;
+    int function_start_line;
+    int line_num;
+
+    std::string ToString() const;
+  };
+  // Returns a list of Frames for the traceback.
+  std::vector<Frame> Frames() const;
+
+private:
+  static bool Check(PyObject *o);
+};
+
+void BuildTracebackSubmodule(nanobind::module_ &m);
+
+} // namespace mlir::python
+
+#endif // JAXLIB_TRACEBACK_H_
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 7a0c95ebb8200..243fbe64de900 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -23,6 +23,8 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
     passmanager.py
     rewrite.py
     dialects/_ods_common.py
+    source_info_util.py
+    traceback_util.py
 
     # The main _mlir module has submodules: include stubs from each.
     _mlir_libs/_mlir/__init__.pyi
@@ -486,6 +488,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
     IRTypes.cpp
     Pass.cpp
     Rewrite.cpp
+    Traceback.cpp
 
     # Headers must be included explicitly so they are installed.
     Globals.h
@@ -493,6 +496,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
     Pass.h
     NanobindUtils.h
     Rewrite.h
+    Traceback.h
   PRIVATE_LINK_LIBS
     LLVMSupport
   EMBED_CAPI_LINK_LIBS
diff --git a/mlir/python/mlir/source_info_util.py b/mlir/python/mlir/source_info_util.py
new file mode 100644
index 0000000000000..ff7ee09149c60
--- /dev/null
+++ b/mlir/python/mlir/source_info_util.py
@@ -0,0 +1,367 @@
+# Copyright 2020 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from collections.abc import Iterator
+import contextlib
+import dataclasses
+import functools
+import itertools
+import os.path
+import re
+import sysconfig
+import threading
+import types
+from typing import NamedTuple
+
+# import jax.version
+# from jax._src.lib import xla_client
+
+from . import traceback_util
+
+traceback_util.register_exclusion(__file__)
+
+from ._mlir_libs._mlir import Traceback
+
+
+class Frame(NamedTuple):
+    file_name: str
+    function_name: str
+    start_line: int
+    start_column: int
+    end_line: int
+    end_column: int
+
+
+_exclude_paths: list[str] = [
+    # Attach the separator to make sure that .../jax does not end up matching
+    # .../jax_triton and other packages that might have a jax prefix.
+    # os.path.dirname(jax.version.__file__) + os.sep,
+    # Also exclude stdlib as user frames. In a non-standard Python runtime,
+    # the following two may be different.
+    sysconfig.get_path("stdlib"),
+    os.path.dirname(sysconfig.__file__),
+]
+
+
+ at functools.cache
+def _exclude_path_regex() -> re.Pattern[str]:
+    # The regex below would not handle an empty set of exclusions correctly.
+    assert len(_exclude_paths) > 0
+    return re.compile("|".join(f"^{re.escape(path)}" for path in _exclude_paths))
+
+
+def register_exclusion(path: str):
+    _exclude_paths.append(path)
+    _exclude_path_regex.cache_clear()
+    is_user_filename.cache_clear()
+
+
+# Explicit inclusions take priority over exclude paths.
+_include_paths: list[str] = []
+
+
+ at functools.cache
+def _include_path_regex() -> re.Pattern[str]:
+    patterns = [f"^{re.escape(path)}" for path in _include_paths]
+    patterns.append("_test.py$")
+    return re.compile("|".join(patterns))
+
+
+def register_inclusion(path: str):
+    _include_paths.append(path)
+    _include_path_regex.cache_clear()
+    is_user_filename.cache_clear()
+
+
+class Scope(NamedTuple):
+    name: str
+
+    def wrap(self, stack: list[str]):
+        stack.append(self.name)
+
+
+class Transform(NamedTuple):
+    name: str
+
+    def wrap(self, stack: list[str]):
+        if stack:
+            stack[-1] = f"{self.name}({stack[-1]})"
+        else:
+            stack.append(f"{self.name}()")
+
+
+ at dataclasses.dataclass(frozen=True)
+class NameStack:
+    stack: tuple[Scope | Transform, ...] = ()
+
+    def extend(self, name: str) -> NameStack:
+        return NameStack((*self.stack, Scope(name)))
+
+    def transform(self, transform_name: str) -> NameStack:
+        return NameStack((*self.stack, Transform(transform_name)))
+
+    def __getitem__(self, idx: slice) -> NameStack:
+        return NameStack(self.stack[idx])
+
+    def __len__(self):
+        return len(self.stack)
+
+    def __add__(self, other: NameStack) -> NameStack:
+        return NameStack(self.stack + other.stack)
+
+    def __radd__(self, other: NameStack) -> NameStack:
+        return NameStack(other.stack + self.stack)
+
+    def __str__(self) -> str:
+        scope: list[str] = []
+        for elem in self.stack[::-1]:
+            elem.wrap(scope)
+        return "/".join(reversed(scope))
+
+
+def new_name_stack(name: str = "") -> NameStack:
+    name_stack = NameStack()
+    if name:
+        name_stack = name_stack.extend(name)
+    return name_stack
+
+
+class SourceInfo:
+    traceback: Traceback | None
+    name_stack: NameStack
+
+    # It's slightly faster to use a class with __slots__ than a NamedTuple.
+    __slots__ = ["traceback", "name_stack"]
+
+    def __init__(self, traceback: Traceback | None, name_stack: NameStack):
+        self.traceback = traceback
+        self.name_stack = name_stack
+
+    def replace(
+        self, *, traceback: Traceback | None = None, name_stack: NameStack | None = None
+    ) -> SourceInfo:
+        return SourceInfo(
+            self.traceback if traceback is None else traceback,
+            self.name_stack if name_stack is None else name_stack,
+        )
+
+
+def new_source_info() -> SourceInfo:
+    return SourceInfo(None, NameStack())
+
+
+ at functools.cache
+def is_user_filename(filename: str) -> bool:
+    """Heuristic that guesses the identity of the user's code in a stack trace."""
+    return (
+        _include_path_regex().search(filename) is not None
+        or _exclude_path_regex().search(filename) is None
+    )
+
+
+def raw_frame_to_frame(code: types.CodeType, lasti: int) -> Frame:
+    loc = Traceback.code_addr2location(code, lasti)
+    start_line, start_column, end_line, end_column = loc
+    return Frame(
+        file_name=code.co_filename,
+        function_name=code.co_qualname,
+        start_line=start_line,
+        start_column=start_column,
+        end_line=end_line,
+        end_column=end_column,
+    )
+
+
+def user_frames(traceback: Traceback | None) -> Iterator[Frame]:
+    """Iterator over the user's frames, filtering jax-internal frames."""
+    # Guess the user's frame is the innermost frame not in the jax source tree or
+    # Python stdlib. We don't use traceback_util.path_starts_with because that
+    # incurs filesystem access, which may be slow; we call this function when
+    # e.g. adding source provenance annotations to XLA lowerings, so we don't
+    # want to incur the cost. We consider files that end with _test.py as user
+    # frames, to allow testing this mechanism from tests.
+    code, lasti = traceback.raw_frames() if traceback else ([], [])
+    return (
+        raw_frame_to_frame(code[i], lasti[i])
+        for i in range(len(code))
+        if is_user_filename(code[i].co_filename)
+    )
+
+
+ at functools.lru_cache(maxsize=64)
+def user_frame(traceback: Traceback | None) -> Frame | None:
+    return next(user_frames(traceback), None)
+
+
+def _summarize_frame(frame: Frame) -> str:
+    if frame.start_column != 0:
+        return (
+            f"{frame.file_name}:{frame.start_line}:{frame.start_column} "
+            f"({frame.function_name})"
+        )
+    else:
+        return f"{frame.file_name}:{frame.start_line} ({frame.function_name})"
+
+
+def summarize(source_info: SourceInfo, num_frames=1) -> str:
+    frames = itertools.islice(user_frames(source_info.traceback), num_frames)
+    frame_strs = [_summarize_frame(frame) if frame else "unknown" for frame in frames]
+    return "\n".join(reversed(frame_strs))
+
+
+class _SourceInfoContext(threading.local):
+    context: SourceInfo
+
+    def __init__(self):
+        self.context = new_source_info()
+
+
+_source_info_context = _SourceInfoContext()
+
+
+def current() -> SourceInfo:
+    source_info = _source_info_context.context
+    if not source_info.traceback:
+        source_info = source_info.replace(traceback=Traceback.get_traceback())
+    return source_info
+
+
+class JaxStackTraceBeforeTransformation(Exception):
+    pass
+
+
+_message = (
+    "The preceding stack trace is the source of the JAX operation that, once "
+    "transformed by JAX, triggered the following exception.\n"
+    "\n--------------------"
+)
+
+
+def has_user_context(e):
+    while e is not None:
+        if isinstance(e, JaxStackTraceBeforeTransformation):
+            return True
+        e = e.__cause__
+    return False
+
+
+class UserContextManager:
+    __slots__ = ["traceback", "name_stack", "prev"]
+
+    def __init__(
+        self, traceback: Traceback | None, *, name_stack: NameStack | None = None
+    ):
+        self.traceback = traceback
+        self.name_stack = name_stack
+
+    def __enter__(self):
+        self.prev = _source_info_context.context
+        _source_info_context.context = _source_info_context.context.replace(
+            traceback=self.traceback, name_stack=self.name_stack
+        )
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        _source_info_context.context = self.prev
+        if exc_type is None or exc_value is None:
+            return
+
+        if self.traceback is None or has_user_context(exc_value):
+            return
+
+        filtered_tb = traceback_util.filter_traceback(
+            self.traceback.as_python_traceback()
+        )
+        if filtered_tb:
+            msg = traceback_util.format_exception_only(exc_value)
+            msg = f"{msg}\n\n{_message}"
+            exp = JaxStackTraceBeforeTransformation(msg).with_traceback(filtered_tb)
+            exp.__context__ = exc_value.__context__
+            exp.__cause__ = exc_value.__cause__
+            exp.__suppress_context__ = exc_value.__suppress_context__
+            exc_value.__context__ = None
+            exc_value.__cause__ = exp
+
+
+user_context = UserContextManager
+
+
+def current_name_stack() -> NameStack:
+    return _source_info_context.context.name_stack
+
+
+class ExtendNameStackContextManager(contextlib.ContextDecorator):
+    __slots__ = ["name", "prev"]
+
+    def __init__(self, name: str):
+        self.name = name
+
+    def __enter__(self):
+        self.prev = prev = _source_info_context.context
+        name_stack = prev.name_stack.extend(self.name)
+        _source_info_context.context = prev.replace(name_stack=name_stack)
+        return name_stack
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        _source_info_context.context = self.prev
+
+
+extend_name_stack = ExtendNameStackContextManager
+
+
+class SetNameStackContextManager(contextlib.ContextDecorator):
+    __slots__ = ["name_stack", "prev"]
+
+    def __init__(self, name_stack: NameStack):
+        self.name_stack = name_stack
+
+    def __enter__(self):
+        self.prev = prev = _source_info_context.context
+        _source_info_context.context = prev.replace(name_stack=self.name_stack)
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        _source_info_context.context = self.prev
+
+
+set_name_stack = SetNameStackContextManager
+
+
+# TODO(mattjj,phawkins): figure out why the commented-out reset_name_stack
+# implementation doesn't work. Luckily this context manager isn't called much so
+# the performance shouldn't matter. See blame commit message for repro.
+# reset_name_stack = lambda: SetNameStackContextManager(NameStack())
+ at contextlib.contextmanager
+def reset_name_stack() -> Iterator[None]:
+    with set_name_stack(NameStack()):
+        yield
+
+
+class TransformNameStackContextManager(contextlib.ContextDecorator):
+    __slots__ = ["name", "prev"]
+
+    def __init__(self, name: str):
+        self.name = name
+
+    def __enter__(self):
+        self.prev = prev = _source_info_context.context
+        name_stack = prev.name_stack.transform(self.name)
+        _source_info_context.context = prev.replace(name_stack=name_stack)
+        return name_stack
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        _source_info_context.context = self.prev
+
+
+transform_name_stack = TransformNameStackContextManager
diff --git a/mlir/python/mlir/traceback_util.py b/mlir/python/mlir/traceback_util.py
new file mode 100644
index 0000000000000..1794fcd1d08ba
--- /dev/null
+++ b/mlir/python/mlir/traceback_util.py
@@ -0,0 +1,238 @@
+# Copyright 2020 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from collections.abc import Callable
+import functools
+import os
+import traceback
+import types
+from typing import Any, TypeVar, cast
+
+
+C = TypeVar("C", bound=Callable[..., Any])
+
+_exclude_paths: list[str] = [__file__]
+
+
+def register_exclusion(path: str):
+    _exclude_paths.append(path)
+
+
+_jax_message_append = (
+    "The stack trace below excludes JAX-internal frames.\n"
+    "The preceding is the original exception that occurred, unmodified.\n"
+    "\n--------------------"
+)
+
+
+def _path_starts_with(path: str, path_prefix: str) -> bool:
+    path = os.path.abspath(path)
+    path_prefix = os.path.abspath(path_prefix)
+    try:
+        common = os.path.commonpath([path, path_prefix])
+    except ValueError:
+        # path and path_prefix are both absolute, the only case will raise a
+        # ValueError is different drives.
+        # https://docs.python.org/3/library/os.path.html#os.path.commonpath
+        return False
+    try:
+        return common == path_prefix or os.path.samefile(common, path_prefix)
+    except OSError:
+        # One of the paths may not exist.
+        return False
+
+
+def include_frame(f: types.FrameType) -> bool:
+    return include_filename(f.f_code.co_filename)
+
+
+def include_filename(filename: str) -> bool:
+    return not any(_path_starts_with(filename, path) for path in _exclude_paths)
+
+
+# When scanning stack traces, we might encounter frames from cpython that are
+# removed from printed stack traces, such as frames from parts of importlib. We
+# ignore these frames heuristically based on source and name match.
+def _ignore_known_hidden_frame(f: types.FrameType) -> bool:
+    return "importlib._bootstrap" in f.f_code.co_filename
+
+
+def _add_tracebackhide_to_hidden_frames(tb: types.TracebackType):
+    for f, _lineno in traceback.walk_tb(tb):
+        if not include_frame(f):
+            f.f_locals["__tracebackhide__"] = True
+
+
+def filter_traceback(tb: types.TracebackType) -> types.TracebackType | None:
+    out = None
+    # Scan the traceback and collect relevant frames.
+    frames = list(traceback.walk_tb(tb))
+    for f, lineno in reversed(frames):
+        if include_frame(f):
+            out = types.TracebackType(out, f, f.f_lasti, lineno)
+    return out
+
+
+def _add_call_stack_frames(tb: types.TracebackType) -> types.TracebackType:
+    # Continue up the call stack.
+    #
+    # We would like to avoid stepping too far up, e.g. past the exec/eval point of
+    # a REPL such as IPython. To that end, we stop past the first contiguous bunch
+    # of module-level frames, if we reach any such frames at all. This is a
+    # heuristic that might stop in advance of the REPL boundary. For example, if
+    # the call stack includes module-level frames from the current module A, and
+    # the current module A was imported from within a function F elsewhere, then
+    # the stack trace we produce will be truncated at F's frame.
+    out = tb
+
+    reached_module_level = False
+    for f, lineno in traceback.walk_stack(tb.tb_frame):
+        if _ignore_known_hidden_frame(f):
+            continue
+        if reached_module_level and f.f_code.co_name != "<module>":
+            break
+        if include_frame(f):
+            out = types.TracebackType(out, f, f.f_lasti, lineno)
+        if f.f_code.co_name == "<module>":
+            reached_module_level = True
+    return out
+
+
+def _is_reraiser_frame(f: traceback.FrameSummary) -> bool:
+    return f.filename == __file__ and f.name == "reraise_with_filtered_traceback"
+
+
+def _is_under_reraiser(e: BaseException) -> bool:
+    if e.__traceback__ is None:
+        return False
+    tb = traceback.extract_stack(e.__traceback__.tb_frame)
+    return any(_is_reraiser_frame(f) for f in tb[:-1])
+
+
+def format_exception_only(e: BaseException) -> str:
+    return "".join(traceback.format_exception_only(type(e), e)).strip()
+
+
+class UnfilteredStackTrace(Exception):
+    pass
+
+
+_simplified_tb_msg = (
+    "For simplicity, JAX has removed its internal frames from the "
+    "traceback of the following exception. Set "
+    "JAX_TRACEBACK_FILTERING=off to include these."
+)
+
+
+class SimplifiedTraceback(Exception):
+    def __str__(self):
+        return _simplified_tb_msg
+
+
+SimplifiedTraceback.__module__ = "jax.errors"
+
+
+def _running_under_ipython() -> bool:
+    """Returns true if we appear to be in an IPython session."""
+    try:
+        get_ipython()  # type: ignore
+        return True
+    except NameError:
+        return False
+
+
+def _ipython_supports_tracebackhide() -> bool:
+    """Returns true if the IPython version supports __tracebackhide__."""
+    import IPython  # pytype: disable=import-error
+
+    return IPython.version_info[:2] >= (7, 17)
+
+
+def _filtering_mode() -> str:
+    mode = None
+    if mode is None or mode == "auto":
+        if _running_under_ipython() and _ipython_supports_tracebackhide():
+            mode = "tracebackhide"
+        else:
+            mode = "quiet_remove_frames"
+    return mode
+
+
+def api_boundary(fun: C) -> C:
+    """Wraps ``fun`` to form a boundary for filtering exception tracebacks.
+
+    When an exception occurs below ``fun``, this appends to it a custom
+    ``__cause__`` that carries a filtered traceback. The traceback imitates the
+    stack trace of the original exception, but with JAX-internal frames removed.
+
+    This boundary annotation works in composition with itself. The topmost frame
+    corresponding to an :func:`~api_boundary` is the one below which stack traces
+    are filtered. In other words, if ``api_boundary(f)`` calls
+    ``api_boundary(g)``, directly or indirectly, the filtered stack trace provided
+    is the same as if ``api_boundary(f)`` were to simply call ``g`` instead.
+
+    This annotation is primarily useful in wrapping functions output by JAX's
+    transformations. For example, consider ``g = jax.jit(f)``. When ``g`` is
+    called, JAX's JIT compilation machinery is invoked, which in turn calls ``f``
+    in order to trace and translate it. If the function ``f`` raises an exception,
+    the stack unwinds through JAX's JIT internals up to the original call site of
+    ``g``. Because the function returned by :func:`~jax.jit` is annotated as an
+    :func:`~api_boundary`, such an exception is accompanied by an additional
+    traceback that excludes the frames specific to JAX's implementation.
+    """
+
+    @functools.wraps(fun)
+    def reraise_with_filtered_traceback(*args, **kwargs):
+        __tracebackhide__ = True
+        try:
+            return fun(*args, **kwargs)
+        except Exception as e:
+            mode = _filtering_mode()
+            if _is_under_reraiser(e) or mode == "off":
+                raise
+            if mode == "tracebackhide":
+                _add_tracebackhide_to_hidden_frames(e.__traceback__)
+                raise
+
+            filtered_tb, unfiltered = None, None
+            try:
+                tb = e.__traceback__
+                filtered_tb = filter_traceback(tb)
+                e.with_traceback(filtered_tb)
+                if mode == "quiet_remove_frames":
+                    e.add_note("--------------------\n" + _simplified_tb_msg)
+                else:
+                    if mode == "remove_frames":
+                        msg = format_exception_only(e)
+                        msg = f"{msg}\n\n{_jax_message_append}"
+                        jax_error = UnfilteredStackTrace(msg)
+                        jax_error.with_traceback(_add_call_stack_frames(tb))
+                    else:
+                        raise ValueError(
+                            f"JAX_TRACEBACK_FILTERING={mode} is not a valid value."
+                        )
+                    jax_error.__cause__ = e.__cause__
+                    jax_error.__context__ = e.__context__
+                    jax_error.__suppress_context__ = e.__suppress_context__
+                    e.__cause__ = jax_error
+                    e.__context__ = None
+                raise
+            finally:
+                del filtered_tb
+                del unfiltered
+                del mode
+
+    return cast(C, reraise_with_filtered_traceback)



More information about the Mlir-commits mailing list