[Mlir-commits] [mlir] [mlir][CAPI][python] bind CallSiteLoc, FileLineColRange, FusedLoc, NameLoc (PR #129351)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Mar 8 19:59:57 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Maksim Levental (makslevental)
<details>
<summary>Changes</summary>
This PR extends the python bindings for CallSiteLoc, FileLineColRange, FusedLoc, NameLoc with field accessors. It also adds the missing `value.location` accessor.
I also did some "spring cleaning" here (`cast` -> `dyn_cast`) after running into some of my own illegal casts.
---
Patch is 24.83 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/129351.diff
5 Files Affected:
- (modified) mlir/include/mlir-c/IR.h (+80)
- (modified) mlir/include/mlir/Bindings/Python/NanobindAdaptors.h (+10)
- (modified) mlir/lib/Bindings/Python/IRCore.cpp (+41-6)
- (modified) mlir/lib/CAPI/IR/IR.cpp (+106-6)
- (modified) mlir/test/python/ir/location.py (+131-43)
``````````diff
diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h
index d562da1f90757..7fd6a41fb435b 100644
--- a/mlir/include/mlir-c/IR.h
+++ b/mlir/include/mlir-c/IR.h
@@ -261,15 +261,75 @@ MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColRangeGet(
MlirContext context, MlirStringRef filename, unsigned start_line,
unsigned start_col, unsigned end_line, unsigned end_col);
+/// Getter for filename of FileLineColRange.
+MLIR_CAPI_EXPORTED MlirIdentifier
+mlirLocationFileLineColRangeGetFilename(MlirLocation location);
+
+/// Getter for start_line of FileLineColRange.
+MLIR_CAPI_EXPORTED int
+mlirLocationFileLineColRangeGetStartLine(MlirLocation location);
+
+/// Getter for start_column of FileLineColRange.
+MLIR_CAPI_EXPORTED int
+mlirLocationFileLineColRangeGetStartColumn(MlirLocation location);
+
+/// Getter for end_line of FileLineColRange.
+MLIR_CAPI_EXPORTED int
+mlirLocationFileLineColRangeGetEndLine(MlirLocation location);
+
+/// Getter for end_column of FileLineColRange.
+MLIR_CAPI_EXPORTED int
+mlirLocationFileLineColRangeGetEndColumn(MlirLocation location);
+
+/// TypeID Getter for FileLineColRange.
+MLIR_CAPI_EXPORTED MlirTypeID mlirLocationFileLineColRangeGetTypeID(void);
+
+/// Checks whether the given location is an FileLineColRange.
+MLIR_CAPI_EXPORTED bool mlirLocationIsAFileLineColRange(MlirLocation location);
+
/// Creates a call site location with a callee and a caller.
MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGet(MlirLocation callee,
MlirLocation caller);
+/// Getter for callee of CallSite.
+MLIR_CAPI_EXPORTED MlirLocation
+mlirLocationCallSiteGetCallee(MlirLocation location);
+
+/// Getter for caller of CallSite.
+MLIR_CAPI_EXPORTED MlirLocation
+mlirLocationCallSiteGetCaller(MlirLocation location);
+
+/// TypeID Getter for CallSite.
+MLIR_CAPI_EXPORTED MlirTypeID mlirLocationCallSiteGetTypeID(void);
+
+/// Checks whether the given location is an CallSite.
+MLIR_CAPI_EXPORTED bool mlirLocationIsACallSite(MlirLocation location);
+
/// Creates a fused location with an array of locations and metadata.
MLIR_CAPI_EXPORTED MlirLocation
mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations,
MlirLocation const *locations, MlirAttribute metadata);
+/// Getter for number of locations fused together.
+MLIR_CAPI_EXPORTED unsigned
+mlirLocationFusedGetNumLocations(MlirLocation location);
+
+/// Getter for locations of Fused. Requires pre-allocated memory of
+/// #fusedLocations X sizeof(MlirLocation).
+MLIR_CAPI_EXPORTED void
+mlirLocationFusedGetLocations(MlirLocation location,
+ MlirLocation *locationsCPtr);
+
+/// Getter for metadata of Fused.
+MLIR_CAPI_EXPORTED MlirAttribute
+mlirLocationFusedGetMetadata(MlirLocation location);
+
+/// TypeID Getter for Fused.
+MLIR_CAPI_EXPORTED MlirTypeID mlirLocationFusedGetTypeID(void);
+
+/// Checks whether the given location is an Fused.
+MLIR_CAPI_EXPORTED bool mlirLocationIsAFused(MlirLocation location);
+
/// Creates a name location owned by the given context. Providing null location
/// for childLoc is allowed and if childLoc is null location, then the behavior
/// is the same as having unknown child location.
@@ -277,6 +337,20 @@ MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGet(MlirContext context,
MlirStringRef name,
MlirLocation childLoc);
+/// Getter for name of Name.
+MLIR_CAPI_EXPORTED MlirIdentifier
+mlirLocationNameGetName(MlirLocation location);
+
+/// Getter for childLoc of Name.
+MLIR_CAPI_EXPORTED MlirLocation
+mlirLocationNameGetChildLoc(MlirLocation location);
+
+/// TypeID Getter for Name.
+MLIR_CAPI_EXPORTED MlirTypeID mlirLocationNameGetTypeID(void);
+
+/// Checks whether the given location is an Name.
+MLIR_CAPI_EXPORTED bool mlirLocationIsAName(MlirLocation location);
+
/// Creates a location with unknown position owned by the given context.
MLIR_CAPI_EXPORTED MlirLocation mlirLocationUnknownGet(MlirContext context);
@@ -978,6 +1052,12 @@ mlirValueReplaceAllUsesExcept(MlirValue of, MlirValue with,
intptr_t numExceptions,
MlirOperation *exceptions);
+/// Gets the location of the value.
+MLIR_CAPI_EXPORTED MlirLocation mlirValueGetLocation(MlirValue v);
+
+/// Gets the context that a value was created with.
+MLIR_CAPI_EXPORTED MlirContext mlirValueGetContext(MlirValue v);
+
//===----------------------------------------------------------------------===//
// OpOperand API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
index 0608182f00b7e..3646bf42e415f 100644
--- a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
+++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h
@@ -321,6 +321,16 @@ struct type_caster<MlirType> {
}
};
+/// Casts MlirStringRef -> object.
+template <>
+struct type_caster<MlirStringRef> {
+ NB_TYPE_CASTER(MlirStringRef, const_name("MlirStringRef"))
+ static handle from_cpp(MlirStringRef s, rv_policy,
+ cleanup_list *cleanup) noexcept {
+ return nanobind::str(s.data, s.length).release();
+ }
+};
+
} // namespace detail
} // namespace nanobind
diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp
index 12793f7dd15be..9fd061d1c8dd9 100644
--- a/mlir/lib/Bindings/Python/IRCore.cpp
+++ b/mlir/lib/Bindings/Python/IRCore.cpp
@@ -2943,6 +2943,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("callee"), nb::arg("frames"),
nb::arg("context").none() = nb::none(),
kContextGetCallSiteLocationDocstring)
+ .def("is_a_callsite", mlirLocationIsACallSite)
+ .def_prop_ro("callee", mlirLocationCallSiteGetCallee)
+ .def_prop_ro("caller", mlirLocationCallSiteGetCaller)
.def_static(
"file",
[](std::string filename, int line, int col,
@@ -2967,6 +2970,16 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"),
nb::arg("end_line"), nb::arg("end_col"),
nb::arg("context").none() = nb::none(), kContextGetFileRangeDocstring)
+ .def("is_a_file", mlirLocationIsAFileLineColRange)
+ .def_prop_ro("filename",
+ [](MlirLocation loc) {
+ return mlirIdentifierStr(
+ mlirLocationFileLineColRangeGetFilename(loc));
+ })
+ .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine)
+ .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn)
+ .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine)
+ .def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn)
.def_static(
"fused",
[](const std::vector<PyLocation> &pyLocations,
@@ -2984,6 +2997,16 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("locations"), nb::arg("metadata").none() = nb::none(),
nb::arg("context").none() = nb::none(),
kContextGetFusedLocationDocstring)
+ .def("is_a_fused", mlirLocationIsAFused)
+ .def_prop_ro("locations",
+ [](MlirLocation loc) {
+ unsigned numLocations =
+ mlirLocationFusedGetNumLocations(loc);
+ std::vector<MlirLocation> locations(numLocations);
+ if (numLocations)
+ mlirLocationFusedGetLocations(loc, locations.data());
+ return locations;
+ })
.def_static(
"name",
[](std::string name, std::optional<PyLocation> childLoc,
@@ -2998,6 +3021,12 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("name"), nb::arg("childLoc").none() = nb::none(),
nb::arg("context").none() = nb::none(),
kContextGetNameLocationDocString)
+ .def("is_a_name", mlirLocationIsAName)
+ .def_prop_ro("name_str",
+ [](MlirLocation loc) {
+ return mlirIdentifierStr(mlirLocationNameGetName(loc));
+ })
+ .def_prop_ro("child_loc", mlirLocationNameGetChildLoc)
.def_static(
"from_attr",
[](PyAttribute &attribute, DefaultingPyMlirContext context) {
@@ -3148,9 +3177,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
auto &concreteOperation = self.getOperation();
concreteOperation.checkValid();
MlirOperation operation = concreteOperation.get();
- MlirStringRef name =
- mlirIdentifierStr(mlirOperationGetName(operation));
- return nb::str(name.data, name.length);
+ return mlirIdentifierStr(mlirOperationGetName(operation));
})
.def_prop_ro("operands",
[](PyOperationBase &self) {
@@ -3738,8 +3765,7 @@ void mlir::python::populateIRCore(nb::module_ &m) {
.def_prop_ro(
"name",
[](PyNamedAttribute &self) {
- return nb::str(mlirIdentifierStr(self.namedAttr.name).data,
- mlirIdentifierStr(self.namedAttr.name).length);
+ return mlirIdentifierStr(self.namedAttr.name);
},
"The name of the NamedAttribute binding")
.def_prop_ro(
@@ -3972,7 +3998,16 @@ void mlir::python::populateIRCore(nb::module_ &m) {
nb::arg("with"), nb::arg("exceptions"),
kValueReplaceAllUsesExceptDocstring)
.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
- [](PyValue &self) { return self.maybeDownCast(); });
+ [](PyValue &self) { return self.maybeDownCast(); })
+ .def_prop_ro(
+ "location",
+ [](MlirValue self) {
+ return PyLocation(
+ PyMlirContext::forContext(mlirValueGetContext(self)),
+ mlirValueGetLocation(self));
+ },
+ "Returns the source location the value");
+
PyBlockArgument::bind(m);
PyOpResult::bind(m);
PyOpOperand::bind(m);
diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp
index 6cd9ba2aef233..b5226d50e6b43 100644
--- a/mlir/lib/CAPI/IR/IR.cpp
+++ b/mlir/lib/CAPI/IR/IR.cpp
@@ -259,7 +259,7 @@ MlirAttribute mlirLocationGetAttribute(MlirLocation location) {
}
MlirLocation mlirLocationFromAttribute(MlirAttribute attribute) {
- return wrap(Location(llvm::cast<LocationAttr>(unwrap(attribute))));
+ return wrap(Location(llvm::dyn_cast<LocationAttr>(unwrap(attribute))));
}
MlirLocation mlirLocationFileLineColGet(MlirContext context,
@@ -278,10 +278,62 @@ mlirLocationFileLineColRangeGet(MlirContext context, MlirStringRef filename,
startLine, startCol, endLine, endCol)));
}
+MlirIdentifier mlirLocationFileLineColRangeGetFilename(MlirLocation location) {
+ return wrap(llvm::dyn_cast<FileLineColRange>(unwrap(location)).getFilename());
+}
+
+int mlirLocationFileLineColRangeGetStartLine(MlirLocation location) {
+ if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location)))
+ return loc.getStartLine();
+ return -1;
+}
+
+int mlirLocationFileLineColRangeGetStartColumn(MlirLocation location) {
+ if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location)))
+ return loc.getStartColumn();
+ return -1;
+}
+
+int mlirLocationFileLineColRangeGetEndLine(MlirLocation location) {
+ if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location)))
+ return loc.getEndLine();
+ return -1;
+}
+
+int mlirLocationFileLineColRangeGetEndColumn(MlirLocation location) {
+ if (auto loc = llvm::dyn_cast<FileLineColRange>(unwrap(location)))
+ return loc.getEndColumn();
+ return -1;
+}
+
+MlirTypeID mlirLocationFileLineColRangeGetTypeID() {
+ return wrap(FileLineColRange::getTypeID());
+}
+
+bool mlirLocationIsAFileLineColRange(MlirLocation location) {
+ return isa<FileLineColRange>(unwrap(location));
+}
+
MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) {
return wrap(Location(CallSiteLoc::get(unwrap(callee), unwrap(caller))));
}
+MlirLocation mlirLocationCallSiteGetCallee(MlirLocation location) {
+ return wrap(Location(llvm::dyn_cast<CallSiteLoc>(unwrap(location)).getCallee()));
+}
+
+MlirLocation mlirLocationCallSiteGetCaller(MlirLocation location) {
+ return wrap(Location(llvm::dyn_cast<CallSiteLoc>(unwrap(location)).getCaller()));
+}
+
+MlirTypeID mlirLocationCallSiteGetTypeID() {
+ return wrap(CallSiteLoc::getTypeID());
+}
+
+bool mlirLocationIsACallSite(MlirLocation location) {
+ return isa<CallSiteLoc>(unwrap(location));
+}
+
MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations,
MlirLocation const *locations,
MlirAttribute metadata) {
@@ -290,6 +342,30 @@ MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations,
return wrap(FusedLoc::get(unwrappedLocs, unwrap(metadata), unwrap(ctx)));
}
+unsigned mlirLocationFusedGetNumLocations(MlirLocation location) {
+ if (auto locationsArrRef = llvm::dyn_cast<FusedLoc>(unwrap(location)))
+ return locationsArrRef.getLocations().size();
+ return 0;
+}
+
+void mlirLocationFusedGetLocations(MlirLocation location,
+ MlirLocation *locationsCPtr) {
+ if (auto locationsArrRef = llvm::dyn_cast<FusedLoc>(unwrap(location))) {
+ for (auto [i, location] : llvm::enumerate(locationsArrRef.getLocations()))
+ locationsCPtr[i] = wrap(location);
+ }
+}
+
+MlirAttribute mlirLocationFusedGetMetadata(MlirLocation location) {
+ return wrap(llvm::dyn_cast<FusedLoc>(unwrap(location)).getMetadata());
+}
+
+MlirTypeID mlirLocationFusedGetTypeID() { return wrap(FusedLoc::getTypeID()); }
+
+bool mlirLocationIsAFused(MlirLocation location) {
+ return isa<FusedLoc>(unwrap(location));
+}
+
MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name,
MlirLocation childLoc) {
if (mlirLocationIsNull(childLoc))
@@ -299,6 +375,21 @@ MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name,
StringAttr::get(unwrap(context), unwrap(name)), unwrap(childLoc))));
}
+MlirIdentifier mlirLocationNameGetName(MlirLocation location) {
+ return wrap((llvm::dyn_cast<NameLoc>(unwrap(location)).getName()));
+}
+
+MlirLocation mlirLocationNameGetChildLoc(MlirLocation location) {
+ return wrap(
+ Location(llvm::dyn_cast<NameLoc>(unwrap(location)).getChildLoc()));
+}
+
+MlirTypeID mlirLocationNameGetTypeID() { return wrap(NameLoc::getTypeID()); }
+
+bool mlirLocationIsAName(MlirLocation location) {
+ return isa<NameLoc>(unwrap(location));
+}
+
MlirLocation mlirLocationUnknownGet(MlirContext context) {
return wrap(Location(UnknownLoc::get(unwrap(context))));
}
@@ -975,25 +1066,26 @@ bool mlirValueIsAOpResult(MlirValue value) {
}
MlirBlock mlirBlockArgumentGetOwner(MlirValue value) {
- return wrap(llvm::cast<BlockArgument>(unwrap(value)).getOwner());
+ return wrap(llvm::dyn_cast<BlockArgument>(unwrap(value)).getOwner());
}
intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) {
return static_cast<intptr_t>(
- llvm::cast<BlockArgument>(unwrap(value)).getArgNumber());
+ llvm::dyn_cast<BlockArgument>(unwrap(value)).getArgNumber());
}
void mlirBlockArgumentSetType(MlirValue value, MlirType type) {
- llvm::cast<BlockArgument>(unwrap(value)).setType(unwrap(type));
+ if (auto blockArg = llvm::dyn_cast<BlockArgument>(unwrap(value)))
+ blockArg.setType(unwrap(type));
}
MlirOperation mlirOpResultGetOwner(MlirValue value) {
- return wrap(llvm::cast<OpResult>(unwrap(value)).getOwner());
+ return wrap(llvm::dyn_cast<OpResult>(unwrap(value)).getOwner());
}
intptr_t mlirOpResultGetResultNumber(MlirValue value) {
return static_cast<intptr_t>(
- llvm::cast<OpResult>(unwrap(value)).getResultNumber());
+ llvm::dyn_cast<OpResult>(unwrap(value)).getResultNumber());
}
MlirType mlirValueGetType(MlirValue value) {
@@ -1047,6 +1139,14 @@ void mlirValueReplaceAllUsesExcept(MlirValue oldValue, MlirValue newValue,
oldValueCpp.replaceAllUsesExcept(newValueCpp, exceptionSet);
}
+MlirLocation mlirValueGetLocation(MlirValue v) {
+ return wrap(unwrap(v).getLoc());
+}
+
+MlirContext mlirValueGetContext(MlirValue v) {
+ return wrap(unwrap(v).getContext());
+}
+
//===----------------------------------------------------------------------===//
// OpOperand API.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/python/ir/location.py b/mlir/test/python/ir/location.py
index 59d8a89e770dd..3e54dc922cd67 100644
--- a/mlir/test/python/ir/location.py
+++ b/mlir/test/python/ir/location.py
@@ -43,22 +43,64 @@ def testLocationAttr():
run(testLocationAttr)
+
# CHECK-LABEL: TEST: testFileLineCol
def testFileLineCol():
with Context() as ctx:
- loc = Location.file("foo.txt", 123, 56)
- range = Location.file("foo.txt", 123, 56, 123, 100)
+ loc = Location.file("foo1.txt", 123, 56)
+ range = Location.file("foo2.txt", 123, 56, 124, 100)
+
ctx = None
gc.collect()
- # CHECK: file str: loc("foo.txt":123:56)
+
+ # CHECK: file str: loc("foo1.txt":123:56)
print("file str:", str(loc))
- # CHECK: file repr: loc("foo.txt":123:56)
+ # CHECK: file repr: loc("foo1.txt":123:56)
print("file repr:", repr(loc))
- # CHECK: file range str: loc("foo.txt":123:56 to :100)
+ # CHECK: file range str: loc("foo2.txt":123:56 to 124:100)
print("file range str:", str(range))
- # CHECK: file range repr: loc("foo.txt":123:56 to :100)
+ # CHECK: file range repr: loc("foo2.txt":123:56 to 124:100)
print("file range repr:", repr(range))
+ assert loc.is_a_file()
+ assert not loc.is_a_name()
+ assert not loc.is_a_callsite()
+ assert not loc.is_a_fused()
+
+ # CHECK: file filename: foo1.txt
+ print("file filename:", loc.filename)
+ # CHECK: file start_line: 123
+ print("file start_line:", loc.start_line)
+ # CHECK: file start_col: 56
+ print("file start_col:", loc.start_col)
+ # CHECK: file end_line: 123
+ print("file end_line:", loc.end_line)
+ # CHECK: file end_col: 56
+ print("file end_col:", loc.end_col)
+
+ assert range.is_a_file()
+ # CHECK: file filename: foo2.txt
+ print("file filename:", range.filename)
+ # CHECK: file start_line: 123
+ print("file start_line:", range.start_line)
+ # CHECK: file start_col: 56
+ print("file start_col:", range.start_col)
+ # CHECK: file end_line: 124
+ print("file end_line:", range.end_line)
+ # CHECK: file end_col: 100
+ print("file end_col:", range.end_col)
+
+ with Context() as ctx:
+ ctx.allow_unregistered_dialects = True
+ loc = Location.file("foo3.txt", 127, 61)
+ with loc:
+ i32 = IntegerType.get_signless(32)
+ module = Module.create()
+ with InsertionPoint(module.body):
+ new_value = Operation.create("custom.op1", results=[i32]).result
+ # CHECK: new_value location: loc("foo3.txt":127:61)
+ print("new_value location: ", new_value.location)
+
run(testFileLineCol)
@@ -67,17 +109,31 @@ def testFileLineCol():
def testName():
with Context() as ctx:
loc = Location.name("nombre")
- locWithChildLoc = Location.name("naam", loc)
+ loc_with_child_loc = Location.name("naam", loc)
+
ctx = None
gc.collect()
- # CHECK: file str: loc("nombre")
- print("file str:", str(loc))
- # CHECK: file repr: loc("nombre")
- print("file repr:", repr(loc))
- # CHECK: file str: loc("naam"("nombre"))
- print("file str:", str(locWithChildLoc))
- # CHECK: file repr: loc("naam"("nombre"))
- print("file repr:", repr(locWithChildLoc))
+
+ # CHECK: name str: loc("nombre")
+ print("name str:", str(loc))
+ # CHECK: name repr: loc("nombre")
+ print("name repr:", repr(loc))
+ # CHECK: name str: loc("naam"("nombre"))
+ print("name str:", str(loc_with_child_loc))
+ # CHEC...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/129351
More information about the Mlir-commits
mailing list