[flang-commits] [flang] [flang][cuda] CUF kernel loop directive (PR #82836)

Valentin Clement バレンタイン クレメン via flang-commits flang-commits at lists.llvm.org
Tue Feb 27 09:11:41 PST 2024

https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/82836

>From 46b322d8e4b00124b5cf2a06c2ceada8a96679c6 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 15 Feb 2024 14:45:50 -0800
Subject: [PATCH] [flang][cuda] CUF kernel loop directive

This patch introduces a new operation to represent the CUDA Fortran
kernel loop directive. This operation is modelled as a LoopLikeOp
operation in a similar way to acc.loop.

Lowering from the flang parse-tree to MLIR is also done.
 flang/include/flang/Lower/PFTBuilder.h        |   5 +-
 .../include/flang/Optimizer/Dialect/FIROps.td |  27 ++++
 flang/lib/Lower/Bridge.cpp                    | 129 ++++++++++++++++++
 flang/lib/Optimizer/Dialect/FIROps.cpp        |  97 +++++++++++++
 .../Lower/CUDA/cuda-kernel-loop-directive.cuf |  51 +++++++
 5 files changed, 307 insertions(+), 2 deletions(-)
 create mode 100644 flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf

diff --git a/flang/include/flang/Lower/PFTBuilder.h b/flang/include/flang/Lower/PFTBuilder.h
index c2b0fdbf357cde..9913f584133faa 100644
--- a/flang/include/flang/Lower/PFTBuilder.h
+++ b/flang/include/flang/Lower/PFTBuilder.h
@@ -138,7 +138,8 @@ using Directives =
     std::tuple<parser::CompilerDirective, parser::OpenACCConstruct,
                parser::OpenACCDeclarativeConstruct, parser::OpenMPConstruct,
-               parser::OpenMPDeclarativeConstruct, parser::OmpEndLoopDirective>;
+               parser::OpenMPDeclarativeConstruct, parser::OmpEndLoopDirective,
+               parser::CUFKernelDoConstruct>;
 using DeclConstructs = std::tuple<parser::OpenMPDeclarativeConstruct,
@@ -178,7 +179,7 @@ static constexpr bool isNopConstructStmt{common::HasMember<
 template <typename A>
 static constexpr bool isExecutableDirective{common::HasMember<
     A, std::tuple<parser::CompilerDirective, parser::OpenACCConstruct,
-                  parser::OpenMPConstruct>>};
+                  parser::OpenMPConstruct, parser::CUFKernelDoConstruct>>};
 template <typename A>
 static constexpr bool isFunctionLike{common::HasMember<
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 08239230f793f1..db5e5f4bc682e6 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -3127,4 +3127,31 @@ def fir_BoxOffsetOp : fir_Op<"box_offset", [NoMemoryEffect]> {
+def fir_CUDAKernelOp : fir_Op<"cuda_kernel", [AttrSizedOperandSegments,
+    DeclareOpInterfaceMethods<LoopLikeOpInterface>]> {
+  let arguments = (ins
+    Variadic<I32>:$grid, // empty means `*`
+    Variadic<I32>:$block, // empty means `*`
+    Optional<I32>:$stream,
+    Variadic<Index>:$lowerbound,
+    Variadic<Index>:$upperbound,
+    Variadic<Index>:$step,
+    OptionalAttr<I64Attr>:$n
+  );
+  let regions = (region AnyRegion:$region);
+  let assemblyFormat = [{
+    `<` `<` `<` custom<CUFKernelValues>($grid, type($grid)) `,` 
+                custom<CUFKernelValues>($block, type($block))
+        ( `,` `stream` `=` $stream^ )? `>` `>` `>`
+        custom<CUFKernelLoopControl>($region, $lowerbound, type($lowerbound),
+            $upperbound, type($upperbound), $step, type($step))
+        attr-dict
+  }];
+  let hasVerifier = 1;
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 2d7f748cefa2d8..c7b8cd96021ce2 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2453,6 +2453,135 @@ class FirConverter : public Fortran::lower::AbstractConverter {
     // Handled by genFIR(const Fortran::parser::OpenACCDeclarativeConstruct &)
+  void genFIR(const Fortran::parser::CUFKernelDoConstruct &kernel) {
+    localSymbols.pushScope();
+    const Fortran::parser::CUFKernelDoConstruct::Directive &dir =
+        std::get<Fortran::parser::CUFKernelDoConstruct::Directive>(kernel.t);
+    mlir::Location loc = genLocation(dir.source);
+    Fortran::lower::StatementContext stmtCtx;
+    unsigned nestedLoops = 1;
+    const auto &nLoops =
+        std::get<std::optional<Fortran::parser::ScalarIntConstantExpr>>(dir.t);
+    if (nLoops)
+      nestedLoops = *Fortran::semantics::GetIntValue(*nLoops);
+    mlir::IntegerAttr n;
+    if (nestedLoops > 1)
+      n = builder->getIntegerAttr(builder->getI64Type(), nestedLoops);
+    const std::list<Fortran::parser::ScalarIntExpr> &grid = std::get<1>(dir.t);
+    const std::list<Fortran::parser::ScalarIntExpr> &block = std::get<2>(dir.t);
+    const std::optional<Fortran::parser::ScalarIntExpr> &stream =
+        std::get<3>(dir.t);
+    llvm::SmallVector<mlir::Value> gridValues;
+    for (const Fortran::parser::ScalarIntExpr &expr : grid)
+      gridValues.push_back(fir::getBase(
+          genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)));
+    llvm::SmallVector<mlir::Value> blockValues;
+    for (const Fortran::parser::ScalarIntExpr &expr : block)
+      blockValues.push_back(fir::getBase(
+          genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)));
+    mlir::Value streamValue;
+    if (stream)
+      streamValue = fir::getBase(
+          genExprValue(*Fortran::semantics::GetExpr(*stream), stmtCtx));
+    const auto &outerDoConstruct =
+        std::get<std::optional<Fortran::parser::DoConstruct>>(kernel.t);
+    llvm::SmallVector<mlir::Location> locs;
+    locs.push_back(loc);
+    llvm::SmallVector<mlir::Value> lbs, ubs, steps;
+    mlir::Type idxTy = builder->getIndexType();
+    llvm::SmallVector<mlir::Type> ivTypes;
+    llvm::SmallVector<mlir::Location> ivLocs;
+    llvm::SmallVector<mlir::Value> ivValues;
+    for (unsigned i = 0; i < nestedLoops; ++i) {
+      const Fortran::parser::LoopControl *loopControl;
+      Fortran::lower::pft::Evaluation *loopEval =
+          &getEval().getFirstNestedEvaluation();
+      mlir::Location crtLoc = loc;
+      if (i == 0) {
+        loopControl = &*outerDoConstruct->GetLoopControl();
+        crtLoc =
+            genLocation(Fortran::parser::FindSourceLocation(outerDoConstruct));
+      } else {
+        auto *doCons = loopEval->getIf<Fortran::parser::DoConstruct>();
+        assert(doCons && "expect do construct");
+        loopControl = &*doCons->GetLoopControl();
+        crtLoc = genLocation(Fortran::parser::FindSourceLocation(*doCons));
+      }
+      locs.push_back(crtLoc);
+      const Fortran::parser::LoopControl::Bounds *bounds =
+          std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
+      assert(bounds && "Expected bounds on the loop construct");
+      Fortran::semantics::Symbol &ivSym =
+          bounds->name.thing.symbol->GetUltimate();
+      ivValues.push_back(getSymbolAddress(ivSym));
+      lbs.push_back(builder->createConvert(
+          crtLoc, idxTy,
+          fir::getBase(genExprValue(*Fortran::semantics::GetExpr(bounds->lower),
+                                    stmtCtx))));
+      ubs.push_back(builder->createConvert(
+          crtLoc, idxTy,
+          fir::getBase(genExprValue(*Fortran::semantics::GetExpr(bounds->upper),
+                                    stmtCtx))));
+      if (bounds->step)
+        steps.push_back(fir::getBase(
+            genExprValue(*Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
+      else // If `step` is not present, assume it is `1`.
+        steps.push_back(builder->createIntegerConstant(loc, idxTy, 1));
+      ivTypes.push_back(idxTy);
+      ivLocs.push_back(crtLoc);
+      if (i < nestedLoops - 1)
+        loopEval = &*std::next(loopEval->getNestedEvaluations().begin());
+    }
+    auto op = builder->create<fir::CUDAKernelOp>(
+        loc, gridValues, blockValues, streamValue, lbs, ubs, steps, n);
+    builder->createBlock(&op.getRegion(), op.getRegion().end(), ivTypes,
+                         ivLocs);
+    mlir::Block &b = op.getRegion().back();
+    builder->setInsertionPointToStart(&b);
+    for (auto [arg, value] : llvm::zip(
+             op.getLoopRegions().front()->front().getArguments(), ivValues)) {
+      mlir::Value convArg =
+          builder->createConvert(loc, fir::unwrapRefType(value.getType()), arg);
+      builder->create<fir::StoreOp>(loc, convArg, value);
+    }
+    builder->create<fir::FirEndOp>(loc);
+    builder->setInsertionPointToStart(&b);
+    Fortran::lower::pft::Evaluation *crtEval = &getEval();
+    if (crtEval->lowerAsStructured()) {
+      crtEval = &crtEval->getFirstNestedEvaluation();
+      for (int64_t i = 1; i < nestedLoops; i++)
+        crtEval = &*std::next(crtEval->getNestedEvaluations().begin());
+    }
+    // Generate loop body
+    for (Fortran::lower::pft::Evaluation &e : crtEval->getNestedEvaluations())
+      genFIR(e);
+    builder->setInsertionPointAfter(op);
+    localSymbols.popScope();
+  }
   void genFIR(const Fortran::parser::OpenMPConstruct &omp) {
     mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
     genOpenMPConstruct(*this, localSymbols, bridge.getSemanticsContext(),
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 0a534cdb3c4871..bbf5c0c09101be 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -3866,6 +3866,103 @@ mlir::LogicalResult fir::DeclareOp::verify() {
   return fortranVar.verifyDeclareLikeOpImpl(getMemref());
+llvm::SmallVector<mlir::Region *> fir::CUDAKernelOp::getLoopRegions() {
+  return {&getRegion()};
+mlir::ParseResult parseCUFKernelValues(
+    mlir::OpAsmParser &parser,
+    llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &values,
+    llvm::SmallVectorImpl<mlir::Type> &types) {
+  if (mlir::succeeded(parser.parseOptionalStar()))
+    return mlir::success();
+  if (parser.parseOptionalLParen()) {
+    if (mlir::failed(parser.parseCommaSeparatedList(
+            mlir::AsmParser::Delimiter::None, [&]() {
+              if (parser.parseOperand(values.emplace_back()))
+                return mlir::failure();
+              return mlir::success();
+            })))
+      return mlir::failure();
+    if (parser.parseRParen())
+      return mlir::failure();
+  } else {
+    if (parser.parseOperand(values.emplace_back()))
+      return mlir::failure();
+    return mlir::success();
+  }
+  return mlir::success();
+void printCUFKernelValues(mlir::OpAsmPrinter &p, mlir::Operation *op,
+                          mlir::ValueRange values, mlir::TypeRange types) {
+  if (values.empty())
+    p << "*";
+  if (values.size() > 1)
+    p << "(";
+  llvm::interleaveComma(values, p, [&p](mlir::Value v) { p << v; });
+  if (values.size() > 1)
+    p << ")";
+mlir::ParseResult parseCUFKernelLoopControl(
+    mlir::OpAsmParser &parser, mlir::Region &region,
+    llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &lowerbound,
+    llvm::SmallVectorImpl<mlir::Type> &lowerboundType,
+    llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &upperbound,
+    llvm::SmallVectorImpl<mlir::Type> &upperboundType,
+    llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &step,
+    llvm::SmallVectorImpl<mlir::Type> &stepType) {
+  llvm::SmallVector<mlir::OpAsmParser::Argument> inductionVars;
+  if (parser.parseLParen() ||
+      parser.parseArgumentList(inductionVars,
+                               mlir::OpAsmParser::Delimiter::None,
+                               /*allowType=*/true) ||
+      parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
+      parser.parseOperandList(lowerbound, inductionVars.size(),
+                              mlir::OpAsmParser::Delimiter::None) ||
+      parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
+      parser.parseKeyword("to") || parser.parseLParen() ||
+      parser.parseOperandList(upperbound, inductionVars.size(),
+                              mlir::OpAsmParser::Delimiter::None) ||
+      parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
+      parser.parseKeyword("step") || parser.parseLParen() ||
+      parser.parseOperandList(step, inductionVars.size(),
+                              mlir::OpAsmParser::Delimiter::None) ||
+      parser.parseColonTypeList(stepType) || parser.parseRParen())
+    return mlir::failure();
+  return parser.parseRegion(region, inductionVars);
+void printCUFKernelLoopControl(
+    mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Region &region,
+    mlir::ValueRange lowerbound, mlir::TypeRange lowerboundType,
+    mlir::ValueRange upperbound, mlir::TypeRange upperboundType,
+    mlir::ValueRange steps, mlir::TypeRange stepType) {
+  mlir::ValueRange regionArgs = region.front().getArguments();
+  if (!regionArgs.empty()) {
+    p << "(";
+    llvm::interleaveComma(
+        regionArgs, p, [&p](mlir::Value v) { p << v << " : " << v.getType(); });
+    p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
+      << upperbound << " : " << upperboundType << ") " << " step (" << steps
+      << " : " << stepType << ") ";
+  }
+  p.printRegion(region, /*printEntryBlockArgs=*/false);
+mlir::LogicalResult fir::CUDAKernelOp::verify() {
+  if (getLowerbound().size() != getUpperbound().size() ||
+      getLowerbound().size() != getStep().size())
+    return emitOpError(
+        "expect same number of values in lowerbound, upperbound and step");
+  return mlir::success();
 // FIROpsDialect
diff --git a/flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf b/flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf
new file mode 100644
index 00000000000000..db628fe756b952
--- /dev/null
+++ b/flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf
@@ -0,0 +1,51 @@
+! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
+! Test lowering of CUDA kernel loop directive.
+subroutine sub1()
+  integer :: i, j
+  integer, parameter :: n = 100
+  real :: a(n), b(n)
+  real :: c(n,n), d(n,n)
+! CHECK-LABEL: func.func @_QPsub1()
+! CHECK: %[[IV:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  !$cuf kernel do <<< 1, 2 >>>
+  do i = 1, n
+    a(i) = a(i) * b(i)
+  end do
+! CHECK: %[[LB:.*]] = fir.convert %c1{{.*}} : (i32) -> index
+! CHECK: %[[UB:.*]] = fir.convert %c100{{.*}} : (i32) -> index
+! CHECK: %[[STEP:.*]] = arith.constant 1 : index
+! CHECK: fir.cuda_kernel<<<%c1_i32, %c2_i32>>> (%[[ARG0:.*]] : index) = (%[[LB]] : index) to (%[[UB]] : index)  step (%[[STEP]] : index)
+! CHECK-NOT: fir.do_loop
+! CHECK: %[[ARG0_I32:.*]] = fir.convert %[[ARG0]] : (index) -> i32
+! CHECK: fir.store %[[ARG0_I32]] to %[[IV]]#1 : !fir.ref<i32>
+  !$cuf kernel do <<< *, * >>>
+  do i = 1, n
+    a(i) = a(i) * b(i)
+  end do
+! CHECK: fir.cuda_kernel<<<*, *>>> (%{{.*}} : index) = (%{{.*}} : index) to (%{{.*}} : index)  step (%{{.*}} : index)
+  !$cuf kernel do(2) <<< 1, (256,1) >>>
+  do i = 1, n
+    do j = 1, n
+      c(i,j) = c(i,j) * d(i,j)
+    end do
+  end do
+! CHECK: fir.cuda_kernel<<<%c1{{.*}}, (%c256{{.*}}, %c1{{.*}})>>> (%{{.*}} : index, %{{.*}} : index) = (%{{.*}}, %{{.*}} : index, index) to (%{{.*}}, %{{.*}} : index, index) step (%{{.*}}, %{{.*}} : index, index)
+! CHECK: {n = 2 : i64}
+! TODO: currently these trigger error in the parser
+! !$cuf kernel do(2) <<< (1,*), (256,1) >>>
+! !$cuf kernel do(2) <<< (*,*), (32,4) >>>

More information about the flang-commits mailing list