[flang-commits] [flang] 7f746be - [flang][cuda] Rewrite predefined var in cuf kernel inside host function (#205974)

via flang-commits flang-commits at lists.llvm.org
Fri Jun 26 08:57:06 PDT 2026


Author: Valentin Clement (バレンタイン クレメン)
Date: 2026-06-26T08:57:01-07:00
New Revision: 7f746be13f23d02f4f29dfebbbc37868443f5961

URL: https://github.com/llvm/llvm-project/commit/7f746be13f23d02f4f29dfebbbc37868443f5961
DIFF: https://github.com/llvm/llvm-project/commit/7f746be13f23d02f4f29dfebbbc37868443f5961.diff

LOG: [flang][cuda] Rewrite predefined var in cuf kernel inside host function (#205974)

Added: 
    

Modified: 
    flang/lib/Optimizer/Transforms/CUDA/CUFPredefinedVarToGPU.cpp
    flang/test/Fir/CUDA/predefined-variables.mlir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Optimizer/Transforms/CUDA/CUFPredefinedVarToGPU.cpp b/flang/lib/Optimizer/Transforms/CUDA/CUFPredefinedVarToGPU.cpp
index 901f15d792173..2dd332488269d 100644
--- a/flang/lib/Optimizer/Transforms/CUDA/CUFPredefinedVarToGPU.cpp
+++ b/flang/lib/Optimizer/Transforms/CUDA/CUFPredefinedVarToGPU.cpp
@@ -6,6 +6,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
 #include "flang/Optimizer/Dialect/FIROps.h"
 #include "flang/Optimizer/Dialect/FIROpsSupport.h"
 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
@@ -101,11 +102,60 @@ processDeclareOp(mlir::OpBuilder &builder, mlir::Location loc,
 struct CUFPredefinedVarToGPU
     : public fir::impl::CUFPredefinedVarToGPUBase<CUFPredefinedVarToGPU> {
 
+  void rewritePredefinedVars(mlir::Region &region, mlir::Location loc) {
+    if (region.empty())
+      return;
+
+    bool hasPredefinedDeclares = false;
+    region.walk([&](fir::DeclareOp declareOp) {
+      llvm::StringRef uniqName = declareOp.getUniqName();
+      hasPredefinedDeclares |= uniqName == mangleBuiltin(threadidx) ||
+                               uniqName == mangleBuiltin(blockidx) ||
+                               uniqName == mangleBuiltin(blockdim) ||
+                               uniqName == mangleBuiltin(griddim);
+    });
+    if (!hasPredefinedDeclares)
+      return;
+
+    mlir::OpBuilder builder(region.getContext());
+    builder.setInsertionPointToStart(&region.front());
+    auto c1 = mlir::arith::ConstantOp::create(
+        builder, loc, builder.getI32Type(), builder.getI32IntegerAttr(1));
+    llvm::SmallVector<mlir::Value, 3> threadids, blockids, blockdims, griddims;
+    createForAllDimensions<mlir::NVVM::ThreadIdXOp, mlir::NVVM::ThreadIdYOp,
+                           mlir::NVVM::ThreadIdZOp>(builder, loc, c1, threadids,
+                                                    /*incrementByOne=*/true);
+    createForAllDimensions<mlir::NVVM::BlockIdXOp, mlir::NVVM::BlockIdYOp,
+                           mlir::NVVM::BlockIdZOp>(builder, loc, c1, blockids,
+                                                   /*incrementByOne=*/true);
+    createForAllDimensions<mlir::NVVM::GridDimXOp, mlir::NVVM::GridDimYOp,
+                           mlir::NVVM::GridDimZOp>(builder, loc, c1, griddims);
+    createForAllDimensions<mlir::NVVM::BlockDimXOp, mlir::NVVM::BlockDimYOp,
+                           mlir::NVVM::BlockDimZOp>(builder, loc, c1,
+                                                    blockdims);
+
+    llvm::SmallVector<mlir::Operation *> opsToDelete;
+    region.walk([&](fir::DeclareOp declareOp) {
+      processDeclareOp(builder, loc, declareOp, mangleBuiltin(threadidx),
+                       threadids, opsToDelete);
+      processDeclareOp(builder, loc, declareOp, mangleBuiltin(blockidx),
+                       blockids, opsToDelete);
+      processDeclareOp(builder, loc, declareOp, mangleBuiltin(blockdim),
+                       blockdims, opsToDelete);
+      processDeclareOp(builder, loc, declareOp, mangleBuiltin(griddim),
+                       griddims, opsToDelete);
+    });
+
+    for (auto *op : opsToDelete)
+      op->erase();
+  }
+
   void runOnOperation() override {
     func::FuncOp funcOp = getOperation();
     if (funcOp.getBody().empty())
       return;
 
+    bool rewrittenWholeFunction = false;
     if (auto cudaProcAttr =
             funcOp.getOperation()->getAttrOfType<cuf::ProcAttributeAttr>(
                 cuf::getProcAttrName())) {
@@ -113,42 +163,19 @@ struct CUFPredefinedVarToGPU
           cudaProcAttr.getValue() == cuf::ProcAttribute::Global ||
           cudaProcAttr.getValue() == cuf::ProcAttribute::GridGlobal ||
           cudaProcAttr.getValue() == cuf::ProcAttribute::HostDevice) {
-        mlir::Location loc = funcOp.getLoc();
-        mlir::OpBuilder builder(funcOp.getContext());
-        builder.setInsertionPointToStart(&funcOp.getBody().front());
-        auto c1 = mlir::arith::ConstantOp::create(
-            builder, loc, builder.getI32Type(), builder.getI32IntegerAttr(1));
-        llvm::SmallVector<mlir::Value, 3> threadids, blockids, blockdims,
-            griddims;
-        createForAllDimensions<mlir::NVVM::ThreadIdXOp, mlir::NVVM::ThreadIdYOp,
-                               mlir::NVVM::ThreadIdZOp>(
-            builder, loc, c1, threadids, /*incrementByOne=*/true);
-        createForAllDimensions<mlir::NVVM::BlockIdXOp, mlir::NVVM::BlockIdYOp,
-                               mlir::NVVM::BlockIdZOp>(
-            builder, loc, c1, blockids, /*incrementByOne=*/true);
-        createForAllDimensions<mlir::NVVM::GridDimXOp, mlir::NVVM::GridDimYOp,
-                               mlir::NVVM::GridDimZOp>(builder, loc, c1,
-                                                       griddims);
-        createForAllDimensions<mlir::NVVM::BlockDimXOp, mlir::NVVM::BlockDimYOp,
-                               mlir::NVVM::BlockDimZOp>(builder, loc, c1,
-                                                        blockdims);
-
-        llvm::SmallVector<mlir::Operation *> opsToDelete;
-        funcOp.walk([&](fir::DeclareOp declareOp) {
-          processDeclareOp(builder, loc, declareOp, mangleBuiltin(threadidx),
-                           threadids, opsToDelete);
-          processDeclareOp(builder, loc, declareOp, mangleBuiltin(blockidx),
-                           blockids, opsToDelete);
-          processDeclareOp(builder, loc, declareOp, mangleBuiltin(blockdim),
-                           blockdims, opsToDelete);
-          processDeclareOp(builder, loc, declareOp, mangleBuiltin(griddim),
-                           griddims, opsToDelete);
-        });
-
-        for (auto *op : opsToDelete)
-          op->erase();
+        rewritePredefinedVars(funcOp.getRegion(), funcOp.getLoc());
+        rewrittenWholeFunction = true;
       }
     }
+
+    if (rewrittenWholeFunction)
+      return;
+
+    // Host functions containing cuf.kernel regions can still carry predefined
+    // vars in the kernel body. Rewrite them in-place as well.
+    funcOp.walk([&](cuf::KernelOp kernelOp) {
+      rewritePredefinedVars(kernelOp.getRegion(), kernelOp.getLoc());
+    });
   }
 };
 

diff  --git a/flang/test/Fir/CUDA/predefined-variables.mlir b/flang/test/Fir/CUDA/predefined-variables.mlir
index 90b72f3bc1bcb..50f269e4c65db 100644
--- a/flang/test/Fir/CUDA/predefined-variables.mlir
+++ b/flang/test/Fir/CUDA/predefined-variables.mlir
@@ -237,7 +237,7 @@ func.func @_QMbarPgfoo2(%arg0: !fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>,
 
 // -----
 
-func.func @surviving_predefined_vars(%arg0: i32, %arg1: i32, %arg2: i32) attributes {cuf.proc_attr = #cuf.cuda_proc<global>} {
+func.func @surviving_predefined_vars(%arg0: i32, %arg1: i32, %arg2: i32) {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c4 = arith.constant 4 : index
@@ -279,7 +279,7 @@ func.func @surviving_predefined_vars(%arg0: i32, %arg1: i32, %arg2: i32) attribu
 
 // -----
 
-func.func @surviving_predefined_vars(%arg0: i32, %arg1: i32, %arg2: i32) attributes {cuf.proc_attr = #cuf.cuda_proc<global>} {
+func.func @surviving_predefined_vars(%arg0: i32, %arg1: i32, %arg2: i32) {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c4 = arith.constant 4 : index
@@ -317,3 +317,4 @@ func.func @surviving_predefined_vars(%arg0: i32, %arg1: i32, %arg2: i32) attribu
 // CHECK-NOT: _QM__fortran_builtinsE__builtin_blockidx
 // CHECK-NOT: _QM__fortran_builtinsE__builtin_griddim
 // CHECK-NOT: _QM__fortran_builtinsE__builtin_threadidx
+// CHECK: nvvm.read.ptx.sreg.tid.x


        


More information about the flang-commits mailing list