[flang-commits] [flang] [flang][cuda] Add option to disable warp function in semantic (PR #143640)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Tue Jun 10 19:03:15 PDT 2025
https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/143640
>From 86bc1a8dc17df981eb05556e83e2077526186d1d Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Tue, 10 Jun 2025 17:27:18 -0700
Subject: [PATCH 1/2] [flang][cuda] Add option to disable warp function in
semantic
---
.../include/flang/Support/Fortran-features.h | 2 +-
flang/lib/Semantics/check-cuda.cpp | 125 ++++++++++++------
flang/tools/bbc/bbc.cpp | 10 ++
3 files changed, 93 insertions(+), 44 deletions(-)
diff --git a/flang/include/flang/Support/Fortran-features.h b/flang/include/flang/Support/Fortran-features.h
index 3f6d825e2b66c..ea0845b7d605f 100644
--- a/flang/include/flang/Support/Fortran-features.h
+++ b/flang/include/flang/Support/Fortran-features.h
@@ -55,7 +55,7 @@ ENUM_CLASS(LanguageFeature, BackslashEscapes, OldDebugLines,
SavedLocalInSpecExpr, PrintNamelist, AssumedRankPassedToNonAssumedRank,
IgnoreIrrelevantAttributes, Unsigned, AmbiguousStructureConstructor,
ContiguousOkForSeqAssociation, ForwardRefExplicitTypeDummy,
- InaccessibleDeferredOverride)
+ InaccessibleDeferredOverride, CudaWarpMatchFunction)
// Portability and suspicious usage warnings
ENUM_CLASS(UsageWarning, Portability, PointerToUndefinable,
diff --git a/flang/lib/Semantics/check-cuda.cpp b/flang/lib/Semantics/check-cuda.cpp
index c024640af1220..8decfb0149829 100644
--- a/flang/lib/Semantics/check-cuda.cpp
+++ b/flang/lib/Semantics/check-cuda.cpp
@@ -17,6 +17,7 @@
#include "flang/Semantics/expression.h"
#include "flang/Semantics/symbol.h"
#include "flang/Semantics/tools.h"
+#include "llvm/ADT/StringSet.h"
// Once labeled DO constructs have been canonicalized and their parse subtrees
// transformed into parser::DoConstructs, scan the parser::Blocks of the program
@@ -61,6 +62,11 @@ bool CanonicalizeCUDA(parser::Program &program) {
using MaybeMsg = std::optional<parser::MessageFormattedText>;
+static const llvm::StringSet<> warpFunctions_ = {"match_all_syncjj",
+ "match_all_syncjx", "match_all_syncjf", "match_all_syncjd",
+ "match_any_syncjj", "match_any_syncjx", "match_any_syncjf",
+ "match_any_syncjd"};
+
// Traverses an evaluate::Expr<> in search of unsupported operations
// on the device.
@@ -68,7 +74,7 @@ struct DeviceExprChecker
: public evaluate::AnyTraverse<DeviceExprChecker, MaybeMsg> {
using Result = MaybeMsg;
using Base = evaluate::AnyTraverse<DeviceExprChecker, Result>;
- DeviceExprChecker() : Base(*this) {}
+ explicit DeviceExprChecker(SemanticsContext &c) : Base(*this), context_{c} {}
using Base::operator();
Result operator()(const evaluate::ProcedureDesignator &x) const {
if (const Symbol * sym{x.GetInterfaceSymbol()}) {
@@ -78,10 +84,17 @@ struct DeviceExprChecker
if (auto attrs{subp->cudaSubprogramAttrs()}) {
if (*attrs == common::CUDASubprogramAttrs::HostDevice ||
*attrs == common::CUDASubprogramAttrs::Device) {
+ if (warpFunctions_.contains(sym->name().ToString()) &&
+ !context_.languageFeatures().IsEnabled(
+ Fortran::common::LanguageFeature::CudaWarpMatchFunction)) {
+ return parser::MessageFormattedText(
+ "warp match function disabled"_err_en_US);
+ }
return {};
}
}
}
+
const Symbol &ultimate{sym->GetUltimate()};
const Scope &scope{ultimate.owner()};
const Symbol *mod{scope.IsModule() ? scope.symbol() : nullptr};
@@ -94,9 +107,12 @@ struct DeviceExprChecker
// TODO(CUDA): Check for unsupported intrinsics here
return {};
}
+
return parser::MessageFormattedText(
"'%s' may not be called in device code"_err_en_US, x.GetName());
}
+
+ SemanticsContext &context_;
};
struct FindHostArray
@@ -133,9 +149,10 @@ struct FindHostArray
}
};
-template <typename A> static MaybeMsg CheckUnwrappedExpr(const A &x) {
+template <typename A>
+static MaybeMsg CheckUnwrappedExpr(SemanticsContext &context, const A &x) {
if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
- return DeviceExprChecker{}(expr->typedExpr);
+ return DeviceExprChecker{context}(expr->typedExpr);
}
return {};
}
@@ -144,104 +161,124 @@ template <typename A>
static void CheckUnwrappedExpr(
SemanticsContext &context, SourceName at, const A &x) {
if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
- if (auto msg{DeviceExprChecker{}(expr->typedExpr)}) {
+ if (auto msg{DeviceExprChecker{context}(expr->typedExpr)}) {
context.Say(at, std::move(*msg));
}
}
}
template <bool CUF_KERNEL> struct ActionStmtChecker {
- template <typename A> static MaybeMsg WhyNotOk(const A &x) {
+ template <typename A>
+ static MaybeMsg WhyNotOk(SemanticsContext &context, const A &x) {
if constexpr (ConstraintTrait<A>) {
- return WhyNotOk(x.thing);
+ return WhyNotOk(context, x.thing);
} else if constexpr (WrapperTrait<A>) {
- return WhyNotOk(x.v);
+ return WhyNotOk(context, x.v);
} else if constexpr (UnionTrait<A>) {
- return WhyNotOk(x.u);
+ return WhyNotOk(context, x.u);
} else if constexpr (TupleTrait<A>) {
- return WhyNotOk(x.t);
+ return WhyNotOk(context, x.t);
} else {
return parser::MessageFormattedText{
"Statement may not appear in device code"_err_en_US};
}
}
template <typename A>
- static MaybeMsg WhyNotOk(const common::Indirection<A> &x) {
- return WhyNotOk(x.value());
+ static MaybeMsg WhyNotOk(
+ SemanticsContext &context, const common::Indirection<A> &x) {
+ return WhyNotOk(context, x.value());
}
template <typename... As>
- static MaybeMsg WhyNotOk(const std::variant<As...> &x) {
- return common::visit([](const auto &x) { return WhyNotOk(x); }, x);
+ static MaybeMsg WhyNotOk(
+ SemanticsContext &context, const std::variant<As...> &x) {
+ return common::visit(
+ [&context](const auto &x) { return WhyNotOk(context, x); }, x);
}
template <std::size_t J = 0, typename... As>
- static MaybeMsg WhyNotOk(const std::tuple<As...> &x) {
+ static MaybeMsg WhyNotOk(
+ SemanticsContext &context, const std::tuple<As...> &x) {
if constexpr (J == sizeof...(As)) {
return {};
- } else if (auto msg{WhyNotOk(std::get<J>(x))}) {
+ } else if (auto msg{WhyNotOk(context, std::get<J>(x))}) {
return msg;
} else {
- return WhyNotOk<(J + 1)>(x);
+ return WhyNotOk<(J + 1)>(context, x);
}
}
- template <typename A> static MaybeMsg WhyNotOk(const std::list<A> &x) {
+ template <typename A>
+ static MaybeMsg WhyNotOk(SemanticsContext &context, const std::list<A> &x) {
for (const auto &y : x) {
- if (MaybeMsg result{WhyNotOk(y)}) {
+ if (MaybeMsg result{WhyNotOk(context, y)}) {
return result;
}
}
return {};
}
- template <typename A> static MaybeMsg WhyNotOk(const std::optional<A> &x) {
+ template <typename A>
+ static MaybeMsg WhyNotOk(
+ SemanticsContext &context, const std::optional<A> &x) {
if (x) {
- return WhyNotOk(*x);
+ return WhyNotOk(context, *x);
} else {
return {};
}
}
template <typename A>
- static MaybeMsg WhyNotOk(const parser::UnlabeledStatement<A> &x) {
- return WhyNotOk(x.statement);
+ static MaybeMsg WhyNotOk(
+ SemanticsContext &context, const parser::UnlabeledStatement<A> &x) {
+ return WhyNotOk(context, x.statement);
}
template <typename A>
- static MaybeMsg WhyNotOk(const parser::Statement<A> &x) {
- return WhyNotOk(x.statement);
+ static MaybeMsg WhyNotOk(
+ SemanticsContext &context, const parser::Statement<A> &x) {
+ return WhyNotOk(context, x.statement);
}
- static MaybeMsg WhyNotOk(const parser::AllocateStmt &) {
+ static MaybeMsg WhyNotOk(
+ SemanticsContext &context, const parser::AllocateStmt &) {
return {}; // AllocateObjects are checked elsewhere
}
- static MaybeMsg WhyNotOk(const parser::AllocateCoarraySpec &) {
+ static MaybeMsg WhyNotOk(
+ SemanticsContext &context, const parser::AllocateCoarraySpec &) {
return parser::MessageFormattedText(
"A coarray may not be allocated on the device"_err_en_US);
}
- static MaybeMsg WhyNotOk(const parser::DeallocateStmt &) {
+ static MaybeMsg WhyNotOk(
+ SemanticsContext &context, const parser::DeallocateStmt &) {
return {}; // AllocateObjects are checked elsewhere
}
- static MaybeMsg WhyNotOk(const parser::AssignmentStmt &x) {
- return DeviceExprChecker{}(x.typedAssignment);
+ static MaybeMsg WhyNotOk(
+ SemanticsContext &context, const parser::AssignmentStmt &x) {
+ return DeviceExprChecker{context}(x.typedAssignment);
}
- static MaybeMsg WhyNotOk(const parser::CallStmt &x) {
- return DeviceExprChecker{}(x.typedCall);
+ static MaybeMsg WhyNotOk(
+ SemanticsContext &context, const parser::CallStmt &x) {
+ return DeviceExprChecker{context}(x.typedCall);
+ }
+ static MaybeMsg WhyNotOk(
+ SemanticsContext &context, const parser::ContinueStmt &) {
+ return {};
}
- static MaybeMsg WhyNotOk(const parser::ContinueStmt &) { return {}; }
- static MaybeMsg WhyNotOk(const parser::IfStmt &x) {
- if (auto result{
- CheckUnwrappedExpr(std::get<parser::ScalarLogicalExpr>(x.t))}) {
+ static MaybeMsg WhyNotOk(SemanticsContext &context, const parser::IfStmt &x) {
+ if (auto result{CheckUnwrappedExpr(
+ context, std::get<parser::ScalarLogicalExpr>(x.t))}) {
return result;
}
- return WhyNotOk(
+ return WhyNotOk(context,
std::get<parser::UnlabeledStatement<parser::ActionStmt>>(x.t)
.statement);
}
- static MaybeMsg WhyNotOk(const parser::NullifyStmt &x) {
+ static MaybeMsg WhyNotOk(
+ SemanticsContext &context, const parser::NullifyStmt &x) {
for (const auto &y : x.v) {
- if (MaybeMsg result{DeviceExprChecker{}(y.typedExpr)}) {
+ if (MaybeMsg result{DeviceExprChecker{context}(y.typedExpr)}) {
return result;
}
}
return {};
}
- static MaybeMsg WhyNotOk(const parser::PointerAssignmentStmt &x) {
- return DeviceExprChecker{}(x.typedAssignment);
+ static MaybeMsg WhyNotOk(
+ SemanticsContext &context, const parser::PointerAssignmentStmt &x) {
+ return DeviceExprChecker{context}(x.typedAssignment);
}
};
@@ -435,12 +472,14 @@ template <bool IsCUFKernelDo> class DeviceContextChecker {
ErrorIfHostSymbol(assign->lhs, source);
ErrorIfHostSymbol(assign->rhs, source);
}
- if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(x)}) {
+ if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(
+ context_, x)}) {
context_.Say(source, std::move(*msg));
}
},
[&](const auto &x) {
- if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(x)}) {
+ if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(
+ context_, x)}) {
context_.Say(source, std::move(*msg));
}
},
@@ -504,7 +543,7 @@ template <bool IsCUFKernelDo> class DeviceContextChecker {
Check(DEREF(parser::Unwrap<parser::Expr>(x)));
}
void Check(const parser::Expr &expr) {
- if (MaybeMsg msg{DeviceExprChecker{}(expr.typedExpr)}) {
+ if (MaybeMsg msg{DeviceExprChecker{context_}(expr.typedExpr)}) {
context_.Say(expr.source, std::move(*msg));
}
}
diff --git a/flang/tools/bbc/bbc.cpp b/flang/tools/bbc/bbc.cpp
index c544008a24d56..c80872108ac8f 100644
--- a/flang/tools/bbc/bbc.cpp
+++ b/flang/tools/bbc/bbc.cpp
@@ -223,6 +223,11 @@ static llvm::cl::opt<bool> enableCUDA("fcuda",
llvm::cl::desc("enable CUDA Fortran"),
llvm::cl::init(false));
+static llvm::cl::opt<bool>
+ disableCUDAWarpFunction("fcuda-disable-warp-function",
+ llvm::cl::desc("Disable CUDA Warp Function"),
+ llvm::cl::init(false));
+
static llvm::cl::opt<std::string>
enableGPUMode("gpu", llvm::cl::desc("Enable GPU Mode managed|unified"),
llvm::cl::init(""));
@@ -600,6 +605,11 @@ int main(int argc, char **argv) {
options.features.Enable(Fortran::common::LanguageFeature::CUDA);
}
+ if (disableCUDAWarpFunction) {
+ options.features.Enable(
+ Fortran::common::LanguageFeature::CudaWarpMatchFunction, false);
+ }
+
if (enableGPUMode == "managed") {
options.features.Enable(Fortran::common::LanguageFeature::CudaManaged);
} else if (enableGPUMode == "unified") {
>From 9982dfaf9e458c1bf819c23e15dbe7e7c321f675 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Tue, 10 Jun 2025 19:03:05 -0700
Subject: [PATCH 2/2] Add test
---
flang/test/Semantics/cuf22.cuf | 8 ++++++++
1 file changed, 8 insertions(+)
create mode 100644 flang/test/Semantics/cuf22.cuf
diff --git a/flang/test/Semantics/cuf22.cuf b/flang/test/Semantics/cuf22.cuf
new file mode 100644
index 0000000000000..36e0f0b2502df
--- /dev/null
+++ b/flang/test/Semantics/cuf22.cuf
@@ -0,0 +1,8 @@
+! RUN: not bbc -fcuda -fcuda-disable-warp-function %s -o - 2>&1 | FileCheck %s
+
+attributes(device) subroutine testMatch()
+ integer :: a, ipred, mask, v32
+ a = match_all_sync(mask, v32, ipred)
+end subroutine
+
+! CHECK: warp match function disabled
More information about the flang-commits
mailing list