[Mlir-commits] [clang] [flang] [mlir] [mlir][Interfaces] Single interface method to query constant region-based CF (PR #193486)
Matthias Springer
llvmlistbot at llvm.org
Wed Apr 22 05:23:00 PDT 2026
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/193486
>From 56af902487800cf6c6d49250194f29b9eadc0943 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Wed, 22 Apr 2026 11:37:59 +0000
Subject: [PATCH] [mlir][Interfaces] Use single interface method to query
region-based control flow
---
clang/include/clang/CIR/Dialect/IR/CIROps.td | 1 -
clang/lib/CIR/Dialect/IR/CIRDialect.cpp | 17 ----
.../include/flang/Optimizer/Dialect/FIROps.td | 2 +-
flang/lib/Optimizer/Dialect/FIROps.cpp | 11 ++-
mlir/include/mlir/Dialect/EmitC/IR/EmitC.td | 4 +-
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 9 ++-
.../mlir/Interfaces/ControlFlowInterfaces.h | 45 +++++++++++
.../mlir/Interfaces/ControlFlowInterfaces.td | 80 +++++++++----------
.../Analysis/DataFlow/DeadCodeAnalysis.cpp | 10 ++-
mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp | 5 +-
mlir/lib/Dialect/EmitC/IR/EmitC.cpp | 21 ++++-
mlir/lib/Dialect/SCF/IR/SCF.cpp | 62 +++++++++-----
.../XeGPU/Transforms/XeGPULayoutImpl.cpp | 5 +-
mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 38 ++++++---
mlir/test/Transforms/sccp.mlir | 5 +-
15 files changed, 209 insertions(+), 106 deletions(-)
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index f20ba262d6480..7ed36869afe85 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -908,7 +908,6 @@ def CIR_IfOp : CIR_Op<"if", [
def CIR_ConditionOp : CIR_Op<"condition", [
Terminator,
DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface, [
- "getSuccessorRegions",
"getMutableSuccessorOperands"
]>
]> {
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 35a31b0dbda63..35d84f8255642 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -379,23 +379,6 @@ LogicalResult cir::BreakOp::verify() {
// BranchOpTerminatorInterface Methods
//===----------------------------------
-void cir::ConditionOp::getSuccessorRegions(
- ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> ®ions) {
- // TODO(cir): The condition value may be folded to a constant, narrowing
- // down its list of possible successors.
-
- // Parent is a loop: condition may branch to the body or to the parent op.
- if (auto loopOp = dyn_cast<LoopOpInterface>(getOperation()->getParentOp())) {
- regions.emplace_back(&loopOp.getBody());
- regions.push_back(RegionSuccessor::parent());
- }
-
- // Parent is an await: condition may branch to resume or suspend regions.
- auto await = cast<AwaitOp>(getOperation()->getParentOp());
- regions.emplace_back(&await.getResume());
- regions.emplace_back(&await.getSuspend());
-}
-
MutableOperandRange
cir::ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) {
// No values are yielded to the successor region.
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index eba4ddda4b6ad..0789c70fdd694 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -2459,7 +2459,7 @@ def fir_IfOp
: region_Op<
"if", [DeclareOpInterfaceMethods<
RegionBranchOpInterface, ["getRegionInvocationBounds",
- "getEntrySuccessorRegions",
+ "getSuccessorRegionsWithConstants",
"getSuccessorInputs"]>,
RecursiveMemoryEffects, NoRegionArguments,
WeightedRegionBranchOpInterface]> {
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 4705033945611..453dbd8a280a2 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -5109,9 +5109,16 @@ fir::IfOp::getSuccessorInputs(mlir::RegionSuccessor successor) {
return mlir::ValueRange();
}
-void fir::IfOp::getEntrySuccessorRegions(
- llvm::ArrayRef<mlir::Attribute> operands,
+void fir::IfOp::getSuccessorRegionsWithConstants(
+ mlir::RegionBranchPoint point,
+ const mlir::RegionBranchPointOperandConstants &operandConstants,
llvm::SmallVectorImpl<mlir::RegionSuccessor> ®ions) {
+ llvm::ArrayRef<mlir::Attribute> operands =
+ operandConstants.getOperandConstants(point);
+ if (!point.isParent() || operands.empty()) {
+ getSuccessorRegions(point, regions);
+ return;
+ }
FoldAdaptor adaptor(operands);
auto boolAttr =
mlir::dyn_cast_or_null<mlir::BoolAttr>(adaptor.getCondition());
diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
index 7d9bb8907eb8b..12cc1add40ee3 100644
--- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
+++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
@@ -1471,7 +1471,7 @@ def EmitC_YieldOp : EmitC_Op<"yield",
def EmitC_IfOp : EmitC_Op<"if",
[DeclareOpInterfaceMethods<RegionBranchOpInterface, [
"getNumRegionInvocations", "getRegionInvocationBounds",
- "getEntrySuccessorRegions", "getSuccessorInputs"]>,
+ "getSuccessorRegionsWithConstants", "getSuccessorInputs"]>,
OpAsmOpInterface, SingleBlock,
SingleBlockImplicitTerminator<"emitc::YieldOp">,
RecursiveMemoryEffects, NoRegionArguments]> {
@@ -1587,7 +1587,7 @@ def EmitC_SwitchOp : EmitC_Op<"switch", [RecursiveMemoryEffects,
OpAsmOpInterface, SingleBlockImplicitTerminator<"emitc::YieldOp">,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getRegionInvocationBounds",
- "getEntrySuccessorRegions"]>]> {
+ "getSuccessorRegionsWithConstants"]>]> {
let summary = "Switch operation";
let description = [{
The `emitc.switch` is a control-flow operation that branches to one of
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 0b33ecb48b7f2..1afcdc4b51395 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -55,7 +55,7 @@ class SCF_Op<string mnemonic, list<Trait> traits = []> :
def ConditionOp : SCF_Op<"condition", [
HasParent<"WhileOp">,
DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface,
- ["getSuccessorRegions", "getMutableSuccessorOperands"]>,
+ ["getMutableSuccessorOperands"]>,
Pure,
Terminator
]> {
@@ -704,7 +704,7 @@ def InParallelOp : SCF_Op<"forall.in_parallel", [
def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
"getNumRegionInvocations", "getRegionInvocationBounds",
- "getEntrySuccessorRegions", "getSuccessorInputs"]>,
+ "getSuccessorRegionsWithConstants", "getSuccessorInputs"]>,
DeclareOpInterfaceMethods<PromotableRegionOpInterface>,
InferTypeOpAdaptor, SingleBlockImplicitTerminator<"scf::YieldOp">,
RecursiveMemoryEffects, RecursivelySpeculatable, NoRegionArguments]> {
@@ -991,7 +991,8 @@ def ReduceReturnOp :
def WhileOp : SCF_Op<"while",
[DeclareOpInterfaceMethods<RegionBranchOpInterface,
- ["getEntrySuccessorOperands", "getSuccessorInputs"]>,
+ ["getEntrySuccessorOperands", "getSuccessorInputs",
+ "getSuccessorRegionsWithConstants"]>,
DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getRegionIterArgs", "getYieldedValuesMutable"]>,
DeclareOpInterfaceMethods<PromotableRegionOpInterface>,
@@ -1147,7 +1148,7 @@ def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
DeclareOpInterfaceMethods<PromotableRegionOpInterface>,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getRegionInvocationBounds",
- "getEntrySuccessorRegions",
+ "getSuccessorRegionsWithConstants",
"getSuccessorInputs"]>]> {
let summary = "switch-case operation on an index argument";
let description = [{
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index a76dce6f2ffc5..8138e95e69d36 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -270,6 +270,51 @@ class RegionBranchPoint {
Operation *predecessor = nullptr;
};
+/// A bundle of constant-operand information for a `RegionBranchOpInterface`
+/// implementation and any of the region branch terminators in its regions.
+///
+/// This is used by
+/// `RegionBranchOpInterface::getSuccessorRegionsWithConstants` to allow
+/// analyses/transformations to pass along constant operand values for a given
+/// branch point so that the implementation can refine the returned successors.
+/// Operand constants are optional: an implementation that doesn't recognize
+/// the provided constants (or for which no constants are provided) must
+/// return the same successors as the no-constants overload.
+///
+/// For each branch point (the parent op or a region branch terminator), the
+/// associated `ArrayRef<Attribute>` either has the same size as the number of
+/// operands of that op (with a null attribute for non-constant operands), or
+/// is empty (which means: no constant information is available for this
+/// branch point).
+class RegionBranchPointOperandConstants {
+public:
+ /// Default constructor: no constants known for any branch point.
+ RegionBranchPointOperandConstants() = default;
+
+ /// Set the constant operand information for the parent (region branch) op.
+ void setParentOperandConstants(ArrayRef<Attribute> constants) {
+ parentOperandConstants = constants;
+ }
+
+ /// Set the constant operand information for a specific region branch
+ /// terminator.
+ void setTerminatorOperandConstants(Operation *terminator,
+ ArrayRef<Attribute> constants) {
+ terminatorOperandConstants[terminator] = constants;
+ }
+
+ /// Returns the constant operand information for the given branch point.
+ /// Returns an empty range if no constants were provided for this point.
+ ArrayRef<Attribute> getOperandConstants(RegionBranchPoint point) const;
+
+private:
+ /// Constant operand information for the parent op of the region branch.
+ ArrayRef<Attribute> parentOperandConstants;
+
+ /// Constant operand information for region branch terminators.
+ DenseMap<Operation *, ArrayRef<Attribute>> terminatorOperandConstants;
+};
+
/// This class represents upper and lower bounds on the number of times a region
/// of a `RegionBranchOpInterface` can be invoked. The lower bound is at least
/// zero, but the upper bound may not be known.
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 06fa724e05fab..ead2a32139555 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -137,6 +137,13 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
specified with `getEntrySuccessorOperands` /
`RegionBranchTerminatorOpInterface::getSuccessorOperands`.
+ Implementations may also optionally refine the set of successors based on
+ constant operand values. To do so, implement
+ `getSuccessorRegionsWithConstants`, which that takes a
+ `RegionBranchPointOperandConstants` argument; that method is invoked by
+ analyses that have constant-folded information. By default, the
+ constants-aware overload dispatches to the no-constants overload.
+
A "region successor" indicates the target of a branch. It can indicate:
1. A region of this op.
2. `RegionSuccessor::parent()`, i.e., the control flow leaves this op.
@@ -183,7 +190,7 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
Returns the operands of this operation that are forwarded to the
successor inputs when branching to `successor`. `successor` is
guaranteed to be among the successors that are returned by
- `getEntrySuccessorRegions`/`getSuccessorRegions(parent())`.
+ `getSuccessorRegions(parent())`.
Example: In the above example, this method returns the operand %b of the
`scf.for` op, regardless of the value of `successor`. I.e., this op always
@@ -198,31 +205,39 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
}]
>,
InterfaceMethod<[{
- Returns the potential region successors when first executing the op.
+ Returns the potential region successors when branching from `point`.
+
+ This is a "constants-aware" variant of `getSuccessorRegions`. The
+ `operandConstants` parameter bundles up the constant values of the
+ operands of the parent (region branch) op and any of its region
+ branch terminators. Based on these, the implementation may filter out
+ certain successors that are statically known not to be taken.
- Unlike `getSuccessorRegions`, this method also passes along the
- constant operands of this op. Based on these, the implementation may
- filter out certain successors. By default, simply dispatches to
- `getSuccessorRegions`. `operands` contains an entry for every
- operand of this op, with a null attribute if the operand has no constant
- value.
+ By default, this method simply dispatches to `getSuccessorRegions`,
+ ignoring the constant operand information.
- Note: The control flow does not necessarily have to enter any region of
- this op.
+ Implementations should handle the case where no constant information
+ is available for a given branch point (i.e.,
+ `operandConstants.getOperandConstants(point)` returns an empty range)
+ by falling back to the no-constants overload.
Example: In the above example, this method may return two region
- region successors: the single region of the `scf.for` op and the
- `scf.for` operation (that implements this interface). If %lb, %ub, %step
- are constants and it can be determined the loop does not have any
- iterations, this method may choose to return only this operation.
- Similarly, if it can be determined that the loop has at least one
- iteration, this method may choose to return only the region of the loop.
+ region successors for the "parent" branching point: the single region
+ of the `scf.for` op and the `scf.for` operation (that implements this
+ interface). If %lb, %ub, %step are constants and it can be determined
+ the loop does not have any iterations, this method may choose to return
+ only this operation. Similarly, if it can be determined that the loop
+ has at least one iteration, this method may choose to return only the
+ region of the loop.
}],
- "void", "getEntrySuccessorRegions",
- (ins "::llvm::ArrayRef<::mlir::Attribute>":$operands,
- "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions), [{}],
+ "void", "getSuccessorRegionsWithConstants",
+ (ins "::mlir::RegionBranchPoint":$point,
+ "const ::mlir::RegionBranchPointOperandConstants &"
+ :$operandConstants,
+ "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions),
+ [{}],
/*defaultImplementation=*/[{
- $_op.getSuccessorRegions(mlir::RegionBranchPoint::parent(), regions);
+ $_op.getSuccessorRegions(point, regions);
}]
>,
InterfaceMethod<[{
@@ -423,6 +438,11 @@ def RegionBranchTerminatorOpInterface :
(However, there may be other block terminators in the same region that
implement the `RegionBranchTerminatorOpInterface`, so the enclosing region
may have region successors.)
+
+ To query the region successors that may be branched to after a given
+ terminator, use `RegionBranchOpInterface::getSuccessorRegions(point, ...)` /
+ `getSuccessorRegionsWithConstants(point, ...)` on the parent op (passing the
+ terminator as the branch point).
}];
let cppNamespace = "::mlir";
@@ -438,26 +458,6 @@ def RegionBranchTerminatorOpInterface :
return ::mlir::MutableOperandRange($_op);
}]
>,
- InterfaceMethod<[{
- Returns the potential region successors that are branched to after this
- terminator based on the given constant operands.
-
- This method also passes along the constant operands of this op.
- `operands` contains an entry for every operand of this op, with a null
- attribute if the operand has no constant value.
-
- The default implementation simply dispatches to the parent
- `RegionBranchOpInterface`'s `getSuccessorRegions` implementation.
- }],
- "void", "getSuccessorRegions",
- (ins "::llvm::ArrayRef<::mlir::Attribute>":$operands,
- "::llvm::SmallVectorImpl<::mlir::RegionSuccessor> &":$regions), [{}],
- /*defaultImplementation=*/[{
- ::mlir::Operation *op = $_op;
- ::llvm::cast<::mlir::RegionBranchOpInterface>(op->getParentOp())
- .getSuccessorRegions(::llvm::cast<::mlir::RegionBranchTerminatorOpInterface>(op), regions);
- }]
- >,
];
let verify = [{
diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
index 38811d06ecd8c..8cf323d13589d 100644
--- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp
@@ -497,7 +497,10 @@ void DeadCodeAnalysis::visitRegionBranchOperation(
return;
SmallVector<RegionSuccessor> successors;
- branch.getEntrySuccessorRegions(*operands, successors);
+ RegionBranchPointOperandConstants operandConstants;
+ operandConstants.setParentOperandConstants(*operands);
+ branch.getSuccessorRegionsWithConstants(RegionBranchPoint::parent(),
+ operandConstants, successors);
visitRegionBranchEdges(branch, branch.getOperation(), successors);
}
@@ -513,7 +516,10 @@ void DeadCodeAnalysis::visitRegionTerminator(Operation *op,
auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op);
if (!terminator)
return;
- terminator.getSuccessorRegions(*operands, successors);
+ RegionBranchPointOperandConstants operandConstants;
+ operandConstants.setTerminatorOperandConstants(terminator, *operands);
+ branch.getSuccessorRegionsWithConstants(RegionBranchPoint(terminator),
+ operandConstants, successors);
visitRegionBranchEdges(branch, op, successors);
}
diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
index 90f2a588d1ca4..5a2ac40f311d8 100644
--- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp
@@ -617,7 +617,10 @@ void AbstractSparseBackwardDataFlowAnalysis::visitRegionSuccessors(
Operation *op = branch.getOperation();
SmallVector<RegionSuccessor> successors;
SmallVector<Attribute> operands(op->getNumOperands(), nullptr);
- branch.getEntrySuccessorRegions(operands, successors);
+ RegionBranchPointOperandConstants operandConstants;
+ operandConstants.setParentOperandConstants(operands);
+ branch.getSuccessorRegionsWithConstants(RegionBranchPoint::parent(),
+ operandConstants, successors);
for (RegionSuccessor &successor : successors) {
if (successor.isParent())
continue;
diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
index 8def84fc49378..dc0ed005526b2 100644
--- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
+++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp
@@ -991,8 +991,15 @@ ValueRange IfOp::getSuccessorInputs(RegionSuccessor successor) {
: ValueRange();
}
-void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
- SmallVectorImpl<RegionSuccessor> ®ions) {
+void IfOp::getSuccessorRegionsWithConstants(
+ RegionBranchPoint point,
+ const RegionBranchPointOperandConstants &operandConstants,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ ArrayRef<Attribute> operands = operandConstants.getOperandConstants(point);
+ if (!point.isParent() || operands.empty()) {
+ getSuccessorRegions(point, regions);
+ return;
+ }
FoldAdaptor adaptor(operands, *this);
auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
if (!boolAttr || boolAttr.getValue())
@@ -1564,9 +1571,15 @@ static std::optional<int64_t> getIntAttrValue(IntegerAttr attr) {
return std::nullopt;
}
-void SwitchOp::getEntrySuccessorRegions(
- ArrayRef<Attribute> operands,
+void SwitchOp::getSuccessorRegionsWithConstants(
+ RegionBranchPoint point,
+ const RegionBranchPointOperandConstants &operandConstants,
SmallVectorImpl<RegionSuccessor> &successors) {
+ ArrayRef<Attribute> operands = operandConstants.getOperandConstants(point);
+ if (!point.isParent() || operands.empty()) {
+ getSuccessorRegions(point, successors);
+ return;
+ }
FoldAdaptor adaptor(operands, *this);
// If a constant was not provided, all regions are possible successors.
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 9f4f4dc9f58e6..e632077de678a 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -296,21 +296,6 @@ ConditionOp::getMutableSuccessorOperands(RegionSuccessor point) {
return getArgsMutable();
}
-void ConditionOp::getSuccessorRegions(
- ArrayRef<Attribute> operands, SmallVectorImpl<RegionSuccessor> ®ions) {
- FoldAdaptor adaptor(operands, *this);
-
- WhileOp whileOp = getParentOp();
-
- // Condition can either lead to the after region or back to the parent op
- // depending on whether the condition is true or not.
- auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
- if (!boolAttr || boolAttr.getValue())
- regions.emplace_back(&whileOp.getAfter());
- if (!boolAttr || !boolAttr.getValue())
- regions.push_back(RegionSuccessor::parent());
-}
-
//===----------------------------------------------------------------------===//
// ForOp
//===----------------------------------------------------------------------===//
@@ -2120,8 +2105,16 @@ ValueRange IfOp::getSuccessorInputs(RegionSuccessor successor) {
: ValueRange();
}
-void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
- SmallVectorImpl<RegionSuccessor> ®ions) {
+void IfOp::getSuccessorRegionsWithConstants(
+ RegionBranchPoint point,
+ const RegionBranchPointOperandConstants &operandConstants,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ // Constant folding only applies when entering from the parent op.
+ ArrayRef<Attribute> operands = operandConstants.getOperandConstants(point);
+ if (!point.isParent() || operands.empty()) {
+ getSuccessorRegions(point, regions);
+ return;
+ }
FoldAdaptor adaptor(operands, *this);
auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
if (!boolAttr || boolAttr.getValue())
@@ -3291,6 +3284,30 @@ void WhileOp::getSuccessorRegions(RegionBranchPoint point,
regions.emplace_back(&getAfter());
}
+void WhileOp::getSuccessorRegionsWithConstants(
+ RegionBranchPoint point,
+ const RegionBranchPointOperandConstants &operandConstants,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ // Constant folding only applies when branching from the `scf.condition`
+ // terminator of the "before" region. For all other branch points, fall back
+ // to the unfiltered behavior.
+ auto conditionOp =
+ dyn_cast_or_null<ConditionOp>(point.getTerminatorPredecessorOrNull());
+ ArrayRef<Attribute> operands = operandConstants.getOperandConstants(point);
+ if (!conditionOp || operands.empty()) {
+ getSuccessorRegions(point, regions);
+ return;
+ }
+ ConditionOp::FoldAdaptor adaptor(operands, conditionOp);
+ // Condition can either lead to the after region or back to the parent op
+ // depending on whether the condition is true or not.
+ auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
+ if (!boolAttr || boolAttr.getValue())
+ regions.emplace_back(&getAfter());
+ if (!boolAttr || !boolAttr.getValue())
+ regions.push_back(RegionSuccessor::parent());
+}
+
ValueRange WhileOp::getSuccessorInputs(RegionSuccessor successor) {
if (successor.isParent())
return getOperation()->getResults();
@@ -3842,9 +3859,16 @@ ValueRange IndexSwitchOp::getSuccessorInputs(RegionSuccessor successor) {
: ValueRange();
}
-void IndexSwitchOp::getEntrySuccessorRegions(
- ArrayRef<Attribute> operands,
+void IndexSwitchOp::getSuccessorRegionsWithConstants(
+ RegionBranchPoint point,
+ const RegionBranchPointOperandConstants &operandConstants,
SmallVectorImpl<RegionSuccessor> &successors) {
+ // Constant folding only applies when entering from the parent op.
+ ArrayRef<Attribute> operands = operandConstants.getOperandConstants(point);
+ if (!point.isParent() || operands.empty()) {
+ getSuccessorRegions(point, successors);
+ return;
+ }
FoldAdaptor adaptor(operands, *this);
// If a constant was not provided, all regions are possible successors.
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
index 7d48315eec6ff..b919d46a07ca7 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutImpl.cpp
@@ -174,7 +174,10 @@ static void propagateRegionResultsToYieldOperands(
SmallVector<RegionSuccessor> successors;
SmallVector<Attribute> operandAttrs(yieldOp->getNumOperands(), nullptr);
- yieldOp.getSuccessorRegions(operandAttrs, successors);
+ RegionBranchPointOperandConstants operandConstants;
+ operandConstants.setTerminatorOperandConstants(yieldOp, operandAttrs);
+ regionBranchOp.getSuccessorRegionsWithConstants(RegionBranchPoint(yieldOp),
+ operandConstants, successors);
for (const RegionSuccessor &successor : successors) {
OperandRange succOps = yieldOp.getSuccessorOperands(successor);
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index c3fb73acf5ef0..686a8c3b15484 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -152,6 +152,21 @@ LogicalResult detail::verifyRegionBranchWeights(Operation *op) {
return verifyWeights(op, weights, op->getNumRegions(), "region", "regions");
}
+//===----------------------------------------------------------------------===//
+// RegionBranchPointOperandConstants
+//===----------------------------------------------------------------------===//
+
+ArrayRef<Attribute> RegionBranchPointOperandConstants::getOperandConstants(
+ RegionBranchPoint point) const {
+ if (point.isParent())
+ return parentOperandConstants;
+ Operation *terminator = point.getTerminatorPredecessorOrNull();
+ auto it = terminatorOperandConstants.find(terminator);
+ if (it == terminatorOperandConstants.end())
+ return {};
+ return it->second;
+}
+
//===----------------------------------------------------------------------===//
// RegionBranchOpInterface
//===----------------------------------------------------------------------===//
@@ -246,7 +261,6 @@ static bool traverseRegionGraph(Region *begin,
SmallVector<Region *> worklist;
auto enqueueAllSuccessors = [&](Region *region) {
LDBG() << "Enqueuing successors for region #" << region->getRegionNumber();
- SmallVector<Attribute> operandAttributes(op->getNumOperands());
for (Block &block : *region) {
if (block.empty())
continue;
@@ -255,8 +269,9 @@ static bool traverseRegionGraph(Region *begin,
if (!terminator)
continue;
SmallVector<RegionSuccessor> successors;
- operandAttributes.resize(terminator->getNumOperands());
- terminator.getSuccessorRegions(operandAttributes, successors);
+ // No constant operand information is provided here; the unfiltered set
+ // of successors is sufficient for the region-graph traversal.
+ op.getSuccessorRegions(RegionBranchPoint(terminator), successors);
LDBG() << "Found " << successors.size()
<< " successors from terminator in block";
for (RegionSuccessor successor : successors) {
@@ -1088,15 +1103,18 @@ static SmallVector<RegionSuccessor>
getSuccessorRegionsWithAttrs(RegionBranchOpInterface op,
RegionBranchPoint point) {
SmallVector<RegionSuccessor> successors;
+ RegionBranchPointOperandConstants operandConstants;
+ SmallVector<Attribute> constants;
if (point.isParent()) {
- op.getEntrySuccessorRegions(extractConstants(op->getOperands()),
- successors);
- return successors;
+ constants = extractConstants(op->getOperands());
+ operandConstants.setParentOperandConstants(constants);
+ } else {
+ RegionBranchTerminatorOpInterface terminator =
+ point.getTerminatorPredecessorOrNull();
+ constants = extractConstants(terminator->getOperands());
+ operandConstants.setTerminatorOperandConstants(terminator, constants);
}
- RegionBranchTerminatorOpInterface terminator =
- point.getTerminatorPredecessorOrNull();
- terminator.getSuccessorRegions(extractConstants(terminator->getOperands()),
- successors);
+ op.getSuccessorRegionsWithConstants(point, operandConstants, successors);
return successors;
}
diff --git a/mlir/test/Transforms/sccp.mlir b/mlir/test/Transforms/sccp.mlir
index f5a5183593d03..ec61644dbb5ea 100644
--- a/mlir/test/Transforms/sccp.mlir
+++ b/mlir/test/Transforms/sccp.mlir
@@ -290,8 +290,9 @@ func.func @no_crash_acc_kernel_environment(%data: memref<8xi32>) {
// -----
// Regression test for https://github.com/llvm/llvm-project/issues/187973
-// SwitchOp::getEntrySuccessorRegions must not call IntegerAttr::getInt() on
-// an unsigned integer type — that function asserts signless/index only.
+// SwitchOp::getSuccessorRegionsWithConstants must not call
+// IntegerAttr::getInt() on an unsigned integer type — that function asserts
+// signless/index only.
// CHECK-LABEL: no_crash_emitc_switch_unsigned_condition
func.func @no_crash_emitc_switch_unsigned_condition() {
More information about the Mlir-commits
mailing list