[Mlir-commits] [mlir] [mlir][RemoveDeadValues] Simplify branch op handling using ub.poison (PR #182711)

Fedor Nikolaev llvmlistbot at llvm.org
Thu Feb 26 06:58:10 PST 2026


https://github.com/felichita updated https://github.com/llvm/llvm-project/pull/182711

>From 3994508402bbe17a37cf2d137edfa7a7862d171d Mon Sep 17 00:00:00 2001
From: Fedor Nikolaev <fridrixnm at gmail.com>
Date: Sun, 22 Feb 2026 00:01:54 +0100
Subject: [PATCH] [mlir][RemoveDeadValues] Simplify branch op handling using
 ub.poison

Replace the complex block argument removal logic in processBranchOp with
a simpler design based on ub.poison replacement, similar to the approach
used for region branch ops.
Now, dead successor operands are replaced with ub.poison and block
arguments are left intact, relying on the canonicalizer to remove them
once all incoming operands are poison.

Fixes #182263
---
 .../Dialect/ControlFlow/IR/ControlFlowOps.cpp |  83 +++++++++++-
 mlir/lib/Transforms/RemoveDeadValues.cpp      | 128 ++++--------------
 .../Dialect/ControlFlow/canonicalize.mlir     |  24 ++++
 mlir/test/Transforms/remove-dead-values.mlir  |  93 ++++++++++---
 4 files changed, 204 insertions(+), 124 deletions(-)

diff --git a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
index d2078d8ab5ca5..10f3e74512a1a 100644
--- a/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/ControlFlow/IR/ControlFlowOps.cpp
@@ -86,7 +86,7 @@ LogicalResult AssertOp::canonicalize(AssertOp op, PatternRewriter &rewriter) {
   return failure();
 }
 
-// This side effect models "program termination". 
+// This side effect models "program termination".
 void AssertOp::getEffects(
     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
         &effects) {
@@ -196,9 +196,85 @@ static LogicalResult simplifyPassThroughBr(BranchOp op,
   return success();
 }
 
+/// If all incoming values for a block argument from all predecessors are the
+/// same value, replace uses of the block argument with that value. This allows
+/// the block argument to be removed by other canonicalization patterns.
+///
+/// Example:
+///   cf.br ^bb1(%poison : i32)      // pred 1
+///   cf.br ^bb1(%poison : i32)      // pred 2
+/// ^bb1(%arg0: i32):
+///   use(%arg0)
+/// ->
+/// ^bb1(%arg0: i32):
+///   use(%poison)                   // %arg0 now unused, folds away
+///
+static bool simplifyUniformBlockArgs(Block *dest, PatternRewriter &rewriter) {
+  bool changed = false;
+  for (BlockArgument arg : dest->getArguments()) {
+    if (arg.use_empty())
+      continue;
+
+    Value commonValue;
+    bool allSame = true;
+    for (Block *pred : dest->getPredecessors()) {
+      auto branch = dyn_cast<BranchOpInterface>(pred->getTerminator());
+      if (!branch) {
+        allSame = false;
+        break;
+      }
+
+      Value incoming;
+      for (unsigned i = 0; i < branch->getNumSuccessors(); ++i) {
+        if (branch->getSuccessor(i) != dest)
+          continue;
+        SuccessorOperands succOps = branch.getSuccessorOperands(i);
+        if (arg.getArgNumber() >= succOps.size()) {
+          allSame = false;
+          break;
+        }
+        incoming = succOps[arg.getArgNumber()];
+        break;
+      }
+      if (!incoming || (commonValue && commonValue != incoming)) {
+        allSame = false;
+        break;
+      }
+      commonValue = incoming;
+    }
+
+    if (allSame && commonValue && commonValue != arg) {
+      if (!commonValue.getDefiningOp<ub::PoisonOp>())
+        continue;
+      rewriter.replaceAllUsesWith(arg, commonValue);
+      changed = true;
+    }
+  }
+  return changed;
+}
+
+namespace {
+/// Rewrite pattern that replaces block arguments with a uniform incoming value
+/// across all predecessors for CondBranchOp successors.
+struct SimplifyCondBranchBlockArgWithUniformIncomingValues
+    : public OpRewritePattern<CondBranchOp> {
+  SimplifyCondBranchBlockArgWithUniformIncomingValues(MLIRContext *ctx)
+      : OpRewritePattern<CondBranchOp>(ctx, /*benefit=*/0) {}
+
+  LogicalResult matchAndRewrite(CondBranchOp op,
+                                PatternRewriter &rewriter) const override {
+    bool changed = false;
+    for (unsigned i = 0; i < op->getNumSuccessors(); ++i)
+      changed |= simplifyUniformBlockArgs(op->getSuccessor(i), rewriter);
+    return success(changed);
+  }
+};
+} // namespace
+
 LogicalResult BranchOp::canonicalize(BranchOp op, PatternRewriter &rewriter) {
   return success(succeeded(simplifyBrToBlockWithSinglePred(op, rewriter)) ||
-                 succeeded(simplifyPassThroughBr(op, rewriter)));
+                 succeeded(simplifyPassThroughBr(op, rewriter)) ||
+                 simplifyUniformBlockArgs(op.getDest(), rewriter));
 }
 
 void BranchOp::setDest(Block *block) { return setSuccessor(block); }
@@ -484,7 +560,8 @@ void CondBranchOp::getCanonicalizationPatterns(RewritePatternSet &results,
   results.add<SimplifyConstCondBranchPred, SimplifyPassThroughCondBranch,
               SimplifyCondBranchIdenticalSuccessors,
               SimplifyCondBranchFromCondBranchOnSameCondition,
-              CondBranchTruthPropagation, DropUnreachableCondBranch>(context);
+              CondBranchTruthPropagation, DropUnreachableCondBranch,
+              SimplifyCondBranchBlockArgWithUniformIncomingValues>(context);
 }
 
 SuccessorOperands CondBranchOp::getSuccessorOperands(unsigned index) {
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 12a47ba2fb65a..214574503429a 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -20,6 +20,8 @@
 // terminator operands of region branch ops, and,
 // (D) Removes simple and region branch ops that have all non-live results and
 // don't affect memory in any way,
+// (E) Replaces dead operands of branch ops with `ub.poison`, relying on the
+//     canonicalizer to remove the corresponding block arguments.
 //
 // iff
 //
@@ -101,24 +103,11 @@ struct OperandsToCleanup {
   bool replaceWithPoison = false;
 };
 
-struct BlockArgsToCleanup {
-  Block *b;
-  BitVector nonLiveArgs;
-};
-
-struct SuccessorOperandsToCleanup {
-  BranchOpInterface branch;
-  unsigned successorIndex;
-  BitVector nonLiveOperands;
-};
-
 struct RDVFinalCleanupList {
   SmallVector<Operation *> operations;
   SmallVector<FunctionToCleanUp> functions;
   SmallVector<OperandsToCleanup> operands;
   SmallVector<ResultsToCleanup> results;
-  SmallVector<BlockArgsToCleanup> blocks;
-  SmallVector<SuccessorOperandsToCleanup> successorOperands;
 };
 
 // Some helper functions...
@@ -470,11 +459,10 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
 ///
 /// Otherwise, iterate through each successor block of `branchOp`.
 /// (1) For each successor block, gather all operands from all successors.
-/// (2) Fetch their associated liveness analysis data and collect for future
-///     removal.
-/// (3) Identify and collect the dead operands from the successor block
-///     as well as their corresponding arguments.
-
+/// (2) Determine which operands are dead using liveness analysis.
+/// (3) Replace dead successor operands with ub.poison instead of erasing them.
+///     Block arguments are left intact — the canonicalizer will remove them
+///     once it sees all incoming operands are poison.
 static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
                             DenseSet<Value> &nonLiveSet,
                             RDVFinalCleanupList &cl) {
@@ -496,28 +484,20 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la,
     return;
   }
 
+  // For each successor, find dead forwarded operands and
+  // schedule them for replacement with ub.poison.
+  BitVector opNonLive(branchOp->getNumOperands(), false);
   for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) {
-    Block *successorBlock = branchOp->getSuccessor(succIdx);
-
-    // Do (1)
-    SuccessorOperands successorOperands =
-        branchOp.getSuccessorOperands(succIdx);
-    SmallVector<Value> operandValues;
-    for (unsigned operandIdx = 0; operandIdx < successorOperands.size();
-         ++operandIdx) {
-      operandValues.push_back(successorOperands[operandIdx]);
+    for (OpOperand &opOperand :
+         branchOp.getSuccessorOperands(succIdx).getMutableForwardedOperands()) {
+      if (!hasLive(opOperand.get(), nonLiveSet, la))
+        opNonLive.set(opOperand.getOperandNumber());
     }
-
-    // Do (2)
-    BitVector successorNonLive =
-        markLives(operandValues, nonLiveSet, la).flip();
-    collectNonLiveValues(nonLiveSet, successorBlock->getArguments(),
-                         successorNonLive);
-
-    // Do (3)
-    cl.blocks.push_back({successorBlock, successorNonLive});
-    cl.successorOperands.push_back({branchOp, succIdx, successorNonLive});
   }
+
+  if (opNonLive.any())
+    cl.operands.push_back({branchOp.getOperation(), opNonLive,
+                           /*callee=*/nullptr, /*replaceWithPoison=*/true});
 }
 
 /// Create ub.poison ops for the given values. If a value has no uses, return
@@ -558,56 +538,7 @@ static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) {
   TrackingListener listener;
   IRRewriter rewriter(ctx, &listener);
 
-  // 1. Blocks, We must remove the block arguments and successor operands before
-  // deleting the operation, as they may reside in the region operation.
-  LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists";
-  for (auto &b : list.blocks) {
-    // blocks that are accessed via multiple codepaths processed once
-    if (b.b->getNumArguments() != b.nonLiveArgs.size())
-      continue;
-    LDBG_OS([&](raw_ostream &os) {
-      os << "Erasing non-live arguments [";
-      llvm::interleaveComma(b.nonLiveArgs.set_bits(), os);
-      os << "] from block #" << b.b->computeBlockNumber() << " in region #"
-         << b.b->getParent()->getRegionNumber() << " of operation "
-         << OpWithFlags(b.b->getParent()->getParentOp(),
-                        OpPrintingFlags().skipRegions().printGenericOpForm());
-    });
-    // Note: Iterate from the end to make sure that that indices of not yet
-    // processes arguments do not change.
-    for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
-      if (!b.nonLiveArgs[i])
-        continue;
-      b.b->getArgument(i).dropAllUses();
-      b.b->eraseArgument(i);
-    }
-  }
-
-  // 2. Successor Operands
-  LDBG() << "Cleaning up " << list.successorOperands.size()
-         << " successor operand lists";
-  for (auto &op : list.successorOperands) {
-    SuccessorOperands successorOperands =
-        op.branch.getSuccessorOperands(op.successorIndex);
-    // blocks that are accessed via multiple codepaths processed once
-    if (successorOperands.size() != op.nonLiveOperands.size())
-      continue;
-    LDBG_OS([&](raw_ostream &os) {
-      os << "Erasing non-live successor operands [";
-      llvm::interleaveComma(op.nonLiveOperands.set_bits(), os);
-      os << "] from successor " << op.successorIndex << " of branch: "
-         << OpWithFlags(op.branch.getOperation(),
-                        OpPrintingFlags().skipRegions().printGenericOpForm());
-    });
-    // it iterates backwards because erase invalidates all successor indexes
-    for (int i = successorOperands.size() - 1; i >= 0; --i) {
-      if (!op.nonLiveOperands[i])
-        continue;
-      successorOperands.erase(i);
-    }
-  }
-
-  // 3. Functions
+  // 1. Functions
   LDBG() << "Cleaning up " << list.functions.size() << " functions";
   // Record which function arguments were erased so we can shrink call-site
   // argument segments for CallOpInterface operations (e.g. ops using
@@ -638,7 +569,7 @@ static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) {
     (void)f.funcOp.eraseResults(f.nonLiveRets);
   }
 
-  // 4. Operands
+  // 2. Operands
   LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
   for (OperandsToCleanup &o : list.operands) {
     // Handle call-specific cleanup only when we have a cached callee reference.
@@ -696,7 +627,7 @@ static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) {
     }
   }
 
-  // 5. Results
+  // 3. Results
   LDBG() << "Cleaning up " << list.results.size() << " result lists";
   for (auto &r : list.results) {
     LDBG_OS([&](raw_ostream &os) {
@@ -709,7 +640,7 @@ static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) {
     dropUsesAndEraseResults(rewriter, r.op, r.nonLive);
   }
 
-  // 6. Operations
+  // 4. Operations
   LDBG() << "Cleaning up " << list.operations.size() << " operations";
   for (Operation *op : list.operations) {
     LDBG() << "Erasing operation: "
@@ -746,7 +677,7 @@ static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) {
     rewriter.eraseOp(op);
   }
 
-  // 7. Remove all dead poison ops.
+  // 5. Remove all dead poison ops.
   for (ub::PoisonOp poisonOp : listener.poisonOps) {
     if (poisonOp.use_empty())
       poisonOp.erase();
@@ -796,20 +727,17 @@ void RemoveDeadValues::runOnOperation() {
   if (!canonicalize)
     return;
 
-  // Canonicalize all region branch ops.
-  SmallVector<Operation *> opsToCanonicalize;
-  module->walk([&](RegionBranchOpInterface regionBranchOp) {
-    opsToCanonicalize.push_back(regionBranchOp.getOperation());
-  });
-  // Collect all canonicalization patterns for region branch ops.
+  // Canonicalize all region branch ops and branch ops.
   RewritePatternSet owningPatterns(context);
   DenseSet<RegisteredOperationName> populatedPatterns;
-  for (Operation *op : opsToCanonicalize)
+  module->walk([&](Operation *op) {
+    if (!isa<RegionBranchOpInterface, BranchOpInterface>(op))
+      return;
     if (std::optional<RegisteredOperationName> info = op->getRegisteredInfo())
       if (populatedPatterns.insert(*info).second)
         info->getCanonicalizationPatterns(owningPatterns, context);
-  if (failed(applyOpPatternsGreedily(opsToCanonicalize,
-                                     std::move(owningPatterns)))) {
+  });
+  if (failed(applyPatternsGreedily(module, std::move(owningPatterns)))) {
     module->emitError("greedy pattern rewrite failed to converge");
     signalPassFailure();
   }
diff --git a/mlir/test/Dialect/ControlFlow/canonicalize.mlir b/mlir/test/Dialect/ControlFlow/canonicalize.mlir
index 21a16784b81b2..cdf8ecd003630 100644
--- a/mlir/test/Dialect/ControlFlow/canonicalize.mlir
+++ b/mlir/test/Dialect/ControlFlow/canonicalize.mlir
@@ -656,3 +656,27 @@ func.func @drop_unreachable_branch_2(%c: i1) {
 ^bb2:
   ub.unreachable
 }
+
+// CHECK-LABEL: @uniform_poison_block_arg_br
+// CHECK-NEXT: return
+func.func @uniform_poison_block_arg_br(%c: i1) {
+  %0 = ub.poison : i32
+  cf.br ^bb1(%0 : i32)
+^bb1(%arg0: i32):
+  cf.br ^bb2
+^bb2:
+  return
+}
+
+// CHECK-LABEL: @uniform_poison_block_arg_cond_br
+// CHECK: cf.cond_br
+// CHECK-NOT: i32
+func.func @uniform_poison_block_arg_cond_br(%c: i1) {
+  %0 = ub.poison : i32
+  %1 = ub.poison : i32
+  cf.cond_br %c, ^bb1(%0 : i32), ^bb2(%1 : i32)
+^bb1(%arg0: i32):
+  return
+^bb2(%arg1: i32):
+  return
+}
diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir
index 87e77b2eb700f..b2113e7439e5b 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -35,21 +35,25 @@ module @named_module_acceptable {
 func.func @acceptable_ir_has_cleanable_loop_of_conditional_and_branch_op(%arg0: i1) {
   %non_live = arith.constant 0 : i32
   // CHECK-NOT: arith.constant
+  // CHECK-CANONICALIZE-NOT: arith.constant
   cf.br ^bb1(%non_live : i32)
-  // CHECK: cf.br ^[[BB1:bb[0-9]+]]
-^bb1(%non_live_1 : i32):
-  // CHECK: ^[[BB1]]:
-  %non_live_5 = arith.constant 1 : i32
-  cf.br ^bb3(%non_live_1, %non_live_5 : i32, i32)
-  // CHECK: cf.br ^[[BB3:bb[0-9]+]]
-  // CHECK-NOT: i32
-^bb3(%non_live_2 : i32, %non_live_6 : i32):
-  // CHECK: ^[[BB3]]:
-  cf.cond_br %arg0, ^bb1(%non_live_2 : i32), ^bb4(%non_live_2 : i32)
-  // CHECK: cf.cond_br %arg0, ^[[BB1]], ^[[BB4:bb[0-9]+]]
-^bb4(%non_live_4 : i32):
-  // CHECK: ^[[BB4]]:
-  return
+    // CHECK: cf.br ^[[BB1:bb[0-9]+]](%{{.*}} : i32)
+    // CHECK-CANONICALIZE: cf.br ^[[BB1:bb[0-9]+]]
+  ^bb1(%non_live_1 : i32):
+    // CHECK: ^[[BB1]](%{{.*}}: i32):
+    // CHECK-CANONICALIZE: ^[[BB1]]:
+    %non_live_5 = arith.constant 1 : i32
+    cf.br ^bb3(%non_live_1, %non_live_5 : i32, i32)
+    // CHECK: cf.br ^[[BB3:bb[0-9]+]](%{{.*}}, %{{.*}} : i32, i32)
+    // CHECK-CANONICALIZE: cf.cond_br %arg0, ^[[BB1]], ^[[BB2:bb[0-9]+]]
+  ^bb3(%non_live_2 : i32, %non_live_6 : i32):
+    // CHECK: ^[[BB3]](%{{.*}}: i32, %{{.*}}: i32):
+    cf.cond_br %arg0, ^bb1(%non_live_2 : i32), ^bb4(%non_live_2 : i32)
+    // CHECK: cf.cond_br %arg0, ^[[BB1]](%{{.*}} : i32), ^[[BB4:bb[0-9]+]](%{{.*}} : i32)
+  ^bb4(%non_live_4 : i32):
+    // CHECK: ^[[BB4]](%{{.*}}: i32):
+    // CHECK-CANONICALIZE: ^[[BB2]]:
+    return
 }
 
 // -----
@@ -345,9 +349,9 @@ func.func private @identity(%arg1 : i32) -> (i32) {
 // Note that this cleanup cannot be done by the `canonicalize` pass.
 //
 // CHECK-CANONICALIZE:       func.func @clean_region_branch_op_remove_result(%[[arg0:.*]]: index, %[[arg1:.*]]: memref<i32>) {
+// CHECK-CANONICALIZE:         %[[c10:.*]] = arith.constant 10
 // CHECK-CANONICALIZE-NEXT:    scf.index_switch %[[arg0]]
 // CHECK-CANONICALIZE-NEXT:    case 1 {
-// CHECK-CANONICALIZE-NEXT:      %[[c10:.*]] = arith.constant 10
 // CHECK-CANONICALIZE-NEXT:      memref.store %[[c10]], %[[arg1]][]
 // CHECK-CANONICALIZE:           scf.yield
 // CHECK-CANONICALIZE-NEXT:    }
@@ -476,6 +480,53 @@ func.func @kernel(%arg0: memref<18xf32>) {
 
 // -----
 
+// Test that RemoveDeadValues does not crash when gpu.launch appears in a block
+// with multiple predecessors. The dead branch operand (%c20) must be replaced
+// with ub.poison, gpu.launch and its grid/block size operands must be
+// preserved, and the live block argument must remain intact.
+//
+// CHECK-LABEL: func.func @gpu_launch_in_multi_predecessor_block
+// CHECK: arith.constant true
+// CHECK: cf.cond_br
+// CHECK: arith.constant 10
+// CHECK: cf.br ^[[BB3:bb[0-9]+]](%{{.*}} : i64)
+// CHECK: ub.poison
+// CHECK: cf.br ^[[BB3]](%{{.*}} : i64)
+// CHECK: ^[[BB3]](%{{.*}}: i64):
+// CHECK: return
+// CHECK-NOT: arith.constant 20
+//
+// CHECK-CANONICALIZE-LABEL: func.func @gpu_launch_in_multi_predecessor_block
+// CHECK-CANONICALIZE: arith.constant 10
+// CHECK-CANONICALIZE: return
+// CHECK-CANONICALIZE-NOT: arith.constant true
+// CHECK-CANONICALIZE-NOT: arith.constant 20
+// CHECK-CANONICALIZE-NOT: cf.cond_br
+// CHECK-CANONICALIZE-NOT: cf.br
+module {
+  func.func @gpu_launch_in_multi_predecessor_block() -> i64 {
+    %cond = arith.constant true
+    cf.cond_br %cond, ^bb1, ^bb2
+  ^bb1:
+    %c10 = arith.constant 10 : i64
+    cf.br ^bb3(%c10 : i64)
+  ^bb2:
+    %c20 = arith.constant 20 : i64
+    cf.br ^bb3(%c20 : i64)
+  ^bb3(%arg0: i64):
+    %c1 = arith.constant 1 : index
+    gpu.launch
+      blocks(%bx, %by, %bz) in (%gx = %c1, %gy = %c1, %gz = %c1)
+      threads(%tx, %ty, %tz) in (%bsx = %c1, %bsy = %c1, %bsz = %c1) {
+      %blk_x = gpu.block_id x
+      %thr_x = gpu.thread_id x
+      gpu.terminator
+    }
+    func.return %arg0 : i64
+  }
+}
+
+// -----
 
 // CHECK-LABEL: llvm_unreachable
 // CHECK-LABEL: @fn_with_llvm_unreachable
@@ -768,6 +819,7 @@ func.func @affine_loop_no_use_iv_has_side_effect_op() {
 // CHECK:         return %[[while]]#0
 
 // CHECK-CANONICALIZE-LABEL: func @scf_while_dead_iter_args()
+// CHECK-CANONICALIZE:         %[[p0:.*]] = ub.poison : i32
 // CHECK-CANONICALIZE:         %[[c5:.*]] = arith.constant 5 : i32
 // CHECK-CANONICALIZE:         %[[while:.*]] = scf.while (%[[arg0:.*]] = %[[c5]]) : (i32) -> i32 {
 // CHECK-CANONICALIZE:           vector.print %[[arg0]]
@@ -775,7 +827,6 @@ func.func @affine_loop_no_use_iv_has_side_effect_op() {
 // CHECK-CANONICALIZE:           scf.condition(%[[cmpi]]) %[[arg0]]
 // CHECK-CANONICALIZE:         } do {
 // CHECK-CANONICALIZE:         ^bb0(%[[arg1:.*]]: i32):
-// CHECK-CANONICALIZE:           %[[p0:.*]] = ub.poison : i32
 // CHECK-CANONICALIZE:           scf.yield %[[p0]]
 // CHECK-CANONICALIZE:         }
 // CHECK-CANONICALIZE:         return %[[while]]
@@ -800,6 +851,11 @@ func.func @scf_while_dead_iter_args() -> i32 {
 // -----
 
 // CHECK-LABEL: func.func @replace_dead_operation_results_with_poison
+// CHECK-CANONICALIZE-LABEL: func.func @replace_dead_operation_results_with_poison
+// CHECK-CANONICALIZE: %[[p:.*]] = ub.poison : vector<1xindex>
+// CHECK-CANONICALIZE: return %[[p]]
+// CHECK-CANONICALIZE-NOT: scf.while
+// CHECK-CANONICALIZE-NOT: "test.three"
 func.func @replace_dead_operation_results_with_poison(%0: vector<1xindex>) -> vector<1xindex> {
   %1 = scf.while (%arg0 = %0) : (vector<1xindex>) -> vector<1xindex> {
     %cond = arith.constant true
@@ -813,11 +869,6 @@ func.func @replace_dead_operation_results_with_poison(%0: vector<1xindex>) -> ve
     // the condition itself is well-formed IR. This prevents a crash in the
     // canonicalization phase which happens after the dead value removal phase.
     // Also check that only used results of an erased op are replaced with ub.poison.
-    // CHECK-CANONICALIZE:      %[[COND:.*]] = ub.poison : i1
-    // CHECK-CANONICALIZE-NEXT: %[[NEXT:.*]] = ub.poison : vector<1xindex>
-    // CHECK-CANONICALIZE-NEXT: scf.condition(%[[COND]]) %[[NEXT]]
-    // CHECK-CANONICALIZE-NOT: ub.poison : i32
-    // CHECK-CANONICALIZE-NOT: "test.three"
     %cond, %unused, %next = "test.three"(%1) : (vector<1xindex>) -> (i1, i32, vector<1xindex>)
     scf.condition(%cond) %next : vector<1xindex>
   } do {



More information about the Mlir-commits mailing list