[flang-commits] [flang] a3201ce - [flang][cuda] Add option to disable warp function in semantic (#143640)

via flang-commits flang-commits at lists.llvm.org
Tue Jun 10 22:10:29 PDT 2025


Author: Valentin Clement (バレンタイン クレメン)
Date: 2025-06-10T22:10:26-07:00
New Revision: a3201ce9e114aa2ecd66e525607093e4dff2f574

URL: https://github.com/llvm/llvm-project/commit/a3201ce9e114aa2ecd66e525607093e4dff2f574
DIFF: https://github.com/llvm/llvm-project/commit/a3201ce9e114aa2ecd66e525607093e4dff2f574.diff

LOG: [flang][cuda] Add option to disable warp function in semantic (#143640)

These functions are not available in some lower compute capabilities.
Add option in the language feature to enforce the semantic check on
these.

Added: 
    flang/test/Semantics/cuf22.cuf

Modified: 
    flang/include/flang/Support/Fortran-features.h
    flang/lib/Semantics/check-cuda.cpp
    flang/tools/bbc/bbc.cpp

Removed: 
    


################################################################################
diff  --git a/flang/include/flang/Support/Fortran-features.h b/flang/include/flang/Support/Fortran-features.h
index 3f6d825e2b66c..ea0845b7d605f 100644
--- a/flang/include/flang/Support/Fortran-features.h
+++ b/flang/include/flang/Support/Fortran-features.h
@@ -55,7 +55,7 @@ ENUM_CLASS(LanguageFeature, BackslashEscapes, OldDebugLines,
     SavedLocalInSpecExpr, PrintNamelist, AssumedRankPassedToNonAssumedRank,
     IgnoreIrrelevantAttributes, Unsigned, AmbiguousStructureConstructor,
     ContiguousOkForSeqAssociation, ForwardRefExplicitTypeDummy,
-    InaccessibleDeferredOverride)
+    InaccessibleDeferredOverride, CudaWarpMatchFunction)
 
 // Portability and suspicious usage warnings
 ENUM_CLASS(UsageWarning, Portability, PointerToUndefinable,

diff  --git a/flang/lib/Semantics/check-cuda.cpp b/flang/lib/Semantics/check-cuda.cpp
index c024640af1220..8decfb0149829 100644
--- a/flang/lib/Semantics/check-cuda.cpp
+++ b/flang/lib/Semantics/check-cuda.cpp
@@ -17,6 +17,7 @@
 #include "flang/Semantics/expression.h"
 #include "flang/Semantics/symbol.h"
 #include "flang/Semantics/tools.h"
+#include "llvm/ADT/StringSet.h"
 
 // Once labeled DO constructs have been canonicalized and their parse subtrees
 // transformed into parser::DoConstructs, scan the parser::Blocks of the program
@@ -61,6 +62,11 @@ bool CanonicalizeCUDA(parser::Program &program) {
 
 using MaybeMsg = std::optional<parser::MessageFormattedText>;
 
+static const llvm::StringSet<> warpFunctions_ = {"match_all_syncjj",
+    "match_all_syncjx", "match_all_syncjf", "match_all_syncjd",
+    "match_any_syncjj", "match_any_syncjx", "match_any_syncjf",
+    "match_any_syncjd"};
+
 // Traverses an evaluate::Expr<> in search of unsupported operations
 // on the device.
 
@@ -68,7 +74,7 @@ struct DeviceExprChecker
     : public evaluate::AnyTraverse<DeviceExprChecker, MaybeMsg> {
   using Result = MaybeMsg;
   using Base = evaluate::AnyTraverse<DeviceExprChecker, Result>;
-  DeviceExprChecker() : Base(*this) {}
+  explicit DeviceExprChecker(SemanticsContext &c) : Base(*this), context_{c} {}
   using Base::operator();
   Result operator()(const evaluate::ProcedureDesignator &x) const {
     if (const Symbol * sym{x.GetInterfaceSymbol()}) {
@@ -78,10 +84,17 @@ struct DeviceExprChecker
         if (auto attrs{subp->cudaSubprogramAttrs()}) {
           if (*attrs == common::CUDASubprogramAttrs::HostDevice ||
               *attrs == common::CUDASubprogramAttrs::Device) {
+            if (warpFunctions_.contains(sym->name().ToString()) &&
+                !context_.languageFeatures().IsEnabled(
+                    Fortran::common::LanguageFeature::CudaWarpMatchFunction)) {
+              return parser::MessageFormattedText(
+                  "warp match function disabled"_err_en_US);
+            }
             return {};
           }
         }
       }
+
       const Symbol &ultimate{sym->GetUltimate()};
       const Scope &scope{ultimate.owner()};
       const Symbol *mod{scope.IsModule() ? scope.symbol() : nullptr};
@@ -94,9 +107,12 @@ struct DeviceExprChecker
       // TODO(CUDA): Check for unsupported intrinsics here
       return {};
     }
+
     return parser::MessageFormattedText(
         "'%s' may not be called in device code"_err_en_US, x.GetName());
   }
+
+  SemanticsContext &context_;
 };
 
 struct FindHostArray
@@ -133,9 +149,10 @@ struct FindHostArray
   }
 };
 
-template <typename A> static MaybeMsg CheckUnwrappedExpr(const A &x) {
+template <typename A>
+static MaybeMsg CheckUnwrappedExpr(SemanticsContext &context, const A &x) {
   if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
-    return DeviceExprChecker{}(expr->typedExpr);
+    return DeviceExprChecker{context}(expr->typedExpr);
   }
   return {};
 }
@@ -144,104 +161,124 @@ template <typename A>
 static void CheckUnwrappedExpr(
     SemanticsContext &context, SourceName at, const A &x) {
   if (const auto *expr{parser::Unwrap<parser::Expr>(x)}) {
-    if (auto msg{DeviceExprChecker{}(expr->typedExpr)}) {
+    if (auto msg{DeviceExprChecker{context}(expr->typedExpr)}) {
       context.Say(at, std::move(*msg));
     }
   }
 }
 
 template <bool CUF_KERNEL> struct ActionStmtChecker {
-  template <typename A> static MaybeMsg WhyNotOk(const A &x) {
+  template <typename A>
+  static MaybeMsg WhyNotOk(SemanticsContext &context, const A &x) {
     if constexpr (ConstraintTrait<A>) {
-      return WhyNotOk(x.thing);
+      return WhyNotOk(context, x.thing);
     } else if constexpr (WrapperTrait<A>) {
-      return WhyNotOk(x.v);
+      return WhyNotOk(context, x.v);
     } else if constexpr (UnionTrait<A>) {
-      return WhyNotOk(x.u);
+      return WhyNotOk(context, x.u);
     } else if constexpr (TupleTrait<A>) {
-      return WhyNotOk(x.t);
+      return WhyNotOk(context, x.t);
     } else {
       return parser::MessageFormattedText{
           "Statement may not appear in device code"_err_en_US};
     }
   }
   template <typename A>
-  static MaybeMsg WhyNotOk(const common::Indirection<A> &x) {
-    return WhyNotOk(x.value());
+  static MaybeMsg WhyNotOk(
+      SemanticsContext &context, const common::Indirection<A> &x) {
+    return WhyNotOk(context, x.value());
   }
   template <typename... As>
-  static MaybeMsg WhyNotOk(const std::variant<As...> &x) {
-    return common::visit([](const auto &x) { return WhyNotOk(x); }, x);
+  static MaybeMsg WhyNotOk(
+      SemanticsContext &context, const std::variant<As...> &x) {
+    return common::visit(
+        [&context](const auto &x) { return WhyNotOk(context, x); }, x);
   }
   template <std::size_t J = 0, typename... As>
-  static MaybeMsg WhyNotOk(const std::tuple<As...> &x) {
+  static MaybeMsg WhyNotOk(
+      SemanticsContext &context, const std::tuple<As...> &x) {
     if constexpr (J == sizeof...(As)) {
       return {};
-    } else if (auto msg{WhyNotOk(std::get<J>(x))}) {
+    } else if (auto msg{WhyNotOk(context, std::get<J>(x))}) {
       return msg;
     } else {
-      return WhyNotOk<(J + 1)>(x);
+      return WhyNotOk<(J + 1)>(context, x);
     }
   }
-  template <typename A> static MaybeMsg WhyNotOk(const std::list<A> &x) {
+  template <typename A>
+  static MaybeMsg WhyNotOk(SemanticsContext &context, const std::list<A> &x) {
     for (const auto &y : x) {
-      if (MaybeMsg result{WhyNotOk(y)}) {
+      if (MaybeMsg result{WhyNotOk(context, y)}) {
         return result;
       }
     }
     return {};
   }
-  template <typename A> static MaybeMsg WhyNotOk(const std::optional<A> &x) {
+  template <typename A>
+  static MaybeMsg WhyNotOk(
+      SemanticsContext &context, const std::optional<A> &x) {
     if (x) {
-      return WhyNotOk(*x);
+      return WhyNotOk(context, *x);
     } else {
       return {};
     }
   }
   template <typename A>
-  static MaybeMsg WhyNotOk(const parser::UnlabeledStatement<A> &x) {
-    return WhyNotOk(x.statement);
+  static MaybeMsg WhyNotOk(
+      SemanticsContext &context, const parser::UnlabeledStatement<A> &x) {
+    return WhyNotOk(context, x.statement);
   }
   template <typename A>
-  static MaybeMsg WhyNotOk(const parser::Statement<A> &x) {
-    return WhyNotOk(x.statement);
+  static MaybeMsg WhyNotOk(
+      SemanticsContext &context, const parser::Statement<A> &x) {
+    return WhyNotOk(context, x.statement);
   }
-  static MaybeMsg WhyNotOk(const parser::AllocateStmt &) {
+  static MaybeMsg WhyNotOk(
+      SemanticsContext &context, const parser::AllocateStmt &) {
     return {}; // AllocateObjects are checked elsewhere
   }
-  static MaybeMsg WhyNotOk(const parser::AllocateCoarraySpec &) {
+  static MaybeMsg WhyNotOk(
+      SemanticsContext &context, const parser::AllocateCoarraySpec &) {
     return parser::MessageFormattedText(
         "A coarray may not be allocated on the device"_err_en_US);
   }
-  static MaybeMsg WhyNotOk(const parser::DeallocateStmt &) {
+  static MaybeMsg WhyNotOk(
+      SemanticsContext &context, const parser::DeallocateStmt &) {
     return {}; // AllocateObjects are checked elsewhere
   }
-  static MaybeMsg WhyNotOk(const parser::AssignmentStmt &x) {
-    return DeviceExprChecker{}(x.typedAssignment);
+  static MaybeMsg WhyNotOk(
+      SemanticsContext &context, const parser::AssignmentStmt &x) {
+    return DeviceExprChecker{context}(x.typedAssignment);
   }
-  static MaybeMsg WhyNotOk(const parser::CallStmt &x) {
-    return DeviceExprChecker{}(x.typedCall);
+  static MaybeMsg WhyNotOk(
+      SemanticsContext &context, const parser::CallStmt &x) {
+    return DeviceExprChecker{context}(x.typedCall);
+  }
+  static MaybeMsg WhyNotOk(
+      SemanticsContext &context, const parser::ContinueStmt &) {
+    return {};
   }
-  static MaybeMsg WhyNotOk(const parser::ContinueStmt &) { return {}; }
-  static MaybeMsg WhyNotOk(const parser::IfStmt &x) {
-    if (auto result{
-            CheckUnwrappedExpr(std::get<parser::ScalarLogicalExpr>(x.t))}) {
+  static MaybeMsg WhyNotOk(SemanticsContext &context, const parser::IfStmt &x) {
+    if (auto result{CheckUnwrappedExpr(
+            context, std::get<parser::ScalarLogicalExpr>(x.t))}) {
       return result;
     }
-    return WhyNotOk(
+    return WhyNotOk(context,
         std::get<parser::UnlabeledStatement<parser::ActionStmt>>(x.t)
             .statement);
   }
-  static MaybeMsg WhyNotOk(const parser::NullifyStmt &x) {
+  static MaybeMsg WhyNotOk(
+      SemanticsContext &context, const parser::NullifyStmt &x) {
     for (const auto &y : x.v) {
-      if (MaybeMsg result{DeviceExprChecker{}(y.typedExpr)}) {
+      if (MaybeMsg result{DeviceExprChecker{context}(y.typedExpr)}) {
         return result;
       }
     }
     return {};
   }
-  static MaybeMsg WhyNotOk(const parser::PointerAssignmentStmt &x) {
-    return DeviceExprChecker{}(x.typedAssignment);
+  static MaybeMsg WhyNotOk(
+      SemanticsContext &context, const parser::PointerAssignmentStmt &x) {
+    return DeviceExprChecker{context}(x.typedAssignment);
   }
 };
 
@@ -435,12 +472,14 @@ template <bool IsCUFKernelDo> class DeviceContextChecker {
                 ErrorIfHostSymbol(assign->lhs, source);
                 ErrorIfHostSymbol(assign->rhs, source);
               }
-              if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(x)}) {
+              if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(
+                      context_, x)}) {
                 context_.Say(source, std::move(*msg));
               }
             },
             [&](const auto &x) {
-              if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(x)}) {
+              if (auto msg{ActionStmtChecker<IsCUFKernelDo>::WhyNotOk(
+                      context_, x)}) {
                 context_.Say(source, std::move(*msg));
               }
             },
@@ -504,7 +543,7 @@ template <bool IsCUFKernelDo> class DeviceContextChecker {
     Check(DEREF(parser::Unwrap<parser::Expr>(x)));
   }
   void Check(const parser::Expr &expr) {
-    if (MaybeMsg msg{DeviceExprChecker{}(expr.typedExpr)}) {
+    if (MaybeMsg msg{DeviceExprChecker{context_}(expr.typedExpr)}) {
       context_.Say(expr.source, std::move(*msg));
     }
   }

diff  --git a/flang/test/Semantics/cuf22.cuf b/flang/test/Semantics/cuf22.cuf
new file mode 100644
index 0000000000000..36e0f0b2502df
--- /dev/null
+++ b/flang/test/Semantics/cuf22.cuf
@@ -0,0 +1,8 @@
+! RUN: not bbc -fcuda -fcuda-disable-warp-function %s -o - 2>&1 | FileCheck %s
+
+attributes(device) subroutine testMatch()
+  integer :: a, ipred, mask, v32
+  a = match_all_sync(mask, v32, ipred)
+end subroutine
+
+! CHECK:  warp match function disabled

diff  --git a/flang/tools/bbc/bbc.cpp b/flang/tools/bbc/bbc.cpp
index c544008a24d56..c80872108ac8f 100644
--- a/flang/tools/bbc/bbc.cpp
+++ b/flang/tools/bbc/bbc.cpp
@@ -223,6 +223,11 @@ static llvm::cl::opt<bool> enableCUDA("fcuda",
                                       llvm::cl::desc("enable CUDA Fortran"),
                                       llvm::cl::init(false));
 
+static llvm::cl::opt<bool>
+    disableCUDAWarpFunction("fcuda-disable-warp-function",
+                            llvm::cl::desc("Disable CUDA Warp Function"),
+                            llvm::cl::init(false));
+
 static llvm::cl::opt<std::string>
     enableGPUMode("gpu", llvm::cl::desc("Enable GPU Mode managed|unified"),
                   llvm::cl::init(""));
@@ -600,6 +605,11 @@ int main(int argc, char **argv) {
     options.features.Enable(Fortran::common::LanguageFeature::CUDA);
   }
 
+  if (disableCUDAWarpFunction) {
+    options.features.Enable(
+        Fortran::common::LanguageFeature::CudaWarpMatchFunction, false);
+  }
+
   if (enableGPUMode == "managed") {
     options.features.Enable(Fortran::common::LanguageFeature::CudaManaged);
   } else if (enableGPUMode == "unified") {


        


More information about the flang-commits mailing list