[flang-commits] [flang] [mlir] [Flang][OpenMP] Support conditional lastprivate on host (PR #200086)
Sunil Shrestha via flang-commits
flang-commits at lists.llvm.org
Thu Jun 18 23:50:35 PDT 2026
================
@@ -4297,6 +4840,317 @@ getReductionType(lower::AbstractConverter &converter,
return reductionType;
}
+/// Compute a flattened canonical (0-based, always ascending) iteration number
+/// from all loop IVs. For a single loop, this is simply (IV - LB) / step.
+/// For collapsed loops with dimensions d0..dN, the flattened index is:
+/// c0 * (N1*N2*...*Nk) + c1 * (N2*...*Nk) + ... + ck
+/// where ci = (IVi - LBi) / stepi and Ni = (UBi - LBi) / stepi + 1.
+/// This yields a unique monotonic index regardless of loop direction,
+/// which is essential for the combiner's `sgt` comparison to correctly
+/// identify the sequentially last iteration.
+static mlir::Value
+computeFlattenedCanonicalIV(fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::omp::LoopNestOp loopNestOp) {
+ mlir::Region ®ion = loopNestOp.getRegion();
+ auto lbs = loopNestOp.getLoopLowerBounds();
+ auto ubs = loopNestOp.getLoopUpperBounds();
+ auto steps = loopNestOp.getLoopSteps();
+ unsigned numDims = lbs.size();
+
+ // Use i64 for the flattened index to avoid overflow.
+ mlir::Type i64Ty = builder.getI64Type();
+
+ // Compute canonical IV and trip count for each dimension.
+ llvm::SmallVector<mlir::Value> canonIVs(numDims);
+ llvm::SmallVector<mlir::Value> tripCounts(numDims);
+ for (unsigned d = 0; d < numDims; ++d) {
+ mlir::Value iv = region.front().getArgument(d);
+ mlir::Type ivType = iv.getType();
+ mlir::Value lb = lbs[d];
+ mlir::Value ub = ubs[d];
+ mlir::Value step = steps[d];
+ if (lb.getType() != ivType)
+ lb = fir::ConvertOp::create(builder, loc, ivType, lb);
+ if (ub.getType() != ivType)
+ ub = fir::ConvertOp::create(builder, loc, ivType, ub);
+ if (step.getType() != ivType)
+ step = fir::ConvertOp::create(builder, loc, ivType, step);
+
+ mlir::Value diff = mlir::arith::SubIOp::create(builder, loc, iv, lb);
+ mlir::Value ci = mlir::arith::DivSIOp::create(builder, loc, diff, step);
+ canonIVs[d] = fir::ConvertOp::create(builder, loc, i64Ty, ci);
+
+ // Trip count: (UB - LB) / step + 1 (loop bounds are inclusive).
+ mlir::Value range = mlir::arith::SubIOp::create(builder, loc, ub, lb);
+ mlir::Value trips = mlir::arith::DivSIOp::create(builder, loc, range, step);
+ mlir::Value one = builder.createIntegerConstant(loc, ivType, 1);
+ trips = mlir::arith::AddIOp::create(builder, loc, trips, one);
+ tripCounts[d] = fir::ConvertOp::create(builder, loc, i64Ty, trips);
+ }
+
+ // Flatten: result = c0*N1*N2*...*Nk + c1*N2*...*Nk + ... + ck
+ mlir::Value flatIdx = canonIVs[0];
+ for (unsigned d = 1; d < numDims; ++d) {
+ flatIdx = mlir::arith::MulIOp::create(builder, loc, flatIdx, tripCounts[d]);
+ flatIdx = mlir::arith::AddIOp::create(builder, loc, flatIdx, canonIVs[d]);
+ }
+ return flatIdx;
+}
+
+/// Bind conditional lastprivate symbols to their value fields inside the
+/// reduction struct. This must be called \b before body lowering so that all
+/// references to the LP symbols resolve to struct field addresses directly,
+/// avoiding the need for a post-hoc address-replacement rewrite.
+///
+/// Returns a map from the newly-created struct-field addresses to symbol names
+/// so that \c injectCondLpIndexStores can later locate writes to these fields.
+static llvm::MapVector<mlir::Value, std::string> bindCondLpSymsToStructFields(
+ lower::AbstractConverter &converter, mlir::Location loc,
+ fir::RecordType lpType, mlir::Value structArg,
+ const llvm::SetVector<const semantics::Symbol *> &condLpSyms) {
+ fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+ llvm::MapVector<mlir::Value, std::string> valAddrToSymName;
+ for (const auto *sym : condLpSyms) {
+ std::string symName = sym->name().ToString();
+ unsigned valFieldIdx = lpType.getFieldIndex(symName);
+ mlir::Type valType = lpType.getType(valFieldIdx);
+
+ fir::IntOrValue valFIdx =
+ mlir::IntegerAttr::get(builder.getI32Type(), valFieldIdx);
+ mlir::Value valAddr = fir::CoordinateOp::create(
+ builder, loc, builder.getRefType(valType), structArg,
+ llvm::SmallVector<fir::IntOrValue, 1>{valFIdx});
+
+ converter.bindSymbol(*sym, valAddr);
+ valAddrToSymName[valAddr] = symName;
+ }
+ return valAddrToSymName;
+}
+
+/// Walk the given region to find assignments (hlfir.assign / fir.store) that
+/// target one of the struct value-field addresses in \p valAddrToSymName and
+/// inject a store of the index value (produced by \p genIndexVal) into the
+/// corresponding index field immediately after each such assignment.
+static void injectCondLpIndexStores(
+ fir::FirOpBuilder &builder, mlir::Location loc, fir::RecordType lpType,
+ mlir::Value structArg, mlir::Region ®ion,
+ const llvm::MapVector<mlir::Value, std::string> &valAddrToSymName,
+ llvm::function_ref<mlir::Value(fir::FirOpBuilder &, mlir::Location)>
+ genIndexVal) {
+ // Look through hlfir.declare to find the underlying struct field address.
+ // When symbols are bound via bindCondLpSymsToStructFields, the lowering
+ // wraps the fir.coordinate_of result in hlfir.declare, so the actual write
+ // target is the declare result rather than the raw coordinate_of.
+ auto lookThroughDeclare = [](mlir::Value v) -> mlir::Value {
+ if (auto declOp = v.getDefiningOp<hlfir::DeclareOp>())
+ return declOp.getMemref();
+ return v;
+ };
+
+ llvm::SmallVector<mlir::Operation *> toAnnotate;
+ region.walk([&](hlfir::AssignOp assignOp) {
+ if (valAddrToSymName.count(lookThroughDeclare(assignOp.getLhs())))
+ toAnnotate.push_back(assignOp);
+ });
+ region.walk([&](fir::StoreOp storeOp) {
+ if (valAddrToSymName.count(lookThroughDeclare(storeOp.getMemref())))
+ toAnnotate.push_back(storeOp);
+ });
+
+ // Compute the index value once at the region entry so that it dominates
+ // all write sites (which may be inside nested fir.if blocks).
+ mlir::Value indexVal;
+ if (!toAnnotate.empty()) {
+ mlir::OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointToStart(®ion.front());
+ indexVal = genIndexVal(builder, loc);
+ if (indexVal.getType() != builder.getI64Type())
+ indexVal =
+ fir::ConvertOp::create(builder, loc, builder.getI64Type(), indexVal);
+ }
+
+ for (mlir::Operation *writeOp : toAnnotate) {
+ mlir::Value target;
+ if (auto assignOp = mlir::dyn_cast<hlfir::AssignOp>(writeOp))
+ target = lookThroughDeclare(assignOp.getLhs());
+ else
+ target =
+ lookThroughDeclare(mlir::cast<fir::StoreOp>(writeOp).getMemref());
+
+ std::string symName = valAddrToSymName.lookup(target);
+ unsigned idxFieldIdx = lpType.getFieldIndex("$" + symName);
+ mlir::Type idxType = lpType.getType(idxFieldIdx);
+
+ mlir::OpBuilder::InsertionGuard guard(builder);
+ builder.setInsertionPointAfter(writeOp);
+
+ fir::IntOrValue idxFIdx =
+ mlir::IntegerAttr::get(builder.getI32Type(), idxFieldIdx);
+ mlir::Value idxAddr = fir::CoordinateOp::create(
+ builder, loc, builder.getRefType(idxType), structArg,
+ llvm::SmallVector<fir::IntOrValue, 1>{idxFIdx});
+
+ fir::StoreOp::create(builder, loc, indexVal, idxAddr);
+ }
+}
+
+static mlir::omp::DeclareReductionOp buildConditionalLastPrivateReduction(
+ lower::AbstractConverter &converter, fir::RecordType lpCondType,
+ const llvm::SetVector<const semantics::Symbol *> &condLpSyms) {
+
+ // Init callback: initialize all fields of the lp_t struct.
+ // Value fields get 0; index fields get -1.
+ //
+ // Returns a null mlir::Value to signal that initialization has already
+ // been performed directly on ompPriv. The reduction infrastructure
+ // (populateByRefInitAndCleanupRegions → initAndCleanupUnboxedDerivedType)
+ // checks for a non-null scalarInitValue before emitting a store, so
+ // returning null here safely skips the redundant store.
+ auto genInitValueCB = [lpCondType](fir::FirOpBuilder &builder,
+ mlir::Location loc, mlir::Type type,
+ mlir::Value ompOrig,
+ mlir::Value ompPriv) -> mlir::Value {
+ initConditionalLpStruct(builder, loc, lpCondType, ompPriv);
+ return mlir::Value{};
+ };
+
+ // Combiner callback: for each (value, index) pair, pick the later iteration.
+ // Fields are arranged as: {val_0, ..., val_{N-1}, idx_0, ..., idx_{N-1}}
+ // where idx field names are "$" + val field name.
+ // If rhs.idx > lhs.idx, copy rhs value and index into lhs.
+ auto genCombinerCB = [lpCondType](fir::FirOpBuilder &builder,
+ mlir::Location loc, mlir::Type type,
+ mlir::Value lhs, mlir::Value rhs,
+ bool isByRef) {
+ fir::RecordType lpType = lpCondType; // non-const copy for getFieldIndex
+ auto fields = lpType.getTypeList();
+ unsigned numVars = fields.size() / 2;
+
+ // Walk the first half (value fields). Index field name = "$" +
+ // value name. The "$" character is invalid in Fortran identifiers,
+ // so the prefix cannot collide with any user variable name.
+ for (unsigned i = 0; i < numVars; ++i) {
+ auto [valName, valType] = fields[i];
+ std::string idxName = "$" + valName;
+ unsigned valIdx = lpType.getFieldIndex(valName);
+ unsigned idxIdx = lpType.getFieldIndex(idxName);
+ mlir::Type idxType = lpType.getType(idxIdx);
+
+ // Get addresses of LHS and RHS index fields
+ fir::IntOrValue idxFieldIdx =
+ mlir::IntegerAttr::get(builder.getI32Type(), idxIdx);
+ mlir::Value lhsIdxAddr = fir::CoordinateOp::create(
+ builder, loc, builder.getRefType(idxType), lhs,
+ llvm::SmallVector<fir::IntOrValue, 1>{idxFieldIdx});
+ mlir::Value rhsIdxAddr = fir::CoordinateOp::create(
+ builder, loc, builder.getRefType(idxType), rhs,
+ llvm::SmallVector<fir::IntOrValue, 1>{idxFieldIdx});
+
+ mlir::Value lhsIdx = fir::LoadOp::create(builder, loc, lhsIdxAddr);
+ mlir::Value rhsIdx = fir::LoadOp::create(builder, loc, rhsIdxAddr);
+
+ // Compare: rhs index > lhs index (signed, iteration indices)
+ mlir::Value cmp = mlir::arith::CmpIOp::create(
+ builder, loc, mlir::arith::CmpIPredicate::sgt, rhsIdx, lhsIdx);
+
+ // If RHS comes from a later iteration, copy its value and index to LHS
+ auto ifOp = fir::IfOp::create(builder, loc, cmp, /*else*/ false);
+ builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
+ // Copy value field: rhs.val_s → lhs.val_s
+ fir::IntOrValue valFieldIdx =
+ mlir::IntegerAttr::get(builder.getI32Type(), valIdx);
+ mlir::Value rhsValAddr = fir::CoordinateOp::create(
+ builder, loc, builder.getRefType(valType), rhs,
+ llvm::SmallVector<fir::IntOrValue, 1>{valFieldIdx});
+ mlir::Value lhsValAddr = fir::CoordinateOp::create(
+ builder, loc, builder.getRefType(valType), lhs,
+ llvm::SmallVector<fir::IntOrValue, 1>{valFieldIdx});
+ mlir::Value rhsVal = fir::LoadOp::create(builder, loc, rhsValAddr);
+ fir::StoreOp::create(builder, loc, rhsVal, lhsValAddr);
+
+ // Copy index field: rhs.idx_s → lhs.idx_s
+ fir::StoreOp::create(builder, loc, rhsIdx, lhsIdxAddr);
+
+ builder.setInsertionPointAfter(ifOp);
+ }
+
+ // By-ref: yield the accumulator (LHS)
+ mlir::omp::YieldOp::create(builder, loc, lhs);
+ };
+
+ // RecordType is always by-ref
+ bool isByRef = true;
+ mlir::Location loc = converter.getCurrentLocation();
+ mlir::Type redType = fir::ReferenceType::get(lpCondType);
+ std::string reductionName = ReductionProcessor::getReductionName(
+ "lp_cond", converter.getKindMap(), redType, isByRef);
+
+ return ReductionProcessor::createDeclareReductionHelper<
+ mlir::omp::DeclareReductionOp>(converter, reductionName, redType, loc,
+ isByRef, genCombinerCB, genInitValueCB);
+}
+
+/// Build a FIR RecordType for conditional lastprivate reduction.
+/// For symbols {x, y}, creates:
+/// !fir.type<_lp_cond_t.lN.M{x:T_x, y:T_y, kx:i64, ky:i64}>
+/// where N is the source line number and M is a monotonic counter.
+static fir::RecordType buildConditionalLpType(
+ lower::AbstractConverter &converter,
+ const llvm::SetVector<const semantics::Symbol *> &condLpSyms,
+ mlir::Location loc) {
+ fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+ mlir::MLIRContext *context = builder.getContext();
+
+ // Derive a unique suffix from the source location and a monotonic counter.
+ // The line number makes names traceable to source; the counter prevents
+ // collisions when INCLUDE files place directives on identical line numbers.
+ // Use atomic for thread-safety in case flang ever lowers in parallel.
+ static std::atomic<unsigned> counter{0};
+ unsigned line = 0;
+ if (auto fileLoc = mlir::dyn_cast<mlir::FileLineColLoc>(loc))
+ line = fileLoc.getLine();
+ else if (auto fusedLoc = mlir::dyn_cast<mlir::FusedLoc>(loc)) {
+ for (mlir::Location sub : fusedLoc.getLocations()) {
+ if (auto fileSub = mlir::dyn_cast<mlir::FileLineColLoc>(sub)) {
+ line = fileSub.getLine();
+ break;
+ }
+ }
+ }
+ std::string typeName =
+ "_lp_cond_t.l" + std::to_string(line) + "." + std::to_string(counter++);
+
+ auto lpCondType = fir::RecordType::get(context, typeName);
+
+ // if it exists return, else build
+ if (lpCondType.isFinalized())
+ return lpCondType;
+
+ // Build field list: first all value fields, then all index fields.
+ // Grouping values before indices (rather than interleaving value/index
+ // pairs) can reduce padding holes when value types differ from i64.
+ llvm::SmallVector<std::pair<std::string, mlir::Type>> fields;
+
+ // Value fields first.
+ for (const auto *sym : condLpSyms) {
+ std::string symName = sym->name().ToString();
+ mlir::Type symType = converter.genType(*sym);
----------------
sshrestha-aa wrote:
Non-scalars are not allowed with conditional lastprivate. Added a semantic check.
https://github.com/llvm/llvm-project/pull/200086
More information about the flang-commits
mailing list