[flang-commits] [flang] [flang] Parse !$CUF KERNEL DO <<< (*) (PR #85338)

Peter Klausler via flang-commits flang-commits at lists.llvm.org
Thu Mar 14 16:51:58 PDT 2024


https://github.com/klausler created https://github.com/llvm/llvm-project/pull/85338

Accept and represent asterisks within the parenthesized grid and block specification lists.

>From e974b92346cdf07eb88ed1fbc3434c0c3ac1dab6 Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Thu, 14 Mar 2024 15:43:40 -0700
Subject: [PATCH] [flang] Parse !$CUF KERNEL DO <<< (*)

Accept and represent asterisks within the parenthesized grid
and block specification lists.
---
 flang/include/flang/Parser/dump-parse-tree.h  |  1 +
 flang/include/flang/Parser/parse-tree.h       | 10 ++++---
 flang/lib/Lower/Bridge.cpp                    | 29 ++++++++++++++-----
 flang/lib/Parser/executable-parsers.cpp       | 20 ++++++-------
 flang/lib/Parser/misc-parsers.h               |  5 ++++
 flang/lib/Parser/unparse.cpp                  |  7 +++++
 .../Lower/CUDA/cuda-kernel-loop-directive.cuf |  5 +---
 flang/test/Parser/cuf-sanity-tree.CUF         |  6 ++--
 flang/test/Semantics/cuf09.cuf                |  9 ++++++
 9 files changed, 63 insertions(+), 29 deletions(-)

diff --git a/flang/include/flang/Parser/dump-parse-tree.h b/flang/include/flang/Parser/dump-parse-tree.h
index 048008a8d80c79..b2c3d92909375c 100644
--- a/flang/include/flang/Parser/dump-parse-tree.h
+++ b/flang/include/flang/Parser/dump-parse-tree.h
@@ -233,6 +233,7 @@ class ParseTreeDumper {
   NODE(parser, CriticalStmt)
   NODE(parser, CUDAAttributesStmt)
   NODE(parser, CUFKernelDoConstruct)
+  NODE(CUFKernelDoConstruct, StarOrExpr)
   NODE(CUFKernelDoConstruct, Directive)
   NODE(parser, CycleStmt)
   NODE(parser, DataComponentDefStmt)
diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index f7b72c3af09164..c96abfba491d4b 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -4297,16 +4297,18 @@ struct OpenACCConstruct {
 // CUF-kernel-do-construct ->
 //     !$CUF KERNEL DO [ (scalar-int-constant-expr) ] <<< grid, block [, stream]
 //     >>> do-construct
-// grid -> * | scalar-int-expr | ( scalar-int-expr-list )
-// block -> * | scalar-int-expr | ( scalar-int-expr-list )
+// star-or-expr -> * | scalar-int-expr
+// grid -> * | scalar-int-expr | ( star-or-expr-list )
+// block -> * | scalar-int-expr | ( star-or-expr-list )
 // stream -> 0, scalar-int-expr | STREAM = scalar-int-expr
 struct CUFKernelDoConstruct {
   TUPLE_CLASS_BOILERPLATE(CUFKernelDoConstruct);
+  WRAPPER_CLASS(StarOrExpr, std::optional<ScalarIntExpr>);
   struct Directive {
     TUPLE_CLASS_BOILERPLATE(Directive);
     CharBlock source;
-    std::tuple<std::optional<ScalarIntConstantExpr>, std::list<ScalarIntExpr>,
-        std::list<ScalarIntExpr>, std::optional<ScalarIntExpr>>
+    std::tuple<std::optional<ScalarIntConstantExpr>, std::list<StarOrExpr>,
+        std::list<StarOrExpr>, std::optional<ScalarIntExpr>>
         t;
   };
   std::tuple<Directive, std::optional<DoConstruct>> t;
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index a668ba4116faab..e6511e0c61c8c3 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2508,19 +2508,32 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     if (nestedLoops > 1)
       n = builder->getIntegerAttr(builder->getI64Type(), nestedLoops);
 
-    const std::list<Fortran::parser::ScalarIntExpr> &grid = std::get<1>(dir.t);
-    const std::list<Fortran::parser::ScalarIntExpr> &block = std::get<2>(dir.t);
+    const std::list<Fortran::parser::CUFKernelDoConstruct::StarOrExpr> &grid =
+        std::get<1>(dir.t);
+    const std::list<Fortran::parser::CUFKernelDoConstruct::StarOrExpr> &block =
+        std::get<2>(dir.t);
     const std::optional<Fortran::parser::ScalarIntExpr> &stream =
         std::get<3>(dir.t);
 
     llvm::SmallVector<mlir::Value> gridValues;
-    for (const Fortran::parser::ScalarIntExpr &expr : grid)
-      gridValues.push_back(fir::getBase(
-          genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)));
+    for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr : grid) {
+      if (expr.v) {
+        gridValues.push_back(fir::getBase(
+            genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx)));
+      } else {
+        // TODO: '*'
+      }
+    }
     llvm::SmallVector<mlir::Value> blockValues;
-    for (const Fortran::parser::ScalarIntExpr &expr : block)
-      blockValues.push_back(fir::getBase(
-          genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)));
+    for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr :
+         block) {
+      if (expr.v) {
+        blockValues.push_back(fir::getBase(
+            genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx)));
+      } else {
+        // TODO: '*'
+      }
+    }
     mlir::Value streamValue;
     if (stream)
       streamValue = fir::getBase(
diff --git a/flang/lib/Parser/executable-parsers.cpp b/flang/lib/Parser/executable-parsers.cpp
index de2be017508c37..6f258f131c3c22 100644
--- a/flang/lib/Parser/executable-parsers.cpp
+++ b/flang/lib/Parser/executable-parsers.cpp
@@ -542,19 +542,19 @@ TYPE_CONTEXT_PARSER("UNLOCK statement"_en_US,
 // CUF-kernel-do-directive ->
 //     !$CUF KERNEL DO [ (scalar-int-constant-expr) ] <<< grid, block [, stream]
 //     >>> do-construct
-// grid -> * | scalar-int-expr | ( scalar-int-expr-list )
-// block -> * | scalar-int-expr | ( scalar-int-expr-list )
+// star-or-expr -> * | scalar-int-expr
+// grid -> * | scalar-int-expr | ( star-or-expr-list )
+// block -> * | scalar-int-expr | ( star-or-expr-list )
 // stream -> ( 0, | STREAM = ) scalar-int-expr
+const auto starOrExpr{construct<CUFKernelDoConstruct::StarOrExpr>(
+    "*" >> pure<std::optional<ScalarIntExpr>>() ||
+    applyFunction(presentOptional<ScalarIntExpr>, scalarIntExpr))};
+constexpr auto gridOrBlock{parenthesized(nonemptyList(starOrExpr)) ||
+    applyFunction(singletonList<CUFKernelDoConstruct::StarOrExpr>, starOrExpr)};
 TYPE_PARSER(sourced(beginDirective >> "$CUF KERNEL DO"_tok >>
     construct<CUFKernelDoConstruct::Directive>(
-        maybe(parenthesized(scalarIntConstantExpr)),
-        "<<<" >>
-            ("*" >> pure<std::list<ScalarIntExpr>>() ||
-                parenthesized(nonemptyList(scalarIntExpr)) ||
-                applyFunction(singletonList<ScalarIntExpr>, scalarIntExpr)),
-        "," >> ("*" >> pure<std::list<ScalarIntExpr>>() ||
-                   parenthesized(nonemptyList(scalarIntExpr)) ||
-                   applyFunction(singletonList<ScalarIntExpr>, scalarIntExpr)),
+        maybe(parenthesized(scalarIntConstantExpr)), "<<<" >> gridOrBlock,
+        "," >> gridOrBlock,
         maybe((", 0 ,"_tok || ", STREAM ="_tok) >> scalarIntExpr) / ">>>" /
             endDirective)))
 TYPE_CONTEXT_PARSER("!$CUF KERNEL DO construct"_en_US,
diff --git a/flang/lib/Parser/misc-parsers.h b/flang/lib/Parser/misc-parsers.h
index e9b52b7d0fcd0f..4a318e05bb4b8c 100644
--- a/flang/lib/Parser/misc-parsers.h
+++ b/flang/lib/Parser/misc-parsers.h
@@ -57,5 +57,10 @@ template <typename A> common::IfNoLvalue<std::list<A>, A> singletonList(A &&x) {
   result.emplace_back(std::move(x));
   return result;
 }
+
+template <typename A>
+common::IfNoLvalue<std::optional<A>, A> presentOptional(A &&x) {
+  return std::make_optional(std::move(x));
+}
 } // namespace Fortran::parser
 #endif
diff --git a/flang/lib/Parser/unparse.cpp b/flang/lib/Parser/unparse.cpp
index 600aa01999dab7..baba4863f5775f 100644
--- a/flang/lib/Parser/unparse.cpp
+++ b/flang/lib/Parser/unparse.cpp
@@ -2729,6 +2729,13 @@ class UnparseVisitor {
   WALK_NESTED_ENUM(OmpOrderModifier, Kind) // OMP order-modifier
 #undef WALK_NESTED_ENUM
 
+  void Unparse(const CUFKernelDoConstruct::StarOrExpr &x) {
+    if (x.v) {
+      Walk(*x.v);
+    } else {
+      Word("*");
+    }
+  }
   void Unparse(const CUFKernelDoConstruct::Directive &x) {
     Word("!$CUF KERNEL DO");
     Walk(" (", std::get<std::optional<ScalarIntConstantExpr>>(x.t), ")");
diff --git a/flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf b/flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf
index db628fe756b952..c017561447f85d 100644
--- a/flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf
+++ b/flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf
@@ -42,10 +42,7 @@ subroutine sub1()
 ! CHECK: fir.cuda_kernel<<<%c1{{.*}}, (%c256{{.*}}, %c1{{.*}})>>> (%{{.*}} : index, %{{.*}} : index) = (%{{.*}}, %{{.*}} : index, index) to (%{{.*}}, %{{.*}} : index, index) step (%{{.*}}, %{{.*}} : index, index)
 ! CHECK: {n = 2 : i64}
 
-! TODO: currently these trigger error in the parser
+! TODO: lowering for these cases
 ! !$cuf kernel do(2) <<< (1,*), (256,1) >>>
 ! !$cuf kernel do(2) <<< (*,*), (32,4) >>>
 end
-
-
-
diff --git a/flang/test/Parser/cuf-sanity-tree.CUF b/flang/test/Parser/cuf-sanity-tree.CUF
index f6cf9bbdd6b0cc..dc12759d3ce52f 100644
--- a/flang/test/Parser/cuf-sanity-tree.CUF
+++ b/flang/test/Parser/cuf-sanity-tree.CUF
@@ -144,11 +144,11 @@ include "cuf-sanity-common"
 !CHECK: | | | | | | EndDoStmt -> 
 !CHECK: | | | | ExecutionPartConstruct -> ExecutableConstruct -> CUFKernelDoConstruct
 !CHECK: | | | | | Directive
-!CHECK: | | | | | | Scalar -> Integer -> Expr = '1_4'
+!CHECK: | | | | | | StarOrExpr -> Scalar -> Integer -> Expr = '1_4'
 !CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '1'
-!CHECK: | | | | | | Scalar -> Integer -> Expr = '2_4'
+!CHECK: | | | | | | StarOrExpr -> Scalar -> Integer -> Expr = '2_4'
 !CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '2'
-!CHECK: | | | | | | Scalar -> Integer -> Expr = '3_4'
+!CHECK: | | | | | | StarOrExpr -> Scalar -> Integer -> Expr = '3_4'
 !CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '3'
 !CHECK: | | | | | | Scalar -> Integer -> Expr = '1_4'
 !CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '1'
diff --git a/flang/test/Semantics/cuf09.cuf b/flang/test/Semantics/cuf09.cuf
index dd70c3b1ff5efd..4bc93132044fdd 100644
--- a/flang/test/Semantics/cuf09.cuf
+++ b/flang/test/Semantics/cuf09.cuf
@@ -10,6 +10,15 @@ module m
 end
 
 program main
+  !$cuf kernel do <<< *, * >>> ! ok
+  do j = 1, 0
+  end do
+  !$cuf kernel do <<< (*), (*) >>> ! ok
+  do j = 1, 0
+  end do
+  !$cuf kernel do <<< (1,*), (2,*) >>> ! ok
+  do j = 1, 0
+  end do
   !ERROR: !$CUF KERNEL DO (1) must be followed by a DO construct with tightly nested outer levels of counted DO loops
   !$cuf kernel do <<< 1, 2 >>>
   do while (.false.)



More information about the flang-commits mailing list