[Mlir-commits] [mlir] [mlir][Interfaces] Fix canonicalize crash from dead successor-input pruning (PR #192222)

Hocky Yudhiono llvmlistbot at llvm.org
Wed Apr 15 02:40:48 PDT 2026


https://github.com/hockyy created https://github.com/llvm/llvm-project/pull/192222

This fixes a canonicalization crash triggered by `scf.for` when running with expensive pattern checks. `RemoveDeadRegionBranchOpSuccessorInputs` used only edge-local successor-input ties. For single-region region-branch ops (like `scf.for`) with constant-pruned control flow, this could miss structural coupling between op results and region block args. As a result, the rewrite could temporarily drop one side but not the other, producing invalid IR (`scf.for` mismatch in loop-carried values vs defined values) and aborting in expensive checks.

>From 15d73c01af363e5e1e0d5aefc4cb60a58a13e7f5 Mon Sep 17 00:00:00 2001
From: Hocky Yudhiono <hocky.yudhiono at gmail.com>
Date: Wed, 15 Apr 2026 17:38:51 +0800
Subject: [PATCH] [mlir][Interfaces] Fix canonicalize crash from dead
 successor-input pruning in single-region region-branch ops

---
 mlir/lib/Interfaces/ControlFlowInterfaces.cpp | 39 +++++++++++++++++++
 mlir/test/Dialect/SCF/canonicalize.mlir       | 21 ++++++++++
 2 files changed, 60 insertions(+)

diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
index c3fb73acf5ef0..7607f7068d708 100644
--- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
+++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp
@@ -799,6 +799,44 @@ static llvm::EquivalenceClasses<Value> computeTiedSuccessorInputs(
   return tiedSuccessorInputs;
 }
 
+/// Successor input mappings are edge-local: when control flow paths are pruned
+/// by constants, some edge pairs may disappear and a pure edge-based tie
+/// relation can miss structural couplings between op results and region block
+/// arguments. For single-region region branch ops (e.g. `scf.for`), tie parent
+/// successor inputs and region successor inputs by slot so canonicalizations
+/// only erase such values together.
+static void tieRegionAndParentSuccessorInputs(
+    RegionBranchOpInterface regionBranchOp,
+    llvm::EquivalenceClasses<Value> &tiedSuccessorInputs) {
+  if (regionBranchOp->getNumRegions() != 1)
+    return;
+
+  ValueRange parentInputs =
+      regionBranchOp.getSuccessorInputs(RegionSuccessor::parent());
+  if (parentInputs.empty())
+    return;
+
+  SmallVector<ValueRange> regionInputs;
+  for (Region &region : regionBranchOp->getRegions()) {
+    ValueRange inputs =
+        regionBranchOp.getSuccessorInputs(RegionSuccessor(&region));
+    if (!inputs.empty())
+      regionInputs.push_back(inputs);
+  }
+  if (regionInputs.empty())
+    return;
+
+  for (ValueRange inputs : regionInputs) {
+    unsigned commonInputCount =
+        std::min<unsigned>(parentInputs.size(), inputs.size());
+    for (unsigned i = 0; i < commonInputCount; ++i) {
+      tiedSuccessorInputs.insert(parentInputs[i]);
+      tiedSuccessorInputs.insert(inputs[i]);
+      tiedSuccessorInputs.unionSets(parentInputs[i], inputs[i]);
+    }
+  }
+}
+
 /// Remove dead successor inputs from region branch ops. A successor input is
 /// dead if it has no uses. Successor inputs come in sets of tied values: if
 /// you remove one value from a set, you must remove all values from the set.
@@ -856,6 +894,7 @@ struct RemoveDeadRegionBranchOpSuccessorInputs : public RewritePattern {
     regionBranchOp.getSuccessorOperandInputMapping(operandToInputs);
     llvm::EquivalenceClasses<Value> tiedSuccessorInputs =
         computeTiedSuccessorInputs(operandToInputs);
+    tieRegionAndParentSuccessorInputs(regionBranchOp, tiedSuccessorInputs);
 
     // Determine which values to remove and group them by block and operation.
     SmallVector<Value> valuesToRemove;
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir
index c324d34942bf8..8ae709d30bec5 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -2360,3 +2360,24 @@ func.func @fold_tensor_cast_into_forall_non_sequential_writes(
   // %0#0 contains %arg1 data; %0#1 contains %arg0 data.
   return %0#0, %0#1 : tensor<?x32xf32>, tensor<?x32xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @single_iteration_loop_keeps_tied_inputs_valid
+func.func @single_iteration_loop_keeps_tied_inputs_valid() {
+  // CHECK: %[[LB:.*]] = arith.constant 42
+  %c42 = arith.constant 42 : index
+  %c43 = arith.constant 43 : index
+  %c1 = arith.constant 1 : index
+  // CHECK: %[[INIT:.*]] = "test.init"
+  %init = "test.init"() : () -> i32
+  // CHECK-NOT: scf.for
+  // CHECK: %[[VAL:.*]] = "test.op"(%[[LB]], %[[INIT]])
+  %0 = scf.for %i = %c42 to %c43 step %c1 iter_args(%arg = %init) -> (i32) {
+    %1 = "test.op"(%i, %init) : (index, i32) -> i32
+    scf.yield %1 : i32
+  }
+  // CHECK: "test.consume"(%[[VAL]]) : (i32) -> ()
+  "test.consume"(%0) : (i32) -> ()
+  return
+}



More information about the Mlir-commits mailing list