[flang-commits] [flang] [flang][cuda] Lower syncwarp to NVVM intrinsic (PR #126164)

via flang-commits flang-commits at lists.llvm.org
Thu Feb 6 17:11:07 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/126164.diff


4 Files Affected:

- (modified) flang/include/flang/Optimizer/Builder/IntrinsicCall.h (+1) 
- (modified) flang/lib/Optimizer/Builder/IntrinsicCall.cpp (+13) 
- (modified) flang/module/cudadevice.f90 (+1-1) 
- (modified) flang/test/Lower/CUDA/cuda-device-proc.cuf (+3-3) 


``````````diff
diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index 32010ae83641e3f..47e8a77fa6aecb3 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -406,6 +406,7 @@ struct IntrinsicLibrary {
   mlir::Value genSyncThreadsAnd(mlir::Type, llvm::ArrayRef<mlir::Value>);
   mlir::Value genSyncThreadsCount(mlir::Type, llvm::ArrayRef<mlir::Value>);
   mlir::Value genSyncThreadsOr(mlir::Type, llvm::ArrayRef<mlir::Value>);
+  void genSyncWarp(llvm::ArrayRef<fir::ExtendedValue>);
   fir::ExtendedValue genSystem(std::optional<mlir::Type>,
                                mlir::ArrayRef<fir::ExtendedValue> args);
   void genSystemClock(llvm::ArrayRef<fir::ExtendedValue>);
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index a6a77dd58677b17..9b684520ec07820 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -680,6 +680,7 @@ static constexpr IntrinsicHandler handlers[]{
     {"syncthreads_and", &I::genSyncThreadsAnd, {}, /*isElemental=*/false},
     {"syncthreads_count", &I::genSyncThreadsCount, {}, /*isElemental=*/false},
     {"syncthreads_or", &I::genSyncThreadsOr, {}, /*isElemental=*/false},
+    {"syncwarp", &I::genSyncWarp, {}, /*isElemental=*/false},
     {"system",
      &I::genSystem,
      {{{"command", asBox}, {"exitstat", asBox, handleDynamicOptional}}},
@@ -7704,6 +7705,18 @@ IntrinsicLibrary::genSyncThreadsOr(mlir::Type resultType,
   return builder.create<fir::CallOp>(loc, funcOp, args).getResult(0);
 }
 
+// SYNCWARP
+void IntrinsicLibrary::genSyncWarp(llvm::ArrayRef<fir::ExtendedValue> args) {
+  assert(args.size() == 1);
+  constexpr llvm::StringLiteral funcName = "llvm.nvvm.bar.warp.sync";
+  mlir::Value mask = fir::getBase(args[0]);
+  mlir::FunctionType funcType =
+      mlir::FunctionType::get(builder.getContext(), {mask.getType()}, {});
+  auto funcOp = builder.createFunction(loc, funcName, funcType);
+  llvm::SmallVector<mlir::Value> argsList{mask};
+  builder.create<fir::CallOp>(loc, funcOp, argsList);
+}
+
 // SYSTEM
 fir::ExtendedValue
 IntrinsicLibrary::genSystem(std::optional<mlir::Type> resultType,
diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90
index 47526bccd98fe6c..45b9f2c83863835 100644
--- a/flang/module/cudadevice.f90
+++ b/flang/module/cudadevice.f90
@@ -49,7 +49,7 @@ attributes(device) integer function syncthreads_or(value)
   public :: syncthreads_or
 
   interface
-    attributes(device) subroutine syncwarp(mask) bind(c, name='__syncwarp')
+    attributes(device) subroutine syncwarp(mask)
       integer, value :: mask
     end subroutine
   end interface
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index ec825263474c1ee..17a6a1d965640e9 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -47,7 +47,7 @@ end
 
 ! CHECK-LABEL: func.func @_QPdevsub() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
 ! CHECK: fir.call @llvm.nvvm.barrier0() fastmath<contract> : () -> ()
-! CHECK: fir.call @__syncwarp(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32) -> ()
+! CHECK: fir.call @llvm.nvvm.bar.warp.sync(%c1{{.*}}) fastmath<contract> : (i32) -> ()
 ! CHECK: fir.call @llvm.nvvm.membar.gl() fastmath<contract> : () -> ()
 ! CHECK: fir.call @llvm.nvvm.membar.cta() fastmath<contract> : () -> ()
 ! CHECK: fir.call @llvm.nvvm.membar.sys() fastmath<contract> : () -> ()
@@ -102,13 +102,13 @@ end
 ! CHECK-LABEL: func.func @_QPhost1()
 ! CHECK: cuf.kernel
 ! CHECK: fir.call @llvm.nvvm.barrier0() fastmath<contract> : () -> ()
-! CHECK: fir.call @__syncwarp(%c1{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32) -> ()
+! CHECK: fir.call @llvm.nvvm.bar.warp.sync(%c1{{.*}}) fastmath<contract> : (i32) -> ()
 ! CHECK: fir.call @llvm.nvvm.barrier0.and(%c1{{.*}}) fastmath<contract> : (i32) -> i32
 ! CHECK: fir.call @llvm.nvvm.barrier0.popc(%c1{{.*}}) fastmath<contract> : (i32) -> i32
 ! CHECK: fir.call @llvm.nvvm.barrier0.or(%c1{{.*}}) fastmath<contract> : (i32) -> i32
 
 ! CHECK: func.func private @llvm.nvvm.barrier0()
-! CHECK: func.func private @__syncwarp(i32) attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncwarp", fir.proc_attrs = #fir.proc_attrs<bind_c>}
+! CHECK: func.func private @llvm.nvvm.bar.warp.sync(i32)
 ! CHECK: func.func private @llvm.nvvm.membar.gl()
 ! CHECK: func.func private @llvm.nvvm.membar.cta()
 ! CHECK: func.func private @llvm.nvvm.membar.sys()

``````````

</details>


https://github.com/llvm/llvm-project/pull/126164


More information about the flang-commits mailing list