[flang-commits] [flang] [flang] Parse !$CUF KERNEL DO <<< (*) (PR #85338)
Peter Klausler via flang-commits
flang-commits at lists.llvm.org
Thu Mar 14 16:51:58 PDT 2024
https://github.com/klausler created https://github.com/llvm/llvm-project/pull/85338
Accept and represent asterisks within the parenthesized grid and block specification lists.
>From e974b92346cdf07eb88ed1fbc3434c0c3ac1dab6 Mon Sep 17 00:00:00 2001
From: Peter Klausler <pklausler at nvidia.com>
Date: Thu, 14 Mar 2024 15:43:40 -0700
Subject: [PATCH] [flang] Parse !$CUF KERNEL DO <<< (*)
Accept and represent asterisks within the parenthesized grid
and block specification lists.
---
flang/include/flang/Parser/dump-parse-tree.h | 1 +
flang/include/flang/Parser/parse-tree.h | 10 ++++---
flang/lib/Lower/Bridge.cpp | 29 ++++++++++++++-----
flang/lib/Parser/executable-parsers.cpp | 20 ++++++-------
flang/lib/Parser/misc-parsers.h | 5 ++++
flang/lib/Parser/unparse.cpp | 7 +++++
.../Lower/CUDA/cuda-kernel-loop-directive.cuf | 5 +---
flang/test/Parser/cuf-sanity-tree.CUF | 6 ++--
flang/test/Semantics/cuf09.cuf | 9 ++++++
9 files changed, 63 insertions(+), 29 deletions(-)
diff --git a/flang/include/flang/Parser/dump-parse-tree.h b/flang/include/flang/Parser/dump-parse-tree.h
index 048008a8d80c79..b2c3d92909375c 100644
--- a/flang/include/flang/Parser/dump-parse-tree.h
+++ b/flang/include/flang/Parser/dump-parse-tree.h
@@ -233,6 +233,7 @@ class ParseTreeDumper {
NODE(parser, CriticalStmt)
NODE(parser, CUDAAttributesStmt)
NODE(parser, CUFKernelDoConstruct)
+ NODE(CUFKernelDoConstruct, StarOrExpr)
NODE(CUFKernelDoConstruct, Directive)
NODE(parser, CycleStmt)
NODE(parser, DataComponentDefStmt)
diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index f7b72c3af09164..c96abfba491d4b 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -4297,16 +4297,18 @@ struct OpenACCConstruct {
// CUF-kernel-do-construct ->
// !$CUF KERNEL DO [ (scalar-int-constant-expr) ] <<< grid, block [, stream]
// >>> do-construct
-// grid -> * | scalar-int-expr | ( scalar-int-expr-list )
-// block -> * | scalar-int-expr | ( scalar-int-expr-list )
+// star-or-expr -> * | scalar-int-expr
+// grid -> * | scalar-int-expr | ( star-or-expr-list )
+// block -> * | scalar-int-expr | ( star-or-expr-list )
// stream -> 0, scalar-int-expr | STREAM = scalar-int-expr
struct CUFKernelDoConstruct {
TUPLE_CLASS_BOILERPLATE(CUFKernelDoConstruct);
+ WRAPPER_CLASS(StarOrExpr, std::optional<ScalarIntExpr>);
struct Directive {
TUPLE_CLASS_BOILERPLATE(Directive);
CharBlock source;
- std::tuple<std::optional<ScalarIntConstantExpr>, std::list<ScalarIntExpr>,
- std::list<ScalarIntExpr>, std::optional<ScalarIntExpr>>
+ std::tuple<std::optional<ScalarIntConstantExpr>, std::list<StarOrExpr>,
+ std::list<StarOrExpr>, std::optional<ScalarIntExpr>>
t;
};
std::tuple<Directive, std::optional<DoConstruct>> t;
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index a668ba4116faab..e6511e0c61c8c3 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2508,19 +2508,32 @@ class FirConverter : public Fortran::lower::AbstractConverter {
if (nestedLoops > 1)
n = builder->getIntegerAttr(builder->getI64Type(), nestedLoops);
- const std::list<Fortran::parser::ScalarIntExpr> &grid = std::get<1>(dir.t);
- const std::list<Fortran::parser::ScalarIntExpr> &block = std::get<2>(dir.t);
+ const std::list<Fortran::parser::CUFKernelDoConstruct::StarOrExpr> &grid =
+ std::get<1>(dir.t);
+ const std::list<Fortran::parser::CUFKernelDoConstruct::StarOrExpr> &block =
+ std::get<2>(dir.t);
const std::optional<Fortran::parser::ScalarIntExpr> &stream =
std::get<3>(dir.t);
llvm::SmallVector<mlir::Value> gridValues;
- for (const Fortran::parser::ScalarIntExpr &expr : grid)
- gridValues.push_back(fir::getBase(
- genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)));
+ for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr : grid) {
+ if (expr.v) {
+ gridValues.push_back(fir::getBase(
+ genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx)));
+ } else {
+ // TODO: '*'
+ }
+ }
llvm::SmallVector<mlir::Value> blockValues;
- for (const Fortran::parser::ScalarIntExpr &expr : block)
- blockValues.push_back(fir::getBase(
- genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)));
+ for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr :
+ block) {
+ if (expr.v) {
+ blockValues.push_back(fir::getBase(
+ genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx)));
+ } else {
+ // TODO: '*'
+ }
+ }
mlir::Value streamValue;
if (stream)
streamValue = fir::getBase(
diff --git a/flang/lib/Parser/executable-parsers.cpp b/flang/lib/Parser/executable-parsers.cpp
index de2be017508c37..6f258f131c3c22 100644
--- a/flang/lib/Parser/executable-parsers.cpp
+++ b/flang/lib/Parser/executable-parsers.cpp
@@ -542,19 +542,19 @@ TYPE_CONTEXT_PARSER("UNLOCK statement"_en_US,
// CUF-kernel-do-directive ->
// !$CUF KERNEL DO [ (scalar-int-constant-expr) ] <<< grid, block [, stream]
// >>> do-construct
-// grid -> * | scalar-int-expr | ( scalar-int-expr-list )
-// block -> * | scalar-int-expr | ( scalar-int-expr-list )
+// star-or-expr -> * | scalar-int-expr
+// grid -> * | scalar-int-expr | ( star-or-expr-list )
+// block -> * | scalar-int-expr | ( star-or-expr-list )
// stream -> ( 0, | STREAM = ) scalar-int-expr
+const auto starOrExpr{construct<CUFKernelDoConstruct::StarOrExpr>(
+ "*" >> pure<std::optional<ScalarIntExpr>>() ||
+ applyFunction(presentOptional<ScalarIntExpr>, scalarIntExpr))};
+constexpr auto gridOrBlock{parenthesized(nonemptyList(starOrExpr)) ||
+ applyFunction(singletonList<CUFKernelDoConstruct::StarOrExpr>, starOrExpr)};
TYPE_PARSER(sourced(beginDirective >> "$CUF KERNEL DO"_tok >>
construct<CUFKernelDoConstruct::Directive>(
- maybe(parenthesized(scalarIntConstantExpr)),
- "<<<" >>
- ("*" >> pure<std::list<ScalarIntExpr>>() ||
- parenthesized(nonemptyList(scalarIntExpr)) ||
- applyFunction(singletonList<ScalarIntExpr>, scalarIntExpr)),
- "," >> ("*" >> pure<std::list<ScalarIntExpr>>() ||
- parenthesized(nonemptyList(scalarIntExpr)) ||
- applyFunction(singletonList<ScalarIntExpr>, scalarIntExpr)),
+ maybe(parenthesized(scalarIntConstantExpr)), "<<<" >> gridOrBlock,
+ "," >> gridOrBlock,
maybe((", 0 ,"_tok || ", STREAM ="_tok) >> scalarIntExpr) / ">>>" /
endDirective)))
TYPE_CONTEXT_PARSER("!$CUF KERNEL DO construct"_en_US,
diff --git a/flang/lib/Parser/misc-parsers.h b/flang/lib/Parser/misc-parsers.h
index e9b52b7d0fcd0f..4a318e05bb4b8c 100644
--- a/flang/lib/Parser/misc-parsers.h
+++ b/flang/lib/Parser/misc-parsers.h
@@ -57,5 +57,10 @@ template <typename A> common::IfNoLvalue<std::list<A>, A> singletonList(A &&x) {
result.emplace_back(std::move(x));
return result;
}
+
+template <typename A>
+common::IfNoLvalue<std::optional<A>, A> presentOptional(A &&x) {
+ return std::make_optional(std::move(x));
+}
} // namespace Fortran::parser
#endif
diff --git a/flang/lib/Parser/unparse.cpp b/flang/lib/Parser/unparse.cpp
index 600aa01999dab7..baba4863f5775f 100644
--- a/flang/lib/Parser/unparse.cpp
+++ b/flang/lib/Parser/unparse.cpp
@@ -2729,6 +2729,13 @@ class UnparseVisitor {
WALK_NESTED_ENUM(OmpOrderModifier, Kind) // OMP order-modifier
#undef WALK_NESTED_ENUM
+ void Unparse(const CUFKernelDoConstruct::StarOrExpr &x) {
+ if (x.v) {
+ Walk(*x.v);
+ } else {
+ Word("*");
+ }
+ }
void Unparse(const CUFKernelDoConstruct::Directive &x) {
Word("!$CUF KERNEL DO");
Walk(" (", std::get<std::optional<ScalarIntConstantExpr>>(x.t), ")");
diff --git a/flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf b/flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf
index db628fe756b952..c017561447f85d 100644
--- a/flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf
+++ b/flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf
@@ -42,10 +42,7 @@ subroutine sub1()
! CHECK: fir.cuda_kernel<<<%c1{{.*}}, (%c256{{.*}}, %c1{{.*}})>>> (%{{.*}} : index, %{{.*}} : index) = (%{{.*}}, %{{.*}} : index, index) to (%{{.*}}, %{{.*}} : index, index) step (%{{.*}}, %{{.*}} : index, index)
! CHECK: {n = 2 : i64}
-! TODO: currently these trigger error in the parser
+! TODO: lowering for these cases
! !$cuf kernel do(2) <<< (1,*), (256,1) >>>
! !$cuf kernel do(2) <<< (*,*), (32,4) >>>
end
-
-
-
diff --git a/flang/test/Parser/cuf-sanity-tree.CUF b/flang/test/Parser/cuf-sanity-tree.CUF
index f6cf9bbdd6b0cc..dc12759d3ce52f 100644
--- a/flang/test/Parser/cuf-sanity-tree.CUF
+++ b/flang/test/Parser/cuf-sanity-tree.CUF
@@ -144,11 +144,11 @@ include "cuf-sanity-common"
!CHECK: | | | | | | EndDoStmt ->
!CHECK: | | | | ExecutionPartConstruct -> ExecutableConstruct -> CUFKernelDoConstruct
!CHECK: | | | | | Directive
-!CHECK: | | | | | | Scalar -> Integer -> Expr = '1_4'
+!CHECK: | | | | | | StarOrExpr -> Scalar -> Integer -> Expr = '1_4'
!CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '1'
-!CHECK: | | | | | | Scalar -> Integer -> Expr = '2_4'
+!CHECK: | | | | | | StarOrExpr -> Scalar -> Integer -> Expr = '2_4'
!CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '2'
-!CHECK: | | | | | | Scalar -> Integer -> Expr = '3_4'
+!CHECK: | | | | | | StarOrExpr -> Scalar -> Integer -> Expr = '3_4'
!CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '3'
!CHECK: | | | | | | Scalar -> Integer -> Expr = '1_4'
!CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '1'
diff --git a/flang/test/Semantics/cuf09.cuf b/flang/test/Semantics/cuf09.cuf
index dd70c3b1ff5efd..4bc93132044fdd 100644
--- a/flang/test/Semantics/cuf09.cuf
+++ b/flang/test/Semantics/cuf09.cuf
@@ -10,6 +10,15 @@ module m
end
program main
+ !$cuf kernel do <<< *, * >>> ! ok
+ do j = 1, 0
+ end do
+ !$cuf kernel do <<< (*), (*) >>> ! ok
+ do j = 1, 0
+ end do
+ !$cuf kernel do <<< (1,*), (2,*) >>> ! ok
+ do j = 1, 0
+ end do
!ERROR: !$CUF KERNEL DO (1) must be followed by a DO construct with tightly nested outer levels of counted DO loops
!$cuf kernel do <<< 1, 2 >>>
do while (.false.)
More information about the flang-commits
mailing list