[Mlir-commits] [mlir] [mlir][vector] Constrain broadcast->shape_cast folding (PR #190230)
Benoit Jacob
llvmlistbot at llvm.org
Mon Apr 6 12:43:27 PDT 2026
https://github.com/bjacob updated https://github.com/llvm/llvm-project/pull/190230
>From eab3ba2512200222d88d70f113982d4771775a69 Mon Sep 17 00:00:00 2001
From: Benoit Jacob <benoit.jacob at amd.com>
Date: Thu, 2 Apr 2026 17:53:41 +0000
Subject: [PATCH] fix-ShapeCastBroadcastFolder
Signed-off-by: Benoit Jacob <benoit.jacob at amd.com>
---
mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 48 +++++++++++--
.../Linalg/transform-op-mmt4d-to-fma.mlir | 69 -------------------
.../canonicalize/vector-to-shape-cast.mlir | 16 +++++
3 files changed, 58 insertions(+), 75 deletions(-)
delete mode 100644 mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 939816262a2b14..cf2da8ff54a1d3 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6703,6 +6703,22 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
}
};
+// vector.broadcast has two distinct semantic modes: duplication across leading
+// dimensions, and stretching across inner dimensions. This helper returns the
+// product of the inner-dimension stretching factors.
+int64_t getBroadcastStretchingFactor(ArrayRef<int64_t> srcShape,
+ ArrayRef<int64_t> dstShape) {
+ int stretchingFactor = 1;
+ int numLeadingDims = dstShape.size() - srcShape.size();
+ for (int i = 0, e = srcShape.size(); i < e; i++) {
+ int64_t dstDim = dstShape[numLeadingDims + i];
+ if (srcShape[i] == 1 && dstDim != 1) {
+ stretchingFactor *= dstDim;
+ }
+ }
+ return stretchingFactor;
+}
+
/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
public:
@@ -6725,13 +6741,33 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
// to
// %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32>
VectorType dstVectorType = shapeCastOp.getResultVectorType();
- if (srcIsScalar || isBroadcastableTo(srcVectorType, dstVectorType) ==
- BroadcastableToResult::Success) {
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
- shapeCastOp, dstVectorType, broadcastOp.getSource());
- return success();
+ ArrayRef<int64_t> dstShape = dstVectorType.getShape();
+ ArrayRef<int64_t> srcShape =
+ srcIsScalar ? ArrayRef<int64_t>{} : srcVectorType.getShape();
+ ArrayRef<int64_t> broadcastShape =
+ broadcastOp.getResultVectorType().getShape();
+
+ if (!srcIsScalar) {
+ if (isBroadcastableTo(srcVectorType, dstVectorType) !=
+ BroadcastableToResult::Success) {
+ return failure();
+ }
+ // Avoid folding if this would result in switching between the two
+ // distinct semantic modes of vector.broadcast (duplication vs
+ // stretching). See https://github.com/llvm/llvm-project/issues/190614.
+ // This is detected by a change in the stretching factor. However if the
+ // source has a single element, there is no ambiguity.
+ if (srcVectorType.getNumElements() != 1) {
+ if (getBroadcastStretchingFactor(srcShape, dstShape) !=
+ getBroadcastStretchingFactor(srcShape, broadcastShape)) {
+ return failure();
+ }
+ }
}
- return failure();
+
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shapeCastOp, dstVectorType,
+ broadcastOp.getSource());
+ return success();
}
};
diff --git a/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
deleted file mode 100644
index b5c6e610f58f92..00000000000000
--- a/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
+++ /dev/null
@@ -1,69 +0,0 @@
-// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
-
-func.func @mmt4d_to_fma(%A: tensor<16x16x8x1xf32>, %B: tensor<16x16x8x1xf32>, %C_in: tensor<16x16x8x8xf32>) -> tensor<16x16x8x8xf32> {
- %res = linalg.mmt4d
- ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
- outs(%C_in: tensor<16x16x8x8xf32>)
- -> tensor<16x16x8x8xf32>
- return %res : tensor<16x16x8x8xf32>
-}
-
-
-// CHECK-LABEL: @mmt4d_to_fma
-// CHECK-COUNT-8: vector.fma
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
- %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
-
- %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %func : (!transform.op<"func.func">) -> !transform.any_op
-
- // Step 1: Tile
- // Tile parallel dims
- %tiled_linalg_op_p, %loops:4 = transform.structured.tile_using_for %mmt4d tile_sizes [1, 1, 0, 8, 8, 0]
- : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
- // Tile reduction dims
- %tiled_linalg_op_r, %loops2:2 = transform.structured.tile_using_for %tiled_linalg_op_p tile_sizes [0, 0, 1, 0, 0, 1]
- : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
-
- // Step 2: Vectorize
- transform.structured.vectorize %tiled_linalg_op_r : !transform.any_op
-
- // Step 3: Simplify
- // vector.multi_reduction --> vector.contract
- // Generates a 6-dim vector.contract with the dim matching the original MMT4D Op
- // and with the following split into parallel and reduction dims:
- // * parallel, parallel, reduction, parallel, parallel, reduction
- transform.apply_patterns to %func {
- transform.apply_patterns.vector.reduction_to_contract
- // Reduce the rank of xfer ops. This transforms vector.contract to be
- // more matmul-like and to enable the lowering to outer product Ops.
- transform.apply_patterns.vector.transfer_permutation_patterns
- } : !transform.op<"func.func">
-
- // Hoisting and LICM - not strictly required
- %func_h = transform.structured.hoist_redundant_vector_transfers %func
- : (!transform.op<"func.func">) -> !transform.op<"func.func">
- %all_loops = transform.structured.match interface{LoopLikeInterface} in %func_h
- : (!transform.op<"func.func">) -> !transform.any_op
- transform.apply_licm to %all_loops : !transform.any_op
- transform.loop.hoist_loop_invariant_subsets %all_loops : !transform.any_op
-
- // Simplify the 6-dim vector.contract into a 3-dim matmul-like
- // vector.contract with the following split into parallel and reduction
- // dims:
- // * parallel, parallel, reduction
- transform.apply_patterns to %func_h {
- transform.apply_patterns.vector.reduction_to_contract
- transform.apply_patterns.vector.cast_away_vector_leading_one_dim
- transform.apply_patterns.canonicalization
- } : !transform.op<"func.func">
-
- // Step 4: Lower vector.contract to vector.fma via vector.outerproduct
- transform.apply_patterns to %func_h {
- transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
- transform.apply_patterns.vector.lower_outerproduct
- } : !transform.op<"func.func">
- transform.yield
- }
-}
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
index 1d082179207dc9..7ae190d0b3c56c 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
@@ -128,3 +128,19 @@ func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
return %1: vector<1xf32>
}
+// -----
+
+// https://github.com/llvm/llvm-project/issues/190614 — ShapeCastBroadcastFolder
+// must not rewrite shape_cast(broadcast) into a lower-rank broadcast when that
+// changes duplication (new leading dims) vs stretching (source unit dim).
+// For %arg0 = [[1], [2]]: broadcast to 2x2x1 then shape_cast to 2x2 yields
+// [[1, 2], [1, 2]]; folding to broadcast 2x1 -> 2x2 would incorrectly yield
+// [[1, 1], [2, 2]].
+// CHECK-LABEL: @no_fold_bcast_mode_switch
+// CHECK: vector.broadcast %{{.*}} : vector<2x1xf32> to vector<2x2x1xf32>
+// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<2x2x1xf32> to vector<2x2xf32>
+func.func @no_fold_bcast_mode_switch(%arg0: vector<2x1xf32>) -> vector<2x2xf32> {
+ %0 = vector.broadcast %arg0 : vector<2x1xf32> to vector<2x2x1xf32>
+ %1 = vector.shape_cast %0 : vector<2x2x1xf32> to vector<2x2xf32>
+ return %1 : vector<2x2xf32>
+}
More information about the Mlir-commits
mailing list