[Mlir-commits] [mlir] [MLIR] Introduce support for early exits (PR #166688)
Mehdi Amini
llvmlistbot at llvm.org
Fri Feb 20 06:44:33 PST 2026
https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/166688
>From 6b05a50ed26d89ccb5077733840e80e80c6c8106 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Thu, 19 Feb 2026 04:45:19 -0800
Subject: [PATCH 1/2] [MLIR] format LangRef.md (NFC)
---
mlir/docs/LangRef.md | 103 +++++++++++++++++++++----------------------
1 file changed, 51 insertions(+), 52 deletions(-)
diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md
index b1da4b9360592..5e53df83997e2 100644
--- a/mlir/docs/LangRef.md
+++ b/mlir/docs/LangRef.md
@@ -23,7 +23,7 @@ transformations and analysis, and a compact serialized form suitable for storage
and transport. The different forms all describe the same semantic content. This
document describes the human-readable textual form.
-[TOC]
+\[TOC\]
## High-Level Structure
@@ -180,7 +180,6 @@ string-literal ::= `"` [^"\n\f\v\r]* `"` TODO: define escaping rules
Not listed here, but MLIR does support comments. They use standard BCPL syntax,
starting with a `//` and going until the end of the line.
-
### Top level Productions
```
@@ -188,8 +187,8 @@ starting with a `//` and going until the end of the line.
toplevel := (operation | attribute-alias-def | type-alias-def)*
```
-The production `toplevel` is the top level production that is parsed by any parsing
-consuming the MLIR syntax. [Operations](#operations),
+The production `toplevel` is the top level production that is parsed by any
+parsing consuming the MLIR syntax. [Operations](#operations),
[Attribute aliases](#attribute-value-aliases), and [Type aliases](#type-aliases)
can be declared on the toplevel.
@@ -233,7 +232,7 @@ body. Particular operations may further limit which identifiers are in scope in
their regions. For instance, the scope of values in a region with
[SSA control flow semantics](#control-flow-and-ssacfg-regions) is constrained
according to the standard definition of
-[SSA dominance](https://en.wikipedia.org/wiki/Dominator_\(graph_theory\)).
+[SSA dominance](<https://en.wikipedia.org/wiki/Dominator_(graph_theory)>).
Another example is the [IsolatedFromAbove trait](Traits/#isolatedfromabove),
which restricts directly accessing values defined in containing regions.
@@ -257,12 +256,12 @@ between, and within, different dialects.
A few of the dialects supported by MLIR:
-* [Affine dialect](Dialects/Affine.md)
-* [Func dialect](Dialects/Func.md)
-* [GPU dialect](Dialects/GPU.md)
-* [LLVM dialect](Dialects/LLVM.md)
-* [SPIR-V dialect](Dialects/SPIR-V.md)
-* [Vector dialect](Dialects/Vector.md)
+- [Affine dialect](Dialects/Affine.md)
+- [Func dialect](Dialects/Func.md)
+- [GPU dialect](Dialects/GPU.md)
+- [LLVM dialect](Dialects/LLVM.md)
+- [SPIR-V dialect](Dialects/SPIR-V.md)
+- [Vector dialect](Dialects/Vector.md)
### Target specific operations
@@ -319,8 +318,8 @@ identified by a unique string (e.g. `dim`, `tf.Conv2d`, `x86.repmovsb`,
has storage for [properties](#properties), has a dictionary of
[attributes](#attributes), has zero or more successors, and zero or more
enclosed [regions](#regions). The generic printing form includes all these
-elements literally, with a function type to indicate the types of the
-results and operands.
+elements literally, with a function type to indicate the types of the results
+and operands.
Example:
@@ -424,12 +423,12 @@ func.func @simple(i64, i1) -> i64 {
**Context:** The "block argument" representation eliminates a number of special
cases from the IR compared to traditional "PHI nodes are operations" SSA IRs
(like LLVM). For example, the
-[parallel copy semantics](https://ieeexplore.ieee.org/document/4907656)
-of SSA is immediately apparent, and function arguments are no longer a special
-case: they become arguments to the entry block
-[[more rationale](Rationale/Rationale.md/#block-arguments-vs-phi-nodes)]. Blocks
-are also a fundamental concept that cannot be represented by operations because
-values defined in an operation cannot be accessed outside the operation.
+[parallel copy semantics](https://ieeexplore.ieee.org/document/4907656) of SSA
+is immediately apparent, and function arguments are no longer a special case:
+they become arguments to the entry block
+\[[more rationale](Rationale/Rationale.md/#block-arguments-vs-phi-nodes)\].
+Blocks are also a fundamental concept that cannot be represented by operations
+because values defined in an operation cannot be accessed outside the operation.
## Regions
@@ -440,8 +439,9 @@ region is not imposed by the IR. Instead, the containing operation defines the
semantics of the regions it contains. MLIR currently defines two kinds of
regions: [SSACFG regions](#control-flow-and-ssacfg-regions), which describe
control flow between blocks, and [Graph regions](#graph-regions), which do not
-require control flow between blocks. The kinds of regions within an operation are
-described using the [RegionKindInterface](Interfaces.md/#regionkindinterfaces).
+require control flow between blocks. The kinds of regions within an operation
+are described using the
+[RegionKindInterface](Interfaces.md/#regionkindinterfaces).
Regions do not have a name or an address, only the blocks contained in a region
do. Regions must be contained within operations and have no type or attributes.
@@ -463,10 +463,9 @@ arguments must match the result types of the function signature. Similarly, the
function arguments must match the types and count of the region arguments. In
general, operations with regions can define these correspondences arbitrarily.
-An *entry block* is a block with no label and no arguments that may occur at
-the beginning of a region. It enables a common pattern of using a region to
-open a new scope.
-
+An *entry block* is a block with no label and no arguments that may occur at the
+beginning of a region. It enables a common pattern of using a region to open a
+new scope.
### Value Scoping
@@ -478,8 +477,7 @@ the enclosing region, if any. By default, operations inside a region can
reference values defined outside of the region whenever it would have been legal
for operands of the enclosing operation to reference those values, but this can
be restricted using traits, such as
-[OpTrait::IsolatedFromAbove](Traits/#isolatedfromabove), or a custom
-verifier.
+[OpTrait::IsolatedFromAbove](Traits/#isolatedfromabove), or a custom verifier.
Example:
@@ -707,9 +705,9 @@ dialect-type-contents ::= dialect-type-body
| [^\[<({\]>)}\0]+
```
-Dialect types are generally specified in an opaque form, where the contents
-of the type are defined within a body wrapped with the dialect namespace
-and `<>`. Consider the following examples:
+Dialect types are generally specified in an opaque form, where the contents of
+the type are defined within a body wrapped with the dialect namespace and `<>`.
+Consider the following examples:
```mlir
// A tensorflow string type.
@@ -733,7 +731,8 @@ part of the syntax into an equivalent, but lighter weight form:
!foo.something<abcd>
```
-See [here](DefiningDialects/AttributesAndTypes.md) to learn how to define dialect types.
+See [here](DefiningDialects/AttributesAndTypes.md) to learn how to define
+dialect types.
### Builtin Types
@@ -761,29 +760,28 @@ attribute-value ::= attribute-alias | dialect-attribute | builtin-attribute
Attributes are the mechanism for specifying constant data on operations in
places where a variable is never allowed - e.g. the comparison predicate of a
-[`cmpi` operation](Dialects/ArithOps.md/#arithcmpi-arithcmpiop). Each operation has an
-attribute dictionary, which associates a set of attribute names to attribute
-values. MLIR's builtin dialect provides a rich set of
+[`cmpi` operation](Dialects/ArithOps.md/#arithcmpi-arithcmpiop). Each operation
+has an attribute dictionary, which associates a set of attribute names to
+attribute values. MLIR's builtin dialect provides a rich set of
[builtin attribute values](#builtin-attribute-values) out of the box (such as
arrays, dictionaries, strings, etc.). Additionally, dialects can define their
own [dialect attribute values](#dialect-attribute-values).
For dialects which haven't adopted properties yet, the top-level attribute
-dictionary attached to an operation has special semantics. The attribute
-entries are considered to be of two different kinds based on whether their
-dictionary key has a dialect prefix:
-
-- *inherent attributes* are inherent to the definition of an operation's
- semantics. The operation itself is expected to verify the consistency of
- these attributes. An example is the `predicate` attribute of the
- `arith.cmpi` op. These attributes must have names that do not start with a
- dialect prefix.
-
-- *discardable attributes* have semantics defined externally to the operation
- itself, but must be compatible with the operations's semantics. These
- attributes must have names that start with a dialect prefix. The dialect
- indicated by the dialect prefix is expected to verify these attributes. An
- example is the `gpu.container_module` attribute.
+dictionary attached to an operation has special semantics. The attribute entries
+are considered to be of two different kinds based on whether their dictionary
+key has a dialect prefix:
+
+- *inherent attributes* are inherent to the definition of an operation's
+ semantics. The operation itself is expected to verify the consistency of these
+ attributes. An example is the `predicate` attribute of the `arith.cmpi` op.
+ These attributes must have names that do not start with a dialect prefix.
+
+- *discardable attributes* have semantics defined externally to the operation
+ itself, but must be compatible with the operations's semantics. These
+ attributes must have names that start with a dialect prefix. The dialect
+ indicated by the dialect prefix is expected to verify these attributes. An
+ example is the `gpu.container_module` attribute.
Note that attribute values are allowed to themselves be dictionary attributes,
but only the top-level dictionary attribute attached to the operation is subject
@@ -851,15 +849,16 @@ and `<>`. Consider the following examples:
#foo<"a123^^^" + bar>
```
-Dialect attributes that are simple enough may use a prettier format, which unwraps
-part of the syntax into an equivalent, but lighter weight form:
+Dialect attributes that are simple enough may use a prettier format, which
+unwraps part of the syntax into an equivalent, but lighter weight form:
```mlir
// A string attribute.
#foo.string<"">
```
-See [here](DefiningDialects/AttributesAndTypes.md) on how to define dialect attribute values.
+See [here](DefiningDialects/AttributesAndTypes.md) on how to define dialect
+attribute values.
### Builtin Attribute Values
>From 79ed734d7c1c82370ac55a2a1328c8458c212010 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sat, 26 Apr 2025 04:51:22 -0700
Subject: [PATCH 2/2] [MLIR] Introduce support for early exits
WIP, mostly lacking documentation, possibly more dataflow
fixes as well.
Need to revisit the traits to support this, formalize better region termination
and requirement on the numBreakingRegion index.
Right now inlining is incompatible with early returns, it would
require inserting some scf.execute region somehow.
---
mlir/docs/LangRef.md | 128 ++-
mlir/include/mlir/Dialect/SCF/IR/SCF.h | 45 ++
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 195 ++++-
mlir/include/mlir/IR/Diagnostics.h | 2 +-
mlir/include/mlir/IR/OpBase.td | 11 +
mlir/include/mlir/IR/OpDefinition.h | 16 +-
mlir/include/mlir/IR/Operation.h | 72 +-
mlir/include/mlir/IR/OperationSupport.h | 14 +
mlir/include/mlir/IR/RegionKindInterface.h | 145 ++++
mlir/include/mlir/IR/RegionKindInterface.td | 58 ++
mlir/lib/AsmParser/Parser.cpp | 49 +-
.../SCFToControlFlow/SCFToControlFlow.cpp | 144 +++-
mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp | 2 +-
.../ShapeToStandard/ShapeToStandard.cpp | 3 +-
.../OwnershipBasedBufferDeallocation.cpp | 8 +-
.../GPU/Transforms/AsyncRegionRewriter.cpp | 3 +-
.../Quant/Transforms/NormalizeQuantTypes.cpp | 3 +-
mlir/lib/Dialect/SCF/IR/SCF.cpp | 746 ++++++++++++++++--
.../SCF/IR/ValueBoundsOpInterfaceImpl.cpp | 4 +-
.../BufferizableOpInterfaceImpl.cpp | 4 +-
.../TosaConvertIntegerTypeToSignless.cpp | 3 +-
mlir/lib/IR/AsmPrinter.cpp | 4 +
mlir/lib/IR/Diagnostics.cpp | 4 +-
mlir/lib/IR/Dominance.cpp | 30 +
mlir/lib/IR/Operation.cpp | 45 +-
mlir/lib/IR/PatternMatch.cpp | 5 +-
mlir/lib/IR/RegionKindInterface.cpp | 156 ++++
mlir/lib/IR/Verifier.cpp | 41 +
mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 21 +-
mlir/lib/TableGen/Operator.cpp | 10 +-
mlir/lib/Transforms/Utils/CMakeLists.txt | 3 +
mlir/lib/Transforms/Utils/InliningUtils.cpp | 7 +
mlir/test/Analysis/test-dominance.mlir | 26 +
.../convert-early-exit-to-cfg.mlir | 84 ++
mlir/test/Dialect/SCF/loop_canonicalize.mlir | 331 ++++++++
mlir/test/IR/early-exit-invalid.mlir | 67 ++
mlir/test/IR/early-exit.mlir | 77 ++
.../Integration/Dialect/SCF/early_exit.mlir | 82 ++
mlir/test/lib/IR/TestDominance.cpp | 33 +
mlir/test/lib/Interfaces/CMakeLists.txt | 1 +
.../RegionBranchOpInterface/CMakeLists.txt | 9 +
.../TestRegionBranchOpInterface.cpp | 76 ++
mlir/test/mlir-tblgen/op-error.td | 2 +-
mlir/tools/mlir-opt/CMakeLists.txt | 1 +
mlir/tools/mlir-opt/mlir-opt.cpp | 2 +
mlir/tools/mlir-tblgen/FormatGen.cpp | 1 +
mlir/tools/mlir-tblgen/FormatGen.h | 2 +
mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 14 +
mlir/tools/mlir-tblgen/OpFormatGen.cpp | 46 ++
.../FileLineColLocBreakpointManagerTest.cpp | 3 +-
mlir/unittests/IR/OperationSupportTest.cpp | 5 +-
mlir/unittests/IR/ValueTest.cpp | 3 +-
.../Transforms/DialectConversion.cpp | 3 +-
53 files changed, 2666 insertions(+), 183 deletions(-)
create mode 100644 mlir/test/Conversion/SCFToControlFlow/convert-early-exit-to-cfg.mlir
create mode 100644 mlir/test/Dialect/SCF/loop_canonicalize.mlir
create mode 100644 mlir/test/IR/early-exit-invalid.mlir
create mode 100644 mlir/test/IR/early-exit.mlir
create mode 100644 mlir/test/Integration/Dialect/SCF/early_exit.mlir
create mode 100644 mlir/test/lib/Interfaces/RegionBranchOpInterface/CMakeLists.txt
create mode 100644 mlir/test/lib/Interfaces/RegionBranchOpInterface/TestRegionBranchOpInterface.cpp
diff --git a/mlir/docs/LangRef.md b/mlir/docs/LangRef.md
index 5e53df83997e2..acf16ee67f955 100644
--- a/mlir/docs/LangRef.md
+++ b/mlir/docs/LangRef.md
@@ -296,7 +296,9 @@ generic-operation ::= string-literal `(` value-use-list? `)` successor-list
custom-operation ::= bare-id custom-operation-format
op-result-list ::= op-result (`,` op-result)* `=`
op-result ::= value-id (`:` integer-literal)?
-successor-list ::= `[` successor (`,` successor)* `]`
+successor-list ::= `[` successor-list-inner `]`
+successor-list-inner ::= successor (`,` successor)* | num-breaking-regions
+num-breaking-regions ::= integer-literal
successor ::= caret-id (`:` block-arg-list)?
dictionary-properties ::= `<` dictionary-attribute `>`
region-list ::= `(` region (`,` region)* `)`
@@ -504,10 +506,20 @@ In MLIR, control flow semantics of a region is indicated by
regions support semantics where operations in a region 'execute sequentially'.
Before an operation executes, its operands have well-defined values. After an
operation executes, the operands have the same values and results also have
-well-defined values. After an operation executes, the next operation in the
-block executes until the operation is the terminator operation at the end of a
-block, in which case some other operation will execute. The determination of the
-next instruction to execute is the 'passing of control flow'.
+well-defined values.
+
+Usually, after an operation executes, the next operation in the block executes
+until the operation is the terminator operation at the end of a block, in which
+case the control will be transfered to another block or one of the parent
+operations. The determination of the next instruction to execute is the 'passing
+of control flow'. The control-flow can be interrupted by an operation if it
+defines the `PropagateControlFlowBreak` trait. Such an operation does not handle
+the break itself; it transparently propagates it outward to an ancestor
+operation that implements `HasBreakingControlFlowOpInterface`. The actual break
+is initiated by a nested [Region Terminator](#region-terminator). Every
+operation between the `RegionTerminator` and the receiving
+`HasBreakingControlFlowOpInterface` operation must define
+`PropagateControlFlowBreak`.
In general, when control flow is passed to an operation, MLIR does not restrict
when control flow enters or exits the regions contained in that operation.
@@ -515,22 +527,23 @@ However, when control flow enters a region, it always begins in the first block
of the region, called the *entry* block. Terminator operations ending each block
represent control flow by explicitly specifying the successor blocks of the
block. Control flow can only pass to one of the specified successor blocks as in
-a `branch` operation, or back to the containing operation as in a `return`
-operation. Terminator operations without successors can only pass control back
-to the containing operation. Within these restrictions, the particular semantics
-of terminator operations is determined by the specific dialect operations
-involved. Blocks (other than the entry block) that are not listed as a successor
-of a terminator operation are defined to be unreachable and can be removed
-without affecting the semantics of the containing operation.
+a `branch` operation, or back to one of the enclosing parent operations, as in a
+`return` operation. Terminator operations without block successors can only pass
+control back to one of the enclosing parent operations, in this case an integer
+defines the number of parent region to break through. Within these restrictions,
+the particular semantics of terminator operations is determined by the specific
+dialect operations involved. Blocks (other than the entry block) that are not
+listed as a successor of a terminator operation are defined to be unreachable
+and can be removed without affecting the semantics of the containing operation.
Although control flow always enters a region through the entry block, control
flow may exit a region through any block with an appropriate terminator. The
-standard dialect leverages this capability to define operations with
+LLVM dialect for example leverages this capability to define operations with
Single-Entry-Multiple-Exit (SEME) regions, possibly flowing through different
blocks in the region and exiting through any block with a `return` operation.
-This behavior is similar to that of a function body in most programming
-languages. In addition, control flow may also not reach the end of a block or
-region, for example if a function call does not return.
+This behavior can model that of a function body in most programming languages.
+In addition, control flow may also not reach the end of a block or region, for
+example if a function call does not return.
Example:
@@ -558,6 +571,89 @@ func.func @accelerator_compute(i64, i1) -> i64 { // An SSACFG region
}
```
+#### Region Terminator
+
+A `RegionTerminator` is a specialization of a block terminator (the `Terminator`
+trait) that transfers the control back to a parent operation. It can exit
+multiple nested regions in a single step, bypassing the normal
+`RegionBranchOpInterface` exit path for every intermediate level. In the generic
+operation format, the exit count appears in the successor-list brackets as a
+plain integer rather than a block label (see `num-breaking-regions` in the
+grammar above). Custom assembly formats may surface this as a literal integer
+argument (e.g. `scf.break 2`).
+
+`num-breaking-regions = N` means the operation exits **N region levels** in
+total, counting its own immediately enclosing region as 1:
+
+- `N = 1`: normal region exit — control is returned to the immediate parent
+ operation (e.g. `scf.yield` in `scf.if`, or `scf.break 1` in `scf.loop`).
+- `N = 2`: exits the current region and the immediate ancestor region; the
+ parent op must define `PropagateControlFlowBreak`.
+- `N = K`: exits K region levels; the N=0...K−1 intermediate parent operations
+ must all define `PropagateControlFlowBreak`.
+
+If the outermost operation *receives* a break from a `RegionTerminator` that
+isn't in a immediate region (N==1 above), then it must implement
+`HasBreakingControlFlowOpInterface`. It is distinct from intermediate ops that
+merely propagate the break (`PropagateControlFlowBreak`). For example, a loop
+operation nested inside another loop operation body may carries both traits
+simultaneously: it handles breaks targeting itself, and propagates breaks that
+target an outer loop through it.
+
+Region terminators may carry values, which are propagated to the target
+operation. For example, when breaking out of a loop that produces results, the
+terminator supplies those result values directly.
+
+Examples:
+
+```mlir
+// scf.yield is the standard region terminator (num-breaking-regions = 1).
+// It exits only its own immediately enclosing region.
+scf.if %cond {
+ scf.yield // returns control to the immediate parent of scf.if
+}
+```
+
+```mlir
+// Trait legend:
+// [H] = HasBreakingControlFlowOpInterface (receives/catches the break)
+// [P] = PropagateControlFlowBreak (passes the break upward unchanged)
+// [H][P] = both: handles breaks targeting it, and propagates breaks that
+// target an outer loop through it
+scf.loop { // [H]
+ scf.loop { // [H][P]
+ scf.if %cond1 { // [P]
+ // Exits if-region (1) + inner-loop-region (2) → breaks inner loop.
+ scf.break 2
+ }
+ scf.if %cond2 { // [P]
+ // Exits if-region (1) + inner-loop-region (2) + outer-loop-region (3)
+ // → breaks outer loop.
+ scf.break 3
+ }
+ scf.if %cond3 { // [P]
+ // Exits if-region (1) + inner-loop-region (2), re-entering inner loop.
+ scf.continue 2
+ }
+ }
+}
+return
+```
+
+```mlir
+// A loop that yields a result value on early exit.
+// scf.break N carries operands that become the loop's results.
+// scf.continue N carries operands that become the next iter_args.
+%result = scf.loop -> f32 { // [H]
+ scf.if %found { // [P]
+ // Exits if-region + loop-region; %value becomes the loop result.
+ scf.break 2 %value : f32
+ }
+ // Re-enter the loop for the next iteration (no iter_args here).
+ scf.continue 1
+}
+```
+
#### Operations with Multiple Regions
An operation containing multiple regions also completely determines the
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index e754a04b0903a..573586d8671ad 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -29,6 +29,11 @@
namespace mlir {
namespace scf {
void buildTerminatedBody(OpBuilder &builder, Location loc);
+
+namespace op_impl {
+struct IfOpImplicitTerminatorType;
+struct LoopOpImplicitTerminatorType;
+} // namespace op_impl
} // namespace scf
} // namespace mlir
@@ -111,6 +116,46 @@ SmallVector<Value> replaceAndCastForOpIterArg(RewriterBase &rewriter,
OpOperand &operand,
Value replacement,
const ValueTypeCastFnTy &castFn);
+namespace op_impl {
+
+//===----------------------------------------------------------------------===//
+// ControlFlowImplicitTerminatorOperation
+//===----------------------------------------------------------------------===//
+
+/// This class provides an interface compatible with
+/// SingleBlockImplicitTerminator, but allows multiple types of potential
+/// terminators aside from just one. If a terminator isn't present, this will
+/// generate a `ImplicitOpT` operation.
+template <typename ImplicitOpT, typename... OtherTerminatorOpTs>
+struct ControlFlowImplicitTerminatorOpType {
+ /// Implementation of `classof` that supports all of the potential terminator
+ /// operations.
+ static bool classof(Operation *op) {
+ return isa<ImplicitOpT, OtherTerminatorOpTs...>(op);
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Implicit Terminator Methods
+
+ /// The following methods are all used when interacting with the "implicit"
+ /// terminator.
+
+ template <typename... Args>
+ static void build(Args &&...args) {
+ ImplicitOpT::build(std::forward<Args>(args)...);
+ }
+ static constexpr StringLiteral getOperationName() {
+ return ImplicitOpT::getOperationName();
+ }
+};
+/// An implicit terminator type for `if` operations, which can contain:
+/// break, continue, yield.
+struct IfOpImplicitTerminatorType
+ : public ControlFlowImplicitTerminatorOpType<YieldOp, BreakOp, ContinueOp> {
+};
+struct LoopOpImplicitTerminatorType
+ : public ControlFlowImplicitTerminatorOpType<ContinueOp, BreakOp> {};
+} // namespace op_impl
/// Helper function to compute the difference between two values. This is used
/// by the loop implementations to compute the trip count.
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index a08cf3c95e6ce..21f20d05ea116 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -15,6 +15,7 @@
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/LoopLikeInterface.td"
+include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/RegionKindInterface.td"
include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
@@ -144,6 +145,184 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// LoopOp
+//===----------------------------------------------------------------------===//
+
+def LoopOp : SCF_Op<"loop",[
+ AutomaticAllocationScope,
+ OpAsmOpInterface,
+ RecursiveMemoryEffects,
+ PropagateControlFlowBreak,
+ DeclareOpInterfaceMethods<RegionBranchOpInterface,
+ ["getEntrySuccessorOperands", "getSuccessorInputs"]>,
+ SingleBlockImplicitTerminator<"op_impl::LoopOpImplicitTerminatorType">,
+ HasBreakingControlFlowOpInterface,
+ HasNestedTerminator<["ContinueOp", "BreakOp"]>
+ ]> {
+ let summary = "Loop until a break operation";
+ let description = [{
+ The `loop` operation represents an, unstructured, infinite loop that executes
+ until a `break` is reached.
+
+ The loop consists of a (1) a set of loop-carried values which are initialized by
+ `initValues` and updated by each iteration of the loop, and
+ (2) a region which represents the loop body.
+
+ The loop will execute the body of the loop until a `break` is dynamically executed.
+
+ Each control path of the loop must be terminated by:
+
+ - a `continue` that yields the next iteration's value for each loop carried variable.
+ - a `break` that terminates the loop and yields the final loop carried values.
+
+ As long as each loop iteration is terminated by one of these operations they may be combined with other control
+ flow operations to express different control flow patterns.
+
+ The loop operation produces one return value for each loop carried variable. The type of the `i`-th return
+ value is that of the `i`-th loop carried variable and its value is the final value of the
+ `i`-th loop carried variable.
+ }];
+
+ let arguments = (ins Variadic<AnyType>:$initValues);
+ let results = (outs Variadic<AnyType>:$resultValues);
+ let regions = (region SizedRegion<1>:$region);
+
+ let extraClassDeclaration = [{
+ /// Required by HasBreakingControlFlowOpInterface. Returns true only for
+ /// scf.break and scf.continue, which are the RegionTerminators that can
+ /// target this loop as their HasBreakingControlFlowOpInterface receiver.
+ static bool acceptsTerminator(Operation *predecessor) {
+ return isa<BreakOp, ContinueOp>(predecessor);
+ }
+
+ /// Return the iteration values of the loop region.
+ Block::BlockArgListType getRegionIterValues() {
+ return getRegion().getArguments();
+ }
+
+ /// Return the `index`-th region iteration value.
+ BlockArgument getRegionIterValue(unsigned index) {
+ return getRegionIterValues()[index];
+ }
+
+ /// Returns the number of region arguments for loop-carried values.
+ unsigned getNumRegionIterValues() { return getRegion().getNumArguments(); }
+
+ /// Returns the loop block body
+ Block *getBody() { return &getRegion().front(); }
+ }];
+
+ let hasCustomAssemblyFormat = 1;
+ let hasRegionVerifier = 1;
+ let hasCanonicalizer = 1;
+}
+
+
+//===----------------------------------------------------------------------===//
+// BreakOp
+//===----------------------------------------------------------------------===//
+
+def BreakOp : SCF_Op<"break", [
+ ReturnLike, Terminator, RegionTerminator, RegionBranchTerminatorOpInterface,
+ ParentOneOf<["IfOp", "LoopOp"]>
+ ]> {
+ let summary = "Break from loop";
+ let description = [{
+ The `break` operation is a `RegionTerminator` that exits one or more nested
+ regions and terminates a `scf.loop`. The mandatory `num-breaking-regions`
+ integer N indicates how many region levels to exit:
+
+ - `N = 1`: the break exits only its immediately enclosing region, which
+ must be a `scf.loop` body. This is the normal exit path of a loop.
+ - `N = K`: the break exits K nested region levels. Levels 1 .. K-1 must
+ be enclosed by operations carrying the `PropagateControlFlowBreak`
+ trait (e.g. `scf.if` or an inner `scf.loop`), and the operation at
+ level K must be a `scf.loop`.
+
+ The `break` may yield any number of operands; their types must match the
+ result types of the target `scf.loop`.
+
+ Example — break out of the immediately enclosing loop:
+ ```mlir
+ scf.loop -> i32 {
+ scf.break 1 %result : i32
+ }
+ ```
+
+ Example — break out of a loop through an enclosing `scf.if`:
+ ```mlir
+ scf.loop {
+ scf.if %cond {
+ scf.break 2 // exits the if-body (1) and the loop-body (2)
+ }
+ scf.continue 1
+ }
+ ```
+ }];
+
+
+ let arguments = (ins Variadic<AnyType>:$operands);
+ let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
+ let assemblyFormat = [{
+ num-breaking-regions attr-dict ($operands^ `:` type($operands))?
+ }];
+}
+
+
+//===----------------------------------------------------------------------===//
+// ContinueOp
+//===----------------------------------------------------------------------===//
+
+def ContinueOp : SCF_Op<"continue", [
+ Terminator, RegionTerminator, DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>, ParentOneOf<["IfOp", "LoopOp"]>
+ ]> {
+ let summary = "Continue to next loop iteration";
+ let description = [{
+ The `continue` operation is a `RegionTerminator` that re-enters a `scf.loop`
+ for its next iteration. Like `scf.break`, it carries a mandatory
+ `num-breaking-regions` integer N:
+
+ - `N = 1`: continues the immediately enclosing `scf.loop`.
+ - `N = K`: continues the K-th enclosing `scf.loop`, exiting K-1
+ intermediate regions. The intermediate parent ops (levels 1 .. K-1)
+ must carry `PropagateControlFlowBreak`.
+
+ The operands of `continue` become the loop-carried values (iter_args) for
+ the next iteration; their types must match the loop's iter_arg types.
+
+ Example — continue the immediately enclosing loop:
+ ```mlir
+ scf.loop iter_args(%i = %init) : i32 {
+ %next = arith.addi %i, %one : i32
+ scf.continue 1 %next : i32
+ }
+ ```
+
+ Example — continue an outer loop from inside a nested `scf.if`:
+ ```mlir
+ scf.loop iter_args(%counter = %init) : i64 {
+ scf.loop iter_args(%inner = %counter) : i64 {
+ scf.if %restart_outer {
+ // Exit inner loop body (1) + inner loop (2), re-entering outer loop.
+ scf.continue 3 %inner : i64
+ }
+ scf.continue 1 %inner : i64
+ }
+ scf.continue 1 %counter : i64
+ }
+ ```
+ }];
+
+ let arguments = (ins Variadic<AnyType>:$operands);
+ let builders = [OpBuilder<(ins), [{
+ $_state.setNumBreakingControlRegions(1);
+ }]>];
+ let assemblyFormat = [{
+ num-breaking-regions ($operands^ `:` type($operands))? attr-dict
+ }];
+ let hasVerifier = 1;
+}
//===----------------------------------------------------------------------===//
// ForOp
@@ -701,8 +880,8 @@ def InParallelOp : SCF_Op<"forall.in_parallel", [
def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
"getNumRegionInvocations", "getRegionInvocationBounds",
"getEntrySuccessorRegions", "getSuccessorInputs"]>,
- InferTypeOpAdaptor, SingleBlockImplicitTerminator<"scf::YieldOp">,
- RecursiveMemoryEffects, RecursivelySpeculatable, NoRegionArguments]> {
+ InferTypeOpAdaptor, SingleBlockImplicitTerminator<"op_impl::IfOpImplicitTerminatorType">,
+ RecursiveMemoryEffects, RecursivelySpeculatable, NoRegionArguments, PropagateControlFlowBreak]> {
let summary = "if-then-else operation";
let description = [{
The `scf.if` operation represents an if-then-else construct for
@@ -785,9 +964,19 @@ def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
: OpBuilder::atBlockEnd(body, listener);
}
Block* thenBlock();
+ [[deprecated("Use thenTerminator() instead")]]
YieldOp thenYield();
+ // Returns either scf.break, scf.continue, or scf.yield.
+ Operation *thenTerminator() {
+ return thenBlock()->getTerminator();
+ }
Block* elseBlock();
+ [[deprecated("Use elseTerminator() instead")]]
YieldOp elseYield();
+ // Returns either scf.break, scf.continue, or scf.yield.
+ Operation *elseTerminator() {
+ return elseBlock()->getTerminator();
+ }
}];
let hasFolder = 1;
let hasCanonicalizer = 1;
@@ -903,7 +1092,7 @@ def ParallelOp : SCF_Op<"parallel",
//===----------------------------------------------------------------------===//
def ReduceOp : SCF_Op<"reduce", [
- Terminator, HasParent<"ParallelOp">, RecursiveMemoryEffects,
+ Terminator, ImmediateRegionTerminator, HasParent<"ParallelOp">, RecursiveMemoryEffects,
DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>]> {
let summary = "reduce operation for scf.parallel";
let description = [{
diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h
index 3b8fb46b06a48..b5b3c55d35c75 100644
--- a/mlir/include/mlir/IR/Diagnostics.h
+++ b/mlir/include/mlir/IR/Diagnostics.h
@@ -200,7 +200,7 @@ class Diagnostic {
/// Stream in an Operation.
Diagnostic &operator<<(Operation &op);
- Diagnostic &operator<<(OpWithFlags op);
+ Diagnostic &operator<<(const OpWithFlags &opWithFlags);
Diagnostic &operator<<(Operation *op) { return *this << *op; }
/// Append an operation with the given printing flags.
Diagnostic &appendOp(Operation &op, const OpPrintingFlags &flags);
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 7a667d701ab71..99ddaa37e2983 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -98,6 +98,10 @@ def SameOperandsAndResultElementType :
NativeOpTrait<"SameOperandsAndResultElementType">;
// Op is a terminator.
def Terminator : NativeOpTrait<"IsTerminator">;
+// Op is a region terminator for the immediate region only.
+def ImmediateRegionTerminator : NativeOpTrait<"RegionTerminator", [Terminator]>;
+// Op is a region terminator, potentially breaking multiple regions
+def RegionTerminator : NativeOpTrait<"RegionTerminator", [Terminator]>;
// Op can be safely normalized in the presence of MemRefs with
// non-identity maps.
def MemRefsNormalizable : NativeOpTrait<"MemRefsNormalizable">;
@@ -130,6 +134,13 @@ class SingleBlockImplicitTerminatorImpl<string op>
class SingleBlockImplicitTerminator<string op>
: TraitList<[SingleBlock, SingleBlockImplicitTerminatorImpl<op>]>;
+// This operation has nested regions with the supplied list of `RegionTerminator`
+// operations.
+class HasNestedTerminator<list<string> ops>
+ : ParamNativeOpTrait<"HasNestedTerminators", !interleave(ops, ", ")>,
+ StructuralOpTrait;
+
+
// Op's regions don't have terminator.
def NoTerminator : NativeOpTrait<"NoTerminator">, StructuralOpTrait;
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index be92fe0a6c7e3..4c21f6fd023b1 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -891,7 +891,8 @@ struct SingleBlock : public TraitBase<ConcreteType, SingleBlock> {
// Non-empty regions must contain a single basic block.
if (!region.hasOneBlock())
return op->emitOpError("expects region #")
- << i << " to have 0 or 1 blocks";
+ << i << " to have 0 or 1 blocks, found "
+ << llvm::range_size(region) << " blocks";
if (!ConcreteType::template hasTrait<NoTerminator>()) {
Block &block = region.front();
@@ -1323,6 +1324,19 @@ struct HasParent {
};
};
+/// This class provides a verifier for ops that are expecting to have nested
+/// predecessors.
+template <typename... NestedPredecessorOpTypes>
+struct HasNestedTerminators {
+ template <typename ConcreteType>
+ class Impl : public TraitBase<ConcreteType, Impl> {
+ public:
+ static bool acceptsTerminator(Operation *predecessor) {
+ return llvm::isa_and_nonnull<NestedPredecessorOpTypes...>(predecessor);
+ }
+ };
+};
+
/// A trait for operations that have an attribute specifying operand segments.
///
/// Certain operations can have multiple variadic operands and their size
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index b2019574a820d..31d7a98529364 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -94,7 +94,8 @@ class alignas(8) Operation final
TypeRange resultTypes, ValueRange operands,
NamedAttrList &&attributes,
OpaqueProperties properties, BlockRange successors,
- unsigned numRegions);
+ unsigned numRegions,
+ unsigned numBreakingControlRegions);
/// Create a new Operation with the specific fields. This constructor uses an
/// existing attribute dictionary to avoid uniquing a list of attributes.
@@ -102,7 +103,8 @@ class alignas(8) Operation final
TypeRange resultTypes, ValueRange operands,
DictionaryAttr attributes,
OpaqueProperties properties, BlockRange successors,
- unsigned numRegions);
+ unsigned numRegions,
+ unsigned numBreakingControlRegions);
/// Create a new Operation from the fields stored in `state`.
static Operation *create(const OperationState &state);
@@ -112,8 +114,8 @@ class alignas(8) Operation final
TypeRange resultTypes, ValueRange operands,
NamedAttrList &&attributes,
OpaqueProperties properties,
- BlockRange successors = {},
- RegionRange regions = {});
+ BlockRange successors = {}, RegionRange regions = {},
+ unsigned numBreakingControlRegions = 0);
/// The name of an operation is the key identifier for it.
OperationName getName() { return name; }
@@ -705,6 +707,33 @@ class alignas(8) Operation final
bool hasSuccessors() { return numSuccs != 0; }
unsigned getNumSuccessors() { return numSuccs; }
+ /// Return true if this operation carries a num-breaking-regions value, i.e.
+ /// it is a RegionTerminator (scf.break, scf.continue, or similar) that can
+ /// exit one or more nested region levels in a single step.
+ bool isBreakingControlFlow() { return isBreakingControlFlowFlag; }
+
+ /// Return the number of nested region levels this terminator exits.
+ /// Returns 0 for ordinary operations that are not region terminators.
+ /// For a RegionTerminator, N = 1 means a normal yield to the immediate
+ /// parent; N = K means the terminator exits K region levels, bypassing K-1
+ /// intermediate PropagateControlFlowBreak ops on the way to the
+ /// HasBreakingControlFlowOpInterface ancestor at level K.
+ int getNumBreakingControlRegions() {
+ if (!isBreakingControlFlow())
+ return 0;
+ return *reinterpret_cast<int *>(getTrailingObjects<detail::OpProperties>());
+ }
+
+ /// Set the num-breaking-regions value on this RegionTerminator. The
+ /// operation must have been created with isBreakingControlFlow == true
+ /// (i.e. numBreakingControlRegions > 0 was passed to Operation::create).
+ void setNumBreakingControlRegions(int numBreakingControlRegions) {
+ assert(isBreakingControlFlow() &&
+ "operation is not a breaking control flow operation");
+ *reinterpret_cast<int *>(getTrailingObjects<detail::OpProperties>()) =
+ numBreakingControlRegions;
+ }
+
Block *getSuccessor(unsigned index) {
assert(index < getNumSuccessors());
return getBlockOperands()[index].get();
@@ -898,14 +927,26 @@ class alignas(8) Operation final
}
/// Returns the properties storage.
OpaqueProperties getPropertiesStorage() {
- if (propertiesStorageSize)
- return getPropertiesStorageUnsafe();
+ if (propertiesStorageSize) {
+ void *properties =
+ reinterpret_cast<void *>(getTrailingObjects<detail::OpProperties>());
+ if (isBreakingControlFlowFlag)
+ properties =
+ reinterpret_cast<void *>(reinterpret_cast<char *>(properties) + 8);
+ return {properties};
+ }
return {nullptr};
}
OpaqueProperties getPropertiesStorage() const {
- if (propertiesStorageSize)
- return {reinterpret_cast<void *>(const_cast<detail::OpProperties *>(
- getTrailingObjects<detail::OpProperties>()))};
+ if (propertiesStorageSize) {
+ void *properties =
+ reinterpret_cast<void *>(const_cast<detail::OpProperties *>(
+ getTrailingObjects<detail::OpProperties>()));
+ if (isBreakingControlFlowFlag)
+ properties =
+ reinterpret_cast<void *>(reinterpret_cast<char *>(properties) + 8);
+ return {properties};
+ }
return {nullptr};
}
/// Returns the properties storage without checking whether properties are
@@ -960,8 +1001,9 @@ class alignas(8) Operation final
private:
Operation(Location location, OperationName name, unsigned numResults,
unsigned numSuccessors, unsigned numRegions,
- int propertiesStorageSize, DictionaryAttr attributes,
- OpaqueProperties properties, bool hasOperandStorage);
+ unsigned numBreakingControlRegions, int propertiesStorageSize,
+ DictionaryAttr attributes, OpaqueProperties properties,
+ bool hasOperandStorage);
// Operations are deleted through the destroy() member because they are
// allocated with malloc.
@@ -1048,13 +1090,19 @@ class alignas(8) Operation final
const unsigned numResults;
const unsigned numSuccs;
- const unsigned numRegions : 23;
+ const unsigned numRegions : 22;
/// This bit signals whether this operation has an operand storage or not. The
/// operand storage may be elided for operations that are known to never have
/// operands.
bool hasOperandStorage : 1;
+ /// Set when this operation carries a num-breaking-regions count (i.e. it is
+ /// a RegionTerminator). When true, the num-breaking-regions integer is stored
+ /// at the start of the trailing OpProperties storage area, ahead of any
+ /// dialect-specific properties.
+ bool isBreakingControlFlowFlag : 1;
+
/// The size of the storage for properties (if any), divided by 8: since the
/// Properties storage will always be rounded up to the next multiple of 8 we
/// save some bits here.
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 1ff7c56ddca38..b0ed588cd0a65 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -971,6 +971,12 @@ struct OperationState {
llvm::function_ref<void(OpaqueProperties)> propertiesDeleter;
llvm::function_ref<void(OpaqueProperties, const OpaqueProperties)>
propertiesSetter;
+ /// Number of nested region levels this operation exits as a RegionTerminator.
+ /// 0 means it is not a RegionTerminator. Values > 0 cause the Operation to
+ /// store this integer in its trailing OpProperties storage and set the
+ /// isBreakingControlFlowFlag bit. See
+ /// Operation::getNumBreakingControlRegions.
+ unsigned numBreakingControlRegions = 0;
friend class Operation;
public:
@@ -1096,6 +1102,14 @@ struct OperationState {
}
void addSuccessors(BlockRange newSuccessors);
+ /// Set the num-breaking-regions count for this operation state, marking the
+ /// resulting Operation as a RegionTerminator that exits `n` region levels.
+ /// Must be called from an op builder before the Operation is created; the
+ /// value is forwarded to Operation::create as numBreakingControlRegions.
+ void setNumBreakingControlRegions(int numBreakingControlRegions) {
+ this->numBreakingControlRegions = numBreakingControlRegions;
+ }
+
/// Create a region that should be attached to the operation. These regions
/// can be filled in immediately without waiting for Operation to be
/// created. When it is, the region bodies will be transferred.
diff --git a/mlir/include/mlir/IR/RegionKindInterface.h b/mlir/include/mlir/IR/RegionKindInterface.h
index d6d3aeeb9bd05..8cadac9db7e15 100644
--- a/mlir/include/mlir/IR/RegionKindInterface.h
+++ b/mlir/include/mlir/IR/RegionKindInterface.h
@@ -36,6 +36,48 @@ class HasOnlyGraphRegion : public TraitBase<ConcreteType, HasOnlyGraphRegion> {
static RegionKind getRegionKind(unsigned index) { return RegionKind::Graph; }
static bool hasSSADominance(unsigned index) { return false; }
};
+
+/// Indicates that this operation is transparent to breaking control flow:
+/// a RegionTerminator (e.g. scf.break / scf.continue) with
+/// num-breaking-regions > 1 can propagate through this op on its way to the
+/// enclosing HasBreakingControlFlowOpInterface ancestor. The op does NOT
+/// consume the break; it simply passes it upward. All ops that are "skipped
+/// over" by a multi-level region terminator must carry this trait.
+template <typename ConcreteType>
+class PropagateControlFlowBreak
+ : public TraitBase<ConcreteType, PropagateControlFlowBreak> {
+public:
+ static LogicalResult verifyTrait(Operation *op) {
+ // Verify the operation has regions and can handle breaking control flow
+ if (op->getNumRegions() == 0)
+ return op->emitOpError(
+ "operation with PropagateControlFlowBreak trait must have regions");
+ return success();
+ }
+};
+
+/// Indicates that this operation is a block terminator that can exit multiple
+/// nested region levels in one step. The operation must carry a
+/// num-breaking-regions value N > 0:
+/// N = 1 — exits its own immediately enclosing region (normal yield).
+/// N = K — exits K region levels; the K-1 intermediate parent ops must
+/// each carry PropagateControlFlowBreak; the op at the K-th level
+/// must implement HasBreakingControlFlowOpInterface.
+/// This trait also requires IsTerminator (enforced by verifyTrait).
+template <typename ConcreteType>
+class RegionTerminator : public TraitBase<ConcreteType, RegionTerminator> {
+public:
+ static LogicalResult verifyTrait(Operation *op) {
+ if (op->getNumBreakingControlRegions() == 0)
+ return op->emitOpError("operation with region terminator trait must have "
+ "breaking control regions > 0");
+ if (!op->hasTrait<OpTrait::IsTerminator>())
+ return op->emitOpError(
+ "operation with region terminator trait must be a terminator");
+ return success();
+ }
+};
+
} // namespace OpTrait
/// Return "true" if the given region may have SSA dominance. This function also
@@ -49,8 +91,111 @@ bool mayHaveSSADominance(Region ®ion);
/// implement the RegionKindInterface.
bool mayBeGraphRegion(Region ®ion);
+/// Return true if `op` (which implements HasBreakingControlFlowOpInterface)
+/// contains at least one RegionTerminator that directly targets it — i.e. a
+/// terminator whose num-breaking-regions equals the region-nesting depth from
+/// `op`'s body to that terminator. Such a terminator is a "nested predecessor"
+/// of `op` because control flow may re-enter or exit `op` from a deeply nested
+/// site rather than only through the immediately enclosing terminator.
+bool hasNestedPredecessors(Operation *op);
+
+/// Return true if `op` contains any RegionTerminator that would "break
+/// through" `op` towards an outer HasBreakingControlFlowOpInterface ancestor,
+/// i.e. a terminator whose num-breaking-regions exceeds the nesting depth at
+/// which it appears inside `op`. This is used to detect whether an op's
+/// post-dominance is disrupted by an early-exit path that bypasses it.
+bool hasBreakingControlFlowOps(Operation *op);
+
+/// Collect all RegionTerminator operations nested inside `op` that directly
+/// target `op` (num-breaking-regions == their region-nesting depth from `op`).
+/// These are the ops that will transfer control flow to `op` on an early exit.
+void collectAllNestedPredecessors(Operation *op,
+ SmallVector<Operation *> &predecessors);
+
+namespace detail {
+/// Implementation helper for visitNestedBreakingControlFlowOps. Walks the
+/// regions of `op` and invokes `callback` for every RegionTerminator whose
+/// num-breaking-regions is >= its current nesting depth (i.e. the terminator
+/// either targets `op` or propagates further upward through `op`).
+/// The `nestedLevel` argument passed to the callback is the 1-based depth of
+/// the terminator relative to `op`'s outermost region.
+void visitNestedBreakingControlFlowOpsImpl(
+ Operation *op,
+ function_ref<WalkResult(Operation *, int nestedLevel)> callback);
+} // namespace detail
+
+/// Walk all RegionTerminator operations that are relevant to breaking control
+/// flow inside `op` (see visitNestedBreakingControlFlowOpsImpl). The callback
+/// receives the terminator op and its 1-based nesting level. The WalkResult-
+/// returning overload supports early termination via WalkResult::interrupt().
+template <typename CallbackT>
+std::enable_if_t<
+ std::is_same_v<decltype(std::declval<CallbackT>()(
+ std::declval<Operation *>(), std::declval<int>())),
+ WalkResult>>
+visitNestedBreakingControlFlowOps(Operation *op, CallbackT &&callback) {
+ detail::visitNestedBreakingControlFlowOpsImpl(op, callback);
+}
+
+/// Walk all RegionTerminator operations relevant to breaking control flow
+/// inside `op`. Void-returning callback overload (no early termination).
+template <typename CallbackT>
+std::enable_if_t<
+ std::is_same_v<decltype(std::declval<CallbackT>()(
+ std::declval<Operation *>(), std::declval<int>())),
+ void>>
+visitNestedBreakingControlFlowOps(Operation *op, CallbackT &&callback) {
+ detail::visitNestedBreakingControlFlowOpsImpl(
+ op, [&](Operation *visitedOp, int nestedLevel) {
+ callback(visitedOp, nestedLevel);
+ return WalkResult::advance();
+ });
+}
+
+/// Walk all RegionTerminator operations relevant to breaking control flow
+/// across all top-level ops in `region`. WalkResult-returning overload.
+template <typename CallbackT>
+std::enable_if_t<
+ std::is_same_v<decltype(std::declval<CallbackT>()(
+ std::declval<Operation *>(), std::declval<int>())),
+ WalkResult>>
+visitNestedBreakingControlFlowOps(Region ®ion, CallbackT &&callback) {
+ for (Operation &op : region.getOps())
+ detail::visitNestedBreakingControlFlowOpsImpl(&op, callback);
+}
+
+/// Walk all RegionTerminator operations relevant to breaking control flow
+/// across all top-level ops in `region`. Void-returning overload.
+template <typename CallbackT>
+std::enable_if_t<
+ std::is_same_v<decltype(std::declval<CallbackT>()(
+ std::declval<Operation *>(), std::declval<int>())),
+ void>>
+visitNestedBreakingControlFlowOps(Region ®ion, CallbackT &&callback) {
+ for (Operation &op : region.getOps())
+ detail::visitNestedBreakingControlFlowOpsImpl(
+ &op, [&](Operation *visitedOp, int nestedLevel) {
+ callback(visitedOp, nestedLevel);
+ return WalkResult::advance();
+ });
+}
+
} // namespace mlir
#include "mlir/IR/RegionKindInterface.h.inc"
+namespace mlir {
+
+/// Return true if the given region may contain breaking control flow — either
+/// because its parent op propagates breaks (PropagateControlFlowBreak) or
+/// because it is the body of a HasBreakingControlFlowOpInterface op. Used to
+/// decide whether post-dominance analysis must account for early-exit paths.
+inline bool hasBreakingControlFlow(Region *region) {
+ return region->getParentOp()
+ ->hasTrait<OpTrait::PropagateControlFlowBreak>() ||
+ isa<HasBreakingControlFlowOpInterface>(region->getParentOp());
+}
+
+} // namespace mlir
+
#endif // MLIR_IR_REGIONKINDINTERFACE_H_
diff --git a/mlir/include/mlir/IR/RegionKindInterface.td b/mlir/include/mlir/IR/RegionKindInterface.td
index 607001a89250e..e6fd3bd05f655 100644
--- a/mlir/include/mlir/IR/RegionKindInterface.td
+++ b/mlir/include/mlir/IR/RegionKindInterface.td
@@ -61,4 +61,62 @@ def GraphRegionNoTerminator : TraitList<[
HasOnlyGraphRegion
]>;
+// Indicates that this op may propagate a breaking control-flow event (a
+// RegionTerminator with num-breaking-regions > 1) from a nested region upward
+// to an enclosing HasBreakingControlFlowOpInterface operation. The op does NOT
+// consume the break itself; it is merely transparent to it. All ops that sit
+// between a RegionTerminator and the HasBreakingControlFlowOpInterface ancestor
+// that will ultimately receive the break must carry this trait.
+def PropagateControlFlowBreak : NativeOpTrait<"PropagateControlFlowBreak">;
+
+// OpInterface for operations that can receive a breaking control-flow event
+// originating from a RegionTerminator (scf.break / scf.continue) anywhere
+// inside their (possibly deeply nested) regions. Every op between the
+// RegionTerminator and this receiver must carry the PropagateControlFlowBreak
+// trait.
+def HasBreakingControlFlowOpInterface : OpInterface<"HasBreakingControlFlowOpInterface"> {
+ let description = [{
+ Interface for operations that act as the target of a breaking control-flow
+ event (e.g. `scf.break` or `scf.continue`). When a `RegionTerminator` has
+ `num-breaking-regions = N > 1`, it exits N region levels; the op at the
+ N-th level must implement this interface. Every intermediate op (levels
+ 1 .. N-1) must carry the `PropagateControlFlowBreak` trait.
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ StaticInterfaceMethod<
+ /*desc=*/[{
+ Return true if this operation accepts the given terminator operation
+ as a breaking-control-flow predecessor. By default all terminators are
+ accepted; override to restrict to specific terminator op types.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"acceptsTerminator",
+ /*args=*/(ins "Operation *":$op),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ return true;
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Return true if this operation has at least one RegionTerminator nested
+ inside it that targets this operation directly (i.e.
+ num-breaking-regions equals the nesting depth from this op's region to
+ the terminator). Used to decide whether post-dominance analysis must
+ account for early-exit paths.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"hasNestedPredecessors",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ return ::mlir::hasNestedPredecessors(this->getOperation());
+ }]
+ >
+ ];
+}
+
+
#endif // MLIR_IR_REGIONKINDINTERFACE
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 22928df03fbc7..a92e836a306e5 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -676,7 +676,8 @@ class OperationParser : public Parser {
ParseResult parseSuccessor(Block *&dest);
/// Parse a comma-separated list of operation successors in brackets.
- ParseResult parseSuccessors(SmallVectorImpl<Block *> &destinations);
+ ParseResult parseSuccessors(SmallVectorImpl<Block *> &destinations,
+ bool parseOpeningBracket = true);
/// Parse an operation instance that is in the generic form.
Operation *parseGenericOperation();
@@ -695,7 +696,8 @@ class OperationParser : public Parser {
std::nullopt,
std::optional<ArrayRef<NamedAttribute>> parsedAttributes = std::nullopt,
std::optional<Attribute> propertiesAttribute = std::nullopt,
- std::optional<FunctionType> parsedFnType = std::nullopt);
+ std::optional<FunctionType> parsedFnType = std::nullopt,
+ std::optional<int> parsedNumBreakingControlRegions = std::nullopt);
/// Parse an operation instance that is in the generic form and insert it at
/// the provided insertion point.
@@ -1237,7 +1239,7 @@ Value OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) {
auto *op = Operation::create(
getEncodedSourceLocation(loc), name, type, /*operands=*/{},
/*attributes=*/NamedAttrList(), /*properties=*/nullptr,
- /*successors=*/{}, /*numRegions=*/0);
+ /*successors=*/{}, /*numRegions=*/0, /*numBreakingControlRegions=*/0);
forwardRefPlaceholders[op->getResult(0)] = loc;
forwardRefOps.insert(op);
return op->getResult(0);
@@ -1381,8 +1383,9 @@ ParseResult OperationParser::parseSuccessor(Block *&dest) {
/// successor-list ::= `[` successor (`,` successor )* `]`
///
ParseResult
-OperationParser::parseSuccessors(SmallVectorImpl<Block *> &destinations) {
- if (parseToken(Token::l_square, "expected '['"))
+OperationParser::parseSuccessors(SmallVectorImpl<Block *> &destinations,
+ bool parseOpeningBracket) {
+ if (parseOpeningBracket && parseToken(Token::l_square, "expected '['"))
return failure();
auto parseElt = [this, &destinations] {
@@ -1420,7 +1423,8 @@ ParseResult OperationParser::parseGenericOperationAfterOpName(
std::optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions,
std::optional<ArrayRef<NamedAttribute>> parsedAttributes,
std::optional<Attribute> propertiesAttribute,
- std::optional<FunctionType> parsedFnType) {
+ std::optional<FunctionType> parsedFnType,
+ std::optional<int> parsedNumBreakingControlRegions) {
// Parse the operand list, if not explicitly provided.
SmallVector<UnresolvedOperand, 8> opInfo;
@@ -1434,19 +1438,38 @@ ParseResult OperationParser::parseGenericOperationAfterOpName(
}
// Parse the successor list, if not explicitly provided.
- if (!parsedSuccessors) {
+ if (parsedSuccessors)
+ result.addSuccessors(*parsedSuccessors);
+ if (parsedNumBreakingControlRegions)
+ result.setNumBreakingControlRegions(*parsedNumBreakingControlRegions);
+ if (!parsedSuccessors || !parsedNumBreakingControlRegions) {
if (getToken().is(Token::l_square)) {
+ consumeToken(Token::l_square);
+
// Check if the operation is not a known terminator.
if (!result.name.mightHaveTrait<OpTrait::IsTerminator>())
return emitError("successors in non-terminator");
- SmallVector<Block *, 2> successors;
- if (parseSuccessors(successors))
- return failure();
- result.addSuccessors(successors);
+ // If we don't have a ^, then we expect a single integer for the number
+ // of breaking control regions.
+ if (!getToken().is(Token::caret_identifier)) {
+ APInt numBreakingControlRegions;
+ OptionalParseResult parseResult =
+ parseOptionalInteger(numBreakingControlRegions);
+ if (!parseResult.has_value() || failed(*parseResult))
+ return emitError("expected `^` or integer after '['");
+ result.setNumBreakingControlRegions(
+ numBreakingControlRegions.getZExtValue());
+ if (failed(parseToken(Token::r_square,
+ "expected ']' to end breaking control regions")))
+ return failure();
+ } else {
+ SmallVector<Block *, 2> successors;
+ if (parseSuccessors(successors, /*parseOpeningBracket=*/false))
+ return failure();
+ result.addSuccessors(successors);
+ }
}
- } else {
- result.addSuccessors(*parsedSuccessors);
}
// Parse the properties, if not explicitly provided.
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 03842cc9bd3a0..b3105199eadad 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -24,6 +24,8 @@
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/Passes.h"
+#include "llvm/Support/DebugLog.h"
+
namespace mlir {
#define GEN_PASS_DEF_SCFTOCONTROLFLOWPASS
#include "mlir/Conversion/Passes.h.inc"
@@ -311,6 +313,16 @@ struct ForallLowering : public OpRewritePattern<mlir::scf::ForallOp> {
PatternRewriter &rewriter) const override;
};
+/// TODO
+struct LoopOpLowering : public OpConversionPattern<LoopOp> {
+ using OpConversionPattern<LoopOp>::OpConversionPattern;
+ void initialize() { setHasBoundedRewriteRecursion(); }
+
+ LogicalResult
+ matchAndRewrite(LoopOp loopOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
} // namespace
static void propagateLoopAttrs(Operation *scfOp, Operation *brOp) {
@@ -401,7 +413,12 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
PatternRewriter &rewriter) const {
auto loc = ifOp.getLoc();
-
+ visitNestedBreakingControlFlowOps(ifOp, [&](Operation *op, int nestedLevel) {
+ if (auto numBreakingControlRegions = op->getNumBreakingControlRegions()) {
+ if (numBreakingControlRegions > nestedLevel)
+ op->setNumBreakingControlRegions(numBreakingControlRegions - 1);
+ }
+ });
// Start by splitting the block containing the 'scf.if' into two parts.
// The part before will contain the condition, the part after will be the
// continuation point.
@@ -425,8 +442,10 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
Operation *thenTerminator = thenRegion.back().getTerminator();
ValueRange thenTerminatorOperands = thenTerminator->getOperands();
rewriter.setInsertionPointToEnd(&thenRegion.back());
- cf::BranchOp::create(rewriter, loc, continueBlock, thenTerminatorOperands);
- rewriter.eraseOp(thenTerminator);
+ if (isa<scf::YieldOp>(thenTerminator)) {
+ cf::BranchOp::create(rewriter, loc, continueBlock, thenTerminatorOperands);
+ rewriter.eraseOp(thenTerminator);
+ }
rewriter.inlineRegionBefore(thenRegion, continueBlock);
// Move blocks from the "else" region (if present) to the region containing
@@ -439,8 +458,11 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
Operation *elseTerminator = elseRegion.back().getTerminator();
ValueRange elseTerminatorOperands = elseTerminator->getOperands();
rewriter.setInsertionPointToEnd(&elseRegion.back());
- cf::BranchOp::create(rewriter, loc, continueBlock, elseTerminatorOperands);
- rewriter.eraseOp(elseTerminator);
+ if (isa<scf::YieldOp>(thenTerminator)) {
+ cf::BranchOp::create(rewriter, loc, continueBlock,
+ elseTerminatorOperands);
+ rewriter.eraseOp(elseTerminator);
+ }
rewriter.inlineRegionBefore(elseRegion, continueBlock);
}
@@ -720,11 +742,114 @@ LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
return scf::forallToParallelLoop(rewriter, forallOp);
}
+LogicalResult
+LoopOpLowering::matchAndRewrite(LoopOp loopOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ if (failed(rewriter.legalize(&loopOp.getRegion())))
+ return rewriter.notifyMatchFailure(loopOp,
+ "failed to convert nested region");
+ }
+
+ SmallVector<Operation *> predecessors;
+ collectAllNestedPredecessors(loopOp, predecessors);
+ for (Operation *predecessor : predecessors) {
+ if (predecessor->getNumBreakingControlRegions() > 1) {
+ return rewriter.notifyMatchFailure(loopOp,
+ "loop op with nested predecessors");
+ }
+ }
+ visitNestedBreakingControlFlowOps(loopOp, [&](Operation *op,
+ int nestedLevel) {
+ if (auto numBreakingControlRegions = op->getNumBreakingControlRegions()) {
+ if (numBreakingControlRegions > nestedLevel)
+ op->setNumBreakingControlRegions(numBreakingControlRegions - 1);
+ }
+ });
+
+ // Lower `scf.loop` to CFG by converting breaks/continues to branches.
+ Location loc = loopOp.getLoc();
+ // Split the block containing loopOp into the init block and continuation.
+ Block *initBlock = rewriter.getInsertionBlock();
+ auto initPos = rewriter.getInsertionPoint();
+ Block *continueBlock = rewriter.splitBlock(initBlock, initPos);
+ continueBlock->addArguments(
+ loopOp.getResultTypes(),
+ SmallVector<Location>(loopOp.getNumResults(), loc));
+
+ // Inline the loop body region into the parent function just before
+ // continueBlock.
+ Region &bodyRegion = loopOp.getRegion();
+ if (bodyRegion.empty() || bodyRegion.front().empty()) {
+ // Degenerate case: no body. Just remove the op.
+ rewriter.eraseOp(loopOp);
+ return success();
+ }
+ Block *loopBody = &bodyRegion.front();
+
+ // Prepare the mapping of loop args to values.
+ SmallVector<Value> loopArgs;
+ for (auto arg : loopBody->getArguments())
+ loopArgs.push_back(arg);
+
+ // Create the loop entry block and move the body there.
+ rewriter.setInsertionPoint(initBlock, initBlock->end());
+ // Split out everything after loopOp into continueBlock.
+ // The block before loop is now initBlock.
+
+ // Move all blocks from the scf.loop region before continueBlock.
+ rewriter.inlineRegionBefore(bodyRegion, continueBlock);
+ // We will remember all break/continue ops to fix up after.
+ SmallVector<Operation *> toErase;
+
+ for (auto predecessor : predecessors) {
+ if (auto breakOp = dyn_cast<scf::BreakOp>(predecessor)) {
+ rewriter.setInsertionPointAfter(breakOp);
+ cf::BranchOp::create(rewriter, breakOp->getLoc(), continueBlock,
+ ValueRange{breakOp.getOperands()});
+ } else if (auto contOp = dyn_cast<scf::ContinueOp>(predecessor)) {
+ rewriter.setInsertionPointAfter(contOp);
+ cf::BranchOp::create(rewriter, contOp->getLoc(), loopBody,
+ ValueRange{contOp.getOperands()});
+ }
+ toErase.push_back(predecessor);
+ }
+
+ // Erase the old scf.break/scf.continue ops.
+ for (Operation *op : toErase)
+ rewriter.eraseOp(op);
+
+ // The loop region is now a CFG. Jump from initBlock to the loop body.
+ rewriter.setInsertionPointToEnd(initBlock);
+ cf::BranchOp::create(rewriter, loc, loopBody,
+ ValueRange{loopOp.getOperands()});
+
+ // Replace the scf.yield with a branch to the loop header (unless it was
+ // replaced above).
+ for (Block &block :
+ llvm::make_early_inc_range(loopBody->getParent()->getBlocks())) {
+ if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator())) {
+ rewriter.setInsertionPoint(yield);
+ // For plain scf.yield at the end of the loop (i.e., loop-carried values),
+ // treat as continue.
+ cf::BranchOp::create(rewriter, yield.getLoc(), loopBody,
+ yield.getOperands());
+ rewriter.eraseOp(yield);
+ }
+ }
+
+ // Replace the original scf.loop op with a branch to continueBlock assigning
+ // results.
+ rewriter.replaceOp(loopOp, continueBlock->getArguments());
+ return success();
+}
+
void mlir::populateSCFToControlFlowConversionPatterns(
RewritePatternSet &patterns) {
- patterns.add<ForallLowering, ForLowering, IfLowering, ParallelLowering,
- WhileLowering, ExecuteRegionLowering, IndexSwitchLowering>(
- patterns.getContext());
+ patterns.add<ExecuteRegionLowering, ForallLowering, ForLowering, IfLowering,
+ IndexSwitchLowering, LoopOpLowering, ParallelLowering,
+ WhileLowering>(patterns.getContext());
patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
}
@@ -735,7 +860,8 @@ void SCFToControlFlowPass::runOnOperation() {
// Configure conversion to lower out SCF operations.
ConversionTarget target(getContext());
target.addIllegalOp<scf::ForallOp, scf::ForOp, scf::IfOp, scf::IndexSwitchOp,
- scf::ParallelOp, scf::WhileOp, scf::ExecuteRegionOp>();
+ scf::LoopOp, scf::ParallelOp, scf::WhileOp,
+ scf::ExecuteRegionOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
ConversionConfig config;
config.allowPatternRollback = allowPatternRollback;
diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
index 71e3f88a63f34..8f5e37220d5cd 100644
--- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
+++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
@@ -56,7 +56,7 @@ void mlir::registerConvertSCFToEmitCInterface(DialectRegistry ®istry) {
namespace {
-struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
+struct SCFToEmitCPass : public ::mlir::impl::SCFToEmitCBase<SCFToEmitCPass> {
void runOnOperation() override;
};
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 0ff9fb3f628ab..da85f7c2d2eaa 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -681,7 +681,8 @@ namespace {
namespace {
/// Conversion pass.
class ConvertShapeToStandardPass
- : public impl::ConvertShapeToStandardPassBase<ConvertShapeToStandardPass> {
+ : public ::mlir::impl::ConvertShapeToStandardPassBase<
+ ConvertShapeToStandardPass> {
void runOnOperation() override;
};
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index 6081e515d4e3a..8502532622ff8 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -669,10 +669,10 @@ Operation *BufferDeallocation::appendOpResults(Operation *op,
ArrayRef<Type> types) {
SmallVector<Type> newTypes(op->getResultTypes());
newTypes.append(types.begin(), types.end());
- auto *newOp = Operation::create(op->getLoc(), op->getName(), newTypes,
- op->getOperands(), op->getAttrDictionary(),
- op->getPropertiesStorage(),
- op->getSuccessors(), op->getNumRegions());
+ auto *newOp = Operation::create(
+ op->getLoc(), op->getName(), newTypes, op->getOperands(),
+ op->getAttrDictionary(), op->getPropertiesStorage(), op->getSuccessors(),
+ op->getNumRegions(), op->getNumBreakingControlRegions());
for (auto [oldRegion, newRegion] :
llvm::zip(op->getRegions(), newOp->getRegions()))
newRegion.takeBody(oldRegion);
diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
index e2e63abe0a11a..3b135d088bb76 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
@@ -111,7 +111,8 @@ struct GpuAsyncRegionPass::ThreadTokenCallback {
auto *newOp = Operation::create(
op->getLoc(), op->getName(), resultTypes, op->getOperands(),
op->getDiscardableAttrDictionary(), op->getPropertiesStorage(),
- op->getSuccessors(), op->getNumRegions());
+ op->getSuccessors(), op->getNumRegions(),
+ op->getNumBreakingControlRegions());
// Clone regions into new op.
IRMapping mapping;
diff --git a/mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp
index 966d7589d42c3..e084c5a146fb8 100644
--- a/mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp
@@ -125,7 +125,8 @@ class ConvertGenericOpwithSubChannelType : public ConversionPattern {
auto *newOp = Operation::create(
op->getLoc(), op->getName(), resultTypes, operands, op->getAttrs(),
- op->getPropertiesStorage(), op->getSuccessors(), op->getNumRegions());
+ op->getPropertiesStorage(), op->getSuccessors(), op->getNumRegions(),
+ op->getNumBreakingControlRegions());
for (auto regions : llvm::zip(op->getRegions(), newOp->getRegions())) {
Region &before = std::get<0>(regions);
Region &parent = std::get<1>(regions);
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index c46a0577c4b96..01bf49362fd2c 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -132,6 +132,41 @@ std::optional<llvm::APSInt> mlir::scf::computeUbMinusLb(Value lb, Value ub,
// ExecuteRegionOp
//===----------------------------------------------------------------------===//
+/// Replaces the given op with the contents of the given single-block region,
+/// using the operands of the block terminator to replace operation results.
+static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
+ Region ®ion, ValueRange blockArgs = {}) {
+ assert(region.hasOneBlock() && "expected single-block region");
+ Block *block = ®ion.front();
+ // Reduce the level of breaking control flow ops by 1 since we inline the
+ // region.
+ visitNestedBreakingControlFlowOps(
+ op, [&](Operation *visitedOp, int nestedLevel) {
+ LDBG() << "replaceOpWithRegion - Visiting op: "
+ << OpWithFlags(visitedOp, OpPrintingFlags().skipRegions())
+ << " at nested level " << nestedLevel;
+ if (nestedLevel <=
+ static_cast<int>(visitedOp->getNumBreakingControlRegions()))
+ visitedOp->setNumBreakingControlRegions(
+ visitedOp->getNumBreakingControlRegions() - 1);
+ });
+ Operation *terminator = block->getTerminator();
+ ValueRange results = terminator->getOperands();
+ rewriter.inlineBlockBefore(block, op, blockArgs);
+ if (terminator->getNumBreakingControlRegions() < 1) {
+ rewriter.replaceOp(op, results);
+ rewriter.eraseOp(terminator);
+ } else {
+ Operation *toDelete = &op->getBlock()->back();
+ Operation *prevOp = toDelete;
+ do {
+ toDelete = prevOp;
+ prevOp = prevOp->getPrevNode();
+ rewriter.eraseOp(toDelete);
+ } while (toDelete != op);
+ }
+}
+
///
/// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
/// block+
@@ -311,6 +346,188 @@ void ConditionOp::getSuccessorRegions(
regions.push_back(RegionSuccessor::parent());
}
+//===----------------------------------------------------------------------===//
+// LoopOp
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// Control Flow Op Utilies
+//===----------------------------------------------------------------------===//
+
+template <typename OpT>
+static ParseResult
+parseControlFlowRegion(OpAsmParser &p, Region ®ion,
+ ArrayRef<OpAsmParser::Argument> arguments = {}) {
+ if (failed(p.parseRegion(region, arguments)))
+ return failure();
+ OpT::ensureTerminator(region, p.getBuilder(),
+ p.getEncodedSourceLoc(p.getNameLoc()));
+ return success();
+}
+
+template <typename ImplicitTerminatorOpT, typename OpT>
+static void printControlFlowRegion(OpAsmPrinter &p, OpT op, Region ®ion) {
+ // We do not print the terminator if it is implicit and has no operands.
+ bool printBlockTerminators =
+ region.front().getTerminator()->getNumOperands() != 0 ||
+ !isa<ImplicitTerminatorOpT>(region.front().getTerminator());
+ p.printRegion(region, /*printEntryBlockArgs=*/false, printBlockTerminators);
+}
+
+LogicalResult ContinueOp::verify() {
+ if (getOperation()->getNumBreakingControlRegions() == 0)
+ return emitOpError(
+ "continue op must have at least one breaking control region");
+ return success();
+}
+
+MutableOperandRange
+ContinueOp::getMutableSuccessorOperands(RegionSuccessor point) {
+ return MutableOperandRange(getOperation());
+}
+
+LogicalResult LoopOp::verifyRegions() {
+ // Check matching between the operands and the region arguments.
+ if (getRegion().empty())
+ return emitOpError("region cannot be empty");
+ if (getRegion().front().getNumArguments() != getNumOperands())
+ return emitOpError(
+ "mismatch in number of loop-carried values and defined values");
+ for (auto [index, argAndOperand] : llvm::enumerate(
+ llvm::zip(getRegion().front().getArguments(), getOperands()))) {
+ auto argType = std::get<0>(argAndOperand).getType();
+ auto operandType = std::get<1>(argAndOperand).getType();
+ if (argType != operandType)
+ return emitOpError() << "types mismatch between " << index
+ << "th iter operand (" << argType
+ << ") and defined region argument (" << operandType
+ << ")";
+ }
+ return success();
+}
+
+void LoopOp::print(OpAsmPrinter &p) {
+ p << " ";
+ bool hasIters = !getInitValues().empty();
+ bool hasReturn = !getResultTypes().empty();
+
+ if (hasIters) {
+ p << "iter_args(";
+ llvm::interleaveComma(
+ llvm::zip(getRegionIterValues(), getInitValues()), p,
+ [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); });
+ p << ") : ";
+ p << getInitValues().getTypes();
+ p << " ";
+ }
+ if (hasReturn) {
+ p << "-> ";
+ p << getResultTypes();
+ p << " ";
+ }
+
+ printControlFlowRegion<ContinueOp>(p, *this, getRegion());
+ p.printOptionalAttrDict((*this)->getAttrs());
+}
+
+ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) {
+ SmallVector<OpAsmParser::Argument, 4> regionArgs;
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> iterOperands;
+ SmallVector<Type, 4> iterTypes;
+
+ if (failed(parser.parseOptionalKeyword("iter_args"))) {
+ // no iter_args, but can still have a return type
+ if (succeeded(parser.parseOptionalArrow()))
+ if (parser.parseTypeList(result.types))
+ return failure();
+ } else {
+ // iter_args are present and must have colon followed by types
+ if (parser.parseAssignmentList(regionArgs, iterOperands) ||
+ parser.parseColon() || parser.parseTypeList(iterTypes))
+ return failure();
+ if (regionArgs.size() != iterTypes.size())
+ return parser.emitError(parser.getCurrentLocation(),
+ "found different number of iter_args and types");
+ // check for optional result type(s)
+ if (succeeded(parser.parseOptionalArrow()))
+ if (parser.parseTypeList(result.types))
+ return failure();
+ // Set region argument types for loop body
+ for (auto [regionArg, type] : llvm::zip_equal(regionArgs, iterTypes)) {
+ regionArg.type = type;
+ }
+ }
+
+ // Parse region and attr dict.
+ if (parseControlFlowRegion<LoopOp>(parser, *result.addRegion(), regionArgs) ||
+ parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
+ // Resolve operands.
+ if (parser.resolveOperands(iterOperands, iterTypes, parser.getNameLoc(),
+ result.operands))
+ return failure();
+
+ return success();
+}
+
+void LoopOp::getSuccessorRegions(RegionBranchPoint point,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ if (point.isParent()) {
+ regions.push_back(RegionSuccessor(&getRegion()));
+ return;
+ }
+
+ // Otherwise, it depends on the terminator: a continue branches brack to the
+ // body and a break to the parent.
+ if (isa<ContinueOp>(point.getTerminatorPredecessorOrNull())) {
+ regions.push_back(RegionSuccessor(&getRegion()));
+ return;
+ }
+ assert(isa<BreakOp>(point.getTerminatorPredecessorOrNull()) &&
+ "expected continue or break terminator");
+
+ regions.push_back(RegionSuccessor::parent());
+}
+
+OperandRange LoopOp::getEntrySuccessorOperands(RegionSuccessor successor) {
+ return getInitValues();
+}
+
+ValueRange LoopOp::getSuccessorInputs(RegionSuccessor successor) {
+ return successor.isParent() ? ValueRange(getResults())
+ : ValueRange(getRegion().getArguments());
+}
+
+namespace {
+
+/// Rewriting pattern that erases loops that have a single iteration.
+struct SimplifyTrivialLoops : public OpRewritePattern<LoopOp> {
+ using OpRewritePattern<LoopOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(LoopOp op,
+ PatternRewriter &rewriter) const override {
+ // Terminator must be a break.
+ if (!isa<BreakOp>(op.getBody()->getTerminator()))
+ return rewriter.notifyMatchFailure(op, "loop terminator isn't a break");
+
+ // If it has nested predecessors, it can't be trivially simplified.
+ if (hasNestedPredecessors(op))
+ return rewriter.notifyMatchFailure(op, "has nested predecessors");
+
+ // Great: it is a single iteration loop, we can simplify it.
+ replaceOpWithRegion(rewriter, op, op.getRegion());
+
+ return success();
+ }
+};
+} // namespace
+
+void LoopOp::getCanonicalizationPatterns(RewritePatternSet &results,
+ MLIRContext *context) {
+ results.add<SimplifyTrivialLoops>(context);
+}
+
//===----------------------------------------------------------------------===//
// ForOp
//===----------------------------------------------------------------------===//
@@ -928,6 +1145,194 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
}
namespace {
+// Fold away ForOp iter arguments when:
+// 1) The op yields the iter arguments.
+// 2) The argument's corresponding outer region iterators (inputs) are yielded.
+// 3) The iter arguments have no use and the corresponding (operation) results
+// have no use.
+//
+// These arguments must be defined outside of the ForOp region and can just be
+// forwarded after simplifying the op inits, yields and returns.
+//
+// The implementation uses `inlineBlockBefore` to steal the content of the
+// original ForOp and avoid cloning.
+struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
+ using OpRewritePattern<scf::ForOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(scf::ForOp forOp,
+ PatternRewriter &rewriter) const final {
+ bool canonicalize = false;
+
+ // An internal flat vector of block transfer
+ // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
+ // transformed block argument mappings. This plays the role of a
+ // IRMapping for the particular use case of calling into
+ // `inlineBlockBefore`.
+ int64_t numResults = forOp.getNumResults();
+ SmallVector<bool, 4> keepMask;
+ keepMask.reserve(numResults);
+ SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
+ newResultValues;
+ newBlockTransferArgs.reserve(1 + numResults);
+ newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
+ newIterArgs.reserve(forOp.getInitArgs().size());
+ newYieldValues.reserve(numResults);
+ newResultValues.reserve(numResults);
+ DenseMap<std::pair<Value, Value>, std::pair<Value, Value>> initYieldToArg;
+ for (auto [init, arg, result, yielded] :
+ llvm::zip(forOp.getInitArgs(), // iter from outside
+ forOp.getRegionIterArgs(), // iter inside region
+ forOp.getResults(), // op results
+ forOp.getYieldedValues() // iter yield
+ )) {
+ // Forwarded is `true` when:
+ // 1) The region `iter` argument is yielded.
+ // 2) The region `iter` argument the corresponding input is yielded.
+ // 3) The region `iter` argument has no use, and the corresponding op
+ // result has no use.
+ bool forwarded = (arg == yielded) || (init == yielded) ||
+ (arg.use_empty() && result.use_empty());
+ if (forwarded) {
+ canonicalize = true;
+ keepMask.push_back(false);
+ newBlockTransferArgs.push_back(init);
+ newResultValues.push_back(init);
+ continue;
+ }
+
+ // Check if a previous kept argument always has the same values for init
+ // and yielded values.
+ if (auto it = initYieldToArg.find({init, yielded});
+ it != initYieldToArg.end()) {
+ canonicalize = true;
+ keepMask.push_back(false);
+ auto [sameArg, sameResult] = it->second;
+ rewriter.replaceAllUsesWith(arg, sameArg);
+ rewriter.replaceAllUsesWith(result, sameResult);
+ // The replacement value doesn't matter because there are no uses.
+ newBlockTransferArgs.push_back(init);
+ newResultValues.push_back(init);
+ continue;
+ }
+
+ // This value is kept.
+ initYieldToArg.insert({{init, yielded}, {arg, result}});
+ keepMask.push_back(true);
+ newIterArgs.push_back(init);
+ newYieldValues.push_back(yielded);
+ newBlockTransferArgs.push_back(Value()); // placeholder with null value
+ newResultValues.push_back(Value()); // placeholder with null value
+ }
+
+ if (!canonicalize)
+ return failure();
+
+ scf::ForOp newForOp =
+ scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
+ forOp.getUpperBound(), forOp.getStep(), newIterArgs,
+ /*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
+ newForOp->setAttrs(forOp->getAttrs());
+ Block &newBlock = newForOp.getRegion().front();
+
+ // Replace the null placeholders with newly constructed values.
+ newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
+ for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
+ idx != e; ++idx) {
+ Value &blockTransferArg = newBlockTransferArgs[1 + idx];
+ Value &newResultVal = newResultValues[idx];
+ assert((blockTransferArg && newResultVal) ||
+ (!blockTransferArg && !newResultVal));
+ if (!blockTransferArg) {
+ blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
+ newResultVal = newForOp.getResult(collapsedIdx++);
+ }
+ }
+
+ Block &oldBlock = forOp.getRegion().front();
+ assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
+ "unexpected argument size mismatch");
+
+ // No results case: the scf::ForOp builder already created a zero
+ // result terminator. Merge before this terminator and just get rid of the
+ // original terminator that has been merged in.
+ if (newIterArgs.empty()) {
+ auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
+ rewriter.inlineBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
+ rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
+ rewriter.replaceOp(forOp, newResultValues);
+ return success();
+ }
+
+ // No terminator case: merge and rewrite the merged terminator.
+ auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(mergedTerminator);
+ SmallVector<Value, 4> filteredOperands;
+ filteredOperands.reserve(newResultValues.size());
+ for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
+ if (keepMask[idx])
+ filteredOperands.push_back(mergedTerminator.getOperand(idx));
+ scf::YieldOp::create(rewriter, mergedTerminator.getLoc(),
+ filteredOperands);
+ };
+
+ rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
+ auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
+ cloneFilteredTerminator(mergedYieldOp);
+ rewriter.eraseOp(mergedYieldOp);
+ rewriter.replaceOp(forOp, newResultValues);
+ return success();
+ }
+};
+
+/// Rewriting pattern that erases loops that are known not to iterate, replaces
+/// single-iteration loops with their bodies, and removes empty loops that
+/// iterate at least once and only return values defined outside of the loop.
+struct SimplifyTrivialForLoops : public OpRewritePattern<ForOp> {
+ using OpRewritePattern<ForOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(ForOp op,
+ PatternRewriter &rewriter) const override {
+ std::optional<APInt> tripCount = op.getStaticTripCount();
+ if (!tripCount.has_value())
+ return rewriter.notifyMatchFailure(op,
+ "can't compute constant trip count");
+
+ if (tripCount->isZero()) {
+ LDBG() << "SimplifyTrivialForLoops tripCount is 0 for loop "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
+ rewriter.replaceOp(op, op.getInitArgs());
+ return success();
+ }
+
+ if (tripCount->getSExtValue() == 1) {
+ LDBG() << "SimplifyTrivialForLoops tripCount is 1 for loop "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
+ SmallVector<Value, 4> blockArgs;
+ blockArgs.reserve(op.getInitArgs().size() + 1);
+ blockArgs.push_back(op.getLowerBound());
+ llvm::append_range(blockArgs, op.getInitArgs());
+ replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs);
+ return success();
+ }
+
+ // Now we are left with loops that have more than 1 iterations.
+ Block &block = op.getRegion().front();
+ if (!llvm::hasSingleElement(block))
+ return failure();
+ // The loop is empty and iterates at least once, if it only returns values
+ // defined outside of the loop, remove it and replace it with yield values.
+ if (llvm::any_of(op.getYieldedValues(),
+ [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
+ return failure();
+ LDBG() << "SimplifyTrivialForLoops empty body loop allows replacement with "
+ "yield operands for loop "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
+ rewriter.replaceOp(op, op.getYieldedValues());
+ return success();
+ }
+};
+
/// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
/// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
///
@@ -990,7 +1395,9 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ForOpTensorCastFolder>(context);
+ results
+ .add<ForOpIterArgsFolder, SimplifyTrivialForLoops, ForOpTensorCastFolder>(
+ context);
populateRegionBranchOpInterfaceCanonicalizationPatterns(
results, ForOp::getOperationName());
populateRegionBranchOpInterfaceInliningPattern(
@@ -1912,13 +2319,21 @@ IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
Region *r = &adaptor.getThenRegion();
if (r->empty())
return failure();
- Block &b = r->front();
- if (b.empty())
- return failure();
- auto yieldOp = llvm::dyn_cast<YieldOp>(b.back());
- if (!yieldOp)
+ Block *b = &r->front();
+ if (b->empty())
return failure();
- TypeRange types = yieldOp.getOperandTypes();
+ Operation *terminator = &b->back();
+ if (terminator->getNumBreakingControlRegions() > 1) {
+ if (adaptor.getElseRegion().empty())
+ return success();
+ b = &adaptor.getElseRegion().front();
+ if (b->empty())
+ return success();
+ terminator = &b->back();
+ if (terminator->getNumBreakingControlRegions() > 1)
+ return success();
+ }
+ TypeRange types = terminator->getOperandTypes();
llvm::append_range(inferredReturnTypes, types);
return success();
}
@@ -2043,7 +2458,9 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
}
void IfOp::print(OpAsmPrinter &p) {
- bool printBlockTerminators = false;
+ bool printBlockTerminators =
+ !isa<YieldOp>(thenBlock()->back()) ||
+ (elseBlock() && !isa<YieldOp>(elseBlock()->back()));
p << " " << getCondition();
if (!getResults().empty()) {
@@ -2073,6 +2490,15 @@ void IfOp::getSuccessorRegions(RegionBranchPoint point,
// The `then` and the `else` region branch back to the parent operation or one
// of the recursive parent operations (early exit case).
if (!point.isParent()) {
+ // Propagating breaks/continues (getNumBreakingControlRegions() > 1) pass
+ // through this if-op to reach an enclosing loop. Don't report parent() as
+ // a successor for them — they don't yield values to this if-op.
+ if (auto terminator = point.getTerminatorPredecessorOrNull()) {
+ Operation *op = terminator.getOperation();
+ if ((isa<BreakOp, ContinueOp>(op)) &&
+ op->getNumBreakingControlRegions() > 1)
+ return;
+ }
regions.push_back(RegionSuccessor::parent());
return;
}
@@ -2147,6 +2573,82 @@ void IfOp::getRegionInvocationBounds(
}
namespace {
+// Pattern to remove unused IfOp results.
+struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
+ using OpRewritePattern<IfOp>::OpRewritePattern;
+
+ void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
+ PatternRewriter &rewriter) const {
+ // Move all operations to the destination block.
+ rewriter.mergeBlocks(source, dest);
+ // Replace the yield op by one that returns only the used values.
+ auto yieldOp = dyn_cast<scf::YieldOp>(dest->getTerminator());
+ if (!yieldOp)
+ return;
+ SmallVector<Value, 4> usedOperands;
+ llvm::transform(usedResults, std::back_inserter(usedOperands),
+ [&](OpResult result) {
+ return yieldOp.getOperand(result.getResultNumber());
+ });
+ rewriter.modifyOpInPlace(yieldOp,
+ [&]() { yieldOp->setOperands(usedOperands); });
+ }
+
+ LogicalResult matchAndRewrite(IfOp op,
+ PatternRewriter &rewriter) const override {
+ // Compute the list of used results.
+ SmallVector<OpResult, 4> usedResults;
+ llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
+ [](OpResult result) { return !result.use_empty(); });
+
+ // Replace the operation if only a subset of its results have uses.
+ if (usedResults.size() == op.getNumResults())
+ return failure();
+
+ // Compute the result types of the replacement operation.
+ SmallVector<Type, 4> newTypes;
+ llvm::transform(usedResults, std::back_inserter(newTypes),
+ [](OpResult result) { return result.getType(); });
+
+ // Create a replacement operation with empty then and else regions.
+ auto newOp =
+ IfOp::create(rewriter, op.getLoc(), newTypes, op.getCondition());
+ rewriter.createBlock(&newOp.getThenRegion());
+ rewriter.createBlock(&newOp.getElseRegion());
+
+ // Move the bodies and replace the terminators (note there is a then and
+ // an else region since the operation returns results).
+ transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
+ transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
+
+ // Replace the operation by the new one.
+ SmallVector<Value, 4> repResults(op.getNumResults());
+ for (const auto &en : llvm::enumerate(usedResults))
+ repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
+ rewriter.replaceOp(op, repResults);
+ return success();
+ }
+};
+
+struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
+ using OpRewritePattern<IfOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(IfOp op,
+ PatternRewriter &rewriter) const override {
+ BoolAttr condition;
+ if (!matchPattern(op.getCondition(), m_Constant(&condition)))
+ return failure();
+
+ if (condition.getValue())
+ replaceOpWithRegion(rewriter, op, op.getThenRegion());
+ else if (!op.getElseRegion().empty())
+ replaceOpWithRegion(rewriter, op, op.getElseRegion());
+ else
+ rewriter.eraseOp(op);
+ return success();
+ }
+};
+
/// Hoist any yielded results whose operands are defined outside
/// the if, to a select instruction.
struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
@@ -2157,9 +2659,14 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
if (op->getNumResults() == 0)
return failure();
+ YieldOp thenYield = dyn_cast<YieldOp>(op.thenTerminator());
+ YieldOp elseYield = dyn_cast<YieldOp>(op.elseTerminator());
+ if (!thenYield || !elseYield)
+ return failure();
+
auto cond = op.getCondition();
- auto thenYieldArgs = op.thenYield().getOperands();
- auto elseYieldArgs = op.elseYield().getOperands();
+ auto thenYieldArgs = thenYield.getOperands();
+ auto elseYieldArgs = elseYield.getOperands();
SmallVector<Type> nonHoistable;
for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
@@ -2203,10 +2710,12 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
}
rewriter.setInsertionPointToEnd(replacement.thenBlock());
- rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenYield(), trueYields);
+ rewriter.replaceOpWithNewOp<YieldOp>(replacement.thenTerminator(),
+ trueYields);
rewriter.setInsertionPointToEnd(replacement.elseBlock());
- rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseYield(), falseYields);
+ rewriter.replaceOpWithNewOp<YieldOp>(replacement.elseTerminator(),
+ falseYields);
rewriter.replaceOp(op, results);
return success();
@@ -2358,36 +2867,35 @@ struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
if (op.getNumResults() == 0)
return failure();
- auto trueYield =
- cast<scf::YieldOp>(op.getThenRegion().back().getTerminator());
- auto falseYield =
- cast<scf::YieldOp>(op.getElseRegion().back().getTerminator());
+ YieldOp thenYield = dyn_cast<YieldOp>(op.thenTerminator());
+ YieldOp elseYield = dyn_cast<YieldOp>(op.elseTerminator());
+ if (!thenYield || !elseYield)
+ return failure();
rewriter.setInsertionPoint(op->getBlock(),
op.getOperation()->getIterator());
bool changed = false;
Type i1Ty = rewriter.getI1Type();
- for (auto [trueResult, falseResult, opResult] :
- llvm::zip(trueYield.getResults(), falseYield.getResults(),
- op.getResults())) {
- if (trueResult == falseResult) {
+ for (auto [thenResult, elseResult, opResult] : llvm::zip(
+ thenYield.getResults(), elseYield.getResults(), op.getResults())) {
+ if (thenResult == elseResult) {
if (!opResult.use_empty()) {
- opResult.replaceAllUsesWith(trueResult);
+ opResult.replaceAllUsesWith(thenResult);
changed = true;
}
continue;
}
- BoolAttr trueYield, falseYield;
- if (!matchPattern(trueResult, m_Constant(&trueYield)) ||
- !matchPattern(falseResult, m_Constant(&falseYield)))
+ BoolAttr thenYield, elseYield;
+ if (!matchPattern(thenResult, m_Constant(&thenYield)) ||
+ !matchPattern(elseResult, m_Constant(&elseYield)))
continue;
- bool trueVal = trueYield.getValue();
- bool falseVal = falseYield.getValue();
- if (!trueVal && falseVal) {
+ bool thenVal = thenYield.getValue();
+ bool elseVal = elseYield.getValue();
+ if (!thenVal && elseVal) {
if (!opResult.use_empty()) {
- Dialect *constDialect = trueResult.getDefiningOp()->getDialect();
+ Dialect *constDialect = thenResult.getDefiningOp()->getDialect();
Value notCond = arith::XOrIOp::create(
rewriter, op.getLoc(), op.getCondition(),
constDialect
@@ -2399,7 +2907,7 @@ struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
changed = true;
}
}
- if (trueVal && !falseVal) {
+ if (thenVal && !elseVal) {
if (!opResult.use_empty()) {
opResult.replaceAllUsesWith(op.getCondition());
changed = true;
@@ -2454,18 +2962,16 @@ struct CombineIfs : public OpRewritePattern<IfOp> {
nextThen = nextIf.thenBlock();
if (!nextIf.getElseRegion().empty())
nextElse = nextIf.elseBlock();
- }
- if (arith::XOrIOp notv =
- nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
+ } else if (arith::XOrIOp notv =
+ nextIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
if (notv.getLhs() == prevIf.getCondition() &&
matchPattern(notv.getRhs(), m_One())) {
nextElse = nextIf.thenBlock();
if (!nextIf.getElseRegion().empty())
nextThen = nextIf.elseBlock();
}
- }
- if (arith::XOrIOp notv =
- prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
+ } else if (arith::XOrIOp notv =
+ prevIf.getCondition().getDefiningOp<arith::XOrIOp>()) {
if (notv.getLhs() == nextIf.getCondition() &&
matchPattern(notv.getRhs(), m_One())) {
nextElse = nextIf.thenBlock();
@@ -2476,14 +2982,25 @@ struct CombineIfs : public OpRewritePattern<IfOp> {
if (!nextThen && !nextElse)
return failure();
+ // Check that the terminators are all YieldOp
+ if (!isa<YieldOp>(prevIf.thenTerminator()) ||
+ (nextThen && !isa<YieldOp>(nextThen->getTerminator())))
+ return failure();
+ if (!prevIf.getElseRegion().empty() &&
+ !isa<YieldOp>(prevIf.elseTerminator()))
+ return failure();
+ if (nextElse && !nextElse->empty() &&
+ !isa<YieldOp>(nextElse->getTerminator()))
+ return failure();
SmallVector<Value> prevElseYielded;
if (!prevIf.getElseRegion().empty())
- prevElseYielded = prevIf.elseYield().getOperands();
+ prevElseYielded = prevIf.elseTerminator()->getOperands();
// Replace all uses of return values of op within nextIf with the
// corresponding yields
- for (auto it : llvm::zip(prevIf.getResults(),
- prevIf.thenYield().getOperands(), prevElseYielded))
+ for (auto it :
+ llvm::zip(prevIf.getResults(), prevIf.thenTerminator()->getOperands(),
+ prevElseYielded))
for (OpOperand &use :
llvm::make_early_inc_range(std::get<0>(it).getUses())) {
if (nextThen && nextThen->getParent()->isAncestor(
@@ -2511,16 +3028,16 @@ struct CombineIfs : public OpRewritePattern<IfOp> {
combinedIf.getThenRegion().begin());
if (nextThen) {
- YieldOp thenYield = combinedIf.thenYield();
- YieldOp thenYield2 = cast<YieldOp>(nextThen->getTerminator());
+ Operation *thenTerminator = combinedIf.thenTerminator();
+ Operation *thenTerminator2 = nextThen->getTerminator();
rewriter.mergeBlocks(nextThen, combinedIf.thenBlock());
rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
- SmallVector<Value> mergedYields(thenYield.getOperands());
- llvm::append_range(mergedYields, thenYield2.getOperands());
- YieldOp::create(rewriter, thenYield2.getLoc(), mergedYields);
- rewriter.eraseOp(thenYield);
- rewriter.eraseOp(thenYield2);
+ SmallVector<Value> mergedYields(thenTerminator->getOperands());
+ llvm::append_range(mergedYields, thenTerminator2->getOperands());
+ YieldOp::create(rewriter, thenTerminator->getLoc(), mergedYields);
+ rewriter.eraseOp(thenTerminator);
+ rewriter.eraseOp(thenTerminator2);
}
rewriter.inlineRegionBefore(prevIf.getElseRegion(),
@@ -2533,18 +3050,17 @@ struct CombineIfs : public OpRewritePattern<IfOp> {
combinedIf.getElseRegion(),
combinedIf.getElseRegion().begin());
} else {
- YieldOp elseYield = combinedIf.elseYield();
- YieldOp elseYield2 = cast<YieldOp>(nextElse->getTerminator());
+ Operation *elseTerminator = combinedIf.elseTerminator();
+ Operation *elseTerminator2 = nextElse->getTerminator();
rewriter.mergeBlocks(nextElse, combinedIf.elseBlock());
-
rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
- SmallVector<Value> mergedElseYields(elseYield.getOperands());
- llvm::append_range(mergedElseYields, elseYield2.getOperands());
+ SmallVector<Value> mergedElseYields(elseTerminator->getOperands());
+ llvm::append_range(mergedElseYields, elseTerminator2->getOperands());
- YieldOp::create(rewriter, elseYield2.getLoc(), mergedElseYields);
- rewriter.eraseOp(elseYield);
- rewriter.eraseOp(elseYield2);
+ YieldOp::create(rewriter, elseTerminator->getLoc(), mergedElseYields);
+ rewriter.eraseOp(elseTerminator);
+ rewriter.eraseOp(elseTerminator2);
}
}
@@ -2572,7 +3088,8 @@ struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
if (ifOp.getNumResults())
return failure();
Block *elseBlock = ifOp.elseBlock();
- if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
+ if (!elseBlock || (!llvm::hasSingleElement(*elseBlock) ||
+ !isa<YieldOp>(elseBlock->getTerminator())))
return failure();
auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
rewriter.inlineRegionBefore(ifOp.getThenRegion(), newIfOp.getThenRegion(),
@@ -2608,21 +3125,32 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
if (!llvm::hasSingleElement(nestedOps))
return failure();
+ auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
+ if (!nestedIf)
+ return failure();
+
+ // Terminator must be a YieldOp
+ if (!isa<YieldOp>(op.thenTerminator()))
+ return failure();
+
// If there is an else block, it can only yield
- if (op.elseBlock() && !llvm::hasSingleElement(*op.elseBlock()))
+ if (op.elseBlock() && (!llvm::hasSingleElement(*op.elseBlock()) ||
+ !isa<YieldOp>(op.elseTerminator())))
return failure();
- auto nestedIf = dyn_cast<IfOp>(*nestedOps.begin());
- if (!nestedIf)
+ // Same for the nested if: the then and else blocks can only yield.
+ if (!isa<YieldOp>(nestedIf.thenTerminator()))
return failure();
- if (nestedIf.elseBlock() && !llvm::hasSingleElement(*nestedIf.elseBlock()))
+ if (nestedIf.elseBlock() &&
+ (!llvm::hasSingleElement(*nestedIf.elseBlock()) ||
+ !isa<YieldOp>(nestedIf.elseTerminator())))
return failure();
- SmallVector<Value> thenYield(op.thenYield().getOperands());
- SmallVector<Value> elseYield;
+ SmallVector<Value> thenTerminator(op.thenTerminator()->getOperands());
+ SmallVector<Value> elseTerminator;
if (op.elseBlock())
- llvm::append_range(elseYield, op.elseYield().getOperands());
+ llvm::append_range(elseTerminator, op.elseTerminator()->getOperands());
// A list of indices for which we should upgrade the value yielded
// in the else to a select.
@@ -2632,19 +3160,20 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
// only permit combining if the value yielded when the condition
// is false in the outer scf.if is the same value yielded when the
// inner scf.if condition is false.
- // Note that the array access to elseYield will not go out of bounds
- // since it must have the same length as thenYield, since they both
+ // Note that the array access to elseTerminator will not go out of bounds
+ // since it must have the same length as thenTerminator, since they both
// come from the same scf.if.
- for (const auto &tup : llvm::enumerate(thenYield)) {
+ for (const auto &tup : llvm::enumerate(thenTerminator)) {
if (tup.value().getDefiningOp() == nestedIf) {
auto nestedIdx = llvm::cast<OpResult>(tup.value()).getResultNumber();
- if (nestedIf.elseYield().getOperand(nestedIdx) !=
- elseYield[tup.index()]) {
+ if (nestedIf.elseTerminator()->getOperand(nestedIdx) !=
+ elseTerminator[tup.index()]) {
return failure();
}
// If the correctness test passes, we will yield
// corresponding value from the inner scf.if
- thenYield[tup.index()] = nestedIf.thenYield().getOperand(nestedIdx);
+ thenTerminator[tup.index()] =
+ nestedIf.thenTerminator()->getOperand(nestedIdx);
continue;
}
@@ -2676,28 +3205,90 @@ struct CombineNestedIfs : public OpRewritePattern<IfOp> {
for (auto idx : elseYieldsToUpgradeToSelect)
results[idx] =
arith::SelectOp::create(rewriter, op.getLoc(), op.getCondition(),
- thenYield[idx], elseYield[idx]);
+ thenTerminator[idx], elseTerminator[idx]);
rewriter.mergeBlocks(nestedIf.thenBlock(), newIfBlock);
rewriter.setInsertionPointToEnd(newIf.thenBlock());
- rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenYield(), thenYield);
- if (!elseYield.empty()) {
+ rewriter.replaceOpWithNewOp<YieldOp>(newIf.thenTerminator(),
+ thenTerminator);
+ if (!elseTerminator.empty()) {
rewriter.createBlock(&newIf.getElseRegion());
rewriter.setInsertionPointToEnd(newIf.elseBlock());
- YieldOp::create(rewriter, loc, elseYield);
+ YieldOp::create(rewriter, loc, elseTerminator);
}
rewriter.replaceOp(op, results);
return success();
}
};
+/// Simplify if with breaking control flow in both branches.
+/// For example:
+/// scf.if %cmp {
+/// scf.break 2 %arg1
+/// } else {
+/// scf.continue 2
+/// }
+/// print(...) // This is dead code
+/// becomes
+/// scf.if %cmp {
+/// scf.break 2 %arg1
+/// }
+/// scf.continue 1
+struct SimplifyIfWithBreakingControlFlowInBothBranches
+ : public OpRewritePattern<IfOp> {
+ using OpRewritePattern<IfOp>::OpRewritePattern;
+ LogicalResult matchAndRewrite(IfOp op,
+ PatternRewriter &rewriter) const override {
+ if (op.getElseRegion().empty() || isa<YieldOp>(op.thenTerminator()) ||
+ isa<YieldOp>(op.elseTerminator()))
+ return failure();
+
+ // Inline the else block after the current op and erase everything after.
+ Block *block = op.elseBlock();
+
+ // Reduce the level of breaking control flow ops by 1 since we inline the
+ // region.
+ visitNestedBreakingControlFlowOps(
+ op.getElseRegion(), [&](Operation *visitedOp, int nestedLevel) {
+ if (nestedLevel <=
+ static_cast<int>(visitedOp->getNumBreakingControlRegions()))
+ visitedOp->setNumBreakingControlRegions(
+ visitedOp->getNumBreakingControlRegions() - 1);
+ });
+ Operation *terminator = block->getTerminator();
+ terminator->setNumBreakingControlRegions(
+ terminator->getNumBreakingControlRegions() - 1);
+ // Inline the else block after the current op
+ rewriter.inlineBlockBefore(block, op->getNextNode());
+
+ // Erase everything after the inline block.
+ Operation *toDelete = &op->getBlock()->back();
+ Operation *prevOp = toDelete;
+ do {
+ toDelete = prevOp;
+ prevOp = prevOp->getPrevNode();
+ rewriter.eraseOp(toDelete);
+ } while (prevOp != terminator);
+
+ // The "else" region is now empty, let's clone the if op and inline the then
+ // region.
+ auto newIfOp = rewriter.cloneWithoutRegions(op);
+ rewriter.inlineRegionBefore(op.getThenRegion(), newIfOp.getThenRegion(),
+ newIfOp.getThenRegion().begin());
+ rewriter.eraseOp(op);
+
+ return success();
+ }
+};
} // namespace
void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<CombineIfs, CombineNestedIfs, ConditionPropagation,
ConvertTrivialIfToSelect, RemoveEmptyElseBranch,
- ReplaceIfYieldWithConditionOrValue>(context);
+ RemoveStaticCondition, RemoveUnusedResults,
+ ReplaceIfYieldWithConditionOrValue,
+ SimplifyIfWithBreakingControlFlowInBothBranches>(context);
populateRegionBranchOpInterfaceCanonicalizationPatterns(
results, IfOp::getOperationName());
populateRegionBranchOpInterfaceInliningPattern(results,
@@ -3107,10 +3698,13 @@ void ParallelOp::getSuccessorRegions(
// ReduceOp
//===----------------------------------------------------------------------===//
-void ReduceOp::build(OpBuilder &builder, OperationState &result) {}
+void ReduceOp::build(OpBuilder &builder, OperationState &result) {
+ result.setNumBreakingControlRegions(1);
+}
void ReduceOp::build(OpBuilder &builder, OperationState &result,
ValueRange operands) {
+ result.setNumBreakingControlRegions(1);
result.addOperands(operands);
for (Value v : operands) {
OpBuilder::InsertionGuard guard(builder);
@@ -3439,8 +4033,8 @@ struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> {
auto it = llvm::find(ifOp->getResults(), arg);
if (it != ifOp->getResults().end()) {
size_t ifOpIdx = it.getIndex();
- Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx);
- Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx);
+ Value thenValue = ifOp.thenTerminator()->getOperand(ifOpIdx);
+ Value elseValue = ifOp.elseTerminator()->getOperand(ifOpIdx);
rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue);
rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue);
@@ -3488,7 +4082,7 @@ struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> {
});
// Inline ifOp then region into new whileOp after region.
- rewriter.eraseOp(ifOp.thenYield());
+ rewriter.eraseOp(ifOp.thenTerminator());
rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(),
newWhileOp.getAfterBody()->begin());
rewriter.eraseOp(ifOp);
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index 496a7b036e65d..8a0b7f955267e 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -179,8 +179,8 @@ struct IfOpInterface
std::optional<int64_t> dim,
ValueBoundsConstraintSet &cstr) {
unsigned int resultNum = cast<OpResult>(value).getResultNumber();
- Value thenValue = ifOp.thenYield().getResults()[resultNum];
- Value elseValue = ifOp.elseYield().getResults()[resultNum];
+ Value thenValue = ifOp.thenTerminator()->getOperand(resultNum);
+ Value elseValue = ifOp.elseTerminator()->getOperand(resultNum);
auto boundsBuilder = cstr.bound(value);
if (dim)
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 9b6a5a96fbc6b..f913251b31a1a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -232,8 +232,8 @@ struct IfOpInterface
auto ifOp = cast<scf::IfOp>(op);
size_t resultNum = std::distance(op->getOpResults().begin(),
llvm::find(op->getOpResults(), value));
- OpOperand *thenOperand = &ifOp.thenYield()->getOpOperand(resultNum);
- OpOperand *elseOperand = &ifOp.elseYield()->getOpOperand(resultNum);
+ OpOperand *thenOperand = &ifOp.thenTerminator()->getOpOperand(resultNum);
+ OpOperand *elseOperand = &ifOp.elseTerminator()->getOpOperand(resultNum);
return {{thenOperand, BufferRelation::Equivalent, /*isDefinite=*/false},
{elseOperand, BufferRelation::Equivalent, /*isDefinite=*/false}};
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaConvertIntegerTypeToSignless.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaConvertIntegerTypeToSignless.cpp
index 4b131333b956a..1b632e4853d51 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaConvertIntegerTypeToSignless.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaConvertIntegerTypeToSignless.cpp
@@ -85,7 +85,8 @@ class ConvertGenericOpWithIntegerTensorType : public ConversionPattern {
// Create new op with replaced operands and results
auto *newOp = Operation::create(
op->getLoc(), op->getName(), resultTypes, operands, op->getAttrs(),
- op->getPropertiesStorage(), op->getSuccessors(), op->getNumRegions());
+ op->getPropertiesStorage(), op->getSuccessors(), op->getNumRegions(),
+ op->getNumBreakingControlRegions());
// Handle regions in e.g. tosa.cond_if and tosa.while_loop
for (auto regions : llvm::zip(op->getRegions(), newOp->getRegions())) {
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 81455699421cc..19073d84a4441 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -3759,6 +3759,10 @@ void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); });
os << ')';
+ if (op->getNumBreakingControlRegions() != 0) {
+ os << " [" << op->getNumBreakingControlRegions() << "]";
+ }
+
// For terminators, print the list of successors and their operands.
if (op->getNumSuccessors() != 0) {
os << '[';
diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp
index f4c9242ed3479..879be5a204c5d 100644
--- a/mlir/lib/IR/Diagnostics.cpp
+++ b/mlir/lib/IR/Diagnostics.cpp
@@ -138,8 +138,8 @@ Diagnostic &Diagnostic::operator<<(Operation &op) {
return appendOp(op, OpPrintingFlags());
}
-Diagnostic &Diagnostic::operator<<(OpWithFlags op) {
- return appendOp(*op.getOperation(), op.flags());
+Diagnostic &Diagnostic::operator<<(const OpWithFlags &opWithFlags) {
+ return appendOp(*opWithFlags.getOperation(), opWithFlags.flags());
}
Diagnostic &Diagnostic::appendOp(Operation &op, const OpPrintingFlags &flags) {
diff --git a/mlir/lib/IR/Dominance.cpp b/mlir/lib/IR/Dominance.cpp
index 0e53b431b5d31..99a7d7d45a1f8 100644
--- a/mlir/lib/IR/Dominance.cpp
+++ b/mlir/lib/IR/Dominance.cpp
@@ -14,8 +14,11 @@
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/RegionKindInterface.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/GenericDomTreeConstruction.h"
+#define DEBUG_TYPE "dominance"
+
using namespace mlir;
using namespace mlir::detail;
@@ -289,6 +292,32 @@ bool DominanceInfoBase<IsPostDom>::properlyDominatesImpl(
// regions kinds, uses and defs can come in any order inside a block.
if (!hasSSADominance(aBlock))
return true;
+
+ // Any operation that propagates a control flow break invalidate the
+ // post-dominance relation.
+ LDBG() << "IsPostDom: " << IsPostDom
+ << " aIt != aBlock->end(): " << (aIt != aBlock->end())
+ << " bIt != bBlock->end(): " << (bIt != bBlock->end())
+ << " hasBreakingControlFlow(aBlock->getParent()): "
+ << hasBreakingControlFlow(aBlock->getParent());
+ if (IsPostDom && aIt != aBlock->end() && bIt != bBlock->end() &&
+ hasBreakingControlFlow(aBlock->getParent())) {
+ bool inRange = false;
+ for (Operation &op : *aBlock) {
+ if (&op == &*bIt)
+ inRange = true;
+ if (inRange) {
+ if (op.hasTrait<OpTrait::PropagateControlFlowBreak>() &&
+ hasBreakingControlFlowOps(&op)) {
+ LDBG() << "Breaking control flow: "
+ << OpWithFlags(&op, OpPrintingFlags().skipRegions());
+ return false;
+ }
+ }
+ if (&op == &*aIt)
+ break;
+ }
+ }
if constexpr (IsPostDom) {
return isBeforeInBlock(aBlock, bIt, aIt);
} else {
@@ -296,6 +325,7 @@ bool DominanceInfoBase<IsPostDom>::properlyDominatesImpl(
}
}
+ // TODO: this should handle breaks in the block. This is not yet implemented.
// If the blocks are different, use DomTree to resolve the query.
return getDomTree(aRegion).properlyDominates(aBlock, bBlock);
}
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index bf8a918641dfb..c7f8d8671600e 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -13,6 +13,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
@@ -32,10 +33,10 @@ using namespace mlir;
/// Create a new Operation from operation state.
Operation *Operation::create(const OperationState &state) {
- Operation *op =
- create(state.location, state.name, state.types, state.operands,
- state.attributes.getDictionary(state.getContext()),
- state.properties, state.successors, state.regions);
+ Operation *op = create(
+ state.location, state.name, state.types, state.operands,
+ state.attributes.getDictionary(state.getContext()), state.properties,
+ state.successors, state.regions, state.numBreakingControlRegions);
if (LLVM_UNLIKELY(state.propertiesAttr)) {
assert(!state.properties);
LogicalResult result =
@@ -52,11 +53,12 @@ Operation *Operation::create(Location location, OperationName name,
TypeRange resultTypes, ValueRange operands,
NamedAttrList &&attributes,
OpaqueProperties properties, BlockRange successors,
- RegionRange regions) {
+ RegionRange regions,
+ unsigned numBreakingControlRegions) {
unsigned numRegions = regions.size();
Operation *op =
create(location, name, resultTypes, operands, std::move(attributes),
- properties, successors, numRegions);
+ properties, successors, numRegions, numBreakingControlRegions);
for (unsigned i = 0; i < numRegions; ++i)
if (regions[i])
op->getRegion(i).takeBody(*regions[i]);
@@ -68,13 +70,14 @@ Operation *Operation::create(Location location, OperationName name,
TypeRange resultTypes, ValueRange operands,
NamedAttrList &&attributes,
OpaqueProperties properties, BlockRange successors,
- unsigned numRegions) {
+ unsigned numRegions,
+ unsigned numBreakingControlRegions) {
// Populate default attributes.
name.populateDefaultAttrs(attributes);
return create(location, name, resultTypes, operands,
attributes.getDictionary(location.getContext()), properties,
- successors, numRegions);
+ successors, numRegions, numBreakingControlRegions);
}
/// Overload of create that takes an existing DictionaryAttr to avoid
@@ -83,7 +86,8 @@ Operation *Operation::create(Location location, OperationName name,
TypeRange resultTypes, ValueRange operands,
DictionaryAttr attributes,
OpaqueProperties properties, BlockRange successors,
- unsigned numRegions) {
+ unsigned numRegions,
+ unsigned numBreakingControlRegions) {
assert(llvm::all_of(resultTypes, [](Type t) { return t; }) &&
"unexpected null result type");
@@ -93,7 +97,10 @@ Operation *Operation::create(Location location, OperationName name,
unsigned numSuccessors = successors.size();
unsigned numOperands = operands.size();
unsigned numResults = resultTypes.size();
- int opPropertiesAllocSize = llvm::alignTo<8>(name.getOpPropertyByteSize());
+ size_t opPropertiesByteSize = name.getOpPropertyByteSize();
+ if (numBreakingControlRegions)
+ opPropertiesByteSize += 8;
+ int opPropertiesAllocSize = llvm::alignTo<8>(opPropertiesByteSize);
// If the operation is known to have no operands, don't allocate an operand
// storage.
@@ -115,12 +122,16 @@ Operation *Operation::create(Location location, OperationName name,
void *rawMem = mallocMem + prefixByteSize;
// Create the new Operation.
- Operation *op = ::new (rawMem) Operation(
- location, name, numResults, numSuccessors, numRegions,
- opPropertiesAllocSize, attributes, properties, needsOperandStorage);
+ Operation *op = ::new (rawMem)
+ Operation(location, name, numResults, numSuccessors, numRegions,
+ numBreakingControlRegions, opPropertiesAllocSize, attributes,
+ properties, needsOperandStorage);
assert((numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) &&
"unexpected successors in a non-terminator operation");
+ assert((numBreakingControlRegions == 0 ||
+ op->mightHaveTrait<OpTrait::IsTerminator>()) &&
+ "unexpected breaking control regions in a non-terminator operation");
// Initialize the results.
auto resultTypeIt = resultTypes.begin();
@@ -154,10 +165,12 @@ Operation *Operation::create(Location location, OperationName name,
Operation::Operation(Location location, OperationName name, unsigned numResults,
unsigned numSuccessors, unsigned numRegions,
+ unsigned numBreakingControlRegions,
int fullPropertiesStorageSize, DictionaryAttr attributes,
OpaqueProperties properties, bool hasOperandStorage)
: location(location), numResults(numResults), numSuccs(numSuccessors),
numRegions(numRegions), hasOperandStorage(hasOperandStorage),
+ isBreakingControlFlowFlag(numBreakingControlRegions),
propertiesStorageSize((fullPropertiesStorageSize + 7) / 8), name(name) {
assert(attributes && "unexpected null attribute dictionary");
assert(fullPropertiesStorageSize <= propertiesCapacity &&
@@ -170,6 +183,9 @@ Operation::Operation(Location location, OperationName name, unsigned numResults,
"allowUnregisteredDialects() on the MLIRContext, or use "
"-allow-unregistered-dialect with the MLIR tool used.");
#endif
+ if (numBreakingControlRegions)
+ *reinterpret_cast<unsigned *>(getTrailingObjects<detail::OpProperties>()) =
+ numBreakingControlRegions;
if (fullPropertiesStorageSize)
name.initOpProperties(getPropertiesStorage(), properties);
}
@@ -729,7 +745,8 @@ Operation *Operation::clone(IRMapping &mapper, CloneOptions options) {
// Create the new operation.
auto *newOp = create(getLoc(), getName(), getResultTypes(), operands, attrs,
- getPropertiesStorage(), successors, getNumRegions());
+ getPropertiesStorage(), successors, getNumRegions(),
+ getNumBreakingControlRegions());
mapper.map(this, newOp);
// Clone the regions.
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index cd067f2cc25b3..c0ae9af2f5083 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -10,6 +10,9 @@
#include "mlir/IR/Iterators.h"
#include "mlir/IR/RegionKindInterface.h"
#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Support/DebugLog.h"
+
+#define DEBUG_TYPE "pattern-match"
using namespace mlir;
@@ -225,7 +228,7 @@ void RewriterBase::eraseOp(Operation *op) {
// Then erase the enclosing op.
eraseSingleOp(op);
};
-
+ LDBG() << "RewriterBase::eraseOp: " << *op;
eraseTree(op);
}
diff --git a/mlir/lib/IR/RegionKindInterface.cpp b/mlir/lib/IR/RegionKindInterface.cpp
index 007f4cf92dbc7..8a4daa22c6d64 100644
--- a/mlir/lib/IR/RegionKindInterface.cpp
+++ b/mlir/lib/IR/RegionKindInterface.cpp
@@ -12,6 +12,11 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/RegionKindInterface.h"
+#include "mlir/Support/WalkResult.h"
+
+#include "llvm/Support/DebugLog.h"
+
+#define DEBUG_TYPE "region-kind-interface"
using namespace mlir;
@@ -32,3 +37,154 @@ bool mlir::mayBeGraphRegion(Region ®ion) {
return false;
return !regionKindOp.hasSSADominance(region.getRegionNumber());
}
+
+namespace {
+// Iterator on all reachable operations in the region.
+// Also keep track if we visited the nested regions of the current op
+// already to drive the traversal.
+struct NestedOpIterator {
+ NestedOpIterator(Region *region, int nestedLevel)
+ : region(region), nestedLevel(nestedLevel) {
+ regionIt = region->begin();
+ blockIt = regionIt->end();
+ if (regionIt != region->end())
+ blockIt = regionIt->begin();
+ }
+ // Advance the iterator to the next reachable operation.
+ void advance() {
+ assert(regionIt != region->end());
+ if (blockIt == regionIt->end()) {
+ ++regionIt;
+ if (regionIt != region->end())
+ blockIt = regionIt->begin();
+ return;
+ }
+ ++blockIt;
+ if (blockIt != regionIt->end()) {
+ LDBG() << this << " - Incrementing block iterator, next op: "
+ << OpWithFlags(&*blockIt, OpPrintingFlags().skipRegions());
+ }
+ }
+
+ // The region we're iterating over.
+ Region *region;
+ // The Block currently being iterated over.
+ Region::iterator regionIt;
+ // The Operation currently being iterated over.
+ Block::iterator blockIt;
+ // The nested level of the current region relative to the starting region.
+ int nestedLevel = 0;
+};
+} // namespace
+
+/// Recursive walk that calls the callback only for terminator operation which
+/// are breaking control flow.
+static void walk(Operation *rootOp,
+ function_ref<WalkResult(Operation *, int)> callback) {
+ // Worklist of regions to visit to drive the traversal.
+ SmallVector<NestedOpIterator> worklist;
+
+ // Perform a traversal of the regions, visiting each
+ // reachable operation.
+ for (Region ®ion : rootOp->getRegions()) {
+ if (region.empty())
+ continue;
+ worklist.push_back({®ion, 1});
+ }
+ while (!worklist.empty()) {
+ NestedOpIterator &it = worklist.back();
+ if (it.regionIt == it.region->end()) {
+ // We're done with this region.
+ worklist.pop_back();
+ continue;
+ }
+ if (it.blockIt == it.regionIt->end()) {
+ // We're done with this block.
+ it.advance();
+ continue;
+ }
+ Operation *op = &*it.blockIt;
+
+ // Only call the callback if we're at the end of the block.
+ if (std::next(it.blockIt) == it.regionIt->end() &&
+ callback(op, it.nestedLevel).wasInterrupted())
+ return;
+
+ // Advance before pushing nested regions to avoid reference invalidation.
+ int currentNestedLevel = it.nestedLevel;
+ it.advance();
+
+ // Recursively visit the nested regions.
+ for (Region &nestedRegion : op->getRegions()) {
+ if (nestedRegion.empty())
+ continue;
+ worklist.push_back({&nestedRegion, currentNestedLevel + 1});
+ }
+ }
+}
+
+/// Return true if `op` has at least one RegionTerminator nested inside it
+/// that directly targets `op` as its control-flow destination. A terminator
+/// directly targets `op` when its num-breaking-regions equals the nesting
+/// depth at which it appears inside `op`'s regions, AND that depth is > 1
+/// (depth 1 would mean the terminator exits only the immediately enclosing
+/// region, going to `op`'s parent rather than `op` itself — that case is
+/// handled by the normal RegionBranchOpInterface path).
+bool mlir::hasNestedPredecessors(Operation *op) {
+ bool found = false;
+ walk(op, [&](Operation *visitedOp, int nestedLevel) {
+ if (nestedLevel > 1 &&
+ nestedLevel ==
+ static_cast<int>(visitedOp->getNumBreakingControlRegions()))
+ found = true;
+ return found ? WalkResult::interrupt() : WalkResult::advance();
+ });
+ return found;
+}
+
+/// Return true if `op` contains any RegionTerminator whose num-breaking-
+/// regions value would carry it *past* `op` toward an outer ancestor. Such a
+/// terminator's nestedLevel (depth relative to `op`'s body) is strictly less
+/// than its num-breaking-regions, meaning `op` is one of the intermediate
+/// PropagateControlFlowBreak ops that is bypassed by this early exit.
+bool mlir::hasBreakingControlFlowOps(Operation *op) {
+ bool found = false;
+ walk(op, [&](Operation *visitedOp, int nestedLevel) {
+ if (nestedLevel <
+ static_cast<int>(visitedOp->getNumBreakingControlRegions()))
+ found = true;
+ return found ? WalkResult::interrupt() : WalkResult::advance();
+ });
+ return found;
+}
+
+/// Invoke `callback` for every RegionTerminator inside `op` whose
+/// num-breaking-regions is >= its current nesting depth (i.e. the terminator
+/// either terminates directly into `op` or propagates further upward). The
+/// `nestedLevel` passed to the callback is the 1-based depth of the terminator
+/// relative to `op`'s outermost region.
+void mlir::detail::visitNestedBreakingControlFlowOpsImpl(
+ Operation *op,
+ function_ref<WalkResult(Operation *, int nestedLevel)> callback) {
+ ::walk(op, [&](Operation *visitedOp, int nestedLevel) {
+ if (nestedLevel <=
+ static_cast<int>(visitedOp->getNumBreakingControlRegions()))
+ return callback(visitedOp, nestedLevel);
+ return WalkResult::advance();
+ });
+}
+
+/// Collect all RegionTerminator ops nested inside `op` that directly target
+/// `op` as their control-flow destination (num-breaking-regions ==
+/// nestedLevel). These are the ops that transfer control to `op` on an early
+/// exit path; they are the "nested predecessors" of `op`.
+void mlir::collectAllNestedPredecessors(
+ Operation *op, SmallVector<Operation *> &predecessors) {
+ visitNestedBreakingControlFlowOps(
+ op, [&](Operation *visitedOp, int nestedLevel) {
+ if (nestedLevel ==
+ static_cast<int>(visitedOp->getNumBreakingControlRegions()))
+ predecessors.push_back(visitedOp);
+ return WalkResult::advance();
+ });
+}
\ No newline at end of file
diff --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp
index e19537a901d18..3fe50f459d55c 100644
--- a/mlir/lib/IR/Verifier.cpp
+++ b/mlir/lib/IR/Verifier.cpp
@@ -29,6 +29,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/RegionKindInterface.h"
#include "mlir/IR/Threading.h"
#include "llvm/ADT/PointerIntPair.h"
@@ -129,6 +130,46 @@ LogicalResult OperationVerifier::verifyOnEntrance(Block &block) {
if (op.getNumSuccessors() != 0 && &op != &block.back())
return op.emitError(
"operation with block successors must terminate its parent block");
+
+ Operation *currentOp = &op;
+ int numBreakingControlRegions =
+ static_cast<int>(op.getNumBreakingControlRegions());
+ if (numBreakingControlRegions) {
+ for (int i [[maybe_unused]] :
+ llvm::seq<int>(0, numBreakingControlRegions)) {
+ currentOp = currentOp->getParentOp();
+ if (!currentOp)
+ return op.emitError("operation with breaking control regions "
+ "exceeding the number of enclosing parent ops");
+ if (numBreakingControlRegions == 1)
+ continue;
+ if (i == numBreakingControlRegions - 1) {
+ auto successorOp =
+ dyn_cast<HasBreakingControlFlowOpInterface>(currentOp);
+ if (!successorOp)
+ return currentOp
+ ->emitError(
+ "operation has a nested predecessor but does not "
+ "have "
+ "the HasBreakingControlFlowOpInterface trait.")
+ .attachNote(op.getLoc())
+ << " for this predecessor operation (" << op.getName()
+ << ")";
+
+ if (!successorOp.acceptsTerminator(&op))
+ return currentOp->emitError(
+ "operation with breaking control regions "
+ "does not accept terminator: ")
+ << OpWithFlags(&op, OpPrintingFlags().skipRegions());
+ } else {
+ if (!currentOp->hasTrait<OpTrait::PropagateControlFlowBreak>())
+ return op.emitError("breaking control regions through an op that "
+ "does not have "
+ "the PropagateControlFlowBreak trait: ")
+ << OpWithFlags(currentOp, OpPrintingFlags().skipRegions());
+ }
+ }
+ }
}
return success();
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 2f95531455b2b..e77888a5b66e7 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -75,14 +75,14 @@ detail::getBranchSuccessorArgument(const SuccessorOperands &operands,
LogicalResult
detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
const SuccessorOperands &operands) {
- LDBG() << "Verifying branch successor operands for successor #" << succNo
- << " in operation " << op->getName();
+ LDBG(3) << "Verifying branch successor operands for successor #" << succNo
+ << " in operation " << op->getName();
// Check the count.
unsigned operandCount = operands.size();
Block *destBB = op->getSuccessor(succNo);
- LDBG() << "Branch has " << operandCount << " operands, target block has "
- << destBB->getNumArguments() << " arguments";
+ LDBG(3) << "Branch has " << operandCount << " operands, target block has "
+ << destBB->getNumArguments() << " arguments";
if (operandCount != destBB->getNumArguments())
return op->emitError() << "branch has " << operandCount
@@ -91,22 +91,22 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
<< destBB->getNumArguments();
// Check the types.
- LDBG() << "Checking type compatibility for "
- << (operandCount - operands.getProducedOperandCount())
- << " forwarded operands";
+ LDBG(3) << "Checking type compatibility for "
+ << (operandCount - operands.getProducedOperandCount())
+ << " forwarded operands";
for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
++i) {
Type operandType = operands[i].getType();
Type argType = destBB->getArgument(i).getType();
- LDBG() << "Checking type compatibility: operand type " << operandType
- << " vs argument type " << argType;
+ LDBG(3) << "Checking type compatibility: operand type " << operandType
+ << " vs argument type " << argType;
if (!cast<BranchOpInterface>(op).areTypesCompatible(operandType, argType))
return op->emitError() << "type mismatch for bb argument #" << i
<< " of successor #" << succNo;
}
- LDBG() << "Branch successor operand verification successful";
+ LDBG(3) << "Branch successor operand verification successful";
return success();
}
@@ -217,6 +217,7 @@ LogicalResult detail::verifyRegionBranchOpInterface(Operation *op) {
}
}
}
+
return success();
}
diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index 82dfbcbfa4d4f..57f3249c8a598 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -747,12 +747,10 @@ void Operator::populateOpStructure() {
auto *dependentTraits = trait->getValueAsListInit("dependentTraits");
for (auto *traitInit : *dependentTraits)
if (!traitSet.contains(traitInit))
- PrintFatalError(
- def.getLoc(),
- trait->getValueAsString("trait") + " requires " +
- cast<DefInit>(traitInit)->getDef()->getValueAsString(
- "trait") +
- " to precede it in traits list");
+ PrintFatalError(def.getLoc(),
+ trait->getName() + " requires " +
+ cast<DefInit>(traitInit)->getDef()->getName() +
+ " to precede it in traits list");
};
std::function<void(const ListInit *)> insert;
diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt
index 3ca16239ba33c..93dd73102ad52 100644
--- a/mlir/lib/Transforms/Utils/CMakeLists.txt
+++ b/mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -14,6 +14,9 @@ add_mlir_library(MLIRTransformUtils
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
+ DEPENDS
+ MLIRRegionKindInterfaceIncGen
+
LINK_LIBS PUBLIC
MLIRAnalysis
MLIRCallInterfaces
diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp
index 73107cfc36ea9..11e23ae45d2cf 100644
--- a/mlir/lib/Transforms/Utils/InliningUtils.cpp
+++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp
@@ -286,6 +286,13 @@ static LogicalResult inlineRegionImpl(
[&](BlockArgument arg) { return !mapper.contains(arg); }))
return failure();
+ // Check that the region has no nested successors, e.g. a nested return inside
+ // a function.
+ if (auto opWithBreakingControlPredecessor =
+ dyn_cast<HasBreakingControlFlowOpInterface>(src->getParentOp()))
+ if (opWithBreakingControlPredecessor.hasNestedPredecessors())
+ return failure();
+
// Check that the operations within the source region are valid to inline.
Region *insertRegion = inlineBlock->getParent();
if (!interface.isLegalToInline(insertRegion, src, shouldCloneInlinedRegion,
diff --git a/mlir/test/Analysis/test-dominance.mlir b/mlir/test/Analysis/test-dominance.mlir
index a926a8271200a..6f2302eb26cdf 100644
--- a/mlir/test/Analysis/test-dominance.mlir
+++ b/mlir/test/Analysis/test-dominance.mlir
@@ -680,3 +680,29 @@ func.func @func_loop_nested_region(
// CHECK: ^{{.*}}
// CHECK: }
// CHECK: }
+
+
+// -----
+
+// CHECK-LABEL: Testing : func_loop_early_exit
+func.func @func_loop_early_exit(%cond : i1, %arg0 : index) -> index {
+ %0 = scf.loop -> index {
+ scf.loop {
+ scf.if %cond {
+ scf.break 3 %arg0 : index
+ }
+ "test.foo"() : () -> ()
+ scf.break 1 {test.print_dominance = true}
+ }
+ }
+ return %0 : index
+}
+
+// CHECK: postdominates(scf.break 1 {test.print_dominance = true} {{.*}} scf.break 3
+// CHECK-SAME: = 0
+// CHECK: postdominates(scf.break 1 {test.print_dominance = true} {{.*}} scf.if
+// CHECK-SAME: = 0
+// CHECK: postdominates(scf.break 1 {test.print_dominance = true} {{.*}} "test.foo"
+// CHECK-SAME: = 1
+// CHECK: postdominates(scf.break 1 {test.print_dominance = true} {{.*}} scf.break 1
+// CHECK-SAME: = 1
diff --git a/mlir/test/Conversion/SCFToControlFlow/convert-early-exit-to-cfg.mlir b/mlir/test/Conversion/SCFToControlFlow/convert-early-exit-to-cfg.mlir
new file mode 100644
index 0000000000000..e78cfa6e9538d
--- /dev/null
+++ b/mlir/test/Conversion/SCFToControlFlow/convert-early-exit-to-cfg.mlir
@@ -0,0 +1,84 @@
+// RUN: mlir-opt -convert-scf-to-cf -split-input-file %s | FileCheck %s
+
+
+func.func @loop_break(%cond : i1) {
+ // CHECK: test.op1
+ "test.op1"() : () -> ()
+ // CHECK-NEXT: cf.br [[LOOP1_ENTRY:.*]]
+ // CHECK-NEXT: [[LOOP1_ENTRY]]
+ scf.loop {
+ // CHECK-NEXT: test.op2
+ "test.op2"() : () -> ()
+ // CHECK-NEXT: cf.cond_br %arg0, [[IF_ENTRY:.*]], [[IF_CONTINUE:.*]]
+ // CHECK-NEXT: [[IF_ENTRY]]
+ scf.if %cond {
+ "test.op3"() : () -> ()
+ scf.break 2 loc("break1")
+ }
+ "test.op3"() : () -> ()
+ } loc("loop1")
+ "test.op4"() : () -> ()
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @nested_loops_and_ifs(
+// CHECK-SAME: %[[COND1:.*]]: i1,
+// CHECK-SAME: %[[COND2:.*]]: i1
+func.func @nested_loops_and_ifs(%cond1 : i1, %cond2 : i1) {
+ // CHECK: test.op1
+ "test.op1"() : () -> ()
+ // CHECK-NEXT: cf.br ^[[OUTER_LOOP_ENTRY:.*]]
+ scf.loop {
+ // CHECK-NEXT: ^[[OUTER_LOOP_ENTRY]]:
+ // CHECK-NEXT: cf.cond_br %[[COND1]], ^[[IF1_THEN_BLOCK:.*]], ^[[IF1_EXIT:.*]]
+ scf.if %cond1 {
+ // CHECK-NEXT: ^[[IF1_THEN_BLOCK]]:
+ // CHECK-NEXT: test.op2
+ "test.op2"() : () -> ()
+ // CHECK-NEXT: cf.br ^[[INNER_LOOP_ENTRY:.*]]
+ scf.loop {
+ // CHECK-NEXT: ^[[INNER_LOOP_ENTRY]]:
+ // CHECK-NEXT: test.op3
+ "test.op3"() : () -> ()
+ // CHECK-NEXT: cf.cond_br %[[COND1]], ^[[IF2_THEN_BLOCK:.*]], ^[[IF2_EXIT:.*]]
+ scf.if %cond1 {
+ // CHECK-NEXT: ^[[IF2_THEN_BLOCK]]:
+ // CHECK-NEXT: test.op4
+ "test.op4"() : () -> ()
+ // CHECK-NEXT: cf.br ^[[INNER_LOOP_ENTRY]]
+ scf.continue 2 loc("continue1")
+ }
+ // CHECK-NEXT: ^[[IF2_EXIT]]:
+ // CHECK-NEXT: test.op5
+ "test.op5"() : () -> ()
+ // CHECK-NEXT: cf.cond_br %[[COND2]], ^[[IF3_THEN_BLOCK:.*]], ^[[IF3_EXIT:.*]]
+ scf.if %cond2 {
+ // CHECK-NEXT: ^[[IF3_THEN_BLOCK]]:
+ // CHECK-NEXT: test.op6
+ "test.op6"() : () -> ()
+ // CHECK-NEXT: cf.br ^[[FUNC_EXIT:.*]]
+ scf.break 4 loc("break2")
+ }
+ // CHECK-NEXT: ^[[IF3_EXIT]]:
+ // CHECK-NEXT: test.op7
+ "test.op7"() : () -> ()
+ // CHECK-NEXT: cf.br ^[[INNER_LOOP_ENTRY]]
+ scf.continue 1 loc("continue2")
+ } loc("loop3")
+ // CHECK-NEXT: ^[[AFTER_INNER_LOOP:.*]]:
+ // CHECK-NEXT: test.op8
+ "test.op8"() : () -> ()
+ // CHECK-NEXT: cf.br ^[[IF1_EXIT]]
+ } loc("if1")
+ // CHECK-NEXT: ^[[IF1_EXIT]]:
+ // CHECK-NEXT: cf.br ^[[OUTER_LOOP_ENTRY]]
+ scf.continue 1 loc("continue3")
+ } loc("loop2")
+ // CHECK-NEXT: ^[[FUNC_EXIT]]:
+ // CHECK-NEXT: test.op9
+ "test.op9"() : () -> ()
+ // CHECK-NEXT: return
+ return
+}
diff --git a/mlir/test/Dialect/SCF/loop_canonicalize.mlir b/mlir/test/Dialect/SCF/loop_canonicalize.mlir
new file mode 100644
index 0000000000000..86541778208d0
--- /dev/null
+++ b/mlir/test/Dialect/SCF/loop_canonicalize.mlir
@@ -0,0 +1,331 @@
+// RUN: mlir-opt %s -pass-pipeline='builtin.module(func.func(canonicalize{test-convergence}))' -split-input-file | FileCheck %s
+
+// CHECK-LABEL: func @fold_single_iteration_loop1
+func.func @fold_single_iteration_loop1(%arg0 : index) -> index {
+ // CHECK-NOT: loop
+ %0 = scf.loop -> index {
+ scf.break 1 %arg0 : index
+ }
+ return %0 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_single_iteration_loop_with_propagating_control_flow
+func.func @fold_single_iteration_loop_with_propagating_control_flow(%cond : i1, %arg0 : index) -> index {
+ %0 = scf.loop -> index {
+ scf.loop {
+ scf.if %cond {
+ scf.break 3 %arg0 : index
+ }
+ scf.break 1
+ }
+ }
+ return %0 : index
+}
+
+// -----
+
+// CHECK-LABEL: func @loop_not_combine_ifs
+func.func @loop_not_combine_ifs(%arg0 : i1, %arg2: i64) {
+ // Verify that we don't combine ifs when terminator mismatches
+ scf.loop {
+ // CHECK: scf.if
+ %res = scf.if %arg0 -> i32 {
+ %v = "test.firstCodeTrue"() : () -> i32
+ scf.yield %v : i32
+ } else {
+ %v2 = "test.firstCodeFalse"() : () -> i32
+ scf.break 2
+ }
+ // CHECK: scf.if
+ %res2 = scf.if %arg0 -> i32 {
+ %v = "test.secondCodeTrue"() : () -> i32
+ scf.yield %v : i32
+ } else {
+ %v2 = "test.secondCodeFalse"() : () -> i32
+ scf.continue 2
+ }
+ }
+ return
+}
+
+// -----
+
+// TODO: We should combine these but we don't right now
+// CHECK-LABEL: func @loop_combine_ifs
+func.func @loop_combine_ifs(%arg0 : i1, %arg2: i64) {
+ // Verify that we don't combine ifs when terminator smatches
+ scf.loop {
+ // CHECK: scf.if
+ // TODO-CHECK-NOT: scf.if
+ %res = scf.if %arg0 -> i32 {
+ %v = "test.firstCodeTrue"() : () -> i32
+ scf.yield %v : i32
+ } else {
+ %v2 = "test.firstCodeFalse"() : () -> i32
+ scf.break 2
+ }
+ %res2 = scf.if %arg0 -> i32 {
+ %v = "test.secondCodeTrue"() : () -> i32
+ scf.yield %v : i32
+ } else {
+ %v2 = "test.secondCodeFalse"() : () -> i32
+ scf.break 2
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @do_not_merge_nested_if_with_breaking_control_flow1
+func.func @do_not_merge_nested_if_with_breaking_control_flow1(%arg0: i1, %arg1: i1) {
+// The outer if then terminator isn't a yield, blocking the merge.
+// CHECK: scf.loop
+// CHECK: scf.if
+// CHECK: scf.if
+ scf.loop {
+ scf.if %arg0 {
+ scf.if %arg1 {
+ "test.op"() : () -> ()
+ scf.yield
+ }
+ scf.break 2
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @do_not_merge_nested_if_with_breaking_control_flow2
+func.func @do_not_merge_nested_if_with_breaking_control_flow2(%arg0: i1, %arg1: i1) {
+// The outer if else block terminator isn't a yield, blocking the merge.
+// CHECK: scf.loop
+// CHECK: scf.if
+// CHECK: scf.if
+// CHECK: else
+// CHECK-NEXT: scf.break 2
+ scf.loop {
+ scf.if %arg0 {
+ scf.if %arg1 {
+ "test.op"() : () -> ()
+ scf.yield
+ }
+ scf.yield
+ } else {
+ scf.break 2
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @do_not_merge_nested_if_with_breaking_control_flow3
+func.func @do_not_merge_nested_if_with_breaking_control_flow3(%arg0: i1, %arg1: i1) {
+// The nested if then block terminator isn't a yield, blocking the merge.
+// CHECK: scf.loop
+// CHECK: scf.if
+// CHECK: scf.if
+// CHECK: test.op
+// CHECK-NEXT: scf.break 3
+ scf.loop {
+ scf.if %arg0 {
+ scf.if %arg1 {
+ "test.op"() : () -> ()
+ scf.break 3
+ }
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @do_not_merge_nested_if_with_breaking_control_flow4
+func.func @do_not_merge_nested_if_with_breaking_control_flow4(%arg0: i1, %arg1: i1) {
+// The nested if else block terminator isn't a yield, blocking the merge.
+// CHECK: scf.loop
+// CHECK: scf.if
+// CHECK: scf.if
+// CHECK: else
+// CHECK-NEXT: scf.break 3
+ scf.loop {
+ scf.if %arg0 {
+ scf.if %arg1 {
+ "test.op"() : () -> ()
+ } else {
+ scf.break 3
+ }
+ }
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @do_not_convert_if_to_select1
+func.func @do_not_convert_if_to_select1(%cond: i1, %arg0 : index, %arg1 : index) -> index {
+ %loop_res = scf.loop -> index {
+ // Inner then terminator is not a yield, blocking transform to select.
+ // CHECK: scf.if
+ %0 = scf.if %cond -> index {
+ scf.break 2 %arg0 : index
+ } else {
+ scf.yield %arg1 : index
+ }
+ scf.break 1 %0 : index
+ }
+ return %loop_res : index
+}
+
+// -----
+
+// CHECK-LABEL: func @do_not_convert_if_to_select2
+func.func @do_not_convert_if_to_select2(%cond: i1, %arg0 : index, %arg1 : index) -> index {
+ %loop_res = scf.loop -> index {
+ // Inner then terminator is not a yield, blocking transform to select.
+ // CHECK: scf.if
+ %0 = scf.if %cond -> index {
+ scf.yield %arg0 : index
+ } else {
+ scf.break 2 %arg1 : index
+ }
+ scf.break 1 %0 : index
+ }
+ return %loop_res : index
+}
+
+
+// -----
+
+// CHECK-LABEL: func @fold_constant_if_with_breaking_cf1
+func.func @fold_constant_if_with_breaking_cf1(%arg0 : index, %arg1 : index) -> index {
+ %cond = arith.constant true
+ // Infinite loop here, inner if can be simplified, the "break" is
+ // unreachable.
+ // CHECK: scf.loop
+ // CHECK-NEXT: }
+ %loop_res = scf.loop -> index {
+ %0 = scf.if %cond -> index {
+ scf.yield %arg0 : index
+ } else {
+ scf.break 2 %arg1 : index
+ }
+ }
+ return %loop_res : index
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_constant_if_with_breaking_cf2
+func.func @fold_constant_if_with_breaking_cf2(%arg0 : index, %arg1 : index) -> index {
+ %cond = arith.constant false
+ // Infinite loop here, inner if can be simplified, the "break" is
+ // unreachable.
+ // CHECK: scf.loop
+ // CHECK-NEXT: }
+ %loop_res = scf.loop -> index {
+ %0 = scf.if %cond -> index {
+ scf.break 2 %arg1 : index
+ } else {
+ scf.yield %arg0 : index
+ }
+ }
+ return %loop_res : index
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_constant_if_with_breaking_cf3
+func.func @fold_constant_if_with_breaking_cf3(%arg0 : index, %arg1 : index) -> index {
+ %cond = arith.constant true
+ // Single iteration loop here, inner if can be simplified, and then the
+ // loop itself.
+ // CHECK-NOT: scf.loop
+ %loop_res = scf.loop -> index {
+ %0 = scf.if %cond -> index {
+ scf.break 2 %arg1 : index
+ } else {
+ scf.yield %arg0 : index
+ }
+ }
+ return %loop_res : index
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_constant_if_with_breaking_cf4
+func.func @fold_constant_if_with_breaking_cf4(%arg0 : index, %arg1 : index) -> index {
+ %cond = arith.constant false
+ // Single iteration loop here, inner if can be simplified, and then the
+ // loop itself.
+ // CHECK-NOT: scf.loop
+ %loop_res = scf.loop -> index {
+ %0 = scf.if %cond -> index {
+ scf.yield %arg0 : index
+ } else {
+ scf.break 2 %arg1 : index
+ }
+ }
+ return %loop_res : index
+}
+
+// -----
+
+// Verify that removing the unused results of an if with nested breaking control flow
+// operation works.
+// CHECK-LABEL: func @remove_unused_if_results1
+func.func @remove_unused_if_results1(%cond : i1, %arg0 : index) -> index {
+ // CHECK: scf.loop
+ %0 = scf.loop -> index {
+ // CHECK: %[[FOO:.*]]:3 = "test.foo"
+ %foo:3 = "test.foo" () : () -> (i32, i64, index)
+ // CHECK-NOT: %[[RES:.*]] = scf.if
+ // CHECK: scf.if
+ %res:3 = scf.if %cond -> (i32, i64, index) {
+ // CHECK: scf.yield
+ scf.yield %foo#0, %foo#1, %foo#2 : i32, i64, index
+ } else {
+ // CHECK: scf.break 2 %[[FOO]]#2 : index
+ scf.break 2 %foo#2 : index
+ }
+ // CHECK: "test.op"(%[[FOO]]#1)
+ "test.op"(%res#1) : (i64) -> ()
+ }
+ return %0 : index
+}
+
+// -----
+
+// Verify that removing the unused results of an if with nested breaking control flow
+// operation works.
+// CHECK-LABEL: func @simplify_if_with_breaking_controlflow_in_both_branches
+func.func @simplify_if_with_breaking_controlflow_in_both_branches(%cond : i1, %cond2 : i1, %arg0 : index) -> index {
+ // CHECK: scf.loop
+ %0 = scf.loop -> index {
+ // CHECK: %[[FOO:.*]] = "test.foo"
+ %foo = "test.foo" () : () -> (index)
+ // CHECK: scf.if
+ scf.if %cond {
+ // CHECK: scf.break 2 %[[FOO]] : index
+ scf.break 2 %foo : index
+ // CHECK-NOT: else
+ } else {
+ // CHECK: %[[BAR:.*]] = "test.bar"
+ %bar = "test.bar" () : () -> (index)
+ // CHECK: scf.if
+ scf.if %cond2 {
+ // verify that this is correctly updated when inlining the parent region.
+ // CHECK: scf.break 2 %[[BAR]] : index
+ scf.break 3 %bar : index
+ }
+ scf.continue 2
+ }
+ "test.op"() : () -> ()
+ }
+ return %0 : index
+}
+
diff --git a/mlir/test/IR/early-exit-invalid.mlir b/mlir/test/IR/early-exit-invalid.mlir
new file mode 100644
index 0000000000000..4d776b3fac055
--- /dev/null
+++ b/mlir/test/IR/early-exit-invalid.mlir
@@ -0,0 +1,67 @@
+
+// RUN: mlir-opt %s --split-input-file --verify-diagnostics
+
+
+// expected-error @+1 {{operation has a nested predecessor but does not have the HasBreakingControlFlowOpInterface trait}}
+ func.func @loop_continue() {
+ scf.loop {
+// expected-note @+1 {{for this predecessor operation (scf.continue)}}
+ scf.continue 2
+ } loc("loop1")
+ return
+}
+
+// -----
+
+func.func @loop_result_mismatch(%value : f32) {
+ // expected-error @+1 {{'scf.loop' op along control flow edge from Operation scf.break to parent: successor operand type #0 'f32' should match successor input type #0 'i32'}}
+ %result = scf.loop -> i32 {
+ scf.break 1 %value : f32 // expected-note {{region branch point}}
+ }
+ return
+}
+
+// -----
+
+func.func @loop_result_number_mismatch(%value : f32) {
+ // expected-error @+1 {{'scf.loop' op along control flow edge from Operation scf.break to parent: region branch point has 1 operands, but region successor needs 2 inputs}}
+ %result:2 = scf.loop -> f32, f32 {
+ scf.break 1 %value : f32 // expected-note {{region branch point}}
+ }
+ return
+}
+
+// -----
+
+func.func @loop_continue_mismatch(%init : i32, %value : f32) {
+ // expected-error @+1 {{'scf.loop' op along control flow edge from Operation scf.continue to Region #0: successor operand type #0 'f32' should match successor input type #0 'i32'}}
+ scf.loop iter_args(%next = %init) : i32 {
+ scf.continue 1 %value : f32 // expected-note {{region branch point}}
+ }
+ return
+}
+
+
+// -----
+
+func.func @loop_iterargs_mismatch(%init : i32, %value : f32) {
+ // expected-error @+2 {{'scf.loop' op along control flow edge from parent to Region #0: successor operand type #0 'i32' should match successor input type #0 'f32'}}
+ // expected-note @+1 {{region branch point}}
+ "scf.loop"(%init) ({
+ ^body(%next : f32):
+ scf.continue 1 %init : i32
+ }) : (i32) -> ()
+ return
+}
+
+// -----
+
+func.func @loop_iterargs_mismatch(%init : i32, %value : f32) {
+ // expected-error @+2 {{'scf.loop' op along control flow edge from parent to Region #0: region branch point has 1 operands, but region successor needs 2 inputs}}
+ // expected-note @+1 {{region branch point}}
+ "scf.loop"(%init) ({
+ ^body(%next : i32, %next2 : f32):
+ scf.continue 1 %init : i32
+ }) : (i32) -> ()
+ return
+}
diff --git a/mlir/test/IR/early-exit.mlir b/mlir/test/IR/early-exit.mlir
new file mode 100644
index 0000000000000..7e3c5420746a9
--- /dev/null
+++ b/mlir/test/IR/early-exit.mlir
@@ -0,0 +1,77 @@
+// RUN: mlir-opt --print-region-branch-op-interface %s --split-input-file | FileCheck %s
+// RUN: mlir-opt %s --mlir-print-debuginfo --mlir-print-op-generic --split-input-file | mlir-opt --print-region-branch-op-interface --split-input-file | FileCheck %s
+
+
+// CHECK-LABEL: func @unregistered_op
+func.func @unregistered_op(%cond : i1) {
+ "test.some_loop"() ({
+ "test.some_if"(%cond) ({
+ "test.some_break"() [2] : () -> ()
+ }) : (i1) -> ()
+ "test.continue"() [1] : () -> ()
+ }) : () -> () return
+}
+
+
+// -----
+
+func.func @loop_break(%cond : i1) {
+ // CHECK: Found RegionBranchOpInterface operation: scf.loop {...} loc("loop1")
+ // CHECK: - Successor is region #0
+ // CHECK: - Found 2 predecessor(s)
+ // CHECK: - Predecessor is scf.break 2 loc("break1")
+ // CHECK: - Predecessor is scf.continue 1
+ scf.loop {
+ scf.if %cond {
+ scf.break 2 loc("break1")
+ }
+ } loc("loop1")
+ return
+}
+
+// -----
+
+func.func @loop_continue(%cond1 : i1, %cond2 : i1) {
+ // CHECK: Found RegionBranchOpInterface operation: scf.loop {...} loc("loop2")
+ // CHECK: - Successor is region #0
+ // CHECK: - Found 2 predecessor(s)
+ // CHECK: - Predecessor is scf.break 3 loc("break2")
+ // CHECK: - Predecessor is scf.continue 1
+ scf.loop {
+ // CHECK: Found RegionBranchOpInterface operation: scf.loop {...} loc("loop3")
+ // CHECK: - Successor is region #0
+ // CHECK: - Found 2 predecessor(s)
+ // CHECK: - Predecessor is scf.continue 2 loc("continue1")
+ // CHECK: - Predecessor is scf.continue 1
+ scf.loop {
+ scf.if %cond1 {
+ scf.continue 2 loc("continue1")
+ }
+ scf.if %cond2 {
+ scf.break 3 loc("break2")
+ }
+ } loc("loop3")
+ } loc("loop2")
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @loop_with_results(
+func.func @loop_with_results(%value : f32) -> f32 {
+ %result = scf.loop -> f32 {
+ scf.break 1 %value : f32
+ }
+ return %result : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @loop_continue_iterargs(
+func.func @loop_continue_iterargs(%init : i32) {
+ scf.loop iter_args(%next = %init) : i32 {
+ scf.continue 1 %next : i32
+ }
+ return
+}
+
diff --git a/mlir/test/Integration/Dialect/SCF/early_exit.mlir b/mlir/test/Integration/Dialect/SCF/early_exit.mlir
new file mode 100644
index 0000000000000..974ad681e3200
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SCF/early_exit.mlir
@@ -0,0 +1,82 @@
+// RUN: mlir-opt %s -convert-scf-to-cf --canonicalize --convert-cf-to-llvm --convert-to-llvm | \
+// RUN: mlir-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+
+
+// End-to-end test of all fp reduction intrinsics (not exhaustive unit tests).
+module {
+ llvm.func @entry() {
+ // Constant for the iteration space and various conditions
+ %one = llvm.mlir.constant(1 : i64) : i64
+ %two = llvm.mlir.constant(2 : i64) : i64
+ %three = llvm.mlir.constant(3 : i64) : i64
+ %four = llvm.mlir.constant(4 : i64) : i64
+ %counter_init = llvm.mlir.constant(0 : i64) : i64
+
+
+// CHECK: Outer Loop Begin with counter: 0
+// CHECK-NEXT: Inner Loop Begin, counter: 1
+// CHECK-NEXT: continue inner loop
+// CHECK-NEXT: Inner Loop Begin, counter: 2
+// CHECK-NEXT: Iteration 2, loop back to outer loop
+// CHECK-NEXT: Outer Loop Begin with counter: 2
+// CHECK-NEXT: Inner Loop Begin, counter: 3
+// CHECK-NEXT: continue inner loop
+// CHECK-NEXT: Inner Loop Begin, counter: 4
+// CHECK-NEXT: continue inner loop
+// CHECK-NEXT: Inner Loop Begin, counter: 5
+// CHECK-NEXT: Last iteration, break out of outer loop
+// CHECK-NEXT: Outer loop finished with result: 4
+
+
+ %result = scf.loop iter_args(%counter_out = %counter_init) : i64 -> i64 {
+ // Outer loop iteration
+ vector.print str "Outer Loop Begin with counter: "
+ vector.print %counter_out : i64
+
+ scf.loop iter_args(%counter = %counter_out) : i64 {
+ // %counter will go from 0 to 4
+ // %counter_update will go from 1 to 5
+ %counter_update = llvm.add %counter, %one : i64
+
+ // Inner loop iteration
+ // print from 1..5
+ vector.print str "Inner Loop Begin, counter: "
+ vector.print %counter_update : i64
+
+ // On the second iteration, print 2.3 and loop back to the outer loop.
+ %cond1 = llvm.icmp "eq" %counter_update, %two : i64
+ scf.if %cond1 {
+ vector.print str "Iteration 2, loop back to outer loop\n"
+ scf.continue 3 %counter_update : i64
+ }
+
+ // Exit condition when counter>4
+ %cond2 = llvm.icmp "sge" %counter, %four : i64
+ scf.if %cond2 {
+ vector.print str "Last iteration, break out of outer loop\n"
+ // return the counter from the previous iteration here (pre-update)
+ scf.break 3 %counter : i64
+ }
+
+ %cond3 = llvm.icmp "eq" %counter_update, %three : i64
+ scf.if %cond2 {
+ vector.print str "Iteration 3, break out of inner loop"
+ scf.break 2
+ }
+ vector.print str "continue inner loop\n"
+ scf.continue 1 %counter_update : i64
+ }
+ vector.print str "continue outer loop\n"
+ scf.continue 1 %counter_out : i64
+ }
+
+// After the loop nest finishes
+ vector.print str "Outer loop finished with result: "
+ vector.print %result : i64
+
+ llvm.return
+ }
+}
diff --git a/mlir/test/lib/IR/TestDominance.cpp b/mlir/test/lib/IR/TestDominance.cpp
index b34149b3e2cbd..24696c8dd028c 100644
--- a/mlir/test/lib/IR/TestDominance.cpp
+++ b/mlir/test/lib/IR/TestDominance.cpp
@@ -14,6 +14,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dominance.h"
+#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
@@ -21,6 +22,14 @@ using namespace mlir;
/// Overloaded helper to call the right function based on whether we are testing
/// dominance or post-dominance.
+static bool dominatesOrPostDominates(DominanceInfo &dominanceInfo, Operation *a,
+ Operation *b) {
+ return dominanceInfo.dominates(a, b);
+}
+static bool dominatesOrPostDominates(PostDominanceInfo &dominanceInfo,
+ Operation *a, Operation *b) {
+ return dominanceInfo.postDominates(a, b);
+}
static bool dominatesOrPostDominates(DominanceInfo &dominanceInfo, Block *a,
Block *b) {
return dominanceInfo.dominates(a, b);
@@ -72,6 +81,30 @@ class DominanceTest {
template <typename DominanceT>
void printDominance(DominanceT &dominanceInfo,
bool printCommonDominatorInfo) {
+ if (printCommonDominatorInfo) {
+ operation->walk([&](Operation *op) {
+ if (!op->getDiscardableAttr("test.print_dominance"))
+ return;
+ operation->walk([&](Operation *nested) {
+ if (std::is_same<DominanceInfo, DominanceT>::value)
+ llvm::outs() << "dominates(";
+ else
+ llvm::outs() << "postdominates(";
+ bool isDominated =
+ dominatesOrPostDominates(dominanceInfo, op, nested);
+ llvm::outs() << OpWithFlags(op, OpPrintingFlags()
+ .skipRegions()
+ .enableDebugInfo()
+ .assumeVerified())
+ << ", "
+ << OpWithFlags(nested, OpPrintingFlags()
+ .skipRegions()
+ .enableDebugInfo()
+ .assumeVerified())
+ << ") = " << std::to_string(isDominated) << "\n";
+ });
+ });
+ }
DenseSet<Block *> parentVisited;
operation->walk([&](Operation *op) {
Block *block = op->getBlock();
diff --git a/mlir/test/lib/Interfaces/CMakeLists.txt b/mlir/test/lib/Interfaces/CMakeLists.txt
index 6a21ed10eec6f..3aa5097b7ed20 100644
--- a/mlir/test/lib/Interfaces/CMakeLists.txt
+++ b/mlir/test/lib/Interfaces/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(LoopLikeInterface)
+add_subdirectory(RegionBranchOpInterface)
add_subdirectory(TilingInterface)
diff --git a/mlir/test/lib/Interfaces/RegionBranchOpInterface/CMakeLists.txt b/mlir/test/lib/Interfaces/RegionBranchOpInterface/CMakeLists.txt
new file mode 100644
index 0000000000000..8e003942e41c0
--- /dev/null
+++ b/mlir/test/lib/Interfaces/RegionBranchOpInterface/CMakeLists.txt
@@ -0,0 +1,9 @@
+add_mlir_library(MLIRTestRegionBranchOpInterface
+ TestRegionBranchOpInterface.cpp
+
+ EXCLUDE_FROM_LIBMLIR
+ )
+mlir_target_link_libraries(MLIRTestRegionBranchOpInterface PUBLIC
+ MLIRControlFlowInterfaces
+ MLIRPass
+ )
diff --git a/mlir/test/lib/Interfaces/RegionBranchOpInterface/TestRegionBranchOpInterface.cpp b/mlir/test/lib/Interfaces/RegionBranchOpInterface/TestRegionBranchOpInterface.cpp
new file mode 100644
index 0000000000000..a1910465d93eb
--- /dev/null
+++ b/mlir/test/lib/Interfaces/RegionBranchOpInterface/TestRegionBranchOpInterface.cpp
@@ -0,0 +1,76 @@
+//===- TestBlockInLoop.cpp - Pass to test mlir::blockIsInLoop -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+namespace {
+/// This is a test pass that tests Blocks's isInLoop method by checking if each
+/// block in a function is in a loop and outputing if it is
+struct PrintRegionBranchOpInterfacePass
+ : public PassWrapper<PrintRegionBranchOpInterfacePass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrintRegionBranchOpInterfacePass)
+
+ StringRef getArgument() const final {
+ return "print-region-branch-op-interface";
+ }
+ StringRef getDescription() const final {
+ return "Print control-flow edges represented by "
+ "mlir::RegionBranchOpInterface";
+ }
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ op->walk<WalkOrder::PreOrder>([&](RegionBranchOpInterface branchOp) {
+ llvm::outs() << "Found RegionBranchOpInterface operation: "
+ << OpWithFlags(
+ branchOp,
+ OpPrintingFlags().skipRegions().enableDebugInfo())
+ << "\n";
+ SmallVector<RegionSuccessor> regions;
+ branchOp.getSuccessorRegions(RegionBranchPoint::parent(), regions);
+ for (auto &successor : regions) {
+ if (successor.isParent()) {
+ llvm::outs() << " - Successor is parent\n";
+ } else {
+ llvm::outs() << " - Successor is region #"
+ << successor.getSuccessor()->getRegionNumber() << "\n";
+ }
+ }
+ if (auto breakingControlFlowOp =
+ dyn_cast<HasBreakingControlFlowOpInterface>(
+ branchOp.getOperation())) {
+ SmallVector<Operation *> predecessors;
+ llvm::outs() << " - Collecting all nested predecessors\n";
+ collectAllNestedPredecessors(breakingControlFlowOp, predecessors);
+ llvm::outs() << " - Found " << predecessors.size()
+ << " predecessor(s)\n";
+ for (auto &predecessor : predecessors) {
+ llvm::outs() << " - Predecessor is "
+ << OpWithFlags(
+ predecessor,
+ OpPrintingFlags().skipRegions().enableDebugInfo())
+ << "\n";
+ }
+ }
+ });
+ }
+};
+
+} // namespace
+
+namespace mlir {
+void registerRegionBranchOpInterfaceTestPasses() {
+ PassRegistration<PrintRegionBranchOpInterfacePass>();
+}
+} // namespace mlir
diff --git a/mlir/test/mlir-tblgen/op-error.td b/mlir/test/mlir-tblgen/op-error.td
index a2eab1f08df28..658a9a951da84 100644
--- a/mlir/test/mlir-tblgen/op-error.td
+++ b/mlir/test/mlir-tblgen/op-error.td
@@ -121,6 +121,6 @@ def OpInterfaceB : OpInterface<"OpInterfaceB"> {
let dependentTraits = [OpTraitA];
}
-// ERROR13: error: OpInterfaceB::Trait requires OpTraitA to precede it in traits list
+// ERROR13: error: OpInterfaceB requires OpTraitA to precede it in traits list
def OpInterfaceWithoutDependentTrait : Op<Test_Dialect, "default_value", [OpInterfaceB]> {}
#endif
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index c607ccfa80e3c..821bad3a4166a 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -22,6 +22,7 @@ if(MLIR_INCLUDE_TESTS)
MLIRGPUTestPasses
MLIRLinalgTestPasses
MLIRLoopLikeInterfaceTestPasses
+ MLIRTestRegionBranchOpInterface
MLIRMathTestPasses
MLIRTestMathToVCIX
MLIRMemRefTestPasses
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index a427132247e6d..d123fd85a9a2b 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -38,6 +38,7 @@ void registerLazyLoadingTestPasses();
void registerLoopLikeInterfaceTestPasses();
void registerPassManagerTestPass();
void registerPrintSpirvAvailabilityPass();
+void registerRegionBranchOpInterfaceTestPasses();
void registerRegionTestPasses();
void registerPrintTosaAvailabilityPass();
void registerShapeFunctionTestPasses();
@@ -191,6 +192,7 @@ static void registerTestPasses() {
registerPassManagerTestPass();
registerPrintSpirvAvailabilityPass();
registerRegionTestPasses();
+ registerRegionBranchOpInterfaceTestPasses();
registerShapeFunctionTestPasses();
registerSideEffectTestPasses();
registerSliceAnalysisTestPass();
diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp
index 04d3ed1f3b70d..2611922ae6d52 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/FormatGen.cpp
@@ -191,6 +191,7 @@ FormatToken FormatLexer::lexIdentifier(const char *tokStart) {
.Case("successors", FormatToken::kw_successors)
.Case("type", FormatToken::kw_type)
.Case("qualified", FormatToken::kw_qualified)
+ .Case("num-breaking-regions", FormatToken::kw_num_breaking_regions)
.Default(FormatToken::identifier);
return FormatToken(kind, str);
}
diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h
index 8e7d49bb37e71..0b44ab69dcd1f 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.h
+++ b/mlir/tools/mlir-tblgen/FormatGen.h
@@ -72,6 +72,7 @@ class FormatToken {
kw_struct,
kw_successors,
kw_type,
+ kw_num_breaking_regions,
keyword_end,
// String valued tokens.
@@ -305,6 +306,7 @@ class DirectiveElement : public FormatElementBase<FormatElement::Directive> {
Results,
Successors,
Type,
+ NumBreakingRegions,
Params,
Struct
};
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index d1f1e85371133..9344756435bef 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -721,6 +721,15 @@ class OpEmitter {
// one parameter. Similarly for operands and attributes.
void genCollectiveParamBuilder(CollectiveBuilderKind kind);
+ void emitImmediateRegionTerminator(MethodBody &body, const Operator &op) {
+ for (const auto &t : op.getTraits()) {
+ if (t.getDef().getName() == "ImmediateRegionTerminator") {
+ body << " " << builderOpState << ".setNumBreakingControlRegions(1);\n";
+ break;
+ }
+ }
+ }
+
// The kind of parameter to generate for result types in builders.
enum class TypeParamKind {
None, // No result type in parameter list.
@@ -2683,6 +2692,7 @@ void OpEmitter::genSeparateArgParamBuilder() {
genInlineCreateBody(paramList);
auto &body = m->body();
+ emitImmediateRegionTerminator(body, op);
genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
/*isRawValueAttr=*/attrType ==
AttrParamKind::UnwrappedValue);
@@ -2812,6 +2822,7 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder(
return;
genInlineCreateBody(paramList);
auto &body = m->body();
+ emitImmediateRegionTerminator(body, op);
// Operands
body << " " << builderOpState << ".addOperands(operands);\n";
@@ -3008,6 +3019,7 @@ void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
return;
genInlineCreateBody(paramList);
auto &body = m->body();
+ emitImmediateRegionTerminator(body, op);
genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
/*isRawValueAttr=*/attrType ==
AttrParamKind::UnwrappedValue);
@@ -3052,6 +3064,7 @@ void OpEmitter::genUseAttrAsResultTypeCollectiveParamBuilder(
genInlineCreateBody(paramList);
auto &body = m->body();
+ emitImmediateRegionTerminator(body, op);
// Push all result types to the operation state
std::string resultType;
@@ -3218,6 +3231,7 @@ void OpEmitter::genCollectiveParamBuilder(CollectiveBuilderKind kind) {
return;
genInlineCreateBody(paramList);
auto &body = m->body();
+ emitImmediateRegionTerminator(body, op);
// Operands
if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index b834c6f8d3aaf..367840e5bc6e8 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -222,6 +222,14 @@ class TypeDirective : public DirectiveElementBase<DirectiveElement::Type> {
bool shouldBeQualifiedFlag = false;
};
+/// This class represents the `num-breaking-regions` directive. This directive
+/// represents the number of breaking regions of an operation.
+class NumBreakingRegionsDirective
+ : public DirectiveElementBase<DirectiveElement::NumBreakingRegions> {
+public:
+ NumBreakingRegionsDirective() = default;
+};
+
/// This class represents a group of order-independent optional clauses. Each
/// clause starts with a literal element and has a coressponding parsing
/// element. A parsing element is a continous sequence of format elements.
@@ -1402,6 +1410,15 @@ void OperationFormat::genParser(Operator &op, OpClass &opClass) {
auto *method = opClass.addStaticMethod("::mlir::ParseResult", "parse",
std::move(paramList));
auto &body = method->body();
+ // RegionTerminators must have a num-breaking-regions parameter, it can be
+ // overridden by the operation format using the num-breaking-regions
+ // parameter.
+ for (const auto &t : op.getTraits()) {
+ if (t.getDef().getName() == "ImmediateRegionTerminator") {
+ body << " result.setNumBreakingControlRegions(1);\n";
+ break;
+ }
+ }
// Generate variables to store the operands and type within the format. This
// allows for referencing these variables in the presence of optional
@@ -1691,6 +1708,16 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
body << formatv(functionalTypeParserCode,
getTypeListName(dir->getInputs(), ignored),
getTypeListName(dir->getResults(), ignored));
+ } else if (isa<NumBreakingRegionsDirective>(element)) {
+ body.indent() << "{\n";
+ body.indent()
+ << "auto loc = parser.getCurrentLocation();(void)loc;\n"
+ << "int32_t numBreakingRegions = 0;\n"
+ << "if (parser.parseInteger(numBreakingRegions))\n"
+ << " return ::mlir::failure();\n"
+ << "result.setNumBreakingControlRegions(numBreakingRegions);\n";
+ body.unindent() << "}\n";
+ body.unindent();
} else {
llvm_unreachable("unknown format element");
}
@@ -2553,6 +2580,13 @@ void OperationFormat::genElementPrinter(FormatElement *element,
return;
}
+ // Emit the num-breaking-regions.
+ if (isa<NumBreakingRegionsDirective>(element)) {
+ body << " _odsPrinter << \" \" << "
+ "getOperation()->getNumBreakingControlRegions();\n";
+ return;
+ }
+
// Optionally insert a space before the next element. The AttrDict printer
// already adds a space as necessary.
if (shouldEmitSpace || !lastWasPunctuation)
@@ -2819,6 +2853,8 @@ class OpFormatParser : public FormatParser {
FailureOr<FormatElement *> parseTypeDirective(SMLoc loc, Context context);
FailureOr<FormatElement *> parseTypeDirectiveOperand(SMLoc loc,
bool isRefChild = false);
+ FailureOr<FormatElement *> parseNumBreakingRegionsDirective(SMLoc loc,
+ Context context);
//===--------------------------------------------------------------------===//
// Fields
@@ -3446,6 +3482,8 @@ OpFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
return parseTypeDirective(loc, ctx);
case FormatToken::kw_oilist:
return parseOIListDirective(loc, ctx);
+ case FormatToken::kw_num_breaking_regions:
+ return parseNumBreakingRegionsDirective(loc, ctx);
default:
return emitError(loc, "unsupported directive kind");
@@ -3696,6 +3734,14 @@ FailureOr<FormatElement *> OpFormatParser::parseTypeDirective(SMLoc loc,
return create<TypeDirective>(*operand);
}
+FailureOr<FormatElement *>
+OpFormatParser::parseNumBreakingRegionsDirective(SMLoc loc, Context context) {
+ if (context != TopLevelContext)
+ return emitError(
+ loc, "'num-breaking-regions' is only valid as a top-level directive");
+ return create<NumBreakingRegionsDirective>();
+}
+
LogicalResult OpFormatParser::markQualified(SMLoc loc, FormatElement *element) {
return TypeSwitch<FormatElement *, LogicalResult>(element)
.Case<AttributeVariable, TypeDirective>([](auto *element) {
diff --git a/mlir/unittests/Debug/FileLineColLocBreakpointManagerTest.cpp b/mlir/unittests/Debug/FileLineColLocBreakpointManagerTest.cpp
index 6a81422b6b66b..8f75fe05cd3dc 100644
--- a/mlir/unittests/Debug/FileLineColLocBreakpointManagerTest.cpp
+++ b/mlir/unittests/Debug/FileLineColLocBreakpointManagerTest.cpp
@@ -23,9 +23,10 @@ static Operation *createOp(MLIRContext *context, Location loc,
StringRef operationName,
unsigned int numRegions = 0) {
context->allowUnregisteredDialects();
+
return Operation::create(loc, OperationName(operationName, context), {}, {},
NamedAttrList(), OpaqueProperties(nullptr), {},
- numRegions);
+ numRegions, /*numBreakingControlRegions=*/0);
}
namespace {
diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp
index 9f3e7ed34a27d..7e6dcc586e506 100644
--- a/mlir/unittests/IR/OperationSupportTest.cpp
+++ b/mlir/unittests/IR/OperationSupportTest.cpp
@@ -24,7 +24,8 @@ static Operation *createOp(MLIRContext *context, ArrayRef<Value> operands = {},
context->allowUnregisteredDialects();
return Operation::create(UnknownLoc::get(context),
OperationName("foo.bar", context), resultTypes,
- operands, NamedAttrList(), nullptr, {}, numRegions);
+ operands, NamedAttrList(), nullptr, {}, numRegions,
+ /*numBreakingControlRegions=*/0);
}
namespace {
@@ -236,7 +237,7 @@ TEST(OperationFormatPrintTest, CanPrintNameAsPrefix) {
Operation *op = Operation::create(
NameLoc::get(StringAttr::get(&context, "my_named_loc")),
OperationName("t.op", &context), builder.getIntegerType(16), {},
- NamedAttrList(), nullptr, {}, 0);
+ NamedAttrList(), nullptr, {}, 0, /*numBreakingControlRegions=*/0);
std::string str;
OpPrintingFlags flags;
diff --git a/mlir/unittests/IR/ValueTest.cpp b/mlir/unittests/IR/ValueTest.cpp
index 97e32d474d522..18ed8d9929175 100644
--- a/mlir/unittests/IR/ValueTest.cpp
+++ b/mlir/unittests/IR/ValueTest.cpp
@@ -22,7 +22,8 @@ static Operation *createOp(MLIRContext *context, ArrayRef<Value> operands = {},
context->allowUnregisteredDialects();
return Operation::create(UnknownLoc::get(context),
OperationName("foo.bar", context), resultTypes,
- operands, NamedAttrList(), nullptr, {}, numRegions);
+ operands, NamedAttrList(), nullptr, {}, numRegions,
+ /*numBreakingControlRegions=*/0);
}
namespace {
diff --git a/mlir/unittests/Transforms/DialectConversion.cpp b/mlir/unittests/Transforms/DialectConversion.cpp
index 6418c9dc0ac5b..2c7b35e9ef69c 100644
--- a/mlir/unittests/Transforms/DialectConversion.cpp
+++ b/mlir/unittests/Transforms/DialectConversion.cpp
@@ -15,7 +15,8 @@ static Operation *createOp(MLIRContext *context) {
context->allowUnregisteredDialects();
return Operation::create(UnknownLoc::get(context),
OperationName("foo.bar", context), {}, {},
- NamedAttrList(), /*properties=*/nullptr, {}, 0);
+ NamedAttrList(), /*properties=*/nullptr, {}, 0,
+ /*numBreakingControlRegions=*/0);
}
namespace {
More information about the Mlir-commits
mailing list