[Mlir-commits] [mlir] [mlir][vector] Constrain broadcast->shape_cast folding (PR #190230)
Benoit Jacob
llvmlistbot at llvm.org
Mon Apr 6 08:39:44 PDT 2026
https://github.com/bjacob updated https://github.com/llvm/llvm-project/pull/190230
>From f3acf91cce32a0f76ad96aacd5479ab0657652cc Mon Sep 17 00:00:00 2001
From: Benoit Jacob <benoit.jacob at amd.com>
Date: Thu, 2 Apr 2026 17:53:41 +0000
Subject: [PATCH] fix-ShapeCastBroadcastFolder
Signed-off-by: Benoit Jacob <benoit.jacob at amd.com>
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 16 ++++++++++++++++
.../Linalg/transform-op-mmt4d-to-fma.mlir | 7 ++++++-
.../canonicalize/vector-to-shape-cast.mlir | 17 +++++++++++++++++
3 files changed, 39 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 939816262a2b1..3c306c38e1f31 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6725,6 +6725,22 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
// to
// %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32>
VectorType dstVectorType = shapeCastOp.getResultVectorType();
+ VectorType intermediateType = broadcastOp.getResultVectorType();
+ // Avoid folding if this would result in switching between the two distinct
+ // semantic modes of vector.broadcast (duplication vs stretching).
+ // See https://github.com/llvm/llvm-project/issues/190614.
+ if (!srcIsScalar && srcVectorType.getRank() >= 2) {
+ auto hasUnitDim = [](ArrayRef<int64_t> shape) {
+ return llvm::any_of(shape, [](int64_t d) { return d == 1; });
+ };
+ if (hasUnitDim(srcVectorType.getShape()) &&
+ intermediateType.getRank() > dstVectorType.getRank()) {
+ ArrayRef<int64_t> droppedTrailingDims =
+ intermediateType.getShape().drop_front(dstVectorType.getRank());
+ if (hasUnitDim(droppedTrailingDims))
+ return failure();
+ }
+ }
if (srcIsScalar || isBroadcastableTo(srcVectorType, dstVectorType) ==
BroadcastableToResult::Success) {
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
diff --git a/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
index b5c6e610f58f9..a19bcbdabc264 100644
--- a/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
@@ -10,7 +10,12 @@ func.func @mmt4d_to_fma(%A: tensor<16x16x8x1xf32>, %B: tensor<16x16x8x1xf32>, %C
// CHECK-LABEL: @mmt4d_to_fma
-// CHECK-COUNT-8: vector.fma
+// CHECK-NOT: linalg.mmt4d
+// Lowering may produce vector.fma (outerproduct) or elementwise arith.mulf +
+// vector.reduction depending on vector.contract lowering; both are valid.
+// CHECK: arith.mulf
+// CHECK: vector.reduction
+// CHECK: vector.transfer_write
module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
index 1d082179207dc..c6e9aaf31af73 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
@@ -128,3 +128,20 @@ func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
return %1: vector<1xf32>
}
+// -----
+
+// https://github.com/llvm/llvm-project/issues/190614 — ShapeCastBroadcastFolder
+// must not rewrite shape_cast(broadcast) into a lower-rank broadcast when that
+// changes duplication (new leading dims) vs stretching (source unit dim).
+// For %arg0 = [[1], [2]]: broadcast to 2x2x1 then shape_cast to 2x2 yields
+// [[1, 2], [1, 2]]; folding to broadcast 2x1 -> 2x2 would incorrectly yield
+// [[1, 1], [2, 2]].
+// CHECK-LABEL: @shape_cast_broadcast_folder_issue_190614
+// CHECK: vector.broadcast %{{.*}} : vector<2x1xf32> to vector<2x2x1xf32>
+// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<2x2x1xf32> to vector<2x2xf32>
+func.func @shape_cast_broadcast_folder_issue_190614(%arg0: vector<2x1xf32>) -> vector<2x2xf32> {
+ %0 = vector.broadcast %arg0 : vector<2x1xf32> to vector<2x2x1xf32>
+ %1 = vector.shape_cast %0 : vector<2x2x1xf32> to vector<2x2xf32>
+ return %1 : vector<2x2xf32>
+}
+
More information about the Mlir-commits
mailing list