[Mlir-commits] [mlir] [mlir][vector] Constrain broadcast->shape_cast folding (PR #190230)

Benoit Jacob llvmlistbot at llvm.org
Mon Apr 6 12:43:27 PDT 2026


https://github.com/bjacob updated https://github.com/llvm/llvm-project/pull/190230

>From eab3ba2512200222d88d70f113982d4771775a69 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <benoit.jacob at amd.com>
Date: Thu, 2 Apr 2026 17:53:41 +0000
Subject: [PATCH] fix-ShapeCastBroadcastFolder

Signed-off-by: Benoit Jacob <benoit.jacob at amd.com>
---
 mlir/lib/Dialect/Vector/IR/VectorOps.cpp      | 48 +++++++++++--
 .../Linalg/transform-op-mmt4d-to-fma.mlir     | 69 -------------------
 .../canonicalize/vector-to-shape-cast.mlir    | 16 +++++
 3 files changed, 58 insertions(+), 75 deletions(-)
 delete mode 100644 mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir

diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 939816262a2b14..cf2da8ff54a1d3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6703,6 +6703,22 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
   }
 };
 
+// vector.broadcast has two distinct semantic modes: duplication across leading
+// dimensions, and stretching across inner dimensions. This helper returns the
+// product of the inner-dimension stretching factors.
+int64_t getBroadcastStretchingFactor(ArrayRef<int64_t> srcShape,
+                                     ArrayRef<int64_t> dstShape) {
+  int stretchingFactor = 1;
+  int numLeadingDims = dstShape.size() - srcShape.size();
+  for (int i = 0, e = srcShape.size(); i < e; i++) {
+    int64_t dstDim = dstShape[numLeadingDims + i];
+    if (srcShape[i] == 1 && dstDim != 1) {
+      stretchingFactor *= dstDim;
+    }
+  }
+  return stretchingFactor;
+}
+
 /// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
 class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
 public:
@@ -6725,13 +6741,33 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
     // to
     // %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32>
     VectorType dstVectorType = shapeCastOp.getResultVectorType();
-    if (srcIsScalar || isBroadcastableTo(srcVectorType, dstVectorType) ==
-                           BroadcastableToResult::Success) {
-      rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
-          shapeCastOp, dstVectorType, broadcastOp.getSource());
-      return success();
+    ArrayRef<int64_t> dstShape = dstVectorType.getShape();
+    ArrayRef<int64_t> srcShape =
+        srcIsScalar ? ArrayRef<int64_t>{} : srcVectorType.getShape();
+    ArrayRef<int64_t> broadcastShape =
+        broadcastOp.getResultVectorType().getShape();
+
+    if (!srcIsScalar) {
+      if (isBroadcastableTo(srcVectorType, dstVectorType) !=
+          BroadcastableToResult::Success) {
+        return failure();
+      }
+      // Avoid folding if this would result in switching between the two
+      // distinct semantic modes of vector.broadcast (duplication vs
+      // stretching). See https://github.com/llvm/llvm-project/issues/190614.
+      // This is detected by a change in the stretching factor. However if the
+      // source has a single element, there is no ambiguity.
+      if (srcVectorType.getNumElements() != 1) {
+        if (getBroadcastStretchingFactor(srcShape, dstShape) !=
+            getBroadcastStretchingFactor(srcShape, broadcastShape)) {
+          return failure();
+        }
+      }
     }
-    return failure();
+
+    rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shapeCastOp, dstVectorType,
+                                                     broadcastOp.getSource());
+    return success();
   }
 };
 
diff --git a/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
deleted file mode 100644
index b5c6e610f58f92..00000000000000
--- a/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
+++ /dev/null
@@ -1,69 +0,0 @@
-// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
-
-func.func @mmt4d_to_fma(%A: tensor<16x16x8x1xf32>, %B: tensor<16x16x8x1xf32>, %C_in: tensor<16x16x8x8xf32>) -> tensor<16x16x8x8xf32> {
-  %res = linalg.mmt4d
-                   ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
-                   outs(%C_in: tensor<16x16x8x8xf32>)
-                   -> tensor<16x16x8x8xf32>
-  return %res : tensor<16x16x8x8xf32>
-}
-
-
-// CHECK-LABEL:     @mmt4d_to_fma
-// CHECK-COUNT-8:         vector.fma
-
-module attributes {transform.with_named_sequence} {
-  transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
-    %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
-
-    %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %func : (!transform.op<"func.func">) -> !transform.any_op
-
-    // Step 1: Tile
-    // Tile parallel dims
-    %tiled_linalg_op_p, %loops:4 = transform.structured.tile_using_for %mmt4d tile_sizes [1, 1, 0, 8, 8, 0]
-      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
-    // Tile reduction dims
-    %tiled_linalg_op_r, %loops2:2 = transform.structured.tile_using_for %tiled_linalg_op_p tile_sizes [0, 0, 1, 0, 0, 1]
-      : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
-
-    // Step 2: Vectorize
-    transform.structured.vectorize %tiled_linalg_op_r : !transform.any_op
-
-    // Step 3: Simplify
-    // vector.multi_reduction --> vector.contract
-    // Generates a 6-dim vector.contract with the dim matching the original MMT4D Op
-    // and with the following split into parallel and reduction dims:
-    //    * parallel, parallel, reduction, parallel, parallel, reduction
-    transform.apply_patterns to %func {
-      transform.apply_patterns.vector.reduction_to_contract
-      // Reduce the rank of xfer ops. This transforms vector.contract to be
-      // more matmul-like and to enable the lowering to outer product Ops.
-      transform.apply_patterns.vector.transfer_permutation_patterns
-    } : !transform.op<"func.func">
-
-    // Hoisting and LICM - not strictly required
-    %func_h = transform.structured.hoist_redundant_vector_transfers %func
-      : (!transform.op<"func.func">) -> !transform.op<"func.func">
-    %all_loops = transform.structured.match interface{LoopLikeInterface} in %func_h
-      : (!transform.op<"func.func">) -> !transform.any_op
-    transform.apply_licm to %all_loops : !transform.any_op
-    transform.loop.hoist_loop_invariant_subsets %all_loops : !transform.any_op
-
-    // Simplify the 6-dim vector.contract into a 3-dim matmul-like
-    // vector.contract with the following split into parallel and reduction
-    // dims:
-    //    * parallel, parallel, reduction
-    transform.apply_patterns to %func_h {
-      transform.apply_patterns.vector.reduction_to_contract
-      transform.apply_patterns.vector.cast_away_vector_leading_one_dim
-      transform.apply_patterns.canonicalization
-    } : !transform.op<"func.func">
-
-    // Step 4: Lower vector.contract to vector.fma via vector.outerproduct
-    transform.apply_patterns to %func_h {
-      transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
-      transform.apply_patterns.vector.lower_outerproduct
-    } : !transform.op<"func.func">
-    transform.yield
-  }
-}
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
index 1d082179207dc9..7ae190d0b3c56c 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
@@ -128,3 +128,19 @@ func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
   return %1: vector<1xf32>
 }
 
+// -----
+
+// https://github.com/llvm/llvm-project/issues/190614 — ShapeCastBroadcastFolder
+// must not rewrite shape_cast(broadcast) into a lower-rank broadcast when that
+// changes duplication (new leading dims) vs stretching (source unit dim).
+// For %arg0 = [[1], [2]]: broadcast to 2x2x1 then shape_cast to 2x2 yields
+// [[1, 2], [1, 2]]; folding to broadcast 2x1 -> 2x2 would incorrectly yield
+// [[1, 1], [2, 2]].
+// CHECK-LABEL: @no_fold_bcast_mode_switch
+// CHECK:         vector.broadcast %{{.*}} : vector<2x1xf32> to vector<2x2x1xf32>
+// CHECK-NEXT:    vector.shape_cast %{{.*}} : vector<2x2x1xf32> to vector<2x2xf32>
+func.func @no_fold_bcast_mode_switch(%arg0: vector<2x1xf32>) -> vector<2x2xf32> {
+  %0 = vector.broadcast %arg0 : vector<2x1xf32> to vector<2x2x1xf32>
+  %1 = vector.shape_cast %0 : vector<2x2x1xf32> to vector<2x2xf32>
+  return %1 : vector<2x2xf32>
+}



More information about the Mlir-commits mailing list