[Mlir-commits] [clang] [flang] [mlir] [mlir][Interfaces] Single interface method to query constant region-based CF (PR #193486)

Matthias Springer llvmlistbot at llvm.org
Wed Apr 22 05:23:00 PDT 2026


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/193486

>From 56af902487800cf6c6d49250194f29b9eadc0943 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 22 Apr 2026 11:37:59 +0000
Subject: [PATCH] [mlir][Interfaces] Use single interface method to query
 region-based control flow

---
 clang/include/clang/CIR/Dialect/IR/CIROps.td  |  1 -
 clang/lib/CIR/Dialect/IR/CIRDialect.cpp       | 17 ----
 .../include/flang/Optimizer/Dialect/FIROps.td |  2 +-
 flang/lib/Optimizer/Dialect/FIROps.cpp        | 11 ++-
 mlir/include/mlir/Dialect/EmitC/IR/EmitC.td   |  4 +-
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td    |  9 ++-
 .../mlir/Interfaces/ControlFlowInterfaces.h   | 45 +++++++++++
 .../mlir/Interfaces/ControlFlowInterfaces.td  | 80 +++++++++----------
 .../Analysis/DataFlow/DeadCodeAnalysis.cpp    | 10 ++-
 mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp |  5 +-
 mlir/lib/Dialect/EmitC/IR/EmitC.cpp           | 21 ++++-
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 62 +++++++++-----
 .../XeGPU/Transforms/XeGPULayoutImpl.cpp      |  5 +-
 mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 38 ++++++---
 mlir/test/Transforms/sccp.mlir                |  5 +-
 15 files changed, 209 insertions(+), 106 deletions(-)

diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index f20ba262d6480..7ed36869afe85 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -908,7 +908,6 @@ def CIR_IfOp : CIR_Op<"if", [
 def CIR_ConditionOp : CIR_Op<"condition", [
   Terminator,
   DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface, [
-    "getSuccessorRegions",
     "getMutableSuccessorOperands"
   ]>
 ]> {
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 35a31b0dbda63..35d84f8255642 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -379,23 +379,6 @@ LogicalResult cir::BreakOp::verify() {
 // BranchOpTerminatorInterface Methods
 //===----------------------------------
 
-void cir::ConditionOp::getSuccessorRegions(
-    ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> &regions) {
-  // TODO(cir): The condition value may be folded to a constant, narrowing
-  // down its list of possible successors.
-
-  // Parent is a loop: condition may branch to the body or to the parent op.
-  if (auto loopOp = dyn_cast<LoopOpInterface>(getOperation()->getParentOp())) {
-    regions.emplace_back(&loopOp.getBody());
-    regions.push_back(RegionSuccessor::parent());
-  }
-
-  // Parent is an await: condition may branch to resume or suspend regions.
-  auto await = cast<AwaitOp>(getOperation()->getParentOp());
-  regions.emplace_back(&await.getResume());
-  regions.emplace_back(&await.getSuspend());
-}
-
 MutableOperandRange
 cir::ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) {
   // No values are yielded to the successor region.
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index eba4ddda4b6ad..0789c70fdd694 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2459,7 +2459,7 @@ def fir_IfOp
     : region_Op<
           "if", [DeclareOpInterfaceMethods<
                      RegionBranchOpInterface, ["getRegionInvocationBounds",
-                                               "getEntrySuccessorRegions",
+                                               "getSuccessorRegionsWithConstants",
                                                "getSuccessorInputs"]>,
                  RecursiveMemoryEffects, NoRegionArguments,
                  WeightedRegionBranchOpInterface]> {
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 4705033945611..453dbd8a280a2 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -5109,9 +5109,16 @@ fir::IfOp::getSuccessorInputs(mlir::RegionSuccessor successor) {
   return mlir::ValueRange();
 }
 
-void fir::IfOp::getEntrySuccessorRegions(
-    llvm::ArrayRef<mlir::Attribute> operands,
+void fir::IfOp::getSuccessorRegionsWithConstants(
+    mlir::RegionBranchPoint point,
+    const mlir::RegionBranchPointOperandConstants &operandConstants,
     llvm::SmallVectorImpl<mlir::RegionSuccessor> &regions) {
+  llvm::ArrayRef<mlir::Attribute> operands =
+      operandConstants.getOperandConstants(point);
+  if (!point.isParent() || operands.empty()) {
+    getSuccessorRegions(point, regions);
+    return;
+  }
   FoldAdaptor adaptor(operands);
   auto boolAttr =
       mlir::dyn_cast_or_null<mlir::BoolAttr>(adaptor.getCondition());
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 7d9bb8907eb8b..12cc1add40ee3 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1471,7 +1471,7 @@ def EmitC_YieldOp : EmitC_Op<"yield",
 def EmitC_IfOp : EmitC_Op<"if",
     [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
     "getNumRegionInvocations", "getRegionInvocationBounds",
-    "getEntrySuccessorRegions", "getSuccessorInputs"]>,
+    "getSuccessorRegionsWithConstants", "getSuccessorInputs"]>,
     OpAsmOpInterface, SingleBlock,
     SingleBlockImplicitTerminator<"emitc::YieldOp">,
     RecursiveMemoryEffects, NoRegionArguments]> {
@@ -1587,7 +1587,7 @@ def EmitC_SwitchOp : EmitC_Op<"switch", [RecursiveMemoryEffects,
     OpAsmOpInterface, SingleBlockImplicitTerminator<"emitc::YieldOp">,
     DeclareOpInterfaceMethods<RegionBranchOpInterface,
                               ["getRegionInvocationBounds",
-                               "getEntrySuccessorRegions"]>]> {
+                               "getSuccessorRegionsWithConstants"]>]> {
   let summary = "Switch operation";
   let description = [{
     The `emitc.switch` is a control-flow operation that branches to one of
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 0b33ecb48b7f2..1afcdc4b51395 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -55,7 +55,7 @@ class SCF_Op<string mnemonic, list<Trait> traits = []> :
 def ConditionOp : SCF_Op<"condition", [
   HasParent<"WhileOp">,
   DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface,
-    ["getSuccessorRegions", "getMutableSuccessorOperands"]>,
+    ["getMutableSuccessorOperands"]>,
   Pure,
   Terminator
 ]> {
@@ -704,7 +704,7 @@ def InParallelOp : SCF_Op<"forall.in_parallel", [
 
 def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
     "getNumRegionInvocations", "getRegionInvocationBounds",
-    "getEntrySuccessorRegions", "getSuccessorInputs"]>,
+    "getSuccessorRegionsWithConstants", "getSuccessorInputs"]>,
     DeclareOpInterfaceMethods<PromotableRegionOpInterface>,
     InferTypeOpAdaptor, SingleBlockImplicitTerminator<"scf::YieldOp">,
     RecursiveMemoryEffects, RecursivelySpeculatable, NoRegionArguments]> {
@@ -991,7 +991,8 @@ def ReduceReturnOp :
 
 def WhileOp : SCF_Op<"while",
     [DeclareOpInterfaceMethods<RegionBranchOpInterface,
-        ["getEntrySuccessorOperands", "getSuccessorInputs"]>,
+        ["getEntrySuccessorOperands", "getSuccessorInputs",
+         "getSuccessorRegionsWithConstants"]>,
      DeclareOpInterfaceMethods<LoopLikeOpInterface,
         ["getRegionIterArgs", "getYieldedValuesMutable"]>,
      DeclareOpInterfaceMethods<PromotableRegionOpInterface>,
@@ -1147,7 +1148,7 @@ def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
     DeclareOpInterfaceMethods<PromotableRegionOpInterface>,
     DeclareOpInterfaceMethods<RegionBranchOpInterface,
                               ["getRegionInvocationBounds",
-                               "getEntrySuccessorRegions",
+                               "getSuccessorRegionsWithConstants",
                                "getSuccessorInputs"]>]> {
   let summary = "switch-case operation on an index argument";
   let description = [{
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index a76dce6f2ffc5..8138e95e69d36 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -270,6 +270,51 @@ class RegionBranchPoint {
   Operation *predecessor = nullptr;
 };
 
+/// A bundle of constant-operand information for a `RegionBranchOpInterface`
+/// implementation and any of the region branch terminators in its regions.
+///
+/// This is used by
+/// `RegionBranchOpInterface::getSuccessorRegionsWithConstants` to allow
+/// analyses/transformations to pass along constant operand values for a given
+/// branch point so that the implementation can refine the returned successors.
+/// Operand constants are optional: an implementation that doesn't recognize
+/// the provided constants (or for which no constants are provided) must
+/// return the same successors as the no-constants overload.
+///
+/// For each branch point (the parent op or a region branch terminator), the
+/// associated `ArrayRef<Attribute>` either has the same size as the number of
+/// operands of that op (with a null attribute for non-constant operands), or
+/// is empty (which means: no constant information is available for this
+/// branch point).
+class RegionBranchPointOperandConstants {
+public:
+  /// Default constructor: no constants known for any branch point.
+  RegionBranchPointOperandConstants() = default;
+
+  /// Set the constant operand information for the parent (region branch) op.
+  void setParentOperandConstants(ArrayRef<Attribute> constants) {
+    parentOperandConstants = constants;
+  }
+
+  /// Set the constant operand information for a specific region branch
+  /// terminator.
+  void setTerminatorOperandConstants(Operation *terminator,
+                                     ArrayRef<Attribute> constants) {
+    terminatorOperandConstants[terminator] = constants;
+  }
+
+  /// Returns the constant operand information for the given branch point.
+  /// Returns an empty range if no constants were provided for this point.
+  ArrayRef<Attribute> getOperandConstants(RegionBranchPoint point) const;
+
+private:
+  /// Constant operand information for the parent op of the region branch.
+  ArrayRef<Attribute> parentOperandConstants;
+
+  /// Constant operand information for region branch terminators.
+  DenseMap<Operation *, ArrayRef<Attribute>> terminatorOperandConstants;
+};
+
 /// This class represents upper and lower bounds on the number of times a region
 /// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least
 /// zero, but the upper bound may not be known.
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 06fa724e05fab..ead2a32139555 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -137,6 +137,13 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
     specified with `getEntrySuccessorOperands` /
     `RegionBranchTerminatorOpInterface::getSuccessorOperands`.
 
+    Implementations may also optionally refine the set of successors based on
+    constant operand values. To do so, implement
+    `getSuccessorRegionsWithConstants`, which that takes a
+    `RegionBranchPointOperandConstants` argument; that method is invoked by
+    analyses that have constant-folded information. By default, the
+    constants-aware overload dispatches to the no-constants overload.
+
     A "region successor" indicates the target of a branch. It can indicate:
     1. A region of this op.
     2. `RegionSuccessor::parent()`, i.e., the control flow leaves this op.
@@ -183,7 +190,7 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
         Returns the operands of this operation that are forwarded to the
         successor inputs when branching to `successor`. `successor` is
         guaranteed to be among the successors that are returned by
-        `getEntrySuccessorRegions`/`getSuccessorRegions(parent())`.
+        `getSuccessorRegions(parent())`.
 
         Example: In the above example, this method returns the operand %b of the
         `scf.for` op, regardless of the value of `successor`. I.e., this op always
@@ -198,31 +205,39 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
       }]
     >,
     InterfaceMethod<[{
-        Returns the potential region successors when first executing the op.
+        Returns the potential region successors when branching from `point`.
+
+        This is a "constants-aware" variant of `getSuccessorRegions`. The
+        `operandConstants` parameter bundles up the constant values of the
+        operands of the parent (region branch) op and any of its region
+        branch terminators. Based on these, the implementation may filter out
+        certain successors that are statically known not to be taken.
 
-        Unlike `getSuccessorRegions`, this method also passes along the
-        constant operands of this op. Based on these, the implementation may
-        filter out certain successors. By default, simply dispatches to
-        `getSuccessorRegions`. `operands` contains an entry for every
-        operand of this op, with a null attribute if the operand has no constant
-        value.
+        By default, this method simply dispatches to `getSuccessorRegions`,
+        ignoring the constant operand information.
 
-        Note: The control flow does not necessarily have to enter any region of
-        this op.
+        Implementations should handle the case where no constant information
+        is available for a given branch point (i.e.,
+        `operandConstants.getOperandConstants(point)` returns an empty range)
+        by falling back to the no-constants overload.
 
         Example: In the above example, this method may return two region
-        region successors: the single region of the `scf.for` op and the
-        `scf.for` operation (that implements this interface). If %lb, %ub, %step
-        are constants and it can be determined the loop does not have any
-        iterations, this method may choose to return only this operation.
-        Similarly, if it can be determined that the loop has at least one
-        iteration, this method may choose to return only the region of the loop.
+        region successors for the "parent" branching point: the single region
+        of the `scf.for` op and the `scf.for` operation (that implements this
+        interface). If %lb, %ub, %step are constants and it can be determined
+        the loop does not have any iterations, this method may choose to return
+        only this operation. Similarly, if it can be determined that the loop
+        has at least one iteration, this method may choose to return only the
+        region of the loop.
       }],
-      "void", "getEntrySuccessorRegions",
-      (ins "::llvm::ArrayRef<::mlir::Attribute>":$operands,
-           "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions), [{}],
+      "void", "getSuccessorRegionsWithConstants",
+      (ins "::mlir::RegionBranchPoint":$point,
+           "const ::mlir::RegionBranchPointOperandConstants &"
+             :$operandConstants,
+           "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions),
+      [{}],
       /*defaultImplementation=*/[{
-        $_op.getSuccessorRegions(mlir::RegionBranchPoint::parent(), regions);
+        $_op.getSuccessorRegions(point, regions);
       }]
     >,
     InterfaceMethod<[{
@@ -423,6 +438,11 @@ def RegionBranchTerminatorOpInterface :
     (However, there may be other block terminators in the same region that
     implement the `RegionBranchTerminatorOpInterface`, so the enclosing region
     may have region successors.)
+
+    To query the region successors that may be branched to after a given
+    terminator, use `RegionBranchOpInterface::getSuccessorRegions(point, ...)` /
+    `getSuccessorRegionsWithConstants(point, ...)` on the parent op (passing the
+    terminator as the branch point).
   }];
   let cppNamespace = "::mlir";
 
@@ -438,26 +458,6 @@ def RegionBranchTerminatorOpInterface :
         return ::mlir::MutableOperandRange($_op);
       }]
     >,
-    InterfaceMethod<[{
-        Returns the potential region successors that are branched to after this
-        terminator based on the given constant operands.
-
-        This method also passes along the constant operands of this op.
-        `operands` contains an entry for every operand of this op, with a null
-        attribute if the operand has no constant value.
-
-        The default implementation simply dispatches to the parent
-        `RegionBranchOpInterface`'s `getSuccessorRegions` implementation.
-      }],
-      "void", "getSuccessorRegions",
-      (ins "::llvm::ArrayRef<::mlir::Attribute>":$operands,
-           "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions), [{}],
-      /*defaultImplementation=*/[{
-        ::mlir::Operation *op = $_op;
-        ::llvm::cast<::mlir::RegionBranchOpInterface>(op->getParentOp())
-          .getSuccessorRegions(::llvm::cast<::mlir::RegionBranchTerminatorOpInterface>(op), regions);
-      }]
-    >,
   ];
 
   let verify = [{
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index 38811d06ecd8c..8cf323d13589d 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -497,7 +497,10 @@ void DeadCodeAnalysis::visitRegionBranchOperation(
     return;
 
   SmallVector<RegionSuccessor> successors;
-  branch.getEntrySuccessorRegions(*operands, successors);
+  RegionBranchPointOperandConstants operandConstants;
+  operandConstants.setParentOperandConstants(*operands);
+  branch.getSuccessorRegionsWithConstants(RegionBranchPoint::parent(),
+                                          operandConstants, successors);
 
   visitRegionBranchEdges(branch, branch.getOperation(), successors);
 }
@@ -513,7 +516,10 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
   auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op);
   if (!terminator)
     return;
-  terminator.getSuccessorRegions(*operands, successors);
+  RegionBranchPointOperandConstants operandConstants;
+  operandConstants.setTerminatorOperandConstants(terminator, *operands);
+  branch.getSuccessorRegionsWithConstants(RegionBranchPoint(terminator),
+                                          operandConstants, successors);
   visitRegionBranchEdges(branch, op, successors);
 }
 
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 90f2a588d1ca4..5a2ac40f311d8 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -617,7 +617,10 @@ void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
   Operation *op = branch.getOperation();
   SmallVector<RegionSuccessor> successors;
   SmallVector<Attribute> operands(op->getNumOperands(), nullptr);
-  branch.getEntrySuccessorRegions(operands, successors);
+  RegionBranchPointOperandConstants operandConstants;
+  operandConstants.setParentOperandConstants(operands);
+  branch.getSuccessorRegionsWithConstants(RegionBranchPoint::parent(),
+                                          operandConstants, successors);
   for (RegionSuccessor &successor : successors) {
     if (successor.isParent())
       continue;
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 8def84fc49378..dc0ed005526b2 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -991,8 +991,15 @@ ValueRange IfOp::getSuccessorInputs(RegionSuccessor successor) {
                               : ValueRange();
 }
 
-void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
-                                    SmallVectorImpl<RegionSuccessor> &regions) {
+void IfOp::getSuccessorRegionsWithConstants(
+    RegionBranchPoint point,
+    const RegionBranchPointOperandConstants &operandConstants,
+    SmallVectorImpl<RegionSuccessor> &regions) {
+  ArrayRef<Attribute> operands = operandConstants.getOperandConstants(point);
+  if (!point.isParent() || operands.empty()) {
+    getSuccessorRegions(point, regions);
+    return;
+  }
   FoldAdaptor adaptor(operands, *this);
   auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
   if (!boolAttr || boolAttr.getValue())
@@ -1564,9 +1571,15 @@ static std::optional<int64_t> getIntAttrValue(IntegerAttr attr) {
   return std::nullopt;
 }
 
-void SwitchOp::getEntrySuccessorRegions(
-    ArrayRef<Attribute> operands,
+void SwitchOp::getSuccessorRegionsWithConstants(
+    RegionBranchPoint point,
+    const RegionBranchPointOperandConstants &operandConstants,
     SmallVectorImpl<RegionSuccessor> &successors) {
+  ArrayRef<Attribute> operands = operandConstants.getOperandConstants(point);
+  if (!point.isParent() || operands.empty()) {
+    getSuccessorRegions(point, successors);
+    return;
+  }
   FoldAdaptor adaptor(operands, *this);
 
   // If a constant was not provided, all regions are possible successors.
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 9f4f4dc9f58e6..e632077de678a 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -296,21 +296,6 @@ ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) {
   return getArgsMutable();
 }
 
-void ConditionOp::getSuccessorRegions(
-    ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> &regions) {
-  FoldAdaptor adaptor(operands, *this);
-
-  WhileOp whileOp = getParentOp();
-
-  // Condition can either lead to the after region or back to the parent op
-  // depending on whether the condition is true or not.
-  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
-  if (!boolAttr || boolAttr.getValue())
-    regions.emplace_back(&whileOp.getAfter());
-  if (!boolAttr || !boolAttr.getValue())
-    regions.push_back(RegionSuccessor::parent());
-}
-
 //===----------------------------------------------------------------------===//
 // ForOp
 //===----------------------------------------------------------------------===//
@@ -2120,8 +2105,16 @@ ValueRange IfOp::getSuccessorInputs(RegionSuccessor successor) {
                               : ValueRange();
 }
 
-void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
-                                    SmallVectorImpl<RegionSuccessor> &regions) {
+void IfOp::getSuccessorRegionsWithConstants(
+    RegionBranchPoint point,
+    const RegionBranchPointOperandConstants &operandConstants,
+    SmallVectorImpl<RegionSuccessor> &regions) {
+  // Constant folding only applies when entering from the parent op.
+  ArrayRef<Attribute> operands = operandConstants.getOperandConstants(point);
+  if (!point.isParent() || operands.empty()) {
+    getSuccessorRegions(point, regions);
+    return;
+  }
   FoldAdaptor adaptor(operands, *this);
   auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
   if (!boolAttr || boolAttr.getValue())
@@ -3291,6 +3284,30 @@ void WhileOp::getSuccessorRegions(RegionBranchPoint point,
   regions.emplace_back(&getAfter());
 }
 
+void WhileOp::getSuccessorRegionsWithConstants(
+    RegionBranchPoint point,
+    const RegionBranchPointOperandConstants &operandConstants,
+    SmallVectorImpl<RegionSuccessor> &regions) {
+  // Constant folding only applies when branching from the `scf.condition`
+  // terminator of the "before" region. For all other branch points, fall back
+  // to the unfiltered behavior.
+  auto conditionOp =
+      dyn_cast_or_null<ConditionOp>(point.getTerminatorPredecessorOrNull());
+  ArrayRef<Attribute> operands = operandConstants.getOperandConstants(point);
+  if (!conditionOp || operands.empty()) {
+    getSuccessorRegions(point, regions);
+    return;
+  }
+  ConditionOp::FoldAdaptor adaptor(operands, conditionOp);
+  // Condition can either lead to the after region or back to the parent op
+  // depending on whether the condition is true or not.
+  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
+  if (!boolAttr || boolAttr.getValue())
+    regions.emplace_back(&getAfter());
+  if (!boolAttr || !boolAttr.getValue())
+    regions.push_back(RegionSuccessor::parent());
+}
+
 ValueRange WhileOp::getSuccessorInputs(RegionSuccessor successor) {
   if (successor.isParent())
     return getOperation()->getResults();
@@ -3842,9 +3859,16 @@ ValueRange IndexSwitchOp::getSuccessorInputs(RegionSuccessor successor) {
                               : ValueRange();
 }
 
-void IndexSwitchOp::getEntrySuccessorRegions(
-    ArrayRef<Attribute> operands,
+void IndexSwitchOp::getSuccessorRegionsWithConstants(
+    RegionBranchPoint point,
+    const RegionBranchPointOperandConstants &operandConstants,
     SmallVectorImpl<RegionSuccessor> &successors) {
+  // Constant folding only applies when entering from the parent op.
+  ArrayRef<Attribute> operands = operandConstants.getOperandConstants(point);
+  if (!point.isParent() || operands.empty()) {
+    getSuccessorRegions(point, successors);
+    return;
+  }
   FoldAdaptor adaptor(operands, *this);
 
   // If a constant was not provided, all regions are possible successors.
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 7d48315eec6ff..b919d46a07ca7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -174,7 +174,10 @@ static void propagateRegionResultsToYieldOperands(
 
   SmallVector<RegionSuccessor> successors;
   SmallVector<Attribute> operandAttrs(yieldOp->getNumOperands(), nullptr);
-  yieldOp.getSuccessorRegions(operandAttrs, successors);
+  RegionBranchPointOperandConstants operandConstants;
+  operandConstants.setTerminatorOperandConstants(yieldOp, operandAttrs);
+  regionBranchOp.getSuccessorRegionsWithConstants(RegionBranchPoint(yieldOp),
+                                                  operandConstants, successors);
 
   for (const RegionSuccessor &successor : successors) {
     OperandRange succOps = yieldOp.getSuccessorOperands(successor);
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index c3fb73acf5ef0..686a8c3b15484 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -152,6 +152,21 @@ LogicalResult detail::verifyRegionBranchWeights(Operation *op) {
   return verifyWeights(op, weights, op->getNumRegions(), "region", "regions");
 }
 
+//===----------------------------------------------------------------------===//
+// RegionBranchPointOperandConstants
+//===----------------------------------------------------------------------===//
+
+ArrayRef<Attribute> RegionBranchPointOperandConstants::getOperandConstants(
+    RegionBranchPoint point) const {
+  if (point.isParent())
+    return parentOperandConstants;
+  Operation *terminator = point.getTerminatorPredecessorOrNull();
+  auto it = terminatorOperandConstants.find(terminator);
+  if (it == terminatorOperandConstants.end())
+    return {};
+  return it->second;
+}
+
 //===----------------------------------------------------------------------===//
 // RegionBranchOpInterface
 //===----------------------------------------------------------------------===//
@@ -246,7 +261,6 @@ static bool traverseRegionGraph(Region *begin,
   SmallVector<Region *> worklist;
   auto enqueueAllSuccessors = [&](Region *region) {
     LDBG() << "Enqueuing successors for region #" << region->getRegionNumber();
-    SmallVector<Attribute> operandAttributes(op->getNumOperands());
     for (Block &block : *region) {
       if (block.empty())
         continue;
@@ -255,8 +269,9 @@ static bool traverseRegionGraph(Region *begin,
       if (!terminator)
         continue;
       SmallVector<RegionSuccessor> successors;
-      operandAttributes.resize(terminator->getNumOperands());
-      terminator.getSuccessorRegions(operandAttributes, successors);
+      // No constant operand information is provided here; the unfiltered set
+      // of successors is sufficient for the region-graph traversal.
+      op.getSuccessorRegions(RegionBranchPoint(terminator), successors);
       LDBG() << "Found " << successors.size()
              << " successors from terminator in block";
       for (RegionSuccessor successor : successors) {
@@ -1088,15 +1103,18 @@ static SmallVector<RegionSuccessor>
 getSuccessorRegionsWithAttrs(RegionBranchOpInterface op,
                              RegionBranchPoint point) {
   SmallVector<RegionSuccessor> successors;
+  RegionBranchPointOperandConstants operandConstants;
+  SmallVector<Attribute> constants;
   if (point.isParent()) {
-    op.getEntrySuccessorRegions(extractConstants(op->getOperands()),
-                                successors);
-    return successors;
+    constants = extractConstants(op->getOperands());
+    operandConstants.setParentOperandConstants(constants);
+  } else {
+    RegionBranchTerminatorOpInterface terminator =
+        point.getTerminatorPredecessorOrNull();
+    constants = extractConstants(terminator->getOperands());
+    operandConstants.setTerminatorOperandConstants(terminator, constants);
   }
-  RegionBranchTerminatorOpInterface terminator =
-      point.getTerminatorPredecessorOrNull();
-  terminator.getSuccessorRegions(extractConstants(terminator->getOperands()),
-                                 successors);
+  op.getSuccessorRegionsWithConstants(point, operandConstants, successors);
   return successors;
 }
 
diff --git a/mlir/test/Transforms/sccp.mlir b/mlir/test/Transforms/sccp.mlir
index f5a5183593d03..ec61644dbb5ea 100644
--- a/mlir/test/Transforms/sccp.mlir
+++ b/mlir/test/Transforms/sccp.mlir
@@ -290,8 +290,9 @@ func.func @no_crash_acc_kernel_environment(%data: memref<8xi32>) {
 // -----
 
 // Regression test for https://github.com/llvm/llvm-project/issues/187973
-// SwitchOp::getEntrySuccessorRegions must not call IntegerAttr::getInt() on
-// an unsigned integer type — that function asserts signless/index only.
+// SwitchOp::getSuccessorRegionsWithConstants must not call
+// IntegerAttr::getInt() on an unsigned integer type — that function asserts
+// signless/index only.
 
 // CHECK-LABEL: no_crash_emitc_switch_unsigned_condition
 func.func @no_crash_emitc_switch_unsigned_condition() {



More information about the Mlir-commits mailing list