[Mlir-commits] [mlir] 824fc35 - [mlir][vector] Constrain broadcast->shape_cast folding (#190230)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 7 11:26:06 PDT 2026
Author: Benoit Jacob
Date: 2026-04-07T14:26:01-04:00
New Revision: 824fc35bd9453e1776f098cc5d7123b1034ea510
URL: https://github.com/llvm/llvm-project/commit/824fc35bd9453e1776f098cc5d7123b1034ea510
DIFF: https://github.com/llvm/llvm-project/commit/824fc35bd9453e1776f098cc5d7123b1034ea510.diff
LOG: [mlir][vector] Constrain broadcast->shape_cast folding (#190230)
Fixes https://github.com/llvm/llvm-project/issues/190614.
Do not fold broadcast->shape_cast when that would result in switching
between the two distinct semantic modes of `vector.broadcast`, as
explained in https://github.com/llvm/llvm-project/issues/190614.
This fixes incorrect-result bugs in IREE:
https://github.com/iree-org/iree/issues/23952
---------
Signed-off-by: Benoit Jacob <benoit.jacob at amd.com>
Added:
mlir/test/Dialect/Vector/vector-multi-reduction-to-fma.mlir
Modified:
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
Removed:
mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
################################################################################
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 939816262a2b1..cf2da8ff54a1d 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6703,6 +6703,22 @@ class ShapeCastCreateMaskFolderTrailingOneDim final
}
};
+// vector.broadcast has two distinct semantic modes: duplication across leading
+// dimensions, and stretching across inner dimensions. This helper returns the
+// product of the inner-dimension stretching factors.
+int64_t getBroadcastStretchingFactor(ArrayRef<int64_t> srcShape,
+ ArrayRef<int64_t> dstShape) {
+ int stretchingFactor = 1;
+ int numLeadingDims = dstShape.size() - srcShape.size();
+ for (int i = 0, e = srcShape.size(); i < e; i++) {
+ int64_t dstDim = dstShape[numLeadingDims + i];
+ if (srcShape[i] == 1 && dstDim != 1) {
+ stretchingFactor *= dstDim;
+ }
+ }
+ return stretchingFactor;
+}
+
/// Pattern to rewrite Y = ShapeCast(Broadcast(X)) as Y = Broadcast(X)
class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
public:
@@ -6725,13 +6741,33 @@ class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
// to
// %1 = vector.broadcast %in : vector<3xf32> to vector<8x3xf32>
VectorType dstVectorType = shapeCastOp.getResultVectorType();
- if (srcIsScalar || isBroadcastableTo(srcVectorType, dstVectorType) ==
- BroadcastableToResult::Success) {
- rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
- shapeCastOp, dstVectorType, broadcastOp.getSource());
- return success();
+ ArrayRef<int64_t> dstShape = dstVectorType.getShape();
+ ArrayRef<int64_t> srcShape =
+ srcIsScalar ? ArrayRef<int64_t>{} : srcVectorType.getShape();
+ ArrayRef<int64_t> broadcastShape =
+ broadcastOp.getResultVectorType().getShape();
+
+ if (!srcIsScalar) {
+ if (isBroadcastableTo(srcVectorType, dstVectorType) !=
+ BroadcastableToResult::Success) {
+ return failure();
+ }
+ // Avoid folding if this would result in switching between the two
+ // distinct semantic modes of vector.broadcast (duplication vs
+ // stretching). See https://github.com/llvm/llvm-project/issues/190614.
+ // This is detected by a change in the stretching factor. However if the
+ // source has a single element, there is no ambiguity.
+ if (srcVectorType.getNumElements() != 1) {
+ if (getBroadcastStretchingFactor(srcShape, dstShape) !=
+ getBroadcastStretchingFactor(srcShape, broadcastShape)) {
+ return failure();
+ }
+ }
}
- return failure();
+
+ rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shapeCastOp, dstVectorType,
+ broadcastOp.getSource());
+ return success();
}
};
diff --git a/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir b/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
deleted file mode 100644
index b5c6e610f58f9..0000000000000
--- a/mlir/test/Dialect/Linalg/transform-op-mmt4d-to-fma.mlir
+++ /dev/null
@@ -1,69 +0,0 @@
-// RUN: mlir-opt %s -transform-interpreter | FileCheck %s
-
-func.func @mmt4d_to_fma(%A: tensor<16x16x8x1xf32>, %B: tensor<16x16x8x1xf32>, %C_in: tensor<16x16x8x8xf32>) -> tensor<16x16x8x8xf32> {
- %res = linalg.mmt4d
- ins(%A, %B: tensor<16x16x8x1xf32>, tensor<16x16x8x1xf32>)
- outs(%C_in: tensor<16x16x8x8xf32>)
- -> tensor<16x16x8x8xf32>
- return %res : tensor<16x16x8x8xf32>
-}
-
-
-// CHECK-LABEL: @mmt4d_to_fma
-// CHECK-COUNT-8: vector.fma
-
-module attributes {transform.with_named_sequence} {
- transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
- %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func">
-
- %mmt4d = transform.structured.match ops{["linalg.mmt4d"]} in %func : (!transform.op<"func.func">) -> !transform.any_op
-
- // Step 1: Tile
- // Tile parallel dims
- %tiled_linalg_op_p, %loops:4 = transform.structured.tile_using_for %mmt4d tile_sizes [1, 1, 0, 8, 8, 0]
- : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
- // Tile reduction dims
- %tiled_linalg_op_r, %loops2:2 = transform.structured.tile_using_for %tiled_linalg_op_p tile_sizes [0, 0, 1, 0, 0, 1]
- : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
-
- // Step 2: Vectorize
- transform.structured.vectorize %tiled_linalg_op_r : !transform.any_op
-
- // Step 3: Simplify
- // vector.multi_reduction --> vector.contract
- // Generates a 6-dim vector.contract with the dim matching the original MMT4D Op
- // and with the following split into parallel and reduction dims:
- // * parallel, parallel, reduction, parallel, parallel, reduction
- transform.apply_patterns to %func {
- transform.apply_patterns.vector.reduction_to_contract
- // Reduce the rank of xfer ops. This transforms vector.contract to be
- // more matmul-like and to enable the lowering to outer product Ops.
- transform.apply_patterns.vector.transfer_permutation_patterns
- } : !transform.op<"func.func">
-
- // Hoisting and LICM - not strictly required
- %func_h = transform.structured.hoist_redundant_vector_transfers %func
- : (!transform.op<"func.func">) -> !transform.op<"func.func">
- %all_loops = transform.structured.match interface{LoopLikeInterface} in %func_h
- : (!transform.op<"func.func">) -> !transform.any_op
- transform.apply_licm to %all_loops : !transform.any_op
- transform.loop.hoist_loop_invariant_subsets %all_loops : !transform.any_op
-
- // Simplify the 6-dim vector.contract into a 3-dim matmul-like
- // vector.contract with the following split into parallel and reduction
- // dims:
- // * parallel, parallel, reduction
- transform.apply_patterns to %func_h {
- transform.apply_patterns.vector.reduction_to_contract
- transform.apply_patterns.vector.cast_away_vector_leading_one_dim
- transform.apply_patterns.canonicalization
- } : !transform.op<"func.func">
-
- // Step 4: Lower vector.contract to vector.fma via vector.outerproduct
- transform.apply_patterns to %func_h {
- transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
- transform.apply_patterns.vector.lower_outerproduct
- } : !transform.op<"func.func">
- transform.yield
- }
-}
diff --git a/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir b/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
index 1d082179207dc..7ae190d0b3c56 100644
--- a/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize/vector-to-shape-cast.mlir
@@ -128,3 +128,19 @@ func.func @extract_from_broadcast(%src: vector<1x1x1xf32>) -> vector<1xf32> {
return %1: vector<1xf32>
}
+// -----
+
+// https://github.com/llvm/llvm-project/issues/190614 — ShapeCastBroadcastFolder
+// must not rewrite shape_cast(broadcast) into a lower-rank broadcast when that
+// changes duplication (new leading dims) vs stretching (source unit dim).
+// For %arg0 = [[1], [2]]: broadcast to 2x2x1 then shape_cast to 2x2 yields
+// [[1, 2], [1, 2]]; folding to broadcast 2x1 -> 2x2 would incorrectly yield
+// [[1, 1], [2, 2]].
+// CHECK-LABEL: @no_fold_bcast_mode_switch
+// CHECK: vector.broadcast %{{.*}} : vector<2x1xf32> to vector<2x2x1xf32>
+// CHECK-NEXT: vector.shape_cast %{{.*}} : vector<2x2x1xf32> to vector<2x2xf32>
+func.func @no_fold_bcast_mode_switch(%arg0: vector<2x1xf32>) -> vector<2x2xf32> {
+ %0 = vector.broadcast %arg0 : vector<2x1xf32> to vector<2x2x1xf32>
+ %1 = vector.shape_cast %0 : vector<2x2x1xf32> to vector<2x2xf32>
+ return %1 : vector<2x2xf32>
+}
diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-to-fma.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-to-fma.mlir
new file mode 100644
index 0000000000000..75ac060bcc666
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-multi-reduction-to-fma.mlir
@@ -0,0 +1,51 @@
+// RUN: mlir-opt %s --transform-interpreter | FileCheck %s
+
+// `arith.mulf` + `vector.multi_reduction` matches what `linalg.matmul` produces
+// after `transform.structured.vectorize` (with matmul-style transfer_read layout).
+// This checks the follow-on stack: multi_reduction → contract → (transfer layout
+// cleanup) → outer-product-style contraction lowering → `vector.fma`.
+
+#map_lhs = affine_map<(d0, d1) -> (d0, 0, d1)>
+#map_rhs = affine_map<(d0, d1) -> (0, d1, d0)>
+
+// CHECK-LABEL: func @multi_reduction_to_fma
+// CHECK-SAME: memref<3x4xf32>
+// CHECK-SAME: memref<4x3xf32>
+// CHECK-SAME: memref<3x3xf32>
+// CHECK-DAG: vector.transfer_read {{.*}} : memref<3x4xf32>, vector<3x4xf32>
+// CHECK-DAG: vector.transfer_read {{.*}} : memref<4x3xf32>, vector<4x3xf32>
+// CHECK-DAG: vector.transfer_read {{.*}} : memref<3x3xf32>, vector<3x3xf32>
+// One dot-product row uses three fused multiply-adds along K; 3 output rows × 4 K steps.
+// CHECK-COUNT-12: vector.fma
+// CHECK-NOT: vector.multi_reduction
+// CHECK-NOT: vector.contract
+// CHECK-NOT: vector.outerproduct
+// CHECK: vector.transfer_write {{.*}} : vector<3x3xf32>, memref<3x3xf32>
+func.func @multi_reduction_to_fma(%A: memref<3x4xf32>, %B: memref<4x3xf32>, %C: memref<3x3xf32>) {
+ %c0 = arith.constant 0 : index
+ %p = ub.poison : f32
+ %va = vector.transfer_read %A[%c0, %c0], %p {permutation_map = #map_lhs} : memref<3x4xf32>, vector<3x3x4xf32>
+ %vb = vector.transfer_read %B[%c0, %c0], %p {permutation_map = #map_rhs} : memref<4x3xf32>, vector<3x3x4xf32>
+ %vc = vector.transfer_read %C[%c0, %c0], %p : memref<3x3xf32>, vector<3x3xf32>
+ %mul = arith.mulf %va, %vb : vector<3x3x4xf32>
+ %acc = vector.multi_reduction <add>, %mul, %vc [2] : vector<3x3x4xf32> to vector<3x3xf32>
+ vector.transfer_write %acc, %C[%c0, %c0] : vector<3x3xf32>, memref<3x3xf32>
+ return
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) {
+ %f = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.any_op
+ transform.apply_patterns to %f {
+ transform.apply_patterns.vector.reduction_to_contract
+ transform.apply_patterns.vector.transfer_permutation_patterns
+ } : !transform.any_op
+ transform.apply_patterns to %f {
+ transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct"
+ } : !transform.any_op
+ transform.apply_patterns to %f {
+ transform.apply_patterns.vector.lower_outerproduct
+ } : !transform.any_op
+ transform.yield
+ }
+}
More information about the Mlir-commits
mailing list