[Mlir-commits] [mlir] [mlir][UB] Add `ub.unreachable` canonicalization (PR #169873)
Matthias Springer
llvmlistbot at llvm.org
Sat Nov 29 18:19:21 PST 2025
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/169873
>From 04a3be8ccf052e0aa0eb4269228773206dbfdfca Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 28 Nov 2025 04:59:57 +0000
Subject: [PATCH 1/2] [mlir][UB] Add `ub.unreachable` canonicalization
---
.../Dialect/ControlFlow/IR/ControlFlowOps.td | 2 +-
mlir/include/mlir/Dialect/UB/IR/UBOps.h | 4 +++
mlir/include/mlir/Dialect/UB/IR/UBOps.td | 1 +
.../lib/Dialect/ControlFlow/IR/CMakeLists.txt | 1 +
.../Dialect/ControlFlow/IR/ControlFlowOps.cpp | 32 ++++++++++++++++++-
mlir/lib/Dialect/UB/IR/UBOps.cpp | 26 +++++++++++++++
.../Dialect/ControlFlow/canonicalize.mlir | 25 +++++++++++++++
mlir/test/Dialect/UB/canonicalize.mlir | 10 ++++++
8 files changed, 99 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
index a441fd82546e3..c9b4da44ffa01 100644
--- a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
@@ -22,7 +22,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
def ControlFlow_Dialect : Dialect {
let name = "cf";
let cppNamespace = "::mlir::cf";
- let dependentDialects = ["arith::ArithDialect"];
+ let dependentDialects = ["arith::ArithDialect", "ub::UBDialect"];
let description = [{
This dialect contains low-level, i.e. non-region based, control flow
constructs. These constructs generally represent control flow directly
diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.h b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
index 21de5cb0c182a..02081e2d6d15f 100644
--- a/mlir/include/mlir/Dialect/UB/IR/UBOps.h
+++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
@@ -9,6 +9,10 @@
#ifndef MLIR_DIALECT_UB_IR_OPS_H
#define MLIR_DIALECT_UB_IR_OPS_H
+namespace mlir {
+class PatternRewriter;
+}
+
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpImplementation.h"
diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.td b/mlir/include/mlir/Dialect/UB/IR/UBOps.td
index 1bff39add691e..95c52db404912 100644
--- a/mlir/include/mlir/Dialect/UB/IR/UBOps.td
+++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.td
@@ -84,6 +84,7 @@ def UnreachableOp : UB_Op<"unreachable", [Terminator]> {
}];
let assemblyFormat = "attr-dict";
+ let hasCanonicalizeMethod = 1;
}
#endif // MLIR_DIALECT_UB_IR_UBOPS_TD
diff --git a/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt b/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt
index 58551bb435c86..05a787fa53ec3 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt
+++ b/mlir/lib/Dialect/ControlFlow/IR/CMakeLists.txt
@@ -12,4 +12,5 @@ add_mlir_dialect_library(MLIRControlFlowDialect
MLIRControlFlowInterfaces
MLIRIR
MLIRSideEffectInterfaces
+ MLIRUBDialect
)
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index f1da1a125e9ef..218758bc0aac5 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -12,6 +12,7 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
@@ -445,6 +446,35 @@ struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
return success(replaced);
}
};
+
+struct DropUnreachableCondBranch : public OpRewritePattern<CondBranchOp> {
+ using OpRewritePattern<CondBranchOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(CondBranchOp condbr,
+ PatternRewriter &rewriter) const override {
+ // If the "true" destination is unreachable, branch to the "false"
+ // destination.
+ Block *trueDest = condbr.getTrueDest();
+ Block *falseDest = condbr.getFalseDest();
+ if (llvm::hasSingleElement(*trueDest) &&
+ isa<ub::UnreachableOp>(trueDest->getTerminator())) {
+ rewriter.replaceOpWithNewOp<BranchOp>(condbr, falseDest,
+ condbr.getFalseOperands());
+ return success();
+ }
+
+ // If the "false" destination is unreachable, branch to the "true"
+ // destination.
+ if (llvm::hasSingleElement(*falseDest) &&
+ isa<ub::UnreachableOp>(falseDest->getTerminator())) {
+ rewriter.replaceOpWithNewOp<BranchOp>(condbr, trueDest,
+ condbr.getTrueOperands());
+ return success();
+ }
+
+ return failure();
+ }
+};
} // namespace
void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
@@ -452,7 +482,7 @@ void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
SimplifyCondBranchIdenticalSuccessors,
SimplifyCondBranchFromCondBranchOnSameCondition,
- CondBranchTruthPropagation>(context);
+ CondBranchTruthPropagation, DropUnreachableCondBranch>(context);
}
SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
diff --git a/mlir/lib/Dialect/UB/IR/UBOps.cpp b/mlir/lib/Dialect/UB/IR/UBOps.cpp
index ee523f9522953..419e3f9d76fb2 100644
--- a/mlir/lib/Dialect/UB/IR/UBOps.cpp
+++ b/mlir/lib/Dialect/UB/IR/UBOps.cpp
@@ -12,6 +12,7 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/UB/IR/UBOpsDialect.cpp.inc"
@@ -57,8 +58,33 @@ Operation *UBDialect::materializeConstant(OpBuilder &builder, Attribute value,
return nullptr;
}
+//===----------------------------------------------------------------------===//
+// PoisonOp
+//===----------------------------------------------------------------------===//
+
OpFoldResult PoisonOp::fold(FoldAdaptor /*adaptor*/) { return getValue(); }
+//===----------------------------------------------------------------------===//
+// UnreachableOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult UnreachableOp::canonicalize(UnreachableOp unreachableOp,
+ PatternRewriter &rewriter) {
+ Block *block = unreachableOp->getBlock();
+ if (llvm::hasSingleElement(*block))
+ return rewriter.notifyMatchFailure(
+ unreachableOp, "unreachable op is the only operation in the block");
+
+ // Erase all other operations in the block. They must be dead.
+ for (Operation &op : llvm::make_early_inc_range(*block)) {
+ if (&op == unreachableOp.getOperation())
+ continue;
+ op.dropAllUses();
+ rewriter.eraseOp(&op);
+ }
+ return success();
+}
+
#include "mlir/Dialect/UB/IR/UBOpsInterfaces.cpp.inc"
#define GET_ATTRDEF_CLASSES
diff --git a/mlir/test/Dialect/ControlFlow/canonicalize.mlir b/mlir/test/Dialect/ControlFlow/canonicalize.mlir
index 17f7d28ba59fb..75dec6dacde91 100644
--- a/mlir/test/Dialect/ControlFlow/canonicalize.mlir
+++ b/mlir/test/Dialect/ControlFlow/canonicalize.mlir
@@ -634,3 +634,28 @@ func.func @unsimplified_cycle_2(%c : i1) {
^bb7:
cf.br ^bb6
}
+
+// CHECK-LABEL: @drop_unreachable_branch_1
+// CHECK-NEXT: "test.foo"() : () -> ()
+// CHECK-NEXT: return
+func.func @drop_unreachable_branch_1(%c: i1) {
+ cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+ "test.foo"() : () -> ()
+ return
+^bb2:
+ "test.bar"() : () -> ()
+ ub.unreachable
+}
+
+// CHECK-LABEL: @drop_unreachable_branch_2
+// CHECK-NEXT: ub.unreachable
+func.func @drop_unreachable_branch_2(%c: i1) {
+ cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+ "test.foo"() : () -> ()
+ ub.unreachable
+^bb2:
+ "test.bar"() : () -> ()
+ ub.unreachable
+}
diff --git a/mlir/test/Dialect/UB/canonicalize.mlir b/mlir/test/Dialect/UB/canonicalize.mlir
index c3f286e49b09b..74ba9f1932384 100644
--- a/mlir/test/Dialect/UB/canonicalize.mlir
+++ b/mlir/test/Dialect/UB/canonicalize.mlir
@@ -9,3 +9,13 @@ func.func @merge_poison() -> (i32, i32) {
%1 = ub.poison : i32
return %0, %1 : i32, i32
}
+
+// -----
+
+// CHECK-LABEL: func @drop_ops_before_unreachable()
+// CHECK-NEXT: ub.unreachable
+func.func @drop_ops_before_unreachable() {
+ "test.foo"() : () -> ()
+ "test.bar"() : () -> ()
+ ub.unreachable
+}
>From 143a40b335b87f85b6ef35054272c1b32763d661 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Sun, 30 Nov 2025 02:18:02 +0000
Subject: [PATCH 2/2] remove UB canonicalization
---
.../Dialect/ControlFlow/IR/ControlFlowOps.td | 2 +-
mlir/include/mlir/Dialect/UB/IR/UBOps.h | 4 ---
mlir/include/mlir/Dialect/UB/IR/UBOps.td | 1 -
.../Dialect/ControlFlow/IR/ControlFlowOps.cpp | 2 ++
mlir/lib/Dialect/UB/IR/UBOps.cpp | 26 -------------------
.../Dialect/ControlFlow/canonicalize.mlir | 3 ---
mlir/test/Dialect/UB/canonicalize.mlir | 10 -------
7 files changed, 3 insertions(+), 45 deletions(-)
diff --git a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
index c9b4da44ffa01..a441fd82546e3 100644
--- a/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td
@@ -22,7 +22,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
def ControlFlow_Dialect : Dialect {
let name = "cf";
let cppNamespace = "::mlir::cf";
- let dependentDialects = ["arith::ArithDialect", "ub::UBDialect"];
+ let dependentDialects = ["arith::ArithDialect"];
let description = [{
This dialect contains low-level, i.e. non-region based, control flow
constructs. These constructs generally represent control flow directly
diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.h b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
index 02081e2d6d15f..21de5cb0c182a 100644
--- a/mlir/include/mlir/Dialect/UB/IR/UBOps.h
+++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.h
@@ -9,10 +9,6 @@
#ifndef MLIR_DIALECT_UB_IR_OPS_H
#define MLIR_DIALECT_UB_IR_OPS_H
-namespace mlir {
-class PatternRewriter;
-}
-
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpImplementation.h"
diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.td b/mlir/include/mlir/Dialect/UB/IR/UBOps.td
index 95c52db404912..1bff39add691e 100644
--- a/mlir/include/mlir/Dialect/UB/IR/UBOps.td
+++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.td
@@ -84,7 +84,6 @@ def UnreachableOp : UB_Op<"unreachable", [Terminator]> {
}];
let assemblyFormat = "attr-dict";
- let hasCanonicalizeMethod = 1;
}
#endif // MLIR_DIALECT_UB_IR_UBOPS_TD
diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index 218758bc0aac5..d2078d8ab5ca5 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -447,6 +447,8 @@ struct CondBranchTruthPropagation : public OpRewritePattern<CondBranchOp> {
}
};
+/// If the destination block of a conditional branch contains only
+/// ub.unreachable, unconditionally branch to the other destination.
struct DropUnreachableCondBranch : public OpRewritePattern<CondBranchOp> {
using OpRewritePattern<CondBranchOp>::OpRewritePattern;
diff --git a/mlir/lib/Dialect/UB/IR/UBOps.cpp b/mlir/lib/Dialect/UB/IR/UBOps.cpp
index 419e3f9d76fb2..ee523f9522953 100644
--- a/mlir/lib/Dialect/UB/IR/UBOps.cpp
+++ b/mlir/lib/Dialect/UB/IR/UBOps.cpp
@@ -12,7 +12,6 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
-#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/UB/IR/UBOpsDialect.cpp.inc"
@@ -58,33 +57,8 @@ Operation *UBDialect::materializeConstant(OpBuilder &builder, Attribute value,
return nullptr;
}
-//===----------------------------------------------------------------------===//
-// PoisonOp
-//===----------------------------------------------------------------------===//
-
OpFoldResult PoisonOp::fold(FoldAdaptor /*adaptor*/) { return getValue(); }
-//===----------------------------------------------------------------------===//
-// UnreachableOp
-//===----------------------------------------------------------------------===//
-
-LogicalResult UnreachableOp::canonicalize(UnreachableOp unreachableOp,
- PatternRewriter &rewriter) {
- Block *block = unreachableOp->getBlock();
- if (llvm::hasSingleElement(*block))
- return rewriter.notifyMatchFailure(
- unreachableOp, "unreachable op is the only operation in the block");
-
- // Erase all other operations in the block. They must be dead.
- for (Operation &op : llvm::make_early_inc_range(*block)) {
- if (&op == unreachableOp.getOperation())
- continue;
- op.dropAllUses();
- rewriter.eraseOp(&op);
- }
- return success();
-}
-
#include "mlir/Dialect/UB/IR/UBOpsInterfaces.cpp.inc"
#define GET_ATTRDEF_CLASSES
diff --git a/mlir/test/Dialect/ControlFlow/canonicalize.mlir b/mlir/test/Dialect/ControlFlow/canonicalize.mlir
index 75dec6dacde91..21a16784b81b2 100644
--- a/mlir/test/Dialect/ControlFlow/canonicalize.mlir
+++ b/mlir/test/Dialect/ControlFlow/canonicalize.mlir
@@ -644,7 +644,6 @@ func.func @drop_unreachable_branch_1(%c: i1) {
"test.foo"() : () -> ()
return
^bb2:
- "test.bar"() : () -> ()
ub.unreachable
}
@@ -653,9 +652,7 @@ func.func @drop_unreachable_branch_1(%c: i1) {
func.func @drop_unreachable_branch_2(%c: i1) {
cf.cond_br %c, ^bb1, ^bb2
^bb1:
- "test.foo"() : () -> ()
ub.unreachable
^bb2:
- "test.bar"() : () -> ()
ub.unreachable
}
diff --git a/mlir/test/Dialect/UB/canonicalize.mlir b/mlir/test/Dialect/UB/canonicalize.mlir
index 74ba9f1932384..c3f286e49b09b 100644
--- a/mlir/test/Dialect/UB/canonicalize.mlir
+++ b/mlir/test/Dialect/UB/canonicalize.mlir
@@ -9,13 +9,3 @@ func.func @merge_poison() -> (i32, i32) {
%1 = ub.poison : i32
return %0, %1 : i32, i32
}
-
-// -----
-
-// CHECK-LABEL: func @drop_ops_before_unreachable()
-// CHECK-NEXT: ub.unreachable
-func.func @drop_ops_before_unreachable() {
- "test.foo"() : () -> ()
- "test.bar"() : () -> ()
- ub.unreachable
-}
More information about the Mlir-commits
mailing list