[Mlir-commits] [mlir] [MLIR][Linalg] Safely Unwind Block Split After Detensorizing (PR #171918)
Miloš Poletanović
llvmlistbot at llvm.org
Wed Dec 24 04:13:02 PST 2025
https://github.com/milos1397 updated https://github.com/llvm/llvm-project/pull/171918
>From bfee810c08caff18db80e9ea15872d65f201d974 Mon Sep 17 00:00:00 2001
From: Milos Poletanovic <mpoletanovic at syrmia.com>
Date: Thu, 11 Dec 2025 13:48:01 +0100
Subject: [PATCH 1/2] [MLIR][Linalg] Fix: Safely Unwind Block Split After
Detensorizing
Condition the final merge of the dummy 'postEntryBlock' into the 'entryBlock'
on the post-entry block having zero block arguments. The detensorizing conversion
pass often rewrites the function's control-flow graph (CFG), introducing arguments
and back-edges that prevent the blocks from merging.
---
.../Dialect/Linalg/Transforms/Detensorize.cpp | 15 +++++--
.../detensorize_entry_block_skip_merge.mlir | 42 +++++++++++++++++++
2 files changed, 53 insertions(+), 4 deletions(-)
create mode 100644 mlir/test/Dialect/Linalg/detensorize_entry_block_skip_merge.mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 830905495e759..0ac2beb819370 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -560,10 +560,17 @@ struct LinalgDetensorize
if (failed(applyPatternsGreedily(getOperation(), std::move(canonPatterns))))
signalPassFailure();
- // Get rid of the dummy entry block we created in the beginning to work
- // around dialect conversion signature rewriting.
- rewriter.eraseOp(branch);
- rewriter.mergeBlocks(postEntryBlock, entryBlock);
+ // Only attempt to unwind the initial block split (merge postEntryBlock
+ // back into entryBlock) if the dialect conversion did NOT modify
+ // postEntryBlock block signature.
+ // If the postEntryBlock block has arguments the merge is unsafe because
+ // mergeBlocks requires 0 block arguments and no predecessors.
+ if (postEntryBlock->getNumArguments() == 0) {
+ // Get rid of the dummy entry block we created in the beginning to work
+ // around dialect conversion signature rewriting.
+ rewriter.eraseOp(branch);
+ rewriter.mergeBlocks(postEntryBlock, entryBlock);
+ }
}
};
} // namespace
diff --git a/mlir/test/Dialect/Linalg/detensorize_entry_block_skip_merge.mlir b/mlir/test/Dialect/Linalg/detensorize_entry_block_skip_merge.mlir
new file mode 100644
index 0000000000000..ae98cdb7ebc30
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/detensorize_entry_block_skip_merge.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline="builtin.module(func.func(linalg-detensorize))" | FileCheck %s
+
+module {
+ memref.global "private" constant @__constant_4x4xf32 : memref<4x4xf32> = dense<8.899000e+01> {alignment = 64 : i64}
+ func.func private @parallel_compute_fn_with_aligned_loops(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index, %arg8: index, %arg9: index, %arg10: memref<4x4xf32>, %arg11: memref<4x4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ cf.br ^bb1(%c0 : index)
+ ^bb1(%0: index): // 2 preds: ^bb0, ^bb5
+ %1 = arith.cmpi slt, %0, %c4 : index
+ cf.cond_br %1, ^bb2(%0 : index), ^bb6
+ ^bb2(%2: index): // pred: ^bb1
+ %3 = arith.addi %2, %c1 : index
+ cf.br ^bb3(%c0 : index)
+ ^bb3(%4: index): // 2 preds: ^bb2, ^bb4
+ %5 = arith.cmpi slt, %4, %c4 : index
+ cf.cond_br %5, ^bb4(%4 : index), ^bb5
+ ^bb4(%6: index): // pred: ^bb3
+ %7 = arith.addi %6, %c1 : index
+ %8 = memref.load %arg10[%2, %6] : memref<4x4xf32>
+ %9 = llvm.intr.tanh(%8) : (f32) -> f32
+ memref.store %9, %arg11[%2, %6] : memref<4x4xf32>
+ cf.br ^bb3(%7 : index)
+ ^bb5: // pred: ^bb3
+ cf.br ^bb1(%3 : index)
+ ^bb6: // pred: ^bb1
+ return
+ }
+}
+
+// CHECK-LABEL: @parallel_compute_fn_with_aligned_loops
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index, %[[ARG4:.+]]: index, %[[ARG5:.+]]: index, %[[ARG6:.+]]: index, %[[ARG7:.+]]: index, %[[ARG8:.+]]: index, %[[ARG9:.+]]: index, %[[ARG10:.+]]: memref<4x4xf32>, %[[ARG11:.+]]: memref<4x4xf32>)
+// CHECK: cf.br ^{{.*}}
+// CHECK: ^{{.*}}:
+// CHECK: arith.cmpi slt
+// CHECK: cf.cond_br
+// CHECK: arith.addi
+// CHECK: memref.load
+// CHECK: llvm.intr.tanh
+// CHECK: memref.store
+// CHECK: return
>From 60a67fb845371fe95ea31d65e9c6e09e76e1b78b Mon Sep 17 00:00:00 2001
From: Milos Poletanovic <mpoletanovic at syrmia.com>
Date: Wed, 24 Dec 2025 11:45:15 +0100
Subject: [PATCH 2/2] Move the test into existing one.
---
.../Linalg/detensorize_entry_block.mlir | 44 +++++++++++++++++++
.../detensorize_entry_block_skip_merge.mlir | 42 ------------------
2 files changed, 44 insertions(+), 42 deletions(-)
delete mode 100644 mlir/test/Dialect/Linalg/detensorize_entry_block_skip_merge.mlir
diff --git a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
index 50a2d6bf532aa..28b79568002a1 100644
--- a/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
+++ b/mlir/test/Dialect/Linalg/detensorize_entry_block.mlir
@@ -19,3 +19,47 @@ func.func @main(%arg0: tensor<f32>) -> tensor<f32> {
// CHECK: ^{{.*}}:
// CHECK: %[[ELEMENTS:.+]] = tensor.from_elements %[[EXTRACTED]] : tensor<f32>
// CHECK: return %[[ELEMENTS]] : tensor<f32>
+
+
+module {
+ memref.global "private" constant @__constant_4x4xf32 : memref<4x4xf32> = dense<8.899000e+01> {alignment = 64 : i64}
+ func.func private @parallel_compute_fn_with_aligned_loops(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index, %arg8: index, %arg9: index, %arg10: memref<4x4xf32>, %arg11: memref<4x4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ cf.br ^bb1(%c0 : index)
+ ^bb1(%0: index): // 2 preds: ^bb0, ^bb4
+ %1 = arith.cmpi slt, %0, %c4 : index
+ cf.cond_br %1, ^bb2(%0 : index), ^bb6
+ ^bb2(%2: index): // pred: ^bb1
+ %3 = arith.addi %2, %c1 : index
+ cf.br ^bb3(%c0 : index)
+ ^bb3(%4: index): // 2 preds: ^bb2, ^bb4
+ %5 = arith.cmpi slt, %4, %c4 : index
+ cf.br ^bb4
+ ^bb4: // pred: ^bb3
+ cf.br ^bb1(%3 : index)
+ ^bb6: // pred: ^bb1
+ return
+ }
+}
+
+// CHECK-LABEL: func.func private @parallel_compute_fn_with_aligned_loops(
+// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index, {{.*}}) {
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C4:.*]] = arith.constant 4 : index
+// CHECK: cf.br ^bb1(%[[C0]] : index)
+// CHECK: ^bb1(%[[VAL_0:.*]]: index):
+// CHECK: %[[CMPI_0:.*]] = arith.cmpi slt, %[[VAL_0]], %[[C4]] : index
+// CHECK: cf.cond_br %[[CMPI_0]], ^bb2(%[[VAL_0]] : index), ^bb5
+// CHECK: ^bb2(%[[VAL_1:.*]]: index):
+// CHECK: %[[ADDI_0:.*]] = arith.addi %[[VAL_1]], %[[C1]] : index
+// CHECK: cf.br ^bb3(%[[C0]] : index)
+// CHECK: ^bb3(%[[VAL_2:.*]]: index):
+// CHECK: %[[CMPI_1:.*]] = arith.cmpi slt, %[[VAL_2]], %[[C4]] : index
+// CHECK: cf.br ^bb4
+// CHECK: ^bb4:
+// CHECK: cf.br ^bb1(%[[ADDI_0]] : index)
+// CHECK: ^bb5:
+// CHECK: return
\ No newline at end of file
diff --git a/mlir/test/Dialect/Linalg/detensorize_entry_block_skip_merge.mlir b/mlir/test/Dialect/Linalg/detensorize_entry_block_skip_merge.mlir
deleted file mode 100644
index ae98cdb7ebc30..0000000000000
--- a/mlir/test/Dialect/Linalg/detensorize_entry_block_skip_merge.mlir
+++ /dev/null
@@ -1,42 +0,0 @@
-// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline="builtin.module(func.func(linalg-detensorize))" | FileCheck %s
-
-module {
- memref.global "private" constant @__constant_4x4xf32 : memref<4x4xf32> = dense<8.899000e+01> {alignment = 64 : i64}
- func.func private @parallel_compute_fn_with_aligned_loops(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index, %arg8: index, %arg9: index, %arg10: memref<4x4xf32>, %arg11: memref<4x4xf32>) {
- %c0 = arith.constant 0 : index
- %c1 = arith.constant 1 : index
- %c4 = arith.constant 4 : index
- cf.br ^bb1(%c0 : index)
- ^bb1(%0: index): // 2 preds: ^bb0, ^bb5
- %1 = arith.cmpi slt, %0, %c4 : index
- cf.cond_br %1, ^bb2(%0 : index), ^bb6
- ^bb2(%2: index): // pred: ^bb1
- %3 = arith.addi %2, %c1 : index
- cf.br ^bb3(%c0 : index)
- ^bb3(%4: index): // 2 preds: ^bb2, ^bb4
- %5 = arith.cmpi slt, %4, %c4 : index
- cf.cond_br %5, ^bb4(%4 : index), ^bb5
- ^bb4(%6: index): // pred: ^bb3
- %7 = arith.addi %6, %c1 : index
- %8 = memref.load %arg10[%2, %6] : memref<4x4xf32>
- %9 = llvm.intr.tanh(%8) : (f32) -> f32
- memref.store %9, %arg11[%2, %6] : memref<4x4xf32>
- cf.br ^bb3(%7 : index)
- ^bb5: // pred: ^bb3
- cf.br ^bb1(%3 : index)
- ^bb6: // pred: ^bb1
- return
- }
-}
-
-// CHECK-LABEL: @parallel_compute_fn_with_aligned_loops
-// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index, %[[ARG4:.+]]: index, %[[ARG5:.+]]: index, %[[ARG6:.+]]: index, %[[ARG7:.+]]: index, %[[ARG8:.+]]: index, %[[ARG9:.+]]: index, %[[ARG10:.+]]: memref<4x4xf32>, %[[ARG11:.+]]: memref<4x4xf32>)
-// CHECK: cf.br ^{{.*}}
-// CHECK: ^{{.*}}:
-// CHECK: arith.cmpi slt
-// CHECK: cf.cond_br
-// CHECK: arith.addi
-// CHECK: memref.load
-// CHECK: llvm.intr.tanh
-// CHECK: memref.store
-// CHECK: return
More information about the Mlir-commits
mailing list