[flang-commits] [clang] [flang] [mlir] [mlir][Interfaces] Single interface method to query constant region-based CF (PR #193486)

Matthias Springer via flang-commits flang-commits at lists.llvm.org
Wed Apr 22 05:14:21 PDT 2026


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/193486

Query all control flow edges through `RegionBranchOpInterface`. The `RegionBranchTerminatorOpInterface::getSuccessorRegion` interface method was removed and incorporated into a new `RegionBranchOpInterface::getSuccessorRegionsWithConstants` interface method, which handles all region branch points.

In summary:
* Query region branch flow without constant information: `RegionBranchOpInterface::getSuccessorRegions` (no change)
* Query region branch flow with constant information: `RegionBranchOpInterface::getSuccessorRegionsWithConstants` (new)

No change in functionality. This commit is just combining two interface methods.

Also fix two "incorrect" API implementations. (They are not technically incorrect but overly complex. The original authors of these functions seemed to assume that all control flow originating from a non-parent region branch point must be queried through the respective terminator op interface. That's **not** the case. The newly refactored API is simpler: the interface method in question was removed from the terminator op interface.)

* `cir::ConditionOp::getSuccessorRegions` did not inspect `operands`. Therefore, that function did not provide any functionality in addition to the `getSuccessorRegions` implementation of the parent op. `cir::ConditionOp::getSuccessorRegions` was deleted as part of this commit.
* `traverseRegionGraph` used to call `RegionBranchTerminatorOpInterface::getSuccessorRegion`, but passed "empty" attributes. Instead, the `getSuccessorRegion` interface method on the region branch op should be used.

Discussion: https://discourse.llvm.org/t/rfc-structural-and-reachable-control-flow-edges-in-regionbranchopinterface/90496/16


>From 4b7878706a0bda6d0a384f4f82c72bc466eeb5c1 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 22 Apr 2026 11:37:59 +0000
Subject: [PATCH] [mlir][Interfaces] Use single interface method to query
 region-based control flow

---
 clang/include/clang/CIR/Dialect/IR/CIROps.td  |  1 -
 clang/lib/CIR/Dialect/IR/CIRDialect.cpp       | 17 ----
 .../include/flang/Optimizer/Dialect/FIROps.td |  2 +-
 flang/lib/Optimizer/Dialect/FIROps.cpp        | 11 ++-
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td   |  4 +-
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td    |  9 ++-
 .../mlir/Interfaces/ControlFlowInterfaces.h   | 46 +++++++++++
 .../mlir/Interfaces/ControlFlowInterfaces.td  | 80 +++++++++----------
 .../Analysis/DataFlow/DeadCodeAnalysis.cpp    | 10 ++-
 mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp |  5 +-
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp           | 21 ++++-
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 62 +++++++++-----
 .../XeGPU/Transforms/XeGPULayoutImpl.cpp      |  5 +-
 mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 38 ++++++---
 mlir/test/Transforms/sccp.mlir                |  5 +-
 15 files changed, 210 insertions(+), 106 deletions(-)

diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index f20ba262d6480..7ed36869afe85 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -908,7 +908,6 @@ def CIR_IfOp : CIR_Op<"if", [
 def CIR_ConditionOp : CIR_Op<"condition", [
   Terminator,
   DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface, [
-    "getSuccessorRegions",
     "getMutableSuccessorOperands"
   ]>
 ]> {
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 35a31b0dbda63..35d84f8255642 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -379,23 +379,6 @@ LogicalResult cir::BreakOp::verify() {
 // BranchOpTerminatorInterface Methods
 //===----------------------------------
 
-void cir::ConditionOp::getSuccessorRegions(
-    ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> &regions) {
-  // TODO(cir): The condition value may be folded to a constant, narrowing
-  // down its list of possible successors.
-
-  // Parent is a loop: condition may branch to the body or to the parent op.
-  if (auto loopOp = dyn_cast<LoopOpInterface>(getOperation()->getParentOp())) {
-    regions.emplace_back(&loopOp.getBody());
-    regions.push_back(RegionSuccessor::parent());
-  }
-
-  // Parent is an await: condition may branch to resume or suspend regions.
-  auto await = cast<AwaitOp>(getOperation()->getParentOp());
-  regions.emplace_back(&await.getResume());
-  regions.emplace_back(&await.getSuspend());
-}
-
 MutableOperandRange
 cir::ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) {
   // No values are yielded to the successor region.
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index eba4ddda4b6ad..0789c70fdd694 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2459,7 +2459,7 @@ def fir_IfOp
     : region_Op<
           "if", [DeclareOpInterfaceMethods<
                      RegionBranchOpInterface, ["getRegionInvocationBounds",
-                                               "getEntrySuccessorRegions",
+                                               "getSuccessorRegionsWithConstants",
                                                "getSuccessorInputs"]>,
                  RecursiveMemoryEffects, NoRegionArguments,
                  WeightedRegionBranchOpInterface]> {
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 4705033945611..453dbd8a280a2 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -5109,9 +5109,16 @@ fir::IfOp::getSuccessorInputs(mlir::RegionSuccessor successor) {
   return mlir::ValueRange();
 }
 
-void fir::IfOp::getEntrySuccessorRegions(
-    llvm::ArrayRef<mlir::Attribute> operands,
+void fir::IfOp::getSuccessorRegionsWithConstants(
+    mlir::RegionBranchPoint point,
+    const mlir::RegionBranchPointOperandConstants &operandConstants,
     llvm::SmallVectorImpl<mlir::RegionSuccessor> &regions) {
+  llvm::ArrayRef<mlir::Attribute> operands =
+      operandConstants.getOperandConstants(point);
+  if (!point.isParent() || operands.empty()) {
+    getSuccessorRegions(point, regions);
+    return;
+  }
   FoldAdaptor adaptor(operands);
   auto boolAttr =
       mlir::dyn_cast_or_null<mlir::BoolAttr>(adaptor.getCondition());
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 7d9bb8907eb8b..12cc1add40ee3 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1471,7 +1471,7 @@ def EmitC_YieldOp : EmitC_Op<"yield",
 def EmitC_IfOp : EmitC_Op<"if",
     [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
     "getNumRegionInvocations", "getRegionInvocationBounds",
-    "getEntrySuccessorRegions", "getSuccessorInputs"]>,
+    "getSuccessorRegionsWithConstants", "getSuccessorInputs"]>,
     OpAsmOpInterface, SingleBlock,
     SingleBlockImplicitTerminator<"emitc::YieldOp">,
     RecursiveMemoryEffects, NoRegionArguments]> {
@@ -1587,7 +1587,7 @@ def EmitC_SwitchOp : EmitC_Op<"switch", [RecursiveMemoryEffects,
     OpAsmOpInterface, SingleBlockImplicitTerminator<"emitc::YieldOp">,
     DeclareOpInterfaceMethods<RegionBranchOpInterface,
                               ["getRegionInvocationBounds",
-                               "getEntrySuccessorRegions"]>]> {
+                               "getSuccessorRegionsWithConstants"]>]> {
   let summary = "Switch operation";
   let description = [{
     The `emitc.switch` is a control-flow operation that branches to one of
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 0b33ecb48b7f2..1afcdc4b51395 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -55,7 +55,7 @@ class SCF_Op<string mnemonic, list<Trait> traits = []> :
 def ConditionOp : SCF_Op<"condition", [
   HasParent<"WhileOp">,
   DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface,
-    ["getSuccessorRegions", "getMutableSuccessorOperands"]>,
+    ["getMutableSuccessorOperands"]>,
   Pure,
   Terminator
 ]> {
@@ -704,7 +704,7 @@ def InParallelOp : SCF_Op<"forall.in_parallel", [
 
 def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
     "getNumRegionInvocations", "getRegionInvocationBounds",
-    "getEntrySuccessorRegions", "getSuccessorInputs"]>,
+    "getSuccessorRegionsWithConstants", "getSuccessorInputs"]>,
     DeclareOpInterfaceMethods<PromotableRegionOpInterface>,
     InferTypeOpAdaptor, SingleBlockImplicitTerminator<"scf::YieldOp">,
     RecursiveMemoryEffects, RecursivelySpeculatable, NoRegionArguments]> {
@@ -991,7 +991,8 @@ def ReduceReturnOp :
 
 def WhileOp : SCF_Op<"while",
     [DeclareOpInterfaceMethods<RegionBranchOpInterface,
-        ["getEntrySuccessorOperands", "getSuccessorInputs"]>,
+        ["getEntrySuccessorOperands", "getSuccessorInputs",
+         "getSuccessorRegionsWithConstants"]>,
      DeclareOpInterfaceMethods<LoopLikeOpInterface,
         ["getRegionIterArgs", "getYieldedValuesMutable"]>,
      DeclareOpInterfaceMethods<PromotableRegionOpInterface>,
@@ -1147,7 +1148,7 @@ def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
     DeclareOpInterfaceMethods<PromotableRegionOpInterface>,
     DeclareOpInterfaceMethods<RegionBranchOpInterface,
                               ["getRegionInvocationBounds",
-                               "getEntrySuccessorRegions",
+                               "getSuccessorRegionsWithConstants",
                                "getSuccessorInputs"]>]> {
   let summary = "switch-case operation on an index argument";
   let description = [{
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index a76dce6f2ffc5..ef818019350eb 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -270,6 +270,52 @@ class RegionBranchPoint {
   Operation *predecessor = nullptr;
 };
 
+/// A bundle of constant-operand information for a `RegionBranchOpInterface`
+/// implementation and any of the region branch terminators in its regions.
+///
+/// This is used by
+/// `RegionBranchOpInterface::getSuccessorRegionsWithConstants` to allow
+/// analyses/transformations to pass along constant operand values for a given
+/// branch point so that the implementation can refine the returned successors.
+/// Operand constants are optional: an implementation that doesn't recognize
+/// the provided constants (or for which no constants are provided) must
+/// return the same successors as the no-constants overload.
+///
+/// For each branch point (the parent op or a region branch terminator), the
+/// associated `ArrayRef<Attribute>` either has the same size as the number of
+/// operands of that op (with a null attribute for non-constant operands), or
+/// is empty (which means: no constant information is available for this
+/// branch point).
+class RegionBranchPointOperandConstants {
+public:
+  /// Default constructor: no constants known for any branch point.
+  RegionBranchPointOperandConstants() = default;
+
+  /// Set the constant operand information for the parent (region branch) op.
+  void setParentOperandConstants(ArrayRef<Attribute> constants) {
+    parentOperandConstants = constants;
+  }
+
+  /// Set the constant operand information for a specific region branch
+  /// terminator.
+  void setTerminatorOperandConstants(Operation *terminator,
+                                     ArrayRef<Attribute> constants) {
+    terminatorOperandConstants.emplace_back(terminator, constants);
+  }
+
+  /// Returns the constant operand information for the given branch point.
+  /// Returns an empty range if no constants were provided for this point.
+  ArrayRef<Attribute> getOperandConstants(RegionBranchPoint point) const;
+
+private:
+  /// Constant operand information for the parent op of the region branch.
+  ArrayRef<Attribute> parentOperandConstants;
+
+  /// Constant operand information for region branch terminators.
+  SmallVector<std::pair<Operation *, ArrayRef<Attribute>>, 1>
+      terminatorOperandConstants;
+};
+
 /// This class represents upper and lower bounds on the number of times a region
 /// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least
 /// zero, but the upper bound may not be known.
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 06fa724e05fab..ead2a32139555 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -137,6 +137,13 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
     specified with `getEntrySuccessorOperands` /
     `RegionBranchTerminatorOpInterface::getSuccessorOperands`.
 
+    Implementations may also optionally refine the set of successors based on
+    constant operand values. To do so, implement
+    `getSuccessorRegionsWithConstants`, which that takes a
+    `RegionBranchPointOperandConstants` argument; that method is invoked by
+    analyses that have constant-folded information. By default, the
+    constants-aware overload dispatches to the no-constants overload.
+
     A "region successor" indicates the target of a branch. It can indicate:
     1. A region of this op.
     2. `RegionSuccessor::parent()`, i.e., the control flow leaves this op.
@@ -183,7 +190,7 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
         Returns the operands of this operation that are forwarded to the
         successor inputs when branching to `successor`. `successor` is
         guaranteed to be among the successors that are returned by
-        `getEntrySuccessorRegions`/`getSuccessorRegions(parent())`.
+        `getSuccessorRegions(parent())`.
 
         Example: In the above example, this method returns the operand %b of the
         `scf.for` op, regardless of the value of `successor`. I.e., this op always
@@ -198,31 +205,39 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
       }]
     >,
     InterfaceMethod<[{
-        Returns the potential region successors when first executing the op.
+        Returns the potential region successors when branching from `point`.
+
+        This is a "constants-aware" variant of `getSuccessorRegions`. The
+        `operandConstants` parameter bundles up the constant values of the
+        operands of the parent (region branch) op and any of its region
+        branch terminators. Based on these, the implementation may filter out
+        certain successors that are statically known not to be taken.
 
-        Unlike `getSuccessorRegions`, this method also passes along the
-        constant operands of this op. Based on these, the implementation may
-        filter out certain successors. By default, simply dispatches to
-        `getSuccessorRegions`. `operands` contains an entry for every
-        operand of this op, with a null attribute if the operand has no constant
-        value.
+        By default, this method simply dispatches to `getSuccessorRegions`,
+        ignoring the constant operand information.
 
-        Note: The control flow does not necessarily have to enter any region of
-        this op.
+        Implementations should handle the case where no constant information
+        is available for a given branch point (i.e.,
+        `operandConstants.getOperandConstants(point)` returns an empty range)
+        by falling back to the no-constants overload.
 
         Example: In the above example, this method may return two region
-        region successors: the single region of the `scf.for` op and the
-        `scf.for` operation (that implements this interface). If %lb, %ub, %step
-        are constants and it can be determined the loop does not have any
-        iterations, this method may choose to return only this operation.
-        Similarly, if it can be determined that the loop has at least one
-        iteration, this method may choose to return only the region of the loop.
+        region successors for the "parent" branching point: the single region
+        of the `scf.for` op and the `scf.for` operation (that implements this
+        interface). If %lb, %ub, %step are constants and it can be determined
+        the loop does not have any iterations, this method may choose to return
+        only this operation. Similarly, if it can be determined that the loop
+        has at least one iteration, this method may choose to return only the
+        region of the loop.
       }],
-      "void", "getEntrySuccessorRegions",
-      (ins "::llvm::ArrayRef<::mlir::Attribute>":$operands,
-           "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions), [{}],
+      "void", "getSuccessorRegionsWithConstants",
+      (ins "::mlir::RegionBranchPoint":$point,
+           "const ::mlir::RegionBranchPointOperandConstants &"
+             :$operandConstants,
+           "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions),
+      [{}],
       /*defaultImplementation=*/[{
-        $_op.getSuccessorRegions(mlir::RegionBranchPoint::parent(), regions);
+        $_op.getSuccessorRegions(point, regions);
       }]
     >,
     InterfaceMethod<[{
@@ -423,6 +438,11 @@ def RegionBranchTerminatorOpInterface :
     (However, there may be other block terminators in the same region that
     implement the `RegionBranchTerminatorOpInterface`, so the enclosing region
     may have region successors.)
+
+    To query the region successors that may be branched to after a given
+    terminator, use `RegionBranchOpInterface::getSuccessorRegions(point, ...)` /
+    `getSuccessorRegionsWithConstants(point, ...)` on the parent op (passing the
+    terminator as the branch point).
   }];
   let cppNamespace = "::mlir";
 
@@ -438,26 +458,6 @@ def RegionBranchTerminatorOpInterface :
         return ::mlir::MutableOperandRange($_op);
       }]
     >,
-    InterfaceMethod<[{
-        Returns the potential region successors that are branched to after this
-        terminator based on the given constant operands.
-
-        This method also passes along the constant operands of this op.
-        `operands` contains an entry for every operand of this op, with a null
-        attribute if the operand has no constant value.
-
-        The default implementation simply dispatches to the parent
-        `RegionBranchOpInterface`'s `getSuccessorRegions` implementation.
-      }],
-      "void", "getSuccessorRegions",
-      (ins "::llvm::ArrayRef<::mlir::Attribute>":$operands,
-           "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions), [{}],
-      /*defaultImplementation=*/[{
-        ::mlir::Operation *op = $_op;
-        ::llvm::cast<::mlir::RegionBranchOpInterface>(op->getParentOp())
-          .getSuccessorRegions(::llvm::cast<::mlir::RegionBranchTerminatorOpInterface>(op), regions);
-      }]
-    >,
   ];
 
   let verify = [{
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index 38811d06ecd8c..8cf323d13589d 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -497,7 +497,10 @@ void DeadCodeAnalysis::visitRegionBranchOperation(
     return;
 
   SmallVector<RegionSuccessor> successors;
-  branch.getEntrySuccessorRegions(*operands, successors);
+  RegionBranchPointOperandConstants operandConstants;
+  operandConstants.setParentOperandConstants(*operands);
+  branch.getSuccessorRegionsWithConstants(RegionBranchPoint::parent(),
+                                          operandConstants, successors);
 
   visitRegionBranchEdges(branch, branch.getOperation(), successors);
 }
@@ -513,7 +516,10 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
   auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op);
   if (!terminator)
     return;
-  terminator.getSuccessorRegions(*operands, successors);
+  RegionBranchPointOperandConstants operandConstants;
+  operandConstants.setTerminatorOperandConstants(terminator, *operands);
+  branch.getSuccessorRegionsWithConstants(RegionBranchPoint(terminator),
+                                          operandConstants, successors);
   visitRegionBranchEdges(branch, op, successors);
 }
 
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 90f2a588d1ca4..5a2ac40f311d8 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -617,7 +617,10 @@ void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
   Operation *op = branch.getOperation();
   SmallVector<RegionSuccessor> successors;
   SmallVector<Attribute> operands(op->getNumOperands(), nullptr);
-  branch.getEntrySuccessorRegions(operands, successors);
+  RegionBranchPointOperandConstants operandConstants;
+  operandConstants.setParentOperandConstants(operands);
+  branch.getSuccessorRegionsWithConstants(RegionBranchPoint::parent(),
+                                          operandConstants, successors);
   for (RegionSuccessor &successor : successors) {
     if (successor.isParent())
       continue;
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 8def84fc49378..dc0ed005526b2 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -991,8 +991,15 @@ ValueRange IfOp::getSuccessorInputs(RegionSuccessor successor) {
                               : ValueRange();
 }
 
-void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
-                                    SmallVectorImpl<RegionSuccessor> &regions) {
+void IfOp::getSuccessorRegionsWithConstants(
+    RegionBranchPoint point,
+    const RegionBranchPointOperandConstants &operandConstants,
+    SmallVectorImpl<RegionSuccessor> &regions) {
+  ArrayRef<Attribute> operands = operandConstants.getOperandConstants(point);
+  if (!point.isParent() || operands.empty()) {
+    getSuccessorRegions(point, regions);
+    return;
+  }
   FoldAdaptor adaptor(operands, *this);
   auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
   if (!boolAttr || boolAttr.getValue())
@@ -1564,9 +1571,15 @@ static std::optional<int64_t> getIntAttrValue(IntegerAttr attr) {
   return std::nullopt;
 }
 
-void SwitchOp::getEntrySuccessorRegions(
-    ArrayRef<Attribute> operands,
+void SwitchOp::getSuccessorRegionsWithConstants(
+    RegionBranchPoint point,
+    const RegionBranchPointOperandConstants &operandConstants,
     SmallVectorImpl<RegionSuccessor> &successors) {
+  ArrayRef<Attribute> operands = operandConstants.getOperandConstants(point);
+  if (!point.isParent() || operands.empty()) {
+    getSuccessorRegions(point, successors);
+    return;
+  }
   FoldAdaptor adaptor(operands, *this);
 
   // If a constant was not provided, all regions are possible successors.
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 9f4f4dc9f58e6..e632077de678a 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -296,21 +296,6 @@ ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) {
   return getArgsMutable();
 }
 
-void ConditionOp::getSuccessorRegions(
-    ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> &regions) {
-  FoldAdaptor adaptor(operands, *this);
-
-  WhileOp whileOp = getParentOp();
-
-  // Condition can either lead to the after region or back to the parent op
-  // depending on whether the condition is true or not.
-  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
-  if (!boolAttr || boolAttr.getValue())
-    regions.emplace_back(&whileOp.getAfter());
-  if (!boolAttr || !boolAttr.getValue())
-    regions.push_back(RegionSuccessor::parent());
-}
-
 //===----------------------------------------------------------------------===//
 // ForOp
 //===----------------------------------------------------------------------===//
@@ -2120,8 +2105,16 @@ ValueRange IfOp::getSuccessorInputs(RegionSuccessor successor) {
                               : ValueRange();
 }
 
-void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
-                                    SmallVectorImpl<RegionSuccessor> &regions) {
+void IfOp::getSuccessorRegionsWithConstants(
+    RegionBranchPoint point,
+    const RegionBranchPointOperandConstants &operandConstants,
+    SmallVectorImpl<RegionSuccessor> &regions) {
+  // Constant folding only applies when entering from the parent op.
+  ArrayRef<Attribute> operands = operandConstants.getOperandConstants(point);
+  if (!point.isParent() || operands.empty()) {
+    getSuccessorRegions(point, regions);
+    return;
+  }
   FoldAdaptor adaptor(operands, *this);
   auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
   if (!boolAttr || boolAttr.getValue())
@@ -3291,6 +3284,30 @@ void WhileOp::getSuccessorRegions(RegionBranchPoint point,
   regions.emplace_back(&getAfter());
 }
 
+void WhileOp::getSuccessorRegionsWithConstants(
+    RegionBranchPoint point,
+    const RegionBranchPointOperandConstants &operandConstants,
+    SmallVectorImpl<RegionSuccessor> &regions) {
+  // Constant folding only applies when branching from the `scf.condition`
+  // terminator of the "before" region. For all other branch points, fall back
+  // to the unfiltered behavior.
+  auto conditionOp =
+      dyn_cast_or_null<ConditionOp>(point.getTerminatorPredecessorOrNull());
+  ArrayRef<Attribute> operands = operandConstants.getOperandConstants(point);
+  if (!conditionOp || operands.empty()) {
+    getSuccessorRegions(point, regions);
+    return;
+  }
+  ConditionOp::FoldAdaptor adaptor(operands, conditionOp);
+  // Condition can either lead to the after region or back to the parent op
+  // depending on whether the condition is true or not.
+  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
+  if (!boolAttr || boolAttr.getValue())
+    regions.emplace_back(&getAfter());
+  if (!boolAttr || !boolAttr.getValue())
+    regions.push_back(RegionSuccessor::parent());
+}
+
 ValueRange WhileOp::getSuccessorInputs(RegionSuccessor successor) {
   if (successor.isParent())
     return getOperation()->getResults();
@@ -3842,9 +3859,16 @@ ValueRange IndexSwitchOp::getSuccessorInputs(RegionSuccessor successor) {
                               : ValueRange();
 }
 
-void IndexSwitchOp::getEntrySuccessorRegions(
-    ArrayRef<Attribute> operands,
+void IndexSwitchOp::getSuccessorRegionsWithConstants(
+    RegionBranchPoint point,
+    const RegionBranchPointOperandConstants &operandConstants,
     SmallVectorImpl<RegionSuccessor> &successors) {
+  // Constant folding only applies when entering from the parent op.
+  ArrayRef<Attribute> operands = operandConstants.getOperandConstants(point);
+  if (!point.isParent() || operands.empty()) {
+    getSuccessorRegions(point, successors);
+    return;
+  }
   FoldAdaptor adaptor(operands, *this);
 
   // If a constant was not provided, all regions are possible successors.
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 7d48315eec6ff..b919d46a07ca7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -174,7 +174,10 @@ static void propagateRegionResultsToYieldOperands(
 
   SmallVector<RegionSuccessor> successors;
   SmallVector<Attribute> operandAttrs(yieldOp->getNumOperands(), nullptr);
-  yieldOp.getSuccessorRegions(operandAttrs, successors);
+  RegionBranchPointOperandConstants operandConstants;
+  operandConstants.setTerminatorOperandConstants(yieldOp, operandAttrs);
+  regionBranchOp.getSuccessorRegionsWithConstants(RegionBranchPoint(yieldOp),
+                                                  operandConstants, successors);
 
   for (const RegionSuccessor &successor : successors) {
     OperandRange succOps = yieldOp.getSuccessorOperands(successor);
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index c3fb73acf5ef0..2d53c2d66bd07 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -152,6 +152,21 @@ LogicalResult detail::verifyRegionBranchWeights(Operation *op) {
   return verifyWeights(op, weights, op->getNumRegions(), "region", "regions");
 }
 
+//===----------------------------------------------------------------------===//
+// RegionBranchPointOperandConstants
+//===----------------------------------------------------------------------===//
+
+ArrayRef<Attribute> RegionBranchPointOperandConstants::getOperandConstants(
+    RegionBranchPoint point) const {
+  if (point.isParent())
+    return parentOperandConstants;
+  Operation *terminator = point.getTerminatorPredecessorOrNull();
+  for (const auto &entry : terminatorOperandConstants)
+    if (entry.first == terminator)
+      return entry.second;
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // RegionBranchOpInterface
 //===----------------------------------------------------------------------===//
@@ -246,7 +261,6 @@ static bool traverseRegionGraph(Region *begin,
   SmallVector<Region *> worklist;
   auto enqueueAllSuccessors = [&](Region *region) {
     LDBG() << "Enqueuing successors for region #" << region->getRegionNumber();
-    SmallVector<Attribute> operandAttributes(op->getNumOperands());
     for (Block &block : *region) {
       if (block.empty())
         continue;
@@ -255,8 +269,9 @@ static bool traverseRegionGraph(Region *begin,
       if (!terminator)
         continue;
       SmallVector<RegionSuccessor> successors;
-      operandAttributes.resize(terminator->getNumOperands());
-      terminator.getSuccessorRegions(operandAttributes, successors);
+      // No constant operand information is provided here; the unfiltered set
+      // of successors is sufficient for the region-graph traversal.
+      op.getSuccessorRegions(RegionBranchPoint(terminator), successors);
       LDBG() << "Found " << successors.size()
              << " successors from terminator in block";
       for (RegionSuccessor successor : successors) {
@@ -1088,15 +1103,18 @@ static SmallVector<RegionSuccessor>
 getSuccessorRegionsWithAttrs(RegionBranchOpInterface op,
                              RegionBranchPoint point) {
   SmallVector<RegionSuccessor> successors;
+  RegionBranchPointOperandConstants operandConstants;
+  SmallVector<Attribute> constants;
   if (point.isParent()) {
-    op.getEntrySuccessorRegions(extractConstants(op->getOperands()),
-                                successors);
-    return successors;
+    constants = extractConstants(op->getOperands());
+    operandConstants.setParentOperandConstants(constants);
+  } else {
+    RegionBranchTerminatorOpInterface terminator =
+        point.getTerminatorPredecessorOrNull();
+    constants = extractConstants(terminator->getOperands());
+    operandConstants.setTerminatorOperandConstants(terminator, constants);
   }
-  RegionBranchTerminatorOpInterface terminator =
-      point.getTerminatorPredecessorOrNull();
-  terminator.getSuccessorRegions(extractConstants(terminator->getOperands()),
-                                 successors);
+  op.getSuccessorRegionsWithConstants(point, operandConstants, successors);
   return successors;
 }
 
diff --git a/mlir/test/Transforms/sccp.mlir b/mlir/test/Transforms/sccp.mlir
index f5a5183593d03..ec61644dbb5ea 100644
--- a/mlir/test/Transforms/sccp.mlir
+++ b/mlir/test/Transforms/sccp.mlir
@@ -290,8 +290,9 @@ func.func @no_crash_acc_kernel_environment(%data: memref<8xi32>) {
 // -----
 
 // Regression test for https://github.com/llvm/llvm-project/issues/187973
-// SwitchOp::getEntrySuccessorRegions must not call IntegerAttr::getInt() on
-// an unsigned integer type — that function asserts signless/index only.
+// SwitchOp::getSuccessorRegionsWithConstants must not call
+// IntegerAttr::getInt() on an unsigned integer type — that function asserts
+// signless/index only.
 
 // CHECK-LABEL: no_crash_emitc_switch_unsigned_condition
 func.func @no_crash_emitc_switch_unsigned_condition() {



More information about the flang-commits mailing list