[flang-commits] [flang] [flang][HLFIR] Relax InlineElementals to support more than two users (PR #186916)

Tom Eccles via flang-commits flang-commits at lists.llvm.org
Fri May 8 09:22:38 PDT 2026


================
@@ -31,29 +33,265 @@ namespace hlfir {
 #include "flang/Optimizer/HLFIR/Passes.h.inc"
 } // namespace hlfir
 
+/// Collects all memory values (buffers/references) that the elemental body
+/// reads from. Use MemoryEffectOpInterface for a fail-safe implementation.
+static void getReadDependencies(hlfir::ElementalOp elemental,
+                                llvm::SmallVectorImpl<mlir::Value> &deps) {
+  elemental.getRegion().walk([&](mlir::Operation *op) {
+    // Check if the operation explicitly implements memory effects.
+    if (auto memInterface = mlir::dyn_cast<mlir::MemoryEffectOpInterface>(op)) {
+      llvm::SmallVector<mlir::MemoryEffects::EffectInstance, 4> effects;
+      memInterface.getEffects(effects);
+      bool hasUnspecifiedRead = false;
+      for (const auto &effect : effects) {
+        if (mlir::isa<mlir::MemoryEffects::Read>(effect.getEffect())) {
+          if (mlir::Value val = effect.getValue()) {
+            deps.push_back(val);
+          } else {
+            // Read effect on an unspecified resource (e.g., global state).
+            hasUnspecifiedRead = true;
+          }
+        }
+      }
+      // If the op has a read effect but the specific value is unknown,
+      // conservatively capture all potential reference operands.
+      if (!hasUnspecifiedRead)
+        return;
+    }
+
+    // Fail-safe Fallback: For operations without the interface or with
+    // unspecified effects, capture any external reference used inside.
+    for (mlir::Value operand : op->getOperands()) {
+      if (operand.getParentRegion() != &elemental.getRegion()) {
+        if (mlir::isa<fir::ReferenceType, fir::PointerType, fir::HeapType,
+                      fir::BoxType>(operand.getType())) {
+          deps.push_back(operand);
+        }
+      }
+    }
+  });
+}
+
+/// Checks if an operation 'op' potentially modifies any memory location that
+/// the elemental reads from (captured in 'deps').
+static bool isConflictingWrite(mlir::Operation *op,
+                               const llvm::SmallVectorImpl<mlir::Value> &deps,
+                               mlir::AliasAnalysis &aa) {
+  // Use walk to handle nested regions (fir.if, fir.do_loop, etc.) recursively.
+  mlir::WalkResult result = op->walk([&](mlir::Operation *nestedOp) {
+    // Operations explicitly marked as having no memory effects are safe.
+    if (mlir::isMemoryEffectFree(nestedOp))
+      return mlir::WalkResult::advance();
+
+    // Explicitly allow safe HLFIR/FIR metadata/lifetime operations.
+    if (mlir::isa<hlfir::DeclareOp, hlfir::AssociateOp, hlfir::EndAssociateOp,
+                  fir::AllocaOp, hlfir::NoReassocOp>(nestedOp))
+      return mlir::WalkResult::advance();
+
+    // Check for explicit memory effects via the interface.
+    if (auto memInterface =
+            mlir::dyn_cast<mlir::MemoryEffectOpInterface>(nestedOp)) {
+      llvm::SmallVector<mlir::MemoryEffects::EffectInstance, 4> effects;
+      memInterface.getEffects(effects);
+
+      for (const auto &effect : effects) {
+        // Analyze effects that modify memory or release resources.
+        if (mlir::isa<mlir::MemoryEffects::Write, mlir::MemoryEffects::Free>(
+                effect.getEffect())) {
+          mlir::Value accessedValue = effect.getValue();
+          // Fail-safe: Assuming conflict for Unknown resource (e.g. external
+          // call).
+          if (!accessedValue)
+            return mlir::WalkResult::interrupt();
+
+          // Perform alias analysis against all read dependencies.
+          for (mlir::Value dep : deps) {
+            if (!aa.alias(accessedValue, dep).isNo())
+              return mlir::WalkResult::interrupt();
+          }
+        }
+      }
+    } else if (nestedOp->getNumRegions() == 0) {
+      // Conservative Fallback: If an operation doesn't have  interface and
+      // has no regions (e.g. a fir.call), assume it can modify anything.
+      return mlir::WalkResult::interrupt();
+    }
+
+    return mlir::WalkResult::advance();
+  });
+
+  // Conflict found as walk interrupted.
+  return result.wasInterrupted();
+}
+
+bool isSafeToInline(hlfir::ElementalOp producer, hlfir::ApplyOp applySite,
+                    mlir::AliasAnalysis &aa) {
+  mlir::DominanceInfo domInfo(producer->getParentOp());
+  if (!domInfo.properlyDominates(producer.getOperation(),
+                                 applySite.getOperation()))
+    return false;
+
+  llvm::SmallVector<mlir::Value> deps;
+  getReadDependencies(producer, deps);
+
+  mlir::Operation *func = producer->getParentOfType<mlir::func::FuncOp>();
+  bool conflict = false;
+
+  func->walk([&](mlir::Operation *op) {
+    // Skip the producer and applySite themselves.
+    if (op == producer.getOperation() || op == applySite.getOperation())
+      return mlir::WalkResult::advance();
+
+    // Skip the operation that contains the applySite.
+    // We only care about operations that execute before the applySite
+    // starts or between the producer and the start of the loop.
+    if (op->isAncestor(applySite.getOperation()))
+      return mlir::WalkResult::advance();
+
+    // Only check operations that strictly execute between definition and use.
+    if (domInfo.properlyDominates(producer.getOperation(), op) &&
+        domInfo.dominates(op, applySite.getOperation())) {
+      if (isConflictingWrite(op, deps, aa)) {
+        conflict = true;
+        return mlir::WalkResult::interrupt();
+      }
+    }
+    return mlir::WalkResult::advance();
+  });
+
+  return !conflict;
+}
+
 /// If the elemental has only two uses and those two are an apply operation and
 /// a destroy operation, return those two, otherwise return {}
 static std::optional<std::pair<hlfir::ApplyOp, hlfir::DestroyOp>>
-getTwoUses(hlfir::ElementalOp elemental) {
-  mlir::Operation::user_range users = elemental->getUsers();
-  // don't inline anything with more than one use (plus hfir.destroy)
-  if (std::distance(users.begin(), users.end()) != 2) {
-    return std::nullopt;
-  }
-
+getTwoUses(hlfir::ElementalOp elemental, mlir::AliasAnalysis &aliasAnalysis) {
   // If the ElementalOp must produce a temporary (e.g. for
   // finalization purposes), then we cannot inline it.
   if (hlfir::elementalOpMustProduceTemp(elemental))
     return std::nullopt;
 
   hlfir::ApplyOp apply;
   hlfir::DestroyOp destroy;
-  for (mlir::Operation *user : users)
-    mlir::TypeSwitch<mlir::Operation *, void>(user)
-        .Case([&](hlfir::ApplyOp op) { apply = op; })
-        .Case([&](hlfir::DestroyOp op) { destroy = op; });
+  unsigned applyCount = 0;
+  bool hasOtherUsers = false;
+
+  llvm::SmallVector<mlir::Value> worklist;
+  worklist.push_back(elemental.getResult());
+  llvm::SmallPtrSet<mlir::Value, 16> visited;
+  llvm::SmallPtrSet<mlir::Operation *, 4> uniqueApplies;
+
+  while (!worklist.empty()) {
+    mlir::Value current = worklist.pop_back_val();
+    if (!current || !visited.insert(current).second)
+      continue;
 
-  if (!apply || !destroy)
+    for (mlir::OpOperand &use : current.getUses()) {
+      mlir::Operation *user = use.getOwner();
+
+      mlir::TypeSwitch<mlir::Operation *, void>(user)
+          .Case<hlfir::ApplyOp>([&](hlfir::ApplyOp op) {
+            // Use raw operation pointer to ensure each apply site is
+            // counted only once.
+            if (uniqueApplies.insert(op.getOperation()).second) {
+              apply = op;
+              applyCount++;
+            }
+          })
+          .Case<hlfir::DestroyOp>([&](hlfir::DestroyOp op) {
+            // Track the mandatory destroy operation for the elemental expr.
+            destroy = op;
+          })
+          .Case<hlfir::DeclareOp, fir::ConvertOp>([&](mlir::Operation *op) {
+            // Follow the dataflow through all results of the operation.
+            // For hlfir.declare, this catches both the variable and base
+            // results. For fir.convert, this catches the converted result.
+            for (mlir::Value result : op->getResults()) {
+              worklist.push_back(result);
+            }
+          })
+          // Buffer Consumers - These require the destroy to stay.
+          .Case<hlfir::AssociateOp, hlfir::SumOp, hlfir::AssignOp, fir::CallOp>(
+              [&](mlir::Operation *) { hasOtherUsers = true; })
+          .Case<mlir::BranchOpInterface>([&](mlir::BranchOpInterface branch) {
+            for (unsigned i = 0; i < branch->getNumSuccessors(); ++i) {
+              mlir::SuccessorOperands operands = branch.getSuccessorOperands(i);
+              for (unsigned j = 0; j < operands.size(); ++j) {
+                if (operands[j] == current) {
+                  // The j-th operand of the branch maps to the j-th block
+                  // argument of the successor block.
+                  mlir::Block *successor = branch->getSuccessor(i);
+                  worklist.push_back(successor->getArgument(j));
+                }
+              }
+            }
+          })
+          .Case<fir::ResultOp>([&](fir::ResultOp op) {
+            mlir::Operation *parent = op->getParentOp();
+            // Only forward if the parent is an op that yields values out.
+            if (parent &&
+                mlir::isa<mlir::RegionBranchOpInterface, fir::IfOp,
+                          fir::DoLoopOp, hlfir::ElementalOp>(parent)) {
+              for (auto it : llvm::enumerate(op.getOperands())) {
+                if (it.value() == current) {
+                  // Map the result index to the parent's result index.
+                  unsigned i = it.index();
+                  if (i < parent->getNumResults()) {
+                    worklist.push_back(parent->getResult(i));
+                  }
+                }
+              }
+            } else {
+              // If it's a terminator for an unknown op.
+              hasOtherUsers = true;
+            }
+          })
+          .Default([&](mlir::Operation *op) {
+            if (op->getNumRegions() > 0) {
+              // Follow the value through metadata ops (declare, convert, etc.)
+              // nested inside regions.
+              op->walk([&](mlir::Operation *innerOp) {
+                for (mlir::Value operand : innerOp->getOperands()) {
+                  if (operand == current) {
+                    if (auto nestedApply =
+                            mlir::dyn_cast<hlfir::ApplyOp>(innerOp)) {
+                      // Use a set to prevent double-counting if walker
+                      // and worklist hit the same apply site.
+                      if (uniqueApplies.insert(nestedApply.getOperation())
+                              .second) {
+                        apply = nestedApply;
+                        applyCount++;
+                      }
+                    } else if (mlir::isa<hlfir::DeclareOp, fir::ConvertOp>(
+                                   innerOp)) {
+                      // Feed internal metadata results back into the worklist.
+                      for (mlir::Value res : innerOp->getResults())
+                        worklist.push_back(res);
+                    } else if (!mlir::isa<hlfir::DestroyOp, fir::ResultOp,
+                                          mlir::BranchOpInterface>(innerOp)) {
+                      // If it's an intrinsic or unknown consumer, it needs the
+                      // buffer.
+                      hasOtherUsers = true;
+                    }
+                  }
+                }
+              });
+            } else {
+              // Non-region op not handled by specific Case<> (e.g. hlfir.sum)
+              hasOtherUsers = true;
+            }
+          });
+    }
+  }
+
+  // Only inline if there is a unique 'apply' site. Other users (such as
+  // intrinsic operations) are allowed because scalarizing the elemental
+  // renders the original array result redundant.
+  if (applyCount != 1 || !destroy)
----------------
tblah wrote:

Could we terminate the walk early when `applyCount > 1` to save on compile time?

https://github.com/llvm/llvm-project/pull/186916


More information about the flang-commits mailing list