[Mlir-commits] [mlir] [MLIR] Introduce support for early exits (PR #166688)
Mehdi Amini
llvmlistbot at llvm.org
Wed Nov 5 18:49:07 PST 2025
https://github.com/joker-eph created https://github.com/llvm/llvm-project/pull/166688
WIP, mostly lacking documentation, possibly more dataflow fixes as well.
Need to revisit the traits to support this, formalize better region termination and requirement on the numBreakingRegion index.
Right now inlining is incompatible with early returns, it would require inserting some scf.execute region somehow.
>From 887e71d66e7991f09d3f18178c54b1e457fcfae5 Mon Sep 17 00:00:00 2001
From: Mehdi Amini <joker.eph at gmail.com>
Date: Sat, 26 Apr 2025 04:51:22 -0700
Subject: [PATCH] [MLIR] Introduce support for early exits
WIP, mostly lacking documentation, possibly more dataflow
fixes as well.
Need to revisit the traits to support this, formalize better region termination
and requirement on the numBreakingRegion index.
Right now inlining is incompatible with early returns, it would
require inserting some scf.execute region somehow.
---
mlir/include/mlir/Dialect/SCF/IR/SCF.h | 45 +++++
mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 122 ++++++++++++-
mlir/include/mlir/IR/Diagnostics.h | 2 +-
mlir/include/mlir/IR/OpDefinition.h | 23 ++-
mlir/include/mlir/IR/Operation.h | 56 ++++--
mlir/include/mlir/IR/OperationSupport.h | 5 +
mlir/include/mlir/IR/RegionKindInterface.h | 32 ++++
mlir/include/mlir/IR/RegionKindInterface.td | 37 ++++
.../mlir/Interfaces/ControlFlowInterfaces.td | 20 +++
mlir/lib/AsmParser/Parser.cpp | 50 ++++--
.../SCFToControlFlow/SCFToControlFlow.cpp | 141 ++++++++++++++-
mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp | 2 +-
.../ShapeToStandard/ShapeToStandard.cpp | 3 +-
.../OwnershipBasedBufferDeallocation.cpp | 8 +-
.../GPU/Transforms/AsyncRegionRewriter.cpp | 3 +-
.../Quant/Transforms/NormalizeQuantTypes.cpp | 3 +-
mlir/lib/Dialect/SCF/IR/SCF.cpp | 169 +++++++++++++++++-
.../TosaConvertIntegerTypeToSignless.cpp | 3 +-
mlir/lib/IR/AsmPrinter.cpp | 4 +
mlir/lib/IR/Diagnostics.cpp | 4 +-
mlir/lib/IR/Dominance.cpp | 23 +++
mlir/lib/IR/Operation.cpp | 45 +++--
mlir/lib/IR/PatternMatch.cpp | 5 +-
mlir/lib/IR/RegionKindInterface.cpp | 123 +++++++++++++
mlir/lib/IR/Verifier.cpp | 37 ++++
mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 33 ++--
mlir/lib/Transforms/Utils/CMakeLists.txt | 3 +
mlir/lib/Transforms/Utils/InliningUtils.cpp | 7 +
.../convert-early-exit-to-cfg.mlir | 47 +++++
mlir/test/IR/early-exit-invalid.mlir | 65 +++++++
mlir/test/IR/early-exit.mlir | 64 +++++++
.../Integration/Dialect/SCF/early_exit.mlir | 82 +++++++++
mlir/test/lib/Interfaces/CMakeLists.txt | 1 +
.../RegionBranchOpInterface/CMakeLists.txt | 9 +
.../TestRegionBranchOpInterface.cpp | 76 ++++++++
mlir/tools/mlir-opt/CMakeLists.txt | 1 +
mlir/tools/mlir-opt/mlir-opt.cpp | 2 +
mlir/tools/mlir-tblgen/FormatGen.cpp | 1 +
mlir/tools/mlir-tblgen/FormatGen.h | 2 +
mlir/tools/mlir-tblgen/OpFormatGen.cpp | 37 ++++
.../FileLineColLocBreakpointManagerTest.cpp | 3 +-
mlir/unittests/IR/OperationSupportTest.cpp | 5 +-
mlir/unittests/IR/ValueTest.cpp | 3 +-
.../Transforms/DialectConversion.cpp | 3 +-
44 files changed, 1318 insertions(+), 91 deletions(-)
create mode 100644 mlir/test/Conversion/SCFToControlFlow/convert-early-exit-to-cfg.mlir
create mode 100644 mlir/test/IR/early-exit-invalid.mlir
create mode 100644 mlir/test/IR/early-exit.mlir
create mode 100644 mlir/test/Integration/Dialect/SCF/early_exit.mlir
create mode 100644 mlir/test/lib/Interfaces/RegionBranchOpInterface/CMakeLists.txt
create mode 100644 mlir/test/lib/Interfaces/RegionBranchOpInterface/TestRegionBranchOpInterface.cpp
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index e754a04b0903a..4307a8f82211c 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -29,6 +29,11 @@
namespace mlir {
namespace scf {
void buildTerminatedBody(OpBuilder &builder, Location loc);
+
+namespace op_impl {
+struct IfOpImplicitTerminatorType;
+struct LoopOpImplicitTerminatorType;
+}
} // namespace scf
} // namespace mlir
@@ -111,6 +116,46 @@ SmallVector<Value> replaceAndCastForOpIterArg(RewriterBase &rewriter,
OpOperand &operand,
Value replacement,
const ValueTypeCastFnTy &castFn);
+namespace op_impl {
+
+//===----------------------------------------------------------------------===//
+// ControlFlowImplicitTerminatorOperation
+//===----------------------------------------------------------------------===//
+
+/// This class provides an interface compatible with
+/// SingleBlockImplicitTerminator, but allows multiple types of potential
+/// terminators aside from just one. If a terminator isn't present, this will
+/// generate a `ImplicitOpT` operation.
+template <typename ImplicitOpT, typename... OtherTerminatorOpTs>
+struct ControlFlowImplicitTerminatorOpType {
+ /// Implementation of `classof` that supports all of the potential terminator
+ /// operations.
+ static bool classof(Operation *op) {
+ return isa<ImplicitOpT, OtherTerminatorOpTs...>(op);
+ }
+
+ //===--------------------------------------------------------------------===//
+ // Implicit Terminator Methods
+
+ /// The following methods are all used when interacting with the "implicit"
+ /// terminator.
+
+ template <typename... Args>
+ static void build(Args &&...args) {
+ ImplicitOpT::build(std::forward<Args>(args)...);
+ }
+ static constexpr StringLiteral getOperationName() {
+ return ImplicitOpT::getOperationName();
+ }
+};
+/// An implicit terminator type for `if` operations, which can contain:
+/// break, continue, yield.
+struct IfOpImplicitTerminatorType
+ : public ControlFlowImplicitTerminatorOpType<YieldOp, BreakOp, ContinueOp> {
+};
+struct LoopOpImplicitTerminatorType
+ : public ControlFlowImplicitTerminatorOpType<ContinueOp, BreakOp> {};
+} // namespace op_impl
/// Helper function to compute the difference between two values. This is used
/// by the loop implementations to compute the trip count.
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index cd033c140a233..eb9b99d426ba7 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -15,6 +15,7 @@
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/LoopLikeInterface.td"
+include "mlir/IR/OpAsmInterface.td"
include "mlir/IR/RegionKindInterface.td"
include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
@@ -143,6 +144,123 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// LoopOp
+//===----------------------------------------------------------------------===//
+
+def LoopOp : SCF_Op<"loop",[
+ AutomaticAllocationScope,
+ OpAsmOpInterface,
+ RecursiveMemoryEffects,
+ PropagateControlFlowBreak,
+ DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+ SingleBlockImplicitTerminator<"op_impl::LoopOpImplicitTerminatorType">,
+ HasBreakingControlFlowOpInterface
+ ]> {
+ let summary = "Loop until a break operation";
+ let description = [{
+ The `loop` operation represents an, unstructured, infinite loop that executes
+ until a `break` is reached.
+
+ The loop consists of a (1) a set of loop-carried values which are initialized by
+ `initValues` and updated by each iteration of the loop, and
+ (2) a region which represents the loop body.
+
+ The loop will execute the body of the loop until a `break` is dynamically executed.
+
+ Each control path of the loop must be terminated by:
+
+ - a `continue` that yields the next iteration's value for each loop carried variable.
+ - a `break` that terminates the loop and yields the final loop carried values.
+
+ As long as each loop iteration is terminated by one of these operations they may be combined with other control
+ flow operations to express different control flow patterns.
+
+ The loop operation produces one return value for each loop carried variable. The type of the `i`-th return
+ value is that of the `i`-th loop carried variable and its value is the final value of the
+ `i`-th loop carried variable.
+ }];
+
+ let arguments = (ins Variadic<AnyType>:$initValues);
+ let results = (outs Variadic<AnyType>:$resultValues);
+ let regions = (region SizedRegion<1>:$region);
+
+ let extraClassDeclaration = [{
+ static bool acceptsTerminator(Operation *predecessor) {
+ return isa<BreakOp, ContinueOp>(predecessor);
+ }
+
+ /// Return the iteration values of the loop region.
+ Block::BlockArgListType getRegionIterValues() {
+ return getRegion().getArguments();
+ }
+
+ /// Return the `index`-th region iteration value.
+ BlockArgument getRegionIterValue(unsigned index) {
+ return getRegionIterValues()[index];
+ }
+
+ /// Returns the number of region arguments for loop-carried values.
+ unsigned getNumRegionIterValues() { return getRegion().getNumArguments(); }
+ }];
+
+ let hasCustomAssemblyFormat = 1;
+ let hasRegionVerifier = 1;
+}
+
+
+//===----------------------------------------------------------------------===//
+// BreakOp
+//===----------------------------------------------------------------------===//
+
+def BreakOp : SCF_Op<"break", [
+ ReturnLike, Terminator, RegionBranchTerminatorOpInterface, ParentOneOf<["IfOp", "LoopOp"]>
+ ]> {
+ let summary = "Break from loop";
+ let description = [{
+ The `break` operation is a terminator operation of a `scf.loop`.
+
+ It may yield any number of operands to the parent loop upon termination.
+ The number of values yielded
+ and the execution semantics of how they are yielded are determined by the
+ parent loop.
+ }];
+
+
+ let arguments = (ins Variadic<AnyType>:$operands);
+ let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
+ let assemblyFormat = [{
+ num-breaking-regions attr-dict ($operands^ `:` type($operands))?
+ }];
+}
+
+
+//===----------------------------------------------------------------------===//
+// ContinueOp
+//===----------------------------------------------------------------------===//
+
+def ContinueOp : SCF_Op<"continue", [
+ Terminator, DeclareOpInterfaceMethods<RegionBranchTerminatorOpInterface>, ParentOneOf<["IfOp", "LoopOp"]>
+ ]> {
+ let summary = "Continue to next loop iteration";
+ let description = [{
+ The continue operation represents a block terminator that returns control to
+ a `scf.loop` operation. The operation may yield any number of operands to the parent
+ loop upon termination.
+
+ The requirements and semantics of the continue operation are defined by the parent loop
+ operation, see the loop operation's description for particular semantics.
+ }];
+
+ let arguments = (ins Variadic<AnyType>:$operands);
+ let builders = [OpBuilder<(ins), [{
+ $_state.setNumBreakingControlRegions(1);
+ }]>];
+ let assemblyFormat = [{
+ num-breaking-regions ($operands^ `:` type($operands))? attr-dict
+ }];
+ let hasVerifier = 1;
+}
//===----------------------------------------------------------------------===//
// ForOp
@@ -700,8 +818,8 @@ def InParallelOp : SCF_Op<"forall.in_parallel", [
def IfOp : SCF_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
"getNumRegionInvocations", "getRegionInvocationBounds",
"getEntrySuccessorRegions"]>,
- InferTypeOpAdaptor, SingleBlockImplicitTerminator<"scf::YieldOp">,
- RecursiveMemoryEffects, RecursivelySpeculatable, NoRegionArguments]> {
+ InferTypeOpAdaptor, SingleBlockImplicitTerminator<"op_impl::IfOpImplicitTerminatorType">,
+ RecursiveMemoryEffects, RecursivelySpeculatable, NoRegionArguments, PropagateControlFlowBreak]> {
let summary = "if-then-else operation";
let description = [{
The `scf.if` operation represents an if-then-else construct for
diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h
index a0a99f4953822..8adf928984469 100644
--- a/mlir/include/mlir/IR/Diagnostics.h
+++ b/mlir/include/mlir/IR/Diagnostics.h
@@ -200,7 +200,7 @@ class Diagnostic {
/// Stream in an Operation.
Diagnostic &operator<<(Operation &op);
- Diagnostic &operator<<(OpWithFlags op);
+ Diagnostic &operator<<(const OpWithFlags &opWithFlags);
Diagnostic &operator<<(Operation *op) { return *this << *op; }
/// Append an operation with the given printing flags.
Diagnostic &appendOp(Operation &op, const OpPrintingFlags &flags);
diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h
index be92fe0a6c7e3..27aa437a06ce2 100644
--- a/mlir/include/mlir/IR/OpDefinition.h
+++ b/mlir/include/mlir/IR/OpDefinition.h
@@ -891,7 +891,8 @@ struct SingleBlock : public TraitBase<ConcreteType, SingleBlock> {
// Non-empty regions must contain a single basic block.
if (!region.hasOneBlock())
return op->emitOpError("expects region #")
- << i << " to have 0 or 1 blocks";
+ << i << " to have 0 or 1 blocks, found "
+ << llvm::range_size(region) << " blocks";
if (!ConcreteType::template hasTrait<NoTerminator>()) {
Block &block = region.front();
@@ -1323,6 +1324,26 @@ struct HasParent {
};
};
+/// This class provides a verifier for ops that are expecting to have nested
+/// predecessors.
+template <typename... NestedPredecessorOpTypes>
+struct HasNestedTerminators {
+ template <typename ConcreteType>
+ class Impl : public TraitBase<ConcreteType, Impl> {
+ public:
+ static LogicalResult verifyTrait(Operation *op) {
+
+ return op->emitOpError()
+ << "expects nested predecessor op "
+ << (sizeof...(NestedPredecessorOpTypes) != 1 ? "to be one of '"
+ : "'")
+ << llvm::ArrayRef(
+ {NestedPredecessorOpTypes::getOperationName()...})
+ << "'";
+ }
+ };
+};
+
/// A trait for operations that have an attribute specifying operand segments.
///
/// Certain operations can have multiple variadic operands and their size
diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h
index b2019574a820d..178ea90d2699e 100644
--- a/mlir/include/mlir/IR/Operation.h
+++ b/mlir/include/mlir/IR/Operation.h
@@ -94,7 +94,8 @@ class alignas(8) Operation final
TypeRange resultTypes, ValueRange operands,
NamedAttrList &&attributes,
OpaqueProperties properties, BlockRange successors,
- unsigned numRegions);
+ unsigned numRegions,
+ unsigned numBreakingControlRegions);
/// Create a new Operation with the specific fields. This constructor uses an
/// existing attribute dictionary to avoid uniquing a list of attributes.
@@ -102,7 +103,8 @@ class alignas(8) Operation final
TypeRange resultTypes, ValueRange operands,
DictionaryAttr attributes,
OpaqueProperties properties, BlockRange successors,
- unsigned numRegions);
+ unsigned numRegions,
+ unsigned numBreakingControlRegions);
/// Create a new Operation from the fields stored in `state`.
static Operation *create(const OperationState &state);
@@ -112,8 +114,8 @@ class alignas(8) Operation final
TypeRange resultTypes, ValueRange operands,
NamedAttrList &&attributes,
OpaqueProperties properties,
- BlockRange successors = {},
- RegionRange regions = {});
+ BlockRange successors = {}, RegionRange regions = {},
+ unsigned numBreakingControlRegions = 0);
/// The name of an operation is the key identifier for it.
OperationName getName() { return name; }
@@ -705,6 +707,20 @@ class alignas(8) Operation final
bool hasSuccessors() { return numSuccs != 0; }
unsigned getNumSuccessors() { return numSuccs; }
+ bool isBreakingControlFlow() { return isBreakingControlFlowFlag; }
+ unsigned getNumBreakingControlRegions() {
+ if (!isBreakingControlFlow())
+ return 0;
+ return *reinterpret_cast<unsigned *>(
+ getTrailingObjects<detail::OpProperties>());
+ }
+ void setNumBreakingControlRegions(unsigned numBreakingControlRegions) {
+ assert(isBreakingControlFlow() &&
+ "operation is not a breaking control flow operation");
+ *reinterpret_cast<unsigned *>(getTrailingObjects<detail::OpProperties>()) =
+ numBreakingControlRegions;
+ }
+
Block *getSuccessor(unsigned index) {
assert(index < getNumSuccessors());
return getBlockOperands()[index].get();
@@ -898,14 +914,26 @@ class alignas(8) Operation final
}
/// Returns the properties storage.
OpaqueProperties getPropertiesStorage() {
- if (propertiesStorageSize)
- return getPropertiesStorageUnsafe();
+ if (propertiesStorageSize) {
+ void *properties =
+ reinterpret_cast<void *>(getTrailingObjects<detail::OpProperties>());
+ if (isBreakingControlFlowFlag)
+ properties =
+ reinterpret_cast<void *>(reinterpret_cast<char *>(properties) + 8);
+ return {properties};
+ }
return {nullptr};
}
OpaqueProperties getPropertiesStorage() const {
- if (propertiesStorageSize)
- return {reinterpret_cast<void *>(const_cast<detail::OpProperties *>(
- getTrailingObjects<detail::OpProperties>()))};
+ if (propertiesStorageSize) {
+ void *properties =
+ reinterpret_cast<void *>(const_cast<detail::OpProperties *>(
+ getTrailingObjects<detail::OpProperties>()));
+ if (isBreakingControlFlowFlag)
+ properties =
+ reinterpret_cast<void *>(reinterpret_cast<char *>(properties) + 8);
+ return {properties};
+ }
return {nullptr};
}
/// Returns the properties storage without checking whether properties are
@@ -960,8 +988,9 @@ class alignas(8) Operation final
private:
Operation(Location location, OperationName name, unsigned numResults,
unsigned numSuccessors, unsigned numRegions,
- int propertiesStorageSize, DictionaryAttr attributes,
- OpaqueProperties properties, bool hasOperandStorage);
+ unsigned numBreakingControlRegions, int propertiesStorageSize,
+ DictionaryAttr attributes, OpaqueProperties properties,
+ bool hasOperandStorage);
// Operations are deleted through the destroy() member because they are
// allocated with malloc.
@@ -1048,13 +1077,16 @@ class alignas(8) Operation final
const unsigned numResults;
const unsigned numSuccs;
- const unsigned numRegions : 23;
+ const unsigned numRegions : 22;
/// This bit signals whether this operation has an operand storage or not. The
/// operand storage may be elided for operations that are known to never have
/// operands.
bool hasOperandStorage : 1;
+ /// This bit signals [TODO]
+ bool isBreakingControlFlowFlag : 1;
+
/// The size of the storage for properties (if any), divided by 8: since the
/// Properties storage will always be rounded up to the next multiple of 8 we
/// save some bits here.
diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h
index 1ff7c56ddca38..1334d7ac08d62 100644
--- a/mlir/include/mlir/IR/OperationSupport.h
+++ b/mlir/include/mlir/IR/OperationSupport.h
@@ -971,6 +971,7 @@ struct OperationState {
llvm::function_ref<void(OpaqueProperties)> propertiesDeleter;
llvm::function_ref<void(OpaqueProperties, const OpaqueProperties)>
propertiesSetter;
+ unsigned numBreakingControlRegions = 0;
friend class Operation;
public:
@@ -1096,6 +1097,10 @@ struct OperationState {
}
void addSuccessors(BlockRange newSuccessors);
+ void setNumBreakingControlRegions(int numBreakingControlRegions) {
+ this->numBreakingControlRegions = numBreakingControlRegions;
+ }
+
/// Create a region that should be attached to the operation. These regions
/// can be filled in immediately without waiting for Operation to be
/// created. When it is, the region bodies will be transferred.
diff --git a/mlir/include/mlir/IR/RegionKindInterface.h b/mlir/include/mlir/IR/RegionKindInterface.h
index d6d3aeeb9bd05..6514d9ad83c25 100644
--- a/mlir/include/mlir/IR/RegionKindInterface.h
+++ b/mlir/include/mlir/IR/RegionKindInterface.h
@@ -36,6 +36,18 @@ class HasOnlyGraphRegion : public TraitBase<ConcreteType, HasOnlyGraphRegion> {
static RegionKind getRegionKind(unsigned index) { return RegionKind::Graph; }
static bool hasSSADominance(unsigned index) { return false; }
};
+
+/// Indicates that this operation may break control flow, by propagating the
+/// control flow break from a nested region.
+template <typename ConcreteType>
+class PropagateControlFlowBreak
+ : public TraitBase<ConcreteType, PropagateControlFlowBreak> {
+public:
+ static LogicalResult verifyTrait(Operation *op) {
+ return success(); // TODO
+ }
+};
+
} // namespace OpTrait
/// Return "true" if the given region may have SSA dominance. This function also
@@ -49,8 +61,28 @@ bool mayHaveSSADominance(Region ®ion);
/// implement the RegionKindInterface.
bool mayBeGraphRegion(Region ®ion);
+bool hasNestedPredecessors(Operation *op);
+
+/// Return "true" if the given operation may break control flow and contains
+/// nested operations that have a successor above this operation.
+bool hasBreakingControlFlowOps(Operation *op);
+
+void collectAllNestedPredecessors(Operation *op,
+ SmallVector<Operation *> &predecessors);
+
} // namespace mlir
#include "mlir/IR/RegionKindInterface.h.inc"
+namespace mlir {
+
+// Return true if the given region may break control flow.
+inline bool hasBreakingControlFlow(Region *region) {
+ return region->getParentOp()
+ ->hasTrait<OpTrait::PropagateControlFlowBreak>() ||
+ isa<HasBreakingControlFlowOpInterface>(region->getParentOp());
+}
+
+} // namespace mlir
+
#endif // MLIR_IR_REGIONKINDINTERFACE_H_
diff --git a/mlir/include/mlir/IR/RegionKindInterface.td b/mlir/include/mlir/IR/RegionKindInterface.td
index 607001a89250e..f0c9ca28e0100 100644
--- a/mlir/include/mlir/IR/RegionKindInterface.td
+++ b/mlir/include/mlir/IR/RegionKindInterface.td
@@ -61,4 +61,41 @@ def GraphRegionNoTerminator : TraitList<[
HasOnlyGraphRegion
]>;
+// Op can break ControlFlow.
+def PropagateControlFlowBreak : NativeOpTrait<"PropagateControlFlowBreak">;
+
+// OpInterface for operation accepting predecessors that are potentially
+// in nested regions (not the immediately enclosed ones).
+def HasBreakingControlFlowOpInterface : OpInterface<"HasBreakingControlFlowOpInterface"> {
+ let description = [{
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ StaticInterfaceMethod<
+ /*desc=*/[{
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"acceptsTerminator",
+ /*args=*/(ins "Operation *":$op),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ return true;
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"hasNestedPredecessors",
+ /*args=*/(ins),
+ /*methodBody=*/[{}],
+ /*defaultImplementation=*/[{
+ return ::mlir::hasNestedPredecessors(this->getOperation());
+ }]
+ >
+ ];
+}
+
+
#endif // MLIR_IR_REGIONKINDINTERFACE
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 94242e3ba39ce..f5c60300f2079 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -449,6 +449,26 @@ def SelectLikeOpInterface : OpInterface<"SelectLikeOpInterface"> {
];
}
+
+//===----------------------------------------------------------------------===//
+// Early-exit Interfaces
+//===----------------------------------------------------------------------===//
+
+def InterruptControlFlowOpInterface : OpInterface<"InterruptControlFlowOpInterface"> {
+ let description = [{
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ InterfaceMethod<[{
+ }],
+ "::mlir::Operation *", "getSuccessor",
+ (ins)
+ >,
+ ];
+}
+
+
//===----------------------------------------------------------------------===//
// WeightedBranchOpInterface
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp
index 74936e32bd9d9..18363e86c258a 100644
--- a/mlir/lib/AsmParser/Parser.cpp
+++ b/mlir/lib/AsmParser/Parser.cpp
@@ -638,7 +638,8 @@ class OperationParser : public Parser {
ParseResult parseSuccessor(Block *&dest);
/// Parse a comma-separated list of operation successors in brackets.
- ParseResult parseSuccessors(SmallVectorImpl<Block *> &destinations);
+ ParseResult parseSuccessors(SmallVectorImpl<Block *> &destinations,
+ bool parseOpeningBracket = true);
/// Parse an operation instance that is in the generic form.
Operation *parseGenericOperation();
@@ -657,7 +658,8 @@ class OperationParser : public Parser {
std::nullopt,
std::optional<ArrayRef<NamedAttribute>> parsedAttributes = std::nullopt,
std::optional<Attribute> propertiesAttribute = std::nullopt,
- std::optional<FunctionType> parsedFnType = std::nullopt);
+ std::optional<FunctionType> parsedFnType = std::nullopt,
+ std::optional<int> parsedNumBreakingControlRegions = std::nullopt);
/// Parse an operation instance that is in the generic form and insert it at
/// the provided insertion point.
@@ -1199,7 +1201,7 @@ Value OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) {
auto *op = Operation::create(
getEncodedSourceLocation(loc), name, type, /*operands=*/{},
/*attributes=*/NamedAttrList(), /*properties=*/nullptr,
- /*successors=*/{}, /*numRegions=*/0);
+ /*successors=*/{}, /*numRegions=*/0, /*numBreakingControlRegions=*/0);
forwardRefPlaceholders[op->getResult(0)] = loc;
forwardRefOps.insert(op);
return op->getResult(0);
@@ -1343,8 +1345,9 @@ ParseResult OperationParser::parseSuccessor(Block *&dest) {
/// successor-list ::= `[` successor (`,` successor )* `]`
///
ParseResult
-OperationParser::parseSuccessors(SmallVectorImpl<Block *> &destinations) {
- if (parseToken(Token::l_square, "expected '['"))
+OperationParser::parseSuccessors(SmallVectorImpl<Block *> &destinations,
+ bool parseOpeningBracket) {
+ if (parseOpeningBracket && parseToken(Token::l_square, "expected '['"))
return failure();
auto parseElt = [this, &destinations] {
@@ -1382,7 +1385,8 @@ ParseResult OperationParser::parseGenericOperationAfterOpName(
std::optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions,
std::optional<ArrayRef<NamedAttribute>> parsedAttributes,
std::optional<Attribute> propertiesAttribute,
- std::optional<FunctionType> parsedFnType) {
+ std::optional<FunctionType> parsedFnType,
+ std::optional<int> parsedNumBreakingControlRegions) {
// Parse the operand list, if not explicitly provided.
SmallVector<UnresolvedOperand, 8> opInfo;
@@ -1396,19 +1400,39 @@ ParseResult OperationParser::parseGenericOperationAfterOpName(
}
// Parse the successor list, if not explicitly provided.
- if (!parsedSuccessors) {
+ if (parsedSuccessors)
+ result.addSuccessors(*parsedSuccessors);
+ if (parsedNumBreakingControlRegions)
+ result.setNumBreakingControlRegions(*parsedNumBreakingControlRegions);
+ if (!parsedSuccessors || !parsedNumBreakingControlRegions) {
if (getToken().is(Token::l_square)) {
+ if (parseToken(Token::l_square, "expected '['"))
+ return failure();
+
// Check if the operation is not a known terminator.
if (!result.name.mightHaveTrait<OpTrait::IsTerminator>())
return emitError("successors in non-terminator");
- SmallVector<Block *, 2> successors;
- if (parseSuccessors(successors))
- return failure();
- result.addSuccessors(successors);
+ // If we don't have a ^, then we expect a single integer for the number
+ // of breaking control regions.
+ if (!getToken().is(Token::caret_identifier)) {
+ APInt numBreakingControlRegions;
+ OptionalParseResult parseResult =
+ parseOptionalInteger(numBreakingControlRegions);
+ if (!parseResult.has_value() || failed(*parseResult))
+ return emitError("expected `^` or integer after '['");
+ result.setNumBreakingControlRegions(
+ numBreakingControlRegions.getZExtValue());
+ if (failed(parseToken(Token::r_square,
+ "expected ']' to end breaking control regions")))
+ return failure();
+ } else {
+ SmallVector<Block *, 2> successors;
+ if (parseSuccessors(successors, /*parseOpeningBracket=*/false))
+ return failure();
+ result.addSuccessors(successors);
+ }
}
- } else {
- result.addSuccessors(*parsedSuccessors);
}
// Parse the properties, if not explicitly provided.
diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
index 37cfc9f2c23e6..04d3356ce7b0c 100644
--- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
+++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp
@@ -310,6 +310,16 @@ struct ForallLowering : public OpRewritePattern<mlir::scf::ForallOp> {
PatternRewriter &rewriter) const override;
};
+/// TODO
+struct LoopOpLowering : public OpConversionPattern<LoopOp> {
+ using OpConversionPattern<LoopOp>::OpConversionPattern;
+ void initialize() { setHasBoundedRewriteRecursion(); }
+
+ LogicalResult
+ matchAndRewrite(LoopOp loopOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
} // namespace
static void propagateLoopAttrs(Operation *scfOp, Operation *brOp) {
@@ -400,7 +410,12 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
PatternRewriter &rewriter) const {
auto loc = ifOp.getLoc();
-
+ ifOp.walk([&](Operation *op) {
+ // TODO: this is incorrect for nested loops
+ if (auto numBreakingControlRegions = op->getNumBreakingControlRegions()) {
+ op->setNumBreakingControlRegions(numBreakingControlRegions - 1);
+ }
+ });
// Start by splitting the block containing the 'scf.if' into two parts.
// The part before will contain the condition, the part after will be the
// continuation point.
@@ -424,8 +439,10 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
Operation *thenTerminator = thenRegion.back().getTerminator();
ValueRange thenTerminatorOperands = thenTerminator->getOperands();
rewriter.setInsertionPointToEnd(&thenRegion.back());
- cf::BranchOp::create(rewriter, loc, continueBlock, thenTerminatorOperands);
- rewriter.eraseOp(thenTerminator);
+ if (isa<scf::YieldOp>(thenTerminator)) {
+ cf::BranchOp::create(rewriter, loc, continueBlock, thenTerminatorOperands);
+ rewriter.eraseOp(thenTerminator);
+ }
rewriter.inlineRegionBefore(thenRegion, continueBlock);
// Move blocks from the "else" region (if present) to the region containing
@@ -438,8 +455,11 @@ LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
Operation *elseTerminator = elseRegion.back().getTerminator();
ValueRange elseTerminatorOperands = elseTerminator->getOperands();
rewriter.setInsertionPointToEnd(&elseRegion.back());
- cf::BranchOp::create(rewriter, loc, continueBlock, elseTerminatorOperands);
- rewriter.eraseOp(elseTerminator);
+ if (isa<scf::YieldOp>(thenTerminator)) {
+ cf::BranchOp::create(rewriter, loc, continueBlock,
+ elseTerminatorOperands);
+ rewriter.eraseOp(elseTerminator);
+ }
rewriter.inlineRegionBefore(elseRegion, continueBlock);
}
@@ -719,11 +739,113 @@ LogicalResult ForallLowering::matchAndRewrite(ForallOp forallOp,
return scf::forallToParallelLoop(rewriter, forallOp);
}
+LogicalResult
+LoopOpLowering::matchAndRewrite(LoopOp loopOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+ {
+ OpBuilder::InsertionGuard guard(rewriter);
+ if (failed(rewriter.legalize(&loopOp.getRegion())))
+ return rewriter.notifyMatchFailure(loopOp,
+ "failed to convert nested region");
+ }
+
+ SmallVector<Operation *> predecessors;
+ collectAllNestedPredecessors(loopOp, predecessors);
+ for (Operation *predecessor : predecessors) {
+ if (predecessor->getNumBreakingControlRegions() > 1) {
+ return rewriter.notifyMatchFailure(loopOp,
+ "loop op with nested predecessors");
+ }
+ }
+ loopOp.walk([&](Operation *op) {
+ // TODO: this is incorrect for nested loops
+ if (auto numBreakingControlRegions = op->getNumBreakingControlRegions()) {
+ op->setNumBreakingControlRegions(numBreakingControlRegions - 1);
+ }
+ });
+
+ // Lower `scf.loop` to CFG by converting breaks/continues to branches.
+ Location loc = loopOp.getLoc();
+ // Split the block containing loopOp into the init block and continuation.
+ Block *initBlock = rewriter.getInsertionBlock();
+ auto initPos = rewriter.getInsertionPoint();
+ Block *continueBlock = rewriter.splitBlock(initBlock, initPos);
+ continueBlock->addArguments(
+ loopOp.getResultTypes(),
+ SmallVector<Location>(loopOp.getNumResults(), loc));
+
+ // Inline the loop body region into the parent function just before
+ // continueBlock.
+ Region &bodyRegion = loopOp.getRegion();
+ if (bodyRegion.empty() || bodyRegion.front().empty()) {
+ // Degenerate case: no body. Just remove the op.
+ rewriter.eraseOp(loopOp);
+ return success();
+ }
+ Block *loopBody = &bodyRegion.front();
+
+ // Prepare the mapping of loop args to values.
+ SmallVector<Value> loopArgs;
+ for (auto arg : loopBody->getArguments())
+ loopArgs.push_back(arg);
+
+ // Create the loop entry block and move the body there.
+ rewriter.setInsertionPoint(initBlock, initBlock->end());
+ // Split out everything after loopOp into continueBlock.
+ // The block before loop is now initBlock.
+
+ // Move all blocks from the scf.loop region before continueBlock.
+ rewriter.inlineRegionBefore(bodyRegion, continueBlock);
+ // We will remember all break/continue ops to fix up after.
+ SmallVector<Operation *> toErase;
+
+ for (auto predecessor : predecessors) {
+ if (auto breakOp = dyn_cast<scf::BreakOp>(predecessor)) {
+ rewriter.setInsertionPoint(breakOp);
+ cf::BranchOp::create(rewriter, breakOp->getLoc(), continueBlock,
+ ValueRange{breakOp.getOperands()});
+ } else if (auto contOp = dyn_cast<scf::ContinueOp>(predecessor)) {
+ rewriter.setInsertionPoint(contOp);
+ cf::BranchOp::create(rewriter, contOp->getLoc(), loopBody,
+ ValueRange{contOp.getOperands()});
+ }
+ toErase.push_back(predecessor);
+ }
+
+ // Erase the old scf.break/scf.continue ops.
+ for (Operation *op : toErase)
+ rewriter.eraseOp(op);
+
+ // The loop region is now a CFG. Jump from initBlock to the loop body.
+ rewriter.setInsertionPointToEnd(initBlock);
+ cf::BranchOp::create(rewriter, loc, loopBody,
+ ValueRange{loopOp.getOperands()});
+
+ // Replace the scf.yield with a branch to the loop header (unless it was
+ // replaced above).
+ for (Block &block :
+ llvm::make_early_inc_range(loopBody->getParent()->getBlocks())) {
+ if (auto yield = dyn_cast<scf::YieldOp>(block.getTerminator())) {
+ rewriter.setInsertionPoint(yield);
+ // For plain scf.yield at the end of the loop (i.e., loop-carried values),
+ // treat as continue.
+ cf::BranchOp::create(rewriter, yield.getLoc(), loopBody,
+ yield.getOperands());
+ rewriter.eraseOp(yield);
+ }
+ }
+
+ // Replace the original scf.loop op with a branch to continueBlock assigning
+ // results.
+ rewriter.replaceOp(loopOp, continueBlock->getArguments());
+ return success();
+}
+
void mlir::populateSCFToControlFlowConversionPatterns(
RewritePatternSet &patterns) {
- patterns.add<ForallLowering, ForLowering, IfLowering, ParallelLowering,
- WhileLowering, ExecuteRegionLowering, IndexSwitchLowering>(
- patterns.getContext());
+ patterns.add<ExecuteRegionLowering, ForallLowering, ForLowering, IfLowering,
+ IndexSwitchLowering, LoopOpLowering, ParallelLowering,
+ WhileLowering>(patterns.getContext());
patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
}
@@ -734,7 +856,8 @@ void SCFToControlFlowPass::runOnOperation() {
// Configure conversion to lower out SCF operations.
ConversionTarget target(getContext());
target.addIllegalOp<scf::ForallOp, scf::ForOp, scf::IfOp, scf::IndexSwitchOp,
- scf::ParallelOp, scf::WhileOp, scf::ExecuteRegionOp>();
+ scf::LoopOp, scf::ParallelOp, scf::WhileOp,
+ scf::ExecuteRegionOp>();
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))
diff --git a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
index 71e3f88a63f34..8f5e37220d5cd 100644
--- a/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
+++ b/mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp
@@ -56,7 +56,7 @@ void mlir::registerConvertSCFToEmitCInterface(DialectRegistry ®istry) {
namespace {
-struct SCFToEmitCPass : public impl::SCFToEmitCBase<SCFToEmitCPass> {
+struct SCFToEmitCPass : public ::mlir::impl::SCFToEmitCBase<SCFToEmitCPass> {
void runOnOperation() override;
};
diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
index 0ff9fb3f628ab..da85f7c2d2eaa 100644
--- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
+++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp
@@ -681,7 +681,8 @@ namespace {
namespace {
/// Conversion pass.
class ConvertShapeToStandardPass
- : public impl::ConvertShapeToStandardPassBase<ConvertShapeToStandardPass> {
+ : public ::mlir::impl::ConvertShapeToStandardPassBase<
+ ConvertShapeToStandardPass> {
void runOnOperation() override;
};
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index 36a759c279eb7..ab9fb52e47e8b 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -671,10 +671,10 @@ Operation *BufferDeallocation::appendOpResults(Operation *op,
ArrayRef<Type> types) {
SmallVector<Type> newTypes(op->getResultTypes());
newTypes.append(types.begin(), types.end());
- auto *newOp = Operation::create(op->getLoc(), op->getName(), newTypes,
- op->getOperands(), op->getAttrDictionary(),
- op->getPropertiesStorage(),
- op->getSuccessors(), op->getNumRegions());
+ auto *newOp = Operation::create(
+ op->getLoc(), op->getName(), newTypes, op->getOperands(),
+ op->getAttrDictionary(), op->getPropertiesStorage(), op->getSuccessors(),
+ op->getNumRegions(), op->getNumBreakingControlRegions());
for (auto [oldRegion, newRegion] :
llvm::zip(op->getRegions(), newOp->getRegions()))
newRegion.takeBody(oldRegion);
diff --git a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
index cd138401e3177..841e479fe5828 100644
--- a/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/AsyncRegionRewriter.cpp
@@ -111,7 +111,8 @@ struct GpuAsyncRegionPass::ThreadTokenCallback {
auto *newOp = Operation::create(
op->getLoc(), op->getName(), resultTypes, op->getOperands(),
op->getDiscardableAttrDictionary(), op->getPropertiesStorage(),
- op->getSuccessors(), op->getNumRegions());
+ op->getSuccessors(), op->getNumRegions(),
+ op->getNumBreakingControlRegions());
// Clone regions into new op.
IRMapping mapping;
diff --git a/mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp b/mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp
index b04ccde2e633b..57d918d7f4cef 100644
--- a/mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp
+++ b/mlir/lib/Dialect/Quant/Transforms/NormalizeQuantTypes.cpp
@@ -124,7 +124,8 @@ class ConvertGenericOpwithSubChannelType : public ConversionPattern {
auto *newOp = Operation::create(
op->getLoc(), op->getName(), resultTypes, operands, op->getAttrs(),
- op->getPropertiesStorage(), op->getSuccessors(), op->getNumRegions());
+ op->getPropertiesStorage(), op->getSuccessors(), op->getNumRegions(),
+ op->getNumBreakingControlRegions());
for (auto regions : llvm::zip(op->getRegions(), newOp->getRegions())) {
Region &before = std::get<0>(regions);
Region &parent = std::get<1>(regions);
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 881e256a8797b..c2ccb63f30d3c 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -430,6 +430,151 @@ void ConditionOp::getSuccessorRegions(
regions.emplace_back(whileOp.getOperation(), whileOp.getResults());
}
+//===----------------------------------------------------------------------===//
+// LoopOp
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// Control Flow Op Utilies
+//===----------------------------------------------------------------------===//
+
+template <typename OpT>
+static ParseResult
+parseControlFlowRegion(OpAsmParser &p, Region ®ion,
+ ArrayRef<OpAsmParser::Argument> arguments = {}) {
+ if (failed(p.parseRegion(region, arguments)))
+ return failure();
+ OpT::ensureTerminator(region, p.getBuilder(),
+ p.getEncodedSourceLoc(p.getNameLoc()));
+ return success();
+}
+
+template <typename ImplicitTerminatorOpT, typename OpT>
+static void printControlFlowRegion(OpAsmPrinter &p, OpT op, Region ®ion) {
+ // We do not print the terminator if it is implicit and has no operands.
+ bool printBlockTerminators =
+ region.front().getTerminator()->getNumOperands() != 0 ||
+ !isa<ImplicitTerminatorOpT>(region.front().getTerminator());
+ p.printRegion(region, /*printEntryBlockArgs=*/false, printBlockTerminators);
+}
+
+LogicalResult ContinueOp::verify() {
+ if (getOperation()->getNumBreakingControlRegions() == 0)
+ return emitOpError(
+ "continue op must have at least one breaking control region");
+ return success();
+}
+
+MutableOperandRange
+ContinueOp::getMutableSuccessorOperands(RegionSuccessor point) {
+ return MutableOperandRange(getOperation());
+}
+
+LogicalResult LoopOp::verifyRegions() {
+ // Check matching between the operands and the region arguments.
+ if (getRegion().empty())
+ return emitOpError("region cannot be empty");
+ if (getRegion().front().getNumArguments() != getNumOperands())
+ return emitOpError(
+ "mismatch in number of loop-carried values and defined values");
+ for (auto [index, argAndOperand] : llvm::enumerate(
+ llvm::zip(getRegion().front().getArguments(), getOperands()))) {
+ auto argType = std::get<0>(argAndOperand).getType();
+ auto operandType = std::get<1>(argAndOperand).getType();
+ if (argType != operandType)
+ return emitOpError() << "types mismatch between " << index
+ << "th iter operand (" << argType
+ << ") and defined region argument (" << operandType
+ << ")";
+ }
+ return success();
+}
+
+void LoopOp::print(OpAsmPrinter &p) {
+ p << " ";
+ bool hasIters = !getInitValues().empty();
+ bool hasReturn = !getResultTypes().empty();
+
+ if (hasIters) {
+ p << "iter_args(";
+ llvm::interleaveComma(
+ llvm::zip(getRegionIterValues(), getInitValues()), p,
+ [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); });
+ p << ") : ";
+ p << getInitValues().getTypes();
+ p << " ";
+ }
+ if (hasReturn) {
+ p << "-> ";
+ p << getResultTypes();
+ p << " ";
+ }
+
+ printControlFlowRegion<ContinueOp>(p, *this, getRegion());
+ p.printOptionalAttrDict((*this)->getAttrs());
+}
+
+ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) {
+ SmallVector<OpAsmParser::Argument, 4> regionArgs;
+ SmallVector<OpAsmParser::UnresolvedOperand, 4> iterOperands;
+ SmallVector<Type, 4> iterTypes;
+
+ if (failed(parser.parseOptionalKeyword("iter_args"))) {
+ // no iter_args, but can still have a return type
+ if (succeeded(parser.parseOptionalArrow()))
+ if (parser.parseTypeList(result.types))
+ return failure();
+ } else {
+ // iter_args are present and must have colon followed by types
+ if (parser.parseAssignmentList(regionArgs, iterOperands) ||
+ parser.parseColon() || parser.parseTypeList(iterTypes))
+ return failure();
+ if (regionArgs.size() != iterTypes.size())
+ return parser.emitError(parser.getCurrentLocation(),
+ "found different number of iter_args and types");
+ // check for optional result type(s)
+ if (succeeded(parser.parseOptionalArrow()))
+ if (parser.parseTypeList(result.types))
+ return failure();
+ // Set region argument types for loop body
+ for (auto [regionArg, type] : llvm::zip_equal(regionArgs, iterTypes)) {
+ regionArg.type = type;
+ }
+ }
+
+ // Parse region and attr dict.
+ if (parseControlFlowRegion<LoopOp>(parser, *result.addRegion(), regionArgs) ||
+ parser.parseOptionalAttrDict(result.attributes))
+ return failure();
+
+ // Resolve operands.
+ if (parser.resolveOperands(iterOperands, iterTypes, parser.getNameLoc(),
+ result.operands))
+ return failure();
+
+ return success();
+}
+
+void LoopOp::getSuccessorRegions(RegionBranchPoint point,
+ SmallVectorImpl<RegionSuccessor> ®ions) {
+ if (point.isParent()) {
+ regions.push_back(RegionSuccessor(&getRegion()));
+ return;
+ }
+
+ // Otherwise, it depends on the terminator: a continue branches brack to the
+ // body and a break to the parent.
+ if (isa<ContinueOp>(point.getTerminatorPredecessorOrNull())) {
+ regions.push_back(
+ RegionSuccessor(&getRegion(), getRegion().getArguments()));
+ return;
+ }
+ assert(isa<BreakOp>(point.getTerminatorPredecessorOrNull()) &&
+ "expected continue or break terminator");
+
+ regions.push_back(RegionSuccessor(getOperation(), getResults()));
+}
+
//===----------------------------------------------------------------------===//
// ForOp
//===----------------------------------------------------------------------===//
@@ -2177,13 +2322,21 @@ IfOp::inferReturnTypes(MLIRContext *ctx, std::optional<Location> loc,
Region *r = &adaptor.getThenRegion();
if (r->empty())
return failure();
- Block &b = r->front();
- if (b.empty())
- return failure();
- auto yieldOp = llvm::dyn_cast<YieldOp>(b.back());
- if (!yieldOp)
+ Block *b = &r->front();
+ if (b->empty())
return failure();
- TypeRange types = yieldOp.getOperandTypes();
+ Operation *terminator = &b->back();
+ if (terminator->getNumBreakingControlRegions() > 1) {
+ if (adaptor.getElseRegion().empty())
+ return success();
+ b = &adaptor.getElseRegion().front();
+ if (b->empty())
+ return success();
+ terminator = &b->back();
+ if (terminator->getNumBreakingControlRegions() > 1)
+ return success();
+ }
+ TypeRange types = terminator->getOperandTypes();
llvm::append_range(inferredReturnTypes, types);
return success();
}
@@ -2308,7 +2461,9 @@ ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
}
void IfOp::print(OpAsmPrinter &p) {
- bool printBlockTerminators = false;
+ bool printBlockTerminators =
+ !isa<YieldOp>(thenBlock()->back()) ||
+ (elseBlock() && !isa<YieldOp>(elseBlock()->back()));
p << " " << getCondition();
if (!getResults().empty()) {
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaConvertIntegerTypeToSignless.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaConvertIntegerTypeToSignless.cpp
index 4b131333b956a..1b632e4853d51 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaConvertIntegerTypeToSignless.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaConvertIntegerTypeToSignless.cpp
@@ -85,7 +85,8 @@ class ConvertGenericOpWithIntegerTensorType : public ConversionPattern {
// Create new op with replaced operands and results
auto *newOp = Operation::create(
op->getLoc(), op->getName(), resultTypes, operands, op->getAttrs(),
- op->getPropertiesStorage(), op->getSuccessors(), op->getNumRegions());
+ op->getPropertiesStorage(), op->getSuccessors(), op->getNumRegions(),
+ op->getNumBreakingControlRegions());
// Handle regions in e.g. tosa.cond_if and tosa.while_loop
for (auto regions : llvm::zip(op->getRegions(), newOp->getRegions())) {
diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp
index 9b23dd6e4f283..3b7b54e7291f9 100644
--- a/mlir/lib/IR/AsmPrinter.cpp
+++ b/mlir/lib/IR/AsmPrinter.cpp
@@ -3722,6 +3722,10 @@ void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); });
os << ')';
+ if (op->getNumBreakingControlRegions() != 0) {
+ os << " [" << op->getNumBreakingControlRegions() << "]";
+ }
+
// For terminators, print the list of successors and their operands.
if (op->getNumSuccessors() != 0) {
os << '[';
diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp
index f4c9242ed3479..879be5a204c5d 100644
--- a/mlir/lib/IR/Diagnostics.cpp
+++ b/mlir/lib/IR/Diagnostics.cpp
@@ -138,8 +138,8 @@ Diagnostic &Diagnostic::operator<<(Operation &op) {
return appendOp(op, OpPrintingFlags());
}
-Diagnostic &Diagnostic::operator<<(OpWithFlags op) {
- return appendOp(*op.getOperation(), op.flags());
+Diagnostic &Diagnostic::operator<<(const OpWithFlags &opWithFlags) {
+ return appendOp(*opWithFlags.getOperation(), opWithFlags.flags());
}
Diagnostic &Diagnostic::appendOp(Operation &op, const OpPrintingFlags &flags) {
diff --git a/mlir/lib/IR/Dominance.cpp b/mlir/lib/IR/Dominance.cpp
index 0e53b431b5d31..32f88cd2c3847 100644
--- a/mlir/lib/IR/Dominance.cpp
+++ b/mlir/lib/IR/Dominance.cpp
@@ -14,8 +14,11 @@
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/RegionKindInterface.h"
+#include "llvm/Support/DebugLog.h"
#include "llvm/Support/GenericDomTreeConstruction.h"
+#define DEBUG_TYPE "dominance"
+
using namespace mlir;
using namespace mlir::detail;
@@ -289,6 +292,26 @@ bool DominanceInfoBase<IsPostDom>::properlyDominatesImpl(
// regions kinds, uses and defs can come in any order inside a block.
if (!hasSSADominance(aBlock))
return true;
+
+ // Any operation that propagates a control flow break invalidate the
+ // post-dominance relation.
+ if (IsPostDom && hasBreakingControlFlow(aBlock->getParent())) {
+ bool inRange = false;
+ for (Operation &op : *aBlock) {
+ if (inRange) {
+ if (op.hasTrait<OpTrait::PropagateControlFlowBreak>() &&
+ hasBreakingControlFlowOps(&op)) {
+ LDBG() << "Breaking control flow: "
+ << OpWithFlags(&op, OpPrintingFlags().skipRegions());
+ return false;
+ }
+ if (&op == &*bIt || &op == &*aIt)
+ break;
+ } else if (&op == &*bIt || &op == &*aIt) {
+ inRange = true;
+ }
+ }
+ }
if constexpr (IsPostDom) {
return isBeforeInBlock(aBlock, bIt, aIt);
} else {
diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp
index 8212d6d3d1eba..6e699eee2dcd4 100644
--- a/mlir/lib/IR/Operation.cpp
+++ b/mlir/lib/IR/Operation.cpp
@@ -13,6 +13,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
@@ -32,10 +33,10 @@ using namespace mlir;
/// Create a new Operation from operation state.
Operation *Operation::create(const OperationState &state) {
- Operation *op =
- create(state.location, state.name, state.types, state.operands,
- state.attributes.getDictionary(state.getContext()),
- state.properties, state.successors, state.regions);
+ Operation *op = create(
+ state.location, state.name, state.types, state.operands,
+ state.attributes.getDictionary(state.getContext()), state.properties,
+ state.successors, state.regions, state.numBreakingControlRegions);
if (LLVM_UNLIKELY(state.propertiesAttr)) {
assert(!state.properties);
LogicalResult result =
@@ -52,11 +53,12 @@ Operation *Operation::create(Location location, OperationName name,
TypeRange resultTypes, ValueRange operands,
NamedAttrList &&attributes,
OpaqueProperties properties, BlockRange successors,
- RegionRange regions) {
+ RegionRange regions,
+ unsigned numBreakingControlRegions) {
unsigned numRegions = regions.size();
Operation *op =
create(location, name, resultTypes, operands, std::move(attributes),
- properties, successors, numRegions);
+ properties, successors, numRegions, numBreakingControlRegions);
for (unsigned i = 0; i < numRegions; ++i)
if (regions[i])
op->getRegion(i).takeBody(*regions[i]);
@@ -68,13 +70,14 @@ Operation *Operation::create(Location location, OperationName name,
TypeRange resultTypes, ValueRange operands,
NamedAttrList &&attributes,
OpaqueProperties properties, BlockRange successors,
- unsigned numRegions) {
+ unsigned numRegions,
+ unsigned numBreakingControlRegions) {
// Populate default attributes.
name.populateDefaultAttrs(attributes);
return create(location, name, resultTypes, operands,
attributes.getDictionary(location.getContext()), properties,
- successors, numRegions);
+ successors, numRegions, numBreakingControlRegions);
}
/// Overload of create that takes an existing DictionaryAttr to avoid
@@ -83,7 +86,8 @@ Operation *Operation::create(Location location, OperationName name,
TypeRange resultTypes, ValueRange operands,
DictionaryAttr attributes,
OpaqueProperties properties, BlockRange successors,
- unsigned numRegions) {
+ unsigned numRegions,
+ unsigned numBreakingControlRegions) {
assert(llvm::all_of(resultTypes, [](Type t) { return t; }) &&
"unexpected null result type");
@@ -93,7 +97,10 @@ Operation *Operation::create(Location location, OperationName name,
unsigned numSuccessors = successors.size();
unsigned numOperands = operands.size();
unsigned numResults = resultTypes.size();
- int opPropertiesAllocSize = llvm::alignTo<8>(name.getOpPropertyByteSize());
+ size_t opPropertiesByteSize = name.getOpPropertyByteSize();
+ if (numBreakingControlRegions)
+ opPropertiesByteSize += 8;
+ int opPropertiesAllocSize = llvm::alignTo<8>(opPropertiesByteSize);
// If the operation is known to have no operands, don't allocate an operand
// storage.
@@ -115,12 +122,16 @@ Operation *Operation::create(Location location, OperationName name,
void *rawMem = mallocMem + prefixByteSize;
// Create the new Operation.
- Operation *op = ::new (rawMem) Operation(
- location, name, numResults, numSuccessors, numRegions,
- opPropertiesAllocSize, attributes, properties, needsOperandStorage);
+ Operation *op = ::new (rawMem)
+ Operation(location, name, numResults, numSuccessors, numRegions,
+ numBreakingControlRegions, opPropertiesAllocSize, attributes,
+ properties, needsOperandStorage);
assert((numSuccessors == 0 || op->mightHaveTrait<OpTrait::IsTerminator>()) &&
"unexpected successors in a non-terminator operation");
+ assert((numBreakingControlRegions == 0 ||
+ op->mightHaveTrait<OpTrait::IsTerminator>()) &&
+ "unexpected breaking control regions in a non-terminator operation");
// Initialize the results.
auto resultTypeIt = resultTypes.begin();
@@ -154,10 +165,12 @@ Operation *Operation::create(Location location, OperationName name,
Operation::Operation(Location location, OperationName name, unsigned numResults,
unsigned numSuccessors, unsigned numRegions,
+ unsigned numBreakingControlRegions,
int fullPropertiesStorageSize, DictionaryAttr attributes,
OpaqueProperties properties, bool hasOperandStorage)
: location(location), numResults(numResults), numSuccs(numSuccessors),
numRegions(numRegions), hasOperandStorage(hasOperandStorage),
+ isBreakingControlFlowFlag(numBreakingControlRegions),
propertiesStorageSize((fullPropertiesStorageSize + 7) / 8), name(name) {
assert(attributes && "unexpected null attribute dictionary");
assert(fullPropertiesStorageSize <= propertiesCapacity &&
@@ -170,6 +183,9 @@ Operation::Operation(Location location, OperationName name, unsigned numResults,
"allowUnregisteredDialects() on the MLIRContext, or use "
"-allow-unregistered-dialect with the MLIR tool used.");
#endif
+ if (numBreakingControlRegions)
+ *reinterpret_cast<unsigned *>(getTrailingObjects<detail::OpProperties>()) =
+ numBreakingControlRegions;
if (fullPropertiesStorageSize)
name.initOpProperties(getPropertiesStorage(), properties);
}
@@ -732,7 +748,8 @@ Operation *Operation::clone(IRMapping &mapper, CloneOptions options) {
// Create the new operation.
auto *newOp = create(getLoc(), getName(), getResultTypes(), operands, attrs,
- getPropertiesStorage(), successors, getNumRegions());
+ getPropertiesStorage(), successors, getNumRegions(),
+ getNumBreakingControlRegions());
mapper.map(this, newOp);
// Clone the regions.
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index 9332f55bd9393..dd798078c1d65 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -10,6 +10,9 @@
#include "mlir/IR/Iterators.h"
#include "mlir/IR/RegionKindInterface.h"
#include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Support/DebugLog.h"
+
+#define DEBUG_TYPE "pattern-match"
using namespace mlir;
@@ -225,7 +228,7 @@ void RewriterBase::eraseOp(Operation *op) {
// Then erase the enclosing op.
eraseSingleOp(op);
};
-
+ LDBG() << "RewriterBase::eraseOp: " << *op;
eraseTree(op);
}
diff --git a/mlir/lib/IR/RegionKindInterface.cpp b/mlir/lib/IR/RegionKindInterface.cpp
index 007f4cf92dbc7..a7ff4c2a211d9 100644
--- a/mlir/lib/IR/RegionKindInterface.cpp
+++ b/mlir/lib/IR/RegionKindInterface.cpp
@@ -12,6 +12,11 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/RegionKindInterface.h"
+#include "mlir/Support/WalkResult.h"
+
+#include "llvm/Support/DebugLog.h"
+
+#define DEBUG_TYPE "region-kind-interface"
using namespace mlir;
@@ -32,3 +37,121 @@ bool mlir::mayBeGraphRegion(Region ®ion) {
return false;
return !regionKindOp.hasSSADominance(region.getRegionNumber());
}
+
+namespace {
+// Iterator on all reachable operations in the region.
+// Also keep track if we visited the nested regions of the current op
+// already to drive the traversal.
+struct NestedOpIterator {
+ NestedOpIterator(Region *region, int nestedLevel)
+ : region(region), nestedLevel(nestedLevel) {
+ regionIt = region->begin();
+ blockIt = regionIt->end();
+ if (regionIt != region->end())
+ blockIt = regionIt->begin();
+ }
+ // Advance the iterator to the next reachable operation.
+ void advance() {
+ assert(regionIt != region->end());
+ if (blockIt == regionIt->end()) {
+ ++regionIt;
+ if (regionIt != region->end())
+ blockIt = regionIt->begin();
+ return;
+ }
+ ++blockIt;
+ if (blockIt != regionIt->end()) {
+ LDBG() << this << " - Incrementing block iterator, next op: "
+ << OpWithFlags(&*blockIt, OpPrintingFlags().skipRegions());
+ }
+ }
+
+ // The region we're iterating over.
+ Region *region;
+ // The Block currently being iterated over.
+ Region::iterator regionIt;
+ // The Operation currently being iterated over.
+ Block::iterator blockIt;
+ // The nested level of the current region relative to the starting region.
+ int nestedLevel = 0;
+};
+} // namespace
+
+static void walk(Operation *rootOp,
+ function_ref<WalkResult(Operation *, int)> callback) {
+ // Worklist of regions to visit to drive the traversal.
+ SmallVector<NestedOpIterator> worklist;
+
+ // Perform a traversal of the regions, visiting each
+ // reachable operation.
+ for (Region ®ion : rootOp->getRegions()) {
+ if (region.empty())
+ continue;
+ worklist.push_back({®ion, 1});
+ }
+ while (!worklist.empty()) {
+ NestedOpIterator &it = worklist.back();
+ if (it.regionIt == it.region->end()) {
+ // We're done with this region.
+ worklist.pop_back();
+ continue;
+ }
+ if (it.blockIt == it.regionIt->end()) {
+ // We're done with this block.
+ it.advance();
+ continue;
+ }
+ Operation *op = &*it.blockIt;
+ LDBG() << "Visiting op: "
+ << OpWithFlags(op, OpPrintingFlags().skipRegions());
+
+ if (callback(op, it.nestedLevel).wasInterrupted())
+ return;
+
+ // Advance before pushing nested regions to avoid reference invalidation.
+ int currentNestedLevel = it.nestedLevel;
+ it.advance();
+
+ // Recursively visit the nested regions.
+ for (Region &nestedRegion : op->getRegions()) {
+ if (nestedRegion.empty())
+ continue;
+ worklist.push_back({&nestedRegion, currentNestedLevel + 1});
+ }
+ }
+}
+
+bool mlir::hasNestedPredecessors(Operation *op) {
+ bool found = false;
+ walk(op, [&](Operation *visitedOp, int nestedLevel) {
+ if (nestedLevel ==
+ static_cast<int>(visitedOp->getNumBreakingControlRegions()))
+ found = true;
+ return found ? WalkResult::interrupt() : WalkResult::advance();
+ });
+ return found;
+}
+
+bool mlir::hasBreakingControlFlowOps(Operation *op) {
+ bool found = false;
+ walk(op, [&](Operation *visitedOp, int nestedLevel) {
+ if (nestedLevel >
+ static_cast<int>(visitedOp->getNumBreakingControlRegions()))
+ found = true;
+ return found ? WalkResult::interrupt() : WalkResult::advance();
+ });
+ return found;
+}
+
+void mlir::collectAllNestedPredecessors(
+ Operation *op, SmallVector<Operation *> &predecessors) {
+ walk(op, [&](Operation *visitedOp, int nestedLevel) {
+ LDBG() << "Visiting op: "
+ << OpWithFlags(visitedOp, OpPrintingFlags().skipRegions())
+ << " at nested level " << nestedLevel;
+ if (nestedLevel ==
+ static_cast<int>(visitedOp->getNumBreakingControlRegions()))
+ predecessors.push_back(visitedOp);
+ return WalkResult::advance();
+ });
+}
\ No newline at end of file
diff --git a/mlir/lib/IR/Verifier.cpp b/mlir/lib/IR/Verifier.cpp
index 3ced663a87be1..62e6720a66215 100644
--- a/mlir/lib/IR/Verifier.cpp
+++ b/mlir/lib/IR/Verifier.cpp
@@ -29,6 +29,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Operation.h"
+#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/RegionKindInterface.h"
#include "mlir/IR/Threading.h"
#include "llvm/ADT/PointerIntPair.h"
@@ -129,6 +130,42 @@ LogicalResult OperationVerifier::verifyOnEntrance(Block &block) {
if (op.getNumSuccessors() != 0 && &op != &block.back())
return op.emitError(
"operation with block successors must terminate its parent block");
+
+ Operation *currentOp = &op;
+ if (op.getNumBreakingControlRegions()) {
+ for (int i [[maybe_unused]] :
+ llvm::seq<int>(0, op.getNumBreakingControlRegions())) {
+ currentOp = currentOp->getParentOp();
+ if (!currentOp)
+ return op.emitError("operation with breaking control regions "
+ "exceededing the number of enclosing parent ops");
+ if (i == static_cast<int>(op.getNumBreakingControlRegions()) - 1) {
+ auto successorOp =
+ dyn_cast<HasBreakingControlFlowOpInterface>(currentOp);
+ if (!successorOp)
+ return currentOp
+ ->emitError(
+ "operation has a nested predessor but does not "
+ "have "
+ "the HasBreakingControlFlowOpInterface trait.")
+ .attachNote(op.getLoc())
+ << " for this predecessor operation (" << op.getName()
+ << ")";
+
+ if (!successorOp.acceptsTerminator(&op))
+ return currentOp->emitError(
+ "operation with breaking control regions "
+ "does not accept terminator: ")
+ << OpWithFlags(&op, OpPrintingFlags().skipRegions());
+ } else {
+ if (!currentOp->hasTrait<OpTrait::PropagateControlFlowBreak>())
+ return op.emitError("breaking control regions through an op that "
+ "does not have "
+ "the PropagateControlFlowBreak trait: ")
+ << OpWithFlags(currentOp, OpPrintingFlags().skipRegions());
+ }
+ }
+ }
}
return success();
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 1e56810ff7aaf..2e2ae2a6ffaac 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -72,14 +72,14 @@ detail::getBranchSuccessorArgument(const SuccessorOperands &operands,
LogicalResult
detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
const SuccessorOperands &operands) {
- LDBG() << "Verifying branch successor operands for successor #" << succNo
- << " in operation " << op->getName();
+ LDBG(3) << "Verifying branch successor operands for successor #" << succNo
+ << " in operation " << op->getName();
// Check the count.
unsigned operandCount = operands.size();
Block *destBB = op->getSuccessor(succNo);
- LDBG() << "Branch has " << operandCount << " operands, target block has "
- << destBB->getNumArguments() << " arguments";
+ LDBG(3) << "Branch has " << operandCount << " operands, target block has "
+ << destBB->getNumArguments() << " arguments";
if (operandCount != destBB->getNumArguments())
return op->emitError() << "branch has " << operandCount
@@ -88,22 +88,22 @@ detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
<< destBB->getNumArguments();
// Check the types.
- LDBG() << "Checking type compatibility for "
- << (operandCount - operands.getProducedOperandCount())
- << " forwarded operands";
+ LDBG(3) << "Checking type compatibility for "
+ << (operandCount - operands.getProducedOperandCount())
+ << " forwarded operands";
for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
++i) {
Type operandType = operands[i].getType();
Type argType = destBB->getArgument(i).getType();
- LDBG() << "Checking type compatibility: operand type " << operandType
- << " vs argument type " << argType;
+ LDBG(3) << "Checking type compatibility: operand type " << operandType
+ << " vs argument type " << argType;
if (!cast<BranchOpInterface>(op).areTypesCompatible(operandType, argType))
return op->emitError() << "type mismatch for bb argument #" << i
<< " of successor #" << succNo;
}
- LDBG() << "Branch successor operand verification successful";
+ LDBG(3) << "Branch successor operand verification successful";
return success();
}
@@ -234,12 +234,15 @@ LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
for (Region ®ion : op->getRegions()) {
// Collect all return-like terminators in the region.
SmallVector<RegionBranchTerminatorOpInterface> regionReturnOps;
- for (Block &block : region)
- if (!block.empty())
- if (auto terminator =
- dyn_cast<RegionBranchTerminatorOpInterface>(block.back()))
+ for (Block &block : region) {
+ if (!block.empty()) {
+ Operation *op = &block.back();
+ if (op->getNumBreakingControlRegions() > 1)
+ continue;
+ if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op))
regionReturnOps.push_back(terminator);
-
+ }
+ }
// If there is no return-like terminator, the op itself should verify
// type consistency.
if (regionReturnOps.empty())
diff --git a/mlir/lib/Transforms/Utils/CMakeLists.txt b/mlir/lib/Transforms/Utils/CMakeLists.txt
index 3ca16239ba33c..93dd73102ad52 100644
--- a/mlir/lib/Transforms/Utils/CMakeLists.txt
+++ b/mlir/lib/Transforms/Utils/CMakeLists.txt
@@ -14,6 +14,9 @@ add_mlir_library(MLIRTransformUtils
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Transforms
+ DEPENDS
+ MLIRRegionKindInterfaceIncGen
+
LINK_LIBS PUBLIC
MLIRAnalysis
MLIRCallInterfaces
diff --git a/mlir/lib/Transforms/Utils/InliningUtils.cpp b/mlir/lib/Transforms/Utils/InliningUtils.cpp
index 73107cfc36ea9..11e23ae45d2cf 100644
--- a/mlir/lib/Transforms/Utils/InliningUtils.cpp
+++ b/mlir/lib/Transforms/Utils/InliningUtils.cpp
@@ -286,6 +286,13 @@ static LogicalResult inlineRegionImpl(
[&](BlockArgument arg) { return !mapper.contains(arg); }))
return failure();
+ // Check that the region has no nested successors, e.g. a nested return inside
+ // a function.
+ if (auto opWithBreakingControlPredecessor =
+ dyn_cast<HasBreakingControlFlowOpInterface>(src->getParentOp()))
+ if (opWithBreakingControlPredecessor.hasNestedPredecessors())
+ return failure();
+
// Check that the operations within the source region are valid to inline.
Region *insertRegion = inlineBlock->getParent();
if (!interface.isLegalToInline(insertRegion, src, shouldCloneInlinedRegion,
diff --git a/mlir/test/Conversion/SCFToControlFlow/convert-early-exit-to-cfg.mlir b/mlir/test/Conversion/SCFToControlFlow/convert-early-exit-to-cfg.mlir
new file mode 100644
index 0000000000000..ad2b0cee6add5
--- /dev/null
+++ b/mlir/test/Conversion/SCFToControlFlow/convert-early-exit-to-cfg.mlir
@@ -0,0 +1,47 @@
+// RUN: mlir-opt -convert-scf-to-cf -split-input-file %s | FileCheck %s
+
+
+func.func @loop_break(%cond : i1) {
+ // CHECK: test.op1
+ "test.op1"() : () -> ()
+ // CHECK-NEXT: cf.br [[LOOP1_ENTRY:.*]]
+ // CHECK-NEXT: [[LOOP1_ENTRY]]
+ scf.loop {
+ // CHECK-NEXT: test.op2
+ "test.op2"() : () -> ()
+ // CHECK-NEXT: cf.cond_br %arg0, [[IF_ENTRY:.*]], [[IF_CONTINUE:.*]]
+ // CHECK-NEXT: [[IF_ENTRY]]
+ scf.if %cond {
+ "test.op3"() : () -> ()
+ scf.break 2 loc("break1")
+ }
+ "test.op3"() : () -> ()
+ } loc("loop1")
+ "test.op4"() : () -> ()
+ return
+}
+
+// -----
+
+func.func @loop_continue(%cond1 : i1, %cond2 : i1) {
+ "test.op1"() : () -> ()
+ scf.loop {
+ "test.op2"() : () -> ()
+ scf.loop {
+ "test.op3"() : () -> ()
+ scf.if %cond1 {
+ "test.op4"() : () -> ()
+ scf.continue 2 loc("continue1")
+ }
+ "test.op5"() : () -> ()
+ scf.if %cond2 {
+ "test.op6"() : () -> ()
+ scf.break 3 loc("break2")
+ }
+ "test.op7"() : () -> ()
+ } loc("loop3")
+ "test.op8"() : () -> ()
+ } loc("loop2")
+ "test.op9"() : () -> ()
+ return
+}
diff --git a/mlir/test/IR/early-exit-invalid.mlir b/mlir/test/IR/early-exit-invalid.mlir
new file mode 100644
index 0000000000000..c3817bf8e5b1d
--- /dev/null
+++ b/mlir/test/IR/early-exit-invalid.mlir
@@ -0,0 +1,65 @@
+
+// RUN: mlir-opt %s --split-input-file --verify-diagnostics
+
+
+// expected-error @+1 {{operation has a nested predessor but does not have the HasBreakingControlFlowOpInterface trait}}
+ func.func @loop_continue() {
+ scf.loop {
+// expected-note @+1 {{for this predecessor operation (scf.continue)}}
+ scf.continue 2
+ } loc("loop1")
+ return
+}
+
+// -----
+
+func.func @loop_result_mismatch(%value : f32) {
+ // expected-error @+1 {{'scf.loop' op along control flow edge from Operation scf.break to parent results: source type #0 'f32' should match input type #0 'i32'}}
+ %result = scf.loop -> i32 {
+ scf.break 1 %value : f32
+ }
+ return
+}
+
+// -----
+
+func.func @loop_result_number_mismatch(%value : f32) {
+ // expected-error @+1 {{'scf.loop' op region control flow edge from Operation scf.break to parent results: source has 1 operands, but target successor <to parent> needs 2}}
+ %result:2 = scf.loop -> f32, f32 {
+ scf.break 1 %value : f32
+ }
+ return
+}
+
+// -----
+
+func.func @loop_continue_mismatch(%init : i32, %value : f32) {
+ // expected-error @+1 {{'scf.loop' op along control flow edge from Operation scf.continue to Region #0: source type #0 'f32' should match input type #0 'i32'}}
+ scf.loop iter_args(%next = %init) : i32 {
+ scf.continue 1 %value : f32
+ }
+ return
+}
+
+
+// -----
+
+func.func @loop_iterargs_mismatch(%init : i32, %value : f32) {
+ // expected-error @+1 {{'scf.loop' op along control flow edge from Operation scf.continue to Region #0: source type #0 'i32' should match input type #0 'f32'}}
+ "scf.loop"(%init) ({
+ ^body(%next : f32):
+ scf.continue 1 %init : i32
+ }) : (i32) -> ()
+ return
+}
+
+// -----
+
+func.func @loop_iterargs_mismatch(%init : i32, %value : f32) {
+ // expected-error @+1 {{'scf.loop' op region control flow edge from Operation scf.continue to Region #0: source has 1 operands, but target successor <to region #0 with 2 inputs> needs 2}}
+ "scf.loop"(%init) ({
+ ^body(%next : i32, %next2 : f32):
+ scf.continue 1 %init : i32
+ }) : (i32) -> ()
+ return
+}
diff --git a/mlir/test/IR/early-exit.mlir b/mlir/test/IR/early-exit.mlir
new file mode 100644
index 0000000000000..ea83bc194ec3f
--- /dev/null
+++ b/mlir/test/IR/early-exit.mlir
@@ -0,0 +1,64 @@
+// RUN: mlir-opt --print-region-branch-op-interface %s --split-input-file | FileCheck %s
+// RUN: mlir-opt %s --mlir-print-debuginfo --mlir-print-op-generic --split-input-file | mlir-opt --print-region-branch-op-interface --split-input-file | FileCheck %s
+
+
+func.func @loop_break(%cond : i1) {
+ // CHECK: Found RegionBranchOpInterface operation: scf.loop {...} loc("loop1")
+ // CHECK: - Successor is region #0
+ // CHECK: - Found 2 predecessor(s)
+ // CHECK: - Predecessor is scf.break 2 loc("break1")
+ // CHECK: - Predecessor is scf.continue 1
+ scf.loop {
+ scf.if %cond {
+ scf.break 2 loc("break1")
+ }
+ } loc("loop1")
+ return
+}
+
+// -----
+
+func.func @loop_continue(%cond1 : i1, %cond2 : i1) {
+ // CHECK: Found RegionBranchOpInterface operation: scf.loop {...} loc("loop2")
+ // CHECK: - Successor is region #0
+ // CHECK: - Found 2 predecessor(s)
+ // CHECK: - Predecessor is scf.break 3 loc("break2")
+ // CHECK: - Predecessor is scf.continue 1
+ scf.loop {
+ // CHECK: Found RegionBranchOpInterface operation: scf.loop {...} loc("loop3")
+ // CHECK: - Successor is region #0
+ // CHECK: - Found 2 predecessor(s)
+ // CHECK: - Predecessor is scf.continue 2 loc("continue1")
+ // CHECK: - Predecessor is scf.continue 1
+ scf.loop {
+ scf.if %cond1 {
+ scf.continue 2 loc("continue1")
+ }
+ scf.if %cond2 {
+ scf.break 3 loc("break2")
+ }
+ } loc("loop3")
+ } loc("loop2")
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func @loop_with_results(
+func.func @loop_with_results(%value : f32) -> f32 {
+ %result = scf.loop -> f32 {
+ scf.break 1 %value : f32
+ }
+ return %result : f32
+}
+
+// -----
+
+// CHECK-LABEL: func @loop_continue_iterargs(
+func.func @loop_continue_iterargs(%init : i32) {
+ scf.loop iter_args(%next = %init) : i32 {
+ scf.continue 1 %next : i32
+ }
+ return
+}
+
diff --git a/mlir/test/Integration/Dialect/SCF/early_exit.mlir b/mlir/test/Integration/Dialect/SCF/early_exit.mlir
new file mode 100644
index 0000000000000..974ad681e3200
--- /dev/null
+++ b/mlir/test/Integration/Dialect/SCF/early_exit.mlir
@@ -0,0 +1,82 @@
+// RUN: mlir-opt %s -convert-scf-to-cf --canonicalize --convert-cf-to-llvm --convert-to-llvm | \
+// RUN: mlir-runner -e entry -entry-point-result=void \
+// RUN: -shared-libs=%mlir_c_runner_utils | \
+// RUN: FileCheck %s
+
+
+
+// End-to-end test of all fp reduction intrinsics (not exhaustive unit tests).
+module {
+ llvm.func @entry() {
+ // Constant for the iteration space and various conditions
+ %one = llvm.mlir.constant(1 : i64) : i64
+ %two = llvm.mlir.constant(2 : i64) : i64
+ %three = llvm.mlir.constant(3 : i64) : i64
+ %four = llvm.mlir.constant(4 : i64) : i64
+ %counter_init = llvm.mlir.constant(0 : i64) : i64
+
+
+// CHECK: Outer Loop Begin with counter: 0
+// CHECK-NEXT: Inner Loop Begin, counter: 1
+// CHECK-NEXT: continue inner loop
+// CHECK-NEXT: Inner Loop Begin, counter: 2
+// CHECK-NEXT: Iteration 2, loop back to outer loop
+// CHECK-NEXT: Outer Loop Begin with counter: 2
+// CHECK-NEXT: Inner Loop Begin, counter: 3
+// CHECK-NEXT: continue inner loop
+// CHECK-NEXT: Inner Loop Begin, counter: 4
+// CHECK-NEXT: continue inner loop
+// CHECK-NEXT: Inner Loop Begin, counter: 5
+// CHECK-NEXT: Last iteration, break out of outer loop
+// CHECK-NEXT: Outer loop finished with result: 4
+
+
+ %result = scf.loop iter_args(%counter_out = %counter_init) : i64 -> i64 {
+ // Outer loop iteration
+ vector.print str "Outer Loop Begin with counter: "
+ vector.print %counter_out : i64
+
+ scf.loop iter_args(%counter = %counter_out) : i64 {
+ // %counter will go from 0 to 4
+ // %counter_update will go from 1 to 5
+ %counter_update = llvm.add %counter, %one : i64
+
+ // Inner loop iteration
+ // print from 1..5
+ vector.print str "Inner Loop Begin, counter: "
+ vector.print %counter_update : i64
+
+ // On the second iteration, print 2.3 and loop back to the outer loop.
+ %cond1 = llvm.icmp "eq" %counter_update, %two : i64
+ scf.if %cond1 {
+ vector.print str "Iteration 2, loop back to outer loop\n"
+ scf.continue 3 %counter_update : i64
+ }
+
+ // Exit condition when counter>4
+ %cond2 = llvm.icmp "sge" %counter, %four : i64
+ scf.if %cond2 {
+ vector.print str "Last iteration, break out of outer loop\n"
+ // return the counter from the previous iteration here (pre-update)
+ scf.break 3 %counter : i64
+ }
+
+ %cond3 = llvm.icmp "eq" %counter_update, %three : i64
+ scf.if %cond2 {
+ vector.print str "Iteration 3, break out of inner loop"
+ scf.break 2
+ }
+ vector.print str "continue inner loop\n"
+ scf.continue 1 %counter_update : i64
+ }
+ vector.print str "continue outer loop\n"
+ scf.continue 1 %counter_out : i64
+ }
+
+// After the loop nest finishes
+ vector.print str "Outer loop finished with result: "
+ vector.print %result : i64
+
+ llvm.return
+ }
+}
diff --git a/mlir/test/lib/Interfaces/CMakeLists.txt b/mlir/test/lib/Interfaces/CMakeLists.txt
index 6a21ed10eec6f..3aa5097b7ed20 100644
--- a/mlir/test/lib/Interfaces/CMakeLists.txt
+++ b/mlir/test/lib/Interfaces/CMakeLists.txt
@@ -1,2 +1,3 @@
add_subdirectory(LoopLikeInterface)
+add_subdirectory(RegionBranchOpInterface)
add_subdirectory(TilingInterface)
diff --git a/mlir/test/lib/Interfaces/RegionBranchOpInterface/CMakeLists.txt b/mlir/test/lib/Interfaces/RegionBranchOpInterface/CMakeLists.txt
new file mode 100644
index 0000000000000..8e003942e41c0
--- /dev/null
+++ b/mlir/test/lib/Interfaces/RegionBranchOpInterface/CMakeLists.txt
@@ -0,0 +1,9 @@
+add_mlir_library(MLIRTestRegionBranchOpInterface
+ TestRegionBranchOpInterface.cpp
+
+ EXCLUDE_FROM_LIBMLIR
+ )
+mlir_target_link_libraries(MLIRTestRegionBranchOpInterface PUBLIC
+ MLIRControlFlowInterfaces
+ MLIRPass
+ )
diff --git a/mlir/test/lib/Interfaces/RegionBranchOpInterface/TestRegionBranchOpInterface.cpp b/mlir/test/lib/Interfaces/RegionBranchOpInterface/TestRegionBranchOpInterface.cpp
new file mode 100644
index 0000000000000..a1910465d93eb
--- /dev/null
+++ b/mlir/test/lib/Interfaces/RegionBranchOpInterface/TestRegionBranchOpInterface.cpp
@@ -0,0 +1,76 @@
+//===- TestBlockInLoop.cpp - Pass to test mlir::blockIsInLoop -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Pass/Pass.h"
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+namespace {
+/// This is a test pass that tests Blocks's isInLoop method by checking if each
+/// block in a function is in a loop and outputing if it is
+struct PrintRegionBranchOpInterfacePass
+ : public PassWrapper<PrintRegionBranchOpInterfacePass, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrintRegionBranchOpInterfacePass)
+
+ StringRef getArgument() const final {
+ return "print-region-branch-op-interface";
+ }
+ StringRef getDescription() const final {
+ return "Print control-flow edges represented by "
+ "mlir::RegionBranchOpInterface";
+ }
+
+ void runOnOperation() override {
+ Operation *op = getOperation();
+ op->walk<WalkOrder::PreOrder>([&](RegionBranchOpInterface branchOp) {
+ llvm::outs() << "Found RegionBranchOpInterface operation: "
+ << OpWithFlags(
+ branchOp,
+ OpPrintingFlags().skipRegions().enableDebugInfo())
+ << "\n";
+ SmallVector<RegionSuccessor> regions;
+ branchOp.getSuccessorRegions(RegionBranchPoint::parent(), regions);
+ for (auto &successor : regions) {
+ if (successor.isParent()) {
+ llvm::outs() << " - Successor is parent\n";
+ } else {
+ llvm::outs() << " - Successor is region #"
+ << successor.getSuccessor()->getRegionNumber() << "\n";
+ }
+ }
+ if (auto breakingControlFlowOp =
+ dyn_cast<HasBreakingControlFlowOpInterface>(
+ branchOp.getOperation())) {
+ SmallVector<Operation *> predecessors;
+ llvm::outs() << " - Collecting all nested predecessors\n";
+ collectAllNestedPredecessors(breakingControlFlowOp, predecessors);
+ llvm::outs() << " - Found " << predecessors.size()
+ << " predecessor(s)\n";
+ for (auto &predecessor : predecessors) {
+ llvm::outs() << " - Predecessor is "
+ << OpWithFlags(
+ predecessor,
+ OpPrintingFlags().skipRegions().enableDebugInfo())
+ << "\n";
+ }
+ }
+ });
+ }
+};
+
+} // namespace
+
+namespace mlir {
+void registerRegionBranchOpInterfaceTestPasses() {
+ PassRegistration<PrintRegionBranchOpInterfacePass>();
+}
+} // namespace mlir
diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt
index c607ccfa80e3c..821bad3a4166a 100644
--- a/mlir/tools/mlir-opt/CMakeLists.txt
+++ b/mlir/tools/mlir-opt/CMakeLists.txt
@@ -22,6 +22,7 @@ if(MLIR_INCLUDE_TESTS)
MLIRGPUTestPasses
MLIRLinalgTestPasses
MLIRLoopLikeInterfaceTestPasses
+ MLIRTestRegionBranchOpInterface
MLIRMathTestPasses
MLIRTestMathToVCIX
MLIRMemRefTestPasses
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index ac739be8c5cb5..f79f8997f572c 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -38,6 +38,7 @@ void registerLazyLoadingTestPasses();
void registerLoopLikeInterfaceTestPasses();
void registerPassManagerTestPass();
void registerPrintSpirvAvailabilityPass();
+void registerRegionBranchOpInterfaceTestPasses();
void registerRegionTestPasses();
void registerPrintTosaAvailabilityPass();
void registerShapeFunctionTestPasses();
@@ -191,6 +192,7 @@ void registerTestPasses() {
registerPassManagerTestPass();
registerPrintSpirvAvailabilityPass();
registerRegionTestPasses();
+ registerRegionBranchOpInterfaceTestPasses();
registerShapeFunctionTestPasses();
registerSideEffectTestPasses();
registerSliceAnalysisTestPass();
diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp
index 04d3ed1f3b70d..2611922ae6d52 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/FormatGen.cpp
@@ -191,6 +191,7 @@ FormatToken FormatLexer::lexIdentifier(const char *tokStart) {
.Case("successors", FormatToken::kw_successors)
.Case("type", FormatToken::kw_type)
.Case("qualified", FormatToken::kw_qualified)
+ .Case("num-breaking-regions", FormatToken::kw_num_breaking_regions)
.Default(FormatToken::identifier);
return FormatToken(kind, str);
}
diff --git a/mlir/tools/mlir-tblgen/FormatGen.h b/mlir/tools/mlir-tblgen/FormatGen.h
index 8e7d49bb37e71..0b44ab69dcd1f 100644
--- a/mlir/tools/mlir-tblgen/FormatGen.h
+++ b/mlir/tools/mlir-tblgen/FormatGen.h
@@ -72,6 +72,7 @@ class FormatToken {
kw_struct,
kw_successors,
kw_type,
+ kw_num_breaking_regions,
keyword_end,
// String valued tokens.
@@ -305,6 +306,7 @@ class DirectiveElement : public FormatElementBase<FormatElement::Directive> {
Results,
Successors,
Type,
+ NumBreakingRegions,
Params,
Struct
};
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index ccf21d16005af..635dabd2d66b9 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -222,6 +222,14 @@ class TypeDirective : public DirectiveElementBase<DirectiveElement::Type> {
bool shouldBeQualifiedFlag = false;
};
+/// This class represents the `num-breaking-regions` directive. This directive
+/// represents the number of breaking regions of an operation.
+class NumBreakingRegionsDirective
+ : public DirectiveElementBase<DirectiveElement::NumBreakingRegions> {
+public:
+ NumBreakingRegionsDirective() = default;
+};
+
/// This class represents a group of order-independent optional clauses. Each
/// clause starts with a literal element and has a coressponding parsing
/// element. A parsing element is a continous sequence of format elements.
@@ -1691,6 +1699,16 @@ void OperationFormat::genElementParser(FormatElement *element, MethodBody &body,
body << formatv(functionalTypeParserCode,
getTypeListName(dir->getInputs(), ignored),
getTypeListName(dir->getResults(), ignored));
+ } else if (isa<NumBreakingRegionsDirective>(element)) {
+ body.indent() << "{\n";
+ body.indent()
+ << "auto loc = parser.getCurrentLocation();(void)loc;\n"
+ << "int32_t numBreakingRegions = 0;\n"
+ << "if (parser.parseInteger(numBreakingRegions))\n"
+ << " return ::mlir::failure();\n"
+ << "result.setNumBreakingControlRegions(numBreakingRegions);\n";
+ body.unindent() << "}\n";
+ body.unindent();
} else {
llvm_unreachable("unknown format element");
}
@@ -2547,6 +2565,13 @@ void OperationFormat::genElementPrinter(FormatElement *element,
return;
}
+ // Emit the num-breaking-regions.
+ if (isa<NumBreakingRegionsDirective>(element)) {
+ body << " _odsPrinter << \" \" << "
+ "getOperation()->getNumBreakingControlRegions();\n";
+ return;
+ }
+
// Optionally insert a space before the next element. The AttrDict printer
// already adds a space as necessary.
if (shouldEmitSpace || !lastWasPunctuation)
@@ -2813,6 +2838,8 @@ class OpFormatParser : public FormatParser {
FailureOr<FormatElement *> parseTypeDirective(SMLoc loc, Context context);
FailureOr<FormatElement *> parseTypeDirectiveOperand(SMLoc loc,
bool isRefChild = false);
+ FailureOr<FormatElement *> parseNumBreakingRegionsDirective(SMLoc loc,
+ Context context);
//===--------------------------------------------------------------------===//
// Fields
@@ -3440,6 +3467,8 @@ OpFormatParser::parseDirectiveImpl(SMLoc loc, FormatToken::Kind kind,
return parseTypeDirective(loc, ctx);
case FormatToken::kw_oilist:
return parseOIListDirective(loc, ctx);
+ case FormatToken::kw_num_breaking_regions:
+ return parseNumBreakingRegionsDirective(loc, ctx);
default:
return emitError(loc, "unsupported directive kind");
@@ -3690,6 +3719,14 @@ FailureOr<FormatElement *> OpFormatParser::parseTypeDirective(SMLoc loc,
return create<TypeDirective>(*operand);
}
+FailureOr<FormatElement *>
+OpFormatParser::parseNumBreakingRegionsDirective(SMLoc loc, Context context) {
+ if (context != TopLevelContext)
+ return emitError(
+ loc, "'num-breaking-regions' is only valid as a top-level directive");
+ return create<NumBreakingRegionsDirective>();
+}
+
LogicalResult OpFormatParser::markQualified(SMLoc loc, FormatElement *element) {
return TypeSwitch<FormatElement *, LogicalResult>(element)
.Case<AttributeVariable, TypeDirective>([](auto *element) {
diff --git a/mlir/unittests/Debug/FileLineColLocBreakpointManagerTest.cpp b/mlir/unittests/Debug/FileLineColLocBreakpointManagerTest.cpp
index 6a81422b6b66b..8f75fe05cd3dc 100644
--- a/mlir/unittests/Debug/FileLineColLocBreakpointManagerTest.cpp
+++ b/mlir/unittests/Debug/FileLineColLocBreakpointManagerTest.cpp
@@ -23,9 +23,10 @@ static Operation *createOp(MLIRContext *context, Location loc,
StringRef operationName,
unsigned int numRegions = 0) {
context->allowUnregisteredDialects();
+
return Operation::create(loc, OperationName(operationName, context), {}, {},
NamedAttrList(), OpaqueProperties(nullptr), {},
- numRegions);
+ numRegions, /*numBreakingControlRegions=*/0);
}
namespace {
diff --git a/mlir/unittests/IR/OperationSupportTest.cpp b/mlir/unittests/IR/OperationSupportTest.cpp
index 9f3e7ed34a27d..7e6dcc586e506 100644
--- a/mlir/unittests/IR/OperationSupportTest.cpp
+++ b/mlir/unittests/IR/OperationSupportTest.cpp
@@ -24,7 +24,8 @@ static Operation *createOp(MLIRContext *context, ArrayRef<Value> operands = {},
context->allowUnregisteredDialects();
return Operation::create(UnknownLoc::get(context),
OperationName("foo.bar", context), resultTypes,
- operands, NamedAttrList(), nullptr, {}, numRegions);
+ operands, NamedAttrList(), nullptr, {}, numRegions,
+ /*numBreakingControlRegions=*/0);
}
namespace {
@@ -236,7 +237,7 @@ TEST(OperationFormatPrintTest, CanPrintNameAsPrefix) {
Operation *op = Operation::create(
NameLoc::get(StringAttr::get(&context, "my_named_loc")),
OperationName("t.op", &context), builder.getIntegerType(16), {},
- NamedAttrList(), nullptr, {}, 0);
+ NamedAttrList(), nullptr, {}, 0, /*numBreakingControlRegions=*/0);
std::string str;
OpPrintingFlags flags;
diff --git a/mlir/unittests/IR/ValueTest.cpp b/mlir/unittests/IR/ValueTest.cpp
index 97e32d474d522..18ed8d9929175 100644
--- a/mlir/unittests/IR/ValueTest.cpp
+++ b/mlir/unittests/IR/ValueTest.cpp
@@ -22,7 +22,8 @@ static Operation *createOp(MLIRContext *context, ArrayRef<Value> operands = {},
context->allowUnregisteredDialects();
return Operation::create(UnknownLoc::get(context),
OperationName("foo.bar", context), resultTypes,
- operands, NamedAttrList(), nullptr, {}, numRegions);
+ operands, NamedAttrList(), nullptr, {}, numRegions,
+ /*numBreakingControlRegions=*/0);
}
namespace {
diff --git a/mlir/unittests/Transforms/DialectConversion.cpp b/mlir/unittests/Transforms/DialectConversion.cpp
index 6418c9dc0ac5b..2c7b35e9ef69c 100644
--- a/mlir/unittests/Transforms/DialectConversion.cpp
+++ b/mlir/unittests/Transforms/DialectConversion.cpp
@@ -15,7 +15,8 @@ static Operation *createOp(MLIRContext *context) {
context->allowUnregisteredDialects();
return Operation::create(UnknownLoc::get(context),
OperationName("foo.bar", context), {}, {},
- NamedAttrList(), /*properties=*/nullptr, {}, 0);
+ NamedAttrList(), /*properties=*/nullptr, {}, 0,
+ /*numBreakingControlRegions=*/0);
}
namespace {
More information about the Mlir-commits
mailing list