[Mlir-commits] [mlir] [MLIR][Linalg] Safely Unwind Block Split After Detensorizing (PR #171918)

Miloš Poletanović llvmlistbot at llvm.org
Sun Dec 14 07:56:02 PST 2025


https://github.com/milos1397 updated https://github.com/llvm/llvm-project/pull/171918

>From bfee810c08caff18db80e9ea15872d65f201d974 Mon Sep 17 00:00:00 2001
From: Milos Poletanovic <mpoletanovic at syrmia.com>
Date: Thu, 11 Dec 2025 13:48:01 +0100
Subject: [PATCH] [MLIR][Linalg] Fix: Safely Unwind Block Split After
 Detensorizing

Condition the final merge of the dummy 'postEntryBlock' into the 'entryBlock'
on the post-entry block having zero block arguments. The detensorizing conversion
pass often rewrites the function's control-flow graph (CFG), introducing arguments
and back-edges that prevent the blocks from merging.
---
 .../Dialect/Linalg/Transforms/Detensorize.cpp | 15 +++++--
 .../detensorize_entry_block_skip_merge.mlir   | 42 +++++++++++++++++++
 2 files changed, 53 insertions(+), 4 deletions(-)
 create mode 100644 mlir/test/Dialect/Linalg/detensorize_entry_block_skip_merge.mlir

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 830905495e759..0ac2beb819370 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -560,10 +560,17 @@ struct LinalgDetensorize
     if (failed(applyPatternsGreedily(getOperation(), std::move(canonPatterns))))
       signalPassFailure();
 
-    // Get rid of the dummy entry block we created in the beginning to work
-    // around dialect conversion signature rewriting.
-    rewriter.eraseOp(branch);
-    rewriter.mergeBlocks(postEntryBlock, entryBlock);
+    // Only attempt to unwind the initial block split (merge postEntryBlock
+    // back into entryBlock) if the dialect conversion did NOT modify
+    // postEntryBlock block signature.
+    // If the postEntryBlock block has arguments the merge is unsafe because
+    // mergeBlocks requires 0 block arguments and no predecessors.
+    if (postEntryBlock->getNumArguments() == 0) {
+      // Get rid of the dummy entry block we created in the beginning to work
+      // around dialect conversion signature rewriting.
+      rewriter.eraseOp(branch);
+      rewriter.mergeBlocks(postEntryBlock, entryBlock);
+    }
   }
 };
 } // namespace
diff --git a/mlir/test/Dialect/Linalg/detensorize_entry_block_skip_merge.mlir b/mlir/test/Dialect/Linalg/detensorize_entry_block_skip_merge.mlir
new file mode 100644
index 0000000000000..ae98cdb7ebc30
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/detensorize_entry_block_skip_merge.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline="builtin.module(func.func(linalg-detensorize))" | FileCheck %s
+
+module {
+  memref.global "private" constant @__constant_4x4xf32 : memref<4x4xf32> = dense<8.899000e+01> {alignment = 64 : i64}
+  func.func private @parallel_compute_fn_with_aligned_loops(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index, %arg8: index, %arg9: index, %arg10: memref<4x4xf32>, %arg11: memref<4x4xf32>) {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c4 = arith.constant 4 : index
+    cf.br ^bb1(%c0 : index)
+  ^bb1(%0: index):  // 2 preds: ^bb0, ^bb5
+    %1 = arith.cmpi slt, %0, %c4 : index
+    cf.cond_br %1, ^bb2(%0 : index), ^bb6
+  ^bb2(%2: index):  // pred: ^bb1
+    %3 = arith.addi %2, %c1 : index
+    cf.br ^bb3(%c0 : index)
+  ^bb3(%4: index):  // 2 preds: ^bb2, ^bb4
+    %5 = arith.cmpi slt, %4, %c4 : index
+    cf.cond_br %5, ^bb4(%4 : index), ^bb5
+  ^bb4(%6: index):  // pred: ^bb3
+    %7 = arith.addi %6, %c1 : index
+    %8 = memref.load %arg10[%2, %6] : memref<4x4xf32>
+    %9 = llvm.intr.tanh(%8) : (f32) -> f32
+    memref.store %9, %arg11[%2, %6] : memref<4x4xf32>
+    cf.br ^bb3(%7 : index)
+  ^bb5:  // pred: ^bb3
+    cf.br ^bb1(%3 : index)
+  ^bb6:  // pred: ^bb1
+    return
+  }
+}
+
+// CHECK-LABEL: @parallel_compute_fn_with_aligned_loops
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index, %[[ARG4:.+]]: index, %[[ARG5:.+]]: index, %[[ARG6:.+]]: index, %[[ARG7:.+]]: index, %[[ARG8:.+]]: index, %[[ARG9:.+]]: index, %[[ARG10:.+]]: memref<4x4xf32>, %[[ARG11:.+]]: memref<4x4xf32>)
+// CHECK: cf.br ^{{.*}}
+// CHECK: ^{{.*}}:
+// CHECK: arith.cmpi slt
+// CHECK: cf.cond_br
+// CHECK: arith.addi
+// CHECK: memref.load
+// CHECK: llvm.intr.tanh
+// CHECK: memref.store
+// CHECK: return



More information about the Mlir-commits mailing list