[flang-commits] [flang] [flang][cuda] Rewrite predefined var in cuf kernel inside host function (PR #205974)

via flang-commits flang-commits at lists.llvm.org
Thu Jun 25 22:26:52 PDT 2026


llvmorg-github-actions[bot] wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>

Predefined variables are normally set inside device function but can appear in cuf kernel when a device function is inlined inside the cuf kernel.

Perviously the pass was only rewriting the predefined variables inside device function and not host function with cuf kernels. Update the pass to support this case. 

---
Full diff: https://github.com/llvm/llvm-project/pull/205974.diff


2 Files Affected:

- (modified) flang/lib/Optimizer/Transforms/CUDA/CUFPredefinedVarToGPU.cpp (+61-34) 
- (modified) flang/test/Fir/CUDA/predefined-variables.mlir (+3-2) 


``````````diff
diff --git a/flang/lib/Optimizer/Transforms/CUDA/CUFPredefinedVarToGPU.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFPredefinedVarToGPU.cpp
index 901f15d792173..2dd332488269d 100644
--- a/flang/lib/Optimizer/Transforms/CUDA/CUFPredefinedVarToGPU.cpp
+++ b/flang/lib/Optimizer/Transforms/CUDA/CUFPredefinedVarToGPU.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
 #include "flang/Optimizer/Dialect/FIROps.h"
 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
@@ -101,11 +102,60 @@ processDeclareOp(mlir::OpBuilder &builder, mlir::Location loc,
 struct CUFPredefinedVarToGPU
     : public fir::impl::CUFPredefinedVarToGPUBase<CUFPredefinedVarToGPU> {
 
+  void rewritePredefinedVars(mlir::Region &region, mlir::Location loc) {
+    if (region.empty())
+      return;
+
+    bool hasPredefinedDeclares = false;
+    region.walk([&](fir::DeclareOp declareOp) {
+      llvm::StringRef uniqName = declareOp.getUniqName();
+      hasPredefinedDeclares |= uniqName == mangleBuiltin(threadidx) ||
+                               uniqName == mangleBuiltin(blockidx) ||
+                               uniqName == mangleBuiltin(blockdim) ||
+                               uniqName == mangleBuiltin(griddim);
+    });
+    if (!hasPredefinedDeclares)
+      return;
+
+    mlir::OpBuilder builder(region.getContext());
+    builder.setInsertionPointToStart(&region.front());
+    auto c1 = mlir::arith::ConstantOp::create(
+        builder, loc, builder.getI32Type(), builder.getI32IntegerAttr(1));
+    llvm::SmallVector<mlir::Value, 3> threadids, blockids, blockdims, griddims;
+    createForAllDimensions<mlir::NVVM::ThreadIdXOp, mlir::NVVM::ThreadIdYOp,
+                           mlir::NVVM::ThreadIdZOp>(builder, loc, c1, threadids,
+                                                    /*incrementByOne=*/true);
+    createForAllDimensions<mlir::NVVM::BlockIdXOp, mlir::NVVM::BlockIdYOp,
+                           mlir::NVVM::BlockIdZOp>(builder, loc, c1, blockids,
+                                                   /*incrementByOne=*/true);
+    createForAllDimensions<mlir::NVVM::GridDimXOp, mlir::NVVM::GridDimYOp,
+                           mlir::NVVM::GridDimZOp>(builder, loc, c1, griddims);
+    createForAllDimensions<mlir::NVVM::BlockDimXOp, mlir::NVVM::BlockDimYOp,
+                           mlir::NVVM::BlockDimZOp>(builder, loc, c1,
+                                                    blockdims);
+
+    llvm::SmallVector<mlir::Operation *> opsToDelete;
+    region.walk([&](fir::DeclareOp declareOp) {
+      processDeclareOp(builder, loc, declareOp, mangleBuiltin(threadidx),
+                       threadids, opsToDelete);
+      processDeclareOp(builder, loc, declareOp, mangleBuiltin(blockidx),
+                       blockids, opsToDelete);
+      processDeclareOp(builder, loc, declareOp, mangleBuiltin(blockdim),
+                       blockdims, opsToDelete);
+      processDeclareOp(builder, loc, declareOp, mangleBuiltin(griddim),
+                       griddims, opsToDelete);
+    });
+
+    for (auto *op : opsToDelete)
+      op->erase();
+  }
+
   void runOnOperation() override {
     func::FuncOp funcOp = getOperation();
     if (funcOp.getBody().empty())
       return;
 
+    bool rewrittenWholeFunction = false;
     if (auto cudaProcAttr =
             funcOp.getOperation()->getAttrOfType<cuf::ProcAttributeAttr>(
                 cuf::getProcAttrName())) {
@@ -113,42 +163,19 @@ struct CUFPredefinedVarToGPU
           cudaProcAttr.getValue() == cuf::ProcAttribute::Global ||
           cudaProcAttr.getValue() == cuf::ProcAttribute::GridGlobal ||
           cudaProcAttr.getValue() == cuf::ProcAttribute::HostDevice) {
-        mlir::Location loc = funcOp.getLoc();
-        mlir::OpBuilder builder(funcOp.getContext());
-        builder.setInsertionPointToStart(&funcOp.getBody().front());
-        auto c1 = mlir::arith::ConstantOp::create(
-            builder, loc, builder.getI32Type(), builder.getI32IntegerAttr(1));
-        llvm::SmallVector<mlir::Value, 3> threadids, blockids, blockdims,
-            griddims;
-        createForAllDimensions<mlir::NVVM::ThreadIdXOp, mlir::NVVM::ThreadIdYOp,
-                               mlir::NVVM::ThreadIdZOp>(
-            builder, loc, c1, threadids, /*incrementByOne=*/true);
-        createForAllDimensions<mlir::NVVM::BlockIdXOp, mlir::NVVM::BlockIdYOp,
-                               mlir::NVVM::BlockIdZOp>(
-            builder, loc, c1, blockids, /*incrementByOne=*/true);
-        createForAllDimensions<mlir::NVVM::GridDimXOp, mlir::NVVM::GridDimYOp,
-                               mlir::NVVM::GridDimZOp>(builder, loc, c1,
-                                                       griddims);
-        createForAllDimensions<mlir::NVVM::BlockDimXOp, mlir::NVVM::BlockDimYOp,
-                               mlir::NVVM::BlockDimZOp>(builder, loc, c1,
-                                                        blockdims);
-
-        llvm::SmallVector<mlir::Operation *> opsToDelete;
-        funcOp.walk([&](fir::DeclareOp declareOp) {
-          processDeclareOp(builder, loc, declareOp, mangleBuiltin(threadidx),
-                           threadids, opsToDelete);
-          processDeclareOp(builder, loc, declareOp, mangleBuiltin(blockidx),
-                           blockids, opsToDelete);
-          processDeclareOp(builder, loc, declareOp, mangleBuiltin(blockdim),
-                           blockdims, opsToDelete);
-          processDeclareOp(builder, loc, declareOp, mangleBuiltin(griddim),
-                           griddims, opsToDelete);
-        });
-
-        for (auto *op : opsToDelete)
-          op->erase();
+        rewritePredefinedVars(funcOp.getRegion(), funcOp.getLoc());
+        rewrittenWholeFunction = true;
       }
     }
+
+    if (rewrittenWholeFunction)
+      return;
+
+    // Host functions containing cuf.kernel regions can still carry predefined
+    // vars in the kernel body. Rewrite them in-place as well.
+    funcOp.walk([&](cuf::KernelOp kernelOp) {
+      rewritePredefinedVars(kernelOp.getRegion(), kernelOp.getLoc());
+    });
   }
 };
 
diff --git a/flang/test/Fir/CUDA/predefined-variables.mlir b/flang/test/Fir/CUDA/predefined-variables.mlir
index 90b72f3bc1bcb..50f269e4c65db 100644
--- a/flang/test/Fir/CUDA/predefined-variables.mlir
+++ b/flang/test/Fir/CUDA/predefined-variables.mlir
@@ -237,7 +237,7 @@ func.func @_QMbarPgfoo2(%arg0: !fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>,
 
 // -----
 
-func.func @surviving_predefined_vars(%arg0: i32, %arg1: i32, %arg2: i32) attributes {cuf.proc_attr = #cuf.cuda_proc<global>} {
+func.func @surviving_predefined_vars(%arg0: i32, %arg1: i32, %arg2: i32) {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c4 = arith.constant 4 : index
@@ -279,7 +279,7 @@ func.func @surviving_predefined_vars(%arg0: i32, %arg1: i32, %arg2: i32) attribu
 
 // -----
 
-func.func @surviving_predefined_vars(%arg0: i32, %arg1: i32, %arg2: i32) attributes {cuf.proc_attr = #cuf.cuda_proc<global>} {
+func.func @surviving_predefined_vars(%arg0: i32, %arg1: i32, %arg2: i32) {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c4 = arith.constant 4 : index
@@ -317,3 +317,4 @@ func.func @surviving_predefined_vars(%arg0: i32, %arg1: i32, %arg2: i32) attribu
 // CHECK-NOT: _QM__fortran_builtinsE__builtin_blockidx
 // CHECK-NOT: _QM__fortran_builtinsE__builtin_griddim
 // CHECK-NOT: _QM__fortran_builtinsE__builtin_threadidx
+// CHECK: nvvm.read.ptx.sreg.tid.x

``````````

</details>


https://github.com/llvm/llvm-project/pull/205974


More information about the flang-commits mailing list