[flang-commits] [flang] [flang][HLFIR] Relax InlineElementals to support more than two users (PR #186916)

via flang-commits flang-commits at lists.llvm.org
Wed Apr 22 21:51:58 PDT 2026


================
@@ -31,29 +33,265 @@ namespace hlfir {
 #include "flang/Optimizer/HLFIR/Passes.h.inc"
 } // namespace hlfir
 
+/// Collects all memory values (buffers/references) that the elemental body
+/// reads from. Use MemoryEffectOpInterface for a fail-safe implementation.
+static void getReadDependencies(hlfir::ElementalOp elemental,
+                                llvm::SmallVectorImpl<mlir::Value> &deps) {
+  elemental.getRegion().walk([&](mlir::Operation *op) {
+    // Check if the operation explicitly implements memory effects.
+    if (auto memInterface = mlir::dyn_cast<mlir::MemoryEffectOpInterface>(op)) {
+      llvm::SmallVector<mlir::MemoryEffects::EffectInstance, 4> effects;
+      memInterface.getEffects(effects);
+      bool hasUnspecifiedRead = false;
+      for (const auto &effect : effects) {
+        if (mlir::isa<mlir::MemoryEffects::Read>(effect.getEffect())) {
+          if (mlir::Value val = effect.getValue()) {
+            deps.push_back(val);
+          } else {
+            // Read effect on an unspecified resource (e.g., global state).
+            hasUnspecifiedRead = true;
+          }
+        }
+      }
+      // If the op has a read effect but the specific value is unknown,
+      // conservatively capture all potential reference operands.
+      if (!hasUnspecifiedRead)
+        return;
+    }
+
+    // Fail-safe Fallback: For operations without the interface or with
+    // unspecified effects, capture any external reference used inside.
+    for (mlir::Value operand : op->getOperands()) {
+      if (operand.getParentRegion() != &elemental.getRegion()) {
+        if (mlir::isa<fir::ReferenceType, fir::PointerType, fir::HeapType,
+                      fir::BoxType>(operand.getType())) {
+          deps.push_back(operand);
+        }
+      }
+    }
+  });
+}
+
+/// Checks if an operation 'op' potentially modifies any memory location that
+/// the elemental reads from (captured in 'deps').
+static bool isConflictingWrite(mlir::Operation *op,
+                               const llvm::SmallVectorImpl<mlir::Value> &deps,
+                               mlir::AliasAnalysis &aa) {
+  // Use walk to handle nested regions (fir.if, fir.do_loop, etc.) recursively.
+  mlir::WalkResult result = op->walk([&](mlir::Operation *nestedOp) {
+    // Operations explicitly marked as having no memory effects are safe.
+    if (mlir::isMemoryEffectFree(nestedOp))
+      return mlir::WalkResult::advance();
+
+    // Explicitly allow safe HLFIR/FIR metadata/lifetime operations.
+    if (mlir::isa<hlfir::DeclareOp, hlfir::AssociateOp, hlfir::EndAssociateOp,
+                  fir::AllocaOp, hlfir::NoReassocOp>(nestedOp))
+      return mlir::WalkResult::advance();
+
+    // Check for explicit memory effects via the interface.
+    if (auto memInterface =
+            mlir::dyn_cast<mlir::MemoryEffectOpInterface>(nestedOp)) {
+      llvm::SmallVector<mlir::MemoryEffects::EffectInstance, 4> effects;
+      memInterface.getEffects(effects);
+
+      for (const auto &effect : effects) {
+        // Analyze effects that modify memory or release resources.
+        if (mlir::isa<mlir::MemoryEffects::Write, mlir::MemoryEffects::Free>(
+                effect.getEffect())) {
+          mlir::Value accessedValue = effect.getValue();
+          // Fail-safe: Assuming conflict for Unknown resource (e.g. external
+          // call).
+          if (!accessedValue)
+            return mlir::WalkResult::interrupt();
+
+          // Perform alias analysis against all read dependencies.
+          for (mlir::Value dep : deps) {
+            if (!aa.alias(accessedValue, dep).isNo())
+              return mlir::WalkResult::interrupt();
+          }
+        }
+      }
+    } else if (nestedOp->getNumRegions() == 0) {
+      // Conservative Fallback: If an operation doesn't have  interface and
+      // has no regions (e.g. a fir.call), assume it can modify anything.
+      return mlir::WalkResult::interrupt();
+    }
+
+    return mlir::WalkResult::advance();
+  });
+
+  // Conflict found as walk interrupted.
+  return result.wasInterrupted();
+}
+
+bool isSafeToInline(hlfir::ElementalOp producer, hlfir::ApplyOp applySite,
+                    mlir::AliasAnalysis &aa) {
+  mlir::DominanceInfo domInfo(producer->getParentOp());
+  if (!domInfo.properlyDominates(producer.getOperation(),
+                                 applySite.getOperation()))
+    return false;
+
+  llvm::SmallVector<mlir::Value> deps;
+  getReadDependencies(producer, deps);
+
+  mlir::Operation *func = producer->getParentOfType<mlir::func::FuncOp>();
+  bool conflict = false;
+
+  func->walk([&](mlir::Operation *op) {
+    // Skip the producer and applySite themselves.
+    if (op == producer.getOperation() || op == applySite.getOperation())
+      return mlir::WalkResult::advance();
+
+    // Skip the operation that contains the applySite.
+    // We only care about operations that execute before the applySite
+    // starts or between the producer and the start of the loop.
+    if (op->isAncestor(applySite.getOperation()))
+      return mlir::WalkResult::advance();
+
+    // Only check operations that strictly execute between definition and use.
+    if (domInfo.properlyDominates(producer.getOperation(), op) &&
+        domInfo.dominates(op, applySite.getOperation())) {
+      if (isConflictingWrite(op, deps, aa)) {
+        conflict = true;
+        return mlir::WalkResult::interrupt();
+      }
+    }
+    return mlir::WalkResult::advance();
+  });
+
+  return !conflict;
+}
+
 /// If the elemental has only two uses and those two are an apply operation and
 /// a destroy operation, return those two, otherwise return {}
 static std::optional<std::pair<hlfir::ApplyOp, hlfir::DestroyOp>>
-getTwoUses(hlfir::ElementalOp elemental) {
-  mlir::Operation::user_range users = elemental->getUsers();
-  // don't inline anything with more than one use (plus hfir.destroy)
-  if (std::distance(users.begin(), users.end()) != 2) {
-    return std::nullopt;
-  }
-
+getTwoUses(hlfir::ElementalOp elemental, mlir::AliasAnalysis &aliasAnalysis) {
   // If the ElementalOp must produce a temporary (e.g. for
   // finalization purposes), then we cannot inline it.
   if (hlfir::elementalOpMustProduceTemp(elemental))
     return std::nullopt;
 
   hlfir::ApplyOp apply;
   hlfir::DestroyOp destroy;
-  for (mlir::Operation *user : users)
-    mlir::TypeSwitch<mlir::Operation *, void>(user)
-        .Case([&](hlfir::ApplyOp op) { apply = op; })
-        .Case([&](hlfir::DestroyOp op) { destroy = op; });
+  unsigned applyCount = 0;
+  bool hasOtherUsers = false;
+
+  llvm::SmallVector<mlir::Value> worklist;
+  worklist.push_back(elemental.getResult());
+  llvm::SmallPtrSet<mlir::Value, 16> visited;
+  llvm::SmallPtrSet<mlir::Operation *, 4> uniqueApplies;
+
+  while (!worklist.empty()) {
+    mlir::Value current = worklist.pop_back_val();
+    if (!current || !visited.insert(current).second)
+      continue;
 
-  if (!apply || !destroy)
+    for (mlir::OpOperand &use : current.getUses()) {
+      mlir::Operation *user = use.getOwner();
+
+      mlir::TypeSwitch<mlir::Operation *, void>(user)
+          .Case<hlfir::ApplyOp>([&](hlfir::ApplyOp op) {
+            // Use raw operation pointer to ensure each apply site is
+            // counted only once.
+            if (uniqueApplies.insert(op.getOperation()).second) {
+              apply = op;
+              applyCount++;
+            }
+          })
+          .Case<hlfir::DestroyOp>([&](hlfir::DestroyOp op) {
+            // Track the mandatory destroy operation for the elemental expr.
+            destroy = op;
+          })
+          .Case<hlfir::DeclareOp, fir::ConvertOp>([&](mlir::Operation *op) {
+            // Follow the dataflow through all results of the operation.
+            // For hlfir.declare, this catches both the variable and base
+            // results. For fir.convert, this catches the converted result.
+            for (mlir::Value result : op->getResults()) {
+              worklist.push_back(result);
+            }
+          })
+          // Buffer Consumers - These require the destroy to stay.
+          .Case<hlfir::AssociateOp, hlfir::SumOp, hlfir::AssignOp, fir::CallOp>(
+              [&](mlir::Operation *) { hasOtherUsers = true; })
+          .Case<mlir::BranchOpInterface>([&](mlir::BranchOpInterface branch) {
+            for (unsigned i = 0; i < branch->getNumSuccessors(); ++i) {
+              mlir::SuccessorOperands operands = branch.getSuccessorOperands(i);
+              for (unsigned j = 0; j < operands.size(); ++j) {
+                if (operands[j] == current) {
+                  // The j-th operand of the branch maps to the j-th block
+                  // argument of the successor block.
+                  mlir::Block *successor = branch->getSuccessor(i);
+                  worklist.push_back(successor->getArgument(j));
+                }
+              }
+            }
+          })
+          .Case<fir::ResultOp>([&](fir::ResultOp op) {
+            mlir::Operation *parent = op->getParentOp();
+            // Only forward if the parent is an op that yields values out.
+            if (parent &&
+                mlir::isa<mlir::RegionBranchOpInterface, fir::IfOp,
+                          fir::DoLoopOp, hlfir::ElementalOp>(parent)) {
+              for (auto it : llvm::enumerate(op.getOperands())) {
+                if (it.value() == current) {
+                  // Map the result index to the parent's result index.
+                  unsigned i = it.index();
+                  if (i < parent->getNumResults()) {
+                    worklist.push_back(parent->getResult(i));
+                  }
+                }
+              }
+            } else {
+              // If it's a terminator for an unknown op.
+              hasOtherUsers = true;
+            }
+          })
+          .Default([&](mlir::Operation *op) {
----------------
anoopkg6 wrote:

I’ve updated the Default case to walk nested regions and feed results of metadata operations (like hlfir.declare and fir.convert) back into the main worklist. This ensures we don't miss apply sites hidden behind declarations inside control flow.
Cases involving complex control-flow boundaries (cross-block branches and loop-exits) are now successfully handled. The worklist now traverses these boundaries by mapping branch operands to successor block arguments and region-terminated operands back to parent results. This allows to find unique apply sites even when separated from the producer by structured or unstructured control flow.

https://github.com/llvm/llvm-project/pull/186916


More information about the flang-commits mailing list