[flang-commits] [flang] [flang][cuda] Add CUFFunctionRewrite pass (PR #174650)
via flang-commits
flang-commits at lists.llvm.org
Tue Jan 6 13:10:54 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Valentin Clement (バレンタイン クレメン) (clementval)
<details>
<summary>Changes</summary>
This rewrite some CUDA Fortran specific like `on_device` function to constant boolean values.
---
Full diff: https://github.com/llvm/llvm-project/pull/174650.diff
4 Files Affected:
- (modified) flang/include/flang/Optimizer/Transforms/Passes.td (+5)
- (modified) flang/lib/Optimizer/Transforms/CMakeLists.txt (+1)
- (added) flang/lib/Optimizer/Transforms/CUDA/CUFFunctionRewrite.cpp (+103)
- (added) flang/test/Fir/CUDA/cuda-function-rewrite.mlir (+44)
``````````diff
diff --git a/flang/include/flang/Optimizer/Transforms/Passes.td b/flang/include/flang/Optimizer/Transforms/Passes.td
index 47ffc4be93b33..dd2023223f1d2 100644
--- a/flang/include/flang/Optimizer/Transforms/Passes.td
+++ b/flang/include/flang/Optimizer/Transforms/Passes.td
@@ -533,6 +533,11 @@ def CUFDeviceFuncTransform
/*default=*/"0", "CUDA compute capability version">];
}
+def CUFFunctionRewrite : Pass<"cuf-function-rewrite", ""> {
+ let summary = "Convert some CUDA Fortran specific call";
+ let dependentDialects = ["fir::FIROpsDialect"];
+}
+
def CUFLaunchAttachAttr : Pass<"cuf-launch-attach-attr", ""> {
let summary = "Attach CUDA attribute to CUF kernel generated launch";
let description = [{
diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt
index 4496e80aa7c40..4ee5eab6247e1 100644
--- a/flang/lib/Optimizer/Transforms/CMakeLists.txt
+++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt
@@ -11,6 +11,7 @@ add_flang_library(FIRTransforms
ControlFlowConverter.cpp
CUDA/CUFAllocationConversion.cpp
CUDA/CUFDeviceFuncTransform.cpp
+ CUDA/CUFFunctionRewrite.cpp
CUDA/CUFLaunchAttachAttr.cpp
CUDA/CUFPredefinedVarToGPU.cpp
CUFAddConstructor.cpp
diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFFunctionRewrite.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFFunctionRewrite.cpp
new file mode 100644
index 0000000000000..d6f9dc097831c
--- /dev/null
+++ b/flang/lib/Optimizer/Transforms/CUDA/CUFFunctionRewrite.cpp
@@ -0,0 +1,103 @@
+//===-- CUFFUnctionRewrite.cpp --------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/CodeGen/TypeConverter.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/Support/DataLayout.h"
+#include "flang/Optimizer/Transforms/Passes.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/ValueRange.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/Debug.h"
+#include <string_view>
+
+#define DEBUG_TYPE "flang-cuf-function-rewrite"
+
+namespace fir {
+#define GEN_PASS_DEF_CUFFUNCTIONREWRITE
+#include "flang/Optimizer/Transforms/Passes.h.inc"
+} // namespace fir
+
+using namespace mlir;
+
+namespace {
+
+using genFunctionType =
+ std::function<mlir::Value(mlir::PatternRewriter &, fir::CallOp op)>;
+
+class CallConversion : public OpRewritePattern<fir::CallOp> {
+public:
+ CallConversion(MLIRContext *context)
+ : OpRewritePattern<fir::CallOp>(context) {}
+
+ LogicalResult
+ matchAndRewrite(fir::CallOp op,
+ mlir::PatternRewriter &rewriter) const override {
+ auto callee = op.getCallee();
+ if (!callee)
+ return failure();
+ auto name = callee->getRootReference().getValue();
+
+ if (genMappings_.contains(name)) {
+ auto fct = genMappings_.find(name);
+ mlir::Value result = fct->second(rewriter, op);
+ if (result)
+ rewriter.replaceOp(op, result);
+ else
+ rewriter.eraseOp(op);
+ return success();
+ }
+ return failure();
+ }
+
+private:
+ static mlir::Value genOnDevice(mlir::PatternRewriter &rewriter,
+ fir::CallOp op) {
+ assert(op.getArgs().size() == 0 && "expect 0 arguments");
+ mlir::Location loc = op.getLoc();
+ unsigned inGPUMod = op->getParentOfType<gpu::GPUModuleOp>() ? 1 : 0;
+ mlir::Type i1Ty = rewriter.getIntegerType(1);
+ mlir::Value t = mlir::arith::ConstantOp::create(
+ rewriter, loc, i1Ty, rewriter.getIntegerAttr(i1Ty, inGPUMod));
+ return fir::ConvertOp::create(rewriter, loc, op.getResult(0).getType(), t);
+ }
+
+ const llvm::StringMap<genFunctionType> genMappings_ = {
+ {"on_device", &genOnDevice}};
+};
+
+class CUFFunctionRewrite
+ : public fir::impl::CUFFunctionRewriteBase<CUFFunctionRewrite> {
+public:
+ void runOnOperation() override {
+ auto *ctx = &getContext();
+ mlir::RewritePatternSet patterns(ctx);
+
+ patterns.insert<CallConversion>(patterns.getContext());
+
+ if (mlir::failed(
+ mlir::applyPatternsGreedily(getOperation(), std::move(patterns)))) {
+ mlir::emitError(mlir::UnknownLoc::get(ctx),
+ "error in CUFFunctionRewrite op conversion\n");
+ signalPassFailure();
+ }
+ }
+};
+
+} // namespace
diff --git a/flang/test/Fir/CUDA/cuda-function-rewrite.mlir b/flang/test/Fir/CUDA/cuda-function-rewrite.mlir
new file mode 100644
index 0000000000000..da1d601a2eb8b
--- /dev/null
+++ b/flang/test/Fir/CUDA/cuda-function-rewrite.mlir
@@ -0,0 +1,44 @@
+// RUN: fir-opt --split-input-file --cuf-function-rewrite %s | FileCheck %s
+
+gpu.module @cuda_device_mod {
+ func.func @_QMmtestsPdo2(%arg0: !fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "c"}, %arg1: !fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "i"}) attributes {cuf.proc_attr = #cuf.cuda_proc<host_device>} {
+ %c2_i32 = arith.constant 2 : i32
+ %c1_i32 = arith.constant 1 : i32
+ %0 = fir.dummy_scope : !fir.dscope
+ %5 = fir.declare %arg0 dummy_scope %0 {uniq_name = "_QMmtestsFdo2Ec"} : (!fir.ref<i32>, !fir.dscope) -> !fir.ref<i32>
+ %8 = fir.declare %arg1 dummy_scope %0 {uniq_name = "_QMmtestsFdo2Ei"} : (!fir.ref<i32>, !fir.dscope) -> !fir.ref<i32>
+ %13 = fir.call @on_device() proc_attrs<bind_c> fastmath<contract> : () -> !fir.logical<4>
+ %14 = fir.convert %13 : (!fir.logical<4>) -> i1
+ fir.if %14 {
+ fir.store %c1_i32 to %5 : !fir.ref<i32>
+ } else {
+ fir.store %c2_i32 to %5 : !fir.ref<i32>
+ }
+ return
+ }
+}
+
+// CHECK-LABEL: gpu.module @cuda_device_mod
+// CHECK: func.func @_QMmtestsPdo2
+// CHECK: fir.if %true
+
+// -----
+
+func.func @_QMmtestsPdo3(%arg0: !fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "c"}, %arg1: !fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "i"}) attributes {cuf.proc_attr = #cuf.cuda_proc<host_device>} {
+ %c2_i32 = arith.constant 2 : i32
+ %c1_i32 = arith.constant 1 : i32
+ %0 = fir.dummy_scope : !fir.dscope
+ %5 = fir.declare %arg0 dummy_scope %0 {uniq_name = "_QMmtestsFdo2Ec"} : (!fir.ref<i32>, !fir.dscope) -> !fir.ref<i32>
+ %8 = fir.declare %arg1 dummy_scope %0 {uniq_name = "_QMmtestsFdo2Ei"} : (!fir.ref<i32>, !fir.dscope) -> !fir.ref<i32>
+ %13 = fir.call @on_device() proc_attrs<bind_c> fastmath<contract> : () -> !fir.logical<4>
+ %14 = fir.convert %13 : (!fir.logical<4>) -> i1
+ fir.if %14 {
+ fir.store %c1_i32 to %5 : !fir.ref<i32>
+ } else {
+ fir.store %c2_i32 to %5 : !fir.ref<i32>
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @_QMmtestsPdo3
+// CHECK: fir.if %false
``````````
</details>
https://github.com/llvm/llvm-project/pull/174650
More information about the flang-commits
mailing list