[Mlir-commits] [mlir] [mlir][Transforms] Fix crash in `-remove-dead-values` on private functions (PR #169269)

Matthias Springer llvmlistbot at llvm.org
Thu Nov 27 21:55:09 PST 2025


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/169269

>From a47b28cf95f5fca857184f05f0c3840b24e470d5 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 28 Nov 2025 04:25:38 +0000
Subject: [PATCH 1/3] [mlir][UB] Add `ub.unreachable` operation

---
 mlir/include/mlir/Dialect/UB/IR/UBOps.td      | 20 +++++++++++
 mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp     | 35 +++++++++++++++----
 mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp   | 14 +++++++-
 mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir |  6 ++++
 .../Conversion/UBToSPIRV/ub-to-spirv.mlir     | 15 ++++++++
 mlir/test/Dialect/UB/ops.mlir                 |  6 ++++
 6 files changed, 88 insertions(+), 8 deletions(-)

diff --git a/mlir/include/mlir/Dialect/UB/IR/UBOps.td b/mlir/include/mlir/Dialect/UB/IR/UBOps.td
index c400a2ef2cc7a..8a354da2db10c 100644
--- a/mlir/include/mlir/Dialect/UB/IR/UBOps.td
+++ b/mlir/include/mlir/Dialect/UB/IR/UBOps.td
@@ -66,4 +66,24 @@ def PoisonOp : UB_Op<"poison", [ConstantLike, Pure]> {
   let hasFolder = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// UnreachableOp
+//===----------------------------------------------------------------------===//
+
+def UnreachableOp : UB_Op<"unreachable", [Terminator]> {
+  let summary = "Unreachable operation.";
+  let description = [{
+    The `unreachable` operation has no defined semantics. This operation
+    indicates that its enclosing basic block is not reachable.
+
+    Example:
+
+    ```
+    ub.unreachable
+    ```
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
 #endif // MLIR_DIALECT_UB_IR_UBOPS_TD
diff --git a/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp b/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp
index 9921a06778dd7..feb04899cb33d 100644
--- a/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp
+++ b/mlir/lib/Conversion/UBToLLVM/UBToLLVM.cpp
@@ -23,8 +23,11 @@ namespace mlir {
 
 using namespace mlir;
 
-namespace {
+//===----------------------------------------------------------------------===//
+// PoisonOpLowering
+//===----------------------------------------------------------------------===//
 
+namespace {
 struct PoisonOpLowering : public ConvertOpToLLVMPattern<ub::PoisonOp> {
   using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
 
@@ -32,13 +35,8 @@ struct PoisonOpLowering : public ConvertOpToLLVMPattern<ub::PoisonOp> {
   matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override;
 };
-
 } // namespace
 
-//===----------------------------------------------------------------------===//
-// PoisonOpLowering
-//===----------------------------------------------------------------------===//
-
 LogicalResult
 PoisonOpLowering::matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor,
                                   ConversionPatternRewriter &rewriter) const {
@@ -60,6 +58,29 @@ PoisonOpLowering::matchAndRewrite(ub::PoisonOp op, OpAdaptor adaptor,
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// UnreachableOpLowering
+//===----------------------------------------------------------------------===//
+
+namespace {
+struct UnreachableOpLowering
+    : public ConvertOpToLLVMPattern<ub::UnreachableOp> {
+  using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(ub::UnreachableOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+} // namespace
+LogicalResult
+
+UnreachableOpLowering::matchAndRewrite(
+    ub::UnreachableOp op, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+  rewriter.replaceOpWithNewOp<LLVM::UnreachableOp>(op);
+  return success();
+}
+
 //===----------------------------------------------------------------------===//
 // Pass Definition
 //===----------------------------------------------------------------------===//
@@ -93,7 +114,7 @@ struct UBToLLVMConversionPass
 
 void mlir::ub::populateUBToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
-  patterns.add<PoisonOpLowering>(converter);
+  patterns.add<PoisonOpLowering, UnreachableOpLowering>(converter);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp
index 244d214cba196..3831387816eaf 100644
--- a/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp
+++ b/mlir/lib/Conversion/UBToSPIRV/UBToSPIRV.cpp
@@ -40,6 +40,17 @@ struct PoisonOpLowering final : OpConversionPattern<ub::PoisonOp> {
   }
 };
 
+struct UnreachableOpLowering final : OpConversionPattern<ub::UnreachableOp> {
+  using Base::Base;
+
+  LogicalResult
+  matchAndRewrite(ub::UnreachableOp op, OpAdaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.replaceOpWithNewOp<spirv::UnreachableOp>(op);
+    return success();
+  }
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -75,5 +86,6 @@ struct UBToSPIRVConversionPass final
 
 void mlir::ub::populateUBToSPIRVConversionPatterns(
     const SPIRVTypeConverter &converter, RewritePatternSet &patterns) {
-  patterns.add<PoisonOpLowering>(converter, patterns.getContext());
+  patterns.add<PoisonOpLowering, UnreachableOpLowering>(converter,
+                                                        patterns.getContext());
 }
diff --git a/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir b/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir
index 6c0b111d4c2c5..0fe63f5a3a89f 100644
--- a/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir
+++ b/mlir/test/Conversion/UBToLLVM/ub-to-llvm.mlir
@@ -17,3 +17,9 @@ func.func @check_poison() {
   %3 = ub.poison : !llvm.ptr
   return
 }
+
+// CHECK-LABEL: @check_unrechable
+func.func @check_unrechable() {
+// CHECK: llvm.unreachable
+  ub.unreachable
+}
diff --git a/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir b/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir
index f497eb3bc552c..edbe8b8001bba 100644
--- a/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir
+++ b/mlir/test/Conversion/UBToSPIRV/ub-to-spirv.mlir
@@ -19,3 +19,18 @@ func.func @check_poison() {
 }
 
 }
+
+// -----
+
+// No successful test because the dialect conversion framework does not convert
+// unreachable blocks.
+
+module attributes {
+  spirv.target_env = #spirv.target_env<
+    #spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, #spirv.resource_limits<>>
+} {
+func.func @check_unrechable() {
+// expected-error at +1{{cannot be used in reachable block}}
+  spirv.Unreachable
+}
+}
diff --git a/mlir/test/Dialect/UB/ops.mlir b/mlir/test/Dialect/UB/ops.mlir
index 724b6b4caac5d..730c1bd1380b8 100644
--- a/mlir/test/Dialect/UB/ops.mlir
+++ b/mlir/test/Dialect/UB/ops.mlir
@@ -38,3 +38,9 @@ func.func @poison_tensor() -> tensor<8x?xf64> {
   %0 = ub.poison : tensor<8x?xf64>
   return %0 : tensor<8x?xf64>
 }
+
+// CHECK-LABEL: func @unreachable()
+//       CHECK:   ub.unreachable
+func.func @unreachable() {
+  ub.unreachable
+}

>From c1b8c55aa730c27e4cc5c9910e2303604f1d50ed Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 24 Nov 2025 02:42:09 +0000
Subject: [PATCH 2/3] [mlir][Transforms] Fix crash in `-remove-dead-values` for
 private functions

---
 mlir/lib/Transforms/RemoveDeadValues.cpp     | 38 ++++++++++++++++++++
 mlir/test/Transforms/remove-dead-values.mlir | 11 ++++++
 2 files changed, 49 insertions(+)

diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 989c614ef6617..9d4d24c39c116 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -141,6 +141,33 @@ static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
   return false;
 }
 
+/// Return true iff at least one value in `values` is dead, given the liveness
+/// information in `la`.
+static bool hasDead(ValueRange values, const DenseSet<Value> &nonLiveSet,
+                    RunLivenessAnalysis &la) {
+  for (Value value : values) {
+    if (nonLiveSet.contains(value)) {
+      LDBG() << "Value " << value << " is already marked non-live (dead)";
+      return true;
+    }
+
+    const Liveness *liveness = la.getLiveness(value);
+    if (!liveness) {
+      LDBG() << "Value " << value
+             << " has no liveness info, conservatively considered live";
+      continue;
+    }
+    if (liveness->isLive) {
+      LDBG() << "Value " << value << " is live according to liveness analysis";
+      continue;
+    } else {
+      LDBG() << "Value " << value << " is dead according to liveness analysis";
+      return true;
+    }
+  }
+  return false;
+}
+
 /// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
 /// i-th value in `values` is live, given the liveness information in `la`.
 static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
@@ -260,6 +287,17 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
 static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
                             DenseSet<Value> &nonLiveSet,
                             RDVFinalCleanupList &cl) {
+  if (hasDead(op->getOperands(), nonLiveSet, la)) {
+    LDBG() << "Simple op has dead operands, so the op must be dead: "
+           << OpWithFlags(op, OpPrintingFlags().skipRegions());
+    assert(!hasLive(op->getResults(), nonLiveSet, la) &&
+           "expected the op to have no live results");
+    cl.operations.push_back(op);
+    collectNonLiveValues(nonLiveSet, op->getResults(),
+                         BitVector(op->getNumResults(), true));
+    return;
+  }
+
   if (!isMemoryEffectFree(op) || hasLive(op->getResults(), nonLiveSet, la)) {
     LDBG() << "Simple op is not memory effect free or has live results, "
               "preserving it: "
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 4bae85dcf4f7d..af157fc8bc5b0 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -118,6 +118,17 @@ func.func @main(%arg0 : i32) {
 
 // -----
 
+// CHECK-LABEL: func.func private @clean_func_op_remove_side_effecting_op() {
+// CHECK-NEXT:    return
+// CHECK-NEXT:  }
+func.func private @clean_func_op_remove_side_effecting_op(%arg0: i32) -> (i32) {
+  // vector.print has a side effect but the op is dead.
+  vector.print %arg0 : i32
+  return %arg0 : i32
+}
+
+// -----
+
 // %arg0 is not live because it is never used. %arg1 is not live because its
 // user `arith.addi` doesn't have any uses and the value that it is forwarded to
 // (%non_live_0) also doesn't have any uses.

>From d2349cae1f9e84dadeb8c2c9844755706011a01a Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Fri, 28 Nov 2025 05:54:43 +0000
Subject: [PATCH 3/3] address comments

---
 mlir/include/mlir/Transforms/Passes.td       |  1 +
 mlir/lib/Transforms/CMakeLists.txt           |  1 +
 mlir/lib/Transforms/RemoveDeadValues.cpp     | 65 +++++++++++---------
 mlir/test/Transforms/remove-dead-values.mlir | 16 +++++
 4 files changed, 53 insertions(+), 30 deletions(-)

diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td
index 28b4a01cf0ecd..55addfdb693e4 100644
--- a/mlir/include/mlir/Transforms/Passes.td
+++ b/mlir/include/mlir/Transforms/Passes.td
@@ -248,6 +248,7 @@ def RemoveDeadValues : Pass<"remove-dead-values"> {
     ```
   }];
   let constructor = "mlir::createRemoveDeadValuesPass()";
+  let dependentDialects = ["ub::UBDialect"];
 }
 
 def PrintIRPass : Pass<"print-ir"> {
diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt
index 54b67f5c7a91e..06161293e907f 100644
--- a/mlir/lib/Transforms/CMakeLists.txt
+++ b/mlir/lib/Transforms/CMakeLists.txt
@@ -39,4 +39,5 @@ add_mlir_library(MLIRTransforms
   MLIRSideEffectInterfaces
   MLIRSupport
   MLIRTransformUtils
+  MLIRUBDialect
   )
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 9d4d24c39c116..e9ced064c9884 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -33,6 +33,7 @@
 
 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
 #include "mlir/Analysis/DataFlow/LivenessAnalysis.h"
+#include "mlir/Dialect/UB/IR/UBOps.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/Dialect.h"
@@ -141,33 +142,6 @@ static bool hasLive(ValueRange values, const DenseSet<Value> &nonLiveSet,
   return false;
 }
 
-/// Return true iff at least one value in `values` is dead, given the liveness
-/// information in `la`.
-static bool hasDead(ValueRange values, const DenseSet<Value> &nonLiveSet,
-                    RunLivenessAnalysis &la) {
-  for (Value value : values) {
-    if (nonLiveSet.contains(value)) {
-      LDBG() << "Value " << value << " is already marked non-live (dead)";
-      return true;
-    }
-
-    const Liveness *liveness = la.getLiveness(value);
-    if (!liveness) {
-      LDBG() << "Value " << value
-             << " has no liveness info, conservatively considered live";
-      continue;
-    }
-    if (liveness->isLive) {
-      LDBG() << "Value " << value << " is live according to liveness analysis";
-      continue;
-    } else {
-      LDBG() << "Value " << value << " is dead according to liveness analysis";
-      return true;
-    }
-  }
-  return false;
-}
-
 /// Return a BitVector of size `values.size()` where its i-th bit is 1 iff the
 /// i-th value in `values` is live, given the liveness information in `la`.
 static BitVector markLives(ValueRange values, const DenseSet<Value> &nonLiveSet,
@@ -287,7 +261,12 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
 static void processSimpleOp(Operation *op, RunLivenessAnalysis &la,
                             DenseSet<Value> &nonLiveSet,
                             RDVFinalCleanupList &cl) {
-  if (hasDead(op->getOperands(), nonLiveSet, la)) {
+  // Operations that have dead operands can be erased regardless of their
+  // side effects. The liveness analysis would not have marked an SSA value as
+  // "dead" if it had a side-effecting user that is reachable.
+  bool hasDeadOperand =
+      markLives(op->getOperands(), nonLiveSet, la).flip().any();
+  if (hasDeadOperand) {
     LDBG() << "Simple op has dead operands, so the op must be dead: "
            << OpWithFlags(op, OpPrintingFlags().skipRegions());
     assert(!hasLive(op->getResults(), nonLiveSet, la) &&
@@ -399,6 +378,8 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
   // block other than the entry block, because every block has a terminator.
   for (Block &block : funcOp.getBlocks()) {
     Operation *returnOp = block.getTerminator();
+    if (!returnOp->hasTrait<OpTrait::ReturnLike>())
+      continue;
     if (returnOp && returnOp->getNumOperands() == numReturns)
       cl.operands.push_back({returnOp, nonLiveRets});
   }
@@ -738,7 +719,11 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
 }
 
 /// Steps to process a `BranchOpInterface` operation:
-/// Iterate through each successor block of `branchOp`.
+///
+/// When a non-forwarded operand is dead (e.g., the condition value of a
+/// conditional branch op), the entire operation is dead.
+///
+/// Otherwise, iterate through each successor block of `branchOp`.
 /// (1) For each successor block, gather all operands from all successors.
 /// (2) Fetch their associated liveness analysis data and collect for future
 ///     removal.
@@ -749,7 +734,22 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
                             DenseSet<Value> &nonLiveSet,
                             RDVFinalCleanupList &cl) {
   LDBG() << "Processing branch op: " << *branchOp;
+
+  // Check for dead non-forwarded operands.
+  BitVector deadNonForwardedOperands =
+      markLives(branchOp->getOperands(), nonLiveSet, la).flip();
   unsigned numSuccessors = branchOp->getNumSuccessors();
+  for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
+    SuccessorOperands successorOperands =
+        branchOp.getSuccessorOperands(succIdx);
+    // Remove all non-forwarded operands from the bit vector.
+    for (OpOperand &opOperand : successorOperands.getMutableForwardedOperands())
+      deadNonForwardedOperands[opOperand.getOperandNumber()] = false;
+  }
+  if (deadNonForwardedOperands.any()) {
+    cl.operations.push_back(branchOp.getOperation());
+    return;
+  }
 
   for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
     Block *successorBlock = branchOp->getSuccessor(succIdx);
@@ -824,9 +824,14 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
 
   // 3. Operations
   LDBG() << "Cleaning up " << list.operations.size() << " operations";
-  for (auto &op : list.operations) {
+  for (Operation *op : list.operations) {
     LDBG() << "Erasing operation: "
            << OpWithFlags(op, OpPrintingFlags().skipRegions());
+    if (op->hasTrait<OpTrait::IsTerminator>()) {
+      // When erasing a terminator, insert an unreachable op in its place.
+      OpBuilder b(op);
+      ub::UnreachableOp::create(b, op->getLoc());
+    }
     op->dropAllUses();
     op->erase();
   }
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index af157fc8bc5b0..6ebc43cd4d3bd 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -698,3 +698,19 @@ func.func @op_block_have_dead_arg(%arg0: index, %arg1: index, %arg2: i1) {
   // CHECK-NEXT: return
   return
 }
+
+// -----
+
+// CHECK-LABEL: func private @remove_dead_branch_op()
+//  CHECK-NEXT:   ub.unreachable
+//  CHECK-NEXT: ^{{.*}}:
+//  CHECK-NEXT:   return
+//  CHECK-NEXT: ^{{.*}}:
+//  CHECK-NEXT:   return
+func.func private @remove_dead_branch_op(%c: i1, %arg0: i64, %arg1: i64) -> (i64) {
+  cf.cond_br %c, ^bb1, ^bb2
+^bb1:
+  return %arg0 : i64
+^bb2:
+  return %arg1 : i64
+}



More information about the Mlir-commits mailing list