[flang-commits] [flang] 91f9f0f - [flang][cuda] Update cuf.kernel_launch stream and conversion (#136179)

via flang-commits flang-commits at lists.llvm.org
Thu Apr 17 12:55:11 PDT 2025


Author: Valentin Clement (バレンタイン クレメン)
Date: 2025-04-17T12:55:08-07:00
New Revision: 91f9f0fa1e3a776006fcbdb9a9c975682f35b10b

URL: https://github.com/llvm/llvm-project/commit/91f9f0fa1e3a776006fcbdb9a9c975682f35b10b
DIFF: https://github.com/llvm/llvm-project/commit/91f9f0fa1e3a776006fcbdb9a9c975682f35b10b.diff

LOG: [flang][cuda] Update cuf.kernel_launch stream and conversion (#136179)

Update `cuf.kernel_launch` to take the stream as a reference. Update the
conversion to insert the `cuf.stream_cast` op so the stream can be set
as dependency.

Added: 
    

Modified: 
    flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
    flang/lib/Lower/ConvertCall.cpp
    flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
    flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
    flang/test/Fir/CUDA/cuda-launch.fir
    flang/test/Lower/CUDA/cuda-kernel-calls.cuf

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
index f55f3e8a4466d..ccf9969e73a8e 100644
--- a/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
+++ b/flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
@@ -200,7 +200,7 @@ def cuf_KernelLaunchOp : cuf_Op<"kernel_launch", [CallOpInterface,
 
   let arguments = (ins SymbolRefAttr:$callee, I32:$grid_x, I32:$grid_y,
       I32:$grid_z, I32:$block_x, I32:$block_y, I32:$block_z,
-      Optional<I32>:$bytes, Optional<AnyIntegerType>:$stream,
+      Optional<I32>:$bytes, Optional<fir_ReferenceType>:$stream,
       Variadic<AnyType>:$args, OptionalAttr<DictArrayAttr>:$arg_attrs,
       OptionalAttr<DictArrayAttr>:$res_attrs);
 
@@ -237,6 +237,8 @@ def cuf_KernelLaunchOp : cuf_Op<"kernel_launch", [CallOpInterface,
           *this, getNbNoArgOperand(), getArgs().size() - 1);
     }
   }];
+
+  let hasVerifier = 1;
 }
 
 def cuf_KernelOp : cuf_Op<"kernel", [AttrSizedOperandSegments,

diff  --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp
index 31f2650917781..f28778ce6c1c9 100644
--- a/flang/lib/Lower/ConvertCall.cpp
+++ b/flang/lib/Lower/ConvertCall.cpp
@@ -589,7 +589,7 @@ Fortran::lower::genCallOpAndResult(
 
     mlir::Value stream; // stream is optional.
     if (caller.getCallDescription().chevrons().size() > 3)
-      stream = fir::getBase(converter.genExprValue(
+      stream = fir::getBase(converter.genExprAddr(
           caller.getCallDescription().chevrons()[3], stmtCtx));
 
     builder.create<cuf::KernelLaunchOp>(

diff  --git a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
index ce197d48d4860..2c6d22f6f6c7d 100644
--- a/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
+++ b/flang/lib/Optimizer/Dialect/CUF/CUFOps.cpp
@@ -139,6 +139,24 @@ llvm::LogicalResult cuf::DeallocateOp::verify() {
   return mlir::success();
 }
 
+//===----------------------------------------------------------------------===//
+// KernelLaunchOp
+//===----------------------------------------------------------------------===//
+
+template <typename OpTy>
+static llvm::LogicalResult checkStreamType(OpTy op) {
+  if (!op.getStream())
+    return mlir::success();
+  auto refTy = mlir::dyn_cast<fir::ReferenceType>(op.getStream().getType());
+  if (!refTy.getEleTy().isInteger(64))
+    return op.emitOpError("stream is expected to be a i64 reference");
+  return mlir::success();
+}
+
+llvm::LogicalResult cuf::KernelLaunchOp::verify() {
+  return checkStreamType(*this);
+}
+
 //===----------------------------------------------------------------------===//
 // KernelOp
 //===----------------------------------------------------------------------===//
@@ -324,10 +342,7 @@ void cuf::SharedMemoryOp::build(
 //===----------------------------------------------------------------------===//
 
 llvm::LogicalResult cuf::StreamCastOp::verify() {
-  auto refTy = mlir::dyn_cast<fir::ReferenceType>(getStream().getType());
-  if (!refTy.getEleTy().isInteger(64))
-    return emitOpError("stream is expected to be a i64 reference");
-  return mlir::success();
+  return checkStreamType(*this);
 }
 
 // Tablegen operators

diff  --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
index caa59c6c17d0f..77364cb837c3c 100644
--- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
+++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
@@ -879,8 +879,13 @@ struct CUFLaunchOpConversion
       gpuLaunchOp.getClusterSizeYMutable().assign(clusterDimY);
       gpuLaunchOp.getClusterSizeZMutable().assign(clusterDimZ);
     }
-    if (op.getStream())
-      gpuLaunchOp.getAsyncObjectMutable().assign(op.getStream());
+    if (op.getStream()) {
+      mlir::OpBuilder::InsertionGuard guard(rewriter);
+      rewriter.setInsertionPoint(gpuLaunchOp);
+      mlir::Value stream =
+          rewriter.create<cuf::StreamCastOp>(loc, op.getStream());
+      gpuLaunchOp.getAsyncDependenciesMutable().append(stream);
+    }
     if (procAttr)
       gpuLaunchOp->setAttr(cuf::getProcAttrName(), procAttr);
     rewriter.replaceOp(op, gpuLaunchOp);
@@ -933,6 +938,7 @@ class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
                                          /*forceUnifiedTBAATree=*/false, *dl);
     target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,
                            mlir::gpu::GPUDialect>();
+    target.addLegalOp<cuf::StreamCastOp>();
     cuf::populateCUFToFIRConversionPatterns(typeConverter, *dl, symtab,
                                             patterns);
     if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,

diff  --git a/flang/test/Fir/CUDA/cuda-launch.fir b/flang/test/Fir/CUDA/cuda-launch.fir
index 621772efff415..319991546d3fe 100644
--- a/flang/test/Fir/CUDA/cuda-launch.fir
+++ b/flang/test/Fir/CUDA/cuda-launch.fir
@@ -146,8 +146,7 @@ module attributes {gpu.container_module, dlti.dl_spec = #dlti.dl_spec<#dlti.dl_e
     %1:2 = hlfir.declare %0 {uniq_name = "_QMtest_callFhostEstream"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
     %c1_i32 = arith.constant 1 : i32
     %c0_i32 = arith.constant 0 : i32
-    %2 = fir.load %1#0 : !fir.ref<i64>
-    cuf.kernel_launch @_QMdevptrPtest<<<%c1_i32, %c1_i32, %c1_i32, %c1_i32, %c1_i32, %c1_i32, %c0_i32, %2 : i64>>>()
+    cuf.kernel_launch @_QMdevptrPtest<<<%c1_i32, %c1_i32, %c1_i32, %c1_i32, %c1_i32, %c1_i32, %c0_i32, %1#0 : !fir.ref<i64>>>>()
     return
   }
 }
@@ -155,5 +154,5 @@ module attributes {gpu.container_module, dlti.dl_spec = #dlti.dl_spec<#dlti.dl_e
 // CHECK-LABEL: func.func @_QQmain()
 // CHECK: %[[STREAM:.*]] = fir.alloca i64 {bindc_name = "stream", uniq_name = "_QMtest_callFhostEstream"}
 // CHECK: %[[DECL_STREAM:.*]]:2 = hlfir.declare %[[STREAM]] {uniq_name = "_QMtest_callFhostEstream"} : (!fir.ref<i64>) -> (!fir.ref<i64>, !fir.ref<i64>)
-// CHECK: %[[STREAM_LOADED:.*]] = fir.load %[[DECL_STREAM]]#0 : !fir.ref<i64>
-// CHECK: gpu.launch_func <%[[STREAM_LOADED]] : i64> @cuda_device_mod::@_QMdevptrPtest
+// CHECK: %[[TOKEN:.*]] = cuf.stream_cast %[[DECL_STREAM]]#0 : <i64>
+// CHECK: gpu.launch_func [%[[TOKEN]]] @cuda_device_mod::@_QMdevptrPtest

diff  --git a/flang/test/Lower/CUDA/cuda-kernel-calls.cuf b/flang/test/Lower/CUDA/cuda-kernel-calls.cuf
index d66d2811f7a8b..71e594e4742ec 100644
--- a/flang/test/Lower/CUDA/cuda-kernel-calls.cuf
+++ b/flang/test/Lower/CUDA/cuda-kernel-calls.cuf
@@ -45,8 +45,8 @@ contains
     call dev_kernel0<<<10, 20, 2>>>()
 ! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}, %c2{{.*}}>>>() 
 
-    call dev_kernel0<<<10, 20, 2, 0>>>()
-! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}, %c2{{.*}}, %c0{{.*}}>>>()
+    call dev_kernel0<<<10, 20, 2, 0_8>>>()
+! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel0<<<%c10{{.*}}, %c1{{.*}}, %c1{{.*}}, %c20{{.*}}, %c1{{.*}}, %c1{{.*}}, %c2{{.*}}, %{{.*}} : !fir.ref<i64>>>>()
 
     call dev_kernel1<<<1, 32>>>(a)
 ! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel1<<<%c1{{.*}}, %c1{{.*}}, %c1{{.*}}, %c32{{.*}}, %c1{{.*}}, %c1{{.*}}>>>(%{{.*}}) : (!fir.ref<f32>)
@@ -55,7 +55,7 @@ contains
 ! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel1<<<%c-1{{.*}}, %c1{{.*}}, %c1{{.*}}, %c32{{.*}}, %c1{{.*}}, %c1{{.*}}>>>(%{{.*}})
 
     call dev_kernel1<<<*,32,0,stream>>>(a)
-! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel1<<<%c-1{{.*}}, %c1{{.*}}, %c1{{.*}}, %c32{{.*}}, %c1{{.*}}, %c1{{.*}}, %c0{{.*}}, %{{.*}} : i64>>>(%{{.*}}) : (!fir.ref<f32>)
+! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel1<<<%c-1{{.*}}, %c1{{.*}}, %c1{{.*}}, %c32{{.*}}, %c1{{.*}}, %c1{{.*}}, %c0{{.*}}, %{{.*}} : !fir.ref<i64>>>>(%{{.*}}) : (!fir.ref<f32>)
 
   end
 


        


More information about the flang-commits mailing list