[Mlir-commits] [mlir] da92bc0 - [mlir][acc] Support call target handling for bind(name) (#187390)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 19 07:30:34 PDT 2026


Author: Razvan Lupusoru
Date: 2026-03-19T07:30:28-07:00
New Revision: da92bc06ff47bb060dc9f303856d12e751763312

URL: https://github.com/llvm/llvm-project/commit/da92bc06ff47bb060dc9f303856d12e751763312
DIFF: https://github.com/llvm/llvm-project/commit/da92bc06ff47bb060dc9f303856d12e751763312.diff

LOG: [mlir][acc] Support call target handling for bind(name) (#187390)

The OpenACC `routine` directive may specify a `bind(name)` clause to
associate the routine with a different symbol for device code. This pass
`ACCBindRoutine` finds calls inside offload regions that target such
routines and rewrites the callee to the bound symbol.

---------

Co-authored-by: Delaram Talaashrafi <dtalaashrafi at nvidia.com>

Added: 
    mlir/lib/Dialect/OpenACC/Transforms/ACCBindRoutine.cpp
    mlir/test/Dialect/OpenACC/acc-bind-routine.mlir

Modified: 
    mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
    mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
index bb98759457166..786d338cea600 100644
--- a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
@@ -486,4 +486,15 @@ def ACCRoutineToGPUFunc : Pass<"acc-routine-to-gpu-func", "mlir::ModuleOp"> {
   let options = [ AccDeviceTypeOption ];
 }
 
+def ACCBindRoutine : Pass<"acc-bind-routine", "mlir::func::FuncOp"> {
+  let summary = "Apply bind clause to function calls in ACC compute regions";
+  let description = [{
+    For calls inside offload regions that target a function with an
+    `acc routine` directive and a `bind(name)` clause, rewrite the
+    call to use the bound symbol so device code calls the correct
+    call target.
+  }];
+  let options = [ AccDeviceTypeOption ];
+}
+
 #endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES

diff  --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCBindRoutine.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCBindRoutine.cpp
new file mode 100644
index 0000000000000..f4c8f18e5b04b
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCBindRoutine.cpp
@@ -0,0 +1,152 @@
+//===- ACCBindRoutine.cpp - OpenACC bind routine transform ---------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// The OpenACC `routine` directive may specify a `bind(name)` clause to
+// associate the routine with a 
diff erent symbol for device code. This pass
+// finds calls inside offload regions that target such routines and rewrites the
+// callee to the bound symbol.
+//
+// Overview:
+// ---------
+// For each function, walk operations that implement OffloadRegionOpInterface.
+// For each call inside the offload region, if the callee is a function with
+// an acc routine that has bind(name), replace the call to use the bound
+// symbol.
+//
+// Requirements:
+// -------------
+// - OffloadRegionOpInterface: the pass walks operations implementing this
+//   interface to discover offload regions (e.g. acc.compute_region) and
+//   rewrites calls inside their getOffloadRegion().
+// - CallOpInterface with working setCalleeFromCallable: call operations
+//   must implement CallOpInterface and setCalleeFromCallable so the pass
+//   can rewrite the callee to the symbol without invalidating the call.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/OpenACC/Analysis/OpenACCSupport.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/Interfaces/CallInterfaces.h"
+#include "llvm/Support/Debug.h"
+
+namespace mlir {
+namespace acc {
+#define GEN_PASS_DEF_ACCBINDROUTINE
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
+} // namespace acc
+} // namespace mlir
+
+#define DEBUG_TYPE "acc-bind-routine"
+
+using namespace mlir;
+using namespace mlir::acc;
+
+namespace {
+
+static RoutineOp getFirstAccRoutineOp(FunctionOpInterface funcOp,
+                                      const SymbolTable &symTab) {
+  if (isSpecializedAccRoutine(funcOp)) {
+    auto attr = funcOp->getAttrOfType<SpecializedRoutineAttr>(
+        getSpecializedRoutineAttrName());
+    return symTab.lookup<RoutineOp>(attr.getRoutine().getLeafReference());
+  }
+  auto routineInfo =
+      funcOp->getAttrOfType<RoutineInfoAttr>(getRoutineInfoAttrName());
+  assert(routineInfo && "expected acc.routine_info for acc routine function");
+  auto accRoutines = routineInfo.getAccRoutines();
+  assert(!accRoutines.empty() && "expected at least one acc routine");
+  return symTab.lookup<RoutineOp>(accRoutines[0].getLeafReference());
+}
+
+static bool isACCRoutineBindDefaultOrDeviceType(RoutineOp op,
+                                                DeviceType deviceType) {
+  if (!op.getBindIdName() && !op.getBindStrName())
+    return false;
+  return op.getBindNameValue().has_value() ||
+         op.getBindNameValue(deviceType).has_value();
+}
+
+class ACCBindRoutine : public acc::impl::ACCBindRoutineBase<ACCBindRoutine> {
+public:
+  using acc::impl::ACCBindRoutineBase<ACCBindRoutine>::ACCBindRoutineBase;
+
+  void runOnOperation() override {
+    func::FuncOp func = getOperation();
+    ModuleOp module = func->getParentOfType<ModuleOp>();
+    if (!module)
+      return;
+
+    SymbolTable symTab(module);
+    auto cachedAnalysis =
+        getCachedParentAnalysis<OpenACCSupport>(func->getParentOp());
+    OpenACCSupport &accSupport =
+        cachedAnalysis ? cachedAnalysis->get() : getAnalysis<OpenACCSupport>();
+
+    bool failed = false;
+
+    func.walk([&](acc::OffloadRegionOpInterface offload) {
+      Region &region = offload.getOffloadRegion();
+      region.walk([&](CallOpInterface callOp) {
+        if (!callOp.getCallableForCallee())
+          return;
+        SymbolRefAttr calleeSymbolRef =
+            dyn_cast<SymbolRefAttr>(callOp.getCallableForCallee());
+        if (!calleeSymbolRef)
+          return;
+
+        FunctionOpInterface callee = symTab.lookup<FunctionOpInterface>(
+            calleeSymbolRef.getLeafReference());
+        if (!callee)
+          return;
+
+        if (!(isAccRoutine(callee) || isSpecializedAccRoutine(callee)))
+          return;
+
+        if (auto routineInfo = callee->getAttrOfType<RoutineInfoAttr>(
+                getRoutineInfoAttrName())) {
+          if (routineInfo.getAccRoutines().size() > 1) {
+            (void)accSupport.emitNYI(callOp.getLoc(),
+                                     "multiple `acc routine`s");
+            failed = true;
+            return;
+          }
+        }
+
+        RoutineOp routine = getFirstAccRoutineOp(callee, symTab);
+        if (!isACCRoutineBindDefaultOrDeviceType(routine, this->deviceType))
+          return;
+
+        auto bindNameOpt = routine.getBindNameValue(this->deviceType);
+        if (!bindNameOpt)
+          bindNameOpt = routine.getBindNameValue();
+        if (!bindNameOpt)
+          return;
+
+        SymbolRefAttr calleeRef;
+        if (auto *symRef = std::get_if<SymbolRefAttr>(&*bindNameOpt)) {
+          calleeRef = *symRef;
+        } else {
+          calleeRef = FlatSymbolRefAttr::get(
+              callOp.getContext(),
+              std::get<StringAttr>(*bindNameOpt).getValue());
+        }
+        callOp.setCalleeFromCallable(calleeRef);
+      });
+    });
+
+    if (failed)
+      signalPassFailure();
+  }
+};
+
+} // namespace

diff  --git a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
index eb4eecfff129f..5bb92592a6512 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
@@ -1,4 +1,5 @@
 add_mlir_dialect_library(MLIROpenACCTransforms
+  ACCBindRoutine.cpp
   ACCComputeLowering.cpp
   ACCRoutineLowering.cpp
   ACCRoutineToGPUFunc.cpp

diff  --git a/mlir/test/Dialect/OpenACC/acc-bind-routine.mlir b/mlir/test/Dialect/OpenACC/acc-bind-routine.mlir
new file mode 100644
index 0000000000000..52e1b9c675ceb
--- /dev/null
+++ b/mlir/test/Dialect/OpenACC/acc-bind-routine.mlir
@@ -0,0 +1,91 @@
+// RUN: mlir-opt %s -acc-bind-routine -split-input-file | FileCheck %s
+
+// Call to routine with bind is rewritten to the bound symbol inside
+// offload region.
+module {
+  acc.routine @r_bind func(@foo) seq bind(@bar)
+  func.func @foo() attributes {acc.routine_info = #acc.routine_info<[@r_bind]>} {
+    return
+  }
+  func.func @bar() {
+    return
+  }
+  func.func @main() {
+    acc.serial {
+      func.call @foo() : () -> ()
+      acc.yield
+    }
+    return
+  }
+}
+
+// CHECK: acc.routine @r_bind func(@foo) bind(@bar) seq
+// CHECK: func.call @bar() : () -> ()
+
+// -----
+
+// Bind with string name: call is rewritten to the string symbol.
+module {
+  acc.routine @r_bind_str func(@wrapped) seq bind("actual_impl")
+  func.func @wrapped() attributes {acc.routine_info = #acc.routine_info<[@r_bind_str]>} {
+    return
+  }
+  func.func @actual_impl() {
+    return
+  }
+  func.func @entry() {
+    acc.serial {
+      func.call @wrapped() : () -> ()
+      acc.yield
+    }
+    return
+  }
+}
+
+// CHECK: func.call @actual_impl() : () -> ()
+
+// -----
+
+// Call outside offload region is unchanged.
+module {
+  acc.routine @r_bind func(@foo) seq bind(@bar)
+  func.func @foo() attributes {acc.routine_info = #acc.routine_info<[@r_bind]>} {
+    return
+  }
+  func.func @bar() {
+    return
+  }
+  func.func @main() {
+    func.call @foo() : () -> ()
+    acc.serial {
+      func.call @foo() : () -> ()
+      acc.yield
+    }
+    return
+  }
+}
+
+// CHECK: call @foo() : () -> ()
+// CHECK: func.call @bar() : () -> ()
+
+// -----
+
+// Indirect call (callee is value) is skipped; no crash.
+module {
+  acc.routine @r_bind func(@target) seq bind(@bound)
+  func.func @target() attributes {acc.routine_info = #acc.routine_info<[@r_bind]>} {
+    return
+  }
+  func.func @bound() {
+    return
+  }
+  func.func @caller(%callee: () -> ()) {
+    acc.serial {
+      func.call_indirect %callee() : () -> ()
+      acc.yield
+    }
+    return
+  }
+}
+
+// CHECK: func.call_indirect %{{.*}}() : () -> ()


        


More information about the Mlir-commits mailing list