[flang-commits] [flang] [Flang][OpenMP] Minimize host ops remaining in device compilation (PR #137200)

Pranav Bhandarkar via flang-commits flang-commits at lists.llvm.org
Wed Aug 27 22:54:30 PDT 2025


================
@@ -102,5 +147,305 @@ class FunctionFilteringPass
       return WalkResult::advance();
     });
   }
+
+private:
+  /// Rewrite the given host device function containing \c omp.target
+  /// operations, to remove host-only operations that are not used by device
+  /// codegen.
+  ///
+  /// It is based on the expected form of the MLIR module as produced by Flang
+  /// lowering and it performs the following mutations:
+  ///   - Replace all values returned by the function with \c fir.undefined.
+  ///   - \c omp.target operations are moved to the end of the function. If they
+  ///     are nested inside of any other operations, they are hoisted out of
+  ///     them.
+  ///   - \c depend, \c device and \c if clauses are removed from these target
+  ///     functions. Values used to initialize other clauses are replaced by
+  ///     placeholders as follows:
+  ///     - Values defined by block arguments are replaced by placeholders only
+  ///       if they are not attached to the parent \c func.func operation. In
+  ///       that case, they are passed unmodified.
+  ///     - \c arith.constant and \c fir.address_of ops are maintained.
+  ///     - Values of type \c fir.boxchar are replaced with a combination of
+  ///       \c fir.alloca for a single bit and a \c fir.emboxchar.
+  ///     - Other values are replaced by a combination of an \c fir.alloca for a
+  ///       single bit and an \c fir.convert to the original type of the value.
+  ///       This can be done because the code eventually generated for these
+  ///       operations will be discarded, as they aren't runnable by the target
+  ///       device.
+  ///   - \c omp.map.info operations associated to these target regions are
+  ///     preserved. These are moved above all \c omp.target and sorted to
+  ///     satisfy dependencies among them.
+  ///   - \c bounds arguments are removed from \c omp.map.info operations.
+  ///   - \c var_ptr and \c var_ptr_ptr arguments of \c omp.map.info are
+  ///     handled as follows:
+  ///     - \c var_ptr_ptr is expected to be defined by a \c fir.box_offset
+  ///       operation which is preserved. Otherwise, the pass will fail.
+  ///     - \c var_ptr can be defined by an \c hlfir.declare which is also
+  ///       preserved. Its \c memref argument is replaced by a placeholder or
+  ///       maintained, similarly to non-map clauses of target operations
+  ///       described above. If it has \c shape or \c typeparams arguments, they
+  ///       are replaced by applicable constants. \c dummy_scope arguments
+  ///       are discarded.
+  ///   - Every other operation not located inside of an \c omp.target is
+  ///     removed.
+  LogicalResult rewriteHostFunction(func::FuncOp funcOp) {
+    Region &region = funcOp.getRegion();
+
+    // Collect target operations inside of the function.
+    llvm::SmallVector<omp::TargetOp> targetOps;
+    region.walk<WalkOrder::PreOrder>([&](Operation *op) {
+      // Skip the inside of omp.target regions, since these contain device code.
+      if (auto targetOp = dyn_cast<omp::TargetOp>(op)) {
+        targetOps.push_back(targetOp);
+        return WalkResult::skip();
+      }
+
+      // Replace omp.target_data entry block argument uses with the value used
+      // to initialize the associated omp.map.info operation. This way,
+      // references are still valid once the omp.target operation has been
+      // extracted out of the omp.target_data region.
+      if (auto targetDataOp = dyn_cast<omp::TargetDataOp>(op)) {
+        llvm::SmallVector<std::pair<Value, BlockArgument>> argPairs;
+        cast<omp::BlockArgOpenMPOpInterface>(*targetDataOp)
+            .getBlockArgsPairs(argPairs);
+        for (auto [operand, blockArg] : argPairs) {
+          auto mapInfo = cast<omp::MapInfoOp>(operand.getDefiningOp());
+          Value varPtr = mapInfo.getVarPtr();
+
+          // If the var_ptr operand of the omp.map.info op defining this entry
+          // block argument is an hlfir.declare, the uses of all users of that
+          // entry block argument that are themselves hlfir.declare are replaced
+          // by values produced by the outer one.
+          //
+          // This prevents this pass from producing chains of hlfir.declare of
+          // the type:
+          // %0 = ...
+          // %1:2 = hlfir.declare %0
+          // %2:2 = hlfir.declare %1#1...
+          // %3 = omp.map.info var_ptr(%2#1 ...
+          if (auto outerDeclare = varPtr.getDefiningOp<hlfir::DeclareOp>())
+            for (Operation *user : blockArg.getUsers())
+              if (isa<hlfir::DeclareOp>(user))
+                user->replaceAllUsesWith(outerDeclare);
+
+          // All remaining uses of the entry block argument are replaced with
+          // the var_ptr initialization value.
+          blockArg.replaceAllUsesWith(varPtr);
+        }
+      }
+      return WalkResult::advance();
+    });
+
+    // Make a temporary clone of the parent operation with an empty region,
+    // and update all references to entry block arguments to those of the new
+    // region. Users will later either be moved to the new region or deleted
+    // when the original region is replaced by the new.
+    OpBuilder builder(&getContext());
+    builder.setInsertionPointAfter(funcOp);
+    Operation *newOp = builder.cloneWithoutRegions(funcOp);
+    Block &block = newOp->getRegion(0).emplaceBlock();
+
+    llvm::SmallVector<Location> locs;
+    locs.reserve(region.getNumArguments());
+    llvm::transform(region.getArguments(), std::back_inserter(locs),
+                    [](const BlockArgument &arg) { return arg.getLoc(); });
+    block.addArguments(region.getArgumentTypes(), locs);
+
+    for (auto [oldArg, newArg] :
+         llvm::zip_equal(region.getArguments(), block.getArguments()))
+      oldArg.replaceAllUsesWith(newArg);
+
+    // Collect omp.map.info ops while satisfying interdependencies and remove
+    // operands that aren't used by target device codegen.
+    //
+    // This logic must be updated whenever operands to omp.target change.
+    llvm::SetVector<Value> rewriteValues;
+    llvm::SetVector<omp::MapInfoOp> mapInfos;
+    for (omp::TargetOp targetOp : targetOps) {
+      assert(targetOp.getHostEvalVars().empty() &&
+             "unexpected host_eval in target device module");
+
+      // Variables unused by the device.
+      targetOp.getDependVarsMutable().clear();
+      targetOp.setDependKindsAttr(nullptr);
+      targetOp.getDeviceMutable().clear();
+      targetOp.getIfExprMutable().clear();
+
+      // TODO: Clear some of these operands rather than rewriting them,
+      // depending on whether they are needed by device codegen once support for
+      // them is fully implemented.
+      for (Value allocVar : targetOp.getAllocateVars())
+        collectRewrite(allocVar, rewriteValues);
+      for (Value allocVar : targetOp.getAllocatorVars())
+        collectRewrite(allocVar, rewriteValues);
+      for (Value inReduction : targetOp.getInReductionVars())
+        collectRewrite(inReduction, rewriteValues);
+      for (Value isDevPtr : targetOp.getIsDevicePtrVars())
+        collectRewrite(isDevPtr, rewriteValues);
+      for (Value mapVar : targetOp.getHasDeviceAddrVars())
+        collectRewrite(cast<omp::MapInfoOp>(mapVar.getDefiningOp()), mapInfos);
+      for (Value mapVar : targetOp.getMapVars())
+        collectRewrite(cast<omp::MapInfoOp>(mapVar.getDefiningOp()), mapInfos);
+      for (Value privateVar : targetOp.getPrivateVars())
+        collectRewrite(privateVar, rewriteValues);
+      if (Value threadLimit = targetOp.getThreadLimit())
+        collectRewrite(threadLimit, rewriteValues);
+    }
+
+    // Move omp.map.info ops to the new block and collect dependencies.
+    llvm::SetVector<hlfir::DeclareOp> declareOps;
+    llvm::SetVector<fir::BoxOffsetOp> boxOffsets;
+    for (omp::MapInfoOp mapOp : mapInfos) {
+      if (auto declareOp = dyn_cast_if_present<hlfir::DeclareOp>(
+              mapOp.getVarPtr().getDefiningOp()))
+        collectRewrite(declareOp, declareOps);
+      else
+        collectRewrite(mapOp.getVarPtr(), rewriteValues);
+
+      if (Value varPtrPtr = mapOp.getVarPtrPtr()) {
+        if (auto boxOffset = llvm::dyn_cast_if_present<fir::BoxOffsetOp>(
+                varPtrPtr.getDefiningOp()))
+          collectRewrite(boxOffset, boxOffsets);
+        else
+          return mapOp->emitOpError() << "var_ptr_ptr rewrite only supported "
+                                         "if defined by fir.box_offset";
+      }
+
+      // Bounds are not used during target device codegen.
+      mapOp.getBoundsMutable().clear();
+      mapOp->moveBefore(&block, block.end());
+    }
+
+    // Create a temporary marker to simplify the op moving process below.
+    builder.setInsertionPointToStart(&block);
+    auto marker = builder.create<fir::UndefOp>(builder.getUnknownLoc(),
+                                               builder.getNoneType());
+    builder.setInsertionPoint(marker);
+
+    // Handle dependencies of hlfir.declare ops.
+    for (hlfir::DeclareOp declareOp : declareOps) {
+      collectRewrite(declareOp.getMemref(), rewriteValues);
+
+      // Shape and typeparams aren't needed for target device codegen, but
+      // removing them would break verifiers.
+      Value zero;
+      if (declareOp.getShape() || !declareOp.getTypeparams().empty())
+        zero = builder.create<arith::ConstantOp>(declareOp.getLoc(),
+                                                 builder.getI64IntegerAttr(0));
+
+      if (auto shape = declareOp.getShape()) {
+        // The pre-cg rewrite pass requires the shape to be defined by one of
+        // fir.shape, fir.shapeshift or fir.shift, so we need to make sure it's
+        // still defined by one of these after this pass.
+        Operation *shapeOp = shape.getDefiningOp();
+        llvm::SmallVector<Value> extents(shapeOp->getNumOperands(), zero);
+        Value newShape =
+            llvm::TypeSwitch<Operation *, Value>(shapeOp)
+                .Case([&](fir::ShapeOp op) {
+                  return builder.create<fir::ShapeOp>(op.getLoc(), extents);
+                })
+                .Case([&](fir::ShapeShiftOp op) {
+                  auto type = fir::ShapeShiftType::get(op.getContext(),
+                                                       extents.size() / 2);
+                  return builder.create<fir::ShapeShiftOp>(op.getLoc(), type,
+                                                           extents);
+                })
+                .Case([&](fir::ShiftOp op) {
+                  auto type =
+                      fir::ShiftType::get(op.getContext(), extents.size());
+                  return builder.create<fir::ShiftOp>(op.getLoc(), type,
+                                                      extents);
+                })
+                .Default([](Operation *op) {
+                  op->emitOpError()
+                      << "hlfir.declare shape expected to be one of: "
+                         "fir.shape, fir.shapeshift or fir.shift";
+                  return nullptr;
+                });
+
+        if (!newShape)
+          return failure();
+
+        declareOp.getShapeMutable().assign(newShape);
+      }
+
+      for (OpOperand &typeParam : declareOp.getTypeparamsMutable())
+        typeParam.assign(zero);
+
+      declareOp.getDummyScopeMutable().clear();
+    }
+
+    // We don't actually need the proper initialization, but rather just
+    // maintain the basic form of these operands. We create 1-bit placeholder
+    // allocas that we "typecast" to the expected type and replace all uses.
+    // Using fir.undefined here instead is not possible because these variables
+    // cannot be constants, as that would trigger different codegen for target
+    // regions.
+    for (Value value : rewriteValues) {
+      Location loc = value.getLoc();
+      Value rewriteValue;
+      // If it's defined by fir.address_of, then we need to keep that op as
+      // well because it might be pointing to a 'declare target' global.
+      // Constants can also trigger different codegen paths, so we keep them as
+      // well.
+      if (isa_and_present<arith::ConstantOp, fir::AddrOfOp>(
+              value.getDefiningOp())) {
+        rewriteValue = builder.clone(*value.getDefiningOp())->getResult(0);
+      } else if (auto boxCharType =
+                     dyn_cast<fir::BoxCharType>(value.getType())) {
+        // !fir.boxchar types cannot be directly obtained by converting a
+        // !fir.ref<i1>, as they aren't reference types. Since they can appear
+        // representing some `target firstprivate` clauses, we need to create
+        // a special case here based on creating a placeholder fir.emboxchar op.
+        MLIRContext *ctx = &getContext();
+        fir::KindTy kind = boxCharType.getKind();
+        auto placeholder = builder.create<fir::AllocaOp>(
+            loc, fir::CharacterType::getSingleton(ctx, kind));
+        auto one = builder.create<arith::ConstantOp>(
+            loc, builder.getI32Type(), builder.getI32IntegerAttr(1));
+        rewriteValue = builder.create<fir::EmboxCharOp>(loc, boxCharType,
+                                                        placeholder, one);
+      } else {
+        Value placeholder =
----------------
bhandarkar-pranav wrote:

nit: why not simply `fir.alloca` of `value.getType()` directly since it is all going to be discarded in any case for device codegen

https://github.com/llvm/llvm-project/pull/137200


More information about the flang-commits mailing list