[flang-commits] [flang] [mlir] [Flang][OpenMP] Support conditional lastprivate on host (PR #200086)
Tom Eccles via flang-commits
flang-commits at lists.llvm.org
Mon Jun 22 07:16:22 PDT 2026
================
@@ -4297,6 +4843,319 @@ getReductionType(lower::AbstractConverter &converter,
return reductionType;
}
+/// Compute a flattened canonical (0-based, always ascending) iteration number
+/// from all loop IVs. For a single loop, this is simply (IV - LB) / step.
+/// For collapsed loops with dimensions d0..dN, the flattened index is:
+/// c0 * (N1*N2*...*Nk) + c1 * (N2*...*Nk) + ... + ck
+/// where ci = (IVi - LBi) / stepi and Ni = (UBi - LBi) / stepi + 1.
+/// This yields a unique monotonic index regardless of loop direction,
+/// which is essential for the combiner's `sgt` comparison to correctly
+/// identify the sequentially last iteration.
+static mlir::Value
+computeFlattenedCanonicalIV(fir::FirOpBuilder &builder, mlir::Location loc,
+ mlir::omp::LoopNestOp loopNestOp) {
+ mlir::Region ®ion = loopNestOp.getRegion();
+ auto lbs = loopNestOp.getLoopLowerBounds();
+ auto ubs = loopNestOp.getLoopUpperBounds();
+ auto steps = loopNestOp.getLoopSteps();
+ unsigned numDims = lbs.size();
+
+ // Use i64 for the flattened index to avoid overflow.
+ mlir::Type i64Ty = builder.getI64Type();
+
+ // Compute canonical IV and trip count for each dimension.
+ llvm::SmallVector<mlir::Value> canonIVs(numDims);
+ llvm::SmallVector<mlir::Value> tripCounts(numDims);
+ for (unsigned d = 0; d < numDims; ++d) {
+ mlir::Value iv = region.front().getArgument(d);
+ mlir::Type ivType = iv.getType();
+ mlir::Value lb = lbs[d];
+ mlir::Value ub = ubs[d];
+ mlir::Value step = steps[d];
+ if (lb.getType() != ivType)
+ lb = fir::ConvertOp::create(builder, loc, ivType, lb);
+ if (ub.getType() != ivType)
+ ub = fir::ConvertOp::create(builder, loc, ivType, ub);
+ if (step.getType() != ivType)
+ step = fir::ConvertOp::create(builder, loc, ivType, step);
+
+ mlir::Value diff = mlir::arith::SubIOp::create(builder, loc, iv, lb);
+ mlir::Value ci = mlir::arith::DivSIOp::create(builder, loc, diff, step);
+ canonIVs[d] = fir::ConvertOp::create(builder, loc, i64Ty, ci);
+
+ // Trip count: (UB - LB) / step + 1 (loop bounds are inclusive).
+ mlir::Value range = mlir::arith::SubIOp::create(builder, loc, ub, lb);
+ mlir::Value trips = mlir::arith::DivSIOp::create(builder, loc, range, step);
+ mlir::Value one = builder.createIntegerConstant(loc, ivType, 1);
+ trips = mlir::arith::AddIOp::create(builder, loc, trips, one);
+ tripCounts[d] = fir::ConvertOp::create(builder, loc, i64Ty, trips);
+ }
+
+ // Flatten: result = c0*N1*N2*...*Nk + c1*N2*...*Nk + ... + ck
+ mlir::Value flatIdx = canonIVs[0];
+ for (unsigned d = 1; d < numDims; ++d) {
+ flatIdx = mlir::arith::MulIOp::create(builder, loc, flatIdx, tripCounts[d]);
+ flatIdx = mlir::arith::AddIOp::create(builder, loc, flatIdx, canonIVs[d]);
+ }
+ return flatIdx;
+}
+
+/// Bind conditional lastprivate symbols to their value fields inside the
+/// reduction struct. This must be called \b before body lowering so that all
+/// references to the LP symbols resolve to struct field addresses directly,
+/// avoiding the need for a post-hoc address-replacement rewrite.
+///
+/// Returns a map from the newly-created struct-field addresses to symbol names
+/// so that \c injectCondLpIndexStores can later locate writes to these fields.
+static llvm::MapVector<mlir::Value, std::string> bindCondLpSymsToStructFields(
+ lower::AbstractConverter &converter, mlir::Location loc,
+ fir::RecordType lpType, mlir::Value structArg,
+ const llvm::SetVector<const semantics::Symbol *> &condLpSyms) {
+ fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+ llvm::MapVector<mlir::Value, std::string> valAddrToSymName;
+ for (const auto *sym : condLpSyms) {
+ std::string symName = sym->name().ToString();
+ unsigned valFieldIdx = lpType.getFieldIndex(symName);
+ mlir::Type valType = lpType.getType(valFieldIdx);
+
+ fir::IntOrValue valFIdx =
+ mlir::IntegerAttr::get(builder.getI32Type(), valFieldIdx);
+ mlir::Value valAddr = fir::CoordinateOp::create(
+ builder, loc, builder.getRefType(valType), structArg,
+ llvm::SmallVector<fir::IntOrValue, 1>{valFIdx});
+
+ converter.bindSymbol(*sym, valAddr);
----------------
tblah wrote:
This won't work for charboxes.
This would have been easier to spot if you added test coverage for all kinds of charcter when you fixed characters.
https://github.com/llvm/llvm-project/pull/200086
More information about the flang-commits
mailing list