[Mlir-commits] [mlir] [mlir][python] auto-locs (PR #151246)

Maksim Levental llvmlistbot at llvm.org
Tue Jul 29 15:49:47 PDT 2025


https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/151246

>From c238c3f79b0cf197e0d3d3ac41c7724d5e10ca89 Mon Sep 17 00:00:00 2001
From: max <maksim.levental at gmail.com>
Date: Tue, 29 Jul 2025 18:21:53 -0400
Subject: [PATCH] [mlir][python] auto-locs

---
 mlir/lib/Bindings/Python/Globals.h      |  12 +++
 mlir/lib/Bindings/Python/IRCore.cpp     | 113 +++++++++++++++++++++---
 mlir/lib/Bindings/Python/IRModule.h     |   5 +-
 mlir/lib/Bindings/Python/MainModule.cpp |   4 +-
 mlir/test/python/ir/location.py         |  19 ++++
 5 files changed, 138 insertions(+), 15 deletions(-)

diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h
index 826a34a535176..23ebede7dac71 100644
--- a/mlir/lib/Bindings/Python/Globals.h
+++ b/mlir/lib/Bindings/Python/Globals.h
@@ -114,11 +114,23 @@ class PyGlobals {
   std::optional<nanobind::object>
   lookupOperationClass(llvm::StringRef operationName);
 
+  bool tracebacksEnabled() {
+    nanobind::ft_lock_guard lock(mutex);
+    return tracebackEnabled_;
+  }
+
+  void setTracebacksEnabled(bool value) {
+    nanobind::ft_lock_guard lock(mutex);
+    tracebackEnabled_ = value;
+  }
+
 private:
   static PyGlobals *instance;
 
   nanobind::ft_mutex mutex;
 
+  bool tracebackEnabled_ = false;
+
   /// Module name prefixes to search under for dialect implementation modules.
   std::vector<std::string> dialectSearchPrefixes;
   /// Map of dialect namespace to external dialect class object.
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 5feed95f96f53..d7cc83636fd8c 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -1523,7 +1523,7 @@ nb::object PyOperation::create(std::string_view name,
                                llvm::ArrayRef<MlirValue> operands,
                                std::optional<nb::dict> attributes,
                                std::optional<std::vector<PyBlock *>> successors,
-                               int regions, DefaultingPyLocation location,
+                               int regions, PyLocation location,
                                const nb::object &maybeIp, bool inferType) {
   llvm::SmallVector<MlirType, 4> mlirResults;
   llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
@@ -1627,7 +1627,7 @@ nb::object PyOperation::create(std::string_view name,
   if (!operation.ptr)
     throw nb::value_error("Operation creation failed");
   PyOperationRef created =
-      PyOperation::createDetached(location->getContext(), operation);
+      PyOperation::createDetached(location.getContext(), operation);
   maybeInsertOperation(created, maybeIp);
 
   return created.getObject();
@@ -1937,9 +1937,9 @@ nb::object PyOpView::buildGeneric(
     std::optional<nb::list> resultTypeList, nb::list operandList,
     std::optional<nb::dict> attributes,
     std::optional<std::vector<PyBlock *>> successors,
-    std::optional<int> regions, DefaultingPyLocation location,
+    std::optional<int> regions, PyLocation location,
     const nb::object &maybeIp) {
-  PyMlirContextRef context = location->getContext();
+  PyMlirContextRef context = location.getContext();
 
   // Class level operation construction metadata.
   // Operand and result segment specs are either none, which does no
@@ -2789,6 +2789,70 @@ class PyOpAttributeMap {
   PyOperationRef operation;
 };
 
+std::optional<MlirLocation> tracebackToLocation(MlirContext ctx) {
+  // We use a thread_local here mostly to avoid requiring a large amount of
+  // space.
+  size_t frames_limit = 100;
+  thread_local std::vector<MlirLocation> frames;
+  frames.reserve(frames_limit);
+  int count = 0;
+
+  assert(PyGILState_Check());
+
+  if (!PyGlobals::get().tracebacksEnabled())
+    return std::nullopt;
+
+  PyThreadState *thread_state = PyThreadState_GET();
+
+  PyFrameObject *next;
+  for (PyFrameObject *py_frame = PyThreadState_GetFrame(thread_state);
+       py_frame != nullptr && count < frames_limit; py_frame = next) {
+    PyCodeObject *code = PyFrame_GetCode(py_frame);
+    int lasti = PyFrame_GetLasti(py_frame);
+    MlirStringRef fileName = mlirStringRefCreateFromCString(
+        nb::borrow<nb::str>(code->co_filename).c_str());
+
+#if PY_VERSION_HEX < 0x030b00f0
+    MlirStringRef funcName = mlirStringRefCreateFromCString(
+        nb::borrow<nb::str>(frame.code->co_name).c_str());
+    auto line = PyCode_Addr2Line(frame.code, frame.lasti);
+    auto loc = mlirLocationFileLineColGet(ctx, fileName, line, 0);
+#else
+    MlirStringRef funcName = mlirStringRefCreateFromCString(
+        nb::borrow<nb::str>(code->co_qualname).c_str());
+    int start_line, start_column, end_line, end_column;
+    if (!PyCode_Addr2Location(code, lasti, &start_line, &start_column,
+                              &end_line, &end_column)) {
+      throw nb::python_error();
+    }
+    auto loc = mlirLocationFileLineColRangeGet(
+        ctx, fileName, start_column, start_column, end_line, end_column);
+#endif
+
+    frames.push_back(mlirLocationNameGet(ctx, funcName, loc));
+    ++count;
+    next = PyFrame_GetBack(py_frame);
+    Py_XDECREF(py_frame);
+
+    if (frames.size() > frames_limit)
+      break;
+  }
+
+  if (frames.empty())
+    return mlirLocationUnknownGet(ctx);
+  if (frames.size() == 1)
+    return frames.front();
+
+  MlirLocation callee = frames.front();
+  frames.erase(frames.begin());
+  MlirLocation caller = frames.back();
+  for (const MlirLocation &frame :
+       llvm::reverse(llvm::ArrayRef(frames).drop_back()))
+    caller = mlirLocationCallSiteGet(frame, caller);
+
+  return mlirLocationCallSiteGet(callee, caller);
+}
+
 } // namespace
 
 //------------------------------------------------------------------------------
@@ -3241,6 +3305,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
       .def_static(
           "create",
           [](DefaultingPyLocation loc) {
+            PyMlirContextRef ctx = loc->getContext();
+            MlirLocation mlirLoc = loc;
+            if (auto tloc = tracebackToLocation(ctx->get()))
+              mlirLoc = *tloc;
             MlirModule module = mlirModuleCreateEmpty(loc);
             return PyModule::forModule(module).releaseObject();
           },
@@ -3467,9 +3535,15 @@ void mlir::python::populateIRCore(nb::module_ &m) {
               }
             }
 
+            PyMlirContextRef ctx = location->getContext();
+            if (auto loc = tracebackToLocation(ctx->get())) {
+              return PyOperation::create(name, results, mlirOperands,
+                                         attributes, successors, regions,
+                                         {ctx, *loc}, maybeIp, inferType);
+            }
             return PyOperation::create(name, results, mlirOperands, attributes,
-                                       successors, regions, location, maybeIp,
-                                       inferType);
+                                       successors, regions, *location.get(),
+                                       maybeIp, inferType);
           },
           nb::arg("name"), nb::arg("results").none() = nb::none(),
           nb::arg("operands").none() = nb::none(),
@@ -3514,10 +3588,19 @@ void mlir::python::populateIRCore(nb::module_ &m) {
                  std::optional<std::vector<PyBlock *>> successors,
                  std::optional<int> regions, DefaultingPyLocation location,
                  const nb::object &maybeIp) {
-                new (self) PyOpView(PyOpView::buildGeneric(
-                    name, opRegionSpec, operandSegmentSpecObj,
-                    resultSegmentSpecObj, resultTypeList, operandList,
-                    attributes, successors, regions, location, maybeIp));
+                PyMlirContextRef ctx = location->getContext();
+                if (auto loc = tracebackToLocation(ctx->get())) {
+                  new (self) PyOpView(PyOpView::buildGeneric(
+                      name, opRegionSpec, operandSegmentSpecObj,
+                      resultSegmentSpecObj, resultTypeList, operandList,
+                      attributes, successors, regions, {ctx, *loc}, maybeIp));
+                } else {
+                  new (self) PyOpView(PyOpView::buildGeneric(
+                      name, opRegionSpec, operandSegmentSpecObj,
+                      resultSegmentSpecObj, resultTypeList, operandList,
+                      attributes, successors, regions, *location.get(),
+                      maybeIp));
+                }
               },
               nb::arg("name"), nb::arg("opRegionSpec"),
               nb::arg("operandSegmentSpecObj").none() = nb::none(),
@@ -3558,10 +3641,18 @@ void mlir::python::populateIRCore(nb::module_ &m) {
             nb::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
         nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS");
         nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS");
+
+        PyMlirContextRef ctx = location->getContext();
+        if (auto loc = tracebackToLocation(ctx->get())) {
+          return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
+                                        resultSegmentSpec, resultTypeList,
+                                        operandList, attributes, successors,
+                                        regions, {ctx, *loc}, maybeIp);
+        }
         return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec,
                                       resultSegmentSpec, resultTypeList,
                                       operandList, attributes, successors,
-                                      regions, location, maybeIp);
+                                      regions, *location.get(), maybeIp);
       },
       nb::arg("cls"), nb::arg("results").none() = nb::none(),
       nb::arg("operands").none() = nb::none(),
diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h
index 9c22dea157c06..87e1a0b12da00 100644
--- a/mlir/lib/Bindings/Python/IRModule.h
+++ b/mlir/lib/Bindings/Python/IRModule.h
@@ -722,8 +722,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
          llvm::ArrayRef<MlirValue> operands,
          std::optional<nanobind::dict> attributes,
          std::optional<std::vector<PyBlock *>> successors, int regions,
-         DefaultingPyLocation location, const nanobind::object &ip,
-         bool inferType);
+         PyLocation location, const nanobind::object &ip, bool inferType);
 
   /// Creates an OpView suitable for this operation.
   nanobind::object createOpView();
@@ -781,7 +780,7 @@ class PyOpView : public PyOperationBase {
                nanobind::list operandList,
                std::optional<nanobind::dict> attributes,
                std::optional<std::vector<PyBlock *>> successors,
-               std::optional<int> regions, DefaultingPyLocation location,
+               std::optional<int> regions, PyLocation location,
                const nanobind::object &maybeIp);
 
   /// Construct an instance of a class deriving from OpView, bypassing its
diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp
index 6f49431006605..dca4366109305 100644
--- a/mlir/lib/Bindings/Python/MainModule.cpp
+++ b/mlir/lib/Bindings/Python/MainModule.cpp
@@ -44,7 +44,9 @@ NB_MODULE(_mlir, m) {
       .def("_register_operation_impl", &PyGlobals::registerOperationImpl,
            "operation_name"_a, "operation_class"_a, nb::kw_only(),
            "replace"_a = false,
-           "Testing hook for directly registering an operation");
+           "Testing hook for directly registering an operation")
+      .def("tracebacks_enabled", &PyGlobals::tracebacksEnabled)
+      .def("set_tracebacks_enabled", &PyGlobals::setTracebacksEnabled);
 
   // Aside from making the globals accessible to python, having python manage
   // it is necessary to make sure it is destroyed (and releases its python
diff --git a/mlir/test/python/ir/location.py b/mlir/test/python/ir/location.py
index 3e54dc922cd67..5ff4449ddc8b4 100644
--- a/mlir/test/python/ir/location.py
+++ b/mlir/test/python/ir/location.py
@@ -1,7 +1,10 @@
 # RUN: %PYTHON %s | FileCheck %s
 
 import gc
+from contextlib import contextmanager
+
 from mlir.ir import *
+from mlir.dialects._ods_common import _cext
 
 
 def run(f):
@@ -27,6 +30,22 @@ def testUnknown():
 run(testUnknown)
 
 
+ at contextmanager
+def with_infer_location():
+    _cext.globals.set_tracebacks_enabled(True)
+    yield
+    _cext.globals.set_tracebacks_enabled(False)
+
+
+# CHECK-LABEL: TEST: testInferLocations
+def testInferLocations():
+    with Context(), with_infer_location():
+        op = Operation.create("custom.op1")
+
+
+run(testInferLocations)
+
+
 # CHECK-LABEL: TEST: testLocationAttr
 def testLocationAttr():
     with Context() as ctxt:



More information about the Mlir-commits mailing list