[flang-commits] [flang] ae8dd63 - [flang][cuda] Add interface and lowering for all_sync (#134001)
via flang-commits
flang-commits at lists.llvm.org
Tue Apr 1 17:59:15 PDT 2025
Author: Valentin Clement (バレンタイン クレメン)
Date: 2025-04-01T17:59:11-07:00
New Revision: ae8dd63681bf93b04ff8a29e3cbbd152bd97c5c7
URL: https://github.com/llvm/llvm-project/commit/ae8dd63681bf93b04ff8a29e3cbbd152bd97c5c7
DIFF: https://github.com/llvm/llvm-project/commit/ae8dd63681bf93b04ff8a29e3cbbd152bd97c5c7.diff
LOG: [flang][cuda] Add interface and lowering for all_sync (#134001)
Added:
Modified:
flang/include/flang/Optimizer/Builder/IntrinsicCall.h
flang/lib/Optimizer/Builder/IntrinsicCall.cpp
flang/module/cudadevice.f90
flang/test/Lower/CUDA/cuda-device-proc.cuf
Removed:
################################################################################
diff --git a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
index 83f08bb88f7f3..a31bbd0a1bd88 100644
--- a/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
+++ b/flang/include/flang/Optimizer/Builder/IntrinsicCall.h
@@ -441,6 +441,7 @@ struct IntrinsicLibrary {
fir::ExtendedValue genUbound(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genUnpack(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
fir::ExtendedValue genVerify(mlir::Type, llvm::ArrayRef<fir::ExtendedValue>);
+ mlir::Value genVoteAllSync(mlir::Type, llvm::ArrayRef<mlir::Value>);
/// Implement all conversion functions like DBLE, the first argument is
/// the value to convert. There may be an additional KIND arguments that
diff --git a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
index 8bbec6d6a7535..9029ea69dd5c4 100644
--- a/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
+++ b/flang/lib/Optimizer/Builder/IntrinsicCall.cpp
@@ -260,6 +260,10 @@ static constexpr IntrinsicHandler handlers[]{
&I::genAll,
{{{"mask", asAddr}, {"dim", asValue}}},
/*isElemental=*/false},
+ {"all_sync",
+ &I::genVoteAllSync,
+ {{{"mask", asValue}, {"pred", asValue}}},
+ /*isElemental=*/false},
{"allocated",
&I::genAllocated,
{{{"array", asInquired}, {"scalar", asInquired}}},
@@ -6495,6 +6499,21 @@ IntrinsicLibrary::genMatchAllSync(mlir::Type resultType,
return value;
}
+// ALL_SYNC
+mlir::Value IntrinsicLibrary::genVoteAllSync(mlir::Type resultType,
+ llvm::ArrayRef<mlir::Value> args) {
+ assert(args.size() == 2);
+
+ llvm::StringRef funcName = "llvm.nvvm.vote.all.sync";
+ mlir::MLIRContext *context = builder.getContext();
+ mlir::Type i32Ty = builder.getI32Type();
+ mlir::FunctionType ftype =
+ mlir::FunctionType::get(context, {i32Ty, i32Ty}, {i32Ty});
+ auto funcOp = builder.createFunction(loc, funcName, ftype);
+ llvm::SmallVector<mlir::Value> filteredArgs;
+ return builder.create<fir::CallOp>(loc, funcOp, args).getResult(0);
+}
+
// MATCH_ANY_SYNC
mlir::Value
IntrinsicLibrary::genMatchAnySync(mlir::Type resultType,
diff --git a/flang/module/cudadevice.f90 b/flang/module/cudadevice.f90
index baaa112f5d8c2..6b8aa4de74240 100644
--- a/flang/module/cudadevice.f90
+++ b/flang/module/cudadevice.f90
@@ -1015,6 +1015,13 @@ attributes(device) integer function match_any_syncjd(mask, val)
end function
end interface
+ interface all_sync
+ attributes(device) integer function all_sync(mask, pred)
+ !dir$ ignore_tkr(d) mask, (td) pred
+ integer, value :: mask, pred
+ end function
+ end interface
+
! LDCG
interface __ldcg
attributes(device) pure integer(4) function __ldcg_i4(x) bind(c)
diff --git a/flang/test/Lower/CUDA/cuda-device-proc.cuf b/flang/test/Lower/CUDA/cuda-device-proc.cuf
index 617d57d097522..9758107c84031 100644
--- a/flang/test/Lower/CUDA/cuda-device-proc.cuf
+++ b/flang/test/Lower/CUDA/cuda-device-proc.cuf
@@ -296,6 +296,15 @@ end
! CHECK: fir.call @__ldlu_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
! CHECK: fir.call @__ldcv_r8x2_(%{{.*}}, %{{.*}}) fastmath<contract> : (!fir.ref<!fir.array<2xf64>>, !fir.ref<!fir.array<?xf64>>) -> ()
+attributes(device) subroutine testVote()
+ integer :: a, ipred, mask, v32
+ a = all_sync(mask, v32)
+
+end subroutine
+
+! CHECK-LABEL: func.func @_QPtestvote()
+! CHECK: fir.call @llvm.nvvm.vote.all.sync
+
! CHECK-DAG: func.func private @__ldca_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)
! CHECK-DAG: func.func private @__ldcg_i4x4_(!fir.ref<!fir.array<4xi32>>, !fir.ref<!fir.array<4xi32>>)
More information about the flang-commits
mailing list