[Mlir-commits] [mlir] [mlir][Interfaces] Add `ExecutionProgressOpInterface` + folding pattern (PR #180348)

Matthias Springer llvmlistbot at llvm.org
Sat Feb 7 05:41:23 PST 2026


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/180348

Add the `ExecutionProgressOpInterface` with an interface method to check if an operation "must progress". Add `mustProgress` attributes to `scf.for` and `scf.while` (default value is "true").

`mustProgress` corresponds to the [`llvm.loop.mustprogress` metadata](https://llvm.org/docs/LangRef.html#langref-llvm-loop-mustprogress).

Also add a canonicalization pattern to erase `RegionBranchOpInterface` ops that must progress but loop infinitely (and are non-side-effecting). This canonicalization pattern is enabled for `scf.for` and `scf.while`.

Registered operations are assumed to "must progress" by default.

RFC: https://discourse.llvm.org/t/infinite-loops-and-dead-code/89530

This PR is a re-upload of #179039.


>From f6219e83eccf301614c91677364fef7ad4241440 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 2 Feb 2026 08:25:45 +0100
Subject: [PATCH] [mlir][Interfaces] Add `ExecutionProgressOpInterface` +
 folding pattern (#179039)

Add the `ExecutionProgressOpInterface` with an interface method to check
if an operation "must progress". Add `mustProgress` attributes to
`scf.for` and `scf.while` (default value is "true").

`mustProgress` corresponds to the [`llvm.loop.mustprogress`
metadata](https://llvm.org/docs/LangRef.html#langref-llvm-loop-mustprogress).

Also add a canonicalization pattern to erase `RegionBranchOpInterface`
ops that must progress but loop infinitely (and are non-side-effecting).
This canonicalization pattern is enabled for `scf.for` and `scf.while`.

RFC: https://discourse.llvm.org/t/infinite-loops-and-dead-code/89530

[mlir] Fix build after #179039 (#179180)

Fix build after #179039.
---
 mlir/include/mlir/Dialect/SCF/IR/SCF.h        |   1 +
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td    |  17 ++-
 mlir/include/mlir/Dialect/UB/IR/UBOps.h       |  10 ++
 mlir/include/mlir/Interfaces/CMakeLists.txt   |   1 +
 .../mlir/Interfaces/ControlFlowInterfaces.h   |  12 ++
 .../mlir/Interfaces/ControlFlowInterfaces.td  |   2 +-
 .../Interfaces/ExecutionProgressOpInterface.h |  29 +++++
 .../ExecutionProgressOpInterface.td           |  48 ++++++++
 mlir/lib/Dialect/SCF/IR/CMakeLists.txt        |   2 +
 mlir/lib/Dialect/SCF/IR/SCF.cpp               |  48 +++++++-
 .../lib/Dialect/SCF/Transforms/ForToWhile.cpp |   5 +-
 mlir/lib/Dialect/UB/IR/CMakeLists.txt         |   4 +
 mlir/lib/Dialect/UB/IR/UBOps.cpp              |  39 +++++++
 mlir/lib/Interfaces/CMakeLists.txt            |   2 +
 mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 103 +++++++++++++-----
 .../ExecutionProgressOpInterface.cpp          |  29 +++++
 mlir/test/CAPI/ir.c                           |   2 +-
 mlir/test/Dialect/SCF/canonicalize.mlir       |  51 +++++++++
 mlir/test/Dialect/SCF/ops.mlir                |   7 +-
 19 files changed, 368 insertions(+), 44 deletions(-)
 create mode 100644 mlir/include/mlir/Interfaces/ExecutionProgressOpInterface.h
 create mode 100644 mlir/include/mlir/Interfaces/ExecutionProgressOpInterface.td
 create mode 100644 mlir/lib/Interfaces/ExecutionProgressOpInterface.cpp

diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCF.h b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
index e754a04b0903a..de60ed99dd336 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCF.h
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCF.h
@@ -20,6 +20,7 @@
 #include "mlir/IR/RegionKindInterface.h"
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
+#include "mlir/Interfaces/ExecutionProgressOpInterface.h"
 #include "mlir/Interfaces/InferTypeOpInterface.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
 #include "mlir/Interfaces/ParallelCombiningOpInterface.h"
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index a08cf3c95e6ce..b259e33f1d75f 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -18,6 +18,7 @@ include "mlir/Interfaces/LoopLikeInterface.td"
 include "mlir/IR/RegionKindInterface.td"
 include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td"
 include "mlir/Interfaces/DestinationStyleOpInterface.td"
+include "mlir/Interfaces/ExecutionProgressOpInterface.td"
 include "mlir/Interfaces/InferTypeOpInterface.td"
 include "mlir/Interfaces/ParallelCombiningOpInterface.td"
 include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -40,7 +41,7 @@ def SCF_Dialect : Dialect {
     and then lowered to some final target like LLVM or SPIR-V.
   }];
 
-  let dependentDialects = ["arith::ArithDialect"];
+  let dependentDialects = ["arith::ArithDialect", "ub::UBDialect"];
 }
 
 // Base class for SCF dialect ops.
@@ -161,6 +162,8 @@ def ForOp : SCF_Op<"for",
        ConditionallySpeculatable,
        DeclareOpInterfaceMethods<RegionBranchOpInterface,
         ["getEntrySuccessorOperands", "getSuccessorInputs"]>,
+       DeclareOpInterfaceMethods<ExecutionProgressOpInterface,
+        ["mustProgress"]>,
        SingleBlockImplicitTerminator<"scf::YieldOp">,
        RecursiveMemoryEffects]> {
   let summary = "for operation";
@@ -265,7 +268,8 @@ def ForOp : SCF_Op<"for",
                        AnySignlessIntegerOrIndex:$upperBound,
                        AnySignlessIntegerOrIndex:$step,
                        Variadic<AnyType>:$initArgs,
-                       UnitAttr:$unsignedCmp);
+                       UnitAttr:$unsignedCmp,
+                       DefaultValuedAttr<BoolAttr, "true">:$mustProgress);
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$region);
 
@@ -986,6 +990,7 @@ def WhileOp : SCF_Op<"while",
         ["getEntrySuccessorOperands", "getSuccessorInputs"]>,
      DeclareOpInterfaceMethods<LoopLikeOpInterface,
         ["getRegionIterArgs", "getYieldedValuesMutable"]>,
+     DeclareOpInterfaceMethods<ExecutionProgressOpInterface, ["mustProgress"]>,
      RecursiveMemoryEffects, SingleBlock]> {
   let summary = "a generic 'while' loop";
   let description = [{
@@ -1101,14 +1106,18 @@ def WhileOp : SCF_Op<"while",
     ```
   }];
 
-  let arguments = (ins Variadic<AnyType>:$inits);
+  let arguments = (ins Variadic<AnyType>:$inits,
+                       DefaultValuedAttr<BoolAttr, "true">:$mustProgress);
   let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$before, SizedRegion<1>:$after);
 
+  let skipDefaultBuilders = 1;
   let builders = [
     OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$inits,
       "function_ref<void(OpBuilder &, Location, ValueRange)>":$beforeBuilder,
-      "function_ref<void(OpBuilder &, Location, ValueRange)>":$afterBuilder)>
+      "function_ref<void(OpBuilder &, Location, ValueRange)>":$afterBuilder)>,
+    OpBuilder<(ins "TypeRange":$resultTypes, "ValueRange":$inits,
+                   CArg<"bool", "true">:$mustProgress)>
   ];
 
   let extraClassDeclaration = [{
diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.h b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
index 21de5cb0c182a..281bd3ed4e805 100644
--- a/mlir/include/mlir/Dialect/UB/IR/UBOps.h
+++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
@@ -12,6 +12,7 @@
 #include "mlir/Bytecode/BytecodeOpInterface.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 
 #include "mlir/Dialect/UB/IR/UBOpsInterfaces.h.inc"
@@ -24,4 +25,13 @@
 
 #include "mlir/Dialect/UB/IR/UBOpsDialect.h.inc"
 
+namespace mlir::ub {
+/// Populate a canonicalization pattern that erases "must progress" region
+/// branch ops that loop infinitely and replaces their results with poison
+/// values.
+void populateEraseInfiniteRegionBranchLoopPattern(RewritePatternSet &patterns,
+                                                  StringRef opName,
+                                                  PatternBenefit benefit = 1);
+} // namespace mlir::ub
+
 #endif // MLIR_DIALECT_UB_IR_OPS_H
diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt
index eb96a68861116..e0c75aee29c00 100644
--- a/mlir/include/mlir/Interfaces/CMakeLists.txt
+++ b/mlir/include/mlir/Interfaces/CMakeLists.txt
@@ -4,6 +4,7 @@ add_mlir_interface(CastInterfaces)
 add_mlir_interface(ControlFlowInterfaces)
 add_mlir_interface(DerivedAttributeOpInterface)
 add_mlir_interface(DestinationStyleOpInterface)
+add_mlir_interface(ExecutionProgressOpInterface)
 add_mlir_interface(FunctionInterfaces)
 add_mlir_interface(IndexingMapOpInterface)
 add_mlir_interface(InferIntRangeInterface)
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
index a76dce6f2ffc5..33e139f6b0cea 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
@@ -314,6 +314,11 @@ Region *getEnclosingRepetitiveRegion(Operation *op);
 /// exists.
 Region *getEnclosingRepetitiveRegion(Value value);
 
+/// Return "true" if the given region branch op is guaranteed to loop
+/// infinitely. Every path starting from "parent" enters the region, but the
+/// "parent" is not reachable from there.
+bool isGuaranteedToLoopInfinitely(RegionBranchOpInterface op);
+
 /// Populate canonicalization patterns that simplify successor operands/inputs
 /// of region branch operations. Only operations with the given name are
 /// matched.
@@ -359,6 +364,13 @@ void populateRegionBranchOpInterfaceInliningPattern(
     PatternMatcherFn matcherFn = detail::defaultMatcherFn,
     PatternBenefit benefit = 1);
 
+/// Return all successor regions when branching from the given region branch
+/// point. This helper functions extracts all constant operand values and
+/// passes them to the `RegionBranchOpInterface`.
+SmallVector<RegionSuccessor>
+getSuccessorRegionsWithAttrs(RegionBranchOpInterface op,
+                             RegionBranchPoint point);
+
 //===----------------------------------------------------------------------===//
 // ControlFlow Traits
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
index 8975b1235a7e3..1dacde297efa2 100644
--- a/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
+++ b/mlir/include/mlir/Interfaces/ControlFlowInterfaces.td
@@ -350,7 +350,7 @@ def RegionBranchOpInterface : OpInterface<"RegionBranchOpInterface"> {
       "bool", "areTypesCompatible",
       (ins "::mlir::Type":$lhs, "::mlir::Type":$rhs), [{}],
       /*defaultImplementation=*/[{ return lhs == rhs; }]
-    >,
+    >
   ];
 
   let verify = [{
diff --git a/mlir/include/mlir/Interfaces/ExecutionProgressOpInterface.h b/mlir/include/mlir/Interfaces/ExecutionProgressOpInterface.h
new file mode 100644
index 0000000000000..e395f909de092
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/ExecutionProgressOpInterface.h
@@ -0,0 +1,29 @@
+//===- ExecutionProgressOpInterface.h ---------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_EXECUTIONPROGRESSOPINTERFACE_H_
+#define MLIR_INTERFACES_EXECUTIONPROGRESSOPINTERFACE_H_
+
+#include "mlir/IR/OpDefinition.h"
+
+#include "mlir/Interfaces/ExecutionProgressOpInterface.h.inc"
+
+namespace mlir {
+/// Return "true" if the operation must progress.
+///
+/// Unregistered operations are treated conservatively: they may not
+/// necessarily progress (i.e., return "false"). Registered operations are
+/// assumed to progress by default. This can be overridden by the
+/// ExecutionProgressOpInterface.
+bool mustProgress(Operation *op);
+
+/// Return "true" if the operation might not progress.
+inline bool mightNotProgress(Operation *op) { return !mustProgress(op); }
+} // namespace mlir
+
+#endif // MLIR_INTERFACES_EXECUTIONPROGRESSOPINTERFACE_H_
diff --git a/mlir/include/mlir/Interfaces/ExecutionProgressOpInterface.td b/mlir/include/mlir/Interfaces/ExecutionProgressOpInterface.td
new file mode 100644
index 0000000000000..4b7923ce3612e
--- /dev/null
+++ b/mlir/include/mlir/Interfaces/ExecutionProgressOpInterface.td
@@ -0,0 +1,48 @@
+//===- ExecutionProgressOpInterface.td - Interface Decl. -*- tablegen -*---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This is the definition file for the ExecutionProgressOpInterface.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_INTERFACES_EXECUTION_PROGRESS_OP_INTERFACE
+#define MLIR_INTERFACES_EXECUTION_PROGRESS_OP_INTERFACE
+
+include "mlir/IR/OpBase.td"
+
+def ExecutionProgressOpInterface : OpInterface<"ExecutionProgressOpInterface"> {
+  let description = [{
+    This interface models execution progress properties of operations.
+  }];
+  let cppNamespace = "::mlir";
+  let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Operations that "must progress" are required to return normally (control
+        flow reaches the next operation) or interact with the environment in an
+        observable way (e.g., volatile memory access, I/O, synchronization or
+        program termination). If a "must progress" op executes indefinitely
+        without any observable interaction, it may be erased.
+        
+        See LLVM "llvm.loop.mustprogress" / "mustprogress" function attribute
+        for more details.
+
+        Operations must progress by default.
+      }],
+      /*retTy=*/"bool",
+      /*methodName=*/"mustProgress",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return true;
+      }]
+    >
+  ];
+}
+
+#endif // MLIR_INTERFACES_EXECUTION_PROGRESS_OP_INTERFACE
diff --git a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
index b111117410ba3..8c3b93b3c580b 100644
--- a/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/SCF/IR/CMakeLists.txt
@@ -13,11 +13,13 @@ add_mlir_dialect_library(MLIRSCFDialect
   MLIRArithDialect
   MLIRControlFlowDialect
   MLIRDialectUtils
+  MLIRExecutionProgressOpInterface
   MLIRFunctionInterfaces
   MLIRIR
   MLIRLoopLikeInterface
   MLIRSideEffectInterfaces
   MLIRTensorDialect
+  MLIRUBDialect
   MLIRValueBoundsOpInterface
   MLIRTransformUtils
   )
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index c46a0577c4b96..0116620bdd3a3 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Matchers.h"
@@ -509,8 +510,10 @@ void ForOp::print(OpAsmPrinter &p) {
   p.printRegion(getRegion(),
                 /*printEntryBlockArgs=*/false,
                 /*printBlockTerminators=*/!getInitArgs().empty());
-  p.printOptionalAttrDict((*this)->getAttrs(),
-                          /*elidedAttrs=*/getUnsignedCmpAttrName().strref());
+  SmallVector<StringRef> elidedAttrs = {getUnsignedCmpAttrName().strref()};
+  if (getMustProgress()) // "true" is the default, elide attribute.
+    elidedAttrs.push_back(getMustProgressAttrName().strref());
+  p.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
 }
 
 ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -691,6 +694,24 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point,
     }
   }
 
+  // Infinite loops (lb < ub and step size 0) enter the loop body and never
+  // leave it.
+  std::optional<std::pair<APInt, bool>> lbCst =
+      getConstantAPIntValue(getLowerBound());
+  std::optional<std::pair<APInt, bool>> ubCst =
+      getConstantAPIntValue(getUpperBound());
+  std::optional<std::pair<APInt, bool>> stepCst =
+      getConstantAPIntValue(getStep());
+  if (lbCst.has_value() && ubCst.has_value() && stepCst.has_value()) {
+    bool atLeastOneIteration =
+        (getUnsignedCmp() && lbCst->first.ult(ubCst->first)) ||
+        (!getUnsignedCmp() && lbCst->first.slt(ubCst->first));
+    if (atLeastOneIteration && stepCst->first.isZero()) {
+      regions.push_back(RegionSuccessor(&getRegion()));
+      return;
+    }
+  }
+
   // Both the operation itself and the region may be branching into the body or
   // back into the operation itself. It is possible for loop not to enter the
   // body.
@@ -703,6 +724,8 @@ ValueRange ForOp::getSuccessorInputs(RegionSuccessor successor) {
                               : ValueRange(getRegionIterArgs());
 }
 
+bool ForOp::mustProgress() { return getMustProgress(); }
+
 SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
 
 /// Promotes the loop body of a forallOp to its containing block if it can be
@@ -1004,6 +1027,8 @@ void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
         auto forOp = cast<ForOp>(blockArg.getOwner()->getParentOp());
         return forOp.getLowerBound();
       });
+  ub::populateEraseInfiniteRegionBranchLoopPattern(results,
+                                                   ForOp::getOperationName());
 }
 
 std::optional<APInt> ForOp::getConstantStep() {
@@ -3210,6 +3235,16 @@ void WhileOp::build(::mlir::OpBuilder &odsBuilder,
     afterBuilder(odsBuilder, odsState.location, afterBlock->getArguments());
 }
 
+void WhileOp::build(::mlir::OpBuilder &odsBuilder,
+                    ::mlir::OperationState &odsState, TypeRange resultTypes,
+                    ValueRange inits, bool mustProgress) {
+  odsState.addOperands(inits);
+  for (unsigned i = 0; i < 2; ++i)
+    (void)odsState.addRegion();
+  odsState.addTypes(resultTypes);
+  odsState.addAttribute("mustProgress", odsBuilder.getBoolAttr(mustProgress));
+}
+
 ConditionOp WhileOp::getConditionOp() {
   return cast<ConditionOp>(getBeforeBody()->getTerminator());
 }
@@ -3273,6 +3308,8 @@ ValueRange WhileOp::getSuccessorInputs(RegionSuccessor successor) {
   llvm_unreachable("invalid region successor");
 }
 
+bool WhileOp::mustProgress() { return getMustProgress(); }
+
 SmallVector<Region *> WhileOp::getLoopRegions() {
   return {&getBefore(), &getAfter()};
 }
@@ -3332,7 +3369,10 @@ void scf::WhileOp::print(OpAsmPrinter &p) {
   p.printRegion(getBefore(), /*printEntryBlockArgs=*/false);
   p << " do ";
   p.printRegion(getAfter());
-  p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
+  SmallVector<StringRef> elidedAttrs;
+  if (getMustProgress()) // "true" is the default, elide attribute.
+    elidedAttrs.push_back(getMustProgressAttrName().strref());
+  p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
 }
 
 /// Verifies that two ranges of types match, i.e. have the same number of
@@ -3708,6 +3748,8 @@ void WhileOp::getCanonicalizationPatterns(RewritePatternSet &results,
       results, WhileOp::getOperationName());
   populateRegionBranchOpInterfaceInliningPattern(results,
                                                  WhileOp::getOperationName());
+  ub::populateEraseInfiniteRegionBranchLoopPattern(results,
+                                                   WhileOp::getOperationName());
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
index ddcbda86cf1f3..152fb226993e9 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
@@ -49,8 +49,9 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
     SmallVector<Value> initArgs;
     initArgs.push_back(forOp.getLowerBound());
     llvm::append_range(initArgs, forOp.getInitArgs());
-    auto whileOp = WhileOp::create(rewriter, forOp.getLoc(), lcvTypes, initArgs,
-                                   forOp->getAttrs());
+    auto whileOp =
+        WhileOp::create(rewriter, forOp.getLoc(), lcvTypes, initArgs);
+    whileOp->setAttrs(forOp->getAttrDictionary());
 
     // 'before' region contains the loop condition and forwarding of iteration
     // arguments to the 'after' region.
diff --git a/mlir/lib/Dialect/UB/IR/CMakeLists.txt b/mlir/lib/Dialect/UB/IR/CMakeLists.txt
index 84125ea0b5718..3baac5045b8db 100644
--- a/mlir/lib/Dialect/UB/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/UB/IR/CMakeLists.txt
@@ -5,9 +5,13 @@ add_mlir_dialect_library(MLIRUBDialect
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/UB
 
   DEPENDS
+  MLIRControlFlowInterfaces
   MLIRUBOpsIncGen
   MLIRUBOpsInterfacesIncGen
 
   LINK_LIBS PUBLIC
+  MLIRControlFlowInterfaces
   MLIRIR
+  MLIRExecutionProgressOpInterface
+  MLIRSideEffectInterfaces
   )
diff --git a/mlir/lib/Dialect/UB/IR/UBOps.cpp b/mlir/lib/Dialect/UB/IR/UBOps.cpp
index ee523f9522953..2310fc5af8cb8 100644
--- a/mlir/lib/Dialect/UB/IR/UBOps.cpp
+++ b/mlir/lib/Dialect/UB/IR/UBOps.cpp
@@ -8,6 +8,8 @@
 
 #include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
+#include "mlir/Interfaces/ControlFlowInterfaces.h"
+#include "mlir/Interfaces/ExecutionProgressOpInterface.h"
 #include "mlir/Transforms/InliningUtils.h"
 
 #include "mlir/IR/Builders.h"
@@ -66,3 +68,40 @@ OpFoldResult PoisonOp::fold(FoldAdaptor /*adaptor*/) { return getValue(); }
 
 #define GET_OP_CLASSES
 #include "mlir/Dialect/UB/IR/UBOps.cpp.inc"
+
+namespace {
+/// Canonicalization pattern for RegionBranchOpInterface ops that loop
+/// infinitely. Such ops are replaced with poison values if they "must
+/// progress".
+struct EraseInfiniteRegionBranchLoop : public RewritePattern {
+  EraseInfiniteRegionBranchLoop(MLIRContext *context, StringRef name,
+                                PatternBenefit benefit = 1)
+      : RewritePattern(name, benefit, context) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    auto regionBranchOp = cast<RegionBranchOpInterface>(op);
+    if (mightNotProgress(op))
+      return rewriter.notifyMatchFailure(
+          op, "only loops that must progress are removed");
+    if (!wouldOpBeTriviallyDead(op))
+      return rewriter.notifyMatchFailure(op,
+                                         "only trivially dead ops are removed");
+    if (!isGuaranteedToLoopInfinitely(regionBranchOp))
+      return rewriter.notifyMatchFailure(
+          op, "only loops that loop infinitely are removed");
+    SmallVector<Value> replacements =
+        llvm::map_to_vector(op->getResultTypes(), [&](Type type) {
+          return PoisonOp::create(rewriter, op->getLoc(), type).getResult();
+        });
+    rewriter.replaceOp(op, replacements);
+    return success();
+  }
+};
+} // namespace
+
+void mlir::ub::populateEraseInfiniteRegionBranchLoopPattern(
+    RewritePatternSet &patterns, StringRef opName, PatternBenefit benefit) {
+  patterns.add<EraseInfiniteRegionBranchLoop>(patterns.getContext(), opName,
+                                              benefit);
+}
diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt
index ad3e2b61be418..7919d64b4cc47 100644
--- a/mlir/lib/Interfaces/CMakeLists.txt
+++ b/mlir/lib/Interfaces/CMakeLists.txt
@@ -6,6 +6,7 @@ set(LLVM_OPTIONAL_SOURCES
   DataLayoutInterfaces.cpp
   DerivedAttributeOpInterface.cpp
   DestinationStyleOpInterface.cpp
+  ExecutionProgressOpInterface.cpp
   FunctionImplementation.cpp
   FunctionInterfaces.cpp
   IndexingMapOpInterface.cpp
@@ -49,6 +50,7 @@ add_mlir_interface_library(ControlFlowInterfaces)
 add_mlir_interface_library(DataLayoutInterfaces)
 add_mlir_interface_library(DerivedAttributeOpInterface)
 add_mlir_interface_library(DestinationStyleOpInterface)
+add_mlir_interface_library(ExecutionProgressOpInterface)
 
 add_mlir_library(MLIRFunctionInterfaces
   FunctionInterfaces.cpp
diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index 2f95531455b2b..873685368d996 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -577,6 +577,49 @@ Region *mlir::getEnclosingRepetitiveRegion(Value value) {
   return nullptr;
 }
 
+/// Return "true" if the given region branch op is guaranteed to loop
+/// infinitely. Every path starting from "parent" enters the region, but the
+/// "parent" is not reachable from there.
+bool mlir::isGuaranteedToLoopInfinitely(RegionBranchOpInterface op) {
+  llvm::SmallDenseSet<Region *> visited;
+
+  // Path starts with "parent".
+  SmallVector<RegionBranchPoint> worklist;
+  worklist.push_back(RegionBranchPoint::parent());
+  bool enteredRegion = false;
+  while (!worklist.empty()) {
+    RegionBranchPoint next = worklist.pop_back_val();
+    SmallVector<RegionSuccessor> successors =
+        getSuccessorRegionsWithAttrs(op, next);
+    for (RegionSuccessor successor : successors) {
+      if (successor.isParent()) {
+        // Found path that ends with "parent".
+        return false;
+      }
+      enteredRegion = true;
+      Region *region = successor.getSuccessor();
+      if (!visited.insert(region).second) {
+        // We have already visited this region.
+        continue;
+      }
+      for (Block &block : *region) {
+        auto terminator =
+            dyn_cast<RegionBranchTerminatorOpInterface>(block.back());
+        if (!terminator) {
+          // Region has no RegionBranchTerminatorOpInterface terminator. E.g.,
+          // the terminator could be a "ub.unreachable" op or a "cf.br" op.
+          continue;
+        }
+        worklist.push_back(RegionBranchPoint(terminator));
+      }
+    }
+  }
+  // We visited all paths through the region branch op and the parent was not
+  // reached. If we visited at least one region, it means that we got stuck
+  // inside the region branch op, indicating an infinite loop.
+  return enteredRegion;
+}
+
 /// Return "true" if `a` can be used in lieu of `b`, where `b` is a region
 /// successor input and `a` is a "reachable value" of `b`. Reachable values
 /// are successor operand values that are (maybe transitively) forwarded to
@@ -679,6 +722,36 @@ static llvm::SmallDenseSet<Value> computeReachableValuesFromSuccessorInput(
   return reachableValues;
 }
 
+/// Given a range of values, return a vector of attributes of the same size,
+/// where the i-th attribute is the constant value of the i-th value. If a
+/// value is not constant, the corresponding attribute is null.
+static SmallVector<Attribute> extractConstants(ValueRange values) {
+  return llvm::map_to_vector(values, [](Value value) {
+    Attribute attr;
+    matchPattern(value, m_Constant(&attr));
+    return attr;
+  });
+}
+
+/// Return all successor regions when branching from the given region branch
+/// point. This helper functions extracts all constant operand values and
+/// passes them to the `RegionBranchOpInterface`.
+SmallVector<RegionSuccessor>
+mlir::getSuccessorRegionsWithAttrs(RegionBranchOpInterface op,
+                                   RegionBranchPoint point) {
+  SmallVector<RegionSuccessor> successors;
+  if (point.isParent()) {
+    op.getEntrySuccessorRegions(extractConstants(op->getOperands()),
+                                successors);
+    return successors;
+  }
+  RegionBranchTerminatorOpInterface terminator =
+      point.getTerminatorPredecessorOrNull();
+  terminator.getSuccessorRegions(extractConstants(terminator->getOperands()),
+                                 successors);
+  return successors;
+}
+
 namespace {
 /// Try to make successor inputs dead by replacing their uses with values that
 /// are not successor inputs. This pattern enables additional canonicalization
@@ -1045,36 +1118,6 @@ struct RemoveDuplicateSuccessorInputUses : public RewritePattern {
   }
 };
 
-/// Given a range of values, return a vector of attributes of the same size,
-/// where the i-th attribute is the constant value of the i-th value. If a
-/// value is not constant, the corresponding attribute is null.
-static SmallVector<Attribute> extractConstants(ValueRange values) {
-  return llvm::map_to_vector(values, [](Value value) {
-    Attribute attr;
-    matchPattern(value, m_Constant(&attr));
-    return attr;
-  });
-}
-
-/// Return all successor regions when branching from the given region branch
-/// point. This helper functions extracts all constant operand values and
-/// passes them to the `RegionBranchOpInterface`.
-static SmallVector<RegionSuccessor>
-getSuccessorRegionsWithAttrs(RegionBranchOpInterface op,
-                             RegionBranchPoint point) {
-  SmallVector<RegionSuccessor> successors;
-  if (point.isParent()) {
-    op.getEntrySuccessorRegions(extractConstants(op->getOperands()),
-                                successors);
-    return successors;
-  }
-  RegionBranchTerminatorOpInterface terminator =
-      point.getTerminatorPredecessorOrNull();
-  terminator.getSuccessorRegions(extractConstants(terminator->getOperands()),
-                                 successors);
-  return successors;
-}
-
 /// Find the single acyclic path through the given region branch op. Return an
 /// empty vector if no such path or multiple such paths exist.
 ///
diff --git a/mlir/lib/Interfaces/ExecutionProgressOpInterface.cpp b/mlir/lib/Interfaces/ExecutionProgressOpInterface.cpp
new file mode 100644
index 0000000000000..64ff8bdf00e4c
--- /dev/null
+++ b/mlir/lib/Interfaces/ExecutionProgressOpInterface.cpp
@@ -0,0 +1,29 @@
+//===- ExecutionProgressOpInterface.cpp -- Execution Progress Interface ---===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Interfaces/ExecutionProgressOpInterface.h"
+
+using namespace mlir;
+
+namespace mlir {
+#include "mlir/Interfaces/ExecutionProgressOpInterface.cpp.inc"
+} // namespace mlir
+
+bool mlir::mustProgress(Operation *op) {
+  // Unregistered operations have unknown semantics, so we conservatively
+  // assume that they do not necessarily progress.
+  if (!op->getName().isRegistered())
+    return false;
+  // Registered operations are assumed to progress by default. This can be
+  // overridden by the ExecutionProgressOpInterface.
+  auto executionProgressOpInterface =
+      dyn_cast<ExecutionProgressOpInterface>(op);
+  if (!executionProgressOpInterface)
+    return true;
+  return executionProgressOpInterface.mustProgress();
+}
diff --git a/mlir/test/CAPI/ir.c b/mlir/test/CAPI/ir.c
index c2caae4e795f3..7608ad1b968ef 100644
--- a/mlir/test/CAPI/ir.c
+++ b/mlir/test/CAPI/ir.c
@@ -320,7 +320,7 @@ int collectStats(MlirOperation operation) {
   // clang-format off
   // CHECK-LABEL: @stats
   // CHECK: Number of operations: 12
-  // CHECK: Number of attributes: 5
+  // CHECK: Number of attributes: 6
   // CHECK: Number of blocks: 3
   // CHECK: Number of regions: 3
   // CHECK: Number of values: 9
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index f65046ecee6da..e9f9e1a964963 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -2284,3 +2284,54 @@ func.func @iter_args_cycles_non_cycle_start(%lb : index, %ub : index, %step : in
   }
   return %res#0, %res#1, %res#2 : i32, i32, i32
 }
+
+// -----
+
+// CHECK-LABEL: @erase_infinite_scf_for_loop
+//       CHECK:   %[[poison:.*]] = ub.poison : index
+//       CHECK:   return %[[poison]]
+func.func @erase_infinite_scf_for_loop(%init: index) -> index {
+  %lb = arith.constant 3 : index
+  %ub = arith.constant 4 : index
+  %step = arith.constant 0 : index
+  %res = scf.for %iv = %lb to %ub step %step iter_args(%iter = %init) -> index {
+    %0 = arith.addi %iter, %iter : index
+    scf.yield %0 : index
+  }
+  return %res : index
+}
+
+// -----
+
+// CHECK-LABEL: @do_not_erase_infinite_loop_with_side_effect
+//       CHECK:   %[[res:.*]] = scf.for
+//       CHECK:     vector.print
+//       CHECK:   return %[[res]]
+func.func @do_not_erase_infinite_loop_with_side_effect(%init: index) -> index {
+  %lb = arith.constant 3 : index
+  %ub = arith.constant 4 : index
+  %step = arith.constant 0 : index
+  %res = scf.for %iv = %lb to %ub step %step iter_args(%iter = %init) -> index {
+    %0 = arith.addi %iter, %iter : index
+    vector.print %0 : index
+    scf.yield %0 : index
+  }
+  return %res : index
+}
+
+// -----
+
+// CHECK-LABEL: @erase_infinite_scf_while_loop
+//       CHECK:   %[[poison:.*]] = ub.poison : index
+//       CHECK:   return %[[poison]]
+func.func @erase_infinite_scf_while_loop(%init: index) -> index {
+  %res = scf.while (%arg0 = %init) : (index) -> (index) {
+    %true = arith.constant true
+    scf.condition(%true) %arg0 : index
+  } do {
+  ^bb0(%arg1: index):
+    %0 = arith.addi %arg1, %arg1 : index
+    scf.yield %0 : index
+  }
+  return %res : index
+}
diff --git a/mlir/test/Dialect/SCF/ops.mlir b/mlir/test/Dialect/SCF/ops.mlir
index 5930a1df04266..bee08216165b0 100644
--- a/mlir/test/Dialect/SCF/ops.mlir
+++ b/mlir/test/Dialect/SCF/ops.mlir
@@ -12,7 +12,7 @@ func.func @std_for(%arg0 : index, %arg1 : index, %arg2 : index) {
       %max_cmp = arith.cmpi sge, %i0, %i1 : index
       %max = arith.select %max_cmp, %i0, %i1 : index
       scf.for %i2 = %min to %max step %i1 {
-      }
+      } {mustProgress = false}
     }
   }
   return
@@ -25,6 +25,7 @@ func.func @std_for(%arg0 : index, %arg1 : index, %arg2 : index) {
 //  CHECK-NEXT:       %{{.*}} = arith.cmpi sge, %{{.*}}, %{{.*}} : index
 //  CHECK-NEXT:       %{{.*}} = arith.select %{{.*}}, %{{.*}}, %{{.*}} : index
 //  CHECK-NEXT:       scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
+//  CHECK-NEXT:       } {mustProgress = false}
 
 func.func @std_for_i32(%arg0 : i32, %arg1 : i32, %arg2 : i32) {
   scf.for %i0 = %arg0 to %arg1 step %arg2 : i32 {
@@ -280,8 +281,8 @@ func.func @while() {
     %5:2 = "test.some_operation"(%arg2, %arg3): (i64, f64) -> (i32, f32)
     // CHECK: scf.yield %{{.*}}, %{{.*}} : i32, f32
     scf.yield %5#0, %5#1 : i32, f32
-  // CHECK: attributes {foo = "bar"}
-  } attributes {foo="bar"}
+  // CHECK: attributes {foo = "bar", mustProgress = false}
+  } attributes {foo="bar", mustProgress=false}
   return
 }
 



More information about the Mlir-commits mailing list