[flang-commits] [flang] [mlir] [Flang][OpenMP] Support conditional lastprivate on host (PR #200086)

Sunil Shrestha via flang-commits flang-commits at lists.llvm.org
Thu Jun 18 23:50:35 PDT 2026


================
@@ -4297,6 +4840,317 @@ getReductionType(lower::AbstractConverter &converter,
   return reductionType;
 }
 
+/// Compute a flattened canonical (0-based, always ascending) iteration number
+/// from all loop IVs.  For a single loop, this is simply (IV - LB) / step.
+/// For collapsed loops with dimensions d0..dN, the flattened index is:
+///   c0 * (N1*N2*...*Nk) + c1 * (N2*...*Nk) + ... + ck
+/// where ci = (IVi - LBi) / stepi and Ni = (UBi - LBi) / stepi + 1.
+/// This yields a unique monotonic index regardless of loop direction,
+/// which is essential for the combiner's `sgt` comparison to correctly
+/// identify the sequentially last iteration.
+static mlir::Value
+computeFlattenedCanonicalIV(fir::FirOpBuilder &builder, mlir::Location loc,
+                            mlir::omp::LoopNestOp loopNestOp) {
+  mlir::Region &region = loopNestOp.getRegion();
+  auto lbs = loopNestOp.getLoopLowerBounds();
+  auto ubs = loopNestOp.getLoopUpperBounds();
+  auto steps = loopNestOp.getLoopSteps();
+  unsigned numDims = lbs.size();
+
+  // Use i64 for the flattened index to avoid overflow.
+  mlir::Type i64Ty = builder.getI64Type();
+
+  // Compute canonical IV and trip count for each dimension.
+  llvm::SmallVector<mlir::Value> canonIVs(numDims);
+  llvm::SmallVector<mlir::Value> tripCounts(numDims);
+  for (unsigned d = 0; d < numDims; ++d) {
+    mlir::Value iv = region.front().getArgument(d);
+    mlir::Type ivType = iv.getType();
+    mlir::Value lb = lbs[d];
+    mlir::Value ub = ubs[d];
+    mlir::Value step = steps[d];
+    if (lb.getType() != ivType)
+      lb = fir::ConvertOp::create(builder, loc, ivType, lb);
+    if (ub.getType() != ivType)
+      ub = fir::ConvertOp::create(builder, loc, ivType, ub);
+    if (step.getType() != ivType)
+      step = fir::ConvertOp::create(builder, loc, ivType, step);
+
+    mlir::Value diff = mlir::arith::SubIOp::create(builder, loc, iv, lb);
+    mlir::Value ci = mlir::arith::DivSIOp::create(builder, loc, diff, step);
+    canonIVs[d] = fir::ConvertOp::create(builder, loc, i64Ty, ci);
+
+    // Trip count: (UB - LB) / step + 1  (loop bounds are inclusive).
+    mlir::Value range = mlir::arith::SubIOp::create(builder, loc, ub, lb);
+    mlir::Value trips = mlir::arith::DivSIOp::create(builder, loc, range, step);
+    mlir::Value one = builder.createIntegerConstant(loc, ivType, 1);
+    trips = mlir::arith::AddIOp::create(builder, loc, trips, one);
+    tripCounts[d] = fir::ConvertOp::create(builder, loc, i64Ty, trips);
+  }
+
+  // Flatten: result = c0*N1*N2*...*Nk + c1*N2*...*Nk + ... + ck
+  mlir::Value flatIdx = canonIVs[0];
+  for (unsigned d = 1; d < numDims; ++d) {
+    flatIdx = mlir::arith::MulIOp::create(builder, loc, flatIdx, tripCounts[d]);
+    flatIdx = mlir::arith::AddIOp::create(builder, loc, flatIdx, canonIVs[d]);
+  }
+  return flatIdx;
+}
+
+/// Bind conditional lastprivate symbols to their value fields inside the
+/// reduction struct.  This must be called \b before body lowering so that all
+/// references to the LP symbols resolve to struct field addresses directly,
+/// avoiding the need for a post-hoc address-replacement rewrite.
+///
+/// Returns a map from the newly-created struct-field addresses to symbol names
+/// so that \c injectCondLpIndexStores can later locate writes to these fields.
+static llvm::MapVector<mlir::Value, std::string> bindCondLpSymsToStructFields(
+    lower::AbstractConverter &converter, mlir::Location loc,
+    fir::RecordType lpType, mlir::Value structArg,
+    const llvm::SetVector<const semantics::Symbol *> &condLpSyms) {
+  fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+  llvm::MapVector<mlir::Value, std::string> valAddrToSymName;
+  for (const auto *sym : condLpSyms) {
+    std::string symName = sym->name().ToString();
+    unsigned valFieldIdx = lpType.getFieldIndex(symName);
+    mlir::Type valType = lpType.getType(valFieldIdx);
+
+    fir::IntOrValue valFIdx =
+        mlir::IntegerAttr::get(builder.getI32Type(), valFieldIdx);
+    mlir::Value valAddr = fir::CoordinateOp::create(
+        builder, loc, builder.getRefType(valType), structArg,
+        llvm::SmallVector<fir::IntOrValue, 1>{valFIdx});
+
+    converter.bindSymbol(*sym, valAddr);
+    valAddrToSymName[valAddr] = symName;
+  }
+  return valAddrToSymName;
+}
+
+/// Walk the given region to find assignments (hlfir.assign / fir.store) that
+/// target one of the struct value-field addresses in \p valAddrToSymName and
+/// inject a store of the index value (produced by \p genIndexVal) into the
+/// corresponding index field immediately after each such assignment.
+static void injectCondLpIndexStores(
+    fir::FirOpBuilder &builder, mlir::Location loc, fir::RecordType lpType,
+    mlir::Value structArg, mlir::Region &region,
+    const llvm::MapVector<mlir::Value, std::string> &valAddrToSymName,
+    llvm::function_ref<mlir::Value(fir::FirOpBuilder &, mlir::Location)>
+        genIndexVal) {
+  // Look through hlfir.declare to find the underlying struct field address.
+  // When symbols are bound via bindCondLpSymsToStructFields, the lowering
+  // wraps the fir.coordinate_of result in hlfir.declare, so the actual write
+  // target is the declare result rather than the raw coordinate_of.
+  auto lookThroughDeclare = [](mlir::Value v) -> mlir::Value {
+    if (auto declOp = v.getDefiningOp<hlfir::DeclareOp>())
+      return declOp.getMemref();
+    return v;
+  };
+
+  llvm::SmallVector<mlir::Operation *> toAnnotate;
+  region.walk([&](hlfir::AssignOp assignOp) {
+    if (valAddrToSymName.count(lookThroughDeclare(assignOp.getLhs())))
+      toAnnotate.push_back(assignOp);
+  });
+  region.walk([&](fir::StoreOp storeOp) {
+    if (valAddrToSymName.count(lookThroughDeclare(storeOp.getMemref())))
+      toAnnotate.push_back(storeOp);
+  });
+
+  // Compute the index value once at the region entry so that it dominates
+  // all write sites (which may be inside nested fir.if blocks).
+  mlir::Value indexVal;
+  if (!toAnnotate.empty()) {
+    mlir::OpBuilder::InsertionGuard guard(builder);
+    builder.setInsertionPointToStart(&region.front());
+    indexVal = genIndexVal(builder, loc);
+    if (indexVal.getType() != builder.getI64Type())
+      indexVal =
+          fir::ConvertOp::create(builder, loc, builder.getI64Type(), indexVal);
+  }
+
+  for (mlir::Operation *writeOp : toAnnotate) {
+    mlir::Value target;
+    if (auto assignOp = mlir::dyn_cast<hlfir::AssignOp>(writeOp))
+      target = lookThroughDeclare(assignOp.getLhs());
+    else
+      target =
+          lookThroughDeclare(mlir::cast<fir::StoreOp>(writeOp).getMemref());
+
+    std::string symName = valAddrToSymName.lookup(target);
+    unsigned idxFieldIdx = lpType.getFieldIndex("$" + symName);
+    mlir::Type idxType = lpType.getType(idxFieldIdx);
+
+    mlir::OpBuilder::InsertionGuard guard(builder);
+    builder.setInsertionPointAfter(writeOp);
+
+    fir::IntOrValue idxFIdx =
+        mlir::IntegerAttr::get(builder.getI32Type(), idxFieldIdx);
+    mlir::Value idxAddr = fir::CoordinateOp::create(
+        builder, loc, builder.getRefType(idxType), structArg,
+        llvm::SmallVector<fir::IntOrValue, 1>{idxFIdx});
+
+    fir::StoreOp::create(builder, loc, indexVal, idxAddr);
+  }
+}
+
+static mlir::omp::DeclareReductionOp buildConditionalLastPrivateReduction(
+    lower::AbstractConverter &converter, fir::RecordType lpCondType,
+    const llvm::SetVector<const semantics::Symbol *> &condLpSyms) {
+
+  // Init callback: initialize all fields of the lp_t struct.
+  // Value fields get 0; index fields get -1.
+  //
+  // Returns a null mlir::Value to signal that initialization has already
+  // been performed directly on ompPriv.  The reduction infrastructure
+  // (populateByRefInitAndCleanupRegions → initAndCleanupUnboxedDerivedType)
+  // checks for a non-null scalarInitValue before emitting a store, so
+  // returning null here safely skips the redundant store.
+  auto genInitValueCB = [lpCondType](fir::FirOpBuilder &builder,
+                                     mlir::Location loc, mlir::Type type,
+                                     mlir::Value ompOrig,
+                                     mlir::Value ompPriv) -> mlir::Value {
+    initConditionalLpStruct(builder, loc, lpCondType, ompPriv);
+    return mlir::Value{};
+  };
+
+  // Combiner callback: for each (value, index) pair, pick the later iteration.
+  // Fields are arranged as: {val_0, ..., val_{N-1}, idx_0, ..., idx_{N-1}}
+  // where idx field names are "$" + val field name.
+  // If rhs.idx > lhs.idx, copy rhs value and index into lhs.
+  auto genCombinerCB = [lpCondType](fir::FirOpBuilder &builder,
+                                    mlir::Location loc, mlir::Type type,
+                                    mlir::Value lhs, mlir::Value rhs,
+                                    bool isByRef) {
+    fir::RecordType lpType = lpCondType; // non-const copy for getFieldIndex
+    auto fields = lpType.getTypeList();
+    unsigned numVars = fields.size() / 2;
+
+    // Walk the first half (value fields). Index field name = "$" +
+    // value name.  The "$" character is invalid in Fortran identifiers,
+    // so the prefix cannot collide with any user variable name.
+    for (unsigned i = 0; i < numVars; ++i) {
+      auto [valName, valType] = fields[i];
+      std::string idxName = "$" + valName;
+      unsigned valIdx = lpType.getFieldIndex(valName);
+      unsigned idxIdx = lpType.getFieldIndex(idxName);
+      mlir::Type idxType = lpType.getType(idxIdx);
+
+      // Get addresses of LHS and RHS index fields
+      fir::IntOrValue idxFieldIdx =
+          mlir::IntegerAttr::get(builder.getI32Type(), idxIdx);
+      mlir::Value lhsIdxAddr = fir::CoordinateOp::create(
+          builder, loc, builder.getRefType(idxType), lhs,
+          llvm::SmallVector<fir::IntOrValue, 1>{idxFieldIdx});
+      mlir::Value rhsIdxAddr = fir::CoordinateOp::create(
+          builder, loc, builder.getRefType(idxType), rhs,
+          llvm::SmallVector<fir::IntOrValue, 1>{idxFieldIdx});
+
+      mlir::Value lhsIdx = fir::LoadOp::create(builder, loc, lhsIdxAddr);
+      mlir::Value rhsIdx = fir::LoadOp::create(builder, loc, rhsIdxAddr);
+
+      // Compare: rhs index > lhs index  (signed, iteration indices)
+      mlir::Value cmp = mlir::arith::CmpIOp::create(
+          builder, loc, mlir::arith::CmpIPredicate::sgt, rhsIdx, lhsIdx);
+
+      // If RHS comes from a later iteration, copy its value and index to LHS
+      auto ifOp = fir::IfOp::create(builder, loc, cmp, /*else*/ false);
+      builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
+
+      // Copy value field: rhs.val_s → lhs.val_s
+      fir::IntOrValue valFieldIdx =
+          mlir::IntegerAttr::get(builder.getI32Type(), valIdx);
+      mlir::Value rhsValAddr = fir::CoordinateOp::create(
+          builder, loc, builder.getRefType(valType), rhs,
+          llvm::SmallVector<fir::IntOrValue, 1>{valFieldIdx});
+      mlir::Value lhsValAddr = fir::CoordinateOp::create(
+          builder, loc, builder.getRefType(valType), lhs,
+          llvm::SmallVector<fir::IntOrValue, 1>{valFieldIdx});
+      mlir::Value rhsVal = fir::LoadOp::create(builder, loc, rhsValAddr);
+      fir::StoreOp::create(builder, loc, rhsVal, lhsValAddr);
+
+      // Copy index field: rhs.idx_s → lhs.idx_s
+      fir::StoreOp::create(builder, loc, rhsIdx, lhsIdxAddr);
+
+      builder.setInsertionPointAfter(ifOp);
+    }
+
+    // By-ref: yield the accumulator (LHS)
+    mlir::omp::YieldOp::create(builder, loc, lhs);
+  };
+
+  // RecordType is always by-ref
+  bool isByRef = true;
+  mlir::Location loc = converter.getCurrentLocation();
+  mlir::Type redType = fir::ReferenceType::get(lpCondType);
+  std::string reductionName = ReductionProcessor::getReductionName(
+      "lp_cond", converter.getKindMap(), redType, isByRef);
+
+  return ReductionProcessor::createDeclareReductionHelper<
+      mlir::omp::DeclareReductionOp>(converter, reductionName, redType, loc,
+                                     isByRef, genCombinerCB, genInitValueCB);
+}
+
+/// Build a FIR RecordType for conditional lastprivate reduction.
+/// For symbols {x, y}, creates:
+///   !fir.type<_lp_cond_t.lN.M{x:T_x, y:T_y, kx:i64, ky:i64}>
+/// where N is the source line number and M is a monotonic counter.
+static fir::RecordType buildConditionalLpType(
+    lower::AbstractConverter &converter,
+    const llvm::SetVector<const semantics::Symbol *> &condLpSyms,
+    mlir::Location loc) {
+  fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+  mlir::MLIRContext *context = builder.getContext();
+
+  // Derive a unique suffix from the source location and a monotonic counter.
+  // The line number makes names traceable to source; the counter prevents
+  // collisions when INCLUDE files place directives on identical line numbers.
+  // Use atomic for thread-safety in case flang ever lowers in parallel.
+  static std::atomic<unsigned> counter{0};
+  unsigned line = 0;
+  if (auto fileLoc = mlir::dyn_cast<mlir::FileLineColLoc>(loc))
+    line = fileLoc.getLine();
+  else if (auto fusedLoc = mlir::dyn_cast<mlir::FusedLoc>(loc)) {
+    for (mlir::Location sub : fusedLoc.getLocations()) {
+      if (auto fileSub = mlir::dyn_cast<mlir::FileLineColLoc>(sub)) {
+        line = fileSub.getLine();
+        break;
+      }
+    }
+  }
+  std::string typeName =
+      "_lp_cond_t.l" + std::to_string(line) + "." + std::to_string(counter++);
+
+  auto lpCondType = fir::RecordType::get(context, typeName);
+
+  // if it exists return, else build
+  if (lpCondType.isFinalized())
+    return lpCondType;
+
+  // Build field list: first all value fields, then all index fields.
+  // Grouping values before indices (rather than interleaving value/index
+  // pairs) can reduce padding holes when value types differ from i64.
+  llvm::SmallVector<std::pair<std::string, mlir::Type>> fields;
+
+  // Value fields first.
+  for (const auto *sym : condLpSyms) {
+    std::string symName = sym->name().ToString();
+    mlir::Type symType = converter.genType(*sym);
----------------
sshrestha-aa wrote:

Non-scalars are not allowed with conditional lastprivate. Added a semantic check.

https://github.com/llvm/llvm-project/pull/200086


More information about the flang-commits mailing list