[flang-commits] [flang] [flang][cuda] Make launch configuration optional for cuf kernel (PR #115947)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Tue Nov 12 14:15:26 PST 2024


https://github.com/clementval created https://github.com/llvm/llvm-project/pull/115947

Launch configuration on `CUF KERNEL DO` directive is optional as shown in example code from https://docs.nvidia.com/hpc-sdk/compilers/cuda-fortran-prog-guide/index.html#cfpg-cuda-fort-host-dev-code

This patch updates the parse tree, parser and lowering to take this into account. 

>From 3f5881d7d3532341b6b0195228dd17d3285d17b6 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Tue, 12 Nov 2024 14:12:24 -0800
Subject: [PATCH] [flang][cuda] Make launch configuration optional for cuf
 kernel

---
 flang/include/flang/Parser/dump-parse-tree.h |  1 +
 flang/include/flang/Parser/parse-tree.h      | 11 +++-
 flang/lib/Lower/Bridge.cpp                   | 69 +++++++++++---------
 flang/lib/Parser/executable-parsers.cpp      | 10 ++-
 flang/lib/Parser/unparse.cpp                 | 16 +++--
 flang/test/Parser/cuf-sanity-common          |  3 +
 6 files changed, 67 insertions(+), 43 deletions(-)

diff --git a/flang/include/flang/Parser/dump-parse-tree.h b/flang/include/flang/Parser/dump-parse-tree.h
index 4bbf9777a54ccb..2fb863738d62d0 100644
--- a/flang/include/flang/Parser/dump-parse-tree.h
+++ b/flang/include/flang/Parser/dump-parse-tree.h
@@ -236,6 +236,7 @@ class ParseTreeDumper {
   NODE(parser, CUFKernelDoConstruct)
   NODE(CUFKernelDoConstruct, StarOrExpr)
   NODE(CUFKernelDoConstruct, Directive)
+  NODE(CUFKernelDoConstruct, LaunchConfiguration)
   NODE(parser, CUFReduction)
   NODE(parser, CycleStmt)
   NODE(parser, DataComponentDefStmt)
diff --git a/flang/include/flang/Parser/parse-tree.h b/flang/include/flang/Parser/parse-tree.h
index 5f5650304f9987..ce0b6167de9fc8 100644
--- a/flang/include/flang/Parser/parse-tree.h
+++ b/flang/include/flang/Parser/parse-tree.h
@@ -4527,12 +4527,17 @@ struct CUFReduction {
 struct CUFKernelDoConstruct {
   TUPLE_CLASS_BOILERPLATE(CUFKernelDoConstruct);
   WRAPPER_CLASS(StarOrExpr, std::optional<ScalarIntExpr>);
+  struct LaunchConfiguration {
+    TUPLE_CLASS_BOILERPLATE(LaunchConfiguration);
+    std::tuple<std::list<StarOrExpr>, std::list<StarOrExpr>,
+        std::optional<ScalarIntExpr>>
+        t;
+  };
   struct Directive {
     TUPLE_CLASS_BOILERPLATE(Directive);
     CharBlock source;
-    std::tuple<std::optional<ScalarIntConstantExpr>, std::list<StarOrExpr>,
-        std::list<StarOrExpr>, std::optional<ScalarIntExpr>,
-        std::list<CUFReduction>>
+    std::tuple<std::optional<ScalarIntConstantExpr>,
+        std::optional<LaunchConfiguration>, std::list<CUFReduction>>
         t;
   };
   std::tuple<Directive, std::optional<DoConstruct>> t;
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 0e3011e73902da..da53edf7e734b0 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2862,14 +2862,11 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     if (nestedLoops > 1)
       n = builder->getIntegerAttr(builder->getI64Type(), nestedLoops);
 
-    const std::list<Fortran::parser::CUFKernelDoConstruct::StarOrExpr> &grid =
-        std::get<1>(dir.t);
-    const std::list<Fortran::parser::CUFKernelDoConstruct::StarOrExpr> &block =
-        std::get<2>(dir.t);
-    const std::optional<Fortran::parser::ScalarIntExpr> &stream =
-        std::get<3>(dir.t);
+    const auto &launchConfig = std::get<std::optional<
+        Fortran::parser::CUFKernelDoConstruct::LaunchConfiguration>>(dir.t);
+
     const std::list<Fortran::parser::CUFReduction> &cufreds =
-        std::get<4>(dir.t);
+        std::get<2>(dir.t);
 
     llvm::SmallVector<mlir::Value> reduceOperands;
     llvm::SmallVector<mlir::Attribute> reduceAttrs;
@@ -2913,35 +2910,45 @@ class FirConverter : public Fortran::lower::AbstractConverter {
         builder->createIntegerConstant(loc, builder->getI32Type(), 0);
 
     llvm::SmallVector<mlir::Value> gridValues;
-    if (!isOnlyStars(grid)) {
-      for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr :
-           grid) {
-        if (expr.v) {
-          gridValues.push_back(fir::getBase(
-              genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx)));
-        } else {
-          gridValues.push_back(zero);
+    llvm::SmallVector<mlir::Value> blockValues;
+    mlir::Value streamValue;
+
+    if (launchConfig) {
+      const std::list<Fortran::parser::CUFKernelDoConstruct::StarOrExpr> &grid =
+          std::get<0>(launchConfig->t);
+      const std::list<Fortran::parser::CUFKernelDoConstruct::StarOrExpr>
+          &block = std::get<1>(launchConfig->t);
+      const std::optional<Fortran::parser::ScalarIntExpr> &stream =
+          std::get<2>(launchConfig->t);
+      if (!isOnlyStars(grid)) {
+        for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr :
+             grid) {
+          if (expr.v) {
+            gridValues.push_back(fir::getBase(
+                genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx)));
+          } else {
+            gridValues.push_back(zero);
+          }
         }
       }
-    }
-    llvm::SmallVector<mlir::Value> blockValues;
-    if (!isOnlyStars(block)) {
-      for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr :
-           block) {
-        if (expr.v) {
-          blockValues.push_back(fir::getBase(
-              genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx)));
-        } else {
-          blockValues.push_back(zero);
+      if (!isOnlyStars(block)) {
+        for (const Fortran::parser::CUFKernelDoConstruct::StarOrExpr &expr :
+             block) {
+          if (expr.v) {
+            blockValues.push_back(fir::getBase(
+                genExprValue(*Fortran::semantics::GetExpr(*expr.v), stmtCtx)));
+          } else {
+            blockValues.push_back(zero);
+          }
         }
       }
+
+      if (stream)
+        streamValue = builder->createConvert(
+            loc, builder->getI32Type(),
+            fir::getBase(
+                genExprValue(*Fortran::semantics::GetExpr(*stream), stmtCtx)));
     }
-    mlir::Value streamValue;
-    if (stream)
-      streamValue = builder->createConvert(
-          loc, builder->getI32Type(),
-          fir::getBase(
-              genExprValue(*Fortran::semantics::GetExpr(*stream), stmtCtx)));
 
     const auto &outerDoConstruct =
         std::get<std::optional<Fortran::parser::DoConstruct>>(kernel.t);
diff --git a/flang/lib/Parser/executable-parsers.cpp b/flang/lib/Parser/executable-parsers.cpp
index 5057e89164c9f2..730165613d91db 100644
--- a/flang/lib/Parser/executable-parsers.cpp
+++ b/flang/lib/Parser/executable-parsers.cpp
@@ -563,11 +563,15 @@ TYPE_PARSER(("REDUCTION"_tok || "REDUCE"_tok) >>
     parenthesized(construct<CUFReduction>(Parser<CUFReduction::Operator>{},
         ":" >> nonemptyList(scalar(variable)))))
 
+TYPE_PARSER("<<<" >>
+    construct<CUFKernelDoConstruct::LaunchConfiguration>(gridOrBlock,
+        "," >> gridOrBlock,
+        maybe((", 0 ,"_tok || ", STREAM ="_tok) >> scalarIntExpr) / ">>>"))
+
 TYPE_PARSER(sourced(beginDirective >> "$CUF KERNEL DO"_tok >>
     construct<CUFKernelDoConstruct::Directive>(
-        maybe(parenthesized(scalarIntConstantExpr)), "<<<" >> gridOrBlock,
-        "," >> gridOrBlock,
-        maybe((", 0 ,"_tok || ", STREAM ="_tok) >> scalarIntExpr) / ">>>",
+        maybe(parenthesized(scalarIntConstantExpr)),
+        maybe(Parser<CUFKernelDoConstruct::LaunchConfiguration>{}),
         many(Parser<CUFReduction>{}) / endDirective)))
 TYPE_CONTEXT_PARSER("!$CUF KERNEL DO construct"_en_US,
     extension<LanguageFeature::CUDA>(construct<CUFKernelDoConstruct>(
diff --git a/flang/lib/Parser/unparse.cpp b/flang/lib/Parser/unparse.cpp
index 20022f8fa984ce..4b511da69832c5 100644
--- a/flang/lib/Parser/unparse.cpp
+++ b/flang/lib/Parser/unparse.cpp
@@ -2932,11 +2932,9 @@ class UnparseVisitor {
       Word("*");
     }
   }
-  void Unparse(const CUFKernelDoConstruct::Directive &x) {
-    Word("!$CUF KERNEL DO");
-    Walk(" (", std::get<std::optional<ScalarIntConstantExpr>>(x.t), ")");
+  void Unparse(const CUFKernelDoConstruct::LaunchConfiguration &x) {
     Word(" <<<");
-    const auto &grid{std::get<1>(x.t)};
+    const auto &grid{std::get<0>(x.t)};
     if (grid.empty()) {
       Word("*");
     } else if (grid.size() == 1) {
@@ -2945,7 +2943,7 @@ class UnparseVisitor {
       Walk("(", grid, ",", ")");
     }
     Word(",");
-    const auto &block{std::get<2>(x.t)};
+    const auto &block{std::get<1>(x.t)};
     if (block.empty()) {
       Word("*");
     } else if (block.size() == 1) {
@@ -2953,10 +2951,16 @@ class UnparseVisitor {
     } else {
       Walk("(", block, ",", ")");
     }
-    if (const auto &stream{std::get<3>(x.t)}) {
+    if (const auto &stream{std::get<2>(x.t)}) {
       Word(",STREAM="), Walk(*stream);
     }
     Word(">>>");
+  }
+  void Unparse(const CUFKernelDoConstruct::Directive &x) {
+    Word("!$CUF KERNEL DO");
+    Walk(" (", std::get<std::optional<ScalarIntConstantExpr>>(x.t), ")");
+    Walk(std::get<std::optional<CUFKernelDoConstruct::LaunchConfiguration>>(
+        x.t));
     Walk(" ", std::get<std::list<CUFReduction>>(x.t), " ");
     Word("\n");
   }
diff --git a/flang/test/Parser/cuf-sanity-common b/flang/test/Parser/cuf-sanity-common
index 7005ef07b22650..816e03bed7220a 100644
--- a/flang/test/Parser/cuf-sanity-common
+++ b/flang/test/Parser/cuf-sanity-common
@@ -31,6 +31,9 @@ module m
     !$cuf kernel do <<<1, (2, 3), stream = 1>>>
     do j = 1, 10
     end do
+    !$cuf kernel do
+    do j = 1, 10
+    end do
     !$cuf kernel do <<<*, *>>> reduce(+:x,y) reduce(*:z)
     do j = 1, 10
       x = x + a(j)



More information about the flang-commits mailing list