[flang-commits] [flang] [flang][cuda] CUF kernel loop directive (PR #82836)
via flang-commits
flang-commits at lists.llvm.org
Fri Feb 23 14:06:44 PST 2024
github-actions[bot] wrote:
<!--LLVM CODE FORMAT COMMENT: {clang-format}-->
:warning: C/C++ code formatter, clang-format found issues in your code. :warning:
<details>
<summary>
You can test this locally with the following command:
</summary>
``````````bash
git-clang-format --diff a73e9244621a6186859644012c295740465ad844 67b14f6215a8ef0776dba3e8f241f063bcb85372 -- flang/include/flang/Lower/PFTBuilder.h flang/lib/Lower/Bridge.cpp flang/lib/Optimizer/Dialect/FIROps.cpp
``````````
</details>
<details>
<summary>
View the diff from clang-format here.
</summary>
``````````diff
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 2c4825fafd..c7b8cd9602 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2463,7 +2463,7 @@ private:
Fortran::lower::StatementContext stmtCtx;
unsigned nestedLoops = 1;
-
+
const auto &nLoops =
std::get<std::optional<Fortran::parser::ScalarIntConstantExpr>>(dir.t);
if (nLoops)
@@ -2475,17 +2475,21 @@ private:
const std::list<Fortran::parser::ScalarIntExpr> &grid = std::get<1>(dir.t);
const std::list<Fortran::parser::ScalarIntExpr> &block = std::get<2>(dir.t);
- const std::optional<Fortran::parser::ScalarIntExpr> &stream = std::get<3>(dir.t);
+ const std::optional<Fortran::parser::ScalarIntExpr> &stream =
+ std::get<3>(dir.t);
llvm::SmallVector<mlir::Value> gridValues;
for (const Fortran::parser::ScalarIntExpr &expr : grid)
- gridValues.push_back(fir::getBase(genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)));
+ gridValues.push_back(fir::getBase(
+ genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)));
llvm::SmallVector<mlir::Value> blockValues;
for (const Fortran::parser::ScalarIntExpr &expr : block)
- blockValues.push_back(fir::getBase(genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)));
+ blockValues.push_back(fir::getBase(
+ genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)));
mlir::Value streamValue;
if (stream)
- streamValue = fir::getBase(genExprValue(*Fortran::semantics::GetExpr(*stream), stmtCtx));
+ streamValue = fir::getBase(
+ genExprValue(*Fortran::semantics::GetExpr(*stream), stmtCtx));
const auto &outerDoConstruct =
std::get<std::optional<Fortran::parser::DoConstruct>>(kernel.t);
@@ -2501,12 +2505,14 @@ private:
llvm::SmallVector<mlir::Value> ivValues;
for (unsigned i = 0; i < nestedLoops; ++i) {
const Fortran::parser::LoopControl *loopControl;
- Fortran::lower::pft::Evaluation *loopEval = &getEval().getFirstNestedEvaluation();
+ Fortran::lower::pft::Evaluation *loopEval =
+ &getEval().getFirstNestedEvaluation();
mlir::Location crtLoc = loc;
if (i == 0) {
loopControl = &*outerDoConstruct->GetLoopControl();
- crtLoc = genLocation(Fortran::parser::FindSourceLocation(outerDoConstruct));
+ crtLoc =
+ genLocation(Fortran::parser::FindSourceLocation(outerDoConstruct));
} else {
auto *doCons = loopEval->getIf<Fortran::parser::DoConstruct>();
assert(doCons && "expect do construct");
@@ -2524,18 +2530,20 @@ private:
bounds->name.thing.symbol->GetUltimate();
ivValues.push_back(getSymbolAddress(ivSym));
- lbs.push_back(builder->createConvert(crtLoc, idxTy,
- fir::getBase(genExprValue(
- *Fortran::semantics::GetExpr(bounds->lower), stmtCtx))));
- ubs.push_back(builder->createConvert(crtLoc, idxTy,
- fir::getBase(genExprValue(
- *Fortran::semantics::GetExpr(bounds->upper), stmtCtx))));
+ lbs.push_back(builder->createConvert(
+ crtLoc, idxTy,
+ fir::getBase(genExprValue(*Fortran::semantics::GetExpr(bounds->lower),
+ stmtCtx))));
+ ubs.push_back(builder->createConvert(
+ crtLoc, idxTy,
+ fir::getBase(genExprValue(*Fortran::semantics::GetExpr(bounds->upper),
+ stmtCtx))));
if (bounds->step)
- steps.push_back(fir::getBase(genExprValue(
- *Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
+ steps.push_back(fir::getBase(
+ genExprValue(*Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
else // If `step` is not present, assume it is `1`.
steps.push_back(builder->createIntegerConstant(loc, idxTy, 1));
-
+
ivTypes.push_back(idxTy);
ivLocs.push_back(crtLoc);
if (i < nestedLoops - 1)
@@ -2544,13 +2552,15 @@ private:
auto op = builder->create<fir::CUDAKernelOp>(
loc, gridValues, blockValues, streamValue, lbs, ubs, steps, n);
- builder->createBlock(&op.getRegion(), op.getRegion().end(), ivTypes, ivLocs);
+ builder->createBlock(&op.getRegion(), op.getRegion().end(), ivTypes,
+ ivLocs);
mlir::Block &b = op.getRegion().back();
builder->setInsertionPointToStart(&b);
for (auto [arg, value] : llvm::zip(
- op.getLoopRegions().front()->front().getArguments(), ivValues)) {
- mlir::Value convArg = builder->createConvert(loc, fir::unwrapRefType(value.getType()), arg);
+ op.getLoopRegions().front()->front().getArguments(), ivValues)) {
+ mlir::Value convArg =
+ builder->createConvert(loc, fir::unwrapRefType(value.getType()), arg);
builder->create<fir::StoreOp>(loc, convArg, value);
}
@@ -2564,8 +2574,6 @@ private:
crtEval = &*std::next(crtEval->getNestedEvaluations().begin());
}
-
-
// Generate loop body
for (Fortran::lower::pft::Evaluation &e : crtEval->getNestedEvaluations())
genFIR(e);
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index c2facb5a00..9bb10a42a3 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -3870,13 +3870,13 @@ llvm::SmallVector<mlir::Region *> fir::CUDAKernelOp::getLoopRegions() {
return {&getRegion()};
}
-mlir::ParseResult
-parseCUFKernelValues(mlir::OpAsmParser &parser,
- llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &values,
- llvm::SmallVectorImpl<mlir::Type> &types) {
+mlir::ParseResult parseCUFKernelValues(
+ mlir::OpAsmParser &parser,
+ llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &values,
+ llvm::SmallVectorImpl<mlir::Type> &types) {
if (mlir::succeeded(parser.parseOptionalStar()))
return mlir::success();
-
+
if (parser.parseOptionalLParen()) {
if (mlir::failed(parser.parseCommaSeparatedList(
mlir::AsmParser::Delimiter::None, [&]() {
@@ -3902,14 +3902,13 @@ void printCUFKernelValues(mlir::OpAsmPrinter &p, mlir::Operation *op,
if (values.size() > 1)
p << "(";
- llvm::interleaveComma(values, p,
- [&p](mlir::Value v) { p << v; });
+ llvm::interleaveComma(values, p, [&p](mlir::Value v) { p << v; });
if (values.size() > 1)
p << ")";
}
-mlir::ParseResult
-parseCUFKernelLoopControl(mlir::OpAsmParser &parser, mlir::Region ®ion,
+mlir::ParseResult parseCUFKernelLoopControl(
+ mlir::OpAsmParser &parser, mlir::Region ®ion,
llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &lowerbound,
llvm::SmallVectorImpl<mlir::Type> &lowerboundType,
llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &upperbound,
@@ -3919,8 +3918,9 @@ parseCUFKernelLoopControl(mlir::OpAsmParser &parser, mlir::Region ®ion,
llvm::SmallVector<mlir::OpAsmParser::Argument> inductionVars;
if (parser.parseLParen() ||
- parser.parseArgumentList(inductionVars, mlir::OpAsmParser::Delimiter::None,
- /*allowType=*/true) ||
+ parser.parseArgumentList(inductionVars,
+ mlir::OpAsmParser::Delimiter::None,
+ /*allowType=*/true) ||
parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
parser.parseOperandList(lowerbound, inductionVars.size(),
mlir::OpAsmParser::Delimiter::None) ||
@@ -3937,17 +3937,16 @@ parseCUFKernelLoopControl(mlir::OpAsmParser &parser, mlir::Region ®ion,
return parser.parseRegion(region, inductionVars);
}
-void printCUFKernelLoopControl(mlir::OpAsmPrinter &p, mlir::Operation *op,
- mlir::Region ®ion, mlir::ValueRange lowerbound,
- mlir::TypeRange lowerboundType,
- mlir::ValueRange upperbound,
- mlir::TypeRange upperboundType, mlir::ValueRange steps,
- mlir::TypeRange stepType) {
+void printCUFKernelLoopControl(
+ mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Region ®ion,
+ mlir::ValueRange lowerbound, mlir::TypeRange lowerboundType,
+ mlir::ValueRange upperbound, mlir::TypeRange upperboundType,
+ mlir::ValueRange steps, mlir::TypeRange stepType) {
mlir::ValueRange regionArgs = region.front().getArguments();
if (!regionArgs.empty()) {
p << "(";
- llvm::interleaveComma(regionArgs, p,
- [&p](mlir::Value v) { p << v << " : " << v.getType(); });
+ llvm::interleaveComma(
+ regionArgs, p, [&p](mlir::Value v) { p << v << " : " << v.getType(); });
p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
<< upperbound << " : " << upperboundType << ") "
<< " step (" << steps << " : " << stepType << ") ";
``````````
</details>
https://github.com/llvm/llvm-project/pull/82836
More information about the flang-commits
mailing list