[Mlir-commits] [flang] [mlir] [Flang][MLIR] Add `!$omp unroll` and `omp.unroll_heuristic` (PR #144785)
Kareem Ergawy
llvmlistbot at llvm.org
Mon Jun 30 05:08:47 PDT 2025
================
@@ -2128,6 +2128,161 @@ genLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
return loopOp;
}
+static mlir::omp::CanonicalLoopOp
+genCanonicalLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
+ semantics::SemanticsContext &semaCtx,
+ lower::pft::Evaluation &eval, mlir::Location loc,
+ const ConstructQueue &queue,
+ ConstructQueue::const_iterator item,
+ llvm::ArrayRef<const semantics::Symbol *> ivs,
+ llvm::omp::Directive directive, DataSharingProcessor &dsp) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+ assert(ivs.size() == 1 && "Nested loops not yet implemented");
+ const semantics::Symbol *iv = ivs[0];
+
+ auto &nestedEval = eval.getFirstNestedEvaluation();
+ if (nestedEval.getIf<parser::DoConstruct>()->IsDoConcurrent()) {
+ TODO(loc, "Do Concurrent in unroll construct");
+ }
+
+ // Get the loop bounds (and increment)
+ auto &doLoopEval = nestedEval.getFirstNestedEvaluation();
+ auto *doStmt = doLoopEval.getIf<parser::NonLabelDoStmt>();
+ assert(doStmt && "Expected do loop to be in the nested evaluation");
+ auto &loopControl = std::get<std::optional<parser::LoopControl>>(doStmt->t);
+ assert(loopControl.has_value());
+ auto *bounds = std::get_if<parser::LoopControl::Bounds>(&loopControl->u);
+ assert(bounds && "Expected bounds for canonical loop");
+ lower::StatementContext stmtCtx;
+ mlir::Value loopLBVar = fir::getBase(
+ converter.genExprValue(*semantics::GetExpr(bounds->lower), stmtCtx));
+ mlir::Value loopUBVar = fir::getBase(
+ converter.genExprValue(*semantics::GetExpr(bounds->upper), stmtCtx));
+ mlir::Value loopStepVar = [&]() {
+ if (bounds->step) {
+ return fir::getBase(
+ converter.genExprValue(*semantics::GetExpr(bounds->step), stmtCtx));
+ } else {
+ // If `step` is not present, assume it is `1`.
+ return firOpBuilder.createIntegerConstant(loc, firOpBuilder.getI32Type(),
+ 1);
+ }
+ }();
+
+ // Get the integer kind for the loop variable and cast the loop bounds
+ size_t loopVarTypeSize = bounds->name.thing.symbol->GetUltimate().size();
+ mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
+ loopLBVar = firOpBuilder.createConvert(loc, loopVarType, loopLBVar);
+ loopUBVar = firOpBuilder.createConvert(loc, loopVarType, loopUBVar);
+ loopStepVar = firOpBuilder.createConvert(loc, loopVarType, loopStepVar);
+
+ // Start lowering
+ mlir::Value zero = firOpBuilder.createIntegerConstant(loc, loopVarType, 0);
+ mlir::Value one = firOpBuilder.createIntegerConstant(loc, loopVarType, 1);
+ mlir::Value isDownwards = firOpBuilder.create<mlir::arith::CmpIOp>(
+ loc, mlir::arith::CmpIPredicate::slt, loopStepVar, zero);
+
+ // Ensure we are counting upwards. If not, negate step and swap lb and ub.
+ mlir::Value negStep =
+ firOpBuilder.create<mlir::arith::SubIOp>(loc, zero, loopStepVar);
+ mlir::Value incr = firOpBuilder.create<mlir::arith::SelectOp>(
+ loc, isDownwards, negStep, loopStepVar);
+ mlir::Value lb = firOpBuilder.create<mlir::arith::SelectOp>(
+ loc, isDownwards, loopUBVar, loopLBVar);
+ mlir::Value ub = firOpBuilder.create<mlir::arith::SelectOp>(
+ loc, isDownwards, loopLBVar, loopUBVar);
+
+ // Compute the trip count assuming lb <= ub. This guarantees that the result
+ // is non-negative and we can use unsigned arithmetic.
+ mlir::Value span = firOpBuilder.create<mlir::arith::SubIOp>(
+ loc, ub, lb, ::mlir::arith::IntegerOverflowFlags::nuw);
+ mlir::Value tcMinusOne =
+ firOpBuilder.create<mlir::arith::DivUIOp>(loc, span, incr);
+ mlir::Value tcIfLooping = firOpBuilder.create<mlir::arith::AddIOp>(
+ loc, tcMinusOne, one, ::mlir::arith::IntegerOverflowFlags::nuw);
+
+ // Fall back to 0 if lb > ub
+ mlir::Value isZeroTC = firOpBuilder.create<mlir::arith::CmpIOp>(
+ loc, mlir::arith::CmpIPredicate::slt, ub, lb);
+ mlir::Value tripcount = firOpBuilder.create<mlir::arith::SelectOp>(
+ loc, isZeroTC, zero, tcIfLooping);
+
+ // Create the CLI handle.
+ auto newcli = firOpBuilder.create<mlir::omp::NewCliOp>(loc);
+ mlir::Value cli = newcli.getResult();
+
+ auto ivCallback = [&](mlir::Operation *op)
+ -> llvm::SmallVector<const Fortran::semantics::Symbol *> {
+ mlir::Region ®ion = op->getRegion(0);
+
+ // Create the op's region skeleton (BB taking the iv as argument)
+ firOpBuilder.createBlock(®ion, {}, {loopVarType}, {loc});
+
+ // Compute the value of the loop variable from the logical iteration number.
+ mlir::Value natIterNum = fir::getBase(region.front().getArgument(0));
+ mlir::Value scaled =
+ firOpBuilder.create<mlir::arith::MulIOp>(loc, natIterNum, loopStepVar);
+ mlir::Value userVal =
+ firOpBuilder.create<mlir::arith::AddIOp>(loc, loopLBVar, scaled);
+
+ // The argument is not currently in memory, so make a temporary for the
+ // argument, and store it there, then bind that location to the argument.
+ mlir::Operation *storeOp =
+ createAndSetPrivatizedLoopVar(converter, loc, userVal, iv);
+
+ firOpBuilder.setInsertionPointAfter(storeOp);
+ return {iv};
+ };
+
+ // Create the omp.canonical_loop operation
+ auto canonLoop = genOpWithBody<mlir::omp::CanonicalLoopOp>(
+ OpWithBodyGenInfo(converter, symTable, semaCtx, loc, nestedEval,
+ directive)
+ .setClauses(&item->clauses)
+ .setDataSharingProcessor(&dsp)
+ .setGenRegionEntryCb(ivCallback),
+ queue, item, tripcount, cli);
+
+ firOpBuilder.setInsertionPointAfter(canonLoop);
+ return canonLoop;
+}
+
+static void genUnrollOp(Fortran::lower::AbstractConverter &converter,
+ Fortran::lower::SymMap &symTable,
+ lower::StatementContext &stmtCtx,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval,
+ mlir::Location loc, const ConstructQueue &queue,
+ ConstructQueue::const_iterator item) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+ mlir::omp::LoopRelatedClauseOps loopInfo;
+ llvm::SmallVector<const semantics::Symbol *> iv;
+ collectLoopRelatedInfo(converter, loc, eval, item->clauses, loopInfo, iv);
----------------
ergawy wrote:
Should this be moved to `genCanonicalLoopOp` instead?
https://github.com/llvm/llvm-project/pull/144785
More information about the Mlir-commits
mailing list