[flang-commits] [flang] 070c888 - [flang][cuda] Lower syncwarp to NVVM intrinsic (#126164)
via flang-commits
flang-commits at lists.llvm.org
Thu Feb 6 19:43:24 PST 2025
Author: Valentin Clement (バレンタイン クレメン)
Date: 2025-02-06T19:43:21-08:00
New Revision: 070c88829251defc268fbfe1c1fe18d2066bdce2
URL: https://github.com/llvm/llvm-project/commit/070c88829251defc268fbfe1c1fe18d2066bdce2
DIFF: https://github.com/llvm/llvm-project/commit/070c88829251defc268fbfe1c1fe18d2066bdce2.diff
LOG: [flang][cuda] Lower syncwarp to NVVM intrinsic (#126164)
Added:
Modified:
flang/include/flang/Optimizer/Builder/IntrinsicCall.h
flang/lib/Optimizer/Builder/IntrinsicCall.cpp
flang/module/cudadevice.f90
flang/test/Lower/CUDA/cuda-device-proc.cuf
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index 32010ae83641e3f..47e8a77fa6aecb3 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -406,6 +406,7 @@ struct IntrinsicLibrary {
mlir::Value genSyncThreadsAnd(mlir::Type, llvm::ArrayRef<mlir::Value>);
mlir::Value genSyncThreadsCount(mlir::Type, llvm::ArrayRef<mlir::Value>);
mlir::Value genSyncThreadsOr(mlir::Type, llvm::ArrayRef<mlir::Value>);
+ void genSyncWarp(llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genSystem(std::optional<mlir::Type>,
mlir::ArrayRef<fir::ExtendedValue> args);
void genSystemClock(llvm::ArrayRef<fir::ExtendedValue>);
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index a6a77dd58677b17..9b684520ec07820 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -680,6 +680,7 @@ static constexpr IntrinsicHandler handlers[]{
{"syncthreads_and", &I::genSyncThreadsAnd, {}, /*isElemental=*/false},
{"syncthreads_count", &I::genSyncThreadsCount, {}, /*isElemental=*/false},
{"syncthreads_or", &I::genSyncThreadsOr, {}, /*isElemental=*/false},
+ {"syncwarp", &I::genSyncWarp, {}, /*isElemental=*/false},
{"system",
&I::genSystem,
{{{"command", asBox}, {"exitstat", asBox, handleDynamicOptional}}},
@@ -7704,6 +7705,18 @@ IntrinsicLibrary::genSyncThreadsOr(mlir::Type resultType,
return builder.create<fir::CallOp>(loc, funcOp, args).getResult(0);
}
+// SYNCWARP
+void IntrinsicLibrary::genSyncWarp(llvm::ArrayRef<fir::ExtendedValue> args) {
+ assert(args.size() == 1);
+ constexpr llvm::StringLiteral funcName = "llvm.nvvm.bar.warp.sync";
+ mlir::Value mask = fir::getBase(args[0]);
+ mlir::FunctionType funcType =
+ mlir::FunctionType::get(builder.getContext(), {mask.getType()}, {});
+ auto funcOp = builder.createFunction(loc, funcName, funcType);
+ llvm::SmallVector<mlir::Value> argsList{mask};
+ builder.create<fir::CallOp>(loc, funcOp, argsList);
+}
+
// SYSTEM
fir::ExtendedValue
IntrinsicLibrary::genSystem(std::optional<mlir::Type> resultType,
diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90
index 47526bccd98fe6c..45b9f2c83863835 100644
--- a/flang/module/cudadevice.f90
+++ b/flang/module/cudadevice.f90
@@ -49,7 +49,7 @@ attributes(device) integer function syncthreads_or(value)
public :: syncthreads_or
interface
- attributes(device) subroutine syncwarp(mask) bind(c, name='__syncwarp')
+ attributes(device) subroutine syncwarp(mask)
integer, value :: mask
end subroutine
end interface
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index ec825263474c1ee..17a6a1d965640e9 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -47,7 +47,7 @@ end
! CHECK-LABEL: func.func @_QPdevsub() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
! CHECK: fir.call @llvm.nvvm.barrier0() fastmath<contract> : () -> ()
-! CHECK: fir.call @__syncwarp(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32) -> ()
+! CHECK: fir.call @llvm.nvvm.bar.warp.sync(%c1{{.*}}) fastmath<contract> : (i32) -> ()
! CHECK: fir.call @llvm.nvvm.membar.gl() fastmath<contract> : () -> ()
! CHECK: fir.call @llvm.nvvm.membar.cta() fastmath<contract> : () -> ()
! CHECK: fir.call @llvm.nvvm.membar.sys() fastmath<contract> : () -> ()
@@ -102,13 +102,13 @@ end
! CHECK-LABEL: func.func @_QPhost1()
! CHECK: cuf.kernel
! CHECK: fir.call @llvm.nvvm.barrier0() fastmath<contract> : () -> ()
-! CHECK: fir.call @__syncwarp(%c1{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32) -> ()
+! CHECK: fir.call @llvm.nvvm.bar.warp.sync(%c1{{.*}}) fastmath<contract> : (i32) -> ()
! CHECK: fir.call @llvm.nvvm.barrier0.and(%c1{{.*}}) fastmath<contract> : (i32) -> i32
! CHECK: fir.call @llvm.nvvm.barrier0.popc(%c1{{.*}}) fastmath<contract> : (i32) -> i32
! CHECK: fir.call @llvm.nvvm.barrier0.or(%c1{{.*}}) fastmath<contract> : (i32) -> i32
! CHECK: func.func private @llvm.nvvm.barrier0()
-! CHECK: func.func private @__syncwarp(i32) attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncwarp", fir.proc_attrs = #fir.proc_attrs<bind_c>}
+! CHECK: func.func private @llvm.nvvm.bar.warp.sync(i32)
! CHECK: func.func private @llvm.nvvm.membar.gl()
! CHECK: func.func private @llvm.nvvm.membar.cta()
! CHECK: func.func private @llvm.nvvm.membar.sys()
More information about the flang-commits
mailing list