[flang-commits] [flang] [flang][cuda] CUF kernel loop directive (PR #82836)

via flang-commits flang-commits at lists.llvm.org
Fri Feb 23 14:06:44 PST 2024


github-actions[bot] wrote:

<!--LLVM CODE FORMAT COMMENT: {clang-format}-->


:warning: C/C++ code formatter, clang-format found issues in your code. :warning:

<details>
<summary>
You can test this locally with the following command:
</summary>

``````````bash
git-clang-format --diff a73e9244621a6186859644012c295740465ad844 67b14f6215a8ef0776dba3e8f241f063bcb85372 -- flang/include/flang/Lower/PFTBuilder.h flang/lib/Lower/Bridge.cpp flang/lib/Optimizer/Dialect/FIROps.cpp
``````````

</details>

<details>
<summary>
View the diff from clang-format here.
</summary>

``````````diff
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 2c4825fafd..c7b8cd9602 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2463,7 +2463,7 @@ private:
     Fortran::lower::StatementContext stmtCtx;
 
     unsigned nestedLoops = 1;
-    
+
     const auto &nLoops =
         std::get<std::optional<Fortran::parser::ScalarIntConstantExpr>>(dir.t);
     if (nLoops)
@@ -2475,17 +2475,21 @@ private:
 
     const std::list<Fortran::parser::ScalarIntExpr> &grid = std::get<1>(dir.t);
     const std::list<Fortran::parser::ScalarIntExpr> &block = std::get<2>(dir.t);
-    const std::optional<Fortran::parser::ScalarIntExpr> &stream = std::get<3>(dir.t);
+    const std::optional<Fortran::parser::ScalarIntExpr> &stream =
+        std::get<3>(dir.t);
 
     llvm::SmallVector<mlir::Value> gridValues;
     for (const Fortran::parser::ScalarIntExpr &expr : grid)
-      gridValues.push_back(fir::getBase(genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)));
+      gridValues.push_back(fir::getBase(
+          genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)));
     llvm::SmallVector<mlir::Value> blockValues;
     for (const Fortran::parser::ScalarIntExpr &expr : block)
-      blockValues.push_back(fir::getBase(genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)));
+      blockValues.push_back(fir::getBase(
+          genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)));
     mlir::Value streamValue;
     if (stream)
-      streamValue = fir::getBase(genExprValue(*Fortran::semantics::GetExpr(*stream), stmtCtx));
+      streamValue = fir::getBase(
+          genExprValue(*Fortran::semantics::GetExpr(*stream), stmtCtx));
 
     const auto &outerDoConstruct =
         std::get<std::optional<Fortran::parser::DoConstruct>>(kernel.t);
@@ -2501,12 +2505,14 @@ private:
     llvm::SmallVector<mlir::Value> ivValues;
     for (unsigned i = 0; i < nestedLoops; ++i) {
       const Fortran::parser::LoopControl *loopControl;
-      Fortran::lower::pft::Evaluation *loopEval = &getEval().getFirstNestedEvaluation();
+      Fortran::lower::pft::Evaluation *loopEval =
+          &getEval().getFirstNestedEvaluation();
 
       mlir::Location crtLoc = loc;
       if (i == 0) {
         loopControl = &*outerDoConstruct->GetLoopControl();
-        crtLoc = genLocation(Fortran::parser::FindSourceLocation(outerDoConstruct));
+        crtLoc =
+            genLocation(Fortran::parser::FindSourceLocation(outerDoConstruct));
       } else {
         auto *doCons = loopEval->getIf<Fortran::parser::DoConstruct>();
         assert(doCons && "expect do construct");
@@ -2524,18 +2530,20 @@ private:
           bounds->name.thing.symbol->GetUltimate();
       ivValues.push_back(getSymbolAddress(ivSym));
 
-      lbs.push_back(builder->createConvert(crtLoc, idxTy,
-          fir::getBase(genExprValue(
-          *Fortran::semantics::GetExpr(bounds->lower), stmtCtx))));
-      ubs.push_back(builder->createConvert(crtLoc, idxTy,
-          fir::getBase(genExprValue(
-          *Fortran::semantics::GetExpr(bounds->upper), stmtCtx))));
+      lbs.push_back(builder->createConvert(
+          crtLoc, idxTy,
+          fir::getBase(genExprValue(*Fortran::semantics::GetExpr(bounds->lower),
+                                    stmtCtx))));
+      ubs.push_back(builder->createConvert(
+          crtLoc, idxTy,
+          fir::getBase(genExprValue(*Fortran::semantics::GetExpr(bounds->upper),
+                                    stmtCtx))));
       if (bounds->step)
-        steps.push_back(fir::getBase(genExprValue(
-            *Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
+        steps.push_back(fir::getBase(
+            genExprValue(*Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
       else // If `step` is not present, assume it is `1`.
         steps.push_back(builder->createIntegerConstant(loc, idxTy, 1));
-      
+
       ivTypes.push_back(idxTy);
       ivLocs.push_back(crtLoc);
       if (i < nestedLoops - 1)
@@ -2544,13 +2552,15 @@ private:
 
     auto op = builder->create<fir::CUDAKernelOp>(
         loc, gridValues, blockValues, streamValue, lbs, ubs, steps, n);
-    builder->createBlock(&op.getRegion(), op.getRegion().end(), ivTypes, ivLocs);
+    builder->createBlock(&op.getRegion(), op.getRegion().end(), ivTypes,
+                         ivLocs);
     mlir::Block &b = op.getRegion().back();
     builder->setInsertionPointToStart(&b);
 
     for (auto [arg, value] : llvm::zip(
-           op.getLoopRegions().front()->front().getArguments(), ivValues)) {
-      mlir::Value convArg = builder->createConvert(loc, fir::unwrapRefType(value.getType()), arg);
+             op.getLoopRegions().front()->front().getArguments(), ivValues)) {
+      mlir::Value convArg =
+          builder->createConvert(loc, fir::unwrapRefType(value.getType()), arg);
       builder->create<fir::StoreOp>(loc, convArg, value);
     }
 
@@ -2564,8 +2574,6 @@ private:
         crtEval = &*std::next(crtEval->getNestedEvaluations().begin());
     }
 
-
-
     // Generate loop body
     for (Fortran::lower::pft::Evaluation &e : crtEval->getNestedEvaluations())
       genFIR(e);
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index c2facb5a00..9bb10a42a3 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -3870,13 +3870,13 @@ llvm::SmallVector<mlir::Region *> fir::CUDAKernelOp::getLoopRegions() {
   return {&getRegion()};
 }
 
-mlir::ParseResult
-parseCUFKernelValues(mlir::OpAsmParser &parser,
-                     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &values,
-                     llvm::SmallVectorImpl<mlir::Type> &types) {
+mlir::ParseResult parseCUFKernelValues(
+    mlir::OpAsmParser &parser,
+    llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &values,
+    llvm::SmallVectorImpl<mlir::Type> &types) {
   if (mlir::succeeded(parser.parseOptionalStar()))
     return mlir::success();
-  
+
   if (parser.parseOptionalLParen()) {
     if (mlir::failed(parser.parseCommaSeparatedList(
             mlir::AsmParser::Delimiter::None, [&]() {
@@ -3902,14 +3902,13 @@ void printCUFKernelValues(mlir::OpAsmPrinter &p, mlir::Operation *op,
 
   if (values.size() > 1)
     p << "(";
-  llvm::interleaveComma(values, p,
-                        [&p](mlir::Value v) { p << v; });
+  llvm::interleaveComma(values, p, [&p](mlir::Value v) { p << v; });
   if (values.size() > 1)
     p << ")";
 }
 
-mlir::ParseResult
-parseCUFKernelLoopControl(mlir::OpAsmParser &parser, mlir::Region &region,
+mlir::ParseResult parseCUFKernelLoopControl(
+    mlir::OpAsmParser &parser, mlir::Region &region,
     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &lowerbound,
     llvm::SmallVectorImpl<mlir::Type> &lowerboundType,
     llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &upperbound,
@@ -3919,8 +3918,9 @@ parseCUFKernelLoopControl(mlir::OpAsmParser &parser, mlir::Region &region,
 
   llvm::SmallVector<mlir::OpAsmParser::Argument> inductionVars;
   if (parser.parseLParen() ||
-      parser.parseArgumentList(inductionVars, mlir::OpAsmParser::Delimiter::None,
-                                /*allowType=*/true) ||
+      parser.parseArgumentList(inductionVars,
+                               mlir::OpAsmParser::Delimiter::None,
+                               /*allowType=*/true) ||
       parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
       parser.parseOperandList(lowerbound, inductionVars.size(),
                               mlir::OpAsmParser::Delimiter::None) ||
@@ -3937,17 +3937,16 @@ parseCUFKernelLoopControl(mlir::OpAsmParser &parser, mlir::Region &region,
   return parser.parseRegion(region, inductionVars);
 }
 
-void printCUFKernelLoopControl(mlir::OpAsmPrinter &p, mlir::Operation *op,
-                      mlir::Region &region, mlir::ValueRange lowerbound,
-                      mlir::TypeRange lowerboundType,
-                      mlir::ValueRange upperbound,
-                      mlir::TypeRange upperboundType, mlir::ValueRange steps,
-                      mlir::TypeRange stepType) {
+void printCUFKernelLoopControl(
+    mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Region &region,
+    mlir::ValueRange lowerbound, mlir::TypeRange lowerboundType,
+    mlir::ValueRange upperbound, mlir::TypeRange upperboundType,
+    mlir::ValueRange steps, mlir::TypeRange stepType) {
   mlir::ValueRange regionArgs = region.front().getArguments();
   if (!regionArgs.empty()) {
     p << "(";
-    llvm::interleaveComma(regionArgs, p,
-                          [&p](mlir::Value v) { p << v << " : " << v.getType(); });
+    llvm::interleaveComma(
+        regionArgs, p, [&p](mlir::Value v) { p << v << " : " << v.getType(); });
     p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
       << upperbound << " : " << upperboundType << ") "
       << " step (" << steps << " : " << stepType << ") ";

``````````

</details>


https://github.com/llvm/llvm-project/pull/82836


More information about the flang-commits mailing list