[flang-commits] [flang] [flang][cuda] CUF kernel loop directive (PR #82836)
Valentin Clement バレンタイン クレメン via flang-commits
flang-commits at lists.llvm.org
Tue Feb 27 09:11:41 PST 2024
https://github.com/clementval updated https://github.com/llvm/llvm-project/pull/82836
>From 46b322d8e4b00124b5cf2a06c2ceada8a96679c6 Mon Sep 17 00:00:00 2001
From: Valentin Clement <clementval at gmail.com>
Date: Thu, 15 Feb 2024 14:45:50 -0800
Subject: [PATCH] [flang][cuda] CUF kernel loop directive
This patch introduces a new operation to represent the CUDA Fortran
kernel loop directive. This operation is modelled as a LoopLikeOp
operation in a similar way to acc.loop.
Lowering from the flang parse-tree to MLIR is also done.
---
flang/include/flang/Lower/PFTBuilder.h | 5 +-
.../include/flang/Optimizer/Dialect/FIROps.td | 27 ++++
flang/lib/Lower/Bridge.cpp | 129 ++++++++++++++++++
flang/lib/Optimizer/Dialect/FIROps.cpp | 97 +++++++++++++
.../Lower/CUDA/cuda-kernel-loop-directive.cuf | 51 +++++++
5 files changed, 307 insertions(+), 2 deletions(-)
create mode 100644 flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf
diff --git a/flang/include/flang/Lower/PFTBuilder.h b/flang/include/flang/Lower/PFTBuilder.h
index c2b0fdbf357cde..9913f584133faa 100644
--- a/flang/include/flang/Lower/PFTBuilder.h
+++ b/flang/include/flang/Lower/PFTBuilder.h
@@ -138,7 +138,8 @@ using Directives =
std::tuple<parser::CompilerDirective, parser::OpenACCConstruct,
parser::OpenACCRoutineConstruct,
parser::OpenACCDeclarativeConstruct, parser::OpenMPConstruct,
- parser::OpenMPDeclarativeConstruct, parser::OmpEndLoopDirective>;
+ parser::OpenMPDeclarativeConstruct, parser::OmpEndLoopDirective,
+ parser::CUFKernelDoConstruct>;
using DeclConstructs = std::tuple<parser::OpenMPDeclarativeConstruct,
parser::OpenACCDeclarativeConstruct>;
@@ -178,7 +179,7 @@ static constexpr bool isNopConstructStmt{common::HasMember<
template <typename A>
static constexpr bool isExecutableDirective{common::HasMember<
A, std::tuple<parser::CompilerDirective, parser::OpenACCConstruct,
- parser::OpenMPConstruct>>};
+ parser::OpenMPConstruct, parser::CUFKernelDoConstruct>>};
template <typename A>
static constexpr bool isFunctionLike{common::HasMember<
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index 08239230f793f1..db5e5f4bc682e6 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -3127,4 +3127,31 @@ def fir_BoxOffsetOp : fir_Op<"box_offset", [NoMemoryEffect]> {
];
}
+def fir_CUDAKernelOp : fir_Op<"cuda_kernel", [AttrSizedOperandSegments,
+ DeclareOpInterfaceMethods<LoopLikeOpInterface>]> {
+
+ let arguments = (ins
+ Variadic<I32>:$grid, // empty means `*`
+ Variadic<I32>:$block, // empty means `*`
+ Optional<I32>:$stream,
+ Variadic<Index>:$lowerbound,
+ Variadic<Index>:$upperbound,
+ Variadic<Index>:$step,
+ OptionalAttr<I64Attr>:$n
+ );
+
+ let regions = (region AnyRegion:$region);
+
+ let assemblyFormat = [{
+ `<` `<` `<` custom<CUFKernelValues>($grid, type($grid)) `,`
+ custom<CUFKernelValues>($block, type($block))
+ ( `,` `stream` `=` $stream^ )? `>` `>` `>`
+ custom<CUFKernelLoopControl>($region, $lowerbound, type($lowerbound),
+ $upperbound, type($upperbound), $step, type($step))
+ attr-dict
+ }];
+
+ let hasVerifier = 1;
+}
+
#endif
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 2d7f748cefa2d8..c7b8cd96021ce2 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2453,6 +2453,135 @@ class FirConverter : public Fortran::lower::AbstractConverter {
// Handled by genFIR(const Fortran::parser::OpenACCDeclarativeConstruct &)
}
+ void genFIR(const Fortran::parser::CUFKernelDoConstruct &kernel) {
+ localSymbols.pushScope();
+ const Fortran::parser::CUFKernelDoConstruct::Directive &dir =
+ std::get<Fortran::parser::CUFKernelDoConstruct::Directive>(kernel.t);
+
+ mlir::Location loc = genLocation(dir.source);
+
+ Fortran::lower::StatementContext stmtCtx;
+
+ unsigned nestedLoops = 1;
+
+ const auto &nLoops =
+ std::get<std::optional<Fortran::parser::ScalarIntConstantExpr>>(dir.t);
+ if (nLoops)
+ nestedLoops = *Fortran::semantics::GetIntValue(*nLoops);
+
+ mlir::IntegerAttr n;
+ if (nestedLoops > 1)
+ n = builder->getIntegerAttr(builder->getI64Type(), nestedLoops);
+
+ const std::list<Fortran::parser::ScalarIntExpr> &grid = std::get<1>(dir.t);
+ const std::list<Fortran::parser::ScalarIntExpr> &block = std::get<2>(dir.t);
+ const std::optional<Fortran::parser::ScalarIntExpr> &stream =
+ std::get<3>(dir.t);
+
+ llvm::SmallVector<mlir::Value> gridValues;
+ for (const Fortran::parser::ScalarIntExpr &expr : grid)
+ gridValues.push_back(fir::getBase(
+ genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)));
+ llvm::SmallVector<mlir::Value> blockValues;
+ for (const Fortran::parser::ScalarIntExpr &expr : block)
+ blockValues.push_back(fir::getBase(
+ genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx)));
+ mlir::Value streamValue;
+ if (stream)
+ streamValue = fir::getBase(
+ genExprValue(*Fortran::semantics::GetExpr(*stream), stmtCtx));
+
+ const auto &outerDoConstruct =
+ std::get<std::optional<Fortran::parser::DoConstruct>>(kernel.t);
+
+ llvm::SmallVector<mlir::Location> locs;
+ locs.push_back(loc);
+ llvm::SmallVector<mlir::Value> lbs, ubs, steps;
+
+ mlir::Type idxTy = builder->getIndexType();
+
+ llvm::SmallVector<mlir::Type> ivTypes;
+ llvm::SmallVector<mlir::Location> ivLocs;
+ llvm::SmallVector<mlir::Value> ivValues;
+ for (unsigned i = 0; i < nestedLoops; ++i) {
+ const Fortran::parser::LoopControl *loopControl;
+ Fortran::lower::pft::Evaluation *loopEval =
+ &getEval().getFirstNestedEvaluation();
+
+ mlir::Location crtLoc = loc;
+ if (i == 0) {
+ loopControl = &*outerDoConstruct->GetLoopControl();
+ crtLoc =
+ genLocation(Fortran::parser::FindSourceLocation(outerDoConstruct));
+ } else {
+ auto *doCons = loopEval->getIf<Fortran::parser::DoConstruct>();
+ assert(doCons && "expect do construct");
+ loopControl = &*doCons->GetLoopControl();
+ crtLoc = genLocation(Fortran::parser::FindSourceLocation(*doCons));
+ }
+
+ locs.push_back(crtLoc);
+
+ const Fortran::parser::LoopControl::Bounds *bounds =
+ std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
+ assert(bounds && "Expected bounds on the loop construct");
+
+ Fortran::semantics::Symbol &ivSym =
+ bounds->name.thing.symbol->GetUltimate();
+ ivValues.push_back(getSymbolAddress(ivSym));
+
+ lbs.push_back(builder->createConvert(
+ crtLoc, idxTy,
+ fir::getBase(genExprValue(*Fortran::semantics::GetExpr(bounds->lower),
+ stmtCtx))));
+ ubs.push_back(builder->createConvert(
+ crtLoc, idxTy,
+ fir::getBase(genExprValue(*Fortran::semantics::GetExpr(bounds->upper),
+ stmtCtx))));
+ if (bounds->step)
+ steps.push_back(fir::getBase(
+ genExprValue(*Fortran::semantics::GetExpr(bounds->step), stmtCtx)));
+ else // If `step` is not present, assume it is `1`.
+ steps.push_back(builder->createIntegerConstant(loc, idxTy, 1));
+
+ ivTypes.push_back(idxTy);
+ ivLocs.push_back(crtLoc);
+ if (i < nestedLoops - 1)
+ loopEval = &*std::next(loopEval->getNestedEvaluations().begin());
+ }
+
+ auto op = builder->create<fir::CUDAKernelOp>(
+ loc, gridValues, blockValues, streamValue, lbs, ubs, steps, n);
+ builder->createBlock(&op.getRegion(), op.getRegion().end(), ivTypes,
+ ivLocs);
+ mlir::Block &b = op.getRegion().back();
+ builder->setInsertionPointToStart(&b);
+
+ for (auto [arg, value] : llvm::zip(
+ op.getLoopRegions().front()->front().getArguments(), ivValues)) {
+ mlir::Value convArg =
+ builder->createConvert(loc, fir::unwrapRefType(value.getType()), arg);
+ builder->create<fir::StoreOp>(loc, convArg, value);
+ }
+
+ builder->create<fir::FirEndOp>(loc);
+ builder->setInsertionPointToStart(&b);
+
+ Fortran::lower::pft::Evaluation *crtEval = &getEval();
+ if (crtEval->lowerAsStructured()) {
+ crtEval = &crtEval->getFirstNestedEvaluation();
+ for (int64_t i = 1; i < nestedLoops; i++)
+ crtEval = &*std::next(crtEval->getNestedEvaluations().begin());
+ }
+
+ // Generate loop body
+ for (Fortran::lower::pft::Evaluation &e : crtEval->getNestedEvaluations())
+ genFIR(e);
+
+ builder->setInsertionPointAfter(op);
+ localSymbols.popScope();
+ }
+
void genFIR(const Fortran::parser::OpenMPConstruct &omp) {
mlir::OpBuilder::InsertPoint insertPt = builder->saveInsertionPoint();
genOpenMPConstruct(*this, localSymbols, bridge.getSemanticsContext(),
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 0a534cdb3c4871..bbf5c0c09101be 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -3866,6 +3866,103 @@ mlir::LogicalResult fir::DeclareOp::verify() {
return fortranVar.verifyDeclareLikeOpImpl(getMemref());
}
+llvm::SmallVector<mlir::Region *> fir::CUDAKernelOp::getLoopRegions() {
+ return {&getRegion()};
+}
+
+mlir::ParseResult parseCUFKernelValues(
+ mlir::OpAsmParser &parser,
+ llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &values,
+ llvm::SmallVectorImpl<mlir::Type> &types) {
+ if (mlir::succeeded(parser.parseOptionalStar()))
+ return mlir::success();
+
+ if (parser.parseOptionalLParen()) {
+ if (mlir::failed(parser.parseCommaSeparatedList(
+ mlir::AsmParser::Delimiter::None, [&]() {
+ if (parser.parseOperand(values.emplace_back()))
+ return mlir::failure();
+ return mlir::success();
+ })))
+ return mlir::failure();
+ if (parser.parseRParen())
+ return mlir::failure();
+ } else {
+ if (parser.parseOperand(values.emplace_back()))
+ return mlir::failure();
+ return mlir::success();
+ }
+ return mlir::success();
+}
+
+void printCUFKernelValues(mlir::OpAsmPrinter &p, mlir::Operation *op,
+ mlir::ValueRange values, mlir::TypeRange types) {
+ if (values.empty())
+ p << "*";
+
+ if (values.size() > 1)
+ p << "(";
+ llvm::interleaveComma(values, p, [&p](mlir::Value v) { p << v; });
+ if (values.size() > 1)
+ p << ")";
+}
+
+mlir::ParseResult parseCUFKernelLoopControl(
+ mlir::OpAsmParser &parser, mlir::Region ®ion,
+ llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &lowerbound,
+ llvm::SmallVectorImpl<mlir::Type> &lowerboundType,
+ llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &upperbound,
+ llvm::SmallVectorImpl<mlir::Type> &upperboundType,
+ llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &step,
+ llvm::SmallVectorImpl<mlir::Type> &stepType) {
+
+ llvm::SmallVector<mlir::OpAsmParser::Argument> inductionVars;
+ if (parser.parseLParen() ||
+ parser.parseArgumentList(inductionVars,
+ mlir::OpAsmParser::Delimiter::None,
+ /*allowType=*/true) ||
+ parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
+ parser.parseOperandList(lowerbound, inductionVars.size(),
+ mlir::OpAsmParser::Delimiter::None) ||
+ parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
+ parser.parseKeyword("to") || parser.parseLParen() ||
+ parser.parseOperandList(upperbound, inductionVars.size(),
+ mlir::OpAsmParser::Delimiter::None) ||
+ parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
+ parser.parseKeyword("step") || parser.parseLParen() ||
+ parser.parseOperandList(step, inductionVars.size(),
+ mlir::OpAsmParser::Delimiter::None) ||
+ parser.parseColonTypeList(stepType) || parser.parseRParen())
+ return mlir::failure();
+ return parser.parseRegion(region, inductionVars);
+}
+
+void printCUFKernelLoopControl(
+ mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Region ®ion,
+ mlir::ValueRange lowerbound, mlir::TypeRange lowerboundType,
+ mlir::ValueRange upperbound, mlir::TypeRange upperboundType,
+ mlir::ValueRange steps, mlir::TypeRange stepType) {
+ mlir::ValueRange regionArgs = region.front().getArguments();
+ if (!regionArgs.empty()) {
+ p << "(";
+ llvm::interleaveComma(
+ regionArgs, p, [&p](mlir::Value v) { p << v << " : " << v.getType(); });
+ p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
+ << upperbound << " : " << upperboundType << ") " << " step (" << steps
+ << " : " << stepType << ") ";
+ }
+ p.printRegion(region, /*printEntryBlockArgs=*/false);
+}
+
+mlir::LogicalResult fir::CUDAKernelOp::verify() {
+ if (getLowerbound().size() != getUpperbound().size() ||
+ getLowerbound().size() != getStep().size())
+ return emitOpError(
+ "expect same number of values in lowerbound, upperbound and step");
+
+ return mlir::success();
+}
+
//===----------------------------------------------------------------------===//
// FIROpsDialect
//===----------------------------------------------------------------------===//
diff --git a/flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf b/flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf
new file mode 100644
index 00000000000000..db628fe756b952
--- /dev/null
+++ b/flang/test/Lower/CUDA/cuda-kernel-loop-directive.cuf
@@ -0,0 +1,51 @@
+! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
+
+! Test lowering of CUDA kernel loop directive.
+
+subroutine sub1()
+ integer :: i, j
+ integer, parameter :: n = 100
+ real :: a(n), b(n)
+ real :: c(n,n), d(n,n)
+
+! CHECK-LABEL: func.func @_QPsub1()
+! CHECK: %[[IV:.*]]:2 = hlfir.declare %{{.*}} {uniq_name = "_QFsub1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+
+ !$cuf kernel do <<< 1, 2 >>>
+ do i = 1, n
+ a(i) = a(i) * b(i)
+ end do
+
+! CHECK: %[[LB:.*]] = fir.convert %c1{{.*}} : (i32) -> index
+! CHECK: %[[UB:.*]] = fir.convert %c100{{.*}} : (i32) -> index
+! CHECK: %[[STEP:.*]] = arith.constant 1 : index
+! CHECK: fir.cuda_kernel<<<%c1_i32, %c2_i32>>> (%[[ARG0:.*]] : index) = (%[[LB]] : index) to (%[[UB]] : index) step (%[[STEP]] : index)
+! CHECK-NOT: fir.do_loop
+! CHECK: %[[ARG0_I32:.*]] = fir.convert %[[ARG0]] : (index) -> i32
+! CHECK: fir.store %[[ARG0_I32]] to %[[IV]]#1 : !fir.ref<i32>
+
+
+ !$cuf kernel do <<< *, * >>>
+ do i = 1, n
+ a(i) = a(i) * b(i)
+ end do
+
+! CHECK: fir.cuda_kernel<<<*, *>>> (%{{.*}} : index) = (%{{.*}} : index) to (%{{.*}} : index) step (%{{.*}} : index)
+
+ !$cuf kernel do(2) <<< 1, (256,1) >>>
+ do i = 1, n
+ do j = 1, n
+ c(i,j) = c(i,j) * d(i,j)
+ end do
+ end do
+
+! CHECK: fir.cuda_kernel<<<%c1{{.*}}, (%c256{{.*}}, %c1{{.*}})>>> (%{{.*}} : index, %{{.*}} : index) = (%{{.*}}, %{{.*}} : index, index) to (%{{.*}}, %{{.*}} : index, index) step (%{{.*}}, %{{.*}} : index, index)
+! CHECK: {n = 2 : i64}
+
+! TODO: currently these trigger error in the parser
+! !$cuf kernel do(2) <<< (1,*), (256,1) >>>
+! !$cuf kernel do(2) <<< (*,*), (32,4) >>>
+end
+
+
+
More information about the flang-commits
mailing list