[Mlir-commits] [mlir] [mlir][Interfaces] Add `ExecutionProgressOpInterface` + folding pattern (PR #180348)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sat Feb 7 05:41:52 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-ub
@llvm/pr-subscribers-mlir
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
Add the `ExecutionProgressOpInterface` with an interface method to check if an operation "must progress". Add `mustProgress` attributes to `scf.for` and `scf.while` (default value is "true").
`mustProgress` corresponds to the [`llvm.loop.mustprogress` metadata](https://llvm.org/docs/LangRef.html#langref-llvm-loop-mustprogress).
Also add a canonicalization pattern to erase `RegionBranchOpInterface` ops that must progress but loop infinitely (and are non-side-effecting). This canonicalization pattern is enabled for `scf.for` and `scf.while`.
Registered operations are assumed to "must progress" by default.
RFC: https://discourse.llvm.org/t/infinite-loops-and-dead-code/89530
This PR is a re-upload of #<!-- -->179039.
---
Patch is 31.11 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/180348.diff
19 Files Affected:
- (modified) mlir/include/mlir/Dialect/SCF/IR/SCF.h (+1)
- (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (+13-4)
- (modified) mlir/include/mlir/Dialect/UB/IR/UBOps.h (+10)
- (modified) mlir/include/mlir/Interfaces/CMakeLists.txt (+1)
- (modified) mlir/include/mlir/Interfaces/ControlFlowInterfaces.h (+12)
- (modified) mlir/include/mlir/Interfaces/ControlFlowInterfaces.td (+1-1)
- (added) mlir/include/mlir/Interfaces/ExecutionProgressOpInterface.h (+29)
- (added) mlir/include/mlir/Interfaces/ExecutionProgressOpInterface.td (+48)
- (modified) mlir/lib/Dialect/SCF/IR/CMakeLists.txt (+2)
- (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+45-3)
- (modified) mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp (+3-2)
- (modified) mlir/lib/Dialect/UB/IR/CMakeLists.txt (+4)
- (modified) mlir/lib/Dialect/UB/IR/UBOps.cpp (+39)
- (modified) mlir/lib/Interfaces/CMakeLists.txt (+2)
- (modified) mlir/lib/Interfaces/ControlFlowInterfaces.cpp (+73-30)
- (added) mlir/lib/Interfaces/ExecutionProgressOpInterface.cpp (+29)
- (modified) mlir/test/CAPI/ir.c (+1-1)
- (modified) mlir/test/Dialect/SCF/canonicalize.mlir (+51)
- (modified) mlir/test/Dialect/SCF/ops.mlir (+4-3)
``````````diff
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index e754a04b0903a..de60ed99dd336 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -20,6 +20,7 @@
#include "mlir/IR/RegionKindInterface.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/ExecutionProgressOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index a08cf3c95e6ce..b259e33f1d75f 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -18,6 +18,7 @@ include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/IR/RegionKindInterface.td"
include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Interfaces/ExecutionProgressOpInterface.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/ParallelCombiningOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -40,7 +41,7 @@ def SCF_Dialect : Dialect {
and then lowered to some final target like LLVM or SPIR-V.
}];
- let dependentDialects = ["arith::ArithDialect"];
+ let dependentDialects = ["arith::ArithDialect", "ub::UBDialect"];
}
// Base class for SCF dialect ops.
@@ -161,6 +162,8 @@ def ForOp : SCF_Op<"for",
ConditionallySpeculatable,
DeclareOpInterfaceMethods<RegionBranchOpInterface,
["getEntrySuccessorOperands", "getSuccessorInputs"]>,
+ DeclareOpInterfaceMethods<ExecutionProgressOpInterface,
+ ["mustProgress"]>,
SingleBlockImplicitTerminator<"scf::YieldOp">,
RecursiveMemoryEffects]> {
let summary = "for operation";
@@ -265,7 +268,8 @@ def ForOp : SCF_Op<"for",
AnySignlessIntegerOrIndex:$upperBound,
AnySignlessIntegerOrIndex:$step,
Variadic<AnyType>:$initArgs,
- UnitAttr:$unsignedCmp);
+ UnitAttr:$unsignedCmp,
+ DefaultValuedAttr<BoolAttr, "true">:$mustProgress);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$region);
@@ -986,6 +990,7 @@ def WhileOp : SCF_Op<"while",
["getEntrySuccessorOperands", "getSuccessorInputs"]>,
DeclareOpInterfaceMethods<LoopLikeOpInterface,
["getRegionIterArgs", "getYieldedValuesMutable"]>,
+ DeclareOpInterfaceMethods<ExecutionProgressOpInterface, ["mustProgress"]>,
RecursiveMemoryEffects, SingleBlock]> {
let summary = "a generic 'while' loop";
let description = [{
@@ -1101,14 +1106,18 @@ def WhileOp : SCF_Op<"while",
```
}];
- let arguments = (ins Variadic<AnyType>:$inits);
+ let arguments = (ins Variadic<AnyType>:$inits,
+ DefaultValuedAttr<BoolAttr, "true">:$mustProgress);
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$before, SizedRegion<1>:$after);
+ let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$inits,
"function_ref<void(OpBuilder &, Location, ValueRange)>":$beforeBuilder,
- "function_ref<void(OpBuilder &, Location, ValueRange)>":$afterBuilder)>
+ "function_ref<void(OpBuilder &, Location, ValueRange)>":$afterBuilder)>,
+ OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$inits,
+ CArg<"bool", "true">:$mustProgress)>
];
let extraClassDeclaration = [{
diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.h b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
index 21de5cb0c182a..281bd3ed4e805 100644
--- a/mlir/include/mlir/Dialect/UB/IR/UBOps.h
+++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
@@ -12,6 +12,7 @@
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Dialect/UB/IR/UBOpsInterfaces.h.inc"
@@ -24,4 +25,13 @@
#include "mlir/Dialect/UB/IR/UBOpsDialect.h.inc"
+namespace mlir::ub {
+/// Populate a canonicalization pattern that erases "must progress" region
+/// branch ops that loop infinitely and replaces their results with poison
+/// values.
+void populateEraseInfiniteRegionBranchLoopPattern(RewritePatternSet &patterns,
+ StringRef opName,
+ PatternBenefit benefit = 1);
+} // namespace mlir::ub
+
#endif // MLIR_DIALECT_UB_IR_OPS_H
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index eb96a68861116..e0c75aee29c00 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_interface(CastInterfaces)
add_mlir_interface(ControlFlowInterfaces)
add_mlir_interface(DerivedAttributeOpInterface)
add_mlir_interface(DestinationStyleOpInterface)
+add_mlir_interface(ExecutionProgressOpInterface)
add_mlir_interface(FunctionInterfaces)
add_mlir_interface(IndexingMapOpInterface)
add_mlir_interface(InferIntRangeInterface)
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index a76dce6f2ffc5..33e139f6b0cea 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -314,6 +314,11 @@ Region *getEnclosingRepetitiveRegion(Operation *op);
/// exists.
Region *getEnclosingRepetitiveRegion(Value value);
+/// Return "true" if the given region branch op is guaranteed to loop
+/// infinitely. Every path starting from "parent" enters the region, but the
+/// "parent" is not reachable from there.
+bool isGuaranteedToLoopInfinitely(RegionBranchOpInterface op);
+
/// Populate canonicalization patterns that simplify successor operands/inputs
/// of region branch operations. Only operations with the given name are
/// matched.
@@ -359,6 +364,13 @@ void populateRegionBranchOpInterfaceInliningPattern(
PatternMatcherFn matcherFn = detail::defaultMatcherFn,
PatternBenefit benefit = 1);
+/// Return all successor regions when branching from the given region branch
+/// point. This helper functions extracts all constant operand values and
+/// passes them to the `RegionBranchOpInterface`.
+SmallVector<RegionSuccessor>
+getSuccessorRegionsWithAttrs(RegionBranchOpInterface op,
+ RegionBranchPoint point);
+
//===----------------------------------------------------------------------===//
// ControlFlow Traits
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 8975b1235a7e3..1dacde297efa2 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -350,7 +350,7 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
"bool", "areTypesCompatible",
(ins "::mlir::Type":$lhs, "::mlir::Type":$rhs), [{}],
/*defaultImplementation=*/[{ return lhs == rhs; }]
- >,
+ >
];
let verify = [{
diff --git a/mlir/include/mlir/Interfaces/ExecutionProgressOpInterface.h b/mlir/include/mlir/Interfaces/ExecutionProgressOpInterface.h
new file mode 100644
index 0000000000000..e395f909de092
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/ExecutionProgressOpInterface.h
@@ -0,0 +1,29 @@
+//===- ExecutionProgressOpInterface.h ---------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_EXECUTIONPROGRESSOPINTERFACE_H_
+#define MLIR_INTERFACES_EXECUTIONPROGRESSOPINTERFACE_H_
+
+#include "mlir/IR/OpDefinition.h"
+
+#include "mlir/Interfaces/ExecutionProgressOpInterface.h.inc"
+
+namespace mlir {
+/// Return "true" if the operation must progress.
+///
+/// Unregistered operations are treated conservatively: they may not
+/// necessarily progress (i.e., return "false"). Registered operations are
+/// assumed to progress by default. This can be overridden by the
+/// ExecutionProgressOpInterface.
+bool mustProgress(Operation *op);
+
+/// Return "true" if the operation might not progress.
+inline bool mightNotProgress(Operation *op) { return !mustProgress(op); }
+} // namespace mlir
+
+#endif // MLIR_INTERFACES_EXECUTIONPROGRESSOPINTERFACE_H_
diff --git a/mlir/include/mlir/Interfaces/ExecutionProgressOpInterface.td b/mlir/include/mlir/Interfaces/ExecutionProgressOpInterface.td
new file mode 100644
index 0000000000000..4b7923ce3612e
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/ExecutionProgressOpInterface.td
@@ -0,0 +1,48 @@
+//===- ExecutionProgressOpInterface.td - Interface Decl. -*- tablegen -*---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the definition file for the ExecutionProgressOpInterface.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_EXECUTION_PROGRESS_OP_INTERFACE
+#define MLIR_INTERFACES_EXECUTION_PROGRESS_OP_INTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def ExecutionProgressOpInterface : OpInterface<"ExecutionProgressOpInterface"> {
+ let description = [{
+ This interface models execution progress properties of operations.
+ }];
+ let cppNamespace = "::mlir";
+ let methods = [
+ InterfaceMethod<
+ /*desc=*/[{
+ Operations that "must progress" are required to return normally (control
+ flow reaches the next operation) or interact with the environment in an
+ observable way (e.g., volatile memory access, I/O, synchronization or
+ program termination). If a "must progress" op executes indefinitely
+ without any observable interaction, it may be erased.
+
+ See LLVM "llvm.loop.mustprogress" / "mustprogress" function attribute
+ for more details.
+
+ Operations must progress by default.
+ }],
+ /*retTy=*/"bool",
+ /*methodName=*/"mustProgress",
+ /*args=*/(ins),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return true;
+ }]
+ >
+ ];
+}
+
+#endif // MLIR_INTERFACES_EXECUTION_PROGRESS_OP_INTERFACE
diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
index b111117410ba3..8c3b93b3c580b 100644
--- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
@@ -13,11 +13,13 @@ add_mlir_dialect_library(MLIRSCFDialect
MLIRArithDialect
MLIRControlFlowDialect
MLIRDialectUtils
+ MLIRExecutionProgressOpInterface
MLIRFunctionInterfaces
MLIRIR
MLIRLoopLikeInterface
MLIRSideEffectInterfaces
MLIRTensorDialect
+ MLIRUBDialect
MLIRValueBoundsOpInterface
MLIRTransformUtils
)
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index c46a0577c4b96..0116620bdd3a3 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Matchers.h"
@@ -509,8 +510,10 @@ void ForOp::print(OpAsmPrinter &p) {
p.printRegion(getRegion(),
/*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/!getInitArgs().empty());
- p.printOptionalAttrDict((*this)->getAttrs(),
- /*elidedAttrs=*/getUnsignedCmpAttrName().strref());
+ SmallVector<StringRef> elidedAttrs = {getUnsignedCmpAttrName().strref()};
+ if (getMustProgress()) // "true" is the default, elide attribute.
+ elidedAttrs.push_back(getMustProgressAttrName().strref());
+ p.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
}
ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -691,6 +694,24 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point,
}
}
+ // Infinite loops (lb < ub and step size 0) enter the loop body and never
+ // leave it.
+ std::optional<std::pair<APInt, bool>> lbCst =
+ getConstantAPIntValue(getLowerBound());
+ std::optional<std::pair<APInt, bool>> ubCst =
+ getConstantAPIntValue(getUpperBound());
+ std::optional<std::pair<APInt, bool>> stepCst =
+ getConstantAPIntValue(getStep());
+ if (lbCst.has_value() && ubCst.has_value() && stepCst.has_value()) {
+ bool atLeastOneIteration =
+ (getUnsignedCmp() && lbCst->first.ult(ubCst->first)) ||
+ (!getUnsignedCmp() && lbCst->first.slt(ubCst->first));
+ if (atLeastOneIteration && stepCst->first.isZero()) {
+ regions.push_back(RegionSuccessor(&getRegion()));
+ return;
+ }
+ }
+
// Both the operation itself and the region may be branching into the body or
// back into the operation itself. It is possible for loop not to enter the
// body.
@@ -703,6 +724,8 @@ ValueRange ForOp::getSuccessorInputs(RegionSuccessor successor) {
: ValueRange(getRegionIterArgs());
}
+bool ForOp::mustProgress() { return getMustProgress(); }
+
SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
/// Promotes the loop body of a forallOp to its containing block if it can be
@@ -1004,6 +1027,8 @@ void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
auto forOp = cast<ForOp>(blockArg.getOwner()->getParentOp());
return forOp.getLowerBound();
});
+ ub::populateEraseInfiniteRegionBranchLoopPattern(results,
+ ForOp::getOperationName());
}
std::optional<APInt> ForOp::getConstantStep() {
@@ -3210,6 +3235,16 @@ void WhileOp::build(::mlir::OpBuilder &odsBuilder,
afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
}
+void WhileOp::build(::mlir::OpBuilder &odsBuilder,
+ ::mlir::OperationState &odsState, TypeRange resultTypes,
+ ValueRange inits, bool mustProgress) {
+ odsState.addOperands(inits);
+ for (unsigned i = 0; i < 2; ++i)
+ (void)odsState.addRegion();
+ odsState.addTypes(resultTypes);
+ odsState.addAttribute("mustProgress", odsBuilder.getBoolAttr(mustProgress));
+}
+
ConditionOp WhileOp::getConditionOp() {
return cast<ConditionOp>(getBeforeBody()->getTerminator());
}
@@ -3273,6 +3308,8 @@ ValueRange WhileOp::getSuccessorInputs(RegionSuccessor successor) {
llvm_unreachable("invalid region successor");
}
+bool WhileOp::mustProgress() { return getMustProgress(); }
+
SmallVector<Region *> WhileOp::getLoopRegions() {
return {&getBefore(), &getAfter()};
}
@@ -3332,7 +3369,10 @@ void scf::WhileOp::print(OpAsmPrinter &p) {
p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
p << " do ";
p.printRegion(getAfter());
- p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
+ SmallVector<StringRef> elidedAttrs;
+ if (getMustProgress()) // "true" is the default, elide attribute.
+ elidedAttrs.push_back(getMustProgressAttrName().strref());
+ p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
}
/// Verifies that two ranges of types match, i.e. have the same number of
@@ -3708,6 +3748,8 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
results, WhileOp::getOperationName());
populateRegionBranchOpInterfaceInliningPattern(results,
WhileOp::getOperationName());
+ ub::populateEraseInfiniteRegionBranchLoopPattern(results,
+ WhileOp::getOperationName());
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
index ddcbda86cf1f3..152fb226993e9 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
@@ -49,8 +49,9 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
SmallVector<Value> initArgs;
initArgs.push_back(forOp.getLowerBound());
llvm::append_range(initArgs, forOp.getInitArgs());
- auto whileOp = WhileOp::create(rewriter, forOp.getLoc(), lcvTypes, initArgs,
- forOp->getAttrs());
+ auto whileOp =
+ WhileOp::create(rewriter, forOp.getLoc(), lcvTypes, initArgs);
+ whileOp->setAttrs(forOp->getAttrDictionary());
// 'before' region contains the loop condition and forwarding of iteration
// arguments to the 'after' region.
diff --git a/mlir/lib/Dialect/UB/IR/CMakeLists.txt b/mlir/lib/Dialect/UB/IR/CMakeLists.txt
index 84125ea0b5718..3baac5045b8db 100644
--- a/mlir/lib/Dialect/UB/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/UB/IR/CMakeLists.txt
@@ -5,9 +5,13 @@ add_mlir_dialect_library(MLIRUBDialect
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/UB
DEPENDS
+ MLIRControlFlowInterfaces
MLIRUBOpsIncGen
MLIRUBOpsInterfacesIncGen
LINK_LIBS PUBLIC
+ MLIRControlFlowInterfaces
MLIRIR
+ MLIRExecutionProgressOpInterface
+ MLIRSideEffectInterfaces
)
diff --git a/mlir/lib/Dialect/UB/IR/UBOps.cpp b/mlir/lib/Dialect/UB/IR/UBOps.cpp
index ee523f9522953..2310fc5af8cb8 100644
--- a/mlir/lib/Dialect/UB/IR/UBOps.cpp
+++ b/mlir/lib/Dialect/UB/IR/UBOps.cpp
@@ -8,6 +8,8 @@
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/ExecutionProgressOpInterface.h"
#include "mlir/Transforms/InliningUtils.h"
#include "mlir/IR/Builders.h"
@@ -66,3 +68,40 @@ OpFoldResult PoisonOp::fold(FoldAdaptor /*adaptor*/) { return getValue(); }
#define GET_OP_CLASSES
#include "mlir/Dialect/UB/IR/UBOps.cpp.inc"
+
+namespace {
+/// Canonicalization pattern for RegionBranchOpInterface ops that loop
+/// infinitely. Such ops are replaced with poison values if they "must
+/// progress".
+struct EraseInfiniteRegionBranchLoop : public RewritePattern {
+ EraseInfiniteRegionBranchLoop(MLIRContext *context, StringRef name,
+ PatternBenefit benefit = 1)
+ : RewritePattern(name, benefit, context) {}
+
+ LogicalResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const override {
+ auto regionBranchOp = cast<RegionBranchOpInterface>(op);
+ if (mightNotProgress(op))
+ return rewriter.notifyMatchFailure(
+ op, "only loops that must progress are removed");
+ if (!wouldOpBeTriviallyDead(op))
+ return rewriter.notifyMatchFailure(op,
+ "only trivially dead ops are removed");
+ if (!isGuaranteedToLoopInfinitely(regionBranchOp))
+ return rewriter.notifyMatchFailure(
+ op, "only loops that loop infinitely are removed");
+ SmallVector<Value> replacements =
+ llvm::map_to_vector(op->getResultTypes(), [&](Type type) {
+ return PoisonOp::create(rewriter, op->getLoc(), type).getResult();
+ });
+ rewriter.replaceOp(op, replacements);
+ return success();
+ }
+};
+} // namespace
+
+void mlir::ub::populateEraseInfiniteRegionBranchLoopPattern(
+ RewritePatternSet &patterns, StringRef opName, PatternBenefit benefit) {
+ patterns.add<EraseInfiniteRegionBranchLoop>(patterns.getContext(), opName,
+ benefit);
+}
diff --git a/m...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/180348
More information about the Mlir-commits
mailing list