[flang-commits] [flang] [flang][cuda] Make launch configuration optional for cuf kernel (PR #115947)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Tue Nov 12 14:15:26 PST 2024
https://github.com/clementval created https://github.com/llvm/llvm-project/pull/115947
Launch configuration on `CUF KERNEL DO` directive is optional as shown in example code from https://docs.nvidia.com/hpc-sdk/compilers/cuda-fortran-prog-guide/index.html#cfpg-cuda-fort-host-dev-code
This patch updates the parse tree, parser and lowering to take this into account.
>From 3f5881d7d3532341b6b0195228dd17d3285d17b6 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Tue, 12 Nov 2024 14:12:24 -0800
Subject: [PATCH] [flang][cuda] Make launch configuration optional for cuf
kernel
---
flang/include/flang/Parser/dump-parse-tree.h | 1 +
flang/include/flang/Parser/parse-tree.h | 11 +++-
flang/lib/Lower/Bridge.cpp | 69 +++++++++++---------
flang/lib/Parser/executable-parsers.cpp | 10 ++-
flang/lib/Parser/unparse.cpp | 16 +++--
flang/test/Parser/cuf-sanity-common | 3 +
6 files changed, 67 insertions(+), 43 deletions(-)
diff --git a/flang/include/flang/Parser/dump-parse-tree.h b/flang/include/flang/Parser/dump-parse-tree.h
index 4bbf9777a54ccb..2fb863738d62d0 100644
--- a/flang/include/flang/Parser/dump-parse-tree.h
+++ b/flang/include/flang/Parser/dump-parse-tree.h
@@ -236,6 +236,7 @@ class ParseTreeDumper {
NODE(parser, CUFKernelDoConstruct)
NODE(CUFKernelDoConstruct, StarOrExpr)
NODE(CUFKernelDoConstruct, Directive)
+ NODE(CUFKernelDoConstruct, LaunchConfiguration)
NODE(parser, CUFReduction)
NODE(parser, CycleStmt)
NODE(parser, DataComponentDefStmt)
diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index 5f5650304f9987..ce0b6167de9fc8 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -4527,12 +4527,17 @@ struct CUFReduction {
struct CUFKernelDoConstruct {
TUPLE_CLASS_BOILERPLATE(CUFKernelDoConstruct);
WRAPPER_CLASS(StarOrExpr, std::optional<ScalarIntExpr>);
+ struct LaunchConfiguration {
+ TUPLE_CLASS_BOILERPLATE(LaunchConfiguration);
+ std::tuple<std::list<StarOrExpr>, std::list<StarOrExpr>,
+ std::optional<ScalarIntExpr>>
+ t;
+ };
struct Directive {
TUPLE_CLASS_BOILERPLATE(Directive);
CharBlock source;
- std::tuple<std::optional<ScalarIntConstantExpr>, std::list<StarOrExpr>,
- std::list<StarOrExpr>, std::optional<ScalarIntExpr>,
- std::list<CUFReduction>>
+ std::tuple<std::optional<ScalarIntConstantExpr>,
+ std::optional<LaunchConfiguration>, std::list<CUFReduction>>
t;
};
std::tuple<Directive, std::optional<DoConstruct>> t;
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 0e3011e73902da..da53edf7e734b0 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2862,14 +2862,11 @@ class FirConverter : public Fortran::lower::AbstractConverter {
if (nestedLoops > 1)
n = builder->getIntegerAttr(builder->getI64Type(), nestedLoops);
- const std::list<Fortran::parser::CUFKernelDoConstruct::StarOrExpr> &grid =
- std::get<1>(dir.t);
- const std::list<Fortran::parser::CUFKernelDoConstruct::StarOrExpr> &block =
- std::get<2>(dir.t);
- const std::optional<Fortran::parser::ScalarIntExpr> &stream =
- std::get<3>(dir.t);
+ const auto &launchConfig = std::get<std::optional<
+ Fortran::parser::CUFKernelDoConstruct::LaunchConfiguration>>(dir.t);
+
const std::list<Fortran::parser::CUFReduction> &cufreds =
- std::get<4>(dir.t);
+ std::get<2>(dir.t);
llvm::SmallVector<mlir::Value> reduceOperands;
llvm::SmallVector<mlir::Attribute> reduceAttrs;
@@ -2913,35 +2910,45 @@ class FirConverter : public Fortran::lower::AbstractConverter {
builder->createIntegerConstant(loc, builder->getI32Type(), 0);
llvm::SmallVector<mlir::Value> gridValues;
- if (!isOnlyStars(grid)) {
- for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr :
- grid) {
- if (expr.v) {
- gridValues.push_back(fir::getBase(
- genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx)));
- } else {
- gridValues.push_back(zero);
+ llvm::SmallVector<mlir::Value> blockValues;
+ mlir::Value streamValue;
+
+ if (launchConfig) {
+ const std::list<Fortran::parser::CUFKernelDoConstruct::StarOrExpr> &grid =
+ std::get<0>(launchConfig->t);
+ const std::list<Fortran::parser::CUFKernelDoConstruct::StarOrExpr>
+ &block = std::get<1>(launchConfig->t);
+ const std::optional<Fortran::parser::ScalarIntExpr> &stream =
+ std::get<2>(launchConfig->t);
+ if (!isOnlyStars(grid)) {
+ for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr :
+ grid) {
+ if (expr.v) {
+ gridValues.push_back(fir::getBase(
+ genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx)));
+ } else {
+ gridValues.push_back(zero);
+ }
}
}
- }
- llvm::SmallVector<mlir::Value> blockValues;
- if (!isOnlyStars(block)) {
- for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr :
- block) {
- if (expr.v) {
- blockValues.push_back(fir::getBase(
- genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx)));
- } else {
- blockValues.push_back(zero);
+ if (!isOnlyStars(block)) {
+ for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr :
+ block) {
+ if (expr.v) {
+ blockValues.push_back(fir::getBase(
+ genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx)));
+ } else {
+ blockValues.push_back(zero);
+ }
}
}
+
+ if (stream)
+ streamValue = builder->createConvert(
+ loc, builder->getI32Type(),
+ fir::getBase(
+ genExprValue(*Fortran::semantics::GetExpr(*stream), stmtCtx)));
}
- mlir::Value streamValue;
- if (stream)
- streamValue = builder->createConvert(
- loc, builder->getI32Type(),
- fir::getBase(
- genExprValue(*Fortran::semantics::GetExpr(*stream), stmtCtx)));
const auto &outerDoConstruct =
std::get<std::optional<Fortran::parser::DoConstruct>>(kernel.t);
diff --git a/flang/lib/Parser/executable-parsers.cpp b/flang/lib/Parser/executable-parsers.cpp
index 5057e89164c9f2..730165613d91db 100644
--- a/flang/lib/Parser/executable-parsers.cpp
+++ b/flang/lib/Parser/executable-parsers.cpp
@@ -563,11 +563,15 @@ TYPE_PARSER(("REDUCTION"_tok || "REDUCE"_tok) >>
parenthesized(construct<CUFReduction>(Parser<CUFReduction::Operator>{},
":" >> nonemptyList(scalar(variable)))))
+TYPE_PARSER("<<<" >>
+ construct<CUFKernelDoConstruct::LaunchConfiguration>(gridOrBlock,
+ "," >> gridOrBlock,
+ maybe((", 0 ,"_tok || ", STREAM ="_tok) >> scalarIntExpr) / ">>>"))
+
TYPE_PARSER(sourced(beginDirective >> "$CUF KERNEL DO"_tok >>
construct<CUFKernelDoConstruct::Directive>(
- maybe(parenthesized(scalarIntConstantExpr)), "<<<" >> gridOrBlock,
- "," >> gridOrBlock,
- maybe((", 0 ,"_tok || ", STREAM ="_tok) >> scalarIntExpr) / ">>>",
+ maybe(parenthesized(scalarIntConstantExpr)),
+ maybe(Parser<CUFKernelDoConstruct::LaunchConfiguration>{}),
many(Parser<CUFReduction>{}) / endDirective)))
TYPE_CONTEXT_PARSER("!$CUF KERNEL DO construct"_en_US,
extension<LanguageFeature::CUDA>(construct<CUFKernelDoConstruct>(
diff --git a/flang/lib/Parser/unparse.cpp b/flang/lib/Parser/unparse.cpp
index 20022f8fa984ce..4b511da69832c5 100644
--- a/flang/lib/Parser/unparse.cpp
+++ b/flang/lib/Parser/unparse.cpp
@@ -2932,11 +2932,9 @@ class UnparseVisitor {
Word("*");
}
}
- void Unparse(const CUFKernelDoConstruct::Directive &x) {
- Word("!$CUF KERNEL DO");
- Walk(" (", std::get<std::optional<ScalarIntConstantExpr>>(x.t), ")");
+ void Unparse(const CUFKernelDoConstruct::LaunchConfiguration &x) {
Word(" <<<");
- const auto &grid{std::get<1>(x.t)};
+ const auto &grid{std::get<0>(x.t)};
if (grid.empty()) {
Word("*");
} else if (grid.size() == 1) {
@@ -2945,7 +2943,7 @@ class UnparseVisitor {
Walk("(", grid, ",", ")");
}
Word(",");
- const auto &block{std::get<2>(x.t)};
+ const auto &block{std::get<1>(x.t)};
if (block.empty()) {
Word("*");
} else if (block.size() == 1) {
@@ -2953,10 +2951,16 @@ class UnparseVisitor {
} else {
Walk("(", block, ",", ")");
}
- if (const auto &stream{std::get<3>(x.t)}) {
+ if (const auto &stream{std::get<2>(x.t)}) {
Word(",STREAM="), Walk(*stream);
}
Word(">>>");
+ }
+ void Unparse(const CUFKernelDoConstruct::Directive &x) {
+ Word("!$CUF KERNEL DO");
+ Walk(" (", std::get<std::optional<ScalarIntConstantExpr>>(x.t), ")");
+ Walk(std::get<std::optional<CUFKernelDoConstruct::LaunchConfiguration>>(
+ x.t));
Walk(" ", std::get<std::list<CUFReduction>>(x.t), " ");
Word("\n");
}
diff --git a/flang/test/Parser/cuf-sanity-common b/flang/test/Parser/cuf-sanity-common
index 7005ef07b22650..816e03bed7220a 100644
--- a/flang/test/Parser/cuf-sanity-common
+++ b/flang/test/Parser/cuf-sanity-common
@@ -31,6 +31,9 @@ module m
!$cuf kernel do <<<1, (2, 3), stream = 1>>>
do j = 1, 10
end do
+ !$cuf kernel do
+ do j = 1, 10
+ end do
!$cuf kernel do <<<*, *>>> reduce(+:x,y) reduce(*:z)
do j = 1, 10
x = x + a(j)
More information about the flang-commits
mailing list