[Mlir-commits] [mlir] [MLIR][Linalg] Safely Unwind Block Split After Detensorizing (PR #171918)
Miloš Poletanović
llvmlistbot at llvm.org
Sun Dec 14 07:56:02 PST 2025
https://github.com/milos1397 updated https://github.com/llvm/llvm-project/pull/171918
>From bfee810c08caff18db80e9ea15872d65f201d974 Mon Sep 17 00:00:00 2001
From: Milos Poletanovic <mpoletanovic at syrmia.com>
Date: Thu, 11 Dec 2025 13:48:01 +0100
Subject: [PATCH] [MLIR][Linalg] Fix: Safely Unwind Block Split After
Detensorizing
Condition the final merge of the dummy 'postEntryBlock' into the 'entryBlock'
on the post-entry block having zero block arguments. The detensorizing conversion
pass often rewrites the function's control-flow graph (CFG), introducing arguments
and back-edges that prevent the blocks from merging.
---
.../Dialect/Linalg/Transforms/Detensorize.cpp | 15 +++++--
.../detensorize_entry_block_skip_merge.mlir | 42 +++++++++++++++++++
2 files changed, 53 insertions(+), 4 deletions(-)
create mode 100644 mlir/test/Dialect/Linalg/detensorize_entry_block_skip_merge.mlir
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 830905495e759..0ac2beb819370 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -560,10 +560,17 @@ struct LinalgDetensorize
if (failed(applyPatternsGreedily(getOperation(), std::move(canonPatterns))))
signalPassFailure();
- // Get rid of the dummy entry block we created in the beginning to work
- // around dialect conversion signature rewriting.
- rewriter.eraseOp(branch);
- rewriter.mergeBlocks(postEntryBlock, entryBlock);
+ // Only attempt to unwind the initial block split (merge postEntryBlock
+ // back into entryBlock) if the dialect conversion did NOT modify
+ // postEntryBlock block signature.
+ // If the postEntryBlock block has arguments the merge is unsafe because
+ // mergeBlocks requires 0 block arguments and no predecessors.
+ if (postEntryBlock->getNumArguments() == 0) {
+ // Get rid of the dummy entry block we created in the beginning to work
+ // around dialect conversion signature rewriting.
+ rewriter.eraseOp(branch);
+ rewriter.mergeBlocks(postEntryBlock, entryBlock);
+ }
}
};
} // namespace
diff --git a/mlir/test/Dialect/Linalg/detensorize_entry_block_skip_merge.mlir b/mlir/test/Dialect/Linalg/detensorize_entry_block_skip_merge.mlir
new file mode 100644
index 0000000000000..ae98cdb7ebc30
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/detensorize_entry_block_skip_merge.mlir
@@ -0,0 +1,42 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline="builtin.module(func.func(linalg-detensorize))" | FileCheck %s
+
+module {
+ memref.global "private" constant @__constant_4x4xf32 : memref<4x4xf32> = dense<8.899000e+01> {alignment = 64 : i64}
+ func.func private @parallel_compute_fn_with_aligned_loops(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index, %arg8: index, %arg9: index, %arg10: memref<4x4xf32>, %arg11: memref<4x4xf32>) {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+ %c4 = arith.constant 4 : index
+ cf.br ^bb1(%c0 : index)
+ ^bb1(%0: index): // 2 preds: ^bb0, ^bb5
+ %1 = arith.cmpi slt, %0, %c4 : index
+ cf.cond_br %1, ^bb2(%0 : index), ^bb6
+ ^bb2(%2: index): // pred: ^bb1
+ %3 = arith.addi %2, %c1 : index
+ cf.br ^bb3(%c0 : index)
+ ^bb3(%4: index): // 2 preds: ^bb2, ^bb4
+ %5 = arith.cmpi slt, %4, %c4 : index
+ cf.cond_br %5, ^bb4(%4 : index), ^bb5
+ ^bb4(%6: index): // pred: ^bb3
+ %7 = arith.addi %6, %c1 : index
+ %8 = memref.load %arg10[%2, %6] : memref<4x4xf32>
+ %9 = llvm.intr.tanh(%8) : (f32) -> f32
+ memref.store %9, %arg11[%2, %6] : memref<4x4xf32>
+ cf.br ^bb3(%7 : index)
+ ^bb5: // pred: ^bb3
+ cf.br ^bb1(%3 : index)
+ ^bb6: // pred: ^bb1
+ return
+ }
+}
+
+// CHECK-LABEL: @parallel_compute_fn_with_aligned_loops
+// CHECK-SAME: (%[[ARG0:.+]]: index, %[[ARG1:.+]]: index, %[[ARG2:.+]]: index, %[[ARG3:.+]]: index, %[[ARG4:.+]]: index, %[[ARG5:.+]]: index, %[[ARG6:.+]]: index, %[[ARG7:.+]]: index, %[[ARG8:.+]]: index, %[[ARG9:.+]]: index, %[[ARG10:.+]]: memref<4x4xf32>, %[[ARG11:.+]]: memref<4x4xf32>)
+// CHECK: cf.br ^{{.*}}
+// CHECK: ^{{.*}}:
+// CHECK: arith.cmpi slt
+// CHECK: cf.cond_br
+// CHECK: arith.addi
+// CHECK: memref.load
+// CHECK: llvm.intr.tanh
+// CHECK: memref.store
+// CHECK: return
More information about the Mlir-commits
mailing list