[Mlir-commits] [mlir] [acc] Introduce ACCRoutineLowering for `acc routine` specialization (PR #186243)
Razvan Lupusoru
llvmlistbot at llvm.org
Thu Mar 12 13:51:12 PDT 2026
https://github.com/razvanlupusoru updated https://github.com/llvm/llvm-project/pull/186243
>From 5df48c46db4c2ae87d576f26b85fb4b7dd5aceff Mon Sep 17 00:00:00 2001
From: Delaram Talaashrafi <dtalaashrafi at nvidia.com>
Date: Thu, 12 Mar 2026 13:31:08 -0700
Subject: [PATCH 1/2] [acc] Introduce ACCRoutineLowering for `acc routine`
specialization
This pass handles `acc routine` directive by creating specialized
functions with appropriate parallelism information that can be used for
eventual creation of device function.
For each acc.routine that is not bound by name, the pass creates a new
function (the "device" copy) whose body is a single acc.compute_region
containing a clone of the original (host) function body. Parallelism is
expressed by one acc.par_width derived from the routine's clauses (seq,
vector, worker, gang). The device copy created is simply a staging
place for eventual move to device module level function.
---
.../mlir/Dialect/OpenACC/OpenACCUtilsCG.h | 9 +-
.../mlir/Dialect/OpenACC/Transforms/Passes.td | 25 ++
.../OpenACC/Transforms/ACCRoutineLowering.cpp | 252 ++++++++++++++++++
.../Dialect/OpenACC/Transforms/CMakeLists.txt | 1 +
.../Dialect/OpenACC/Utils/OpenACCUtilsCG.cpp | 41 ++-
.../Dialect/OpenACC/acc-routine-lowering.mlir | 125 +++++++++
.../Dialect/OpenACC/OpenACCUtilsCGTest.cpp | 66 +++++
7 files changed, 509 insertions(+), 10 deletions(-)
create mode 100644 mlir/lib/Dialect/OpenACC/Transforms/ACCRoutineLowering.cpp
create mode 100644 mlir/test/Dialect/OpenACC/acc-routine-lowering.mlir
diff --git a/mlir/include/mlir/Dialect/OpenACC/OpenACCUtilsCG.h b/mlir/include/mlir/Dialect/OpenACC/OpenACCUtilsCG.h
index 5a0d70c53bece..f72d080858747 100644
--- a/mlir/include/mlir/Dialect/OpenACC/OpenACCUtilsCG.h
+++ b/mlir/include/mlir/Dialect/OpenACC/OpenACCUtilsCG.h
@@ -43,6 +43,12 @@ std::optional<DataLayout> getDataLayout(Operation *op,
/// The `mapping` is used and updated during cloning, allowing callers to
/// track value correspondences. Optional `output`, `kernelFuncName`,
/// `kernelModuleName`, and `stream` arguments are forwarded to the op.
+///
+/// When `inputArgsToMap` is non-empty, it is used as the key set for the
+/// clone mapping (instead of `inputArgs`). Use this when cloning a region
+/// that references one set of values (e.g. the source function's args) while
+/// the op's operands are another set (e.g. the current block's args).
+/// `inputArgsToMap` must have the same size as `inputArgs` when provided.
ComputeRegionOp buildComputeRegion(Location loc, ValueRange launchArgs,
ValueRange inputArgs, llvm::StringRef origin,
Region ®ionToClone,
@@ -50,7 +56,8 @@ ComputeRegionOp buildComputeRegion(Location loc, ValueRange launchArgs,
ValueRange output = {},
FlatSymbolRefAttr kernelFuncName = {},
FlatSymbolRefAttr kernelModuleName = {},
- Value stream = {});
+ Value stream = {},
+ ValueRange inputArgsToMap = {});
} // namespace acc
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
index 9ab99208f83c7..3dfbca478c16b 100644
--- a/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/OpenACC/Transforms/Passes.td
@@ -450,4 +450,29 @@ def ACCComputeLowering : Pass<"acc-compute-lowering", "mlir::func::FuncOp"> {
let options = [ AccDeviceTypeOption ];
}
+def ACCRoutineLowering : Pass<"acc-routine-lowering", "mlir::ModuleOp"> {
+ let summary = "Specialize `acc routine` functions for device";
+ let description = [{
+ This pass handles `acc routine` directive by creating specialized
+ functions with appropriate parallelism information that can be used for
+ eventual creation of device function.
+
+ For each acc.routine that is not bound by name, the pass creates a new
+ function (the "device" copy) whose body is a single acc.compute_region
+ containing a clone of the original (host) function body. Parallelism is
+ expressed by one acc.par_width derived from the routine's clauses (seq,
+ vector, worker, gang). The pass does not use acc.kernel_environment. It
+ sets acc.specialized_routine on the new function and updates the
+ acc.routine's func_name to point to it. For nohost routines, all uses of
+ the host symbol are replaced with the device symbol and the host function
+ is erased. Routines with bind(name) and external functions are skipped.
+ }];
+ let dependentDialects = [
+ "mlir::acc::OpenACCDialect",
+ "mlir::func::FuncDialect",
+ "mlir::scf::SCFDialect"
+ ];
+ let options = [ AccDeviceTypeOption ];
+}
+
#endif // MLIR_DIALECT_OPENACC_TRANSFORMS_PASSES
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCRoutineLowering.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCRoutineLowering.cpp
new file mode 100644
index 0000000000000..f0d8dea006970
--- /dev/null
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCRoutineLowering.cpp
@@ -0,0 +1,252 @@
+//===- ACCRoutineLowering.cpp - Wrap ACC routines in compute_region -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass handles `acc routine` directive by creating specialized
+// functions with appropriate parallelism information that can be used for
+// eventual creation of device function.
+//
+// Overview:
+// ---------
+// For each acc.routine that is not bound by name, the pass creates a new
+// function (the "device" copy) whose body is a single acc.compute_region
+// containing a clone of the original (host) function body. Parallelism is
+// expressed by one acc.par_width derived from the routine's clauses (seq,
+// vector, worker, gang). The device copy created is simply a staging
+// place for eventual move to device module level function.
+//
+// Transformations:
+// ----------------
+// 1. Device function: Same signature as the host; attributes copied except
+// acc.routine_info. The acc.specialized_routine attribute is set with the
+// routine symbol, par level, and original function name.
+//
+// 2. Body: One acc.par_width, one acc.compute_region that clones the host
+// body. Multi-block host bodies are wrapped in scf.execute_region inside
+// the compute_region.
+//
+// 3. Finalization: acc.routine's func_name is updated to the device function.
+// For nohost routines, all uses of the host symbol are replaced with the
+// device symbol and the host function is erased.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h"
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/OpenACC/OpenACC.h"
+#include "mlir/Dialect/OpenACC/OpenACCParMapping.h"
+#include "mlir/Dialect/OpenACC/OpenACCUtilsCG.h"
+#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/SymbolTable.h"
+#include "mlir/IR/Value.h"
+
+namespace mlir {
+namespace acc {
+#define GEN_PASS_DEF_ACCROUTINELOWERING
+#include "mlir/Dialect/OpenACC/Transforms/Passes.h.inc"
+} // namespace acc
+} // namespace mlir
+
+#define DEBUG_TYPE "acc-routine-lowering"
+
+using namespace mlir;
+using namespace mlir::acc;
+
+namespace {
+
+/// Compute the ParLevel from an acc.routine op for specialization.
+static ParLevel computeParLevel(RoutineOp routineOp, DeviceType deviceType) {
+ auto gangDim = routineOp.getGangDimValue(deviceType);
+ if (!gangDim)
+ gangDim = routineOp.getGangDimValue();
+ if (gangDim) {
+ switch (*gangDim) {
+ case 1:
+ return ParLevel::gang_dim1;
+ case 2:
+ return ParLevel::gang_dim2;
+ case 3:
+ return ParLevel::gang_dim3;
+ default:
+ break;
+ }
+ }
+ if (routineOp.hasGang(deviceType) || routineOp.hasGang())
+ return ParLevel::gang_dim1;
+ if (routineOp.hasWorker(deviceType) || routineOp.hasWorker())
+ return ParLevel::worker;
+ if (routineOp.hasVector(deviceType) || routineOp.hasVector())
+ return ParLevel::vector;
+ return ParLevel::seq;
+}
+
+/// Collect return operands from the function (first block with func.return).
+static void getReturnValues(func::FuncOp func, SmallVectorImpl<Value> &result) {
+ result.clear();
+ for (Block &block : func.getBody().getBlocks()) {
+ if (auto returnOp = dyn_cast<func::ReturnOp>(block.getTerminator())) {
+ result.assign(returnOp.operand_begin(), returnOp.operand_end());
+ break;
+ }
+ }
+}
+
+/// Create the device function with the same signature as the host, set
+/// specialized_routine, and add a single block with the same block arguments.
+static func::FuncOp createFunctionForDeviceStaging(func::FuncOp hostFunc,
+ RoutineOp routineOp,
+ ParLevel parLevel, MLIRContext *ctx,
+ IRRewriter &rewriter) {
+ Location loc = hostFunc.getLoc();
+ FunctionType funcType = hostFunc.getFunctionType();
+ func::FuncOp deviceFunc =
+ func::FuncOp::create(rewriter, loc, hostFunc.getName(), funcType);
+ deviceFunc->setAttrs(hostFunc->getAttrs());
+ deviceFunc->removeAttr(getRoutineInfoAttrName());
+ deviceFunc->setAttr(
+ getSpecializedRoutineAttrName(),
+ SpecializedRoutineAttr::get(
+ ctx, SymbolRefAttr::get(ctx, routineOp.getSymName()),
+ ParLevelAttr::get(ctx, parLevel),
+ StringAttr::get(ctx, hostFunc.getName())));
+
+ Block *sourceBlock = &hostFunc.getBody().front();
+ Block *newBlock = rewriter.createBlock(&deviceFunc.getRegion());
+ for (BlockArgument arg : sourceBlock->getArguments())
+ newBlock->addArgument(arg.getType(), hostFunc.getLoc());
+
+ return deviceFunc;
+}
+
+/// Fill the device function body: one acc.par_width, one acc.compute_region
+/// (cloning the host body with inputArgsToMap), then func.return.
+static LogicalResult buildRoutineBody(func::FuncOp deviceFunc,
+ func::FuncOp hostFunc,
+ ArrayRef<Value> funcReturnVals,
+ ParLevel parLevel,
+ DefaultACCToGPUMappingPolicy &policy,
+ IRRewriter &rewriter) {
+ Block *newBlock = &deviceFunc.getBody().front();
+ Block *sourceBlock = &hostFunc.getBody().front();
+ Location loc = hostFunc.getLoc();
+ MLIRContext *ctx = rewriter.getContext();
+
+ rewriter.setInsertionPointToStart(newBlock);
+ GPUParallelDimAttr parDim = policy.map(ctx, parLevel);
+ Value parWidthVal = ParWidthOp::create(rewriter, loc, Value(), parDim);
+ SmallVector<Value, 4> inputArgs(newBlock->getArguments().begin(),
+ newBlock->getArguments().end());
+
+ // Normally the region passed to buildComputeRegion is something in the
+ // current function. Here we pass the body of the original (host) function as
+ // an optimization to avoid cloning twice (once for a staged device copy and
+ // again when creating the compute region). Since we clone only once, we must
+ // also provide the original function's arguments so the mapping is correct
+ // when cloning the body.
+ ValueRange sourceArgsToMap = sourceBlock->getArguments();
+
+ IRMapping mapping;
+ rewriter.setInsertionPointAfter(parWidthVal.getDefiningOp());
+ ComputeRegionOp computeRegion = buildComputeRegion(
+ loc, {parWidthVal}, inputArgs, RoutineOp::getOperationName(),
+ hostFunc.getBody(), rewriter, mapping,
+ /*output=*/funcReturnVals, /*kernelFuncName=*/{},
+ /*kernelModuleName=*/{}, /*stream=*/{}, sourceArgsToMap);
+ if (!computeRegion)
+ return failure();
+
+ rewriter.setInsertionPointAfter(computeRegion);
+ if (funcReturnVals.empty())
+ func::ReturnOp::create(rewriter, loc);
+ else
+ func::ReturnOp::create(rewriter, loc, computeRegion.getResults());
+
+ return success();
+}
+
+/// Update acc.routine refs and optionally erase host for nohost routines.
+static LogicalResult
+finalizeRoutines(SmallVectorImpl<std::tuple<func::FuncOp, func::FuncOp, RoutineOp>> &accRoutineInfo,
+ ModuleOp mod, MLIRContext *ctx) {
+ for (auto &[hostFunc, deviceFunc, routineOp] : accRoutineInfo) {
+ routineOp.setFuncNameAttr(
+ SymbolRefAttr::get(ctx, deviceFunc.getName()));
+ routineOp->moveBefore(deviceFunc);
+
+ if (routineOp.getNohost()) {
+ if (failed(SymbolTable::replaceAllSymbolUses(
+ StringAttr::get(ctx, hostFunc.getName()),
+ StringAttr::get(ctx, deviceFunc.getName()), mod))) {
+ routineOp.emitError("cannot replace symbol uses for acc routine");
+ return failure();
+ }
+ hostFunc->erase();
+ }
+ }
+ return success();
+}
+
+class ACCRoutineLowering
+ : public acc::impl::ACCRoutineLoweringBase<ACCRoutineLowering> {
+public:
+ using ACCRoutineLoweringBase::ACCRoutineLoweringBase;
+
+ void runOnOperation() override {
+ ModuleOp mod = getOperation();
+ if (mod.getOps<RoutineOp>().empty()) {
+ LLVM_DEBUG(llvm::dbgs()
+ << "Skipping ACCRoutineLowering - no acc.routine ops\n");
+ return;
+ }
+
+ SymbolTable symTab(mod);
+ MLIRContext *ctx = mod.getContext();
+ IRRewriter rewriter(ctx);
+ DefaultACCToGPUMappingPolicy policy;
+
+ // Tuple: host function, device function, routine operation
+ SmallVector<std::tuple<func::FuncOp, func::FuncOp, RoutineOp>, 4>
+ accRoutineInfo;
+
+ for (RoutineOp routineOp : mod.getOps<RoutineOp>()) {
+ if (routineOp.getBindNameValue() ||
+ routineOp.getBindNameValue(deviceType))
+ continue;
+
+ func::FuncOp hostFunc = symTab.lookup<func::FuncOp>(
+ routineOp.getFuncName().getLeafReference());
+ if (!hostFunc) {
+ routineOp.emitError("acc routine function not found in symbol table");
+ return signalPassFailure();
+ }
+ if (hostFunc.isExternal())
+ continue;
+
+ SmallVector<Value, 4> funcReturnVals;
+ getReturnValues(hostFunc, funcReturnVals);
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ ParLevel parLevel = computeParLevel(routineOp, deviceType);
+ func::FuncOp deviceFunc =
+ createFunctionForDeviceStaging(hostFunc, routineOp, parLevel, ctx, rewriter);
+ if (failed(buildRoutineBody(deviceFunc, hostFunc, funcReturnVals,
+ parLevel, policy, rewriter)))
+ return signalPassFailure();
+
+ accRoutineInfo.push_back({hostFunc, deviceFunc, routineOp});
+ symTab.insert(deviceFunc);
+ }
+
+ if (failed(finalizeRoutines(accRoutineInfo, mod, ctx)))
+ return signalPassFailure();
+ }
+};
+
+} // namespace
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
index 1bb16b4b9642d..2e81988b6610b 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/OpenACC/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_mlir_dialect_library(MLIROpenACCTransforms
ACCComputeLowering.cpp
+ ACCRoutineLowering.cpp
ACCDeclareGPUModuleInsertion.cpp
ACCIfClauseLowering.cpp
ACCImplicitData.cpp
diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtilsCG.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtilsCG.cpp
index f5e0e5c33fee4..61ac0574cb6e4 100644
--- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtilsCG.cpp
+++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtilsCG.cpp
@@ -15,6 +15,7 @@
#include "mlir/Dialect/OpenACC/OpenACCUtilsLoop.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
+#include "llvm/ADT/STLExtras.h"
namespace mlir {
namespace acc {
@@ -59,7 +60,8 @@ buildComputeRegion(Location loc, ValueRange launchArgs, ValueRange inputArgs,
llvm::StringRef origin, Region ®ionToClone,
RewriterBase &rewriter, IRMapping &mapping,
ValueRange output, FlatSymbolRefAttr kernelFuncName,
- FlatSymbolRefAttr kernelModuleName, Value stream) {
+ FlatSymbolRefAttr kernelModuleName, Value stream,
+ ValueRange inputArgsToMap) {
SmallVector<Type> resultTypes;
for (auto val : output)
resultTypes.push_back(val.getType());
@@ -71,6 +73,10 @@ buildComputeRegion(Location loc, ValueRange launchArgs, ValueRange inputArgs,
"empty region for acc.compute_region");
OpBuilder::InsertionGuard guard(rewriter);
+ ValueRange mapKeys = inputArgsToMap.empty() ? inputArgs : inputArgsToMap;
+ assert(mapKeys.size() == inputArgs.size() &&
+ "inputArgsToMap must have same size as inputArgs when provided");
+
auto parWidthType = ParWidthType::get(rewriter.getContext());
Block *entryBlock = rewriter.createBlock(&computeRegion.getRegion());
for (size_t i = 0; i < launchArgs.size(); ++i)
@@ -78,7 +84,7 @@ buildComputeRegion(Location loc, ValueRange launchArgs, ValueRange inputArgs,
for (Value input : inputArgs)
entryBlock->addArgument(input.getType(), loc);
for (size_t i = 0; i < inputArgs.size(); ++i)
- mapping.map(inputArgs[i], entryBlock->getArgument(launchArgs.size() + i));
+ mapping.map(mapKeys[i], entryBlock->getArgument(launchArgs.size() + i));
rewriter.setInsertionPointToStart(entryBlock);
if (regionToClone.getBlocks().size() == 1) {
for (auto &op : regionToClone.front().getOperations()) {
@@ -86,21 +92,38 @@ buildComputeRegion(Location loc, ValueRange launchArgs, ValueRange inputArgs,
break;
rewriter.clone(op, mapping);
}
+ SmallVector<Value> yieldOperands;
+ for (auto val : output)
+ yieldOperands.push_back(mapping.lookup(val));
+ rewriter.setInsertionPointToEnd(entryBlock);
+ YieldOp::create(rewriter, loc, yieldOperands);
} else {
auto exeRegion = mlir::acc::wrapMultiBlockRegionWithSCFExecuteRegion(
- regionToClone, mapping, loc, rewriter);
+ regionToClone, mapping, loc, rewriter, /*convertFuncReturn=*/true);
if (!exeRegion) {
rewriter.eraseOp(computeRegion);
return nullptr;
}
+ SmallVector<scf::YieldOp> yieldOps(
+ llvm::to_vector(exeRegion.getOps<scf::YieldOp>()));
+ assert(!yieldOps.empty() &&
+ "multi-block region must contain at least one scf.yield");
+ assert(llvm::all_of(yieldOps,
+ [&output](scf::YieldOp yieldOp) {
+ return yieldOp.getNumOperands() ==
+ static_cast<int64_t>(output.size()) &&
+ llvm::all_of(
+ llvm::zip(yieldOp.getOperands(), output),
+ [](auto pair) {
+ return std::get<0>(pair).getType() ==
+ std::get<1>(pair).getType();
+ });
+ }) &&
+ "each scf.yield operand count and types must match output");
+ rewriter.setInsertionPointToEnd(entryBlock);
+ YieldOp::create(rewriter, loc, exeRegion.getResults());
}
- SmallVector<Value> yieldOperands;
- for (auto val : output)
- yieldOperands.push_back(mapping.lookup(val));
- rewriter.setInsertionPointToEnd(entryBlock);
- YieldOp::create(rewriter, loc, yieldOperands);
-
return computeRegion;
}
diff --git a/mlir/test/Dialect/OpenACC/acc-routine-lowering.mlir b/mlir/test/Dialect/OpenACC/acc-routine-lowering.mlir
new file mode 100644
index 0000000000000..56197ca7a7bdd
--- /dev/null
+++ b/mlir/test/Dialect/OpenACC/acc-routine-lowering.mlir
@@ -0,0 +1,125 @@
+// RUN: mlir-opt %s -acc-routine-lowering -split-input-file | FileCheck %s
+
+// Test seq routine: body is wrapped in acc.compute_region with one
+// acc.par_width (sequential).
+acc.routine @routine_seq func(@host_foo) seq
+// CHECK: acc.routine @routine_seq func(@
+// CHECK: acc.specialized_routine = #acc.specialized_routine<@routine_seq, <seq>, "host_foo">
+// CHECK-NOT: acc.kernel_environment
+// CHECK: acc.par_width {par_dim = #acc.par_dim<sequential>}
+// CHECK: acc.compute_region
+// CHECK: origin = "acc.routine"
+func.func @host_foo(%buf: memref<8xi32>) {
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ memref.store %c0_i32, %buf[%c0] : memref<8xi32>
+ return
+}
+
+// -----
+
+// Test vector routine: par_width has thread_x with default mapping of vector dimension.
+acc.routine @routine_vec func(@host_bar) vector
+// CHECK: acc.routine @routine_vec func(@
+// CHECK: acc.specialized_routine = #acc.specialized_routine<@routine_vec, <vector>, "host_bar">
+// CHECK-NOT: acc.kernel_environment
+// CHECK: acc.par_width {par_dim = #acc.par_dim<thread_x>}
+// CHECK: acc.compute_region
+// CHECK: origin = "acc.routine"
+func.func @host_bar(%buf: memref<4xi32>) {
+ %c0 = arith.constant 0 : index
+ %c0_i32 = arith.constant 0 : i32
+ memref.store %c0_i32, %buf[%c0] : memref<4xi32>
+ return
+}
+
+// -----
+
+// Test worker routine: par_width has thread_y with default mapping of worker dimension.
+acc.routine @routine_worker func(@host_worker) worker
+// CHECK: acc.routine @routine_worker func(@
+// CHECK: acc.specialized_routine = #acc.specialized_routine<@routine_worker, <worker>, "host_worker">
+// CHECK: acc.par_width {par_dim = #acc.par_dim<thread_y>}
+// CHECK: acc.compute_region
+// CHECK: origin = "acc.routine"
+func.func @host_worker(%x: i32) {
+ return
+}
+
+// -----
+
+// Test gang routine: par_width has block_x (gang dim 1) with default mapping of gang dimension.
+acc.routine @routine_gang func(@host_gang) gang
+// CHECK: acc.routine @routine_gang func(@
+// CHECK: acc.specialized_routine = #acc.specialized_routine<@routine_gang, <gang_dim1>, "host_gang">
+// CHECK: acc.par_width {par_dim = #acc.par_dim<block_x>}
+// CHECK: acc.compute_region
+// CHECK: origin = "acc.routine"
+func.func @host_gang() {
+ return
+}
+
+// -----
+
+// Test routine with single return value: device func returns compute_region result.
+acc.routine @routine_ret func(@host_ret) seq
+// CHECK: acc.routine @routine_ret func(@
+// CHECK: acc.specialized_routine = #acc.specialized_routine<@routine_ret, <seq>, "host_ret">
+// CHECK: %[[CR:[0-9]+]] = acc.compute_region
+// CHECK: acc.yield %{{.*}} : i32
+// CHECK: return %[[CR]] : i32
+func.func @host_ret(%cond: i1) -> i32 {
+ %c0 = arith.constant 0 : i32
+ %c1 = arith.constant 1 : i32
+ %r = arith.select %cond, %c0, %c1 : i32
+ return %r : i32
+}
+
+// -----
+
+// Test routine with unstructured control flow.
+acc.routine @routine_cf func(@host_cf) seq
+// CHECK: acc.routine @routine_cf func(@
+// CHECK: acc.specialized_routine = #acc.specialized_routine<@routine_cf, <seq>, "host_cf">
+// CHECK: acc.par_width {par_dim = #acc.par_dim<sequential>}
+// CHECK: %[[CR:[0-9]+]] = acc.compute_region
+// CHECK: %[[EXE:[0-9]+]] = scf.execute_region
+// CHECK: scf.yield %{{.*}} : i32
+// CHECK: acc.yield %[[EXE]] : i32
+// CHECK: } {origin = "acc.routine"}
+// CHECK: return %[[CR]] : i32
+func.func @host_cf(%cond: i1) -> i32 {
+ cf.cond_br %cond, ^then, ^else
+^then:
+ %c0 = arith.constant 0 : i32
+ return %c0 : i32
+^else:
+ %c1 = arith.constant 1 : i32
+ return %c1 : i32
+}
+
+// -----
+
+// Test routine with bind(name): pass skips it, routine and func remain unchanged.
+acc.routine @routine_bind func(@host_bind) seq bind("myname")
+// CHECK: acc.routine @routine_bind func(@host_bind)
+// CHECK-NOT: acc.specialized_routine
+// CHECK: func.func @host_bind()
+func.func @host_bind() {
+ return
+}
+
+// -----
+
+// Test multiple routines in one module: each gets its own device copy.
+acc.routine @r_a func(@f_a) seq
+acc.routine @r_b func(@f_a) vector
+// CHECK: acc.routine @r_a func(@
+// CHECK: acc.specialized_routine = #acc.specialized_routine<@r_a, <seq>, "f_a">
+// CHECK: acc.par_width {par_dim = #acc.par_dim<sequential>}
+// CHECK: acc.routine @r_b func(@
+// CHECK: acc.specialized_routine = #acc.specialized_routine<@r_b, <vector>, "f_a">
+// CHECK: acc.par_width {par_dim = #acc.par_dim<thread_x>}
+func.func @f_a() {
+ return
+}
diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsCGTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsCGTest.cpp
index 671fa6c5560eb..c1962564b04f3 100644
--- a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsCGTest.cpp
+++ b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsCGTest.cpp
@@ -148,3 +148,69 @@ TEST_F(OpenACCUtilsCGTest, buildComputeRegionWithLaunchArgs) {
func::ReturnOp::create(rewriter, loc);
}
+
+// Test buildComputeRegion with inputArgsToMap: clone a region whose block args
+// are the "source" values, while the op's inputArgs are "device" values. The
+// mapping should map source -> compute_region block args so the cloned body
+// uses the correct values.
+TEST_F(OpenACCUtilsCGTest, buildComputeRegionWithInputArgsToMap) {
+ OwningOpRef<ModuleOp> module = ModuleOp::create(b, loc);
+ IRRewriter rewriter(&context);
+ rewriter.setInsertionPointToEnd(module->getBody());
+
+ // Source function: one block with one index arg, body uses it (addi), then
+ // return (terminator is not cloned).
+ auto funcTy = b.getFunctionType({b.getIndexType()}, {});
+ auto sourceFunc =
+ func::FuncOp::create(rewriter, loc, "source", funcTy);
+ Block *sourceBlock = sourceFunc.addEntryBlock();
+ rewriter.setInsertionPointToStart(sourceBlock);
+ auto c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
+ auto addOp = arith::AddIOp::create(rewriter, loc, sourceBlock->getArgument(0),
+ c1.getResult());
+ (void)addOp;
+ func::ReturnOp::create(rewriter, loc);
+
+ // Set insertion back to module so hostFunc is also added to the module.
+ rewriter.setInsertionPointToEnd(module->getBody());
+
+ // Current function: we have a "device" block with one index arg. We will
+ // clone sourceFunc's body into a compute_region, with inputArgs = [device arg]
+ // and inputArgsToMap = [source block arg], so the clone maps source arg ->
+ // compute region block arg.
+ auto hostFuncTy = b.getFunctionType({b.getIndexType()}, {});
+ auto hostFunc = func::FuncOp::create(rewriter, loc, "host", hostFuncTy);
+ Block *deviceBlock = hostFunc.addEntryBlock();
+ rewriter.setInsertionPointToStart(deviceBlock);
+
+ Region &sourceRegion = sourceFunc.getBody();
+ ValueRange sourceArgsToMap = sourceRegion.front().getArguments();
+ ValueRange inputArgs = deviceBlock->getArguments();
+
+ IRMapping mapping;
+ auto cr = buildComputeRegion(
+ loc, /*launchArgs=*/{}, inputArgs, SerialOp::getOperationName(),
+ sourceRegion, rewriter, mapping,
+ /*output=*/{}, /*kernelFuncName=*/{}, /*kernelModuleName=*/{},
+ /*stream=*/{}, sourceArgsToMap);
+
+ ASSERT_TRUE(cr);
+ EXPECT_EQ(cr.getInputArgs().size(), 1u);
+ EXPECT_EQ(cr.getInputArgs()[0], deviceBlock->getArgument(0));
+ Block &crBlock = cr.getRegion().front();
+ EXPECT_EQ(crBlock.getNumArguments(), 1u);
+ // The cloned body should use the compute_region's block arg (mapped from
+ // source arg). So the only non-constant operand of the addi in the clone
+ // should be crBlock.getArgument(0).
+ bool foundAddI = false;
+ for (Operation &op : crBlock.getOperations()) {
+ if (isa<arith::AddIOp>(op)) {
+ foundAddI = true;
+ EXPECT_EQ(op.getOperand(0), crBlock.getArgument(0));
+ break;
+ }
+ }
+ EXPECT_TRUE(foundAddI);
+
+ func::ReturnOp::create(rewriter, loc);
+}
>From dc5f235749259095cbcf7bc62a938eed8e7559d4 Mon Sep 17 00:00:00 2001
From: Razvan Lupusoru <rlupusoru at nvidia.com>
Date: Thu, 12 Mar 2026 13:51:01 -0700
Subject: [PATCH 2/2] Fix format
---
.../OpenACC/Transforms/ACCRoutineLowering.cpp | 50 +++++++++----------
.../Dialect/OpenACC/Utils/OpenACCUtilsCG.cpp | 15 +++---
.../Dialect/OpenACC/OpenACCUtilsCGTest.cpp | 9 ++--
3 files changed, 36 insertions(+), 38 deletions(-)
diff --git a/mlir/lib/Dialect/OpenACC/Transforms/ACCRoutineLowering.cpp b/mlir/lib/Dialect/OpenACC/Transforms/ACCRoutineLowering.cpp
index f0d8dea006970..04b2260045838 100644
--- a/mlir/lib/Dialect/OpenACC/Transforms/ACCRoutineLowering.cpp
+++ b/mlir/lib/Dialect/OpenACC/Transforms/ACCRoutineLowering.cpp
@@ -101,21 +101,21 @@ static void getReturnValues(func::FuncOp func, SmallVectorImpl<Value> &result) {
/// Create the device function with the same signature as the host, set
/// specialized_routine, and add a single block with the same block arguments.
static func::FuncOp createFunctionForDeviceStaging(func::FuncOp hostFunc,
- RoutineOp routineOp,
- ParLevel parLevel, MLIRContext *ctx,
- IRRewriter &rewriter) {
+ RoutineOp routineOp,
+ ParLevel parLevel,
+ MLIRContext *ctx,
+ IRRewriter &rewriter) {
Location loc = hostFunc.getLoc();
FunctionType funcType = hostFunc.getFunctionType();
func::FuncOp deviceFunc =
func::FuncOp::create(rewriter, loc, hostFunc.getName(), funcType);
deviceFunc->setAttrs(hostFunc->getAttrs());
deviceFunc->removeAttr(getRoutineInfoAttrName());
- deviceFunc->setAttr(
- getSpecializedRoutineAttrName(),
- SpecializedRoutineAttr::get(
- ctx, SymbolRefAttr::get(ctx, routineOp.getSymName()),
- ParLevelAttr::get(ctx, parLevel),
- StringAttr::get(ctx, hostFunc.getName())));
+ deviceFunc->setAttr(getSpecializedRoutineAttrName(),
+ SpecializedRoutineAttr::get(
+ ctx, SymbolRefAttr::get(ctx, routineOp.getSymName()),
+ ParLevelAttr::get(ctx, parLevel),
+ StringAttr::get(ctx, hostFunc.getName())));
Block *sourceBlock = &hostFunc.getBody().front();
Block *newBlock = rewriter.createBlock(&deviceFunc.getRegion());
@@ -127,12 +127,10 @@ static func::FuncOp createFunctionForDeviceStaging(func::FuncOp hostFunc,
/// Fill the device function body: one acc.par_width, one acc.compute_region
/// (cloning the host body with inputArgsToMap), then func.return.
-static LogicalResult buildRoutineBody(func::FuncOp deviceFunc,
- func::FuncOp hostFunc,
- ArrayRef<Value> funcReturnVals,
- ParLevel parLevel,
- DefaultACCToGPUMappingPolicy &policy,
- IRRewriter &rewriter) {
+static LogicalResult
+buildRoutineBody(func::FuncOp deviceFunc, func::FuncOp hostFunc,
+ ArrayRef<Value> funcReturnVals, ParLevel parLevel,
+ DefaultACCToGPUMappingPolicy &policy, IRRewriter &rewriter) {
Block *newBlock = &deviceFunc.getBody().front();
Block *sourceBlock = &hostFunc.getBody().front();
Location loc = hostFunc.getLoc();
@@ -142,8 +140,8 @@ static LogicalResult buildRoutineBody(func::FuncOp deviceFunc,
GPUParallelDimAttr parDim = policy.map(ctx, parLevel);
Value parWidthVal = ParWidthOp::create(rewriter, loc, Value(), parDim);
SmallVector<Value, 4> inputArgs(newBlock->getArguments().begin(),
- newBlock->getArguments().end());
-
+ newBlock->getArguments().end());
+
// Normally the region passed to buildComputeRegion is something in the
// current function. Here we pass the body of the original (host) function as
// an optimization to avoid cloning twice (once for a staged device copy and
@@ -151,7 +149,7 @@ static LogicalResult buildRoutineBody(func::FuncOp deviceFunc,
// also provide the original function's arguments so the mapping is correct
// when cloning the body.
ValueRange sourceArgsToMap = sourceBlock->getArguments();
-
+
IRMapping mapping;
rewriter.setInsertionPointAfter(parWidthVal.getDefiningOp());
ComputeRegionOp computeRegion = buildComputeRegion(
@@ -172,12 +170,12 @@ static LogicalResult buildRoutineBody(func::FuncOp deviceFunc,
}
/// Update acc.routine refs and optionally erase host for nohost routines.
-static LogicalResult
-finalizeRoutines(SmallVectorImpl<std::tuple<func::FuncOp, func::FuncOp, RoutineOp>> &accRoutineInfo,
- ModuleOp mod, MLIRContext *ctx) {
+static LogicalResult finalizeRoutines(
+ SmallVectorImpl<std::tuple<func::FuncOp, func::FuncOp, RoutineOp>>
+ &accRoutineInfo,
+ ModuleOp mod, MLIRContext *ctx) {
for (auto &[hostFunc, deviceFunc, routineOp] : accRoutineInfo) {
- routineOp.setFuncNameAttr(
- SymbolRefAttr::get(ctx, deviceFunc.getName()));
+ routineOp.setFuncNameAttr(SymbolRefAttr::get(ctx, deviceFunc.getName()));
routineOp->moveBefore(deviceFunc);
if (routineOp.getNohost()) {
@@ -234,10 +232,10 @@ class ACCRoutineLowering
OpBuilder::InsertionGuard guard(rewriter);
ParLevel parLevel = computeParLevel(routineOp, deviceType);
- func::FuncOp deviceFunc =
- createFunctionForDeviceStaging(hostFunc, routineOp, parLevel, ctx, rewriter);
+ func::FuncOp deviceFunc = createFunctionForDeviceStaging(
+ hostFunc, routineOp, parLevel, ctx, rewriter);
if (failed(buildRoutineBody(deviceFunc, hostFunc, funcReturnVals,
- parLevel, policy, rewriter)))
+ parLevel, policy, rewriter)))
return signalPassFailure();
accRoutineInfo.push_back({hostFunc, deviceFunc, routineOp});
diff --git a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtilsCG.cpp b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtilsCG.cpp
index 61ac0574cb6e4..661074444a055 100644
--- a/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtilsCG.cpp
+++ b/mlir/lib/Dialect/OpenACC/Utils/OpenACCUtilsCG.cpp
@@ -55,13 +55,14 @@ std::optional<DataLayout> getDataLayout(Operation *op, bool allowDefault) {
return std::nullopt;
}
-ComputeRegionOp
-buildComputeRegion(Location loc, ValueRange launchArgs, ValueRange inputArgs,
- llvm::StringRef origin, Region ®ionToClone,
- RewriterBase &rewriter, IRMapping &mapping,
- ValueRange output, FlatSymbolRefAttr kernelFuncName,
- FlatSymbolRefAttr kernelModuleName, Value stream,
- ValueRange inputArgsToMap) {
+ComputeRegionOp buildComputeRegion(Location loc, ValueRange launchArgs,
+ ValueRange inputArgs, llvm::StringRef origin,
+ Region ®ionToClone,
+ RewriterBase &rewriter, IRMapping &mapping,
+ ValueRange output,
+ FlatSymbolRefAttr kernelFuncName,
+ FlatSymbolRefAttr kernelModuleName,
+ Value stream, ValueRange inputArgsToMap) {
SmallVector<Type> resultTypes;
for (auto val : output)
resultTypes.push_back(val.getType());
diff --git a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsCGTest.cpp b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsCGTest.cpp
index c1962564b04f3..2940145e40c74 100644
--- a/mlir/unittests/Dialect/OpenACC/OpenACCUtilsCGTest.cpp
+++ b/mlir/unittests/Dialect/OpenACC/OpenACCUtilsCGTest.cpp
@@ -161,8 +161,7 @@ TEST_F(OpenACCUtilsCGTest, buildComputeRegionWithInputArgsToMap) {
// Source function: one block with one index arg, body uses it (addi), then
// return (terminator is not cloned).
auto funcTy = b.getFunctionType({b.getIndexType()}, {});
- auto sourceFunc =
- func::FuncOp::create(rewriter, loc, "source", funcTy);
+ auto sourceFunc = func::FuncOp::create(rewriter, loc, "source", funcTy);
Block *sourceBlock = sourceFunc.addEntryBlock();
rewriter.setInsertionPointToStart(sourceBlock);
auto c1 = arith::ConstantIndexOp::create(rewriter, loc, 1);
@@ -175,9 +174,9 @@ TEST_F(OpenACCUtilsCGTest, buildComputeRegionWithInputArgsToMap) {
rewriter.setInsertionPointToEnd(module->getBody());
// Current function: we have a "device" block with one index arg. We will
- // clone sourceFunc's body into a compute_region, with inputArgs = [device arg]
- // and inputArgsToMap = [source block arg], so the clone maps source arg ->
- // compute region block arg.
+ // clone sourceFunc's body into a compute_region, with inputArgs = [device
+ // arg] and inputArgsToMap = [source block arg], so the clone maps source arg
+ // -> compute region block arg.
auto hostFuncTy = b.getFunctionType({b.getIndexType()}, {});
auto hostFunc = func::FuncOp::create(rewriter, loc, "host", hostFuncTy);
Block *deviceBlock = hostFunc.addEntryBlock();
More information about the Mlir-commits
mailing list