[Mlir-commits] [mlir] da92bc0 - [mlir][acc] Support call target handling for bind(name) (#187390)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Mar 19 07:30:34 PDT 2026
Author: Razvan Lupusoru
Date: 2026-03-19T07:30:28-07:00
New Revision: da92bc06ff47bb060dc9f303856d12e751763312
URL: https://github.com/llvm/llvm-project/commit/da92bc06ff47bb060dc9f303856d12e751763312
DIFF: https://github.com/llvm/llvm-project/commit/da92bc06ff47bb060dc9f303856d12e751763312.diff
LOG: [mlir][acc] Support call target handling for bind(name) (#187390)
The OpenACC `routine` directive may specify a `bind(name)` clause to
associate the routine with a different symbol for device code. This pass
`ACCBindRoutine` finds calls inside offload regions that target such
routines and rewrites the callee to the bound symbol.
---------
Co-authored-by: Delaram Talaashrafi <dtalaashrafi at nvidia.com>
Added:
mlir/lib/Dialect/OpenACC/Transforms/ACCBindRoutine.cpp
mlir/test/Dialect/OpenACC/acc-bind-routine.mlir
Modified:
mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
Removed:
################################################################################
diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
index bb98759457166..786d338cea600 100644
--- a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
@@ -486,4 +486,15 @@ def ACCRoutineToGPUFunc : Pass<"acc-routine-to-gpu-func", "mlir::ModuleOp"> {
let options = [ AccDeviceTypeOption ];
}
+def ACCBindRoutine : Pass<"acc-bind-routine", "mlir::func::FuncOp"> {
+ let summary = "Apply bind clause to function calls in ACC compute regions";
+ let description = [{
+ For calls inside offload regions that target a function with an
+ `acc routine` directive and a `bind(name)` clause, rewrite the
+ call to use the bound symbol so device code calls the correct
+ call target.
+ }];
+ let options = [ AccDeviceTypeOption ];
+}
+
#endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCBindRoutine.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCBindRoutine.cpp
new file mode 100644
index 0000000000000..f4c8f18e5b04b
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCBindRoutine.cpp
@@ -0,0 +1,152 @@
+//===- ACCBindRoutine.cpp - OpenACC bind routine transform ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// The OpenACC `routine` directive may specify a `bind(name)` clause to
+// associate the routine with a
diff erent symbol for device code. This pass
+// finds calls inside offload regions that target such routines and rewrites the
+// callee to the bound symbol.
+//
+// Overview:
+// ---------
+// For each function, walk operations that implement OffloadRegionOpInterface.
+// For each call inside the offload region, if the callee is a function with
+// an acc routine that has bind(name), replace the call to use the bound
+// symbol.
+//
+// Requirements:
+// -------------
+// - OffloadRegionOpInterface: the pass walks operations implementing this
+// interface to discover offload regions (e.g. acc.compute_region) and
+// rewrites calls inside their getOffloadRegion().
+// - CallOpInterface with working setCalleeFromCallable: call operations
+// must implement CallOpInterface and setCalleeFromCallable so the pass
+// can rewrite the callee to the symbol without invalidating the call.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir {
+namespace acc {
+#define GEN_PASS_DEF_ACCBINDROUTINE
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
+} // namespace acc
+} // namespace mlir
+
+#define DEBUG_TYPE "acc-bind-routine"
+
+using namespace mlir;
+using namespace mlir::acc;
+
+namespace {
+
+static RoutineOp getFirstAccRoutineOp(FunctionOpInterface funcOp,
+ const SymbolTable &symTab) {
+ if (isSpecializedAccRoutine(funcOp)) {
+ auto attr = funcOp->getAttrOfType<SpecializedRoutineAttr>(
+ getSpecializedRoutineAttrName());
+ return symTab.lookup<RoutineOp>(attr.getRoutine().getLeafReference());
+ }
+ auto routineInfo =
+ funcOp->getAttrOfType<RoutineInfoAttr>(getRoutineInfoAttrName());
+ assert(routineInfo && "expected acc.routine_info for acc routine function");
+ auto accRoutines = routineInfo.getAccRoutines();
+ assert(!accRoutines.empty() && "expected at least one acc routine");
+ return symTab.lookup<RoutineOp>(accRoutines[0].getLeafReference());
+}
+
+static bool isACCRoutineBindDefaultOrDeviceType(RoutineOp op,
+ DeviceType deviceType) {
+ if (!op.getBindIdName() && !op.getBindStrName())
+ return false;
+ return op.getBindNameValue().has_value() ||
+ op.getBindNameValue(deviceType).has_value();
+}
+
+class ACCBindRoutine : public acc::impl::ACCBindRoutineBase<ACCBindRoutine> {
+public:
+ using acc::impl::ACCBindRoutineBase<ACCBindRoutine>::ACCBindRoutineBase;
+
+ void runOnOperation() override {
+ func::FuncOp func = getOperation();
+ ModuleOp module = func->getParentOfType<ModuleOp>();
+ if (!module)
+ return;
+
+ SymbolTable symTab(module);
+ auto cachedAnalysis =
+ getCachedParentAnalysis<OpenACCSupport>(func->getParentOp());
+ OpenACCSupport &accSupport =
+ cachedAnalysis ? cachedAnalysis->get() : getAnalysis<OpenACCSupport>();
+
+ bool failed = false;
+
+ func.walk([&](acc::OffloadRegionOpInterface offload) {
+ Region ®ion = offload.getOffloadRegion();
+ region.walk([&](CallOpInterface callOp) {
+ if (!callOp.getCallableForCallee())
+ return;
+ SymbolRefAttr calleeSymbolRef =
+ dyn_cast<SymbolRefAttr>(callOp.getCallableForCallee());
+ if (!calleeSymbolRef)
+ return;
+
+ FunctionOpInterface callee = symTab.lookup<FunctionOpInterface>(
+ calleeSymbolRef.getLeafReference());
+ if (!callee)
+ return;
+
+ if (!(isAccRoutine(callee) || isSpecializedAccRoutine(callee)))
+ return;
+
+ if (auto routineInfo = callee->getAttrOfType<RoutineInfoAttr>(
+ getRoutineInfoAttrName())) {
+ if (routineInfo.getAccRoutines().size() > 1) {
+ (void)accSupport.emitNYI(callOp.getLoc(),
+ "multiple `acc routine`s");
+ failed = true;
+ return;
+ }
+ }
+
+ RoutineOp routine = getFirstAccRoutineOp(callee, symTab);
+ if (!isACCRoutineBindDefaultOrDeviceType(routine, this->deviceType))
+ return;
+
+ auto bindNameOpt = routine.getBindNameValue(this->deviceType);
+ if (!bindNameOpt)
+ bindNameOpt = routine.getBindNameValue();
+ if (!bindNameOpt)
+ return;
+
+ SymbolRefAttr calleeRef;
+ if (auto *symRef = std::get_if<SymbolRefAttr>(&*bindNameOpt)) {
+ calleeRef = *symRef;
+ } else {
+ calleeRef = FlatSymbolRefAttr::get(
+ callOp.getContext(),
+ std::get<StringAttr>(*bindNameOpt).getValue());
+ }
+ callOp.setCalleeFromCallable(calleeRef);
+ });
+ });
+
+ if (failed)
+ signalPassFailure();
+ }
+};
+
+} // namespace
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
index eb4eecfff129f..5bb92592a6512 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
add_mlir_dialect_library(MLIROpenACCTransforms
+ ACCBindRoutine.cpp
ACCComputeLowering.cpp
ACCRoutineLowering.cpp
ACCRoutineToGPUFunc.cpp
diff --git a/mlir/test/Dialect/OpenACC/acc-bind-routine.mlir b/mlir/test/Dialect/OpenACC/acc-bind-routine.mlir
new file mode 100644
index 0000000000000..52e1b9c675ceb
--- /dev/null
+++ b/mlir/test/Dialect/OpenACC/acc-bind-routine.mlir
@@ -0,0 +1,91 @@
+// RUN: mlir-opt %s -acc-bind-routine -split-input-file | FileCheck %s
+
+// Call to routine with bind is rewritten to the bound symbol inside
+// offload region.
+module {
+ acc.routine @r_bind func(@foo) seq bind(@bar)
+ func.func @foo() attributes {acc.routine_info = #acc.routine_info<[@r_bind]>} {
+ return
+ }
+ func.func @bar() {
+ return
+ }
+ func.func @main() {
+ acc.serial {
+ func.call @foo() : () -> ()
+ acc.yield
+ }
+ return
+ }
+}
+
+// CHECK: acc.routine @r_bind func(@foo) bind(@bar) seq
+// CHECK: func.call @bar() : () -> ()
+
+// -----
+
+// Bind with string name: call is rewritten to the string symbol.
+module {
+ acc.routine @r_bind_str func(@wrapped) seq bind("actual_impl")
+ func.func @wrapped() attributes {acc.routine_info = #acc.routine_info<[@r_bind_str]>} {
+ return
+ }
+ func.func @actual_impl() {
+ return
+ }
+ func.func @entry() {
+ acc.serial {
+ func.call @wrapped() : () -> ()
+ acc.yield
+ }
+ return
+ }
+}
+
+// CHECK: func.call @actual_impl() : () -> ()
+
+// -----
+
+// Call outside offload region is unchanged.
+module {
+ acc.routine @r_bind func(@foo) seq bind(@bar)
+ func.func @foo() attributes {acc.routine_info = #acc.routine_info<[@r_bind]>} {
+ return
+ }
+ func.func @bar() {
+ return
+ }
+ func.func @main() {
+ func.call @foo() : () -> ()
+ acc.serial {
+ func.call @foo() : () -> ()
+ acc.yield
+ }
+ return
+ }
+}
+
+// CHECK: call @foo() : () -> ()
+// CHECK: func.call @bar() : () -> ()
+
+// -----
+
+// Indirect call (callee is value) is skipped; no crash.
+module {
+ acc.routine @r_bind func(@target) seq bind(@bound)
+ func.func @target() attributes {acc.routine_info = #acc.routine_info<[@r_bind]>} {
+ return
+ }
+ func.func @bound() {
+ return
+ }
+ func.func @caller(%callee: () -> ()) {
+ acc.serial {
+ func.call_indirect %callee() : () -> ()
+ acc.yield
+ }
+ return
+ }
+}
+
+// CHECK: func.call_indirect %{{.*}}() : () -> ()
More information about the Mlir-commits
mailing list