[Mlir-commits] [mlir] [MLIR][Linalg] Safely Unwind Block Split After Detensorizing (PR #171918)

Miloš Poletanović llvmlistbot at llvm.org
Wed Dec 24 04:13:02 PST 2025


https://github.com/milos1397 updated https://github.com/llvm/llvm-project/pull/171918

>From bfee810c08caff18db80e9ea15872d65f201d974 Mon Sep 17 00:00:00 2001
From: Milos Poletanovic <mpoletanovic at syrmia.com>
Date: Thu, 11 Dec 2025 13:48:01 +0100
Subject: [PATCH 1/2] [MLIR][Linalg] Fix: Safely Unwind Block Split After
 Detensorizing

Condition the final merge of the dummy 'postEntryBlock' into the 'entryBlock'
on the post-entry block having zero block arguments. The detensorizing conversion
pass often rewrites the function's control-flow graph (CFG), introducing arguments
and back-edges that prevent the blocks from merging.
---
 .../Dialect/Linalg/Transforms/Detensorize.cpp | 15 +++++--
 .../detensorize_entry_block_skip_merge.mlir   | 42 +++++++++++++++++++
 2 files changed, 53 insertions(+), 4 deletions(-)
 create mode 100644 mlir/test/Dialect/Linalg/detensorize_entry_block_skip_merge.mlir

diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 830905495e759..0ac2beb819370 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -560,10 +560,17 @@ struct LinalgDetensorize
     if (failed(applyPatternsGreedily(getOperation(), std::move(canonPatterns))))
       signalPassFailure();
 
-    // Get rid of the dummy entry block we created in the beginning to work
-    // around dialect conversion signature rewriting.
-    rewriter.eraseOp(branch);
-    rewriter.mergeBlocks(postEntryBlock, entryBlock);
+    // Only attempt to unwind the initial block split (merge postEntryBlock
+    // back into entryBlock) if the dialect conversion did NOT modify
+    // postEntryBlock block signature.
+    // If the postEntryBlock block has arguments the merge is unsafe because
+    // mergeBlocks requires 0 block arguments and no predecessors.
+    if (postEntryBlock->getNumArguments() == 0) {
+      // Get rid of the dummy entry block we created in the beginning to work
+      // around dialect conversion signature rewriting.
+      rewriter.eraseOp(branch);
+      rewriter.mergeBlocks(postEntryBlock, entryBlock);
+    }
   }
 };
 } // namespace
diff --git a/mlir/test/Dialect/Linalg/detensorize_entry_block_skip_merge.mlir b/mlir/test/Dialect/Linalg/detensorize_entry_block_skip_merge.mlir
new file mode 100644
index 0000000000000..ae98cdb7ebc30
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/detensorize_entry_block_skip_merge.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline="builtin.module(func.func(linalg-detensorize))" | FileCheck %s
+
+module {
+  memref.global "private" constant @__constant_4x4xf32 : memref<4x4xf32> = dense<8.899000e+01> {alignment = 64 : i64}
+  func.func private @parallel_compute_fn_with_aligned_loops(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index, %arg8: index, %arg9: index, %arg10: memref<4x4xf32>, %arg11: memref<4x4xf32>) {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c4 = arith.constant 4 : index
+    cf.br ^bb1(%c0 : index)
+  ^bb1(%0: index):  // 2 preds: ^bb0, ^bb5
+    %1 = arith.cmpi slt, %0, %c4 : index
+    cf.cond_br %1, ^bb2(%0 : index), ^bb6
+  ^bb2(%2: index):  // pred: ^bb1
+    %3 = arith.addi %2, %c1 : index
+    cf.br ^bb3(%c0 : index)
+  ^bb3(%4: index):  // 2 preds: ^bb2, ^bb4
+    %5 = arith.cmpi slt, %4, %c4 : index
+    cf.cond_br %5, ^bb4(%4 : index), ^bb5
+  ^bb4(%6: index):  // pred: ^bb3
+    %7 = arith.addi %6, %c1 : index
+    %8 = memref.load %arg10[%2, %6] : memref<4x4xf32>
+    %9 = llvm.intr.tanh(%8) : (f32) -> f32
+    memref.store %9, %arg11[%2, %6] : memref<4x4xf32>
+    cf.br ^bb3(%7 : index)
+  ^bb5:  // pred: ^bb3
+    cf.br ^bb1(%3 : index)
+  ^bb6:  // pred: ^bb1
+    return
+  }
+}
+
+// CHECK-LABEL: @parallel_compute_fn_with_aligned_loops
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index, %[[ARG4:.+]]: index, %[[ARG5:.+]]: index, %[[ARG6:.+]]: index, %[[ARG7:.+]]: index, %[[ARG8:.+]]: index, %[[ARG9:.+]]: index, %[[ARG10:.+]]: memref<4x4xf32>, %[[ARG11:.+]]: memref<4x4xf32>)
+// CHECK: cf.br ^{{.*}}
+// CHECK: ^{{.*}}:
+// CHECK: arith.cmpi slt
+// CHECK: cf.cond_br
+// CHECK: arith.addi
+// CHECK: memref.load
+// CHECK: llvm.intr.tanh
+// CHECK: memref.store
+// CHECK: return

>From 60a67fb845371fe95ea31d65e9c6e09e76e1b78b Mon Sep 17 00:00:00 2001
From: Milos Poletanovic <mpoletanovic at syrmia.com>
Date: Wed, 24 Dec 2025 11:45:15 +0100
Subject: [PATCH 2/2] Move the test into existing one.

---
 .../Linalg/detensorize_entry_block.mlir       | 44 +++++++++++++++++++
 .../detensorize_entry_block_skip_merge.mlir   | 42 ------------------
 2 files changed, 44 insertions(+), 42 deletions(-)
 delete mode 100644 mlir/test/Dialect/Linalg/detensorize_entry_block_skip_merge.mlir

diff --git a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
index 50a2d6bf532aa..28b79568002a1 100644
--- a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
@@ -19,3 +19,47 @@ func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
 // CHECK: ^{{.*}}:
 // CHECK:   %[[ELEMENTS:.+]] = tensor.from_elements %[[EXTRACTED]] : tensor<f32>
 // CHECK:   return %[[ELEMENTS]] : tensor<f32>
+
+
+module {
+  memref.global "private" constant @__constant_4x4xf32 : memref<4x4xf32> = dense<8.899000e+01> {alignment = 64 : i64}
+  func.func private @parallel_compute_fn_with_aligned_loops(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index, %arg8: index, %arg9: index, %arg10: memref<4x4xf32>, %arg11: memref<4x4xf32>) {
+    %c0 = arith.constant 0 : index
+    %c1 = arith.constant 1 : index
+    %c4 = arith.constant 4 : index
+    cf.br ^bb1(%c0 : index)
+  ^bb1(%0: index):  // 2 preds: ^bb0, ^bb4
+    %1 = arith.cmpi slt, %0, %c4 : index
+    cf.cond_br %1, ^bb2(%0 : index), ^bb6
+  ^bb2(%2: index):  // pred: ^bb1
+    %3 = arith.addi %2, %c1 : index
+    cf.br ^bb3(%c0 : index)
+  ^bb3(%4: index):  // 2 preds: ^bb2, ^bb4
+    %5 = arith.cmpi slt, %4, %c4 : index
+    cf.br ^bb4
+  ^bb4:  // pred: ^bb3
+    cf.br ^bb1(%3 : index)
+  ^bb6:  // pred: ^bb1
+    return
+  }
+}
+
+// CHECK-LABEL: func.func private @parallel_compute_fn_with_aligned_loops(
+// CHECK-SAME:    %[[ARG0:.*]]: index, %[[ARG1:.*]]: index, {{.*}}) {
+// CHECK:           %[[C0:.*]] = arith.constant 0 : index
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[C4:.*]] = arith.constant 4 : index
+// CHECK:           cf.br ^bb1(%[[C0]] : index)
+// CHECK:         ^bb1(%[[VAL_0:.*]]: index):
+// CHECK:           %[[CMPI_0:.*]] = arith.cmpi slt, %[[VAL_0]], %[[C4]] : index
+// CHECK:           cf.cond_br %[[CMPI_0]], ^bb2(%[[VAL_0]] : index), ^bb5
+// CHECK:         ^bb2(%[[VAL_1:.*]]: index):
+// CHECK:           %[[ADDI_0:.*]] = arith.addi %[[VAL_1]], %[[C1]] : index
+// CHECK:           cf.br ^bb3(%[[C0]] : index)
+// CHECK:         ^bb3(%[[VAL_2:.*]]: index):
+// CHECK:           %[[CMPI_1:.*]] = arith.cmpi slt, %[[VAL_2]], %[[C4]] : index
+// CHECK:           cf.br ^bb4
+// CHECK:         ^bb4:
+// CHECK:           cf.br ^bb1(%[[ADDI_0]] : index)
+// CHECK:         ^bb5:
+// CHECK:           return
\ No newline at end of file
diff --git a/mlir/test/Dialect/Linalg/detensorize_entry_block_skip_merge.mlir b/mlir/test/Dialect/Linalg/detensorize_entry_block_skip_merge.mlir
deleted file mode 100644
index ae98cdb7ebc30..0000000000000
--- a/mlir/test/Dialect/Linalg/detensorize_entry_block_skip_merge.mlir
+++ /dev/null
@@ -1,42 +0,0 @@
-// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline="builtin.module(func.func(linalg-detensorize))" | FileCheck %s
-
-module {
-  memref.global "private" constant @__constant_4x4xf32 : memref<4x4xf32> = dense<8.899000e+01> {alignment = 64 : i64}
-  func.func private @parallel_compute_fn_with_aligned_loops(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index, %arg8: index, %arg9: index, %arg10: memref<4x4xf32>, %arg11: memref<4x4xf32>) {
-    %c0 = arith.constant 0 : index
-    %c1 = arith.constant 1 : index
-    %c4 = arith.constant 4 : index
-    cf.br ^bb1(%c0 : index)
-  ^bb1(%0: index):  // 2 preds: ^bb0, ^bb5
-    %1 = arith.cmpi slt, %0, %c4 : index
-    cf.cond_br %1, ^bb2(%0 : index), ^bb6
-  ^bb2(%2: index):  // pred: ^bb1
-    %3 = arith.addi %2, %c1 : index
-    cf.br ^bb3(%c0 : index)
-  ^bb3(%4: index):  // 2 preds: ^bb2, ^bb4
-    %5 = arith.cmpi slt, %4, %c4 : index
-    cf.cond_br %5, ^bb4(%4 : index), ^bb5
-  ^bb4(%6: index):  // pred: ^bb3
-    %7 = arith.addi %6, %c1 : index
-    %8 = memref.load %arg10[%2, %6] : memref<4x4xf32>
-    %9 = llvm.intr.tanh(%8) : (f32) -> f32
-    memref.store %9, %arg11[%2, %6] : memref<4x4xf32>
-    cf.br ^bb3(%7 : index)
-  ^bb5:  // pred: ^bb3
-    cf.br ^bb1(%3 : index)
-  ^bb6:  // pred: ^bb1
-    return
-  }
-}
-
-// CHECK-LABEL: @parallel_compute_fn_with_aligned_loops
-// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index, %[[ARG4:.+]]: index, %[[ARG5:.+]]: index, %[[ARG6:.+]]: index, %[[ARG7:.+]]: index, %[[ARG8:.+]]: index, %[[ARG9:.+]]: index, %[[ARG10:.+]]: memref<4x4xf32>, %[[ARG11:.+]]: memref<4x4xf32>)
-// CHECK: cf.br ^{{.*}}
-// CHECK: ^{{.*}}:
-// CHECK: arith.cmpi slt
-// CHECK: cf.cond_br
-// CHECK: arith.addi
-// CHECK: memref.load
-// CHECK: llvm.intr.tanh
-// CHECK: memref.store
-// CHECK: return



More information about the Mlir-commits mailing list