[flang-commits] [flang] [Flang][OpenMP] Minimize host ops remaining in device compilation (PR #137200)
Pranav Bhandarkar via flang-commits
flang-commits at lists.llvm.org
Wed Aug 27 22:54:30 PDT 2025
================
@@ -102,5 +147,305 @@ class FunctionFilteringPass
return WalkResult::advance();
});
}
+
+private:
+ /// Rewrite the given host device function containing \c omp.target
+ /// operations, to remove host-only operations that are not used by device
+ /// codegen.
+ ///
+ /// It is based on the expected form of the MLIR module as produced by Flang
+ /// lowering and it performs the following mutations:
+ /// - Replace all values returned by the function with \c fir.undefined.
+ /// - \c omp.target operations are moved to the end of the function. If they
+ /// are nested inside of any other operations, they are hoisted out of
+ /// them.
+ /// - \c depend, \c device and \c if clauses are removed from these target
+ /// functions. Values used to initialize other clauses are replaced by
+ /// placeholders as follows:
+ /// - Values defined by block arguments are replaced by placeholders only
+ /// if they are not attached to the parent \c func.func operation. In
+ /// that case, they are passed unmodified.
+ /// - \c arith.constant and \c fir.address_of ops are maintained.
+ /// - Values of type \c fir.boxchar are replaced with a combination of
+ /// \c fir.alloca for a single bit and a \c fir.emboxchar.
+ /// - Other values are replaced by a combination of an \c fir.alloca for a
+ /// single bit and an \c fir.convert to the original type of the value.
+ /// This can be done because the code eventually generated for these
+ /// operations will be discarded, as they aren't runnable by the target
+ /// device.
+ /// - \c omp.map.info operations associated to these target regions are
+ /// preserved. These are moved above all \c omp.target and sorted to
+ /// satisfy dependencies among them.
+ /// - \c bounds arguments are removed from \c omp.map.info operations.
+ /// - \c var_ptr and \c var_ptr_ptr arguments of \c omp.map.info are
+ /// handled as follows:
+ /// - \c var_ptr_ptr is expected to be defined by a \c fir.box_offset
+ /// operation which is preserved. Otherwise, the pass will fail.
+ /// - \c var_ptr can be defined by an \c hlfir.declare which is also
+ /// preserved. Its \c memref argument is replaced by a placeholder or
+ /// maintained, similarly to non-map clauses of target operations
+ /// described above. If it has \c shape or \c typeparams arguments, they
+ /// are replaced by applicable constants. \c dummy_scope arguments
+ /// are discarded.
+ /// - Every other operation not located inside of an \c omp.target is
+ /// removed.
+ LogicalResult rewriteHostFunction(func::FuncOp funcOp) {
+ Region ®ion = funcOp.getRegion();
+
+ // Collect target operations inside of the function.
+ llvm::SmallVector<omp::TargetOp> targetOps;
+ region.walk<WalkOrder::PreOrder>([&](Operation *op) {
+ // Skip the inside of omp.target regions, since these contain device code.
+ if (auto targetOp = dyn_cast<omp::TargetOp>(op)) {
+ targetOps.push_back(targetOp);
+ return WalkResult::skip();
+ }
+
+ // Replace omp.target_data entry block argument uses with the value used
+ // to initialize the associated omp.map.info operation. This way,
+ // references are still valid once the omp.target operation has been
+ // extracted out of the omp.target_data region.
+ if (auto targetDataOp = dyn_cast<omp::TargetDataOp>(op)) {
+ llvm::SmallVector<std::pair<Value, BlockArgument>> argPairs;
+ cast<omp::BlockArgOpenMPOpInterface>(*targetDataOp)
+ .getBlockArgsPairs(argPairs);
+ for (auto [operand, blockArg] : argPairs) {
+ auto mapInfo = cast<omp::MapInfoOp>(operand.getDefiningOp());
+ Value varPtr = mapInfo.getVarPtr();
+
+ // If the var_ptr operand of the omp.map.info op defining this entry
+ // block argument is an hlfir.declare, the uses of all users of that
+ // entry block argument that are themselves hlfir.declare are replaced
+ // by values produced by the outer one.
+ //
+ // This prevents this pass from producing chains of hlfir.declare of
+ // the type:
+ // %0 = ...
+ // %1:2 = hlfir.declare %0
+ // %2:2 = hlfir.declare %1#1...
+ // %3 = omp.map.info var_ptr(%2#1 ...
+ if (auto outerDeclare = varPtr.getDefiningOp<hlfir::DeclareOp>())
+ for (Operation *user : blockArg.getUsers())
+ if (isa<hlfir::DeclareOp>(user))
+ user->replaceAllUsesWith(outerDeclare);
+
+ // All remaining uses of the entry block argument are replaced with
+ // the var_ptr initialization value.
+ blockArg.replaceAllUsesWith(varPtr);
+ }
+ }
+ return WalkResult::advance();
+ });
+
+ // Make a temporary clone of the parent operation with an empty region,
+ // and update all references to entry block arguments to those of the new
+ // region. Users will later either be moved to the new region or deleted
+ // when the original region is replaced by the new.
+ OpBuilder builder(&getContext());
+ builder.setInsertionPointAfter(funcOp);
+ Operation *newOp = builder.cloneWithoutRegions(funcOp);
+ Block &block = newOp->getRegion(0).emplaceBlock();
+
+ llvm::SmallVector<Location> locs;
+ locs.reserve(region.getNumArguments());
+ llvm::transform(region.getArguments(), std::back_inserter(locs),
+ [](const BlockArgument &arg) { return arg.getLoc(); });
+ block.addArguments(region.getArgumentTypes(), locs);
+
+ for (auto [oldArg, newArg] :
+ llvm::zip_equal(region.getArguments(), block.getArguments()))
+ oldArg.replaceAllUsesWith(newArg);
+
+ // Collect omp.map.info ops while satisfying interdependencies and remove
+ // operands that aren't used by target device codegen.
+ //
+ // This logic must be updated whenever operands to omp.target change.
+ llvm::SetVector<Value> rewriteValues;
+ llvm::SetVector<omp::MapInfoOp> mapInfos;
+ for (omp::TargetOp targetOp : targetOps) {
+ assert(targetOp.getHostEvalVars().empty() &&
+ "unexpected host_eval in target device module");
+
+ // Variables unused by the device.
+ targetOp.getDependVarsMutable().clear();
+ targetOp.setDependKindsAttr(nullptr);
+ targetOp.getDeviceMutable().clear();
+ targetOp.getIfExprMutable().clear();
+
+ // TODO: Clear some of these operands rather than rewriting them,
+ // depending on whether they are needed by device codegen once support for
+ // them is fully implemented.
+ for (Value allocVar : targetOp.getAllocateVars())
+ collectRewrite(allocVar, rewriteValues);
+ for (Value allocVar : targetOp.getAllocatorVars())
+ collectRewrite(allocVar, rewriteValues);
+ for (Value inReduction : targetOp.getInReductionVars())
+ collectRewrite(inReduction, rewriteValues);
+ for (Value isDevPtr : targetOp.getIsDevicePtrVars())
+ collectRewrite(isDevPtr, rewriteValues);
+ for (Value mapVar : targetOp.getHasDeviceAddrVars())
+ collectRewrite(cast<omp::MapInfoOp>(mapVar.getDefiningOp()), mapInfos);
+ for (Value mapVar : targetOp.getMapVars())
+ collectRewrite(cast<omp::MapInfoOp>(mapVar.getDefiningOp()), mapInfos);
+ for (Value privateVar : targetOp.getPrivateVars())
+ collectRewrite(privateVar, rewriteValues);
+ if (Value threadLimit = targetOp.getThreadLimit())
+ collectRewrite(threadLimit, rewriteValues);
+ }
+
+ // Move omp.map.info ops to the new block and collect dependencies.
+ llvm::SetVector<hlfir::DeclareOp> declareOps;
+ llvm::SetVector<fir::BoxOffsetOp> boxOffsets;
+ for (omp::MapInfoOp mapOp : mapInfos) {
+ if (auto declareOp = dyn_cast_if_present<hlfir::DeclareOp>(
+ mapOp.getVarPtr().getDefiningOp()))
+ collectRewrite(declareOp, declareOps);
+ else
+ collectRewrite(mapOp.getVarPtr(), rewriteValues);
+
+ if (Value varPtrPtr = mapOp.getVarPtrPtr()) {
+ if (auto boxOffset = llvm::dyn_cast_if_present<fir::BoxOffsetOp>(
+ varPtrPtr.getDefiningOp()))
+ collectRewrite(boxOffset, boxOffsets);
+ else
+ return mapOp->emitOpError() << "var_ptr_ptr rewrite only supported "
+ "if defined by fir.box_offset";
+ }
+
+ // Bounds are not used during target device codegen.
+ mapOp.getBoundsMutable().clear();
+ mapOp->moveBefore(&block, block.end());
+ }
+
+ // Create a temporary marker to simplify the op moving process below.
+ builder.setInsertionPointToStart(&block);
+ auto marker = builder.create<fir::UndefOp>(builder.getUnknownLoc(),
+ builder.getNoneType());
+ builder.setInsertionPoint(marker);
+
+ // Handle dependencies of hlfir.declare ops.
+ for (hlfir::DeclareOp declareOp : declareOps) {
+ collectRewrite(declareOp.getMemref(), rewriteValues);
+
+ // Shape and typeparams aren't needed for target device codegen, but
+ // removing them would break verifiers.
+ Value zero;
+ if (declareOp.getShape() || !declareOp.getTypeparams().empty())
+ zero = builder.create<arith::ConstantOp>(declareOp.getLoc(),
+ builder.getI64IntegerAttr(0));
+
+ if (auto shape = declareOp.getShape()) {
+ // The pre-cg rewrite pass requires the shape to be defined by one of
+ // fir.shape, fir.shapeshift or fir.shift, so we need to make sure it's
+ // still defined by one of these after this pass.
+ Operation *shapeOp = shape.getDefiningOp();
+ llvm::SmallVector<Value> extents(shapeOp->getNumOperands(), zero);
+ Value newShape =
+ llvm::TypeSwitch<Operation *, Value>(shapeOp)
+ .Case([&](fir::ShapeOp op) {
+ return builder.create<fir::ShapeOp>(op.getLoc(), extents);
+ })
+ .Case([&](fir::ShapeShiftOp op) {
+ auto type = fir::ShapeShiftType::get(op.getContext(),
+ extents.size() / 2);
+ return builder.create<fir::ShapeShiftOp>(op.getLoc(), type,
+ extents);
+ })
+ .Case([&](fir::ShiftOp op) {
+ auto type =
+ fir::ShiftType::get(op.getContext(), extents.size());
+ return builder.create<fir::ShiftOp>(op.getLoc(), type,
+ extents);
+ })
+ .Default([](Operation *op) {
+ op->emitOpError()
+ << "hlfir.declare shape expected to be one of: "
+ "fir.shape, fir.shapeshift or fir.shift";
+ return nullptr;
+ });
+
+ if (!newShape)
+ return failure();
+
+ declareOp.getShapeMutable().assign(newShape);
+ }
+
+ for (OpOperand &typeParam : declareOp.getTypeparamsMutable())
+ typeParam.assign(zero);
+
+ declareOp.getDummyScopeMutable().clear();
+ }
+
+ // We don't actually need the proper initialization, but rather just
+ // maintain the basic form of these operands. We create 1-bit placeholder
+ // allocas that we "typecast" to the expected type and replace all uses.
+ // Using fir.undefined here instead is not possible because these variables
+ // cannot be constants, as that would trigger different codegen for target
+ // regions.
+ for (Value value : rewriteValues) {
+ Location loc = value.getLoc();
+ Value rewriteValue;
+ // If it's defined by fir.address_of, then we need to keep that op as
+ // well because it might be pointing to a 'declare target' global.
+ // Constants can also trigger different codegen paths, so we keep them as
+ // well.
+ if (isa_and_present<arith::ConstantOp, fir::AddrOfOp>(
+ value.getDefiningOp())) {
+ rewriteValue = builder.clone(*value.getDefiningOp())->getResult(0);
+ } else if (auto boxCharType =
+ dyn_cast<fir::BoxCharType>(value.getType())) {
+ // !fir.boxchar types cannot be directly obtained by converting a
+ // !fir.ref<i1>, as they aren't reference types. Since they can appear
+ // representing some `target firstprivate` clauses, we need to create
+ // a special case here based on creating a placeholder fir.emboxchar op.
+ MLIRContext *ctx = &getContext();
+ fir::KindTy kind = boxCharType.getKind();
+ auto placeholder = builder.create<fir::AllocaOp>(
+ loc, fir::CharacterType::getSingleton(ctx, kind));
+ auto one = builder.create<arith::ConstantOp>(
+ loc, builder.getI32Type(), builder.getI32IntegerAttr(1));
+ rewriteValue = builder.create<fir::EmboxCharOp>(loc, boxCharType,
+ placeholder, one);
+ } else {
+ Value placeholder =
----------------
bhandarkar-pranav wrote:
nit: why not simply `fir.alloca` of `value.getType()` directly since it is all going to be discarded in any case for device codegen
https://github.com/llvm/llvm-project/pull/137200
More information about the flang-commits
mailing list