[flang-commits] [flang] WIP: Trying to get a call treated as an intrinsic (PR #120020)

Renaud Kauffmann via flang-commits flang-commits at lists.llvm.org
Tue Dec 17 17:44:24 PST 2024


https://github.com/Renaud-K updated https://github.com/llvm/llvm-project/pull/120020

>From b7a6883a8d483712a2f2ab33e39052ffe9f497ae Mon Sep 17 00:00:00 2001
From: Renaud-K <rkauffmann at nvidia.com>
Date: Tue, 17 Dec 2024 17:43:50 -0800
Subject: [PATCH] Using nvvm intrinsics for the syncthread and threadfence
 families of calls

---
 .../flang/Optimizer/Builder/IntrinsicCall.h   |  7 ++
 flang/lib/Optimizer/Builder/IntrinsicCall.cpp | 85 +++++++++++++++++++
 flang/module/cudadevice.f90                   | 14 +--
 flang/test/Lower/CUDA/cuda-device-proc.cuf    |  8 +-
 4 files changed, 103 insertions(+), 11 deletions(-)

diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index bc0020e614db24..c16c717779a61e 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -392,6 +392,10 @@ struct IntrinsicLibrary {
   fir::ExtendedValue genSum(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
   void genSignalSubroutine(llvm::ArrayRef<fir::ExtendedValue>);
   void genSleep(llvm::ArrayRef<fir::ExtendedValue>);
+  void genSyncThreads(llvm::ArrayRef<fir::ExtendedValue>);
+  mlir::Value genSyncThreadsAnd(mlir::Type,llvm::ArrayRef<mlir::Value>);
+  mlir::Value genSyncThreadsCount(mlir::Type,llvm::ArrayRef<mlir::Value>);
+  mlir::Value genSyncThreadsOr(mlir::Type,llvm::ArrayRef<mlir::Value>);
   fir::ExtendedValue genSystem(std::optional<mlir::Type>,
                                mlir::ArrayRef<fir::ExtendedValue> args);
   void genSystemClock(llvm::ArrayRef<fir::ExtendedValue>);
@@ -401,6 +405,9 @@ struct IntrinsicLibrary {
                                  llvm::ArrayRef<fir::ExtendedValue>);
   fir::ExtendedValue genTranspose(mlir::Type,
                                   llvm::ArrayRef<fir::ExtendedValue>);
+  void genThreadFence(llvm::ArrayRef<fir::ExtendedValue>);
+  void genThreadFenceBlock(llvm::ArrayRef<fir::ExtendedValue>);
+  void genThreadFenceSystem(llvm::ArrayRef<fir::ExtendedValue>);
   fir::ExtendedValue genTrim(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
   fir::ExtendedValue genUbound(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
   fir::ExtendedValue genUnpack(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 547cebefd2df47..fe449d95c1605f 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -642,6 +642,10 @@ static constexpr IntrinsicHandler handlers[]{
        {"dim", asValue},
        {"mask", asBox, handleDynamicOptional}}},
      /*isElemental=*/false},
+    {"syncthreads", &I::genSyncThreads, {}, /*isElemental=*/false},
+    {"syncthreads_and", &I::genSyncThreadsAnd, {}, /*isElemental=*/false},
+    {"syncthreads_count", &I::genSyncThreadsCount, {}, /*isElemental=*/false},
+    {"syncthreads_or", &I::genSyncThreadsOr, {}, /*isElemental=*/false},
     {"system",
      &I::genSystem,
      {{{"command", asBox}, {"exitstat", asBox, handleDynamicOptional}}},
@@ -660,6 +664,9 @@ static constexpr IntrinsicHandler handlers[]{
      &I::genTranspose,
      {{{"matrix", asAddr}}},
      /*isElemental=*/false},
+    {"threadfence", &I::genThreadFence, {}, /*isElemental=*/false},
+    {"threadfence_block", &I::genThreadFenceBlock, {}, /*isElemental=*/false},
+    {"threadfence_system", &I::genThreadFenceSystem, {}, /*isElemental=*/false},
     {"trim", &I::genTrim, {{{"string", asAddr}}}, /*isElemental=*/false},
     {"ubound",
      &I::genUbound,
@@ -7290,6 +7297,52 @@ IntrinsicLibrary::genSum(mlir::Type resultType,
                       resultType, args);
 }
 
+// SYNCTHREADS
+void IntrinsicLibrary::genSyncThreads(llvm::ArrayRef<fir::ExtendedValue> args) {
+  constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0";
+  mlir::FunctionType funcType =
+      mlir::FunctionType::get(builder.getContext(), {}, {});
+  auto funcOp = builder.createFunction(loc, funcName, funcType);
+  llvm::SmallVector<mlir::Value> noArgs;
+  builder.create<fir::CallOp>(loc, funcOp, noArgs);
+}
+
+// SYNCTHREADS_AND
+mlir::Value
+IntrinsicLibrary::genSyncThreadsAnd(mlir::Type resultType,
+                                    llvm::ArrayRef<mlir::Value> args) {
+  constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0.and";
+  mlir::MLIRContext *context = builder.getContext();
+  mlir::FunctionType ftype =
+      mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
+  auto funcOp = builder.createFunction(loc, funcName, ftype);
+  return builder.create<fir::CallOp>(loc, funcOp, args).getResult(0);
+}
+
+// SYNCTHREADS_COUNT
+mlir::Value
+IntrinsicLibrary::genSyncThreadsCount(mlir::Type resultType,
+                                      llvm::ArrayRef<mlir::Value> args) {
+  constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0.popc";
+  mlir::MLIRContext *context = builder.getContext();
+  mlir::FunctionType ftype =
+      mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
+  auto funcOp = builder.createFunction(loc, funcName, ftype);
+  return builder.create<fir::CallOp>(loc, funcOp, args).getResult(0);
+}
+
+// SYNCTHREADS_OR
+mlir::Value
+IntrinsicLibrary::genSyncThreadsOr(mlir::Type resultType,
+                                   llvm::ArrayRef<mlir::Value> args) {
+  constexpr llvm::StringLiteral funcName = "llvm.nvvm.barrier0.or";
+  mlir::MLIRContext *context = builder.getContext();
+  mlir::FunctionType ftype =
+      mlir::FunctionType::get(context, {resultType}, {args[0].getType()});
+  auto funcOp = builder.createFunction(loc, funcName, ftype);
+  return builder.create<fir::CallOp>(loc, funcOp, args).getResult(0);
+}
+
 // SYSTEM
 fir::ExtendedValue
 IntrinsicLibrary::genSystem(std::optional<mlir::Type> resultType,
@@ -7420,6 +7473,38 @@ IntrinsicLibrary::genTranspose(mlir::Type resultType,
   return readAndAddCleanUp(resultMutableBox, resultType, "TRANSPOSE");
 }
 
+// THREADFENCE
+void IntrinsicLibrary::genThreadFence(llvm::ArrayRef<fir::ExtendedValue> args) {
+  constexpr llvm::StringLiteral funcName = "llvm.nvvm.membar.gl";
+  mlir::FunctionType funcType =
+      mlir::FunctionType::get(builder.getContext(), {}, {});
+  auto funcOp = builder.createFunction(loc, funcName, funcType);
+  llvm::SmallVector<mlir::Value> noArgs;
+  builder.create<fir::CallOp>(loc, funcOp, noArgs);
+}
+
+// THREADFENCE_BLOCK
+void IntrinsicLibrary::genThreadFenceBlock(
+    llvm::ArrayRef<fir::ExtendedValue> args) {
+  constexpr llvm::StringLiteral funcName = "llvm.nvvm.membar.cta";
+  mlir::FunctionType funcType =
+      mlir::FunctionType::get(builder.getContext(), {}, {});
+  auto funcOp = builder.createFunction(loc, funcName, funcType);
+  llvm::SmallVector<mlir::Value> noArgs;
+  builder.create<fir::CallOp>(loc, funcOp, noArgs);
+}
+
+// THREADFENCE_SYSTEM
+void IntrinsicLibrary::genThreadFenceSystem(
+    llvm::ArrayRef<fir::ExtendedValue> args) {
+  constexpr llvm::StringLiteral funcName = "llvm.nvvm.membar.sys";
+  mlir::FunctionType funcType =
+      mlir::FunctionType::get(builder.getContext(), {}, {});
+  auto funcOp = builder.createFunction(loc, funcName, funcType);
+  llvm::SmallVector<mlir::Value> noArgs;
+  builder.create<fir::CallOp>(loc, funcOp, noArgs);
+}
+
 // TRIM
 fir::ExtendedValue
 IntrinsicLibrary::genTrim(mlir::Type resultType,
diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90
index 0224ecfdde7c60..1402bd4e150419 100644
--- a/flang/module/cudadevice.f90
+++ b/flang/module/cudadevice.f90
@@ -18,27 +18,27 @@ module cudadevice
   ! Synchronization Functions
 
   interface
-    attributes(device) subroutine syncthreads() bind(c, name='__syncthreads')
+    attributes(device) subroutine syncthreads()
     end subroutine
   end interface
   public :: syncthreads
 
   interface
-    attributes(device) integer function syncthreads_and(value) bind(c, name='__syncthreads_and')
+    attributes(device) integer function syncthreads_and(value)
       integer :: value
     end function
   end interface
   public :: syncthreads_and
 
   interface
-    attributes(device) integer function syncthreads_count(value) bind(c, name='__syncthreads_count')
+    attributes(device) integer function syncthreads_count(value)
       integer :: value
     end function
   end interface
   public :: syncthreads_count
 
   interface
-    attributes(device) integer function syncthreads_or(value) bind(c, name='__syncthreads_or')
+    attributes(device) integer function syncthreads_or(value)
       integer :: value
     end function
   end interface
@@ -54,19 +54,19 @@ attributes(device) subroutine syncwarp(mask) bind(c, name='__syncwarp')
   ! Memory Fences
 
   interface
-    attributes(device) subroutine threadfence() bind(c, name='__threadfence')
+    attributes(device) subroutine threadfence()
     end subroutine
   end interface
   public :: threadfence
 
   interface
-    attributes(device) subroutine threadfence_block() bind(c, name='__threadfence_block')
+    attributes(device) subroutine threadfence_block()
     end subroutine
   end interface
   public :: threadfence_block
 
   interface
-    attributes(device) subroutine threadfence_system() bind(c, name='__threadfence_system')
+    attributes(device) subroutine threadfence_system()
     end subroutine
   end interface
   public :: threadfence_system
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index 1331b644130c87..1951fb1d43c817 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -17,14 +17,14 @@ attributes(global) subroutine devsub()
 end
 
 ! CHECK-LABEL: func.func @_QPdevsub() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
-! CHECK: fir.call @__syncthreads()
+! CHECK: fir.call @llvm.nvvm.barrier0() fastmath<contract> : () -> ()
 ! CHECK: fir.call @__syncwarp(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i32>) -> ()
 ! CHECK: fir.call @__threadfence()
 ! CHECK: fir.call @__threadfence_block()
 ! CHECK: fir.call @__threadfence_system()
-! CHECK: %{{.*}} = fir.call @__syncthreads_and(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i32>) -> i32
-! CHECK: %{{.*}} = fir.call @__syncthreads_count(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i32>) -> i32
-! CHECK: %{{.*}} = fir.call @__syncthreads_or(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (!fir.ref<i32>) -> i32
+! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.and(%c1_i32_0) fastmath<contract> : (i32) -> i32
+! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.popc(%c1_i32_1) fastmath<contract> : (i32) -> i32
+! CHECK: %{{.*}} = fir.call @llvm.nvvm.barrier0.or(%c1_i32_2) fastmath<contract> : (i32) -> i32
 
 ! CHECK: func.func private @__syncthreads() attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncthreads", fir.proc_attrs = #fir.proc_attrs<bind_c>}
 ! CHECK: func.func private @__syncwarp(!fir.ref<i32> {cuf.data_attr = #cuf.cuda<device>}) attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncwarp", fir.proc_attrs = #fir.proc_attrs<bind_c>}



More information about the flang-commits mailing list