[flang] [llvm] [flang][cuda] Lower set/get default stream (PR #181775)
via llvm-commits
llvm-commits at lists.llvm.org
Mon Feb 16 21:49:21 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-fir-hlfir
Author: Valentin Clement (バレンタイン クレメン) (clementval)
<details>
<summary>Changes</summary>
---
Full diff: https://github.com/llvm/llvm-project/pull/181775.diff
4 Files Affected:
- (modified) flang-rt/unittests/Runtime/CUDA/DefaultStream.cpp (+15)
- (modified) flang/include/flang/Optimizer/Builder/CUDAIntrinsicCall.h (+4)
- (modified) flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp (+37)
- (modified) flang/module/cuda_runtime_api.f90 (+2-2)
``````````diff
diff --git a/flang-rt/unittests/Runtime/CUDA/DefaultStream.cpp b/flang-rt/unittests/Runtime/CUDA/DefaultStream.cpp
index af04487c59d02..7e255905ceb1d 100644
--- a/flang-rt/unittests/Runtime/CUDA/DefaultStream.cpp
+++ b/flang-rt/unittests/Runtime/CUDA/DefaultStream.cpp
@@ -25,3 +25,18 @@ TEST(DefaultStreamTest, GetAndSetTest) {
cudaStream_t outStream = RTDECL(CUFGetDefaultStream)();
EXPECT_EQ(outStream, stream);
}
+
+TEST(DefaultStreamTest, GetAndSetArrayTest) {
+ using Fortran::common::TypeCategory;
+ cudaStream_t defaultStream = RTDECL(CUFGetDefaultStream)();
+ EXPECT_EQ(defaultStream, nullptr);
+
+ cudaStream_t outStream = RTDECL(CUFGetDefaultStream)();
+ EXPECT_EQ(outStream, nullptr);
+ cudaStream_t stream;
+ cudaStreamCreate(&stream);
+ EXPECT_EQ(cudaSuccess, cudaGetLastError());
+ RTDECL(CUFSetDefaultStream)(stream);
+ outStream = RTDECL(CUFGetDefaultStream)();
+ EXPECT_EQ(outStream, stream);
+}
diff --git a/flang/include/flang/Optimizer/Builder/CUDAIntrinsicCall.h b/flang/include/flang/Optimizer/Builder/CUDAIntrinsicCall.h
index d92f0c72dde0d..3e23a4dfa0203 100644
--- a/flang/include/flang/Optimizer/Builder/CUDAIntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/CUDAIntrinsicCall.h
@@ -51,12 +51,16 @@ struct CUDAIntrinsicLibrary : IntrinsicLibrary {
mlir::Value genBarrierTryWaitSleep(mlir::Type, llvm::ArrayRef<mlir::Value>);
mlir::Value genClusterBlockIndex(mlir::Type, llvm::ArrayRef<mlir::Value>);
mlir::Value genClusterDimBlocks(mlir::Type, llvm::ArrayRef<mlir::Value>);
+ fir::ExtendedValue
+ genCUDASetDefaultStream(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue
genCUDASetDefaultStreamArray(mlir::Type,
llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue
genCUDAGetDefaultStreamArg(mlir::Type,
llvm::ArrayRef<fir::ExtendedValue>);
+ mlir::Value genCUDAGetDefaultStreamNull(mlir::Type,
+ llvm::ArrayRef<mlir::Value>);
void genFenceProxyAsync(llvm::ArrayRef<fir::ExtendedValue>);
template <const char *fctName, int extent>
fir::ExtendedValue genLDXXFunc(mlir::Type,
diff --git a/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp b/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp
index 4c4403dcd71a9..4986c57048081 100644
--- a/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/CUDAIntrinsicCall.cpp
@@ -388,11 +388,21 @@ static constexpr IntrinsicHandler cudaHandlers[]{
&CI::genCUDAGetDefaultStreamArg),
{{{"devptr", asAddr}}},
/*isElemental=*/false},
+ {"cudagetstreamdefaultnull",
+ static_cast<CUDAIntrinsicLibrary::ElementalGenerator>(
+ &CI::genCUDAGetDefaultStreamNull),
+ {},
+ /*isElemental=*/false},
{"cudasetstreamarray",
static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
&CI::genCUDASetDefaultStreamArray),
{{{"devptr", asAddr}, {"stream", asValue}}},
/*isElemental=*/false},
+ {"cudasetstreamdefault",
+ static_cast<CUDAIntrinsicLibrary::ExtendedGenerator>(
+ &CI::genCUDASetDefaultStream),
+ {{{"stream", asValue}}},
+ /*isElemental=*/false},
{"fence_proxy_async",
static_cast<CUDAIntrinsicLibrary::SubroutineGenerator>(
&CI::genFenceProxyAsync),
@@ -1114,6 +1124,20 @@ CUDAIntrinsicLibrary::genClusterDimBlocks(mlir::Type resultType,
return res;
}
+// CUDASETSTREAMDEFAULT
+fir::ExtendedValue CUDAIntrinsicLibrary::genCUDASetDefaultStream(
+ mlir::Type resTy, llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 1);
+ mlir::Value stream = fir::getBase(args[0]);
+ mlir::Type i64Ty = builder.getI64Type();
+ auto ctx = builder.getContext();
+ mlir::FunctionType ftype = mlir::FunctionType::get(ctx, {i64Ty}, {});
+ auto funcOp =
+ builder.createFunction(loc, RTNAME_STRING(CUFSetDefaultStream), ftype);
+ auto call = fir::CallOp::create(builder, loc, funcOp, {stream});
+ return call.getResult(0);
+}
+
// CUDASETSTREAMARRAY
fir::ExtendedValue CUDAIntrinsicLibrary::genCUDASetDefaultStreamArray(
mlir::Type resTy, llvm::ArrayRef<fir::ExtendedValue> args) {
@@ -1154,6 +1178,19 @@ fir::ExtendedValue CUDAIntrinsicLibrary::genCUDAGetDefaultStreamArg(
return call.getResult(0);
}
+// CUDAGETDEFAULTSTREAMNULL
+mlir::Value CUDAIntrinsicLibrary::genCUDAGetDefaultStreamNull(
+ mlir::Type resultType, llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 0);
+ mlir::Type i64Ty = builder.getI64Type();
+ auto ctx = builder.getContext();
+ mlir::FunctionType ftype = mlir::FunctionType::get(ctx, {}, {i64Ty});
+ auto funcOp =
+ builder.createFunction(loc, RTNAME_STRING(CUFGetDefaultStream), ftype);
+ auto call = fir::CallOp::create(builder, loc, funcOp, {});
+ return call.getResult(0);
+}
+
// FENCE_PROXY_ASYNC
void CUDAIntrinsicLibrary::genFenceProxyAsync(
llvm::ArrayRef<fir::ExtendedValue> args) {
diff --git a/flang/module/cuda_runtime_api.f90 b/flang/module/cuda_runtime_api.f90
index d6cb6d8c0f715..7c6968cabc373 100644
--- a/flang/module/cuda_runtime_api.f90
+++ b/flang/module/cuda_runtime_api.f90
@@ -17,13 +17,13 @@ integer(kind=cuda_stream_kind) function cudagetstreamdefaultarg(devptr)
!DIR$ IGNORE_TKR (TKR) devptr
integer, device :: devptr(*)
end function
- integer(kind=cuda_stream_kind) function cudastreamgetdefaultnull()
+ integer(kind=cuda_stream_kind) function cudagetstreamdefaultnull()
import cuda_stream_kind
end function
end interface
interface cudaforsetdefaultstream
- integer function cudasetdefaultstream(stream)
+ integer function cudasetstreamdefault(stream)
import cuda_stream_kind
!DIR$ IGNORE_TKR (K) stream
integer(kind=cuda_stream_kind), value :: stream
``````````
</details>
https://github.com/llvm/llvm-project/pull/181775
More information about the llvm-commits
mailing list