[Mlir-commits] [mlir] 16585af - [mlir][acc] Fix bindNameValue for RoutineOp (#187307)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Wed Mar 18 09:23:39 PDT 2026
Author: Razvan Lupusoru
Date: 2026-03-18T09:23:34-07:00
New Revision: 16585af33b4dd4e922de479402dac9c36a81d9a0
URL: https://github.com/llvm/llvm-project/commit/16585af33b4dd4e922de479402dac9c36a81d9a0
DIFF: https://github.com/llvm/llvm-project/commit/16585af33b4dd4e922de479402dac9c36a81d9a0.diff
LOG: [mlir][acc] Fix bindNameValue for RoutineOp (#187307)
If the routine op only has one of the string or id attributes, the API
was crashing since it was attempting to search in both. Guard each
search individually.
Added:
Modified:
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
index c1b2e0d714e08..0709b9ebb95aa 100644
--- a/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
+++ b/mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
@@ -4673,23 +4673,22 @@ RoutineOp::getBindNameValue() {
std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
- if (!hasDeviceTypeValues(getBindIdNameDeviceType()) &&
- !hasDeviceTypeValues(getBindStrNameDeviceType())) {
- return std::nullopt;
- }
-
- if (auto pos = findSegment(*getBindIdNameDeviceType(), deviceType)) {
- auto attr = (*getBindIdName())[*pos];
- auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
- assert(symbolRefAttr && "expected SymbolRef");
- return symbolRefAttr;
+ if (hasDeviceTypeValues(getBindIdNameDeviceType())) {
+ if (auto pos = findSegment(*getBindIdNameDeviceType(), deviceType)) {
+ auto attr = (*getBindIdName())[*pos];
+ auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
+ assert(symbolRefAttr && "expected SymbolRef");
+ return symbolRefAttr;
+ }
}
- if (auto pos = findSegment(*getBindStrNameDeviceType(), deviceType)) {
- auto attr = (*getBindStrName())[*pos];
- auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
- assert(stringAttr && "expected String");
- return stringAttr;
+ if (hasDeviceTypeValues(getBindStrNameDeviceType())) {
+ if (auto pos = findSegment(*getBindStrNameDeviceType(), deviceType)) {
+ auto attr = (*getBindStrName())[*pos];
+ auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
+ assert(stringAttr && "expected String");
+ return stringAttr;
+ }
}
return std::nullopt;
diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
index 825bca99b4ba7..f277001ca1cfd 100644
--- a/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
+++ b/mlir/unittests/Dialect/OpenACC/OpenACCOpsTest.cpp
@@ -562,6 +562,46 @@ TEST_F(OpenACCOpsTest, routineOpTest) {
op->removeBindStrNameAttr();
}
+TEST_F(OpenACCOpsTest, routineOpGetBindNameValueOnlyBindStrOrOnlyBindId) {
+ // getBindNameValue(DeviceType) must not dereference when only one of bind(id)
+ // or bind(name) is set (the other has no device-type array).
+ OwningOpRef<RoutineOp> op =
+ RoutineOp::create(b, loc, TypeRange{}, ValueRange{});
+
+ auto dtypeNone = DeviceTypeAttr::get(&context, DeviceType::None);
+ auto dtypeNvidia = DeviceTypeAttr::get(&context, DeviceType::Nvidia);
+
+ // Only bind(name): no bindIdName/bindIdNameDeviceType. getBindNameValue
+ // must not crash when looking up by device type.
+ op->setBindStrNameDeviceTypeAttr(b.getArrayAttr({dtypeNvidia}));
+ op->setBindStrNameAttr(b.getArrayAttr({b.getStringAttr("only_str_bind")}));
+ EXPECT_TRUE(op->getBindNameValue(DeviceType::Nvidia).has_value());
+ EXPECT_EQ(std::visit(
+ [](const auto &attr) -> std::string {
+ if constexpr (std::is_same_v<std::decay_t<decltype(attr)>,
+ mlir::StringAttr>) {
+ return attr.str();
+ } else {
+ return attr.getLeafReference().str();
+ }
+ },
+ op->getBindNameValue(DeviceType::Nvidia).value()),
+ "only_str_bind");
+ EXPECT_FALSE(op->getBindNameValue(DeviceType::Host).has_value());
+ op->removeBindStrNameDeviceTypeAttr();
+ op->removeBindStrNameAttr();
+
+ // Only bind(id): no bindStrName/bindStrNameDeviceType. getBindNameValue
+ // must not crash when looking up by device type.
+ op->setBindIdNameDeviceTypeAttr(b.getArrayAttr({dtypeNone}));
+ op->setBindIdNameAttr(
+ b.getArrayAttr({SymbolRefAttr::get(&context, "only_id_bind")}));
+ EXPECT_TRUE(op->getBindNameValue().has_value());
+ EXPECT_FALSE(op->getBindNameValue(DeviceType::Nvidia).has_value());
+ op->removeBindIdNameDeviceTypeAttr();
+ op->removeBindIdNameAttr();
+}
+
template <typename Op>
static void testShortDataEntryOpBuilders(OpBuilder &b, MLIRContext &context,
Location loc, DataClause dataClause) {
More information about the Mlir-commits
mailing list