[Mlir-commits] [mlir] [mlir][vector] Constrain broadcast->shape_cast folding (PR #190230)

Benoit Jacob llvmlistbot at llvm.org
Mon Apr 6 08:39:44 PDT 2026


https://github.com/bjacob updated https://github.com/llvm/llvm-project/pull/190230

>From f3acf91cce32a0f76ad96aacd5479ab0657652cc Mon Sep 17 00:00:00 2001
From: Benoit Jacob <benoit.jacob at amd.com>
Date: Thu, 2 Apr 2026 17:53:41 +0000
Subject: [PATCH] fix-ShapeCastBroadcastFolder

Signed-off-by: Benoit Jacob <benoit.jacob at amd.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp        | 16 ++++++++++++++++
 .../Linalg/transform-op-mmt4d-to-fma.mlir       |  7 ++++++-
 .../canonicalize/vector-to-shape-cast.mlir      | 17 +++++++++++++++++
 3 files changed, 39 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 939816262a2b1..3c306c38e1f31 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6725,6 +6725,22 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
     // to
     // %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32>
     VectorType dstVectorType = shapeCastOp.getResultVectorType();
+    VectorType intermediateType = broadcastOp.getResultVectorType();
+    // Avoid folding if this would result in switching between the two distinct
+    // semantic modes of vector.broadcast (duplication vs stretching).
+    // See https://github.com/llvm/llvm-project/issues/190614.
+    if (!srcIsScalar && srcVectorType.getRank() >= 2) {
+      auto hasUnitDim = [](ArrayRef<int64_t> shape) {
+        return llvm::any_of(shape, [](int64_t d) { return d == 1; });
+      };
+      if (hasUnitDim(srcVectorType.getShape()) &&
+          intermediateType.getRank() > dstVectorType.getRank()) {
+        ArrayRef<int64_t> droppedTrailingDims =
+            intermediateType.getShape().drop_front(dstVectorType.getRank());
+        if (hasUnitDim(droppedTrailingDims))
+          return failure();
+      }
+    }
     if (srcIsScalar || isBroadcastableTo(srcVectorType, dstVectorType) ==
                            BroadcastableToResult::Success) {
       rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
diff --git a/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
index b5c6e610f58f9..a19bcbdabc264 100644
--- a/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
@@ -10,7 +10,12 @@ func.func @mmt4d_to_fma(%A: tensor<16x16x8x1xf32>, %B: tensor<16x16x8x1xf32>, %C
 
 
 // CHECK-LABEL:     @mmt4d_to_fma
-// CHECK-COUNT-8:         vector.fma
+// CHECK-NOT:         linalg.mmt4d
+// Lowering may produce vector.fma (outerproduct) or elementwise arith.mulf +
+// vector.reduction depending on vector.contract lowering; both are valid.
+// CHECK:               arith.mulf
+// CHECK:               vector.reduction
+// CHECK:               vector.transfer_write
 
 module attributes {transform.with_named_sequence} {
   transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
index 1d082179207dc..c6e9aaf31af73 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
@@ -128,3 +128,20 @@ func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
   return %1: vector<1xf32>
 }
 
+// -----
+
+// https://github.com/llvm/llvm-project/issues/190614 — ShapeCastBroadcastFolder
+// must not rewrite shape_cast(broadcast) into a lower-rank broadcast when that
+// changes duplication (new leading dims) vs stretching (source unit dim).
+// For %arg0 = [[1], [2]]: broadcast to 2x2x1 then shape_cast to 2x2 yields
+// [[1, 2], [1, 2]]; folding to broadcast 2x1 -> 2x2 would incorrectly yield
+// [[1, 1], [2, 2]].
+// CHECK-LABEL: @shape_cast_broadcast_folder_issue_190614
+// CHECK:         vector.broadcast %{{.*}} : vector<2x1xf32> to vector<2x2x1xf32>
+// CHECK-NEXT:    vector.shape_cast %{{.*}} : vector<2x2x1xf32> to vector<2x2xf32>
+func.func @shape_cast_broadcast_folder_issue_190614(%arg0: vector<2x1xf32>) -> vector<2x2xf32> {
+  %0 = vector.broadcast %arg0 : vector<2x1xf32> to vector<2x2x1xf32>
+  %1 = vector.shape_cast %0 : vector<2x2x1xf32> to vector<2x2xf32>
+  return %1 : vector<2x2xf32>
+}
+



More information about the Mlir-commits mailing list