[Mlir-commits] [mlir] [mlir][python] source line info (PR #149166)

Maksim Levental llvmlistbot at llvm.org
Tue Jul 29 14:49:39 PDT 2025


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/149166

>From ebf792d978e44d0df8ba8d538972096fe9fd00bd Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Wed, 16 Jul 2025 14:50:30 -0400
Subject: [PATCH 1/4] [mlir][python] source line info

---
 mlir/lib/Bindings/Python/MainModule.cpp |   2 +
 mlir/lib/Bindings/Python/Traceback.cpp  | 442 ++++++++++++++++++++++++
 mlir/lib/Bindings/Python/Traceback.h    |  63 ++++
 mlir/python/CMakeLists.txt              |   4 +
 mlir/python/mlir/source_info_util.py    | 367 ++++++++++++++++++++
 mlir/python/mlir/traceback_util.py      | 238 +++++++++++++
 6 files changed, 1116 insertions(+)
 create mode 100644 mlir/lib/Bindings/Python/Traceback.cpp
 create mode 100644 mlir/lib/Bindings/Python/Traceback.h
 create mode 100644 mlir/python/mlir/source_info_util.py
 create mode 100644 mlir/python/mlir/traceback_util.py

diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 6f49431006605..489d8e21a56cd 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -12,6 +12,7 @@
 #include "NanobindUtils.h"
 #include "Pass.h"
 #include "Rewrite.h"
+#include "Traceback.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 
 namespace nb = nanobind;
@@ -105,6 +106,7 @@ NB_MODULE(_mlir, m) {
       "typeid"_a, nb::kw_only(), "replace"_a = false,
       "Register a value caster for casting MLIR values to custom user values.");
 
+  BuildTracebackSubmodule(m);
   // Define and populate IR submodule.
   auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
   populateIRCore(irModule);
diff --git a/mlir/lib/Bindings/Python/Traceback.cpp b/mlir/lib/Bindings/Python/Traceback.cpp
new file mode 100644
index 0000000000000..fee85cca6574f
--- /dev/null
+++ b/mlir/lib/Bindings/Python/Traceback.cpp
@@ -0,0 +1,442 @@
+/* Copyright 2020 The JAX Authors
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "Traceback.h"
+
+#include <Python.h>
+
+#include <array>
+#include <atomic>
+#include <cstddef>
+#include <cstring>
+#include <optional>
+#include <string>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include "nanobind/nanobind.h"
+#include "nanobind/stl/optional.h"    // IWYU pragma: keep
+#include "nanobind/stl/string.h"      // IWYU pragma: keep
+#include "nanobind/stl/string_view.h" // IWYU pragma: keep
+#include "nanobind/stl/vector.h"      // IWYU pragma: keep
+
+#include "llvm/ADT/StringExtras.h"
+
+#ifdef PLATFORM_GOOGLE
+#define Py_BUILD_CORE
+#include "internal/pycore_frame.h"
+#undef Py_BUILD_CORE
+#endif // PLATFORM_GOOGLE
+
+namespace mlir::python {
+struct TracebackEntry;
+struct TracebackObject;
+} // namespace mlir::python
+namespace nb = nanobind;
+
+template <>
+struct std::hash<mlir::python::TracebackObject> {
+  std::size_t
+  operator()(const mlir::python::TracebackObject &tb) const noexcept;
+};
+
+template <>
+struct std::hash<mlir::python::TracebackEntry> {
+  std::size_t
+  operator()(const mlir::python::TracebackEntry &tbe) const noexcept;
+};
+
+namespace mlir::python {
+
+std::atomic<bool> traceback_enabled_ = true;
+
+static constexpr int kMaxFrames = 512;
+
+PyTypeObject *traceback_type_ = nullptr;
+
+// Entry in a traceback. Must be POD.
+struct TracebackEntry {
+  TracebackEntry() = default;
+  TracebackEntry(PyCodeObject *code, int lasti) : code(code), lasti(lasti) {}
+  PyCodeObject *code;
+  int lasti;
+
+  bool operator==(const TracebackEntry &other) const {
+    return code == other.code && lasti == other.lasti;
+  }
+  bool operator!=(const TracebackEntry &other) const {
+    return !operator==(other);
+  }
+};
+static_assert(std::is_trivial_v<TracebackEntry> == true);
+
+struct TracebackObject {
+  PyObject_VAR_HEAD;
+  TracebackEntry frames[];
+};
+
+static_assert(sizeof(TracebackObject) % alignof(PyObject) == 0);
+static_assert(sizeof(TracebackEntry) % alignof(void *) == 0);
+
+bool traceback_check(nb::handle o) {
+  return Py_TYPE(o.ptr()) == traceback_type_;
+}
+
+Py_hash_t traceback_tp_hash(PyObject *o) {
+  TracebackObject *tb = reinterpret_cast<TracebackObject *>(o);
+  std::hash<TracebackObject> hasher{};
+  size_t h = hasher(*tb);
+  Py_hash_t s = llvm::bit_cast<Py_hash_t>(h); // Python hashes are signed.
+  return s == -1 ? -2 : s; // -1 must not be used as a Python hash value.
+}
+
+PyObject *traceback_tp_richcompare(PyObject *self, PyObject *other, int op) {
+  if (op != Py_EQ && op != Py_NE) {
+    return Py_NewRef(Py_NotImplemented);
+  }
+
+  if (!traceback_check(other)) {
+    return Py_NewRef(Py_False);
+  }
+  TracebackObject *tb_self = reinterpret_cast<TracebackObject *>(self);
+  TracebackObject *tb_other = reinterpret_cast<TracebackObject *>(other);
+  if (Py_SIZE(tb_self) != Py_SIZE(tb_other)) {
+    return Py_NewRef(op == Py_EQ ? Py_False : Py_True);
+  }
+  for (Py_ssize_t i = 0; i < Py_SIZE(tb_self); ++i) {
+    if ((tb_self->frames[i] != tb_other->frames[i])) {
+      return Py_NewRef(op == Py_EQ ? Py_False : Py_True);
+    }
+  }
+  return Py_NewRef(op == Py_EQ ? Py_True : Py_False);
+}
+
+static void traceback_tp_dealloc(PyObject *self) {
+  TracebackObject *tb = reinterpret_cast<TracebackObject *>(self);
+  for (Py_ssize_t i = 0; i < Py_SIZE(tb); ++i) {
+    Py_XDECREF(tb->frames[i].code);
+  }
+  PyTypeObject *tp = Py_TYPE(self);
+  tp->tp_free((PyObject *)self);
+  Py_DECREF(tp);
+}
+
+Traceback::Frame DecodeFrame(const TracebackEntry &frame) {
+  return Traceback::Frame{
+      .file_name = nb::borrow<nb::str>(frame.code->co_filename),
+      .function_name = nb::borrow<nb::str>(frame.code->co_qualname),
+      .function_start_line = frame.code->co_firstlineno,
+      .line_num = PyCode_Addr2Line(frame.code, frame.lasti),
+  };
+}
+
+std::string traceback_to_string(const TracebackObject *tb) {
+  std::vector<std::string> frame_strs;
+  frame_strs.reserve(Py_SIZE(tb));
+  for (Py_ssize_t i = 0; i < Py_SIZE(tb); ++i) {
+    frame_strs.push_back(DecodeFrame(tb->frames[i]).ToString());
+  }
+  return llvm::join(frame_strs, "\n");
+}
+
+PyObject *traceback_tp_str(PyObject *self) {
+  TracebackObject *tb = reinterpret_cast<TracebackObject *>(self);
+  return nb::cast(traceback_to_string(tb)).release().ptr();
+}
+
+// It turns out to be slightly faster to define a tp_hash slot rather than
+// defining __hash__ and __eq__ on the class.
+PyType_Slot traceback_slots_[] = {
+    {Py_tp_hash, reinterpret_cast<void *>(traceback_tp_hash)},
+    {Py_tp_richcompare, reinterpret_cast<void *>(traceback_tp_richcompare)},
+    {Py_tp_dealloc, reinterpret_cast<void *>(traceback_tp_dealloc)},
+    {Py_tp_str, reinterpret_cast<void *>(traceback_tp_str)},
+    {0, nullptr},
+};
+
+nb::object AsPythonTraceback(const Traceback &tb) {
+  nb::object traceback = nb::none();
+  nb::dict globals;
+  nb::handle traceback_type(reinterpret_cast<PyObject *>(&PyTraceBack_Type));
+  TracebackObject *tb_obj = reinterpret_cast<TracebackObject *>(tb.ptr());
+  for (Py_ssize_t i = 0; i < Py_SIZE(tb_obj); ++i) {
+    const TracebackEntry &frame = tb_obj->frames[i];
+    int lineno = PyCode_Addr2Line(frame.code, frame.lasti);
+    // Under Python 3.11 we observed crashes when using a fake PyFrameObject
+    // with a real PyCodeObject (https://github.com/google/jax/issues/16027).
+    // because the frame does not have fields necessary to compute the locals,
+    // notably the closure object, leading to crashes in CPython in
+    // _PyFrame_FastToLocalsWithError
+    // https://github.com/python/cpython/blob/deaf509e8fc6e0363bd6f26d52ad42f976ec42f2/Objects/frameobject.c#LL1116C2-L1116C2
+    // We therefore always build a fake code object to go along with our fake
+    // frame.
+    PyCodeObject *py_code =
+        PyCode_NewEmpty(PyUnicode_AsUTF8(frame.code->co_filename),
+                        PyUnicode_AsUTF8(frame.code->co_name), lineno);
+    PyFrameObject *py_frame = PyFrame_New(PyThreadState_Get(), py_code,
+                                          globals.ptr(), /*locals=*/nullptr);
+    Py_DECREF(py_code);
+
+    traceback = traceback_type(
+        /*tb_next=*/std::move(traceback),
+        /*tb_frame=*/
+        nb::steal<nb::object>(reinterpret_cast<PyObject *>(py_frame)),
+        /*tb_lasti=*/0,
+        /*tb_lineno=*/lineno);
+  }
+  return traceback;
+}
+
+std::vector<Traceback::Frame> Traceback::Frames() const {
+  // We require the GIL because we manipulate Python strings.
+  assert(PyGILState_Check());
+  std::vector<Traceback::Frame> frames;
+  TracebackObject *tb = reinterpret_cast<TracebackObject *>(ptr());
+  frames.reserve(Py_SIZE(tb));
+  for (Py_ssize_t i = 0; i < Py_SIZE(tb); ++i) {
+    const TracebackEntry &frame = tb->frames[i];
+    frames.push_back(DecodeFrame(frame));
+  }
+  return frames;
+}
+
+std::string Traceback::Frame::ToString() const {
+  std::string s = nb::cast<std::string>(file_name);
+  s += ":" + std::to_string(line_num) + " ";
+  s += "(" + nb::cast<std::string>(function_name) + ")";
+  return s;
+}
+
+std::string Traceback::ToString() const {
+  return traceback_to_string(reinterpret_cast<const TracebackObject *>(ptr()));
+}
+
+std::vector<std::pair<PyCodeObject *, int>> Traceback::RawFrames() const {
+  const TracebackObject *tb = reinterpret_cast<const TracebackObject *>(ptr());
+  std::vector<std::pair<PyCodeObject *, int>> frames;
+  frames.reserve(Py_SIZE(tb));
+  for (Py_ssize_t i = 0; i < Py_SIZE(tb); ++i) {
+    frames.push_back(std::make_pair(tb->frames[i].code, tb->frames[i].lasti));
+  }
+  return frames;
+}
+
+/*static*/ bool Traceback::Check(PyObject *o) { return traceback_check(o); }
+
+/*static*/ std::optional<Traceback> Traceback::Get() {
+  // We use a thread_local here mostly to avoid requiring a large amount of
+  // space.
+  thread_local std::array<TracebackEntry, kMaxFrames> frames;
+  int count = 0;
+
+  assert(PyGILState_Check());
+
+  if (!traceback_enabled_.load()) {
+    return std::nullopt;
+  }
+
+  PyThreadState *thread_state = PyThreadState_GET();
+
+#ifdef PLATFORM_GOOGLE
+// This code is equivalent to the version using public APIs, but it saves us
+// an allocation of one object per stack frame. However, this is definitely
+// violating the API contract of CPython, so we only use this where we can be
+// confident we know exactly which CPython we are using (internal to Google).
+// Feel free to turn this on if you like, but it might break at any time!
+#if PY_VERSION_HEX < 0x030d0000
+  for (_PyInterpreterFrame *f = thread_state->cframe->current_frame;
+       f != nullptr && count < kMaxFrames; f = f->previous) {
+    if (_PyFrame_IsIncomplete(f))
+      continue;
+    Py_INCREF(f->f_code);
+    frames[count] = {f->f_code, static_cast<int>(_PyInterpreterFrame_LASTI(f) *
+                                                 sizeof(_Py_CODEUNIT))};
+    ++count;
+  }
+#else  // PY_VERSION_HEX < 0x030d0000
+  for (_PyInterpreterFrame *f = thread_state->current_frame;
+       f != nullptr && count < kMaxFrames; f = f->previous) {
+    if (_PyFrame_IsIncomplete(f))
+      continue;
+    Py_INCREF(f->f_executable);
+    frames[count] = {
+        reinterpret_cast<PyCodeObject *>(f->f_executable),
+        static_cast<int>(_PyInterpreterFrame_LASTI(f) * sizeof(_Py_CODEUNIT))};
+    ++count;
+  }
+#endif // PY_VERSION_HEX < 0x030d0000
+
+#else  // PLATFORM_GOOGLE
+  PyFrameObject *next;
+  for (PyFrameObject *py_frame = PyThreadState_GetFrame(thread_state);
+       py_frame != nullptr && count < kMaxFrames; py_frame = next) {
+    frames[count] = {PyFrame_GetCode(py_frame), PyFrame_GetLasti(py_frame)};
+    ++count;
+    next = PyFrame_GetBack(py_frame);
+    Py_XDECREF(py_frame);
+  }
+#endif // PLATFORM_GOOGLE
+
+  Traceback traceback =
+      nb::steal<Traceback>(PyObject_NewVar(PyObject, traceback_type_, count));
+  TracebackObject *tb = reinterpret_cast<TracebackObject *>(traceback.ptr());
+  std::memcpy(tb->frames, frames.data(), sizeof(TracebackEntry) * count);
+  return traceback;
+}
+
+template <typename Func>
+nanobind::object nb_property_readonly(Func &&get) {
+  nanobind::handle property(reinterpret_cast<PyObject *>(&PyProperty_Type));
+  return property(nanobind::cpp_function(std::forward<Func>(get)),
+                  nanobind::none(), nanobind::none(), "");
+}
+
+void BuildTracebackSubmodule(nb::module_ &m) {
+  nb::class_<Traceback::Frame>(m, "Frame")
+      .def(nb::init<const nb::str &, const nb::str &, int, int>())
+      .def_ro("file_name", &Traceback::Frame::file_name)
+      .def_ro("function_name", &Traceback::Frame::function_name)
+      .def_ro("function_start_line", &Traceback::Frame::function_start_line)
+      .def_ro("line_num", &Traceback::Frame::line_num)
+      .def("__repr__", [](const Traceback::Frame &frame) {
+        std::string s = nb::cast<std::string>(frame.function_name);
+        s += ";" + nb::cast<std::string>(frame.file_name);
+        s += ":" + std::to_string(frame.line_num);
+        return s;
+      });
+
+  std::string name = nb::cast<std::string>(m.attr("__name__"));
+  name += ".Traceback";
+
+  PyType_Spec traceback_spec = {
+      /*.name=*/name.c_str(),
+      /*.basicsize=*/static_cast<int>(sizeof(TracebackObject)),
+      /*.itemsize=*/static_cast<int>(sizeof(TracebackEntry)),
+      /*.flags=*/Py_TPFLAGS_DEFAULT,
+      /*.slots=*/traceback_slots_,
+  };
+
+  traceback_type_ =
+      reinterpret_cast<PyTypeObject *>(PyType_FromSpec(&traceback_spec));
+  if (!traceback_type_) {
+    throw nb::python_error();
+  }
+
+  auto type = nb::borrow<nb::object>(traceback_type_);
+  m.attr("Traceback") = type;
+
+  m.def("tracebacks_enabled", []() { return traceback_enabled_.load(); });
+  m.def("set_tracebacks_enabled",
+        [](bool value) { traceback_enabled_.store(value); });
+
+  type.attr("get_traceback") = nb::cpp_function(Traceback::Get,
+                                                R"doc(
+      Returns a :class:`Traceback` for the current thread.
+
+      If ``Traceback.enabled`` is ``True``, returns a :class:`Traceback`
+      object that describes the Python stack of the calling thread. Stack
+      trace collection has a small overhead, so it is disabled by default. If
+      traceback collection is disabled, returns ``None``. )doc");
+  type.attr("frames") = nb_property_readonly(&Traceback::Frames);
+  type.attr("raw_frames") = nb::cpp_function(
+      [](const Traceback &tb) -> nb::tuple {
+        // We return a tuple of lists, rather than a list of tuples, because it
+        // is cheaper to allocate only three Python objects for everything
+        // rather than one per frame.
+        std::vector<std::pair<PyCodeObject *, int>> frames = tb.RawFrames();
+        nb::list out_code = nb::steal<nb::list>(PyList_New(frames.size()));
+        nb::list out_lasti = nb::steal<nb::list>(PyList_New(frames.size()));
+        for (size_t i = 0; i < frames.size(); ++i) {
+          const auto &frame = frames[i];
+          PyObject *code = reinterpret_cast<PyObject *>(frame.first);
+          Py_INCREF(code);
+          PyList_SET_ITEM(out_code.ptr(), i, code);
+          PyList_SET_ITEM(out_lasti.ptr(), i,
+                          nb::int_(frame.second).release().ptr());
+        }
+        return nb::make_tuple(out_code, out_lasti);
+      },
+      nb::is_method());
+  type.attr("as_python_traceback") =
+      nb::cpp_function(AsPythonTraceback, nb::is_method());
+
+  type.attr("traceback_from_frames") = nb::cpp_function(
+      [](std::vector<Traceback::Frame> frames) {
+        nb::object traceback = nb::none();
+        nb::dict globals;
+        nb::handle traceback_type(
+            reinterpret_cast<PyObject *>(&PyTraceBack_Type));
+        for (const Traceback::Frame &frame : frames) {
+          PyCodeObject *py_code =
+              PyCode_NewEmpty(frame.file_name.c_str(),
+                              frame.function_name.c_str(), frame.line_num);
+          PyFrameObject *py_frame = PyFrame_New(PyThreadState_Get(), py_code,
+                                                globals.ptr(), /*locals=*/
+                                                nullptr);
+          Py_DECREF(py_code);
+          traceback = traceback_type(
+              /*tb_next=*/std::move(traceback),
+              /*tb_frame=*/
+              nb::steal<nb::object>(reinterpret_cast<PyObject *>(py_frame)),
+              /*tb_lasti=*/0,
+              /*tb_lineno=*/
+              frame.line_num);
+        }
+        return traceback;
+      },
+      "Creates a traceback from a list of frames.");
+
+  type.attr("code_addr2line") = nb::cpp_function(
+      [](nb::handle code, int lasti) {
+        if (!PyCode_Check(code.ptr())) {
+          throw std::runtime_error("code argument must be a code object");
+        }
+        return PyCode_Addr2Line(reinterpret_cast<PyCodeObject *>(code.ptr()),
+                                lasti);
+      },
+      "Python wrapper around the Python C API function PyCode_Addr2Line");
+
+  type.attr("code_addr2location") = nb::cpp_function(
+      [](nb::handle code, int lasti) {
+        if (!PyCode_Check(code.ptr())) {
+          throw std::runtime_error("code argument must be a code object");
+        }
+        int start_line, start_column, end_line, end_column;
+        if (!PyCode_Addr2Location(reinterpret_cast<PyCodeObject *>(code.ptr()),
+                                  lasti, &start_line, &start_column, &end_line,
+                                  &end_column)) {
+          throw nb::python_error();
+        }
+        return nb::make_tuple(start_line, start_column, end_line, end_column);
+      },
+      "Python wrapper around the Python C API function PyCode_Addr2Location");
+}
+} // namespace mlir::python
+
+std::size_t std::hash<mlir::python::TracebackObject>::operator()(
+    const mlir::python::TracebackObject &tb) const noexcept {
+  const unsigned length = Py_SIZE(&tb);
+  const mlir::python::TracebackEntry *begin = &tb.frames[0];
+  const mlir::python::TracebackEntry *end = begin + length;
+  const unsigned *VBegin = reinterpret_cast<const unsigned *>(begin);
+  const unsigned *VEnd = reinterpret_cast<const unsigned *>(end);
+  return llvm::hash_combine(length, llvm::hash_combine_range(VBegin, VEnd));
+}
+
+std::size_t std::hash<mlir::python::TracebackEntry>::operator()(
+    const mlir::python::TracebackEntry &tbe) const noexcept {
+  return llvm::hash_combine(tbe.code, tbe.lasti);
+}
\ No newline at end of file
diff --git a/mlir/lib/Bindings/Python/Traceback.h b/mlir/lib/Bindings/Python/Traceback.h
new file mode 100644
index 0000000000000..0aab15ddc5da3
--- /dev/null
+++ b/mlir/lib/Bindings/Python/Traceback.h
@@ -0,0 +1,63 @@
+/* Copyright 2020 The JAX Authors
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef JAXLIB_TRACEBACK_H_
+#define JAXLIB_TRACEBACK_H_
+
+#include <Python.h>
+
+#include <optional>
+#include <string>
+#include <utility>
+#include <vector>
+
+// placeholder for index annotation headers
+#include "nanobind/nanobind.h"
+
+namespace mlir::python {
+
+class Traceback : public nanobind::object {
+public:
+  NB_OBJECT(Traceback, nanobind::object, "Traceback", Traceback::Check);
+
+  // Returns a traceback if it is enabled, otherwise returns nullopt.
+  static std::optional<Traceback> Get();
+
+  // Returns a string representation of the traceback.
+  std::string ToString() const;
+
+  // Returns a list of (code, lasti) pairs for each frame in the traceback.
+  std::vector<std::pair<PyCodeObject *, int>> RawFrames() const;
+
+  struct Frame {
+    nanobind::str file_name;
+    nanobind::str function_name;
+    int function_start_line;
+    int line_num;
+
+    std::string ToString() const;
+  };
+  // Returns a list of Frames for the traceback.
+  std::vector<Frame> Frames() const;
+
+private:
+  static bool Check(PyObject *o);
+};
+
+void BuildTracebackSubmodule(nanobind::module_ &m);
+
+} // namespace mlir::python
+
+#endif // JAXLIB_TRACEBACK_H_
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 7a0c95ebb8200..243fbe64de900 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -23,6 +23,8 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python
     passmanager.py
     rewrite.py
     dialects/_ods_common.py
+    source_info_util.py
+    traceback_util.py
 
     # The main _mlir module has submodules: include stubs from each.
     _mlir_libs/_mlir/__init__.pyi
@@ -486,6 +488,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
     IRTypes.cpp
     Pass.cpp
     Rewrite.cpp
+    Traceback.cpp
 
     # Headers must be included explicitly so they are installed.
     Globals.h
@@ -493,6 +496,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
     Pass.h
     NanobindUtils.h
     Rewrite.h
+    Traceback.h
   PRIVATE_LINK_LIBS
     LLVMSupport
   EMBED_CAPI_LINK_LIBS
diff --git a/mlir/python/mlir/source_info_util.py b/mlir/python/mlir/source_info_util.py
new file mode 100644
index 0000000000000..ff7ee09149c60
--- /dev/null
+++ b/mlir/python/mlir/source_info_util.py
@@ -0,0 +1,367 @@
+# Copyright 2020 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from collections.abc import Iterator
+import contextlib
+import dataclasses
+import functools
+import itertools
+import os.path
+import re
+import sysconfig
+import threading
+import types
+from typing import NamedTuple
+
+# import jax.version
+# from jax._src.lib import xla_client
+
+from . import traceback_util
+
+traceback_util.register_exclusion(__file__)
+
+from ._mlir_libs._mlir import Traceback
+
+
+class Frame(NamedTuple):
+    file_name: str
+    function_name: str
+    start_line: int
+    start_column: int
+    end_line: int
+    end_column: int
+
+
+_exclude_paths: list[str] = [
+    # Attach the separator to make sure that .../jax does not end up matching
+    # .../jax_triton and other packages that might have a jax prefix.
+    # os.path.dirname(jax.version.__file__) + os.sep,
+    # Also exclude stdlib as user frames. In a non-standard Python runtime,
+    # the following two may be different.
+    sysconfig.get_path("stdlib"),
+    os.path.dirname(sysconfig.__file__),
+]
+
+
+ at functools.cache
+def _exclude_path_regex() -> re.Pattern[str]:
+    # The regex below would not handle an empty set of exclusions correctly.
+    assert len(_exclude_paths) > 0
+    return re.compile("|".join(f"^{re.escape(path)}" for path in _exclude_paths))
+
+
+def register_exclusion(path: str):
+    _exclude_paths.append(path)
+    _exclude_path_regex.cache_clear()
+    is_user_filename.cache_clear()
+
+
+# Explicit inclusions take priority over exclude paths.
+_include_paths: list[str] = []
+
+
+ at functools.cache
+def _include_path_regex() -> re.Pattern[str]:
+    patterns = [f"^{re.escape(path)}" for path in _include_paths]
+    patterns.append("_test.py$")
+    return re.compile("|".join(patterns))
+
+
+def register_inclusion(path: str):
+    _include_paths.append(path)
+    _include_path_regex.cache_clear()
+    is_user_filename.cache_clear()
+
+
+class Scope(NamedTuple):
+    name: str
+
+    def wrap(self, stack: list[str]):
+        stack.append(self.name)
+
+
+class Transform(NamedTuple):
+    name: str
+
+    def wrap(self, stack: list[str]):
+        if stack:
+            stack[-1] = f"{self.name}({stack[-1]})"
+        else:
+            stack.append(f"{self.name}()")
+
+
+ at dataclasses.dataclass(frozen=True)
+class NameStack:
+    stack: tuple[Scope | Transform, ...] = ()
+
+    def extend(self, name: str) -> NameStack:
+        return NameStack((*self.stack, Scope(name)))
+
+    def transform(self, transform_name: str) -> NameStack:
+        return NameStack((*self.stack, Transform(transform_name)))
+
+    def __getitem__(self, idx: slice) -> NameStack:
+        return NameStack(self.stack[idx])
+
+    def __len__(self):
+        return len(self.stack)
+
+    def __add__(self, other: NameStack) -> NameStack:
+        return NameStack(self.stack + other.stack)
+
+    def __radd__(self, other: NameStack) -> NameStack:
+        return NameStack(other.stack + self.stack)
+
+    def __str__(self) -> str:
+        scope: list[str] = []
+        for elem in self.stack[::-1]:
+            elem.wrap(scope)
+        return "/".join(reversed(scope))
+
+
+def new_name_stack(name: str = "") -> NameStack:
+    name_stack = NameStack()
+    if name:
+        name_stack = name_stack.extend(name)
+    return name_stack
+
+
+class SourceInfo:
+    traceback: Traceback | None
+    name_stack: NameStack
+
+    # It's slightly faster to use a class with __slots__ than a NamedTuple.
+    __slots__ = ["traceback", "name_stack"]
+
+    def __init__(self, traceback: Traceback | None, name_stack: NameStack):
+        self.traceback = traceback
+        self.name_stack = name_stack
+
+    def replace(
+        self, *, traceback: Traceback | None = None, name_stack: NameStack | None = None
+    ) -> SourceInfo:
+        return SourceInfo(
+            self.traceback if traceback is None else traceback,
+            self.name_stack if name_stack is None else name_stack,
+        )
+
+
+def new_source_info() -> SourceInfo:
+    return SourceInfo(None, NameStack())
+
+
+ at functools.cache
+def is_user_filename(filename: str) -> bool:
+    """Heuristic that guesses the identity of the user's code in a stack trace."""
+    return (
+        _include_path_regex().search(filename) is not None
+        or _exclude_path_regex().search(filename) is None
+    )
+
+
+def raw_frame_to_frame(code: types.CodeType, lasti: int) -> Frame:
+    loc = Traceback.code_addr2location(code, lasti)
+    start_line, start_column, end_line, end_column = loc
+    return Frame(
+        file_name=code.co_filename,
+        function_name=code.co_qualname,
+        start_line=start_line,
+        start_column=start_column,
+        end_line=end_line,
+        end_column=end_column,
+    )
+
+
+def user_frames(traceback: Traceback | None) -> Iterator[Frame]:
+    """Iterator over the user's frames, filtering jax-internal frames."""
+    # Guess the user's frame is the innermost frame not in the jax source tree or
+    # Python stdlib. We don't use traceback_util.path_starts_with because that
+    # incurs filesystem access, which may be slow; we call this function when
+    # e.g. adding source provenance annotations to XLA lowerings, so we don't
+    # want to incur the cost. We consider files that end with _test.py as user
+    # frames, to allow testing this mechanism from tests.
+    code, lasti = traceback.raw_frames() if traceback else ([], [])
+    return (
+        raw_frame_to_frame(code[i], lasti[i])
+        for i in range(len(code))
+        if is_user_filename(code[i].co_filename)
+    )
+
+
+ at functools.lru_cache(maxsize=64)
+def user_frame(traceback: Traceback | None) -> Frame | None:
+    return next(user_frames(traceback), None)
+
+
+def _summarize_frame(frame: Frame) -> str:
+    if frame.start_column != 0:
+        return (
+            f"{frame.file_name}:{frame.start_line}:{frame.start_column} "
+            f"({frame.function_name})"
+        )
+    else:
+        return f"{frame.file_name}:{frame.start_line} ({frame.function_name})"
+
+
+def summarize(source_info: SourceInfo, num_frames=1) -> str:
+    frames = itertools.islice(user_frames(source_info.traceback), num_frames)
+    frame_strs = [_summarize_frame(frame) if frame else "unknown" for frame in frames]
+    return "\n".join(reversed(frame_strs))
+
+
+class _SourceInfoContext(threading.local):
+    context: SourceInfo
+
+    def __init__(self):
+        self.context = new_source_info()
+
+
+_source_info_context = _SourceInfoContext()
+
+
+def current() -> SourceInfo:
+    source_info = _source_info_context.context
+    if not source_info.traceback:
+        source_info = source_info.replace(traceback=Traceback.get_traceback())
+    return source_info
+
+
+class JaxStackTraceBeforeTransformation(Exception):
+    pass
+
+
+_message = (
+    "The preceding stack trace is the source of the JAX operation that, once "
+    "transformed by JAX, triggered the following exception.\n"
+    "\n--------------------"
+)
+
+
+def has_user_context(e):
+    while e is not None:
+        if isinstance(e, JaxStackTraceBeforeTransformation):
+            return True
+        e = e.__cause__
+    return False
+
+
+class UserContextManager:
+    __slots__ = ["traceback", "name_stack", "prev"]
+
+    def __init__(
+        self, traceback: Traceback | None, *, name_stack: NameStack | None = None
+    ):
+        self.traceback = traceback
+        self.name_stack = name_stack
+
+    def __enter__(self):
+        self.prev = _source_info_context.context
+        _source_info_context.context = _source_info_context.context.replace(
+            traceback=self.traceback, name_stack=self.name_stack
+        )
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        _source_info_context.context = self.prev
+        if exc_type is None or exc_value is None:
+            return
+
+        if self.traceback is None or has_user_context(exc_value):
+            return
+
+        filtered_tb = traceback_util.filter_traceback(
+            self.traceback.as_python_traceback()
+        )
+        if filtered_tb:
+            msg = traceback_util.format_exception_only(exc_value)
+            msg = f"{msg}\n\n{_message}"
+            exp = JaxStackTraceBeforeTransformation(msg).with_traceback(filtered_tb)
+            exp.__context__ = exc_value.__context__
+            exp.__cause__ = exc_value.__cause__
+            exp.__suppress_context__ = exc_value.__suppress_context__
+            exc_value.__context__ = None
+            exc_value.__cause__ = exp
+
+
+user_context = UserContextManager
+
+
+def current_name_stack() -> NameStack:
+    return _source_info_context.context.name_stack
+
+
+class ExtendNameStackContextManager(contextlib.ContextDecorator):
+    __slots__ = ["name", "prev"]
+
+    def __init__(self, name: str):
+        self.name = name
+
+    def __enter__(self):
+        self.prev = prev = _source_info_context.context
+        name_stack = prev.name_stack.extend(self.name)
+        _source_info_context.context = prev.replace(name_stack=name_stack)
+        return name_stack
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        _source_info_context.context = self.prev
+
+
+extend_name_stack = ExtendNameStackContextManager
+
+
+class SetNameStackContextManager(contextlib.ContextDecorator):
+    __slots__ = ["name_stack", "prev"]
+
+    def __init__(self, name_stack: NameStack):
+        self.name_stack = name_stack
+
+    def __enter__(self):
+        self.prev = prev = _source_info_context.context
+        _source_info_context.context = prev.replace(name_stack=self.name_stack)
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        _source_info_context.context = self.prev
+
+
+set_name_stack = SetNameStackContextManager
+
+
+# TODO(mattjj,phawkins): figure out why the commented-out reset_name_stack
+# implementation doesn't work. Luckily this context manager isn't called much so
+# the performance shouldn't matter. See blame commit message for repro.
+# reset_name_stack = lambda: SetNameStackContextManager(NameStack())
+ at contextlib.contextmanager
+def reset_name_stack() -> Iterator[None]:
+    with set_name_stack(NameStack()):
+        yield
+
+
+class TransformNameStackContextManager(contextlib.ContextDecorator):
+    __slots__ = ["name", "prev"]
+
+    def __init__(self, name: str):
+        self.name = name
+
+    def __enter__(self):
+        self.prev = prev = _source_info_context.context
+        name_stack = prev.name_stack.transform(self.name)
+        _source_info_context.context = prev.replace(name_stack=name_stack)
+        return name_stack
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        _source_info_context.context = self.prev
+
+
+transform_name_stack = TransformNameStackContextManager
diff --git a/mlir/python/mlir/traceback_util.py b/mlir/python/mlir/traceback_util.py
new file mode 100644
index 0000000000000..1794fcd1d08ba
--- /dev/null
+++ b/mlir/python/mlir/traceback_util.py
@@ -0,0 +1,238 @@
+# Copyright 2020 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from collections.abc import Callable
+import functools
+import os
+import traceback
+import types
+from typing import Any, TypeVar, cast
+
+
+C = TypeVar("C", bound=Callable[..., Any])
+
+_exclude_paths: list[str] = [__file__]
+
+
+def register_exclusion(path: str):
+    _exclude_paths.append(path)
+
+
+_jax_message_append = (
+    "The stack trace below excludes JAX-internal frames.\n"
+    "The preceding is the original exception that occurred, unmodified.\n"
+    "\n--------------------"
+)
+
+
+def _path_starts_with(path: str, path_prefix: str) -> bool:
+    path = os.path.abspath(path)
+    path_prefix = os.path.abspath(path_prefix)
+    try:
+        common = os.path.commonpath([path, path_prefix])
+    except ValueError:
+        # path and path_prefix are both absolute, the only case will raise a
+        # ValueError is different drives.
+        # https://docs.python.org/3/library/os.path.html#os.path.commonpath
+        return False
+    try:
+        return common == path_prefix or os.path.samefile(common, path_prefix)
+    except OSError:
+        # One of the paths may not exist.
+        return False
+
+
+def include_frame(f: types.FrameType) -> bool:
+    return include_filename(f.f_code.co_filename)
+
+
+def include_filename(filename: str) -> bool:
+    return not any(_path_starts_with(filename, path) for path in _exclude_paths)
+
+
+# When scanning stack traces, we might encounter frames from cpython that are
+# removed from printed stack traces, such as frames from parts of importlib. We
+# ignore these frames heuristically based on source and name match.
+def _ignore_known_hidden_frame(f: types.FrameType) -> bool:
+    return "importlib._bootstrap" in f.f_code.co_filename
+
+
+def _add_tracebackhide_to_hidden_frames(tb: types.TracebackType):
+    for f, _lineno in traceback.walk_tb(tb):
+        if not include_frame(f):
+            f.f_locals["__tracebackhide__"] = True
+
+
+def filter_traceback(tb: types.TracebackType) -> types.TracebackType | None:
+    out = None
+    # Scan the traceback and collect relevant frames.
+    frames = list(traceback.walk_tb(tb))
+    for f, lineno in reversed(frames):
+        if include_frame(f):
+            out = types.TracebackType(out, f, f.f_lasti, lineno)
+    return out
+
+
+def _add_call_stack_frames(tb: types.TracebackType) -> types.TracebackType:
+    # Continue up the call stack.
+    #
+    # We would like to avoid stepping too far up, e.g. past the exec/eval point of
+    # a REPL such as IPython. To that end, we stop past the first contiguous bunch
+    # of module-level frames, if we reach any such frames at all. This is a
+    # heuristic that might stop in advance of the REPL boundary. For example, if
+    # the call stack includes module-level frames from the current module A, and
+    # the current module A was imported from within a function F elsewhere, then
+    # the stack trace we produce will be truncated at F's frame.
+    out = tb
+
+    reached_module_level = False
+    for f, lineno in traceback.walk_stack(tb.tb_frame):
+        if _ignore_known_hidden_frame(f):
+            continue
+        if reached_module_level and f.f_code.co_name != "<module>":
+            break
+        if include_frame(f):
+            out = types.TracebackType(out, f, f.f_lasti, lineno)
+        if f.f_code.co_name == "<module>":
+            reached_module_level = True
+    return out
+
+
+def _is_reraiser_frame(f: traceback.FrameSummary) -> bool:
+    return f.filename == __file__ and f.name == "reraise_with_filtered_traceback"
+
+
+def _is_under_reraiser(e: BaseException) -> bool:
+    if e.__traceback__ is None:
+        return False
+    tb = traceback.extract_stack(e.__traceback__.tb_frame)
+    return any(_is_reraiser_frame(f) for f in tb[:-1])
+
+
+def format_exception_only(e: BaseException) -> str:
+    return "".join(traceback.format_exception_only(type(e), e)).strip()
+
+
+class UnfilteredStackTrace(Exception):
+    pass
+
+
+_simplified_tb_msg = (
+    "For simplicity, JAX has removed its internal frames from the "
+    "traceback of the following exception. Set "
+    "JAX_TRACEBACK_FILTERING=off to include these."
+)
+
+
+class SimplifiedTraceback(Exception):
+    def __str__(self):
+        return _simplified_tb_msg
+
+
+SimplifiedTraceback.__module__ = "jax.errors"
+
+
+def _running_under_ipython() -> bool:
+    """Returns true if we appear to be in an IPython session."""
+    try:
+        get_ipython()  # type: ignore
+        return True
+    except NameError:
+        return False
+
+
+def _ipython_supports_tracebackhide() -> bool:
+    """Returns true if the IPython version supports __tracebackhide__."""
+    import IPython  # pytype: disable=import-error
+
+    return IPython.version_info[:2] >= (7, 17)
+
+
+def _filtering_mode() -> str:
+    mode = None
+    if mode is None or mode == "auto":
+        if _running_under_ipython() and _ipython_supports_tracebackhide():
+            mode = "tracebackhide"
+        else:
+            mode = "quiet_remove_frames"
+    return mode
+
+
+def api_boundary(fun: C) -> C:
+    """Wraps ``fun`` to form a boundary for filtering exception tracebacks.
+
+    When an exception occurs below ``fun``, this appends to it a custom
+    ``__cause__`` that carries a filtered traceback. The traceback imitates the
+    stack trace of the original exception, but with JAX-internal frames removed.
+
+    This boundary annotation works in composition with itself. The topmost frame
+    corresponding to an :func:`~api_boundary` is the one below which stack traces
+    are filtered. In other words, if ``api_boundary(f)`` calls
+    ``api_boundary(g)``, directly or indirectly, the filtered stack trace provided
+    is the same as if ``api_boundary(f)`` were to simply call ``g`` instead.
+
+    This annotation is primarily useful in wrapping functions output by JAX's
+    transformations. For example, consider ``g = jax.jit(f)``. When ``g`` is
+    called, JAX's JIT compilation machinery is invoked, which in turn calls ``f``
+    in order to trace and translate it. If the function ``f`` raises an exception,
+    the stack unwinds through JAX's JIT internals up to the original call site of
+    ``g``. Because the function returned by :func:`~jax.jit` is annotated as an
+    :func:`~api_boundary`, such an exception is accompanied by an additional
+    traceback that excludes the frames specific to JAX's implementation.
+    """
+
+    @functools.wraps(fun)
+    def reraise_with_filtered_traceback(*args, **kwargs):
+        __tracebackhide__ = True
+        try:
+            return fun(*args, **kwargs)
+        except Exception as e:
+            mode = _filtering_mode()
+            if _is_under_reraiser(e) or mode == "off":
+                raise
+            if mode == "tracebackhide":
+                _add_tracebackhide_to_hidden_frames(e.__traceback__)
+                raise
+
+            filtered_tb, unfiltered = None, None
+            try:
+                tb = e.__traceback__
+                filtered_tb = filter_traceback(tb)
+                e.with_traceback(filtered_tb)
+                if mode == "quiet_remove_frames":
+                    e.add_note("--------------------\n" + _simplified_tb_msg)
+                else:
+                    if mode == "remove_frames":
+                        msg = format_exception_only(e)
+                        msg = f"{msg}\n\n{_jax_message_append}"
+                        jax_error = UnfilteredStackTrace(msg)
+                        jax_error.with_traceback(_add_call_stack_frames(tb))
+                    else:
+                        raise ValueError(
+                            f"JAX_TRACEBACK_FILTERING={mode} is not a valid value."
+                        )
+                    jax_error.__cause__ = e.__cause__
+                    jax_error.__context__ = e.__context__
+                    jax_error.__suppress_context__ = e.__suppress_context__
+                    e.__cause__ = jax_error
+                    e.__context__ = None
+                raise
+            finally:
+                del filtered_tb
+                del unfiltered
+                del mode
+
+    return cast(C, reraise_with_filtered_traceback)

>From 02f11a9c77b930f76a8b28459457ed27c1180438 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Mon, 28 Jul 2025 21:12:21 -0400
Subject: [PATCH 2/4] compat patches

---
 mlir/lib/Bindings/Python/Traceback.cpp | 51 +++++++++++++++++++++-----
 1 file changed, 42 insertions(+), 9 deletions(-)

diff --git a/mlir/lib/Bindings/Python/Traceback.cpp b/mlir/lib/Bindings/Python/Traceback.cpp
index fee85cca6574f..35e711f5ebaef 100644
--- a/mlir/lib/Bindings/Python/Traceback.cpp
+++ b/mlir/lib/Bindings/Python/Traceback.cpp
@@ -41,6 +41,31 @@ limitations under the License.
 #undef Py_BUILD_CORE
 #endif // PLATFORM_GOOGLE
 
+// Introduced in python 3.10
+#if PY_VERSION_HEX < 0x030a00f0
+PyObject *Py_NewRef(PyObject *o) {
+  Py_INCREF(o);
+  return o;
+}
+#endif
+
+// bpo-40421 added PyFrame_GetLasti() to Python 3.11.0b1
+#if PY_VERSION_HEX < 0x030B00B1 && !defined(PYPY_VERSION)
+static inline int PyFrame_GetLasti(PyFrameObject *frame) {
+#if PY_VERSION_HEX >= 0x030A00A7
+  // bpo-27129: Since Python 3.10.0a7, f_lasti is an instruction offset,
+  // not a bytes offset anymore. Python uses 16-bit "wordcode" (2 bytes)
+  // instructions.
+  if (frame->f_lasti < 0) {
+    return -1;
+  }
+  return frame->f_lasti * 2;
+#else
+  return frame->f_lasti;
+#endif
+}
+#endif
+
 namespace mlir::python {
 struct TracebackEntry;
 struct TracebackObject;
@@ -135,11 +160,17 @@ static void traceback_tp_dealloc(PyObject *self) {
 }
 
 Traceback::Frame DecodeFrame(const TracebackEntry &frame) {
+  // python 3.11
+#if PY_VERSION_HEX < 0x030b00f0
+  PyObject *name = frame.code->co_name;
+#else
+  PyObject *name = frame.code->co_qualname;
+#endif
   return Traceback::Frame{
-      .file_name = nb::borrow<nb::str>(frame.code->co_filename),
-      .function_name = nb::borrow<nb::str>(frame.code->co_qualname),
-      .function_start_line = frame.code->co_firstlineno,
-      .line_num = PyCode_Addr2Line(frame.code, frame.lasti),
+      /*file_name=*/nb::borrow<nb::str>(frame.code->co_filename),
+      /*function_name=*/nb::borrow<nb::str>(name),
+      /*function_start_line=*/frame.code->co_firstlineno,
+      /*line_num=*/PyCode_Addr2Line(frame.code, frame.lasti),
   };
 }
 
@@ -415,11 +446,13 @@ void BuildTracebackSubmodule(nb::module_ &m) {
           throw std::runtime_error("code argument must be a code object");
         }
         int start_line, start_column, end_line, end_column;
-        if (!PyCode_Addr2Location(reinterpret_cast<PyCodeObject *>(code.ptr()),
-                                  lasti, &start_line, &start_column, &end_line,
-                                  &end_column)) {
-          throw nb::python_error();
-        }
+        // if (!PyCode_Addr2Location(reinterpret_cast<PyCodeObject
+        // *>(code.ptr()),
+        //                           lasti, &start_line, &start_column,
+        //                           &end_line, &end_column)) {
+        //   throw nb::python_error();
+        // }
+        throw nb::python_error();
         return nb::make_tuple(start_line, start_column, end_line, end_column);
       },
       "Python wrapper around the Python C API function PyCode_Addr2Location");

>From f60824e425351d730e860c6e6a5b448b37345dba Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Tue, 29 Jul 2025 13:58:33 -0400
Subject: [PATCH 3/4] refactor

---
 mlir/lib/Bindings/Python/Traceback.cpp |  14 +--
 mlir/python/mlir/source_info_util.py   | 162 ++++++++++++++++++++++---
 mlir/python/mlir/traceback_util.py     |  22 ++++
 mlir/test/python/ir/line_info.py       |  39 ++++++
 4 files changed, 213 insertions(+), 24 deletions(-)
 create mode 100644 mlir/test/python/ir/line_info.py

diff --git a/mlir/lib/Bindings/Python/Traceback.cpp b/mlir/lib/Bindings/Python/Traceback.cpp
index 35e711f5ebaef..66cd9b37963da 100644
--- a/mlir/lib/Bindings/Python/Traceback.cpp
+++ b/mlir/lib/Bindings/Python/Traceback.cpp
@@ -440,22 +440,22 @@ void BuildTracebackSubmodule(nb::module_ &m) {
       },
       "Python wrapper around the Python C API function PyCode_Addr2Line");
 
+#if PY_VERSION_HEX >= 0x030b00f0
   type.attr("code_addr2location") = nb::cpp_function(
       [](nb::handle code, int lasti) {
         if (!PyCode_Check(code.ptr())) {
           throw std::runtime_error("code argument must be a code object");
         }
         int start_line, start_column, end_line, end_column;
-        // if (!PyCode_Addr2Location(reinterpret_cast<PyCodeObject
-        // *>(code.ptr()),
-        //                           lasti, &start_line, &start_column,
-        //                           &end_line, &end_column)) {
-        //   throw nb::python_error();
-        // }
-        throw nb::python_error();
+        if (!PyCode_Addr2Location(reinterpret_cast<PyCodeObject *>(code.ptr()),
+                                  lasti, &start_line, &start_column, &end_line,
+                                  &end_column)) {
+          throw nb::python_error();
+        }
         return nb::make_tuple(start_line, start_column, end_line, end_column);
       },
       "Python wrapper around the Python C API function PyCode_Addr2Location");
+#endif
 }
 } // namespace mlir::python
 
diff --git a/mlir/python/mlir/source_info_util.py b/mlir/python/mlir/source_info_util.py
index ff7ee09149c60..fdc325c7b134b 100644
--- a/mlir/python/mlir/source_info_util.py
+++ b/mlir/python/mlir/source_info_util.py
@@ -14,6 +14,7 @@
 
 from __future__ import annotations
 
+import sys
 from collections.abc import Iterator
 import contextlib
 import dataclasses
@@ -24,25 +25,32 @@
 import sysconfig
 import threading
 import types
-from typing import NamedTuple
+from typing import NamedTuple, Optional
+from .ir import Location
 
 # import jax.version
 # from jax._src.lib import xla_client
 
 from . import traceback_util
+from .traceback_util import (
+    TracebackCaches,
+    Traceback,
+    _traceback_caches,
+    _traceback_in_locations_limit,
+    _include_full_tracebacks_in_locations,
+)
 
 traceback_util.register_exclusion(__file__)
 
-from ._mlir_libs._mlir import Traceback
-
 
-class Frame(NamedTuple):
+ at dataclasses.dataclass(frozen=True)
+class Frame:
     file_name: str
     function_name: str
     start_line: int
-    start_column: int
-    end_line: int
-    end_column: int
+    start_column: Optional[int] = None
+    end_line: Optional[int] = None
+    end_column: Optional[int] = None
 
 
 _exclude_paths: list[str] = [
@@ -173,16 +181,27 @@ def is_user_filename(filename: str) -> bool:
 
 
 def raw_frame_to_frame(code: types.CodeType, lasti: int) -> Frame:
-    loc = Traceback.code_addr2location(code, lasti)
-    start_line, start_column, end_line, end_column = loc
-    return Frame(
-        file_name=code.co_filename,
-        function_name=code.co_qualname,
-        start_line=start_line,
-        start_column=start_column,
-        end_line=end_line,
-        end_column=end_column,
-    )
+    if sys.version_info.minor >= 11:
+        loc = Traceback.code_addr2location(code, lasti)
+        start_line, start_column, end_line, end_column = loc
+        frame = Frame(
+            file_name=code.co_filename,
+            function_name=code.co_qualname,
+            start_line=start_line,
+            start_column=start_column,
+            end_line=end_line,
+            end_column=end_column,
+        )
+    else:
+        start_line = Traceback.code_addr2line(code, lasti)
+        frame = Frame(
+            file_name=code.co_filename,
+            function_name=code.co_name,
+            start_line=start_line,
+            start_column=0,
+        )
+
+    return frame
 
 
 def user_frames(traceback: Traceback | None) -> Iterator[Frame]:
@@ -226,6 +245,7 @@ class _SourceInfoContext(threading.local):
     context: SourceInfo
 
     def __init__(self):
+        super().__init__()
         self.context = new_source_info()
 
 
@@ -365,3 +385,111 @@ def __exit__(self, exc_type, exc_value, traceback):
 
 
 transform_name_stack = TransformNameStackContextManager
+
+
+def get_canonical_source_file(file_name: str, caches: TracebackCaches) -> str:
+    canonical_file_name = caches.canonical_name_cache.get(file_name, None)
+    if canonical_file_name is not None:
+        return canonical_file_name
+
+    # pattern = config.hlo_source_file_canonicalization_regex.value
+    # if pattern:
+    #     file_name = re.sub(pattern, "", file_name)
+    caches.canonical_name_cache[file_name] = file_name
+    return file_name
+
+
+def _is_user_file(file_name: str) -> bool:
+    is_user = _traceback_caches.is_user_file_cache.get(file_name, None)
+    if is_user is not None:
+        return is_user
+    out = is_user_filename(file_name)
+    _traceback_caches.is_user_file_cache[file_name] = out
+    return out
+
+
+def _traceback_to_location(tb: Traceback) -> Location:
+    """Converts a full traceback to a callsite() MLIR location."""
+    loc = _traceback_caches.traceback_cache.get(tb, None)
+    if loc is not None:
+        return loc
+
+    frame_locs = []
+    frames_limit = _traceback_in_locations_limit
+    frames_limit = frames_limit if frames_limit >= 0 else 1000
+
+    codes, lastis = tb.raw_frames()
+    for i, code in enumerate(codes):
+        if not _is_user_file(code.co_filename):
+            continue
+
+        lasti = lastis[i]
+        code_lasti = code, lasti
+        loc = _traceback_caches.location_cache.get(code_lasti, None)
+        if loc is None:
+            frame = raw_frame_to_frame(code, lasti)
+            if (
+                frame.start_column is not None
+                and frame.end_line is not None
+                and frame.end_column is not None
+            ):
+                file_loc = Location.file(
+                    get_canonical_source_file(frame.file_name, _traceback_caches),
+                    frame.start_line,
+                    frame.start_column,
+                    frame.end_line,
+                    frame.end_column,
+                )
+            else:
+                file_loc = Location.file(
+                    get_canonical_source_file(frame.file_name, _traceback_caches),
+                    frame.start_line,
+                    frame.start_column,
+                )
+            loc = Location.name(frame.function_name, childLoc=file_loc)
+            _traceback_caches.location_cache[code_lasti] = loc
+        frame_locs.append(loc)
+        if len(frame_locs) >= frames_limit:
+            break
+
+    n = len(frame_locs)
+    if n == 0:
+        loc = Location.unknown()
+    elif n == 1:
+        loc = frame_locs[0]
+    else:
+        loc = Location.callsite(frame_locs[0], frame_locs[1:])
+    _traceback_caches.traceback_cache[tb] = loc
+    return loc
+
+
+def source_info_to_location(
+    primitive: None,
+    name_stack: NameStack,
+    traceback: Traceback | None,
+) -> Location:
+    if _include_full_tracebacks_in_locations:
+        if traceback is None:
+            loc = Location.unknown()
+        else:
+            loc = _traceback_to_location(traceback)
+    else:
+        frame = user_frame(traceback)
+        if frame is None:
+            loc = Location.unknown()
+        else:
+            loc = Location.file(
+                get_canonical_source_file(frame.file_name, _traceback_caches),
+                frame.start_line,
+                frame.start_column,
+            )
+    if primitive is None:
+        if name_stack.stack:
+            loc = Location.name(str(name_stack), childLoc=loc)
+    else:
+        eqn_str = (
+            f"{name_stack}/{primitive.name}" if name_stack.stack else primitive.name
+        )
+        loc = Location.name(eqn_str, childLoc=loc)
+        loc = Location.name(f"{primitive.name}:", childLoc=loc)
+    return loc
diff --git a/mlir/python/mlir/traceback_util.py b/mlir/python/mlir/traceback_util.py
index 1794fcd1d08ba..2bace9782687b 100644
--- a/mlir/python/mlir/traceback_util.py
+++ b/mlir/python/mlir/traceback_util.py
@@ -14,12 +14,15 @@
 
 from __future__ import annotations
 
+import dataclasses
 from collections.abc import Callable
 import functools
 import os
 import traceback
 import types
 from typing import Any, TypeVar, cast
+from ._mlir_libs._mlir import Traceback
+from .ir import Location
 
 
 C = TypeVar("C", bound=Callable[..., Any])
@@ -236,3 +239,22 @@ def reraise_with_filtered_traceback(*args, **kwargs):
                 del mode
 
     return cast(C, reraise_with_filtered_traceback)
+
+
+ at dataclasses.dataclass
+class TracebackCaches:
+    traceback_cache: dict[Traceback, Location]
+    location_cache: dict[tuple[types.CodeType, int], Location]
+    canonical_name_cache: dict[str, str]
+    is_user_file_cache: dict[str, bool]
+
+    def __init__(self):
+        self.traceback_cache = {}
+        self.location_cache = {}
+        self.canonical_name_cache = {}
+        self.is_user_file_cache = {}
+
+
+_traceback_caches = TracebackCaches()
+_traceback_in_locations_limit = 100
+_include_full_tracebacks_in_locations = True
diff --git a/mlir/test/python/ir/line_info.py b/mlir/test/python/ir/line_info.py
new file mode 100644
index 0000000000000..6c0eea255d2df
--- /dev/null
+++ b/mlir/test/python/ir/line_info.py
@@ -0,0 +1,39 @@
+# RUN: %PYTHON %s | FileCheck %s
+import gc
+import traceback
+
+from mlir import source_info_util
+from mlir.source_info_util import _traceback_to_location
+from mlir import traceback_util
+from mlir.ir import Context
+
+# CHECK: hello
+print("hello")
+
+
+# traceback_util.register_exclusion(__file__)
+
+
+def run(f):
+    print("\nTEST:", f.__name__)
+    with Context() as ctx:
+        f()
+    gc.collect()
+    # assert Context._get_live_count() == 0
+    return f
+
+
+ at run
+def foo():
+    def bar():
+        curr = source_info_util.current()
+        print(curr.name_stack)
+        print(curr.traceback)
+        traceback.print_tb(
+            traceback_util.filter_traceback(curr.traceback.as_python_traceback())
+        )
+
+        loc = _traceback_to_location(curr.traceback)
+        print(loc)
+
+    bar()

>From b13abe9c49f5b377aad54da4e82e71a576489fb5 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Tue, 29 Jul 2025 17:49:27 -0400
Subject: [PATCH 4/4] refactor to just emit inferred locs

---
 mlir/lib/Bindings/Python/IRCore.cpp    |  28 +++++--
 mlir/lib/Bindings/Python/IRModule.h    |   4 +-
 mlir/lib/Bindings/Python/Traceback.cpp |  60 ++++++++++++++
 mlir/lib/Bindings/Python/Traceback.h   |   6 ++
 mlir/python/mlir/source_info_util.py   | 106 ++++++++++++-------------
 mlir/python/mlir/traceback_util.py     |   4 +-
 6 files changed, 145 insertions(+), 63 deletions(-)

diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 5feed95f96f53..2975e7add7d49 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -9,6 +9,7 @@
 #include "Globals.h"
 #include "IRModule.h"
 #include "NanobindUtils.h"
+#include "Traceback.h"
 #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/Debug.h"
@@ -1523,7 +1524,7 @@ nb::object PyOperation::create(std::string_view name,
                                llvm::ArrayRef<MlirValue> operands,
                                std::optional<nb::dict> attributes,
                                std::optional<std::vector<PyBlock *>> successors,
-                               int regions, DefaultingPyLocation location,
+                               int regions, PyLocation location,
                                const nb::object &maybeIp, bool inferType) {
   llvm::SmallVector<MlirType, 4> mlirResults;
   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
@@ -1627,7 +1628,7 @@ nb::object PyOperation::create(std::string_view name,
   if (!operation.ptr)
     throw nb::value_error("Operation creation failed");
   PyOperationRef created =
-      PyOperation::createDetached(location->getContext(), operation);
+      PyOperation::createDetached(location.getContext(), operation);
   maybeInsertOperation(created, maybeIp);
 
   return created.getObject();
@@ -1937,9 +1938,9 @@ nb::object PyOpView::buildGeneric(
     std::optional<nb::list> resultTypeList, nb::list operandList,
     std::optional<nb::dict> attributes,
     std::optional<std::vector<PyBlock *>> successors,
-    std::optional<int> regions, DefaultingPyLocation location,
+    std::optional<int> regions, PyLocation location,
     const nb::object &maybeIp) {
-  PyMlirContextRef context = location->getContext();
+  PyMlirContextRef context = location.getContext();
 
   // Class level operation construction metadata.
   // Operand and result segment specs are either none, which does no
@@ -3456,6 +3457,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
              std::optional<std::vector<PyBlock *>> successors, int regions,
              DefaultingPyLocation location, const nb::object &maybeIp,
              bool inferType) {
+            //////////////
+            std::optional<Traceback> tb = Traceback::Get();
+            PyMlirContextRef ctx = location->getContext();
+            auto loc = tb->tracebackToLocation(ctx->get());
+            PyLocation pyLoc{ctx, loc};
+            //////////////
+
             // Unpack/validate operands.
             llvm::SmallVector<MlirValue, 4> mlirOperands;
             if (operands) {
@@ -3468,7 +3476,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
             }
 
             return PyOperation::create(name, results, mlirOperands, attributes,
-                                       successors, regions, location, maybeIp,
+                                       successors, regions, pyLoc, maybeIp,
                                        inferType);
           },
           nb::arg("name"), nb::arg("results").none() = nb::none(),
@@ -3517,7 +3525,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
                 new (self) PyOpView(PyOpView::buildGeneric(
                     name, opRegionSpec, operandSegmentSpecObj,
                     resultSegmentSpecObj, resultTypeList, operandList,
-                    attributes, successors, regions, location, maybeIp));
+                    attributes, successors, regions, *location.get(), maybeIp));
               },
               nb::arg("name"), nb::arg("opRegionSpec"),
               nb::arg("operandSegmentSpecObj").none() = nb::none(),
@@ -3553,6 +3561,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
          std::optional<std::vector<PyBlock *>> successors,
          std::optional<int> regions, DefaultingPyLocation location,
          const nb::object &maybeIp) {
+        //////////////
+        std::optional<Traceback> tb = Traceback::Get();
+        PyMlirContextRef ctx = location->getContext();
+        auto loc = tb->tracebackToLocation(ctx->get());
+        PyLocation pyLoc{ctx, loc};
+        //////////////
         std::string name = nb::cast<std::string>(cls.attr("OPERATION_NAME"));
         std::tuple<int, bool> opRegionSpec =
             nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
@@ -3561,7 +3575,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
         return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
                                       resultSegmentSpec, resultTypeList,
                                       operandList, attributes, successors,
-                                      regions, location, maybeIp);
+                                      regions, pyLoc, maybeIp);
       },
       nb::arg("cls"), nb::arg("results").none() = nb::none(),
       nb::arg("operands").none() = nb::none(),
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 9c22dea157c06..e21d8660e8434 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -722,7 +722,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
          llvm::ArrayRef<MlirValue> operands,
          std::optional<nanobind::dict> attributes,
          std::optional<std::vector<PyBlock *>> successors, int regions,
-         DefaultingPyLocation location, const nanobind::object &ip,
+         PyLocation location, const nanobind::object &ip,
          bool inferType);
 
   /// Creates an OpView suitable for this operation.
@@ -781,7 +781,7 @@ class PyOpView : public PyOperationBase {
                nanobind::list operandList,
                std::optional<nanobind::dict> attributes,
                std::optional<std::vector<PyBlock *>> successors,
-               std::optional<int> regions, DefaultingPyLocation location,
+               std::optional<int> regions, PyLocation location,
                const nanobind::object &maybeIp);
 
   /// Construct an instance of a class deriving from OpView, bypassing its
diff --git a/mlir/lib/Bindings/Python/Traceback.cpp b/mlir/lib/Bindings/Python/Traceback.cpp
index 66cd9b37963da..812f30534bac9 100644
--- a/mlir/lib/Bindings/Python/Traceback.cpp
+++ b/mlir/lib/Bindings/Python/Traceback.cpp
@@ -14,6 +14,7 @@ limitations under the License.
 ==============================================================================*/
 
 #include "Traceback.h"
+#include "IRModule.h"
 
 #include <Python.h>
 
@@ -244,6 +245,58 @@ std::vector<Traceback::Frame> Traceback::Frames() const {
   return frames;
 }
 
+MlirLocation Traceback::tracebackToLocation(MlirContext ctx) const {
+  // We require the GIL because we manipulate Python strings.
+  assert(PyGILState_Check());
+
+  // check cache
+  int frames_limit = 100;
+  std::vector<MlirLocation> frame_locs{};
+  TracebackObject *tb = reinterpret_cast<TracebackObject *>(ptr());
+  frame_locs.reserve(Py_SIZE(tb));
+  for (Py_ssize_t i = 0; i < Py_SIZE(tb); ++i) {
+    const TracebackEntry &frame = tb->frames[i];
+    // if not _is_user_file(code.co_filename):
+    //     continue
+    // get_canonical_source_file
+    MlirStringRef fileName = mlirStringRefCreateFromCString(
+        nb::borrow<nb::str>(frame.code->co_filename).c_str());
+#if PY_VERSION_HEX < 0x030b00f0
+    MlirStringRef funcName = mlirStringRefCreateFromCString(
+        nb::borrow<nb::str>(frame.code->co_name).c_str());
+    auto line = PyCode_Addr2Line(frame.code, frame.lasti);
+    auto loc = mlirLocationFileLineColGet(ctx, fileName, line, 0);
+#else
+    MlirStringRef funcName = mlirStringRefCreateFromCString(
+        nb::borrow<nb::str>(frame.code->co_qualname).c_str());
+    int start_line, start_column, end_line, end_column;
+    if (!PyCode_Addr2Location(frame.code, frame.lasti, &start_line,
+                              &start_column, &end_line, &end_column)) {
+      throw nb::python_error();
+    }
+    auto loc = mlirLocationFileLineColRangeGet(
+        ctx, fileName, start_column, start_column, end_line, end_column);
+#endif
+    frame_locs.push_back(mlirLocationNameGet(ctx, funcName, loc));
+    if (frame_locs.size() > frames_limit)
+      break;
+  }
+
+  if (frame_locs.empty())
+    return mlirLocationUnknownGet(ctx);
+  if (frame_locs.size() == 1)
+    return frame_locs.front();
+
+  MlirLocation callee = frame_locs.front();
+  frame_locs.erase(frame_locs.begin());
+  MlirLocation caller = frame_locs.back();
+  for (const MlirLocation &frame :
+       llvm::reverse(llvm::ArrayRef(frame_locs).drop_back()))
+    caller = mlirLocationCallSiteGet(frame, caller);
+
+  return mlirLocationCallSiteGet(callee, caller);
+}
+
 std::string Traceback::Frame::ToString() const {
   std::string s = nb::cast<std::string>(file_name);
   s += ":" + std::to_string(line_num) + " ";
@@ -381,6 +434,13 @@ void BuildTracebackSubmodule(nb::module_ &m) {
       object that describes the Python stack of the calling thread. Stack
       trace collection has a small overhead, so it is disabled by default. If
       traceback collection is disabled, returns ``None``. )doc");
+  type.attr("_infer_location") = nb::cpp_function(
+      [](DefaultingPyMlirContext context) {
+        auto tb = Traceback::Get();
+        assert(tb);
+        return tb->tracebackToLocation(context->get());
+      },
+      nb::arg("context") = nb::none());
   type.attr("frames") = nb_property_readonly(&Traceback::Frames);
   type.attr("raw_frames") = nb::cpp_function(
       [](const Traceback &tb) -> nb::tuple {
diff --git a/mlir/lib/Bindings/Python/Traceback.h b/mlir/lib/Bindings/Python/Traceback.h
index 0aab15ddc5da3..0dcfedd8b974b 100644
--- a/mlir/lib/Bindings/Python/Traceback.h
+++ b/mlir/lib/Bindings/Python/Traceback.h
@@ -23,6 +23,8 @@ limitations under the License.
 #include <utility>
 #include <vector>
 
+#include "mlir-c/IR.h"
+
 // placeholder for index annotation headers
 #include "nanobind/nanobind.h"
 
@@ -49,9 +51,13 @@ class Traceback : public nanobind::object {
 
     std::string ToString() const;
   };
+
   // Returns a list of Frames for the traceback.
   std::vector<Frame> Frames() const;
 
+  // Returns a list of Frames for the traceback.
+  MlirLocation tracebackToLocation(MlirContext ctx) const;
+
 private:
   static bool Check(PyObject *o);
 };
diff --git a/mlir/python/mlir/source_info_util.py b/mlir/python/mlir/source_info_util.py
index fdc325c7b134b..937f139164491 100644
--- a/mlir/python/mlir/source_info_util.py
+++ b/mlir/python/mlir/source_info_util.py
@@ -408,59 +408,59 @@ def _is_user_file(file_name: str) -> bool:
     return out
 
 
-def _traceback_to_location(tb: Traceback) -> Location:
-    """Converts a full traceback to a callsite() MLIR location."""
-    loc = _traceback_caches.traceback_cache.get(tb, None)
-    if loc is not None:
-        return loc
-
-    frame_locs = []
-    frames_limit = _traceback_in_locations_limit
-    frames_limit = frames_limit if frames_limit >= 0 else 1000
-
-    codes, lastis = tb.raw_frames()
-    for i, code in enumerate(codes):
-        if not _is_user_file(code.co_filename):
-            continue
-
-        lasti = lastis[i]
-        code_lasti = code, lasti
-        loc = _traceback_caches.location_cache.get(code_lasti, None)
-        if loc is None:
-            frame = raw_frame_to_frame(code, lasti)
-            if (
-                frame.start_column is not None
-                and frame.end_line is not None
-                and frame.end_column is not None
-            ):
-                file_loc = Location.file(
-                    get_canonical_source_file(frame.file_name, _traceback_caches),
-                    frame.start_line,
-                    frame.start_column,
-                    frame.end_line,
-                    frame.end_column,
-                )
-            else:
-                file_loc = Location.file(
-                    get_canonical_source_file(frame.file_name, _traceback_caches),
-                    frame.start_line,
-                    frame.start_column,
-                )
-            loc = Location.name(frame.function_name, childLoc=file_loc)
-            _traceback_caches.location_cache[code_lasti] = loc
-        frame_locs.append(loc)
-        if len(frame_locs) >= frames_limit:
-            break
-
-    n = len(frame_locs)
-    if n == 0:
-        loc = Location.unknown()
-    elif n == 1:
-        loc = frame_locs[0]
-    else:
-        loc = Location.callsite(frame_locs[0], frame_locs[1:])
-    _traceback_caches.traceback_cache[tb] = loc
-    return loc
+# def _traceback_to_location(tb: Traceback) -> Location:
+#     """Converts a full traceback to a callsite() MLIR location."""
+#     loc = _traceback_caches.traceback_cache.get(tb, None)
+#     if loc is not None:
+#         return loc
+#
+#     frame_locs = []
+#     frames_limit = _traceback_in_locations_limit
+#     frames_limit = frames_limit if frames_limit >= 0 else 1000
+#
+#     codes, lastis = tb.raw_frames()
+#     for _, code in enumerate(codes):
+#         if not _is_user_file(code.co_filename):
+#             continue
+#
+#         lasti = lastis[i]
+#         code_lasti = code, lasti
+#         loc = _traceback_caches.location_cache.get(code_lasti, None)
+#         if loc is None:
+#             frame = raw_frame_to_frame(code, lasti)
+#             if (
+#                 frame.start_column is not None
+#                 and frame.end_line is not None
+#                 and frame.end_column is not None
+#             ):
+#                 file_loc = Location.file(
+#                     get_canonical_source_file(frame.file_name, _traceback_caches),
+#                     frame.start_line,
+#                     frame.start_column,
+#                     frame.end_line,
+#                     frame.end_column,
+#                 )
+#             else:
+#                 file_loc = Location.file(
+#                     get_canonical_source_file(frame.file_name, _traceback_caches),
+#                     frame.start_line,
+#                     frame.start_column,
+#                 )
+#             loc = Location.name(frame.function_name, childLoc=file_loc)
+#             _traceback_caches.location_cache[code_lasti] = loc
+#         frame_locs.append(loc)
+#         if len(frame_locs) >= frames_limit:
+#             break
+#
+#     n = len(frame_locs)
+#     if n == 0:
+#         loc = Location.unknown()
+#     elif n == 1:
+#         loc = frame_locs[0]
+#     else:
+#         loc = Location.callsite(frame_locs[0], frame_locs[1:])
+#     _traceback_caches.traceback_cache[tb] = loc
+#     return loc
 
 
 def source_info_to_location(
diff --git a/mlir/python/mlir/traceback_util.py b/mlir/python/mlir/traceback_util.py
index 2bace9782687b..893fb43588dfb 100644
--- a/mlir/python/mlir/traceback_util.py
+++ b/mlir/python/mlir/traceback_util.py
@@ -21,10 +21,12 @@
 import traceback
 import types
 from typing import Any, TypeVar, cast
-from ._mlir_libs._mlir import Traceback
+from ._mlir_libs._mlir import Traceback, set_tracebacks_enabled
 from .ir import Location
 
 
+set_tracebacks_enabled(True)
+
 C = TypeVar("C", bound=Callable[..., Any])
 
 _exclude_paths: list[str] = [__file__]



More information about the Mlir-commits mailing list