[flang-commits] [flang] [fang][cuda] Allow * in call chevron syntax (PR #115381)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Thu Nov 7 17:25:55 PST 2024


https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/115381

>From 7f99ff74fe20f3423906c1a2a5326f7271d5fa41 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 7 Nov 2024 14:26:00 -0800
Subject: [PATCH 1/3] [fang][cuda] Allow * in call chevron syntax

---
 flang/include/flang/Parser/dump-parse-tree.h |  1 +
 flang/include/flang/Parser/parse-tree.h      |  5 ++--
 flang/lib/Parser/program-parsers.cpp         |  7 +++--
 flang/lib/Parser/unparse.cpp                 |  7 +++++
 flang/lib/Semantics/expression.cpp           | 28 ++++++++++++++------
 flang/test/Parser/cuf-sanity-common          |  3 +++
 flang/test/Parser/cuf-sanity-tree.CUF        | 12 ++++-----
 flang/test/Parser/cuf-sanity-unparse.CUF     |  3 +++
 8 files changed, 48 insertions(+), 18 deletions(-)

diff --git a/flang/include/flang/Parser/dump-parse-tree.h b/flang/include/flang/Parser/dump-parse-tree.h
index bfeb23de535392..675faeb33668f3 100644
--- a/flang/include/flang/Parser/dump-parse-tree.h
+++ b/flang/include/flang/Parser/dump-parse-tree.h
@@ -177,6 +177,7 @@ class ParseTreeDumper {
   NODE(parser, Call)
   NODE(parser, CallStmt)
   NODE(CallStmt, Chevrons)
+  NODE(CallStmt, StarOrExpr)
   NODE(parser, CaseConstruct)
   NODE(CaseConstruct, Case)
   NODE(parser, CaseSelector)
diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index d2c5b45d995813..f84fd7565b006d 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -3247,13 +3247,14 @@ struct FunctionReference {
 
 // R1521 call-stmt -> CALL procedure-designator [ chevrons ]
 //         [( [actual-arg-spec-list] )]
-// (CUDA) chevrons -> <<< scalar-expr, scalar-expr [,
+// (CUDA) chevrons -> <<< * | scalar-expr, * | scalar-expr [,
 //          scalar-int-expr [, scalar-int-expr ] ] >>>
 struct CallStmt {
   BOILERPLATE(CallStmt);
+  WRAPPER_CLASS(StarOrExpr, std::optional<ScalarExpr>);
   struct Chevrons {
     TUPLE_CLASS_BOILERPLATE(Chevrons);
-    std::tuple<ScalarExpr, ScalarExpr, std::optional<ScalarIntExpr>,
+    std::tuple<StarOrExpr, StarOrExpr, std::optional<ScalarIntExpr>,
         std::optional<ScalarIntExpr>>
         t;
   };
diff --git a/flang/lib/Parser/program-parsers.cpp b/flang/lib/Parser/program-parsers.cpp
index 2b7da18a09bb30..a11b9b87765f81 100644
--- a/flang/lib/Parser/program-parsers.cpp
+++ b/flang/lib/Parser/program-parsers.cpp
@@ -474,10 +474,13 @@ TYPE_CONTEXT_PARSER("function reference"_en_US,
 
 // R1521 call-stmt -> CALL procedure-designator [chevrons]
 ///                          [( [actual-arg-spec-list] )]
-// (CUDA) chevrons -> <<< scalar-expr, scalar-expr [, scalar-int-expr
+// (CUDA) chevrons -> <<< * | scalar-expr, * | scalar-expr [, scalar-int-expr
 //                      [, scalar-int-expr ] ] >>>
+constexpr auto starOrExpr{
+    construct<CallStmt::StarOrExpr>("*" >> pure<std::optional<ScalarExpr>>() ||
+        applyFunction(presentOptional<ScalarExpr>, scalarExpr))};
 TYPE_PARSER(extension<LanguageFeature::CUDA>(
-    "<<<" >> construct<CallStmt::Chevrons>(scalarExpr, "," >> scalarExpr,
+    "<<<" >> construct<CallStmt::Chevrons>(starOrExpr, ", " >> starOrExpr,
                  maybe("," >> scalarIntExpr), maybe("," >> scalarIntExpr)) /
         ">>>"))
 constexpr auto actualArgSpecList{optionalList(actualArgSpec)};
diff --git a/flang/lib/Parser/unparse.cpp b/flang/lib/Parser/unparse.cpp
index bbb126dcdb6d5e..5d70f3433b4453 100644
--- a/flang/lib/Parser/unparse.cpp
+++ b/flang/lib/Parser/unparse.cpp
@@ -1703,6 +1703,13 @@ class UnparseVisitor {
   void Unparse(const IntrinsicStmt &x) { // R1519
     Word("INTRINSIC :: "), Walk(x.v, ", ");
   }
+  void Unparse(const CallStmt::StarOrExpr &x) {
+    if (x.v) {
+      Walk(*x.v);
+    } else {
+      Word("*");
+    }
+  }
   void Unparse(const CallStmt::Chevrons &x) { // CUDA
     Walk(std::get<0>(x.t)); // grid
     Word(","), Walk(std::get<1>(x.t)); // block
diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index c70c8a8aecc2f8..e380d9532ee181 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -3066,17 +3066,29 @@ std::optional<Chevrons> ExpressionAnalyzer::AnalyzeChevrons(
     return false;
   }};
   if (const auto &chevrons{call.chevrons}) {
-    if (auto expr{Analyze(std::get<0>(chevrons->t))};
-        expr && checkLaunchArg(*expr, "grid")) {
-      result.emplace_back(*expr);
+    auto &starOrExpr0{std::get<0>(chevrons->t)};
+    if (starOrExpr0.v) {
+      if (auto expr{Analyze(*starOrExpr0.v)};
+          expr && checkLaunchArg(*expr, "grid")) {
+        result.emplace_back(*expr);
+      } else {
+        return std::nullopt;
+      }
     } else {
-      return std::nullopt;
+      result.emplace_back(
+          AsGenericExpr(evaluate::Constant<evaluate::SubscriptInteger>{-1}));
     }
-    if (auto expr{Analyze(std::get<1>(chevrons->t))};
-        expr && checkLaunchArg(*expr, "block")) {
-      result.emplace_back(*expr);
+    auto &starOrExpr1{std::get<1>(chevrons->t)};
+    if (starOrExpr1.v) {
+      if (auto expr{Analyze(*starOrExpr1.v)};
+          expr && checkLaunchArg(*expr, "block")) {
+        result.emplace_back(*expr);
+      } else {
+        return std::nullopt;
+      }
     } else {
-      return std::nullopt;
+      result.emplace_back(
+          AsGenericExpr(evaluate::Constant<evaluate::SubscriptInteger>{-1}));
     }
     if (const auto &maybeExpr{std::get<2>(chevrons->t)}) {
       if (auto expr{Analyze(*maybeExpr)}) {
diff --git a/flang/test/Parser/cuf-sanity-common b/flang/test/Parser/cuf-sanity-common
index 9341f054d79d46..d08048058adbec 100644
--- a/flang/test/Parser/cuf-sanity-common
+++ b/flang/test/Parser/cuf-sanity-common
@@ -40,6 +40,9 @@ module m
     call globalsub<<<1, 2>>>
     call globalsub<<<1, 2, 3>>>
     call globalsub<<<1, 2, 3, 4>>>
+    call globalsub<<<*,*>>>
+    call globalsub<<<*,5>>>
+    call globalsub<<<1,*>>>
     allocate(pa(32), pinned = isPinned)
   end subroutine
 end module
diff --git a/flang/test/Parser/cuf-sanity-tree.CUF b/flang/test/Parser/cuf-sanity-tree.CUF
index 2820441d5b5f0a..7f097ab6c9c659 100644
--- a/flang/test/Parser/cuf-sanity-tree.CUF
+++ b/flang/test/Parser/cuf-sanity-tree.CUF
@@ -166,17 +166,17 @@ include "cuf-sanity-common"
 !CHECK: | | | | | Call
 !CHECK: | | | | | | ProcedureDesignator -> Name = 'globalsub'
 !CHECK: | | | | | Chevrons
-!CHECK: | | | | | | Scalar -> Expr = '1_4'
+!CHECK: | | | | | | StarOrExpr -> Scalar -> Expr = '1_4'
 !CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '1'
-!CHECK: | | | | | | Scalar -> Expr = '2_4'
+!CHECK: | | | | | | StarOrExpr -> Scalar -> Expr = '2_4'
 !CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '2'
 !CHECK: | | | | ExecutionPartConstruct -> ExecutableConstruct -> ActionStmt -> CallStmt = 'CALL globalsub<<<1_4,2_4,3_4>>>()'
 !CHECK: | | | | | Call
 !CHECK: | | | | | | ProcedureDesignator -> Name = 'globalsub'
 !CHECK: | | | | | Chevrons
-!CHECK: | | | | | | Scalar -> Expr = '1_4'
+!CHECK: | | | | | | StarOrExpr -> Scalar -> Expr = '1_4'
 !CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '1'
-!CHECK: | | | | | | Scalar -> Expr = '2_4'
+!CHECK: | | | | | | StarOrExpr -> Scalar -> Expr = '2_4'
 !CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '2'
 !CHECK: | | | | | | Scalar -> Integer -> Expr = '3_4'
 !CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '3'
@@ -184,9 +184,9 @@ include "cuf-sanity-common"
 !CHECK: | | | | | Call
 !CHECK: | | | | | | ProcedureDesignator -> Name = 'globalsub'
 !CHECK: | | | | | Chevrons
-!CHECK: | | | | | | Scalar -> Expr = '1_4'
+!CHECK: | | | | | | StarOrExpr -> Scalar -> Expr = '1_4'
 !CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '1'
-!CHECK: | | | | | | Scalar -> Expr = '2_4'
+!CHECK: | | | | | | StarOrExpr -> Scalar -> Expr = '2_4'
 !CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '2'
 !CHECK: | | | | | | Scalar -> Integer -> Expr = '3_4'
 !CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '3'
diff --git a/flang/test/Parser/cuf-sanity-unparse.CUF b/flang/test/Parser/cuf-sanity-unparse.CUF
index d4be347dd044ea..938caa5982c6e5 100644
--- a/flang/test/Parser/cuf-sanity-unparse.CUF
+++ b/flang/test/Parser/cuf-sanity-unparse.CUF
@@ -43,6 +43,9 @@ include "cuf-sanity-common"
 !CHECK:    CALL globalsub<<<1_4,2_4>>>()
 !CHECK:    CALL globalsub<<<1_4,2_4,3_4>>>()
 !CHECK:    CALL globalsub<<<1_4,2_4,3_4,4_4>>>()
+!CHECK:    CALL globalsub<<<-1_8,-1_8>>>()
+!CHECK:    CALL globalsub<<<-1_8,5_4>>>()
+!CHECK:    CALL globalsub<<<1_4,-1_8>>>()
 !CHECK:   ALLOCATE(pa(32_4), PINNED=ispinned)
 !CHECK:  END SUBROUTINE
 !CHECK: END MODULE

>From c9dabfae6dc10f0d0e2a417ef2e7436b2b327041 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 7 Nov 2024 16:35:50 -0800
Subject: [PATCH 2/3] Allow only one *

---
 flang/lib/Semantics/expression.cpp       | 8 ++++++++
 flang/test/Parser/cuf-sanity-common      | 1 -
 flang/test/Parser/cuf-sanity-unparse.CUF | 1 -
 flang/test/Semantics/cuf04.cuf           | 2 ++
 4 files changed, 10 insertions(+), 2 deletions(-)

diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index e380d9532ee181..b492fe1291b80b 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -3065,7 +3065,10 @@ std::optional<Chevrons> ExpressionAnalyzer::AnalyzeChevrons(
         which);
     return false;
   }};
+
   if (const auto &chevrons{call.chevrons}) {
+    bool gridIsStar{false};
+    bool blockIsStar{false};
     auto &starOrExpr0{std::get<0>(chevrons->t)};
     if (starOrExpr0.v) {
       if (auto expr{Analyze(*starOrExpr0.v)};
@@ -3075,6 +3078,7 @@ std::optional<Chevrons> ExpressionAnalyzer::AnalyzeChevrons(
         return std::nullopt;
       }
     } else {
+      gridIsStar = true;
       result.emplace_back(
           AsGenericExpr(evaluate::Constant<evaluate::SubscriptInteger>{-1}));
     }
@@ -3087,9 +3091,13 @@ std::optional<Chevrons> ExpressionAnalyzer::AnalyzeChevrons(
         return std::nullopt;
       }
     } else {
+      blockIsStar = true;
       result.emplace_back(
           AsGenericExpr(evaluate::Constant<evaluate::SubscriptInteger>{-1}));
     }
+    if (gridIsStar && blockIsStar) {
+      Say("Grid and block can not be * in kernel launch parameter"_err_en_US);
+    }
     if (const auto &maybeExpr{std::get<2>(chevrons->t)}) {
       if (auto expr{Analyze(*maybeExpr)}) {
         result.emplace_back(*expr);
diff --git a/flang/test/Parser/cuf-sanity-common b/flang/test/Parser/cuf-sanity-common
index d08048058adbec..ed8ffd09768257 100644
--- a/flang/test/Parser/cuf-sanity-common
+++ b/flang/test/Parser/cuf-sanity-common
@@ -40,7 +40,6 @@ module m
     call globalsub<<<1, 2>>>
     call globalsub<<<1, 2, 3>>>
     call globalsub<<<1, 2, 3, 4>>>
-    call globalsub<<<*,*>>>
     call globalsub<<<*,5>>>
     call globalsub<<<1,*>>>
     allocate(pa(32), pinned = isPinned)
diff --git a/flang/test/Parser/cuf-sanity-unparse.CUF b/flang/test/Parser/cuf-sanity-unparse.CUF
index 938caa5982c6e5..9345e837d9e184 100644
--- a/flang/test/Parser/cuf-sanity-unparse.CUF
+++ b/flang/test/Parser/cuf-sanity-unparse.CUF
@@ -43,7 +43,6 @@ include "cuf-sanity-common"
 !CHECK:    CALL globalsub<<<1_4,2_4>>>()
 !CHECK:    CALL globalsub<<<1_4,2_4,3_4>>>()
 !CHECK:    CALL globalsub<<<1_4,2_4,3_4,4_4>>>()
-!CHECK:    CALL globalsub<<<-1_8,-1_8>>>()
 !CHECK:    CALL globalsub<<<-1_8,5_4>>>()
 !CHECK:    CALL globalsub<<<1_4,-1_8>>>()
 !CHECK:   ALLOCATE(pa(32_4), PINNED=ispinned)
diff --git a/flang/test/Semantics/cuf04.cuf b/flang/test/Semantics/cuf04.cuf
index 2e2faa90b490db..32b2102ec43072 100644
--- a/flang/test/Semantics/cuf04.cuf
+++ b/flang/test/Semantics/cuf04.cuf
@@ -20,5 +20,7 @@ module m
     call globsubr
     !ERROR: Kernel launch parameters in chevrons may not be used unless calling a kernel subroutine
     call boring<<<1,2>>>
+    !ERROR: Grid and block can not be * in kernel launch parameter
+    call globsubr<<<*, *>>>
   end subroutine
 end module

>From d02922c3afb5d962b50fce3fb2ece45792a6e79a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Valentin=20Clement=20=28=E3=83=90=E3=83=AC=E3=83=B3?=
 =?UTF-8?q?=E3=82=BF=E3=82=A4=E3=83=B3=20=E3=82=AF=E3=83=AC=E3=83=A1?=
 =?UTF-8?q?=E3=83=B3=29?= <clementval at gmail.com>
Date: Thu, 7 Nov 2024 17:25:48 -0800
Subject: [PATCH 3/3] Update flang/lib/Semantics/expression.cpp

---
 flang/lib/Semantics/expression.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/flang/lib/Semantics/expression.cpp b/flang/lib/Semantics/expression.cpp
index b492fe1291b80b..84818106ce56e6 100644
--- a/flang/lib/Semantics/expression.cpp
+++ b/flang/lib/Semantics/expression.cpp
@@ -3096,7 +3096,7 @@ std::optional<Chevrons> ExpressionAnalyzer::AnalyzeChevrons(
           AsGenericExpr(evaluate::Constant<evaluate::SubscriptInteger>{-1}));
     }
     if (gridIsStar && blockIsStar) {
-      Say("Grid and block can not be * in kernel launch parameter"_err_en_US);
+      Say("Grid and block can not both be * in kernel launch parameter"_err_en_US);
     }
     if (const auto &maybeExpr{std::get<2>(chevrons->t)}) {
       if (auto expr{Analyze(*maybeExpr)}) {



More information about the flang-commits mailing list